mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-09 15:17:59 -05:00
feat(backend): Speed up graph create/update (#10025)
- Resolves #10024 Caching the repeated DB calls by the graph lifecycle hooks significantly speeds up graph update/create calls with many authenticated blocks (~300ms saved per authenticated block) ### Changes 🏗️ - Add and use `IntegrationCredentialsManager.cached_getter(user_id)` in lifecycle hooks - Split `refresh_if_needed(..)` method out of `IntegrationCredentialsManager.get(..)` - Simplify interface of lifecycle hooks: change `get_credentials` parameter to `user_id` ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: - [x] Save a graph with nodes with credentials
This commit is contained in:
committed by
GitHub
parent
767d2f2c1e
commit
8e2fb2daa4
@@ -1,13 +1,13 @@
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
from autogpt_libs.utils.synchronize import RedisKeyedMutex
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from backend.data import redis
|
||||
from backend.data.model import Credentials
|
||||
from backend.data.model import Credentials, OAuth2Credentials
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
from backend.integrations.oauth import HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -78,25 +78,7 @@ class IntegrationCredentialsManager:
|
||||
f"{datetime.fromtimestamp(credentials.access_token_expires_at)}; "
|
||||
f"current time is {datetime.now()}"
|
||||
)
|
||||
|
||||
with self._locked(user_id, credentials_id, "refresh"):
|
||||
oauth_handler = _get_provider_oauth_handler(credentials.provider)
|
||||
if oauth_handler.needs_refresh(credentials):
|
||||
logger.debug(
|
||||
f"Refreshing '{credentials.provider}' "
|
||||
f"credentials #{credentials.id}"
|
||||
)
|
||||
_lock = None
|
||||
if lock:
|
||||
# Wait until the credentials are no longer in use anywhere
|
||||
_lock = self._acquire_lock(user_id, credentials_id)
|
||||
|
||||
fresh_credentials = oauth_handler.refresh_tokens(credentials)
|
||||
self.store.update_creds(user_id, fresh_credentials)
|
||||
if _lock and _lock.locked() and _lock.owned():
|
||||
_lock.release()
|
||||
|
||||
credentials = fresh_credentials
|
||||
credentials = self.refresh_if_needed(user_id, credentials, lock)
|
||||
else:
|
||||
logger.debug(f"Credentials #{credentials.id} never expire")
|
||||
|
||||
@@ -121,6 +103,50 @@ class IntegrationCredentialsManager:
|
||||
)
|
||||
return credentials, lock
|
||||
|
||||
def cached_getter(self, user_id: str) -> Callable[[str], "Credentials | None"]:
|
||||
all_credentials = None
|
||||
|
||||
def get_credentials(creds_id: str) -> "Credentials | None":
|
||||
nonlocal all_credentials
|
||||
if not all_credentials:
|
||||
# Fetch credentials on first necessity
|
||||
all_credentials = self.store.get_all_creds(user_id)
|
||||
|
||||
credential = next((c for c in all_credentials if c.id == creds_id), None)
|
||||
if not credential:
|
||||
return None
|
||||
if credential.type != "oauth2" or not credential.access_token_expires_at:
|
||||
# Credential doesn't expire
|
||||
return credential
|
||||
|
||||
# Credential is OAuth2 credential and has expiration timestamp
|
||||
return self.refresh_if_needed(user_id, credential)
|
||||
|
||||
return get_credentials
|
||||
|
||||
def refresh_if_needed(
|
||||
self, user_id: str, credentials: OAuth2Credentials, lock: bool = True
|
||||
) -> OAuth2Credentials:
|
||||
with self._locked(user_id, credentials.id, "refresh"):
|
||||
oauth_handler = _get_provider_oauth_handler(credentials.provider)
|
||||
if oauth_handler.needs_refresh(credentials):
|
||||
logger.debug(
|
||||
f"Refreshing '{credentials.provider}' "
|
||||
f"credentials #{credentials.id}"
|
||||
)
|
||||
_lock = None
|
||||
if lock:
|
||||
# Wait until the credentials are no longer in use anywhere
|
||||
_lock = self._acquire_lock(user_id, credentials.id)
|
||||
|
||||
fresh_credentials = oauth_handler.refresh_tokens(credentials)
|
||||
self.store.update_creds(user_id, fresh_credentials)
|
||||
if _lock and _lock.locked() and _lock.owned():
|
||||
_lock.release()
|
||||
|
||||
credentials = fresh_credentials
|
||||
return credentials
|
||||
|
||||
def update(self, user_id: str, updated: Credentials) -> None:
|
||||
with self._locked(user_id, updated.id):
|
||||
self.store.update_creds(user_id, updated)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Callable, Optional, cast
|
||||
from typing import TYPE_CHECKING, Optional, cast
|
||||
|
||||
from backend.data.block import BlockSchema, BlockWebhookConfig
|
||||
from backend.data.graph import set_node_webhook
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.webhooks import get_webhook_manager, supports_webhooks
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -12,21 +13,17 @@ if TYPE_CHECKING:
|
||||
from ._base import BaseWebhooksManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
credentials_manager = IntegrationCredentialsManager()
|
||||
|
||||
|
||||
async def on_graph_activate(
|
||||
graph: "GraphModel", get_credentials: Callable[[str], "Credentials | None"]
|
||||
):
|
||||
async def on_graph_activate(graph: "GraphModel", user_id: str):
|
||||
"""
|
||||
Hook to be called when a graph is activated/created.
|
||||
|
||||
⚠️ Assuming node entities are not re-used between graph versions, ⚠️
|
||||
this hook calls `on_node_activate` on all nodes in this graph.
|
||||
|
||||
Params:
|
||||
get_credentials: `credentials_id` -> Credentials
|
||||
"""
|
||||
# Compare nodes in new_graph_version with previous_graph_version
|
||||
get_credentials = credentials_manager.cached_getter(user_id)
|
||||
updated_nodes = []
|
||||
for new_node in graph.nodes:
|
||||
block_input_schema = cast(BlockSchema, new_node.block.input_schema)
|
||||
@@ -56,18 +53,14 @@ async def on_graph_activate(
|
||||
return graph
|
||||
|
||||
|
||||
async def on_graph_deactivate(
|
||||
graph: "GraphModel", get_credentials: Callable[[str], "Credentials | None"]
|
||||
):
|
||||
async def on_graph_deactivate(graph: "GraphModel", user_id: str):
|
||||
"""
|
||||
Hook to be called when a graph is deactivated/deleted.
|
||||
|
||||
⚠️ Assuming node entities are not re-used between graph versions, ⚠️
|
||||
this hook calls `on_node_deactivate` on all nodes in `graph`.
|
||||
|
||||
Params:
|
||||
get_credentials: `credentials_id` -> Credentials
|
||||
"""
|
||||
get_credentials = credentials_manager.cached_getter(user_id)
|
||||
updated_nodes = []
|
||||
for node in graph.nodes:
|
||||
block_input_schema = cast(BlockSchema, node.block.input_schema)
|
||||
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Sequence
|
||||
from typing import Annotated, Any, Sequence
|
||||
|
||||
import pydantic
|
||||
import stripe
|
||||
@@ -60,7 +60,6 @@ from backend.data.user import (
|
||||
from backend.executor import scheduler
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.executor.utils import create_execution_queue_config
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.webhooks.graph_lifecycle_hooks import (
|
||||
on_graph_activate,
|
||||
on_graph_deactivate,
|
||||
@@ -78,9 +77,6 @@ from backend.server.utils import get_user_id
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.model import Credentials
|
||||
|
||||
|
||||
@thread_cached
|
||||
def execution_scheduler_client() -> scheduler.SchedulerClient:
|
||||
@@ -101,7 +97,6 @@ def execution_event_bus() -> AsyncRedisExecutionEventBus:
|
||||
|
||||
settings = Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
integration_creds_manager = IntegrationCredentialsManager()
|
||||
|
||||
_user_credit_model = get_user_credit_model()
|
||||
|
||||
@@ -466,10 +461,7 @@ async def create_new_graph(
|
||||
library_db.add_generated_agent_image(graph, library_agent.id)
|
||||
)
|
||||
|
||||
graph = await on_graph_activate(
|
||||
graph,
|
||||
get_credentials=lambda id: integration_creds_manager.get(user_id, id),
|
||||
)
|
||||
graph = await on_graph_activate(graph, user_id=user_id)
|
||||
return graph
|
||||
|
||||
|
||||
@@ -480,11 +472,7 @@ async def delete_graph(
|
||||
graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> DeleteGraphResponse:
|
||||
if active_version := await graph_db.get_graph(graph_id, user_id=user_id):
|
||||
|
||||
def get_credentials(credentials_id: str) -> "Credentials | None":
|
||||
return integration_creds_manager.get(user_id, credentials_id)
|
||||
|
||||
await on_graph_deactivate(active_version, get_credentials)
|
||||
await on_graph_deactivate(active_version, user_id=user_id)
|
||||
|
||||
return {"version_counts": await graph_db.delete_graph(graph_id, user_id=user_id)}
|
||||
|
||||
@@ -521,24 +509,15 @@ async def update_graph(
|
||||
user_id, graph.id, graph.version
|
||||
)
|
||||
|
||||
def get_credentials(credentials_id: str) -> "Credentials | None":
|
||||
return integration_creds_manager.get(user_id, credentials_id)
|
||||
|
||||
# Handle activation of the new graph first to ensure continuity
|
||||
new_graph_version = await on_graph_activate(
|
||||
new_graph_version,
|
||||
get_credentials=get_credentials,
|
||||
)
|
||||
new_graph_version = await on_graph_activate(new_graph_version, user_id=user_id)
|
||||
# Ensure new version is the only active version
|
||||
await graph_db.set_graph_active_version(
|
||||
graph_id=graph_id, version=new_graph_version.version, user_id=user_id
|
||||
)
|
||||
if current_active_version:
|
||||
# Handle deactivation of the previously active version
|
||||
await on_graph_deactivate(
|
||||
current_active_version,
|
||||
get_credentials=get_credentials,
|
||||
)
|
||||
await on_graph_deactivate(current_active_version, user_id=user_id)
|
||||
|
||||
return new_graph_version
|
||||
|
||||
@@ -562,14 +541,8 @@ async def set_graph_active_version(
|
||||
|
||||
current_active_graph = await graph_db.get_graph(graph_id, user_id=user_id)
|
||||
|
||||
def get_credentials(credentials_id: str) -> "Credentials | None":
|
||||
return integration_creds_manager.get(user_id, credentials_id)
|
||||
|
||||
# Handle activation of the new graph first to ensure continuity
|
||||
await on_graph_activate(
|
||||
new_active_graph,
|
||||
get_credentials=get_credentials,
|
||||
)
|
||||
await on_graph_activate(new_active_graph, user_id=user_id)
|
||||
# Ensure new version is the only active version
|
||||
await graph_db.set_graph_active_version(
|
||||
graph_id=graph_id,
|
||||
@@ -584,10 +557,7 @@ async def set_graph_active_version(
|
||||
|
||||
if current_active_graph and current_active_graph.version != new_active_version:
|
||||
# Handle deactivation of the previously active version
|
||||
await on_graph_deactivate(
|
||||
current_active_graph,
|
||||
get_credentials=get_credentials,
|
||||
)
|
||||
await on_graph_deactivate(current_active_graph, user_id=user_id)
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
|
||||
@@ -736,10 +736,7 @@ async def fork_library_agent(library_agent_id: str, user_id: str):
|
||||
new_graph = await graph_db.fork_graph(
|
||||
original_agent.graph_id, original_agent.graph_version, user_id
|
||||
)
|
||||
new_graph = await on_graph_activate(
|
||||
new_graph,
|
||||
get_credentials=lambda id: integration_creds_manager.get(user_id, id),
|
||||
)
|
||||
new_graph = await on_graph_activate(new_graph, user_id=user_id)
|
||||
|
||||
# Create a library agent for the new graph
|
||||
return await create_library_agent(new_graph, user_id)
|
||||
|
||||
Reference in New Issue
Block a user