feat(platform): add rate-limit tiering system for CoPilot (#12581)

## Summary
- Adds a four-tier subscription system (FREE/PRO/BUSINESS/ENTERPRISE)
for CoPilot with configurable multipliers (1x/5x/20x/60x) applied on top
of the base LaunchDarkly/config limits
- Stores user tier in the database (`User.subscriptionTier` column as a
Prisma enum, defaults to PRO for beta testing) with admin API endpoints
for tier management
- Includes tier info in usage status responses and OTEL/Langfuse trace
metadata for observability

## Tier Structure
| Tier | Multiplier | Daily Tokens | Weekly Tokens | Notes |
|------|-----------|-------------|--------------|-------|
| FREE | 1x | 2.5M | 12.5M | Base tier (unused during beta) |
| PRO | 5x | 12.5M | 62.5M | Default on sign-up (beta) |
| BUSINESS | 20x | 50M | 250M | Manual upgrade for select users |
| ENTERPRISE | 60x | 150M | 750M | Highest tier, custom |

## Changes
- **`rate_limit.py`**: `SubscriptionTier` enum
(FREE/PRO/BUSINESS/ENTERPRISE), `TIER_MULTIPLIERS`, `get_user_tier()`,
`set_user_tier()`, update `get_global_rate_limits()` to apply tier
multiplier and return 3-tuple, add `tier` field to `CoPilotUsageStatus`
- **`rate_limit_admin_routes.py`**: Add `GET/POST
/admin/rate_limit/tier` endpoints, include `tier` in
`UserRateLimitResponse`
- **`routes.py`** (chat): Include tier in `/usage` endpoint response
- **`sdk/service.py`**: Send `subscription_tier` in OTEL/Langfuse trace
metadata
- **`schema.prisma`**: Add `SubscriptionTier` enum and
`subscriptionTier` column to `User` model (default: PRO)
- **`config.py`**: Update docs to reflect tier system
- **Migration**: `20260326200000_add_rate_limit_tier` — creates enum,
migrates STANDARD→PRO, adds BUSINESS, sets default to PRO

## Test plan
- [x] 72 unit tests all passing (43 rate_limit + 11 admin routes + 18
chat routes)
- [ ] Verify FREE tier users get base limits (2.5M daily, 12.5M weekly)
- [ ] Verify PRO tier users get 5x limits (12.5M daily, 62.5M weekly)
- [ ] Verify BUSINESS tier users get 20x limits (50M daily, 250M weekly)
- [ ] Verify ENTERPRISE tier users get 60x limits (150M daily, 750M
weekly)
- [ ] Verify admin can read and set user tiers via API
- [ ] Verify tier info appears in Langfuse traces
- [ ] Verify migration applies cleanly (creates enum, migrates STANDARD
users to PRO, adds BUSINESS, default PRO)

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: Nicholas Tindle <nicholas.tindle@agpt.co>
This commit is contained in:
Zamil Majdy
2026-04-03 15:36:01 +02:00
committed by GitHub
parent 08bb05141c
commit 2b0e8a5a9f
30 changed files with 3166 additions and 121 deletions

View File

@@ -9,11 +9,14 @@ from pydantic import BaseModel
from backend.copilot.config import ChatConfig
from backend.copilot.rate_limit import (
SubscriptionTier,
get_global_rate_limits,
get_usage_status,
get_user_tier,
reset_user_usage,
set_user_tier,
)
from backend.data.user import get_user_by_email, get_user_email_by_id
from backend.data.user import get_user_by_email, get_user_email_by_id, search_users
logger = logging.getLogger(__name__)
@@ -33,6 +36,17 @@ class UserRateLimitResponse(BaseModel):
weekly_token_limit: int
daily_tokens_used: int
weekly_tokens_used: int
tier: SubscriptionTier
class UserTierResponse(BaseModel):
user_id: str
tier: SubscriptionTier
class SetUserTierRequest(BaseModel):
user_id: str
tier: SubscriptionTier
async def _resolve_user_id(
@@ -86,10 +100,10 @@ async def get_user_rate_limit(
logger.info("Admin %s checking rate limit for user %s", admin_user_id, resolved_id)
daily_limit, weekly_limit = await get_global_rate_limits(
daily_limit, weekly_limit, tier = await get_global_rate_limits(
resolved_id, config.daily_token_limit, config.weekly_token_limit
)
usage = await get_usage_status(resolved_id, daily_limit, weekly_limit)
usage = await get_usage_status(resolved_id, daily_limit, weekly_limit, tier=tier)
return UserRateLimitResponse(
user_id=resolved_id,
@@ -98,6 +112,7 @@ async def get_user_rate_limit(
weekly_token_limit=weekly_limit,
daily_tokens_used=usage.daily.used,
weekly_tokens_used=usage.weekly.used,
tier=tier,
)
@@ -125,10 +140,10 @@ async def reset_user_rate_limit(
logger.exception("Failed to reset user usage")
raise HTTPException(status_code=500, detail="Failed to reset usage") from e
daily_limit, weekly_limit = await get_global_rate_limits(
daily_limit, weekly_limit, tier = await get_global_rate_limits(
user_id, config.daily_token_limit, config.weekly_token_limit
)
usage = await get_usage_status(user_id, daily_limit, weekly_limit)
usage = await get_usage_status(user_id, daily_limit, weekly_limit, tier=tier)
try:
resolved_email = await get_user_email_by_id(user_id)
@@ -143,4 +158,102 @@ async def reset_user_rate_limit(
weekly_token_limit=weekly_limit,
daily_tokens_used=usage.daily.used,
weekly_tokens_used=usage.weekly.used,
tier=tier,
)
@router.get(
"/rate_limit/tier",
response_model=UserTierResponse,
summary="Get User Rate Limit Tier",
)
async def get_user_rate_limit_tier(
user_id: str,
admin_user_id: str = Security(get_user_id),
) -> UserTierResponse:
"""Get a user's current rate-limit tier. Admin-only.
Returns 404 if the user does not exist in the database.
"""
logger.info("Admin %s checking tier for user %s", admin_user_id, user_id)
resolved_email = await get_user_email_by_id(user_id)
if resolved_email is None:
raise HTTPException(status_code=404, detail=f"User {user_id} not found")
tier = await get_user_tier(user_id)
return UserTierResponse(user_id=user_id, tier=tier)
@router.post(
"/rate_limit/tier",
response_model=UserTierResponse,
summary="Set User Rate Limit Tier",
)
async def set_user_rate_limit_tier(
request: SetUserTierRequest,
admin_user_id: str = Security(get_user_id),
) -> UserTierResponse:
"""Set a user's rate-limit tier. Admin-only.
Returns 404 if the user does not exist in the database.
"""
try:
resolved_email = await get_user_email_by_id(request.user_id)
except Exception:
logger.warning(
"Failed to resolve email for user %s",
request.user_id,
exc_info=True,
)
resolved_email = None
if resolved_email is None:
raise HTTPException(status_code=404, detail=f"User {request.user_id} not found")
old_tier = await get_user_tier(request.user_id)
logger.info(
"Admin %s changing tier for user %s (%s): %s -> %s",
admin_user_id,
request.user_id,
resolved_email,
old_tier.value,
request.tier.value,
)
try:
await set_user_tier(request.user_id, request.tier)
except Exception as e:
logger.exception("Failed to set user tier")
raise HTTPException(status_code=500, detail="Failed to set tier") from e
return UserTierResponse(user_id=request.user_id, tier=request.tier)
class UserSearchResult(BaseModel):
user_id: str
user_email: Optional[str] = None
@router.get(
"/rate_limit/search_users",
response_model=list[UserSearchResult],
summary="Search Users by Name or Email",
)
async def admin_search_users(
query: str,
limit: int = 20,
admin_user_id: str = Security(get_user_id),
) -> list[UserSearchResult]:
"""Search users by partial email or name. Admin-only.
Queries the User table directly — returns results even for users
without credit transaction history.
"""
if len(query.strip()) < 3:
raise HTTPException(
status_code=400,
detail="Search query must be at least 3 characters.",
)
logger.info("Admin %s searching users with query=%r", admin_user_id, query)
results = await search_users(query, limit=max(1, min(limit, 50)))
return [UserSearchResult(user_id=uid, user_email=email) for uid, email in results]

View File

@@ -9,7 +9,7 @@ import pytest_mock
from autogpt_libs.auth.jwt_utils import get_jwt_payload
from pytest_snapshot.plugin import Snapshot
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
from backend.copilot.rate_limit import CoPilotUsageStatus, SubscriptionTier, UsageWindow
from .rate_limit_admin_routes import router as rate_limit_admin_router
@@ -57,7 +57,7 @@ def _patch_rate_limit_deps(
mocker.patch(
f"{_MOCK_MODULE}.get_global_rate_limits",
new_callable=AsyncMock,
return_value=(2_500_000, 12_500_000),
return_value=(2_500_000, 12_500_000, SubscriptionTier.FREE),
)
mocker.patch(
f"{_MOCK_MODULE}.get_usage_status",
@@ -89,6 +89,7 @@ def test_get_rate_limit(
assert data["weekly_token_limit"] == 12_500_000
assert data["daily_tokens_used"] == 500_000
assert data["weekly_tokens_used"] == 3_000_000
assert data["tier"] == "FREE"
configured_snapshot.assert_match(
json.dumps(data, indent=2, sort_keys=True) + "\n",
@@ -162,6 +163,7 @@ def test_reset_user_usage_daily_only(
assert data["daily_tokens_used"] == 0
# Weekly is untouched
assert data["weekly_tokens_used"] == 3_000_000
assert data["tier"] == "FREE"
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=False)
@@ -192,6 +194,7 @@ def test_reset_user_usage_daily_and_weekly(
data = response.json()
assert data["daily_tokens_used"] == 0
assert data["weekly_tokens_used"] == 0
assert data["tier"] == "FREE"
mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=True)
@@ -228,7 +231,7 @@ def test_get_rate_limit_email_lookup_failure(
mocker.patch(
f"{_MOCK_MODULE}.get_global_rate_limits",
new_callable=AsyncMock,
return_value=(2_500_000, 12_500_000),
return_value=(2_500_000, 12_500_000, SubscriptionTier.FREE),
)
mocker.patch(
f"{_MOCK_MODULE}.get_usage_status",
@@ -261,3 +264,303 @@ def test_admin_endpoints_require_admin_role(mock_jwt_user) -> None:
json={"user_id": "test"},
)
assert response.status_code == 403
# ---------------------------------------------------------------------------
# Tier management endpoints
# ---------------------------------------------------------------------------
def test_get_user_tier(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
) -> None:
"""Test getting a user's rate-limit tier."""
mocker.patch(
f"{_MOCK_MODULE}.get_user_email_by_id",
new_callable=AsyncMock,
return_value=_TARGET_EMAIL,
)
mocker.patch(
f"{_MOCK_MODULE}.get_user_tier",
new_callable=AsyncMock,
return_value=SubscriptionTier.PRO,
)
response = client.get("/admin/rate_limit/tier", params={"user_id": target_user_id})
assert response.status_code == 200
data = response.json()
assert data["user_id"] == target_user_id
assert data["tier"] == "PRO"
def test_get_user_tier_user_not_found(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
) -> None:
"""Test that getting tier for a non-existent user returns 404."""
mocker.patch(
f"{_MOCK_MODULE}.get_user_email_by_id",
new_callable=AsyncMock,
return_value=None,
)
response = client.get("/admin/rate_limit/tier", params={"user_id": target_user_id})
assert response.status_code == 404
def test_set_user_tier(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
) -> None:
"""Test setting a user's rate-limit tier (upgrade)."""
mocker.patch(
f"{_MOCK_MODULE}.get_user_email_by_id",
new_callable=AsyncMock,
return_value=_TARGET_EMAIL,
)
mocker.patch(
f"{_MOCK_MODULE}.get_user_tier",
new_callable=AsyncMock,
return_value=SubscriptionTier.FREE,
)
mock_set = mocker.patch(
f"{_MOCK_MODULE}.set_user_tier",
new_callable=AsyncMock,
)
response = client.post(
"/admin/rate_limit/tier",
json={"user_id": target_user_id, "tier": "ENTERPRISE"},
)
assert response.status_code == 200
data = response.json()
assert data["user_id"] == target_user_id
assert data["tier"] == "ENTERPRISE"
mock_set.assert_awaited_once_with(target_user_id, SubscriptionTier.ENTERPRISE)
def test_set_user_tier_downgrade(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
) -> None:
"""Test downgrading a user's tier from PRO to FREE."""
mocker.patch(
f"{_MOCK_MODULE}.get_user_email_by_id",
new_callable=AsyncMock,
return_value=_TARGET_EMAIL,
)
mocker.patch(
f"{_MOCK_MODULE}.get_user_tier",
new_callable=AsyncMock,
return_value=SubscriptionTier.PRO,
)
mock_set = mocker.patch(
f"{_MOCK_MODULE}.set_user_tier",
new_callable=AsyncMock,
)
response = client.post(
"/admin/rate_limit/tier",
json={"user_id": target_user_id, "tier": "FREE"},
)
assert response.status_code == 200
data = response.json()
assert data["user_id"] == target_user_id
assert data["tier"] == "FREE"
mock_set.assert_awaited_once_with(target_user_id, SubscriptionTier.FREE)
def test_set_user_tier_invalid_tier(
target_user_id: str,
) -> None:
"""Test that setting an invalid tier returns 422."""
response = client.post(
"/admin/rate_limit/tier",
json={"user_id": target_user_id, "tier": "invalid"},
)
assert response.status_code == 422
def test_set_user_tier_invalid_tier_uppercase(
target_user_id: str,
) -> None:
"""Test that setting an unrecognised uppercase tier (e.g. 'INVALID') returns 422.
Regression: ensures Pydantic enum validation rejects values that are not
members of SubscriptionTier, even when they look like valid enum names.
"""
response = client.post(
"/admin/rate_limit/tier",
json={"user_id": target_user_id, "tier": "INVALID"},
)
assert response.status_code == 422
body = response.json()
assert "detail" in body
def test_set_user_tier_email_lookup_failure_returns_404(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
) -> None:
"""Test that email lookup failure returns 404 (user unverifiable)."""
mocker.patch(
f"{_MOCK_MODULE}.get_user_email_by_id",
new_callable=AsyncMock,
side_effect=Exception("DB connection failed"),
)
response = client.post(
"/admin/rate_limit/tier",
json={"user_id": target_user_id, "tier": "PRO"},
)
assert response.status_code == 404
def test_set_user_tier_user_not_found(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
) -> None:
"""Test that setting tier for a non-existent user returns 404."""
mocker.patch(
f"{_MOCK_MODULE}.get_user_email_by_id",
new_callable=AsyncMock,
return_value=None,
)
response = client.post(
"/admin/rate_limit/tier",
json={"user_id": target_user_id, "tier": "PRO"},
)
assert response.status_code == 404
def test_set_user_tier_db_failure(
mocker: pytest_mock.MockerFixture,
target_user_id: str,
) -> None:
"""Test that DB failure on set tier returns 500."""
mocker.patch(
f"{_MOCK_MODULE}.get_user_email_by_id",
new_callable=AsyncMock,
return_value=_TARGET_EMAIL,
)
mocker.patch(
f"{_MOCK_MODULE}.get_user_tier",
new_callable=AsyncMock,
return_value=SubscriptionTier.FREE,
)
mocker.patch(
f"{_MOCK_MODULE}.set_user_tier",
new_callable=AsyncMock,
side_effect=Exception("DB connection refused"),
)
response = client.post(
"/admin/rate_limit/tier",
json={"user_id": target_user_id, "tier": "PRO"},
)
assert response.status_code == 500
def test_tier_endpoints_require_admin_role(mock_jwt_user) -> None:
"""Test that tier admin endpoints require admin role."""
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
response = client.get("/admin/rate_limit/tier", params={"user_id": "test"})
assert response.status_code == 403
response = client.post(
"/admin/rate_limit/tier",
json={"user_id": "test", "tier": "PRO"},
)
assert response.status_code == 403
# ─── search_users endpoint ──────────────────────────────────────────
def test_search_users_returns_matching_users(
mocker: pytest_mock.MockerFixture,
admin_user_id: str,
) -> None:
"""Partial search should return all matching users from the User table."""
mocker.patch(
_MOCK_MODULE + ".search_users",
new_callable=AsyncMock,
return_value=[
("user-1", "zamil.majdy@gmail.com"),
("user-2", "zamil.majdy@agpt.co"),
],
)
response = client.get("/admin/rate_limit/search_users", params={"query": "zamil"})
assert response.status_code == 200
results = response.json()
assert len(results) == 2
assert results[0]["user_email"] == "zamil.majdy@gmail.com"
assert results[1]["user_email"] == "zamil.majdy@agpt.co"
def test_search_users_empty_results(
mocker: pytest_mock.MockerFixture,
admin_user_id: str,
) -> None:
"""Search with no matches returns empty list."""
mocker.patch(
_MOCK_MODULE + ".search_users",
new_callable=AsyncMock,
return_value=[],
)
response = client.get(
"/admin/rate_limit/search_users", params={"query": "nonexistent"}
)
assert response.status_code == 200
assert response.json() == []
def test_search_users_short_query_rejected(
admin_user_id: str,
) -> None:
"""Query shorter than 3 characters should return 400."""
response = client.get("/admin/rate_limit/search_users", params={"query": "ab"})
assert response.status_code == 400
def test_search_users_negative_limit_clamped(
mocker: pytest_mock.MockerFixture,
admin_user_id: str,
) -> None:
"""Negative limit should be clamped to 1, not passed through."""
mock_search = mocker.patch(
_MOCK_MODULE + ".search_users",
new_callable=AsyncMock,
return_value=[],
)
response = client.get(
"/admin/rate_limit/search_users", params={"query": "test", "limit": -1}
)
assert response.status_code == 200
mock_search.assert_awaited_once_with("test", limit=1)
def test_search_users_requires_admin_role(mock_jwt_user) -> None:
"""Test that the search_users endpoint requires admin role."""
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
response = client.get("/admin/rate_limit/search_users", params={"query": "test"})
assert response.status_code == 403

View File

@@ -456,8 +456,9 @@ async def get_copilot_usage(
Returns current token usage vs limits for daily and weekly windows.
Global defaults sourced from LaunchDarkly (falling back to config).
Includes the user's rate-limit tier.
"""
daily_limit, weekly_limit = await get_global_rate_limits(
daily_limit, weekly_limit, tier = await get_global_rate_limits(
user_id, config.daily_token_limit, config.weekly_token_limit
)
return await get_usage_status(
@@ -465,6 +466,7 @@ async def get_copilot_usage(
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
rate_limit_reset_cost=config.rate_limit_reset_cost,
tier=tier,
)
@@ -516,7 +518,7 @@ async def reset_copilot_usage(
detail="Rate limit reset is not available (credit system is disabled).",
)
daily_limit, weekly_limit = await get_global_rate_limits(
daily_limit, weekly_limit, tier = await get_global_rate_limits(
user_id, config.daily_token_limit, config.weekly_token_limit
)
@@ -556,6 +558,7 @@ async def reset_copilot_usage(
user_id=user_id,
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
tier=tier,
)
if daily_limit > 0 and usage_status.daily.used < daily_limit:
raise HTTPException(
@@ -631,6 +634,7 @@ async def reset_copilot_usage(
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
rate_limit_reset_cost=config.rate_limit_reset_cost,
tier=tier,
)
return RateLimitResetResponse(
@@ -741,7 +745,7 @@ async def stream_chat_post(
# Global defaults sourced from LaunchDarkly, falling back to config.
if user_id:
try:
daily_limit, weekly_limit = await get_global_rate_limits(
daily_limit, weekly_limit, _ = await get_global_rate_limits(
user_id, config.daily_token_limit, config.weekly_token_limit
)
await check_rate_limit(

View File

@@ -9,6 +9,7 @@ import pytest
import pytest_mock
from backend.api.features.chat import routes as chat_routes
from backend.copilot.rate_limit import SubscriptionTier
app = fastapi.FastAPI()
app.include_router(chat_routes.router)
@@ -331,14 +332,28 @@ def _mock_usage(
*,
daily_used: int = 500,
weekly_used: int = 2000,
daily_limit: int = 10000,
weekly_limit: int = 50000,
tier: "SubscriptionTier" = SubscriptionTier.FREE,
) -> AsyncMock:
"""Mock get_usage_status to return a predictable CoPilotUsageStatus."""
"""Mock get_usage_status and get_global_rate_limits for usage endpoint tests.
Mocks both ``get_global_rate_limits`` (returns the given limits + tier) and
``get_usage_status`` so that tests exercise the endpoint without hitting
LaunchDarkly or Prisma.
"""
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
mocker.patch(
"backend.api.features.chat.routes.get_global_rate_limits",
new_callable=AsyncMock,
return_value=(daily_limit, weekly_limit, tier),
)
resets_at = datetime.now(UTC) + timedelta(days=1)
status = CoPilotUsageStatus(
daily=UsageWindow(used=daily_used, limit=10000, resets_at=resets_at),
weekly=UsageWindow(used=weekly_used, limit=50000, resets_at=resets_at),
daily=UsageWindow(used=daily_used, limit=daily_limit, resets_at=resets_at),
weekly=UsageWindow(used=weekly_used, limit=weekly_limit, resets_at=resets_at),
)
return mocker.patch(
"backend.api.features.chat.routes.get_usage_status",
@@ -369,6 +384,7 @@ def test_usage_returns_daily_and_weekly(
daily_token_limit=10000,
weekly_token_limit=50000,
rate_limit_reset_cost=chat_routes.config.rate_limit_reset_cost,
tier=SubscriptionTier.FREE,
)
@@ -376,11 +392,9 @@ def test_usage_uses_config_limits(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
"""The endpoint forwards daily_token_limit and weekly_token_limit from config."""
mock_get = _mock_usage(mocker)
"""The endpoint forwards resolved limits from get_global_rate_limits to get_usage_status."""
mock_get = _mock_usage(mocker, daily_limit=99999, weekly_limit=77777)
mocker.patch.object(chat_routes.config, "daily_token_limit", 99999)
mocker.patch.object(chat_routes.config, "weekly_token_limit", 77777)
mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", 500)
response = client.get("/usage")
@@ -391,6 +405,7 @@ def test_usage_uses_config_limits(
daily_token_limit=99999,
weekly_token_limit=77777,
rate_limit_reset_cost=500,
tier=SubscriptionTier.FREE,
)

View File

@@ -189,6 +189,7 @@ async def test_create_store_submission(mocker):
notifyOnAgentApproved=True,
notifyOnAgentRejected=True,
timezone="Europe/Delft",
subscriptionTier=prisma.enums.SubscriptionTier.FREE, # type: ignore[reportCallIssue,reportAttributeAccessIssue]
)
mock_agent = prisma.models.AgentGraph(
id="agent-id",

View File

@@ -81,11 +81,11 @@ class ChatConfig(BaseSettings):
# allows ~70-100 turns/day.
# Checked at the HTTP layer (routes.py) before each turn.
#
# TODO: These are deploy-time constants applied identically to every user.
# If per-user or per-plan limits are needed (e.g., free tier vs paid), these
# must move to the database (e.g., a UserPlan table) and get_usage_status /
# check_rate_limit would look up each user's specific limits instead of
# reading config.daily_token_limit / config.weekly_token_limit.
# These are base limits for the FREE tier. Higher tiers (PRO, BUSINESS,
# ENTERPRISE) multiply these by their tier multiplier (see
# rate_limit.TIER_MULTIPLIERS). User tier is stored in the
# User.subscriptionTier DB column and resolved inside
# get_global_rate_limits().
daily_token_limit: int = Field(
default=2_500_000,
description="Max tokens per day, resets at midnight UTC (0 = unlimited)",

View File

@@ -9,11 +9,14 @@ UTC). Fails open when Redis is unavailable to avoid blocking users.
import asyncio
import logging
from datetime import UTC, datetime, timedelta
from enum import Enum
from prisma.models import User as PrismaUser
from pydantic import BaseModel, Field
from redis.exceptions import RedisError
from backend.data.redis_client import get_redis_async
from backend.util.cache import cached
logger = logging.getLogger(__name__)
@@ -21,6 +24,40 @@ logger = logging.getLogger(__name__)
_USAGE_KEY_PREFIX = "copilot:usage"
# ---------------------------------------------------------------------------
# Subscription tier definitions
# ---------------------------------------------------------------------------
class SubscriptionTier(str, Enum):
"""Subscription tiers with increasing token allowances.
Mirrors the ``SubscriptionTier`` enum in ``schema.prisma``.
Once ``prisma generate`` is run, this can be replaced with::
from prisma.enums import SubscriptionTier
"""
FREE = "FREE"
PRO = "PRO"
BUSINESS = "BUSINESS"
ENTERPRISE = "ENTERPRISE"
# Multiplier applied to the base limits (from LD / config) for each tier.
# Intentionally int (not float): keeps limits as whole token counts and avoids
# floating-point rounding. If fractional multipliers are ever needed, change
# the type and round the result in get_global_rate_limits().
TIER_MULTIPLIERS: dict[SubscriptionTier, int] = {
SubscriptionTier.FREE: 1,
SubscriptionTier.PRO: 5,
SubscriptionTier.BUSINESS: 20,
SubscriptionTier.ENTERPRISE: 60,
}
DEFAULT_TIER = SubscriptionTier.FREE
class UsageWindow(BaseModel):
"""Usage within a single time window."""
@@ -36,6 +73,7 @@ class CoPilotUsageStatus(BaseModel):
daily: UsageWindow
weekly: UsageWindow
tier: SubscriptionTier = DEFAULT_TIER
reset_cost: int = Field(
default=0,
description="Credit cost (in cents) to reset the daily limit. 0 = feature disabled.",
@@ -66,6 +104,7 @@ async def get_usage_status(
daily_token_limit: int,
weekly_token_limit: int,
rate_limit_reset_cost: int = 0,
tier: SubscriptionTier = DEFAULT_TIER,
) -> CoPilotUsageStatus:
"""Get current usage status for a user.
@@ -74,6 +113,7 @@ async def get_usage_status(
daily_token_limit: Max tokens per day (0 = unlimited).
weekly_token_limit: Max tokens per week (0 = unlimited).
rate_limit_reset_cost: Credit cost (cents) to reset daily limit (0 = disabled).
tier: The user's rate-limit tier (included in the response).
Returns:
CoPilotUsageStatus with current usage and limits.
@@ -103,6 +143,7 @@ async def get_usage_status(
limit=weekly_token_limit,
resets_at=_weekly_reset_time(now=now),
),
tier=tier,
reset_cost=rate_limit_reset_cost,
)
@@ -343,20 +384,100 @@ async def record_token_usage(
)
class _UserNotFoundError(Exception):
"""Raised when a user record is missing or has no subscription tier.
Used internally by ``_fetch_user_tier`` to signal a cache-miss condition:
by raising instead of returning ``DEFAULT_TIER``, we prevent the ``@cached``
decorator from storing the fallback value. This avoids a race condition
where a non-existent user's DEFAULT_TIER is cached, then the user is
created with a higher tier but receives the stale cached FREE tier for
up to 5 minutes.
"""
@cached(maxsize=1000, ttl_seconds=300, shared_cache=True)
async def _fetch_user_tier(user_id: str) -> SubscriptionTier:
"""Fetch the user's rate-limit tier from the database (cached via Redis).
Uses ``shared_cache=True`` so that tier changes propagate across all pods
immediately when the cache entry is invalidated (via ``cache_delete``).
Only successful DB lookups of existing users with a valid tier are cached.
Raises ``_UserNotFoundError`` when the user is missing or has no tier, so
the ``@cached`` decorator does **not** store a fallback value. This
prevents a race condition where a non-existent user's ``DEFAULT_TIER`` is
cached and then persists after the user is created with a higher tier.
"""
user = await PrismaUser.prisma().find_unique(where={"id": user_id})
if user and user.subscriptionTier: # type: ignore[reportAttributeAccessIssue]
return SubscriptionTier(user.subscriptionTier) # type: ignore[reportAttributeAccessIssue]
raise _UserNotFoundError(user_id)
async def get_user_tier(user_id: str) -> SubscriptionTier:
"""Look up the user's rate-limit tier from the database.
Successful results are cached for 5 minutes (via ``_fetch_user_tier``)
to avoid a DB round-trip on every rate-limit check.
Falls back to ``DEFAULT_TIER`` **without caching** when the DB is
unreachable or returns an unrecognised value, so the next call retries
the query instead of serving a stale fallback for up to 5 minutes.
"""
try:
return await _fetch_user_tier(user_id)
except Exception as exc:
logger.warning(
"Failed to resolve rate-limit tier for user %s, defaulting to %s: %s",
user_id[:8],
DEFAULT_TIER.value,
exc,
)
return DEFAULT_TIER
# Expose cache management on the public function so callers (including tests)
# never need to reach into the private ``_fetch_user_tier``.
get_user_tier.cache_clear = _fetch_user_tier.cache_clear # type: ignore[attr-defined]
get_user_tier.cache_delete = _fetch_user_tier.cache_delete # type: ignore[attr-defined]
async def set_user_tier(user_id: str, tier: SubscriptionTier) -> None:
"""Persist the user's rate-limit tier to the database.
Also invalidates the ``get_user_tier`` cache for this user so that
subsequent rate-limit checks immediately see the new tier.
Raises:
prisma.errors.RecordNotFoundError: If the user does not exist.
"""
await PrismaUser.prisma().update(
where={"id": user_id},
data={"subscriptionTier": tier.value},
)
# Invalidate cached tier so rate-limit checks pick up the change immediately.
get_user_tier.cache_delete(user_id) # type: ignore[attr-defined]
async def get_global_rate_limits(
user_id: str,
config_daily: int,
config_weekly: int,
) -> tuple[int, int]:
) -> tuple[int, int, SubscriptionTier]:
"""Resolve global rate limits from LaunchDarkly, falling back to config.
The base limits (from LD or config) are multiplied by the user's
tier multiplier so that higher tiers receive proportionally larger
allowances.
Args:
user_id: User ID for LD flag evaluation context.
config_daily: Fallback daily limit from ChatConfig.
config_weekly: Fallback weekly limit from ChatConfig.
Returns:
(daily_token_limit, weekly_token_limit) tuple.
(daily_token_limit, weekly_token_limit, tier) 3-tuple.
"""
# Lazy import to avoid circular dependency:
# rate_limit -> feature_flag -> settings -> ... -> rate_limit
@@ -378,7 +499,15 @@ async def get_global_rate_limits(
except (TypeError, ValueError):
logger.warning("Invalid LD value for weekly token limit: %r", weekly_raw)
weekly = config_weekly
return daily, weekly
# Apply tier multiplier
tier = await get_user_tier(user_id)
multiplier = TIER_MULTIPLIERS.get(tier, 1)
if multiplier != 1:
daily = daily * multiplier
weekly = weekly * multiplier
return daily, weekly, tier
async def reset_user_usage(user_id: str, *, reset_weekly: bool = False) -> None:

View File

@@ -7,12 +7,19 @@ import pytest
from redis.exceptions import RedisError
from .rate_limit import (
DEFAULT_TIER,
TIER_MULTIPLIERS,
CoPilotUsageStatus,
RateLimitExceeded,
SubscriptionTier,
UsageWindow,
check_rate_limit,
get_global_rate_limits,
get_usage_status,
get_user_tier,
record_token_usage,
reset_daily_usage,
set_user_tier,
)
_USER = "test-user-rl"
@@ -335,6 +342,524 @@ class TestRecordTokenUsage:
await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50)
# ---------------------------------------------------------------------------
# SubscriptionTier and tier multipliers
# ---------------------------------------------------------------------------
class TestSubscriptionTier:
def test_tier_values(self):
assert SubscriptionTier.FREE.value == "FREE"
assert SubscriptionTier.PRO.value == "PRO"
assert SubscriptionTier.BUSINESS.value == "BUSINESS"
assert SubscriptionTier.ENTERPRISE.value == "ENTERPRISE"
def test_tier_multipliers(self):
assert TIER_MULTIPLIERS[SubscriptionTier.FREE] == 1
assert TIER_MULTIPLIERS[SubscriptionTier.PRO] == 5
assert TIER_MULTIPLIERS[SubscriptionTier.BUSINESS] == 20
assert TIER_MULTIPLIERS[SubscriptionTier.ENTERPRISE] == 60
def test_default_tier_is_free(self):
assert DEFAULT_TIER == SubscriptionTier.FREE
def test_usage_status_includes_tier(self):
now = datetime.now(UTC)
status = CoPilotUsageStatus(
daily=UsageWindow(used=0, limit=100, resets_at=now + timedelta(hours=1)),
weekly=UsageWindow(used=0, limit=500, resets_at=now + timedelta(days=1)),
)
assert status.tier == SubscriptionTier.FREE
def test_usage_status_with_custom_tier(self):
now = datetime.now(UTC)
status = CoPilotUsageStatus(
daily=UsageWindow(used=0, limit=100, resets_at=now + timedelta(hours=1)),
weekly=UsageWindow(used=0, limit=500, resets_at=now + timedelta(days=1)),
tier=SubscriptionTier.PRO,
)
assert status.tier == SubscriptionTier.PRO
# ---------------------------------------------------------------------------
# get_user_tier
# ---------------------------------------------------------------------------
class TestGetUserTier:
@pytest.fixture(autouse=True)
def _clear_tier_cache(self):
"""Clear the get_user_tier cache before each test."""
get_user_tier.cache_clear() # type: ignore[attr-defined]
@pytest.mark.asyncio
async def test_returns_tier_from_db(self):
"""Should return the tier stored in the user record."""
mock_user = MagicMock()
mock_user.subscriptionTier = "PRO"
mock_prisma = AsyncMock()
mock_prisma.find_unique = AsyncMock(return_value=mock_user)
with patch(
"backend.copilot.rate_limit.PrismaUser.prisma",
return_value=mock_prisma,
):
tier = await get_user_tier(_USER)
assert tier == SubscriptionTier.PRO
@pytest.mark.asyncio
async def test_returns_default_when_user_not_found(self):
"""Should return DEFAULT_TIER when user is not in the DB."""
mock_prisma = AsyncMock()
mock_prisma.find_unique = AsyncMock(return_value=None)
with patch(
"backend.copilot.rate_limit.PrismaUser.prisma",
return_value=mock_prisma,
):
tier = await get_user_tier(_USER)
assert tier == DEFAULT_TIER
@pytest.mark.asyncio
async def test_returns_default_when_tier_is_none(self):
"""Should return DEFAULT_TIER when subscriptionTier is None."""
mock_user = MagicMock()
mock_user.subscriptionTier = None
mock_prisma = AsyncMock()
mock_prisma.find_unique = AsyncMock(return_value=mock_user)
with patch(
"backend.copilot.rate_limit.PrismaUser.prisma",
return_value=mock_prisma,
):
tier = await get_user_tier(_USER)
assert tier == DEFAULT_TIER
@pytest.mark.asyncio
async def test_returns_default_on_db_error(self):
"""Should fall back to DEFAULT_TIER when DB raises."""
mock_prisma = AsyncMock()
mock_prisma.find_unique = AsyncMock(side_effect=Exception("DB down"))
with patch(
"backend.copilot.rate_limit.PrismaUser.prisma",
return_value=mock_prisma,
):
tier = await get_user_tier(_USER)
assert tier == DEFAULT_TIER
@pytest.mark.asyncio
async def test_db_error_is_not_cached(self):
"""Transient DB errors should NOT cache the default tier.
Regression test: a transient DB failure previously cached DEFAULT_TIER
for 5 minutes, incorrectly downgrading higher-tier users until expiry.
"""
failing_prisma = AsyncMock()
failing_prisma.find_unique = AsyncMock(side_effect=Exception("DB down"))
with patch(
"backend.copilot.rate_limit.PrismaUser.prisma",
return_value=failing_prisma,
):
tier1 = await get_user_tier(_USER)
assert tier1 == DEFAULT_TIER
# Now DB recovers and returns PRO
mock_user = MagicMock()
mock_user.subscriptionTier = "PRO"
ok_prisma = AsyncMock()
ok_prisma.find_unique = AsyncMock(return_value=mock_user)
with patch(
"backend.copilot.rate_limit.PrismaUser.prisma",
return_value=ok_prisma,
):
tier2 = await get_user_tier(_USER)
# Should get PRO now — the error result was not cached
assert tier2 == SubscriptionTier.PRO
@pytest.mark.asyncio
async def test_returns_default_on_invalid_tier_value(self):
"""Should fall back to DEFAULT_TIER when stored value is invalid."""
mock_user = MagicMock()
mock_user.subscriptionTier = "invalid-tier"
mock_prisma = AsyncMock()
mock_prisma.find_unique = AsyncMock(return_value=mock_user)
with patch(
"backend.copilot.rate_limit.PrismaUser.prisma",
return_value=mock_prisma,
):
tier = await get_user_tier(_USER)
assert tier == DEFAULT_TIER
@pytest.mark.asyncio
async def test_user_not_found_is_not_cached(self):
"""Non-existent user should NOT cache DEFAULT_TIER.
Regression test: when ``get_user_tier`` is called before a user record
exists, the DEFAULT_TIER fallback must not be cached. Otherwise, a
newly created user with a higher tier (e.g. PRO) would receive the
stale cached FREE tier for up to 5 minutes.
"""
# First call: user does not exist yet
missing_prisma = AsyncMock()
missing_prisma.find_unique = AsyncMock(return_value=None)
with patch(
"backend.copilot.rate_limit.PrismaUser.prisma",
return_value=missing_prisma,
):
tier1 = await get_user_tier(_USER)
assert tier1 == DEFAULT_TIER
# Second call: user now exists with PRO tier
mock_user = MagicMock()
mock_user.subscriptionTier = "PRO"
ok_prisma = AsyncMock()
ok_prisma.find_unique = AsyncMock(return_value=mock_user)
with patch(
"backend.copilot.rate_limit.PrismaUser.prisma",
return_value=ok_prisma,
):
tier2 = await get_user_tier(_USER)
# Should get PRO — the not-found result was not cached
assert tier2 == SubscriptionTier.PRO
# ---------------------------------------------------------------------------
# set_user_tier
# ---------------------------------------------------------------------------
class TestSetUserTier:
@pytest.fixture(autouse=True)
def _clear_tier_cache(self):
"""Clear the get_user_tier cache before each test."""
get_user_tier.cache_clear() # type: ignore[attr-defined]
@pytest.mark.asyncio
async def test_updates_db_and_invalidates_cache(self):
"""set_user_tier should persist to DB and invalidate the tier cache."""
mock_prisma = AsyncMock()
mock_prisma.update = AsyncMock(return_value=None)
with patch(
"backend.copilot.rate_limit.PrismaUser.prisma",
return_value=mock_prisma,
):
await set_user_tier(_USER, SubscriptionTier.PRO)
mock_prisma.update.assert_awaited_once_with(
where={"id": _USER},
data={"subscriptionTier": "PRO"},
)
@pytest.mark.asyncio
async def test_record_not_found_propagates(self):
"""RecordNotFoundError from Prisma should propagate to callers."""
import prisma.errors
mock_prisma = AsyncMock()
mock_prisma.update = AsyncMock(
side_effect=prisma.errors.RecordNotFoundError(
{"error": "Record not found"}
),
)
with patch(
"backend.copilot.rate_limit.PrismaUser.prisma",
return_value=mock_prisma,
):
with pytest.raises(prisma.errors.RecordNotFoundError):
await set_user_tier(_USER, SubscriptionTier.ENTERPRISE)
@pytest.mark.asyncio
async def test_cache_invalidated_after_set(self):
"""After set_user_tier, get_user_tier should query DB again (not cache)."""
# First, populate the cache with BUSINESS
mock_user_biz = MagicMock()
mock_user_biz.subscriptionTier = "BUSINESS"
mock_prisma_get = AsyncMock()
mock_prisma_get.find_unique = AsyncMock(return_value=mock_user_biz)
with patch(
"backend.copilot.rate_limit.PrismaUser.prisma",
return_value=mock_prisma_get,
):
tier_before = await get_user_tier(_USER)
assert tier_before == SubscriptionTier.BUSINESS
# Now set tier to ENTERPRISE (this should invalidate the cache)
mock_prisma_set = AsyncMock()
mock_prisma_set.update = AsyncMock(return_value=None)
with patch(
"backend.copilot.rate_limit.PrismaUser.prisma",
return_value=mock_prisma_set,
):
await set_user_tier(_USER, SubscriptionTier.ENTERPRISE)
# Now get_user_tier should hit DB again (cache was invalidated)
mock_user_ent = MagicMock()
mock_user_ent.subscriptionTier = "ENTERPRISE"
mock_prisma_get2 = AsyncMock()
mock_prisma_get2.find_unique = AsyncMock(return_value=mock_user_ent)
with patch(
"backend.copilot.rate_limit.PrismaUser.prisma",
return_value=mock_prisma_get2,
):
tier_after = await get_user_tier(_USER)
assert tier_after == SubscriptionTier.ENTERPRISE
# ---------------------------------------------------------------------------
# get_global_rate_limits with tiers
# ---------------------------------------------------------------------------
class TestGetGlobalRateLimitsWithTiers:
@staticmethod
def _ld_side_effect(daily: int, weekly: int):
"""Return an async side_effect that dispatches by flag_key."""
async def _side_effect(flag_key: str, _uid: str, default: int) -> int:
if "daily" in flag_key.lower():
return daily
if "weekly" in flag_key.lower():
return weekly
return default
return _side_effect
@pytest.mark.asyncio
async def test_free_tier_no_multiplier(self):
"""Free tier should not change limits."""
with (
patch(
"backend.copilot.rate_limit.get_user_tier",
new_callable=AsyncMock,
return_value=SubscriptionTier.FREE,
),
patch(
"backend.util.feature_flag.get_feature_flag_value",
side_effect=self._ld_side_effect(2_500_000, 12_500_000),
),
):
daily, weekly, tier = await get_global_rate_limits(
_USER, 2_500_000, 12_500_000
)
assert daily == 2_500_000
assert weekly == 12_500_000
assert tier == SubscriptionTier.FREE
@pytest.mark.asyncio
async def test_pro_tier_5x_multiplier(self):
"""Pro tier should multiply limits by 5."""
with (
patch(
"backend.copilot.rate_limit.get_user_tier",
new_callable=AsyncMock,
return_value=SubscriptionTier.PRO,
),
patch(
"backend.util.feature_flag.get_feature_flag_value",
side_effect=self._ld_side_effect(2_500_000, 12_500_000),
),
):
daily, weekly, tier = await get_global_rate_limits(
_USER, 2_500_000, 12_500_000
)
assert daily == 12_500_000
assert weekly == 62_500_000
assert tier == SubscriptionTier.PRO
@pytest.mark.asyncio
async def test_business_tier_20x_multiplier(self):
"""Business tier should multiply limits by 20."""
with (
patch(
"backend.copilot.rate_limit.get_user_tier",
new_callable=AsyncMock,
return_value=SubscriptionTier.BUSINESS,
),
patch(
"backend.util.feature_flag.get_feature_flag_value",
side_effect=self._ld_side_effect(2_500_000, 12_500_000),
),
):
daily, weekly, tier = await get_global_rate_limits(
_USER, 2_500_000, 12_500_000
)
assert daily == 50_000_000
assert weekly == 250_000_000
assert tier == SubscriptionTier.BUSINESS
@pytest.mark.asyncio
async def test_enterprise_tier_60x_multiplier(self):
"""Enterprise tier should multiply limits by 60."""
with (
patch(
"backend.copilot.rate_limit.get_user_tier",
new_callable=AsyncMock,
return_value=SubscriptionTier.ENTERPRISE,
),
patch(
"backend.util.feature_flag.get_feature_flag_value",
side_effect=self._ld_side_effect(2_500_000, 12_500_000),
),
):
daily, weekly, tier = await get_global_rate_limits(
_USER, 2_500_000, 12_500_000
)
assert daily == 150_000_000
assert weekly == 750_000_000
assert tier == SubscriptionTier.ENTERPRISE
# ---------------------------------------------------------------------------
# End-to-end: tier limits are respected by check_rate_limit
# ---------------------------------------------------------------------------
class TestTierLimitsRespected:
"""Verify that tier-adjusted limits from get_global_rate_limits flow
correctly into check_rate_limit, so higher tiers allow more usage and
lower tiers are blocked when they would exceed their allocation."""
_BASE_DAILY = 2_500_000
_BASE_WEEKLY = 12_500_000
@staticmethod
def _ld_side_effect(daily: int, weekly: int):
async def _side_effect(flag_key: str, _uid: str, default: int) -> int:
if "daily" in flag_key.lower():
return daily
if "weekly" in flag_key.lower():
return weekly
return default
return _side_effect
@pytest.mark.asyncio
async def test_pro_user_allowed_above_free_limit(self):
"""A PRO user with usage above the FREE limit should be allowed."""
# Usage: 3M tokens (above FREE limit of 2.5M, below PRO limit of 12.5M)
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["3000000", "3000000"])
with (
patch(
"backend.copilot.rate_limit.get_user_tier",
new_callable=AsyncMock,
return_value=SubscriptionTier.PRO,
),
patch(
"backend.util.feature_flag.get_feature_flag_value",
side_effect=self._ld_side_effect(self._BASE_DAILY, self._BASE_WEEKLY),
),
patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
),
):
daily, weekly, tier = await get_global_rate_limits(
_USER, self._BASE_DAILY, self._BASE_WEEKLY
)
# PRO: 5x multiplier
assert daily == 12_500_000
assert tier == SubscriptionTier.PRO
# Should NOT raise — 3M < 12.5M
await check_rate_limit(
_USER, daily_token_limit=daily, weekly_token_limit=weekly
)
@pytest.mark.asyncio
async def test_free_user_blocked_at_free_limit(self):
"""A FREE user at or above the base limit should be blocked."""
# Usage: 2.5M tokens (at FREE limit of 2.5M)
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["2500000", "2500000"])
with (
patch(
"backend.copilot.rate_limit.get_user_tier",
new_callable=AsyncMock,
return_value=SubscriptionTier.FREE,
),
patch(
"backend.util.feature_flag.get_feature_flag_value",
side_effect=self._ld_side_effect(self._BASE_DAILY, self._BASE_WEEKLY),
),
patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
),
):
daily, weekly, tier = await get_global_rate_limits(
_USER, self._BASE_DAILY, self._BASE_WEEKLY
)
# FREE: 1x multiplier
assert daily == 2_500_000
assert tier == SubscriptionTier.FREE
# Should raise — 2.5M >= 2.5M
with pytest.raises(RateLimitExceeded):
await check_rate_limit(
_USER, daily_token_limit=daily, weekly_token_limit=weekly
)
@pytest.mark.asyncio
async def test_enterprise_user_has_highest_headroom(self):
"""An ENTERPRISE user should have 60x the base limit."""
# Usage: 100M tokens (huge, but below ENTERPRISE daily of 150M)
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=["100000000", "100000000"])
with (
patch(
"backend.copilot.rate_limit.get_user_tier",
new_callable=AsyncMock,
return_value=SubscriptionTier.ENTERPRISE,
),
patch(
"backend.util.feature_flag.get_feature_flag_value",
side_effect=self._ld_side_effect(self._BASE_DAILY, self._BASE_WEEKLY),
),
patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
),
):
daily, weekly, tier = await get_global_rate_limits(
_USER, self._BASE_DAILY, self._BASE_WEEKLY
)
assert daily == 150_000_000
assert tier == SubscriptionTier.ENTERPRISE
# Should NOT raise — 100M < 150M
await check_rate_limit(
_USER, daily_token_limit=daily, weekly_token_limit=weekly
)
# ---------------------------------------------------------------------------
# reset_daily_usage
# ---------------------------------------------------------------------------
@@ -421,3 +946,267 @@ class TestResetDailyUsage:
result = await reset_daily_usage(_USER, daily_token_limit=10000)
assert result is False
# ---------------------------------------------------------------------------
# Tier-limit enforcement (integration-style)
# ---------------------------------------------------------------------------
class TestTierLimitsEnforced:
"""Verify that tier-multiplied limits are actually respected by
``check_rate_limit`` — i.e. that usage within the tier allowance passes
and usage at/above the tier allowance is rejected."""
_BASE_DAILY = 1_000_000
_BASE_WEEKLY = 5_000_000
@staticmethod
def _ld_side_effect(daily: int, weekly: int):
"""Mock LD flag lookup returning the given raw limits."""
async def _side_effect(flag_key: str, _uid: str, default: int) -> int:
if "daily" in flag_key.lower():
return daily
if "weekly" in flag_key.lower():
return weekly
return default
return _side_effect
@pytest.mark.asyncio
async def test_pro_within_limit_allowed(self):
"""Usage under PRO daily limit should not raise."""
pro_daily = self._BASE_DAILY * TIER_MULTIPLIERS[SubscriptionTier.PRO]
mock_redis = AsyncMock()
# Simulate usage just under the PRO daily limit
mock_redis.get = AsyncMock(side_effect=[str(pro_daily - 1), "0"])
with (
patch(
"backend.copilot.rate_limit.get_user_tier",
new_callable=AsyncMock,
return_value=SubscriptionTier.PRO,
),
patch(
"backend.util.feature_flag.get_feature_flag_value",
side_effect=self._ld_side_effect(self._BASE_DAILY, self._BASE_WEEKLY),
),
patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
),
):
daily, weekly, tier = await get_global_rate_limits(
_USER, self._BASE_DAILY, self._BASE_WEEKLY
)
assert tier == SubscriptionTier.PRO
assert daily == pro_daily
# Should not raise — usage is under the limit
await check_rate_limit(_USER, daily, weekly)
@pytest.mark.asyncio
async def test_pro_at_limit_rejected(self):
"""Usage at exactly the PRO daily limit should raise."""
pro_daily = self._BASE_DAILY * TIER_MULTIPLIERS[SubscriptionTier.PRO]
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=[str(pro_daily), "0"])
with (
patch(
"backend.copilot.rate_limit.get_user_tier",
new_callable=AsyncMock,
return_value=SubscriptionTier.PRO,
),
patch(
"backend.util.feature_flag.get_feature_flag_value",
side_effect=self._ld_side_effect(self._BASE_DAILY, self._BASE_WEEKLY),
),
patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
),
):
daily, weekly, tier = await get_global_rate_limits(
_USER, self._BASE_DAILY, self._BASE_WEEKLY
)
with pytest.raises(RateLimitExceeded) as exc_info:
await check_rate_limit(_USER, daily, weekly)
assert exc_info.value.window == "daily"
@pytest.mark.asyncio
async def test_business_higher_limit_allows_pro_overflow(self):
"""Usage exceeding PRO but under BUSINESS should pass for BUSINESS."""
pro_daily = self._BASE_DAILY * TIER_MULTIPLIERS[SubscriptionTier.PRO]
biz_daily = self._BASE_DAILY * TIER_MULTIPLIERS[SubscriptionTier.BUSINESS]
# Usage between PRO and BUSINESS limits
usage = pro_daily + 1_000_000
assert usage < biz_daily, "test sanity: usage must be under BUSINESS limit"
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=[str(usage), "0"])
with (
patch(
"backend.copilot.rate_limit.get_user_tier",
new_callable=AsyncMock,
return_value=SubscriptionTier.BUSINESS,
),
patch(
"backend.util.feature_flag.get_feature_flag_value",
side_effect=self._ld_side_effect(self._BASE_DAILY, self._BASE_WEEKLY),
),
patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
),
):
daily, weekly, tier = await get_global_rate_limits(
_USER, self._BASE_DAILY, self._BASE_WEEKLY
)
assert tier == SubscriptionTier.BUSINESS
assert daily == biz_daily
# Should not raise — BUSINESS tier can handle this
await check_rate_limit(_USER, daily, weekly)
@pytest.mark.asyncio
async def test_weekly_limit_enforced_for_tier(self):
"""Weekly limit should also be tier-multiplied and enforced."""
pro_weekly = self._BASE_WEEKLY * TIER_MULTIPLIERS[SubscriptionTier.PRO]
mock_redis = AsyncMock()
# Daily usage fine, weekly at limit
mock_redis.get = AsyncMock(side_effect=["0", str(pro_weekly)])
with (
patch(
"backend.copilot.rate_limit.get_user_tier",
new_callable=AsyncMock,
return_value=SubscriptionTier.PRO,
),
patch(
"backend.util.feature_flag.get_feature_flag_value",
side_effect=self._ld_side_effect(self._BASE_DAILY, self._BASE_WEEKLY),
),
patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
),
):
daily, weekly, tier = await get_global_rate_limits(
_USER, self._BASE_DAILY, self._BASE_WEEKLY
)
with pytest.raises(RateLimitExceeded) as exc_info:
await check_rate_limit(_USER, daily, weekly)
assert exc_info.value.window == "weekly"
@pytest.mark.asyncio
async def test_free_tier_base_limit_enforced(self):
"""Free tier (1x multiplier) should enforce the base limit exactly."""
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=[str(self._BASE_DAILY), "0"])
with (
patch(
"backend.copilot.rate_limit.get_user_tier",
new_callable=AsyncMock,
return_value=SubscriptionTier.FREE,
),
patch(
"backend.util.feature_flag.get_feature_flag_value",
side_effect=self._ld_side_effect(self._BASE_DAILY, self._BASE_WEEKLY),
),
patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
),
):
daily, weekly, tier = await get_global_rate_limits(
_USER, self._BASE_DAILY, self._BASE_WEEKLY
)
assert daily == self._BASE_DAILY # 1x multiplier
with pytest.raises(RateLimitExceeded):
await check_rate_limit(_USER, daily, weekly)
@pytest.mark.asyncio
async def test_free_tier_cannot_bypass_pro_limit(self):
"""A FREE-tier user whose usage is within PRO limits but over FREE
limits must still be rejected.
Negative test: ensures the tier multiplier is applied *before* the
rate-limit check, so a lower-tier user cannot 'bypass' limits that
would be acceptable for a higher tier.
"""
free_daily = self._BASE_DAILY * TIER_MULTIPLIERS[SubscriptionTier.FREE]
pro_daily = self._BASE_DAILY * TIER_MULTIPLIERS[SubscriptionTier.PRO]
# Usage above FREE limit but below PRO limit
usage = free_daily + 500_000
assert usage < pro_daily, "test sanity: usage must be under PRO limit"
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=[str(usage), "0"])
with (
patch(
"backend.copilot.rate_limit.get_user_tier",
new_callable=AsyncMock,
return_value=SubscriptionTier.FREE,
),
patch(
"backend.util.feature_flag.get_feature_flag_value",
side_effect=self._ld_side_effect(self._BASE_DAILY, self._BASE_WEEKLY),
),
patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
),
):
daily, weekly, tier = await get_global_rate_limits(
_USER, self._BASE_DAILY, self._BASE_WEEKLY
)
assert tier == SubscriptionTier.FREE
assert daily == free_daily # 1x, not 5x
with pytest.raises(RateLimitExceeded) as exc_info:
await check_rate_limit(_USER, daily, weekly)
assert exc_info.value.window == "daily"
@pytest.mark.asyncio
async def test_tier_change_updates_effective_limits(self):
"""After upgrading from FREE to BUSINESS, the effective limits must
increase accordingly.
Verifies that the tier multiplier is correctly applied after a tier
change, and that usage that was over the FREE limit is within the new
BUSINESS limit.
"""
free_daily = self._BASE_DAILY * TIER_MULTIPLIERS[SubscriptionTier.FREE]
biz_daily = self._BASE_DAILY * TIER_MULTIPLIERS[SubscriptionTier.BUSINESS]
# Usage above FREE limit but below BUSINESS limit
usage = free_daily + 500_000
assert usage < biz_daily, "test sanity: usage must be under BUSINESS limit"
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=[str(usage), "0"])
# Simulate the user having been upgraded to BUSINESS
with (
patch(
"backend.copilot.rate_limit.get_user_tier",
new_callable=AsyncMock,
return_value=SubscriptionTier.BUSINESS,
),
patch(
"backend.util.feature_flag.get_feature_flag_value",
side_effect=self._ld_side_effect(self._BASE_DAILY, self._BASE_WEEKLY),
),
patch(
"backend.copilot.rate_limit.get_redis_async",
return_value=mock_redis,
),
):
daily, weekly, tier = await get_global_rate_limits(
_USER, self._BASE_DAILY, self._BASE_WEEKLY
)
assert tier == SubscriptionTier.BUSINESS
assert daily == biz_daily # 20x
# Should NOT raise — usage is within the BUSINESS tier allowance
await check_rate_limit(_USER, daily, weekly)

