diff --git a/autogpt_platform/backend/backend/integrations/creds_manager.py b/autogpt_platform/backend/backend/integrations/creds_manager.py index 3abd90d90f..eb5e132503 100644 --- a/autogpt_platform/backend/backend/integrations/creds_manager.py +++ b/autogpt_platform/backend/backend/integrations/creds_manager.py @@ -1,13 +1,13 @@ import logging from contextlib import contextmanager from datetime import datetime -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable from autogpt_libs.utils.synchronize import RedisKeyedMutex from redis.lock import Lock as RedisLock from backend.data import redis -from backend.data.model import Credentials +from backend.data.model import Credentials, OAuth2Credentials from backend.integrations.credentials_store import IntegrationCredentialsStore from backend.integrations.oauth import HANDLERS_BY_NAME from backend.integrations.providers import ProviderName @@ -78,25 +78,7 @@ class IntegrationCredentialsManager: f"{datetime.fromtimestamp(credentials.access_token_expires_at)}; " f"current time is {datetime.now()}" ) - - with self._locked(user_id, credentials_id, "refresh"): - oauth_handler = _get_provider_oauth_handler(credentials.provider) - if oauth_handler.needs_refresh(credentials): - logger.debug( - f"Refreshing '{credentials.provider}' " - f"credentials #{credentials.id}" - ) - _lock = None - if lock: - # Wait until the credentials are no longer in use anywhere - _lock = self._acquire_lock(user_id, credentials_id) - - fresh_credentials = oauth_handler.refresh_tokens(credentials) - self.store.update_creds(user_id, fresh_credentials) - if _lock and _lock.locked() and _lock.owned(): - _lock.release() - - credentials = fresh_credentials + credentials = self.refresh_if_needed(user_id, credentials, lock) else: logger.debug(f"Credentials #{credentials.id} never expire") @@ -121,6 +103,50 @@ class IntegrationCredentialsManager: ) return credentials, lock + def cached_getter(self, user_id: str) -> Callable[[str], "Credentials | None"]: + all_credentials = None + + def get_credentials(creds_id: str) -> "Credentials | None": + nonlocal all_credentials + if not all_credentials: + # Fetch credentials on first necessity + all_credentials = self.store.get_all_creds(user_id) + + credential = next((c for c in all_credentials if c.id == creds_id), None) + if not credential: + return None + if credential.type != "oauth2" or not credential.access_token_expires_at: + # Credential doesn't expire + return credential + + # Credential is OAuth2 credential and has expiration timestamp + return self.refresh_if_needed(user_id, credential) + + return get_credentials + + def refresh_if_needed( + self, user_id: str, credentials: OAuth2Credentials, lock: bool = True + ) -> OAuth2Credentials: + with self._locked(user_id, credentials.id, "refresh"): + oauth_handler = _get_provider_oauth_handler(credentials.provider) + if oauth_handler.needs_refresh(credentials): + logger.debug( + f"Refreshing '{credentials.provider}' " + f"credentials #{credentials.id}" + ) + _lock = None + if lock: + # Wait until the credentials are no longer in use anywhere + _lock = self._acquire_lock(user_id, credentials.id) + + fresh_credentials = oauth_handler.refresh_tokens(credentials) + self.store.update_creds(user_id, fresh_credentials) + if _lock and _lock.locked() and _lock.owned(): + _lock.release() + + credentials = fresh_credentials + return credentials + def update(self, user_id: str, updated: Credentials) -> None: with self._locked(user_id, updated.id): self.store.update_creds(user_id, updated) diff --git a/autogpt_platform/backend/backend/integrations/webhooks/graph_lifecycle_hooks.py b/autogpt_platform/backend/backend/integrations/webhooks/graph_lifecycle_hooks.py index 5c4d1285b2..898c6772bb 100644 --- a/autogpt_platform/backend/backend/integrations/webhooks/graph_lifecycle_hooks.py +++ b/autogpt_platform/backend/backend/integrations/webhooks/graph_lifecycle_hooks.py @@ -1,8 +1,9 @@ import logging -from typing import TYPE_CHECKING, Callable, Optional, cast +from typing import TYPE_CHECKING, Optional, cast from backend.data.block import BlockSchema, BlockWebhookConfig from backend.data.graph import set_node_webhook +from backend.integrations.creds_manager import IntegrationCredentialsManager from backend.integrations.webhooks import get_webhook_manager, supports_webhooks if TYPE_CHECKING: @@ -12,21 +13,17 @@ if TYPE_CHECKING: from ._base import BaseWebhooksManager logger = logging.getLogger(__name__) +credentials_manager = IntegrationCredentialsManager() -async def on_graph_activate( - graph: "GraphModel", get_credentials: Callable[[str], "Credentials | None"] -): +async def on_graph_activate(graph: "GraphModel", user_id: str): """ Hook to be called when a graph is activated/created. ⚠️ Assuming node entities are not re-used between graph versions, ⚠️ this hook calls `on_node_activate` on all nodes in this graph. - - Params: - get_credentials: `credentials_id` -> Credentials """ - # Compare nodes in new_graph_version with previous_graph_version + get_credentials = credentials_manager.cached_getter(user_id) updated_nodes = [] for new_node in graph.nodes: block_input_schema = cast(BlockSchema, new_node.block.input_schema) @@ -56,18 +53,14 @@ async def on_graph_activate( return graph -async def on_graph_deactivate( - graph: "GraphModel", get_credentials: Callable[[str], "Credentials | None"] -): +async def on_graph_deactivate(graph: "GraphModel", user_id: str): """ Hook to be called when a graph is deactivated/deleted. ⚠️ Assuming node entities are not re-used between graph versions, ⚠️ this hook calls `on_node_deactivate` on all nodes in `graph`. - - Params: - get_credentials: `credentials_id` -> Credentials """ + get_credentials = credentials_manager.cached_getter(user_id) updated_nodes = [] for node in graph.nodes: block_input_schema = cast(BlockSchema, node.block.input_schema) diff --git a/autogpt_platform/backend/backend/server/routers/v1.py b/autogpt_platform/backend/backend/server/routers/v1.py index e403270aea..64815baa92 100644 --- a/autogpt_platform/backend/backend/server/routers/v1.py +++ b/autogpt_platform/backend/backend/server/routers/v1.py @@ -2,7 +2,7 @@ import asyncio import logging from collections import defaultdict from datetime import datetime -from typing import TYPE_CHECKING, Annotated, Any, Sequence +from typing import Annotated, Any, Sequence import pydantic import stripe @@ -60,7 +60,6 @@ from backend.data.user import ( from backend.executor import scheduler from backend.executor import utils as execution_utils from backend.executor.utils import create_execution_queue_config -from backend.integrations.creds_manager import IntegrationCredentialsManager from backend.integrations.webhooks.graph_lifecycle_hooks import ( on_graph_activate, on_graph_deactivate, @@ -78,9 +77,6 @@ from backend.server.utils import get_user_id from backend.util.service import get_service_client from backend.util.settings import Settings -if TYPE_CHECKING: - from backend.data.model import Credentials - @thread_cached def execution_scheduler_client() -> scheduler.SchedulerClient: @@ -101,7 +97,6 @@ def execution_event_bus() -> AsyncRedisExecutionEventBus: settings = Settings() logger = logging.getLogger(__name__) -integration_creds_manager = IntegrationCredentialsManager() _user_credit_model = get_user_credit_model() @@ -466,10 +461,7 @@ async def create_new_graph( library_db.add_generated_agent_image(graph, library_agent.id) ) - graph = await on_graph_activate( - graph, - get_credentials=lambda id: integration_creds_manager.get(user_id, id), - ) + graph = await on_graph_activate(graph, user_id=user_id) return graph @@ -480,11 +472,7 @@ async def delete_graph( graph_id: str, user_id: Annotated[str, Depends(get_user_id)] ) -> DeleteGraphResponse: if active_version := await graph_db.get_graph(graph_id, user_id=user_id): - - def get_credentials(credentials_id: str) -> "Credentials | None": - return integration_creds_manager.get(user_id, credentials_id) - - await on_graph_deactivate(active_version, get_credentials) + await on_graph_deactivate(active_version, user_id=user_id) return {"version_counts": await graph_db.delete_graph(graph_id, user_id=user_id)} @@ -521,24 +509,15 @@ async def update_graph( user_id, graph.id, graph.version ) - def get_credentials(credentials_id: str) -> "Credentials | None": - return integration_creds_manager.get(user_id, credentials_id) - # Handle activation of the new graph first to ensure continuity - new_graph_version = await on_graph_activate( - new_graph_version, - get_credentials=get_credentials, - ) + new_graph_version = await on_graph_activate(new_graph_version, user_id=user_id) # Ensure new version is the only active version await graph_db.set_graph_active_version( graph_id=graph_id, version=new_graph_version.version, user_id=user_id ) if current_active_version: # Handle deactivation of the previously active version - await on_graph_deactivate( - current_active_version, - get_credentials=get_credentials, - ) + await on_graph_deactivate(current_active_version, user_id=user_id) return new_graph_version @@ -562,14 +541,8 @@ async def set_graph_active_version( current_active_graph = await graph_db.get_graph(graph_id, user_id=user_id) - def get_credentials(credentials_id: str) -> "Credentials | None": - return integration_creds_manager.get(user_id, credentials_id) - # Handle activation of the new graph first to ensure continuity - await on_graph_activate( - new_active_graph, - get_credentials=get_credentials, - ) + await on_graph_activate(new_active_graph, user_id=user_id) # Ensure new version is the only active version await graph_db.set_graph_active_version( graph_id=graph_id, @@ -584,10 +557,7 @@ async def set_graph_active_version( if current_active_graph and current_active_graph.version != new_active_version: # Handle deactivation of the previously active version - await on_graph_deactivate( - current_active_graph, - get_credentials=get_credentials, - ) + await on_graph_deactivate(current_active_graph, user_id=user_id) @v1_router.post( diff --git a/autogpt_platform/backend/backend/server/v2/library/db.py b/autogpt_platform/backend/backend/server/v2/library/db.py index c5b240d3b7..417ae0c9e7 100644 --- a/autogpt_platform/backend/backend/server/v2/library/db.py +++ b/autogpt_platform/backend/backend/server/v2/library/db.py @@ -736,10 +736,7 @@ async def fork_library_agent(library_agent_id: str, user_id: str): new_graph = await graph_db.fork_graph( original_agent.graph_id, original_agent.graph_version, user_id ) - new_graph = await on_graph_activate( - new_graph, - get_credentials=lambda id: integration_creds_manager.get(user_id, id), - ) + new_graph = await on_graph_activate(new_graph, user_id=user_id) # Create a library agent for the new graph return await create_library_agent(new_graph, user_id)