refactor(backend): Simplify CredentialsField usage + use ProviderName globally (#8725)

- Resolves #8931
- Follow-up to #8358

### Changes 🏗️
- Avoid double specifying provider and cred types on `credentials`
inputs
- Move `credentials` sub-schema validation from `CredentialsField` to
`CredentialsMetaInput.validate_credentials_field_schema(..)`, which is
called in `BlockSchema.__pydantic_init_subclass__`
- Use `ProviderName` enum globally
This commit is contained in:
Reinier van der Leer
2024-12-11 20:27:09 +01:00
committed by GitHub
parent b8a3ffc04a
commit 33b9eef376
38 changed files with 313 additions and 191 deletions

View File

@@ -12,6 +12,7 @@ from backend.data.model import (
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
class ImageSize(str, Enum):
@@ -101,12 +102,10 @@ class ImageGenModel(str, Enum):
class AIImageGeneratorBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput[Literal["replicate"], Literal["api_key"]] = (
CredentialsField(
provider="replicate",
supported_credential_types={"api_key"},
description="Enter your Replicate API key to access the image generation API. You can obtain an API key from https://replicate.com/account/api-tokens.",
)
credentials: CredentialsMetaInput[
Literal[ProviderName.REPLICATE], Literal["api_key"]
] = CredentialsField(
description="Enter your Replicate API key to access the image generation API. You can obtain an API key from https://replicate.com/account/api-tokens.",
)
prompt: str = SchemaField(
description="Text prompt for image generation",

View File

@@ -13,6 +13,7 @@ from backend.data.model import (
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
logger = logging.getLogger(__name__)
@@ -54,13 +55,11 @@ class NormalizationStrategy(str, Enum):
class AIMusicGeneratorBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput[Literal["replicate"], Literal["api_key"]] = (
CredentialsField(
provider="replicate",
supported_credential_types={"api_key"},
description="The Replicate integration can be used with "
"any API key with sufficient permissions for the blocks it is used on.",
)
credentials: CredentialsMetaInput[
Literal[ProviderName.REPLICATE], Literal["api_key"]
] = CredentialsField(
description="The Replicate integration can be used with "
"any API key with sufficient permissions for the blocks it is used on.",
)
prompt: str = SchemaField(
description="A description of the music you want to generate",

View File

@@ -12,6 +12,7 @@ from backend.data.model import (
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.request import requests
TEST_CREDENTIALS = APIKeyCredentials(
@@ -140,13 +141,11 @@ logger = logging.getLogger(__name__)
class AIShortformVideoCreatorBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput[Literal["revid"], Literal["api_key"]] = (
CredentialsField(
provider="revid",
supported_credential_types={"api_key"},
description="The revid.ai integration can be used with "
"any API key with sufficient permissions for the blocks it is used on.",
)
credentials: CredentialsMetaInput[
Literal[ProviderName.REVID], Literal["api_key"]
] = CredentialsField(
description="The revid.ai integration can be used with "
"any API key with sufficient permissions for the blocks it is used on.",
)
script: str = SchemaField(
description="""1. Use short and punctuated sentences\n\n2. Use linebreaks to create a new clip\n\n3. Text outside of brackets is spoken by the AI, and [text between brackets] will be used to guide the visual generation. For example, [close-up of a cat] will show a close-up of a cat.""",

View File

@@ -11,6 +11,7 @@ from backend.data.model import (
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
@@ -39,12 +40,10 @@ class CodeExecutionBlock(Block):
# TODO : Add support to upload and download files
# Currently, You can customized the CPU and Memory, only by creating a pre customized sandbox template
class Input(BlockSchema):
credentials: CredentialsMetaInput[Literal["e2b"], Literal["api_key"]] = (
CredentialsField(
provider="e2b",
supported_credential_types={"api_key"},
description="Enter your api key for the E2B Sandbox. You can get it in here - https://e2b.dev/docs",
)
credentials: CredentialsMetaInput[
Literal[ProviderName.E2B], Literal["api_key"]
] = CredentialsField(
description="Enter your api key for the E2B Sandbox. You can get it in here - https://e2b.dev/docs",
)
# Todo : Option to run commond in background

View File

@@ -12,16 +12,15 @@ from backend.data.model import (
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
DiscordCredentials = CredentialsMetaInput[Literal["discord"], Literal["api_key"]]
DiscordCredentials = CredentialsMetaInput[
Literal[ProviderName.DISCORD], Literal["api_key"]
]
def DiscordCredentialsField() -> DiscordCredentials:
return CredentialsField(
description="Discord bot token",
provider="discord",
supported_credential_types={"api_key"},
)
return CredentialsField(description="Discord bot token")
TEST_CREDENTIALS = APIKeyCredentials(

View File

@@ -3,10 +3,11 @@ from typing import Literal
from pydantic import SecretStr
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
from backend.integrations.providers import ProviderName
ExaCredentials = APIKeyCredentials
ExaCredentialsInput = CredentialsMetaInput[
Literal["exa"],
Literal[ProviderName.EXA],
Literal["api_key"],
]
@@ -28,8 +29,4 @@ TEST_CREDENTIALS_INPUT = {
def ExaCredentialsField() -> ExaCredentialsInput:
"""Creates an Exa credentials input on a block."""
return CredentialsField(
provider="exa",
supported_credential_types={"api_key"},
description="The Exa integration requires an API Key.",
)
return CredentialsField(description="The Exa integration requires an API Key.")

View File

@@ -3,10 +3,11 @@ from typing import Literal
from pydantic import SecretStr
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
from backend.integrations.providers import ProviderName
FalCredentials = APIKeyCredentials
FalCredentialsInput = CredentialsMetaInput[
Literal["fal"],
Literal[ProviderName.FAL],
Literal["api_key"],
]
@@ -30,7 +31,5 @@ def FalCredentialsField() -> FalCredentialsInput:
Creates a FAL credentials input on a block.
"""
return CredentialsField(
provider="fal",
supported_credential_types={"api_key"},
description="The FAL integration can be used with an API Key.",
)

View File

@@ -8,6 +8,7 @@ from backend.data.model import (
CredentialsMetaInput,
OAuth2Credentials,
)
from backend.integrations.providers import ProviderName
from backend.util.settings import Secrets
secrets = Secrets()
@@ -17,7 +18,7 @@ GITHUB_OAUTH_IS_CONFIGURED = bool(
GithubCredentials = APIKeyCredentials | OAuth2Credentials
GithubCredentialsInput = CredentialsMetaInput[
Literal["github"],
Literal[ProviderName.GITHUB],
Literal["api_key", "oauth2"] if GITHUB_OAUTH_IS_CONFIGURED else Literal["api_key"],
]
@@ -30,10 +31,6 @@ def GithubCredentialsField(scope: str) -> GithubCredentialsInput:
scope: The authorization scope needed for the block to work. ([list of available scopes](https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/scopes-for-oauth-apps#available-scopes))
""" # noqa
return CredentialsField(
provider="github",
supported_credential_types=(
{"api_key", "oauth2"} if GITHUB_OAUTH_IS_CONFIGURED else {"api_key"}
),
required_scopes={scope},
description="The GitHub integration can be used with OAuth, "
"or any API key with sufficient permissions for the blocks it is used on.",

View File

@@ -3,6 +3,7 @@ from typing import Literal
from pydantic import SecretStr
from backend.data.model import CredentialsField, CredentialsMetaInput, OAuth2Credentials
from backend.integrations.providers import ProviderName
from backend.util.settings import Secrets
# --8<-- [start:GoogleOAuthIsConfigured]
@@ -12,7 +13,9 @@ GOOGLE_OAUTH_IS_CONFIGURED = bool(
)
# --8<-- [end:GoogleOAuthIsConfigured]
GoogleCredentials = OAuth2Credentials
GoogleCredentialsInput = CredentialsMetaInput[Literal["google"], Literal["oauth2"]]
GoogleCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.GOOGLE], Literal["oauth2"]
]
def GoogleCredentialsField(scopes: list[str]) -> GoogleCredentialsInput:
@@ -23,8 +26,6 @@ def GoogleCredentialsField(scopes: list[str]) -> GoogleCredentialsInput:
scopes: The authorization scopes needed for the block to work.
"""
return CredentialsField(
provider="google",
supported_credential_types={"oauth2"},
required_scopes=set(scopes),
description="The Google integration requires OAuth2 authentication.",
)

View File

@@ -10,6 +10,7 @@ from backend.data.model import (
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
@@ -38,12 +39,8 @@ class Place(BaseModel):
class GoogleMapsSearchBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput[
Literal["google_maps"], Literal["api_key"]
] = CredentialsField(
provider="google_maps",
supported_credential_types={"api_key"},
description="Google Maps API Key",
)
Literal[ProviderName.GOOGLE_MAPS], Literal["api_key"]
] = CredentialsField(description="Google Maps API Key")
query: str = SchemaField(
description="Search query for local businesses",
placeholder="e.g., 'restaurants in New York'",

View File

@@ -3,10 +3,11 @@ from typing import Literal
from pydantic import SecretStr
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
from backend.integrations.providers import ProviderName
HubSpotCredentials = APIKeyCredentials
HubSpotCredentialsInput = CredentialsMetaInput[
Literal["hubspot"],
Literal[ProviderName.HUBSPOT],
Literal["api_key"],
]
@@ -14,8 +15,6 @@ HubSpotCredentialsInput = CredentialsMetaInput[
def HubSpotCredentialsField() -> HubSpotCredentialsInput:
"""Creates a HubSpot credentials input on a block."""
return CredentialsField(
provider="hubspot",
supported_credential_types={"api_key"},
description="The HubSpot integration requires an API Key.",
)

View File

@@ -11,6 +11,7 @@ from backend.data.model import (
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.request import requests
TEST_CREDENTIALS = APIKeyCredentials(
@@ -83,13 +84,10 @@ class UpscaleOption(str, Enum):
class IdeogramModelBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput[Literal["ideogram"], Literal["api_key"]] = (
CredentialsField(
provider="ideogram",
supported_credential_types={"api_key"},
description="The Ideogram integration can be used with any API key with sufficient permissions for the blocks it is used on.",
)
credentials: CredentialsMetaInput[
Literal[ProviderName.IDEOGRAM], Literal["api_key"]
] = CredentialsField(
description="The Ideogram integration can be used with any API key with sufficient permissions for the blocks it is used on.",
)
prompt: str = SchemaField(
description="Text prompt for image generation",

View File

@@ -3,27 +3,14 @@ from typing import Literal
from pydantic import SecretStr
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
from backend.integrations.providers import ProviderName
JinaCredentials = APIKeyCredentials
JinaCredentialsInput = CredentialsMetaInput[
Literal["jina"],
Literal[ProviderName.JINA],
Literal["api_key"],
]
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="jina",
api_key=SecretStr("mock-jina-api-key"),
title="Mock Jina API key",
expires_at=None,
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.type,
}
def JinaCredentialsField() -> JinaCredentialsInput:
"""
@@ -31,8 +18,6 @@ def JinaCredentialsField() -> JinaCredentialsInput:
"""
return CredentialsField(
provider="jina",
supported_credential_types={"api_key"},
description="The Jina integration can be used with an API Key.",
)

View File

@@ -7,6 +7,8 @@ from typing import TYPE_CHECKING, Any, List, Literal, NamedTuple
from pydantic import SecretStr
from backend.integrations.providers import ProviderName
if TYPE_CHECKING:
from enum import _EnumMemberT
@@ -27,7 +29,13 @@ from backend.util.settings import BehaveAs, Settings
logger = logging.getLogger(__name__)
LLMProviderName = Literal["anthropic", "groq", "openai", "ollama", "open_router"]
LLMProviderName = Literal[
ProviderName.ANTHROPIC,
ProviderName.GROQ,
ProviderName.OLLAMA,
ProviderName.OPENAI,
ProviderName.OPEN_ROUTER,
]
AICredentials = CredentialsMetaInput[LLMProviderName, Literal["api_key"]]
TEST_CREDENTIALS = APIKeyCredentials(
@@ -48,8 +56,6 @@ TEST_CREDENTIALS_INPUT = {
def AICredentialsField() -> AICredentials:
return CredentialsField(
description="API key for the LLM provider.",
provider=["anthropic", "groq", "openai", "ollama", "open_router"],
supported_credential_types={"api_key"},
discriminator="model",
discriminator_mapping={
model.value: model.metadata.provider for model in LlmModel

View File

@@ -12,6 +12,7 @@ from backend.data.model import (
SchemaField,
SecretField,
)
from backend.integrations.providers import ProviderName
from backend.util.request import requests
TEST_CREDENTIALS = APIKeyCredentials(
@@ -77,12 +78,10 @@ class PublishToMediumBlock(Block):
description="Whether to notify followers that the user has published",
placeholder="False",
)
credentials: CredentialsMetaInput[Literal["medium"], Literal["api_key"]] = (
CredentialsField(
provider="medium",
supported_credential_types={"api_key"},
description="The Medium integration can be used with any API key with sufficient permissions for the blocks it is used on.",
)
credentials: CredentialsMetaInput[
Literal[ProviderName.MEDIUM], Literal["api_key"]
] = CredentialsField(
description="The Medium integration can be used with any API key with sufficient permissions for the blocks it is used on.",
)
class Output(BlockSchema):

View File

@@ -10,22 +10,18 @@ from backend.data.model import (
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
PineconeCredentials = APIKeyCredentials
PineconeCredentialsInput = CredentialsMetaInput[
Literal["pinecone"],
Literal[ProviderName.PINECONE],
Literal["api_key"],
]
def PineconeCredentialsField() -> PineconeCredentialsInput:
"""
Creates a Pinecone credentials input on a block.
"""
"""Creates a Pinecone credentials input on a block."""
return CredentialsField(
provider="pinecone",
supported_credential_types={"api_key"},
description="The Pinecone integration can be used with an API Key.",
)

View File

@@ -13,6 +13,7 @@ from backend.data.model import (
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
@@ -54,13 +55,11 @@ class ImageType(str, Enum):
class ReplicateFluxAdvancedModelBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput[Literal["replicate"], Literal["api_key"]] = (
CredentialsField(
provider="replicate",
supported_credential_types={"api_key"},
description="The Replicate integration can be used with "
"any API key with sufficient permissions for the blocks it is used on.",
)
credentials: CredentialsMetaInput[
Literal[ProviderName.REPLICATE], Literal["api_key"]
] = CredentialsField(
description="The Replicate integration can be used with "
"any API key with sufficient permissions for the blocks it is used on.",
)
prompt: str = SchemaField(
description="Text prompt for image generation",

View File

@@ -11,6 +11,7 @@ from backend.data.model import (
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
class GetWikipediaSummaryBlock(Block, GetRequest):
@@ -65,10 +66,8 @@ class GetWeatherInformationBlock(Block, GetRequest):
description="Location to get weather information for"
)
credentials: CredentialsMetaInput[
Literal["openweathermap"], Literal["api_key"]
Literal[ProviderName.OPENWEATHERMAP], Literal["api_key"]
] = CredentialsField(
provider="openweathermap",
supported_credential_types={"api_key"},
description="The OpenWeatherMap integration can be used with "
"any API key with sufficient permissions for the blocks it is used on.",
)

View File

@@ -4,16 +4,15 @@ from typing import Literal
from pydantic import BaseModel, SecretStr
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
from backend.integrations.providers import ProviderName
Slant3DCredentialsInput = CredentialsMetaInput[Literal["slant3d"], Literal["api_key"]]
Slant3DCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.SLANT3D], Literal["api_key"]
]
def Slant3DCredentialsField() -> Slant3DCredentialsInput:
return CredentialsField(
provider="slant3d",
supported_credential_types={"api_key"},
description="Slant3D API key for authentication",
)
return CredentialsField(description="Slant3D API key for authentication")
TEST_CREDENTIALS = APIKeyCredentials(

View File

@@ -10,6 +10,7 @@ from backend.data.model import (
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.request import requests
TEST_CREDENTIALS = APIKeyCredentials(
@@ -29,13 +30,11 @@ TEST_CREDENTIALS_INPUT = {
class CreateTalkingAvatarVideoBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput[Literal["d_id"], Literal["api_key"]] = (
CredentialsField(
provider="d_id",
supported_credential_types={"api_key"},
description="The D-ID integration can be used with "
"any API key with sufficient permissions for the blocks it is used on.",
)
credentials: CredentialsMetaInput[
Literal[ProviderName.D_ID], Literal["api_key"]
] = CredentialsField(
description="The D-ID integration can be used with "
"any API key with sufficient permissions for the blocks it is used on.",
)
script_input: str = SchemaField(
description="The text input for the script",

View File

@@ -9,6 +9,7 @@ from backend.data.model import (
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.request import requests
TEST_CREDENTIALS = APIKeyCredentials(
@@ -38,10 +39,8 @@ class UnrealTextToSpeechBlock(Block):
default="Scarlett",
)
credentials: CredentialsMetaInput[
Literal["unreal_speech"], Literal["api_key"]
Literal[ProviderName.UNREAL_SPEECH], Literal["api_key"]
] = CredentialsField(
provider="unreal_speech",
supported_credential_types={"api_key"},
description="The Unreal Speech integration can be used with "
"any API key with sufficient permissions for the blocks it is used on.",
)

View File

@@ -65,7 +65,7 @@ class BlockCategory(Enum):
class BlockSchema(BaseModel):
cached_jsonschema: ClassVar[dict[str, Any]] = {}
cached_jsonschema: ClassVar[dict[str, Any]]
@classmethod
def jsonschema(cls) -> dict[str, Any]:
@@ -145,6 +145,10 @@ class BlockSchema(BaseModel):
- A field that is called `credentials` MUST be a `CredentialsMetaInput`.
"""
super().__pydantic_init_subclass__(**kwargs)
# Reset cached JSON schema to prevent inheriting it from parent class
cls.cached_jsonschema = {}
credentials_fields = [
field_name
for field_name, info in cls.model_fields.items()
@@ -176,6 +180,11 @@ class BlockSchema(BaseModel):
f"Field 'credentials' on {cls.__qualname__} "
f"must be of type {CredentialsMetaInput.__name__}"
)
if credentials_field := cls.model_fields.get(CREDENTIALS_FIELD_NAME):
credentials_input_type = cast(
CredentialsMetaInput, credentials_field.annotation
)
credentials_input_type.validate_credentials_field_schema(cls)
BlockSchemaInputType = TypeVar("BlockSchemaInputType", bound=BlockSchema)

View File

@@ -7,6 +7,7 @@ from pydantic import Field
from backend.data.includes import INTEGRATION_WEBHOOK_INCLUDE
from backend.data.queue import AsyncRedisEventBus
from backend.integrations.providers import ProviderName
from .db import BaseDbModel
@@ -18,7 +19,7 @@ logger = logging.getLogger(__name__)
class Webhook(BaseDbModel):
user_id: str
provider: str
provider: ProviderName
credentials_id: str
webhook_type: str
resource: str
@@ -37,7 +38,7 @@ class Webhook(BaseDbModel):
return Webhook(
id=webhook.id,
user_id=webhook.userId,
provider=webhook.provider,
provider=ProviderName(webhook.provider),
credentials_id=webhook.credentialsId,
webhook_type=webhook.webhookType,
resource=webhook.resource,
@@ -61,7 +62,7 @@ async def create_webhook(webhook: Webhook) -> Webhook:
data={
"id": webhook.id,
"userId": webhook.user_id,
"provider": webhook.provider,
"provider": webhook.provider.value,
"credentialsId": webhook.credentials_id,
"webhookType": webhook.webhook_type,
"resource": webhook.resource,

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import logging
from typing import (
TYPE_CHECKING,
Annotated,
Any,
Callable,
@@ -11,19 +12,32 @@ from typing import (
Optional,
TypedDict,
TypeVar,
get_args,
)
from uuid import uuid4
from pydantic import BaseModel, Field, GetCoreSchemaHandler, SecretStr, field_serializer
from pydantic import (
BaseModel,
ConfigDict,
Field,
GetCoreSchemaHandler,
SecretStr,
field_serializer,
)
from pydantic_core import (
CoreSchema,
PydanticUndefined,
PydanticUndefinedType,
ValidationError,
core_schema,
)
from backend.integrations.providers import ProviderName
from backend.util.settings import Secrets
if TYPE_CHECKING:
from backend.data.block import BlockSchema
T = TypeVar("T")
logger = logging.getLogger(__name__)
@@ -220,7 +234,7 @@ class UserIntegrations(BaseModel):
oauth_states: list[OAuthState] = Field(default_factory=list)
CP = TypeVar("CP", bound=str)
CP = TypeVar("CP", bound=ProviderName)
CT = TypeVar("CT", bound=CredentialsType)
@@ -233,19 +247,51 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
provider: CP
type: CT
@staticmethod
def _add_json_schema_extra(schema, cls: CredentialsMetaInput):
schema["credentials_provider"] = get_args(
cls.model_fields["provider"].annotation
)
schema["credentials_types"] = get_args(cls.model_fields["type"].annotation)
class CredentialsFieldSchemaExtra(BaseModel, Generic[CP, CT]):
model_config = ConfigDict(
json_schema_extra=_add_json_schema_extra, # type: ignore
)
@classmethod
def validate_credentials_field_schema(cls, model: type["BlockSchema"]):
"""Validates the schema of a `credentials` field"""
field_schema = model.jsonschema()["properties"][CREDENTIALS_FIELD_NAME]
try:
schema_extra = _CredentialsFieldSchemaExtra[CP, CT].model_validate(
field_schema
)
except ValidationError as e:
if "Field required [type=missing" not in str(e):
raise
raise TypeError(
"Field 'credentials' JSON schema lacks required extra items: "
f"{field_schema}"
) from e
if (
len(schema_extra.credentials_provider) > 1
and not schema_extra.discriminator
):
raise TypeError("Multi-provider CredentialsField requires discriminator!")
class _CredentialsFieldSchemaExtra(BaseModel, Generic[CP, CT]):
# TODO: move discrimination mechanism out of CredentialsField (frontend + backend)
credentials_provider: list[CP]
credentials_scopes: Optional[list[str]]
credentials_scopes: Optional[list[str]] = None
credentials_types: list[CT]
discriminator: Optional[str] = None
discriminator_mapping: Optional[dict[str, CP]] = None
def CredentialsField(
provider: CP | list[CP],
supported_credential_types: set[CT],
required_scopes: set[str] = set(),
*,
discriminator: Optional[str] = None,
@@ -253,26 +299,26 @@ def CredentialsField(
title: Optional[str] = None,
description: Optional[str] = None,
**kwargs,
) -> CredentialsMetaInput[CP, CT]:
) -> CredentialsMetaInput:
"""
`CredentialsField` must and can only be used on fields named `credentials`.
This is enforced by the `BlockSchema` base class.
"""
if not isinstance(provider, str) and len(provider) > 1 and not discriminator:
raise TypeError("Multi-provider CredentialsField requires discriminator!")
field_schema_extra = CredentialsFieldSchemaExtra[CP, CT](
credentials_provider=[provider] if isinstance(provider, str) else provider,
credentials_scopes=list(required_scopes) or None, # omit if empty
credentials_types=list(supported_credential_types),
discriminator=discriminator,
discriminator_mapping=discriminator_mapping,
)
field_schema_extra = {
k: v
for k, v in {
"credentials_scopes": list(required_scopes) or None,
"discriminator": discriminator,
"discriminator_mapping": discriminator_mapping,
}.items()
if v is not None
}
return Field(
title=title,
description=description,
json_schema_extra=field_schema_extra.model_dump(exclude_none=True),
json_schema_extra=field_schema_extra, # validated on BlockSchema init
**kwargs,
)

View File

@@ -1,6 +1,7 @@
import logging
from contextlib import contextmanager
from datetime import datetime
from typing import TYPE_CHECKING
from autogpt_libs.utils.synchronize import RedisKeyedMutex
from redis.lock import Lock as RedisLock
@@ -8,10 +9,13 @@ from redis.lock import Lock as RedisLock
from backend.data import redis
from backend.data.model import Credentials
from backend.integrations.credentials_store import IntegrationCredentialsStore
from backend.integrations.oauth import HANDLERS_BY_NAME, BaseOAuthHandler
from backend.integrations.oauth import HANDLERS_BY_NAME
from backend.util.exceptions import MissingConfigError
from backend.util.settings import Settings
if TYPE_CHECKING:
from backend.integrations.oauth import BaseOAuthHandler
logger = logging.getLogger(__name__)
settings = Settings()
@@ -148,7 +152,7 @@ class IntegrationCredentialsManager:
self.store.locks.release_all_locks()
def _get_provider_oauth_handler(provider_name: str) -> BaseOAuthHandler:
def _get_provider_oauth_handler(provider_name: str) -> "BaseOAuthHandler":
if provider_name not in HANDLERS_BY_NAME:
raise KeyError(f"Unknown provider '{provider_name}'")

View File

@@ -1,10 +1,15 @@
from .base import BaseOAuthHandler
from typing import TYPE_CHECKING
from .github import GitHubOAuthHandler
from .google import GoogleOAuthHandler
from .notion import NotionOAuthHandler
if TYPE_CHECKING:
from ..providers import ProviderName
from .base import BaseOAuthHandler
# --8<-- [start:HANDLERS_BY_NAMEExample]
HANDLERS_BY_NAME: dict[str, type[BaseOAuthHandler]] = {
HANDLERS_BY_NAME: dict["ProviderName", type["BaseOAuthHandler"]] = {
handler.PROVIDER_NAME: handler
for handler in [
GitHubOAuthHandler,

View File

@@ -4,13 +4,14 @@ from abc import ABC, abstractmethod
from typing import ClassVar
from backend.data.model import OAuth2Credentials
from backend.integrations.providers import ProviderName
logger = logging.getLogger(__name__)
class BaseOAuthHandler(ABC):
# --8<-- [start:BaseOAuthHandler1]
PROVIDER_NAME: ClassVar[str]
PROVIDER_NAME: ClassVar[ProviderName]
DEFAULT_SCOPES: ClassVar[list[str]] = []
# --8<-- [end:BaseOAuthHandler1]
@@ -76,6 +77,8 @@ class BaseOAuthHandler(ABC):
"""Handles the default scopes for the provider"""
# If scopes are empty, use the default scopes for the provider
if not scopes:
logger.debug(f"Using default scopes for provider {self.PROVIDER_NAME}")
logger.debug(
f"Using default scopes for provider {self.PROVIDER_NAME.value}"
)
scopes = self.DEFAULT_SCOPES
return scopes

View File

@@ -3,6 +3,7 @@ from typing import Optional
from urllib.parse import urlencode
from backend.data.model import OAuth2Credentials
from backend.integrations.providers import ProviderName
from backend.util.request import requests
from .base import BaseOAuthHandler
@@ -23,7 +24,7 @@ class GitHubOAuthHandler(BaseOAuthHandler):
access token *with no refresh token*.
""" # noqa
PROVIDER_NAME = "github"
PROVIDER_NAME = ProviderName.GITHUB
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
self.client_id = client_id

View File

@@ -9,6 +9,7 @@ from google_auth_oauthlib.flow import Flow
from pydantic import SecretStr
from backend.data.model import OAuth2Credentials
from backend.integrations.providers import ProviderName
from .base import BaseOAuthHandler
@@ -21,7 +22,7 @@ class GoogleOAuthHandler(BaseOAuthHandler):
Based on the documentation at https://developers.google.com/identity/protocols/oauth2/web-server
""" # noqa
PROVIDER_NAME = "google"
PROVIDER_NAME = ProviderName.GOOGLE
EMAIL_ENDPOINT = "https://www.googleapis.com/oauth2/v2/userinfo"
DEFAULT_SCOPES = [
"https://www.googleapis.com/auth/userinfo.email",

View File

@@ -2,6 +2,7 @@ from base64 import b64encode
from urllib.parse import urlencode
from backend.data.model import OAuth2Credentials
from backend.integrations.providers import ProviderName
from backend.util.request import requests
from .base import BaseOAuthHandler
@@ -16,7 +17,7 @@ class NotionOAuthHandler(BaseOAuthHandler):
- Notion doesn't use scopes
"""
PROVIDER_NAME = "notion"
PROVIDER_NAME = ProviderName.NOTION
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
self.client_id = client_id

View File

@@ -1,7 +1,30 @@
from enum import Enum
# --8<-- [start:ProviderName]
class ProviderName(str, Enum):
ANTHROPIC = "anthropic"
DISCORD = "discord"
D_ID = "d_id"
E2B = "e2b"
EXA = "exa"
FAL = "fal"
GITHUB = "github"
GOOGLE = "google"
GOOGLE_MAPS = "google_maps"
GROQ = "groq"
HUBSPOT = "hubspot"
IDEOGRAM = "ideogram"
JINA = "jina"
MEDIUM = "medium"
NOTION = "notion"
OLLAMA = "ollama"
OPENAI = "openai"
OPENWEATHERMAP = "openweathermap"
OPEN_ROUTER = "open_router"
PINECONE = "pinecone"
REPLICATE = "replicate"
REVID = "revid"
SLANT3D = "slant3d"
UNREAL_SPEECH = "unreal_speech"
# --8<-- [end:ProviderName]

View File

@@ -4,10 +4,11 @@ from .github import GithubWebhooksManager
from .slant3d import Slant3DWebhooksManager
if TYPE_CHECKING:
from ..providers import ProviderName
from .base import BaseWebhooksManager
# --8<-- [start:WEBHOOK_MANAGERS_BY_NAME]
WEBHOOK_MANAGERS_BY_NAME: dict[str, type["BaseWebhooksManager"]] = {
WEBHOOK_MANAGERS_BY_NAME: dict["ProviderName", type["BaseWebhooksManager"]] = {
handler.PROVIDER_NAME: handler
for handler in [
GithubWebhooksManager,

View File

@@ -9,6 +9,7 @@ from strenum import StrEnum
from backend.data import integrations
from backend.data.model import Credentials
from backend.integrations.providers import ProviderName
from backend.util.exceptions import MissingConfigError
from backend.util.settings import Config
@@ -20,7 +21,7 @@ WT = TypeVar("WT", bound=StrEnum)
class BaseWebhooksManager(ABC, Generic[WT]):
# --8<-- [start:BaseWebhooksManager1]
PROVIDER_NAME: ClassVar[str]
PROVIDER_NAME: ClassVar[ProviderName]
# --8<-- [end:BaseWebhooksManager1]
WebhookType: WT
@@ -143,7 +144,7 @@ class BaseWebhooksManager(ABC, Generic[WT]):
secret = secrets.token_hex(32)
provider_name = self.PROVIDER_NAME
ingress_url = (
f"{app_config.platform_base_url}/api/integrations/{provider_name}"
f"{app_config.platform_base_url}/api/integrations/{provider_name.value}"
f"/webhooks/{id}/ingress"
)
provider_webhook_id, config = await self._register_webhook(

View File

@@ -8,6 +8,7 @@ from strenum import StrEnum
from backend.data import integrations
from backend.data.model import Credentials
from backend.integrations.providers import ProviderName
from .base import BaseWebhooksManager
@@ -20,7 +21,7 @@ class GithubWebhookType(StrEnum):
class GithubWebhooksManager(BaseWebhooksManager):
PROVIDER_NAME = "github"
PROVIDER_NAME = ProviderName.GITHUB
WebhookType = GithubWebhookType

View File

@@ -95,11 +95,18 @@ async def on_node_activate(
if not block.webhook_config:
return node
provider = block.webhook_config.provider
if provider not in WEBHOOK_MANAGERS_BY_NAME:
raise ValueError(
f"Block #{block.id} has webhook_config for provider {provider} "
"which does not support webhooks"
)
logger.debug(
f"Activating webhook node #{node.id} with config {block.webhook_config}"
)
webhooks_manager = WEBHOOK_MANAGERS_BY_NAME[block.webhook_config.provider]()
webhooks_manager = WEBHOOK_MANAGERS_BY_NAME[provider]()
try:
resource = block.webhook_config.resource_format.format(**node.input_default)
@@ -167,7 +174,14 @@ async def on_node_deactivate(
if not block.webhook_config:
return node
webhooks_manager = WEBHOOK_MANAGERS_BY_NAME[block.webhook_config.provider]()
provider = block.webhook_config.provider
if provider not in WEBHOOK_MANAGERS_BY_NAME:
raise ValueError(
f"Block #{block.id} has webhook_config for provider {provider} "
"which does not support webhooks"
)
webhooks_manager = WEBHOOK_MANAGERS_BY_NAME[provider]()
if node.webhook_id:
logger.debug(f"Node #{node.id} has webhook_id {node.webhook_id}")
@@ -189,7 +203,7 @@ async def on_node_deactivate(
logger.warning(
f"Cannot deregister webhook #{webhook.id}: credentials "
f"#{webhook.credentials_id} not available "
f"({webhook.provider} webhook ID: {webhook.provider_webhook_id})"
f"({webhook.provider.value} webhook ID: {webhook.provider_webhook_id})"
)
return updated_node

View File

@@ -1,11 +1,11 @@
import logging
from typing import ClassVar
import requests
from fastapi import Request
from backend.data import integrations
from backend.data.model import APIKeyCredentials, Credentials
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks.base import BaseWebhooksManager
logger = logging.getLogger(__name__)
@@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
class Slant3DWebhooksManager(BaseWebhooksManager):
"""Manager for Slant3D webhooks"""
PROVIDER_NAME: ClassVar[str] = "slant3d"
PROVIDER_NAME = ProviderName.SLANT3D
BASE_URL = "https://www.slant3dapi.com/api"
async def _register_webhook(

View File

@@ -1,5 +1,5 @@
import logging
from typing import Annotated, Literal
from typing import TYPE_CHECKING, Annotated, Literal
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request
from pydantic import BaseModel, Field, SecretStr
@@ -20,12 +20,16 @@ from backend.data.model import (
)
from backend.executor.manager import ExecutionManager
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.oauth import HANDLERS_BY_NAME, BaseOAuthHandler
from backend.integrations.oauth import HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks import WEBHOOK_MANAGERS_BY_NAME
from backend.util.exceptions import NeedConfirmation
from backend.util.service import get_service_client
from backend.util.settings import Settings
if TYPE_CHECKING:
from backend.integrations.oauth import BaseOAuthHandler
from ..utils import get_user_id
logger = logging.getLogger(__name__)
@@ -42,7 +46,9 @@ class LoginResponse(BaseModel):
@router.get("/{provider}/login")
def login(
provider: Annotated[str, Path(title="The provider to initiate an OAuth flow for")],
provider: Annotated[
ProviderName, Path(title="The provider to initiate an OAuth flow for")
],
user_id: Annotated[str, Depends(get_user_id)],
request: Request,
scopes: Annotated[
@@ -74,7 +80,9 @@ class CredentialsMetaResponse(BaseModel):
@router.post("/{provider}/callback")
def callback(
provider: Annotated[str, Path(title="The target provider for this OAuth exchange")],
provider: Annotated[
ProviderName, Path(title="The target provider for this OAuth exchange")
],
code: Annotated[str, Body(title="Authorization code acquired by user login")],
state_token: Annotated[str, Body(title="Anti-CSRF nonce")],
user_id: Annotated[str, Depends(get_user_id)],
@@ -103,11 +111,12 @@ def callback(
if not set(scopes).issubset(set(credentials.scopes)):
# For now, we'll just log the warning and continue
logger.warning(
f"Granted scopes {credentials.scopes} for {provider}do not include all requested scopes {scopes}"
f"Granted scopes {credentials.scopes} for provider {provider.value} "
f"do not include all requested scopes {scopes}"
)
except Exception as e:
logger.error(f"Code->Token exchange failed for provider {provider}: {e}")
logger.error(f"Code->Token exchange failed for provider {provider.value}: {e}")
raise HTTPException(
status_code=400, detail=f"Failed to exchange code for tokens: {str(e)}"
)
@@ -116,7 +125,8 @@ def callback(
creds_manager.create(user_id, credentials)
logger.debug(
f"Successfully processed OAuth callback for user {user_id} and provider {provider}"
f"Successfully processed OAuth callback for user {user_id} "
f"and provider {provider.value}"
)
return CredentialsMetaResponse(
id=credentials.id,
@@ -148,7 +158,9 @@ def list_credentials(
@router.get("/{provider}/credentials")
def list_credentials_by_provider(
provider: Annotated[str, Path(title="The provider to list credentials for")],
provider: Annotated[
ProviderName, Path(title="The provider to list credentials for")
],
user_id: Annotated[str, Depends(get_user_id)],
) -> list[CredentialsMetaResponse]:
credentials = creds_manager.store.get_creds_by_provider(user_id, provider)
@@ -167,7 +179,9 @@ def list_credentials_by_provider(
@router.get("/{provider}/credentials/{cred_id}")
def get_credential(
provider: Annotated[str, Path(title="The provider to retrieve credentials for")],
provider: Annotated[
ProviderName, Path(title="The provider to retrieve credentials for")
],
cred_id: Annotated[str, Path(title="The ID of the credentials to retrieve")],
user_id: Annotated[str, Depends(get_user_id)],
) -> Credentials:
@@ -184,7 +198,9 @@ def get_credential(
@router.post("/{provider}/credentials", status_code=201)
def create_api_key_credentials(
user_id: Annotated[str, Depends(get_user_id)],
provider: Annotated[str, Path(title="The provider to create credentials for")],
provider: Annotated[
ProviderName, Path(title="The provider to create credentials for")
],
api_key: Annotated[str, Body(title="The API key to store")],
title: Annotated[str, Body(title="Optional title for the credentials")],
expires_at: Annotated[
@@ -225,7 +241,9 @@ class CredentialsDeletionNeedsConfirmationResponse(BaseModel):
@router.delete("/{provider}/credentials/{cred_id}")
async def delete_credentials(
request: Request,
provider: Annotated[str, Path(title="The provider to delete credentials for")],
provider: Annotated[
ProviderName, Path(title="The provider to delete credentials for")
],
cred_id: Annotated[str, Path(title="The ID of the credentials to delete")],
user_id: Annotated[str, Depends(get_user_id)],
force: Annotated[
@@ -264,15 +282,20 @@ async def delete_credentials(
@router.post("/{provider}/webhooks/{webhook_id}/ingress")
async def webhook_ingress_generic(
request: Request,
provider: Annotated[str, Path(title="Provider where the webhook was registered")],
provider: Annotated[
ProviderName, Path(title="Provider where the webhook was registered")
],
webhook_id: Annotated[str, Path(title="Our ID for the webhook")],
):
logger.debug(f"Received {provider} webhook ingress for ID {webhook_id}")
logger.debug(f"Received {provider.value} webhook ingress for ID {webhook_id}")
webhook_manager = WEBHOOK_MANAGERS_BY_NAME[provider]()
webhook = await get_webhook(webhook_id)
logger.debug(f"Webhook #{webhook_id}: {webhook}")
payload, event_type = await webhook_manager.validate_payload(webhook, request)
logger.debug(f"Validated {provider} {event_type} event with payload {payload}")
logger.debug(
f"Validated {provider.value} {webhook.webhook_type} {event_type} event "
f"with payload {payload}"
)
webhook_event = WebhookEvent(
provider=provider,
@@ -341,6 +364,14 @@ async def remove_all_webhooks_for_credentials(
NeedConfirmation: If any of the webhooks are still in use and `force` is `False`
"""
webhooks = await get_all_webhooks(credentials.id)
if credentials.provider not in WEBHOOK_MANAGERS_BY_NAME:
if webhooks:
logger.error(
f"Credentials #{credentials.id} for provider {credentials.provider} "
f"are attached to {len(webhooks)} webhooks, "
f"but there is no available WebhooksHandler for {credentials.provider}"
)
return
if any(w.attached_nodes for w in webhooks) and not force:
raise NeedConfirmation(
"Some webhooks linked to these credentials are still in use by an agent"
@@ -359,18 +390,23 @@ async def remove_all_webhooks_for_credentials(
logger.warning(f"Webhook #{webhook.id} failed to prune")
def _get_provider_oauth_handler(req: Request, provider_name: str) -> BaseOAuthHandler:
def _get_provider_oauth_handler(
req: Request, provider_name: ProviderName
) -> "BaseOAuthHandler":
if provider_name not in HANDLERS_BY_NAME:
raise HTTPException(
status_code=404, detail=f"Unknown provider '{provider_name}'"
status_code=404,
detail=f"Provider '{provider_name.value}' does not support OAuth",
)
client_id = getattr(settings.secrets, f"{provider_name}_client_id")
client_secret = getattr(settings.secrets, f"{provider_name}_client_secret")
client_id = getattr(settings.secrets, f"{provider_name.value}_client_id")
client_secret = getattr(settings.secrets, f"{provider_name.value}_client_secret")
if not (client_id and client_secret):
raise HTTPException(
status_code=501,
detail=f"Integration with provider '{provider_name}' is not configured",
detail=(
f"Integration with provider '{provider_name.value}' is not configured"
),
)
handler_class = HANDLERS_BY_NAME[provider_name]

View File

@@ -121,6 +121,7 @@ from backend.data.model import (
from backend.data.block import Block, BlockOutput, BlockSchema
from backend.data.model import CredentialsField
from backend.integrations.providers import ProviderName
# API Key auth:
@@ -128,9 +129,9 @@ class BlockWithAPIKeyAuth(Block):
class Input(BlockSchema):
# Note that the type hint below is require or you will get a type error.
# The first argument is the provider name, the second is the credential type.
credentials: CredentialsMetaInput[Literal['github'], Literal['api_key']] = CredentialsField(
provider="github",
supported_credential_types={"api_key"},
credentials: CredentialsMetaInput[
Literal[ProviderName.GITHUB], Literal["api_key"]
] = CredentialsField(
description="The GitHub integration can be used with "
"any API key with sufficient permissions for the blocks it is used on.",
)
@@ -151,9 +152,9 @@ class BlockWithOAuth(Block):
class Input(BlockSchema):
# Note that the type hint below is require or you will get a type error.
# The first argument is the provider name, the second is the credential type.
credentials: CredentialsMetaInput[Literal['github'], Literal['oauth2']] = CredentialsField(
provider="github",
supported_credential_types={"oauth2"},
credentials: CredentialsMetaInput[
Literal[ProviderName.GITHUB], Literal["oauth2"]
] = CredentialsField(
required_scopes={"repo"},
description="The GitHub integration can be used with OAuth.",
)
@@ -174,9 +175,9 @@ class BlockWithAPIKeyAndOAuth(Block):
class Input(BlockSchema):
# Note that the type hint below is require or you will get a type error.
# The first argument is the provider name, the second is the credential type.
credentials: CredentialsMetaInput[Literal['github'], Literal['api_key', 'oauth2']] = CredentialsField(
provider="github",
supported_credential_types={"api_key", "oauth2"},
credentials: CredentialsMetaInput[
Literal[ProviderName.GITHUB], Literal["api_key", "oauth2"]
] = CredentialsField(
required_scopes={"repo"},
description="The GitHub integration can be used with OAuth, "
"or any API key with sufficient permissions for the blocks it is used on.",
@@ -227,6 +228,16 @@ response = requests.post(
)
```
The `ProviderName` enum is the single source of truth for which providers exist in our system.
Naturally, to add an authenticated block for a new provider, you'll have to add it here too.
<details>
<summary><code>ProviderName</code> definition</summary>
```python title="backend/integrations/providers.py"
--8<-- "autogpt_platform/backend/backend/integrations/providers.py:ProviderName"
```
</details>
#### Adding an OAuth2 service integration
To add support for a new OAuth2-authenticated service, you'll need to add an `OAuthHandler`.