fix(kb): exclude deleted docs from embeddings/vector search (#1319)

* update infra and remove railway

* fix(kb): exclude deleted docs from queries

* Revert "update infra and remove railway"

This reverts commit b23258a5a1.
This commit is contained in:
Waleed
2025-09-11 12:09:03 -07:00
committed by GitHub
parent 2dc75b1ac1
commit 6cf02b9b5a
3 changed files with 232 additions and 2 deletions

View File

@@ -1006,4 +1006,210 @@ describe('Knowledge Search API Route', () => {
expect(mockGenerateSearchEmbedding).not.toHaveBeenCalled() // No embedding for tag-only
})
})
describe('Deleted document filtering', () => {
it('should exclude results from deleted documents in vector search', async () => {
mockGetUserId.mockResolvedValue('user-123')
mockCheckKnowledgeBaseAccess.mockResolvedValue({
hasAccess: true,
knowledgeBase: {
id: 'kb-123',
userId: 'user-123',
name: 'Test KB',
deletedAt: null,
},
})
mockHandleVectorOnlySearch.mockResolvedValue([
{
id: 'chunk-1',
content: 'Content from active document',
documentId: 'doc-active',
chunkIndex: 0,
tag1: null,
tag2: null,
tag3: null,
tag4: null,
tag5: null,
tag6: null,
tag7: null,
distance: 0.2,
knowledgeBaseId: 'kb-123',
},
])
mockGetQueryStrategy.mockReturnValue({
useParallel: false,
distanceThreshold: 1.0,
parallelLimit: 15,
singleQueryOptimized: true,
})
mockGenerateSearchEmbedding.mockResolvedValue([0.1, 0.2, 0.3])
mockGetDocumentNamesByIds.mockResolvedValue({
'doc-active': 'Active Document.pdf',
})
const mockTagDefs = {
select: vi.fn().mockReturnThis(),
from: vi.fn().mockReturnThis(),
where: vi.fn().mockResolvedValue([]),
}
mockDbChain.select.mockReturnValueOnce(mockTagDefs)
const req = createMockRequest('POST', {
knowledgeBaseIds: ['kb-123'],
query: 'test query',
topK: 10,
})
const { POST } = await import('@/app/api/knowledge/search/route')
const response = await POST(req)
const data = await response.json()
expect(response.status).toBe(200)
expect(data.success).toBe(true)
expect(data.data.results).toHaveLength(1)
expect(data.data.results[0].documentId).toBe('doc-active')
expect(data.data.results[0].documentName).toBe('Active Document.pdf')
})
it('should exclude results from deleted documents in tag search', async () => {
mockGetUserId.mockResolvedValue('user-123')
mockCheckKnowledgeBaseAccess.mockResolvedValue({
hasAccess: true,
knowledgeBase: {
id: 'kb-123',
userId: 'user-123',
name: 'Test KB',
deletedAt: null,
},
})
mockHandleTagOnlySearch.mockResolvedValue([
{
id: 'chunk-2',
content: 'Content from active document with tag',
documentId: 'doc-active-tagged',
chunkIndex: 0,
tag1: 'api',
tag2: null,
tag3: null,
tag4: null,
tag5: null,
tag6: null,
tag7: null,
distance: 0,
knowledgeBaseId: 'kb-123',
},
])
mockGetQueryStrategy.mockReturnValue({
useParallel: false,
distanceThreshold: 1.0,
parallelLimit: 15,
singleQueryOptimized: true,
})
mockGetDocumentNamesByIds.mockResolvedValue({
'doc-active-tagged': 'Active Tagged Document.pdf',
})
const mockTagDefs = {
select: vi.fn().mockReturnThis(),
from: vi.fn().mockReturnThis(),
where: vi.fn().mockResolvedValue([]),
}
mockDbChain.select.mockReturnValueOnce(mockTagDefs)
const req = createMockRequest('POST', {
knowledgeBaseIds: ['kb-123'],
filters: { tag1: 'api' },
topK: 10,
})
const { POST } = await import('@/app/api/knowledge/search/route')
const response = await POST(req)
const data = await response.json()
expect(response.status).toBe(200)
expect(data.success).toBe(true)
expect(data.data.results).toHaveLength(1)
expect(data.data.results[0].documentId).toBe('doc-active-tagged')
expect(data.data.results[0].documentName).toBe('Active Tagged Document.pdf')
expect(data.data.results[0].metadata).toEqual({ tag1: 'api' })
})
it('should exclude results from deleted documents in combined tag+vector search', async () => {
mockGetUserId.mockResolvedValue('user-123')
mockCheckKnowledgeBaseAccess.mockResolvedValue({
hasAccess: true,
knowledgeBase: {
id: 'kb-123',
userId: 'user-123',
name: 'Test KB',
deletedAt: null,
},
})
mockHandleTagAndVectorSearch.mockResolvedValue([
{
id: 'chunk-3',
content: 'Relevant content from active document',
documentId: 'doc-active-combined',
chunkIndex: 0,
tag1: 'guide',
tag2: null,
tag3: null,
tag4: null,
tag5: null,
tag6: null,
tag7: null,
distance: 0.15,
knowledgeBaseId: 'kb-123',
},
])
mockGetQueryStrategy.mockReturnValue({
useParallel: false,
distanceThreshold: 1.0,
parallelLimit: 15,
singleQueryOptimized: true,
})
mockGenerateSearchEmbedding.mockResolvedValue([0.1, 0.2, 0.3])
mockGetDocumentNamesByIds.mockResolvedValue({
'doc-active-combined': 'Active Combined Search.pdf',
})
const mockTagDefs = {
select: vi.fn().mockReturnThis(),
from: vi.fn().mockReturnThis(),
where: vi.fn().mockResolvedValue([]),
}
mockDbChain.select.mockReturnValueOnce(mockTagDefs)
const req = createMockRequest('POST', {
knowledgeBaseIds: ['kb-123'],
query: 'relevant content',
filters: { tag1: 'guide' },
topK: 10,
})
const { POST } = await import('@/app/api/knowledge/search/route')
const response = await POST(req)
const data = await response.json()
expect(response.status).toBe(200)
expect(data.success).toBe(true)
expect(data.data.results).toHaveLength(1)
expect(data.data.results[0].documentId).toBe('doc-active-combined')
expect(data.data.results[0].documentName).toBe('Active Combined Search.pdf')
expect(data.data.results[0].metadata).toEqual({ tag1: 'guide' })
expect(data.data.results[0].similarity).toBe(0.85) // 1 - 0.15 distance
})
})
})

