mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
update sdk
This commit is contained in:
13
autogpt_platform/backend/backend/blocks/exa/_config.py
Normal file
13
autogpt_platform/backend/backend/blocks/exa/_config.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
Shared configuration for all Exa blocks using the new SDK pattern.
|
||||
"""
|
||||
|
||||
from backend.sdk import BlockCostType, ProviderBuilder
|
||||
|
||||
# Configure the Exa provider once for all blocks
|
||||
exa = (
|
||||
ProviderBuilder("exa")
|
||||
.with_api_key("EXA_API_KEY", "Exa API Key")
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
.build()
|
||||
)
|
||||
@@ -5,38 +5,21 @@ from backend.sdk import (
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
Boolean,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
Dict,
|
||||
List,
|
||||
SchemaField,
|
||||
SecretStr,
|
||||
Settings,
|
||||
String,
|
||||
default_credentials,
|
||||
provider,
|
||||
requests,
|
||||
)
|
||||
|
||||
settings = Settings()
|
||||
from ._config import exa
|
||||
|
||||
|
||||
@provider("exa")
|
||||
@default_credentials(
|
||||
APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="exa",
|
||||
api_key=SecretStr(settings.secrets.exa_api_key),
|
||||
title="Use Credits for Exa",
|
||||
expires_at=None,
|
||||
)
|
||||
)
|
||||
class ExaAnswerBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = CredentialsField(
|
||||
provider="exa",
|
||||
supported_credential_types={"api_key"},
|
||||
description="The Exa integration requires an API Key.",
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
query: String = SchemaField(
|
||||
description="The question or query to answer",
|
||||
@@ -61,19 +44,17 @@ class ExaAnswerBlock(Block):
|
||||
|
||||
class Output(BlockSchema):
|
||||
answer: String = SchemaField(
|
||||
description="The generated answer based on search results",
|
||||
description="The generated answer based on search results"
|
||||
)
|
||||
citations: List[Dict] = SchemaField(
|
||||
description="Search results used to generate the answer",
|
||||
default_factory=list,
|
||||
)
|
||||
cost_dollars: Dict = SchemaField(
|
||||
description="Cost breakdown of the request",
|
||||
default_factory=dict,
|
||||
description="Cost breakdown of the request", default_factory=dict
|
||||
)
|
||||
error: String = SchemaField(
|
||||
description="Error message if the request failed",
|
||||
default="",
|
||||
description="Error message if the request failed", default=""
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
@@ -83,17 +64,6 @@ class ExaAnswerBlock(Block):
|
||||
categories={BlockCategory.SEARCH, BlockCategory.AI},
|
||||
input_schema=ExaAnswerBlock.Input,
|
||||
output_schema=ExaAnswerBlock.Output,
|
||||
test_input={
|
||||
"query": "What is the capital of France?",
|
||||
"text": False,
|
||||
"stream": False,
|
||||
"model": "exa",
|
||||
},
|
||||
test_output=[
|
||||
("answer", "Paris"),
|
||||
("citations", []),
|
||||
("cost_dollars", {}),
|
||||
],
|
||||
)
|
||||
|
||||
def run(
|
||||
|
||||
@@ -4,28 +4,24 @@ from backend.sdk import (
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
List,
|
||||
SchemaField,
|
||||
String,
|
||||
provider,
|
||||
requests,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from .helpers import ContentSettings
|
||||
|
||||
|
||||
@provider("exa")
|
||||
class ExaContentsBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = CredentialsField(
|
||||
provider="exa",
|
||||
supported_credential_types={"api_key"},
|
||||
description="The Exa integration requires an API Key.",
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
ids: List[String] = SchemaField(
|
||||
description="Array of document IDs obtained from searches",
|
||||
description="Array of document IDs obtained from searches"
|
||||
)
|
||||
contents: ContentSettings = SchemaField(
|
||||
description="Content retrieval settings",
|
||||
@@ -35,8 +31,7 @@ class ExaContentsBlock(Block):
|
||||
|
||||
class Output(BlockSchema):
|
||||
results: List = SchemaField(
|
||||
description="List of document contents",
|
||||
default_factory=list,
|
||||
description="List of document contents", default_factory=list
|
||||
)
|
||||
error: String = SchemaField(
|
||||
description="Error message if the request failed", default=""
|
||||
|
||||
@@ -7,51 +7,38 @@ from backend.sdk import (
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
Boolean,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
Integer,
|
||||
List,
|
||||
SchemaField,
|
||||
String,
|
||||
provider,
|
||||
requests,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from .helpers import ContentSettings
|
||||
|
||||
|
||||
@provider("exa")
|
||||
class ExaSearchBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = CredentialsField(
|
||||
provider="exa",
|
||||
supported_credential_types={"api_key"},
|
||||
description="The Exa integration requires an API Key.",
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
query: String = SchemaField(description="The search query")
|
||||
use_auto_prompt: Boolean = SchemaField(
|
||||
description="Whether to use autoprompt",
|
||||
default=True,
|
||||
advanced=True,
|
||||
description="Whether to use autoprompt", default=True, advanced=True
|
||||
)
|
||||
type: String = SchemaField(
|
||||
description="Type of search",
|
||||
default="",
|
||||
advanced=True,
|
||||
description="Type of search", default="", advanced=True
|
||||
)
|
||||
category: String = SchemaField(
|
||||
description="Category to search within",
|
||||
default="",
|
||||
advanced=True,
|
||||
description="Category to search within", default="", advanced=True
|
||||
)
|
||||
number_of_results: Integer = SchemaField(
|
||||
description="Number of results to return",
|
||||
default=10,
|
||||
advanced=True,
|
||||
description="Number of results to return", default=10, advanced=True
|
||||
)
|
||||
include_domains: List[String] = SchemaField(
|
||||
description="Domains to include in search",
|
||||
default_factory=list,
|
||||
description="Domains to include in search", default_factory=list
|
||||
)
|
||||
exclude_domains: List[String] = SchemaField(
|
||||
description="Domains to exclude from search",
|
||||
@@ -59,26 +46,22 @@ class ExaSearchBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
start_crawl_date: datetime = SchemaField(
|
||||
description="Start date for crawled content",
|
||||
description="Start date for crawled content"
|
||||
)
|
||||
end_crawl_date: datetime = SchemaField(
|
||||
description="End date for crawled content",
|
||||
description="End date for crawled content"
|
||||
)
|
||||
start_published_date: datetime = SchemaField(
|
||||
description="Start date for published content",
|
||||
description="Start date for published content"
|
||||
)
|
||||
end_published_date: datetime = SchemaField(
|
||||
description="End date for published content",
|
||||
description="End date for published content"
|
||||
)
|
||||
include_text: List[String] = SchemaField(
|
||||
description="Text patterns to include",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
description="Text patterns to include", default_factory=list, advanced=True
|
||||
)
|
||||
exclude_text: List[String] = SchemaField(
|
||||
description="Text patterns to exclude",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
description="Text patterns to exclude", default_factory=list, advanced=True
|
||||
)
|
||||
contents: ContentSettings = SchemaField(
|
||||
description="Content retrieval settings",
|
||||
@@ -88,12 +71,10 @@ class ExaSearchBlock(Block):
|
||||
|
||||
class Output(BlockSchema):
|
||||
results: List = SchemaField(
|
||||
description="List of search results",
|
||||
default_factory=list,
|
||||
description="List of search results", default_factory=list
|
||||
)
|
||||
error: String = SchemaField(
|
||||
description="Error message if the request failed",
|
||||
default="",
|
||||
description="Error message if the request failed", default=""
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -7,34 +7,28 @@ from backend.sdk import (
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
Integer,
|
||||
List,
|
||||
SchemaField,
|
||||
String,
|
||||
provider,
|
||||
requests,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from .helpers import ContentSettings
|
||||
|
||||
|
||||
@provider("exa")
|
||||
class ExaFindSimilarBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = CredentialsField(
|
||||
provider="exa",
|
||||
supported_credential_types={"api_key"},
|
||||
description="The Exa integration requires an API Key.",
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
url: String = SchemaField(
|
||||
description="The url for which you would like to find similar links"
|
||||
)
|
||||
number_of_results: Integer = SchemaField(
|
||||
description="Number of results to return",
|
||||
default=10,
|
||||
advanced=True,
|
||||
description="Number of results to return", default=10, advanced=True
|
||||
)
|
||||
include_domains: List[String] = SchemaField(
|
||||
description="Domains to include in search",
|
||||
@@ -47,16 +41,16 @@ class ExaFindSimilarBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
start_crawl_date: datetime = SchemaField(
|
||||
description="Start date for crawled content",
|
||||
description="Start date for crawled content"
|
||||
)
|
||||
end_crawl_date: datetime = SchemaField(
|
||||
description="End date for crawled content",
|
||||
description="End date for crawled content"
|
||||
)
|
||||
start_published_date: datetime = SchemaField(
|
||||
description="Start date for published content",
|
||||
description="Start date for published content"
|
||||
)
|
||||
end_published_date: datetime = SchemaField(
|
||||
description="End date for published content",
|
||||
description="End date for published content"
|
||||
)
|
||||
include_text: List[String] = SchemaField(
|
||||
description="Text patterns to include (max 1 string, up to 5 words)",
|
||||
|
||||
@@ -6,7 +6,6 @@ from backend.sdk import (
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
Boolean,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
Dict,
|
||||
Integer,
|
||||
@@ -14,23 +13,20 @@ from backend.sdk import (
|
||||
Optional,
|
||||
SchemaField,
|
||||
String,
|
||||
provider,
|
||||
requests,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from .helpers import WebsetEnrichmentConfig, WebsetSearchConfig
|
||||
|
||||
|
||||
@provider("exa")
|
||||
class ExaCreateWebsetBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = CredentialsField(
|
||||
provider="exa",
|
||||
supported_credential_types={"api_key"},
|
||||
description="The Exa integration requires an API Key.",
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
search: WebsetSearchConfig = SchemaField(
|
||||
description="Initial search configuration for the Webset",
|
||||
description="Initial search configuration for the Webset"
|
||||
)
|
||||
enrichments: Optional[List[WebsetEnrichmentConfig]] = SchemaField(
|
||||
default=None,
|
||||
@@ -51,21 +47,17 @@ class ExaCreateWebsetBlock(Block):
|
||||
|
||||
class Output(BlockSchema):
|
||||
webset_id: String = SchemaField(
|
||||
description="The unique identifier for the created webset",
|
||||
)
|
||||
status: String = SchemaField(
|
||||
description="The status of the webset",
|
||||
description="The unique identifier for the created webset"
|
||||
)
|
||||
status: String = SchemaField(description="The status of the webset")
|
||||
external_id: Optional[String] = SchemaField(
|
||||
description="The external identifier for the webset",
|
||||
default=None,
|
||||
description="The external identifier for the webset", default=None
|
||||
)
|
||||
created_at: String = SchemaField(
|
||||
description="The date and time the webset was created",
|
||||
description="The date and time the webset was created"
|
||||
)
|
||||
error: String = SchemaField(
|
||||
description="Error message if the request failed",
|
||||
default="",
|
||||
description="Error message if the request failed", default=""
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
@@ -121,13 +113,10 @@ class ExaCreateWebsetBlock(Block):
|
||||
yield "created_at", ""
|
||||
|
||||
|
||||
@provider("exa")
|
||||
class ExaUpdateWebsetBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = CredentialsField(
|
||||
provider="exa",
|
||||
supported_credential_types={"api_key"},
|
||||
description="The Exa integration requires an API Key.",
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: String = SchemaField(
|
||||
description="The ID or external ID of the Webset to update",
|
||||
@@ -140,25 +129,20 @@ class ExaUpdateWebsetBlock(Block):
|
||||
|
||||
class Output(BlockSchema):
|
||||
webset_id: String = SchemaField(
|
||||
description="The unique identifier for the webset",
|
||||
)
|
||||
status: String = SchemaField(
|
||||
description="The status of the webset",
|
||||
description="The unique identifier for the webset"
|
||||
)
|
||||
status: String = SchemaField(description="The status of the webset")
|
||||
external_id: Optional[String] = SchemaField(
|
||||
description="The external identifier for the webset",
|
||||
default=None,
|
||||
description="The external identifier for the webset", default=None
|
||||
)
|
||||
metadata: Dict = SchemaField(
|
||||
description="Updated metadata for the webset",
|
||||
default_factory=dict,
|
||||
description="Updated metadata for the webset", default_factory=dict
|
||||
)
|
||||
updated_at: String = SchemaField(
|
||||
description="The date and time the webset was updated",
|
||||
description="The date and time the webset was updated"
|
||||
)
|
||||
error: String = SchemaField(
|
||||
description="Error message if the request failed",
|
||||
default="",
|
||||
description="Error message if the request failed", default=""
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
@@ -203,13 +187,10 @@ class ExaUpdateWebsetBlock(Block):
|
||||
yield "updated_at", ""
|
||||
|
||||
|
||||
@provider("exa")
|
||||
class ExaListWebsetsBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = CredentialsField(
|
||||
provider="exa",
|
||||
supported_credential_types={"api_key"},
|
||||
description="The Exa integration requires an API Key.",
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
cursor: Optional[String] = SchemaField(
|
||||
default=None,
|
||||
@@ -225,21 +206,16 @@ class ExaListWebsetsBlock(Block):
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
websets: List = SchemaField(
|
||||
description="List of websets",
|
||||
default_factory=list,
|
||||
)
|
||||
websets: List = SchemaField(description="List of websets", default_factory=list)
|
||||
has_more: Boolean = SchemaField(
|
||||
description="Whether there are more results to paginate through",
|
||||
default=False,
|
||||
)
|
||||
next_cursor: Optional[String] = SchemaField(
|
||||
description="Cursor for the next page of results",
|
||||
default=None,
|
||||
description="Cursor for the next page of results", default=None
|
||||
)
|
||||
error: String = SchemaField(
|
||||
description="Error message if the request failed",
|
||||
default="",
|
||||
description="Error message if the request failed", default=""
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
@@ -280,46 +256,35 @@ class ExaListWebsetsBlock(Block):
|
||||
yield "has_more", False
|
||||
|
||||
|
||||
@provider("exa")
|
||||
class ExaGetWebsetBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = CredentialsField(
|
||||
provider="exa",
|
||||
supported_credential_types={"api_key"},
|
||||
description="The Exa integration requires an API Key.",
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: String = SchemaField(
|
||||
description="The ID or external ID of the Webset to retrieve",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
expand_items: Boolean = SchemaField(
|
||||
default=False,
|
||||
description="Include items in the response",
|
||||
advanced=True,
|
||||
default=False, description="Include items in the response", advanced=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
webset_id: String = SchemaField(
|
||||
description="The unique identifier for the webset",
|
||||
)
|
||||
status: String = SchemaField(
|
||||
description="The status of the webset",
|
||||
description="The unique identifier for the webset"
|
||||
)
|
||||
status: String = SchemaField(description="The status of the webset")
|
||||
external_id: Optional[String] = SchemaField(
|
||||
description="The external identifier for the webset",
|
||||
default=None,
|
||||
description="The external identifier for the webset", default=None
|
||||
)
|
||||
searches: List[Dict] = SchemaField(
|
||||
description="The searches performed on the webset",
|
||||
default_factory=list,
|
||||
description="The searches performed on the webset", default_factory=list
|
||||
)
|
||||
enrichments: List[Dict] = SchemaField(
|
||||
description="The enrichments applied to the webset",
|
||||
default_factory=list,
|
||||
description="The enrichments applied to the webset", default_factory=list
|
||||
)
|
||||
monitors: List[Dict] = SchemaField(
|
||||
description="The monitors for the webset",
|
||||
default_factory=list,
|
||||
description="The monitors for the webset", default_factory=list
|
||||
)
|
||||
items: Optional[List[Dict]] = SchemaField(
|
||||
description="The items in the webset (if expand_items is true)",
|
||||
@@ -330,14 +295,13 @@ class ExaGetWebsetBlock(Block):
|
||||
default_factory=dict,
|
||||
)
|
||||
created_at: String = SchemaField(
|
||||
description="The date and time the webset was created",
|
||||
description="The date and time the webset was created"
|
||||
)
|
||||
updated_at: String = SchemaField(
|
||||
description="The date and time the webset was last updated",
|
||||
description="The date and time the webset was last updated"
|
||||
)
|
||||
error: String = SchemaField(
|
||||
description="Error message if the request failed",
|
||||
default="",
|
||||
description="Error message if the request failed", default=""
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
@@ -389,13 +353,10 @@ class ExaGetWebsetBlock(Block):
|
||||
yield "updated_at", ""
|
||||
|
||||
|
||||
@provider("exa")
|
||||
class ExaDeleteWebsetBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = CredentialsField(
|
||||
provider="exa",
|
||||
supported_credential_types={"api_key"},
|
||||
description="The Exa integration requires an API Key.",
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: String = SchemaField(
|
||||
description="The ID or external ID of the Webset to delete",
|
||||
@@ -404,22 +365,17 @@ class ExaDeleteWebsetBlock(Block):
|
||||
|
||||
class Output(BlockSchema):
|
||||
webset_id: String = SchemaField(
|
||||
description="The unique identifier for the deleted webset",
|
||||
description="The unique identifier for the deleted webset"
|
||||
)
|
||||
external_id: Optional[String] = SchemaField(
|
||||
description="The external identifier for the deleted webset",
|
||||
default=None,
|
||||
)
|
||||
status: String = SchemaField(
|
||||
description="The status of the deleted webset",
|
||||
description="The external identifier for the deleted webset", default=None
|
||||
)
|
||||
status: String = SchemaField(description="The status of the deleted webset")
|
||||
success: String = SchemaField(
|
||||
description="Whether the deletion was successful",
|
||||
default="true",
|
||||
description="Whether the deletion was successful", default="true"
|
||||
)
|
||||
error: String = SchemaField(
|
||||
description="Error message if the request failed",
|
||||
default="",
|
||||
description="Error message if the request failed", default=""
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
@@ -456,13 +412,10 @@ class ExaDeleteWebsetBlock(Block):
|
||||
yield "success", "false"
|
||||
|
||||
|
||||
@provider("exa")
|
||||
class ExaCancelWebsetBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = CredentialsField(
|
||||
provider="exa",
|
||||
supported_credential_types={"api_key"},
|
||||
description="The Exa integration requires an API Key.",
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: String = SchemaField(
|
||||
description="The ID or external ID of the Webset to cancel",
|
||||
@@ -471,22 +424,19 @@ class ExaCancelWebsetBlock(Block):
|
||||
|
||||
class Output(BlockSchema):
|
||||
webset_id: String = SchemaField(
|
||||
description="The unique identifier for the webset",
|
||||
description="The unique identifier for the webset"
|
||||
)
|
||||
status: String = SchemaField(
|
||||
description="The status of the webset after cancellation",
|
||||
description="The status of the webset after cancellation"
|
||||
)
|
||||
external_id: Optional[String] = SchemaField(
|
||||
description="The external identifier for the webset",
|
||||
default=None,
|
||||
description="The external identifier for the webset", default=None
|
||||
)
|
||||
success: String = SchemaField(
|
||||
description="Whether the cancellation was successful",
|
||||
default="true",
|
||||
description="Whether the cancellation was successful", default="true"
|
||||
)
|
||||
error: String = SchemaField(
|
||||
description="Error message if the request failed",
|
||||
default="",
|
||||
description="Error message if the request failed", default=""
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -11,8 +11,6 @@ from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockCost,
|
||||
BlockCostType,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
Boolean,
|
||||
@@ -22,9 +20,6 @@ from backend.sdk import (
|
||||
SchemaField,
|
||||
SecretStr,
|
||||
String,
|
||||
cost_config,
|
||||
default_credentials,
|
||||
provider,
|
||||
)
|
||||
|
||||
# Define test credentials for testing
|
||||
@@ -44,21 +39,7 @@ TEST_CREDENTIALS_INPUT = {
|
||||
}
|
||||
|
||||
|
||||
# Example of a simple service with auto-registration
|
||||
@provider("example-service") # Custom provider demonstrating SDK flexibility
|
||||
@cost_config(
|
||||
BlockCost(cost_amount=2, cost_type=BlockCostType.RUN),
|
||||
BlockCost(cost_amount=1, cost_type=BlockCostType.BYTE),
|
||||
)
|
||||
@default_credentials(
|
||||
APIKeyCredentials(
|
||||
id="example-service-default",
|
||||
provider="example-service", # Custom provider name
|
||||
api_key=SecretStr("example-default-api-key"),
|
||||
title="Example Service Default API Key",
|
||||
expires_at=None,
|
||||
)
|
||||
)
|
||||
# Example of a simple service
|
||||
class ExampleSDKBlock(Block):
|
||||
"""
|
||||
Example block demonstrating the new SDK system.
|
||||
|
||||
@@ -5,13 +5,13 @@ This demonstrates webhook auto-registration without modifying
|
||||
files outside the blocks folder.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from backend.sdk import (
|
||||
BaseModel,
|
||||
BaseWebhooksManager,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockCost,
|
||||
BlockCostType,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
@@ -19,14 +19,10 @@ from backend.sdk import (
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
Dict,
|
||||
Enum,
|
||||
Field,
|
||||
ProviderName,
|
||||
SchemaField,
|
||||
String,
|
||||
cost_config,
|
||||
provider,
|
||||
webhook_config,
|
||||
)
|
||||
|
||||
|
||||
@@ -72,22 +68,10 @@ class ExampleWebhookManager(BaseWebhooksManager):
|
||||
pass
|
||||
|
||||
|
||||
# Now create the webhook block with auto-registration
|
||||
@provider("examplewebhook")
|
||||
@webhook_config("examplewebhook", ExampleWebhookManager)
|
||||
@cost_config(
|
||||
BlockCost(
|
||||
cost_amount=0, cost_type=BlockCostType.RUN
|
||||
) # Webhooks typically free to receive
|
||||
)
|
||||
# Now create the webhook block
|
||||
class ExampleWebhookSDKBlock(Block):
|
||||
"""
|
||||
Example webhook block demonstrating SDK webhook capabilities.
|
||||
|
||||
With the new SDK:
|
||||
- Webhook manager registered via @webhook_config decorator
|
||||
- No need to modify webhooks/__init__.py
|
||||
- Fully self-contained webhook implementation
|
||||
Example webhook block demonstrating webhook capabilities.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
|
||||
@@ -18,28 +18,21 @@ After SDK: Single import statement
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockCost,
|
||||
BlockCostType,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
Integer,
|
||||
SchemaField,
|
||||
String,
|
||||
cost_config,
|
||||
provider,
|
||||
)
|
||||
|
||||
|
||||
@provider("simple_service")
|
||||
@cost_config(BlockCost(cost_amount=1, cost_type=BlockCostType.RUN))
|
||||
class SimpleExampleBlock(Block):
|
||||
"""
|
||||
A simple example block showing the power of the SDK.
|
||||
|
||||
Key benefits:
|
||||
1. Single import: from backend.sdk import *
|
||||
2. Auto-registration via decorators
|
||||
3. No manual config file updates needed
|
||||
2. Clean, simple block structure
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
|
||||
@@ -14,14 +14,14 @@ This module provides:
|
||||
"""
|
||||
|
||||
# Standard library imports
|
||||
import asyncio
|
||||
import logging
|
||||
from enum import Enum
|
||||
from logging import getLogger as TruncatedLogger
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
from typing import Dict as DictType
|
||||
from typing import List as ListType
|
||||
from typing import Literal
|
||||
from typing import Literal as _Literal
|
||||
from typing import Optional, Set, Tuple, Type, TypeVar, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import requests
|
||||
|
||||
# Third-party imports
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
@@ -47,23 +47,15 @@ from backend.data.model import (
|
||||
|
||||
# === INTEGRATIONS ===
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.sdk.builder import ProviderBuilder
|
||||
from backend.sdk.provider import Provider
|
||||
|
||||
# === NEW SDK COMPONENTS (imported early for patches) ===
|
||||
from backend.sdk.registry import AutoRegistry, BlockConfiguration
|
||||
|
||||
# === UTILITIES ===
|
||||
from backend.util import json
|
||||
|
||||
# === AUTO-REGISTRATION DECORATORS ===
|
||||
from .decorators import (
|
||||
cost_config,
|
||||
default_credentials,
|
||||
oauth_config,
|
||||
provider,
|
||||
register_cost,
|
||||
register_credentials,
|
||||
register_oauth,
|
||||
register_webhook_manager,
|
||||
webhook_config,
|
||||
)
|
||||
|
||||
# === OPTIONAL IMPORTS WITH TRY/EXCEPT ===
|
||||
# Webhooks
|
||||
try:
|
||||
@@ -112,31 +104,8 @@ except ImportError:
|
||||
try:
|
||||
from backend.util.logging import TruncatedLogger
|
||||
except ImportError:
|
||||
TruncatedLogger = TruncatedLogger # Use the one imported at top
|
||||
TruncatedLogger = None
|
||||
|
||||
# GitHub components
|
||||
try:
|
||||
from backend.blocks.github._auth import (
|
||||
GithubCredentials,
|
||||
GithubCredentialsField,
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
except ImportError:
|
||||
GithubCredentials = None
|
||||
GithubCredentialsInput = None
|
||||
GithubCredentialsField = None
|
||||
|
||||
# Google components
|
||||
try:
|
||||
from backend.blocks.google._auth import (
|
||||
GoogleCredentials,
|
||||
GoogleCredentialsField,
|
||||
GoogleCredentialsInput,
|
||||
)
|
||||
except ImportError:
|
||||
GoogleCredentials = None
|
||||
GoogleCredentialsInput = None
|
||||
GoogleCredentialsField = None
|
||||
|
||||
# OAuth handlers
|
||||
try:
|
||||
@@ -144,49 +113,25 @@ try:
|
||||
except ImportError:
|
||||
BaseOAuthHandler = None
|
||||
|
||||
try:
|
||||
from backend.integrations.oauth.github import GitHubOAuthHandler
|
||||
except ImportError:
|
||||
GitHubOAuthHandler = None
|
||||
|
||||
try:
|
||||
from backend.integrations.oauth.google import GoogleOAuthHandler
|
||||
except ImportError:
|
||||
GoogleOAuthHandler = None
|
||||
|
||||
# Webhook managers
|
||||
try:
|
||||
from backend.integrations.webhooks.github import GithubWebhooksManager
|
||||
except ImportError:
|
||||
GithubWebhooksManager = None
|
||||
|
||||
try:
|
||||
from backend.integrations.webhooks.generic import GenericWebhooksManager
|
||||
except ImportError:
|
||||
GenericWebhooksManager = None
|
||||
|
||||
# === VARIABLE ASSIGNMENTS AND TYPE ALIASES ===
|
||||
# Type aliases
|
||||
# Type aliases for block development
|
||||
String = str
|
||||
Integer = int
|
||||
Float = float
|
||||
Boolean = bool
|
||||
List = ListType
|
||||
Dict = DictType
|
||||
|
||||
# Credential type with proper provider name
|
||||
CredentialsMetaInput = _CredentialsMetaInput[
|
||||
ProviderName, _Literal["api_key", "oauth2", "user_password"]
|
||||
]
|
||||
|
||||
# Webhook manager aliases
|
||||
if GithubWebhooksManager is not None:
|
||||
GitHubWebhooksManager = GithubWebhooksManager # Alias for consistency
|
||||
else:
|
||||
GitHubWebhooksManager = None
|
||||
|
||||
if GenericWebhooksManager is not None:
|
||||
GenericWebhookManager = GenericWebhooksManager # Alias for consistency
|
||||
else:
|
||||
GenericWebhookManager = None
|
||||
# Initialize the registry's integration patches
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
|
||||
# === COMPREHENSIVE __all__ EXPORT ===
|
||||
__all__ = [
|
||||
@@ -216,19 +161,7 @@ __all__ = [
|
||||
"BaseWebhooksManager",
|
||||
"ManualWebhookManagerBase",
|
||||
# Provider-Specific (when available)
|
||||
"GithubCredentials",
|
||||
"GithubCredentialsInput",
|
||||
"GithubCredentialsField",
|
||||
"GoogleCredentials",
|
||||
"GoogleCredentialsInput",
|
||||
"GoogleCredentialsField",
|
||||
"BaseOAuthHandler",
|
||||
"GitHubOAuthHandler",
|
||||
"GoogleOAuthHandler",
|
||||
"GitHubWebhooksManager",
|
||||
"GithubWebhooksManager",
|
||||
"GenericWebhookManager",
|
||||
"GenericWebhooksManager",
|
||||
# Utilities
|
||||
"json",
|
||||
"store_media_file",
|
||||
@@ -236,9 +169,7 @@ __all__ = [
|
||||
"convert",
|
||||
"TextFormatter",
|
||||
"TruncatedLogger",
|
||||
"logging",
|
||||
"asyncio",
|
||||
# Types
|
||||
# Type aliases for blocks
|
||||
"String",
|
||||
"Integer",
|
||||
"Float",
|
||||
@@ -247,26 +178,17 @@ __all__ = [
|
||||
"Dict",
|
||||
"Optional",
|
||||
"Any",
|
||||
"Literal",
|
||||
"Union",
|
||||
"TypeVar",
|
||||
"Type",
|
||||
"Tuple",
|
||||
"Set",
|
||||
"Literal",
|
||||
"BaseModel",
|
||||
"SecretStr",
|
||||
"Field",
|
||||
"Enum",
|
||||
# Auto-Registration Decorators
|
||||
"register_credentials",
|
||||
"register_cost",
|
||||
"register_oauth",
|
||||
"register_webhook_manager",
|
||||
"provider",
|
||||
"cost_config",
|
||||
"webhook_config",
|
||||
"default_credentials",
|
||||
"oauth_config",
|
||||
"SecretStr",
|
||||
"requests",
|
||||
# SDK Components
|
||||
"AutoRegistry",
|
||||
"BlockConfiguration",
|
||||
"Provider",
|
||||
"ProviderBuilder",
|
||||
]
|
||||
|
||||
# Remove None values from __all__
|
||||
|
||||
131
autogpt_platform/backend/backend/sdk/builder.py
Normal file
131
autogpt_platform/backend/backend/sdk/builder.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""
|
||||
Builder class for creating provider configurations with a fluent API.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Callable, List, Optional, Type
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.cost import BlockCost, BlockCostType
|
||||
from backend.data.model import APIKeyCredentials, Credentials
|
||||
from backend.integrations.oauth.base import BaseOAuthHandler
|
||||
from backend.integrations.webhooks._base import BaseWebhooksManager
|
||||
from backend.sdk.provider import Provider
|
||||
from backend.sdk.registry import AutoRegistry
|
||||
from backend.util.settings import Settings
|
||||
|
||||
|
||||
class ProviderBuilder:
|
||||
"""Builder for creating provider configurations."""
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self._oauth_handler: Optional[Type[BaseOAuthHandler]] = None
|
||||
self._webhook_manager: Optional[Type[BaseWebhooksManager]] = None
|
||||
self._default_credentials: List[Credentials] = []
|
||||
self._base_costs: List[BlockCost] = []
|
||||
self._supported_auth_types: set = set()
|
||||
self._api_client_factory: Optional[Callable] = None
|
||||
self._error_handler: Optional[Callable[[Exception], str]] = None
|
||||
self._default_scopes: Optional[List[str]] = None
|
||||
self._extra_config: dict = {}
|
||||
|
||||
def with_oauth(
|
||||
self, handler_class: Type[BaseOAuthHandler], scopes: Optional[List[str]] = None
|
||||
) -> "ProviderBuilder":
|
||||
"""Add OAuth support."""
|
||||
self._oauth_handler = handler_class
|
||||
self._supported_auth_types.add("oauth2")
|
||||
if scopes:
|
||||
self._default_scopes = scopes
|
||||
return self
|
||||
|
||||
def with_api_key(self, env_var_name: str, title: str) -> "ProviderBuilder":
|
||||
"""Add API key support with environment variable name."""
|
||||
self._supported_auth_types.add("api_key")
|
||||
|
||||
# Register the API key mapping
|
||||
AutoRegistry.register_api_key(self.name, env_var_name)
|
||||
|
||||
# Check if API key exists in environment
|
||||
api_key = os.getenv(env_var_name)
|
||||
if api_key:
|
||||
self._default_credentials.append(
|
||||
APIKeyCredentials(
|
||||
id=f"{self.name}-default",
|
||||
provider=self.name,
|
||||
api_key=SecretStr(api_key),
|
||||
title=title,
|
||||
)
|
||||
)
|
||||
return self
|
||||
|
||||
def with_api_key_from_settings(
|
||||
self, settings_attr: str, title: str
|
||||
) -> "ProviderBuilder":
|
||||
"""Use existing API key from settings."""
|
||||
self._supported_auth_types.add("api_key")
|
||||
|
||||
# Try to get the API key from settings
|
||||
settings = Settings()
|
||||
api_key = getattr(settings.secrets, settings_attr, None)
|
||||
if api_key:
|
||||
self._default_credentials.append(
|
||||
APIKeyCredentials(
|
||||
id=f"{self.name}-default",
|
||||
provider=self.name,
|
||||
api_key=api_key,
|
||||
title=title,
|
||||
)
|
||||
)
|
||||
return self
|
||||
|
||||
def with_webhook_manager(
|
||||
self, manager_class: Type[BaseWebhooksManager]
|
||||
) -> "ProviderBuilder":
|
||||
"""Register webhook manager for this provider."""
|
||||
self._webhook_manager = manager_class
|
||||
return self
|
||||
|
||||
def with_base_cost(
|
||||
self, amount: int, cost_type: BlockCostType
|
||||
) -> "ProviderBuilder":
|
||||
"""Set base cost for all blocks using this provider."""
|
||||
self._base_costs.append(BlockCost(cost_amount=amount, cost_type=cost_type))
|
||||
return self
|
||||
|
||||
def with_api_client(self, factory: Callable) -> "ProviderBuilder":
|
||||
"""Register API client factory."""
|
||||
self._api_client_factory = factory
|
||||
return self
|
||||
|
||||
def with_error_handler(
|
||||
self, handler: Callable[[Exception], str]
|
||||
) -> "ProviderBuilder":
|
||||
"""Register error handler for provider-specific errors."""
|
||||
self._error_handler = handler
|
||||
return self
|
||||
|
||||
def with_config(self, **kwargs) -> "ProviderBuilder":
|
||||
"""Add additional configuration options."""
|
||||
self._extra_config.update(kwargs)
|
||||
return self
|
||||
|
||||
def build(self) -> Provider:
|
||||
"""Build and register the provider configuration."""
|
||||
provider = Provider(
|
||||
name=self.name,
|
||||
oauth_handler=self._oauth_handler,
|
||||
webhook_manager=self._webhook_manager,
|
||||
default_credentials=self._default_credentials,
|
||||
base_costs=self._base_costs,
|
||||
supported_auth_types=self._supported_auth_types,
|
||||
api_client_factory=self._api_client_factory,
|
||||
error_handler=self._error_handler,
|
||||
**self._extra_config,
|
||||
)
|
||||
|
||||
# Auto-registration happens here
|
||||
AutoRegistry.register_provider(provider)
|
||||
return provider
|
||||
66
autogpt_platform/backend/backend/sdk/provider.py
Normal file
66
autogpt_platform/backend/backend/sdk/provider.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""
|
||||
Provider configuration class that holds all provider-related settings.
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, List, Optional, Set, Type
|
||||
|
||||
from backend.data.cost import BlockCost
|
||||
from backend.data.model import Credentials, CredentialsField, CredentialsMetaInput
|
||||
from backend.integrations.oauth.base import BaseOAuthHandler
|
||||
from backend.integrations.webhooks._base import BaseWebhooksManager
|
||||
|
||||
|
||||
class Provider:
|
||||
"""A configured provider that blocks can use."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
oauth_handler: Optional[Type[BaseOAuthHandler]] = None,
|
||||
webhook_manager: Optional[Type[BaseWebhooksManager]] = None,
|
||||
default_credentials: Optional[List[Credentials]] = None,
|
||||
base_costs: Optional[List[BlockCost]] = None,
|
||||
supported_auth_types: Optional[Set[str]] = None,
|
||||
api_client_factory: Optional[Callable] = None,
|
||||
error_handler: Optional[Callable[[Exception], str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.name = name
|
||||
self.oauth_handler = oauth_handler
|
||||
self.webhook_manager = webhook_manager
|
||||
self.default_credentials = default_credentials or []
|
||||
self.base_costs = base_costs or []
|
||||
self.supported_auth_types = supported_auth_types or set()
|
||||
self._api_client_factory = api_client_factory
|
||||
self._error_handler = error_handler
|
||||
|
||||
# Store any additional configuration
|
||||
self._extra_config = kwargs
|
||||
|
||||
def credentials_field(self, **kwargs) -> CredentialsMetaInput:
|
||||
"""Return a CredentialsField configured for this provider."""
|
||||
# Merge provider defaults with user overrides
|
||||
field_kwargs = {
|
||||
"provider": self.name,
|
||||
"supported_credential_types": self.supported_auth_types,
|
||||
"description": f"{self.name.title()} credentials",
|
||||
}
|
||||
field_kwargs.update(kwargs)
|
||||
|
||||
return CredentialsField(**field_kwargs)
|
||||
|
||||
def get_api(self, credentials: Credentials) -> Any:
|
||||
"""Get API client instance for the given credentials."""
|
||||
if self._api_client_factory:
|
||||
return self._api_client_factory(credentials)
|
||||
raise NotImplementedError(f"No API client factory registered for {self.name}")
|
||||
|
||||
def handle_error(self, error: Exception) -> str:
|
||||
"""Handle provider-specific errors."""
|
||||
if self._error_handler:
|
||||
return self._error_handler(error)
|
||||
return str(error)
|
||||
|
||||
def get_config(self, key: str, default: Any = None) -> Any:
|
||||
"""Get additional configuration value."""
|
||||
return self._extra_config.get(key, default)
|
||||
203
autogpt_platform/backend/backend/sdk/registry.py
Normal file
203
autogpt_platform/backend/backend/sdk/registry.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""
|
||||
Auto-registration system for blocks, providers, and their configurations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks.basic import Block
|
||||
from backend.data.model import APIKeyCredentials, Credentials
|
||||
from backend.integrations.oauth.base import BaseOAuthHandler
|
||||
from backend.integrations.webhooks._base import BaseWebhooksManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.sdk.provider import Provider
|
||||
|
||||
|
||||
class BlockConfiguration:
|
||||
"""Configuration associated with a block."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: str,
|
||||
costs: List[Any],
|
||||
default_credentials: List[Credentials],
|
||||
webhook_manager: Optional[Type[BaseWebhooksManager]] = None,
|
||||
oauth_handler: Optional[Type[BaseOAuthHandler]] = None,
|
||||
):
|
||||
self.provider = provider
|
||||
self.costs = costs
|
||||
self.default_credentials = default_credentials
|
||||
self.webhook_manager = webhook_manager
|
||||
self.oauth_handler = oauth_handler
|
||||
|
||||
|
||||
class AutoRegistry:
|
||||
"""Central registry for all block-related configurations."""
|
||||
|
||||
_lock = threading.Lock()
|
||||
_providers: Dict[str, "Provider"] = {}
|
||||
_default_credentials: List[Credentials] = []
|
||||
_oauth_handlers: Dict[str, Type[BaseOAuthHandler]] = {}
|
||||
_webhook_managers: Dict[str, Type[BaseWebhooksManager]] = {}
|
||||
_block_configurations: Dict[Type[Block], BlockConfiguration] = {}
|
||||
_api_key_mappings: Dict[str, str] = {} # provider -> env_var_name
|
||||
|
||||
@classmethod
|
||||
def register_provider(cls, provider: "Provider") -> None:
|
||||
"""Auto-register provider and all its configurations."""
|
||||
with cls._lock:
|
||||
cls._providers[provider.name] = provider
|
||||
|
||||
# Register OAuth handler if provided
|
||||
if provider.oauth_handler:
|
||||
cls._oauth_handlers[provider.name] = provider.oauth_handler
|
||||
|
||||
# Register webhook manager if provided
|
||||
if provider.webhook_manager:
|
||||
cls._webhook_managers[provider.name] = provider.webhook_manager
|
||||
|
||||
# Register default credentials
|
||||
cls._default_credentials.extend(provider.default_credentials)
|
||||
|
||||
@classmethod
|
||||
def register_api_key(cls, provider: str, env_var_name: str) -> None:
|
||||
"""Register an environment variable as an API key for a provider."""
|
||||
with cls._lock:
|
||||
cls._api_key_mappings[provider] = env_var_name
|
||||
|
||||
# Dynamically check if the env var exists and create credential
|
||||
import os
|
||||
|
||||
api_key = os.getenv(env_var_name)
|
||||
if api_key:
|
||||
credential = APIKeyCredentials(
|
||||
id=f"{provider}-default",
|
||||
provider=provider,
|
||||
api_key=SecretStr(api_key),
|
||||
title=f"Default {provider} credentials",
|
||||
)
|
||||
# Check if credential already exists to avoid duplicates
|
||||
if not any(c.id == credential.id for c in cls._default_credentials):
|
||||
cls._default_credentials.append(credential)
|
||||
|
||||
@classmethod
|
||||
def get_all_credentials(cls) -> List[Credentials]:
|
||||
"""Replace hardcoded get_all_creds() in credentials_store.py."""
|
||||
with cls._lock:
|
||||
return cls._default_credentials.copy()
|
||||
|
||||
@classmethod
|
||||
def get_oauth_handlers(cls) -> Dict[str, Type[BaseOAuthHandler]]:
|
||||
"""Replace HANDLERS_BY_NAME in oauth/__init__.py."""
|
||||
with cls._lock:
|
||||
return cls._oauth_handlers.copy()
|
||||
|
||||
@classmethod
|
||||
def get_webhook_managers(cls) -> Dict[str, Type[BaseWebhooksManager]]:
|
||||
"""Replace load_webhook_managers() in webhooks/__init__.py."""
|
||||
with cls._lock:
|
||||
return cls._webhook_managers.copy()
|
||||
|
||||
@classmethod
|
||||
def register_block_configuration(
|
||||
cls, block_class: Type[Block], config: BlockConfiguration
|
||||
) -> None:
|
||||
"""Register configuration for a specific block class."""
|
||||
with cls._lock:
|
||||
cls._block_configurations[block_class] = config
|
||||
|
||||
@classmethod
|
||||
def get_provider(cls, name: str) -> Optional["Provider"]:
|
||||
"""Get a registered provider by name."""
|
||||
with cls._lock:
|
||||
return cls._providers.get(name)
|
||||
|
||||
@classmethod
|
||||
def clear(cls) -> None:
|
||||
"""Clear all registrations (useful for testing)."""
|
||||
with cls._lock:
|
||||
cls._providers.clear()
|
||||
cls._default_credentials.clear()
|
||||
cls._oauth_handlers.clear()
|
||||
cls._webhook_managers.clear()
|
||||
cls._block_configurations.clear()
|
||||
cls._api_key_mappings.clear()
|
||||
|
||||
@classmethod
|
||||
def patch_integrations(cls) -> None:
|
||||
"""Patch existing integration points to use AutoRegistry."""
|
||||
# Patch oauth handlers
|
||||
try:
|
||||
import backend.integrations.oauth as oauth
|
||||
|
||||
if hasattr(oauth, "HANDLERS_BY_NAME"):
|
||||
# Create a new dict that includes both original and SDK handlers
|
||||
original_handlers = dict(oauth.HANDLERS_BY_NAME)
|
||||
|
||||
class PatchedHandlersDict(dict): # type: ignore
|
||||
def __getitem__(self, key):
|
||||
# First try SDK handlers
|
||||
sdk_handlers = cls.get_oauth_handlers()
|
||||
if key in sdk_handlers:
|
||||
return sdk_handlers[key]
|
||||
# Fall back to original
|
||||
return original_handlers[key]
|
||||
|
||||
def get(self, key, default=None):
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
def __contains__(self, key):
|
||||
sdk_handlers = cls.get_oauth_handlers()
|
||||
return key in sdk_handlers or key in original_handlers
|
||||
|
||||
def keys(self): # type: ignore[override]
|
||||
sdk_handlers = cls.get_oauth_handlers()
|
||||
all_keys = set(original_handlers.keys()) | set(
|
||||
sdk_handlers.keys()
|
||||
)
|
||||
return all_keys
|
||||
|
||||
def values(self):
|
||||
combined = dict(original_handlers)
|
||||
sdk_handlers = cls.get_oauth_handlers()
|
||||
if isinstance(sdk_handlers, dict):
|
||||
combined.update(sdk_handlers) # type: ignore
|
||||
return combined.values()
|
||||
|
||||
def items(self):
|
||||
combined = dict(original_handlers)
|
||||
sdk_handlers = cls.get_oauth_handlers()
|
||||
if isinstance(sdk_handlers, dict):
|
||||
combined.update(sdk_handlers) # type: ignore
|
||||
return combined.items()
|
||||
|
||||
oauth.HANDLERS_BY_NAME = PatchedHandlersDict()
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to patch oauth handlers: {e}")
|
||||
|
||||
# Patch webhook managers
|
||||
try:
|
||||
import backend.integrations.webhooks as webhooks
|
||||
|
||||
if hasattr(webhooks, "load_webhook_managers"):
|
||||
original_load = webhooks.load_webhook_managers
|
||||
|
||||
def patched_load():
|
||||
# Get original managers
|
||||
managers = original_load()
|
||||
# Add SDK-registered managers
|
||||
sdk_managers = cls.get_webhook_managers()
|
||||
if isinstance(sdk_managers, dict):
|
||||
managers.update(sdk_managers) # type: ignore
|
||||
return managers
|
||||
|
||||
webhooks.load_webhook_managers = patched_load
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to patch webhook managers: {e}")
|
||||
@@ -20,7 +20,7 @@ from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.oauth import HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks import get_webhook_manager
|
||||
from backend.sdk.auto_registry import get_registry
|
||||
from backend.sdk.registry import AutoRegistry
|
||||
from backend.util.exceptions import NeedConfirmation, NotFoundError
|
||||
from backend.util.settings import Settings
|
||||
|
||||
@@ -434,8 +434,7 @@ async def list_providers() -> List[str]:
|
||||
static_providers = [member.value for member in ProviderName]
|
||||
|
||||
# Get dynamic providers from registry
|
||||
registry = get_registry()
|
||||
dynamic_providers = list(registry.providers)
|
||||
dynamic_providers = list(AutoRegistry._providers.keys())
|
||||
|
||||
# Combine and deduplicate
|
||||
all_providers = list(set(static_providers + dynamic_providers))
|
||||
@@ -461,7 +460,7 @@ async def get_providers_details() -> Dict[str, ProviderDetails]:
|
||||
Returns a dictionary mapping provider names to their details,
|
||||
including supported credential types and other metadata.
|
||||
"""
|
||||
registry = get_registry()
|
||||
# AutoRegistry is used directly as a class with class methods
|
||||
|
||||
# Build provider details
|
||||
provider_details: Dict[str, ProviderDetails] = {}
|
||||
@@ -471,24 +470,26 @@ async def get_providers_details() -> Dict[str, ProviderDetails]:
|
||||
provider_details[member.value] = ProviderDetails(
|
||||
name=member.value,
|
||||
source="static",
|
||||
has_oauth=member.value in registry.oauth_handlers,
|
||||
has_webhooks=member.value in registry.webhook_managers,
|
||||
has_oauth=member.value in AutoRegistry._oauth_handlers,
|
||||
has_webhooks=member.value in AutoRegistry._webhook_managers,
|
||||
)
|
||||
|
||||
# Add/update with dynamic providers
|
||||
for provider in registry.providers:
|
||||
for provider in AutoRegistry._providers:
|
||||
if provider not in provider_details:
|
||||
provider_details[provider] = ProviderDetails(
|
||||
name=provider,
|
||||
source="dynamic",
|
||||
has_oauth=provider in registry.oauth_handlers,
|
||||
has_webhooks=provider in registry.webhook_managers,
|
||||
has_oauth=provider in AutoRegistry._oauth_handlers,
|
||||
has_webhooks=provider in AutoRegistry._webhook_managers,
|
||||
)
|
||||
else:
|
||||
provider_details[provider].source = "both"
|
||||
provider_details[provider].has_oauth = provider in registry.oauth_handlers
|
||||
provider_details[provider].has_oauth = (
|
||||
provider in AutoRegistry._oauth_handlers
|
||||
)
|
||||
provider_details[provider].has_webhooks = (
|
||||
provider in registry.webhook_managers
|
||||
provider in AutoRegistry._webhook_managers
|
||||
)
|
||||
|
||||
# Determine supported credential types for each provider
|
||||
|
||||
@@ -58,40 +58,8 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
await backend.data.db.connect()
|
||||
await backend.data.block.initialize_blocks()
|
||||
|
||||
# Set up auto-registration system for SDK
|
||||
try:
|
||||
from backend.sdk.auto_registry import setup_auto_registration
|
||||
|
||||
logger.info("Starting SDK auto-registration system...")
|
||||
registry = setup_auto_registration()
|
||||
|
||||
# Log successful registration
|
||||
logger.info("Auto-registration completed successfully:")
|
||||
logger.info(f" - {len(registry.block_costs)} block costs registered")
|
||||
logger.info(
|
||||
f" - {len(registry.default_credentials)} default credentials registered"
|
||||
)
|
||||
logger.info(f" - {len(registry.oauth_handlers)} OAuth handlers registered")
|
||||
logger.info(f" - {len(registry.webhook_managers)} webhook managers registered")
|
||||
logger.info(f" - {len(registry.providers)} providers registered")
|
||||
|
||||
# Log specific credential providers for debugging
|
||||
credential_providers = [
|
||||
getattr(cred, "provider", "unknown")
|
||||
for cred in registry.default_credentials
|
||||
]
|
||||
if credential_providers:
|
||||
logger.info(
|
||||
f" - Default credential providers: {', '.join(credential_providers)}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Auto-registration setup failed: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
# Don't let this failure prevent startup, but make it very visible
|
||||
raise
|
||||
# SDK auto-registration is now handled by AutoRegistry.patch_integrations()
|
||||
# which is called when the SDK module is imported
|
||||
|
||||
await backend.data.user.migrate_and_encrypt_user_integrations()
|
||||
await backend.data.graph.fix_llm_provider_credentials()
|
||||
|
||||
29
autogpt_platform/backend/test/sdk/conftest.py
Normal file
29
autogpt_platform/backend/test/sdk/conftest.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""
|
||||
Configuration for SDK tests.
|
||||
|
||||
This conftest.py file provides basic test setup for SDK unit tests
|
||||
without requiring the full server infrastructure.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def server():
|
||||
"""Mock server fixture for SDK tests."""
|
||||
mock_server = MagicMock()
|
||||
mock_server.agent_server = MagicMock()
|
||||
mock_server.agent_server.test_create_graph = MagicMock()
|
||||
return mock_server
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_registry():
|
||||
"""Reset the AutoRegistry before each test."""
|
||||
from backend.sdk.registry import AutoRegistry
|
||||
|
||||
AutoRegistry.clear()
|
||||
yield
|
||||
AutoRegistry.clear()
|
||||
@@ -1,231 +0,0 @@
|
||||
"""
|
||||
Demo: Creating a new block with the SDK using 'from backend.sdk import *'
|
||||
|
||||
This file demonstrates the simplified block creation process.
|
||||
"""
|
||||
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockCost,
|
||||
BlockCostType,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
Float,
|
||||
List,
|
||||
SchemaField,
|
||||
SecretStr,
|
||||
String,
|
||||
cost_config,
|
||||
default_credentials,
|
||||
provider,
|
||||
)
|
||||
|
||||
|
||||
# Create a custom translation service block with full auto-registration
|
||||
@provider("ultra-translate-ai")
|
||||
@cost_config(
|
||||
BlockCost(cost_amount=3, cost_type=BlockCostType.RUN),
|
||||
BlockCost(cost_amount=1, cost_type=BlockCostType.BYTE),
|
||||
)
|
||||
@default_credentials(
|
||||
APIKeyCredentials(
|
||||
id="ultra-translate-default",
|
||||
provider="ultra-translate-ai",
|
||||
api_key=SecretStr("ultra-translate-default-api-key"),
|
||||
title="Ultra Translate AI Default API Key",
|
||||
)
|
||||
)
|
||||
class UltraTranslateBlock(Block):
|
||||
"""
|
||||
Ultra Translate AI - Advanced Translation Service
|
||||
|
||||
This block demonstrates how easy it is to create a new service integration
|
||||
with the SDK. No external configuration files need to be modified!
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = CredentialsField(
|
||||
provider="ultra-translate-ai",
|
||||
supported_credential_types={"api_key"},
|
||||
description="API credentials for Ultra Translate AI",
|
||||
)
|
||||
text: String = SchemaField(
|
||||
description="Text to translate", placeholder="Enter text to translate..."
|
||||
)
|
||||
source_language: String = SchemaField(
|
||||
description="Source language code (auto-detect if empty)",
|
||||
default="",
|
||||
placeholder="en, es, fr, de, ja, zh",
|
||||
)
|
||||
target_language: String = SchemaField(
|
||||
description="Target language code",
|
||||
default="es",
|
||||
placeholder="en, es, fr, de, ja, zh",
|
||||
)
|
||||
formality: String = SchemaField(
|
||||
description="Translation formality level (formal, neutral, informal)",
|
||||
default="neutral",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
translated_text: String = SchemaField(description="The translated text")
|
||||
detected_language: String = SchemaField(
|
||||
description="Auto-detected source language (if applicable)"
|
||||
)
|
||||
confidence: Float = SchemaField(
|
||||
description="Translation confidence score (0-1)"
|
||||
)
|
||||
alternatives: List[String] = SchemaField(
|
||||
description="Alternative translations", default=[]
|
||||
)
|
||||
error: String = SchemaField(
|
||||
description="Error message if translation failed", default=""
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="abb8abc4-c968-45fe-815a-c5521ad67f32",
|
||||
description="Translate text between languages using Ultra Translate AI",
|
||||
categories={BlockCategory.TEXT, BlockCategory.AI},
|
||||
input_schema=UltraTranslateBlock.Input,
|
||||
output_schema=UltraTranslateBlock.Output,
|
||||
test_input={
|
||||
"text": "Hello, how are you?",
|
||||
"target_language": "es",
|
||||
"formality": "informal",
|
||||
},
|
||||
test_output=[
|
||||
("translated_text", "¡Hola! ¿Cómo estás?"),
|
||||
("detected_language", "en"),
|
||||
("confidence", 0.98),
|
||||
("alternatives", ["¡Hola! ¿Qué tal?", "¡Hola! ¿Cómo te va?"]),
|
||||
("error", ""),
|
||||
],
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
# Get API key
|
||||
api_key = credentials.api_key.get_secret_value() # noqa: F841
|
||||
|
||||
# Simulate translation based on input
|
||||
translations = {
|
||||
("Hello, how are you?", "es", "informal"): {
|
||||
"text": "¡Hola! ¿Cómo estás?",
|
||||
"alternatives": ["¡Hola! ¿Qué tal?", "¡Hola! ¿Cómo te va?"],
|
||||
},
|
||||
("Hello, how are you?", "es", "formal"): {
|
||||
"text": "Hola, ¿cómo está usted?",
|
||||
"alternatives": ["Buenos días, ¿cómo se encuentra?"],
|
||||
},
|
||||
("Hello, how are you?", "fr", "neutral"): {
|
||||
"text": "Bonjour, comment allez-vous ?",
|
||||
"alternatives": ["Salut, comment ça va ?"],
|
||||
},
|
||||
("Hello, how are you?", "de", "neutral"): {
|
||||
"text": "Hallo, wie geht es dir?",
|
||||
"alternatives": ["Hallo, wie geht's?"],
|
||||
},
|
||||
}
|
||||
|
||||
# Get translation
|
||||
key = (input_data.text, input_data.target_language, input_data.formality)
|
||||
result = translations.get(
|
||||
key,
|
||||
{
|
||||
"text": f"[{input_data.target_language}] {input_data.text}",
|
||||
"alternatives": [],
|
||||
},
|
||||
)
|
||||
|
||||
# Detect source language if not provided
|
||||
detected_lang = input_data.source_language or "en"
|
||||
|
||||
yield "translated_text", result["text"]
|
||||
yield "detected_language", detected_lang
|
||||
yield "confidence", 0.95
|
||||
yield "alternatives", result["alternatives"]
|
||||
yield "error", ""
|
||||
|
||||
except Exception as e:
|
||||
yield "translated_text", ""
|
||||
yield "detected_language", ""
|
||||
yield "confidence", 0.0
|
||||
yield "alternatives", []
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
# This function demonstrates testing the block
|
||||
def demo_block_usage():
|
||||
"""Demonstrate using the block"""
|
||||
print("=" * 60)
|
||||
print("🌐 Ultra Translate AI Block Demo")
|
||||
print("=" * 60)
|
||||
|
||||
# Create block instance
|
||||
block = UltraTranslateBlock()
|
||||
print(f"\n✅ Created block: {block.name}")
|
||||
print(f" ID: {block.id}")
|
||||
print(f" Categories: {block.categories}")
|
||||
|
||||
# Check auto-registration
|
||||
from backend.sdk.auto_registry import get_registry
|
||||
|
||||
registry = get_registry()
|
||||
|
||||
print("\n📋 Auto-Registration Status:")
|
||||
print(f" ✅ Provider registered: {'ultra-translate-ai' in registry.providers}")
|
||||
print(f" ✅ Costs registered: {UltraTranslateBlock in registry.block_costs}")
|
||||
if UltraTranslateBlock in registry.block_costs:
|
||||
costs = registry.block_costs[UltraTranslateBlock]
|
||||
print(f" - Per run: {costs[0].cost_amount} credits")
|
||||
print(f" - Per byte: {costs[1].cost_amount} credits")
|
||||
|
||||
creds = registry.get_default_credentials_list()
|
||||
has_default_cred = any(c.id == "ultra-translate-default" for c in creds)
|
||||
print(f" ✅ Default credentials: {has_default_cred}")
|
||||
|
||||
# Test dynamic provider enum
|
||||
print("\n🔧 Dynamic Provider Test:")
|
||||
provider = ProviderName("ultra-translate-ai")
|
||||
print(f" ✅ Custom provider accepted: {provider.value}")
|
||||
print(f" ✅ Is ProviderName instance: {isinstance(provider, ProviderName)}")
|
||||
|
||||
# Test block execution
|
||||
print("\n🚀 Test Block Execution:")
|
||||
test_creds = APIKeyCredentials(
|
||||
id="test",
|
||||
provider="ultra-translate-ai",
|
||||
api_key=SecretStr("test-api-key"),
|
||||
title="Test",
|
||||
)
|
||||
|
||||
# Create test input with credentials meta
|
||||
test_input = UltraTranslateBlock.Input(
|
||||
credentials={"provider": "ultra-translate-ai", "id": "test", "type": "api_key"},
|
||||
text="Hello, how are you?",
|
||||
target_language="es",
|
||||
formality="informal",
|
||||
)
|
||||
|
||||
results = list(block.run(test_input, credentials=test_creds))
|
||||
output = {k: v for k, v in results}
|
||||
|
||||
print(f" Input: '{test_input.text}'")
|
||||
print(f" Target: {test_input.target_language} ({test_input.formality})")
|
||||
print(f" Output: '{output['translated_text']}'")
|
||||
print(f" Confidence: {output['confidence']}")
|
||||
print(f" Alternatives: {output['alternatives']}")
|
||||
|
||||
print("\n✨ Block works perfectly with zero external configuration!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo_block_usage()
|
||||
@@ -1,247 +0,0 @@
|
||||
"""
|
||||
Test custom provider functionality in the SDK.
|
||||
|
||||
This test suite verifies that the SDK properly supports dynamic provider
|
||||
registration and that custom providers work correctly with the system.
|
||||
"""
|
||||
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockCost,
|
||||
BlockCostType,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
Boolean,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
SecretStr,
|
||||
String,
|
||||
cost_config,
|
||||
default_credentials,
|
||||
provider,
|
||||
)
|
||||
from backend.sdk.auto_registry import get_registry
|
||||
from backend.util.test import execute_block_test
|
||||
|
||||
# Test credentials for custom providers
|
||||
CUSTOM_TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="custom-provider-test-creds",
|
||||
provider="my-custom-service",
|
||||
api_key=SecretStr("test-api-key-12345"),
|
||||
title="Custom Service Test Credentials",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
CUSTOM_TEST_CREDENTIALS_INPUT = {
|
||||
"provider": CUSTOM_TEST_CREDENTIALS.provider,
|
||||
"id": CUSTOM_TEST_CREDENTIALS.id,
|
||||
"type": CUSTOM_TEST_CREDENTIALS.type,
|
||||
"title": CUSTOM_TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
@provider("my-custom-service")
|
||||
@cost_config(
|
||||
BlockCost(cost_amount=10, cost_type=BlockCostType.RUN),
|
||||
BlockCost(cost_amount=2, cost_type=BlockCostType.BYTE),
|
||||
)
|
||||
@default_credentials(
|
||||
APIKeyCredentials(
|
||||
id="my-custom-service-default",
|
||||
provider="my-custom-service",
|
||||
api_key=SecretStr("default-custom-api-key"),
|
||||
title="My Custom Service Default API Key",
|
||||
expires_at=None,
|
||||
)
|
||||
)
|
||||
class CustomProviderBlock(Block):
|
||||
"""Test block with a completely custom provider."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = CredentialsField(
|
||||
provider="my-custom-service",
|
||||
supported_credential_types={"api_key"},
|
||||
description="Custom service credentials",
|
||||
)
|
||||
message: String = SchemaField(
|
||||
description="Message to process", default="Hello from custom provider!"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
result: String = SchemaField(description="Processed message")
|
||||
provider_used: String = SchemaField(description="Provider name used")
|
||||
credentials_valid: Boolean = SchemaField(
|
||||
description="Whether credentials were valid"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d1234567-89ab-cdef-0123-456789abcdef",
|
||||
description="Test block demonstrating custom provider support",
|
||||
categories={BlockCategory.TEXT},
|
||||
input_schema=CustomProviderBlock.Input,
|
||||
output_schema=CustomProviderBlock.Output,
|
||||
test_input={
|
||||
"credentials": CUSTOM_TEST_CREDENTIALS_INPUT,
|
||||
"message": "Test message",
|
||||
},
|
||||
test_output=[
|
||||
("result", "CUSTOM: Test message"),
|
||||
("provider_used", "my-custom-service"),
|
||||
("credentials_valid", True),
|
||||
],
|
||||
test_credentials=CUSTOM_TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Verify we got the right credentials
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
yield "result", f"CUSTOM: {input_data.message}"
|
||||
yield "provider_used", credentials.provider
|
||||
yield "credentials_valid", bool(api_key)
|
||||
|
||||
|
||||
@provider("another-custom-provider")
|
||||
@cost_config(
|
||||
BlockCost(cost_amount=5, cost_type=BlockCostType.RUN),
|
||||
)
|
||||
class AnotherCustomProviderBlock(Block):
|
||||
"""Another test block to verify multiple custom providers work."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = CredentialsField(
|
||||
provider="another-custom-provider",
|
||||
supported_credential_types={"api_key"},
|
||||
)
|
||||
data: String = SchemaField(description="Input data")
|
||||
|
||||
class Output(BlockSchema):
|
||||
processed: String = SchemaField(description="Processed data")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="e2345678-9abc-def0-1234-567890abcdef",
|
||||
description="Another custom provider test",
|
||||
categories={BlockCategory.TEXT},
|
||||
input_schema=AnotherCustomProviderBlock.Input,
|
||||
output_schema=AnotherCustomProviderBlock.Output,
|
||||
test_input={
|
||||
"credentials": {
|
||||
"provider": "another-custom-provider",
|
||||
"id": "test-creds-2",
|
||||
"type": "api_key",
|
||||
"title": "Test Creds 2",
|
||||
},
|
||||
"data": "test data",
|
||||
},
|
||||
test_output=[("processed", "ANOTHER: test data")],
|
||||
test_credentials=APIKeyCredentials(
|
||||
id="test-creds-2",
|
||||
provider="another-custom-provider",
|
||||
api_key=SecretStr("another-test-key"),
|
||||
title="Another Test Key",
|
||||
expires_at=None,
|
||||
),
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
yield "processed", f"ANOTHER: {input_data.data}"
|
||||
|
||||
|
||||
class TestCustomProvider:
|
||||
"""Test suite for custom provider functionality."""
|
||||
|
||||
def test_custom_provider_enum_accepts_any_string(self):
|
||||
"""Test that ProviderName enum accepts any string value."""
|
||||
# Test with a completely new provider name
|
||||
custom_provider = ProviderName("my-totally-new-provider")
|
||||
assert custom_provider.value == "my-totally-new-provider"
|
||||
|
||||
# Test with existing provider
|
||||
existing_provider = ProviderName.OPENAI
|
||||
assert existing_provider.value == "openai"
|
||||
|
||||
# Test comparison
|
||||
another_custom = ProviderName("my-totally-new-provider")
|
||||
assert custom_provider == another_custom
|
||||
|
||||
def test_custom_provider_block_executes(self):
|
||||
"""Test that blocks with custom providers can execute properly."""
|
||||
block = CustomProviderBlock()
|
||||
execute_block_test(block)
|
||||
|
||||
def test_multiple_custom_providers(self):
|
||||
"""Test that multiple custom providers can coexist."""
|
||||
block1 = CustomProviderBlock()
|
||||
block2 = AnotherCustomProviderBlock()
|
||||
|
||||
# Both blocks should execute successfully
|
||||
execute_block_test(block1)
|
||||
execute_block_test(block2)
|
||||
|
||||
def test_custom_provider_registration(self):
|
||||
"""Test that custom providers are registered in the auto-registry."""
|
||||
registry = get_registry()
|
||||
|
||||
# Check that our custom provider blocks have registered their costs
|
||||
block_costs = registry.get_block_costs_dict()
|
||||
assert CustomProviderBlock in block_costs
|
||||
assert AnotherCustomProviderBlock in block_costs
|
||||
|
||||
# Check the costs are correct
|
||||
custom_costs = block_costs[CustomProviderBlock]
|
||||
assert len(custom_costs) == 2
|
||||
assert any(
|
||||
cost.cost_amount == 10 and cost.cost_type == BlockCostType.RUN
|
||||
for cost in custom_costs
|
||||
)
|
||||
assert any(
|
||||
cost.cost_amount == 2 and cost.cost_type == BlockCostType.BYTE
|
||||
for cost in custom_costs
|
||||
)
|
||||
|
||||
def test_custom_provider_default_credentials(self):
|
||||
"""Test that default credentials are registered for custom providers."""
|
||||
registry = get_registry()
|
||||
default_creds = registry.get_default_credentials_list()
|
||||
|
||||
# Check that our custom provider's default credentials are registered
|
||||
custom_default_creds = [
|
||||
cred for cred in default_creds if cred.provider == "my-custom-service"
|
||||
]
|
||||
assert len(custom_default_creds) >= 1
|
||||
assert custom_default_creds[0].id == "my-custom-service-default"
|
||||
|
||||
def test_custom_provider_with_oauth(self):
|
||||
"""Test that custom providers can use OAuth handlers."""
|
||||
# This is a placeholder for OAuth testing
|
||||
# In a real implementation, you would create a custom OAuth handler
|
||||
pass
|
||||
|
||||
def test_custom_provider_with_webhooks(self):
|
||||
"""Test that custom providers can use webhook managers."""
|
||||
# This is a placeholder for webhook testing
|
||||
# In a real implementation, you would create a custom webhook manager
|
||||
pass
|
||||
|
||||
|
||||
# Test that runs as part of pytest
|
||||
def test_custom_provider_functionality():
|
||||
"""Run all custom provider tests."""
|
||||
test_instance = TestCustomProvider()
|
||||
|
||||
# Run each test method
|
||||
test_instance.test_custom_provider_enum_accepts_any_string()
|
||||
test_instance.test_custom_provider_block_executes()
|
||||
test_instance.test_multiple_custom_providers()
|
||||
test_instance.test_custom_provider_registration()
|
||||
test_instance.test_custom_provider_default_credentials()
|
||||
@@ -1,416 +0,0 @@
|
||||
"""
|
||||
Advanced tests for custom provider functionality including OAuth and Webhooks.
|
||||
|
||||
This test suite demonstrates how custom providers can integrate with all
|
||||
aspects of the SDK including OAuth authentication and webhook handling.
|
||||
"""
|
||||
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
BaseModel,
|
||||
BaseOAuthHandler,
|
||||
BaseWebhooksManager,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockCost,
|
||||
BlockCostType,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
BlockWebhookConfig,
|
||||
Boolean,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
Dict,
|
||||
Float,
|
||||
List,
|
||||
OAuth2Credentials,
|
||||
SchemaField,
|
||||
SecretStr,
|
||||
String,
|
||||
cost_config,
|
||||
default_credentials,
|
||||
oauth_config,
|
||||
provider,
|
||||
webhook_config,
|
||||
)
|
||||
from backend.util.test import execute_block_test
|
||||
|
||||
|
||||
# Custom OAuth Handler for testing
|
||||
class CustomServiceOAuthHandler(BaseOAuthHandler):
|
||||
"""OAuth handler for our custom service."""
|
||||
|
||||
PROVIDER_NAME = ProviderName("custom-oauth-service")
|
||||
DEFAULT_SCOPES = ["read", "write", "admin"]
|
||||
|
||||
def get_login_url(
|
||||
self, scopes: list[str], state: str, code_challenge: Optional[str]
|
||||
) -> str:
|
||||
"""Generate OAuth login URL."""
|
||||
scope_str = " ".join(scopes)
|
||||
return f"https://custom-oauth-service.com/oauth/authorize?client_id=test&scope={scope_str}&state={state}"
|
||||
|
||||
def exchange_code_for_tokens(
|
||||
self, code: str, scopes: list[str], code_verifier: Optional[str]
|
||||
) -> OAuth2Credentials:
|
||||
"""Exchange authorization code for tokens."""
|
||||
# Mock token exchange
|
||||
return OAuth2Credentials(
|
||||
provider=self.PROVIDER_NAME,
|
||||
access_token=SecretStr("mock-access-token"),
|
||||
refresh_token=SecretStr("mock-refresh-token"),
|
||||
scopes=scopes,
|
||||
access_token_expires_at=int(time.time() + 3600),
|
||||
title="Custom OAuth Service",
|
||||
id="custom-oauth-creds",
|
||||
)
|
||||
|
||||
|
||||
# Custom Webhook Manager for testing
|
||||
class CustomWebhookManager(BaseWebhooksManager):
|
||||
"""Webhook manager for our custom service."""
|
||||
|
||||
PROVIDER_NAME = ProviderName("custom-webhook-service")
|
||||
|
||||
class WebhookType(str, Enum):
|
||||
DATA_RECEIVED = "data_received"
|
||||
STATUS_CHANGED = "status_changed"
|
||||
|
||||
@classmethod
|
||||
async def validate_payload(cls, webhook: Any, request: Any) -> tuple[dict, str]:
|
||||
"""Validate incoming webhook payload."""
|
||||
# Mock payload validation
|
||||
payload = {"data": "test data", "timestamp": time.time()}
|
||||
event_type = "data_received"
|
||||
return payload, event_type
|
||||
|
||||
async def _register_webhook(
|
||||
self,
|
||||
credentials: Any,
|
||||
webhook_type: Any,
|
||||
resource: str,
|
||||
events: list[str],
|
||||
ingress_url: str,
|
||||
secret: str,
|
||||
) -> tuple[str, dict]:
|
||||
"""Register webhook with external service."""
|
||||
# Mock webhook registration
|
||||
webhook_id = "custom-webhook-12345"
|
||||
config = {"url": ingress_url, "events": events, "resource": resource}
|
||||
return webhook_id, config
|
||||
|
||||
async def _deregister_webhook(self, webhook: Any, credentials: Any) -> None:
|
||||
"""Deregister webhook from external service."""
|
||||
# Mock webhook deregistration
|
||||
pass
|
||||
|
||||
|
||||
# Test OAuth-enabled block
|
||||
@provider("custom-oauth-service")
|
||||
@oauth_config("custom-oauth-service", CustomServiceOAuthHandler)
|
||||
@cost_config(
|
||||
BlockCost(cost_amount=15, cost_type=BlockCostType.RUN),
|
||||
)
|
||||
class CustomOAuthBlock(Block):
|
||||
"""Block that uses OAuth authentication with a custom provider."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = CredentialsField(
|
||||
provider="custom-oauth-service",
|
||||
supported_credential_types={"oauth2"},
|
||||
required_scopes={"read", "write"},
|
||||
description="OAuth credentials for custom service",
|
||||
)
|
||||
action: String = SchemaField(
|
||||
description="Action to perform", default="fetch_data"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
data: Dict = SchemaField(description="Retrieved data")
|
||||
token_valid: Boolean = SchemaField(description="Whether OAuth token was valid")
|
||||
scopes: List[String] = SchemaField(description="Available scopes")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f3456789-abcd-ef01-2345-6789abcdef01",
|
||||
description="Custom OAuth provider test block",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=CustomOAuthBlock.Input,
|
||||
output_schema=CustomOAuthBlock.Output,
|
||||
test_input={
|
||||
"credentials": {
|
||||
"provider": "custom-oauth-service",
|
||||
"id": "oauth-test-creds",
|
||||
"type": "oauth2",
|
||||
"title": "Test OAuth Creds",
|
||||
},
|
||||
"action": "test_action",
|
||||
},
|
||||
test_output=[
|
||||
("data", {"status": "success", "action": "test_action"}),
|
||||
("token_valid", True),
|
||||
("scopes", ["read", "write"]),
|
||||
],
|
||||
test_credentials=OAuth2Credentials(
|
||||
id="oauth-test-creds",
|
||||
provider="custom-oauth-service",
|
||||
access_token=SecretStr("test-access-token"),
|
||||
refresh_token=SecretStr("test-refresh-token"),
|
||||
scopes=["read", "write"],
|
||||
access_token_expires_at=int(time.time() + 3600),
|
||||
title="Test OAuth Credentials",
|
||||
),
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: OAuth2Credentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Simulate OAuth API call
|
||||
token = credentials.access_token.get_secret_value()
|
||||
|
||||
yield "data", {"status": "success", "action": input_data.action}
|
||||
yield "token_valid", bool(token)
|
||||
yield "scopes", credentials.scopes
|
||||
|
||||
|
||||
# Event filter model for webhook
|
||||
class WebhookEventFilter(BaseModel):
|
||||
data_received: bool = True
|
||||
status_changed: bool = False
|
||||
|
||||
|
||||
# Test Webhook-enabled block
|
||||
@provider("custom-webhook-service")
|
||||
@webhook_config("custom-webhook-service", CustomWebhookManager)
|
||||
class CustomWebhookBlock(Block):
|
||||
"""Block that receives webhooks from a custom provider."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = CredentialsField(
|
||||
provider="custom-webhook-service",
|
||||
supported_credential_types={"api_key"},
|
||||
description="Credentials for webhook service",
|
||||
)
|
||||
events: WebhookEventFilter = SchemaField(
|
||||
description="Events to listen for", default_factory=WebhookEventFilter
|
||||
)
|
||||
payload: Dict = SchemaField(
|
||||
description="Webhook payload", default={}, hidden=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
event_type: String = SchemaField(description="Type of event received")
|
||||
event_data: Dict = SchemaField(description="Event data")
|
||||
timestamp: Float = SchemaField(description="Event timestamp")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a4567890-bcde-f012-3456-7890bcdef012",
|
||||
description="Custom webhook provider test block",
|
||||
categories={BlockCategory.INPUT},
|
||||
input_schema=CustomWebhookBlock.Input,
|
||||
output_schema=CustomWebhookBlock.Output,
|
||||
block_type=BlockType.WEBHOOK,
|
||||
webhook_config=BlockWebhookConfig(
|
||||
provider=ProviderName("custom-webhook-service"),
|
||||
webhook_type="data_received",
|
||||
event_filter_input="events",
|
||||
resource_format="webhook/{webhook_id}",
|
||||
),
|
||||
test_input={
|
||||
"credentials": {
|
||||
"provider": "custom-webhook-service",
|
||||
"id": "webhook-test-creds",
|
||||
"type": "api_key",
|
||||
"title": "Test Webhook Creds",
|
||||
},
|
||||
"events": {"data_received": True, "status_changed": False},
|
||||
"payload": {
|
||||
"type": "data_received",
|
||||
"data": "test",
|
||||
"timestamp": 1234567890.0,
|
||||
},
|
||||
},
|
||||
test_output=[
|
||||
("event_type", "data_received"),
|
||||
(
|
||||
"event_data",
|
||||
{
|
||||
"type": "data_received",
|
||||
"data": "test",
|
||||
"timestamp": 1234567890.0,
|
||||
},
|
||||
),
|
||||
("timestamp", 1234567890.0),
|
||||
],
|
||||
test_credentials=APIKeyCredentials(
|
||||
id="webhook-test-creds",
|
||||
provider="custom-webhook-service",
|
||||
api_key=SecretStr("webhook-api-key"),
|
||||
title="Webhook API Key",
|
||||
expires_at=None,
|
||||
),
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
payload = input_data.payload
|
||||
|
||||
yield "event_type", payload.get("type", "unknown")
|
||||
yield "event_data", payload
|
||||
yield "timestamp", payload.get("timestamp", 0.0)
|
||||
|
||||
|
||||
# Combined block using multiple custom features
|
||||
@provider("custom-full-service")
|
||||
@oauth_config("custom-full-service", CustomServiceOAuthHandler)
|
||||
@webhook_config("custom-full-service", CustomWebhookManager)
|
||||
@cost_config(
|
||||
BlockCost(cost_amount=20, cost_type=BlockCostType.RUN),
|
||||
BlockCost(cost_amount=5, cost_type=BlockCostType.BYTE),
|
||||
)
|
||||
@default_credentials(
|
||||
APIKeyCredentials(
|
||||
id="custom-full-service-default",
|
||||
provider="custom-full-service",
|
||||
api_key=SecretStr("default-full-service-key"),
|
||||
title="Custom Full Service Default Key",
|
||||
expires_at=None,
|
||||
)
|
||||
)
|
||||
class CustomFullServiceBlock(Block):
|
||||
"""Block demonstrating all custom provider features."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = CredentialsField(
|
||||
provider="custom-full-service",
|
||||
supported_credential_types={"api_key", "oauth2"},
|
||||
description="Credentials for full service",
|
||||
)
|
||||
mode: String = SchemaField(description="Operation mode", default="standard")
|
||||
|
||||
class Output(BlockSchema):
|
||||
result: String = SchemaField(description="Operation result")
|
||||
features_used: List[String] = SchemaField(description="Features utilized")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b5678901-cdef-0123-4567-8901cdef0123",
|
||||
description="Full-featured custom provider block",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS, BlockCategory.INPUT},
|
||||
input_schema=CustomFullServiceBlock.Input,
|
||||
output_schema=CustomFullServiceBlock.Output,
|
||||
test_input={
|
||||
"credentials": {
|
||||
"provider": "custom-full-service",
|
||||
"id": "full-test-creds",
|
||||
"type": "api_key",
|
||||
"title": "Full Service Test Creds",
|
||||
},
|
||||
"mode": "test",
|
||||
},
|
||||
test_output=[
|
||||
("result", "SUCCESS: test mode"),
|
||||
("features_used", ["provider", "cost_config", "default_credentials"]),
|
||||
],
|
||||
test_credentials=APIKeyCredentials(
|
||||
id="full-test-creds",
|
||||
provider="custom-full-service",
|
||||
api_key=SecretStr("full-service-test-key"),
|
||||
title="Full Service Test Key",
|
||||
expires_at=None,
|
||||
),
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, *, credentials: Any, **kwargs) -> BlockOutput:
|
||||
features = ["provider", "cost_config", "default_credentials"]
|
||||
|
||||
if isinstance(credentials, OAuth2Credentials):
|
||||
features.append("oauth")
|
||||
|
||||
yield "result", f"SUCCESS: {input_data.mode} mode"
|
||||
yield "features_used", features
|
||||
|
||||
|
||||
class TestCustomProviderAdvanced:
|
||||
"""Advanced test suite for custom provider functionality."""
|
||||
|
||||
def test_oauth_handler_registration(self):
|
||||
"""Test that custom OAuth handlers are registered."""
|
||||
from backend.sdk.auto_registry import get_registry
|
||||
|
||||
registry = get_registry()
|
||||
oauth_handlers = registry.get_oauth_handlers_dict()
|
||||
|
||||
# Check if our custom OAuth handler is registered
|
||||
assert "custom-oauth-service" in oauth_handlers
|
||||
assert oauth_handlers["custom-oauth-service"] == CustomServiceOAuthHandler
|
||||
|
||||
def test_webhook_manager_registration(self):
|
||||
"""Test that custom webhook managers are registered."""
|
||||
from backend.sdk.auto_registry import get_registry
|
||||
|
||||
registry = get_registry()
|
||||
webhook_managers = registry.get_webhook_managers_dict()
|
||||
|
||||
# Check if our custom webhook manager is registered
|
||||
assert "custom-webhook-service" in webhook_managers
|
||||
assert webhook_managers["custom-webhook-service"] == CustomWebhookManager
|
||||
|
||||
def test_oauth_block_execution(self):
|
||||
"""Test OAuth-enabled block execution."""
|
||||
block = CustomOAuthBlock()
|
||||
execute_block_test(block)
|
||||
|
||||
def test_webhook_block_execution(self):
|
||||
"""Test webhook-enabled block execution."""
|
||||
block = CustomWebhookBlock()
|
||||
execute_block_test(block)
|
||||
|
||||
def test_full_service_block_execution(self):
|
||||
"""Test full-featured block execution."""
|
||||
block = CustomFullServiceBlock()
|
||||
execute_block_test(block)
|
||||
|
||||
def test_multiple_decorators_on_same_provider(self):
|
||||
"""Test that a single provider can have multiple features."""
|
||||
from backend.sdk.auto_registry import get_registry
|
||||
|
||||
registry = get_registry()
|
||||
|
||||
# Check OAuth handler
|
||||
oauth_handlers = registry.get_oauth_handlers_dict()
|
||||
assert "custom-full-service" in oauth_handlers
|
||||
|
||||
# Check webhook manager
|
||||
webhook_managers = registry.get_webhook_managers_dict()
|
||||
assert "custom-full-service" in webhook_managers
|
||||
|
||||
# Check default credentials
|
||||
default_creds = registry.get_default_credentials_list()
|
||||
full_service_creds = [
|
||||
cred for cred in default_creds if cred.provider == "custom-full-service"
|
||||
]
|
||||
assert len(full_service_creds) >= 1
|
||||
|
||||
# Check cost config
|
||||
block_costs = registry.get_block_costs_dict()
|
||||
assert CustomFullServiceBlock in block_costs
|
||||
|
||||
|
||||
# Main test function
|
||||
def test_custom_provider_advanced_functionality():
|
||||
"""Run all advanced custom provider tests."""
|
||||
test_instance = TestCustomProviderAdvanced()
|
||||
|
||||
test_instance.test_oauth_handler_registration()
|
||||
test_instance.test_webhook_manager_registration()
|
||||
test_instance.test_oauth_block_execution()
|
||||
test_instance.test_webhook_block_execution()
|
||||
test_instance.test_full_service_block_execution()
|
||||
test_instance.test_multiple_decorators_on_same_provider()
|
||||
@@ -1,35 +0,0 @@
|
||||
"""Test that custom providers work with validation."""
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
|
||||
class TestModel(BaseModel):
|
||||
provider: ProviderName
|
||||
|
||||
|
||||
def test_custom_provider_validation():
|
||||
"""Test that custom provider names are accepted."""
|
||||
# Test with existing provider
|
||||
model1 = TestModel(provider=ProviderName("openai"))
|
||||
assert model1.provider == ProviderName.OPENAI
|
||||
assert model1.provider.value == "openai"
|
||||
|
||||
# Test with custom provider
|
||||
model2 = TestModel(provider=ProviderName("my-custom-provider"))
|
||||
assert model2.provider.value == "my-custom-provider"
|
||||
|
||||
# Test JSON schema
|
||||
schema = TestModel.model_json_schema()
|
||||
provider_schema = schema["properties"]["provider"]
|
||||
|
||||
# Should not have enum constraint
|
||||
assert "enum" not in provider_schema
|
||||
assert provider_schema["type"] == "string"
|
||||
|
||||
print("✅ Custom provider validation works!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_custom_provider_validation()
|
||||
442
autogpt_platform/backend/test/sdk/test_sdk_block_creation.py
Normal file
442
autogpt_platform/backend/test/sdk/test_sdk_block_creation.py
Normal file
@@ -0,0 +1,442 @@
|
||||
"""
|
||||
Tests for creating blocks using the SDK.
|
||||
|
||||
This test suite verifies that blocks can be created using only SDK imports
|
||||
and that they work correctly without decorators.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockCostType,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
Boolean,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
Integer,
|
||||
Optional,
|
||||
ProviderBuilder,
|
||||
SchemaField,
|
||||
SecretStr,
|
||||
String,
|
||||
)
|
||||
|
||||
|
||||
class TestBasicBlockCreation:
|
||||
"""Test creating basic blocks using the SDK."""
|
||||
|
||||
def test_simple_block(self):
|
||||
"""Test creating a simple block without any decorators."""
|
||||
|
||||
class SimpleBlock(Block):
|
||||
"""A simple test block."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
text: String = SchemaField(description="Input text")
|
||||
count: Integer = SchemaField(description="Repeat count", default=1)
|
||||
|
||||
class Output(BlockSchema):
|
||||
result: String = SchemaField(description="Output result")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="simple-test-block",
|
||||
description="A simple test block",
|
||||
categories={BlockCategory.TEXT},
|
||||
input_schema=SimpleBlock.Input,
|
||||
output_schema=SimpleBlock.Output,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
result = input_data.text * input_data.count
|
||||
yield "result", result
|
||||
|
||||
# Create and test the block
|
||||
block = SimpleBlock()
|
||||
assert block.id == "simple-test-block"
|
||||
assert BlockCategory.TEXT in block.categories
|
||||
|
||||
# Test execution
|
||||
outputs = list(
|
||||
block.run(
|
||||
SimpleBlock.Input(text="Hello ", count=3),
|
||||
)
|
||||
)
|
||||
assert len(outputs) == 1
|
||||
assert outputs[0] == ("result", "Hello Hello Hello ")
|
||||
|
||||
def test_block_with_credentials(self):
|
||||
"""Test creating a block that requires credentials."""
|
||||
|
||||
class APIBlock(Block):
|
||||
"""A block that requires API credentials."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = CredentialsField(
|
||||
provider="test_api",
|
||||
supported_credential_types={"api_key"},
|
||||
description="API credentials",
|
||||
)
|
||||
query: String = SchemaField(description="API query")
|
||||
|
||||
class Output(BlockSchema):
|
||||
response: String = SchemaField(description="API response")
|
||||
authenticated: Boolean = SchemaField(description="Was authenticated")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="api-test-block",
|
||||
description="Test block with API credentials",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=APIBlock.Input,
|
||||
output_schema=APIBlock.Output,
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Simulate API call
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
authenticated = bool(api_key)
|
||||
|
||||
yield "response", f"API response for: {input_data.query}"
|
||||
yield "authenticated", authenticated
|
||||
|
||||
# Create test credentials
|
||||
test_creds = APIKeyCredentials(
|
||||
id="test-creds",
|
||||
provider="test_api",
|
||||
api_key=SecretStr("test-api-key"),
|
||||
title="Test API Key",
|
||||
)
|
||||
|
||||
# Create and test the block
|
||||
block = APIBlock()
|
||||
outputs = list(
|
||||
block.run(
|
||||
APIBlock.Input(
|
||||
credentials={ # type: ignore
|
||||
"provider": "test_api",
|
||||
"id": "test-creds",
|
||||
"type": "api_key",
|
||||
},
|
||||
query="test query",
|
||||
),
|
||||
credentials=test_creds,
|
||||
)
|
||||
)
|
||||
|
||||
assert len(outputs) == 2
|
||||
assert outputs[0] == ("response", "API response for: test query")
|
||||
assert outputs[1] == ("authenticated", True)
|
||||
|
||||
def test_block_with_multiple_outputs(self):
|
||||
"""Test block that yields multiple outputs."""
|
||||
|
||||
class MultiOutputBlock(Block):
|
||||
"""Block with multiple outputs."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
text: String = SchemaField(description="Input text")
|
||||
|
||||
class Output(BlockSchema):
|
||||
uppercase: String = SchemaField(description="Uppercase version")
|
||||
lowercase: String = SchemaField(description="Lowercase version")
|
||||
length: Integer = SchemaField(description="Text length")
|
||||
is_empty: Boolean = SchemaField(description="Is text empty")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="multi-output-block",
|
||||
description="Block with multiple outputs",
|
||||
categories={BlockCategory.TEXT},
|
||||
input_schema=MultiOutputBlock.Input,
|
||||
output_schema=MultiOutputBlock.Output,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
text = input_data.text
|
||||
yield "uppercase", text.upper()
|
||||
yield "lowercase", text.lower()
|
||||
yield "length", len(text)
|
||||
yield "is_empty", len(text) == 0
|
||||
|
||||
# Test the block
|
||||
block = MultiOutputBlock()
|
||||
outputs = list(block.run(MultiOutputBlock.Input(text="Hello World")))
|
||||
|
||||
assert len(outputs) == 4
|
||||
assert ("uppercase", "HELLO WORLD") in outputs
|
||||
assert ("lowercase", "hello world") in outputs
|
||||
assert ("length", 11) in outputs
|
||||
assert ("is_empty", False) in outputs
|
||||
|
||||
|
||||
class TestBlockWithProvider:
|
||||
"""Test creating blocks associated with providers."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test provider."""
|
||||
# Create a provider using ProviderBuilder
|
||||
self.provider = (
|
||||
ProviderBuilder("test_service")
|
||||
.with_api_key("TEST_SERVICE_API_KEY", "Test Service API Key")
|
||||
.with_base_cost(10, BlockCostType.RUN)
|
||||
.build()
|
||||
)
|
||||
|
||||
def test_block_using_provider(self):
|
||||
"""Test block that uses a registered provider."""
|
||||
|
||||
class TestServiceBlock(Block):
|
||||
"""Block for test service."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = CredentialsField(
|
||||
provider="test_service", # Matches our provider
|
||||
supported_credential_types={"api_key"},
|
||||
description="Test service credentials",
|
||||
)
|
||||
action: String = SchemaField(description="Action to perform")
|
||||
|
||||
class Output(BlockSchema):
|
||||
result: String = SchemaField(description="Action result")
|
||||
provider_name: String = SchemaField(description="Provider used")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="test-service-block",
|
||||
description="Block using test service provider",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=TestServiceBlock.Input,
|
||||
output_schema=TestServiceBlock.Output,
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# The provider name should match
|
||||
yield "result", f"Performed: {input_data.action}"
|
||||
yield "provider_name", credentials.provider
|
||||
|
||||
# Create credentials for our provider
|
||||
creds = APIKeyCredentials(
|
||||
id="test-service-creds",
|
||||
provider="test_service",
|
||||
api_key=SecretStr("test-key"),
|
||||
title="Test Service Key",
|
||||
)
|
||||
|
||||
# Test the block
|
||||
block = TestServiceBlock()
|
||||
outputs = dict(
|
||||
block.run(
|
||||
TestServiceBlock.Input(
|
||||
credentials={ # type: ignore
|
||||
"provider": "test_service",
|
||||
"id": "test-service-creds",
|
||||
"type": "api_key",
|
||||
},
|
||||
action="test action",
|
||||
),
|
||||
credentials=creds,
|
||||
)
|
||||
)
|
||||
|
||||
assert outputs["result"] == "Performed: test action"
|
||||
assert outputs["provider_name"] == "test_service"
|
||||
|
||||
|
||||
class TestComplexBlockScenarios:
|
||||
"""Test more complex block scenarios."""
|
||||
|
||||
def test_block_with_optional_fields(self):
|
||||
"""Test block with optional input fields."""
|
||||
# Optional is already imported at the module level
|
||||
|
||||
class OptionalFieldBlock(Block):
|
||||
"""Block with optional fields."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
required_field: String = SchemaField(description="Required field")
|
||||
optional_field: Optional[String] = SchemaField(
|
||||
description="Optional field",
|
||||
default=None,
|
||||
)
|
||||
optional_with_default: String = SchemaField(
|
||||
description="Optional with default",
|
||||
default="default value",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
has_optional: Boolean = SchemaField(description="Has optional value")
|
||||
optional_value: Optional[String] = SchemaField(
|
||||
description="Optional value"
|
||||
)
|
||||
default_value: String = SchemaField(description="Default value")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="optional-field-block",
|
||||
description="Block with optional fields",
|
||||
categories={BlockCategory.TEXT},
|
||||
input_schema=OptionalFieldBlock.Input,
|
||||
output_schema=OptionalFieldBlock.Output,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
yield "has_optional", input_data.optional_field is not None
|
||||
yield "optional_value", input_data.optional_field
|
||||
yield "default_value", input_data.optional_with_default
|
||||
|
||||
# Test with optional field provided
|
||||
block = OptionalFieldBlock()
|
||||
outputs = dict(
|
||||
block.run(
|
||||
OptionalFieldBlock.Input(
|
||||
required_field="test",
|
||||
optional_field="provided",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
assert outputs["has_optional"] is True
|
||||
assert outputs["optional_value"] == "provided"
|
||||
assert outputs["default_value"] == "default value"
|
||||
|
||||
# Test without optional field
|
||||
outputs = dict(
|
||||
block.run(
|
||||
OptionalFieldBlock.Input(
|
||||
required_field="test",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
assert outputs["has_optional"] is False
|
||||
assert outputs["optional_value"] is None
|
||||
assert outputs["default_value"] == "default value"
|
||||
|
||||
def test_block_with_complex_types(self):
|
||||
"""Test block with complex input/output types."""
|
||||
from backend.sdk import BaseModel, Dict, List
|
||||
|
||||
class ItemModel(BaseModel):
|
||||
name: str
|
||||
value: int
|
||||
|
||||
class ComplexBlock(Block):
|
||||
"""Block with complex types."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
items: List[String] = SchemaField(description="List of items")
|
||||
mapping: Dict[String, Integer] = SchemaField(
|
||||
description="String to int mapping"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
item_count: Integer = SchemaField(description="Number of items")
|
||||
total_value: Integer = SchemaField(description="Sum of mapping values")
|
||||
combined: List[String] = SchemaField(description="Combined results")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="complex-types-block",
|
||||
description="Block with complex types",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=ComplexBlock.Input,
|
||||
output_schema=ComplexBlock.Output,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
yield "item_count", len(input_data.items)
|
||||
yield "total_value", sum(input_data.mapping.values())
|
||||
|
||||
# Combine items with their mapping values
|
||||
combined = []
|
||||
for item in input_data.items:
|
||||
value = input_data.mapping.get(item, 0)
|
||||
combined.append(f"{item}: {value}")
|
||||
|
||||
yield "combined", combined
|
||||
|
||||
# Test the block
|
||||
block = ComplexBlock()
|
||||
outputs = dict(
|
||||
block.run(
|
||||
ComplexBlock.Input(
|
||||
items=["apple", "banana", "orange"],
|
||||
mapping={"apple": 5, "banana": 3, "orange": 4},
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
assert outputs["item_count"] == 3
|
||||
assert outputs["total_value"] == 12
|
||||
assert outputs["combined"] == ["apple: 5", "banana: 3", "orange: 4"]
|
||||
|
||||
def test_block_error_handling(self):
|
||||
"""Test block error handling."""
|
||||
|
||||
class ErrorHandlingBlock(Block):
|
||||
"""Block that demonstrates error handling."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
value: Integer = SchemaField(description="Input value")
|
||||
should_error: Boolean = SchemaField(
|
||||
description="Whether to trigger an error",
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
result: Integer = SchemaField(description="Result")
|
||||
error_message: Optional[String] = SchemaField(
|
||||
description="Error if any", default=None
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="error-handling-block",
|
||||
description="Block with error handling",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=ErrorHandlingBlock.Input,
|
||||
output_schema=ErrorHandlingBlock.Output,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
if input_data.should_error:
|
||||
raise ValueError("Intentional error triggered")
|
||||
|
||||
if input_data.value < 0:
|
||||
yield "error_message", "Value must be non-negative"
|
||||
yield "result", 0
|
||||
else:
|
||||
yield "result", input_data.value * 2
|
||||
yield "error_message", None
|
||||
|
||||
# Test normal operation
|
||||
block = ErrorHandlingBlock()
|
||||
outputs = dict(block.run(ErrorHandlingBlock.Input(value=5, should_error=False)))
|
||||
|
||||
assert outputs["result"] == 10
|
||||
assert outputs["error_message"] is None
|
||||
|
||||
# Test with negative value
|
||||
outputs = dict(
|
||||
block.run(ErrorHandlingBlock.Input(value=-5, should_error=False))
|
||||
)
|
||||
|
||||
assert outputs["result"] == 0
|
||||
assert outputs["error_message"] == "Value must be non-negative"
|
||||
|
||||
# Test with error
|
||||
with pytest.raises(ValueError, match="Intentional error triggered"):
|
||||
list(block.run(ErrorHandlingBlock.Input(value=5, should_error=True)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -1,483 +0,0 @@
|
||||
"""
|
||||
Comprehensive test suite for the AutoGPT SDK implementation.
|
||||
Tests all aspects of the SDK including imports, decorators, and auto-registration.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add backend to path
|
||||
backend_path = Path(__file__).parent.parent.parent
|
||||
sys.path.insert(0, str(backend_path))
|
||||
|
||||
|
||||
class TestSDKImplementation:
|
||||
"""Comprehensive SDK tests"""
|
||||
|
||||
def test_sdk_imports_all_components(self):
|
||||
"""Test that all expected components are available from backend.sdk import *"""
|
||||
# Import SDK
|
||||
import backend.sdk as sdk
|
||||
|
||||
# Core block components
|
||||
assert hasattr(sdk, "Block")
|
||||
assert hasattr(sdk, "BlockCategory")
|
||||
assert hasattr(sdk, "BlockOutput")
|
||||
assert hasattr(sdk, "BlockSchema")
|
||||
assert hasattr(sdk, "BlockType")
|
||||
assert hasattr(sdk, "SchemaField")
|
||||
|
||||
# Credential components
|
||||
assert hasattr(sdk, "CredentialsField")
|
||||
assert hasattr(sdk, "CredentialsMetaInput")
|
||||
assert hasattr(sdk, "APIKeyCredentials")
|
||||
assert hasattr(sdk, "OAuth2Credentials")
|
||||
assert hasattr(sdk, "UserPasswordCredentials")
|
||||
|
||||
# Cost components
|
||||
assert hasattr(sdk, "BlockCost")
|
||||
assert hasattr(sdk, "BlockCostType")
|
||||
assert hasattr(sdk, "NodeExecutionStats")
|
||||
|
||||
# Provider component
|
||||
assert hasattr(sdk, "ProviderName")
|
||||
|
||||
# Type aliases
|
||||
assert sdk.String is str
|
||||
assert sdk.Integer is int
|
||||
assert sdk.Float is float
|
||||
assert sdk.Boolean is bool
|
||||
|
||||
# Decorators
|
||||
assert hasattr(sdk, "provider")
|
||||
assert hasattr(sdk, "cost_config")
|
||||
assert hasattr(sdk, "default_credentials")
|
||||
assert hasattr(sdk, "webhook_config")
|
||||
assert hasattr(sdk, "oauth_config")
|
||||
|
||||
# Common types
|
||||
assert hasattr(sdk, "List")
|
||||
assert hasattr(sdk, "Dict")
|
||||
assert hasattr(sdk, "Optional")
|
||||
assert hasattr(sdk, "Any")
|
||||
assert hasattr(sdk, "Union")
|
||||
assert hasattr(sdk, "BaseModel")
|
||||
assert hasattr(sdk, "SecretStr")
|
||||
assert hasattr(sdk, "Enum")
|
||||
|
||||
# Utilities
|
||||
assert hasattr(sdk, "json")
|
||||
assert hasattr(sdk, "logging")
|
||||
|
||||
print("✅ All SDK imports verified")
|
||||
|
||||
def test_auto_registry_system(self):
|
||||
"""Test the auto-registration system"""
|
||||
from backend.sdk import APIKeyCredentials, BlockCost, BlockCostType, SecretStr
|
||||
from backend.sdk.auto_registry import AutoRegistry, get_registry
|
||||
|
||||
# Get registry instance
|
||||
registry = get_registry()
|
||||
assert isinstance(registry, AutoRegistry)
|
||||
|
||||
# Test provider registration
|
||||
initial_providers = len(registry.providers)
|
||||
registry.register_provider("test-provider-123")
|
||||
assert "test-provider-123" in registry.providers
|
||||
assert len(registry.providers) == initial_providers + 1
|
||||
|
||||
# Test cost registration
|
||||
class TestBlock:
|
||||
pass
|
||||
|
||||
test_costs = [
|
||||
BlockCost(cost_amount=10, cost_type=BlockCostType.RUN),
|
||||
BlockCost(cost_amount=2, cost_type=BlockCostType.BYTE),
|
||||
]
|
||||
registry.register_block_cost(TestBlock, test_costs)
|
||||
assert TestBlock in registry.block_costs
|
||||
assert len(registry.block_costs[TestBlock]) == 2
|
||||
assert registry.block_costs[TestBlock][0].cost_amount == 10
|
||||
|
||||
# Test credential registration
|
||||
test_cred = APIKeyCredentials(
|
||||
id="test-cred-123",
|
||||
provider="test-provider-123",
|
||||
api_key=SecretStr("test-api-key"),
|
||||
title="Test Credential",
|
||||
)
|
||||
registry.register_default_credential(test_cred)
|
||||
|
||||
# Check credential was added
|
||||
creds = registry.get_default_credentials_list()
|
||||
assert any(c.id == "test-cred-123" for c in creds)
|
||||
|
||||
# Test duplicate prevention
|
||||
initial_cred_count = len(registry.default_credentials)
|
||||
registry.register_default_credential(test_cred) # Add again
|
||||
assert (
|
||||
len(registry.default_credentials) == initial_cred_count
|
||||
) # Should not increase
|
||||
|
||||
print("✅ Auto-registry system verified")
|
||||
|
||||
def test_decorators_functionality(self):
|
||||
"""Test that all decorators work correctly"""
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockCost,
|
||||
BlockCostType,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
SchemaField,
|
||||
SecretStr,
|
||||
String,
|
||||
cost_config,
|
||||
default_credentials,
|
||||
oauth_config,
|
||||
provider,
|
||||
webhook_config,
|
||||
)
|
||||
from backend.sdk.auto_registry import get_registry
|
||||
|
||||
registry = get_registry()
|
||||
|
||||
# Clear registry state for clean test
|
||||
# initial_provider_count = len(registry.providers)
|
||||
|
||||
# Test combined decorators on a block
|
||||
@provider("test-service-xyz")
|
||||
@cost_config(
|
||||
BlockCost(cost_amount=15, cost_type=BlockCostType.RUN),
|
||||
BlockCost(cost_amount=3, cost_type=BlockCostType.SECOND),
|
||||
)
|
||||
@default_credentials(
|
||||
APIKeyCredentials(
|
||||
id="test-service-xyz-default",
|
||||
provider="test-service-xyz",
|
||||
api_key=SecretStr("default-test-key"),
|
||||
title="Test Service Default Key",
|
||||
)
|
||||
)
|
||||
class TestServiceBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
text: String = SchemaField(description="Test input")
|
||||
|
||||
class Output(BlockSchema):
|
||||
result: String = SchemaField(description="Test output")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f0421f19-53da-4824-97cc-4d2bccd1399f",
|
||||
description="Test service block",
|
||||
categories={BlockCategory.TEXT},
|
||||
input_schema=TestServiceBlock.Input,
|
||||
output_schema=TestServiceBlock.Output,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
yield "result", f"Processed: {input_data.text}"
|
||||
|
||||
# Verify decorators worked
|
||||
assert "test-service-xyz" in registry.providers
|
||||
assert TestServiceBlock in registry.block_costs
|
||||
assert len(registry.block_costs[TestServiceBlock]) == 2
|
||||
|
||||
# Check credentials
|
||||
creds = registry.get_default_credentials_list()
|
||||
assert any(c.id == "test-service-xyz-default" for c in creds)
|
||||
|
||||
# Test webhook decorator (mock classes for testing)
|
||||
class MockWebhookManager:
|
||||
pass
|
||||
|
||||
@webhook_config("test-webhook-provider", MockWebhookManager)
|
||||
class TestWebhookBlock:
|
||||
pass
|
||||
|
||||
assert "test-webhook-provider" in registry.webhook_managers
|
||||
assert registry.webhook_managers["test-webhook-provider"] == MockWebhookManager
|
||||
|
||||
# Test oauth decorator
|
||||
class MockOAuthHandler:
|
||||
pass
|
||||
|
||||
@oauth_config("test-oauth-provider", MockOAuthHandler)
|
||||
class TestOAuthBlock:
|
||||
pass
|
||||
|
||||
assert "test-oauth-provider" in registry.oauth_handlers
|
||||
assert registry.oauth_handlers["test-oauth-provider"] == MockOAuthHandler
|
||||
|
||||
print("✅ All decorators verified")
|
||||
|
||||
def test_provider_enum_dynamic_support(self):
|
||||
"""Test that ProviderName enum supports dynamic providers"""
|
||||
from backend.sdk import ProviderName
|
||||
|
||||
# Test existing provider
|
||||
existing = ProviderName.GITHUB
|
||||
assert existing.value == "github"
|
||||
assert isinstance(existing, ProviderName)
|
||||
|
||||
# Test dynamic provider
|
||||
dynamic = ProviderName("my-custom-provider-abc")
|
||||
assert dynamic.value == "my-custom-provider-abc"
|
||||
assert isinstance(dynamic, ProviderName)
|
||||
assert dynamic._name_ == "MY-CUSTOM-PROVIDER-ABC"
|
||||
|
||||
# Test that same dynamic provider returns same instance
|
||||
dynamic2 = ProviderName("my-custom-provider-abc")
|
||||
assert dynamic.value == dynamic2.value
|
||||
|
||||
# Test invalid input
|
||||
try:
|
||||
ProviderName(123) # Should not work with non-string
|
||||
assert False, "Should have failed with non-string"
|
||||
except ValueError:
|
||||
pass # Expected
|
||||
|
||||
print("✅ Dynamic provider enum verified")
|
||||
|
||||
def test_complete_block_example(self):
|
||||
"""Test a complete block using all SDK features"""
|
||||
# This simulates what a block developer would write
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockCost,
|
||||
BlockCostType,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
Float,
|
||||
SchemaField,
|
||||
SecretStr,
|
||||
String,
|
||||
cost_config,
|
||||
default_credentials,
|
||||
provider,
|
||||
)
|
||||
|
||||
@provider("ai-translator-service")
|
||||
@cost_config(
|
||||
BlockCost(cost_amount=5, cost_type=BlockCostType.RUN),
|
||||
BlockCost(cost_amount=1, cost_type=BlockCostType.BYTE),
|
||||
)
|
||||
@default_credentials(
|
||||
APIKeyCredentials(
|
||||
id="ai-translator-default",
|
||||
provider="ai-translator-service",
|
||||
api_key=SecretStr("translator-default-key"),
|
||||
title="AI Translator Default API Key",
|
||||
)
|
||||
)
|
||||
class AITranslatorBlock(Block):
|
||||
"""AI-powered translation block using the SDK"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = CredentialsField(
|
||||
provider="ai-translator-service",
|
||||
supported_credential_types={"api_key"},
|
||||
description="API credentials for AI Translator",
|
||||
)
|
||||
text: String = SchemaField(
|
||||
description="Text to translate", default="Hello, world!"
|
||||
)
|
||||
target_language: String = SchemaField(
|
||||
description="Target language code", default="es"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
translated_text: String = SchemaField(description="Translated text")
|
||||
source_language: String = SchemaField(
|
||||
description="Detected source language"
|
||||
)
|
||||
confidence: Float = SchemaField(
|
||||
description="Translation confidence score"
|
||||
)
|
||||
error: String = SchemaField(
|
||||
description="Error message if any", default=""
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="dc832afe-902a-4520-8512-d3b85428d4ec",
|
||||
description="Translate text using AI Translator Service",
|
||||
categories={BlockCategory.TEXT, BlockCategory.AI},
|
||||
input_schema=AITranslatorBlock.Input,
|
||||
output_schema=AITranslatorBlock.Output,
|
||||
test_input={"text": "Hello, world!", "target_language": "es"},
|
||||
test_output=[
|
||||
("translated_text", "¡Hola, mundo!"),
|
||||
("source_language", "en"),
|
||||
("confidence", 0.95),
|
||||
],
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
# Simulate translation
|
||||
credentials.api_key.get_secret_value() # Verify we can access the key
|
||||
|
||||
# Mock translation logic
|
||||
translations = {
|
||||
("Hello, world!", "es"): "¡Hola, mundo!",
|
||||
("Hello, world!", "fr"): "Bonjour le monde!",
|
||||
("Hello, world!", "de"): "Hallo Welt!",
|
||||
}
|
||||
|
||||
key = (input_data.text, input_data.target_language)
|
||||
translated = translations.get(
|
||||
key, f"[{input_data.target_language}] {input_data.text}"
|
||||
)
|
||||
|
||||
yield "translated_text", translated
|
||||
yield "source_language", "en"
|
||||
yield "confidence", 0.95
|
||||
yield "error", ""
|
||||
|
||||
except Exception as e:
|
||||
yield "translated_text", ""
|
||||
yield "source_language", ""
|
||||
yield "confidence", 0.0
|
||||
yield "error", str(e)
|
||||
|
||||
# Verify the block was created correctly
|
||||
block = AITranslatorBlock()
|
||||
assert block.id == "dc832afe-902a-4520-8512-d3b85428d4ec"
|
||||
assert block.description == "Translate text using AI Translator Service"
|
||||
assert BlockCategory.TEXT in block.categories
|
||||
assert BlockCategory.AI in block.categories
|
||||
|
||||
# Verify decorators registered everything
|
||||
from backend.sdk.auto_registry import get_registry
|
||||
|
||||
registry = get_registry()
|
||||
|
||||
assert "ai-translator-service" in registry.providers
|
||||
assert AITranslatorBlock in registry.block_costs
|
||||
assert len(registry.block_costs[AITranslatorBlock]) == 2
|
||||
|
||||
creds = registry.get_default_credentials_list()
|
||||
assert any(c.id == "ai-translator-default" for c in creds)
|
||||
|
||||
print("✅ Complete block example verified")
|
||||
|
||||
def test_backward_compatibility(self):
|
||||
"""Test that old-style imports still work"""
|
||||
# Test that we can still import from original locations
|
||||
try:
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
assert Block is not None
|
||||
assert BlockCategory is not None
|
||||
assert BlockOutput is not None
|
||||
assert BlockSchema is not None
|
||||
assert SchemaField is not None
|
||||
print("✅ Backward compatibility verified")
|
||||
except ImportError as e:
|
||||
print(f"❌ Backward compatibility issue: {e}")
|
||||
raise
|
||||
|
||||
def test_auto_registration_patching(self):
|
||||
"""Test that auto-registration correctly patches existing systems"""
|
||||
from backend.sdk.auto_registry import patch_existing_systems
|
||||
|
||||
# This would normally be called during app startup
|
||||
# For testing, we'll verify the patching logic works
|
||||
try:
|
||||
patch_existing_systems()
|
||||
print("✅ Auto-registration patching verified")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Patching had issues (expected in test environment): {e}")
|
||||
# This is expected in test environment where not all systems are loaded
|
||||
|
||||
def test_import_star_works(self):
|
||||
"""Test that 'from backend.sdk import *' actually works"""
|
||||
# Create a temporary module to test import *
|
||||
test_code = """
|
||||
from backend.sdk import *
|
||||
|
||||
# Test that common items are available
|
||||
assert Block is not None
|
||||
assert BlockSchema is not None
|
||||
assert SchemaField is not None
|
||||
assert String == str
|
||||
assert provider is not None
|
||||
assert cost_config is not None
|
||||
print("Import * works correctly")
|
||||
"""
|
||||
|
||||
# Execute in a clean namespace
|
||||
namespace = {"__name__": "__main__"}
|
||||
try:
|
||||
exec(test_code, namespace)
|
||||
print("✅ Import * functionality verified")
|
||||
except Exception as e:
|
||||
print(f"❌ Import * failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def run_all_tests():
|
||||
"""Run all SDK tests"""
|
||||
print("\n" + "=" * 60)
|
||||
print("🧪 Running Comprehensive SDK Tests")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
test_suite = TestSDKImplementation()
|
||||
|
||||
tests = [
|
||||
("SDK Imports", test_suite.test_sdk_imports_all_components),
|
||||
("Auto-Registry System", test_suite.test_auto_registry_system),
|
||||
("Decorators", test_suite.test_decorators_functionality),
|
||||
("Dynamic Provider Enum", test_suite.test_provider_enum_dynamic_support),
|
||||
("Complete Block Example", test_suite.test_complete_block_example),
|
||||
("Backward Compatibility", test_suite.test_backward_compatibility),
|
||||
("Auto-Registration Patching", test_suite.test_auto_registration_patching),
|
||||
("Import * Syntax", test_suite.test_import_star_works),
|
||||
]
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
|
||||
for test_name, test_func in tests:
|
||||
print(f"\n📋 Testing: {test_name}")
|
||||
print("-" * 40)
|
||||
try:
|
||||
test_func()
|
||||
passed += 1
|
||||
except Exception as e:
|
||||
print(f"❌ Test failed: {e}")
|
||||
failed += 1
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(f"📊 Test Results: {passed} passed, {failed} failed")
|
||||
print("=" * 60)
|
||||
|
||||
if failed == 0:
|
||||
print("\n🎉 All SDK tests passed! The implementation is working correctly.")
|
||||
else:
|
||||
print(f"\n⚠️ {failed} tests failed. Please review the errors above.")
|
||||
|
||||
return failed == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = run_all_tests()
|
||||
sys.exit(0 if success else 1)
|
||||
@@ -1,172 +0,0 @@
|
||||
"""
|
||||
Test the SDK import system and auto-registration
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add backend to path
|
||||
backend_path = Path(__file__).parent.parent.parent
|
||||
sys.path.insert(0, str(backend_path))
|
||||
|
||||
|
||||
def test_sdk_imports():
|
||||
"""Test that all expected imports are available from backend.sdk"""
|
||||
|
||||
# Import the module and check its contents
|
||||
import backend.sdk as sdk
|
||||
|
||||
# Core block components should be available
|
||||
assert hasattr(sdk, "Block")
|
||||
assert hasattr(sdk, "BlockCategory")
|
||||
assert hasattr(sdk, "BlockOutput")
|
||||
assert hasattr(sdk, "BlockSchema")
|
||||
assert hasattr(sdk, "SchemaField")
|
||||
|
||||
# Credential types should be available
|
||||
assert hasattr(sdk, "CredentialsField")
|
||||
assert hasattr(sdk, "CredentialsMetaInput")
|
||||
assert hasattr(sdk, "APIKeyCredentials")
|
||||
assert hasattr(sdk, "OAuth2Credentials")
|
||||
|
||||
# Cost system should be available
|
||||
assert hasattr(sdk, "BlockCost")
|
||||
assert hasattr(sdk, "BlockCostType")
|
||||
|
||||
# Providers should be available
|
||||
assert hasattr(sdk, "ProviderName")
|
||||
|
||||
# Type aliases should work
|
||||
assert sdk.String is str
|
||||
assert sdk.Integer is int
|
||||
assert sdk.Float is float
|
||||
assert sdk.Boolean is bool
|
||||
|
||||
# Decorators should be available
|
||||
assert hasattr(sdk, "provider")
|
||||
assert hasattr(sdk, "cost_config")
|
||||
assert hasattr(sdk, "default_credentials")
|
||||
assert hasattr(sdk, "webhook_config")
|
||||
assert hasattr(sdk, "oauth_config")
|
||||
|
||||
# Common types should be available
|
||||
assert hasattr(sdk, "List")
|
||||
assert hasattr(sdk, "Dict")
|
||||
assert hasattr(sdk, "Optional")
|
||||
assert hasattr(sdk, "Any")
|
||||
assert hasattr(sdk, "Union")
|
||||
assert hasattr(sdk, "BaseModel")
|
||||
assert hasattr(sdk, "SecretStr")
|
||||
|
||||
# Utilities should be available
|
||||
assert hasattr(sdk, "json")
|
||||
assert hasattr(sdk, "logging")
|
||||
|
||||
|
||||
def test_auto_registry():
|
||||
"""Test the auto-registration system"""
|
||||
|
||||
from backend.sdk import APIKeyCredentials, BlockCost, BlockCostType, SecretStr
|
||||
from backend.sdk.auto_registry import AutoRegistry, get_registry
|
||||
|
||||
# Get the registry
|
||||
registry = get_registry()
|
||||
assert isinstance(registry, AutoRegistry)
|
||||
|
||||
# Test registering a provider
|
||||
registry.register_provider("test-provider")
|
||||
assert "test-provider" in registry.providers
|
||||
|
||||
# Test registering block costs
|
||||
test_costs = [BlockCost(cost_amount=5, cost_type=BlockCostType.RUN)]
|
||||
|
||||
class TestBlock:
|
||||
pass
|
||||
|
||||
registry.register_block_cost(TestBlock, test_costs)
|
||||
assert TestBlock in registry.block_costs
|
||||
assert registry.block_costs[TestBlock] == test_costs
|
||||
|
||||
# Test registering credentials
|
||||
test_cred = APIKeyCredentials(
|
||||
id="test-cred",
|
||||
provider="test-provider",
|
||||
api_key=SecretStr("test-key"),
|
||||
title="Test Credential",
|
||||
)
|
||||
registry.register_default_credential(test_cred)
|
||||
assert test_cred in registry.default_credentials
|
||||
|
||||
|
||||
def test_decorators():
|
||||
"""Test that decorators work correctly"""
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
BlockCost,
|
||||
BlockCostType,
|
||||
SecretStr,
|
||||
cost_config,
|
||||
default_credentials,
|
||||
provider,
|
||||
)
|
||||
from backend.sdk.auto_registry import get_registry
|
||||
|
||||
# Clear registry for test
|
||||
registry = get_registry()
|
||||
|
||||
# Test provider decorator
|
||||
@provider("decorator-test")
|
||||
class DecoratorTestBlock:
|
||||
pass
|
||||
|
||||
assert "decorator-test" in registry.providers
|
||||
|
||||
# Test cost_config decorator
|
||||
@cost_config(BlockCost(cost_amount=10, cost_type=BlockCostType.RUN))
|
||||
class CostTestBlock:
|
||||
pass
|
||||
|
||||
assert CostTestBlock in registry.block_costs
|
||||
assert len(registry.block_costs[CostTestBlock]) == 1
|
||||
assert registry.block_costs[CostTestBlock][0].cost_amount == 10
|
||||
|
||||
# Test default_credentials decorator
|
||||
@default_credentials(
|
||||
APIKeyCredentials(
|
||||
id="decorator-test-cred",
|
||||
provider="decorator-test",
|
||||
api_key=SecretStr("test-api-key"),
|
||||
title="Decorator Test Credential",
|
||||
)
|
||||
)
|
||||
class CredTestBlock:
|
||||
pass
|
||||
|
||||
# Check if credential was registered
|
||||
creds = registry.get_default_credentials_list()
|
||||
assert any(c.id == "decorator-test-cred" for c in creds)
|
||||
|
||||
|
||||
def test_example_block_imports():
|
||||
"""Test that example blocks can use SDK imports"""
|
||||
# Skip this test since example blocks were moved to examples directory
|
||||
# to avoid interfering with main tests
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run tests
|
||||
test_sdk_imports()
|
||||
print("✅ SDK imports test passed")
|
||||
|
||||
test_auto_registry()
|
||||
print("✅ Auto-registry test passed")
|
||||
|
||||
test_decorators()
|
||||
print("✅ Decorators test passed")
|
||||
|
||||
test_example_block_imports()
|
||||
print("✅ Example block test passed")
|
||||
|
||||
print("\n🎉 All SDK tests passed!")
|
||||
@@ -1,442 +0,0 @@
|
||||
"""
|
||||
Integration test demonstrating the complete SDK workflow.
|
||||
This shows how a developer would create a new block with zero external configuration.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
|
||||
# Add backend to path
|
||||
backend_path = Path(__file__).parent.parent.parent
|
||||
sys.path.insert(0, str(backend_path))
|
||||
|
||||
# ruff: noqa: E402
|
||||
# Import SDK at module level for testing (after sys.path modification)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
BaseWebhooksManager,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockCost,
|
||||
BlockCostType,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
BlockWebhookConfig,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
Dict,
|
||||
Float,
|
||||
Integer,
|
||||
List,
|
||||
SchemaField,
|
||||
SecretStr,
|
||||
String,
|
||||
cost_config,
|
||||
default_credentials,
|
||||
provider,
|
||||
webhook_config,
|
||||
)
|
||||
|
||||
|
||||
def test_complete_sdk_workflow():
|
||||
"""
|
||||
Demonstrate the complete workflow of creating a new block with the SDK.
|
||||
This test shows:
|
||||
1. Single import statement
|
||||
2. Custom provider registration
|
||||
3. Cost configuration
|
||||
4. Default credentials
|
||||
5. Zero external configuration needed
|
||||
"""
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("🚀 SDK Integration Test - Complete Workflow")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
# Step 1: Import everything needed with a single statement
|
||||
print("Step 1: Import SDK")
|
||||
# SDK already imported at module level
|
||||
print("✅ Imported all components with 'from backend.sdk import *'")
|
||||
|
||||
# Step 2: Create a custom AI service block
|
||||
print("\nStep 2: Create a custom AI service block")
|
||||
|
||||
@provider("custom-ai-vision-service")
|
||||
@cost_config(
|
||||
BlockCost(cost_amount=10, cost_type=BlockCostType.RUN),
|
||||
BlockCost(cost_amount=5, cost_type=BlockCostType.BYTE),
|
||||
)
|
||||
@default_credentials(
|
||||
APIKeyCredentials(
|
||||
id="custom-ai-vision-default",
|
||||
provider="custom-ai-vision-service",
|
||||
api_key=SecretStr("vision-service-default-api-key"),
|
||||
title="Custom AI Vision Service Default API Key",
|
||||
expires_at=None,
|
||||
)
|
||||
)
|
||||
class CustomAIVisionBlock(Block):
|
||||
"""
|
||||
Custom AI Vision Analysis Block
|
||||
|
||||
This block demonstrates:
|
||||
- Custom provider name (not in the original enum)
|
||||
- Automatic cost registration
|
||||
- Default credentials setup
|
||||
- Complex input/output schemas
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = CredentialsField(
|
||||
provider="custom-ai-vision-service",
|
||||
supported_credential_types={"api_key"},
|
||||
description="API credentials for Custom AI Vision Service",
|
||||
)
|
||||
image_url: String = SchemaField(
|
||||
description="URL of the image to analyze",
|
||||
placeholder="https://example.com/image.jpg",
|
||||
)
|
||||
analysis_type: String = SchemaField(
|
||||
description="Type of analysis to perform",
|
||||
default="general",
|
||||
)
|
||||
confidence_threshold: Float = SchemaField(
|
||||
description="Minimum confidence threshold for detections",
|
||||
default=0.7,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
)
|
||||
max_results: Integer = SchemaField(
|
||||
description="Maximum number of results to return",
|
||||
default=10,
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
detections: List[Dict] = SchemaField(
|
||||
description="List of detected items with confidence scores", default=[]
|
||||
)
|
||||
analysis_type: String = SchemaField(
|
||||
description="Type of analysis performed"
|
||||
)
|
||||
processing_time: Float = SchemaField(
|
||||
description="Time taken to process the image in seconds"
|
||||
)
|
||||
total_detections: Integer = SchemaField(
|
||||
description="Total number of detections found"
|
||||
)
|
||||
error: String = SchemaField(
|
||||
description="Error message if analysis failed", default=""
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="303d9bd3-f2a5-41ca-bb9c-e347af8ef72f",
|
||||
description="Analyze images using Custom AI Vision Service with configurable detection types",
|
||||
categories={BlockCategory.AI, BlockCategory.MULTIMEDIA},
|
||||
input_schema=CustomAIVisionBlock.Input,
|
||||
output_schema=CustomAIVisionBlock.Output,
|
||||
test_input={
|
||||
"image_url": "https://example.com/test-image.jpg",
|
||||
"analysis_type": "objects",
|
||||
"confidence_threshold": 0.8,
|
||||
"max_results": 5,
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"detections",
|
||||
[
|
||||
{"object": "car", "confidence": 0.95},
|
||||
{"object": "person", "confidence": 0.87},
|
||||
],
|
||||
),
|
||||
("analysis_type", "objects"),
|
||||
("processing_time", 1.23),
|
||||
("total_detections", 2),
|
||||
("error", ""),
|
||||
],
|
||||
static_output=False,
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Get API key
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Simulate API call to vision service
|
||||
print(f" - Using API key: {api_key[:10]}...")
|
||||
print(f" - Analyzing image: {input_data.image_url}")
|
||||
print(f" - Analysis type: {input_data.analysis_type}")
|
||||
|
||||
# Mock detection results based on analysis type
|
||||
mock_results = {
|
||||
"general": [
|
||||
{"category": "indoor", "confidence": 0.92},
|
||||
{"category": "office", "confidence": 0.88},
|
||||
],
|
||||
"faces": [
|
||||
{"face_id": 1, "confidence": 0.95, "age": "25-35"},
|
||||
{"face_id": 2, "confidence": 0.91, "age": "40-50"},
|
||||
],
|
||||
"objects": [
|
||||
{"object": "laptop", "confidence": 0.94},
|
||||
{"object": "coffee_cup", "confidence": 0.89},
|
||||
{"object": "notebook", "confidence": 0.85},
|
||||
],
|
||||
"text": [
|
||||
{"text": "Hello World", "confidence": 0.97},
|
||||
{"text": "SDK Demo", "confidence": 0.93},
|
||||
],
|
||||
"scene": [
|
||||
{"scene": "office_workspace", "confidence": 0.91},
|
||||
{"scene": "indoor_lighting", "confidence": 0.87},
|
||||
],
|
||||
}
|
||||
|
||||
# Get results for the requested analysis type
|
||||
detections = mock_results.get(
|
||||
input_data.analysis_type,
|
||||
[{"error": "Unknown analysis type", "confidence": 0.0}],
|
||||
)
|
||||
|
||||
# Filter by confidence threshold
|
||||
filtered_detections = [
|
||||
d
|
||||
for d in detections
|
||||
if d.get("confidence", 0) >= input_data.confidence_threshold
|
||||
]
|
||||
|
||||
# Limit results
|
||||
final_detections = filtered_detections[: input_data.max_results]
|
||||
|
||||
# Calculate processing time
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
# Yield results
|
||||
yield "detections", final_detections
|
||||
yield "analysis_type", input_data.analysis_type
|
||||
yield "processing_time", round(processing_time, 3)
|
||||
yield "total_detections", len(final_detections)
|
||||
yield "error", ""
|
||||
|
||||
except Exception as e:
|
||||
yield "detections", []
|
||||
yield "analysis_type", input_data.analysis_type
|
||||
yield "processing_time", time.time() - start_time
|
||||
yield "total_detections", 0
|
||||
yield "error", str(e)
|
||||
|
||||
print("✅ Block class created with all decorators")
|
||||
|
||||
# Step 3: Verify auto-registration worked
|
||||
print("\nStep 3: Verify auto-registration")
|
||||
from backend.sdk.auto_registry import get_registry
|
||||
|
||||
registry = get_registry()
|
||||
|
||||
# Check provider registration
|
||||
assert "custom-ai-vision-service" in registry.providers
|
||||
print("✅ Custom provider 'custom-ai-vision-service' auto-registered")
|
||||
|
||||
# Check cost registration
|
||||
assert CustomAIVisionBlock in registry.block_costs
|
||||
costs = registry.block_costs[CustomAIVisionBlock]
|
||||
assert len(costs) == 2
|
||||
assert costs[0].cost_amount == 10
|
||||
assert costs[0].cost_type == BlockCostType.RUN
|
||||
print("✅ Block costs auto-registered (10 credits per run, 5 per byte)")
|
||||
|
||||
# Check credential registration
|
||||
creds = registry.get_default_credentials_list()
|
||||
vision_cred = next((c for c in creds if c.id == "custom-ai-vision-default"), None)
|
||||
assert vision_cred is not None
|
||||
assert vision_cred.provider == "custom-ai-vision-service"
|
||||
print("✅ Default credentials auto-registered")
|
||||
|
||||
# Step 4: Test dynamic provider enum
|
||||
print("\nStep 4: Test dynamic provider support")
|
||||
provider_instance = ProviderName("custom-ai-vision-service")
|
||||
assert provider_instance.value == "custom-ai-vision-service"
|
||||
assert isinstance(provider_instance, ProviderName)
|
||||
print("✅ ProviderName enum accepts custom provider dynamically")
|
||||
|
||||
# Step 5: Instantiate and test the block
|
||||
print("\nStep 5: Test block instantiation and execution")
|
||||
block = CustomAIVisionBlock()
|
||||
|
||||
# Verify block properties
|
||||
assert block.id == "303d9bd3-f2a5-41ca-bb9c-e347af8ef72f"
|
||||
assert BlockCategory.AI in block.categories
|
||||
assert BlockCategory.MULTIMEDIA in block.categories
|
||||
print("✅ Block instantiated successfully")
|
||||
|
||||
# Test block execution
|
||||
test_credentials = APIKeyCredentials(
|
||||
id="test-cred",
|
||||
provider="custom-ai-vision-service",
|
||||
api_key=SecretStr("test-api-key-12345"),
|
||||
title="Test API Key",
|
||||
)
|
||||
|
||||
test_input = CustomAIVisionBlock.Input(
|
||||
credentials={
|
||||
"provider": "custom-ai-vision-service",
|
||||
"id": "test",
|
||||
"type": "api_key",
|
||||
},
|
||||
image_url="https://example.com/test.jpg",
|
||||
analysis_type="objects",
|
||||
confidence_threshold=0.8,
|
||||
max_results=3,
|
||||
)
|
||||
|
||||
print("\n Running block with test data...")
|
||||
results = list(block.run(test_input, credentials=test_credentials))
|
||||
|
||||
# Verify outputs
|
||||
output_dict = {key: value for key, value in results}
|
||||
assert "detections" in output_dict
|
||||
assert "analysis_type" in output_dict
|
||||
assert output_dict["analysis_type"] == "objects"
|
||||
assert "total_detections" in output_dict
|
||||
assert output_dict["error"] == ""
|
||||
print("✅ Block execution successful")
|
||||
|
||||
# Step 6: Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("🎉 SDK Integration Test Complete!")
|
||||
print("=" * 60)
|
||||
print("\nKey achievements demonstrated:")
|
||||
print("✅ Single import: from backend.sdk import *")
|
||||
print("✅ Custom provider registered automatically")
|
||||
print("✅ Costs configured via decorator")
|
||||
print("✅ Default credentials set via decorator")
|
||||
print("✅ Block works without ANY external configuration")
|
||||
print("✅ Dynamic provider name accepted by enum")
|
||||
print("\nThe SDK successfully enables zero-configuration block development!")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def test_webhook_block_workflow():
|
||||
"""Test creating a webhook block with the SDK"""
|
||||
|
||||
print("\n\n" + "=" * 60)
|
||||
print("🔔 Webhook Block Integration Test")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
# SDK already imported at module level
|
||||
|
||||
# Create a simple webhook manager
|
||||
class CustomWebhookManager(BaseWebhooksManager):
|
||||
PROVIDER_NAME = ProviderName("custom-webhook-service")
|
||||
|
||||
class WebhookType(str, Enum):
|
||||
DATA_UPDATE = "data_update"
|
||||
STATUS_CHANGE = "status_change"
|
||||
|
||||
@classmethod
|
||||
async def validate_payload(cls, webhook, request) -> tuple[dict, str]:
|
||||
payload = await request.json()
|
||||
event_type = request.headers.get("X-Custom-Event", "unknown")
|
||||
return payload, event_type
|
||||
|
||||
async def _register_webhook(
|
||||
self,
|
||||
credentials,
|
||||
webhook_type: str,
|
||||
resource: str,
|
||||
events: list[str],
|
||||
ingress_url: str,
|
||||
secret: str,
|
||||
) -> tuple[str, dict]:
|
||||
# Mock registration
|
||||
return "webhook-12345", {"status": "registered"}
|
||||
|
||||
async def _deregister_webhook(self, webhook, credentials) -> None:
|
||||
pass
|
||||
|
||||
# Create webhook block
|
||||
@provider("custom-webhook-service")
|
||||
@webhook_config("custom-webhook-service", CustomWebhookManager)
|
||||
class CustomWebhookBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
webhook_events: Dict = SchemaField(
|
||||
description="Events to listen for",
|
||||
default={"data_update": True, "status_change": False},
|
||||
)
|
||||
payload: Dict = SchemaField(
|
||||
description="Webhook payload", default={}, hidden=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
event_type: String = SchemaField(description="Type of event")
|
||||
event_data: Dict = SchemaField(description="Event data")
|
||||
timestamp: String = SchemaField(description="Event timestamp")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3e730ed4-6eb2-4b89-b5ae-001860c88aef",
|
||||
description="Listen for custom webhook events",
|
||||
categories={BlockCategory.INPUT},
|
||||
input_schema=CustomWebhookBlock.Input,
|
||||
output_schema=CustomWebhookBlock.Output,
|
||||
block_type=BlockType.WEBHOOK,
|
||||
webhook_config=BlockWebhookConfig(
|
||||
provider=ProviderName("custom-webhook-service"),
|
||||
webhook_type="data_update",
|
||||
event_filter_input="webhook_events",
|
||||
resource_format="{resource}",
|
||||
),
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
payload = input_data.payload
|
||||
yield "event_type", payload.get("type", "unknown")
|
||||
yield "event_data", payload
|
||||
yield "timestamp", payload.get("timestamp", "")
|
||||
|
||||
# Verify registration
|
||||
from backend.sdk.auto_registry import get_registry
|
||||
|
||||
registry = get_registry()
|
||||
|
||||
assert "custom-webhook-service" in registry.webhook_managers
|
||||
assert registry.webhook_managers["custom-webhook-service"] == CustomWebhookManager
|
||||
print("✅ Webhook manager auto-registered")
|
||||
print("✅ Webhook block created successfully")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
# Run main integration test
|
||||
success1 = test_complete_sdk_workflow()
|
||||
|
||||
# Run webhook integration test
|
||||
success2 = test_webhook_block_workflow()
|
||||
|
||||
if success1 and success2:
|
||||
print("\n\n🌟 All integration tests passed successfully!")
|
||||
sys.exit(0)
|
||||
else:
|
||||
print("\n\n❌ Some integration tests failed")
|
||||
sys.exit(1)
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n\n❌ Integration test failed with error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
326
autogpt_platform/backend/test/sdk/test_sdk_patching.py
Normal file
326
autogpt_platform/backend/test/sdk/test_sdk_patching.py
Normal file
@@ -0,0 +1,326 @@
|
||||
"""
|
||||
Tests for the SDK's integration patching mechanism.
|
||||
|
||||
This test suite verifies that the AutoRegistry correctly patches
|
||||
existing integration points to include SDK-registered components.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.sdk import (
|
||||
AutoRegistry,
|
||||
BaseOAuthHandler,
|
||||
BaseWebhooksManager,
|
||||
ProviderBuilder,
|
||||
)
|
||||
|
||||
|
||||
class MockOAuthHandler(BaseOAuthHandler):
|
||||
"""Mock OAuth handler for testing."""
|
||||
|
||||
PROVIDER_NAME = ProviderName.GITHUB
|
||||
|
||||
@classmethod
|
||||
async def authorize(cls, *args, **kwargs):
|
||||
return "mock_auth"
|
||||
|
||||
|
||||
class MockWebhookManager(BaseWebhooksManager):
|
||||
"""Mock webhook manager for testing."""
|
||||
|
||||
PROVIDER_NAME = ProviderName.GITHUB
|
||||
|
||||
@classmethod
|
||||
async def validate_payload(cls, webhook, request):
|
||||
return {}, "test_event"
|
||||
|
||||
async def _register_webhook(self, *args, **kwargs):
|
||||
return "mock_webhook_id", {}
|
||||
|
||||
async def _deregister_webhook(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class TestOAuthPatching:
|
||||
"""Test OAuth handler patching functionality."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Clear registry and set up mocks."""
|
||||
AutoRegistry.clear()
|
||||
|
||||
def test_oauth_handler_dictionary_patching(self):
|
||||
"""Test that OAuth handlers are correctly patched into HANDLERS_BY_NAME."""
|
||||
# Create original handlers
|
||||
original_handlers = {
|
||||
"existing_provider": Mock(spec=BaseOAuthHandler),
|
||||
}
|
||||
|
||||
# Create a mock oauth module
|
||||
mock_oauth_module = MagicMock()
|
||||
mock_oauth_module.HANDLERS_BY_NAME = original_handlers.copy()
|
||||
|
||||
# Register a new provider with OAuth
|
||||
(ProviderBuilder("new_oauth_provider").with_oauth(MockOAuthHandler).build())
|
||||
|
||||
# Apply patches
|
||||
with patch("backend.integrations.oauth", mock_oauth_module):
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
patched_dict = mock_oauth_module.HANDLERS_BY_NAME
|
||||
|
||||
# Test that original handler still exists
|
||||
assert "existing_provider" in patched_dict
|
||||
assert (
|
||||
patched_dict["existing_provider"]
|
||||
== original_handlers["existing_provider"]
|
||||
)
|
||||
|
||||
# Test that new handler is accessible
|
||||
assert "new_oauth_provider" in patched_dict
|
||||
assert patched_dict["new_oauth_provider"] == MockOAuthHandler
|
||||
|
||||
# Test dict methods
|
||||
assert "existing_provider" in patched_dict.keys()
|
||||
assert "new_oauth_provider" in patched_dict.keys()
|
||||
|
||||
# Test .get() method
|
||||
assert patched_dict.get("new_oauth_provider") == MockOAuthHandler
|
||||
assert patched_dict.get("nonexistent", "default") == "default"
|
||||
|
||||
# Test __contains__
|
||||
assert "new_oauth_provider" in patched_dict
|
||||
assert "nonexistent" not in patched_dict
|
||||
|
||||
def test_oauth_patching_with_multiple_providers(self):
|
||||
"""Test patching with multiple OAuth providers."""
|
||||
|
||||
# Create another OAuth handler
|
||||
class AnotherOAuthHandler(BaseOAuthHandler):
|
||||
PROVIDER_NAME = ProviderName.GOOGLE
|
||||
|
||||
# Register multiple providers
|
||||
(ProviderBuilder("oauth_provider_1").with_oauth(MockOAuthHandler).build())
|
||||
|
||||
(ProviderBuilder("oauth_provider_2").with_oauth(AnotherOAuthHandler).build())
|
||||
|
||||
# Mock the oauth module
|
||||
mock_oauth_module = MagicMock()
|
||||
mock_oauth_module.HANDLERS_BY_NAME = {}
|
||||
|
||||
with patch("backend.integrations.oauth", mock_oauth_module):
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
patched_dict = mock_oauth_module.HANDLERS_BY_NAME
|
||||
|
||||
# Both providers should be accessible
|
||||
assert patched_dict["oauth_provider_1"] == MockOAuthHandler
|
||||
assert patched_dict["oauth_provider_2"] == AnotherOAuthHandler
|
||||
|
||||
# Check values() method
|
||||
values = list(patched_dict.values())
|
||||
assert MockOAuthHandler in values
|
||||
assert AnotherOAuthHandler in values
|
||||
|
||||
# Check items() method
|
||||
items = dict(patched_dict.items())
|
||||
assert items["oauth_provider_1"] == MockOAuthHandler
|
||||
assert items["oauth_provider_2"] == AnotherOAuthHandler
|
||||
|
||||
|
||||
class TestWebhookPatching:
|
||||
"""Test webhook manager patching functionality."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Clear registry."""
|
||||
AutoRegistry.clear()
|
||||
|
||||
def test_webhook_manager_patching(self):
|
||||
"""Test that webhook managers are correctly patched."""
|
||||
|
||||
# Mock the original load_webhook_managers function
|
||||
def mock_load_webhook_managers():
|
||||
return {
|
||||
"existing_webhook": Mock(spec=BaseWebhooksManager),
|
||||
}
|
||||
|
||||
# Register a provider with webhooks
|
||||
(
|
||||
ProviderBuilder("webhook_provider")
|
||||
.with_webhook_manager(MockWebhookManager)
|
||||
.build()
|
||||
)
|
||||
|
||||
# Mock the webhooks module
|
||||
mock_webhooks_module = MagicMock()
|
||||
mock_webhooks_module.load_webhook_managers = mock_load_webhook_managers
|
||||
|
||||
with patch("backend.integrations.webhooks", mock_webhooks_module):
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
# Call the patched function
|
||||
result = mock_webhooks_module.load_webhook_managers()
|
||||
|
||||
# Original webhook should still exist
|
||||
assert "existing_webhook" in result
|
||||
|
||||
# New webhook should be added
|
||||
assert "webhook_provider" in result
|
||||
assert result["webhook_provider"] == MockWebhookManager
|
||||
|
||||
def test_webhook_patching_no_original_function(self):
|
||||
"""Test webhook patching when load_webhook_managers doesn't exist."""
|
||||
# Mock webhooks module without load_webhook_managers
|
||||
mock_webhooks_module = MagicMock(spec=[])
|
||||
|
||||
# Register a provider
|
||||
(
|
||||
ProviderBuilder("test_provider")
|
||||
.with_webhook_manager(MockWebhookManager)
|
||||
.build()
|
||||
)
|
||||
|
||||
with patch("backend.integrations.webhooks", mock_webhooks_module):
|
||||
# Should not raise an error
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
# Function should not be added if it didn't exist
|
||||
assert not hasattr(mock_webhooks_module, "load_webhook_managers")
|
||||
|
||||
|
||||
class TestPatchingEdgeCases:
|
||||
"""Test edge cases and error handling in patching."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Clear registry."""
|
||||
AutoRegistry.clear()
|
||||
|
||||
def test_patching_with_import_errors(self):
|
||||
"""Test that patching handles import errors gracefully."""
|
||||
# Register a provider
|
||||
(ProviderBuilder("test_provider").with_oauth(MockOAuthHandler).build())
|
||||
|
||||
# Make the oauth module import fail
|
||||
with patch("builtins.__import__", side_effect=ImportError("Mock import error")):
|
||||
# Should not raise an error
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
def test_patching_with_attribute_errors(self):
|
||||
"""Test handling of missing attributes."""
|
||||
# Mock oauth module without HANDLERS_BY_NAME
|
||||
mock_oauth_module = MagicMock(spec=[])
|
||||
|
||||
(ProviderBuilder("test_provider").with_oauth(MockOAuthHandler).build())
|
||||
|
||||
with patch("backend.integrations.oauth", mock_oauth_module):
|
||||
# Should not raise an error
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
def test_patching_preserves_thread_safety(self):
|
||||
"""Test that patching maintains thread safety."""
|
||||
import threading
|
||||
import time
|
||||
|
||||
results = []
|
||||
errors = []
|
||||
|
||||
def register_provider(name, delay=0):
|
||||
try:
|
||||
time.sleep(delay)
|
||||
(ProviderBuilder(name).with_oauth(MockOAuthHandler).build())
|
||||
results.append(name)
|
||||
except Exception as e:
|
||||
errors.append((name, str(e)))
|
||||
|
||||
# Create multiple threads
|
||||
threads = []
|
||||
for i in range(5):
|
||||
t = threading.Thread(
|
||||
target=register_provider, args=(f"provider_{i}", i * 0.01)
|
||||
)
|
||||
threads.append(t)
|
||||
t.start()
|
||||
|
||||
# Wait for all threads
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Check results
|
||||
assert len(errors) == 0
|
||||
assert len(results) == 5
|
||||
assert len(AutoRegistry._providers) == 5
|
||||
|
||||
# Verify all providers are registered
|
||||
for i in range(5):
|
||||
assert f"provider_{i}" in AutoRegistry._providers
|
||||
|
||||
|
||||
class TestPatchingIntegration:
|
||||
"""Test the complete patching integration flow."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Clear registry."""
|
||||
AutoRegistry.clear()
|
||||
|
||||
def test_complete_provider_registration_and_patching(self):
|
||||
"""Test the complete flow from provider registration to patching."""
|
||||
# Mock both oauth and webhooks modules
|
||||
mock_oauth = MagicMock()
|
||||
mock_oauth.HANDLERS_BY_NAME = {"original": Mock()}
|
||||
|
||||
mock_webhooks = MagicMock()
|
||||
mock_webhooks.load_webhook_managers = lambda: {"original": Mock()}
|
||||
|
||||
# Create a fully featured provider
|
||||
(
|
||||
ProviderBuilder("complete_provider")
|
||||
.with_api_key("COMPLETE_KEY", "Complete API Key")
|
||||
.with_oauth(MockOAuthHandler, scopes=["read", "write"])
|
||||
.with_webhook_manager(MockWebhookManager)
|
||||
.build()
|
||||
)
|
||||
|
||||
# Apply patches
|
||||
with patch("backend.integrations.oauth", mock_oauth):
|
||||
with patch("backend.integrations.webhooks", mock_webhooks):
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
# Verify OAuth patching
|
||||
oauth_dict = mock_oauth.HANDLERS_BY_NAME
|
||||
assert "complete_provider" in oauth_dict
|
||||
assert oauth_dict["complete_provider"] == MockOAuthHandler
|
||||
assert "original" in oauth_dict # Original preserved
|
||||
|
||||
# Verify webhook patching
|
||||
webhook_result = mock_webhooks.load_webhook_managers()
|
||||
assert "complete_provider" in webhook_result
|
||||
assert webhook_result["complete_provider"] == MockWebhookManager
|
||||
assert "original" in webhook_result # Original preserved
|
||||
|
||||
def test_patching_is_idempotent(self):
|
||||
"""Test that calling patch_integrations multiple times is safe."""
|
||||
mock_oauth = MagicMock()
|
||||
mock_oauth.HANDLERS_BY_NAME = {}
|
||||
|
||||
# Register a provider
|
||||
(ProviderBuilder("test_provider").with_oauth(MockOAuthHandler).build())
|
||||
|
||||
with patch("backend.integrations.oauth", mock_oauth):
|
||||
# Patch multiple times
|
||||
AutoRegistry.patch_integrations()
|
||||
AutoRegistry.patch_integrations()
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
# Should still work correctly
|
||||
oauth_dict = mock_oauth.HANDLERS_BY_NAME
|
||||
assert "test_provider" in oauth_dict
|
||||
assert oauth_dict["test_provider"] == MockOAuthHandler
|
||||
|
||||
# Should not have duplicates or errors
|
||||
assert len([k for k in oauth_dict.keys() if k == "test_provider"]) == 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
511
autogpt_platform/backend/test/sdk/test_sdk_registry.py
Normal file
511
autogpt_platform/backend/test/sdk/test_sdk_registry.py
Normal file
@@ -0,0 +1,511 @@
|
||||
"""
|
||||
Tests for the SDK auto-registration system via AutoRegistry.
|
||||
|
||||
This test suite verifies:
|
||||
1. Provider registration and retrieval
|
||||
2. OAuth handler registration via patches
|
||||
3. Webhook manager registration via patches
|
||||
4. Credential registration and management
|
||||
5. Block configuration association
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
AutoRegistry,
|
||||
BaseOAuthHandler,
|
||||
BaseWebhooksManager,
|
||||
Block,
|
||||
BlockConfiguration,
|
||||
Provider,
|
||||
ProviderBuilder,
|
||||
)
|
||||
|
||||
|
||||
class TestAutoRegistry:
|
||||
"""Test the AutoRegistry functionality."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Clear registry before each test."""
|
||||
AutoRegistry.clear()
|
||||
|
||||
def test_provider_registration(self):
|
||||
"""Test that providers can be registered and retrieved."""
|
||||
# Create a test provider
|
||||
provider = Provider(
|
||||
name="test_provider",
|
||||
oauth_handler=None,
|
||||
webhook_manager=None,
|
||||
default_credentials=[],
|
||||
base_costs=[],
|
||||
supported_auth_types={"api_key"},
|
||||
)
|
||||
|
||||
# Register it
|
||||
AutoRegistry.register_provider(provider)
|
||||
|
||||
# Verify it's registered
|
||||
assert "test_provider" in AutoRegistry._providers
|
||||
assert AutoRegistry.get_provider("test_provider") == provider
|
||||
|
||||
def test_provider_with_oauth(self):
|
||||
"""Test provider registration with OAuth handler."""
|
||||
|
||||
# Create a mock OAuth handler
|
||||
class TestOAuthHandler(BaseOAuthHandler):
|
||||
PROVIDER_NAME = ProviderName.GITHUB
|
||||
|
||||
provider = Provider(
|
||||
name="oauth_provider",
|
||||
oauth_handler=TestOAuthHandler,
|
||||
webhook_manager=None,
|
||||
default_credentials=[],
|
||||
base_costs=[],
|
||||
supported_auth_types={"oauth2"},
|
||||
)
|
||||
|
||||
AutoRegistry.register_provider(provider)
|
||||
|
||||
# Verify OAuth handler is registered
|
||||
assert "oauth_provider" in AutoRegistry._oauth_handlers
|
||||
assert AutoRegistry._oauth_handlers["oauth_provider"] == TestOAuthHandler
|
||||
|
||||
def test_provider_with_webhook_manager(self):
|
||||
"""Test provider registration with webhook manager."""
|
||||
|
||||
# Create a mock webhook manager
|
||||
class TestWebhookManager(BaseWebhooksManager):
|
||||
PROVIDER_NAME = ProviderName.GITHUB
|
||||
|
||||
provider = Provider(
|
||||
name="webhook_provider",
|
||||
oauth_handler=None,
|
||||
webhook_manager=TestWebhookManager,
|
||||
default_credentials=[],
|
||||
base_costs=[],
|
||||
supported_auth_types={"api_key"},
|
||||
)
|
||||
|
||||
AutoRegistry.register_provider(provider)
|
||||
|
||||
# Verify webhook manager is registered
|
||||
assert "webhook_provider" in AutoRegistry._webhook_managers
|
||||
assert AutoRegistry._webhook_managers["webhook_provider"] == TestWebhookManager
|
||||
|
||||
def test_default_credentials_registration(self):
|
||||
"""Test that default credentials are registered."""
|
||||
# Create test credentials
|
||||
from backend.sdk import SecretStr
|
||||
|
||||
cred1 = APIKeyCredentials(
|
||||
id="test-cred-1",
|
||||
provider="test_provider",
|
||||
api_key=SecretStr("test-key-1"),
|
||||
title="Test Credential 1",
|
||||
)
|
||||
cred2 = APIKeyCredentials(
|
||||
id="test-cred-2",
|
||||
provider="test_provider",
|
||||
api_key=SecretStr("test-key-2"),
|
||||
title="Test Credential 2",
|
||||
)
|
||||
|
||||
provider = Provider(
|
||||
name="test_provider",
|
||||
oauth_handler=None,
|
||||
webhook_manager=None,
|
||||
default_credentials=[cred1, cred2],
|
||||
base_costs=[],
|
||||
supported_auth_types={"api_key"},
|
||||
)
|
||||
|
||||
AutoRegistry.register_provider(provider)
|
||||
|
||||
# Verify credentials are registered
|
||||
all_creds = AutoRegistry.get_all_credentials()
|
||||
assert cred1 in all_creds
|
||||
assert cred2 in all_creds
|
||||
|
||||
def test_api_key_registration(self):
|
||||
"""Test API key environment variable registration."""
|
||||
import os
|
||||
|
||||
# Set up a test environment variable
|
||||
os.environ["TEST_API_KEY"] = "test-api-key-value"
|
||||
|
||||
try:
|
||||
AutoRegistry.register_api_key("test_provider", "TEST_API_KEY")
|
||||
|
||||
# Verify the mapping is stored
|
||||
assert AutoRegistry._api_key_mappings["test_provider"] == "TEST_API_KEY"
|
||||
|
||||
# Verify a credential was created
|
||||
all_creds = AutoRegistry.get_all_credentials()
|
||||
test_cred = next(
|
||||
(c for c in all_creds if c.id == "test_provider-default"), None
|
||||
)
|
||||
assert test_cred is not None
|
||||
assert test_cred.provider == "test_provider"
|
||||
assert test_cred.api_key.get_secret_value() == "test-api-key-value" # type: ignore
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
del os.environ["TEST_API_KEY"]
|
||||
|
||||
def test_get_oauth_handlers(self):
|
||||
"""Test retrieving all OAuth handlers."""
|
||||
|
||||
# Register multiple providers with OAuth
|
||||
class TestOAuth1(BaseOAuthHandler):
|
||||
PROVIDER_NAME = ProviderName.GITHUB
|
||||
|
||||
class TestOAuth2(BaseOAuthHandler):
|
||||
PROVIDER_NAME = ProviderName.GOOGLE
|
||||
|
||||
provider1 = Provider(
|
||||
name="provider1",
|
||||
oauth_handler=TestOAuth1,
|
||||
webhook_manager=None,
|
||||
default_credentials=[],
|
||||
base_costs=[],
|
||||
supported_auth_types={"oauth2"},
|
||||
)
|
||||
|
||||
provider2 = Provider(
|
||||
name="provider2",
|
||||
oauth_handler=TestOAuth2,
|
||||
webhook_manager=None,
|
||||
default_credentials=[],
|
||||
base_costs=[],
|
||||
supported_auth_types={"oauth2"},
|
||||
)
|
||||
|
||||
AutoRegistry.register_provider(provider1)
|
||||
AutoRegistry.register_provider(provider2)
|
||||
|
||||
handlers = AutoRegistry.get_oauth_handlers()
|
||||
assert "provider1" in handlers
|
||||
assert "provider2" in handlers
|
||||
assert handlers["provider1"] == TestOAuth1
|
||||
assert handlers["provider2"] == TestOAuth2
|
||||
|
||||
def test_block_configuration_registration(self):
|
||||
"""Test registering block configuration."""
|
||||
|
||||
# Create a test block class
|
||||
class TestBlock(Block):
|
||||
pass
|
||||
|
||||
config = BlockConfiguration(
|
||||
provider="test_provider",
|
||||
costs=[],
|
||||
default_credentials=[],
|
||||
webhook_manager=None,
|
||||
oauth_handler=None,
|
||||
)
|
||||
|
||||
AutoRegistry.register_block_configuration(TestBlock, config)
|
||||
|
||||
# Verify it's registered
|
||||
assert TestBlock in AutoRegistry._block_configurations
|
||||
assert AutoRegistry._block_configurations[TestBlock] == config
|
||||
|
||||
def test_clear_registry(self):
|
||||
"""Test clearing all registrations."""
|
||||
# Add some registrations
|
||||
provider = Provider(
|
||||
name="test_provider",
|
||||
oauth_handler=None,
|
||||
webhook_manager=None,
|
||||
default_credentials=[],
|
||||
base_costs=[],
|
||||
supported_auth_types={"api_key"},
|
||||
)
|
||||
AutoRegistry.register_provider(provider)
|
||||
AutoRegistry.register_api_key("test", "TEST_KEY")
|
||||
|
||||
# Clear everything
|
||||
AutoRegistry.clear()
|
||||
|
||||
# Verify everything is cleared
|
||||
assert len(AutoRegistry._providers) == 0
|
||||
assert len(AutoRegistry._default_credentials) == 0
|
||||
assert len(AutoRegistry._oauth_handlers) == 0
|
||||
assert len(AutoRegistry._webhook_managers) == 0
|
||||
assert len(AutoRegistry._block_configurations) == 0
|
||||
assert len(AutoRegistry._api_key_mappings) == 0
|
||||
|
||||
|
||||
class TestAutoRegistryPatching:
|
||||
"""Test the integration patching functionality."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Clear registry before each test."""
|
||||
AutoRegistry.clear()
|
||||
|
||||
@patch("backend.integrations.oauth.HANDLERS_BY_NAME", {})
|
||||
def test_oauth_handler_patching(self):
|
||||
"""Test that OAuth handlers are patched into the system."""
|
||||
|
||||
# Create a test OAuth handler
|
||||
class TestOAuthHandler(BaseOAuthHandler):
|
||||
PROVIDER_NAME = ProviderName.GITHUB
|
||||
|
||||
# Register a provider with OAuth
|
||||
provider = Provider(
|
||||
name="patched_provider",
|
||||
oauth_handler=TestOAuthHandler,
|
||||
webhook_manager=None,
|
||||
default_credentials=[],
|
||||
base_costs=[],
|
||||
supported_auth_types={"oauth2"},
|
||||
)
|
||||
|
||||
AutoRegistry.register_provider(provider)
|
||||
|
||||
# Mock the oauth module
|
||||
mock_oauth = MagicMock()
|
||||
mock_oauth.HANDLERS_BY_NAME = {}
|
||||
|
||||
with patch("backend.integrations.oauth", mock_oauth):
|
||||
# Apply patches
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
# Verify the patched dict works
|
||||
patched_dict = mock_oauth.HANDLERS_BY_NAME
|
||||
assert "patched_provider" in patched_dict
|
||||
assert patched_dict["patched_provider"] == TestOAuthHandler
|
||||
|
||||
@patch("backend.integrations.webhooks.load_webhook_managers")
|
||||
def test_webhook_manager_patching(self, mock_load_managers):
|
||||
"""Test that webhook managers are patched into the system."""
|
||||
# Set up the mock to return an empty dict
|
||||
mock_load_managers.return_value = {}
|
||||
|
||||
# Create a test webhook manager
|
||||
class TestWebhookManager(BaseWebhooksManager):
|
||||
PROVIDER_NAME = ProviderName.GITHUB
|
||||
|
||||
# Register a provider with webhooks
|
||||
provider = Provider(
|
||||
name="webhook_provider",
|
||||
oauth_handler=None,
|
||||
webhook_manager=TestWebhookManager,
|
||||
default_credentials=[],
|
||||
base_costs=[],
|
||||
supported_auth_types={"api_key"},
|
||||
)
|
||||
|
||||
AutoRegistry.register_provider(provider)
|
||||
|
||||
# Mock the webhooks module
|
||||
mock_webhooks = MagicMock()
|
||||
mock_webhooks.load_webhook_managers = mock_load_managers
|
||||
|
||||
with patch("backend.integrations.webhooks", mock_webhooks):
|
||||
# Apply patches
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
# Call the patched function
|
||||
result = mock_webhooks.load_webhook_managers()
|
||||
|
||||
# Verify our webhook manager is included
|
||||
assert "webhook_provider" in result
|
||||
assert result["webhook_provider"] == TestWebhookManager
|
||||
|
||||
|
||||
class TestProviderBuilder:
|
||||
"""Test the ProviderBuilder fluent API."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Clear registry before each test."""
|
||||
AutoRegistry.clear()
|
||||
|
||||
def test_basic_provider_builder(self):
|
||||
"""Test building a basic provider."""
|
||||
provider = (
|
||||
ProviderBuilder("test_provider")
|
||||
.with_api_key("TEST_API_KEY", "Test API Key")
|
||||
.build()
|
||||
)
|
||||
|
||||
assert provider.name == "test_provider"
|
||||
assert "api_key" in provider.supported_auth_types
|
||||
assert AutoRegistry.get_provider("test_provider") == provider
|
||||
|
||||
def test_provider_builder_with_oauth(self):
|
||||
"""Test building a provider with OAuth."""
|
||||
|
||||
class TestOAuth(BaseOAuthHandler):
|
||||
PROVIDER_NAME = ProviderName.GITHUB
|
||||
|
||||
provider = (
|
||||
ProviderBuilder("oauth_test")
|
||||
.with_oauth(TestOAuth, scopes=["read", "write"])
|
||||
.build()
|
||||
)
|
||||
|
||||
assert provider.oauth_handler == TestOAuth
|
||||
assert "oauth2" in provider.supported_auth_types
|
||||
|
||||
def test_provider_builder_with_webhook(self):
|
||||
"""Test building a provider with webhook manager."""
|
||||
|
||||
class TestWebhook(BaseWebhooksManager):
|
||||
PROVIDER_NAME = ProviderName.GITHUB
|
||||
|
||||
provider = (
|
||||
ProviderBuilder("webhook_test").with_webhook_manager(TestWebhook).build()
|
||||
)
|
||||
|
||||
assert provider.webhook_manager == TestWebhook
|
||||
|
||||
def test_provider_builder_with_base_cost(self):
|
||||
"""Test building a provider with base costs."""
|
||||
from backend.data.cost import BlockCostType
|
||||
|
||||
provider = (
|
||||
ProviderBuilder("cost_test")
|
||||
.with_base_cost(10, BlockCostType.RUN)
|
||||
.with_base_cost(5, BlockCostType.BYTE)
|
||||
.build()
|
||||
)
|
||||
|
||||
assert len(provider.base_costs) == 2
|
||||
assert provider.base_costs[0].cost_amount == 10
|
||||
assert provider.base_costs[0].cost_type == BlockCostType.RUN
|
||||
assert provider.base_costs[1].cost_amount == 5
|
||||
assert provider.base_costs[1].cost_type == BlockCostType.BYTE
|
||||
|
||||
def test_provider_builder_with_api_client(self):
|
||||
"""Test building a provider with API client factory."""
|
||||
|
||||
def mock_client_factory():
|
||||
return Mock()
|
||||
|
||||
provider = (
|
||||
ProviderBuilder("client_test").with_api_client(mock_client_factory).build()
|
||||
)
|
||||
|
||||
assert provider._api_client_factory == mock_client_factory
|
||||
|
||||
def test_provider_builder_with_error_handler(self):
|
||||
"""Test building a provider with error handler."""
|
||||
|
||||
def mock_error_handler(exc: Exception) -> str:
|
||||
return f"Error: {str(exc)}"
|
||||
|
||||
provider = (
|
||||
ProviderBuilder("error_test").with_error_handler(mock_error_handler).build()
|
||||
)
|
||||
|
||||
assert provider._error_handler == mock_error_handler
|
||||
|
||||
def test_provider_builder_complete_example(self):
|
||||
"""Test building a complete provider with all features."""
|
||||
from backend.data.cost import BlockCostType
|
||||
|
||||
class TestOAuth(BaseOAuthHandler):
|
||||
PROVIDER_NAME = ProviderName.GITHUB
|
||||
|
||||
class TestWebhook(BaseWebhooksManager):
|
||||
PROVIDER_NAME = ProviderName.GITHUB
|
||||
|
||||
def client_factory():
|
||||
return Mock()
|
||||
|
||||
def error_handler(exc):
|
||||
return str(exc)
|
||||
|
||||
provider = (
|
||||
ProviderBuilder("complete_test")
|
||||
.with_api_key("COMPLETE_API_KEY", "Complete API Key")
|
||||
.with_oauth(TestOAuth, scopes=["read"])
|
||||
.with_webhook_manager(TestWebhook)
|
||||
.with_base_cost(100, BlockCostType.RUN)
|
||||
.with_api_client(client_factory)
|
||||
.with_error_handler(error_handler)
|
||||
.with_config(custom_setting="value")
|
||||
.build()
|
||||
)
|
||||
|
||||
# Verify all settings
|
||||
assert provider.name == "complete_test"
|
||||
assert "api_key" in provider.supported_auth_types
|
||||
assert "oauth2" in provider.supported_auth_types
|
||||
assert provider.oauth_handler == TestOAuth
|
||||
assert provider.webhook_manager == TestWebhook
|
||||
assert len(provider.base_costs) == 1
|
||||
assert provider._api_client_factory == client_factory
|
||||
assert provider._error_handler == error_handler
|
||||
assert provider.get_config("custom_setting") == "value" # from with_config
|
||||
|
||||
# Verify it's registered
|
||||
assert AutoRegistry.get_provider("complete_test") == provider
|
||||
assert "complete_test" in AutoRegistry._oauth_handlers
|
||||
assert "complete_test" in AutoRegistry._webhook_managers
|
||||
|
||||
|
||||
class TestSDKImports:
|
||||
"""Test that all expected exports are available from the SDK."""
|
||||
|
||||
def test_core_block_imports(self):
|
||||
"""Test core block system imports."""
|
||||
from backend.sdk import Block, BlockCategory
|
||||
|
||||
# Just verify they're importable
|
||||
assert Block is not None
|
||||
assert BlockCategory is not None
|
||||
|
||||
def test_schema_imports(self):
|
||||
"""Test schema and model imports."""
|
||||
from backend.sdk import APIKeyCredentials, SchemaField
|
||||
|
||||
assert SchemaField is not None
|
||||
assert APIKeyCredentials is not None
|
||||
|
||||
def test_type_alias_imports(self):
|
||||
"""Test type alias imports."""
|
||||
from backend.sdk import Boolean, Float, Integer, String
|
||||
|
||||
# Verify they're the correct types
|
||||
assert String is str
|
||||
assert Integer is int
|
||||
assert Float is float
|
||||
assert Boolean is bool
|
||||
|
||||
def test_cost_system_imports(self):
|
||||
"""Test cost system imports."""
|
||||
from backend.sdk import BlockCost, BlockCostType
|
||||
|
||||
assert BlockCost is not None
|
||||
assert BlockCostType is not None
|
||||
|
||||
def test_utility_imports(self):
|
||||
"""Test utility imports."""
|
||||
from backend.sdk import BaseModel, json, requests
|
||||
|
||||
assert json is not None
|
||||
assert BaseModel is not None
|
||||
assert requests is not None
|
||||
|
||||
def test_integration_imports(self):
|
||||
"""Test integration imports."""
|
||||
from backend.sdk import ProviderName
|
||||
|
||||
assert ProviderName is not None
|
||||
|
||||
def test_sdk_component_imports(self):
|
||||
"""Test SDK-specific component imports."""
|
||||
from backend.sdk import AutoRegistry, ProviderBuilder
|
||||
|
||||
assert AutoRegistry is not None
|
||||
assert ProviderBuilder is not None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -1,197 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Standalone SDK tests that can run without Redis/PostgreSQL/RabbitMQ.
|
||||
Run with: python test/sdk/test_sdk_standalone.py
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add backend to path
|
||||
backend_path = Path(__file__).parent.parent.parent
|
||||
sys.path.insert(0, str(backend_path))
|
||||
|
||||
|
||||
def test_sdk_imports():
|
||||
"""Test that SDK imports work correctly"""
|
||||
print("\n=== Testing SDK Imports ===")
|
||||
|
||||
# Import SDK
|
||||
import backend.sdk as sdk
|
||||
|
||||
# Verify imports
|
||||
assert hasattr(sdk, "Block")
|
||||
assert hasattr(sdk, "BlockSchema")
|
||||
assert hasattr(sdk, "SchemaField")
|
||||
assert hasattr(sdk, "provider")
|
||||
assert hasattr(sdk, "cost_config")
|
||||
print("✅ SDK imports work correctly")
|
||||
|
||||
|
||||
def test_dynamic_provider():
|
||||
"""Test dynamic provider enum"""
|
||||
print("\n=== Testing Dynamic Provider ===")
|
||||
|
||||
from backend.sdk import ProviderName
|
||||
|
||||
# Test existing provider
|
||||
github = ProviderName.GITHUB
|
||||
assert github.value == "github"
|
||||
|
||||
# Test dynamic provider
|
||||
custom = ProviderName("my-custom-provider")
|
||||
assert custom.value == "my-custom-provider"
|
||||
assert isinstance(custom, ProviderName)
|
||||
print("✅ Dynamic provider enum works")
|
||||
|
||||
|
||||
def test_auto_registry():
|
||||
"""Test auto-registration system"""
|
||||
print("\n=== Testing Auto-Registry ===")
|
||||
|
||||
from backend.sdk import BlockCost, BlockCostType, cost_config, provider
|
||||
from backend.sdk.auto_registry import get_registry
|
||||
|
||||
registry = get_registry()
|
||||
initial_count = len(registry.providers)
|
||||
|
||||
# Register a test provider
|
||||
@provider("test-provider-xyz")
|
||||
class TestBlock:
|
||||
pass
|
||||
|
||||
assert "test-provider-xyz" in registry.providers
|
||||
assert len(registry.providers) == initial_count + 1
|
||||
|
||||
# Register costs
|
||||
@cost_config(BlockCost(cost_amount=5, cost_type=BlockCostType.RUN))
|
||||
class TestBlock2:
|
||||
pass
|
||||
|
||||
assert TestBlock2 in registry.block_costs
|
||||
assert registry.block_costs[TestBlock2][0].cost_amount == 5
|
||||
|
||||
print("✅ Auto-registry works correctly")
|
||||
|
||||
|
||||
def test_complete_block_creation():
|
||||
"""Test creating a complete block with SDK"""
|
||||
print("\n=== Testing Complete Block Creation ===")
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockCost,
|
||||
BlockCostType,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
Integer,
|
||||
SchemaField,
|
||||
SecretStr,
|
||||
String,
|
||||
cost_config,
|
||||
default_credentials,
|
||||
provider,
|
||||
)
|
||||
|
||||
@provider("test-ai-service")
|
||||
@cost_config(
|
||||
BlockCost(cost_amount=10, cost_type=BlockCostType.RUN),
|
||||
BlockCost(cost_amount=2, cost_type=BlockCostType.BYTE),
|
||||
)
|
||||
@default_credentials(
|
||||
APIKeyCredentials(
|
||||
id="test-ai-default",
|
||||
provider="test-ai-service",
|
||||
api_key=SecretStr("test-default-key"),
|
||||
title="Test AI Service Default Key",
|
||||
)
|
||||
)
|
||||
class TestAIBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = CredentialsField(
|
||||
provider="test-ai-service",
|
||||
supported_credential_types={"api_key"},
|
||||
description="API credentials",
|
||||
)
|
||||
prompt: String = SchemaField(description="AI prompt")
|
||||
|
||||
class Output(BlockSchema):
|
||||
response: String = SchemaField(description="AI response")
|
||||
tokens: Integer = SchemaField(description="Token count")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a1b2c3d4-e5f6-7890-abcd-ef1234567890",
|
||||
description="Test AI Service Block",
|
||||
categories={BlockCategory.AI, BlockCategory.TEXT},
|
||||
input_schema=TestAIBlock.Input,
|
||||
output_schema=TestAIBlock.Output,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
yield "response", f"AI says: {input_data.prompt}"
|
||||
yield "tokens", len(input_data.prompt.split())
|
||||
|
||||
# Verify block creation
|
||||
block = TestAIBlock()
|
||||
assert block.id == "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
|
||||
assert BlockCategory.AI in block.categories
|
||||
|
||||
# Verify auto-registration
|
||||
from backend.sdk.auto_registry import get_registry
|
||||
|
||||
registry = get_registry()
|
||||
|
||||
assert "test-ai-service" in registry.providers
|
||||
assert TestAIBlock in registry.block_costs
|
||||
assert len(registry.block_costs[TestAIBlock]) == 2
|
||||
|
||||
print("✅ Complete block creation works")
|
||||
|
||||
|
||||
def run_all_tests():
|
||||
"""Run all standalone tests"""
|
||||
print("\n" + "=" * 60)
|
||||
print("🧪 Running Standalone SDK Tests")
|
||||
print("=" * 60)
|
||||
|
||||
tests = [
|
||||
test_sdk_imports,
|
||||
test_dynamic_provider,
|
||||
test_auto_registry,
|
||||
test_complete_block_creation,
|
||||
]
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
|
||||
for test in tests:
|
||||
try:
|
||||
test()
|
||||
passed += 1
|
||||
except Exception as e:
|
||||
print(f"❌ {test.__name__} failed: {e}")
|
||||
failed += 1
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(f"📊 Results: {passed} passed, {failed} failed")
|
||||
print("=" * 60)
|
||||
|
||||
if failed == 0:
|
||||
print("\n🎉 All standalone SDK tests passed!")
|
||||
return True
|
||||
else:
|
||||
print(f"\n⚠️ {failed} tests failed")
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = run_all_tests()
|
||||
sys.exit(0 if success else 1)
|
||||
505
autogpt_platform/backend/test/sdk/test_sdk_webhooks.py
Normal file
505
autogpt_platform/backend/test/sdk/test_sdk_webhooks.py
Normal file
@@ -0,0 +1,505 @@
|
||||
"""
|
||||
Tests for SDK webhook functionality.
|
||||
|
||||
This test suite verifies webhook blocks and webhook manager integration.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
AutoRegistry,
|
||||
BaseModel,
|
||||
BaseWebhooksManager,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockWebhookConfig,
|
||||
Boolean,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
Field,
|
||||
Integer,
|
||||
ProviderBuilder,
|
||||
SchemaField,
|
||||
SecretStr,
|
||||
String,
|
||||
)
|
||||
|
||||
|
||||
class TestWebhookTypes(str, Enum):
|
||||
"""Test webhook event types."""
|
||||
|
||||
CREATED = "created"
|
||||
UPDATED = "updated"
|
||||
DELETED = "deleted"
|
||||
|
||||
|
||||
class TestWebhooksManager(BaseWebhooksManager):
|
||||
"""Test webhook manager implementation."""
|
||||
|
||||
PROVIDER_NAME = ProviderName.GITHUB # Reuse for testing
|
||||
|
||||
class WebhookType(str, Enum):
|
||||
TEST = "test"
|
||||
|
||||
@classmethod
|
||||
async def validate_payload(cls, webhook, request):
|
||||
"""Validate incoming webhook payload."""
|
||||
# Mock implementation
|
||||
payload = {"test": "data"}
|
||||
event_type = "test_event"
|
||||
return payload, event_type
|
||||
|
||||
async def _register_webhook(
|
||||
self,
|
||||
credentials,
|
||||
webhook_type: str,
|
||||
resource: str,
|
||||
events: list[str],
|
||||
ingress_url: str,
|
||||
secret: str,
|
||||
) -> tuple[str, dict]:
|
||||
"""Register webhook with external service."""
|
||||
# Mock implementation
|
||||
webhook_id = f"test_webhook_{resource}"
|
||||
config = {
|
||||
"webhook_type": webhook_type,
|
||||
"resource": resource,
|
||||
"events": events,
|
||||
"url": ingress_url,
|
||||
}
|
||||
return webhook_id, config
|
||||
|
||||
async def _deregister_webhook(self, webhook, credentials) -> None:
|
||||
"""Deregister webhook from external service."""
|
||||
# Mock implementation
|
||||
pass
|
||||
|
||||
|
||||
class TestWebhookBlock(Block):
|
||||
"""Test webhook block implementation."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = CredentialsField(
|
||||
provider="test_webhooks",
|
||||
supported_credential_types={"api_key"},
|
||||
description="Webhook service credentials",
|
||||
)
|
||||
webhook_url: String = SchemaField(
|
||||
description="URL to receive webhooks",
|
||||
)
|
||||
resource_id: String = SchemaField(
|
||||
description="Resource to monitor",
|
||||
)
|
||||
events: list[TestWebhookTypes] = SchemaField(
|
||||
description="Events to listen for",
|
||||
default=[TestWebhookTypes.CREATED],
|
||||
)
|
||||
payload: dict = SchemaField(
|
||||
description="Webhook payload",
|
||||
default={},
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
webhook_id: String = SchemaField(description="Registered webhook ID")
|
||||
is_active: Boolean = SchemaField(description="Webhook is active")
|
||||
event_count: Integer = SchemaField(description="Number of events configured")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="test-webhook-block",
|
||||
description="Test webhook block",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=TestWebhookBlock.Input,
|
||||
output_schema=TestWebhookBlock.Output,
|
||||
webhook_config=BlockWebhookConfig(
|
||||
provider="test_webhooks", # type: ignore
|
||||
webhook_type="test",
|
||||
resource_format="{resource_id}",
|
||||
),
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Simulate webhook registration
|
||||
webhook_id = f"webhook_{input_data.resource_id}"
|
||||
|
||||
yield "webhook_id", webhook_id
|
||||
yield "is_active", True
|
||||
yield "event_count", len(input_data.events)
|
||||
|
||||
|
||||
class TestWebhookBlockCreation:
|
||||
"""Test creating webhook blocks with the SDK."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test environment."""
|
||||
AutoRegistry.clear()
|
||||
|
||||
# Register a provider with webhook support
|
||||
self.provider = (
|
||||
ProviderBuilder("test_webhooks")
|
||||
.with_api_key("TEST_WEBHOOK_KEY", "Test Webhook API Key")
|
||||
.with_webhook_manager(TestWebhooksManager)
|
||||
.build()
|
||||
)
|
||||
|
||||
def test_basic_webhook_block(self):
|
||||
"""Test creating a basic webhook block."""
|
||||
block = TestWebhookBlock()
|
||||
|
||||
# Verify block configuration
|
||||
assert block.webhook_config is not None
|
||||
assert block.webhook_config.provider == "test_webhooks"
|
||||
assert block.webhook_config.webhook_type == "test"
|
||||
assert "{resource_id}" in block.webhook_config.resource_format # type: ignore
|
||||
|
||||
# Test block execution
|
||||
test_creds = APIKeyCredentials(
|
||||
id="test-webhook-creds",
|
||||
provider="test_webhooks",
|
||||
api_key=SecretStr("test-key"),
|
||||
title="Test Webhook Key",
|
||||
)
|
||||
|
||||
outputs = dict(
|
||||
block.run(
|
||||
TestWebhookBlock.Input(
|
||||
credentials={ # type: ignore
|
||||
"provider": "test_webhooks",
|
||||
"id": "test-webhook-creds",
|
||||
"type": "api_key",
|
||||
},
|
||||
webhook_url="https://example.com/webhook",
|
||||
resource_id="resource_123",
|
||||
events=[TestWebhookTypes.CREATED, TestWebhookTypes.UPDATED],
|
||||
),
|
||||
credentials=test_creds,
|
||||
)
|
||||
)
|
||||
|
||||
assert outputs["webhook_id"] == "webhook_resource_123"
|
||||
assert outputs["is_active"] is True
|
||||
assert outputs["event_count"] == 2
|
||||
|
||||
def test_webhook_block_with_filters(self):
|
||||
"""Test webhook block with event filters."""
|
||||
|
||||
class EventFilterModel(BaseModel):
|
||||
include_system: bool = Field(default=False)
|
||||
severity_levels: list[str] = Field(
|
||||
default_factory=lambda: ["info", "warning"]
|
||||
)
|
||||
|
||||
class FilteredWebhookBlock(Block):
|
||||
"""Webhook block with filtering."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = CredentialsField(
|
||||
provider="test_webhooks",
|
||||
supported_credential_types={"api_key"},
|
||||
)
|
||||
resource: String = SchemaField(description="Resource to monitor")
|
||||
filters: EventFilterModel = SchemaField(
|
||||
description="Event filters",
|
||||
default_factory=EventFilterModel,
|
||||
)
|
||||
payload: dict = SchemaField(
|
||||
description="Webhook payload",
|
||||
default={},
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
webhook_active: Boolean = SchemaField(description="Webhook active")
|
||||
filter_summary: String = SchemaField(description="Active filters")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="filtered-webhook-block",
|
||||
description="Webhook with filters",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=FilteredWebhookBlock.Input,
|
||||
output_schema=FilteredWebhookBlock.Output,
|
||||
webhook_config=BlockWebhookConfig(
|
||||
provider="test_webhooks", # type: ignore
|
||||
webhook_type="filtered",
|
||||
resource_format="{resource}",
|
||||
),
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
filters = input_data.filters
|
||||
filter_parts = []
|
||||
|
||||
if filters.include_system:
|
||||
filter_parts.append("system events")
|
||||
|
||||
filter_parts.append(f"{len(filters.severity_levels)} severity levels")
|
||||
|
||||
yield "webhook_active", True
|
||||
yield "filter_summary", ", ".join(filter_parts)
|
||||
|
||||
# Test the block
|
||||
block = FilteredWebhookBlock()
|
||||
|
||||
test_creds = APIKeyCredentials(
|
||||
id="test-creds",
|
||||
provider="test_webhooks",
|
||||
api_key=SecretStr("key"),
|
||||
title="Test Key",
|
||||
)
|
||||
|
||||
# Test with default filters
|
||||
outputs = dict(
|
||||
block.run(
|
||||
FilteredWebhookBlock.Input(
|
||||
credentials={ # type: ignore
|
||||
"provider": "test_webhooks",
|
||||
"id": "test-creds",
|
||||
"type": "api_key",
|
||||
},
|
||||
resource="test_resource",
|
||||
),
|
||||
credentials=test_creds,
|
||||
)
|
||||
)
|
||||
|
||||
assert outputs["webhook_active"] is True
|
||||
assert "2 severity levels" in outputs["filter_summary"]
|
||||
|
||||
# Test with custom filters
|
||||
custom_filters = EventFilterModel(
|
||||
include_system=True,
|
||||
severity_levels=["error", "critical"],
|
||||
)
|
||||
|
||||
outputs = dict(
|
||||
block.run(
|
||||
FilteredWebhookBlock.Input(
|
||||
credentials={ # type: ignore
|
||||
"provider": "test_webhooks",
|
||||
"id": "test-creds",
|
||||
"type": "api_key",
|
||||
},
|
||||
resource="test_resource",
|
||||
filters=custom_filters,
|
||||
),
|
||||
credentials=test_creds,
|
||||
)
|
||||
)
|
||||
|
||||
assert "system events" in outputs["filter_summary"]
|
||||
assert "2 severity levels" in outputs["filter_summary"]
|
||||
|
||||
|
||||
class TestWebhookManagerIntegration:
|
||||
"""Test webhook manager integration with AutoRegistry."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Clear registry."""
|
||||
AutoRegistry.clear()
|
||||
|
||||
def test_webhook_manager_registration(self):
|
||||
"""Test that webhook managers are properly registered."""
|
||||
|
||||
# Create multiple webhook managers
|
||||
class WebhookManager1(BaseWebhooksManager):
|
||||
PROVIDER_NAME = ProviderName.GITHUB
|
||||
|
||||
class WebhookManager2(BaseWebhooksManager):
|
||||
PROVIDER_NAME = ProviderName.GOOGLE
|
||||
|
||||
# Register providers with webhook managers
|
||||
(
|
||||
ProviderBuilder("webhook_service_1")
|
||||
.with_webhook_manager(WebhookManager1)
|
||||
.build()
|
||||
)
|
||||
|
||||
(
|
||||
ProviderBuilder("webhook_service_2")
|
||||
.with_webhook_manager(WebhookManager2)
|
||||
.build()
|
||||
)
|
||||
|
||||
# Verify registration
|
||||
managers = AutoRegistry.get_webhook_managers()
|
||||
assert "webhook_service_1" in managers
|
||||
assert "webhook_service_2" in managers
|
||||
assert managers["webhook_service_1"] == WebhookManager1
|
||||
assert managers["webhook_service_2"] == WebhookManager2
|
||||
|
||||
def test_webhook_block_with_provider_manager(self):
|
||||
"""Test webhook block using a provider's webhook manager."""
|
||||
# Register provider with webhook manager
|
||||
(
|
||||
ProviderBuilder("integrated_webhooks")
|
||||
.with_api_key("INTEGRATED_KEY", "Integrated Webhook Key")
|
||||
.with_webhook_manager(TestWebhooksManager)
|
||||
.build()
|
||||
)
|
||||
|
||||
# Create a block that uses this provider
|
||||
class IntegratedWebhookBlock(Block):
|
||||
"""Block using integrated webhook manager."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = CredentialsField(
|
||||
provider="integrated_webhooks",
|
||||
supported_credential_types={"api_key"},
|
||||
)
|
||||
target: String = SchemaField(description="Webhook target")
|
||||
payload: dict = SchemaField(
|
||||
description="Webhook payload",
|
||||
default={},
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
status: String = SchemaField(description="Webhook status")
|
||||
manager_type: String = SchemaField(description="Manager type used")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="integrated-webhook-block",
|
||||
description="Uses integrated webhook manager",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=IntegratedWebhookBlock.Input,
|
||||
output_schema=IntegratedWebhookBlock.Output,
|
||||
webhook_config=BlockWebhookConfig(
|
||||
provider="integrated_webhooks", # type: ignore
|
||||
webhook_type=TestWebhooksManager.WebhookType.TEST,
|
||||
resource_format="{target}",
|
||||
),
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
# Get the webhook manager for this provider
|
||||
managers = AutoRegistry.get_webhook_managers()
|
||||
manager_class = managers.get("integrated_webhooks")
|
||||
|
||||
yield "status", "configured"
|
||||
yield "manager_type", (
|
||||
manager_class.__name__ if manager_class else "none"
|
||||
)
|
||||
|
||||
# Test the block
|
||||
block = IntegratedWebhookBlock()
|
||||
|
||||
test_creds = APIKeyCredentials(
|
||||
id="integrated-creds",
|
||||
provider="integrated_webhooks",
|
||||
api_key=SecretStr("key"),
|
||||
title="Integrated Key",
|
||||
)
|
||||
|
||||
outputs = dict(
|
||||
block.run(
|
||||
IntegratedWebhookBlock.Input(
|
||||
credentials={ # type: ignore
|
||||
"provider": "integrated_webhooks",
|
||||
"id": "integrated-creds",
|
||||
"type": "api_key",
|
||||
},
|
||||
target="test_target",
|
||||
),
|
||||
credentials=test_creds,
|
||||
)
|
||||
)
|
||||
|
||||
assert outputs["status"] == "configured"
|
||||
assert outputs["manager_type"] == "TestWebhooksManager"
|
||||
|
||||
|
||||
class TestWebhookEventHandling:
|
||||
"""Test webhook event handling in blocks."""
|
||||
|
||||
def test_webhook_event_processing_block(self):
|
||||
"""Test a block that processes webhook events."""
|
||||
|
||||
class WebhookEventBlock(Block):
|
||||
"""Block that processes webhook events."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
event_type: String = SchemaField(description="Type of webhook event")
|
||||
payload: dict = SchemaField(description="Webhook payload")
|
||||
verify_signature: Boolean = SchemaField(
|
||||
description="Whether to verify webhook signature",
|
||||
default=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
processed: Boolean = SchemaField(description="Event was processed")
|
||||
event_summary: String = SchemaField(description="Summary of event")
|
||||
action_required: Boolean = SchemaField(description="Action required")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="webhook-event-processor",
|
||||
description="Processes incoming webhook events",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=WebhookEventBlock.Input,
|
||||
output_schema=WebhookEventBlock.Output,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
# Process based on event type
|
||||
event_type = input_data.event_type
|
||||
payload = input_data.payload
|
||||
|
||||
if event_type == "created":
|
||||
summary = f"New item created: {payload.get('id', 'unknown')}"
|
||||
action_required = True
|
||||
elif event_type == "updated":
|
||||
summary = f"Item updated: {payload.get('id', 'unknown')}"
|
||||
action_required = False
|
||||
elif event_type == "deleted":
|
||||
summary = f"Item deleted: {payload.get('id', 'unknown')}"
|
||||
action_required = True
|
||||
else:
|
||||
summary = f"Unknown event: {event_type}"
|
||||
action_required = False
|
||||
|
||||
yield "processed", True
|
||||
yield "event_summary", summary
|
||||
yield "action_required", action_required
|
||||
|
||||
# Test the block with different events
|
||||
block = WebhookEventBlock()
|
||||
|
||||
# Test created event
|
||||
outputs = dict(
|
||||
block.run(
|
||||
WebhookEventBlock.Input(
|
||||
event_type="created",
|
||||
payload={"id": "123", "name": "Test Item"},
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
assert outputs["processed"] is True
|
||||
assert "New item created: 123" in outputs["event_summary"]
|
||||
assert outputs["action_required"] is True
|
||||
|
||||
# Test updated event
|
||||
outputs = dict(
|
||||
block.run(
|
||||
WebhookEventBlock.Input(
|
||||
event_type="updated",
|
||||
payload={"id": "456", "changes": ["name", "status"]},
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
assert outputs["processed"] is True
|
||||
assert "Item updated: 456" in outputs["event_summary"]
|
||||
assert outputs["action_required"] is False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user