mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-10 07:38:04 -05:00
feat(backend): Ensure validity of OAuth credentials during graph execution (#8191)
- feat(backend/executor): Change credential injection mechanism to acquire credentials from `AgentServer` just before execution - Also locks the credentials for the duration of the execution - feat(backend/server): Add thread-safe `IntegrationCredentialsManager` to handle and synchronize credentials-related operations - feat(libs): Add mutexes to `SupabaseIntegrationCredentialsStore` to ensure thread-safety Also: - feat(backend): Added Pydantic model (de)serialization support to `@expose` decorator Refactorings: - refactor(backend, libs): Move `KeyedMutex` to `autogpt_libs.utils.synchronize` - refactor(backend/server): Make `backend.server.integrations` module with `router`, `creds_manager`, and `utils` in it
This commit is contained in:
committed by
GitHub
parent
d8145c158c
commit
992989ee71
@@ -1,8 +1,9 @@
|
||||
from .store import SupabaseIntegrationCredentialsStore
|
||||
from .types import APIKeyCredentials, OAuth2Credentials
|
||||
from .types import Credentials, APIKeyCredentials, OAuth2Credentials
|
||||
|
||||
__all__ = [
|
||||
"SupabaseIntegrationCredentialsStore",
|
||||
"Credentials",
|
||||
"APIKeyCredentials",
|
||||
"OAuth2Credentials",
|
||||
]
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
import secrets
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from supabase import Client
|
||||
if TYPE_CHECKING:
|
||||
from redis import Redis
|
||||
from supabase import Client
|
||||
|
||||
from autogpt_libs.utils.synchronize import RedisKeyedMutex
|
||||
|
||||
from .types import (
|
||||
Credentials,
|
||||
@@ -14,26 +18,28 @@ from .types import (
|
||||
|
||||
|
||||
class SupabaseIntegrationCredentialsStore:
|
||||
def __init__(self, supabase: Client):
|
||||
def __init__(self, supabase: "Client", redis: "Redis"):
|
||||
self.supabase = supabase
|
||||
self.locks = RedisKeyedMutex(redis)
|
||||
|
||||
def add_creds(self, user_id: str, credentials: Credentials) -> None:
|
||||
if self.get_creds_by_id(user_id, credentials.id):
|
||||
raise ValueError(
|
||||
f"Can not re-create existing credentials with ID {credentials.id} "
|
||||
f"for user with ID {user_id}"
|
||||
with self.locked_user_metadata(user_id):
|
||||
if self.get_creds_by_id(user_id, credentials.id):
|
||||
raise ValueError(
|
||||
f"Can not re-create existing credentials #{credentials.id} "
|
||||
f"for user #{user_id}"
|
||||
)
|
||||
self._set_user_integration_creds(
|
||||
user_id, [*self.get_all_creds(user_id), credentials]
|
||||
)
|
||||
self._set_user_integration_creds(
|
||||
user_id, [*self.get_all_creds(user_id), credentials]
|
||||
)
|
||||
|
||||
def get_all_creds(self, user_id: str) -> list[Credentials]:
|
||||
user_metadata = self._get_user_metadata(user_id)
|
||||
return UserMetadata.model_validate(user_metadata).integration_credentials
|
||||
|
||||
def get_creds_by_id(self, user_id: str, credentials_id: str) -> Credentials | None:
|
||||
credentials = self.get_all_creds(user_id)
|
||||
return next((c for c in credentials if c.id == credentials_id), None)
|
||||
all_credentials = self.get_all_creds(user_id)
|
||||
return next((c for c in all_credentials if c.id == credentials_id), None)
|
||||
|
||||
def get_creds_by_provider(self, user_id: str, provider: str) -> list[Credentials]:
|
||||
credentials = self.get_all_creds(user_id)
|
||||
@@ -44,42 +50,45 @@ class SupabaseIntegrationCredentialsStore:
|
||||
return list(set(c.provider for c in credentials))
|
||||
|
||||
def update_creds(self, user_id: str, updated: Credentials) -> None:
|
||||
current = self.get_creds_by_id(user_id, updated.id)
|
||||
if not current:
|
||||
raise ValueError(
|
||||
f"Credentials with ID {updated.id} "
|
||||
f"for user with ID {user_id} not found"
|
||||
)
|
||||
if type(current) is not type(updated):
|
||||
raise TypeError(
|
||||
f"Can not update credentials with ID {updated.id} "
|
||||
f"from type {type(current)} "
|
||||
f"to type {type(updated)}"
|
||||
)
|
||||
with self.locked_user_metadata(user_id):
|
||||
current = self.get_creds_by_id(user_id, updated.id)
|
||||
if not current:
|
||||
raise ValueError(
|
||||
f"Credentials with ID {updated.id} "
|
||||
f"for user with ID {user_id} not found"
|
||||
)
|
||||
if type(current) is not type(updated):
|
||||
raise TypeError(
|
||||
f"Can not update credentials with ID {updated.id} "
|
||||
f"from type {type(current)} "
|
||||
f"to type {type(updated)}"
|
||||
)
|
||||
|
||||
# Ensure no scopes are removed when updating credentials
|
||||
if (
|
||||
isinstance(updated, OAuth2Credentials)
|
||||
and isinstance(current, OAuth2Credentials)
|
||||
and not set(updated.scopes).issuperset(current.scopes)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Can not update credentials with ID {updated.id} "
|
||||
f"and scopes {current.scopes} "
|
||||
f"to more restrictive set of scopes {updated.scopes}"
|
||||
)
|
||||
# Ensure no scopes are removed when updating credentials
|
||||
if (
|
||||
isinstance(updated, OAuth2Credentials)
|
||||
and isinstance(current, OAuth2Credentials)
|
||||
and not set(updated.scopes).issuperset(current.scopes)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Can not update credentials with ID {updated.id} "
|
||||
f"and scopes {current.scopes} "
|
||||
f"to more restrictive set of scopes {updated.scopes}"
|
||||
)
|
||||
|
||||
# Update the credentials
|
||||
updated_credentials_list = [
|
||||
updated if c.id == updated.id else c for c in self.get_all_creds(user_id)
|
||||
]
|
||||
self._set_user_integration_creds(user_id, updated_credentials_list)
|
||||
# Update the credentials
|
||||
updated_credentials_list = [
|
||||
updated if c.id == updated.id else c
|
||||
for c in self.get_all_creds(user_id)
|
||||
]
|
||||
self._set_user_integration_creds(user_id, updated_credentials_list)
|
||||
|
||||
def delete_creds_by_id(self, user_id: str, credentials_id: str) -> None:
|
||||
filtered_credentials = [
|
||||
c for c in self.get_all_creds(user_id) if c.id != credentials_id
|
||||
]
|
||||
self._set_user_integration_creds(user_id, filtered_credentials)
|
||||
with self.locked_user_metadata(user_id):
|
||||
filtered_credentials = [
|
||||
c for c in self.get_all_creds(user_id) if c.id != credentials_id
|
||||
]
|
||||
self._set_user_integration_creds(user_id, filtered_credentials)
|
||||
|
||||
async def store_state_token(
|
||||
self, user_id: str, provider: str, scopes: list[str]
|
||||
@@ -94,14 +103,15 @@ class SupabaseIntegrationCredentialsStore:
|
||||
scopes=scopes,
|
||||
)
|
||||
|
||||
user_metadata = self._get_user_metadata(user_id)
|
||||
oauth_states = user_metadata.get("integration_oauth_states", [])
|
||||
oauth_states.append(state.model_dump())
|
||||
user_metadata["integration_oauth_states"] = oauth_states
|
||||
with self.locked_user_metadata(user_id):
|
||||
user_metadata = self._get_user_metadata(user_id)
|
||||
oauth_states = user_metadata.get("integration_oauth_states", [])
|
||||
oauth_states.append(state.model_dump())
|
||||
user_metadata["integration_oauth_states"] = oauth_states
|
||||
|
||||
self.supabase.auth.admin.update_user_by_id(
|
||||
user_id, {"user_metadata": user_metadata}
|
||||
)
|
||||
self.supabase.auth.admin.update_user_by_id(
|
||||
user_id, {"user_metadata": user_metadata}
|
||||
)
|
||||
|
||||
return token
|
||||
|
||||
@@ -136,29 +146,30 @@ class SupabaseIntegrationCredentialsStore:
|
||||
return []
|
||||
|
||||
async def verify_state_token(self, user_id: str, token: str, provider: str) -> bool:
|
||||
user_metadata = self._get_user_metadata(user_id)
|
||||
oauth_states = user_metadata.get("integration_oauth_states", [])
|
||||
with self.locked_user_metadata(user_id):
|
||||
user_metadata = self._get_user_metadata(user_id)
|
||||
oauth_states = user_metadata.get("integration_oauth_states", [])
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
valid_state = next(
|
||||
(
|
||||
state
|
||||
for state in oauth_states
|
||||
if state["token"] == token
|
||||
and state["provider"] == provider
|
||||
and state["expires_at"] > now.timestamp()
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if valid_state:
|
||||
# Remove the used state
|
||||
oauth_states.remove(valid_state)
|
||||
user_metadata["integration_oauth_states"] = oauth_states
|
||||
self.supabase.auth.admin.update_user_by_id(
|
||||
user_id, {"user_metadata": user_metadata}
|
||||
now = datetime.now(timezone.utc)
|
||||
valid_state = next(
|
||||
(
|
||||
state
|
||||
for state in oauth_states
|
||||
if state["token"] == token
|
||||
and state["provider"] == provider
|
||||
and state["expires_at"] > now.timestamp()
|
||||
),
|
||||
None,
|
||||
)
|
||||
return True
|
||||
|
||||
if valid_state:
|
||||
# Remove the used state
|
||||
oauth_states.remove(valid_state)
|
||||
user_metadata["integration_oauth_states"] = oauth_states
|
||||
self.supabase.auth.admin.update_user_by_id(
|
||||
user_id, {"user_metadata": user_metadata}
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@@ -178,3 +189,7 @@ class SupabaseIntegrationCredentialsStore:
|
||||
if not response.user:
|
||||
raise ValueError(f"User with ID {user_id} not found")
|
||||
return cast(UserMetadataRaw, response.user.user_metadata)
|
||||
|
||||
def locked_user_metadata(self, user_id: str):
|
||||
key = (self.supabase.supabase_url, f"user:{user_id}", "metadata")
|
||||
return self.locks.locked(key)
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
from contextlib import contextmanager
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from expiringdict import ExpiringDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
|
||||
class RedisKeyedMutex:
|
||||
"""
|
||||
This class provides a mutex that can be locked and unlocked by a specific key,
|
||||
using Redis as a distributed locking provider.
|
||||
It uses an ExpiringDict to automatically clear the mutex after a specified timeout,
|
||||
in case the key is not unlocked for a specified duration, to prevent memory leaks.
|
||||
"""
|
||||
|
||||
def __init__(self, redis: "Redis", timeout: int | None = 60):
|
||||
self.redis = redis
|
||||
self.timeout = timeout
|
||||
self.locks: dict[Any, "RedisLock"] = ExpiringDict(
|
||||
max_len=6000, max_age_seconds=self.timeout
|
||||
)
|
||||
self.locks_lock = Lock()
|
||||
|
||||
@contextmanager
|
||||
def locked(self, key: Any):
|
||||
lock = self.acquire(key)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
def acquire(self, key: Any) -> "RedisLock":
|
||||
"""Acquires and returns a lock with the given key"""
|
||||
with self.locks_lock:
|
||||
if key not in self.locks:
|
||||
self.locks[key] = self.redis.lock(
|
||||
str(key), self.timeout, thread_local=False
|
||||
)
|
||||
lock = self.locks[key]
|
||||
lock.acquire()
|
||||
return lock
|
||||
|
||||
def release(self, key: Any):
|
||||
if lock := self.locks.get(key):
|
||||
lock.release()
|
||||
|
||||
def release_all_locks(self):
|
||||
"""Call this on process termination to ensure all locks are released"""
|
||||
self.locks_lock.acquire(blocking=False)
|
||||
for lock in self.locks.values():
|
||||
if lock.locked() and lock.owned():
|
||||
lock.release()
|
||||
36
autogpt_platform/autogpt_libs/poetry.lock
generated
36
autogpt_platform/autogpt_libs/poetry.lock
generated
@@ -377,6 +377,20 @@ files = [
|
||||
[package.extras]
|
||||
test = ["pytest (>=6)"]
|
||||
|
||||
[[package]]
|
||||
name = "expiringdict"
|
||||
version = "1.2.2"
|
||||
description = "Dictionary with auto-expiring values for caching purposes"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "expiringdict-1.2.2-py3-none-any.whl", hash = "sha256:09a5d20bc361163e6432a874edd3179676e935eb81b925eccef48d409a8a45e8"},
|
||||
{file = "expiringdict-1.2.2.tar.gz", hash = "sha256:300fb92a7e98f15b05cf9a856c1415b3bc4f2e132be07daa326da6414c23ee09"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
tests = ["coverage", "coveralls", "dill", "mock", "nose"]
|
||||
|
||||
[[package]]
|
||||
name = "frozenlist"
|
||||
version = "1.4.1"
|
||||
@@ -1031,6 +1045,7 @@ description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629"},
|
||||
{file = "pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034"},
|
||||
]
|
||||
|
||||
@@ -1041,6 +1056,7 @@ description = "A collection of ASN.1-based protocols modules"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pyasn1_modules-0.4.1-py3-none-any.whl", hash = "sha256:49bfa96b45a292b711e986f222502c1c9a5e1f4e568fc30e2574a6c7d07838fd"},
|
||||
{file = "pyasn1_modules-0.4.1.tar.gz", hash = "sha256:c28e2dbf9c06ad61c71a075c7e0f9fd0f1b0bb2d2ad4377f240d33ac2ab60a7c"},
|
||||
]
|
||||
|
||||
@@ -1253,6 +1269,24 @@ python-dateutil = ">=2.8.1,<3.0.0"
|
||||
typing-extensions = ">=4.12.2,<5.0.0"
|
||||
websockets = ">=11,<13"
|
||||
|
||||
[[package]]
|
||||
name = "redis"
|
||||
version = "5.1.1"
|
||||
description = "Python client for Redis database and key-value store"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "redis-5.1.1-py3-none-any.whl", hash = "sha256:f8ea06b7482a668c6475ae202ed8d9bcaa409f6e87fb77ed1043d912afd62e24"},
|
||||
{file = "redis-5.1.1.tar.gz", hash = "sha256:f6c997521fedbae53387307c5d0bf784d9acc28d9f1d058abeac566ec4dbed72"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
async-timeout = {version = ">=4.0.3", markers = "python_full_version < \"3.11.3\""}
|
||||
|
||||
[package.extras]
|
||||
hiredis = ["hiredis (>=3.0.0)"]
|
||||
ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==23.2.1)", "requests (>=2.31.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "requests"
|
||||
version = "2.32.3"
|
||||
@@ -1690,4 +1724,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.10,<4.0"
|
||||
content-hash = "e9b6e5d877eeb9c9f1ebc69dead1985d749facc160afbe61f3bf37e9a6e35aa5"
|
||||
content-hash = "ad9a4c8b399f6480a9f70319d13df810f92f63b532d4e10503d283f0948bed6c"
|
||||
|
||||
@@ -8,6 +8,7 @@ packages = [{ include = "autogpt_libs" }]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
colorama = "^0.4.6"
|
||||
expiringdict = "^1.2.2"
|
||||
google-cloud-logging = "^3.8.0"
|
||||
pydantic = "^2.8.2"
|
||||
pydantic-settings = "^2.5.2"
|
||||
@@ -16,6 +17,9 @@ python = ">=3.10,<4.0"
|
||||
python-dotenv = "^1.0.1"
|
||||
supabase = "^2.7.2"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
redis = "^5.0.8"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
@@ -25,7 +25,8 @@ def main(**kwargs):
|
||||
"""
|
||||
|
||||
from backend.executor import ExecutionManager, ExecutionScheduler
|
||||
from backend.server import AgentServer, WebsocketServer
|
||||
from backend.server.rest_api import AgentServer
|
||||
from backend.server.ws_api import WebsocketServer
|
||||
|
||||
run_processes(
|
||||
ExecutionManager(),
|
||||
|
||||
@@ -3,7 +3,6 @@ from datetime import datetime, timezone
|
||||
from multiprocessing import Manager
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from autogpt_libs.supabase_integration_credentials_store.types import Credentials
|
||||
from prisma.enums import AgentExecutionStatus
|
||||
from prisma.models import (
|
||||
AgentGraphExecution,
|
||||
@@ -26,7 +25,6 @@ class GraphExecution(BaseModel):
|
||||
graph_exec_id: str
|
||||
graph_id: str
|
||||
start_node_execs: list["NodeExecution"]
|
||||
node_input_credentials: dict[str, Credentials] # dict[node_id, Credentials]
|
||||
|
||||
|
||||
class NodeExecution(BaseModel):
|
||||
|
||||
@@ -11,8 +11,8 @@ from contextlib import contextmanager
|
||||
from multiprocessing.pool import AsyncResult, Pool
|
||||
from typing import TYPE_CHECKING, Any, Coroutine, Generator, TypeVar, cast
|
||||
|
||||
from autogpt_libs.supabase_integration_credentials_store.types import Credentials
|
||||
from pydantic import BaseModel
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.server.rest_api import AgentServer
|
||||
@@ -40,14 +40,16 @@ from backend.data.execution import (
|
||||
)
|
||||
from backend.data.graph import Graph, Link, Node, get_graph, get_node
|
||||
from backend.data.model import CREDENTIALS_FIELD_NAME, CredentialsMetaInput
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util import json
|
||||
from backend.util.decorator import error_logged, time_measured
|
||||
from backend.util.logging import configure_logging
|
||||
from backend.util.service import AppService, expose, get_service_client
|
||||
from backend.util.settings import Config
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.type import convert
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
|
||||
class LogMetadata:
|
||||
@@ -102,8 +104,8 @@ ExecutionStream = Generator[NodeExecution, None, None]
|
||||
def execute_node(
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
api_client: "AgentServer",
|
||||
creds_manager: IntegrationCredentialsManager,
|
||||
data: NodeExecution,
|
||||
input_credentials: Credentials | None = None,
|
||||
execution_stats: dict[str, Any] | None = None,
|
||||
) -> ExecutionStream:
|
||||
"""
|
||||
@@ -164,8 +166,15 @@ def execute_node(
|
||||
user_credit = get_user_credit_model()
|
||||
|
||||
extra_exec_kwargs = {}
|
||||
if input_credentials:
|
||||
extra_exec_kwargs["credentials"] = input_credentials
|
||||
# Last-minute fetch credentials + acquire a system-wide read-write lock to prevent
|
||||
# changes during execution. ⚠️ This means a set of credentials can only be used by
|
||||
# one (running) block at a time; simultaneous execution of blocks using same
|
||||
# credentials is not supported.
|
||||
credentials = creds_lock = None
|
||||
if CREDENTIALS_FIELD_NAME in input_data:
|
||||
credentials_meta = CredentialsMetaInput(**input_data[CREDENTIALS_FIELD_NAME])
|
||||
credentials, creds_lock = creds_manager.acquire(user_id, credentials_meta.id)
|
||||
extra_exec_kwargs["credentials"] = credentials
|
||||
|
||||
output_size = 0
|
||||
try:
|
||||
@@ -192,6 +201,10 @@ def execute_node(
|
||||
):
|
||||
yield execution
|
||||
|
||||
# Release lock on credentials ASAP
|
||||
if creds_lock:
|
||||
creds_lock.release()
|
||||
|
||||
r = update_execution(ExecutionStatus.COMPLETED)
|
||||
s = input_size + output_size
|
||||
t = (
|
||||
@@ -210,21 +223,14 @@ def execute_node(
|
||||
raise e
|
||||
|
||||
finally:
|
||||
# Ensure credentials are released even if execution fails
|
||||
if creds_lock:
|
||||
creds_lock.release()
|
||||
if execution_stats is not None:
|
||||
execution_stats["input_size"] = input_size
|
||||
execution_stats["output_size"] = output_size
|
||||
|
||||
|
||||
@contextmanager
|
||||
def synchronized(key: str, timeout: int = 60):
|
||||
lock = redis.get_redis().lock(f"lock:{key}", timeout=timeout)
|
||||
try:
|
||||
lock.acquire()
|
||||
yield
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
|
||||
def _enqueue_next_nodes(
|
||||
api_client: "AgentServer",
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
@@ -400,12 +406,6 @@ def validate_exec(
|
||||
return data, node_block.name
|
||||
|
||||
|
||||
def get_agent_server_client() -> "AgentServer":
|
||||
from backend.server.rest_api import AgentServer
|
||||
|
||||
return get_service_client(AgentServer, Config().agent_server_port)
|
||||
|
||||
|
||||
class Executor:
|
||||
"""
|
||||
This class contains event handlers for the process pool executor events.
|
||||
@@ -441,6 +441,7 @@ class Executor:
|
||||
redis.connect()
|
||||
cls.loop.run_until_complete(db.connect())
|
||||
cls.agent_server_client = get_agent_server_client()
|
||||
cls.creds_manager = IntegrationCredentialsManager()
|
||||
|
||||
# Set up shutdown handlers
|
||||
cls.shutdown_lock = threading.Lock()
|
||||
@@ -454,6 +455,8 @@ class Executor:
|
||||
if not cls.shutdown_lock.acquire(blocking=False):
|
||||
return # already shutting down
|
||||
|
||||
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Releasing locks...")
|
||||
cls.creds_manager.release_all_locks()
|
||||
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting DB...")
|
||||
cls.loop.run_until_complete(db.disconnect())
|
||||
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting Redis...")
|
||||
@@ -466,8 +469,12 @@ class Executor:
|
||||
if not cls.shutdown_lock.acquire(blocking=False):
|
||||
return # already shutting down, no need to self-terminate
|
||||
|
||||
llprint(f"[on_node_executor_sigterm {cls.pid}] ⏳ Releasing locks...")
|
||||
cls.creds_manager.release_all_locks()
|
||||
llprint(f"[on_node_executor_sigterm {cls.pid}] ⏳ Disconnecting DB...")
|
||||
cls.loop.run_until_complete(db.disconnect())
|
||||
llprint(f"[on_node_executor_sigterm {cls.pid}] ⏳ Disconnecting Redis...")
|
||||
redis.disconnect()
|
||||
llprint(f"[on_node_executor_sigterm {cls.pid}] ✅ Finished cleanup")
|
||||
sys.exit(0)
|
||||
|
||||
@@ -477,7 +484,6 @@ class Executor:
|
||||
cls,
|
||||
q: ExecutionQueue[NodeExecution],
|
||||
node_exec: NodeExecution,
|
||||
input_credentials: Credentials | None,
|
||||
):
|
||||
log_metadata = LogMetadata(
|
||||
user_id=node_exec.user_id,
|
||||
@@ -490,7 +496,7 @@ class Executor:
|
||||
|
||||
execution_stats = {}
|
||||
timing_info, _ = cls._on_node_execution(
|
||||
q, node_exec, input_credentials, log_metadata, execution_stats
|
||||
q, node_exec, log_metadata, execution_stats
|
||||
)
|
||||
execution_stats["walltime"] = timing_info.wall_time
|
||||
execution_stats["cputime"] = timing_info.cpu_time
|
||||
@@ -505,14 +511,13 @@ class Executor:
|
||||
cls,
|
||||
q: ExecutionQueue[NodeExecution],
|
||||
node_exec: NodeExecution,
|
||||
input_credentials: Credentials | None,
|
||||
log_metadata: LogMetadata,
|
||||
stats: dict[str, Any] | None = None,
|
||||
):
|
||||
try:
|
||||
log_metadata.info(f"Start node execution {node_exec.node_exec_id}")
|
||||
for execution in execute_node(
|
||||
cls.loop, cls.agent_server_client, node_exec, input_credentials, stats
|
||||
cls.loop, cls.agent_server_client, cls.creds_manager, node_exec, stats
|
||||
):
|
||||
q.add(execution)
|
||||
log_metadata.info(f"Finished node execution {node_exec.node_exec_id}")
|
||||
@@ -525,7 +530,7 @@ class Executor:
|
||||
def on_graph_executor_start(cls):
|
||||
configure_logging()
|
||||
|
||||
cls.pool_size = Config().num_node_workers
|
||||
cls.pool_size = settings.config.num_node_workers
|
||||
cls.loop = asyncio.new_event_loop()
|
||||
cls.pid = os.getpid()
|
||||
|
||||
@@ -648,11 +653,7 @@ class Executor:
|
||||
)
|
||||
running_executions[exec_data.node_id] = cls.executor.apply_async(
|
||||
cls.on_node_execution,
|
||||
(
|
||||
queue,
|
||||
exec_data,
|
||||
graph_exec.node_input_credentials.get(exec_data.node_id),
|
||||
),
|
||||
(queue, exec_data),
|
||||
callback=make_exec_callback(exec_data),
|
||||
)
|
||||
|
||||
@@ -688,10 +689,11 @@ class Executor:
|
||||
|
||||
class ExecutionManager(AppService):
|
||||
def __init__(self):
|
||||
super().__init__(port=Config().execution_manager_port)
|
||||
super().__init__(port=settings.config.execution_manager_port)
|
||||
self.use_db = True
|
||||
self.use_queue = True # we only need the Redis connection
|
||||
self.use_supabase = True
|
||||
self.pool_size = Config().num_graph_workers
|
||||
self.pool_size = settings.config.num_graph_workers
|
||||
self.queue = ExecutionQueue[GraphExecution]()
|
||||
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
|
||||
|
||||
@@ -700,7 +702,9 @@ class ExecutionManager(AppService):
|
||||
SupabaseIntegrationCredentialsStore,
|
||||
)
|
||||
|
||||
self.credentials_store = SupabaseIntegrationCredentialsStore(self.supabase)
|
||||
self.credentials_store = SupabaseIntegrationCredentialsStore(
|
||||
self.supabase, redis.get_redis()
|
||||
)
|
||||
self.executor = ProcessPoolExecutor(
|
||||
max_workers=self.pool_size,
|
||||
initializer=Executor.on_graph_executor_start,
|
||||
@@ -732,6 +736,8 @@ class ExecutionManager(AppService):
|
||||
|
||||
@property
|
||||
def agent_server_client(self) -> "AgentServer":
|
||||
# Since every single usage of this property happens from a different thread,
|
||||
# there is no value in caching it.
|
||||
return get_agent_server_client()
|
||||
|
||||
@expose
|
||||
@@ -743,7 +749,7 @@ class ExecutionManager(AppService):
|
||||
raise Exception(f"Graph #{graph_id} not found.")
|
||||
|
||||
graph.validate_graph(for_run=True)
|
||||
node_input_credentials = self._get_node_input_credentials(graph, user_id)
|
||||
self._validate_node_input_credentials(graph, user_id)
|
||||
|
||||
nodes_input = []
|
||||
for node in graph.starting_nodes:
|
||||
@@ -799,7 +805,6 @@ class ExecutionManager(AppService):
|
||||
graph_id=graph_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
start_node_execs=starting_node_execs,
|
||||
node_input_credentials=node_input_credentials,
|
||||
)
|
||||
self.queue.add(graph_exec)
|
||||
|
||||
@@ -846,12 +851,8 @@ class ExecutionManager(AppService):
|
||||
)
|
||||
self.agent_server_client.send_execution_update(exec_update.model_dump())
|
||||
|
||||
def _get_node_input_credentials(
|
||||
self, graph: Graph, user_id: str
|
||||
) -> dict[str, Credentials]:
|
||||
"""Gets all credentials for all nodes of the graph"""
|
||||
|
||||
node_credentials: dict[str, Credentials] = {}
|
||||
def _validate_node_input_credentials(self, graph: Graph, user_id: str):
|
||||
"""Checks all credentials for all nodes of the graph"""
|
||||
|
||||
for node in graph.nodes:
|
||||
block = get_block(node.block_id)
|
||||
@@ -894,9 +895,25 @@ class ExecutionManager(AppService):
|
||||
f"Invalid credentials #{credentials.id} for node #{node.id}: "
|
||||
"type/provider mismatch"
|
||||
)
|
||||
node_credentials[node.id] = credentials
|
||||
|
||||
return node_credentials
|
||||
|
||||
# ------- UTILITIES ------- #
|
||||
|
||||
|
||||
def get_agent_server_client() -> "AgentServer":
|
||||
from backend.server.rest_api import AgentServer
|
||||
|
||||
return get_service_client(AgentServer, settings.config.agent_server_port)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def synchronized(key: str, timeout: int = 60):
|
||||
lock: RedisLock = redis.get_redis().lock(f"lock:{key}", timeout=timeout)
|
||||
try:
|
||||
lock.acquire()
|
||||
yield
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
|
||||
def llprint(message: str):
|
||||
|
||||
172
autogpt_platform/backend/backend/integrations/creds_manager.py
Normal file
172
autogpt_platform/backend/backend/integrations/creds_manager.py
Normal file
@@ -0,0 +1,172 @@
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
|
||||
from autogpt_libs.supabase_integration_credentials_store import (
|
||||
Credentials,
|
||||
SupabaseIntegrationCredentialsStore,
|
||||
)
|
||||
from autogpt_libs.utils.synchronize import RedisKeyedMutex
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from backend.data import redis
|
||||
from backend.integrations.oauth import HANDLERS_BY_NAME, BaseOAuthHandler
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from ..server.integrations.utils import get_supabase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
|
||||
class IntegrationCredentialsManager:
|
||||
"""
|
||||
Handles the lifecycle of integration credentials.
|
||||
- Automatically refreshes requested credentials if needed.
|
||||
- Uses locking mechanisms to ensure system-wide consistency and
|
||||
prevent invalidation of in-use tokens.
|
||||
|
||||
### ⚠️ Gotcha
|
||||
With `acquire(..)`, credentials can only be in use in one place at a time (e.g. one
|
||||
block execution).
|
||||
|
||||
### Locking mechanism
|
||||
- Because *getting* credentials can result in a refresh (= *invalidation* +
|
||||
*replacement*) of the stored credentials, *getting* is an operation that
|
||||
potentially requires read/write access.
|
||||
- Checking whether a token has to be refreshed is subject to an additional `refresh`
|
||||
scoped lock to prevent unnecessary sequential refreshes when multiple executions
|
||||
try to access the same credentials simultaneously.
|
||||
- We MUST lock credentials while in use to prevent them from being invalidated while
|
||||
they are in use, e.g. because they are being refreshed by a different part
|
||||
of the system.
|
||||
- The `!time_sensitive` lock in `acquire(..)` is part of a two-tier locking
|
||||
mechanism in which *updating* gets priority over *getting* credentials.
|
||||
This is to prevent a long queue of waiting *get* requests from blocking essential
|
||||
credential refreshes or user-initiated updates.
|
||||
|
||||
It is possible to implement a reader/writer locking system where either multiple
|
||||
readers or a single writer can have simultaneous access, but this would add a lot of
|
||||
complexity to the mechanism. I don't expect the current ("simple") mechanism to
|
||||
cause so much latency that it's worth implementing.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
redis_conn = redis.get_redis()
|
||||
self._locks = RedisKeyedMutex(redis_conn)
|
||||
self.store = SupabaseIntegrationCredentialsStore(get_supabase(), redis_conn)
|
||||
|
||||
def create(self, user_id: str, credentials: Credentials) -> None:
|
||||
return self.store.add_creds(user_id, credentials)
|
||||
|
||||
def exists(self, user_id: str, credentials_id: str) -> bool:
|
||||
return self.store.get_creds_by_id(user_id, credentials_id) is not None
|
||||
|
||||
def get(
|
||||
self, user_id: str, credentials_id: str, lock: bool = True
|
||||
) -> Credentials | None:
|
||||
credentials = self.store.get_creds_by_id(user_id, credentials_id)
|
||||
if not credentials:
|
||||
return None
|
||||
|
||||
# Refresh OAuth credentials if needed
|
||||
if credentials.type == "oauth2" and credentials.access_token_expires_at:
|
||||
logger.debug(
|
||||
f"Credentials #{credentials.id} expire at "
|
||||
f"{datetime.fromtimestamp(credentials.access_token_expires_at)}; "
|
||||
f"current time is {datetime.now()}"
|
||||
)
|
||||
|
||||
with self._locked(user_id, credentials_id, "refresh"):
|
||||
oauth_handler = _get_provider_oauth_handler(credentials.provider)
|
||||
if oauth_handler.needs_refresh(credentials):
|
||||
logger.debug(
|
||||
f"Refreshing '{credentials.provider}' "
|
||||
f"credentials #{credentials.id}"
|
||||
)
|
||||
_lock = None
|
||||
if lock:
|
||||
# Wait until the credentials are no longer in use anywhere
|
||||
_lock = self._acquire_lock(user_id, credentials_id)
|
||||
|
||||
fresh_credentials = oauth_handler.refresh_tokens(credentials)
|
||||
self.store.update_creds(user_id, fresh_credentials)
|
||||
if _lock:
|
||||
_lock.release()
|
||||
|
||||
credentials = fresh_credentials
|
||||
else:
|
||||
logger.debug(f"Credentials #{credentials.id} never expire")
|
||||
|
||||
return credentials
|
||||
|
||||
def acquire(
|
||||
self, user_id: str, credentials_id: str
|
||||
) -> tuple[Credentials, RedisLock]:
|
||||
"""
|
||||
⚠️ WARNING: this locks credentials system-wide and blocks both acquiring
|
||||
and updating them elsewhere until the lock is released.
|
||||
See the class docstring for more info.
|
||||
"""
|
||||
# Use a low-priority (!time_sensitive) locking queue on top of the general lock
|
||||
# to allow priority access for refreshing/updating the tokens.
|
||||
with self._locked(user_id, credentials_id, "!time_sensitive"):
|
||||
lock = self._acquire_lock(user_id, credentials_id)
|
||||
credentials = self.get(user_id, credentials_id, lock=False)
|
||||
if not credentials:
|
||||
raise ValueError(
|
||||
f"Credentials #{credentials_id} for user #{user_id} not found"
|
||||
)
|
||||
return credentials, lock
|
||||
|
||||
def update(self, user_id: str, updated: Credentials) -> None:
|
||||
with self._locked(user_id, updated.id):
|
||||
self.store.update_creds(user_id, updated)
|
||||
|
||||
def delete(self, user_id: str, credentials_id: str) -> None:
|
||||
with self._locked(user_id, credentials_id):
|
||||
self.store.delete_creds_by_id(user_id, credentials_id)
|
||||
|
||||
# -- Locking utilities -- #
|
||||
|
||||
def _acquire_lock(self, user_id: str, credentials_id: str, *args: str) -> RedisLock:
|
||||
key = (
|
||||
self.store.supabase.supabase_url,
|
||||
f"user:{user_id}",
|
||||
f"credentials:{credentials_id}",
|
||||
*args,
|
||||
)
|
||||
return self._locks.acquire(key)
|
||||
|
||||
@contextmanager
|
||||
def _locked(self, user_id: str, credentials_id: str, *args: str):
|
||||
lock = self._acquire_lock(user_id, credentials_id, *args)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
def release_all_locks(self):
|
||||
"""Call this on process termination to ensure all locks are released"""
|
||||
self._locks.release_all_locks()
|
||||
self.store.locks.release_all_locks()
|
||||
|
||||
|
||||
def _get_provider_oauth_handler(provider_name: str) -> BaseOAuthHandler:
|
||||
if provider_name not in HANDLERS_BY_NAME:
|
||||
raise KeyError(f"Unknown provider '{provider_name}'")
|
||||
|
||||
client_id = getattr(settings.secrets, f"{provider_name}_client_id")
|
||||
client_secret = getattr(settings.secrets, f"{provider_name}_client_secret")
|
||||
if not (client_id and client_secret):
|
||||
raise Exception( # TODO: ConfigError
|
||||
f"Integration with provider '{provider_name}' is not configured",
|
||||
)
|
||||
|
||||
handler_class = HANDLERS_BY_NAME[provider_name]
|
||||
frontend_base_url = settings.config.frontend_base_url
|
||||
return handler_class(
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
redirect_uri=f"{frontend_base_url}/auth/integrations/oauth_callback",
|
||||
)
|
||||
@@ -1,6 +1,6 @@
|
||||
from backend.app import run_processes
|
||||
from backend.executor import ExecutionScheduler
|
||||
from backend.server import AgentServer
|
||||
from backend.server.rest_api import AgentServer
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
from .rest_api import AgentServer
|
||||
from .ws_api import WebsocketServer
|
||||
|
||||
__all__ = ["AgentServer", "WebsocketServer"]
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from autogpt_libs.supabase_integration_credentials_store import (
|
||||
SupabaseIntegrationCredentialsStore,
|
||||
)
|
||||
from autogpt_libs.supabase_integration_credentials_store.types import (
|
||||
APIKeyCredentials,
|
||||
Credentials,
|
||||
@@ -21,20 +18,17 @@ from fastapi import (
|
||||
Response,
|
||||
)
|
||||
from pydantic import BaseModel, SecretStr
|
||||
from supabase import Client
|
||||
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.oauth import HANDLERS_BY_NAME, BaseOAuthHandler
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from ..utils import get_supabase, get_user_id
|
||||
from ..utils import get_user_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def get_store(supabase: Client = Depends(get_supabase)):
|
||||
return SupabaseIntegrationCredentialsStore(supabase)
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
@@ -47,7 +41,6 @@ async def login(
|
||||
provider: Annotated[str, Path(title="The provider to initiate an OAuth flow for")],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
request: Request,
|
||||
store: Annotated[SupabaseIntegrationCredentialsStore, Depends(get_store)],
|
||||
scopes: Annotated[
|
||||
str, Query(title="Comma-separated list of authorization scopes")
|
||||
] = "",
|
||||
@@ -57,7 +50,9 @@ async def login(
|
||||
requested_scopes = scopes.split(",") if scopes else []
|
||||
|
||||
# Generate and store a secure random state token along with the scopes
|
||||
state_token = await store.store_state_token(user_id, provider, requested_scopes)
|
||||
state_token = await creds_manager.store.store_state_token(
|
||||
user_id, provider, requested_scopes
|
||||
)
|
||||
|
||||
login_url = handler.get_login_url(requested_scopes, state_token)
|
||||
|
||||
@@ -77,7 +72,6 @@ async def callback(
|
||||
provider: Annotated[str, Path(title="The target provider for this OAuth exchange")],
|
||||
code: Annotated[str, Body(title="Authorization code acquired by user login")],
|
||||
state_token: Annotated[str, Body(title="Anti-CSRF nonce")],
|
||||
store: Annotated[SupabaseIntegrationCredentialsStore, Depends(get_store)],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
request: Request,
|
||||
) -> CredentialsMetaResponse:
|
||||
@@ -85,12 +79,12 @@ async def callback(
|
||||
handler = _get_provider_oauth_handler(request, provider)
|
||||
|
||||
# Verify the state token
|
||||
if not await store.verify_state_token(user_id, state_token, provider):
|
||||
if not await creds_manager.store.verify_state_token(user_id, state_token, provider):
|
||||
logger.warning(f"Invalid or expired state token for user {user_id}")
|
||||
raise HTTPException(status_code=400, detail="Invalid or expired state token")
|
||||
|
||||
try:
|
||||
scopes = await store.get_any_valid_scopes_from_state_token(
|
||||
scopes = await creds_manager.store.get_any_valid_scopes_from_state_token(
|
||||
user_id, state_token, provider
|
||||
)
|
||||
logger.debug(f"Retrieved scopes from state token: {scopes}")
|
||||
@@ -114,7 +108,7 @@ async def callback(
|
||||
)
|
||||
|
||||
# TODO: Allow specifying `title` to set on `credentials`
|
||||
store.add_creds(user_id, credentials)
|
||||
creds_manager.create(user_id, credentials)
|
||||
|
||||
logger.debug(
|
||||
f"Successfully processed OAuth callback for user {user_id} and provider {provider}"
|
||||
@@ -132,9 +126,8 @@ async def callback(
|
||||
async def list_credentials(
|
||||
provider: Annotated[str, Path(title="The provider to list credentials for")],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
store: Annotated[SupabaseIntegrationCredentialsStore, Depends(get_store)],
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
credentials = store.get_creds_by_provider(user_id, provider)
|
||||
credentials = creds_manager.store.get_creds_by_provider(user_id, provider)
|
||||
return [
|
||||
CredentialsMetaResponse(
|
||||
id=cred.id,
|
||||
@@ -152,9 +145,8 @@ async def get_credential(
|
||||
provider: Annotated[str, Path(title="The provider to retrieve credentials for")],
|
||||
cred_id: Annotated[str, Path(title="The ID of the credentials to retrieve")],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
store: Annotated[SupabaseIntegrationCredentialsStore, Depends(get_store)],
|
||||
) -> Credentials:
|
||||
credential = store.get_creds_by_id(user_id, cred_id)
|
||||
credential = creds_manager.get(user_id, cred_id)
|
||||
if not credential:
|
||||
raise HTTPException(status_code=404, detail="Credentials not found")
|
||||
if credential.provider != provider:
|
||||
@@ -166,7 +158,6 @@ async def get_credential(
|
||||
|
||||
@router.post("/{provider}/credentials", status_code=201)
|
||||
async def create_api_key_credentials(
|
||||
store: Annotated[SupabaseIntegrationCredentialsStore, Depends(get_store)],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
provider: Annotated[str, Path(title="The provider to create credentials for")],
|
||||
api_key: Annotated[str, Body(title="The API key to store")],
|
||||
@@ -183,7 +174,7 @@ async def create_api_key_credentials(
|
||||
)
|
||||
|
||||
try:
|
||||
store.add_creds(user_id, new_credentials)
|
||||
creds_manager.create(user_id, new_credentials)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to store credentials: {str(e)}"
|
||||
@@ -196,9 +187,8 @@ async def delete_credential(
|
||||
provider: Annotated[str, Path(title="The provider to delete credentials for")],
|
||||
cred_id: Annotated[str, Path(title="The ID of the credentials to delete")],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
store: Annotated[SupabaseIntegrationCredentialsStore, Depends(get_store)],
|
||||
):
|
||||
creds = store.get_creds_by_id(user_id, cred_id)
|
||||
creds = creds_manager.store.get_creds_by_id(user_id, cred_id)
|
||||
if not creds:
|
||||
raise HTTPException(status_code=404, detail="Credentials not found")
|
||||
if creds.provider != provider:
|
||||
@@ -206,7 +196,7 @@ async def delete_credential(
|
||||
status_code=404, detail="Credentials do not match the specified provider"
|
||||
)
|
||||
|
||||
store.delete_creds_by_id(user_id, cred_id)
|
||||
creds_manager.delete(user_id, cred_id)
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
from supabase import Client, create_client
|
||||
|
||||
from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
def get_supabase() -> Client:
|
||||
return create_client(
|
||||
settings.secrets.supabase_url, settings.secrets.supabase_service_role_key
|
||||
)
|
||||
@@ -20,6 +20,7 @@ from backend.data.credit import get_block_costs, get_user_credit_model
|
||||
from backend.data.queue import RedisEventQueue
|
||||
from backend.data.user import get_or_create_user
|
||||
from backend.executor import ExecutionManager, ExecutionScheduler
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.server.model import CreateGraph, SetGraphActiveVersion
|
||||
from backend.util.service import AppService, expose, get_service_client
|
||||
from backend.util.settings import AppEnvironment, Config, Settings
|
||||
@@ -82,15 +83,16 @@ class AgentServer(AppService):
|
||||
api_router.dependencies.append(Depends(auth_middleware))
|
||||
|
||||
# Import & Attach sub-routers
|
||||
import backend.server.integrations.router
|
||||
import backend.server.routers.analytics
|
||||
import backend.server.routers.integrations
|
||||
|
||||
api_router.include_router(
|
||||
backend.server.routers.integrations.router,
|
||||
backend.server.integrations.router.router,
|
||||
prefix="/integrations",
|
||||
tags=["integrations"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
self.integration_creds_manager = IntegrationCredentialsManager()
|
||||
|
||||
api_router.include_router(
|
||||
backend.server.routers.analytics.router,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from autogpt_libs.auth.middleware import auth_middleware
|
||||
from fastapi import Depends, HTTPException
|
||||
from supabase import Client, create_client
|
||||
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.util.settings import Settings
|
||||
@@ -17,9 +16,3 @@ def get_user_id(payload: dict = Depends(auth_middleware)) -> str:
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="User ID not found in token")
|
||||
return user_id
|
||||
|
||||
|
||||
def get_supabase() -> Client:
|
||||
return create_client(
|
||||
settings.secrets.supabase_url, settings.secrets.supabase_service_role_key
|
||||
)
|
||||
|
||||
@@ -3,10 +3,24 @@ import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import typing
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Callable, Coroutine, Type, TypeVar, cast
|
||||
from types import UnionType
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Iterator,
|
||||
Type,
|
||||
TypeVar,
|
||||
cast,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
|
||||
import Pyro5.api
|
||||
from pydantic import BaseModel
|
||||
from Pyro5 import api as pyro
|
||||
|
||||
from backend.data import db
|
||||
@@ -27,9 +41,8 @@ def expose(func: C) -> C:
|
||||
Decorator to mark a method or class to be exposed for remote calls.
|
||||
|
||||
## ⚠️ Gotcha
|
||||
The types on the exposed function signature are respected **as long as they are
|
||||
fully picklable**. This is not the case for Pydantic models, so if you really need
|
||||
to pass a model, try dumping the model and passing the resulting dict instead.
|
||||
Aside from "simple" types, only Pydantic models are passed unscathed *if annotated*.
|
||||
Any other passed or returned class objects are converted to dictionaries by Pyro.
|
||||
"""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
@@ -40,9 +53,35 @@ def expose(func: C) -> C:
|
||||
logger.exception(msg)
|
||||
raise Exception(msg, e)
|
||||
|
||||
# Register custom serializers and deserializers for annotated Pydantic models
|
||||
for name, annotation in func.__annotations__.items():
|
||||
for model in _pydantic_models_from_type_annotation(annotation):
|
||||
logger.debug(
|
||||
f"Registering Pyro (de)serializers for {func.__name__} annotation "
|
||||
f"'{name}': {model.__qualname__}"
|
||||
)
|
||||
pyro.register_class_to_dict(
|
||||
model,
|
||||
lambda obj: {
|
||||
"__class__": obj.__class__.__qualname__,
|
||||
**obj.model_dump(),
|
||||
},
|
||||
)
|
||||
pyro.register_dict_to_class(
|
||||
model.__qualname__, _make_pyrodantic_parser(model)
|
||||
)
|
||||
|
||||
return pyro.expose(wrapper) # type: ignore
|
||||
|
||||
|
||||
def _make_pyrodantic_parser(model: type[BaseModel]):
|
||||
def parse_data_into_model(qualname, data: dict):
|
||||
logger.debug(f"Parsing Pyroed {model.__qualname__} from data {data}")
|
||||
return model(**data)
|
||||
|
||||
return parse_data_into_model
|
||||
|
||||
|
||||
class AppService(AppProcess):
|
||||
shared_event_loop: asyncio.AbstractEventLoop
|
||||
event_queue: AbstractEventQueue = RedisEventQueue()
|
||||
@@ -134,3 +173,28 @@ def get_service_client(service_type: Type[AS], port: int) -> AS:
|
||||
return getattr(self.proxy, name)
|
||||
|
||||
return cast(AS, DynamicClient())
|
||||
|
||||
|
||||
# --------- UTILITIES --------- #
|
||||
|
||||
|
||||
def _pydantic_models_from_type_annotation(annotation) -> Iterator[type[BaseModel]]:
|
||||
# Peel Annotated parameters
|
||||
if (origin := get_origin(annotation)) and origin is Annotated:
|
||||
annotation = get_args(annotation)[0]
|
||||
|
||||
if origin := get_origin(annotation):
|
||||
if origin is UnionType:
|
||||
types = get_args(annotation)
|
||||
else:
|
||||
types = [origin]
|
||||
else:
|
||||
types = [annotation]
|
||||
|
||||
for annotype in types:
|
||||
if (
|
||||
annotype is not None
|
||||
and not hasattr(typing, annotype.__name__) # avoid generics and aliases
|
||||
and issubclass(annotype, BaseModel)
|
||||
):
|
||||
yield annotype
|
||||
|
||||
@@ -6,8 +6,7 @@ from backend.data.execution import ExecutionStatus
|
||||
from backend.data.model import CREDENTIALS_FIELD_NAME
|
||||
from backend.data.user import create_default_user
|
||||
from backend.executor import ExecutionManager, ExecutionScheduler
|
||||
from backend.server import AgentServer
|
||||
from backend.server.rest_api import get_user_id
|
||||
from backend.server.rest_api import AgentServer, get_user_id
|
||||
|
||||
log = print
|
||||
|
||||
|
||||
3
autogpt_platform/backend/poetry.lock
generated
3
autogpt_platform/backend/poetry.lock
generated
@@ -293,6 +293,7 @@ develop = true
|
||||
|
||||
[package.dependencies]
|
||||
colorama = "^0.4.6"
|
||||
expiringdict = "^1.2.2"
|
||||
google-cloud-logging = "^3.8.0"
|
||||
pydantic = "^2.8.2"
|
||||
pydantic-settings = "^2.5.2"
|
||||
@@ -3667,4 +3668,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "3ab370b624b486517a2fbcdc17fb294fbd76b3ec6659c5b471c57bfd738e7277"
|
||||
content-hash = "0962d61ced1a8154c64c6bbdb3f72aca558831adfbfda68eb66f39b535466f77"
|
||||
|
||||
@@ -16,7 +16,6 @@ autogpt-libs = { path = "../autogpt_libs", develop = true }
|
||||
click = "^8.1.7"
|
||||
croniter = "^2.0.5"
|
||||
discord-py = "^2.4.0"
|
||||
expiringdict = "^1.2.2"
|
||||
fastapi = "^0.109.0"
|
||||
feedparser = "^6.0.11"
|
||||
flake8 = "^7.0.0"
|
||||
|
||||
@@ -4,8 +4,8 @@ from prisma.models import User
|
||||
from backend.blocks.basic import FindInDictionaryBlock, StoreValueBlock
|
||||
from backend.blocks.maths import CalculatorBlock, Operation
|
||||
from backend.data import execution, graph
|
||||
from backend.server import AgentServer
|
||||
from backend.server.model import CreateGraph
|
||||
from backend.server.rest_api import AgentServer
|
||||
from backend.usecases.sample import create_test_graph, create_test_user
|
||||
from backend.util.test import SpinTestServer, wait_execution
|
||||
|
||||
|
||||
Reference in New Issue
Block a user