feat(backend): Ensure validity of OAuth credentials during graph execution (#8191)

- feat(backend/executor): Change credential injection mechanism to acquire credentials from `AgentServer` just before execution
  - Also locks the credentials for the duration of the execution

- feat(backend/server): Add thread-safe `IntegrationCredentialsManager` to handle and synchronize credentials-related operations

- feat(libs): Add mutexes to `SupabaseIntegrationCredentialsStore` to ensure thread-safety

Also:
- feat(backend): Added Pydantic model (de)serialization support to `@expose` decorator

Refactorings:
- refactor(backend, libs): Move `KeyedMutex` to `autogpt_libs.utils.synchronize`
- refactor(backend/server): Make `backend.server.integrations` module with `router`, `creds_manager`, and `utils` in it
This commit is contained in:
Reinier van der Leer
2024-10-10 18:45:43 +02:00
committed by GitHub
parent d8145c158c
commit 992989ee71
20 changed files with 521 additions and 168 deletions

View File

@@ -1,8 +1,9 @@
from .store import SupabaseIntegrationCredentialsStore
from .types import APIKeyCredentials, OAuth2Credentials
from .types import Credentials, APIKeyCredentials, OAuth2Credentials
__all__ = [
"SupabaseIntegrationCredentialsStore",
"Credentials",
"APIKeyCredentials",
"OAuth2Credentials",
]

View File

@@ -1,8 +1,12 @@
import secrets
from datetime import datetime, timedelta, timezone
from typing import cast
from typing import TYPE_CHECKING, cast
from supabase import Client
if TYPE_CHECKING:
from redis import Redis
from supabase import Client
from autogpt_libs.utils.synchronize import RedisKeyedMutex
from .types import (
Credentials,
@@ -14,26 +18,28 @@ from .types import (
class SupabaseIntegrationCredentialsStore:
def __init__(self, supabase: Client):
def __init__(self, supabase: "Client", redis: "Redis"):
self.supabase = supabase
self.locks = RedisKeyedMutex(redis)
def add_creds(self, user_id: str, credentials: Credentials) -> None:
if self.get_creds_by_id(user_id, credentials.id):
raise ValueError(
f"Can not re-create existing credentials with ID {credentials.id} "
f"for user with ID {user_id}"
with self.locked_user_metadata(user_id):
if self.get_creds_by_id(user_id, credentials.id):
raise ValueError(
f"Can not re-create existing credentials #{credentials.id} "
f"for user #{user_id}"
)
self._set_user_integration_creds(
user_id, [*self.get_all_creds(user_id), credentials]
)
self._set_user_integration_creds(
user_id, [*self.get_all_creds(user_id), credentials]
)
def get_all_creds(self, user_id: str) -> list[Credentials]:
user_metadata = self._get_user_metadata(user_id)
return UserMetadata.model_validate(user_metadata).integration_credentials
def get_creds_by_id(self, user_id: str, credentials_id: str) -> Credentials | None:
credentials = self.get_all_creds(user_id)
return next((c for c in credentials if c.id == credentials_id), None)
all_credentials = self.get_all_creds(user_id)
return next((c for c in all_credentials if c.id == credentials_id), None)
def get_creds_by_provider(self, user_id: str, provider: str) -> list[Credentials]:
credentials = self.get_all_creds(user_id)
@@ -44,42 +50,45 @@ class SupabaseIntegrationCredentialsStore:
return list(set(c.provider for c in credentials))
def update_creds(self, user_id: str, updated: Credentials) -> None:
current = self.get_creds_by_id(user_id, updated.id)
if not current:
raise ValueError(
f"Credentials with ID {updated.id} "
f"for user with ID {user_id} not found"
)
if type(current) is not type(updated):
raise TypeError(
f"Can not update credentials with ID {updated.id} "
f"from type {type(current)} "
f"to type {type(updated)}"
)
with self.locked_user_metadata(user_id):
current = self.get_creds_by_id(user_id, updated.id)
if not current:
raise ValueError(
f"Credentials with ID {updated.id} "
f"for user with ID {user_id} not found"
)
if type(current) is not type(updated):
raise TypeError(
f"Can not update credentials with ID {updated.id} "
f"from type {type(current)} "
f"to type {type(updated)}"
)
# Ensure no scopes are removed when updating credentials
if (
isinstance(updated, OAuth2Credentials)
and isinstance(current, OAuth2Credentials)
and not set(updated.scopes).issuperset(current.scopes)
):
raise ValueError(
f"Can not update credentials with ID {updated.id} "
f"and scopes {current.scopes} "
f"to more restrictive set of scopes {updated.scopes}"
)
# Ensure no scopes are removed when updating credentials
if (
isinstance(updated, OAuth2Credentials)
and isinstance(current, OAuth2Credentials)
and not set(updated.scopes).issuperset(current.scopes)
):
raise ValueError(
f"Can not update credentials with ID {updated.id} "
f"and scopes {current.scopes} "
f"to more restrictive set of scopes {updated.scopes}"
)
# Update the credentials
updated_credentials_list = [
updated if c.id == updated.id else c for c in self.get_all_creds(user_id)
]
self._set_user_integration_creds(user_id, updated_credentials_list)
# Update the credentials
updated_credentials_list = [
updated if c.id == updated.id else c
for c in self.get_all_creds(user_id)
]
self._set_user_integration_creds(user_id, updated_credentials_list)
def delete_creds_by_id(self, user_id: str, credentials_id: str) -> None:
filtered_credentials = [
c for c in self.get_all_creds(user_id) if c.id != credentials_id
]
self._set_user_integration_creds(user_id, filtered_credentials)
with self.locked_user_metadata(user_id):
filtered_credentials = [
c for c in self.get_all_creds(user_id) if c.id != credentials_id
]
self._set_user_integration_creds(user_id, filtered_credentials)
async def store_state_token(
self, user_id: str, provider: str, scopes: list[str]
@@ -94,14 +103,15 @@ class SupabaseIntegrationCredentialsStore:
scopes=scopes,
)
user_metadata = self._get_user_metadata(user_id)
oauth_states = user_metadata.get("integration_oauth_states", [])
oauth_states.append(state.model_dump())
user_metadata["integration_oauth_states"] = oauth_states
with self.locked_user_metadata(user_id):
user_metadata = self._get_user_metadata(user_id)
oauth_states = user_metadata.get("integration_oauth_states", [])
oauth_states.append(state.model_dump())
user_metadata["integration_oauth_states"] = oauth_states
self.supabase.auth.admin.update_user_by_id(
user_id, {"user_metadata": user_metadata}
)
self.supabase.auth.admin.update_user_by_id(
user_id, {"user_metadata": user_metadata}
)
return token
@@ -136,29 +146,30 @@ class SupabaseIntegrationCredentialsStore:
return []
async def verify_state_token(self, user_id: str, token: str, provider: str) -> bool:
user_metadata = self._get_user_metadata(user_id)
oauth_states = user_metadata.get("integration_oauth_states", [])
with self.locked_user_metadata(user_id):
user_metadata = self._get_user_metadata(user_id)
oauth_states = user_metadata.get("integration_oauth_states", [])
now = datetime.now(timezone.utc)
valid_state = next(
(
state
for state in oauth_states
if state["token"] == token
and state["provider"] == provider
and state["expires_at"] > now.timestamp()
),
None,
)
if valid_state:
# Remove the used state
oauth_states.remove(valid_state)
user_metadata["integration_oauth_states"] = oauth_states
self.supabase.auth.admin.update_user_by_id(
user_id, {"user_metadata": user_metadata}
now = datetime.now(timezone.utc)
valid_state = next(
(
state
for state in oauth_states
if state["token"] == token
and state["provider"] == provider
and state["expires_at"] > now.timestamp()
),
None,
)
return True
if valid_state:
# Remove the used state
oauth_states.remove(valid_state)
user_metadata["integration_oauth_states"] = oauth_states
self.supabase.auth.admin.update_user_by_id(
user_id, {"user_metadata": user_metadata}
)
return True
return False
@@ -178,3 +189,7 @@ class SupabaseIntegrationCredentialsStore:
if not response.user:
raise ValueError(f"User with ID {user_id} not found")
return cast(UserMetadataRaw, response.user.user_metadata)
def locked_user_metadata(self, user_id: str):
key = (self.supabase.supabase_url, f"user:{user_id}", "metadata")
return self.locks.locked(key)

View File

@@ -0,0 +1,56 @@
from contextlib import contextmanager
from threading import Lock
from typing import TYPE_CHECKING, Any
from expiringdict import ExpiringDict
if TYPE_CHECKING:
from redis import Redis
from redis.lock import Lock as RedisLock
class RedisKeyedMutex:
"""
This class provides a mutex that can be locked and unlocked by a specific key,
using Redis as a distributed locking provider.
It uses an ExpiringDict to automatically clear the mutex after a specified timeout,
in case the key is not unlocked for a specified duration, to prevent memory leaks.
"""
def __init__(self, redis: "Redis", timeout: int | None = 60):
self.redis = redis
self.timeout = timeout
self.locks: dict[Any, "RedisLock"] = ExpiringDict(
max_len=6000, max_age_seconds=self.timeout
)
self.locks_lock = Lock()
@contextmanager
def locked(self, key: Any):
lock = self.acquire(key)
try:
yield
finally:
lock.release()
def acquire(self, key: Any) -> "RedisLock":
"""Acquires and returns a lock with the given key"""
with self.locks_lock:
if key not in self.locks:
self.locks[key] = self.redis.lock(
str(key), self.timeout, thread_local=False
)
lock = self.locks[key]
lock.acquire()
return lock
def release(self, key: Any):
if lock := self.locks.get(key):
lock.release()
def release_all_locks(self):
"""Call this on process termination to ensure all locks are released"""
self.locks_lock.acquire(blocking=False)
for lock in self.locks.values():
if lock.locked() and lock.owned():
lock.release()

View File

@@ -377,6 +377,20 @@ files = [
[package.extras]
test = ["pytest (>=6)"]
[[package]]
name = "expiringdict"
version = "1.2.2"
description = "Dictionary with auto-expiring values for caching purposes"
optional = false
python-versions = "*"
files = [
{file = "expiringdict-1.2.2-py3-none-any.whl", hash = "sha256:09a5d20bc361163e6432a874edd3179676e935eb81b925eccef48d409a8a45e8"},
{file = "expiringdict-1.2.2.tar.gz", hash = "sha256:300fb92a7e98f15b05cf9a856c1415b3bc4f2e132be07daa326da6414c23ee09"},
]
[package.extras]
tests = ["coverage", "coveralls", "dill", "mock", "nose"]
[[package]]
name = "frozenlist"
version = "1.4.1"
@@ -1031,6 +1045,7 @@ description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs
optional = false
python-versions = ">=3.8"
files = [
{file = "pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629"},
{file = "pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034"},
]
@@ -1041,6 +1056,7 @@ description = "A collection of ASN.1-based protocols modules"
optional = false
python-versions = ">=3.8"
files = [
{file = "pyasn1_modules-0.4.1-py3-none-any.whl", hash = "sha256:49bfa96b45a292b711e986f222502c1c9a5e1f4e568fc30e2574a6c7d07838fd"},
{file = "pyasn1_modules-0.4.1.tar.gz", hash = "sha256:c28e2dbf9c06ad61c71a075c7e0f9fd0f1b0bb2d2ad4377f240d33ac2ab60a7c"},
]
@@ -1253,6 +1269,24 @@ python-dateutil = ">=2.8.1,<3.0.0"
typing-extensions = ">=4.12.2,<5.0.0"
websockets = ">=11,<13"
[[package]]
name = "redis"
version = "5.1.1"
description = "Python client for Redis database and key-value store"
optional = false
python-versions = ">=3.8"
files = [
{file = "redis-5.1.1-py3-none-any.whl", hash = "sha256:f8ea06b7482a668c6475ae202ed8d9bcaa409f6e87fb77ed1043d912afd62e24"},
{file = "redis-5.1.1.tar.gz", hash = "sha256:f6c997521fedbae53387307c5d0bf784d9acc28d9f1d058abeac566ec4dbed72"},
]
[package.dependencies]
async-timeout = {version = ">=4.0.3", markers = "python_full_version < \"3.11.3\""}
[package.extras]
hiredis = ["hiredis (>=3.0.0)"]
ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==23.2.1)", "requests (>=2.31.0)"]
[[package]]
name = "requests"
version = "2.32.3"
@@ -1690,4 +1724,4 @@ type = ["pytest-mypy"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<4.0"
content-hash = "e9b6e5d877eeb9c9f1ebc69dead1985d749facc160afbe61f3bf37e9a6e35aa5"
content-hash = "ad9a4c8b399f6480a9f70319d13df810f92f63b532d4e10503d283f0948bed6c"

View File

@@ -8,6 +8,7 @@ packages = [{ include = "autogpt_libs" }]
[tool.poetry.dependencies]
colorama = "^0.4.6"
expiringdict = "^1.2.2"
google-cloud-logging = "^3.8.0"
pydantic = "^2.8.2"
pydantic-settings = "^2.5.2"
@@ -16,6 +17,9 @@ python = ">=3.10,<4.0"
python-dotenv = "^1.0.1"
supabase = "^2.7.2"
[tool.poetry.group.dev.dependencies]
redis = "^5.0.8"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

View File

@@ -25,7 +25,8 @@ def main(**kwargs):
"""
from backend.executor import ExecutionManager, ExecutionScheduler
from backend.server import AgentServer, WebsocketServer
from backend.server.rest_api import AgentServer
from backend.server.ws_api import WebsocketServer
run_processes(
ExecutionManager(),

View File

@@ -3,7 +3,6 @@ from datetime import datetime, timezone
from multiprocessing import Manager
from typing import Any, Generic, TypeVar
from autogpt_libs.supabase_integration_credentials_store.types import Credentials
from prisma.enums import AgentExecutionStatus
from prisma.models import (
AgentGraphExecution,
@@ -26,7 +25,6 @@ class GraphExecution(BaseModel):
graph_exec_id: str
graph_id: str
start_node_execs: list["NodeExecution"]
node_input_credentials: dict[str, Credentials] # dict[node_id, Credentials]
class NodeExecution(BaseModel):

View File

@@ -11,8 +11,8 @@ from contextlib import contextmanager
from multiprocessing.pool import AsyncResult, Pool
from typing import TYPE_CHECKING, Any, Coroutine, Generator, TypeVar, cast
from autogpt_libs.supabase_integration_credentials_store.types import Credentials
from pydantic import BaseModel
from redis.lock import Lock as RedisLock
if TYPE_CHECKING:
from backend.server.rest_api import AgentServer
@@ -40,14 +40,16 @@ from backend.data.execution import (
)
from backend.data.graph import Graph, Link, Node, get_graph, get_node
from backend.data.model import CREDENTIALS_FIELD_NAME, CredentialsMetaInput
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util import json
from backend.util.decorator import error_logged, time_measured
from backend.util.logging import configure_logging
from backend.util.service import AppService, expose, get_service_client
from backend.util.settings import Config
from backend.util.settings import Settings
from backend.util.type import convert
logger = logging.getLogger(__name__)
settings = Settings()
class LogMetadata:
@@ -102,8 +104,8 @@ ExecutionStream = Generator[NodeExecution, None, None]
def execute_node(
loop: asyncio.AbstractEventLoop,
api_client: "AgentServer",
creds_manager: IntegrationCredentialsManager,
data: NodeExecution,
input_credentials: Credentials | None = None,
execution_stats: dict[str, Any] | None = None,
) -> ExecutionStream:
"""
@@ -164,8 +166,15 @@ def execute_node(
user_credit = get_user_credit_model()
extra_exec_kwargs = {}
if input_credentials:
extra_exec_kwargs["credentials"] = input_credentials
# Last-minute fetch credentials + acquire a system-wide read-write lock to prevent
# changes during execution. ⚠️ This means a set of credentials can only be used by
# one (running) block at a time; simultaneous execution of blocks using same
# credentials is not supported.
credentials = creds_lock = None
if CREDENTIALS_FIELD_NAME in input_data:
credentials_meta = CredentialsMetaInput(**input_data[CREDENTIALS_FIELD_NAME])
credentials, creds_lock = creds_manager.acquire(user_id, credentials_meta.id)
extra_exec_kwargs["credentials"] = credentials
output_size = 0
try:
@@ -192,6 +201,10 @@ def execute_node(
):
yield execution
# Release lock on credentials ASAP
if creds_lock:
creds_lock.release()
r = update_execution(ExecutionStatus.COMPLETED)
s = input_size + output_size
t = (
@@ -210,21 +223,14 @@ def execute_node(
raise e
finally:
# Ensure credentials are released even if execution fails
if creds_lock:
creds_lock.release()
if execution_stats is not None:
execution_stats["input_size"] = input_size
execution_stats["output_size"] = output_size
@contextmanager
def synchronized(key: str, timeout: int = 60):
lock = redis.get_redis().lock(f"lock:{key}", timeout=timeout)
try:
lock.acquire()
yield
finally:
lock.release()
def _enqueue_next_nodes(
api_client: "AgentServer",
loop: asyncio.AbstractEventLoop,
@@ -400,12 +406,6 @@ def validate_exec(
return data, node_block.name
def get_agent_server_client() -> "AgentServer":
from backend.server.rest_api import AgentServer
return get_service_client(AgentServer, Config().agent_server_port)
class Executor:
"""
This class contains event handlers for the process pool executor events.
@@ -441,6 +441,7 @@ class Executor:
redis.connect()
cls.loop.run_until_complete(db.connect())
cls.agent_server_client = get_agent_server_client()
cls.creds_manager = IntegrationCredentialsManager()
# Set up shutdown handlers
cls.shutdown_lock = threading.Lock()
@@ -454,6 +455,8 @@ class Executor:
if not cls.shutdown_lock.acquire(blocking=False):
return # already shutting down
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Releasing locks...")
cls.creds_manager.release_all_locks()
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting DB...")
cls.loop.run_until_complete(db.disconnect())
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting Redis...")
@@ -466,8 +469,12 @@ class Executor:
if not cls.shutdown_lock.acquire(blocking=False):
return # already shutting down, no need to self-terminate
llprint(f"[on_node_executor_sigterm {cls.pid}] ⏳ Releasing locks...")
cls.creds_manager.release_all_locks()
llprint(f"[on_node_executor_sigterm {cls.pid}] ⏳ Disconnecting DB...")
cls.loop.run_until_complete(db.disconnect())
llprint(f"[on_node_executor_sigterm {cls.pid}] ⏳ Disconnecting Redis...")
redis.disconnect()
llprint(f"[on_node_executor_sigterm {cls.pid}] ✅ Finished cleanup")
sys.exit(0)
@@ -477,7 +484,6 @@ class Executor:
cls,
q: ExecutionQueue[NodeExecution],
node_exec: NodeExecution,
input_credentials: Credentials | None,
):
log_metadata = LogMetadata(
user_id=node_exec.user_id,
@@ -490,7 +496,7 @@ class Executor:
execution_stats = {}
timing_info, _ = cls._on_node_execution(
q, node_exec, input_credentials, log_metadata, execution_stats
q, node_exec, log_metadata, execution_stats
)
execution_stats["walltime"] = timing_info.wall_time
execution_stats["cputime"] = timing_info.cpu_time
@@ -505,14 +511,13 @@ class Executor:
cls,
q: ExecutionQueue[NodeExecution],
node_exec: NodeExecution,
input_credentials: Credentials | None,
log_metadata: LogMetadata,
stats: dict[str, Any] | None = None,
):
try:
log_metadata.info(f"Start node execution {node_exec.node_exec_id}")
for execution in execute_node(
cls.loop, cls.agent_server_client, node_exec, input_credentials, stats
cls.loop, cls.agent_server_client, cls.creds_manager, node_exec, stats
):
q.add(execution)
log_metadata.info(f"Finished node execution {node_exec.node_exec_id}")
@@ -525,7 +530,7 @@ class Executor:
def on_graph_executor_start(cls):
configure_logging()
cls.pool_size = Config().num_node_workers
cls.pool_size = settings.config.num_node_workers
cls.loop = asyncio.new_event_loop()
cls.pid = os.getpid()
@@ -648,11 +653,7 @@ class Executor:
)
running_executions[exec_data.node_id] = cls.executor.apply_async(
cls.on_node_execution,
(
queue,
exec_data,
graph_exec.node_input_credentials.get(exec_data.node_id),
),
(queue, exec_data),
callback=make_exec_callback(exec_data),
)
@@ -688,10 +689,11 @@ class Executor:
class ExecutionManager(AppService):
def __init__(self):
super().__init__(port=Config().execution_manager_port)
super().__init__(port=settings.config.execution_manager_port)
self.use_db = True
self.use_queue = True # we only need the Redis connection
self.use_supabase = True
self.pool_size = Config().num_graph_workers
self.pool_size = settings.config.num_graph_workers
self.queue = ExecutionQueue[GraphExecution]()
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
@@ -700,7 +702,9 @@ class ExecutionManager(AppService):
SupabaseIntegrationCredentialsStore,
)
self.credentials_store = SupabaseIntegrationCredentialsStore(self.supabase)
self.credentials_store = SupabaseIntegrationCredentialsStore(
self.supabase, redis.get_redis()
)
self.executor = ProcessPoolExecutor(
max_workers=self.pool_size,
initializer=Executor.on_graph_executor_start,
@@ -732,6 +736,8 @@ class ExecutionManager(AppService):
@property
def agent_server_client(self) -> "AgentServer":
# Since every single usage of this property happens from a different thread,
# there is no value in caching it.
return get_agent_server_client()
@expose
@@ -743,7 +749,7 @@ class ExecutionManager(AppService):
raise Exception(f"Graph #{graph_id} not found.")
graph.validate_graph(for_run=True)
node_input_credentials = self._get_node_input_credentials(graph, user_id)
self._validate_node_input_credentials(graph, user_id)
nodes_input = []
for node in graph.starting_nodes:
@@ -799,7 +805,6 @@ class ExecutionManager(AppService):
graph_id=graph_id,
graph_exec_id=graph_exec_id,
start_node_execs=starting_node_execs,
node_input_credentials=node_input_credentials,
)
self.queue.add(graph_exec)
@@ -846,12 +851,8 @@ class ExecutionManager(AppService):
)
self.agent_server_client.send_execution_update(exec_update.model_dump())
def _get_node_input_credentials(
self, graph: Graph, user_id: str
) -> dict[str, Credentials]:
"""Gets all credentials for all nodes of the graph"""
node_credentials: dict[str, Credentials] = {}
def _validate_node_input_credentials(self, graph: Graph, user_id: str):
"""Checks all credentials for all nodes of the graph"""
for node in graph.nodes:
block = get_block(node.block_id)
@@ -894,9 +895,25 @@ class ExecutionManager(AppService):
f"Invalid credentials #{credentials.id} for node #{node.id}: "
"type/provider mismatch"
)
node_credentials[node.id] = credentials
return node_credentials
# ------- UTILITIES ------- #
def get_agent_server_client() -> "AgentServer":
from backend.server.rest_api import AgentServer
return get_service_client(AgentServer, settings.config.agent_server_port)
@contextmanager
def synchronized(key: str, timeout: int = 60):
lock: RedisLock = redis.get_redis().lock(f"lock:{key}", timeout=timeout)
try:
lock.acquire()
yield
finally:
lock.release()
def llprint(message: str):

View File

@@ -0,0 +1,172 @@
import logging
from contextlib import contextmanager
from datetime import datetime
from autogpt_libs.supabase_integration_credentials_store import (
Credentials,
SupabaseIntegrationCredentialsStore,
)
from autogpt_libs.utils.synchronize import RedisKeyedMutex
from redis.lock import Lock as RedisLock
from backend.data import redis
from backend.integrations.oauth import HANDLERS_BY_NAME, BaseOAuthHandler
from backend.util.settings import Settings
from ..server.integrations.utils import get_supabase
logger = logging.getLogger(__name__)
settings = Settings()
class IntegrationCredentialsManager:
"""
Handles the lifecycle of integration credentials.
- Automatically refreshes requested credentials if needed.
- Uses locking mechanisms to ensure system-wide consistency and
prevent invalidation of in-use tokens.
### ⚠️ Gotcha
With `acquire(..)`, credentials can only be in use in one place at a time (e.g. one
block execution).
### Locking mechanism
- Because *getting* credentials can result in a refresh (= *invalidation* +
*replacement*) of the stored credentials, *getting* is an operation that
potentially requires read/write access.
- Checking whether a token has to be refreshed is subject to an additional `refresh`
scoped lock to prevent unnecessary sequential refreshes when multiple executions
try to access the same credentials simultaneously.
- We MUST lock credentials while in use to prevent them from being invalidated while
they are in use, e.g. because they are being refreshed by a different part
of the system.
- The `!time_sensitive` lock in `acquire(..)` is part of a two-tier locking
mechanism in which *updating* gets priority over *getting* credentials.
This is to prevent a long queue of waiting *get* requests from blocking essential
credential refreshes or user-initiated updates.
It is possible to implement a reader/writer locking system where either multiple
readers or a single writer can have simultaneous access, but this would add a lot of
complexity to the mechanism. I don't expect the current ("simple") mechanism to
cause so much latency that it's worth implementing.
"""
def __init__(self):
redis_conn = redis.get_redis()
self._locks = RedisKeyedMutex(redis_conn)
self.store = SupabaseIntegrationCredentialsStore(get_supabase(), redis_conn)
def create(self, user_id: str, credentials: Credentials) -> None:
return self.store.add_creds(user_id, credentials)
def exists(self, user_id: str, credentials_id: str) -> bool:
return self.store.get_creds_by_id(user_id, credentials_id) is not None
def get(
self, user_id: str, credentials_id: str, lock: bool = True
) -> Credentials | None:
credentials = self.store.get_creds_by_id(user_id, credentials_id)
if not credentials:
return None
# Refresh OAuth credentials if needed
if credentials.type == "oauth2" and credentials.access_token_expires_at:
logger.debug(
f"Credentials #{credentials.id} expire at "
f"{datetime.fromtimestamp(credentials.access_token_expires_at)}; "
f"current time is {datetime.now()}"
)
with self._locked(user_id, credentials_id, "refresh"):
oauth_handler = _get_provider_oauth_handler(credentials.provider)
if oauth_handler.needs_refresh(credentials):
logger.debug(
f"Refreshing '{credentials.provider}' "
f"credentials #{credentials.id}"
)
_lock = None
if lock:
# Wait until the credentials are no longer in use anywhere
_lock = self._acquire_lock(user_id, credentials_id)
fresh_credentials = oauth_handler.refresh_tokens(credentials)
self.store.update_creds(user_id, fresh_credentials)
if _lock:
_lock.release()
credentials = fresh_credentials
else:
logger.debug(f"Credentials #{credentials.id} never expire")
return credentials
def acquire(
self, user_id: str, credentials_id: str
) -> tuple[Credentials, RedisLock]:
"""
⚠️ WARNING: this locks credentials system-wide and blocks both acquiring
and updating them elsewhere until the lock is released.
See the class docstring for more info.
"""
# Use a low-priority (!time_sensitive) locking queue on top of the general lock
# to allow priority access for refreshing/updating the tokens.
with self._locked(user_id, credentials_id, "!time_sensitive"):
lock = self._acquire_lock(user_id, credentials_id)
credentials = self.get(user_id, credentials_id, lock=False)
if not credentials:
raise ValueError(
f"Credentials #{credentials_id} for user #{user_id} not found"
)
return credentials, lock
def update(self, user_id: str, updated: Credentials) -> None:
with self._locked(user_id, updated.id):
self.store.update_creds(user_id, updated)
def delete(self, user_id: str, credentials_id: str) -> None:
with self._locked(user_id, credentials_id):
self.store.delete_creds_by_id(user_id, credentials_id)
# -- Locking utilities -- #
def _acquire_lock(self, user_id: str, credentials_id: str, *args: str) -> RedisLock:
key = (
self.store.supabase.supabase_url,
f"user:{user_id}",
f"credentials:{credentials_id}",
*args,
)
return self._locks.acquire(key)
@contextmanager
def _locked(self, user_id: str, credentials_id: str, *args: str):
lock = self._acquire_lock(user_id, credentials_id, *args)
try:
yield
finally:
lock.release()
def release_all_locks(self):
"""Call this on process termination to ensure all locks are released"""
self._locks.release_all_locks()
self.store.locks.release_all_locks()
def _get_provider_oauth_handler(provider_name: str) -> BaseOAuthHandler:
if provider_name not in HANDLERS_BY_NAME:
raise KeyError(f"Unknown provider '{provider_name}'")
client_id = getattr(settings.secrets, f"{provider_name}_client_id")
client_secret = getattr(settings.secrets, f"{provider_name}_client_secret")
if not (client_id and client_secret):
raise Exception( # TODO: ConfigError
f"Integration with provider '{provider_name}' is not configured",
)
handler_class = HANDLERS_BY_NAME[provider_name]
frontend_base_url = settings.config.frontend_base_url
return handler_class(
client_id=client_id,
client_secret=client_secret,
redirect_uri=f"{frontend_base_url}/auth/integrations/oauth_callback",
)

View File

@@ -1,6 +1,6 @@
from backend.app import run_processes
from backend.executor import ExecutionScheduler
from backend.server import AgentServer
from backend.server.rest_api import AgentServer
def main():

View File

@@ -1,4 +0,0 @@
from .rest_api import AgentServer
from .ws_api import WebsocketServer
__all__ = ["AgentServer", "WebsocketServer"]

View File

@@ -1,9 +1,6 @@
import logging
from typing import Annotated
from autogpt_libs.supabase_integration_credentials_store import (
SupabaseIntegrationCredentialsStore,
)
from autogpt_libs.supabase_integration_credentials_store.types import (
APIKeyCredentials,
Credentials,
@@ -21,20 +18,17 @@ from fastapi import (
Response,
)
from pydantic import BaseModel, SecretStr
from supabase import Client
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.oauth import HANDLERS_BY_NAME, BaseOAuthHandler
from backend.util.settings import Settings
from ..utils import get_supabase, get_user_id
from ..utils import get_user_id
logger = logging.getLogger(__name__)
settings = Settings()
router = APIRouter()
def get_store(supabase: Client = Depends(get_supabase)):
return SupabaseIntegrationCredentialsStore(supabase)
creds_manager = IntegrationCredentialsManager()
class LoginResponse(BaseModel):
@@ -47,7 +41,6 @@ async def login(
provider: Annotated[str, Path(title="The provider to initiate an OAuth flow for")],
user_id: Annotated[str, Depends(get_user_id)],
request: Request,
store: Annotated[SupabaseIntegrationCredentialsStore, Depends(get_store)],
scopes: Annotated[
str, Query(title="Comma-separated list of authorization scopes")
] = "",
@@ -57,7 +50,9 @@ async def login(
requested_scopes = scopes.split(",") if scopes else []
# Generate and store a secure random state token along with the scopes
state_token = await store.store_state_token(user_id, provider, requested_scopes)
state_token = await creds_manager.store.store_state_token(
user_id, provider, requested_scopes
)
login_url = handler.get_login_url(requested_scopes, state_token)
@@ -77,7 +72,6 @@ async def callback(
provider: Annotated[str, Path(title="The target provider for this OAuth exchange")],
code: Annotated[str, Body(title="Authorization code acquired by user login")],
state_token: Annotated[str, Body(title="Anti-CSRF nonce")],
store: Annotated[SupabaseIntegrationCredentialsStore, Depends(get_store)],
user_id: Annotated[str, Depends(get_user_id)],
request: Request,
) -> CredentialsMetaResponse:
@@ -85,12 +79,12 @@ async def callback(
handler = _get_provider_oauth_handler(request, provider)
# Verify the state token
if not await store.verify_state_token(user_id, state_token, provider):
if not await creds_manager.store.verify_state_token(user_id, state_token, provider):
logger.warning(f"Invalid or expired state token for user {user_id}")
raise HTTPException(status_code=400, detail="Invalid or expired state token")
try:
scopes = await store.get_any_valid_scopes_from_state_token(
scopes = await creds_manager.store.get_any_valid_scopes_from_state_token(
user_id, state_token, provider
)
logger.debug(f"Retrieved scopes from state token: {scopes}")
@@ -114,7 +108,7 @@ async def callback(
)
# TODO: Allow specifying `title` to set on `credentials`
store.add_creds(user_id, credentials)
creds_manager.create(user_id, credentials)
logger.debug(
f"Successfully processed OAuth callback for user {user_id} and provider {provider}"
@@ -132,9 +126,8 @@ async def callback(
async def list_credentials(
provider: Annotated[str, Path(title="The provider to list credentials for")],
user_id: Annotated[str, Depends(get_user_id)],
store: Annotated[SupabaseIntegrationCredentialsStore, Depends(get_store)],
) -> list[CredentialsMetaResponse]:
credentials = store.get_creds_by_provider(user_id, provider)
credentials = creds_manager.store.get_creds_by_provider(user_id, provider)
return [
CredentialsMetaResponse(
id=cred.id,
@@ -152,9 +145,8 @@ async def get_credential(
provider: Annotated[str, Path(title="The provider to retrieve credentials for")],
cred_id: Annotated[str, Path(title="The ID of the credentials to retrieve")],
user_id: Annotated[str, Depends(get_user_id)],
store: Annotated[SupabaseIntegrationCredentialsStore, Depends(get_store)],
) -> Credentials:
credential = store.get_creds_by_id(user_id, cred_id)
credential = creds_manager.get(user_id, cred_id)
if not credential:
raise HTTPException(status_code=404, detail="Credentials not found")
if credential.provider != provider:
@@ -166,7 +158,6 @@ async def get_credential(
@router.post("/{provider}/credentials", status_code=201)
async def create_api_key_credentials(
store: Annotated[SupabaseIntegrationCredentialsStore, Depends(get_store)],
user_id: Annotated[str, Depends(get_user_id)],
provider: Annotated[str, Path(title="The provider to create credentials for")],
api_key: Annotated[str, Body(title="The API key to store")],
@@ -183,7 +174,7 @@ async def create_api_key_credentials(
)
try:
store.add_creds(user_id, new_credentials)
creds_manager.create(user_id, new_credentials)
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Failed to store credentials: {str(e)}"
@@ -196,9 +187,8 @@ async def delete_credential(
provider: Annotated[str, Path(title="The provider to delete credentials for")],
cred_id: Annotated[str, Path(title="The ID of the credentials to delete")],
user_id: Annotated[str, Depends(get_user_id)],
store: Annotated[SupabaseIntegrationCredentialsStore, Depends(get_store)],
):
creds = store.get_creds_by_id(user_id, cred_id)
creds = creds_manager.store.get_creds_by_id(user_id, cred_id)
if not creds:
raise HTTPException(status_code=404, detail="Credentials not found")
if creds.provider != provider:
@@ -206,7 +196,7 @@ async def delete_credential(
status_code=404, detail="Credentials do not match the specified provider"
)
store.delete_creds_by_id(user_id, cred_id)
creds_manager.delete(user_id, cred_id)
return Response(status_code=204)

View File

@@ -0,0 +1,11 @@
from supabase import Client, create_client
from backend.util.settings import Settings
settings = Settings()
def get_supabase() -> Client:
return create_client(
settings.secrets.supabase_url, settings.secrets.supabase_service_role_key
)

View File

@@ -20,6 +20,7 @@ from backend.data.credit import get_block_costs, get_user_credit_model
from backend.data.queue import RedisEventQueue
from backend.data.user import get_or_create_user
from backend.executor import ExecutionManager, ExecutionScheduler
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.server.model import CreateGraph, SetGraphActiveVersion
from backend.util.service import AppService, expose, get_service_client
from backend.util.settings import AppEnvironment, Config, Settings
@@ -82,15 +83,16 @@ class AgentServer(AppService):
api_router.dependencies.append(Depends(auth_middleware))
# Import & Attach sub-routers
import backend.server.integrations.router
import backend.server.routers.analytics
import backend.server.routers.integrations
api_router.include_router(
backend.server.routers.integrations.router,
backend.server.integrations.router.router,
prefix="/integrations",
tags=["integrations"],
dependencies=[Depends(auth_middleware)],
)
self.integration_creds_manager = IntegrationCredentialsManager()
api_router.include_router(
backend.server.routers.analytics.router,

View File

@@ -1,6 +1,5 @@
from autogpt_libs.auth.middleware import auth_middleware
from fastapi import Depends, HTTPException
from supabase import Client, create_client
from backend.data.user import DEFAULT_USER_ID
from backend.util.settings import Settings
@@ -17,9 +16,3 @@ def get_user_id(payload: dict = Depends(auth_middleware)) -> str:
if not user_id:
raise HTTPException(status_code=401, detail="User ID not found in token")
return user_id
def get_supabase() -> Client:
return create_client(
settings.secrets.supabase_url, settings.secrets.supabase_service_role_key
)

View File

@@ -3,10 +3,24 @@ import logging
import os
import threading
import time
import typing
from abc import abstractmethod
from typing import Any, Callable, Coroutine, Type, TypeVar, cast
from types import UnionType
from typing import (
Annotated,
Any,
Callable,
Coroutine,
Iterator,
Type,
TypeVar,
cast,
get_args,
get_origin,
)
import Pyro5.api
from pydantic import BaseModel
from Pyro5 import api as pyro
from backend.data import db
@@ -27,9 +41,8 @@ def expose(func: C) -> C:
Decorator to mark a method or class to be exposed for remote calls.
## ⚠️ Gotcha
The types on the exposed function signature are respected **as long as they are
fully picklable**. This is not the case for Pydantic models, so if you really need
to pass a model, try dumping the model and passing the resulting dict instead.
Aside from "simple" types, only Pydantic models are passed unscathed *if annotated*.
Any other passed or returned class objects are converted to dictionaries by Pyro.
"""
def wrapper(*args, **kwargs):
@@ -40,9 +53,35 @@ def expose(func: C) -> C:
logger.exception(msg)
raise Exception(msg, e)
# Register custom serializers and deserializers for annotated Pydantic models
for name, annotation in func.__annotations__.items():
for model in _pydantic_models_from_type_annotation(annotation):
logger.debug(
f"Registering Pyro (de)serializers for {func.__name__} annotation "
f"'{name}': {model.__qualname__}"
)
pyro.register_class_to_dict(
model,
lambda obj: {
"__class__": obj.__class__.__qualname__,
**obj.model_dump(),
},
)
pyro.register_dict_to_class(
model.__qualname__, _make_pyrodantic_parser(model)
)
return pyro.expose(wrapper) # type: ignore
def _make_pyrodantic_parser(model: type[BaseModel]):
def parse_data_into_model(qualname, data: dict):
logger.debug(f"Parsing Pyroed {model.__qualname__} from data {data}")
return model(**data)
return parse_data_into_model
class AppService(AppProcess):
shared_event_loop: asyncio.AbstractEventLoop
event_queue: AbstractEventQueue = RedisEventQueue()
@@ -134,3 +173,28 @@ def get_service_client(service_type: Type[AS], port: int) -> AS:
return getattr(self.proxy, name)
return cast(AS, DynamicClient())
# --------- UTILITIES --------- #
def _pydantic_models_from_type_annotation(annotation) -> Iterator[type[BaseModel]]:
# Peel Annotated parameters
if (origin := get_origin(annotation)) and origin is Annotated:
annotation = get_args(annotation)[0]
if origin := get_origin(annotation):
if origin is UnionType:
types = get_args(annotation)
else:
types = [origin]
else:
types = [annotation]
for annotype in types:
if (
annotype is not None
and not hasattr(typing, annotype.__name__) # avoid generics and aliases
and issubclass(annotype, BaseModel)
):
yield annotype

View File

@@ -6,8 +6,7 @@ from backend.data.execution import ExecutionStatus
from backend.data.model import CREDENTIALS_FIELD_NAME
from backend.data.user import create_default_user
from backend.executor import ExecutionManager, ExecutionScheduler
from backend.server import AgentServer
from backend.server.rest_api import get_user_id
from backend.server.rest_api import AgentServer, get_user_id
log = print

View File

@@ -293,6 +293,7 @@ develop = true
[package.dependencies]
colorama = "^0.4.6"
expiringdict = "^1.2.2"
google-cloud-logging = "^3.8.0"
pydantic = "^2.8.2"
pydantic-settings = "^2.5.2"
@@ -3667,4 +3668,4 @@ type = ["pytest-mypy"]
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "3ab370b624b486517a2fbcdc17fb294fbd76b3ec6659c5b471c57bfd738e7277"
content-hash = "0962d61ced1a8154c64c6bbdb3f72aca558831adfbfda68eb66f39b535466f77"

View File

@@ -16,7 +16,6 @@ autogpt-libs = { path = "../autogpt_libs", develop = true }
click = "^8.1.7"
croniter = "^2.0.5"
discord-py = "^2.4.0"
expiringdict = "^1.2.2"
fastapi = "^0.109.0"
feedparser = "^6.0.11"
flake8 = "^7.0.0"

View File

@@ -4,8 +4,8 @@ from prisma.models import User
from backend.blocks.basic import FindInDictionaryBlock, StoreValueBlock
from backend.blocks.maths import CalculatorBlock, Operation
from backend.data import execution, graph
from backend.server import AgentServer
from backend.server.model import CreateGraph
from backend.server.rest_api import AgentServer
from backend.usecases.sample import create_test_graph, create_test_user
from backend.util.test import SpinTestServer, wait_execution