mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
refactor(backend): clean up block search code, add tests
- Extract _get_enabled_blocks() to deduplicate disabled/broken block filtering between get_missing_items() and get_stats() - Reuse block instance from disabled check (no double instantiation) - Move split_camelcase import to module level in hybrid_search.py - Add tests: split_camelcase, tokenize CamelCase, disabled-block batch budget, get_stats with broken blocks
This commit is contained in:
@@ -155,6 +155,26 @@ class StoreAgentHandler(ContentHandler):
|
||||
}
|
||||
|
||||
|
||||
def _get_enabled_blocks() -> dict[str, Any]:
|
||||
"""Return ``{block_id: block_instance}`` for all enabled, instantiable blocks.
|
||||
|
||||
Disabled blocks and blocks that fail to instantiate are silently skipped
|
||||
(with a warning log), so callers never need their own try/except loop.
|
||||
"""
|
||||
from backend.blocks import get_blocks
|
||||
|
||||
enabled: dict[str, Any] = {}
|
||||
for block_id, block_cls in get_blocks().items():
|
||||
try:
|
||||
instance = block_cls()
|
||||
except Exception as e:
|
||||
logger.warning("Skipping block %s: init failed: %s", block_id, e)
|
||||
continue
|
||||
if not instance.disabled:
|
||||
enabled[block_id] = instance
|
||||
return enabled
|
||||
|
||||
|
||||
class BlockHandler(ContentHandler):
|
||||
"""Handler for block definitions (Python classes)."""
|
||||
|
||||
@@ -164,16 +184,11 @@ class BlockHandler(ContentHandler):
|
||||
|
||||
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
||||
"""Fetch blocks without embeddings."""
|
||||
from backend.blocks import get_blocks
|
||||
|
||||
# Get all available blocks
|
||||
all_blocks = get_blocks()
|
||||
|
||||
# Check which ones have embeddings
|
||||
if not all_blocks:
|
||||
enabled = _get_enabled_blocks()
|
||||
if not enabled:
|
||||
return []
|
||||
|
||||
block_ids = list(all_blocks.keys())
|
||||
block_ids = list(enabled.keys())
|
||||
|
||||
# Query for existing embeddings
|
||||
placeholders = ",".join([f"${i+1}" for i in range(len(block_ids))])
|
||||
@@ -188,58 +203,41 @@ class BlockHandler(ContentHandler):
|
||||
)
|
||||
|
||||
existing_ids = {row["contentId"] for row in existing_result}
|
||||
# Filter disabled blocks before applying batch_size so that a large
|
||||
# number of disabled blocks can't exhaust the batch budget and prevent
|
||||
# enabled blocks from being indexed.
|
||||
missing_blocks: list[tuple[str, type]] = []
|
||||
for block_id, block_cls in all_blocks.items():
|
||||
|
||||
# Convert to ContentItem — disabled filtering already done by
|
||||
# _get_enabled_blocks so batch_size won't be exhausted by disabled blocks.
|
||||
items = []
|
||||
for block_id, block in enabled.items():
|
||||
if block_id in existing_ids:
|
||||
continue
|
||||
try:
|
||||
if block_cls().disabled:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(f"Skipping block {block_id}: failed to init: {e}")
|
||||
continue
|
||||
missing_blocks.append((block_id, block_cls))
|
||||
if len(items) >= batch_size:
|
||||
break
|
||||
|
||||
# Convert to ContentItem
|
||||
items = []
|
||||
for block_id, block_cls in missing_blocks[:batch_size]:
|
||||
try:
|
||||
block_instance = block_cls()
|
||||
|
||||
# Build searchable text from block metadata
|
||||
parts = []
|
||||
if block_instance.name:
|
||||
parts.append(split_camelcase(block_instance.name))
|
||||
if block_instance.description:
|
||||
parts.append(block_instance.description)
|
||||
if block_instance.categories:
|
||||
parts.append(
|
||||
" ".join(str(cat.value) for cat in block_instance.categories)
|
||||
)
|
||||
if block.name:
|
||||
parts.append(split_camelcase(block.name))
|
||||
if block.description:
|
||||
parts.append(block.description)
|
||||
if block.categories:
|
||||
parts.append(" ".join(str(cat.value) for cat in block.categories))
|
||||
|
||||
# Add input schema field descriptions
|
||||
block_input_fields = block_instance.input_schema.model_fields
|
||||
parts += [
|
||||
f"{field_name}: {field_info.description}"
|
||||
for field_name, field_info in block_input_fields.items()
|
||||
for field_name, field_info in block.input_schema.model_fields.items()
|
||||
if field_info.description
|
||||
]
|
||||
|
||||
searchable_text = " ".join(parts)
|
||||
|
||||
categories_list = (
|
||||
[cat.value for cat in block_instance.categories]
|
||||
if block_instance.categories
|
||||
else []
|
||||
[cat.value for cat in block.categories] if block.categories else []
|
||||
)
|
||||
|
||||
# Extract provider names from credentials fields
|
||||
credentials_info = (
|
||||
block_instance.input_schema.get_credentials_fields_info()
|
||||
)
|
||||
credentials_info = block.input_schema.get_credentials_fields_info()
|
||||
is_integration = len(credentials_info) > 0
|
||||
provider_names = [
|
||||
provider.value.lower()
|
||||
@@ -250,7 +248,7 @@ class BlockHandler(ContentHandler):
|
||||
# Check if block has LlmModel field in input schema
|
||||
has_llm_model_field = any(
|
||||
_contains_type(field.annotation, LlmModel)
|
||||
for field in block_instance.input_schema.model_fields.values()
|
||||
for field in block.input_schema.model_fields.values()
|
||||
)
|
||||
|
||||
items.append(
|
||||
@@ -259,13 +257,13 @@ class BlockHandler(ContentHandler):
|
||||
content_type=ContentType.BLOCK,
|
||||
searchable_text=searchable_text,
|
||||
metadata={
|
||||
"name": block_instance.name,
|
||||
"name": block.name,
|
||||
"categories": categories_list,
|
||||
"providers": provider_names,
|
||||
"has_llm_model_field": has_llm_model_field,
|
||||
"is_integration": is_integration,
|
||||
},
|
||||
user_id=None, # Blocks are public
|
||||
user_id=None,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -276,28 +274,13 @@ class BlockHandler(ContentHandler):
|
||||
|
||||
async def get_stats(self) -> dict[str, int]:
|
||||
"""Get statistics about block embedding coverage."""
|
||||
from backend.blocks import get_blocks
|
||||
|
||||
all_blocks = get_blocks()
|
||||
|
||||
# Filter out disabled blocks - they're not indexed
|
||||
enabled_block_ids: list[str] = []
|
||||
for block_id, block_cls in all_blocks.items():
|
||||
try:
|
||||
if block_cls().disabled:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Skipping block %s in stats: init failed: %s", block_id, e
|
||||
)
|
||||
continue
|
||||
enabled_block_ids.append(block_id)
|
||||
total_blocks = len(enabled_block_ids)
|
||||
enabled = _get_enabled_blocks()
|
||||
total_blocks = len(enabled)
|
||||
|
||||
if total_blocks == 0:
|
||||
return {"total": 0, "with_embeddings": 0, "without_embeddings": 0}
|
||||
|
||||
block_ids = enabled_block_ids
|
||||
block_ids = list(enabled.keys())
|
||||
placeholders = ",".join([f"${i+1}" for i in range(len(block_ids))])
|
||||
|
||||
embedded_result = await query_raw_with_schema(
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""
|
||||
E2E tests for content handlers (blocks, store agents, documentation).
|
||||
|
||||
Tests the full flow: discovering content → generating embeddings → storing.
|
||||
Tests for content handlers (blocks, store agents, documentation).
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
@@ -15,15 +13,80 @@ from backend.api.features.store.content_handlers import (
|
||||
BlockHandler,
|
||||
DocumentationHandler,
|
||||
StoreAgentHandler,
|
||||
_get_enabled_blocks,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper to build a mock block class that returns a pre-configured instance
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_block_class(
|
||||
*,
|
||||
name: str = "Block",
|
||||
description: str = "",
|
||||
disabled: bool = False,
|
||||
categories: list | None = None,
|
||||
fields: dict | None = None,
|
||||
raise_on_init: Exception | None = None,
|
||||
) -> MagicMock:
|
||||
cls = MagicMock()
|
||||
if raise_on_init:
|
||||
cls.side_effect = raise_on_init
|
||||
return cls
|
||||
inst = MagicMock()
|
||||
inst.name = name
|
||||
inst.disabled = disabled
|
||||
inst.description = description
|
||||
inst.categories = categories or []
|
||||
field_mocks = {}
|
||||
for fname, fdesc in (fields or {}).items():
|
||||
f = MagicMock()
|
||||
f.description = fdesc
|
||||
field_mocks[fname] = f
|
||||
inst.input_schema.model_fields = field_mocks
|
||||
inst.input_schema.get_credentials_fields_info.return_value = {}
|
||||
cls.return_value = inst
|
||||
return cls
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _get_enabled_blocks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_enabled_blocks_filters_disabled():
|
||||
"""Disabled blocks are excluded."""
|
||||
blocks = {
|
||||
"enabled": _make_block_class(name="E", disabled=False),
|
||||
"disabled": _make_block_class(name="D", disabled=True),
|
||||
}
|
||||
with patch("backend.blocks.get_blocks", return_value=blocks):
|
||||
result = _get_enabled_blocks()
|
||||
assert list(result.keys()) == ["enabled"]
|
||||
|
||||
|
||||
def test_get_enabled_blocks_skips_broken():
|
||||
"""Blocks that raise on init are skipped, not crash."""
|
||||
blocks = {
|
||||
"good": _make_block_class(name="Good"),
|
||||
"bad": _make_block_class(raise_on_init=RuntimeError("boom")),
|
||||
}
|
||||
with patch("backend.blocks.get_blocks", return_value=blocks):
|
||||
result = _get_enabled_blocks()
|
||||
assert list(result.keys()) == ["good"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# StoreAgentHandler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_agent_handler_get_missing_items(mocker):
|
||||
"""Test StoreAgentHandler fetches approved agents without embeddings."""
|
||||
handler = StoreAgentHandler()
|
||||
|
||||
# Mock database query
|
||||
mock_missing = [
|
||||
{
|
||||
"id": "agent-1",
|
||||
@@ -54,9 +117,7 @@ async def test_store_agent_handler_get_stats(mocker):
|
||||
"""Test StoreAgentHandler returns correct stats."""
|
||||
handler = StoreAgentHandler()
|
||||
|
||||
# Mock approved count query
|
||||
mock_approved = [{"count": 50}]
|
||||
# Mock embedded count query
|
||||
mock_embedded = [{"count": 30}]
|
||||
|
||||
with patch(
|
||||
@@ -70,42 +131,36 @@ async def test_store_agent_handler_get_stats(mocker):
|
||||
assert stats["without_embeddings"] == 20
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BlockHandler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_get_missing_items(mocker):
|
||||
async def test_block_handler_get_missing_items():
|
||||
"""Test BlockHandler discovers blocks without embeddings."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock get_blocks to return test blocks
|
||||
mock_block_class = MagicMock()
|
||||
mock_block_instance = MagicMock()
|
||||
mock_block_instance.name = "Calculator Block"
|
||||
mock_block_instance.description = "Performs calculations"
|
||||
mock_block_instance.categories = [MagicMock(value="MATH")]
|
||||
mock_block_instance.disabled = False
|
||||
mock_field = MagicMock()
|
||||
mock_field.description = "Math expression to evaluate"
|
||||
mock_block_instance.input_schema.model_fields = {"expression": mock_field}
|
||||
mock_block_instance.input_schema.get_credentials_fields_info.return_value = {}
|
||||
mock_block_class.return_value = mock_block_instance
|
||||
blocks = {
|
||||
"block-uuid-1": _make_block_class(
|
||||
name="CalculatorBlock",
|
||||
description="Performs calculations",
|
||||
categories=[MagicMock(value="MATH")],
|
||||
fields={"expression": "Math expression to evaluate"},
|
||||
),
|
||||
}
|
||||
|
||||
mock_blocks = {"block-uuid-1": mock_block_class}
|
||||
|
||||
# Mock existing embeddings query (no embeddings exist)
|
||||
mock_existing = []
|
||||
|
||||
with patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
with patch("backend.blocks.get_blocks", return_value=blocks):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=mock_existing,
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].content_id == "block-uuid-1"
|
||||
assert items[0].content_type == ContentType.BLOCK
|
||||
# CamelCase should be split in searchable text
|
||||
assert "Calculator Block" in items[0].searchable_text
|
||||
assert "Performs calculations" in items[0].searchable_text
|
||||
assert "MATH" in items[0].searchable_text
|
||||
@@ -114,31 +169,63 @@ async def test_block_handler_get_missing_items(mocker):
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_get_stats(mocker):
|
||||
async def test_block_handler_get_missing_items_splits_camelcase():
|
||||
"""CamelCase block names are split for better search indexing."""
|
||||
handler = BlockHandler()
|
||||
|
||||
blocks = {
|
||||
"ai-block": _make_block_class(name="AITextGeneratorBlock"),
|
||||
}
|
||||
|
||||
with patch("backend.blocks.get_blocks", return_value=blocks):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
assert "AI Text Generator Block" in items[0].searchable_text
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_disabled_dont_exhaust_batch():
|
||||
"""Disabled blocks don't consume batch budget, so enabled blocks get indexed."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# 5 disabled + 3 enabled, batch_size=2
|
||||
blocks = {
|
||||
**{
|
||||
f"dis-{i}": _make_block_class(name=f"D{i}", disabled=True) for i in range(5)
|
||||
},
|
||||
**{f"en-{i}": _make_block_class(name=f"E{i}") for i in range(3)},
|
||||
}
|
||||
|
||||
with patch("backend.blocks.get_blocks", return_value=blocks):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=2)
|
||||
|
||||
assert len(items) == 2
|
||||
assert all(item.content_id.startswith("en-") for item in items)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_get_stats():
|
||||
"""Test BlockHandler returns correct stats."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock get_blocks - each block class returns an instance with disabled=False
|
||||
def make_mock_block_class():
|
||||
mock_class = MagicMock()
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.disabled = False
|
||||
mock_class.return_value = mock_instance
|
||||
return mock_class
|
||||
|
||||
mock_blocks = {
|
||||
"block-1": make_mock_block_class(),
|
||||
"block-2": make_mock_block_class(),
|
||||
"block-3": make_mock_block_class(),
|
||||
blocks = {
|
||||
"block-1": _make_block_class(name="B1"),
|
||||
"block-2": _make_block_class(name="B2"),
|
||||
"block-3": _make_block_class(name="B3"),
|
||||
}
|
||||
|
||||
# Mock embedded count query (2 blocks have embeddings)
|
||||
mock_embedded = [{"count": 2}]
|
||||
|
||||
with patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
with patch("backend.blocks.get_blocks", return_value=blocks):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=mock_embedded,
|
||||
@@ -150,21 +237,84 @@ async def test_block_handler_get_stats(mocker):
|
||||
assert stats["without_embeddings"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_get_stats_skips_broken():
|
||||
"""get_stats skips broken blocks instead of crashing."""
|
||||
handler = BlockHandler()
|
||||
|
||||
blocks = {
|
||||
"good": _make_block_class(name="Good"),
|
||||
"bad": _make_block_class(raise_on_init=RuntimeError("boom")),
|
||||
}
|
||||
|
||||
mock_embedded = [{"count": 1}]
|
||||
|
||||
with patch("backend.blocks.get_blocks", return_value=blocks):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=mock_embedded,
|
||||
):
|
||||
stats = await handler.get_stats()
|
||||
|
||||
assert stats["total"] == 1 # only the good block
|
||||
assert stats["with_embeddings"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_handles_empty_attributes():
|
||||
"""Test BlockHandler handles blocks with empty/falsy attribute values."""
|
||||
handler = BlockHandler()
|
||||
|
||||
blocks = {"block-minimal": _make_block_class(name="Minimal Block")}
|
||||
|
||||
with patch("backend.blocks.get_blocks", return_value=blocks):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].searchable_text == "Minimal Block"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_skips_failed_blocks():
|
||||
"""Test BlockHandler skips blocks that fail to instantiate."""
|
||||
handler = BlockHandler()
|
||||
|
||||
blocks = {
|
||||
"good-block": _make_block_class(name="Good Block", description="Works fine"),
|
||||
"bad-block": _make_block_class(raise_on_init=Exception("Instantiation failed")),
|
||||
}
|
||||
|
||||
with patch("backend.blocks.get_blocks", return_value=blocks):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].content_id == "good-block"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DocumentationHandler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_get_missing_items(tmp_path, mocker):
|
||||
"""Test DocumentationHandler discovers docs without embeddings."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Create temporary docs directory with test files
|
||||
docs_root = tmp_path / "docs"
|
||||
docs_root.mkdir()
|
||||
|
||||
(docs_root / "guide.md").write_text("# Getting Started\n\nThis is a guide.")
|
||||
(docs_root / "api.mdx").write_text("# API Reference\n\nAPI documentation.")
|
||||
|
||||
# Mock _get_docs_root to return temp dir
|
||||
with patch.object(handler, "_get_docs_root", return_value=docs_root):
|
||||
# Mock existing embeddings query (no embeddings exist)
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
@@ -173,7 +323,6 @@ async def test_documentation_handler_get_missing_items(tmp_path, mocker):
|
||||
|
||||
assert len(items) == 2
|
||||
|
||||
# Check guide.md (content_id format: doc_path::section_index)
|
||||
guide_item = next(
|
||||
(item for item in items if item.content_id == "guide.md::0"), None
|
||||
)
|
||||
@@ -184,7 +333,6 @@ async def test_documentation_handler_get_missing_items(tmp_path, mocker):
|
||||
assert guide_item.metadata["doc_title"] == "Getting Started"
|
||||
assert guide_item.user_id is None
|
||||
|
||||
# Check api.mdx (content_id format: doc_path::section_index)
|
||||
api_item = next(
|
||||
(item for item in items if item.content_id == "api.mdx::0"), None
|
||||
)
|
||||
@@ -197,14 +345,12 @@ async def test_documentation_handler_get_stats(tmp_path, mocker):
|
||||
"""Test DocumentationHandler returns correct stats."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Create temporary docs directory
|
||||
docs_root = tmp_path / "docs"
|
||||
docs_root.mkdir()
|
||||
(docs_root / "doc1.md").write_text("# Doc 1")
|
||||
(docs_root / "doc2.md").write_text("# Doc 2")
|
||||
(docs_root / "doc3.mdx").write_text("# Doc 3")
|
||||
|
||||
# Mock embedded count query (1 doc has embedding)
|
||||
mock_embedded = [{"count": 1}]
|
||||
|
||||
with patch.object(handler, "_get_docs_root", return_value=docs_root):
|
||||
@@ -224,13 +370,11 @@ async def test_documentation_handler_title_extraction(tmp_path):
|
||||
"""Test DocumentationHandler extracts title from markdown heading."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Test with heading
|
||||
doc_with_heading = tmp_path / "with_heading.md"
|
||||
doc_with_heading.write_text("# My Title\n\nContent here")
|
||||
title = handler._extract_doc_title(doc_with_heading)
|
||||
assert title == "My Title"
|
||||
|
||||
# Test without heading
|
||||
doc_without_heading = tmp_path / "no-heading.md"
|
||||
doc_without_heading.write_text("Just content, no heading")
|
||||
title = handler._extract_doc_title(doc_without_heading)
|
||||
@@ -242,7 +386,6 @@ async def test_documentation_handler_markdown_chunking(tmp_path):
|
||||
"""Test DocumentationHandler chunks markdown by headings."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Test document with multiple sections
|
||||
doc_with_sections = tmp_path / "sections.md"
|
||||
doc_with_sections.write_text(
|
||||
"# Document Title\n\n"
|
||||
@@ -254,7 +397,6 @@ async def test_documentation_handler_markdown_chunking(tmp_path):
|
||||
)
|
||||
sections = handler._chunk_markdown_by_headings(doc_with_sections)
|
||||
|
||||
# Should have 3 sections: intro (with doc title), section one, section two
|
||||
assert len(sections) == 3
|
||||
assert sections[0].title == "Document Title"
|
||||
assert sections[0].index == 0
|
||||
@@ -268,7 +410,6 @@ async def test_documentation_handler_markdown_chunking(tmp_path):
|
||||
assert sections[2].index == 2
|
||||
assert "Content for section two" in sections[2].content
|
||||
|
||||
# Test document without headings
|
||||
doc_no_sections = tmp_path / "no-sections.md"
|
||||
doc_no_sections.write_text("Just plain content without any headings.")
|
||||
sections = handler._chunk_markdown_by_headings(doc_no_sections)
|
||||
@@ -282,21 +423,39 @@ async def test_documentation_handler_section_content_ids():
|
||||
"""Test DocumentationHandler creates and parses section content IDs."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Test making content ID
|
||||
content_id = handler._make_section_content_id("docs/guide.md", 2)
|
||||
assert content_id == "docs/guide.md::2"
|
||||
|
||||
# Test parsing content ID
|
||||
doc_path, section_index = handler._parse_section_content_id("docs/guide.md::2")
|
||||
assert doc_path == "docs/guide.md"
|
||||
assert section_index == 2
|
||||
|
||||
# Test parsing legacy format (no section index)
|
||||
doc_path, section_index = handler._parse_section_content_id("docs/old-format.md")
|
||||
assert doc_path == "docs/old-format.md"
|
||||
assert section_index == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_missing_docs_directory():
|
||||
"""Test DocumentationHandler handles missing docs directory gracefully."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
fake_path = Path("/nonexistent/docs")
|
||||
with patch.object(handler, "_get_docs_root", return_value=fake_path):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
assert items == []
|
||||
|
||||
stats = await handler.get_stats()
|
||||
assert stats["total"] == 0
|
||||
assert stats["with_embeddings"] == 0
|
||||
assert stats["without_embeddings"] == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_content_handlers_registry():
|
||||
"""Test all content types are registered."""
|
||||
@@ -307,88 +466,3 @@ async def test_content_handlers_registry():
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.STORE_AGENT], StoreAgentHandler)
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.BLOCK], BlockHandler)
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.DOCUMENTATION], DocumentationHandler)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_handles_empty_attributes():
|
||||
"""Test BlockHandler handles blocks with empty/falsy attribute values."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock block with empty values (all attributes exist but are falsy)
|
||||
mock_block_class = MagicMock()
|
||||
mock_block_instance = MagicMock()
|
||||
mock_block_instance.name = "Minimal Block"
|
||||
mock_block_instance.disabled = False
|
||||
mock_block_instance.description = ""
|
||||
mock_block_instance.categories = set()
|
||||
mock_block_instance.input_schema.model_fields = {}
|
||||
mock_block_instance.input_schema.get_credentials_fields_info.return_value = {}
|
||||
mock_block_class.return_value = mock_block_instance
|
||||
|
||||
mock_blocks = {"block-minimal": mock_block_class}
|
||||
|
||||
with patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].searchable_text == "Minimal Block"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_skips_failed_blocks():
|
||||
"""Test BlockHandler skips blocks that fail to instantiate."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock one good block and one bad block
|
||||
good_block = MagicMock()
|
||||
good_instance = MagicMock()
|
||||
good_instance.name = "Good Block"
|
||||
good_instance.description = "Works fine"
|
||||
good_instance.categories = []
|
||||
good_instance.disabled = False
|
||||
good_instance.input_schema.model_fields = {}
|
||||
good_instance.input_schema.get_credentials_fields_info.return_value = {}
|
||||
good_block.return_value = good_instance
|
||||
|
||||
bad_block = MagicMock()
|
||||
bad_block.side_effect = Exception("Instantiation failed")
|
||||
|
||||
mock_blocks = {"good-block": good_block, "bad-block": bad_block}
|
||||
|
||||
with patch(
|
||||
"backend.blocks.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
# Should only get the good block
|
||||
assert len(items) == 1
|
||||
assert items[0].content_id == "good-block"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_missing_docs_directory():
|
||||
"""Test DocumentationHandler handles missing docs directory gracefully."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Mock _get_docs_root to return non-existent path
|
||||
fake_path = Path("/nonexistent/docs")
|
||||
with patch.object(handler, "_get_docs_root", return_value=fake_path):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
assert items == []
|
||||
|
||||
stats = await handler.get_stats()
|
||||
assert stats["total"] == 0
|
||||
assert stats["with_embeddings"] == 0
|
||||
assert stats["without_embeddings"] == 0
|
||||
|
||||
@@ -20,6 +20,7 @@ from backend.api.features.store.embeddings import (
|
||||
embed_query,
|
||||
embedding_to_vector_string,
|
||||
)
|
||||
from backend.api.features.store.text_utils import split_camelcase
|
||||
from backend.data.db import query_raw_with_schema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -38,10 +39,7 @@ def tokenize(text: str) -> list[str]:
|
||||
"""
|
||||
if not text:
|
||||
return []
|
||||
from backend.api.features.store.text_utils import split_camelcase
|
||||
|
||||
tokens = re.findall(r"\b\w+\b", split_camelcase(text).lower())
|
||||
return tokens
|
||||
return re.findall(r"\b\w+\b", split_camelcase(text).lower())
|
||||
|
||||
|
||||
def bm25_rerank(
|
||||
|
||||
@@ -14,8 +14,49 @@ from backend.api.features.store.hybrid_search import (
|
||||
HybridSearchWeights,
|
||||
UnifiedSearchWeights,
|
||||
hybrid_search,
|
||||
tokenize,
|
||||
unified_hybrid_search,
|
||||
)
|
||||
from backend.api.features.store.text_utils import split_camelcase
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# split_camelcase
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_text, expected",
|
||||
[
|
||||
("AITextGeneratorBlock", "AI Text Generator Block"),
|
||||
("HTTPRequestBlock", "HTTP Request Block"),
|
||||
("simpleWord", "simple Word"),
|
||||
("already spaced", "already spaced"),
|
||||
("XMLParser", "XML Parser"),
|
||||
("getHTTPResponse", "get HTTP Response"),
|
||||
("Block", "Block"),
|
||||
("", ""),
|
||||
],
|
||||
)
|
||||
def test_split_camelcase(input_text: str, expected: str):
|
||||
assert split_camelcase(input_text) == expected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# tokenize (BM25)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_text, expected",
|
||||
[
|
||||
("AITextGeneratorBlock", ["ai", "text", "generator", "block"]),
|
||||
("hello world", ["hello", "world"]),
|
||||
("", []),
|
||||
("HTTPRequest", ["http", "request"]),
|
||||
],
|
||||
)
|
||||
def test_tokenize(input_text: str, expected: list[str]):
|
||||
assert tokenize(input_text) == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
|
||||
Reference in New Issue
Block a user