revert(store): also revert content_handlers_test.py from PR #12400

This commit is contained in:
Zamil Majdy
2026-03-17 05:19:26 +07:00
parent 7ff8bc8c5e
commit c342290910
6 changed files with 233 additions and 345 deletions

View File

@@ -15,11 +15,20 @@ from prisma.enums import ContentType
from backend.blocks.llm import LlmModel
from backend.data.db import query_raw_with_schema
from backend.util.text import split_camelcase
logger = logging.getLogger(__name__)
def _contains_type(annotation: Any, target: type) -> bool:
"""Check if an annotation is or contains the target type (handles Optional/Union/Annotated)."""
if annotation is target:
return True
origin = get_origin(annotation)
if origin is None:
return False
return any(_contains_type(arg, target) for arg in get_args(annotation))
@dataclass
class ContentItem:
"""Represents a piece of content to be embedded."""
@@ -154,11 +163,16 @@ class BlockHandler(ContentHandler):
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
"""Fetch blocks without embeddings."""
enabled = _get_enabled_blocks()
if not enabled:
from backend.blocks import get_blocks
# Get all available blocks
all_blocks = get_blocks()
# Check which ones have embeddings
if not all_blocks:
return []
block_ids = list(enabled.keys())
block_ids = list(all_blocks.keys())
# Query for existing embeddings
placeholders = ",".join([f"${i+1}" for i in range(len(block_ids))])
@@ -173,41 +187,52 @@ class BlockHandler(ContentHandler):
)
existing_ids = {row["contentId"] for row in existing_result}
missing_blocks = [
(block_id, block_cls)
for block_id, block_cls in all_blocks.items()
if block_id not in existing_ids
]
# Convert to ContentItem — disabled filtering already done by
# _get_enabled_blocks so batch_size won't be exhausted by disabled blocks.
# Convert to ContentItem
items = []
for block_id, block in enabled.items():
if block_id in existing_ids:
continue
if len(items) >= batch_size:
break
for block_id, block_cls in missing_blocks[:batch_size]:
try:
block_instance = block_cls()
if block_instance.disabled:
continue
# Build searchable text from block metadata
parts = []
if block.name:
parts.append(split_camelcase(block.name))
if block.description:
parts.append(block.description)
if block.categories:
parts.append(" ".join(str(cat.value) for cat in block.categories))
if block_instance.name:
parts.append(block_instance.name)
if block_instance.description:
parts.append(block_instance.description)
if block_instance.categories:
parts.append(
" ".join(str(cat.value) for cat in block_instance.categories)
)
# Add input schema field descriptions
block_input_fields = block_instance.input_schema.model_fields
parts += [
f"{field_name}: {field_info.description}"
for field_name, field_info in block.input_schema.model_fields.items()
for field_name, field_info in block_input_fields.items()
if field_info.description
]
searchable_text = " ".join(parts)
categories_list = (
[cat.value for cat in block.categories] if block.categories else []
[cat.value for cat in block_instance.categories]
if block_instance.categories
else []
)
# Extract provider names from credentials fields
credentials_info = block.input_schema.get_credentials_fields_info()
credentials_info = (
block_instance.input_schema.get_credentials_fields_info()
)
is_integration = len(credentials_info) > 0
provider_names = [
provider.value.lower()
@@ -218,7 +243,7 @@ class BlockHandler(ContentHandler):
# Check if block has LlmModel field in input schema
has_llm_model_field = any(
_contains_type(field.annotation, LlmModel)
for field in block.input_schema.model_fields.values()
for field in block_instance.input_schema.model_fields.values()
)
items.append(
@@ -227,30 +252,39 @@ class BlockHandler(ContentHandler):
content_type=ContentType.BLOCK,
searchable_text=searchable_text,
metadata={
"name": block.name,
"name": block_instance.name,
"categories": categories_list,
"providers": provider_names,
"has_llm_model_field": has_llm_model_field,
"is_integration": is_integration,
},
user_id=None,
user_id=None, # Blocks are public
)
)
except Exception as e:
logger.warning("Failed to process block %s: %s", block_id, e)
logger.warning(f"Failed to process block {block_id}: {e}")
continue
return items
async def get_stats(self) -> dict[str, int]:
"""Get statistics about block embedding coverage."""
enabled = _get_enabled_blocks()
total_blocks = len(enabled)
from backend.blocks import get_blocks
all_blocks = get_blocks()
# Filter out disabled blocks - they're not indexed
enabled_block_ids = [
block_id
for block_id, block_cls in all_blocks.items()
if not block_cls().disabled
]
total_blocks = len(enabled_block_ids)
if total_blocks == 0:
return {"total": 0, "with_embeddings": 0, "without_embeddings": 0}
block_ids = list(enabled.keys())
block_ids = enabled_block_ids
placeholders = ",".join([f"${i+1}" for i in range(len(block_ids))])
embedded_result = await query_raw_with_schema(
@@ -272,36 +306,6 @@ class BlockHandler(ContentHandler):
}
def _contains_type(annotation: Any, target: type) -> bool:
"""Check if an annotation is or contains the target type (handles Optional/Union/Annotated)."""
if annotation is target:
return True
origin = get_origin(annotation)
if origin is None:
return False
return any(_contains_type(arg, target) for arg in get_args(annotation))
def _get_enabled_blocks() -> dict[str, Any]:
"""Return ``{block_id: block_instance}`` for all enabled, instantiable blocks.
Disabled blocks and blocks that fail to instantiate are silently skipped
(with a warning log), so callers never need their own try/except loop.
"""
from backend.blocks import get_blocks
enabled: dict[str, Any] = {}
for block_id, block_cls in get_blocks().items():
try:
instance = block_cls()
except Exception as e:
logger.warning("Skipping block %s: init failed: %s", block_id, e)
continue
if not instance.disabled:
enabled[block_id] = instance
return enabled
@dataclass
class MarkdownSection:
"""Represents a section of a markdown document."""

