mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
fix(copilot): move get_user_tier import to top-level and expose cache via public API
- sdk/service.py: Move `get_user_tier` import from local (inside function) to module-level — no circular dependency exists. - rate_limit.py: Expose `cache_clear`/`cache_delete` as attributes on the public `get_user_tier` function so callers never need to import the private `_fetch_user_tier`. - rate_limit_test.py: Remove `_fetch_user_tier` import; use `get_user_tier.cache_clear()` instead.
This commit is contained in:
@@ -407,10 +407,16 @@ async def get_user_tier(user_id: str) -> RateLimitTier:
|
||||
return DEFAULT_TIER
|
||||
|
||||
|
||||
# Expose cache management on the public function so callers (including tests)
|
||||
# never need to reach into the private ``_fetch_user_tier``.
|
||||
get_user_tier.cache_clear = _fetch_user_tier.cache_clear # type: ignore[attr-defined]
|
||||
get_user_tier.cache_delete = _fetch_user_tier.cache_delete # type: ignore[attr-defined]
|
||||
|
||||
|
||||
async def set_user_tier(user_id: str, tier: RateLimitTier) -> None:
|
||||
"""Persist the user's rate-limit tier to the database.
|
||||
|
||||
Also invalidates the ``_fetch_user_tier`` cache for this user so that
|
||||
Also invalidates the ``get_user_tier`` cache for this user so that
|
||||
subsequent rate-limit checks immediately see the new tier.
|
||||
|
||||
Raises:
|
||||
@@ -421,7 +427,7 @@ async def set_user_tier(user_id: str, tier: RateLimitTier) -> None:
|
||||
data={"rateLimitTier": tier.value},
|
||||
)
|
||||
# Invalidate cached tier so rate-limit checks pick up the change immediately.
|
||||
_fetch_user_tier.cache_delete(user_id)
|
||||
get_user_tier.cache_delete(user_id) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
async def get_global_rate_limits(
|
||||
|
||||
@@ -13,7 +13,6 @@ from .rate_limit import (
|
||||
RateLimitExceeded,
|
||||
RateLimitTier,
|
||||
UsageWindow,
|
||||
_fetch_user_tier,
|
||||
check_rate_limit,
|
||||
get_global_rate_limits,
|
||||
get_usage_status,
|
||||
@@ -388,8 +387,8 @@ class TestRateLimitTier:
|
||||
class TestGetUserTier:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_tier_cache(self):
|
||||
"""Clear the _fetch_user_tier cache before each test."""
|
||||
_fetch_user_tier.cache_clear()
|
||||
"""Clear the get_user_tier cache before each test."""
|
||||
get_user_tier.cache_clear() # type: ignore[attr-defined]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_tier_from_db(self):
|
||||
|
||||
@@ -33,6 +33,7 @@ from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.context import get_workspace_manager
|
||||
from backend.copilot.permissions import apply_tool_permissions
|
||||
from backend.copilot.rate_limit import get_user_tier
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.executor.cluster_lock import AsyncClusterLock
|
||||
from backend.util.exceptions import NotFoundError
|
||||
@@ -1834,8 +1835,6 @@ async def stream_chat_completion_sdk(
|
||||
# langsmith tracing integration attaches them to every span. This
|
||||
# is what Langfuse (or any OTEL backend) maps to its native
|
||||
# user/session fields.
|
||||
from backend.copilot.rate_limit import get_user_tier
|
||||
|
||||
_user_tier = await get_user_tier(user_id) if user_id else None
|
||||
_otel_metadata: dict[str, str] = {
|
||||
"resume": str(use_resume),
|
||||
|
||||
Reference in New Issue
Block a user