mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
update oauth system to work with dyncamiclly registered classes
This commit is contained in:
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import json
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from backend.sdk import APIKeyCredentials, OAuth2Credentials, Requests
|
||||
from backend.sdk import OAuth2Credentials, Requests
|
||||
|
||||
from .models import CreateCommentResponse, CreateIssueResponse, Issue, Project
|
||||
|
||||
@@ -24,7 +24,7 @@ class LinearClient:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
credentials: Union[OAuth2Credentials, APIKeyCredentials, None] = None,
|
||||
credentials: Union[OAuth2Credentials, None] = None,
|
||||
custom_requests: Optional[Requests] = None,
|
||||
):
|
||||
if custom_requests:
|
||||
@@ -33,15 +33,10 @@ class LinearClient:
|
||||
headers: Dict[str, str] = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if credentials:
|
||||
if isinstance(credentials, OAuth2Credentials):
|
||||
headers["Authorization"] = (
|
||||
f"Bearer {credentials.access_token.get_secret_value()}"
|
||||
)
|
||||
elif isinstance(credentials, APIKeyCredentials):
|
||||
headers["Authorization"] = (
|
||||
f"{credentials.api_key.get_secret_value()}"
|
||||
)
|
||||
if credentials and isinstance(credentials, OAuth2Credentials):
|
||||
headers["Authorization"] = (
|
||||
f"Bearer {credentials.access_token.get_secret_value()}"
|
||||
)
|
||||
|
||||
self._requests = Requests(
|
||||
extra_headers=headers,
|
||||
|
||||
@@ -9,21 +9,18 @@ from backend.sdk import BlockCostType, ProviderBuilder
|
||||
from ._oauth import LinearOAuthHandler
|
||||
|
||||
# Check if Linear OAuth is configured
|
||||
LINEAR_OAUTH_IS_CONFIGURED = bool(
|
||||
os.getenv("LINEAR_CLIENT_ID") and os.getenv("LINEAR_CLIENT_SECRET")
|
||||
)
|
||||
client_id = os.getenv("LINEAR_CLIENT_ID")
|
||||
client_secret = os.getenv("LINEAR_CLIENT_SECRET")
|
||||
LINEAR_OAUTH_IS_CONFIGURED = bool(client_id and client_secret)
|
||||
|
||||
# Build the Linear provider
|
||||
builder = ProviderBuilder("linear").with_base_cost(1, BlockCostType.RUN)
|
||||
|
||||
# Add OAuth support if configured
|
||||
# Linear only supports OAuth authentication
|
||||
if LINEAR_OAUTH_IS_CONFIGURED:
|
||||
builder = builder.with_oauth(
|
||||
LinearOAuthHandler, scopes=["read", "write", "issues:create", "comments:create"]
|
||||
)
|
||||
|
||||
# Add API key support as a fallback
|
||||
builder = builder.with_api_key("LINEAR_API_KEY", "Linear API Key")
|
||||
|
||||
# Build the provider
|
||||
linear = builder.build()
|
||||
|
||||
@@ -7,7 +7,6 @@ from typing import Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
BaseOAuthHandler,
|
||||
OAuth2Credentials,
|
||||
ProviderName,
|
||||
@@ -29,7 +28,9 @@ class LinearOAuthHandler(BaseOAuthHandler):
|
||||
OAuth2 handler for Linear.
|
||||
"""
|
||||
|
||||
PROVIDER_NAME = ProviderName.LINEAR
|
||||
# Provider name will be set dynamically by the SDK when registered
|
||||
# We use a placeholder that will be replaced by AutoRegistry.register_provider()
|
||||
PROVIDER_NAME = ProviderName("linear")
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
@@ -149,14 +150,15 @@ class LinearOAuthHandler(BaseOAuthHandler):
|
||||
from ._api import LinearClient
|
||||
|
||||
try:
|
||||
linear_client = LinearClient(
|
||||
APIKeyCredentials(
|
||||
api_key=SecretStr(access_token),
|
||||
title="temp",
|
||||
provider=self.PROVIDER_NAME,
|
||||
expires_at=None,
|
||||
)
|
||||
# Create a temporary OAuth2Credentials object for the LinearClient
|
||||
temp_creds = OAuth2Credentials(
|
||||
id="temp",
|
||||
provider=self.PROVIDER_NAME,
|
||||
title="temp",
|
||||
access_token=SecretStr(access_token),
|
||||
scopes=[],
|
||||
)
|
||||
linear_client = LinearClient(credentials=temp_creds)
|
||||
|
||||
query = """
|
||||
query Viewer {
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
@@ -46,7 +45,7 @@ class LinearCreateCommentBlock(Block):
|
||||
|
||||
@staticmethod
|
||||
async def create_comment(
|
||||
credentials: OAuth2Credentials | APIKeyCredentials, issue_id: str, comment: str
|
||||
credentials: OAuth2Credentials, issue_id: str, comment: str
|
||||
) -> tuple[str, str]:
|
||||
client = LinearClient(credentials=credentials)
|
||||
response: CreateCommentResponse = await client.try_create_comment(
|
||||
@@ -58,7 +57,7 @@ class LinearCreateCommentBlock(Block):
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: OAuth2Credentials | APIKeyCredentials,
|
||||
credentials: OAuth2Credentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Execute the comment creation"""
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
@@ -59,7 +58,7 @@ class LinearCreateIssueBlock(Block):
|
||||
|
||||
@staticmethod
|
||||
async def create_issue(
|
||||
credentials: OAuth2Credentials | APIKeyCredentials,
|
||||
credentials: OAuth2Credentials,
|
||||
team_name: str,
|
||||
title: str,
|
||||
description: str = "",
|
||||
@@ -88,7 +87,7 @@ class LinearCreateIssueBlock(Block):
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: OAuth2Credentials | APIKeyCredentials,
|
||||
credentials: OAuth2Credentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Execute the issue creation"""
|
||||
@@ -138,7 +137,7 @@ class LinearSearchIssuesBlock(Block):
|
||||
|
||||
@staticmethod
|
||||
async def search_issues(
|
||||
credentials: OAuth2Credentials | APIKeyCredentials,
|
||||
credentials: OAuth2Credentials,
|
||||
term: str,
|
||||
) -> list[Issue]:
|
||||
client = LinearClient(credentials=credentials)
|
||||
@@ -149,7 +148,7 @@ class LinearSearchIssuesBlock(Block):
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: OAuth2Credentials | APIKeyCredentials,
|
||||
credentials: OAuth2Credentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Execute the issue search"""
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
@@ -42,7 +41,7 @@ class LinearSearchProjectsBlock(Block):
|
||||
|
||||
@staticmethod
|
||||
async def search_projects(
|
||||
credentials: OAuth2Credentials | APIKeyCredentials,
|
||||
credentials: OAuth2Credentials,
|
||||
term: str,
|
||||
) -> list[Project]:
|
||||
client = LinearClient(credentials=credentials)
|
||||
@@ -53,7 +52,7 @@ class LinearSearchProjectsBlock(Block):
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: OAuth2Credentials | APIKeyCredentials,
|
||||
credentials: OAuth2Credentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Execute the project search"""
|
||||
|
||||
@@ -8,20 +8,104 @@ from .notion import NotionOAuthHandler
|
||||
from .twitter import TwitterOAuthHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..providers import ProviderName
|
||||
from .base import BaseOAuthHandler
|
||||
|
||||
# --8<-- [start:HANDLERS_BY_NAMEExample]
|
||||
HANDLERS_BY_NAME: dict["ProviderName", type["BaseOAuthHandler"]] = {
|
||||
handler.PROVIDER_NAME: handler
|
||||
for handler in [
|
||||
GitHubOAuthHandler,
|
||||
GoogleOAuthHandler,
|
||||
NotionOAuthHandler,
|
||||
TwitterOAuthHandler,
|
||||
TodoistOAuthHandler,
|
||||
]
|
||||
# Build handlers dict with string keys for compatibility with SDK auto-registration
|
||||
_ORIGINAL_HANDLERS = [
|
||||
GitHubOAuthHandler,
|
||||
GoogleOAuthHandler,
|
||||
NotionOAuthHandler,
|
||||
TwitterOAuthHandler,
|
||||
TodoistOAuthHandler,
|
||||
]
|
||||
|
||||
# Start with original handlers
|
||||
_handlers_dict = {
|
||||
(
|
||||
handler.PROVIDER_NAME.value
|
||||
if hasattr(handler.PROVIDER_NAME, "value")
|
||||
else str(handler.PROVIDER_NAME)
|
||||
): handler
|
||||
for handler in _ORIGINAL_HANDLERS
|
||||
}
|
||||
|
||||
|
||||
# Create a custom dict class that includes SDK handlers
|
||||
class SDKAwareHandlersDict(dict):
|
||||
"""Dictionary that automatically includes SDK-registered OAuth handlers."""
|
||||
|
||||
def __getitem__(self, key):
|
||||
# First try the original handlers
|
||||
if key in _handlers_dict:
|
||||
return _handlers_dict[key]
|
||||
|
||||
# Then try SDK handlers
|
||||
try:
|
||||
from backend.sdk import AutoRegistry
|
||||
|
||||
sdk_handlers = AutoRegistry.get_oauth_handlers()
|
||||
if key in sdk_handlers:
|
||||
return sdk_handlers[key]
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# If not found, raise KeyError
|
||||
raise KeyError(key)
|
||||
|
||||
def get(self, key, default=None):
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
def __contains__(self, key):
|
||||
if key in _handlers_dict:
|
||||
return True
|
||||
try:
|
||||
from backend.sdk import AutoRegistry
|
||||
|
||||
sdk_handlers = AutoRegistry.get_oauth_handlers()
|
||||
return key in sdk_handlers
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
def keys(self):
|
||||
# Combine all keys into a single dict and return its keys view
|
||||
combined = dict(_handlers_dict)
|
||||
try:
|
||||
from backend.sdk import AutoRegistry
|
||||
|
||||
sdk_handlers = AutoRegistry.get_oauth_handlers()
|
||||
combined.update(sdk_handlers)
|
||||
except ImportError:
|
||||
pass
|
||||
return combined.keys()
|
||||
|
||||
def values(self):
|
||||
combined = dict(_handlers_dict)
|
||||
try:
|
||||
from backend.sdk import AutoRegistry
|
||||
|
||||
sdk_handlers = AutoRegistry.get_oauth_handlers()
|
||||
combined.update(sdk_handlers)
|
||||
except ImportError:
|
||||
pass
|
||||
return combined.values()
|
||||
|
||||
def items(self):
|
||||
combined = dict(_handlers_dict)
|
||||
try:
|
||||
from backend.sdk import AutoRegistry
|
||||
|
||||
sdk_handlers = AutoRegistry.get_oauth_handlers()
|
||||
combined.update(sdk_handlers)
|
||||
except ImportError:
|
||||
pass
|
||||
return combined.items()
|
||||
|
||||
|
||||
HANDLERS_BY_NAME: dict[str, type["BaseOAuthHandler"]] = SDKAwareHandlersDict()
|
||||
# --8<-- [end:HANDLERS_BY_NAMEExample]
|
||||
|
||||
__all__ = ["HANDLERS_BY_NAME"]
|
||||
|
||||
@@ -28,7 +28,6 @@ class ProviderName(str, Enum):
|
||||
HUBSPOT = "hubspot"
|
||||
IDEOGRAM = "ideogram"
|
||||
JINA = "jina"
|
||||
LINEAR = "linear"
|
||||
LLAMA_API = "llama_api"
|
||||
MEDIUM = "medium"
|
||||
MEM0 = "mem0"
|
||||
|
||||
@@ -129,10 +129,6 @@ CredentialsMetaInput = _CredentialsMetaInput[
|
||||
]
|
||||
|
||||
|
||||
# Initialize the registry's integration patches
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
|
||||
# === COMPREHENSIVE __all__ EXPORT ===
|
||||
__all__ = [
|
||||
# Core Block System
|
||||
|
||||
@@ -54,6 +54,16 @@ class AutoRegistry:
|
||||
|
||||
# Register OAuth handler if provided
|
||||
if provider.oauth_handler:
|
||||
# Dynamically set PROVIDER_NAME if not already set
|
||||
if (
|
||||
not hasattr(provider.oauth_handler, "PROVIDER_NAME")
|
||||
or provider.oauth_handler.PROVIDER_NAME is None
|
||||
):
|
||||
# Import ProviderName to create dynamic enum value
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
# This works because ProviderName has _missing_ method
|
||||
provider.oauth_handler.PROVIDER_NAME = ProviderName(provider.name)
|
||||
cls._oauth_handlers[provider.name] = provider.oauth_handler
|
||||
|
||||
# Register webhook manager if provided
|
||||
@@ -136,57 +146,8 @@ class AutoRegistry:
|
||||
@classmethod
|
||||
def patch_integrations(cls) -> None:
|
||||
"""Patch existing integration points to use AutoRegistry."""
|
||||
# Patch oauth handlers
|
||||
try:
|
||||
import backend.integrations.oauth as oauth
|
||||
|
||||
if hasattr(oauth, "HANDLERS_BY_NAME"):
|
||||
# Create a new dict that includes both original and SDK handlers
|
||||
original_handlers = dict(oauth.HANDLERS_BY_NAME)
|
||||
|
||||
class PatchedHandlersDict(dict): # type: ignore
|
||||
def __getitem__(self, key):
|
||||
# First try SDK handlers
|
||||
sdk_handlers = cls.get_oauth_handlers()
|
||||
if key in sdk_handlers:
|
||||
return sdk_handlers[key]
|
||||
# Fall back to original
|
||||
return original_handlers[key]
|
||||
|
||||
def get(self, key, default=None):
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
def __contains__(self, key):
|
||||
sdk_handlers = cls.get_oauth_handlers()
|
||||
return key in sdk_handlers or key in original_handlers
|
||||
|
||||
def keys(self): # type: ignore[override]
|
||||
sdk_handlers = cls.get_oauth_handlers()
|
||||
all_keys = set(original_handlers.keys()) | set(
|
||||
sdk_handlers.keys()
|
||||
)
|
||||
return all_keys
|
||||
|
||||
def values(self):
|
||||
combined = dict(original_handlers)
|
||||
sdk_handlers = cls.get_oauth_handlers()
|
||||
if isinstance(sdk_handlers, dict):
|
||||
combined.update(sdk_handlers) # type: ignore
|
||||
return combined.values()
|
||||
|
||||
def items(self):
|
||||
combined = dict(original_handlers)
|
||||
sdk_handlers = cls.get_oauth_handlers()
|
||||
if isinstance(sdk_handlers, dict):
|
||||
combined.update(sdk_handlers) # type: ignore
|
||||
return combined.items()
|
||||
|
||||
oauth.HANDLERS_BY_NAME = PatchedHandlersDict()
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to patch oauth handlers: {e}")
|
||||
# OAuth handlers are now handled by SDKAwareHandlersDict in oauth/__init__.py
|
||||
# No patching needed for OAuth handlers
|
||||
|
||||
# Patch webhook managers
|
||||
try:
|
||||
|
||||
@@ -477,10 +477,23 @@ async def remove_all_webhooks_for_credentials(
|
||||
def _get_provider_oauth_handler(
|
||||
req: Request, provider_name: ProviderName
|
||||
) -> "BaseOAuthHandler":
|
||||
if provider_name not in HANDLERS_BY_NAME:
|
||||
# Ensure blocks are loaded so SDK providers are available
|
||||
try:
|
||||
from backend.blocks import load_all_blocks
|
||||
|
||||
load_all_blocks() # This is cached, so it only runs once
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load blocks: {e}")
|
||||
|
||||
# Convert provider_name to string for lookup
|
||||
provider_key = (
|
||||
provider_name.value if hasattr(provider_name, "value") else str(provider_name)
|
||||
)
|
||||
|
||||
if provider_key not in HANDLERS_BY_NAME:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Provider '{provider_name.value}' does not support OAuth",
|
||||
detail=f"Provider '{provider_key}' does not support OAuth",
|
||||
)
|
||||
|
||||
client_id = getattr(settings.secrets, f"{provider_name.value}_client_id")
|
||||
@@ -497,7 +510,7 @@ def _get_provider_oauth_handler(
|
||||
},
|
||||
)
|
||||
|
||||
handler_class = HANDLERS_BY_NAME[provider_name]
|
||||
handler_class = HANDLERS_BY_NAME[provider_key]
|
||||
frontend_base_url = (
|
||||
settings.config.frontend_base_url
|
||||
or settings.config.platform_base_url
|
||||
|
||||
Reference in New Issue
Block a user