mirror of
https://github.com/simstudioai/sim.git
synced 2026-04-06 03:00:16 -04:00
fix(kb): fix mistral parse and kb uploads, include userId in internal auth (#1767)
* fix(kb): fix mistral parse and kb uploads, include userId in internal auth * update updated_at for kb when adding a new doc via knowledge block * update tests
This commit is contained in:
@@ -117,9 +117,23 @@ vi.mock('@sim/db', () => {
|
||||
return {
|
||||
db: {
|
||||
select: vi.fn(() => selectBuilder),
|
||||
update: () => ({
|
||||
set: () => ({
|
||||
where: () => Promise.resolve(),
|
||||
update: (table: any) => ({
|
||||
set: (payload: any) => ({
|
||||
where: () => {
|
||||
const tableSymbols = Object.getOwnPropertySymbols(table || {})
|
||||
const baseNameSymbol = tableSymbols.find((s) => s.toString().includes('BaseName'))
|
||||
const tableName = baseNameSymbol ? table[baseNameSymbol] : ''
|
||||
if (tableName === 'knowledge_base') {
|
||||
dbOps.order.push('updateKb')
|
||||
dbOps.updatePayloads.push(payload)
|
||||
} else if (tableName === 'document') {
|
||||
if (payload.processingStatus !== 'processing') {
|
||||
dbOps.order.push('updateDoc')
|
||||
dbOps.updatePayloads.push(payload)
|
||||
}
|
||||
}
|
||||
return Promise.resolve()
|
||||
},
|
||||
}),
|
||||
}),
|
||||
transaction: vi.fn(async (fn: any) => {
|
||||
@@ -131,11 +145,11 @@ vi.mock('@sim/db', () => {
|
||||
return Promise.resolve()
|
||||
},
|
||||
}),
|
||||
update: () => ({
|
||||
update: (table: any) => ({
|
||||
set: (payload: any) => ({
|
||||
where: () => {
|
||||
dbOps.updatePayloads.push(payload)
|
||||
const label = dbOps.updatePayloads.length === 1 ? 'updateDoc' : 'updateKb'
|
||||
const label = payload.processingStatus !== undefined ? 'updateDoc' : 'updateKb'
|
||||
dbOps.order.push(label)
|
||||
return Promise.resolve()
|
||||
},
|
||||
@@ -169,6 +183,9 @@ describe('Knowledge Utils', () => {
|
||||
|
||||
describe('processDocumentAsync', () => {
|
||||
it.concurrent('should insert embeddings before updating document counters', async () => {
|
||||
kbRows.push({ id: 'kb1', userId: 'user1', workspaceId: null })
|
||||
docRows.push({ id: 'doc1', knowledgeBaseId: 'kb1' })
|
||||
|
||||
await processDocumentAsync(
|
||||
'kb1',
|
||||
'doc1',
|
||||
|
||||
@@ -29,7 +29,8 @@ export async function GET(request: NextRequest, { params }: { params: Promise<{
|
||||
|
||||
if (authHeader?.startsWith('Bearer ')) {
|
||||
const token = authHeader.split(' ')[1]
|
||||
isInternalCall = await verifyInternalToken(token)
|
||||
const verification = await verifyInternalToken(token)
|
||||
isInternalCall = verification.valid
|
||||
}
|
||||
|
||||
if (!isInternalCall) {
|
||||
|
||||
@@ -37,7 +37,8 @@ export async function GET(request: NextRequest, { params }: { params: Promise<{
|
||||
|
||||
if (authHeader?.startsWith('Bearer ')) {
|
||||
const token = authHeader.split(' ')[1]
|
||||
isInternalCall = await verifyInternalToken(token)
|
||||
const verification = await verifyInternalToken(token)
|
||||
isInternalCall = verification.valid
|
||||
}
|
||||
|
||||
let userId: string | null = null
|
||||
|
||||
@@ -33,17 +33,19 @@ export async function checkHybridAuth(
|
||||
const authHeader = request.headers.get('authorization')
|
||||
if (authHeader?.startsWith('Bearer ')) {
|
||||
const token = authHeader.split(' ')[1]
|
||||
const isInternalCall = await verifyInternalToken(token)
|
||||
const verification = await verifyInternalToken(token)
|
||||
|
||||
if (isInternalCall) {
|
||||
// For internal calls, we need workflowId to determine user context
|
||||
if (verification.valid) {
|
||||
let workflowId: string | null = null
|
||||
let userId: string | null = verification.userId || null
|
||||
|
||||
// Try to get workflowId from query params or request body
|
||||
const { searchParams } = new URL(request.url)
|
||||
workflowId = searchParams.get('workflowId')
|
||||
if (!userId) {
|
||||
userId = searchParams.get('userId')
|
||||
}
|
||||
|
||||
if (!workflowId && request.method === 'POST') {
|
||||
if (!workflowId && !userId && request.method === 'POST') {
|
||||
try {
|
||||
// Clone the request to avoid consuming the original body
|
||||
const clonedRequest = request.clone()
|
||||
@@ -51,21 +53,22 @@ export async function checkHybridAuth(
|
||||
if (bodyText) {
|
||||
const body = JSON.parse(bodyText)
|
||||
workflowId = body.workflowId || body._context?.workflowId
|
||||
userId = userId || body.userId || body._context?.userId
|
||||
}
|
||||
} catch {
|
||||
// Ignore JSON parse errors
|
||||
}
|
||||
}
|
||||
|
||||
if (!workflowId && options.requireWorkflowId !== false) {
|
||||
if (userId) {
|
||||
return {
|
||||
success: false,
|
||||
error: 'workflowId required for internal JWT calls',
|
||||
success: true,
|
||||
userId,
|
||||
authType: 'internal_jwt',
|
||||
}
|
||||
}
|
||||
|
||||
if (workflowId) {
|
||||
// Get workflow owner as user context
|
||||
const [workflowData] = await db
|
||||
.select({ userId: workflow.userId })
|
||||
.from(workflow)
|
||||
@@ -85,7 +88,14 @@ export async function checkHybridAuth(
|
||||
authType: 'internal_jwt',
|
||||
}
|
||||
}
|
||||
// Internal call without workflow context - still valid for some routes
|
||||
|
||||
if (options.requireWorkflowId !== false) {
|
||||
return {
|
||||
success: false,
|
||||
error: 'workflowId or userId required for internal JWT calls',
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
success: true,
|
||||
authType: 'internal_jwt',
|
||||
|
||||
@@ -5,7 +5,6 @@ import { createLogger } from '@/lib/logs/console/logger'
|
||||
|
||||
const logger = createLogger('CronAuth')
|
||||
|
||||
// Create a secret key for JWT signing
|
||||
const getJwtSecret = () => {
|
||||
const secret = new TextEncoder().encode(env.INTERNAL_API_SECRET)
|
||||
return secret
|
||||
@@ -14,11 +13,17 @@ const getJwtSecret = () => {
|
||||
/**
|
||||
* Generate an internal JWT token for server-side API calls
|
||||
* Token expires in 5 minutes to keep it short-lived
|
||||
* @param userId Optional user ID to embed in token payload
|
||||
*/
|
||||
export async function generateInternalToken(): Promise<string> {
|
||||
export async function generateInternalToken(userId?: string): Promise<string> {
|
||||
const secret = getJwtSecret()
|
||||
|
||||
const token = await new SignJWT({ type: 'internal' })
|
||||
const payload: { type: string; userId?: string } = { type: 'internal' }
|
||||
if (userId) {
|
||||
payload.userId = userId
|
||||
}
|
||||
|
||||
const token = await new SignJWT(payload)
|
||||
.setProtectedHeader({ alg: 'HS256' })
|
||||
.setIssuedAt()
|
||||
.setExpirationTime('5m')
|
||||
@@ -31,9 +36,11 @@ export async function generateInternalToken(): Promise<string> {
|
||||
|
||||
/**
|
||||
* Verify an internal JWT token
|
||||
* Returns true if valid, false otherwise
|
||||
* Returns verification result with userId if present in token
|
||||
*/
|
||||
export async function verifyInternalToken(token: string): Promise<boolean> {
|
||||
export async function verifyInternalToken(
|
||||
token: string
|
||||
): Promise<{ valid: boolean; userId?: string }> {
|
||||
try {
|
||||
const secret = getJwtSecret()
|
||||
|
||||
@@ -43,10 +50,17 @@ export async function verifyInternalToken(token: string): Promise<boolean> {
|
||||
})
|
||||
|
||||
// Check that it's an internal token
|
||||
return payload.type === 'internal'
|
||||
if (payload.type === 'internal') {
|
||||
return {
|
||||
valid: true,
|
||||
userId: typeof payload.userId === 'string' ? payload.userId : undefined,
|
||||
}
|
||||
}
|
||||
|
||||
return { valid: false }
|
||||
} catch (error) {
|
||||
// Token verification failed
|
||||
return false
|
||||
return { valid: false }
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -56,7 +56,9 @@ export async function processDocument(
|
||||
mimeType: string,
|
||||
chunkSize = 1000,
|
||||
chunkOverlap = 200,
|
||||
minChunkSize = 1
|
||||
minChunkSize = 1,
|
||||
userId?: string,
|
||||
workspaceId?: string | null
|
||||
): Promise<{
|
||||
chunks: Chunk[]
|
||||
metadata: {
|
||||
@@ -73,7 +75,7 @@ export async function processDocument(
|
||||
logger.info(`Processing document: ${filename}`)
|
||||
|
||||
try {
|
||||
const parseResult = await parseDocument(fileUrl, filename, mimeType)
|
||||
const parseResult = await parseDocument(fileUrl, filename, mimeType, userId, workspaceId)
|
||||
const { content, processingMethod } = parseResult
|
||||
const cloudUrl = 'cloudUrl' in parseResult ? parseResult.cloudUrl : undefined
|
||||
|
||||
@@ -131,7 +133,9 @@ export async function processDocument(
|
||||
async function parseDocument(
|
||||
fileUrl: string,
|
||||
filename: string,
|
||||
mimeType: string
|
||||
mimeType: string,
|
||||
userId?: string,
|
||||
workspaceId?: string | null
|
||||
): Promise<{
|
||||
content: string
|
||||
processingMethod: 'file-parser' | 'mistral-ocr'
|
||||
@@ -146,12 +150,12 @@ async function parseDocument(
|
||||
if (isPDF && (hasAzureMistralOCR || hasMistralOCR)) {
|
||||
if (hasAzureMistralOCR) {
|
||||
logger.info(`Using Azure Mistral OCR: ${filename}`)
|
||||
return parseWithAzureMistralOCR(fileUrl, filename, mimeType)
|
||||
return parseWithAzureMistralOCR(fileUrl, filename, mimeType, userId, workspaceId)
|
||||
}
|
||||
|
||||
if (hasMistralOCR) {
|
||||
logger.info(`Using Mistral OCR: ${filename}`)
|
||||
return parseWithMistralOCR(fileUrl, filename, mimeType)
|
||||
return parseWithMistralOCR(fileUrl, filename, mimeType, userId, workspaceId)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -159,7 +163,13 @@ async function parseDocument(
|
||||
return parseWithFileParser(fileUrl, filename, mimeType)
|
||||
}
|
||||
|
||||
async function handleFileForOCR(fileUrl: string, filename: string, mimeType: string) {
|
||||
async function handleFileForOCR(
|
||||
fileUrl: string,
|
||||
filename: string,
|
||||
mimeType: string,
|
||||
userId?: string,
|
||||
workspaceId?: string | null
|
||||
) {
|
||||
const isExternalHttps = fileUrl.startsWith('https://') && !fileUrl.includes('/api/files/serve/')
|
||||
|
||||
if (isExternalHttps) {
|
||||
@@ -175,6 +185,8 @@ async function handleFileForOCR(fileUrl: string, filename: string, mimeType: str
|
||||
originalName: filename,
|
||||
uploadedAt: new Date().toISOString(),
|
||||
purpose: 'knowledge-base',
|
||||
...(userId && { userId }),
|
||||
...(workspaceId && { workspaceId }),
|
||||
}
|
||||
|
||||
const cloudResult = await StorageService.uploadFile({
|
||||
@@ -288,7 +300,13 @@ async function makeOCRRequest(
|
||||
}
|
||||
}
|
||||
|
||||
async function parseWithAzureMistralOCR(fileUrl: string, filename: string, mimeType: string) {
|
||||
async function parseWithAzureMistralOCR(
|
||||
fileUrl: string,
|
||||
filename: string,
|
||||
mimeType: string,
|
||||
userId?: string,
|
||||
workspaceId?: string | null
|
||||
) {
|
||||
validateOCRConfig(
|
||||
env.OCR_AZURE_API_KEY,
|
||||
env.OCR_AZURE_ENDPOINT,
|
||||
@@ -336,12 +354,18 @@ async function parseWithAzureMistralOCR(fileUrl: string, filename: string, mimeT
|
||||
})
|
||||
|
||||
return env.MISTRAL_API_KEY
|
||||
? parseWithMistralOCR(fileUrl, filename, mimeType)
|
||||
? parseWithMistralOCR(fileUrl, filename, mimeType, userId, workspaceId)
|
||||
: parseWithFileParser(fileUrl, filename, mimeType)
|
||||
}
|
||||
}
|
||||
|
||||
async function parseWithMistralOCR(fileUrl: string, filename: string, mimeType: string) {
|
||||
async function parseWithMistralOCR(
|
||||
fileUrl: string,
|
||||
filename: string,
|
||||
mimeType: string,
|
||||
userId?: string,
|
||||
workspaceId?: string | null
|
||||
) {
|
||||
if (!env.MISTRAL_API_KEY) {
|
||||
throw new Error('Mistral API key required')
|
||||
}
|
||||
@@ -350,7 +374,13 @@ async function parseWithMistralOCR(fileUrl: string, filename: string, mimeType:
|
||||
throw new Error('Mistral parser tool not configured')
|
||||
}
|
||||
|
||||
const { httpsUrl, cloudUrl } = await handleFileForOCR(fileUrl, filename, mimeType)
|
||||
const { httpsUrl, cloudUrl } = await handleFileForOCR(
|
||||
fileUrl,
|
||||
filename,
|
||||
mimeType,
|
||||
userId,
|
||||
workspaceId
|
||||
)
|
||||
const params = { filePath: httpsUrl, apiKey: env.MISTRAL_API_KEY, resultType: 'text' as const }
|
||||
|
||||
try {
|
||||
@@ -361,7 +391,9 @@ async function parseWithMistralOCR(fileUrl: string, filename: string, mimeType:
|
||||
? mistralParserTool.request!.url(params)
|
||||
: mistralParserTool.request!.url
|
||||
|
||||
if (url.startsWith('/')) {
|
||||
const isInternalRoute = url.startsWith('/')
|
||||
|
||||
if (isInternalRoute) {
|
||||
const { getBaseUrl } = await import('@/lib/urls/utils')
|
||||
url = `${getBaseUrl()}${url}`
|
||||
}
|
||||
@@ -371,9 +403,9 @@ async function parseWithMistralOCR(fileUrl: string, filename: string, mimeType:
|
||||
? mistralParserTool.request!.headers(params)
|
||||
: mistralParserTool.request!.headers
|
||||
|
||||
if (url.includes('/api/tools/mistral/parse')) {
|
||||
if (isInternalRoute) {
|
||||
const { generateInternalToken } = await import('@/lib/auth/internal')
|
||||
const internalToken = await generateInternalToken()
|
||||
const internalToken = await generateInternalToken(userId)
|
||||
headers = {
|
||||
...headers,
|
||||
Authorization: `Bearer ${internalToken}`,
|
||||
|
||||
@@ -439,6 +439,19 @@ export async function processDocumentAsync(
|
||||
try {
|
||||
logger.info(`[${documentId}] Starting document processing: ${docData.filename}`)
|
||||
|
||||
const kb = await db
|
||||
.select({
|
||||
userId: knowledgeBase.userId,
|
||||
workspaceId: knowledgeBase.workspaceId,
|
||||
})
|
||||
.from(knowledgeBase)
|
||||
.where(eq(knowledgeBase.id, knowledgeBaseId))
|
||||
.limit(1)
|
||||
|
||||
if (kb.length === 0) {
|
||||
throw new Error(`Knowledge base not found: ${knowledgeBaseId}`)
|
||||
}
|
||||
|
||||
await db
|
||||
.update(document)
|
||||
.set({
|
||||
@@ -458,7 +471,9 @@ export async function processDocumentAsync(
|
||||
docData.mimeType,
|
||||
processingOptions.chunkSize || 512,
|
||||
processingOptions.chunkOverlap || 200,
|
||||
processingOptions.minCharactersPerChunk || 1
|
||||
processingOptions.minCharactersPerChunk || 1,
|
||||
kb[0].userId,
|
||||
kb[0].workspaceId
|
||||
)
|
||||
|
||||
if (processed.chunks.length > LARGE_DOC_CONFIG.MAX_CHUNKS_PER_DOCUMENT) {
|
||||
@@ -758,7 +773,11 @@ export async function createDocumentRecords(
|
||||
`[${requestId}] Bulk created ${documentRecords.length} document records in knowledge base ${knowledgeBaseId}`
|
||||
)
|
||||
|
||||
// Increment storage usage tracking
|
||||
await tx
|
||||
.update(knowledgeBase)
|
||||
.set({ updatedAt: now })
|
||||
.where(eq(knowledgeBase.id, knowledgeBaseId))
|
||||
|
||||
if (userId) {
|
||||
const totalSize = documents.reduce((sum, doc) => sum + doc.fileSize, 0)
|
||||
|
||||
@@ -1070,9 +1089,13 @@ export async function createSingleDocument(
|
||||
|
||||
await db.insert(document).values(newDocument)
|
||||
|
||||
await db
|
||||
.update(knowledgeBase)
|
||||
.set({ updatedAt: now })
|
||||
.where(eq(knowledgeBase.id, knowledgeBaseId))
|
||||
|
||||
logger.info(`[${requestId}] Document created: ${documentId} in knowledge base ${knowledgeBaseId}`)
|
||||
|
||||
// Increment storage usage tracking
|
||||
if (userId) {
|
||||
// Get knowledge base owner
|
||||
const kb = await db
|
||||
|
||||
@@ -122,10 +122,16 @@ export const mistralParserTool: ToolConfig<MistralParserInput, MistralParserOutp
|
||||
throw new Error('Missing or invalid file path: Please provide a URL to a PDF document')
|
||||
}
|
||||
|
||||
// Validate and normalize URL
|
||||
let filePathToValidate = params.filePath.trim()
|
||||
if (filePathToValidate.startsWith('/')) {
|
||||
const baseUrl = getBaseUrl()
|
||||
if (!baseUrl) throw new Error('Failed to get base URL for file path conversion')
|
||||
filePathToValidate = `${baseUrl}${filePathToValidate}`
|
||||
}
|
||||
|
||||
let url
|
||||
try {
|
||||
url = new URL(params.filePath.trim())
|
||||
url = new URL(filePathToValidate)
|
||||
|
||||
// Validate protocol
|
||||
if (!['http:', 'https:'].includes(url.protocol)) {
|
||||
|
||||
Reference in New Issue
Block a user