mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-19 20:18:22 -05:00
Compare commits
3 Commits
fix/undefi
...
pwuts/ayrs
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7d90376eb2 | ||
|
|
993e123f1b | ||
|
|
863a9e98ec |
@@ -15,7 +15,6 @@ from typing import (
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
get_args,
|
||||
)
|
||||
@@ -37,6 +36,7 @@ from pydantic_core import (
|
||||
ValidationError,
|
||||
core_schema,
|
||||
)
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.settings import Secrets
|
||||
@@ -260,15 +260,32 @@ class OAuthState(BaseModel):
|
||||
|
||||
class UserMetadata(BaseModel):
|
||||
integration_credentials: list[Credentials] = Field(default_factory=list)
|
||||
"""⚠️ Deprecated; use `UserIntegrations.credentials` instead"""
|
||||
integration_oauth_states: list[OAuthState] = Field(default_factory=list)
|
||||
"""⚠️ Deprecated; use `UserIntegrations.oauth_states` instead"""
|
||||
|
||||
|
||||
class UserMetadataRaw(TypedDict, total=False):
|
||||
integration_credentials: list[dict]
|
||||
"""⚠️ Deprecated; use `UserIntegrations.credentials` instead"""
|
||||
integration_oauth_states: list[dict]
|
||||
"""⚠️ Deprecated; use `UserIntegrations.oauth_states` instead"""
|
||||
|
||||
|
||||
class UserIntegrations(BaseModel):
|
||||
|
||||
class ManagedCredentials(BaseModel):
|
||||
"""Integration credentials managed by us, rather than by the user"""
|
||||
|
||||
ayrshare_profile_key: Optional[SecretStr] = None
|
||||
|
||||
@field_serializer("*")
|
||||
def dump_secret_strings(value: Any, _info):
|
||||
if isinstance(value, SecretStr):
|
||||
return value.get_secret_value()
|
||||
return value
|
||||
|
||||
managed_credentials: ManagedCredentials = Field(default_factory=ManagedCredentials)
|
||||
credentials: list[Credentials] = Field(default_factory=list)
|
||||
oauth_states: list[OAuthState] = Field(default_factory=list)
|
||||
|
||||
|
||||
@@ -124,7 +124,7 @@ async def get_user_integrations(user_id: str) -> UserIntegrations:
|
||||
|
||||
|
||||
async def update_user_integrations(user_id: str, data: UserIntegrations):
|
||||
encrypted_data = JSONCryptor().encrypt(data.model_dump())
|
||||
encrypted_data = JSONCryptor().encrypt(data.model_dump(exclude_none=True))
|
||||
await User.prisma().update(
|
||||
where={"id": user_id},
|
||||
data={"integrations": encrypted_data},
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import secrets
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
@@ -224,6 +225,8 @@ class IntegrationCredentialsStore:
|
||||
|
||||
return get_service_client(DatabaseManagerClient)
|
||||
|
||||
# =============== USER-MANAGED CREDENTIALS =============== #
|
||||
|
||||
def add_creds(self, user_id: str, credentials: Credentials) -> None:
|
||||
with self.locked_user_integrations(user_id):
|
||||
if self.get_creds_by_id(user_id, credentials.id):
|
||||
@@ -337,6 +340,19 @@ class IntegrationCredentialsStore:
|
||||
]
|
||||
self._set_user_integration_creds(user_id, filtered_credentials)
|
||||
|
||||
# ============== SYSTEM-MANAGED CREDENTIALS ============== #
|
||||
|
||||
def get_ayrshare_profile_key(self, user_id: str) -> SecretStr | None:
|
||||
managed_user_creds = self._get_user_integrations(user_id).managed_credentials
|
||||
return managed_user_creds.ayrshare_profile_key
|
||||
|
||||
def set_ayrshare_profile_key(self, user_id: str, profile_key: str) -> None:
|
||||
_profile_key = SecretStr(profile_key)
|
||||
with self.edit_user_integrations(user_id) as user_integrations:
|
||||
user_integrations.managed_credentials.ayrshare_profile_key = _profile_key
|
||||
|
||||
# ===================== OAUTH STATES ===================== #
|
||||
|
||||
def store_state_token(
|
||||
self, user_id: str, provider: str, scopes: list[str], use_pkce: bool = False
|
||||
) -> tuple[str, str]:
|
||||
@@ -353,16 +369,8 @@ class IntegrationCredentialsStore:
|
||||
scopes=scopes,
|
||||
)
|
||||
|
||||
with self.locked_user_integrations(user_id):
|
||||
|
||||
user_integrations = self._get_user_integrations(user_id)
|
||||
oauth_states = user_integrations.oauth_states
|
||||
oauth_states.append(state)
|
||||
user_integrations.oauth_states = oauth_states
|
||||
|
||||
self.db_manager.update_user_integrations(
|
||||
user_id=user_id, data=user_integrations
|
||||
)
|
||||
with self.edit_user_integrations(user_id) as user_integrations:
|
||||
user_integrations.oauth_states.append(state)
|
||||
|
||||
return token, code_challenge
|
||||
|
||||
@@ -404,6 +412,17 @@ class IntegrationCredentialsStore:
|
||||
|
||||
return None
|
||||
|
||||
# =================== GET/SET HELPERS =================== #
|
||||
|
||||
@contextmanager
|
||||
def edit_user_integrations(self, user_id: str):
|
||||
with self.locked_user_integrations(user_id):
|
||||
user_integrations = self._get_user_integrations(user_id)
|
||||
yield user_integrations # yield to allow edits
|
||||
self.db_manager.update_user_integrations(
|
||||
user_id=user_id, data=user_integrations
|
||||
)
|
||||
|
||||
def _set_user_integration_creds(
|
||||
self, user_id: str, credentials: list[Credentials]
|
||||
) -> None:
|
||||
|
||||
Reference in New Issue
Block a user