feat(backend): Speed up graph create/update (#10025)

- Resolves #10024

Caching the repeated DB calls by the graph lifecycle hooks significantly
speeds up graph update/create calls with many authenticated blocks
(~300ms saved per authenticated block)

### Changes 🏗️

- Add and use `IntegrationCredentialsManager.cached_getter(user_id)` in
lifecycle hooks
- Split `refresh_if_needed(..)` method out of
`IntegrationCredentialsManager.get(..)`
- Simplify interface of lifecycle hooks: change `get_credentials`
parameter to `user_id`

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Save a graph with nodes with credentials
This commit is contained in:
Reinier van der Leer
2025-05-26 10:59:27 +01:00
committed by GitHub
parent 767d2f2c1e
commit 8e2fb2daa4
4 changed files with 62 additions and 76 deletions

View File

@@ -1,13 +1,13 @@
import logging
from contextlib import contextmanager
from datetime import datetime
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Callable
from autogpt_libs.utils.synchronize import RedisKeyedMutex
from redis.lock import Lock as RedisLock
from backend.data import redis
from backend.data.model import Credentials
from backend.data.model import Credentials, OAuth2Credentials
from backend.integrations.credentials_store import IntegrationCredentialsStore
from backend.integrations.oauth import HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName
@@ -78,25 +78,7 @@ class IntegrationCredentialsManager:
f"{datetime.fromtimestamp(credentials.access_token_expires_at)}; "
f"current time is {datetime.now()}"
)
with self._locked(user_id, credentials_id, "refresh"):
oauth_handler = _get_provider_oauth_handler(credentials.provider)
if oauth_handler.needs_refresh(credentials):
logger.debug(
f"Refreshing '{credentials.provider}' "
f"credentials #{credentials.id}"
)
_lock = None
if lock:
# Wait until the credentials are no longer in use anywhere
_lock = self._acquire_lock(user_id, credentials_id)
fresh_credentials = oauth_handler.refresh_tokens(credentials)
self.store.update_creds(user_id, fresh_credentials)
if _lock and _lock.locked() and _lock.owned():
_lock.release()
credentials = fresh_credentials
credentials = self.refresh_if_needed(user_id, credentials, lock)
else:
logger.debug(f"Credentials #{credentials.id} never expire")
@@ -121,6 +103,50 @@ class IntegrationCredentialsManager:
)
return credentials, lock
def cached_getter(self, user_id: str) -> Callable[[str], "Credentials | None"]:
all_credentials = None
def get_credentials(creds_id: str) -> "Credentials | None":
nonlocal all_credentials
if not all_credentials:
# Fetch credentials on first necessity
all_credentials = self.store.get_all_creds(user_id)
credential = next((c for c in all_credentials if c.id == creds_id), None)
if not credential:
return None
if credential.type != "oauth2" or not credential.access_token_expires_at:
# Credential doesn't expire
return credential
# Credential is OAuth2 credential and has expiration timestamp
return self.refresh_if_needed(user_id, credential)
return get_credentials
def refresh_if_needed(
self, user_id: str, credentials: OAuth2Credentials, lock: bool = True
) -> OAuth2Credentials:
with self._locked(user_id, credentials.id, "refresh"):
oauth_handler = _get_provider_oauth_handler(credentials.provider)
if oauth_handler.needs_refresh(credentials):
logger.debug(
f"Refreshing '{credentials.provider}' "
f"credentials #{credentials.id}"
)
_lock = None
if lock:
# Wait until the credentials are no longer in use anywhere
_lock = self._acquire_lock(user_id, credentials.id)
fresh_credentials = oauth_handler.refresh_tokens(credentials)
self.store.update_creds(user_id, fresh_credentials)
if _lock and _lock.locked() and _lock.owned():
_lock.release()
credentials = fresh_credentials
return credentials
def update(self, user_id: str, updated: Credentials) -> None:
with self._locked(user_id, updated.id):
self.store.update_creds(user_id, updated)

View File

@@ -1,8 +1,9 @@
import logging
from typing import TYPE_CHECKING, Callable, Optional, cast
from typing import TYPE_CHECKING, Optional, cast
from backend.data.block import BlockSchema, BlockWebhookConfig
from backend.data.graph import set_node_webhook
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.webhooks import get_webhook_manager, supports_webhooks
if TYPE_CHECKING:
@@ -12,21 +13,17 @@ if TYPE_CHECKING:
from ._base import BaseWebhooksManager
logger = logging.getLogger(__name__)
credentials_manager = IntegrationCredentialsManager()
async def on_graph_activate(
graph: "GraphModel", get_credentials: Callable[[str], "Credentials | None"]
):
async def on_graph_activate(graph: "GraphModel", user_id: str):
"""
Hook to be called when a graph is activated/created.
⚠️ Assuming node entities are not re-used between graph versions, ⚠️
this hook calls `on_node_activate` on all nodes in this graph.
Params:
get_credentials: `credentials_id` -> Credentials
"""
# Compare nodes in new_graph_version with previous_graph_version
get_credentials = credentials_manager.cached_getter(user_id)
updated_nodes = []
for new_node in graph.nodes:
block_input_schema = cast(BlockSchema, new_node.block.input_schema)
@@ -56,18 +53,14 @@ async def on_graph_activate(
return graph
async def on_graph_deactivate(
graph: "GraphModel", get_credentials: Callable[[str], "Credentials | None"]
):
async def on_graph_deactivate(graph: "GraphModel", user_id: str):
"""
Hook to be called when a graph is deactivated/deleted.
⚠️ Assuming node entities are not re-used between graph versions, ⚠️
this hook calls `on_node_deactivate` on all nodes in `graph`.
Params:
get_credentials: `credentials_id` -> Credentials
"""
get_credentials = credentials_manager.cached_getter(user_id)
updated_nodes = []
for node in graph.nodes:
block_input_schema = cast(BlockSchema, node.block.input_schema)

View File

@@ -2,7 +2,7 @@ import asyncio
import logging
from collections import defaultdict
from datetime import datetime
from typing import TYPE_CHECKING, Annotated, Any, Sequence
from typing import Annotated, Any, Sequence
import pydantic
import stripe
@@ -60,7 +60,6 @@ from backend.data.user import (
from backend.executor import scheduler
from backend.executor import utils as execution_utils
from backend.executor.utils import create_execution_queue_config
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.webhooks.graph_lifecycle_hooks import (
on_graph_activate,
on_graph_deactivate,
@@ -78,9 +77,6 @@ from backend.server.utils import get_user_id
from backend.util.service import get_service_client
from backend.util.settings import Settings
if TYPE_CHECKING:
from backend.data.model import Credentials
@thread_cached
def execution_scheduler_client() -> scheduler.SchedulerClient:
@@ -101,7 +97,6 @@ def execution_event_bus() -> AsyncRedisExecutionEventBus:
settings = Settings()
logger = logging.getLogger(__name__)
integration_creds_manager = IntegrationCredentialsManager()
_user_credit_model = get_user_credit_model()
@@ -466,10 +461,7 @@ async def create_new_graph(
library_db.add_generated_agent_image(graph, library_agent.id)
)
graph = await on_graph_activate(
graph,
get_credentials=lambda id: integration_creds_manager.get(user_id, id),
)
graph = await on_graph_activate(graph, user_id=user_id)
return graph
@@ -480,11 +472,7 @@ async def delete_graph(
graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> DeleteGraphResponse:
if active_version := await graph_db.get_graph(graph_id, user_id=user_id):
def get_credentials(credentials_id: str) -> "Credentials | None":
return integration_creds_manager.get(user_id, credentials_id)
await on_graph_deactivate(active_version, get_credentials)
await on_graph_deactivate(active_version, user_id=user_id)
return {"version_counts": await graph_db.delete_graph(graph_id, user_id=user_id)}
@@ -521,24 +509,15 @@ async def update_graph(
user_id, graph.id, graph.version
)
def get_credentials(credentials_id: str) -> "Credentials | None":
return integration_creds_manager.get(user_id, credentials_id)
# Handle activation of the new graph first to ensure continuity
new_graph_version = await on_graph_activate(
new_graph_version,
get_credentials=get_credentials,
)
new_graph_version = await on_graph_activate(new_graph_version, user_id=user_id)
# Ensure new version is the only active version
await graph_db.set_graph_active_version(
graph_id=graph_id, version=new_graph_version.version, user_id=user_id
)
if current_active_version:
# Handle deactivation of the previously active version
await on_graph_deactivate(
current_active_version,
get_credentials=get_credentials,
)
await on_graph_deactivate(current_active_version, user_id=user_id)
return new_graph_version
@@ -562,14 +541,8 @@ async def set_graph_active_version(
current_active_graph = await graph_db.get_graph(graph_id, user_id=user_id)
def get_credentials(credentials_id: str) -> "Credentials | None":
return integration_creds_manager.get(user_id, credentials_id)
# Handle activation of the new graph first to ensure continuity
await on_graph_activate(
new_active_graph,
get_credentials=get_credentials,
)
await on_graph_activate(new_active_graph, user_id=user_id)
# Ensure new version is the only active version
await graph_db.set_graph_active_version(
graph_id=graph_id,
@@ -584,10 +557,7 @@ async def set_graph_active_version(
if current_active_graph and current_active_graph.version != new_active_version:
# Handle deactivation of the previously active version
await on_graph_deactivate(
current_active_graph,
get_credentials=get_credentials,
)
await on_graph_deactivate(current_active_graph, user_id=user_id)
@v1_router.post(

View File

@@ -736,10 +736,7 @@ async def fork_library_agent(library_agent_id: str, user_id: str):
new_graph = await graph_db.fork_graph(
original_agent.graph_id, original_agent.graph_version, user_id
)
new_graph = await on_graph_activate(
new_graph,
get_credentials=lambda id: integration_creds_manager.get(user_id, id),
)
new_graph = await on_graph_activate(new_graph, user_id=user_id)
# Create a library agent for the new graph
return await create_library_agent(new_graph, user_id)