diff --git a/autogpt_platform/backend/backend/data/user.py b/autogpt_platform/backend/backend/data/user.py index cb8d6093f2..bb7458dc7b 100644 --- a/autogpt_platform/backend/backend/data/user.py +++ b/autogpt_platform/backend/backend/data/user.py @@ -94,3 +94,43 @@ async def update_user_integrations(user_id: str, data: UserIntegrations): where={"id": user_id}, data={"integrations": encrypted_data}, ) + + +async def migrate_and_encrypt_user_integrations(): + """Migrate integration credentials and OAuth states from metadata to integrations column.""" + users = await User.prisma().find_many( + where={ + "metadata": { + "path": ["integration_credentials"], + "not": Json({"a": "yolo"}), # bogus value works to check if key exists + } # type: ignore + } + ) + logger.info(f"Migrating integration credentials for {len(users)} users") + + for user in users: + raw_metadata = cast(UserMetadataRaw, user.metadata) + metadata = UserMetadata.model_validate(raw_metadata) + + # Get existing integrations data + integrations = await get_user_integrations(user_id=user.id) + + # Copy credentials and oauth states from metadata if they exist + if metadata.integration_credentials and not integrations.credentials: + integrations.credentials = metadata.integration_credentials + if metadata.integration_oauth_states: + integrations.oauth_states = metadata.integration_oauth_states + + # Save to integrations column + await update_user_integrations(user_id=user.id, data=integrations) + + # Remove from metadata + raw_metadata = dict(raw_metadata) + raw_metadata.pop("integration_credentials", None) + raw_metadata.pop("integration_oauth_states", None) + + # Update metadata without integration data + await User.prisma().update( + where={"id": user.id}, + data={"metadata": Json(raw_metadata)}, + ) diff --git a/autogpt_platform/backend/backend/server/rest_api.py b/autogpt_platform/backend/backend/server/rest_api.py index 951da83839..27c6679cc4 100644 --- a/autogpt_platform/backend/backend/server/rest_api.py +++ b/autogpt_platform/backend/backend/server/rest_api.py @@ -22,6 +22,7 @@ logger = logging.getLogger(__name__) async def lifespan_context(app: fastapi.FastAPI): await backend.data.db.connect() await backend.data.block.initialize_blocks() + await backend.data.user.migrate_and_encrypt_user_integrations() yield await backend.data.db.disconnect()