mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
refactor(classic): use platform API for blocks instead of local loading
Simplify the platform_blocks component to fetch blocks from the platform API (/api/v1/blocks) instead of loading them locally from the monorepo. This removes the dependency on having the platform backend code available. - Remove loader.py (no longer needed) - Update client.py with list_blocks() method - Simplify component.py to use API for both search and execute - Remove user_id from config (not needed by API) - Update tests for API-based approach Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
"""HTTP client for platform API - used for block execution.
|
||||
"""HTTP client for platform API.
|
||||
|
||||
This client handles communication with the AutoGPT Platform API,
|
||||
which manages credentials and executes blocks with proper authentication.
|
||||
for listing blocks, executing them, and managing credentials.
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -21,7 +21,7 @@ class PlatformClientError(Exception):
|
||||
|
||||
|
||||
class PlatformClient:
|
||||
"""Client for platform.agpt.co API - used for block execution."""
|
||||
"""Client for platform.agpt.co API."""
|
||||
|
||||
def __init__(self, base_url: str, api_key: str, timeout: int = 60):
|
||||
"""Initialize the platform client.
|
||||
@@ -42,91 +42,16 @@ class PlatformClient:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
return headers
|
||||
|
||||
async def execute_block(
|
||||
self,
|
||||
block_id: str,
|
||||
input_data: dict[str, Any],
|
||||
user_id: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Execute a block via platform API.
|
||||
|
||||
Args:
|
||||
block_id: The block ID to execute.
|
||||
input_data: Input data matching the block's input schema.
|
||||
user_id: User ID for credential resolution.
|
||||
async def list_blocks(self) -> list[dict[str, Any]]:
|
||||
"""List all available blocks from platform API.
|
||||
|
||||
Returns:
|
||||
Execution result with outputs.
|
||||
List of block dictionaries with id, name, description, schemas.
|
||||
|
||||
Raises:
|
||||
PlatformClientError: If the API request fails.
|
||||
"""
|
||||
url = f"{self.base_url}/api/v1/blocks/{block_id}/execute"
|
||||
payload = {"input_data": input_data, "user_id": user_id}
|
||||
|
||||
async with aiohttp.ClientSession(timeout=self.timeout) as session:
|
||||
try:
|
||||
async with session.post(
|
||||
url, headers=self._headers(), json=payload
|
||||
) as resp:
|
||||
if resp.status >= 400:
|
||||
error_text = await resp.text()
|
||||
raise PlatformClientError(
|
||||
f"Platform API error: {error_text}",
|
||||
status_code=resp.status,
|
||||
)
|
||||
return await resp.json()
|
||||
except aiohttp.ClientError as e:
|
||||
raise PlatformClientError(f"Connection error: {e}") from e
|
||||
|
||||
async def check_credentials(
|
||||
self,
|
||||
block_id: str,
|
||||
user_id: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Check if user has required credentials for a block.
|
||||
|
||||
Args:
|
||||
block_id: The block ID to check.
|
||||
user_id: User ID for credential lookup.
|
||||
|
||||
Returns:
|
||||
Credential check result with has_required_credentials and missing list.
|
||||
|
||||
Raises:
|
||||
PlatformClientError: If the API request fails.
|
||||
"""
|
||||
url = f"{self.base_url}/api/v1/blocks/{block_id}/credentials/check"
|
||||
params = {"user_id": user_id}
|
||||
|
||||
async with aiohttp.ClientSession(timeout=self.timeout) as session:
|
||||
try:
|
||||
async with session.get(
|
||||
url, headers=self._headers(), params=params
|
||||
) as resp:
|
||||
if resp.status >= 400:
|
||||
error_text = await resp.text()
|
||||
raise PlatformClientError(
|
||||
f"Platform API error: {error_text}",
|
||||
status_code=resp.status,
|
||||
)
|
||||
return await resp.json()
|
||||
except aiohttp.ClientError as e:
|
||||
raise PlatformClientError(f"Connection error: {e}") from e
|
||||
|
||||
async def get_block_info(self, block_id: str) -> dict[str, Any]:
|
||||
"""Get block information from platform API.
|
||||
|
||||
Args:
|
||||
block_id: The block ID to get info for.
|
||||
|
||||
Returns:
|
||||
Block information including schema and description.
|
||||
|
||||
Raises:
|
||||
PlatformClientError: If the API request fails.
|
||||
"""
|
||||
url = f"{self.base_url}/api/v1/blocks/{block_id}"
|
||||
url = f"{self.base_url}/api/v1/blocks"
|
||||
|
||||
async with aiohttp.ClientSession(timeout=self.timeout) as session:
|
||||
try:
|
||||
@@ -140,3 +65,37 @@ class PlatformClient:
|
||||
return await resp.json()
|
||||
except aiohttp.ClientError as e:
|
||||
raise PlatformClientError(f"Connection error: {e}") from e
|
||||
|
||||
async def execute_block(
|
||||
self,
|
||||
block_id: str,
|
||||
input_data: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Execute a block via platform API.
|
||||
|
||||
Args:
|
||||
block_id: The block ID to execute.
|
||||
input_data: Input data matching the block's input schema.
|
||||
|
||||
Returns:
|
||||
Execution result with outputs.
|
||||
|
||||
Raises:
|
||||
PlatformClientError: If the API request fails.
|
||||
"""
|
||||
url = f"{self.base_url}/api/v1/blocks/{block_id}/execute"
|
||||
|
||||
async with aiohttp.ClientSession(timeout=self.timeout) as session:
|
||||
try:
|
||||
async with session.post(
|
||||
url, headers=self._headers(), json=input_data
|
||||
) as resp:
|
||||
if resp.status >= 400:
|
||||
error_text = await resp.text()
|
||||
raise PlatformClientError(
|
||||
f"Platform API error: {error_text}",
|
||||
status_code=resp.status,
|
||||
)
|
||||
return await resp.json()
|
||||
except aiohttp.ClientError as e:
|
||||
raise PlatformClientError(f"Connection error: {e}") from e
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
"""Platform blocks component for classic agents.
|
||||
|
||||
Provides search_blocks and execute_block commands:
|
||||
- search_blocks: Uses local block registry (fast, offline)
|
||||
- execute_block: Uses platform API (handles credentials)
|
||||
Provides search_blocks and execute_block commands that call the platform API.
|
||||
"""
|
||||
|
||||
import json
|
||||
@@ -14,7 +12,6 @@ from forge.agent.protocols import CommandProvider, DirectiveProvider
|
||||
from forge.command import Command, command
|
||||
from forge.models.json_schema import JSONSchema
|
||||
|
||||
from . import loader
|
||||
from .client import PlatformClient, PlatformClientError
|
||||
from .config import PlatformBlocksConfig
|
||||
|
||||
@@ -26,24 +23,14 @@ class PlatformBlocksComponent(
|
||||
CommandProvider,
|
||||
ConfigurableComponent[PlatformBlocksConfig],
|
||||
):
|
||||
"""Provides search_blocks and execute_block commands.
|
||||
|
||||
- search_blocks: Uses local block registry (fast, offline)
|
||||
- execute_block: Uses platform API (handles credentials)
|
||||
"""
|
||||
"""Provides search_blocks and execute_block commands via platform API."""
|
||||
|
||||
config_class = PlatformBlocksConfig
|
||||
|
||||
def __init__(self, config: PlatformBlocksConfig | None = None):
|
||||
ConfigurableComponent.__init__(self, config)
|
||||
self._client: PlatformClient | None = None
|
||||
self._platform_available = loader.is_platform_available()
|
||||
|
||||
if not self._platform_available:
|
||||
logger.warning(
|
||||
"Platform blocks not available - "
|
||||
"install autogpt_platform or add to PYTHONPATH"
|
||||
)
|
||||
self._blocks_cache: list[dict[str, Any]] | None = None
|
||||
|
||||
@property
|
||||
def client(self) -> PlatformClient:
|
||||
@@ -58,25 +45,32 @@ class PlatformBlocksComponent(
|
||||
|
||||
def get_resources(self) -> Iterator[str]:
|
||||
"""Describe available resources."""
|
||||
if self.config.enabled and self._platform_available:
|
||||
try:
|
||||
block_count = len(loader.load_blocks())
|
||||
yield (
|
||||
f"Access to {block_count} platform blocks via search_blocks "
|
||||
"and execute_block commands."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not count blocks: {e}")
|
||||
if self.config.enabled:
|
||||
yield (
|
||||
"Access to platform blocks via search_blocks and execute_block "
|
||||
"commands. Use search_blocks first to discover available blocks."
|
||||
)
|
||||
|
||||
def get_commands(self) -> Iterator[Command]:
|
||||
"""Provide available commands."""
|
||||
if not self.config.enabled:
|
||||
return
|
||||
if not self._platform_available:
|
||||
return
|
||||
yield self.search_blocks
|
||||
yield self.execute_block
|
||||
|
||||
async def _get_blocks(self) -> list[dict[str, Any]]:
|
||||
"""Get blocks from API, with caching."""
|
||||
if self._blocks_cache is not None:
|
||||
return self._blocks_cache
|
||||
|
||||
try:
|
||||
self._blocks_cache = await self.client.list_blocks()
|
||||
logger.info(f"Loaded {len(self._blocks_cache)} blocks from platform API")
|
||||
return self._blocks_cache
|
||||
except PlatformClientError as e:
|
||||
logger.error(f"Failed to load blocks from API: {e}")
|
||||
return []
|
||||
|
||||
@command(
|
||||
names=["search_blocks", "find_block"],
|
||||
description=(
|
||||
@@ -92,8 +86,8 @@ class PlatformBlocksComponent(
|
||||
),
|
||||
},
|
||||
)
|
||||
def search_blocks(self, query: str) -> str:
|
||||
"""Search blocks locally (fast, no network call).
|
||||
async def search_blocks(self, query: str) -> str:
|
||||
"""Search blocks via platform API.
|
||||
|
||||
Args:
|
||||
query: Search query for finding blocks.
|
||||
@@ -102,7 +96,35 @@ class PlatformBlocksComponent(
|
||||
JSON string with search results.
|
||||
"""
|
||||
try:
|
||||
results = loader.search_blocks(query, limit=20)
|
||||
blocks = await self._get_blocks()
|
||||
query_lower = query.lower()
|
||||
results: list[dict[str, Any]] = []
|
||||
|
||||
for block in blocks:
|
||||
name = block.get("name", "")
|
||||
description = block.get("description", "")
|
||||
categories = [
|
||||
c.get("category", "") for c in block.get("categories", [])
|
||||
]
|
||||
|
||||
# Check for match
|
||||
name_match = query_lower in name.lower()
|
||||
desc_match = query_lower in description.lower()
|
||||
cat_match = any(query_lower in c.lower() for c in categories)
|
||||
|
||||
if name_match or desc_match or cat_match:
|
||||
results.append(
|
||||
{
|
||||
"id": block.get("id"),
|
||||
"name": name,
|
||||
"description": description,
|
||||
"categories": categories,
|
||||
"input_schema": block.get("inputSchema", {}),
|
||||
}
|
||||
)
|
||||
|
||||
if len(results) >= 20:
|
||||
break
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
@@ -120,8 +142,7 @@ class PlatformBlocksComponent(
|
||||
names=["execute_block", "run_block"],
|
||||
description=(
|
||||
"Execute a platform block by ID with input data. "
|
||||
"IMPORTANT: Use search_blocks FIRST to get the block ID and schema. "
|
||||
"Credentials are automatically resolved via platform API."
|
||||
"IMPORTANT: Use search_blocks FIRST to get the block ID and schema."
|
||||
),
|
||||
parameters={
|
||||
"block_id": JSONSchema(
|
||||
@@ -146,43 +167,24 @@ class PlatformBlocksComponent(
|
||||
Returns:
|
||||
JSON string with execution result.
|
||||
"""
|
||||
user_id = self.config.user_id or "classic_agent"
|
||||
|
||||
try:
|
||||
# Get block info locally for better error messages
|
||||
block = loader.get_block(block_id)
|
||||
block_name = getattr(block, "name", block_id) if block else block_id
|
||||
|
||||
# Check credentials first
|
||||
try:
|
||||
cred_check = await self.client.check_credentials(block_id, user_id)
|
||||
if not cred_check.get("has_required_credentials", True):
|
||||
missing = cred_check.get("missing_credentials", [])
|
||||
return json.dumps(
|
||||
{
|
||||
"error": "Missing required credentials",
|
||||
"block": block_name,
|
||||
"missing_credentials": missing,
|
||||
"message": (
|
||||
"Please configure the required credentials at "
|
||||
f"{self.config.platform_url}/settings/credentials"
|
||||
),
|
||||
},
|
||||
indent=2,
|
||||
)
|
||||
except PlatformClientError as e:
|
||||
logger.warning(f"Could not check credentials: {e}")
|
||||
# Continue anyway - execution will fail if creds are missing
|
||||
# Get block name for better error messages
|
||||
blocks = await self._get_blocks()
|
||||
block_name = block_id
|
||||
for block in blocks:
|
||||
if block.get("id") == block_id:
|
||||
block_name = block.get("name", block_id)
|
||||
break
|
||||
|
||||
# Execute the block
|
||||
result = await self.client.execute_block(block_id, input_data, user_id)
|
||||
result = await self.client.execute_block(block_id, input_data)
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"block": block_name,
|
||||
"block_id": block_id,
|
||||
"outputs": result.get("outputs", {}),
|
||||
"outputs": result,
|
||||
},
|
||||
indent=2,
|
||||
)
|
||||
|
||||
@@ -9,8 +9,7 @@ class PlatformBlocksConfig(BaseModel):
|
||||
)
|
||||
platform_url: str = Field(
|
||||
default="https://platform.agpt.co",
|
||||
description="Platform API URL for execution",
|
||||
description="Platform API base URL",
|
||||
)
|
||||
api_key: str = Field(default="", description="Platform API key for authentication")
|
||||
user_id: str = Field(default="", description="User ID for credential lookup")
|
||||
timeout: int = Field(default=60, description="Execution timeout in seconds")
|
||||
timeout: int = Field(default=60, description="Request timeout in seconds")
|
||||
|
||||
@@ -1,239 +0,0 @@
|
||||
"""Local block loading from platform codebase.
|
||||
|
||||
This module provides functions to load and search platform blocks locally,
|
||||
without making network calls. The platform backend is automatically discovered
|
||||
from the monorepo structure.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_blocks_cache: dict[str, type] | None = None
|
||||
_platform_path_added = False
|
||||
|
||||
|
||||
def _get_platform_backend_path() -> Path | None:
|
||||
"""Find the platform backend path relative to this file.
|
||||
|
||||
The monorepo structure is:
|
||||
AutoGPT/main/
|
||||
├── classic/forge/forge/components/platform_blocks/loader.py (this file)
|
||||
└── autogpt_platform/backend/ (platform backend)
|
||||
"""
|
||||
# This file is at: classic/forge/forge/components/platform_blocks/loader.py
|
||||
# Go up to classic/, then up to main/, then into autogpt_platform/backend/
|
||||
this_file = Path(__file__).resolve()
|
||||
classic_dir = this_file.parent.parent.parent.parent.parent.parent # classic/
|
||||
main_dir = classic_dir.parent # main/
|
||||
platform_backend = main_dir / "autogpt_platform" / "backend"
|
||||
|
||||
if platform_backend.exists() and (platform_backend / "backend").exists():
|
||||
return platform_backend
|
||||
return None
|
||||
|
||||
|
||||
def _ensure_platform_path() -> bool:
|
||||
"""Add platform backend to sys.path if not already present."""
|
||||
global _platform_path_added
|
||||
if _platform_path_added:
|
||||
return True
|
||||
|
||||
platform_path = _get_platform_backend_path()
|
||||
if platform_path is None:
|
||||
logger.debug("Platform backend not found in monorepo structure")
|
||||
return False
|
||||
|
||||
path_str = str(platform_path)
|
||||
if path_str not in sys.path:
|
||||
sys.path.insert(0, path_str)
|
||||
logger.debug(f"Added platform backend to path: {path_str}")
|
||||
|
||||
_platform_path_added = True
|
||||
return True
|
||||
|
||||
|
||||
def is_platform_available() -> bool:
|
||||
"""Check if platform blocks can be imported."""
|
||||
_ensure_platform_path()
|
||||
try:
|
||||
from backend.blocks import ( # pyright: ignore[reportMissingImports]
|
||||
load_all_blocks,
|
||||
)
|
||||
|
||||
_ = load_all_blocks # Silence unused import warning
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def load_blocks() -> dict[str, type]:
|
||||
"""Load all blocks from platform codebase.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping block IDs to block classes.
|
||||
"""
|
||||
global _blocks_cache
|
||||
if _blocks_cache is not None:
|
||||
return _blocks_cache
|
||||
|
||||
_ensure_platform_path()
|
||||
|
||||
try:
|
||||
from backend.blocks import ( # pyright: ignore[reportMissingImports]
|
||||
load_all_blocks,
|
||||
)
|
||||
|
||||
loaded: dict[str, type] = load_all_blocks()
|
||||
_blocks_cache = loaded
|
||||
logger.info(f"Loaded {len(loaded)} platform blocks")
|
||||
return loaded
|
||||
except ImportError as e:
|
||||
logger.warning(f"Could not import platform blocks: {e}")
|
||||
return {}
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading platform blocks: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def get_block(block_id: str) -> Any | None:
|
||||
"""Get a specific block instance by ID.
|
||||
|
||||
Args:
|
||||
block_id: The unique block ID (UUID format).
|
||||
|
||||
Returns:
|
||||
Block instance or None if not found.
|
||||
"""
|
||||
blocks = load_blocks()
|
||||
block_cls = blocks.get(block_id)
|
||||
if block_cls:
|
||||
return block_cls()
|
||||
return None
|
||||
|
||||
|
||||
def search_blocks(query: str, limit: int = 20) -> list[dict[str, Any]]:
|
||||
"""Search blocks by name or description.
|
||||
|
||||
Args:
|
||||
query: Search query (case-insensitive).
|
||||
limit: Maximum number of results to return.
|
||||
|
||||
Returns:
|
||||
List of block info dictionaries.
|
||||
"""
|
||||
blocks = load_blocks()
|
||||
results: list[dict[str, Any]] = []
|
||||
query_lower = query.lower()
|
||||
|
||||
for block_id, block_cls in blocks.items():
|
||||
try:
|
||||
block = block_cls()
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not instantiate block {block_id}: {e}")
|
||||
continue
|
||||
|
||||
# Skip disabled blocks
|
||||
if getattr(block, "disabled", False):
|
||||
continue
|
||||
|
||||
# Get name and description
|
||||
name = getattr(block, "name", block_cls.__name__)
|
||||
description = getattr(block, "description", "")
|
||||
|
||||
# Check name and description for match
|
||||
name_match = query_lower in name.lower()
|
||||
desc_match = query_lower in description.lower()
|
||||
|
||||
# Check categories
|
||||
categories = []
|
||||
if hasattr(block, "categories"):
|
||||
categories = [c.value for c in block.categories]
|
||||
category_match = any(query_lower in c.lower() for c in categories)
|
||||
|
||||
if name_match or desc_match or category_match:
|
||||
# Get input schema
|
||||
input_schema: dict[str, Any] = {}
|
||||
if hasattr(block, "input_schema"):
|
||||
try:
|
||||
input_schema = block.input_schema.jsonschema()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Get output schema
|
||||
output_schema: dict[str, Any] = {}
|
||||
if hasattr(block, "output_schema"):
|
||||
try:
|
||||
output_schema = block.output_schema.jsonschema()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
results.append(
|
||||
{
|
||||
"id": block_id,
|
||||
"name": name,
|
||||
"description": description,
|
||||
"categories": categories,
|
||||
"input_schema": input_schema,
|
||||
"output_schema": output_schema,
|
||||
}
|
||||
)
|
||||
|
||||
if len(results) >= limit:
|
||||
break
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def get_block_info(block_id: str) -> dict[str, Any] | None:
|
||||
"""Get detailed information about a specific block.
|
||||
|
||||
Args:
|
||||
block_id: The unique block ID (UUID format).
|
||||
|
||||
Returns:
|
||||
Block info dictionary or None if not found.
|
||||
"""
|
||||
block = get_block(block_id)
|
||||
if not block:
|
||||
return None
|
||||
|
||||
name = getattr(block, "name", block.__class__.__name__)
|
||||
description = getattr(block, "description", "")
|
||||
|
||||
categories = []
|
||||
if hasattr(block, "categories"):
|
||||
categories = [c.value for c in block.categories]
|
||||
|
||||
input_schema: dict[str, Any] = {}
|
||||
if hasattr(block, "input_schema"):
|
||||
try:
|
||||
input_schema = block.input_schema.jsonschema()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
output_schema: dict[str, Any] = {}
|
||||
if hasattr(block, "output_schema"):
|
||||
try:
|
||||
output_schema = block.output_schema.jsonschema()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
"id": block_id,
|
||||
"name": name,
|
||||
"description": description,
|
||||
"categories": categories,
|
||||
"input_schema": input_schema,
|
||||
"output_schema": output_schema,
|
||||
}
|
||||
|
||||
|
||||
def clear_cache() -> None:
|
||||
"""Clear the blocks cache. Useful for testing."""
|
||||
global _blocks_cache, _platform_path_added
|
||||
_blocks_cache = None
|
||||
_platform_path_added = False
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Tests for PlatformBlocksComponent."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -13,302 +13,206 @@ from forge.components.platform_blocks.client import PlatformClient, PlatformClie
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_blocks():
|
||||
"""Create mock block classes for testing."""
|
||||
|
||||
class MockInputSchema:
|
||||
@classmethod
|
||||
def jsonschema(cls):
|
||||
return {
|
||||
def mock_blocks_response():
|
||||
"""Mock response from platform API /blocks endpoint."""
|
||||
return [
|
||||
{
|
||||
"id": "email-block-id",
|
||||
"name": "SendEmailBlock",
|
||||
"description": "Send an email message",
|
||||
"categories": [{"category": "COMMUNICATION"}],
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {"type": "string", "description": "Input text"},
|
||||
},
|
||||
"required": ["text"],
|
||||
}
|
||||
|
||||
class MockOutputSchema:
|
||||
@classmethod
|
||||
def jsonschema(cls):
|
||||
return {
|
||||
"properties": {"to": {"type": "string"}, "body": {"type": "string"}},
|
||||
},
|
||||
"outputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"result": {"type": "string"},
|
||||
},
|
||||
}
|
||||
|
||||
class MockEmailBlock:
|
||||
def __init__(self):
|
||||
self.name = "SendEmailBlock"
|
||||
self.description = "Send an email message"
|
||||
self.categories = [MagicMock(value="Communication")]
|
||||
self.disabled = False
|
||||
self.input_schema = MockInputSchema
|
||||
self.output_schema = MockOutputSchema
|
||||
|
||||
class MockSearchBlock:
|
||||
def __init__(self):
|
||||
self.name = "WebSearchBlock"
|
||||
self.description = "Search the web for information"
|
||||
self.categories = [MagicMock(value="Search")]
|
||||
self.disabled = False
|
||||
self.input_schema = MockInputSchema
|
||||
self.output_schema = MockOutputSchema
|
||||
|
||||
class MockDisabledBlock:
|
||||
def __init__(self):
|
||||
self.name = "DisabledBlock"
|
||||
self.description = "A disabled block"
|
||||
self.categories = []
|
||||
self.disabled = True
|
||||
self.input_schema = MockInputSchema
|
||||
self.output_schema = MockOutputSchema
|
||||
|
||||
return {
|
||||
"email-block-id": MockEmailBlock,
|
||||
"search-block-id": MockSearchBlock,
|
||||
"disabled-block-id": MockDisabledBlock,
|
||||
}
|
||||
"properties": {"success": {"type": "boolean"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "search-block-id",
|
||||
"name": "WebSearchBlock",
|
||||
"description": "Search the web for information",
|
||||
"categories": [{"category": "SEARCH"}],
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {"query": {"type": "string"}},
|
||||
},
|
||||
"outputSchema": {
|
||||
"type": "object",
|
||||
"properties": {"results": {"type": "array"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "ai-block-id",
|
||||
"name": "AITextGeneratorBlock",
|
||||
"description": "Generate text using AI",
|
||||
"categories": [{"category": "AI"}],
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {"prompt": {"type": "string"}},
|
||||
},
|
||||
"outputSchema": {
|
||||
"type": "object",
|
||||
"properties": {"text": {"type": "string"}},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def component_with_mocks(mock_blocks):
|
||||
"""Create a PlatformBlocksComponent with mocked loader."""
|
||||
with (
|
||||
patch(
|
||||
"forge.components.platform_blocks.loader.is_platform_available",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"forge.components.platform_blocks.loader.load_blocks",
|
||||
return_value=mock_blocks,
|
||||
),
|
||||
):
|
||||
yield PlatformBlocksComponent()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def component_unavailable():
|
||||
"""Create a PlatformBlocksComponent when platform is unavailable."""
|
||||
with patch(
|
||||
"forge.components.platform_blocks.loader.is_platform_available",
|
||||
return_value=False,
|
||||
):
|
||||
yield PlatformBlocksComponent()
|
||||
|
||||
|
||||
class TestPlatformAvailability:
|
||||
"""Tests for platform availability handling."""
|
||||
|
||||
def test_component_disabled_when_platform_unavailable(self, component_unavailable):
|
||||
"""Component should yield no commands when platform unavailable."""
|
||||
commands = list(component_unavailable.get_commands())
|
||||
assert len(commands) == 0
|
||||
|
||||
def test_component_enabled_when_platform_available(self, component_with_mocks):
|
||||
"""Component should yield commands when platform is available."""
|
||||
commands = list(component_with_mocks.get_commands())
|
||||
assert len(commands) == 2
|
||||
|
||||
def test_get_resources_when_unavailable(self, component_unavailable):
|
||||
"""Should not yield resources when platform unavailable."""
|
||||
resources = list(component_unavailable.get_resources())
|
||||
assert len(resources) == 0
|
||||
|
||||
def test_get_resources_when_available(self, component_with_mocks):
|
||||
"""Should yield resource info when platform available."""
|
||||
resources = list(component_with_mocks.get_resources())
|
||||
assert len(resources) == 1
|
||||
assert "3" in resources[0] # 3 blocks loaded
|
||||
assert "search_blocks" in resources[0]
|
||||
def component():
|
||||
"""Create a PlatformBlocksComponent for testing."""
|
||||
return PlatformBlocksComponent()
|
||||
|
||||
|
||||
class TestSearchBlocks:
|
||||
"""Tests for the search_blocks command."""
|
||||
|
||||
def test_search_by_name(self, component_with_mocks, mock_blocks):
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_by_name(self, component, mock_blocks_response):
|
||||
"""Search should find blocks by name."""
|
||||
with patch(
|
||||
"forge.components.platform_blocks.loader.load_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
result_json = component_with_mocks.search_blocks("email")
|
||||
result = json.loads(result_json)
|
||||
component._client = AsyncMock(spec=PlatformClient)
|
||||
component._client.list_blocks.return_value = mock_blocks_response
|
||||
|
||||
assert result["count"] == 1
|
||||
assert result["blocks"][0]["name"] == "SendEmailBlock"
|
||||
assert result["blocks"][0]["id"] == "email-block-id"
|
||||
result_json = await component.search_blocks("email")
|
||||
result = json.loads(result_json)
|
||||
|
||||
def test_search_by_description(self, component_with_mocks, mock_blocks):
|
||||
assert result["count"] == 1
|
||||
assert result["blocks"][0]["name"] == "SendEmailBlock"
|
||||
assert result["blocks"][0]["id"] == "email-block-id"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_by_description(self, component, mock_blocks_response):
|
||||
"""Search should find blocks by description."""
|
||||
with patch(
|
||||
"forge.components.platform_blocks.loader.load_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
result_json = component_with_mocks.search_blocks("web")
|
||||
result = json.loads(result_json)
|
||||
component._client = AsyncMock(spec=PlatformClient)
|
||||
component._client.list_blocks.return_value = mock_blocks_response
|
||||
|
||||
assert result["count"] == 1
|
||||
assert result["blocks"][0]["name"] == "WebSearchBlock"
|
||||
result_json = await component.search_blocks("web")
|
||||
result = json.loads(result_json)
|
||||
|
||||
def test_search_by_category(self, component_with_mocks, mock_blocks):
|
||||
assert result["count"] == 1
|
||||
assert result["blocks"][0]["name"] == "WebSearchBlock"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_by_category(self, component, mock_blocks_response):
|
||||
"""Search should find blocks by category."""
|
||||
with patch(
|
||||
"forge.components.platform_blocks.loader.load_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
result_json = component_with_mocks.search_blocks("Communication")
|
||||
result = json.loads(result_json)
|
||||
component._client = AsyncMock(spec=PlatformClient)
|
||||
component._client.list_blocks.return_value = mock_blocks_response
|
||||
|
||||
assert result["count"] == 1
|
||||
assert result["blocks"][0]["name"] == "SendEmailBlock"
|
||||
result_json = await component.search_blocks("COMMUNICATION")
|
||||
result = json.loads(result_json)
|
||||
|
||||
def test_search_excludes_disabled_blocks(self, component_with_mocks, mock_blocks):
|
||||
"""Search should not return disabled blocks."""
|
||||
with patch(
|
||||
"forge.components.platform_blocks.loader.load_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
result_json = component_with_mocks.search_blocks("disabled")
|
||||
result = json.loads(result_json)
|
||||
assert result["count"] == 1
|
||||
assert result["blocks"][0]["name"] == "SendEmailBlock"
|
||||
|
||||
assert result["count"] == 0
|
||||
|
||||
def test_search_no_results(self, component_with_mocks, mock_blocks):
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_no_results(self, component, mock_blocks_response):
|
||||
"""Search with no matches should return empty results."""
|
||||
with patch(
|
||||
"forge.components.platform_blocks.loader.load_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
result_json = component_with_mocks.search_blocks("nonexistent")
|
||||
result = json.loads(result_json)
|
||||
component._client = AsyncMock(spec=PlatformClient)
|
||||
component._client.list_blocks.return_value = mock_blocks_response
|
||||
|
||||
assert result["count"] == 0
|
||||
assert result["blocks"] == []
|
||||
result_json = await component.search_blocks("nonexistent")
|
||||
result = json.loads(result_json)
|
||||
|
||||
def test_search_includes_schema(self, component_with_mocks, mock_blocks):
|
||||
assert result["count"] == 0
|
||||
assert result["blocks"] == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_includes_schema(self, component, mock_blocks_response):
|
||||
"""Search results should include input schema."""
|
||||
with patch(
|
||||
"forge.components.platform_blocks.loader.load_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
result_json = component_with_mocks.search_blocks("email")
|
||||
result = json.loads(result_json)
|
||||
component._client = AsyncMock(spec=PlatformClient)
|
||||
component._client.list_blocks.return_value = mock_blocks_response
|
||||
|
||||
assert "input_schema" in result["blocks"][0]
|
||||
assert "properties" in result["blocks"][0]["input_schema"]
|
||||
result_json = await component.search_blocks("email")
|
||||
result = json.loads(result_json)
|
||||
|
||||
assert "input_schema" in result["blocks"][0]
|
||||
assert "properties" in result["blocks"][0]["input_schema"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_api_error(self, component):
|
||||
"""Search should handle API errors gracefully by returning empty results."""
|
||||
component._client = AsyncMock(spec=PlatformClient)
|
||||
component._client.list_blocks.side_effect = PlatformClientError(
|
||||
"Connection failed", status_code=500
|
||||
)
|
||||
|
||||
result_json = await component.search_blocks("test")
|
||||
result = json.loads(result_json)
|
||||
|
||||
# On API error, returns empty results (graceful degradation)
|
||||
assert result["count"] == 0
|
||||
assert result["blocks"] == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_caches_blocks(self, component, mock_blocks_response):
|
||||
"""Blocks should be cached after first fetch."""
|
||||
component._client = AsyncMock(spec=PlatformClient)
|
||||
component._client.list_blocks.return_value = mock_blocks_response
|
||||
|
||||
await component.search_blocks("email")
|
||||
await component.search_blocks("web")
|
||||
|
||||
# Should only call API once
|
||||
assert component._client.list_blocks.call_count == 1
|
||||
|
||||
|
||||
class TestExecuteBlock:
|
||||
"""Tests for the execute_block command."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_block_success(self, component_with_mocks, mock_blocks):
|
||||
async def test_execute_block_success(self, component, mock_blocks_response):
|
||||
"""Execute should return success with outputs."""
|
||||
mock_client = AsyncMock(spec=PlatformClient)
|
||||
mock_client.check_credentials.return_value = {
|
||||
"has_required_credentials": True,
|
||||
}
|
||||
mock_client.execute_block.return_value = {
|
||||
"outputs": {"result": "Email sent successfully"},
|
||||
}
|
||||
component._client = AsyncMock(spec=PlatformClient)
|
||||
component._client.list_blocks.return_value = mock_blocks_response
|
||||
component._client.execute_block.return_value = {"success": True}
|
||||
|
||||
with patch(
|
||||
"forge.components.platform_blocks.loader.get_block",
|
||||
return_value=mock_blocks["email-block-id"](),
|
||||
):
|
||||
component_with_mocks._client = mock_client
|
||||
result_json = await component.execute_block(
|
||||
block_id="email-block-id",
|
||||
input_data={"to": "test@example.com", "body": "Hello"},
|
||||
)
|
||||
result = json.loads(result_json)
|
||||
|
||||
result_json = await component_with_mocks.execute_block(
|
||||
block_id="email-block-id",
|
||||
input_data={"text": "Hello world"},
|
||||
)
|
||||
result = json.loads(result_json)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["block"] == "SendEmailBlock"
|
||||
assert result["outputs"]["result"] == "Email sent successfully"
|
||||
assert result["success"] is True
|
||||
assert result["block"] == "SendEmailBlock"
|
||||
assert result["outputs"]["success"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_block_missing_credentials(
|
||||
self, component_with_mocks, mock_blocks
|
||||
):
|
||||
"""Execute should return error when credentials are missing."""
|
||||
mock_client = AsyncMock(spec=PlatformClient)
|
||||
mock_client.check_credentials.return_value = {
|
||||
"has_required_credentials": False,
|
||||
"missing_credentials": ["gmail_oauth"],
|
||||
}
|
||||
|
||||
with patch(
|
||||
"forge.components.platform_blocks.loader.get_block",
|
||||
return_value=mock_blocks["email-block-id"](),
|
||||
):
|
||||
component_with_mocks._client = mock_client
|
||||
|
||||
result_json = await component_with_mocks.execute_block(
|
||||
block_id="email-block-id",
|
||||
input_data={"text": "Hello"},
|
||||
)
|
||||
result = json.loads(result_json)
|
||||
|
||||
assert "error" in result
|
||||
assert result["error"] == "Missing required credentials"
|
||||
assert "gmail_oauth" in result["missing_credentials"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_block_api_error(self, component_with_mocks, mock_blocks):
|
||||
async def test_execute_block_api_error(self, component, mock_blocks_response):
|
||||
"""Execute should handle API errors gracefully."""
|
||||
mock_client = AsyncMock(spec=PlatformClient)
|
||||
mock_client.check_credentials.return_value = {
|
||||
"has_required_credentials": True,
|
||||
}
|
||||
mock_client.execute_block.side_effect = PlatformClientError(
|
||||
component._client = AsyncMock(spec=PlatformClient)
|
||||
component._client.list_blocks.return_value = mock_blocks_response
|
||||
component._client.execute_block.side_effect = PlatformClientError(
|
||||
"Block execution failed", status_code=500
|
||||
)
|
||||
|
||||
with patch(
|
||||
"forge.components.platform_blocks.loader.get_block",
|
||||
return_value=mock_blocks["email-block-id"](),
|
||||
):
|
||||
component_with_mocks._client = mock_client
|
||||
result_json = await component.execute_block(
|
||||
block_id="email-block-id",
|
||||
input_data={"to": "test@example.com"},
|
||||
)
|
||||
result = json.loads(result_json)
|
||||
|
||||
result_json = await component_with_mocks.execute_block(
|
||||
block_id="email-block-id",
|
||||
input_data={"text": "Hello"},
|
||||
)
|
||||
result = json.loads(result_json)
|
||||
|
||||
assert "error" in result
|
||||
assert result["status_code"] == 500
|
||||
assert "error" in result
|
||||
assert result["status_code"] == 500
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_block_credential_check_fails(
|
||||
self, component_with_mocks, mock_blocks
|
||||
):
|
||||
"""Execute should continue when credential check fails."""
|
||||
mock_client = AsyncMock(spec=PlatformClient)
|
||||
mock_client.check_credentials.side_effect = PlatformClientError(
|
||||
"Connection error"
|
||||
async def test_execute_unknown_block(self, component, mock_blocks_response):
|
||||
"""Execute with unknown block ID should still attempt execution."""
|
||||
component._client = AsyncMock(spec=PlatformClient)
|
||||
component._client.list_blocks.return_value = mock_blocks_response
|
||||
component._client.execute_block.return_value = {"result": "ok"}
|
||||
|
||||
result_json = await component.execute_block(
|
||||
block_id="unknown-id",
|
||||
input_data={"test": "data"},
|
||||
)
|
||||
mock_client.execute_block.return_value = {
|
||||
"outputs": {"result": "Success"},
|
||||
}
|
||||
result = json.loads(result_json)
|
||||
|
||||
with patch(
|
||||
"forge.components.platform_blocks.loader.get_block",
|
||||
return_value=mock_blocks["email-block-id"](),
|
||||
):
|
||||
component_with_mocks._client = mock_client
|
||||
|
||||
result_json = await component_with_mocks.execute_block(
|
||||
block_id="email-block-id",
|
||||
input_data={"text": "Hello"},
|
||||
)
|
||||
result = json.loads(result_json)
|
||||
|
||||
# Should still succeed since execution worked
|
||||
assert result["success"] is True
|
||||
# Should use block_id as name when not found
|
||||
assert result["block"] == "unknown-id"
|
||||
assert result["success"] is True
|
||||
|
||||
|
||||
class TestConfiguration:
|
||||
@@ -320,7 +224,6 @@ class TestConfiguration:
|
||||
assert config.enabled is True
|
||||
assert config.platform_url == "https://platform.agpt.co"
|
||||
assert config.api_key == ""
|
||||
assert config.user_id == ""
|
||||
assert config.timeout == 60
|
||||
|
||||
def test_custom_configuration(self):
|
||||
@@ -329,41 +232,33 @@ class TestConfiguration:
|
||||
enabled=False,
|
||||
platform_url="https://dev-builder.agpt.co",
|
||||
api_key="test-key",
|
||||
user_id="test-user",
|
||||
timeout=120,
|
||||
)
|
||||
assert config.enabled is False
|
||||
assert config.platform_url == "https://dev-builder.agpt.co"
|
||||
assert config.api_key == "test-key"
|
||||
assert config.user_id == "test-user"
|
||||
assert config.timeout == 120
|
||||
|
||||
def test_component_respects_disabled_config(self):
|
||||
"""Component should not yield commands when disabled."""
|
||||
with patch(
|
||||
"forge.components.platform_blocks.loader.is_platform_available",
|
||||
return_value=True,
|
||||
):
|
||||
component = PlatformBlocksComponent(
|
||||
config=PlatformBlocksConfig(enabled=False)
|
||||
)
|
||||
commands = list(component.get_commands())
|
||||
assert len(commands) == 0
|
||||
component = PlatformBlocksComponent(config=PlatformBlocksConfig(enabled=False))
|
||||
commands = list(component.get_commands())
|
||||
assert len(commands) == 0
|
||||
|
||||
|
||||
class TestProtocols:
|
||||
"""Tests for protocol implementations."""
|
||||
|
||||
def test_get_commands(self, component_with_mocks):
|
||||
def test_get_commands(self, component):
|
||||
"""CommandProvider.get_commands should yield commands."""
|
||||
commands = list(component_with_mocks.get_commands())
|
||||
commands = list(component.get_commands())
|
||||
command_names = [c.names[0] for c in commands]
|
||||
assert "search_blocks" in command_names
|
||||
assert "execute_block" in command_names
|
||||
|
||||
def test_command_aliases(self, component_with_mocks):
|
||||
def test_command_aliases(self, component):
|
||||
"""Commands should have proper aliases."""
|
||||
commands = list(component_with_mocks.get_commands())
|
||||
commands = list(component.get_commands())
|
||||
|
||||
for cmd in commands:
|
||||
if "search_blocks" in cmd.names:
|
||||
@@ -371,6 +266,12 @@ class TestProtocols:
|
||||
if "execute_block" in cmd.names:
|
||||
assert "run_block" in cmd.names
|
||||
|
||||
def test_get_resources(self, component):
|
||||
"""DirectiveProvider.get_resources should yield resource info."""
|
||||
resources = list(component.get_resources())
|
||||
assert len(resources) == 1
|
||||
assert "search_blocks" in resources[0]
|
||||
|
||||
|
||||
class TestPlatformClient:
|
||||
"""Tests for PlatformClient."""
|
||||
|
||||
Reference in New Issue
Block a user