mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
This PR adds hybrid search functionality combining semantic embeddings with traditional text search for improved store listing discovery. ### Changes 🏗️ - Add `embeddings.py` - OpenAI-based embedding generation and similarity search - Add `hybrid_search.py` - Combines vector similarity with text matching for better search results - Add `backfill_embeddings.py` - Script to generate embeddings for existing store listings - Update `db.py` - Integrate hybrid search into store database queries - Update `schema.prisma` - Add embedding storage fields and indexes - Add migrations for embedding columns and HNSW index for vector search ### Architecture Decisions 🏛️ **Fail-Fast Approach (No Silent Fallbacks)** We explicitly chose NOT to implement graceful degradation when hybrid search fails. Here's why: ✅ **Benefits:** - Errors surface immediately → faster fixes - Tests verify hybrid search actually works (not just fallback) - Consistent search quality for all users - Forces proper infrastructure setup (API keys, database) ❌ **Why Not Fallback:** - Silent degradation hides production issues - Users get inconsistent results without knowing why - Tests can pass even when hybrid search is broken - Reduces operational visibility **How We Prevent Failures:** 1. Embedding generation in approval flow (db.py:1545) 2. Error logging with `logger.error` (not warning) 3. Clear error messages (ValueError explains what's wrong) 4. Comprehensive test coverage (9/9 tests passing) If embeddings fail, it indicates a real infrastructure issue (missing API key, OpenAI down, database issues) that needs immediate attention, not silent degradation. ### Test Coverage ✅ **All tests passing (1625 total):** - 9/9 hybrid_search tests (including fail-fast validation) - 3/3 db search integration tests - Full schema compatibility (public/platform schemas) - Error handling verification ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: - [x] Test hybrid search returns relevant results - [x] Test embedding generation for new listings - [x] Test backfill script on existing data - [x] Verify search performance with embeddings - [x] Test fail-fast behavior when embeddings unavailable #### For configuration changes: - [x] `.env.default` is updated or already compatible with my changes - [x] `docker-compose.yml` is updated or already compatible with my changes - [x] Configuration: Requires `openai_internal_api_key` in secrets --------- Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co> Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
466 lines
16 KiB
Python
466 lines
16 KiB
Python
import json
|
|
from typing import Any
|
|
from unittest.mock import AsyncMock, patch
|
|
from uuid import UUID
|
|
|
|
import fastapi.exceptions
|
|
import pytest
|
|
from pytest_snapshot.plugin import Snapshot
|
|
|
|
import backend.api.features.store.model as store
|
|
from backend.api.model import CreateGraph
|
|
from backend.blocks.basic import StoreValueBlock
|
|
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
|
from backend.data.block import BlockSchema, BlockSchemaInput
|
|
from backend.data.graph import Graph, Link, Node
|
|
from backend.data.model import SchemaField
|
|
from backend.data.user import DEFAULT_USER_ID
|
|
from backend.usecases.sample import create_test_user
|
|
from backend.util.test import SpinTestServer
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def mock_embedding_functions():
|
|
"""Mock embedding functions for all tests to avoid database/API dependencies."""
|
|
with patch(
|
|
"backend.api.features.store.db.ensure_embedding",
|
|
new_callable=AsyncMock,
|
|
return_value=True,
|
|
):
|
|
yield
|
|
|
|
|
|
@pytest.mark.asyncio(loop_scope="session")
|
|
async def test_graph_creation(server: SpinTestServer, snapshot: Snapshot):
|
|
"""
|
|
Test the creation of a graph with nodes and links.
|
|
|
|
This test ensures that:
|
|
1. A graph can be successfully created with valid connections.
|
|
2. The created graph has the correct structure and properties.
|
|
|
|
Args:
|
|
server (SpinTestServer): The test server instance.
|
|
"""
|
|
value_block = StoreValueBlock().id
|
|
input_block = AgentInputBlock().id
|
|
|
|
graph = Graph(
|
|
id="test_graph",
|
|
name="TestGraph",
|
|
description="Test graph",
|
|
nodes=[
|
|
Node(id="node_1", block_id=value_block),
|
|
Node(id="node_2", block_id=input_block, input_default={"name": "input"}),
|
|
Node(id="node_3", block_id=value_block),
|
|
],
|
|
links=[
|
|
Link(
|
|
source_id="node_1",
|
|
sink_id="node_2",
|
|
source_name="output",
|
|
sink_name="name",
|
|
),
|
|
],
|
|
)
|
|
create_graph = CreateGraph(graph=graph)
|
|
created_graph = await server.agent_server.test_create_graph(
|
|
create_graph, DEFAULT_USER_ID
|
|
)
|
|
|
|
assert UUID(created_graph.id)
|
|
assert created_graph.name == "TestGraph"
|
|
|
|
assert len(created_graph.nodes) == 3
|
|
assert UUID(created_graph.nodes[0].id)
|
|
assert UUID(created_graph.nodes[1].id)
|
|
assert UUID(created_graph.nodes[2].id)
|
|
|
|
nodes = created_graph.nodes
|
|
links = created_graph.links
|
|
assert len(links) == 1
|
|
assert links[0].source_id != links[0].sink_id
|
|
assert links[0].source_id in {nodes[0].id, nodes[1].id, nodes[2].id}
|
|
assert links[0].sink_id in {nodes[0].id, nodes[1].id, nodes[2].id}
|
|
|
|
# Create a serializable version of the graph for snapshot testing
|
|
# Remove dynamic IDs to make snapshots reproducible
|
|
graph_data = {
|
|
"name": created_graph.name,
|
|
"description": created_graph.description,
|
|
"nodes_count": len(created_graph.nodes),
|
|
"links_count": len(created_graph.links),
|
|
"node_blocks": [node.block_id for node in created_graph.nodes],
|
|
"link_structure": [
|
|
{"source_name": link.source_name, "sink_name": link.sink_name}
|
|
for link in created_graph.links
|
|
],
|
|
}
|
|
snapshot.snapshot_dir = "snapshots"
|
|
snapshot.assert_match(
|
|
json.dumps(graph_data, indent=2, sort_keys=True), "grph_struct"
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio(loop_scope="session")
|
|
async def test_get_input_schema(server: SpinTestServer, snapshot: Snapshot):
|
|
"""
|
|
Test the get_input_schema method of a created graph.
|
|
|
|
This test ensures that:
|
|
1. A graph can be created with a single node.
|
|
2. The input schema of the created graph is correctly generated.
|
|
3. The input schema contains the expected input name and node id.
|
|
|
|
Args:
|
|
server (SpinTestServer): The test server instance.
|
|
"""
|
|
value_block = StoreValueBlock().id
|
|
input_block = AgentInputBlock().id
|
|
output_block = AgentOutputBlock().id
|
|
|
|
graph = Graph(
|
|
name="TestInputSchema",
|
|
description="Test input schema",
|
|
nodes=[
|
|
Node(
|
|
id="node_0_a",
|
|
block_id=input_block,
|
|
input_default={
|
|
"name": "in_key_a",
|
|
"title": "Key A",
|
|
"value": "A",
|
|
"advanced": True,
|
|
},
|
|
metadata={"id": "node_0_a"},
|
|
),
|
|
Node(
|
|
id="node_0_b",
|
|
block_id=input_block,
|
|
input_default={"name": "in_key_b", "advanced": True},
|
|
metadata={"id": "node_0_b"},
|
|
),
|
|
Node(id="node_1", block_id=value_block, metadata={"id": "node_1"}),
|
|
Node(
|
|
id="node_2",
|
|
block_id=output_block,
|
|
input_default={
|
|
"name": "out_key",
|
|
"description": "This is an output key",
|
|
},
|
|
metadata={"id": "node_2"},
|
|
),
|
|
],
|
|
links=[
|
|
Link(
|
|
source_id="node_0_a",
|
|
sink_id="node_1",
|
|
source_name="result",
|
|
sink_name="input",
|
|
),
|
|
Link(
|
|
source_id="node_0_b",
|
|
sink_id="node_1",
|
|
source_name="result",
|
|
sink_name="input",
|
|
),
|
|
Link(
|
|
source_id="node_1",
|
|
sink_id="node_2",
|
|
source_name="output",
|
|
sink_name="value",
|
|
),
|
|
],
|
|
)
|
|
|
|
create_graph = CreateGraph(graph=graph)
|
|
created_graph = await server.agent_server.test_create_graph(
|
|
create_graph, DEFAULT_USER_ID
|
|
)
|
|
|
|
class ExpectedInputSchema(BlockSchemaInput):
|
|
in_key_a: Any = SchemaField(title="Key A", default="A", advanced=True)
|
|
in_key_b: Any = SchemaField(title="in_key_b", advanced=False)
|
|
|
|
class ExpectedOutputSchema(BlockSchema):
|
|
# Note: Graph output schemas are dynamically generated and don't inherit
|
|
# from BlockSchemaOutput, so we use BlockSchema as the base instead
|
|
out_key: Any = SchemaField(
|
|
description="This is an output key",
|
|
title="out_key",
|
|
advanced=False,
|
|
)
|
|
|
|
input_schema = created_graph.input_schema
|
|
input_schema["title"] = "ExpectedInputSchema"
|
|
assert input_schema == ExpectedInputSchema.jsonschema()
|
|
|
|
# Add snapshot testing for the schemas
|
|
snapshot.snapshot_dir = "snapshots"
|
|
snapshot.assert_match(
|
|
json.dumps(input_schema, indent=2, sort_keys=True), "grph_in_schm"
|
|
)
|
|
|
|
output_schema = created_graph.output_schema
|
|
output_schema["title"] = "ExpectedOutputSchema"
|
|
assert output_schema == ExpectedOutputSchema.jsonschema()
|
|
|
|
# Add snapshot testing for the output schema
|
|
snapshot.snapshot_dir = "snapshots"
|
|
snapshot.assert_match(
|
|
json.dumps(output_schema, indent=2, sort_keys=True), "grph_out_schm"
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio(loop_scope="session")
|
|
async def test_clean_graph(server: SpinTestServer):
|
|
"""
|
|
Test the stripped_for_export function that:
|
|
1. Removes sensitive/secret fields from node inputs
|
|
2. Removes webhook information
|
|
3. Preserves non-sensitive data including input block values
|
|
"""
|
|
# Create a graph with input blocks containing both sensitive and normal data
|
|
graph = Graph(
|
|
id="test_clean_graph",
|
|
name="Test Clean Graph",
|
|
description="Test graph cleaning",
|
|
nodes=[
|
|
Node(
|
|
block_id=AgentInputBlock().id,
|
|
input_default={
|
|
"_test_id": "input_node",
|
|
"name": "test_input",
|
|
"value": "test value", # This should be preserved
|
|
"description": "Test input description",
|
|
},
|
|
),
|
|
Node(
|
|
block_id=AgentInputBlock().id,
|
|
input_default={
|
|
"_test_id": "input_node_secret",
|
|
"name": "secret_input",
|
|
"value": "another value",
|
|
"secret": True, # This makes the input secret
|
|
},
|
|
),
|
|
Node(
|
|
block_id=StoreValueBlock().id,
|
|
input_default={
|
|
"_test_id": "node_with_secrets",
|
|
"input": "normal_value",
|
|
"control_test_input": "should be preserved",
|
|
"api_key": "secret_api_key_123", # Should be filtered
|
|
"password": "secret_password_456", # Should be filtered
|
|
"token": "secret_token_789", # Should be filtered
|
|
"credentials": { # Should be filtered
|
|
"id": "fake-github-credentials-id",
|
|
"provider": "github",
|
|
"type": "api_key",
|
|
},
|
|
"anthropic_credentials": { # Should be filtered
|
|
"id": "fake-anthropic-credentials-id",
|
|
"provider": "anthropic",
|
|
"type": "api_key",
|
|
},
|
|
},
|
|
),
|
|
],
|
|
links=[],
|
|
)
|
|
|
|
# Create graph and get model
|
|
create_graph = CreateGraph(graph=graph)
|
|
created_graph = await server.agent_server.test_create_graph(
|
|
create_graph, DEFAULT_USER_ID
|
|
)
|
|
|
|
# Clean the graph
|
|
cleaned_graph = await server.agent_server.test_get_graph(
|
|
created_graph.id, created_graph.version, DEFAULT_USER_ID, for_export=True
|
|
)
|
|
|
|
# Verify sensitive fields are removed but normal fields are preserved
|
|
input_node = next(
|
|
n for n in cleaned_graph.nodes if n.input_default["_test_id"] == "input_node"
|
|
)
|
|
|
|
# Non-sensitive fields should be preserved
|
|
assert input_node.input_default["name"] == "test_input"
|
|
assert input_node.input_default["value"] == "test value" # Should be preserved now
|
|
assert input_node.input_default["description"] == "Test input description"
|
|
|
|
# Sensitive fields should be filtered out
|
|
assert "api_key" not in input_node.input_default
|
|
assert "password" not in input_node.input_default
|
|
|
|
# Verify secret input node preserves non-sensitive fields but removes secret value
|
|
secret_node = next(
|
|
n
|
|
for n in cleaned_graph.nodes
|
|
if n.input_default["_test_id"] == "input_node_secret"
|
|
)
|
|
assert secret_node.input_default["name"] == "secret_input"
|
|
assert "value" not in secret_node.input_default # Secret default should be removed
|
|
assert secret_node.input_default["secret"] is True
|
|
|
|
# Verify sensitive fields are filtered from nodes with secrets
|
|
secrets_node = next(
|
|
n
|
|
for n in cleaned_graph.nodes
|
|
if n.input_default["_test_id"] == "node_with_secrets"
|
|
)
|
|
# Normal fields should be preserved
|
|
assert secrets_node.input_default["input"] == "normal_value"
|
|
assert secrets_node.input_default["control_test_input"] == "should be preserved"
|
|
# Sensitive fields should be filtered out
|
|
assert "api_key" not in secrets_node.input_default
|
|
assert "password" not in secrets_node.input_default
|
|
assert "token" not in secrets_node.input_default
|
|
assert "credentials" not in secrets_node.input_default
|
|
assert "anthropic_credentials" not in secrets_node.input_default
|
|
|
|
# Verify webhook info is removed (if any nodes had it)
|
|
for node in cleaned_graph.nodes:
|
|
assert node.webhook_id is None
|
|
assert node.webhook is None
|
|
|
|
|
|
@pytest.mark.asyncio(loop_scope="session")
|
|
async def test_access_store_listing_graph(server: SpinTestServer):
|
|
"""
|
|
Test the access of a store listing graph.
|
|
"""
|
|
graph = Graph(
|
|
id="test_clean_graph",
|
|
name="Test Clean Graph",
|
|
description="Test graph cleaning",
|
|
nodes=[
|
|
Node(
|
|
id="input_node",
|
|
block_id=AgentInputBlock().id,
|
|
input_default={
|
|
"name": "test_input",
|
|
"value": "test value",
|
|
"description": "Test input description",
|
|
},
|
|
),
|
|
],
|
|
links=[],
|
|
)
|
|
|
|
# Create graph and get model
|
|
create_graph = CreateGraph(graph=graph)
|
|
created_graph = await server.agent_server.test_create_graph(
|
|
create_graph, DEFAULT_USER_ID
|
|
)
|
|
|
|
store_submission_request = store.StoreSubmissionRequest(
|
|
agent_id=created_graph.id,
|
|
agent_version=created_graph.version,
|
|
slug=created_graph.id,
|
|
name="Test name",
|
|
sub_heading="Test sub heading",
|
|
video_url=None,
|
|
image_urls=[],
|
|
description="Test description",
|
|
categories=[],
|
|
)
|
|
|
|
# First we check the graph an not be accessed by a different user
|
|
with pytest.raises(fastapi.exceptions.HTTPException) as exc_info:
|
|
await server.agent_server.test_get_graph(
|
|
created_graph.id,
|
|
created_graph.version,
|
|
"3e53486c-cf57-477e-ba2a-cb02dc828e1b",
|
|
)
|
|
assert exc_info.value.status_code == 404
|
|
assert "Graph" in str(exc_info.value.detail)
|
|
|
|
# Now we create a store listing
|
|
store_listing = await server.agent_server.test_create_store_listing(
|
|
store_submission_request, DEFAULT_USER_ID
|
|
)
|
|
|
|
if isinstance(store_listing, fastapi.responses.JSONResponse):
|
|
assert False, "Failed to create store listing"
|
|
|
|
slv_id = (
|
|
store_listing.store_listing_version_id
|
|
if store_listing.store_listing_version_id is not None
|
|
else None
|
|
)
|
|
|
|
assert slv_id is not None
|
|
|
|
admin_user = await create_test_user(alt_user=True)
|
|
await server.agent_server.test_review_store_listing(
|
|
store.ReviewSubmissionRequest(
|
|
store_listing_version_id=slv_id,
|
|
is_approved=True,
|
|
comments="Test comments",
|
|
),
|
|
user_id=admin_user.id,
|
|
)
|
|
|
|
# Now we check the graph can be accessed by a user that does not own the graph
|
|
got_graph = await server.agent_server.test_get_graph(
|
|
created_graph.id, created_graph.version, "3e53486c-cf57-477e-ba2a-cb02dc828e1b"
|
|
)
|
|
assert got_graph is not None
|
|
|
|
|
|
# ============================================================================
|
|
# Tests for Optional Credentials Feature
|
|
# ============================================================================
|
|
|
|
|
|
def test_node_credentials_optional_default():
|
|
"""Test that credentials_optional defaults to False when not set in metadata."""
|
|
node = Node(
|
|
id="test_node",
|
|
block_id=StoreValueBlock().id,
|
|
input_default={},
|
|
metadata={},
|
|
)
|
|
assert node.credentials_optional is False
|
|
|
|
|
|
def test_node_credentials_optional_true():
|
|
"""Test that credentials_optional returns True when explicitly set."""
|
|
node = Node(
|
|
id="test_node",
|
|
block_id=StoreValueBlock().id,
|
|
input_default={},
|
|
metadata={"credentials_optional": True},
|
|
)
|
|
assert node.credentials_optional is True
|
|
|
|
|
|
def test_node_credentials_optional_false():
|
|
"""Test that credentials_optional returns False when explicitly set to False."""
|
|
node = Node(
|
|
id="test_node",
|
|
block_id=StoreValueBlock().id,
|
|
input_default={},
|
|
metadata={"credentials_optional": False},
|
|
)
|
|
assert node.credentials_optional is False
|
|
|
|
|
|
def test_node_credentials_optional_with_other_metadata():
|
|
"""Test that credentials_optional works correctly with other metadata present."""
|
|
node = Node(
|
|
id="test_node",
|
|
block_id=StoreValueBlock().id,
|
|
input_default={},
|
|
metadata={
|
|
"position": {"x": 100, "y": 200},
|
|
"customized_name": "My Custom Node",
|
|
"credentials_optional": True,
|
|
},
|
|
)
|
|
assert node.credentials_optional is True
|
|
assert node.metadata["position"] == {"x": 100, "y": 200}
|
|
assert node.metadata["customized_name"] == "My Custom Node"
|