View File

@@ -422,4 +422,14 @@ describe('Knowledge Search Utils', () => {
Object.keys(env).forEach((key) => delete (env as any)[key])
})
})
describe('getDocumentNamesByIds', () => {
it('should handle empty input gracefully', async () => {
const { getDocumentNamesByIds } = await import('./utils')
const result = await getDocumentNamesByIds([])
expect(result).toEqual({})
})
})
})

View File

@@ -1,4 +1,4 @@
import { and, eq, inArray, sql } from 'drizzle-orm'
import { and, eq, inArray, isNull, sql } from 'drizzle-orm'
import { createLogger } from '@/lib/logs/console/logger'
import { db } from '@/db'
import { document, embedding } from '@/db/schema'
@@ -19,7 +19,7 @@ export async function getDocumentNamesByIds(
filename: document.filename,
})
.from(document)
.where(inArray(document.id, uniqueIds))
.where(and(inArray(document.id, uniqueIds), isNull(document.deletedAt)))
const documentNameMap: Record<string, string> = {}
documents.forEach((doc) => {
@@ -119,10 +119,12 @@ async function executeTagFilterQuery(
return await db
.select({ id: embedding.id })
.from(embedding)
.innerJoin(document, eq(embedding.documentId, document.id))
.where(
and(
eq(embedding.knowledgeBaseId, knowledgeBaseIds[0]),
eq(embedding.enabled, true),
isNull(document.deletedAt),
...getTagFilters(filters, embedding)
)
)
@@ -130,10 +132,12 @@ async function executeTagFilterQuery(
return await db
.select({ id: embedding.id })
.from(embedding)
.innerJoin(document, eq(embedding.documentId, document.id))
.where(
and(
inArray(embedding.knowledgeBaseId, knowledgeBaseIds),
eq(embedding.enabled, true),
isNull(document.deletedAt),
...getTagFilters(filters, embedding)
)
)
@@ -166,9 +170,11 @@ async function executeVectorSearchOnIds(
knowledgeBaseId: embedding.knowledgeBaseId,
})
.from(embedding)
.innerJoin(document, eq(embedding.documentId, document.id))
.where(
and(
inArray(embedding.id, embeddingIds),
isNull(document.deletedAt),
sql`${embedding.embedding} <=> ${queryVector}::vector < ${distanceThreshold}`
)
)
@@ -209,10 +215,12 @@ export async function handleTagOnlySearch(params: SearchParams): Promise<SearchR
knowledgeBaseId: embedding.knowledgeBaseId,
})
.from(embedding)
.innerJoin(document, eq(embedding.documentId, document.id))
.where(
and(
eq(embedding.knowledgeBaseId, kbId),
eq(embedding.enabled, true),
isNull(document.deletedAt),
...getTagFilters(filters, embedding)
)
)
@@ -240,10 +248,12 @@ export async function handleTagOnlySearch(params: SearchParams): Promise<SearchR
knowledgeBaseId: embedding.knowledgeBaseId,
})
.from(embedding)
.innerJoin(document, eq(embedding.documentId, document.id))
.where(
and(
inArray(embedding.knowledgeBaseId, knowledgeBaseIds),
eq(embedding.enabled, true),
isNull(document.deletedAt),
...getTagFilters(filters, embedding)
)
)
@@ -283,10 +293,12 @@ export async function handleVectorOnlySearch(params: SearchParams): Promise<Sear
knowledgeBaseId: embedding.knowledgeBaseId,
})
.from(embedding)
.innerJoin(document, eq(embedding.documentId, document.id))
.where(
and(
eq(embedding.knowledgeBaseId, kbId),
eq(embedding.enabled, true),
isNull(document.deletedAt),
sql`${embedding.embedding} <=> ${queryVector}::vector < ${distanceThreshold}`
)
)
@@ -316,10 +328,12 @@ export async function handleVectorOnlySearch(params: SearchParams): Promise<Sear
knowledgeBaseId: embedding.knowledgeBaseId,
})
.from(embedding)
.innerJoin(document, eq(embedding.documentId, document.id))
.where(
and(
inArray(embedding.knowledgeBaseId, knowledgeBaseIds),
eq(embedding.enabled, true),
isNull(document.deletedAt),
sql`${embedding.embedding} <=> ${queryVector}::vector < ${distanceThreshold}`
)
)