mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-03-17 03:00:27 -04:00
Compare commits
5 Commits
fix/transc
...
feat/brows
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1c26799867 | ||
|
|
a2ca05611e | ||
|
|
9ae1bda93d | ||
|
|
3fbc695da7 | ||
|
|
a131737f57 |
@@ -10,6 +10,7 @@ from .add_understanding import AddUnderstandingTool
|
||||
from .agent_output import AgentOutputTool
|
||||
from .base import BaseTool
|
||||
from .bash_exec import BashExecTool
|
||||
from .browse_web import BrowseWebTool
|
||||
from .create_agent import CreateAgentTool
|
||||
from .customize_agent import CustomizeAgentTool
|
||||
from .edit_agent import EditAgentTool
|
||||
@@ -50,6 +51,8 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"get_doc_page": GetDocPageTool(),
|
||||
# Web fetch for safe URL retrieval
|
||||
"web_fetch": WebFetchTool(),
|
||||
# Browser-based browsing for JS-rendered pages (Stagehand + Browserbase)
|
||||
"browse_web": BrowseWebTool(),
|
||||
# Sandboxed code execution (bubblewrap)
|
||||
"bash_exec": BashExecTool(),
|
||||
# Persistent workspace tools (cloud storage, survives across sessions)
|
||||
|
||||
227
autogpt_platform/backend/backend/copilot/tools/browse_web.py
Normal file
227
autogpt_platform/backend/backend/copilot/tools/browse_web.py
Normal file
@@ -0,0 +1,227 @@
|
||||
"""Web browsing tool — navigate real browser sessions to extract page content.
|
||||
|
||||
Uses Stagehand + Browserbase for cloud-based browser execution. Handles
|
||||
JS-rendered pages, SPAs, and dynamic content that web_fetch cannot reach.
|
||||
|
||||
Requires environment variables:
|
||||
STAGEHAND_API_KEY — Browserbase API key
|
||||
STAGEHAND_PROJECT_ID — Browserbase project ID
|
||||
ANTHROPIC_API_KEY — LLM key used by Stagehand for extraction
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import BrowseWebResponse, ErrorResponse, ToolResponseBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Stagehand uses the LLM internally for natural-language extraction/actions.
|
||||
_STAGEHAND_MODEL = "anthropic/claude-sonnet-4-5-20250929"
|
||||
# Hard cap on extracted content returned to the LLM context.
|
||||
_MAX_CONTENT_CHARS = 50_000
|
||||
# Explicit timeouts for Stagehand browser operations (milliseconds).
|
||||
_GOTO_TIMEOUT_MS = 30_000 # page navigation
|
||||
_EXTRACT_TIMEOUT_MS = 60_000 # LLM extraction
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Thread-safety patch for Stagehand signal handlers (applied lazily, once).
|
||||
#
|
||||
# Stagehand calls signal.signal() during __init__, which raises ValueError
|
||||
# when called from a non-main thread (e.g. the CoPilot executor thread pool).
|
||||
# We patch _register_signal_handlers to be a no-op outside the main thread.
|
||||
# The patch is applied exactly once per process via double-checked locking.
|
||||
# ---------------------------------------------------------------------------
|
||||
_stagehand_patched = False
|
||||
_patch_lock = threading.Lock()
|
||||
|
||||
|
||||
def _patch_stagehand_once() -> None:
|
||||
"""Monkey-patch Stagehand signal handler registration to be thread-safe.
|
||||
|
||||
Must be called after ``import stagehand.main`` has succeeded.
|
||||
Safe to call from multiple threads — applies the patch at most once.
|
||||
"""
|
||||
global _stagehand_patched
|
||||
if _stagehand_patched:
|
||||
return
|
||||
with _patch_lock:
|
||||
if _stagehand_patched:
|
||||
return
|
||||
import stagehand.main # noqa: PLC0415
|
||||
|
||||
_original = stagehand.main.Stagehand._register_signal_handlers
|
||||
|
||||
def _safe_register(self: Any) -> None:
|
||||
if threading.current_thread() is threading.main_thread():
|
||||
_original(self)
|
||||
|
||||
stagehand.main.Stagehand._register_signal_handlers = _safe_register
|
||||
_stagehand_patched = True
|
||||
|
||||
|
||||
class BrowseWebTool(BaseTool):
|
||||
"""Navigate a URL with a real browser and extract its content.
|
||||
|
||||
Use this instead of ``web_fetch`` when the page requires JavaScript
|
||||
to render (SPAs, dashboards, paywalled content with JS checks, etc.).
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "browse_web"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Navigate to a URL using a real browser and extract content. "
|
||||
"Handles JavaScript-rendered pages and dynamic content that "
|
||||
"web_fetch cannot reach. "
|
||||
"Specify exactly what to extract via the `instruction` parameter."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The HTTP/HTTPS URL to navigate to.",
|
||||
},
|
||||
"instruction": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"What to extract from the page. Be specific — e.g. "
|
||||
"'Extract all pricing plans with features and prices', "
|
||||
"'Get the main article text and author', "
|
||||
"'List all navigation links'. "
|
||||
"Defaults to extracting the main page content."
|
||||
),
|
||||
"default": "Extract the main content of this page.",
|
||||
},
|
||||
},
|
||||
"required": ["url"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None, # noqa: ARG002
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
"""Navigate to a URL with a real browser and return extracted content."""
|
||||
url: str = (kwargs.get("url") or "").strip()
|
||||
instruction: str = (
|
||||
kwargs.get("instruction") or "Extract the main content of this page."
|
||||
)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not url:
|
||||
return ErrorResponse(
|
||||
message="Please provide a URL to browse.",
|
||||
error="missing_url",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not url.startswith(("http://", "https://")):
|
||||
return ErrorResponse(
|
||||
message="Only HTTP/HTTPS URLs are supported.",
|
||||
error="invalid_url",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
api_key = os.environ.get("STAGEHAND_API_KEY")
|
||||
project_id = os.environ.get("STAGEHAND_PROJECT_ID")
|
||||
model_api_key = os.environ.get("ANTHROPIC_API_KEY")
|
||||
|
||||
if not api_key or not project_id:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Web browsing is not configured on this platform. "
|
||||
"STAGEHAND_API_KEY and STAGEHAND_PROJECT_ID are required."
|
||||
),
|
||||
error="not_configured",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not model_api_key:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Web browsing is not configured: ANTHROPIC_API_KEY is required "
|
||||
"for Stagehand's extraction model."
|
||||
),
|
||||
error="not_configured",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Lazy import — Stagehand is an optional heavy dependency.
|
||||
# Importing here scopes any ImportError to this tool only, so other
|
||||
# tools continue to register and work normally if Stagehand is absent.
|
||||
try:
|
||||
from stagehand import Stagehand # noqa: PLC0415
|
||||
except ImportError:
|
||||
return ErrorResponse(
|
||||
message="Web browsing is not available: Stagehand is not installed.",
|
||||
error="not_configured",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Apply the signal handler patch now that we know stagehand is present.
|
||||
_patch_stagehand_once()
|
||||
|
||||
client: Any | None = None
|
||||
try:
|
||||
client = Stagehand(
|
||||
api_key=api_key,
|
||||
project_id=project_id,
|
||||
model_name=_STAGEHAND_MODEL,
|
||||
model_api_key=model_api_key,
|
||||
)
|
||||
await client.init()
|
||||
|
||||
page = client.page
|
||||
assert page is not None, "Stagehand page is not initialized"
|
||||
await page.goto(url, timeoutMs=_GOTO_TIMEOUT_MS)
|
||||
result = await page.extract(instruction, timeoutMs=_EXTRACT_TIMEOUT_MS)
|
||||
|
||||
# Extract the text content from the Pydantic result model.
|
||||
raw = result.model_dump().get("extraction", "")
|
||||
content = str(raw) if raw else ""
|
||||
|
||||
truncated = len(content) > _MAX_CONTENT_CHARS
|
||||
if truncated:
|
||||
suffix = "\n\n[Content truncated]"
|
||||
keep = max(0, _MAX_CONTENT_CHARS - len(suffix))
|
||||
content = content[:keep] + suffix
|
||||
|
||||
return BrowseWebResponse(
|
||||
message=f"Browsed {url}",
|
||||
url=url,
|
||||
content=content,
|
||||
truncated=truncated,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("[browse_web] Failed for %s", url)
|
||||
return ErrorResponse(
|
||||
message="Failed to browse URL.",
|
||||
error="browse_failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
finally:
|
||||
if client is not None:
|
||||
try:
|
||||
await client.close()
|
||||
except Exception:
|
||||
pass
|
||||
@@ -0,0 +1,486 @@
|
||||
"""Unit tests for BrowseWebTool.
|
||||
|
||||
All tests run without a running server / database. External dependencies
|
||||
(Stagehand, Browserbase) are mocked via sys.modules injection so the suite
|
||||
stays fast and deterministic.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import threading
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import backend.copilot.tools.browse_web as _browse_web_mod
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.browse_web import (
|
||||
_MAX_CONTENT_CHARS,
|
||||
BrowseWebTool,
|
||||
_patch_stagehand_once,
|
||||
)
|
||||
from backend.copilot.tools.models import BrowseWebResponse, ErrorResponse, ResponseType
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def make_session(user_id: str = "test-user") -> ChatSession:
|
||||
return ChatSession(
|
||||
session_id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
messages=[],
|
||||
usage=[],
|
||||
started_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
successful_agent_runs={},
|
||||
successful_agent_schedules={},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_stagehand_patch():
|
||||
"""Reset the process-level _stagehand_patched flag before every test."""
|
||||
_browse_web_mod._stagehand_patched = False
|
||||
yield
|
||||
_browse_web_mod._stagehand_patched = False
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def env_vars(monkeypatch):
|
||||
"""Inject the three env vars required by BrowseWebTool."""
|
||||
monkeypatch.setenv("STAGEHAND_API_KEY", "test-api-key")
|
||||
monkeypatch.setenv("STAGEHAND_PROJECT_ID", "test-project-id")
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "test-anthropic-key")
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def stagehand_mocks(monkeypatch):
|
||||
"""Inject mock stagehand + stagehand.main into sys.modules.
|
||||
|
||||
Returns a dict with the mock objects so individual tests can
|
||||
assert on calls or inject side-effects.
|
||||
"""
|
||||
# --- mock page ---
|
||||
mock_result = MagicMock()
|
||||
mock_result.model_dump.return_value = {"extraction": "Page content here"}
|
||||
|
||||
mock_page = AsyncMock()
|
||||
mock_page.goto = AsyncMock(return_value=None)
|
||||
mock_page.extract = AsyncMock(return_value=mock_result)
|
||||
|
||||
# --- mock client ---
|
||||
mock_client = AsyncMock()
|
||||
mock_client.page = mock_page
|
||||
mock_client.init = AsyncMock(return_value=None)
|
||||
mock_client.close = AsyncMock(return_value=None)
|
||||
|
||||
MockStagehand = MagicMock(return_value=mock_client)
|
||||
|
||||
# --- stagehand top-level module ---
|
||||
mock_stagehand = MagicMock()
|
||||
mock_stagehand.Stagehand = MockStagehand
|
||||
|
||||
# --- stagehand.main (needed by _patch_stagehand_once) ---
|
||||
mock_main = MagicMock()
|
||||
mock_main.Stagehand = MagicMock()
|
||||
mock_main.Stagehand._register_signal_handlers = MagicMock()
|
||||
|
||||
monkeypatch.setitem(sys.modules, "stagehand", mock_stagehand)
|
||||
monkeypatch.setitem(sys.modules, "stagehand.main", mock_main)
|
||||
|
||||
return {
|
||||
"client": mock_client,
|
||||
"page": mock_page,
|
||||
"result": mock_result,
|
||||
"MockStagehand": MockStagehand,
|
||||
"mock_main": mock_main,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. Tool metadata
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBrowseWebToolMetadata:
|
||||
def test_name(self):
|
||||
assert BrowseWebTool().name == "browse_web"
|
||||
|
||||
def test_requires_auth(self):
|
||||
assert BrowseWebTool().requires_auth is True
|
||||
|
||||
def test_url_is_required_parameter(self):
|
||||
params = BrowseWebTool().parameters
|
||||
assert "url" in params["properties"]
|
||||
assert "url" in params["required"]
|
||||
|
||||
def test_instruction_is_optional(self):
|
||||
params = BrowseWebTool().parameters
|
||||
assert "instruction" in params["properties"]
|
||||
assert "instruction" not in params.get("required", [])
|
||||
|
||||
def test_registered_in_tool_registry(self):
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
assert "browse_web" in TOOL_REGISTRY
|
||||
assert isinstance(TOOL_REGISTRY["browse_web"], BrowseWebTool)
|
||||
|
||||
def test_response_type_enum_value(self):
|
||||
assert ResponseType.BROWSE_WEB == "browse_web"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. Input validation (no external deps)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestInputValidation:
|
||||
async def test_missing_url_returns_error(self):
|
||||
result = await BrowseWebTool()._execute(user_id="u1", session=make_session())
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "url" in result.message.lower()
|
||||
|
||||
async def test_empty_url_returns_error(self):
|
||||
result = await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url=""
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
async def test_ftp_url_rejected(self):
|
||||
result = await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="ftp://example.com/file"
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "http" in result.message.lower()
|
||||
|
||||
async def test_file_url_rejected(self):
|
||||
result = await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="file:///etc/passwd"
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
async def test_javascript_url_rejected(self):
|
||||
result = await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="javascript:alert(1)"
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. Environment variable checks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEnvVarChecks:
|
||||
async def test_missing_api_key(self, monkeypatch):
|
||||
monkeypatch.delenv("STAGEHAND_API_KEY", raising=False)
|
||||
monkeypatch.setenv("STAGEHAND_PROJECT_ID", "proj")
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "key")
|
||||
result = await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="https://example.com"
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "not_configured"
|
||||
|
||||
async def test_missing_project_id(self, monkeypatch):
|
||||
monkeypatch.setenv("STAGEHAND_API_KEY", "key")
|
||||
monkeypatch.delenv("STAGEHAND_PROJECT_ID", raising=False)
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "key")
|
||||
result = await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="https://example.com"
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "not_configured"
|
||||
|
||||
async def test_missing_anthropic_key(self, monkeypatch):
|
||||
monkeypatch.setenv("STAGEHAND_API_KEY", "key")
|
||||
monkeypatch.setenv("STAGEHAND_PROJECT_ID", "proj")
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
result = await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="https://example.com"
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "not_configured"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. Stagehand absent (ImportError path)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStagehandAbsent:
|
||||
async def test_returns_not_configured_error(self, env_vars, monkeypatch):
|
||||
"""Blocking the stagehand import must return a graceful ErrorResponse."""
|
||||
# sys.modules entry set to None → Python raises ImportError on import
|
||||
monkeypatch.setitem(sys.modules, "stagehand", None)
|
||||
monkeypatch.setitem(sys.modules, "stagehand.main", None)
|
||||
|
||||
result = await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="https://example.com"
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "not_configured"
|
||||
assert "not available" in result.message or "not installed" in result.message
|
||||
|
||||
async def test_other_tools_unaffected_when_stagehand_absent(
|
||||
self, env_vars, monkeypatch
|
||||
):
|
||||
"""Registry import must not raise even when stagehand is blocked."""
|
||||
monkeypatch.setitem(sys.modules, "stagehand", None)
|
||||
# This import already happened at module load; just verify the registry exists
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
assert "browse_web" in TOOL_REGISTRY
|
||||
assert "web_fetch" in TOOL_REGISTRY # unrelated tool still present
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 5. Successful browse
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSuccessfulBrowse:
|
||||
async def test_returns_browse_web_response(self, env_vars, stagehand_mocks):
|
||||
result = await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="https://example.com"
|
||||
)
|
||||
assert isinstance(result, BrowseWebResponse)
|
||||
assert result.url == "https://example.com"
|
||||
assert result.content == "Page content here"
|
||||
assert result.truncated is False
|
||||
|
||||
async def test_http_url_accepted(self, env_vars, stagehand_mocks):
|
||||
result = await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="http://example.com"
|
||||
)
|
||||
assert isinstance(result, BrowseWebResponse)
|
||||
|
||||
async def test_session_id_propagated(self, env_vars, stagehand_mocks):
|
||||
session = make_session()
|
||||
result = await BrowseWebTool()._execute(
|
||||
user_id="u1", session=session, url="https://example.com"
|
||||
)
|
||||
assert isinstance(result, BrowseWebResponse)
|
||||
assert result.session_id == session.session_id
|
||||
|
||||
async def test_custom_instruction_forwarded_to_extract(
|
||||
self, env_vars, stagehand_mocks
|
||||
):
|
||||
await BrowseWebTool()._execute(
|
||||
user_id="u1",
|
||||
session=make_session(),
|
||||
url="https://example.com",
|
||||
instruction="Extract all pricing plans",
|
||||
)
|
||||
stagehand_mocks["page"].extract.assert_awaited_once()
|
||||
first_arg = stagehand_mocks["page"].extract.call_args[0][0]
|
||||
assert first_arg == "Extract all pricing plans"
|
||||
|
||||
async def test_default_instruction_used_when_omitted(
|
||||
self, env_vars, stagehand_mocks
|
||||
):
|
||||
await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="https://example.com"
|
||||
)
|
||||
first_arg = stagehand_mocks["page"].extract.call_args[0][0]
|
||||
assert "main content" in first_arg.lower()
|
||||
|
||||
async def test_explicit_timeouts_passed_to_stagehand(
|
||||
self, env_vars, stagehand_mocks
|
||||
):
|
||||
from backend.copilot.tools.browse_web import (
|
||||
_EXTRACT_TIMEOUT_MS,
|
||||
_GOTO_TIMEOUT_MS,
|
||||
)
|
||||
|
||||
await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="https://example.com"
|
||||
)
|
||||
goto_kwargs = stagehand_mocks["page"].goto.call_args[1]
|
||||
extract_kwargs = stagehand_mocks["page"].extract.call_args[1]
|
||||
assert goto_kwargs.get("timeoutMs") == _GOTO_TIMEOUT_MS
|
||||
assert extract_kwargs.get("timeoutMs") == _EXTRACT_TIMEOUT_MS
|
||||
|
||||
async def test_client_closed_after_success(self, env_vars, stagehand_mocks):
|
||||
await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="https://example.com"
|
||||
)
|
||||
stagehand_mocks["client"].close.assert_awaited_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 6. Truncation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTruncation:
|
||||
async def test_short_content_not_truncated(self, env_vars, stagehand_mocks):
|
||||
stagehand_mocks["result"].model_dump.return_value = {"extraction": "short"}
|
||||
result = await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="https://example.com"
|
||||
)
|
||||
assert isinstance(result, BrowseWebResponse)
|
||||
assert result.truncated is False
|
||||
assert result.content == "short"
|
||||
|
||||
async def test_oversized_content_is_truncated(self, env_vars, stagehand_mocks):
|
||||
big = "a" * (_MAX_CONTENT_CHARS + 1000)
|
||||
stagehand_mocks["result"].model_dump.return_value = {"extraction": big}
|
||||
result = await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="https://example.com"
|
||||
)
|
||||
assert isinstance(result, BrowseWebResponse)
|
||||
assert result.truncated is True
|
||||
assert result.content.endswith("[Content truncated]")
|
||||
|
||||
async def test_truncated_content_never_exceeds_cap(self, env_vars, stagehand_mocks):
|
||||
"""The final string must be ≤ _MAX_CONTENT_CHARS regardless of input size."""
|
||||
big = "b" * (_MAX_CONTENT_CHARS * 3)
|
||||
stagehand_mocks["result"].model_dump.return_value = {"extraction": big}
|
||||
result = await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="https://example.com"
|
||||
)
|
||||
assert isinstance(result, BrowseWebResponse)
|
||||
assert len(result.content) == _MAX_CONTENT_CHARS
|
||||
|
||||
async def test_content_exactly_at_limit_not_truncated(
|
||||
self, env_vars, stagehand_mocks
|
||||
):
|
||||
exact = "c" * _MAX_CONTENT_CHARS
|
||||
stagehand_mocks["result"].model_dump.return_value = {"extraction": exact}
|
||||
result = await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="https://example.com"
|
||||
)
|
||||
assert isinstance(result, BrowseWebResponse)
|
||||
assert result.truncated is False
|
||||
assert len(result.content) == _MAX_CONTENT_CHARS
|
||||
|
||||
async def test_empty_extraction_returns_empty_content(
|
||||
self, env_vars, stagehand_mocks
|
||||
):
|
||||
stagehand_mocks["result"].model_dump.return_value = {"extraction": ""}
|
||||
result = await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="https://example.com"
|
||||
)
|
||||
assert isinstance(result, BrowseWebResponse)
|
||||
assert result.content == ""
|
||||
assert result.truncated is False
|
||||
|
||||
async def test_none_extraction_returns_empty_content(
|
||||
self, env_vars, stagehand_mocks
|
||||
):
|
||||
stagehand_mocks["result"].model_dump.return_value = {"extraction": None}
|
||||
result = await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="https://example.com"
|
||||
)
|
||||
assert isinstance(result, BrowseWebResponse)
|
||||
assert result.content == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 7. Error handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
async def test_stagehand_init_exception_returns_generic_error(
|
||||
self, env_vars, stagehand_mocks
|
||||
):
|
||||
stagehand_mocks["client"].init.side_effect = RuntimeError("Connection refused")
|
||||
result = await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="https://example.com"
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "browse_failed"
|
||||
|
||||
async def test_raw_exception_text_not_leaked_to_user(
|
||||
self, env_vars, stagehand_mocks
|
||||
):
|
||||
"""Internal error details must not appear in the user-facing message."""
|
||||
stagehand_mocks["client"].init.side_effect = RuntimeError("SECRET_TOKEN_abc123")
|
||||
result = await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="https://example.com"
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "SECRET_TOKEN_abc123" not in result.message
|
||||
assert result.message == "Failed to browse URL."
|
||||
|
||||
async def test_goto_timeout_returns_error(self, env_vars, stagehand_mocks):
|
||||
stagehand_mocks["page"].goto.side_effect = TimeoutError("Navigation timed out")
|
||||
result = await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="https://example.com"
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "browse_failed"
|
||||
|
||||
async def test_client_closed_after_exception(self, env_vars, stagehand_mocks):
|
||||
stagehand_mocks["page"].goto.side_effect = RuntimeError("boom")
|
||||
await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="https://example.com"
|
||||
)
|
||||
stagehand_mocks["client"].close.assert_awaited_once()
|
||||
|
||||
async def test_close_failure_does_not_propagate(self, env_vars, stagehand_mocks):
|
||||
"""If close() itself raises, the tool must still return ErrorResponse."""
|
||||
stagehand_mocks["client"].init.side_effect = RuntimeError("init failed")
|
||||
stagehand_mocks["client"].close.side_effect = RuntimeError("close also failed")
|
||||
result = await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="https://example.com"
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 8. Thread-safety of _patch_stagehand_once
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPatchStagehandOnce:
|
||||
def test_idempotent_double_call(self, stagehand_mocks):
|
||||
"""_stagehand_patched transitions False→True exactly once."""
|
||||
assert _browse_web_mod._stagehand_patched is False
|
||||
_patch_stagehand_once()
|
||||
assert _browse_web_mod._stagehand_patched is True
|
||||
_patch_stagehand_once() # second call — still True, not re-patched
|
||||
assert _browse_web_mod._stagehand_patched is True
|
||||
|
||||
def test_safe_register_is_noop_in_worker_thread(self, stagehand_mocks):
|
||||
"""The patched handler must silently do nothing when called from a worker."""
|
||||
_patch_stagehand_once()
|
||||
mock_main = sys.modules["stagehand.main"]
|
||||
safe_register = mock_main.Stagehand._register_signal_handlers
|
||||
|
||||
errors: list[Exception] = []
|
||||
|
||||
def run():
|
||||
try:
|
||||
safe_register(MagicMock())
|
||||
except Exception as exc:
|
||||
errors.append(exc)
|
||||
|
||||
t = threading.Thread(target=run)
|
||||
t.start()
|
||||
t.join()
|
||||
|
||||
assert errors == [], f"Worker thread raised: {errors}"
|
||||
|
||||
def test_patched_flag_set_after_execution(self, env_vars, stagehand_mocks):
|
||||
"""After a successful browse, _stagehand_patched must be True."""
|
||||
|
||||
async def _run():
|
||||
return await BrowseWebTool()._execute(
|
||||
user_id="u1", session=make_session(), url="https://example.com"
|
||||
)
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(_run())
|
||||
assert _browse_web_mod._stagehand_patched is True
|
||||
@@ -41,6 +41,8 @@ class ResponseType(str, Enum):
|
||||
INPUT_VALIDATION_ERROR = "input_validation_error"
|
||||
# Web fetch
|
||||
WEB_FETCH = "web_fetch"
|
||||
# Browser-based web browsing (JS-rendered pages)
|
||||
BROWSE_WEB = "browse_web"
|
||||
# Code execution
|
||||
BASH_EXEC = "bash_exec"
|
||||
# Feature request types
|
||||
@@ -438,6 +440,15 @@ class WebFetchResponse(ToolResponseBase):
|
||||
truncated: bool = False
|
||||
|
||||
|
||||
class BrowseWebResponse(ToolResponseBase):
|
||||
"""Response for browse_web tool."""
|
||||
|
||||
type: ResponseType = ResponseType.BROWSE_WEB
|
||||
url: str
|
||||
content: str
|
||||
truncated: bool = False
|
||||
|
||||
|
||||
class BashExecResponse(ToolResponseBase):
|
||||
"""Response for bash_exec tool."""
|
||||
|
||||
|
||||
@@ -11167,6 +11167,7 @@
|
||||
"operation_in_progress",
|
||||
"input_validation_error",
|
||||
"web_fetch",
|
||||
"browse_web",
|
||||
"bash_exec",
|
||||
"feature_request_search",
|
||||
"feature_request_created",
|
||||
|
||||
Reference in New Issue
Block a user