fix(kb): fix mistral parse and kb uploads, include userId in internal auth (#1767)

* fix(kb): fix mistral parse and kb uploads, include userId in internal auth

* update updated_at for kb when adding a new doc via knowledge block

* update tests
This commit is contained in:
Waleed
2025-10-29 23:18:39 -07:00
committed by GitHub
parent 48f520b3c7
commit 47913f87de
8 changed files with 146 additions and 42 deletions

View File

@@ -117,9 +117,23 @@ vi.mock('@sim/db', () => {
return {
db: {
select: vi.fn(() => selectBuilder),
update: () => ({
set: () => ({
where: () => Promise.resolve(),
update: (table: any) => ({
set: (payload: any) => ({
where: () => {
const tableSymbols = Object.getOwnPropertySymbols(table || {})
const baseNameSymbol = tableSymbols.find((s) => s.toString().includes('BaseName'))
const tableName = baseNameSymbol ? table[baseNameSymbol] : ''
if (tableName === 'knowledge_base') {
dbOps.order.push('updateKb')
dbOps.updatePayloads.push(payload)
} else if (tableName === 'document') {
if (payload.processingStatus !== 'processing') {
dbOps.order.push('updateDoc')
dbOps.updatePayloads.push(payload)
}
}
return Promise.resolve()
},
}),
}),
transaction: vi.fn(async (fn: any) => {
@@ -131,11 +145,11 @@ vi.mock('@sim/db', () => {
return Promise.resolve()
},
}),
update: () => ({
update: (table: any) => ({
set: (payload: any) => ({
where: () => {
dbOps.updatePayloads.push(payload)
const label = dbOps.updatePayloads.length === 1 ? 'updateDoc' : 'updateKb'
const label = payload.processingStatus !== undefined ? 'updateDoc' : 'updateKb'
dbOps.order.push(label)
return Promise.resolve()
},
@@ -169,6 +183,9 @@ describe('Knowledge Utils', () => {
describe('processDocumentAsync', () => {
it.concurrent('should insert embeddings before updating document counters', async () => {
kbRows.push({ id: 'kb1', userId: 'user1', workspaceId: null })
docRows.push({ id: 'doc1', knowledgeBaseId: 'kb1' })
await processDocumentAsync(
'kb1',
'doc1',

View File

@@ -29,7 +29,8 @@ export async function GET(request: NextRequest, { params }: { params: Promise<{
if (authHeader?.startsWith('Bearer ')) {
const token = authHeader.split(' ')[1]
isInternalCall = await verifyInternalToken(token)
const verification = await verifyInternalToken(token)
isInternalCall = verification.valid
}
if (!isInternalCall) {

View File

@@ -37,7 +37,8 @@ export async function GET(request: NextRequest, { params }: { params: Promise<{
if (authHeader?.startsWith('Bearer ')) {
const token = authHeader.split(' ')[1]
isInternalCall = await verifyInternalToken(token)
const verification = await verifyInternalToken(token)
isInternalCall = verification.valid
}
let userId: string | null = null

View File

@@ -33,17 +33,19 @@ export async function checkHybridAuth(
const authHeader = request.headers.get('authorization')
if (authHeader?.startsWith('Bearer ')) {
const token = authHeader.split(' ')[1]
const isInternalCall = await verifyInternalToken(token)
const verification = await verifyInternalToken(token)
if (isInternalCall) {
// For internal calls, we need workflowId to determine user context
if (verification.valid) {
let workflowId: string | null = null
let userId: string | null = verification.userId || null
// Try to get workflowId from query params or request body
const { searchParams } = new URL(request.url)
workflowId = searchParams.get('workflowId')
if (!userId) {
userId = searchParams.get('userId')
}
if (!workflowId && request.method === 'POST') {
if (!workflowId && !userId && request.method === 'POST') {
try {
// Clone the request to avoid consuming the original body
const clonedRequest = request.clone()
@@ -51,21 +53,22 @@ export async function checkHybridAuth(
if (bodyText) {
const body = JSON.parse(bodyText)
workflowId = body.workflowId || body._context?.workflowId
userId = userId || body.userId || body._context?.userId
}
} catch {
// Ignore JSON parse errors
}
}
if (!workflowId && options.requireWorkflowId !== false) {
if (userId) {
return {
success: false,
error: 'workflowId required for internal JWT calls',
success: true,
userId,
authType: 'internal_jwt',
}
}
if (workflowId) {
// Get workflow owner as user context
const [workflowData] = await db
.select({ userId: workflow.userId })
.from(workflow)
@@ -85,7 +88,14 @@ export async function checkHybridAuth(
authType: 'internal_jwt',
}
}
// Internal call without workflow context - still valid for some routes
if (options.requireWorkflowId !== false) {
return {
success: false,
error: 'workflowId or userId required for internal JWT calls',
}
}
return {
success: true,
authType: 'internal_jwt',

View File

@@ -5,7 +5,6 @@ import { createLogger } from '@/lib/logs/console/logger'
const logger = createLogger('CronAuth')
// Create a secret key for JWT signing
const getJwtSecret = () => {
const secret = new TextEncoder().encode(env.INTERNAL_API_SECRET)
return secret
@@ -14,11 +13,17 @@ const getJwtSecret = () => {
/**
* Generate an internal JWT token for server-side API calls
* Token expires in 5 minutes to keep it short-lived
* @param userId Optional user ID to embed in token payload
*/
export async function generateInternalToken(): Promise<string> {
export async function generateInternalToken(userId?: string): Promise<string> {
const secret = getJwtSecret()
const token = await new SignJWT({ type: 'internal' })
const payload: { type: string; userId?: string } = { type: 'internal' }
if (userId) {
payload.userId = userId
}
const token = await new SignJWT(payload)
.setProtectedHeader({ alg: 'HS256' })
.setIssuedAt()
.setExpirationTime('5m')
@@ -31,9 +36,11 @@ export async function generateInternalToken(): Promise<string> {
/**
* Verify an internal JWT token
* Returns true if valid, false otherwise
* Returns verification result with userId if present in token
*/
export async function verifyInternalToken(token: string): Promise<boolean> {
export async function verifyInternalToken(
token: string
): Promise<{ valid: boolean; userId?: string }> {
try {
const secret = getJwtSecret()
@@ -43,10 +50,17 @@ export async function verifyInternalToken(token: string): Promise<boolean> {
})
// Check that it's an internal token
return payload.type === 'internal'
if (payload.type === 'internal') {
return {
valid: true,
userId: typeof payload.userId === 'string' ? payload.userId : undefined,
}
}
return { valid: false }
} catch (error) {
// Token verification failed
return false
return { valid: false }
}
}

View File

@@ -56,7 +56,9 @@ export async function processDocument(
mimeType: string,
chunkSize = 1000,
chunkOverlap = 200,
minChunkSize = 1
minChunkSize = 1,
userId?: string,
workspaceId?: string | null
): Promise<{
chunks: Chunk[]
metadata: {
@@ -73,7 +75,7 @@ export async function processDocument(
logger.info(`Processing document: ${filename}`)
try {
const parseResult = await parseDocument(fileUrl, filename, mimeType)
const parseResult = await parseDocument(fileUrl, filename, mimeType, userId, workspaceId)
const { content, processingMethod } = parseResult
const cloudUrl = 'cloudUrl' in parseResult ? parseResult.cloudUrl : undefined
@@ -131,7 +133,9 @@ export async function processDocument(
async function parseDocument(
fileUrl: string,
filename: string,
mimeType: string
mimeType: string,
userId?: string,
workspaceId?: string | null
): Promise<{
content: string
processingMethod: 'file-parser' | 'mistral-ocr'
@@ -146,12 +150,12 @@ async function parseDocument(
if (isPDF && (hasAzureMistralOCR || hasMistralOCR)) {
if (hasAzureMistralOCR) {
logger.info(`Using Azure Mistral OCR: ${filename}`)
return parseWithAzureMistralOCR(fileUrl, filename, mimeType)
return parseWithAzureMistralOCR(fileUrl, filename, mimeType, userId, workspaceId)
}
if (hasMistralOCR) {
logger.info(`Using Mistral OCR: ${filename}`)
return parseWithMistralOCR(fileUrl, filename, mimeType)
return parseWithMistralOCR(fileUrl, filename, mimeType, userId, workspaceId)
}
}
@@ -159,7 +163,13 @@ async function parseDocument(
return parseWithFileParser(fileUrl, filename, mimeType)
}
async function handleFileForOCR(fileUrl: string, filename: string, mimeType: string) {
async function handleFileForOCR(
fileUrl: string,
filename: string,
mimeType: string,
userId?: string,
workspaceId?: string | null
) {
const isExternalHttps = fileUrl.startsWith('https://') && !fileUrl.includes('/api/files/serve/')
if (isExternalHttps) {
@@ -175,6 +185,8 @@ async function handleFileForOCR(fileUrl: string, filename: string, mimeType: str
originalName: filename,
uploadedAt: new Date().toISOString(),
purpose: 'knowledge-base',
...(userId && { userId }),
...(workspaceId && { workspaceId }),
}
const cloudResult = await StorageService.uploadFile({
@@ -288,7 +300,13 @@ async function makeOCRRequest(
}
}
async function parseWithAzureMistralOCR(fileUrl: string, filename: string, mimeType: string) {
async function parseWithAzureMistralOCR(
fileUrl: string,
filename: string,
mimeType: string,
userId?: string,
workspaceId?: string | null
) {
validateOCRConfig(
env.OCR_AZURE_API_KEY,
env.OCR_AZURE_ENDPOINT,
@@ -336,12 +354,18 @@ async function parseWithAzureMistralOCR(fileUrl: string, filename: string, mimeT
})
return env.MISTRAL_API_KEY
? parseWithMistralOCR(fileUrl, filename, mimeType)
? parseWithMistralOCR(fileUrl, filename, mimeType, userId, workspaceId)
: parseWithFileParser(fileUrl, filename, mimeType)
}
}
async function parseWithMistralOCR(fileUrl: string, filename: string, mimeType: string) {
async function parseWithMistralOCR(
fileUrl: string,
filename: string,
mimeType: string,
userId?: string,
workspaceId?: string | null
) {
if (!env.MISTRAL_API_KEY) {
throw new Error('Mistral API key required')
}
@@ -350,7 +374,13 @@ async function parseWithMistralOCR(fileUrl: string, filename: string, mimeType:
throw new Error('Mistral parser tool not configured')
}
const { httpsUrl, cloudUrl } = await handleFileForOCR(fileUrl, filename, mimeType)
const { httpsUrl, cloudUrl } = await handleFileForOCR(
fileUrl,
filename,
mimeType,
userId,
workspaceId
)
const params = { filePath: httpsUrl, apiKey: env.MISTRAL_API_KEY, resultType: 'text' as const }
try {
@@ -361,7 +391,9 @@ async function parseWithMistralOCR(fileUrl: string, filename: string, mimeType:
? mistralParserTool.request!.url(params)
: mistralParserTool.request!.url
if (url.startsWith('/')) {
const isInternalRoute = url.startsWith('/')
if (isInternalRoute) {
const { getBaseUrl } = await import('@/lib/urls/utils')
url = `${getBaseUrl()}${url}`
}
@@ -371,9 +403,9 @@ async function parseWithMistralOCR(fileUrl: string, filename: string, mimeType:
? mistralParserTool.request!.headers(params)
: mistralParserTool.request!.headers
if (url.includes('/api/tools/mistral/parse')) {
if (isInternalRoute) {
const { generateInternalToken } = await import('@/lib/auth/internal')
const internalToken = await generateInternalToken()
const internalToken = await generateInternalToken(userId)
headers = {
...headers,
Authorization: `Bearer ${internalToken}`,

View File

@@ -439,6 +439,19 @@ export async function processDocumentAsync(
try {
logger.info(`[${documentId}] Starting document processing: ${docData.filename}`)
const kb = await db
.select({
userId: knowledgeBase.userId,
workspaceId: knowledgeBase.workspaceId,
})
.from(knowledgeBase)
.where(eq(knowledgeBase.id, knowledgeBaseId))
.limit(1)
if (kb.length === 0) {
throw new Error(`Knowledge base not found: ${knowledgeBaseId}`)
}
await db
.update(document)
.set({
@@ -458,7 +471,9 @@ export async function processDocumentAsync(
docData.mimeType,
processingOptions.chunkSize || 512,
processingOptions.chunkOverlap || 200,
processingOptions.minCharactersPerChunk || 1
processingOptions.minCharactersPerChunk || 1,
kb[0].userId,
kb[0].workspaceId
)
if (processed.chunks.length > LARGE_DOC_CONFIG.MAX_CHUNKS_PER_DOCUMENT) {
@@ -758,7 +773,11 @@ export async function createDocumentRecords(
`[${requestId}] Bulk created ${documentRecords.length} document records in knowledge base ${knowledgeBaseId}`
)
// Increment storage usage tracking
await tx
.update(knowledgeBase)
.set({ updatedAt: now })
.where(eq(knowledgeBase.id, knowledgeBaseId))
if (userId) {
const totalSize = documents.reduce((sum, doc) => sum + doc.fileSize, 0)
@@ -1070,9 +1089,13 @@ export async function createSingleDocument(
await db.insert(document).values(newDocument)
await db
.update(knowledgeBase)
.set({ updatedAt: now })
.where(eq(knowledgeBase.id, knowledgeBaseId))
logger.info(`[${requestId}] Document created: ${documentId} in knowledge base ${knowledgeBaseId}`)
// Increment storage usage tracking
if (userId) {
// Get knowledge base owner
const kb = await db

View File

@@ -122,10 +122,16 @@ export const mistralParserTool: ToolConfig<MistralParserInput, MistralParserOutp
throw new Error('Missing or invalid file path: Please provide a URL to a PDF document')
}
// Validate and normalize URL
let filePathToValidate = params.filePath.trim()
if (filePathToValidate.startsWith('/')) {
const baseUrl = getBaseUrl()
if (!baseUrl) throw new Error('Failed to get base URL for file path conversion')
filePathToValidate = `${baseUrl}${filePathToValidate}`
}
let url
try {
url = new URL(params.filePath.trim())
url = new URL(filePathToValidate)
// Validate protocol
if (!['http:', 'https:'].includes(url.protocol)) {