Compare commits

...

3 Commits

Author SHA1 Message Date
Reinier van der Leer
7d90376eb2 simplify code 2025-05-15 14:57:56 +02:00
Reinier van der Leer
993e123f1b improve code 2025-05-15 14:06:31 +02:00
Reinier van der Leer
863a9e98ec feat(backend): Managed credentials store 2025-05-15 13:18:23 +02:00
3 changed files with 48 additions and 12 deletions

View File

@@ -15,7 +15,6 @@ from typing import (
Literal,
Optional,
Sequence,
TypedDict,
TypeVar,
get_args,
)
@@ -37,6 +36,7 @@ from pydantic_core import (
ValidationError,
core_schema,
)
from typing_extensions import TypedDict
from backend.integrations.providers import ProviderName
from backend.util.settings import Secrets
@@ -260,15 +260,32 @@ class OAuthState(BaseModel):
class UserMetadata(BaseModel):
integration_credentials: list[Credentials] = Field(default_factory=list)
"""⚠️ Deprecated; use `UserIntegrations.credentials` instead"""
integration_oauth_states: list[OAuthState] = Field(default_factory=list)
"""⚠️ Deprecated; use `UserIntegrations.oauth_states` instead"""
class UserMetadataRaw(TypedDict, total=False):
integration_credentials: list[dict]
"""⚠️ Deprecated; use `UserIntegrations.credentials` instead"""
integration_oauth_states: list[dict]
"""⚠️ Deprecated; use `UserIntegrations.oauth_states` instead"""
class UserIntegrations(BaseModel):
class ManagedCredentials(BaseModel):
"""Integration credentials managed by us, rather than by the user"""
ayrshare_profile_key: Optional[SecretStr] = None
@field_serializer("*")
def dump_secret_strings(value: Any, _info):
if isinstance(value, SecretStr):
return value.get_secret_value()
return value
managed_credentials: ManagedCredentials = Field(default_factory=ManagedCredentials)
credentials: list[Credentials] = Field(default_factory=list)
oauth_states: list[OAuthState] = Field(default_factory=list)

View File

@@ -124,7 +124,7 @@ async def get_user_integrations(user_id: str) -> UserIntegrations:
async def update_user_integrations(user_id: str, data: UserIntegrations):
encrypted_data = JSONCryptor().encrypt(data.model_dump())
encrypted_data = JSONCryptor().encrypt(data.model_dump(exclude_none=True))
await User.prisma().update(
where={"id": user_id},
data={"integrations": encrypted_data},

View File

@@ -1,6 +1,7 @@
import base64
import hashlib
import secrets
from contextlib import contextmanager
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Optional
@@ -224,6 +225,8 @@ class IntegrationCredentialsStore:
return get_service_client(DatabaseManagerClient)
# =============== USER-MANAGED CREDENTIALS =============== #
def add_creds(self, user_id: str, credentials: Credentials) -> None:
with self.locked_user_integrations(user_id):
if self.get_creds_by_id(user_id, credentials.id):
@@ -337,6 +340,19 @@ class IntegrationCredentialsStore:
]
self._set_user_integration_creds(user_id, filtered_credentials)
# ============== SYSTEM-MANAGED CREDENTIALS ============== #
def get_ayrshare_profile_key(self, user_id: str) -> SecretStr | None:
managed_user_creds = self._get_user_integrations(user_id).managed_credentials
return managed_user_creds.ayrshare_profile_key
def set_ayrshare_profile_key(self, user_id: str, profile_key: str) -> None:
_profile_key = SecretStr(profile_key)
with self.edit_user_integrations(user_id) as user_integrations:
user_integrations.managed_credentials.ayrshare_profile_key = _profile_key
# ===================== OAUTH STATES ===================== #
def store_state_token(
self, user_id: str, provider: str, scopes: list[str], use_pkce: bool = False
) -> tuple[str, str]:
@@ -353,16 +369,8 @@ class IntegrationCredentialsStore:
scopes=scopes,
)
with self.locked_user_integrations(user_id):
user_integrations = self._get_user_integrations(user_id)
oauth_states = user_integrations.oauth_states
oauth_states.append(state)
user_integrations.oauth_states = oauth_states
self.db_manager.update_user_integrations(
user_id=user_id, data=user_integrations
)
with self.edit_user_integrations(user_id) as user_integrations:
user_integrations.oauth_states.append(state)
return token, code_challenge
@@ -404,6 +412,17 @@ class IntegrationCredentialsStore:
return None
# =================== GET/SET HELPERS =================== #
@contextmanager
def edit_user_integrations(self, user_id: str):
with self.locked_user_integrations(user_id):
user_integrations = self._get_user_integrations(user_id)
yield user_integrations # yield to allow edits
self.db_manager.update_user_integrations(
user_id=user_id, data=user_integrations
)
def _set_user_integration_creds(
self, user_id: str, credentials: list[Credentials]
) -> None: