Compare commits
90 Commits
dev
...
test-scree
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fcbb1613cc | ||
|
|
1fa89d9488 | ||
|
|
3e016508d4 | ||
|
|
b80d7abda9 | ||
|
|
0e310c788a | ||
|
|
91af007c18 | ||
|
|
e7ca81ed89 | ||
|
|
5164fa878f | ||
|
|
cf605ef5a3 | ||
|
|
e7bd05c6f1 | ||
|
|
22fb3549e3 | ||
|
|
1c3fe1444e | ||
|
|
b89321a688 | ||
|
|
630d6d4705 | ||
|
|
7c685c6677 | ||
|
|
bbdf13c7a8 | ||
|
|
e1ea4cf326 | ||
|
|
db6b4444e0 | ||
|
|
9b1175473b | ||
|
|
752a238166 | ||
|
|
2a73d1baa9 | ||
|
|
254e6057f4 | ||
|
|
a616e5a060 | ||
|
|
c9461836c6 | ||
|
|
50a8df3d67 | ||
|
|
3f7a8dc44d | ||
|
|
1c15d6a6cc | ||
|
|
a31be77408 | ||
|
|
1d45f2f18c | ||
|
|
27e34e9514 | ||
|
|
16d696edcc | ||
|
|
f87bbd5966 | ||
|
|
b64d1ed9fa | ||
|
|
3895d95826 | ||
|
|
181208528f | ||
|
|
0365a26c85 | ||
|
|
fb63ae54f0 | ||
|
|
6de79fb73f | ||
|
|
d57da6c078 | ||
|
|
689cd67a13 | ||
|
|
dca89d1586 | ||
|
|
2f63fcd383 | ||
|
|
f04cd08e40 | ||
|
|
44714f1b25 | ||
|
|
78b95f8a76 | ||
|
|
6f0c1dfa11 | ||
|
|
5e595231da | ||
|
|
7b36bed8a5 | ||
|
|
372900c141 | ||
|
|
7afd2b249d | ||
|
|
8d22653810 | ||
|
|
b00e16b438 | ||
|
|
b5acfb7855 | ||
|
|
1ee0bd6619 | ||
|
|
4190f75b0b | ||
|
|
71315aa982 | ||
|
|
960f893295 | ||
|
|
759effab60 | ||
|
|
45b6ada739 | ||
|
|
da544d3411 | ||
|
|
54e5059d7c | ||
|
|
1d7d2f77f3 | ||
|
|
567bc73ec4 | ||
|
|
61ef54af05 | ||
|
|
405403e6b7 | ||
|
|
ab16e63b0a | ||
|
|
45d3193727 | ||
|
|
9a08011d7d | ||
|
|
6fa66ac7da | ||
|
|
4bad08394c | ||
|
|
993c43b623 | ||
|
|
a8a62eeefc | ||
|
|
173614bcc5 | ||
|
|
fbe634fb19 | ||
|
|
a338c72c42 | ||
|
|
7f4398efa3 | ||
|
|
c2a054c511 | ||
|
|
83b00f4789 | ||
|
|
95524e94b3 | ||
|
|
2c517ff9a1 | ||
|
|
7020ae2189 | ||
|
|
b9336984be | ||
|
|
9924dedddc | ||
|
|
c054799b4f | ||
|
|
f3b5d584a3 | ||
|
|
476d9dcf80 | ||
|
|
072b623f8b | ||
|
|
26b0c95936 | ||
|
|
308357de84 | ||
|
|
1a6c50c6cc |
@@ -0,0 +1,98 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from autogpt_libs.auth import get_user_id, requires_admin_user
|
||||
from cachetools import TTLCache
|
||||
from fastapi import APIRouter, Query, Security
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.platform_cost import (
|
||||
CostLogRow,
|
||||
PlatformCostDashboard,
|
||||
get_platform_cost_dashboard,
|
||||
get_platform_cost_logs,
|
||||
)
|
||||
from backend.util.models import Pagination
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Cache dashboard results for 30 seconds per unique filter combination.
|
||||
# The table is append-only so stale reads are acceptable for analytics.
|
||||
_DASHBOARD_CACHE_TTL = 30
|
||||
_dashboard_cache: TTLCache[tuple, PlatformCostDashboard] = TTLCache(
|
||||
maxsize=256, ttl=_DASHBOARD_CACHE_TTL
|
||||
)
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/platform-costs",
|
||||
tags=["platform-cost", "admin"],
|
||||
dependencies=[Security(requires_admin_user)],
|
||||
)
|
||||
|
||||
|
||||
class PlatformCostLogsResponse(BaseModel):
|
||||
logs: list[CostLogRow]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
@router.get(
|
||||
"/dashboard",
|
||||
response_model=PlatformCostDashboard,
|
||||
summary="Get Platform Cost Dashboard",
|
||||
)
|
||||
async def get_cost_dashboard(
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
start: datetime | None = Query(None),
|
||||
end: datetime | None = Query(None),
|
||||
provider: str | None = Query(None),
|
||||
user_id: str | None = Query(None),
|
||||
):
|
||||
logger.info("Admin %s fetching platform cost dashboard", admin_user_id)
|
||||
cache_key = (start, end, provider, user_id)
|
||||
cached = _dashboard_cache.get(cache_key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
result = await get_platform_cost_dashboard(
|
||||
start=start,
|
||||
end=end,
|
||||
provider=provider,
|
||||
user_id=user_id,
|
||||
)
|
||||
_dashboard_cache[cache_key] = result
|
||||
return result
|
||||
|
||||
|
||||
@router.get(
|
||||
"/logs",
|
||||
response_model=PlatformCostLogsResponse,
|
||||
summary="Get Platform Cost Logs",
|
||||
)
|
||||
async def get_cost_logs(
|
||||
admin_user_id: str = Security(get_user_id),
|
||||
start: datetime | None = Query(None),
|
||||
end: datetime | None = Query(None),
|
||||
provider: str | None = Query(None),
|
||||
user_id: str | None = Query(None),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(50, ge=1, le=200),
|
||||
):
|
||||
logger.info("Admin %s fetching platform cost logs", admin_user_id)
|
||||
logs, total = await get_platform_cost_logs(
|
||||
start=start,
|
||||
end=end,
|
||||
provider=provider,
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
return PlatformCostLogsResponse(
|
||||
logs=logs,
|
||||
pagination=Pagination(
|
||||
total_items=total,
|
||||
total_pages=total_pages,
|
||||
current_page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,192 @@
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
from backend.data.platform_cost import PlatformCostDashboard
|
||||
|
||||
from . import platform_cost_routes
|
||||
from .platform_cost_routes import router as platform_cost_router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(platform_cost_router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_admin_auth(mock_jwt_admin):
|
||||
"""Setup admin auth overrides for all tests in this module"""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
# Clear TTL cache so each test starts cold.
|
||||
platform_cost_routes._dashboard_cache.clear()
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_get_dashboard_success(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
real_dashboard = PlatformCostDashboard(
|
||||
by_provider=[],
|
||||
by_user=[],
|
||||
total_cost_microdollars=0,
|
||||
total_requests=0,
|
||||
total_users=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_dashboard",
|
||||
AsyncMock(return_value=real_dashboard),
|
||||
)
|
||||
|
||||
response = client.get("/platform-costs/dashboard")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "by_provider" in data
|
||||
assert "by_user" in data
|
||||
assert data["total_cost_microdollars"] == 0
|
||||
|
||||
|
||||
def test_get_logs_success(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_logs",
|
||||
AsyncMock(return_value=([], 0)),
|
||||
)
|
||||
|
||||
response = client.get("/platform-costs/logs")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["logs"] == []
|
||||
assert data["pagination"]["total_items"] == 0
|
||||
|
||||
|
||||
def test_get_dashboard_with_filters(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
real_dashboard = PlatformCostDashboard(
|
||||
by_provider=[],
|
||||
by_user=[],
|
||||
total_cost_microdollars=0,
|
||||
total_requests=0,
|
||||
total_users=0,
|
||||
)
|
||||
mock_dashboard = AsyncMock(return_value=real_dashboard)
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_dashboard",
|
||||
mock_dashboard,
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/platform-costs/dashboard",
|
||||
params={
|
||||
"start": "2026-01-01T00:00:00",
|
||||
"end": "2026-04-01T00:00:00",
|
||||
"provider": "openai",
|
||||
"user_id": "test-user-123",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
mock_dashboard.assert_called_once()
|
||||
call_kwargs = mock_dashboard.call_args.kwargs
|
||||
assert call_kwargs["provider"] == "openai"
|
||||
assert call_kwargs["user_id"] == "test-user-123"
|
||||
assert call_kwargs["start"] is not None
|
||||
assert call_kwargs["end"] is not None
|
||||
|
||||
|
||||
def test_get_logs_with_pagination(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_logs",
|
||||
AsyncMock(return_value=([], 0)),
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/platform-costs/logs",
|
||||
params={"page": 2, "page_size": 25, "provider": "anthropic"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["pagination"]["current_page"] == 2
|
||||
assert data["pagination"]["page_size"] == 25
|
||||
|
||||
|
||||
def test_get_dashboard_requires_admin() -> None:
|
||||
import fastapi
|
||||
from fastapi import HTTPException
|
||||
|
||||
def reject_jwt(request: fastapi.Request):
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = reject_jwt
|
||||
try:
|
||||
response = client.get("/platform-costs/dashboard")
|
||||
assert response.status_code == 401
|
||||
response = client.get("/platform-costs/logs")
|
||||
assert response.status_code == 401
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_get_dashboard_rejects_non_admin(mock_jwt_user, mock_jwt_admin) -> None:
|
||||
"""Non-admin JWT must be rejected with 403 by requires_admin_user."""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
try:
|
||||
response = client.get("/platform-costs/dashboard")
|
||||
assert response.status_code == 403
|
||||
response = client.get("/platform-costs/logs")
|
||||
assert response.status_code == 403
|
||||
finally:
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
|
||||
|
||||
def test_get_logs_invalid_page_size_too_large() -> None:
|
||||
"""page_size > 200 must be rejected with 422."""
|
||||
response = client.get("/platform-costs/logs", params={"page_size": 201})
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_get_logs_invalid_page_size_zero() -> None:
|
||||
"""page_size = 0 (below ge=1) must be rejected with 422."""
|
||||
response = client.get("/platform-costs/logs", params={"page_size": 0})
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_get_logs_invalid_page_negative() -> None:
|
||||
"""page < 1 must be rejected with 422."""
|
||||
response = client.get("/platform-costs/logs", params={"page": 0})
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_get_dashboard_invalid_date_format() -> None:
|
||||
"""Malformed start date must be rejected with 422."""
|
||||
response = client.get("/platform-costs/dashboard", params={"start": "not-a-date"})
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_get_dashboard_cache_hit(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Second identical request returns cached result without calling the DB again."""
|
||||
real_dashboard = PlatformCostDashboard(
|
||||
by_provider=[],
|
||||
by_user=[],
|
||||
total_cost_microdollars=42,
|
||||
total_requests=1,
|
||||
total_users=1,
|
||||
)
|
||||
mock_fn = mocker.patch(
|
||||
"backend.api.features.admin.platform_cost_routes.get_platform_cost_dashboard",
|
||||
AsyncMock(return_value=real_dashboard),
|
||||
)
|
||||
|
||||
client.get("/platform-costs/dashboard")
|
||||
client.get("/platform-costs/dashboard")
|
||||
|
||||
mock_fn.assert_awaited_once() # second request hit the cache
|
||||
@@ -18,6 +18,7 @@ from prisma.errors import PrismaError
|
||||
|
||||
import backend.api.features.admin.credit_admin_routes
|
||||
import backend.api.features.admin.execution_analytics_routes
|
||||
import backend.api.features.admin.platform_cost_routes
|
||||
import backend.api.features.admin.rate_limit_admin_routes
|
||||
import backend.api.features.admin.store_admin_routes
|
||||
import backend.api.features.builder
|
||||
@@ -329,6 +330,11 @@ app.include_router(
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/copilot",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.admin.platform_cost_routes.router,
|
||||
tags=["v2", "admin"],
|
||||
prefix="/api/admin",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.executions.review.routes.router,
|
||||
tags=["v2", "executions", "review"],
|
||||
|
||||
@@ -17,7 +17,7 @@ from backend.blocks.apollo.models import (
|
||||
PrimaryPhone,
|
||||
SearchOrganizationsRequest,
|
||||
)
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
|
||||
|
||||
|
||||
class SearchOrganizationsBlock(Block):
|
||||
@@ -218,6 +218,11 @@ To find IDs, identify the values for organization_id when you call this endpoint
|
||||
) -> BlockOutput:
|
||||
query = SearchOrganizationsRequest(**input_data.model_dump())
|
||||
organizations = await self.search_organizations(query, credentials)
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=float(len(organizations)), provider_cost_type="items"
|
||||
)
|
||||
)
|
||||
for organization in organizations:
|
||||
yield "organization", organization
|
||||
yield "organizations", organizations
|
||||
|
||||
@@ -21,7 +21,7 @@ from backend.blocks.apollo.models import (
|
||||
SearchPeopleRequest,
|
||||
SenorityLevels,
|
||||
)
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
|
||||
|
||||
|
||||
class SearchPeopleBlock(Block):
|
||||
@@ -366,4 +366,9 @@ class SearchPeopleBlock(Block):
|
||||
*(enrich_or_fallback(person) for person in people)
|
||||
)
|
||||
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=float(len(people)), provider_cost_type="items"
|
||||
)
|
||||
)
|
||||
yield "people", people
|
||||
|
||||
@@ -0,0 +1,712 @@
|
||||
"""Unit tests for merge_stats cost tracking in individual blocks.
|
||||
|
||||
Covers the exa code_context, exa contents, and apollo organization blocks
|
||||
to verify provider cost is correctly extracted and reported.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import APIKeyCredentials, NodeExecutionStats
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
TEST_EXA_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="exa",
|
||||
api_key=SecretStr("mock-exa-api-key"),
|
||||
title="Mock Exa API key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
TEST_EXA_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_EXA_CREDENTIALS.provider,
|
||||
"id": TEST_EXA_CREDENTIALS.id,
|
||||
"type": TEST_EXA_CREDENTIALS.type,
|
||||
"title": TEST_EXA_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ExaCodeContextBlock — cost_dollars is a string like "0.005"
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExaCodeContextBlockCostTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_stats_called_with_float_cost(self):
|
||||
"""float(cost_dollars) parsed from API string and passed to merge_stats."""
|
||||
from backend.blocks.exa.code_context import ExaCodeContextBlock
|
||||
|
||||
block = ExaCodeContextBlock()
|
||||
|
||||
api_response = {
|
||||
"requestId": "req-1",
|
||||
"query": "how to use hooks",
|
||||
"response": "Here are some examples...",
|
||||
"resultsCount": 3,
|
||||
"costDollars": "0.005",
|
||||
"searchTime": 1.2,
|
||||
"outputTokens": 100,
|
||||
}
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = api_response
|
||||
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.blocks.exa.code_context.Requests.post",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_resp,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = ExaCodeContextBlock.Input(
|
||||
query="how to use hooks",
|
||||
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
results = []
|
||||
async for output in block.run(
|
||||
input_data,
|
||||
credentials=TEST_EXA_CREDENTIALS,
|
||||
):
|
||||
results.append(output)
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == pytest.approx(0.005)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_cost_dollars_does_not_raise(self):
|
||||
"""When cost_dollars cannot be parsed as float, merge_stats is not called."""
|
||||
from backend.blocks.exa.code_context import ExaCodeContextBlock
|
||||
|
||||
block = ExaCodeContextBlock()
|
||||
|
||||
api_response = {
|
||||
"requestId": "req-2",
|
||||
"query": "query",
|
||||
"response": "response",
|
||||
"resultsCount": 0,
|
||||
"costDollars": "N/A",
|
||||
"searchTime": 0.5,
|
||||
"outputTokens": 0,
|
||||
}
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = api_response
|
||||
|
||||
merge_calls: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.blocks.exa.code_context.Requests.post",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_resp,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: merge_calls.append(s)
|
||||
),
|
||||
):
|
||||
input_data = ExaCodeContextBlock.Input(
|
||||
query="query",
|
||||
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(
|
||||
input_data,
|
||||
credentials=TEST_EXA_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert merge_calls == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_cost_is_tracked(self):
|
||||
"""A zero cost_dollars string '0.0' should still be recorded."""
|
||||
from backend.blocks.exa.code_context import ExaCodeContextBlock
|
||||
|
||||
block = ExaCodeContextBlock()
|
||||
|
||||
api_response = {
|
||||
"requestId": "req-3",
|
||||
"query": "query",
|
||||
"response": "...",
|
||||
"resultsCount": 1,
|
||||
"costDollars": "0.0",
|
||||
"searchTime": 0.1,
|
||||
"outputTokens": 10,
|
||||
}
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = api_response
|
||||
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.blocks.exa.code_context.Requests.post",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_resp,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = ExaCodeContextBlock.Input(
|
||||
query="query",
|
||||
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(
|
||||
input_data,
|
||||
credentials=TEST_EXA_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ExaContentsBlock — response.cost_dollars.total (CostDollars model)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExaContentsBlockCostTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_stats_called_with_cost_dollars_total(self):
|
||||
"""provider_cost equals response.cost_dollars.total when present."""
|
||||
from backend.blocks.exa.contents import ExaContentsBlock
|
||||
from backend.blocks.exa.helpers import CostDollars
|
||||
|
||||
block = ExaContentsBlock()
|
||||
|
||||
cost_dollars = CostDollars(total=0.012)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.results = []
|
||||
mock_response.context = None
|
||||
mock_response.statuses = None
|
||||
mock_response.cost_dollars = cost_dollars
|
||||
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.blocks.exa.contents.AsyncExa",
|
||||
return_value=MagicMock(
|
||||
get_contents=AsyncMock(return_value=mock_response)
|
||||
),
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = ExaContentsBlock.Input(
|
||||
urls=["https://example.com"],
|
||||
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(
|
||||
input_data,
|
||||
credentials=TEST_EXA_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == pytest.approx(0.012)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_merge_stats_when_cost_dollars_absent(self):
|
||||
"""When response.cost_dollars is None, merge_stats is not called."""
|
||||
from backend.blocks.exa.contents import ExaContentsBlock
|
||||
|
||||
block = ExaContentsBlock()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.results = []
|
||||
mock_response.context = None
|
||||
mock_response.statuses = None
|
||||
mock_response.cost_dollars = None
|
||||
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.blocks.exa.contents.AsyncExa",
|
||||
return_value=MagicMock(
|
||||
get_contents=AsyncMock(return_value=mock_response)
|
||||
),
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = ExaContentsBlock.Input(
|
||||
urls=["https://example.com"],
|
||||
credentials=TEST_EXA_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(
|
||||
input_data,
|
||||
credentials=TEST_EXA_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert accumulated == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SearchOrganizationsBlock — provider_cost = float(len(organizations))
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSearchOrganizationsBlockCostTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_stats_called_with_org_count(self):
|
||||
"""provider_cost == number of returned organizations, type == 'items'."""
|
||||
from backend.blocks.apollo._auth import TEST_CREDENTIALS as APOLLO_CREDS
|
||||
from backend.blocks.apollo._auth import (
|
||||
TEST_CREDENTIALS_INPUT as APOLLO_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.apollo.models import Organization
|
||||
from backend.blocks.apollo.organization import SearchOrganizationsBlock
|
||||
|
||||
block = SearchOrganizationsBlock()
|
||||
|
||||
fake_orgs = [Organization(id=str(i), name=f"Org{i}") for i in range(3)]
|
||||
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
SearchOrganizationsBlock,
|
||||
"search_organizations",
|
||||
new_callable=AsyncMock,
|
||||
return_value=fake_orgs,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = SearchOrganizationsBlock.Input(
|
||||
credentials=APOLLO_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
results = []
|
||||
async for output in block.run(
|
||||
input_data,
|
||||
credentials=APOLLO_CREDS,
|
||||
):
|
||||
results.append(output)
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == pytest.approx(3.0)
|
||||
assert accumulated[0].provider_cost_type == "items"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_org_list_tracks_zero(self):
|
||||
"""An empty organization list results in provider_cost=0.0."""
|
||||
from backend.blocks.apollo._auth import TEST_CREDENTIALS as APOLLO_CREDS
|
||||
from backend.blocks.apollo._auth import (
|
||||
TEST_CREDENTIALS_INPUT as APOLLO_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.apollo.organization import SearchOrganizationsBlock
|
||||
|
||||
block = SearchOrganizationsBlock()
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
SearchOrganizationsBlock,
|
||||
"search_organizations",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = SearchOrganizationsBlock.Input(
|
||||
credentials=APOLLO_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(
|
||||
input_data,
|
||||
credentials=APOLLO_CREDS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == 0.0
|
||||
assert accumulated[0].provider_cost_type == "items"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# JinaEmbeddingBlock — token count from usage.total_tokens
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestJinaEmbeddingBlockCostTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_stats_called_with_token_count(self):
|
||||
"""provider token count is recorded when API returns usage.total_tokens."""
|
||||
from backend.blocks.jina._auth import TEST_CREDENTIALS as JINA_CREDS
|
||||
from backend.blocks.jina._auth import TEST_CREDENTIALS_INPUT as JINA_CREDS_INPUT
|
||||
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
|
||||
|
||||
block = JinaEmbeddingBlock()
|
||||
|
||||
api_response = {
|
||||
"data": [{"embedding": [0.1, 0.2, 0.3]}],
|
||||
"usage": {"total_tokens": 42},
|
||||
}
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = api_response
|
||||
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.blocks.jina.embeddings.Requests.post",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_resp,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = JinaEmbeddingBlock.Input(
|
||||
texts=["hello world"],
|
||||
credentials=JINA_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=JINA_CREDS):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].input_token_count == 42
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_merge_stats_when_usage_absent(self):
|
||||
"""When API response omits usage field, merge_stats is not called."""
|
||||
from backend.blocks.jina._auth import TEST_CREDENTIALS as JINA_CREDS
|
||||
from backend.blocks.jina._auth import TEST_CREDENTIALS_INPUT as JINA_CREDS_INPUT
|
||||
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
|
||||
|
||||
block = JinaEmbeddingBlock()
|
||||
|
||||
api_response = {
|
||||
"data": [{"embedding": [0.1, 0.2, 0.3]}],
|
||||
}
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = api_response
|
||||
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.blocks.jina.embeddings.Requests.post",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_resp,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = JinaEmbeddingBlock.Input(
|
||||
texts=["hello"],
|
||||
credentials=JINA_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=JINA_CREDS):
|
||||
pass
|
||||
|
||||
assert accumulated == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# UnrealTextToSpeechBlock — character count from input text length
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestUnrealTextToSpeechBlockCostTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_stats_called_with_character_count(self):
|
||||
"""provider_cost equals len(text) with type='characters'."""
|
||||
from backend.blocks.text_to_speech_block import TEST_CREDENTIALS as TTS_CREDS
|
||||
from backend.blocks.text_to_speech_block import (
|
||||
TEST_CREDENTIALS_INPUT as TTS_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
||||
|
||||
block = UnrealTextToSpeechBlock()
|
||||
test_text = "Hello, world!"
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
UnrealTextToSpeechBlock,
|
||||
"call_unreal_speech_api",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"OutputUri": "https://example.com/audio.mp3"},
|
||||
),
|
||||
patch.object(block, "merge_stats") as mock_merge,
|
||||
):
|
||||
input_data = UnrealTextToSpeechBlock.Input(
|
||||
text=test_text,
|
||||
credentials=TTS_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=TTS_CREDS):
|
||||
pass
|
||||
|
||||
mock_merge.assert_called_once()
|
||||
stats = mock_merge.call_args[0][0]
|
||||
assert stats.provider_cost == float(len(test_text))
|
||||
assert stats.provider_cost_type == "characters"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_text_gives_zero_characters(self):
|
||||
"""An empty text string results in provider_cost=0.0."""
|
||||
from backend.blocks.text_to_speech_block import TEST_CREDENTIALS as TTS_CREDS
|
||||
from backend.blocks.text_to_speech_block import (
|
||||
TEST_CREDENTIALS_INPUT as TTS_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
||||
|
||||
block = UnrealTextToSpeechBlock()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
UnrealTextToSpeechBlock,
|
||||
"call_unreal_speech_api",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"OutputUri": "https://example.com/audio.mp3"},
|
||||
),
|
||||
patch.object(block, "merge_stats") as mock_merge,
|
||||
):
|
||||
input_data = UnrealTextToSpeechBlock.Input(
|
||||
text="",
|
||||
credentials=TTS_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=TTS_CREDS):
|
||||
pass
|
||||
|
||||
mock_merge.assert_called_once()
|
||||
stats = mock_merge.call_args[0][0]
|
||||
assert stats.provider_cost == 0.0
|
||||
assert stats.provider_cost_type == "characters"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GoogleMapsSearchBlock — item count from search_places results
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGoogleMapsSearchBlockCostTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_stats_called_with_place_count(self):
|
||||
"""provider_cost equals number of returned places, type == 'items'."""
|
||||
from backend.blocks.google_maps import TEST_CREDENTIALS as MAPS_CREDS
|
||||
from backend.blocks.google_maps import (
|
||||
TEST_CREDENTIALS_INPUT as MAPS_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.google_maps import GoogleMapsSearchBlock
|
||||
|
||||
block = GoogleMapsSearchBlock()
|
||||
|
||||
fake_places = [{"name": f"Place{i}", "address": f"Addr{i}"} for i in range(4)]
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
GoogleMapsSearchBlock,
|
||||
"search_places",
|
||||
return_value=fake_places,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = GoogleMapsSearchBlock.Input(
|
||||
query="coffee shops",
|
||||
credentials=MAPS_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=MAPS_CREDS):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == 4.0
|
||||
assert accumulated[0].provider_cost_type == "items"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_results_tracks_zero(self):
|
||||
"""Zero places returned results in provider_cost=0.0."""
|
||||
from backend.blocks.google_maps import TEST_CREDENTIALS as MAPS_CREDS
|
||||
from backend.blocks.google_maps import (
|
||||
TEST_CREDENTIALS_INPUT as MAPS_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.google_maps import GoogleMapsSearchBlock
|
||||
|
||||
block = GoogleMapsSearchBlock()
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
GoogleMapsSearchBlock,
|
||||
"search_places",
|
||||
return_value=[],
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = GoogleMapsSearchBlock.Input(
|
||||
query="nothing here",
|
||||
credentials=MAPS_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=MAPS_CREDS):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == 0.0
|
||||
assert accumulated[0].provider_cost_type == "items"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SmartLeadAddLeadsBlock — item count from lead_list length
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSmartLeadAddLeadsBlockCostTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_stats_called_with_lead_count(self):
|
||||
"""provider_cost equals number of leads uploaded, type == 'items'."""
|
||||
from backend.blocks.smartlead._auth import TEST_CREDENTIALS as SL_CREDS
|
||||
from backend.blocks.smartlead._auth import (
|
||||
TEST_CREDENTIALS_INPUT as SL_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.smartlead.campaign import AddLeadToCampaignBlock
|
||||
from backend.blocks.smartlead.models import (
|
||||
AddLeadsToCampaignResponse,
|
||||
LeadInput,
|
||||
)
|
||||
|
||||
block = AddLeadToCampaignBlock()
|
||||
|
||||
fake_leads = [
|
||||
LeadInput(first_name="Alice", last_name="A", email="alice@example.com"),
|
||||
LeadInput(first_name="Bob", last_name="B", email="bob@example.com"),
|
||||
]
|
||||
fake_response = AddLeadsToCampaignResponse(
|
||||
ok=True,
|
||||
upload_count=2,
|
||||
total_leads=2,
|
||||
block_count=0,
|
||||
duplicate_count=0,
|
||||
invalid_email_count=0,
|
||||
invalid_emails=[],
|
||||
already_added_to_campaign=0,
|
||||
unsubscribed_leads=[],
|
||||
is_lead_limit_exhausted=False,
|
||||
lead_import_stopped_count=0,
|
||||
bounce_count=0,
|
||||
)
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
AddLeadToCampaignBlock,
|
||||
"add_leads_to_campaign",
|
||||
new_callable=AsyncMock,
|
||||
return_value=fake_response,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = AddLeadToCampaignBlock.Input(
|
||||
campaign_id=123,
|
||||
lead_list=fake_leads,
|
||||
credentials=SL_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=SL_CREDS):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == 2.0
|
||||
assert accumulated[0].provider_cost_type == "items"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SearchPeopleBlock — item count from people list length
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSearchPeopleBlockCostTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_stats_called_with_people_count(self):
|
||||
"""provider_cost equals number of returned people, type == 'items'."""
|
||||
from backend.blocks.apollo._auth import TEST_CREDENTIALS as APOLLO_CREDS
|
||||
from backend.blocks.apollo._auth import (
|
||||
TEST_CREDENTIALS_INPUT as APOLLO_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.apollo.models import Contact
|
||||
from backend.blocks.apollo.people import SearchPeopleBlock
|
||||
|
||||
block = SearchPeopleBlock()
|
||||
fake_people = [Contact(id=str(i), first_name=f"Person{i}") for i in range(5)]
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
SearchPeopleBlock,
|
||||
"search_people",
|
||||
new_callable=AsyncMock,
|
||||
return_value=fake_people,
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = SearchPeopleBlock.Input(
|
||||
credentials=APOLLO_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=APOLLO_CREDS):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == pytest.approx(5.0)
|
||||
assert accumulated[0].provider_cost_type == "items"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_people_list_tracks_zero(self):
|
||||
"""An empty people list results in provider_cost=0.0."""
|
||||
from backend.blocks.apollo._auth import TEST_CREDENTIALS as APOLLO_CREDS
|
||||
from backend.blocks.apollo._auth import (
|
||||
TEST_CREDENTIALS_INPUT as APOLLO_CREDS_INPUT,
|
||||
)
|
||||
from backend.blocks.apollo.people import SearchPeopleBlock
|
||||
|
||||
block = SearchPeopleBlock()
|
||||
accumulated: list[NodeExecutionStats] = []
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
SearchPeopleBlock,
|
||||
"search_people",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
),
|
||||
patch.object(
|
||||
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
|
||||
),
|
||||
):
|
||||
input_data = SearchPeopleBlock.Input(
|
||||
credentials=APOLLO_CREDS_INPUT, # type: ignore[arg-type]
|
||||
)
|
||||
async for _ in block.run(input_data, credentials=APOLLO_CREDS):
|
||||
pass
|
||||
|
||||
assert len(accumulated) == 1
|
||||
assert accumulated[0].provider_cost == 0.0
|
||||
assert accumulated[0].provider_cost_type == "items"
|
||||
@@ -9,6 +9,7 @@ from typing import Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -116,3 +117,10 @@ class ExaCodeContextBlock(Block):
|
||||
yield "cost_dollars", context.cost_dollars
|
||||
yield "search_time", context.search_time
|
||||
yield "output_tokens", context.output_tokens
|
||||
|
||||
# Parse cost_dollars (API returns as string, e.g. "0.005")
|
||||
try:
|
||||
cost_usd = float(context.cost_dollars)
|
||||
self.merge_stats(NodeExecutionStats(provider_cost=cost_usd))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Optional
|
||||
from exa_py import AsyncExa
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -223,3 +224,6 @@ class ExaContentsBlock(Block):
|
||||
|
||||
if response.cost_dollars:
|
||||
yield "cost_dollars", response.cost_dollars
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=response.cost_dollars.total)
|
||||
)
|
||||
|
||||
@@ -0,0 +1,575 @@
|
||||
"""Tests for cost tracking in Exa blocks.
|
||||
|
||||
Covers the cost_dollars → provider_cost → merge_stats path for both
|
||||
ExaContentsBlock and ExaCodeContextBlock.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.exa._test import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT
|
||||
from backend.data.model import NodeExecutionStats
|
||||
|
||||
|
||||
class TestExaCodeContextCostTracking:
|
||||
"""ExaCodeContextBlock parses cost_dollars (string) and calls merge_stats."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_cost_string_is_parsed_and_merged(self):
|
||||
"""A numeric cost string like '0.005' is merged as provider_cost."""
|
||||
from backend.blocks.exa.code_context import ExaCodeContextBlock
|
||||
|
||||
block = ExaCodeContextBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
api_response = {
|
||||
"requestId": "req-1",
|
||||
"query": "test query",
|
||||
"response": "some code",
|
||||
"resultsCount": 3,
|
||||
"costDollars": "0.005",
|
||||
"searchTime": 1.2,
|
||||
"outputTokens": 100,
|
||||
}
|
||||
|
||||
with patch("backend.blocks.exa.code_context.Requests") as mock_requests_cls:
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = api_response
|
||||
mock_requests_cls.return_value.post = AsyncMock(return_value=mock_resp)
|
||||
|
||||
outputs = []
|
||||
async for key, value in block.run(
|
||||
block.Input(query="test query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
outputs.append((key, value))
|
||||
|
||||
assert any(k == "cost_dollars" for k, _ in outputs)
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.005)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_cost_string_does_not_raise(self):
|
||||
"""A non-numeric cost_dollars value is swallowed silently."""
|
||||
from backend.blocks.exa.code_context import ExaCodeContextBlock
|
||||
|
||||
block = ExaCodeContextBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
api_response = {
|
||||
"requestId": "req-2",
|
||||
"query": "test",
|
||||
"response": "code",
|
||||
"resultsCount": 0,
|
||||
"costDollars": "N/A",
|
||||
"searchTime": 0.5,
|
||||
"outputTokens": 0,
|
||||
}
|
||||
|
||||
with patch("backend.blocks.exa.code_context.Requests") as mock_requests_cls:
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = api_response
|
||||
mock_requests_cls.return_value.post = AsyncMock(return_value=mock_resp)
|
||||
|
||||
outputs = []
|
||||
async for key, value in block.run(
|
||||
block.Input(query="test", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
outputs.append((key, value))
|
||||
|
||||
# No merge_stats call because float() raised ValueError
|
||||
assert len(merged) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_cost_string_is_merged(self):
|
||||
"""'0.0' is a valid cost — should still be tracked."""
|
||||
from backend.blocks.exa.code_context import ExaCodeContextBlock
|
||||
|
||||
block = ExaCodeContextBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
api_response = {
|
||||
"requestId": "req-3",
|
||||
"query": "free query",
|
||||
"response": "result",
|
||||
"resultsCount": 1,
|
||||
"costDollars": "0.0",
|
||||
"searchTime": 0.1,
|
||||
"outputTokens": 10,
|
||||
}
|
||||
|
||||
with patch("backend.blocks.exa.code_context.Requests") as mock_requests_cls:
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = api_response
|
||||
mock_requests_cls.return_value.post = AsyncMock(return_value=mock_resp)
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(query="free query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.0)
|
||||
|
||||
|
||||
class TestExaContentsCostTracking:
|
||||
"""ExaContentsBlock merges cost_dollars.total as provider_cost."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_dollars_total_is_merged(self):
|
||||
"""When the SDK response includes cost_dollars, its total is merged."""
|
||||
from backend.blocks.exa.contents import ExaContentsBlock
|
||||
from backend.blocks.exa.helpers import CostDollars
|
||||
|
||||
block = ExaContentsBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
mock_sdk_response = MagicMock()
|
||||
mock_sdk_response.results = []
|
||||
mock_sdk_response.context = None
|
||||
mock_sdk_response.statuses = None
|
||||
mock_sdk_response.cost_dollars = CostDollars(total=0.012)
|
||||
|
||||
with patch("backend.blocks.exa.contents.AsyncExa") as mock_exa_cls:
|
||||
mock_exa = MagicMock()
|
||||
mock_exa.get_contents = AsyncMock(return_value=mock_sdk_response)
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.012)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_cost_dollars_skips_merge(self):
|
||||
"""When cost_dollars is absent, merge_stats is not called."""
|
||||
from backend.blocks.exa.contents import ExaContentsBlock
|
||||
|
||||
block = ExaContentsBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
mock_sdk_response = MagicMock()
|
||||
mock_sdk_response.results = []
|
||||
mock_sdk_response.context = None
|
||||
mock_sdk_response.statuses = None
|
||||
mock_sdk_response.cost_dollars = None
|
||||
|
||||
with patch("backend.blocks.exa.contents.AsyncExa") as mock_exa_cls:
|
||||
mock_exa = MagicMock()
|
||||
mock_exa.get_contents = AsyncMock(return_value=mock_sdk_response)
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_cost_dollars_is_merged(self):
|
||||
"""A total of 0.0 (free tier) should still be merged."""
|
||||
from backend.blocks.exa.contents import ExaContentsBlock
|
||||
from backend.blocks.exa.helpers import CostDollars
|
||||
|
||||
block = ExaContentsBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
mock_sdk_response = MagicMock()
|
||||
mock_sdk_response.results = []
|
||||
mock_sdk_response.context = None
|
||||
mock_sdk_response.statuses = None
|
||||
mock_sdk_response.cost_dollars = CostDollars(total=0.0)
|
||||
|
||||
with patch("backend.blocks.exa.contents.AsyncExa") as mock_exa_cls:
|
||||
mock_exa = MagicMock()
|
||||
mock_exa.get_contents = AsyncMock(return_value=mock_sdk_response)
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.0)
|
||||
|
||||
|
||||
class TestExaSearchCostTracking:
|
||||
"""ExaSearchBlock merges cost_dollars.total as provider_cost."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_dollars_total_is_merged(self):
|
||||
"""When the SDK response includes cost_dollars, its total is merged."""
|
||||
from backend.blocks.exa.helpers import CostDollars
|
||||
from backend.blocks.exa.search import ExaSearchBlock
|
||||
|
||||
block = ExaSearchBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
mock_sdk_response = MagicMock()
|
||||
mock_sdk_response.results = []
|
||||
mock_sdk_response.context = None
|
||||
mock_sdk_response.resolved_search_type = None
|
||||
mock_sdk_response.cost_dollars = CostDollars(total=0.008)
|
||||
|
||||
with patch("backend.blocks.exa.search.AsyncExa") as mock_exa_cls:
|
||||
mock_exa = MagicMock()
|
||||
mock_exa.search = AsyncMock(return_value=mock_sdk_response)
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(query="test query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.008)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_cost_dollars_skips_merge(self):
|
||||
"""When cost_dollars is absent, merge_stats is not called."""
|
||||
from backend.blocks.exa.search import ExaSearchBlock
|
||||
|
||||
block = ExaSearchBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
mock_sdk_response = MagicMock()
|
||||
mock_sdk_response.results = []
|
||||
mock_sdk_response.context = None
|
||||
mock_sdk_response.resolved_search_type = None
|
||||
mock_sdk_response.cost_dollars = None
|
||||
|
||||
with patch("backend.blocks.exa.search.AsyncExa") as mock_exa_cls:
|
||||
mock_exa = MagicMock()
|
||||
mock_exa.search = AsyncMock(return_value=mock_sdk_response)
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(query="test query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 0
|
||||
|
||||
|
||||
class TestExaSimilarCostTracking:
|
||||
"""ExaFindSimilarBlock merges cost_dollars.total as provider_cost."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_dollars_total_is_merged(self):
|
||||
"""When the SDK response includes cost_dollars, its total is merged."""
|
||||
from backend.blocks.exa.helpers import CostDollars
|
||||
from backend.blocks.exa.similar import ExaFindSimilarBlock
|
||||
|
||||
block = ExaFindSimilarBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
mock_sdk_response = MagicMock()
|
||||
mock_sdk_response.results = []
|
||||
mock_sdk_response.context = None
|
||||
mock_sdk_response.request_id = "req-1"
|
||||
mock_sdk_response.cost_dollars = CostDollars(total=0.015)
|
||||
|
||||
with patch("backend.blocks.exa.similar.AsyncExa") as mock_exa_cls:
|
||||
mock_exa = MagicMock()
|
||||
mock_exa.find_similar = AsyncMock(return_value=mock_sdk_response)
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(url="https://example.com", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.015)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_cost_dollars_skips_merge(self):
|
||||
"""When cost_dollars is absent, merge_stats is not called."""
|
||||
from backend.blocks.exa.similar import ExaFindSimilarBlock
|
||||
|
||||
block = ExaFindSimilarBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
mock_sdk_response = MagicMock()
|
||||
mock_sdk_response.results = []
|
||||
mock_sdk_response.context = None
|
||||
mock_sdk_response.request_id = "req-2"
|
||||
mock_sdk_response.cost_dollars = None
|
||||
|
||||
with patch("backend.blocks.exa.similar.AsyncExa") as mock_exa_cls:
|
||||
mock_exa = MagicMock()
|
||||
mock_exa.find_similar = AsyncMock(return_value=mock_sdk_response)
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(url="https://example.com", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ExaCreateResearchBlock — cost_dollars from completed poll response
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
COMPLETED_RESEARCH_RESPONSE = {
|
||||
"researchId": "test-research-id",
|
||||
"status": "completed",
|
||||
"model": "exa-research",
|
||||
"instructions": "test instructions",
|
||||
"createdAt": 1700000000000,
|
||||
"finishedAt": 1700000060000,
|
||||
"costDollars": {
|
||||
"total": 0.05,
|
||||
"numSearches": 3,
|
||||
"numPages": 10,
|
||||
"reasoningTokens": 500,
|
||||
},
|
||||
"output": {"content": "Research findings...", "parsed": None},
|
||||
}
|
||||
|
||||
PENDING_RESEARCH_RESPONSE = {
|
||||
"researchId": "test-research-id",
|
||||
"status": "pending",
|
||||
"model": "exa-research",
|
||||
"instructions": "test instructions",
|
||||
"createdAt": 1700000000000,
|
||||
}
|
||||
|
||||
|
||||
class TestExaCreateResearchBlockCostTracking:
|
||||
"""ExaCreateResearchBlock merges cost from completed poll response."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_merged_when_research_completes(self):
|
||||
"""merge_stats called with provider_cost=total when poll returns completed."""
|
||||
from backend.blocks.exa.research import ExaCreateResearchBlock
|
||||
|
||||
block = ExaCreateResearchBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
create_resp = MagicMock()
|
||||
create_resp.json.return_value = PENDING_RESEARCH_RESPONSE
|
||||
|
||||
poll_resp = MagicMock()
|
||||
poll_resp.json.return_value = COMPLETED_RESEARCH_RESPONSE
|
||||
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.post = AsyncMock(return_value=create_resp)
|
||||
mock_instance.get = AsyncMock(return_value=poll_resp)
|
||||
|
||||
with (
|
||||
patch("backend.blocks.exa.research.Requests", return_value=mock_instance),
|
||||
patch("asyncio.sleep", new=AsyncMock()),
|
||||
):
|
||||
async for _ in block.run(
|
||||
block.Input(
|
||||
instructions="test instructions",
|
||||
wait_for_completion=True,
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
),
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.05)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_merge_when_no_cost_dollars(self):
|
||||
"""When completed response has no costDollars, merge_stats is not called."""
|
||||
from backend.blocks.exa.research import ExaCreateResearchBlock
|
||||
|
||||
block = ExaCreateResearchBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
no_cost_response = {**COMPLETED_RESEARCH_RESPONSE, "costDollars": None}
|
||||
create_resp = MagicMock()
|
||||
create_resp.json.return_value = PENDING_RESEARCH_RESPONSE
|
||||
poll_resp = MagicMock()
|
||||
poll_resp.json.return_value = no_cost_response
|
||||
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.post = AsyncMock(return_value=create_resp)
|
||||
mock_instance.get = AsyncMock(return_value=poll_resp)
|
||||
|
||||
with (
|
||||
patch("backend.blocks.exa.research.Requests", return_value=mock_instance),
|
||||
patch("asyncio.sleep", new=AsyncMock()),
|
||||
):
|
||||
async for _ in block.run(
|
||||
block.Input(
|
||||
instructions="test instructions",
|
||||
wait_for_completion=True,
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
),
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert merged == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ExaGetResearchBlock — cost_dollars from single GET response
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExaGetResearchBlockCostTracking:
|
||||
"""ExaGetResearchBlock merges cost when the fetched research has cost_dollars."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_merged_from_completed_research(self):
|
||||
"""merge_stats called with provider_cost=total when research has costDollars."""
|
||||
from backend.blocks.exa.research import ExaGetResearchBlock
|
||||
|
||||
block = ExaGetResearchBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
get_resp = MagicMock()
|
||||
get_resp.json.return_value = COMPLETED_RESEARCH_RESPONSE
|
||||
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get = AsyncMock(return_value=get_resp)
|
||||
|
||||
with patch("backend.blocks.exa.research.Requests", return_value=mock_instance):
|
||||
async for _ in block.run(
|
||||
block.Input(
|
||||
research_id="test-research-id",
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
),
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.05)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_merge_when_no_cost_dollars(self):
|
||||
"""When research has no costDollars, merge_stats is not called."""
|
||||
from backend.blocks.exa.research import ExaGetResearchBlock
|
||||
|
||||
block = ExaGetResearchBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
no_cost_response = {**COMPLETED_RESEARCH_RESPONSE, "costDollars": None}
|
||||
get_resp = MagicMock()
|
||||
get_resp.json.return_value = no_cost_response
|
||||
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get = AsyncMock(return_value=get_resp)
|
||||
|
||||
with patch("backend.blocks.exa.research.Requests", return_value=mock_instance):
|
||||
async for _ in block.run(
|
||||
block.Input(
|
||||
research_id="test-research-id",
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
),
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert merged == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ExaWaitForResearchBlock — cost_dollars from polling response
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExaWaitForResearchBlockCostTracking:
|
||||
"""ExaWaitForResearchBlock merges cost when the polled research has cost_dollars."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_merged_when_research_completes(self):
|
||||
"""merge_stats called with provider_cost=total once polling returns completed."""
|
||||
from backend.blocks.exa.research import ExaWaitForResearchBlock
|
||||
|
||||
block = ExaWaitForResearchBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
poll_resp = MagicMock()
|
||||
poll_resp.json.return_value = COMPLETED_RESEARCH_RESPONSE
|
||||
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get = AsyncMock(return_value=poll_resp)
|
||||
|
||||
with (
|
||||
patch("backend.blocks.exa.research.Requests", return_value=mock_instance),
|
||||
patch("asyncio.sleep", new=AsyncMock()),
|
||||
):
|
||||
async for _ in block.run(
|
||||
block.Input(
|
||||
research_id="test-research-id",
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
),
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.05)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_merge_when_no_cost_dollars(self):
|
||||
"""When completed research has no costDollars, merge_stats is not called."""
|
||||
from backend.blocks.exa.research import ExaWaitForResearchBlock
|
||||
|
||||
block = ExaWaitForResearchBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
no_cost_response = {**COMPLETED_RESEARCH_RESPONSE, "costDollars": None}
|
||||
poll_resp = MagicMock()
|
||||
poll_resp.json.return_value = no_cost_response
|
||||
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get = AsyncMock(return_value=poll_resp)
|
||||
|
||||
with (
|
||||
patch("backend.blocks.exa.research.Requests", return_value=mock_instance),
|
||||
patch("asyncio.sleep", new=AsyncMock()),
|
||||
):
|
||||
async for _ in block.run(
|
||||
block.Input(
|
||||
research_id="test-research-id",
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore[arg-type]
|
||||
),
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert merged == []
|
||||
@@ -12,6 +12,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -232,6 +233,11 @@ class ExaCreateResearchBlock(Block):
|
||||
|
||||
if research.cost_dollars:
|
||||
yield "cost_total", research.cost_dollars.total
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=research.cost_dollars.total
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
await asyncio.sleep(check_interval)
|
||||
@@ -346,6 +352,9 @@ class ExaGetResearchBlock(Block):
|
||||
yield "cost_searches", research.cost_dollars.num_searches
|
||||
yield "cost_pages", research.cost_dollars.num_pages
|
||||
yield "cost_reasoning_tokens", research.cost_dollars.reasoning_tokens
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=research.cost_dollars.total)
|
||||
)
|
||||
|
||||
yield "error_message", research.error
|
||||
|
||||
@@ -432,6 +441,9 @@ class ExaWaitForResearchBlock(Block):
|
||||
|
||||
if research.cost_dollars:
|
||||
yield "cost_total", research.cost_dollars.total
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=research.cost_dollars.total)
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -206,3 +207,6 @@ class ExaSearchBlock(Block):
|
||||
|
||||
if response.cost_dollars:
|
||||
yield "cost_dollars", response.cost_dollars
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=response.cost_dollars.total)
|
||||
)
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -167,3 +168,6 @@ class ExaFindSimilarBlock(Block):
|
||||
|
||||
if response.cost_dollars:
|
||||
yield "cost_dollars", response.cost_dollars
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(provider_cost=response.cost_dollars.total)
|
||||
)
|
||||
|
||||
@@ -14,6 +14,7 @@ from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -117,6 +118,11 @@ class GoogleMapsSearchBlock(Block):
|
||||
input_data.radius,
|
||||
input_data.max_results,
|
||||
)
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=float(len(places)), provider_cost_type="items"
|
||||
)
|
||||
)
|
||||
for place in places:
|
||||
yield "place", place
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from backend.blocks.jina._auth import (
|
||||
JinaCredentialsField,
|
||||
JinaCredentialsInput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.model import NodeExecutionStats, SchemaField
|
||||
from backend.util.request import Requests
|
||||
|
||||
|
||||
@@ -45,5 +45,13 @@ class JinaEmbeddingBlock(Block):
|
||||
}
|
||||
data = {"input": input_data.texts, "model": input_data.model}
|
||||
response = await Requests().post(url, headers=headers, json=data)
|
||||
embeddings = [e["embedding"] for e in response.json()["data"]]
|
||||
resp_json = response.json()
|
||||
embeddings = [e["embedding"] for e in resp_json["data"]]
|
||||
usage = resp_json.get("usage", {})
|
||||
if usage.get("total_tokens"):
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=usage.get("total_tokens", 0),
|
||||
)
|
||||
)
|
||||
yield "embeddings", embeddings
|
||||
|
||||
@@ -13,6 +13,7 @@ import ollama
|
||||
import openai
|
||||
from anthropic.types import ToolParam
|
||||
from groq import AsyncGroq
|
||||
from openai.types.chat import ChatCompletion as OpenAIChatCompletion
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from backend.blocks._base import (
|
||||
@@ -737,6 +738,7 @@ class LLMResponse(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
reasoning: Optional[str] = None
|
||||
provider_cost: float | None = None
|
||||
|
||||
|
||||
def convert_openai_tool_fmt_to_anthropic(
|
||||
@@ -771,6 +773,32 @@ def convert_openai_tool_fmt_to_anthropic(
|
||||
return anthropic_tools
|
||||
|
||||
|
||||
def extract_openrouter_cost(response: OpenAIChatCompletion) -> float | None:
|
||||
"""Extract OpenRouter's `x-total-cost` header from an OpenAI SDK response.
|
||||
|
||||
OpenRouter returns the per-request USD cost in a response header. The
|
||||
OpenAI SDK exposes the raw httpx response via an undocumented `_response`
|
||||
attribute. We use try/except AttributeError so that if the SDK ever drops
|
||||
or renames that attribute, the warning is visible in logs rather than
|
||||
silently degrading to no cost tracking.
|
||||
"""
|
||||
try:
|
||||
raw_resp = response._response # type: ignore[attr-defined]
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
"OpenAI SDK response missing _response attribute"
|
||||
" — OpenRouter cost tracking unavailable"
|
||||
)
|
||||
return None
|
||||
try:
|
||||
cost_header = raw_resp.headers.get("x-total-cost")
|
||||
if not cost_header:
|
||||
return None
|
||||
return float(cost_header)
|
||||
except (ValueError, TypeError, AttributeError):
|
||||
return None
|
||||
|
||||
|
||||
def extract_openai_reasoning(response) -> str | None:
|
||||
"""Extract reasoning from OpenAI-compatible response if available."""
|
||||
"""Note: This will likely not working since the reasoning is not present in another Response API"""
|
||||
@@ -1103,6 +1131,7 @@ async def llm_call(
|
||||
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
|
||||
completion_tokens=response.usage.completion_tokens if response.usage else 0,
|
||||
reasoning=reasoning,
|
||||
provider_cost=extract_openrouter_cost(response),
|
||||
)
|
||||
elif provider == "llama_api":
|
||||
tools_param = tools if tools else openai.NOT_GIVEN
|
||||
@@ -1410,6 +1439,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
|
||||
error_feedback_message = ""
|
||||
llm_model = input_data.model
|
||||
last_attempt_cost: float | None = None
|
||||
|
||||
for retry_count in range(input_data.retry):
|
||||
logger.debug(f"LLM request: {prompt}")
|
||||
@@ -1427,12 +1457,15 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
max_tokens=input_data.max_tokens,
|
||||
)
|
||||
response_text = llm_response.response
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=llm_response.prompt_tokens,
|
||||
output_token_count=llm_response.completion_tokens,
|
||||
)
|
||||
# Merge token counts for every attempt (each call costs tokens).
|
||||
# provider_cost (actual USD) is tracked separately and only merged
|
||||
# on success to avoid double-counting across retries.
|
||||
token_stats = NodeExecutionStats(
|
||||
input_token_count=llm_response.prompt_tokens,
|
||||
output_token_count=llm_response.completion_tokens,
|
||||
)
|
||||
self.merge_stats(token_stats)
|
||||
last_attempt_cost = llm_response.provider_cost
|
||||
logger.debug(f"LLM attempt-{retry_count} response: {response_text}")
|
||||
|
||||
if input_data.expected_format:
|
||||
@@ -1501,6 +1534,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
NodeExecutionStats(
|
||||
llm_call_count=retry_count + 1,
|
||||
llm_retry_count=retry_count,
|
||||
provider_cost=last_attempt_cost,
|
||||
)
|
||||
)
|
||||
yield "response", response_obj
|
||||
@@ -1521,6 +1555,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
NodeExecutionStats(
|
||||
llm_call_count=retry_count + 1,
|
||||
llm_retry_count=retry_count,
|
||||
provider_cost=last_attempt_cost,
|
||||
)
|
||||
)
|
||||
yield "response", {"response": response_text}
|
||||
|
||||
@@ -23,7 +23,7 @@ from backend.blocks.smartlead.models import (
|
||||
SaveSequencesResponse,
|
||||
Sequence,
|
||||
)
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
|
||||
|
||||
|
||||
class CreateCampaignBlock(Block):
|
||||
@@ -226,6 +226,12 @@ class AddLeadToCampaignBlock(Block):
|
||||
response = await self.add_leads_to_campaign(
|
||||
input_data.campaign_id, input_data.lead_list, credentials
|
||||
)
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=float(len(input_data.lead_list)),
|
||||
provider_cost_type="items",
|
||||
)
|
||||
)
|
||||
|
||||
yield "campaign_id", input_data.campaign_id
|
||||
yield "upload_count", response.upload_count
|
||||
|
||||
@@ -199,6 +199,66 @@ class TestLLMStatsTracking:
|
||||
assert block.execution_stats.llm_call_count == 2 # retry_count + 1 = 1 + 1 = 2
|
||||
assert block.execution_stats.llm_retry_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_cost_uses_last_attempt_only(self):
|
||||
"""provider_cost is only merged from the final successful attempt.
|
||||
|
||||
Intermediate retry costs are intentionally dropped to avoid
|
||||
double-counting: the cost of failed attempts is captured in
|
||||
last_attempt_cost only when the loop eventually succeeds.
|
||||
"""
|
||||
import backend.blocks.llm as llm
|
||||
|
||||
block = llm.AIStructuredResponseGeneratorBlock()
|
||||
call_count = 0
|
||||
|
||||
async def mock_llm_call(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
# First attempt: fails validation, returns cost $0.01
|
||||
return llm.LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response='<json_output id="test123456">{"wrong": "key"}</json_output>',
|
||||
tool_calls=None,
|
||||
prompt_tokens=10,
|
||||
completion_tokens=5,
|
||||
reasoning=None,
|
||||
provider_cost=0.01,
|
||||
)
|
||||
# Second attempt: succeeds, returns cost $0.02
|
||||
return llm.LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response='<json_output id="test123456">{"key1": "value1", "key2": "value2"}</json_output>',
|
||||
tool_calls=None,
|
||||
prompt_tokens=20,
|
||||
completion_tokens=10,
|
||||
reasoning=None,
|
||||
provider_cost=0.02,
|
||||
)
|
||||
|
||||
block.llm_call = mock_llm_call # type: ignore
|
||||
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test prompt",
|
||||
expected_format={"key1": "desc1", "key2": "desc2"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
retry=2,
|
||||
)
|
||||
|
||||
with patch("secrets.token_hex", return_value="test123456"):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
# Only the final successful attempt's cost is merged
|
||||
assert block.execution_stats.provider_cost == pytest.approx(0.02)
|
||||
# Tokens from both attempts accumulate
|
||||
assert block.execution_stats.input_token_count == 30
|
||||
assert block.execution_stats.output_token_count == 15
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ai_text_summarizer_multiple_chunks(self):
|
||||
"""Test that AITextSummarizerBlock correctly accumulates stats across multiple chunks."""
|
||||
@@ -987,3 +1047,51 @@ class TestLlmModelMissing:
|
||||
assert (
|
||||
llm.LlmModel("extra/google/gemini-2.5-pro") == llm.LlmModel.GEMINI_2_5_PRO
|
||||
)
|
||||
|
||||
|
||||
class TestExtractOpenRouterCost:
|
||||
"""Tests for extract_openrouter_cost — the x-total-cost header parser."""
|
||||
|
||||
def _mk_response(self, headers: dict | None):
|
||||
response = MagicMock()
|
||||
if headers is None:
|
||||
response._response = None
|
||||
else:
|
||||
raw = MagicMock()
|
||||
raw.headers = headers
|
||||
response._response = raw
|
||||
return response
|
||||
|
||||
def test_extracts_numeric_cost(self):
|
||||
response = self._mk_response({"x-total-cost": "0.0042"})
|
||||
assert llm.extract_openrouter_cost(response) == 0.0042
|
||||
|
||||
def test_returns_none_when_header_missing(self):
|
||||
response = self._mk_response({})
|
||||
assert llm.extract_openrouter_cost(response) is None
|
||||
|
||||
def test_returns_none_when_header_empty_string(self):
|
||||
response = self._mk_response({"x-total-cost": ""})
|
||||
assert llm.extract_openrouter_cost(response) is None
|
||||
|
||||
def test_returns_none_when_header_non_numeric(self):
|
||||
response = self._mk_response({"x-total-cost": "not-a-number"})
|
||||
assert llm.extract_openrouter_cost(response) is None
|
||||
|
||||
def test_returns_none_when_no_response_attr(self):
|
||||
response = MagicMock(spec=[]) # no _response attr
|
||||
assert llm.extract_openrouter_cost(response) is None
|
||||
|
||||
def test_returns_none_when_raw_is_none(self):
|
||||
response = self._mk_response(None)
|
||||
assert llm.extract_openrouter_cost(response) is None
|
||||
|
||||
def test_returns_none_when_raw_has_no_headers(self):
|
||||
response = MagicMock()
|
||||
response._response = MagicMock(spec=[]) # no headers attr
|
||||
assert llm.extract_openrouter_cost(response) is None
|
||||
|
||||
def test_returns_zero_for_zero_cost(self):
|
||||
"""Zero-cost is a valid value (free tier) and must not become None."""
|
||||
response = self._mk_response({"x-total-cost": "0"})
|
||||
assert llm.extract_openrouter_cost(response) == 0.0
|
||||
|
||||
@@ -13,6 +13,7 @@ from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -104,4 +105,10 @@ class UnrealTextToSpeechBlock(Block):
|
||||
input_data.text,
|
||||
input_data.voice_id,
|
||||
)
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
provider_cost=float(len(input_data.text)),
|
||||
provider_cost_type="characters",
|
||||
)
|
||||
)
|
||||
yield "mp3_url", api_response["OutputUri"]
|
||||
|
||||
@@ -9,6 +9,7 @@ shared tool registry as the SDK path.
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
@@ -22,6 +23,7 @@ from typing import TYPE_CHECKING, Any, cast
|
||||
import orjson
|
||||
from langfuse import propagate_attributes
|
||||
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolParam
|
||||
from opentelemetry import trace as otel_trace
|
||||
|
||||
from backend.copilot.config import CopilotMode
|
||||
from backend.copilot.context import get_workspace_manager, set_execution_context
|
||||
@@ -334,6 +336,7 @@ class _BaselineStreamState:
|
||||
text_started: bool = False
|
||||
turn_prompt_tokens: int = 0
|
||||
turn_completion_tokens: int = 0
|
||||
cost_usd: float | None = None
|
||||
thinking_stripper: _ThinkingStripper = field(default_factory=_ThinkingStripper)
|
||||
session_messages: list[ChatMessage] = field(default_factory=list)
|
||||
|
||||
@@ -354,6 +357,7 @@ async def _baseline_llm_caller(
|
||||
state.thinking_stripper = _ThinkingStripper()
|
||||
|
||||
round_text = ""
|
||||
response = None # initialized before try so finally block can access it
|
||||
try:
|
||||
client = _get_openai_client()
|
||||
typed_messages = cast(list[ChatCompletionMessageParam], messages)
|
||||
@@ -430,6 +434,18 @@ async def _baseline_llm_caller(
|
||||
state.text_started = False
|
||||
state.text_block_id = str(uuid.uuid4())
|
||||
finally:
|
||||
# Extract OpenRouter cost from response headers (in finally so we
|
||||
# capture cost even when the stream errors mid-way — we already paid).
|
||||
# Accumulate across multi-round tool-calling turns.
|
||||
try:
|
||||
cost_header = response._response.headers.get("x-total-cost") # type: ignore[attr-defined]
|
||||
if cost_header:
|
||||
cost = float(cost_header)
|
||||
if math.isfinite(cost):
|
||||
state.cost_usd = (state.cost_usd or 0.0) + max(0.0, cost)
|
||||
except (AttributeError, ValueError):
|
||||
pass
|
||||
|
||||
# Always persist partial text so the session history stays consistent,
|
||||
# even when the stream is interrupted by an exception.
|
||||
state.assistant_text += round_text
|
||||
@@ -1183,8 +1199,22 @@ async def stream_chat_completion_baseline(
|
||||
yield StreamError(errorText=error_msg, code="baseline_error")
|
||||
# Still persist whatever we got
|
||||
finally:
|
||||
# Close Langfuse trace context
|
||||
# Set cost attributes on OTEL span before closing
|
||||
if _trace_ctx is not None:
|
||||
try:
|
||||
span = otel_trace.get_current_span()
|
||||
if span and span.is_recording():
|
||||
span.set_attribute(
|
||||
"gen_ai.usage.prompt_tokens", state.turn_prompt_tokens
|
||||
)
|
||||
span.set_attribute(
|
||||
"gen_ai.usage.completion_tokens",
|
||||
state.turn_completion_tokens,
|
||||
)
|
||||
if state.cost_usd is not None:
|
||||
span.set_attribute("gen_ai.usage.cost_usd", state.cost_usd)
|
||||
except Exception:
|
||||
logger.debug("[Baseline] Failed to set OTEL cost attributes")
|
||||
try:
|
||||
_trace_ctx.__exit__(None, None, None)
|
||||
except Exception:
|
||||
@@ -1226,6 +1256,8 @@ async def stream_chat_completion_baseline(
|
||||
prompt_tokens=state.turn_prompt_tokens,
|
||||
completion_tokens=state.turn_completion_tokens,
|
||||
log_prefix="[Baseline]",
|
||||
cost_usd=state.cost_usd,
|
||||
model=active_model,
|
||||
)
|
||||
|
||||
# Persist structured tool-call history (assistant + tool messages)
|
||||
|
||||
@@ -4,7 +4,7 @@ These tests cover ``_baseline_conversation_updater`` and ``_BaselineStreamState`
|
||||
without requiring API keys, database connections, or network access.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
@@ -631,3 +631,169 @@ class TestPrepareBaselineAttachments:
|
||||
|
||||
assert hint == ""
|
||||
assert blocks == []
|
||||
|
||||
|
||||
class TestBaselineCostExtraction:
|
||||
"""Tests for x-total-cost header extraction in _baseline_llm_caller."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_usd_extracted_from_response_header(self):
|
||||
"""state.cost_usd is set from x-total-cost header when present."""
|
||||
from backend.copilot.baseline.service import (
|
||||
_baseline_llm_caller,
|
||||
_BaselineStreamState,
|
||||
)
|
||||
|
||||
state = _BaselineStreamState(model="gpt-4o-mini")
|
||||
|
||||
# Build a mock raw httpx response with the cost header
|
||||
mock_raw_response = MagicMock()
|
||||
mock_raw_response.headers = {"x-total-cost": "0.0123"}
|
||||
|
||||
# Build a mock async streaming response that yields no chunks but has
|
||||
# a _response attribute pointing to the mock httpx response
|
||||
mock_stream_response = MagicMock()
|
||||
mock_stream_response._response = mock_raw_response
|
||||
|
||||
async def empty_aiter():
|
||||
return
|
||||
yield # make it an async generator
|
||||
|
||||
mock_stream_response.__aiter__ = lambda self: empty_aiter()
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(
|
||||
return_value=mock_stream_response
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service._get_openai_client",
|
||||
return_value=mock_client,
|
||||
):
|
||||
await _baseline_llm_caller(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert state.cost_usd == pytest.approx(0.0123)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_usd_accumulates_across_calls(self):
|
||||
"""cost_usd accumulates when _baseline_llm_caller is called multiple times."""
|
||||
from backend.copilot.baseline.service import (
|
||||
_baseline_llm_caller,
|
||||
_BaselineStreamState,
|
||||
)
|
||||
|
||||
state = _BaselineStreamState(model="gpt-4o-mini")
|
||||
|
||||
def make_stream_mock(cost: str) -> MagicMock:
|
||||
mock_raw = MagicMock()
|
||||
mock_raw.headers = {"x-total-cost": cost}
|
||||
mock_stream = MagicMock()
|
||||
mock_stream._response = mock_raw
|
||||
|
||||
async def empty_aiter():
|
||||
return
|
||||
yield
|
||||
|
||||
mock_stream.__aiter__ = lambda self: empty_aiter()
|
||||
return mock_stream
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(
|
||||
side_effect=[make_stream_mock("0.01"), make_stream_mock("0.02")]
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service._get_openai_client",
|
||||
return_value=mock_client,
|
||||
):
|
||||
await _baseline_llm_caller(
|
||||
messages=[{"role": "user", "content": "first"}],
|
||||
tools=[],
|
||||
state=state,
|
||||
)
|
||||
await _baseline_llm_caller(
|
||||
messages=[{"role": "user", "content": "second"}],
|
||||
tools=[],
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert state.cost_usd == pytest.approx(0.03)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_cost_when_header_absent(self):
|
||||
"""state.cost_usd remains None when response has no x-total-cost header."""
|
||||
from backend.copilot.baseline.service import (
|
||||
_baseline_llm_caller,
|
||||
_BaselineStreamState,
|
||||
)
|
||||
|
||||
state = _BaselineStreamState(model="gpt-4o-mini")
|
||||
|
||||
mock_raw = MagicMock()
|
||||
mock_raw.headers = {}
|
||||
mock_stream = MagicMock()
|
||||
mock_stream._response = mock_raw
|
||||
|
||||
async def empty_aiter():
|
||||
return
|
||||
yield
|
||||
|
||||
mock_stream.__aiter__ = lambda self: empty_aiter()
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=mock_stream)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service._get_openai_client",
|
||||
return_value=mock_client,
|
||||
):
|
||||
await _baseline_llm_caller(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert state.cost_usd is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_extracted_even_when_stream_raises(self):
|
||||
"""cost_usd is captured in the finally block even when streaming fails."""
|
||||
from backend.copilot.baseline.service import (
|
||||
_baseline_llm_caller,
|
||||
_BaselineStreamState,
|
||||
)
|
||||
|
||||
state = _BaselineStreamState(model="gpt-4o-mini")
|
||||
|
||||
mock_raw = MagicMock()
|
||||
mock_raw.headers = {"x-total-cost": "0.005"}
|
||||
mock_stream = MagicMock()
|
||||
mock_stream._response = mock_raw
|
||||
|
||||
async def failing_aiter():
|
||||
raise RuntimeError("stream error")
|
||||
yield # make it an async generator
|
||||
|
||||
mock_stream.__aiter__ = lambda self: failing_aiter()
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=mock_stream)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.baseline.service._get_openai_client",
|
||||
return_value=mock_client,
|
||||
),
|
||||
pytest.raises(RuntimeError, match="stream error"),
|
||||
):
|
||||
await _baseline_llm_caller(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert state.cost_usd == pytest.approx(0.005)
|
||||
|
||||
@@ -151,8 +151,8 @@ class CoPilotProcessor:
|
||||
This method is called once per worker thread to set up the async event
|
||||
loop and initialize any required resources.
|
||||
|
||||
Database is accessed only through DatabaseManager, so we don't need to connect
|
||||
to Prisma directly.
|
||||
DB operations route through DatabaseManagerAsyncClient (RPC) via the
|
||||
db_accessors pattern — no direct Prisma connection is needed here.
|
||||
"""
|
||||
configure_logging()
|
||||
set_service_name("CoPilotExecutor")
|
||||
|
||||
@@ -15,6 +15,7 @@ from prisma.models import User as PrismaUser
|
||||
from pydantic import BaseModel, Field
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from backend.data.db_accessors import user_db
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.util.cache import cached
|
||||
|
||||
@@ -409,9 +410,12 @@ async def _fetch_user_tier(user_id: str) -> SubscriptionTier:
|
||||
prevents a race condition where a non-existent user's ``DEFAULT_TIER`` is
|
||||
cached and then persists after the user is created with a higher tier.
|
||||
"""
|
||||
user = await PrismaUser.prisma().find_unique(where={"id": user_id})
|
||||
if user and user.subscriptionTier: # type: ignore[reportAttributeAccessIssue]
|
||||
return SubscriptionTier(user.subscriptionTier) # type: ignore[reportAttributeAccessIssue]
|
||||
try:
|
||||
user = await user_db().get_user_by_id(user_id)
|
||||
except Exception:
|
||||
raise _UserNotFoundError(user_id)
|
||||
if user.subscription_tier:
|
||||
return SubscriptionTier(user.subscription_tier)
|
||||
raise _UserNotFoundError(user_id)
|
||||
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ from claude_agent_sdk import (
|
||||
)
|
||||
from langfuse import propagate_attributes
|
||||
from langsmith.integrations.claude_agent_sdk import configure_claude_agent_sdk
|
||||
from opentelemetry import trace as otel_trace
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.context import get_workspace_manager
|
||||
@@ -2372,8 +2373,26 @@ async def stream_chat_completion_sdk(
|
||||
|
||||
raise
|
||||
finally:
|
||||
# --- Close OTEL context ---
|
||||
# --- Close OTEL context (with cost attributes) ---
|
||||
if _otel_ctx is not None:
|
||||
try:
|
||||
span = otel_trace.get_current_span()
|
||||
if span and span.is_recording():
|
||||
span.set_attribute("gen_ai.usage.prompt_tokens", turn_prompt_tokens)
|
||||
span.set_attribute(
|
||||
"gen_ai.usage.completion_tokens", turn_completion_tokens
|
||||
)
|
||||
span.set_attribute(
|
||||
"gen_ai.usage.cache_read_tokens", turn_cache_read_tokens
|
||||
)
|
||||
span.set_attribute(
|
||||
"gen_ai.usage.cache_creation_tokens",
|
||||
turn_cache_creation_tokens,
|
||||
)
|
||||
if turn_cost_usd is not None:
|
||||
span.set_attribute("gen_ai.usage.cost_usd", turn_cost_usd)
|
||||
except Exception:
|
||||
logger.debug("Failed to set OTEL cost attributes", exc_info=True)
|
||||
try:
|
||||
_otel_ctx.__exit__(*sys.exc_info())
|
||||
except Exception:
|
||||
@@ -2391,6 +2410,8 @@ async def stream_chat_completion_sdk(
|
||||
cache_creation_tokens=turn_cache_creation_tokens,
|
||||
log_prefix=log_prefix,
|
||||
cost_usd=turn_cost_usd,
|
||||
model=config.model,
|
||||
provider="anthropic",
|
||||
)
|
||||
|
||||
# --- Persist session messages ---
|
||||
|
||||
@@ -4,17 +4,85 @@ Both the baseline (OpenRouter) and SDK (Anthropic) service layers need to:
|
||||
1. Append a ``Usage`` record to the session.
|
||||
2. Log the turn's token counts.
|
||||
3. Record weighted usage in Redis for rate-limiting.
|
||||
4. Write a PlatformCostLog entry for admin cost tracking.
|
||||
|
||||
This module extracts that common logic so both paths stay in sync.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
import threading
|
||||
|
||||
from backend.data.db_accessors import platform_cost_db
|
||||
from backend.data.platform_cost import PlatformCostEntry, usd_to_microdollars
|
||||
|
||||
from .model import ChatSession, Usage
|
||||
from .rate_limit import record_token_usage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Hold strong references to in-flight cost log tasks to prevent GC.
|
||||
_pending_log_tasks: set[asyncio.Task[None]] = set()
|
||||
# Guards all reads and writes to _pending_log_tasks. Done callbacks (discard)
|
||||
# fire from the event loop thread; drain_pending_cost_logs iterates the set
|
||||
# from any caller — the lock prevents RuntimeError from concurrent modification.
|
||||
_pending_log_tasks_lock = threading.Lock()
|
||||
# Per-loop semaphores: asyncio.Semaphore is not thread-safe and must not be
|
||||
# shared across event loops running in different threads.
|
||||
_log_semaphores: dict[asyncio.AbstractEventLoop, asyncio.Semaphore] = {}
|
||||
|
||||
|
||||
def _get_log_semaphore() -> asyncio.Semaphore:
|
||||
loop = asyncio.get_running_loop()
|
||||
sem = _log_semaphores.get(loop)
|
||||
if sem is None:
|
||||
sem = asyncio.Semaphore(50)
|
||||
_log_semaphores[loop] = sem
|
||||
return sem
|
||||
|
||||
|
||||
def _schedule_cost_log(entry: PlatformCostEntry) -> None:
|
||||
"""Schedule a fire-and-forget cost log via DatabaseManagerAsyncClient RPC."""
|
||||
|
||||
async def _safe_log() -> None:
|
||||
async with _get_log_semaphore():
|
||||
try:
|
||||
await platform_cost_db().log_platform_cost(entry)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to log platform cost for user=%s provider=%s block=%s",
|
||||
entry.user_id,
|
||||
entry.provider,
|
||||
entry.block_name,
|
||||
)
|
||||
|
||||
task = asyncio.create_task(_safe_log())
|
||||
with _pending_log_tasks_lock:
|
||||
_pending_log_tasks.add(task)
|
||||
|
||||
def _remove(t: asyncio.Task[None]) -> None:
|
||||
with _pending_log_tasks_lock:
|
||||
_pending_log_tasks.discard(t)
|
||||
|
||||
task.add_done_callback(_remove)
|
||||
|
||||
|
||||
# Identifiers used by PlatformCostLog for copilot turns (not tied to a real
|
||||
# block/credential in the block_cost_config or credentials_store tables).
|
||||
COPILOT_BLOCK_ID = "copilot"
|
||||
COPILOT_CREDENTIAL_ID = "copilot_system"
|
||||
|
||||
|
||||
def _copilot_block_name(log_prefix: str) -> str:
|
||||
"""Extract stable block_name from ``"[SDK][session][T1]"`` -> ``"copilot:SDK"``."""
|
||||
match = re.search(r"\[([A-Za-z][A-Za-z0-9_]*)\]", log_prefix)
|
||||
if match:
|
||||
return f"{COPILOT_BLOCK_ID}:{match.group(1)}"
|
||||
tag = log_prefix.strip(" []")
|
||||
return f"{COPILOT_BLOCK_ID}:{tag}" if tag else COPILOT_BLOCK_ID
|
||||
|
||||
|
||||
async def persist_and_record_usage(
|
||||
*,
|
||||
@@ -26,6 +94,8 @@ async def persist_and_record_usage(
|
||||
cache_creation_tokens: int = 0,
|
||||
log_prefix: str = "",
|
||||
cost_usd: float | str | None = None,
|
||||
model: str | None = None,
|
||||
provider: str = "open_router",
|
||||
) -> int:
|
||||
"""Persist token usage to session and record for rate limiting.
|
||||
|
||||
@@ -38,6 +108,7 @@ async def persist_and_record_usage(
|
||||
cache_creation_tokens: Tokens written to prompt cache (Anthropic only).
|
||||
log_prefix: Prefix for log messages (e.g. "[SDK]", "[Baseline]").
|
||||
cost_usd: Optional cost for logging (float from SDK, str otherwise).
|
||||
provider: Cost provider name (e.g. "anthropic", "open_router").
|
||||
|
||||
Returns:
|
||||
The computed total_tokens (prompt + completion; cache excluded).
|
||||
@@ -47,12 +118,13 @@ async def persist_and_record_usage(
|
||||
cache_read_tokens = max(0, cache_read_tokens)
|
||||
cache_creation_tokens = max(0, cache_creation_tokens)
|
||||
|
||||
if (
|
||||
no_tokens = (
|
||||
prompt_tokens <= 0
|
||||
and completion_tokens <= 0
|
||||
and cache_read_tokens <= 0
|
||||
and cache_creation_tokens <= 0
|
||||
):
|
||||
)
|
||||
if no_tokens and cost_usd is None:
|
||||
return 0
|
||||
|
||||
# total_tokens = prompt + completion. Cache tokens are tracked
|
||||
@@ -73,14 +145,14 @@ async def persist_and_record_usage(
|
||||
|
||||
if cache_read_tokens or cache_creation_tokens:
|
||||
logger.info(
|
||||
f"{log_prefix} Turn usage: uncached={prompt_tokens}, "
|
||||
f"cache_read={cache_read_tokens}, cache_create={cache_creation_tokens}, "
|
||||
f"output={completion_tokens}, total={total_tokens}, cost_usd={cost_usd}"
|
||||
f"{log_prefix} Turn usage: uncached={prompt_tokens}, cache_read={cache_read_tokens},"
|
||||
f" cache_create={cache_creation_tokens}, output={completion_tokens},"
|
||||
f" total={total_tokens}, cost_usd={cost_usd}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"{log_prefix} Turn usage: prompt={prompt_tokens}, "
|
||||
f"completion={completion_tokens}, total={total_tokens}"
|
||||
f"{log_prefix} Turn usage: prompt={prompt_tokens}, completion={completion_tokens},"
|
||||
f" total={total_tokens}"
|
||||
)
|
||||
|
||||
if user_id:
|
||||
@@ -93,6 +165,54 @@ async def persist_and_record_usage(
|
||||
cache_creation_tokens=cache_creation_tokens,
|
||||
)
|
||||
except Exception as usage_err:
|
||||
logger.warning(f"{log_prefix} Failed to record token usage: {usage_err}")
|
||||
logger.warning("%s Failed to record token usage: %s", log_prefix, usage_err)
|
||||
|
||||
# Log to PlatformCostLog for admin cost dashboard.
|
||||
# Include entries where cost_usd is set even if token count is 0
|
||||
# (e.g. fully-cached Anthropic responses where only cache tokens
|
||||
# accumulate a charge without incrementing total_tokens).
|
||||
if user_id and (total_tokens > 0 or cost_usd is not None):
|
||||
cost_float = None
|
||||
if cost_usd is not None:
|
||||
try:
|
||||
val = float(cost_usd)
|
||||
if math.isfinite(val) and val >= 0:
|
||||
cost_float = val
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
cost_microdollars = usd_to_microdollars(cost_float)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if cost_float is not None:
|
||||
tracking_type = "cost_usd"
|
||||
tracking_amount = cost_float
|
||||
else:
|
||||
tracking_type = "tokens"
|
||||
tracking_amount = total_tokens
|
||||
|
||||
_schedule_cost_log(
|
||||
PlatformCostEntry(
|
||||
user_id=user_id,
|
||||
graph_exec_id=session_id,
|
||||
block_id=COPILOT_BLOCK_ID,
|
||||
block_name=_copilot_block_name(log_prefix),
|
||||
provider=provider,
|
||||
credential_id=COPILOT_CREDENTIAL_ID,
|
||||
cost_microdollars=cost_microdollars,
|
||||
input_tokens=prompt_tokens,
|
||||
output_tokens=completion_tokens,
|
||||
model=model,
|
||||
tracking_type=tracking_type,
|
||||
tracking_amount=tracking_amount,
|
||||
metadata={
|
||||
"tracking_type": tracking_type,
|
||||
"tracking_amount": tracking_amount,
|
||||
"cache_read_tokens": cache_read_tokens,
|
||||
"cache_creation_tokens": cache_creation_tokens,
|
||||
"source": "copilot",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return total_tokens
|
||||
|
||||
@@ -4,6 +4,7 @@ Covers both the baseline (prompt+completion only) and SDK (with cache breakdown)
|
||||
calling conventions, session persistence, and rate-limit recording.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
@@ -279,3 +280,260 @@ class TestRateLimitRecording:
|
||||
completion_tokens=0,
|
||||
)
|
||||
mock_record.assert_not_awaited()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PlatformCostLog integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPlatformCostLogging:
|
||||
@pytest.mark.asyncio
|
||||
async def test_logs_cost_entry_with_cost_usd(self):
|
||||
"""When cost_usd is provided, tracking_type should be 'cost_usd'."""
|
||||
mock_log = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.token_tracking.platform_cost_db",
|
||||
return_value=type(
|
||||
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
|
||||
)(),
|
||||
),
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
session=_make_session(),
|
||||
user_id="user-cost",
|
||||
prompt_tokens=200,
|
||||
completion_tokens=100,
|
||||
cost_usd=0.005,
|
||||
model="gpt-4",
|
||||
provider="anthropic",
|
||||
log_prefix="[SDK]",
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
mock_log.assert_awaited_once()
|
||||
entry = mock_log.call_args[0][0]
|
||||
assert entry.user_id == "user-cost"
|
||||
assert entry.provider == "anthropic"
|
||||
assert entry.model == "gpt-4"
|
||||
assert entry.cost_microdollars == 5000
|
||||
assert entry.input_tokens == 200
|
||||
assert entry.output_tokens == 100
|
||||
assert entry.tracking_type == "cost_usd"
|
||||
assert entry.metadata["tracking_type"] == "cost_usd"
|
||||
assert entry.metadata["tracking_amount"] == 0.005
|
||||
assert entry.block_name == "copilot:SDK"
|
||||
assert entry.graph_exec_id == "sess-test"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logs_cost_entry_without_cost_usd(self):
|
||||
"""When cost_usd is None, tracking_type should be 'tokens'."""
|
||||
mock_log = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.token_tracking.platform_cost_db",
|
||||
return_value=type(
|
||||
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
|
||||
)(),
|
||||
),
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id="user-tokens",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
log_prefix="[Baseline]",
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
mock_log.assert_awaited_once()
|
||||
entry = mock_log.call_args[0][0]
|
||||
assert entry.cost_microdollars is None
|
||||
assert entry.tracking_type == "tokens"
|
||||
assert entry.metadata["tracking_type"] == "tokens"
|
||||
assert entry.metadata["tracking_amount"] == 150
|
||||
assert entry.graph_exec_id is None
|
||||
assert entry.block_name == "copilot:Baseline"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_cost_log_when_no_user_id(self):
|
||||
"""No PlatformCostLog entry when user_id is None."""
|
||||
mock_log = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.token_tracking.platform_cost_db",
|
||||
return_value=type(
|
||||
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
|
||||
)(),
|
||||
),
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id=None,
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
mock_log.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_usd_invalid_string_falls_back_to_tokens(self):
|
||||
"""Invalid cost_usd string should fall back to tokens tracking."""
|
||||
mock_log = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.token_tracking.platform_cost_db",
|
||||
return_value=type(
|
||||
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
|
||||
)(),
|
||||
),
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id="user-invalid",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cost_usd="not-a-number",
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
mock_log.assert_awaited_once()
|
||||
entry = mock_log.call_args[0][0]
|
||||
assert entry.cost_microdollars is None
|
||||
assert entry.metadata["tracking_type"] == "tokens"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_usd_string_number_is_parsed(self):
|
||||
"""String-encoded cost_usd (e.g. from OpenRouter) should be parsed."""
|
||||
mock_log = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.token_tracking.platform_cost_db",
|
||||
return_value=type(
|
||||
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
|
||||
)(),
|
||||
),
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id="user-str",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cost_usd="0.01",
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
mock_log.assert_awaited_once()
|
||||
entry = mock_log.call_args[0][0]
|
||||
assert entry.cost_microdollars == 10_000
|
||||
assert entry.metadata["tracking_type"] == "cost_usd"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_log_prefix_produces_copilot_block_name(self):
|
||||
"""Empty log_prefix results in block_name='copilot'."""
|
||||
mock_log = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.token_tracking.platform_cost_db",
|
||||
return_value=type(
|
||||
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
|
||||
)(),
|
||||
),
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id="user-empty",
|
||||
prompt_tokens=10,
|
||||
completion_tokens=5,
|
||||
log_prefix="",
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
entry = mock_log.call_args[0][0]
|
||||
assert entry.block_name == "copilot"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_tokens_included_in_metadata(self):
|
||||
"""Cache token counts should be present in the metadata."""
|
||||
mock_log = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.token_tracking.platform_cost_db",
|
||||
return_value=type(
|
||||
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
|
||||
)(),
|
||||
),
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id="user-cache",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cache_read_tokens=5000,
|
||||
cache_creation_tokens=300,
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
entry = mock_log.call_args[0][0]
|
||||
assert entry.metadata["cache_read_tokens"] == 5000
|
||||
assert entry.metadata["cache_creation_tokens"] == 300
|
||||
assert entry.metadata["source"] == "copilot"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logs_cost_only_when_tokens_zero(self):
|
||||
"""Zero prompt+completion tokens with cost_usd set still logs the entry."""
|
||||
mock_log = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.token_tracking.platform_cost_db",
|
||||
return_value=type(
|
||||
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
|
||||
)(),
|
||||
),
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id="user-cached",
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
cost_usd=0.005,
|
||||
model="claude-3-5-sonnet",
|
||||
provider="anthropic",
|
||||
log_prefix="[SDK]",
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
# Guard: total_tokens == 0 but cost_usd is set — must still log
|
||||
mock_log.assert_awaited_once()
|
||||
entry = mock_log.call_args[0][0]
|
||||
assert entry.user_id == "user-cached"
|
||||
assert entry.tracking_type == "cost_usd"
|
||||
assert entry.cost_microdollars == 5000
|
||||
assert entry.input_tokens == 0
|
||||
assert entry.output_tokens == 0
|
||||
|
||||
@@ -142,3 +142,9 @@ def credit_db():
|
||||
credit_db = get_database_manager_async_client()
|
||||
|
||||
return credit_db
|
||||
|
||||
|
||||
def platform_cost_db():
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
return get_database_manager_async_client()
|
||||
|
||||
@@ -96,6 +96,7 @@ from backend.data.notifications import (
|
||||
remove_notifications_from_batch,
|
||||
)
|
||||
from backend.data.onboarding import increment_onboarding_runs
|
||||
from backend.data.platform_cost import log_platform_cost
|
||||
from backend.data.understanding import (
|
||||
get_business_understanding,
|
||||
upsert_business_understanding,
|
||||
@@ -332,6 +333,9 @@ class DatabaseManager(AppService):
|
||||
get_blocks_needing_optimization = _(get_blocks_needing_optimization)
|
||||
update_block_optimized_description = _(update_block_optimized_description)
|
||||
|
||||
# ============ Platform Cost Tracking ============ #
|
||||
log_platform_cost = _(log_platform_cost)
|
||||
|
||||
# ============ CoPilot Chat Sessions ============ #
|
||||
get_chat_session = _(chat_db.get_chat_session)
|
||||
create_chat_session = _(chat_db.create_chat_session)
|
||||
@@ -529,6 +533,9 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
# ============ Block Descriptions ============ #
|
||||
get_blocks_needing_optimization = d.get_blocks_needing_optimization
|
||||
|
||||
# ============ Platform Cost Tracking ============ #
|
||||
log_platform_cost = d.log_platform_cost
|
||||
|
||||
# ============ CoPilot Chat Sessions ============ #
|
||||
get_chat_session = d.get_chat_session
|
||||
create_chat_session = d.create_chat_session
|
||||
|
||||
@@ -104,6 +104,11 @@ class User(BaseModel):
|
||||
description="User timezone (IANA timezone identifier or 'not-set')",
|
||||
)
|
||||
|
||||
# Subscription / rate-limit tier
|
||||
subscription_tier: str | None = Field(
|
||||
default=None, description="Subscription tier (FREE, PRO, BUSINESS, ENTERPRISE)"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, prisma_user: "PrismaUser") -> "User":
|
||||
"""Convert a database User object to application User model."""
|
||||
@@ -158,6 +163,7 @@ class User(BaseModel):
|
||||
notify_on_weekly_summary=prisma_user.notifyOnWeeklySummary or True,
|
||||
notify_on_monthly_summary=prisma_user.notifyOnMonthlySummary or True,
|
||||
timezone=prisma_user.timezone or USER_TIMEZONE_NOT_SET,
|
||||
subscription_tier=prisma_user.subscriptionTier,
|
||||
)
|
||||
|
||||
|
||||
@@ -819,6 +825,17 @@ class RefundRequest(BaseModel):
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
ProviderCostType = Literal[
|
||||
"cost_usd", # Actual USD cost reported by the provider
|
||||
"tokens", # LLM token counts (sum of input + output)
|
||||
"characters", # Per-character billing (TTS providers)
|
||||
"sandbox_seconds", # Per-second compute billing (e.g. E2B)
|
||||
"walltime_seconds", # Per-second billing incl. queue/polling
|
||||
"per_run", # Per-API-call billing with fixed cost
|
||||
"items", # Per-item billing (lead/organization/result count)
|
||||
]
|
||||
|
||||
|
||||
class NodeExecutionStats(BaseModel):
|
||||
"""Execution statistics for a node execution."""
|
||||
|
||||
@@ -838,32 +855,39 @@ class NodeExecutionStats(BaseModel):
|
||||
output_token_count: int = 0
|
||||
extra_cost: int = 0
|
||||
extra_steps: int = 0
|
||||
provider_cost: float | None = None
|
||||
# Type of the provider-reported cost/usage captured above. When set
|
||||
# by a block, resolve_tracking honors this directly instead of
|
||||
# guessing from provider name.
|
||||
provider_cost_type: Optional[ProviderCostType] = None
|
||||
# Moderation fields
|
||||
cleared_inputs: Optional[dict[str, list[str]]] = None
|
||||
cleared_outputs: Optional[dict[str, list[str]]] = None
|
||||
|
||||
def __iadd__(self, other: "NodeExecutionStats") -> "NodeExecutionStats":
|
||||
"""Mutate this instance by adding another NodeExecutionStats."""
|
||||
"""Mutate this instance by adding another NodeExecutionStats.
|
||||
|
||||
Avoids calling model_dump() twice per merge (called on every
|
||||
merge_stats() from ~20+ blocks); reads via getattr/vars instead.
|
||||
"""
|
||||
if not isinstance(other, NodeExecutionStats):
|
||||
return NotImplemented
|
||||
|
||||
stats_dict = other.model_dump()
|
||||
current_stats = self.model_dump()
|
||||
|
||||
for key, value in stats_dict.items():
|
||||
if key not in current_stats:
|
||||
# Field doesn't exist yet, just set it
|
||||
for key in type(other).model_fields:
|
||||
value = getattr(other, key)
|
||||
if value is None:
|
||||
# Never overwrite an existing value with None
|
||||
continue
|
||||
current = getattr(self, key, None)
|
||||
if current is None:
|
||||
# Field doesn't exist yet or is None, just set it
|
||||
setattr(self, key, value)
|
||||
elif isinstance(value, dict) and isinstance(current_stats[key], dict):
|
||||
current_stats[key].update(value)
|
||||
setattr(self, key, current_stats[key])
|
||||
elif isinstance(value, (int, float)) and isinstance(
|
||||
current_stats[key], (int, float)
|
||||
):
|
||||
setattr(self, key, current_stats[key] + value)
|
||||
elif isinstance(value, list) and isinstance(current_stats[key], list):
|
||||
current_stats[key].extend(value)
|
||||
setattr(self, key, current_stats[key])
|
||||
elif isinstance(value, dict) and isinstance(current, dict):
|
||||
current.update(value)
|
||||
elif isinstance(value, (int, float)) and isinstance(current, (int, float)):
|
||||
setattr(self, key, current + value)
|
||||
elif isinstance(value, list) and isinstance(current, list):
|
||||
current.extend(value)
|
||||
else:
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import HostScopedCredentials
|
||||
from backend.data.model import HostScopedCredentials, NodeExecutionStats
|
||||
|
||||
|
||||
class TestHostScopedCredentials:
|
||||
@@ -166,3 +166,84 @@ class TestHostScopedCredentials:
|
||||
)
|
||||
|
||||
assert creds.matches_url(test_url) == expected
|
||||
|
||||
|
||||
class TestNodeExecutionStatsIadd:
|
||||
def test_adds_numeric_fields(self):
|
||||
a = NodeExecutionStats(input_token_count=100, output_token_count=50)
|
||||
b = NodeExecutionStats(input_token_count=200, output_token_count=30)
|
||||
a += b
|
||||
assert a.input_token_count == 300
|
||||
assert a.output_token_count == 80
|
||||
|
||||
def test_none_does_not_overwrite(self):
|
||||
a = NodeExecutionStats(provider_cost=0.5, error="some error")
|
||||
b = NodeExecutionStats(provider_cost=None, error=None)
|
||||
a += b
|
||||
assert a.provider_cost == 0.5
|
||||
assert a.error == "some error"
|
||||
|
||||
def test_none_is_skipped_preserving_existing_value(self):
|
||||
a = NodeExecutionStats(input_token_count=100)
|
||||
b = NodeExecutionStats()
|
||||
a += b
|
||||
assert a.input_token_count == 100
|
||||
|
||||
def test_dict_fields_are_merged(self):
|
||||
a = NodeExecutionStats(
|
||||
cleared_inputs={"field1": ["val1"]},
|
||||
)
|
||||
b = NodeExecutionStats(
|
||||
cleared_inputs={"field2": ["val2"]},
|
||||
)
|
||||
a += b
|
||||
assert a.cleared_inputs == {"field1": ["val1"], "field2": ["val2"]}
|
||||
|
||||
def test_returns_self(self):
|
||||
a = NodeExecutionStats()
|
||||
b = NodeExecutionStats(input_token_count=10)
|
||||
result = a.__iadd__(b)
|
||||
assert result is a
|
||||
|
||||
def test_not_implemented_for_non_stats(self):
|
||||
a = NodeExecutionStats()
|
||||
result = a.__iadd__("not a stats") # type: ignore[arg-type]
|
||||
assert result is NotImplemented
|
||||
|
||||
def test_error_none_does_not_clear_existing_error(self):
|
||||
a = NodeExecutionStats(error="existing error")
|
||||
b = NodeExecutionStats(error=None)
|
||||
a += b
|
||||
assert a.error == "existing error"
|
||||
|
||||
def test_provider_cost_none_does_not_clear_existing_cost(self):
|
||||
a = NodeExecutionStats(provider_cost=0.05)
|
||||
b = NodeExecutionStats(provider_cost=None)
|
||||
a += b
|
||||
assert a.provider_cost == 0.05
|
||||
|
||||
def test_provider_cost_accumulates_when_both_set(self):
|
||||
a = NodeExecutionStats(provider_cost=0.01)
|
||||
b = NodeExecutionStats(provider_cost=0.02)
|
||||
a += b
|
||||
assert abs((a.provider_cost or 0) - 0.03) < 1e-9
|
||||
|
||||
def test_provider_cost_first_write_from_none(self):
|
||||
a = NodeExecutionStats()
|
||||
b = NodeExecutionStats(provider_cost=0.05)
|
||||
a += b
|
||||
assert a.provider_cost == 0.05
|
||||
|
||||
def test_provider_cost_type_first_write_from_none(self):
|
||||
"""Writing provider_cost_type into a stats with None sets it."""
|
||||
a = NodeExecutionStats()
|
||||
b = NodeExecutionStats(provider_cost_type="characters")
|
||||
a += b
|
||||
assert a.provider_cost_type == "characters"
|
||||
|
||||
def test_provider_cost_type_none_does_not_overwrite(self):
|
||||
"""A None provider_cost_type from other must not clear an existing value."""
|
||||
a = NodeExecutionStats(provider_cost_type="tokens")
|
||||
b = NodeExecutionStats()
|
||||
a += b
|
||||
assert a.provider_cost_type == "tokens"
|
||||
|
||||
390
autogpt_platform/backend/backend/data/platform_cost.py
Normal file
@@ -0,0 +1,390 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.db import execute_raw_with_schema, query_raw_with_schema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MICRODOLLARS_PER_USD = 1_000_000
|
||||
|
||||
# Dashboard query limits — keep in sync with the SQL queries below
|
||||
MAX_PROVIDER_ROWS = 500
|
||||
MAX_USER_ROWS = 100
|
||||
|
||||
# Default date range for dashboard queries when no start date is provided.
|
||||
# Prevents full-table scans on large deployments.
|
||||
DEFAULT_DASHBOARD_DAYS = 30
|
||||
|
||||
|
||||
def usd_to_microdollars(cost_usd: float | None) -> int | None:
|
||||
"""Convert a USD amount (float) to microdollars (int). None-safe."""
|
||||
if cost_usd is None:
|
||||
return None
|
||||
return round(cost_usd * MICRODOLLARS_PER_USD)
|
||||
|
||||
|
||||
class PlatformCostEntry(BaseModel):
|
||||
user_id: str
|
||||
graph_exec_id: str | None = None
|
||||
node_exec_id: str | None = None
|
||||
graph_id: str | None = None
|
||||
node_id: str | None = None
|
||||
block_id: str
|
||||
block_name: str
|
||||
provider: str
|
||||
credential_id: str
|
||||
cost_microdollars: int | None = None
|
||||
input_tokens: int | None = None
|
||||
output_tokens: int | None = None
|
||||
data_size: int | None = None
|
||||
duration: float | None = None
|
||||
model: str | None = None
|
||||
tracking_type: str | None = None
|
||||
tracking_amount: float | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
async def log_platform_cost(entry: PlatformCostEntry) -> None:
|
||||
await execute_raw_with_schema(
|
||||
"""
|
||||
INSERT INTO {schema_prefix}"PlatformCostLog"
|
||||
("id", "createdAt", "userId", "graphExecId", "nodeExecId",
|
||||
"graphId", "nodeId", "blockId", "blockName", "provider",
|
||||
"credentialId", "costMicrodollars", "inputTokens", "outputTokens",
|
||||
"dataSize", "duration", "model", "trackingType", "trackingAmount",
|
||||
"metadata")
|
||||
VALUES (
|
||||
gen_random_uuid(), NOW(), $1, $2, $3, $4, $5, $6, $7, $8, $9,
|
||||
$10, $11, $12, $13, $14, $15, $16, $17, $18::jsonb
|
||||
)
|
||||
""",
|
||||
entry.user_id,
|
||||
entry.graph_exec_id,
|
||||
entry.node_exec_id,
|
||||
entry.graph_id,
|
||||
entry.node_id,
|
||||
entry.block_id,
|
||||
entry.block_name,
|
||||
# Normalize to lowercase so the (provider, createdAt) index is always
|
||||
# used without LOWER() on the read side.
|
||||
entry.provider.lower(),
|
||||
entry.credential_id,
|
||||
entry.cost_microdollars,
|
||||
entry.input_tokens,
|
||||
entry.output_tokens,
|
||||
entry.data_size,
|
||||
entry.duration,
|
||||
entry.model,
|
||||
entry.tracking_type,
|
||||
entry.tracking_amount,
|
||||
_json_or_none(entry.metadata),
|
||||
)
|
||||
|
||||
|
||||
# Bound the number of concurrent cost-log DB inserts to prevent unbounded
|
||||
# task/connection growth under sustained load or DB slowness.
|
||||
_log_semaphore = asyncio.Semaphore(50)
|
||||
|
||||
|
||||
async def log_platform_cost_safe(entry: PlatformCostEntry) -> None:
|
||||
"""Fire-and-forget wrapper that never raises."""
|
||||
try:
|
||||
async with _log_semaphore:
|
||||
await log_platform_cost(entry)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to log platform cost for user=%s provider=%s block=%s",
|
||||
entry.user_id,
|
||||
entry.provider,
|
||||
entry.block_name,
|
||||
)
|
||||
|
||||
|
||||
def _json_or_none(data: dict[str, Any] | None) -> str | None:
|
||||
if data is None:
|
||||
return None
|
||||
return json.dumps(data)
|
||||
|
||||
|
||||
def _mask_email(email: str | None) -> str | None:
|
||||
"""Mask an email address to reduce PII exposure in admin API responses.
|
||||
|
||||
Turns 'user@example.com' into 'us***@example.com'.
|
||||
Handles short local parts gracefully (e.g. 'a@b.com' → 'a***@b.com').
|
||||
"""
|
||||
if not email:
|
||||
return email
|
||||
at = email.find("@")
|
||||
if at < 0:
|
||||
return "***"
|
||||
local = email[:at]
|
||||
domain = email[at:]
|
||||
visible = local[:2] if len(local) >= 2 else local[:1]
|
||||
return f"{visible}***{domain}"
|
||||
|
||||
|
||||
class ProviderCostSummary(BaseModel):
|
||||
provider: str
|
||||
tracking_type: str | None = None
|
||||
total_cost_microdollars: int
|
||||
total_input_tokens: int
|
||||
total_output_tokens: int
|
||||
total_duration_seconds: float = 0.0
|
||||
total_tracking_amount: float = 0.0
|
||||
request_count: int
|
||||
|
||||
|
||||
class UserCostSummary(BaseModel):
|
||||
user_id: str | None = None
|
||||
email: str | None = None
|
||||
total_cost_microdollars: int
|
||||
total_input_tokens: int
|
||||
total_output_tokens: int
|
||||
request_count: int
|
||||
|
||||
|
||||
class CostLogRow(BaseModel):
|
||||
id: str
|
||||
created_at: datetime
|
||||
user_id: str | None = None
|
||||
email: str | None = None
|
||||
graph_exec_id: str | None = None
|
||||
node_exec_id: str | None = None
|
||||
block_name: str
|
||||
provider: str
|
||||
tracking_type: str | None = None
|
||||
cost_microdollars: int | None = None
|
||||
input_tokens: int | None = None
|
||||
output_tokens: int | None = None
|
||||
duration: float | None = None
|
||||
model: str | None = None
|
||||
|
||||
|
||||
class PlatformCostDashboard(BaseModel):
|
||||
by_provider: list[ProviderCostSummary]
|
||||
by_user: list[UserCostSummary]
|
||||
total_cost_microdollars: int
|
||||
total_requests: int
|
||||
total_users: int
|
||||
|
||||
|
||||
def _build_where(
|
||||
start: datetime | None,
|
||||
end: datetime | None,
|
||||
provider: str | None,
|
||||
user_id: str | None,
|
||||
table_alias: str = "",
|
||||
) -> tuple[str, list[Any]]:
|
||||
prefix = f"{table_alias}." if table_alias else ""
|
||||
clauses: list[str] = []
|
||||
params: list[Any] = []
|
||||
idx = 1
|
||||
|
||||
if start:
|
||||
clauses.append(f'{prefix}"createdAt" >= ${idx}::timestamptz')
|
||||
params.append(start)
|
||||
idx += 1
|
||||
if end:
|
||||
clauses.append(f'{prefix}"createdAt" <= ${idx}::timestamptz')
|
||||
params.append(end)
|
||||
idx += 1
|
||||
if provider:
|
||||
# Provider names are normalized to lowercase at write time so a plain
|
||||
# equality check is sufficient and the (provider, createdAt) index is used.
|
||||
clauses.append(f'{prefix}"provider" = ${idx}')
|
||||
params.append(provider.lower())
|
||||
idx += 1
|
||||
if user_id:
|
||||
clauses.append(f'{prefix}"userId" = ${idx}')
|
||||
params.append(user_id)
|
||||
idx += 1
|
||||
|
||||
return (" AND ".join(clauses) if clauses else "TRUE", params)
|
||||
|
||||
|
||||
async def get_platform_cost_dashboard(
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
provider: str | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> PlatformCostDashboard:
|
||||
"""Aggregate platform cost logs for the admin dashboard.
|
||||
|
||||
Note: by_provider rows are keyed on (provider, tracking_type). A single
|
||||
provider can therefore appear in multiple rows if it has entries with
|
||||
different billing models (e.g. "openai" with both "tokens" and "cost_usd"
|
||||
if pricing is later added for some entries). Frontend treats each row
|
||||
independently rather than as a provider primary key.
|
||||
|
||||
Defaults to the last DEFAULT_DASHBOARD_DAYS days when no start date is
|
||||
provided to avoid full-table scans on large deployments.
|
||||
"""
|
||||
if start is None:
|
||||
start = datetime.now(timezone.utc) - timedelta(days=DEFAULT_DASHBOARD_DAYS)
|
||||
where_p, params_p = _build_where(start, end, provider, user_id, "p")
|
||||
|
||||
by_provider_rows, by_user_rows, total_user_rows = await asyncio.gather(
|
||||
query_raw_with_schema(
|
||||
f"""
|
||||
SELECT
|
||||
p."provider",
|
||||
p."trackingType" AS tracking_type,
|
||||
COALESCE(SUM(p."costMicrodollars"), 0)::bigint AS total_cost,
|
||||
COALESCE(SUM(p."inputTokens"), 0)::bigint AS total_input_tokens,
|
||||
COALESCE(SUM(p."outputTokens"), 0)::bigint AS total_output_tokens,
|
||||
COALESCE(SUM(p."duration"), 0)::float AS total_duration,
|
||||
COALESCE(SUM(p."trackingAmount"), 0)::float AS total_tracking_amount,
|
||||
COUNT(*)::bigint AS request_count
|
||||
FROM {{schema_prefix}}"PlatformCostLog" p
|
||||
WHERE {where_p}
|
||||
GROUP BY p."provider", p."trackingType"
|
||||
ORDER BY total_cost DESC
|
||||
LIMIT {MAX_PROVIDER_ROWS}
|
||||
""",
|
||||
*params_p,
|
||||
),
|
||||
query_raw_with_schema(
|
||||
f"""
|
||||
SELECT
|
||||
p."userId" AS user_id,
|
||||
u."email",
|
||||
COALESCE(SUM(p."costMicrodollars"), 0)::bigint AS total_cost,
|
||||
COALESCE(SUM(p."inputTokens"), 0)::bigint AS total_input_tokens,
|
||||
COALESCE(SUM(p."outputTokens"), 0)::bigint AS total_output_tokens,
|
||||
COUNT(*)::bigint AS request_count
|
||||
FROM {{schema_prefix}}"PlatformCostLog" p
|
||||
LEFT JOIN {{schema_prefix}}"User" u ON u."id" = p."userId"
|
||||
WHERE {where_p}
|
||||
GROUP BY p."userId", u."email"
|
||||
ORDER BY total_cost DESC
|
||||
LIMIT {MAX_USER_ROWS}
|
||||
""",
|
||||
*params_p,
|
||||
),
|
||||
query_raw_with_schema(
|
||||
f"""
|
||||
SELECT COUNT(DISTINCT p."userId")::bigint AS cnt
|
||||
FROM {{schema_prefix}}"PlatformCostLog" p
|
||||
WHERE {where_p}
|
||||
""",
|
||||
*params_p,
|
||||
),
|
||||
)
|
||||
|
||||
# Use the exact COUNT(DISTINCT userId) so total_users is not capped at
|
||||
# MAX_USER_ROWS (which would silently report 100 for >100 active users).
|
||||
total_users = int(total_user_rows[0]["cnt"]) if total_user_rows else 0
|
||||
total_cost = sum(r["total_cost"] for r in by_provider_rows)
|
||||
total_requests = sum(r["request_count"] for r in by_provider_rows)
|
||||
|
||||
return PlatformCostDashboard(
|
||||
by_provider=[
|
||||
ProviderCostSummary(
|
||||
provider=r["provider"],
|
||||
tracking_type=r.get("tracking_type"),
|
||||
total_cost_microdollars=r["total_cost"],
|
||||
total_input_tokens=r["total_input_tokens"],
|
||||
total_output_tokens=r["total_output_tokens"],
|
||||
total_duration_seconds=r.get("total_duration", 0.0),
|
||||
total_tracking_amount=r.get("total_tracking_amount", 0.0),
|
||||
request_count=r["request_count"],
|
||||
)
|
||||
for r in by_provider_rows
|
||||
],
|
||||
by_user=[
|
||||
UserCostSummary(
|
||||
user_id=r.get("user_id"),
|
||||
email=_mask_email(r.get("email")),
|
||||
total_cost_microdollars=r["total_cost"],
|
||||
total_input_tokens=r["total_input_tokens"],
|
||||
total_output_tokens=r["total_output_tokens"],
|
||||
request_count=r["request_count"],
|
||||
)
|
||||
for r in by_user_rows
|
||||
],
|
||||
total_cost_microdollars=total_cost,
|
||||
total_requests=total_requests,
|
||||
total_users=total_users,
|
||||
)
|
||||
|
||||
|
||||
async def get_platform_cost_logs(
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
provider: str | None = None,
|
||||
user_id: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
) -> tuple[list[CostLogRow], int]:
|
||||
if start is None:
|
||||
start = datetime.now(tz=timezone.utc) - timedelta(days=DEFAULT_DASHBOARD_DAYS)
|
||||
where_sql, params = _build_where(start, end, provider, user_id, "p")
|
||||
|
||||
offset = (page - 1) * page_size
|
||||
limit_idx = len(params) + 1
|
||||
offset_idx = len(params) + 2
|
||||
|
||||
count_rows, rows = await asyncio.gather(
|
||||
query_raw_with_schema(
|
||||
f"""
|
||||
SELECT COUNT(*)::bigint AS cnt
|
||||
FROM {{schema_prefix}}"PlatformCostLog" p
|
||||
WHERE {where_sql}
|
||||
""",
|
||||
*params,
|
||||
),
|
||||
query_raw_with_schema(
|
||||
f"""
|
||||
SELECT
|
||||
p."id",
|
||||
p."createdAt" AS created_at,
|
||||
p."userId" AS user_id,
|
||||
u."email",
|
||||
p."graphExecId" AS graph_exec_id,
|
||||
p."nodeExecId" AS node_exec_id,
|
||||
p."blockName" AS block_name,
|
||||
p."provider",
|
||||
p."trackingType" AS tracking_type,
|
||||
p."costMicrodollars" AS cost_microdollars,
|
||||
p."inputTokens" AS input_tokens,
|
||||
p."outputTokens" AS output_tokens,
|
||||
p."duration",
|
||||
p."model"
|
||||
FROM {{schema_prefix}}"PlatformCostLog" p
|
||||
LEFT JOIN {{schema_prefix}}"User" u ON u."id" = p."userId"
|
||||
WHERE {where_sql}
|
||||
ORDER BY p."createdAt" DESC, p."id" DESC
|
||||
LIMIT ${limit_idx} OFFSET ${offset_idx}
|
||||
""",
|
||||
*params,
|
||||
page_size,
|
||||
offset,
|
||||
),
|
||||
)
|
||||
total = count_rows[0]["cnt"] if count_rows else 0
|
||||
|
||||
logs = [
|
||||
CostLogRow(
|
||||
id=r["id"],
|
||||
created_at=r["created_at"],
|
||||
user_id=r.get("user_id"),
|
||||
email=_mask_email(r.get("email")),
|
||||
graph_exec_id=r.get("graph_exec_id"),
|
||||
node_exec_id=r.get("node_exec_id"),
|
||||
block_name=r["block_name"],
|
||||
provider=r["provider"],
|
||||
tracking_type=r.get("tracking_type"),
|
||||
cost_microdollars=r.get("cost_microdollars"),
|
||||
input_tokens=r.get("input_tokens"),
|
||||
output_tokens=r.get("output_tokens"),
|
||||
duration=r.get("duration"),
|
||||
model=r.get("model"),
|
||||
)
|
||||
for r in rows
|
||||
]
|
||||
return logs, total
|
||||
266
autogpt_platform/backend/backend/data/platform_cost_test.py
Normal file
@@ -0,0 +1,266 @@
|
||||
"""Unit tests for helpers and async functions in platform_cost module."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from .platform_cost import (
|
||||
PlatformCostEntry,
|
||||
_build_where,
|
||||
_json_or_none,
|
||||
get_platform_cost_dashboard,
|
||||
get_platform_cost_logs,
|
||||
log_platform_cost,
|
||||
log_platform_cost_safe,
|
||||
)
|
||||
|
||||
|
||||
class TestJsonOrNone:
|
||||
def test_returns_none_for_none(self):
|
||||
assert _json_or_none(None) is None
|
||||
|
||||
def test_returns_json_string_for_dict(self):
|
||||
result = _json_or_none({"key": "value", "num": 42})
|
||||
assert result is not None
|
||||
assert '"key"' in result
|
||||
assert '"value"' in result
|
||||
|
||||
def test_returns_json_for_empty_dict(self):
|
||||
assert _json_or_none({}) == "{}"
|
||||
|
||||
|
||||
class TestBuildWhere:
|
||||
def test_no_filters_returns_true(self):
|
||||
sql, params = _build_where(None, None, None, None)
|
||||
assert sql == "TRUE"
|
||||
assert params == []
|
||||
|
||||
def test_start_only(self):
|
||||
dt = datetime(2026, 1, 1, tzinfo=timezone.utc)
|
||||
sql, params = _build_where(dt, None, None, None)
|
||||
assert '"createdAt" >= $1::timestamptz' in sql
|
||||
assert params == [dt]
|
||||
|
||||
def test_end_only(self):
|
||||
dt = datetime(2026, 6, 1, tzinfo=timezone.utc)
|
||||
sql, params = _build_where(None, dt, None, None)
|
||||
assert '"createdAt" <= $1::timestamptz' in sql
|
||||
assert params == [dt]
|
||||
|
||||
def test_provider_only(self):
|
||||
# Provider names are normalized to lowercase at write time, so the
|
||||
# filter uses a plain equality check. The input is also lowercased so
|
||||
# "OpenAI" and "openai" both match stored rows.
|
||||
sql, params = _build_where(None, None, "OpenAI", None)
|
||||
assert '"provider" = $1' in sql
|
||||
assert params == ["openai"]
|
||||
|
||||
def test_user_id_only(self):
|
||||
sql, params = _build_where(None, None, None, "user-123")
|
||||
assert '"userId" = $1' in sql
|
||||
assert params == ["user-123"]
|
||||
|
||||
def test_all_filters(self):
|
||||
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
|
||||
end = datetime(2026, 6, 1, tzinfo=timezone.utc)
|
||||
sql, params = _build_where(start, end, "Anthropic", "u1")
|
||||
assert "$1" in sql
|
||||
assert "$2" in sql
|
||||
assert "$3" in sql
|
||||
assert "$4" in sql
|
||||
assert len(params) == 4
|
||||
# Provider is lowercased at filter time to match stored lowercase values.
|
||||
assert params == [start, end, "anthropic", "u1"]
|
||||
|
||||
def test_table_alias(self):
|
||||
dt = datetime(2026, 1, 1, tzinfo=timezone.utc)
|
||||
sql, params = _build_where(dt, None, None, None, table_alias="p")
|
||||
assert 'p."createdAt"' in sql
|
||||
assert params == [dt]
|
||||
|
||||
def test_clauses_joined_with_and(self):
|
||||
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
|
||||
end = datetime(2026, 6, 1, tzinfo=timezone.utc)
|
||||
sql, _ = _build_where(start, end, None, None)
|
||||
assert " AND " in sql
|
||||
|
||||
|
||||
def _make_entry(**overrides: object) -> PlatformCostEntry:
|
||||
return PlatformCostEntry.model_validate(
|
||||
{
|
||||
"user_id": "user-1",
|
||||
"block_id": "block-1",
|
||||
"block_name": "TestBlock",
|
||||
"provider": "openai",
|
||||
"credential_id": "cred-1",
|
||||
**overrides,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class TestLogPlatformCost:
|
||||
@pytest.mark.asyncio
|
||||
async def test_calls_execute_raw_with_schema(self):
|
||||
mock_exec = AsyncMock()
|
||||
with patch("backend.data.platform_cost.execute_raw_with_schema", new=mock_exec):
|
||||
entry = _make_entry(
|
||||
input_tokens=100,
|
||||
output_tokens=50,
|
||||
cost_microdollars=5000,
|
||||
model="gpt-4",
|
||||
metadata={"key": "val"},
|
||||
)
|
||||
await log_platform_cost(entry)
|
||||
mock_exec.assert_awaited_once()
|
||||
args = mock_exec.call_args
|
||||
assert args[0][1] == "user-1" # user_id is first param
|
||||
assert args[0][6] == "block-1" # block_id
|
||||
assert args[0][7] == "TestBlock" # block_name
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metadata_none_passes_none(self):
|
||||
mock_exec = AsyncMock()
|
||||
with patch("backend.data.platform_cost.execute_raw_with_schema", new=mock_exec):
|
||||
entry = _make_entry(metadata=None)
|
||||
await log_platform_cost(entry)
|
||||
args = mock_exec.call_args
|
||||
assert args[0][-1] is None # last arg is metadata json
|
||||
|
||||
|
||||
class TestLogPlatformCostSafe:
|
||||
@pytest.mark.asyncio
|
||||
async def test_does_not_raise_on_error(self):
|
||||
with patch(
|
||||
"backend.data.platform_cost.execute_raw_with_schema",
|
||||
new=AsyncMock(side_effect=RuntimeError("DB down")),
|
||||
):
|
||||
entry = _make_entry()
|
||||
await log_platform_cost_safe(entry)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_succeeds_when_no_error(self):
|
||||
mock_exec = AsyncMock()
|
||||
with patch("backend.data.platform_cost.execute_raw_with_schema", new=mock_exec):
|
||||
entry = _make_entry()
|
||||
await log_platform_cost_safe(entry)
|
||||
mock_exec.assert_awaited_once()
|
||||
|
||||
|
||||
class TestGetPlatformCostDashboard:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_dashboard_with_data(self):
|
||||
provider_rows = [
|
||||
{
|
||||
"provider": "openai",
|
||||
"tracking_type": "tokens",
|
||||
"total_cost": 5000,
|
||||
"total_input_tokens": 1000,
|
||||
"total_output_tokens": 500,
|
||||
"total_duration": 10.5,
|
||||
"request_count": 3,
|
||||
}
|
||||
]
|
||||
user_rows = [
|
||||
{
|
||||
"user_id": "u1",
|
||||
"email": "a@b.com",
|
||||
"total_cost": 5000,
|
||||
"total_input_tokens": 1000,
|
||||
"total_output_tokens": 500,
|
||||
"request_count": 3,
|
||||
}
|
||||
]
|
||||
# Dashboard runs 3 queries: by_provider, by_user, COUNT(DISTINCT userId).
|
||||
mock_query = AsyncMock(side_effect=[provider_rows, user_rows, [{"cnt": 1}]])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
dashboard = await get_platform_cost_dashboard()
|
||||
assert dashboard.total_cost_microdollars == 5000
|
||||
assert dashboard.total_requests == 3
|
||||
assert dashboard.total_users == 1
|
||||
assert len(dashboard.by_provider) == 1
|
||||
assert dashboard.by_provider[0].provider == "openai"
|
||||
assert dashboard.by_provider[0].tracking_type == "tokens"
|
||||
assert dashboard.by_provider[0].total_duration_seconds == 10.5
|
||||
assert len(dashboard.by_user) == 1
|
||||
assert dashboard.by_user[0].email == "a***@b.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_empty_dashboard(self):
|
||||
mock_query = AsyncMock(side_effect=[[], [], []])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
dashboard = await get_platform_cost_dashboard()
|
||||
assert dashboard.total_cost_microdollars == 0
|
||||
assert dashboard.total_requests == 0
|
||||
assert dashboard.total_users == 0
|
||||
assert dashboard.by_provider == []
|
||||
assert dashboard.by_user == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_passes_filters_to_queries(self):
|
||||
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
|
||||
mock_query = AsyncMock(side_effect=[[], [], []])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
await get_platform_cost_dashboard(
|
||||
start=start, provider="openai", user_id="u1"
|
||||
)
|
||||
assert mock_query.await_count == 3
|
||||
first_call_sql = mock_query.call_args_list[0][0][0]
|
||||
assert "createdAt" in first_call_sql
|
||||
|
||||
|
||||
class TestGetPlatformCostLogs:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_logs_and_total(self):
|
||||
count_rows = [{"cnt": 1}]
|
||||
log_rows = [
|
||||
{
|
||||
"id": "log-1",
|
||||
"created_at": datetime(2026, 3, 1, tzinfo=timezone.utc),
|
||||
"user_id": "u1",
|
||||
"email": "a@b.com",
|
||||
"graph_exec_id": "g1",
|
||||
"node_exec_id": "n1",
|
||||
"block_name": "TestBlock",
|
||||
"provider": "openai",
|
||||
"tracking_type": "tokens",
|
||||
"cost_microdollars": 5000,
|
||||
"input_tokens": 100,
|
||||
"output_tokens": 50,
|
||||
"duration": 1.5,
|
||||
"model": "gpt-4",
|
||||
}
|
||||
]
|
||||
mock_query = AsyncMock(side_effect=[count_rows, log_rows])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
logs, total = await get_platform_cost_logs(page=1, page_size=10)
|
||||
assert total == 1
|
||||
assert len(logs) == 1
|
||||
assert logs[0].id == "log-1"
|
||||
assert logs[0].provider == "openai"
|
||||
assert logs[0].model == "gpt-4"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_empty_when_no_data(self):
|
||||
mock_query = AsyncMock(side_effect=[[{"cnt": 0}], []])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
logs, total = await get_platform_cost_logs()
|
||||
assert total == 0
|
||||
assert logs == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pagination_offset(self):
|
||||
mock_query = AsyncMock(side_effect=[[{"cnt": 100}], []])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
logs, total = await get_platform_cost_logs(page=3, page_size=25)
|
||||
assert total == 100
|
||||
second_call_args = mock_query.call_args_list[1][0]
|
||||
assert 25 in second_call_args # page_size
|
||||
assert 50 in second_call_args # offset = (3-1) * 25
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_count_returns_zero(self):
|
||||
mock_query = AsyncMock(side_effect=[[], []])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
logs, total = await get_platform_cost_logs()
|
||||
assert total == 0
|
||||
291
autogpt_platform/backend/backend/executor/cost_tracking.py
Normal file
@@ -0,0 +1,291 @@
|
||||
"""Helpers for platform cost tracking on system-credential block executions."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from backend.blocks._base import Block, BlockSchema
|
||||
from backend.copilot.token_tracking import _pending_log_tasks as _copilot_tasks
|
||||
from backend.copilot.token_tracking import (
|
||||
_pending_log_tasks_lock as _copilot_tasks_lock,
|
||||
)
|
||||
from backend.data.execution import NodeExecutionEntry
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.data.platform_cost import PlatformCostEntry, usd_to_microdollars
|
||||
from backend.executor.utils import block_usage_cost
|
||||
from backend.integrations.credentials_store import is_system_credential
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.db_manager import DatabaseManagerAsyncClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Provider groupings by billing model — used when the block didn't explicitly
|
||||
# declare stats.provider_cost_type and we fall back to provider-name
|
||||
# heuristics. Values match ProviderName enum values.
|
||||
_CHARACTER_BILLED_PROVIDERS = frozenset(
|
||||
{ProviderName.D_ID.value, ProviderName.ELEVENLABS.value}
|
||||
)
|
||||
_WALLTIME_BILLED_PROVIDERS = frozenset(
|
||||
{
|
||||
ProviderName.FAL.value,
|
||||
ProviderName.REVID.value,
|
||||
ProviderName.REPLICATE.value,
|
||||
}
|
||||
)
|
||||
|
||||
# Hold strong references to in-flight log tasks so the event loop doesn't
|
||||
# garbage-collect them mid-execution. Tasks remove themselves on completion.
|
||||
# _pending_log_tasks_lock guards all reads and writes: worker threads call
|
||||
# discard() via done callbacks while drain_pending_cost_logs() iterates.
|
||||
_pending_log_tasks: set[asyncio.Task] = set()
|
||||
_pending_log_tasks_lock = threading.Lock()
|
||||
# Per-loop semaphores: asyncio.Semaphore is not thread-safe and must not be
|
||||
# shared across event loops running in different threads. Key by loop instance
|
||||
# so each executor worker thread gets its own semaphore.
|
||||
_log_semaphores: dict[asyncio.AbstractEventLoop, asyncio.Semaphore] = {}
|
||||
|
||||
|
||||
def _get_log_semaphore() -> asyncio.Semaphore:
|
||||
loop = asyncio.get_running_loop()
|
||||
sem = _log_semaphores.get(loop)
|
||||
if sem is None:
|
||||
sem = asyncio.Semaphore(50)
|
||||
_log_semaphores[loop] = sem
|
||||
return sem
|
||||
|
||||
|
||||
async def drain_pending_cost_logs(timeout: float = 5.0) -> None:
|
||||
"""Await all in-flight cost log tasks with a timeout.
|
||||
|
||||
Drains both the executor cost log tasks (_pending_log_tasks in this module,
|
||||
used for block execution cost tracking via DatabaseManagerAsyncClient) and
|
||||
the copilot cost log tasks (token_tracking._pending_log_tasks, used for
|
||||
copilot LLM turns via platform_cost_db()).
|
||||
|
||||
Call this during graceful shutdown to flush pending INSERT tasks before
|
||||
the process exits. Tasks that don't complete within `timeout` seconds are
|
||||
abandoned and their failures are already logged by _safe_log.
|
||||
"""
|
||||
# asyncio.wait() requires all tasks to belong to the running event loop.
|
||||
# _pending_log_tasks is shared across executor worker threads (each with
|
||||
# its own loop), so filter to only tasks owned by the current loop.
|
||||
# Acquire the lock to take a consistent snapshot (worker threads call
|
||||
# discard() via done callbacks concurrently with this iteration).
|
||||
current_loop = asyncio.get_running_loop()
|
||||
with _pending_log_tasks_lock:
|
||||
all_pending = [t for t in _pending_log_tasks if t.get_loop() is current_loop]
|
||||
if all_pending:
|
||||
logger.info("Draining %d executor cost log task(s)", len(all_pending))
|
||||
_, still_pending = await asyncio.wait(all_pending, timeout=timeout)
|
||||
if still_pending:
|
||||
logger.warning(
|
||||
"%d executor cost log task(s) did not complete within %.1fs",
|
||||
len(still_pending),
|
||||
timeout,
|
||||
)
|
||||
# Also drain copilot cost log tasks (token_tracking._pending_log_tasks)
|
||||
with _copilot_tasks_lock:
|
||||
copilot_pending = [t for t in _copilot_tasks if t.get_loop() is current_loop]
|
||||
if copilot_pending:
|
||||
logger.info("Draining %d copilot cost log task(s)", len(copilot_pending))
|
||||
_, still_pending = await asyncio.wait(copilot_pending, timeout=timeout)
|
||||
if still_pending:
|
||||
logger.warning(
|
||||
"%d copilot cost log task(s) did not complete within %.1fs",
|
||||
len(still_pending),
|
||||
timeout,
|
||||
)
|
||||
|
||||
|
||||
def _schedule_log(
|
||||
db_client: "DatabaseManagerAsyncClient", entry: PlatformCostEntry
|
||||
) -> None:
|
||||
async def _safe_log() -> None:
|
||||
async with _get_log_semaphore():
|
||||
try:
|
||||
await db_client.log_platform_cost(entry)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to log platform cost for user=%s provider=%s block=%s",
|
||||
entry.user_id,
|
||||
entry.provider,
|
||||
entry.block_name,
|
||||
)
|
||||
|
||||
task = asyncio.create_task(_safe_log())
|
||||
with _pending_log_tasks_lock:
|
||||
_pending_log_tasks.add(task)
|
||||
|
||||
def _remove(t: asyncio.Task) -> None:
|
||||
with _pending_log_tasks_lock:
|
||||
_pending_log_tasks.discard(t)
|
||||
|
||||
task.add_done_callback(_remove)
|
||||
|
||||
|
||||
def _extract_model_name(raw: str | dict | None) -> str | None:
|
||||
"""Return a string model name from a block input field, or None.
|
||||
|
||||
Handles str (returned as-is), dict (e.g. an enum wrapper, skipped), and
|
||||
None (no model field). Unexpected types are coerced to str as a fallback.
|
||||
"""
|
||||
if raw is None:
|
||||
return None
|
||||
if isinstance(raw, str):
|
||||
return raw
|
||||
if isinstance(raw, dict):
|
||||
return None
|
||||
return str(raw)
|
||||
|
||||
|
||||
def resolve_tracking(
|
||||
provider: str,
|
||||
stats: NodeExecutionStats,
|
||||
input_data: dict[str, Any],
|
||||
) -> tuple[str, float]:
|
||||
"""Return (tracking_type, tracking_amount) based on provider billing model.
|
||||
|
||||
Preference order:
|
||||
1. Block-declared: if the block set `provider_cost_type` on its stats,
|
||||
honor it directly (paired with `provider_cost` as the amount).
|
||||
2. Heuristic fallback: infer from `provider_cost`/token counts, then
|
||||
from provider name for per-character / per-second billing.
|
||||
"""
|
||||
# 1. Block explicitly declared its cost type (only when an amount is present)
|
||||
if stats.provider_cost_type and stats.provider_cost is not None:
|
||||
return stats.provider_cost_type, stats.provider_cost
|
||||
|
||||
# 2. Provider returned actual USD cost (OpenRouter, Exa)
|
||||
if stats.provider_cost is not None:
|
||||
return "cost_usd", stats.provider_cost
|
||||
|
||||
# 3. LLM providers: track by tokens
|
||||
if stats.input_token_count or stats.output_token_count:
|
||||
return "tokens", float(
|
||||
(stats.input_token_count or 0) + (stats.output_token_count or 0)
|
||||
)
|
||||
|
||||
# 4. Provider-specific billing heuristics
|
||||
|
||||
# TTS: billed per character of input text
|
||||
if provider == ProviderName.UNREAL_SPEECH.value:
|
||||
text = input_data.get("text", "")
|
||||
return "characters", float(len(text)) if isinstance(text, str) else 0.0
|
||||
|
||||
# D-ID + ElevenLabs voice: billed per character of script
|
||||
if provider in _CHARACTER_BILLED_PROVIDERS:
|
||||
text = (
|
||||
input_data.get("script_input", "")
|
||||
or input_data.get("text", "")
|
||||
or input_data.get("script", "") # VideoNarrationBlock uses `script`
|
||||
)
|
||||
return "characters", float(len(text)) if isinstance(text, str) else 0.0
|
||||
|
||||
# E2B: billed per second of sandbox time
|
||||
if provider == ProviderName.E2B.value:
|
||||
return "sandbox_seconds", round(stats.walltime, 3) if stats.walltime else 0.0
|
||||
|
||||
# Video/image gen: walltime includes queue + generation + polling
|
||||
if provider in _WALLTIME_BILLED_PROVIDERS:
|
||||
return "walltime_seconds", round(stats.walltime, 3) if stats.walltime else 0.0
|
||||
|
||||
# Per-request: Google Maps, Ideogram, Nvidia, Apollo, etc.
|
||||
# All billed per API call - count 1 per block execution.
|
||||
return "per_run", 1.0
|
||||
|
||||
|
||||
async def log_system_credential_cost(
|
||||
node_exec: NodeExecutionEntry,
|
||||
block: Block,
|
||||
stats: NodeExecutionStats,
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
) -> None:
|
||||
"""Check if a system credential was used and log the platform cost.
|
||||
|
||||
Routes through DatabaseManagerAsyncClient so the write goes via the
|
||||
message-passing DB service rather than calling Prisma directly (which
|
||||
is not connected in the executor process).
|
||||
|
||||
Logs only the first matching system credential field (one log per
|
||||
execution). Any unexpected error is caught and logged — cost logging
|
||||
is strictly best-effort and must never disrupt block execution.
|
||||
|
||||
Note: costMicrodollars is left null for providers that don't return
|
||||
a USD cost. The credit_cost in metadata captures our internal credit
|
||||
charge as a proxy.
|
||||
"""
|
||||
try:
|
||||
if node_exec.execution_context.dry_run:
|
||||
return
|
||||
|
||||
input_data = node_exec.inputs
|
||||
input_model = cast(type[BlockSchema], block.input_schema)
|
||||
|
||||
for field_name in input_model.get_credentials_fields():
|
||||
cred_data = input_data.get(field_name)
|
||||
if not cred_data or not isinstance(cred_data, dict):
|
||||
continue
|
||||
cred_id = cred_data.get("id", "")
|
||||
if not cred_id or not is_system_credential(cred_id):
|
||||
continue
|
||||
|
||||
model_name = _extract_model_name(input_data.get("model"))
|
||||
|
||||
credit_cost, _ = block_usage_cost(block=block, input_data=input_data)
|
||||
|
||||
provider_name = cred_data.get("provider", "unknown")
|
||||
tracking_type, tracking_amount = resolve_tracking(
|
||||
provider=provider_name,
|
||||
stats=stats,
|
||||
input_data=input_data,
|
||||
)
|
||||
|
||||
# Only treat provider_cost as USD when the tracking type says so.
|
||||
# For other types (items, characters, per_run, ...) the
|
||||
# provider_cost field holds the raw amount, not a dollar value.
|
||||
# Use tracking_amount (the normalized value from resolve_tracking)
|
||||
# rather than raw stats.provider_cost to avoid unit mismatches.
|
||||
cost_microdollars = None
|
||||
if tracking_type == "cost_usd":
|
||||
cost_microdollars = usd_to_microdollars(tracking_amount)
|
||||
|
||||
meta: dict[str, Any] = {
|
||||
"tracking_type": tracking_type,
|
||||
"tracking_amount": tracking_amount,
|
||||
}
|
||||
if credit_cost is not None:
|
||||
meta["credit_cost"] = credit_cost
|
||||
if stats.provider_cost is not None:
|
||||
# Use 'provider_cost_raw' — the value's unit varies by tracking
|
||||
# type (USD for cost_usd, count for items/characters/per_run, etc.)
|
||||
meta["provider_cost_raw"] = stats.provider_cost
|
||||
|
||||
_schedule_log(
|
||||
db_client,
|
||||
PlatformCostEntry(
|
||||
user_id=node_exec.user_id,
|
||||
graph_exec_id=node_exec.graph_exec_id,
|
||||
node_exec_id=node_exec.node_exec_id,
|
||||
graph_id=node_exec.graph_id,
|
||||
node_id=node_exec.node_id,
|
||||
block_id=node_exec.block_id,
|
||||
block_name=block.name,
|
||||
provider=provider_name,
|
||||
credential_id=cred_id,
|
||||
cost_microdollars=cost_microdollars,
|
||||
input_tokens=stats.input_token_count,
|
||||
output_tokens=stats.output_token_count,
|
||||
data_size=stats.output_size if stats.output_size > 0 else None,
|
||||
duration=stats.walltime if stats.walltime > 0 else None,
|
||||
model=model_name,
|
||||
tracking_type=tracking_type,
|
||||
tracking_amount=tracking_amount,
|
||||
metadata=meta,
|
||||
),
|
||||
)
|
||||
return # One log per execution is enough
|
||||
except Exception:
|
||||
logger.exception("log_system_credential_cost failed unexpectedly")
|
||||
@@ -45,6 +45,10 @@ from backend.data.notifications import (
|
||||
ZeroBalanceData,
|
||||
)
|
||||
from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.cost_tracking import (
|
||||
drain_pending_cost_logs,
|
||||
log_system_credential_cost,
|
||||
)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.notifications.notifications import queue_notification
|
||||
from backend.util import json
|
||||
@@ -692,6 +696,15 @@ class ExecutionProcessor:
|
||||
stats=graph_stats,
|
||||
)
|
||||
|
||||
# Log platform cost if system credentials were used (only on success)
|
||||
if status == ExecutionStatus.COMPLETED:
|
||||
await log_system_credential_cost(
|
||||
node_exec=node_exec,
|
||||
block=node.block,
|
||||
stats=execution_stats,
|
||||
db_client=db_client,
|
||||
)
|
||||
|
||||
return execution_stats
|
||||
|
||||
@async_time_measured
|
||||
@@ -2044,14 +2057,23 @@ class ExecutionManager(AppProcess):
|
||||
prefix + " [cancel-consumer]",
|
||||
)
|
||||
|
||||
# Drain any in-flight cost log tasks before exit so we don't silently
|
||||
# drop INSERT operations during deployments.
|
||||
loop = getattr(self, "node_execution_loop", None)
|
||||
if loop is not None and loop.is_running():
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
drain_pending_cost_logs(), loop
|
||||
).result(timeout=10)
|
||||
logger.info(f"{prefix} ✅ Cost log tasks drained")
|
||||
except Exception as e:
|
||||
logger.warning(f"{prefix} ⚠️ Failed to drain cost log tasks: {e}")
|
||||
|
||||
logger.info(f"{prefix} ✅ Finished GraphExec cleanup")
|
||||
|
||||
super().cleanup()
|
||||
|
||||
|
||||
# ------- UTILITIES ------- #
|
||||
|
||||
|
||||
def get_db_client() -> "DatabaseManagerClient":
|
||||
return get_database_manager_client()
|
||||
|
||||
|
||||
@@ -0,0 +1,567 @@
|
||||
"""Unit tests for resolve_tracking and log_system_credential_cost."""
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.data.execution import ExecutionContext, NodeExecutionEntry
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.executor.cost_tracking import log_system_credential_cost, resolve_tracking
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_tracking
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveTracking:
|
||||
def _stats(self, **overrides: Any) -> NodeExecutionStats:
|
||||
return NodeExecutionStats(**overrides)
|
||||
|
||||
def test_provider_cost_returns_cost_usd(self):
|
||||
stats = self._stats(provider_cost=0.0042)
|
||||
tt, amt = resolve_tracking("openai", stats, {})
|
||||
assert tt == "cost_usd"
|
||||
assert amt == 0.0042
|
||||
|
||||
def test_token_counts_return_tokens(self):
|
||||
stats = self._stats(input_token_count=300, output_token_count=100)
|
||||
tt, amt = resolve_tracking("anthropic", stats, {})
|
||||
assert tt == "tokens"
|
||||
assert amt == 400.0
|
||||
|
||||
def test_token_counts_only_input(self):
|
||||
stats = self._stats(input_token_count=500)
|
||||
tt, amt = resolve_tracking("groq", stats, {})
|
||||
assert tt == "tokens"
|
||||
assert amt == 500.0
|
||||
|
||||
def test_unreal_speech_returns_characters(self):
|
||||
stats = self._stats()
|
||||
tt, amt = resolve_tracking("unreal_speech", stats, {"text": "Hello world"})
|
||||
assert tt == "characters"
|
||||
assert amt == 11.0
|
||||
|
||||
def test_unreal_speech_empty_text(self):
|
||||
stats = self._stats()
|
||||
tt, amt = resolve_tracking("unreal_speech", stats, {"text": ""})
|
||||
assert tt == "characters"
|
||||
assert amt == 0.0
|
||||
|
||||
def test_unreal_speech_non_string_text(self):
|
||||
stats = self._stats()
|
||||
tt, amt = resolve_tracking("unreal_speech", stats, {"text": 123})
|
||||
assert tt == "characters"
|
||||
assert amt == 0.0
|
||||
|
||||
def test_d_id_uses_script_input(self):
|
||||
stats = self._stats()
|
||||
tt, amt = resolve_tracking("d_id", stats, {"script_input": "Hello"})
|
||||
assert tt == "characters"
|
||||
assert amt == 5.0
|
||||
|
||||
def test_elevenlabs_uses_text(self):
|
||||
stats = self._stats()
|
||||
tt, amt = resolve_tracking("elevenlabs", stats, {"text": "Say this"})
|
||||
assert tt == "characters"
|
||||
assert amt == 8.0
|
||||
|
||||
def test_elevenlabs_fallback_to_text_when_no_script_input(self):
|
||||
stats = self._stats()
|
||||
tt, amt = resolve_tracking("elevenlabs", stats, {"text": "Fallback text"})
|
||||
assert tt == "characters"
|
||||
assert amt == 13.0
|
||||
|
||||
def test_elevenlabs_uses_script_field(self):
|
||||
"""VideoNarrationBlock (elevenlabs) uses `script` field, not script_input/text."""
|
||||
stats = self._stats()
|
||||
tt, amt = resolve_tracking("elevenlabs", stats, {"script": "Narration"})
|
||||
assert tt == "characters"
|
||||
assert amt == 9.0
|
||||
|
||||
def test_block_declared_cost_type_items(self):
|
||||
"""Block explicitly setting provider_cost_type='items' short-circuits heuristics."""
|
||||
stats = self._stats(provider_cost=5.0, provider_cost_type="items")
|
||||
tt, amt = resolve_tracking("google_maps", stats, {})
|
||||
assert tt == "items"
|
||||
assert amt == 5.0
|
||||
|
||||
def test_block_declared_cost_type_characters(self):
|
||||
"""TTS block can declare characters directly, bypassing input_data lookup."""
|
||||
stats = self._stats(provider_cost=42.0, provider_cost_type="characters")
|
||||
tt, amt = resolve_tracking("unreal_speech", stats, {})
|
||||
assert tt == "characters"
|
||||
assert amt == 42.0
|
||||
|
||||
def test_block_declared_cost_type_wins_over_tokens(self):
|
||||
"""provider_cost_type takes precedence over token-based heuristic."""
|
||||
stats = self._stats(
|
||||
provider_cost=1.0,
|
||||
provider_cost_type="per_run",
|
||||
input_token_count=500,
|
||||
)
|
||||
tt, amt = resolve_tracking("openai", stats, {})
|
||||
assert tt == "per_run"
|
||||
assert amt == 1.0
|
||||
|
||||
def test_e2b_returns_sandbox_seconds(self):
|
||||
stats = self._stats(walltime=45.123)
|
||||
tt, amt = resolve_tracking("e2b", stats, {})
|
||||
assert tt == "sandbox_seconds"
|
||||
assert amt == 45.123
|
||||
|
||||
def test_e2b_no_walltime(self):
|
||||
stats = self._stats(walltime=0)
|
||||
tt, amt = resolve_tracking("e2b", stats, {})
|
||||
assert tt == "sandbox_seconds"
|
||||
assert amt == 0.0
|
||||
|
||||
def test_fal_returns_walltime(self):
|
||||
stats = self._stats(walltime=12.5)
|
||||
tt, amt = resolve_tracking("fal", stats, {})
|
||||
assert tt == "walltime_seconds"
|
||||
assert amt == 12.5
|
||||
|
||||
def test_revid_returns_walltime(self):
|
||||
stats = self._stats(walltime=60.0)
|
||||
tt, amt = resolve_tracking("revid", stats, {})
|
||||
assert tt == "walltime_seconds"
|
||||
assert amt == 60.0
|
||||
|
||||
def test_replicate_returns_walltime(self):
|
||||
stats = self._stats(walltime=30.0)
|
||||
tt, amt = resolve_tracking("replicate", stats, {})
|
||||
assert tt == "walltime_seconds"
|
||||
assert amt == 30.0
|
||||
|
||||
def test_unknown_provider_returns_per_run(self):
|
||||
stats = self._stats()
|
||||
tt, amt = resolve_tracking("google_maps", stats, {})
|
||||
assert tt == "per_run"
|
||||
assert amt == 1.0
|
||||
|
||||
def test_provider_cost_takes_precedence_over_tokens(self):
|
||||
stats = self._stats(
|
||||
provider_cost=0.01, input_token_count=500, output_token_count=200
|
||||
)
|
||||
tt, amt = resolve_tracking("openai", stats, {})
|
||||
assert tt == "cost_usd"
|
||||
assert amt == 0.01
|
||||
|
||||
def test_provider_cost_zero_is_not_none(self):
|
||||
"""provider_cost=0.0 is falsy but should still be tracked as cost_usd
|
||||
(e.g. free-tier or fully-cached responses from OpenRouter)."""
|
||||
stats = self._stats(provider_cost=0.0)
|
||||
tt, amt = resolve_tracking("open_router", stats, {})
|
||||
assert tt == "cost_usd"
|
||||
assert amt == 0.0
|
||||
|
||||
def test_tokens_take_precedence_over_provider_specific(self):
|
||||
stats = self._stats(input_token_count=100, walltime=10.0)
|
||||
tt, amt = resolve_tracking("fal", stats, {})
|
||||
assert tt == "tokens"
|
||||
assert amt == 100.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# log_system_credential_cost
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_db_client() -> MagicMock:
|
||||
db_client = MagicMock()
|
||||
db_client.log_platform_cost = AsyncMock()
|
||||
return db_client
|
||||
|
||||
|
||||
def _make_block(has_credentials: bool = True) -> MagicMock:
|
||||
block = MagicMock()
|
||||
block.name = "TestBlock"
|
||||
input_schema = MagicMock()
|
||||
if has_credentials:
|
||||
input_schema.get_credentials_fields.return_value = {"credentials": MagicMock()}
|
||||
else:
|
||||
input_schema.get_credentials_fields.return_value = {}
|
||||
block.input_schema = input_schema
|
||||
return block
|
||||
|
||||
|
||||
def _make_node_exec(
|
||||
inputs: dict | None = None,
|
||||
dry_run: bool = False,
|
||||
) -> NodeExecutionEntry:
|
||||
return NodeExecutionEntry(
|
||||
user_id="user-1",
|
||||
graph_exec_id="gx-1",
|
||||
graph_id="g-1",
|
||||
graph_version=1,
|
||||
node_exec_id="nx-1",
|
||||
node_id="n-1",
|
||||
block_id="b-1",
|
||||
inputs=inputs or {},
|
||||
execution_context=ExecutionContext(dry_run=dry_run),
|
||||
)
|
||||
|
||||
|
||||
class TestLogSystemCredentialCost:
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_dry_run(self):
|
||||
db_client = _make_db_client()
|
||||
node_exec = _make_node_exec(dry_run=True)
|
||||
block = _make_block()
|
||||
stats = NodeExecutionStats()
|
||||
await log_system_credential_cost(node_exec, block, stats, db_client)
|
||||
db_client.log_platform_cost.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_no_credential_fields(self):
|
||||
db_client = _make_db_client()
|
||||
node_exec = _make_node_exec(inputs={})
|
||||
block = _make_block(has_credentials=False)
|
||||
stats = NodeExecutionStats()
|
||||
await log_system_credential_cost(node_exec, block, stats, db_client)
|
||||
db_client.log_platform_cost.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_cred_data_missing(self):
|
||||
db_client = _make_db_client()
|
||||
node_exec = _make_node_exec(inputs={})
|
||||
block = _make_block()
|
||||
stats = NodeExecutionStats()
|
||||
await log_system_credential_cost(node_exec, block, stats, db_client)
|
||||
db_client.log_platform_cost.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_not_system_credential(self):
|
||||
db_client = _make_db_client()
|
||||
with patch(
|
||||
"backend.executor.cost_tracking.is_system_credential",
|
||||
return_value=False,
|
||||
):
|
||||
node_exec = _make_node_exec(
|
||||
inputs={
|
||||
"credentials": {"id": "user-cred-123", "provider": "openai"},
|
||||
}
|
||||
)
|
||||
block = _make_block()
|
||||
stats = NodeExecutionStats()
|
||||
await log_system_credential_cost(node_exec, block, stats, db_client)
|
||||
db_client.log_platform_cost.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logs_with_system_credential(self):
|
||||
db_client = _make_db_client()
|
||||
with (
|
||||
patch(
|
||||
"backend.executor.cost_tracking.is_system_credential", return_value=True
|
||||
),
|
||||
patch(
|
||||
"backend.executor.cost_tracking.block_usage_cost",
|
||||
return_value=(10, None),
|
||||
),
|
||||
):
|
||||
node_exec = _make_node_exec(
|
||||
inputs={
|
||||
"credentials": {"id": "sys-cred-1", "provider": "openai"},
|
||||
"model": "gpt-4",
|
||||
}
|
||||
)
|
||||
block = _make_block()
|
||||
stats = NodeExecutionStats(input_token_count=500, output_token_count=200)
|
||||
await log_system_credential_cost(node_exec, block, stats, db_client)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
db_client.log_platform_cost.assert_awaited_once()
|
||||
entry = db_client.log_platform_cost.call_args[0][0]
|
||||
assert entry.user_id == "user-1"
|
||||
assert entry.provider == "openai"
|
||||
assert entry.block_name == "TestBlock"
|
||||
assert entry.model == "gpt-4"
|
||||
assert entry.input_tokens == 500
|
||||
assert entry.output_tokens == 200
|
||||
assert entry.tracking_type == "tokens"
|
||||
assert entry.metadata["tracking_type"] == "tokens"
|
||||
assert entry.metadata["tracking_amount"] == 700.0
|
||||
assert entry.metadata["credit_cost"] == 10
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logs_with_provider_cost(self):
|
||||
db_client = _make_db_client()
|
||||
with (
|
||||
patch(
|
||||
"backend.executor.cost_tracking.is_system_credential", return_value=True
|
||||
),
|
||||
patch(
|
||||
"backend.executor.cost_tracking.block_usage_cost",
|
||||
return_value=(5, None),
|
||||
),
|
||||
):
|
||||
node_exec = _make_node_exec(
|
||||
inputs={
|
||||
"credentials": {"id": "sys-cred-2", "provider": "open_router"},
|
||||
}
|
||||
)
|
||||
block = _make_block()
|
||||
stats = NodeExecutionStats(provider_cost=0.0015)
|
||||
await log_system_credential_cost(node_exec, block, stats, db_client)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
entry = db_client.log_platform_cost.call_args[0][0]
|
||||
assert entry.cost_microdollars == 1500
|
||||
assert entry.tracking_type == "cost_usd"
|
||||
assert entry.metadata["tracking_type"] == "cost_usd"
|
||||
assert entry.metadata["provider_cost_raw"] == 0.0015
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_name_enum_converted_to_str(self):
|
||||
db_client = _make_db_client()
|
||||
with (
|
||||
patch(
|
||||
"backend.executor.cost_tracking.is_system_credential", return_value=True
|
||||
),
|
||||
patch(
|
||||
"backend.executor.cost_tracking.block_usage_cost",
|
||||
return_value=(0, None),
|
||||
),
|
||||
):
|
||||
from enum import Enum
|
||||
|
||||
class FakeModel(Enum):
|
||||
GPT4 = "gpt-4"
|
||||
|
||||
node_exec = _make_node_exec(
|
||||
inputs={
|
||||
"credentials": {"id": "sys-cred", "provider": "openai"},
|
||||
"model": FakeModel.GPT4,
|
||||
}
|
||||
)
|
||||
block = _make_block()
|
||||
stats = NodeExecutionStats()
|
||||
await log_system_credential_cost(node_exec, block, stats, db_client)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
entry = db_client.log_platform_cost.call_args[0][0]
|
||||
assert entry.model == "FakeModel.GPT4"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_name_dict_becomes_none(self):
|
||||
db_client = _make_db_client()
|
||||
with (
|
||||
patch(
|
||||
"backend.executor.cost_tracking.is_system_credential", return_value=True
|
||||
),
|
||||
patch(
|
||||
"backend.executor.cost_tracking.block_usage_cost",
|
||||
return_value=(0, None),
|
||||
),
|
||||
):
|
||||
node_exec = _make_node_exec(
|
||||
inputs={
|
||||
"credentials": {"id": "sys-cred", "provider": "openai"},
|
||||
"model": {"nested": "value"},
|
||||
}
|
||||
)
|
||||
block = _make_block()
|
||||
stats = NodeExecutionStats()
|
||||
await log_system_credential_cost(node_exec, block, stats, db_client)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
entry = db_client.log_platform_cost.call_args[0][0]
|
||||
assert entry.model is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_does_not_raise_when_block_usage_cost_raises(self):
|
||||
"""log_system_credential_cost must swallow exceptions from block_usage_cost."""
|
||||
db_client = _make_db_client()
|
||||
with (
|
||||
patch(
|
||||
"backend.executor.cost_tracking.is_system_credential", return_value=True
|
||||
),
|
||||
patch(
|
||||
"backend.executor.cost_tracking.block_usage_cost",
|
||||
side_effect=RuntimeError("pricing lookup failed"),
|
||||
),
|
||||
):
|
||||
node_exec = _make_node_exec(
|
||||
inputs={
|
||||
"credentials": {"id": "sys-cred", "provider": "openai"},
|
||||
}
|
||||
)
|
||||
block = _make_block()
|
||||
stats = NodeExecutionStats()
|
||||
# Should not raise — outer except must catch block_usage_cost error
|
||||
await log_system_credential_cost(node_exec, block, stats, db_client)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_round_instead_of_int_for_microdollars(self):
|
||||
db_client = _make_db_client()
|
||||
with (
|
||||
patch(
|
||||
"backend.executor.cost_tracking.is_system_credential", return_value=True
|
||||
),
|
||||
patch(
|
||||
"backend.executor.cost_tracking.block_usage_cost",
|
||||
return_value=(0, None),
|
||||
),
|
||||
):
|
||||
node_exec = _make_node_exec(
|
||||
inputs={
|
||||
"credentials": {"id": "sys-cred", "provider": "openai"},
|
||||
}
|
||||
)
|
||||
block = _make_block()
|
||||
# 0.0015 * 1_000_000 = 1499.9999999... with float math
|
||||
# round() should give 1500, int() would give 1499
|
||||
stats = NodeExecutionStats(provider_cost=0.0015)
|
||||
await log_system_credential_cost(node_exec, block, stats, db_client)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
entry = db_client.log_platform_cost.call_args[0][0]
|
||||
assert entry.cost_microdollars == 1500
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_per_run_metadata_has_no_provider_cost_raw(self):
|
||||
"""For per-run providers (google_maps etc), provider_cost_raw is absent
|
||||
from metadata since stats.provider_cost is None."""
|
||||
db_client = _make_db_client()
|
||||
with (
|
||||
patch(
|
||||
"backend.executor.cost_tracking.is_system_credential", return_value=True
|
||||
),
|
||||
patch(
|
||||
"backend.executor.cost_tracking.block_usage_cost",
|
||||
return_value=(0, None),
|
||||
),
|
||||
):
|
||||
node_exec = _make_node_exec(
|
||||
inputs={
|
||||
"credentials": {"id": "sys-cred", "provider": "google_maps"},
|
||||
}
|
||||
)
|
||||
block = _make_block()
|
||||
stats = NodeExecutionStats() # no provider_cost
|
||||
await log_system_credential_cost(node_exec, block, stats, db_client)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
entry = db_client.log_platform_cost.call_args[0][0]
|
||||
assert entry.tracking_type == "per_run"
|
||||
assert "provider_cost_raw" not in (entry.metadata or {})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# merge_stats accumulation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMergeStats:
|
||||
"""Tests for NodeExecutionStats accumulation via += (used by Block.merge_stats)."""
|
||||
|
||||
def test_accumulates_output_size(self):
|
||||
stats = NodeExecutionStats()
|
||||
stats += NodeExecutionStats(output_size=10)
|
||||
stats += NodeExecutionStats(output_size=25)
|
||||
assert stats.output_size == 35
|
||||
|
||||
def test_accumulates_tokens(self):
|
||||
stats = NodeExecutionStats()
|
||||
stats += NodeExecutionStats(input_token_count=100, output_token_count=50)
|
||||
stats += NodeExecutionStats(input_token_count=200, output_token_count=150)
|
||||
assert stats.input_token_count == 300
|
||||
assert stats.output_token_count == 200
|
||||
|
||||
def test_preserves_provider_cost(self):
|
||||
stats = NodeExecutionStats()
|
||||
stats += NodeExecutionStats(provider_cost=0.005)
|
||||
stats += NodeExecutionStats(output_size=10)
|
||||
assert stats.provider_cost == 0.005
|
||||
assert stats.output_size == 10
|
||||
|
||||
def test_provider_cost_accumulates(self):
|
||||
"""Multiple merge_stats with provider_cost should sum (multi-round
|
||||
tool-calling in copilot / retries can report cost separately)."""
|
||||
stats = NodeExecutionStats()
|
||||
stats += NodeExecutionStats(provider_cost=0.001)
|
||||
stats += NodeExecutionStats(provider_cost=0.002)
|
||||
stats += NodeExecutionStats(provider_cost=0.003)
|
||||
assert stats.provider_cost == pytest.approx(0.006)
|
||||
|
||||
def test_provider_cost_none_does_not_overwrite(self):
|
||||
"""A None provider_cost must not wipe a previously-set value."""
|
||||
stats = NodeExecutionStats(provider_cost=0.01)
|
||||
stats += NodeExecutionStats() # provider_cost=None by default
|
||||
assert stats.provider_cost == 0.01
|
||||
|
||||
def test_provider_cost_type_last_write_wins(self):
|
||||
"""provider_cost_type is a Literal — last set value wins on merge."""
|
||||
stats = NodeExecutionStats(provider_cost_type="tokens")
|
||||
stats += NodeExecutionStats(provider_cost_type="items")
|
||||
assert stats.provider_cost_type == "items"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# on_node_execution -> log_system_credential_cost integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestManagerCostTrackingIntegration:
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_called_with_accumulated_stats(self):
|
||||
"""Verify that log_system_credential_cost receives stats that could
|
||||
have been accumulated by merge_stats across multiple yield steps."""
|
||||
db_client = _make_db_client()
|
||||
with (
|
||||
patch(
|
||||
"backend.executor.cost_tracking.is_system_credential", return_value=True
|
||||
),
|
||||
patch(
|
||||
"backend.executor.cost_tracking.block_usage_cost",
|
||||
return_value=(5, None),
|
||||
),
|
||||
):
|
||||
stats = NodeExecutionStats()
|
||||
stats += NodeExecutionStats(output_size=10, input_token_count=100)
|
||||
stats += NodeExecutionStats(output_size=25, input_token_count=200)
|
||||
|
||||
assert stats.output_size == 35
|
||||
assert stats.input_token_count == 300
|
||||
|
||||
node_exec = _make_node_exec(
|
||||
inputs={
|
||||
"credentials": {"id": "sys-cred-acc", "provider": "openai"},
|
||||
"model": "gpt-4",
|
||||
}
|
||||
)
|
||||
block = _make_block()
|
||||
await log_system_credential_cost(node_exec, block, stats, db_client)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
db_client.log_platform_cost.assert_awaited_once()
|
||||
entry = db_client.log_platform_cost.call_args[0][0]
|
||||
assert entry.input_tokens == 300
|
||||
assert entry.tracking_type == "tokens"
|
||||
assert entry.metadata["tracking_amount"] == 300.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_cost_log_when_status_is_failed(self):
|
||||
"""Manager only calls log_system_credential_cost on COMPLETED status.
|
||||
|
||||
This test verifies the guard condition `if status == COMPLETED` directly:
|
||||
calling log_system_credential_cost only happens on success, never on
|
||||
FAILED or ERROR executions.
|
||||
"""
|
||||
from backend.data.execution import ExecutionStatus
|
||||
|
||||
db_client = _make_db_client()
|
||||
node_exec = _make_node_exec(
|
||||
inputs={"credentials": {"id": "sys-cred", "provider": "openai"}}
|
||||
)
|
||||
block = _make_block()
|
||||
stats = NodeExecutionStats(input_token_count=100)
|
||||
|
||||
# Simulate the manager guard: only call on COMPLETED
|
||||
status = ExecutionStatus.FAILED
|
||||
if status == ExecutionStatus.COMPLETED:
|
||||
await log_system_credential_cost(node_exec, block, stats, db_client)
|
||||
|
||||
db_client.log_platform_cost.assert_not_awaited()
|
||||
@@ -0,0 +1,42 @@
|
||||
-- CreateTable
|
||||
CREATE TABLE "PlatformCostLog" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"userId" TEXT,
|
||||
"graphExecId" TEXT,
|
||||
"nodeExecId" TEXT,
|
||||
"graphId" TEXT,
|
||||
"nodeId" TEXT,
|
||||
"blockId" TEXT NOT NULL,
|
||||
"blockName" TEXT NOT NULL,
|
||||
"provider" TEXT NOT NULL,
|
||||
"credentialId" TEXT NOT NULL,
|
||||
"costMicrodollars" BIGINT,
|
||||
"inputTokens" INTEGER,
|
||||
"outputTokens" INTEGER,
|
||||
"dataSize" INTEGER,
|
||||
"duration" DOUBLE PRECISION,
|
||||
"model" TEXT,
|
||||
"trackingType" TEXT,
|
||||
"metadata" JSONB,
|
||||
|
||||
CONSTRAINT "PlatformCostLog_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "PlatformCostLog_userId_createdAt_idx" ON "PlatformCostLog"("userId", "createdAt");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "PlatformCostLog_provider_createdAt_idx" ON "PlatformCostLog"("provider", "createdAt");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "PlatformCostLog_createdAt_idx" ON "PlatformCostLog"("createdAt");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "PlatformCostLog_graphExecId_idx" ON "PlatformCostLog"("graphExecId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "PlatformCostLog_provider_trackingType_idx" ON "PlatformCostLog"("provider", "trackingType");
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "PlatformCostLog" ADD CONSTRAINT "PlatformCostLog_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE SET NULL ON UPDATE CASCADE;
|
||||
@@ -0,0 +1,2 @@
|
||||
-- AlterTable
|
||||
ALTER TABLE "PlatformCostLog" ADD COLUMN "trackingAmount" DOUBLE PRECISION;
|
||||
@@ -75,6 +75,8 @@ model User {
|
||||
PendingHumanReviews PendingHumanReview[]
|
||||
Workspace UserWorkspace?
|
||||
|
||||
PlatformCostLogs PlatformCostLog[]
|
||||
|
||||
// OAuth Provider relations
|
||||
OAuthApplications OAuthApplication[]
|
||||
OAuthAuthorizationCodes OAuthAuthorizationCode[]
|
||||
@@ -815,6 +817,45 @@ model CreditRefundRequest {
|
||||
@@index([userId, transactionKey])
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
////////// Platform Cost Tracking TABLES //////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
model PlatformCostLog {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
|
||||
userId String?
|
||||
User User? @relation(fields: [userId], references: [id], onDelete: SetNull)
|
||||
graphExecId String?
|
||||
nodeExecId String?
|
||||
graphId String?
|
||||
nodeId String?
|
||||
blockId String
|
||||
blockName String
|
||||
provider String
|
||||
credentialId String
|
||||
|
||||
// Cost in microdollars (1 USD = 1,000,000). Null if unknown.
|
||||
costMicrodollars BigInt?
|
||||
|
||||
inputTokens Int?
|
||||
outputTokens Int?
|
||||
dataSize Int? // bytes
|
||||
duration Float? // seconds
|
||||
model String?
|
||||
trackingType String? // e.g. "cost_usd", "tokens", "characters", "items", "per_run", "sandbox_seconds", "walltime_seconds"
|
||||
trackingAmount Float? // Amount in the unit implied by trackingType
|
||||
metadata Json?
|
||||
|
||||
@@index([userId, createdAt])
|
||||
@@index([provider, createdAt])
|
||||
@@index([createdAt])
|
||||
@@index([graphExecId])
|
||||
@@index([provider, trackingType])
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////// Store TABLES ///////////////////////////
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
import { Sidebar } from "@/components/__legacy__/Sidebar";
|
||||
import { Users, DollarSign, UserSearch, FileText } from "lucide-react";
|
||||
import { Gauge } from "@phosphor-icons/react/dist/ssr";
|
||||
import {
|
||||
Users,
|
||||
CurrencyDollar,
|
||||
MagnifyingGlass,
|
||||
Gauge,
|
||||
Receipt,
|
||||
FileText,
|
||||
} from "@phosphor-icons/react/dist/ssr";
|
||||
|
||||
import { IconSliders } from "@/components/__legacy__/ui/icons";
|
||||
|
||||
@@ -15,18 +21,23 @@ const sidebarLinkGroups = [
|
||||
{
|
||||
text: "User Spending",
|
||||
href: "/admin/spending",
|
||||
icon: <DollarSign className="h-6 w-6" />,
|
||||
icon: <CurrencyDollar className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "User Impersonation",
|
||||
href: "/admin/impersonation",
|
||||
icon: <UserSearch className="h-6 w-6" />,
|
||||
icon: <MagnifyingGlass className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "Rate Limits",
|
||||
href: "/admin/rate-limits",
|
||||
icon: <Gauge className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "Platform Costs",
|
||||
href: "/admin/platform-costs",
|
||||
icon: <Receipt className="h-6 w-6" />,
|
||||
},
|
||||
{
|
||||
text: "Execution Analytics",
|
||||
href: "/admin/execution-analytics",
|
||||
|
||||
@@ -0,0 +1,429 @@
|
||||
import {
|
||||
render,
|
||||
screen,
|
||||
cleanup,
|
||||
waitFor,
|
||||
} from "@/tests/integrations/test-utils";
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import { PlatformCostContent } from "../components/PlatformCostContent";
|
||||
import type { PlatformCostDashboard } from "@/app/api/__generated__/models/platformCostDashboard";
|
||||
import type { PlatformCostLogsResponse } from "@/app/api/__generated__/models/platformCostLogsResponse";
|
||||
|
||||
// Mock the generated Orval hooks so tests don't hit the network
|
||||
const mockUseGetDashboard = vi.fn();
|
||||
const mockUseGetLogs = vi.fn();
|
||||
|
||||
vi.mock("@/app/api/__generated__/endpoints/admin/admin", () => ({
|
||||
useGetV2GetPlatformCostDashboard: (...args: unknown[]) =>
|
||||
mockUseGetDashboard(...args),
|
||||
useGetV2GetPlatformCostLogs: (...args: unknown[]) => mockUseGetLogs(...args),
|
||||
}));
|
||||
|
||||
afterEach(() => {
|
||||
cleanup();
|
||||
mockUseGetDashboard.mockReset();
|
||||
mockUseGetLogs.mockReset();
|
||||
});
|
||||
|
||||
const emptyDashboard: PlatformCostDashboard = {
|
||||
total_cost_microdollars: 0,
|
||||
total_requests: 0,
|
||||
total_users: 0,
|
||||
by_provider: [],
|
||||
by_user: [],
|
||||
};
|
||||
|
||||
const emptyLogs: PlatformCostLogsResponse = {
|
||||
logs: [],
|
||||
pagination: {
|
||||
current_page: 1,
|
||||
page_size: 50,
|
||||
total_items: 0,
|
||||
total_pages: 0,
|
||||
},
|
||||
};
|
||||
|
||||
const dashboardWithData: PlatformCostDashboard = {
|
||||
total_cost_microdollars: 5_000_000,
|
||||
total_requests: 100,
|
||||
total_users: 5,
|
||||
by_provider: [
|
||||
{
|
||||
provider: "openai",
|
||||
tracking_type: "tokens",
|
||||
total_cost_microdollars: 3_000_000,
|
||||
total_input_tokens: 50000,
|
||||
total_output_tokens: 20000,
|
||||
total_duration_seconds: 0,
|
||||
request_count: 60,
|
||||
},
|
||||
{
|
||||
provider: "google_maps",
|
||||
tracking_type: "per_run",
|
||||
total_cost_microdollars: 0,
|
||||
total_input_tokens: 0,
|
||||
total_output_tokens: 0,
|
||||
total_duration_seconds: 0,
|
||||
request_count: 40,
|
||||
},
|
||||
],
|
||||
by_user: [
|
||||
{
|
||||
user_id: "user-1",
|
||||
email: "alice@example.com",
|
||||
total_cost_microdollars: 3_000_000,
|
||||
total_input_tokens: 50000,
|
||||
total_output_tokens: 20000,
|
||||
request_count: 60,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const logsWithData: PlatformCostLogsResponse = {
|
||||
logs: [
|
||||
{
|
||||
id: "log-1",
|
||||
created_at: "2026-03-01T00:00:00Z" as unknown as Date,
|
||||
user_id: "user-1",
|
||||
email: "alice@example.com",
|
||||
graph_exec_id: "gx-123",
|
||||
node_exec_id: "nx-456",
|
||||
block_name: "LLMBlock",
|
||||
provider: "openai",
|
||||
tracking_type: "tokens",
|
||||
cost_microdollars: 5000,
|
||||
input_tokens: 100,
|
||||
output_tokens: 50,
|
||||
duration: 1.5,
|
||||
model: "gpt-4",
|
||||
},
|
||||
],
|
||||
pagination: {
|
||||
current_page: 1,
|
||||
page_size: 50,
|
||||
total_items: 1,
|
||||
total_pages: 1,
|
||||
},
|
||||
};
|
||||
|
||||
function renderComponent(searchParams = {}) {
|
||||
return render(<PlatformCostContent searchParams={searchParams} />);
|
||||
}
|
||||
|
||||
describe("PlatformCostContent", () => {
|
||||
it("shows loading state initially", () => {
|
||||
mockUseGetDashboard.mockReturnValue({ data: undefined, isLoading: true });
|
||||
mockUseGetLogs.mockReturnValue({ data: undefined, isLoading: true });
|
||||
renderComponent();
|
||||
// Loading state renders Skeleton placeholders (animate-pulse divs) instead of content
|
||||
expect(screen.queryByText("Loading...")).toBeNull();
|
||||
// Summary cards and table content are not yet shown
|
||||
expect(screen.queryByText("Known Cost")).toBeNull();
|
||||
});
|
||||
|
||||
it("renders empty dashboard", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: emptyDashboard,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: emptyLogs,
|
||||
isLoading: false,
|
||||
});
|
||||
renderComponent();
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
// Verify the two summary cards that show $0.0000 — Known Cost and Estimated Total
|
||||
const zeroCostItems = screen.getAllByText("$0.0000");
|
||||
expect(zeroCostItems.length).toBe(2);
|
||||
expect(screen.getByText("No cost data yet")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders dashboard with provider data", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: dashboardWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: logsWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
renderComponent();
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("$5.0000")).toBeDefined();
|
||||
expect(screen.getByText("100")).toBeDefined();
|
||||
expect(screen.getByText("5")).toBeDefined();
|
||||
expect(screen.getByText("openai")).toBeDefined();
|
||||
expect(screen.getByText("google_maps")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders tracking type badges", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: dashboardWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: logsWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
renderComponent();
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("tokens")).toBeDefined();
|
||||
expect(screen.getByText("per_run")).toBeDefined();
|
||||
});
|
||||
|
||||
it("shows error state on fetch failure", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: undefined,
|
||||
isLoading: false,
|
||||
error: new Error("Network error"),
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: undefined,
|
||||
isLoading: false,
|
||||
error: new Error("Network error"),
|
||||
});
|
||||
renderComponent();
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("Network error")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders tab buttons", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: emptyDashboard,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({ data: emptyLogs, isLoading: false });
|
||||
renderComponent();
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("By Provider")).toBeDefined();
|
||||
expect(screen.getByText("By User")).toBeDefined();
|
||||
expect(screen.getByText("Raw Logs")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders summary cards with correct labels", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: dashboardWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: logsWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
renderComponent();
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getAllByText("Known Cost").length).toBeGreaterThanOrEqual(1);
|
||||
expect(screen.getByText("Estimated Total")).toBeDefined();
|
||||
expect(screen.getByText("Total Requests")).toBeDefined();
|
||||
expect(screen.getByText("Active Users")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders filter inputs", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: emptyDashboard,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({ data: emptyLogs, isLoading: false });
|
||||
renderComponent();
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("Start Date")).toBeDefined();
|
||||
expect(screen.getByText("End Date")).toBeDefined();
|
||||
expect(screen.getAllByText(/Provider/i).length).toBeGreaterThanOrEqual(1);
|
||||
expect(screen.getByText("User ID")).toBeDefined();
|
||||
expect(screen.getByText("Apply")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders by-user tab when specified", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: dashboardWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: logsWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
renderComponent({ tab: "by-user" });
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("alice@example.com")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders logs tab when specified", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: dashboardWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: logsWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
renderComponent({ tab: "logs" });
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("LLMBlock")).toBeDefined();
|
||||
expect(screen.getByText("gpt-4")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders no logs message when empty", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: emptyDashboard,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({ data: emptyLogs, isLoading: false });
|
||||
renderComponent({ tab: "logs" });
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("No logs found")).toBeDefined();
|
||||
});
|
||||
|
||||
it("shows pagination when multiple pages", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: dashboardWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
const multiPageLogs: PlatformCostLogsResponse = {
|
||||
logs: logsWithData.logs,
|
||||
pagination: {
|
||||
current_page: 1,
|
||||
page_size: 50,
|
||||
total_items: 200,
|
||||
total_pages: 4,
|
||||
},
|
||||
};
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: multiPageLogs,
|
||||
isLoading: false,
|
||||
});
|
||||
renderComponent({ tab: "logs" });
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("Previous")).toBeDefined();
|
||||
expect(screen.getByText("Next")).toBeDefined();
|
||||
expect(screen.getByText(/Page 1 of 4/)).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders user table with unknown email", async () => {
|
||||
const dashWithNullEmail: PlatformCostDashboard = {
|
||||
...dashboardWithData,
|
||||
by_user: [
|
||||
{
|
||||
user_id: "user-2",
|
||||
email: null,
|
||||
total_cost_microdollars: 1000,
|
||||
total_input_tokens: 100,
|
||||
total_output_tokens: 50,
|
||||
request_count: 5,
|
||||
},
|
||||
],
|
||||
};
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: dashWithNullEmail,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({ data: emptyLogs, isLoading: false });
|
||||
renderComponent({ tab: "by-user" });
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("Unknown")).toBeDefined();
|
||||
});
|
||||
|
||||
it("by-user tab content visible when tab=by-user param set", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: dashboardWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: logsWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
renderComponent({ tab: "by-user" });
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("alice@example.com")).toBeDefined();
|
||||
// overview tab content should not be visible
|
||||
expect(screen.queryByText("openai")).toBeNull();
|
||||
});
|
||||
|
||||
it("logs tab content visible when tab=logs param set", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: dashboardWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: logsWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
renderComponent({ tab: "logs" });
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("LLMBlock")).toBeDefined();
|
||||
expect(screen.getByText("gpt-4")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders log with null user as dash", async () => {
|
||||
const logWithNullUser: PlatformCostLogsResponse = {
|
||||
logs: [
|
||||
{
|
||||
id: "log-2",
|
||||
created_at: "2026-03-01T00:00:00Z" as unknown as Date,
|
||||
user_id: null,
|
||||
email: null,
|
||||
graph_exec_id: null,
|
||||
node_exec_id: null,
|
||||
block_name: "copilot:SDK",
|
||||
provider: "anthropic",
|
||||
tracking_type: "cost_usd",
|
||||
cost_microdollars: 15000,
|
||||
input_tokens: null,
|
||||
output_tokens: null,
|
||||
duration: null,
|
||||
model: "claude-opus-4-20250514",
|
||||
},
|
||||
],
|
||||
pagination: {
|
||||
current_page: 1,
|
||||
page_size: 50,
|
||||
total_items: 1,
|
||||
total_pages: 1,
|
||||
},
|
||||
};
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: emptyDashboard,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: logWithNullUser,
|
||||
isLoading: false,
|
||||
});
|
||||
renderComponent({ tab: "logs" });
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("copilot:SDK")).toBeDefined();
|
||||
expect(screen.getByText("anthropic")).toBeDefined();
|
||||
// null email + null user_id renders as "-" in the User column; multiple
|
||||
// other cells (tokens, duration, session) also render "-", so use
|
||||
// getAllByText to avoid the single-match constraint.
|
||||
expect(screen.getAllByText("-").length).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,87 @@
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
|
||||
const mockGetDashboard = vi.fn();
|
||||
const mockGetLogs = vi.fn();
|
||||
|
||||
vi.mock("@/app/api/__generated__/endpoints/admin/admin", () => ({
|
||||
getV2GetPlatformCostDashboard: (...args: unknown[]) =>
|
||||
mockGetDashboard(...args),
|
||||
getV2GetPlatformCostLogs: (...args: unknown[]) => mockGetLogs(...args),
|
||||
}));
|
||||
|
||||
import { getPlatformCostDashboard, getPlatformCostLogs } from "../actions";
|
||||
|
||||
describe("getPlatformCostDashboard", () => {
|
||||
it("returns data on success", async () => {
|
||||
const mockData = { total_cost_microdollars: 1000, total_requests: 5 };
|
||||
mockGetDashboard.mockResolvedValue({ status: 200, data: mockData });
|
||||
const result = await getPlatformCostDashboard();
|
||||
expect(result).toEqual(mockData);
|
||||
});
|
||||
|
||||
it("returns undefined on non-200", async () => {
|
||||
mockGetDashboard.mockResolvedValue({ status: 401 });
|
||||
const result = await getPlatformCostDashboard();
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
|
||||
it("passes filter params to API", async () => {
|
||||
mockGetDashboard.mockReset();
|
||||
mockGetDashboard.mockResolvedValue({ status: 200, data: {} });
|
||||
await getPlatformCostDashboard({
|
||||
start: "2026-01-01T00:00:00",
|
||||
end: "2026-06-01T00:00:00",
|
||||
provider: "openai",
|
||||
user_id: "user-1",
|
||||
});
|
||||
expect(mockGetDashboard).toHaveBeenCalledTimes(1);
|
||||
const params = mockGetDashboard.mock.calls[0][0];
|
||||
expect(params.start).toBe("2026-01-01T00:00:00");
|
||||
expect(params.end).toBe("2026-06-01T00:00:00");
|
||||
expect(params.provider).toBe("openai");
|
||||
expect(params.user_id).toBe("user-1");
|
||||
});
|
||||
|
||||
it("passes undefined for empty filter strings", async () => {
|
||||
mockGetDashboard.mockReset();
|
||||
mockGetDashboard.mockResolvedValue({ status: 200, data: {} });
|
||||
await getPlatformCostDashboard({
|
||||
start: "",
|
||||
provider: "",
|
||||
user_id: "",
|
||||
});
|
||||
expect(mockGetDashboard).toHaveBeenCalledTimes(1);
|
||||
const params = mockGetDashboard.mock.calls[0][0];
|
||||
expect(params.start).toBeUndefined();
|
||||
expect(params.provider).toBeUndefined();
|
||||
expect(params.user_id).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe("getPlatformCostLogs", () => {
|
||||
it("returns data on success", async () => {
|
||||
const mockData = { logs: [], pagination: { current_page: 1 } };
|
||||
mockGetLogs.mockResolvedValue({ status: 200, data: mockData });
|
||||
const result = await getPlatformCostLogs();
|
||||
expect(result).toEqual(mockData);
|
||||
});
|
||||
|
||||
it("passes page and page_size", async () => {
|
||||
mockGetLogs.mockReset();
|
||||
mockGetLogs.mockResolvedValue({ status: 200, data: { logs: [] } });
|
||||
await getPlatformCostLogs({ page: 3, page_size: 25 });
|
||||
expect(mockGetLogs).toHaveBeenCalledTimes(1);
|
||||
const params = mockGetLogs.mock.calls[0][0];
|
||||
expect(params.page).toBe(3);
|
||||
expect(params.page_size).toBe(25);
|
||||
});
|
||||
|
||||
it("passes start date string through to API", async () => {
|
||||
mockGetLogs.mockReset();
|
||||
mockGetLogs.mockResolvedValue({ status: 200, data: { logs: [] } });
|
||||
await getPlatformCostLogs({ start: "2026-03-01T00:00:00" });
|
||||
expect(mockGetLogs).toHaveBeenCalledTimes(1);
|
||||
const params = mockGetLogs.mock.calls[0][0];
|
||||
expect(params.start).toBe("2026-03-01T00:00:00");
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,300 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import type { ProviderCostSummary } from "@/app/api/__generated__/models/providerCostSummary";
|
||||
import {
|
||||
toDateOrUndefined,
|
||||
formatMicrodollars,
|
||||
formatTokens,
|
||||
formatDuration,
|
||||
estimateCostForRow,
|
||||
trackingValue,
|
||||
toLocalInput,
|
||||
toUtcIso,
|
||||
} from "../helpers";
|
||||
|
||||
function makeRow(overrides: Partial<ProviderCostSummary>): ProviderCostSummary {
|
||||
return {
|
||||
provider: "openai",
|
||||
tracking_type: null,
|
||||
total_cost_microdollars: 0,
|
||||
total_input_tokens: 0,
|
||||
total_output_tokens: 0,
|
||||
total_duration_seconds: 0,
|
||||
request_count: 0,
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
describe("toDateOrUndefined", () => {
|
||||
it("returns undefined for empty string", () => {
|
||||
expect(toDateOrUndefined("")).toBeUndefined();
|
||||
});
|
||||
|
||||
it("returns undefined for undefined", () => {
|
||||
expect(toDateOrUndefined(undefined)).toBeUndefined();
|
||||
});
|
||||
|
||||
it("returns undefined for invalid date string", () => {
|
||||
expect(toDateOrUndefined("not-a-date")).toBeUndefined();
|
||||
});
|
||||
|
||||
it("returns a Date for a valid ISO string", () => {
|
||||
const result = toDateOrUndefined("2026-01-15T00:00:00Z");
|
||||
expect(result).toBeInstanceOf(Date);
|
||||
expect(result!.toISOString()).toBe("2026-01-15T00:00:00.000Z");
|
||||
});
|
||||
});
|
||||
|
||||
describe("formatMicrodollars", () => {
|
||||
it("formats zero", () => {
|
||||
expect(formatMicrodollars(0)).toBe("$0.0000");
|
||||
});
|
||||
|
||||
it("formats a small amount", () => {
|
||||
expect(formatMicrodollars(50_000)).toBe("$0.0500");
|
||||
});
|
||||
|
||||
it("formats one dollar", () => {
|
||||
expect(formatMicrodollars(1_000_000)).toBe("$1.0000");
|
||||
});
|
||||
});
|
||||
|
||||
describe("formatTokens", () => {
|
||||
it("formats small numbers as-is", () => {
|
||||
expect(formatTokens(500)).toBe("500");
|
||||
});
|
||||
|
||||
it("formats thousands with K suffix", () => {
|
||||
expect(formatTokens(1_500)).toBe("1.5K");
|
||||
});
|
||||
|
||||
it("formats millions with M suffix", () => {
|
||||
expect(formatTokens(2_500_000)).toBe("2.5M");
|
||||
});
|
||||
});
|
||||
|
||||
describe("formatDuration", () => {
|
||||
it("formats seconds", () => {
|
||||
expect(formatDuration(30)).toBe("30.0s");
|
||||
});
|
||||
|
||||
it("formats minutes", () => {
|
||||
expect(formatDuration(90)).toBe("1.5m");
|
||||
});
|
||||
|
||||
it("formats hours", () => {
|
||||
expect(formatDuration(5400)).toBe("1.5h");
|
||||
});
|
||||
});
|
||||
|
||||
describe("estimateCostForRow", () => {
|
||||
it("returns microdollars directly for cost_usd tracking", () => {
|
||||
const row = makeRow({
|
||||
tracking_type: "cost_usd",
|
||||
total_cost_microdollars: 500_000,
|
||||
});
|
||||
expect(estimateCostForRow(row, {})).toBe(500_000);
|
||||
});
|
||||
|
||||
it("returns reported cost for token tracking when cost > 0", () => {
|
||||
const row = makeRow({
|
||||
tracking_type: "tokens",
|
||||
total_cost_microdollars: 100_000,
|
||||
total_input_tokens: 1000,
|
||||
total_output_tokens: 500,
|
||||
});
|
||||
expect(estimateCostForRow(row, {})).toBe(100_000);
|
||||
});
|
||||
|
||||
it("estimates cost from default rate for token tracking with zero cost", () => {
|
||||
const row = makeRow({
|
||||
provider: "openai",
|
||||
tracking_type: "tokens",
|
||||
total_cost_microdollars: 0,
|
||||
total_input_tokens: 500,
|
||||
total_output_tokens: 500,
|
||||
});
|
||||
// 1000 tokens / 1000 * 0.005 USD * 1_000_000 = 5000
|
||||
expect(estimateCostForRow(row, {})).toBe(5000);
|
||||
});
|
||||
|
||||
it("returns null for unknown token provider with zero cost", () => {
|
||||
const row = makeRow({
|
||||
provider: "unknown_provider",
|
||||
tracking_type: "tokens",
|
||||
total_cost_microdollars: 0,
|
||||
});
|
||||
expect(estimateCostForRow(row, {})).toBeNull();
|
||||
});
|
||||
|
||||
it("uses per-run override when provided", () => {
|
||||
const row = makeRow({
|
||||
provider: "google_maps",
|
||||
tracking_type: "per_run",
|
||||
request_count: 10,
|
||||
});
|
||||
// override = 0.05 * 10 * 1_000_000 = 500_000
|
||||
expect(estimateCostForRow(row, { "google_maps:per_run": 0.05 })).toBe(
|
||||
500_000,
|
||||
);
|
||||
});
|
||||
|
||||
it("uses default per-run cost when no override", () => {
|
||||
const row = makeRow({
|
||||
provider: "google_maps",
|
||||
tracking_type: null,
|
||||
request_count: 5,
|
||||
});
|
||||
// 0.032 * 5 * 1_000_000 = 160_000
|
||||
expect(estimateCostForRow(row, {})).toBe(160_000);
|
||||
});
|
||||
|
||||
it("returns null for unknown per_run provider", () => {
|
||||
const row = makeRow({
|
||||
provider: "totally_unknown",
|
||||
tracking_type: "per_run",
|
||||
request_count: 3,
|
||||
});
|
||||
expect(estimateCostForRow(row, {})).toBeNull();
|
||||
});
|
||||
|
||||
it("returns null for duration tracking with no rate and no cost", () => {
|
||||
const row = makeRow({
|
||||
provider: "openai",
|
||||
tracking_type: "duration_seconds",
|
||||
total_cost_microdollars: 0,
|
||||
total_duration_seconds: 100,
|
||||
});
|
||||
expect(estimateCostForRow(row, {})).toBeNull();
|
||||
});
|
||||
|
||||
it("estimates cost from default rate for characters tracking", () => {
|
||||
const row = makeRow({
|
||||
provider: "elevenlabs",
|
||||
tracking_type: "characters",
|
||||
total_cost_microdollars: 0,
|
||||
total_tracking_amount: 2000,
|
||||
});
|
||||
// 2000 chars / 1000 * 0.18 USD * 1_000_000 = 360_000
|
||||
expect(estimateCostForRow(row, {})).toBe(360_000);
|
||||
});
|
||||
|
||||
it("estimates cost from default rate for items tracking", () => {
|
||||
const row = makeRow({
|
||||
provider: "apollo",
|
||||
tracking_type: "items",
|
||||
total_cost_microdollars: 0,
|
||||
total_tracking_amount: 50,
|
||||
});
|
||||
// 50 * 0.02 * 1_000_000 = 1_000_000
|
||||
expect(estimateCostForRow(row, {})).toBe(1_000_000);
|
||||
});
|
||||
|
||||
it("estimates cost from default rate for duration tracking", () => {
|
||||
const row = makeRow({
|
||||
provider: "e2b",
|
||||
tracking_type: "sandbox_seconds",
|
||||
total_cost_microdollars: 0,
|
||||
total_duration_seconds: 1_000_000,
|
||||
});
|
||||
// 1_000_000 * 0.000014 * 1_000_000 = 14_000_000
|
||||
expect(estimateCostForRow(row, {})).toBe(14_000_000);
|
||||
});
|
||||
});
|
||||
|
||||
describe("trackingValue", () => {
|
||||
it("returns formatted microdollars for cost_usd", () => {
|
||||
const row = makeRow({
|
||||
tracking_type: "cost_usd",
|
||||
total_cost_microdollars: 1_000_000,
|
||||
});
|
||||
expect(trackingValue(row)).toBe("$1.0000");
|
||||
});
|
||||
|
||||
it("returns formatted token count for tokens", () => {
|
||||
const row = makeRow({
|
||||
tracking_type: "tokens",
|
||||
total_input_tokens: 500,
|
||||
total_output_tokens: 500,
|
||||
});
|
||||
expect(trackingValue(row)).toBe("1.0K tokens");
|
||||
});
|
||||
|
||||
it("returns formatted duration for sandbox_seconds", () => {
|
||||
const row = makeRow({
|
||||
tracking_type: "sandbox_seconds",
|
||||
total_duration_seconds: 120,
|
||||
});
|
||||
expect(trackingValue(row)).toBe("2.0m");
|
||||
});
|
||||
|
||||
it("returns run count for per_run (default tracking)", () => {
|
||||
const row = makeRow({
|
||||
tracking_type: null,
|
||||
request_count: 42,
|
||||
});
|
||||
expect(trackingValue(row)).toBe("42 runs");
|
||||
});
|
||||
|
||||
it("returns formatted character count for characters tracking", () => {
|
||||
const row = makeRow({
|
||||
tracking_type: "characters",
|
||||
total_tracking_amount: 2500,
|
||||
});
|
||||
expect(trackingValue(row)).toBe("2.5K chars");
|
||||
});
|
||||
|
||||
it("returns formatted item count for items tracking", () => {
|
||||
const row = makeRow({
|
||||
tracking_type: "items",
|
||||
total_tracking_amount: 1234,
|
||||
});
|
||||
expect(trackingValue(row)).toBe("1,234 items");
|
||||
});
|
||||
|
||||
it("returns formatted duration for sandbox_seconds", () => {
|
||||
const row = makeRow({
|
||||
tracking_type: "sandbox_seconds",
|
||||
total_duration_seconds: 7200,
|
||||
});
|
||||
expect(trackingValue(row)).toBe("2.0h");
|
||||
});
|
||||
|
||||
it("returns formatted duration for walltime_seconds", () => {
|
||||
const row = makeRow({
|
||||
tracking_type: "walltime_seconds",
|
||||
total_duration_seconds: 45,
|
||||
});
|
||||
expect(trackingValue(row)).toBe("45.0s");
|
||||
});
|
||||
});
|
||||
|
||||
describe("toLocalInput", () => {
|
||||
it("returns empty string for empty input", () => {
|
||||
expect(toLocalInput("")).toBe("");
|
||||
});
|
||||
|
||||
it("returns empty string for invalid ISO", () => {
|
||||
expect(toLocalInput("not-a-date")).toBe("");
|
||||
});
|
||||
|
||||
it("converts UTC ISO to local datetime-local format", () => {
|
||||
const result = toLocalInput("2026-01-15T12:30:00Z");
|
||||
// Format should be YYYY-MM-DDTHH:mm
|
||||
expect(result).toMatch(/^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}$/);
|
||||
});
|
||||
});
|
||||
|
||||
describe("toUtcIso", () => {
|
||||
it("returns empty string for empty input", () => {
|
||||
expect(toUtcIso("")).toBe("");
|
||||
});
|
||||
|
||||
it("returns empty string for invalid local time", () => {
|
||||
expect(toUtcIso("not-a-date")).toBe("");
|
||||
});
|
||||
|
||||
it("converts local datetime-local to ISO string", () => {
|
||||
const result = toUtcIso("2026-01-15T12:30");
|
||||
expect(result).toMatch(/^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z$/);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,45 @@
|
||||
import {
|
||||
getV2GetPlatformCostDashboard,
|
||||
getV2GetPlatformCostLogs,
|
||||
} from "@/app/api/__generated__/endpoints/admin/admin";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
|
||||
// Backend expects ISO datetime strings. The generated client's URL builder
|
||||
// calls .toString() on values, which for Date objects produces the human
|
||||
// "Tue Mar 31 2026 22:00:00 GMT+0000 (Coordinated Universal Time)" format
|
||||
// that FastAPI rejects with 422. We already pass UTC ISO from the URL, so
|
||||
// forward the raw strings through the `as unknown as Date` cast to match
|
||||
// the generated typing without triggering Date.toString().
|
||||
export async function getPlatformCostDashboard(params?: {
|
||||
start?: string;
|
||||
end?: string;
|
||||
provider?: string;
|
||||
user_id?: string;
|
||||
}) {
|
||||
const response = await getV2GetPlatformCostDashboard({
|
||||
start: (params?.start || undefined) as unknown as Date | undefined,
|
||||
end: (params?.end || undefined) as unknown as Date | undefined,
|
||||
provider: params?.provider || undefined,
|
||||
user_id: params?.user_id || undefined,
|
||||
});
|
||||
return okData(response);
|
||||
}
|
||||
|
||||
export async function getPlatformCostLogs(params?: {
|
||||
start?: string;
|
||||
end?: string;
|
||||
provider?: string;
|
||||
user_id?: string;
|
||||
page?: number;
|
||||
page_size?: number;
|
||||
}) {
|
||||
const response = await getV2GetPlatformCostLogs({
|
||||
start: (params?.start || undefined) as unknown as Date | undefined,
|
||||
end: (params?.end || undefined) as unknown as Date | undefined,
|
||||
provider: params?.provider || undefined,
|
||||
user_id: params?.user_id || undefined,
|
||||
page: params?.page,
|
||||
page_size: params?.page_size,
|
||||
});
|
||||
return okData(response);
|
||||
}
|
||||
@@ -0,0 +1,140 @@
|
||||
import type { CostLogRow } from "@/app/api/__generated__/models/costLogRow";
|
||||
import type { Pagination } from "@/app/api/__generated__/models/pagination";
|
||||
import { formatDuration, formatMicrodollars, formatTokens } from "../helpers";
|
||||
import { TrackingBadge } from "./TrackingBadge";
|
||||
|
||||
function formatLogDate(value: unknown): string {
|
||||
if (value instanceof Date) return value.toLocaleString();
|
||||
if (typeof value === "string" || typeof value === "number")
|
||||
return new Date(value).toLocaleString();
|
||||
return "-";
|
||||
}
|
||||
|
||||
interface Props {
|
||||
logs: CostLogRow[];
|
||||
pagination: Pagination | null;
|
||||
onPageChange: (page: number) => void;
|
||||
}
|
||||
|
||||
function LogsTable({ logs, pagination, onPageChange }: Props) {
|
||||
return (
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="overflow-x-auto">
|
||||
<table className="w-full text-left text-sm">
|
||||
<thead className="border-b text-xs uppercase text-muted-foreground">
|
||||
<tr>
|
||||
<th scope="col" className="px-3 py-3">
|
||||
Time
|
||||
</th>
|
||||
<th scope="col" className="px-3 py-3">
|
||||
User
|
||||
</th>
|
||||
<th scope="col" className="px-3 py-3">
|
||||
Block
|
||||
</th>
|
||||
<th scope="col" className="px-3 py-3">
|
||||
Provider
|
||||
</th>
|
||||
<th scope="col" className="px-3 py-3">
|
||||
Type
|
||||
</th>
|
||||
<th scope="col" className="px-3 py-3">
|
||||
Model
|
||||
</th>
|
||||
<th scope="col" className="px-3 py-3 text-right">
|
||||
Cost
|
||||
</th>
|
||||
<th scope="col" className="px-3 py-3 text-right">
|
||||
Tokens
|
||||
</th>
|
||||
<th scope="col" className="px-3 py-3 text-right">
|
||||
Duration
|
||||
</th>
|
||||
<th scope="col" className="px-3 py-3">
|
||||
Session
|
||||
</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{logs.map((log) => (
|
||||
<tr key={log.id} className="border-b hover:bg-muted">
|
||||
<td className="whitespace-nowrap px-3 py-2 text-xs">
|
||||
{formatLogDate(log.created_at)}
|
||||
</td>
|
||||
<td className="px-3 py-2 text-xs">
|
||||
{log.email ||
|
||||
(log.user_id ? String(log.user_id).slice(0, 8) : "-")}
|
||||
</td>
|
||||
<td className="px-3 py-2 text-xs font-medium">
|
||||
{log.block_name}
|
||||
</td>
|
||||
<td className="px-3 py-2 text-xs">{log.provider}</td>
|
||||
<td className="px-3 py-2 text-xs">
|
||||
<TrackingBadge trackingType={log.tracking_type} />
|
||||
</td>
|
||||
<td className="px-3 py-2 text-xs">{log.model || "-"}</td>
|
||||
<td className="px-3 py-2 text-right text-xs">
|
||||
{log.cost_microdollars != null
|
||||
? formatMicrodollars(Number(log.cost_microdollars))
|
||||
: "-"}
|
||||
</td>
|
||||
<td className="px-3 py-2 text-right text-xs">
|
||||
{log.input_tokens != null || log.output_tokens != null
|
||||
? `${formatTokens(Number(log.input_tokens ?? 0))} / ${formatTokens(Number(log.output_tokens ?? 0))}`
|
||||
: "-"}
|
||||
</td>
|
||||
<td className="px-3 py-2 text-right text-xs">
|
||||
{log.duration != null
|
||||
? formatDuration(Number(log.duration))
|
||||
: "-"}
|
||||
</td>
|
||||
<td className="px-3 py-2 text-xs text-muted-foreground">
|
||||
{log.graph_exec_id
|
||||
? String(log.graph_exec_id).slice(0, 8)
|
||||
: "-"}
|
||||
</td>
|
||||
</tr>
|
||||
))}
|
||||
{logs.length === 0 && (
|
||||
<tr>
|
||||
<td
|
||||
colSpan={10}
|
||||
className="px-4 py-8 text-center text-muted-foreground"
|
||||
>
|
||||
No logs found
|
||||
</td>
|
||||
</tr>
|
||||
)}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
{pagination && pagination.total_pages > 1 && (
|
||||
<div className="flex items-center justify-between px-4">
|
||||
<span className="text-sm text-muted-foreground">
|
||||
Page {pagination.current_page} of {pagination.total_pages} (
|
||||
{pagination.total_items} total)
|
||||
</span>
|
||||
<div className="flex gap-2">
|
||||
<button
|
||||
disabled={pagination.current_page <= 1}
|
||||
onClick={() => onPageChange(pagination.current_page - 1)}
|
||||
className="rounded border px-3 py-1 text-sm disabled:opacity-50"
|
||||
>
|
||||
Previous
|
||||
</button>
|
||||
<button
|
||||
disabled={pagination.current_page >= pagination.total_pages}
|
||||
onClick={() => onPageChange(pagination.current_page + 1)}
|
||||
className="rounded border px-3 py-1 text-sm disabled:opacity-50"
|
||||
>
|
||||
Next
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export { LogsTable };
|
||||
@@ -0,0 +1,233 @@
|
||||
"use client";
|
||||
|
||||
import { Alert, AlertDescription } from "@/components/molecules/Alert/Alert";
|
||||
import { Skeleton } from "@/components/atoms/Skeleton/Skeleton";
|
||||
import { formatMicrodollars } from "../helpers";
|
||||
import { SummaryCard } from "./SummaryCard";
|
||||
import { ProviderTable } from "./ProviderTable";
|
||||
import { UserTable } from "./UserTable";
|
||||
import { LogsTable } from "./LogsTable";
|
||||
import { usePlatformCostContent } from "./usePlatformCostContent";
|
||||
|
||||
interface Props {
|
||||
searchParams: {
|
||||
start?: string;
|
||||
end?: string;
|
||||
provider?: string;
|
||||
user_id?: string;
|
||||
page?: string;
|
||||
tab?: string;
|
||||
};
|
||||
}
|
||||
|
||||
function PlatformCostContent({ searchParams }: Props) {
|
||||
const {
|
||||
dashboard,
|
||||
logs,
|
||||
pagination,
|
||||
loading,
|
||||
error,
|
||||
totalEstimatedCost,
|
||||
tab,
|
||||
startInput,
|
||||
setStartInput,
|
||||
endInput,
|
||||
setEndInput,
|
||||
providerInput,
|
||||
setProviderInput,
|
||||
userInput,
|
||||
setUserInput,
|
||||
rateOverrides,
|
||||
handleRateOverride,
|
||||
updateUrl,
|
||||
handleFilter,
|
||||
} = usePlatformCostContent(searchParams);
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-6">
|
||||
<div className="flex flex-wrap items-end gap-3 rounded-lg border p-4">
|
||||
<div className="flex flex-col gap-1">
|
||||
<label htmlFor="start-date" className="text-sm text-muted-foreground">
|
||||
Start Date <span className="text-xs">(local time)</span>
|
||||
</label>
|
||||
<input
|
||||
id="start-date"
|
||||
type="datetime-local"
|
||||
className="rounded border px-3 py-1.5 text-sm"
|
||||
value={startInput}
|
||||
onChange={(e) => setStartInput(e.target.value)}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex flex-col gap-1">
|
||||
<label htmlFor="end-date" className="text-sm text-muted-foreground">
|
||||
End Date <span className="text-xs">(local time)</span>
|
||||
</label>
|
||||
<input
|
||||
id="end-date"
|
||||
type="datetime-local"
|
||||
className="rounded border px-3 py-1.5 text-sm"
|
||||
value={endInput}
|
||||
onChange={(e) => setEndInput(e.target.value)}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex flex-col gap-1">
|
||||
<label
|
||||
htmlFor="provider-filter"
|
||||
className="text-sm text-muted-foreground"
|
||||
>
|
||||
Provider
|
||||
</label>
|
||||
<input
|
||||
id="provider-filter"
|
||||
type="text"
|
||||
placeholder="e.g. openai"
|
||||
className="rounded border px-3 py-1.5 text-sm"
|
||||
value={providerInput}
|
||||
onChange={(e) => setProviderInput(e.target.value)}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex flex-col gap-1">
|
||||
<label
|
||||
htmlFor="user-id-filter"
|
||||
className="text-sm text-muted-foreground"
|
||||
>
|
||||
User ID
|
||||
</label>
|
||||
<input
|
||||
id="user-id-filter"
|
||||
type="text"
|
||||
placeholder="Filter by user"
|
||||
className="rounded border px-3 py-1.5 text-sm"
|
||||
value={userInput}
|
||||
onChange={(e) => setUserInput(e.target.value)}
|
||||
/>
|
||||
</div>
|
||||
<button
|
||||
onClick={handleFilter}
|
||||
className="rounded bg-primary px-4 py-1.5 text-sm text-primary-foreground hover:bg-primary/90"
|
||||
>
|
||||
Apply
|
||||
</button>
|
||||
<button
|
||||
onClick={() => {
|
||||
setStartInput("");
|
||||
setEndInput("");
|
||||
setProviderInput("");
|
||||
setUserInput("");
|
||||
updateUrl({
|
||||
start: "",
|
||||
end: "",
|
||||
provider: "",
|
||||
user_id: "",
|
||||
page: "1",
|
||||
});
|
||||
}}
|
||||
className="rounded border px-4 py-1.5 text-sm hover:bg-muted"
|
||||
>
|
||||
Clear
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<Alert variant="error">
|
||||
<AlertDescription>{error}</AlertDescription>
|
||||
</Alert>
|
||||
)}
|
||||
|
||||
{loading ? (
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="grid grid-cols-2 gap-4 md:grid-cols-4">
|
||||
{[...Array(4)].map((_, i) => (
|
||||
<Skeleton key={i} className="h-20 rounded-lg" />
|
||||
))}
|
||||
</div>
|
||||
<Skeleton className="h-8 w-48 rounded" />
|
||||
<Skeleton className="h-64 rounded-lg" />
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
{dashboard && (
|
||||
<div className="grid grid-cols-2 gap-4 md:grid-cols-4">
|
||||
<SummaryCard
|
||||
label="Known Cost"
|
||||
value={formatMicrodollars(dashboard.total_cost_microdollars)}
|
||||
subtitle="From providers that report USD cost"
|
||||
/>
|
||||
<SummaryCard
|
||||
label="Estimated Total"
|
||||
value={formatMicrodollars(totalEstimatedCost)}
|
||||
subtitle="Including per-run cost estimates"
|
||||
/>
|
||||
<SummaryCard
|
||||
label="Total Requests"
|
||||
value={dashboard.total_requests.toLocaleString()}
|
||||
/>
|
||||
<SummaryCard
|
||||
label="Active Users"
|
||||
value={dashboard.total_users.toLocaleString()}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div
|
||||
role="tablist"
|
||||
aria-label="Cost view tabs"
|
||||
className="flex gap-2 border-b"
|
||||
>
|
||||
{["overview", "by-user", "logs"].map((t) => (
|
||||
<button
|
||||
key={t}
|
||||
id={`tab-${t}`}
|
||||
role="tab"
|
||||
aria-selected={tab === t}
|
||||
aria-controls={`tabpanel-${t}`}
|
||||
onClick={() => updateUrl({ tab: t, page: "1" })}
|
||||
className={`px-4 py-2 text-sm font-medium ${tab === t ? "border-b-2 border-primary text-primary" : "text-muted-foreground hover:text-foreground"}`}
|
||||
>
|
||||
{t === "overview"
|
||||
? "By Provider"
|
||||
: t === "by-user"
|
||||
? "By User"
|
||||
: "Raw Logs"}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{tab === "overview" && dashboard && (
|
||||
<div
|
||||
role="tabpanel"
|
||||
id="tabpanel-overview"
|
||||
aria-labelledby="tab-overview"
|
||||
>
|
||||
<ProviderTable
|
||||
data={dashboard.by_provider}
|
||||
rateOverrides={rateOverrides}
|
||||
onRateOverride={handleRateOverride}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
{tab === "by-user" && dashboard && (
|
||||
<div
|
||||
role="tabpanel"
|
||||
id="tabpanel-by-user"
|
||||
aria-labelledby="tab-by-user"
|
||||
>
|
||||
<UserTable data={dashboard.by_user} />
|
||||
</div>
|
||||
)}
|
||||
{tab === "logs" && (
|
||||
<div role="tabpanel" id="tabpanel-logs" aria-labelledby="tab-logs">
|
||||
<LogsTable
|
||||
logs={logs}
|
||||
pagination={pagination}
|
||||
onPageChange={(p) => updateUrl({ page: p.toString() })}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export { PlatformCostContent };
|
||||
@@ -0,0 +1,131 @@
|
||||
import type { ProviderCostSummary } from "@/app/api/__generated__/models/providerCostSummary";
|
||||
import {
|
||||
defaultRateFor,
|
||||
estimateCostForRow,
|
||||
formatMicrodollars,
|
||||
rateKey,
|
||||
rateUnitLabel,
|
||||
trackingValue,
|
||||
} from "../helpers";
|
||||
import { TrackingBadge } from "./TrackingBadge";
|
||||
|
||||
interface Props {
|
||||
data: ProviderCostSummary[];
|
||||
rateOverrides: Record<string, number>;
|
||||
onRateOverride: (key: string, val: number | null) => void;
|
||||
}
|
||||
|
||||
function ProviderTable({ data, rateOverrides, onRateOverride }: Props) {
|
||||
return (
|
||||
<div className="overflow-x-auto">
|
||||
<table className="w-full text-left text-sm">
|
||||
<thead className="border-b text-xs uppercase text-muted-foreground">
|
||||
<tr>
|
||||
<th scope="col" className="px-4 py-3">
|
||||
Provider
|
||||
</th>
|
||||
<th scope="col" className="px-4 py-3">
|
||||
Type
|
||||
</th>
|
||||
<th scope="col" className="px-4 py-3 text-right">
|
||||
Usage
|
||||
</th>
|
||||
<th scope="col" className="px-4 py-3 text-right">
|
||||
Requests
|
||||
</th>
|
||||
<th scope="col" className="px-4 py-3 text-right">
|
||||
Known Cost
|
||||
</th>
|
||||
<th scope="col" className="px-4 py-3 text-right">
|
||||
Est. Cost
|
||||
</th>
|
||||
<th
|
||||
scope="col"
|
||||
className="px-4 py-3 text-right"
|
||||
title="Per-session only"
|
||||
>
|
||||
Rate <span className="text-[10px] font-normal">(unsaved)</span>
|
||||
</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{data.map((row) => {
|
||||
const est = estimateCostForRow(row, rateOverrides);
|
||||
const tt = row.tracking_type || "per_run";
|
||||
// For cost_usd rows the provider reports USD directly so rate
|
||||
// input doesn't apply; otherwise show an editable input.
|
||||
const showRateInput = tt !== "cost_usd";
|
||||
const key = rateKey(row.provider, tt);
|
||||
const fallback = defaultRateFor(row.provider, tt);
|
||||
const currentRate = rateOverrides[key] ?? fallback;
|
||||
return (
|
||||
<tr key={key} className="border-b hover:bg-muted">
|
||||
<td className="px-4 py-3 font-medium">{row.provider}</td>
|
||||
<td className="px-4 py-3">
|
||||
<TrackingBadge trackingType={row.tracking_type} />
|
||||
</td>
|
||||
<td className="px-4 py-3 text-right">{trackingValue(row)}</td>
|
||||
<td className="px-4 py-3 text-right">
|
||||
{row.request_count.toLocaleString()}
|
||||
</td>
|
||||
<td className="px-4 py-3 text-right">
|
||||
{row.total_cost_microdollars > 0
|
||||
? formatMicrodollars(row.total_cost_microdollars)
|
||||
: "-"}
|
||||
</td>
|
||||
<td className="px-4 py-3 text-right">
|
||||
{est !== null ? (
|
||||
formatMicrodollars(est)
|
||||
) : (
|
||||
<span className="text-muted-foreground">-</span>
|
||||
)}
|
||||
</td>
|
||||
<td className="px-4 py-2 text-right">
|
||||
{showRateInput ? (
|
||||
<div className="flex items-center justify-end gap-1">
|
||||
<input
|
||||
type="number"
|
||||
step="0.0001"
|
||||
min="0"
|
||||
aria-label={`Rate for ${row.provider} (${tt})`}
|
||||
className="w-24 rounded border px-2 py-1 text-right text-xs"
|
||||
placeholder={fallback !== null ? String(fallback) : "0"}
|
||||
value={currentRate ?? ""}
|
||||
onChange={(e) => {
|
||||
const val = parseFloat(e.target.value);
|
||||
if (!isNaN(val)) onRateOverride(key, val);
|
||||
else if (e.target.value === "")
|
||||
onRateOverride(key, null);
|
||||
}}
|
||||
/>
|
||||
<span
|
||||
className="text-[10px] text-muted-foreground"
|
||||
title={rateUnitLabel(tt)}
|
||||
>
|
||||
{rateUnitLabel(tt)}
|
||||
</span>
|
||||
</div>
|
||||
) : (
|
||||
<span className="text-xs text-muted-foreground">auto</span>
|
||||
)}
|
||||
</td>
|
||||
</tr>
|
||||
);
|
||||
})}
|
||||
{data.length === 0 && (
|
||||
<tr>
|
||||
<td
|
||||
colSpan={7}
|
||||
className="px-4 py-8 text-center text-muted-foreground"
|
||||
>
|
||||
No cost data yet
|
||||
</td>
|
||||
</tr>
|
||||
)}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export { ProviderTable };
|
||||
@@ -0,0 +1,19 @@
|
||||
interface Props {
|
||||
label: string;
|
||||
value: string;
|
||||
subtitle?: string;
|
||||
}
|
||||
|
||||
function SummaryCard({ label, value, subtitle }: Props) {
|
||||
return (
|
||||
<div className="rounded-lg border p-4">
|
||||
<div className="text-sm text-muted-foreground">{label}</div>
|
||||
<div className="text-2xl font-bold">{value}</div>
|
||||
{subtitle && (
|
||||
<div className="mt-1 text-xs text-muted-foreground">{subtitle}</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export { SummaryCard };
|
||||
@@ -0,0 +1,25 @@
|
||||
function TrackingBadge({
|
||||
trackingType,
|
||||
}: {
|
||||
trackingType: string | null | undefined;
|
||||
}) {
|
||||
const colors: Record<string, string> = {
|
||||
cost_usd: "bg-green-500/10 text-green-700",
|
||||
tokens: "bg-blue-500/10 text-blue-700",
|
||||
characters: "bg-purple-500/10 text-purple-700",
|
||||
sandbox_seconds: "bg-orange-500/10 text-orange-700",
|
||||
walltime_seconds: "bg-orange-500/10 text-orange-700",
|
||||
items: "bg-pink-500/10 text-pink-700",
|
||||
per_run: "bg-muted text-muted-foreground",
|
||||
};
|
||||
const label = trackingType || "per_run";
|
||||
return (
|
||||
<span
|
||||
className={`inline-block rounded px-1.5 py-0.5 text-[10px] font-medium ${colors[label] || colors.per_run}`}
|
||||
>
|
||||
{label}
|
||||
</span>
|
||||
);
|
||||
}
|
||||
|
||||
export { TrackingBadge };
|
||||
@@ -0,0 +1,75 @@
|
||||
import type { PlatformCostDashboard } from "@/app/api/__generated__/models/platformCostDashboard";
|
||||
import { formatMicrodollars, formatTokens } from "../helpers";
|
||||
|
||||
interface Props {
|
||||
data: PlatformCostDashboard["by_user"];
|
||||
}
|
||||
|
||||
function UserTable({ data }: Props) {
|
||||
return (
|
||||
<div className="overflow-x-auto">
|
||||
<table className="w-full text-left text-sm">
|
||||
<thead className="border-b text-xs uppercase text-muted-foreground">
|
||||
<tr>
|
||||
<th scope="col" className="px-4 py-3">
|
||||
User
|
||||
</th>
|
||||
<th scope="col" className="px-4 py-3 text-right">
|
||||
Known Cost
|
||||
</th>
|
||||
<th scope="col" className="px-4 py-3 text-right">
|
||||
Requests
|
||||
</th>
|
||||
<th scope="col" className="px-4 py-3 text-right">
|
||||
Input Tokens
|
||||
</th>
|
||||
<th scope="col" className="px-4 py-3 text-right">
|
||||
Output Tokens
|
||||
</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{data.map((row, idx) => (
|
||||
<tr
|
||||
key={row.user_id ?? `unknown-${idx}`}
|
||||
className="border-b hover:bg-muted"
|
||||
>
|
||||
<td className="px-4 py-3">
|
||||
<div className="font-medium">{row.email || "Unknown"}</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{row.user_id}
|
||||
</div>
|
||||
</td>
|
||||
<td className="px-4 py-3 text-right">
|
||||
{row.total_cost_microdollars > 0
|
||||
? formatMicrodollars(row.total_cost_microdollars)
|
||||
: "-"}
|
||||
</td>
|
||||
<td className="px-4 py-3 text-right">
|
||||
{row.request_count.toLocaleString()}
|
||||
</td>
|
||||
<td className="px-4 py-3 text-right">
|
||||
{formatTokens(row.total_input_tokens)}
|
||||
</td>
|
||||
<td className="px-4 py-3 text-right">
|
||||
{formatTokens(row.total_output_tokens)}
|
||||
</td>
|
||||
</tr>
|
||||
))}
|
||||
{data.length === 0 && (
|
||||
<tr>
|
||||
<td
|
||||
colSpan={5}
|
||||
className="px-4 py-8 text-center text-muted-foreground"
|
||||
>
|
||||
No cost data yet
|
||||
</td>
|
||||
</tr>
|
||||
)}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export { UserTable };
|
||||
@@ -0,0 +1,136 @@
|
||||
"use client";
|
||||
|
||||
import { useRouter, useSearchParams } from "next/navigation";
|
||||
import { useState } from "react";
|
||||
import {
|
||||
useGetV2GetPlatformCostDashboard,
|
||||
useGetV2GetPlatformCostLogs,
|
||||
} from "@/app/api/__generated__/endpoints/admin/admin";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
import { estimateCostForRow, toLocalInput, toUtcIso } from "../helpers";
|
||||
|
||||
interface InitialSearchParams {
|
||||
start?: string;
|
||||
end?: string;
|
||||
provider?: string;
|
||||
user_id?: string;
|
||||
page?: string;
|
||||
tab?: string;
|
||||
}
|
||||
|
||||
export function usePlatformCostContent(searchParams: InitialSearchParams) {
|
||||
const router = useRouter();
|
||||
const urlParams = useSearchParams();
|
||||
|
||||
const tab = urlParams.get("tab") || searchParams.tab || "overview";
|
||||
const page = parseInt(urlParams.get("page") || searchParams.page || "1", 10);
|
||||
const startDate = urlParams.get("start") || searchParams.start || "";
|
||||
const endDate = urlParams.get("end") || searchParams.end || "";
|
||||
const providerFilter =
|
||||
urlParams.get("provider") || searchParams.provider || "";
|
||||
const userFilter = urlParams.get("user_id") || searchParams.user_id || "";
|
||||
|
||||
const [startInput, setStartInput] = useState(toLocalInput(startDate));
|
||||
const [endInput, setEndInput] = useState(toLocalInput(endDate));
|
||||
const [providerInput, setProviderInput] = useState(providerFilter);
|
||||
const [userInput, setUserInput] = useState(userFilter);
|
||||
const [rateOverrides, setRateOverrides] = useState<Record<string, number>>(
|
||||
{},
|
||||
);
|
||||
|
||||
// Pass ISO date strings through `as unknown as Date` so Orval's URL builder
|
||||
// forwards them as-is. Date.toString() produces a format FastAPI rejects;
|
||||
// strings pass through .toString() unchanged.
|
||||
const filterParams = {
|
||||
start: (startDate || undefined) as unknown as Date | undefined,
|
||||
end: (endDate || undefined) as unknown as Date | undefined,
|
||||
provider: providerFilter || undefined,
|
||||
user_id: userFilter || undefined,
|
||||
};
|
||||
|
||||
const {
|
||||
data: dashboard,
|
||||
isLoading: dashLoading,
|
||||
error: dashError,
|
||||
} = useGetV2GetPlatformCostDashboard(filterParams, {
|
||||
query: { select: okData },
|
||||
});
|
||||
|
||||
const {
|
||||
data: logsResponse,
|
||||
isLoading: logsLoading,
|
||||
error: logsError,
|
||||
} = useGetV2GetPlatformCostLogs(
|
||||
{ ...filterParams, page, page_size: 50 },
|
||||
{ query: { select: okData } },
|
||||
);
|
||||
|
||||
const loading = dashLoading || logsLoading;
|
||||
const error = dashError
|
||||
? dashError instanceof Error
|
||||
? dashError.message
|
||||
: "Failed to load dashboard"
|
||||
: logsError
|
||||
? logsError instanceof Error
|
||||
? logsError.message
|
||||
: "Failed to load logs"
|
||||
: null;
|
||||
|
||||
function updateUrl(overrides: Record<string, string>) {
|
||||
const params = new URLSearchParams(urlParams.toString());
|
||||
for (const [k, v] of Object.entries(overrides)) {
|
||||
if (v) params.set(k, v);
|
||||
else params.delete(k);
|
||||
}
|
||||
router.push(`/admin/platform-costs?${params.toString()}`);
|
||||
}
|
||||
|
||||
function handleFilter() {
|
||||
updateUrl({
|
||||
start: toUtcIso(startInput),
|
||||
end: toUtcIso(endInput),
|
||||
provider: providerInput,
|
||||
user_id: userInput,
|
||||
page: "1",
|
||||
});
|
||||
}
|
||||
|
||||
function handleRateOverride(key: string, val: number | null) {
|
||||
setRateOverrides((prev) => {
|
||||
if (val === null) {
|
||||
const { [key]: _, ...rest } = prev;
|
||||
return rest;
|
||||
}
|
||||
return { ...prev, [key]: val };
|
||||
});
|
||||
}
|
||||
|
||||
const totalEstimatedCost =
|
||||
dashboard?.by_provider.reduce((sum, row) => {
|
||||
const est = estimateCostForRow(row, rateOverrides);
|
||||
return sum + (est ?? 0);
|
||||
}, 0) ?? 0;
|
||||
|
||||
return {
|
||||
dashboard: dashboard ?? null,
|
||||
logs: logsResponse?.logs ?? [],
|
||||
pagination: logsResponse?.pagination ?? null,
|
||||
loading,
|
||||
error,
|
||||
totalEstimatedCost,
|
||||
tab,
|
||||
page,
|
||||
startInput,
|
||||
setStartInput,
|
||||
endInput,
|
||||
setEndInput,
|
||||
providerInput,
|
||||
setProviderInput,
|
||||
userInput,
|
||||
setUserInput,
|
||||
rateOverrides,
|
||||
handleRateOverride,
|
||||
updateUrl,
|
||||
handleFilter,
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,204 @@
|
||||
import type { ProviderCostSummary } from "@/app/api/__generated__/models/providerCostSummary";
|
||||
|
||||
const MICRODOLLARS_PER_USD = 1_000_000;
|
||||
|
||||
// Per-request cost estimates (USD) for providers billed per API call.
|
||||
export const DEFAULT_COST_PER_RUN: Record<string, number> = {
|
||||
google_maps: 0.032, // $0.032/request - Google Maps Places API
|
||||
ideogram: 0.08, // $0.08/image - Ideogram standard generation
|
||||
nvidia: 0.0, // Free tier - NVIDIA NIM deepfake detection
|
||||
screenshotone: 0.01, // ~$0.01/screenshot - ScreenshotOne starter
|
||||
zerobounce: 0.008, // $0.008/validation - ZeroBounce
|
||||
mem0: 0.01, // ~$0.01/request - Mem0
|
||||
openweathermap: 0.0, // Free tier
|
||||
webshare_proxy: 0.0, // Flat subscription
|
||||
enrichlayer: 0.1, // ~$0.10/profile lookup
|
||||
jina: 0.0, // Free tier
|
||||
};
|
||||
|
||||
export const DEFAULT_COST_PER_1K_TOKENS: Record<string, number> = {
|
||||
openai: 0.005,
|
||||
anthropic: 0.008,
|
||||
groq: 0.0003,
|
||||
ollama: 0.0,
|
||||
aiml_api: 0.005,
|
||||
llama_api: 0.003,
|
||||
v0: 0.005,
|
||||
};
|
||||
|
||||
// Per-character rates (USD / 1K characters) for TTS providers.
|
||||
export const DEFAULT_COST_PER_1K_CHARS: Record<string, number> = {
|
||||
unreal_speech: 0.008, // ~$8/1M chars on Starter
|
||||
elevenlabs: 0.18, // ~$0.18/1K chars on Starter
|
||||
d_id: 0.04, // ~$0.04/1K chars estimated
|
||||
};
|
||||
|
||||
// Per-item rates (USD / item) for item-count billed APIs.
|
||||
export const DEFAULT_COST_PER_ITEM: Record<string, number> = {
|
||||
google_maps: 0.017, // avg of $0.032 nearby + ~$0.015 detail enrich
|
||||
apollo: 0.02, // ~$0.02/contact on low-volume tiers
|
||||
smartlead: 0.001, // ~$0.001/lead added
|
||||
};
|
||||
|
||||
// Per-second rates (USD / second) for duration-billed providers.
|
||||
export const DEFAULT_COST_PER_SECOND: Record<string, number> = {
|
||||
e2b: 0.000014, // $0.000014/sec (2-core sandbox)
|
||||
fal: 0.0005, // varies by model, conservative
|
||||
replicate: 0.001, // varies by hardware
|
||||
revid: 0.01, // per-second of video
|
||||
};
|
||||
|
||||
export function toDateOrUndefined(val?: string): Date | undefined {
|
||||
if (!val) return undefined;
|
||||
const d = new Date(val);
|
||||
return isNaN(d.getTime()) ? undefined : d;
|
||||
}
|
||||
|
||||
export function formatMicrodollars(microdollars: number) {
|
||||
return `$${(microdollars / MICRODOLLARS_PER_USD).toFixed(4)}`;
|
||||
}
|
||||
|
||||
export function formatTokens(tokens: number) {
|
||||
if (tokens >= 1_000_000) return `${(tokens / 1_000_000).toFixed(1)}M`;
|
||||
if (tokens >= 1_000) return `${(tokens / 1_000).toFixed(1)}K`;
|
||||
return tokens.toString();
|
||||
}
|
||||
|
||||
export function formatDuration(seconds: number) {
|
||||
if (seconds >= 3600) return `${(seconds / 3600).toFixed(1)}h`;
|
||||
if (seconds >= 60) return `${(seconds / 60).toFixed(1)}m`;
|
||||
return `${seconds.toFixed(1)}s`;
|
||||
}
|
||||
|
||||
// Unit label for each tracking type — what the rate input represents.
|
||||
export function rateUnitLabel(trackingType: string | null | undefined): string {
|
||||
switch (trackingType) {
|
||||
case "tokens":
|
||||
return "$/1K tokens";
|
||||
case "characters":
|
||||
return "$/1K chars";
|
||||
case "items":
|
||||
return "$/item";
|
||||
case "sandbox_seconds":
|
||||
case "walltime_seconds":
|
||||
return "$/second";
|
||||
case "per_run":
|
||||
return "$/run";
|
||||
default:
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
// Default rate for a (provider, tracking_type) pair.
|
||||
export function defaultRateFor(
|
||||
provider: string,
|
||||
trackingType: string | null | undefined,
|
||||
): number | null {
|
||||
switch (trackingType) {
|
||||
case "tokens":
|
||||
return DEFAULT_COST_PER_1K_TOKENS[provider] ?? null;
|
||||
case "characters":
|
||||
return DEFAULT_COST_PER_1K_CHARS[provider] ?? null;
|
||||
case "items":
|
||||
return DEFAULT_COST_PER_ITEM[provider] ?? null;
|
||||
case "sandbox_seconds":
|
||||
case "walltime_seconds":
|
||||
return DEFAULT_COST_PER_SECOND[provider] ?? null;
|
||||
case "per_run":
|
||||
return DEFAULT_COST_PER_RUN[provider] ?? null;
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
// Overrides are keyed on `${provider}:${tracking_type}` since the same
|
||||
// provider can have multiple rows with different billing models.
|
||||
export function rateKey(
|
||||
provider: string,
|
||||
trackingType: string | null | undefined,
|
||||
): string {
|
||||
return `${provider}:${trackingType ?? "per_run"}`;
|
||||
}
|
||||
|
||||
export function estimateCostForRow(
|
||||
row: ProviderCostSummary,
|
||||
rateOverrides: Record<string, number>,
|
||||
) {
|
||||
const tt = row.tracking_type || "per_run";
|
||||
|
||||
// Providers that report USD directly: use known cost.
|
||||
if (tt === "cost_usd") return row.total_cost_microdollars;
|
||||
|
||||
// Prefer the real USD the provider reported if any, but only for token paths
|
||||
// where OpenRouter piggybacks on the tokens row via x-total-cost.
|
||||
if (tt === "tokens" && row.total_cost_microdollars > 0) {
|
||||
return row.total_cost_microdollars;
|
||||
}
|
||||
|
||||
const rate =
|
||||
rateOverrides[rateKey(row.provider, tt)] ??
|
||||
defaultRateFor(row.provider, tt);
|
||||
if (rate === null || rate === undefined) return null;
|
||||
|
||||
// Compute the amount for this tracking type, then multiply by rate.
|
||||
let amount: number;
|
||||
switch (tt) {
|
||||
case "tokens":
|
||||
// Rate is per-1K tokens.
|
||||
amount = (row.total_input_tokens + row.total_output_tokens) / 1000;
|
||||
break;
|
||||
case "characters":
|
||||
// Rate is per-1K chars. trackingAmount aggregates char counts.
|
||||
amount = (row.total_tracking_amount || 0) / 1000;
|
||||
break;
|
||||
case "items":
|
||||
amount = row.total_tracking_amount || 0;
|
||||
break;
|
||||
case "sandbox_seconds":
|
||||
case "walltime_seconds":
|
||||
amount = row.total_duration_seconds || 0;
|
||||
break;
|
||||
case "per_run":
|
||||
amount = row.request_count;
|
||||
break;
|
||||
default:
|
||||
return row.total_cost_microdollars > 0
|
||||
? row.total_cost_microdollars
|
||||
: null;
|
||||
}
|
||||
|
||||
return Math.round(rate * amount * MICRODOLLARS_PER_USD);
|
||||
}
|
||||
|
||||
export function trackingValue(row: ProviderCostSummary) {
|
||||
const tt = row.tracking_type || "per_run";
|
||||
if (tt === "cost_usd") return formatMicrodollars(row.total_cost_microdollars);
|
||||
if (tt === "tokens") {
|
||||
const tokens = row.total_input_tokens + row.total_output_tokens;
|
||||
return `${formatTokens(tokens)} tokens`;
|
||||
}
|
||||
if (tt === "sandbox_seconds" || tt === "walltime_seconds")
|
||||
return formatDuration(row.total_duration_seconds || 0);
|
||||
if (tt === "characters")
|
||||
return `${formatTokens(Math.round(row.total_tracking_amount || 0))} chars`;
|
||||
if (tt === "items")
|
||||
return `${Math.round(row.total_tracking_amount || 0).toLocaleString()} items`;
|
||||
return `${row.request_count.toLocaleString()} runs`;
|
||||
}
|
||||
|
||||
// URL holds UTC ISO; datetime-local inputs need local "YYYY-MM-DDTHH:mm".
|
||||
export function toLocalInput(iso: string) {
|
||||
if (!iso) return "";
|
||||
const d = new Date(iso);
|
||||
if (isNaN(d.getTime())) return "";
|
||||
const pad = (n: number) => String(n).padStart(2, "0");
|
||||
return `${d.getFullYear()}-${pad(d.getMonth() + 1)}-${pad(d.getDate())}T${pad(d.getHours())}:${pad(d.getMinutes())}`;
|
||||
}
|
||||
|
||||
// datetime-local emits naive local time; convert to UTC ISO so the
|
||||
// backend filter window matches what the admin sees in their browser.
|
||||
export function toUtcIso(local: string) {
|
||||
if (!local) return "";
|
||||
const d = new Date(local);
|
||||
return isNaN(d.getTime()) ? "" : d.toISOString();
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
import { withRoleAccess } from "@/lib/withRoleAccess";
|
||||
import { Suspense } from "react";
|
||||
import { PlatformCostContent } from "./components/PlatformCostContent";
|
||||
|
||||
type SearchParams = {
|
||||
start?: string;
|
||||
end?: string;
|
||||
provider?: string;
|
||||
user_id?: string;
|
||||
page?: string;
|
||||
tab?: string;
|
||||
};
|
||||
|
||||
function PlatformCostDashboard({
|
||||
searchParams,
|
||||
}: {
|
||||
searchParams: SearchParams;
|
||||
}) {
|
||||
return (
|
||||
<div className="mx-auto p-6">
|
||||
<div className="flex flex-col gap-4">
|
||||
<div>
|
||||
<h1 className="text-3xl font-bold">Platform Costs</h1>
|
||||
<p className="text-muted-foreground">
|
||||
Track real API costs incurred by system credentials across providers
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<Suspense
|
||||
key={JSON.stringify(searchParams)}
|
||||
fallback={
|
||||
<div className="py-10 text-center">Loading cost data...</div>
|
||||
}
|
||||
>
|
||||
<PlatformCostContent searchParams={searchParams} />
|
||||
</Suspense>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default async function PlatformCostDashboardPage({
|
||||
searchParams,
|
||||
}: {
|
||||
searchParams: Promise<SearchParams>;
|
||||
}) {
|
||||
const withAdminAccess = await withRoleAccess(["admin"]);
|
||||
const ProtectedDashboard = await withAdminAccess(PlatformCostDashboard);
|
||||
return <ProtectedDashboard searchParams={await searchParams} />;
|
||||
}
|
||||
@@ -7,6 +7,179 @@
|
||||
"version": "0.1"
|
||||
},
|
||||
"paths": {
|
||||
"/api/admin/platform-costs/dashboard": {
|
||||
"get": {
|
||||
"tags": ["v2", "admin", "platform-cost", "admin"],
|
||||
"summary": "Get Platform Cost Dashboard",
|
||||
"operationId": "getV2Get platform cost dashboard",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "start",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"anyOf": [
|
||||
{ "type": "string", "format": "date-time" },
|
||||
{ "type": "null" }
|
||||
],
|
||||
"title": "Start"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "end",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"anyOf": [
|
||||
{ "type": "string", "format": "date-time" },
|
||||
{ "type": "null" }
|
||||
],
|
||||
"title": "End"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "provider",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Provider"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "user_id",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "User Id"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/PlatformCostDashboard"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/admin/platform-costs/logs": {
|
||||
"get": {
|
||||
"tags": ["v2", "admin", "platform-cost", "admin"],
|
||||
"summary": "Get Platform Cost Logs",
|
||||
"operationId": "getV2Get platform cost logs",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "start",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"anyOf": [
|
||||
{ "type": "string", "format": "date-time" },
|
||||
{ "type": "null" }
|
||||
],
|
||||
"title": "Start"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "end",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"anyOf": [
|
||||
{ "type": "string", "format": "date-time" },
|
||||
{ "type": "null" }
|
||||
],
|
||||
"title": "End"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "provider",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Provider"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "user_id",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "User Id"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "page",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
"default": 1,
|
||||
"title": "Page"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "page_size",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"type": "integer",
|
||||
"maximum": 200,
|
||||
"minimum": 1,
|
||||
"default": 50,
|
||||
"title": "Page Size"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/PlatformCostLogsResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/analytics/log_raw_analytics": {
|
||||
"post": {
|
||||
"tags": ["analytics"],
|
||||
@@ -8733,6 +8906,61 @@
|
||||
],
|
||||
"title": "ContentType"
|
||||
},
|
||||
"CostLogRow": {
|
||||
"properties": {
|
||||
"id": { "type": "string", "title": "Id" },
|
||||
"created_at": {
|
||||
"type": "string",
|
||||
"format": "date-time",
|
||||
"title": "Created At"
|
||||
},
|
||||
"user_id": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "User Id"
|
||||
},
|
||||
"email": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Email"
|
||||
},
|
||||
"graph_exec_id": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Graph Exec Id"
|
||||
},
|
||||
"node_exec_id": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Node Exec Id"
|
||||
},
|
||||
"block_name": { "type": "string", "title": "Block Name" },
|
||||
"provider": { "type": "string", "title": "Provider" },
|
||||
"tracking_type": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Tracking Type"
|
||||
},
|
||||
"cost_microdollars": {
|
||||
"anyOf": [{ "type": "integer" }, { "type": "null" }],
|
||||
"title": "Cost Microdollars"
|
||||
},
|
||||
"input_tokens": {
|
||||
"anyOf": [{ "type": "integer" }, { "type": "null" }],
|
||||
"title": "Input Tokens"
|
||||
},
|
||||
"output_tokens": {
|
||||
"anyOf": [{ "type": "integer" }, { "type": "null" }],
|
||||
"title": "Output Tokens"
|
||||
},
|
||||
"duration": {
|
||||
"anyOf": [{ "type": "number" }, { "type": "null" }],
|
||||
"title": "Duration"
|
||||
},
|
||||
"model": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Model"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["id", "created_at", "block_name", "provider"],
|
||||
"title": "CostLogRow"
|
||||
},
|
||||
"CountResponse": {
|
||||
"properties": {
|
||||
"all_blocks": { "type": "integer", "title": "All Blocks" },
|
||||
@@ -11664,6 +11892,48 @@
|
||||
"title": "PendingHumanReviewModel",
|
||||
"description": "Response model for pending human review data.\n\nRepresents a human review request that is awaiting user action.\nContains all necessary information for a user to review and approve\nor reject data from a Human-in-the-Loop block execution.\n\nAttributes:\n id: Unique identifier for the review record\n user_id: ID of the user who must perform the review\n node_exec_id: ID of the node execution that created this review\n node_id: ID of the node definition (for grouping reviews from same node)\n graph_exec_id: ID of the graph execution containing the node\n graph_id: ID of the graph template being executed\n graph_version: Version number of the graph template\n payload: The actual data payload awaiting review\n instructions: Instructions or message for the reviewer\n editable: Whether the reviewer can edit the data\n status: Current review status (WAITING, APPROVED, or REJECTED)\n review_message: Optional message from the reviewer\n created_at: Timestamp when review was created\n updated_at: Timestamp when review was last modified\n reviewed_at: Timestamp when review was completed (if applicable)"
|
||||
},
|
||||
"PlatformCostDashboard": {
|
||||
"properties": {
|
||||
"by_provider": {
|
||||
"items": { "$ref": "#/components/schemas/ProviderCostSummary" },
|
||||
"type": "array",
|
||||
"title": "By Provider"
|
||||
},
|
||||
"by_user": {
|
||||
"items": { "$ref": "#/components/schemas/UserCostSummary" },
|
||||
"type": "array",
|
||||
"title": "By User"
|
||||
},
|
||||
"total_cost_microdollars": {
|
||||
"type": "integer",
|
||||
"title": "Total Cost Microdollars"
|
||||
},
|
||||
"total_requests": { "type": "integer", "title": "Total Requests" },
|
||||
"total_users": { "type": "integer", "title": "Total Users" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": [
|
||||
"by_provider",
|
||||
"by_user",
|
||||
"total_cost_microdollars",
|
||||
"total_requests",
|
||||
"total_users"
|
||||
],
|
||||
"title": "PlatformCostDashboard"
|
||||
},
|
||||
"PlatformCostLogsResponse": {
|
||||
"properties": {
|
||||
"logs": {
|
||||
"items": { "$ref": "#/components/schemas/CostLogRow" },
|
||||
"type": "array",
|
||||
"title": "Logs"
|
||||
},
|
||||
"pagination": { "$ref": "#/components/schemas/Pagination" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["logs", "pagination"],
|
||||
"title": "PlatformCostLogsResponse"
|
||||
},
|
||||
"PostmarkBounceEnum": {
|
||||
"type": "integer",
|
||||
"enum": [
|
||||
@@ -12058,6 +12328,47 @@
|
||||
"title": "ProviderConstants",
|
||||
"description": "Model that exposes all provider names as a constant in the OpenAPI schema.\nThis is designed to be converted by Orval into a TypeScript constant."
|
||||
},
|
||||
"ProviderCostSummary": {
|
||||
"properties": {
|
||||
"provider": { "type": "string", "title": "Provider" },
|
||||
"tracking_type": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Tracking Type"
|
||||
},
|
||||
"total_cost_microdollars": {
|
||||
"type": "integer",
|
||||
"title": "Total Cost Microdollars"
|
||||
},
|
||||
"total_input_tokens": {
|
||||
"type": "integer",
|
||||
"title": "Total Input Tokens"
|
||||
},
|
||||
"total_output_tokens": {
|
||||
"type": "integer",
|
||||
"title": "Total Output Tokens"
|
||||
},
|
||||
"total_duration_seconds": {
|
||||
"type": "number",
|
||||
"title": "Total Duration Seconds",
|
||||
"default": 0.0
|
||||
},
|
||||
"total_tracking_amount": {
|
||||
"type": "number",
|
||||
"title": "Total Tracking Amount",
|
||||
"default": 0.0
|
||||
},
|
||||
"request_count": { "type": "integer", "title": "Request Count" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": [
|
||||
"provider",
|
||||
"total_cost_microdollars",
|
||||
"total_input_tokens",
|
||||
"total_output_tokens",
|
||||
"request_count"
|
||||
],
|
||||
"title": "ProviderCostSummary"
|
||||
},
|
||||
"ProviderEnumResponse": {
|
||||
"properties": {
|
||||
"provider": {
|
||||
@@ -14938,6 +15249,39 @@
|
||||
"title": "UsageWindow",
|
||||
"description": "Usage within a single time window."
|
||||
},
|
||||
"UserCostSummary": {
|
||||
"properties": {
|
||||
"user_id": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "User Id"
|
||||
},
|
||||
"email": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Email"
|
||||
},
|
||||
"total_cost_microdollars": {
|
||||
"type": "integer",
|
||||
"title": "Total Cost Microdollars"
|
||||
},
|
||||
"total_input_tokens": {
|
||||
"type": "integer",
|
||||
"title": "Total Input Tokens"
|
||||
},
|
||||
"total_output_tokens": {
|
||||
"type": "integer",
|
||||
"title": "Total Output Tokens"
|
||||
},
|
||||
"request_count": { "type": "integer", "title": "Request Count" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": [
|
||||
"total_cost_microdollars",
|
||||
"total_input_tokens",
|
||||
"total_output_tokens",
|
||||
"request_count"
|
||||
],
|
||||
"title": "UserCostSummary"
|
||||
},
|
||||
"UserHistoryResponse": {
|
||||
"properties": {
|
||||
"history": {
|
||||
|
||||
BIN
test-screenshots/PR-12696/01-after-login.png
Normal file
|
After Width: | Height: | Size: 78 KiB |
BIN
test-screenshots/PR-12696/02-admin-platform-costs.png
Normal file
|
After Width: | Height: | Size: 88 KiB |
BIN
test-screenshots/PR-12696/03-by-provider-table.png
Normal file
|
After Width: | Height: | Size: 65 KiB |
BIN
test-screenshots/PR-12696/04-by-user-tab.png
Normal file
|
After Width: | Height: | Size: 86 KiB |
BIN
test-screenshots/PR-12696/05-by-user-rows.png
Normal file
|
After Width: | Height: | Size: 46 KiB |
BIN
test-screenshots/PR-12696/06-raw-logs-tab.png
Normal file
|
After Width: | Height: | Size: 124 KiB |
BIN
test-screenshots/PR-12696/07-provider-filter.png
Normal file
|
After Width: | Height: | Size: 88 KiB |
BIN
test-screenshots/PR-12696/08-retest-dashboard.png
Normal file
|
After Width: | Height: | Size: 88 KiB |