fix(kb-perms): search tool perms to use new system (#786)

* fix(kb): search tool perms

* fix tests
This commit is contained in:
Vikhyath Mondreti
2025-07-24 20:44:26 -07:00
committed by GitHub
parent af1c7dc39d
commit 2f57d8a884
3 changed files with 89 additions and 39 deletions

View File

@@ -51,6 +51,11 @@ vi.mock('@/providers/utils', () => ({
}),
}))
const mockCheckKnowledgeBaseAccess = vi.fn()
vi.mock('@/app/api/knowledge/utils', () => ({
checkKnowledgeBaseAccess: mockCheckKnowledgeBaseAccess,
}))
mockConsoleLogger()
describe('Knowledge Search API Route', () => {
@@ -132,7 +137,11 @@ describe('Knowledge Search API Route', () => {
it('should perform search successfully with single knowledge base', async () => {
mockGetUserId.mockResolvedValue('user-123')
mockDbChain.where.mockResolvedValueOnce(mockKnowledgeBases)
// Mock knowledge base access check to return success
mockCheckKnowledgeBaseAccess.mockResolvedValue({
hasAccess: true,
knowledgeBase: mockKnowledgeBases[0],
})
mockDbChain.limit.mockResolvedValueOnce(mockSearchResults)
@@ -149,6 +158,10 @@ describe('Knowledge Search API Route', () => {
const response = await POST(req)
const data = await response.json()
if (response.status !== 200) {
console.log('Test failed with response:', data)
}
expect(response.status).toBe(200)
expect(data.success).toBe(true)
expect(data.data.results).toHaveLength(2)
@@ -171,7 +184,10 @@ describe('Knowledge Search API Route', () => {
mockGetUserId.mockResolvedValue('user-123')
mockDbChain.where.mockResolvedValueOnce(multiKbs)
// Mock knowledge base access check to return success for both KBs
mockCheckKnowledgeBaseAccess
.mockResolvedValueOnce({ hasAccess: true, knowledgeBase: multiKbs[0] })
.mockResolvedValueOnce({ hasAccess: true, knowledgeBase: multiKbs[1] })
mockDbChain.limit.mockResolvedValueOnce(mockSearchResults)
@@ -201,9 +217,13 @@ describe('Knowledge Search API Route', () => {
mockGetUserId.mockResolvedValue('user-123')
mockDbChain.where.mockResolvedValueOnce(mockKnowledgeBases) // First call: get knowledge bases
// Mock knowledge base access check to return success
mockCheckKnowledgeBaseAccess.mockResolvedValue({
hasAccess: true,
knowledgeBase: mockKnowledgeBases[0],
})
mockDbChain.limit.mockResolvedValueOnce(mockSearchResults) // Second call: search results
mockDbChain.limit.mockResolvedValueOnce(mockSearchResults) // Search results
mockFetch.mockResolvedValue({
ok: true,
@@ -255,7 +275,11 @@ describe('Knowledge Search API Route', () => {
it('should return not found for non-existent knowledge base', async () => {
mockGetUserId.mockResolvedValue('user-123')
mockDbChain.where.mockResolvedValueOnce([]) // No knowledge bases found
// Mock knowledge base access check to return no access
mockCheckKnowledgeBaseAccess.mockResolvedValue({
hasAccess: false,
notFound: true,
})
const req = createMockRequest('POST', validSearchData)
const { POST } = await import('./route')
@@ -274,7 +298,10 @@ describe('Knowledge Search API Route', () => {
mockGetUserId.mockResolvedValue('user-123')
mockDbChain.where.mockResolvedValueOnce(mockKnowledgeBases) // Only kb-123 found
// Mock access check: first KB has access, second doesn't
mockCheckKnowledgeBaseAccess
.mockResolvedValueOnce({ hasAccess: true, knowledgeBase: mockKnowledgeBases[0] })
.mockResolvedValueOnce({ hasAccess: false, notFound: true })
const req = createMockRequest('POST', multiKbData)
const { POST } = await import('./route')
@@ -282,7 +309,7 @@ describe('Knowledge Search API Route', () => {
const data = await response.json()
expect(response.status).toBe(404)
expect(data.error).toBe('Knowledge bases not found: kb-missing')
expect(data.error).toBe('Knowledge bases not found or access denied: kb-missing')
})
it.concurrent('should validate search parameters', async () => {
@@ -310,9 +337,13 @@ describe('Knowledge Search API Route', () => {
mockGetUserId.mockResolvedValue('user-123')
mockDbChain.where.mockResolvedValueOnce(mockKnowledgeBases) // First call: get knowledge bases
// Mock knowledge base access check to return success
mockCheckKnowledgeBaseAccess.mockResolvedValue({
hasAccess: true,
knowledgeBase: mockKnowledgeBases[0],
})
mockDbChain.limit.mockResolvedValueOnce(mockSearchResults) // Second call: search results
mockDbChain.limit.mockResolvedValueOnce(mockSearchResults) // Search results
mockFetch.mockResolvedValue({
ok: true,
@@ -416,7 +447,13 @@ describe('Knowledge Search API Route', () => {
describe('Cost tracking', () => {
it.concurrent('should include cost information in successful search response', async () => {
mockGetUserId.mockResolvedValue('user-123')
mockDbChain.where.mockResolvedValueOnce(mockKnowledgeBases)
// Mock knowledge base access check to return success
mockCheckKnowledgeBaseAccess.mockResolvedValue({
hasAccess: true,
knowledgeBase: mockKnowledgeBases[0],
})
mockDbChain.limit.mockResolvedValueOnce(mockSearchResults)
mockFetch.mockResolvedValue({
@@ -458,7 +495,13 @@ describe('Knowledge Search API Route', () => {
const { calculateCost } = await import('@/providers/utils')
mockGetUserId.mockResolvedValue('user-123')
mockDbChain.where.mockResolvedValueOnce(mockKnowledgeBases)
// Mock knowledge base access check to return success
mockCheckKnowledgeBaseAccess.mockResolvedValue({
hasAccess: true,
knowledgeBase: mockKnowledgeBases[0],
})
mockDbChain.limit.mockResolvedValueOnce(mockSearchResults)
mockFetch.mockResolvedValue({
@@ -509,7 +552,13 @@ describe('Knowledge Search API Route', () => {
}
mockGetUserId.mockResolvedValue('user-123')
mockDbChain.where.mockResolvedValueOnce(mockKnowledgeBases)
// Mock knowledge base access check to return success
mockCheckKnowledgeBaseAccess.mockResolvedValue({
hasAccess: true,
knowledgeBase: mockKnowledgeBases[0],
})
mockDbChain.limit.mockResolvedValueOnce(mockSearchResults)
mockFetch.mockResolvedValue({

View File

@@ -1,4 +1,4 @@
import { and, eq, inArray, isNull, sql } from 'drizzle-orm'
import { and, eq, inArray, sql } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { retryWithExponentialBackoff } from '@/lib/documents/utils'
@@ -6,8 +6,9 @@ import { env } from '@/lib/env'
import { createLogger } from '@/lib/logs/console-logger'
import { estimateTokenCount } from '@/lib/tokenization/estimators'
import { getUserId } from '@/app/api/auth/oauth/utils'
import { checkKnowledgeBaseAccess } from '@/app/api/knowledge/utils'
import { db } from '@/db'
import { embedding, knowledgeBase } from '@/db/schema'
import { embedding } from '@/db/schema'
import { calculateCost } from '@/providers/utils'
const logger = createLogger('VectorSearchAPI')
@@ -261,39 +262,37 @@ export async function POST(request: NextRequest) {
? validatedData.knowledgeBaseIds
: [validatedData.knowledgeBaseIds]
const [kb, queryEmbedding] = await Promise.all([
db
.select()
.from(knowledgeBase)
.where(
and(
inArray(knowledgeBase.id, knowledgeBaseIds),
eq(knowledgeBase.userId, userId),
isNull(knowledgeBase.deletedAt)
)
),
generateSearchEmbedding(validatedData.query),
])
// Check access permissions for each knowledge base using proper workspace-based permissions
const accessibleKbIds: string[] = []
for (const kbId of knowledgeBaseIds) {
const accessCheck = await checkKnowledgeBaseAccess(kbId, userId)
if (accessCheck.hasAccess) {
accessibleKbIds.push(kbId)
}
}
if (kb.length === 0) {
if (accessibleKbIds.length === 0) {
return NextResponse.json(
{ error: 'Knowledge base not found or access denied' },
{ status: 404 }
)
}
const foundKbIds = kb.map((k) => k.id)
const missingKbIds = knowledgeBaseIds.filter((id) => !foundKbIds.includes(id))
// Generate query embedding in parallel with access checks
const queryEmbedding = await generateSearchEmbedding(validatedData.query)
if (missingKbIds.length > 0) {
// Check if any requested knowledge bases were not accessible
const inaccessibleKbIds = knowledgeBaseIds.filter((id) => !accessibleKbIds.includes(id))
if (inaccessibleKbIds.length > 0) {
return NextResponse.json(
{ error: `Knowledge bases not found: ${missingKbIds.join(', ')}` },
{ error: `Knowledge bases not found or access denied: ${inaccessibleKbIds.join(', ')}` },
{ status: 404 }
)
}
// Adaptive query strategy based on KB count and parameters
const strategy = getQueryStrategy(foundKbIds.length, validatedData.topK)
// Adaptive query strategy based on accessible KB count and parameters
const strategy = getQueryStrategy(accessibleKbIds.length, validatedData.topK)
const queryVector = JSON.stringify(queryEmbedding)
let results: any[]
@@ -301,7 +300,7 @@ export async function POST(request: NextRequest) {
if (strategy.useParallel) {
// Execute parallel queries for better performance with many KBs
const parallelResults = await executeParallelQueries(
foundKbIds,
accessibleKbIds,
queryVector,
validatedData.topK,
strategy.distanceThreshold,
@@ -311,7 +310,7 @@ export async function POST(request: NextRequest) {
} else {
// Execute single optimized query for fewer KBs
results = await executeSingleQuery(
foundKbIds,
accessibleKbIds,
queryVector,
validatedData.topK,
strategy.distanceThreshold,
@@ -350,8 +349,8 @@ export async function POST(request: NextRequest) {
similarity: 1 - result.distance,
})),
query: validatedData.query,
knowledgeBaseIds: foundKbIds,
knowledgeBaseId: foundKbIds[0],
knowledgeBaseIds: accessibleKbIds,
knowledgeBaseId: accessibleKbIds[0],
topK: validatedData.topK,
totalResults: results.length,
...(cost && tokenCount

View File

@@ -93,7 +93,9 @@ export class WorkflowBlockHandler implements BlockHandler {
})
const startTime = performance.now()
const result = await subExecutor.execute(executionId)
// Use the actual child workflow ID for authentication, not the execution ID
// This ensures knowledge base and other API calls can properly authenticate
const result = await subExecutor.execute(workflowId)
const duration = performance.now() - startTime
// Remove current execution from stack after completion