mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
added more sdk tests
This commit is contained in:
@@ -5,7 +5,7 @@ This test suite verifies that blocks can be created using only SDK imports
|
||||
and that they work correctly without decorators.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -13,9 +13,12 @@ from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockCostType,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
OAuth2Credentials,
|
||||
ProviderBuilder,
|
||||
SchemaField,
|
||||
SecretStr,
|
||||
)
|
||||
@@ -434,5 +437,478 @@ class TestComplexBlockScenarios:
|
||||
pass
|
||||
|
||||
|
||||
class TestAuthenticationVariants:
|
||||
"""Test complex authentication scenarios including OAuth, API keys, and scopes."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_block_with_scopes(self):
|
||||
"""Test creating a block that uses OAuth2 with scopes."""
|
||||
from backend.sdk import OAuth2Credentials, ProviderBuilder
|
||||
|
||||
# Create a test OAuth provider with scopes
|
||||
# For testing, we don't need an actual OAuth handler
|
||||
# In real usage, you would provide a proper OAuth handler class
|
||||
oauth_provider = (
|
||||
ProviderBuilder("test_oauth_provider")
|
||||
.with_api_key("TEST_OAUTH_API", "Test OAuth API")
|
||||
.with_base_cost(5, BlockCostType.RUN)
|
||||
.build()
|
||||
)
|
||||
|
||||
class OAuthScopedBlock(Block):
|
||||
"""Block requiring OAuth2 with specific scopes."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = oauth_provider.credentials_field(
|
||||
description="OAuth2 credentials with scopes",
|
||||
scopes=["read:user", "write:data"],
|
||||
)
|
||||
resource: str = SchemaField(description="Resource to access")
|
||||
|
||||
class Output(BlockSchema):
|
||||
data: str = SchemaField(description="Retrieved data")
|
||||
scopes_used: list[str] = SchemaField(
|
||||
description="Scopes that were used"
|
||||
)
|
||||
token_info: dict[str, Any] = SchemaField(
|
||||
description="Token information"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="oauth-scoped-block",
|
||||
description="Test OAuth2 with scopes",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=OAuthScopedBlock.Input,
|
||||
output_schema=OAuthScopedBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: OAuth2Credentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Simulate OAuth API call with scopes
|
||||
token = credentials.access_token.get_secret_value()
|
||||
|
||||
yield "data", f"OAuth data for {input_data.resource}"
|
||||
yield "scopes_used", credentials.scopes or []
|
||||
yield "token_info", {
|
||||
"has_token": bool(token),
|
||||
"has_refresh": credentials.refresh_token is not None,
|
||||
"provider": credentials.provider,
|
||||
"expires_at": credentials.access_token_expires_at,
|
||||
}
|
||||
|
||||
# Create test OAuth credentials
|
||||
test_oauth_creds = OAuth2Credentials(
|
||||
id="test-oauth-creds",
|
||||
provider="test_oauth_provider",
|
||||
access_token=SecretStr("test-access-token"),
|
||||
refresh_token=SecretStr("test-refresh-token"),
|
||||
scopes=["read:user", "write:data"],
|
||||
title="Test OAuth Credentials",
|
||||
)
|
||||
|
||||
# Test the block
|
||||
block = OAuthScopedBlock()
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
OAuthScopedBlock.Input(
|
||||
credentials={ # type: ignore
|
||||
"provider": "test_oauth_provider",
|
||||
"id": "test-oauth-creds",
|
||||
"type": "oauth2",
|
||||
},
|
||||
resource="user/profile",
|
||||
),
|
||||
credentials=test_oauth_creds,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["data"] == "OAuth data for user/profile"
|
||||
assert set(outputs["scopes_used"]) == {"read:user", "write:data"}
|
||||
assert outputs["token_info"]["has_token"] is True
|
||||
assert outputs["token_info"]["expires_at"] is None
|
||||
assert outputs["token_info"]["has_refresh"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_auth_block(self):
|
||||
"""Test block that supports both OAuth2 and API key authentication."""
|
||||
# No need to import these again, already imported at top
|
||||
|
||||
# Create provider supporting both auth types
|
||||
# Create provider supporting API key auth
|
||||
# In real usage, you would add OAuth support with .with_oauth()
|
||||
mixed_provider = (
|
||||
ProviderBuilder("mixed_auth_provider")
|
||||
.with_api_key("MIXED_API_KEY", "Mixed Provider API Key")
|
||||
.with_base_cost(8, BlockCostType.RUN)
|
||||
.build()
|
||||
)
|
||||
|
||||
class MixedAuthBlock(Block):
|
||||
"""Block supporting multiple authentication methods."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = mixed_provider.credentials_field(
|
||||
description="API key or OAuth2 credentials",
|
||||
supported_credential_types=["api_key", "oauth2"],
|
||||
)
|
||||
operation: str = SchemaField(description="Operation to perform")
|
||||
|
||||
class Output(BlockSchema):
|
||||
result: str = SchemaField(description="Operation result")
|
||||
auth_type: str = SchemaField(description="Authentication type used")
|
||||
auth_details: dict[str, Any] = SchemaField(description="Auth details")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="mixed-auth-block",
|
||||
description="Block supporting OAuth2 and API key",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=MixedAuthBlock.Input,
|
||||
output_schema=MixedAuthBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: Union[APIKeyCredentials, OAuth2Credentials],
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# Handle different credential types
|
||||
if isinstance(credentials, APIKeyCredentials):
|
||||
auth_type = "api_key"
|
||||
auth_details = {
|
||||
"has_key": bool(credentials.api_key.get_secret_value()),
|
||||
"key_prefix": credentials.api_key.get_secret_value()[:5]
|
||||
+ "...",
|
||||
}
|
||||
elif isinstance(credentials, OAuth2Credentials):
|
||||
auth_type = "oauth2"
|
||||
auth_details = {
|
||||
"has_token": bool(credentials.access_token.get_secret_value()),
|
||||
"scopes": credentials.scopes or [],
|
||||
}
|
||||
else:
|
||||
auth_type = "unknown"
|
||||
auth_details = {}
|
||||
|
||||
yield "result", f"Performed {input_data.operation} with {auth_type}"
|
||||
yield "auth_type", auth_type
|
||||
yield "auth_details", auth_details
|
||||
|
||||
# Test with API key
|
||||
api_creds = APIKeyCredentials(
|
||||
id="mixed-api-creds",
|
||||
provider="mixed_auth_provider",
|
||||
api_key=SecretStr("sk-1234567890"),
|
||||
title="Mixed API Key",
|
||||
)
|
||||
|
||||
block = MixedAuthBlock()
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
MixedAuthBlock.Input(
|
||||
credentials={ # type: ignore
|
||||
"provider": "mixed_auth_provider",
|
||||
"id": "mixed-api-creds",
|
||||
"type": "api_key",
|
||||
},
|
||||
operation="fetch_data",
|
||||
),
|
||||
credentials=api_creds,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["auth_type"] == "api_key"
|
||||
assert outputs["result"] == "Performed fetch_data with api_key"
|
||||
assert outputs["auth_details"]["key_prefix"] == "sk-12..."
|
||||
|
||||
# Test with OAuth2
|
||||
oauth_creds = OAuth2Credentials(
|
||||
id="mixed-oauth-creds",
|
||||
provider="mixed_auth_provider",
|
||||
access_token=SecretStr("oauth-token-123"),
|
||||
scopes=["full_access"],
|
||||
title="Mixed OAuth",
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
MixedAuthBlock.Input(
|
||||
credentials={ # type: ignore
|
||||
"provider": "mixed_auth_provider",
|
||||
"id": "mixed-oauth-creds",
|
||||
"type": "oauth2",
|
||||
},
|
||||
operation="update_data",
|
||||
),
|
||||
credentials=oauth_creds,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["auth_type"] == "oauth2"
|
||||
assert outputs["result"] == "Performed update_data with oauth2"
|
||||
assert outputs["auth_details"]["scopes"] == ["full_access"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_credentials_block(self):
|
||||
"""Test block requiring multiple different credentials."""
|
||||
from backend.sdk import ProviderBuilder
|
||||
|
||||
# Create multiple providers
|
||||
primary_provider = (
|
||||
ProviderBuilder("primary_service")
|
||||
.with_api_key("PRIMARY_API_KEY", "Primary Service Key")
|
||||
.build()
|
||||
)
|
||||
|
||||
# For testing purposes, using API key instead of OAuth handler
|
||||
secondary_provider = (
|
||||
ProviderBuilder("secondary_service")
|
||||
.with_api_key("SECONDARY_API_KEY", "Secondary Service Key")
|
||||
.build()
|
||||
)
|
||||
|
||||
class MultiCredentialBlock(Block):
|
||||
"""Block requiring credentials from multiple services."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
primary_credentials: CredentialsMetaInput = (
|
||||
primary_provider.credentials_field(
|
||||
description="Primary service API key"
|
||||
)
|
||||
)
|
||||
secondary_credentials: CredentialsMetaInput = (
|
||||
secondary_provider.credentials_field(
|
||||
description="Secondary service OAuth"
|
||||
)
|
||||
)
|
||||
merge_data: bool = SchemaField(
|
||||
description="Whether to merge data from both services",
|
||||
default=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
primary_data: str = SchemaField(description="Data from primary service")
|
||||
secondary_data: str = SchemaField(
|
||||
description="Data from secondary service"
|
||||
)
|
||||
merged_result: Optional[str] = SchemaField(
|
||||
description="Merged data if requested"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="multi-credential-block",
|
||||
description="Block using multiple credentials",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=MultiCredentialBlock.Input,
|
||||
output_schema=MultiCredentialBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
primary_credentials: APIKeyCredentials,
|
||||
secondary_credentials: OAuth2Credentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# Simulate fetching data with primary API key
|
||||
primary_data = f"Primary data using {primary_credentials.provider}"
|
||||
yield "primary_data", primary_data
|
||||
|
||||
# Simulate fetching data with secondary OAuth
|
||||
secondary_data = f"Secondary data with {len(secondary_credentials.scopes or [])} scopes"
|
||||
yield "secondary_data", secondary_data
|
||||
|
||||
# Merge if requested
|
||||
if input_data.merge_data:
|
||||
merged = f"{primary_data} + {secondary_data}"
|
||||
yield "merged_result", merged
|
||||
else:
|
||||
yield "merged_result", None
|
||||
|
||||
# Create test credentials
|
||||
primary_creds = APIKeyCredentials(
|
||||
id="primary-creds",
|
||||
provider="primary_service",
|
||||
api_key=SecretStr("primary-key-123"),
|
||||
title="Primary Key",
|
||||
)
|
||||
|
||||
secondary_creds = OAuth2Credentials(
|
||||
id="secondary-creds",
|
||||
provider="secondary_service",
|
||||
access_token=SecretStr("secondary-token"),
|
||||
scopes=["read", "write"],
|
||||
title="Secondary OAuth",
|
||||
)
|
||||
|
||||
# Test the block
|
||||
block = MultiCredentialBlock()
|
||||
outputs = {}
|
||||
|
||||
# Note: In real usage, the framework would inject the correct credentials
|
||||
# based on the field names. Here we simulate that behavior.
|
||||
async for name, value in block.run(
|
||||
MultiCredentialBlock.Input(
|
||||
primary_credentials={ # type: ignore
|
||||
"provider": "primary_service",
|
||||
"id": "primary-creds",
|
||||
"type": "api_key",
|
||||
},
|
||||
secondary_credentials={ # type: ignore
|
||||
"provider": "secondary_service",
|
||||
"id": "secondary-creds",
|
||||
"type": "oauth2",
|
||||
},
|
||||
merge_data=True,
|
||||
),
|
||||
primary_credentials=primary_creds,
|
||||
secondary_credentials=secondary_creds,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["primary_data"] == "Primary data using primary_service"
|
||||
assert outputs["secondary_data"] == "Secondary data with 2 scopes"
|
||||
assert "Primary data" in outputs["merged_result"]
|
||||
assert "Secondary data" in outputs["merged_result"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_scope_validation(self):
|
||||
"""Test OAuth scope validation and handling."""
|
||||
from backend.sdk import OAuth2Credentials, ProviderBuilder
|
||||
|
||||
# Provider with specific required scopes
|
||||
# For testing OAuth scope validation
|
||||
scoped_provider = (
|
||||
ProviderBuilder("scoped_oauth_service")
|
||||
.with_api_key("SCOPED_OAUTH_KEY", "Scoped OAuth Service")
|
||||
.build()
|
||||
)
|
||||
|
||||
class ScopeValidationBlock(Block):
|
||||
"""Block that validates OAuth scopes."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = scoped_provider.credentials_field(
|
||||
description="OAuth credentials with specific scopes",
|
||||
scopes=["user:read", "user:write"], # Required scopes
|
||||
)
|
||||
require_admin: bool = SchemaField(
|
||||
description="Whether admin scopes are required",
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
allowed_operations: list[str] = SchemaField(
|
||||
description="Operations allowed with current scopes"
|
||||
)
|
||||
missing_scopes: list[str] = SchemaField(
|
||||
description="Scopes that are missing for full access"
|
||||
)
|
||||
has_required_scopes: bool = SchemaField(
|
||||
description="Whether all required scopes are present"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="scope-validation-block",
|
||||
description="Block that validates OAuth scopes",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=ScopeValidationBlock.Input,
|
||||
output_schema=ScopeValidationBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: OAuth2Credentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
current_scopes = set(credentials.scopes or [])
|
||||
required_scopes = {"user:read", "user:write"}
|
||||
|
||||
if input_data.require_admin:
|
||||
required_scopes.update({"admin:read", "admin:write"})
|
||||
|
||||
# Determine allowed operations based on scopes
|
||||
allowed_ops = []
|
||||
if "user:read" in current_scopes:
|
||||
allowed_ops.append("read_user_data")
|
||||
if "user:write" in current_scopes:
|
||||
allowed_ops.append("update_user_data")
|
||||
if "admin:read" in current_scopes:
|
||||
allowed_ops.append("read_admin_data")
|
||||
if "admin:write" in current_scopes:
|
||||
allowed_ops.append("update_admin_data")
|
||||
|
||||
missing = list(required_scopes - current_scopes)
|
||||
has_required = len(missing) == 0
|
||||
|
||||
yield "allowed_operations", allowed_ops
|
||||
yield "missing_scopes", missing
|
||||
yield "has_required_scopes", has_required
|
||||
|
||||
# Test with partial scopes
|
||||
partial_creds = OAuth2Credentials(
|
||||
id="partial-oauth",
|
||||
provider="scoped_oauth_service",
|
||||
access_token=SecretStr("partial-token"),
|
||||
scopes=["user:read"], # Only one of the required scopes
|
||||
title="Partial OAuth",
|
||||
)
|
||||
|
||||
block = ScopeValidationBlock()
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
ScopeValidationBlock.Input(
|
||||
credentials={ # type: ignore
|
||||
"provider": "scoped_oauth_service",
|
||||
"id": "partial-oauth",
|
||||
"type": "oauth2",
|
||||
},
|
||||
require_admin=False,
|
||||
),
|
||||
credentials=partial_creds,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
assert outputs["allowed_operations"] == ["read_user_data"]
|
||||
assert "user:write" in outputs["missing_scopes"]
|
||||
assert outputs["has_required_scopes"] is False
|
||||
|
||||
# Test with all required scopes
|
||||
full_creds = OAuth2Credentials(
|
||||
id="full-oauth",
|
||||
provider="scoped_oauth_service",
|
||||
access_token=SecretStr("full-token"),
|
||||
scopes=["user:read", "user:write", "admin:read"],
|
||||
title="Full OAuth",
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
ScopeValidationBlock.Input(
|
||||
credentials={ # type: ignore
|
||||
"provider": "scoped_oauth_service",
|
||||
"id": "full-oauth",
|
||||
"type": "oauth2",
|
||||
},
|
||||
require_admin=False,
|
||||
),
|
||||
credentials=full_creds,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
assert set(outputs["allowed_operations"]) == {
|
||||
"read_user_data",
|
||||
"update_user_data",
|
||||
"read_admin_data",
|
||||
}
|
||||
assert outputs["missing_scopes"] == []
|
||||
assert outputs["has_required_scopes"] is True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
|
||||
Reference in New Issue
Block a user