Merge branch 'dev' of github.com:Significant-Gravitas/AutoGPT into feat/copilot-include-graph-option

This commit is contained in:
Zamil Majdy
2026-04-01 20:51:46 +02:00
4 changed files with 181 additions and 42 deletions

View File

@@ -123,6 +123,7 @@ async def get_provider_token(user_id: str, provider: str) -> str | None:
[c for c in creds_list if c.type == "oauth2"],
key=lambda c: 0 if "repo" in (cast(OAuth2Credentials, c).scopes or []) else 1,
)
refresh_failed = False
for creds in oauth2_creds:
if creds.type == "oauth2":
try:
@@ -141,6 +142,7 @@ async def get_provider_token(user_id: str, provider: str) -> str | None:
# Do NOT fall back to the stale token — it is likely expired
# or revoked. Returning None forces the caller to re-auth,
# preventing the LLM from receiving a non-functional token.
refresh_failed = True
continue
_token_cache[cache_key] = token
return token
@@ -152,8 +154,12 @@ async def get_provider_token(user_id: str, provider: str) -> str | None:
_token_cache[cache_key] = token
return token
# No credentials found — cache to avoid repeated DB hits.
_null_cache[cache_key] = True
# Only cache "not connected" when the user truly has no credentials for this
# provider. If we had OAuth credentials but refresh failed (e.g. transient
# network error, event-loop mismatch), do NOT cache the negative result —
# the next call should retry the refresh instead of being blocked for 60 s.
if not refresh_failed:
_null_cache[cache_key] = True
return None

View File

@@ -129,8 +129,15 @@ class TestGetProviderToken:
assert result == "oauth-tok"
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth2_refresh_failure_returns_none(self):
"""On refresh failure, return None instead of caching a stale token."""
async def test_oauth2_refresh_failure_returns_none_without_null_cache(self):
"""On refresh failure, return None but do NOT cache in null_cache.
The user has credentials — they just couldn't be refreshed right now
(e.g. transient network error or event-loop mismatch in the copilot
executor). Caching a negative result would block all credential
lookups for 60 s even though the creds exist and may refresh fine
on the next attempt.
"""
oauth_creds = _make_oauth2_creds("stale-oauth-tok")
mock_manager = MagicMock()
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[oauth_creds])
@@ -141,6 +148,8 @@ class TestGetProviderToken:
# Stale tokens must NOT be returned — forces re-auth.
assert result is None
# Must NOT cache negative result when refresh failed — next call retries.
assert (_USER, _PROVIDER) not in _null_cache
@pytest.mark.asyncio(loop_scope="session")
async def test_no_credentials_caches_null_entry(self):
@@ -176,6 +185,96 @@ class TestGetProviderToken:
assert _NULL_CACHE_TTL < _TOKEN_CACHE_TTL
class TestThreadSafetyLocks:
"""Bug reproduction: shared AsyncRedisKeyedMutex across threads caused
'Future attached to a different loop' when copilot workers accessed
credentials from different event loops."""
@pytest.mark.asyncio(loop_scope="session")
async def test_store_locks_returns_per_thread_instance(self):
"""IntegrationCredentialsStore.locks() must return different instances
for different threads (via @thread_cached)."""
import asyncio
import concurrent.futures
from backend.integrations.credentials_store import IntegrationCredentialsStore
store = IntegrationCredentialsStore()
async def get_locks_id():
mock_redis = AsyncMock()
with patch(
"backend.integrations.credentials_store.get_redis_async",
return_value=mock_redis,
):
locks = await store.locks()
return id(locks)
# Get locks from main thread
main_id = await get_locks_id()
# Get locks from a worker thread
def run_in_thread():
loop = asyncio.new_event_loop()
try:
return loop.run_until_complete(get_locks_id())
finally:
loop.close()
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
worker_id = await asyncio.get_event_loop().run_in_executor(
pool, run_in_thread
)
assert main_id != worker_id, (
"Store.locks() returned the same instance across threads. "
"This would cause 'Future attached to a different loop' errors."
)
@pytest.mark.asyncio(loop_scope="session")
async def test_manager_delegates_to_store_locks(self):
"""IntegrationCredentialsManager.locks() should delegate to store."""
from backend.integrations.creds_manager import IntegrationCredentialsManager
manager = IntegrationCredentialsManager()
mock_redis = AsyncMock()
with patch(
"backend.integrations.credentials_store.get_redis_async",
return_value=mock_redis,
):
locks = await manager.locks()
# Should have gotten it from the store
assert locks is not None
class TestRefreshUnlockedPath:
"""Bug reproduction: copilot worker threads need lock-free refresh because
Redis-backed asyncio.Lock created on one event loop can't be used on another."""
@pytest.mark.asyncio(loop_scope="session")
async def test_refresh_if_needed_lock_false_skips_redis(self):
"""refresh_if_needed(lock=False) must not touch Redis locks at all."""
from backend.integrations.creds_manager import IntegrationCredentialsManager
manager = IntegrationCredentialsManager()
creds = _make_oauth2_creds()
mock_handler = MagicMock()
mock_handler.needs_refresh = MagicMock(return_value=False)
with patch(
"backend.integrations.creds_manager._get_provider_oauth_handler",
new_callable=AsyncMock,
return_value=mock_handler,
):
result = await manager.refresh_if_needed(_USER, creds, lock=False)
# Should return credentials without touching locks
assert result.id == creds.id
class TestGetIntegrationEnvVars:
@pytest.mark.asyncio(loop_scope="session")
async def test_injects_all_env_vars_for_provider(self):

