improve code

This commit is contained in:
Reinier van der Leer
2025-05-15 14:06:31 +02:00
parent 863a9e98ec
commit 993e123f1b
3 changed files with 23 additions and 8 deletions

View File

@@ -274,10 +274,16 @@ class UserMetadataRaw(TypedDict, total=False):
class UserIntegrations(BaseModel):
class ManagedCredentials(TypedDict, total=False):
class ManagedCredentials(BaseModel):
"""Integration credentials managed by us, rather than by the user"""
ayrshare: APIKeyCredentials
ayrshare_profile_key: Optional[SecretStr] = None
@field_serializer("*")
def dump_secret_strings(value: Any, _info):
if isinstance(value, SecretStr):
return value.get_secret_value()
return value
managed_credentials: ManagedCredentials = Field(default_factory=ManagedCredentials)
credentials: list[Credentials] = Field(default_factory=list)

View File

@@ -124,7 +124,7 @@ async def get_user_integrations(user_id: str) -> UserIntegrations:
async def update_user_integrations(user_id: str, data: UserIntegrations):
encrypted_data = JSONCryptor().encrypt(data.model_dump())
encrypted_data = JSONCryptor().encrypt(data.model_dump(exclude_none=True))
await User.prisma().update(
where={"id": user_id},
data={"integrations": encrypted_data},

View File

@@ -224,6 +224,8 @@ class IntegrationCredentialsStore:
return get_service_client(DatabaseManagerClient)
# =============== USER-MANAGED CREDENTIALS =============== #
def add_creds(self, user_id: str, credentials: Credentials) -> None:
with self.locked_user_integrations(user_id):
if self.get_creds_by_id(user_id, credentials.id):
@@ -337,18 +339,23 @@ class IntegrationCredentialsStore:
]
self._set_user_integration_creds(user_id, filtered_credentials)
def get_managed_creds(self, user_id: str, provider: str) -> Optional[Credentials]:
user_integrations = self._get_user_integrations(user_id)
return user_integrations.managed_credentials.get(provider)
# ============== SYSTEM-MANAGED CREDENTIALS ============== #
def set_managed_creds(self, user_id: str, credentials: Credentials) -> None:
def get_ayrshare_profile_key(self, user_id: str) -> SecretStr | None:
managed_user_creds = self._get_user_integrations(user_id).managed_credentials
return managed_user_creds.ayrshare_profile_key
def set_ayrshare_profile_key(self, user_id: str, profile_key: str) -> None:
_profile_key = SecretStr(profile_key)
with self.locked_user_integrations(user_id):
user_integrations = self._get_user_integrations(user_id)
user_integrations.managed_credentials[credentials.provider] = credentials
user_integrations.managed_credentials.ayrshare_profile_key = _profile_key
self.db_manager.update_user_integrations(
user_id=user_id, data=user_integrations
)
# ===================== OAUTH STATES ===================== #
def store_state_token(
self, user_id: str, provider: str, scopes: list[str], use_pkce: bool = False
) -> tuple[str, str]:
@@ -416,6 +423,8 @@ class IntegrationCredentialsStore:
return None
# =================== GET/SET HELPERS =================== #
def _set_user_integration_creds(
self, user_id: str, credentials: list[Credentials]
) -> None: