changed linear to use sdk to test oauth flow

This commit is contained in:
SwiftyOS
2025-07-02 09:40:31 +02:00
parent 75309047cf
commit 4f057c5b72
14 changed files with 228 additions and 333 deletions

View File

@@ -26,15 +26,15 @@ def load_all_blocks() -> dict[str, type["Block"]]:
for f in current_dir.rglob("*.py"):
if not f.is_file() or f.name == "__init__.py" or f.name.startswith("test_"):
continue
# Skip examples directory if not enabled
relative_path = f.relative_to(current_dir)
if not load_examples and relative_path.parts[0] == "examples":
continue
module_path = str(relative_path)[:-3].replace(os.path.sep, ".")
modules.append(module_path)
for module in modules:
if not re.match("^[a-z0-9_.]+$", module):
raise ValueError(

View File

@@ -0,0 +1,14 @@
"""
Linear integration blocks for AutoGPT Platform.
"""
from .comment import LinearCreateCommentBlock
from .issues import LinearCreateIssueBlock, LinearSearchIssuesBlock
from .projects import LinearSearchProjectsBlock
__all__ = [
"LinearCreateCommentBlock",
"LinearCreateIssueBlock",
"LinearSearchIssuesBlock",
"LinearSearchProjectsBlock",
]

View File

@@ -1,16 +1,15 @@
from __future__ import annotations
import json
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union
from backend.blocks.linear._auth import LinearCredentials
from backend.blocks.linear.models import (
from backend.sdk import OAuth2Credentials, APIKeyCredentials, Requests
from .models import (
CreateCommentResponse,
CreateIssueResponse,
Issue,
Project,
)
from backend.util.request import Requests
class LinearAPIException(Exception):
@@ -29,18 +28,20 @@ class LinearClient:
def __init__(
self,
credentials: LinearCredentials | None = None,
credentials: Union[OAuth2Credentials, APIKeyCredentials, None] = None,
custom_requests: Optional[Requests] = None,
):
if custom_requests:
self._requests = custom_requests
else:
headers: Dict[str, str] = {
"Content-Type": "application/json",
}
if credentials:
headers["Authorization"] = credentials.auth_header()
if isinstance(credentials, OAuth2Credentials):
headers["Authorization"] = f"Bearer {credentials.access_token.get_secret_value()}"
elif isinstance(credentials, APIKeyCredentials):
headers["Authorization"] = f"{credentials.api_key.get_secret_value()}"
self._requests = Requests(
extra_headers=headers,

View File

@@ -1,101 +0,0 @@
from enum import Enum
from typing import Literal
from pydantic import SecretStr
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
CredentialsMetaInput,
OAuth2Credentials,
)
from backend.integrations.providers import ProviderName
from backend.util.settings import Secrets
secrets = Secrets()
LINEAR_OAUTH_IS_CONFIGURED = bool(
secrets.linear_client_id and secrets.linear_client_secret
)
LinearCredentials = OAuth2Credentials | APIKeyCredentials
# LinearCredentialsInput = CredentialsMetaInput[
# Literal[ProviderName.LINEAR],
# Literal["oauth2", "api_key"] if LINEAR_OAUTH_IS_CONFIGURED else Literal["oauth2"],
# ]
LinearCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.LINEAR], Literal["oauth2"]
]
# (required) Comma separated list of scopes:
# read - (Default) Read access for the user's account. This scope will always be present.
# write - Write access for the user's account. If your application only needs to create comments, use a more targeted scope
# issues:create - Allows creating new issues and their attachments
# comments:create - Allows creating new issue comments
# timeSchedule:write - Allows creating and modifying time schedules
# admin - Full access to admin level endpoints. You should never ask for this permission unless it's absolutely needed
class LinearScope(str, Enum):
READ = "read"
WRITE = "write"
ISSUES_CREATE = "issues:create"
COMMENTS_CREATE = "comments:create"
TIME_SCHEDULE_WRITE = "timeSchedule:write"
ADMIN = "admin"
def LinearCredentialsField(scopes: list[LinearScope]) -> LinearCredentialsInput:
"""
Creates a Linear credentials input on a block.
Params:
scope: The authorization scope needed for the block to work. ([list of available scopes](https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/scopes-for-oauth-apps#available-scopes))
""" # noqa
return CredentialsField(
required_scopes=set([LinearScope.READ.value]).union(
set([scope.value for scope in scopes])
),
description="The Linear integration can be used with OAuth, "
"or any API key with sufficient permissions for the blocks it is used on.",
)
TEST_CREDENTIALS_OAUTH = OAuth2Credentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="linear",
title="Mock Linear API key",
username="mock-linear-username",
access_token=SecretStr("mock-linear-access-token"),
access_token_expires_at=None,
refresh_token=SecretStr("mock-linear-refresh-token"),
refresh_token_expires_at=None,
scopes=["mock-linear-scopes"],
)
TEST_CREDENTIALS_API_KEY = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="linear",
title="Mock Linear API key",
api_key=SecretStr("mock-linear-api-key"),
expires_at=None,
)
TEST_CREDENTIALS_INPUT_OAUTH = {
"provider": TEST_CREDENTIALS_OAUTH.provider,
"id": TEST_CREDENTIALS_OAUTH.id,
"type": TEST_CREDENTIALS_OAUTH.type,
"title": TEST_CREDENTIALS_OAUTH.type,
}
TEST_CREDENTIALS_INPUT_API_KEY = {
"provider": TEST_CREDENTIALS_API_KEY.provider,
"id": TEST_CREDENTIALS_API_KEY.id,
"type": TEST_CREDENTIALS_API_KEY.type,
"title": TEST_CREDENTIALS_API_KEY.type,
}

View File

@@ -0,0 +1,28 @@
"""
Shared configuration for all Linear blocks using the new SDK pattern.
"""
import os
from backend.sdk import BlockCostType, ProviderBuilder
from ._oauth import LinearOAuthHandler
# Check if Linear OAuth is configured
LINEAR_OAUTH_IS_CONFIGURED = bool(
os.getenv("LINEAR_CLIENT_ID") and os.getenv("LINEAR_CLIENT_SECRET")
)
# Build the Linear provider
builder = ProviderBuilder("linear").with_base_cost(1, BlockCostType.RUN)
# Add OAuth support if configured
if LINEAR_OAUTH_IS_CONFIGURED:
builder = builder.with_oauth(
LinearOAuthHandler,
scopes=["read", "write", "issues:create", "comments:create"]
)
# Add API key support as a fallback
builder = builder.with_api_key("LINEAR_API_KEY", "Linear API Key")
# Build the provider
linear = builder.build()

View File

@@ -1,15 +1,26 @@
"""
Linear OAuth handler implementation.
"""
import json
from typing import Optional
from urllib.parse import urlencode
from pydantic import SecretStr
from backend.sdk import (
BaseOAuthHandler,
OAuth2Credentials,
APIKeyCredentials,
ProviderName,
Requests,
SecretStr,
)
from backend.blocks.linear._api import LinearAPIException
from backend.data.model import APIKeyCredentials, OAuth2Credentials
from backend.integrations.providers import ProviderName
from backend.util.request import Requests
from .base import BaseOAuthHandler
class LinearAPIException(Exception):
"""Exception for Linear API errors."""
def __init__(self, message: str, status_code: int):
super().__init__(message)
self.status_code = status_code
class LinearOAuthHandler(BaseOAuthHandler):
@@ -24,17 +35,16 @@ class LinearOAuthHandler(BaseOAuthHandler):
self.client_secret = client_secret
self.redirect_uri = redirect_uri
self.auth_base_url = "https://linear.app/oauth/authorize"
self.token_url = "https://api.linear.app/oauth/token" # Correct token URL
self.token_url = "https://api.linear.app/oauth/token"
self.revoke_url = "https://api.linear.app/oauth/revoke"
def get_login_url(
self, scopes: list[str], state: str, code_challenge: Optional[str]
) -> str:
params = {
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"response_type": "code", # Important: include "response_type"
"response_type": "code",
"scope": ",".join(scopes), # Comma-separated, not space-separated
"state": state,
}
@@ -92,13 +102,13 @@ class LinearOAuthHandler(BaseOAuthHandler):
request_body = {
"client_id": self.client_id,
"client_secret": self.client_secret,
"grant_type": "authorization_code", # Ensure grant_type is correct
"grant_type": "authorization_code",
**params,
}
headers = {
"Content-Type": "application/x-www-form-urlencoded"
} # Correct header for token request
}
response = await Requests().post(
self.token_url, data=request_body, headers=headers
)
@@ -122,7 +132,7 @@ class LinearOAuthHandler(BaseOAuthHandler):
title=current_credentials.title if current_credentials else None,
username=token_data.get("user", {}).get(
"name", "Unknown User"
), # extract name or set appropriate
),
access_token=token_data["access_token"],
scopes=token_data["scope"].split(
","
@@ -139,7 +149,7 @@ class LinearOAuthHandler(BaseOAuthHandler):
async def _request_username(self, access_token: str) -> Optional[str]:
# Use the LinearClient to fetch user details using GraphQL
from backend.blocks.linear._api import LinearClient
from ._api import LinearClient
try:
linear_client = LinearClient(
@@ -149,7 +159,7 @@ class LinearOAuthHandler(BaseOAuthHandler):
provider=self.PROVIDER_NAME,
expires_at=None,
)
) # Temporary credentials for this request
)
query = """
query Viewer {
@@ -162,6 +172,6 @@ class LinearOAuthHandler(BaseOAuthHandler):
response = await linear_client.query(query)
return response["viewer"]["name"]
except Exception as e: # Handle any errors
except Exception as e:
print(f"Error fetching username: {e}")
return None
return None

View File

@@ -1,61 +1,51 @@
from backend.blocks.linear._api import LinearAPIException, LinearClient
from backend.blocks.linear._auth import (
LINEAR_OAUTH_IS_CONFIGURED,
TEST_CREDENTIALS_INPUT_OAUTH,
TEST_CREDENTIALS_OAUTH,
LinearCredentials,
LinearCredentialsField,
LinearCredentialsInput,
LinearScope,
from backend.sdk import (
Block,
BlockCategory,
BlockOutput,
BlockSchema,
SchemaField,
String,
OAuth2Credentials,
APIKeyCredentials,
CredentialsMetaInput,
)
from backend.blocks.linear.models import CreateCommentResponse
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from ._api import LinearAPIException, LinearClient
from .models import CreateCommentResponse
from ._config import linear
class LinearCreateCommentBlock(Block):
"""Block for creating comments on Linear issues"""
class Input(BlockSchema):
credentials: LinearCredentialsInput = LinearCredentialsField(
scopes=[LinearScope.COMMENTS_CREATE],
credentials: CredentialsMetaInput = linear.credentials_field(
description="Linear credentials with comment creation permissions",
required_scopes={"read", "comments:create"},
)
issue_id: str = SchemaField(description="ID of the issue to comment on")
comment: str = SchemaField(description="Comment text to add to the issue")
issue_id: String = SchemaField(description="ID of the issue to comment on")
comment: String = SchemaField(description="Comment text to add to the issue")
class Output(BlockSchema):
comment_id: str = SchemaField(description="ID of the created comment")
comment_body: str = SchemaField(
comment_id: String = SchemaField(description="ID of the created comment")
comment_body: String = SchemaField(
description="Text content of the created comment"
)
error: str = SchemaField(description="Error message if comment creation failed")
error: String = SchemaField(
description="Error message if comment creation failed", default=""
)
def __init__(self):
super().__init__(
id="8f7d3a2e-9b5c-4c6a-8f1d-7c8b3e4a5d6c",
description="Creates a new comment on a Linear issue",
input_schema=self.Input,
output_schema=self.Output,
categories={BlockCategory.PRODUCTIVITY, BlockCategory.ISSUE_TRACKING},
test_input={
"issue_id": "TEST-123",
"comment": "Test comment",
"credentials": TEST_CREDENTIALS_INPUT_OAUTH,
},
disabled=not LINEAR_OAUTH_IS_CONFIGURED,
test_credentials=TEST_CREDENTIALS_OAUTH,
test_output=[("comment_id", "abc123"), ("comment_body", "Test comment")],
test_mock={
"create_comment": lambda *args, **kwargs: (
"abc123",
"Test comment",
)
},
input_schema=LinearCreateCommentBlock.Input,
output_schema=LinearCreateCommentBlock.Output,
categories={BlockCategory.PRODUCTIVITY},
)
@staticmethod
async def create_comment(
credentials: LinearCredentials, issue_id: str, comment: str
credentials: OAuth2Credentials | APIKeyCredentials, issue_id: str, comment: str
) -> tuple[str, str]:
client = LinearClient(credentials=credentials)
response: CreateCommentResponse = await client.try_create_comment(
@@ -64,7 +54,11 @@ class LinearCreateCommentBlock(Block):
return response.comment.id, response.comment.body
async def run(
self, input_data: Input, *, credentials: LinearCredentials, **kwargs
self,
input_data: Input,
*,
credentials: OAuth2Credentials | APIKeyCredentials,
**kwargs
) -> BlockOutput:
"""Execute the comment creation"""
try:
@@ -80,4 +74,4 @@ class LinearCreateCommentBlock(Block):
except LinearAPIException as e:
yield "error", str(e)
except Exception as e:
yield "error", f"Unexpected error: {str(e)}"
yield "error", f"Unexpected error: {str(e)}"

View File

@@ -1,79 +1,67 @@
from backend.blocks.linear._api import LinearAPIException, LinearClient
from backend.blocks.linear._auth import (
LINEAR_OAUTH_IS_CONFIGURED,
TEST_CREDENTIALS_INPUT_OAUTH,
TEST_CREDENTIALS_OAUTH,
LinearCredentials,
LinearCredentialsField,
LinearCredentialsInput,
LinearScope,
from backend.sdk import (
Block,
BlockCategory,
BlockOutput,
BlockSchema,
SchemaField,
String,
OAuth2Credentials,
APIKeyCredentials,
CredentialsMetaInput,
)
from backend.blocks.linear.models import CreateIssueResponse, Issue
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from ._api import LinearAPIException, LinearClient
from .models import CreateIssueResponse, Issue
from ._config import linear
class LinearCreateIssueBlock(Block):
"""Block for creating issues on Linear"""
class Input(BlockSchema):
credentials: LinearCredentialsInput = LinearCredentialsField(
scopes=[LinearScope.ISSUES_CREATE],
credentials: CredentialsMetaInput = linear.credentials_field(
description="Linear credentials with issue creation permissions",
required_scopes={"read", "issues:create"},
)
title: str = SchemaField(description="Title of the issue")
description: str | None = SchemaField(description="Description of the issue")
team_name: str = SchemaField(
title: String = SchemaField(description="Title of the issue")
description: String = SchemaField(description="Description of the issue", default="")
team_name: String = SchemaField(
description="Name of the team to create the issue on"
)
priority: int | None = SchemaField(
description="Priority of the issue",
default=None,
priority: int = SchemaField(
description="Priority of the issue (0-4, where 0 is no priority, 1 is urgent, 2 is high, 3 is normal, 4 is low)",
default=3,
ge=0,
le=4,
)
project_name: str | None = SchemaField(
project_name: String = SchemaField(
description="Name of the project to create the issue on",
default=None,
default="",
)
class Output(BlockSchema):
issue_id: str = SchemaField(description="ID of the created issue")
issue_title: str = SchemaField(description="Title of the created issue")
error: str = SchemaField(description="Error message if issue creation failed")
issue_id: String = SchemaField(description="ID of the created issue")
issue_title: String = SchemaField(description="Title of the created issue")
error: String = SchemaField(
description="Error message if issue creation failed", default=""
)
def __init__(self):
super().__init__(
id="f9c68f55-dcca-40a8-8771-abf9601680aa",
description="Creates a new issue on Linear",
disabled=not LINEAR_OAUTH_IS_CONFIGURED,
input_schema=self.Input,
output_schema=self.Output,
categories={BlockCategory.PRODUCTIVITY, BlockCategory.ISSUE_TRACKING},
test_input={
"title": "Test issue",
"description": "Test description",
"team_name": "Test team",
"project_name": "Test project",
"credentials": TEST_CREDENTIALS_INPUT_OAUTH,
},
test_credentials=TEST_CREDENTIALS_OAUTH,
test_output=[("issue_id", "abc123"), ("issue_title", "Test issue")],
test_mock={
"create_issue": lambda *args, **kwargs: (
"abc123",
"Test issue",
)
},
categories={BlockCategory.PRODUCTIVITY},
)
@staticmethod
async def create_issue(
credentials: LinearCredentials,
credentials: OAuth2Credentials | APIKeyCredentials,
team_name: str,
title: str,
description: str | None = None,
priority: int | None = None,
project_name: str | None = None,
description: str = "",
priority: int = 3,
project_name: str = "",
) -> tuple[str, str]:
client = LinearClient(credentials=credentials)
team_id = await client.try_get_team_by_name(team_name=team_name)
@@ -87,14 +75,18 @@ class LinearCreateIssueBlock(Block):
response: CreateIssueResponse = await client.try_create_issue(
team_id=team_id,
title=title,
description=description,
priority=priority,
description=description if description else None,
priority=priority if priority != 3 else None,
project_id=project_id,
)
return response.issue.identifier, response.issue.title
async def run(
self, input_data: Input, *, credentials: LinearCredentials, **kwargs
self,
input_data: Input,
*,
credentials: OAuth2Credentials | APIKeyCredentials,
**kwargs
) -> BlockOutput:
"""Execute the issue creation"""
try:
@@ -120,13 +112,17 @@ class LinearSearchIssuesBlock(Block):
"""Block for searching issues on Linear"""
class Input(BlockSchema):
term: str = SchemaField(description="Term to search for issues")
credentials: LinearCredentialsInput = LinearCredentialsField(
scopes=[LinearScope.READ],
credentials: CredentialsMetaInput = linear.credentials_field(
description="Linear credentials with read permissions",
required_scopes={"read"},
)
term: String = SchemaField(description="Term to search for issues")
class Output(BlockSchema):
issues: list[Issue] = SchemaField(description="List of issues")
error: String = SchemaField(
description="Error message if search failed", default=""
)
def __init__(self):
super().__init__(
@@ -134,42 +130,12 @@ class LinearSearchIssuesBlock(Block):
description="Searches for issues on Linear",
input_schema=self.Input,
output_schema=self.Output,
disabled=not LINEAR_OAUTH_IS_CONFIGURED,
test_input={
"term": "Test issue",
"credentials": TEST_CREDENTIALS_INPUT_OAUTH,
},
test_credentials=TEST_CREDENTIALS_OAUTH,
test_output=[
(
"issues",
[
Issue(
id="abc123",
identifier="abc123",
title="Test issue",
description="Test description",
priority=1,
)
],
)
],
test_mock={
"search_issues": lambda *args, **kwargs: [
Issue(
id="abc123",
identifier="abc123",
title="Test issue",
description="Test description",
priority=1,
)
]
},
categories={BlockCategory.PRODUCTIVITY},
)
@staticmethod
async def search_issues(
credentials: LinearCredentials,
credentials: OAuth2Credentials | APIKeyCredentials,
term: str,
) -> list[Issue]:
client = LinearClient(credentials=credentials)
@@ -177,7 +143,11 @@ class LinearSearchIssuesBlock(Block):
return response
async def run(
self, input_data: Input, *, credentials: LinearCredentials, **kwargs
self,
input_data: Input,
*,
credentials: OAuth2Credentials | APIKeyCredentials,
**kwargs
) -> BlockOutput:
"""Execute the issue search"""
try:

View File

@@ -1,4 +1,4 @@
from pydantic import BaseModel
from backend.sdk import BaseModel
class Comment(BaseModel):

View File

@@ -1,30 +1,34 @@
from backend.blocks.linear._api import LinearAPIException, LinearClient
from backend.blocks.linear._auth import (
LINEAR_OAUTH_IS_CONFIGURED,
TEST_CREDENTIALS_INPUT_OAUTH,
TEST_CREDENTIALS_OAUTH,
LinearCredentials,
LinearCredentialsField,
LinearCredentialsInput,
LinearScope,
from backend.sdk import (
Block,
BlockCategory,
BlockOutput,
BlockSchema,
SchemaField,
String,
OAuth2Credentials,
APIKeyCredentials,
CredentialsMetaInput,
)
from backend.blocks.linear.models import Project
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from ._api import LinearAPIException, LinearClient
from .models import Project
from ._config import linear
class LinearSearchProjectsBlock(Block):
"""Block for searching projects on Linear"""
class Input(BlockSchema):
credentials: LinearCredentialsInput = LinearCredentialsField(
scopes=[LinearScope.READ],
credentials: CredentialsMetaInput = linear.credentials_field(
description="Linear credentials with read permissions",
required_scopes={"read"},
)
term: str = SchemaField(description="Term to search for projects")
term: String = SchemaField(description="Term to search for projects")
class Output(BlockSchema):
projects: list[Project] = SchemaField(description="List of projects")
error: str = SchemaField(description="Error message if issue creation failed")
error: String = SchemaField(
description="Error message if search failed", default=""
)
def __init__(self):
super().__init__(
@@ -32,45 +36,12 @@ class LinearSearchProjectsBlock(Block):
description="Searches for projects on Linear",
input_schema=self.Input,
output_schema=self.Output,
categories={BlockCategory.PRODUCTIVITY, BlockCategory.ISSUE_TRACKING},
test_input={
"term": "Test project",
"credentials": TEST_CREDENTIALS_INPUT_OAUTH,
},
disabled=not LINEAR_OAUTH_IS_CONFIGURED,
test_credentials=TEST_CREDENTIALS_OAUTH,
test_output=[
(
"projects",
[
Project(
id="abc123",
name="Test project",
description="Test description",
priority=1,
progress=1,
content="Test content",
)
],
)
],
test_mock={
"search_projects": lambda *args, **kwargs: [
Project(
id="abc123",
name="Test project",
description="Test description",
priority=1,
progress=1,
content="Test content",
)
]
},
categories={BlockCategory.PRODUCTIVITY},
)
@staticmethod
async def search_projects(
credentials: LinearCredentials,
credentials: OAuth2Credentials | APIKeyCredentials,
term: str,
) -> list[Project]:
client = LinearClient(credentials=credentials)
@@ -78,7 +49,11 @@ class LinearSearchProjectsBlock(Block):
return response
async def run(
self, input_data: Input, *, credentials: LinearCredentials, **kwargs
self,
input_data: Input,
*,
credentials: OAuth2Credentials | APIKeyCredentials,
**kwargs
) -> BlockOutput:
"""Execute the project search"""
try:

View File

@@ -4,7 +4,6 @@ from backend.integrations.oauth.todoist import TodoistOAuthHandler
from .github import GitHubOAuthHandler
from .google import GoogleOAuthHandler
from .linear import LinearOAuthHandler
from .notion import NotionOAuthHandler
from .twitter import TwitterOAuthHandler
@@ -20,7 +19,6 @@ HANDLERS_BY_NAME: dict["ProviderName", type["BaseOAuthHandler"]] = {
GoogleOAuthHandler,
NotionOAuthHandler,
TwitterOAuthHandler,
LinearOAuthHandler,
TodoistOAuthHandler,
]
}

View File

@@ -115,7 +115,7 @@ class AutoRegistry:
"""Get a registered provider by name."""
with cls._lock:
return cls._providers.get(name)
@classmethod
def get_all_provider_names(cls) -> List[str]:
"""Get all registered provider names."""

View File

@@ -16,23 +16,23 @@ from backend.sdk.registry import AutoRegistry
def get_all_provider_names() -> List[str]:
"""
Collect all provider names from both ProviderName enum and AutoRegistry.
This function should be called at runtime to ensure we get all
dynamically registered providers.
Returns:
A sorted list of unique provider names.
"""
# Get static providers from enum
static_providers = [member.value for member in ProviderName]
# Get dynamic providers from registry
dynamic_providers = AutoRegistry.get_all_provider_names()
# Combine and deduplicate
all_providers = list(set(static_providers + dynamic_providers))
all_providers.sort()
return all_providers
@@ -42,9 +42,10 @@ def get_all_provider_names() -> List[str]:
class ProviderNamesResponse(BaseModel):
"""Response containing list of all provider names."""
providers: List[str] = Field(
description="List of all available provider names",
default_factory=get_all_provider_names
default_factory=get_all_provider_names,
)
@@ -54,22 +55,25 @@ def create_provider_enum_model():
This ensures the OpenAPI schema includes all provider names.
"""
all_providers = get_all_provider_names()
if not all_providers:
# Fallback if no providers are registered yet
all_providers = ["unknown"]
# Create a Literal type with all provider names
# This will be included in the OpenAPI schema
ProviderNameLiteral = Literal[tuple(all_providers)] # type: ignore
# Create a dynamic model that uses this Literal
DynamicProviderModel = create_model(
'AllProviderNames',
provider=(ProviderNameLiteral, Field(description="A provider name from the complete list")),
__module__=__name__
"AllProviderNames",
provider=(
ProviderNameLiteral,
Field(description="A provider name from the complete list"),
),
__module__=__name__,
)
return DynamicProviderModel
@@ -78,14 +82,14 @@ class ProviderConstants(BaseModel):
Model that exposes all provider names as a constant in the OpenAPI schema.
This is designed to be converted by Orval into a TypeScript constant.
"""
PROVIDER_NAMES: Dict[str, str] = Field(
description="All available provider names as a constant mapping",
default_factory=lambda: {
name.upper().replace('-', '_'): name
for name in get_all_provider_names()
}
name.upper().replace("-", "_"): name for name in get_all_provider_names()
},
)
class Config:
schema_extra = {
"example": {
@@ -94,7 +98,7 @@ class ProviderConstants(BaseModel):
"ANTHROPIC": "anthropic",
"EXA": "exa",
"GEM": "gem",
"EXAMPLE_SERVICE": "example-service"
"EXAMPLE_SERVICE": "example-service",
}
}
}
}

View File

@@ -1,6 +1,6 @@
import asyncio
import logging
from typing import TYPE_CHECKING, Annotated, Awaitable, Dict, List, Literal
from typing import TYPE_CHECKING, Annotated, Awaitable, List, Literal
from fastapi import (
APIRouter,
@@ -33,7 +33,6 @@ from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.oauth import HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks import get_webhook_manager
from backend.sdk.registry import AutoRegistry
from backend.server.integrations.models import (
ProviderConstants,
ProviderNamesResponse,
@@ -521,7 +520,7 @@ async def list_providers() -> List[str]:
Returns both statically defined providers (from ProviderName enum)
and dynamically registered providers (from SDK decorators).
Note: The complete list of provider names is also available as a constant
in the generated TypeScript client via PROVIDER_NAMES.
"""
@@ -534,7 +533,7 @@ async def list_providers() -> List[str]:
async def get_provider_names() -> ProviderNamesResponse:
"""
Get all provider names in a structured format.
This endpoint is specifically designed to expose the provider names
in the OpenAPI schema so that code generators like Orval can create
appropriate TypeScript constants.
@@ -546,7 +545,7 @@ async def get_provider_names() -> ProviderNamesResponse:
async def get_provider_constants() -> ProviderConstants:
"""
Get provider names as constants.
This endpoint returns a model with provider names as constants,
specifically designed for OpenAPI code generation tools to create
TypeScript constants.
@@ -556,6 +555,7 @@ async def get_provider_constants() -> ProviderConstants:
class ProviderEnumResponse(BaseModel):
"""Response containing a provider from the enum."""
provider: str = Field(
description="A provider name from the complete list of providers"
)
@@ -565,11 +565,13 @@ class ProviderEnumResponse(BaseModel):
async def get_provider_enum_example() -> ProviderEnumResponse:
"""
Example endpoint that uses the CompleteProviderNames enum.
This endpoint exists to ensure that the CompleteProviderNames enum is included
in the OpenAPI schema, which will cause Orval to generate it as a
TypeScript enum/constant.
"""
# Return the first provider as an example
all_providers = get_all_provider_names()
return ProviderEnumResponse(provider=all_providers[0] if all_providers else "openai")
return ProviderEnumResponse(
provider=all_providers[0] if all_providers else "openai"
)