View File

@@ -19,6 +19,7 @@ from backend.data.model import (
UserPasswordCredentials,
)
from backend.data.redis_client import get_redis_async
from backend.util.cache import thread_cached
from backend.util.settings import Settings
settings = Settings()
@@ -304,15 +305,12 @@ def is_system_provider(provider: str) -> bool:
class IntegrationCredentialsStore:
def __init__(self):
self._locks = None
@thread_cached
async def locks(self) -> AsyncRedisKeyedMutex:
if self._locks:
return self._locks
self._locks = AsyncRedisKeyedMutex(await get_redis_async())
return self._locks
# Per-thread: copilot executor runs worker threads with separate event
# loops; AsyncRedisKeyedMutex's internal asyncio.Lock is bound to the
# loop it was created on.
return AsyncRedisKeyedMutex(await get_redis_async())
@property
def db_manager(self):

View File

@@ -8,7 +8,6 @@ from autogpt_libs.utils.synchronize import AsyncRedisKeyedMutex
from redis.asyncio.lock import Lock as AsyncRedisLock
from backend.data.model import Credentials, OAuth2Credentials
from backend.data.redis_client import get_redis_async
from backend.integrations.credentials_store import (
IntegrationCredentialsStore,
provider_matches,
@@ -106,14 +105,13 @@ class IntegrationCredentialsManager:
def __init__(self):
self.store = IntegrationCredentialsStore()
self._locks = None
async def locks(self) -> AsyncRedisKeyedMutex:
if self._locks:
return self._locks
self._locks = AsyncRedisKeyedMutex(await get_redis_async())
return self._locks
# Delegate to store's @thread_cached locks. Manager uses these for
# fine-grained per-credential locking (refresh, acquire); the store
# uses its own for coarse per-user integrations locking. Same mutex
# type, different key spaces — no collision.
return await self.store.locks()
async def create(self, user_id: str, credentials: Credentials) -> None:
result = await self.store.add_creds(user_id, credentials)
@@ -188,35 +186,74 @@ class IntegrationCredentialsManager:
async def refresh_if_needed(
self, user_id: str, credentials: OAuth2Credentials, lock: bool = True
) -> OAuth2Credentials:
# When lock=False, skip ALL Redis locking (both the outer "refresh" scope
# lock and the inner credential lock). This is used by the copilot's
# integration_creds module which runs across multiple threads with separate
# event loops; acquiring a Redis lock whose asyncio.Lock() was created on
# a different loop raises "Future attached to a different loop".
if lock:
return await self._refresh_locked(user_id, credentials)
return await self._refresh_unlocked(user_id, credentials)
async def _get_oauth_handler(
self, credentials: OAuth2Credentials
) -> "BaseOAuthHandler":
"""Resolve the appropriate OAuth handler for the given credentials."""
if provider_matches(credentials.provider, ProviderName.MCP.value):
return create_mcp_oauth_handler(credentials)
return await _get_provider_oauth_handler(credentials.provider)
async def _refresh_locked(
self, user_id: str, credentials: OAuth2Credentials
) -> OAuth2Credentials:
async with self._locked(user_id, credentials.id, "refresh"):
if provider_matches(credentials.provider, ProviderName.MCP.value):
oauth_handler = create_mcp_oauth_handler(credentials)
else:
oauth_handler = await _get_provider_oauth_handler(credentials.provider)
oauth_handler = await self._get_oauth_handler(credentials)
if oauth_handler.needs_refresh(credentials):
logger.debug(
f"Refreshing '{credentials.provider}' credentials #{credentials.id}"
"Refreshing '%s' credentials #%s",
credentials.provider,
credentials.id,
)
_lock = None
if lock:
# Wait until the credentials are no longer in use anywhere
_lock = await self._acquire_lock(user_id, credentials.id)
# Wait until the credentials are no longer in use anywhere
_lock = await self._acquire_lock(user_id, credentials.id)
try:
fresh_credentials = await oauth_handler.refresh_tokens(credentials)
await self.store.update_creds(user_id, fresh_credentials)
_invoke_creds_changed_hook(user_id, fresh_credentials.provider)
credentials = fresh_credentials
finally:
if (await _lock.locked()) and (await _lock.owned()):
try:
await _lock.release()
except Exception:
logger.warning(
"Failed to release OAuth refresh lock",
exc_info=True,
)
return credentials
fresh_credentials = await oauth_handler.refresh_tokens(credentials)
await self.store.update_creds(user_id, fresh_credentials)
# Notify listeners so the refreshed token is picked up immediately.
_invoke_creds_changed_hook(user_id, fresh_credentials.provider)
if _lock and (await _lock.locked()) and (await _lock.owned()):
try:
await _lock.release()
except Exception:
logger.warning(
"Failed to release OAuth refresh lock",
exc_info=True,
)
async def _refresh_unlocked(
self, user_id: str, credentials: OAuth2Credentials
) -> OAuth2Credentials:
"""Best-effort token refresh without any Redis locking.
credentials = fresh_credentials
Safe for use from multi-threaded contexts (e.g. copilot workers) where
each thread has its own event loop and sharing Redis-backed asyncio locks
is not possible. Concurrent refreshes are tolerated: the last writer
wins, and stale tokens are overwritten.
"""
oauth_handler = await self._get_oauth_handler(credentials)
if oauth_handler.needs_refresh(credentials):
logger.debug(
"Refreshing '%s' credentials #%s (lock-free)",
credentials.provider,
credentials.id,
)
fresh_credentials = await oauth_handler.refresh_tokens(credentials)
await self.store.update_creds(user_id, fresh_credentials)
_invoke_creds_changed_hook(user_id, fresh_credentials.provider)
credentials = fresh_credentials
return credentials
async def update(self, user_id: str, updated: Credentials) -> None:
@@ -264,7 +301,6 @@ class IntegrationCredentialsManager:
async def release_all_locks(self):
"""Call this on process termination to ensure all locks are released"""
await (await self.locks()).release_all_locks()
await (await self.store.locks()).release_all_locks()