mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-09 07:08:09 -05:00
added indexer and search example
This commit is contained in:
868
docs/index.py
Normal file
868
docs/index.py
Normal file
@@ -0,0 +1,868 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Documentation Indexer
|
||||
|
||||
Creates a hybrid search index from markdown documentation files:
|
||||
- Local embeddings via sentence-transformers (all-MiniLM-L6-v2)
|
||||
- BM25 index for lexical search
|
||||
- PageRank scores based on internal link graph
|
||||
- Title index for fast title matching
|
||||
|
||||
Based on ZIM-Plus indexing architecture.
|
||||
|
||||
Usage:
|
||||
python index.py [--docs-dir ./content] [--output index.bin] [--json]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
import pickle
|
||||
import re
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Optional imports with graceful fallback
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
HAS_SENTENCE_TRANSFORMERS = True
|
||||
except ImportError:
|
||||
HAS_SENTENCE_TRANSFORMERS = False
|
||||
print("Warning: sentence-transformers not installed. Run: pip install sentence-transformers")
|
||||
|
||||
try:
|
||||
from openai import OpenAI
|
||||
HAS_OPENAI = True
|
||||
except ImportError:
|
||||
HAS_OPENAI = False
|
||||
|
||||
try:
|
||||
from rank_bm25 import BM25Okapi
|
||||
HAS_BM25 = True
|
||||
except ImportError:
|
||||
HAS_BM25 = False
|
||||
print("Warning: rank_bm25 not installed. Run: pip install rank-bm25")
|
||||
|
||||
# Default embedding model (compatible with Transformers.js)
|
||||
DEFAULT_EMBEDDING_MODEL = "all-MiniLM-L6-v2"
|
||||
DEFAULT_EMBEDDING_DIM = 384
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Data Structures
|
||||
# ============================================================================
|
||||
|
||||
@dataclass
|
||||
class Chunk:
|
||||
"""A chunk of text from a document."""
|
||||
doc_path: str # Relative path to source document
|
||||
doc_title: str # Document title (from first H1 or filename)
|
||||
chunk_id: int # Chunk index within document
|
||||
text: str # Chunk text content
|
||||
heading: str # Current heading context
|
||||
start_char: int # Start position in original doc
|
||||
end_char: int # End position in original doc
|
||||
embedding: Optional[np.ndarray] = None # OpenAI embedding vector
|
||||
|
||||
|
||||
@dataclass
|
||||
class Document:
|
||||
"""A markdown document."""
|
||||
path: str # Relative path from docs root
|
||||
title: str # Document title
|
||||
content: str # Raw markdown content
|
||||
headings: list[str] = field(default_factory=list) # All headings
|
||||
outgoing_links: list[str] = field(default_factory=list) # Links to other docs
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchIndex:
|
||||
"""Complete search index structure."""
|
||||
# Metadata
|
||||
version: str = "1.0.0"
|
||||
docs_hash: str = "" # Hash of all docs for cache invalidation
|
||||
embedding_model: str = "text-embedding-3-small"
|
||||
embedding_dim: int = 1536
|
||||
|
||||
# Document data
|
||||
documents: list[Document] = field(default_factory=list)
|
||||
chunks: list[Chunk] = field(default_factory=list)
|
||||
|
||||
# Embeddings matrix (num_chunks x embedding_dim)
|
||||
embeddings: Optional[np.ndarray] = None
|
||||
|
||||
# BM25 index (serialized)
|
||||
bm25_corpus: list[list[str]] = field(default_factory=list)
|
||||
|
||||
# PageRank scores per document
|
||||
pagerank: dict[str, float] = field(default_factory=dict)
|
||||
|
||||
# Title inverted index: word -> list of (doc_idx, score)
|
||||
title_index: dict[str, list[tuple[int, float]]] = field(default_factory=dict)
|
||||
|
||||
# Path to doc index mapping
|
||||
path_to_idx: dict[str, int] = field(default_factory=dict)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Markdown Parsing
|
||||
# ============================================================================
|
||||
|
||||
def extract_title(content: str, filename: str) -> str:
|
||||
"""Extract document title from first H1 heading or filename."""
|
||||
match = re.search(r'^#\s+(.+)$', content, re.MULTILINE)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
return filename.replace('.md', '').replace('-', ' ').replace('_', ' ').title()
|
||||
|
||||
|
||||
def extract_headings(content: str) -> list[str]:
|
||||
"""Extract all headings from markdown."""
|
||||
headings = []
|
||||
for match in re.finditer(r'^(#{1,6})\s+(.+)$', content, re.MULTILINE):
|
||||
level = len(match.group(1))
|
||||
text = match.group(2).strip()
|
||||
headings.append(f"{'#' * level} {text}")
|
||||
return headings
|
||||
|
||||
|
||||
def extract_links(content: str, current_path: str) -> list[str]:
|
||||
"""Extract internal markdown links, normalized to relative paths."""
|
||||
links = []
|
||||
# Match [text](path) but not external URLs
|
||||
for match in re.finditer(r'\[([^\]]+)\]\(([^)]+)\)', content):
|
||||
link_path = match.group(2)
|
||||
# Skip external links, anchors, and images
|
||||
if link_path.startswith(('http://', 'https://', '#', 'mailto:')):
|
||||
continue
|
||||
if link_path.endswith(('.png', '.jpg', '.gif', '.svg')):
|
||||
continue
|
||||
|
||||
# Normalize the path relative to docs root
|
||||
# Handle relative paths like ../foo.md or ./bar.md
|
||||
current_dir = Path(current_path).parent
|
||||
normalized = (current_dir / link_path).as_posix()
|
||||
# Remove ./ prefix and normalize
|
||||
normalized = re.sub(r'^\./', '', normalized)
|
||||
# Ensure .md extension
|
||||
if not normalized.endswith('.md'):
|
||||
normalized += '.md' if '.' not in Path(normalized).name else ''
|
||||
links.append(normalized)
|
||||
|
||||
return links
|
||||
|
||||
|
||||
def parse_document(path: Path, docs_root: Path) -> Document:
|
||||
"""Parse a markdown document."""
|
||||
content = path.read_text(encoding='utf-8')
|
||||
rel_path = path.relative_to(docs_root).as_posix()
|
||||
|
||||
return Document(
|
||||
path=rel_path,
|
||||
title=extract_title(content, path.name),
|
||||
content=content,
|
||||
headings=extract_headings(content),
|
||||
outgoing_links=extract_links(content, rel_path)
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Chunking
|
||||
# ============================================================================
|
||||
|
||||
def chunk_markdown(
|
||||
content: str,
|
||||
doc_path: str,
|
||||
doc_title: str,
|
||||
chunk_size: int = 6000,
|
||||
chunk_overlap: int = 200
|
||||
) -> list[Chunk]:
|
||||
"""
|
||||
Chunk markdown content with heading awareness.
|
||||
|
||||
Strategy:
|
||||
1. Split by headings to preserve semantic boundaries
|
||||
2. Further split large sections by paragraphs
|
||||
3. Merge small sections to reach target chunk size
|
||||
4. Add overlap between chunks for context continuity
|
||||
"""
|
||||
chunks = []
|
||||
|
||||
# Split content into sections by headings
|
||||
sections = []
|
||||
current_heading = doc_title
|
||||
current_text = []
|
||||
current_start = 0
|
||||
|
||||
lines = content.split('\n')
|
||||
char_pos = 0
|
||||
|
||||
for line in lines:
|
||||
# Check if this is a heading
|
||||
heading_match = re.match(r'^(#{1,6})\s+(.+)$', line)
|
||||
|
||||
if heading_match:
|
||||
# Save previous section if not empty
|
||||
if current_text:
|
||||
section_text = '\n'.join(current_text)
|
||||
sections.append({
|
||||
'heading': current_heading,
|
||||
'text': section_text,
|
||||
'start': current_start,
|
||||
'end': char_pos
|
||||
})
|
||||
|
||||
# Start new section
|
||||
current_heading = heading_match.group(2).strip()
|
||||
current_text = [line]
|
||||
current_start = char_pos
|
||||
else:
|
||||
current_text.append(line)
|
||||
|
||||
char_pos += len(line) + 1 # +1 for newline
|
||||
|
||||
# Don't forget the last section
|
||||
if current_text:
|
||||
section_text = '\n'.join(current_text)
|
||||
sections.append({
|
||||
'heading': current_heading,
|
||||
'text': section_text,
|
||||
'start': current_start,
|
||||
'end': char_pos
|
||||
})
|
||||
|
||||
# Now merge small sections and split large ones
|
||||
chunk_id = 0
|
||||
buffer_text = ""
|
||||
buffer_heading = doc_title
|
||||
buffer_start = 0
|
||||
|
||||
for section in sections:
|
||||
section_text = section['text'].strip()
|
||||
if not section_text:
|
||||
continue
|
||||
|
||||
# If adding this section would exceed chunk size
|
||||
if len(buffer_text) + len(section_text) > chunk_size:
|
||||
# Save current buffer as chunk if not empty
|
||||
if buffer_text.strip():
|
||||
chunks.append(Chunk(
|
||||
doc_path=doc_path,
|
||||
doc_title=doc_title,
|
||||
chunk_id=chunk_id,
|
||||
text=buffer_text.strip(),
|
||||
heading=buffer_heading,
|
||||
start_char=buffer_start,
|
||||
end_char=section['start']
|
||||
))
|
||||
chunk_id += 1
|
||||
|
||||
# If section itself is too large, split it
|
||||
if len(section_text) > chunk_size:
|
||||
# Split by paragraphs
|
||||
paragraphs = re.split(r'\n\n+', section_text)
|
||||
para_buffer = ""
|
||||
para_start = section['start']
|
||||
|
||||
for para in paragraphs:
|
||||
if len(para_buffer) + len(para) > chunk_size:
|
||||
if para_buffer.strip():
|
||||
chunks.append(Chunk(
|
||||
doc_path=doc_path,
|
||||
doc_title=doc_title,
|
||||
chunk_id=chunk_id,
|
||||
text=para_buffer.strip(),
|
||||
heading=section['heading'],
|
||||
start_char=para_start,
|
||||
end_char=para_start + len(para_buffer)
|
||||
))
|
||||
chunk_id += 1
|
||||
para_buffer = para
|
||||
para_start = para_start + len(para_buffer)
|
||||
else:
|
||||
para_buffer += "\n\n" + para if para_buffer else para
|
||||
|
||||
# Remaining paragraph buffer becomes new buffer
|
||||
buffer_text = para_buffer
|
||||
buffer_heading = section['heading']
|
||||
buffer_start = para_start
|
||||
else:
|
||||
# Start new buffer with this section
|
||||
buffer_text = section_text
|
||||
buffer_heading = section['heading']
|
||||
buffer_start = section['start']
|
||||
else:
|
||||
# Add section to buffer
|
||||
buffer_text += "\n\n" + section_text if buffer_text else section_text
|
||||
if not buffer_heading or buffer_heading == doc_title:
|
||||
buffer_heading = section['heading']
|
||||
|
||||
# Don't forget the last buffer
|
||||
if buffer_text.strip():
|
||||
chunks.append(Chunk(
|
||||
doc_path=doc_path,
|
||||
doc_title=doc_title,
|
||||
chunk_id=chunk_id,
|
||||
text=buffer_text.strip(),
|
||||
heading=buffer_heading,
|
||||
start_char=buffer_start,
|
||||
end_char=len(content)
|
||||
))
|
||||
|
||||
# Add overlap by prepending context from previous chunk
|
||||
if chunk_overlap > 0 and len(chunks) > 1:
|
||||
for i in range(1, len(chunks)):
|
||||
prev_text = chunks[i-1].text
|
||||
if len(prev_text) > chunk_overlap:
|
||||
# Find a good break point (end of sentence or paragraph)
|
||||
overlap_text = prev_text[-chunk_overlap:]
|
||||
# Try to start at a sentence boundary
|
||||
sentence_match = re.search(r'[.!?]\s+', overlap_text)
|
||||
if sentence_match:
|
||||
overlap_text = overlap_text[sentence_match.end():]
|
||||
chunks[i].text = f"...{overlap_text}\n\n{chunks[i].text}"
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Embeddings
|
||||
# ============================================================================
|
||||
|
||||
def create_embeddings_local(
|
||||
chunks: list[Chunk],
|
||||
model_name: str = DEFAULT_EMBEDDING_MODEL,
|
||||
batch_size: int = 32
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Create embeddings using sentence-transformers (local model).
|
||||
|
||||
Uses all-MiniLM-L6-v2 by default which is compatible with Transformers.js
|
||||
for client-side query embedding.
|
||||
"""
|
||||
if not HAS_SENTENCE_TRANSFORMERS:
|
||||
raise RuntimeError(
|
||||
"sentence-transformers not installed. Run: pip install sentence-transformers"
|
||||
)
|
||||
|
||||
print(f"Loading embedding model: {model_name}")
|
||||
model = SentenceTransformer(model_name)
|
||||
|
||||
print(f"Creating embeddings for {len(chunks)} chunks...")
|
||||
texts = [chunk.text for chunk in chunks]
|
||||
|
||||
# Encode with progress
|
||||
embeddings = model.encode(
|
||||
texts,
|
||||
batch_size=batch_size,
|
||||
show_progress_bar=True,
|
||||
convert_to_numpy=True
|
||||
)
|
||||
|
||||
return embeddings.astype(np.float32)
|
||||
|
||||
|
||||
def create_embeddings_openai(
|
||||
chunks: list[Chunk],
|
||||
model: str = "text-embedding-3-small",
|
||||
batch_size: int = 100
|
||||
) -> np.ndarray:
|
||||
"""Create OpenAI embeddings for all chunks (requires API key)."""
|
||||
if not HAS_OPENAI:
|
||||
raise RuntimeError("OpenAI library not installed")
|
||||
|
||||
client = OpenAI()
|
||||
embeddings = []
|
||||
|
||||
print(f"Creating OpenAI embeddings for {len(chunks)} chunks...")
|
||||
|
||||
for i in range(0, len(chunks), batch_size):
|
||||
batch = chunks[i:i + batch_size]
|
||||
texts = [c.text[:8000] for c in batch]
|
||||
|
||||
response = client.embeddings.create(
|
||||
model=model,
|
||||
input=texts
|
||||
)
|
||||
|
||||
for embedding_data in response.data:
|
||||
embeddings.append(embedding_data.embedding)
|
||||
|
||||
print(f" Processed {min(i + batch_size, len(chunks))}/{len(chunks)} chunks")
|
||||
|
||||
return np.array(embeddings, dtype=np.float32)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# BM25 Index
|
||||
# ============================================================================
|
||||
|
||||
def tokenize(text: str) -> list[str]:
|
||||
"""Simple tokenizer for BM25."""
|
||||
# Lowercase and extract words
|
||||
text = text.lower()
|
||||
# Remove code blocks
|
||||
text = re.sub(r'```[\s\S]*?```', '', text)
|
||||
text = re.sub(r'`[^`]+`', '', text)
|
||||
# Extract words
|
||||
words = re.findall(r'\b[a-z][a-z0-9_-]*\b', text)
|
||||
# Remove very short words and stopwords
|
||||
stopwords = {'the', 'a', 'an', 'is', 'are', 'was', 'were', 'be', 'been',
|
||||
'being', 'have', 'has', 'had', 'do', 'does', 'did', 'will',
|
||||
'would', 'could', 'should', 'may', 'might', 'must', 'shall',
|
||||
'can', 'need', 'dare', 'ought', 'used', 'to', 'of', 'in',
|
||||
'for', 'on', 'with', 'at', 'by', 'from', 'as', 'into',
|
||||
'through', 'during', 'before', 'after', 'above', 'below',
|
||||
'between', 'under', 'again', 'further', 'then', 'once',
|
||||
'and', 'but', 'or', 'nor', 'so', 'yet', 'both', 'either',
|
||||
'neither', 'not', 'only', 'own', 'same', 'than', 'too',
|
||||
'very', 'just', 'also', 'now', 'here', 'there', 'when',
|
||||
'where', 'why', 'how', 'all', 'each', 'every', 'both',
|
||||
'few', 'more', 'most', 'other', 'some', 'such', 'no',
|
||||
'any', 'this', 'that', 'these', 'those', 'it', 'its'}
|
||||
return [w for w in words if len(w) > 2 and w not in stopwords]
|
||||
|
||||
|
||||
def build_bm25_corpus(chunks: list[Chunk]) -> list[list[str]]:
|
||||
"""Build tokenized corpus for BM25."""
|
||||
return [tokenize(chunk.text) for chunk in chunks]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# PageRank
|
||||
# ============================================================================
|
||||
|
||||
def build_link_graph(documents: list[Document]) -> dict[str, list[str]]:
|
||||
"""Build adjacency list from document links."""
|
||||
# Create path lookup
|
||||
valid_paths = {doc.path for doc in documents}
|
||||
|
||||
graph = defaultdict(list)
|
||||
for doc in documents:
|
||||
for link in doc.outgoing_links:
|
||||
# Normalize link path
|
||||
normalized = link.lstrip('./')
|
||||
if normalized in valid_paths:
|
||||
graph[doc.path].append(normalized)
|
||||
|
||||
return dict(graph)
|
||||
|
||||
|
||||
def compute_pagerank(
|
||||
documents: list[Document],
|
||||
damping: float = 0.85,
|
||||
max_iterations: int = 100,
|
||||
tolerance: float = 1e-6
|
||||
) -> dict[str, float]:
|
||||
"""
|
||||
Compute PageRank scores using power iteration.
|
||||
|
||||
Args:
|
||||
documents: List of documents with outgoing_links
|
||||
damping: Damping factor (probability of following a link)
|
||||
max_iterations: Maximum iterations for convergence
|
||||
tolerance: Convergence threshold
|
||||
|
||||
Returns:
|
||||
Dictionary mapping document paths to PageRank scores
|
||||
"""
|
||||
n = len(documents)
|
||||
if n == 0:
|
||||
return {}
|
||||
|
||||
# Build path to index mapping
|
||||
path_to_idx = {doc.path: i for i, doc in enumerate(documents)}
|
||||
valid_paths = set(path_to_idx.keys())
|
||||
|
||||
# Build adjacency matrix
|
||||
# out_links[i] = list of indices that document i links to
|
||||
out_links = []
|
||||
for doc in documents:
|
||||
links = []
|
||||
for link in doc.outgoing_links:
|
||||
normalized = link.lstrip('./')
|
||||
if normalized in valid_paths:
|
||||
links.append(path_to_idx[normalized])
|
||||
out_links.append(links)
|
||||
|
||||
# Initialize PageRank scores uniformly
|
||||
pr = np.ones(n) / n
|
||||
|
||||
# Power iteration
|
||||
for iteration in range(max_iterations):
|
||||
new_pr = np.ones(n) * (1 - damping) / n
|
||||
|
||||
for i in range(n):
|
||||
if out_links[i]:
|
||||
# Distribute PageRank to outgoing links
|
||||
contribution = damping * pr[i] / len(out_links[i])
|
||||
for j in out_links[i]:
|
||||
new_pr[j] += contribution
|
||||
else:
|
||||
# Dangling node: distribute to all nodes
|
||||
new_pr += damping * pr[i] / n
|
||||
|
||||
# Check convergence
|
||||
diff = np.abs(new_pr - pr).sum()
|
||||
pr = new_pr
|
||||
|
||||
if diff < tolerance:
|
||||
print(f" PageRank converged after {iteration + 1} iterations")
|
||||
break
|
||||
|
||||
# Normalize to [0, 1] range
|
||||
pr = (pr - pr.min()) / (pr.max() - pr.min() + 1e-10)
|
||||
|
||||
return {documents[i].path: float(pr[i]) for i in range(n)}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Title Index
|
||||
# ============================================================================
|
||||
|
||||
def build_title_index(documents: list[Document]) -> dict[str, list[tuple[int, float]]]:
|
||||
"""
|
||||
Build inverted index for title search.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping words to list of (doc_index, score) tuples
|
||||
"""
|
||||
index = defaultdict(list)
|
||||
|
||||
for doc_idx, doc in enumerate(documents):
|
||||
# Tokenize title
|
||||
words = tokenize(doc.title)
|
||||
word_set = set(words)
|
||||
|
||||
for word in word_set:
|
||||
# Score based on word position and frequency
|
||||
score = 1.0
|
||||
if words and words[0] == word:
|
||||
score += 0.5 # Bonus for first word
|
||||
index[word].append((doc_idx, score))
|
||||
|
||||
return dict(index)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Main Indexing Pipeline
|
||||
# ============================================================================
|
||||
|
||||
def compute_docs_hash(docs_dir: Path) -> str:
|
||||
"""Compute hash of all doc files for cache invalidation."""
|
||||
hasher = hashlib.md5()
|
||||
for path in sorted(docs_dir.rglob('*.md')):
|
||||
hasher.update(path.read_bytes())
|
||||
return hasher.hexdigest()
|
||||
|
||||
|
||||
def build_index(
|
||||
docs_dir: Path,
|
||||
embedding_model: str = DEFAULT_EMBEDDING_MODEL,
|
||||
chunk_size: int = 6000,
|
||||
chunk_overlap: int = 200,
|
||||
skip_embeddings: bool = False,
|
||||
use_openai: bool = False
|
||||
) -> SearchIndex:
|
||||
"""
|
||||
Build complete search index from documentation directory.
|
||||
|
||||
Args:
|
||||
docs_dir: Path to documentation directory
|
||||
embedding_model: Embedding model to use (default: all-MiniLM-L6-v2)
|
||||
chunk_size: Target chunk size in characters
|
||||
chunk_overlap: Overlap between chunks
|
||||
skip_embeddings: Skip embedding generation (for testing)
|
||||
use_openai: Use OpenAI embeddings instead of local model
|
||||
|
||||
Returns:
|
||||
Complete SearchIndex ready for search
|
||||
"""
|
||||
print(f"Building index from {docs_dir}")
|
||||
|
||||
# Find all markdown files
|
||||
md_files = list(docs_dir.rglob('*.md'))
|
||||
print(f"Found {len(md_files)} markdown files")
|
||||
|
||||
if not md_files:
|
||||
raise ValueError(f"No markdown files found in {docs_dir}")
|
||||
|
||||
# Parse all documents
|
||||
print("Parsing documents...")
|
||||
documents = [parse_document(path, docs_dir) for path in md_files]
|
||||
|
||||
# Create path to index mapping
|
||||
path_to_idx = {doc.path: i for i, doc in enumerate(documents)}
|
||||
|
||||
# Chunk all documents
|
||||
print("Chunking documents...")
|
||||
all_chunks = []
|
||||
for doc in documents:
|
||||
doc_chunks = chunk_markdown(
|
||||
doc.content,
|
||||
doc.path,
|
||||
doc.title,
|
||||
chunk_size,
|
||||
chunk_overlap
|
||||
)
|
||||
all_chunks.extend(doc_chunks)
|
||||
print(f"Created {len(all_chunks)} chunks")
|
||||
|
||||
# Build BM25 corpus
|
||||
print("Building BM25 index...")
|
||||
bm25_corpus = build_bm25_corpus(all_chunks)
|
||||
|
||||
# Compute PageRank
|
||||
print("Computing PageRank...")
|
||||
pagerank = compute_pagerank(documents)
|
||||
|
||||
# Build title index
|
||||
print("Building title index...")
|
||||
title_index = build_title_index(documents)
|
||||
|
||||
# Create embeddings
|
||||
embeddings = None
|
||||
embedding_dim = DEFAULT_EMBEDDING_DIM
|
||||
if not skip_embeddings:
|
||||
if use_openai:
|
||||
if HAS_OPENAI:
|
||||
embeddings = create_embeddings_openai(all_chunks, embedding_model)
|
||||
embedding_dim = embeddings.shape[1]
|
||||
else:
|
||||
print("Skipping embeddings (openai not installed)")
|
||||
else:
|
||||
if HAS_SENTENCE_TRANSFORMERS:
|
||||
embeddings = create_embeddings_local(all_chunks, embedding_model)
|
||||
embedding_dim = embeddings.shape[1]
|
||||
else:
|
||||
print("Skipping embeddings (sentence-transformers not installed)")
|
||||
|
||||
# Compute docs hash
|
||||
docs_hash = compute_docs_hash(docs_dir)
|
||||
|
||||
# Build final index
|
||||
index = SearchIndex(
|
||||
version="1.0.0",
|
||||
docs_hash=docs_hash,
|
||||
embedding_model=embedding_model,
|
||||
embedding_dim=embedding_dim,
|
||||
documents=documents,
|
||||
chunks=all_chunks,
|
||||
embeddings=embeddings,
|
||||
bm25_corpus=bm25_corpus,
|
||||
pagerank=pagerank,
|
||||
title_index=title_index,
|
||||
path_to_idx=path_to_idx
|
||||
)
|
||||
|
||||
return index
|
||||
|
||||
|
||||
def save_index(index: SearchIndex, output_path: Path) -> None:
|
||||
"""Save index to binary file."""
|
||||
print(f"Saving index to {output_path}")
|
||||
|
||||
# Convert embeddings to float16 for space savings
|
||||
if index.embeddings is not None:
|
||||
index.embeddings = index.embeddings.astype(np.float16)
|
||||
|
||||
with open(output_path, 'wb') as f:
|
||||
pickle.dump(index, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
size_mb = output_path.stat().st_size / (1024 * 1024)
|
||||
print(f"Index saved ({size_mb:.2f} MB)")
|
||||
|
||||
|
||||
def save_index_json(index: SearchIndex, output_path: Path) -> None:
|
||||
"""
|
||||
Save index to JSON format for client-side JavaScript search.
|
||||
|
||||
The JSON structure is optimized for browser loading:
|
||||
- Chunks with text, metadata, and embeddings
|
||||
- BM25 vocabulary and document frequencies
|
||||
- PageRank scores
|
||||
- Title index
|
||||
"""
|
||||
import json
|
||||
import base64
|
||||
|
||||
print(f"Saving JSON index to {output_path}")
|
||||
|
||||
# Build chunks array
|
||||
chunks_data = []
|
||||
for i, chunk in enumerate(index.chunks):
|
||||
chunk_data = {
|
||||
"id": i,
|
||||
"doc": chunk.doc_path,
|
||||
"title": chunk.doc_title,
|
||||
"heading": chunk.heading,
|
||||
"text": chunk.text,
|
||||
}
|
||||
|
||||
# Add embedding if available (as base64 float32)
|
||||
if index.embeddings is not None:
|
||||
emb = index.embeddings[i].astype(np.float32)
|
||||
chunk_data["emb"] = base64.b64encode(emb.tobytes()).decode('ascii')
|
||||
|
||||
chunks_data.append(chunk_data)
|
||||
|
||||
# Build BM25 data
|
||||
# Calculate IDF for each term
|
||||
bm25_data = {}
|
||||
if index.bm25_corpus:
|
||||
# Build vocabulary with document frequencies
|
||||
doc_freq = {}
|
||||
for doc_tokens in index.bm25_corpus:
|
||||
seen = set()
|
||||
for token in doc_tokens:
|
||||
if token not in seen:
|
||||
doc_freq[token] = doc_freq.get(token, 0) + 1
|
||||
seen.add(token)
|
||||
|
||||
n_docs = len(index.bm25_corpus)
|
||||
bm25_data = {
|
||||
"n_docs": n_docs,
|
||||
"avgdl": sum(len(d) for d in index.bm25_corpus) / max(n_docs, 1),
|
||||
"df": doc_freq, # document frequency per term
|
||||
"doc_lens": [len(d) for d in index.bm25_corpus],
|
||||
}
|
||||
|
||||
# Build output structure
|
||||
output = {
|
||||
"version": index.version,
|
||||
"embedding_model": index.embedding_model,
|
||||
"embedding_dim": index.embedding_dim,
|
||||
"chunks": chunks_data,
|
||||
"bm25": bm25_data,
|
||||
"pagerank": index.pagerank,
|
||||
"title_index": {k: list(v) for k, v in index.title_index.items()},
|
||||
}
|
||||
|
||||
# Write JSON
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(output, f, separators=(',', ':')) # Compact JSON
|
||||
|
||||
size_mb = output_path.stat().st_size / (1024 * 1024)
|
||||
print(f"JSON index saved ({size_mb:.2f} MB)")
|
||||
|
||||
|
||||
def load_index(index_path: Path) -> SearchIndex:
|
||||
"""Load index from binary file."""
|
||||
with open(index_path, 'rb') as f:
|
||||
index = pickle.load(f)
|
||||
|
||||
# Convert embeddings back to float32 for computation
|
||||
if index.embeddings is not None:
|
||||
index.embeddings = index.embeddings.astype(np.float32)
|
||||
|
||||
return index
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# CLI
|
||||
# ============================================================================
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Index documentation for hybrid search"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--docs-dir',
|
||||
type=Path,
|
||||
default=Path('./content/platform'),
|
||||
help='Path to documentation directory (default: ./content/platform)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output',
|
||||
type=Path,
|
||||
default=Path('./index.bin'),
|
||||
help='Output index file path (default: ./index.bin)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--json-output',
|
||||
type=Path,
|
||||
default=None,
|
||||
help='Output path for JSON index (default: same as --output with .json extension)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--embedding-model',
|
||||
type=str,
|
||||
default=DEFAULT_EMBEDDING_MODEL,
|
||||
help=f'Embedding model (default: {DEFAULT_EMBEDDING_MODEL})'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--use-openai',
|
||||
action='store_true',
|
||||
help='Use OpenAI embeddings instead of local sentence-transformers'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--chunk-size',
|
||||
type=int,
|
||||
default=6000,
|
||||
help='Chunk size in characters (default: 6000)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--chunk-overlap',
|
||||
type=int,
|
||||
default=200,
|
||||
help='Chunk overlap in characters (default: 200)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--skip-embeddings',
|
||||
action='store_true',
|
||||
help='Skip embedding generation (for testing)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--json',
|
||||
action='store_true',
|
||||
help='Also output JSON format for client-side JavaScript search'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--json-only',
|
||||
action='store_true',
|
||||
help='Only output JSON format (skip binary pickle)'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.docs_dir.exists():
|
||||
print(f"Error: Documentation directory not found: {args.docs_dir}")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
index = build_index(
|
||||
args.docs_dir,
|
||||
embedding_model=args.embedding_model,
|
||||
chunk_size=args.chunk_size,
|
||||
chunk_overlap=args.chunk_overlap,
|
||||
skip_embeddings=args.skip_embeddings,
|
||||
use_openai=args.use_openai
|
||||
)
|
||||
|
||||
# Save binary format unless json-only
|
||||
if not args.json_only:
|
||||
save_index(index, args.output)
|
||||
|
||||
# Save JSON format if requested
|
||||
if args.json or args.json_only:
|
||||
json_path = args.json_output if args.json_output else args.output.with_suffix('.json')
|
||||
save_index_json(index, json_path)
|
||||
|
||||
print("\nIndex Statistics:")
|
||||
print(f" Documents: {len(index.documents)}")
|
||||
print(f" Chunks: {len(index.chunks)}")
|
||||
print(f" Embeddings: {'Yes' if index.embeddings is not None else 'No'}")
|
||||
print(f" PageRank scores: {len(index.pagerank)}")
|
||||
print(f" Title index terms: {len(index.title_index)}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error building index: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
715
docs/search.py
Normal file
715
docs/search.py
Normal file
@@ -0,0 +1,715 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Documentation Search
|
||||
|
||||
Hybrid search combining:
|
||||
- Semantic search (OpenAI embeddings + cosine similarity)
|
||||
- Lexical search (BM25)
|
||||
- Authority ranking (PageRank)
|
||||
- Title matching
|
||||
- Content quality signals
|
||||
|
||||
Based on ZIM-Plus search architecture with tunable weights.
|
||||
|
||||
Usage:
|
||||
python search.py "your query" [--index index.bin] [--top-k 10]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import re
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
from openai import OpenAI
|
||||
HAS_OPENAI = True
|
||||
except ImportError:
|
||||
HAS_OPENAI = False
|
||||
|
||||
try:
|
||||
from rank_bm25 import BM25Okapi
|
||||
HAS_BM25 = True
|
||||
except ImportError:
|
||||
HAS_BM25 = False
|
||||
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
HAS_SENTENCE_TRANSFORMERS = True
|
||||
except ImportError:
|
||||
HAS_SENTENCE_TRANSFORMERS = False
|
||||
|
||||
from index import SearchIndex, Chunk, Document, load_index, tokenize
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Search Configuration
|
||||
# ============================================================================
|
||||
|
||||
@dataclass
|
||||
class SearchWeights:
|
||||
"""
|
||||
Hybrid search weight configuration.
|
||||
|
||||
Based on ZIM-Plus reranking signals:
|
||||
- semantic: Cosine similarity from embeddings
|
||||
- title_match: Query terms appearing in title
|
||||
- url_path_match: Query terms appearing in URL/path
|
||||
- bm25: Sparse lexical matching score
|
||||
- content_quality: Penalizes TOC/nav/boilerplate chunks
|
||||
- pagerank: Link authority score
|
||||
- position_boost: Prefers earlier chunks in document
|
||||
|
||||
All weights should sum to 1.0 for interpretability.
|
||||
"""
|
||||
semantic: float = 0.30
|
||||
title_match: float = 0.20
|
||||
url_path_match: float = 0.15
|
||||
bm25: float = 0.15
|
||||
content_quality: float = 0.10
|
||||
pagerank: float = 0.05
|
||||
position_boost: float = 0.05
|
||||
|
||||
# Diversity penalty: max chunks per document
|
||||
max_chunks_per_doc: int = 2
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Ensure weights are valid."""
|
||||
total = (self.semantic + self.title_match + self.url_path_match +
|
||||
self.bm25 + self.content_quality + self.pagerank +
|
||||
self.position_boost)
|
||||
if abs(total - 1.0) > 0.01:
|
||||
print(f"Warning: Weights sum to {total:.3f}, not 1.0")
|
||||
|
||||
|
||||
# Default weights (tuned for documentation search)
|
||||
DEFAULT_WEIGHTS = SearchWeights()
|
||||
|
||||
# Alternative weight presets for different use cases
|
||||
WEIGHT_PRESETS = {
|
||||
"semantic_heavy": SearchWeights(
|
||||
semantic=0.50, title_match=0.15, url_path_match=0.10,
|
||||
bm25=0.10, content_quality=0.05, pagerank=0.05, position_boost=0.05
|
||||
),
|
||||
"keyword_heavy": SearchWeights(
|
||||
semantic=0.20, title_match=0.20, url_path_match=0.15,
|
||||
bm25=0.30, content_quality=0.05, pagerank=0.05, position_boost=0.05
|
||||
),
|
||||
"authority_heavy": SearchWeights(
|
||||
semantic=0.25, title_match=0.15, url_path_match=0.10,
|
||||
bm25=0.15, content_quality=0.10, pagerank=0.20, position_boost=0.05
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Search Result
|
||||
# ============================================================================
|
||||
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
"""A single search result with scoring breakdown."""
|
||||
chunk: Chunk
|
||||
score: float # Final combined score
|
||||
|
||||
# Individual signal scores (for debugging/tuning)
|
||||
semantic_score: float = 0.0
|
||||
title_score: float = 0.0
|
||||
path_score: float = 0.0
|
||||
bm25_score: float = 0.0
|
||||
quality_score: float = 0.0
|
||||
pagerank_score: float = 0.0
|
||||
position_score: float = 0.0
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"[{self.score:.3f}] {self.chunk.doc_title}\n"
|
||||
f" Path: {self.chunk.doc_path}\n"
|
||||
f" Heading: {self.chunk.heading}\n"
|
||||
f" Scores: sem={self.semantic_score:.2f} title={self.title_score:.2f} "
|
||||
f"path={self.path_score:.2f} bm25={self.bm25_score:.2f} "
|
||||
f"qual={self.quality_score:.2f} pr={self.pagerank_score:.2f} "
|
||||
f"pos={self.position_score:.2f}"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Search Engine
|
||||
# ============================================================================
|
||||
|
||||
class HybridSearcher:
|
||||
"""
|
||||
Hybrid search engine combining multiple ranking signals.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index: SearchIndex,
|
||||
weights: SearchWeights = DEFAULT_WEIGHTS,
|
||||
openai_client: Optional[OpenAI] = None
|
||||
):
|
||||
self.index = index
|
||||
self.weights = weights
|
||||
self.weights.validate()
|
||||
|
||||
# Detect if index was built with local embeddings (sentence-transformers)
|
||||
self.use_local_embeddings = index.embedding_model in [
|
||||
"all-MiniLM-L6-v2", "all-mpnet-base-v2", "paraphrase-MiniLM-L6-v2"
|
||||
]
|
||||
|
||||
# Initialize local embedding model if needed
|
||||
self.local_model = None
|
||||
if self.use_local_embeddings and HAS_SENTENCE_TRANSFORMERS:
|
||||
self.local_model = SentenceTransformer(index.embedding_model)
|
||||
|
||||
# Initialize OpenAI client for query embedding (only if not using local)
|
||||
self.openai_client = None
|
||||
if not self.use_local_embeddings:
|
||||
self.openai_client = openai_client
|
||||
if openai_client is None and HAS_OPENAI:
|
||||
self.openai_client = OpenAI()
|
||||
|
||||
# Build BM25 index from stored corpus
|
||||
self.bm25 = None
|
||||
if HAS_BM25 and index.bm25_corpus:
|
||||
self.bm25 = BM25Okapi(index.bm25_corpus)
|
||||
|
||||
# Precompute normalized embeddings for faster cosine similarity
|
||||
self.normalized_embeddings = None
|
||||
if index.embeddings is not None:
|
||||
norms = np.linalg.norm(index.embeddings, axis=1, keepdims=True)
|
||||
self.normalized_embeddings = index.embeddings / (norms + 1e-10)
|
||||
|
||||
def embed_query(self, query: str) -> Optional[np.ndarray]:
|
||||
"""Get embedding for search query."""
|
||||
# Use local model if available
|
||||
if self.local_model is not None:
|
||||
embedding = self.local_model.encode(query, convert_to_numpy=True)
|
||||
return embedding.astype(np.float32)
|
||||
|
||||
# Fall back to OpenAI
|
||||
if self.openai_client is None:
|
||||
return None
|
||||
|
||||
response = self.openai_client.embeddings.create(
|
||||
model=self.index.embedding_model,
|
||||
input=query
|
||||
)
|
||||
return np.array(response.data[0].embedding, dtype=np.float32)
|
||||
|
||||
def compute_semantic_scores(self, query_embedding: np.ndarray) -> np.ndarray:
|
||||
"""Compute cosine similarity between query and all chunks."""
|
||||
if self.normalized_embeddings is None:
|
||||
return np.zeros(len(self.index.chunks))
|
||||
|
||||
# Normalize query embedding
|
||||
query_norm = query_embedding / (np.linalg.norm(query_embedding) + 1e-10)
|
||||
|
||||
# Cosine similarity via dot product of normalized vectors
|
||||
similarities = self.normalized_embeddings @ query_norm
|
||||
|
||||
# Normalize to [0, 1] range
|
||||
similarities = (similarities + 1) / 2 # cosine ranges from -1 to 1
|
||||
|
||||
return similarities
|
||||
|
||||
def compute_bm25_scores(self, query_tokens: list[str]) -> np.ndarray:
|
||||
"""Compute BM25 scores for all chunks."""
|
||||
if self.bm25 is None:
|
||||
return np.zeros(len(self.index.chunks))
|
||||
|
||||
scores = self.bm25.get_scores(query_tokens)
|
||||
|
||||
# Normalize to [0, 1] range
|
||||
if scores.max() > 0:
|
||||
scores = scores / scores.max()
|
||||
|
||||
return scores
|
||||
|
||||
def compute_title_scores(self, query_tokens: list[str]) -> np.ndarray:
|
||||
"""Compute title match scores for all chunks."""
|
||||
scores = np.zeros(len(self.index.chunks))
|
||||
query_set = set(query_tokens)
|
||||
|
||||
for chunk_idx, chunk in enumerate(self.index.chunks):
|
||||
title_tokens = set(tokenize(chunk.doc_title))
|
||||
|
||||
# Exact matches
|
||||
exact_matches = len(query_set & title_tokens)
|
||||
|
||||
# Partial matches (substring)
|
||||
partial_matches = 0
|
||||
for qt in query_tokens:
|
||||
for tt in title_tokens:
|
||||
if qt in tt or tt in qt:
|
||||
partial_matches += 0.5
|
||||
|
||||
# Compute score
|
||||
if query_tokens:
|
||||
scores[chunk_idx] = (exact_matches * 2 + partial_matches) / (len(query_tokens) * 2)
|
||||
|
||||
return np.clip(scores, 0, 1)
|
||||
|
||||
def compute_path_scores(self, query_tokens: list[str]) -> np.ndarray:
|
||||
"""Compute URL/path match scores for all chunks."""
|
||||
scores = np.zeros(len(self.index.chunks))
|
||||
query_set = set(query_tokens)
|
||||
|
||||
for chunk_idx, chunk in enumerate(self.index.chunks):
|
||||
# Extract path components
|
||||
path_parts = re.split(r'[/_-]', chunk.doc_path.lower())
|
||||
path_parts = [p.replace('.md', '') for p in path_parts if p]
|
||||
path_set = set(path_parts)
|
||||
|
||||
# Count matches
|
||||
matches = len(query_set & path_set)
|
||||
|
||||
# Partial matches
|
||||
partial = 0
|
||||
for qt in query_tokens:
|
||||
for pp in path_parts:
|
||||
if qt in pp or pp in qt:
|
||||
partial += 0.5
|
||||
|
||||
if query_tokens:
|
||||
scores[chunk_idx] = (matches * 2 + partial) / (len(query_tokens) * 2)
|
||||
|
||||
return np.clip(scores, 0, 1)
|
||||
|
||||
def compute_quality_scores(self) -> np.ndarray:
|
||||
"""
|
||||
Compute content quality scores.
|
||||
|
||||
Penalizes:
|
||||
- TOC/navigation chunks (lots of links, little content)
|
||||
- Very short chunks
|
||||
- Chunks that are mostly code
|
||||
"""
|
||||
scores = np.ones(len(self.index.chunks))
|
||||
|
||||
for chunk_idx, chunk in enumerate(self.index.chunks):
|
||||
text = chunk.text
|
||||
penalty = 0.0
|
||||
|
||||
# Penalize TOC-like content (many links)
|
||||
link_count = len(re.findall(r'\[([^\]]+)\]\([^)]+\)', text))
|
||||
if link_count > 10:
|
||||
penalty += 0.3
|
||||
|
||||
# Penalize very short chunks
|
||||
if len(text) < 200:
|
||||
penalty += 0.2
|
||||
|
||||
# Penalize chunks that are mostly code
|
||||
code_blocks = re.findall(r'```[\s\S]*?```', text)
|
||||
code_length = sum(len(b) for b in code_blocks)
|
||||
if len(text) > 0 and code_length / len(text) > 0.8:
|
||||
penalty += 0.2
|
||||
|
||||
# Penalize index/navigation pages
|
||||
if chunk.doc_path.endswith('index.md'):
|
||||
penalty += 0.1
|
||||
|
||||
scores[chunk_idx] = max(0, 1 - penalty)
|
||||
|
||||
return scores
|
||||
|
||||
def compute_pagerank_scores(self) -> np.ndarray:
|
||||
"""Get PageRank scores for all chunks (by document)."""
|
||||
scores = np.zeros(len(self.index.chunks))
|
||||
|
||||
for chunk_idx, chunk in enumerate(self.index.chunks):
|
||||
scores[chunk_idx] = self.index.pagerank.get(chunk.doc_path, 0.0)
|
||||
|
||||
return scores
|
||||
|
||||
def compute_position_scores(self) -> np.ndarray:
|
||||
"""Compute position boost (prefer earlier chunks in document)."""
|
||||
scores = np.zeros(len(self.index.chunks))
|
||||
|
||||
# Group chunks by document
|
||||
doc_chunks = {}
|
||||
for chunk_idx, chunk in enumerate(self.index.chunks):
|
||||
if chunk.doc_path not in doc_chunks:
|
||||
doc_chunks[chunk.doc_path] = []
|
||||
doc_chunks[chunk.doc_path].append(chunk_idx)
|
||||
|
||||
for doc_path, chunk_indices in doc_chunks.items():
|
||||
n = len(chunk_indices)
|
||||
for i, chunk_idx in enumerate(chunk_indices):
|
||||
# Earlier chunks get higher scores (linear decay)
|
||||
scores[chunk_idx] = 1 - (i / max(n, 1))
|
||||
|
||||
return scores
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
top_k: int = 10,
|
||||
apply_diversity: bool = True
|
||||
) -> list[SearchResult]:
|
||||
"""
|
||||
Perform hybrid search.
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
top_k: Number of results to return
|
||||
apply_diversity: Apply diversity penalty (max chunks per doc)
|
||||
|
||||
Returns:
|
||||
List of SearchResult objects sorted by score
|
||||
"""
|
||||
# Tokenize query
|
||||
query_tokens = tokenize(query)
|
||||
if not query_tokens:
|
||||
return []
|
||||
|
||||
# Compute all signal scores
|
||||
semantic_scores = np.zeros(len(self.index.chunks))
|
||||
if self.normalized_embeddings is not None and (self.local_model or self.openai_client):
|
||||
query_embedding = self.embed_query(query)
|
||||
if query_embedding is not None:
|
||||
semantic_scores = self.compute_semantic_scores(query_embedding)
|
||||
|
||||
bm25_scores = self.compute_bm25_scores(query_tokens)
|
||||
title_scores = self.compute_title_scores(query_tokens)
|
||||
path_scores = self.compute_path_scores(query_tokens)
|
||||
quality_scores = self.compute_quality_scores()
|
||||
pagerank_scores = self.compute_pagerank_scores()
|
||||
position_scores = self.compute_position_scores()
|
||||
|
||||
# Combine scores using weights
|
||||
w = self.weights
|
||||
combined_scores = (
|
||||
w.semantic * semantic_scores +
|
||||
w.title_match * title_scores +
|
||||
w.url_path_match * path_scores +
|
||||
w.bm25 * bm25_scores +
|
||||
w.content_quality * quality_scores +
|
||||
w.pagerank * pagerank_scores +
|
||||
w.position_boost * position_scores
|
||||
)
|
||||
|
||||
# Create results
|
||||
results = []
|
||||
for chunk_idx in range(len(self.index.chunks)):
|
||||
results.append(SearchResult(
|
||||
chunk=self.index.chunks[chunk_idx],
|
||||
score=combined_scores[chunk_idx],
|
||||
semantic_score=semantic_scores[chunk_idx],
|
||||
title_score=title_scores[chunk_idx],
|
||||
path_score=path_scores[chunk_idx],
|
||||
bm25_score=bm25_scores[chunk_idx],
|
||||
quality_score=quality_scores[chunk_idx],
|
||||
pagerank_score=pagerank_scores[chunk_idx],
|
||||
position_score=position_scores[chunk_idx]
|
||||
))
|
||||
|
||||
# Sort by score
|
||||
results.sort(key=lambda r: r.score, reverse=True)
|
||||
|
||||
# Apply diversity penalty
|
||||
if apply_diversity:
|
||||
results = self._apply_diversity(results, top_k)
|
||||
|
||||
return results[:top_k]
|
||||
|
||||
def _apply_diversity(
|
||||
self,
|
||||
results: list[SearchResult],
|
||||
target_k: int
|
||||
) -> list[SearchResult]:
|
||||
"""
|
||||
Deduplicate results from the same document unless they point to
|
||||
different sections (headings) within the page.
|
||||
|
||||
Logic:
|
||||
1. Only keep one result per unique (doc_path, heading) combination
|
||||
2. Additionally limit total chunks per document to max_chunks_per_doc
|
||||
"""
|
||||
seen_sections = set() # (doc_path, heading) tuples
|
||||
doc_counts = {} # doc_path -> count
|
||||
filtered = []
|
||||
|
||||
for result in results:
|
||||
doc_path = result.chunk.doc_path
|
||||
heading = result.chunk.heading
|
||||
section_key = (doc_path, heading)
|
||||
|
||||
# Skip if we've already seen this exact section
|
||||
if section_key in seen_sections:
|
||||
continue
|
||||
|
||||
# Also enforce max chunks per document
|
||||
doc_count = doc_counts.get(doc_path, 0)
|
||||
if doc_count >= self.weights.max_chunks_per_doc:
|
||||
continue
|
||||
|
||||
# Keep this result
|
||||
seen_sections.add(section_key)
|
||||
doc_counts[doc_path] = doc_count + 1
|
||||
filtered.append(result)
|
||||
|
||||
if len(filtered) >= target_k * 2: # Get extra for buffer
|
||||
break
|
||||
|
||||
return filtered
|
||||
|
||||
def search_title_only(self, query: str, top_k: int = 10) -> list[SearchResult]:
|
||||
"""
|
||||
Fallback search using only title index and PageRank.
|
||||
Useful when embeddings aren't available.
|
||||
"""
|
||||
query_tokens = tokenize(query)
|
||||
if not query_tokens:
|
||||
return []
|
||||
|
||||
# Score documents by title match
|
||||
doc_scores = {}
|
||||
for token in query_tokens:
|
||||
if token in self.index.title_index:
|
||||
for doc_idx, score in self.index.title_index[token]:
|
||||
doc_path = self.index.documents[doc_idx].path
|
||||
doc_scores[doc_path] = doc_scores.get(doc_path, 0) + score
|
||||
|
||||
# Boost by PageRank
|
||||
for doc_path in doc_scores:
|
||||
pr = self.index.pagerank.get(doc_path, 0.0)
|
||||
doc_scores[doc_path] *= (1 + pr)
|
||||
|
||||
# Get top documents and their first chunks
|
||||
sorted_docs = sorted(doc_scores.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
results = []
|
||||
for doc_path, score in sorted_docs[:top_k]:
|
||||
# Find first chunk of this document
|
||||
for chunk in self.index.chunks:
|
||||
if chunk.doc_path == doc_path:
|
||||
results.append(SearchResult(
|
||||
chunk=chunk,
|
||||
score=score,
|
||||
title_score=score
|
||||
))
|
||||
break
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Reciprocal Rank Fusion (Alternative scoring method)
|
||||
# ============================================================================
|
||||
|
||||
def reciprocal_rank_fusion(
|
||||
rankings: list[list[int]],
|
||||
k: int = 60
|
||||
) -> list[tuple[int, float]]:
|
||||
"""
|
||||
Combine multiple rankings using Reciprocal Rank Fusion.
|
||||
|
||||
RRF is an alternative to weighted linear combination that's
|
||||
less sensitive to score scale differences.
|
||||
|
||||
Args:
|
||||
rankings: List of rankings (each is list of chunk indices)
|
||||
k: RRF parameter (default 60 works well)
|
||||
|
||||
Returns:
|
||||
List of (chunk_idx, rrf_score) tuples sorted by score
|
||||
"""
|
||||
scores = {}
|
||||
|
||||
for ranking in rankings:
|
||||
for rank, chunk_idx in enumerate(ranking):
|
||||
if chunk_idx not in scores:
|
||||
scores[chunk_idx] = 0
|
||||
scores[chunk_idx] += 1 / (k + rank + 1)
|
||||
|
||||
return sorted(scores.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
|
||||
class RRFSearcher(HybridSearcher):
|
||||
"""
|
||||
Alternative searcher using Reciprocal Rank Fusion instead of
|
||||
weighted linear combination.
|
||||
"""
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
top_k: int = 10,
|
||||
apply_diversity: bool = True
|
||||
) -> list[SearchResult]:
|
||||
"""Search using RRF fusion."""
|
||||
query_tokens = tokenize(query)
|
||||
if not query_tokens:
|
||||
return []
|
||||
|
||||
# Get individual rankings
|
||||
rankings = []
|
||||
|
||||
# Semantic ranking
|
||||
if self.normalized_embeddings is not None and (self.local_model or self.openai_client):
|
||||
query_embedding = self.embed_query(query)
|
||||
if query_embedding is not None:
|
||||
scores = self.compute_semantic_scores(query_embedding)
|
||||
rankings.append(np.argsort(scores)[::-1].tolist())
|
||||
|
||||
# BM25 ranking
|
||||
if self.bm25:
|
||||
scores = self.compute_bm25_scores(query_tokens)
|
||||
rankings.append(np.argsort(scores)[::-1].tolist())
|
||||
|
||||
# Title ranking
|
||||
scores = self.compute_title_scores(query_tokens)
|
||||
rankings.append(np.argsort(scores)[::-1].tolist())
|
||||
|
||||
if not rankings:
|
||||
return []
|
||||
|
||||
# Fuse rankings
|
||||
fused = reciprocal_rank_fusion(rankings)
|
||||
|
||||
# Build results
|
||||
results = []
|
||||
for chunk_idx, score in fused[:top_k * 3]: # Extra for diversity
|
||||
chunk = self.index.chunks[chunk_idx]
|
||||
results.append(SearchResult(
|
||||
chunk=chunk,
|
||||
score=score
|
||||
))
|
||||
|
||||
# Apply diversity
|
||||
if apply_diversity:
|
||||
results = self._apply_diversity(results, top_k)
|
||||
|
||||
return results[:top_k]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# CLI
|
||||
# ============================================================================
|
||||
|
||||
def format_result(result: SearchResult, show_text: bool = True) -> str:
|
||||
"""Format a search result for display."""
|
||||
lines = [
|
||||
f"\n{'='*60}",
|
||||
f"Score: {result.score:.3f}",
|
||||
f"Title: {result.chunk.doc_title}",
|
||||
f"Path: {result.chunk.doc_path}",
|
||||
f"Section: {result.chunk.heading}",
|
||||
]
|
||||
|
||||
if show_text:
|
||||
# Truncate text for display
|
||||
text = result.chunk.text[:500]
|
||||
if len(result.chunk.text) > 500:
|
||||
text += "..."
|
||||
lines.append(f"\n{text}")
|
||||
|
||||
return '\n'.join(lines)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Search documentation using hybrid search"
|
||||
)
|
||||
parser.add_argument(
|
||||
'query',
|
||||
type=str,
|
||||
help='Search query'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--index',
|
||||
type=Path,
|
||||
default=Path('./index.bin'),
|
||||
help='Path to index file (default: ./index.bin)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--top-k',
|
||||
type=int,
|
||||
default=5,
|
||||
help='Number of results (default: 5)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--weights',
|
||||
type=str,
|
||||
choices=['default', 'semantic_heavy', 'keyword_heavy', 'authority_heavy'],
|
||||
default='default',
|
||||
help='Weight preset (default: default)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--rrf',
|
||||
action='store_true',
|
||||
help='Use Reciprocal Rank Fusion instead of weighted combination'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--no-diversity',
|
||||
action='store_true',
|
||||
help='Disable diversity penalty'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--debug',
|
||||
action='store_true',
|
||||
help='Show detailed scoring breakdown'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--no-text',
|
||||
action='store_true',
|
||||
help='Hide result text snippets'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.index.exists():
|
||||
print(f"Error: Index file not found: {args.index}")
|
||||
print("Run 'python index.py' first to create the index.")
|
||||
sys.exit(1)
|
||||
|
||||
# Load index
|
||||
print(f"Loading index from {args.index}...")
|
||||
index = load_index(args.index)
|
||||
print(f"Loaded {len(index.chunks)} chunks from {len(index.documents)} documents")
|
||||
|
||||
# Select weights
|
||||
weights = DEFAULT_WEIGHTS
|
||||
if args.weights != 'default':
|
||||
weights = WEIGHT_PRESETS[args.weights]
|
||||
|
||||
# Create searcher
|
||||
SearcherClass = RRFSearcher if args.rrf else HybridSearcher
|
||||
searcher = SearcherClass(index, weights)
|
||||
|
||||
# Search
|
||||
print(f"\nSearching for: '{args.query}'")
|
||||
results = searcher.search(
|
||||
args.query,
|
||||
top_k=args.top_k,
|
||||
apply_diversity=not args.no_diversity
|
||||
)
|
||||
|
||||
if not results:
|
||||
print("No results found.")
|
||||
return
|
||||
|
||||
print(f"\nFound {len(results)} results:")
|
||||
|
||||
for i, result in enumerate(results, 1):
|
||||
print(f"\n--- Result {i} ---")
|
||||
print(format_result(result, show_text=not args.no_text))
|
||||
|
||||
if args.debug:
|
||||
print(f"\nScore breakdown:")
|
||||
print(f" Semantic: {result.semantic_score:.3f}")
|
||||
print(f" Title match: {result.title_score:.3f}")
|
||||
print(f" Path match: {result.path_score:.3f}")
|
||||
print(f" BM25: {result.bm25_score:.3f}")
|
||||
print(f" Quality: {result.quality_score:.3f}")
|
||||
print(f" PageRank: {result.pagerank_score:.3f}")
|
||||
print(f" Position: {result.position_score:.3f}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user