fix(backend/credentials): Handle Python 3.13 str(StrEnum) bug in OAuth state verification

verify_state_token and get_creds_by_provider compared provider strings
with ==, which failed when OAuth states were stored with the buggy
"ProviderName.MCP" format from Python 3.13's str(Enum) behavior.

Also fix double-append in store_state_token where the state was written
once via edit_user_integrations and again via a redundant manual block.
This commit is contained in:
Zamil Majdy
2026-02-10 13:32:38 +04:00
parent 8a2f98b23c
commit c03fb170e0

View File

@@ -22,6 +22,27 @@ from backend.util.settings import Settings
settings = Settings()
def _provider_matches(stored: str, expected: str) -> bool:
"""Compare provider strings, handling Python 3.13 ``str(StrEnum)`` bug.
On Python 3.13, ``str(ProviderName.MCP)`` returns ``"ProviderName.MCP"``
instead of ``"mcp"``. OAuth states persisted with the buggy format need
to match when ``expected`` is the canonical value (e.g. ``"mcp"``).
"""
if stored == expected:
return True
if stored.startswith("ProviderName."):
member = stored.removeprefix("ProviderName.")
from backend.integrations.providers import ProviderName
try:
return ProviderName[member].value == expected
except KeyError:
pass
return False
# This is an overrride since ollama doesn't actually require an API key, but the creddential system enforces one be attached
ollama_credentials = APIKeyCredentials(
id="744fdc56-071a-4761-b5a5-0af0ce10a2b5",
@@ -389,7 +410,7 @@ class IntegrationCredentialsStore:
self, user_id: str, provider: str
) -> list[Credentials]:
credentials = await self.get_all_creds(user_id)
return [c for c in credentials if c.provider == provider]
return [c for c in credentials if _provider_matches(c.provider, provider)]
async def get_authorized_providers(self, user_id: str) -> list[str]:
credentials = await self.get_all_creds(user_id)
@@ -485,17 +506,6 @@ class IntegrationCredentialsStore:
async with self.edit_user_integrations(user_id) as user_integrations:
user_integrations.oauth_states.append(state)
async with await self.locked_user_integrations(user_id):
user_integrations = await self._get_user_integrations(user_id)
oauth_states = user_integrations.oauth_states
oauth_states.append(state)
user_integrations.oauth_states = oauth_states
await self.db_manager.update_user_integrations(
user_id=user_id, data=user_integrations
)
return token, code_challenge
def _generate_code_challenge(self) -> tuple[str, str]:
@@ -521,7 +531,7 @@ class IntegrationCredentialsStore:
state
for state in oauth_states
if secrets.compare_digest(state.token, token)
and state.provider == provider
and _provider_matches(state.provider, provider)
and state.expires_at > now.timestamp()
),
None,