mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Merge branch 'dev' of github.com:Significant-Gravitas/AutoGPT into feat/copilot-include-graph-option
This commit is contained in:
@@ -123,6 +123,7 @@ async def get_provider_token(user_id: str, provider: str) -> str | None:
|
||||
[c for c in creds_list if c.type == "oauth2"],
|
||||
key=lambda c: 0 if "repo" in (cast(OAuth2Credentials, c).scopes or []) else 1,
|
||||
)
|
||||
refresh_failed = False
|
||||
for creds in oauth2_creds:
|
||||
if creds.type == "oauth2":
|
||||
try:
|
||||
@@ -141,6 +142,7 @@ async def get_provider_token(user_id: str, provider: str) -> str | None:
|
||||
# Do NOT fall back to the stale token — it is likely expired
|
||||
# or revoked. Returning None forces the caller to re-auth,
|
||||
# preventing the LLM from receiving a non-functional token.
|
||||
refresh_failed = True
|
||||
continue
|
||||
_token_cache[cache_key] = token
|
||||
return token
|
||||
@@ -152,8 +154,12 @@ async def get_provider_token(user_id: str, provider: str) -> str | None:
|
||||
_token_cache[cache_key] = token
|
||||
return token
|
||||
|
||||
# No credentials found — cache to avoid repeated DB hits.
|
||||
_null_cache[cache_key] = True
|
||||
# Only cache "not connected" when the user truly has no credentials for this
|
||||
# provider. If we had OAuth credentials but refresh failed (e.g. transient
|
||||
# network error, event-loop mismatch), do NOT cache the negative result —
|
||||
# the next call should retry the refresh instead of being blocked for 60 s.
|
||||
if not refresh_failed:
|
||||
_null_cache[cache_key] = True
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@@ -129,8 +129,15 @@ class TestGetProviderToken:
|
||||
assert result == "oauth-tok"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth2_refresh_failure_returns_none(self):
|
||||
"""On refresh failure, return None instead of caching a stale token."""
|
||||
async def test_oauth2_refresh_failure_returns_none_without_null_cache(self):
|
||||
"""On refresh failure, return None but do NOT cache in null_cache.
|
||||
|
||||
The user has credentials — they just couldn't be refreshed right now
|
||||
(e.g. transient network error or event-loop mismatch in the copilot
|
||||
executor). Caching a negative result would block all credential
|
||||
lookups for 60 s even though the creds exist and may refresh fine
|
||||
on the next attempt.
|
||||
"""
|
||||
oauth_creds = _make_oauth2_creds("stale-oauth-tok")
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[oauth_creds])
|
||||
@@ -141,6 +148,8 @@ class TestGetProviderToken:
|
||||
|
||||
# Stale tokens must NOT be returned — forces re-auth.
|
||||
assert result is None
|
||||
# Must NOT cache negative result when refresh failed — next call retries.
|
||||
assert (_USER, _PROVIDER) not in _null_cache
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_no_credentials_caches_null_entry(self):
|
||||
@@ -176,6 +185,96 @@ class TestGetProviderToken:
|
||||
assert _NULL_CACHE_TTL < _TOKEN_CACHE_TTL
|
||||
|
||||
|
||||
class TestThreadSafetyLocks:
|
||||
"""Bug reproduction: shared AsyncRedisKeyedMutex across threads caused
|
||||
'Future attached to a different loop' when copilot workers accessed
|
||||
credentials from different event loops."""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_locks_returns_per_thread_instance(self):
|
||||
"""IntegrationCredentialsStore.locks() must return different instances
|
||||
for different threads (via @thread_cached)."""
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
|
||||
store = IntegrationCredentialsStore()
|
||||
|
||||
async def get_locks_id():
|
||||
mock_redis = AsyncMock()
|
||||
with patch(
|
||||
"backend.integrations.credentials_store.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
locks = await store.locks()
|
||||
return id(locks)
|
||||
|
||||
# Get locks from main thread
|
||||
main_id = await get_locks_id()
|
||||
|
||||
# Get locks from a worker thread
|
||||
def run_in_thread():
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(get_locks_id())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
worker_id = await asyncio.get_event_loop().run_in_executor(
|
||||
pool, run_in_thread
|
||||
)
|
||||
|
||||
assert main_id != worker_id, (
|
||||
"Store.locks() returned the same instance across threads. "
|
||||
"This would cause 'Future attached to a different loop' errors."
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_manager_delegates_to_store_locks(self):
|
||||
"""IntegrationCredentialsManager.locks() should delegate to store."""
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
|
||||
manager = IntegrationCredentialsManager()
|
||||
mock_redis = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"backend.integrations.credentials_store.get_redis_async",
|
||||
return_value=mock_redis,
|
||||
):
|
||||
locks = await manager.locks()
|
||||
|
||||
# Should have gotten it from the store
|
||||
assert locks is not None
|
||||
|
||||
|
||||
class TestRefreshUnlockedPath:
|
||||
"""Bug reproduction: copilot worker threads need lock-free refresh because
|
||||
Redis-backed asyncio.Lock created on one event loop can't be used on another."""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_refresh_if_needed_lock_false_skips_redis(self):
|
||||
"""refresh_if_needed(lock=False) must not touch Redis locks at all."""
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
|
||||
manager = IntegrationCredentialsManager()
|
||||
creds = _make_oauth2_creds()
|
||||
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.needs_refresh = MagicMock(return_value=False)
|
||||
|
||||
with patch(
|
||||
"backend.integrations.creds_manager._get_provider_oauth_handler",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_handler,
|
||||
):
|
||||
result = await manager.refresh_if_needed(_USER, creds, lock=False)
|
||||
|
||||
# Should return credentials without touching locks
|
||||
assert result.id == creds.id
|
||||
|
||||
|
||||
class TestGetIntegrationEnvVars:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_injects_all_env_vars_for_provider(self):
|
||||
|
||||
@@ -19,6 +19,7 @@ from backend.data.model import (
|
||||
UserPasswordCredentials,
|
||||
)
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.util.cache import thread_cached
|
||||
from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
@@ -304,15 +305,12 @@ def is_system_provider(provider: str) -> bool:
|
||||
|
||||
|
||||
class IntegrationCredentialsStore:
|
||||
def __init__(self):
|
||||
self._locks = None
|
||||
|
||||
@thread_cached
|
||||
async def locks(self) -> AsyncRedisKeyedMutex:
|
||||
if self._locks:
|
||||
return self._locks
|
||||
|
||||
self._locks = AsyncRedisKeyedMutex(await get_redis_async())
|
||||
return self._locks
|
||||
# Per-thread: copilot executor runs worker threads with separate event
|
||||
# loops; AsyncRedisKeyedMutex's internal asyncio.Lock is bound to the
|
||||
# loop it was created on.
|
||||
return AsyncRedisKeyedMutex(await get_redis_async())
|
||||
|
||||
@property
|
||||
def db_manager(self):
|
||||
|
||||
@@ -8,7 +8,6 @@ from autogpt_libs.utils.synchronize import AsyncRedisKeyedMutex
|
||||
from redis.asyncio.lock import Lock as AsyncRedisLock
|
||||
|
||||
from backend.data.model import Credentials, OAuth2Credentials
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.integrations.credentials_store import (
|
||||
IntegrationCredentialsStore,
|
||||
provider_matches,
|
||||
@@ -106,14 +105,13 @@ class IntegrationCredentialsManager:
|
||||
|
||||
def __init__(self):
|
||||
self.store = IntegrationCredentialsStore()
|
||||
self._locks = None
|
||||
|
||||
async def locks(self) -> AsyncRedisKeyedMutex:
|
||||
if self._locks:
|
||||
return self._locks
|
||||
|
||||
self._locks = AsyncRedisKeyedMutex(await get_redis_async())
|
||||
return self._locks
|
||||
# Delegate to store's @thread_cached locks. Manager uses these for
|
||||
# fine-grained per-credential locking (refresh, acquire); the store
|
||||
# uses its own for coarse per-user integrations locking. Same mutex
|
||||
# type, different key spaces — no collision.
|
||||
return await self.store.locks()
|
||||
|
||||
async def create(self, user_id: str, credentials: Credentials) -> None:
|
||||
result = await self.store.add_creds(user_id, credentials)
|
||||
@@ -188,35 +186,74 @@ class IntegrationCredentialsManager:
|
||||
|
||||
async def refresh_if_needed(
|
||||
self, user_id: str, credentials: OAuth2Credentials, lock: bool = True
|
||||
) -> OAuth2Credentials:
|
||||
# When lock=False, skip ALL Redis locking (both the outer "refresh" scope
|
||||
# lock and the inner credential lock). This is used by the copilot's
|
||||
# integration_creds module which runs across multiple threads with separate
|
||||
# event loops; acquiring a Redis lock whose asyncio.Lock() was created on
|
||||
# a different loop raises "Future attached to a different loop".
|
||||
if lock:
|
||||
return await self._refresh_locked(user_id, credentials)
|
||||
return await self._refresh_unlocked(user_id, credentials)
|
||||
|
||||
async def _get_oauth_handler(
|
||||
self, credentials: OAuth2Credentials
|
||||
) -> "BaseOAuthHandler":
|
||||
"""Resolve the appropriate OAuth handler for the given credentials."""
|
||||
if provider_matches(credentials.provider, ProviderName.MCP.value):
|
||||
return create_mcp_oauth_handler(credentials)
|
||||
return await _get_provider_oauth_handler(credentials.provider)
|
||||
|
||||
async def _refresh_locked(
|
||||
self, user_id: str, credentials: OAuth2Credentials
|
||||
) -> OAuth2Credentials:
|
||||
async with self._locked(user_id, credentials.id, "refresh"):
|
||||
if provider_matches(credentials.provider, ProviderName.MCP.value):
|
||||
oauth_handler = create_mcp_oauth_handler(credentials)
|
||||
else:
|
||||
oauth_handler = await _get_provider_oauth_handler(credentials.provider)
|
||||
oauth_handler = await self._get_oauth_handler(credentials)
|
||||
if oauth_handler.needs_refresh(credentials):
|
||||
logger.debug(
|
||||
f"Refreshing '{credentials.provider}' credentials #{credentials.id}"
|
||||
"Refreshing '%s' credentials #%s",
|
||||
credentials.provider,
|
||||
credentials.id,
|
||||
)
|
||||
_lock = None
|
||||
if lock:
|
||||
# Wait until the credentials are no longer in use anywhere
|
||||
_lock = await self._acquire_lock(user_id, credentials.id)
|
||||
# Wait until the credentials are no longer in use anywhere
|
||||
_lock = await self._acquire_lock(user_id, credentials.id)
|
||||
try:
|
||||
fresh_credentials = await oauth_handler.refresh_tokens(credentials)
|
||||
await self.store.update_creds(user_id, fresh_credentials)
|
||||
_invoke_creds_changed_hook(user_id, fresh_credentials.provider)
|
||||
credentials = fresh_credentials
|
||||
finally:
|
||||
if (await _lock.locked()) and (await _lock.owned()):
|
||||
try:
|
||||
await _lock.release()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to release OAuth refresh lock",
|
||||
exc_info=True,
|
||||
)
|
||||
return credentials
|
||||
|
||||
fresh_credentials = await oauth_handler.refresh_tokens(credentials)
|
||||
await self.store.update_creds(user_id, fresh_credentials)
|
||||
# Notify listeners so the refreshed token is picked up immediately.
|
||||
_invoke_creds_changed_hook(user_id, fresh_credentials.provider)
|
||||
if _lock and (await _lock.locked()) and (await _lock.owned()):
|
||||
try:
|
||||
await _lock.release()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to release OAuth refresh lock",
|
||||
exc_info=True,
|
||||
)
|
||||
async def _refresh_unlocked(
|
||||
self, user_id: str, credentials: OAuth2Credentials
|
||||
) -> OAuth2Credentials:
|
||||
"""Best-effort token refresh without any Redis locking.
|
||||
|
||||
credentials = fresh_credentials
|
||||
Safe for use from multi-threaded contexts (e.g. copilot workers) where
|
||||
each thread has its own event loop and sharing Redis-backed asyncio locks
|
||||
is not possible. Concurrent refreshes are tolerated: the last writer
|
||||
wins, and stale tokens are overwritten.
|
||||
"""
|
||||
oauth_handler = await self._get_oauth_handler(credentials)
|
||||
if oauth_handler.needs_refresh(credentials):
|
||||
logger.debug(
|
||||
"Refreshing '%s' credentials #%s (lock-free)",
|
||||
credentials.provider,
|
||||
credentials.id,
|
||||
)
|
||||
fresh_credentials = await oauth_handler.refresh_tokens(credentials)
|
||||
await self.store.update_creds(user_id, fresh_credentials)
|
||||
_invoke_creds_changed_hook(user_id, fresh_credentials.provider)
|
||||
credentials = fresh_credentials
|
||||
return credentials
|
||||
|
||||
async def update(self, user_id: str, updated: Credentials) -> None:
|
||||
@@ -264,7 +301,6 @@ class IntegrationCredentialsManager:
|
||||
|
||||
async def release_all_locks(self):
|
||||
"""Call this on process termination to ensure all locks are released"""
|
||||
await (await self.locks()).release_all_locks()
|
||||
await (await self.store.locks()).release_all_locks()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user