View File

@@ -9,7 +9,7 @@ import pytest
from fastapi import HTTPException
from backend.api.features.chat.routes import reset_copilot_usage
from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow
from backend.copilot.rate_limit import CoPilotUsageStatus, SubscriptionTier, UsageWindow
from backend.util.exceptions import InsufficientBalanceError
@@ -53,6 +53,18 @@ def _mock_settings(enable_credit: bool = True):
return mock
def _mock_rate_limits(
daily: int = 2_500_000,
weekly: int = 12_500_000,
tier: SubscriptionTier = SubscriptionTier.PRO,
):
"""Mock get_global_rate_limits to return fixed limits (no tier multiplier)."""
return patch(
f"{_MODULE}.get_global_rate_limits",
AsyncMock(return_value=(daily, weekly, tier)),
)
@pytest.mark.asyncio
class TestResetCopilotUsage:
async def test_feature_disabled_returns_400(self):
@@ -70,10 +82,7 @@ class TestResetCopilotUsage:
with (
patch(f"{_MODULE}.config", _make_config(daily_token_limit=0)),
patch(f"{_MODULE}.settings", _mock_settings()),
patch(
f"{_MODULE}.get_global_rate_limits",
AsyncMock(return_value=(0, 12_500_000)),
),
_mock_rate_limits(daily=0),
):
with pytest.raises(HTTPException) as exc_info:
await reset_copilot_usage(user_id="user-1")
@@ -87,10 +96,7 @@ class TestResetCopilotUsage:
with (
patch(f"{_MODULE}.config", cfg),
patch(f"{_MODULE}.settings", _mock_settings()),
patch(
f"{_MODULE}.get_global_rate_limits",
AsyncMock(return_value=(2_500_000, 12_500_000)),
),
_mock_rate_limits(),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
patch(f"{_MODULE}.release_reset_lock", AsyncMock()) as mock_release,
@@ -120,10 +126,7 @@ class TestResetCopilotUsage:
with (
patch(f"{_MODULE}.config", cfg),
patch(f"{_MODULE}.settings", _mock_settings()),
patch(
f"{_MODULE}.get_global_rate_limits",
AsyncMock(return_value=(2_500_000, 12_500_000)),
),
_mock_rate_limits(),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
patch(f"{_MODULE}.release_reset_lock", AsyncMock()) as mock_release,
@@ -153,10 +156,7 @@ class TestResetCopilotUsage:
with (
patch(f"{_MODULE}.config", cfg),
patch(f"{_MODULE}.settings", _mock_settings()),
patch(
f"{_MODULE}.get_global_rate_limits",
AsyncMock(return_value=(2_500_000, 12_500_000)),
),
_mock_rate_limits(),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
patch(f"{_MODULE}.release_reset_lock", AsyncMock()),
@@ -187,10 +187,7 @@ class TestResetCopilotUsage:
with (
patch(f"{_MODULE}.config", cfg),
patch(f"{_MODULE}.settings", _mock_settings()),
patch(
f"{_MODULE}.get_global_rate_limits",
AsyncMock(return_value=(2_500_000, 12_500_000)),
),
_mock_rate_limits(),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=3)),
):
with pytest.raises(HTTPException) as exc_info:
@@ -228,10 +225,7 @@ class TestResetCopilotUsage:
with (
patch(f"{_MODULE}.config", cfg),
patch(f"{_MODULE}.settings", _mock_settings()),
patch(
f"{_MODULE}.get_global_rate_limits",
AsyncMock(return_value=(2_500_000, 12_500_000)),
),
_mock_rate_limits(),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
patch(f"{_MODULE}.release_reset_lock", AsyncMock()) as mock_release,
@@ -252,10 +246,7 @@ class TestResetCopilotUsage:
with (
patch(f"{_MODULE}.config", _make_config()),
patch(f"{_MODULE}.settings", _mock_settings()),
patch(
f"{_MODULE}.get_global_rate_limits",
AsyncMock(return_value=(2_500_000, 12_500_000)),
),
_mock_rate_limits(),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=None)),
):
with pytest.raises(HTTPException) as exc_info:
@@ -273,10 +264,7 @@ class TestResetCopilotUsage:
with (
patch(f"{_MODULE}.config", cfg),
patch(f"{_MODULE}.settings", _mock_settings()),
patch(
f"{_MODULE}.get_global_rate_limits",
AsyncMock(return_value=(2_500_000, 12_500_000)),
),
_mock_rate_limits(),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
patch(f"{_MODULE}.release_reset_lock", AsyncMock()),
@@ -307,10 +295,7 @@ class TestResetCopilotUsage:
with (
patch(f"{_MODULE}.config", cfg),
patch(f"{_MODULE}.settings", _mock_settings()),
patch(
f"{_MODULE}.get_global_rate_limits",
AsyncMock(return_value=(2_500_000, 12_500_000)),
),
_mock_rate_limits(),
patch(f"{_MODULE}.get_daily_reset_count", AsyncMock(return_value=0)),
patch(f"{_MODULE}.acquire_reset_lock", AsyncMock(return_value=True)),
patch(f"{_MODULE}.release_reset_lock", AsyncMock()),

View File

@@ -33,6 +33,7 @@ from pydantic import BaseModel
from backend.copilot.context import get_workspace_manager
from backend.copilot.permissions import apply_tool_permissions
from backend.copilot.rate_limit import get_user_tier
from backend.data.redis_client import get_redis_async
from backend.executor.cluster_lock import AsyncClusterLock
from backend.util.exceptions import NotFoundError
@@ -1946,15 +1947,20 @@ async def stream_chat_completion_sdk(
# langsmith tracing integration attaches them to every span. This
# is what Langfuse (or any OTEL backend) maps to its native
# user/session fields.
_user_tier = await get_user_tier(user_id) if user_id else None
_otel_metadata: dict[str, str] = {
"resume": str(use_resume),
"conversation_turn": str(turn),
}
if _user_tier:
_otel_metadata["subscription_tier"] = _user_tier.value
_otel_ctx = propagate_attributes(
user_id=user_id,
session_id=session_id,
trace_name="copilot-sdk",
tags=["sdk"],
metadata={
"resume": str(use_resume),
"conversation_turn": str(turn),
},
metadata=_otel_metadata,
)
_otel_ctx.__enter__()

View File

@@ -82,6 +82,28 @@ async def get_user_by_email(email: str) -> Optional[User]:
raise DatabaseError(f"Failed to get user by email {email}: {e}") from e
async def search_users(query: str, limit: int = 20) -> list[tuple[str, str | None]]:
"""Search users by partial email or name.
Returns a list of ``(user_id, email)`` tuples, up to *limit* results.
Searches the User table directly — no dependency on credit history.
"""
query = query.strip()
if not query or len(query) < 3:
return []
users = await prisma.user.find_many(
where={
"OR": [
{"email": {"contains": query, "mode": "insensitive"}},
{"name": {"contains": query, "mode": "insensitive"}},
],
},
take=limit,
order={"email": "asc"},
)
return [(u.id, u.email) for u in users]
async def update_user_email(user_id: str, email: str):
try:
# Get old email first for cache invalidation

View File

@@ -121,10 +121,16 @@ def _make_hashable_key(
def _make_redis_key(key: tuple[Any, ...], func_name: str) -> str:
"""Convert a hashable key tuple to a Redis key string."""
# Ensure key is already hashable
hashable_key = key if isinstance(key, tuple) else (key,)
return f"cache:{func_name}:{hash(hashable_key)}"
"""Convert a hashable key tuple to a Redis key string.
Uses SHA-256 instead of Python's built-in ``hash()`` because ``hash()``
is randomised per-process (``PYTHONHASHSEED``). In a multi-pod
deployment every pod must derive the **same** Redis key for the same
arguments, otherwise cache lookups and invalidations silently miss.
"""
key_bytes = repr(key).encode()
digest = hashlib.sha256(key_bytes).hexdigest()
return f"cache:{func_name}:{digest}"
@runtime_checkable

View File

@@ -0,0 +1,5 @@
-- CreateEnum
CREATE TYPE "SubscriptionTier" AS ENUM ('FREE', 'PRO', 'BUSINESS', 'ENTERPRISE');
-- AlterTable: add subscriptionTier column with default PRO (beta testing)
ALTER TABLE "User" ADD COLUMN "subscriptionTier" "SubscriptionTier" NOT NULL DEFAULT 'PRO';

View File

@@ -40,6 +40,15 @@ model User {
timezone String @default("not-set")
// CoPilot subscription tier — controls rate-limit multipliers.
// Multipliers applied in get_global_rate_limits(): FREE=1x, PRO=5x, BUSINESS=20x, ENTERPRISE=60x.
// NOTE: @default(PRO) is intentional for the beta period — all existing and new
// users receive PRO-level (5x) rate limits by default. The Python-level constant
// DEFAULT_TIER=FREE (in copilot/rate_limit.py) acts as a code-level fallback when
// the DB value is NULL or unrecognised. At GA, a migration will flip the column
// default to FREE and batch-update users to their billing-derived tiers.
subscriptionTier SubscriptionTier @default(PRO)
// Relations
AgentGraphs AgentGraph[]
@@ -73,6 +82,13 @@ model User {
OAuthRefreshTokens OAuthRefreshToken[]
}
enum SubscriptionTier {
FREE
PRO
BUSINESS
ENTERPRISE
}
enum OnboardingStep {
// Introductory onboarding (Library)
WELCOME

View File

@@ -1,6 +1,7 @@
{
"daily_token_limit": 2500000,
"daily_tokens_used": 500000,
"tier": "FREE",
"user_email": "target@example.com",
"user_id": "5e53486c-cf57-477e-ba2a-cb02dc828e1c",
"weekly_token_limit": 12500000,

View File

@@ -1,6 +1,7 @@
{
"daily_token_limit": 2500000,
"daily_tokens_used": 0,
"tier": "FREE",
"user_email": "target@example.com",
"user_id": "5e53486c-cf57-477e-ba2a-cb02dc828e1c",
"weekly_token_limit": 12500000,

View File

@@ -1,6 +1,7 @@
{
"daily_token_limit": 2500000,
"daily_tokens_used": 0,
"tier": "FREE",
"user_email": "target@example.com",
"user_id": "5e53486c-cf57-477e-ba2a-cb02dc828e1c",
"weekly_token_limit": 12500000,

View File

@@ -3,18 +3,48 @@
import { useState } from "react";
import { Button } from "@/components/atoms/Button/Button";
import type { UserRateLimitResponse } from "@/app/api/__generated__/models/userRateLimitResponse";
import { useToast } from "@/components/molecules/Toast/use-toast";
import { UsageBar } from "../../components/UsageBar";
const TIERS = ["FREE", "PRO", "BUSINESS", "ENTERPRISE"] as const;
type Tier = (typeof TIERS)[number];
const TIER_MULTIPLIERS: Record<Tier, string> = {
FREE: "1x base limits",
PRO: "5x base limits",
BUSINESS: "20x base limits",
ENTERPRISE: "60x base limits",
};
const TIER_COLORS: Record<Tier, string> = {
FREE: "bg-gray-100 text-gray-700",
PRO: "bg-blue-100 text-blue-700",
BUSINESS: "bg-purple-100 text-purple-700",
ENTERPRISE: "bg-amber-100 text-amber-700",
};
interface Props {
data: UserRateLimitResponse;
onReset: (resetWeekly: boolean) => Promise<void>;
onTierChange?: (newTier: string) => Promise<void>;
/** Override the outer container classes (default: bordered card). */
className?: string;
}
export function RateLimitDisplay({ data, onReset, className }: Props) {
export function RateLimitDisplay({
data,
onReset,
onTierChange,
className,
}: Props) {
const [isResetting, setIsResetting] = useState(false);
const [resetWeekly, setResetWeekly] = useState(false);
const [isChangingTier, setIsChangingTier] = useState(false);
const { toast } = useToast();
const currentTier = TIERS.includes(data.tier as Tier)
? (data.tier as Tier)
: "FREE";
async function handleReset() {
const msg = resetWeekly
@@ -30,19 +60,76 @@ export function RateLimitDisplay({ data, onReset, className }: Props) {
}
}
async function handleTierChange(newTier: string) {
if (newTier === currentTier || !onTierChange) return;
if (
!window.confirm(
`Change tier from ${currentTier} to ${newTier}? This will change the user's rate limits.`,
)
)
return;
setIsChangingTier(true);
try {
await onTierChange(newTier);
toast({
title: "Tier updated",
description: `Changed to ${newTier} (${TIER_MULTIPLIERS[newTier as Tier]}).`,
});
} catch {
toast({
title: "Error",
description: "Failed to update tier.",
variant: "destructive",
});
} finally {
setIsChangingTier(false);
}
}
const nothingToReset = resetWeekly
? data.daily_tokens_used === 0 && data.weekly_tokens_used === 0
: data.daily_tokens_used === 0;
return (
<div className={className ?? "rounded-md border bg-white p-6"}>
<h2 className="mb-1 text-lg font-semibold">
Rate Limits for {data.user_email ?? data.user_id}
</h2>
{data.user_email && (
<p className="mb-4 text-xs text-gray-500">User ID: {data.user_id}</p>
)}
{!data.user_email && <div className="mb-4" />}
<div className="mb-4 flex items-start justify-between">
<div>
<h2 className="mb-1 text-lg font-semibold">
Rate Limits for {data.user_email ?? data.user_id}
</h2>
{data.user_email && (
<p className="text-xs text-gray-500">User ID: {data.user_id}</p>
)}
</div>
<span
className={`rounded-full px-3 py-1 text-xs font-medium ${TIER_COLORS[currentTier] ?? "bg-gray-100 text-gray-700"}`}
>
{currentTier}
</span>
</div>
<div className="mb-4 flex items-center gap-3">
<label className="text-sm font-medium text-gray-700">
Subscription Tier
</label>
<select
aria-label="Subscription tier"
value={currentTier}
onChange={(e) => handleTierChange(e.target.value)}
className="rounded-md border bg-white px-3 py-1.5 text-sm"
disabled={isChangingTier || !onTierChange}
>
{TIERS.map((tier) => (
<option key={tier} value={tier}>
{tier} {TIER_MULTIPLIERS[tier]}
</option>
))}
</select>
{isChangingTier && (
<span className="text-xs text-gray-500">Updating...</span>
)}
</div>
<div className="grid grid-cols-2 gap-6">
<div className="space-y-2">

View File

@@ -14,6 +14,7 @@ export function RateLimitManager() {
handleSearch,
handleSelectUser,
handleReset,
handleTierChange,
} = useRateLimitManager();
return (
@@ -74,7 +75,11 @@ export function RateLimitManager() {
)}
{rateLimitData && (
<RateLimitDisplay data={rateLimitData} onReset={handleReset} />
<RateLimitDisplay
data={rateLimitData}
onReset={handleReset}
onTierChange={handleTierChange}
/>
)}
</div>
);

View File

@@ -0,0 +1,281 @@
import {
render,
screen,
fireEvent,
waitFor,
cleanup,
} from "@/tests/integrations/test-utils";
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import { RateLimitDisplay } from "../RateLimitDisplay";
import type { UserRateLimitResponse } from "@/app/api/__generated__/models/userRateLimitResponse";
vi.mock("@/components/molecules/Toast/use-toast", () => ({
useToast: () => ({ toast: vi.fn() }),
}));
const mockConfirm = vi.fn();
beforeEach(() => {
mockConfirm.mockReset();
window.confirm = mockConfirm;
});
afterEach(() => {
cleanup();
});
function makeData(
overrides: Partial<UserRateLimitResponse> = {},
): UserRateLimitResponse {
return {
user_id: "user-abc-123",
user_email: "alice@example.com",
daily_token_limit: 10000,
weekly_token_limit: 50000,
daily_tokens_used: 2500,
weekly_tokens_used: 10000,
tier: "FREE",
...overrides,
};
}
describe("RateLimitDisplay", () => {
it("renders the user email heading", () => {
render(<RateLimitDisplay data={makeData()} onReset={vi.fn()} />);
expect(
screen.getByText(/Rate Limits for alice@example\.com/),
).toBeDefined();
});
it("renders user ID when email is present", () => {
render(<RateLimitDisplay data={makeData()} onReset={vi.fn()} />);
expect(screen.getByText(/user-abc-123/)).toBeDefined();
});
it("falls back to user_id in heading when email is absent", () => {
render(
<RateLimitDisplay
data={makeData({ user_email: undefined })}
onReset={vi.fn()}
/>,
);
expect(screen.getByText(/Rate Limits for user-abc-123/)).toBeDefined();
});
it("displays the current tier badge", () => {
render(
<RateLimitDisplay data={makeData({ tier: "PRO" })} onReset={vi.fn()} />,
);
const badge = screen.getByText("PRO");
expect(badge).toBeDefined();
expect(badge.className).toContain("bg-blue-100");
});
it("defaults unknown tier to FREE", () => {
render(
<RateLimitDisplay
data={makeData({ tier: "UNKNOWN" as UserRateLimitResponse["tier"] })}
onReset={vi.fn()}
/>,
);
const badge = screen.getByText("FREE");
expect(badge).toBeDefined();
});
it("renders tier dropdown with all tiers", () => {
render(<RateLimitDisplay data={makeData()} onReset={vi.fn()} />);
const select = screen.getByLabelText("Subscription tier");
expect(select).toBeDefined();
expect(select.querySelectorAll("option").length).toBe(4);
});
it("disables tier dropdown when onTierChange is not provided", () => {
render(<RateLimitDisplay data={makeData()} onReset={vi.fn()} />);
const select = screen.getByLabelText(
"Subscription tier",
) as HTMLSelectElement;
expect(select.disabled).toBe(true);
});
it("enables tier dropdown when onTierChange is provided", () => {
render(
<RateLimitDisplay
data={makeData()}
onReset={vi.fn()}
onTierChange={vi.fn()}
/>,
);
const select = screen.getByLabelText(
"Subscription tier",
) as HTMLSelectElement;
expect(select.disabled).toBe(false);
});
it("renders daily and weekly usage sections", () => {
render(<RateLimitDisplay data={makeData()} onReset={vi.fn()} />);
expect(screen.getByText("Daily Usage")).toBeDefined();
expect(screen.getByText("Weekly Usage")).toBeDefined();
});
it("renders reset scope dropdown and reset button", () => {
render(<RateLimitDisplay data={makeData()} onReset={vi.fn()} />);
expect(screen.getByLabelText("Reset scope")).toBeDefined();
expect(screen.getByText("Reset Usage")).toBeDefined();
});
it("disables reset button when nothing to reset", () => {
render(
<RateLimitDisplay
data={makeData({ daily_tokens_used: 0 })}
onReset={vi.fn()}
/>,
);
const button = screen.getByText("Reset Usage").closest("button")!;
expect(button.disabled).toBe(true);
});
it("enables reset button when there is usage to reset", () => {
render(
<RateLimitDisplay
data={makeData({ daily_tokens_used: 100 })}
onReset={vi.fn()}
/>,
);
const button = screen.getByText("Reset Usage").closest("button")!;
expect(button.disabled).toBe(false);
});
it("calls onReset when reset button is clicked and confirmed", async () => {
const onReset = vi.fn().mockResolvedValue(undefined);
mockConfirm.mockReturnValue(true);
render(<RateLimitDisplay data={makeData()} onReset={onReset} />);
fireEvent.click(screen.getByText("Reset Usage"));
await waitFor(() => {
expect(onReset).toHaveBeenCalledWith(false);
});
});
it("does not call onReset when confirm is cancelled", () => {
const onReset = vi.fn();
mockConfirm.mockReturnValue(false);
render(<RateLimitDisplay data={makeData()} onReset={onReset} />);
fireEvent.click(screen.getByText("Reset Usage"));
expect(onReset).not.toHaveBeenCalled();
});
it("passes resetWeekly=true when 'both' is selected", async () => {
const onReset = vi.fn().mockResolvedValue(undefined);
mockConfirm.mockReturnValue(true);
render(
<RateLimitDisplay
data={makeData({ weekly_tokens_used: 100 })}
onReset={onReset}
/>,
);
fireEvent.change(screen.getByLabelText("Reset scope"), {
target: { value: "both" },
});
fireEvent.click(screen.getByText("Reset Usage"));
await waitFor(() => {
expect(onReset).toHaveBeenCalledWith(true);
});
});
it("calls onTierChange when tier is changed and confirmed", async () => {
const onTierChange = vi.fn().mockResolvedValue(undefined);
mockConfirm.mockReturnValue(true);
render(
<RateLimitDisplay
data={makeData({ tier: "FREE" })}
onReset={vi.fn()}
onTierChange={onTierChange}
/>,
);
fireEvent.change(screen.getByLabelText("Subscription tier"), {
target: { value: "PRO" },
});
await waitFor(() => {
expect(onTierChange).toHaveBeenCalledWith("PRO");
});
});
it("does not call onTierChange when selecting the same tier", () => {
const onTierChange = vi.fn();
render(
<RateLimitDisplay
data={makeData({ tier: "FREE" })}
onReset={vi.fn()}
onTierChange={onTierChange}
/>,
);
fireEvent.change(screen.getByLabelText("Subscription tier"), {
target: { value: "FREE" },
});
expect(onTierChange).not.toHaveBeenCalled();
});
it("does not call onTierChange when confirm is cancelled", () => {
const onTierChange = vi.fn();
mockConfirm.mockReturnValue(false);
render(
<RateLimitDisplay
data={makeData({ tier: "FREE" })}
onReset={vi.fn()}
onTierChange={onTierChange}
/>,
);
fireEvent.change(screen.getByLabelText("Subscription tier"), {
target: { value: "PRO" },
});
expect(onTierChange).not.toHaveBeenCalled();
});
it("catches error when onTierChange rejects", async () => {
const onTierChange = vi.fn().mockRejectedValue(new Error("fail"));
mockConfirm.mockReturnValue(true);
render(
<RateLimitDisplay
data={makeData({ tier: "FREE" })}
onReset={vi.fn()}
onTierChange={onTierChange}
/>,
);
fireEvent.change(screen.getByLabelText("Subscription tier"), {
target: { value: "PRO" },
});
await waitFor(() => {
expect(onTierChange).toHaveBeenCalledWith("PRO");
});
});
it("applies custom className when provided", () => {
const { container } = render(
<RateLimitDisplay
data={makeData()}
onReset={vi.fn()}
className="custom-class"
/>,
);
expect(container.firstElementChild?.className).toBe("custom-class");
});
});

View File

@@ -0,0 +1,216 @@
import {
render,
screen,
fireEvent,
cleanup,
} from "@/tests/integrations/test-utils";
import { afterEach, describe, expect, it, vi } from "vitest";
import { RateLimitManager } from "../RateLimitManager";
import type { UserRateLimitResponse } from "@/app/api/__generated__/models/userRateLimitResponse";
const mockHandleSearch = vi.fn();
const mockHandleSelectUser = vi.fn();
const mockHandleReset = vi.fn();
const mockHandleTierChange = vi.fn();
vi.mock("../useRateLimitManager", () => ({
useRateLimitManager: () => mockHookReturn,
}));
vi.mock("../../../components/AdminUserSearch", () => ({
AdminUserSearch: ({
onSearch,
placeholder,
isLoading,
}: {
onSearch: (q: string) => void;
placeholder: string;
isLoading: boolean;
}) => (
<div data-testid="admin-user-search">
<input
data-testid="search-input"
placeholder={placeholder}
disabled={isLoading}
onKeyDown={(e) => {
if (e.key === "Enter") onSearch((e.target as HTMLInputElement).value);
}}
/>
</div>
),
}));
vi.mock("../RateLimitDisplay", () => ({
RateLimitDisplay: ({
data,
onReset,
onTierChange,
}: {
data: UserRateLimitResponse;
onReset: (rw: boolean) => void;
onTierChange: (t: string) => void;
}) => (
<div data-testid="rate-limit-display">
<span>{data.user_email ?? data.user_id}</span>
<button onClick={() => onReset(false)}>mock-reset</button>
<button onClick={() => onTierChange("PRO")}>mock-tier</button>
</div>
),
}));
let mockHookReturn = buildHookReturn();
function buildHookReturn(overrides: Record<string, unknown> = {}) {
return {
isSearching: false,
isLoadingRateLimit: false,
searchResults: [] as Array<{ user_id: string; user_email: string }>,
selectedUser: null as { user_id: string; user_email: string } | null,
rateLimitData: null as UserRateLimitResponse | null,
handleSearch: mockHandleSearch,
handleSelectUser: mockHandleSelectUser,
handleReset: mockHandleReset,
handleTierChange: mockHandleTierChange,
...overrides,
};
}
afterEach(() => {
cleanup();
mockHandleSearch.mockClear();
mockHandleSelectUser.mockClear();
mockHandleReset.mockClear();
mockHandleTierChange.mockClear();
mockHookReturn = buildHookReturn();
});
describe("RateLimitManager", () => {
it("renders the search section", () => {
render(<RateLimitManager />);
expect(screen.getByText("Search User")).toBeDefined();
expect(screen.getByTestId("admin-user-search")).toBeDefined();
});
it("renders description text for search", () => {
render(<RateLimitManager />);
expect(
screen.getByText(/Exact email or user ID does a direct lookup/),
).toBeDefined();
});
it("does not show user list when searchResults is empty", () => {
render(<RateLimitManager />);
expect(screen.queryByText(/Select a user/)).toBeNull();
});
it("shows user selection list when results exist and no user selected", () => {
mockHookReturn = buildHookReturn({
searchResults: [
{ user_id: "u1", user_email: "alice@example.com" },
{ user_id: "u2", user_email: "bob@example.com" },
],
});
render(<RateLimitManager />);
expect(screen.getByText("Select a user (2 results)")).toBeDefined();
expect(screen.getByText("alice@example.com")).toBeDefined();
expect(screen.getByText("bob@example.com")).toBeDefined();
});
it("shows singular 'result' text for single result", () => {
mockHookReturn = buildHookReturn({
searchResults: [{ user_id: "u1", user_email: "alice@example.com" }],
});
render(<RateLimitManager />);
expect(screen.getByText("Select a user (1 result)")).toBeDefined();
});
it("calls handleSelectUser when a user in the list is clicked", () => {
const users = [
{ user_id: "u1", user_email: "alice@example.com" },
{ user_id: "u2", user_email: "bob@example.com" },
];
mockHookReturn = buildHookReturn({ searchResults: users });
render(<RateLimitManager />);
fireEvent.click(screen.getByText("bob@example.com"));
expect(mockHandleSelectUser).toHaveBeenCalledWith(users[1]);
});
it("hides selection list when a user is selected", () => {
const users = [{ user_id: "u1", user_email: "alice@example.com" }];
mockHookReturn = buildHookReturn({
searchResults: users,
selectedUser: users[0],
});
render(<RateLimitManager />);
expect(screen.queryByText(/Select a user/)).toBeNull();
});
it("shows selected user indicator", () => {
const users = [{ user_id: "u1", user_email: "alice@example.com" }];
mockHookReturn = buildHookReturn({
searchResults: users,
selectedUser: users[0],
});
render(<RateLimitManager />);
expect(screen.getByText("Selected:")).toBeDefined();
});
it("shows loading message when isLoadingRateLimit is true", () => {
mockHookReturn = buildHookReturn({ isLoadingRateLimit: true });
render(<RateLimitManager />);
expect(screen.getByText("Loading rate limits...")).toBeDefined();
});
it("renders RateLimitDisplay when rateLimitData is present", () => {
mockHookReturn = buildHookReturn({
rateLimitData: {
user_id: "user-123",
user_email: "alice@example.com",
daily_token_limit: 10000,
weekly_token_limit: 50000,
daily_tokens_used: 2500,
weekly_tokens_used: 10000,
tier: "FREE",
},
});
render(<RateLimitManager />);
expect(screen.getByTestId("rate-limit-display")).toBeDefined();
expect(screen.getByText("alice@example.com")).toBeDefined();
});
it("does not render RateLimitDisplay when rateLimitData is null", () => {
render(<RateLimitManager />);
expect(screen.queryByTestId("rate-limit-display")).toBeNull();
});
it("passes handleReset and handleTierChange to RateLimitDisplay", () => {
mockHookReturn = buildHookReturn({
rateLimitData: {
user_id: "user-123",
user_email: "alice@example.com",
daily_token_limit: 10000,
weekly_token_limit: 50000,
daily_tokens_used: 2500,
weekly_tokens_used: 10000,
tier: "FREE",
},
});
render(<RateLimitManager />);
fireEvent.click(screen.getByText("mock-reset"));
expect(mockHandleReset).toHaveBeenCalledWith(false);
fireEvent.click(screen.getByText("mock-tier"));
expect(mockHandleTierChange).toHaveBeenCalledWith("PRO");
});
});

View File

@@ -0,0 +1,387 @@
import { describe, expect, it, vi, beforeEach, afterEach } from "vitest";
import { renderHook, act, cleanup } from "@testing-library/react";
const mockToast = vi.fn();
vi.mock("@/components/molecules/Toast/use-toast", () => ({
useToast: () => ({ toast: mockToast }),
}));
const mockGetV2GetUserRateLimit = vi.fn();
const mockGetV2SearchUsersByNameOrEmail = vi.fn();
const mockPostV2ResetUserRateLimitUsage = vi.fn();
const mockPostV2SetUserRateLimitTier = vi.fn();
vi.mock("@/app/api/__generated__/endpoints/admin/admin", () => ({
getV2GetUserRateLimit: (...args: unknown[]) =>
mockGetV2GetUserRateLimit(...args),
getV2SearchUsersByNameOrEmail: (...args: unknown[]) =>
mockGetV2SearchUsersByNameOrEmail(...args),
postV2ResetUserRateLimitUsage: (...args: unknown[]) =>
mockPostV2ResetUserRateLimitUsage(...args),
postV2SetUserRateLimitTier: (...args: unknown[]) =>
mockPostV2SetUserRateLimitTier(...args),
}));
import { useRateLimitManager } from "../useRateLimitManager";
function makeRateLimitResponse(overrides = {}) {
return {
user_id: "user-123",
user_email: "alice@example.com",
daily_token_limit: 10000,
weekly_token_limit: 50000,
daily_tokens_used: 2500,
weekly_tokens_used: 10000,
tier: "FREE",
...overrides,
};
}
beforeEach(() => {
mockToast.mockClear();
mockGetV2GetUserRateLimit.mockReset();
mockGetV2SearchUsersByNameOrEmail.mockReset();
mockPostV2ResetUserRateLimitUsage.mockReset();
mockPostV2SetUserRateLimitTier.mockReset();
});
afterEach(() => {
cleanup();
});
describe("useRateLimitManager", () => {
it("returns initial state", () => {
const { result } = renderHook(() => useRateLimitManager());
expect(result.current.isSearching).toBe(false);
expect(result.current.isLoadingRateLimit).toBe(false);
expect(result.current.searchResults).toEqual([]);
expect(result.current.selectedUser).toBeNull();
expect(result.current.rateLimitData).toBeNull();
});
it("handleSearch does nothing for empty query", async () => {
const { result } = renderHook(() => useRateLimitManager());
await act(async () => {
await result.current.handleSearch(" ");
});
expect(mockGetV2GetUserRateLimit).not.toHaveBeenCalled();
expect(mockGetV2SearchUsersByNameOrEmail).not.toHaveBeenCalled();
});
it("handleSearch does direct lookup for email input", async () => {
const data = makeRateLimitResponse();
mockGetV2GetUserRateLimit.mockResolvedValue({ status: 200, data });
const { result } = renderHook(() => useRateLimitManager());
await act(async () => {
await result.current.handleSearch("alice@example.com");
});
expect(mockGetV2GetUserRateLimit).toHaveBeenCalledWith({
email: "alice@example.com",
});
expect(result.current.rateLimitData).toEqual(data);
expect(result.current.selectedUser).toEqual({
user_id: "user-123",
user_email: "alice@example.com",
});
});
it("handleSearch does direct lookup for UUID input", async () => {
const uuid = "550e8400-e29b-41d4-a716-446655440000";
const data = makeRateLimitResponse({ user_id: uuid });
mockGetV2GetUserRateLimit.mockResolvedValue({ status: 200, data });
const { result } = renderHook(() => useRateLimitManager());
await act(async () => {
await result.current.handleSearch(uuid);
});
expect(mockGetV2GetUserRateLimit).toHaveBeenCalledWith({
user_id: uuid,
});
expect(result.current.rateLimitData).toEqual(data);
});
it("handleSearch shows error toast on direct lookup failure", async () => {
mockGetV2GetUserRateLimit.mockResolvedValue({ status: 404 });
const { result } = renderHook(() => useRateLimitManager());
await act(async () => {
await result.current.handleSearch("alice@example.com");
});
expect(mockToast).toHaveBeenCalledWith(
expect.objectContaining({
title: "Error",
variant: "destructive",
}),
);
expect(result.current.rateLimitData).toBeNull();
});
it("handleSearch does fuzzy search for partial text", async () => {
const users = [
{ user_id: "u1", user_email: "alice@example.com" },
{ user_id: "u2", user_email: "bob@example.com" },
];
mockGetV2SearchUsersByNameOrEmail.mockResolvedValue({
status: 200,
data: users,
});
const { result } = renderHook(() => useRateLimitManager());
await act(async () => {
await result.current.handleSearch("alice");
});
expect(mockGetV2SearchUsersByNameOrEmail).toHaveBeenCalledWith({
query: "alice",
limit: 20,
});
expect(result.current.searchResults).toEqual(users);
});
it("handleSearch shows toast when fuzzy search returns no results", async () => {
mockGetV2SearchUsersByNameOrEmail.mockResolvedValue({
status: 200,
data: [],
});
const { result } = renderHook(() => useRateLimitManager());
await act(async () => {
await result.current.handleSearch("nonexistent");
});
expect(mockToast).toHaveBeenCalledWith(
expect.objectContaining({ title: "No results" }),
);
expect(result.current.searchResults).toEqual([]);
});
it("handleSearch shows error toast on fuzzy search failure", async () => {
mockGetV2SearchUsersByNameOrEmail.mockResolvedValue({ status: 500 });
const { result } = renderHook(() => useRateLimitManager());
await act(async () => {
await result.current.handleSearch("alice");
});
expect(mockToast).toHaveBeenCalledWith(
expect.objectContaining({
title: "Error",
variant: "destructive",
}),
);
});
it("handleSelectUser fetches rate limit for selected user", async () => {
const data = makeRateLimitResponse();
mockGetV2GetUserRateLimit.mockResolvedValue({ status: 200, data });
const { result } = renderHook(() => useRateLimitManager());
await act(async () => {
await result.current.handleSelectUser({
user_id: "user-123",
user_email: "alice@example.com",
});
});
expect(mockGetV2GetUserRateLimit).toHaveBeenCalledWith({
user_id: "user-123",
});
expect(result.current.selectedUser).toEqual({
user_id: "user-123",
user_email: "alice@example.com",
});
expect(result.current.rateLimitData).toEqual(data);
});
it("handleSelectUser shows error toast on fetch failure", async () => {
mockGetV2GetUserRateLimit.mockResolvedValue({ status: 500 });
const { result } = renderHook(() => useRateLimitManager());
await act(async () => {
await result.current.handleSelectUser({
user_id: "user-123",
user_email: "alice@example.com",
});
});
expect(mockToast).toHaveBeenCalledWith(
expect.objectContaining({
title: "Error",
variant: "destructive",
}),
);
expect(result.current.rateLimitData).toBeNull();
});
it("handleReset calls reset endpoint and updates data", async () => {
const initial = makeRateLimitResponse({ daily_tokens_used: 5000 });
const after = makeRateLimitResponse({ daily_tokens_used: 0 });
mockGetV2GetUserRateLimit.mockResolvedValue({ status: 200, data: initial });
mockPostV2ResetUserRateLimitUsage.mockResolvedValue({
status: 200,
data: after,
});
const { result } = renderHook(() => useRateLimitManager());
await act(async () => {
await result.current.handleSelectUser({
user_id: "user-123",
user_email: "alice@example.com",
});
});
await act(async () => {
await result.current.handleReset(false);
});
expect(mockPostV2ResetUserRateLimitUsage).toHaveBeenCalledWith({
user_id: "user-123",
reset_weekly: false,
});
expect(result.current.rateLimitData).toEqual(after);
expect(mockToast).toHaveBeenCalledWith(
expect.objectContaining({ title: "Success" }),
);
});
it("handleReset does nothing when no rate limit data", async () => {
const { result } = renderHook(() => useRateLimitManager());
await act(async () => {
await result.current.handleReset(false);
});
expect(mockPostV2ResetUserRateLimitUsage).not.toHaveBeenCalled();
});
it("handleReset shows error toast on failure", async () => {
const initial = makeRateLimitResponse();
mockGetV2GetUserRateLimit.mockResolvedValue({ status: 200, data: initial });
mockPostV2ResetUserRateLimitUsage.mockRejectedValue(
new Error("network error"),
);
const { result } = renderHook(() => useRateLimitManager());
await act(async () => {
await result.current.handleSelectUser({
user_id: "user-123",
user_email: "alice@example.com",
});
});
await act(async () => {
await result.current.handleReset(true);
});
expect(mockToast).toHaveBeenCalledWith(
expect.objectContaining({
title: "Error",
description: "Failed to reset rate limit usage.",
variant: "destructive",
}),
);
});
it("handleTierChange calls set tier and re-fetches", async () => {
const initial = makeRateLimitResponse({ tier: "FREE" });
const updated = makeRateLimitResponse({ tier: "PRO" });
mockGetV2GetUserRateLimit
.mockResolvedValueOnce({ status: 200, data: initial })
.mockResolvedValueOnce({ status: 200, data: updated });
mockPostV2SetUserRateLimitTier.mockResolvedValue({ status: 200 });
const { result } = renderHook(() => useRateLimitManager());
await act(async () => {
await result.current.handleSelectUser({
user_id: "user-123",
user_email: "alice@example.com",
});
});
await act(async () => {
await result.current.handleTierChange("PRO");
});
expect(mockPostV2SetUserRateLimitTier).toHaveBeenCalledWith({
user_id: "user-123",
tier: "PRO",
});
expect(result.current.rateLimitData).toEqual(updated);
});
it("handleTierChange does nothing when no rate limit data", async () => {
const { result } = renderHook(() => useRateLimitManager());
await act(async () => {
await result.current.handleTierChange("PRO");
});
expect(mockPostV2SetUserRateLimitTier).not.toHaveBeenCalled();
});
it("handleReset throws when endpoint returns non-200 status", async () => {
const initial = makeRateLimitResponse({ daily_tokens_used: 5000 });
mockGetV2GetUserRateLimit.mockResolvedValue({ status: 200, data: initial });
mockPostV2ResetUserRateLimitUsage.mockResolvedValue({ status: 500 });
const { result } = renderHook(() => useRateLimitManager());
await act(async () => {
await result.current.handleSelectUser({
user_id: "user-123",
user_email: "alice@example.com",
});
});
await act(async () => {
await result.current.handleReset(false);
});
expect(mockToast).toHaveBeenCalledWith(
expect.objectContaining({
title: "Error",
description: "Failed to reset rate limit usage.",
variant: "destructive",
}),
);
});
it("handleTierChange throws when set-tier endpoint returns non-200", async () => {
const initial = makeRateLimitResponse({ tier: "FREE" });
mockGetV2GetUserRateLimit.mockResolvedValue({ status: 200, data: initial });
mockPostV2SetUserRateLimitTier.mockResolvedValue({ status: 500 });
const { result } = renderHook(() => useRateLimitManager());
await act(async () => {
await result.current.handleSelectUser({
user_id: "user-123",
user_email: "alice@example.com",
});
});
await expect(
act(async () => {
await result.current.handleTierChange("PRO");
}),
).rejects.toThrow("Failed to update tier");
});
});

View File

@@ -2,11 +2,13 @@
import { useState } from "react";
import { useToast } from "@/components/molecules/Toast/use-toast";
import type { SetUserTierRequest } from "@/app/api/__generated__/models/setUserTierRequest";
import type { UserRateLimitResponse } from "@/app/api/__generated__/models/userRateLimitResponse";
import {
getV2GetUserRateLimit,
getV2GetAllUsersHistory,
getV2SearchUsersByNameOrEmail,
postV2ResetUserRateLimitUsage,
postV2SetUserRateLimitTier,
} from "@/app/api/__generated__/endpoints/admin/admin";
export interface UserOption {
@@ -14,18 +16,10 @@ export interface UserOption {
user_email: string;
}
/**
* Returns true when the input looks like a complete email address.
* Used to decide whether to call the direct email lookup endpoint
* vs. the broader user-history search.
*/
function looksLikeEmail(input: string): boolean {
return /^[^\s@]+@[^\s@]+\.[^\s@]+$/.test(input);
}
/**
* Returns true when the input looks like a UUID (user ID).
*/
function looksLikeUuid(input: string): boolean {
return /^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i.test(
input,
@@ -41,7 +35,6 @@ export function useRateLimitManager() {
const [rateLimitData, setRateLimitData] =
useState<UserRateLimitResponse | null>(null);
/** Direct lookup by email or user ID via the rate-limit endpoint. */
async function handleDirectLookup(trimmed: string) {
setIsSearching(true);
setSearchResults([]);
@@ -77,7 +70,6 @@ export function useRateLimitManager() {
}
}
/** Fuzzy name/email search via the spending-history endpoint. */
async function handleFuzzySearch(trimmed: string) {
setIsSearching(true);
setSearchResults([]);
@@ -85,38 +77,21 @@ export function useRateLimitManager() {
setRateLimitData(null);
try {
const response = await getV2GetAllUsersHistory({
search: trimmed,
page: 1,
page_size: 50,
const response = await getV2SearchUsersByNameOrEmail({
query: trimmed,
limit: 20,
});
if (response.status !== 200) {
throw new Error("Failed to search users");
}
// Deduplicate by user_id to get unique users
const seen = new Set<string>();
const users: UserOption[] = [];
for (const tx of response.data.history) {
if (!seen.has(tx.user_id)) {
seen.add(tx.user_id);
users.push({
user_id: tx.user_id,
user_email: String(tx.user_email ?? tx.user_id),
});
}
}
const users = (response.data ?? []).map((u) => ({
user_id: u.user_id,
user_email: u.user_email ?? u.user_id,
}));
if (users.length === 0) {
toast({
title: "No results",
description: "No users found matching your search.",
});
toast({ title: "No results", description: "No users found." });
}
// Always show the result list so the user explicitly picks a match.
// The history endpoint paginates transactions, not users, so a single
// page may not be authoritative -- avoid auto-selecting.
setSearchResults(users);
} catch (error) {
console.error("Error searching users:", error);
@@ -199,6 +174,32 @@ export function useRateLimitManager() {
}
}
async function handleTierChange(newTier: string) {
if (!rateLimitData) return;
const response = await postV2SetUserRateLimitTier({
user_id: rateLimitData.user_id,
tier: newTier as SetUserTierRequest["tier"],
});
if (response.status !== 200) {
throw new Error("Failed to update tier");
}
// Re-fetch rate limit data to reflect new tier-adjusted limits.
try {
const refreshResponse = await getV2GetUserRateLimit({
user_id: rateLimitData.user_id,
});
if (refreshResponse.status === 200) {
setRateLimitData(refreshResponse.data);
}
} catch {
// Tier was changed server-side; UI will be stale but not incorrect.
// The caller's success toast is still valid — the tier change worked.
}
}
return {
isSearching,
isLoadingRateLimit,
@@ -208,5 +209,6 @@ export function useRateLimitManager() {
handleSearch,
handleSelectUser,
handleReset,
handleTierChange,
};
}

View File

@@ -124,9 +124,20 @@ export function UsagePanelContent({
);
}
const tierLabel = usage.tier
? usage.tier.charAt(0) + usage.tier.slice(1).toLowerCase()
: null;
return (
<div className="flex flex-col gap-3">
<div className="text-xs font-semibold text-neutral-800">Usage limits</div>
<div className="flex items-baseline justify-between">
<span className="text-xs font-semibold text-neutral-800">
Usage limits
</span>
{tierLabel && (
<span className="text-[11px] text-neutral-500">{tierLabel} plan</span>
)}
</div>
{hasDailyLimit && (
<UsageBar
label="Today"

View File

@@ -31,16 +31,19 @@ function makeUsage({
dailyLimit = 10000,
weeklyUsed = 2000,
weeklyLimit = 50000,
tier = "FREE",
}: {
dailyUsed?: number;
dailyLimit?: number;
weeklyUsed?: number;
weeklyLimit?: number;
tier?: string;
} = {}) {
const future = new Date(Date.now() + 3600 * 1000); // 1h from now
return {
daily: { used: dailyUsed, limit: dailyLimit, resets_at: future },
weekly: { used: weeklyUsed, limit: weeklyLimit, resets_at: future },
tier,
};
}
@@ -110,6 +113,16 @@ describe("UsageLimits", () => {
expect(screen.getByText("100% used")).toBeDefined();
});
it("displays the user tier label", () => {
mockUseGetV2GetCopilotUsage.mockReturnValue({
data: makeUsage({ tier: "PRO" }),
isLoading: false,
});
render(<UsageLimits />);
expect(screen.getByText("Pro plan")).toBeDefined();
});
it("shows learn more link to credits page", () => {
mockUseGetV2GetCopilotUsage.mockReturnValue({
data: makeUsage(),

View File

@@ -0,0 +1,30 @@
import { describe, expect, it } from "vitest";
import { formatResetTime } from "../UsagePanelContent";
describe("formatResetTime", () => {
const now = new Date("2025-06-15T12:00:00Z");
it("returns 'now' when reset time is in the past", () => {
expect(formatResetTime("2025-06-15T11:00:00Z", now)).toBe("now");
});
it("returns minutes only when under 1 hour", () => {
const result = formatResetTime("2025-06-15T12:30:00Z", now);
expect(result).toBe("in 30m");
});
it("returns hours and minutes when under 24 hours", () => {
const result = formatResetTime("2025-06-15T16:45:00Z", now);
expect(result).toBe("in 4h 45m");
});
it("returns formatted date when over 24 hours away", () => {
const result = formatResetTime("2025-06-17T00:00:00Z", now);
expect(result).toMatch(/Tue/);
});
it("accepts a Date object for resetsAt", () => {
const resetDate = new Date("2025-06-15T14:00:00Z");
expect(formatResetTime(resetDate, now)).toBe("in 2h 0m");
});
});

View File

@@ -0,0 +1,114 @@
import {
render,
screen,
cleanup,
fireEvent,
} from "@/tests/integrations/test-utils";
import { afterEach, describe, expect, it, vi } from "vitest";
import { UsagePanelContent } from "../UsagePanelContent";
import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus";
const mockResetUsage = vi.fn();
vi.mock("../../../hooks/useResetRateLimit", () => ({
useResetRateLimit: () => ({ resetUsage: mockResetUsage, isPending: false }),
}));
afterEach(() => {
cleanup();
mockResetUsage.mockReset();
});
function makeUsage(
overrides: Partial<{
dailyUsed: number;
dailyLimit: number;
weeklyUsed: number;
weeklyLimit: number;
tier: string;
resetCost: number;
}> = {},
): CoPilotUsageStatus {
const {
dailyUsed = 500,
dailyLimit = 10000,
weeklyUsed = 2000,
weeklyLimit = 50000,
tier = "FREE",
resetCost = 100,
} = overrides;
const future = new Date(Date.now() + 3600 * 1000);
return {
daily: { used: dailyUsed, limit: dailyLimit, resets_at: future },
weekly: { used: weeklyUsed, limit: weeklyLimit, resets_at: future },
tier,
reset_cost: resetCost,
} as CoPilotUsageStatus;
}
describe("UsagePanelContent", () => {
it("renders 'No usage limits configured' when both limits are zero", () => {
render(
<UsagePanelContent
usage={makeUsage({ dailyLimit: 0, weeklyLimit: 0 })}
/>,
);
expect(screen.getByText("No usage limits configured")).toBeDefined();
});
it("renders the reset button when daily limit is exhausted", () => {
render(
<UsagePanelContent
usage={makeUsage({
dailyUsed: 10000,
dailyLimit: 10000,
resetCost: 50,
})}
/>,
);
expect(screen.getByText(/Reset daily limit/)).toBeDefined();
});
it("does not render the reset button when weekly limit is also exhausted", () => {
render(
<UsagePanelContent
usage={makeUsage({
dailyUsed: 10000,
dailyLimit: 10000,
weeklyUsed: 50000,
weeklyLimit: 50000,
resetCost: 50,
})}
/>,
);
expect(screen.queryByText(/Reset daily limit/)).toBeNull();
});
it("calls resetUsage when the reset button is clicked", () => {
render(
<UsagePanelContent
usage={makeUsage({
dailyUsed: 10000,
dailyLimit: 10000,
resetCost: 50,
})}
/>,
);
fireEvent.click(screen.getByText(/Reset daily limit/));
expect(mockResetUsage).toHaveBeenCalled();
});
it("renders 'Add credits' link when insufficient credits", () => {
render(
<UsagePanelContent
usage={makeUsage({
dailyUsed: 10000,
dailyLimit: 10000,
resetCost: 50,
})}
hasInsufficientCredits={true}
isBillingEnabled={true}
/>,
);
expect(screen.getByText("Add credits to reset")).toBeDefined();
});
});

View File

@@ -0,0 +1,337 @@
import { describe, expect, it } from "vitest";
import type { ToolUIPart } from "ai";
import {
TOOL_AGENT,
TOOL_TASK,
TOOL_TASK_OUTPUT,
extractToolName,
formatToolName,
getToolCategory,
truncate,
humanizeFileName,
getAnimationText,
} from "../helpers";
describe("extractToolName", () => {
it("strips the tool- prefix from part.type", () => {
const part = { type: "tool-bash_exec" } as unknown as ToolUIPart;
expect(extractToolName(part)).toBe("bash_exec");
});
it("returns type unchanged when there is no tool- prefix", () => {
const part = { type: "Read" } as unknown as ToolUIPart;
expect(extractToolName(part)).toBe("Read");
});
});
describe("formatToolName", () => {
it("replaces underscores with spaces and capitalizes first letter", () => {
expect(formatToolName("bash_exec")).toBe("Bash exec");
});
it("capitalizes a single word", () => {
expect(formatToolName("read")).toBe("Read");
});
it("handles already capitalized names", () => {
expect(formatToolName("WebSearch")).toBe("WebSearch");
});
});
describe("getToolCategory", () => {
it("returns 'bash' for bash_exec", () => {
expect(getToolCategory("bash_exec")).toBe("bash");
});
it("returns 'web' for web_fetch, WebSearch, WebFetch", () => {
expect(getToolCategory("web_fetch")).toBe("web");
expect(getToolCategory("WebSearch")).toBe("web");
expect(getToolCategory("WebFetch")).toBe("web");
});
it("returns 'browser' for browser tools", () => {
expect(getToolCategory("browser_navigate")).toBe("browser");
expect(getToolCategory("browser_act")).toBe("browser");
expect(getToolCategory("browser_screenshot")).toBe("browser");
});
it("returns 'file-read' for read tools", () => {
expect(getToolCategory("read_workspace_file")).toBe("file-read");
expect(getToolCategory("read_file")).toBe("file-read");
expect(getToolCategory("Read")).toBe("file-read");
});
it("returns 'file-write' for write tools", () => {
expect(getToolCategory("write_workspace_file")).toBe("file-write");
expect(getToolCategory("write_file")).toBe("file-write");
expect(getToolCategory("Write")).toBe("file-write");
});
it("returns 'file-delete' for delete tool", () => {
expect(getToolCategory("delete_workspace_file")).toBe("file-delete");
});
it("returns 'file-list' for listing tools", () => {
expect(getToolCategory("list_workspace_files")).toBe("file-list");
expect(getToolCategory("glob")).toBe("file-list");
expect(getToolCategory("Glob")).toBe("file-list");
});
it("returns 'search' for grep tools", () => {
expect(getToolCategory("grep")).toBe("search");
expect(getToolCategory("Grep")).toBe("search");
});
it("returns 'edit' for edit tools", () => {
expect(getToolCategory("edit_file")).toBe("edit");
expect(getToolCategory("Edit")).toBe("edit");
});
it("returns 'todo' for TodoWrite", () => {
expect(getToolCategory("TodoWrite")).toBe("todo");
});
it("returns 'compaction' for context_compaction", () => {
expect(getToolCategory("context_compaction")).toBe("compaction");
});
it("returns 'agent' for agent tools", () => {
expect(getToolCategory(TOOL_AGENT)).toBe("agent");
expect(getToolCategory(TOOL_TASK)).toBe("agent");
expect(getToolCategory(TOOL_TASK_OUTPUT)).toBe("agent");
});
it("returns 'other' for unknown tools", () => {
expect(getToolCategory("unknown_tool")).toBe("other");
});
});
describe("truncate", () => {
it("returns text unchanged when shorter than maxLen", () => {
expect(truncate("short", 10)).toBe("short");
});
it("returns text unchanged when equal to maxLen", () => {
expect(truncate("12345", 5)).toBe("12345");
});
it("truncates and appends ellipsis when longer than maxLen", () => {
const result = truncate("this is a very long string", 10);
expect(result).toBe("this is a\u2026");
expect(result.length).toBeLessThanOrEqual(11);
});
});
describe("humanizeFileName", () => {
it("strips path and extension, titlecases words", () => {
expect(humanizeFileName("/path/to/my-file.ts")).toBe('"My File"');
});
it("handles underscores", () => {
expect(humanizeFileName("some_module_name.py")).toBe('"Some Module Name"');
});
it("preserves all-caps words", () => {
expect(humanizeFileName("README.md")).toBe('"README"');
});
it("handles file with no extension", () => {
expect(humanizeFileName("Makefile")).toBe('"Makefile"');
});
it("strips known extensions", () => {
expect(humanizeFileName("data.json")).toBe('"Data"');
expect(humanizeFileName("image.png")).toBe('"Image"');
expect(humanizeFileName("archive.tar")).toBe('"Archive"');
});
});
describe("getAnimationText", () => {
function makePart(
overrides: Partial<ToolUIPart> & { type: string },
): ToolUIPart {
return {
state: "input-streaming",
input: undefined,
output: undefined,
...overrides,
} as unknown as ToolUIPart;
}
it("shows streaming text for bash with command summary", () => {
const part = makePart({
type: "tool-bash_exec",
state: "input-available",
input: { command: "ls -la" },
});
expect(getAnimationText(part, "bash")).toBe("Running: ls -la");
});
it("shows generic streaming text for bash without input", () => {
const part = makePart({
type: "tool-bash_exec",
state: "input-streaming",
});
expect(getAnimationText(part, "bash")).toBe("Running command\u2026");
});
it("shows completed text for bash", () => {
const part = makePart({
type: "tool-bash_exec",
state: "output-available",
input: { command: "echo hello" },
output: { exit_code: 0 },
});
expect(getAnimationText(part, "bash")).toBe("Ran: echo hello");
});
it("shows exit code on non-zero exit", () => {
const part = makePart({
type: "tool-bash_exec",
state: "output-available",
input: { command: "false" },
output: { exit_code: 1 },
});
expect(getAnimationText(part, "bash")).toBe("Command exited with code 1");
});
it("shows error text for bash failure", () => {
const part = makePart({
type: "tool-bash_exec",
state: "output-error",
});
expect(getAnimationText(part, "bash")).toBe("Command failed");
});
it("shows searching text for WebSearch", () => {
const part = makePart({
type: "tool-WebSearch",
state: "input-available",
input: { query: "test query" },
});
expect(getAnimationText(part, "web")).toBe('Searching "test query"');
});
it("shows fetching text for web_fetch", () => {
const part = makePart({
type: "tool-web_fetch",
state: "input-available",
input: { url: "https://example.com" },
});
expect(getAnimationText(part, "web")).toBe("Fetching https://example.com");
});
it("shows reading text for file-read", () => {
const part = makePart({
type: "tool-Read",
state: "input-available",
input: { file_path: "/src/index.ts" },
});
expect(getAnimationText(part, "file-read")).toBe('Reading "Index"');
});
it("shows writing text for file-write", () => {
const part = makePart({
type: "tool-Write",
state: "input-available",
input: { file_path: "/src/output.json" },
});
expect(getAnimationText(part, "file-write")).toBe('Writing "Output"');
});
it("shows compaction text", () => {
const part = makePart({
type: "tool-context_compaction",
state: "input-streaming",
});
expect(getAnimationText(part, "compaction")).toBe(
"Summarizing earlier messages\u2026",
);
});
it("shows completed compaction text", () => {
const part = makePart({
type: "tool-context_compaction",
state: "output-available",
});
expect(getAnimationText(part, "compaction")).toBe(
"Earlier messages were summarized",
);
});
it("shows agent streaming text with description", () => {
const part = makePart({
type: `tool-${TOOL_AGENT}`,
state: "input-available",
input: { description: "analyze code" },
});
expect(getAnimationText(part, "agent")).toBe("Running agent: analyze code");
});
it("shows agent completed for async launch", () => {
const part = makePart({
type: `tool-${TOOL_AGENT}`,
state: "output-available",
output: { isAsync: true },
});
expect(getAnimationText(part, "agent")).toBe("Agent started in background");
});
it("shows default streaming text for unknown tools", () => {
const part = makePart({
type: "tool-custom_tool",
state: "input-streaming",
});
expect(getAnimationText(part, "other")).toBe("Running Custom tool\u2026");
});
it("shows default completed text for unknown tools", () => {
const part = makePart({
type: "tool-custom_tool",
state: "output-available",
});
expect(getAnimationText(part, "other")).toBe("Custom tool completed");
});
it("shows default error text for unknown tools", () => {
const part = makePart({
type: "tool-custom_tool",
state: "output-error",
});
expect(getAnimationText(part, "other")).toBe("Custom tool failed");
});
it("shows browser screenshot streaming", () => {
const part = makePart({
type: "tool-browser_screenshot",
state: "input-available",
});
expect(getAnimationText(part, "browser")).toBe("Taking screenshot\u2026");
});
it("shows todo streaming text", () => {
const part = makePart({
type: "tool-TodoWrite",
state: "input-available",
input: {
todos: [
{
content: "Fix bug",
status: "in_progress",
activeForm: "Fixing the bug",
},
],
},
});
expect(getAnimationText(part, "todo")).toBe("Fixing the bug");
});
it("shows TaskOutput timeout text", () => {
const part = makePart({
type: `tool-${TOOL_TASK_OUTPUT}`,
state: "output-available",
output: { retrieval_status: "timeout" },
});
expect(getAnimationText(part, "agent")).toBe("Agent still running\u2026");
});
});

View File

@@ -95,7 +95,8 @@ export function useChatSession() {
async function createSession() {
if (sessionId) return sessionId;
try {
const response = await createSessionMutation({ data: null });
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const response = await (createSessionMutation as any)({ data: null });
if (response.status !== 200 || !response.data?.id) {
const error = new Error("Failed to create session");
Sentry.captureException(error, {

View File

@@ -1407,7 +1407,7 @@
"get": {
"tags": ["v2", "chat", "chat"],
"summary": "Get Copilot Usage",
"description": "Get CoPilot usage status for the authenticated user.\n\nReturns current token usage vs limits for daily and weekly windows.\nGlobal defaults sourced from LaunchDarkly (falling back to config).",
"description": "Get CoPilot usage status for the authenticated user.\n\nReturns current token usage vs limits for daily and weekly windows.\nGlobal defaults sourced from LaunchDarkly (falling back to config).\nIncludes the user's rate-limit tier.",
"operationId": "getV2GetCopilotUsage",
"responses": {
"200": {
@@ -1553,6 +1553,128 @@
"security": [{ "HTTPBearerJWT": [] }]
}
},
"/api/copilot/admin/rate_limit/search_users": {
"get": {
"tags": ["v2", "admin", "copilot", "admin"],
"summary": "Search Users by Name or Email",
"description": "Search users by partial email or name. Admin-only.\n\nQueries the User table directly — returns results even for users\nwithout credit transaction history.",
"operationId": "getV2Search users by name or email",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "query",
"in": "query",
"required": true,
"schema": { "type": "string", "title": "Query" }
},
{
"name": "limit",
"in": "query",
"required": false,
"schema": { "type": "integer", "default": 20, "title": "Limit" }
}
],
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"type": "array",
"items": { "$ref": "#/components/schemas/UserSearchResult" },
"title": "Response Getv2Search Users By Name Or Email"
}
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
}
},
"/api/copilot/admin/rate_limit/tier": {
"get": {
"tags": ["v2", "admin", "copilot", "admin"],
"summary": "Get User Rate Limit Tier",
"description": "Get a user's current rate-limit tier. Admin-only.\n\nReturns 404 if the user does not exist in the database.",
"operationId": "getV2Get user rate limit tier",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "user_id",
"in": "query",
"required": true,
"schema": { "type": "string", "title": "User Id" }
}
],
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/UserTierResponse" }
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
},
"post": {
"tags": ["v2", "admin", "copilot", "admin"],
"summary": "Set User Rate Limit Tier",
"description": "Set a user's rate-limit tier. Admin-only.\n\nReturns 404 if the user does not exist in the database.",
"operationId": "postV2Set user rate limit tier",
"security": [{ "HTTPBearerJWT": [] }],
"requestBody": {
"required": true,
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/SetUserTierRequest" }
}
}
},
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/UserTierResponse" }
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
}
},
"/api/credits": {
"get": {
"tags": ["v1", "credits"],
@@ -8496,6 +8618,10 @@
"properties": {
"daily": { "$ref": "#/components/schemas/UsageWindow" },
"weekly": { "$ref": "#/components/schemas/UsageWindow" },
"tier": {
"$ref": "#/components/schemas/SubscriptionTier",
"default": "FREE"
},
"reset_cost": {
"type": "integer",
"title": "Reset Cost",
@@ -12283,6 +12409,15 @@
"required": ["active_graph_version"],
"title": "SetGraphActiveVersion"
},
"SetUserTierRequest": {
"properties": {
"user_id": { "type": "string", "title": "User Id" },
"tier": { "$ref": "#/components/schemas/SubscriptionTier" }
},
"type": "object",
"required": ["user_id", "tier"],
"title": "SetUserTierRequest"
},
"SetupInfo": {
"properties": {
"agent_id": { "type": "string", "title": "Agent Id" },
@@ -13052,6 +13187,12 @@
"enum": ["DRAFT", "PENDING", "APPROVED", "REJECTED"],
"title": "SubmissionStatus"
},
"SubscriptionTier": {
"type": "string",
"enum": ["FREE", "PRO", "BUSINESS", "ENTERPRISE"],
"title": "SubscriptionTier",
"description": "Subscription tiers with increasing token allowances.\n\nMirrors the ``SubscriptionTier`` enum in ``schema.prisma``.\nOnce ``prisma generate`` is run, this can be replaced with::\n\n from prisma.enums import SubscriptionTier"
},
"SuggestedGoalResponse": {
"properties": {
"type": {
@@ -14880,7 +15021,8 @@
"weekly_tokens_used": {
"type": "integer",
"title": "Weekly Tokens Used"
}
},
"tier": { "$ref": "#/components/schemas/SubscriptionTier" }
},
"type": "object",
"required": [
@@ -14888,7 +15030,8 @@
"daily_token_limit",
"weekly_token_limit",
"daily_tokens_used",
"weekly_tokens_used"
"weekly_tokens_used",
"tier"
],
"title": "UserRateLimitResponse"
},
@@ -14915,6 +15058,27 @@
"title": "UserReadiness",
"description": "User readiness status."
},
"UserSearchResult": {
"properties": {
"user_id": { "type": "string", "title": "User Id" },
"user_email": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "User Email"
}
},
"type": "object",
"required": ["user_id"],
"title": "UserSearchResult"
},
"UserTierResponse": {
"properties": {
"user_id": { "type": "string", "title": "User Id" },
"tier": { "$ref": "#/components/schemas/SubscriptionTier" }
},
"type": "object",
"required": ["user_id", "tier"],
"title": "UserTierResponse"
},
"UserTransaction": {
"properties": {
"transaction_key": {