mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-10 06:45:28 -05:00
fix(backend/credentials): Handle Python 3.13 str(StrEnum) bug in OAuth state verification
verify_state_token and get_creds_by_provider compared provider strings with ==, which failed when OAuth states were stored with the buggy "ProviderName.MCP" format from Python 3.13's str(Enum) behavior. Also fix double-append in store_state_token where the state was written once via edit_user_integrations and again via a redundant manual block.
This commit is contained in:
@@ -22,6 +22,27 @@ from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
def _provider_matches(stored: str, expected: str) -> bool:
|
||||
"""Compare provider strings, handling Python 3.13 ``str(StrEnum)`` bug.
|
||||
|
||||
On Python 3.13, ``str(ProviderName.MCP)`` returns ``"ProviderName.MCP"``
|
||||
instead of ``"mcp"``. OAuth states persisted with the buggy format need
|
||||
to match when ``expected`` is the canonical value (e.g. ``"mcp"``).
|
||||
"""
|
||||
if stored == expected:
|
||||
return True
|
||||
if stored.startswith("ProviderName."):
|
||||
member = stored.removeprefix("ProviderName.")
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
try:
|
||||
return ProviderName[member].value == expected
|
||||
except KeyError:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
# This is an overrride since ollama doesn't actually require an API key, but the creddential system enforces one be attached
|
||||
ollama_credentials = APIKeyCredentials(
|
||||
id="744fdc56-071a-4761-b5a5-0af0ce10a2b5",
|
||||
@@ -389,7 +410,7 @@ class IntegrationCredentialsStore:
|
||||
self, user_id: str, provider: str
|
||||
) -> list[Credentials]:
|
||||
credentials = await self.get_all_creds(user_id)
|
||||
return [c for c in credentials if c.provider == provider]
|
||||
return [c for c in credentials if _provider_matches(c.provider, provider)]
|
||||
|
||||
async def get_authorized_providers(self, user_id: str) -> list[str]:
|
||||
credentials = await self.get_all_creds(user_id)
|
||||
@@ -485,17 +506,6 @@ class IntegrationCredentialsStore:
|
||||
async with self.edit_user_integrations(user_id) as user_integrations:
|
||||
user_integrations.oauth_states.append(state)
|
||||
|
||||
async with await self.locked_user_integrations(user_id):
|
||||
|
||||
user_integrations = await self._get_user_integrations(user_id)
|
||||
oauth_states = user_integrations.oauth_states
|
||||
oauth_states.append(state)
|
||||
user_integrations.oauth_states = oauth_states
|
||||
|
||||
await self.db_manager.update_user_integrations(
|
||||
user_id=user_id, data=user_integrations
|
||||
)
|
||||
|
||||
return token, code_challenge
|
||||
|
||||
def _generate_code_challenge(self) -> tuple[str, str]:
|
||||
@@ -521,7 +531,7 @@ class IntegrationCredentialsStore:
|
||||
state
|
||||
for state in oauth_states
|
||||
if secrets.compare_digest(state.token, token)
|
||||
and state.provider == provider
|
||||
and _provider_matches(state.provider, provider)
|
||||
and state.expires_at > now.timestamp()
|
||||
),
|
||||
None,
|
||||
|
||||
Reference in New Issue
Block a user