mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
fix: Fix type errors and enable custom providers in SDK
- Update CredentialsMetaInput.allowed_providers() to return None for unrestricted providers - Fix pyright type errors by using ProviderName() constructor instead of string literals - Update webhook manager signatures in tests to match abstract base class - Add comprehensive test suites for custom provider functionality - Configure ruff to ignore star import warnings in SDK and test files - Ensure all formatting tools (ruff, black, isort, pyright) pass successfully This enables SDK users to define custom providers without modifying core enums while maintaining strict type safety throughout the codebase.
This commit is contained in:
@@ -9,17 +9,33 @@ This demonstrates:
|
||||
|
||||
from backend.sdk import * # noqa: F403, F405
|
||||
|
||||
# Define test credentials for testing
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="example-service", # Custom provider name
|
||||
api_key=SecretStr("mock-example-api-key"),
|
||||
title="Mock Example Service API key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
# Example of a simple service with auto-registration
|
||||
@provider("exampleservice")
|
||||
@provider("example-service") # Custom provider demonstrating SDK flexibility
|
||||
@cost_config(
|
||||
BlockCost(cost_amount=2, cost_type=BlockCostType.RUN),
|
||||
BlockCost(cost_amount=1, cost_type=BlockCostType.BYTE),
|
||||
)
|
||||
@default_credentials(
|
||||
APIKeyCredentials(
|
||||
id="exampleservice-default",
|
||||
provider="exampleservice",
|
||||
id="example-service-default",
|
||||
provider="example-service", # Custom provider name
|
||||
api_key=SecretStr("example-default-api-key"),
|
||||
title="Example Service Default API Key",
|
||||
expires_at=None,
|
||||
@@ -39,7 +55,7 @@ class ExampleSDKBlock(Block):
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = CredentialsField(
|
||||
provider="exampleservice",
|
||||
provider="example-service", # Custom provider name
|
||||
supported_credential_types={"api_key"},
|
||||
description="Credentials for Example Service API",
|
||||
)
|
||||
@@ -63,12 +79,17 @@ class ExampleSDKBlock(Block):
|
||||
categories={BlockCategory.TEXT, BlockCategory.BASIC},
|
||||
input_schema=ExampleSDKBlock.Input,
|
||||
output_schema=ExampleSDKBlock.Output,
|
||||
test_input={"text": "Test input", "max_length": 50},
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"text": "Test input",
|
||||
"max_length": 50,
|
||||
},
|
||||
test_output=[
|
||||
("result", "PROCESSED: Test input"),
|
||||
("length", 20),
|
||||
("length", 21), # Length of "PROCESSED: Test input"
|
||||
("api_key_used", True),
|
||||
],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
def run(
|
||||
|
||||
@@ -41,6 +41,9 @@ from pydantic_core import (
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.settings import Secrets
|
||||
|
||||
# Type alias for any provider name (including custom ones)
|
||||
AnyProviderName = str # Will be validated as ProviderName at runtime
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.block import BlockSchema
|
||||
|
||||
@@ -288,12 +291,20 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
type: CT
|
||||
|
||||
@classmethod
|
||||
def allowed_providers(cls) -> tuple[ProviderName, ...]:
|
||||
return get_args(cls.model_fields["provider"].annotation)
|
||||
def allowed_providers(cls) -> tuple[ProviderName, ...] | None:
|
||||
args = get_args(cls.model_fields["provider"].annotation)
|
||||
# If no type parameters are provided, allow any provider
|
||||
if not args:
|
||||
return None # None means no specific providers, allow any
|
||||
return args
|
||||
|
||||
@classmethod
|
||||
def allowed_cred_types(cls) -> tuple[CredentialsType, ...]:
|
||||
return get_args(cls.model_fields["type"].annotation)
|
||||
args = get_args(cls.model_fields["type"].annotation)
|
||||
# If no type parameters are provided, allow any credential type
|
||||
if not args:
|
||||
return ("api_key", "oauth2", "user_password") # All credential types
|
||||
return args
|
||||
|
||||
@classmethod
|
||||
def validate_credentials_field_schema(cls, model: type["BlockSchema"]):
|
||||
@@ -313,7 +324,12 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
f"{field_schema}"
|
||||
) from e
|
||||
|
||||
if len(cls.allowed_providers()) > 1 and not schema_extra.discriminator:
|
||||
providers = cls.allowed_providers()
|
||||
if (
|
||||
providers is not None
|
||||
and len(providers) > 1
|
||||
and not schema_extra.discriminator
|
||||
):
|
||||
raise TypeError(
|
||||
f"Multi-provider CredentialsField '{field_name}' "
|
||||
"requires discriminator!"
|
||||
@@ -321,7 +337,12 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
|
||||
@staticmethod
|
||||
def _add_json_schema_extra(schema, cls: CredentialsMetaInput):
|
||||
schema["credentials_provider"] = cls.allowed_providers()
|
||||
allowed_providers = cls.allowed_providers()
|
||||
# If no specific providers (None), allow any string
|
||||
if allowed_providers is None:
|
||||
schema["credentials_provider"] = ["string"] # Allow any string provider
|
||||
else:
|
||||
schema["credentials_provider"] = allowed_providers
|
||||
schema["credentials_types"] = cls.allowed_cred_types()
|
||||
|
||||
model_config = ConfigDict(
|
||||
|
||||
@@ -65,4 +65,41 @@ class ProviderName(str, Enum):
|
||||
return pseudo_member
|
||||
return None # type: ignore
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(cls, schema, handler):
|
||||
"""
|
||||
Custom JSON schema generation that allows any string value,
|
||||
not just the predefined enum values.
|
||||
"""
|
||||
# Get the default schema
|
||||
json_schema = handler(schema)
|
||||
|
||||
# Remove the enum constraint to allow any string
|
||||
if "enum" in json_schema:
|
||||
del json_schema["enum"]
|
||||
|
||||
# Keep the type as string
|
||||
json_schema["type"] = "string"
|
||||
|
||||
# Update description to indicate custom providers are allowed
|
||||
json_schema["description"] = (
|
||||
"Provider name for integrations. "
|
||||
"Can be any string value, including custom provider names."
|
||||
)
|
||||
|
||||
return json_schema
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source_type, handler):
|
||||
"""
|
||||
Pydantic v2 core schema that allows any string value.
|
||||
"""
|
||||
from pydantic_core import core_schema
|
||||
|
||||
# Create a string schema that validates any string
|
||||
return core_schema.no_info_after_validator_function(
|
||||
cls,
|
||||
core_schema.str_schema(),
|
||||
)
|
||||
|
||||
# --8<-- [end:ProviderName]
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# noqa: E402
|
||||
"""
|
||||
AutoGPT Platform Block Development SDK
|
||||
|
||||
@@ -13,6 +14,10 @@ This module provides:
|
||||
- Auto-registration decorators
|
||||
"""
|
||||
|
||||
# Pre-configured CredentialsMetaInput that accepts any provider
|
||||
# Uses ProviderName which has _missing_ method to accept any string
|
||||
from typing import Literal as _Literal
|
||||
|
||||
# === CORE BLOCK SYSTEM ===
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
@@ -23,10 +28,9 @@ from backend.data.block import (
|
||||
BlockType,
|
||||
BlockWebhookConfig,
|
||||
)
|
||||
from backend.data.model import APIKeyCredentials, CredentialsField
|
||||
from backend.data.model import CredentialsMetaInput as _CredentialsMetaInput
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
OAuth2Credentials,
|
||||
SchemaField,
|
||||
@@ -36,6 +40,10 @@ from backend.data.model import (
|
||||
# === INTEGRATIONS ===
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
CredentialsMetaInput = _CredentialsMetaInput[
|
||||
ProviderName, _Literal["api_key", "oauth2", "user_password"]
|
||||
]
|
||||
|
||||
# === WEBHOOKS ===
|
||||
try:
|
||||
from backend.integrations.webhooks._base import BaseWebhooksManager
|
||||
|
||||
@@ -116,5 +116,8 @@ target-version = "py310"
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
# Example and test files using SDK star imports
|
||||
"backend/blocks/examples/*.py" = ["F403", "F405"]
|
||||
"backend/sdk/__init__.py" = ["E402"] # Module level imports after try/except blocks
|
||||
"test/sdk/demo_sdk_block.py" = ["F403", "F405"]
|
||||
"test/sdk/test_sdk_integration.py" = ["F403", "F405", "F406", "E402"]
|
||||
"test/sdk/test_custom_provider.py" = ["F403", "F405"]
|
||||
"test/sdk/test_custom_provider_advanced.py" = ["F403", "F405"]
|
||||
|
||||
232
autogpt_platform/backend/test/sdk/test_custom_provider.py
Normal file
232
autogpt_platform/backend/test/sdk/test_custom_provider.py
Normal file
@@ -0,0 +1,232 @@
|
||||
# noqa: F405
|
||||
"""
|
||||
Test custom provider functionality in the SDK.
|
||||
|
||||
This test suite verifies that the SDK properly supports dynamic provider
|
||||
registration and that custom providers work correctly with the system.
|
||||
"""
|
||||
|
||||
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.sdk import * # noqa: F403
|
||||
from backend.sdk.auto_registry import get_registry
|
||||
from backend.util.test import execute_block_test
|
||||
|
||||
# Test credentials for custom providers
|
||||
CUSTOM_TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="custom-provider-test-creds",
|
||||
provider="my-custom-service",
|
||||
api_key=SecretStr("test-api-key-12345"),
|
||||
title="Custom Service Test Credentials",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
CUSTOM_TEST_CREDENTIALS_INPUT = {
|
||||
"provider": CUSTOM_TEST_CREDENTIALS.provider,
|
||||
"id": CUSTOM_TEST_CREDENTIALS.id,
|
||||
"type": CUSTOM_TEST_CREDENTIALS.type,
|
||||
"title": CUSTOM_TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
@provider("my-custom-service")
|
||||
@cost_config(
|
||||
BlockCost(cost_amount=10, cost_type=BlockCostType.RUN),
|
||||
BlockCost(cost_amount=2, cost_type=BlockCostType.BYTE),
|
||||
)
|
||||
@default_credentials(
|
||||
APIKeyCredentials(
|
||||
id="my-custom-service-default",
|
||||
provider="my-custom-service",
|
||||
api_key=SecretStr("default-custom-api-key"),
|
||||
title="My Custom Service Default API Key",
|
||||
expires_at=None,
|
||||
)
|
||||
)
|
||||
class CustomProviderBlock(Block):
|
||||
"""Test block with a completely custom provider."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = CredentialsField(
|
||||
provider="my-custom-service",
|
||||
supported_credential_types={"api_key"},
|
||||
description="Custom service credentials",
|
||||
)
|
||||
message: String = SchemaField(
|
||||
description="Message to process", default="Hello from custom provider!"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
result: String = SchemaField(description="Processed message")
|
||||
provider_used: String = SchemaField(description="Provider name used")
|
||||
credentials_valid: Boolean = SchemaField(
|
||||
description="Whether credentials were valid"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d1234567-89ab-cdef-0123-456789abcdef",
|
||||
description="Test block demonstrating custom provider support",
|
||||
categories={BlockCategory.TEXT},
|
||||
input_schema=CustomProviderBlock.Input,
|
||||
output_schema=CustomProviderBlock.Output,
|
||||
test_input={
|
||||
"credentials": CUSTOM_TEST_CREDENTIALS_INPUT,
|
||||
"message": "Test message",
|
||||
},
|
||||
test_output=[
|
||||
("result", "CUSTOM: Test message"),
|
||||
("provider_used", "my-custom-service"),
|
||||
("credentials_valid", True),
|
||||
],
|
||||
test_credentials=CUSTOM_TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Verify we got the right credentials
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
yield "result", f"CUSTOM: {input_data.message}"
|
||||
yield "provider_used", credentials.provider
|
||||
yield "credentials_valid", bool(api_key)
|
||||
|
||||
|
||||
@provider("another-custom-provider")
|
||||
@cost_config(
|
||||
BlockCost(cost_amount=5, cost_type=BlockCostType.RUN),
|
||||
)
|
||||
class AnotherCustomProviderBlock(Block):
|
||||
"""Another test block to verify multiple custom providers work."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = CredentialsField(
|
||||
provider="another-custom-provider",
|
||||
supported_credential_types={"api_key"},
|
||||
)
|
||||
data: String = SchemaField(description="Input data")
|
||||
|
||||
class Output(BlockSchema):
|
||||
processed: String = SchemaField(description="Processed data")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="e2345678-9abc-def0-1234-567890abcdef",
|
||||
description="Another custom provider test",
|
||||
categories={BlockCategory.TEXT},
|
||||
input_schema=AnotherCustomProviderBlock.Input,
|
||||
output_schema=AnotherCustomProviderBlock.Output,
|
||||
test_input={
|
||||
"credentials": {
|
||||
"provider": "another-custom-provider",
|
||||
"id": "test-creds-2",
|
||||
"type": "api_key",
|
||||
"title": "Test Creds 2",
|
||||
},
|
||||
"data": "test data",
|
||||
},
|
||||
test_output=[("processed", "ANOTHER: test data")],
|
||||
test_credentials=APIKeyCredentials(
|
||||
id="test-creds-2",
|
||||
provider="another-custom-provider",
|
||||
api_key=SecretStr("another-test-key"),
|
||||
title="Another Test Key",
|
||||
expires_at=None,
|
||||
),
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
yield "processed", f"ANOTHER: {input_data.data}"
|
||||
|
||||
|
||||
class TestCustomProvider:
|
||||
"""Test suite for custom provider functionality."""
|
||||
|
||||
def test_custom_provider_enum_accepts_any_string(self):
|
||||
"""Test that ProviderName enum accepts any string value."""
|
||||
# Test with a completely new provider name
|
||||
custom_provider = ProviderName("my-totally-new-provider")
|
||||
assert custom_provider.value == "my-totally-new-provider"
|
||||
|
||||
# Test with existing provider
|
||||
existing_provider = ProviderName.OPENAI
|
||||
assert existing_provider.value == "openai"
|
||||
|
||||
# Test comparison
|
||||
another_custom = ProviderName("my-totally-new-provider")
|
||||
assert custom_provider == another_custom
|
||||
|
||||
def test_custom_provider_block_executes(self):
|
||||
"""Test that blocks with custom providers can execute properly."""
|
||||
block = CustomProviderBlock()
|
||||
execute_block_test(block)
|
||||
|
||||
def test_multiple_custom_providers(self):
|
||||
"""Test that multiple custom providers can coexist."""
|
||||
block1 = CustomProviderBlock()
|
||||
block2 = AnotherCustomProviderBlock()
|
||||
|
||||
# Both blocks should execute successfully
|
||||
execute_block_test(block1)
|
||||
execute_block_test(block2)
|
||||
|
||||
def test_custom_provider_registration(self):
|
||||
"""Test that custom providers are registered in the auto-registry."""
|
||||
registry = get_registry()
|
||||
|
||||
# Check that our custom provider blocks have registered their costs
|
||||
block_costs = registry.get_block_costs_dict()
|
||||
assert CustomProviderBlock in block_costs
|
||||
assert AnotherCustomProviderBlock in block_costs
|
||||
|
||||
# Check the costs are correct
|
||||
custom_costs = block_costs[CustomProviderBlock]
|
||||
assert len(custom_costs) == 2
|
||||
assert any(
|
||||
cost.cost_amount == 10 and cost.cost_type == BlockCostType.RUN
|
||||
for cost in custom_costs
|
||||
)
|
||||
assert any(
|
||||
cost.cost_amount == 2 and cost.cost_type == BlockCostType.BYTE
|
||||
for cost in custom_costs
|
||||
)
|
||||
|
||||
def test_custom_provider_default_credentials(self):
|
||||
"""Test that default credentials are registered for custom providers."""
|
||||
registry = get_registry()
|
||||
default_creds = registry.get_default_credentials_list()
|
||||
|
||||
# Check that our custom provider's default credentials are registered
|
||||
custom_default_creds = [
|
||||
cred for cred in default_creds if cred.provider == "my-custom-service"
|
||||
]
|
||||
assert len(custom_default_creds) >= 1
|
||||
assert custom_default_creds[0].id == "my-custom-service-default"
|
||||
|
||||
def test_custom_provider_with_oauth(self):
|
||||
"""Test that custom providers can use OAuth handlers."""
|
||||
# This is a placeholder for OAuth testing
|
||||
# In a real implementation, you would create a custom OAuth handler
|
||||
pass
|
||||
|
||||
def test_custom_provider_with_webhooks(self):
|
||||
"""Test that custom providers can use webhook managers."""
|
||||
# This is a placeholder for webhook testing
|
||||
# In a real implementation, you would create a custom webhook manager
|
||||
pass
|
||||
|
||||
|
||||
# Test that runs as part of pytest
|
||||
def test_custom_provider_functionality():
|
||||
"""Run all custom provider tests."""
|
||||
test_instance = TestCustomProvider()
|
||||
|
||||
# Run each test method
|
||||
test_instance.test_custom_provider_enum_accepts_any_string()
|
||||
test_instance.test_custom_provider_block_executes()
|
||||
test_instance.test_multiple_custom_providers()
|
||||
test_instance.test_custom_provider_registration()
|
||||
test_instance.test_custom_provider_default_credentials()
|
||||
@@ -0,0 +1,389 @@
|
||||
# noqa: F405
|
||||
"""
|
||||
Advanced tests for custom provider functionality including OAuth and Webhooks.
|
||||
|
||||
This test suite demonstrates how custom providers can integrate with all
|
||||
aspects of the SDK including OAuth authentication and webhook handling.
|
||||
"""
|
||||
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.sdk import * # noqa: F403
|
||||
from backend.util.test import execute_block_test
|
||||
|
||||
|
||||
# Custom OAuth Handler for testing
|
||||
class CustomServiceOAuthHandler(BaseOAuthHandler):
|
||||
"""OAuth handler for our custom service."""
|
||||
|
||||
PROVIDER_NAME = ProviderName("custom-oauth-service")
|
||||
DEFAULT_SCOPES = ["read", "write", "admin"]
|
||||
|
||||
def get_login_url(
|
||||
self, scopes: list[str], state: str, code_challenge: Optional[str]
|
||||
) -> str:
|
||||
"""Generate OAuth login URL."""
|
||||
scope_str = " ".join(scopes)
|
||||
return f"https://custom-oauth-service.com/oauth/authorize?client_id=test&scope={scope_str}&state={state}"
|
||||
|
||||
def exchange_code_for_tokens(
|
||||
self, code: str, scopes: list[str], code_verifier: Optional[str]
|
||||
) -> OAuth2Credentials:
|
||||
"""Exchange authorization code for tokens."""
|
||||
# Mock token exchange
|
||||
return OAuth2Credentials(
|
||||
provider=self.PROVIDER_NAME,
|
||||
access_token=SecretStr("mock-access-token"),
|
||||
refresh_token=SecretStr("mock-refresh-token"),
|
||||
scopes=scopes,
|
||||
access_token_expires_at=int(time.time() + 3600),
|
||||
title="Custom OAuth Service",
|
||||
id="custom-oauth-creds",
|
||||
)
|
||||
|
||||
|
||||
# Custom Webhook Manager for testing
|
||||
class CustomWebhookManager(BaseWebhooksManager):
|
||||
"""Webhook manager for our custom service."""
|
||||
|
||||
PROVIDER_NAME = ProviderName("custom-webhook-service")
|
||||
|
||||
class WebhookType(str, Enum):
|
||||
DATA_RECEIVED = "data_received"
|
||||
STATUS_CHANGED = "status_changed"
|
||||
|
||||
@classmethod
|
||||
async def validate_payload(cls, webhook: Any, request: Any) -> tuple[dict, str]:
|
||||
"""Validate incoming webhook payload."""
|
||||
# Mock payload validation
|
||||
payload = {"data": "test data", "timestamp": time.time()}
|
||||
event_type = "data_received"
|
||||
return payload, event_type
|
||||
|
||||
async def _register_webhook(
|
||||
self,
|
||||
credentials: Any,
|
||||
webhook_type: Any,
|
||||
resource: str,
|
||||
events: list[str],
|
||||
ingress_url: str,
|
||||
secret: str,
|
||||
) -> tuple[str, dict]:
|
||||
"""Register webhook with external service."""
|
||||
# Mock webhook registration
|
||||
webhook_id = "custom-webhook-12345"
|
||||
config = {"url": ingress_url, "events": events, "resource": resource}
|
||||
return webhook_id, config
|
||||
|
||||
async def _deregister_webhook(self, webhook: Any, credentials: Any) -> None:
|
||||
"""Deregister webhook from external service."""
|
||||
# Mock webhook deregistration
|
||||
pass
|
||||
|
||||
|
||||
# Test OAuth-enabled block
|
||||
@provider("custom-oauth-service")
|
||||
@oauth_config("custom-oauth-service", CustomServiceOAuthHandler)
|
||||
@cost_config(
|
||||
BlockCost(cost_amount=15, cost_type=BlockCostType.RUN),
|
||||
)
|
||||
class CustomOAuthBlock(Block):
|
||||
"""Block that uses OAuth authentication with a custom provider."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = CredentialsField(
|
||||
provider="custom-oauth-service",
|
||||
supported_credential_types={"oauth2"},
|
||||
required_scopes={"read", "write"},
|
||||
description="OAuth credentials for custom service",
|
||||
)
|
||||
action: String = SchemaField(
|
||||
description="Action to perform", default="fetch_data"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
data: Dict = SchemaField(description="Retrieved data")
|
||||
token_valid: Boolean = SchemaField(description="Whether OAuth token was valid")
|
||||
scopes: List[String] = SchemaField(description="Available scopes")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f3456789-abcd-ef01-2345-6789abcdef01",
|
||||
description="Custom OAuth provider test block",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=CustomOAuthBlock.Input,
|
||||
output_schema=CustomOAuthBlock.Output,
|
||||
test_input={
|
||||
"credentials": {
|
||||
"provider": "custom-oauth-service",
|
||||
"id": "oauth-test-creds",
|
||||
"type": "oauth2",
|
||||
"title": "Test OAuth Creds",
|
||||
},
|
||||
"action": "test_action",
|
||||
},
|
||||
test_output=[
|
||||
("data", {"status": "success", "action": "test_action"}),
|
||||
("token_valid", True),
|
||||
("scopes", ["read", "write"]),
|
||||
],
|
||||
test_credentials=OAuth2Credentials(
|
||||
id="oauth-test-creds",
|
||||
provider="custom-oauth-service",
|
||||
access_token=SecretStr("test-access-token"),
|
||||
refresh_token=SecretStr("test-refresh-token"),
|
||||
scopes=["read", "write"],
|
||||
access_token_expires_at=int(time.time() + 3600),
|
||||
title="Test OAuth Credentials",
|
||||
),
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: OAuth2Credentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Simulate OAuth API call
|
||||
token = credentials.access_token.get_secret_value()
|
||||
|
||||
yield "data", {"status": "success", "action": input_data.action}
|
||||
yield "token_valid", bool(token)
|
||||
yield "scopes", credentials.scopes
|
||||
|
||||
|
||||
# Event filter model for webhook
|
||||
class WebhookEventFilter(BaseModel):
|
||||
data_received: bool = True
|
||||
status_changed: bool = False
|
||||
|
||||
|
||||
# Test Webhook-enabled block
|
||||
@provider("custom-webhook-service")
|
||||
@webhook_config("custom-webhook-service", CustomWebhookManager)
|
||||
class CustomWebhookBlock(Block):
|
||||
"""Block that receives webhooks from a custom provider."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = CredentialsField(
|
||||
provider="custom-webhook-service",
|
||||
supported_credential_types={"api_key"},
|
||||
description="Credentials for webhook service",
|
||||
)
|
||||
events: WebhookEventFilter = SchemaField(
|
||||
description="Events to listen for", default_factory=WebhookEventFilter
|
||||
)
|
||||
payload: Dict = SchemaField(
|
||||
description="Webhook payload", default={}, hidden=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
event_type: String = SchemaField(description="Type of event received")
|
||||
event_data: Dict = SchemaField(description="Event data")
|
||||
timestamp: Float = SchemaField(description="Event timestamp")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a4567890-bcde-f012-3456-7890bcdef012",
|
||||
description="Custom webhook provider test block",
|
||||
categories={BlockCategory.INPUT},
|
||||
input_schema=CustomWebhookBlock.Input,
|
||||
output_schema=CustomWebhookBlock.Output,
|
||||
block_type=BlockType.WEBHOOK,
|
||||
webhook_config=BlockWebhookConfig(
|
||||
provider=ProviderName("custom-webhook-service"),
|
||||
webhook_type="data_received",
|
||||
event_filter_input="events",
|
||||
resource_format="webhook/{webhook_id}",
|
||||
),
|
||||
test_input={
|
||||
"credentials": {
|
||||
"provider": "custom-webhook-service",
|
||||
"id": "webhook-test-creds",
|
||||
"type": "api_key",
|
||||
"title": "Test Webhook Creds",
|
||||
},
|
||||
"events": {"data_received": True, "status_changed": False},
|
||||
"payload": {
|
||||
"type": "data_received",
|
||||
"data": "test",
|
||||
"timestamp": 1234567890.0,
|
||||
},
|
||||
},
|
||||
test_output=[
|
||||
("event_type", "data_received"),
|
||||
(
|
||||
"event_data",
|
||||
{
|
||||
"type": "data_received",
|
||||
"data": "test",
|
||||
"timestamp": 1234567890.0,
|
||||
},
|
||||
),
|
||||
("timestamp", 1234567890.0),
|
||||
],
|
||||
test_credentials=APIKeyCredentials(
|
||||
id="webhook-test-creds",
|
||||
provider="custom-webhook-service",
|
||||
api_key=SecretStr("webhook-api-key"),
|
||||
title="Webhook API Key",
|
||||
expires_at=None,
|
||||
),
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
payload = input_data.payload
|
||||
|
||||
yield "event_type", payload.get("type", "unknown")
|
||||
yield "event_data", payload
|
||||
yield "timestamp", payload.get("timestamp", 0.0)
|
||||
|
||||
|
||||
# Combined block using multiple custom features
|
||||
@provider("custom-full-service")
|
||||
@oauth_config("custom-full-service", CustomServiceOAuthHandler)
|
||||
@webhook_config("custom-full-service", CustomWebhookManager)
|
||||
@cost_config(
|
||||
BlockCost(cost_amount=20, cost_type=BlockCostType.RUN),
|
||||
BlockCost(cost_amount=5, cost_type=BlockCostType.BYTE),
|
||||
)
|
||||
@default_credentials(
|
||||
APIKeyCredentials(
|
||||
id="custom-full-service-default",
|
||||
provider="custom-full-service",
|
||||
api_key=SecretStr("default-full-service-key"),
|
||||
title="Custom Full Service Default Key",
|
||||
expires_at=None,
|
||||
)
|
||||
)
|
||||
class CustomFullServiceBlock(Block):
|
||||
"""Block demonstrating all custom provider features."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = CredentialsField(
|
||||
provider="custom-full-service",
|
||||
supported_credential_types={"api_key", "oauth2"},
|
||||
description="Credentials for full service",
|
||||
)
|
||||
mode: String = SchemaField(description="Operation mode", default="standard")
|
||||
|
||||
class Output(BlockSchema):
|
||||
result: String = SchemaField(description="Operation result")
|
||||
features_used: List[String] = SchemaField(description="Features utilized")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b5678901-cdef-0123-4567-8901cdef0123",
|
||||
description="Full-featured custom provider block",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS, BlockCategory.INPUT},
|
||||
input_schema=CustomFullServiceBlock.Input,
|
||||
output_schema=CustomFullServiceBlock.Output,
|
||||
test_input={
|
||||
"credentials": {
|
||||
"provider": "custom-full-service",
|
||||
"id": "full-test-creds",
|
||||
"type": "api_key",
|
||||
"title": "Full Service Test Creds",
|
||||
},
|
||||
"mode": "test",
|
||||
},
|
||||
test_output=[
|
||||
("result", "SUCCESS: test mode"),
|
||||
("features_used", ["provider", "cost_config", "default_credentials"]),
|
||||
],
|
||||
test_credentials=APIKeyCredentials(
|
||||
id="full-test-creds",
|
||||
provider="custom-full-service",
|
||||
api_key=SecretStr("full-service-test-key"),
|
||||
title="Full Service Test Key",
|
||||
expires_at=None,
|
||||
),
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, *, credentials: Any, **kwargs) -> BlockOutput:
|
||||
features = ["provider", "cost_config", "default_credentials"]
|
||||
|
||||
if isinstance(credentials, OAuth2Credentials):
|
||||
features.append("oauth")
|
||||
|
||||
yield "result", f"SUCCESS: {input_data.mode} mode"
|
||||
yield "features_used", features
|
||||
|
||||
|
||||
class TestCustomProviderAdvanced:
|
||||
"""Advanced test suite for custom provider functionality."""
|
||||
|
||||
def test_oauth_handler_registration(self):
|
||||
"""Test that custom OAuth handlers are registered."""
|
||||
from backend.sdk.auto_registry import get_registry
|
||||
|
||||
registry = get_registry()
|
||||
oauth_handlers = registry.get_oauth_handlers_dict()
|
||||
|
||||
# Check if our custom OAuth handler is registered
|
||||
assert "custom-oauth-service" in oauth_handlers
|
||||
assert oauth_handlers["custom-oauth-service"] == CustomServiceOAuthHandler
|
||||
|
||||
def test_webhook_manager_registration(self):
|
||||
"""Test that custom webhook managers are registered."""
|
||||
from backend.sdk.auto_registry import get_registry
|
||||
|
||||
registry = get_registry()
|
||||
webhook_managers = registry.get_webhook_managers_dict()
|
||||
|
||||
# Check if our custom webhook manager is registered
|
||||
assert "custom-webhook-service" in webhook_managers
|
||||
assert webhook_managers["custom-webhook-service"] == CustomWebhookManager
|
||||
|
||||
def test_oauth_block_execution(self):
|
||||
"""Test OAuth-enabled block execution."""
|
||||
block = CustomOAuthBlock()
|
||||
execute_block_test(block)
|
||||
|
||||
def test_webhook_block_execution(self):
|
||||
"""Test webhook-enabled block execution."""
|
||||
block = CustomWebhookBlock()
|
||||
execute_block_test(block)
|
||||
|
||||
def test_full_service_block_execution(self):
|
||||
"""Test full-featured block execution."""
|
||||
block = CustomFullServiceBlock()
|
||||
execute_block_test(block)
|
||||
|
||||
def test_multiple_decorators_on_same_provider(self):
|
||||
"""Test that a single provider can have multiple features."""
|
||||
from backend.sdk.auto_registry import get_registry
|
||||
|
||||
registry = get_registry()
|
||||
|
||||
# Check OAuth handler
|
||||
oauth_handlers = registry.get_oauth_handlers_dict()
|
||||
assert "custom-full-service" in oauth_handlers
|
||||
|
||||
# Check webhook manager
|
||||
webhook_managers = registry.get_webhook_managers_dict()
|
||||
assert "custom-full-service" in webhook_managers
|
||||
|
||||
# Check default credentials
|
||||
default_creds = registry.get_default_credentials_list()
|
||||
full_service_creds = [
|
||||
cred for cred in default_creds if cred.provider == "custom-full-service"
|
||||
]
|
||||
assert len(full_service_creds) >= 1
|
||||
|
||||
# Check cost config
|
||||
block_costs = registry.get_block_costs_dict()
|
||||
assert CustomFullServiceBlock in block_costs
|
||||
|
||||
|
||||
# Main test function
|
||||
def test_custom_provider_advanced_functionality():
|
||||
"""Run all advanced custom provider tests."""
|
||||
test_instance = TestCustomProviderAdvanced()
|
||||
|
||||
test_instance.test_oauth_handler_registration()
|
||||
test_instance.test_webhook_manager_registration()
|
||||
test_instance.test_oauth_block_execution()
|
||||
test_instance.test_webhook_block_execution()
|
||||
test_instance.test_full_service_block_execution()
|
||||
test_instance.test_multiple_decorators_on_same_provider()
|
||||
@@ -0,0 +1,35 @@
|
||||
"""Test that custom providers work with validation."""
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
|
||||
class TestModel(BaseModel):
|
||||
provider: ProviderName
|
||||
|
||||
|
||||
def test_custom_provider_validation():
|
||||
"""Test that custom provider names are accepted."""
|
||||
# Test with existing provider
|
||||
model1 = TestModel(provider=ProviderName("openai"))
|
||||
assert model1.provider == ProviderName.OPENAI
|
||||
assert model1.provider.value == "openai"
|
||||
|
||||
# Test with custom provider
|
||||
model2 = TestModel(provider=ProviderName("my-custom-provider"))
|
||||
assert model2.provider.value == "my-custom-provider"
|
||||
|
||||
# Test JSON schema
|
||||
schema = TestModel.model_json_schema()
|
||||
provider_schema = schema["properties"]["provider"]
|
||||
|
||||
# Should not have enum constraint
|
||||
assert "enum" not in provider_schema
|
||||
assert provider_schema["type"] == "string"
|
||||
|
||||
print("✅ Custom provider validation works!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_custom_provider_validation()
|
||||
Reference in New Issue
Block a user