fix(backend): fix TestGetPlatformCostDashboard mocks to match 3-query implementation

get_platform_cost_dashboard runs 3 concurrent queries (by_provider,
by_user, COUNT DISTINCT userId) but the unit tests only provided 2
side_effect values, causing StopAsyncIteration on the third call.
Updated all three test cases to supply a third mock return value and
corrected await_count assertion from 2 to 3.
This commit is contained in:
Zamil Majdy
2026-04-07 23:15:42 +07:00
parent 91af007c18
commit 0e310c788a

View File

@@ -171,9 +171,8 @@ class TestGetPlatformCostDashboard:
"request_count": 3,
}
]
# Dashboard now runs 2 queries (by_provider + by_user) — total_users is
# derived from len(by_user_rows) instead of a separate COUNT query.
mock_query = AsyncMock(side_effect=[provider_rows, user_rows])
# Dashboard runs 3 queries: by_provider, by_user, COUNT(DISTINCT userId).
mock_query = AsyncMock(side_effect=[provider_rows, user_rows, [{"cnt": 1}]])
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
dashboard = await get_platform_cost_dashboard()
assert dashboard.total_cost_microdollars == 5000
@@ -188,7 +187,7 @@ class TestGetPlatformCostDashboard:
@pytest.mark.asyncio
async def test_returns_empty_dashboard(self):
mock_query = AsyncMock(side_effect=[[], []])
mock_query = AsyncMock(side_effect=[[], [], []])
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
dashboard = await get_platform_cost_dashboard()
assert dashboard.total_cost_microdollars == 0
@@ -200,12 +199,12 @@ class TestGetPlatformCostDashboard:
@pytest.mark.asyncio
async def test_passes_filters_to_queries(self):
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
mock_query = AsyncMock(side_effect=[[], []])
mock_query = AsyncMock(side_effect=[[], [], []])
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
await get_platform_cost_dashboard(
start=start, provider="openai", user_id="u1"
)
assert mock_query.await_count == 2
assert mock_query.await_count == 3
first_call_sql = mock_query.call_args_list[0][0][0]
assert "createdAt" in first_call_sql