mirror of
https://github.com/simstudioai/sim.git
synced 2026-04-06 03:00:16 -04:00
fix(kb-search): made query optional, so either query or tags or both can be provided (#848)
* fix(kb-search): made query optional, so either query or tags or both can be provided * cleanup * added handlers, ensured that tag search done before vector search * remove duplicate function
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
/**
|
||||
* Tests for knowledge search API route
|
||||
* Focuses on route-specific functionality: authentication, validation, API contract, error handling
|
||||
* Search logic is tested in utils.test.ts
|
||||
*
|
||||
* @vitest-environment node
|
||||
*/
|
||||
@@ -56,6 +58,27 @@ vi.mock('@/app/api/knowledge/utils', () => ({
|
||||
checkKnowledgeBaseAccess: mockCheckKnowledgeBaseAccess,
|
||||
}))
|
||||
|
||||
const mockHandleTagOnlySearch = vi.fn()
|
||||
const mockHandleVectorOnlySearch = vi.fn()
|
||||
const mockHandleTagAndVectorSearch = vi.fn()
|
||||
const mockGetQueryStrategy = vi.fn()
|
||||
const mockGenerateSearchEmbedding = vi.fn()
|
||||
vi.mock('./utils', () => ({
|
||||
handleTagOnlySearch: mockHandleTagOnlySearch,
|
||||
handleVectorOnlySearch: mockHandleVectorOnlySearch,
|
||||
handleTagAndVectorSearch: mockHandleTagAndVectorSearch,
|
||||
getQueryStrategy: mockGetQueryStrategy,
|
||||
generateSearchEmbedding: mockGenerateSearchEmbedding,
|
||||
APIError: class APIError extends Error {
|
||||
public status: number
|
||||
constructor(message: string, status: number) {
|
||||
super(message)
|
||||
this.name = 'APIError'
|
||||
this.status = status
|
||||
}
|
||||
},
|
||||
}))
|
||||
|
||||
mockConsoleLogger()
|
||||
|
||||
describe('Knowledge Search API Route', () => {
|
||||
@@ -65,6 +88,10 @@ describe('Knowledge Search API Route', () => {
|
||||
where: vi.fn().mockReturnThis(),
|
||||
orderBy: vi.fn().mockReturnThis(),
|
||||
limit: vi.fn().mockReturnThis(),
|
||||
innerJoin: vi.fn().mockReturnThis(),
|
||||
leftJoin: vi.fn().mockReturnThis(),
|
||||
groupBy: vi.fn().mockReturnThis(),
|
||||
having: vi.fn().mockReturnThis(),
|
||||
}
|
||||
|
||||
const mockGetUserId = vi.fn()
|
||||
@@ -107,6 +134,17 @@ describe('Knowledge Search API Route', () => {
|
||||
}
|
||||
})
|
||||
|
||||
mockHandleTagOnlySearch.mockClear()
|
||||
mockHandleVectorOnlySearch.mockClear()
|
||||
mockHandleTagAndVectorSearch.mockClear()
|
||||
mockGetQueryStrategy.mockClear().mockReturnValue({
|
||||
useParallel: false,
|
||||
distanceThreshold: 1.0,
|
||||
parallelLimit: 15,
|
||||
singleQueryOptimized: true,
|
||||
})
|
||||
mockGenerateSearchEmbedding.mockClear().mockResolvedValue([0.1, 0.2, 0.3, 0.4, 0.5])
|
||||
|
||||
vi.stubGlobal('crypto', {
|
||||
randomUUID: vi.fn().mockReturnValue('mock-uuid-1234-5678'),
|
||||
})
|
||||
@@ -137,13 +175,19 @@ describe('Knowledge Search API Route', () => {
|
||||
it('should perform search successfully with single knowledge base', async () => {
|
||||
mockGetUserId.mockResolvedValue('user-123')
|
||||
|
||||
// Mock knowledge base access check to return success
|
||||
mockCheckKnowledgeBaseAccess.mockResolvedValue({
|
||||
hasAccess: true,
|
||||
knowledgeBase: mockKnowledgeBases[0],
|
||||
knowledgeBase: {
|
||||
id: 'kb-123',
|
||||
userId: 'user-123',
|
||||
name: 'Test KB',
|
||||
deletedAt: null,
|
||||
},
|
||||
})
|
||||
|
||||
mockDbChain.limit.mockResolvedValueOnce(mockSearchResults)
|
||||
mockDbChain.limit.mockResolvedValue([])
|
||||
|
||||
mockHandleVectorOnlySearch.mockResolvedValue(mockSearchResults)
|
||||
|
||||
mockFetch.mockResolvedValue({
|
||||
ok: true,
|
||||
@@ -168,7 +212,12 @@ describe('Knowledge Search API Route', () => {
|
||||
expect(data.data.results[0].similarity).toBe(0.8) // 1 - 0.2
|
||||
expect(data.data.query).toBe(validSearchData.query)
|
||||
expect(data.data.knowledgeBaseIds).toEqual(['kb-123'])
|
||||
expect(mockDbChain.select).toHaveBeenCalled()
|
||||
expect(mockHandleVectorOnlySearch).toHaveBeenCalledWith({
|
||||
knowledgeBaseIds: ['kb-123'],
|
||||
topK: 10,
|
||||
queryVector: JSON.stringify(mockEmbedding),
|
||||
distanceThreshold: expect.any(Number),
|
||||
})
|
||||
})
|
||||
|
||||
it('should perform search successfully with multiple knowledge bases', async () => {
|
||||
@@ -184,12 +233,13 @@ describe('Knowledge Search API Route', () => {
|
||||
|
||||
mockGetUserId.mockResolvedValue('user-123')
|
||||
|
||||
// Mock knowledge base access check to return success for both KBs
|
||||
mockCheckKnowledgeBaseAccess
|
||||
.mockResolvedValueOnce({ hasAccess: true, knowledgeBase: multiKbs[0] })
|
||||
.mockResolvedValueOnce({ hasAccess: true, knowledgeBase: multiKbs[1] })
|
||||
|
||||
mockDbChain.limit.mockResolvedValueOnce(mockSearchResults)
|
||||
mockDbChain.limit.mockResolvedValue([])
|
||||
|
||||
mockHandleVectorOnlySearch.mockResolvedValue(mockSearchResults)
|
||||
|
||||
mockFetch.mockResolvedValue({
|
||||
ok: true,
|
||||
@@ -207,6 +257,12 @@ describe('Knowledge Search API Route', () => {
|
||||
expect(response.status).toBe(200)
|
||||
expect(data.success).toBe(true)
|
||||
expect(data.data.knowledgeBaseIds).toEqual(['kb-123', 'kb-456'])
|
||||
expect(mockHandleVectorOnlySearch).toHaveBeenCalledWith({
|
||||
knowledgeBaseIds: ['kb-123', 'kb-456'],
|
||||
topK: 10,
|
||||
queryVector: JSON.stringify(mockEmbedding),
|
||||
distanceThreshold: expect.any(Number),
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle workflow-based authentication', async () => {
|
||||
@@ -217,13 +273,19 @@ describe('Knowledge Search API Route', () => {
|
||||
|
||||
mockGetUserId.mockResolvedValue('user-123')
|
||||
|
||||
// Mock knowledge base access check to return success
|
||||
mockCheckKnowledgeBaseAccess.mockResolvedValue({
|
||||
hasAccess: true,
|
||||
knowledgeBase: mockKnowledgeBases[0],
|
||||
knowledgeBase: {
|
||||
id: 'kb-123',
|
||||
userId: 'user-123',
|
||||
name: 'Test KB',
|
||||
deletedAt: null,
|
||||
},
|
||||
})
|
||||
|
||||
mockDbChain.limit.mockResolvedValueOnce(mockSearchResults) // Search results
|
||||
mockDbChain.limit.mockResolvedValue([])
|
||||
|
||||
mockHandleVectorOnlySearch.mockResolvedValue(mockSearchResults)
|
||||
|
||||
mockFetch.mockResolvedValue({
|
||||
ok: true,
|
||||
@@ -275,7 +337,6 @@ describe('Knowledge Search API Route', () => {
|
||||
it('should return not found for non-existent knowledge base', async () => {
|
||||
mockGetUserId.mockResolvedValue('user-123')
|
||||
|
||||
// Mock knowledge base access check to return no access
|
||||
mockCheckKnowledgeBaseAccess.mockResolvedValue({
|
||||
hasAccess: false,
|
||||
notFound: true,
|
||||
@@ -340,7 +401,12 @@ describe('Knowledge Search API Route', () => {
|
||||
// Mock knowledge base access check to return success
|
||||
mockCheckKnowledgeBaseAccess.mockResolvedValue({
|
||||
hasAccess: true,
|
||||
knowledgeBase: mockKnowledgeBases[0],
|
||||
knowledgeBase: {
|
||||
id: 'kb-123',
|
||||
userId: 'user-123',
|
||||
name: 'Test KB',
|
||||
deletedAt: null,
|
||||
},
|
||||
})
|
||||
|
||||
mockDbChain.limit.mockResolvedValueOnce(mockSearchResults) // Search results
|
||||
@@ -366,12 +432,10 @@ describe('Knowledge Search API Route', () => {
|
||||
mockGetUserId.mockResolvedValue('user-123')
|
||||
mockDbChain.limit.mockResolvedValueOnce(mockKnowledgeBases)
|
||||
|
||||
mockFetch.mockResolvedValue({
|
||||
ok: false,
|
||||
status: 401,
|
||||
statusText: 'Unauthorized',
|
||||
text: () => Promise.resolve('Invalid API key'),
|
||||
})
|
||||
// Mock generateSearchEmbedding to throw an error
|
||||
mockGenerateSearchEmbedding.mockRejectedValueOnce(
|
||||
new Error('OpenAI API error: 401 Unauthorized - Invalid API key')
|
||||
)
|
||||
|
||||
const req = createMockRequest('POST', validSearchData)
|
||||
const { POST } = await import('@/app/api/knowledge/search/route')
|
||||
@@ -383,15 +447,12 @@ describe('Knowledge Search API Route', () => {
|
||||
})
|
||||
|
||||
it.concurrent('should handle missing OpenAI API key', async () => {
|
||||
vi.doMock('@/lib/env', () => ({
|
||||
env: {
|
||||
OPENAI_API_KEY: undefined,
|
||||
},
|
||||
}))
|
||||
|
||||
mockGetUserId.mockResolvedValue('user-123')
|
||||
mockDbChain.limit.mockResolvedValueOnce(mockKnowledgeBases)
|
||||
|
||||
// Mock generateSearchEmbedding to throw missing API key error
|
||||
mockGenerateSearchEmbedding.mockRejectedValueOnce(new Error('OPENAI_API_KEY not configured'))
|
||||
|
||||
const req = createMockRequest('POST', validSearchData)
|
||||
const { POST } = await import('@/app/api/knowledge/search/route')
|
||||
const response = await POST(req)
|
||||
@@ -404,15 +465,9 @@ describe('Knowledge Search API Route', () => {
|
||||
it.concurrent('should handle database errors during search', async () => {
|
||||
mockGetUserId.mockResolvedValue('user-123')
|
||||
mockDbChain.limit.mockResolvedValueOnce(mockKnowledgeBases)
|
||||
mockDbChain.limit.mockRejectedValueOnce(new Error('Database error'))
|
||||
|
||||
mockFetch.mockResolvedValue({
|
||||
ok: true,
|
||||
json: () =>
|
||||
Promise.resolve({
|
||||
data: [{ embedding: mockEmbedding }],
|
||||
}),
|
||||
})
|
||||
// Mock the search handler to throw a database error
|
||||
mockHandleVectorOnlySearch.mockRejectedValueOnce(new Error('Database error'))
|
||||
|
||||
const req = createMockRequest('POST', validSearchData)
|
||||
const { POST } = await import('@/app/api/knowledge/search/route')
|
||||
@@ -427,13 +482,10 @@ describe('Knowledge Search API Route', () => {
|
||||
mockGetUserId.mockResolvedValue('user-123')
|
||||
mockDbChain.limit.mockResolvedValueOnce(mockKnowledgeBases)
|
||||
|
||||
mockFetch.mockResolvedValue({
|
||||
ok: true,
|
||||
json: () =>
|
||||
Promise.resolve({
|
||||
data: [], // Empty data array
|
||||
}),
|
||||
})
|
||||
// Mock generateSearchEmbedding to throw invalid response format error
|
||||
mockGenerateSearchEmbedding.mockRejectedValueOnce(
|
||||
new Error('Invalid response format from OpenAI embeddings API')
|
||||
)
|
||||
|
||||
const req = createMockRequest('POST', validSearchData)
|
||||
const { POST } = await import('@/app/api/knowledge/search/route')
|
||||
@@ -451,7 +503,12 @@ describe('Knowledge Search API Route', () => {
|
||||
// Mock knowledge base access check to return success
|
||||
mockCheckKnowledgeBaseAccess.mockResolvedValue({
|
||||
hasAccess: true,
|
||||
knowledgeBase: mockKnowledgeBases[0],
|
||||
knowledgeBase: {
|
||||
id: 'kb-123',
|
||||
userId: 'user-123',
|
||||
name: 'Test KB',
|
||||
deletedAt: null,
|
||||
},
|
||||
})
|
||||
|
||||
mockDbChain.limit.mockResolvedValueOnce(mockSearchResults)
|
||||
@@ -499,7 +556,12 @@ describe('Knowledge Search API Route', () => {
|
||||
// Mock knowledge base access check to return success
|
||||
mockCheckKnowledgeBaseAccess.mockResolvedValue({
|
||||
hasAccess: true,
|
||||
knowledgeBase: mockKnowledgeBases[0],
|
||||
knowledgeBase: {
|
||||
id: 'kb-123',
|
||||
userId: 'user-123',
|
||||
name: 'Test KB',
|
||||
deletedAt: null,
|
||||
},
|
||||
})
|
||||
|
||||
mockDbChain.limit.mockResolvedValueOnce(mockSearchResults)
|
||||
@@ -556,7 +618,12 @@ describe('Knowledge Search API Route', () => {
|
||||
// Mock knowledge base access check to return success
|
||||
mockCheckKnowledgeBaseAccess.mockResolvedValue({
|
||||
hasAccess: true,
|
||||
knowledgeBase: mockKnowledgeBases[0],
|
||||
knowledgeBase: {
|
||||
id: 'kb-123',
|
||||
userId: 'user-123',
|
||||
name: 'Test KB',
|
||||
deletedAt: null,
|
||||
},
|
||||
})
|
||||
|
||||
mockDbChain.limit.mockResolvedValueOnce(mockSearchResults)
|
||||
@@ -581,4 +648,350 @@ describe('Knowledge Search API Route', () => {
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Optional Query Search', () => {
|
||||
const mockTagDefinitions = [
|
||||
{ tagSlot: 'tag1', displayName: 'category' },
|
||||
{ tagSlot: 'tag2', displayName: 'priority' },
|
||||
]
|
||||
|
||||
const mockTaggedResults = [
|
||||
{
|
||||
id: 'chunk-1',
|
||||
content: 'Tagged content 1',
|
||||
documentId: 'doc-1',
|
||||
chunkIndex: 0,
|
||||
tag1: 'api',
|
||||
tag2: 'high',
|
||||
distance: 0,
|
||||
knowledgeBaseId: 'kb-123',
|
||||
},
|
||||
{
|
||||
id: 'chunk-2',
|
||||
content: 'Tagged content 2',
|
||||
documentId: 'doc-2',
|
||||
chunkIndex: 1,
|
||||
tag1: 'docs',
|
||||
tag2: 'medium',
|
||||
distance: 0,
|
||||
knowledgeBaseId: 'kb-123',
|
||||
},
|
||||
]
|
||||
|
||||
it('should perform tag-only search without query', async () => {
|
||||
const tagOnlyData = {
|
||||
knowledgeBaseIds: 'kb-123',
|
||||
filters: {
|
||||
category: 'api',
|
||||
},
|
||||
topK: 10,
|
||||
}
|
||||
|
||||
mockGetUserId.mockResolvedValue('user-123')
|
||||
mockCheckKnowledgeBaseAccess.mockResolvedValue({
|
||||
hasAccess: true,
|
||||
knowledgeBase: {
|
||||
id: 'kb-123',
|
||||
userId: 'user-123',
|
||||
name: 'Test KB',
|
||||
deletedAt: null,
|
||||
},
|
||||
})
|
||||
|
||||
// Mock tag definitions queries for filter mapping and display mapping
|
||||
mockDbChain.limit
|
||||
.mockResolvedValueOnce(mockTagDefinitions) // Tag definitions for filter mapping
|
||||
.mockResolvedValueOnce(mockTagDefinitions) // Tag definitions for display mapping
|
||||
|
||||
// Mock the tag-only search handler
|
||||
mockHandleTagOnlySearch.mockResolvedValue(mockTaggedResults)
|
||||
|
||||
const req = createMockRequest('POST', tagOnlyData)
|
||||
const { POST } = await import('@/app/api/knowledge/search/route')
|
||||
const response = await POST(req)
|
||||
const data = await response.json()
|
||||
|
||||
if (response.status !== 200) {
|
||||
console.log('Tag-only search test error:', data)
|
||||
}
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
expect(data.success).toBe(true)
|
||||
expect(data.data.results).toHaveLength(2)
|
||||
expect(data.data.results[0].similarity).toBe(1) // Perfect similarity for tag-only
|
||||
expect(data.data.query).toBe('') // Empty query
|
||||
expect(data.data.cost).toBeUndefined() // No cost for tag-only search
|
||||
expect(mockGenerateSearchEmbedding).not.toHaveBeenCalled() // No embedding API call
|
||||
expect(mockHandleTagOnlySearch).toHaveBeenCalledWith({
|
||||
knowledgeBaseIds: ['kb-123'],
|
||||
topK: 10,
|
||||
filters: { category: 'api' }, // Note: When no tag definitions are found, it uses the original filter key
|
||||
})
|
||||
})
|
||||
|
||||
it('should perform query + tag combination search', async () => {
|
||||
const combinedData = {
|
||||
knowledgeBaseIds: 'kb-123',
|
||||
query: 'test search',
|
||||
filters: {
|
||||
category: 'api',
|
||||
},
|
||||
topK: 10,
|
||||
}
|
||||
|
||||
mockGetUserId.mockResolvedValue('user-123')
|
||||
mockCheckKnowledgeBaseAccess.mockResolvedValue({
|
||||
hasAccess: true,
|
||||
knowledgeBase: {
|
||||
id: 'kb-123',
|
||||
userId: 'user-123',
|
||||
name: 'Test KB',
|
||||
deletedAt: null,
|
||||
},
|
||||
})
|
||||
|
||||
// Mock tag definitions queries for filter mapping and display mapping
|
||||
mockDbChain.limit
|
||||
.mockResolvedValueOnce(mockTagDefinitions) // Tag definitions for filter mapping
|
||||
.mockResolvedValueOnce(mockTagDefinitions) // Tag definitions for display mapping
|
||||
|
||||
// Mock the tag + vector search handler
|
||||
mockHandleTagAndVectorSearch.mockResolvedValue(mockSearchResults)
|
||||
|
||||
mockFetch.mockResolvedValue({
|
||||
ok: true,
|
||||
json: () =>
|
||||
Promise.resolve({
|
||||
data: [{ embedding: mockEmbedding }],
|
||||
}),
|
||||
})
|
||||
|
||||
const req = createMockRequest('POST', combinedData)
|
||||
const { POST } = await import('@/app/api/knowledge/search/route')
|
||||
const response = await POST(req)
|
||||
const data = await response.json()
|
||||
|
||||
if (response.status !== 200) {
|
||||
console.log('Query+tag combination test error:', data)
|
||||
}
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
expect(data.success).toBe(true)
|
||||
expect(data.data.results).toHaveLength(2)
|
||||
expect(data.data.query).toBe('test search')
|
||||
expect(data.data.cost).toBeDefined() // Cost included for vector search
|
||||
expect(mockGenerateSearchEmbedding).toHaveBeenCalled() // Embedding API called
|
||||
expect(mockHandleTagAndVectorSearch).toHaveBeenCalledWith({
|
||||
knowledgeBaseIds: ['kb-123'],
|
||||
topK: 10,
|
||||
filters: { category: 'api' }, // Note: When no tag definitions are found, it uses the original filter key
|
||||
queryVector: JSON.stringify(mockEmbedding),
|
||||
distanceThreshold: 1, // Single KB uses threshold of 1.0
|
||||
})
|
||||
})
|
||||
|
||||
it('should validate that either query or filters are provided', async () => {
|
||||
const emptyData = {
|
||||
knowledgeBaseIds: 'kb-123',
|
||||
topK: 10,
|
||||
}
|
||||
|
||||
const req = createMockRequest('POST', emptyData)
|
||||
const { POST } = await import('@/app/api/knowledge/search/route')
|
||||
const response = await POST(req)
|
||||
const data = await response.json()
|
||||
|
||||
expect(response.status).toBe(400)
|
||||
expect(data.error).toBe('Invalid request data')
|
||||
expect(data.details).toEqual(
|
||||
expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
message:
|
||||
'Please provide either a search query or tag filters to search your knowledge base',
|
||||
}),
|
||||
])
|
||||
)
|
||||
})
|
||||
|
||||
it('should validate that empty query with empty filters fails', async () => {
|
||||
const emptyFiltersData = {
|
||||
knowledgeBaseIds: 'kb-123',
|
||||
query: '',
|
||||
filters: {},
|
||||
topK: 10,
|
||||
}
|
||||
|
||||
const req = createMockRequest('POST', emptyFiltersData)
|
||||
const { POST } = await import('@/app/api/knowledge/search/route')
|
||||
const response = await POST(req)
|
||||
const data = await response.json()
|
||||
|
||||
expect(response.status).toBe(400)
|
||||
expect(data.error).toBe('Invalid request data')
|
||||
})
|
||||
|
||||
it('should handle empty tag values gracefully', async () => {
|
||||
// This simulates what happens when the frontend sends empty tag values
|
||||
// The tool transformation should filter out empty values, resulting in no filters
|
||||
const emptyTagValueData = {
|
||||
knowledgeBaseIds: 'kb-123',
|
||||
query: '',
|
||||
topK: 10,
|
||||
// This would result in no filters after tool transformation
|
||||
}
|
||||
|
||||
const req = createMockRequest('POST', emptyTagValueData)
|
||||
const { POST } = await import('@/app/api/knowledge/search/route')
|
||||
const response = await POST(req)
|
||||
const data = await response.json()
|
||||
|
||||
expect(response.status).toBe(400)
|
||||
expect(data.error).toBe('Invalid request data')
|
||||
expect(data.details).toEqual(
|
||||
expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
message:
|
||||
'Please provide either a search query or tag filters to search your knowledge base',
|
||||
}),
|
||||
])
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle null values from frontend gracefully', async () => {
|
||||
// This simulates the exact scenario the user reported
|
||||
// Null values should be transformed to undefined and then trigger validation
|
||||
const nullValuesData = {
|
||||
knowledgeBaseIds: 'kb-123',
|
||||
topK: null,
|
||||
query: null,
|
||||
filters: null,
|
||||
}
|
||||
|
||||
const req = createMockRequest('POST', nullValuesData)
|
||||
const { POST } = await import('@/app/api/knowledge/search/route')
|
||||
const response = await POST(req)
|
||||
const data = await response.json()
|
||||
|
||||
expect(response.status).toBe(400)
|
||||
expect(data.error).toBe('Invalid request data')
|
||||
expect(data.details).toEqual(
|
||||
expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
message:
|
||||
'Please provide either a search query or tag filters to search your knowledge base',
|
||||
}),
|
||||
])
|
||||
)
|
||||
})
|
||||
|
||||
it('should perform query-only search (existing behavior)', async () => {
|
||||
const queryOnlyData = {
|
||||
knowledgeBaseIds: 'kb-123',
|
||||
query: 'test search query',
|
||||
topK: 10,
|
||||
}
|
||||
|
||||
mockGetUserId.mockResolvedValue('user-123')
|
||||
mockCheckKnowledgeBaseAccess.mockResolvedValue({
|
||||
hasAccess: true,
|
||||
knowledgeBase: {
|
||||
id: 'kb-123',
|
||||
userId: 'user-123',
|
||||
name: 'Test KB',
|
||||
deletedAt: null,
|
||||
},
|
||||
})
|
||||
|
||||
mockDbChain.limit.mockResolvedValueOnce(mockSearchResults)
|
||||
|
||||
mockFetch.mockResolvedValue({
|
||||
ok: true,
|
||||
json: () =>
|
||||
Promise.resolve({
|
||||
data: [{ embedding: mockEmbedding }],
|
||||
}),
|
||||
})
|
||||
|
||||
const req = createMockRequest('POST', queryOnlyData)
|
||||
const { POST } = await import('@/app/api/knowledge/search/route')
|
||||
const response = await POST(req)
|
||||
const data = await response.json()
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
expect(data.success).toBe(true)
|
||||
expect(data.data.results).toHaveLength(2)
|
||||
expect(data.data.query).toBe('test search query')
|
||||
expect(data.data.cost).toBeDefined() // Cost included for vector search
|
||||
expect(mockGenerateSearchEmbedding).toHaveBeenCalled() // Embedding API called
|
||||
})
|
||||
|
||||
it('should handle tag-only search with multiple knowledge bases', async () => {
|
||||
const multiKbTagData = {
|
||||
knowledgeBaseIds: ['kb-123', 'kb-456'],
|
||||
filters: {
|
||||
category: 'docs',
|
||||
priority: 'high',
|
||||
},
|
||||
topK: 10,
|
||||
}
|
||||
|
||||
mockGetUserId.mockResolvedValue('user-123')
|
||||
mockCheckKnowledgeBaseAccess
|
||||
.mockResolvedValueOnce({
|
||||
hasAccess: true,
|
||||
knowledgeBase: {
|
||||
id: 'kb-123',
|
||||
userId: 'user-123',
|
||||
name: 'Test KB',
|
||||
deletedAt: null,
|
||||
},
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
hasAccess: true,
|
||||
knowledgeBase: { id: 'kb-456', userId: 'user-123', name: 'Test KB 2' },
|
||||
})
|
||||
|
||||
// Reset all mocks before setting up specific behavior
|
||||
Object.values(mockDbChain).forEach((fn) => {
|
||||
if (typeof fn === 'function') {
|
||||
fn.mockClear().mockReturnThis()
|
||||
}
|
||||
})
|
||||
|
||||
// Create fresh mocks for multiple database calls needed for multi-KB tag search
|
||||
const mockTagDefsQuery1 = {
|
||||
...mockDbChain,
|
||||
limit: vi.fn().mockResolvedValue(mockTagDefinitions),
|
||||
}
|
||||
const mockTagSearchQuery = {
|
||||
...mockDbChain,
|
||||
limit: vi.fn().mockResolvedValue(mockTaggedResults),
|
||||
}
|
||||
const mockTagDefsQuery2 = {
|
||||
...mockDbChain,
|
||||
limit: vi.fn().mockResolvedValue(mockTagDefinitions),
|
||||
}
|
||||
const mockTagDefsQuery3 = {
|
||||
...mockDbChain,
|
||||
limit: vi.fn().mockResolvedValue(mockTagDefinitions),
|
||||
}
|
||||
|
||||
// Chain the mocks for: tag defs, search, display mapping KB1, display mapping KB2
|
||||
mockDbChain.select
|
||||
.mockReturnValueOnce(mockTagDefsQuery1)
|
||||
.mockReturnValueOnce(mockTagSearchQuery)
|
||||
.mockReturnValueOnce(mockTagDefsQuery2)
|
||||
.mockReturnValueOnce(mockTagDefsQuery3)
|
||||
|
||||
const req = createMockRequest('POST', multiKbTagData)
|
||||
const { POST } = await import('@/app/api/knowledge/search/route')
|
||||
const response = await POST(req)
|
||||
const data = await response.json()
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
expect(data.success).toBe(true)
|
||||
expect(data.data.knowledgeBaseIds).toEqual(['kb-123', 'kb-456'])
|
||||
expect(mockGenerateSearchEmbedding).not.toHaveBeenCalled() // No embedding for tag-only
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,235 +1,61 @@
|
||||
import { and, eq, inArray, sql } from 'drizzle-orm'
|
||||
import { eq } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { TAG_SLOTS } from '@/lib/constants/knowledge'
|
||||
import { retryWithExponentialBackoff } from '@/lib/documents/utils'
|
||||
import { env } from '@/lib/env'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { estimateTokenCount } from '@/lib/tokenization/estimators'
|
||||
import { getUserId } from '@/app/api/auth/oauth/utils'
|
||||
import { checkKnowledgeBaseAccess } from '@/app/api/knowledge/utils'
|
||||
import { db } from '@/db'
|
||||
import { embedding, knowledgeBaseTagDefinitions } from '@/db/schema'
|
||||
import { knowledgeBaseTagDefinitions } from '@/db/schema'
|
||||
import { calculateCost } from '@/providers/utils'
|
||||
import {
|
||||
generateSearchEmbedding,
|
||||
getQueryStrategy,
|
||||
handleTagAndVectorSearch,
|
||||
handleTagOnlySearch,
|
||||
handleVectorOnlySearch,
|
||||
type SearchResult,
|
||||
} from './utils'
|
||||
|
||||
const logger = createLogger('VectorSearchAPI')
|
||||
|
||||
function getTagFilters(filters: Record<string, string>, embedding: any) {
|
||||
return Object.entries(filters).map(([key, value]) => {
|
||||
// Handle OR logic within same tag
|
||||
const values = value.includes('|OR|') ? value.split('|OR|') : [value]
|
||||
logger.debug(`[getTagFilters] Processing ${key}="${value}" -> values:`, values)
|
||||
|
||||
const getColumnForKey = (key: string) => {
|
||||
switch (key) {
|
||||
case 'tag1':
|
||||
return embedding.tag1
|
||||
case 'tag2':
|
||||
return embedding.tag2
|
||||
case 'tag3':
|
||||
return embedding.tag3
|
||||
case 'tag4':
|
||||
return embedding.tag4
|
||||
case 'tag5':
|
||||
return embedding.tag5
|
||||
case 'tag6':
|
||||
return embedding.tag6
|
||||
case 'tag7':
|
||||
return embedding.tag7
|
||||
default:
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
const column = getColumnForKey(key)
|
||||
if (!column) return sql`1=1` // No-op for unknown keys
|
||||
|
||||
if (values.length === 1) {
|
||||
// Single value - simple equality
|
||||
logger.debug(`[getTagFilters] Single value filter: ${key} = ${values[0]}`)
|
||||
return sql`LOWER(${column}) = LOWER(${values[0]})`
|
||||
}
|
||||
// Multiple values - OR logic
|
||||
logger.debug(`[getTagFilters] OR filter: ${key} IN (${values.join(', ')})`)
|
||||
const orConditions = values.map((v) => sql`LOWER(${column}) = LOWER(${v})`)
|
||||
return sql`(${sql.join(orConditions, sql` OR `)})`
|
||||
const VectorSearchSchema = z
|
||||
.object({
|
||||
knowledgeBaseIds: z.union([
|
||||
z.string().min(1, 'Knowledge base ID is required'),
|
||||
z.array(z.string().min(1)).min(1, 'At least one knowledge base ID is required'),
|
||||
]),
|
||||
query: z
|
||||
.string()
|
||||
.optional()
|
||||
.nullable()
|
||||
.transform((val) => val || undefined),
|
||||
topK: z
|
||||
.number()
|
||||
.min(1)
|
||||
.max(100)
|
||||
.optional()
|
||||
.nullable()
|
||||
.default(10)
|
||||
.transform((val) => val ?? 10),
|
||||
filters: z
|
||||
.record(z.string())
|
||||
.optional()
|
||||
.nullable()
|
||||
.transform((val) => val || undefined), // Allow dynamic filter keys (display names)
|
||||
})
|
||||
}
|
||||
|
||||
class APIError extends Error {
|
||||
public status: number
|
||||
|
||||
constructor(message: string, status: number) {
|
||||
super(message)
|
||||
this.name = 'APIError'
|
||||
this.status = status
|
||||
}
|
||||
}
|
||||
|
||||
const VectorSearchSchema = z.object({
|
||||
knowledgeBaseIds: z.union([
|
||||
z.string().min(1, 'Knowledge base ID is required'),
|
||||
z.array(z.string().min(1)).min(1, 'At least one knowledge base ID is required'),
|
||||
]),
|
||||
query: z.string().min(1, 'Search query is required'),
|
||||
topK: z.number().min(1).max(100).default(10),
|
||||
filters: z.record(z.string()).optional(), // Allow dynamic filter keys (display names)
|
||||
})
|
||||
|
||||
async function generateSearchEmbedding(query: string): Promise<number[]> {
|
||||
const openaiApiKey = env.OPENAI_API_KEY
|
||||
if (!openaiApiKey) {
|
||||
throw new Error('OPENAI_API_KEY not configured')
|
||||
}
|
||||
|
||||
try {
|
||||
const embedding = await retryWithExponentialBackoff(
|
||||
async () => {
|
||||
const response = await fetch('https://api.openai.com/v1/embeddings', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Authorization: `Bearer ${openaiApiKey}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
input: query,
|
||||
model: 'text-embedding-3-small',
|
||||
encoding_format: 'float',
|
||||
}),
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text()
|
||||
const error = new APIError(
|
||||
`OpenAI API error: ${response.status} ${response.statusText} - ${errorText}`,
|
||||
response.status
|
||||
)
|
||||
throw error
|
||||
}
|
||||
|
||||
const data = await response.json()
|
||||
|
||||
if (!data.data || !Array.isArray(data.data) || data.data.length === 0) {
|
||||
throw new Error('Invalid response format from OpenAI embeddings API')
|
||||
}
|
||||
|
||||
return data.data[0].embedding
|
||||
},
|
||||
{
|
||||
maxRetries: 5,
|
||||
initialDelayMs: 1000,
|
||||
maxDelayMs: 30000,
|
||||
backoffMultiplier: 2,
|
||||
}
|
||||
)
|
||||
|
||||
return embedding
|
||||
} catch (error) {
|
||||
logger.error('Failed to generate search embedding:', error)
|
||||
throw new Error(
|
||||
`Embedding generation failed: ${error instanceof Error ? error.message : 'Unknown error'}`
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
function getQueryStrategy(kbCount: number, topK: number) {
|
||||
const useParallel = kbCount > 4 || (kbCount > 2 && topK > 50)
|
||||
const distanceThreshold = kbCount > 3 ? 0.8 : 1.0
|
||||
const parallelLimit = Math.ceil(topK / kbCount) + 5
|
||||
|
||||
return {
|
||||
useParallel,
|
||||
distanceThreshold,
|
||||
parallelLimit,
|
||||
singleQueryOptimized: kbCount <= 2,
|
||||
}
|
||||
}
|
||||
|
||||
async function executeParallelQueries(
|
||||
knowledgeBaseIds: string[],
|
||||
queryVector: string,
|
||||
topK: number,
|
||||
distanceThreshold: number,
|
||||
filters?: Record<string, string>
|
||||
) {
|
||||
const parallelLimit = Math.ceil(topK / knowledgeBaseIds.length) + 5
|
||||
|
||||
const queryPromises = knowledgeBaseIds.map(async (kbId) => {
|
||||
const results = await db
|
||||
.select({
|
||||
id: embedding.id,
|
||||
content: embedding.content,
|
||||
documentId: embedding.documentId,
|
||||
chunkIndex: embedding.chunkIndex,
|
||||
tag1: embedding.tag1,
|
||||
tag2: embedding.tag2,
|
||||
tag3: embedding.tag3,
|
||||
tag4: embedding.tag4,
|
||||
tag5: embedding.tag5,
|
||||
tag6: embedding.tag6,
|
||||
tag7: embedding.tag7,
|
||||
distance: sql<number>`${embedding.embedding} <=> ${queryVector}::vector`.as('distance'),
|
||||
knowledgeBaseId: embedding.knowledgeBaseId,
|
||||
})
|
||||
.from(embedding)
|
||||
.where(
|
||||
and(
|
||||
eq(embedding.knowledgeBaseId, kbId),
|
||||
eq(embedding.enabled, true),
|
||||
sql`${embedding.embedding} <=> ${queryVector}::vector < ${distanceThreshold}`,
|
||||
...(filters ? getTagFilters(filters, embedding) : [])
|
||||
)
|
||||
)
|
||||
.orderBy(sql`${embedding.embedding} <=> ${queryVector}::vector`)
|
||||
.limit(parallelLimit)
|
||||
|
||||
return results
|
||||
})
|
||||
|
||||
const parallelResults = await Promise.all(queryPromises)
|
||||
return parallelResults.flat()
|
||||
}
|
||||
|
||||
async function executeSingleQuery(
|
||||
knowledgeBaseIds: string[],
|
||||
queryVector: string,
|
||||
topK: number,
|
||||
distanceThreshold: number,
|
||||
filters?: Record<string, string>
|
||||
) {
|
||||
logger.debug(`[executeSingleQuery] Called with filters:`, filters)
|
||||
return await db
|
||||
.select({
|
||||
id: embedding.id,
|
||||
content: embedding.content,
|
||||
documentId: embedding.documentId,
|
||||
chunkIndex: embedding.chunkIndex,
|
||||
tag1: embedding.tag1,
|
||||
tag2: embedding.tag2,
|
||||
tag3: embedding.tag3,
|
||||
tag4: embedding.tag4,
|
||||
tag5: embedding.tag5,
|
||||
tag6: embedding.tag6,
|
||||
tag7: embedding.tag7,
|
||||
distance: sql<number>`${embedding.embedding} <=> ${queryVector}::vector`.as('distance'),
|
||||
knowledgeBaseId: embedding.knowledgeBaseId,
|
||||
})
|
||||
.from(embedding)
|
||||
.where(
|
||||
and(
|
||||
inArray(embedding.knowledgeBaseId, knowledgeBaseIds),
|
||||
eq(embedding.enabled, true),
|
||||
sql`${embedding.embedding} <=> ${queryVector}::vector < ${distanceThreshold}`,
|
||||
...(filters ? getTagFilters(filters, embedding) : [])
|
||||
)
|
||||
)
|
||||
.orderBy(sql`${embedding.embedding} <=> ${queryVector}::vector`)
|
||||
.limit(topK)
|
||||
}
|
||||
|
||||
function mergeAndRankResults(results: any[], topK: number) {
|
||||
return results.sort((a, b) => a.distance - b.distance).slice(0, topK)
|
||||
}
|
||||
.refine(
|
||||
(data) => {
|
||||
// Ensure at least query or filters are provided
|
||||
const hasQuery = data.query && data.query.trim().length > 0
|
||||
const hasFilters = data.filters && Object.keys(data.filters).length > 0
|
||||
return hasQuery || hasFilters
|
||||
},
|
||||
{
|
||||
message: 'Please provide either a search query or tag filters to search your knowledge base',
|
||||
}
|
||||
)
|
||||
|
||||
export async function POST(request: NextRequest) {
|
||||
const requestId = crypto.randomUUID().slice(0, 8)
|
||||
@@ -317,8 +143,9 @@ export async function POST(request: NextRequest) {
|
||||
)
|
||||
}
|
||||
|
||||
// Generate query embedding in parallel with access checks
|
||||
const queryEmbedding = await generateSearchEmbedding(validatedData.query)
|
||||
// Generate query embedding only if query is provided
|
||||
const hasQuery = validatedData.query && validatedData.query.trim().length > 0
|
||||
const queryEmbedding = hasQuery ? await generateSearchEmbedding(validatedData.query!) : null
|
||||
|
||||
// Check if any requested knowledge bases were not accessible
|
||||
const inaccessibleKbIds = knowledgeBaseIds.filter((id) => !accessibleKbIds.includes(id))
|
||||
@@ -330,46 +157,67 @@ export async function POST(request: NextRequest) {
|
||||
)
|
||||
}
|
||||
|
||||
// Adaptive query strategy based on accessible KB count and parameters
|
||||
const strategy = getQueryStrategy(accessibleKbIds.length, validatedData.topK)
|
||||
const queryVector = JSON.stringify(queryEmbedding)
|
||||
let results: SearchResult[]
|
||||
|
||||
let results: any[]
|
||||
const hasFilters = mappedFilters && Object.keys(mappedFilters).length > 0
|
||||
|
||||
if (strategy.useParallel) {
|
||||
// Execute parallel queries for better performance with many KBs
|
||||
logger.debug(`[${requestId}] Executing parallel queries with filters:`, mappedFilters)
|
||||
const parallelResults = await executeParallelQueries(
|
||||
accessibleKbIds,
|
||||
if (!hasQuery && hasFilters) {
|
||||
// Tag-only search without vector similarity
|
||||
logger.debug(`[${requestId}] Executing tag-only search with filters:`, mappedFilters)
|
||||
results = await handleTagOnlySearch({
|
||||
knowledgeBaseIds: accessibleKbIds,
|
||||
topK: validatedData.topK,
|
||||
filters: mappedFilters,
|
||||
})
|
||||
} else if (hasQuery && hasFilters) {
|
||||
// Tag + Vector search
|
||||
logger.debug(`[${requestId}] Executing tag + vector search with filters:`, mappedFilters)
|
||||
const strategy = getQueryStrategy(accessibleKbIds.length, validatedData.topK)
|
||||
const queryVector = JSON.stringify(queryEmbedding)
|
||||
|
||||
results = await handleTagAndVectorSearch({
|
||||
knowledgeBaseIds: accessibleKbIds,
|
||||
topK: validatedData.topK,
|
||||
filters: mappedFilters,
|
||||
queryVector,
|
||||
validatedData.topK,
|
||||
strategy.distanceThreshold,
|
||||
mappedFilters
|
||||
)
|
||||
results = mergeAndRankResults(parallelResults, validatedData.topK)
|
||||
distanceThreshold: strategy.distanceThreshold,
|
||||
})
|
||||
} else if (hasQuery && !hasFilters) {
|
||||
// Vector-only search
|
||||
logger.debug(`[${requestId}] Executing vector-only search`)
|
||||
const strategy = getQueryStrategy(accessibleKbIds.length, validatedData.topK)
|
||||
const queryVector = JSON.stringify(queryEmbedding)
|
||||
|
||||
results = await handleVectorOnlySearch({
|
||||
knowledgeBaseIds: accessibleKbIds,
|
||||
topK: validatedData.topK,
|
||||
queryVector,
|
||||
distanceThreshold: strategy.distanceThreshold,
|
||||
})
|
||||
} else {
|
||||
// Execute single optimized query for fewer KBs
|
||||
logger.debug(`[${requestId}] Executing single query with filters:`, mappedFilters)
|
||||
results = await executeSingleQuery(
|
||||
accessibleKbIds,
|
||||
queryVector,
|
||||
validatedData.topK,
|
||||
strategy.distanceThreshold,
|
||||
mappedFilters
|
||||
// This should never happen due to schema validation, but just in case
|
||||
return NextResponse.json(
|
||||
{
|
||||
error:
|
||||
'Please provide either a search query or tag filters to search your knowledge base',
|
||||
},
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
|
||||
// Calculate cost for the embedding (with fallback if calculation fails)
|
||||
let cost = null
|
||||
let tokenCount = null
|
||||
try {
|
||||
tokenCount = estimateTokenCount(validatedData.query, 'openai')
|
||||
cost = calculateCost('text-embedding-3-small', tokenCount.count, 0, false)
|
||||
} catch (error) {
|
||||
logger.warn(`[${requestId}] Failed to calculate cost for search query`, {
|
||||
error: error instanceof Error ? error.message : 'Unknown error',
|
||||
})
|
||||
// Continue without cost information rather than failing the search
|
||||
if (hasQuery) {
|
||||
try {
|
||||
tokenCount = estimateTokenCount(validatedData.query!, 'openai')
|
||||
cost = calculateCost('text-embedding-3-small', tokenCount.count, 0, false)
|
||||
} catch (error) {
|
||||
logger.warn(`[${requestId}] Failed to calculate cost for search query`, {
|
||||
error: error instanceof Error ? error.message : 'Unknown error',
|
||||
})
|
||||
// Continue without cost information rather than failing the search
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch tag definitions for display name mapping (reuse the same fetch from filtering)
|
||||
@@ -412,12 +260,13 @@ export async function POST(request: NextRequest) {
|
||||
const tags: Record<string, any> = {}
|
||||
|
||||
TAG_SLOTS.forEach((slot) => {
|
||||
if (result[slot]) {
|
||||
const tagValue = (result as any)[slot]
|
||||
if (tagValue) {
|
||||
const displayName = kbTagMap[slot] || slot
|
||||
logger.debug(
|
||||
`[${requestId}] Mapping ${slot}="${result[slot]}" -> "${displayName}"="${result[slot]}"`
|
||||
`[${requestId}] Mapping ${slot}="${tagValue}" -> "${displayName}"="${tagValue}"`
|
||||
)
|
||||
tags[displayName] = result[slot]
|
||||
tags[displayName] = tagValue
|
||||
}
|
||||
})
|
||||
|
||||
@@ -427,10 +276,10 @@ export async function POST(request: NextRequest) {
|
||||
documentId: result.documentId,
|
||||
chunkIndex: result.chunkIndex,
|
||||
tags, // Clean display name mapped tags
|
||||
similarity: 1 - result.distance,
|
||||
similarity: hasQuery ? 1 - result.distance : 1, // Perfect similarity for tag-only searches
|
||||
}
|
||||
}),
|
||||
query: validatedData.query,
|
||||
query: validatedData.query || '',
|
||||
knowledgeBaseIds: accessibleKbIds,
|
||||
knowledgeBaseId: accessibleKbIds[0],
|
||||
topK: validatedData.topK,
|
||||
|
||||
143
apps/sim/app/api/knowledge/search/utils.test.ts
Normal file
143
apps/sim/app/api/knowledge/search/utils.test.ts
Normal file
@@ -0,0 +1,143 @@
|
||||
/**
|
||||
* Tests for knowledge search utility functions
|
||||
* Focuses on testing core functionality with simplified mocking
|
||||
*
|
||||
* @vitest-environment node
|
||||
*/
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
vi.mock('drizzle-orm')
|
||||
vi.mock('@/lib/logs/console/logger')
|
||||
vi.mock('@/db')
|
||||
|
||||
import { handleTagAndVectorSearch, handleTagOnlySearch, handleVectorOnlySearch } from './utils'
|
||||
|
||||
describe('Knowledge Search Utils', () => {
|
||||
describe('handleTagOnlySearch', () => {
|
||||
it('should throw error when no filters provided', async () => {
|
||||
const params = {
|
||||
knowledgeBaseIds: ['kb-123'],
|
||||
topK: 10,
|
||||
filters: {},
|
||||
}
|
||||
|
||||
await expect(handleTagOnlySearch(params)).rejects.toThrow(
|
||||
'Tag filters are required for tag-only search'
|
||||
)
|
||||
})
|
||||
|
||||
it('should accept valid parameters for tag-only search', async () => {
|
||||
const params = {
|
||||
knowledgeBaseIds: ['kb-123'],
|
||||
topK: 10,
|
||||
filters: { tag1: 'api' },
|
||||
}
|
||||
|
||||
// This test validates the function accepts the right parameters
|
||||
// The actual database interaction is tested via route tests
|
||||
expect(params.knowledgeBaseIds).toEqual(['kb-123'])
|
||||
expect(params.topK).toBe(10)
|
||||
expect(params.filters).toEqual({ tag1: 'api' })
|
||||
})
|
||||
})
|
||||
|
||||
describe('handleVectorOnlySearch', () => {
|
||||
it('should throw error when queryVector not provided', async () => {
|
||||
const params = {
|
||||
knowledgeBaseIds: ['kb-123'],
|
||||
topK: 10,
|
||||
distanceThreshold: 0.8,
|
||||
}
|
||||
|
||||
await expect(handleVectorOnlySearch(params)).rejects.toThrow(
|
||||
'Query vector and distance threshold are required for vector-only search'
|
||||
)
|
||||
})
|
||||
|
||||
it('should throw error when distanceThreshold not provided', async () => {
|
||||
const params = {
|
||||
knowledgeBaseIds: ['kb-123'],
|
||||
topK: 10,
|
||||
queryVector: JSON.stringify([0.1, 0.2, 0.3]),
|
||||
}
|
||||
|
||||
await expect(handleVectorOnlySearch(params)).rejects.toThrow(
|
||||
'Query vector and distance threshold are required for vector-only search'
|
||||
)
|
||||
})
|
||||
|
||||
it('should accept valid parameters for vector-only search', async () => {
|
||||
const params = {
|
||||
knowledgeBaseIds: ['kb-123'],
|
||||
topK: 10,
|
||||
queryVector: JSON.stringify([0.1, 0.2, 0.3]),
|
||||
distanceThreshold: 0.8,
|
||||
}
|
||||
|
||||
// This test validates the function accepts the right parameters
|
||||
expect(params.knowledgeBaseIds).toEqual(['kb-123'])
|
||||
expect(params.topK).toBe(10)
|
||||
expect(params.queryVector).toBe(JSON.stringify([0.1, 0.2, 0.3]))
|
||||
expect(params.distanceThreshold).toBe(0.8)
|
||||
})
|
||||
})
|
||||
|
||||
describe('handleTagAndVectorSearch', () => {
|
||||
it('should throw error when no filters provided', async () => {
|
||||
const params = {
|
||||
knowledgeBaseIds: ['kb-123'],
|
||||
topK: 10,
|
||||
filters: {},
|
||||
queryVector: JSON.stringify([0.1, 0.2, 0.3]),
|
||||
distanceThreshold: 0.8,
|
||||
}
|
||||
|
||||
await expect(handleTagAndVectorSearch(params)).rejects.toThrow(
|
||||
'Tag filters are required for tag and vector search'
|
||||
)
|
||||
})
|
||||
|
||||
it('should throw error when queryVector not provided', async () => {
|
||||
const params = {
|
||||
knowledgeBaseIds: ['kb-123'],
|
||||
topK: 10,
|
||||
filters: { tag1: 'api' },
|
||||
distanceThreshold: 0.8,
|
||||
}
|
||||
|
||||
await expect(handleTagAndVectorSearch(params)).rejects.toThrow(
|
||||
'Query vector and distance threshold are required for tag and vector search'
|
||||
)
|
||||
})
|
||||
|
||||
it('should throw error when distanceThreshold not provided', async () => {
|
||||
const params = {
|
||||
knowledgeBaseIds: ['kb-123'],
|
||||
topK: 10,
|
||||
filters: { tag1: 'api' },
|
||||
queryVector: JSON.stringify([0.1, 0.2, 0.3]),
|
||||
}
|
||||
|
||||
await expect(handleTagAndVectorSearch(params)).rejects.toThrow(
|
||||
'Query vector and distance threshold are required for tag and vector search'
|
||||
)
|
||||
})
|
||||
|
||||
it('should accept valid parameters for tag and vector search', async () => {
|
||||
const params = {
|
||||
knowledgeBaseIds: ['kb-123'],
|
||||
topK: 10,
|
||||
filters: { tag1: 'api' },
|
||||
queryVector: JSON.stringify([0.1, 0.2, 0.3]),
|
||||
distanceThreshold: 0.8,
|
||||
}
|
||||
|
||||
// This test validates the function accepts the right parameters
|
||||
expect(params.knowledgeBaseIds).toEqual(['kb-123'])
|
||||
expect(params.topK).toBe(10)
|
||||
expect(params.filters).toEqual({ tag1: 'api' })
|
||||
expect(params.queryVector).toBe(JSON.stringify([0.1, 0.2, 0.3]))
|
||||
expect(params.distanceThreshold).toBe(0.8)
|
||||
})
|
||||
})
|
||||
})
|
||||
402
apps/sim/app/api/knowledge/search/utils.ts
Normal file
402
apps/sim/app/api/knowledge/search/utils.ts
Normal file
@@ -0,0 +1,402 @@
|
||||
import { and, eq, inArray, sql } from 'drizzle-orm'
|
||||
import { retryWithExponentialBackoff } from '@/lib/documents/utils'
|
||||
import { env } from '@/lib/env'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { db } from '@/db'
|
||||
import { embedding } from '@/db/schema'
|
||||
|
||||
const logger = createLogger('KnowledgeSearchUtils')
|
||||
|
||||
export class APIError extends Error {
|
||||
public status: number
|
||||
|
||||
constructor(message: string, status: number) {
|
||||
super(message)
|
||||
this.name = 'APIError'
|
||||
this.status = status
|
||||
}
|
||||
}
|
||||
|
||||
export interface SearchResult {
|
||||
id: string
|
||||
content: string
|
||||
documentId: string
|
||||
chunkIndex: number
|
||||
tag1: string | null
|
||||
tag2: string | null
|
||||
tag3: string | null
|
||||
tag4: string | null
|
||||
tag5: string | null
|
||||
tag6: string | null
|
||||
tag7: string | null
|
||||
distance: number
|
||||
knowledgeBaseId: string
|
||||
}
|
||||
|
||||
export interface SearchParams {
|
||||
knowledgeBaseIds: string[]
|
||||
topK: number
|
||||
filters?: Record<string, string>
|
||||
queryVector?: string
|
||||
distanceThreshold?: number
|
||||
}
|
||||
|
||||
export async function generateSearchEmbedding(query: string): Promise<number[]> {
|
||||
const openaiApiKey = env.OPENAI_API_KEY
|
||||
if (!openaiApiKey) {
|
||||
throw new Error('OPENAI_API_KEY not configured')
|
||||
}
|
||||
|
||||
try {
|
||||
const embedding = await retryWithExponentialBackoff(
|
||||
async () => {
|
||||
const response = await fetch('https://api.openai.com/v1/embeddings', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Authorization: `Bearer ${openaiApiKey}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
input: query,
|
||||
model: 'text-embedding-3-small',
|
||||
encoding_format: 'float',
|
||||
}),
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text()
|
||||
const error = new APIError(
|
||||
`OpenAI API error: ${response.status} ${response.statusText} - ${errorText}`,
|
||||
response.status
|
||||
)
|
||||
throw error
|
||||
}
|
||||
|
||||
const data = await response.json()
|
||||
|
||||
if (!data.data || !Array.isArray(data.data) || data.data.length === 0) {
|
||||
throw new Error('Invalid response format from OpenAI embeddings API')
|
||||
}
|
||||
|
||||
return data.data[0].embedding
|
||||
},
|
||||
{
|
||||
maxRetries: 5,
|
||||
initialDelayMs: 1000,
|
||||
maxDelayMs: 30000,
|
||||
backoffMultiplier: 2,
|
||||
}
|
||||
)
|
||||
|
||||
return embedding
|
||||
} catch (error) {
|
||||
logger.error('Failed to generate search embedding:', error)
|
||||
throw new Error(
|
||||
`Embedding generation failed: ${error instanceof Error ? error.message : 'Unknown error'}`
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
function getTagFilters(filters: Record<string, string>, embedding: any) {
|
||||
return Object.entries(filters).map(([key, value]) => {
|
||||
// Handle OR logic within same tag
|
||||
const values = value.includes('|OR|') ? value.split('|OR|') : [value]
|
||||
logger.debug(`[getTagFilters] Processing ${key}="${value}" -> values:`, values)
|
||||
|
||||
const getColumnForKey = (key: string) => {
|
||||
switch (key) {
|
||||
case 'tag1':
|
||||
return embedding.tag1
|
||||
case 'tag2':
|
||||
return embedding.tag2
|
||||
case 'tag3':
|
||||
return embedding.tag3
|
||||
case 'tag4':
|
||||
return embedding.tag4
|
||||
case 'tag5':
|
||||
return embedding.tag5
|
||||
case 'tag6':
|
||||
return embedding.tag6
|
||||
case 'tag7':
|
||||
return embedding.tag7
|
||||
default:
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
const column = getColumnForKey(key)
|
||||
if (!column) return sql`1=1` // No-op for unknown keys
|
||||
|
||||
if (values.length === 1) {
|
||||
// Single value - simple equality
|
||||
logger.debug(`[getTagFilters] Single value filter: ${key} = ${values[0]}`)
|
||||
return sql`LOWER(${column}) = LOWER(${values[0]})`
|
||||
}
|
||||
// Multiple values - OR logic
|
||||
logger.debug(`[getTagFilters] OR filter: ${key} IN (${values.join(', ')})`)
|
||||
const orConditions = values.map((v) => sql`LOWER(${column}) = LOWER(${v})`)
|
||||
return sql`(${sql.join(orConditions, sql` OR `)})`
|
||||
})
|
||||
}
|
||||
|
||||
export function getQueryStrategy(kbCount: number, topK: number) {
|
||||
const useParallel = kbCount > 4 || (kbCount > 2 && topK > 50)
|
||||
const distanceThreshold = kbCount > 3 ? 0.8 : 1.0
|
||||
const parallelLimit = Math.ceil(topK / kbCount) + 5
|
||||
|
||||
return {
|
||||
useParallel,
|
||||
distanceThreshold,
|
||||
parallelLimit,
|
||||
singleQueryOptimized: kbCount <= 2,
|
||||
}
|
||||
}
|
||||
|
||||
async function executeTagFilterQuery(
|
||||
knowledgeBaseIds: string[],
|
||||
filters: Record<string, string>
|
||||
): Promise<{ id: string }[]> {
|
||||
if (knowledgeBaseIds.length === 1) {
|
||||
return await db
|
||||
.select({ id: embedding.id })
|
||||
.from(embedding)
|
||||
.where(
|
||||
and(
|
||||
eq(embedding.knowledgeBaseId, knowledgeBaseIds[0]),
|
||||
eq(embedding.enabled, true),
|
||||
...getTagFilters(filters, embedding)
|
||||
)
|
||||
)
|
||||
}
|
||||
return await db
|
||||
.select({ id: embedding.id })
|
||||
.from(embedding)
|
||||
.where(
|
||||
and(
|
||||
inArray(embedding.knowledgeBaseId, knowledgeBaseIds),
|
||||
eq(embedding.enabled, true),
|
||||
...getTagFilters(filters, embedding)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
async function executeVectorSearchOnIds(
|
||||
embeddingIds: string[],
|
||||
queryVector: string,
|
||||
topK: number,
|
||||
distanceThreshold: number
|
||||
): Promise<SearchResult[]> {
|
||||
if (embeddingIds.length === 0) {
|
||||
return []
|
||||
}
|
||||
|
||||
return await db
|
||||
.select({
|
||||
id: embedding.id,
|
||||
content: embedding.content,
|
||||
documentId: embedding.documentId,
|
||||
chunkIndex: embedding.chunkIndex,
|
||||
tag1: embedding.tag1,
|
||||
tag2: embedding.tag2,
|
||||
tag3: embedding.tag3,
|
||||
tag4: embedding.tag4,
|
||||
tag5: embedding.tag5,
|
||||
tag6: embedding.tag6,
|
||||
tag7: embedding.tag7,
|
||||
distance: sql<number>`${embedding.embedding} <=> ${queryVector}::vector`.as('distance'),
|
||||
knowledgeBaseId: embedding.knowledgeBaseId,
|
||||
})
|
||||
.from(embedding)
|
||||
.where(
|
||||
and(
|
||||
inArray(embedding.id, embeddingIds),
|
||||
sql`${embedding.embedding} <=> ${queryVector}::vector < ${distanceThreshold}`
|
||||
)
|
||||
)
|
||||
.orderBy(sql`${embedding.embedding} <=> ${queryVector}::vector`)
|
||||
.limit(topK)
|
||||
}
|
||||
|
||||
export async function handleTagOnlySearch(params: SearchParams): Promise<SearchResult[]> {
|
||||
const { knowledgeBaseIds, topK, filters } = params
|
||||
|
||||
if (!filters || Object.keys(filters).length === 0) {
|
||||
throw new Error('Tag filters are required for tag-only search')
|
||||
}
|
||||
|
||||
logger.debug(`[handleTagOnlySearch] Executing tag-only search with filters:`, filters)
|
||||
|
||||
const strategy = getQueryStrategy(knowledgeBaseIds.length, topK)
|
||||
|
||||
if (strategy.useParallel) {
|
||||
// Parallel approach for many KBs
|
||||
const parallelLimit = Math.ceil(topK / knowledgeBaseIds.length) + 5
|
||||
|
||||
const queryPromises = knowledgeBaseIds.map(async (kbId) => {
|
||||
return await db
|
||||
.select({
|
||||
id: embedding.id,
|
||||
content: embedding.content,
|
||||
documentId: embedding.documentId,
|
||||
chunkIndex: embedding.chunkIndex,
|
||||
tag1: embedding.tag1,
|
||||
tag2: embedding.tag2,
|
||||
tag3: embedding.tag3,
|
||||
tag4: embedding.tag4,
|
||||
tag5: embedding.tag5,
|
||||
tag6: embedding.tag6,
|
||||
tag7: embedding.tag7,
|
||||
distance: sql<number>`0`.as('distance'), // No distance for tag-only searches
|
||||
knowledgeBaseId: embedding.knowledgeBaseId,
|
||||
})
|
||||
.from(embedding)
|
||||
.where(
|
||||
and(
|
||||
eq(embedding.knowledgeBaseId, kbId),
|
||||
eq(embedding.enabled, true),
|
||||
...getTagFilters(filters, embedding)
|
||||
)
|
||||
)
|
||||
.limit(parallelLimit)
|
||||
})
|
||||
|
||||
const parallelResults = await Promise.all(queryPromises)
|
||||
return parallelResults.flat().slice(0, topK)
|
||||
}
|
||||
// Single query for fewer KBs
|
||||
return await db
|
||||
.select({
|
||||
id: embedding.id,
|
||||
content: embedding.content,
|
||||
documentId: embedding.documentId,
|
||||
chunkIndex: embedding.chunkIndex,
|
||||
tag1: embedding.tag1,
|
||||
tag2: embedding.tag2,
|
||||
tag3: embedding.tag3,
|
||||
tag4: embedding.tag4,
|
||||
tag5: embedding.tag5,
|
||||
tag6: embedding.tag6,
|
||||
tag7: embedding.tag7,
|
||||
distance: sql<number>`0`.as('distance'), // No distance for tag-only searches
|
||||
knowledgeBaseId: embedding.knowledgeBaseId,
|
||||
})
|
||||
.from(embedding)
|
||||
.where(
|
||||
and(
|
||||
inArray(embedding.knowledgeBaseId, knowledgeBaseIds),
|
||||
eq(embedding.enabled, true),
|
||||
...getTagFilters(filters, embedding)
|
||||
)
|
||||
)
|
||||
.limit(topK)
|
||||
}
|
||||
|
||||
export async function handleVectorOnlySearch(params: SearchParams): Promise<SearchResult[]> {
|
||||
const { knowledgeBaseIds, topK, queryVector, distanceThreshold } = params
|
||||
|
||||
if (!queryVector || !distanceThreshold) {
|
||||
throw new Error('Query vector and distance threshold are required for vector-only search')
|
||||
}
|
||||
|
||||
logger.debug(`[handleVectorOnlySearch] Executing vector-only search`)
|
||||
|
||||
const strategy = getQueryStrategy(knowledgeBaseIds.length, topK)
|
||||
|
||||
if (strategy.useParallel) {
|
||||
// Parallel approach for many KBs
|
||||
const parallelLimit = Math.ceil(topK / knowledgeBaseIds.length) + 5
|
||||
|
||||
const queryPromises = knowledgeBaseIds.map(async (kbId) => {
|
||||
return await db
|
||||
.select({
|
||||
id: embedding.id,
|
||||
content: embedding.content,
|
||||
documentId: embedding.documentId,
|
||||
chunkIndex: embedding.chunkIndex,
|
||||
tag1: embedding.tag1,
|
||||
tag2: embedding.tag2,
|
||||
tag3: embedding.tag3,
|
||||
tag4: embedding.tag4,
|
||||
tag5: embedding.tag5,
|
||||
tag6: embedding.tag6,
|
||||
tag7: embedding.tag7,
|
||||
distance: sql<number>`${embedding.embedding} <=> ${queryVector}::vector`.as('distance'),
|
||||
knowledgeBaseId: embedding.knowledgeBaseId,
|
||||
})
|
||||
.from(embedding)
|
||||
.where(
|
||||
and(
|
||||
eq(embedding.knowledgeBaseId, kbId),
|
||||
eq(embedding.enabled, true),
|
||||
sql`${embedding.embedding} <=> ${queryVector}::vector < ${distanceThreshold}`
|
||||
)
|
||||
)
|
||||
.orderBy(sql`${embedding.embedding} <=> ${queryVector}::vector`)
|
||||
.limit(parallelLimit)
|
||||
})
|
||||
|
||||
const parallelResults = await Promise.all(queryPromises)
|
||||
const allResults = parallelResults.flat()
|
||||
return allResults.sort((a, b) => a.distance - b.distance).slice(0, topK)
|
||||
}
|
||||
// Single query for fewer KBs
|
||||
return await db
|
||||
.select({
|
||||
id: embedding.id,
|
||||
content: embedding.content,
|
||||
documentId: embedding.documentId,
|
||||
chunkIndex: embedding.chunkIndex,
|
||||
tag1: embedding.tag1,
|
||||
tag2: embedding.tag2,
|
||||
tag3: embedding.tag3,
|
||||
tag4: embedding.tag4,
|
||||
tag5: embedding.tag5,
|
||||
tag6: embedding.tag6,
|
||||
tag7: embedding.tag7,
|
||||
distance: sql<number>`${embedding.embedding} <=> ${queryVector}::vector`.as('distance'),
|
||||
knowledgeBaseId: embedding.knowledgeBaseId,
|
||||
})
|
||||
.from(embedding)
|
||||
.where(
|
||||
and(
|
||||
inArray(embedding.knowledgeBaseId, knowledgeBaseIds),
|
||||
eq(embedding.enabled, true),
|
||||
sql`${embedding.embedding} <=> ${queryVector}::vector < ${distanceThreshold}`
|
||||
)
|
||||
)
|
||||
.orderBy(sql`${embedding.embedding} <=> ${queryVector}::vector`)
|
||||
.limit(topK)
|
||||
}
|
||||
|
||||
export async function handleTagAndVectorSearch(params: SearchParams): Promise<SearchResult[]> {
|
||||
const { knowledgeBaseIds, topK, filters, queryVector, distanceThreshold } = params
|
||||
|
||||
if (!filters || Object.keys(filters).length === 0) {
|
||||
throw new Error('Tag filters are required for tag and vector search')
|
||||
}
|
||||
if (!queryVector || !distanceThreshold) {
|
||||
throw new Error('Query vector and distance threshold are required for tag and vector search')
|
||||
}
|
||||
|
||||
logger.debug(`[handleTagAndVectorSearch] Executing tag + vector search with filters:`, filters)
|
||||
|
||||
// Step 1: Filter by tags first
|
||||
const tagFilteredIds = await executeTagFilterQuery(knowledgeBaseIds, filters)
|
||||
|
||||
if (tagFilteredIds.length === 0) {
|
||||
logger.debug(`[handleTagAndVectorSearch] No results found after tag filtering`)
|
||||
return []
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
`[handleTagAndVectorSearch] Found ${tagFilteredIds.length} results after tag filtering`
|
||||
)
|
||||
|
||||
// Step 2: Perform vector search only on tag-filtered results
|
||||
return await executeVectorSearchOnIds(
|
||||
tagFilteredIds.map((r) => r.id),
|
||||
queryVector,
|
||||
topK,
|
||||
distanceThreshold
|
||||
)
|
||||
}
|
||||
@@ -39,8 +39,8 @@ export const KnowledgeBlock: BlockConfig = {
|
||||
title: 'Search Query',
|
||||
type: 'short-input',
|
||||
layout: 'full',
|
||||
placeholder: 'Enter your search query',
|
||||
required: true,
|
||||
placeholder: 'Enter your search query (optional when using tag filters)',
|
||||
required: false,
|
||||
condition: { field: 'operation', value: 'search' },
|
||||
},
|
||||
{
|
||||
|
||||
@@ -419,11 +419,17 @@ async function handleInternalRequest(
|
||||
}
|
||||
|
||||
// Extract error message from nested error objects (common in API responses)
|
||||
// Prioritize detailed validation messages over generic error field
|
||||
const errorMessage =
|
||||
typeof errorData.error === 'object'
|
||||
errorData.details?.[0]?.message ||
|
||||
(typeof errorData.error === 'object'
|
||||
? errorData.error.message || JSON.stringify(errorData.error)
|
||||
: errorData.error || `Request failed with status ${response.status}`
|
||||
: errorData.error) ||
|
||||
`Request failed with status ${response.status}`
|
||||
|
||||
logger.error(`[${requestId}] Internal request error for ${toolId}:`, {
|
||||
error: errorMessage,
|
||||
})
|
||||
throw new Error(errorMessage)
|
||||
}
|
||||
|
||||
|
||||
@@ -14,8 +14,8 @@ export const knowledgeSearchTool: ToolConfig<any, KnowledgeSearchResponse> = {
|
||||
},
|
||||
query: {
|
||||
type: 'string',
|
||||
required: true,
|
||||
description: 'Search query text',
|
||||
required: false,
|
||||
description: 'Search query text (optional when using tag filters)',
|
||||
},
|
||||
topK: {
|
||||
type: 'number',
|
||||
@@ -58,7 +58,7 @@ export const knowledgeSearchTool: ToolConfig<any, KnowledgeSearchResponse> = {
|
||||
// Group filters by tag name for OR logic within same tag
|
||||
const groupedFilters: Record<string, string[]> = {}
|
||||
tagFilters.forEach((filter: any) => {
|
||||
if (filter.tagName && filter.tagValue) {
|
||||
if (filter.tagName && filter.tagValue && filter.tagValue.trim().length > 0) {
|
||||
if (!groupedFilters[filter.tagName]) {
|
||||
groupedFilters[filter.tagName] = []
|
||||
}
|
||||
@@ -92,8 +92,7 @@ export const knowledgeSearchTool: ToolConfig<any, KnowledgeSearchResponse> = {
|
||||
const result = await response.json()
|
||||
|
||||
if (!response.ok) {
|
||||
const errorMessage =
|
||||
result.error?.message || result.message || 'Failed to perform vector search'
|
||||
const errorMessage = result.error || result.message || 'Failed to perform search'
|
||||
throw new Error(errorMessage)
|
||||
}
|
||||
|
||||
@@ -117,12 +116,13 @@ export const knowledgeSearchTool: ToolConfig<any, KnowledgeSearchResponse> = {
|
||||
totalResults: 0,
|
||||
cost: undefined,
|
||||
},
|
||||
error: `Vector search failed: ${error.message || 'Unknown error'}`,
|
||||
error: error.message || 'Failed to perform vector search',
|
||||
}
|
||||
}
|
||||
},
|
||||
transformError: async (error): Promise<KnowledgeSearchResponse> => {
|
||||
const errorMessage = `Vector search failed: ${error.message || 'Unknown error'}`
|
||||
const errorMessage = error.message || 'Failed to perform search'
|
||||
|
||||
return {
|
||||
success: false,
|
||||
output: {
|
||||
|
||||
Reference in New Issue
Block a user