From 33b9eef37663c659f7411f2d26e31a761bdbb856 Mon Sep 17 00:00:00 2001 From: Reinier van der Leer Date: Wed, 11 Dec 2024 20:27:09 +0100 Subject: [PATCH] refactor(backend): Simplify `CredentialsField` usage + use `ProviderName` globally (#8725) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Resolves #8931 - Follow-up to #8358 ### Changes 🏗️ - Avoid double specifying provider and cred types on `credentials` inputs - Move `credentials` sub-schema validation from `CredentialsField` to `CredentialsMetaInput.validate_credentials_field_schema(..)`, which is called in `BlockSchema.__pydantic_init_subclass__` - Use `ProviderName` enum globally --- .../blocks/ai_image_generator_block.py | 11 ++- .../backend/blocks/ai_music_generator.py | 13 ++- .../blocks/ai_shortform_video_block.py | 13 ++- .../backend/backend/blocks/code_executor.py | 11 ++- .../backend/backend/blocks/discord.py | 11 ++- .../backend/backend/blocks/exa/_auth.py | 9 +-- .../backend/backend/blocks/fal/_auth.py | 5 +- .../backend/backend/blocks/github/_auth.py | 7 +- .../backend/backend/blocks/google/_auth.py | 7 +- .../backend/backend/blocks/google_maps.py | 9 +-- .../backend/backend/blocks/hubspot/_auth.py | 5 +- .../backend/backend/blocks/ideogram.py | 12 ++- .../backend/backend/blocks/jina/_auth.py | 19 +---- .../backend/backend/blocks/llm.py | 12 ++- .../backend/backend/blocks/medium.py | 11 ++- .../backend/backend/blocks/pinecone.py | 10 +-- .../backend/blocks/replicate_flux_advanced.py | 13 ++- .../backend/backend/blocks/search.py | 5 +- .../backend/backend/blocks/slant3d/_api.py | 11 ++- .../backend/backend/blocks/talking_head.py | 13 ++- .../backend/blocks/text_to_speech_block.py | 5 +- .../backend/backend/data/block.py | 11 ++- .../backend/backend/data/integrations.py | 7 +- .../backend/backend/data/model.py | 80 +++++++++++++++---- .../backend/integrations/creds_manager.py | 8 +- .../backend/integrations/oauth/__init__.py | 9 ++- .../backend/integrations/oauth/base.py | 7 +- .../backend/integrations/oauth/github.py | 3 +- .../backend/integrations/oauth/google.py | 3 +- .../backend/integrations/oauth/notion.py | 3 +- .../backend/backend/integrations/providers.py | 23 ++++++ .../backend/integrations/webhooks/__init__.py | 3 +- .../backend/integrations/webhooks/base.py | 5 +- .../backend/integrations/webhooks/github.py | 3 +- .../webhooks/graph_lifecycle_hooks.py | 20 ++++- .../backend/integrations/webhooks/slant3d.py | 4 +- .../backend/server/integrations/router.py | 74 ++++++++++++----- docs/content/platform/new_blocks.md | 29 ++++--- 38 files changed, 313 insertions(+), 191 deletions(-) diff --git a/autogpt_platform/backend/backend/blocks/ai_image_generator_block.py b/autogpt_platform/backend/backend/blocks/ai_image_generator_block.py index ee94f016aa..ebd79dda9a 100644 --- a/autogpt_platform/backend/backend/blocks/ai_image_generator_block.py +++ b/autogpt_platform/backend/backend/blocks/ai_image_generator_block.py @@ -12,6 +12,7 @@ from backend.data.model import ( CredentialsMetaInput, SchemaField, ) +from backend.integrations.providers import ProviderName class ImageSize(str, Enum): @@ -101,12 +102,10 @@ class ImageGenModel(str, Enum): class AIImageGeneratorBlock(Block): class Input(BlockSchema): - credentials: CredentialsMetaInput[Literal["replicate"], Literal["api_key"]] = ( - CredentialsField( - provider="replicate", - supported_credential_types={"api_key"}, - description="Enter your Replicate API key to access the image generation API. You can obtain an API key from https://replicate.com/account/api-tokens.", - ) + credentials: CredentialsMetaInput[ + Literal[ProviderName.REPLICATE], Literal["api_key"] + ] = CredentialsField( + description="Enter your Replicate API key to access the image generation API. You can obtain an API key from https://replicate.com/account/api-tokens.", ) prompt: str = SchemaField( description="Text prompt for image generation", diff --git a/autogpt_platform/backend/backend/blocks/ai_music_generator.py b/autogpt_platform/backend/backend/blocks/ai_music_generator.py index f70d43ce37..7082035108 100644 --- a/autogpt_platform/backend/backend/blocks/ai_music_generator.py +++ b/autogpt_platform/backend/backend/blocks/ai_music_generator.py @@ -13,6 +13,7 @@ from backend.data.model import ( CredentialsMetaInput, SchemaField, ) +from backend.integrations.providers import ProviderName logger = logging.getLogger(__name__) @@ -54,13 +55,11 @@ class NormalizationStrategy(str, Enum): class AIMusicGeneratorBlock(Block): class Input(BlockSchema): - credentials: CredentialsMetaInput[Literal["replicate"], Literal["api_key"]] = ( - CredentialsField( - provider="replicate", - supported_credential_types={"api_key"}, - description="The Replicate integration can be used with " - "any API key with sufficient permissions for the blocks it is used on.", - ) + credentials: CredentialsMetaInput[ + Literal[ProviderName.REPLICATE], Literal["api_key"] + ] = CredentialsField( + description="The Replicate integration can be used with " + "any API key with sufficient permissions for the blocks it is used on.", ) prompt: str = SchemaField( description="A description of the music you want to generate", diff --git a/autogpt_platform/backend/backend/blocks/ai_shortform_video_block.py b/autogpt_platform/backend/backend/blocks/ai_shortform_video_block.py index 08023b8771..df2b3a2726 100644 --- a/autogpt_platform/backend/backend/blocks/ai_shortform_video_block.py +++ b/autogpt_platform/backend/backend/blocks/ai_shortform_video_block.py @@ -12,6 +12,7 @@ from backend.data.model import ( CredentialsMetaInput, SchemaField, ) +from backend.integrations.providers import ProviderName from backend.util.request import requests TEST_CREDENTIALS = APIKeyCredentials( @@ -140,13 +141,11 @@ logger = logging.getLogger(__name__) class AIShortformVideoCreatorBlock(Block): class Input(BlockSchema): - credentials: CredentialsMetaInput[Literal["revid"], Literal["api_key"]] = ( - CredentialsField( - provider="revid", - supported_credential_types={"api_key"}, - description="The revid.ai integration can be used with " - "any API key with sufficient permissions for the blocks it is used on.", - ) + credentials: CredentialsMetaInput[ + Literal[ProviderName.REVID], Literal["api_key"] + ] = CredentialsField( + description="The revid.ai integration can be used with " + "any API key with sufficient permissions for the blocks it is used on.", ) script: str = SchemaField( description="""1. Use short and punctuated sentences\n\n2. Use linebreaks to create a new clip\n\n3. Text outside of brackets is spoken by the AI, and [text between brackets] will be used to guide the visual generation. For example, [close-up of a cat] will show a close-up of a cat.""", diff --git a/autogpt_platform/backend/backend/blocks/code_executor.py b/autogpt_platform/backend/backend/blocks/code_executor.py index 3f513b30ca..4c92f6fb42 100644 --- a/autogpt_platform/backend/backend/blocks/code_executor.py +++ b/autogpt_platform/backend/backend/blocks/code_executor.py @@ -11,6 +11,7 @@ from backend.data.model import ( CredentialsMetaInput, SchemaField, ) +from backend.integrations.providers import ProviderName TEST_CREDENTIALS = APIKeyCredentials( id="01234567-89ab-cdef-0123-456789abcdef", @@ -39,12 +40,10 @@ class CodeExecutionBlock(Block): # TODO : Add support to upload and download files # Currently, You can customized the CPU and Memory, only by creating a pre customized sandbox template class Input(BlockSchema): - credentials: CredentialsMetaInput[Literal["e2b"], Literal["api_key"]] = ( - CredentialsField( - provider="e2b", - supported_credential_types={"api_key"}, - description="Enter your api key for the E2B Sandbox. You can get it in here - https://e2b.dev/docs", - ) + credentials: CredentialsMetaInput[ + Literal[ProviderName.E2B], Literal["api_key"] + ] = CredentialsField( + description="Enter your api key for the E2B Sandbox. You can get it in here - https://e2b.dev/docs", ) # Todo : Option to run commond in background diff --git a/autogpt_platform/backend/backend/blocks/discord.py b/autogpt_platform/backend/backend/blocks/discord.py index c638e40250..08ba8af074 100644 --- a/autogpt_platform/backend/backend/blocks/discord.py +++ b/autogpt_platform/backend/backend/blocks/discord.py @@ -12,16 +12,15 @@ from backend.data.model import ( CredentialsMetaInput, SchemaField, ) +from backend.integrations.providers import ProviderName -DiscordCredentials = CredentialsMetaInput[Literal["discord"], Literal["api_key"]] +DiscordCredentials = CredentialsMetaInput[ + Literal[ProviderName.DISCORD], Literal["api_key"] +] def DiscordCredentialsField() -> DiscordCredentials: - return CredentialsField( - description="Discord bot token", - provider="discord", - supported_credential_types={"api_key"}, - ) + return CredentialsField(description="Discord bot token") TEST_CREDENTIALS = APIKeyCredentials( diff --git a/autogpt_platform/backend/backend/blocks/exa/_auth.py b/autogpt_platform/backend/backend/blocks/exa/_auth.py index 412143a5c2..7b826ef408 100644 --- a/autogpt_platform/backend/backend/blocks/exa/_auth.py +++ b/autogpt_platform/backend/backend/blocks/exa/_auth.py @@ -3,10 +3,11 @@ from typing import Literal from pydantic import SecretStr from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput +from backend.integrations.providers import ProviderName ExaCredentials = APIKeyCredentials ExaCredentialsInput = CredentialsMetaInput[ - Literal["exa"], + Literal[ProviderName.EXA], Literal["api_key"], ] @@ -28,8 +29,4 @@ TEST_CREDENTIALS_INPUT = { def ExaCredentialsField() -> ExaCredentialsInput: """Creates an Exa credentials input on a block.""" - return CredentialsField( - provider="exa", - supported_credential_types={"api_key"}, - description="The Exa integration requires an API Key.", - ) + return CredentialsField(description="The Exa integration requires an API Key.") diff --git a/autogpt_platform/backend/backend/blocks/fal/_auth.py b/autogpt_platform/backend/backend/blocks/fal/_auth.py index 3271ee0950..5d02186e57 100644 --- a/autogpt_platform/backend/backend/blocks/fal/_auth.py +++ b/autogpt_platform/backend/backend/blocks/fal/_auth.py @@ -3,10 +3,11 @@ from typing import Literal from pydantic import SecretStr from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput +from backend.integrations.providers import ProviderName FalCredentials = APIKeyCredentials FalCredentialsInput = CredentialsMetaInput[ - Literal["fal"], + Literal[ProviderName.FAL], Literal["api_key"], ] @@ -30,7 +31,5 @@ def FalCredentialsField() -> FalCredentialsInput: Creates a FAL credentials input on a block. """ return CredentialsField( - provider="fal", - supported_credential_types={"api_key"}, description="The FAL integration can be used with an API Key.", ) diff --git a/autogpt_platform/backend/backend/blocks/github/_auth.py b/autogpt_platform/backend/backend/blocks/github/_auth.py index 72aa8f6480..df7eed90f7 100644 --- a/autogpt_platform/backend/backend/blocks/github/_auth.py +++ b/autogpt_platform/backend/backend/blocks/github/_auth.py @@ -8,6 +8,7 @@ from backend.data.model import ( CredentialsMetaInput, OAuth2Credentials, ) +from backend.integrations.providers import ProviderName from backend.util.settings import Secrets secrets = Secrets() @@ -17,7 +18,7 @@ GITHUB_OAUTH_IS_CONFIGURED = bool( GithubCredentials = APIKeyCredentials | OAuth2Credentials GithubCredentialsInput = CredentialsMetaInput[ - Literal["github"], + Literal[ProviderName.GITHUB], Literal["api_key", "oauth2"] if GITHUB_OAUTH_IS_CONFIGURED else Literal["api_key"], ] @@ -30,10 +31,6 @@ def GithubCredentialsField(scope: str) -> GithubCredentialsInput: scope: The authorization scope needed for the block to work. ([list of available scopes](https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/scopes-for-oauth-apps#available-scopes)) """ # noqa return CredentialsField( - provider="github", - supported_credential_types=( - {"api_key", "oauth2"} if GITHUB_OAUTH_IS_CONFIGURED else {"api_key"} - ), required_scopes={scope}, description="The GitHub integration can be used with OAuth, " "or any API key with sufficient permissions for the blocks it is used on.", diff --git a/autogpt_platform/backend/backend/blocks/google/_auth.py b/autogpt_platform/backend/backend/blocks/google/_auth.py index ccae2e4622..2b364dbd40 100644 --- a/autogpt_platform/backend/backend/blocks/google/_auth.py +++ b/autogpt_platform/backend/backend/blocks/google/_auth.py @@ -3,6 +3,7 @@ from typing import Literal from pydantic import SecretStr from backend.data.model import CredentialsField, CredentialsMetaInput, OAuth2Credentials +from backend.integrations.providers import ProviderName from backend.util.settings import Secrets # --8<-- [start:GoogleOAuthIsConfigured] @@ -12,7 +13,9 @@ GOOGLE_OAUTH_IS_CONFIGURED = bool( ) # --8<-- [end:GoogleOAuthIsConfigured] GoogleCredentials = OAuth2Credentials -GoogleCredentialsInput = CredentialsMetaInput[Literal["google"], Literal["oauth2"]] +GoogleCredentialsInput = CredentialsMetaInput[ + Literal[ProviderName.GOOGLE], Literal["oauth2"] +] def GoogleCredentialsField(scopes: list[str]) -> GoogleCredentialsInput: @@ -23,8 +26,6 @@ def GoogleCredentialsField(scopes: list[str]) -> GoogleCredentialsInput: scopes: The authorization scopes needed for the block to work. """ return CredentialsField( - provider="google", - supported_credential_types={"oauth2"}, required_scopes=set(scopes), description="The Google integration requires OAuth2 authentication.", ) diff --git a/autogpt_platform/backend/backend/blocks/google_maps.py b/autogpt_platform/backend/backend/blocks/google_maps.py index d211fe8ff3..9e7f793531 100644 --- a/autogpt_platform/backend/backend/blocks/google_maps.py +++ b/autogpt_platform/backend/backend/blocks/google_maps.py @@ -10,6 +10,7 @@ from backend.data.model import ( CredentialsMetaInput, SchemaField, ) +from backend.integrations.providers import ProviderName TEST_CREDENTIALS = APIKeyCredentials( id="01234567-89ab-cdef-0123-456789abcdef", @@ -38,12 +39,8 @@ class Place(BaseModel): class GoogleMapsSearchBlock(Block): class Input(BlockSchema): credentials: CredentialsMetaInput[ - Literal["google_maps"], Literal["api_key"] - ] = CredentialsField( - provider="google_maps", - supported_credential_types={"api_key"}, - description="Google Maps API Key", - ) + Literal[ProviderName.GOOGLE_MAPS], Literal["api_key"] + ] = CredentialsField(description="Google Maps API Key") query: str = SchemaField( description="Search query for local businesses", placeholder="e.g., 'restaurants in New York'", diff --git a/autogpt_platform/backend/backend/blocks/hubspot/_auth.py b/autogpt_platform/backend/backend/blocks/hubspot/_auth.py index c32af8c38c..b32456d5d5 100644 --- a/autogpt_platform/backend/backend/blocks/hubspot/_auth.py +++ b/autogpt_platform/backend/backend/blocks/hubspot/_auth.py @@ -3,10 +3,11 @@ from typing import Literal from pydantic import SecretStr from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput +from backend.integrations.providers import ProviderName HubSpotCredentials = APIKeyCredentials HubSpotCredentialsInput = CredentialsMetaInput[ - Literal["hubspot"], + Literal[ProviderName.HUBSPOT], Literal["api_key"], ] @@ -14,8 +15,6 @@ HubSpotCredentialsInput = CredentialsMetaInput[ def HubSpotCredentialsField() -> HubSpotCredentialsInput: """Creates a HubSpot credentials input on a block.""" return CredentialsField( - provider="hubspot", - supported_credential_types={"api_key"}, description="The HubSpot integration requires an API Key.", ) diff --git a/autogpt_platform/backend/backend/blocks/ideogram.py b/autogpt_platform/backend/backend/blocks/ideogram.py index b6a21e7ace..82eb91238b 100644 --- a/autogpt_platform/backend/backend/blocks/ideogram.py +++ b/autogpt_platform/backend/backend/blocks/ideogram.py @@ -11,6 +11,7 @@ from backend.data.model import ( CredentialsMetaInput, SchemaField, ) +from backend.integrations.providers import ProviderName from backend.util.request import requests TEST_CREDENTIALS = APIKeyCredentials( @@ -83,13 +84,10 @@ class UpscaleOption(str, Enum): class IdeogramModelBlock(Block): class Input(BlockSchema): - - credentials: CredentialsMetaInput[Literal["ideogram"], Literal["api_key"]] = ( - CredentialsField( - provider="ideogram", - supported_credential_types={"api_key"}, - description="The Ideogram integration can be used with any API key with sufficient permissions for the blocks it is used on.", - ) + credentials: CredentialsMetaInput[ + Literal[ProviderName.IDEOGRAM], Literal["api_key"] + ] = CredentialsField( + description="The Ideogram integration can be used with any API key with sufficient permissions for the blocks it is used on.", ) prompt: str = SchemaField( description="Text prompt for image generation", diff --git a/autogpt_platform/backend/backend/blocks/jina/_auth.py b/autogpt_platform/backend/backend/blocks/jina/_auth.py index 2bffeecce7..5bf0ddd5cf 100644 --- a/autogpt_platform/backend/backend/blocks/jina/_auth.py +++ b/autogpt_platform/backend/backend/blocks/jina/_auth.py @@ -3,27 +3,14 @@ from typing import Literal from pydantic import SecretStr from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput +from backend.integrations.providers import ProviderName JinaCredentials = APIKeyCredentials JinaCredentialsInput = CredentialsMetaInput[ - Literal["jina"], + Literal[ProviderName.JINA], Literal["api_key"], ] -TEST_CREDENTIALS = APIKeyCredentials( - id="01234567-89ab-cdef-0123-456789abcdef", - provider="jina", - api_key=SecretStr("mock-jina-api-key"), - title="Mock Jina API key", - expires_at=None, -) -TEST_CREDENTIALS_INPUT = { - "provider": TEST_CREDENTIALS.provider, - "id": TEST_CREDENTIALS.id, - "type": TEST_CREDENTIALS.type, - "title": TEST_CREDENTIALS.type, -} - def JinaCredentialsField() -> JinaCredentialsInput: """ @@ -31,8 +18,6 @@ def JinaCredentialsField() -> JinaCredentialsInput: """ return CredentialsField( - provider="jina", - supported_credential_types={"api_key"}, description="The Jina integration can be used with an API Key.", ) diff --git a/autogpt_platform/backend/backend/blocks/llm.py b/autogpt_platform/backend/backend/blocks/llm.py index 8390164463..e913e88e6d 100644 --- a/autogpt_platform/backend/backend/blocks/llm.py +++ b/autogpt_platform/backend/backend/blocks/llm.py @@ -7,6 +7,8 @@ from typing import TYPE_CHECKING, Any, List, Literal, NamedTuple from pydantic import SecretStr +from backend.integrations.providers import ProviderName + if TYPE_CHECKING: from enum import _EnumMemberT @@ -27,7 +29,13 @@ from backend.util.settings import BehaveAs, Settings logger = logging.getLogger(__name__) -LLMProviderName = Literal["anthropic", "groq", "openai", "ollama", "open_router"] +LLMProviderName = Literal[ + ProviderName.ANTHROPIC, + ProviderName.GROQ, + ProviderName.OLLAMA, + ProviderName.OPENAI, + ProviderName.OPEN_ROUTER, +] AICredentials = CredentialsMetaInput[LLMProviderName, Literal["api_key"]] TEST_CREDENTIALS = APIKeyCredentials( @@ -48,8 +56,6 @@ TEST_CREDENTIALS_INPUT = { def AICredentialsField() -> AICredentials: return CredentialsField( description="API key for the LLM provider.", - provider=["anthropic", "groq", "openai", "ollama", "open_router"], - supported_credential_types={"api_key"}, discriminator="model", discriminator_mapping={ model.value: model.metadata.provider for model in LlmModel diff --git a/autogpt_platform/backend/backend/blocks/medium.py b/autogpt_platform/backend/backend/blocks/medium.py index da8c367c78..6d871b4caa 100644 --- a/autogpt_platform/backend/backend/blocks/medium.py +++ b/autogpt_platform/backend/backend/blocks/medium.py @@ -12,6 +12,7 @@ from backend.data.model import ( SchemaField, SecretField, ) +from backend.integrations.providers import ProviderName from backend.util.request import requests TEST_CREDENTIALS = APIKeyCredentials( @@ -77,12 +78,10 @@ class PublishToMediumBlock(Block): description="Whether to notify followers that the user has published", placeholder="False", ) - credentials: CredentialsMetaInput[Literal["medium"], Literal["api_key"]] = ( - CredentialsField( - provider="medium", - supported_credential_types={"api_key"}, - description="The Medium integration can be used with any API key with sufficient permissions for the blocks it is used on.", - ) + credentials: CredentialsMetaInput[ + Literal[ProviderName.MEDIUM], Literal["api_key"] + ] = CredentialsField( + description="The Medium integration can be used with any API key with sufficient permissions for the blocks it is used on.", ) class Output(BlockSchema): diff --git a/autogpt_platform/backend/backend/blocks/pinecone.py b/autogpt_platform/backend/backend/blocks/pinecone.py index a62c1fa777..f1b79f4431 100644 --- a/autogpt_platform/backend/backend/blocks/pinecone.py +++ b/autogpt_platform/backend/backend/blocks/pinecone.py @@ -10,22 +10,18 @@ from backend.data.model import ( CredentialsMetaInput, SchemaField, ) +from backend.integrations.providers import ProviderName PineconeCredentials = APIKeyCredentials PineconeCredentialsInput = CredentialsMetaInput[ - Literal["pinecone"], + Literal[ProviderName.PINECONE], Literal["api_key"], ] def PineconeCredentialsField() -> PineconeCredentialsInput: - """ - Creates a Pinecone credentials input on a block. - - """ + """Creates a Pinecone credentials input on a block.""" return CredentialsField( - provider="pinecone", - supported_credential_types={"api_key"}, description="The Pinecone integration can be used with an API Key.", ) diff --git a/autogpt_platform/backend/backend/blocks/replicate_flux_advanced.py b/autogpt_platform/backend/backend/blocks/replicate_flux_advanced.py index b346c87a38..c913ead132 100644 --- a/autogpt_platform/backend/backend/blocks/replicate_flux_advanced.py +++ b/autogpt_platform/backend/backend/blocks/replicate_flux_advanced.py @@ -13,6 +13,7 @@ from backend.data.model import ( CredentialsMetaInput, SchemaField, ) +from backend.integrations.providers import ProviderName TEST_CREDENTIALS = APIKeyCredentials( id="01234567-89ab-cdef-0123-456789abcdef", @@ -54,13 +55,11 @@ class ImageType(str, Enum): class ReplicateFluxAdvancedModelBlock(Block): class Input(BlockSchema): - credentials: CredentialsMetaInput[Literal["replicate"], Literal["api_key"]] = ( - CredentialsField( - provider="replicate", - supported_credential_types={"api_key"}, - description="The Replicate integration can be used with " - "any API key with sufficient permissions for the blocks it is used on.", - ) + credentials: CredentialsMetaInput[ + Literal[ProviderName.REPLICATE], Literal["api_key"] + ] = CredentialsField( + description="The Replicate integration can be used with " + "any API key with sufficient permissions for the blocks it is used on.", ) prompt: str = SchemaField( description="Text prompt for image generation", diff --git a/autogpt_platform/backend/backend/blocks/search.py b/autogpt_platform/backend/backend/blocks/search.py index f89232703a..633ad31091 100644 --- a/autogpt_platform/backend/backend/blocks/search.py +++ b/autogpt_platform/backend/backend/blocks/search.py @@ -11,6 +11,7 @@ from backend.data.model import ( CredentialsMetaInput, SchemaField, ) +from backend.integrations.providers import ProviderName class GetWikipediaSummaryBlock(Block, GetRequest): @@ -65,10 +66,8 @@ class GetWeatherInformationBlock(Block, GetRequest): description="Location to get weather information for" ) credentials: CredentialsMetaInput[ - Literal["openweathermap"], Literal["api_key"] + Literal[ProviderName.OPENWEATHERMAP], Literal["api_key"] ] = CredentialsField( - provider="openweathermap", - supported_credential_types={"api_key"}, description="The OpenWeatherMap integration can be used with " "any API key with sufficient permissions for the blocks it is used on.", ) diff --git a/autogpt_platform/backend/backend/blocks/slant3d/_api.py b/autogpt_platform/backend/backend/blocks/slant3d/_api.py index f8f3211047..d952662e78 100644 --- a/autogpt_platform/backend/backend/blocks/slant3d/_api.py +++ b/autogpt_platform/backend/backend/blocks/slant3d/_api.py @@ -4,16 +4,15 @@ from typing import Literal from pydantic import BaseModel, SecretStr from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput +from backend.integrations.providers import ProviderName -Slant3DCredentialsInput = CredentialsMetaInput[Literal["slant3d"], Literal["api_key"]] +Slant3DCredentialsInput = CredentialsMetaInput[ + Literal[ProviderName.SLANT3D], Literal["api_key"] +] def Slant3DCredentialsField() -> Slant3DCredentialsInput: - return CredentialsField( - provider="slant3d", - supported_credential_types={"api_key"}, - description="Slant3D API key for authentication", - ) + return CredentialsField(description="Slant3D API key for authentication") TEST_CREDENTIALS = APIKeyCredentials( diff --git a/autogpt_platform/backend/backend/blocks/talking_head.py b/autogpt_platform/backend/backend/blocks/talking_head.py index 20aadcd214..e1965dbc64 100644 --- a/autogpt_platform/backend/backend/blocks/talking_head.py +++ b/autogpt_platform/backend/backend/blocks/talking_head.py @@ -10,6 +10,7 @@ from backend.data.model import ( CredentialsMetaInput, SchemaField, ) +from backend.integrations.providers import ProviderName from backend.util.request import requests TEST_CREDENTIALS = APIKeyCredentials( @@ -29,13 +30,11 @@ TEST_CREDENTIALS_INPUT = { class CreateTalkingAvatarVideoBlock(Block): class Input(BlockSchema): - credentials: CredentialsMetaInput[Literal["d_id"], Literal["api_key"]] = ( - CredentialsField( - provider="d_id", - supported_credential_types={"api_key"}, - description="The D-ID integration can be used with " - "any API key with sufficient permissions for the blocks it is used on.", - ) + credentials: CredentialsMetaInput[ + Literal[ProviderName.D_ID], Literal["api_key"] + ] = CredentialsField( + description="The D-ID integration can be used with " + "any API key with sufficient permissions for the blocks it is used on.", ) script_input: str = SchemaField( description="The text input for the script", diff --git a/autogpt_platform/backend/backend/blocks/text_to_speech_block.py b/autogpt_platform/backend/backend/blocks/text_to_speech_block.py index b92cc9fa44..dc9e0bf26e 100644 --- a/autogpt_platform/backend/backend/blocks/text_to_speech_block.py +++ b/autogpt_platform/backend/backend/blocks/text_to_speech_block.py @@ -9,6 +9,7 @@ from backend.data.model import ( CredentialsMetaInput, SchemaField, ) +from backend.integrations.providers import ProviderName from backend.util.request import requests TEST_CREDENTIALS = APIKeyCredentials( @@ -38,10 +39,8 @@ class UnrealTextToSpeechBlock(Block): default="Scarlett", ) credentials: CredentialsMetaInput[ - Literal["unreal_speech"], Literal["api_key"] + Literal[ProviderName.UNREAL_SPEECH], Literal["api_key"] ] = CredentialsField( - provider="unreal_speech", - supported_credential_types={"api_key"}, description="The Unreal Speech integration can be used with " "any API key with sufficient permissions for the blocks it is used on.", ) diff --git a/autogpt_platform/backend/backend/data/block.py b/autogpt_platform/backend/backend/data/block.py index e1035e84fe..096195e2c9 100644 --- a/autogpt_platform/backend/backend/data/block.py +++ b/autogpt_platform/backend/backend/data/block.py @@ -65,7 +65,7 @@ class BlockCategory(Enum): class BlockSchema(BaseModel): - cached_jsonschema: ClassVar[dict[str, Any]] = {} + cached_jsonschema: ClassVar[dict[str, Any]] @classmethod def jsonschema(cls) -> dict[str, Any]: @@ -145,6 +145,10 @@ class BlockSchema(BaseModel): - A field that is called `credentials` MUST be a `CredentialsMetaInput`. """ super().__pydantic_init_subclass__(**kwargs) + + # Reset cached JSON schema to prevent inheriting it from parent class + cls.cached_jsonschema = {} + credentials_fields = [ field_name for field_name, info in cls.model_fields.items() @@ -176,6 +180,11 @@ class BlockSchema(BaseModel): f"Field 'credentials' on {cls.__qualname__} " f"must be of type {CredentialsMetaInput.__name__}" ) + if credentials_field := cls.model_fields.get(CREDENTIALS_FIELD_NAME): + credentials_input_type = cast( + CredentialsMetaInput, credentials_field.annotation + ) + credentials_input_type.validate_credentials_field_schema(cls) BlockSchemaInputType = TypeVar("BlockSchemaInputType", bound=BlockSchema) diff --git a/autogpt_platform/backend/backend/data/integrations.py b/autogpt_platform/backend/backend/data/integrations.py index 4d5197f693..4f444fb035 100644 --- a/autogpt_platform/backend/backend/data/integrations.py +++ b/autogpt_platform/backend/backend/data/integrations.py @@ -7,6 +7,7 @@ from pydantic import Field from backend.data.includes import INTEGRATION_WEBHOOK_INCLUDE from backend.data.queue import AsyncRedisEventBus +from backend.integrations.providers import ProviderName from .db import BaseDbModel @@ -18,7 +19,7 @@ logger = logging.getLogger(__name__) class Webhook(BaseDbModel): user_id: str - provider: str + provider: ProviderName credentials_id: str webhook_type: str resource: str @@ -37,7 +38,7 @@ class Webhook(BaseDbModel): return Webhook( id=webhook.id, user_id=webhook.userId, - provider=webhook.provider, + provider=ProviderName(webhook.provider), credentials_id=webhook.credentialsId, webhook_type=webhook.webhookType, resource=webhook.resource, @@ -61,7 +62,7 @@ async def create_webhook(webhook: Webhook) -> Webhook: data={ "id": webhook.id, "userId": webhook.user_id, - "provider": webhook.provider, + "provider": webhook.provider.value, "credentialsId": webhook.credentials_id, "webhookType": webhook.webhook_type, "resource": webhook.resource, diff --git a/autogpt_platform/backend/backend/data/model.py b/autogpt_platform/backend/backend/data/model.py index 5940bf7ecc..9e79fd7da3 100644 --- a/autogpt_platform/backend/backend/data/model.py +++ b/autogpt_platform/backend/backend/data/model.py @@ -2,6 +2,7 @@ from __future__ import annotations import logging from typing import ( + TYPE_CHECKING, Annotated, Any, Callable, @@ -11,19 +12,32 @@ from typing import ( Optional, TypedDict, TypeVar, + get_args, ) from uuid import uuid4 -from pydantic import BaseModel, Field, GetCoreSchemaHandler, SecretStr, field_serializer +from pydantic import ( + BaseModel, + ConfigDict, + Field, + GetCoreSchemaHandler, + SecretStr, + field_serializer, +) from pydantic_core import ( CoreSchema, PydanticUndefined, PydanticUndefinedType, + ValidationError, core_schema, ) +from backend.integrations.providers import ProviderName from backend.util.settings import Secrets +if TYPE_CHECKING: + from backend.data.block import BlockSchema + T = TypeVar("T") logger = logging.getLogger(__name__) @@ -220,7 +234,7 @@ class UserIntegrations(BaseModel): oauth_states: list[OAuthState] = Field(default_factory=list) -CP = TypeVar("CP", bound=str) +CP = TypeVar("CP", bound=ProviderName) CT = TypeVar("CT", bound=CredentialsType) @@ -233,19 +247,51 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]): provider: CP type: CT + @staticmethod + def _add_json_schema_extra(schema, cls: CredentialsMetaInput): + schema["credentials_provider"] = get_args( + cls.model_fields["provider"].annotation + ) + schema["credentials_types"] = get_args(cls.model_fields["type"].annotation) -class CredentialsFieldSchemaExtra(BaseModel, Generic[CP, CT]): + model_config = ConfigDict( + json_schema_extra=_add_json_schema_extra, # type: ignore + ) + + @classmethod + def validate_credentials_field_schema(cls, model: type["BlockSchema"]): + """Validates the schema of a `credentials` field""" + field_schema = model.jsonschema()["properties"][CREDENTIALS_FIELD_NAME] + try: + schema_extra = _CredentialsFieldSchemaExtra[CP, CT].model_validate( + field_schema + ) + except ValidationError as e: + if "Field required [type=missing" not in str(e): + raise + + raise TypeError( + "Field 'credentials' JSON schema lacks required extra items: " + f"{field_schema}" + ) from e + + if ( + len(schema_extra.credentials_provider) > 1 + and not schema_extra.discriminator + ): + raise TypeError("Multi-provider CredentialsField requires discriminator!") + + +class _CredentialsFieldSchemaExtra(BaseModel, Generic[CP, CT]): # TODO: move discrimination mechanism out of CredentialsField (frontend + backend) credentials_provider: list[CP] - credentials_scopes: Optional[list[str]] + credentials_scopes: Optional[list[str]] = None credentials_types: list[CT] discriminator: Optional[str] = None discriminator_mapping: Optional[dict[str, CP]] = None def CredentialsField( - provider: CP | list[CP], - supported_credential_types: set[CT], required_scopes: set[str] = set(), *, discriminator: Optional[str] = None, @@ -253,26 +299,26 @@ def CredentialsField( title: Optional[str] = None, description: Optional[str] = None, **kwargs, -) -> CredentialsMetaInput[CP, CT]: +) -> CredentialsMetaInput: """ `CredentialsField` must and can only be used on fields named `credentials`. This is enforced by the `BlockSchema` base class. """ - if not isinstance(provider, str) and len(provider) > 1 and not discriminator: - raise TypeError("Multi-provider CredentialsField requires discriminator!") - field_schema_extra = CredentialsFieldSchemaExtra[CP, CT]( - credentials_provider=[provider] if isinstance(provider, str) else provider, - credentials_scopes=list(required_scopes) or None, # omit if empty - credentials_types=list(supported_credential_types), - discriminator=discriminator, - discriminator_mapping=discriminator_mapping, - ) + field_schema_extra = { + k: v + for k, v in { + "credentials_scopes": list(required_scopes) or None, + "discriminator": discriminator, + "discriminator_mapping": discriminator_mapping, + }.items() + if v is not None + } return Field( title=title, description=description, - json_schema_extra=field_schema_extra.model_dump(exclude_none=True), + json_schema_extra=field_schema_extra, # validated on BlockSchema init **kwargs, ) diff --git a/autogpt_platform/backend/backend/integrations/creds_manager.py b/autogpt_platform/backend/backend/integrations/creds_manager.py index 7cbf8f4af7..17633bd587 100644 --- a/autogpt_platform/backend/backend/integrations/creds_manager.py +++ b/autogpt_platform/backend/backend/integrations/creds_manager.py @@ -1,6 +1,7 @@ import logging from contextlib import contextmanager from datetime import datetime +from typing import TYPE_CHECKING from autogpt_libs.utils.synchronize import RedisKeyedMutex from redis.lock import Lock as RedisLock @@ -8,10 +9,13 @@ from redis.lock import Lock as RedisLock from backend.data import redis from backend.data.model import Credentials from backend.integrations.credentials_store import IntegrationCredentialsStore -from backend.integrations.oauth import HANDLERS_BY_NAME, BaseOAuthHandler +from backend.integrations.oauth import HANDLERS_BY_NAME from backend.util.exceptions import MissingConfigError from backend.util.settings import Settings +if TYPE_CHECKING: + from backend.integrations.oauth import BaseOAuthHandler + logger = logging.getLogger(__name__) settings = Settings() @@ -148,7 +152,7 @@ class IntegrationCredentialsManager: self.store.locks.release_all_locks() -def _get_provider_oauth_handler(provider_name: str) -> BaseOAuthHandler: +def _get_provider_oauth_handler(provider_name: str) -> "BaseOAuthHandler": if provider_name not in HANDLERS_BY_NAME: raise KeyError(f"Unknown provider '{provider_name}'") diff --git a/autogpt_platform/backend/backend/integrations/oauth/__init__.py b/autogpt_platform/backend/backend/integrations/oauth/__init__.py index 834293da92..f5888f07a8 100644 --- a/autogpt_platform/backend/backend/integrations/oauth/__init__.py +++ b/autogpt_platform/backend/backend/integrations/oauth/__init__.py @@ -1,10 +1,15 @@ -from .base import BaseOAuthHandler +from typing import TYPE_CHECKING + from .github import GitHubOAuthHandler from .google import GoogleOAuthHandler from .notion import NotionOAuthHandler +if TYPE_CHECKING: + from ..providers import ProviderName + from .base import BaseOAuthHandler + # --8<-- [start:HANDLERS_BY_NAMEExample] -HANDLERS_BY_NAME: dict[str, type[BaseOAuthHandler]] = { +HANDLERS_BY_NAME: dict["ProviderName", type["BaseOAuthHandler"]] = { handler.PROVIDER_NAME: handler for handler in [ GitHubOAuthHandler, diff --git a/autogpt_platform/backend/backend/integrations/oauth/base.py b/autogpt_platform/backend/backend/integrations/oauth/base.py index ad5433d734..54786ba1aa 100644 --- a/autogpt_platform/backend/backend/integrations/oauth/base.py +++ b/autogpt_platform/backend/backend/integrations/oauth/base.py @@ -4,13 +4,14 @@ from abc import ABC, abstractmethod from typing import ClassVar from backend.data.model import OAuth2Credentials +from backend.integrations.providers import ProviderName logger = logging.getLogger(__name__) class BaseOAuthHandler(ABC): # --8<-- [start:BaseOAuthHandler1] - PROVIDER_NAME: ClassVar[str] + PROVIDER_NAME: ClassVar[ProviderName] DEFAULT_SCOPES: ClassVar[list[str]] = [] # --8<-- [end:BaseOAuthHandler1] @@ -76,6 +77,8 @@ class BaseOAuthHandler(ABC): """Handles the default scopes for the provider""" # If scopes are empty, use the default scopes for the provider if not scopes: - logger.debug(f"Using default scopes for provider {self.PROVIDER_NAME}") + logger.debug( + f"Using default scopes for provider {self.PROVIDER_NAME.value}" + ) scopes = self.DEFAULT_SCOPES return scopes diff --git a/autogpt_platform/backend/backend/integrations/oauth/github.py b/autogpt_platform/backend/backend/integrations/oauth/github.py index ed883b6320..d83c9b2093 100644 --- a/autogpt_platform/backend/backend/integrations/oauth/github.py +++ b/autogpt_platform/backend/backend/integrations/oauth/github.py @@ -3,6 +3,7 @@ from typing import Optional from urllib.parse import urlencode from backend.data.model import OAuth2Credentials +from backend.integrations.providers import ProviderName from backend.util.request import requests from .base import BaseOAuthHandler @@ -23,7 +24,7 @@ class GitHubOAuthHandler(BaseOAuthHandler): access token *with no refresh token*. """ # noqa - PROVIDER_NAME = "github" + PROVIDER_NAME = ProviderName.GITHUB def __init__(self, client_id: str, client_secret: str, redirect_uri: str): self.client_id = client_id diff --git a/autogpt_platform/backend/backend/integrations/oauth/google.py b/autogpt_platform/backend/backend/integrations/oauth/google.py index 13175b0b46..5a03e615a4 100644 --- a/autogpt_platform/backend/backend/integrations/oauth/google.py +++ b/autogpt_platform/backend/backend/integrations/oauth/google.py @@ -9,6 +9,7 @@ from google_auth_oauthlib.flow import Flow from pydantic import SecretStr from backend.data.model import OAuth2Credentials +from backend.integrations.providers import ProviderName from .base import BaseOAuthHandler @@ -21,7 +22,7 @@ class GoogleOAuthHandler(BaseOAuthHandler): Based on the documentation at https://developers.google.com/identity/protocols/oauth2/web-server """ # noqa - PROVIDER_NAME = "google" + PROVIDER_NAME = ProviderName.GOOGLE EMAIL_ENDPOINT = "https://www.googleapis.com/oauth2/v2/userinfo" DEFAULT_SCOPES = [ "https://www.googleapis.com/auth/userinfo.email", diff --git a/autogpt_platform/backend/backend/integrations/oauth/notion.py b/autogpt_platform/backend/backend/integrations/oauth/notion.py index 7f5458ae9a..e71bae2956 100644 --- a/autogpt_platform/backend/backend/integrations/oauth/notion.py +++ b/autogpt_platform/backend/backend/integrations/oauth/notion.py @@ -2,6 +2,7 @@ from base64 import b64encode from urllib.parse import urlencode from backend.data.model import OAuth2Credentials +from backend.integrations.providers import ProviderName from backend.util.request import requests from .base import BaseOAuthHandler @@ -16,7 +17,7 @@ class NotionOAuthHandler(BaseOAuthHandler): - Notion doesn't use scopes """ - PROVIDER_NAME = "notion" + PROVIDER_NAME = ProviderName.NOTION def __init__(self, client_id: str, client_secret: str, redirect_uri: str): self.client_id = client_id diff --git a/autogpt_platform/backend/backend/integrations/providers.py b/autogpt_platform/backend/backend/integrations/providers.py index b38becf5ca..c2a20c4ed6 100644 --- a/autogpt_platform/backend/backend/integrations/providers.py +++ b/autogpt_platform/backend/backend/integrations/providers.py @@ -1,7 +1,30 @@ from enum import Enum +# --8<-- [start:ProviderName] class ProviderName(str, Enum): + ANTHROPIC = "anthropic" + DISCORD = "discord" + D_ID = "d_id" + E2B = "e2b" + EXA = "exa" + FAL = "fal" GITHUB = "github" GOOGLE = "google" + GOOGLE_MAPS = "google_maps" + GROQ = "groq" + HUBSPOT = "hubspot" + IDEOGRAM = "ideogram" + JINA = "jina" + MEDIUM = "medium" NOTION = "notion" + OLLAMA = "ollama" + OPENAI = "openai" + OPENWEATHERMAP = "openweathermap" + OPEN_ROUTER = "open_router" + PINECONE = "pinecone" + REPLICATE = "replicate" + REVID = "revid" + SLANT3D = "slant3d" + UNREAL_SPEECH = "unreal_speech" + # --8<-- [end:ProviderName] diff --git a/autogpt_platform/backend/backend/integrations/webhooks/__init__.py b/autogpt_platform/backend/backend/integrations/webhooks/__init__.py index ebfeea7637..c3992f3652 100644 --- a/autogpt_platform/backend/backend/integrations/webhooks/__init__.py +++ b/autogpt_platform/backend/backend/integrations/webhooks/__init__.py @@ -4,10 +4,11 @@ from .github import GithubWebhooksManager from .slant3d import Slant3DWebhooksManager if TYPE_CHECKING: + from ..providers import ProviderName from .base import BaseWebhooksManager # --8<-- [start:WEBHOOK_MANAGERS_BY_NAME] -WEBHOOK_MANAGERS_BY_NAME: dict[str, type["BaseWebhooksManager"]] = { +WEBHOOK_MANAGERS_BY_NAME: dict["ProviderName", type["BaseWebhooksManager"]] = { handler.PROVIDER_NAME: handler for handler in [ GithubWebhooksManager, diff --git a/autogpt_platform/backend/backend/integrations/webhooks/base.py b/autogpt_platform/backend/backend/integrations/webhooks/base.py index b1c8a62ddc..999c0cd8bc 100644 --- a/autogpt_platform/backend/backend/integrations/webhooks/base.py +++ b/autogpt_platform/backend/backend/integrations/webhooks/base.py @@ -9,6 +9,7 @@ from strenum import StrEnum from backend.data import integrations from backend.data.model import Credentials +from backend.integrations.providers import ProviderName from backend.util.exceptions import MissingConfigError from backend.util.settings import Config @@ -20,7 +21,7 @@ WT = TypeVar("WT", bound=StrEnum) class BaseWebhooksManager(ABC, Generic[WT]): # --8<-- [start:BaseWebhooksManager1] - PROVIDER_NAME: ClassVar[str] + PROVIDER_NAME: ClassVar[ProviderName] # --8<-- [end:BaseWebhooksManager1] WebhookType: WT @@ -143,7 +144,7 @@ class BaseWebhooksManager(ABC, Generic[WT]): secret = secrets.token_hex(32) provider_name = self.PROVIDER_NAME ingress_url = ( - f"{app_config.platform_base_url}/api/integrations/{provider_name}" + f"{app_config.platform_base_url}/api/integrations/{provider_name.value}" f"/webhooks/{id}/ingress" ) provider_webhook_id, config = await self._register_webhook( diff --git a/autogpt_platform/backend/backend/integrations/webhooks/github.py b/autogpt_platform/backend/backend/integrations/webhooks/github.py index 679e771246..b3ef4f780f 100644 --- a/autogpt_platform/backend/backend/integrations/webhooks/github.py +++ b/autogpt_platform/backend/backend/integrations/webhooks/github.py @@ -8,6 +8,7 @@ from strenum import StrEnum from backend.data import integrations from backend.data.model import Credentials +from backend.integrations.providers import ProviderName from .base import BaseWebhooksManager @@ -20,7 +21,7 @@ class GithubWebhookType(StrEnum): class GithubWebhooksManager(BaseWebhooksManager): - PROVIDER_NAME = "github" + PROVIDER_NAME = ProviderName.GITHUB WebhookType = GithubWebhookType diff --git a/autogpt_platform/backend/backend/integrations/webhooks/graph_lifecycle_hooks.py b/autogpt_platform/backend/backend/integrations/webhooks/graph_lifecycle_hooks.py index 363bf535c5..c241bc3a4a 100644 --- a/autogpt_platform/backend/backend/integrations/webhooks/graph_lifecycle_hooks.py +++ b/autogpt_platform/backend/backend/integrations/webhooks/graph_lifecycle_hooks.py @@ -95,11 +95,18 @@ async def on_node_activate( if not block.webhook_config: return node + provider = block.webhook_config.provider + if provider not in WEBHOOK_MANAGERS_BY_NAME: + raise ValueError( + f"Block #{block.id} has webhook_config for provider {provider} " + "which does not support webhooks" + ) + logger.debug( f"Activating webhook node #{node.id} with config {block.webhook_config}" ) - webhooks_manager = WEBHOOK_MANAGERS_BY_NAME[block.webhook_config.provider]() + webhooks_manager = WEBHOOK_MANAGERS_BY_NAME[provider]() try: resource = block.webhook_config.resource_format.format(**node.input_default) @@ -167,7 +174,14 @@ async def on_node_deactivate( if not block.webhook_config: return node - webhooks_manager = WEBHOOK_MANAGERS_BY_NAME[block.webhook_config.provider]() + provider = block.webhook_config.provider + if provider not in WEBHOOK_MANAGERS_BY_NAME: + raise ValueError( + f"Block #{block.id} has webhook_config for provider {provider} " + "which does not support webhooks" + ) + + webhooks_manager = WEBHOOK_MANAGERS_BY_NAME[provider]() if node.webhook_id: logger.debug(f"Node #{node.id} has webhook_id {node.webhook_id}") @@ -189,7 +203,7 @@ async def on_node_deactivate( logger.warning( f"Cannot deregister webhook #{webhook.id}: credentials " f"#{webhook.credentials_id} not available " - f"({webhook.provider} webhook ID: {webhook.provider_webhook_id})" + f"({webhook.provider.value} webhook ID: {webhook.provider_webhook_id})" ) return updated_node diff --git a/autogpt_platform/backend/backend/integrations/webhooks/slant3d.py b/autogpt_platform/backend/backend/integrations/webhooks/slant3d.py index 316bca0166..405c94d9d0 100644 --- a/autogpt_platform/backend/backend/integrations/webhooks/slant3d.py +++ b/autogpt_platform/backend/backend/integrations/webhooks/slant3d.py @@ -1,11 +1,11 @@ import logging -from typing import ClassVar import requests from fastapi import Request from backend.data import integrations from backend.data.model import APIKeyCredentials, Credentials +from backend.integrations.providers import ProviderName from backend.integrations.webhooks.base import BaseWebhooksManager logger = logging.getLogger(__name__) @@ -14,7 +14,7 @@ logger = logging.getLogger(__name__) class Slant3DWebhooksManager(BaseWebhooksManager): """Manager for Slant3D webhooks""" - PROVIDER_NAME: ClassVar[str] = "slant3d" + PROVIDER_NAME = ProviderName.SLANT3D BASE_URL = "https://www.slant3dapi.com/api" async def _register_webhook( diff --git a/autogpt_platform/backend/backend/server/integrations/router.py b/autogpt_platform/backend/backend/server/integrations/router.py index 34770b677d..3e0b0a3a03 100644 --- a/autogpt_platform/backend/backend/server/integrations/router.py +++ b/autogpt_platform/backend/backend/server/integrations/router.py @@ -1,5 +1,5 @@ import logging -from typing import Annotated, Literal +from typing import TYPE_CHECKING, Annotated, Literal from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request from pydantic import BaseModel, Field, SecretStr @@ -20,12 +20,16 @@ from backend.data.model import ( ) from backend.executor.manager import ExecutionManager from backend.integrations.creds_manager import IntegrationCredentialsManager -from backend.integrations.oauth import HANDLERS_BY_NAME, BaseOAuthHandler +from backend.integrations.oauth import HANDLERS_BY_NAME +from backend.integrations.providers import ProviderName from backend.integrations.webhooks import WEBHOOK_MANAGERS_BY_NAME from backend.util.exceptions import NeedConfirmation from backend.util.service import get_service_client from backend.util.settings import Settings +if TYPE_CHECKING: + from backend.integrations.oauth import BaseOAuthHandler + from ..utils import get_user_id logger = logging.getLogger(__name__) @@ -42,7 +46,9 @@ class LoginResponse(BaseModel): @router.get("/{provider}/login") def login( - provider: Annotated[str, Path(title="The provider to initiate an OAuth flow for")], + provider: Annotated[ + ProviderName, Path(title="The provider to initiate an OAuth flow for") + ], user_id: Annotated[str, Depends(get_user_id)], request: Request, scopes: Annotated[ @@ -74,7 +80,9 @@ class CredentialsMetaResponse(BaseModel): @router.post("/{provider}/callback") def callback( - provider: Annotated[str, Path(title="The target provider for this OAuth exchange")], + provider: Annotated[ + ProviderName, Path(title="The target provider for this OAuth exchange") + ], code: Annotated[str, Body(title="Authorization code acquired by user login")], state_token: Annotated[str, Body(title="Anti-CSRF nonce")], user_id: Annotated[str, Depends(get_user_id)], @@ -103,11 +111,12 @@ def callback( if not set(scopes).issubset(set(credentials.scopes)): # For now, we'll just log the warning and continue logger.warning( - f"Granted scopes {credentials.scopes} for {provider}do not include all requested scopes {scopes}" + f"Granted scopes {credentials.scopes} for provider {provider.value} " + f"do not include all requested scopes {scopes}" ) except Exception as e: - logger.error(f"Code->Token exchange failed for provider {provider}: {e}") + logger.error(f"Code->Token exchange failed for provider {provider.value}: {e}") raise HTTPException( status_code=400, detail=f"Failed to exchange code for tokens: {str(e)}" ) @@ -116,7 +125,8 @@ def callback( creds_manager.create(user_id, credentials) logger.debug( - f"Successfully processed OAuth callback for user {user_id} and provider {provider}" + f"Successfully processed OAuth callback for user {user_id} " + f"and provider {provider.value}" ) return CredentialsMetaResponse( id=credentials.id, @@ -148,7 +158,9 @@ def list_credentials( @router.get("/{provider}/credentials") def list_credentials_by_provider( - provider: Annotated[str, Path(title="The provider to list credentials for")], + provider: Annotated[ + ProviderName, Path(title="The provider to list credentials for") + ], user_id: Annotated[str, Depends(get_user_id)], ) -> list[CredentialsMetaResponse]: credentials = creds_manager.store.get_creds_by_provider(user_id, provider) @@ -167,7 +179,9 @@ def list_credentials_by_provider( @router.get("/{provider}/credentials/{cred_id}") def get_credential( - provider: Annotated[str, Path(title="The provider to retrieve credentials for")], + provider: Annotated[ + ProviderName, Path(title="The provider to retrieve credentials for") + ], cred_id: Annotated[str, Path(title="The ID of the credentials to retrieve")], user_id: Annotated[str, Depends(get_user_id)], ) -> Credentials: @@ -184,7 +198,9 @@ def get_credential( @router.post("/{provider}/credentials", status_code=201) def create_api_key_credentials( user_id: Annotated[str, Depends(get_user_id)], - provider: Annotated[str, Path(title="The provider to create credentials for")], + provider: Annotated[ + ProviderName, Path(title="The provider to create credentials for") + ], api_key: Annotated[str, Body(title="The API key to store")], title: Annotated[str, Body(title="Optional title for the credentials")], expires_at: Annotated[ @@ -225,7 +241,9 @@ class CredentialsDeletionNeedsConfirmationResponse(BaseModel): @router.delete("/{provider}/credentials/{cred_id}") async def delete_credentials( request: Request, - provider: Annotated[str, Path(title="The provider to delete credentials for")], + provider: Annotated[ + ProviderName, Path(title="The provider to delete credentials for") + ], cred_id: Annotated[str, Path(title="The ID of the credentials to delete")], user_id: Annotated[str, Depends(get_user_id)], force: Annotated[ @@ -264,15 +282,20 @@ async def delete_credentials( @router.post("/{provider}/webhooks/{webhook_id}/ingress") async def webhook_ingress_generic( request: Request, - provider: Annotated[str, Path(title="Provider where the webhook was registered")], + provider: Annotated[ + ProviderName, Path(title="Provider where the webhook was registered") + ], webhook_id: Annotated[str, Path(title="Our ID for the webhook")], ): - logger.debug(f"Received {provider} webhook ingress for ID {webhook_id}") + logger.debug(f"Received {provider.value} webhook ingress for ID {webhook_id}") webhook_manager = WEBHOOK_MANAGERS_BY_NAME[provider]() webhook = await get_webhook(webhook_id) logger.debug(f"Webhook #{webhook_id}: {webhook}") payload, event_type = await webhook_manager.validate_payload(webhook, request) - logger.debug(f"Validated {provider} {event_type} event with payload {payload}") + logger.debug( + f"Validated {provider.value} {webhook.webhook_type} {event_type} event " + f"with payload {payload}" + ) webhook_event = WebhookEvent( provider=provider, @@ -341,6 +364,14 @@ async def remove_all_webhooks_for_credentials( NeedConfirmation: If any of the webhooks are still in use and `force` is `False` """ webhooks = await get_all_webhooks(credentials.id) + if credentials.provider not in WEBHOOK_MANAGERS_BY_NAME: + if webhooks: + logger.error( + f"Credentials #{credentials.id} for provider {credentials.provider} " + f"are attached to {len(webhooks)} webhooks, " + f"but there is no available WebhooksHandler for {credentials.provider}" + ) + return if any(w.attached_nodes for w in webhooks) and not force: raise NeedConfirmation( "Some webhooks linked to these credentials are still in use by an agent" @@ -359,18 +390,23 @@ async def remove_all_webhooks_for_credentials( logger.warning(f"Webhook #{webhook.id} failed to prune") -def _get_provider_oauth_handler(req: Request, provider_name: str) -> BaseOAuthHandler: +def _get_provider_oauth_handler( + req: Request, provider_name: ProviderName +) -> "BaseOAuthHandler": if provider_name not in HANDLERS_BY_NAME: raise HTTPException( - status_code=404, detail=f"Unknown provider '{provider_name}'" + status_code=404, + detail=f"Provider '{provider_name.value}' does not support OAuth", ) - client_id = getattr(settings.secrets, f"{provider_name}_client_id") - client_secret = getattr(settings.secrets, f"{provider_name}_client_secret") + client_id = getattr(settings.secrets, f"{provider_name.value}_client_id") + client_secret = getattr(settings.secrets, f"{provider_name.value}_client_secret") if not (client_id and client_secret): raise HTTPException( status_code=501, - detail=f"Integration with provider '{provider_name}' is not configured", + detail=( + f"Integration with provider '{provider_name.value}' is not configured" + ), ) handler_class = HANDLERS_BY_NAME[provider_name] diff --git a/docs/content/platform/new_blocks.md b/docs/content/platform/new_blocks.md index 623b1d9b52..9fb22db5a9 100644 --- a/docs/content/platform/new_blocks.md +++ b/docs/content/platform/new_blocks.md @@ -121,6 +121,7 @@ from backend.data.model import ( from backend.data.block import Block, BlockOutput, BlockSchema from backend.data.model import CredentialsField +from backend.integrations.providers import ProviderName # API Key auth: @@ -128,9 +129,9 @@ class BlockWithAPIKeyAuth(Block): class Input(BlockSchema): # Note that the type hint below is require or you will get a type error. # The first argument is the provider name, the second is the credential type. - credentials: CredentialsMetaInput[Literal['github'], Literal['api_key']] = CredentialsField( - provider="github", - supported_credential_types={"api_key"}, + credentials: CredentialsMetaInput[ + Literal[ProviderName.GITHUB], Literal["api_key"] + ] = CredentialsField( description="The GitHub integration can be used with " "any API key with sufficient permissions for the blocks it is used on.", ) @@ -151,9 +152,9 @@ class BlockWithOAuth(Block): class Input(BlockSchema): # Note that the type hint below is require or you will get a type error. # The first argument is the provider name, the second is the credential type. - credentials: CredentialsMetaInput[Literal['github'], Literal['oauth2']] = CredentialsField( - provider="github", - supported_credential_types={"oauth2"}, + credentials: CredentialsMetaInput[ + Literal[ProviderName.GITHUB], Literal["oauth2"] + ] = CredentialsField( required_scopes={"repo"}, description="The GitHub integration can be used with OAuth.", ) @@ -174,9 +175,9 @@ class BlockWithAPIKeyAndOAuth(Block): class Input(BlockSchema): # Note that the type hint below is require or you will get a type error. # The first argument is the provider name, the second is the credential type. - credentials: CredentialsMetaInput[Literal['github'], Literal['api_key', 'oauth2']] = CredentialsField( - provider="github", - supported_credential_types={"api_key", "oauth2"}, + credentials: CredentialsMetaInput[ + Literal[ProviderName.GITHUB], Literal["api_key", "oauth2"] + ] = CredentialsField( required_scopes={"repo"}, description="The GitHub integration can be used with OAuth, " "or any API key with sufficient permissions for the blocks it is used on.", @@ -227,6 +228,16 @@ response = requests.post( ) ``` +The `ProviderName` enum is the single source of truth for which providers exist in our system. +Naturally, to add an authenticated block for a new provider, you'll have to add it here too. +
+ProviderName definition + +```python title="backend/integrations/providers.py" +--8<-- "autogpt_platform/backend/backend/integrations/providers.py:ProviderName" +``` +
+ #### Adding an OAuth2 service integration To add support for a new OAuth2-authenticated service, you'll need to add an `OAuthHandler`.