mirror of
https://github.com/simstudioai/sim.git
synced 2026-04-06 03:00:16 -04:00
fix(kb-perms): search tool perms to use new system (#786)
* fix(kb): search tool perms * fix tests
This commit is contained in:
committed by
GitHub
parent
af1c7dc39d
commit
2f57d8a884
@@ -51,6 +51,11 @@ vi.mock('@/providers/utils', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
const mockCheckKnowledgeBaseAccess = vi.fn()
|
||||
vi.mock('@/app/api/knowledge/utils', () => ({
|
||||
checkKnowledgeBaseAccess: mockCheckKnowledgeBaseAccess,
|
||||
}))
|
||||
|
||||
mockConsoleLogger()
|
||||
|
||||
describe('Knowledge Search API Route', () => {
|
||||
@@ -132,7 +137,11 @@ describe('Knowledge Search API Route', () => {
|
||||
it('should perform search successfully with single knowledge base', async () => {
|
||||
mockGetUserId.mockResolvedValue('user-123')
|
||||
|
||||
mockDbChain.where.mockResolvedValueOnce(mockKnowledgeBases)
|
||||
// Mock knowledge base access check to return success
|
||||
mockCheckKnowledgeBaseAccess.mockResolvedValue({
|
||||
hasAccess: true,
|
||||
knowledgeBase: mockKnowledgeBases[0],
|
||||
})
|
||||
|
||||
mockDbChain.limit.mockResolvedValueOnce(mockSearchResults)
|
||||
|
||||
@@ -149,6 +158,10 @@ describe('Knowledge Search API Route', () => {
|
||||
const response = await POST(req)
|
||||
const data = await response.json()
|
||||
|
||||
if (response.status !== 200) {
|
||||
console.log('Test failed with response:', data)
|
||||
}
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
expect(data.success).toBe(true)
|
||||
expect(data.data.results).toHaveLength(2)
|
||||
@@ -171,7 +184,10 @@ describe('Knowledge Search API Route', () => {
|
||||
|
||||
mockGetUserId.mockResolvedValue('user-123')
|
||||
|
||||
mockDbChain.where.mockResolvedValueOnce(multiKbs)
|
||||
// Mock knowledge base access check to return success for both KBs
|
||||
mockCheckKnowledgeBaseAccess
|
||||
.mockResolvedValueOnce({ hasAccess: true, knowledgeBase: multiKbs[0] })
|
||||
.mockResolvedValueOnce({ hasAccess: true, knowledgeBase: multiKbs[1] })
|
||||
|
||||
mockDbChain.limit.mockResolvedValueOnce(mockSearchResults)
|
||||
|
||||
@@ -201,9 +217,13 @@ describe('Knowledge Search API Route', () => {
|
||||
|
||||
mockGetUserId.mockResolvedValue('user-123')
|
||||
|
||||
mockDbChain.where.mockResolvedValueOnce(mockKnowledgeBases) // First call: get knowledge bases
|
||||
// Mock knowledge base access check to return success
|
||||
mockCheckKnowledgeBaseAccess.mockResolvedValue({
|
||||
hasAccess: true,
|
||||
knowledgeBase: mockKnowledgeBases[0],
|
||||
})
|
||||
|
||||
mockDbChain.limit.mockResolvedValueOnce(mockSearchResults) // Second call: search results
|
||||
mockDbChain.limit.mockResolvedValueOnce(mockSearchResults) // Search results
|
||||
|
||||
mockFetch.mockResolvedValue({
|
||||
ok: true,
|
||||
@@ -255,7 +275,11 @@ describe('Knowledge Search API Route', () => {
|
||||
it('should return not found for non-existent knowledge base', async () => {
|
||||
mockGetUserId.mockResolvedValue('user-123')
|
||||
|
||||
mockDbChain.where.mockResolvedValueOnce([]) // No knowledge bases found
|
||||
// Mock knowledge base access check to return no access
|
||||
mockCheckKnowledgeBaseAccess.mockResolvedValue({
|
||||
hasAccess: false,
|
||||
notFound: true,
|
||||
})
|
||||
|
||||
const req = createMockRequest('POST', validSearchData)
|
||||
const { POST } = await import('./route')
|
||||
@@ -274,7 +298,10 @@ describe('Knowledge Search API Route', () => {
|
||||
|
||||
mockGetUserId.mockResolvedValue('user-123')
|
||||
|
||||
mockDbChain.where.mockResolvedValueOnce(mockKnowledgeBases) // Only kb-123 found
|
||||
// Mock access check: first KB has access, second doesn't
|
||||
mockCheckKnowledgeBaseAccess
|
||||
.mockResolvedValueOnce({ hasAccess: true, knowledgeBase: mockKnowledgeBases[0] })
|
||||
.mockResolvedValueOnce({ hasAccess: false, notFound: true })
|
||||
|
||||
const req = createMockRequest('POST', multiKbData)
|
||||
const { POST } = await import('./route')
|
||||
@@ -282,7 +309,7 @@ describe('Knowledge Search API Route', () => {
|
||||
const data = await response.json()
|
||||
|
||||
expect(response.status).toBe(404)
|
||||
expect(data.error).toBe('Knowledge bases not found: kb-missing')
|
||||
expect(data.error).toBe('Knowledge bases not found or access denied: kb-missing')
|
||||
})
|
||||
|
||||
it.concurrent('should validate search parameters', async () => {
|
||||
@@ -310,9 +337,13 @@ describe('Knowledge Search API Route', () => {
|
||||
|
||||
mockGetUserId.mockResolvedValue('user-123')
|
||||
|
||||
mockDbChain.where.mockResolvedValueOnce(mockKnowledgeBases) // First call: get knowledge bases
|
||||
// Mock knowledge base access check to return success
|
||||
mockCheckKnowledgeBaseAccess.mockResolvedValue({
|
||||
hasAccess: true,
|
||||
knowledgeBase: mockKnowledgeBases[0],
|
||||
})
|
||||
|
||||
mockDbChain.limit.mockResolvedValueOnce(mockSearchResults) // Second call: search results
|
||||
mockDbChain.limit.mockResolvedValueOnce(mockSearchResults) // Search results
|
||||
|
||||
mockFetch.mockResolvedValue({
|
||||
ok: true,
|
||||
@@ -416,7 +447,13 @@ describe('Knowledge Search API Route', () => {
|
||||
describe('Cost tracking', () => {
|
||||
it.concurrent('should include cost information in successful search response', async () => {
|
||||
mockGetUserId.mockResolvedValue('user-123')
|
||||
mockDbChain.where.mockResolvedValueOnce(mockKnowledgeBases)
|
||||
|
||||
// Mock knowledge base access check to return success
|
||||
mockCheckKnowledgeBaseAccess.mockResolvedValue({
|
||||
hasAccess: true,
|
||||
knowledgeBase: mockKnowledgeBases[0],
|
||||
})
|
||||
|
||||
mockDbChain.limit.mockResolvedValueOnce(mockSearchResults)
|
||||
|
||||
mockFetch.mockResolvedValue({
|
||||
@@ -458,7 +495,13 @@ describe('Knowledge Search API Route', () => {
|
||||
const { calculateCost } = await import('@/providers/utils')
|
||||
|
||||
mockGetUserId.mockResolvedValue('user-123')
|
||||
mockDbChain.where.mockResolvedValueOnce(mockKnowledgeBases)
|
||||
|
||||
// Mock knowledge base access check to return success
|
||||
mockCheckKnowledgeBaseAccess.mockResolvedValue({
|
||||
hasAccess: true,
|
||||
knowledgeBase: mockKnowledgeBases[0],
|
||||
})
|
||||
|
||||
mockDbChain.limit.mockResolvedValueOnce(mockSearchResults)
|
||||
|
||||
mockFetch.mockResolvedValue({
|
||||
@@ -509,7 +552,13 @@ describe('Knowledge Search API Route', () => {
|
||||
}
|
||||
|
||||
mockGetUserId.mockResolvedValue('user-123')
|
||||
mockDbChain.where.mockResolvedValueOnce(mockKnowledgeBases)
|
||||
|
||||
// Mock knowledge base access check to return success
|
||||
mockCheckKnowledgeBaseAccess.mockResolvedValue({
|
||||
hasAccess: true,
|
||||
knowledgeBase: mockKnowledgeBases[0],
|
||||
})
|
||||
|
||||
mockDbChain.limit.mockResolvedValueOnce(mockSearchResults)
|
||||
|
||||
mockFetch.mockResolvedValue({
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { and, eq, inArray, isNull, sql } from 'drizzle-orm'
|
||||
import { and, eq, inArray, sql } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { retryWithExponentialBackoff } from '@/lib/documents/utils'
|
||||
@@ -6,8 +6,9 @@ import { env } from '@/lib/env'
|
||||
import { createLogger } from '@/lib/logs/console-logger'
|
||||
import { estimateTokenCount } from '@/lib/tokenization/estimators'
|
||||
import { getUserId } from '@/app/api/auth/oauth/utils'
|
||||
import { checkKnowledgeBaseAccess } from '@/app/api/knowledge/utils'
|
||||
import { db } from '@/db'
|
||||
import { embedding, knowledgeBase } from '@/db/schema'
|
||||
import { embedding } from '@/db/schema'
|
||||
import { calculateCost } from '@/providers/utils'
|
||||
|
||||
const logger = createLogger('VectorSearchAPI')
|
||||
@@ -261,39 +262,37 @@ export async function POST(request: NextRequest) {
|
||||
? validatedData.knowledgeBaseIds
|
||||
: [validatedData.knowledgeBaseIds]
|
||||
|
||||
const [kb, queryEmbedding] = await Promise.all([
|
||||
db
|
||||
.select()
|
||||
.from(knowledgeBase)
|
||||
.where(
|
||||
and(
|
||||
inArray(knowledgeBase.id, knowledgeBaseIds),
|
||||
eq(knowledgeBase.userId, userId),
|
||||
isNull(knowledgeBase.deletedAt)
|
||||
)
|
||||
),
|
||||
generateSearchEmbedding(validatedData.query),
|
||||
])
|
||||
// Check access permissions for each knowledge base using proper workspace-based permissions
|
||||
const accessibleKbIds: string[] = []
|
||||
for (const kbId of knowledgeBaseIds) {
|
||||
const accessCheck = await checkKnowledgeBaseAccess(kbId, userId)
|
||||
if (accessCheck.hasAccess) {
|
||||
accessibleKbIds.push(kbId)
|
||||
}
|
||||
}
|
||||
|
||||
if (kb.length === 0) {
|
||||
if (accessibleKbIds.length === 0) {
|
||||
return NextResponse.json(
|
||||
{ error: 'Knowledge base not found or access denied' },
|
||||
{ status: 404 }
|
||||
)
|
||||
}
|
||||
|
||||
const foundKbIds = kb.map((k) => k.id)
|
||||
const missingKbIds = knowledgeBaseIds.filter((id) => !foundKbIds.includes(id))
|
||||
// Generate query embedding in parallel with access checks
|
||||
const queryEmbedding = await generateSearchEmbedding(validatedData.query)
|
||||
|
||||
if (missingKbIds.length > 0) {
|
||||
// Check if any requested knowledge bases were not accessible
|
||||
const inaccessibleKbIds = knowledgeBaseIds.filter((id) => !accessibleKbIds.includes(id))
|
||||
|
||||
if (inaccessibleKbIds.length > 0) {
|
||||
return NextResponse.json(
|
||||
{ error: `Knowledge bases not found: ${missingKbIds.join(', ')}` },
|
||||
{ error: `Knowledge bases not found or access denied: ${inaccessibleKbIds.join(', ')}` },
|
||||
{ status: 404 }
|
||||
)
|
||||
}
|
||||
|
||||
// Adaptive query strategy based on KB count and parameters
|
||||
const strategy = getQueryStrategy(foundKbIds.length, validatedData.topK)
|
||||
// Adaptive query strategy based on accessible KB count and parameters
|
||||
const strategy = getQueryStrategy(accessibleKbIds.length, validatedData.topK)
|
||||
const queryVector = JSON.stringify(queryEmbedding)
|
||||
|
||||
let results: any[]
|
||||
@@ -301,7 +300,7 @@ export async function POST(request: NextRequest) {
|
||||
if (strategy.useParallel) {
|
||||
// Execute parallel queries for better performance with many KBs
|
||||
const parallelResults = await executeParallelQueries(
|
||||
foundKbIds,
|
||||
accessibleKbIds,
|
||||
queryVector,
|
||||
validatedData.topK,
|
||||
strategy.distanceThreshold,
|
||||
@@ -311,7 +310,7 @@ export async function POST(request: NextRequest) {
|
||||
} else {
|
||||
// Execute single optimized query for fewer KBs
|
||||
results = await executeSingleQuery(
|
||||
foundKbIds,
|
||||
accessibleKbIds,
|
||||
queryVector,
|
||||
validatedData.topK,
|
||||
strategy.distanceThreshold,
|
||||
@@ -350,8 +349,8 @@ export async function POST(request: NextRequest) {
|
||||
similarity: 1 - result.distance,
|
||||
})),
|
||||
query: validatedData.query,
|
||||
knowledgeBaseIds: foundKbIds,
|
||||
knowledgeBaseId: foundKbIds[0],
|
||||
knowledgeBaseIds: accessibleKbIds,
|
||||
knowledgeBaseId: accessibleKbIds[0],
|
||||
topK: validatedData.topK,
|
||||
totalResults: results.length,
|
||||
...(cost && tokenCount
|
||||
|
||||
@@ -93,7 +93,9 @@ export class WorkflowBlockHandler implements BlockHandler {
|
||||
})
|
||||
|
||||
const startTime = performance.now()
|
||||
const result = await subExecutor.execute(executionId)
|
||||
// Use the actual child workflow ID for authentication, not the execution ID
|
||||
// This ensures knowledge base and other API calls can properly authenticate
|
||||
const result = await subExecutor.execute(workflowId)
|
||||
const duration = performance.now() - startTime
|
||||
|
||||
// Remove current execution from stack after completion
|
||||
|
||||
Reference in New Issue
Block a user