fix(copilot): move get_user_tier import to top-level and expose cache via public API

- sdk/service.py: Move `get_user_tier` import from local (inside function)
  to module-level — no circular dependency exists.
- rate_limit.py: Expose `cache_clear`/`cache_delete` as attributes on the
  public `get_user_tier` function so callers never need to import the
  private `_fetch_user_tier`.
- rate_limit_test.py: Remove `_fetch_user_tier` import; use
  `get_user_tier.cache_clear()` instead.
This commit is contained in:
Zamil Majdy
2026-03-27 09:52:59 +07:00
parent e1d5113051
commit e900ee615a
3 changed files with 11 additions and 7 deletions

View File

@@ -407,10 +407,16 @@ async def get_user_tier(user_id: str) -> RateLimitTier:
return DEFAULT_TIER
# Expose cache management on the public function so callers (including tests)
# never need to reach into the private ``_fetch_user_tier``.
get_user_tier.cache_clear = _fetch_user_tier.cache_clear # type: ignore[attr-defined]
get_user_tier.cache_delete = _fetch_user_tier.cache_delete # type: ignore[attr-defined]
async def set_user_tier(user_id: str, tier: RateLimitTier) -> None:
"""Persist the user's rate-limit tier to the database.
Also invalidates the ``_fetch_user_tier`` cache for this user so that
Also invalidates the ``get_user_tier`` cache for this user so that
subsequent rate-limit checks immediately see the new tier.
Raises:
@@ -421,7 +427,7 @@ async def set_user_tier(user_id: str, tier: RateLimitTier) -> None:
data={"rateLimitTier": tier.value},
)
# Invalidate cached tier so rate-limit checks pick up the change immediately.
_fetch_user_tier.cache_delete(user_id)
get_user_tier.cache_delete(user_id) # type: ignore[attr-defined]
async def get_global_rate_limits(

View File

@@ -13,7 +13,6 @@ from .rate_limit import (
RateLimitExceeded,
RateLimitTier,
UsageWindow,
_fetch_user_tier,
check_rate_limit,
get_global_rate_limits,
get_usage_status,
@@ -388,8 +387,8 @@ class TestRateLimitTier:
class TestGetUserTier:
@pytest.fixture(autouse=True)
def _clear_tier_cache(self):
"""Clear the _fetch_user_tier cache before each test."""
_fetch_user_tier.cache_clear()
"""Clear the get_user_tier cache before each test."""
get_user_tier.cache_clear() # type: ignore[attr-defined]
@pytest.mark.asyncio
async def test_returns_tier_from_db(self):

View File

@@ -33,6 +33,7 @@ from pydantic import BaseModel
from backend.copilot.context import get_workspace_manager
from backend.copilot.permissions import apply_tool_permissions
from backend.copilot.rate_limit import get_user_tier
from backend.data.redis_client import get_redis_async
from backend.executor.cluster_lock import AsyncClusterLock
from backend.util.exceptions import NotFoundError
@@ -1834,8 +1835,6 @@ async def stream_chat_completion_sdk(
# langsmith tracing integration attaches them to every span. This
# is what Langfuse (or any OTEL backend) maps to its native
# user/session fields.
from backend.copilot.rate_limit import get_user_tier
_user_tier = await get_user_tier(user_id) if user_id else None
_otel_metadata: dict[str, str] = {
"resume": str(use_resume),