mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Compare commits
403 Commits
fix/copilo
...
combined-p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
64d9d6d880 | ||
|
|
9fc324e28a | ||
|
|
adf66bdd24 | ||
|
|
dc10ad715a | ||
|
|
493c91e0dd | ||
|
|
b278e66f4d | ||
|
|
3e183ed2a3 | ||
|
|
82887a2d92 | ||
|
|
993c43b623 | ||
|
|
13fcc62a31 | ||
|
|
8fefa23468 | ||
|
|
749a56ca20 | ||
|
|
a8a62eeefc | ||
|
|
173614bcc5 | ||
|
|
3396cb3f4c | ||
|
|
0c5d628b74 | ||
|
|
ed40549499 | ||
|
|
fbe634fb19 | ||
|
|
a338c72c42 | ||
|
|
a9d13f0cbf | ||
|
|
e83e50a8f1 | ||
|
|
7f4398efa3 | ||
|
|
c2a054c511 | ||
|
|
b256560619 | ||
|
|
c63d5f538b | ||
|
|
eeba884671 | ||
|
|
90822e3f37 | ||
|
|
a8bb6b5544 | ||
|
|
83b00f4789 | ||
|
|
4cd53bb7f6 | ||
|
|
96d83e9bbd | ||
|
|
e99f4ac767 | ||
|
|
67c2540177 | ||
|
|
0da949ba42 | ||
|
|
95524e94b3 | ||
|
|
eda02f9ce6 | ||
|
|
9ab6082a23 | ||
|
|
2c517ff9a1 | ||
|
|
7020ae2189 | ||
|
|
a49ac5ba13 | ||
|
|
2a969e5018 | ||
|
|
79005b1be5 | ||
|
|
4f8cdbee47 | ||
|
|
3ed444dd60 | ||
|
|
83e747ebcd | ||
|
|
827f2b0f87 | ||
|
|
b0d5d3b95e | ||
|
|
eb9244be1a | ||
|
|
dd17e83299 | ||
|
|
74009bedac | ||
|
|
72d0c8dad8 | ||
|
|
e860f164e4 | ||
|
|
b9336984be | ||
|
|
9924dedddc | ||
|
|
c054799b4f | ||
|
|
004d3957b3 | ||
|
|
f3b5d584a3 | ||
|
|
476d9dcf80 | ||
|
|
072b623f8b | ||
|
|
a68f48e6b7 | ||
|
|
60e2474640 | ||
|
|
a892bbd4dd | ||
|
|
538e8619da | ||
|
|
4edb1f6e4a | ||
|
|
480d58607d | ||
|
|
8561eb35f2 | ||
|
|
0b4acd73f4 | ||
|
|
e9fe2991d6 | ||
|
|
26b0c95936 | ||
|
|
735965bbe5 | ||
|
|
a8f9ed0f60 | ||
|
|
308357de84 | ||
|
|
1a6c50c6cc | ||
|
|
9391dfa4b2 | ||
|
|
6b031085bd | ||
|
|
6a69d7c68d | ||
|
|
ad77e881c9 | ||
|
|
f1aedfeedd | ||
|
|
49c7ab4011 | ||
|
|
2d04584c84 | ||
|
|
2578f61abb | ||
|
|
927c6e7db0 | ||
|
|
f753e6162f | ||
|
|
b996bc556b | ||
|
|
e4f79261c1 | ||
|
|
09bc939498 | ||
|
|
79c5a10f75 | ||
|
|
2bf5a37646 | ||
|
|
d5d24e6e66 | ||
|
|
c9cbd7531e | ||
|
|
289a19d402 | ||
|
|
7800af1835 | ||
|
|
114f91ff53 | ||
|
|
63a0153e4f | ||
|
|
1364616ff1 | ||
|
|
4e9169c1a2 | ||
|
|
705e97ec46 | ||
|
|
8cea0bede0 | ||
|
|
5e52050788 | ||
|
|
0b77af29aa | ||
|
|
bd4cc21fc6 | ||
|
|
19ea753639 | ||
|
|
07fd734fa1 | ||
|
|
8a4a16ec5c | ||
|
|
5551da674e | ||
|
|
e57e48272a | ||
|
|
6f03ceeb88 | ||
|
|
554ff0b20b | ||
|
|
c2f421cb42 | ||
|
|
dd228de17d | ||
|
|
c26ff22f9c | ||
|
|
760360fbe9 | ||
|
|
e3d589b180 | ||
|
|
913d93f47c | ||
|
|
03e5d37dc4 | ||
|
|
6e2dab413e | ||
|
|
b10dc7c2d5 | ||
|
|
8de935c84b | ||
|
|
dd34b0dc48 | ||
|
|
015e0d591e | ||
|
|
2cb65f5c34 | ||
|
|
3a49086c3d | ||
|
|
0e567df1da | ||
|
|
b5b754d5eb | ||
|
|
456bb1c4d0 | ||
|
|
263cd0ecac | ||
|
|
66afca6e0c | ||
|
|
11b846dd49 | ||
|
|
a71396ee48 | ||
|
|
beb43bb847 | ||
|
|
a55653f8c1 | ||
|
|
f3dd708cf6 | ||
|
|
c4ff31c79c | ||
|
|
9f2257daaa | ||
|
|
925e9a047c | ||
|
|
3e6faf2de7 | ||
|
|
40a1f504c0 | ||
|
|
22e8c5c353 | ||
|
|
1de2a7fb09 | ||
|
|
b3d9e9e856 | ||
|
|
48b166a82c | ||
|
|
697b15ce81 | ||
|
|
5beabf936c | ||
|
|
b9e29c96bd | ||
|
|
32bfe1b209 | ||
|
|
62302db470 | ||
|
|
89c7f34d26 | ||
|
|
543fc2da70 | ||
|
|
7f986bc565 | ||
|
|
f4571cb9e1 | ||
|
|
5f41afe748 | ||
|
|
d046c01a65 | ||
|
|
b220fe4347 | ||
|
|
7af138adba | ||
|
|
5c406a20ba | ||
|
|
61513b9dad | ||
|
|
6f679a0e32 | ||
|
|
b8065212b1 | ||
|
|
d5281a9a13 | ||
|
|
05495d8478 | ||
|
|
bae409d04e | ||
|
|
e11eb2caaa | ||
|
|
2c04768711 | ||
|
|
c9bf3aa339 | ||
|
|
4ac0ba570a | ||
|
|
d61a2c6cd0 | ||
|
|
1c301b4b61 | ||
|
|
e753aee7a0 | ||
|
|
f76566c834 | ||
|
|
a58b997141 | ||
|
|
3f24a003ad | ||
|
|
1a645e1e37 | ||
|
|
bee76962b0 | ||
|
|
864e68bed1 | ||
|
|
7c6201110c | ||
|
|
bded680b77 | ||
|
|
1e008dc172 | ||
|
|
9966e122ab | ||
|
|
65108c31dc | ||
|
|
7767c97f50 | ||
|
|
69ab21ebe7 | ||
|
|
6fe4e1b774 | ||
|
|
c778cc9849 | ||
|
|
50b635da6d | ||
|
|
08e254143b | ||
|
|
89fcfc4e0a | ||
|
|
e7ca07f4bf | ||
|
|
c564ac7277 | ||
|
|
ac3a826ad0 | ||
|
|
6f32184019 | ||
|
|
6d0eedae83 | ||
|
|
fb328f9d74 | ||
|
|
a369fbe169 | ||
|
|
2a0b74cae4 | ||
|
|
b08f9fc02a | ||
|
|
857acb2bbc | ||
|
|
0cb230c4f0 | ||
|
|
2cd5c0eab8 | ||
|
|
7bf8e460ea | ||
|
|
84d328517a | ||
|
|
842ff6c600 | ||
|
|
b510fbee2a | ||
|
|
bb7f0ad1f2 | ||
|
|
3f8af89b63 | ||
|
|
375e5e1f10 | ||
|
|
fd1d706315 | ||
|
|
faf2f43f6a | ||
|
|
eea230d37f | ||
|
|
76965429f1 | ||
|
|
eefa60368f | ||
|
|
88fe1e9b5e | ||
|
|
93264b1177 | ||
|
|
3269d17880 | ||
|
|
1e5788f2cf | ||
|
|
ca8214d95f | ||
|
|
f58ce5cc70 | ||
|
|
bf29801b07 | ||
|
|
dcc2bdd8ab | ||
|
|
e74a918c4a | ||
|
|
ff05b5b8d5 | ||
|
|
56090f870c | ||
|
|
a3e3d3ff6b | ||
|
|
cfaa1ff0d4 | ||
|
|
6ab9a3285f | ||
|
|
390324d5a1 | ||
|
|
9f51796dbe | ||
|
|
f42f0013df | ||
|
|
3154e5b87a | ||
|
|
78cd14d501 | ||
|
|
137edb3e6e | ||
|
|
449e9b17f1 | ||
|
|
5b3f87d7c7 | ||
|
|
ee7209a575 | ||
|
|
7ea89b07ce | ||
|
|
5324e0cc2f | ||
|
|
c7cbb8b02e | ||
|
|
d66ffb1ee4 | ||
|
|
5d489c72b5 | ||
|
|
646ffe1693 | ||
|
|
59b1811e8b | ||
|
|
6404e58fb1 | ||
|
|
e0bfa1524e | ||
|
|
ac947a0c11 | ||
|
|
c9f45f056a | ||
|
|
89264091ad | ||
|
|
e3183f1955 | ||
|
|
3ea243c760 | ||
|
|
991969612c | ||
|
|
8de9880f43 | ||
|
|
86d8efe697 | ||
|
|
10ec6c7215 | ||
|
|
51e5371362 | ||
|
|
cdd14726ce | ||
|
|
1ebd5635f6 | ||
|
|
349b6c63de | ||
|
|
2f7cfa6f1b | ||
|
|
049aa1ad7d | ||
|
|
a16be2675b | ||
|
|
ac416a561e | ||
|
|
c47fcc1925 | ||
|
|
77fd8648a7 | ||
|
|
4842599bec | ||
|
|
339e155823 | ||
|
|
9344e62d66 | ||
|
|
ee6cc20cbc | ||
|
|
eb96b019c5 | ||
|
|
9cf6ac9ad9 | ||
|
|
d3173605eb | ||
|
|
98c27653f2 | ||
|
|
dced534df3 | ||
|
|
4ebe294707 | ||
|
|
2e8e115cd1 | ||
|
|
5ca49a8ec9 | ||
|
|
a9db5af0fa | ||
|
|
dcbfcfb158 | ||
|
|
723b852ba4 | ||
|
|
c7e0f8169a | ||
|
|
ce1555c07a | ||
|
|
403a36a3fc | ||
|
|
490643d65a | ||
|
|
2b14ecf5ee | ||
|
|
14d6d66bdc | ||
|
|
28443e2e33 | ||
|
|
611a20d7df | ||
|
|
ce201cd19c | ||
|
|
0c76852768 | ||
|
|
414b8bbaac | ||
|
|
4c85f2399a | ||
|
|
db0e5a1b0b | ||
|
|
22a5e76af9 | ||
|
|
7919da16b4 | ||
|
|
052f953afb | ||
|
|
abd9fbe08a | ||
|
|
81308af770 | ||
|
|
a726c1d1d5 | ||
|
|
a015bf9e1c | ||
|
|
d99278a40d | ||
|
|
bd7d9a5697 | ||
|
|
9cfa53a2ff | ||
|
|
e6cf899a6d | ||
|
|
b655b30aeb | ||
|
|
5b8daf5d4c | ||
|
|
9b74b7bb41 | ||
|
|
a1578984cc | ||
|
|
c0869e9168 | ||
|
|
0db5a6ff9a | ||
|
|
3664624445 | ||
|
|
f1e2ce0703 | ||
|
|
c226cf0925 | ||
|
|
dade634b4a | ||
|
|
34101c4389 | ||
|
|
2218254c8a | ||
|
|
4d63cffa7a | ||
|
|
ebf3b920d8 | ||
|
|
9bd579b041 | ||
|
|
41601cbb5c | ||
|
|
c636b6f310 | ||
|
|
292be77b86 | ||
|
|
dd3349e6bc | ||
|
|
bfdf4b99db | ||
|
|
aba78b0fdd | ||
|
|
12934dfd72 | ||
|
|
c5507415fd | ||
|
|
7ff096afd9 | ||
|
|
38fb504063 | ||
|
|
b4388a9c93 | ||
|
|
a7a68e585a | ||
|
|
14ad37b0c7 | ||
|
|
389cd28879 | ||
|
|
656858eba1 | ||
|
|
f0a3afda7d | ||
|
|
a9cbb3ee2f | ||
|
|
1810452920 | ||
|
|
4f6f3ca240 | ||
|
|
9ffecbac02 | ||
|
|
eb22cf4483 | ||
|
|
16636b64c6 | ||
|
|
c2709fbc28 | ||
|
|
3adbaacc0e | ||
|
|
4da3535a9c | ||
|
|
56e0b568a4 | ||
|
|
4acac9ff5b | ||
|
|
0b0777ac87 | ||
|
|
698b1599cb | ||
|
|
a2f94f08d9 | ||
|
|
0c6f20f728 | ||
|
|
d100b2515b | ||
|
|
14113f96a9 | ||
|
|
ee40a4b9a8 | ||
|
|
0008cafc3b | ||
|
|
f55bc84fe7 | ||
|
|
3cfee4c4b5 | ||
|
|
c48b5239b9 | ||
|
|
e44615f8b8 | ||
|
|
22f0da0a03 | ||
|
|
9264b42050 | ||
|
|
3a40188024 | ||
|
|
8d6433c1a5 | ||
|
|
c7430eaffb | ||
|
|
dc272559c6 | ||
|
|
a98b0aee95 | ||
|
|
264869cab9 | ||
|
|
a85ba9e36d | ||
|
|
18c5f67107 | ||
|
|
0348e7b228 | ||
|
|
e35376d3ec | ||
|
|
687af1bdc3 | ||
|
|
694032e45f | ||
|
|
231a4b6f51 | ||
|
|
da6f77da47 | ||
|
|
1747f4e6f3 | ||
|
|
0d6d8e820c | ||
|
|
24c286fbed | ||
|
|
c75f1ff749 | ||
|
|
cfc6d3538c | ||
|
|
e9540041d6 | ||
|
|
8ac86a03b5 | ||
|
|
2aac78eae4 | ||
|
|
dbfc791357 | ||
|
|
880c957c86 | ||
|
|
857a8ef0aa | ||
|
|
1008f9fcd4 | ||
|
|
c26791e6ae | ||
|
|
cf66c08125 | ||
|
|
b4362785e4 | ||
|
|
f38fa96df4 | ||
|
|
98c8f94ef2 | ||
|
|
7b0111d9b5 | ||
|
|
85e9e4c5b7 | ||
|
|
e900ee615a | ||
|
|
e1d5113051 | ||
|
|
4963d227ea | ||
|
|
19dea0e4ca | ||
|
|
87d5a39267 | ||
|
|
87ac8148e3 | ||
|
|
491132f62f | ||
|
|
55815a3207 | ||
|
|
5c3aa11600 | ||
|
|
b5cbf8505b | ||
|
|
f49f63de76 | ||
|
|
8f76384942 | ||
|
|
ffb8d366d6 | ||
|
|
432ef5ab5e |
@@ -0,0 +1,85 @@
|
||||
import logging
|
||||
import typing
|
||||
from datetime import datetime
|
||||
|
||||
from autogpt_libs.auth import get_user_id, requires_admin_user
|
||||
from fastapi import APIRouter, Query, Security
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.platform_cost import (
|
||||
CostLogRow,
|
||||
PlatformCostDashboard,
|
||||
get_platform_cost_dashboard,
|
||||
get_platform_cost_logs,
|
||||
)
|
||||
from backend.util.models import Pagination
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/admin",
|
||||
tags=["platform-cost", "admin"],
|
||||
dependencies=[Security(requires_admin_user)],
|
||||
)
|
||||
|
||||
|
||||
class PlatformCostLogsResponse(BaseModel):
|
||||
logs: list[CostLogRow]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
@router.get(
|
||||
"/platform_costs/dashboard",
|
||||
response_model=PlatformCostDashboard,
|
||||
summary="Get Platform Cost Dashboard",
|
||||
)
|
||||
async def get_cost_dashboard(
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
start: typing.Optional[datetime] = Query(None),
|
||||
end: typing.Optional[datetime] = Query(None),
|
||||
provider: typing.Optional[str] = Query(None),
|
||||
user_id: typing.Optional[str] = Query(None),
|
||||
):
|
||||
logger.info(f"Admin {admin_user_id} fetching platform cost dashboard")
|
||||
return await get_platform_cost_dashboard(
|
||||
start=start,
|
||||
end=end,
|
||||
provider=provider,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/platform_costs/logs",
|
||||
response_model=PlatformCostLogsResponse,
|
||||
summary="Get Platform Cost Logs",
|
||||
)
|
||||
async def get_cost_logs(
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
start: typing.Optional[datetime] = Query(None),
|
||||
end: typing.Optional[datetime] = Query(None),
|
||||
provider: typing.Optional[str] = Query(None),
|
||||
user_id: typing.Optional[str] = Query(None),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(50, ge=1, le=200),
|
||||
):
|
||||
logger.info(f"Admin {admin_user_id} fetching platform cost logs")
|
||||
logs, total = await get_platform_cost_logs(
|
||||
start=start,
|
||||
end=end,
|
||||
provider=provider,
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
return PlatformCostLogsResponse(
|
||||
logs=logs,
|
||||
pagination=Pagination(
|
||||
total_items=total,
|
||||
total_pages=total_pages,
|
||||
current_page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,135 @@
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
from .platform_cost_routes import router as platform_cost_router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(platform_cost_router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_admin_auth(mock_jwt_admin):
|
||||
"""Setup admin auth overrides for all tests in this module"""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_get_dashboard_success(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mock_dashboard = AsyncMock(
|
||||
return_value=AsyncMock(
|
||||
by_provider=[],
|
||||
by_user=[],
|
||||
total_cost_microdollars=0,
|
||||
total_requests=0,
|
||||
total_users=0,
|
||||
model_dump=lambda **_: {
|
||||
"by_provider": [],
|
||||
"by_user": [],
|
||||
"total_cost_microdollars": 0,
|
||||
"total_requests": 0,
|
||||
"total_users": 0,
|
||||
},
|
||||
)
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_dashboard",
|
||||
mock_dashboard,
|
||||
)
|
||||
|
||||
response = client.get("/admin/platform_costs/dashboard")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "by_provider" in data
|
||||
assert "by_user" in data
|
||||
assert data["total_cost_microdollars"] == 0
|
||||
|
||||
|
||||
def test_get_logs_success(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_logs",
|
||||
AsyncMock(return_value=([], 0)),
|
||||
)
|
||||
|
||||
response = client.get("/admin/platform_costs/logs")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["logs"] == []
|
||||
assert data["pagination"]["total_items"] == 0
|
||||
|
||||
|
||||
def test_get_dashboard_with_filters(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mock_dashboard = AsyncMock(
|
||||
return_value=AsyncMock(
|
||||
by_provider=[],
|
||||
by_user=[],
|
||||
total_cost_microdollars=0,
|
||||
total_requests=0,
|
||||
total_users=0,
|
||||
model_dump=lambda **_: {
|
||||
"by_provider": [],
|
||||
"by_user": [],
|
||||
"total_cost_microdollars": 0,
|
||||
"total_requests": 0,
|
||||
"total_users": 0,
|
||||
},
|
||||
)
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_dashboard",
|
||||
mock_dashboard,
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/admin/platform_costs/dashboard",
|
||||
params={
|
||||
"start": "2026-01-01T00:00:00",
|
||||
"end": "2026-04-01T00:00:00",
|
||||
"provider": "openai",
|
||||
"user_id": "test-user-123",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
mock_dashboard.assert_called_once()
|
||||
call_kwargs = mock_dashboard.call_args.kwargs
|
||||
assert call_kwargs["provider"] == "openai"
|
||||
assert call_kwargs["user_id"] == "test-user-123"
|
||||
assert call_kwargs["start"] is not None
|
||||
assert call_kwargs["end"] is not None
|
||||
|
||||
|
||||
def test_get_logs_with_pagination(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_logs",
|
||||
AsyncMock(return_value=([], 0)),
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/admin/platform_costs/logs",
|
||||
params={"page": 2, "page_size": 25, "provider": "anthropic"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["pagination"]["current_page"] == 2
|
||||
assert data["pagination"]["page_size"] == 25
|
||||
|
||||
|
||||
def test_get_dashboard_requires_admin() -> None:
|
||||
app.dependency_overrides.clear()
|
||||
response = client.get("/admin/platform_costs/dashboard")
|
||||
assert response.status_code in (401, 403)
|
||||
@@ -9,11 +9,14 @@ from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.rate_limit import (
|
||||
SubscriptionTier,
|
||||
get_global_rate_limits,
|
||||
get_usage_status,
|
||||
get_user_tier,
|
||||
reset_user_usage,
|
||||
set_user_tier,
|
||||
)
|
||||
from backend.data.user import get_user_by_email, get_user_email_by_id
|
||||
from backend.data.user import get_user_by_email, get_user_email_by_id, search_users
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -33,6 +36,17 @@ class UserRateLimitResponse(BaseModel):
|
||||
weekly_token_limit: int
|
||||
daily_tokens_used: int
|
||||
weekly_tokens_used: int
|
||||
tier: SubscriptionTier
|
||||
|
||||
|
||||
class UserTierResponse(BaseModel):
|
||||
user_id: str
|
||||
tier: SubscriptionTier
|
||||
|
||||
|
||||
class SetUserTierRequest(BaseModel):
|
||||
user_id: str
|
||||
tier: SubscriptionTier
|
||||
|
||||
|
||||
async def _resolve_user_id(
|
||||
@@ -86,10 +100,10 @@ async def get_user_rate_limit(
|
||||
|
||||
logger.info("Admin %s checking rate limit for user %s", admin_user_id, resolved_id)
|
||||
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||
resolved_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
usage = await get_usage_status(resolved_id, daily_limit, weekly_limit)
|
||||
usage = await get_usage_status(resolved_id, daily_limit, weekly_limit, tier=tier)
|
||||
|
||||
return UserRateLimitResponse(
|
||||
user_id=resolved_id,
|
||||
@@ -98,6 +112,7 @@ async def get_user_rate_limit(
|
||||
weekly_token_limit=weekly_limit,
|
||||
daily_tokens_used=usage.daily.used,
|
||||
weekly_tokens_used=usage.weekly.used,
|
||||
tier=tier,
|
||||
)
|
||||
|
||||
|
||||
@@ -125,10 +140,10 @@ async def reset_user_rate_limit(
|
||||
logger.exception("Failed to reset user usage")
|
||||
raise HTTPException(status_code=500, detail="Failed to reset usage") from e
|
||||
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
usage = await get_usage_status(user_id, daily_limit, weekly_limit)
|
||||
usage = await get_usage_status(user_id, daily_limit, weekly_limit, tier=tier)
|
||||
|
||||
try:
|
||||
resolved_email = await get_user_email_by_id(user_id)
|
||||
@@ -143,4 +158,96 @@ async def reset_user_rate_limit(
|
||||
weekly_token_limit=weekly_limit,
|
||||
daily_tokens_used=usage.daily.used,
|
||||
weekly_tokens_used=usage.weekly.used,
|
||||
tier=tier,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/rate_limit/tier",
|
||||
response_model=UserTierResponse,
|
||||
summary="Get User Rate Limit Tier",
|
||||
)
|
||||
async def get_user_rate_limit_tier(
|
||||
user_id: str,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> UserTierResponse:
|
||||
"""Get a user's current rate-limit tier. Admin-only.
|
||||
|
||||
Returns 404 if the user does not exist in the database.
|
||||
"""
|
||||
logger.info("Admin %s checking tier for user %s", admin_user_id, user_id)
|
||||
|
||||
resolved_email = await get_user_email_by_id(user_id)
|
||||
if resolved_email is None:
|
||||
raise HTTPException(status_code=404, detail=f"User {user_id} not found")
|
||||
|
||||
tier = await get_user_tier(user_id)
|
||||
return UserTierResponse(user_id=user_id, tier=tier)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/rate_limit/tier",
|
||||
response_model=UserTierResponse,
|
||||
summary="Set User Rate Limit Tier",
|
||||
)
|
||||
async def set_user_rate_limit_tier(
|
||||
request: SetUserTierRequest,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> UserTierResponse:
|
||||
"""Set a user's rate-limit tier. Admin-only."""
|
||||
old_tier = await get_user_tier(request.user_id)
|
||||
|
||||
# Resolve email for audit logging (non-blocking — don't fail the
|
||||
# tier change if email lookup fails).
|
||||
try:
|
||||
resolved_email = await get_user_email_by_id(request.user_id)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to resolve email for user %s", request.user_id, exc_info=True
|
||||
)
|
||||
resolved_email = None
|
||||
logger.info(
|
||||
"Admin %s changing tier for user %s (%s): %s -> %s",
|
||||
admin_user_id,
|
||||
request.user_id,
|
||||
resolved_email or "unknown",
|
||||
old_tier.value,
|
||||
request.tier.value,
|
||||
)
|
||||
try:
|
||||
await set_user_tier(request.user_id, request.tier)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to set user tier")
|
||||
raise HTTPException(status_code=500, detail="Failed to set tier") from e
|
||||
|
||||
return UserTierResponse(user_id=request.user_id, tier=request.tier)
|
||||
|
||||
|
||||
class UserSearchResult(BaseModel):
|
||||
user_id: str
|
||||
user_email: Optional[str] = None
|
||||
|
||||
|
||||
@router.get(
|
||||
"/rate_limit/search_users",
|
||||
response_model=list[UserSearchResult],
|
||||
summary="Search Users by Name or Email",
|
||||
)
|
||||
async def admin_search_users(
|
||||
query: str,
|
||||
limit: int = 20,
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
) -> list[UserSearchResult]:
|
||||
"""Search users by partial email or name. Admin-only.
|
||||
|
||||
Queries the User table directly — returns results even for users
|
||||
without credit transaction history.
|
||||
"""
|
||||
if len(query.strip()) < 3:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Search query must be at least 3 characters.",
|
||||
)
|
||||
logger.info("Admin %s searching users with query=%r", admin_user_id, query)
|
||||
results = await search_users(query, limit=max(1, min(limit, 50)))
|
||||
return [UserSearchResult(user_id=uid, user_email=email) for uid, email in results]
|
||||
|
||||
@@ -9,7 +9,7 @@ import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
|
||||
from backend.copilot.rate_limit import CoPilotUsageStatus, SubscriptionTier, UsageWindow
|
||||
|
||||
from .rate_limit_admin_routes import router as rate_limit_admin_router
|
||||
|
||||
@@ -57,7 +57,7 @@ def _patch_rate_limit_deps(
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_global_rate_limits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(2_500_000, 12_500_000),
|
||||
return_value=(2_500_000, 12_500_000, SubscriptionTier.FREE),
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_usage_status",
|
||||
@@ -89,6 +89,7 @@ def test_get_rate_limit(
|
||||
assert data["weekly_token_limit"] == 12_500_000
|
||||
assert data["daily_tokens_used"] == 500_000
|
||||
assert data["weekly_tokens_used"] == 3_000_000
|
||||
assert data["tier"] == "FREE"
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(data, indent=2, sort_keys=True) + "\n",
|
||||
@@ -162,6 +163,7 @@ def test_reset_user_usage_daily_only(
|
||||
assert data["daily_tokens_used"] == 0
|
||||
# Weekly is untouched
|
||||
assert data["weekly_tokens_used"] == 3_000_000
|
||||
assert data["tier"] == "FREE"
|
||||
|
||||
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=False)
|
||||
|
||||
@@ -192,6 +194,7 @@ def test_reset_user_usage_daily_and_weekly(
|
||||
data = response.json()
|
||||
assert data["daily_tokens_used"] == 0
|
||||
assert data["weekly_tokens_used"] == 0
|
||||
assert data["tier"] == "FREE"
|
||||
|
||||
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=True)
|
||||
|
||||
@@ -228,7 +231,7 @@ def test_get_rate_limit_email_lookup_failure(
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_global_rate_limits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(2_500_000, 12_500_000),
|
||||
return_value=(2_500_000, 12_500_000, SubscriptionTier.FREE),
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_usage_status",
|
||||
@@ -261,3 +264,294 @@ def test_admin_endpoints_require_admin_role(mock_jwt_user) -> None:
|
||||
json={"user_id": "test"},
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tier management endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_user_tier(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test getting a user's rate-limit tier."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_TARGET_EMAIL,
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.PRO,
|
||||
)
|
||||
|
||||
response = client.get("/admin/rate_limit/tier", params={"user_id": target_user_id})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["tier"] == "PRO"
|
||||
|
||||
|
||||
def test_get_user_tier_user_not_found(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that getting tier for a non-existent user returns 404."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
response = client.get("/admin/rate_limit/tier", params={"user_id": target_user_id})
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_set_user_tier(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test setting a user's rate-limit tier (upgrade)."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_TARGET_EMAIL,
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.FREE,
|
||||
)
|
||||
mock_set = mocker.patch(
|
||||
f"{_MOCK_MODULE}.set_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "ENTERPRISE"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["tier"] == "ENTERPRISE"
|
||||
mock_set.assert_awaited_once_with(target_user_id, SubscriptionTier.ENTERPRISE)
|
||||
|
||||
|
||||
def test_set_user_tier_downgrade(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test downgrading a user's tier from PRO to FREE."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_TARGET_EMAIL,
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.PRO,
|
||||
)
|
||||
mock_set = mocker.patch(
|
||||
f"{_MOCK_MODULE}.set_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "FREE"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_id"] == target_user_id
|
||||
assert data["tier"] == "FREE"
|
||||
mock_set.assert_awaited_once_with(target_user_id, SubscriptionTier.FREE)
|
||||
|
||||
|
||||
def test_set_user_tier_invalid_tier(
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that setting an invalid tier returns 422."""
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "invalid"},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_set_user_tier_invalid_tier_uppercase(
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that setting an unrecognised uppercase tier (e.g. 'INVALID') returns 422.
|
||||
|
||||
Regression: ensures Pydantic enum validation rejects values that are not
|
||||
members of SubscriptionTier, even when they look like valid enum names.
|
||||
"""
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "INVALID"},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
body = response.json()
|
||||
assert "detail" in body
|
||||
|
||||
|
||||
def test_set_user_tier_email_lookup_failure_non_blocking(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that email lookup failure doesn't block tier change."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("DB connection failed"),
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.FREE,
|
||||
)
|
||||
mock_set = mocker.patch(
|
||||
f"{_MOCK_MODULE}.set_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "PRO"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_set.assert_awaited_once()
|
||||
|
||||
|
||||
def test_set_user_tier_db_failure(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
target_user_id: str,
|
||||
) -> None:
|
||||
"""Test that DB failure on set tier returns 500."""
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_TARGET_EMAIL,
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.FREE,
|
||||
)
|
||||
mocker.patch(
|
||||
f"{_MOCK_MODULE}.set_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("DB connection refused"),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": target_user_id, "tier": "PRO"},
|
||||
)
|
||||
|
||||
assert response.status_code == 500
|
||||
|
||||
|
||||
def test_tier_endpoints_require_admin_role(mock_jwt_user) -> None:
|
||||
"""Test that tier admin endpoints require admin role."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
|
||||
response = client.get("/admin/rate_limit/tier", params={"user_id": "test"})
|
||||
assert response.status_code == 403
|
||||
|
||||
response = client.post(
|
||||
"/admin/rate_limit/tier",
|
||||
json={"user_id": "test", "tier": "PRO"},
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
# ─── search_users endpoint ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_search_users_returns_matching_users(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
admin_user_id: str,
|
||||
) -> None:
|
||||
"""Partial search should return all matching users from the User table."""
|
||||
mocker.patch(
|
||||
_MOCK_MODULE + ".search_users",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[
|
||||
("user-1", "zamil.majdy@gmail.com"),
|
||||
("user-2", "zamil.majdy@agpt.co"),
|
||||
],
|
||||
)
|
||||
|
||||
response = client.get("/admin/rate_limit/search_users", params={"query": "zamil"})
|
||||
|
||||
assert response.status_code == 200
|
||||
results = response.json()
|
||||
assert len(results) == 2
|
||||
assert results[0]["user_email"] == "zamil.majdy@gmail.com"
|
||||
assert results[1]["user_email"] == "zamil.majdy@agpt.co"
|
||||
|
||||
|
||||
def test_search_users_empty_results(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
admin_user_id: str,
|
||||
) -> None:
|
||||
"""Search with no matches returns empty list."""
|
||||
mocker.patch(
|
||||
_MOCK_MODULE + ".search_users",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/admin/rate_limit/search_users", params={"query": "nonexistent"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
|
||||
def test_search_users_short_query_rejected(
|
||||
admin_user_id: str,
|
||||
) -> None:
|
||||
"""Query shorter than 3 characters should return 400."""
|
||||
response = client.get("/admin/rate_limit/search_users", params={"query": "ab"})
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
def test_search_users_negative_limit_clamped(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
admin_user_id: str,
|
||||
) -> None:
|
||||
"""Negative limit should be clamped to 1, not passed through."""
|
||||
mock_search = mocker.patch(
|
||||
_MOCK_MODULE + ".search_users",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/admin/rate_limit/search_users", params={"query": "test", "limit": -1}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_search.assert_awaited_once_with("test", limit=1)
|
||||
|
||||
|
||||
def test_search_users_requires_admin_role(mock_jwt_user) -> None:
|
||||
"""Test that the search_users endpoint requires admin role."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
|
||||
response = client.get("/admin/rate_limit/search_users", params={"query": "test"})
|
||||
assert response.status_code == 403
|
||||
|
||||
@@ -4,7 +4,7 @@ import asyncio
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Annotated
|
||||
from typing import Annotated, Literal
|
||||
from uuid import uuid4
|
||||
|
||||
from autogpt_libs import auth
|
||||
@@ -111,6 +111,11 @@ class StreamChatRequest(BaseModel):
|
||||
file_ids: list[str] | None = Field(
|
||||
default=None, max_length=20
|
||||
) # Workspace file IDs attached to this message
|
||||
mode: Literal["fast", "extended_thinking"] | None = Field(
|
||||
default=None,
|
||||
description="Autopilot mode: 'fast' for baseline LLM, 'extended_thinking' for Claude Agent SDK. "
|
||||
"If None, uses the server default (extended_thinking).",
|
||||
)
|
||||
|
||||
|
||||
class CreateSessionRequest(BaseModel):
|
||||
@@ -456,8 +461,9 @@ async def get_copilot_usage(
|
||||
|
||||
Returns current token usage vs limits for daily and weekly windows.
|
||||
Global defaults sourced from LaunchDarkly (falling back to config).
|
||||
Includes the user's rate-limit tier.
|
||||
"""
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
return await get_usage_status(
|
||||
@@ -465,6 +471,7 @@ async def get_copilot_usage(
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
rate_limit_reset_cost=config.rate_limit_reset_cost,
|
||||
tier=tier,
|
||||
)
|
||||
|
||||
|
||||
@@ -516,7 +523,7 @@ async def reset_copilot_usage(
|
||||
detail="Rate limit reset is not available (credit system is disabled).",
|
||||
)
|
||||
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
daily_limit, weekly_limit, tier = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
|
||||
@@ -550,10 +557,13 @@ async def reset_copilot_usage(
|
||||
|
||||
try:
|
||||
# Verify the user is actually at or over their daily limit.
|
||||
# (rate_limit_reset_cost intentionally omitted — this object is only
|
||||
# used for limit checks, not returned to the client.)
|
||||
usage_status = await get_usage_status(
|
||||
user_id=user_id,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
tier=tier,
|
||||
)
|
||||
if daily_limit > 0 and usage_status.daily.used < daily_limit:
|
||||
raise HTTPException(
|
||||
@@ -629,6 +639,7 @@ async def reset_copilot_usage(
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
rate_limit_reset_cost=config.rate_limit_reset_cost,
|
||||
tier=tier,
|
||||
)
|
||||
|
||||
return RateLimitResetResponse(
|
||||
@@ -739,7 +750,7 @@ async def stream_chat_post(
|
||||
# Global defaults sourced from LaunchDarkly, falling back to config.
|
||||
if user_id:
|
||||
try:
|
||||
daily_limit, weekly_limit = await get_global_rate_limits(
|
||||
daily_limit, weekly_limit, _ = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
await check_rate_limit(
|
||||
@@ -834,6 +845,7 @@ async def stream_chat_post(
|
||||
is_user_message=request.is_user_message,
|
||||
context=request.context,
|
||||
file_ids=sanitized_file_ids,
|
||||
mode=request.mode,
|
||||
)
|
||||
|
||||
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
||||
|
||||
@@ -9,6 +9,7 @@ import pytest
|
||||
import pytest_mock
|
||||
|
||||
from backend.api.features.chat import routes as chat_routes
|
||||
from backend.copilot.rate_limit import SubscriptionTier
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(chat_routes.router)
|
||||
@@ -331,14 +332,28 @@ def _mock_usage(
|
||||
*,
|
||||
daily_used: int = 500,
|
||||
weekly_used: int = 2000,
|
||||
daily_limit: int = 10000,
|
||||
weekly_limit: int = 50000,
|
||||
tier: "SubscriptionTier" = SubscriptionTier.FREE,
|
||||
) -> AsyncMock:
|
||||
"""Mock get_usage_status to return a predictable CoPilotUsageStatus."""
|
||||
"""Mock get_usage_status and get_global_rate_limits for usage endpoint tests.
|
||||
|
||||
Mocks both ``get_global_rate_limits`` (returns the given limits + tier) and
|
||||
``get_usage_status`` so that tests exercise the endpoint without hitting
|
||||
LaunchDarkly or Prisma.
|
||||
"""
|
||||
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_global_rate_limits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(daily_limit, weekly_limit, tier),
|
||||
)
|
||||
|
||||
resets_at = datetime.now(UTC) + timedelta(days=1)
|
||||
status = CoPilotUsageStatus(
|
||||
daily=UsageWindow(used=daily_used, limit=10000, resets_at=resets_at),
|
||||
weekly=UsageWindow(used=weekly_used, limit=50000, resets_at=resets_at),
|
||||
daily=UsageWindow(used=daily_used, limit=daily_limit, resets_at=resets_at),
|
||||
weekly=UsageWindow(used=weekly_used, limit=weekly_limit, resets_at=resets_at),
|
||||
)
|
||||
return mocker.patch(
|
||||
"backend.api.features.chat.routes.get_usage_status",
|
||||
@@ -369,6 +384,7 @@ def test_usage_returns_daily_and_weekly(
|
||||
daily_token_limit=10000,
|
||||
weekly_token_limit=50000,
|
||||
rate_limit_reset_cost=chat_routes.config.rate_limit_reset_cost,
|
||||
tier=SubscriptionTier.FREE,
|
||||
)
|
||||
|
||||
|
||||
@@ -376,11 +392,9 @@ def test_usage_uses_config_limits(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""The endpoint forwards daily_token_limit and weekly_token_limit from config."""
|
||||
mock_get = _mock_usage(mocker)
|
||||
"""The endpoint forwards resolved limits from get_global_rate_limits to get_usage_status."""
|
||||
mock_get = _mock_usage(mocker, daily_limit=99999, weekly_limit=77777)
|
||||
|
||||
mocker.patch.object(chat_routes.config, "daily_token_limit", 99999)
|
||||
mocker.patch.object(chat_routes.config, "weekly_token_limit", 77777)
|
||||
mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", 500)
|
||||
|
||||
response = client.get("/usage")
|
||||
@@ -391,6 +405,7 @@ def test_usage_uses_config_limits(
|
||||
daily_token_limit=99999,
|
||||
weekly_token_limit=77777,
|
||||
rate_limit_reset_cost=500,
|
||||
tier=SubscriptionTier.FREE,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -481,6 +481,11 @@ async def create_library_agent(
|
||||
sensitive_action_safe_mode=sensitive_action_safe_mode,
|
||||
).model_dump()
|
||||
),
|
||||
**(
|
||||
{"Folder": {"connect": {"id": folder_id}}}
|
||||
if folder_id and graph_entry is graph
|
||||
else {}
|
||||
),
|
||||
},
|
||||
},
|
||||
include=library_agent_include(
|
||||
|
||||
@@ -189,6 +189,7 @@ async def test_create_store_submission(mocker):
|
||||
notifyOnAgentApproved=True,
|
||||
notifyOnAgentRejected=True,
|
||||
timezone="Europe/Delft",
|
||||
subscriptionTier=prisma.enums.SubscriptionTier.FREE, # type: ignore[reportCallIssue,reportAttributeAccessIssue]
|
||||
)
|
||||
mock_agent = prisma.models.AgentGraph(
|
||||
id="agent-id",
|
||||
|
||||
@@ -18,6 +18,7 @@ from prisma.errors import PrismaError
|
||||
|
||||
import backend.api.features.admin.credit_admin_routes
|
||||
import backend.api.features.admin.execution_analytics_routes
|
||||
import backend.api.features.admin.platform_cost_routes
|
||||
import backend.api.features.admin.rate_limit_admin_routes
|
||||
import backend.api.features.admin.store_admin_routes
|
||||
import backend.api.features.builder
|
||||
@@ -329,6 +330,11 @@ app.include_router(
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/copilot",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.admin.platform_cost_routes.router,
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/platform-costs",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.executions.review.routes.router,
|
||||
tags=["v2", "executions", "review"],
|
||||
|
||||
@@ -698,13 +698,30 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
if should_pause:
|
||||
return
|
||||
|
||||
# Validate the input data (original or reviewer-modified) once
|
||||
if error := self.input_schema.validate_data(input_data):
|
||||
raise BlockInputError(
|
||||
message=f"Unable to execute block with invalid input data: {error}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
# Validate the input data (original or reviewer-modified) once.
|
||||
# In dry-run mode, credential fields may contain sentinel None values
|
||||
# that would fail JSON schema required checks. We still validate the
|
||||
# non-credential fields so blocks that execute for real during dry-run
|
||||
# (e.g. AgentExecutorBlock) get proper input validation.
|
||||
is_dry_run = getattr(kwargs.get("execution_context"), "dry_run", False)
|
||||
if is_dry_run:
|
||||
cred_field_names = set(self.input_schema.get_credentials_fields().keys())
|
||||
non_cred_data = {
|
||||
k: v for k, v in input_data.items() if k not in cred_field_names
|
||||
}
|
||||
if error := self.input_schema.validate_data(non_cred_data):
|
||||
raise BlockInputError(
|
||||
message=f"Unable to execute block with invalid input data: {error}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
else:
|
||||
if error := self.input_schema.validate_data(input_data):
|
||||
raise BlockInputError(
|
||||
message=f"Unable to execute block with invalid input data: {error}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
|
||||
# Use the validated input data
|
||||
async for output_name, output_data in self.run(
|
||||
|
||||
@@ -49,11 +49,17 @@ class AgentExecutorBlock(Block):
|
||||
@classmethod
|
||||
def get_missing_input(cls, data: BlockInput) -> set[str]:
|
||||
required_fields = cls.get_input_schema(data).get("required", [])
|
||||
return set(required_fields) - set(data)
|
||||
# Check against the nested `inputs` dict, not the top-level node
|
||||
# data — required fields like "topic" live inside data["inputs"],
|
||||
# not at data["topic"].
|
||||
provided = data.get("inputs", {})
|
||||
return set(required_fields) - set(provided)
|
||||
|
||||
@classmethod
|
||||
def get_mismatch_error(cls, data: BlockInput) -> str | None:
|
||||
return validate_with_jsonschema(cls.get_input_schema(data), data)
|
||||
return validate_with_jsonschema(
|
||||
cls.get_input_schema(data), data.get("inputs", {})
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Use BlockSchema to avoid automatic error field that could clash with graph outputs
|
||||
@@ -88,6 +94,7 @@ class AgentExecutorBlock(Block):
|
||||
execution_context=execution_context.model_copy(
|
||||
update={"parent_execution_id": graph_exec_id},
|
||||
),
|
||||
dry_run=execution_context.dry_run,
|
||||
)
|
||||
|
||||
logger = execution_utils.LogMetadata(
|
||||
@@ -149,14 +156,19 @@ class AgentExecutorBlock(Block):
|
||||
ExecutionStatus.TERMINATED,
|
||||
ExecutionStatus.FAILED,
|
||||
]:
|
||||
logger.debug(
|
||||
f"Execution {log_id} received event {event.event_type} with status {event.status}"
|
||||
logger.info(
|
||||
f"Execution {log_id} skipping event {event.event_type} status={event.status} "
|
||||
f"node={getattr(event, 'node_exec_id', '?')}"
|
||||
)
|
||||
continue
|
||||
|
||||
if event.event_type == ExecutionEventType.GRAPH_EXEC_UPDATE:
|
||||
# If the graph execution is COMPLETED, TERMINATED, or FAILED,
|
||||
# we can stop listening for further events.
|
||||
logger.info(
|
||||
f"Execution {log_id} graph completed with status {event.status}, "
|
||||
f"yielded {len(yielded_node_exec_ids)} outputs"
|
||||
)
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
extra_cost=event.stats.cost if event.stats else 0,
|
||||
|
||||
@@ -18,6 +18,7 @@ from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -358,6 +359,7 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
yield "video_url", stored_url
|
||||
|
||||
|
||||
@@ -565,6 +567,7 @@ class AIAdMakerVideoCreatorBlock(Block):
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
yield "video_url", stored_url
|
||||
|
||||
|
||||
@@ -760,4 +763,5 @@ class AIScreenshotToVideoAdBlock(Block):
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
yield "video_url", stored_url
|
||||
|
||||
@@ -17,7 +17,7 @@ from backend.blocks.apollo.models import (
|
||||
PrimaryPhone,
|
||||
SearchOrganizationsRequest,
|
||||
)
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
|
||||
|
||||
|
||||
class SearchOrganizationsBlock(Block):
|
||||
@@ -218,6 +218,7 @@ To find IDs, identify the values for organization_id when you call this endpoint
|
||||
) -> BlockOutput:
|
||||
query = SearchOrganizationsRequest(**input_data.model_dump())
|
||||
organizations = await self.search_organizations(query, credentials)
|
||||
self.merge_stats(NodeExecutionStats(output_size=len(organizations)))
|
||||
for organization in organizations:
|
||||
yield "organization", organization
|
||||
yield "organizations", organizations
|
||||
|
||||
@@ -21,7 +21,7 @@ from backend.blocks.apollo.models import (
|
||||
SearchPeopleRequest,
|
||||
SenorityLevels,
|
||||
)
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
|
||||
|
||||
|
||||
class SearchPeopleBlock(Block):
|
||||
@@ -366,4 +366,5 @@ class SearchPeopleBlock(Block):
|
||||
*(enrich_or_fallback(person) for person in people)
|
||||
)
|
||||
|
||||
self.merge_stats(NodeExecutionStats(output_size=len(people)))
|
||||
yield "people", people
|
||||
|
||||
@@ -13,7 +13,7 @@ from backend.blocks.apollo._auth import (
|
||||
ApolloCredentialsInput,
|
||||
)
|
||||
from backend.blocks.apollo.models import Contact, EnrichPersonRequest
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
|
||||
|
||||
|
||||
class GetPersonDetailBlock(Block):
|
||||
@@ -141,4 +141,5 @@ class GetPersonDetailBlock(Block):
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
query = EnrichPersonRequest(**input_data.model_dump())
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
yield "contact", await self.enrich_person(query, credentials)
|
||||
|
||||
@@ -17,6 +17,7 @@ from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -342,6 +343,7 @@ class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
|
||||
|
||||
# Determine result object shape & filter out empty formats
|
||||
main_result, results = self.process_execution_results(results)
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
if main_result:
|
||||
yield "main_result", main_result
|
||||
yield "results", results
|
||||
@@ -467,6 +469,7 @@ class InstantiateCodeSandboxBlock(Block, BaseE2BExecutorMixin):
|
||||
setup_commands=input_data.setup_commands,
|
||||
timeout=input_data.timeout,
|
||||
)
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
if sandbox_id:
|
||||
yield "sandbox_id", sandbox_id
|
||||
else:
|
||||
@@ -577,6 +580,7 @@ class ExecuteCodeStepBlock(Block, BaseE2BExecutorMixin):
|
||||
|
||||
# Determine result object shape & filter out empty formats
|
||||
main_result, results = self.process_execution_results(results)
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
if main_result:
|
||||
yield "main_result", main_result
|
||||
yield "results", results
|
||||
|
||||
@@ -15,7 +15,12 @@ from backend.blocks._base import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import APIKeyCredentials, CredentialsField, SchemaField
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
from ._api import (
|
||||
@@ -195,6 +200,7 @@ class GetLinkedinProfileBlock(Block):
|
||||
include_social_media=input_data.include_social_media,
|
||||
include_extra=input_data.include_extra,
|
||||
)
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
yield "profile", profile
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching LinkedIn profile: {str(e)}")
|
||||
@@ -341,6 +347,7 @@ class LinkedinPersonLookupBlock(Block):
|
||||
include_similarity_checks=input_data.include_similarity_checks,
|
||||
enrich_profile=input_data.enrich_profile,
|
||||
)
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
yield "lookup_result", lookup_result
|
||||
except Exception as e:
|
||||
logger.error(f"Error looking up LinkedIn profile: {str(e)}")
|
||||
@@ -443,6 +450,7 @@ class LinkedinRoleLookupBlock(Block):
|
||||
company_name=input_data.company_name,
|
||||
enrich_profile=input_data.enrich_profile,
|
||||
)
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
yield "role_lookup_result", role_lookup_result
|
||||
except Exception as e:
|
||||
logger.error(f"Error looking up role in company: {str(e)}")
|
||||
@@ -523,6 +531,7 @@ class GetLinkedinProfilePictureBlock(Block):
|
||||
credentials=credentials,
|
||||
linkedin_profile_url=input_data.linkedin_profile_url,
|
||||
)
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
yield "profile_picture_url", profile_picture
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting profile picture: {str(e)}")
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Optional
|
||||
from exa_py import AsyncExa
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -223,3 +224,6 @@ class ExaContentsBlock(Block):
|
||||
|
||||
if response.cost_dollars:
|
||||
yield "cost_dollars", response.cost_dollars
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=response.cost_dollars.total)
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -206,3 +207,6 @@ class ExaSearchBlock(Block):
|
||||
|
||||
if response.cost_dollars:
|
||||
yield "cost_dollars", response.cost_dollars
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=response.cost_dollars.total)
|
||||
)
|
||||
|
||||
@@ -18,7 +18,7 @@ from backend.blocks.fal._auth import (
|
||||
FalCredentialsInput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.model import NodeExecutionStats, SchemaField
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.request import ClientResponseError, Requests
|
||||
from backend.util.type import MediaFileType
|
||||
@@ -230,6 +230,7 @@ class AIVideoGeneratorBlock(Block):
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
yield "video_url", stored_url
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
|
||||
@@ -14,6 +14,7 @@ from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -117,6 +118,7 @@ class GoogleMapsSearchBlock(Block):
|
||||
input_data.radius,
|
||||
input_data.max_results,
|
||||
)
|
||||
self.merge_stats(NodeExecutionStats(output_size=len(places)))
|
||||
for place in places:
|
||||
yield "place", place
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -227,6 +228,7 @@ class IdeogramModelBlock(Block):
|
||||
image_url=result,
|
||||
)
|
||||
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
yield "result", result
|
||||
|
||||
async def run_model(
|
||||
|
||||
@@ -2,6 +2,8 @@ import copy
|
||||
from datetime import date, time
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
@@ -467,7 +469,8 @@ class AgentFileInputBlock(AgentInputBlock):
|
||||
|
||||
class AgentDropdownInputBlock(AgentInputBlock):
|
||||
"""
|
||||
A specialized text input block that relies on placeholder_values to present a dropdown.
|
||||
A specialized text input block that presents a dropdown selector
|
||||
restricted to a fixed set of values.
|
||||
"""
|
||||
|
||||
class Input(AgentInputBlock.Input):
|
||||
@@ -477,16 +480,23 @@ class AgentDropdownInputBlock(AgentInputBlock):
|
||||
advanced=False,
|
||||
title="Default Value",
|
||||
)
|
||||
placeholder_values: list = SchemaField(
|
||||
description="Possible values for the dropdown.",
|
||||
# Use Field() directly (not SchemaField) to pass validation_alias,
|
||||
# which handles backward compat for legacy "placeholder_values" across
|
||||
# all construction paths (model_construct, __init__, model_validate).
|
||||
options: list = Field(
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
title="Dropdown Options",
|
||||
description=(
|
||||
"If provided, renders the input as a dropdown selector "
|
||||
"restricted to these values. Leave empty for free-text input."
|
||||
),
|
||||
validation_alias=AliasChoices("options", "placeholder_values"),
|
||||
json_schema_extra={"advanced": False, "secret": False},
|
||||
)
|
||||
|
||||
def generate_schema(self):
|
||||
schema = super().generate_schema()
|
||||
if possible_values := self.placeholder_values:
|
||||
if possible_values := self.options:
|
||||
schema["enum"] = possible_values
|
||||
return schema
|
||||
|
||||
@@ -504,13 +514,13 @@ class AgentDropdownInputBlock(AgentInputBlock):
|
||||
{
|
||||
"value": "Option A",
|
||||
"name": "dropdown_1",
|
||||
"placeholder_values": ["Option A", "Option B", "Option C"],
|
||||
"options": ["Option A", "Option B", "Option C"],
|
||||
"description": "Dropdown example 1",
|
||||
},
|
||||
{
|
||||
"value": "Option C",
|
||||
"name": "dropdown_2",
|
||||
"placeholder_values": ["Option A", "Option B", "Option C"],
|
||||
"options": ["Option A", "Option B", "Option C"],
|
||||
"description": "Dropdown example 2",
|
||||
},
|
||||
],
|
||||
|
||||
@@ -10,7 +10,7 @@ from backend.blocks.jina._auth import (
|
||||
JinaCredentialsField,
|
||||
JinaCredentialsInput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.model import NodeExecutionStats, SchemaField
|
||||
from backend.util.request import Requests
|
||||
|
||||
|
||||
@@ -45,5 +45,13 @@ class JinaEmbeddingBlock(Block):
|
||||
}
|
||||
data = {"input": input_data.texts, "model": input_data.model}
|
||||
response = await Requests().post(url, headers=headers, json=data)
|
||||
embeddings = [e["embedding"] for e in response.json()["data"]]
|
||||
resp_json = response.json()
|
||||
embeddings = [e["embedding"] for e in resp_json["data"]]
|
||||
usage = resp_json.get("usage", {})
|
||||
if usage.get("total_tokens"):
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=usage.get("total_tokens", 0),
|
||||
)
|
||||
)
|
||||
yield "embeddings", embeddings
|
||||
|
||||
@@ -687,6 +687,7 @@ class LLMResponse(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
reasoning: Optional[str] = None
|
||||
provider_cost: float | None = None
|
||||
|
||||
|
||||
def convert_openai_tool_fmt_to_anthropic(
|
||||
@@ -1045,6 +1046,16 @@ async def llm_call(
|
||||
tool_calls = extract_openai_tool_calls(response)
|
||||
reasoning = extract_openai_reasoning(response)
|
||||
|
||||
cost = None
|
||||
try:
|
||||
raw_resp = getattr(response, "_response", None)
|
||||
if raw_resp and hasattr(raw_resp, "headers"):
|
||||
cost_header = raw_resp.headers.get("x-total-cost")
|
||||
if cost_header:
|
||||
cost = float(cost_header)
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
|
||||
return LLMResponse(
|
||||
raw_response=response.choices[0].message,
|
||||
prompt=prompt,
|
||||
@@ -1053,6 +1064,7 @@ async def llm_call(
|
||||
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
|
||||
completion_tokens=response.usage.completion_tokens if response.usage else 0,
|
||||
reasoning=reasoning,
|
||||
provider_cost=cost,
|
||||
)
|
||||
elif provider == "llama_api":
|
||||
tools_param = tools if tools else openai.NOT_GIVEN
|
||||
@@ -1377,12 +1389,13 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
max_tokens=input_data.max_tokens,
|
||||
)
|
||||
response_text = llm_response.response
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=llm_response.prompt_tokens,
|
||||
output_token_count=llm_response.completion_tokens,
|
||||
)
|
||||
cost_stats = NodeExecutionStats(
|
||||
input_token_count=llm_response.prompt_tokens,
|
||||
output_token_count=llm_response.completion_tokens,
|
||||
)
|
||||
if llm_response.provider_cost is not None:
|
||||
cost_stats.provider_cost = llm_response.provider_cost
|
||||
self.merge_stats(cost_stats)
|
||||
logger.debug(f"LLM attempt-{retry_count} response: {response_text}")
|
||||
|
||||
if input_data.expected_format:
|
||||
|
||||
@@ -89,6 +89,12 @@ class MCPToolBlock(Block):
|
||||
default={},
|
||||
hidden=True,
|
||||
)
|
||||
tool_description: str = SchemaField(
|
||||
description="Description of the selected MCP tool. "
|
||||
"Populated automatically when a tool is selected.",
|
||||
default="",
|
||||
hidden=True,
|
||||
)
|
||||
|
||||
tool_arguments: dict[str, Any] = SchemaField(
|
||||
description="Arguments to pass to the selected MCP tool. "
|
||||
|
||||
@@ -8,6 +8,7 @@ from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -153,6 +154,7 @@ class AddMemoryBlock(Block, Mem0Base):
|
||||
messages,
|
||||
**params,
|
||||
)
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
|
||||
results = result.get("results", [])
|
||||
yield "results", results
|
||||
@@ -255,6 +257,7 @@ class SearchMemoryBlock(Block, Mem0Base):
|
||||
result: list[dict[str, Any]] = client.search(
|
||||
input_data.query, version="v2", filters=filters
|
||||
)
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
yield "memories", result
|
||||
|
||||
except Exception as e:
|
||||
@@ -340,6 +343,7 @@ class GetAllMemoriesBlock(Block, Mem0Base):
|
||||
filters=filters,
|
||||
version="v2",
|
||||
)
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
|
||||
yield "memories", memories
|
||||
|
||||
@@ -434,6 +438,7 @@ class GetLatestMemoryBlock(Block, Mem0Base):
|
||||
filters=filters,
|
||||
version="v2",
|
||||
)
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
|
||||
if memories:
|
||||
# Return the latest memory (first in the list as they're sorted by recency)
|
||||
|
||||
@@ -10,7 +10,7 @@ from backend.blocks.nvidia._auth import (
|
||||
NvidiaCredentialsField,
|
||||
NvidiaCredentialsInput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.model import NodeExecutionStats, SchemaField
|
||||
from backend.util.request import Requests
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
@@ -69,6 +69,7 @@ class NvidiaDeepfakeDetectBlock(Block):
|
||||
data = response.json()
|
||||
|
||||
result = data.get("data", [{}])[0]
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
|
||||
# Get deepfake probability from first bounding box if any
|
||||
deepfake_prob = 0.0
|
||||
|
||||
@@ -17,7 +17,12 @@ from backend.blocks.replicate._auth import (
|
||||
ReplicateCredentialsInput,
|
||||
)
|
||||
from backend.blocks.replicate._helper import ReplicateOutputs, extract_result
|
||||
from backend.data.model import APIKeyCredentials, CredentialsField, SchemaField
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.util.exceptions import BlockExecutionError, BlockInputError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -108,6 +113,7 @@ class ReplicateModelBlock(Block):
|
||||
result = await self.run_model(
|
||||
model_ref, input_data.model_inputs, credentials.api_key
|
||||
)
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
yield "result", result
|
||||
yield "status", "succeeded"
|
||||
yield "model_name", input_data.model_name
|
||||
|
||||
@@ -16,6 +16,7 @@ from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -185,6 +186,7 @@ class ScreenshotWebPageBlock(Block):
|
||||
block_chats=input_data.block_chats,
|
||||
cache=input_data.cache,
|
||||
)
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
yield "image", screenshot_data["image"]
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
@@ -15,6 +15,7 @@ from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -146,6 +147,7 @@ class GetWeatherInformationBlock(Block, GetRequest):
|
||||
weather_data = await self.get_request(url, json=True)
|
||||
|
||||
if "main" in weather_data and "weather" in weather_data:
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
yield "temperature", str(weather_data["main"]["temp"])
|
||||
yield "humidity", str(weather_data["main"]["humidity"])
|
||||
yield "condition", weather_data["weather"][0]["description"]
|
||||
|
||||
@@ -23,7 +23,7 @@ from backend.blocks.smartlead.models import (
|
||||
SaveSequencesResponse,
|
||||
Sequence,
|
||||
)
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
|
||||
|
||||
|
||||
class CreateCampaignBlock(Block):
|
||||
@@ -100,6 +100,7 @@ class CreateCampaignBlock(Block):
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
response = await self.create_campaign(input_data.name, credentials)
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
|
||||
yield "id", response.id
|
||||
yield "name", response.name
|
||||
@@ -226,6 +227,7 @@ class AddLeadToCampaignBlock(Block):
|
||||
response = await self.add_leads_to_campaign(
|
||||
input_data.campaign_id, input_data.lead_list, credentials
|
||||
)
|
||||
self.merge_stats(NodeExecutionStats(output_size=len(input_data.lead_list)))
|
||||
|
||||
yield "campaign_id", input_data.campaign_id
|
||||
yield "upload_count", response.upload_count
|
||||
@@ -321,6 +323,7 @@ class SaveCampaignSequencesBlock(Block):
|
||||
response = await self.save_campaign_sequences(
|
||||
input_data.campaign_id, input_data.sequences, credentials
|
||||
)
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
|
||||
if response.data:
|
||||
yield "data", response.data
|
||||
|
||||
304
autogpt_platform/backend/backend/blocks/sql_query_block.py
Normal file
304
autogpt_platform/backend/backend/blocks/sql_query_block.py
Normal file
@@ -0,0 +1,304 @@
|
||||
import asyncio
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
from sqlalchemy.engine.url import URL
|
||||
from sqlalchemy.exc import DBAPIError, OperationalError, ProgrammingError
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.blocks.sql_query_helpers import (
|
||||
_DATABASE_TYPE_DEFAULT_PORT,
|
||||
_DATABASE_TYPE_TO_DRIVER,
|
||||
DatabaseType,
|
||||
_execute_query,
|
||||
_sanitize_error,
|
||||
_validate_query_is_read_only,
|
||||
_validate_single_statement,
|
||||
)
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
UserPasswordCredentials,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import resolve_and_check_blocked
|
||||
|
||||
TEST_CREDENTIALS = UserPasswordCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="database",
|
||||
username=SecretStr("test_user"),
|
||||
password=SecretStr("test_pass"),
|
||||
title="Mock Database credentials",
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
DatabaseCredentials = UserPasswordCredentials
|
||||
DatabaseCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.DATABASE],
|
||||
Literal["user_password"],
|
||||
]
|
||||
|
||||
|
||||
def DatabaseCredentialsField() -> DatabaseCredentialsInput:
|
||||
return CredentialsField(
|
||||
description="Database username and password",
|
||||
)
|
||||
|
||||
|
||||
class SQLQueryBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
database_type: DatabaseType = SchemaField(
|
||||
default=DatabaseType.POSTGRES,
|
||||
description="Database engine",
|
||||
advanced=False,
|
||||
)
|
||||
host: SecretStr = SchemaField(
|
||||
description="Database hostname or IP address",
|
||||
placeholder="db.example.com",
|
||||
secret=True,
|
||||
)
|
||||
port: int | None = SchemaField(
|
||||
default=None,
|
||||
description=(
|
||||
"Database port (leave empty for default: "
|
||||
"PostgreSQL: 5432, MySQL: 3306, MSSQL: 1433)"
|
||||
),
|
||||
ge=1,
|
||||
le=65535,
|
||||
)
|
||||
database: str = SchemaField(
|
||||
description="Name of the database to connect to",
|
||||
placeholder="my_database",
|
||||
)
|
||||
query: str = SchemaField(
|
||||
description="SQL query to execute",
|
||||
placeholder="SELECT * FROM analytics.daily_active_users LIMIT 10",
|
||||
)
|
||||
read_only: bool = SchemaField(
|
||||
default=True,
|
||||
description=(
|
||||
"When enabled (default), only SELECT queries are allowed "
|
||||
"and the database session is set to read-only mode. "
|
||||
"Disable to allow write operations (INSERT, UPDATE, DELETE, etc.)."
|
||||
),
|
||||
)
|
||||
timeout: int = SchemaField(
|
||||
default=30,
|
||||
description="Query timeout in seconds (max 120)",
|
||||
ge=1,
|
||||
le=120,
|
||||
)
|
||||
max_rows: int = SchemaField(
|
||||
default=1000,
|
||||
description="Maximum number of rows to return (max 10000)",
|
||||
ge=1,
|
||||
le=10000,
|
||||
)
|
||||
credentials: DatabaseCredentialsInput = DatabaseCredentialsField()
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
results: list[dict[str, Any]] = SchemaField(
|
||||
description="Query results as a list of row dictionaries"
|
||||
)
|
||||
columns: list[str] = SchemaField(
|
||||
description="Column names from the query result"
|
||||
)
|
||||
row_count: int = SchemaField(description="Number of rows returned")
|
||||
affected_rows: int = SchemaField(
|
||||
description="Number of rows affected by a write query (INSERT/UPDATE/DELETE)"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the query failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="4dc35c0f-4fd8-465e-9616-5a216f1ba2bc",
|
||||
description=(
|
||||
"Execute a SQL query. Read-only by default for safety "
|
||||
"-- disable to allow write operations. "
|
||||
"Supports PostgreSQL, MySQL, and MSSQL via SQLAlchemy."
|
||||
),
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=SQLQueryBlock.Input,
|
||||
output_schema=SQLQueryBlock.Output,
|
||||
test_input={
|
||||
"query": "SELECT 1 AS test_col",
|
||||
"database_type": DatabaseType.POSTGRES,
|
||||
"host": "localhost",
|
||||
"database": "test_db",
|
||||
"timeout": 30,
|
||||
"max_rows": 1000,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("results", [{"test_col": 1}]),
|
||||
("columns", ["test_col"]),
|
||||
("row_count", 1),
|
||||
],
|
||||
test_mock={
|
||||
"execute_query": lambda *_args, **_kwargs: (
|
||||
[{"test_col": 1}],
|
||||
["test_col"],
|
||||
-1,
|
||||
),
|
||||
"check_host_allowed": lambda *_args, **_kwargs: ["127.0.0.1"],
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def check_host_allowed(host: str) -> list[str]:
|
||||
"""Validate that the given host is not a private/blocked address.
|
||||
|
||||
Returns the list of resolved IP addresses so the caller can pin the
|
||||
connection to the validated IP (preventing DNS rebinding / TOCTOU).
|
||||
Raises ValueError or OSError if the host is blocked.
|
||||
Extracted as a method so it can be mocked during block tests.
|
||||
"""
|
||||
return await resolve_and_check_blocked(host)
|
||||
|
||||
@staticmethod
|
||||
def execute_query(
|
||||
connection_url: URL | str,
|
||||
query: str,
|
||||
timeout: int,
|
||||
max_rows: int,
|
||||
read_only: bool = True,
|
||||
database_type: DatabaseType = DatabaseType.POSTGRES,
|
||||
) -> tuple[list[dict[str, Any]], list[str], int]:
|
||||
"""Execute a SQL query and return (rows, columns, affected_rows).
|
||||
|
||||
Delegates to ``_execute_query`` in ``sql_query_helpers``.
|
||||
Extracted as a method so it can be mocked during block tests.
|
||||
"""
|
||||
return _execute_query(
|
||||
connection_url=connection_url,
|
||||
query=query,
|
||||
timeout=timeout,
|
||||
max_rows=max_rows,
|
||||
read_only=read_only,
|
||||
database_type=database_type,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: DatabaseCredentials,
|
||||
**_kwargs: Any,
|
||||
) -> BlockOutput:
|
||||
# Validate query structure and read-only constraints.
|
||||
error = self._validate_query(input_data)
|
||||
if error:
|
||||
yield "error", error
|
||||
return
|
||||
|
||||
# Validate host and resolve for SSRF protection.
|
||||
host, pinned_host, error = await self._resolve_host(input_data)
|
||||
if error:
|
||||
yield "error", error
|
||||
return
|
||||
|
||||
# Build connection URL and execute.
|
||||
port = input_data.port or _DATABASE_TYPE_DEFAULT_PORT[input_data.database_type]
|
||||
username = credentials.username.get_secret_value()
|
||||
connection_url = URL.create(
|
||||
drivername=_DATABASE_TYPE_TO_DRIVER[input_data.database_type],
|
||||
username=username,
|
||||
password=credentials.password.get_secret_value(),
|
||||
host=pinned_host,
|
||||
port=port,
|
||||
database=input_data.database,
|
||||
)
|
||||
conn_str = connection_url.render_as_string(hide_password=True)
|
||||
db_name = input_data.database
|
||||
|
||||
def _sanitize(err: Exception) -> str:
|
||||
return _sanitize_error(
|
||||
str(err).strip(),
|
||||
conn_str,
|
||||
host=pinned_host,
|
||||
original_host=host,
|
||||
username=username,
|
||||
port=port,
|
||||
database=db_name,
|
||||
)
|
||||
|
||||
try:
|
||||
results, columns, affected = await asyncio.to_thread(
|
||||
self.execute_query,
|
||||
connection_url=connection_url,
|
||||
query=input_data.query,
|
||||
timeout=input_data.timeout,
|
||||
max_rows=input_data.max_rows,
|
||||
read_only=input_data.read_only,
|
||||
database_type=input_data.database_type,
|
||||
)
|
||||
yield "results", results
|
||||
yield "columns", columns
|
||||
yield "row_count", len(results)
|
||||
if affected >= 0:
|
||||
yield "affected_rows", affected
|
||||
except OperationalError as e:
|
||||
yield "error", self._classify_operational_error(
|
||||
_sanitize(e),
|
||||
input_data.timeout,
|
||||
)
|
||||
except ProgrammingError as e:
|
||||
yield "error", f"SQL error: {_sanitize(e)}"
|
||||
except DBAPIError as e:
|
||||
yield "error", f"Database error: {_sanitize(e)}"
|
||||
except ModuleNotFoundError:
|
||||
yield "error", (
|
||||
f"Database driver not available for "
|
||||
f"{input_data.database_type.value}. "
|
||||
f"Please contact the platform administrator."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _validate_query(input_data: "SQLQueryBlock.Input") -> str | None:
|
||||
"""Validate query structure and read-only constraints."""
|
||||
stmt_error, parsed_stmt = _validate_single_statement(input_data.query)
|
||||
if stmt_error:
|
||||
return stmt_error
|
||||
assert parsed_stmt is not None
|
||||
if input_data.read_only:
|
||||
return _validate_query_is_read_only(parsed_stmt)
|
||||
return None
|
||||
|
||||
async def _resolve_host(
|
||||
self, input_data: "SQLQueryBlock.Input"
|
||||
) -> tuple[str, str, str | None]:
|
||||
"""Validate and resolve the database host. Returns (host, pinned_ip, error)."""
|
||||
host = input_data.host.get_secret_value().strip()
|
||||
if not host:
|
||||
return "", "", "Database host is required."
|
||||
if host.startswith("/"):
|
||||
return host, "", "Unix socket connections are not allowed."
|
||||
try:
|
||||
resolved_ips = await self.check_host_allowed(host)
|
||||
except (ValueError, OSError) as e:
|
||||
return host, "", f"Blocked host: {str(e).strip()}"
|
||||
return host, resolved_ips[0], None
|
||||
|
||||
@staticmethod
|
||||
def _classify_operational_error(sanitized_msg: str, timeout: int) -> str:
|
||||
"""Classify an already-sanitized OperationalError for user display."""
|
||||
lower = sanitized_msg.lower()
|
||||
if "timeout" in lower or "cancel" in lower:
|
||||
return f"Query timed out after {timeout}s."
|
||||
if "connect" in lower:
|
||||
return f"Failed to connect to database: {sanitized_msg}"
|
||||
return f"Database error: {sanitized_msg}"
|
||||
1396
autogpt_platform/backend/backend/blocks/sql_query_block_test.py
Normal file
1396
autogpt_platform/backend/backend/blocks/sql_query_block_test.py
Normal file
File diff suppressed because it is too large
Load Diff
376
autogpt_platform/backend/backend/blocks/sql_query_helpers.py
Normal file
376
autogpt_platform/backend/backend/blocks/sql_query_helpers.py
Normal file
@@ -0,0 +1,376 @@
|
||||
import re
|
||||
from datetime import date, datetime, time
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import sqlparse
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.engine.url import URL
|
||||
|
||||
|
||||
class DatabaseType(str, Enum):
|
||||
POSTGRES = "postgres"
|
||||
MYSQL = "mysql"
|
||||
MSSQL = "mssql"
|
||||
|
||||
|
||||
# Defense-in-depth: reject queries containing data-modifying keywords.
|
||||
# These are checked against parsed SQL tokens (not raw text) so column names
|
||||
# and string literals do not cause false positives.
|
||||
_DISALLOWED_KEYWORDS = {
|
||||
"INSERT",
|
||||
"UPDATE",
|
||||
"DELETE",
|
||||
"DROP",
|
||||
"ALTER",
|
||||
"CREATE",
|
||||
"TRUNCATE",
|
||||
"GRANT",
|
||||
"REVOKE",
|
||||
"COPY",
|
||||
"EXECUTE",
|
||||
"CALL",
|
||||
"SET",
|
||||
"RESET",
|
||||
"DISCARD",
|
||||
"NOTIFY",
|
||||
"DO",
|
||||
}
|
||||
|
||||
# Map DatabaseType enum values to the expected SQLAlchemy driver prefix.
|
||||
_DATABASE_TYPE_TO_DRIVER = {
|
||||
DatabaseType.POSTGRES: "postgresql",
|
||||
DatabaseType.MYSQL: "mysql+pymysql",
|
||||
DatabaseType.MSSQL: "mssql+pymssql",
|
||||
}
|
||||
|
||||
# Default ports for each database type.
|
||||
_DATABASE_TYPE_DEFAULT_PORT = {
|
||||
DatabaseType.POSTGRES: 5432,
|
||||
DatabaseType.MYSQL: 3306,
|
||||
DatabaseType.MSSQL: 1433,
|
||||
}
|
||||
|
||||
|
||||
def _sanitize_error(
|
||||
error_msg: str,
|
||||
connection_string: str,
|
||||
*,
|
||||
host: str = "",
|
||||
original_host: str = "",
|
||||
username: str = "",
|
||||
port: int = 0,
|
||||
database: str = "",
|
||||
) -> str:
|
||||
"""Remove connection string, credentials, and infrastructure details
|
||||
from error messages so they are safe to expose to the LLM.
|
||||
|
||||
Scrubs:
|
||||
- The full connection string
|
||||
- URL-embedded credentials (``://user:pass@``)
|
||||
- ``password=<value>`` key-value pairs
|
||||
- The database hostname / IP used for the connection
|
||||
- The original (pre-resolution) hostname provided by the user
|
||||
- Any IPv4 addresses that appear in the message
|
||||
- Any bracketed IPv6 addresses (e.g. ``[::1]``, ``[fe80::1%eth0]``)
|
||||
- The database username
|
||||
- The database port number
|
||||
- The database name
|
||||
"""
|
||||
sanitized = error_msg.replace(connection_string, "<connection_string>")
|
||||
sanitized = re.sub(r"password=[^\s&]+", "password=***", sanitized)
|
||||
sanitized = re.sub(r"://[^@]+@", "://***:***@", sanitized)
|
||||
|
||||
# Replace the known host (may be an IP already) before the generic IP pass.
|
||||
# Also replace the original (pre-DNS-resolution) hostname if it differs.
|
||||
if original_host and original_host != host:
|
||||
sanitized = sanitized.replace(original_host, "<host>")
|
||||
if host:
|
||||
sanitized = sanitized.replace(host, "<host>")
|
||||
|
||||
# Replace any remaining IPv4 addresses (e.g. resolved IPs the driver logs)
|
||||
sanitized = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", "<ip>", sanitized)
|
||||
|
||||
# Replace bracketed IPv6 addresses (e.g. "[::1]", "[fe80::1%eth0]")
|
||||
sanitized = re.sub(r"\[[0-9a-fA-F:]+(?:%[^\]]+)?\]", "<ip>", sanitized)
|
||||
|
||||
# Replace the database username (handles double-quoted, single-quoted,
|
||||
# and unquoted formats across PostgreSQL, MySQL, and MSSQL error messages).
|
||||
if username:
|
||||
sanitized = re.sub(
|
||||
r"""for user ["']?""" + re.escape(username) + r"""["']?""",
|
||||
"for user <user>",
|
||||
sanitized,
|
||||
)
|
||||
# Catch remaining bare occurrences in various quote styles:
|
||||
# - PostgreSQL: "FATAL: role "myuser" does not exist"
|
||||
# - MySQL: "Access denied for user 'myuser'@'host'"
|
||||
# - MSSQL: "Login failed for user 'myuser'"
|
||||
sanitized = sanitized.replace(f'"{username}"', "<user>")
|
||||
sanitized = sanitized.replace(f"'{username}'", "<user>")
|
||||
|
||||
# Replace the port number (handles "port 5432" and ":5432" formats)
|
||||
if port:
|
||||
port_str = re.escape(str(port))
|
||||
sanitized = re.sub(
|
||||
r"(?:port |:)" + port_str + r"(?![0-9])",
|
||||
lambda m: ("port " if m.group().startswith("p") else ":") + "<port>",
|
||||
sanitized,
|
||||
)
|
||||
|
||||
# Replace the database name to avoid leaking internal infrastructure names.
|
||||
# Use word-boundary regex to prevent mangling when the database name is a
|
||||
# common substring (e.g. "test", "data", "on").
|
||||
if database:
|
||||
sanitized = re.sub(r"\b" + re.escape(database) + r"\b", "<database>", sanitized)
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def _extract_keyword_tokens(parsed: sqlparse.sql.Statement) -> list[str]:
|
||||
"""Extract keyword tokens from a parsed SQL statement.
|
||||
|
||||
Uses sqlparse token type classification to collect Keyword/DML/DDL/DCL
|
||||
tokens. String literals and identifiers have different token types, so
|
||||
they are naturally excluded from the result.
|
||||
"""
|
||||
return [
|
||||
token.normalized.upper()
|
||||
for token in parsed.flatten()
|
||||
if token.ttype
|
||||
in (
|
||||
sqlparse.tokens.Keyword,
|
||||
sqlparse.tokens.Keyword.DML,
|
||||
sqlparse.tokens.Keyword.DDL,
|
||||
sqlparse.tokens.Keyword.DCL,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def _has_disallowed_into(stmt: sqlparse.sql.Statement) -> bool:
|
||||
"""Check if a statement contains a disallowed ``INTO`` clause.
|
||||
|
||||
``SELECT ... INTO @variable`` is a valid read-only MySQL syntax that stores
|
||||
a query result into a session-scoped user variable. All other forms of
|
||||
``INTO`` are data-modifying or file-writing and must be blocked:
|
||||
|
||||
* ``SELECT ... INTO new_table`` (PostgreSQL / MSSQL – creates a table)
|
||||
* ``SELECT ... INTO OUTFILE`` (MySQL – writes to the filesystem)
|
||||
* ``SELECT ... INTO DUMPFILE`` (MySQL – writes to the filesystem)
|
||||
* ``INSERT INTO ...`` (already blocked by INSERT being in the
|
||||
disallowed set, but we reject INTO as well for defense-in-depth)
|
||||
|
||||
Returns ``True`` if the statement contains a disallowed ``INTO``.
|
||||
"""
|
||||
flat = list(stmt.flatten())
|
||||
for i, token in enumerate(flat):
|
||||
if not (
|
||||
token.ttype in (sqlparse.tokens.Keyword,)
|
||||
and token.normalized.upper() == "INTO"
|
||||
):
|
||||
continue
|
||||
|
||||
# Look at the first non-whitespace token after INTO.
|
||||
j = i + 1
|
||||
while j < len(flat) and flat[j].ttype is sqlparse.tokens.Text.Whitespace:
|
||||
j += 1
|
||||
|
||||
if j >= len(flat):
|
||||
# INTO at the very end – malformed, block it.
|
||||
return True
|
||||
|
||||
next_token = flat[j]
|
||||
# MySQL user variable: either a single Name starting with "@"
|
||||
# (e.g. ``@total``) or a bare ``@`` Operator token followed by a Name.
|
||||
if next_token.ttype is sqlparse.tokens.Name and next_token.value.startswith(
|
||||
"@"
|
||||
):
|
||||
continue
|
||||
if next_token.ttype is sqlparse.tokens.Operator and next_token.value == "@":
|
||||
continue
|
||||
|
||||
# Everything else (table name, OUTFILE, DUMPFILE, etc.) is disallowed.
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _validate_query_is_read_only(stmt: sqlparse.sql.Statement) -> str | None:
|
||||
"""Validate that a parsed SQL statement is read-only (SELECT/WITH only).
|
||||
|
||||
Accepts an already-parsed statement from ``_validate_single_statement``
|
||||
to avoid re-parsing. Checks:
|
||||
1. Statement type must be SELECT (sqlparse classifies WITH...SELECT as SELECT)
|
||||
2. No disallowed keywords (INSERT, UPDATE, DELETE, DROP, etc.)
|
||||
3. No disallowed INTO clauses (allows MySQL ``SELECT ... INTO @variable``)
|
||||
|
||||
Returns an error message if the query is not read-only, None otherwise.
|
||||
"""
|
||||
# sqlparse returns 'SELECT' for SELECT and WITH...SELECT queries
|
||||
if stmt.get_type() != "SELECT":
|
||||
return "Only SELECT queries are allowed."
|
||||
|
||||
# Defense-in-depth: check parsed keyword tokens for disallowed keywords
|
||||
for kw in _extract_keyword_tokens(stmt):
|
||||
# Normalize multi-word tokens (e.g. "SET LOCAL" -> "SET")
|
||||
base_kw = kw.split()[0] if " " in kw else kw
|
||||
if base_kw in _DISALLOWED_KEYWORDS:
|
||||
return f"Disallowed SQL keyword: {kw}"
|
||||
|
||||
# Contextual check for INTO: allow MySQL @variable syntax, block everything else
|
||||
if _has_disallowed_into(stmt):
|
||||
return "Disallowed SQL keyword: INTO"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _validate_single_statement(
|
||||
query: str,
|
||||
) -> tuple[str | None, sqlparse.sql.Statement | None]:
|
||||
"""Validate that the query contains exactly one non-empty SQL statement.
|
||||
|
||||
Returns (error_message, parsed_statement). If error_message is not None,
|
||||
the query is invalid and parsed_statement will be None.
|
||||
"""
|
||||
stripped = query.strip().rstrip(";").strip()
|
||||
if not stripped:
|
||||
return "Query is empty.", None
|
||||
|
||||
# Parse the SQL using sqlparse for proper tokenization
|
||||
statements = sqlparse.parse(stripped)
|
||||
|
||||
# Filter out empty statements and comment-only statements
|
||||
statements = [
|
||||
s
|
||||
for s in statements
|
||||
if s.tokens
|
||||
and str(s).strip()
|
||||
and not all(
|
||||
t.is_whitespace or t.ttype in sqlparse.tokens.Comment for t in s.flatten()
|
||||
)
|
||||
]
|
||||
|
||||
if not statements:
|
||||
return "Query is empty.", None
|
||||
|
||||
# Reject multiple statements -- prevents injection via semicolons
|
||||
if len(statements) > 1:
|
||||
return "Only single statements are allowed.", None
|
||||
|
||||
return None, statements[0]
|
||||
|
||||
|
||||
def _serialize_value(value: Any) -> Any:
|
||||
"""Convert database-specific types to JSON-serializable Python types."""
|
||||
if isinstance(value, Decimal):
|
||||
# Use int for whole numbers; use str for fractional to preserve exact
|
||||
# precision (float would silently round high-precision analytics values).
|
||||
if value == value.to_integral_value():
|
||||
return int(value)
|
||||
return str(value)
|
||||
if isinstance(value, (datetime, date, time)):
|
||||
return value.isoformat()
|
||||
if isinstance(value, memoryview):
|
||||
return bytes(value).hex()
|
||||
if isinstance(value, bytes):
|
||||
return value.hex()
|
||||
return value
|
||||
|
||||
|
||||
def _configure_session(
|
||||
conn: Any,
|
||||
dialect_name: str,
|
||||
timeout_ms: str,
|
||||
read_only: bool,
|
||||
) -> None:
|
||||
"""Set session-level timeout and read-only mode for the given dialect."""
|
||||
if dialect_name == "postgresql":
|
||||
conn.execute(text("SET statement_timeout = " + timeout_ms))
|
||||
if read_only:
|
||||
conn.execute(text("SET default_transaction_read_only = ON"))
|
||||
elif dialect_name == "mysql":
|
||||
# NOTE: MAX_EXECUTION_TIME only applies to SELECT statements.
|
||||
# Write queries (INSERT/UPDATE/DELETE) are not bounded by this
|
||||
# setting; they rely on the database's wait_timeout instead.
|
||||
conn.execute(text("SET SESSION MAX_EXECUTION_TIME = " + timeout_ms))
|
||||
if read_only:
|
||||
conn.execute(text("SET SESSION TRANSACTION READ ONLY"))
|
||||
elif dialect_name == "mssql":
|
||||
# MSSQL: SET LOCK_TIMEOUT limits lock-wait time (ms).
|
||||
# pymssql's connect_args "login_timeout" handles the connection
|
||||
# timeout, but LOCK_TIMEOUT covers in-query lock waits.
|
||||
conn.execute(text("SET LOCK_TIMEOUT " + timeout_ms))
|
||||
# MSSQL lacks a session-level read-only mode like
|
||||
# PostgreSQL/MySQL. Read-only enforcement is handled by
|
||||
# the SQL validation layer (_validate_query_is_read_only)
|
||||
# and the ROLLBACK in the finally block.
|
||||
|
||||
|
||||
def _run_in_transaction(
|
||||
conn: Any,
|
||||
dialect_name: str,
|
||||
query: str,
|
||||
max_rows: int,
|
||||
read_only: bool,
|
||||
) -> tuple[list[dict[str, Any]], list[str], int]:
|
||||
"""Execute a query inside an explicit transaction, returning results."""
|
||||
# MSSQL uses T-SQL "BEGIN TRANSACTION"; others use "BEGIN".
|
||||
begin_stmt = "BEGIN TRANSACTION" if dialect_name == "mssql" else "BEGIN"
|
||||
conn.execute(text(begin_stmt))
|
||||
try:
|
||||
result = conn.execute(text(query))
|
||||
affected = result.rowcount if not result.returns_rows else -1
|
||||
columns = list(result.keys()) if result.returns_rows else []
|
||||
rows = result.fetchmany(max_rows) if result.returns_rows else []
|
||||
results = [
|
||||
{col: _serialize_value(val) for col, val in zip(columns, row)}
|
||||
for row in rows
|
||||
]
|
||||
except Exception:
|
||||
conn.execute(text("ROLLBACK"))
|
||||
raise
|
||||
else:
|
||||
conn.execute(text("ROLLBACK" if read_only else "COMMIT"))
|
||||
return results, columns, affected
|
||||
|
||||
|
||||
def _execute_query(
|
||||
connection_url: URL | str,
|
||||
query: str,
|
||||
timeout: int,
|
||||
max_rows: int,
|
||||
read_only: bool = True,
|
||||
database_type: DatabaseType = DatabaseType.POSTGRES,
|
||||
) -> tuple[list[dict[str, Any]], list[str], int]:
|
||||
"""Execute a SQL query and return (rows, columns, affected_rows).
|
||||
|
||||
Uses SQLAlchemy to connect to any supported database.
|
||||
For SELECT queries, rows are limited to ``max_rows`` via DBAPI fetchmany.
|
||||
For write queries, affected_rows contains the rowcount from the driver.
|
||||
When ``read_only`` is True, the database session is set to read-only
|
||||
mode and the transaction is always rolled back.
|
||||
"""
|
||||
# Determine driver-specific connection timeout argument.
|
||||
# pymssql uses "login_timeout", while PostgreSQL/MySQL use "connect_timeout".
|
||||
timeout_key = (
|
||||
"login_timeout" if database_type == DatabaseType.MSSQL else "connect_timeout"
|
||||
)
|
||||
engine = create_engine(connection_url, connect_args={timeout_key: 10})
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
# Use AUTOCOMMIT so SET commands take effect immediately.
|
||||
conn = conn.execution_options(isolation_level="AUTOCOMMIT")
|
||||
|
||||
# Compute timeout in milliseconds. The value is Pydantic-validated
|
||||
# (ge=1, le=120), but we use int() as defense-in-depth.
|
||||
# NOTE: SET commands do not support bind parameters in most
|
||||
# databases, so we use str(int(...)) for safe interpolation.
|
||||
timeout_ms = str(int(timeout * 1000))
|
||||
|
||||
_configure_session(conn, engine.dialect.name, timeout_ms, read_only)
|
||||
return _run_in_transaction(
|
||||
conn, engine.dialect.name, query, max_rows, read_only
|
||||
)
|
||||
finally:
|
||||
engine.dispose()
|
||||
@@ -15,6 +15,7 @@ from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -181,6 +182,7 @@ class CreateTalkingAvatarVideoBlock(Block):
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
yield "video_url", stored_url
|
||||
return
|
||||
elif status_response["status"] == "error":
|
||||
|
||||
@@ -300,13 +300,27 @@ def test_agent_input_block_ignores_legacy_placeholder_values():
|
||||
|
||||
|
||||
def test_dropdown_input_block_produces_enum():
|
||||
"""Verify AgentDropdownInputBlock.Input.generate_schema() produces enum."""
|
||||
options = ["Option A", "Option B"]
|
||||
"""Verify AgentDropdownInputBlock.Input.generate_schema() produces enum
|
||||
using the canonical 'options' field name."""
|
||||
opts = ["Option A", "Option B"]
|
||||
instance = AgentDropdownInputBlock.Input.model_construct(
|
||||
name="choice", value=None, placeholder_values=options
|
||||
name="choice", value=None, options=opts
|
||||
)
|
||||
schema = instance.generate_schema()
|
||||
assert schema.get("enum") == options
|
||||
assert schema.get("enum") == opts
|
||||
|
||||
|
||||
def test_dropdown_input_block_legacy_placeholder_values_produces_enum():
|
||||
"""Verify backward compat: passing legacy 'placeholder_values' to
|
||||
AgentDropdownInputBlock still produces enum via model_construct remap."""
|
||||
opts = ["Option A", "Option B"]
|
||||
instance = AgentDropdownInputBlock.Input.model_construct(
|
||||
name="choice", value=None, placeholder_values=opts
|
||||
)
|
||||
schema = instance.generate_schema()
|
||||
assert (
|
||||
schema.get("enum") == opts
|
||||
), "Legacy placeholder_values should be remapped to options"
|
||||
|
||||
|
||||
def test_generate_schema_integration_legacy_placeholder_values():
|
||||
@@ -329,11 +343,11 @@ def test_generate_schema_integration_legacy_placeholder_values():
|
||||
|
||||
def test_generate_schema_integration_dropdown_produces_enum():
|
||||
"""Test the full Graph._generate_schema path with AgentDropdownInputBlock
|
||||
— verifies enum IS produced for dropdown blocks."""
|
||||
— verifies enum IS produced for dropdown blocks using canonical field name."""
|
||||
dropdown_input_default = {
|
||||
"name": "color",
|
||||
"value": None,
|
||||
"placeholder_values": ["Red", "Green", "Blue"],
|
||||
"options": ["Red", "Green", "Blue"],
|
||||
}
|
||||
result = BaseGraph._generate_schema(
|
||||
(AgentDropdownInputBlock.Input, dropdown_input_default),
|
||||
@@ -344,3 +358,36 @@ def test_generate_schema_integration_dropdown_produces_enum():
|
||||
"Green",
|
||||
"Blue",
|
||||
], "Graph schema should contain enum from AgentDropdownInputBlock"
|
||||
|
||||
|
||||
def test_generate_schema_integration_dropdown_legacy_placeholder_values():
|
||||
"""Test the full Graph._generate_schema path with AgentDropdownInputBlock
|
||||
using legacy 'placeholder_values' — verifies backward compat produces enum."""
|
||||
legacy_dropdown_input_default = {
|
||||
"name": "color",
|
||||
"value": None,
|
||||
"placeholder_values": ["Red", "Green", "Blue"],
|
||||
}
|
||||
result = BaseGraph._generate_schema(
|
||||
(AgentDropdownInputBlock.Input, legacy_dropdown_input_default),
|
||||
)
|
||||
color_props = result["properties"]["color"]
|
||||
assert color_props.get("enum") == [
|
||||
"Red",
|
||||
"Green",
|
||||
"Blue",
|
||||
], "Legacy placeholder_values should still produce enum via model_construct remap"
|
||||
|
||||
|
||||
def test_dropdown_input_block_init_legacy_placeholder_values():
|
||||
"""Verify backward compat: constructing AgentDropdownInputBlock.Input via
|
||||
model_validate with legacy 'placeholder_values' correctly maps to 'options'."""
|
||||
opts = ["Option A", "Option B"]
|
||||
instance = AgentDropdownInputBlock.Input.model_validate(
|
||||
{"name": "choice", "value": None, "placeholder_values": opts}
|
||||
)
|
||||
assert (
|
||||
instance.options == opts
|
||||
), "Legacy placeholder_values should be remapped to options via model_validate"
|
||||
schema = instance.generate_schema()
|
||||
assert schema.get("enum") == opts
|
||||
|
||||
@@ -13,6 +13,7 @@ from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -104,4 +105,5 @@ class UnrealTextToSpeechBlock(Block):
|
||||
input_data.text,
|
||||
input_data.voice_id,
|
||||
)
|
||||
self.merge_stats(NodeExecutionStats(output_size=len(input_data.text)))
|
||||
yield "mp3_url", api_response["OutputUri"]
|
||||
|
||||
@@ -19,6 +19,7 @@ from backend.blocks._base import (
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
UserPasswordCredentials,
|
||||
)
|
||||
@@ -170,6 +171,7 @@ class TranscribeYoutubeVideoBlock(Block):
|
||||
transcript = self.get_transcript(video_id, credentials)
|
||||
transcript_text = self.format_transcript(transcript=transcript)
|
||||
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
# Only yield after all operations succeed
|
||||
yield "video_id", video_id
|
||||
yield "transcript", transcript_text
|
||||
|
||||
@@ -21,7 +21,7 @@ from backend.blocks.zerobounce._auth import (
|
||||
ZeroBounceCredentials,
|
||||
ZeroBounceCredentialsInput,
|
||||
)
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
|
||||
|
||||
|
||||
class Response(BaseModel):
|
||||
@@ -177,5 +177,6 @@ class ValidateEmailsBlock(Block):
|
||||
)
|
||||
|
||||
response_model = Response(**response.__dict__)
|
||||
self.merge_stats(NodeExecutionStats(output_size=1))
|
||||
|
||||
yield "response", response_model
|
||||
|
||||
@@ -51,6 +51,12 @@ from backend.copilot.service import (
|
||||
from backend.copilot.token_tracking import persist_and_record_usage
|
||||
from backend.copilot.tools import execute_tool, get_available_tools
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.copilot.transcript import (
|
||||
download_transcript,
|
||||
upload_transcript,
|
||||
validate_transcript,
|
||||
)
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.prompt import (
|
||||
compress_context,
|
||||
@@ -108,7 +114,7 @@ async def _baseline_llm_caller(
|
||||
if tools:
|
||||
typed_tools = cast(list[ChatCompletionToolParam], tools)
|
||||
response = await client.chat.completions.create(
|
||||
model=config.model,
|
||||
model=config.fast_model,
|
||||
messages=typed_messages,
|
||||
tools=typed_tools,
|
||||
stream=True,
|
||||
@@ -116,7 +122,7 @@ async def _baseline_llm_caller(
|
||||
)
|
||||
else:
|
||||
response = await client.chat.completions.create(
|
||||
model=config.model,
|
||||
model=config.fast_model,
|
||||
messages=typed_messages,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
@@ -282,6 +288,9 @@ def _baseline_conversation_updater(
|
||||
messages: list[dict[str, Any]],
|
||||
response: LLMLoopResponse,
|
||||
tool_results: list[ToolCallResult] | None = None,
|
||||
*,
|
||||
transcript_builder: TranscriptBuilder,
|
||||
model: str = "",
|
||||
) -> None:
|
||||
"""Update OpenAI message list with assistant response + tool results.
|
||||
|
||||
@@ -301,6 +310,29 @@ def _baseline_conversation_updater(
|
||||
for tc in response.tool_calls
|
||||
]
|
||||
messages.append(assistant_msg)
|
||||
# Record assistant message (with tool_calls) to transcript
|
||||
content_blocks: list[dict[str, Any]] = []
|
||||
if response.response_text:
|
||||
content_blocks.append({"type": "text", "text": response.response_text})
|
||||
for tc in response.tool_calls:
|
||||
try:
|
||||
args = orjson.loads(tc.arguments) if tc.arguments else {}
|
||||
except Exception:
|
||||
args = {}
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": tc.id,
|
||||
"name": tc.name,
|
||||
"input": args,
|
||||
}
|
||||
)
|
||||
if content_blocks:
|
||||
transcript_builder.append_assistant(
|
||||
content_blocks=content_blocks,
|
||||
model=model,
|
||||
stop_reason="tool_use",
|
||||
)
|
||||
for tr in tool_results:
|
||||
messages.append(
|
||||
{
|
||||
@@ -309,9 +341,22 @@ def _baseline_conversation_updater(
|
||||
"content": tr.content,
|
||||
}
|
||||
)
|
||||
# Record tool result to transcript AFTER the assistant tool_use
|
||||
# block to maintain correct Anthropic API ordering:
|
||||
# assistant(tool_use) → user(tool_result)
|
||||
transcript_builder.append_tool_result(
|
||||
tool_use_id=tr.tool_call_id,
|
||||
content=tr.content,
|
||||
)
|
||||
else:
|
||||
if response.response_text:
|
||||
messages.append({"role": "assistant", "content": response.response_text})
|
||||
# Record final text to transcript
|
||||
transcript_builder.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": response.response_text}],
|
||||
model=model,
|
||||
stop_reason="end_turn",
|
||||
)
|
||||
|
||||
|
||||
async def _update_title_async(
|
||||
@@ -340,19 +385,23 @@ async def _compress_session_messages(
|
||||
msg_dict: dict[str, Any] = {"role": msg.role}
|
||||
if msg.content:
|
||||
msg_dict["content"] = msg.content
|
||||
if msg.tool_calls:
|
||||
msg_dict["tool_calls"] = msg.tool_calls
|
||||
if msg.tool_call_id:
|
||||
msg_dict["tool_call_id"] = msg.tool_call_id
|
||||
messages_dict.append(msg_dict)
|
||||
|
||||
try:
|
||||
result = await compress_context(
|
||||
messages=messages_dict,
|
||||
model=config.model,
|
||||
model=config.fast_model,
|
||||
client=_get_openai_client(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("[Baseline] Context compression with LLM failed: %s", e)
|
||||
result = await compress_context(
|
||||
messages=messages_dict,
|
||||
model=config.model,
|
||||
model=config.fast_model,
|
||||
client=None,
|
||||
)
|
||||
|
||||
@@ -366,7 +415,12 @@ async def _compress_session_messages(
|
||||
result.messages_dropped,
|
||||
)
|
||||
return [
|
||||
ChatMessage(role=m["role"], content=m.get("content"))
|
||||
ChatMessage(
|
||||
role=m["role"],
|
||||
content=m.get("content"),
|
||||
tool_calls=m.get("tool_calls"),
|
||||
tool_call_id=m.get("tool_call_id"),
|
||||
)
|
||||
for m in result.messages
|
||||
]
|
||||
|
||||
@@ -397,7 +451,8 @@ async def stream_chat_completion_baseline(
|
||||
f"Session {session_id} not found. Please create a new session first."
|
||||
)
|
||||
|
||||
# Append user message
|
||||
# Append user message (skip if it's an exact duplicate of the last message,
|
||||
# e.g. from a network retry)
|
||||
new_role = "user" if is_user_message else "assistant"
|
||||
if message and (
|
||||
len(session.messages) == 0
|
||||
@@ -416,6 +471,54 @@ async def stream_chat_completion_baseline(
|
||||
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
# --- Transcript support (feature parity with SDK path) ---
|
||||
transcript_builder = TranscriptBuilder()
|
||||
transcript_covers_prefix = True
|
||||
|
||||
if user_id and len(session.messages) > 1:
|
||||
try:
|
||||
dl = await download_transcript(user_id, session_id, log_prefix="[Baseline]")
|
||||
if dl and validate_transcript(dl.content):
|
||||
# Reject stale transcripts: if msg_count is known and
|
||||
# doesn't cover the current session, loading it would
|
||||
# silently drop intermediate turns from the transcript.
|
||||
session_msg_count = len(session.messages)
|
||||
if dl.message_count and dl.message_count < session_msg_count - 1:
|
||||
logger.warning(
|
||||
"[Baseline] Transcript stale: covers %d of %d messages, skipping",
|
||||
dl.message_count,
|
||||
session_msg_count,
|
||||
)
|
||||
transcript_covers_prefix = False
|
||||
else:
|
||||
transcript_builder.load_previous(
|
||||
dl.content, log_prefix="[Baseline]"
|
||||
)
|
||||
logger.info(
|
||||
"[Baseline] Loaded transcript: %dB, msg_count=%d",
|
||||
len(dl.content),
|
||||
dl.message_count,
|
||||
)
|
||||
elif dl:
|
||||
logger.warning("[Baseline] Downloaded transcript but invalid")
|
||||
transcript_covers_prefix = False
|
||||
else:
|
||||
logger.debug("[Baseline] No transcript available")
|
||||
transcript_covers_prefix = False
|
||||
except Exception as e:
|
||||
logger.warning("[Baseline] Transcript download failed: %s", e)
|
||||
transcript_covers_prefix = False
|
||||
|
||||
# Append user message to transcript.
|
||||
# Always append when the message is present and is from the user,
|
||||
# even on duplicate-suppressed retries (is_new_message=False).
|
||||
# The loaded transcript may be stale (uploaded before the previous
|
||||
# attempt stored this message), so skipping it would leave the
|
||||
# transcript without the user turn, creating a malformed
|
||||
# assistant-after-assistant structure when the LLM reply is added.
|
||||
if message and is_user_message:
|
||||
transcript_builder.append_user(content=message)
|
||||
|
||||
# Generate title for new sessions
|
||||
if is_user_message and not session.title:
|
||||
user_messages = [m for m in session.messages if m.role == "user"]
|
||||
@@ -448,12 +551,30 @@ async def stream_chat_completion_baseline(
|
||||
# Compress context if approaching the model's token limit
|
||||
messages_for_context = await _compress_session_messages(session.messages)
|
||||
|
||||
# Build OpenAI message list from session history
|
||||
# Build OpenAI message list from session history.
|
||||
# Include tool_calls on assistant messages and tool-role results so the
|
||||
# model retains full context of what tools were invoked and their outcomes.
|
||||
openai_messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": system_prompt}
|
||||
]
|
||||
for msg in messages_for_context:
|
||||
if msg.role in ("user", "assistant") and msg.content:
|
||||
if msg.role == "assistant":
|
||||
entry: dict[str, Any] = {"role": "assistant"}
|
||||
if msg.content:
|
||||
entry["content"] = msg.content
|
||||
if msg.tool_calls:
|
||||
entry["tool_calls"] = msg.tool_calls
|
||||
if msg.content or msg.tool_calls:
|
||||
openai_messages.append(entry)
|
||||
elif msg.role == "tool" and msg.tool_call_id:
|
||||
openai_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": msg.tool_call_id,
|
||||
"content": msg.content or "",
|
||||
}
|
||||
)
|
||||
elif msg.role == "user" and msg.content:
|
||||
openai_messages.append({"role": msg.role, "content": msg.content})
|
||||
|
||||
tools = get_available_tools()
|
||||
@@ -487,6 +608,12 @@ async def stream_chat_completion_baseline(
|
||||
_baseline_tool_executor, state=state, user_id=user_id, session=session
|
||||
)
|
||||
|
||||
_bound_conversation_updater = partial(
|
||||
_baseline_conversation_updater,
|
||||
transcript_builder=transcript_builder,
|
||||
model=config.fast_model,
|
||||
)
|
||||
|
||||
try:
|
||||
loop_result = None
|
||||
async for loop_result in tool_call_loop(
|
||||
@@ -494,7 +621,7 @@ async def stream_chat_completion_baseline(
|
||||
tools=tools,
|
||||
llm_call=_bound_llm_caller,
|
||||
execute_tool=_bound_tool_executor,
|
||||
update_conversation=_baseline_conversation_updater,
|
||||
update_conversation=_bound_conversation_updater,
|
||||
max_iterations=_MAX_TOOL_ROUNDS,
|
||||
):
|
||||
# Drain buffered events after each iteration (real-time streaming)
|
||||
@@ -563,10 +690,10 @@ async def stream_chat_completion_baseline(
|
||||
and not (_stream_error and not state.assistant_text)
|
||||
):
|
||||
state.turn_prompt_tokens = max(
|
||||
estimate_token_count(openai_messages, model=config.model), 1
|
||||
estimate_token_count(openai_messages, model=config.fast_model), 1
|
||||
)
|
||||
state.turn_completion_tokens = estimate_token_count_str(
|
||||
state.assistant_text, model=config.model
|
||||
state.assistant_text, model=config.fast_model
|
||||
)
|
||||
logger.info(
|
||||
"[Baseline] No streaming usage reported; estimated tokens: "
|
||||
@@ -597,6 +724,25 @@ async def stream_chat_completion_baseline(
|
||||
except Exception as persist_err:
|
||||
logger.error("[Baseline] Failed to persist session: %s", persist_err)
|
||||
|
||||
# --- Upload transcript for next-turn continuity ---
|
||||
if user_id and transcript_covers_prefix:
|
||||
try:
|
||||
_transcript_content = transcript_builder.to_jsonl()
|
||||
if _transcript_content and validate_transcript(_transcript_content):
|
||||
await asyncio.shield(
|
||||
upload_transcript(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
content=_transcript_content,
|
||||
message_count=len(session.messages),
|
||||
log_prefix="[Baseline]",
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.debug("[Baseline] No valid transcript to upload")
|
||||
except Exception as upload_err:
|
||||
logger.error("[Baseline] Transcript upload failed: %s", upload_err)
|
||||
|
||||
# Yield usage and finish AFTER try/finally (not inside finally).
|
||||
# PEP 525 prohibits yielding from finally in async generators during
|
||||
# aclose() — doing so raises RuntimeError on client disconnect.
|
||||
|
||||
@@ -14,12 +14,21 @@ class ChatConfig(BaseSettings):
|
||||
|
||||
# OpenAI API Configuration
|
||||
model: str = Field(
|
||||
default="anthropic/claude-opus-4.6", description="Default model to use"
|
||||
default="anthropic/claude-opus-4.6",
|
||||
description="Default model for extended thinking mode",
|
||||
)
|
||||
fast_model: str = Field(
|
||||
default="anthropic/claude-sonnet-4",
|
||||
description="Model for fast mode (baseline path). Should be faster/cheaper than the default model.",
|
||||
)
|
||||
title_model: str = Field(
|
||||
default="openai/gpt-4o-mini",
|
||||
description="Model to use for generating session titles (should be fast/cheap)",
|
||||
)
|
||||
simulation_model: str = Field(
|
||||
default="google/gemini-2.5-flash",
|
||||
description="Model for dry-run block simulation (should be fast/cheap with good JSON output)",
|
||||
)
|
||||
api_key: str | None = Field(default=None, description="OpenAI API key")
|
||||
base_url: str | None = Field(
|
||||
default=OPENROUTER_BASE_URL,
|
||||
@@ -77,11 +86,11 @@ class ChatConfig(BaseSettings):
|
||||
# allows ~70-100 turns/day.
|
||||
# Checked at the HTTP layer (routes.py) before each turn.
|
||||
#
|
||||
# TODO: These are deploy-time constants applied identically to every user.
|
||||
# If per-user or per-plan limits are needed (e.g., free tier vs paid), these
|
||||
# must move to the database (e.g., a UserPlan table) and get_usage_status /
|
||||
# check_rate_limit would look up each user's specific limits instead of
|
||||
# reading config.daily_token_limit / config.weekly_token_limit.
|
||||
# These are base limits for the FREE tier. Higher tiers (PRO, BUSINESS,
|
||||
# ENTERPRISE) multiply these by their tier multiplier (see
|
||||
# rate_limit.TIER_MULTIPLIERS). User tier is stored in the
|
||||
# User.subscriptionTier DB column and resolved inside
|
||||
# get_global_rate_limits().
|
||||
daily_token_limit: int = Field(
|
||||
default=2_500_000,
|
||||
description="Max tokens per day, resets at midnight UTC (0 = unlimited)",
|
||||
@@ -129,6 +138,32 @@ class ChatConfig(BaseSettings):
|
||||
description="Use --resume for multi-turn conversations instead of "
|
||||
"history compression. Falls back to compression when unavailable.",
|
||||
)
|
||||
claude_agent_fallback_model: str = Field(
|
||||
default="claude-sonnet-4-20250514",
|
||||
description="Fallback model when the primary model is unavailable (e.g. 529 "
|
||||
"overloaded). The SDK automatically retries with this cheaper model.",
|
||||
)
|
||||
claude_agent_max_turns: int = Field(
|
||||
default=50,
|
||||
ge=1,
|
||||
le=500,
|
||||
description="Maximum number of agentic turns (tool-use loops) per query. "
|
||||
"Prevents runaway tool loops from burning budget.",
|
||||
)
|
||||
claude_agent_max_budget_usd: float = Field(
|
||||
default=5.0,
|
||||
ge=0.01,
|
||||
le=100.0,
|
||||
description="Maximum spend in USD per SDK query. The CLI aborts the "
|
||||
"request if this budget is exceeded.",
|
||||
)
|
||||
claude_agent_max_transient_retries: int = Field(
|
||||
default=3,
|
||||
ge=0,
|
||||
le=10,
|
||||
description="Maximum number of retries for transient API errors "
|
||||
"(429, 5xx, ECONNRESET) before surfacing the error to the user.",
|
||||
)
|
||||
use_openrouter: bool = Field(
|
||||
default=True,
|
||||
description="Enable routing API calls through the OpenRouter proxy. "
|
||||
|
||||
@@ -44,12 +44,32 @@ def parse_node_id_from_exec_id(node_exec_id: str) -> str:
|
||||
# Transient Anthropic API error detection
|
||||
# ---------------------------------------------------------------------------
|
||||
# Patterns in error text that indicate a transient Anthropic API error
|
||||
# (ECONNRESET / dropped TCP connection) which is retryable.
|
||||
# which is retryable. Covers:
|
||||
# - Connection-level: ECONNRESET, dropped TCP connections
|
||||
# - HTTP 429: rate-limit / too-many-requests
|
||||
# - HTTP 5xx: server errors, overloaded
|
||||
_TRANSIENT_ERROR_PATTERNS = (
|
||||
# Connection-level
|
||||
"socket connection was closed unexpectedly",
|
||||
"ECONNRESET",
|
||||
"connection was forcibly closed",
|
||||
"network socket disconnected",
|
||||
# 429 rate-limit patterns
|
||||
"rate limit",
|
||||
"rate_limit",
|
||||
"too many requests",
|
||||
"status code 429",
|
||||
# 5xx server error patterns
|
||||
"overloaded",
|
||||
"internal server error",
|
||||
"bad gateway",
|
||||
"service unavailable",
|
||||
"gateway timeout",
|
||||
"status code 529",
|
||||
"status code 500",
|
||||
"status code 502",
|
||||
"status code 503",
|
||||
"status code 504",
|
||||
)
|
||||
|
||||
FRIENDLY_TRANSIENT_MSG = "Anthropic connection interrupted — please retry"
|
||||
|
||||
@@ -149,7 +149,8 @@ def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
|
||||
|
||||
Allowed:
|
||||
- Files under *sdk_cwd* (``/tmp/copilot-<session>/``)
|
||||
- Files under ``~/.claude/projects/<encoded-cwd>/<uuid>/tool-results/...``.
|
||||
- Files under ``~/.claude/projects/<encoded-cwd>/<uuid>/tool-results/...``
|
||||
or ``tool-outputs/...``.
|
||||
The SDK nests tool-results under a conversation UUID directory;
|
||||
the UUID segment is validated with ``_UUID_RE``.
|
||||
"""
|
||||
@@ -174,17 +175,20 @@ def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
|
||||
# Defence-in-depth: ensure project_dir didn't escape the base.
|
||||
if not project_dir.startswith(SDK_PROJECTS_DIR + os.sep):
|
||||
return False
|
||||
# Only allow: <encoded-cwd>/<uuid>/tool-results/<file>
|
||||
# Only allow: <encoded-cwd>/<uuid>/<tool-dir>/<file>
|
||||
# The SDK always creates a conversation UUID directory between
|
||||
# the project dir and tool-results/.
|
||||
# the project dir and the tool directory.
|
||||
# Accept both "tool-results" (SDK's persisted outputs) and
|
||||
# "tool-outputs" (the model sometimes confuses workspace paths
|
||||
# with filesystem paths and generates this variant).
|
||||
if resolved.startswith(project_dir + os.sep):
|
||||
relative = resolved[len(project_dir) + 1 :]
|
||||
parts = relative.split(os.sep)
|
||||
# Require exactly: [<uuid>, "tool-results", <file>, ...]
|
||||
# Require exactly: [<uuid>, "tool-results"|"tool-outputs", <file>, ...]
|
||||
if (
|
||||
len(parts) >= 3
|
||||
and _UUID_RE.match(parts[0])
|
||||
and parts[1] == "tool-results"
|
||||
and parts[1] in ("tool-results", "tool-outputs")
|
||||
):
|
||||
return True
|
||||
|
||||
|
||||
@@ -134,6 +134,21 @@ def test_is_allowed_local_path_tool_results_with_uuid():
|
||||
_current_project_dir.set("")
|
||||
|
||||
|
||||
def test_is_allowed_local_path_tool_outputs_with_uuid():
|
||||
"""Files under <encoded-cwd>/<uuid>/tool-outputs/ are also allowed."""
|
||||
encoded = "test-encoded-dir"
|
||||
conv_uuid = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
|
||||
path = os.path.join(
|
||||
SDK_PROJECTS_DIR, encoded, conv_uuid, "tool-outputs", "output.json"
|
||||
)
|
||||
|
||||
_current_project_dir.set(encoded)
|
||||
try:
|
||||
assert is_allowed_local_path(path, sdk_cwd=None)
|
||||
finally:
|
||||
_current_project_dir.set("")
|
||||
|
||||
|
||||
def test_is_allowed_local_path_tool_results_without_uuid_rejected():
|
||||
"""Direct <encoded-cwd>/tool-results/ (no UUID) is rejected."""
|
||||
encoded = "test-encoded-dir"
|
||||
@@ -159,7 +174,7 @@ def test_is_allowed_local_path_sibling_of_tool_results_is_rejected():
|
||||
|
||||
|
||||
def test_is_allowed_local_path_valid_uuid_wrong_segment_name_rejected():
|
||||
"""A valid UUID dir but non-'tool-results' second segment is rejected."""
|
||||
"""A valid UUID dir but non-'tool-results'/'tool-outputs' second segment is rejected."""
|
||||
encoded = "test-encoded-dir"
|
||||
uuid_str = "12345678-1234-5678-9abc-def012345678"
|
||||
path = os.path.join(
|
||||
|
||||
@@ -251,20 +251,31 @@ class CoPilotProcessor:
|
||||
stream_fn = stream_chat_completion_dummy
|
||||
log.warning("Using DUMMY service (CHAT_TEST_MODE=true)")
|
||||
else:
|
||||
use_sdk = (
|
||||
config.use_claude_code_subscription
|
||||
or await is_feature_enabled(
|
||||
Flag.COPILOT_SDK,
|
||||
entry.user_id or "anonymous",
|
||||
default=config.use_claude_agent_sdk,
|
||||
# Per-request mode override from the frontend takes priority.
|
||||
# 'fast' → baseline (OpenAI-compatible), 'extended_thinking' → SDK.
|
||||
if entry.mode == "fast":
|
||||
use_sdk = False
|
||||
elif entry.mode == "extended_thinking":
|
||||
use_sdk = True
|
||||
else:
|
||||
# No mode specified — fall back to feature flag / config.
|
||||
use_sdk = (
|
||||
config.use_claude_code_subscription
|
||||
or await is_feature_enabled(
|
||||
Flag.COPILOT_SDK,
|
||||
entry.user_id or "anonymous",
|
||||
default=config.use_claude_agent_sdk,
|
||||
)
|
||||
)
|
||||
)
|
||||
stream_fn = (
|
||||
sdk_service.stream_chat_completion_sdk
|
||||
if use_sdk
|
||||
else stream_chat_completion_baseline
|
||||
)
|
||||
log.info(f"Using {'SDK' if use_sdk else 'baseline'} service")
|
||||
log.info(
|
||||
f"Using {'SDK' if use_sdk else 'baseline'} service "
|
||||
f"(mode={entry.mode or 'default'})"
|
||||
)
|
||||
|
||||
# Stream chat completion and publish chunks to Redis.
|
||||
# stream_and_publish wraps the raw stream with registry
|
||||
|
||||
@@ -6,6 +6,7 @@ Defines two exchanges and queues following the graph executor pattern:
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -156,6 +157,9 @@ class CoPilotExecutionEntry(BaseModel):
|
||||
file_ids: list[str] | None = None
|
||||
"""Workspace file IDs attached to the user's message"""
|
||||
|
||||
mode: Literal["fast", "extended_thinking"] | None = None
|
||||
"""Autopilot mode override: 'fast' or 'extended_thinking'. None = server default."""
|
||||
|
||||
|
||||
class CancelCoPilotEvent(BaseModel):
|
||||
"""Event to cancel a CoPilot operation."""
|
||||
@@ -175,6 +179,7 @@ async def enqueue_copilot_turn(
|
||||
is_user_message: bool = True,
|
||||
context: dict[str, str] | None = None,
|
||||
file_ids: list[str] | None = None,
|
||||
mode: Literal["fast", "extended_thinking"] | None = None,
|
||||
) -> None:
|
||||
"""Enqueue a CoPilot task for processing by the executor service.
|
||||
|
||||
@@ -186,6 +191,7 @@ async def enqueue_copilot_turn(
|
||||
is_user_message: Whether the message is from the user (vs system/assistant)
|
||||
context: Optional context for the message (e.g., {url: str, content: str})
|
||||
file_ids: Optional workspace file IDs attached to the user's message
|
||||
mode: Autopilot mode override ('fast' or 'extended_thinking'). None = server default.
|
||||
"""
|
||||
from backend.util.clients import get_async_copilot_queue
|
||||
|
||||
@@ -197,6 +203,7 @@ async def enqueue_copilot_turn(
|
||||
is_user_message=is_user_message,
|
||||
context=context,
|
||||
file_ids=file_ids,
|
||||
mode=mode,
|
||||
)
|
||||
|
||||
queue_client = await get_async_copilot_queue()
|
||||
|
||||
@@ -59,6 +59,16 @@ _null_cache: TTLCache[tuple[str, str], bool] = TTLCache(
|
||||
maxsize=_CACHE_MAX_SIZE, ttl=_NULL_CACHE_TTL
|
||||
)
|
||||
|
||||
# GitHub user identity caches (keyed by user_id only, not provider tuple).
|
||||
# Declared here so invalidate_user_provider_cache() can reference them.
|
||||
_GH_IDENTITY_CACHE_TTL = 600.0 # 10 min — profile data rarely changes
|
||||
_gh_identity_cache: TTLCache[str, dict[str, str]] = TTLCache(
|
||||
maxsize=_CACHE_MAX_SIZE, ttl=_GH_IDENTITY_CACHE_TTL
|
||||
)
|
||||
_gh_identity_null_cache: TTLCache[str, bool] = TTLCache(
|
||||
maxsize=_CACHE_MAX_SIZE, ttl=_NULL_CACHE_TTL
|
||||
)
|
||||
|
||||
|
||||
def invalidate_user_provider_cache(user_id: str, provider: str) -> None:
|
||||
"""Remove the cached entry for *user_id*/*provider* from both caches.
|
||||
@@ -66,11 +76,19 @@ def invalidate_user_provider_cache(user_id: str, provider: str) -> None:
|
||||
Call this after storing new credentials so that the next
|
||||
``get_provider_token()`` call performs a fresh DB lookup instead of
|
||||
serving a stale TTL-cached result.
|
||||
|
||||
For GitHub specifically, also clears the git-identity caches so that
|
||||
``get_github_user_git_identity()`` re-fetches the user's profile on
|
||||
the next call instead of serving stale identity data.
|
||||
"""
|
||||
key = (user_id, provider)
|
||||
_token_cache.pop(key, None)
|
||||
_null_cache.pop(key, None)
|
||||
|
||||
if provider == "github":
|
||||
_gh_identity_cache.pop(user_id, None)
|
||||
_gh_identity_null_cache.pop(user_id, None)
|
||||
|
||||
|
||||
# Register this module's cache-bust function with the credentials manager so
|
||||
# that any create/update/delete operation immediately evicts stale cache
|
||||
@@ -123,6 +141,7 @@ async def get_provider_token(user_id: str, provider: str) -> str | None:
|
||||
[c for c in creds_list if c.type == "oauth2"],
|
||||
key=lambda c: 0 if "repo" in (cast(OAuth2Credentials, c).scopes or []) else 1,
|
||||
)
|
||||
refresh_failed = False
|
||||
for creds in oauth2_creds:
|
||||
if creds.type == "oauth2":
|
||||
try:
|
||||
@@ -141,6 +160,7 @@ async def get_provider_token(user_id: str, provider: str) -> str | None:
|
||||
# Do NOT fall back to the stale token — it is likely expired
|
||||
# or revoked. Returning None forces the caller to re-auth,
|
||||
# preventing the LLM from receiving a non-functional token.
|
||||
refresh_failed = True
|
||||
continue
|
||||
_token_cache[cache_key] = token
|
||||
return token
|
||||
@@ -152,8 +172,12 @@ async def get_provider_token(user_id: str, provider: str) -> str | None:
|
||||
_token_cache[cache_key] = token
|
||||
return token
|
||||
|
||||
# No credentials found — cache to avoid repeated DB hits.
|
||||
_null_cache[cache_key] = True
|
||||
# Only cache "not connected" when the user truly has no credentials for this
|
||||
# provider. If we had OAuth credentials but refresh failed (e.g. transient
|
||||
# network error, event-loop mismatch), do NOT cache the negative result —
|
||||
# the next call should retry the refresh instead of being blocked for 60 s.
|
||||
if not refresh_failed:
|
||||
_null_cache[cache_key] = True
|
||||
return None
|
||||
|
||||
|
||||
@@ -171,3 +195,76 @@ async def get_integration_env_vars(user_id: str) -> dict[str, str]:
|
||||
for var in var_names:
|
||||
env[var] = token
|
||||
return env
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GitHub user identity (for git committer env vars)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def get_github_user_git_identity(user_id: str) -> dict[str, str] | None:
|
||||
"""Fetch the GitHub user's name and email for git committer env vars.
|
||||
|
||||
Uses the ``/user`` GitHub API endpoint with the user's stored token.
|
||||
Returns a dict with ``GIT_AUTHOR_NAME``, ``GIT_AUTHOR_EMAIL``,
|
||||
``GIT_COMMITTER_NAME``, and ``GIT_COMMITTER_EMAIL`` if the user has a
|
||||
connected GitHub account. Returns ``None`` otherwise.
|
||||
|
||||
Results are cached for 10 minutes; "not connected" results are cached for
|
||||
60 s (same as null-token cache).
|
||||
"""
|
||||
if user_id in _gh_identity_null_cache:
|
||||
return None
|
||||
if cached := _gh_identity_cache.get(user_id):
|
||||
return cached
|
||||
|
||||
token = await get_provider_token(user_id, "github")
|
||||
if not token:
|
||||
_gh_identity_null_cache[user_id] = True
|
||||
return None
|
||||
|
||||
import aiohttp
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
"https://api.github.com/user",
|
||||
headers={
|
||||
"Authorization": f"token {token}",
|
||||
"Accept": "application/vnd.github+json",
|
||||
},
|
||||
timeout=aiohttp.ClientTimeout(total=5),
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
logger.warning(
|
||||
"[git-identity] GitHub /user returned %s for user %s",
|
||||
resp.status,
|
||||
user_id,
|
||||
)
|
||||
return None
|
||||
data = await resp.json()
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"[git-identity] Failed to fetch GitHub profile for user %s: %s",
|
||||
user_id,
|
||||
exc,
|
||||
)
|
||||
return None
|
||||
|
||||
name = data.get("name") or data.get("login") or "AutoGPT User"
|
||||
# GitHub may return email=null if the user has set their email to private.
|
||||
# Fall back to the noreply address GitHub generates for every account.
|
||||
email = data.get("email")
|
||||
if not email:
|
||||
gh_id = data.get("id", "")
|
||||
login = data.get("login", "user")
|
||||
email = f"{gh_id}+{login}@users.noreply.github.com"
|
||||
|
||||
identity = {
|
||||
"GIT_AUTHOR_NAME": name,
|
||||
"GIT_AUTHOR_EMAIL": email,
|
||||
"GIT_COMMITTER_NAME": name,
|
||||
"GIT_COMMITTER_EMAIL": email,
|
||||
}
|
||||
_gh_identity_cache[user_id] = identity
|
||||
return identity
|
||||
|
||||
@@ -9,6 +9,8 @@ from backend.copilot.integration_creds import (
|
||||
_NULL_CACHE_TTL,
|
||||
_TOKEN_CACHE_TTL,
|
||||
PROVIDER_ENV_VARS,
|
||||
_gh_identity_cache,
|
||||
_gh_identity_null_cache,
|
||||
_null_cache,
|
||||
_token_cache,
|
||||
get_integration_env_vars,
|
||||
@@ -49,9 +51,13 @@ def clear_caches():
|
||||
"""Ensure clean caches before and after every test."""
|
||||
_token_cache.clear()
|
||||
_null_cache.clear()
|
||||
_gh_identity_cache.clear()
|
||||
_gh_identity_null_cache.clear()
|
||||
yield
|
||||
_token_cache.clear()
|
||||
_null_cache.clear()
|
||||
_gh_identity_cache.clear()
|
||||
_gh_identity_null_cache.clear()
|
||||
|
||||
|
||||
class TestInvalidateUserProviderCache:
|
||||
@@ -77,6 +83,34 @@ class TestInvalidateUserProviderCache:
|
||||
invalidate_user_provider_cache(_USER, _PROVIDER)
|
||||
assert other_key in _token_cache
|
||||
|
||||
def test_clears_gh_identity_cache_for_github_provider(self):
|
||||
"""When provider is 'github', identity caches must also be cleared."""
|
||||
_gh_identity_cache[_USER] = {
|
||||
"GIT_AUTHOR_NAME": "Old Name",
|
||||
"GIT_AUTHOR_EMAIL": "old@example.com",
|
||||
"GIT_COMMITTER_NAME": "Old Name",
|
||||
"GIT_COMMITTER_EMAIL": "old@example.com",
|
||||
}
|
||||
invalidate_user_provider_cache(_USER, "github")
|
||||
assert _USER not in _gh_identity_cache
|
||||
|
||||
def test_clears_gh_identity_null_cache_for_github_provider(self):
|
||||
"""When provider is 'github', the identity null-cache must also be cleared."""
|
||||
_gh_identity_null_cache[_USER] = True
|
||||
invalidate_user_provider_cache(_USER, "github")
|
||||
assert _USER not in _gh_identity_null_cache
|
||||
|
||||
def test_does_not_clear_gh_identity_cache_for_other_providers(self):
|
||||
"""When provider is NOT 'github', identity caches must be left alone."""
|
||||
_gh_identity_cache[_USER] = {
|
||||
"GIT_AUTHOR_NAME": "Some Name",
|
||||
"GIT_AUTHOR_EMAIL": "some@example.com",
|
||||
"GIT_COMMITTER_NAME": "Some Name",
|
||||
"GIT_COMMITTER_EMAIL": "some@example.com",
|
||||
}
|
||||
invalidate_user_provider_cache(_USER, "some-other-provider")
|
||||
assert _USER in _gh_identity_cache
|
||||
|
||||
|
||||
class TestGetProviderToken:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@@ -129,8 +163,15 @@ class TestGetProviderToken:
|
||||
assert result == "oauth-tok"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth2_refresh_failure_returns_none(self):
|
||||
"""On refresh failure, return None instead of caching a stale token."""
|
||||
async def test_oauth2_refresh_failure_returns_none_without_null_cache(self):
|
||||
"""On refresh failure, return None but do NOT cache in null_cache.
|
||||
|
||||
The user has credentials — they just couldn't be refreshed right now
|
||||
(e.g. transient network error or event-loop mismatch in the copilot
|
||||
executor). Caching a negative result would block all credential
|
||||
lookups for 60 s even though the creds exist and may refresh fine
|
||||
on the next attempt.
|
||||
"""
|
||||
oauth_creds = _make_oauth2_creds("stale-oauth-tok")
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[oauth_creds])
|
||||
@@ -141,6 +182,8 @@ class TestGetProviderToken:
|
||||
|
||||
# Stale tokens must NOT be returned — forces re-auth.
|
||||
assert result is None
|
||||
# Must NOT cache negative result when refresh failed — next call retries.
|
||||
assert (_USER, _PROVIDER) not in _null_cache
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_no_credentials_caches_null_entry(self):
|
||||
@@ -176,6 +219,96 @@ class TestGetProviderToken:
|
||||
assert _NULL_CACHE_TTL < _TOKEN_CACHE_TTL
|
||||
|
||||
|
||||
class TestThreadSafetyLocks:
|
||||
"""Bug reproduction: shared AsyncRedisKeyedMutex across threads caused
|
||||
'Future attached to a different loop' when copilot workers accessed
|
||||
credentials from different event loops."""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_locks_returns_per_thread_instance(self):
|
||||
"""IntegrationCredentialsStore.locks() must return different instances
|
||||
for different threads (via @thread_cached)."""
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
|
||||
store = IntegrationCredentialsStore()
|
||||
|
||||
async def get_locks_id():
|
||||
mock_redis = AsyncMock()
|
||||
with patch(
|
||||
"backend.integrations.credentials_store.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
locks = await store.locks()
|
||||
return id(locks)
|
||||
|
||||
# Get locks from main thread
|
||||
main_id = await get_locks_id()
|
||||
|
||||
# Get locks from a worker thread
|
||||
def run_in_thread():
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(get_locks_id())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
worker_id = await asyncio.get_event_loop().run_in_executor(
|
||||
pool, run_in_thread
|
||||
)
|
||||
|
||||
assert main_id != worker_id, (
|
||||
"Store.locks() returned the same instance across threads. "
|
||||
"This would cause 'Future attached to a different loop' errors."
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_manager_delegates_to_store_locks(self):
|
||||
"""IntegrationCredentialsManager.locks() should delegate to store."""
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
|
||||
manager = IntegrationCredentialsManager()
|
||||
mock_redis = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"backend.integrations.credentials_store.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
locks = await manager.locks()
|
||||
|
||||
# Should have gotten it from the store
|
||||
assert locks is not None
|
||||
|
||||
|
||||
class TestRefreshUnlockedPath:
|
||||
"""Bug reproduction: copilot worker threads need lock-free refresh because
|
||||
Redis-backed asyncio.Lock created on one event loop can't be used on another."""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_refresh_if_needed_lock_false_skips_redis(self):
|
||||
"""refresh_if_needed(lock=False) must not touch Redis locks at all."""
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
|
||||
manager = IntegrationCredentialsManager()
|
||||
creds = _make_oauth2_creds()
|
||||
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.needs_refresh = MagicMock(return_value=False)
|
||||
|
||||
with patch(
|
||||
"backend.integrations.creds_manager._get_provider_oauth_handler",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_handler,
|
||||
):
|
||||
result = await manager.refresh_if_needed(_USER, creds, lock=False)
|
||||
|
||||
# Should return credentials without touching locks
|
||||
assert result.id == creds.id
|
||||
|
||||
|
||||
class TestGetIntegrationEnvVars:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_injects_all_env_vars_for_provider(self):
|
||||
|
||||
@@ -66,6 +66,7 @@ from pydantic import BaseModel, PrivateAttr
|
||||
ToolName = Literal[
|
||||
# Platform tools (must match keys in TOOL_REGISTRY)
|
||||
"add_understanding",
|
||||
"ask_question",
|
||||
"bash_exec",
|
||||
"browser_act",
|
||||
"browser_navigate",
|
||||
@@ -102,6 +103,7 @@ ToolName = Literal[
|
||||
"web_fetch",
|
||||
"write_workspace_file",
|
||||
# SDK built-ins
|
||||
"Agent",
|
||||
"Edit",
|
||||
"Glob",
|
||||
"Grep",
|
||||
|
||||
@@ -544,6 +544,7 @@ class TestApplyToolPermissions:
|
||||
class TestSdkBuiltinToolNames:
|
||||
def test_expected_builtins_present(self):
|
||||
expected = {
|
||||
"Agent",
|
||||
"Read",
|
||||
"Write",
|
||||
"Edit",
|
||||
|
||||
@@ -18,6 +18,18 @@ After `write_workspace_file`, embed the `download_url` in Markdown:
|
||||
- Image: ``
|
||||
- Video: ``
|
||||
|
||||
### Handling binary/image data in tool outputs — CRITICAL
|
||||
When a tool output contains base64-encoded binary data (images, PDFs, etc.):
|
||||
1. **NEVER** try to inline or render the base64 content in your response.
|
||||
2. **Save** the data to workspace using `write_workspace_file` (pass the base64 data URI as content).
|
||||
3. **Show** the result via the workspace download URL in Markdown: ``.
|
||||
|
||||
### Passing large data between tools — CRITICAL
|
||||
When tool outputs produce large text that you need to feed into another tool:
|
||||
- **NEVER** copy-paste the full text into the next tool call argument.
|
||||
- **Save** the output to a file (workspace or local), then use `@@agptfile:` references.
|
||||
- This avoids token limits and ensures data integrity.
|
||||
|
||||
### File references — @@agptfile:
|
||||
Pass large file content to tools by reference: `@@agptfile:<uri>[<start>-<end>]`
|
||||
- `workspace://<file_id>` or `workspace:///<path>` — workspace files
|
||||
@@ -114,6 +126,21 @@ After building the file, reference it with `@@agptfile:` in other tools:
|
||||
- When spawning sub-agents for research, ensure each has a distinct
|
||||
non-overlapping scope to avoid redundant searches.
|
||||
|
||||
|
||||
### Tool Discovery Priority
|
||||
|
||||
When the user asks to interact with a service or API, follow this order:
|
||||
|
||||
1. **find_block first** — Search platform blocks with `find_block`. The platform has hundreds of built-in blocks (Google Sheets, Docs, Calendar, Gmail, Slack, GitHub, etc.) that work without extra setup.
|
||||
|
||||
2. **run_mcp_tool** — If no matching block exists, check if a hosted MCP server is available for the service. Only use known MCP server URLs from the registry.
|
||||
|
||||
3. **SendAuthenticatedWebRequestBlock** — If no block or MCP server exists, use `SendAuthenticatedWebRequestBlock` with existing host-scoped credentials. Check available credentials via `connect_integration`.
|
||||
|
||||
4. **Manual API call** — As a last resort, guide the user to set up credentials and use `SendAuthenticatedWebRequestBlock` with direct API calls.
|
||||
|
||||
**Never skip step 1.** Built-in blocks are more reliable, tested, and user-friendly than MCP or raw API calls.
|
||||
|
||||
### Sub-agent tasks
|
||||
- When using the Task tool, NEVER set `run_in_background` to true.
|
||||
All tasks must run in the foreground.
|
||||
@@ -138,6 +165,11 @@ parent autopilot handles orchestration.
|
||||
# E2B-only notes — E2B has full internet access so gh CLI works there.
|
||||
# Not shown in local (bubblewrap) mode: --unshare-net blocks all network.
|
||||
_E2B_TOOL_NOTES = """
|
||||
### SDK tool-result files in E2B
|
||||
When you `Read` an SDK tool-result file, it is automatically copied into the
|
||||
sandbox so `bash_exec` can access it for further processing.
|
||||
The exact sandbox path is shown in the `[Sandbox copy available at ...]` note.
|
||||
|
||||
### GitHub CLI (`gh`) and git
|
||||
- If the user has connected their GitHub account, both `gh` and `git` are
|
||||
pre-authenticated — use them directly without any manual login step.
|
||||
@@ -203,19 +235,22 @@ def _build_storage_supplement(
|
||||
- Files here **survive across sessions indefinitely**
|
||||
|
||||
### Moving files between storages
|
||||
- **{file_move_name_1_to_2}**: Copy to persistent workspace
|
||||
- **{file_move_name_2_to_1}**: Download for processing
|
||||
- **{file_move_name_1_to_2}**: `write_workspace_file(filename="output.json", source_path="/path/to/local/file")`
|
||||
- **{file_move_name_2_to_1}**: `read_workspace_file(path="tool-outputs/data.json", save_to_path="{working_dir}/data.json")`
|
||||
|
||||
### File persistence
|
||||
Important files (code, configs, outputs) should be saved to workspace to ensure they persist.
|
||||
|
||||
### SDK tool-result files
|
||||
When tool outputs are large, the SDK truncates them and saves the full output to
|
||||
a local file under `~/.claude/projects/.../tool-results/`. To read these files,
|
||||
always use `Read` (NOT `bash_exec`, NOT `read_workspace_file`).
|
||||
These files are on the host filesystem — `bash_exec` runs in the sandbox and
|
||||
CANNOT access them. `read_workspace_file` reads from cloud workspace storage,
|
||||
where SDK tool-results are NOT stored.
|
||||
a local file under `~/.claude/projects/.../tool-results/` (or `tool-outputs/`).
|
||||
To read these files, use `Read` — it reads from the host filesystem.
|
||||
|
||||
### Large tool outputs saved to workspace
|
||||
When a tool output contains `<tool-output-truncated workspace_path="...">`, the
|
||||
full output is in workspace storage (NOT on the local filesystem). To access it:
|
||||
- Use `read_workspace_file(path="...", offset=..., length=50000)` for reading sections.
|
||||
- To process in the sandbox, use `read_workspace_file(path="...", save_to_path="{working_dir}/file.json")` first, then use `bash_exec` on the local copy.
|
||||
{_SHARED_TOOL_NOTES}{extra_notes}"""
|
||||
|
||||
|
||||
|
||||
@@ -6,16 +6,23 @@ from pathlib import Path
|
||||
class TestAgentGenerationGuideContainsClarifySection:
|
||||
"""The agent generation guide must include the clarification section."""
|
||||
|
||||
def test_guide_includes_clarify_before_building(self):
|
||||
def test_guide_includes_clarify_section(self):
|
||||
guide_path = Path(__file__).parent / "sdk" / "agent_generation_guide.md"
|
||||
content = guide_path.read_text(encoding="utf-8")
|
||||
assert "Clarifying Before Building" in content
|
||||
assert "Before or During Building" in content
|
||||
|
||||
def test_guide_mentions_find_block_for_clarification(self):
|
||||
guide_path = Path(__file__).parent / "sdk" / "agent_generation_guide.md"
|
||||
content = guide_path.read_text(encoding="utf-8")
|
||||
# find_block must appear in the clarification section (before the workflow)
|
||||
clarify_section = content.split("Clarifying Before Building")[1].split(
|
||||
clarify_section = content.split("Before or During Building")[1].split(
|
||||
"### Workflow"
|
||||
)[0]
|
||||
assert "find_block" in clarify_section
|
||||
|
||||
def test_guide_mentions_ask_question_tool(self):
|
||||
guide_path = Path(__file__).parent / "sdk" / "agent_generation_guide.md"
|
||||
content = guide_path.read_text(encoding="utf-8")
|
||||
clarify_section = content.split("Before or During Building")[1].split(
|
||||
"### Workflow"
|
||||
)[0]
|
||||
assert "ask_question" in clarify_section
|
||||
|
||||
@@ -9,11 +9,14 @@ UTC). Fails open when Redis is unavailable to avoid blocking users.
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from enum import Enum
|
||||
|
||||
from prisma.models import User as PrismaUser
|
||||
from pydantic import BaseModel, Field
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.util.cache import cached
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -21,6 +24,40 @@ logger = logging.getLogger(__name__)
|
||||
_USAGE_KEY_PREFIX = "copilot:usage"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Subscription tier definitions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SubscriptionTier(str, Enum):
|
||||
"""Subscription tiers with increasing token allowances.
|
||||
|
||||
Mirrors the ``SubscriptionTier`` enum in ``schema.prisma``.
|
||||
Once ``prisma generate`` is run, this can be replaced with::
|
||||
|
||||
from prisma.enums import SubscriptionTier
|
||||
"""
|
||||
|
||||
FREE = "FREE"
|
||||
PRO = "PRO"
|
||||
BUSINESS = "BUSINESS"
|
||||
ENTERPRISE = "ENTERPRISE"
|
||||
|
||||
|
||||
# Multiplier applied to the base limits (from LD / config) for each tier.
|
||||
# Intentionally int (not float): keeps limits as whole token counts and avoids
|
||||
# floating-point rounding. If fractional multipliers are ever needed, change
|
||||
# the type and round the result in get_global_rate_limits().
|
||||
TIER_MULTIPLIERS: dict[SubscriptionTier, int] = {
|
||||
SubscriptionTier.FREE: 1,
|
||||
SubscriptionTier.PRO: 5,
|
||||
SubscriptionTier.BUSINESS: 20,
|
||||
SubscriptionTier.ENTERPRISE: 60,
|
||||
}
|
||||
|
||||
DEFAULT_TIER = SubscriptionTier.FREE
|
||||
|
||||
|
||||
class UsageWindow(BaseModel):
|
||||
"""Usage within a single time window."""
|
||||
|
||||
@@ -36,6 +73,7 @@ class CoPilotUsageStatus(BaseModel):
|
||||
|
||||
daily: UsageWindow
|
||||
weekly: UsageWindow
|
||||
tier: SubscriptionTier = DEFAULT_TIER
|
||||
reset_cost: int = Field(
|
||||
default=0,
|
||||
description="Credit cost (in cents) to reset the daily limit. 0 = feature disabled.",
|
||||
@@ -66,6 +104,7 @@ async def get_usage_status(
|
||||
daily_token_limit: int,
|
||||
weekly_token_limit: int,
|
||||
rate_limit_reset_cost: int = 0,
|
||||
tier: SubscriptionTier = DEFAULT_TIER,
|
||||
) -> CoPilotUsageStatus:
|
||||
"""Get current usage status for a user.
|
||||
|
||||
@@ -74,6 +113,7 @@ async def get_usage_status(
|
||||
daily_token_limit: Max tokens per day (0 = unlimited).
|
||||
weekly_token_limit: Max tokens per week (0 = unlimited).
|
||||
rate_limit_reset_cost: Credit cost (cents) to reset daily limit (0 = disabled).
|
||||
tier: The user's rate-limit tier (included in the response).
|
||||
|
||||
Returns:
|
||||
CoPilotUsageStatus with current usage and limits.
|
||||
@@ -103,6 +143,7 @@ async def get_usage_status(
|
||||
limit=weekly_token_limit,
|
||||
resets_at=_weekly_reset_time(now=now),
|
||||
),
|
||||
tier=tier,
|
||||
reset_cost=rate_limit_reset_cost,
|
||||
)
|
||||
|
||||
@@ -161,8 +202,9 @@ async def reset_daily_usage(user_id: str, daily_token_limit: int = 0) -> bool:
|
||||
daily_token_limit: The configured daily token limit. When positive,
|
||||
the weekly counter is reduced by this amount.
|
||||
|
||||
Fails open: returns False if Redis is unavailable (consistent with
|
||||
the fail-open design of this module).
|
||||
Returns False if Redis is unavailable so the caller can handle
|
||||
compensation (fail-closed for billed operations, unlike the read-only
|
||||
rate-limit checks which fail-open).
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
try:
|
||||
@@ -342,20 +384,100 @@ async def record_token_usage(
|
||||
)
|
||||
|
||||
|
||||
class _UserNotFoundError(Exception):
|
||||
"""Raised when a user record is missing or has no subscription tier.
|
||||
|
||||
Used internally by ``_fetch_user_tier`` to signal a cache-miss condition:
|
||||
by raising instead of returning ``DEFAULT_TIER``, we prevent the ``@cached``
|
||||
decorator from storing the fallback value. This avoids a race condition
|
||||
where a non-existent user's DEFAULT_TIER is cached, then the user is
|
||||
created with a higher tier but receives the stale cached FREE tier for
|
||||
up to 5 minutes.
|
||||
"""
|
||||
|
||||
|
||||
@cached(maxsize=1000, ttl_seconds=300, shared_cache=True)
|
||||
async def _fetch_user_tier(user_id: str) -> SubscriptionTier:
|
||||
"""Fetch the user's rate-limit tier from the database (cached via Redis).
|
||||
|
||||
Uses ``shared_cache=True`` so that tier changes propagate across all pods
|
||||
immediately when the cache entry is invalidated (via ``cache_delete``).
|
||||
|
||||
Only successful DB lookups of existing users with a valid tier are cached.
|
||||
Raises ``_UserNotFoundError`` when the user is missing or has no tier, so
|
||||
the ``@cached`` decorator does **not** store a fallback value. This
|
||||
prevents a race condition where a non-existent user's ``DEFAULT_TIER`` is
|
||||
cached and then persists after the user is created with a higher tier.
|
||||
"""
|
||||
user = await PrismaUser.prisma().find_unique(where={"id": user_id})
|
||||
if user and user.subscriptionTier: # type: ignore[reportAttributeAccessIssue]
|
||||
return SubscriptionTier(user.subscriptionTier) # type: ignore[reportAttributeAccessIssue]
|
||||
raise _UserNotFoundError(user_id)
|
||||
|
||||
|
||||
async def get_user_tier(user_id: str) -> SubscriptionTier:
|
||||
"""Look up the user's rate-limit tier from the database.
|
||||
|
||||
Successful results are cached for 5 minutes (via ``_fetch_user_tier``)
|
||||
to avoid a DB round-trip on every rate-limit check.
|
||||
|
||||
Falls back to ``DEFAULT_TIER`` **without caching** when the DB is
|
||||
unreachable or returns an unrecognised value, so the next call retries
|
||||
the query instead of serving a stale fallback for up to 5 minutes.
|
||||
"""
|
||||
try:
|
||||
return await _fetch_user_tier(user_id)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to resolve rate-limit tier for user %s, defaulting to %s: %s",
|
||||
user_id[:8],
|
||||
DEFAULT_TIER.value,
|
||||
exc,
|
||||
)
|
||||
return DEFAULT_TIER
|
||||
|
||||
|
||||
# Expose cache management on the public function so callers (including tests)
|
||||
# never need to reach into the private ``_fetch_user_tier``.
|
||||
get_user_tier.cache_clear = _fetch_user_tier.cache_clear # type: ignore[attr-defined]
|
||||
get_user_tier.cache_delete = _fetch_user_tier.cache_delete # type: ignore[attr-defined]
|
||||
|
||||
|
||||
async def set_user_tier(user_id: str, tier: SubscriptionTier) -> None:
|
||||
"""Persist the user's rate-limit tier to the database.
|
||||
|
||||
Also invalidates the ``get_user_tier`` cache for this user so that
|
||||
subsequent rate-limit checks immediately see the new tier.
|
||||
|
||||
Raises:
|
||||
prisma.errors.RecordNotFoundError: If the user does not exist.
|
||||
"""
|
||||
await PrismaUser.prisma().update(
|
||||
where={"id": user_id},
|
||||
data={"subscriptionTier": tier.value},
|
||||
)
|
||||
# Invalidate cached tier so rate-limit checks pick up the change immediately.
|
||||
get_user_tier.cache_delete(user_id) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
async def get_global_rate_limits(
|
||||
user_id: str,
|
||||
config_daily: int,
|
||||
config_weekly: int,
|
||||
) -> tuple[int, int]:
|
||||
) -> tuple[int, int, SubscriptionTier]:
|
||||
"""Resolve global rate limits from LaunchDarkly, falling back to config.
|
||||
|
||||
The base limits (from LD or config) are multiplied by the user's
|
||||
tier multiplier so that higher tiers receive proportionally larger
|
||||
allowances.
|
||||
|
||||
Args:
|
||||
user_id: User ID for LD flag evaluation context.
|
||||
config_daily: Fallback daily limit from ChatConfig.
|
||||
config_weekly: Fallback weekly limit from ChatConfig.
|
||||
|
||||
Returns:
|
||||
(daily_token_limit, weekly_token_limit) tuple.
|
||||
(daily_token_limit, weekly_token_limit, tier) 3-tuple.
|
||||
"""
|
||||
# Lazy import to avoid circular dependency:
|
||||
# rate_limit -> feature_flag -> settings -> ... -> rate_limit
|
||||
@@ -377,7 +499,15 @@ async def get_global_rate_limits(
|
||||
except (TypeError, ValueError):
|
||||
logger.warning("Invalid LD value for weekly token limit: %r", weekly_raw)
|
||||
weekly = config_weekly
|
||||
return daily, weekly
|
||||
|
||||
# Apply tier multiplier
|
||||
tier = await get_user_tier(user_id)
|
||||
multiplier = TIER_MULTIPLIERS.get(tier, 1)
|
||||
if multiplier != 1:
|
||||
daily = daily * multiplier
|
||||
weekly = weekly * multiplier
|
||||
|
||||
return daily, weekly, tier
|
||||
|
||||
|
||||
async def reset_user_usage(user_id: str, *, reset_weekly: bool = False) -> None:
|
||||
|
||||
@@ -7,12 +7,19 @@ import pytest
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from .rate_limit import (
|
||||
DEFAULT_TIER,
|
||||
TIER_MULTIPLIERS,
|
||||
CoPilotUsageStatus,
|
||||
RateLimitExceeded,
|
||||
SubscriptionTier,
|
||||
UsageWindow,
|
||||
check_rate_limit,
|
||||
get_global_rate_limits,
|
||||
get_usage_status,
|
||||
get_user_tier,
|
||||
record_token_usage,
|
||||
reset_daily_usage,
|
||||
set_user_tier,
|
||||
)
|
||||
|
||||
_USER = "test-user-rl"
|
||||
@@ -335,6 +342,524 @@ class TestRecordTokenUsage:
|
||||
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SubscriptionTier and tier multipliers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSubscriptionTier:
|
||||
def test_tier_values(self):
|
||||
assert SubscriptionTier.FREE.value == "FREE"
|
||||
assert SubscriptionTier.PRO.value == "PRO"
|
||||
assert SubscriptionTier.BUSINESS.value == "BUSINESS"
|
||||
assert SubscriptionTier.ENTERPRISE.value == "ENTERPRISE"
|
||||
|
||||
def test_tier_multipliers(self):
|
||||
assert TIER_MULTIPLIERS[SubscriptionTier.FREE] == 1
|
||||
assert TIER_MULTIPLIERS[SubscriptionTier.PRO] == 5
|
||||
assert TIER_MULTIPLIERS[SubscriptionTier.BUSINESS] == 20
|
||||
assert TIER_MULTIPLIERS[SubscriptionTier.ENTERPRISE] == 60
|
||||
|
||||
def test_default_tier_is_free(self):
|
||||
assert DEFAULT_TIER == SubscriptionTier.FREE
|
||||
|
||||
def test_usage_status_includes_tier(self):
|
||||
now = datetime.now(UTC)
|
||||
status = CoPilotUsageStatus(
|
||||
daily=UsageWindow(used=0, limit=100, resets_at=now + timedelta(hours=1)),
|
||||
weekly=UsageWindow(used=0, limit=500, resets_at=now + timedelta(days=1)),
|
||||
)
|
||||
assert status.tier == SubscriptionTier.FREE
|
||||
|
||||
def test_usage_status_with_custom_tier(self):
|
||||
now = datetime.now(UTC)
|
||||
status = CoPilotUsageStatus(
|
||||
daily=UsageWindow(used=0, limit=100, resets_at=now + timedelta(hours=1)),
|
||||
weekly=UsageWindow(used=0, limit=500, resets_at=now + timedelta(days=1)),
|
||||
tier=SubscriptionTier.PRO,
|
||||
)
|
||||
assert status.tier == SubscriptionTier.PRO
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_user_tier
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetUserTier:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_tier_cache(self):
|
||||
"""Clear the get_user_tier cache before each test."""
|
||||
get_user_tier.cache_clear() # type: ignore[attr-defined]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_tier_from_db(self):
|
||||
"""Should return the tier stored in the user record."""
|
||||
mock_user = MagicMock()
|
||||
mock_user.subscriptionTier = "PRO"
|
||||
|
||||
mock_prisma = AsyncMock()
|
||||
mock_prisma.find_unique = AsyncMock(return_value=mock_user)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=mock_prisma,
|
||||
):
|
||||
tier = await get_user_tier(_USER)
|
||||
|
||||
assert tier == SubscriptionTier.PRO
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_default_when_user_not_found(self):
|
||||
"""Should return DEFAULT_TIER when user is not in the DB."""
|
||||
mock_prisma = AsyncMock()
|
||||
mock_prisma.find_unique = AsyncMock(return_value=None)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=mock_prisma,
|
||||
):
|
||||
tier = await get_user_tier(_USER)
|
||||
|
||||
assert tier == DEFAULT_TIER
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_default_when_tier_is_none(self):
|
||||
"""Should return DEFAULT_TIER when subscriptionTier is None."""
|
||||
mock_user = MagicMock()
|
||||
mock_user.subscriptionTier = None
|
||||
|
||||
mock_prisma = AsyncMock()
|
||||
mock_prisma.find_unique = AsyncMock(return_value=mock_user)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=mock_prisma,
|
||||
):
|
||||
tier = await get_user_tier(_USER)
|
||||
|
||||
assert tier == DEFAULT_TIER
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_default_on_db_error(self):
|
||||
"""Should fall back to DEFAULT_TIER when DB raises."""
|
||||
mock_prisma = AsyncMock()
|
||||
mock_prisma.find_unique = AsyncMock(side_effect=Exception("DB down"))
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=mock_prisma,
|
||||
):
|
||||
tier = await get_user_tier(_USER)
|
||||
|
||||
assert tier == DEFAULT_TIER
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_db_error_is_not_cached(self):
|
||||
"""Transient DB errors should NOT cache the default tier.
|
||||
|
||||
Regression test: a transient DB failure previously cached DEFAULT_TIER
|
||||
for 5 minutes, incorrectly downgrading higher-tier users until expiry.
|
||||
"""
|
||||
failing_prisma = AsyncMock()
|
||||
failing_prisma.find_unique = AsyncMock(side_effect=Exception("DB down"))
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=failing_prisma,
|
||||
):
|
||||
tier1 = await get_user_tier(_USER)
|
||||
assert tier1 == DEFAULT_TIER
|
||||
|
||||
# Now DB recovers and returns PRO
|
||||
mock_user = MagicMock()
|
||||
mock_user.subscriptionTier = "PRO"
|
||||
ok_prisma = AsyncMock()
|
||||
ok_prisma.find_unique = AsyncMock(return_value=mock_user)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=ok_prisma,
|
||||
):
|
||||
tier2 = await get_user_tier(_USER)
|
||||
|
||||
# Should get PRO now — the error result was not cached
|
||||
assert tier2 == SubscriptionTier.PRO
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_default_on_invalid_tier_value(self):
|
||||
"""Should fall back to DEFAULT_TIER when stored value is invalid."""
|
||||
mock_user = MagicMock()
|
||||
mock_user.subscriptionTier = "invalid-tier"
|
||||
|
||||
mock_prisma = AsyncMock()
|
||||
mock_prisma.find_unique = AsyncMock(return_value=mock_user)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=mock_prisma,
|
||||
):
|
||||
tier = await get_user_tier(_USER)
|
||||
|
||||
assert tier == DEFAULT_TIER
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_not_found_is_not_cached(self):
|
||||
"""Non-existent user should NOT cache DEFAULT_TIER.
|
||||
|
||||
Regression test: when ``get_user_tier`` is called before a user record
|
||||
exists, the DEFAULT_TIER fallback must not be cached. Otherwise, a
|
||||
newly created user with a higher tier (e.g. PRO) would receive the
|
||||
stale cached FREE tier for up to 5 minutes.
|
||||
"""
|
||||
# First call: user does not exist yet
|
||||
missing_prisma = AsyncMock()
|
||||
missing_prisma.find_unique = AsyncMock(return_value=None)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=missing_prisma,
|
||||
):
|
||||
tier1 = await get_user_tier(_USER)
|
||||
assert tier1 == DEFAULT_TIER
|
||||
|
||||
# Second call: user now exists with PRO tier
|
||||
mock_user = MagicMock()
|
||||
mock_user.subscriptionTier = "PRO"
|
||||
ok_prisma = AsyncMock()
|
||||
ok_prisma.find_unique = AsyncMock(return_value=mock_user)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=ok_prisma,
|
||||
):
|
||||
tier2 = await get_user_tier(_USER)
|
||||
|
||||
# Should get PRO — the not-found result was not cached
|
||||
assert tier2 == SubscriptionTier.PRO
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# set_user_tier
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSetUserTier:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_tier_cache(self):
|
||||
"""Clear the get_user_tier cache before each test."""
|
||||
get_user_tier.cache_clear() # type: ignore[attr-defined]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_updates_db_and_invalidates_cache(self):
|
||||
"""set_user_tier should persist to DB and invalidate the tier cache."""
|
||||
mock_prisma = AsyncMock()
|
||||
mock_prisma.update = AsyncMock(return_value=None)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=mock_prisma,
|
||||
):
|
||||
await set_user_tier(_USER, SubscriptionTier.PRO)
|
||||
|
||||
mock_prisma.update.assert_awaited_once_with(
|
||||
where={"id": _USER},
|
||||
data={"subscriptionTier": "PRO"},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_not_found_propagates(self):
|
||||
"""RecordNotFoundError from Prisma should propagate to callers."""
|
||||
import prisma.errors
|
||||
|
||||
mock_prisma = AsyncMock()
|
||||
mock_prisma.update = AsyncMock(
|
||||
side_effect=prisma.errors.RecordNotFoundError(
|
||||
{"error": "Record not found"}
|
||||
),
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=mock_prisma,
|
||||
):
|
||||
with pytest.raises(prisma.errors.RecordNotFoundError):
|
||||
await set_user_tier(_USER, SubscriptionTier.ENTERPRISE)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_invalidated_after_set(self):
|
||||
"""After set_user_tier, get_user_tier should query DB again (not cache)."""
|
||||
# First, populate the cache with BUSINESS
|
||||
mock_user_biz = MagicMock()
|
||||
mock_user_biz.subscriptionTier = "BUSINESS"
|
||||
mock_prisma_get = AsyncMock()
|
||||
mock_prisma_get.find_unique = AsyncMock(return_value=mock_user_biz)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=mock_prisma_get,
|
||||
):
|
||||
tier_before = await get_user_tier(_USER)
|
||||
assert tier_before == SubscriptionTier.BUSINESS
|
||||
|
||||
# Now set tier to ENTERPRISE (this should invalidate the cache)
|
||||
mock_prisma_set = AsyncMock()
|
||||
mock_prisma_set.update = AsyncMock(return_value=None)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=mock_prisma_set,
|
||||
):
|
||||
await set_user_tier(_USER, SubscriptionTier.ENTERPRISE)
|
||||
|
||||
# Now get_user_tier should hit DB again (cache was invalidated)
|
||||
mock_user_ent = MagicMock()
|
||||
mock_user_ent.subscriptionTier = "ENTERPRISE"
|
||||
mock_prisma_get2 = AsyncMock()
|
||||
mock_prisma_get2.find_unique = AsyncMock(return_value=mock_user_ent)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.rate_limit.PrismaUser.prisma",
|
||||
return_value=mock_prisma_get2,
|
||||
):
|
||||
tier_after = await get_user_tier(_USER)
|
||||
|
||||
assert tier_after == SubscriptionTier.ENTERPRISE
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_global_rate_limits with tiers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetGlobalRateLimitsWithTiers:
|
||||
@staticmethod
|
||||
def _ld_side_effect(daily: int, weekly: int):
|
||||
"""Return an async side_effect that dispatches by flag_key."""
|
||||
|
||||
async def _side_effect(flag_key: str, _uid: str, default: int) -> int:
|
||||
if "daily" in flag_key.lower():
|
||||
return daily
|
||||
if "weekly" in flag_key.lower():
|
||||
return weekly
|
||||
return default
|
||||
|
||||
return _side_effect
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_free_tier_no_multiplier(self):
|
||||
"""Free tier should not change limits."""
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.FREE,
|
||||
),
|
||||
patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
side_effect=self._ld_side_effect(2_500_000, 12_500_000),
|
||||
),
|
||||
):
|
||||
daily, weekly, tier = await get_global_rate_limits(
|
||||
_USER, 2_500_000, 12_500_000
|
||||
)
|
||||
|
||||
assert daily == 2_500_000
|
||||
assert weekly == 12_500_000
|
||||
assert tier == SubscriptionTier.FREE
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pro_tier_5x_multiplier(self):
|
||||
"""Pro tier should multiply limits by 5."""
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.PRO,
|
||||
),
|
||||
patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
side_effect=self._ld_side_effect(2_500_000, 12_500_000),
|
||||
),
|
||||
):
|
||||
daily, weekly, tier = await get_global_rate_limits(
|
||||
_USER, 2_500_000, 12_500_000
|
||||
)
|
||||
|
||||
assert daily == 12_500_000
|
||||
assert weekly == 62_500_000
|
||||
assert tier == SubscriptionTier.PRO
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_business_tier_20x_multiplier(self):
|
||||
"""Business tier should multiply limits by 20."""
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.BUSINESS,
|
||||
),
|
||||
patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
side_effect=self._ld_side_effect(2_500_000, 12_500_000),
|
||||
),
|
||||
):
|
||||
daily, weekly, tier = await get_global_rate_limits(
|
||||
_USER, 2_500_000, 12_500_000
|
||||
)
|
||||
|
||||
assert daily == 50_000_000
|
||||
assert weekly == 250_000_000
|
||||
assert tier == SubscriptionTier.BUSINESS
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enterprise_tier_60x_multiplier(self):
|
||||
"""Enterprise tier should multiply limits by 60."""
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.ENTERPRISE,
|
||||
),
|
||||
patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
side_effect=self._ld_side_effect(2_500_000, 12_500_000),
|
||||
),
|
||||
):
|
||||
daily, weekly, tier = await get_global_rate_limits(
|
||||
_USER, 2_500_000, 12_500_000
|
||||
)
|
||||
|
||||
assert daily == 150_000_000
|
||||
assert weekly == 750_000_000
|
||||
assert tier == SubscriptionTier.ENTERPRISE
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# End-to-end: tier limits are respected by check_rate_limit
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTierLimitsRespected:
|
||||
"""Verify that tier-adjusted limits from get_global_rate_limits flow
|
||||
correctly into check_rate_limit, so higher tiers allow more usage and
|
||||
lower tiers are blocked when they would exceed their allocation."""
|
||||
|
||||
_BASE_DAILY = 2_500_000
|
||||
_BASE_WEEKLY = 12_500_000
|
||||
|
||||
@staticmethod
|
||||
def _ld_side_effect(daily: int, weekly: int):
|
||||
|
||||
async def _side_effect(flag_key: str, _uid: str, default: int) -> int:
|
||||
if "daily" in flag_key.lower():
|
||||
return daily
|
||||
if "weekly" in flag_key.lower():
|
||||
return weekly
|
||||
return default
|
||||
|
||||
return _side_effect
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pro_user_allowed_above_free_limit(self):
|
||||
"""A PRO user with usage above the FREE limit should be allowed."""
|
||||
# Usage: 3M tokens (above FREE limit of 2.5M, below PRO limit of 12.5M)
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["3000000", "3000000"])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.PRO,
|
||||
),
|
||||
patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
side_effect=self._ld_side_effect(self._BASE_DAILY, self._BASE_WEEKLY),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
),
|
||||
):
|
||||
daily, weekly, tier = await get_global_rate_limits(
|
||||
_USER, self._BASE_DAILY, self._BASE_WEEKLY
|
||||
)
|
||||
# PRO: 5x multiplier
|
||||
assert daily == 12_500_000
|
||||
assert tier == SubscriptionTier.PRO
|
||||
# Should NOT raise — 3M < 12.5M
|
||||
await check_rate_limit(
|
||||
_USER, daily_token_limit=daily, weekly_token_limit=weekly
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_free_user_blocked_at_free_limit(self):
|
||||
"""A FREE user at or above the base limit should be blocked."""
|
||||
# Usage: 2.5M tokens (at FREE limit of 2.5M)
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["2500000", "2500000"])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.FREE,
|
||||
),
|
||||
patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
side_effect=self._ld_side_effect(self._BASE_DAILY, self._BASE_WEEKLY),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
),
|
||||
):
|
||||
daily, weekly, tier = await get_global_rate_limits(
|
||||
_USER, self._BASE_DAILY, self._BASE_WEEKLY
|
||||
)
|
||||
# FREE: 1x multiplier
|
||||
assert daily == 2_500_000
|
||||
assert tier == SubscriptionTier.FREE
|
||||
# Should raise — 2.5M >= 2.5M
|
||||
with pytest.raises(RateLimitExceeded):
|
||||
await check_rate_limit(
|
||||
_USER, daily_token_limit=daily, weekly_token_limit=weekly
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enterprise_user_has_highest_headroom(self):
|
||||
"""An ENTERPRISE user should have 60x the base limit."""
|
||||
# Usage: 100M tokens (huge, but below ENTERPRISE daily of 150M)
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=["100000000", "100000000"])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.ENTERPRISE,
|
||||
),
|
||||
patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
side_effect=self._ld_side_effect(self._BASE_DAILY, self._BASE_WEEKLY),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
),
|
||||
):
|
||||
daily, weekly, tier = await get_global_rate_limits(
|
||||
_USER, self._BASE_DAILY, self._BASE_WEEKLY
|
||||
)
|
||||
assert daily == 150_000_000
|
||||
assert tier == SubscriptionTier.ENTERPRISE
|
||||
# Should NOT raise — 100M < 150M
|
||||
await check_rate_limit(
|
||||
_USER, daily_token_limit=daily, weekly_token_limit=weekly
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# reset_daily_usage
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -421,3 +946,267 @@ class TestResetDailyUsage:
|
||||
result = await reset_daily_usage(_USER, daily_token_limit=10000)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tier-limit enforcement (integration-style)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTierLimitsEnforced:
|
||||
"""Verify that tier-multiplied limits are actually respected by
|
||||
``check_rate_limit`` — i.e. that usage within the tier allowance passes
|
||||
and usage at/above the tier allowance is rejected."""
|
||||
|
||||
_BASE_DAILY = 1_000_000
|
||||
_BASE_WEEKLY = 5_000_000
|
||||
|
||||
@staticmethod
|
||||
def _ld_side_effect(daily: int, weekly: int):
|
||||
"""Mock LD flag lookup returning the given raw limits."""
|
||||
|
||||
async def _side_effect(flag_key: str, _uid: str, default: int) -> int:
|
||||
if "daily" in flag_key.lower():
|
||||
return daily
|
||||
if "weekly" in flag_key.lower():
|
||||
return weekly
|
||||
return default
|
||||
|
||||
return _side_effect
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pro_within_limit_allowed(self):
|
||||
"""Usage under PRO daily limit should not raise."""
|
||||
pro_daily = self._BASE_DAILY * TIER_MULTIPLIERS[SubscriptionTier.PRO]
|
||||
mock_redis = AsyncMock()
|
||||
# Simulate usage just under the PRO daily limit
|
||||
mock_redis.get = AsyncMock(side_effect=[str(pro_daily - 1), "0"])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.PRO,
|
||||
),
|
||||
patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
side_effect=self._ld_side_effect(self._BASE_DAILY, self._BASE_WEEKLY),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
),
|
||||
):
|
||||
daily, weekly, tier = await get_global_rate_limits(
|
||||
_USER, self._BASE_DAILY, self._BASE_WEEKLY
|
||||
)
|
||||
assert tier == SubscriptionTier.PRO
|
||||
assert daily == pro_daily
|
||||
# Should not raise — usage is under the limit
|
||||
await check_rate_limit(_USER, daily, weekly)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pro_at_limit_rejected(self):
|
||||
"""Usage at exactly the PRO daily limit should raise."""
|
||||
pro_daily = self._BASE_DAILY * TIER_MULTIPLIERS[SubscriptionTier.PRO]
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=[str(pro_daily), "0"])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.PRO,
|
||||
),
|
||||
patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
side_effect=self._ld_side_effect(self._BASE_DAILY, self._BASE_WEEKLY),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
),
|
||||
):
|
||||
daily, weekly, tier = await get_global_rate_limits(
|
||||
_USER, self._BASE_DAILY, self._BASE_WEEKLY
|
||||
)
|
||||
with pytest.raises(RateLimitExceeded) as exc_info:
|
||||
await check_rate_limit(_USER, daily, weekly)
|
||||
assert exc_info.value.window == "daily"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_business_higher_limit_allows_pro_overflow(self):
|
||||
"""Usage exceeding PRO but under BUSINESS should pass for BUSINESS."""
|
||||
pro_daily = self._BASE_DAILY * TIER_MULTIPLIERS[SubscriptionTier.PRO]
|
||||
biz_daily = self._BASE_DAILY * TIER_MULTIPLIERS[SubscriptionTier.BUSINESS]
|
||||
# Usage between PRO and BUSINESS limits
|
||||
usage = pro_daily + 1_000_000
|
||||
assert usage < biz_daily, "test sanity: usage must be under BUSINESS limit"
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=[str(usage), "0"])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.BUSINESS,
|
||||
),
|
||||
patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
side_effect=self._ld_side_effect(self._BASE_DAILY, self._BASE_WEEKLY),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
),
|
||||
):
|
||||
daily, weekly, tier = await get_global_rate_limits(
|
||||
_USER, self._BASE_DAILY, self._BASE_WEEKLY
|
||||
)
|
||||
assert tier == SubscriptionTier.BUSINESS
|
||||
assert daily == biz_daily
|
||||
# Should not raise — BUSINESS tier can handle this
|
||||
await check_rate_limit(_USER, daily, weekly)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_weekly_limit_enforced_for_tier(self):
|
||||
"""Weekly limit should also be tier-multiplied and enforced."""
|
||||
pro_weekly = self._BASE_WEEKLY * TIER_MULTIPLIERS[SubscriptionTier.PRO]
|
||||
mock_redis = AsyncMock()
|
||||
# Daily usage fine, weekly at limit
|
||||
mock_redis.get = AsyncMock(side_effect=["0", str(pro_weekly)])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.PRO,
|
||||
),
|
||||
patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
side_effect=self._ld_side_effect(self._BASE_DAILY, self._BASE_WEEKLY),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
),
|
||||
):
|
||||
daily, weekly, tier = await get_global_rate_limits(
|
||||
_USER, self._BASE_DAILY, self._BASE_WEEKLY
|
||||
)
|
||||
with pytest.raises(RateLimitExceeded) as exc_info:
|
||||
await check_rate_limit(_USER, daily, weekly)
|
||||
assert exc_info.value.window == "weekly"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_free_tier_base_limit_enforced(self):
|
||||
"""Free tier (1x multiplier) should enforce the base limit exactly."""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=[str(self._BASE_DAILY), "0"])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.FREE,
|
||||
),
|
||||
patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
side_effect=self._ld_side_effect(self._BASE_DAILY, self._BASE_WEEKLY),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
),
|
||||
):
|
||||
daily, weekly, tier = await get_global_rate_limits(
|
||||
_USER, self._BASE_DAILY, self._BASE_WEEKLY
|
||||
)
|
||||
assert daily == self._BASE_DAILY # 1x multiplier
|
||||
with pytest.raises(RateLimitExceeded):
|
||||
await check_rate_limit(_USER, daily, weekly)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_free_tier_cannot_bypass_pro_limit(self):
|
||||
"""A FREE-tier user whose usage is within PRO limits but over FREE
|
||||
limits must still be rejected.
|
||||
|
||||
Negative test: ensures the tier multiplier is applied *before* the
|
||||
rate-limit check, so a lower-tier user cannot 'bypass' limits that
|
||||
would be acceptable for a higher tier.
|
||||
"""
|
||||
free_daily = self._BASE_DAILY * TIER_MULTIPLIERS[SubscriptionTier.FREE]
|
||||
pro_daily = self._BASE_DAILY * TIER_MULTIPLIERS[SubscriptionTier.PRO]
|
||||
# Usage above FREE limit but below PRO limit
|
||||
usage = free_daily + 500_000
|
||||
assert usage < pro_daily, "test sanity: usage must be under PRO limit"
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=[str(usage), "0"])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.FREE,
|
||||
),
|
||||
patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
side_effect=self._ld_side_effect(self._BASE_DAILY, self._BASE_WEEKLY),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
),
|
||||
):
|
||||
daily, weekly, tier = await get_global_rate_limits(
|
||||
_USER, self._BASE_DAILY, self._BASE_WEEKLY
|
||||
)
|
||||
assert tier == SubscriptionTier.FREE
|
||||
assert daily == free_daily # 1x, not 5x
|
||||
with pytest.raises(RateLimitExceeded) as exc_info:
|
||||
await check_rate_limit(_USER, daily, weekly)
|
||||
assert exc_info.value.window == "daily"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tier_change_updates_effective_limits(self):
|
||||
"""After upgrading from FREE to BUSINESS, the effective limits must
|
||||
increase accordingly.
|
||||
|
||||
Verifies that the tier multiplier is correctly applied after a tier
|
||||
change, and that usage that was over the FREE limit is within the new
|
||||
BUSINESS limit.
|
||||
"""
|
||||
free_daily = self._BASE_DAILY * TIER_MULTIPLIERS[SubscriptionTier.FREE]
|
||||
biz_daily = self._BASE_DAILY * TIER_MULTIPLIERS[SubscriptionTier.BUSINESS]
|
||||
# Usage above FREE limit but below BUSINESS limit
|
||||
usage = free_daily + 500_000
|
||||
assert usage < biz_daily, "test sanity: usage must be under BUSINESS limit"
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=[str(usage), "0"])
|
||||
|
||||
# Simulate the user having been upgraded to BUSINESS
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_user_tier",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SubscriptionTier.BUSINESS,
|
||||
),
|
||||
patch(
|
||||
"backend.util.feature_flag.get_feature_flag_value",
|
||||
side_effect=self._ld_side_effect(self._BASE_DAILY, self._BASE_WEEKLY),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.rate_limit.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
),
|
||||
):
|
||||
daily, weekly, tier = await get_global_rate_limits(
|
||||
_USER, self._BASE_DAILY, self._BASE_WEEKLY
|
||||
)
|
||||
assert tier == SubscriptionTier.BUSINESS
|
||||
assert daily == biz_daily # 20x
|
||||
# Should NOT raise — usage is within the BUSINESS tier allowance
|
||||
await check_rate_limit(_USER, daily, weekly)
|
||||
|
||||
@@ -9,7 +9,7 @@ import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from backend.api.features.chat.routes import reset_copilot_usage
|
||||
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
|
||||
from backend.copilot.rate_limit import CoPilotUsageStatus, SubscriptionTier, UsageWindow
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
|
||||
|
||||
@@ -53,6 +53,18 @@ def _mock_settings(enable_credit: bool = True):
|
||||
return mock
|
||||
|
||||
|
||||
def _mock_rate_limits(
|
||||
daily: int = 2_500_000,
|
||||
weekly: int = 12_500_000,
|
||||
tier: SubscriptionTier = SubscriptionTier.PRO,
|
||||
):
|
||||
"""Mock get_global_rate_limits to return fixed limits (no tier multiplier)."""
|
||||
return patch(
|
||||
f"{_MODULE}.get_global_rate_limits",
|
||||
AsyncMock(return_value=(daily, weekly, tier)),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestResetCopilotUsage:
|
||||
async def test_feature_disabled_returns_400(self):
|
||||
@@ -70,6 +82,7 @@ class TestResetCopilotUsage:
|
||||
with (
|
||||
patch(f"{_MODULE}.config", _make_config(daily_token_limit=0)),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
_mock_rate_limits(daily=0),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await reset_copilot_usage(user_id="user-1")
|
||||
@@ -83,6 +96,7 @@ class TestResetCopilotUsage:
|
||||
with (
|
||||
patch(f"{_MODULE}.config", cfg),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
_mock_rate_limits(),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
|
||||
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
|
||||
patch(f"{_MODULE}.release_reset_lock", AsyncMock()) as mock_release,
|
||||
@@ -112,6 +126,7 @@ class TestResetCopilotUsage:
|
||||
with (
|
||||
patch(f"{_MODULE}.config", cfg),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
_mock_rate_limits(),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
|
||||
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
|
||||
patch(f"{_MODULE}.release_reset_lock", AsyncMock()) as mock_release,
|
||||
@@ -141,6 +156,7 @@ class TestResetCopilotUsage:
|
||||
with (
|
||||
patch(f"{_MODULE}.config", cfg),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
_mock_rate_limits(),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
|
||||
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
|
||||
patch(f"{_MODULE}.release_reset_lock", AsyncMock()),
|
||||
@@ -171,6 +187,7 @@ class TestResetCopilotUsage:
|
||||
with (
|
||||
patch(f"{_MODULE}.config", cfg),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
_mock_rate_limits(),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=3)),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
@@ -208,6 +225,7 @@ class TestResetCopilotUsage:
|
||||
with (
|
||||
patch(f"{_MODULE}.config", cfg),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
_mock_rate_limits(),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
|
||||
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
|
||||
patch(f"{_MODULE}.release_reset_lock", AsyncMock()) as mock_release,
|
||||
@@ -228,6 +246,7 @@ class TestResetCopilotUsage:
|
||||
with (
|
||||
patch(f"{_MODULE}.config", _make_config()),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
_mock_rate_limits(),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=None)),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
@@ -245,6 +264,7 @@ class TestResetCopilotUsage:
|
||||
with (
|
||||
patch(f"{_MODULE}.config", cfg),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
_mock_rate_limits(),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
|
||||
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
|
||||
patch(f"{_MODULE}.release_reset_lock", AsyncMock()),
|
||||
@@ -275,6 +295,7 @@ class TestResetCopilotUsage:
|
||||
with (
|
||||
patch(f"{_MODULE}.config", cfg),
|
||||
patch(f"{_MODULE}.settings", _mock_settings()),
|
||||
_mock_rate_limits(),
|
||||
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
|
||||
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
|
||||
patch(f"{_MODULE}.release_reset_lock", AsyncMock()),
|
||||
|
||||
@@ -3,41 +3,62 @@
|
||||
You can create, edit, and customize agents directly. You ARE the brain —
|
||||
generate the agent JSON yourself using block schemas, then validate and save.
|
||||
|
||||
### Clarifying Before Building
|
||||
### Clarifying — Before or During Building
|
||||
|
||||
Before starting the workflow below, check whether the user's goal is
|
||||
**ambiguous** — missing the output format, delivery channel, data source,
|
||||
or trigger. If so:
|
||||
1. Call `find_block` with a query targeting the ambiguous dimension to
|
||||
discover what the platform actually supports.
|
||||
2. Ask the user **one concrete question** grounded in the discovered
|
||||
Use `ask_question` whenever the user's intent is ambiguous — whether
|
||||
that's before starting or midway through the workflow. Common moments:
|
||||
|
||||
- **Before building**: output format, delivery channel, data source, or
|
||||
trigger is unspecified.
|
||||
- **During block discovery**: multiple blocks could fit and the user
|
||||
should choose.
|
||||
- **During JSON generation**: a wiring decision depends on user
|
||||
preference.
|
||||
|
||||
Steps:
|
||||
1. Call `find_block` (or another discovery tool) to learn what the
|
||||
platform actually supports for the ambiguous dimension.
|
||||
2. Call `ask_question` with a concrete question listing the discovered
|
||||
options (e.g. "The platform supports Gmail, Slack, and Google Docs —
|
||||
which should the agent use for delivery?").
|
||||
3. **Wait for the user's answer** before proceeding.
|
||||
3. **Wait for the user's answer** before continuing.
|
||||
|
||||
**Skip this** when the goal already specifies all dimensions (e.g.
|
||||
"scrape prices from Amazon and email me daily").
|
||||
|
||||
### Workflow for Creating/Editing Agents
|
||||
|
||||
1. **Discover blocks**: Call `find_block(query, include_schemas=true)` to
|
||||
1. **If editing**: First narrow to the specific agent by UUID, then fetch its
|
||||
graph: `find_library_agent(query="<agent_id>", include_graph=true)`. This
|
||||
returns the full graph structure (nodes + links). **Never edit blindly** —
|
||||
always inspect the current graph first so you know exactly what to change.
|
||||
Avoid using `include_graph=true` with broad keyword searches, as fetching
|
||||
multiple graphs at once is expensive and consumes LLM context budget.
|
||||
2. **Discover blocks**: Call `find_block(query, include_schemas=true)` to
|
||||
search for relevant blocks. This returns block IDs, names, descriptions,
|
||||
and full input/output schemas.
|
||||
2. **Find library agents**: Call `find_library_agent` to discover reusable
|
||||
3. **Find library agents**: Call `find_library_agent` to discover reusable
|
||||
agents that can be composed as sub-agents via `AgentExecutorBlock`.
|
||||
3. **Generate JSON**: Build the agent JSON using block schemas:
|
||||
- Use block IDs from step 1 as `block_id` in nodes
|
||||
4. **Generate/modify JSON**: Build or modify the agent JSON using block schemas:
|
||||
- Use block IDs from step 2 as `block_id` in nodes
|
||||
- Wire outputs to inputs using links
|
||||
- Set design-time config in `input_default`
|
||||
- Use `AgentInputBlock` for values the user provides at runtime
|
||||
4. **Write to workspace**: Save the JSON to a workspace file so the user
|
||||
- When editing, apply targeted changes and preserve unchanged parts
|
||||
5. **Write to workspace**: Save the JSON to a workspace file so the user
|
||||
can review it: `write_workspace_file(filename="agent.json", content=...)`
|
||||
5. **Validate**: Call `validate_agent_graph` with the agent JSON to check
|
||||
6. **Validate**: Call `validate_agent_graph` with the agent JSON to check
|
||||
for errors
|
||||
6. **Fix if needed**: Call `fix_agent_graph` to auto-fix common issues,
|
||||
7. **Fix if needed**: Call `fix_agent_graph` to auto-fix common issues,
|
||||
or fix manually based on the error descriptions. Iterate until valid.
|
||||
7. **Save**: Call `create_agent` (new) or `edit_agent` (existing) with
|
||||
8. **Save**: Call `create_agent` (new) or `edit_agent` (existing) with
|
||||
the final `agent_json`
|
||||
8. **Dry-run**: ALWAYS call `run_agent` with `dry_run=True` and
|
||||
`wait_for_result=120` to verify the agent works end-to-end.
|
||||
9. **Inspect & fix**: Check the dry-run output for errors. If issues are
|
||||
found, call `edit_agent` to fix and dry-run again. Repeat until the
|
||||
simulation passes or the problems are clearly unfixable.
|
||||
See "REQUIRED: Dry-Run Verification Loop" section below for details.
|
||||
|
||||
### Agent JSON Structure
|
||||
|
||||
@@ -89,8 +110,8 @@ These define the agent's interface — what it accepts and what it produces.
|
||||
|
||||
**AgentDropdownInputBlock** (ID: `655d6fdf-a334-421c-b733-520549c07cd1`):
|
||||
- Specialized input block that presents a dropdown/select to the user
|
||||
- Required `input_default` fields: `name` (str), `placeholder_values` (list of options, must have at least one)
|
||||
- Optional: `title`, `description`, `value` (default selection)
|
||||
- Required `input_default` fields: `name` (str)
|
||||
- Optional: `options` (list of dropdown values; when omitted/empty, input behaves as free-text), `title`, `description`, `value` (default selection)
|
||||
- Output: `result` — the user-selected value at runtime
|
||||
- Use this instead of AgentInputBlock when the user should pick from a fixed set of options
|
||||
|
||||
@@ -231,19 +252,62 @@ call in a loop until the task is complete:
|
||||
Regular blocks work exactly like sub-agents as tools — wire each input
|
||||
field from `source_name: "tools"` on the Orchestrator side.
|
||||
|
||||
### Testing with Dry Run
|
||||
### REQUIRED: Dry-Run Verification Loop (create -> dry-run -> fix)
|
||||
|
||||
After saving an agent, suggest a dry run to validate wiring without consuming
|
||||
real API calls, credentials, or credits:
|
||||
After creating or editing an agent, you MUST dry-run it before telling the
|
||||
user the agent is ready. NEVER skip this step.
|
||||
|
||||
1. **Run**: Call `run_agent` or `run_block` with `dry_run=True` and provide
|
||||
sample inputs. This executes the graph with mock outputs, verifying that
|
||||
links resolve correctly and required inputs are satisfied.
|
||||
2. **Check results**: Call `view_agent_output` with `show_execution_details=True`
|
||||
to inspect the full node-by-node execution trace. This shows what each node
|
||||
received as input and produced as output, making it easy to spot wiring issues.
|
||||
3. **Iterate**: If the dry run reveals wiring issues or missing inputs, fix
|
||||
the agent JSON and re-save before suggesting a real execution.
|
||||
#### Step-by-step workflow
|
||||
|
||||
1. **Create/Edit**: Call `create_agent` or `edit_agent` to save the agent.
|
||||
2. **Dry-run**: Call `run_agent` with `dry_run=True`, `wait_for_result=120`,
|
||||
and realistic sample inputs that exercise every path in the agent. This
|
||||
simulates execution using an LLM for each block — no real API calls,
|
||||
credentials, or credits are consumed.
|
||||
3. **Inspect output**: Examine the dry-run result for problems. If
|
||||
`wait_for_result` returns only a summary, call
|
||||
`view_agent_output(execution_id=..., show_execution_details=True)` to
|
||||
see the full node-by-node execution trace. Look for:
|
||||
- **Errors / failed nodes** — a node raised an exception or returned an
|
||||
error status. Common causes: wrong `source_name`/`sink_name` in links,
|
||||
missing `input_default` values, or referencing a nonexistent block output.
|
||||
- **Null / empty outputs** — data did not flow through a link. Verify that
|
||||
`source_name` and `sink_name` match the block schemas exactly (case-
|
||||
sensitive, including nested `_#_` notation).
|
||||
- **Nodes that never executed** — the node was not reached. Likely a
|
||||
missing or broken link from an upstream node.
|
||||
- **Unexpected values** — data arrived but in the wrong type or
|
||||
structure. Check type compatibility between linked ports.
|
||||
4. **Fix**: If any issues are found, call `edit_agent` with the corrected
|
||||
agent JSON, then go back to step 2.
|
||||
5. **Repeat**: Continue the dry-run -> fix cycle until the simulation passes
|
||||
or the problems are clearly unfixable. If you stop making progress,
|
||||
report the remaining issues to the user and ask for guidance.
|
||||
|
||||
#### Good vs bad dry-run output
|
||||
|
||||
**Good output** (agent is ready):
|
||||
- All nodes executed successfully (no errors in the execution trace)
|
||||
- Data flows through every link with non-null, correctly-typed values
|
||||
- The final `AgentOutputBlock` contains a meaningful result
|
||||
- Status is `COMPLETED`
|
||||
|
||||
**Bad output** (needs fixing):
|
||||
- Status is `FAILED` — check the error message for the failing node
|
||||
- An output node received `null` — trace back to find the broken link
|
||||
- A node received data in the wrong format (e.g. string where list expected)
|
||||
- Nodes downstream of a failing node were skipped entirely
|
||||
|
||||
**Special block behaviour in dry-run mode:**
|
||||
- **OrchestratorBlock** and **AgentExecutorBlock** execute for real so the
|
||||
orchestrator can make LLM calls and agent executors can spawn child graphs.
|
||||
Their downstream tool blocks and child-graph blocks are still simulated.
|
||||
Note: real LLM inference calls are made (consuming API quota), even though
|
||||
platform credits are not charged. Agent-mode iterations are capped at 1 in
|
||||
dry-run to keep it fast.
|
||||
- **MCPToolBlock** is simulated using the selected tool's name and JSON Schema
|
||||
so the LLM can produce a realistic mock response without connecting to the
|
||||
MCP server.
|
||||
|
||||
### Example: Simple AI Text Processor
|
||||
|
||||
|
||||
@@ -2,14 +2,30 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from backend.util import json
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session", loop_scope="session", name="server")
|
||||
async def _server_noop() -> None:
|
||||
"""No-op server stub — SDK tests don't need the full backend."""
|
||||
return None
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(
|
||||
scope="session", loop_scope="session", autouse=True, name="graph_cleanup"
|
||||
)
|
||||
async def _graph_cleanup_noop() -> AsyncIterator[None]:
|
||||
"""No-op graph cleanup stub."""
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_chat_config():
|
||||
"""Mock ChatConfig so compact_transcript tests skip real config lookup."""
|
||||
|
||||
@@ -8,6 +8,9 @@ SDK-internal paths (``~/.claude/projects/…/tool-results/``) are handled
|
||||
by the separate ``Read`` MCP tool registered in ``tool_adapter.py``.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
@@ -28,6 +31,12 @@ from backend.copilot.context import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default number of lines returned by ``read_file`` when the caller does not
|
||||
# specify a limit. Also used as the threshold in ``bridge_to_sandbox`` to
|
||||
# decide whether the model is requesting the full file (and thus whether the
|
||||
# bridge copy is worthwhile).
|
||||
_DEFAULT_READ_LIMIT = 2000
|
||||
|
||||
|
||||
async def _check_sandbox_symlink_escape(
|
||||
sandbox: Any,
|
||||
@@ -89,7 +98,7 @@ def _get_sandbox_and_path(
|
||||
return sandbox, remote
|
||||
|
||||
|
||||
async def _sandbox_write(sandbox: Any, path: str, content: str) -> None:
|
||||
async def _sandbox_write(sandbox: Any, path: str, content: str | bytes) -> None:
|
||||
"""Write *content* to *path* inside the sandbox.
|
||||
|
||||
The E2B filesystem API (``sandbox.files.write``) and the command API
|
||||
@@ -102,11 +111,14 @@ async def _sandbox_write(sandbox: Any, path: str, content: str) -> None:
|
||||
To work around this, writes targeting ``/tmp`` are performed via
|
||||
``tee`` through the command API, which runs as the sandbox ``user``
|
||||
and can therefore always overwrite user-owned files.
|
||||
|
||||
*content* may be ``str`` (text) or ``bytes`` (binary). Both paths
|
||||
are handled correctly: text is encoded to bytes for the base64 shell
|
||||
pipe, and raw bytes are passed through without any encoding.
|
||||
"""
|
||||
if path == "/tmp" or path.startswith("/tmp/"):
|
||||
import base64 as _b64
|
||||
|
||||
encoded = _b64.b64encode(content.encode()).decode()
|
||||
raw = content.encode() if isinstance(content, str) else content
|
||||
encoded = base64.b64encode(raw).decode()
|
||||
result = await sandbox.commands.run(
|
||||
f"echo {shlex.quote(encoded)} | base64 -d > {shlex.quote(path)}",
|
||||
cwd=E2B_WORKDIR,
|
||||
@@ -128,14 +140,25 @@ async def _handle_read_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Read lines from a sandbox file, falling back to the local host for SDK-internal paths."""
|
||||
file_path: str = args.get("file_path", "")
|
||||
offset: int = max(0, int(args.get("offset", 0)))
|
||||
limit: int = max(1, int(args.get("limit", 2000)))
|
||||
limit: int = max(1, int(args.get("limit", _DEFAULT_READ_LIMIT)))
|
||||
|
||||
if not file_path:
|
||||
return _mcp("file_path is required", error=True)
|
||||
|
||||
# SDK-internal paths (tool-results, ephemeral working dir) stay on the host.
|
||||
# SDK-internal paths (tool-results/tool-outputs, ephemeral working dir)
|
||||
# stay on the host. When E2B is active, also copy the file into the
|
||||
# sandbox so bash_exec can access it for further processing.
|
||||
if _is_allowed_local(file_path):
|
||||
return _read_local(file_path, offset, limit)
|
||||
result = _read_local(file_path, offset, limit)
|
||||
if not result.get("isError"):
|
||||
sandbox = _get_sandbox()
|
||||
if sandbox is not None:
|
||||
annotation = await bridge_and_annotate(
|
||||
sandbox, file_path, offset, limit
|
||||
)
|
||||
if annotation:
|
||||
result["content"][0]["text"] += annotation
|
||||
return result
|
||||
|
||||
result = _get_sandbox_and_path(file_path)
|
||||
if isinstance(result, dict):
|
||||
@@ -302,6 +325,103 @@ async def _handle_grep(args: dict[str, Any]) -> dict[str, Any]:
|
||||
return _mcp(output if output else "No matches found.")
|
||||
|
||||
|
||||
# Bridging: copy SDK-internal files into E2B sandbox
|
||||
|
||||
# Files larger than this are written to /home/user/ via sandbox.files.write()
|
||||
# instead of /tmp/ via shell base64, to avoid shell argument length limits
|
||||
# and E2B command timeouts. Base64 expands content by ~33%, so keep this
|
||||
# well under the typical Linux ARG_MAX (128 KB).
|
||||
_BRIDGE_SHELL_MAX_BYTES = 32 * 1024 # 32 KB
|
||||
# Files larger than this are skipped entirely to avoid excessive transfer times.
|
||||
_BRIDGE_SKIP_BYTES = 50 * 1024 * 1024 # 50 MB
|
||||
|
||||
|
||||
async def bridge_to_sandbox(
|
||||
sandbox: Any, file_path: str, offset: int, limit: int
|
||||
) -> str | None:
|
||||
"""Best-effort copy of a host-side SDK file into the E2B sandbox.
|
||||
|
||||
When the model reads an SDK-internal file (e.g. tool-results), it often
|
||||
wants to process the data with bash. Copying the file into the sandbox
|
||||
under a stable name lets ``bash_exec`` access it without extra steps.
|
||||
|
||||
Only copies when offset=0 and limit is large enough to indicate the model
|
||||
wants the full file. Errors are logged but never propagated.
|
||||
|
||||
Returns the sandbox path on success, or ``None`` on skip/failure.
|
||||
|
||||
Size handling:
|
||||
- <= 32 KB: written to ``/tmp/<hash>-<basename>`` via shell base64
|
||||
(``_sandbox_write``). Kept small to stay within ARG_MAX.
|
||||
- 32 KB - 50 MB: written to ``/home/user/<hash>-<basename>`` via
|
||||
``sandbox.files.write()`` to avoid shell argument length limits.
|
||||
- > 50 MB: skipped entirely with a warning.
|
||||
|
||||
The sandbox filename is prefixed with a short hash of the full source
|
||||
path to avoid collisions when different source files share the same
|
||||
basename (e.g. multiple ``result.json`` files).
|
||||
"""
|
||||
if offset != 0 or limit < _DEFAULT_READ_LIMIT:
|
||||
return None
|
||||
try:
|
||||
expanded = os.path.realpath(os.path.expanduser(file_path))
|
||||
basename = os.path.basename(expanded)
|
||||
source_id = hashlib.sha256(expanded.encode()).hexdigest()[:12]
|
||||
unique_name = f"{source_id}-{basename}"
|
||||
file_size = os.path.getsize(expanded)
|
||||
if file_size > _BRIDGE_SKIP_BYTES:
|
||||
logger.warning(
|
||||
"[E2B] Skipping bridge for large file (%d bytes): %s",
|
||||
file_size,
|
||||
basename,
|
||||
)
|
||||
return None
|
||||
|
||||
def _read_bytes() -> bytes:
|
||||
with open(expanded, "rb") as fh:
|
||||
return fh.read()
|
||||
|
||||
raw_content = await asyncio.to_thread(_read_bytes)
|
||||
try:
|
||||
text_content: str | None = raw_content.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
text_content = None
|
||||
data: str | bytes = text_content if text_content is not None else raw_content
|
||||
if file_size <= _BRIDGE_SHELL_MAX_BYTES:
|
||||
sandbox_path = f"/tmp/{unique_name}"
|
||||
await _sandbox_write(sandbox, sandbox_path, data)
|
||||
else:
|
||||
sandbox_path = f"/home/user/{unique_name}"
|
||||
await sandbox.files.write(sandbox_path, data)
|
||||
logger.info(
|
||||
"[E2B] Bridged SDK file to sandbox: %s -> %s", basename, sandbox_path
|
||||
)
|
||||
return sandbox_path
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"[E2B] Failed to bridge SDK file to sandbox: %s",
|
||||
file_path,
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def bridge_and_annotate(
|
||||
sandbox: Any, file_path: str, offset: int, limit: int
|
||||
) -> str | None:
|
||||
"""Bridge a host file to the sandbox and return a newline-prefixed annotation.
|
||||
|
||||
Combines ``bridge_to_sandbox`` with the standard annotation suffix so
|
||||
callers don't need to duplicate the pattern. Returns a string like
|
||||
``"\\n[Sandbox copy available at /tmp/abc-file.txt]"`` on success, or
|
||||
``None`` if bridging was skipped or failed.
|
||||
"""
|
||||
sandbox_path = await bridge_to_sandbox(sandbox, file_path, offset, limit)
|
||||
if sandbox_path is None:
|
||||
return None
|
||||
return f"\n[Sandbox copy available at {sandbox_path}]"
|
||||
|
||||
|
||||
# Local read (for SDK-internal paths)
|
||||
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
Pure unit tests with no external dependencies (no E2B, no sandbox).
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import shutil
|
||||
from types import SimpleNamespace
|
||||
@@ -13,12 +14,26 @@ import pytest
|
||||
from backend.copilot.context import E2B_WORKDIR, SDK_PROJECTS_DIR, _current_project_dir
|
||||
|
||||
from .e2b_file_tools import (
|
||||
_BRIDGE_SHELL_MAX_BYTES,
|
||||
_BRIDGE_SKIP_BYTES,
|
||||
_DEFAULT_READ_LIMIT,
|
||||
_check_sandbox_symlink_escape,
|
||||
_read_local,
|
||||
_sandbox_write,
|
||||
bridge_and_annotate,
|
||||
bridge_to_sandbox,
|
||||
resolve_sandbox_path,
|
||||
)
|
||||
|
||||
|
||||
def _expected_bridge_path(file_path: str, prefix: str = "/tmp") -> str:
|
||||
"""Compute the expected sandbox path for a bridged file."""
|
||||
expanded = os.path.realpath(os.path.expanduser(file_path))
|
||||
basename = os.path.basename(expanded)
|
||||
source_id = hashlib.sha256(expanded.encode()).hexdigest()[:12]
|
||||
return f"{prefix}/{source_id}-{basename}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_sandbox_path — sandbox path normalisation & boundary enforcement
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -91,9 +106,9 @@ class TestResolveSandboxPath:
|
||||
# ---------------------------------------------------------------------------
|
||||
# _read_local — host filesystem reads with allowlist enforcement
|
||||
#
|
||||
# In E2B mode, _read_local only allows tool-results paths (via
|
||||
# is_allowed_local_path without sdk_cwd). Regular files live on the
|
||||
# sandbox, not the host.
|
||||
# In E2B mode, _read_local only allows tool-results/tool-outputs paths
|
||||
# (via is_allowed_local_path without sdk_cwd). Regular files live on
|
||||
# the sandbox, not the host.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@@ -119,7 +134,7 @@ class TestReadLocal:
|
||||
)
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local(filepath, offset=0, limit=2000)
|
||||
result = _read_local(filepath, offset=0, limit=_DEFAULT_READ_LIMIT)
|
||||
assert result["isError"] is False
|
||||
assert "line 1" in result["content"][0]["text"]
|
||||
assert "line 2" in result["content"][0]["text"]
|
||||
@@ -127,6 +142,25 @@ class TestReadLocal:
|
||||
_current_project_dir.reset(token)
|
||||
os.unlink(filepath)
|
||||
|
||||
def test_read_tool_outputs_file(self):
|
||||
"""Reading a tool-outputs file should also succeed."""
|
||||
encoded = "-tmp-copilot-e2b-test-read-outputs"
|
||||
tool_outputs_dir = os.path.join(
|
||||
SDK_PROJECTS_DIR, encoded, self._CONV_UUID, "tool-outputs"
|
||||
)
|
||||
os.makedirs(tool_outputs_dir, exist_ok=True)
|
||||
filepath = os.path.join(tool_outputs_dir, "sdk-abc123.json")
|
||||
with open(filepath, "w") as f:
|
||||
f.write('{"data": "test"}\n')
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local(filepath, offset=0, limit=_DEFAULT_READ_LIMIT)
|
||||
assert result["isError"] is False
|
||||
assert "test" in result["content"][0]["text"]
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
shutil.rmtree(os.path.join(SDK_PROJECTS_DIR, encoded), ignore_errors=True)
|
||||
|
||||
def test_read_disallowed_path_blocked(self):
|
||||
"""Reading /etc/passwd should be blocked by the allowlist."""
|
||||
result = _read_local("/etc/passwd", offset=0, limit=10)
|
||||
@@ -335,3 +369,199 @@ class TestSandboxWrite:
|
||||
encoded_in_cmd = call_args.split("echo ")[1].split(" |")[0].strip("'")
|
||||
decoded = base64.b64decode(encoded_in_cmd).decode()
|
||||
assert decoded == content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# bridge_to_sandbox — copy SDK-internal files into E2B sandbox
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_bridge_sandbox() -> SimpleNamespace:
|
||||
"""Build a sandbox mock suitable for bridge_to_sandbox tests."""
|
||||
run_result = SimpleNamespace(stdout="", stderr="", exit_code=0)
|
||||
commands = SimpleNamespace(run=AsyncMock(return_value=run_result))
|
||||
files = SimpleNamespace(write=AsyncMock())
|
||||
return SimpleNamespace(commands=commands, files=files)
|
||||
|
||||
|
||||
class TestBridgeToSandbox:
|
||||
@pytest.mark.asyncio
|
||||
async def test_happy_path_small_file(self, tmp_path):
|
||||
"""A small file is bridged to /tmp/<hash>-<basename> via _sandbox_write."""
|
||||
f = tmp_path / "result.json"
|
||||
f.write_text('{"ok": true}')
|
||||
sandbox = _make_bridge_sandbox()
|
||||
|
||||
result = await bridge_to_sandbox(
|
||||
sandbox, str(f), offset=0, limit=_DEFAULT_READ_LIMIT
|
||||
)
|
||||
|
||||
expected = _expected_bridge_path(str(f))
|
||||
assert result == expected
|
||||
sandbox.commands.run.assert_called_once()
|
||||
cmd = sandbox.commands.run.call_args[0][0]
|
||||
assert "result.json" in cmd
|
||||
sandbox.files.write.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skip_when_offset_nonzero(self, tmp_path):
|
||||
"""Bridging is skipped when offset != 0 (partial read)."""
|
||||
f = tmp_path / "data.txt"
|
||||
f.write_text("content")
|
||||
sandbox = _make_bridge_sandbox()
|
||||
|
||||
result = await bridge_to_sandbox(
|
||||
sandbox, str(f), offset=10, limit=_DEFAULT_READ_LIMIT
|
||||
)
|
||||
|
||||
assert result is None
|
||||
sandbox.commands.run.assert_not_called()
|
||||
sandbox.files.write.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skip_when_limit_too_small(self, tmp_path):
|
||||
"""Bridging is skipped when limit < _DEFAULT_READ_LIMIT (partial read)."""
|
||||
f = tmp_path / "data.txt"
|
||||
f.write_text("content")
|
||||
sandbox = _make_bridge_sandbox()
|
||||
|
||||
await bridge_to_sandbox(sandbox, str(f), offset=0, limit=100)
|
||||
|
||||
sandbox.commands.run.assert_not_called()
|
||||
sandbox.files.write.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nonexistent_file_does_not_raise(self, tmp_path):
|
||||
"""Bridging a non-existent file logs but does not propagate errors."""
|
||||
sandbox = _make_bridge_sandbox()
|
||||
|
||||
await bridge_to_sandbox(
|
||||
sandbox, str(tmp_path / "ghost.txt"), offset=0, limit=_DEFAULT_READ_LIMIT
|
||||
)
|
||||
|
||||
sandbox.commands.run.assert_not_called()
|
||||
sandbox.files.write.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sandbox_write_failure_returns_none(self, tmp_path):
|
||||
"""If sandbox write fails, returns None (best-effort)."""
|
||||
f = tmp_path / "data.txt"
|
||||
f.write_text("content")
|
||||
sandbox = _make_bridge_sandbox()
|
||||
sandbox.commands.run.side_effect = RuntimeError("E2B timeout")
|
||||
|
||||
result = await bridge_to_sandbox(
|
||||
sandbox, str(f), offset=0, limit=_DEFAULT_READ_LIMIT
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_file_uses_files_api(self, tmp_path):
|
||||
"""Files > 32 KB but <= 50 MB are written to /home/user/ via files.write."""
|
||||
f = tmp_path / "big.json"
|
||||
f.write_bytes(b"x" * (_BRIDGE_SHELL_MAX_BYTES + 1))
|
||||
sandbox = _make_bridge_sandbox()
|
||||
|
||||
result = await bridge_to_sandbox(
|
||||
sandbox, str(f), offset=0, limit=_DEFAULT_READ_LIMIT
|
||||
)
|
||||
|
||||
expected = _expected_bridge_path(str(f), prefix="/home/user")
|
||||
assert result == expected
|
||||
sandbox.files.write.assert_called_once()
|
||||
call_args = sandbox.files.write.call_args[0]
|
||||
assert call_args[0] == expected
|
||||
sandbox.commands.run.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_small_binary_file_preserves_bytes(self, tmp_path):
|
||||
"""A small binary file is bridged to /tmp via base64 without corruption."""
|
||||
binary_data = bytes(range(256))
|
||||
f = tmp_path / "image.png"
|
||||
f.write_bytes(binary_data)
|
||||
sandbox = _make_bridge_sandbox()
|
||||
|
||||
result = await bridge_to_sandbox(
|
||||
sandbox, str(f), offset=0, limit=_DEFAULT_READ_LIMIT
|
||||
)
|
||||
|
||||
expected = _expected_bridge_path(str(f))
|
||||
assert result == expected
|
||||
sandbox.commands.run.assert_called_once()
|
||||
cmd = sandbox.commands.run.call_args[0][0]
|
||||
assert "base64" in cmd
|
||||
sandbox.files.write.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_binary_file_writes_raw_bytes(self, tmp_path):
|
||||
"""A large binary file is bridged to /home/user/ as raw bytes."""
|
||||
binary_data = bytes(range(256)) * 200
|
||||
f = tmp_path / "photo.jpg"
|
||||
f.write_bytes(binary_data)
|
||||
sandbox = _make_bridge_sandbox()
|
||||
|
||||
result = await bridge_to_sandbox(
|
||||
sandbox, str(f), offset=0, limit=_DEFAULT_READ_LIMIT
|
||||
)
|
||||
|
||||
expected = _expected_bridge_path(str(f), prefix="/home/user")
|
||||
assert result == expected
|
||||
sandbox.files.write.assert_called_once()
|
||||
call_args = sandbox.files.write.call_args[0]
|
||||
assert call_args[0] == expected
|
||||
assert call_args[1] == binary_data
|
||||
sandbox.commands.run.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_very_large_file_skipped(self, tmp_path):
|
||||
"""Files > 50 MB are skipped entirely."""
|
||||
f = tmp_path / "huge.bin"
|
||||
# Create a sparse file to avoid actually writing 50 MB
|
||||
with open(f, "wb") as fh:
|
||||
fh.seek(_BRIDGE_SKIP_BYTES + 1)
|
||||
fh.write(b"\0")
|
||||
sandbox = _make_bridge_sandbox()
|
||||
|
||||
result = await bridge_to_sandbox(
|
||||
sandbox, str(f), offset=0, limit=_DEFAULT_READ_LIMIT
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
sandbox.commands.run.assert_not_called()
|
||||
sandbox.files.write.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# bridge_and_annotate — shared helper wrapping bridge_to_sandbox + annotation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBridgeAndAnnotate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_annotation_on_success(self, tmp_path):
|
||||
"""On success, returns a newline-prefixed annotation with the sandbox path."""
|
||||
f = tmp_path / "data.json"
|
||||
f.write_text('{"ok": true}')
|
||||
sandbox = _make_bridge_sandbox()
|
||||
|
||||
annotation = await bridge_and_annotate(
|
||||
sandbox, str(f), offset=0, limit=_DEFAULT_READ_LIMIT
|
||||
)
|
||||
|
||||
expected_path = _expected_bridge_path(str(f))
|
||||
assert annotation == f"\n[Sandbox copy available at {expected_path}]"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_skipped(self, tmp_path):
|
||||
"""When bridging is skipped (e.g. offset != 0), returns None."""
|
||||
f = tmp_path / "data.json"
|
||||
f.write_text("content")
|
||||
sandbox = _make_bridge_sandbox()
|
||||
|
||||
annotation = await bridge_and_annotate(
|
||||
sandbox, str(f), offset=10, limit=_DEFAULT_READ_LIMIT
|
||||
)
|
||||
|
||||
assert annotation is None
|
||||
|
||||
@@ -20,6 +20,7 @@ config = ChatConfig()
|
||||
def build_sdk_env(
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
sdk_cwd: str | None = None,
|
||||
) -> dict[str, str]:
|
||||
"""Build env vars for the SDK CLI subprocess.
|
||||
|
||||
@@ -29,25 +30,35 @@ def build_sdk_env(
|
||||
``ANTHROPIC_API_KEY`` from the parent environment.
|
||||
3. **OpenRouter** (default) — overrides base URL and auth token to
|
||||
route through the proxy, with Langfuse trace headers.
|
||||
|
||||
When *sdk_cwd* is provided, ``CLAUDE_CODE_TMPDIR`` is set so that
|
||||
the CLI writes temp/sub-agent output inside the per-session workspace
|
||||
directory rather than an inaccessible system temp path.
|
||||
"""
|
||||
# --- Mode 1: Claude Code subscription auth ---
|
||||
if config.use_claude_code_subscription:
|
||||
validate_subscription()
|
||||
return {
|
||||
env: dict[str, str] = {
|
||||
"ANTHROPIC_API_KEY": "",
|
||||
"ANTHROPIC_AUTH_TOKEN": "",
|
||||
"ANTHROPIC_BASE_URL": "",
|
||||
}
|
||||
if sdk_cwd:
|
||||
env["CLAUDE_CODE_TMPDIR"] = sdk_cwd
|
||||
return env
|
||||
|
||||
# --- Mode 2: Direct Anthropic (no proxy hop) ---
|
||||
if not config.openrouter_active:
|
||||
return {}
|
||||
env = {}
|
||||
if sdk_cwd:
|
||||
env["CLAUDE_CODE_TMPDIR"] = sdk_cwd
|
||||
return env
|
||||
|
||||
# --- Mode 3: OpenRouter proxy ---
|
||||
base = (config.base_url or "").rstrip("/")
|
||||
if base.endswith("/v1"):
|
||||
base = base[:-3]
|
||||
env: dict[str, str] = {
|
||||
env = {
|
||||
"ANTHROPIC_BASE_URL": base,
|
||||
"ANTHROPIC_AUTH_TOKEN": config.api_key or "",
|
||||
"ANTHROPIC_API_KEY": "", # force CLI to use AUTH_TOKEN
|
||||
@@ -65,4 +76,7 @@ def build_sdk_env(
|
||||
if parts:
|
||||
env["ANTHROPIC_CUSTOM_HEADERS"] = "\n".join(parts)
|
||||
|
||||
if sdk_cwd:
|
||||
env["CLAUDE_CODE_TMPDIR"] = sdk_cwd
|
||||
|
||||
return env
|
||||
|
||||
@@ -240,3 +240,54 @@ class TestBuildSdkEnvModePriority:
|
||||
"ANTHROPIC_AUTH_TOKEN": "",
|
||||
"ANTHROPIC_BASE_URL": "",
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLAUDE_CODE_TMPDIR integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestClaudeCodeTmpdir:
|
||||
"""Verify build_sdk_env() sets CLAUDE_CODE_TMPDIR from *sdk_cwd*."""
|
||||
|
||||
def test_tmpdir_set_when_sdk_cwd_is_truthy(self):
|
||||
"""CLAUDE_CODE_TMPDIR is set to sdk_cwd when sdk_cwd is truthy."""
|
||||
cfg = _make_config(use_openrouter=False)
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env(sdk_cwd="/tmp/copilot-workspace")
|
||||
|
||||
assert result["CLAUDE_CODE_TMPDIR"] == "/tmp/copilot-workspace"
|
||||
|
||||
def test_tmpdir_not_set_when_sdk_cwd_is_none(self):
|
||||
"""CLAUDE_CODE_TMPDIR is NOT in the env when sdk_cwd is None."""
|
||||
cfg = _make_config(use_openrouter=False)
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env(sdk_cwd=None)
|
||||
|
||||
assert "CLAUDE_CODE_TMPDIR" not in result
|
||||
|
||||
def test_tmpdir_not_set_when_sdk_cwd_is_empty_string(self):
|
||||
"""CLAUDE_CODE_TMPDIR is NOT in the env when sdk_cwd is empty string."""
|
||||
cfg = _make_config(use_openrouter=False)
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env(sdk_cwd="")
|
||||
|
||||
assert "CLAUDE_CODE_TMPDIR" not in result
|
||||
|
||||
@patch("backend.copilot.sdk.env.validate_subscription")
|
||||
def test_tmpdir_set_in_subscription_mode(self, mock_validate):
|
||||
"""CLAUDE_CODE_TMPDIR is set even in subscription mode."""
|
||||
cfg = _make_config(use_claude_code_subscription=True)
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env(sdk_cwd="/tmp/sub-workspace")
|
||||
|
||||
assert result["CLAUDE_CODE_TMPDIR"] == "/tmp/sub-workspace"
|
||||
assert result["ANTHROPIC_API_KEY"] == ""
|
||||
|
||||
@@ -28,13 +28,12 @@ Each result includes a `remotes` array with the exact server URL to use.
|
||||
|
||||
### Important: Check blocks first
|
||||
|
||||
Before using `run_mcp_tool`, always check if the platform already has blocks for the service
|
||||
using `find_block`. The platform has hundreds of built-in blocks (Google Sheets, Google Docs,
|
||||
Google Calendar, Gmail, etc.) that work without MCP setup.
|
||||
Always follow the **Tool Discovery Priority** described in the tool notes:
|
||||
call `find_block` before resorting to `run_mcp_tool`.
|
||||
|
||||
Only use `run_mcp_tool` when:
|
||||
- The service is in the known hosted MCP servers list above, OR
|
||||
- You searched `find_block` first and found no matching blocks
|
||||
- You searched `find_block` first and found no matching blocks, AND
|
||||
- The service is in the known hosted MCP servers list above or found via the registry API
|
||||
|
||||
**Never guess or construct MCP server URLs.** Only use URLs from the known servers list above
|
||||
or from the `remotes[].url` field in MCP registry search results.
|
||||
|
||||
@@ -0,0 +1,391 @@
|
||||
"""Tests for P0 guardrails: _resolve_fallback_model, security env vars, TMPDIR."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.constants import is_transient_api_error
|
||||
|
||||
|
||||
def _make_config(**overrides) -> ChatConfig:
|
||||
"""Create a ChatConfig with safe defaults, applying *overrides*."""
|
||||
defaults = {
|
||||
"use_claude_code_subscription": False,
|
||||
"use_openrouter": False,
|
||||
"api_key": None,
|
||||
"base_url": None,
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return ChatConfig(**defaults)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _resolve_fallback_model
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SVC = "backend.copilot.sdk.service"
|
||||
|
||||
|
||||
class TestResolveFallbackModel:
|
||||
"""Provider-aware fallback model resolution."""
|
||||
|
||||
def test_returns_none_when_empty(self):
|
||||
cfg = _make_config(claude_agent_fallback_model="")
|
||||
with patch(f"{_SVC}.config", cfg):
|
||||
from backend.copilot.sdk.service import _resolve_fallback_model
|
||||
|
||||
assert _resolve_fallback_model() is None
|
||||
|
||||
def test_strips_provider_prefix(self):
|
||||
"""OpenRouter-style 'anthropic/claude-sonnet-4-...' is stripped."""
|
||||
cfg = _make_config(
|
||||
claude_agent_fallback_model="anthropic/claude-sonnet-4-20250514",
|
||||
use_openrouter=True,
|
||||
api_key="sk-test",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
with patch(f"{_SVC}.config", cfg):
|
||||
from backend.copilot.sdk.service import _resolve_fallback_model
|
||||
|
||||
result = _resolve_fallback_model()
|
||||
|
||||
assert result == "claude-sonnet-4-20250514"
|
||||
assert "/" not in result
|
||||
|
||||
def test_dots_replaced_for_direct_anthropic(self):
|
||||
"""Direct Anthropic requires hyphen-separated versions."""
|
||||
cfg = _make_config(
|
||||
claude_agent_fallback_model="claude-sonnet-4.5-20250514",
|
||||
use_openrouter=False,
|
||||
)
|
||||
with patch(f"{_SVC}.config", cfg):
|
||||
from backend.copilot.sdk.service import _resolve_fallback_model
|
||||
|
||||
result = _resolve_fallback_model()
|
||||
|
||||
assert result is not None
|
||||
assert "." not in result
|
||||
assert result == "claude-sonnet-4-5-20250514"
|
||||
|
||||
def test_dots_preserved_for_openrouter(self):
|
||||
"""OpenRouter uses dot-separated versions — don't normalise."""
|
||||
cfg = _make_config(
|
||||
claude_agent_fallback_model="claude-sonnet-4.5-20250514",
|
||||
use_openrouter=True,
|
||||
api_key="sk-test",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
with patch(f"{_SVC}.config", cfg):
|
||||
from backend.copilot.sdk.service import _resolve_fallback_model
|
||||
|
||||
result = _resolve_fallback_model()
|
||||
|
||||
assert result == "claude-sonnet-4.5-20250514"
|
||||
|
||||
def test_default_value(self):
|
||||
"""Default fallback model resolves to a valid string."""
|
||||
cfg = _make_config()
|
||||
with patch(f"{_SVC}.config", cfg):
|
||||
from backend.copilot.sdk.service import _resolve_fallback_model
|
||||
|
||||
result = _resolve_fallback_model()
|
||||
|
||||
assert result is not None
|
||||
assert "sonnet" in result.lower() or "claude" in result.lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Security & isolation env vars
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSecurityEnvVars:
|
||||
"""Verify the env-var contract in the service module.
|
||||
|
||||
The production code sets CLAUDE_CODE_TMPDIR and security env vars
|
||||
inline after ``build_sdk_env()`` returns. We grep for these string
|
||||
literals in ``service.py`` to ensure they aren't accidentally removed.
|
||||
"""
|
||||
|
||||
_SERVICE_PATH = "autogpt_platform/backend/backend/copilot/sdk/service.py"
|
||||
|
||||
@staticmethod
|
||||
def _read_service_source() -> str:
|
||||
import pathlib
|
||||
|
||||
# Walk up from this test file to the repo root
|
||||
repo = pathlib.Path(__file__).resolve().parents[5]
|
||||
return (repo / TestSecurityEnvVars._SERVICE_PATH).read_text()
|
||||
|
||||
def test_tmpdir_env_var_present_in_source(self):
|
||||
"""CLAUDE_CODE_TMPDIR must be set when sdk_cwd is provided."""
|
||||
src = self._read_service_source()
|
||||
assert 'sdk_env["CLAUDE_CODE_TMPDIR"]' in src
|
||||
|
||||
def test_home_not_overridden_in_source(self):
|
||||
"""HOME must NOT be overridden — would break git/ssh/npm."""
|
||||
src = self._read_service_source()
|
||||
assert 'sdk_env["HOME"]' not in src
|
||||
|
||||
def test_security_env_vars_present_in_source(self):
|
||||
"""All four security env vars must be set in the service module."""
|
||||
src = self._read_service_source()
|
||||
for var in (
|
||||
"CLAUDE_CODE_DISABLE_CLAUDE_MDS",
|
||||
"CLAUDE_CODE_SKIP_PROMPT_HISTORY",
|
||||
"CLAUDE_CODE_DISABLE_AUTO_MEMORY",
|
||||
"CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC",
|
||||
):
|
||||
assert var in src, f"{var} not found in service.py"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config defaults
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConfigDefaults:
|
||||
"""Verify ChatConfig P0 fields have correct defaults."""
|
||||
|
||||
def test_fallback_model_default(self):
|
||||
cfg = _make_config()
|
||||
assert cfg.claude_agent_fallback_model
|
||||
assert "sonnet" in cfg.claude_agent_fallback_model.lower()
|
||||
|
||||
def test_max_turns_default(self):
|
||||
cfg = _make_config()
|
||||
assert cfg.claude_agent_max_turns == 50
|
||||
|
||||
def test_max_budget_usd_default(self):
|
||||
cfg = _make_config()
|
||||
assert cfg.claude_agent_max_budget_usd == 5.0
|
||||
|
||||
def test_max_transient_retries_default(self):
|
||||
cfg = _make_config()
|
||||
assert cfg.claude_agent_max_transient_retries == 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_sdk_env — all 3 auth modes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_ENV = "backend.copilot.sdk.env"
|
||||
|
||||
|
||||
class TestBuildSdkEnv:
|
||||
"""Verify build_sdk_env returns correct dicts for each auth mode."""
|
||||
|
||||
def test_subscription_mode_clears_keys(self):
|
||||
"""Mode 1: subscription clears API key / auth token / base URL."""
|
||||
cfg = _make_config(use_claude_code_subscription=True)
|
||||
with (
|
||||
patch(f"{_ENV}.config", cfg),
|
||||
patch(f"{_ENV}.validate_subscription"),
|
||||
):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
env = build_sdk_env(session_id="s1", user_id="u1")
|
||||
|
||||
assert env["ANTHROPIC_API_KEY"] == ""
|
||||
assert env["ANTHROPIC_AUTH_TOKEN"] == ""
|
||||
assert env["ANTHROPIC_BASE_URL"] == ""
|
||||
|
||||
def test_direct_anthropic_returns_empty_dict(self):
|
||||
"""Mode 2: direct Anthropic returns {} (inherits from parent env)."""
|
||||
cfg = _make_config(
|
||||
use_claude_code_subscription=False,
|
||||
use_openrouter=False,
|
||||
)
|
||||
with patch(f"{_ENV}.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
env = build_sdk_env()
|
||||
|
||||
assert env == {}
|
||||
|
||||
def test_openrouter_sets_base_url_and_auth(self):
|
||||
"""Mode 3: OpenRouter sets base URL, auth token, and clears API key."""
|
||||
cfg = _make_config(
|
||||
use_claude_code_subscription=False,
|
||||
use_openrouter=True,
|
||||
api_key="sk-or-test",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
with patch(f"{_ENV}.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
env = build_sdk_env(session_id="sess-1", user_id="user-1")
|
||||
|
||||
assert env["ANTHROPIC_BASE_URL"] == "https://openrouter.ai/api"
|
||||
assert env["ANTHROPIC_AUTH_TOKEN"] == "sk-or-test"
|
||||
assert env["ANTHROPIC_API_KEY"] == ""
|
||||
assert "x-session-id: sess-1" in env["ANTHROPIC_CUSTOM_HEADERS"]
|
||||
assert "x-user-id: user-1" in env["ANTHROPIC_CUSTOM_HEADERS"]
|
||||
|
||||
def test_openrouter_no_headers_when_ids_empty(self):
|
||||
"""Mode 3: No custom headers when session_id/user_id are not given."""
|
||||
cfg = _make_config(
|
||||
use_claude_code_subscription=False,
|
||||
use_openrouter=True,
|
||||
api_key="sk-or-test",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
with patch(f"{_ENV}.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
env = build_sdk_env()
|
||||
|
||||
assert "ANTHROPIC_CUSTOM_HEADERS" not in env
|
||||
|
||||
def test_all_modes_return_mutable_dict(self):
|
||||
"""build_sdk_env must return a mutable dict (not None) so callers
|
||||
can add security env vars like CLAUDE_CODE_TMPDIR."""
|
||||
for cfg in (
|
||||
_make_config(use_claude_code_subscription=True),
|
||||
_make_config(use_openrouter=False),
|
||||
_make_config(
|
||||
use_openrouter=True,
|
||||
api_key="k",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
),
|
||||
):
|
||||
with (
|
||||
patch(f"{_ENV}.config", cfg),
|
||||
patch(f"{_ENV}.validate_subscription"),
|
||||
):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
env = build_sdk_env()
|
||||
|
||||
assert isinstance(env, dict)
|
||||
env["CLAUDE_CODE_TMPDIR"] = "/tmp/test"
|
||||
assert env["CLAUDE_CODE_TMPDIR"] == "/tmp/test"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# is_transient_api_error
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsTransientApiError:
|
||||
"""Verify that is_transient_api_error detects all transient patterns."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"error_text",
|
||||
[
|
||||
"socket connection was closed unexpectedly",
|
||||
"ECONNRESET",
|
||||
"connection was forcibly closed",
|
||||
"network socket disconnected",
|
||||
],
|
||||
)
|
||||
def test_connection_level_errors(self, error_text: str):
|
||||
assert is_transient_api_error(error_text)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"error_text",
|
||||
[
|
||||
"rate limit exceeded",
|
||||
"rate_limit_error",
|
||||
"Too Many Requests",
|
||||
"status code 429",
|
||||
],
|
||||
)
|
||||
def test_429_rate_limit_errors(self, error_text: str):
|
||||
assert is_transient_api_error(error_text)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"error_text",
|
||||
[
|
||||
"API is overloaded",
|
||||
"Internal Server Error",
|
||||
"Bad Gateway",
|
||||
"Service Unavailable",
|
||||
"Gateway Timeout",
|
||||
"status code 529",
|
||||
"status code 500",
|
||||
"status code 502",
|
||||
"status code 503",
|
||||
"status code 504",
|
||||
],
|
||||
)
|
||||
def test_5xx_server_errors(self, error_text: str):
|
||||
assert is_transient_api_error(error_text)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"error_text",
|
||||
[
|
||||
"invalid_api_key",
|
||||
"Authentication failed",
|
||||
"prompt is too long",
|
||||
"model not found",
|
||||
"",
|
||||
],
|
||||
)
|
||||
def test_non_transient_errors(self, error_text: str):
|
||||
assert not is_transient_api_error(error_text)
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert is_transient_api_error("SOCKET CONNECTION WAS CLOSED UNEXPECTEDLY")
|
||||
assert is_transient_api_error("econnreset")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config validators for max_turns / max_budget_usd
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConfigValidators:
|
||||
"""Verify ge/le bounds on max_turns and max_budget_usd."""
|
||||
|
||||
def test_max_turns_rejects_zero(self):
|
||||
with pytest.raises(ValidationError):
|
||||
_make_config(claude_agent_max_turns=0)
|
||||
|
||||
def test_max_turns_rejects_negative(self):
|
||||
with pytest.raises(ValidationError):
|
||||
_make_config(claude_agent_max_turns=-1)
|
||||
|
||||
def test_max_turns_rejects_above_500(self):
|
||||
with pytest.raises(ValidationError):
|
||||
_make_config(claude_agent_max_turns=501)
|
||||
|
||||
def test_max_turns_accepts_boundary_values(self):
|
||||
cfg_low = _make_config(claude_agent_max_turns=1)
|
||||
assert cfg_low.claude_agent_max_turns == 1
|
||||
cfg_high = _make_config(claude_agent_max_turns=500)
|
||||
assert cfg_high.claude_agent_max_turns == 500
|
||||
|
||||
def test_max_budget_rejects_zero(self):
|
||||
with pytest.raises(ValidationError):
|
||||
_make_config(claude_agent_max_budget_usd=0.0)
|
||||
|
||||
def test_max_budget_rejects_negative(self):
|
||||
with pytest.raises(ValidationError):
|
||||
_make_config(claude_agent_max_budget_usd=-1.0)
|
||||
|
||||
def test_max_budget_rejects_above_100(self):
|
||||
with pytest.raises(ValidationError):
|
||||
_make_config(claude_agent_max_budget_usd=100.01)
|
||||
|
||||
def test_max_budget_accepts_boundary_values(self):
|
||||
cfg_low = _make_config(claude_agent_max_budget_usd=0.01)
|
||||
assert cfg_low.claude_agent_max_budget_usd == 0.01
|
||||
cfg_high = _make_config(claude_agent_max_budget_usd=100.0)
|
||||
assert cfg_high.claude_agent_max_budget_usd == 100.0
|
||||
|
||||
def test_max_transient_retries_rejects_negative(self):
|
||||
with pytest.raises(ValidationError):
|
||||
_make_config(claude_agent_max_transient_retries=-1)
|
||||
|
||||
def test_max_transient_retries_rejects_above_10(self):
|
||||
with pytest.raises(ValidationError):
|
||||
_make_config(claude_agent_max_transient_retries=11)
|
||||
|
||||
def test_max_transient_retries_accepts_boundary_values(self):
|
||||
cfg_low = _make_config(claude_agent_max_transient_retries=0)
|
||||
assert cfg_low.claude_agent_max_transient_retries == 0
|
||||
cfg_high = _make_config(claude_agent_max_transient_retries=10)
|
||||
assert cfg_high.claude_agent_max_transient_retries == 10
|
||||
@@ -403,7 +403,7 @@ class TestCompactTranscript:
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
@@ -438,7 +438,7 @@ class TestCompactTranscript:
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
@@ -462,7 +462,7 @@ class TestCompactTranscript:
|
||||
]
|
||||
)
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("LLM unavailable"),
|
||||
):
|
||||
@@ -568,11 +568,11 @@ class TestRunCompressionTimeout:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
"backend.copilot.transcript.get_openai_client",
|
||||
return_value="fake-client",
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
"backend.copilot.transcript.compress_context",
|
||||
side_effect=_mock_compress,
|
||||
),
|
||||
):
|
||||
@@ -602,11 +602,11 @@ class TestRunCompressionTimeout:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
"backend.copilot.transcript.get_openai_client",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
"backend.copilot.transcript.compress_context",
|
||||
new_callable=AsyncMock,
|
||||
return_value=truncation_result,
|
||||
) as mock_compress,
|
||||
|
||||
@@ -29,6 +29,7 @@ from backend.copilot.response_model import (
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
|
||||
from .compaction import compaction_events
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
from .tool_adapter import MCP_TOOL_PREFIX
|
||||
from .tool_adapter import _pending_tool_outputs as _pto
|
||||
@@ -259,13 +260,13 @@ def test_result_error_emits_error_and_finish():
|
||||
is_error=True,
|
||||
num_turns=0,
|
||||
session_id="s1",
|
||||
result="API rate limited",
|
||||
result="Invalid API key provided",
|
||||
)
|
||||
results = adapter.convert_message(msg)
|
||||
# No step was open, so no FinishStep — just Error + Finish
|
||||
assert len(results) == 2
|
||||
assert isinstance(results[0], StreamError)
|
||||
assert "API rate limited" in results[0].errorText
|
||||
assert "Invalid API key provided" in results[0].errorText
|
||||
assert isinstance(results[1], StreamFinish)
|
||||
|
||||
|
||||
@@ -689,3 +690,102 @@ def test_already_resolved_tool_skipped_in_user_message():
|
||||
assert (
|
||||
len(output_events) == 0
|
||||
), "Already-resolved tool should not emit duplicate output"
|
||||
|
||||
|
||||
# -- _end_text_if_open before compaction -------------------------------------
|
||||
|
||||
|
||||
def test_end_text_if_open_emits_text_end_before_finish_step():
|
||||
"""StreamTextEnd must be emitted before StreamFinishStep during compaction.
|
||||
|
||||
When ``emit_end_if_ready`` fires compaction events while a text block is
|
||||
still open, ``_end_text_if_open`` must close it first. If StreamFinishStep
|
||||
arrives before StreamTextEnd, the Vercel AI SDK clears ``activeTextParts``
|
||||
and raises "Received text-end for missing text part".
|
||||
"""
|
||||
adapter = _adapter()
|
||||
|
||||
# Open a text block by processing an AssistantMessage with text
|
||||
msg = AssistantMessage(content=[TextBlock(text="partial response")], model="test")
|
||||
adapter.convert_message(msg)
|
||||
assert adapter.has_started_text
|
||||
assert not adapter.has_ended_text
|
||||
|
||||
# Simulate what service.py does before yielding compaction events
|
||||
pre_close: list[StreamBaseResponse] = []
|
||||
adapter._end_text_if_open(pre_close)
|
||||
combined = pre_close + list(compaction_events("Compacted transcript"))
|
||||
|
||||
text_end_idx = next(
|
||||
(i for i, e in enumerate(combined) if isinstance(e, StreamTextEnd)), None
|
||||
)
|
||||
finish_step_idx = next(
|
||||
(i for i, e in enumerate(combined) if isinstance(e, StreamFinishStep)), None
|
||||
)
|
||||
|
||||
assert text_end_idx is not None, "StreamTextEnd must be present"
|
||||
assert finish_step_idx is not None, "StreamFinishStep must be present"
|
||||
assert text_end_idx < finish_step_idx, (
|
||||
f"StreamTextEnd (idx={text_end_idx}) must precede "
|
||||
f"StreamFinishStep (idx={finish_step_idx}) — otherwise the Vercel AI SDK "
|
||||
"clears activeTextParts before text-end arrives"
|
||||
)
|
||||
|
||||
|
||||
def test_step_open_must_reset_after_compaction_finish_step():
|
||||
"""Adapter step_open must be reset when compaction emits StreamFinishStep.
|
||||
|
||||
Compaction events bypass the adapter, so service.py must explicitly clear
|
||||
step_open after yielding a StreamFinishStep from compaction. Without this,
|
||||
the next AssistantMessage skips StreamStartStep because the adapter still
|
||||
thinks a step is open.
|
||||
"""
|
||||
adapter = _adapter()
|
||||
|
||||
# Open a step + text block via an AssistantMessage
|
||||
msg = AssistantMessage(content=[TextBlock(text="thinking...")], model="test")
|
||||
adapter.convert_message(msg)
|
||||
assert adapter.step_open is True
|
||||
|
||||
# Simulate what service.py does: close text, then check compaction events
|
||||
pre_close: list[StreamBaseResponse] = []
|
||||
adapter._end_text_if_open(pre_close)
|
||||
|
||||
events = list(compaction_events("Compacted transcript"))
|
||||
if any(isinstance(ev, StreamFinishStep) for ev in events):
|
||||
adapter.step_open = False
|
||||
|
||||
assert (
|
||||
adapter.step_open is False
|
||||
), "step_open must be False after compaction emits StreamFinishStep"
|
||||
|
||||
# Next AssistantMessage must open a new step
|
||||
msg2 = AssistantMessage(content=[TextBlock(text="continued")], model="test")
|
||||
results = adapter.convert_message(msg2)
|
||||
assert any(
|
||||
isinstance(r, StreamStartStep) for r in results
|
||||
), "A new StreamStartStep must be emitted after compaction closed the step"
|
||||
|
||||
|
||||
def test_end_text_if_open_no_op_when_no_text_open():
|
||||
"""_end_text_if_open emits nothing when no text block is open."""
|
||||
adapter = _adapter()
|
||||
results: list[StreamBaseResponse] = []
|
||||
adapter._end_text_if_open(results)
|
||||
assert results == []
|
||||
|
||||
|
||||
def test_end_text_if_open_no_op_after_text_already_ended():
|
||||
"""_end_text_if_open emits nothing when the text block is already closed."""
|
||||
adapter = _adapter()
|
||||
msg = AssistantMessage(content=[TextBlock(text="hello")], model="test")
|
||||
adapter.convert_message(msg)
|
||||
# Close it once
|
||||
first: list[StreamBaseResponse] = []
|
||||
adapter._end_text_if_open(first)
|
||||
assert len(first) == 1
|
||||
assert isinstance(first[0], StreamTextEnd)
|
||||
# Second call must be a no-op
|
||||
second: list[StreamBaseResponse] = []
|
||||
adapter._end_text_if_open(second)
|
||||
assert second == []
|
||||
|
||||
@@ -113,7 +113,7 @@ class TestScenarioCompactAndRetry:
|
||||
)(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
),
|
||||
@@ -170,7 +170,7 @@ class TestScenarioCompactFailsFallback:
|
||||
)(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("LLM unavailable"),
|
||||
),
|
||||
@@ -261,7 +261,7 @@ class TestScenarioDoubleFailDBFallback:
|
||||
)(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
),
|
||||
@@ -337,7 +337,7 @@ class TestScenarioCompactionIdentical:
|
||||
)(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
),
|
||||
@@ -730,7 +730,7 @@ class TestRetryEdgeCases:
|
||||
)(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
),
|
||||
@@ -841,7 +841,7 @@ class TestRetryStateReset:
|
||||
)(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("boom"),
|
||||
),
|
||||
@@ -1010,7 +1010,7 @@ def _make_sdk_patches(
|
||||
(f"{_SVC}.create_security_hooks", dict(return_value=MagicMock())),
|
||||
(f"{_SVC}.get_copilot_tool_names", dict(return_value=[])),
|
||||
(f"{_SVC}.get_sdk_disallowed_tools", dict(return_value=[])),
|
||||
(f"{_SVC}.build_sdk_env", dict(return_value=None)),
|
||||
(f"{_SVC}.build_sdk_env", dict(return_value={})),
|
||||
(f"{_SVC}._resolve_sdk_model", dict(return_value=None)),
|
||||
(f"{_SVC}.set_execution_context", {}),
|
||||
(
|
||||
@@ -1487,3 +1487,188 @@ class TestStreamChatCompletionRetryIntegration:
|
||||
errors = [e for e in events if isinstance(e, StreamError)]
|
||||
assert not errors, f"Unexpected StreamError: {errors}"
|
||||
assert any(isinstance(e, StreamStart) for e in events)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_result_message_success_subtype_prompt_too_long_triggers_compaction(
|
||||
self,
|
||||
):
|
||||
"""CLI returns ResultMessage(subtype="success") with result="Prompt is too long".
|
||||
|
||||
The SDK internally compacts but the transcript is still too long. It
|
||||
returns subtype="success" (process completed) with result="Prompt is
|
||||
too long" (the actual rejection message). The retry loop must detect
|
||||
this as a context-length error and trigger compaction — the subtype
|
||||
"success" must not fool it into treating this as a real response.
|
||||
"""
|
||||
import contextlib
|
||||
|
||||
from claude_agent_sdk import ResultMessage
|
||||
|
||||
from backend.copilot.response_model import StreamError, StreamStart
|
||||
from backend.copilot.sdk.service import stream_chat_completion_sdk
|
||||
|
||||
session = self._make_session()
|
||||
success_result = self._make_result_message()
|
||||
attempt_count = [0]
|
||||
|
||||
error_result = ResultMessage(
|
||||
subtype="success",
|
||||
result="Prompt is too long",
|
||||
duration_ms=100,
|
||||
duration_api_ms=0,
|
||||
is_error=False,
|
||||
num_turns=1,
|
||||
session_id="test-session-id",
|
||||
)
|
||||
|
||||
def _client_factory(*args, **kwargs):
|
||||
attempt_count[0] += 1
|
||||
|
||||
async def _receive_error():
|
||||
yield error_result
|
||||
|
||||
async def _receive_success():
|
||||
yield success_result
|
||||
|
||||
client = MagicMock()
|
||||
client._transport = MagicMock()
|
||||
client._transport.write = AsyncMock()
|
||||
client.query = AsyncMock()
|
||||
if attempt_count[0] == 1:
|
||||
client.receive_response = _receive_error
|
||||
else:
|
||||
client.receive_response = _receive_success
|
||||
cm = AsyncMock()
|
||||
cm.__aenter__.return_value = client
|
||||
cm.__aexit__.return_value = None
|
||||
return cm
|
||||
|
||||
original_transcript = _build_transcript(
|
||||
[("user", "prior question"), ("assistant", "prior answer")]
|
||||
)
|
||||
compacted_transcript = _build_transcript(
|
||||
[("user", "[summary]"), ("assistant", "summary reply")]
|
||||
)
|
||||
|
||||
patches = _make_sdk_patches(
|
||||
session,
|
||||
original_transcript=original_transcript,
|
||||
compacted_transcript=compacted_transcript,
|
||||
client_side_effect=_client_factory,
|
||||
)
|
||||
|
||||
events = []
|
||||
with contextlib.ExitStack() as stack:
|
||||
for target, kwargs in patches:
|
||||
stack.enter_context(patch(target, **kwargs))
|
||||
async for event in stream_chat_completion_sdk(
|
||||
session_id="test-session-id",
|
||||
message="hello",
|
||||
is_user_message=True,
|
||||
user_id="test-user",
|
||||
session=session,
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
assert attempt_count[0] == 2, (
|
||||
f"Expected 2 SDK attempts (subtype='success' with 'Prompt is too long' "
|
||||
f"result should trigger compaction retry), got {attempt_count[0]}"
|
||||
)
|
||||
errors = [e for e in events if isinstance(e, StreamError)]
|
||||
assert not errors, f"Unexpected StreamError: {errors}"
|
||||
assert any(isinstance(e, StreamStart) for e in events)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assistant_message_error_content_prompt_too_long_triggers_compaction(
|
||||
self,
|
||||
):
|
||||
"""AssistantMessage.error="invalid_request" with content "Prompt is too long".
|
||||
|
||||
The SDK returns error type "invalid_request" but puts the actual
|
||||
rejection message ("Prompt is too long") in the content blocks.
|
||||
The retry loop must detect this via content inspection (sdk_error
|
||||
being set confirms it's an error message, not user content).
|
||||
"""
|
||||
import contextlib
|
||||
|
||||
from claude_agent_sdk import AssistantMessage, ResultMessage, TextBlock
|
||||
|
||||
from backend.copilot.response_model import StreamError, StreamStart
|
||||
from backend.copilot.sdk.service import stream_chat_completion_sdk
|
||||
|
||||
session = self._make_session()
|
||||
success_result = self._make_result_message()
|
||||
attempt_count = [0]
|
||||
|
||||
def _client_factory(*args, **kwargs):
|
||||
attempt_count[0] += 1
|
||||
|
||||
async def _receive_error():
|
||||
# SDK returns invalid_request with "Prompt is too long" in content.
|
||||
# ResultMessage.result is a non-PTL value ("done") to isolate
|
||||
# the AssistantMessage content detection path exclusively.
|
||||
yield AssistantMessage(
|
||||
content=[TextBlock(text="Prompt is too long")],
|
||||
model="<synthetic>",
|
||||
error="invalid_request",
|
||||
)
|
||||
yield ResultMessage(
|
||||
subtype="success",
|
||||
result="done",
|
||||
duration_ms=100,
|
||||
duration_api_ms=0,
|
||||
is_error=False,
|
||||
num_turns=1,
|
||||
session_id="test-session-id",
|
||||
)
|
||||
|
||||
async def _receive_success():
|
||||
yield success_result
|
||||
|
||||
client = MagicMock()
|
||||
client._transport = MagicMock()
|
||||
client._transport.write = AsyncMock()
|
||||
client.query = AsyncMock()
|
||||
if attempt_count[0] == 1:
|
||||
client.receive_response = _receive_error
|
||||
else:
|
||||
client.receive_response = _receive_success
|
||||
cm = AsyncMock()
|
||||
cm.__aenter__.return_value = client
|
||||
cm.__aexit__.return_value = None
|
||||
return cm
|
||||
|
||||
original_transcript = _build_transcript(
|
||||
[("user", "prior question"), ("assistant", "prior answer")]
|
||||
)
|
||||
compacted_transcript = _build_transcript(
|
||||
[("user", "[summary]"), ("assistant", "summary reply")]
|
||||
)
|
||||
|
||||
patches = _make_sdk_patches(
|
||||
session,
|
||||
original_transcript=original_transcript,
|
||||
compacted_transcript=compacted_transcript,
|
||||
client_side_effect=_client_factory,
|
||||
)
|
||||
|
||||
events = []
|
||||
with contextlib.ExitStack() as stack:
|
||||
for target, kwargs in patches:
|
||||
stack.enter_context(patch(target, **kwargs))
|
||||
async for event in stream_chat_completion_sdk(
|
||||
session_id="test-session-id",
|
||||
message="hello",
|
||||
is_user_message=True,
|
||||
user_id="test-user",
|
||||
session=session,
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
assert attempt_count[0] == 2, (
|
||||
f"Expected 2 SDK attempts (AssistantMessage error content 'Prompt is "
|
||||
f"too long' should trigger compaction retry), got {attempt_count[0]}"
|
||||
)
|
||||
errors = [e for e in events if isinstance(e, StreamError)]
|
||||
assert not errors, f"Unexpected StreamError: {errors}"
|
||||
assert any(isinstance(e, StreamStart) for e in events)
|
||||
|
||||
@@ -105,6 +105,10 @@ def test_agent_options_accepts_all_our_fields():
|
||||
"env",
|
||||
"resume",
|
||||
"max_buffer_size",
|
||||
"stderr",
|
||||
"fallback_model",
|
||||
"max_turns",
|
||||
"max_budget_usd",
|
||||
]
|
||||
sig = inspect.signature(ClaudeAgentOptions)
|
||||
for field in fields_we_use:
|
||||
|
||||
@@ -22,6 +22,38 @@ from .tool_adapter import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# The SDK CLI uses "Task" in older versions and "Agent" in v2.x+.
|
||||
# Shared across all sessions — used by security hooks for sub-agent detection.
|
||||
_SUBAGENT_TOOLS: frozenset[str] = frozenset({"Task", "Agent"})
|
||||
|
||||
# Unicode ranges stripped by _sanitize():
|
||||
# - BiDi overrides (U+202A-U+202E, U+2066-U+2069) can trick reviewers
|
||||
# into misreading code/logs.
|
||||
# - Zero-width characters (U+200B-U+200F, U+FEFF) can hide content.
|
||||
_BIDI_AND_ZW_CHARS = set(
|
||||
chr(c)
|
||||
for r in (range(0x202A, 0x202F), range(0x2066, 0x206A), range(0x200B, 0x2010))
|
||||
for c in r
|
||||
) | {"\ufeff"}
|
||||
|
||||
|
||||
def _sanitize(value: str, max_len: int = 200) -> str:
|
||||
"""Strip control characters and truncate for safe logging.
|
||||
|
||||
Removes C0 (U+0000-U+001F), DEL (U+007F), C1 (U+0080-U+009F),
|
||||
Unicode BiDi overrides, and zero-width characters to prevent
|
||||
log injection and visual spoofing.
|
||||
"""
|
||||
cleaned = "".join(
|
||||
c
|
||||
for c in value
|
||||
if c >= " "
|
||||
and c != "\x7f"
|
||||
and not ("\x80" <= c <= "\x9f")
|
||||
and c not in _BIDI_AND_ZW_CHARS
|
||||
)
|
||||
return cleaned[:max_len]
|
||||
|
||||
|
||||
def _deny(reason: str) -> dict[str, Any]:
|
||||
"""Return a hook denial response."""
|
||||
@@ -136,11 +168,13 @@ def create_security_hooks(
|
||||
- PostToolUse: Log successful tool executions
|
||||
- PostToolUseFailure: Log and handle failed tool executions
|
||||
- PreCompact: Log context compaction events (SDK handles compaction automatically)
|
||||
- SubagentStart: Log sub-agent lifecycle start
|
||||
- SubagentStop: Log sub-agent lifecycle end
|
||||
|
||||
Args:
|
||||
user_id: Current user ID for isolation validation
|
||||
sdk_cwd: SDK working directory for workspace-scoped tool validation
|
||||
max_subtasks: Maximum concurrent Task (sub-agent) spawns allowed per session
|
||||
max_subtasks: Maximum concurrent sub-agent spawns allowed per session
|
||||
on_compact: Callback invoked when SDK starts compacting context.
|
||||
Receives the transcript_path from the hook input.
|
||||
|
||||
@@ -151,9 +185,19 @@ def create_security_hooks(
|
||||
from claude_agent_sdk import HookMatcher
|
||||
from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput
|
||||
|
||||
# Per-session tracking for Task sub-agent concurrency.
|
||||
# Per-session tracking for sub-agent concurrency.
|
||||
# Set of tool_use_ids that consumed a slot — len() is the active count.
|
||||
task_tool_use_ids: set[str] = set()
|
||||
#
|
||||
# LIMITATION: For background (async) agents the SDK returns the
|
||||
# Agent/Task tool immediately with {isAsync: true}, which triggers
|
||||
# PostToolUse and releases the slot while the agent is still running.
|
||||
# SubagentStop fires later when the background process finishes but
|
||||
# does not currently hold a slot. This means the concurrency limit
|
||||
# only gates *launches*, not true concurrent execution. To fix this
|
||||
# we would need to track background agent_ids separately and release
|
||||
# in SubagentStop, but the SDK does not guarantee SubagentStop fires
|
||||
# for every background agent (e.g. on session abort).
|
||||
subagent_tool_use_ids: set[str] = set()
|
||||
|
||||
async def pre_tool_use_hook(
|
||||
input_data: HookInput,
|
||||
@@ -165,29 +209,22 @@ def create_security_hooks(
|
||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||
tool_input = cast(dict[str, Any], input_data.get("tool_input", {}))
|
||||
|
||||
# Rate-limit Task (sub-agent) spawns per session
|
||||
if tool_name == "Task":
|
||||
# Block background task execution first — denied calls
|
||||
# should not consume a subtask slot.
|
||||
if tool_input.get("run_in_background"):
|
||||
logger.info(f"[SDK] Blocked background Task, user={user_id}")
|
||||
return cast(
|
||||
SyncHookJSONOutput,
|
||||
_deny(
|
||||
"Background task execution is not supported. "
|
||||
"Run tasks in the foreground instead "
|
||||
"(remove the run_in_background parameter)."
|
||||
),
|
||||
)
|
||||
if len(task_tool_use_ids) >= max_subtasks:
|
||||
# Rate-limit sub-agent spawns per session.
|
||||
# The SDK CLI renamed "Task" → "Agent" in v2.x; handle both.
|
||||
if tool_name in _SUBAGENT_TOOLS:
|
||||
# Background agents are allowed — the SDK returns immediately
|
||||
# with {isAsync: true} and the model polls via TaskOutput.
|
||||
# Still count them against the concurrency limit.
|
||||
if len(subagent_tool_use_ids) >= max_subtasks:
|
||||
logger.warning(
|
||||
f"[SDK] Task limit reached ({max_subtasks}), user={user_id}"
|
||||
f"[SDK] Sub-agent limit reached ({max_subtasks}), "
|
||||
f"user={user_id}"
|
||||
)
|
||||
return cast(
|
||||
SyncHookJSONOutput,
|
||||
_deny(
|
||||
f"Maximum {max_subtasks} concurrent sub-tasks. "
|
||||
"Wait for running sub-tasks to finish, "
|
||||
f"Maximum {max_subtasks} concurrent sub-agents. "
|
||||
"Wait for running sub-agents to finish, "
|
||||
"or continue in the main conversation."
|
||||
),
|
||||
)
|
||||
@@ -208,20 +245,20 @@ def create_security_hooks(
|
||||
if result:
|
||||
return cast(SyncHookJSONOutput, result)
|
||||
|
||||
# Reserve the Task slot only after all validations pass
|
||||
if tool_name == "Task" and tool_use_id is not None:
|
||||
task_tool_use_ids.add(tool_use_id)
|
||||
# Reserve the sub-agent slot only after all validations pass
|
||||
if tool_name in _SUBAGENT_TOOLS and tool_use_id is not None:
|
||||
subagent_tool_use_ids.add(tool_use_id)
|
||||
|
||||
logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}")
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
def _release_task_slot(tool_name: str, tool_use_id: str | None) -> None:
|
||||
"""Release a Task concurrency slot if one was reserved."""
|
||||
if tool_name == "Task" and tool_use_id in task_tool_use_ids:
|
||||
task_tool_use_ids.discard(tool_use_id)
|
||||
def _release_subagent_slot(tool_name: str, tool_use_id: str | None) -> None:
|
||||
"""Release a sub-agent concurrency slot if one was reserved."""
|
||||
if tool_name in _SUBAGENT_TOOLS and tool_use_id in subagent_tool_use_ids:
|
||||
subagent_tool_use_ids.discard(tool_use_id)
|
||||
logger.info(
|
||||
"[SDK] Task slot released, active=%d/%d, user=%s",
|
||||
len(task_tool_use_ids),
|
||||
"[SDK] Sub-agent slot released, active=%d/%d, user=%s",
|
||||
len(subagent_tool_use_ids),
|
||||
max_subtasks,
|
||||
user_id,
|
||||
)
|
||||
@@ -241,13 +278,14 @@ def create_security_hooks(
|
||||
_ = context
|
||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||
|
||||
_release_task_slot(tool_name, tool_use_id)
|
||||
_release_subagent_slot(tool_name, tool_use_id)
|
||||
is_builtin = not tool_name.startswith(MCP_TOOL_PREFIX)
|
||||
safe_tool_use_id = _sanitize(str(tool_use_id or ""), max_len=12)
|
||||
logger.info(
|
||||
"[SDK] PostToolUse: %s (builtin=%s, tool_use_id=%s)",
|
||||
tool_name,
|
||||
is_builtin,
|
||||
(tool_use_id or "")[:12],
|
||||
safe_tool_use_id,
|
||||
)
|
||||
|
||||
# Stash output for SDK built-in tools so the response adapter can
|
||||
@@ -256,7 +294,7 @@ def create_security_hooks(
|
||||
if is_builtin:
|
||||
tool_response = input_data.get("tool_response")
|
||||
if tool_response is not None:
|
||||
resp_preview = str(tool_response)[:100]
|
||||
resp_preview = _sanitize(str(tool_response), max_len=100)
|
||||
logger.info(
|
||||
"[SDK] Stashing builtin output for %s (%d chars): %s...",
|
||||
tool_name,
|
||||
@@ -280,13 +318,17 @@ def create_security_hooks(
|
||||
"""Log failed tool executions for debugging."""
|
||||
_ = context
|
||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||
error = input_data.get("error", "Unknown error")
|
||||
error = _sanitize(str(input_data.get("error", "Unknown error")))
|
||||
safe_tool_use_id = _sanitize(str(tool_use_id or ""))
|
||||
logger.warning(
|
||||
f"[SDK] Tool failed: {tool_name}, error={error}, "
|
||||
f"user={user_id}, tool_use_id={tool_use_id}"
|
||||
"[SDK] Tool failed: %s, error=%s, user=%s, tool_use_id=%s",
|
||||
tool_name,
|
||||
error,
|
||||
user_id,
|
||||
safe_tool_use_id,
|
||||
)
|
||||
|
||||
_release_task_slot(tool_name, tool_use_id)
|
||||
_release_subagent_slot(tool_name, tool_use_id)
|
||||
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
@@ -301,16 +343,14 @@ def create_security_hooks(
|
||||
This hook provides visibility into when compaction happens.
|
||||
"""
|
||||
_ = context, tool_use_id
|
||||
trigger = input_data.get("trigger", "auto")
|
||||
trigger = _sanitize(str(input_data.get("trigger", "auto")), max_len=50)
|
||||
# Sanitize untrusted input: strip control chars for logging AND
|
||||
# for the value passed downstream. read_compacted_entries()
|
||||
# validates against _projects_base() as defence-in-depth, but
|
||||
# sanitizing here prevents log injection and rejects obviously
|
||||
# malformed paths early.
|
||||
transcript_path = (
|
||||
str(input_data.get("transcript_path", ""))
|
||||
.replace("\n", "")
|
||||
.replace("\r", "")
|
||||
transcript_path = _sanitize(
|
||||
str(input_data.get("transcript_path", "")), max_len=500
|
||||
)
|
||||
logger.info(
|
||||
"[SDK] Context compaction triggered: %s, user=%s, transcript_path=%s",
|
||||
@@ -322,6 +362,44 @@ def create_security_hooks(
|
||||
on_compact(transcript_path)
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
async def subagent_start_hook(
|
||||
input_data: HookInput,
|
||||
tool_use_id: str | None,
|
||||
context: HookContext,
|
||||
) -> SyncHookJSONOutput:
|
||||
"""Log when a sub-agent starts execution."""
|
||||
_ = context, tool_use_id
|
||||
agent_id = _sanitize(str(input_data.get("agent_id", "?")))
|
||||
agent_type = _sanitize(str(input_data.get("agent_type", "?")))
|
||||
logger.info(
|
||||
"[SDK] SubagentStart: agent_id=%s, type=%s, user=%s",
|
||||
agent_id,
|
||||
agent_type,
|
||||
user_id,
|
||||
)
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
async def subagent_stop_hook(
|
||||
input_data: HookInput,
|
||||
tool_use_id: str | None,
|
||||
context: HookContext,
|
||||
) -> SyncHookJSONOutput:
|
||||
"""Log when a sub-agent stops."""
|
||||
_ = context, tool_use_id
|
||||
agent_id = _sanitize(str(input_data.get("agent_id", "?")))
|
||||
agent_type = _sanitize(str(input_data.get("agent_type", "?")))
|
||||
transcript = _sanitize(
|
||||
str(input_data.get("agent_transcript_path", "")), max_len=500
|
||||
)
|
||||
logger.info(
|
||||
"[SDK] SubagentStop: agent_id=%s, type=%s, user=%s, transcript=%s",
|
||||
agent_id,
|
||||
agent_type,
|
||||
user_id,
|
||||
transcript,
|
||||
)
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
hooks: dict[str, Any] = {
|
||||
"PreToolUse": [HookMatcher(matcher="*", hooks=[pre_tool_use_hook])],
|
||||
"PostToolUse": [HookMatcher(matcher="*", hooks=[post_tool_use_hook])],
|
||||
@@ -329,6 +407,8 @@ def create_security_hooks(
|
||||
HookMatcher(matcher="*", hooks=[post_tool_failure_hook])
|
||||
],
|
||||
"PreCompact": [HookMatcher(matcher="*", hooks=[pre_compact_hook])],
|
||||
"SubagentStart": [HookMatcher(matcher="*", hooks=[subagent_start_hook])],
|
||||
"SubagentStop": [HookMatcher(matcher="*", hooks=[subagent_stop_hook])],
|
||||
}
|
||||
|
||||
return hooks
|
||||
|
||||
@@ -5,6 +5,7 @@ They validate that the security hooks correctly block unauthorized paths,
|
||||
tool access, and dangerous input patterns.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import pytest
|
||||
@@ -136,8 +137,20 @@ def test_read_tool_results_allowed():
|
||||
_current_project_dir.reset(token)
|
||||
|
||||
|
||||
def test_read_tool_outputs_allowed():
|
||||
"""tool-outputs/ paths should be allowed, same as tool-results/."""
|
||||
home = os.path.expanduser("~")
|
||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/a1b2c3d4-e5f6-7890-abcd-ef1234567890/tool-outputs/12345.txt"
|
||||
token = _current_project_dir.set("-tmp-copilot-abc123")
|
||||
try:
|
||||
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
|
||||
assert result == {}
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
|
||||
|
||||
def test_read_claude_projects_settings_json_denied():
|
||||
"""SDK-internal artifacts like settings.json are NOT accessible — only tool-results/ is."""
|
||||
"""SDK-internal artifacts like settings.json are NOT accessible — only tool-results/tool-outputs is."""
|
||||
home = os.path.expanduser("~")
|
||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/settings.json"
|
||||
token = _current_project_dir.set("-tmp-copilot-abc123")
|
||||
@@ -233,16 +246,15 @@ def _hooks():
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_background_blocked(_hooks):
|
||||
"""Task with run_in_background=true must be denied."""
|
||||
async def test_task_background_allowed(_hooks):
|
||||
"""Task with run_in_background=true is allowed (SDK handles async lifecycle)."""
|
||||
pre, _, _ = _hooks
|
||||
result = await pre(
|
||||
{"tool_name": "Task", "tool_input": {"run_in_background": True, "prompt": "x"}},
|
||||
tool_use_id=None,
|
||||
tool_use_id="tu-bg-1",
|
||||
context={},
|
||||
)
|
||||
assert _is_denied(result)
|
||||
assert "foreground" in _reason(result).lower()
|
||||
assert not _is_denied(result)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@@ -356,3 +368,303 @@ async def test_task_slot_released_on_failure(_hooks):
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# "Agent" tool name (SDK v2.x+ renamed "Task" → "Agent")
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_background_allowed(_hooks):
|
||||
"""Agent with run_in_background=true is allowed (SDK handles async lifecycle)."""
|
||||
pre, _, _ = _hooks
|
||||
result = await pre(
|
||||
{
|
||||
"tool_name": "Agent",
|
||||
"tool_input": {"run_in_background": True, "prompt": "x"},
|
||||
},
|
||||
tool_use_id="tu-agent-bg-1",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_foreground_allowed(_hooks):
|
||||
"""Agent without run_in_background should be allowed."""
|
||||
pre, _, _ = _hooks
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "do stuff"}},
|
||||
tool_use_id="tu-agent-1",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_background_agent_counts_against_limit(_hooks):
|
||||
"""Background agents still consume concurrency slots."""
|
||||
pre, _, _ = _hooks
|
||||
# Two background agents fill the limit
|
||||
for i in range(2):
|
||||
result = await pre(
|
||||
{
|
||||
"tool_name": "Agent",
|
||||
"tool_input": {"run_in_background": True, "prompt": "bg"},
|
||||
},
|
||||
tool_use_id=f"tu-bglimit-{i}",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
# Third (background or foreground) should be denied
|
||||
result = await pre(
|
||||
{
|
||||
"tool_name": "Agent",
|
||||
"tool_input": {"run_in_background": True, "prompt": "over"},
|
||||
},
|
||||
tool_use_id="tu-bglimit-2",
|
||||
context={},
|
||||
)
|
||||
assert _is_denied(result)
|
||||
assert "Maximum" in _reason(result)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_limit_enforced(_hooks):
|
||||
"""Agent spawns beyond max_subtasks should be denied."""
|
||||
pre, _, _ = _hooks
|
||||
# First two should pass
|
||||
for i in range(2):
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "ok"}},
|
||||
tool_use_id=f"tu-agent-limit-{i}",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
# Third should be denied (limit=2)
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "over limit"}},
|
||||
tool_use_id="tu-agent-limit-2",
|
||||
context={},
|
||||
)
|
||||
assert _is_denied(result)
|
||||
assert "Maximum" in _reason(result)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_slot_released_on_completion(_hooks):
|
||||
"""Completing an Agent should free a slot so new Agents can be spawned."""
|
||||
pre, post, _ = _hooks
|
||||
# Fill both slots
|
||||
for i in range(2):
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "ok"}},
|
||||
tool_use_id=f"tu-agent-comp-{i}",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
# Third should be denied — at capacity
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "over"}},
|
||||
tool_use_id="tu-agent-comp-2",
|
||||
context={},
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
# Complete first agent — frees a slot
|
||||
await post(
|
||||
{"tool_name": "Agent", "tool_input": {}},
|
||||
tool_use_id="tu-agent-comp-0",
|
||||
context={},
|
||||
)
|
||||
|
||||
# Now a new Agent should be allowed
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "after release"}},
|
||||
tool_use_id="tu-agent-comp-3",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_slot_released_on_failure(_hooks):
|
||||
"""A failed Agent should also free its concurrency slot."""
|
||||
pre, _, post_failure = _hooks
|
||||
# Fill both slots
|
||||
for i in range(2):
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "ok"}},
|
||||
tool_use_id=f"tu-agent-fail-{i}",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
# At capacity
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "over"}},
|
||||
tool_use_id="tu-agent-fail-2",
|
||||
context={},
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
# Fail first agent — should free a slot
|
||||
await post_failure(
|
||||
{"tool_name": "Agent", "tool_input": {}, "error": "something broke"},
|
||||
tool_use_id="tu-agent-fail-0",
|
||||
context={},
|
||||
)
|
||||
|
||||
# New Agent should be allowed
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "after failure"}},
|
||||
tool_use_id="tu-agent-fail-3",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_task_agent_share_slots(_hooks):
|
||||
"""Task and Agent share the same concurrency pool."""
|
||||
pre, post, _ = _hooks
|
||||
# Fill one slot with Task, one with Agent
|
||||
result = await pre(
|
||||
{"tool_name": "Task", "tool_input": {"prompt": "ok"}},
|
||||
tool_use_id="tu-mix-task",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "ok"}},
|
||||
tool_use_id="tu-mix-agent",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
# Third (either name) should be denied
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "over"}},
|
||||
tool_use_id="tu-mix-over",
|
||||
context={},
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
# Release the Task slot
|
||||
await post(
|
||||
{"tool_name": "Task", "tool_input": {}},
|
||||
tool_use_id="tu-mix-task",
|
||||
context={},
|
||||
)
|
||||
|
||||
# Now an Agent should be allowed
|
||||
result = await pre(
|
||||
{"tool_name": "Agent", "tool_input": {"prompt": "after task release"}},
|
||||
tool_use_id="tu-mix-new",
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SubagentStart / SubagentStop hooks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def _subagent_hooks():
|
||||
"""Create hooks and return (subagent_start, subagent_stop) handlers."""
|
||||
hooks = create_security_hooks(user_id="u1", sdk_cwd=SDK_CWD, max_subtasks=2)
|
||||
start = hooks["SubagentStart"][0].hooks[0]
|
||||
stop = hooks["SubagentStop"][0].hooks[0]
|
||||
return start, stop
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_subagent_start_hook_returns_empty(_subagent_hooks):
|
||||
"""SubagentStart hook should return an empty dict (logging only)."""
|
||||
start, _ = _subagent_hooks
|
||||
result = await start(
|
||||
{"agent_id": "sa-123", "agent_type": "research"},
|
||||
tool_use_id=None,
|
||||
context={},
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_subagent_stop_hook_returns_empty(_subagent_hooks):
|
||||
"""SubagentStop hook should return an empty dict (logging only)."""
|
||||
_, stop = _subagent_hooks
|
||||
result = await stop(
|
||||
{
|
||||
"agent_id": "sa-123",
|
||||
"agent_type": "research",
|
||||
"agent_transcript_path": "/tmp/transcript.txt",
|
||||
},
|
||||
tool_use_id=None,
|
||||
context={},
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_subagent_hooks_sanitize_inputs(_subagent_hooks, caplog):
|
||||
"""SubagentStart/Stop should sanitize control chars from inputs."""
|
||||
start, stop = _subagent_hooks
|
||||
# Inject control characters (C0, DEL, C1, BiDi overrides, zero-width)
|
||||
# — hook should not raise AND logs must be clean
|
||||
with caplog.at_level(logging.DEBUG, logger="backend.copilot.sdk.security_hooks"):
|
||||
result = await start(
|
||||
{
|
||||
"agent_id": "sa\n-injected\r\x00\x7f",
|
||||
"agent_type": "safe\x80_type\x9f\ttab",
|
||||
},
|
||||
tool_use_id=None,
|
||||
context={},
|
||||
)
|
||||
assert result == {}
|
||||
# Control chars must be stripped from the logged values
|
||||
for record in caplog.records:
|
||||
assert "\x00" not in record.message
|
||||
assert "\r" not in record.message
|
||||
assert "\n" not in record.message
|
||||
assert "\x7f" not in record.message
|
||||
assert "\x80" not in record.message
|
||||
assert "\x9f" not in record.message
|
||||
assert "safe_type" in caplog.text
|
||||
|
||||
caplog.clear()
|
||||
with caplog.at_level(logging.DEBUG, logger="backend.copilot.sdk.security_hooks"):
|
||||
result = await stop(
|
||||
{
|
||||
"agent_id": "sa\n-injected\x7f",
|
||||
"agent_type": "type\r\x80\x9f",
|
||||
"agent_transcript_path": "/tmp/\x00malicious\npath\u202a\u200b",
|
||||
},
|
||||
tool_use_id=None,
|
||||
context={},
|
||||
)
|
||||
assert result == {}
|
||||
for record in caplog.records:
|
||||
assert "\x00" not in record.message
|
||||
assert "\r" not in record.message
|
||||
assert "\n" not in record.message
|
||||
assert "\x7f" not in record.message
|
||||
assert "\u202a" not in record.message
|
||||
assert "\u200b" not in record.message
|
||||
assert "/tmp/maliciouspath" in caplog.text
|
||||
|
||||
@@ -1310,10 +1310,16 @@ async def _run_stream_attempt(
|
||||
# AssistantMessage.error (not as a Python exception).
|
||||
# Re-raise so the outer retry loop can compact the
|
||||
# transcript and retry with reduced context.
|
||||
# Only check error_text (the error field), not the
|
||||
# content preview — content may contain arbitrary text
|
||||
# that false-positives the pattern match.
|
||||
if _is_prompt_too_long(Exception(error_text)):
|
||||
# Check both error_text and error_preview: sdk_error
|
||||
# being set confirms this is an error message (not user
|
||||
# content), so checking content is safe. The actual
|
||||
# error description (e.g. "Prompt is too long") may be
|
||||
# in the content, not the error type field
|
||||
# (e.g. error="invalid_request", content="Prompt is
|
||||
# too long").
|
||||
if _is_prompt_too_long(Exception(error_text)) or _is_prompt_too_long(
|
||||
Exception(error_preview)
|
||||
):
|
||||
logger.warning(
|
||||
"%s Prompt-too-long detected via AssistantMessage "
|
||||
"error — raising for retry",
|
||||
@@ -1414,13 +1420,16 @@ async def _run_stream_attempt(
|
||||
ctx.log_prefix,
|
||||
sdk_msg.result or "(no error message provided)",
|
||||
)
|
||||
# If the CLI itself rejected the prompt as too long
|
||||
# (pre-API check, duration_api_ms=0), re-raise as an
|
||||
# exception so the retry loop can trigger compaction.
|
||||
# Without this, the ResultMessage is silently consumed
|
||||
# and the retry/compaction mechanism is never invoked.
|
||||
if _is_prompt_too_long(RuntimeError(sdk_msg.result or "")):
|
||||
raise RuntimeError("Prompt is too long")
|
||||
|
||||
# Check for prompt-too-long regardless of subtype — the
|
||||
# SDK may return subtype="success" with result="Prompt is
|
||||
# too long" when the CLI rejects the prompt before calling
|
||||
# the API (cost_usd=0, no tokens consumed). If we only
|
||||
# check the "error" subtype path, the stream appears to
|
||||
# complete normally, the synthetic error text is stored
|
||||
# in the transcript, and the session grows without bound.
|
||||
if _is_prompt_too_long(RuntimeError(sdk_msg.result or "")):
|
||||
raise RuntimeError("Prompt is too long")
|
||||
|
||||
# Capture token usage from ResultMessage.
|
||||
# Anthropic reports cached tokens separately:
|
||||
@@ -1453,6 +1462,23 @@ async def _run_stream_attempt(
|
||||
# Emit compaction end if SDK finished compacting.
|
||||
# Sync TranscriptBuilder with the CLI's active context.
|
||||
compact_result = await ctx.compaction.emit_end_if_ready(ctx.session)
|
||||
if compact_result.events:
|
||||
# Compaction events end with StreamFinishStep, which maps to
|
||||
# Vercel AI SDK's "finish-step" — that clears activeTextParts.
|
||||
# Close any open text block BEFORE the compaction events so
|
||||
# the text-end arrives before finish-step, preventing
|
||||
# "text-end for missing text part" errors on the frontend.
|
||||
pre_close: list[StreamBaseResponse] = []
|
||||
state.adapter._end_text_if_open(pre_close)
|
||||
# Compaction events bypass the adapter, so sync step state
|
||||
# when a StreamFinishStep is present — otherwise the adapter
|
||||
# will skip StreamStartStep on the next AssistantMessage.
|
||||
if any(
|
||||
isinstance(ev, StreamFinishStep) for ev in compact_result.events
|
||||
):
|
||||
state.adapter.step_open = False
|
||||
for r in pre_close:
|
||||
yield r
|
||||
for ev in compact_result.events:
|
||||
yield ev
|
||||
entries_replaced = False
|
||||
@@ -1858,7 +1884,10 @@ async def stream_chat_completion_sdk(
|
||||
)
|
||||
|
||||
# Fail fast when no API credentials are available at all.
|
||||
sdk_env = build_sdk_env(session_id=session_id, user_id=user_id)
|
||||
# sdk_cwd routes the CLI's temp dir into the per-session workspace
|
||||
# so sub-agent output files land inside sdk_cwd (see build_sdk_env).
|
||||
sdk_env = build_sdk_env(session_id=session_id, user_id=user_id, sdk_cwd=sdk_cwd)
|
||||
|
||||
if not config.api_key and not config.use_claude_code_subscription:
|
||||
raise RuntimeError(
|
||||
"No API key configured. Set OPEN_ROUTER_API_KEY, "
|
||||
|
||||
@@ -10,6 +10,7 @@ import pytest
|
||||
|
||||
from .service import (
|
||||
_is_sdk_disconnect_error,
|
||||
_normalize_model_name,
|
||||
_prepare_file_attachments,
|
||||
_resolve_sdk_model,
|
||||
_safe_close_sdk_client,
|
||||
@@ -405,6 +406,49 @@ def _clean_config_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
|
||||
|
||||
class TestNormalizeModelName:
|
||||
"""Tests for _normalize_model_name — shared provider-aware normalization."""
|
||||
|
||||
def test_strips_provider_prefix(self, monkeypatch, _clean_config_env):
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
use_openrouter=False,
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
assert _normalize_model_name("anthropic/claude-opus-4.6") == "claude-opus-4-6"
|
||||
|
||||
def test_dots_preserved_for_openrouter(self, monkeypatch, _clean_config_env):
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
use_openrouter=True,
|
||||
api_key="or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
assert _normalize_model_name("anthropic/claude-opus-4.6") == "claude-opus-4.6"
|
||||
|
||||
def test_no_prefix_no_dots(self, monkeypatch, _clean_config_env):
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
use_openrouter=False,
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
assert (
|
||||
_normalize_model_name("claude-sonnet-4-20250514")
|
||||
== "claude-sonnet-4-20250514"
|
||||
)
|
||||
|
||||
|
||||
class TestResolveSdkModel:
|
||||
"""Tests for _resolve_sdk_model — model ID resolution for the SDK CLI."""
|
||||
|
||||
|
||||
@@ -439,7 +439,7 @@ class TestCompactTranscriptThinkingBlocks:
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
@@ -498,7 +498,7 @@ class TestCompactTranscriptThinkingBlocks:
|
||||
)()
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
side_effect=mock_compression,
|
||||
):
|
||||
await compact_transcript(transcript, model="test-model")
|
||||
@@ -551,7 +551,7 @@ class TestCompactTranscriptThinkingBlocks:
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
@@ -601,7 +601,7 @@ class TestCompactTranscriptThinkingBlocks:
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
@@ -638,7 +638,7 @@ class TestCompactTranscriptThinkingBlocks:
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
@@ -699,7 +699,7 @@ class TestCompactTranscriptThinkingBlocks:
|
||||
},
|
||||
)()
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
"backend.copilot.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
|
||||
@@ -38,7 +38,7 @@ from backend.copilot.tools import TOOL_REGISTRY
|
||||
from backend.copilot.tools.base import BaseTool
|
||||
from backend.util.truncate import truncate
|
||||
|
||||
from .e2b_file_tools import E2B_FILE_TOOL_NAMES, E2B_FILE_TOOLS
|
||||
from .e2b_file_tools import E2B_FILE_TOOL_NAMES, E2B_FILE_TOOLS, bridge_and_annotate
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from e2b import AsyncSandbox
|
||||
@@ -387,7 +387,16 @@ async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]:
|
||||
selected = list(itertools.islice(f, offset, offset + limit))
|
||||
# Cleanup happens in _cleanup_sdk_tool_results after session ends;
|
||||
# don't delete here — the SDK may read in multiple chunks.
|
||||
return _mcp_ok("".join(selected))
|
||||
#
|
||||
# When E2B is active, also copy the file into the sandbox so
|
||||
# bash_exec can process it (the model often uses Read then bash).
|
||||
text = "".join(selected)
|
||||
sandbox = _current_sandbox.get(None)
|
||||
if sandbox is not None:
|
||||
annotation = await bridge_and_annotate(sandbox, resolved, offset, limit)
|
||||
if annotation:
|
||||
text += annotation
|
||||
return _mcp_ok(text)
|
||||
except FileNotFoundError:
|
||||
return _mcp_err(f"File not found: {file_path}")
|
||||
except Exception as e:
|
||||
@@ -581,13 +590,14 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
# Security hooks validate that file paths stay within sdk_cwd.
|
||||
# Bash is NOT included — use the sandboxed MCP bash_exec tool instead,
|
||||
# which provides kernel-level network isolation via unshare --net.
|
||||
# Task allows spawning sub-agents (rate-limited by security hooks).
|
||||
# Task/Agent allows spawning sub-agents (rate-limited by security hooks).
|
||||
# The CLI renamed "Task" → "Agent" in v2.x; both are listed for compat.
|
||||
# WebSearch uses Brave Search via Anthropic's API — safe, no SSRF risk.
|
||||
# TodoWrite manages the task checklist shown in the UI — no security concern.
|
||||
# In E2B mode, all five are disabled — MCP equivalents provide direct sandbox
|
||||
# access. read_file also handles local tool-results and ephemeral reads.
|
||||
_SDK_BUILTIN_FILE_TOOLS = ["Read", "Write", "Edit", "Glob", "Grep"]
|
||||
_SDK_BUILTIN_ALWAYS = ["Task", "WebSearch", "TodoWrite"]
|
||||
_SDK_BUILTIN_ALWAYS = ["Task", "Agent", "WebSearch", "TodoWrite"]
|
||||
_SDK_BUILTIN_TOOLS = [*_SDK_BUILTIN_FILE_TOOLS, *_SDK_BUILTIN_ALWAYS]
|
||||
|
||||
# SDK built-in tools that must be explicitly blocked.
|
||||
|
||||
@@ -619,3 +619,95 @@ class TestSDKDisallowedTools:
|
||||
def test_webfetch_tool_is_disallowed(self):
|
||||
"""WebFetch is disallowed due to SSRF risk."""
|
||||
assert "WebFetch" in SDK_DISALLOWED_TOOLS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _read_file_handler — bridge_and_annotate integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReadFileHandlerBridge:
|
||||
"""Verify that _read_file_handler calls bridge_and_annotate when a sandbox is active."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init_context(self):
|
||||
set_execution_context(
|
||||
user_id="test",
|
||||
session=None, # type: ignore[arg-type]
|
||||
sandbox=None,
|
||||
sdk_cwd="/tmp/copilot-bridge-test",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bridge_called_when_sandbox_active(self, tmp_path, monkeypatch):
|
||||
"""When a sandbox is set, bridge_and_annotate is called and its annotation appended."""
|
||||
from backend.copilot.context import _current_sandbox
|
||||
|
||||
from .tool_adapter import _read_file_handler
|
||||
|
||||
test_file = tmp_path / "tool-results" / "data.json"
|
||||
test_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
test_file.write_text('{"ok": true}\n')
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.tool_adapter.is_allowed_local_path",
|
||||
lambda path, cwd: True,
|
||||
)
|
||||
|
||||
fake_sandbox = object()
|
||||
token = _current_sandbox.set(fake_sandbox) # type: ignore[arg-type]
|
||||
try:
|
||||
bridge_calls: list[tuple] = []
|
||||
|
||||
async def fake_bridge_and_annotate(sandbox, file_path, offset, limit):
|
||||
bridge_calls.append((sandbox, file_path, offset, limit))
|
||||
return "\n[Sandbox copy available at /tmp/abc-data.json]"
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.tool_adapter.bridge_and_annotate",
|
||||
fake_bridge_and_annotate,
|
||||
)
|
||||
|
||||
result = await _read_file_handler(
|
||||
{"file_path": str(test_file), "offset": 0, "limit": 2000}
|
||||
)
|
||||
|
||||
assert result["isError"] is False
|
||||
assert len(bridge_calls) == 1
|
||||
assert bridge_calls[0][0] is fake_sandbox
|
||||
assert "/tmp/abc-data.json" in result["content"][0]["text"]
|
||||
finally:
|
||||
_current_sandbox.reset(token)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bridge_not_called_without_sandbox(self, tmp_path, monkeypatch):
|
||||
"""When no sandbox is set, bridge_and_annotate is not called."""
|
||||
from .tool_adapter import _read_file_handler
|
||||
|
||||
test_file = tmp_path / "tool-results" / "data.json"
|
||||
test_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
test_file.write_text('{"ok": true}\n')
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.tool_adapter.is_allowed_local_path",
|
||||
lambda path, cwd: True,
|
||||
)
|
||||
|
||||
bridge_calls: list[tuple] = []
|
||||
|
||||
async def fake_bridge_and_annotate(sandbox, file_path, offset, limit):
|
||||
bridge_calls.append((sandbox, file_path, offset, limit))
|
||||
return "\n[Sandbox copy available at /tmp/abc-data.json]"
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.tool_adapter.bridge_and_annotate",
|
||||
fake_bridge_and_annotate,
|
||||
)
|
||||
|
||||
result = await _read_file_handler(
|
||||
{"file_path": str(test_file), "offset": 0, "limit": 2000}
|
||||
)
|
||||
|
||||
assert result["isError"] is False
|
||||
assert len(bridge_calls) == 0
|
||||
assert "Sandbox copy" not in result["content"][0]["text"]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,235 +1,10 @@
|
||||
"""Build complete JSONL transcript from SDK messages.
|
||||
"""Re-export from shared ``backend.copilot.transcript_builder`` for backward compat.
|
||||
|
||||
The transcript represents the FULL active context at any point in time.
|
||||
Each upload REPLACES the previous transcript atomically.
|
||||
|
||||
Flow:
|
||||
Turn 1: Upload [msg1, msg2]
|
||||
Turn 2: Download [msg1, msg2] → Upload [msg1, msg2, msg3, msg4] (REPLACE)
|
||||
Turn 3: Download [msg1, msg2, msg3, msg4] → Upload [all messages] (REPLACE)
|
||||
|
||||
The transcript is never incremental - always the complete atomic state.
|
||||
The canonical implementation now lives at ``backend.copilot.transcript_builder``
|
||||
so both the SDK and baseline paths can import without cross-package
|
||||
dependencies.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder, TranscriptEntry
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.util import json
|
||||
|
||||
from .transcript import STRIPPABLE_TYPES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TranscriptEntry(BaseModel):
|
||||
"""Single transcript entry (user or assistant turn)."""
|
||||
|
||||
type: str
|
||||
uuid: str
|
||||
parentUuid: str | None
|
||||
isCompactSummary: bool | None = None
|
||||
message: dict[str, Any]
|
||||
|
||||
|
||||
class TranscriptBuilder:
|
||||
"""Build complete JSONL transcript from SDK messages.
|
||||
|
||||
This builder maintains the FULL conversation state, not incremental changes.
|
||||
The output is always the complete active context.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._entries: list[TranscriptEntry] = []
|
||||
self._last_uuid: str | None = None
|
||||
|
||||
def _last_is_assistant(self) -> bool:
|
||||
return bool(self._entries) and self._entries[-1].type == "assistant"
|
||||
|
||||
def _last_message_id(self) -> str:
|
||||
"""Return the message.id of the last entry, or '' if none."""
|
||||
if self._entries:
|
||||
return self._entries[-1].message.get("id", "")
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _parse_entry(data: dict) -> TranscriptEntry | None:
|
||||
"""Parse a single transcript entry, filtering strippable types.
|
||||
|
||||
Returns ``None`` for entries that should be skipped (strippable types
|
||||
that are not compaction summaries).
|
||||
"""
|
||||
entry_type = data.get("type", "")
|
||||
if entry_type in STRIPPABLE_TYPES and not data.get("isCompactSummary"):
|
||||
return None
|
||||
return TranscriptEntry(
|
||||
type=entry_type,
|
||||
uuid=data.get("uuid") or str(uuid4()),
|
||||
parentUuid=data.get("parentUuid"),
|
||||
isCompactSummary=data.get("isCompactSummary"),
|
||||
message=data.get("message", {}),
|
||||
)
|
||||
|
||||
def load_previous(self, content: str, log_prefix: str = "[Transcript]") -> None:
|
||||
"""Load complete previous transcript.
|
||||
|
||||
This loads the FULL previous context. As new messages come in,
|
||||
we append to this state. The final output is the complete context
|
||||
(previous + new), not just the delta.
|
||||
"""
|
||||
if not content or not content.strip():
|
||||
return
|
||||
|
||||
lines = content.strip().split("\n")
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
data = json.loads(line, fallback=None)
|
||||
if data is None:
|
||||
logger.warning(
|
||||
"%s Failed to parse transcript line %d/%d",
|
||||
log_prefix,
|
||||
line_num,
|
||||
len(lines),
|
||||
)
|
||||
continue
|
||||
|
||||
entry = self._parse_entry(data)
|
||||
if entry is None:
|
||||
continue
|
||||
self._entries.append(entry)
|
||||
self._last_uuid = entry.uuid
|
||||
|
||||
logger.info(
|
||||
"%s Loaded %d entries from previous transcript (last_uuid=%s)",
|
||||
log_prefix,
|
||||
len(self._entries),
|
||||
self._last_uuid[:12] if self._last_uuid else None,
|
||||
)
|
||||
|
||||
def append_user(self, content: str | list[dict], uuid: str | None = None) -> None:
|
||||
"""Append a user entry."""
|
||||
msg_uuid = uuid or str(uuid4())
|
||||
|
||||
self._entries.append(
|
||||
TranscriptEntry(
|
||||
type="user",
|
||||
uuid=msg_uuid,
|
||||
parentUuid=self._last_uuid,
|
||||
message={"role": "user", "content": content},
|
||||
)
|
||||
)
|
||||
self._last_uuid = msg_uuid
|
||||
|
||||
def append_tool_result(self, tool_use_id: str, content: str) -> None:
|
||||
"""Append a tool result as a user entry (one per tool call)."""
|
||||
self.append_user(
|
||||
content=[
|
||||
{"type": "tool_result", "tool_use_id": tool_use_id, "content": content}
|
||||
]
|
||||
)
|
||||
|
||||
def append_assistant(
|
||||
self,
|
||||
content_blocks: list[dict],
|
||||
model: str = "",
|
||||
stop_reason: str | None = None,
|
||||
) -> None:
|
||||
"""Append an assistant entry.
|
||||
|
||||
Consecutive assistant entries automatically share the same message ID
|
||||
so the CLI can merge them (thinking → text → tool_use) into a single
|
||||
API message on ``--resume``. A new ID is assigned whenever an
|
||||
assistant entry follows a non-assistant entry (user message or tool
|
||||
result), because that marks the start of a new API response.
|
||||
"""
|
||||
message_id = (
|
||||
self._last_message_id()
|
||||
if self._last_is_assistant()
|
||||
else f"msg_sdk_{uuid4().hex[:24]}"
|
||||
)
|
||||
|
||||
msg_uuid = str(uuid4())
|
||||
|
||||
self._entries.append(
|
||||
TranscriptEntry(
|
||||
type="assistant",
|
||||
uuid=msg_uuid,
|
||||
parentUuid=self._last_uuid,
|
||||
message={
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"id": message_id,
|
||||
"type": "message",
|
||||
"content": content_blocks,
|
||||
"stop_reason": stop_reason,
|
||||
"stop_sequence": None,
|
||||
},
|
||||
)
|
||||
)
|
||||
self._last_uuid = msg_uuid
|
||||
|
||||
def replace_entries(
|
||||
self, compacted_entries: list[dict], log_prefix: str = "[Transcript]"
|
||||
) -> None:
|
||||
"""Replace all entries with compacted entries from the CLI session file.
|
||||
|
||||
Called after mid-stream compaction so TranscriptBuilder mirrors the
|
||||
CLI's active context (compaction summary + post-compaction entries).
|
||||
|
||||
Builds the new list first and validates it's non-empty before swapping,
|
||||
so corrupt input cannot wipe the conversation history.
|
||||
"""
|
||||
new_entries: list[TranscriptEntry] = []
|
||||
for data in compacted_entries:
|
||||
entry = self._parse_entry(data)
|
||||
if entry is not None:
|
||||
new_entries.append(entry)
|
||||
|
||||
if not new_entries:
|
||||
logger.warning(
|
||||
"%s replace_entries produced 0 entries from %d inputs, keeping old (%d entries)",
|
||||
log_prefix,
|
||||
len(compacted_entries),
|
||||
len(self._entries),
|
||||
)
|
||||
return
|
||||
|
||||
old_count = len(self._entries)
|
||||
self._entries = new_entries
|
||||
self._last_uuid = new_entries[-1].uuid
|
||||
|
||||
logger.info(
|
||||
"%s TranscriptBuilder compacted: %d entries -> %d entries",
|
||||
log_prefix,
|
||||
old_count,
|
||||
len(self._entries),
|
||||
)
|
||||
|
||||
def to_jsonl(self) -> str:
|
||||
"""Export complete context as JSONL.
|
||||
|
||||
Consecutive assistant entries are kept separate to match the
|
||||
native CLI format — the SDK merges them internally on resume.
|
||||
|
||||
Returns the FULL conversation state (all entries), not incremental.
|
||||
This output REPLACES any previous transcript.
|
||||
"""
|
||||
if not self._entries:
|
||||
return ""
|
||||
|
||||
lines = [entry.model_dump_json(exclude_none=True) for entry in self._entries]
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
@property
|
||||
def entry_count(self) -> int:
|
||||
"""Total number of entries in the complete context."""
|
||||
return len(self._entries)
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
"""Whether this builder has any entries."""
|
||||
return len(self._entries) == 0
|
||||
__all__ = ["TranscriptBuilder", "TranscriptEntry"]
|
||||
|
||||
@@ -303,7 +303,7 @@ class TestDeleteTranscript:
|
||||
mock_storage.delete = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript.get_workspace_storage",
|
||||
"backend.copilot.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
@@ -323,7 +323,7 @@ class TestDeleteTranscript:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript.get_workspace_storage",
|
||||
"backend.copilot.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
@@ -341,7 +341,7 @@ class TestDeleteTranscript:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.transcript.get_workspace_storage",
|
||||
"backend.copilot.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
@@ -858,11 +858,11 @@ class TestRunCompression:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
"backend.copilot.transcript.get_openai_client",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
"backend.copilot.transcript.compress_context",
|
||||
new_callable=AsyncMock,
|
||||
return_value=truncation_result,
|
||||
) as mock_compress,
|
||||
@@ -894,11 +894,11 @@ class TestRunCompression:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
"backend.copilot.transcript.get_openai_client",
|
||||
return_value=mock_client,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
"backend.copilot.transcript.compress_context",
|
||||
new_callable=AsyncMock,
|
||||
return_value=llm_result,
|
||||
) as mock_compress,
|
||||
@@ -932,11 +932,11 @@ class TestRunCompression:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
"backend.copilot.transcript.get_openai_client",
|
||||
return_value=mock_client,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
"backend.copilot.transcript.compress_context",
|
||||
side_effect=_compress_side_effect,
|
||||
),
|
||||
):
|
||||
@@ -970,19 +970,19 @@ class TestRunCompression:
|
||||
fake_client = MagicMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
"backend.copilot.transcript.get_openai_client",
|
||||
return_value=fake_client,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
"backend.copilot.transcript.compress_context",
|
||||
side_effect=_compress_side_effect,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._COMPACTION_TIMEOUT_SECONDS",
|
||||
"backend.copilot.transcript._COMPACTION_TIMEOUT_SECONDS",
|
||||
0.05,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._TRUNCATION_TIMEOUT_SECONDS",
|
||||
"backend.copilot.transcript._TRUNCATION_TIMEOUT_SECONDS",
|
||||
5,
|
||||
),
|
||||
):
|
||||
@@ -1015,7 +1015,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
"backend.copilot.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1044,7 +1044,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
"backend.copilot.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1070,7 +1070,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
"backend.copilot.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1096,7 +1096,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
"backend.copilot.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1118,7 +1118,7 @@ class TestCleanupStaleProjectDirs:
|
||||
|
||||
nonexistent = str(tmp_path / "does-not-exist" / "projects")
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
"backend.copilot.transcript._projects_base",
|
||||
lambda: nonexistent,
|
||||
)
|
||||
|
||||
@@ -1137,7 +1137,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
"backend.copilot.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1165,7 +1165,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
"backend.copilot.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1189,7 +1189,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
"backend.copilot.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import pytest
|
||||
from .model import create_chat_session, get_chat_session, upsert_chat_session
|
||||
from .response_model import StreamError, StreamTextDelta
|
||||
from .sdk import service as sdk_service
|
||||
from .sdk.transcript import download_transcript
|
||||
from .transcript import download_transcript
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -4,12 +4,15 @@ Both the baseline (OpenRouter) and SDK (Anthropic) service layers need to:
|
||||
1. Append a ``Usage`` record to the session.
|
||||
2. Log the turn's token counts.
|
||||
3. Record weighted usage in Redis for rate-limiting.
|
||||
4. Write a PlatformCostLog entry for admin cost tracking.
|
||||
|
||||
This module extracts that common logic so both paths stay in sync.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from backend.data.platform_cost import PlatformCostEntry, log_platform_cost_safe
|
||||
|
||||
from .model import ChatSession, Usage
|
||||
from .rate_limit import record_token_usage
|
||||
|
||||
@@ -95,4 +98,47 @@ async def persist_and_record_usage(
|
||||
except Exception as usage_err:
|
||||
logger.warning(f"{log_prefix} Failed to record token usage: {usage_err}")
|
||||
|
||||
# Log to PlatformCostLog for admin cost dashboard
|
||||
if user_id and total_tokens > 0:
|
||||
cost_float = None
|
||||
if cost_usd is not None:
|
||||
try:
|
||||
cost_float = float(cost_usd)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
cost_microdollars = (
|
||||
int(cost_float * 1_000_000) if cost_float is not None else None
|
||||
)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if cost_float is not None:
|
||||
tracking_type = "cost_usd"
|
||||
tracking_amount = cost_float
|
||||
else:
|
||||
tracking_type = "tokens"
|
||||
tracking_amount = total_tokens
|
||||
|
||||
await log_platform_cost_safe(
|
||||
PlatformCostEntry(
|
||||
user_id=user_id,
|
||||
graph_exec_id=session_id,
|
||||
block_id="copilot",
|
||||
block_name=f"copilot:{log_prefix.strip(' []')}".rstrip(":"),
|
||||
provider="open_router",
|
||||
credential_id="copilot_system",
|
||||
cost_microdollars=cost_microdollars,
|
||||
input_tokens=prompt_tokens,
|
||||
output_tokens=completion_tokens,
|
||||
model=None,
|
||||
metadata={
|
||||
"tracking_type": tracking_type,
|
||||
"tracking_amount": tracking_amount,
|
||||
"cache_read_tokens": cache_read_tokens,
|
||||
"cache_creation_tokens": cache_creation_tokens,
|
||||
"source": "copilot",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return total_tokens
|
||||
|
||||
@@ -10,6 +10,7 @@ from backend.copilot.tracking import track_tool_called
|
||||
from .add_understanding import AddUnderstandingTool
|
||||
from .agent_browser import BrowserActTool, BrowserNavigateTool, BrowserScreenshotTool
|
||||
from .agent_output import AgentOutputTool
|
||||
from .ask_question import AskQuestionTool
|
||||
from .base import BaseTool
|
||||
from .bash_exec import BashExecTool
|
||||
from .connect_integration import ConnectIntegrationTool
|
||||
@@ -55,6 +56,7 @@ logger = logging.getLogger(__name__)
|
||||
# Single source of truth for all tools
|
||||
TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"add_understanding": AddUnderstandingTool(),
|
||||
"ask_question": AskQuestionTool(),
|
||||
"create_agent": CreateAgentTool(),
|
||||
"customize_agent": CustomizeAgentTool(),
|
||||
"edit_agent": EditAgentTool(),
|
||||
|
||||
@@ -34,11 +34,16 @@ _GMAIL_SEND_BLOCK_ID = "6c27abc2-e51d-499e-a85f-5a0041ba94f0"
|
||||
_TEXT_REPLACE_BLOCK_ID = "7e7c87ab-3469-4bcc-9abe-67705091b713"
|
||||
|
||||
# Defaults applied to OrchestratorBlock nodes by the fixer.
|
||||
_SDM_DEFAULTS: dict[str, int | bool] = {
|
||||
# execution_mode and model match the copilot's default (extended thinking
|
||||
# with Opus) so generated agents inherit the same reasoning capabilities.
|
||||
# If the user explicitly sets these fields, the fixer won't override them.
|
||||
_SDM_DEFAULTS: dict[str, int | bool | str] = {
|
||||
"agent_mode_max_iterations": 10,
|
||||
"conversation_compaction": True,
|
||||
"retry": 3,
|
||||
"multiple_tool_calls": False,
|
||||
"execution_mode": "extended_thinking",
|
||||
"model": "claude-opus-4-6",
|
||||
}
|
||||
|
||||
|
||||
@@ -1649,6 +1654,8 @@ class AgentFixer:
|
||||
2. ``conversation_compaction`` defaults to ``True``
|
||||
3. ``retry`` defaults to ``3``
|
||||
4. ``multiple_tool_calls`` defaults to ``False``
|
||||
5. ``execution_mode`` defaults to ``"extended_thinking"``
|
||||
6. ``model`` defaults to ``"claude-opus-4-6"``
|
||||
|
||||
Args:
|
||||
agent: The agent dictionary to fix
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
@@ -9,7 +10,7 @@ if TYPE_CHECKING:
|
||||
from backend.api.features.library.model import LibraryAgent
|
||||
from backend.api.features.store.model import StoreAgent, StoreAgentDetails
|
||||
|
||||
from backend.data.db_accessors import library_db, store_db
|
||||
from backend.data.db_accessors import graph_db, library_db, store_db
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
|
||||
from .models import (
|
||||
@@ -34,12 +35,13 @@ async def search_agents(
|
||||
source: SearchSource,
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
include_graph: bool = False,
|
||||
) -> ToolResponseBase:
|
||||
"""Search for agents in marketplace or user library."""
|
||||
if source == "marketplace":
|
||||
return await _search_marketplace(query, session_id)
|
||||
else:
|
||||
return await _search_library(query, session_id, user_id)
|
||||
return await _search_library(query, session_id, user_id, include_graph)
|
||||
|
||||
|
||||
async def _search_marketplace(query: str, session_id: str | None) -> ToolResponseBase:
|
||||
@@ -105,7 +107,10 @@ async def _search_marketplace(query: str, session_id: str | None) -> ToolRespons
|
||||
|
||||
|
||||
async def _search_library(
|
||||
query: str, session_id: str | None, user_id: str | None
|
||||
query: str,
|
||||
session_id: str | None,
|
||||
user_id: str | None,
|
||||
include_graph: bool = False,
|
||||
) -> ToolResponseBase:
|
||||
"""Search user's library agents, with direct UUID lookup fallback."""
|
||||
if not user_id:
|
||||
@@ -149,6 +154,10 @@ async def _search_library(
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
truncation_notice: str | None = None
|
||||
if include_graph and agents:
|
||||
truncation_notice = await _enrich_agents_with_graph(agents, user_id)
|
||||
|
||||
if not agents:
|
||||
if not query:
|
||||
return NoResultsResponse(
|
||||
@@ -182,13 +191,17 @@ async def _search_library(
|
||||
else:
|
||||
title = f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} in your library for '{query}'"
|
||||
|
||||
message = (
|
||||
"Found agents in the user's library. You can provide a link to view "
|
||||
"an agent at: /library/agents/{agent_id}. Use agent_output to get "
|
||||
"execution results, or run_agent to execute. Let the user know we can "
|
||||
"create a custom agent for them based on their needs."
|
||||
)
|
||||
if truncation_notice:
|
||||
message = f"{message}\n\nNote: {truncation_notice}"
|
||||
|
||||
return AgentsFoundResponse(
|
||||
message=(
|
||||
"Found agents in the user's library. You can provide a link to view "
|
||||
"an agent at: /library/agents/{agent_id}. Use agent_output to get "
|
||||
"execution results, or run_agent to execute. Let the user know we can "
|
||||
"create a custom agent for them based on their needs."
|
||||
),
|
||||
message=message,
|
||||
title=title,
|
||||
agents=agents,
|
||||
count=len(agents),
|
||||
@@ -196,6 +209,81 @@ async def _search_library(
|
||||
)
|
||||
|
||||
|
||||
_MAX_GRAPH_FETCHES = 10
|
||||
|
||||
|
||||
_GRAPH_FETCH_TIMEOUT = 15 # seconds
|
||||
|
||||
|
||||
async def _enrich_agents_with_graph(
|
||||
agents: list[AgentInfo], user_id: str
|
||||
) -> str | None:
|
||||
"""Fetch and attach full Graph (nodes + links) to each agent in-place.
|
||||
|
||||
Only the first ``_MAX_GRAPH_FETCHES`` agents with a ``graph_id`` are
|
||||
enriched. If some agents are skipped, a truncation notice is returned
|
||||
so the caller can surface it to the copilot.
|
||||
|
||||
Graphs are fetched with ``for_export=True`` so that credentials, API keys,
|
||||
and other secrets in ``input_default`` are stripped before the data reaches
|
||||
the LLM context.
|
||||
|
||||
Returns a truncation notice string when some agents were skipped, or
|
||||
``None`` when all eligible agents were enriched.
|
||||
"""
|
||||
with_graph_id = [a for a in agents if a.graph_id]
|
||||
fetchable = with_graph_id[:_MAX_GRAPH_FETCHES]
|
||||
if not fetchable:
|
||||
return None
|
||||
|
||||
gdb = graph_db()
|
||||
|
||||
async def _fetch(agent: AgentInfo) -> None:
|
||||
graph_id = agent.graph_id
|
||||
if not graph_id:
|
||||
return
|
||||
try:
|
||||
graph = await gdb.get_graph(
|
||||
graph_id,
|
||||
version=agent.graph_version,
|
||||
user_id=user_id,
|
||||
for_export=True,
|
||||
)
|
||||
if graph is None:
|
||||
logger.warning("Graph not found for agent %s", graph_id)
|
||||
agent.graph = graph
|
||||
except Exception as e:
|
||||
logger.warning("Failed to fetch graph for agent %s: %s", graph_id, e)
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.gather(*[_fetch(a) for a in fetchable]),
|
||||
timeout=_GRAPH_FETCH_TIMEOUT,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"include_graph: timed out after %ds fetching graphs", _GRAPH_FETCH_TIMEOUT
|
||||
)
|
||||
|
||||
skipped = len(with_graph_id) - len(fetchable)
|
||||
if skipped > 0:
|
||||
logger.warning(
|
||||
"include_graph: fetched graphs for %d/%d agents "
|
||||
"(_MAX_GRAPH_FETCHES=%d, %d skipped)",
|
||||
len(fetchable),
|
||||
len(with_graph_id),
|
||||
_MAX_GRAPH_FETCHES,
|
||||
skipped,
|
||||
)
|
||||
return (
|
||||
f"Graph data included for {len(fetchable)} of "
|
||||
f"{len(with_graph_id)} eligible agents (limit: {_MAX_GRAPH_FETCHES}). "
|
||||
f"To fetch graphs for remaining agents, narrow your search to a "
|
||||
f"specific agent by UUID."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _marketplace_agent_to_info(agent: StoreAgent | StoreAgentDetails) -> AgentInfo:
|
||||
"""Convert a marketplace agent (StoreAgent or StoreAgentDetails) to an AgentInfo."""
|
||||
return AgentInfo(
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
"""Tests for agent search direct lookup functionality."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from .agent_search import search_agents
|
||||
from .models import AgentsFoundResponse, NoResultsResponse
|
||||
from .agent_search import _enrich_agents_with_graph, search_agents
|
||||
from .models import AgentInfo, AgentsFoundResponse, NoResultsResponse
|
||||
|
||||
_TEST_USER_ID = "test-user-agent-search"
|
||||
|
||||
@@ -133,10 +134,10 @@ class TestMarketplaceSlugLookup:
|
||||
class TestLibraryUUIDLookup:
|
||||
"""Tests for UUID direct lookup in library search."""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_uuid_lookup_found_by_graph_id(self):
|
||||
"""UUID query matching a graph_id returns the agent directly."""
|
||||
agent_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
|
||||
@staticmethod
|
||||
def _make_mock_library_agent(
|
||||
agent_id: str = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d",
|
||||
) -> MagicMock:
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.id = "lib-agent-id"
|
||||
mock_agent.name = "My Library Agent"
|
||||
@@ -150,6 +151,13 @@ class TestLibraryUUIDLookup:
|
||||
mock_agent.graph_version = 1
|
||||
mock_agent.input_schema = {}
|
||||
mock_agent.output_schema = {}
|
||||
return mock_agent
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_uuid_lookup_found_by_graph_id(self):
|
||||
"""UUID query matching a graph_id returns the agent directly."""
|
||||
agent_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
|
||||
mock_agent = self._make_mock_library_agent(agent_id)
|
||||
|
||||
mock_lib_db = MagicMock()
|
||||
mock_lib_db.get_library_agent_by_graph_id = AsyncMock(return_value=mock_agent)
|
||||
@@ -168,3 +176,427 @@ class TestLibraryUUIDLookup:
|
||||
assert isinstance(response, AgentsFoundResponse)
|
||||
assert response.count == 1
|
||||
assert response.agents[0].name == "My Library Agent"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_include_graph_fetches_graph(self):
|
||||
"""include_graph=True attaches BaseGraph to agent results."""
|
||||
from backend.data.graph import BaseGraph
|
||||
|
||||
agent_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
|
||||
mock_agent = self._make_mock_library_agent(agent_id)
|
||||
mock_lib_db = MagicMock()
|
||||
mock_lib_db.get_library_agent_by_graph_id = AsyncMock(return_value=mock_agent)
|
||||
|
||||
fake_graph = BaseGraph(id=agent_id, name="My Library Agent", description="test")
|
||||
mock_graph_db = MagicMock()
|
||||
mock_graph_db.get_graph = AsyncMock(return_value=fake_graph)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.agent_search.library_db",
|
||||
return_value=mock_lib_db,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.agent_search.graph_db",
|
||||
return_value=mock_graph_db,
|
||||
),
|
||||
):
|
||||
response = await search_agents(
|
||||
query=agent_id,
|
||||
source="library",
|
||||
session_id="s",
|
||||
user_id=_TEST_USER_ID,
|
||||
include_graph=True,
|
||||
)
|
||||
|
||||
assert isinstance(response, AgentsFoundResponse)
|
||||
assert response.agents[0].graph is not None
|
||||
assert response.agents[0].graph.id == agent_id
|
||||
mock_graph_db.get_graph.assert_awaited_once_with(
|
||||
agent_id,
|
||||
version=1,
|
||||
user_id=_TEST_USER_ID,
|
||||
for_export=True,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_include_graph_false_skips_fetch(self):
|
||||
"""include_graph=False (default) does not fetch graph data."""
|
||||
agent_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
|
||||
mock_agent = self._make_mock_library_agent(agent_id)
|
||||
mock_lib_db = MagicMock()
|
||||
mock_lib_db.get_library_agent_by_graph_id = AsyncMock(return_value=mock_agent)
|
||||
|
||||
mock_graph_db = MagicMock()
|
||||
mock_graph_db.get_graph = AsyncMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.agent_search.library_db",
|
||||
return_value=mock_lib_db,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.agent_search.graph_db",
|
||||
return_value=mock_graph_db,
|
||||
),
|
||||
):
|
||||
response = await search_agents(
|
||||
query=agent_id,
|
||||
source="library",
|
||||
session_id="s",
|
||||
user_id=_TEST_USER_ID,
|
||||
include_graph=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, AgentsFoundResponse)
|
||||
assert response.agents[0].graph is None
|
||||
mock_graph_db.get_graph.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_include_graph_handles_fetch_failure(self):
|
||||
"""include_graph=True still returns agents when graph fetch fails."""
|
||||
agent_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
|
||||
mock_agent = self._make_mock_library_agent(agent_id)
|
||||
mock_lib_db = MagicMock()
|
||||
mock_lib_db.get_library_agent_by_graph_id = AsyncMock(return_value=mock_agent)
|
||||
|
||||
mock_graph_db = MagicMock()
|
||||
mock_graph_db.get_graph = AsyncMock(side_effect=Exception("DB down"))
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.agent_search.library_db",
|
||||
return_value=mock_lib_db,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.agent_search.graph_db",
|
||||
return_value=mock_graph_db,
|
||||
),
|
||||
):
|
||||
response = await search_agents(
|
||||
query=agent_id,
|
||||
source="library",
|
||||
session_id="s",
|
||||
user_id=_TEST_USER_ID,
|
||||
include_graph=True,
|
||||
)
|
||||
|
||||
assert isinstance(response, AgentsFoundResponse)
|
||||
assert response.agents[0].graph is None
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_include_graph_handles_none_return(self):
|
||||
"""include_graph=True handles get_graph returning None."""
|
||||
agent_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
|
||||
mock_agent = self._make_mock_library_agent(agent_id)
|
||||
mock_lib_db = MagicMock()
|
||||
mock_lib_db.get_library_agent_by_graph_id = AsyncMock(return_value=mock_agent)
|
||||
|
||||
mock_graph_db = MagicMock()
|
||||
mock_graph_db.get_graph = AsyncMock(return_value=None)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.agent_search.library_db",
|
||||
return_value=mock_lib_db,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.agent_search.graph_db",
|
||||
return_value=mock_graph_db,
|
||||
),
|
||||
):
|
||||
response = await search_agents(
|
||||
query=agent_id,
|
||||
source="library",
|
||||
session_id="s",
|
||||
user_id=_TEST_USER_ID,
|
||||
include_graph=True,
|
||||
)
|
||||
|
||||
assert isinstance(response, AgentsFoundResponse)
|
||||
assert response.agents[0].graph is None
|
||||
|
||||
|
||||
class TestEnrichAgentsWithGraph:
|
||||
"""Tests for _enrich_agents_with_graph edge cases."""
|
||||
|
||||
@staticmethod
|
||||
def _make_mock_library_agent(
|
||||
agent_id: str = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d",
|
||||
graph_id: str | None = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d",
|
||||
) -> MagicMock:
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.id = f"lib-{agent_id[:8]}"
|
||||
mock_agent.name = f"Agent {agent_id[:8]}"
|
||||
mock_agent.description = "A library agent"
|
||||
mock_agent.creator_name = "testuser"
|
||||
mock_agent.status.value = "HEALTHY"
|
||||
mock_agent.can_access_graph = True
|
||||
mock_agent.has_external_trigger = False
|
||||
mock_agent.new_output = False
|
||||
mock_agent.graph_id = graph_id
|
||||
mock_agent.graph_version = 1
|
||||
mock_agent.input_schema = {}
|
||||
mock_agent.output_schema = {}
|
||||
return mock_agent
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_truncation_surfaces_in_response(self):
|
||||
"""When >_MAX_GRAPH_FETCHES agents have graphs, the response contains a truncation notice."""
|
||||
from backend.copilot.tools.agent_search import _MAX_GRAPH_FETCHES
|
||||
from backend.data.graph import BaseGraph
|
||||
|
||||
agent_count = _MAX_GRAPH_FETCHES + 5
|
||||
mock_agents = []
|
||||
for i in range(agent_count):
|
||||
uid = f"a1b2c3d4-e5f6-4a7b-8c9d-{i:012d}"
|
||||
mock_agents.append(self._make_mock_library_agent(uid, uid))
|
||||
|
||||
mock_lib_db = MagicMock()
|
||||
mock_search_results = MagicMock()
|
||||
mock_search_results.agents = mock_agents
|
||||
mock_lib_db.list_library_agents = AsyncMock(return_value=mock_search_results)
|
||||
|
||||
fake_graph = BaseGraph(id="x", name="g", description="d")
|
||||
mock_gdb = MagicMock()
|
||||
mock_gdb.get_graph = AsyncMock(return_value=fake_graph)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.agent_search.library_db",
|
||||
return_value=mock_lib_db,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.agent_search.graph_db",
|
||||
return_value=mock_gdb,
|
||||
),
|
||||
):
|
||||
response = await search_agents(
|
||||
query="",
|
||||
source="library",
|
||||
session_id="s",
|
||||
user_id=_TEST_USER_ID,
|
||||
include_graph=True,
|
||||
)
|
||||
|
||||
assert isinstance(response, AgentsFoundResponse)
|
||||
assert mock_gdb.get_graph.await_count == _MAX_GRAPH_FETCHES
|
||||
enriched = [a for a in response.agents if a.graph is not None]
|
||||
assert len(enriched) == _MAX_GRAPH_FETCHES
|
||||
assert "Graph data included for" in response.message
|
||||
assert str(_MAX_GRAPH_FETCHES) in response.message
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_mixed_graph_id_presence(self):
|
||||
"""Agents without graph_id are skipped during enrichment."""
|
||||
from backend.data.graph import BaseGraph
|
||||
|
||||
agent_with = self._make_mock_library_agent(
|
||||
"aaaa0000-0000-0000-0000-000000000001",
|
||||
"aaaa0000-0000-0000-0000-000000000001",
|
||||
)
|
||||
agent_without = self._make_mock_library_agent(
|
||||
"bbbb0000-0000-0000-0000-000000000002",
|
||||
graph_id=None,
|
||||
)
|
||||
|
||||
mock_lib_db = MagicMock()
|
||||
mock_search_results = MagicMock()
|
||||
mock_search_results.agents = [agent_with, agent_without]
|
||||
mock_lib_db.list_library_agents = AsyncMock(return_value=mock_search_results)
|
||||
|
||||
fake_graph = BaseGraph(
|
||||
id="aaaa0000-0000-0000-0000-000000000001", name="g", description="d"
|
||||
)
|
||||
mock_gdb = MagicMock()
|
||||
mock_gdb.get_graph = AsyncMock(return_value=fake_graph)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.agent_search.library_db",
|
||||
return_value=mock_lib_db,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.agent_search.graph_db",
|
||||
return_value=mock_gdb,
|
||||
),
|
||||
):
|
||||
response = await search_agents(
|
||||
query="",
|
||||
source="library",
|
||||
session_id="s",
|
||||
user_id=_TEST_USER_ID,
|
||||
include_graph=True,
|
||||
)
|
||||
|
||||
assert isinstance(response, AgentsFoundResponse)
|
||||
assert len(response.agents) == 2
|
||||
assert response.agents[0].graph is not None
|
||||
assert response.agents[1].graph is None
|
||||
mock_gdb.get_graph.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_partial_failure_across_multiple_agents(self):
|
||||
"""When some graph fetches fail, successful ones still have graphs attached."""
|
||||
from backend.data.graph import BaseGraph
|
||||
|
||||
id_ok = "aaaa0000-0000-0000-0000-000000000001"
|
||||
id_fail = "bbbb0000-0000-0000-0000-000000000002"
|
||||
agent_ok = self._make_mock_library_agent(id_ok, id_ok)
|
||||
agent_fail = self._make_mock_library_agent(id_fail, id_fail)
|
||||
|
||||
mock_lib_db = MagicMock()
|
||||
mock_search_results = MagicMock()
|
||||
mock_search_results.agents = [agent_ok, agent_fail]
|
||||
mock_lib_db.list_library_agents = AsyncMock(return_value=mock_search_results)
|
||||
|
||||
fake_graph = BaseGraph(id=id_ok, name="g", description="d")
|
||||
|
||||
async def _side_effect(graph_id, **kwargs):
|
||||
if graph_id == id_fail:
|
||||
raise Exception("DB error")
|
||||
return fake_graph
|
||||
|
||||
mock_gdb = MagicMock()
|
||||
mock_gdb.get_graph = AsyncMock(side_effect=_side_effect)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.agent_search.library_db",
|
||||
return_value=mock_lib_db,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.agent_search.graph_db",
|
||||
return_value=mock_gdb,
|
||||
),
|
||||
):
|
||||
response = await search_agents(
|
||||
query="",
|
||||
source="library",
|
||||
session_id="s",
|
||||
user_id=_TEST_USER_ID,
|
||||
include_graph=True,
|
||||
)
|
||||
|
||||
assert isinstance(response, AgentsFoundResponse)
|
||||
assert response.agents[0].graph is not None
|
||||
assert response.agents[0].graph.id == id_ok
|
||||
assert response.agents[1].graph is None
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_keyword_search_with_include_graph(self):
|
||||
"""include_graph works via keyword search (non-UUID path)."""
|
||||
from backend.data.graph import BaseGraph
|
||||
|
||||
agent_id = "cccc0000-0000-0000-0000-000000000003"
|
||||
mock_agent = self._make_mock_library_agent(agent_id, agent_id)
|
||||
|
||||
mock_lib_db = MagicMock()
|
||||
mock_search_results = MagicMock()
|
||||
mock_search_results.agents = [mock_agent]
|
||||
mock_lib_db.list_library_agents = AsyncMock(return_value=mock_search_results)
|
||||
|
||||
fake_graph = BaseGraph(id=agent_id, name="g", description="d")
|
||||
mock_gdb = MagicMock()
|
||||
mock_gdb.get_graph = AsyncMock(return_value=fake_graph)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.agent_search.library_db",
|
||||
return_value=mock_lib_db,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.agent_search.graph_db",
|
||||
return_value=mock_gdb,
|
||||
),
|
||||
):
|
||||
response = await search_agents(
|
||||
query="email",
|
||||
source="library",
|
||||
session_id="s",
|
||||
user_id=_TEST_USER_ID,
|
||||
include_graph=True,
|
||||
)
|
||||
|
||||
assert isinstance(response, AgentsFoundResponse)
|
||||
assert response.agents[0].graph is not None
|
||||
assert response.agents[0].graph.id == agent_id
|
||||
mock_gdb.get_graph.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_timeout_preserves_successful_fetches(self):
|
||||
"""On timeout, agents that already fetched their graph keep the result."""
|
||||
fast_agent = AgentInfo(
|
||||
id="a1",
|
||||
name="Fast",
|
||||
description="d",
|
||||
source="library",
|
||||
graph_id="fast-graph",
|
||||
)
|
||||
slow_agent = AgentInfo(
|
||||
id="a2",
|
||||
name="Slow",
|
||||
description="d",
|
||||
source="library",
|
||||
graph_id="slow-graph",
|
||||
)
|
||||
fake_graph = MagicMock()
|
||||
fake_graph.id = "graph-1"
|
||||
|
||||
async def mock_get_graph(
|
||||
graph_id, *, version=None, user_id=None, for_export=False
|
||||
):
|
||||
if graph_id == "fast-graph":
|
||||
return fake_graph
|
||||
await asyncio.sleep(999)
|
||||
return MagicMock()
|
||||
|
||||
mock_gdb = MagicMock()
|
||||
mock_gdb.get_graph = AsyncMock(side_effect=mock_get_graph)
|
||||
|
||||
with (
|
||||
patch("backend.copilot.tools.agent_search.graph_db", return_value=mock_gdb),
|
||||
patch("backend.copilot.tools.agent_search._GRAPH_FETCH_TIMEOUT", 0.1),
|
||||
):
|
||||
await _enrich_agents_with_graph([fast_agent, slow_agent], _TEST_USER_ID)
|
||||
|
||||
assert fast_agent.graph is fake_graph
|
||||
assert slow_agent.graph is None
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_enrich_success(self):
|
||||
"""All agents get their graphs when no timeout occurs."""
|
||||
agent = AgentInfo(
|
||||
id="a1", name="Test", description="d", source="library", graph_id="g1"
|
||||
)
|
||||
fake_graph = MagicMock()
|
||||
fake_graph.id = "graph-1"
|
||||
|
||||
mock_gdb = MagicMock()
|
||||
mock_gdb.get_graph = AsyncMock(return_value=fake_graph)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.agent_search.graph_db", return_value=mock_gdb
|
||||
):
|
||||
result = await _enrich_agents_with_graph([agent], _TEST_USER_ID)
|
||||
|
||||
assert agent.graph is fake_graph
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_enrich_skips_agents_without_graph_id(self):
|
||||
"""Agents without graph_id are not fetched."""
|
||||
agent_no_id = AgentInfo(
|
||||
id="a1", name="Test", description="d", source="library", graph_id=None
|
||||
)
|
||||
|
||||
mock_gdb = MagicMock()
|
||||
mock_gdb.get_graph = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.agent_search.graph_db", return_value=mock_gdb
|
||||
):
|
||||
result = await _enrich_agents_with_graph([agent_no_id], _TEST_USER_ID)
|
||||
|
||||
mock_gdb.get_graph.assert_not_called()
|
||||
assert result is None
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
"""AskQuestionTool - Ask the user a clarifying question before proceeding."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import ClarificationNeededResponse, ClarifyingQuestion, ToolResponseBase
|
||||
|
||||
|
||||
class AskQuestionTool(BaseTool):
|
||||
"""Ask the user a clarifying question and wait for their answer.
|
||||
|
||||
Use this tool when the user's request is ambiguous and you need more
|
||||
information before proceeding. Call find_block or other discovery tools
|
||||
first to ground your question in real platform options, then call this
|
||||
tool with a concrete question listing those options.
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "ask_question"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Ask the user a clarifying question. Use when the request is "
|
||||
"ambiguous and you need to confirm intent, choose between options, "
|
||||
"or gather missing details before proceeding."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"question": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The concrete question to ask the user. Should list "
|
||||
"real options when applicable."
|
||||
),
|
||||
},
|
||||
"options": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"Options for the user to choose from "
|
||||
"(e.g. ['Email', 'Slack', 'Google Docs'])."
|
||||
),
|
||||
},
|
||||
"keyword": {
|
||||
"type": "string",
|
||||
"description": "Short label identifying what the question is about.",
|
||||
},
|
||||
},
|
||||
"required": ["question"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return False
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
del user_id # unused; required by BaseTool contract
|
||||
question_raw = kwargs.get("question")
|
||||
if not isinstance(question_raw, str) or not question_raw.strip():
|
||||
raise ValueError("ask_question requires a non-empty 'question' string")
|
||||
question = question_raw.strip()
|
||||
raw_options = kwargs.get("options", [])
|
||||
if not isinstance(raw_options, list):
|
||||
raw_options = []
|
||||
options: list[str] = [str(o) for o in raw_options if o]
|
||||
raw_keyword = kwargs.get("keyword", "")
|
||||
keyword: str = str(raw_keyword) if raw_keyword else ""
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
example = ", ".join(options) if options else None
|
||||
clarifying_question = ClarifyingQuestion(
|
||||
question=question,
|
||||
keyword=keyword,
|
||||
example=example,
|
||||
)
|
||||
return ClarificationNeededResponse(
|
||||
message=question,
|
||||
session_id=session_id,
|
||||
questions=[clarifying_question],
|
||||
)
|
||||
@@ -0,0 +1,99 @@
|
||||
"""Tests for AskQuestionTool."""
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.ask_question import AskQuestionTool
|
||||
from backend.copilot.tools.models import ClarificationNeededResponse
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def tool() -> AskQuestionTool:
|
||||
return AskQuestionTool()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def session() -> ChatSession:
|
||||
return ChatSession.new(user_id="test-user", dry_run=False)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_with_options(tool: AskQuestionTool, session: ChatSession):
|
||||
result = await tool._execute(
|
||||
user_id=None,
|
||||
session=session,
|
||||
question="Which channel?",
|
||||
options=["Email", "Slack", "Google Docs"],
|
||||
keyword="channel",
|
||||
)
|
||||
|
||||
assert isinstance(result, ClarificationNeededResponse)
|
||||
assert result.message == "Which channel?"
|
||||
assert result.session_id == session.session_id
|
||||
assert len(result.questions) == 1
|
||||
|
||||
q = result.questions[0]
|
||||
assert q.question == "Which channel?"
|
||||
assert q.keyword == "channel"
|
||||
assert q.example == "Email, Slack, Google Docs"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_without_options(tool: AskQuestionTool, session: ChatSession):
|
||||
result = await tool._execute(
|
||||
user_id=None,
|
||||
session=session,
|
||||
question="What format do you want?",
|
||||
)
|
||||
|
||||
assert isinstance(result, ClarificationNeededResponse)
|
||||
assert result.message == "What format do you want?"
|
||||
assert len(result.questions) == 1
|
||||
|
||||
q = result.questions[0]
|
||||
assert q.question == "What format do you want?"
|
||||
assert q.keyword == ""
|
||||
assert q.example is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_with_keyword_only(tool: AskQuestionTool, session: ChatSession):
|
||||
result = await tool._execute(
|
||||
user_id=None,
|
||||
session=session,
|
||||
question="How often should it run?",
|
||||
keyword="trigger",
|
||||
)
|
||||
|
||||
assert isinstance(result, ClarificationNeededResponse)
|
||||
q = result.questions[0]
|
||||
assert q.keyword == "trigger"
|
||||
assert q.example is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_rejects_empty_question(
|
||||
tool: AskQuestionTool, session: ChatSession
|
||||
):
|
||||
with pytest.raises(ValueError, match="non-empty"):
|
||||
await tool._execute(user_id=None, session=session, question="")
|
||||
|
||||
with pytest.raises(ValueError, match="non-empty"):
|
||||
await tool._execute(user_id=None, session=session, question=" ")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_coerces_invalid_options(
|
||||
tool: AskQuestionTool, session: ChatSession
|
||||
):
|
||||
"""LLM may send options as a string instead of a list; should not crash."""
|
||||
result = await tool._execute(
|
||||
user_id=None,
|
||||
session=session,
|
||||
question="Pick one",
|
||||
options="not-a-list", # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
assert isinstance(result, ClarificationNeededResponse)
|
||||
q = result.questions[0]
|
||||
assert q.example is None
|
||||
@@ -91,10 +91,16 @@ async def _persist_and_summarize(
|
||||
f"\nFull output ({total:,} chars) saved to workspace. "
|
||||
f"Use read_workspace_file("
|
||||
f'path="{file_path}", offset=<char_offset>, length=50000) '
|
||||
f"to read any section."
|
||||
f"to read any section. "
|
||||
f"To process the file in the sandbox/working dir, use "
|
||||
f"read_workspace_file("
|
||||
f'path="{file_path}", save_to_path="<working_dir>/{tool_call_id}.json") '
|
||||
f"first, then use bash_exec to work with the local copy."
|
||||
)
|
||||
# Use workspace:// prefix so the model doesn't confuse the workspace path
|
||||
# with a local filesystem path (e.g. ~/.claude/projects/.../tool-outputs/).
|
||||
return (
|
||||
f'<tool-output-truncated total_chars={total} path="{file_path}">\n'
|
||||
f'<tool-output-truncated total_chars={total} workspace_path="{file_path}">\n'
|
||||
f"{preview}\n"
|
||||
f"{retrieval}\n"
|
||||
f"</tool-output-truncated>"
|
||||
|
||||
@@ -67,7 +67,7 @@ class TestPersistAndSummarize:
|
||||
assert "<tool-output-truncated" in result
|
||||
assert "</tool-output-truncated>" in result
|
||||
assert "total_chars=200000" in result
|
||||
assert 'path="tool-outputs/tc-123.json"' in result
|
||||
assert 'workspace_path="tool-outputs/tc-123.json"' in result
|
||||
assert "read_workspace_file" in result
|
||||
# Middle-out sentinel from truncate()
|
||||
assert "omitted" in result
|
||||
|
||||
@@ -22,7 +22,10 @@ from e2b import AsyncSandbox
|
||||
from e2b.exceptions import TimeoutException
|
||||
|
||||
from backend.copilot.context import E2B_WORKDIR, get_current_sandbox
|
||||
from backend.copilot.integration_creds import get_integration_env_vars
|
||||
from backend.copilot.integration_creds import (
|
||||
get_github_user_git_identity,
|
||||
get_integration_env_vars,
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .base import BaseTool
|
||||
@@ -159,6 +162,12 @@ class BashExecTool(BaseTool):
|
||||
secret_values = [v for v in integration_env.values() if v]
|
||||
envs.update(integration_env)
|
||||
|
||||
# Set git author/committer identity from the user's GitHub profile
|
||||
# so commits made in the sandbox are attributed correctly.
|
||||
git_identity = await get_github_user_git_identity(user_id)
|
||||
if git_identity:
|
||||
envs.update(git_identity)
|
||||
|
||||
try:
|
||||
result = await sandbox.commands.run(
|
||||
f"bash -c {shlex.quote(command)}",
|
||||
|
||||
@@ -38,7 +38,10 @@ class TestBashExecE2BTokenInjection:
|
||||
with patch(
|
||||
"backend.copilot.tools.bash_exec.get_integration_env_vars",
|
||||
new=AsyncMock(return_value=env_vars),
|
||||
) as mock_get_env:
|
||||
) as mock_get_env, patch(
|
||||
"backend.copilot.tools.bash_exec.get_github_user_git_identity",
|
||||
new=AsyncMock(return_value=None),
|
||||
):
|
||||
result = await tool._execute_on_e2b(
|
||||
sandbox=sandbox,
|
||||
command="echo hi",
|
||||
@@ -53,6 +56,66 @@ class TestBashExecE2BTokenInjection:
|
||||
assert call_kwargs["envs"]["GITHUB_TOKEN"] == "gh-secret"
|
||||
assert isinstance(result, BashExecResponse)
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_git_identity_set_from_github_profile(self):
|
||||
"""When user has a connected GitHub account, git env vars are set from their profile."""
|
||||
tool = _make_tool()
|
||||
session = make_session(user_id=_USER)
|
||||
sandbox = _make_sandbox(stdout="ok")
|
||||
identity = {
|
||||
"GIT_AUTHOR_NAME": "Test User",
|
||||
"GIT_AUTHOR_EMAIL": "test@example.com",
|
||||
"GIT_COMMITTER_NAME": "Test User",
|
||||
"GIT_COMMITTER_EMAIL": "test@example.com",
|
||||
}
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.bash_exec.get_integration_env_vars",
|
||||
new=AsyncMock(return_value={}),
|
||||
), patch(
|
||||
"backend.copilot.tools.bash_exec.get_github_user_git_identity",
|
||||
new=AsyncMock(return_value=identity),
|
||||
):
|
||||
await tool._execute_on_e2b(
|
||||
sandbox=sandbox,
|
||||
command="git commit -m test",
|
||||
timeout=10,
|
||||
session_id=session.session_id,
|
||||
user_id=_USER,
|
||||
)
|
||||
|
||||
call_kwargs = sandbox.commands.run.call_args[1]
|
||||
assert call_kwargs["envs"]["GIT_AUTHOR_NAME"] == "Test User"
|
||||
assert call_kwargs["envs"]["GIT_AUTHOR_EMAIL"] == "test@example.com"
|
||||
assert call_kwargs["envs"]["GIT_COMMITTER_NAME"] == "Test User"
|
||||
assert call_kwargs["envs"]["GIT_COMMITTER_EMAIL"] == "test@example.com"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_no_git_identity_when_github_not_connected(self):
|
||||
"""When user has no GitHub account, git identity env vars are absent."""
|
||||
tool = _make_tool()
|
||||
session = make_session(user_id=_USER)
|
||||
sandbox = _make_sandbox(stdout="ok")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.bash_exec.get_integration_env_vars",
|
||||
new=AsyncMock(return_value={}),
|
||||
), patch(
|
||||
"backend.copilot.tools.bash_exec.get_github_user_git_identity",
|
||||
new=AsyncMock(return_value=None),
|
||||
):
|
||||
await tool._execute_on_e2b(
|
||||
sandbox=sandbox,
|
||||
command="echo hi",
|
||||
timeout=10,
|
||||
session_id=session.session_id,
|
||||
user_id=_USER,
|
||||
)
|
||||
|
||||
call_kwargs = sandbox.commands.run.call_args[1]
|
||||
assert "GIT_AUTHOR_NAME" not in call_kwargs["envs"]
|
||||
assert "GIT_COMMITTER_EMAIL" not in call_kwargs["envs"]
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_no_token_injection_when_user_id_is_none(self):
|
||||
"""When user_id is None, get_integration_env_vars must NOT be called."""
|
||||
@@ -63,7 +126,10 @@ class TestBashExecE2BTokenInjection:
|
||||
with patch(
|
||||
"backend.copilot.tools.bash_exec.get_integration_env_vars",
|
||||
new=AsyncMock(return_value={"GH_TOKEN": "should-not-appear"}),
|
||||
) as mock_get_env:
|
||||
) as mock_get_env, patch(
|
||||
"backend.copilot.tools.bash_exec.get_github_user_git_identity",
|
||||
new=AsyncMock(return_value=None),
|
||||
) as mock_get_identity:
|
||||
result = await tool._execute_on_e2b(
|
||||
sandbox=sandbox,
|
||||
command="echo hi",
|
||||
@@ -73,6 +139,8 @@ class TestBashExecE2BTokenInjection:
|
||||
)
|
||||
|
||||
mock_get_env.assert_not_called()
|
||||
mock_get_identity.assert_not_called()
|
||||
call_kwargs = sandbox.commands.run.call_args[1]
|
||||
assert "GH_TOKEN" not in call_kwargs["envs"]
|
||||
assert "GIT_AUTHOR_NAME" not in call_kwargs["envs"]
|
||||
assert isinstance(result, BashExecResponse)
|
||||
|
||||
@@ -20,7 +20,8 @@ class FindLibraryAgentTool(BaseTool):
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Search user's library agents. Returns graph_id, schemas for sub-agent composition. "
|
||||
"Omit query to list all."
|
||||
"Omit query to list all. Set include_graph=true to also fetch the full "
|
||||
"graph structure (nodes + links) for debugging or editing."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -32,6 +33,15 @@ class FindLibraryAgentTool(BaseTool):
|
||||
"type": "string",
|
||||
"description": "Search by name/description. Omit to list all.",
|
||||
},
|
||||
"include_graph": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"When true, includes the full graph structure "
|
||||
"(nodes + links) for each found agent. "
|
||||
"Use when you need to inspect, debug, or edit an agent."
|
||||
),
|
||||
"default": False,
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
}
|
||||
@@ -45,6 +55,7 @@ class FindLibraryAgentTool(BaseTool):
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
query: str = "",
|
||||
include_graph: bool = False,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
return await search_agents(
|
||||
@@ -52,4 +63,5 @@ class FindLibraryAgentTool(BaseTool):
|
||||
source="library",
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
include_graph=include_graph,
|
||||
)
|
||||
|
||||
@@ -42,7 +42,10 @@ class GetAgentBuildingGuideTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Get the agent JSON building guide (nodes, links, AgentExecutorBlock, MCPToolBlock usage). Call before generating agent JSON."
|
||||
return (
|
||||
"Get the agent JSON building guide (nodes, links, AgentExecutorBlock, MCPToolBlock usage, "
|
||||
"and the create->dry-run->fix iterative workflow). Call before generating agent JSON."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
|
||||
@@ -48,27 +48,41 @@ logger = logging.getLogger(__name__)
|
||||
def get_inputs_from_schema(
|
||||
input_schema: dict[str, Any],
|
||||
exclude_fields: set[str] | None = None,
|
||||
input_data: dict[str, Any] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Extract input field info from JSON schema."""
|
||||
"""Extract input field info from JSON schema.
|
||||
|
||||
When *input_data* is provided, each field's ``value`` key is populated
|
||||
with the value the CoPilot already supplied — so the frontend can
|
||||
prefill the form instead of showing empty inputs. Fields marked
|
||||
``advanced`` in the schema are flagged so the frontend can hide them
|
||||
by default (matching the builder behaviour).
|
||||
"""
|
||||
if not isinstance(input_schema, dict):
|
||||
return []
|
||||
|
||||
exclude = exclude_fields or set()
|
||||
properties = input_schema.get("properties", {})
|
||||
required = set(input_schema.get("required", []))
|
||||
provided = input_data or {}
|
||||
|
||||
return [
|
||||
{
|
||||
results: list[dict[str, Any]] = []
|
||||
for name, schema in properties.items():
|
||||
if name in exclude:
|
||||
continue
|
||||
entry: dict[str, Any] = {
|
||||
"name": name,
|
||||
"title": schema.get("title", name),
|
||||
"type": schema.get("type", "string"),
|
||||
"description": schema.get("description", ""),
|
||||
"required": name in required,
|
||||
"default": schema.get("default"),
|
||||
"advanced": schema.get("advanced", False),
|
||||
}
|
||||
for name, schema in properties.items()
|
||||
if name not in exclude
|
||||
]
|
||||
if name in provided:
|
||||
entry["value"] = provided[name]
|
||||
results.append(entry)
|
||||
return results
|
||||
|
||||
|
||||
async def execute_block(
|
||||
@@ -446,7 +460,9 @@ async def prepare_block_for_execution(
|
||||
requirements={
|
||||
"credentials": missing_creds_list,
|
||||
"inputs": get_inputs_from_schema(
|
||||
input_schema, exclude_fields=credentials_fields
|
||||
input_schema,
|
||||
exclude_fields=credentials_fields,
|
||||
input_data=input_data,
|
||||
),
|
||||
"execution_modes": ["immediate"],
|
||||
},
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.graph import BaseGraph
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
|
||||
|
||||
@@ -122,6 +123,10 @@ class AgentInfo(BaseModel):
|
||||
default=None,
|
||||
description="Input schema for the agent, including field names, types, and defaults",
|
||||
)
|
||||
graph: BaseGraph | None = Field(
|
||||
default=None,
|
||||
description="Full graph structure (nodes + links) when include_graph is requested",
|
||||
)
|
||||
|
||||
|
||||
class AgentsFoundResponse(ToolResponseBase):
|
||||
|
||||
@@ -153,7 +153,11 @@ class RunAgentTool(BaseTool):
|
||||
},
|
||||
"dry_run": {
|
||||
"type": "boolean",
|
||||
"description": "Execute in preview mode.",
|
||||
"description": (
|
||||
"When true, simulates execution using an LLM for each block "
|
||||
"— no real API calls, credentials, or credits. "
|
||||
"See agent_generation_guide for the full workflow."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["dry_run"],
|
||||
|
||||
@@ -10,7 +10,11 @@ import backend.copilot.tools.run_block as run_block_module
|
||||
from backend.copilot.tools.helpers import execute_block
|
||||
from backend.copilot.tools.models import BlockOutputResponse, ErrorResponse
|
||||
from backend.copilot.tools.run_block import RunBlockTool
|
||||
from backend.executor.simulator import build_simulation_prompt, simulate_block
|
||||
from backend.executor.simulator import (
|
||||
build_simulation_prompt,
|
||||
prepare_dry_run,
|
||||
simulate_block,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
@@ -75,7 +79,8 @@ def make_openai_response(
|
||||
async def test_simulate_block_basic():
|
||||
"""simulate_block returns correct (output_name, output_data) tuples.
|
||||
|
||||
Empty "error" pins are dropped at source — only non-empty errors are yielded.
|
||||
Empty error pins should be omitted (not yielded) — only pins with
|
||||
meaningful values are forwarded.
|
||||
"""
|
||||
mock_block = make_mock_block()
|
||||
mock_client = AsyncMock()
|
||||
@@ -91,7 +96,7 @@ async def test_simulate_block_basic():
|
||||
outputs.append((name, data))
|
||||
|
||||
assert ("result", "simulated output") in outputs
|
||||
# Empty error pin is dropped at the simulator level
|
||||
# Empty error pin should NOT be yielded — the simulator omits empty values
|
||||
assert ("error", "") not in outputs
|
||||
|
||||
|
||||
@@ -147,7 +152,7 @@ async def test_simulate_block_all_retries_exhausted():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simulate_block_missing_output_pins():
|
||||
"""LLM response missing some output pins; verify non-error pins filled with None."""
|
||||
"""LLM response missing some output pins; they are omitted (not yielded)."""
|
||||
mock_block = make_mock_block(
|
||||
output_props={
|
||||
"result": {"type": "string"},
|
||||
@@ -169,30 +174,9 @@ async def test_simulate_block_missing_output_pins():
|
||||
outputs[name] = data
|
||||
|
||||
assert outputs["result"] == "hello"
|
||||
assert outputs["count"] is None # missing pin filled with None
|
||||
assert "error" not in outputs # missing error pin is omitted entirely
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simulate_block_keeps_nonempty_error():
|
||||
"""simulate_block keeps non-empty error pins (simulated logical errors)."""
|
||||
mock_block = make_mock_block()
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.completions.create = AsyncMock(
|
||||
return_value=make_openai_response(
|
||||
'{"result": "", "error": "API rate limit exceeded"}'
|
||||
)
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.executor.simulator.get_openai_client", return_value=mock_client
|
||||
):
|
||||
outputs = []
|
||||
async for name, data in simulate_block(mock_block, {"query": "test"}):
|
||||
outputs.append((name, data))
|
||||
|
||||
assert ("result", "") in outputs
|
||||
assert ("error", "API rate limit exceeded") in outputs
|
||||
# Missing pins are omitted — only pins with meaningful values are yielded
|
||||
assert "count" not in outputs
|
||||
assert "error" not in outputs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -228,17 +212,19 @@ async def test_simulate_block_truncates_long_inputs():
|
||||
assert len(parsed["text"]) < 25000
|
||||
|
||||
|
||||
def test_build_simulation_prompt_excludes_error_from_must_include():
|
||||
"""The 'MUST include' prompt line should NOT list 'error' — the prompt
|
||||
already instructs the LLM to OMIT error unless simulating a logical error.
|
||||
Including it in 'MUST include' would be contradictory."""
|
||||
def test_build_simulation_prompt_lists_available_output_pins():
|
||||
"""The prompt should list available output pins (excluding error) so the LLM
|
||||
knows which keys it MUST include. Error is excluded because the prompt
|
||||
tells the LLM to omit it unless simulating a logical failure."""
|
||||
block = make_mock_block() # default output_props has "result" and "error"
|
||||
system_prompt, _ = build_simulation_prompt(block, {"query": "test"})
|
||||
must_include_line = [
|
||||
line for line in system_prompt.splitlines() if "MUST include" in line
|
||||
available_line = [
|
||||
line for line in system_prompt.splitlines() if "Available output pins" in line
|
||||
][0]
|
||||
assert '"result"' in must_include_line
|
||||
assert '"error"' not in must_include_line
|
||||
assert '"result"' in available_line
|
||||
# "error" is intentionally excluded from the required output pins list
|
||||
# since the prompt instructs the LLM to omit it unless simulating errors
|
||||
assert '"error"' not in available_line
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -493,3 +479,146 @@ async def test_execute_block_dry_run_simulator_error_returns_error_response():
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "[SIMULATOR ERROR" in response.message
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# prepare_dry_run tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_prepare_dry_run_orchestrator_block():
|
||||
"""prepare_dry_run caps iterations and overrides model to simulation model."""
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
|
||||
block = OrchestratorBlock()
|
||||
input_data = {"prompt": "hello", "model": "gpt-4o", "agent_mode_max_iterations": 10}
|
||||
with patch(
|
||||
"backend.executor.simulator._get_platform_openrouter_key",
|
||||
return_value="sk-or-test-key",
|
||||
):
|
||||
result = prepare_dry_run(block, input_data)
|
||||
|
||||
assert result is not None
|
||||
# Model is overridden to the simulation model (not the user's model).
|
||||
assert result["model"] != "gpt-4o"
|
||||
assert result["agent_mode_max_iterations"] == 1
|
||||
assert result["_dry_run_api_key"] == "sk-or-test-key"
|
||||
# Original input_data should not be mutated.
|
||||
assert input_data["model"] == "gpt-4o"
|
||||
|
||||
|
||||
def test_prepare_dry_run_agent_executor_block():
|
||||
"""prepare_dry_run returns a copy of input_data for AgentExecutorBlock.
|
||||
|
||||
AgentExecutorBlock must execute for real during dry-run so it can spawn
|
||||
a child graph execution (whose blocks are then simulated). Its Output
|
||||
schema has no properties, so LLM simulation would yield zero outputs.
|
||||
"""
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
|
||||
block = AgentExecutorBlock()
|
||||
input_data = {
|
||||
"user_id": "u1",
|
||||
"graph_id": "g1",
|
||||
"graph_version": 1,
|
||||
"inputs": {"text": "hello"},
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
}
|
||||
result = prepare_dry_run(block, input_data)
|
||||
|
||||
assert result is not None
|
||||
# Input data is returned as-is (no model swap needed).
|
||||
assert result["user_id"] == "u1"
|
||||
assert result["graph_id"] == "g1"
|
||||
# Original input_data should not be mutated.
|
||||
assert result is not input_data
|
||||
|
||||
|
||||
def test_prepare_dry_run_regular_block_returns_none():
|
||||
"""prepare_dry_run returns None for a regular block (use simulator)."""
|
||||
mock_block = make_mock_block()
|
||||
assert prepare_dry_run(mock_block, {"query": "test"}) is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Input/output block passthrough tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simulate_agent_input_block_passthrough():
|
||||
"""AgentInputBlock should pass through the value directly, no LLM call."""
|
||||
from backend.blocks.io import AgentInputBlock
|
||||
|
||||
block = AgentInputBlock()
|
||||
outputs = []
|
||||
async for name, data in simulate_block(
|
||||
block, {"value": "hello world", "name": "q"}
|
||||
):
|
||||
outputs.append((name, data))
|
||||
|
||||
assert outputs == [("result", "hello world")]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simulate_agent_dropdown_input_block_passthrough():
|
||||
"""AgentDropdownInputBlock (subclass of AgentInputBlock) should pass through."""
|
||||
from backend.blocks.io import AgentDropdownInputBlock
|
||||
|
||||
block = AgentDropdownInputBlock()
|
||||
outputs = []
|
||||
async for name, data in simulate_block(
|
||||
block,
|
||||
{
|
||||
"value": "Option B",
|
||||
"name": "sev",
|
||||
"options": ["Option A", "Option B"],
|
||||
},
|
||||
):
|
||||
outputs.append((name, data))
|
||||
|
||||
assert outputs == [("result", "Option B")]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simulate_agent_input_block_none_value_falls_back_to_name():
|
||||
"""AgentInputBlock with value=None falls back to the input name."""
|
||||
from backend.blocks.io import AgentInputBlock
|
||||
|
||||
block = AgentInputBlock()
|
||||
outputs = []
|
||||
async for name, data in simulate_block(block, {"value": None, "name": "q"}):
|
||||
outputs.append((name, data))
|
||||
|
||||
# When value is None, the simulator falls back to the "name" field
|
||||
assert outputs == [("result", "q")]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simulate_agent_output_block_passthrough():
|
||||
"""AgentOutputBlock should pass through value as output."""
|
||||
from backend.blocks.io import AgentOutputBlock
|
||||
|
||||
block = AgentOutputBlock()
|
||||
outputs = []
|
||||
async for name, data in simulate_block(
|
||||
block, {"value": "result text", "name": "out1"}
|
||||
):
|
||||
outputs.append((name, data))
|
||||
|
||||
assert ("output", "result text") in outputs
|
||||
assert ("name", "out1") in outputs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simulate_agent_output_block_no_name():
|
||||
"""AgentOutputBlock without name in input should still yield output."""
|
||||
from backend.blocks.io import AgentOutputBlock
|
||||
|
||||
block = AgentOutputBlock()
|
||||
outputs = []
|
||||
async for name, data in simulate_block(block, {"value": 42}):
|
||||
outputs.append((name, data))
|
||||
|
||||
assert outputs == [("output", 42)]
|
||||
|
||||
1098
autogpt_platform/backend/backend/copilot/transcript.py
Normal file
1098
autogpt_platform/backend/backend/copilot/transcript.py
Normal file
File diff suppressed because it is too large
Load Diff
235
autogpt_platform/backend/backend/copilot/transcript_builder.py
Normal file
235
autogpt_platform/backend/backend/copilot/transcript_builder.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""Build complete JSONL transcript from SDK messages.
|
||||
|
||||
The transcript represents the FULL active context at any point in time.
|
||||
Each upload REPLACES the previous transcript atomically.
|
||||
|
||||
Flow:
|
||||
Turn 1: Upload [msg1, msg2]
|
||||
Turn 2: Download [msg1, msg2] → Upload [msg1, msg2, msg3, msg4] (REPLACE)
|
||||
Turn 3: Download [msg1, msg2, msg3, msg4] → Upload [all messages] (REPLACE)
|
||||
|
||||
The transcript is never incremental - always the complete atomic state.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.util import json
|
||||
|
||||
from .transcript import STRIPPABLE_TYPES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TranscriptEntry(BaseModel):
|
||||
"""Single transcript entry (user or assistant turn)."""
|
||||
|
||||
type: str
|
||||
uuid: str
|
||||
parentUuid: str | None
|
||||
isCompactSummary: bool | None = None
|
||||
message: dict[str, Any]
|
||||
|
||||
|
||||
class TranscriptBuilder:
|
||||
"""Build complete JSONL transcript from SDK messages.
|
||||
|
||||
This builder maintains the FULL conversation state, not incremental changes.
|
||||
The output is always the complete active context.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._entries: list[TranscriptEntry] = []
|
||||
self._last_uuid: str | None = None
|
||||
|
||||
def _last_is_assistant(self) -> bool:
|
||||
return bool(self._entries) and self._entries[-1].type == "assistant"
|
||||
|
||||
def _last_message_id(self) -> str:
|
||||
"""Return the message.id of the last entry, or '' if none."""
|
||||
if self._entries:
|
||||
return self._entries[-1].message.get("id", "")
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _parse_entry(data: dict) -> TranscriptEntry | None:
|
||||
"""Parse a single transcript entry, filtering strippable types.
|
||||
|
||||
Returns ``None`` for entries that should be skipped (strippable types
|
||||
that are not compaction summaries).
|
||||
"""
|
||||
entry_type = data.get("type", "")
|
||||
if entry_type in STRIPPABLE_TYPES and not data.get("isCompactSummary"):
|
||||
return None
|
||||
return TranscriptEntry(
|
||||
type=entry_type,
|
||||
uuid=data.get("uuid") or str(uuid4()),
|
||||
parentUuid=data.get("parentUuid"),
|
||||
isCompactSummary=data.get("isCompactSummary"),
|
||||
message=data.get("message", {}),
|
||||
)
|
||||
|
||||
def load_previous(self, content: str, log_prefix: str = "[Transcript]") -> None:
|
||||
"""Load complete previous transcript.
|
||||
|
||||
This loads the FULL previous context. As new messages come in,
|
||||
we append to this state. The final output is the complete context
|
||||
(previous + new), not just the delta.
|
||||
"""
|
||||
if not content or not content.strip():
|
||||
return
|
||||
|
||||
lines = content.strip().split("\n")
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
data = json.loads(line, fallback=None)
|
||||
if data is None:
|
||||
logger.warning(
|
||||
"%s Failed to parse transcript line %d/%d",
|
||||
log_prefix,
|
||||
line_num,
|
||||
len(lines),
|
||||
)
|
||||
continue
|
||||
|
||||
entry = self._parse_entry(data)
|
||||
if entry is None:
|
||||
continue
|
||||
self._entries.append(entry)
|
||||
self._last_uuid = entry.uuid
|
||||
|
||||
logger.info(
|
||||
"%s Loaded %d entries from previous transcript (last_uuid=%s)",
|
||||
log_prefix,
|
||||
len(self._entries),
|
||||
self._last_uuid[:12] if self._last_uuid else None,
|
||||
)
|
||||
|
||||
def append_user(self, content: str | list[dict], uuid: str | None = None) -> None:
|
||||
"""Append a user entry."""
|
||||
msg_uuid = uuid or str(uuid4())
|
||||
|
||||
self._entries.append(
|
||||
TranscriptEntry(
|
||||
type="user",
|
||||
uuid=msg_uuid,
|
||||
parentUuid=self._last_uuid,
|
||||
message={"role": "user", "content": content},
|
||||
)
|
||||
)
|
||||
self._last_uuid = msg_uuid
|
||||
|
||||
def append_tool_result(self, tool_use_id: str, content: str) -> None:
|
||||
"""Append a tool result as a user entry (one per tool call)."""
|
||||
self.append_user(
|
||||
content=[
|
||||
{"type": "tool_result", "tool_use_id": tool_use_id, "content": content}
|
||||
]
|
||||
)
|
||||
|
||||
def append_assistant(
|
||||
self,
|
||||
content_blocks: list[dict],
|
||||
model: str = "",
|
||||
stop_reason: str | None = None,
|
||||
) -> None:
|
||||
"""Append an assistant entry.
|
||||
|
||||
Consecutive assistant entries automatically share the same message ID
|
||||
so the CLI can merge them (thinking → text → tool_use) into a single
|
||||
API message on ``--resume``. A new ID is assigned whenever an
|
||||
assistant entry follows a non-assistant entry (user message or tool
|
||||
result), because that marks the start of a new API response.
|
||||
"""
|
||||
message_id = (
|
||||
self._last_message_id()
|
||||
if self._last_is_assistant()
|
||||
else f"msg_sdk_{uuid4().hex[:24]}"
|
||||
)
|
||||
|
||||
msg_uuid = str(uuid4())
|
||||
|
||||
self._entries.append(
|
||||
TranscriptEntry(
|
||||
type="assistant",
|
||||
uuid=msg_uuid,
|
||||
parentUuid=self._last_uuid,
|
||||
message={
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"id": message_id,
|
||||
"type": "message",
|
||||
"content": content_blocks,
|
||||
"stop_reason": stop_reason,
|
||||
"stop_sequence": None,
|
||||
},
|
||||
)
|
||||
)
|
||||
self._last_uuid = msg_uuid
|
||||
|
||||
def replace_entries(
|
||||
self, compacted_entries: list[dict], log_prefix: str = "[Transcript]"
|
||||
) -> None:
|
||||
"""Replace all entries with compacted entries from the CLI session file.
|
||||
|
||||
Called after mid-stream compaction so TranscriptBuilder mirrors the
|
||||
CLI's active context (compaction summary + post-compaction entries).
|
||||
|
||||
Builds the new list first and validates it's non-empty before swapping,
|
||||
so corrupt input cannot wipe the conversation history.
|
||||
"""
|
||||
new_entries: list[TranscriptEntry] = []
|
||||
for data in compacted_entries:
|
||||
entry = self._parse_entry(data)
|
||||
if entry is not None:
|
||||
new_entries.append(entry)
|
||||
|
||||
if not new_entries:
|
||||
logger.warning(
|
||||
"%s replace_entries produced 0 entries from %d inputs, keeping old (%d entries)",
|
||||
log_prefix,
|
||||
len(compacted_entries),
|
||||
len(self._entries),
|
||||
)
|
||||
return
|
||||
|
||||
old_count = len(self._entries)
|
||||
self._entries = new_entries
|
||||
self._last_uuid = new_entries[-1].uuid
|
||||
|
||||
logger.info(
|
||||
"%s TranscriptBuilder compacted: %d entries -> %d entries",
|
||||
log_prefix,
|
||||
old_count,
|
||||
len(self._entries),
|
||||
)
|
||||
|
||||
def to_jsonl(self) -> str:
|
||||
"""Export complete context as JSONL.
|
||||
|
||||
Consecutive assistant entries are kept separate to match the
|
||||
native CLI format — the SDK merges them internally on resume.
|
||||
|
||||
Returns the FULL conversation state (all entries), not incremental.
|
||||
This output REPLACES any previous transcript.
|
||||
"""
|
||||
if not self._entries:
|
||||
return ""
|
||||
|
||||
lines = [entry.model_dump_json(exclude_none=True) for entry in self._entries]
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
@property
|
||||
def entry_count(self) -> int:
|
||||
"""Total number of entries in the complete context."""
|
||||
return len(self._entries)
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
"""Whether this builder has any entries."""
|
||||
return len(self._entries) == 0
|
||||
@@ -838,6 +838,7 @@ class NodeExecutionStats(BaseModel):
|
||||
output_token_count: int = 0
|
||||
extra_cost: int = 0
|
||||
extra_steps: int = 0
|
||||
provider_cost: float | None = None
|
||||
# Moderation fields
|
||||
cleared_inputs: Optional[dict[str, list[str]]] = None
|
||||
cleared_outputs: Optional[dict[str, list[str]]] = None
|
||||
@@ -851,6 +852,9 @@ class NodeExecutionStats(BaseModel):
|
||||
current_stats = self.model_dump()
|
||||
|
||||
for key, value in stats_dict.items():
|
||||
if value is None:
|
||||
# Never overwrite an existing value with None
|
||||
continue
|
||||
if key not in current_stats:
|
||||
# Field doesn't exist yet, just set it
|
||||
setattr(self, key, value)
|
||||
|
||||
306
autogpt_platform/backend/backend/data/platform_cost.py
Normal file
306
autogpt_platform/backend/backend/data/platform_cost.py
Normal file
@@ -0,0 +1,306 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.db import execute_raw_with_schema, query_raw_with_schema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PlatformCostEntry(BaseModel):
|
||||
user_id: str
|
||||
graph_exec_id: str | None = None
|
||||
node_exec_id: str | None = None
|
||||
graph_id: str | None = None
|
||||
node_id: str | None = None
|
||||
block_id: str
|
||||
block_name: str
|
||||
provider: str
|
||||
credential_id: str
|
||||
cost_microdollars: int | None = None
|
||||
input_tokens: int | None = None
|
||||
output_tokens: int | None = None
|
||||
data_size: int | None = None
|
||||
duration: float | None = None
|
||||
model: str | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
async def log_platform_cost(entry: PlatformCostEntry) -> None:
|
||||
await execute_raw_with_schema(
|
||||
"""
|
||||
INSERT INTO {schema_prefix}"PlatformCostLog"
|
||||
("id", "createdAt", "userId", "graphExecId", "nodeExecId",
|
||||
"graphId", "nodeId", "blockId", "blockName", "provider",
|
||||
"credentialId", "costMicrodollars", "inputTokens", "outputTokens",
|
||||
"dataSize", "duration", "model", "metadata")
|
||||
VALUES (
|
||||
gen_random_uuid(), NOW(), $1, $2, $3, $4, $5, $6, $7, $8, $9,
|
||||
$10, $11, $12, $13, $14, $15, $16::jsonb
|
||||
)
|
||||
""",
|
||||
entry.user_id,
|
||||
entry.graph_exec_id,
|
||||
entry.node_exec_id,
|
||||
entry.graph_id,
|
||||
entry.node_id,
|
||||
entry.block_id,
|
||||
entry.block_name,
|
||||
entry.provider,
|
||||
entry.credential_id,
|
||||
entry.cost_microdollars,
|
||||
entry.input_tokens,
|
||||
entry.output_tokens,
|
||||
entry.data_size,
|
||||
entry.duration,
|
||||
entry.model,
|
||||
_json_or_none(entry.metadata),
|
||||
)
|
||||
|
||||
|
||||
async def log_platform_cost_safe(entry: PlatformCostEntry) -> None:
|
||||
"""Fire-and-forget wrapper that never raises."""
|
||||
try:
|
||||
await log_platform_cost(entry)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to log platform cost for user=%s provider=%s block=%s",
|
||||
entry.user_id,
|
||||
entry.provider,
|
||||
entry.block_name,
|
||||
)
|
||||
|
||||
|
||||
def _json_or_none(data: dict[str, Any] | None) -> str | None:
|
||||
if data is None:
|
||||
return None
|
||||
return json.dumps(data)
|
||||
|
||||
|
||||
class ProviderCostSummary(BaseModel):
|
||||
provider: str
|
||||
total_cost_microdollars: int
|
||||
total_input_tokens: int
|
||||
total_output_tokens: int
|
||||
request_count: int
|
||||
|
||||
|
||||
class UserCostSummary(BaseModel):
|
||||
user_id: str | None = None
|
||||
email: str | None = None
|
||||
total_cost_microdollars: int
|
||||
total_input_tokens: int
|
||||
total_output_tokens: int
|
||||
request_count: int
|
||||
|
||||
|
||||
class CostLogRow(BaseModel):
|
||||
id: str
|
||||
created_at: datetime
|
||||
user_id: str | None = None
|
||||
email: str | None = None
|
||||
graph_exec_id: str | None = None
|
||||
node_exec_id: str | None = None
|
||||
block_name: str
|
||||
provider: str
|
||||
cost_microdollars: int | None = None
|
||||
input_tokens: int | None = None
|
||||
output_tokens: int | None = None
|
||||
model: str | None = None
|
||||
|
||||
|
||||
class PlatformCostDashboard(BaseModel):
|
||||
by_provider: list[ProviderCostSummary]
|
||||
by_user: list[UserCostSummary]
|
||||
total_cost_microdollars: int
|
||||
total_requests: int
|
||||
total_users: int
|
||||
|
||||
|
||||
def _build_where(
|
||||
start: datetime | None,
|
||||
end: datetime | None,
|
||||
provider: str | None,
|
||||
user_id: str | None,
|
||||
table_alias: str = "",
|
||||
) -> tuple[str, list[Any]]:
|
||||
prefix = f"{table_alias}." if table_alias else ""
|
||||
clauses: list[str] = []
|
||||
params: list[Any] = []
|
||||
idx = 1
|
||||
|
||||
if start:
|
||||
clauses.append(f'{prefix}"createdAt" >= ${idx}::timestamptz')
|
||||
params.append(start)
|
||||
idx += 1
|
||||
if end:
|
||||
clauses.append(f'{prefix}"createdAt" <= ${idx}::timestamptz')
|
||||
params.append(end)
|
||||
idx += 1
|
||||
if provider:
|
||||
clauses.append(f'LOWER({prefix}"provider") = LOWER(${idx})')
|
||||
params.append(provider)
|
||||
idx += 1
|
||||
if user_id:
|
||||
clauses.append(f'{prefix}"userId" = ${idx}')
|
||||
params.append(user_id)
|
||||
idx += 1
|
||||
|
||||
return (" AND ".join(clauses) if clauses else "TRUE", params)
|
||||
|
||||
|
||||
async def get_platform_cost_dashboard(
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
provider: str | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> PlatformCostDashboard:
|
||||
where_p, params_p = _build_where(start, end, provider, user_id, "p")
|
||||
|
||||
by_provider_rows, user_count_rows, by_user_rows = await asyncio.gather(
|
||||
query_raw_with_schema(
|
||||
f"""
|
||||
SELECT
|
||||
p."provider",
|
||||
COALESCE(SUM(p."costMicrodollars"), 0)::bigint AS total_cost,
|
||||
COALESCE(SUM(p."inputTokens"), 0)::bigint AS total_input_tokens,
|
||||
COALESCE(SUM(p."outputTokens"), 0)::bigint AS total_output_tokens,
|
||||
COUNT(*)::bigint AS request_count
|
||||
FROM {{schema_prefix}}"PlatformCostLog" p
|
||||
WHERE {where_p}
|
||||
GROUP BY p."provider"
|
||||
ORDER BY total_cost DESC
|
||||
""",
|
||||
*params_p,
|
||||
),
|
||||
query_raw_with_schema(
|
||||
f"""
|
||||
SELECT COUNT(DISTINCT p."userId")::bigint AS cnt
|
||||
FROM {{schema_prefix}}"PlatformCostLog" p
|
||||
WHERE {where_p}
|
||||
""",
|
||||
*params_p,
|
||||
),
|
||||
query_raw_with_schema(
|
||||
f"""
|
||||
SELECT
|
||||
p."userId" AS user_id,
|
||||
u."email",
|
||||
COALESCE(SUM(p."costMicrodollars"), 0)::bigint AS total_cost,
|
||||
COALESCE(SUM(p."inputTokens"), 0)::bigint AS total_input_tokens,
|
||||
COALESCE(SUM(p."outputTokens"), 0)::bigint AS total_output_tokens,
|
||||
COUNT(*)::bigint AS request_count
|
||||
FROM {{schema_prefix}}"PlatformCostLog" p
|
||||
LEFT JOIN {{schema_prefix}}"User" u ON u."id" = p."userId"
|
||||
WHERE {where_p}
|
||||
GROUP BY p."userId", u."email"
|
||||
ORDER BY total_cost DESC
|
||||
LIMIT 100
|
||||
""",
|
||||
*params_p,
|
||||
),
|
||||
)
|
||||
|
||||
total_users = user_count_rows[0]["cnt"] if user_count_rows else 0
|
||||
total_cost = sum(r["total_cost"] for r in by_provider_rows)
|
||||
total_requests = sum(r["request_count"] for r in by_provider_rows)
|
||||
|
||||
return PlatformCostDashboard(
|
||||
by_provider=[
|
||||
ProviderCostSummary(
|
||||
provider=r["provider"],
|
||||
total_cost_microdollars=r["total_cost"],
|
||||
total_input_tokens=r["total_input_tokens"],
|
||||
total_output_tokens=r["total_output_tokens"],
|
||||
request_count=r["request_count"],
|
||||
)
|
||||
for r in by_provider_rows
|
||||
],
|
||||
by_user=[
|
||||
UserCostSummary(
|
||||
user_id=r.get("user_id"),
|
||||
email=r.get("email"),
|
||||
total_cost_microdollars=r["total_cost"],
|
||||
total_input_tokens=r["total_input_tokens"],
|
||||
total_output_tokens=r["total_output_tokens"],
|
||||
request_count=r["request_count"],
|
||||
)
|
||||
for r in by_user_rows
|
||||
],
|
||||
total_cost_microdollars=total_cost,
|
||||
total_requests=total_requests,
|
||||
total_users=total_users,
|
||||
)
|
||||
|
||||
|
||||
async def get_platform_cost_logs(
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
provider: str | None = None,
|
||||
user_id: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
) -> tuple[list[CostLogRow], int]:
|
||||
where_sql, params = _build_where(start, end, provider, user_id, "p")
|
||||
|
||||
count_rows = await query_raw_with_schema(
|
||||
f"""
|
||||
SELECT COUNT(*)::bigint AS cnt
|
||||
FROM {{schema_prefix}}"PlatformCostLog" p
|
||||
WHERE {where_sql}
|
||||
""",
|
||||
*params,
|
||||
)
|
||||
total = count_rows[0]["cnt"] if count_rows else 0
|
||||
|
||||
offset = (page - 1) * page_size
|
||||
limit_idx = len(params) + 1
|
||||
offset_idx = len(params) + 2
|
||||
rows = await query_raw_with_schema(
|
||||
f"""
|
||||
SELECT
|
||||
p."id",
|
||||
p."createdAt" AS created_at,
|
||||
p."userId" AS user_id,
|
||||
u."email",
|
||||
p."graphExecId" AS graph_exec_id,
|
||||
p."nodeExecId" AS node_exec_id,
|
||||
p."blockName" AS block_name,
|
||||
p."provider",
|
||||
p."costMicrodollars" AS cost_microdollars,
|
||||
p."inputTokens" AS input_tokens,
|
||||
p."outputTokens" AS output_tokens,
|
||||
p."model"
|
||||
FROM {{schema_prefix}}"PlatformCostLog" p
|
||||
LEFT JOIN {{schema_prefix}}"User" u ON u."id" = p."userId"
|
||||
WHERE {where_sql}
|
||||
ORDER BY p."createdAt" DESC, p."id" DESC
|
||||
LIMIT ${limit_idx} OFFSET ${offset_idx}
|
||||
""",
|
||||
*params,
|
||||
page_size,
|
||||
offset,
|
||||
)
|
||||
|
||||
logs = [
|
||||
CostLogRow(
|
||||
id=r["id"],
|
||||
created_at=r["created_at"],
|
||||
user_id=r.get("user_id"),
|
||||
email=r.get("email"),
|
||||
graph_exec_id=r.get("graph_exec_id"),
|
||||
node_exec_id=r.get("node_exec_id"),
|
||||
block_name=r["block_name"],
|
||||
provider=r["provider"],
|
||||
cost_microdollars=r.get("cost_microdollars"),
|
||||
input_tokens=r.get("input_tokens"),
|
||||
output_tokens=r.get("output_tokens"),
|
||||
model=r.get("model"),
|
||||
)
|
||||
for r in rows
|
||||
]
|
||||
return logs, total
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user