update oauth system to work with dyncamiclly registered classes

This commit is contained in:
SwiftyOS
2025-07-02 16:21:46 +02:00
parent e564e15701
commit b478ae51c1
11 changed files with 151 additions and 107 deletions

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import json
from typing import Any, Dict, Optional, Union
from backend.sdk import APIKeyCredentials, OAuth2Credentials, Requests
from backend.sdk import OAuth2Credentials, Requests
from .models import CreateCommentResponse, CreateIssueResponse, Issue, Project
@@ -24,7 +24,7 @@ class LinearClient:
def __init__(
self,
credentials: Union[OAuth2Credentials, APIKeyCredentials, None] = None,
credentials: Union[OAuth2Credentials, None] = None,
custom_requests: Optional[Requests] = None,
):
if custom_requests:
@@ -33,15 +33,10 @@ class LinearClient:
headers: Dict[str, str] = {
"Content-Type": "application/json",
}
if credentials:
if isinstance(credentials, OAuth2Credentials):
headers["Authorization"] = (
f"Bearer {credentials.access_token.get_secret_value()}"
)
elif isinstance(credentials, APIKeyCredentials):
headers["Authorization"] = (
f"{credentials.api_key.get_secret_value()}"
)
if credentials and isinstance(credentials, OAuth2Credentials):
headers["Authorization"] = (
f"Bearer {credentials.access_token.get_secret_value()}"
)
self._requests = Requests(
extra_headers=headers,

View File

@@ -9,21 +9,18 @@ from backend.sdk import BlockCostType, ProviderBuilder
from ._oauth import LinearOAuthHandler
# Check if Linear OAuth is configured
LINEAR_OAUTH_IS_CONFIGURED = bool(
os.getenv("LINEAR_CLIENT_ID") and os.getenv("LINEAR_CLIENT_SECRET")
)
client_id = os.getenv("LINEAR_CLIENT_ID")
client_secret = os.getenv("LINEAR_CLIENT_SECRET")
LINEAR_OAUTH_IS_CONFIGURED = bool(client_id and client_secret)
# Build the Linear provider
builder = ProviderBuilder("linear").with_base_cost(1, BlockCostType.RUN)
# Add OAuth support if configured
# Linear only supports OAuth authentication
if LINEAR_OAUTH_IS_CONFIGURED:
builder = builder.with_oauth(
LinearOAuthHandler, scopes=["read", "write", "issues:create", "comments:create"]
)
# Add API key support as a fallback
builder = builder.with_api_key("LINEAR_API_KEY", "Linear API Key")
# Build the provider
linear = builder.build()

View File

@@ -7,7 +7,6 @@ from typing import Optional
from urllib.parse import urlencode
from backend.sdk import (
APIKeyCredentials,
BaseOAuthHandler,
OAuth2Credentials,
ProviderName,
@@ -29,7 +28,9 @@ class LinearOAuthHandler(BaseOAuthHandler):
OAuth2 handler for Linear.
"""
PROVIDER_NAME = ProviderName.LINEAR
# Provider name will be set dynamically by the SDK when registered
# We use a placeholder that will be replaced by AutoRegistry.register_provider()
PROVIDER_NAME = ProviderName("linear")
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
self.client_id = client_id
@@ -149,14 +150,15 @@ class LinearOAuthHandler(BaseOAuthHandler):
from ._api import LinearClient
try:
linear_client = LinearClient(
APIKeyCredentials(
api_key=SecretStr(access_token),
title="temp",
provider=self.PROVIDER_NAME,
expires_at=None,
)
# Create a temporary OAuth2Credentials object for the LinearClient
temp_creds = OAuth2Credentials(
id="temp",
provider=self.PROVIDER_NAME,
title="temp",
access_token=SecretStr(access_token),
scopes=[],
)
linear_client = LinearClient(credentials=temp_creds)
query = """
query Viewer {

View File

@@ -1,5 +1,4 @@
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
@@ -46,7 +45,7 @@ class LinearCreateCommentBlock(Block):
@staticmethod
async def create_comment(
credentials: OAuth2Credentials | APIKeyCredentials, issue_id: str, comment: str
credentials: OAuth2Credentials, issue_id: str, comment: str
) -> tuple[str, str]:
client = LinearClient(credentials=credentials)
response: CreateCommentResponse = await client.try_create_comment(
@@ -58,7 +57,7 @@ class LinearCreateCommentBlock(Block):
self,
input_data: Input,
*,
credentials: OAuth2Credentials | APIKeyCredentials,
credentials: OAuth2Credentials,
**kwargs,
) -> BlockOutput:
"""Execute the comment creation"""

View File

@@ -1,5 +1,4 @@
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
@@ -59,7 +58,7 @@ class LinearCreateIssueBlock(Block):
@staticmethod
async def create_issue(
credentials: OAuth2Credentials | APIKeyCredentials,
credentials: OAuth2Credentials,
team_name: str,
title: str,
description: str = "",
@@ -88,7 +87,7 @@ class LinearCreateIssueBlock(Block):
self,
input_data: Input,
*,
credentials: OAuth2Credentials | APIKeyCredentials,
credentials: OAuth2Credentials,
**kwargs,
) -> BlockOutput:
"""Execute the issue creation"""
@@ -138,7 +137,7 @@ class LinearSearchIssuesBlock(Block):
@staticmethod
async def search_issues(
credentials: OAuth2Credentials | APIKeyCredentials,
credentials: OAuth2Credentials,
term: str,
) -> list[Issue]:
client = LinearClient(credentials=credentials)
@@ -149,7 +148,7 @@ class LinearSearchIssuesBlock(Block):
self,
input_data: Input,
*,
credentials: OAuth2Credentials | APIKeyCredentials,
credentials: OAuth2Credentials,
**kwargs,
) -> BlockOutput:
"""Execute the issue search"""

View File

@@ -1,5 +1,4 @@
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
@@ -42,7 +41,7 @@ class LinearSearchProjectsBlock(Block):
@staticmethod
async def search_projects(
credentials: OAuth2Credentials | APIKeyCredentials,
credentials: OAuth2Credentials,
term: str,
) -> list[Project]:
client = LinearClient(credentials=credentials)
@@ -53,7 +52,7 @@ class LinearSearchProjectsBlock(Block):
self,
input_data: Input,
*,
credentials: OAuth2Credentials | APIKeyCredentials,
credentials: OAuth2Credentials,
**kwargs,
) -> BlockOutput:
"""Execute the project search"""

View File

@@ -8,20 +8,104 @@ from .notion import NotionOAuthHandler
from .twitter import TwitterOAuthHandler
if TYPE_CHECKING:
from ..providers import ProviderName
from .base import BaseOAuthHandler
# --8<-- [start:HANDLERS_BY_NAMEExample]
HANDLERS_BY_NAME: dict["ProviderName", type["BaseOAuthHandler"]] = {
handler.PROVIDER_NAME: handler
for handler in [
GitHubOAuthHandler,
GoogleOAuthHandler,
NotionOAuthHandler,
TwitterOAuthHandler,
TodoistOAuthHandler,
]
# Build handlers dict with string keys for compatibility with SDK auto-registration
_ORIGINAL_HANDLERS = [
GitHubOAuthHandler,
GoogleOAuthHandler,
NotionOAuthHandler,
TwitterOAuthHandler,
TodoistOAuthHandler,
]
# Start with original handlers
_handlers_dict = {
(
handler.PROVIDER_NAME.value
if hasattr(handler.PROVIDER_NAME, "value")
else str(handler.PROVIDER_NAME)
): handler
for handler in _ORIGINAL_HANDLERS
}
# Create a custom dict class that includes SDK handlers
class SDKAwareHandlersDict(dict):
"""Dictionary that automatically includes SDK-registered OAuth handlers."""
def __getitem__(self, key):
# First try the original handlers
if key in _handlers_dict:
return _handlers_dict[key]
# Then try SDK handlers
try:
from backend.sdk import AutoRegistry
sdk_handlers = AutoRegistry.get_oauth_handlers()
if key in sdk_handlers:
return sdk_handlers[key]
except ImportError:
pass
# If not found, raise KeyError
raise KeyError(key)
def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default
def __contains__(self, key):
if key in _handlers_dict:
return True
try:
from backend.sdk import AutoRegistry
sdk_handlers = AutoRegistry.get_oauth_handlers()
return key in sdk_handlers
except ImportError:
return False
def keys(self):
# Combine all keys into a single dict and return its keys view
combined = dict(_handlers_dict)
try:
from backend.sdk import AutoRegistry
sdk_handlers = AutoRegistry.get_oauth_handlers()
combined.update(sdk_handlers)
except ImportError:
pass
return combined.keys()
def values(self):
combined = dict(_handlers_dict)
try:
from backend.sdk import AutoRegistry
sdk_handlers = AutoRegistry.get_oauth_handlers()
combined.update(sdk_handlers)
except ImportError:
pass
return combined.values()
def items(self):
combined = dict(_handlers_dict)
try:
from backend.sdk import AutoRegistry
sdk_handlers = AutoRegistry.get_oauth_handlers()
combined.update(sdk_handlers)
except ImportError:
pass
return combined.items()
HANDLERS_BY_NAME: dict[str, type["BaseOAuthHandler"]] = SDKAwareHandlersDict()
# --8<-- [end:HANDLERS_BY_NAMEExample]
__all__ = ["HANDLERS_BY_NAME"]

View File

@@ -28,7 +28,6 @@ class ProviderName(str, Enum):
HUBSPOT = "hubspot"
IDEOGRAM = "ideogram"
JINA = "jina"
LINEAR = "linear"
LLAMA_API = "llama_api"
MEDIUM = "medium"
MEM0 = "mem0"

View File

@@ -129,10 +129,6 @@ CredentialsMetaInput = _CredentialsMetaInput[
]
# Initialize the registry's integration patches
AutoRegistry.patch_integrations()
# === COMPREHENSIVE __all__ EXPORT ===
__all__ = [
# Core Block System

View File

@@ -54,6 +54,16 @@ class AutoRegistry:
# Register OAuth handler if provided
if provider.oauth_handler:
# Dynamically set PROVIDER_NAME if not already set
if (
not hasattr(provider.oauth_handler, "PROVIDER_NAME")
or provider.oauth_handler.PROVIDER_NAME is None
):
# Import ProviderName to create dynamic enum value
from backend.integrations.providers import ProviderName
# This works because ProviderName has _missing_ method
provider.oauth_handler.PROVIDER_NAME = ProviderName(provider.name)
cls._oauth_handlers[provider.name] = provider.oauth_handler
# Register webhook manager if provided
@@ -136,57 +146,8 @@ class AutoRegistry:
@classmethod
def patch_integrations(cls) -> None:
"""Patch existing integration points to use AutoRegistry."""
# Patch oauth handlers
try:
import backend.integrations.oauth as oauth
if hasattr(oauth, "HANDLERS_BY_NAME"):
# Create a new dict that includes both original and SDK handlers
original_handlers = dict(oauth.HANDLERS_BY_NAME)
class PatchedHandlersDict(dict): # type: ignore
def __getitem__(self, key):
# First try SDK handlers
sdk_handlers = cls.get_oauth_handlers()
if key in sdk_handlers:
return sdk_handlers[key]
# Fall back to original
return original_handlers[key]
def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default
def __contains__(self, key):
sdk_handlers = cls.get_oauth_handlers()
return key in sdk_handlers or key in original_handlers
def keys(self): # type: ignore[override]
sdk_handlers = cls.get_oauth_handlers()
all_keys = set(original_handlers.keys()) | set(
sdk_handlers.keys()
)
return all_keys
def values(self):
combined = dict(original_handlers)
sdk_handlers = cls.get_oauth_handlers()
if isinstance(sdk_handlers, dict):
combined.update(sdk_handlers) # type: ignore
return combined.values()
def items(self):
combined = dict(original_handlers)
sdk_handlers = cls.get_oauth_handlers()
if isinstance(sdk_handlers, dict):
combined.update(sdk_handlers) # type: ignore
return combined.items()
oauth.HANDLERS_BY_NAME = PatchedHandlersDict()
except Exception as e:
logging.warning(f"Failed to patch oauth handlers: {e}")
# OAuth handlers are now handled by SDKAwareHandlersDict in oauth/__init__.py
# No patching needed for OAuth handlers
# Patch webhook managers
try:

View File

@@ -477,10 +477,23 @@ async def remove_all_webhooks_for_credentials(
def _get_provider_oauth_handler(
req: Request, provider_name: ProviderName
) -> "BaseOAuthHandler":
if provider_name not in HANDLERS_BY_NAME:
# Ensure blocks are loaded so SDK providers are available
try:
from backend.blocks import load_all_blocks
load_all_blocks() # This is cached, so it only runs once
except Exception as e:
logger.warning(f"Failed to load blocks: {e}")
# Convert provider_name to string for lookup
provider_key = (
provider_name.value if hasattr(provider_name, "value") else str(provider_name)
)
if provider_key not in HANDLERS_BY_NAME:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Provider '{provider_name.value}' does not support OAuth",
detail=f"Provider '{provider_key}' does not support OAuth",
)
client_id = getattr(settings.secrets, f"{provider_name.value}_client_id")
@@ -497,7 +510,7 @@ def _get_provider_oauth_handler(
},
)
handler_class = HANDLERS_BY_NAME[provider_name]
handler_class = HANDLERS_BY_NAME[provider_key]
frontend_base_url = (
settings.config.frontend_base_url
or settings.config.platform_base_url