update sdk

This commit is contained in:
SwiftyOS
2025-06-13 19:37:31 +02:00
parent 864f76f904
commit a09ecab7f1
28 changed files with 2356 additions and 2614 deletions

View File

@@ -0,0 +1,13 @@
"""
Shared configuration for all Exa blocks using the new SDK pattern.
"""
from backend.sdk import BlockCostType, ProviderBuilder
# Configure the Exa provider once for all blocks
exa = (
ProviderBuilder("exa")
.with_api_key("EXA_API_KEY", "Exa API Key")
.with_base_cost(1, BlockCostType.RUN)
.build()
)

View File

@@ -5,38 +5,21 @@ from backend.sdk import (
BlockOutput,
BlockSchema,
Boolean,
CredentialsField,
CredentialsMetaInput,
Dict,
List,
SchemaField,
SecretStr,
Settings,
String,
default_credentials,
provider,
requests,
)
settings = Settings()
from ._config import exa
@provider("exa")
@default_credentials(
APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="exa",
api_key=SecretStr(settings.secrets.exa_api_key),
title="Use Credits for Exa",
expires_at=None,
)
)
class ExaAnswerBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput = CredentialsField(
provider="exa",
supported_credential_types={"api_key"},
description="The Exa integration requires an API Key.",
credentials: CredentialsMetaInput = exa.credentials_field(
description="The Exa integration requires an API Key."
)
query: String = SchemaField(
description="The question or query to answer",
@@ -61,19 +44,17 @@ class ExaAnswerBlock(Block):
class Output(BlockSchema):
answer: String = SchemaField(
description="The generated answer based on search results",
description="The generated answer based on search results"
)
citations: List[Dict] = SchemaField(
description="Search results used to generate the answer",
default_factory=list,
)
cost_dollars: Dict = SchemaField(
description="Cost breakdown of the request",
default_factory=dict,
description="Cost breakdown of the request", default_factory=dict
)
error: String = SchemaField(
description="Error message if the request failed",
default="",
description="Error message if the request failed", default=""
)
def __init__(self):
@@ -83,17 +64,6 @@ class ExaAnswerBlock(Block):
categories={BlockCategory.SEARCH, BlockCategory.AI},
input_schema=ExaAnswerBlock.Input,
output_schema=ExaAnswerBlock.Output,
test_input={
"query": "What is the capital of France?",
"text": False,
"stream": False,
"model": "exa",
},
test_output=[
("answer", "Paris"),
("citations", []),
("cost_dollars", {}),
],
)
def run(

View File

@@ -4,28 +4,24 @@ from backend.sdk import (
BlockCategory,
BlockOutput,
BlockSchema,
CredentialsField,
CredentialsMetaInput,
List,
SchemaField,
String,
provider,
requests,
)
from ._config import exa
from .helpers import ContentSettings
@provider("exa")
class ExaContentsBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput = CredentialsField(
provider="exa",
supported_credential_types={"api_key"},
description="The Exa integration requires an API Key.",
credentials: CredentialsMetaInput = exa.credentials_field(
description="The Exa integration requires an API Key."
)
ids: List[String] = SchemaField(
description="Array of document IDs obtained from searches",
description="Array of document IDs obtained from searches"
)
contents: ContentSettings = SchemaField(
description="Content retrieval settings",
@@ -35,8 +31,7 @@ class ExaContentsBlock(Block):
class Output(BlockSchema):
results: List = SchemaField(
description="List of document contents",
default_factory=list,
description="List of document contents", default_factory=list
)
error: String = SchemaField(
description="Error message if the request failed", default=""

View File

@@ -7,51 +7,38 @@ from backend.sdk import (
BlockOutput,
BlockSchema,
Boolean,
CredentialsField,
CredentialsMetaInput,
Integer,
List,
SchemaField,
String,
provider,
requests,
)
from ._config import exa
from .helpers import ContentSettings
@provider("exa")
class ExaSearchBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput = CredentialsField(
provider="exa",
supported_credential_types={"api_key"},
description="The Exa integration requires an API Key.",
credentials: CredentialsMetaInput = exa.credentials_field(
description="The Exa integration requires an API Key."
)
query: String = SchemaField(description="The search query")
use_auto_prompt: Boolean = SchemaField(
description="Whether to use autoprompt",
default=True,
advanced=True,
description="Whether to use autoprompt", default=True, advanced=True
)
type: String = SchemaField(
description="Type of search",
default="",
advanced=True,
description="Type of search", default="", advanced=True
)
category: String = SchemaField(
description="Category to search within",
default="",
advanced=True,
description="Category to search within", default="", advanced=True
)
number_of_results: Integer = SchemaField(
description="Number of results to return",
default=10,
advanced=True,
description="Number of results to return", default=10, advanced=True
)
include_domains: List[String] = SchemaField(
description="Domains to include in search",
default_factory=list,
description="Domains to include in search", default_factory=list
)
exclude_domains: List[String] = SchemaField(
description="Domains to exclude from search",
@@ -59,26 +46,22 @@ class ExaSearchBlock(Block):
advanced=True,
)
start_crawl_date: datetime = SchemaField(
description="Start date for crawled content",
description="Start date for crawled content"
)
end_crawl_date: datetime = SchemaField(
description="End date for crawled content",
description="End date for crawled content"
)
start_published_date: datetime = SchemaField(
description="Start date for published content",
description="Start date for published content"
)
end_published_date: datetime = SchemaField(
description="End date for published content",
description="End date for published content"
)
include_text: List[String] = SchemaField(
description="Text patterns to include",
default_factory=list,
advanced=True,
description="Text patterns to include", default_factory=list, advanced=True
)
exclude_text: List[String] = SchemaField(
description="Text patterns to exclude",
default_factory=list,
advanced=True,
description="Text patterns to exclude", default_factory=list, advanced=True
)
contents: ContentSettings = SchemaField(
description="Content retrieval settings",
@@ -88,12 +71,10 @@ class ExaSearchBlock(Block):
class Output(BlockSchema):
results: List = SchemaField(
description="List of search results",
default_factory=list,
description="List of search results", default_factory=list
)
error: String = SchemaField(
description="Error message if the request failed",
default="",
description="Error message if the request failed", default=""
)
def __init__(self):

View File

@@ -7,34 +7,28 @@ from backend.sdk import (
BlockCategory,
BlockOutput,
BlockSchema,
CredentialsField,
CredentialsMetaInput,
Integer,
List,
SchemaField,
String,
provider,
requests,
)
from ._config import exa
from .helpers import ContentSettings
@provider("exa")
class ExaFindSimilarBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput = CredentialsField(
provider="exa",
supported_credential_types={"api_key"},
description="The Exa integration requires an API Key.",
credentials: CredentialsMetaInput = exa.credentials_field(
description="The Exa integration requires an API Key."
)
url: String = SchemaField(
description="The url for which you would like to find similar links"
)
number_of_results: Integer = SchemaField(
description="Number of results to return",
default=10,
advanced=True,
description="Number of results to return", default=10, advanced=True
)
include_domains: List[String] = SchemaField(
description="Domains to include in search",
@@ -47,16 +41,16 @@ class ExaFindSimilarBlock(Block):
advanced=True,
)
start_crawl_date: datetime = SchemaField(
description="Start date for crawled content",
description="Start date for crawled content"
)
end_crawl_date: datetime = SchemaField(
description="End date for crawled content",
description="End date for crawled content"
)
start_published_date: datetime = SchemaField(
description="Start date for published content",
description="Start date for published content"
)
end_published_date: datetime = SchemaField(
description="End date for published content",
description="End date for published content"
)
include_text: List[String] = SchemaField(
description="Text patterns to include (max 1 string, up to 5 words)",

View File

@@ -6,7 +6,6 @@ from backend.sdk import (
BlockOutput,
BlockSchema,
Boolean,
CredentialsField,
CredentialsMetaInput,
Dict,
Integer,
@@ -14,23 +13,20 @@ from backend.sdk import (
Optional,
SchemaField,
String,
provider,
requests,
)
from ._config import exa
from .helpers import WebsetEnrichmentConfig, WebsetSearchConfig
@provider("exa")
class ExaCreateWebsetBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput = CredentialsField(
provider="exa",
supported_credential_types={"api_key"},
description="The Exa integration requires an API Key.",
credentials: CredentialsMetaInput = exa.credentials_field(
description="The Exa integration requires an API Key."
)
search: WebsetSearchConfig = SchemaField(
description="Initial search configuration for the Webset",
description="Initial search configuration for the Webset"
)
enrichments: Optional[List[WebsetEnrichmentConfig]] = SchemaField(
default=None,
@@ -51,21 +47,17 @@ class ExaCreateWebsetBlock(Block):
class Output(BlockSchema):
webset_id: String = SchemaField(
description="The unique identifier for the created webset",
)
status: String = SchemaField(
description="The status of the webset",
description="The unique identifier for the created webset"
)
status: String = SchemaField(description="The status of the webset")
external_id: Optional[String] = SchemaField(
description="The external identifier for the webset",
default=None,
description="The external identifier for the webset", default=None
)
created_at: String = SchemaField(
description="The date and time the webset was created",
description="The date and time the webset was created"
)
error: String = SchemaField(
description="Error message if the request failed",
default="",
description="Error message if the request failed", default=""
)
def __init__(self):
@@ -121,13 +113,10 @@ class ExaCreateWebsetBlock(Block):
yield "created_at", ""
@provider("exa")
class ExaUpdateWebsetBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput = CredentialsField(
provider="exa",
supported_credential_types={"api_key"},
description="The Exa integration requires an API Key.",
credentials: CredentialsMetaInput = exa.credentials_field(
description="The Exa integration requires an API Key."
)
webset_id: String = SchemaField(
description="The ID or external ID of the Webset to update",
@@ -140,25 +129,20 @@ class ExaUpdateWebsetBlock(Block):
class Output(BlockSchema):
webset_id: String = SchemaField(
description="The unique identifier for the webset",
)
status: String = SchemaField(
description="The status of the webset",
description="The unique identifier for the webset"
)
status: String = SchemaField(description="The status of the webset")
external_id: Optional[String] = SchemaField(
description="The external identifier for the webset",
default=None,
description="The external identifier for the webset", default=None
)
metadata: Dict = SchemaField(
description="Updated metadata for the webset",
default_factory=dict,
description="Updated metadata for the webset", default_factory=dict
)
updated_at: String = SchemaField(
description="The date and time the webset was updated",
description="The date and time the webset was updated"
)
error: String = SchemaField(
description="Error message if the request failed",
default="",
description="Error message if the request failed", default=""
)
def __init__(self):
@@ -203,13 +187,10 @@ class ExaUpdateWebsetBlock(Block):
yield "updated_at", ""
@provider("exa")
class ExaListWebsetsBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput = CredentialsField(
provider="exa",
supported_credential_types={"api_key"},
description="The Exa integration requires an API Key.",
credentials: CredentialsMetaInput = exa.credentials_field(
description="The Exa integration requires an API Key."
)
cursor: Optional[String] = SchemaField(
default=None,
@@ -225,21 +206,16 @@ class ExaListWebsetsBlock(Block):
)
class Output(BlockSchema):
websets: List = SchemaField(
description="List of websets",
default_factory=list,
)
websets: List = SchemaField(description="List of websets", default_factory=list)
has_more: Boolean = SchemaField(
description="Whether there are more results to paginate through",
default=False,
)
next_cursor: Optional[String] = SchemaField(
description="Cursor for the next page of results",
default=None,
description="Cursor for the next page of results", default=None
)
error: String = SchemaField(
description="Error message if the request failed",
default="",
description="Error message if the request failed", default=""
)
def __init__(self):
@@ -280,46 +256,35 @@ class ExaListWebsetsBlock(Block):
yield "has_more", False
@provider("exa")
class ExaGetWebsetBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput = CredentialsField(
provider="exa",
supported_credential_types={"api_key"},
description="The Exa integration requires an API Key.",
credentials: CredentialsMetaInput = exa.credentials_field(
description="The Exa integration requires an API Key."
)
webset_id: String = SchemaField(
description="The ID or external ID of the Webset to retrieve",
placeholder="webset-id-or-external-id",
)
expand_items: Boolean = SchemaField(
default=False,
description="Include items in the response",
advanced=True,
default=False, description="Include items in the response", advanced=True
)
class Output(BlockSchema):
webset_id: String = SchemaField(
description="The unique identifier for the webset",
)
status: String = SchemaField(
description="The status of the webset",
description="The unique identifier for the webset"
)
status: String = SchemaField(description="The status of the webset")
external_id: Optional[String] = SchemaField(
description="The external identifier for the webset",
default=None,
description="The external identifier for the webset", default=None
)
searches: List[Dict] = SchemaField(
description="The searches performed on the webset",
default_factory=list,
description="The searches performed on the webset", default_factory=list
)
enrichments: List[Dict] = SchemaField(
description="The enrichments applied to the webset",
default_factory=list,
description="The enrichments applied to the webset", default_factory=list
)
monitors: List[Dict] = SchemaField(
description="The monitors for the webset",
default_factory=list,
description="The monitors for the webset", default_factory=list
)
items: Optional[List[Dict]] = SchemaField(
description="The items in the webset (if expand_items is true)",
@@ -330,14 +295,13 @@ class ExaGetWebsetBlock(Block):
default_factory=dict,
)
created_at: String = SchemaField(
description="The date and time the webset was created",
description="The date and time the webset was created"
)
updated_at: String = SchemaField(
description="The date and time the webset was last updated",
description="The date and time the webset was last updated"
)
error: String = SchemaField(
description="Error message if the request failed",
default="",
description="Error message if the request failed", default=""
)
def __init__(self):
@@ -389,13 +353,10 @@ class ExaGetWebsetBlock(Block):
yield "updated_at", ""
@provider("exa")
class ExaDeleteWebsetBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput = CredentialsField(
provider="exa",
supported_credential_types={"api_key"},
description="The Exa integration requires an API Key.",
credentials: CredentialsMetaInput = exa.credentials_field(
description="The Exa integration requires an API Key."
)
webset_id: String = SchemaField(
description="The ID or external ID of the Webset to delete",
@@ -404,22 +365,17 @@ class ExaDeleteWebsetBlock(Block):
class Output(BlockSchema):
webset_id: String = SchemaField(
description="The unique identifier for the deleted webset",
description="The unique identifier for the deleted webset"
)
external_id: Optional[String] = SchemaField(
description="The external identifier for the deleted webset",
default=None,
)
status: String = SchemaField(
description="The status of the deleted webset",
description="The external identifier for the deleted webset", default=None
)
status: String = SchemaField(description="The status of the deleted webset")
success: String = SchemaField(
description="Whether the deletion was successful",
default="true",
description="Whether the deletion was successful", default="true"
)
error: String = SchemaField(
description="Error message if the request failed",
default="",
description="Error message if the request failed", default=""
)
def __init__(self):
@@ -456,13 +412,10 @@ class ExaDeleteWebsetBlock(Block):
yield "success", "false"
@provider("exa")
class ExaCancelWebsetBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput = CredentialsField(
provider="exa",
supported_credential_types={"api_key"},
description="The Exa integration requires an API Key.",
credentials: CredentialsMetaInput = exa.credentials_field(
description="The Exa integration requires an API Key."
)
webset_id: String = SchemaField(
description="The ID or external ID of the Webset to cancel",
@@ -471,22 +424,19 @@ class ExaCancelWebsetBlock(Block):
class Output(BlockSchema):
webset_id: String = SchemaField(
description="The unique identifier for the webset",
description="The unique identifier for the webset"
)
status: String = SchemaField(
description="The status of the webset after cancellation",
description="The status of the webset after cancellation"
)
external_id: Optional[String] = SchemaField(
description="The external identifier for the webset",
default=None,
description="The external identifier for the webset", default=None
)
success: String = SchemaField(
description="Whether the cancellation was successful",
default="true",
description="Whether the cancellation was successful", default="true"
)
error: String = SchemaField(
description="Error message if the request failed",
default="",
description="Error message if the request failed", default=""
)
def __init__(self):

View File

@@ -11,8 +11,6 @@ from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockCost,
BlockCostType,
BlockOutput,
BlockSchema,
Boolean,
@@ -22,9 +20,6 @@ from backend.sdk import (
SchemaField,
SecretStr,
String,
cost_config,
default_credentials,
provider,
)
# Define test credentials for testing
@@ -44,21 +39,7 @@ TEST_CREDENTIALS_INPUT = {
}
# Example of a simple service with auto-registration
@provider("example-service") # Custom provider demonstrating SDK flexibility
@cost_config(
BlockCost(cost_amount=2, cost_type=BlockCostType.RUN),
BlockCost(cost_amount=1, cost_type=BlockCostType.BYTE),
)
@default_credentials(
APIKeyCredentials(
id="example-service-default",
provider="example-service", # Custom provider name
api_key=SecretStr("example-default-api-key"),
title="Example Service Default API Key",
expires_at=None,
)
)
# Example of a simple service
class ExampleSDKBlock(Block):
"""
Example block demonstrating the new SDK system.

View File

@@ -5,13 +5,13 @@ This demonstrates webhook auto-registration without modifying
files outside the blocks folder.
"""
from enum import Enum
from backend.sdk import (
BaseModel,
BaseWebhooksManager,
Block,
BlockCategory,
BlockCost,
BlockCostType,
BlockOutput,
BlockSchema,
BlockType,
@@ -19,14 +19,10 @@ from backend.sdk import (
CredentialsField,
CredentialsMetaInput,
Dict,
Enum,
Field,
ProviderName,
SchemaField,
String,
cost_config,
provider,
webhook_config,
)
@@ -72,22 +68,10 @@ class ExampleWebhookManager(BaseWebhooksManager):
pass
# Now create the webhook block with auto-registration
@provider("examplewebhook")
@webhook_config("examplewebhook", ExampleWebhookManager)
@cost_config(
BlockCost(
cost_amount=0, cost_type=BlockCostType.RUN
) # Webhooks typically free to receive
)
# Now create the webhook block
class ExampleWebhookSDKBlock(Block):
"""
Example webhook block demonstrating SDK webhook capabilities.
With the new SDK:
- Webhook manager registered via @webhook_config decorator
- No need to modify webhooks/__init__.py
- Fully self-contained webhook implementation
Example webhook block demonstrating webhook capabilities.
"""
class Input(BlockSchema):

View File

@@ -18,28 +18,21 @@ After SDK: Single import statement
from backend.sdk import (
Block,
BlockCategory,
BlockCost,
BlockCostType,
BlockOutput,
BlockSchema,
Integer,
SchemaField,
String,
cost_config,
provider,
)
@provider("simple_service")
@cost_config(BlockCost(cost_amount=1, cost_type=BlockCostType.RUN))
class SimpleExampleBlock(Block):
"""
A simple example block showing the power of the SDK.
Key benefits:
1. Single import: from backend.sdk import *
2. Auto-registration via decorators
3. No manual config file updates needed
2. Clean, simple block structure
"""
class Input(BlockSchema):

View File

@@ -14,14 +14,14 @@ This module provides:
"""
# Standard library imports
import asyncio
import logging
from enum import Enum
from logging import getLogger as TruncatedLogger
from typing import Any, Dict, List
from typing import Any
from typing import Dict as DictType
from typing import List as ListType
from typing import Literal
from typing import Literal as _Literal
from typing import Optional, Set, Tuple, Type, TypeVar, Union
from typing import Optional, Union
import requests
# Third-party imports
from pydantic import BaseModel, Field, SecretStr
@@ -47,23 +47,15 @@ from backend.data.model import (
# === INTEGRATIONS ===
from backend.integrations.providers import ProviderName
from backend.sdk.builder import ProviderBuilder
from backend.sdk.provider import Provider
# === NEW SDK COMPONENTS (imported early for patches) ===
from backend.sdk.registry import AutoRegistry, BlockConfiguration
# === UTILITIES ===
from backend.util import json
# === AUTO-REGISTRATION DECORATORS ===
from .decorators import (
cost_config,
default_credentials,
oauth_config,
provider,
register_cost,
register_credentials,
register_oauth,
register_webhook_manager,
webhook_config,
)
# === OPTIONAL IMPORTS WITH TRY/EXCEPT ===
# Webhooks
try:
@@ -112,31 +104,8 @@ except ImportError:
try:
from backend.util.logging import TruncatedLogger
except ImportError:
TruncatedLogger = TruncatedLogger # Use the one imported at top
TruncatedLogger = None
# GitHub components
try:
from backend.blocks.github._auth import (
GithubCredentials,
GithubCredentialsField,
GithubCredentialsInput,
)
except ImportError:
GithubCredentials = None
GithubCredentialsInput = None
GithubCredentialsField = None
# Google components
try:
from backend.blocks.google._auth import (
GoogleCredentials,
GoogleCredentialsField,
GoogleCredentialsInput,
)
except ImportError:
GoogleCredentials = None
GoogleCredentialsInput = None
GoogleCredentialsField = None
# OAuth handlers
try:
@@ -144,49 +113,25 @@ try:
except ImportError:
BaseOAuthHandler = None
try:
from backend.integrations.oauth.github import GitHubOAuthHandler
except ImportError:
GitHubOAuthHandler = None
try:
from backend.integrations.oauth.google import GoogleOAuthHandler
except ImportError:
GoogleOAuthHandler = None
# Webhook managers
try:
from backend.integrations.webhooks.github import GithubWebhooksManager
except ImportError:
GithubWebhooksManager = None
try:
from backend.integrations.webhooks.generic import GenericWebhooksManager
except ImportError:
GenericWebhooksManager = None
# === VARIABLE ASSIGNMENTS AND TYPE ALIASES ===
# Type aliases
# Type aliases for block development
String = str
Integer = int
Float = float
Boolean = bool
List = ListType
Dict = DictType
# Credential type with proper provider name
CredentialsMetaInput = _CredentialsMetaInput[
ProviderName, _Literal["api_key", "oauth2", "user_password"]
]
# Webhook manager aliases
if GithubWebhooksManager is not None:
GitHubWebhooksManager = GithubWebhooksManager # Alias for consistency
else:
GitHubWebhooksManager = None
if GenericWebhooksManager is not None:
GenericWebhookManager = GenericWebhooksManager # Alias for consistency
else:
GenericWebhookManager = None
# Initialize the registry's integration patches
AutoRegistry.patch_integrations()
# === COMPREHENSIVE __all__ EXPORT ===
__all__ = [
@@ -216,19 +161,7 @@ __all__ = [
"BaseWebhooksManager",
"ManualWebhookManagerBase",
# Provider-Specific (when available)
"GithubCredentials",
"GithubCredentialsInput",
"GithubCredentialsField",
"GoogleCredentials",
"GoogleCredentialsInput",
"GoogleCredentialsField",
"BaseOAuthHandler",
"GitHubOAuthHandler",
"GoogleOAuthHandler",
"GitHubWebhooksManager",
"GithubWebhooksManager",
"GenericWebhookManager",
"GenericWebhooksManager",
# Utilities
"json",
"store_media_file",
@@ -236,9 +169,7 @@ __all__ = [
"convert",
"TextFormatter",
"TruncatedLogger",
"logging",
"asyncio",
# Types
# Type aliases for blocks
"String",
"Integer",
"Float",
@@ -247,26 +178,17 @@ __all__ = [
"Dict",
"Optional",
"Any",
"Literal",
"Union",
"TypeVar",
"Type",
"Tuple",
"Set",
"Literal",
"BaseModel",
"SecretStr",
"Field",
"Enum",
# Auto-Registration Decorators
"register_credentials",
"register_cost",
"register_oauth",
"register_webhook_manager",
"provider",
"cost_config",
"webhook_config",
"default_credentials",
"oauth_config",
"SecretStr",
"requests",
# SDK Components
"AutoRegistry",
"BlockConfiguration",
"Provider",
"ProviderBuilder",
]
# Remove None values from __all__

View File

@@ -0,0 +1,131 @@
"""
Builder class for creating provider configurations with a fluent API.
"""
import os
from typing import Callable, List, Optional, Type
from pydantic import SecretStr
from backend.data.cost import BlockCost, BlockCostType
from backend.data.model import APIKeyCredentials, Credentials
from backend.integrations.oauth.base import BaseOAuthHandler
from backend.integrations.webhooks._base import BaseWebhooksManager
from backend.sdk.provider import Provider
from backend.sdk.registry import AutoRegistry
from backend.util.settings import Settings
class ProviderBuilder:
"""Builder for creating provider configurations."""
def __init__(self, name: str):
self.name = name
self._oauth_handler: Optional[Type[BaseOAuthHandler]] = None
self._webhook_manager: Optional[Type[BaseWebhooksManager]] = None
self._default_credentials: List[Credentials] = []
self._base_costs: List[BlockCost] = []
self._supported_auth_types: set = set()
self._api_client_factory: Optional[Callable] = None
self._error_handler: Optional[Callable[[Exception], str]] = None
self._default_scopes: Optional[List[str]] = None
self._extra_config: dict = {}
def with_oauth(
self, handler_class: Type[BaseOAuthHandler], scopes: Optional[List[str]] = None
) -> "ProviderBuilder":
"""Add OAuth support."""
self._oauth_handler = handler_class
self._supported_auth_types.add("oauth2")
if scopes:
self._default_scopes = scopes
return self
def with_api_key(self, env_var_name: str, title: str) -> "ProviderBuilder":
"""Add API key support with environment variable name."""
self._supported_auth_types.add("api_key")
# Register the API key mapping
AutoRegistry.register_api_key(self.name, env_var_name)
# Check if API key exists in environment
api_key = os.getenv(env_var_name)
if api_key:
self._default_credentials.append(
APIKeyCredentials(
id=f"{self.name}-default",
provider=self.name,
api_key=SecretStr(api_key),
title=title,
)
)
return self
def with_api_key_from_settings(
self, settings_attr: str, title: str
) -> "ProviderBuilder":
"""Use existing API key from settings."""
self._supported_auth_types.add("api_key")
# Try to get the API key from settings
settings = Settings()
api_key = getattr(settings.secrets, settings_attr, None)
if api_key:
self._default_credentials.append(
APIKeyCredentials(
id=f"{self.name}-default",
provider=self.name,
api_key=api_key,
title=title,
)
)
return self
def with_webhook_manager(
self, manager_class: Type[BaseWebhooksManager]
) -> "ProviderBuilder":
"""Register webhook manager for this provider."""
self._webhook_manager = manager_class
return self
def with_base_cost(
self, amount: int, cost_type: BlockCostType
) -> "ProviderBuilder":
"""Set base cost for all blocks using this provider."""
self._base_costs.append(BlockCost(cost_amount=amount, cost_type=cost_type))
return self
def with_api_client(self, factory: Callable) -> "ProviderBuilder":
"""Register API client factory."""
self._api_client_factory = factory
return self
def with_error_handler(
self, handler: Callable[[Exception], str]
) -> "ProviderBuilder":
"""Register error handler for provider-specific errors."""
self._error_handler = handler
return self
def with_config(self, **kwargs) -> "ProviderBuilder":
"""Add additional configuration options."""
self._extra_config.update(kwargs)
return self
def build(self) -> Provider:
"""Build and register the provider configuration."""
provider = Provider(
name=self.name,
oauth_handler=self._oauth_handler,
webhook_manager=self._webhook_manager,
default_credentials=self._default_credentials,
base_costs=self._base_costs,
supported_auth_types=self._supported_auth_types,
api_client_factory=self._api_client_factory,
error_handler=self._error_handler,
**self._extra_config,
)
# Auto-registration happens here
AutoRegistry.register_provider(provider)
return provider

View File

@@ -0,0 +1,66 @@
"""
Provider configuration class that holds all provider-related settings.
"""
from typing import Any, Callable, List, Optional, Set, Type
from backend.data.cost import BlockCost
from backend.data.model import Credentials, CredentialsField, CredentialsMetaInput
from backend.integrations.oauth.base import BaseOAuthHandler
from backend.integrations.webhooks._base import BaseWebhooksManager
class Provider:
"""A configured provider that blocks can use."""
def __init__(
self,
name: str,
oauth_handler: Optional[Type[BaseOAuthHandler]] = None,
webhook_manager: Optional[Type[BaseWebhooksManager]] = None,
default_credentials: Optional[List[Credentials]] = None,
base_costs: Optional[List[BlockCost]] = None,
supported_auth_types: Optional[Set[str]] = None,
api_client_factory: Optional[Callable] = None,
error_handler: Optional[Callable[[Exception], str]] = None,
**kwargs,
):
self.name = name
self.oauth_handler = oauth_handler
self.webhook_manager = webhook_manager
self.default_credentials = default_credentials or []
self.base_costs = base_costs or []
self.supported_auth_types = supported_auth_types or set()
self._api_client_factory = api_client_factory
self._error_handler = error_handler
# Store any additional configuration
self._extra_config = kwargs
def credentials_field(self, **kwargs) -> CredentialsMetaInput:
"""Return a CredentialsField configured for this provider."""
# Merge provider defaults with user overrides
field_kwargs = {
"provider": self.name,
"supported_credential_types": self.supported_auth_types,
"description": f"{self.name.title()} credentials",
}
field_kwargs.update(kwargs)
return CredentialsField(**field_kwargs)
def get_api(self, credentials: Credentials) -> Any:
"""Get API client instance for the given credentials."""
if self._api_client_factory:
return self._api_client_factory(credentials)
raise NotImplementedError(f"No API client factory registered for {self.name}")
def handle_error(self, error: Exception) -> str:
"""Handle provider-specific errors."""
if self._error_handler:
return self._error_handler(error)
return str(error)
def get_config(self, key: str, default: Any = None) -> Any:
"""Get additional configuration value."""
return self._extra_config.get(key, default)

View File

@@ -0,0 +1,203 @@
"""
Auto-registration system for blocks, providers, and their configurations.
"""
import logging
import threading
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
from pydantic import SecretStr
from backend.blocks.basic import Block
from backend.data.model import APIKeyCredentials, Credentials
from backend.integrations.oauth.base import BaseOAuthHandler
from backend.integrations.webhooks._base import BaseWebhooksManager
if TYPE_CHECKING:
from backend.sdk.provider import Provider
class BlockConfiguration:
"""Configuration associated with a block."""
def __init__(
self,
provider: str,
costs: List[Any],
default_credentials: List[Credentials],
webhook_manager: Optional[Type[BaseWebhooksManager]] = None,
oauth_handler: Optional[Type[BaseOAuthHandler]] = None,
):
self.provider = provider
self.costs = costs
self.default_credentials = default_credentials
self.webhook_manager = webhook_manager
self.oauth_handler = oauth_handler
class AutoRegistry:
"""Central registry for all block-related configurations."""
_lock = threading.Lock()
_providers: Dict[str, "Provider"] = {}
_default_credentials: List[Credentials] = []
_oauth_handlers: Dict[str, Type[BaseOAuthHandler]] = {}
_webhook_managers: Dict[str, Type[BaseWebhooksManager]] = {}
_block_configurations: Dict[Type[Block], BlockConfiguration] = {}
_api_key_mappings: Dict[str, str] = {} # provider -> env_var_name
@classmethod
def register_provider(cls, provider: "Provider") -> None:
"""Auto-register provider and all its configurations."""
with cls._lock:
cls._providers[provider.name] = provider
# Register OAuth handler if provided
if provider.oauth_handler:
cls._oauth_handlers[provider.name] = provider.oauth_handler
# Register webhook manager if provided
if provider.webhook_manager:
cls._webhook_managers[provider.name] = provider.webhook_manager
# Register default credentials
cls._default_credentials.extend(provider.default_credentials)
@classmethod
def register_api_key(cls, provider: str, env_var_name: str) -> None:
"""Register an environment variable as an API key for a provider."""
with cls._lock:
cls._api_key_mappings[provider] = env_var_name
# Dynamically check if the env var exists and create credential
import os
api_key = os.getenv(env_var_name)
if api_key:
credential = APIKeyCredentials(
id=f"{provider}-default",
provider=provider,
api_key=SecretStr(api_key),
title=f"Default {provider} credentials",
)
# Check if credential already exists to avoid duplicates
if not any(c.id == credential.id for c in cls._default_credentials):
cls._default_credentials.append(credential)
@classmethod
def get_all_credentials(cls) -> List[Credentials]:
"""Replace hardcoded get_all_creds() in credentials_store.py."""
with cls._lock:
return cls._default_credentials.copy()
@classmethod
def get_oauth_handlers(cls) -> Dict[str, Type[BaseOAuthHandler]]:
"""Replace HANDLERS_BY_NAME in oauth/__init__.py."""
with cls._lock:
return cls._oauth_handlers.copy()
@classmethod
def get_webhook_managers(cls) -> Dict[str, Type[BaseWebhooksManager]]:
"""Replace load_webhook_managers() in webhooks/__init__.py."""
with cls._lock:
return cls._webhook_managers.copy()
@classmethod
def register_block_configuration(
cls, block_class: Type[Block], config: BlockConfiguration
) -> None:
"""Register configuration for a specific block class."""
with cls._lock:
cls._block_configurations[block_class] = config
@classmethod
def get_provider(cls, name: str) -> Optional["Provider"]:
"""Get a registered provider by name."""
with cls._lock:
return cls._providers.get(name)
@classmethod
def clear(cls) -> None:
"""Clear all registrations (useful for testing)."""
with cls._lock:
cls._providers.clear()
cls._default_credentials.clear()
cls._oauth_handlers.clear()
cls._webhook_managers.clear()
cls._block_configurations.clear()
cls._api_key_mappings.clear()
@classmethod
def patch_integrations(cls) -> None:
"""Patch existing integration points to use AutoRegistry."""
# Patch oauth handlers
try:
import backend.integrations.oauth as oauth
if hasattr(oauth, "HANDLERS_BY_NAME"):
# Create a new dict that includes both original and SDK handlers
original_handlers = dict(oauth.HANDLERS_BY_NAME)
class PatchedHandlersDict(dict): # type: ignore
def __getitem__(self, key):
# First try SDK handlers
sdk_handlers = cls.get_oauth_handlers()
if key in sdk_handlers:
return sdk_handlers[key]
# Fall back to original
return original_handlers[key]
def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default
def __contains__(self, key):
sdk_handlers = cls.get_oauth_handlers()
return key in sdk_handlers or key in original_handlers
def keys(self): # type: ignore[override]
sdk_handlers = cls.get_oauth_handlers()
all_keys = set(original_handlers.keys()) | set(
sdk_handlers.keys()
)
return all_keys
def values(self):
combined = dict(original_handlers)
sdk_handlers = cls.get_oauth_handlers()
if isinstance(sdk_handlers, dict):
combined.update(sdk_handlers) # type: ignore
return combined.values()
def items(self):
combined = dict(original_handlers)
sdk_handlers = cls.get_oauth_handlers()
if isinstance(sdk_handlers, dict):
combined.update(sdk_handlers) # type: ignore
return combined.items()
oauth.HANDLERS_BY_NAME = PatchedHandlersDict()
except Exception as e:
logging.warning(f"Failed to patch oauth handlers: {e}")
# Patch webhook managers
try:
import backend.integrations.webhooks as webhooks
if hasattr(webhooks, "load_webhook_managers"):
original_load = webhooks.load_webhook_managers
def patched_load():
# Get original managers
managers = original_load()
# Add SDK-registered managers
sdk_managers = cls.get_webhook_managers()
if isinstance(sdk_managers, dict):
managers.update(sdk_managers) # type: ignore
return managers
webhooks.load_webhook_managers = patched_load
except Exception as e:
logging.warning(f"Failed to patch webhook managers: {e}")

View File

@@ -20,7 +20,7 @@ from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.oauth import HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks import get_webhook_manager
from backend.sdk.auto_registry import get_registry
from backend.sdk.registry import AutoRegistry
from backend.util.exceptions import NeedConfirmation, NotFoundError
from backend.util.settings import Settings
@@ -434,8 +434,7 @@ async def list_providers() -> List[str]:
static_providers = [member.value for member in ProviderName]
# Get dynamic providers from registry
registry = get_registry()
dynamic_providers = list(registry.providers)
dynamic_providers = list(AutoRegistry._providers.keys())
# Combine and deduplicate
all_providers = list(set(static_providers + dynamic_providers))
@@ -461,7 +460,7 @@ async def get_providers_details() -> Dict[str, ProviderDetails]:
Returns a dictionary mapping provider names to their details,
including supported credential types and other metadata.
"""
registry = get_registry()
# AutoRegistry is used directly as a class with class methods
# Build provider details
provider_details: Dict[str, ProviderDetails] = {}
@@ -471,24 +470,26 @@ async def get_providers_details() -> Dict[str, ProviderDetails]:
provider_details[member.value] = ProviderDetails(
name=member.value,
source="static",
has_oauth=member.value in registry.oauth_handlers,
has_webhooks=member.value in registry.webhook_managers,
has_oauth=member.value in AutoRegistry._oauth_handlers,
has_webhooks=member.value in AutoRegistry._webhook_managers,
)
# Add/update with dynamic providers
for provider in registry.providers:
for provider in AutoRegistry._providers:
if provider not in provider_details:
provider_details[provider] = ProviderDetails(
name=provider,
source="dynamic",
has_oauth=provider in registry.oauth_handlers,
has_webhooks=provider in registry.webhook_managers,
has_oauth=provider in AutoRegistry._oauth_handlers,
has_webhooks=provider in AutoRegistry._webhook_managers,
)
else:
provider_details[provider].source = "both"
provider_details[provider].has_oauth = provider in registry.oauth_handlers
provider_details[provider].has_oauth = (
provider in AutoRegistry._oauth_handlers
)
provider_details[provider].has_webhooks = (
provider in registry.webhook_managers
provider in AutoRegistry._webhook_managers
)
# Determine supported credential types for each provider

View File

@@ -58,40 +58,8 @@ async def lifespan_context(app: fastapi.FastAPI):
await backend.data.db.connect()
await backend.data.block.initialize_blocks()
# Set up auto-registration system for SDK
try:
from backend.sdk.auto_registry import setup_auto_registration
logger.info("Starting SDK auto-registration system...")
registry = setup_auto_registration()
# Log successful registration
logger.info("Auto-registration completed successfully:")
logger.info(f" - {len(registry.block_costs)} block costs registered")
logger.info(
f" - {len(registry.default_credentials)} default credentials registered"
)
logger.info(f" - {len(registry.oauth_handlers)} OAuth handlers registered")
logger.info(f" - {len(registry.webhook_managers)} webhook managers registered")
logger.info(f" - {len(registry.providers)} providers registered")
# Log specific credential providers for debugging
credential_providers = [
getattr(cred, "provider", "unknown")
for cred in registry.default_credentials
]
if credential_providers:
logger.info(
f" - Default credential providers: {', '.join(credential_providers)}"
)
except Exception as e:
logger.error(f"Auto-registration setup failed: {e}")
import traceback
logger.error(f"Traceback: {traceback.format_exc()}")
# Don't let this failure prevent startup, but make it very visible
raise
# SDK auto-registration is now handled by AutoRegistry.patch_integrations()
# which is called when the SDK module is imported
await backend.data.user.migrate_and_encrypt_user_integrations()
await backend.data.graph.fix_llm_provider_credentials()

View File

@@ -0,0 +1,29 @@
"""
Configuration for SDK tests.
This conftest.py file provides basic test setup for SDK unit tests
without requiring the full server infrastructure.
"""
from unittest.mock import MagicMock
import pytest
@pytest.fixture(scope="session")
def server():
"""Mock server fixture for SDK tests."""
mock_server = MagicMock()
mock_server.agent_server = MagicMock()
mock_server.agent_server.test_create_graph = MagicMock()
return mock_server
@pytest.fixture(autouse=True)
def reset_registry():
"""Reset the AutoRegistry before each test."""
from backend.sdk.registry import AutoRegistry
AutoRegistry.clear()
yield
AutoRegistry.clear()

View File

@@ -1,231 +0,0 @@
"""
Demo: Creating a new block with the SDK using 'from backend.sdk import *'
This file demonstrates the simplified block creation process.
"""
from backend.integrations.providers import ProviderName
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockCost,
BlockCostType,
BlockOutput,
BlockSchema,
CredentialsField,
CredentialsMetaInput,
Float,
List,
SchemaField,
SecretStr,
String,
cost_config,
default_credentials,
provider,
)
# Create a custom translation service block with full auto-registration
@provider("ultra-translate-ai")
@cost_config(
BlockCost(cost_amount=3, cost_type=BlockCostType.RUN),
BlockCost(cost_amount=1, cost_type=BlockCostType.BYTE),
)
@default_credentials(
APIKeyCredentials(
id="ultra-translate-default",
provider="ultra-translate-ai",
api_key=SecretStr("ultra-translate-default-api-key"),
title="Ultra Translate AI Default API Key",
)
)
class UltraTranslateBlock(Block):
"""
Ultra Translate AI - Advanced Translation Service
This block demonstrates how easy it is to create a new service integration
with the SDK. No external configuration files need to be modified!
"""
class Input(BlockSchema):
credentials: CredentialsMetaInput = CredentialsField(
provider="ultra-translate-ai",
supported_credential_types={"api_key"},
description="API credentials for Ultra Translate AI",
)
text: String = SchemaField(
description="Text to translate", placeholder="Enter text to translate..."
)
source_language: String = SchemaField(
description="Source language code (auto-detect if empty)",
default="",
placeholder="en, es, fr, de, ja, zh",
)
target_language: String = SchemaField(
description="Target language code",
default="es",
placeholder="en, es, fr, de, ja, zh",
)
formality: String = SchemaField(
description="Translation formality level (formal, neutral, informal)",
default="neutral",
)
class Output(BlockSchema):
translated_text: String = SchemaField(description="The translated text")
detected_language: String = SchemaField(
description="Auto-detected source language (if applicable)"
)
confidence: Float = SchemaField(
description="Translation confidence score (0-1)"
)
alternatives: List[String] = SchemaField(
description="Alternative translations", default=[]
)
error: String = SchemaField(
description="Error message if translation failed", default=""
)
def __init__(self):
super().__init__(
id="abb8abc4-c968-45fe-815a-c5521ad67f32",
description="Translate text between languages using Ultra Translate AI",
categories={BlockCategory.TEXT, BlockCategory.AI},
input_schema=UltraTranslateBlock.Input,
output_schema=UltraTranslateBlock.Output,
test_input={
"text": "Hello, how are you?",
"target_language": "es",
"formality": "informal",
},
test_output=[
("translated_text", "¡Hola! ¿Cómo estás?"),
("detected_language", "en"),
("confidence", 0.98),
("alternatives", ["¡Hola! ¿Qué tal?", "¡Hola! ¿Cómo te va?"]),
("error", ""),
],
)
def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
# Get API key
api_key = credentials.api_key.get_secret_value() # noqa: F841
# Simulate translation based on input
translations = {
("Hello, how are you?", "es", "informal"): {
"text": "¡Hola! ¿Cómo estás?",
"alternatives": ["¡Hola! ¿Qué tal?", "¡Hola! ¿Cómo te va?"],
},
("Hello, how are you?", "es", "formal"): {
"text": "Hola, ¿cómo está usted?",
"alternatives": ["Buenos días, ¿cómo se encuentra?"],
},
("Hello, how are you?", "fr", "neutral"): {
"text": "Bonjour, comment allez-vous ?",
"alternatives": ["Salut, comment ça va ?"],
},
("Hello, how are you?", "de", "neutral"): {
"text": "Hallo, wie geht es dir?",
"alternatives": ["Hallo, wie geht's?"],
},
}
# Get translation
key = (input_data.text, input_data.target_language, input_data.formality)
result = translations.get(
key,
{
"text": f"[{input_data.target_language}] {input_data.text}",
"alternatives": [],
},
)
# Detect source language if not provided
detected_lang = input_data.source_language or "en"
yield "translated_text", result["text"]
yield "detected_language", detected_lang
yield "confidence", 0.95
yield "alternatives", result["alternatives"]
yield "error", ""
except Exception as e:
yield "translated_text", ""
yield "detected_language", ""
yield "confidence", 0.0
yield "alternatives", []
yield "error", str(e)
# This function demonstrates testing the block
def demo_block_usage():
"""Demonstrate using the block"""
print("=" * 60)
print("🌐 Ultra Translate AI Block Demo")
print("=" * 60)
# Create block instance
block = UltraTranslateBlock()
print(f"\n✅ Created block: {block.name}")
print(f" ID: {block.id}")
print(f" Categories: {block.categories}")
# Check auto-registration
from backend.sdk.auto_registry import get_registry
registry = get_registry()
print("\n📋 Auto-Registration Status:")
print(f" ✅ Provider registered: {'ultra-translate-ai' in registry.providers}")
print(f" ✅ Costs registered: {UltraTranslateBlock in registry.block_costs}")
if UltraTranslateBlock in registry.block_costs:
costs = registry.block_costs[UltraTranslateBlock]
print(f" - Per run: {costs[0].cost_amount} credits")
print(f" - Per byte: {costs[1].cost_amount} credits")
creds = registry.get_default_credentials_list()
has_default_cred = any(c.id == "ultra-translate-default" for c in creds)
print(f" ✅ Default credentials: {has_default_cred}")
# Test dynamic provider enum
print("\n🔧 Dynamic Provider Test:")
provider = ProviderName("ultra-translate-ai")
print(f" ✅ Custom provider accepted: {provider.value}")
print(f" ✅ Is ProviderName instance: {isinstance(provider, ProviderName)}")
# Test block execution
print("\n🚀 Test Block Execution:")
test_creds = APIKeyCredentials(
id="test",
provider="ultra-translate-ai",
api_key=SecretStr("test-api-key"),
title="Test",
)
# Create test input with credentials meta
test_input = UltraTranslateBlock.Input(
credentials={"provider": "ultra-translate-ai", "id": "test", "type": "api_key"},
text="Hello, how are you?",
target_language="es",
formality="informal",
)
results = list(block.run(test_input, credentials=test_creds))
output = {k: v for k, v in results}
print(f" Input: '{test_input.text}'")
print(f" Target: {test_input.target_language} ({test_input.formality})")
print(f" Output: '{output['translated_text']}'")
print(f" Confidence: {output['confidence']}")
print(f" Alternatives: {output['alternatives']}")
print("\n✨ Block works perfectly with zero external configuration!")
if __name__ == "__main__":
demo_block_usage()

View File

@@ -1,247 +0,0 @@
"""
Test custom provider functionality in the SDK.
This test suite verifies that the SDK properly supports dynamic provider
registration and that custom providers work correctly with the system.
"""
from backend.integrations.providers import ProviderName
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockCost,
BlockCostType,
BlockOutput,
BlockSchema,
Boolean,
CredentialsField,
CredentialsMetaInput,
SchemaField,
SecretStr,
String,
cost_config,
default_credentials,
provider,
)
from backend.sdk.auto_registry import get_registry
from backend.util.test import execute_block_test
# Test credentials for custom providers
CUSTOM_TEST_CREDENTIALS = APIKeyCredentials(
id="custom-provider-test-creds",
provider="my-custom-service",
api_key=SecretStr("test-api-key-12345"),
title="Custom Service Test Credentials",
expires_at=None,
)
CUSTOM_TEST_CREDENTIALS_INPUT = {
"provider": CUSTOM_TEST_CREDENTIALS.provider,
"id": CUSTOM_TEST_CREDENTIALS.id,
"type": CUSTOM_TEST_CREDENTIALS.type,
"title": CUSTOM_TEST_CREDENTIALS.title,
}
@provider("my-custom-service")
@cost_config(
BlockCost(cost_amount=10, cost_type=BlockCostType.RUN),
BlockCost(cost_amount=2, cost_type=BlockCostType.BYTE),
)
@default_credentials(
APIKeyCredentials(
id="my-custom-service-default",
provider="my-custom-service",
api_key=SecretStr("default-custom-api-key"),
title="My Custom Service Default API Key",
expires_at=None,
)
)
class CustomProviderBlock(Block):
"""Test block with a completely custom provider."""
class Input(BlockSchema):
credentials: CredentialsMetaInput = CredentialsField(
provider="my-custom-service",
supported_credential_types={"api_key"},
description="Custom service credentials",
)
message: String = SchemaField(
description="Message to process", default="Hello from custom provider!"
)
class Output(BlockSchema):
result: String = SchemaField(description="Processed message")
provider_used: String = SchemaField(description="Provider name used")
credentials_valid: Boolean = SchemaField(
description="Whether credentials were valid"
)
def __init__(self):
super().__init__(
id="d1234567-89ab-cdef-0123-456789abcdef",
description="Test block demonstrating custom provider support",
categories={BlockCategory.TEXT},
input_schema=CustomProviderBlock.Input,
output_schema=CustomProviderBlock.Output,
test_input={
"credentials": CUSTOM_TEST_CREDENTIALS_INPUT,
"message": "Test message",
},
test_output=[
("result", "CUSTOM: Test message"),
("provider_used", "my-custom-service"),
("credentials_valid", True),
],
test_credentials=CUSTOM_TEST_CREDENTIALS,
)
def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
# Verify we got the right credentials
api_key = credentials.api_key.get_secret_value()
yield "result", f"CUSTOM: {input_data.message}"
yield "provider_used", credentials.provider
yield "credentials_valid", bool(api_key)
@provider("another-custom-provider")
@cost_config(
BlockCost(cost_amount=5, cost_type=BlockCostType.RUN),
)
class AnotherCustomProviderBlock(Block):
"""Another test block to verify multiple custom providers work."""
class Input(BlockSchema):
credentials: CredentialsMetaInput = CredentialsField(
provider="another-custom-provider",
supported_credential_types={"api_key"},
)
data: String = SchemaField(description="Input data")
class Output(BlockSchema):
processed: String = SchemaField(description="Processed data")
def __init__(self):
super().__init__(
id="e2345678-9abc-def0-1234-567890abcdef",
description="Another custom provider test",
categories={BlockCategory.TEXT},
input_schema=AnotherCustomProviderBlock.Input,
output_schema=AnotherCustomProviderBlock.Output,
test_input={
"credentials": {
"provider": "another-custom-provider",
"id": "test-creds-2",
"type": "api_key",
"title": "Test Creds 2",
},
"data": "test data",
},
test_output=[("processed", "ANOTHER: test data")],
test_credentials=APIKeyCredentials(
id="test-creds-2",
provider="another-custom-provider",
api_key=SecretStr("another-test-key"),
title="Another Test Key",
expires_at=None,
),
)
def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
yield "processed", f"ANOTHER: {input_data.data}"
class TestCustomProvider:
"""Test suite for custom provider functionality."""
def test_custom_provider_enum_accepts_any_string(self):
"""Test that ProviderName enum accepts any string value."""
# Test with a completely new provider name
custom_provider = ProviderName("my-totally-new-provider")
assert custom_provider.value == "my-totally-new-provider"
# Test with existing provider
existing_provider = ProviderName.OPENAI
assert existing_provider.value == "openai"
# Test comparison
another_custom = ProviderName("my-totally-new-provider")
assert custom_provider == another_custom
def test_custom_provider_block_executes(self):
"""Test that blocks with custom providers can execute properly."""
block = CustomProviderBlock()
execute_block_test(block)
def test_multiple_custom_providers(self):
"""Test that multiple custom providers can coexist."""
block1 = CustomProviderBlock()
block2 = AnotherCustomProviderBlock()
# Both blocks should execute successfully
execute_block_test(block1)
execute_block_test(block2)
def test_custom_provider_registration(self):
"""Test that custom providers are registered in the auto-registry."""
registry = get_registry()
# Check that our custom provider blocks have registered their costs
block_costs = registry.get_block_costs_dict()
assert CustomProviderBlock in block_costs
assert AnotherCustomProviderBlock in block_costs
# Check the costs are correct
custom_costs = block_costs[CustomProviderBlock]
assert len(custom_costs) == 2
assert any(
cost.cost_amount == 10 and cost.cost_type == BlockCostType.RUN
for cost in custom_costs
)
assert any(
cost.cost_amount == 2 and cost.cost_type == BlockCostType.BYTE
for cost in custom_costs
)
def test_custom_provider_default_credentials(self):
"""Test that default credentials are registered for custom providers."""
registry = get_registry()
default_creds = registry.get_default_credentials_list()
# Check that our custom provider's default credentials are registered
custom_default_creds = [
cred for cred in default_creds if cred.provider == "my-custom-service"
]
assert len(custom_default_creds) >= 1
assert custom_default_creds[0].id == "my-custom-service-default"
def test_custom_provider_with_oauth(self):
"""Test that custom providers can use OAuth handlers."""
# This is a placeholder for OAuth testing
# In a real implementation, you would create a custom OAuth handler
pass
def test_custom_provider_with_webhooks(self):
"""Test that custom providers can use webhook managers."""
# This is a placeholder for webhook testing
# In a real implementation, you would create a custom webhook manager
pass
# Test that runs as part of pytest
def test_custom_provider_functionality():
"""Run all custom provider tests."""
test_instance = TestCustomProvider()
# Run each test method
test_instance.test_custom_provider_enum_accepts_any_string()
test_instance.test_custom_provider_block_executes()
test_instance.test_multiple_custom_providers()
test_instance.test_custom_provider_registration()
test_instance.test_custom_provider_default_credentials()

View File

@@ -1,416 +0,0 @@
"""
Advanced tests for custom provider functionality including OAuth and Webhooks.
This test suite demonstrates how custom providers can integrate with all
aspects of the SDK including OAuth authentication and webhook handling.
"""
import time
from enum import Enum
from typing import Any, Optional
from backend.integrations.providers import ProviderName
from backend.sdk import (
APIKeyCredentials,
BaseModel,
BaseOAuthHandler,
BaseWebhooksManager,
Block,
BlockCategory,
BlockCost,
BlockCostType,
BlockOutput,
BlockSchema,
BlockType,
BlockWebhookConfig,
Boolean,
CredentialsField,
CredentialsMetaInput,
Dict,
Float,
List,
OAuth2Credentials,
SchemaField,
SecretStr,
String,
cost_config,
default_credentials,
oauth_config,
provider,
webhook_config,
)
from backend.util.test import execute_block_test
# Custom OAuth Handler for testing
class CustomServiceOAuthHandler(BaseOAuthHandler):
"""OAuth handler for our custom service."""
PROVIDER_NAME = ProviderName("custom-oauth-service")
DEFAULT_SCOPES = ["read", "write", "admin"]
def get_login_url(
self, scopes: list[str], state: str, code_challenge: Optional[str]
) -> str:
"""Generate OAuth login URL."""
scope_str = " ".join(scopes)
return f"https://custom-oauth-service.com/oauth/authorize?client_id=test&scope={scope_str}&state={state}"
def exchange_code_for_tokens(
self, code: str, scopes: list[str], code_verifier: Optional[str]
) -> OAuth2Credentials:
"""Exchange authorization code for tokens."""
# Mock token exchange
return OAuth2Credentials(
provider=self.PROVIDER_NAME,
access_token=SecretStr("mock-access-token"),
refresh_token=SecretStr("mock-refresh-token"),
scopes=scopes,
access_token_expires_at=int(time.time() + 3600),
title="Custom OAuth Service",
id="custom-oauth-creds",
)
# Custom Webhook Manager for testing
class CustomWebhookManager(BaseWebhooksManager):
"""Webhook manager for our custom service."""
PROVIDER_NAME = ProviderName("custom-webhook-service")
class WebhookType(str, Enum):
DATA_RECEIVED = "data_received"
STATUS_CHANGED = "status_changed"
@classmethod
async def validate_payload(cls, webhook: Any, request: Any) -> tuple[dict, str]:
"""Validate incoming webhook payload."""
# Mock payload validation
payload = {"data": "test data", "timestamp": time.time()}
event_type = "data_received"
return payload, event_type
async def _register_webhook(
self,
credentials: Any,
webhook_type: Any,
resource: str,
events: list[str],
ingress_url: str,
secret: str,
) -> tuple[str, dict]:
"""Register webhook with external service."""
# Mock webhook registration
webhook_id = "custom-webhook-12345"
config = {"url": ingress_url, "events": events, "resource": resource}
return webhook_id, config
async def _deregister_webhook(self, webhook: Any, credentials: Any) -> None:
"""Deregister webhook from external service."""
# Mock webhook deregistration
pass
# Test OAuth-enabled block
@provider("custom-oauth-service")
@oauth_config("custom-oauth-service", CustomServiceOAuthHandler)
@cost_config(
BlockCost(cost_amount=15, cost_type=BlockCostType.RUN),
)
class CustomOAuthBlock(Block):
"""Block that uses OAuth authentication with a custom provider."""
class Input(BlockSchema):
credentials: CredentialsMetaInput = CredentialsField(
provider="custom-oauth-service",
supported_credential_types={"oauth2"},
required_scopes={"read", "write"},
description="OAuth credentials for custom service",
)
action: String = SchemaField(
description="Action to perform", default="fetch_data"
)
class Output(BlockSchema):
data: Dict = SchemaField(description="Retrieved data")
token_valid: Boolean = SchemaField(description="Whether OAuth token was valid")
scopes: List[String] = SchemaField(description="Available scopes")
def __init__(self):
super().__init__(
id="f3456789-abcd-ef01-2345-6789abcdef01",
description="Custom OAuth provider test block",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=CustomOAuthBlock.Input,
output_schema=CustomOAuthBlock.Output,
test_input={
"credentials": {
"provider": "custom-oauth-service",
"id": "oauth-test-creds",
"type": "oauth2",
"title": "Test OAuth Creds",
},
"action": "test_action",
},
test_output=[
("data", {"status": "success", "action": "test_action"}),
("token_valid", True),
("scopes", ["read", "write"]),
],
test_credentials=OAuth2Credentials(
id="oauth-test-creds",
provider="custom-oauth-service",
access_token=SecretStr("test-access-token"),
refresh_token=SecretStr("test-refresh-token"),
scopes=["read", "write"],
access_token_expires_at=int(time.time() + 3600),
title="Test OAuth Credentials",
),
)
def run(
self, input_data: Input, *, credentials: OAuth2Credentials, **kwargs
) -> BlockOutput:
# Simulate OAuth API call
token = credentials.access_token.get_secret_value()
yield "data", {"status": "success", "action": input_data.action}
yield "token_valid", bool(token)
yield "scopes", credentials.scopes
# Event filter model for webhook
class WebhookEventFilter(BaseModel):
data_received: bool = True
status_changed: bool = False
# Test Webhook-enabled block
@provider("custom-webhook-service")
@webhook_config("custom-webhook-service", CustomWebhookManager)
class CustomWebhookBlock(Block):
"""Block that receives webhooks from a custom provider."""
class Input(BlockSchema):
credentials: CredentialsMetaInput = CredentialsField(
provider="custom-webhook-service",
supported_credential_types={"api_key"},
description="Credentials for webhook service",
)
events: WebhookEventFilter = SchemaField(
description="Events to listen for", default_factory=WebhookEventFilter
)
payload: Dict = SchemaField(
description="Webhook payload", default={}, hidden=True
)
class Output(BlockSchema):
event_type: String = SchemaField(description="Type of event received")
event_data: Dict = SchemaField(description="Event data")
timestamp: Float = SchemaField(description="Event timestamp")
def __init__(self):
super().__init__(
id="a4567890-bcde-f012-3456-7890bcdef012",
description="Custom webhook provider test block",
categories={BlockCategory.INPUT},
input_schema=CustomWebhookBlock.Input,
output_schema=CustomWebhookBlock.Output,
block_type=BlockType.WEBHOOK,
webhook_config=BlockWebhookConfig(
provider=ProviderName("custom-webhook-service"),
webhook_type="data_received",
event_filter_input="events",
resource_format="webhook/{webhook_id}",
),
test_input={
"credentials": {
"provider": "custom-webhook-service",
"id": "webhook-test-creds",
"type": "api_key",
"title": "Test Webhook Creds",
},
"events": {"data_received": True, "status_changed": False},
"payload": {
"type": "data_received",
"data": "test",
"timestamp": 1234567890.0,
},
},
test_output=[
("event_type", "data_received"),
(
"event_data",
{
"type": "data_received",
"data": "test",
"timestamp": 1234567890.0,
},
),
("timestamp", 1234567890.0),
],
test_credentials=APIKeyCredentials(
id="webhook-test-creds",
provider="custom-webhook-service",
api_key=SecretStr("webhook-api-key"),
title="Webhook API Key",
expires_at=None,
),
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
payload = input_data.payload
yield "event_type", payload.get("type", "unknown")
yield "event_data", payload
yield "timestamp", payload.get("timestamp", 0.0)
# Combined block using multiple custom features
@provider("custom-full-service")
@oauth_config("custom-full-service", CustomServiceOAuthHandler)
@webhook_config("custom-full-service", CustomWebhookManager)
@cost_config(
BlockCost(cost_amount=20, cost_type=BlockCostType.RUN),
BlockCost(cost_amount=5, cost_type=BlockCostType.BYTE),
)
@default_credentials(
APIKeyCredentials(
id="custom-full-service-default",
provider="custom-full-service",
api_key=SecretStr("default-full-service-key"),
title="Custom Full Service Default Key",
expires_at=None,
)
)
class CustomFullServiceBlock(Block):
"""Block demonstrating all custom provider features."""
class Input(BlockSchema):
credentials: CredentialsMetaInput = CredentialsField(
provider="custom-full-service",
supported_credential_types={"api_key", "oauth2"},
description="Credentials for full service",
)
mode: String = SchemaField(description="Operation mode", default="standard")
class Output(BlockSchema):
result: String = SchemaField(description="Operation result")
features_used: List[String] = SchemaField(description="Features utilized")
def __init__(self):
super().__init__(
id="b5678901-cdef-0123-4567-8901cdef0123",
description="Full-featured custom provider block",
categories={BlockCategory.DEVELOPER_TOOLS, BlockCategory.INPUT},
input_schema=CustomFullServiceBlock.Input,
output_schema=CustomFullServiceBlock.Output,
test_input={
"credentials": {
"provider": "custom-full-service",
"id": "full-test-creds",
"type": "api_key",
"title": "Full Service Test Creds",
},
"mode": "test",
},
test_output=[
("result", "SUCCESS: test mode"),
("features_used", ["provider", "cost_config", "default_credentials"]),
],
test_credentials=APIKeyCredentials(
id="full-test-creds",
provider="custom-full-service",
api_key=SecretStr("full-service-test-key"),
title="Full Service Test Key",
expires_at=None,
),
)
def run(self, input_data: Input, *, credentials: Any, **kwargs) -> BlockOutput:
features = ["provider", "cost_config", "default_credentials"]
if isinstance(credentials, OAuth2Credentials):
features.append("oauth")
yield "result", f"SUCCESS: {input_data.mode} mode"
yield "features_used", features
class TestCustomProviderAdvanced:
"""Advanced test suite for custom provider functionality."""
def test_oauth_handler_registration(self):
"""Test that custom OAuth handlers are registered."""
from backend.sdk.auto_registry import get_registry
registry = get_registry()
oauth_handlers = registry.get_oauth_handlers_dict()
# Check if our custom OAuth handler is registered
assert "custom-oauth-service" in oauth_handlers
assert oauth_handlers["custom-oauth-service"] == CustomServiceOAuthHandler
def test_webhook_manager_registration(self):
"""Test that custom webhook managers are registered."""
from backend.sdk.auto_registry import get_registry
registry = get_registry()
webhook_managers = registry.get_webhook_managers_dict()
# Check if our custom webhook manager is registered
assert "custom-webhook-service" in webhook_managers
assert webhook_managers["custom-webhook-service"] == CustomWebhookManager
def test_oauth_block_execution(self):
"""Test OAuth-enabled block execution."""
block = CustomOAuthBlock()
execute_block_test(block)
def test_webhook_block_execution(self):
"""Test webhook-enabled block execution."""
block = CustomWebhookBlock()
execute_block_test(block)
def test_full_service_block_execution(self):
"""Test full-featured block execution."""
block = CustomFullServiceBlock()
execute_block_test(block)
def test_multiple_decorators_on_same_provider(self):
"""Test that a single provider can have multiple features."""
from backend.sdk.auto_registry import get_registry
registry = get_registry()
# Check OAuth handler
oauth_handlers = registry.get_oauth_handlers_dict()
assert "custom-full-service" in oauth_handlers
# Check webhook manager
webhook_managers = registry.get_webhook_managers_dict()
assert "custom-full-service" in webhook_managers
# Check default credentials
default_creds = registry.get_default_credentials_list()
full_service_creds = [
cred for cred in default_creds if cred.provider == "custom-full-service"
]
assert len(full_service_creds) >= 1
# Check cost config
block_costs = registry.get_block_costs_dict()
assert CustomFullServiceBlock in block_costs
# Main test function
def test_custom_provider_advanced_functionality():
"""Run all advanced custom provider tests."""
test_instance = TestCustomProviderAdvanced()
test_instance.test_oauth_handler_registration()
test_instance.test_webhook_manager_registration()
test_instance.test_oauth_block_execution()
test_instance.test_webhook_block_execution()
test_instance.test_full_service_block_execution()
test_instance.test_multiple_decorators_on_same_provider()

View File

@@ -1,35 +0,0 @@
"""Test that custom providers work with validation."""
from pydantic import BaseModel
from backend.integrations.providers import ProviderName
class TestModel(BaseModel):
provider: ProviderName
def test_custom_provider_validation():
"""Test that custom provider names are accepted."""
# Test with existing provider
model1 = TestModel(provider=ProviderName("openai"))
assert model1.provider == ProviderName.OPENAI
assert model1.provider.value == "openai"
# Test with custom provider
model2 = TestModel(provider=ProviderName("my-custom-provider"))
assert model2.provider.value == "my-custom-provider"
# Test JSON schema
schema = TestModel.model_json_schema()
provider_schema = schema["properties"]["provider"]
# Should not have enum constraint
assert "enum" not in provider_schema
assert provider_schema["type"] == "string"
print("✅ Custom provider validation works!")
if __name__ == "__main__":
test_custom_provider_validation()

View File

@@ -0,0 +1,442 @@
"""
Tests for creating blocks using the SDK.
This test suite verifies that blocks can be created using only SDK imports
and that they work correctly without decorators.
"""
import pytest
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockCostType,
BlockOutput,
BlockSchema,
Boolean,
CredentialsField,
CredentialsMetaInput,
Integer,
Optional,
ProviderBuilder,
SchemaField,
SecretStr,
String,
)
class TestBasicBlockCreation:
"""Test creating basic blocks using the SDK."""
def test_simple_block(self):
"""Test creating a simple block without any decorators."""
class SimpleBlock(Block):
"""A simple test block."""
class Input(BlockSchema):
text: String = SchemaField(description="Input text")
count: Integer = SchemaField(description="Repeat count", default=1)
class Output(BlockSchema):
result: String = SchemaField(description="Output result")
def __init__(self):
super().__init__(
id="simple-test-block",
description="A simple test block",
categories={BlockCategory.TEXT},
input_schema=SimpleBlock.Input,
output_schema=SimpleBlock.Output,
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
result = input_data.text * input_data.count
yield "result", result
# Create and test the block
block = SimpleBlock()
assert block.id == "simple-test-block"
assert BlockCategory.TEXT in block.categories
# Test execution
outputs = list(
block.run(
SimpleBlock.Input(text="Hello ", count=3),
)
)
assert len(outputs) == 1
assert outputs[0] == ("result", "Hello Hello Hello ")
def test_block_with_credentials(self):
"""Test creating a block that requires credentials."""
class APIBlock(Block):
"""A block that requires API credentials."""
class Input(BlockSchema):
credentials: CredentialsMetaInput = CredentialsField(
provider="test_api",
supported_credential_types={"api_key"},
description="API credentials",
)
query: String = SchemaField(description="API query")
class Output(BlockSchema):
response: String = SchemaField(description="API response")
authenticated: Boolean = SchemaField(description="Was authenticated")
def __init__(self):
super().__init__(
id="api-test-block",
description="Test block with API credentials",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=APIBlock.Input,
output_schema=APIBlock.Output,
)
def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
# Simulate API call
api_key = credentials.api_key.get_secret_value()
authenticated = bool(api_key)
yield "response", f"API response for: {input_data.query}"
yield "authenticated", authenticated
# Create test credentials
test_creds = APIKeyCredentials(
id="test-creds",
provider="test_api",
api_key=SecretStr("test-api-key"),
title="Test API Key",
)
# Create and test the block
block = APIBlock()
outputs = list(
block.run(
APIBlock.Input(
credentials={ # type: ignore
"provider": "test_api",
"id": "test-creds",
"type": "api_key",
},
query="test query",
),
credentials=test_creds,
)
)
assert len(outputs) == 2
assert outputs[0] == ("response", "API response for: test query")
assert outputs[1] == ("authenticated", True)
def test_block_with_multiple_outputs(self):
"""Test block that yields multiple outputs."""
class MultiOutputBlock(Block):
"""Block with multiple outputs."""
class Input(BlockSchema):
text: String = SchemaField(description="Input text")
class Output(BlockSchema):
uppercase: String = SchemaField(description="Uppercase version")
lowercase: String = SchemaField(description="Lowercase version")
length: Integer = SchemaField(description="Text length")
is_empty: Boolean = SchemaField(description="Is text empty")
def __init__(self):
super().__init__(
id="multi-output-block",
description="Block with multiple outputs",
categories={BlockCategory.TEXT},
input_schema=MultiOutputBlock.Input,
output_schema=MultiOutputBlock.Output,
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
text = input_data.text
yield "uppercase", text.upper()
yield "lowercase", text.lower()
yield "length", len(text)
yield "is_empty", len(text) == 0
# Test the block
block = MultiOutputBlock()
outputs = list(block.run(MultiOutputBlock.Input(text="Hello World")))
assert len(outputs) == 4
assert ("uppercase", "HELLO WORLD") in outputs
assert ("lowercase", "hello world") in outputs
assert ("length", 11) in outputs
assert ("is_empty", False) in outputs
class TestBlockWithProvider:
"""Test creating blocks associated with providers."""
def setup_method(self):
"""Set up test provider."""
# Create a provider using ProviderBuilder
self.provider = (
ProviderBuilder("test_service")
.with_api_key("TEST_SERVICE_API_KEY", "Test Service API Key")
.with_base_cost(10, BlockCostType.RUN)
.build()
)
def test_block_using_provider(self):
"""Test block that uses a registered provider."""
class TestServiceBlock(Block):
"""Block for test service."""
class Input(BlockSchema):
credentials: CredentialsMetaInput = CredentialsField(
provider="test_service", # Matches our provider
supported_credential_types={"api_key"},
description="Test service credentials",
)
action: String = SchemaField(description="Action to perform")
class Output(BlockSchema):
result: String = SchemaField(description="Action result")
provider_name: String = SchemaField(description="Provider used")
def __init__(self):
super().__init__(
id="test-service-block",
description="Block using test service provider",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=TestServiceBlock.Input,
output_schema=TestServiceBlock.Output,
)
def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
# The provider name should match
yield "result", f"Performed: {input_data.action}"
yield "provider_name", credentials.provider
# Create credentials for our provider
creds = APIKeyCredentials(
id="test-service-creds",
provider="test_service",
api_key=SecretStr("test-key"),
title="Test Service Key",
)
# Test the block
block = TestServiceBlock()
outputs = dict(
block.run(
TestServiceBlock.Input(
credentials={ # type: ignore
"provider": "test_service",
"id": "test-service-creds",
"type": "api_key",
},
action="test action",
),
credentials=creds,
)
)
assert outputs["result"] == "Performed: test action"
assert outputs["provider_name"] == "test_service"
class TestComplexBlockScenarios:
"""Test more complex block scenarios."""
def test_block_with_optional_fields(self):
"""Test block with optional input fields."""
# Optional is already imported at the module level
class OptionalFieldBlock(Block):
"""Block with optional fields."""
class Input(BlockSchema):
required_field: String = SchemaField(description="Required field")
optional_field: Optional[String] = SchemaField(
description="Optional field",
default=None,
)
optional_with_default: String = SchemaField(
description="Optional with default",
default="default value",
)
class Output(BlockSchema):
has_optional: Boolean = SchemaField(description="Has optional value")
optional_value: Optional[String] = SchemaField(
description="Optional value"
)
default_value: String = SchemaField(description="Default value")
def __init__(self):
super().__init__(
id="optional-field-block",
description="Block with optional fields",
categories={BlockCategory.TEXT},
input_schema=OptionalFieldBlock.Input,
output_schema=OptionalFieldBlock.Output,
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
yield "has_optional", input_data.optional_field is not None
yield "optional_value", input_data.optional_field
yield "default_value", input_data.optional_with_default
# Test with optional field provided
block = OptionalFieldBlock()
outputs = dict(
block.run(
OptionalFieldBlock.Input(
required_field="test",
optional_field="provided",
)
)
)
assert outputs["has_optional"] is True
assert outputs["optional_value"] == "provided"
assert outputs["default_value"] == "default value"
# Test without optional field
outputs = dict(
block.run(
OptionalFieldBlock.Input(
required_field="test",
)
)
)
assert outputs["has_optional"] is False
assert outputs["optional_value"] is None
assert outputs["default_value"] == "default value"
def test_block_with_complex_types(self):
"""Test block with complex input/output types."""
from backend.sdk import BaseModel, Dict, List
class ItemModel(BaseModel):
name: str
value: int
class ComplexBlock(Block):
"""Block with complex types."""
class Input(BlockSchema):
items: List[String] = SchemaField(description="List of items")
mapping: Dict[String, Integer] = SchemaField(
description="String to int mapping"
)
class Output(BlockSchema):
item_count: Integer = SchemaField(description="Number of items")
total_value: Integer = SchemaField(description="Sum of mapping values")
combined: List[String] = SchemaField(description="Combined results")
def __init__(self):
super().__init__(
id="complex-types-block",
description="Block with complex types",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=ComplexBlock.Input,
output_schema=ComplexBlock.Output,
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
yield "item_count", len(input_data.items)
yield "total_value", sum(input_data.mapping.values())
# Combine items with their mapping values
combined = []
for item in input_data.items:
value = input_data.mapping.get(item, 0)
combined.append(f"{item}: {value}")
yield "combined", combined
# Test the block
block = ComplexBlock()
outputs = dict(
block.run(
ComplexBlock.Input(
items=["apple", "banana", "orange"],
mapping={"apple": 5, "banana": 3, "orange": 4},
)
)
)
assert outputs["item_count"] == 3
assert outputs["total_value"] == 12
assert outputs["combined"] == ["apple: 5", "banana: 3", "orange: 4"]
def test_block_error_handling(self):
"""Test block error handling."""
class ErrorHandlingBlock(Block):
"""Block that demonstrates error handling."""
class Input(BlockSchema):
value: Integer = SchemaField(description="Input value")
should_error: Boolean = SchemaField(
description="Whether to trigger an error",
default=False,
)
class Output(BlockSchema):
result: Integer = SchemaField(description="Result")
error_message: Optional[String] = SchemaField(
description="Error if any", default=None
)
def __init__(self):
super().__init__(
id="error-handling-block",
description="Block with error handling",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=ErrorHandlingBlock.Input,
output_schema=ErrorHandlingBlock.Output,
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
if input_data.should_error:
raise ValueError("Intentional error triggered")
if input_data.value < 0:
yield "error_message", "Value must be non-negative"
yield "result", 0
else:
yield "result", input_data.value * 2
yield "error_message", None
# Test normal operation
block = ErrorHandlingBlock()
outputs = dict(block.run(ErrorHandlingBlock.Input(value=5, should_error=False)))
assert outputs["result"] == 10
assert outputs["error_message"] is None
# Test with negative value
outputs = dict(
block.run(ErrorHandlingBlock.Input(value=-5, should_error=False))
)
assert outputs["result"] == 0
assert outputs["error_message"] == "Value must be non-negative"
# Test with error
with pytest.raises(ValueError, match="Intentional error triggered"):
list(block.run(ErrorHandlingBlock.Input(value=5, should_error=True)))
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -1,483 +0,0 @@
"""
Comprehensive test suite for the AutoGPT SDK implementation.
Tests all aspects of the SDK including imports, decorators, and auto-registration.
"""
import sys
from pathlib import Path
# Add backend to path
backend_path = Path(__file__).parent.parent.parent
sys.path.insert(0, str(backend_path))
class TestSDKImplementation:
"""Comprehensive SDK tests"""
def test_sdk_imports_all_components(self):
"""Test that all expected components are available from backend.sdk import *"""
# Import SDK
import backend.sdk as sdk
# Core block components
assert hasattr(sdk, "Block")
assert hasattr(sdk, "BlockCategory")
assert hasattr(sdk, "BlockOutput")
assert hasattr(sdk, "BlockSchema")
assert hasattr(sdk, "BlockType")
assert hasattr(sdk, "SchemaField")
# Credential components
assert hasattr(sdk, "CredentialsField")
assert hasattr(sdk, "CredentialsMetaInput")
assert hasattr(sdk, "APIKeyCredentials")
assert hasattr(sdk, "OAuth2Credentials")
assert hasattr(sdk, "UserPasswordCredentials")
# Cost components
assert hasattr(sdk, "BlockCost")
assert hasattr(sdk, "BlockCostType")
assert hasattr(sdk, "NodeExecutionStats")
# Provider component
assert hasattr(sdk, "ProviderName")
# Type aliases
assert sdk.String is str
assert sdk.Integer is int
assert sdk.Float is float
assert sdk.Boolean is bool
# Decorators
assert hasattr(sdk, "provider")
assert hasattr(sdk, "cost_config")
assert hasattr(sdk, "default_credentials")
assert hasattr(sdk, "webhook_config")
assert hasattr(sdk, "oauth_config")
# Common types
assert hasattr(sdk, "List")
assert hasattr(sdk, "Dict")
assert hasattr(sdk, "Optional")
assert hasattr(sdk, "Any")
assert hasattr(sdk, "Union")
assert hasattr(sdk, "BaseModel")
assert hasattr(sdk, "SecretStr")
assert hasattr(sdk, "Enum")
# Utilities
assert hasattr(sdk, "json")
assert hasattr(sdk, "logging")
print("✅ All SDK imports verified")
def test_auto_registry_system(self):
"""Test the auto-registration system"""
from backend.sdk import APIKeyCredentials, BlockCost, BlockCostType, SecretStr
from backend.sdk.auto_registry import AutoRegistry, get_registry
# Get registry instance
registry = get_registry()
assert isinstance(registry, AutoRegistry)
# Test provider registration
initial_providers = len(registry.providers)
registry.register_provider("test-provider-123")
assert "test-provider-123" in registry.providers
assert len(registry.providers) == initial_providers + 1
# Test cost registration
class TestBlock:
pass
test_costs = [
BlockCost(cost_amount=10, cost_type=BlockCostType.RUN),
BlockCost(cost_amount=2, cost_type=BlockCostType.BYTE),
]
registry.register_block_cost(TestBlock, test_costs)
assert TestBlock in registry.block_costs
assert len(registry.block_costs[TestBlock]) == 2
assert registry.block_costs[TestBlock][0].cost_amount == 10
# Test credential registration
test_cred = APIKeyCredentials(
id="test-cred-123",
provider="test-provider-123",
api_key=SecretStr("test-api-key"),
title="Test Credential",
)
registry.register_default_credential(test_cred)
# Check credential was added
creds = registry.get_default_credentials_list()
assert any(c.id == "test-cred-123" for c in creds)
# Test duplicate prevention
initial_cred_count = len(registry.default_credentials)
registry.register_default_credential(test_cred) # Add again
assert (
len(registry.default_credentials) == initial_cred_count
) # Should not increase
print("✅ Auto-registry system verified")
def test_decorators_functionality(self):
"""Test that all decorators work correctly"""
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockCost,
BlockCostType,
BlockOutput,
BlockSchema,
SchemaField,
SecretStr,
String,
cost_config,
default_credentials,
oauth_config,
provider,
webhook_config,
)
from backend.sdk.auto_registry import get_registry
registry = get_registry()
# Clear registry state for clean test
# initial_provider_count = len(registry.providers)
# Test combined decorators on a block
@provider("test-service-xyz")
@cost_config(
BlockCost(cost_amount=15, cost_type=BlockCostType.RUN),
BlockCost(cost_amount=3, cost_type=BlockCostType.SECOND),
)
@default_credentials(
APIKeyCredentials(
id="test-service-xyz-default",
provider="test-service-xyz",
api_key=SecretStr("default-test-key"),
title="Test Service Default Key",
)
)
class TestServiceBlock(Block):
class Input(BlockSchema):
text: String = SchemaField(description="Test input")
class Output(BlockSchema):
result: String = SchemaField(description="Test output")
def __init__(self):
super().__init__(
id="f0421f19-53da-4824-97cc-4d2bccd1399f",
description="Test service block",
categories={BlockCategory.TEXT},
input_schema=TestServiceBlock.Input,
output_schema=TestServiceBlock.Output,
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
yield "result", f"Processed: {input_data.text}"
# Verify decorators worked
assert "test-service-xyz" in registry.providers
assert TestServiceBlock in registry.block_costs
assert len(registry.block_costs[TestServiceBlock]) == 2
# Check credentials
creds = registry.get_default_credentials_list()
assert any(c.id == "test-service-xyz-default" for c in creds)
# Test webhook decorator (mock classes for testing)
class MockWebhookManager:
pass
@webhook_config("test-webhook-provider", MockWebhookManager)
class TestWebhookBlock:
pass
assert "test-webhook-provider" in registry.webhook_managers
assert registry.webhook_managers["test-webhook-provider"] == MockWebhookManager
# Test oauth decorator
class MockOAuthHandler:
pass
@oauth_config("test-oauth-provider", MockOAuthHandler)
class TestOAuthBlock:
pass
assert "test-oauth-provider" in registry.oauth_handlers
assert registry.oauth_handlers["test-oauth-provider"] == MockOAuthHandler
print("✅ All decorators verified")
def test_provider_enum_dynamic_support(self):
"""Test that ProviderName enum supports dynamic providers"""
from backend.sdk import ProviderName
# Test existing provider
existing = ProviderName.GITHUB
assert existing.value == "github"
assert isinstance(existing, ProviderName)
# Test dynamic provider
dynamic = ProviderName("my-custom-provider-abc")
assert dynamic.value == "my-custom-provider-abc"
assert isinstance(dynamic, ProviderName)
assert dynamic._name_ == "MY-CUSTOM-PROVIDER-ABC"
# Test that same dynamic provider returns same instance
dynamic2 = ProviderName("my-custom-provider-abc")
assert dynamic.value == dynamic2.value
# Test invalid input
try:
ProviderName(123) # Should not work with non-string
assert False, "Should have failed with non-string"
except ValueError:
pass # Expected
print("✅ Dynamic provider enum verified")
def test_complete_block_example(self):
"""Test a complete block using all SDK features"""
# This simulates what a block developer would write
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockCost,
BlockCostType,
BlockOutput,
BlockSchema,
CredentialsField,
CredentialsMetaInput,
Float,
SchemaField,
SecretStr,
String,
cost_config,
default_credentials,
provider,
)
@provider("ai-translator-service")
@cost_config(
BlockCost(cost_amount=5, cost_type=BlockCostType.RUN),
BlockCost(cost_amount=1, cost_type=BlockCostType.BYTE),
)
@default_credentials(
APIKeyCredentials(
id="ai-translator-default",
provider="ai-translator-service",
api_key=SecretStr("translator-default-key"),
title="AI Translator Default API Key",
)
)
class AITranslatorBlock(Block):
"""AI-powered translation block using the SDK"""
class Input(BlockSchema):
credentials: CredentialsMetaInput = CredentialsField(
provider="ai-translator-service",
supported_credential_types={"api_key"},
description="API credentials for AI Translator",
)
text: String = SchemaField(
description="Text to translate", default="Hello, world!"
)
target_language: String = SchemaField(
description="Target language code", default="es"
)
class Output(BlockSchema):
translated_text: String = SchemaField(description="Translated text")
source_language: String = SchemaField(
description="Detected source language"
)
confidence: Float = SchemaField(
description="Translation confidence score"
)
error: String = SchemaField(
description="Error message if any", default=""
)
def __init__(self):
super().__init__(
id="dc832afe-902a-4520-8512-d3b85428d4ec",
description="Translate text using AI Translator Service",
categories={BlockCategory.TEXT, BlockCategory.AI},
input_schema=AITranslatorBlock.Input,
output_schema=AITranslatorBlock.Output,
test_input={"text": "Hello, world!", "target_language": "es"},
test_output=[
("translated_text", "¡Hola, mundo!"),
("source_language", "en"),
("confidence", 0.95),
],
)
def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
# Simulate translation
credentials.api_key.get_secret_value() # Verify we can access the key
# Mock translation logic
translations = {
("Hello, world!", "es"): "¡Hola, mundo!",
("Hello, world!", "fr"): "Bonjour le monde!",
("Hello, world!", "de"): "Hallo Welt!",
}
key = (input_data.text, input_data.target_language)
translated = translations.get(
key, f"[{input_data.target_language}] {input_data.text}"
)
yield "translated_text", translated
yield "source_language", "en"
yield "confidence", 0.95
yield "error", ""
except Exception as e:
yield "translated_text", ""
yield "source_language", ""
yield "confidence", 0.0
yield "error", str(e)
# Verify the block was created correctly
block = AITranslatorBlock()
assert block.id == "dc832afe-902a-4520-8512-d3b85428d4ec"
assert block.description == "Translate text using AI Translator Service"
assert BlockCategory.TEXT in block.categories
assert BlockCategory.AI in block.categories
# Verify decorators registered everything
from backend.sdk.auto_registry import get_registry
registry = get_registry()
assert "ai-translator-service" in registry.providers
assert AITranslatorBlock in registry.block_costs
assert len(registry.block_costs[AITranslatorBlock]) == 2
creds = registry.get_default_credentials_list()
assert any(c.id == "ai-translator-default" for c in creds)
print("✅ Complete block example verified")
def test_backward_compatibility(self):
"""Test that old-style imports still work"""
# Test that we can still import from original locations
try:
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchema,
)
from backend.data.model import SchemaField
assert Block is not None
assert BlockCategory is not None
assert BlockOutput is not None
assert BlockSchema is not None
assert SchemaField is not None
print("✅ Backward compatibility verified")
except ImportError as e:
print(f"❌ Backward compatibility issue: {e}")
raise
def test_auto_registration_patching(self):
"""Test that auto-registration correctly patches existing systems"""
from backend.sdk.auto_registry import patch_existing_systems
# This would normally be called during app startup
# For testing, we'll verify the patching logic works
try:
patch_existing_systems()
print("✅ Auto-registration patching verified")
except Exception as e:
print(f"⚠️ Patching had issues (expected in test environment): {e}")
# This is expected in test environment where not all systems are loaded
def test_import_star_works(self):
"""Test that 'from backend.sdk import *' actually works"""
# Create a temporary module to test import *
test_code = """
from backend.sdk import *
# Test that common items are available
assert Block is not None
assert BlockSchema is not None
assert SchemaField is not None
assert String == str
assert provider is not None
assert cost_config is not None
print("Import * works correctly")
"""
# Execute in a clean namespace
namespace = {"__name__": "__main__"}
try:
exec(test_code, namespace)
print("✅ Import * functionality verified")
except Exception as e:
print(f"❌ Import * failed: {e}")
raise
def run_all_tests():
"""Run all SDK tests"""
print("\n" + "=" * 60)
print("🧪 Running Comprehensive SDK Tests")
print("=" * 60 + "\n")
test_suite = TestSDKImplementation()
tests = [
("SDK Imports", test_suite.test_sdk_imports_all_components),
("Auto-Registry System", test_suite.test_auto_registry_system),
("Decorators", test_suite.test_decorators_functionality),
("Dynamic Provider Enum", test_suite.test_provider_enum_dynamic_support),
("Complete Block Example", test_suite.test_complete_block_example),
("Backward Compatibility", test_suite.test_backward_compatibility),
("Auto-Registration Patching", test_suite.test_auto_registration_patching),
("Import * Syntax", test_suite.test_import_star_works),
]
passed = 0
failed = 0
for test_name, test_func in tests:
print(f"\n📋 Testing: {test_name}")
print("-" * 40)
try:
test_func()
passed += 1
except Exception as e:
print(f"❌ Test failed: {e}")
failed += 1
import traceback
traceback.print_exc()
print("\n" + "=" * 60)
print(f"📊 Test Results: {passed} passed, {failed} failed")
print("=" * 60)
if failed == 0:
print("\n🎉 All SDK tests passed! The implementation is working correctly.")
else:
print(f"\n⚠️ {failed} tests failed. Please review the errors above.")
return failed == 0
if __name__ == "__main__":
success = run_all_tests()
sys.exit(0 if success else 1)

View File

@@ -1,172 +0,0 @@
"""
Test the SDK import system and auto-registration
"""
import sys
from pathlib import Path
# Add backend to path
backend_path = Path(__file__).parent.parent.parent
sys.path.insert(0, str(backend_path))
def test_sdk_imports():
"""Test that all expected imports are available from backend.sdk"""
# Import the module and check its contents
import backend.sdk as sdk
# Core block components should be available
assert hasattr(sdk, "Block")
assert hasattr(sdk, "BlockCategory")
assert hasattr(sdk, "BlockOutput")
assert hasattr(sdk, "BlockSchema")
assert hasattr(sdk, "SchemaField")
# Credential types should be available
assert hasattr(sdk, "CredentialsField")
assert hasattr(sdk, "CredentialsMetaInput")
assert hasattr(sdk, "APIKeyCredentials")
assert hasattr(sdk, "OAuth2Credentials")
# Cost system should be available
assert hasattr(sdk, "BlockCost")
assert hasattr(sdk, "BlockCostType")
# Providers should be available
assert hasattr(sdk, "ProviderName")
# Type aliases should work
assert sdk.String is str
assert sdk.Integer is int
assert sdk.Float is float
assert sdk.Boolean is bool
# Decorators should be available
assert hasattr(sdk, "provider")
assert hasattr(sdk, "cost_config")
assert hasattr(sdk, "default_credentials")
assert hasattr(sdk, "webhook_config")
assert hasattr(sdk, "oauth_config")
# Common types should be available
assert hasattr(sdk, "List")
assert hasattr(sdk, "Dict")
assert hasattr(sdk, "Optional")
assert hasattr(sdk, "Any")
assert hasattr(sdk, "Union")
assert hasattr(sdk, "BaseModel")
assert hasattr(sdk, "SecretStr")
# Utilities should be available
assert hasattr(sdk, "json")
assert hasattr(sdk, "logging")
def test_auto_registry():
"""Test the auto-registration system"""
from backend.sdk import APIKeyCredentials, BlockCost, BlockCostType, SecretStr
from backend.sdk.auto_registry import AutoRegistry, get_registry
# Get the registry
registry = get_registry()
assert isinstance(registry, AutoRegistry)
# Test registering a provider
registry.register_provider("test-provider")
assert "test-provider" in registry.providers
# Test registering block costs
test_costs = [BlockCost(cost_amount=5, cost_type=BlockCostType.RUN)]
class TestBlock:
pass
registry.register_block_cost(TestBlock, test_costs)
assert TestBlock in registry.block_costs
assert registry.block_costs[TestBlock] == test_costs
# Test registering credentials
test_cred = APIKeyCredentials(
id="test-cred",
provider="test-provider",
api_key=SecretStr("test-key"),
title="Test Credential",
)
registry.register_default_credential(test_cred)
assert test_cred in registry.default_credentials
def test_decorators():
"""Test that decorators work correctly"""
from backend.sdk import (
APIKeyCredentials,
BlockCost,
BlockCostType,
SecretStr,
cost_config,
default_credentials,
provider,
)
from backend.sdk.auto_registry import get_registry
# Clear registry for test
registry = get_registry()
# Test provider decorator
@provider("decorator-test")
class DecoratorTestBlock:
pass
assert "decorator-test" in registry.providers
# Test cost_config decorator
@cost_config(BlockCost(cost_amount=10, cost_type=BlockCostType.RUN))
class CostTestBlock:
pass
assert CostTestBlock in registry.block_costs
assert len(registry.block_costs[CostTestBlock]) == 1
assert registry.block_costs[CostTestBlock][0].cost_amount == 10
# Test default_credentials decorator
@default_credentials(
APIKeyCredentials(
id="decorator-test-cred",
provider="decorator-test",
api_key=SecretStr("test-api-key"),
title="Decorator Test Credential",
)
)
class CredTestBlock:
pass
# Check if credential was registered
creds = registry.get_default_credentials_list()
assert any(c.id == "decorator-test-cred" for c in creds)
def test_example_block_imports():
"""Test that example blocks can use SDK imports"""
# Skip this test since example blocks were moved to examples directory
# to avoid interfering with main tests
pass
if __name__ == "__main__":
# Run tests
test_sdk_imports()
print("✅ SDK imports test passed")
test_auto_registry()
print("✅ Auto-registry test passed")
test_decorators()
print("✅ Decorators test passed")
test_example_block_imports()
print("✅ Example block test passed")
print("\n🎉 All SDK tests passed!")

View File

@@ -1,442 +0,0 @@
"""
Integration test demonstrating the complete SDK workflow.
This shows how a developer would create a new block with zero external configuration.
"""
import sys
from enum import Enum
from pathlib import Path
# Add backend to path
backend_path = Path(__file__).parent.parent.parent
sys.path.insert(0, str(backend_path))
# ruff: noqa: E402
# Import SDK at module level for testing (after sys.path modification)
from backend.integrations.providers import ProviderName
from backend.sdk import (
APIKeyCredentials,
BaseWebhooksManager,
Block,
BlockCategory,
BlockCost,
BlockCostType,
BlockOutput,
BlockSchema,
BlockType,
BlockWebhookConfig,
CredentialsField,
CredentialsMetaInput,
Dict,
Float,
Integer,
List,
SchemaField,
SecretStr,
String,
cost_config,
default_credentials,
provider,
webhook_config,
)
def test_complete_sdk_workflow():
"""
Demonstrate the complete workflow of creating a new block with the SDK.
This test shows:
1. Single import statement
2. Custom provider registration
3. Cost configuration
4. Default credentials
5. Zero external configuration needed
"""
print("\n" + "=" * 60)
print("🚀 SDK Integration Test - Complete Workflow")
print("=" * 60 + "\n")
# Step 1: Import everything needed with a single statement
print("Step 1: Import SDK")
# SDK already imported at module level
print("✅ Imported all components with 'from backend.sdk import *'")
# Step 2: Create a custom AI service block
print("\nStep 2: Create a custom AI service block")
@provider("custom-ai-vision-service")
@cost_config(
BlockCost(cost_amount=10, cost_type=BlockCostType.RUN),
BlockCost(cost_amount=5, cost_type=BlockCostType.BYTE),
)
@default_credentials(
APIKeyCredentials(
id="custom-ai-vision-default",
provider="custom-ai-vision-service",
api_key=SecretStr("vision-service-default-api-key"),
title="Custom AI Vision Service Default API Key",
expires_at=None,
)
)
class CustomAIVisionBlock(Block):
"""
Custom AI Vision Analysis Block
This block demonstrates:
- Custom provider name (not in the original enum)
- Automatic cost registration
- Default credentials setup
- Complex input/output schemas
"""
class Input(BlockSchema):
credentials: CredentialsMetaInput = CredentialsField(
provider="custom-ai-vision-service",
supported_credential_types={"api_key"},
description="API credentials for Custom AI Vision Service",
)
image_url: String = SchemaField(
description="URL of the image to analyze",
placeholder="https://example.com/image.jpg",
)
analysis_type: String = SchemaField(
description="Type of analysis to perform",
default="general",
)
confidence_threshold: Float = SchemaField(
description="Minimum confidence threshold for detections",
default=0.7,
ge=0.0,
le=1.0,
)
max_results: Integer = SchemaField(
description="Maximum number of results to return",
default=10,
ge=1,
le=100,
)
class Output(BlockSchema):
detections: List[Dict] = SchemaField(
description="List of detected items with confidence scores", default=[]
)
analysis_type: String = SchemaField(
description="Type of analysis performed"
)
processing_time: Float = SchemaField(
description="Time taken to process the image in seconds"
)
total_detections: Integer = SchemaField(
description="Total number of detections found"
)
error: String = SchemaField(
description="Error message if analysis failed", default=""
)
def __init__(self):
super().__init__(
id="303d9bd3-f2a5-41ca-bb9c-e347af8ef72f",
description="Analyze images using Custom AI Vision Service with configurable detection types",
categories={BlockCategory.AI, BlockCategory.MULTIMEDIA},
input_schema=CustomAIVisionBlock.Input,
output_schema=CustomAIVisionBlock.Output,
test_input={
"image_url": "https://example.com/test-image.jpg",
"analysis_type": "objects",
"confidence_threshold": 0.8,
"max_results": 5,
},
test_output=[
(
"detections",
[
{"object": "car", "confidence": 0.95},
{"object": "person", "confidence": 0.87},
],
),
("analysis_type", "objects"),
("processing_time", 1.23),
("total_detections", 2),
("error", ""),
],
static_output=False,
)
def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
import time
start_time = time.time()
try:
# Get API key
api_key = credentials.api_key.get_secret_value()
# Simulate API call to vision service
print(f" - Using API key: {api_key[:10]}...")
print(f" - Analyzing image: {input_data.image_url}")
print(f" - Analysis type: {input_data.analysis_type}")
# Mock detection results based on analysis type
mock_results = {
"general": [
{"category": "indoor", "confidence": 0.92},
{"category": "office", "confidence": 0.88},
],
"faces": [
{"face_id": 1, "confidence": 0.95, "age": "25-35"},
{"face_id": 2, "confidence": 0.91, "age": "40-50"},
],
"objects": [
{"object": "laptop", "confidence": 0.94},
{"object": "coffee_cup", "confidence": 0.89},
{"object": "notebook", "confidence": 0.85},
],
"text": [
{"text": "Hello World", "confidence": 0.97},
{"text": "SDK Demo", "confidence": 0.93},
],
"scene": [
{"scene": "office_workspace", "confidence": 0.91},
{"scene": "indoor_lighting", "confidence": 0.87},
],
}
# Get results for the requested analysis type
detections = mock_results.get(
input_data.analysis_type,
[{"error": "Unknown analysis type", "confidence": 0.0}],
)
# Filter by confidence threshold
filtered_detections = [
d
for d in detections
if d.get("confidence", 0) >= input_data.confidence_threshold
]
# Limit results
final_detections = filtered_detections[: input_data.max_results]
# Calculate processing time
processing_time = time.time() - start_time
# Yield results
yield "detections", final_detections
yield "analysis_type", input_data.analysis_type
yield "processing_time", round(processing_time, 3)
yield "total_detections", len(final_detections)
yield "error", ""
except Exception as e:
yield "detections", []
yield "analysis_type", input_data.analysis_type
yield "processing_time", time.time() - start_time
yield "total_detections", 0
yield "error", str(e)
print("✅ Block class created with all decorators")
# Step 3: Verify auto-registration worked
print("\nStep 3: Verify auto-registration")
from backend.sdk.auto_registry import get_registry
registry = get_registry()
# Check provider registration
assert "custom-ai-vision-service" in registry.providers
print("✅ Custom provider 'custom-ai-vision-service' auto-registered")
# Check cost registration
assert CustomAIVisionBlock in registry.block_costs
costs = registry.block_costs[CustomAIVisionBlock]
assert len(costs) == 2
assert costs[0].cost_amount == 10
assert costs[0].cost_type == BlockCostType.RUN
print("✅ Block costs auto-registered (10 credits per run, 5 per byte)")
# Check credential registration
creds = registry.get_default_credentials_list()
vision_cred = next((c for c in creds if c.id == "custom-ai-vision-default"), None)
assert vision_cred is not None
assert vision_cred.provider == "custom-ai-vision-service"
print("✅ Default credentials auto-registered")
# Step 4: Test dynamic provider enum
print("\nStep 4: Test dynamic provider support")
provider_instance = ProviderName("custom-ai-vision-service")
assert provider_instance.value == "custom-ai-vision-service"
assert isinstance(provider_instance, ProviderName)
print("✅ ProviderName enum accepts custom provider dynamically")
# Step 5: Instantiate and test the block
print("\nStep 5: Test block instantiation and execution")
block = CustomAIVisionBlock()
# Verify block properties
assert block.id == "303d9bd3-f2a5-41ca-bb9c-e347af8ef72f"
assert BlockCategory.AI in block.categories
assert BlockCategory.MULTIMEDIA in block.categories
print("✅ Block instantiated successfully")
# Test block execution
test_credentials = APIKeyCredentials(
id="test-cred",
provider="custom-ai-vision-service",
api_key=SecretStr("test-api-key-12345"),
title="Test API Key",
)
test_input = CustomAIVisionBlock.Input(
credentials={
"provider": "custom-ai-vision-service",
"id": "test",
"type": "api_key",
},
image_url="https://example.com/test.jpg",
analysis_type="objects",
confidence_threshold=0.8,
max_results=3,
)
print("\n Running block with test data...")
results = list(block.run(test_input, credentials=test_credentials))
# Verify outputs
output_dict = {key: value for key, value in results}
assert "detections" in output_dict
assert "analysis_type" in output_dict
assert output_dict["analysis_type"] == "objects"
assert "total_detections" in output_dict
assert output_dict["error"] == ""
print("✅ Block execution successful")
# Step 6: Summary
print("\n" + "=" * 60)
print("🎉 SDK Integration Test Complete!")
print("=" * 60)
print("\nKey achievements demonstrated:")
print("✅ Single import: from backend.sdk import *")
print("✅ Custom provider registered automatically")
print("✅ Costs configured via decorator")
print("✅ Default credentials set via decorator")
print("✅ Block works without ANY external configuration")
print("✅ Dynamic provider name accepted by enum")
print("\nThe SDK successfully enables zero-configuration block development!")
return True
def test_webhook_block_workflow():
"""Test creating a webhook block with the SDK"""
print("\n\n" + "=" * 60)
print("🔔 Webhook Block Integration Test")
print("=" * 60 + "\n")
# SDK already imported at module level
# Create a simple webhook manager
class CustomWebhookManager(BaseWebhooksManager):
PROVIDER_NAME = ProviderName("custom-webhook-service")
class WebhookType(str, Enum):
DATA_UPDATE = "data_update"
STATUS_CHANGE = "status_change"
@classmethod
async def validate_payload(cls, webhook, request) -> tuple[dict, str]:
payload = await request.json()
event_type = request.headers.get("X-Custom-Event", "unknown")
return payload, event_type
async def _register_webhook(
self,
credentials,
webhook_type: str,
resource: str,
events: list[str],
ingress_url: str,
secret: str,
) -> tuple[str, dict]:
# Mock registration
return "webhook-12345", {"status": "registered"}
async def _deregister_webhook(self, webhook, credentials) -> None:
pass
# Create webhook block
@provider("custom-webhook-service")
@webhook_config("custom-webhook-service", CustomWebhookManager)
class CustomWebhookBlock(Block):
class Input(BlockSchema):
webhook_events: Dict = SchemaField(
description="Events to listen for",
default={"data_update": True, "status_change": False},
)
payload: Dict = SchemaField(
description="Webhook payload", default={}, hidden=True
)
class Output(BlockSchema):
event_type: String = SchemaField(description="Type of event")
event_data: Dict = SchemaField(description="Event data")
timestamp: String = SchemaField(description="Event timestamp")
def __init__(self):
super().__init__(
id="3e730ed4-6eb2-4b89-b5ae-001860c88aef",
description="Listen for custom webhook events",
categories={BlockCategory.INPUT},
input_schema=CustomWebhookBlock.Input,
output_schema=CustomWebhookBlock.Output,
block_type=BlockType.WEBHOOK,
webhook_config=BlockWebhookConfig(
provider=ProviderName("custom-webhook-service"),
webhook_type="data_update",
event_filter_input="webhook_events",
resource_format="{resource}",
),
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
payload = input_data.payload
yield "event_type", payload.get("type", "unknown")
yield "event_data", payload
yield "timestamp", payload.get("timestamp", "")
# Verify registration
from backend.sdk.auto_registry import get_registry
registry = get_registry()
assert "custom-webhook-service" in registry.webhook_managers
assert registry.webhook_managers["custom-webhook-service"] == CustomWebhookManager
print("✅ Webhook manager auto-registered")
print("✅ Webhook block created successfully")
return True
if __name__ == "__main__":
try:
# Run main integration test
success1 = test_complete_sdk_workflow()
# Run webhook integration test
success2 = test_webhook_block_workflow()
if success1 and success2:
print("\n\n🌟 All integration tests passed successfully!")
sys.exit(0)
else:
print("\n\n❌ Some integration tests failed")
sys.exit(1)
except Exception as e:
print(f"\n\n❌ Integration test failed with error: {e}")
import traceback
traceback.print_exc()
sys.exit(1)

View File

@@ -0,0 +1,326 @@
"""
Tests for the SDK's integration patching mechanism.
This test suite verifies that the AutoRegistry correctly patches
existing integration points to include SDK-registered components.
"""
from unittest.mock import MagicMock, Mock, patch
import pytest
from backend.integrations.providers import ProviderName
from backend.sdk import (
AutoRegistry,
BaseOAuthHandler,
BaseWebhooksManager,
ProviderBuilder,
)
class MockOAuthHandler(BaseOAuthHandler):
"""Mock OAuth handler for testing."""
PROVIDER_NAME = ProviderName.GITHUB
@classmethod
async def authorize(cls, *args, **kwargs):
return "mock_auth"
class MockWebhookManager(BaseWebhooksManager):
"""Mock webhook manager for testing."""
PROVIDER_NAME = ProviderName.GITHUB
@classmethod
async def validate_payload(cls, webhook, request):
return {}, "test_event"
async def _register_webhook(self, *args, **kwargs):
return "mock_webhook_id", {}
async def _deregister_webhook(self, *args, **kwargs):
pass
class TestOAuthPatching:
"""Test OAuth handler patching functionality."""
def setup_method(self):
"""Clear registry and set up mocks."""
AutoRegistry.clear()
def test_oauth_handler_dictionary_patching(self):
"""Test that OAuth handlers are correctly patched into HANDLERS_BY_NAME."""
# Create original handlers
original_handlers = {
"existing_provider": Mock(spec=BaseOAuthHandler),
}
# Create a mock oauth module
mock_oauth_module = MagicMock()
mock_oauth_module.HANDLERS_BY_NAME = original_handlers.copy()
# Register a new provider with OAuth
(ProviderBuilder("new_oauth_provider").with_oauth(MockOAuthHandler).build())
# Apply patches
with patch("backend.integrations.oauth", mock_oauth_module):
AutoRegistry.patch_integrations()
patched_dict = mock_oauth_module.HANDLERS_BY_NAME
# Test that original handler still exists
assert "existing_provider" in patched_dict
assert (
patched_dict["existing_provider"]
== original_handlers["existing_provider"]
)
# Test that new handler is accessible
assert "new_oauth_provider" in patched_dict
assert patched_dict["new_oauth_provider"] == MockOAuthHandler
# Test dict methods
assert "existing_provider" in patched_dict.keys()
assert "new_oauth_provider" in patched_dict.keys()
# Test .get() method
assert patched_dict.get("new_oauth_provider") == MockOAuthHandler
assert patched_dict.get("nonexistent", "default") == "default"
# Test __contains__
assert "new_oauth_provider" in patched_dict
assert "nonexistent" not in patched_dict
def test_oauth_patching_with_multiple_providers(self):
"""Test patching with multiple OAuth providers."""
# Create another OAuth handler
class AnotherOAuthHandler(BaseOAuthHandler):
PROVIDER_NAME = ProviderName.GOOGLE
# Register multiple providers
(ProviderBuilder("oauth_provider_1").with_oauth(MockOAuthHandler).build())
(ProviderBuilder("oauth_provider_2").with_oauth(AnotherOAuthHandler).build())
# Mock the oauth module
mock_oauth_module = MagicMock()
mock_oauth_module.HANDLERS_BY_NAME = {}
with patch("backend.integrations.oauth", mock_oauth_module):
AutoRegistry.patch_integrations()
patched_dict = mock_oauth_module.HANDLERS_BY_NAME
# Both providers should be accessible
assert patched_dict["oauth_provider_1"] == MockOAuthHandler
assert patched_dict["oauth_provider_2"] == AnotherOAuthHandler
# Check values() method
values = list(patched_dict.values())
assert MockOAuthHandler in values
assert AnotherOAuthHandler in values
# Check items() method
items = dict(patched_dict.items())
assert items["oauth_provider_1"] == MockOAuthHandler
assert items["oauth_provider_2"] == AnotherOAuthHandler
class TestWebhookPatching:
"""Test webhook manager patching functionality."""
def setup_method(self):
"""Clear registry."""
AutoRegistry.clear()
def test_webhook_manager_patching(self):
"""Test that webhook managers are correctly patched."""
# Mock the original load_webhook_managers function
def mock_load_webhook_managers():
return {
"existing_webhook": Mock(spec=BaseWebhooksManager),
}
# Register a provider with webhooks
(
ProviderBuilder("webhook_provider")
.with_webhook_manager(MockWebhookManager)
.build()
)
# Mock the webhooks module
mock_webhooks_module = MagicMock()
mock_webhooks_module.load_webhook_managers = mock_load_webhook_managers
with patch("backend.integrations.webhooks", mock_webhooks_module):
AutoRegistry.patch_integrations()
# Call the patched function
result = mock_webhooks_module.load_webhook_managers()
# Original webhook should still exist
assert "existing_webhook" in result
# New webhook should be added
assert "webhook_provider" in result
assert result["webhook_provider"] == MockWebhookManager
def test_webhook_patching_no_original_function(self):
"""Test webhook patching when load_webhook_managers doesn't exist."""
# Mock webhooks module without load_webhook_managers
mock_webhooks_module = MagicMock(spec=[])
# Register a provider
(
ProviderBuilder("test_provider")
.with_webhook_manager(MockWebhookManager)
.build()
)
with patch("backend.integrations.webhooks", mock_webhooks_module):
# Should not raise an error
AutoRegistry.patch_integrations()
# Function should not be added if it didn't exist
assert not hasattr(mock_webhooks_module, "load_webhook_managers")
class TestPatchingEdgeCases:
"""Test edge cases and error handling in patching."""
def setup_method(self):
"""Clear registry."""
AutoRegistry.clear()
def test_patching_with_import_errors(self):
"""Test that patching handles import errors gracefully."""
# Register a provider
(ProviderBuilder("test_provider").with_oauth(MockOAuthHandler).build())
# Make the oauth module import fail
with patch("builtins.__import__", side_effect=ImportError("Mock import error")):
# Should not raise an error
AutoRegistry.patch_integrations()
def test_patching_with_attribute_errors(self):
"""Test handling of missing attributes."""
# Mock oauth module without HANDLERS_BY_NAME
mock_oauth_module = MagicMock(spec=[])
(ProviderBuilder("test_provider").with_oauth(MockOAuthHandler).build())
with patch("backend.integrations.oauth", mock_oauth_module):
# Should not raise an error
AutoRegistry.patch_integrations()
def test_patching_preserves_thread_safety(self):
"""Test that patching maintains thread safety."""
import threading
import time
results = []
errors = []
def register_provider(name, delay=0):
try:
time.sleep(delay)
(ProviderBuilder(name).with_oauth(MockOAuthHandler).build())
results.append(name)
except Exception as e:
errors.append((name, str(e)))
# Create multiple threads
threads = []
for i in range(5):
t = threading.Thread(
target=register_provider, args=(f"provider_{i}", i * 0.01)
)
threads.append(t)
t.start()
# Wait for all threads
for t in threads:
t.join()
# Check results
assert len(errors) == 0
assert len(results) == 5
assert len(AutoRegistry._providers) == 5
# Verify all providers are registered
for i in range(5):
assert f"provider_{i}" in AutoRegistry._providers
class TestPatchingIntegration:
"""Test the complete patching integration flow."""
def setup_method(self):
"""Clear registry."""
AutoRegistry.clear()
def test_complete_provider_registration_and_patching(self):
"""Test the complete flow from provider registration to patching."""
# Mock both oauth and webhooks modules
mock_oauth = MagicMock()
mock_oauth.HANDLERS_BY_NAME = {"original": Mock()}
mock_webhooks = MagicMock()
mock_webhooks.load_webhook_managers = lambda: {"original": Mock()}
# Create a fully featured provider
(
ProviderBuilder("complete_provider")
.with_api_key("COMPLETE_KEY", "Complete API Key")
.with_oauth(MockOAuthHandler, scopes=["read", "write"])
.with_webhook_manager(MockWebhookManager)
.build()
)
# Apply patches
with patch("backend.integrations.oauth", mock_oauth):
with patch("backend.integrations.webhooks", mock_webhooks):
AutoRegistry.patch_integrations()
# Verify OAuth patching
oauth_dict = mock_oauth.HANDLERS_BY_NAME
assert "complete_provider" in oauth_dict
assert oauth_dict["complete_provider"] == MockOAuthHandler
assert "original" in oauth_dict # Original preserved
# Verify webhook patching
webhook_result = mock_webhooks.load_webhook_managers()
assert "complete_provider" in webhook_result
assert webhook_result["complete_provider"] == MockWebhookManager
assert "original" in webhook_result # Original preserved
def test_patching_is_idempotent(self):
"""Test that calling patch_integrations multiple times is safe."""
mock_oauth = MagicMock()
mock_oauth.HANDLERS_BY_NAME = {}
# Register a provider
(ProviderBuilder("test_provider").with_oauth(MockOAuthHandler).build())
with patch("backend.integrations.oauth", mock_oauth):
# Patch multiple times
AutoRegistry.patch_integrations()
AutoRegistry.patch_integrations()
AutoRegistry.patch_integrations()
# Should still work correctly
oauth_dict = mock_oauth.HANDLERS_BY_NAME
assert "test_provider" in oauth_dict
assert oauth_dict["test_provider"] == MockOAuthHandler
# Should not have duplicates or errors
assert len([k for k in oauth_dict.keys() if k == "test_provider"]) == 1
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,511 @@
"""
Tests for the SDK auto-registration system via AutoRegistry.
This test suite verifies:
1. Provider registration and retrieval
2. OAuth handler registration via patches
3. Webhook manager registration via patches
4. Credential registration and management
5. Block configuration association
"""
from unittest.mock import MagicMock, Mock, patch
import pytest
from backend.integrations.providers import ProviderName
from backend.sdk import (
APIKeyCredentials,
AutoRegistry,
BaseOAuthHandler,
BaseWebhooksManager,
Block,
BlockConfiguration,
Provider,
ProviderBuilder,
)
class TestAutoRegistry:
"""Test the AutoRegistry functionality."""
def setup_method(self):
"""Clear registry before each test."""
AutoRegistry.clear()
def test_provider_registration(self):
"""Test that providers can be registered and retrieved."""
# Create a test provider
provider = Provider(
name="test_provider",
oauth_handler=None,
webhook_manager=None,
default_credentials=[],
base_costs=[],
supported_auth_types={"api_key"},
)
# Register it
AutoRegistry.register_provider(provider)
# Verify it's registered
assert "test_provider" in AutoRegistry._providers
assert AutoRegistry.get_provider("test_provider") == provider
def test_provider_with_oauth(self):
"""Test provider registration with OAuth handler."""
# Create a mock OAuth handler
class TestOAuthHandler(BaseOAuthHandler):
PROVIDER_NAME = ProviderName.GITHUB
provider = Provider(
name="oauth_provider",
oauth_handler=TestOAuthHandler,
webhook_manager=None,
default_credentials=[],
base_costs=[],
supported_auth_types={"oauth2"},
)
AutoRegistry.register_provider(provider)
# Verify OAuth handler is registered
assert "oauth_provider" in AutoRegistry._oauth_handlers
assert AutoRegistry._oauth_handlers["oauth_provider"] == TestOAuthHandler
def test_provider_with_webhook_manager(self):
"""Test provider registration with webhook manager."""
# Create a mock webhook manager
class TestWebhookManager(BaseWebhooksManager):
PROVIDER_NAME = ProviderName.GITHUB
provider = Provider(
name="webhook_provider",
oauth_handler=None,
webhook_manager=TestWebhookManager,
default_credentials=[],
base_costs=[],
supported_auth_types={"api_key"},
)
AutoRegistry.register_provider(provider)
# Verify webhook manager is registered
assert "webhook_provider" in AutoRegistry._webhook_managers
assert AutoRegistry._webhook_managers["webhook_provider"] == TestWebhookManager
def test_default_credentials_registration(self):
"""Test that default credentials are registered."""
# Create test credentials
from backend.sdk import SecretStr
cred1 = APIKeyCredentials(
id="test-cred-1",
provider="test_provider",
api_key=SecretStr("test-key-1"),
title="Test Credential 1",
)
cred2 = APIKeyCredentials(
id="test-cred-2",
provider="test_provider",
api_key=SecretStr("test-key-2"),
title="Test Credential 2",
)
provider = Provider(
name="test_provider",
oauth_handler=None,
webhook_manager=None,
default_credentials=[cred1, cred2],
base_costs=[],
supported_auth_types={"api_key"},
)
AutoRegistry.register_provider(provider)
# Verify credentials are registered
all_creds = AutoRegistry.get_all_credentials()
assert cred1 in all_creds
assert cred2 in all_creds
def test_api_key_registration(self):
"""Test API key environment variable registration."""
import os
# Set up a test environment variable
os.environ["TEST_API_KEY"] = "test-api-key-value"
try:
AutoRegistry.register_api_key("test_provider", "TEST_API_KEY")
# Verify the mapping is stored
assert AutoRegistry._api_key_mappings["test_provider"] == "TEST_API_KEY"
# Verify a credential was created
all_creds = AutoRegistry.get_all_credentials()
test_cred = next(
(c for c in all_creds if c.id == "test_provider-default"), None
)
assert test_cred is not None
assert test_cred.provider == "test_provider"
assert test_cred.api_key.get_secret_value() == "test-api-key-value" # type: ignore
finally:
# Clean up
del os.environ["TEST_API_KEY"]
def test_get_oauth_handlers(self):
"""Test retrieving all OAuth handlers."""
# Register multiple providers with OAuth
class TestOAuth1(BaseOAuthHandler):
PROVIDER_NAME = ProviderName.GITHUB
class TestOAuth2(BaseOAuthHandler):
PROVIDER_NAME = ProviderName.GOOGLE
provider1 = Provider(
name="provider1",
oauth_handler=TestOAuth1,
webhook_manager=None,
default_credentials=[],
base_costs=[],
supported_auth_types={"oauth2"},
)
provider2 = Provider(
name="provider2",
oauth_handler=TestOAuth2,
webhook_manager=None,
default_credentials=[],
base_costs=[],
supported_auth_types={"oauth2"},
)
AutoRegistry.register_provider(provider1)
AutoRegistry.register_provider(provider2)
handlers = AutoRegistry.get_oauth_handlers()
assert "provider1" in handlers
assert "provider2" in handlers
assert handlers["provider1"] == TestOAuth1
assert handlers["provider2"] == TestOAuth2
def test_block_configuration_registration(self):
"""Test registering block configuration."""
# Create a test block class
class TestBlock(Block):
pass
config = BlockConfiguration(
provider="test_provider",
costs=[],
default_credentials=[],
webhook_manager=None,
oauth_handler=None,
)
AutoRegistry.register_block_configuration(TestBlock, config)
# Verify it's registered
assert TestBlock in AutoRegistry._block_configurations
assert AutoRegistry._block_configurations[TestBlock] == config
def test_clear_registry(self):
"""Test clearing all registrations."""
# Add some registrations
provider = Provider(
name="test_provider",
oauth_handler=None,
webhook_manager=None,
default_credentials=[],
base_costs=[],
supported_auth_types={"api_key"},
)
AutoRegistry.register_provider(provider)
AutoRegistry.register_api_key("test", "TEST_KEY")
# Clear everything
AutoRegistry.clear()
# Verify everything is cleared
assert len(AutoRegistry._providers) == 0
assert len(AutoRegistry._default_credentials) == 0
assert len(AutoRegistry._oauth_handlers) == 0
assert len(AutoRegistry._webhook_managers) == 0
assert len(AutoRegistry._block_configurations) == 0
assert len(AutoRegistry._api_key_mappings) == 0
class TestAutoRegistryPatching:
"""Test the integration patching functionality."""
def setup_method(self):
"""Clear registry before each test."""
AutoRegistry.clear()
@patch("backend.integrations.oauth.HANDLERS_BY_NAME", {})
def test_oauth_handler_patching(self):
"""Test that OAuth handlers are patched into the system."""
# Create a test OAuth handler
class TestOAuthHandler(BaseOAuthHandler):
PROVIDER_NAME = ProviderName.GITHUB
# Register a provider with OAuth
provider = Provider(
name="patched_provider",
oauth_handler=TestOAuthHandler,
webhook_manager=None,
default_credentials=[],
base_costs=[],
supported_auth_types={"oauth2"},
)
AutoRegistry.register_provider(provider)
# Mock the oauth module
mock_oauth = MagicMock()
mock_oauth.HANDLERS_BY_NAME = {}
with patch("backend.integrations.oauth", mock_oauth):
# Apply patches
AutoRegistry.patch_integrations()
# Verify the patched dict works
patched_dict = mock_oauth.HANDLERS_BY_NAME
assert "patched_provider" in patched_dict
assert patched_dict["patched_provider"] == TestOAuthHandler
@patch("backend.integrations.webhooks.load_webhook_managers")
def test_webhook_manager_patching(self, mock_load_managers):
"""Test that webhook managers are patched into the system."""
# Set up the mock to return an empty dict
mock_load_managers.return_value = {}
# Create a test webhook manager
class TestWebhookManager(BaseWebhooksManager):
PROVIDER_NAME = ProviderName.GITHUB
# Register a provider with webhooks
provider = Provider(
name="webhook_provider",
oauth_handler=None,
webhook_manager=TestWebhookManager,
default_credentials=[],
base_costs=[],
supported_auth_types={"api_key"},
)
AutoRegistry.register_provider(provider)
# Mock the webhooks module
mock_webhooks = MagicMock()
mock_webhooks.load_webhook_managers = mock_load_managers
with patch("backend.integrations.webhooks", mock_webhooks):
# Apply patches
AutoRegistry.patch_integrations()
# Call the patched function
result = mock_webhooks.load_webhook_managers()
# Verify our webhook manager is included
assert "webhook_provider" in result
assert result["webhook_provider"] == TestWebhookManager
class TestProviderBuilder:
"""Test the ProviderBuilder fluent API."""
def setup_method(self):
"""Clear registry before each test."""
AutoRegistry.clear()
def test_basic_provider_builder(self):
"""Test building a basic provider."""
provider = (
ProviderBuilder("test_provider")
.with_api_key("TEST_API_KEY", "Test API Key")
.build()
)
assert provider.name == "test_provider"
assert "api_key" in provider.supported_auth_types
assert AutoRegistry.get_provider("test_provider") == provider
def test_provider_builder_with_oauth(self):
"""Test building a provider with OAuth."""
class TestOAuth(BaseOAuthHandler):
PROVIDER_NAME = ProviderName.GITHUB
provider = (
ProviderBuilder("oauth_test")
.with_oauth(TestOAuth, scopes=["read", "write"])
.build()
)
assert provider.oauth_handler == TestOAuth
assert "oauth2" in provider.supported_auth_types
def test_provider_builder_with_webhook(self):
"""Test building a provider with webhook manager."""
class TestWebhook(BaseWebhooksManager):
PROVIDER_NAME = ProviderName.GITHUB
provider = (
ProviderBuilder("webhook_test").with_webhook_manager(TestWebhook).build()
)
assert provider.webhook_manager == TestWebhook
def test_provider_builder_with_base_cost(self):
"""Test building a provider with base costs."""
from backend.data.cost import BlockCostType
provider = (
ProviderBuilder("cost_test")
.with_base_cost(10, BlockCostType.RUN)
.with_base_cost(5, BlockCostType.BYTE)
.build()
)
assert len(provider.base_costs) == 2
assert provider.base_costs[0].cost_amount == 10
assert provider.base_costs[0].cost_type == BlockCostType.RUN
assert provider.base_costs[1].cost_amount == 5
assert provider.base_costs[1].cost_type == BlockCostType.BYTE
def test_provider_builder_with_api_client(self):
"""Test building a provider with API client factory."""
def mock_client_factory():
return Mock()
provider = (
ProviderBuilder("client_test").with_api_client(mock_client_factory).build()
)
assert provider._api_client_factory == mock_client_factory
def test_provider_builder_with_error_handler(self):
"""Test building a provider with error handler."""
def mock_error_handler(exc: Exception) -> str:
return f"Error: {str(exc)}"
provider = (
ProviderBuilder("error_test").with_error_handler(mock_error_handler).build()
)
assert provider._error_handler == mock_error_handler
def test_provider_builder_complete_example(self):
"""Test building a complete provider with all features."""
from backend.data.cost import BlockCostType
class TestOAuth(BaseOAuthHandler):
PROVIDER_NAME = ProviderName.GITHUB
class TestWebhook(BaseWebhooksManager):
PROVIDER_NAME = ProviderName.GITHUB
def client_factory():
return Mock()
def error_handler(exc):
return str(exc)
provider = (
ProviderBuilder("complete_test")
.with_api_key("COMPLETE_API_KEY", "Complete API Key")
.with_oauth(TestOAuth, scopes=["read"])
.with_webhook_manager(TestWebhook)
.with_base_cost(100, BlockCostType.RUN)
.with_api_client(client_factory)
.with_error_handler(error_handler)
.with_config(custom_setting="value")
.build()
)
# Verify all settings
assert provider.name == "complete_test"
assert "api_key" in provider.supported_auth_types
assert "oauth2" in provider.supported_auth_types
assert provider.oauth_handler == TestOAuth
assert provider.webhook_manager == TestWebhook
assert len(provider.base_costs) == 1
assert provider._api_client_factory == client_factory
assert provider._error_handler == error_handler
assert provider.get_config("custom_setting") == "value" # from with_config
# Verify it's registered
assert AutoRegistry.get_provider("complete_test") == provider
assert "complete_test" in AutoRegistry._oauth_handlers
assert "complete_test" in AutoRegistry._webhook_managers
class TestSDKImports:
"""Test that all expected exports are available from the SDK."""
def test_core_block_imports(self):
"""Test core block system imports."""
from backend.sdk import Block, BlockCategory
# Just verify they're importable
assert Block is not None
assert BlockCategory is not None
def test_schema_imports(self):
"""Test schema and model imports."""
from backend.sdk import APIKeyCredentials, SchemaField
assert SchemaField is not None
assert APIKeyCredentials is not None
def test_type_alias_imports(self):
"""Test type alias imports."""
from backend.sdk import Boolean, Float, Integer, String
# Verify they're the correct types
assert String is str
assert Integer is int
assert Float is float
assert Boolean is bool
def test_cost_system_imports(self):
"""Test cost system imports."""
from backend.sdk import BlockCost, BlockCostType
assert BlockCost is not None
assert BlockCostType is not None
def test_utility_imports(self):
"""Test utility imports."""
from backend.sdk import BaseModel, json, requests
assert json is not None
assert BaseModel is not None
assert requests is not None
def test_integration_imports(self):
"""Test integration imports."""
from backend.sdk import ProviderName
assert ProviderName is not None
def test_sdk_component_imports(self):
"""Test SDK-specific component imports."""
from backend.sdk import AutoRegistry, ProviderBuilder
assert AutoRegistry is not None
assert ProviderBuilder is not None
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -1,197 +0,0 @@
#!/usr/bin/env python3
"""
Standalone SDK tests that can run without Redis/PostgreSQL/RabbitMQ.
Run with: python test/sdk/test_sdk_standalone.py
"""
import sys
from pathlib import Path
# Add backend to path
backend_path = Path(__file__).parent.parent.parent
sys.path.insert(0, str(backend_path))
def test_sdk_imports():
"""Test that SDK imports work correctly"""
print("\n=== Testing SDK Imports ===")
# Import SDK
import backend.sdk as sdk
# Verify imports
assert hasattr(sdk, "Block")
assert hasattr(sdk, "BlockSchema")
assert hasattr(sdk, "SchemaField")
assert hasattr(sdk, "provider")
assert hasattr(sdk, "cost_config")
print("✅ SDK imports work correctly")
def test_dynamic_provider():
"""Test dynamic provider enum"""
print("\n=== Testing Dynamic Provider ===")
from backend.sdk import ProviderName
# Test existing provider
github = ProviderName.GITHUB
assert github.value == "github"
# Test dynamic provider
custom = ProviderName("my-custom-provider")
assert custom.value == "my-custom-provider"
assert isinstance(custom, ProviderName)
print("✅ Dynamic provider enum works")
def test_auto_registry():
"""Test auto-registration system"""
print("\n=== Testing Auto-Registry ===")
from backend.sdk import BlockCost, BlockCostType, cost_config, provider
from backend.sdk.auto_registry import get_registry
registry = get_registry()
initial_count = len(registry.providers)
# Register a test provider
@provider("test-provider-xyz")
class TestBlock:
pass
assert "test-provider-xyz" in registry.providers
assert len(registry.providers) == initial_count + 1
# Register costs
@cost_config(BlockCost(cost_amount=5, cost_type=BlockCostType.RUN))
class TestBlock2:
pass
assert TestBlock2 in registry.block_costs
assert registry.block_costs[TestBlock2][0].cost_amount == 5
print("✅ Auto-registry works correctly")
def test_complete_block_creation():
"""Test creating a complete block with SDK"""
print("\n=== Testing Complete Block Creation ===")
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockCost,
BlockCostType,
BlockOutput,
BlockSchema,
CredentialsField,
CredentialsMetaInput,
Integer,
SchemaField,
SecretStr,
String,
cost_config,
default_credentials,
provider,
)
@provider("test-ai-service")
@cost_config(
BlockCost(cost_amount=10, cost_type=BlockCostType.RUN),
BlockCost(cost_amount=2, cost_type=BlockCostType.BYTE),
)
@default_credentials(
APIKeyCredentials(
id="test-ai-default",
provider="test-ai-service",
api_key=SecretStr("test-default-key"),
title="Test AI Service Default Key",
)
)
class TestAIBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput = CredentialsField(
provider="test-ai-service",
supported_credential_types={"api_key"},
description="API credentials",
)
prompt: String = SchemaField(description="AI prompt")
class Output(BlockSchema):
response: String = SchemaField(description="AI response")
tokens: Integer = SchemaField(description="Token count")
def __init__(self):
super().__init__(
id="a1b2c3d4-e5f6-7890-abcd-ef1234567890",
description="Test AI Service Block",
categories={BlockCategory.AI, BlockCategory.TEXT},
input_schema=TestAIBlock.Input,
output_schema=TestAIBlock.Output,
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
yield "response", f"AI says: {input_data.prompt}"
yield "tokens", len(input_data.prompt.split())
# Verify block creation
block = TestAIBlock()
assert block.id == "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
assert BlockCategory.AI in block.categories
# Verify auto-registration
from backend.sdk.auto_registry import get_registry
registry = get_registry()
assert "test-ai-service" in registry.providers
assert TestAIBlock in registry.block_costs
assert len(registry.block_costs[TestAIBlock]) == 2
print("✅ Complete block creation works")
def run_all_tests():
"""Run all standalone tests"""
print("\n" + "=" * 60)
print("🧪 Running Standalone SDK Tests")
print("=" * 60)
tests = [
test_sdk_imports,
test_dynamic_provider,
test_auto_registry,
test_complete_block_creation,
]
passed = 0
failed = 0
for test in tests:
try:
test()
passed += 1
except Exception as e:
print(f"{test.__name__} failed: {e}")
failed += 1
import traceback
traceback.print_exc()
print("\n" + "=" * 60)
print(f"📊 Results: {passed} passed, {failed} failed")
print("=" * 60)
if failed == 0:
print("\n🎉 All standalone SDK tests passed!")
return True
else:
print(f"\n⚠️ {failed} tests failed")
return False
if __name__ == "__main__":
success = run_all_tests()
sys.exit(0 if success else 1)

View File

@@ -0,0 +1,505 @@
"""
Tests for SDK webhook functionality.
This test suite verifies webhook blocks and webhook manager integration.
"""
from enum import Enum
import pytest
from backend.integrations.providers import ProviderName
from backend.sdk import (
APIKeyCredentials,
AutoRegistry,
BaseModel,
BaseWebhooksManager,
Block,
BlockCategory,
BlockOutput,
BlockSchema,
BlockWebhookConfig,
Boolean,
CredentialsField,
CredentialsMetaInput,
Field,
Integer,
ProviderBuilder,
SchemaField,
SecretStr,
String,
)
class TestWebhookTypes(str, Enum):
"""Test webhook event types."""
CREATED = "created"
UPDATED = "updated"
DELETED = "deleted"
class TestWebhooksManager(BaseWebhooksManager):
"""Test webhook manager implementation."""
PROVIDER_NAME = ProviderName.GITHUB # Reuse for testing
class WebhookType(str, Enum):
TEST = "test"
@classmethod
async def validate_payload(cls, webhook, request):
"""Validate incoming webhook payload."""
# Mock implementation
payload = {"test": "data"}
event_type = "test_event"
return payload, event_type
async def _register_webhook(
self,
credentials,
webhook_type: str,
resource: str,
events: list[str],
ingress_url: str,
secret: str,
) -> tuple[str, dict]:
"""Register webhook with external service."""
# Mock implementation
webhook_id = f"test_webhook_{resource}"
config = {
"webhook_type": webhook_type,
"resource": resource,
"events": events,
"url": ingress_url,
}
return webhook_id, config
async def _deregister_webhook(self, webhook, credentials) -> None:
"""Deregister webhook from external service."""
# Mock implementation
pass
class TestWebhookBlock(Block):
"""Test webhook block implementation."""
class Input(BlockSchema):
credentials: CredentialsMetaInput = CredentialsField(
provider="test_webhooks",
supported_credential_types={"api_key"},
description="Webhook service credentials",
)
webhook_url: String = SchemaField(
description="URL to receive webhooks",
)
resource_id: String = SchemaField(
description="Resource to monitor",
)
events: list[TestWebhookTypes] = SchemaField(
description="Events to listen for",
default=[TestWebhookTypes.CREATED],
)
payload: dict = SchemaField(
description="Webhook payload",
default={},
)
class Output(BlockSchema):
webhook_id: String = SchemaField(description="Registered webhook ID")
is_active: Boolean = SchemaField(description="Webhook is active")
event_count: Integer = SchemaField(description="Number of events configured")
def __init__(self):
super().__init__(
id="test-webhook-block",
description="Test webhook block",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=TestWebhookBlock.Input,
output_schema=TestWebhookBlock.Output,
webhook_config=BlockWebhookConfig(
provider="test_webhooks", # type: ignore
webhook_type="test",
resource_format="{resource_id}",
),
)
def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
# Simulate webhook registration
webhook_id = f"webhook_{input_data.resource_id}"
yield "webhook_id", webhook_id
yield "is_active", True
yield "event_count", len(input_data.events)
class TestWebhookBlockCreation:
"""Test creating webhook blocks with the SDK."""
def setup_method(self):
"""Set up test environment."""
AutoRegistry.clear()
# Register a provider with webhook support
self.provider = (
ProviderBuilder("test_webhooks")
.with_api_key("TEST_WEBHOOK_KEY", "Test Webhook API Key")
.with_webhook_manager(TestWebhooksManager)
.build()
)
def test_basic_webhook_block(self):
"""Test creating a basic webhook block."""
block = TestWebhookBlock()
# Verify block configuration
assert block.webhook_config is not None
assert block.webhook_config.provider == "test_webhooks"
assert block.webhook_config.webhook_type == "test"
assert "{resource_id}" in block.webhook_config.resource_format # type: ignore
# Test block execution
test_creds = APIKeyCredentials(
id="test-webhook-creds",
provider="test_webhooks",
api_key=SecretStr("test-key"),
title="Test Webhook Key",
)
outputs = dict(
block.run(
TestWebhookBlock.Input(
credentials={ # type: ignore
"provider": "test_webhooks",
"id": "test-webhook-creds",
"type": "api_key",
},
webhook_url="https://example.com/webhook",
resource_id="resource_123",
events=[TestWebhookTypes.CREATED, TestWebhookTypes.UPDATED],
),
credentials=test_creds,
)
)
assert outputs["webhook_id"] == "webhook_resource_123"
assert outputs["is_active"] is True
assert outputs["event_count"] == 2
def test_webhook_block_with_filters(self):
"""Test webhook block with event filters."""
class EventFilterModel(BaseModel):
include_system: bool = Field(default=False)
severity_levels: list[str] = Field(
default_factory=lambda: ["info", "warning"]
)
class FilteredWebhookBlock(Block):
"""Webhook block with filtering."""
class Input(BlockSchema):
credentials: CredentialsMetaInput = CredentialsField(
provider="test_webhooks",
supported_credential_types={"api_key"},
)
resource: String = SchemaField(description="Resource to monitor")
filters: EventFilterModel = SchemaField(
description="Event filters",
default_factory=EventFilterModel,
)
payload: dict = SchemaField(
description="Webhook payload",
default={},
)
class Output(BlockSchema):
webhook_active: Boolean = SchemaField(description="Webhook active")
filter_summary: String = SchemaField(description="Active filters")
def __init__(self):
super().__init__(
id="filtered-webhook-block",
description="Webhook with filters",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=FilteredWebhookBlock.Input,
output_schema=FilteredWebhookBlock.Output,
webhook_config=BlockWebhookConfig(
provider="test_webhooks", # type: ignore
webhook_type="filtered",
resource_format="{resource}",
),
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
filters = input_data.filters
filter_parts = []
if filters.include_system:
filter_parts.append("system events")
filter_parts.append(f"{len(filters.severity_levels)} severity levels")
yield "webhook_active", True
yield "filter_summary", ", ".join(filter_parts)
# Test the block
block = FilteredWebhookBlock()
test_creds = APIKeyCredentials(
id="test-creds",
provider="test_webhooks",
api_key=SecretStr("key"),
title="Test Key",
)
# Test with default filters
outputs = dict(
block.run(
FilteredWebhookBlock.Input(
credentials={ # type: ignore
"provider": "test_webhooks",
"id": "test-creds",
"type": "api_key",
},
resource="test_resource",
),
credentials=test_creds,
)
)
assert outputs["webhook_active"] is True
assert "2 severity levels" in outputs["filter_summary"]
# Test with custom filters
custom_filters = EventFilterModel(
include_system=True,
severity_levels=["error", "critical"],
)
outputs = dict(
block.run(
FilteredWebhookBlock.Input(
credentials={ # type: ignore
"provider": "test_webhooks",
"id": "test-creds",
"type": "api_key",
},
resource="test_resource",
filters=custom_filters,
),
credentials=test_creds,
)
)
assert "system events" in outputs["filter_summary"]
assert "2 severity levels" in outputs["filter_summary"]
class TestWebhookManagerIntegration:
"""Test webhook manager integration with AutoRegistry."""
def setup_method(self):
"""Clear registry."""
AutoRegistry.clear()
def test_webhook_manager_registration(self):
"""Test that webhook managers are properly registered."""
# Create multiple webhook managers
class WebhookManager1(BaseWebhooksManager):
PROVIDER_NAME = ProviderName.GITHUB
class WebhookManager2(BaseWebhooksManager):
PROVIDER_NAME = ProviderName.GOOGLE
# Register providers with webhook managers
(
ProviderBuilder("webhook_service_1")
.with_webhook_manager(WebhookManager1)
.build()
)
(
ProviderBuilder("webhook_service_2")
.with_webhook_manager(WebhookManager2)
.build()
)
# Verify registration
managers = AutoRegistry.get_webhook_managers()
assert "webhook_service_1" in managers
assert "webhook_service_2" in managers
assert managers["webhook_service_1"] == WebhookManager1
assert managers["webhook_service_2"] == WebhookManager2
def test_webhook_block_with_provider_manager(self):
"""Test webhook block using a provider's webhook manager."""
# Register provider with webhook manager
(
ProviderBuilder("integrated_webhooks")
.with_api_key("INTEGRATED_KEY", "Integrated Webhook Key")
.with_webhook_manager(TestWebhooksManager)
.build()
)
# Create a block that uses this provider
class IntegratedWebhookBlock(Block):
"""Block using integrated webhook manager."""
class Input(BlockSchema):
credentials: CredentialsMetaInput = CredentialsField(
provider="integrated_webhooks",
supported_credential_types={"api_key"},
)
target: String = SchemaField(description="Webhook target")
payload: dict = SchemaField(
description="Webhook payload",
default={},
)
class Output(BlockSchema):
status: String = SchemaField(description="Webhook status")
manager_type: String = SchemaField(description="Manager type used")
def __init__(self):
super().__init__(
id="integrated-webhook-block",
description="Uses integrated webhook manager",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=IntegratedWebhookBlock.Input,
output_schema=IntegratedWebhookBlock.Output,
webhook_config=BlockWebhookConfig(
provider="integrated_webhooks", # type: ignore
webhook_type=TestWebhooksManager.WebhookType.TEST,
resource_format="{target}",
),
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
# Get the webhook manager for this provider
managers = AutoRegistry.get_webhook_managers()
manager_class = managers.get("integrated_webhooks")
yield "status", "configured"
yield "manager_type", (
manager_class.__name__ if manager_class else "none"
)
# Test the block
block = IntegratedWebhookBlock()
test_creds = APIKeyCredentials(
id="integrated-creds",
provider="integrated_webhooks",
api_key=SecretStr("key"),
title="Integrated Key",
)
outputs = dict(
block.run(
IntegratedWebhookBlock.Input(
credentials={ # type: ignore
"provider": "integrated_webhooks",
"id": "integrated-creds",
"type": "api_key",
},
target="test_target",
),
credentials=test_creds,
)
)
assert outputs["status"] == "configured"
assert outputs["manager_type"] == "TestWebhooksManager"
class TestWebhookEventHandling:
"""Test webhook event handling in blocks."""
def test_webhook_event_processing_block(self):
"""Test a block that processes webhook events."""
class WebhookEventBlock(Block):
"""Block that processes webhook events."""
class Input(BlockSchema):
event_type: String = SchemaField(description="Type of webhook event")
payload: dict = SchemaField(description="Webhook payload")
verify_signature: Boolean = SchemaField(
description="Whether to verify webhook signature",
default=True,
)
class Output(BlockSchema):
processed: Boolean = SchemaField(description="Event was processed")
event_summary: String = SchemaField(description="Summary of event")
action_required: Boolean = SchemaField(description="Action required")
def __init__(self):
super().__init__(
id="webhook-event-processor",
description="Processes incoming webhook events",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=WebhookEventBlock.Input,
output_schema=WebhookEventBlock.Output,
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
# Process based on event type
event_type = input_data.event_type
payload = input_data.payload
if event_type == "created":
summary = f"New item created: {payload.get('id', 'unknown')}"
action_required = True
elif event_type == "updated":
summary = f"Item updated: {payload.get('id', 'unknown')}"
action_required = False
elif event_type == "deleted":
summary = f"Item deleted: {payload.get('id', 'unknown')}"
action_required = True
else:
summary = f"Unknown event: {event_type}"
action_required = False
yield "processed", True
yield "event_summary", summary
yield "action_required", action_required
# Test the block with different events
block = WebhookEventBlock()
# Test created event
outputs = dict(
block.run(
WebhookEventBlock.Input(
event_type="created",
payload={"id": "123", "name": "Test Item"},
)
)
)
assert outputs["processed"] is True
assert "New item created: 123" in outputs["event_summary"]
assert outputs["action_required"] is True
# Test updated event
outputs = dict(
block.run(
WebhookEventBlock.Input(
event_type="updated",
payload={"id": "456", "changes": ["name", "status"]},
)
)
)
assert outputs["processed"] is True
assert "Item updated: 456" in outputs["event_summary"]
assert outputs["action_required"] is False
if __name__ == "__main__":
pytest.main([__file__, "-v"])