View File

@@ -1,5 +1,7 @@
"""
Tests for content handlers (blocks, store agents, documentation).
E2E tests for content handlers (blocks, store agents, documentation).
Tests the full flow: discovering content → generating embeddings → storing.
"""
from pathlib import Path
@@ -13,80 +15,15 @@ from backend.api.features.store.content_handlers import (
BlockHandler,
DocumentationHandler,
StoreAgentHandler,
_get_enabled_blocks,
)
# ---------------------------------------------------------------------------
# Helper to build a mock block class that returns a pre-configured instance
# ---------------------------------------------------------------------------
def _make_block_class(
*,
name: str = "Block",
description: str = "",
disabled: bool = False,
categories: list | None = None,
fields: dict | None = None,
raise_on_init: Exception | None = None,
) -> MagicMock:
cls = MagicMock()
if raise_on_init:
cls.side_effect = raise_on_init
return cls
inst = MagicMock()
inst.name = name
inst.disabled = disabled
inst.description = description
inst.categories = categories or []
field_mocks = {}
for fname, fdesc in (fields or {}).items():
f = MagicMock()
f.description = fdesc
field_mocks[fname] = f
inst.input_schema.model_fields = field_mocks
inst.input_schema.get_credentials_fields_info.return_value = {}
cls.return_value = inst
return cls
# ---------------------------------------------------------------------------
# _get_enabled_blocks
# ---------------------------------------------------------------------------
def test_get_enabled_blocks_filters_disabled():
"""Disabled blocks are excluded."""
blocks = {
"enabled": _make_block_class(name="E", disabled=False),
"disabled": _make_block_class(name="D", disabled=True),
}
with patch("backend.blocks.get_blocks", return_value=blocks):
result = _get_enabled_blocks()
assert list(result.keys()) == ["enabled"]
def test_get_enabled_blocks_skips_broken():
"""Blocks that raise on init are skipped, not crash."""
blocks = {
"good": _make_block_class(name="Good"),
"bad": _make_block_class(raise_on_init=RuntimeError("boom")),
}
with patch("backend.blocks.get_blocks", return_value=blocks):
result = _get_enabled_blocks()
assert list(result.keys()) == ["good"]
# ---------------------------------------------------------------------------
# StoreAgentHandler
# ---------------------------------------------------------------------------
@pytest.mark.asyncio(loop_scope="session")
async def test_store_agent_handler_get_missing_items(mocker):
"""Test StoreAgentHandler fetches approved agents without embeddings."""
handler = StoreAgentHandler()
# Mock database query
mock_missing = [
{
"id": "agent-1",
@@ -117,7 +54,9 @@ async def test_store_agent_handler_get_stats(mocker):
"""Test StoreAgentHandler returns correct stats."""
handler = StoreAgentHandler()
# Mock approved count query
mock_approved = [{"count": 50}]
# Mock embedded count query
mock_embedded = [{"count": 30}]
with patch(
@@ -131,36 +70,42 @@ async def test_store_agent_handler_get_stats(mocker):
assert stats["without_embeddings"] == 20
# ---------------------------------------------------------------------------
# BlockHandler
# ---------------------------------------------------------------------------
@pytest.mark.asyncio(loop_scope="session")
async def test_block_handler_get_missing_items():
async def test_block_handler_get_missing_items(mocker):
"""Test BlockHandler discovers blocks without embeddings."""
handler = BlockHandler()
blocks = {
"block-uuid-1": _make_block_class(
name="CalculatorBlock",
description="Performs calculations",
categories=[MagicMock(value="MATH")],
fields={"expression": "Math expression to evaluate"},
),
}
# Mock get_blocks to return test blocks
mock_block_class = MagicMock()
mock_block_instance = MagicMock()
mock_block_instance.name = "Calculator Block"
mock_block_instance.description = "Performs calculations"
mock_block_instance.categories = [MagicMock(value="MATH")]
mock_block_instance.disabled = False
mock_field = MagicMock()
mock_field.description = "Math expression to evaluate"
mock_block_instance.input_schema.model_fields = {"expression": mock_field}
mock_block_instance.input_schema.get_credentials_fields_info.return_value = {}
mock_block_class.return_value = mock_block_instance
with patch("backend.blocks.get_blocks", return_value=blocks):
mock_blocks = {"block-uuid-1": mock_block_class}
# Mock existing embeddings query (no embeddings exist)
mock_existing = []
with patch(
"backend.blocks.get_blocks",
return_value=mock_blocks,
):
with patch(
"backend.api.features.store.content_handlers.query_raw_with_schema",
return_value=[],
return_value=mock_existing,
):
items = await handler.get_missing_items(batch_size=10)
assert len(items) == 1
assert items[0].content_id == "block-uuid-1"
assert items[0].content_type == ContentType.BLOCK
# CamelCase should be split in searchable text
assert "Calculator Block" in items[0].searchable_text
assert "Performs calculations" in items[0].searchable_text
assert "MATH" in items[0].searchable_text
@@ -169,63 +114,31 @@ async def test_block_handler_get_missing_items():
@pytest.mark.asyncio(loop_scope="session")
async def test_block_handler_get_missing_items_splits_camelcase():
"""CamelCase block names are split for better search indexing."""
handler = BlockHandler()
blocks = {
"ai-block": _make_block_class(name="AITextGeneratorBlock"),
}
with patch("backend.blocks.get_blocks", return_value=blocks):
with patch(
"backend.api.features.store.content_handlers.query_raw_with_schema",
return_value=[],
):
items = await handler.get_missing_items(batch_size=10)
assert len(items) == 1
assert "AI Text Generator Block" in items[0].searchable_text
@pytest.mark.asyncio(loop_scope="session")
async def test_block_handler_disabled_dont_exhaust_batch():
"""Disabled blocks don't consume batch budget, so enabled blocks get indexed."""
handler = BlockHandler()
# 5 disabled + 3 enabled, batch_size=2
blocks = {
**{
f"dis-{i}": _make_block_class(name=f"D{i}", disabled=True) for i in range(5)
},
**{f"en-{i}": _make_block_class(name=f"E{i}") for i in range(3)},
}
with patch("backend.blocks.get_blocks", return_value=blocks):
with patch(
"backend.api.features.store.content_handlers.query_raw_with_schema",
return_value=[],
):
items = await handler.get_missing_items(batch_size=2)
assert len(items) == 2
assert all(item.content_id.startswith("en-") for item in items)
@pytest.mark.asyncio(loop_scope="session")
async def test_block_handler_get_stats():
async def test_block_handler_get_stats(mocker):
"""Test BlockHandler returns correct stats."""
handler = BlockHandler()
blocks = {
"block-1": _make_block_class(name="B1"),
"block-2": _make_block_class(name="B2"),
"block-3": _make_block_class(name="B3"),
# Mock get_blocks - each block class returns an instance with disabled=False
def make_mock_block_class():
mock_class = MagicMock()
mock_instance = MagicMock()
mock_instance.disabled = False
mock_class.return_value = mock_instance
return mock_class
mock_blocks = {
"block-1": make_mock_block_class(),
"block-2": make_mock_block_class(),
"block-3": make_mock_block_class(),
}
# Mock embedded count query (2 blocks have embeddings)
mock_embedded = [{"count": 2}]
with patch("backend.blocks.get_blocks", return_value=blocks):
with patch(
"backend.blocks.get_blocks",
return_value=mock_blocks,
):
with patch(
"backend.api.features.store.content_handlers.query_raw_with_schema",
return_value=mock_embedded,
@@ -237,84 +150,21 @@ async def test_block_handler_get_stats():
assert stats["without_embeddings"] == 1
@pytest.mark.asyncio(loop_scope="session")
async def test_block_handler_get_stats_skips_broken():
"""get_stats skips broken blocks instead of crashing."""
handler = BlockHandler()
blocks = {
"good": _make_block_class(name="Good"),
"bad": _make_block_class(raise_on_init=RuntimeError("boom")),
}
mock_embedded = [{"count": 1}]
with patch("backend.blocks.get_blocks", return_value=blocks):
with patch(
"backend.api.features.store.content_handlers.query_raw_with_schema",
return_value=mock_embedded,
):
stats = await handler.get_stats()
assert stats["total"] == 1 # only the good block
assert stats["with_embeddings"] == 1
@pytest.mark.asyncio(loop_scope="session")
async def test_block_handler_handles_empty_attributes():
"""Test BlockHandler handles blocks with empty/falsy attribute values."""
handler = BlockHandler()
blocks = {"block-minimal": _make_block_class(name="Minimal Block")}
with patch("backend.blocks.get_blocks", return_value=blocks):
with patch(
"backend.api.features.store.content_handlers.query_raw_with_schema",
return_value=[],
):
items = await handler.get_missing_items(batch_size=10)
assert len(items) == 1
assert items[0].searchable_text == "Minimal Block"
@pytest.mark.asyncio(loop_scope="session")
async def test_block_handler_skips_failed_blocks():
"""Test BlockHandler skips blocks that fail to instantiate."""
handler = BlockHandler()
blocks = {
"good-block": _make_block_class(name="Good Block", description="Works fine"),
"bad-block": _make_block_class(raise_on_init=Exception("Instantiation failed")),
}
with patch("backend.blocks.get_blocks", return_value=blocks):
with patch(
"backend.api.features.store.content_handlers.query_raw_with_schema",
return_value=[],
):
items = await handler.get_missing_items(batch_size=10)
assert len(items) == 1
assert items[0].content_id == "good-block"
# ---------------------------------------------------------------------------
# DocumentationHandler
# ---------------------------------------------------------------------------
@pytest.mark.asyncio(loop_scope="session")
async def test_documentation_handler_get_missing_items(tmp_path, mocker):
"""Test DocumentationHandler discovers docs without embeddings."""
handler = DocumentationHandler()
# Create temporary docs directory with test files
docs_root = tmp_path / "docs"
docs_root.mkdir()
(docs_root / "guide.md").write_text("# Getting Started\n\nThis is a guide.")
(docs_root / "api.mdx").write_text("# API Reference\n\nAPI documentation.")
# Mock _get_docs_root to return temp dir
with patch.object(handler, "_get_docs_root", return_value=docs_root):
# Mock existing embeddings query (no embeddings exist)
with patch(
"backend.api.features.store.content_handlers.query_raw_with_schema",
return_value=[],
@@ -323,6 +173,7 @@ async def test_documentation_handler_get_missing_items(tmp_path, mocker):
assert len(items) == 2
# Check guide.md (content_id format: doc_path::section_index)
guide_item = next(
(item for item in items if item.content_id == "guide.md::0"), None
)
@@ -333,6 +184,7 @@ async def test_documentation_handler_get_missing_items(tmp_path, mocker):
assert guide_item.metadata["doc_title"] == "Getting Started"
assert guide_item.user_id is None
# Check api.mdx (content_id format: doc_path::section_index)
api_item = next(
(item for item in items if item.content_id == "api.mdx::0"), None
)
@@ -345,12 +197,14 @@ async def test_documentation_handler_get_stats(tmp_path, mocker):
"""Test DocumentationHandler returns correct stats."""
handler = DocumentationHandler()
# Create temporary docs directory
docs_root = tmp_path / "docs"
docs_root.mkdir()
(docs_root / "doc1.md").write_text("# Doc 1")
(docs_root / "doc2.md").write_text("# Doc 2")
(docs_root / "doc3.mdx").write_text("# Doc 3")
# Mock embedded count query (1 doc has embedding)
mock_embedded = [{"count": 1}]
with patch.object(handler, "_get_docs_root", return_value=docs_root):
@@ -370,11 +224,13 @@ async def test_documentation_handler_title_extraction(tmp_path):
"""Test DocumentationHandler extracts title from markdown heading."""
handler = DocumentationHandler()
# Test with heading
doc_with_heading = tmp_path / "with_heading.md"
doc_with_heading.write_text("# My Title\n\nContent here")
title = handler._extract_doc_title(doc_with_heading)
assert title == "My Title"
# Test without heading
doc_without_heading = tmp_path / "no-heading.md"
doc_without_heading.write_text("Just content, no heading")
title = handler._extract_doc_title(doc_without_heading)
@@ -386,6 +242,7 @@ async def test_documentation_handler_markdown_chunking(tmp_path):
"""Test DocumentationHandler chunks markdown by headings."""
handler = DocumentationHandler()
# Test document with multiple sections
doc_with_sections = tmp_path / "sections.md"
doc_with_sections.write_text(
"# Document Title\n\n"
@@ -397,6 +254,7 @@ async def test_documentation_handler_markdown_chunking(tmp_path):
)
sections = handler._chunk_markdown_by_headings(doc_with_sections)
# Should have 3 sections: intro (with doc title), section one, section two
assert len(sections) == 3
assert sections[0].title == "Document Title"
assert sections[0].index == 0
@@ -410,6 +268,7 @@ async def test_documentation_handler_markdown_chunking(tmp_path):
assert sections[2].index == 2
assert "Content for section two" in sections[2].content
# Test document without headings
doc_no_sections = tmp_path / "no-sections.md"
doc_no_sections.write_text("Just plain content without any headings.")
sections = handler._chunk_markdown_by_headings(doc_no_sections)
@@ -423,39 +282,21 @@ async def test_documentation_handler_section_content_ids():
"""Test DocumentationHandler creates and parses section content IDs."""
handler = DocumentationHandler()
# Test making content ID
content_id = handler._make_section_content_id("docs/guide.md", 2)
assert content_id == "docs/guide.md::2"
# Test parsing content ID
doc_path, section_index = handler._parse_section_content_id("docs/guide.md::2")
assert doc_path == "docs/guide.md"
assert section_index == 2
# Test parsing legacy format (no section index)
doc_path, section_index = handler._parse_section_content_id("docs/old-format.md")
assert doc_path == "docs/old-format.md"
assert section_index == 0
@pytest.mark.asyncio(loop_scope="session")
async def test_documentation_handler_missing_docs_directory():
"""Test DocumentationHandler handles missing docs directory gracefully."""
handler = DocumentationHandler()
fake_path = Path("/nonexistent/docs")
with patch.object(handler, "_get_docs_root", return_value=fake_path):
items = await handler.get_missing_items(batch_size=10)
assert items == []
stats = await handler.get_stats()
assert stats["total"] == 0
assert stats["with_embeddings"] == 0
assert stats["without_embeddings"] == 0
# ---------------------------------------------------------------------------
# Registry
# ---------------------------------------------------------------------------
@pytest.mark.asyncio(loop_scope="session")
async def test_content_handlers_registry():
"""Test all content types are registered."""
@@ -466,3 +307,88 @@ async def test_content_handlers_registry():
assert isinstance(CONTENT_HANDLERS[ContentType.STORE_AGENT], StoreAgentHandler)
assert isinstance(CONTENT_HANDLERS[ContentType.BLOCK], BlockHandler)
assert isinstance(CONTENT_HANDLERS[ContentType.DOCUMENTATION], DocumentationHandler)
@pytest.mark.asyncio(loop_scope="session")
async def test_block_handler_handles_empty_attributes():
"""Test BlockHandler handles blocks with empty/falsy attribute values."""
handler = BlockHandler()
# Mock block with empty values (all attributes exist but are falsy)
mock_block_class = MagicMock()
mock_block_instance = MagicMock()
mock_block_instance.name = "Minimal Block"
mock_block_instance.disabled = False
mock_block_instance.description = ""
mock_block_instance.categories = set()
mock_block_instance.input_schema.model_fields = {}
mock_block_instance.input_schema.get_credentials_fields_info.return_value = {}
mock_block_class.return_value = mock_block_instance
mock_blocks = {"block-minimal": mock_block_class}
with patch(
"backend.blocks.get_blocks",
return_value=mock_blocks,
):
with patch(
"backend.api.features.store.content_handlers.query_raw_with_schema",
return_value=[],
):
items = await handler.get_missing_items(batch_size=10)
assert len(items) == 1
assert items[0].searchable_text == "Minimal Block"
@pytest.mark.asyncio(loop_scope="session")
async def test_block_handler_skips_failed_blocks():
"""Test BlockHandler skips blocks that fail to instantiate."""
handler = BlockHandler()
# Mock one good block and one bad block
good_block = MagicMock()
good_instance = MagicMock()
good_instance.name = "Good Block"
good_instance.description = "Works fine"
good_instance.categories = []
good_instance.disabled = False
good_instance.input_schema.model_fields = {}
good_instance.input_schema.get_credentials_fields_info.return_value = {}
good_block.return_value = good_instance
bad_block = MagicMock()
bad_block.side_effect = Exception("Instantiation failed")
mock_blocks = {"good-block": good_block, "bad-block": bad_block}
with patch(
"backend.blocks.get_blocks",
return_value=mock_blocks,
):
with patch(
"backend.api.features.store.content_handlers.query_raw_with_schema",
return_value=[],
):
items = await handler.get_missing_items(batch_size=10)
# Should only get the good block
assert len(items) == 1
assert items[0].content_id == "good-block"
@pytest.mark.asyncio(loop_scope="session")
async def test_documentation_handler_missing_docs_directory():
"""Test DocumentationHandler handles missing docs directory gracefully."""
handler = DocumentationHandler()
# Mock _get_docs_root to return non-existent path
fake_path = Path("/nonexistent/docs")
with patch.object(handler, "_get_docs_root", return_value=fake_path):
items = await handler.get_missing_items(batch_size=10)
assert items == []
stats = await handler.get_stats()
assert stats["total"] == 0
assert stats["with_embeddings"] == 0
assert stats["without_embeddings"] == 0

View File

@@ -21,7 +21,6 @@ from backend.api.features.store.embeddings import (
embedding_to_vector_string,
)
from backend.data.db import query_raw_with_schema
from backend.util.text import split_camelcase
logger = logging.getLogger(__name__)
@@ -32,14 +31,12 @@ logger = logging.getLogger(__name__)
def tokenize(text: str) -> list[str]:
"""Simple tokenizer for BM25 lowercase and split on word boundaries.
CamelCase is split first so "AITextGeneratorBlock" becomes
``["ai", "text", "generator", "block"]``.
"""
"""Simple tokenizer for BM25 - lowercase and split on non-alphanumeric."""
if not text:
return []
return re.findall(r"\b\w+\b", split_camelcase(text).lower())
# Lowercase and split on non-alphanumeric characters
tokens = re.findall(r"\b\w+\b", text.lower())
return tokens
def bm25_rerank(

View File

@@ -14,49 +14,8 @@ from backend.api.features.store.hybrid_search import (
HybridSearchWeights,
UnifiedSearchWeights,
hybrid_search,
tokenize,
unified_hybrid_search,
)
from backend.util.text import split_camelcase
# ---------------------------------------------------------------------------
# split_camelcase
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"input_text, expected",
[
("AITextGeneratorBlock", "AI Text Generator Block"),
("HTTPRequestBlock", "HTTP Request Block"),
("simpleWord", "simple Word"),
("already spaced", "already spaced"),
("XMLParser", "XML Parser"),
("getHTTPResponse", "get HTTP Response"),
("Block", "Block"),
("", ""),
],
)
def test_split_camelcase(input_text: str, expected: str):
assert split_camelcase(input_text) == expected
# ---------------------------------------------------------------------------
# tokenize (BM25)
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"input_text, expected",
[
("AITextGeneratorBlock", ["ai", "text", "generator", "block"]),
("hello world", ["hello", "world"]),
("", []),
("HTTPRequest", ["http", "request"]),
],
)
def test_tokenize(input_text: str, expected: list[str]):
assert tokenize(input_text) == expected
@pytest.mark.asyncio(loop_scope="session")

View File

@@ -0,0 +1,19 @@
"""Small text helpers shared across store search modules."""
import re
def split_camelcase(text: str) -> str:
"""Split CamelCase into separate words.
Examples::
>>> split_camelcase("AITextGeneratorBlock")
'AI Text Generator Block'
>>> split_camelcase("HTTPRequestBlock")
'HTTP Request Block'
"""
text = text[:500] # Bound input length to prevent regex DoS
text = re.sub(r"(?<=[a-z0-9])(?=[A-Z])", r" ", text)
text = re.sub(r"(?<=[A-Z])(?=[A-Z][a-z])", r" ", text)
return text

View File

@@ -1,5 +1,4 @@
import logging
import re
import bleach
from bleach.css_sanitizer import CSSSanitizer
@@ -155,19 +154,3 @@ class TextFormatter:
)
return rendered_subject_template, rendered_base_template
def split_camelcase(text: str) -> str:
"""Split CamelCase into separate words.
Examples::
>>> split_camelcase("AITextGeneratorBlock")
'AI Text Generator Block'
>>> split_camelcase("HTTPRequestBlock")
'HTTP Request Block'
"""
text = text[:500] # Bound input length to prevent regex DoS
text = re.sub(r"(?<=[a-z0-9])(?=[A-Z])", r" ", text)
text = re.sub(r"(?<=[A-Z])(?=[A-Z][a-z])", r" ", text)
return text