mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
improve code
This commit is contained in:
@@ -274,10 +274,16 @@ class UserMetadataRaw(TypedDict, total=False):
|
||||
|
||||
class UserIntegrations(BaseModel):
|
||||
|
||||
class ManagedCredentials(TypedDict, total=False):
|
||||
class ManagedCredentials(BaseModel):
|
||||
"""Integration credentials managed by us, rather than by the user"""
|
||||
|
||||
ayrshare: APIKeyCredentials
|
||||
ayrshare_profile_key: Optional[SecretStr] = None
|
||||
|
||||
@field_serializer("*")
|
||||
def dump_secret_strings(value: Any, _info):
|
||||
if isinstance(value, SecretStr):
|
||||
return value.get_secret_value()
|
||||
return value
|
||||
|
||||
managed_credentials: ManagedCredentials = Field(default_factory=ManagedCredentials)
|
||||
credentials: list[Credentials] = Field(default_factory=list)
|
||||
|
||||
@@ -124,7 +124,7 @@ async def get_user_integrations(user_id: str) -> UserIntegrations:
|
||||
|
||||
|
||||
async def update_user_integrations(user_id: str, data: UserIntegrations):
|
||||
encrypted_data = JSONCryptor().encrypt(data.model_dump())
|
||||
encrypted_data = JSONCryptor().encrypt(data.model_dump(exclude_none=True))
|
||||
await User.prisma().update(
|
||||
where={"id": user_id},
|
||||
data={"integrations": encrypted_data},
|
||||
|
||||
@@ -224,6 +224,8 @@ class IntegrationCredentialsStore:
|
||||
|
||||
return get_service_client(DatabaseManagerClient)
|
||||
|
||||
# =============== USER-MANAGED CREDENTIALS =============== #
|
||||
|
||||
def add_creds(self, user_id: str, credentials: Credentials) -> None:
|
||||
with self.locked_user_integrations(user_id):
|
||||
if self.get_creds_by_id(user_id, credentials.id):
|
||||
@@ -337,18 +339,23 @@ class IntegrationCredentialsStore:
|
||||
]
|
||||
self._set_user_integration_creds(user_id, filtered_credentials)
|
||||
|
||||
def get_managed_creds(self, user_id: str, provider: str) -> Optional[Credentials]:
|
||||
user_integrations = self._get_user_integrations(user_id)
|
||||
return user_integrations.managed_credentials.get(provider)
|
||||
# ============== SYSTEM-MANAGED CREDENTIALS ============== #
|
||||
|
||||
def set_managed_creds(self, user_id: str, credentials: Credentials) -> None:
|
||||
def get_ayrshare_profile_key(self, user_id: str) -> SecretStr | None:
|
||||
managed_user_creds = self._get_user_integrations(user_id).managed_credentials
|
||||
return managed_user_creds.ayrshare_profile_key
|
||||
|
||||
def set_ayrshare_profile_key(self, user_id: str, profile_key: str) -> None:
|
||||
_profile_key = SecretStr(profile_key)
|
||||
with self.locked_user_integrations(user_id):
|
||||
user_integrations = self._get_user_integrations(user_id)
|
||||
user_integrations.managed_credentials[credentials.provider] = credentials
|
||||
user_integrations.managed_credentials.ayrshare_profile_key = _profile_key
|
||||
self.db_manager.update_user_integrations(
|
||||
user_id=user_id, data=user_integrations
|
||||
)
|
||||
|
||||
# ===================== OAUTH STATES ===================== #
|
||||
|
||||
def store_state_token(
|
||||
self, user_id: str, provider: str, scopes: list[str], use_pkce: bool = False
|
||||
) -> tuple[str, str]:
|
||||
@@ -416,6 +423,8 @@ class IntegrationCredentialsStore:
|
||||
|
||||
return None
|
||||
|
||||
# =================== GET/SET HELPERS =================== #
|
||||
|
||||
def _set_user_integration_creds(
|
||||
self, user_id: str, credentials: list[Credentials]
|
||||
) -> None:
|
||||
|
||||
Reference in New Issue
Block a user