mirror of
https://github.com/simstudioai/sim.git
synced 2026-01-07 22:24:06 -05:00
fix(kb-uploads): created knowledge, chunks, tags services and use redis for queueing docs in kb (#1143)
* improvement(kb): created knowledge, chunks, tags services and use redis for queueing docs in kb * moved directories around * cleanup * bulk create docuemnt records after upload is completed * fix(copilot): send api key to sim agent (#1142) * Fix api key auth * Lint * ack PR comments * added sort by functionality for headers in kb table * updated * test fallback from redis, fix styling * cleanup copilot, fixed tooltips * feat: local auto layout (#1144) * feat: added llms.txt and robots.txt (#1145) * fix(condition-block): edges not following blocks, duplicate issues (#1146) * fix(condition-block): edges not following blocks, duplicate issues * add subblock update to setActiveWorkflow * Update apps/sim/app/workspace/[workspaceId]/w/[workflowId]/components/workflow-block/components/sub-block/components/condition-input.tsx Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> --------- Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> * fix dependency array * fix(copilot-cleanup): support azure blob upload in copilot, remove dead code & consolidate other copilot files (#1147) * cleanup * support azure blob image upload * imports cleanup * PR comments * ack PR comments * fix key validation * improvement(forwarding+excel): added forwarding and improve excel read (#1136) * added forwarding for outlook * lint * improved excel sheet read * addressed greptile * fixed bodytext getting truncated * fixed any type * added html func --------- Co-authored-by: Adam Gough <adamgough@Mac.attlocal.net> * revert agent const * update docs --------- Co-authored-by: Siddharth Ganesan <33737564+Sg312@users.noreply.github.com> Co-authored-by: Emir Karabeg <78010029+emir-karabeg@users.noreply.github.com> Co-authored-by: Vikhyath Mondreti <vikhyathvikku@gmail.com> Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Co-authored-by: Vikhyath Mondreti <vikhyath@simstudio.ai> Co-authored-by: Adam Gough <77861281+aadamgough@users.noreply.github.com> Co-authored-by: Adam Gough <adamgough@Mac.attlocal.net>
This commit is contained in:
@@ -109,7 +109,7 @@ Read data from a Microsoft Excel spreadsheet
|
||||
| Parameter | Type | Required | Description |
|
||||
| --------- | ---- | -------- | ----------- |
|
||||
| `spreadsheetId` | string | Yes | The ID of the spreadsheet to read from |
|
||||
| `range` | string | No | The range of cells to read from |
|
||||
| `range` | string | No | The range of cells to read from. Accepts "SheetName!A1:B2" for explicit ranges or just "SheetName" to read the used range of that sheet. If omitted, reads the used range of the first sheet. |
|
||||
|
||||
#### Output
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ Upload a file to OneDrive
|
||||
| `fileName` | string | Yes | The name of the file to upload |
|
||||
| `content` | string | Yes | The content of the file to upload |
|
||||
| `folderSelector` | string | No | Select the folder to upload the file to |
|
||||
| `folderId` | string | No | The ID of the folder to upload the file to \(internal use\) |
|
||||
| `manualFolderId` | string | No | Manually entered folder ID \(advanced mode\) |
|
||||
|
||||
#### Output
|
||||
|
||||
@@ -87,7 +87,7 @@ Create a new folder in OneDrive
|
||||
| --------- | ---- | -------- | ----------- |
|
||||
| `folderName` | string | Yes | Name of the folder to create |
|
||||
| `folderSelector` | string | No | Select the parent folder to create the folder in |
|
||||
| `folderId` | string | No | ID of the parent folder \(internal use\) |
|
||||
| `manualFolderId` | string | No | Manually entered parent folder ID \(advanced mode\) |
|
||||
|
||||
#### Output
|
||||
|
||||
@@ -105,7 +105,7 @@ List files and folders in OneDrive
|
||||
| Parameter | Type | Required | Description |
|
||||
| --------- | ---- | -------- | ----------- |
|
||||
| `folderSelector` | string | No | Select the folder to list files from |
|
||||
| `folderId` | string | No | The ID of the folder to list files from \(internal use\) |
|
||||
| `manualFolderId` | string | No | The manually entered folder ID \(advanced mode\) |
|
||||
| `query` | string | No | A query to filter the files |
|
||||
| `pageSize` | number | No | The number of files to return |
|
||||
|
||||
|
||||
@@ -211,10 +211,27 @@ Read emails from Outlook
|
||||
|
||||
| Parameter | Type | Description |
|
||||
| --------- | ---- | ----------- |
|
||||
| `success` | boolean | Email read operation success status |
|
||||
| `messageCount` | number | Number of emails retrieved |
|
||||
| `messages` | array | Array of email message objects |
|
||||
| `message` | string | Success or status message |
|
||||
| `results` | array | Array of email message objects |
|
||||
|
||||
### `outlook_forward`
|
||||
|
||||
Forward an existing Outlook message to specified recipients
|
||||
|
||||
#### Input
|
||||
|
||||
| Parameter | Type | Required | Description |
|
||||
| --------- | ---- | -------- | ----------- |
|
||||
| `messageId` | string | Yes | The ID of the message to forward |
|
||||
| `to` | string | Yes | Recipient email address\(es\), comma-separated |
|
||||
| `comment` | string | No | Optional comment to include with the forwarded message |
|
||||
|
||||
#### Output
|
||||
|
||||
| Parameter | Type | Description |
|
||||
| --------- | ---- | ----------- |
|
||||
| `message` | string | Success or error message |
|
||||
| `results` | object | Delivery result details |
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -49,15 +49,12 @@ const PASSWORD_VALIDATIONS = {
|
||||
},
|
||||
}
|
||||
|
||||
// Validate callback URL to prevent open redirect vulnerabilities
|
||||
const validateCallbackUrl = (url: string): boolean => {
|
||||
try {
|
||||
// If it's a relative URL, it's safe
|
||||
if (url.startsWith('/')) {
|
||||
return true
|
||||
}
|
||||
|
||||
// If absolute URL, check if it belongs to the same origin
|
||||
const currentOrigin = typeof window !== 'undefined' ? window.location.origin : ''
|
||||
if (url.startsWith(currentOrigin)) {
|
||||
return true
|
||||
@@ -70,7 +67,6 @@ const validateCallbackUrl = (url: string): boolean => {
|
||||
}
|
||||
}
|
||||
|
||||
// Validate password and return array of error messages
|
||||
const validatePassword = (passwordValue: string): string[] => {
|
||||
const errors: string[] = []
|
||||
|
||||
@@ -521,9 +517,7 @@ export default function LoginPage({
|
||||
</div>
|
||||
{resetStatus.type && (
|
||||
<div
|
||||
className={`text-sm ${
|
||||
resetStatus.type === 'success' ? 'text-[#4CAF50]' : 'text-red-500'
|
||||
}`}
|
||||
className={`text-sm ${resetStatus.type === 'success' ? 'text-[#4CAF50]' : 'text-red-500'}`}
|
||||
>
|
||||
{resetStatus.message}
|
||||
</div>
|
||||
|
||||
@@ -109,7 +109,9 @@ export async function PUT(request: NextRequest) {
|
||||
// If we can't decrypt the existing value, treat as changed and re-encrypt
|
||||
logger.warn(
|
||||
`[${requestId}] Could not decrypt existing variable ${key}, re-encrypting`,
|
||||
{ error: decryptError }
|
||||
{
|
||||
error: decryptError,
|
||||
}
|
||||
)
|
||||
variablesToEncrypt[key] = newValue
|
||||
updatedVariables.push(key)
|
||||
|
||||
@@ -1,16 +1,8 @@
|
||||
import {
|
||||
AbortMultipartUploadCommand,
|
||||
CompleteMultipartUploadCommand,
|
||||
CreateMultipartUploadCommand,
|
||||
UploadPartCommand,
|
||||
} from '@aws-sdk/client-s3'
|
||||
import { getSignedUrl } from '@aws-sdk/s3-request-presigner'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { v4 as uuidv4 } from 'uuid'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { getStorageProvider, isUsingCloudStorage } from '@/lib/uploads'
|
||||
import { S3_KB_CONFIG } from '@/lib/uploads/setup'
|
||||
import { BLOB_KB_CONFIG } from '@/lib/uploads/setup'
|
||||
|
||||
const logger = createLogger('MultipartUploadAPI')
|
||||
|
||||
@@ -26,15 +18,6 @@ interface GetPartUrlsRequest {
|
||||
partNumbers: number[]
|
||||
}
|
||||
|
||||
interface CompleteMultipartRequest {
|
||||
uploadId: string
|
||||
key: string
|
||||
parts: Array<{
|
||||
ETag: string
|
||||
PartNumber: number
|
||||
}>
|
||||
}
|
||||
|
||||
export async function POST(request: NextRequest) {
|
||||
try {
|
||||
const session = await getSession()
|
||||
@@ -44,106 +27,214 @@ export async function POST(request: NextRequest) {
|
||||
|
||||
const action = request.nextUrl.searchParams.get('action')
|
||||
|
||||
if (!isUsingCloudStorage() || getStorageProvider() !== 's3') {
|
||||
if (!isUsingCloudStorage()) {
|
||||
return NextResponse.json(
|
||||
{ error: 'Multipart upload is only available with S3 storage' },
|
||||
{ error: 'Multipart upload is only available with cloud storage (S3 or Azure Blob)' },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
|
||||
const { getS3Client } = await import('@/lib/uploads/s3/s3-client')
|
||||
const s3Client = getS3Client()
|
||||
const storageProvider = getStorageProvider()
|
||||
|
||||
switch (action) {
|
||||
case 'initiate': {
|
||||
const data: InitiateMultipartRequest = await request.json()
|
||||
const { fileName, contentType } = data
|
||||
const { fileName, contentType, fileSize } = data
|
||||
|
||||
const safeFileName = fileName.replace(/\s+/g, '-').replace(/[^a-zA-Z0-9.-]/g, '_')
|
||||
const uniqueKey = `kb/${uuidv4()}-${safeFileName}`
|
||||
if (storageProvider === 's3') {
|
||||
const { initiateS3MultipartUpload } = await import('@/lib/uploads/s3/s3-client')
|
||||
|
||||
const command = new CreateMultipartUploadCommand({
|
||||
Bucket: S3_KB_CONFIG.bucket,
|
||||
Key: uniqueKey,
|
||||
ContentType: contentType,
|
||||
Metadata: {
|
||||
originalName: fileName,
|
||||
uploadedAt: new Date().toISOString(),
|
||||
purpose: 'knowledge-base',
|
||||
},
|
||||
})
|
||||
const result = await initiateS3MultipartUpload({
|
||||
fileName,
|
||||
contentType,
|
||||
fileSize,
|
||||
})
|
||||
|
||||
const response = await s3Client.send(command)
|
||||
logger.info(`Initiated S3 multipart upload for ${fileName}: ${result.uploadId}`)
|
||||
|
||||
logger.info(`Initiated multipart upload for ${fileName}: ${response.UploadId}`)
|
||||
return NextResponse.json({
|
||||
uploadId: result.uploadId,
|
||||
key: result.key,
|
||||
})
|
||||
}
|
||||
if (storageProvider === 'blob') {
|
||||
const { initiateMultipartUpload } = await import('@/lib/uploads/blob/blob-client')
|
||||
|
||||
return NextResponse.json({
|
||||
uploadId: response.UploadId,
|
||||
key: uniqueKey,
|
||||
})
|
||||
const result = await initiateMultipartUpload({
|
||||
fileName,
|
||||
contentType,
|
||||
fileSize,
|
||||
customConfig: {
|
||||
containerName: BLOB_KB_CONFIG.containerName,
|
||||
accountName: BLOB_KB_CONFIG.accountName,
|
||||
accountKey: BLOB_KB_CONFIG.accountKey,
|
||||
connectionString: BLOB_KB_CONFIG.connectionString,
|
||||
},
|
||||
})
|
||||
|
||||
logger.info(`Initiated Azure multipart upload for ${fileName}: ${result.uploadId}`)
|
||||
|
||||
return NextResponse.json({
|
||||
uploadId: result.uploadId,
|
||||
key: result.key,
|
||||
})
|
||||
}
|
||||
|
||||
return NextResponse.json(
|
||||
{ error: `Unsupported storage provider: ${storageProvider}` },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
|
||||
case 'get-part-urls': {
|
||||
const data: GetPartUrlsRequest = await request.json()
|
||||
const { uploadId, key, partNumbers } = data
|
||||
|
||||
const presignedUrls = await Promise.all(
|
||||
partNumbers.map(async (partNumber) => {
|
||||
const command = new UploadPartCommand({
|
||||
Bucket: S3_KB_CONFIG.bucket,
|
||||
Key: key,
|
||||
PartNumber: partNumber,
|
||||
UploadId: uploadId,
|
||||
})
|
||||
if (storageProvider === 's3') {
|
||||
const { getS3MultipartPartUrls } = await import('@/lib/uploads/s3/s3-client')
|
||||
|
||||
const url = await getSignedUrl(s3Client, command, { expiresIn: 3600 })
|
||||
return { partNumber, url }
|
||||
const presignedUrls = await getS3MultipartPartUrls(key, uploadId, partNumbers)
|
||||
|
||||
return NextResponse.json({ presignedUrls })
|
||||
}
|
||||
if (storageProvider === 'blob') {
|
||||
const { getMultipartPartUrls } = await import('@/lib/uploads/blob/blob-client')
|
||||
|
||||
const presignedUrls = await getMultipartPartUrls(key, uploadId, partNumbers, {
|
||||
containerName: BLOB_KB_CONFIG.containerName,
|
||||
accountName: BLOB_KB_CONFIG.accountName,
|
||||
accountKey: BLOB_KB_CONFIG.accountKey,
|
||||
connectionString: BLOB_KB_CONFIG.connectionString,
|
||||
})
|
||||
)
|
||||
|
||||
return NextResponse.json({ presignedUrls })
|
||||
return NextResponse.json({ presignedUrls })
|
||||
}
|
||||
|
||||
return NextResponse.json(
|
||||
{ error: `Unsupported storage provider: ${storageProvider}` },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
|
||||
case 'complete': {
|
||||
const data: CompleteMultipartRequest = await request.json()
|
||||
const data = await request.json()
|
||||
|
||||
// Handle batch completion
|
||||
if ('uploads' in data) {
|
||||
const results = await Promise.all(
|
||||
data.uploads.map(async (upload: any) => {
|
||||
const { uploadId, key } = upload
|
||||
|
||||
if (storageProvider === 's3') {
|
||||
const { completeS3MultipartUpload } = await import('@/lib/uploads/s3/s3-client')
|
||||
const parts = upload.parts // S3 format: { ETag, PartNumber }
|
||||
|
||||
const result = await completeS3MultipartUpload(key, uploadId, parts)
|
||||
|
||||
return {
|
||||
success: true,
|
||||
location: result.location,
|
||||
path: result.path,
|
||||
key: result.key,
|
||||
}
|
||||
}
|
||||
if (storageProvider === 'blob') {
|
||||
const { completeMultipartUpload } = await import('@/lib/uploads/blob/blob-client')
|
||||
const parts = upload.parts // Azure format: { blockId, partNumber }
|
||||
|
||||
const result = await completeMultipartUpload(key, uploadId, parts, {
|
||||
containerName: BLOB_KB_CONFIG.containerName,
|
||||
accountName: BLOB_KB_CONFIG.accountName,
|
||||
accountKey: BLOB_KB_CONFIG.accountKey,
|
||||
connectionString: BLOB_KB_CONFIG.connectionString,
|
||||
})
|
||||
|
||||
return {
|
||||
success: true,
|
||||
location: result.location,
|
||||
path: result.path,
|
||||
key: result.key,
|
||||
}
|
||||
}
|
||||
|
||||
throw new Error(`Unsupported storage provider: ${storageProvider}`)
|
||||
})
|
||||
)
|
||||
|
||||
logger.info(`Completed ${data.uploads.length} multipart uploads`)
|
||||
return NextResponse.json({ results })
|
||||
}
|
||||
|
||||
// Handle single completion
|
||||
const { uploadId, key, parts } = data
|
||||
|
||||
const command = new CompleteMultipartUploadCommand({
|
||||
Bucket: S3_KB_CONFIG.bucket,
|
||||
Key: key,
|
||||
UploadId: uploadId,
|
||||
MultipartUpload: {
|
||||
Parts: parts.sort((a, b) => a.PartNumber - b.PartNumber),
|
||||
},
|
||||
})
|
||||
if (storageProvider === 's3') {
|
||||
const { completeS3MultipartUpload } = await import('@/lib/uploads/s3/s3-client')
|
||||
|
||||
const response = await s3Client.send(command)
|
||||
const result = await completeS3MultipartUpload(key, uploadId, parts)
|
||||
|
||||
logger.info(`Completed multipart upload for key ${key}`)
|
||||
logger.info(`Completed S3 multipart upload for key ${key}`)
|
||||
|
||||
const finalPath = `/api/files/serve/s3/${encodeURIComponent(key)}`
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
location: result.location,
|
||||
path: result.path,
|
||||
key: result.key,
|
||||
})
|
||||
}
|
||||
if (storageProvider === 'blob') {
|
||||
const { completeMultipartUpload } = await import('@/lib/uploads/blob/blob-client')
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
location: response.Location,
|
||||
path: finalPath,
|
||||
key,
|
||||
})
|
||||
const result = await completeMultipartUpload(key, uploadId, parts, {
|
||||
containerName: BLOB_KB_CONFIG.containerName,
|
||||
accountName: BLOB_KB_CONFIG.accountName,
|
||||
accountKey: BLOB_KB_CONFIG.accountKey,
|
||||
connectionString: BLOB_KB_CONFIG.connectionString,
|
||||
})
|
||||
|
||||
logger.info(`Completed Azure multipart upload for key ${key}`)
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
location: result.location,
|
||||
path: result.path,
|
||||
key: result.key,
|
||||
})
|
||||
}
|
||||
|
||||
return NextResponse.json(
|
||||
{ error: `Unsupported storage provider: ${storageProvider}` },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
|
||||
case 'abort': {
|
||||
const data = await request.json()
|
||||
const { uploadId, key } = data
|
||||
|
||||
const command = new AbortMultipartUploadCommand({
|
||||
Bucket: S3_KB_CONFIG.bucket,
|
||||
Key: key,
|
||||
UploadId: uploadId,
|
||||
})
|
||||
if (storageProvider === 's3') {
|
||||
const { abortS3MultipartUpload } = await import('@/lib/uploads/s3/s3-client')
|
||||
|
||||
await s3Client.send(command)
|
||||
await abortS3MultipartUpload(key, uploadId)
|
||||
|
||||
logger.info(`Aborted multipart upload for key ${key}`)
|
||||
logger.info(`Aborted S3 multipart upload for key ${key}`)
|
||||
} else if (storageProvider === 'blob') {
|
||||
const { abortMultipartUpload } = await import('@/lib/uploads/blob/blob-client')
|
||||
|
||||
await abortMultipartUpload(key, uploadId, {
|
||||
containerName: BLOB_KB_CONFIG.containerName,
|
||||
accountName: BLOB_KB_CONFIG.accountName,
|
||||
accountKey: BLOB_KB_CONFIG.accountKey,
|
||||
connectionString: BLOB_KB_CONFIG.connectionString,
|
||||
})
|
||||
|
||||
logger.info(`Aborted Azure multipart upload for key ${key}`)
|
||||
} else {
|
||||
return NextResponse.json(
|
||||
{ error: `Unsupported storage provider: ${storageProvider}` },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
|
||||
return NextResponse.json({ success: true })
|
||||
}
|
||||
|
||||
361
apps/sim/app/api/files/presigned/batch/route.ts
Normal file
361
apps/sim/app/api/files/presigned/batch/route.ts
Normal file
@@ -0,0 +1,361 @@
|
||||
import { PutObjectCommand } from '@aws-sdk/client-s3'
|
||||
import { getSignedUrl } from '@aws-sdk/s3-request-presigner'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { v4 as uuidv4 } from 'uuid'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { getStorageProvider, isUsingCloudStorage } from '@/lib/uploads'
|
||||
import {
|
||||
BLOB_CHAT_CONFIG,
|
||||
BLOB_CONFIG,
|
||||
BLOB_COPILOT_CONFIG,
|
||||
BLOB_KB_CONFIG,
|
||||
S3_CHAT_CONFIG,
|
||||
S3_CONFIG,
|
||||
S3_COPILOT_CONFIG,
|
||||
S3_KB_CONFIG,
|
||||
} from '@/lib/uploads/setup'
|
||||
import { validateFileType } from '@/lib/uploads/validation'
|
||||
import { createErrorResponse, createOptionsResponse } from '@/app/api/files/utils'
|
||||
|
||||
const logger = createLogger('BatchPresignedUploadAPI')
|
||||
|
||||
interface BatchFileRequest {
|
||||
fileName: string
|
||||
contentType: string
|
||||
fileSize: number
|
||||
}
|
||||
|
||||
interface BatchPresignedUrlRequest {
|
||||
files: BatchFileRequest[]
|
||||
}
|
||||
|
||||
type UploadType = 'general' | 'knowledge-base' | 'chat' | 'copilot'
|
||||
|
||||
export async function POST(request: NextRequest) {
|
||||
try {
|
||||
const session = await getSession()
|
||||
if (!session?.user?.id) {
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
let data: BatchPresignedUrlRequest
|
||||
try {
|
||||
data = await request.json()
|
||||
} catch {
|
||||
return NextResponse.json({ error: 'Invalid JSON in request body' }, { status: 400 })
|
||||
}
|
||||
|
||||
const { files } = data
|
||||
|
||||
if (!files || !Array.isArray(files) || files.length === 0) {
|
||||
return NextResponse.json(
|
||||
{ error: 'files array is required and cannot be empty' },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
|
||||
if (files.length > 100) {
|
||||
return NextResponse.json(
|
||||
{ error: 'Cannot process more than 100 files at once' },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
|
||||
const uploadTypeParam = request.nextUrl.searchParams.get('type')
|
||||
const uploadType: UploadType =
|
||||
uploadTypeParam === 'knowledge-base'
|
||||
? 'knowledge-base'
|
||||
: uploadTypeParam === 'chat'
|
||||
? 'chat'
|
||||
: uploadTypeParam === 'copilot'
|
||||
? 'copilot'
|
||||
: 'general'
|
||||
|
||||
const MAX_FILE_SIZE = 100 * 1024 * 1024
|
||||
for (const file of files) {
|
||||
if (!file.fileName?.trim()) {
|
||||
return NextResponse.json({ error: 'fileName is required for all files' }, { status: 400 })
|
||||
}
|
||||
if (!file.contentType?.trim()) {
|
||||
return NextResponse.json(
|
||||
{ error: 'contentType is required for all files' },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
if (!file.fileSize || file.fileSize <= 0) {
|
||||
return NextResponse.json(
|
||||
{ error: 'fileSize must be positive for all files' },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
if (file.fileSize > MAX_FILE_SIZE) {
|
||||
return NextResponse.json(
|
||||
{ error: `File ${file.fileName} exceeds maximum size of ${MAX_FILE_SIZE} bytes` },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
|
||||
if (uploadType === 'knowledge-base') {
|
||||
const fileValidationError = validateFileType(file.fileName, file.contentType)
|
||||
if (fileValidationError) {
|
||||
return NextResponse.json(
|
||||
{
|
||||
error: fileValidationError.message,
|
||||
code: fileValidationError.code,
|
||||
supportedTypes: fileValidationError.supportedTypes,
|
||||
},
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const sessionUserId = session.user.id
|
||||
|
||||
if (uploadType === 'copilot' && !sessionUserId?.trim()) {
|
||||
return NextResponse.json(
|
||||
{ error: 'Authenticated user session is required for copilot uploads' },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
|
||||
if (!isUsingCloudStorage()) {
|
||||
return NextResponse.json(
|
||||
{ error: 'Direct uploads are only available when cloud storage is enabled' },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
|
||||
const storageProvider = getStorageProvider()
|
||||
logger.info(
|
||||
`Generating batch ${uploadType} presigned URLs for ${files.length} files using ${storageProvider}`
|
||||
)
|
||||
|
||||
const startTime = Date.now()
|
||||
|
||||
let result
|
||||
switch (storageProvider) {
|
||||
case 's3':
|
||||
result = await handleBatchS3PresignedUrls(files, uploadType, sessionUserId)
|
||||
break
|
||||
case 'blob':
|
||||
result = await handleBatchBlobPresignedUrls(files, uploadType, sessionUserId)
|
||||
break
|
||||
default:
|
||||
return NextResponse.json(
|
||||
{ error: `Unknown storage provider: ${storageProvider}` },
|
||||
{ status: 500 }
|
||||
)
|
||||
}
|
||||
|
||||
const duration = Date.now() - startTime
|
||||
logger.info(
|
||||
`Generated ${files.length} presigned URLs in ${duration}ms (avg ${Math.round(duration / files.length)}ms per file)`
|
||||
)
|
||||
|
||||
return NextResponse.json(result)
|
||||
} catch (error) {
|
||||
logger.error('Error generating batch presigned URLs:', error)
|
||||
return createErrorResponse(
|
||||
error instanceof Error ? error : new Error('Failed to generate batch presigned URLs')
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
async function handleBatchS3PresignedUrls(
|
||||
files: BatchFileRequest[],
|
||||
uploadType: UploadType,
|
||||
userId?: string
|
||||
) {
|
||||
const config =
|
||||
uploadType === 'knowledge-base'
|
||||
? S3_KB_CONFIG
|
||||
: uploadType === 'chat'
|
||||
? S3_CHAT_CONFIG
|
||||
: uploadType === 'copilot'
|
||||
? S3_COPILOT_CONFIG
|
||||
: S3_CONFIG
|
||||
|
||||
if (!config.bucket || !config.region) {
|
||||
throw new Error(`S3 configuration missing for ${uploadType} uploads`)
|
||||
}
|
||||
|
||||
const { getS3Client, sanitizeFilenameForMetadata } = await import('@/lib/uploads/s3/s3-client')
|
||||
const s3Client = getS3Client()
|
||||
|
||||
let prefix = ''
|
||||
if (uploadType === 'knowledge-base') {
|
||||
prefix = 'kb/'
|
||||
} else if (uploadType === 'chat') {
|
||||
prefix = 'chat/'
|
||||
} else if (uploadType === 'copilot') {
|
||||
prefix = `${userId}/`
|
||||
}
|
||||
|
||||
const baseMetadata: Record<string, string> = {
|
||||
uploadedAt: new Date().toISOString(),
|
||||
}
|
||||
|
||||
if (uploadType === 'knowledge-base') {
|
||||
baseMetadata.purpose = 'knowledge-base'
|
||||
} else if (uploadType === 'chat') {
|
||||
baseMetadata.purpose = 'chat'
|
||||
} else if (uploadType === 'copilot') {
|
||||
baseMetadata.purpose = 'copilot'
|
||||
baseMetadata.userId = userId || ''
|
||||
}
|
||||
|
||||
const results = await Promise.all(
|
||||
files.map(async (file) => {
|
||||
const safeFileName = file.fileName.replace(/\s+/g, '-').replace(/[^a-zA-Z0-9.-]/g, '_')
|
||||
const uniqueKey = `${prefix}${uuidv4()}-${safeFileName}`
|
||||
const sanitizedOriginalName = sanitizeFilenameForMetadata(file.fileName)
|
||||
|
||||
const metadata = {
|
||||
...baseMetadata,
|
||||
originalName: sanitizedOriginalName,
|
||||
}
|
||||
|
||||
const command = new PutObjectCommand({
|
||||
Bucket: config.bucket,
|
||||
Key: uniqueKey,
|
||||
ContentType: file.contentType,
|
||||
Metadata: metadata,
|
||||
})
|
||||
|
||||
const presignedUrl = await getSignedUrl(s3Client, command, { expiresIn: 3600 })
|
||||
|
||||
const finalPath =
|
||||
uploadType === 'chat'
|
||||
? `https://${config.bucket}.s3.${config.region}.amazonaws.com/${uniqueKey}`
|
||||
: `/api/files/serve/s3/${encodeURIComponent(uniqueKey)}`
|
||||
|
||||
return {
|
||||
fileName: file.fileName,
|
||||
presignedUrl,
|
||||
fileInfo: {
|
||||
path: finalPath,
|
||||
key: uniqueKey,
|
||||
name: file.fileName,
|
||||
size: file.fileSize,
|
||||
type: file.contentType,
|
||||
},
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
return {
|
||||
files: results,
|
||||
directUploadSupported: true,
|
||||
}
|
||||
}
|
||||
|
||||
async function handleBatchBlobPresignedUrls(
|
||||
files: BatchFileRequest[],
|
||||
uploadType: UploadType,
|
||||
userId?: string
|
||||
) {
|
||||
const config =
|
||||
uploadType === 'knowledge-base'
|
||||
? BLOB_KB_CONFIG
|
||||
: uploadType === 'chat'
|
||||
? BLOB_CHAT_CONFIG
|
||||
: uploadType === 'copilot'
|
||||
? BLOB_COPILOT_CONFIG
|
||||
: BLOB_CONFIG
|
||||
|
||||
if (
|
||||
!config.accountName ||
|
||||
!config.containerName ||
|
||||
(!config.accountKey && !config.connectionString)
|
||||
) {
|
||||
throw new Error(`Azure Blob configuration missing for ${uploadType} uploads`)
|
||||
}
|
||||
|
||||
const { getBlobServiceClient } = await import('@/lib/uploads/blob/blob-client')
|
||||
const { BlobSASPermissions, generateBlobSASQueryParameters, StorageSharedKeyCredential } =
|
||||
await import('@azure/storage-blob')
|
||||
|
||||
const blobServiceClient = getBlobServiceClient()
|
||||
const containerClient = blobServiceClient.getContainerClient(config.containerName)
|
||||
|
||||
let prefix = ''
|
||||
if (uploadType === 'knowledge-base') {
|
||||
prefix = 'kb/'
|
||||
} else if (uploadType === 'chat') {
|
||||
prefix = 'chat/'
|
||||
} else if (uploadType === 'copilot') {
|
||||
prefix = `${userId}/`
|
||||
}
|
||||
|
||||
const baseUploadHeaders: Record<string, string> = {
|
||||
'x-ms-blob-type': 'BlockBlob',
|
||||
'x-ms-meta-uploadedat': new Date().toISOString(),
|
||||
}
|
||||
|
||||
if (uploadType === 'knowledge-base') {
|
||||
baseUploadHeaders['x-ms-meta-purpose'] = 'knowledge-base'
|
||||
} else if (uploadType === 'chat') {
|
||||
baseUploadHeaders['x-ms-meta-purpose'] = 'chat'
|
||||
} else if (uploadType === 'copilot') {
|
||||
baseUploadHeaders['x-ms-meta-purpose'] = 'copilot'
|
||||
baseUploadHeaders['x-ms-meta-userid'] = encodeURIComponent(userId || '')
|
||||
}
|
||||
|
||||
const results = await Promise.all(
|
||||
files.map(async (file) => {
|
||||
const safeFileName = file.fileName.replace(/\s+/g, '-').replace(/[^a-zA-Z0-9.-]/g, '_')
|
||||
const uniqueKey = `${prefix}${uuidv4()}-${safeFileName}`
|
||||
const blockBlobClient = containerClient.getBlockBlobClient(uniqueKey)
|
||||
|
||||
const sasOptions = {
|
||||
containerName: config.containerName,
|
||||
blobName: uniqueKey,
|
||||
permissions: BlobSASPermissions.parse('w'),
|
||||
startsOn: new Date(),
|
||||
expiresOn: new Date(Date.now() + 3600 * 1000),
|
||||
}
|
||||
|
||||
const sasToken = generateBlobSASQueryParameters(
|
||||
sasOptions,
|
||||
new StorageSharedKeyCredential(config.accountName, config.accountKey || '')
|
||||
).toString()
|
||||
|
||||
const presignedUrl = `${blockBlobClient.url}?${sasToken}`
|
||||
|
||||
const finalPath =
|
||||
uploadType === 'chat'
|
||||
? blockBlobClient.url
|
||||
: `/api/files/serve/blob/${encodeURIComponent(uniqueKey)}`
|
||||
|
||||
const uploadHeaders = {
|
||||
...baseUploadHeaders,
|
||||
'x-ms-blob-content-type': file.contentType,
|
||||
'x-ms-meta-originalname': encodeURIComponent(file.fileName),
|
||||
}
|
||||
|
||||
return {
|
||||
fileName: file.fileName,
|
||||
presignedUrl,
|
||||
fileInfo: {
|
||||
path: finalPath,
|
||||
key: uniqueKey,
|
||||
name: file.fileName,
|
||||
size: file.fileSize,
|
||||
type: file.contentType,
|
||||
},
|
||||
uploadHeaders,
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
return {
|
||||
files: results,
|
||||
directUploadSupported: true,
|
||||
}
|
||||
}
|
||||
|
||||
export async function OPTIONS() {
|
||||
return createOptionsResponse()
|
||||
}
|
||||
@@ -16,6 +16,7 @@ import {
|
||||
S3_COPILOT_CONFIG,
|
||||
S3_KB_CONFIG,
|
||||
} from '@/lib/uploads/setup'
|
||||
import { validateFileType } from '@/lib/uploads/validation'
|
||||
import { createErrorResponse, createOptionsResponse } from '@/app/api/files/utils'
|
||||
|
||||
const logger = createLogger('PresignedUploadAPI')
|
||||
@@ -96,6 +97,13 @@ export async function POST(request: NextRequest) {
|
||||
? 'copilot'
|
||||
: 'general'
|
||||
|
||||
if (uploadType === 'knowledge-base') {
|
||||
const fileValidationError = validateFileType(fileName, contentType)
|
||||
if (fileValidationError) {
|
||||
throw new ValidationError(`${fileValidationError.message}`)
|
||||
}
|
||||
}
|
||||
|
||||
// Evaluate user id from session for copilot uploads
|
||||
const sessionUserId = session.user.id
|
||||
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
import { createHash, randomUUID } from 'crypto'
|
||||
import { eq, sql } from 'drizzle-orm'
|
||||
import { randomUUID } from 'crypto'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { deleteChunk, updateChunk } from '@/lib/knowledge/chunks/service'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { checkChunkAccess } from '@/app/api/knowledge/utils'
|
||||
import { db } from '@/db'
|
||||
import { document, embedding } from '@/db/schema'
|
||||
|
||||
const logger = createLogger('ChunkByIdAPI')
|
||||
|
||||
@@ -102,33 +100,7 @@ export async function PUT(
|
||||
try {
|
||||
const validatedData = UpdateChunkSchema.parse(body)
|
||||
|
||||
const updateData: Partial<{
|
||||
content: string
|
||||
contentLength: number
|
||||
tokenCount: number
|
||||
chunkHash: string
|
||||
enabled: boolean
|
||||
updatedAt: Date
|
||||
}> = {}
|
||||
|
||||
if (validatedData.content) {
|
||||
updateData.content = validatedData.content
|
||||
updateData.contentLength = validatedData.content.length
|
||||
// Update token count estimation (rough approximation: 4 chars per token)
|
||||
updateData.tokenCount = Math.ceil(validatedData.content.length / 4)
|
||||
updateData.chunkHash = createHash('sha256').update(validatedData.content).digest('hex')
|
||||
}
|
||||
|
||||
if (validatedData.enabled !== undefined) updateData.enabled = validatedData.enabled
|
||||
|
||||
await db.update(embedding).set(updateData).where(eq(embedding.id, chunkId))
|
||||
|
||||
// Fetch the updated chunk
|
||||
const updatedChunk = await db
|
||||
.select()
|
||||
.from(embedding)
|
||||
.where(eq(embedding.id, chunkId))
|
||||
.limit(1)
|
||||
const updatedChunk = await updateChunk(chunkId, validatedData, requestId)
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Chunk updated: ${chunkId} in document ${documentId} in knowledge base ${knowledgeBaseId}`
|
||||
@@ -136,7 +108,7 @@ export async function PUT(
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
data: updatedChunk[0],
|
||||
data: updatedChunk,
|
||||
})
|
||||
} catch (validationError) {
|
||||
if (validationError instanceof z.ZodError) {
|
||||
@@ -190,37 +162,7 @@ export async function DELETE(
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
// Use transaction to atomically delete chunk and update document statistics
|
||||
await db.transaction(async (tx) => {
|
||||
// Get chunk data before deletion for statistics update
|
||||
const chunkToDelete = await tx
|
||||
.select({
|
||||
tokenCount: embedding.tokenCount,
|
||||
contentLength: embedding.contentLength,
|
||||
})
|
||||
.from(embedding)
|
||||
.where(eq(embedding.id, chunkId))
|
||||
.limit(1)
|
||||
|
||||
if (chunkToDelete.length === 0) {
|
||||
throw new Error('Chunk not found')
|
||||
}
|
||||
|
||||
const chunk = chunkToDelete[0]
|
||||
|
||||
// Delete the chunk
|
||||
await tx.delete(embedding).where(eq(embedding.id, chunkId))
|
||||
|
||||
// Update document statistics
|
||||
await tx
|
||||
.update(document)
|
||||
.set({
|
||||
chunkCount: sql`${document.chunkCount} - 1`,
|
||||
tokenCount: sql`${document.tokenCount} - ${chunk.tokenCount}`,
|
||||
characterCount: sql`${document.characterCount} - ${chunk.contentLength}`,
|
||||
})
|
||||
.where(eq(document.id, documentId))
|
||||
})
|
||||
await deleteChunk(chunkId, documentId, requestId)
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Chunk deleted: ${chunkId} from document ${documentId} in knowledge base ${knowledgeBaseId}`
|
||||
|
||||
@@ -1,378 +0,0 @@
|
||||
/**
|
||||
* Tests for knowledge document chunks API route
|
||||
*
|
||||
* @vitest-environment node
|
||||
*/
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import {
|
||||
createMockRequest,
|
||||
mockAuth,
|
||||
mockConsoleLogger,
|
||||
mockDrizzleOrm,
|
||||
mockKnowledgeSchemas,
|
||||
} from '@/app/api/__test-utils__/utils'
|
||||
|
||||
mockKnowledgeSchemas()
|
||||
mockDrizzleOrm()
|
||||
mockConsoleLogger()
|
||||
|
||||
vi.mock('@/lib/tokenization/estimators', () => ({
|
||||
estimateTokenCount: vi.fn().mockReturnValue({ count: 452 }),
|
||||
}))
|
||||
|
||||
vi.mock('@/providers/utils', () => ({
|
||||
calculateCost: vi.fn().mockReturnValue({
|
||||
input: 0.00000904,
|
||||
output: 0,
|
||||
total: 0.00000904,
|
||||
pricing: {
|
||||
input: 0.02,
|
||||
output: 0,
|
||||
updatedAt: '2025-07-10',
|
||||
},
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/api/knowledge/utils', () => ({
|
||||
checkKnowledgeBaseAccess: vi.fn(),
|
||||
checkKnowledgeBaseWriteAccess: vi.fn(),
|
||||
checkDocumentAccess: vi.fn(),
|
||||
checkDocumentWriteAccess: vi.fn(),
|
||||
checkChunkAccess: vi.fn(),
|
||||
generateEmbeddings: vi.fn().mockResolvedValue([[0.1, 0.2, 0.3, 0.4, 0.5]]),
|
||||
processDocumentAsync: vi.fn(),
|
||||
}))
|
||||
|
||||
describe('Knowledge Document Chunks API Route', () => {
|
||||
const mockAuth$ = mockAuth()
|
||||
|
||||
const mockDbChain = {
|
||||
select: vi.fn().mockReturnThis(),
|
||||
from: vi.fn().mockReturnThis(),
|
||||
where: vi.fn().mockReturnThis(),
|
||||
orderBy: vi.fn().mockReturnThis(),
|
||||
limit: vi.fn().mockReturnThis(),
|
||||
offset: vi.fn().mockReturnThis(),
|
||||
insert: vi.fn().mockReturnThis(),
|
||||
values: vi.fn().mockResolvedValue(undefined),
|
||||
update: vi.fn().mockReturnThis(),
|
||||
set: vi.fn().mockReturnThis(),
|
||||
returning: vi.fn().mockResolvedValue([]),
|
||||
delete: vi.fn().mockReturnThis(),
|
||||
transaction: vi.fn(),
|
||||
}
|
||||
|
||||
const mockGetUserId = vi.fn()
|
||||
|
||||
beforeEach(async () => {
|
||||
vi.clearAllMocks()
|
||||
|
||||
vi.doMock('@/db', () => ({
|
||||
db: mockDbChain,
|
||||
}))
|
||||
|
||||
vi.doMock('@/app/api/auth/oauth/utils', () => ({
|
||||
getUserId: mockGetUserId,
|
||||
}))
|
||||
|
||||
Object.values(mockDbChain).forEach((fn) => {
|
||||
if (typeof fn === 'function' && fn !== mockDbChain.values && fn !== mockDbChain.returning) {
|
||||
fn.mockClear().mockReturnThis()
|
||||
}
|
||||
})
|
||||
|
||||
vi.stubGlobal('crypto', {
|
||||
randomUUID: vi.fn().mockReturnValue('mock-chunk-uuid-1234'),
|
||||
createHash: vi.fn().mockReturnValue({
|
||||
update: vi.fn().mockReturnThis(),
|
||||
digest: vi.fn().mockReturnValue('mock-hash-123'),
|
||||
}),
|
||||
})
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('POST /api/knowledge/[id]/documents/[documentId]/chunks', () => {
|
||||
const validChunkData = {
|
||||
content: 'This is test chunk content for uploading to the knowledge base document.',
|
||||
enabled: true,
|
||||
}
|
||||
|
||||
const mockDocumentAccess = {
|
||||
hasAccess: true,
|
||||
notFound: false,
|
||||
reason: '',
|
||||
document: {
|
||||
id: 'doc-123',
|
||||
processingStatus: 'completed',
|
||||
tag1: 'tag1-value',
|
||||
tag2: 'tag2-value',
|
||||
tag3: null,
|
||||
tag4: null,
|
||||
tag5: null,
|
||||
tag6: null,
|
||||
tag7: null,
|
||||
},
|
||||
}
|
||||
|
||||
const mockParams = Promise.resolve({ id: 'kb-123', documentId: 'doc-123' })
|
||||
|
||||
it('should create chunk successfully with cost tracking', async () => {
|
||||
const { checkDocumentWriteAccess, generateEmbeddings } = await import(
|
||||
'@/app/api/knowledge/utils'
|
||||
)
|
||||
const { estimateTokenCount } = await import('@/lib/tokenization/estimators')
|
||||
const { calculateCost } = await import('@/providers/utils')
|
||||
|
||||
mockGetUserId.mockResolvedValue('user-123')
|
||||
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
|
||||
...mockDocumentAccess,
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
} as any)
|
||||
|
||||
// Mock generateEmbeddings
|
||||
vi.mocked(generateEmbeddings).mockResolvedValue([[0.1, 0.2, 0.3]])
|
||||
|
||||
// Mock transaction
|
||||
const mockTx = {
|
||||
select: vi.fn().mockReturnThis(),
|
||||
from: vi.fn().mockReturnThis(),
|
||||
where: vi.fn().mockReturnThis(),
|
||||
orderBy: vi.fn().mockReturnThis(),
|
||||
limit: vi.fn().mockResolvedValue([{ chunkIndex: 0 }]),
|
||||
insert: vi.fn().mockReturnThis(),
|
||||
values: vi.fn().mockResolvedValue(undefined),
|
||||
update: vi.fn().mockReturnThis(),
|
||||
set: vi.fn().mockReturnThis(),
|
||||
}
|
||||
|
||||
mockDbChain.transaction.mockImplementation(async (callback) => {
|
||||
return await callback(mockTx)
|
||||
})
|
||||
|
||||
const req = createMockRequest('POST', validChunkData)
|
||||
const { POST } = await import('@/app/api/knowledge/[id]/documents/[documentId]/chunks/route')
|
||||
const response = await POST(req, { params: mockParams })
|
||||
const data = await response.json()
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
expect(data.success).toBe(true)
|
||||
|
||||
// Verify cost tracking
|
||||
expect(data.data.cost).toBeDefined()
|
||||
expect(data.data.cost.input).toBe(0.00000904)
|
||||
expect(data.data.cost.output).toBe(0)
|
||||
expect(data.data.cost.total).toBe(0.00000904)
|
||||
expect(data.data.cost.tokens).toEqual({
|
||||
prompt: 452,
|
||||
completion: 0,
|
||||
total: 452,
|
||||
})
|
||||
expect(data.data.cost.model).toBe('text-embedding-3-small')
|
||||
expect(data.data.cost.pricing).toEqual({
|
||||
input: 0.02,
|
||||
output: 0,
|
||||
updatedAt: '2025-07-10',
|
||||
})
|
||||
|
||||
// Verify function calls
|
||||
expect(estimateTokenCount).toHaveBeenCalledWith(validChunkData.content, 'openai')
|
||||
expect(calculateCost).toHaveBeenCalledWith('text-embedding-3-small', 452, 0, false)
|
||||
})
|
||||
|
||||
it('should handle workflow-based authentication', async () => {
|
||||
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
|
||||
|
||||
const workflowData = {
|
||||
...validChunkData,
|
||||
workflowId: 'workflow-123',
|
||||
}
|
||||
|
||||
mockGetUserId.mockResolvedValue('user-123')
|
||||
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
|
||||
...mockDocumentAccess,
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
} as any)
|
||||
|
||||
const mockTx = {
|
||||
select: vi.fn().mockReturnThis(),
|
||||
from: vi.fn().mockReturnThis(),
|
||||
where: vi.fn().mockReturnThis(),
|
||||
orderBy: vi.fn().mockReturnThis(),
|
||||
limit: vi.fn().mockResolvedValue([]),
|
||||
insert: vi.fn().mockReturnThis(),
|
||||
values: vi.fn().mockResolvedValue(undefined),
|
||||
update: vi.fn().mockReturnThis(),
|
||||
set: vi.fn().mockReturnThis(),
|
||||
}
|
||||
|
||||
mockDbChain.transaction.mockImplementation(async (callback) => {
|
||||
return await callback(mockTx)
|
||||
})
|
||||
|
||||
const req = createMockRequest('POST', workflowData)
|
||||
const { POST } = await import('@/app/api/knowledge/[id]/documents/[documentId]/chunks/route')
|
||||
const response = await POST(req, { params: mockParams })
|
||||
const data = await response.json()
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
expect(data.success).toBe(true)
|
||||
expect(mockGetUserId).toHaveBeenCalledWith(expect.any(String), 'workflow-123')
|
||||
})
|
||||
|
||||
it.concurrent('should return unauthorized for unauthenticated request', async () => {
|
||||
mockGetUserId.mockResolvedValue(null)
|
||||
|
||||
const req = createMockRequest('POST', validChunkData)
|
||||
const { POST } = await import('@/app/api/knowledge/[id]/documents/[documentId]/chunks/route')
|
||||
const response = await POST(req, { params: mockParams })
|
||||
const data = await response.json()
|
||||
|
||||
expect(response.status).toBe(401)
|
||||
expect(data.error).toBe('Unauthorized')
|
||||
})
|
||||
|
||||
it('should return not found for workflow that does not exist', async () => {
|
||||
const workflowData = {
|
||||
...validChunkData,
|
||||
workflowId: 'nonexistent-workflow',
|
||||
}
|
||||
|
||||
mockGetUserId.mockResolvedValue(null)
|
||||
|
||||
const req = createMockRequest('POST', workflowData)
|
||||
const { POST } = await import('@/app/api/knowledge/[id]/documents/[documentId]/chunks/route')
|
||||
const response = await POST(req, { params: mockParams })
|
||||
const data = await response.json()
|
||||
|
||||
expect(response.status).toBe(404)
|
||||
expect(data.error).toBe('Workflow not found')
|
||||
})
|
||||
|
||||
it.concurrent('should return not found for document access denied', async () => {
|
||||
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
|
||||
|
||||
mockGetUserId.mockResolvedValue('user-123')
|
||||
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
|
||||
hasAccess: false,
|
||||
notFound: true,
|
||||
reason: 'Document not found',
|
||||
})
|
||||
|
||||
const req = createMockRequest('POST', validChunkData)
|
||||
const { POST } = await import('@/app/api/knowledge/[id]/documents/[documentId]/chunks/route')
|
||||
const response = await POST(req, { params: mockParams })
|
||||
const data = await response.json()
|
||||
|
||||
expect(response.status).toBe(404)
|
||||
expect(data.error).toBe('Document not found')
|
||||
})
|
||||
|
||||
it('should return unauthorized for unauthorized document access', async () => {
|
||||
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
|
||||
|
||||
mockGetUserId.mockResolvedValue('user-123')
|
||||
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
|
||||
hasAccess: false,
|
||||
notFound: false,
|
||||
reason: 'Unauthorized access',
|
||||
})
|
||||
|
||||
const req = createMockRequest('POST', validChunkData)
|
||||
const { POST } = await import('@/app/api/knowledge/[id]/documents/[documentId]/chunks/route')
|
||||
const response = await POST(req, { params: mockParams })
|
||||
const data = await response.json()
|
||||
|
||||
expect(response.status).toBe(401)
|
||||
expect(data.error).toBe('Unauthorized')
|
||||
})
|
||||
|
||||
it('should reject chunks for failed documents', async () => {
|
||||
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
|
||||
|
||||
mockGetUserId.mockResolvedValue('user-123')
|
||||
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
|
||||
...mockDocumentAccess,
|
||||
document: {
|
||||
...mockDocumentAccess.document!,
|
||||
processingStatus: 'failed',
|
||||
},
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
} as any)
|
||||
|
||||
const req = createMockRequest('POST', validChunkData)
|
||||
const { POST } = await import('@/app/api/knowledge/[id]/documents/[documentId]/chunks/route')
|
||||
const response = await POST(req, { params: mockParams })
|
||||
const data = await response.json()
|
||||
|
||||
expect(response.status).toBe(400)
|
||||
expect(data.error).toBe('Cannot add chunks to failed document')
|
||||
})
|
||||
|
||||
it.concurrent('should validate chunk data', async () => {
|
||||
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
|
||||
|
||||
mockGetUserId.mockResolvedValue('user-123')
|
||||
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
|
||||
...mockDocumentAccess,
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
} as any)
|
||||
|
||||
const invalidData = {
|
||||
content: '', // Empty content
|
||||
enabled: true,
|
||||
}
|
||||
|
||||
const req = createMockRequest('POST', invalidData)
|
||||
const { POST } = await import('@/app/api/knowledge/[id]/documents/[documentId]/chunks/route')
|
||||
const response = await POST(req, { params: mockParams })
|
||||
const data = await response.json()
|
||||
|
||||
expect(response.status).toBe(400)
|
||||
expect(data.error).toBe('Invalid request data')
|
||||
expect(data.details).toBeDefined()
|
||||
})
|
||||
|
||||
it('should inherit tags from parent document', async () => {
|
||||
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
|
||||
|
||||
mockGetUserId.mockResolvedValue('user-123')
|
||||
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
|
||||
...mockDocumentAccess,
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
} as any)
|
||||
|
||||
const mockTx = {
|
||||
select: vi.fn().mockReturnThis(),
|
||||
from: vi.fn().mockReturnThis(),
|
||||
where: vi.fn().mockReturnThis(),
|
||||
orderBy: vi.fn().mockReturnThis(),
|
||||
limit: vi.fn().mockResolvedValue([]),
|
||||
insert: vi.fn().mockReturnThis(),
|
||||
values: vi.fn().mockImplementation((data) => {
|
||||
// Verify that tags are inherited from document
|
||||
expect(data.tag1).toBe('tag1-value')
|
||||
expect(data.tag2).toBe('tag2-value')
|
||||
expect(data.tag3).toBe(null)
|
||||
return Promise.resolve(undefined)
|
||||
}),
|
||||
update: vi.fn().mockReturnThis(),
|
||||
set: vi.fn().mockReturnThis(),
|
||||
}
|
||||
|
||||
mockDbChain.transaction.mockImplementation(async (callback) => {
|
||||
return await callback(mockTx)
|
||||
})
|
||||
|
||||
const req = createMockRequest('POST', validChunkData)
|
||||
const { POST } = await import('@/app/api/knowledge/[id]/documents/[documentId]/chunks/route')
|
||||
await POST(req, { params: mockParams })
|
||||
|
||||
expect(mockTx.values).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
// REMOVED: "should handle cost calculation with different content lengths" test - it was failing
|
||||
})
|
||||
})
|
||||
@@ -1,18 +1,11 @@
|
||||
import crypto from 'crypto'
|
||||
import { and, asc, eq, ilike, inArray, sql } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { batchChunkOperation, createChunk, queryChunks } from '@/lib/knowledge/chunks/service'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { estimateTokenCount } from '@/lib/tokenization/estimators'
|
||||
import { getUserId } from '@/app/api/auth/oauth/utils'
|
||||
import {
|
||||
checkDocumentAccess,
|
||||
checkDocumentWriteAccess,
|
||||
generateEmbeddings,
|
||||
} from '@/app/api/knowledge/utils'
|
||||
import { db } from '@/db'
|
||||
import { document, embedding } from '@/db/schema'
|
||||
import { checkDocumentAccess, checkDocumentWriteAccess } from '@/app/api/knowledge/utils'
|
||||
import { calculateCost } from '@/providers/utils'
|
||||
|
||||
const logger = createLogger('DocumentChunksAPI')
|
||||
@@ -66,7 +59,6 @@ export async function GET(
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
// Check if document processing is completed
|
||||
const doc = accessCheck.document
|
||||
if (!doc) {
|
||||
logger.warn(
|
||||
@@ -89,7 +81,6 @@ export async function GET(
|
||||
)
|
||||
}
|
||||
|
||||
// Parse query parameters
|
||||
const { searchParams } = new URL(req.url)
|
||||
const queryParams = GetChunksQuerySchema.parse({
|
||||
search: searchParams.get('search') || undefined,
|
||||
@@ -98,67 +89,12 @@ export async function GET(
|
||||
offset: searchParams.get('offset') || undefined,
|
||||
})
|
||||
|
||||
// Build query conditions
|
||||
const conditions = [eq(embedding.documentId, documentId)]
|
||||
|
||||
// Add enabled filter
|
||||
if (queryParams.enabled === 'true') {
|
||||
conditions.push(eq(embedding.enabled, true))
|
||||
} else if (queryParams.enabled === 'false') {
|
||||
conditions.push(eq(embedding.enabled, false))
|
||||
}
|
||||
|
||||
// Add search filter
|
||||
if (queryParams.search) {
|
||||
conditions.push(ilike(embedding.content, `%${queryParams.search}%`))
|
||||
}
|
||||
|
||||
// Fetch chunks
|
||||
const chunks = await db
|
||||
.select({
|
||||
id: embedding.id,
|
||||
chunkIndex: embedding.chunkIndex,
|
||||
content: embedding.content,
|
||||
contentLength: embedding.contentLength,
|
||||
tokenCount: embedding.tokenCount,
|
||||
enabled: embedding.enabled,
|
||||
startOffset: embedding.startOffset,
|
||||
endOffset: embedding.endOffset,
|
||||
tag1: embedding.tag1,
|
||||
tag2: embedding.tag2,
|
||||
tag3: embedding.tag3,
|
||||
tag4: embedding.tag4,
|
||||
tag5: embedding.tag5,
|
||||
tag6: embedding.tag6,
|
||||
tag7: embedding.tag7,
|
||||
createdAt: embedding.createdAt,
|
||||
updatedAt: embedding.updatedAt,
|
||||
})
|
||||
.from(embedding)
|
||||
.where(and(...conditions))
|
||||
.orderBy(asc(embedding.chunkIndex))
|
||||
.limit(queryParams.limit)
|
||||
.offset(queryParams.offset)
|
||||
|
||||
// Get total count for pagination
|
||||
const totalCount = await db
|
||||
.select({ count: sql`count(*)` })
|
||||
.from(embedding)
|
||||
.where(and(...conditions))
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Retrieved ${chunks.length} chunks for document ${documentId} in knowledge base ${knowledgeBaseId}`
|
||||
)
|
||||
const result = await queryChunks(documentId, queryParams, requestId)
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
data: chunks,
|
||||
pagination: {
|
||||
total: Number(totalCount[0]?.count || 0),
|
||||
limit: queryParams.limit,
|
||||
offset: queryParams.offset,
|
||||
hasMore: chunks.length === queryParams.limit,
|
||||
},
|
||||
data: result.chunks,
|
||||
pagination: result.pagination,
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error(`[${requestId}] Error fetching chunks`, error)
|
||||
@@ -219,76 +155,27 @@ export async function POST(
|
||||
try {
|
||||
const validatedData = CreateChunkSchema.parse(searchParams)
|
||||
|
||||
// Generate embedding for the content first (outside transaction for performance)
|
||||
logger.info(`[${requestId}] Generating embedding for manual chunk`)
|
||||
const embeddings = await generateEmbeddings([validatedData.content])
|
||||
const docTags = {
|
||||
tag1: doc.tag1 ?? null,
|
||||
tag2: doc.tag2 ?? null,
|
||||
tag3: doc.tag3 ?? null,
|
||||
tag4: doc.tag4 ?? null,
|
||||
tag5: doc.tag5 ?? null,
|
||||
tag6: doc.tag6 ?? null,
|
||||
tag7: doc.tag7 ?? null,
|
||||
}
|
||||
|
||||
// Calculate accurate token count for both database storage and cost calculation
|
||||
const tokenCount = estimateTokenCount(validatedData.content, 'openai')
|
||||
const newChunk = await createChunk(
|
||||
knowledgeBaseId,
|
||||
documentId,
|
||||
docTags,
|
||||
validatedData,
|
||||
requestId
|
||||
)
|
||||
|
||||
const chunkId = crypto.randomUUID()
|
||||
const now = new Date()
|
||||
|
||||
// Use transaction to atomically get next index and insert chunk
|
||||
const newChunk = await db.transaction(async (tx) => {
|
||||
// Get the next chunk index atomically within the transaction
|
||||
const lastChunk = await tx
|
||||
.select({ chunkIndex: embedding.chunkIndex })
|
||||
.from(embedding)
|
||||
.where(eq(embedding.documentId, documentId))
|
||||
.orderBy(sql`${embedding.chunkIndex} DESC`)
|
||||
.limit(1)
|
||||
|
||||
const nextChunkIndex = lastChunk.length > 0 ? lastChunk[0].chunkIndex + 1 : 0
|
||||
|
||||
const chunkData = {
|
||||
id: chunkId,
|
||||
knowledgeBaseId,
|
||||
documentId,
|
||||
chunkIndex: nextChunkIndex,
|
||||
chunkHash: crypto.createHash('sha256').update(validatedData.content).digest('hex'),
|
||||
content: validatedData.content,
|
||||
contentLength: validatedData.content.length,
|
||||
tokenCount: tokenCount.count, // Use accurate token count
|
||||
embedding: embeddings[0],
|
||||
embeddingModel: 'text-embedding-3-small',
|
||||
startOffset: 0, // Manual chunks don't have document offsets
|
||||
endOffset: validatedData.content.length,
|
||||
// Inherit tags from parent document
|
||||
tag1: doc.tag1,
|
||||
tag2: doc.tag2,
|
||||
tag3: doc.tag3,
|
||||
tag4: doc.tag4,
|
||||
tag5: doc.tag5,
|
||||
tag6: doc.tag6,
|
||||
tag7: doc.tag7,
|
||||
enabled: validatedData.enabled,
|
||||
createdAt: now,
|
||||
updatedAt: now,
|
||||
}
|
||||
|
||||
// Insert the new chunk
|
||||
await tx.insert(embedding).values(chunkData)
|
||||
|
||||
// Update document statistics
|
||||
await tx
|
||||
.update(document)
|
||||
.set({
|
||||
chunkCount: sql`${document.chunkCount} + 1`,
|
||||
tokenCount: sql`${document.tokenCount} + ${chunkData.tokenCount}`,
|
||||
characterCount: sql`${document.characterCount} + ${chunkData.contentLength}`,
|
||||
})
|
||||
.where(eq(document.id, documentId))
|
||||
|
||||
return chunkData
|
||||
})
|
||||
|
||||
logger.info(`[${requestId}] Manual chunk created: ${chunkId} in document ${documentId}`)
|
||||
|
||||
// Calculate cost for the embedding (with fallback if calculation fails)
|
||||
let cost = null
|
||||
try {
|
||||
cost = calculateCost('text-embedding-3-small', tokenCount.count, 0, false)
|
||||
cost = calculateCost('text-embedding-3-small', newChunk.tokenCount, 0, false)
|
||||
} catch (error) {
|
||||
logger.warn(`[${requestId}] Failed to calculate cost for chunk upload`, {
|
||||
error: error instanceof Error ? error.message : 'Unknown error',
|
||||
@@ -307,9 +194,9 @@ export async function POST(
|
||||
output: cost.output,
|
||||
total: cost.total,
|
||||
tokens: {
|
||||
prompt: tokenCount.count,
|
||||
prompt: newChunk.tokenCount,
|
||||
completion: 0,
|
||||
total: tokenCount.count,
|
||||
total: newChunk.tokenCount,
|
||||
},
|
||||
model: 'text-embedding-3-small',
|
||||
pricing: cost.pricing,
|
||||
@@ -371,92 +258,16 @@ export async function PATCH(
|
||||
const validatedData = BatchOperationSchema.parse(body)
|
||||
const { operation, chunkIds } = validatedData
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Starting batch ${operation} operation on ${chunkIds.length} chunks for document ${documentId}`
|
||||
)
|
||||
|
||||
const results = []
|
||||
let successCount = 0
|
||||
const errorCount = 0
|
||||
|
||||
if (operation === 'delete') {
|
||||
// Handle batch delete with transaction for consistency
|
||||
await db.transaction(async (tx) => {
|
||||
// Get chunks to delete for statistics update
|
||||
const chunksToDelete = await tx
|
||||
.select({
|
||||
id: embedding.id,
|
||||
tokenCount: embedding.tokenCount,
|
||||
contentLength: embedding.contentLength,
|
||||
})
|
||||
.from(embedding)
|
||||
.where(and(eq(embedding.documentId, documentId), inArray(embedding.id, chunkIds)))
|
||||
|
||||
if (chunksToDelete.length === 0) {
|
||||
throw new Error('No valid chunks found to delete')
|
||||
}
|
||||
|
||||
// Delete chunks
|
||||
await tx
|
||||
.delete(embedding)
|
||||
.where(and(eq(embedding.documentId, documentId), inArray(embedding.id, chunkIds)))
|
||||
|
||||
// Update document statistics
|
||||
const totalTokens = chunksToDelete.reduce((sum, chunk) => sum + chunk.tokenCount, 0)
|
||||
const totalCharacters = chunksToDelete.reduce(
|
||||
(sum, chunk) => sum + chunk.contentLength,
|
||||
0
|
||||
)
|
||||
|
||||
await tx
|
||||
.update(document)
|
||||
.set({
|
||||
chunkCount: sql`${document.chunkCount} - ${chunksToDelete.length}`,
|
||||
tokenCount: sql`${document.tokenCount} - ${totalTokens}`,
|
||||
characterCount: sql`${document.characterCount} - ${totalCharacters}`,
|
||||
})
|
||||
.where(eq(document.id, documentId))
|
||||
|
||||
successCount = chunksToDelete.length
|
||||
results.push({
|
||||
operation: 'delete',
|
||||
deletedCount: chunksToDelete.length,
|
||||
chunkIds: chunksToDelete.map((c) => c.id),
|
||||
})
|
||||
})
|
||||
} else {
|
||||
// Handle batch enable/disable
|
||||
const enabled = operation === 'enable'
|
||||
|
||||
// Update chunks in a single query
|
||||
const updateResult = await db
|
||||
.update(embedding)
|
||||
.set({
|
||||
enabled,
|
||||
updatedAt: new Date(),
|
||||
})
|
||||
.where(and(eq(embedding.documentId, documentId), inArray(embedding.id, chunkIds)))
|
||||
.returning({ id: embedding.id })
|
||||
|
||||
successCount = updateResult.length
|
||||
results.push({
|
||||
operation,
|
||||
updatedCount: updateResult.length,
|
||||
chunkIds: updateResult.map((r) => r.id),
|
||||
})
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Batch ${operation} operation completed: ${successCount} successful, ${errorCount} errors`
|
||||
)
|
||||
const result = await batchChunkOperation(documentId, operation, chunkIds, requestId)
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
data: {
|
||||
operation,
|
||||
successCount,
|
||||
errorCount,
|
||||
results,
|
||||
successCount: result.processed,
|
||||
errorCount: result.errors.length,
|
||||
processed: result.processed,
|
||||
errors: result.errors,
|
||||
},
|
||||
})
|
||||
} catch (validationError) {
|
||||
|
||||
@@ -24,7 +24,14 @@ vi.mock('@/app/api/knowledge/utils', () => ({
|
||||
processDocumentAsync: vi.fn(),
|
||||
}))
|
||||
|
||||
// Setup common mocks
|
||||
vi.mock('@/lib/knowledge/documents/service', () => ({
|
||||
updateDocument: vi.fn(),
|
||||
deleteDocument: vi.fn(),
|
||||
markDocumentAsFailedTimeout: vi.fn(),
|
||||
retryDocumentProcessing: vi.fn(),
|
||||
processDocumentAsync: vi.fn(),
|
||||
}))
|
||||
|
||||
mockDrizzleOrm()
|
||||
mockConsoleLogger()
|
||||
|
||||
@@ -42,8 +49,6 @@ describe('Document By ID API Route', () => {
|
||||
transaction: vi.fn(),
|
||||
}
|
||||
|
||||
// Mock functions will be imported dynamically in tests
|
||||
|
||||
const mockDocument = {
|
||||
id: 'doc-123',
|
||||
knowledgeBaseId: 'kb-123',
|
||||
@@ -73,7 +78,6 @@ describe('Document By ID API Route', () => {
|
||||
}
|
||||
}
|
||||
})
|
||||
// Mock functions are cleared automatically by vitest
|
||||
}
|
||||
|
||||
beforeEach(async () => {
|
||||
@@ -83,8 +87,6 @@ describe('Document By ID API Route', () => {
|
||||
db: mockDbChain,
|
||||
}))
|
||||
|
||||
// Utils are mocked at the top level
|
||||
|
||||
vi.stubGlobal('crypto', {
|
||||
randomUUID: vi.fn().mockReturnValue('mock-uuid-1234-5678'),
|
||||
})
|
||||
@@ -195,6 +197,7 @@ describe('Document By ID API Route', () => {
|
||||
|
||||
it('should update document successfully', async () => {
|
||||
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
|
||||
const { updateDocument } = await import('@/lib/knowledge/documents/service')
|
||||
|
||||
mockAuth$.mockAuthenticatedUser()
|
||||
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
|
||||
@@ -203,31 +206,12 @@ describe('Document By ID API Route', () => {
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
|
||||
// Create a sequence of mocks for the database operations
|
||||
const updateChain = {
|
||||
set: vi.fn().mockReturnValue({
|
||||
where: vi.fn().mockResolvedValue(undefined), // Update operation completes
|
||||
}),
|
||||
const updatedDocument = {
|
||||
...mockDocument,
|
||||
...validUpdateData,
|
||||
deletedAt: null,
|
||||
}
|
||||
|
||||
const selectChain = {
|
||||
from: vi.fn().mockReturnValue({
|
||||
where: vi.fn().mockReturnValue({
|
||||
limit: vi.fn().mockResolvedValue([{ ...mockDocument, ...validUpdateData }]),
|
||||
}),
|
||||
}),
|
||||
}
|
||||
|
||||
// Mock transaction
|
||||
mockDbChain.transaction.mockImplementation(async (callback) => {
|
||||
const mockTx = {
|
||||
update: vi.fn().mockReturnValue(updateChain),
|
||||
}
|
||||
await callback(mockTx)
|
||||
})
|
||||
|
||||
// Mock db operations in sequence
|
||||
mockDbChain.select.mockReturnValue(selectChain)
|
||||
vi.mocked(updateDocument).mockResolvedValue(updatedDocument)
|
||||
|
||||
const req = createMockRequest('PUT', validUpdateData)
|
||||
const { PUT } = await import('@/app/api/knowledge/[id]/documents/[documentId]/route')
|
||||
@@ -238,8 +222,11 @@ describe('Document By ID API Route', () => {
|
||||
expect(data.success).toBe(true)
|
||||
expect(data.data.filename).toBe('updated-document.pdf')
|
||||
expect(data.data.enabled).toBe(false)
|
||||
expect(mockDbChain.transaction).toHaveBeenCalled()
|
||||
expect(mockDbChain.select).toHaveBeenCalled()
|
||||
expect(vi.mocked(updateDocument)).toHaveBeenCalledWith(
|
||||
'doc-123',
|
||||
validUpdateData,
|
||||
expect.any(String)
|
||||
)
|
||||
})
|
||||
|
||||
it('should validate update data', async () => {
|
||||
@@ -274,6 +261,7 @@ describe('Document By ID API Route', () => {
|
||||
|
||||
it('should mark document as failed due to timeout successfully', async () => {
|
||||
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
|
||||
const { markDocumentAsFailedTimeout } = await import('@/lib/knowledge/documents/service')
|
||||
|
||||
const processingDocument = {
|
||||
...mockDocument,
|
||||
@@ -288,34 +276,11 @@ describe('Document By ID API Route', () => {
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
|
||||
// Create a sequence of mocks for the database operations
|
||||
const updateChain = {
|
||||
set: vi.fn().mockReturnValue({
|
||||
where: vi.fn().mockResolvedValue(undefined), // Update operation completes
|
||||
}),
|
||||
}
|
||||
|
||||
const selectChain = {
|
||||
from: vi.fn().mockReturnValue({
|
||||
where: vi.fn().mockReturnValue({
|
||||
limit: vi
|
||||
.fn()
|
||||
.mockResolvedValue([{ ...processingDocument, processingStatus: 'failed' }]),
|
||||
}),
|
||||
}),
|
||||
}
|
||||
|
||||
// Mock transaction
|
||||
mockDbChain.transaction.mockImplementation(async (callback) => {
|
||||
const mockTx = {
|
||||
update: vi.fn().mockReturnValue(updateChain),
|
||||
}
|
||||
await callback(mockTx)
|
||||
vi.mocked(markDocumentAsFailedTimeout).mockResolvedValue({
|
||||
success: true,
|
||||
processingDuration: 200000,
|
||||
})
|
||||
|
||||
// Mock db operations in sequence
|
||||
mockDbChain.select.mockReturnValue(selectChain)
|
||||
|
||||
const req = createMockRequest('PUT', { markFailedDueToTimeout: true })
|
||||
const { PUT } = await import('@/app/api/knowledge/[id]/documents/[documentId]/route')
|
||||
const response = await PUT(req, { params: mockParams })
|
||||
@@ -323,13 +288,13 @@ describe('Document By ID API Route', () => {
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
expect(data.success).toBe(true)
|
||||
expect(mockDbChain.transaction).toHaveBeenCalled()
|
||||
expect(updateChain.set).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
processingStatus: 'failed',
|
||||
processingError: 'Processing timed out - background process may have been terminated',
|
||||
processingCompletedAt: expect.any(Date),
|
||||
})
|
||||
expect(data.data.documentId).toBe('doc-123')
|
||||
expect(data.data.status).toBe('failed')
|
||||
expect(data.data.message).toBe('Document marked as failed due to timeout')
|
||||
expect(vi.mocked(markDocumentAsFailedTimeout)).toHaveBeenCalledWith(
|
||||
'doc-123',
|
||||
processingDocument.processingStartedAt,
|
||||
expect.any(String)
|
||||
)
|
||||
})
|
||||
|
||||
@@ -354,6 +319,7 @@ describe('Document By ID API Route', () => {
|
||||
|
||||
it('should reject marking failed for recently started processing', async () => {
|
||||
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
|
||||
const { markDocumentAsFailedTimeout } = await import('@/lib/knowledge/documents/service')
|
||||
|
||||
const recentProcessingDocument = {
|
||||
...mockDocument,
|
||||
@@ -368,6 +334,10 @@ describe('Document By ID API Route', () => {
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
|
||||
vi.mocked(markDocumentAsFailedTimeout).mockRejectedValue(
|
||||
new Error('Document has not been processing long enough to be considered dead')
|
||||
)
|
||||
|
||||
const req = createMockRequest('PUT', { markFailedDueToTimeout: true })
|
||||
const { PUT } = await import('@/app/api/knowledge/[id]/documents/[documentId]/route')
|
||||
const response = await PUT(req, { params: mockParams })
|
||||
@@ -382,9 +352,8 @@ describe('Document By ID API Route', () => {
|
||||
const mockParams = Promise.resolve({ id: 'kb-123', documentId: 'doc-123' })
|
||||
|
||||
it('should retry processing successfully', async () => {
|
||||
const { checkDocumentWriteAccess, processDocumentAsync } = await import(
|
||||
'@/app/api/knowledge/utils'
|
||||
)
|
||||
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
|
||||
const { retryDocumentProcessing } = await import('@/lib/knowledge/documents/service')
|
||||
|
||||
const failedDocument = {
|
||||
...mockDocument,
|
||||
@@ -399,23 +368,12 @@ describe('Document By ID API Route', () => {
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
|
||||
// Mock transaction
|
||||
mockDbChain.transaction.mockImplementation(async (callback) => {
|
||||
const mockTx = {
|
||||
delete: vi.fn().mockReturnValue({
|
||||
where: vi.fn().mockResolvedValue(undefined),
|
||||
}),
|
||||
update: vi.fn().mockReturnValue({
|
||||
set: vi.fn().mockReturnValue({
|
||||
where: vi.fn().mockResolvedValue(undefined),
|
||||
}),
|
||||
}),
|
||||
}
|
||||
return await callback(mockTx)
|
||||
vi.mocked(retryDocumentProcessing).mockResolvedValue({
|
||||
success: true,
|
||||
status: 'pending',
|
||||
message: 'Document retry processing started',
|
||||
})
|
||||
|
||||
vi.mocked(processDocumentAsync).mockResolvedValue(undefined)
|
||||
|
||||
const req = createMockRequest('PUT', { retryProcessing: true })
|
||||
const { PUT } = await import('@/app/api/knowledge/[id]/documents/[documentId]/route')
|
||||
const response = await PUT(req, { params: mockParams })
|
||||
@@ -425,8 +383,17 @@ describe('Document By ID API Route', () => {
|
||||
expect(data.success).toBe(true)
|
||||
expect(data.data.status).toBe('pending')
|
||||
expect(data.data.message).toBe('Document retry processing started')
|
||||
expect(mockDbChain.transaction).toHaveBeenCalled()
|
||||
expect(vi.mocked(processDocumentAsync)).toHaveBeenCalled()
|
||||
expect(vi.mocked(retryDocumentProcessing)).toHaveBeenCalledWith(
|
||||
'kb-123',
|
||||
'doc-123',
|
||||
{
|
||||
filename: failedDocument.filename,
|
||||
fileUrl: failedDocument.fileUrl,
|
||||
fileSize: failedDocument.fileSize,
|
||||
mimeType: failedDocument.mimeType,
|
||||
},
|
||||
expect.any(String)
|
||||
)
|
||||
})
|
||||
|
||||
it('should reject retry for non-failed document', async () => {
|
||||
@@ -486,6 +453,7 @@ describe('Document By ID API Route', () => {
|
||||
|
||||
it('should handle database errors during update', async () => {
|
||||
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
|
||||
const { updateDocument } = await import('@/lib/knowledge/documents/service')
|
||||
|
||||
mockAuth$.mockAuthenticatedUser()
|
||||
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
|
||||
@@ -494,8 +462,7 @@ describe('Document By ID API Route', () => {
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
|
||||
// Mock transaction to throw an error
|
||||
mockDbChain.transaction.mockRejectedValue(new Error('Database error'))
|
||||
vi.mocked(updateDocument).mockRejectedValue(new Error('Database error'))
|
||||
|
||||
const req = createMockRequest('PUT', validUpdateData)
|
||||
const { PUT } = await import('@/app/api/knowledge/[id]/documents/[documentId]/route')
|
||||
@@ -512,6 +479,7 @@ describe('Document By ID API Route', () => {
|
||||
|
||||
it('should delete document successfully', async () => {
|
||||
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
|
||||
const { deleteDocument } = await import('@/lib/knowledge/documents/service')
|
||||
|
||||
mockAuth$.mockAuthenticatedUser()
|
||||
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
|
||||
@@ -520,10 +488,10 @@ describe('Document By ID API Route', () => {
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
|
||||
// Properly chain the mock database operations for soft delete
|
||||
mockDbChain.update.mockReturnValue(mockDbChain)
|
||||
mockDbChain.set.mockReturnValue(mockDbChain)
|
||||
mockDbChain.where.mockResolvedValue(undefined) // Update operation resolves
|
||||
vi.mocked(deleteDocument).mockResolvedValue({
|
||||
success: true,
|
||||
message: 'Document deleted successfully',
|
||||
})
|
||||
|
||||
const req = createMockRequest('DELETE')
|
||||
const { DELETE } = await import('@/app/api/knowledge/[id]/documents/[documentId]/route')
|
||||
@@ -533,12 +501,7 @@ describe('Document By ID API Route', () => {
|
||||
expect(response.status).toBe(200)
|
||||
expect(data.success).toBe(true)
|
||||
expect(data.data.message).toBe('Document deleted successfully')
|
||||
expect(mockDbChain.update).toHaveBeenCalled()
|
||||
expect(mockDbChain.set).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
deletedAt: expect.any(Date),
|
||||
})
|
||||
)
|
||||
expect(vi.mocked(deleteDocument)).toHaveBeenCalledWith('doc-123', expect.any(String))
|
||||
})
|
||||
|
||||
it('should return unauthorized for unauthenticated user', async () => {
|
||||
@@ -592,6 +555,7 @@ describe('Document By ID API Route', () => {
|
||||
|
||||
it('should handle database errors during deletion', async () => {
|
||||
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
|
||||
const { deleteDocument } = await import('@/lib/knowledge/documents/service')
|
||||
|
||||
mockAuth$.mockAuthenticatedUser()
|
||||
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
|
||||
@@ -599,7 +563,7 @@ describe('Document By ID API Route', () => {
|
||||
document: mockDocument,
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
mockDbChain.set.mockRejectedValue(new Error('Database error'))
|
||||
vi.mocked(deleteDocument).mockRejectedValue(new Error('Database error'))
|
||||
|
||||
const req = createMockRequest('DELETE')
|
||||
const { DELETE } = await import('@/app/api/knowledge/[id]/documents/[documentId]/route')
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
import { eq } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { TAG_SLOTS } from '@/lib/constants/knowledge'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import {
|
||||
checkDocumentAccess,
|
||||
checkDocumentWriteAccess,
|
||||
processDocumentAsync,
|
||||
} from '@/app/api/knowledge/utils'
|
||||
import { db } from '@/db'
|
||||
import { document, embedding } from '@/db/schema'
|
||||
deleteDocument,
|
||||
markDocumentAsFailedTimeout,
|
||||
retryDocumentProcessing,
|
||||
updateDocument,
|
||||
} from '@/lib/knowledge/documents/service'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { checkDocumentAccess, checkDocumentWriteAccess } from '@/app/api/knowledge/utils'
|
||||
|
||||
const logger = createLogger('DocumentByIdAPI')
|
||||
|
||||
@@ -113,9 +111,7 @@ export async function PUT(
|
||||
|
||||
const updateData: any = {}
|
||||
|
||||
// Handle special operations first
|
||||
if (validatedData.markFailedDueToTimeout) {
|
||||
// Mark document as failed due to timeout (replaces mark-failed endpoint)
|
||||
const doc = accessCheck.document
|
||||
|
||||
if (doc.processingStatus !== 'processing') {
|
||||
@@ -132,58 +128,30 @@ export async function PUT(
|
||||
)
|
||||
}
|
||||
|
||||
const now = new Date()
|
||||
const processingDuration = now.getTime() - new Date(doc.processingStartedAt).getTime()
|
||||
const DEAD_PROCESS_THRESHOLD_MS = 150 * 1000
|
||||
try {
|
||||
await markDocumentAsFailedTimeout(documentId, doc.processingStartedAt, requestId)
|
||||
|
||||
if (processingDuration <= DEAD_PROCESS_THRESHOLD_MS) {
|
||||
return NextResponse.json(
|
||||
{ error: 'Document has not been processing long enough to be considered dead' },
|
||||
{ status: 400 }
|
||||
)
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
data: {
|
||||
documentId,
|
||||
status: 'failed',
|
||||
message: 'Document marked as failed due to timeout',
|
||||
},
|
||||
})
|
||||
} catch (error) {
|
||||
if (error instanceof Error) {
|
||||
return NextResponse.json({ error: error.message }, { status: 400 })
|
||||
}
|
||||
throw error
|
||||
}
|
||||
|
||||
updateData.processingStatus = 'failed'
|
||||
updateData.processingError =
|
||||
'Processing timed out - background process may have been terminated'
|
||||
updateData.processingCompletedAt = now
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Marked document ${documentId} as failed due to dead process (processing time: ${Math.round(processingDuration / 1000)}s)`
|
||||
)
|
||||
} else if (validatedData.retryProcessing) {
|
||||
// Retry processing (replaces retry endpoint)
|
||||
const doc = accessCheck.document
|
||||
|
||||
if (doc.processingStatus !== 'failed') {
|
||||
return NextResponse.json({ error: 'Document is not in failed state' }, { status: 400 })
|
||||
}
|
||||
|
||||
// Clear existing embeddings and reset document state
|
||||
await db.transaction(async (tx) => {
|
||||
await tx.delete(embedding).where(eq(embedding.documentId, documentId))
|
||||
|
||||
await tx
|
||||
.update(document)
|
||||
.set({
|
||||
processingStatus: 'pending',
|
||||
processingStartedAt: null,
|
||||
processingCompletedAt: null,
|
||||
processingError: null,
|
||||
chunkCount: 0,
|
||||
tokenCount: 0,
|
||||
characterCount: 0,
|
||||
})
|
||||
.where(eq(document.id, documentId))
|
||||
})
|
||||
|
||||
const processingOptions = {
|
||||
chunkSize: 1024,
|
||||
minCharactersPerChunk: 24,
|
||||
recipe: 'default',
|
||||
lang: 'en',
|
||||
}
|
||||
|
||||
const docData = {
|
||||
filename: doc.filename,
|
||||
fileUrl: doc.fileUrl,
|
||||
@@ -191,80 +159,33 @@ export async function PUT(
|
||||
mimeType: doc.mimeType,
|
||||
}
|
||||
|
||||
processDocumentAsync(knowledgeBaseId, documentId, docData, processingOptions).catch(
|
||||
(error: unknown) => {
|
||||
logger.error(`[${requestId}] Background retry processing error:`, error)
|
||||
}
|
||||
const result = await retryDocumentProcessing(
|
||||
knowledgeBaseId,
|
||||
documentId,
|
||||
docData,
|
||||
requestId
|
||||
)
|
||||
|
||||
logger.info(`[${requestId}] Document retry initiated: ${documentId}`)
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
data: {
|
||||
documentId,
|
||||
status: 'pending',
|
||||
message: 'Document retry processing started',
|
||||
status: result.status,
|
||||
message: result.message,
|
||||
},
|
||||
})
|
||||
} else {
|
||||
// Regular field updates
|
||||
if (validatedData.filename !== undefined) updateData.filename = validatedData.filename
|
||||
if (validatedData.enabled !== undefined) updateData.enabled = validatedData.enabled
|
||||
if (validatedData.chunkCount !== undefined) updateData.chunkCount = validatedData.chunkCount
|
||||
if (validatedData.tokenCount !== undefined) updateData.tokenCount = validatedData.tokenCount
|
||||
if (validatedData.characterCount !== undefined)
|
||||
updateData.characterCount = validatedData.characterCount
|
||||
if (validatedData.processingStatus !== undefined)
|
||||
updateData.processingStatus = validatedData.processingStatus
|
||||
if (validatedData.processingError !== undefined)
|
||||
updateData.processingError = validatedData.processingError
|
||||
const updatedDocument = await updateDocument(documentId, validatedData, requestId)
|
||||
|
||||
// Tag field updates
|
||||
TAG_SLOTS.forEach((slot) => {
|
||||
if ((validatedData as any)[slot] !== undefined) {
|
||||
;(updateData as any)[slot] = (validatedData as any)[slot]
|
||||
}
|
||||
logger.info(
|
||||
`[${requestId}] Document updated: ${documentId} in knowledge base ${knowledgeBaseId}`
|
||||
)
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
data: updatedDocument,
|
||||
})
|
||||
}
|
||||
|
||||
await db.transaction(async (tx) => {
|
||||
// Update the document
|
||||
await tx.update(document).set(updateData).where(eq(document.id, documentId))
|
||||
|
||||
// If any tag fields were updated, also update the embeddings
|
||||
const hasTagUpdates = TAG_SLOTS.some((field) => (validatedData as any)[field] !== undefined)
|
||||
|
||||
if (hasTagUpdates) {
|
||||
const embeddingUpdateData: Record<string, string | null> = {}
|
||||
TAG_SLOTS.forEach((field) => {
|
||||
if ((validatedData as any)[field] !== undefined) {
|
||||
embeddingUpdateData[field] = (validatedData as any)[field] || null
|
||||
}
|
||||
})
|
||||
|
||||
await tx
|
||||
.update(embedding)
|
||||
.set(embeddingUpdateData)
|
||||
.where(eq(embedding.documentId, documentId))
|
||||
}
|
||||
})
|
||||
|
||||
// Fetch the updated document
|
||||
const updatedDocument = await db
|
||||
.select()
|
||||
.from(document)
|
||||
.where(eq(document.id, documentId))
|
||||
.limit(1)
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Document updated: ${documentId} in knowledge base ${knowledgeBaseId}`
|
||||
)
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
data: updatedDocument[0],
|
||||
})
|
||||
} catch (validationError) {
|
||||
if (validationError instanceof z.ZodError) {
|
||||
logger.warn(`[${requestId}] Invalid document update data`, {
|
||||
@@ -313,13 +234,7 @@ export async function DELETE(
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
// Soft delete by setting deletedAt timestamp
|
||||
await db
|
||||
.update(document)
|
||||
.set({
|
||||
deletedAt: new Date(),
|
||||
})
|
||||
.where(eq(document.id, documentId))
|
||||
const result = await deleteDocument(documentId, requestId)
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Document deleted: ${documentId} from knowledge base ${knowledgeBaseId}`
|
||||
@@ -327,7 +242,7 @@ export async function DELETE(
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
data: { message: 'Document deleted successfully' },
|
||||
data: result,
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error(`[${requestId}] Error deleting document`, error)
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
import { randomUUID } from 'crypto'
|
||||
import { and, eq, sql } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { SUPPORTED_FIELD_TYPES } from '@/lib/constants/knowledge'
|
||||
import {
|
||||
getMaxSlotsForFieldType,
|
||||
getSlotsForFieldType,
|
||||
SUPPORTED_FIELD_TYPES,
|
||||
} from '@/lib/constants/knowledge'
|
||||
cleanupUnusedTagDefinitions,
|
||||
createOrUpdateTagDefinitionsBulk,
|
||||
deleteAllTagDefinitions,
|
||||
getDocumentTagDefinitions,
|
||||
} from '@/lib/knowledge/tags/service'
|
||||
import type { BulkTagDefinitionsData } from '@/lib/knowledge/tags/types'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { checkKnowledgeBaseAccess, checkKnowledgeBaseWriteAccess } from '@/app/api/knowledge/utils'
|
||||
import { db } from '@/db'
|
||||
import { document, knowledgeBaseTagDefinitions } from '@/db/schema'
|
||||
import { checkDocumentAccess, checkDocumentWriteAccess } from '@/app/api/knowledge/utils'
|
||||
|
||||
export const dynamic = 'force-dynamic'
|
||||
|
||||
@@ -29,106 +29,6 @@ const BulkTagDefinitionsSchema = z.object({
|
||||
definitions: z.array(TagDefinitionSchema),
|
||||
})
|
||||
|
||||
// Helper function to get the next available slot for a knowledge base and field type
|
||||
async function getNextAvailableSlot(
|
||||
knowledgeBaseId: string,
|
||||
fieldType: string,
|
||||
existingBySlot?: Map<string, any>
|
||||
): Promise<string | null> {
|
||||
// Get available slots for this field type
|
||||
const availableSlots = getSlotsForFieldType(fieldType)
|
||||
let usedSlots: Set<string>
|
||||
|
||||
if (existingBySlot) {
|
||||
// Use provided map if available (for performance in batch operations)
|
||||
// Filter by field type
|
||||
usedSlots = new Set(
|
||||
Array.from(existingBySlot.entries())
|
||||
.filter(([_, def]) => def.fieldType === fieldType)
|
||||
.map(([slot, _]) => slot)
|
||||
)
|
||||
} else {
|
||||
// Query database for existing tag definitions of the same field type
|
||||
const existingDefinitions = await db
|
||||
.select({ tagSlot: knowledgeBaseTagDefinitions.tagSlot })
|
||||
.from(knowledgeBaseTagDefinitions)
|
||||
.where(
|
||||
and(
|
||||
eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId),
|
||||
eq(knowledgeBaseTagDefinitions.fieldType, fieldType)
|
||||
)
|
||||
)
|
||||
|
||||
usedSlots = new Set(existingDefinitions.map((def) => def.tagSlot))
|
||||
}
|
||||
|
||||
// Find the first available slot for this field type
|
||||
for (const slot of availableSlots) {
|
||||
if (!usedSlots.has(slot)) {
|
||||
return slot
|
||||
}
|
||||
}
|
||||
|
||||
return null // No available slots for this field type
|
||||
}
|
||||
|
||||
// Helper function to clean up unused tag definitions
|
||||
async function cleanupUnusedTagDefinitions(knowledgeBaseId: string, requestId: string) {
|
||||
try {
|
||||
logger.info(`[${requestId}] Starting cleanup for KB ${knowledgeBaseId}`)
|
||||
|
||||
// Get all tag definitions for this KB
|
||||
const allDefinitions = await db
|
||||
.select()
|
||||
.from(knowledgeBaseTagDefinitions)
|
||||
.where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId))
|
||||
|
||||
logger.info(`[${requestId}] Found ${allDefinitions.length} tag definitions to check`)
|
||||
|
||||
if (allDefinitions.length === 0) {
|
||||
return 0
|
||||
}
|
||||
|
||||
let cleanedCount = 0
|
||||
|
||||
// For each tag definition, check if any documents use that tag slot
|
||||
for (const definition of allDefinitions) {
|
||||
const slot = definition.tagSlot
|
||||
|
||||
// Use raw SQL with proper column name injection
|
||||
const countResult = await db.execute(sql`
|
||||
SELECT count(*) as count
|
||||
FROM document
|
||||
WHERE knowledge_base_id = ${knowledgeBaseId}
|
||||
AND ${sql.raw(slot)} IS NOT NULL
|
||||
AND trim(${sql.raw(slot)}) != ''
|
||||
`)
|
||||
const count = Number(countResult[0]?.count) || 0
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Tag ${definition.displayName} (${slot}): ${count} documents using it`
|
||||
)
|
||||
|
||||
// If count is 0, remove this tag definition
|
||||
if (count === 0) {
|
||||
await db
|
||||
.delete(knowledgeBaseTagDefinitions)
|
||||
.where(eq(knowledgeBaseTagDefinitions.id, definition.id))
|
||||
|
||||
cleanedCount++
|
||||
logger.info(
|
||||
`[${requestId}] Removed unused tag definition: ${definition.displayName} (${definition.tagSlot})`
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return cleanedCount
|
||||
} catch (error) {
|
||||
logger.warn(`[${requestId}] Failed to cleanup unused tag definitions:`, error)
|
||||
return 0 // Don't fail the main operation if cleanup fails
|
||||
}
|
||||
}
|
||||
|
||||
// GET /api/knowledge/[id]/documents/[documentId]/tag-definitions - Get tag definitions for a document
|
||||
export async function GET(
|
||||
req: NextRequest,
|
||||
@@ -145,35 +45,22 @@ export async function GET(
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
// Check if user has access to the knowledge base
|
||||
const accessCheck = await checkKnowledgeBaseAccess(knowledgeBaseId, session.user.id)
|
||||
if (!accessCheck.hasAccess) {
|
||||
return NextResponse.json({ error: 'Forbidden' }, { status: 403 })
|
||||
}
|
||||
|
||||
// Verify document exists and belongs to the knowledge base
|
||||
const documentExists = await db
|
||||
.select({ id: document.id })
|
||||
.from(document)
|
||||
.where(and(eq(document.id, documentId), eq(document.knowledgeBaseId, knowledgeBaseId)))
|
||||
.limit(1)
|
||||
|
||||
if (documentExists.length === 0) {
|
||||
return NextResponse.json({ error: 'Document not found' }, { status: 404 })
|
||||
const accessCheck = await checkDocumentAccess(knowledgeBaseId, documentId, session.user.id)
|
||||
if (!accessCheck.hasAccess) {
|
||||
if (accessCheck.notFound) {
|
||||
logger.warn(
|
||||
`[${requestId}] ${accessCheck.reason}: KB=${knowledgeBaseId}, Doc=${documentId}`
|
||||
)
|
||||
return NextResponse.json({ error: accessCheck.reason }, { status: 404 })
|
||||
}
|
||||
logger.warn(
|
||||
`[${requestId}] User ${session.user.id} attempted unauthorized document access: ${accessCheck.reason}`
|
||||
)
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
// Get tag definitions for the knowledge base
|
||||
const tagDefinitions = await db
|
||||
.select({
|
||||
id: knowledgeBaseTagDefinitions.id,
|
||||
tagSlot: knowledgeBaseTagDefinitions.tagSlot,
|
||||
displayName: knowledgeBaseTagDefinitions.displayName,
|
||||
fieldType: knowledgeBaseTagDefinitions.fieldType,
|
||||
createdAt: knowledgeBaseTagDefinitions.createdAt,
|
||||
updatedAt: knowledgeBaseTagDefinitions.updatedAt,
|
||||
})
|
||||
.from(knowledgeBaseTagDefinitions)
|
||||
.where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId))
|
||||
const tagDefinitions = await getDocumentTagDefinitions(knowledgeBaseId)
|
||||
|
||||
logger.info(`[${requestId}] Retrieved ${tagDefinitions.length} tag definitions`)
|
||||
|
||||
@@ -203,21 +90,19 @@ export async function POST(
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
// Check if user has write access to the knowledge base
|
||||
const accessCheck = await checkKnowledgeBaseWriteAccess(knowledgeBaseId, session.user.id)
|
||||
// Verify document exists and user has write access
|
||||
const accessCheck = await checkDocumentWriteAccess(knowledgeBaseId, documentId, session.user.id)
|
||||
if (!accessCheck.hasAccess) {
|
||||
return NextResponse.json({ error: 'Forbidden' }, { status: 403 })
|
||||
}
|
||||
|
||||
// Verify document exists and belongs to the knowledge base
|
||||
const documentExists = await db
|
||||
.select({ id: document.id })
|
||||
.from(document)
|
||||
.where(and(eq(document.id, documentId), eq(document.knowledgeBaseId, knowledgeBaseId)))
|
||||
.limit(1)
|
||||
|
||||
if (documentExists.length === 0) {
|
||||
return NextResponse.json({ error: 'Document not found' }, { status: 404 })
|
||||
if (accessCheck.notFound) {
|
||||
logger.warn(
|
||||
`[${requestId}] ${accessCheck.reason}: KB=${knowledgeBaseId}, Doc=${documentId}`
|
||||
)
|
||||
return NextResponse.json({ error: accessCheck.reason }, { status: 404 })
|
||||
}
|
||||
logger.warn(
|
||||
`[${requestId}] User ${session.user.id} attempted unauthorized document write access: ${accessCheck.reason}`
|
||||
)
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
let body
|
||||
@@ -238,197 +123,24 @@ export async function POST(
|
||||
|
||||
const validatedData = BulkTagDefinitionsSchema.parse(body)
|
||||
|
||||
// Validate slots are valid for their field types
|
||||
for (const definition of validatedData.definitions) {
|
||||
const validSlots = getSlotsForFieldType(definition.fieldType)
|
||||
if (validSlots.length === 0) {
|
||||
return NextResponse.json(
|
||||
{ error: `Unsupported field type: ${definition.fieldType}` },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
|
||||
if (!validSlots.includes(definition.tagSlot)) {
|
||||
return NextResponse.json(
|
||||
{
|
||||
error: `Invalid slot '${definition.tagSlot}' for field type '${definition.fieldType}'. Valid slots: ${validSlots.join(', ')}`,
|
||||
},
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
const bulkData: BulkTagDefinitionsData = {
|
||||
definitions: validatedData.definitions.map((def) => ({
|
||||
tagSlot: def.tagSlot,
|
||||
displayName: def.displayName,
|
||||
fieldType: def.fieldType,
|
||||
originalDisplayName: def._originalDisplayName,
|
||||
})),
|
||||
}
|
||||
|
||||
// Validate no duplicate tag slots within the same field type
|
||||
const slotsByFieldType = new Map<string, Set<string>>()
|
||||
for (const definition of validatedData.definitions) {
|
||||
if (!slotsByFieldType.has(definition.fieldType)) {
|
||||
slotsByFieldType.set(definition.fieldType, new Set())
|
||||
}
|
||||
const slotsForType = slotsByFieldType.get(definition.fieldType)!
|
||||
if (slotsForType.has(definition.tagSlot)) {
|
||||
return NextResponse.json(
|
||||
{
|
||||
error: `Duplicate slot '${definition.tagSlot}' for field type '${definition.fieldType}'`,
|
||||
},
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
slotsForType.add(definition.tagSlot)
|
||||
}
|
||||
|
||||
const now = new Date()
|
||||
const createdDefinitions: (typeof knowledgeBaseTagDefinitions.$inferSelect)[] = []
|
||||
|
||||
// Get existing definitions
|
||||
const existingDefinitions = await db
|
||||
.select()
|
||||
.from(knowledgeBaseTagDefinitions)
|
||||
.where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId))
|
||||
|
||||
// Group by field type for validation
|
||||
const existingByFieldType = new Map<string, number>()
|
||||
for (const def of existingDefinitions) {
|
||||
existingByFieldType.set(def.fieldType, (existingByFieldType.get(def.fieldType) || 0) + 1)
|
||||
}
|
||||
|
||||
// Validate we don't exceed limits per field type
|
||||
const newByFieldType = new Map<string, number>()
|
||||
for (const definition of validatedData.definitions) {
|
||||
// Skip validation for edit operations - they don't create new slots
|
||||
if (definition._originalDisplayName) {
|
||||
continue
|
||||
}
|
||||
|
||||
const existingTagNames = new Set(
|
||||
existingDefinitions
|
||||
.filter((def) => def.fieldType === definition.fieldType)
|
||||
.map((def) => def.displayName)
|
||||
)
|
||||
|
||||
if (!existingTagNames.has(definition.displayName)) {
|
||||
newByFieldType.set(
|
||||
definition.fieldType,
|
||||
(newByFieldType.get(definition.fieldType) || 0) + 1
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
for (const [fieldType, newCount] of newByFieldType.entries()) {
|
||||
const existingCount = existingByFieldType.get(fieldType) || 0
|
||||
const maxSlots = getMaxSlotsForFieldType(fieldType)
|
||||
|
||||
if (existingCount + newCount > maxSlots) {
|
||||
return NextResponse.json(
|
||||
{
|
||||
error: `Cannot create ${newCount} new '${fieldType}' tags. Knowledge base already has ${existingCount} '${fieldType}' tag definitions. Maximum is ${maxSlots} per field type.`,
|
||||
},
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Use transaction to ensure consistency
|
||||
await db.transaction(async (tx) => {
|
||||
// Create maps for lookups
|
||||
const existingByName = new Map(existingDefinitions.map((def) => [def.displayName, def]))
|
||||
const existingBySlot = new Map(existingDefinitions.map((def) => [def.tagSlot, def]))
|
||||
|
||||
// Process each definition
|
||||
for (const definition of validatedData.definitions) {
|
||||
if (definition._originalDisplayName) {
|
||||
// This is an EDIT operation - find by original name and update
|
||||
const originalDefinition = existingByName.get(definition._originalDisplayName)
|
||||
|
||||
if (originalDefinition) {
|
||||
logger.info(
|
||||
`[${requestId}] Editing tag definition: ${definition._originalDisplayName} -> ${definition.displayName} (slot ${originalDefinition.tagSlot})`
|
||||
)
|
||||
|
||||
await tx
|
||||
.update(knowledgeBaseTagDefinitions)
|
||||
.set({
|
||||
displayName: definition.displayName,
|
||||
fieldType: definition.fieldType,
|
||||
updatedAt: now,
|
||||
})
|
||||
.where(eq(knowledgeBaseTagDefinitions.id, originalDefinition.id))
|
||||
|
||||
createdDefinitions.push({
|
||||
...originalDefinition,
|
||||
displayName: definition.displayName,
|
||||
fieldType: definition.fieldType,
|
||||
updatedAt: now,
|
||||
})
|
||||
continue
|
||||
}
|
||||
logger.warn(
|
||||
`[${requestId}] Could not find original definition for: ${definition._originalDisplayName}`
|
||||
)
|
||||
}
|
||||
|
||||
// Regular create/update logic
|
||||
const existingByDisplayName = existingByName.get(definition.displayName)
|
||||
|
||||
if (existingByDisplayName) {
|
||||
// Display name exists - UPDATE operation
|
||||
logger.info(
|
||||
`[${requestId}] Updating existing tag definition: ${definition.displayName} (slot ${existingByDisplayName.tagSlot})`
|
||||
)
|
||||
|
||||
await tx
|
||||
.update(knowledgeBaseTagDefinitions)
|
||||
.set({
|
||||
fieldType: definition.fieldType,
|
||||
updatedAt: now,
|
||||
})
|
||||
.where(eq(knowledgeBaseTagDefinitions.id, existingByDisplayName.id))
|
||||
|
||||
createdDefinitions.push({
|
||||
...existingByDisplayName,
|
||||
fieldType: definition.fieldType,
|
||||
updatedAt: now,
|
||||
})
|
||||
} else {
|
||||
// Display name doesn't exist - CREATE operation
|
||||
const targetSlot = await getNextAvailableSlot(
|
||||
knowledgeBaseId,
|
||||
definition.fieldType,
|
||||
existingBySlot
|
||||
)
|
||||
|
||||
if (!targetSlot) {
|
||||
logger.error(
|
||||
`[${requestId}] No available slots for new tag definition: ${definition.displayName}`
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Creating new tag definition: ${definition.displayName} -> ${targetSlot}`
|
||||
)
|
||||
|
||||
const newDefinition = {
|
||||
id: randomUUID(),
|
||||
knowledgeBaseId,
|
||||
tagSlot: targetSlot as any,
|
||||
displayName: definition.displayName,
|
||||
fieldType: definition.fieldType,
|
||||
createdAt: now,
|
||||
updatedAt: now,
|
||||
}
|
||||
|
||||
await tx.insert(knowledgeBaseTagDefinitions).values(newDefinition)
|
||||
existingBySlot.set(targetSlot as any, newDefinition)
|
||||
createdDefinitions.push(newDefinition as any)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
logger.info(`[${requestId}] Created/updated ${createdDefinitions.length} tag definitions`)
|
||||
const result = await createOrUpdateTagDefinitionsBulk(knowledgeBaseId, bulkData, requestId)
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
data: createdDefinitions,
|
||||
data: {
|
||||
created: result.created,
|
||||
updated: result.updated,
|
||||
errors: result.errors,
|
||||
},
|
||||
})
|
||||
} catch (error) {
|
||||
if (error instanceof z.ZodError) {
|
||||
@@ -459,10 +171,19 @@ export async function DELETE(
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
// Check if user has write access to the knowledge base
|
||||
const accessCheck = await checkKnowledgeBaseWriteAccess(knowledgeBaseId, session.user.id)
|
||||
// Verify document exists and user has write access
|
||||
const accessCheck = await checkDocumentWriteAccess(knowledgeBaseId, documentId, session.user.id)
|
||||
if (!accessCheck.hasAccess) {
|
||||
return NextResponse.json({ error: 'Forbidden' }, { status: 403 })
|
||||
if (accessCheck.notFound) {
|
||||
logger.warn(
|
||||
`[${requestId}] ${accessCheck.reason}: KB=${knowledgeBaseId}, Doc=${documentId}`
|
||||
)
|
||||
return NextResponse.json({ error: accessCheck.reason }, { status: 404 })
|
||||
}
|
||||
logger.warn(
|
||||
`[${requestId}] User ${session.user.id} attempted unauthorized document write access: ${accessCheck.reason}`
|
||||
)
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
if (action === 'cleanup') {
|
||||
@@ -478,13 +199,12 @@ export async function DELETE(
|
||||
// Delete all tag definitions (original behavior)
|
||||
logger.info(`[${requestId}] Deleting all tag definitions for KB ${knowledgeBaseId}`)
|
||||
|
||||
const result = await db
|
||||
.delete(knowledgeBaseTagDefinitions)
|
||||
.where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId))
|
||||
const deletedCount = await deleteAllTagDefinitions(knowledgeBaseId, requestId)
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
message: 'Tag definitions deleted successfully',
|
||||
data: { deleted: deletedCount },
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error(`[${requestId}] Error with tag definitions operation`, error)
|
||||
|
||||
@@ -24,6 +24,19 @@ vi.mock('@/app/api/knowledge/utils', () => ({
|
||||
processDocumentAsync: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/lib/knowledge/documents/service', () => ({
|
||||
getDocuments: vi.fn(),
|
||||
createSingleDocument: vi.fn(),
|
||||
createDocumentRecords: vi.fn(),
|
||||
processDocumentsWithQueue: vi.fn(),
|
||||
getProcessingConfig: vi.fn(),
|
||||
bulkDocumentOperation: vi.fn(),
|
||||
updateDocument: vi.fn(),
|
||||
deleteDocument: vi.fn(),
|
||||
markDocumentAsFailedTimeout: vi.fn(),
|
||||
retryDocumentProcessing: vi.fn(),
|
||||
}))
|
||||
|
||||
mockDrizzleOrm()
|
||||
mockConsoleLogger()
|
||||
|
||||
@@ -72,7 +85,6 @@ describe('Knowledge Base Documents API Route', () => {
|
||||
}
|
||||
}
|
||||
})
|
||||
// Clear all mocks - they will be set up in individual tests
|
||||
}
|
||||
|
||||
beforeEach(async () => {
|
||||
@@ -96,6 +108,7 @@ describe('Knowledge Base Documents API Route', () => {
|
||||
|
||||
it('should retrieve documents successfully for authenticated user', async () => {
|
||||
const { checkKnowledgeBaseAccess } = await import('@/app/api/knowledge/utils')
|
||||
const { getDocuments } = await import('@/lib/knowledge/documents/service')
|
||||
|
||||
mockAuth$.mockAuthenticatedUser()
|
||||
vi.mocked(checkKnowledgeBaseAccess).mockResolvedValue({
|
||||
@@ -103,11 +116,15 @@ describe('Knowledge Base Documents API Route', () => {
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
|
||||
// Mock the count query (first query)
|
||||
mockDbChain.where.mockResolvedValueOnce([{ count: 1 }])
|
||||
|
||||
// Mock the documents query (second query)
|
||||
mockDbChain.offset.mockResolvedValue([mockDocument])
|
||||
vi.mocked(getDocuments).mockResolvedValue({
|
||||
documents: [mockDocument],
|
||||
pagination: {
|
||||
total: 1,
|
||||
limit: 50,
|
||||
offset: 0,
|
||||
hasMore: false,
|
||||
},
|
||||
})
|
||||
|
||||
const req = createMockRequest('GET')
|
||||
const { GET } = await import('@/app/api/knowledge/[id]/documents/route')
|
||||
@@ -118,12 +135,22 @@ describe('Knowledge Base Documents API Route', () => {
|
||||
expect(data.success).toBe(true)
|
||||
expect(data.data.documents).toHaveLength(1)
|
||||
expect(data.data.documents[0].id).toBe('doc-123')
|
||||
expect(mockDbChain.select).toHaveBeenCalled()
|
||||
expect(vi.mocked(checkKnowledgeBaseAccess)).toHaveBeenCalledWith('kb-123', 'user-123')
|
||||
expect(vi.mocked(getDocuments)).toHaveBeenCalledWith(
|
||||
'kb-123',
|
||||
{
|
||||
includeDisabled: false,
|
||||
search: undefined,
|
||||
limit: 50,
|
||||
offset: 0,
|
||||
},
|
||||
expect.any(String)
|
||||
)
|
||||
})
|
||||
|
||||
it('should filter disabled documents by default', async () => {
|
||||
const { checkKnowledgeBaseAccess } = await import('@/app/api/knowledge/utils')
|
||||
const { getDocuments } = await import('@/lib/knowledge/documents/service')
|
||||
|
||||
mockAuth$.mockAuthenticatedUser()
|
||||
vi.mocked(checkKnowledgeBaseAccess).mockResolvedValue({
|
||||
@@ -131,22 +158,36 @@ describe('Knowledge Base Documents API Route', () => {
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
|
||||
// Mock the count query (first query)
|
||||
mockDbChain.where.mockResolvedValueOnce([{ count: 1 }])
|
||||
|
||||
// Mock the documents query (second query)
|
||||
mockDbChain.offset.mockResolvedValue([mockDocument])
|
||||
vi.mocked(getDocuments).mockResolvedValue({
|
||||
documents: [mockDocument],
|
||||
pagination: {
|
||||
total: 1,
|
||||
limit: 50,
|
||||
offset: 0,
|
||||
hasMore: false,
|
||||
},
|
||||
})
|
||||
|
||||
const req = createMockRequest('GET')
|
||||
const { GET } = await import('@/app/api/knowledge/[id]/documents/route')
|
||||
const response = await GET(req, { params: mockParams })
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
expect(mockDbChain.where).toHaveBeenCalled()
|
||||
expect(vi.mocked(getDocuments)).toHaveBeenCalledWith(
|
||||
'kb-123',
|
||||
{
|
||||
includeDisabled: false,
|
||||
search: undefined,
|
||||
limit: 50,
|
||||
offset: 0,
|
||||
},
|
||||
expect.any(String)
|
||||
)
|
||||
})
|
||||
|
||||
it('should include disabled documents when requested', async () => {
|
||||
const { checkKnowledgeBaseAccess } = await import('@/app/api/knowledge/utils')
|
||||
const { getDocuments } = await import('@/lib/knowledge/documents/service')
|
||||
|
||||
mockAuth$.mockAuthenticatedUser()
|
||||
vi.mocked(checkKnowledgeBaseAccess).mockResolvedValue({
|
||||
@@ -154,11 +195,15 @@ describe('Knowledge Base Documents API Route', () => {
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
|
||||
// Mock the count query (first query)
|
||||
mockDbChain.where.mockResolvedValueOnce([{ count: 1 }])
|
||||
|
||||
// Mock the documents query (second query)
|
||||
mockDbChain.offset.mockResolvedValue([mockDocument])
|
||||
vi.mocked(getDocuments).mockResolvedValue({
|
||||
documents: [mockDocument],
|
||||
pagination: {
|
||||
total: 1,
|
||||
limit: 50,
|
||||
offset: 0,
|
||||
hasMore: false,
|
||||
},
|
||||
})
|
||||
|
||||
const url = 'http://localhost:3000/api/knowledge/kb-123/documents?includeDisabled=true'
|
||||
const req = new Request(url, { method: 'GET' }) as any
|
||||
@@ -167,6 +212,16 @@ describe('Knowledge Base Documents API Route', () => {
|
||||
const response = await GET(req, { params: mockParams })
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
expect(vi.mocked(getDocuments)).toHaveBeenCalledWith(
|
||||
'kb-123',
|
||||
{
|
||||
includeDisabled: true,
|
||||
search: undefined,
|
||||
limit: 50,
|
||||
offset: 0,
|
||||
},
|
||||
expect.any(String)
|
||||
)
|
||||
})
|
||||
|
||||
it('should return unauthorized for unauthenticated user', async () => {
|
||||
@@ -216,13 +271,14 @@ describe('Knowledge Base Documents API Route', () => {
|
||||
|
||||
it('should handle database errors', async () => {
|
||||
const { checkKnowledgeBaseAccess } = await import('@/app/api/knowledge/utils')
|
||||
const { getDocuments } = await import('@/lib/knowledge/documents/service')
|
||||
|
||||
mockAuth$.mockAuthenticatedUser()
|
||||
vi.mocked(checkKnowledgeBaseAccess).mockResolvedValue({
|
||||
hasAccess: true,
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
mockDbChain.orderBy.mockRejectedValue(new Error('Database error'))
|
||||
vi.mocked(getDocuments).mockRejectedValue(new Error('Database error'))
|
||||
|
||||
const req = createMockRequest('GET')
|
||||
const { GET } = await import('@/app/api/knowledge/[id]/documents/route')
|
||||
@@ -245,13 +301,35 @@ describe('Knowledge Base Documents API Route', () => {
|
||||
|
||||
it('should create single document successfully', async () => {
|
||||
const { checkKnowledgeBaseWriteAccess } = await import('@/app/api/knowledge/utils')
|
||||
const { createSingleDocument } = await import('@/lib/knowledge/documents/service')
|
||||
|
||||
mockAuth$.mockAuthenticatedUser()
|
||||
vi.mocked(checkKnowledgeBaseWriteAccess).mockResolvedValue({
|
||||
hasAccess: true,
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
mockDbChain.values.mockResolvedValue(undefined)
|
||||
|
||||
const createdDocument = {
|
||||
id: 'doc-123',
|
||||
knowledgeBaseId: 'kb-123',
|
||||
filename: validDocumentData.filename,
|
||||
fileUrl: validDocumentData.fileUrl,
|
||||
fileSize: validDocumentData.fileSize,
|
||||
mimeType: validDocumentData.mimeType,
|
||||
chunkCount: 0,
|
||||
tokenCount: 0,
|
||||
characterCount: 0,
|
||||
enabled: true,
|
||||
uploadedAt: new Date(),
|
||||
tag1: null,
|
||||
tag2: null,
|
||||
tag3: null,
|
||||
tag4: null,
|
||||
tag5: null,
|
||||
tag6: null,
|
||||
tag7: null,
|
||||
}
|
||||
vi.mocked(createSingleDocument).mockResolvedValue(createdDocument)
|
||||
|
||||
const req = createMockRequest('POST', validDocumentData)
|
||||
const { POST } = await import('@/app/api/knowledge/[id]/documents/route')
|
||||
@@ -262,7 +340,11 @@ describe('Knowledge Base Documents API Route', () => {
|
||||
expect(data.success).toBe(true)
|
||||
expect(data.data.filename).toBe(validDocumentData.filename)
|
||||
expect(data.data.fileUrl).toBe(validDocumentData.fileUrl)
|
||||
expect(mockDbChain.insert).toHaveBeenCalled()
|
||||
expect(vi.mocked(createSingleDocument)).toHaveBeenCalledWith(
|
||||
validDocumentData,
|
||||
'kb-123',
|
||||
expect.any(String)
|
||||
)
|
||||
})
|
||||
|
||||
it('should validate single document data', async () => {
|
||||
@@ -320,9 +402,9 @@ describe('Knowledge Base Documents API Route', () => {
|
||||
}
|
||||
|
||||
it('should create bulk documents successfully', async () => {
|
||||
const { checkKnowledgeBaseWriteAccess, processDocumentAsync } = await import(
|
||||
'@/app/api/knowledge/utils'
|
||||
)
|
||||
const { checkKnowledgeBaseWriteAccess } = await import('@/app/api/knowledge/utils')
|
||||
const { createDocumentRecords, processDocumentsWithQueue, getProcessingConfig } =
|
||||
await import('@/lib/knowledge/documents/service')
|
||||
|
||||
mockAuth$.mockAuthenticatedUser()
|
||||
vi.mocked(checkKnowledgeBaseWriteAccess).mockResolvedValue({
|
||||
@@ -330,17 +412,31 @@ describe('Knowledge Base Documents API Route', () => {
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
|
||||
// Mock transaction to return the created documents
|
||||
mockDbChain.transaction.mockImplementation(async (callback) => {
|
||||
const mockTx = {
|
||||
insert: vi.fn().mockReturnValue({
|
||||
values: vi.fn().mockResolvedValue(undefined),
|
||||
}),
|
||||
}
|
||||
return await callback(mockTx)
|
||||
})
|
||||
const createdDocuments = [
|
||||
{
|
||||
documentId: 'doc-1',
|
||||
filename: 'doc1.pdf',
|
||||
fileUrl: 'https://example.com/doc1.pdf',
|
||||
fileSize: 1024,
|
||||
mimeType: 'application/pdf',
|
||||
},
|
||||
{
|
||||
documentId: 'doc-2',
|
||||
filename: 'doc2.pdf',
|
||||
fileUrl: 'https://example.com/doc2.pdf',
|
||||
fileSize: 2048,
|
||||
mimeType: 'application/pdf',
|
||||
},
|
||||
]
|
||||
|
||||
vi.mocked(processDocumentAsync).mockResolvedValue(undefined)
|
||||
vi.mocked(createDocumentRecords).mockResolvedValue(createdDocuments)
|
||||
vi.mocked(processDocumentsWithQueue).mockResolvedValue(undefined)
|
||||
vi.mocked(getProcessingConfig).mockReturnValue({
|
||||
maxConcurrentDocuments: 8,
|
||||
batchSize: 20,
|
||||
delayBetweenBatches: 100,
|
||||
delayBetweenDocuments: 0,
|
||||
})
|
||||
|
||||
const req = createMockRequest('POST', validBulkData)
|
||||
const { POST } = await import('@/app/api/knowledge/[id]/documents/route')
|
||||
@@ -352,7 +448,12 @@ describe('Knowledge Base Documents API Route', () => {
|
||||
expect(data.data.total).toBe(2)
|
||||
expect(data.data.documentsCreated).toHaveLength(2)
|
||||
expect(data.data.processingMethod).toBe('background')
|
||||
expect(mockDbChain.transaction).toHaveBeenCalled()
|
||||
expect(vi.mocked(createDocumentRecords)).toHaveBeenCalledWith(
|
||||
validBulkData.documents,
|
||||
'kb-123',
|
||||
expect.any(String)
|
||||
)
|
||||
expect(vi.mocked(processDocumentsWithQueue)).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should validate bulk document data', async () => {
|
||||
@@ -394,9 +495,9 @@ describe('Knowledge Base Documents API Route', () => {
|
||||
})
|
||||
|
||||
it('should handle processing errors gracefully', async () => {
|
||||
const { checkKnowledgeBaseWriteAccess, processDocumentAsync } = await import(
|
||||
'@/app/api/knowledge/utils'
|
||||
)
|
||||
const { checkKnowledgeBaseWriteAccess } = await import('@/app/api/knowledge/utils')
|
||||
const { createDocumentRecords, processDocumentsWithQueue, getProcessingConfig } =
|
||||
await import('@/lib/knowledge/documents/service')
|
||||
|
||||
mockAuth$.mockAuthenticatedUser()
|
||||
vi.mocked(checkKnowledgeBaseWriteAccess).mockResolvedValue({
|
||||
@@ -404,26 +505,30 @@ describe('Knowledge Base Documents API Route', () => {
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
|
||||
// Mock transaction to succeed but processing to fail
|
||||
mockDbChain.transaction.mockImplementation(async (callback) => {
|
||||
const mockTx = {
|
||||
insert: vi.fn().mockReturnValue({
|
||||
values: vi.fn().mockResolvedValue(undefined),
|
||||
}),
|
||||
}
|
||||
return await callback(mockTx)
|
||||
})
|
||||
const createdDocuments = [
|
||||
{
|
||||
documentId: 'doc-1',
|
||||
filename: 'doc1.pdf',
|
||||
fileUrl: 'https://example.com/doc1.pdf',
|
||||
fileSize: 1024,
|
||||
mimeType: 'application/pdf',
|
||||
},
|
||||
]
|
||||
|
||||
// Don't reject the promise - the processing is async and catches errors internally
|
||||
vi.mocked(processDocumentAsync).mockResolvedValue(undefined)
|
||||
vi.mocked(createDocumentRecords).mockResolvedValue(createdDocuments)
|
||||
vi.mocked(processDocumentsWithQueue).mockResolvedValue(undefined)
|
||||
vi.mocked(getProcessingConfig).mockReturnValue({
|
||||
maxConcurrentDocuments: 8,
|
||||
batchSize: 20,
|
||||
delayBetweenBatches: 100,
|
||||
delayBetweenDocuments: 0,
|
||||
})
|
||||
|
||||
const req = createMockRequest('POST', validBulkData)
|
||||
const { POST } = await import('@/app/api/knowledge/[id]/documents/route')
|
||||
const response = await POST(req, { params: mockParams })
|
||||
const data = await response.json()
|
||||
|
||||
// The endpoint should still return success since documents are created
|
||||
// and processing happens asynchronously
|
||||
expect(response.status).toBe(200)
|
||||
expect(data.success).toBe(true)
|
||||
})
|
||||
@@ -485,13 +590,14 @@ describe('Knowledge Base Documents API Route', () => {
|
||||
|
||||
it('should handle database errors during creation', async () => {
|
||||
const { checkKnowledgeBaseWriteAccess } = await import('@/app/api/knowledge/utils')
|
||||
const { createSingleDocument } = await import('@/lib/knowledge/documents/service')
|
||||
|
||||
mockAuth$.mockAuthenticatedUser()
|
||||
vi.mocked(checkKnowledgeBaseWriteAccess).mockResolvedValue({
|
||||
hasAccess: true,
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
mockDbChain.values.mockRejectedValue(new Error('Database error'))
|
||||
vi.mocked(createSingleDocument).mockRejectedValue(new Error('Database error'))
|
||||
|
||||
const req = createMockRequest('POST', validDocumentData)
|
||||
const { POST } = await import('@/app/api/knowledge/[id]/documents/route')
|
||||
|
||||
@@ -1,279 +1,22 @@
|
||||
import { randomUUID } from 'crypto'
|
||||
import { and, desc, eq, inArray, isNull, sql } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { getSlotsForFieldType } from '@/lib/constants/knowledge'
|
||||
import {
|
||||
bulkDocumentOperation,
|
||||
createDocumentRecords,
|
||||
createSingleDocument,
|
||||
getDocuments,
|
||||
getProcessingConfig,
|
||||
processDocumentsWithQueue,
|
||||
} from '@/lib/knowledge/documents/service'
|
||||
import type { DocumentSortField, SortOrder } from '@/lib/knowledge/documents/types'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { getUserId } from '@/app/api/auth/oauth/utils'
|
||||
import {
|
||||
checkKnowledgeBaseAccess,
|
||||
checkKnowledgeBaseWriteAccess,
|
||||
processDocumentAsync,
|
||||
} from '@/app/api/knowledge/utils'
|
||||
import { db } from '@/db'
|
||||
import { document, knowledgeBaseTagDefinitions } from '@/db/schema'
|
||||
import { checkKnowledgeBaseAccess, checkKnowledgeBaseWriteAccess } from '@/app/api/knowledge/utils'
|
||||
|
||||
const logger = createLogger('DocumentsAPI')
|
||||
|
||||
const PROCESSING_CONFIG = {
|
||||
maxConcurrentDocuments: 3,
|
||||
batchSize: 5,
|
||||
delayBetweenBatches: 1000,
|
||||
delayBetweenDocuments: 500,
|
||||
}
|
||||
|
||||
// Helper function to get the next available slot for a knowledge base and field type
|
||||
async function getNextAvailableSlot(
|
||||
knowledgeBaseId: string,
|
||||
fieldType: string,
|
||||
existingBySlot?: Map<string, any>
|
||||
): Promise<string | null> {
|
||||
let usedSlots: Set<string>
|
||||
|
||||
if (existingBySlot) {
|
||||
// Use provided map if available (for performance in batch operations)
|
||||
// Filter by field type
|
||||
usedSlots = new Set(
|
||||
Array.from(existingBySlot.entries())
|
||||
.filter(([_, def]) => def.fieldType === fieldType)
|
||||
.map(([slot, _]) => slot)
|
||||
)
|
||||
} else {
|
||||
// Query database for existing tag definitions of the same field type
|
||||
const existingDefinitions = await db
|
||||
.select({ tagSlot: knowledgeBaseTagDefinitions.tagSlot })
|
||||
.from(knowledgeBaseTagDefinitions)
|
||||
.where(
|
||||
and(
|
||||
eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId),
|
||||
eq(knowledgeBaseTagDefinitions.fieldType, fieldType)
|
||||
)
|
||||
)
|
||||
|
||||
usedSlots = new Set(existingDefinitions.map((def) => def.tagSlot))
|
||||
}
|
||||
|
||||
// Find the first available slot for this field type
|
||||
const availableSlots = getSlotsForFieldType(fieldType)
|
||||
for (const slot of availableSlots) {
|
||||
if (!usedSlots.has(slot)) {
|
||||
return slot
|
||||
}
|
||||
}
|
||||
|
||||
return null // No available slots for this field type
|
||||
}
|
||||
|
||||
// Helper function to process structured document tags
|
||||
async function processDocumentTags(
|
||||
knowledgeBaseId: string,
|
||||
tagData: Array<{ tagName: string; fieldType: string; value: string }>,
|
||||
requestId: string
|
||||
): Promise<Record<string, string | null>> {
|
||||
const result: Record<string, string | null> = {}
|
||||
|
||||
// Initialize all text tag slots to null (only text type is supported currently)
|
||||
const textSlots = getSlotsForFieldType('text')
|
||||
textSlots.forEach((slot) => {
|
||||
result[slot] = null
|
||||
})
|
||||
|
||||
if (!Array.isArray(tagData) || tagData.length === 0) {
|
||||
return result
|
||||
}
|
||||
|
||||
try {
|
||||
// Get existing tag definitions
|
||||
const existingDefinitions = await db
|
||||
.select()
|
||||
.from(knowledgeBaseTagDefinitions)
|
||||
.where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId))
|
||||
|
||||
const existingByName = new Map(existingDefinitions.map((def) => [def.displayName, def]))
|
||||
const existingBySlot = new Map(existingDefinitions.map((def) => [def.tagSlot, def]))
|
||||
|
||||
// Process each tag
|
||||
for (const tag of tagData) {
|
||||
if (!tag.tagName?.trim() || !tag.value?.trim()) continue
|
||||
|
||||
const tagName = tag.tagName.trim()
|
||||
const fieldType = tag.fieldType
|
||||
const value = tag.value.trim()
|
||||
|
||||
let targetSlot: string | null = null
|
||||
|
||||
// Check if tag definition already exists
|
||||
const existingDef = existingByName.get(tagName)
|
||||
if (existingDef) {
|
||||
targetSlot = existingDef.tagSlot
|
||||
} else {
|
||||
// Find next available slot using the helper function
|
||||
targetSlot = await getNextAvailableSlot(knowledgeBaseId, fieldType, existingBySlot)
|
||||
|
||||
// Create new tag definition if we have a slot
|
||||
if (targetSlot) {
|
||||
const newDefinition = {
|
||||
id: randomUUID(),
|
||||
knowledgeBaseId,
|
||||
tagSlot: targetSlot as any,
|
||||
displayName: tagName,
|
||||
fieldType,
|
||||
createdAt: new Date(),
|
||||
updatedAt: new Date(),
|
||||
}
|
||||
|
||||
await db.insert(knowledgeBaseTagDefinitions).values(newDefinition)
|
||||
existingBySlot.set(targetSlot as any, newDefinition)
|
||||
|
||||
logger.info(`[${requestId}] Created tag definition: ${tagName} -> ${targetSlot}`)
|
||||
}
|
||||
}
|
||||
|
||||
// Assign value to the slot
|
||||
if (targetSlot) {
|
||||
result[targetSlot] = value
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
} catch (error) {
|
||||
logger.error(`[${requestId}] Error processing document tags:`, error)
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
async function processDocumentsWithConcurrencyControl(
|
||||
createdDocuments: Array<{
|
||||
documentId: string
|
||||
filename: string
|
||||
fileUrl: string
|
||||
fileSize: number
|
||||
mimeType: string
|
||||
}>,
|
||||
knowledgeBaseId: string,
|
||||
processingOptions: {
|
||||
chunkSize: number
|
||||
minCharactersPerChunk: number
|
||||
recipe: string
|
||||
lang: string
|
||||
chunkOverlap: number
|
||||
},
|
||||
requestId: string
|
||||
): Promise<void> {
|
||||
const totalDocuments = createdDocuments.length
|
||||
const batches = []
|
||||
|
||||
for (let i = 0; i < totalDocuments; i += PROCESSING_CONFIG.batchSize) {
|
||||
batches.push(createdDocuments.slice(i, i + PROCESSING_CONFIG.batchSize))
|
||||
}
|
||||
|
||||
logger.info(`[${requestId}] Processing ${totalDocuments} documents in ${batches.length} batches`)
|
||||
|
||||
for (const [batchIndex, batch] of batches.entries()) {
|
||||
logger.info(
|
||||
`[${requestId}] Starting batch ${batchIndex + 1}/${batches.length} with ${batch.length} documents`
|
||||
)
|
||||
|
||||
await processBatchWithConcurrency(batch, knowledgeBaseId, processingOptions, requestId)
|
||||
|
||||
if (batchIndex < batches.length - 1) {
|
||||
await new Promise((resolve) => setTimeout(resolve, PROCESSING_CONFIG.delayBetweenBatches))
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(`[${requestId}] Completed processing initiation for all ${totalDocuments} documents`)
|
||||
}
|
||||
|
||||
async function processBatchWithConcurrency(
|
||||
batch: Array<{
|
||||
documentId: string
|
||||
filename: string
|
||||
fileUrl: string
|
||||
fileSize: number
|
||||
mimeType: string
|
||||
}>,
|
||||
knowledgeBaseId: string,
|
||||
processingOptions: {
|
||||
chunkSize: number
|
||||
minCharactersPerChunk: number
|
||||
recipe: string
|
||||
lang: string
|
||||
chunkOverlap: number
|
||||
},
|
||||
requestId: string
|
||||
): Promise<void> {
|
||||
const semaphore = new Array(PROCESSING_CONFIG.maxConcurrentDocuments).fill(0)
|
||||
const processingPromises = batch.map(async (doc, index) => {
|
||||
if (index > 0) {
|
||||
await new Promise((resolve) =>
|
||||
setTimeout(resolve, index * PROCESSING_CONFIG.delayBetweenDocuments)
|
||||
)
|
||||
}
|
||||
|
||||
await new Promise<void>((resolve) => {
|
||||
const checkSlot = () => {
|
||||
const availableIndex = semaphore.findIndex((slot) => slot === 0)
|
||||
if (availableIndex !== -1) {
|
||||
semaphore[availableIndex] = 1
|
||||
resolve()
|
||||
} else {
|
||||
setTimeout(checkSlot, 100)
|
||||
}
|
||||
}
|
||||
checkSlot()
|
||||
})
|
||||
|
||||
try {
|
||||
logger.info(`[${requestId}] Starting processing for document: ${doc.filename}`)
|
||||
|
||||
await processDocumentAsync(
|
||||
knowledgeBaseId,
|
||||
doc.documentId,
|
||||
{
|
||||
filename: doc.filename,
|
||||
fileUrl: doc.fileUrl,
|
||||
fileSize: doc.fileSize,
|
||||
mimeType: doc.mimeType,
|
||||
},
|
||||
processingOptions
|
||||
)
|
||||
|
||||
logger.info(`[${requestId}] Successfully initiated processing for document: ${doc.filename}`)
|
||||
} catch (error: unknown) {
|
||||
logger.error(`[${requestId}] Failed to process document: ${doc.filename}`, {
|
||||
documentId: doc.documentId,
|
||||
filename: doc.filename,
|
||||
error: error instanceof Error ? error.message : 'Unknown error',
|
||||
})
|
||||
|
||||
try {
|
||||
await db
|
||||
.update(document)
|
||||
.set({
|
||||
processingStatus: 'failed',
|
||||
processingError:
|
||||
error instanceof Error ? error.message : 'Failed to initiate processing',
|
||||
processingCompletedAt: new Date(),
|
||||
})
|
||||
.where(eq(document.id, doc.documentId))
|
||||
} catch (dbError: unknown) {
|
||||
logger.error(
|
||||
`[${requestId}] Failed to update document status for failed document: ${doc.documentId}`,
|
||||
dbError
|
||||
)
|
||||
}
|
||||
} finally {
|
||||
const slotIndex = semaphore.findIndex((slot) => slot === 1)
|
||||
if (slotIndex !== -1) {
|
||||
semaphore[slotIndex] = 0
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
await Promise.allSettled(processingPromises)
|
||||
}
|
||||
|
||||
const CreateDocumentSchema = z.object({
|
||||
filename: z.string().min(1, 'Filename is required'),
|
||||
fileUrl: z.string().url('File URL must be valid'),
|
||||
@@ -337,83 +80,50 @@ export async function GET(req: NextRequest, { params }: { params: Promise<{ id:
|
||||
|
||||
const url = new URL(req.url)
|
||||
const includeDisabled = url.searchParams.get('includeDisabled') === 'true'
|
||||
const search = url.searchParams.get('search')
|
||||
const search = url.searchParams.get('search') || undefined
|
||||
const limit = Number.parseInt(url.searchParams.get('limit') || '50')
|
||||
const offset = Number.parseInt(url.searchParams.get('offset') || '0')
|
||||
const sortByParam = url.searchParams.get('sortBy')
|
||||
const sortOrderParam = url.searchParams.get('sortOrder')
|
||||
|
||||
// Build where conditions
|
||||
const whereConditions = [
|
||||
eq(document.knowledgeBaseId, knowledgeBaseId),
|
||||
isNull(document.deletedAt),
|
||||
// Validate sort parameters
|
||||
const validSortFields: DocumentSortField[] = [
|
||||
'filename',
|
||||
'fileSize',
|
||||
'tokenCount',
|
||||
'chunkCount',
|
||||
'uploadedAt',
|
||||
'processingStatus',
|
||||
]
|
||||
const validSortOrders: SortOrder[] = ['asc', 'desc']
|
||||
|
||||
// Filter out disabled documents unless specifically requested
|
||||
if (!includeDisabled) {
|
||||
whereConditions.push(eq(document.enabled, true))
|
||||
}
|
||||
const sortBy =
|
||||
sortByParam && validSortFields.includes(sortByParam as DocumentSortField)
|
||||
? (sortByParam as DocumentSortField)
|
||||
: undefined
|
||||
const sortOrder =
|
||||
sortOrderParam && validSortOrders.includes(sortOrderParam as SortOrder)
|
||||
? (sortOrderParam as SortOrder)
|
||||
: undefined
|
||||
|
||||
// Add search condition if provided
|
||||
if (search) {
|
||||
whereConditions.push(
|
||||
// Search in filename
|
||||
sql`LOWER(${document.filename}) LIKE LOWER(${`%${search}%`})`
|
||||
)
|
||||
}
|
||||
|
||||
// Get total count for pagination
|
||||
const totalResult = await db
|
||||
.select({ count: sql<number>`COUNT(*)` })
|
||||
.from(document)
|
||||
.where(and(...whereConditions))
|
||||
|
||||
const total = totalResult[0]?.count || 0
|
||||
const hasMore = offset + limit < total
|
||||
|
||||
const documents = await db
|
||||
.select({
|
||||
id: document.id,
|
||||
filename: document.filename,
|
||||
fileUrl: document.fileUrl,
|
||||
fileSize: document.fileSize,
|
||||
mimeType: document.mimeType,
|
||||
chunkCount: document.chunkCount,
|
||||
tokenCount: document.tokenCount,
|
||||
characterCount: document.characterCount,
|
||||
processingStatus: document.processingStatus,
|
||||
processingStartedAt: document.processingStartedAt,
|
||||
processingCompletedAt: document.processingCompletedAt,
|
||||
processingError: document.processingError,
|
||||
enabled: document.enabled,
|
||||
uploadedAt: document.uploadedAt,
|
||||
// Include tags in response
|
||||
tag1: document.tag1,
|
||||
tag2: document.tag2,
|
||||
tag3: document.tag3,
|
||||
tag4: document.tag4,
|
||||
tag5: document.tag5,
|
||||
tag6: document.tag6,
|
||||
tag7: document.tag7,
|
||||
})
|
||||
.from(document)
|
||||
.where(and(...whereConditions))
|
||||
.orderBy(desc(document.uploadedAt))
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Retrieved ${documents.length} documents (${offset}-${offset + documents.length} of ${total}) for knowledge base ${knowledgeBaseId}`
|
||||
const result = await getDocuments(
|
||||
knowledgeBaseId,
|
||||
{
|
||||
includeDisabled,
|
||||
search,
|
||||
limit,
|
||||
offset,
|
||||
...(sortBy && { sortBy }),
|
||||
...(sortOrder && { sortOrder }),
|
||||
},
|
||||
requestId
|
||||
)
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
data: {
|
||||
documents,
|
||||
pagination: {
|
||||
total,
|
||||
limit,
|
||||
offset,
|
||||
hasMore,
|
||||
},
|
||||
documents: result.documents,
|
||||
pagination: result.pagination,
|
||||
},
|
||||
})
|
||||
} catch (error) {
|
||||
@@ -462,80 +172,21 @@ export async function POST(req: NextRequest, { params }: { params: Promise<{ id:
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
// Check if this is a bulk operation
|
||||
if (body.bulk === true) {
|
||||
// Handle bulk processing (replaces process-documents endpoint)
|
||||
try {
|
||||
const validatedData = BulkCreateDocumentsSchema.parse(body)
|
||||
|
||||
const createdDocuments = await db.transaction(async (tx) => {
|
||||
const documentPromises = validatedData.documents.map(async (docData) => {
|
||||
const documentId = randomUUID()
|
||||
const now = new Date()
|
||||
|
||||
// Process documentTagsData if provided (for knowledge base block)
|
||||
let processedTags: Record<string, string | null> = {
|
||||
tag1: null,
|
||||
tag2: null,
|
||||
tag3: null,
|
||||
tag4: null,
|
||||
tag5: null,
|
||||
tag6: null,
|
||||
tag7: null,
|
||||
}
|
||||
|
||||
if (docData.documentTagsData) {
|
||||
try {
|
||||
const tagData = JSON.parse(docData.documentTagsData)
|
||||
if (Array.isArray(tagData)) {
|
||||
processedTags = await processDocumentTags(knowledgeBaseId, tagData, requestId)
|
||||
}
|
||||
} catch (error) {
|
||||
logger.warn(
|
||||
`[${requestId}] Failed to parse documentTagsData for bulk document:`,
|
||||
error
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
const newDocument = {
|
||||
id: documentId,
|
||||
knowledgeBaseId,
|
||||
filename: docData.filename,
|
||||
fileUrl: docData.fileUrl,
|
||||
fileSize: docData.fileSize,
|
||||
mimeType: docData.mimeType,
|
||||
chunkCount: 0,
|
||||
tokenCount: 0,
|
||||
characterCount: 0,
|
||||
processingStatus: 'pending' as const,
|
||||
enabled: true,
|
||||
uploadedAt: now,
|
||||
// Use processed tags if available, otherwise fall back to individual tag fields
|
||||
tag1: processedTags.tag1 || docData.tag1 || null,
|
||||
tag2: processedTags.tag2 || docData.tag2 || null,
|
||||
tag3: processedTags.tag3 || docData.tag3 || null,
|
||||
tag4: processedTags.tag4 || docData.tag4 || null,
|
||||
tag5: processedTags.tag5 || docData.tag5 || null,
|
||||
tag6: processedTags.tag6 || docData.tag6 || null,
|
||||
tag7: processedTags.tag7 || docData.tag7 || null,
|
||||
}
|
||||
|
||||
await tx.insert(document).values(newDocument)
|
||||
logger.info(
|
||||
`[${requestId}] Document record created: ${documentId} for file: ${docData.filename}`
|
||||
)
|
||||
return { documentId, ...docData }
|
||||
})
|
||||
|
||||
return await Promise.all(documentPromises)
|
||||
})
|
||||
const createdDocuments = await createDocumentRecords(
|
||||
validatedData.documents,
|
||||
knowledgeBaseId,
|
||||
requestId
|
||||
)
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Starting controlled async processing of ${createdDocuments.length} documents`
|
||||
)
|
||||
|
||||
processDocumentsWithConcurrencyControl(
|
||||
processDocumentsWithQueue(
|
||||
createdDocuments,
|
||||
knowledgeBaseId,
|
||||
validatedData.processingOptions,
|
||||
@@ -555,9 +206,9 @@ export async function POST(req: NextRequest, { params }: { params: Promise<{ id:
|
||||
})),
|
||||
processingMethod: 'background',
|
||||
processingConfig: {
|
||||
maxConcurrentDocuments: PROCESSING_CONFIG.maxConcurrentDocuments,
|
||||
batchSize: PROCESSING_CONFIG.batchSize,
|
||||
totalBatches: Math.ceil(createdDocuments.length / PROCESSING_CONFIG.batchSize),
|
||||
maxConcurrentDocuments: getProcessingConfig().maxConcurrentDocuments,
|
||||
batchSize: getProcessingConfig().batchSize,
|
||||
totalBatches: Math.ceil(createdDocuments.length / getProcessingConfig().batchSize),
|
||||
},
|
||||
},
|
||||
})
|
||||
@@ -578,52 +229,7 @@ export async function POST(req: NextRequest, { params }: { params: Promise<{ id:
|
||||
try {
|
||||
const validatedData = CreateDocumentSchema.parse(body)
|
||||
|
||||
const documentId = randomUUID()
|
||||
const now = new Date()
|
||||
|
||||
// Process structured tag data if provided
|
||||
let processedTags: Record<string, string | null> = {
|
||||
tag1: validatedData.tag1 || null,
|
||||
tag2: validatedData.tag2 || null,
|
||||
tag3: validatedData.tag3 || null,
|
||||
tag4: validatedData.tag4 || null,
|
||||
tag5: validatedData.tag5 || null,
|
||||
tag6: validatedData.tag6 || null,
|
||||
tag7: validatedData.tag7 || null,
|
||||
}
|
||||
|
||||
if (validatedData.documentTagsData) {
|
||||
try {
|
||||
const tagData = JSON.parse(validatedData.documentTagsData)
|
||||
if (Array.isArray(tagData)) {
|
||||
// Process structured tag data and create tag definitions
|
||||
processedTags = await processDocumentTags(knowledgeBaseId, tagData, requestId)
|
||||
}
|
||||
} catch (error) {
|
||||
logger.warn(`[${requestId}] Failed to parse documentTagsData:`, error)
|
||||
}
|
||||
}
|
||||
|
||||
const newDocument = {
|
||||
id: documentId,
|
||||
knowledgeBaseId,
|
||||
filename: validatedData.filename,
|
||||
fileUrl: validatedData.fileUrl,
|
||||
fileSize: validatedData.fileSize,
|
||||
mimeType: validatedData.mimeType,
|
||||
chunkCount: 0,
|
||||
tokenCount: 0,
|
||||
characterCount: 0,
|
||||
enabled: true,
|
||||
uploadedAt: now,
|
||||
...processedTags,
|
||||
}
|
||||
|
||||
await db.insert(document).values(newDocument)
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Document created: ${documentId} in knowledge base ${knowledgeBaseId}`
|
||||
)
|
||||
const newDocument = await createSingleDocument(validatedData, knowledgeBaseId, requestId)
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
@@ -649,7 +255,7 @@ export async function POST(req: NextRequest, { params }: { params: Promise<{ id:
|
||||
}
|
||||
|
||||
export async function PATCH(req: NextRequest, { params }: { params: Promise<{ id: string }> }) {
|
||||
const requestId = crypto.randomUUID().slice(0, 8)
|
||||
const requestId = randomUUID().slice(0, 8)
|
||||
const { id: knowledgeBaseId } = await params
|
||||
|
||||
try {
|
||||
@@ -678,89 +284,28 @@ export async function PATCH(req: NextRequest, { params }: { params: Promise<{ id
|
||||
const validatedData = BulkUpdateDocumentsSchema.parse(body)
|
||||
const { operation, documentIds } = validatedData
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Starting bulk ${operation} operation on ${documentIds.length} documents in knowledge base ${knowledgeBaseId}`
|
||||
)
|
||||
|
||||
// Verify all documents belong to this knowledge base and user has access
|
||||
const documentsToUpdate = await db
|
||||
.select({
|
||||
id: document.id,
|
||||
enabled: document.enabled,
|
||||
})
|
||||
.from(document)
|
||||
.where(
|
||||
and(
|
||||
eq(document.knowledgeBaseId, knowledgeBaseId),
|
||||
inArray(document.id, documentIds),
|
||||
isNull(document.deletedAt)
|
||||
)
|
||||
)
|
||||
|
||||
if (documentsToUpdate.length === 0) {
|
||||
return NextResponse.json({ error: 'No valid documents found to update' }, { status: 404 })
|
||||
}
|
||||
|
||||
if (documentsToUpdate.length !== documentIds.length) {
|
||||
logger.warn(
|
||||
`[${requestId}] Some documents not found or don't belong to knowledge base. Requested: ${documentIds.length}, Found: ${documentsToUpdate.length}`
|
||||
)
|
||||
}
|
||||
|
||||
// Perform the bulk operation
|
||||
let updateResult: Array<{ id: string; enabled?: boolean; deletedAt?: Date | null }>
|
||||
let successCount: number
|
||||
|
||||
if (operation === 'delete') {
|
||||
// Handle bulk soft delete
|
||||
updateResult = await db
|
||||
.update(document)
|
||||
.set({
|
||||
deletedAt: new Date(),
|
||||
})
|
||||
.where(
|
||||
and(
|
||||
eq(document.knowledgeBaseId, knowledgeBaseId),
|
||||
inArray(document.id, documentIds),
|
||||
isNull(document.deletedAt)
|
||||
)
|
||||
)
|
||||
.returning({ id: document.id, deletedAt: document.deletedAt })
|
||||
|
||||
successCount = updateResult.length
|
||||
} else {
|
||||
// Handle bulk enable/disable
|
||||
const enabled = operation === 'enable'
|
||||
|
||||
updateResult = await db
|
||||
.update(document)
|
||||
.set({
|
||||
enabled,
|
||||
})
|
||||
.where(
|
||||
and(
|
||||
eq(document.knowledgeBaseId, knowledgeBaseId),
|
||||
inArray(document.id, documentIds),
|
||||
isNull(document.deletedAt)
|
||||
)
|
||||
)
|
||||
.returning({ id: document.id, enabled: document.enabled })
|
||||
|
||||
successCount = updateResult.length
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Bulk ${operation} operation completed: ${successCount} documents updated in knowledge base ${knowledgeBaseId}`
|
||||
)
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
data: {
|
||||
try {
|
||||
const result = await bulkDocumentOperation(
|
||||
knowledgeBaseId,
|
||||
operation,
|
||||
successCount,
|
||||
updatedDocuments: updateResult,
|
||||
},
|
||||
})
|
||||
documentIds,
|
||||
requestId
|
||||
)
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
data: {
|
||||
operation,
|
||||
successCount: result.successCount,
|
||||
updatedDocuments: result.updatedDocuments,
|
||||
},
|
||||
})
|
||||
} catch (error) {
|
||||
if (error instanceof Error && error.message === 'No valid documents found to update') {
|
||||
return NextResponse.json({ error: 'No valid documents found to update' }, { status: 404 })
|
||||
}
|
||||
throw error
|
||||
}
|
||||
} catch (validationError) {
|
||||
if (validationError instanceof z.ZodError) {
|
||||
logger.warn(`[${requestId}] Invalid bulk operation data`, {
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
import { randomUUID } from 'crypto'
|
||||
import { and, eq } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { getMaxSlotsForFieldType, getSlotsForFieldType } from '@/lib/constants/knowledge'
|
||||
import { getNextAvailableSlot, getTagDefinitions } from '@/lib/knowledge/tags/service'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { checkKnowledgeBaseAccess } from '@/app/api/knowledge/utils'
|
||||
import { db } from '@/db'
|
||||
import { knowledgeBaseTagDefinitions } from '@/db/schema'
|
||||
|
||||
const logger = createLogger('NextAvailableSlotAPI')
|
||||
|
||||
@@ -31,51 +28,36 @@ export async function GET(req: NextRequest, { params }: { params: Promise<{ id:
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
// Check if user has read access to the knowledge base
|
||||
const accessCheck = await checkKnowledgeBaseAccess(knowledgeBaseId, session.user.id)
|
||||
if (!accessCheck.hasAccess) {
|
||||
return NextResponse.json({ error: 'Forbidden' }, { status: 403 })
|
||||
}
|
||||
|
||||
// Get available slots for this field type
|
||||
const availableSlots = getSlotsForFieldType(fieldType)
|
||||
const maxSlots = getMaxSlotsForFieldType(fieldType)
|
||||
// Get existing definitions once and reuse
|
||||
const existingDefinitions = await getTagDefinitions(knowledgeBaseId)
|
||||
const usedSlots = existingDefinitions
|
||||
.filter((def) => def.fieldType === fieldType)
|
||||
.map((def) => def.tagSlot)
|
||||
|
||||
// Get existing tag definitions to find used slots for this field type
|
||||
const existingDefinitions = await db
|
||||
.select({ tagSlot: knowledgeBaseTagDefinitions.tagSlot })
|
||||
.from(knowledgeBaseTagDefinitions)
|
||||
.where(
|
||||
and(
|
||||
eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId),
|
||||
eq(knowledgeBaseTagDefinitions.fieldType, fieldType)
|
||||
)
|
||||
)
|
||||
|
||||
const usedSlots = new Set(existingDefinitions.map((def) => def.tagSlot as string))
|
||||
|
||||
// Find the first available slot for this field type
|
||||
let nextAvailableSlot: string | null = null
|
||||
for (const slot of availableSlots) {
|
||||
if (!usedSlots.has(slot)) {
|
||||
nextAvailableSlot = slot
|
||||
break
|
||||
}
|
||||
}
|
||||
// Create a map for efficient lookup and pass to avoid redundant query
|
||||
const existingBySlot = new Map(existingDefinitions.map((def) => [def.tagSlot as string, def]))
|
||||
const nextAvailableSlot = await getNextAvailableSlot(knowledgeBaseId, fieldType, existingBySlot)
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Next available slot for fieldType ${fieldType}: ${nextAvailableSlot}`
|
||||
)
|
||||
|
||||
const result = {
|
||||
nextAvailableSlot,
|
||||
fieldType,
|
||||
usedSlots,
|
||||
totalSlots: 7,
|
||||
availableSlots: nextAvailableSlot ? 7 - usedSlots.length : 0,
|
||||
}
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
data: {
|
||||
nextAvailableSlot,
|
||||
fieldType,
|
||||
usedSlots: Array.from(usedSlots),
|
||||
totalSlots: maxSlots,
|
||||
availableSlots: maxSlots - usedSlots.size,
|
||||
},
|
||||
data: result,
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error(`[${requestId}] Error getting next available slot`, error)
|
||||
|
||||
@@ -16,9 +16,26 @@ mockKnowledgeSchemas()
|
||||
mockDrizzleOrm()
|
||||
mockConsoleLogger()
|
||||
|
||||
vi.mock('@/lib/knowledge/service', () => ({
|
||||
getKnowledgeBaseById: vi.fn(),
|
||||
updateKnowledgeBase: vi.fn(),
|
||||
deleteKnowledgeBase: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/api/knowledge/utils', () => ({
|
||||
checkKnowledgeBaseAccess: vi.fn(),
|
||||
checkKnowledgeBaseWriteAccess: vi.fn(),
|
||||
}))
|
||||
|
||||
describe('Knowledge Base By ID API Route', () => {
|
||||
const mockAuth$ = mockAuth()
|
||||
|
||||
let mockGetKnowledgeBaseById: any
|
||||
let mockUpdateKnowledgeBase: any
|
||||
let mockDeleteKnowledgeBase: any
|
||||
let mockCheckKnowledgeBaseAccess: any
|
||||
let mockCheckKnowledgeBaseWriteAccess: any
|
||||
|
||||
const mockDbChain = {
|
||||
select: vi.fn().mockReturnThis(),
|
||||
from: vi.fn().mockReturnThis(),
|
||||
@@ -62,6 +79,15 @@ describe('Knowledge Base By ID API Route', () => {
|
||||
vi.stubGlobal('crypto', {
|
||||
randomUUID: vi.fn().mockReturnValue('mock-uuid-1234-5678'),
|
||||
})
|
||||
|
||||
const knowledgeService = await import('@/lib/knowledge/service')
|
||||
const knowledgeUtils = await import('@/app/api/knowledge/utils')
|
||||
|
||||
mockGetKnowledgeBaseById = knowledgeService.getKnowledgeBaseById as any
|
||||
mockUpdateKnowledgeBase = knowledgeService.updateKnowledgeBase as any
|
||||
mockDeleteKnowledgeBase = knowledgeService.deleteKnowledgeBase as any
|
||||
mockCheckKnowledgeBaseAccess = knowledgeUtils.checkKnowledgeBaseAccess as any
|
||||
mockCheckKnowledgeBaseWriteAccess = knowledgeUtils.checkKnowledgeBaseWriteAccess as any
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
@@ -74,9 +100,12 @@ describe('Knowledge Base By ID API Route', () => {
|
||||
it('should retrieve knowledge base successfully for authenticated user', async () => {
|
||||
mockAuth$.mockAuthenticatedUser()
|
||||
|
||||
mockDbChain.limit.mockResolvedValueOnce([{ id: 'kb-123', userId: 'user-123' }])
|
||||
mockCheckKnowledgeBaseAccess.mockResolvedValueOnce({
|
||||
hasAccess: true,
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
|
||||
mockDbChain.limit.mockResolvedValueOnce([mockKnowledgeBase])
|
||||
mockGetKnowledgeBaseById.mockResolvedValueOnce(mockKnowledgeBase)
|
||||
|
||||
const req = createMockRequest('GET')
|
||||
const { GET } = await import('@/app/api/knowledge/[id]/route')
|
||||
@@ -87,7 +116,8 @@ describe('Knowledge Base By ID API Route', () => {
|
||||
expect(data.success).toBe(true)
|
||||
expect(data.data.id).toBe('kb-123')
|
||||
expect(data.data.name).toBe('Test Knowledge Base')
|
||||
expect(mockDbChain.select).toHaveBeenCalled()
|
||||
expect(mockCheckKnowledgeBaseAccess).toHaveBeenCalledWith('kb-123', 'user-123')
|
||||
expect(mockGetKnowledgeBaseById).toHaveBeenCalledWith('kb-123')
|
||||
})
|
||||
|
||||
it('should return unauthorized for unauthenticated user', async () => {
|
||||
@@ -105,7 +135,10 @@ describe('Knowledge Base By ID API Route', () => {
|
||||
it('should return not found for non-existent knowledge base', async () => {
|
||||
mockAuth$.mockAuthenticatedUser()
|
||||
|
||||
mockDbChain.limit.mockResolvedValueOnce([])
|
||||
mockCheckKnowledgeBaseAccess.mockResolvedValueOnce({
|
||||
hasAccess: false,
|
||||
notFound: true,
|
||||
})
|
||||
|
||||
const req = createMockRequest('GET')
|
||||
const { GET } = await import('@/app/api/knowledge/[id]/route')
|
||||
@@ -119,7 +152,10 @@ describe('Knowledge Base By ID API Route', () => {
|
||||
it('should return unauthorized for knowledge base owned by different user', async () => {
|
||||
mockAuth$.mockAuthenticatedUser()
|
||||
|
||||
mockDbChain.limit.mockResolvedValueOnce([{ id: 'kb-123', userId: 'different-user' }])
|
||||
mockCheckKnowledgeBaseAccess.mockResolvedValueOnce({
|
||||
hasAccess: false,
|
||||
notFound: false,
|
||||
})
|
||||
|
||||
const req = createMockRequest('GET')
|
||||
const { GET } = await import('@/app/api/knowledge/[id]/route')
|
||||
@@ -130,9 +166,29 @@ describe('Knowledge Base By ID API Route', () => {
|
||||
expect(data.error).toBe('Unauthorized')
|
||||
})
|
||||
|
||||
it('should return not found when service returns null', async () => {
|
||||
mockAuth$.mockAuthenticatedUser()
|
||||
|
||||
mockCheckKnowledgeBaseAccess.mockResolvedValueOnce({
|
||||
hasAccess: true,
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
|
||||
mockGetKnowledgeBaseById.mockResolvedValueOnce(null)
|
||||
|
||||
const req = createMockRequest('GET')
|
||||
const { GET } = await import('@/app/api/knowledge/[id]/route')
|
||||
const response = await GET(req, { params: mockParams })
|
||||
const data = await response.json()
|
||||
|
||||
expect(response.status).toBe(404)
|
||||
expect(data.error).toBe('Knowledge base not found')
|
||||
})
|
||||
|
||||
it('should handle database errors', async () => {
|
||||
mockAuth$.mockAuthenticatedUser()
|
||||
mockDbChain.limit.mockRejectedValueOnce(new Error('Database error'))
|
||||
|
||||
mockCheckKnowledgeBaseAccess.mockRejectedValueOnce(new Error('Database error'))
|
||||
|
||||
const req = createMockRequest('GET')
|
||||
const { GET } = await import('@/app/api/knowledge/[id]/route')
|
||||
@@ -156,13 +212,13 @@ describe('Knowledge Base By ID API Route', () => {
|
||||
|
||||
resetMocks()
|
||||
|
||||
mockDbChain.where.mockReturnValueOnce(mockDbChain) // Return this to continue chain
|
||||
mockDbChain.limit.mockResolvedValueOnce([{ id: 'kb-123', userId: 'user-123' }])
|
||||
mockCheckKnowledgeBaseWriteAccess.mockResolvedValueOnce({
|
||||
hasAccess: true,
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
|
||||
mockDbChain.where.mockResolvedValueOnce(undefined)
|
||||
|
||||
mockDbChain.where.mockReturnValueOnce(mockDbChain) // Return this to continue chain
|
||||
mockDbChain.limit.mockResolvedValueOnce([{ ...mockKnowledgeBase, ...validUpdateData }])
|
||||
const updatedKnowledgeBase = { ...mockKnowledgeBase, ...validUpdateData }
|
||||
mockUpdateKnowledgeBase.mockResolvedValueOnce(updatedKnowledgeBase)
|
||||
|
||||
const req = createMockRequest('PUT', validUpdateData)
|
||||
const { PUT } = await import('@/app/api/knowledge/[id]/route')
|
||||
@@ -172,7 +228,16 @@ describe('Knowledge Base By ID API Route', () => {
|
||||
expect(response.status).toBe(200)
|
||||
expect(data.success).toBe(true)
|
||||
expect(data.data.name).toBe('Updated Knowledge Base')
|
||||
expect(mockDbChain.update).toHaveBeenCalled()
|
||||
expect(mockCheckKnowledgeBaseWriteAccess).toHaveBeenCalledWith('kb-123', 'user-123')
|
||||
expect(mockUpdateKnowledgeBase).toHaveBeenCalledWith(
|
||||
'kb-123',
|
||||
{
|
||||
name: validUpdateData.name,
|
||||
description: validUpdateData.description,
|
||||
chunkingConfig: undefined,
|
||||
},
|
||||
expect.any(String)
|
||||
)
|
||||
})
|
||||
|
||||
it('should return unauthorized for unauthenticated user', async () => {
|
||||
@@ -192,8 +257,10 @@ describe('Knowledge Base By ID API Route', () => {
|
||||
|
||||
resetMocks()
|
||||
|
||||
mockDbChain.where.mockReturnValueOnce(mockDbChain) // Return this to continue chain
|
||||
mockDbChain.limit.mockResolvedValueOnce([])
|
||||
mockCheckKnowledgeBaseWriteAccess.mockResolvedValueOnce({
|
||||
hasAccess: false,
|
||||
notFound: true,
|
||||
})
|
||||
|
||||
const req = createMockRequest('PUT', validUpdateData)
|
||||
const { PUT } = await import('@/app/api/knowledge/[id]/route')
|
||||
@@ -209,8 +276,10 @@ describe('Knowledge Base By ID API Route', () => {
|
||||
|
||||
resetMocks()
|
||||
|
||||
mockDbChain.where.mockReturnValueOnce(mockDbChain) // Return this to continue chain
|
||||
mockDbChain.limit.mockResolvedValueOnce([{ id: 'kb-123', userId: 'user-123' }])
|
||||
mockCheckKnowledgeBaseWriteAccess.mockResolvedValueOnce({
|
||||
hasAccess: true,
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
|
||||
const invalidData = {
|
||||
name: '',
|
||||
@@ -229,9 +298,13 @@ describe('Knowledge Base By ID API Route', () => {
|
||||
it('should handle database errors during update', async () => {
|
||||
mockAuth$.mockAuthenticatedUser()
|
||||
|
||||
mockDbChain.limit.mockResolvedValueOnce([{ id: 'kb-123', userId: 'user-123' }])
|
||||
// Mock successful write access check
|
||||
mockCheckKnowledgeBaseWriteAccess.mockResolvedValueOnce({
|
||||
hasAccess: true,
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
|
||||
mockDbChain.where.mockRejectedValueOnce(new Error('Database error'))
|
||||
mockUpdateKnowledgeBase.mockRejectedValueOnce(new Error('Database error'))
|
||||
|
||||
const req = createMockRequest('PUT', validUpdateData)
|
||||
const { PUT } = await import('@/app/api/knowledge/[id]/route')
|
||||
@@ -251,10 +324,12 @@ describe('Knowledge Base By ID API Route', () => {
|
||||
|
||||
resetMocks()
|
||||
|
||||
mockDbChain.where.mockReturnValueOnce(mockDbChain) // Return this to continue chain
|
||||
mockDbChain.limit.mockResolvedValueOnce([{ id: 'kb-123', userId: 'user-123' }])
|
||||
mockCheckKnowledgeBaseWriteAccess.mockResolvedValueOnce({
|
||||
hasAccess: true,
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
|
||||
mockDbChain.where.mockResolvedValueOnce(undefined)
|
||||
mockDeleteKnowledgeBase.mockResolvedValueOnce(undefined)
|
||||
|
||||
const req = createMockRequest('DELETE')
|
||||
const { DELETE } = await import('@/app/api/knowledge/[id]/route')
|
||||
@@ -264,7 +339,8 @@ describe('Knowledge Base By ID API Route', () => {
|
||||
expect(response.status).toBe(200)
|
||||
expect(data.success).toBe(true)
|
||||
expect(data.data.message).toBe('Knowledge base deleted successfully')
|
||||
expect(mockDbChain.update).toHaveBeenCalled()
|
||||
expect(mockCheckKnowledgeBaseWriteAccess).toHaveBeenCalledWith('kb-123', 'user-123')
|
||||
expect(mockDeleteKnowledgeBase).toHaveBeenCalledWith('kb-123', expect.any(String))
|
||||
})
|
||||
|
||||
it('should return unauthorized for unauthenticated user', async () => {
|
||||
@@ -284,8 +360,10 @@ describe('Knowledge Base By ID API Route', () => {
|
||||
|
||||
resetMocks()
|
||||
|
||||
mockDbChain.where.mockReturnValueOnce(mockDbChain) // Return this to continue chain
|
||||
mockDbChain.limit.mockResolvedValueOnce([])
|
||||
mockCheckKnowledgeBaseWriteAccess.mockResolvedValueOnce({
|
||||
hasAccess: false,
|
||||
notFound: true,
|
||||
})
|
||||
|
||||
const req = createMockRequest('DELETE')
|
||||
const { DELETE } = await import('@/app/api/knowledge/[id]/route')
|
||||
@@ -301,8 +379,10 @@ describe('Knowledge Base By ID API Route', () => {
|
||||
|
||||
resetMocks()
|
||||
|
||||
mockDbChain.where.mockReturnValueOnce(mockDbChain) // Return this to continue chain
|
||||
mockDbChain.limit.mockResolvedValueOnce([{ id: 'kb-123', userId: 'different-user' }])
|
||||
mockCheckKnowledgeBaseWriteAccess.mockResolvedValueOnce({
|
||||
hasAccess: false,
|
||||
notFound: false,
|
||||
})
|
||||
|
||||
const req = createMockRequest('DELETE')
|
||||
const { DELETE } = await import('@/app/api/knowledge/[id]/route')
|
||||
@@ -316,9 +396,12 @@ describe('Knowledge Base By ID API Route', () => {
|
||||
it('should handle database errors during delete', async () => {
|
||||
mockAuth$.mockAuthenticatedUser()
|
||||
|
||||
mockDbChain.limit.mockResolvedValueOnce([{ id: 'kb-123', userId: 'user-123' }])
|
||||
mockCheckKnowledgeBaseWriteAccess.mockResolvedValueOnce({
|
||||
hasAccess: true,
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
|
||||
mockDbChain.where.mockRejectedValueOnce(new Error('Database error'))
|
||||
mockDeleteKnowledgeBase.mockRejectedValueOnce(new Error('Database error'))
|
||||
|
||||
const req = createMockRequest('DELETE')
|
||||
const { DELETE } = await import('@/app/api/knowledge/[id]/route')
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import { and, eq, isNull } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import {
|
||||
deleteKnowledgeBase,
|
||||
getKnowledgeBaseById,
|
||||
updateKnowledgeBase,
|
||||
} from '@/lib/knowledge/service'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { checkKnowledgeBaseAccess, checkKnowledgeBaseWriteAccess } from '@/app/api/knowledge/utils'
|
||||
import { db } from '@/db'
|
||||
import { knowledgeBase } from '@/db/schema'
|
||||
|
||||
const logger = createLogger('KnowledgeBaseByIdAPI')
|
||||
|
||||
@@ -48,13 +50,9 @@ export async function GET(_req: NextRequest, { params }: { params: Promise<{ id:
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
const knowledgeBases = await db
|
||||
.select()
|
||||
.from(knowledgeBase)
|
||||
.where(and(eq(knowledgeBase.id, id), isNull(knowledgeBase.deletedAt)))
|
||||
.limit(1)
|
||||
const knowledgeBaseData = await getKnowledgeBaseById(id)
|
||||
|
||||
if (knowledgeBases.length === 0) {
|
||||
if (!knowledgeBaseData) {
|
||||
return NextResponse.json({ error: 'Knowledge base not found' }, { status: 404 })
|
||||
}
|
||||
|
||||
@@ -62,7 +60,7 @@ export async function GET(_req: NextRequest, { params }: { params: Promise<{ id:
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
data: knowledgeBases[0],
|
||||
data: knowledgeBaseData,
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error(`[${requestId}] Error fetching knowledge base`, error)
|
||||
@@ -99,42 +97,21 @@ export async function PUT(req: NextRequest, { params }: { params: Promise<{ id:
|
||||
try {
|
||||
const validatedData = UpdateKnowledgeBaseSchema.parse(body)
|
||||
|
||||
const updateData: any = {
|
||||
updatedAt: new Date(),
|
||||
}
|
||||
|
||||
if (validatedData.name !== undefined) updateData.name = validatedData.name
|
||||
if (validatedData.description !== undefined)
|
||||
updateData.description = validatedData.description
|
||||
if (validatedData.workspaceId !== undefined)
|
||||
updateData.workspaceId = validatedData.workspaceId
|
||||
|
||||
// Handle embedding model and dimension together to ensure consistency
|
||||
if (
|
||||
validatedData.embeddingModel !== undefined ||
|
||||
validatedData.embeddingDimension !== undefined
|
||||
) {
|
||||
updateData.embeddingModel = 'text-embedding-3-small'
|
||||
updateData.embeddingDimension = 1536
|
||||
}
|
||||
|
||||
if (validatedData.chunkingConfig !== undefined)
|
||||
updateData.chunkingConfig = validatedData.chunkingConfig
|
||||
|
||||
await db.update(knowledgeBase).set(updateData).where(eq(knowledgeBase.id, id))
|
||||
|
||||
// Fetch the updated knowledge base
|
||||
const updatedKnowledgeBase = await db
|
||||
.select()
|
||||
.from(knowledgeBase)
|
||||
.where(eq(knowledgeBase.id, id))
|
||||
.limit(1)
|
||||
const updatedKnowledgeBase = await updateKnowledgeBase(
|
||||
id,
|
||||
{
|
||||
name: validatedData.name,
|
||||
description: validatedData.description,
|
||||
chunkingConfig: validatedData.chunkingConfig,
|
||||
},
|
||||
requestId
|
||||
)
|
||||
|
||||
logger.info(`[${requestId}] Knowledge base updated: ${id} for user ${session.user.id}`)
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
data: updatedKnowledgeBase[0],
|
||||
data: updatedKnowledgeBase,
|
||||
})
|
||||
} catch (validationError) {
|
||||
if (validationError instanceof z.ZodError) {
|
||||
@@ -178,14 +155,7 @@ export async function DELETE(_req: NextRequest, { params }: { params: Promise<{
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
// Soft delete by setting deletedAt timestamp
|
||||
await db
|
||||
.update(knowledgeBase)
|
||||
.set({
|
||||
deletedAt: new Date(),
|
||||
updatedAt: new Date(),
|
||||
})
|
||||
.where(eq(knowledgeBase.id, id))
|
||||
await deleteKnowledgeBase(id, requestId)
|
||||
|
||||
logger.info(`[${requestId}] Knowledge base deleted: ${id} for user ${session.user.id}`)
|
||||
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
import { randomUUID } from 'crypto'
|
||||
import { and, eq, isNotNull } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { deleteTagDefinition } from '@/lib/knowledge/tags/service'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { checkKnowledgeBaseAccess } from '@/app/api/knowledge/utils'
|
||||
import { db } from '@/db'
|
||||
import { document, embedding, knowledgeBaseTagDefinitions } from '@/db/schema'
|
||||
|
||||
export const dynamic = 'force-dynamic'
|
||||
|
||||
@@ -29,87 +27,16 @@ export async function DELETE(
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
// Check if user has access to the knowledge base
|
||||
const accessCheck = await checkKnowledgeBaseAccess(knowledgeBaseId, session.user.id)
|
||||
if (!accessCheck.hasAccess) {
|
||||
return NextResponse.json({ error: 'Forbidden' }, { status: 403 })
|
||||
}
|
||||
|
||||
// Get the tag definition to find which slot it uses
|
||||
const tagDefinition = await db
|
||||
.select({
|
||||
id: knowledgeBaseTagDefinitions.id,
|
||||
tagSlot: knowledgeBaseTagDefinitions.tagSlot,
|
||||
displayName: knowledgeBaseTagDefinitions.displayName,
|
||||
})
|
||||
.from(knowledgeBaseTagDefinitions)
|
||||
.where(
|
||||
and(
|
||||
eq(knowledgeBaseTagDefinitions.id, tagId),
|
||||
eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId)
|
||||
)
|
||||
)
|
||||
.limit(1)
|
||||
|
||||
if (tagDefinition.length === 0) {
|
||||
return NextResponse.json({ error: 'Tag definition not found' }, { status: 404 })
|
||||
}
|
||||
|
||||
const tagDef = tagDefinition[0]
|
||||
|
||||
// Delete the tag definition and clear all document tags in a transaction
|
||||
await db.transaction(async (tx) => {
|
||||
logger.info(`[${requestId}] Starting transaction to delete ${tagDef.tagSlot}`)
|
||||
|
||||
try {
|
||||
// Clear the tag from documents that actually have this tag set
|
||||
logger.info(`[${requestId}] Clearing tag from documents...`)
|
||||
await tx
|
||||
.update(document)
|
||||
.set({ [tagDef.tagSlot]: null })
|
||||
.where(
|
||||
and(
|
||||
eq(document.knowledgeBaseId, knowledgeBaseId),
|
||||
isNotNull(document[tagDef.tagSlot as keyof typeof document.$inferSelect])
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(`[${requestId}] Documents updated successfully`)
|
||||
|
||||
// Clear the tag from embeddings that actually have this tag set
|
||||
logger.info(`[${requestId}] Clearing tag from embeddings...`)
|
||||
await tx
|
||||
.update(embedding)
|
||||
.set({ [tagDef.tagSlot]: null })
|
||||
.where(
|
||||
and(
|
||||
eq(embedding.knowledgeBaseId, knowledgeBaseId),
|
||||
isNotNull(embedding[tagDef.tagSlot as keyof typeof embedding.$inferSelect])
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(`[${requestId}] Embeddings updated successfully`)
|
||||
|
||||
// Delete the tag definition
|
||||
logger.info(`[${requestId}] Deleting tag definition...`)
|
||||
await tx
|
||||
.delete(knowledgeBaseTagDefinitions)
|
||||
.where(eq(knowledgeBaseTagDefinitions.id, tagId))
|
||||
|
||||
logger.info(`[${requestId}] Tag definition deleted successfully`)
|
||||
} catch (error) {
|
||||
logger.error(`[${requestId}] Error in transaction:`, error)
|
||||
throw error
|
||||
}
|
||||
})
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Successfully deleted tag definition ${tagDef.displayName} (${tagDef.tagSlot})`
|
||||
)
|
||||
const deletedTag = await deleteTagDefinition(tagId, requestId)
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
message: `Tag definition "${tagDef.displayName}" deleted successfully`,
|
||||
message: `Tag definition "${deletedTag.displayName}" deleted successfully`,
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error(`[${requestId}] Error deleting tag definition`, error)
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import { randomUUID } from 'crypto'
|
||||
import { and, eq } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { SUPPORTED_FIELD_TYPES } from '@/lib/constants/knowledge'
|
||||
import { createTagDefinition, getTagDefinitions } from '@/lib/knowledge/tags/service'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { checkKnowledgeBaseAccess } from '@/app/api/knowledge/utils'
|
||||
import { db } from '@/db'
|
||||
import { knowledgeBaseTagDefinitions } from '@/db/schema'
|
||||
|
||||
export const dynamic = 'force-dynamic'
|
||||
|
||||
@@ -24,25 +24,12 @@ export async function GET(req: NextRequest, { params }: { params: Promise<{ id:
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
// Check if user has access to the knowledge base
|
||||
const accessCheck = await checkKnowledgeBaseAccess(knowledgeBaseId, session.user.id)
|
||||
if (!accessCheck.hasAccess) {
|
||||
return NextResponse.json({ error: 'Forbidden' }, { status: 403 })
|
||||
}
|
||||
|
||||
// Get tag definitions for the knowledge base
|
||||
const tagDefinitions = await db
|
||||
.select({
|
||||
id: knowledgeBaseTagDefinitions.id,
|
||||
tagSlot: knowledgeBaseTagDefinitions.tagSlot,
|
||||
displayName: knowledgeBaseTagDefinitions.displayName,
|
||||
fieldType: knowledgeBaseTagDefinitions.fieldType,
|
||||
createdAt: knowledgeBaseTagDefinitions.createdAt,
|
||||
updatedAt: knowledgeBaseTagDefinitions.updatedAt,
|
||||
})
|
||||
.from(knowledgeBaseTagDefinitions)
|
||||
.where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId))
|
||||
.orderBy(knowledgeBaseTagDefinitions.tagSlot)
|
||||
const tagDefinitions = await getTagDefinitions(knowledgeBaseId)
|
||||
|
||||
logger.info(`[${requestId}] Retrieved ${tagDefinitions.length} tag definitions`)
|
||||
|
||||
@@ -69,68 +56,43 @@ export async function POST(req: NextRequest, { params }: { params: Promise<{ id:
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
// Check if user has access to the knowledge base
|
||||
const accessCheck = await checkKnowledgeBaseAccess(knowledgeBaseId, session.user.id)
|
||||
if (!accessCheck.hasAccess) {
|
||||
return NextResponse.json({ error: 'Forbidden' }, { status: 403 })
|
||||
}
|
||||
|
||||
const body = await req.json()
|
||||
const { tagSlot, displayName, fieldType } = body
|
||||
|
||||
if (!tagSlot || !displayName || !fieldType) {
|
||||
return NextResponse.json(
|
||||
{ error: 'tagSlot, displayName, and fieldType are required' },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
const CreateTagDefinitionSchema = z.object({
|
||||
tagSlot: z.string().min(1, 'Tag slot is required'),
|
||||
displayName: z.string().min(1, 'Display name is required'),
|
||||
fieldType: z.enum(SUPPORTED_FIELD_TYPES as [string, ...string[]], {
|
||||
errorMap: () => ({ message: 'Invalid field type' }),
|
||||
}),
|
||||
})
|
||||
|
||||
// Check if tag slot is already used
|
||||
const existingTag = await db
|
||||
.select()
|
||||
.from(knowledgeBaseTagDefinitions)
|
||||
.where(
|
||||
and(
|
||||
eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId),
|
||||
eq(knowledgeBaseTagDefinitions.tagSlot, tagSlot)
|
||||
let validatedData
|
||||
try {
|
||||
validatedData = CreateTagDefinitionSchema.parse(body)
|
||||
} catch (error) {
|
||||
if (error instanceof z.ZodError) {
|
||||
return NextResponse.json(
|
||||
{ error: 'Invalid request data', details: error.errors },
|
||||
{ status: 400 }
|
||||
)
|
||||
)
|
||||
.limit(1)
|
||||
|
||||
if (existingTag.length > 0) {
|
||||
return NextResponse.json({ error: 'Tag slot is already in use' }, { status: 409 })
|
||||
}
|
||||
throw error
|
||||
}
|
||||
|
||||
// Check if display name is already used
|
||||
const existingName = await db
|
||||
.select()
|
||||
.from(knowledgeBaseTagDefinitions)
|
||||
.where(
|
||||
and(
|
||||
eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId),
|
||||
eq(knowledgeBaseTagDefinitions.displayName, displayName)
|
||||
)
|
||||
)
|
||||
.limit(1)
|
||||
|
||||
if (existingName.length > 0) {
|
||||
return NextResponse.json({ error: 'Tag name is already in use' }, { status: 409 })
|
||||
}
|
||||
|
||||
// Create the new tag definition
|
||||
const newTagDefinition = {
|
||||
id: randomUUID(),
|
||||
knowledgeBaseId,
|
||||
tagSlot,
|
||||
displayName,
|
||||
fieldType,
|
||||
createdAt: new Date(),
|
||||
updatedAt: new Date(),
|
||||
}
|
||||
|
||||
await db.insert(knowledgeBaseTagDefinitions).values(newTagDefinition)
|
||||
|
||||
logger.info(`[${requestId}] Successfully created tag definition ${displayName} (${tagSlot})`)
|
||||
const newTagDefinition = await createTagDefinition(
|
||||
{
|
||||
knowledgeBaseId,
|
||||
tagSlot: validatedData.tagSlot,
|
||||
displayName: validatedData.displayName,
|
||||
fieldType: validatedData.fieldType,
|
||||
},
|
||||
requestId
|
||||
)
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
import { randomUUID } from 'crypto'
|
||||
import { and, eq, isNotNull } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { getTagUsage } from '@/lib/knowledge/tags/service'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { checkKnowledgeBaseAccess } from '@/app/api/knowledge/utils'
|
||||
import { db } from '@/db'
|
||||
import { document, knowledgeBaseTagDefinitions } from '@/db/schema'
|
||||
|
||||
export const dynamic = 'force-dynamic'
|
||||
|
||||
@@ -24,57 +22,15 @@ export async function GET(req: NextRequest, { params }: { params: Promise<{ id:
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
// Check if user has access to the knowledge base
|
||||
const accessCheck = await checkKnowledgeBaseAccess(knowledgeBaseId, session.user.id)
|
||||
if (!accessCheck.hasAccess) {
|
||||
return NextResponse.json({ error: 'Forbidden' }, { status: 403 })
|
||||
}
|
||||
|
||||
// Get all tag definitions for the knowledge base
|
||||
const tagDefinitions = await db
|
||||
.select({
|
||||
id: knowledgeBaseTagDefinitions.id,
|
||||
tagSlot: knowledgeBaseTagDefinitions.tagSlot,
|
||||
displayName: knowledgeBaseTagDefinitions.displayName,
|
||||
})
|
||||
.from(knowledgeBaseTagDefinitions)
|
||||
.where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId))
|
||||
|
||||
// Get usage statistics for each tag definition
|
||||
const usageStats = await Promise.all(
|
||||
tagDefinitions.map(async (tagDef) => {
|
||||
// Count documents using this tag slot
|
||||
const tagSlotColumn = tagDef.tagSlot as keyof typeof document.$inferSelect
|
||||
|
||||
const documentsWithTag = await db
|
||||
.select({
|
||||
id: document.id,
|
||||
filename: document.filename,
|
||||
[tagDef.tagSlot]: document[tagSlotColumn as keyof typeof document.$inferSelect] as any,
|
||||
})
|
||||
.from(document)
|
||||
.where(
|
||||
and(
|
||||
eq(document.knowledgeBaseId, knowledgeBaseId),
|
||||
isNotNull(document[tagSlotColumn as keyof typeof document.$inferSelect])
|
||||
)
|
||||
)
|
||||
|
||||
return {
|
||||
tagName: tagDef.displayName,
|
||||
tagSlot: tagDef.tagSlot,
|
||||
documentCount: documentsWithTag.length,
|
||||
documents: documentsWithTag.map((doc) => ({
|
||||
id: doc.id,
|
||||
name: doc.filename,
|
||||
tagValue: doc[tagDef.tagSlot],
|
||||
})),
|
||||
}
|
||||
})
|
||||
)
|
||||
const usageStats = await getTagUsage(knowledgeBaseId, requestId)
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Retrieved usage statistics for ${tagDefinitions.length} tag definitions`
|
||||
`[${requestId}] Retrieved usage statistics for ${usageStats.length} tag definitions`
|
||||
)
|
||||
|
||||
return NextResponse.json({
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
import { and, count, eq, isNotNull, isNull, or } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { createKnowledgeBase, getKnowledgeBases } from '@/lib/knowledge/service'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { getUserEntityPermissions } from '@/lib/permissions/utils'
|
||||
import { db } from '@/db'
|
||||
import { document, knowledgeBase, permissions } from '@/db/schema'
|
||||
|
||||
const logger = createLogger('KnowledgeBaseAPI')
|
||||
|
||||
@@ -41,60 +38,10 @@ export async function GET(req: NextRequest) {
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
// Check for workspace filtering
|
||||
const { searchParams } = new URL(req.url)
|
||||
const workspaceId = searchParams.get('workspaceId')
|
||||
|
||||
// Get knowledge bases that user can access through direct ownership OR workspace permissions
|
||||
const knowledgeBasesWithCounts = await db
|
||||
.select({
|
||||
id: knowledgeBase.id,
|
||||
name: knowledgeBase.name,
|
||||
description: knowledgeBase.description,
|
||||
tokenCount: knowledgeBase.tokenCount,
|
||||
embeddingModel: knowledgeBase.embeddingModel,
|
||||
embeddingDimension: knowledgeBase.embeddingDimension,
|
||||
chunkingConfig: knowledgeBase.chunkingConfig,
|
||||
createdAt: knowledgeBase.createdAt,
|
||||
updatedAt: knowledgeBase.updatedAt,
|
||||
workspaceId: knowledgeBase.workspaceId,
|
||||
docCount: count(document.id),
|
||||
})
|
||||
.from(knowledgeBase)
|
||||
.leftJoin(
|
||||
document,
|
||||
and(eq(document.knowledgeBaseId, knowledgeBase.id), isNull(document.deletedAt))
|
||||
)
|
||||
.leftJoin(
|
||||
permissions,
|
||||
and(
|
||||
eq(permissions.entityType, 'workspace'),
|
||||
eq(permissions.entityId, knowledgeBase.workspaceId),
|
||||
eq(permissions.userId, session.user.id)
|
||||
)
|
||||
)
|
||||
.where(
|
||||
and(
|
||||
isNull(knowledgeBase.deletedAt),
|
||||
workspaceId
|
||||
? // When filtering by workspace
|
||||
or(
|
||||
// Knowledge bases belonging to the specified workspace (user must have workspace permissions)
|
||||
and(eq(knowledgeBase.workspaceId, workspaceId), isNotNull(permissions.userId)),
|
||||
// Fallback: User-owned knowledge bases without workspace (legacy)
|
||||
and(eq(knowledgeBase.userId, session.user.id), isNull(knowledgeBase.workspaceId))
|
||||
)
|
||||
: // When not filtering by workspace, use original logic
|
||||
or(
|
||||
// User owns the knowledge base directly
|
||||
eq(knowledgeBase.userId, session.user.id),
|
||||
// User has permissions on the knowledge base's workspace
|
||||
isNotNull(permissions.userId)
|
||||
)
|
||||
)
|
||||
)
|
||||
.groupBy(knowledgeBase.id)
|
||||
.orderBy(knowledgeBase.createdAt)
|
||||
const knowledgeBasesWithCounts = await getKnowledgeBases(session.user.id, workspaceId)
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
@@ -121,49 +68,16 @@ export async function POST(req: NextRequest) {
|
||||
try {
|
||||
const validatedData = CreateKnowledgeBaseSchema.parse(body)
|
||||
|
||||
// If creating in a workspace, check if user has write/admin permissions
|
||||
if (validatedData.workspaceId) {
|
||||
const userPermission = await getUserEntityPermissions(
|
||||
session.user.id,
|
||||
'workspace',
|
||||
validatedData.workspaceId
|
||||
)
|
||||
if (userPermission !== 'write' && userPermission !== 'admin') {
|
||||
logger.warn(
|
||||
`[${requestId}] User ${session.user.id} denied permission to create knowledge base in workspace ${validatedData.workspaceId}`
|
||||
)
|
||||
return NextResponse.json(
|
||||
{ error: 'Insufficient permissions to create knowledge base in this workspace' },
|
||||
{ status: 403 }
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
const id = crypto.randomUUID()
|
||||
const now = new Date()
|
||||
|
||||
const newKnowledgeBase = {
|
||||
id,
|
||||
const createData = {
|
||||
...validatedData,
|
||||
userId: session.user.id,
|
||||
workspaceId: validatedData.workspaceId || null,
|
||||
name: validatedData.name,
|
||||
description: validatedData.description || null,
|
||||
tokenCount: 0,
|
||||
embeddingModel: validatedData.embeddingModel,
|
||||
embeddingDimension: validatedData.embeddingDimension,
|
||||
chunkingConfig: validatedData.chunkingConfig || {
|
||||
maxSize: 1024,
|
||||
minSize: 100,
|
||||
overlap: 200,
|
||||
},
|
||||
docCount: 0,
|
||||
createdAt: now,
|
||||
updatedAt: now,
|
||||
}
|
||||
|
||||
await db.insert(knowledgeBase).values(newKnowledgeBase)
|
||||
const newKnowledgeBase = await createKnowledgeBase(createData, requestId)
|
||||
|
||||
logger.info(`[${requestId}] Knowledge base created: ${id} for user ${session.user.id}`)
|
||||
logger.info(
|
||||
`[${requestId}] Knowledge base created: ${newKnowledgeBase.id} for user ${session.user.id}`
|
||||
)
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
import { eq } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { TAG_SLOTS } from '@/lib/constants/knowledge'
|
||||
import { getDocumentTagDefinitions } from '@/lib/knowledge/tags/service'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { estimateTokenCount } from '@/lib/tokenization/estimators'
|
||||
import { getUserId } from '@/app/api/auth/oauth/utils'
|
||||
import { checkKnowledgeBaseAccess } from '@/app/api/knowledge/utils'
|
||||
import { db } from '@/db'
|
||||
import { knowledgeBaseTagDefinitions } from '@/db/schema'
|
||||
import { calculateCost } from '@/providers/utils'
|
||||
import {
|
||||
generateSearchEmbedding,
|
||||
@@ -94,13 +92,7 @@ export async function POST(request: NextRequest) {
|
||||
try {
|
||||
// Fetch tag definitions for the first accessible KB (since we're using single KB now)
|
||||
const kbId = accessibleKbIds[0]
|
||||
const tagDefs = await db
|
||||
.select({
|
||||
tagSlot: knowledgeBaseTagDefinitions.tagSlot,
|
||||
displayName: knowledgeBaseTagDefinitions.displayName,
|
||||
})
|
||||
.from(knowledgeBaseTagDefinitions)
|
||||
.where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, kbId))
|
||||
const tagDefs = await getDocumentTagDefinitions(kbId)
|
||||
|
||||
logger.debug(`[${requestId}] Found tag definitions:`, tagDefs)
|
||||
logger.debug(`[${requestId}] Original filters:`, validatedData.filters)
|
||||
@@ -224,13 +216,7 @@ export async function POST(request: NextRequest) {
|
||||
const tagDefinitionsMap: Record<string, Record<string, string>> = {}
|
||||
for (const kbId of accessibleKbIds) {
|
||||
try {
|
||||
const tagDefs = await db
|
||||
.select({
|
||||
tagSlot: knowledgeBaseTagDefinitions.tagSlot,
|
||||
displayName: knowledgeBaseTagDefinitions.displayName,
|
||||
})
|
||||
.from(knowledgeBaseTagDefinitions)
|
||||
.where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, kbId))
|
||||
const tagDefs = await getDocumentTagDefinitions(kbId)
|
||||
|
||||
tagDefinitionsMap[kbId] = {}
|
||||
tagDefs.forEach((def) => {
|
||||
|
||||
@@ -16,7 +16,7 @@ vi.mock('@/lib/logs/console/logger', () => ({
|
||||
})),
|
||||
}))
|
||||
vi.mock('@/db')
|
||||
vi.mock('@/lib/documents/utils', () => ({
|
||||
vi.mock('@/lib/knowledge/documents/utils', () => ({
|
||||
retryWithExponentialBackoff: (fn: any) => fn(),
|
||||
}))
|
||||
|
||||
|
||||
@@ -21,11 +21,11 @@ vi.mock('@/lib/env', () => ({
|
||||
typeof value === 'string' ? value === 'true' || value === '1' : Boolean(value),
|
||||
}))
|
||||
|
||||
vi.mock('@/lib/documents/utils', () => ({
|
||||
vi.mock('@/lib/knowledge/documents/utils', () => ({
|
||||
retryWithExponentialBackoff: (fn: any) => fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/lib/documents/document-processor', () => ({
|
||||
vi.mock('@/lib/knowledge/documents/document-processor', () => ({
|
||||
processDocument: vi.fn().mockResolvedValue({
|
||||
chunks: [
|
||||
{
|
||||
@@ -149,12 +149,12 @@ vi.mock('@/db', () => {
|
||||
}
|
||||
})
|
||||
|
||||
import { generateEmbeddings } from '@/lib/embeddings/utils'
|
||||
import { processDocumentAsync } from '@/lib/knowledge/documents/service'
|
||||
import {
|
||||
checkChunkAccess,
|
||||
checkDocumentAccess,
|
||||
checkKnowledgeBaseAccess,
|
||||
generateEmbeddings,
|
||||
processDocumentAsync,
|
||||
} from '@/app/api/knowledge/utils'
|
||||
|
||||
describe('Knowledge Utils', () => {
|
||||
|
||||
@@ -1,35 +1,8 @@
|
||||
import crypto from 'crypto'
|
||||
import { and, eq, isNull } from 'drizzle-orm'
|
||||
import { processDocument } from '@/lib/documents/document-processor'
|
||||
import { generateEmbeddings } from '@/lib/embeddings/utils'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { getUserEntityPermissions } from '@/lib/permissions/utils'
|
||||
import { db } from '@/db'
|
||||
import { document, embedding, knowledgeBase } from '@/db/schema'
|
||||
|
||||
const logger = createLogger('KnowledgeUtils')
|
||||
|
||||
const TIMEOUTS = {
|
||||
OVERALL_PROCESSING: 150000, // 150 seconds (2.5 minutes)
|
||||
EMBEDDINGS_API: 60000, // 60 seconds per batch
|
||||
} as const
|
||||
|
||||
/**
|
||||
* Create a timeout wrapper for async operations
|
||||
*/
|
||||
function withTimeout<T>(
|
||||
promise: Promise<T>,
|
||||
timeoutMs: number,
|
||||
operation = 'Operation'
|
||||
): Promise<T> {
|
||||
return Promise.race([
|
||||
promise,
|
||||
new Promise<never>((_, reject) =>
|
||||
setTimeout(() => reject(new Error(`${operation} timed out after ${timeoutMs}ms`)), timeoutMs)
|
||||
),
|
||||
])
|
||||
}
|
||||
|
||||
export interface KnowledgeBaseData {
|
||||
id: string
|
||||
userId: string
|
||||
@@ -380,154 +353,3 @@ export async function checkChunkAccess(
|
||||
knowledgeBase: kbAccess.knowledgeBase!,
|
||||
}
|
||||
}
|
||||
|
||||
// Export for external use
|
||||
export { generateEmbeddings }
|
||||
|
||||
/**
|
||||
* Process a document asynchronously with full error handling
|
||||
*/
|
||||
export async function processDocumentAsync(
|
||||
knowledgeBaseId: string,
|
||||
documentId: string,
|
||||
docData: {
|
||||
filename: string
|
||||
fileUrl: string
|
||||
fileSize: number
|
||||
mimeType: string
|
||||
},
|
||||
processingOptions: {
|
||||
chunkSize?: number
|
||||
minCharactersPerChunk?: number
|
||||
recipe?: string
|
||||
lang?: string
|
||||
chunkOverlap?: number
|
||||
}
|
||||
): Promise<void> {
|
||||
const startTime = Date.now()
|
||||
try {
|
||||
logger.info(`[${documentId}] Starting document processing: ${docData.filename}`)
|
||||
|
||||
// Set status to processing
|
||||
await db
|
||||
.update(document)
|
||||
.set({
|
||||
processingStatus: 'processing',
|
||||
processingStartedAt: new Date(),
|
||||
processingError: null, // Clear any previous error
|
||||
})
|
||||
.where(eq(document.id, documentId))
|
||||
|
||||
logger.info(`[${documentId}] Status updated to 'processing', starting document processor`)
|
||||
|
||||
// Wrap the entire processing operation with a 5-minute timeout
|
||||
await withTimeout(
|
||||
(async () => {
|
||||
const processed = await processDocument(
|
||||
docData.fileUrl,
|
||||
docData.filename,
|
||||
docData.mimeType,
|
||||
processingOptions.chunkSize || 1000,
|
||||
processingOptions.chunkOverlap || 200,
|
||||
processingOptions.minCharactersPerChunk || 1
|
||||
)
|
||||
|
||||
const now = new Date()
|
||||
|
||||
logger.info(
|
||||
`[${documentId}] Document parsed successfully, generating embeddings for ${processed.chunks.length} chunks`
|
||||
)
|
||||
|
||||
const chunkTexts = processed.chunks.map((chunk) => chunk.text)
|
||||
const embeddings = chunkTexts.length > 0 ? await generateEmbeddings(chunkTexts) : []
|
||||
|
||||
logger.info(`[${documentId}] Embeddings generated, fetching document tags`)
|
||||
|
||||
// Fetch document to get tags
|
||||
const documentRecord = await db
|
||||
.select({
|
||||
tag1: document.tag1,
|
||||
tag2: document.tag2,
|
||||
tag3: document.tag3,
|
||||
tag4: document.tag4,
|
||||
tag5: document.tag5,
|
||||
tag6: document.tag6,
|
||||
tag7: document.tag7,
|
||||
})
|
||||
.from(document)
|
||||
.where(eq(document.id, documentId))
|
||||
.limit(1)
|
||||
|
||||
const documentTags = documentRecord[0] || {}
|
||||
|
||||
logger.info(`[${documentId}] Creating embedding records with tags`)
|
||||
|
||||
const embeddingRecords = processed.chunks.map((chunk, chunkIndex) => ({
|
||||
id: crypto.randomUUID(),
|
||||
knowledgeBaseId,
|
||||
documentId,
|
||||
chunkIndex,
|
||||
chunkHash: crypto.createHash('sha256').update(chunk.text).digest('hex'),
|
||||
content: chunk.text,
|
||||
contentLength: chunk.text.length,
|
||||
tokenCount: Math.ceil(chunk.text.length / 4),
|
||||
embedding: embeddings[chunkIndex] || null,
|
||||
embeddingModel: 'text-embedding-3-small',
|
||||
startOffset: chunk.metadata.startIndex,
|
||||
endOffset: chunk.metadata.endIndex,
|
||||
// Copy tags from document
|
||||
tag1: documentTags.tag1,
|
||||
tag2: documentTags.tag2,
|
||||
tag3: documentTags.tag3,
|
||||
tag4: documentTags.tag4,
|
||||
tag5: documentTags.tag5,
|
||||
tag6: documentTags.tag6,
|
||||
tag7: documentTags.tag7,
|
||||
createdAt: now,
|
||||
updatedAt: now,
|
||||
}))
|
||||
|
||||
await db.transaction(async (tx) => {
|
||||
if (embeddingRecords.length > 0) {
|
||||
await tx.insert(embedding).values(embeddingRecords)
|
||||
}
|
||||
|
||||
await tx
|
||||
.update(document)
|
||||
.set({
|
||||
chunkCount: processed.metadata.chunkCount,
|
||||
tokenCount: processed.metadata.tokenCount,
|
||||
characterCount: processed.metadata.characterCount,
|
||||
processingStatus: 'completed',
|
||||
processingCompletedAt: now,
|
||||
processingError: null,
|
||||
})
|
||||
.where(eq(document.id, documentId))
|
||||
})
|
||||
})(),
|
||||
TIMEOUTS.OVERALL_PROCESSING,
|
||||
'Document processing'
|
||||
)
|
||||
|
||||
const processingTime = Date.now() - startTime
|
||||
logger.info(`[${documentId}] Successfully processed document in ${processingTime}ms`)
|
||||
} catch (error) {
|
||||
const processingTime = Date.now() - startTime
|
||||
logger.error(`[${documentId}] Failed to process document after ${processingTime}ms:`, {
|
||||
error: error instanceof Error ? error.message : 'Unknown error',
|
||||
stack: error instanceof Error ? error.stack : undefined,
|
||||
filename: docData.filename,
|
||||
fileUrl: docData.fileUrl,
|
||||
mimeType: docData.mimeType,
|
||||
})
|
||||
|
||||
await db
|
||||
.update(document)
|
||||
.set({
|
||||
processingStatus: 'failed',
|
||||
processingError: error instanceof Error ? error.message : 'Unknown error',
|
||||
processingCompletedAt: new Date(),
|
||||
})
|
||||
.where(eq(document.id, documentId))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -64,7 +64,9 @@ export async function POST(request: Request) {
|
||||
|
||||
return new NextResponse(
|
||||
`Internal Server Error: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||
{ status: 500 }
|
||||
{
|
||||
status: 500,
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -112,7 +112,9 @@ export async function POST(request: NextRequest) {
|
||||
|
||||
return new Response(
|
||||
`Internal Server Error: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||
{ status: 500 }
|
||||
{
|
||||
status: 500,
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -495,7 +495,9 @@ async function createAirtableWebhookSubscription(
|
||||
} else {
|
||||
logger.info(
|
||||
`[${requestId}] Successfully created webhook in Airtable for webhook ${webhookData.id}.`,
|
||||
{ airtableWebhookId: responseBody.id }
|
||||
{
|
||||
airtableWebhookId: responseBody.id,
|
||||
}
|
||||
)
|
||||
// Store the airtableWebhookId (responseBody.id) within the providerConfig
|
||||
try {
|
||||
|
||||
@@ -4,8 +4,10 @@ import { useCallback, useEffect, useState } from 'react'
|
||||
import { format } from 'date-fns'
|
||||
import {
|
||||
AlertCircle,
|
||||
ChevronDown,
|
||||
ChevronLeft,
|
||||
ChevronRight,
|
||||
ChevronUp,
|
||||
Circle,
|
||||
CircleOff,
|
||||
FileText,
|
||||
@@ -29,6 +31,7 @@ import { Button } from '@/components/ui/button'
|
||||
import { Checkbox } from '@/components/ui/checkbox'
|
||||
import { SearchHighlight } from '@/components/ui/search-highlight'
|
||||
import { Tooltip, TooltipContent, TooltipTrigger } from '@/components/ui/tooltip'
|
||||
import type { DocumentSortField, SortOrder } from '@/lib/knowledge/documents/types'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import {
|
||||
ActionBar,
|
||||
@@ -47,7 +50,6 @@ import { type DocumentData, useKnowledgeStore } from '@/stores/knowledge/store'
|
||||
|
||||
const logger = createLogger('KnowledgeBase')
|
||||
|
||||
// Constants
|
||||
const DOCUMENTS_PER_PAGE = 50
|
||||
|
||||
interface KnowledgeBaseProps {
|
||||
@@ -143,6 +145,8 @@ export function KnowledgeBase({
|
||||
const [isDeleting, setIsDeleting] = useState(false)
|
||||
const [isBulkOperating, setIsBulkOperating] = useState(false)
|
||||
const [currentPage, setCurrentPage] = useState(1)
|
||||
const [sortBy, setSortBy] = useState<DocumentSortField>('uploadedAt')
|
||||
const [sortOrder, setSortOrder] = useState<SortOrder>('desc')
|
||||
|
||||
const {
|
||||
knowledgeBase,
|
||||
@@ -160,6 +164,8 @@ export function KnowledgeBase({
|
||||
search: searchQuery || undefined,
|
||||
limit: DOCUMENTS_PER_PAGE,
|
||||
offset: (currentPage - 1) * DOCUMENTS_PER_PAGE,
|
||||
sortBy,
|
||||
sortOrder,
|
||||
})
|
||||
|
||||
const router = useRouter()
|
||||
@@ -194,6 +200,41 @@ export function KnowledgeBase({
|
||||
}
|
||||
}, [hasPrevPage])
|
||||
|
||||
const handleSort = useCallback(
|
||||
(field: DocumentSortField) => {
|
||||
if (sortBy === field) {
|
||||
// Toggle sort order if same field
|
||||
setSortOrder(sortOrder === 'asc' ? 'desc' : 'asc')
|
||||
} else {
|
||||
// Set new field with default desc order
|
||||
setSortBy(field)
|
||||
setSortOrder('desc')
|
||||
}
|
||||
// Reset to first page when sorting changes
|
||||
setCurrentPage(1)
|
||||
},
|
||||
[sortBy, sortOrder]
|
||||
)
|
||||
|
||||
// Helper function to render sortable header
|
||||
const renderSortableHeader = (field: DocumentSortField, label: string, className = '') => (
|
||||
<th className={`px-4 pt-2 pb-3 text-left font-medium ${className}`}>
|
||||
<button
|
||||
type='button'
|
||||
onClick={() => handleSort(field)}
|
||||
className='flex items-center gap-1 text-muted-foreground text-xs leading-none transition-colors hover:text-foreground'
|
||||
>
|
||||
<span>{label}</span>
|
||||
{sortBy === field &&
|
||||
(sortOrder === 'asc' ? (
|
||||
<ChevronUp className='h-3 w-3' />
|
||||
) : (
|
||||
<ChevronDown className='h-3 w-3' />
|
||||
))}
|
||||
</button>
|
||||
</th>
|
||||
)
|
||||
|
||||
// Auto-refresh documents when there are processing documents
|
||||
useEffect(() => {
|
||||
const hasProcessingDocuments = documents.some(
|
||||
@@ -677,6 +718,7 @@ export function KnowledgeBase({
|
||||
value={searchQuery}
|
||||
onChange={handleSearchChange}
|
||||
placeholder='Search documents...'
|
||||
isLoading={isLoadingDocuments}
|
||||
/>
|
||||
|
||||
<div className='flex items-center gap-3'>
|
||||
@@ -732,26 +774,12 @@ export function KnowledgeBase({
|
||||
className='h-3.5 w-3.5 border-gray-300 focus-visible:ring-[var(--brand-primary-hex)]/20 data-[state=checked]:border-[var(--brand-primary-hex)] data-[state=checked]:bg-[var(--brand-primary-hex)] [&>*]:h-3 [&>*]:w-3'
|
||||
/>
|
||||
</th>
|
||||
<th className='px-4 pt-2 pb-3 text-left font-medium'>
|
||||
<span className='text-muted-foreground text-xs leading-none'>Name</span>
|
||||
</th>
|
||||
<th className='px-4 pt-2 pb-3 text-left font-medium'>
|
||||
<span className='text-muted-foreground text-xs leading-none'>Size</span>
|
||||
</th>
|
||||
<th className='px-4 pt-2 pb-3 text-left font-medium'>
|
||||
<span className='text-muted-foreground text-xs leading-none'>Tokens</span>
|
||||
</th>
|
||||
<th className='hidden px-4 pt-2 pb-3 text-left font-medium lg:table-cell'>
|
||||
<span className='text-muted-foreground text-xs leading-none'>Chunks</span>
|
||||
</th>
|
||||
<th className='px-4 pt-2 pb-3 text-left font-medium'>
|
||||
<span className='text-muted-foreground text-xs leading-none'>
|
||||
Uploaded
|
||||
</span>
|
||||
</th>
|
||||
<th className='px-4 pt-2 pb-3 text-left font-medium'>
|
||||
<span className='text-muted-foreground text-xs leading-none'>Status</span>
|
||||
</th>
|
||||
{renderSortableHeader('filename', 'Name')}
|
||||
{renderSortableHeader('fileSize', 'Size')}
|
||||
{renderSortableHeader('tokenCount', 'Tokens')}
|
||||
{renderSortableHeader('chunkCount', 'Chunks', 'hidden lg:table-cell')}
|
||||
{renderSortableHeader('uploadedAt', 'Uploaded')}
|
||||
{renderSortableHeader('processingStatus', 'Status')}
|
||||
<th className='px-4 pt-2 pb-3 text-left font-medium'>
|
||||
<span className='text-muted-foreground text-xs leading-none'>
|
||||
Actions
|
||||
@@ -865,11 +893,7 @@ export function KnowledgeBase({
|
||||
key={doc.id}
|
||||
className={`border-b transition-colors hover:bg-accent/30 ${
|
||||
isSelected ? 'bg-accent/30' : ''
|
||||
} ${
|
||||
doc.processingStatus === 'completed'
|
||||
? 'cursor-pointer'
|
||||
: 'cursor-default'
|
||||
}`}
|
||||
} ${doc.processingStatus === 'completed' ? 'cursor-pointer' : 'cursor-default'}`}
|
||||
onClick={() => {
|
||||
if (doc.processingStatus === 'completed') {
|
||||
handleDocumentClick(doc.id)
|
||||
|
||||
@@ -166,12 +166,6 @@ export function UploadModal({
|
||||
return `${Number.parseFloat((bytes / k ** i).toFixed(1))} ${sizes[i]}`
|
||||
}
|
||||
|
||||
// Calculate progress percentage
|
||||
const progressPercentage =
|
||||
uploadProgress.totalFiles > 0
|
||||
? Math.round((uploadProgress.filesCompleted / uploadProgress.totalFiles) * 100)
|
||||
: 0
|
||||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={handleClose}>
|
||||
<DialogContent className='flex max-h-[95vh] max-w-2xl flex-col overflow-hidden'>
|
||||
@@ -296,23 +290,26 @@ export function UploadModal({
|
||||
</div>
|
||||
|
||||
{/* Footer */}
|
||||
<div className='flex justify-end gap-3 border-t pt-4'>
|
||||
<Button variant='outline' onClick={handleClose} disabled={isUploading}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
onClick={handleUpload}
|
||||
disabled={files.length === 0 || isUploading}
|
||||
className='bg-[var(--brand-primary-hex)] font-[480] text-primary-foreground shadow-[0_0_0_0_var(--brand-primary-hex)] transition-all duration-200 hover:bg-[var(--brand-primary-hover-hex)] hover:shadow-[0_0_0_4px_rgba(127,47,255,0.15)]'
|
||||
>
|
||||
{isUploading
|
||||
? uploadProgress.stage === 'uploading'
|
||||
? `Uploading ${uploadProgress.filesCompleted + 1}/${uploadProgress.totalFiles}...`
|
||||
: uploadProgress.stage === 'processing'
|
||||
? 'Processing...'
|
||||
: 'Uploading...'
|
||||
: `Upload ${files.length} file${files.length !== 1 ? 's' : ''}`}
|
||||
</Button>
|
||||
<div className='flex justify-between border-t pt-4'>
|
||||
<div className='flex gap-3' />
|
||||
<div className='flex gap-3'>
|
||||
<Button variant='outline' onClick={handleClose} disabled={isUploading}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
onClick={handleUpload}
|
||||
disabled={files.length === 0 || isUploading}
|
||||
className='bg-[var(--brand-primary-hex)] font-[480] text-primary-foreground shadow-[0_0_0_0_var(--brand-primary-hex)] transition-all duration-200 hover:bg-[var(--brand-primary-hover-hex)] hover:shadow-[0_0_0_4px_rgba(127,47,255,0.15)]'
|
||||
>
|
||||
{isUploading
|
||||
? uploadProgress.stage === 'uploading'
|
||||
? `Uploading ${uploadProgress.filesCompleted + 1}/${uploadProgress.totalFiles}...`
|
||||
: uploadProgress.stage === 'processing'
|
||||
? 'Processing...'
|
||||
: 'Uploading...'
|
||||
: `Upload ${files.length} file${files.length !== 1 ? 's' : ''}`}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import { useEffect, useRef, useState } from 'react'
|
||||
import { zodResolver } from '@hookform/resolvers/zod'
|
||||
import { AlertCircle, X } from 'lucide-react'
|
||||
import { AlertCircle, Check, Loader2, X } from 'lucide-react'
|
||||
import { useParams } from 'next/navigation'
|
||||
import { useForm } from 'react-hook-form'
|
||||
import { z } from 'zod'
|
||||
@@ -11,6 +11,7 @@ import { Button } from '@/components/ui/button'
|
||||
import { Dialog, DialogContent, DialogHeader, DialogTitle } from '@/components/ui/dialog'
|
||||
import { Input } from '@/components/ui/input'
|
||||
import { Label } from '@/components/ui/label'
|
||||
import { Progress } from '@/components/ui/progress'
|
||||
import { Textarea } from '@/components/ui/textarea'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { getDocumentIcon } from '@/app/workspace/[workspaceId]/knowledge/components'
|
||||
@@ -88,9 +89,10 @@ export function CreateModal({ open, onOpenChange, onKnowledgeBaseCreated }: Crea
|
||||
const scrollContainerRef = useRef<HTMLDivElement>(null)
|
||||
const dropZoneRef = useRef<HTMLDivElement>(null)
|
||||
|
||||
const { uploadFiles } = useKnowledgeUpload({
|
||||
const { uploadFiles, isUploading, uploadProgress } = useKnowledgeUpload({
|
||||
onUploadComplete: (uploadedFiles) => {
|
||||
logger.info(`Successfully uploaded ${uploadedFiles.length} files`)
|
||||
// Files uploaded and document records created - processing will continue in background
|
||||
},
|
||||
})
|
||||
|
||||
@@ -303,6 +305,12 @@ export function CreateModal({ open, onOpenChange, onKnowledgeBaseCreated }: Crea
|
||||
const newKnowledgeBase = result.data
|
||||
|
||||
if (files.length > 0) {
|
||||
newKnowledgeBase.docCount = files.length
|
||||
|
||||
if (onKnowledgeBaseCreated) {
|
||||
onKnowledgeBaseCreated(newKnowledgeBase)
|
||||
}
|
||||
|
||||
const uploadedFiles = await uploadFiles(files, newKnowledgeBase.id, {
|
||||
chunkSize: data.maxChunkSize,
|
||||
minCharactersPerChunk: data.minChunkSize,
|
||||
@@ -310,22 +318,17 @@ export function CreateModal({ open, onOpenChange, onKnowledgeBaseCreated }: Crea
|
||||
recipe: 'default',
|
||||
})
|
||||
|
||||
// Update the knowledge base object with the correct document count
|
||||
newKnowledgeBase.docCount = uploadedFiles.length
|
||||
|
||||
logger.info(`Successfully uploaded ${uploadedFiles.length} files`)
|
||||
logger.info(`Started processing ${uploadedFiles.length} documents in the background`)
|
||||
} else {
|
||||
if (onKnowledgeBaseCreated) {
|
||||
onKnowledgeBaseCreated(newKnowledgeBase)
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up file previews
|
||||
files.forEach((file) => URL.revokeObjectURL(file.preview))
|
||||
setFiles([])
|
||||
|
||||
// Call the callback if provided
|
||||
if (onKnowledgeBaseCreated) {
|
||||
onKnowledgeBaseCreated(newKnowledgeBase)
|
||||
}
|
||||
|
||||
// Close modal immediately - no need for success message
|
||||
onOpenChange(false)
|
||||
} catch (error) {
|
||||
logger.error('Error creating knowledge base:', error)
|
||||
@@ -557,29 +560,57 @@ export function CreateModal({ open, onOpenChange, onKnowledgeBaseCreated }: Crea
|
||||
|
||||
{/* File list */}
|
||||
<div className='space-y-2'>
|
||||
{files.map((file, index) => (
|
||||
<div
|
||||
key={index}
|
||||
className='flex items-center gap-3 rounded-md border p-3'
|
||||
>
|
||||
{getFileIcon(file.type, file.name)}
|
||||
<div className='min-w-0 flex-1'>
|
||||
<p className='truncate font-medium text-sm'>{file.name}</p>
|
||||
<p className='text-muted-foreground text-xs'>
|
||||
{formatFileSize(file.size)}
|
||||
</p>
|
||||
</div>
|
||||
<Button
|
||||
type='button'
|
||||
variant='ghost'
|
||||
size='sm'
|
||||
onClick={() => removeFile(index)}
|
||||
className='h-8 w-8 p-0 text-muted-foreground hover:text-destructive'
|
||||
{files.map((file, index) => {
|
||||
const fileStatus = uploadProgress.fileStatuses?.[index]
|
||||
const isCurrentlyUploading = fileStatus?.status === 'uploading'
|
||||
const isCompleted = fileStatus?.status === 'completed'
|
||||
const isFailed = fileStatus?.status === 'failed'
|
||||
|
||||
return (
|
||||
<div
|
||||
key={index}
|
||||
className='flex items-center gap-3 rounded-md border p-3'
|
||||
>
|
||||
<X className='h-4 w-4' />
|
||||
</Button>
|
||||
</div>
|
||||
))}
|
||||
{getFileIcon(file.type, file.name)}
|
||||
<div className='min-w-0 flex-1'>
|
||||
<div className='flex items-center gap-2'>
|
||||
{isCurrentlyUploading && (
|
||||
<Loader2 className='h-4 w-4 animate-spin text-[var(--brand-primary-hex)]' />
|
||||
)}
|
||||
{isCompleted && <Check className='h-4 w-4 text-green-500' />}
|
||||
{isFailed && <X className='h-4 w-4 text-red-500' />}
|
||||
<p className='truncate font-medium text-sm'>{file.name}</p>
|
||||
</div>
|
||||
<div className='flex items-center gap-2'>
|
||||
<p className='text-muted-foreground text-xs'>
|
||||
{formatFileSize(file.size)}
|
||||
</p>
|
||||
{isCurrentlyUploading && (
|
||||
<div className='min-w-0 max-w-32 flex-1'>
|
||||
<Progress
|
||||
value={fileStatus?.progress || 0}
|
||||
className='h-1'
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
{isFailed && fileStatus?.error && (
|
||||
<p className='mt-1 text-red-500 text-xs'>{fileStatus.error}</p>
|
||||
)}
|
||||
</div>
|
||||
<Button
|
||||
type='button'
|
||||
variant='ghost'
|
||||
size='sm'
|
||||
onClick={() => removeFile(index)}
|
||||
disabled={isUploading}
|
||||
className='h-8 w-8 p-0 text-muted-foreground hover:text-destructive'
|
||||
>
|
||||
<X className='h-4 w-4' />
|
||||
</Button>
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
@@ -606,7 +637,15 @@ export function CreateModal({ open, onOpenChange, onKnowledgeBaseCreated }: Crea
|
||||
disabled={isSubmitting || !nameValue?.trim()}
|
||||
className='bg-[var(--brand-primary-hex)] font-[480] text-primary-foreground shadow-[0_0_0_0_var(--brand-primary-hex)] transition-all duration-200 hover:bg-[var(--brand-primary-hover-hex)] hover:shadow-[0_0_0_4px_rgba(127,47,255,0.15)] disabled:opacity-50 disabled:hover:shadow-none'
|
||||
>
|
||||
{isSubmitting ? 'Creating...' : 'Create Knowledge Base'}
|
||||
{isSubmitting
|
||||
? isUploading
|
||||
? uploadProgress.stage === 'uploading'
|
||||
? `Uploading ${uploadProgress.filesCompleted}/${uploadProgress.totalFiles}...`
|
||||
: uploadProgress.stage === 'processing'
|
||||
? 'Processing...'
|
||||
: 'Creating...'
|
||||
: 'Creating...'
|
||||
: 'Create Knowledge Base'}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -83,12 +83,11 @@ class ProcessingError extends KnowledgeUploadError {
|
||||
}
|
||||
}
|
||||
|
||||
// Upload configuration constants
|
||||
// Vercel has a 4.5MB body size limit for API routes
|
||||
const UPLOAD_CONFIG = {
|
||||
BATCH_SIZE: 5, // Upload 5 files in parallel
|
||||
MAX_RETRIES: 3, // Retry failed uploads up to 3 times
|
||||
RETRY_DELAY: 1000, // Initial retry delay in ms
|
||||
BATCH_SIZE: 15, // Upload files in parallel - this is fast and not the bottleneck
|
||||
MAX_RETRIES: 3, // Standard retry count
|
||||
RETRY_DELAY: 2000, // Initial retry delay in ms (2 seconds)
|
||||
RETRY_MULTIPLIER: 2, // Standard exponential backoff (2s, 4s, 8s)
|
||||
CHUNK_SIZE: 5 * 1024 * 1024,
|
||||
VERCEL_MAX_BODY_SIZE: 4.5 * 1024 * 1024, // Vercel's 4.5MB limit
|
||||
DIRECT_UPLOAD_THRESHOLD: 4 * 1024 * 1024, // Files > 4MB must use presigned URLs
|
||||
@@ -205,7 +204,7 @@ export function useKnowledgeUpload(options: UseKnowledgeUploadOptions = {}) {
|
||||
// Use presigned URLs for all uploads when cloud storage is available
|
||||
// Check if file needs multipart upload for large files
|
||||
if (file.size > UPLOAD_CONFIG.LARGE_FILE_THRESHOLD) {
|
||||
return await uploadFileInChunks(file, presignedData, fileIndex)
|
||||
return await uploadFileInChunks(file, presignedData)
|
||||
}
|
||||
return await uploadFileDirectly(file, presignedData, fileIndex)
|
||||
}
|
||||
@@ -233,13 +232,16 @@ export function useKnowledgeUpload(options: UseKnowledgeUploadOptions = {}) {
|
||||
|
||||
// Retry logic
|
||||
if (retryCount < UPLOAD_CONFIG.MAX_RETRIES) {
|
||||
const delay = UPLOAD_CONFIG.RETRY_DELAY * 2 ** retryCount // Exponential backoff
|
||||
// Only log essential info for debugging
|
||||
const delay = UPLOAD_CONFIG.RETRY_DELAY * UPLOAD_CONFIG.RETRY_MULTIPLIER ** retryCount // More aggressive exponential backoff
|
||||
if (isTimeout || isNetwork) {
|
||||
logger.warn(`Upload failed (${isTimeout ? 'timeout' : 'network'}), retrying...`, {
|
||||
attempt: retryCount + 1,
|
||||
fileSize: file.size,
|
||||
})
|
||||
logger.warn(
|
||||
`Upload failed (${isTimeout ? 'timeout' : 'network'}), retrying in ${delay / 1000}s...`,
|
||||
{
|
||||
attempt: retryCount + 1,
|
||||
fileSize: file.size,
|
||||
delay: delay,
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
// Reset progress to 0 before retry to indicate restart
|
||||
@@ -321,7 +323,9 @@ export function useKnowledgeUpload(options: UseKnowledgeUploadOptions = {}) {
|
||||
reject(
|
||||
new DirectUploadError(
|
||||
`Direct upload failed for ${file.name}: ${xhr.status} ${xhr.statusText}`,
|
||||
{ uploadResponse: xhr.statusText }
|
||||
{
|
||||
uploadResponse: xhr.statusText,
|
||||
}
|
||||
)
|
||||
)
|
||||
}
|
||||
@@ -362,11 +366,7 @@ export function useKnowledgeUpload(options: UseKnowledgeUploadOptions = {}) {
|
||||
/**
|
||||
* Upload large file in chunks (multipart upload)
|
||||
*/
|
||||
const uploadFileInChunks = async (
|
||||
file: File,
|
||||
presignedData: any,
|
||||
fileIndex?: number
|
||||
): Promise<UploadedFile> => {
|
||||
const uploadFileInChunks = async (file: File, presignedData: any): Promise<UploadedFile> => {
|
||||
logger.info(
|
||||
`Uploading large file ${file.name} (${(file.size / 1024 / 1024).toFixed(2)}MB) using multipart upload`
|
||||
)
|
||||
@@ -538,10 +538,10 @@ export function useKnowledgeUpload(options: UseKnowledgeUploadOptions = {}) {
|
||||
}
|
||||
|
||||
/**
|
||||
* Upload files with a constant pool of concurrent uploads
|
||||
* Upload files using batch presigned URLs (works for both S3 and Azure Blob)
|
||||
*/
|
||||
const uploadFilesInBatches = async (files: File[]): Promise<UploadedFile[]> => {
|
||||
const uploadedFiles: UploadedFile[] = []
|
||||
const results: UploadedFile[] = []
|
||||
const failedFiles: Array<{ file: File; error: Error }> = []
|
||||
|
||||
// Initialize file statuses
|
||||
@@ -557,57 +557,100 @@ export function useKnowledgeUpload(options: UseKnowledgeUploadOptions = {}) {
|
||||
fileStatuses,
|
||||
}))
|
||||
|
||||
// Create a queue of files to upload
|
||||
const fileQueue = files.map((file, index) => ({ file, index }))
|
||||
const activeUploads = new Map<number, Promise<any>>()
|
||||
logger.info(`Starting batch upload of ${files.length} files`)
|
||||
|
||||
logger.info(
|
||||
`Starting upload of ${files.length} files with concurrency ${UPLOAD_CONFIG.BATCH_SIZE}`
|
||||
)
|
||||
try {
|
||||
const BATCH_SIZE = 100 // Process 100 files at a time
|
||||
const batches = []
|
||||
|
||||
// Function to start an upload for a file
|
||||
const startUpload = async (file: File, fileIndex: number) => {
|
||||
// Mark file as uploading (only if not already processing)
|
||||
setUploadProgress((prev) => {
|
||||
const currentStatus = prev.fileStatuses?.[fileIndex]?.status
|
||||
// Don't re-upload files that are already completed or currently uploading
|
||||
if (currentStatus === 'completed' || currentStatus === 'uploading') {
|
||||
return prev
|
||||
// Create all batches
|
||||
for (let batchStart = 0; batchStart < files.length; batchStart += BATCH_SIZE) {
|
||||
const batchFiles = files.slice(batchStart, batchStart + BATCH_SIZE)
|
||||
const batchIndexOffset = batchStart
|
||||
batches.push({ batchFiles, batchIndexOffset })
|
||||
}
|
||||
|
||||
logger.info(`Starting parallel processing of ${batches.length} batches`)
|
||||
|
||||
// Step 1: Get ALL presigned URLs in parallel
|
||||
const presignedPromises = batches.map(async ({ batchFiles }, batchIndex) => {
|
||||
logger.info(
|
||||
`Getting presigned URLs for batch ${batchIndex + 1}/${batches.length} (${batchFiles.length} files)`
|
||||
)
|
||||
|
||||
const batchRequest = {
|
||||
files: batchFiles.map((file) => ({
|
||||
fileName: file.name,
|
||||
contentType: file.type,
|
||||
fileSize: file.size,
|
||||
})),
|
||||
}
|
||||
return {
|
||||
...prev,
|
||||
fileStatuses: prev.fileStatuses?.map((fs, idx) =>
|
||||
idx === fileIndex ? { ...fs, status: 'uploading' as const, progress: 0 } : fs
|
||||
),
|
||||
|
||||
const batchResponse = await fetch('/api/files/presigned/batch?type=knowledge-base', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(batchRequest),
|
||||
})
|
||||
|
||||
if (!batchResponse.ok) {
|
||||
throw new Error(
|
||||
`Batch ${batchIndex + 1} presigned URL generation failed: ${batchResponse.statusText}`
|
||||
)
|
||||
}
|
||||
|
||||
const { files: presignedData } = await batchResponse.json()
|
||||
return { batchFiles, presignedData, batchIndex }
|
||||
})
|
||||
|
||||
try {
|
||||
const result = await uploadSingleFileWithRetry(file, 0, fileIndex)
|
||||
const allPresignedData = await Promise.all(presignedPromises)
|
||||
logger.info(`Got all presigned URLs, starting uploads`)
|
||||
|
||||
// Mark file as completed (with atomic update)
|
||||
setUploadProgress((prev) => {
|
||||
// Only mark as completed if still uploading (prevent race conditions)
|
||||
if (prev.fileStatuses?.[fileIndex]?.status === 'uploading') {
|
||||
return {
|
||||
// Step 2: Upload all files with global concurrency control
|
||||
const allUploads = allPresignedData.flatMap(({ batchFiles, presignedData, batchIndex }) => {
|
||||
const batchIndexOffset = batchIndex * BATCH_SIZE
|
||||
|
||||
return batchFiles.map((file, batchFileIndex) => {
|
||||
const fileIndex = batchIndexOffset + batchFileIndex
|
||||
const presigned = presignedData[batchFileIndex]
|
||||
|
||||
return { file, presigned, fileIndex }
|
||||
})
|
||||
})
|
||||
|
||||
// Process all uploads with concurrency control
|
||||
for (let i = 0; i < allUploads.length; i += UPLOAD_CONFIG.BATCH_SIZE) {
|
||||
const concurrentBatch = allUploads.slice(i, i + UPLOAD_CONFIG.BATCH_SIZE)
|
||||
|
||||
const uploadPromises = concurrentBatch.map(async ({ file, presigned, fileIndex }) => {
|
||||
if (!presigned) {
|
||||
throw new Error(`No presigned data for file ${file.name}`)
|
||||
}
|
||||
|
||||
// Mark as uploading
|
||||
setUploadProgress((prev) => ({
|
||||
...prev,
|
||||
fileStatuses: prev.fileStatuses?.map((fs, idx) =>
|
||||
idx === fileIndex ? { ...fs, status: 'uploading' as const } : fs
|
||||
),
|
||||
}))
|
||||
|
||||
try {
|
||||
// Upload directly to storage
|
||||
const result = await uploadFileDirectly(file, presigned, fileIndex)
|
||||
|
||||
// Mark as completed
|
||||
setUploadProgress((prev) => ({
|
||||
...prev,
|
||||
filesCompleted: prev.filesCompleted + 1,
|
||||
fileStatuses: prev.fileStatuses?.map((fs, idx) =>
|
||||
idx === fileIndex ? { ...fs, status: 'completed' as const, progress: 100 } : fs
|
||||
),
|
||||
}
|
||||
}
|
||||
return prev
|
||||
})
|
||||
}))
|
||||
|
||||
uploadedFiles.push(result)
|
||||
return { success: true, file, result }
|
||||
} catch (error) {
|
||||
// Mark file as failed (with atomic update)
|
||||
setUploadProgress((prev) => {
|
||||
// Only mark as failed if still uploading
|
||||
if (prev.fileStatuses?.[fileIndex]?.status === 'uploading') {
|
||||
return {
|
||||
return result
|
||||
} catch (error) {
|
||||
// Mark as failed
|
||||
setUploadProgress((prev) => ({
|
||||
...prev,
|
||||
fileStatuses: prev.fileStatuses?.map((fs, idx) =>
|
||||
idx === fileIndex
|
||||
@@ -618,52 +661,44 @@ export function useKnowledgeUpload(options: UseKnowledgeUploadOptions = {}) {
|
||||
}
|
||||
: fs
|
||||
),
|
||||
}
|
||||
}))
|
||||
throw error
|
||||
}
|
||||
return prev
|
||||
})
|
||||
|
||||
failedFiles.push({
|
||||
file,
|
||||
error: error instanceof Error ? error : new Error(String(error)),
|
||||
})
|
||||
const batchResults = await Promise.allSettled(uploadPromises)
|
||||
|
||||
return {
|
||||
success: false,
|
||||
file,
|
||||
error: error instanceof Error ? error : new Error(String(error)),
|
||||
for (let j = 0; j < batchResults.length; j++) {
|
||||
const result = batchResults[j]
|
||||
if (result.status === 'fulfilled') {
|
||||
results.push(result.value)
|
||||
} else {
|
||||
failedFiles.push({
|
||||
file: concurrentBatch[j].file,
|
||||
error:
|
||||
result.reason instanceof Error ? result.reason : new Error(String(result.reason)),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process files with constant concurrency pool
|
||||
while (fileQueue.length > 0 || activeUploads.size > 0) {
|
||||
// Start new uploads up to the batch size limit
|
||||
while (fileQueue.length > 0 && activeUploads.size < UPLOAD_CONFIG.BATCH_SIZE) {
|
||||
const { file, index } = fileQueue.shift()!
|
||||
const uploadPromise = startUpload(file, index).finally(() => {
|
||||
activeUploads.delete(index)
|
||||
})
|
||||
activeUploads.set(index, uploadPromise)
|
||||
if (failedFiles.length > 0) {
|
||||
logger.error(`Failed to upload ${failedFiles.length} files`)
|
||||
throw new KnowledgeUploadError(
|
||||
`Failed to upload ${failedFiles.length} file(s)`,
|
||||
'PARTIAL_UPLOAD_FAILURE',
|
||||
{
|
||||
failedFiles,
|
||||
uploadedFiles: results,
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
// Wait for at least one upload to complete if we're at capacity or done with queue
|
||||
if (activeUploads.size > 0) {
|
||||
await Promise.race(Array.from(activeUploads.values()))
|
||||
}
|
||||
return results
|
||||
} catch (error) {
|
||||
logger.error('Batch upload failed:', error)
|
||||
throw error
|
||||
}
|
||||
|
||||
// Report failed files
|
||||
if (failedFiles.length > 0) {
|
||||
logger.error(`Failed to upload ${failedFiles.length} files:`, failedFiles)
|
||||
const errorMessage = `Failed to upload ${failedFiles.length} file(s): ${failedFiles.map((f) => f.file.name).join(', ')}`
|
||||
throw new KnowledgeUploadError(errorMessage, 'PARTIAL_UPLOAD_FAILURE', {
|
||||
failedFiles,
|
||||
uploadedFiles,
|
||||
})
|
||||
}
|
||||
|
||||
return uploadedFiles
|
||||
}
|
||||
|
||||
const uploadFiles = async (
|
||||
|
||||
@@ -48,26 +48,29 @@ export function SubdomainInput({
|
||||
Subdomain
|
||||
</Label>
|
||||
<div className='relative flex items-center rounded-md ring-offset-background focus-within:ring-2 focus-within:ring-ring focus-within:ring-offset-2'>
|
||||
<Input
|
||||
id='subdomain'
|
||||
placeholder='company-name'
|
||||
value={value}
|
||||
onChange={(e) => handleChange(e.target.value)}
|
||||
required
|
||||
disabled={disabled}
|
||||
className={cn(
|
||||
'rounded-r-none border-r-0 focus-visible:ring-0 focus-visible:ring-offset-0',
|
||||
error && 'border-destructive focus-visible:border-destructive'
|
||||
<div className='relative flex-1'>
|
||||
<Input
|
||||
id='subdomain'
|
||||
placeholder='company-name'
|
||||
value={value}
|
||||
onChange={(e) => handleChange(e.target.value)}
|
||||
required
|
||||
disabled={disabled}
|
||||
className={cn(
|
||||
'rounded-r-none border-r-0 focus-visible:ring-0 focus-visible:ring-offset-0',
|
||||
isChecking && 'pr-8',
|
||||
error && 'border-destructive focus-visible:border-destructive'
|
||||
)}
|
||||
/>
|
||||
{isChecking && (
|
||||
<div className='-translate-y-1/2 absolute top-1/2 right-2'>
|
||||
<div className='h-[18px] w-[18px] animate-spin rounded-full border-2 border-gray-300 border-t-[var(--brand-primary-hex)]' />
|
||||
</div>
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
<div className='flex h-10 items-center whitespace-nowrap rounded-r-md border border-l-0 bg-muted px-3 font-medium text-muted-foreground text-sm'>
|
||||
{getDomainSuffix()}
|
||||
</div>
|
||||
{isChecking && (
|
||||
<div className='absolute right-14 flex items-center'>
|
||||
<div className='h-4 w-4 animate-spin rounded-full border-2 border-gray-300 border-t-blue-600' />
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
{error && <p className='mt-1 text-destructive text-sm'>{error}</p>}
|
||||
</div>
|
||||
|
||||
@@ -355,9 +355,7 @@ export function OutputSelect({
|
||||
</span>
|
||||
)}
|
||||
<ChevronDown
|
||||
className={`ml-1 h-4 w-4 flex-shrink-0 transition-transform ${
|
||||
isOutputDropdownOpen ? 'rotate-180' : ''
|
||||
}`}
|
||||
className={`ml-1 h-4 w-4 flex-shrink-0 transition-transform ${isOutputDropdownOpen ? 'rotate-180' : ''}`}
|
||||
/>
|
||||
</button>
|
||||
|
||||
|
||||
@@ -417,9 +417,9 @@ export const Copilot = forwardRef<CopilotRef, CopilotProps>(({ panelWidth }, ref
|
||||
onClick={scrollToBottom}
|
||||
size='sm'
|
||||
variant='outline'
|
||||
className='flex items-center gap-1 rounded-full border border-gray-200 bg-white px-3 py-1 shadow-lg transition-all hover:bg-gray-50'
|
||||
className='flex items-center gap-1 rounded-full border border-gray-200 bg-white px-3 py-1 shadow-lg transition-all hover:bg-gray-50 dark:border-gray-600 dark:bg-gray-800 dark:hover:bg-gray-700'
|
||||
>
|
||||
<ArrowDown className='h-3.5 w-3.5' />
|
||||
<ArrowDown className='h-3.5 w-3.5 text-gray-700 dark:text-gray-300' />
|
||||
<span className='sr-only'>Scroll to bottom</span>
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
@@ -387,15 +387,19 @@ export function Panel() {
|
||||
open={isHistoryDropdownOpen}
|
||||
onOpenChange={handleHistoryDropdownOpen}
|
||||
>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<button
|
||||
className='font-medium text-md leading-normal transition-[filter] hover:brightness-75 focus:outline-none focus-visible:outline-none active:outline-none dark:hover:brightness-125'
|
||||
style={{ color: 'var(--base-muted-foreground)' }}
|
||||
title='Chat history'
|
||||
>
|
||||
<History className='h-4 w-4' strokeWidth={2} />
|
||||
</button>
|
||||
</DropdownMenuTrigger>
|
||||
<Tooltip>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<TooltipTrigger asChild>
|
||||
<button
|
||||
className='font-medium text-md leading-normal transition-[filter] hover:brightness-75 focus:outline-none focus-visible:outline-none active:outline-none dark:hover:brightness-125'
|
||||
style={{ color: 'var(--base-muted-foreground)' }}
|
||||
>
|
||||
<History className='h-4 w-4' strokeWidth={2} />
|
||||
</button>
|
||||
</TooltipTrigger>
|
||||
</DropdownMenuTrigger>
|
||||
<TooltipContent side='bottom'>Chat history</TooltipContent>
|
||||
</Tooltip>
|
||||
<DropdownMenuContent
|
||||
align='end'
|
||||
className='z-[200] w-48 rounded-lg border-[#E5E5E5] bg-[#FFFFFF] shadow-xs dark:border-[#414141] dark:bg-[var(--surface-elevated)]'
|
||||
@@ -478,13 +482,18 @@ export function Panel() {
|
||||
<TooltipContent side='bottom'>Clear {activeTab}</TooltipContent>
|
||||
</Tooltip>
|
||||
)}
|
||||
<button
|
||||
onClick={handleClosePanel}
|
||||
className='font-medium text-md leading-normal transition-[filter] hover:brightness-75 focus:outline-none focus-visible:outline-none active:outline-none dark:hover:brightness-125'
|
||||
style={{ color: 'var(--base-muted-foreground)' }}
|
||||
>
|
||||
<X className='h-4 w-4' strokeWidth={2} />
|
||||
</button>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<button
|
||||
onClick={handleClosePanel}
|
||||
className='font-medium text-md leading-normal transition-[filter] hover:brightness-75 focus:outline-none focus-visible:outline-none active:outline-none dark:hover:brightness-125'
|
||||
style={{ color: 'var(--base-muted-foreground)' }}
|
||||
>
|
||||
<X className='h-4 w-4' strokeWidth={2} />
|
||||
</button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent side='bottom'>Close panel</TooltipContent>
|
||||
</Tooltip>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
@@ -155,7 +155,9 @@ export function FolderSelector({
|
||||
if (!accessToken) return null
|
||||
const resp = await fetch(
|
||||
`https://graph.microsoft.com/v1.0/me/mailFolders/${encodeURIComponent(folderId)}`,
|
||||
{ headers: { Authorization: `Bearer ${accessToken}` } }
|
||||
{
|
||||
headers: { Authorization: `Bearer ${accessToken}` },
|
||||
}
|
||||
)
|
||||
if (!resp.ok) return null
|
||||
const folder = await resp.json()
|
||||
|
||||
@@ -1440,16 +1440,12 @@ export function ToolInput({
|
||||
Auto
|
||||
</span>
|
||||
<span
|
||||
className={`font-medium text-xs ${
|
||||
tool.usageControl === 'force' ? 'block' : 'hidden'
|
||||
}`}
|
||||
className={`font-medium text-xs ${tool.usageControl === 'force' ? 'block' : 'hidden'}`}
|
||||
>
|
||||
Force
|
||||
</span>
|
||||
<span
|
||||
className={`font-medium text-xs ${
|
||||
tool.usageControl === 'none' ? 'block' : 'hidden'
|
||||
}`}
|
||||
className={`font-medium text-xs ${tool.usageControl === 'none' ? 'block' : 'hidden'}`}
|
||||
>
|
||||
None
|
||||
</span>
|
||||
|
||||
@@ -552,9 +552,7 @@ const WorkflowContent = React.memo(() => {
|
||||
|
||||
// Create a new block with a unique ID
|
||||
const id = crypto.randomUUID()
|
||||
const name = `${blockConfig.name} ${
|
||||
Object.values(blocks).filter((b) => b.type === type).length + 1
|
||||
}`
|
||||
const name = `${blockConfig.name} ${Object.values(blocks).filter((b) => b.type === type).length + 1}`
|
||||
|
||||
// Auto-connect logic
|
||||
const isAutoConnectEnabled = useGeneralStore.getState().isAutoConnectEnabled
|
||||
|
||||
@@ -889,9 +889,7 @@ export function Sidebar() {
|
||||
|
||||
{/* 2. Workspace Selector */}
|
||||
<div
|
||||
className={`pointer-events-auto flex-shrink-0 ${
|
||||
!isWorkspaceSelectorVisible ? 'hidden' : ''
|
||||
}`}
|
||||
className={`pointer-events-auto flex-shrink-0 ${!isWorkspaceSelectorVisible ? 'hidden' : ''}`}
|
||||
>
|
||||
<WorkspaceSelector
|
||||
workspaces={workspaces}
|
||||
|
||||
@@ -100,7 +100,9 @@ describe('StreamingResponseFormatProcessor', () => {
|
||||
mockStream,
|
||||
'block-1',
|
||||
['block-1_username', 'block-1_age'],
|
||||
{ schema: { properties: { username: { type: 'string' }, age: { type: 'number' } } } }
|
||||
{
|
||||
schema: { properties: { username: { type: 'string' }, age: { type: 'number' } } },
|
||||
}
|
||||
)
|
||||
|
||||
const reader = processedStream.getReader()
|
||||
@@ -132,7 +134,9 @@ describe('StreamingResponseFormatProcessor', () => {
|
||||
mockStream,
|
||||
'block-1',
|
||||
['block-1_config', 'block-1_count'],
|
||||
{ schema: { properties: { config: { type: 'object' }, count: { type: 'number' } } } }
|
||||
{
|
||||
schema: { properties: { config: { type: 'object' }, count: { type: 'number' } } },
|
||||
}
|
||||
)
|
||||
|
||||
const reader = processedStream.getReader()
|
||||
|
||||
@@ -45,7 +45,13 @@ const DEFAULT_PAGE_SIZE = 50
|
||||
|
||||
export function useKnowledgeBaseDocuments(
|
||||
knowledgeBaseId: string,
|
||||
options?: { search?: string; limit?: number; offset?: number }
|
||||
options?: {
|
||||
search?: string
|
||||
limit?: number
|
||||
offset?: number
|
||||
sortBy?: string
|
||||
sortOrder?: string
|
||||
}
|
||||
) {
|
||||
const { getDocuments, getCachedDocuments, loadingDocuments, updateDocument, refreshDocuments } =
|
||||
useKnowledgeStore()
|
||||
@@ -55,10 +61,12 @@ export function useKnowledgeBaseDocuments(
|
||||
const documentsCache = getCachedDocuments(knowledgeBaseId)
|
||||
const isLoading = loadingDocuments.has(knowledgeBaseId)
|
||||
|
||||
// Load documents with server-side pagination and search
|
||||
// Load documents with server-side pagination, search, and sorting
|
||||
const requestLimit = options?.limit || DEFAULT_PAGE_SIZE
|
||||
const requestOffset = options?.offset || 0
|
||||
const requestSearch = options?.search
|
||||
const requestSortBy = options?.sortBy
|
||||
const requestSortOrder = options?.sortOrder
|
||||
|
||||
useEffect(() => {
|
||||
if (!knowledgeBaseId || isLoading) return
|
||||
@@ -72,6 +80,8 @@ export function useKnowledgeBaseDocuments(
|
||||
search: requestSearch,
|
||||
limit: requestLimit,
|
||||
offset: requestOffset,
|
||||
sortBy: requestSortBy,
|
||||
sortOrder: requestSortOrder,
|
||||
})
|
||||
} catch (err) {
|
||||
if (isMounted) {
|
||||
@@ -85,7 +95,16 @@ export function useKnowledgeBaseDocuments(
|
||||
return () => {
|
||||
isMounted = false
|
||||
}
|
||||
}, [knowledgeBaseId, isLoading, getDocuments, requestSearch, requestLimit, requestOffset])
|
||||
}, [
|
||||
knowledgeBaseId,
|
||||
isLoading,
|
||||
getDocuments,
|
||||
requestSearch,
|
||||
requestLimit,
|
||||
requestOffset,
|
||||
requestSortBy,
|
||||
requestSortOrder,
|
||||
])
|
||||
|
||||
// Use server-side filtered and paginated results directly
|
||||
const documents = documentsCache?.documents || []
|
||||
@@ -103,11 +122,21 @@ export function useKnowledgeBaseDocuments(
|
||||
search: requestSearch,
|
||||
limit: requestLimit,
|
||||
offset: requestOffset,
|
||||
sortBy: requestSortBy,
|
||||
sortOrder: requestSortOrder,
|
||||
})
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : 'Failed to refresh documents')
|
||||
}
|
||||
}, [knowledgeBaseId, refreshDocuments, requestSearch, requestLimit, requestOffset])
|
||||
}, [
|
||||
knowledgeBaseId,
|
||||
refreshDocuments,
|
||||
requestSearch,
|
||||
requestLimit,
|
||||
requestOffset,
|
||||
requestSortBy,
|
||||
requestSortOrder,
|
||||
])
|
||||
|
||||
const updateDocumentLocal = useCallback(
|
||||
(documentId: string, updates: Partial<DocumentData>) => {
|
||||
|
||||
@@ -17,15 +17,14 @@ export const searchDocumentationServerTool: BaseServerTool<DocsSearchParams, any
|
||||
const { query, topK = 10, threshold } = params
|
||||
if (!query || typeof query !== 'string') throw new Error('query is required')
|
||||
|
||||
logger.info('Executing docs search (new runtime)', { query, topK })
|
||||
logger.info('Executing docs search', { query, topK })
|
||||
|
||||
const { getCopilotConfig } = await import('@/lib/copilot/config')
|
||||
const config = getCopilotConfig()
|
||||
const similarityThreshold = threshold ?? config.rag.similarityThreshold
|
||||
|
||||
const { generateEmbeddings } = await import('@/app/api/knowledge/utils')
|
||||
const embeddings = await generateEmbeddings([query])
|
||||
const queryEmbedding = embeddings[0]
|
||||
const { generateSearchEmbedding } = await import('@/lib/embeddings/utils')
|
||||
const queryEmbedding = await generateSearchEmbedding(query)
|
||||
if (!queryEmbedding || queryEmbedding.length === 0) {
|
||||
return { results: [], query, totalResults: 0 }
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { isRetryableError, retryWithExponentialBackoff } from '@/lib/documents/utils'
|
||||
import { env } from '@/lib/env'
|
||||
import { isRetryableError, retryWithExponentialBackoff } from '@/lib/knowledge/documents/utils'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
|
||||
const logger = createLogger('EmbeddingUtils')
|
||||
@@ -104,7 +104,7 @@ async function callEmbeddingAPI(inputs: string[], config: EmbeddingConfig): Prom
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate embeddings for multiple texts with batching
|
||||
* Generate embeddings for multiple texts with simple batching
|
||||
*/
|
||||
export async function generateEmbeddings(
|
||||
texts: string[],
|
||||
|
||||
@@ -2,6 +2,7 @@ import { createReadStream, existsSync } from 'fs'
|
||||
import { Readable } from 'stream'
|
||||
import csvParser from 'csv-parser'
|
||||
import type { FileParseResult, FileParser } from '@/lib/file-parsers/types'
|
||||
import { sanitizeTextForUTF8 } from '@/lib/file-parsers/utils'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
|
||||
const logger = createLogger('CsvParser')
|
||||
@@ -41,17 +42,20 @@ export class CsvParser implements FileParser {
|
||||
|
||||
// Add headers
|
||||
if (headers.length > 0) {
|
||||
content += `${headers.join(', ')}\n`
|
||||
const cleanHeaders = headers.map((h) => sanitizeTextForUTF8(String(h)))
|
||||
content += `${cleanHeaders.join(', ')}\n`
|
||||
}
|
||||
|
||||
// Add rows
|
||||
results.forEach((row) => {
|
||||
const rowValues = Object.values(row).join(', ')
|
||||
content += `${rowValues}\n`
|
||||
const cleanValues = Object.values(row).map((v) =>
|
||||
sanitizeTextForUTF8(String(v || ''))
|
||||
)
|
||||
content += `${cleanValues.join(', ')}\n`
|
||||
})
|
||||
|
||||
resolve({
|
||||
content,
|
||||
content: sanitizeTextForUTF8(content),
|
||||
metadata: {
|
||||
rowCount: results.length,
|
||||
headers: headers,
|
||||
@@ -101,17 +105,20 @@ export class CsvParser implements FileParser {
|
||||
|
||||
// Add headers
|
||||
if (headers.length > 0) {
|
||||
content += `${headers.join(', ')}\n`
|
||||
const cleanHeaders = headers.map((h) => sanitizeTextForUTF8(String(h)))
|
||||
content += `${cleanHeaders.join(', ')}\n`
|
||||
}
|
||||
|
||||
// Add rows
|
||||
results.forEach((row) => {
|
||||
const rowValues = Object.values(row).join(', ')
|
||||
content += `${rowValues}\n`
|
||||
const cleanValues = Object.values(row).map((v) =>
|
||||
sanitizeTextForUTF8(String(v || ''))
|
||||
)
|
||||
content += `${cleanValues.join(', ')}\n`
|
||||
})
|
||||
|
||||
resolve({
|
||||
content,
|
||||
content: sanitizeTextForUTF8(content),
|
||||
metadata: {
|
||||
rowCount: results.length,
|
||||
headers: headers,
|
||||
|
||||
126
apps/sim/lib/file-parsers/doc-parser.ts
Normal file
126
apps/sim/lib/file-parsers/doc-parser.ts
Normal file
@@ -0,0 +1,126 @@
|
||||
import { existsSync } from 'fs'
|
||||
import { readFile } from 'fs/promises'
|
||||
import type { FileParseResult, FileParser } from '@/lib/file-parsers/types'
|
||||
import { sanitizeTextForUTF8 } from '@/lib/file-parsers/utils'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
|
||||
const logger = createLogger('DocParser')
|
||||
|
||||
export class DocParser implements FileParser {
|
||||
async parseFile(filePath: string): Promise<FileParseResult> {
|
||||
try {
|
||||
// Validate input
|
||||
if (!filePath) {
|
||||
throw new Error('No file path provided')
|
||||
}
|
||||
|
||||
// Check if file exists
|
||||
if (!existsSync(filePath)) {
|
||||
throw new Error(`File not found: ${filePath}`)
|
||||
}
|
||||
|
||||
logger.info(`Parsing DOC file: ${filePath}`)
|
||||
|
||||
// Read the file
|
||||
const buffer = await readFile(filePath)
|
||||
return this.parseBuffer(buffer)
|
||||
} catch (error) {
|
||||
logger.error('DOC file parsing error:', error)
|
||||
throw new Error(`Failed to parse DOC file: ${(error as Error).message}`)
|
||||
}
|
||||
}
|
||||
|
||||
async parseBuffer(buffer: Buffer): Promise<FileParseResult> {
|
||||
try {
|
||||
logger.info('Parsing DOC buffer, size:', buffer.length)
|
||||
|
||||
if (!buffer || buffer.length === 0) {
|
||||
throw new Error('Empty buffer provided')
|
||||
}
|
||||
|
||||
// Try to dynamically import the word extractor
|
||||
let WordExtractor
|
||||
try {
|
||||
WordExtractor = (await import('word-extractor')).default
|
||||
} catch (importError) {
|
||||
logger.warn('word-extractor not available, using fallback extraction')
|
||||
return this.fallbackExtraction(buffer)
|
||||
}
|
||||
|
||||
try {
|
||||
const extractor = new WordExtractor()
|
||||
const extracted = await extractor.extract(buffer)
|
||||
|
||||
const content = sanitizeTextForUTF8(extracted.getBody())
|
||||
const headers = extracted.getHeaders()
|
||||
const footers = extracted.getFooters()
|
||||
|
||||
// Combine body with headers/footers if they exist
|
||||
let fullContent = content
|
||||
if (headers?.trim()) {
|
||||
fullContent = `${sanitizeTextForUTF8(headers)}\n\n${fullContent}`
|
||||
}
|
||||
if (footers?.trim()) {
|
||||
fullContent = `${fullContent}\n\n${sanitizeTextForUTF8(footers)}`
|
||||
}
|
||||
|
||||
logger.info('DOC parsing completed successfully')
|
||||
|
||||
return {
|
||||
content: fullContent.trim(),
|
||||
metadata: {
|
||||
hasHeaders: !!headers?.trim(),
|
||||
hasFooters: !!footers?.trim(),
|
||||
characterCount: fullContent.length,
|
||||
extractionMethod: 'word-extractor',
|
||||
},
|
||||
}
|
||||
} catch (extractError) {
|
||||
logger.warn('word-extractor failed, using fallback:', extractError)
|
||||
return this.fallbackExtraction(buffer)
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('DOC buffer parsing error:', error)
|
||||
throw new Error(`Failed to parse DOC buffer: ${(error as Error).message}`)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Fallback extraction method for when word-extractor is not available
|
||||
* This is a very basic extraction that looks for readable text in the binary
|
||||
*/
|
||||
private fallbackExtraction(buffer: Buffer): FileParseResult {
|
||||
logger.info('Using fallback text extraction for DOC file')
|
||||
|
||||
// Convert buffer to string and try to extract readable text
|
||||
// This is very basic and won't work well for complex DOC files
|
||||
const text = buffer.toString('utf8', 0, Math.min(buffer.length, 100000)) // Limit to first 100KB
|
||||
|
||||
// Extract sequences of printable ASCII characters
|
||||
const readableText = text
|
||||
.match(/[\x20-\x7E\s]{4,}/g) // Find sequences of 4+ printable characters
|
||||
?.filter(
|
||||
(chunk) =>
|
||||
chunk.trim().length > 10 && // Minimum length
|
||||
/[a-zA-Z]/.test(chunk) && // Must contain letters
|
||||
!/^[\x00-\x1F]*$/.test(chunk) // Not just control characters
|
||||
)
|
||||
.join(' ')
|
||||
.replace(/\s+/g, ' ')
|
||||
.trim()
|
||||
|
||||
const content = readableText
|
||||
? sanitizeTextForUTF8(readableText)
|
||||
: 'Unable to extract text from DOC file. Please convert to DOCX format for better results.'
|
||||
|
||||
return {
|
||||
content,
|
||||
metadata: {
|
||||
extractionMethod: 'fallback',
|
||||
characterCount: content.length,
|
||||
warning:
|
||||
'Basic text extraction used. For better results, install word-extractor package or convert to DOCX format.',
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -76,6 +76,13 @@ function getParserInstances(): Record<string, FileParser> {
|
||||
logger.error('Failed to load DOCX parser:', error)
|
||||
}
|
||||
|
||||
try {
|
||||
const { DocParser } = require('@/lib/file-parsers/doc-parser')
|
||||
parserInstances.doc = new DocParser()
|
||||
} catch (error) {
|
||||
logger.error('Failed to load DOC parser:', error)
|
||||
}
|
||||
|
||||
try {
|
||||
const { TxtParser } = require('@/lib/file-parsers/txt-parser')
|
||||
parserInstances.txt = new TxtParser()
|
||||
@@ -102,7 +109,6 @@ function getParserInstances(): Record<string, FileParser> {
|
||||
}
|
||||
}
|
||||
|
||||
logger.info('Available parsers:', Object.keys(parserInstances))
|
||||
return parserInstances
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { readFile } from 'fs/promises'
|
||||
import type { FileParseResult, FileParser } from '@/lib/file-parsers/types'
|
||||
import { sanitizeTextForUTF8 } from '@/lib/file-parsers/utils'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
|
||||
const logger = createLogger('TxtParser')
|
||||
@@ -27,8 +28,9 @@ export class TxtParser implements FileParser {
|
||||
try {
|
||||
logger.info('Parsing buffer, size:', buffer.length)
|
||||
|
||||
// Extract content
|
||||
const result = buffer.toString('utf-8')
|
||||
// Extract content and sanitize for UTF-8 storage
|
||||
const rawContent = buffer.toString('utf-8')
|
||||
const result = sanitizeTextForUTF8(rawContent)
|
||||
|
||||
return {
|
||||
content: result,
|
||||
|
||||
@@ -8,4 +8,4 @@ export interface FileParser {
|
||||
parseBuffer?(buffer: Buffer): Promise<FileParseResult>
|
||||
}
|
||||
|
||||
export type SupportedFileType = 'pdf' | 'csv' | 'docx' | 'xlsx' | 'xls'
|
||||
export type SupportedFileType = 'pdf' | 'csv' | 'doc' | 'docx' | 'txt' | 'md' | 'xlsx' | 'xls'
|
||||
|
||||
42
apps/sim/lib/file-parsers/utils.ts
Normal file
42
apps/sim/lib/file-parsers/utils.ts
Normal file
@@ -0,0 +1,42 @@
|
||||
/**
|
||||
* Utility functions for file parsing
|
||||
*/
|
||||
|
||||
/**
|
||||
* Clean text content to ensure it's safe for UTF-8 storage in PostgreSQL
|
||||
* Removes null bytes and control characters that can cause encoding errors
|
||||
*/
|
||||
export function sanitizeTextForUTF8(text: string): string {
|
||||
if (!text || typeof text !== 'string') {
|
||||
return ''
|
||||
}
|
||||
|
||||
return text
|
||||
.replace(/\0/g, '') // Remove null bytes (0x00)
|
||||
.replace(/[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]/g, '') // Remove control characters except \t(0x09), \n(0x0A), \r(0x0D)
|
||||
.replace(/\uFFFD/g, '') // Remove Unicode replacement character
|
||||
.replace(/[\uD800-\uDFFF]/g, '') // Remove unpaired surrogate characters
|
||||
}
|
||||
|
||||
/**
|
||||
* Sanitize an array of strings
|
||||
*/
|
||||
export function sanitizeTextArray(texts: string[]): string[] {
|
||||
return texts.map((text) => sanitizeTextForUTF8(text))
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a string contains problematic characters for UTF-8 storage
|
||||
*/
|
||||
export function hasInvalidUTF8Characters(text: string): boolean {
|
||||
if (!text || typeof text !== 'string') {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for null bytes and control characters
|
||||
return (
|
||||
/[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]/.test(text) ||
|
||||
/\uFFFD/.test(text) ||
|
||||
/[\uD800-\uDFFF]/.test(text)
|
||||
)
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
import { existsSync } from 'fs'
|
||||
import * as XLSX from 'xlsx'
|
||||
import type { FileParseResult, FileParser } from '@/lib/file-parsers/types'
|
||||
import { sanitizeTextForUTF8 } from '@/lib/file-parsers/utils'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
|
||||
const logger = createLogger('XlsxParser')
|
||||
@@ -61,21 +62,22 @@ export class XlsxParser implements FileParser {
|
||||
sheets[sheetName] = sheetData
|
||||
totalRows += sheetData.length
|
||||
|
||||
// Add sheet content to the overall content string
|
||||
content += `Sheet: ${sheetName}\n`
|
||||
content += `=${'='.repeat(sheetName.length + 6)}\n\n`
|
||||
// Add sheet content to the overall content string (clean sheet name)
|
||||
const cleanSheetName = sanitizeTextForUTF8(sheetName)
|
||||
content += `Sheet: ${cleanSheetName}\n`
|
||||
content += `=${'='.repeat(cleanSheetName.length + 6)}\n\n`
|
||||
|
||||
if (sheetData.length > 0) {
|
||||
// Process each row
|
||||
sheetData.forEach((row: unknown, rowIndex: number) => {
|
||||
if (Array.isArray(row) && row.length > 0) {
|
||||
// Convert row to string, handling undefined/null values
|
||||
// Convert row to string, handling undefined/null values and cleaning non-UTF8 characters
|
||||
const rowString = row
|
||||
.map((cell) => {
|
||||
if (cell === null || cell === undefined) {
|
||||
return ''
|
||||
}
|
||||
return String(cell)
|
||||
return sanitizeTextForUTF8(String(cell))
|
||||
})
|
||||
.join('\t')
|
||||
|
||||
@@ -91,8 +93,11 @@ export class XlsxParser implements FileParser {
|
||||
|
||||
logger.info(`XLSX parsing completed: ${sheetNames.length} sheets, ${totalRows} total rows`)
|
||||
|
||||
// Final cleanup of the entire content to ensure UTF-8 compatibility
|
||||
const cleanContent = sanitizeTextForUTF8(content).trim()
|
||||
|
||||
return {
|
||||
content: content.trim(),
|
||||
content: cleanContent,
|
||||
metadata: {
|
||||
sheetCount: sheetNames.length,
|
||||
sheetNames: sheetNames,
|
||||
|
||||
470
apps/sim/lib/knowledge/chunks/service.ts
Normal file
470
apps/sim/lib/knowledge/chunks/service.ts
Normal file
@@ -0,0 +1,470 @@
|
||||
import { createHash, randomUUID } from 'crypto'
|
||||
import { and, asc, eq, ilike, inArray, sql } from 'drizzle-orm'
|
||||
import { generateEmbeddings } from '@/lib/embeddings/utils'
|
||||
import type {
|
||||
BatchOperationResult,
|
||||
ChunkData,
|
||||
ChunkFilters,
|
||||
ChunkQueryResult,
|
||||
CreateChunkData,
|
||||
} from '@/lib/knowledge/chunks/types'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { estimateTokenCount } from '@/lib/tokenization/estimators'
|
||||
import { db } from '@/db'
|
||||
import { document, embedding } from '@/db/schema'
|
||||
|
||||
const logger = createLogger('ChunksService')
|
||||
|
||||
/**
|
||||
* Query chunks for a document with filtering and pagination
|
||||
*/
|
||||
export async function queryChunks(
|
||||
documentId: string,
|
||||
filters: ChunkFilters,
|
||||
requestId: string
|
||||
): Promise<ChunkQueryResult> {
|
||||
const { search, enabled = 'all', limit = 50, offset = 0 } = filters
|
||||
|
||||
// Build query conditions
|
||||
const conditions = [eq(embedding.documentId, documentId)]
|
||||
|
||||
// Add enabled filter
|
||||
if (enabled === 'true') {
|
||||
conditions.push(eq(embedding.enabled, true))
|
||||
} else if (enabled === 'false') {
|
||||
conditions.push(eq(embedding.enabled, false))
|
||||
}
|
||||
|
||||
// Add search filter
|
||||
if (search) {
|
||||
conditions.push(ilike(embedding.content, `%${search}%`))
|
||||
}
|
||||
|
||||
// Fetch chunks
|
||||
const chunks = await db
|
||||
.select({
|
||||
id: embedding.id,
|
||||
chunkIndex: embedding.chunkIndex,
|
||||
content: embedding.content,
|
||||
contentLength: embedding.contentLength,
|
||||
tokenCount: embedding.tokenCount,
|
||||
enabled: embedding.enabled,
|
||||
startOffset: embedding.startOffset,
|
||||
endOffset: embedding.endOffset,
|
||||
tag1: embedding.tag1,
|
||||
tag2: embedding.tag2,
|
||||
tag3: embedding.tag3,
|
||||
tag4: embedding.tag4,
|
||||
tag5: embedding.tag5,
|
||||
tag6: embedding.tag6,
|
||||
tag7: embedding.tag7,
|
||||
createdAt: embedding.createdAt,
|
||||
updatedAt: embedding.updatedAt,
|
||||
})
|
||||
.from(embedding)
|
||||
.where(and(...conditions))
|
||||
.orderBy(asc(embedding.chunkIndex))
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
|
||||
// Get total count for pagination
|
||||
const totalCount = await db
|
||||
.select({ count: sql`count(*)` })
|
||||
.from(embedding)
|
||||
.where(and(...conditions))
|
||||
|
||||
logger.info(`[${requestId}] Retrieved ${chunks.length} chunks for document ${documentId}`)
|
||||
|
||||
return {
|
||||
chunks: chunks as ChunkData[],
|
||||
pagination: {
|
||||
total: Number(totalCount[0]?.count || 0),
|
||||
limit,
|
||||
offset,
|
||||
hasMore: chunks.length === limit,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new chunk for a document
|
||||
*/
|
||||
export async function createChunk(
|
||||
knowledgeBaseId: string,
|
||||
documentId: string,
|
||||
docTags: Record<string, string | null>,
|
||||
chunkData: CreateChunkData,
|
||||
requestId: string
|
||||
): Promise<ChunkData> {
|
||||
// Generate embedding for the content first (outside transaction for performance)
|
||||
logger.info(`[${requestId}] Generating embedding for manual chunk`)
|
||||
const embeddings = await generateEmbeddings([chunkData.content])
|
||||
|
||||
// Calculate accurate token count
|
||||
const tokenCount = estimateTokenCount(chunkData.content, 'openai')
|
||||
|
||||
const chunkId = randomUUID()
|
||||
const now = new Date()
|
||||
|
||||
// Use transaction to atomically get next index and insert chunk
|
||||
const newChunk = await db.transaction(async (tx) => {
|
||||
// Get the next chunk index atomically within the transaction
|
||||
const lastChunk = await tx
|
||||
.select({ chunkIndex: embedding.chunkIndex })
|
||||
.from(embedding)
|
||||
.where(eq(embedding.documentId, documentId))
|
||||
.orderBy(sql`${embedding.chunkIndex} DESC`)
|
||||
.limit(1)
|
||||
|
||||
const nextChunkIndex = lastChunk.length > 0 ? lastChunk[0].chunkIndex + 1 : 0
|
||||
|
||||
const chunkDBData = {
|
||||
id: chunkId,
|
||||
knowledgeBaseId,
|
||||
documentId,
|
||||
chunkIndex: nextChunkIndex,
|
||||
chunkHash: createHash('sha256').update(chunkData.content).digest('hex'),
|
||||
content: chunkData.content,
|
||||
contentLength: chunkData.content.length,
|
||||
tokenCount: tokenCount.count,
|
||||
embedding: embeddings[0],
|
||||
embeddingModel: 'text-embedding-3-small',
|
||||
startOffset: 0, // Manual chunks don't have document offsets
|
||||
endOffset: chunkData.content.length,
|
||||
// Inherit tags from parent document
|
||||
tag1: docTags.tag1,
|
||||
tag2: docTags.tag2,
|
||||
tag3: docTags.tag3,
|
||||
tag4: docTags.tag4,
|
||||
tag5: docTags.tag5,
|
||||
tag6: docTags.tag6,
|
||||
tag7: docTags.tag7,
|
||||
enabled: chunkData.enabled ?? true,
|
||||
createdAt: now,
|
||||
updatedAt: now,
|
||||
}
|
||||
|
||||
await tx.insert(embedding).values(chunkDBData)
|
||||
|
||||
// Update document statistics
|
||||
await tx
|
||||
.update(document)
|
||||
.set({
|
||||
chunkCount: sql`${document.chunkCount} + 1`,
|
||||
tokenCount: sql`${document.tokenCount} + ${tokenCount.count}`,
|
||||
characterCount: sql`${document.characterCount} + ${chunkData.content.length}`,
|
||||
})
|
||||
.where(eq(document.id, documentId))
|
||||
|
||||
return {
|
||||
id: chunkId,
|
||||
chunkIndex: nextChunkIndex,
|
||||
content: chunkData.content,
|
||||
contentLength: chunkData.content.length,
|
||||
tokenCount: tokenCount.count,
|
||||
enabled: chunkData.enabled ?? true,
|
||||
startOffset: 0,
|
||||
endOffset: chunkData.content.length,
|
||||
tag1: docTags.tag1,
|
||||
tag2: docTags.tag2,
|
||||
tag3: docTags.tag3,
|
||||
tag4: docTags.tag4,
|
||||
tag5: docTags.tag5,
|
||||
tag6: docTags.tag6,
|
||||
tag7: docTags.tag7,
|
||||
createdAt: now,
|
||||
updatedAt: now,
|
||||
} as ChunkData
|
||||
})
|
||||
|
||||
logger.info(`[${requestId}] Created chunk ${chunkId} in document ${documentId}`)
|
||||
|
||||
return newChunk
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform batch operations on chunks
|
||||
*/
|
||||
export async function batchChunkOperation(
|
||||
documentId: string,
|
||||
operation: 'enable' | 'disable' | 'delete',
|
||||
chunkIds: string[],
|
||||
requestId: string
|
||||
): Promise<BatchOperationResult> {
|
||||
logger.info(
|
||||
`[${requestId}] Starting batch ${operation} operation on ${chunkIds.length} chunks for document ${documentId}`
|
||||
)
|
||||
|
||||
const errors: string[] = []
|
||||
let successCount = 0
|
||||
|
||||
if (operation === 'delete') {
|
||||
// Handle batch delete with transaction for consistency
|
||||
await db.transaction(async (tx) => {
|
||||
// Get chunks to delete for statistics update
|
||||
const chunksToDelete = await tx
|
||||
.select({
|
||||
id: embedding.id,
|
||||
tokenCount: embedding.tokenCount,
|
||||
contentLength: embedding.contentLength,
|
||||
})
|
||||
.from(embedding)
|
||||
.where(and(eq(embedding.documentId, documentId), inArray(embedding.id, chunkIds)))
|
||||
|
||||
if (chunksToDelete.length === 0) {
|
||||
errors.push('No matching chunks found to delete')
|
||||
return
|
||||
}
|
||||
|
||||
const totalTokensToRemove = chunksToDelete.reduce((sum, chunk) => sum + chunk.tokenCount, 0)
|
||||
const totalCharsToRemove = chunksToDelete.reduce((sum, chunk) => sum + chunk.contentLength, 0)
|
||||
|
||||
// Delete chunks
|
||||
const deleteResult = await tx
|
||||
.delete(embedding)
|
||||
.where(and(eq(embedding.documentId, documentId), inArray(embedding.id, chunkIds)))
|
||||
|
||||
// Update document statistics
|
||||
await tx
|
||||
.update(document)
|
||||
.set({
|
||||
chunkCount: sql`${document.chunkCount} - ${chunksToDelete.length}`,
|
||||
tokenCount: sql`${document.tokenCount} - ${totalTokensToRemove}`,
|
||||
characterCount: sql`${document.characterCount} - ${totalCharsToRemove}`,
|
||||
})
|
||||
.where(eq(document.id, documentId))
|
||||
|
||||
successCount = chunksToDelete.length
|
||||
})
|
||||
} else {
|
||||
// Handle enable/disable operations
|
||||
const enabled = operation === 'enable'
|
||||
|
||||
await db
|
||||
.update(embedding)
|
||||
.set({
|
||||
enabled,
|
||||
updatedAt: new Date(),
|
||||
})
|
||||
.where(and(eq(embedding.documentId, documentId), inArray(embedding.id, chunkIds)))
|
||||
|
||||
// For enable/disable, we assume all chunks were processed successfully
|
||||
successCount = chunkIds.length
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Batch ${operation} completed: ${successCount} chunks processed, ${errors.length} errors`
|
||||
)
|
||||
|
||||
return {
|
||||
success: errors.length === 0,
|
||||
processed: successCount,
|
||||
errors,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Update a single chunk
|
||||
*/
|
||||
export async function updateChunk(
|
||||
chunkId: string,
|
||||
updateData: {
|
||||
content?: string
|
||||
enabled?: boolean
|
||||
},
|
||||
requestId: string
|
||||
): Promise<ChunkData> {
|
||||
const dbUpdateData: {
|
||||
updatedAt: Date
|
||||
content?: string
|
||||
contentLength?: number
|
||||
tokenCount?: number
|
||||
chunkHash?: string
|
||||
embedding?: number[]
|
||||
enabled?: boolean
|
||||
} = {
|
||||
updatedAt: new Date(),
|
||||
}
|
||||
|
||||
// Use transaction if content is being updated to ensure consistent document statistics
|
||||
if (updateData.content !== undefined && typeof updateData.content === 'string') {
|
||||
return await db.transaction(async (tx) => {
|
||||
// Get current chunk data for character count calculation and content comparison
|
||||
const currentChunk = await tx
|
||||
.select({
|
||||
documentId: embedding.documentId,
|
||||
content: embedding.content,
|
||||
contentLength: embedding.contentLength,
|
||||
tokenCount: embedding.tokenCount,
|
||||
})
|
||||
.from(embedding)
|
||||
.where(eq(embedding.id, chunkId))
|
||||
.limit(1)
|
||||
|
||||
if (currentChunk.length === 0) {
|
||||
throw new Error(`Chunk ${chunkId} not found`)
|
||||
}
|
||||
|
||||
const oldContentLength = currentChunk[0].contentLength
|
||||
const oldTokenCount = currentChunk[0].tokenCount
|
||||
const content = updateData.content! // We know it's defined from the if check above
|
||||
const newContentLength = content.length
|
||||
|
||||
// Only regenerate embedding if content actually changed
|
||||
if (content !== currentChunk[0].content) {
|
||||
logger.info(`[${requestId}] Content changed, regenerating embedding for chunk ${chunkId}`)
|
||||
|
||||
// Generate new embedding for the updated content
|
||||
const embeddings = await generateEmbeddings([content])
|
||||
|
||||
// Calculate accurate token count
|
||||
const tokenCount = estimateTokenCount(content, 'openai')
|
||||
|
||||
dbUpdateData.content = content
|
||||
dbUpdateData.contentLength = newContentLength
|
||||
dbUpdateData.tokenCount = tokenCount.count
|
||||
dbUpdateData.chunkHash = createHash('sha256').update(content).digest('hex')
|
||||
// Add the embedding field to the update data
|
||||
dbUpdateData.embedding = embeddings[0]
|
||||
} else {
|
||||
// Content hasn't changed, just update other fields if needed
|
||||
dbUpdateData.content = content
|
||||
dbUpdateData.contentLength = newContentLength
|
||||
dbUpdateData.tokenCount = oldTokenCount // Keep the same token count if content is identical
|
||||
dbUpdateData.chunkHash = createHash('sha256').update(content).digest('hex')
|
||||
}
|
||||
|
||||
if (updateData.enabled !== undefined) {
|
||||
dbUpdateData.enabled = updateData.enabled
|
||||
}
|
||||
|
||||
// Update the chunk
|
||||
await tx.update(embedding).set(dbUpdateData).where(eq(embedding.id, chunkId))
|
||||
|
||||
// Update document statistics for the character and token count changes
|
||||
const charDiff = newContentLength - oldContentLength
|
||||
const tokenDiff = dbUpdateData.tokenCount! - oldTokenCount
|
||||
|
||||
await tx
|
||||
.update(document)
|
||||
.set({
|
||||
characterCount: sql`${document.characterCount} + ${charDiff}`,
|
||||
tokenCount: sql`${document.tokenCount} + ${tokenDiff}`,
|
||||
})
|
||||
.where(eq(document.id, currentChunk[0].documentId))
|
||||
|
||||
// Fetch and return the updated chunk
|
||||
const updatedChunk = await tx
|
||||
.select({
|
||||
id: embedding.id,
|
||||
chunkIndex: embedding.chunkIndex,
|
||||
content: embedding.content,
|
||||
contentLength: embedding.contentLength,
|
||||
tokenCount: embedding.tokenCount,
|
||||
enabled: embedding.enabled,
|
||||
startOffset: embedding.startOffset,
|
||||
endOffset: embedding.endOffset,
|
||||
tag1: embedding.tag1,
|
||||
tag2: embedding.tag2,
|
||||
tag3: embedding.tag3,
|
||||
tag4: embedding.tag4,
|
||||
tag5: embedding.tag5,
|
||||
tag6: embedding.tag6,
|
||||
tag7: embedding.tag7,
|
||||
createdAt: embedding.createdAt,
|
||||
updatedAt: embedding.updatedAt,
|
||||
})
|
||||
.from(embedding)
|
||||
.where(eq(embedding.id, chunkId))
|
||||
.limit(1)
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Updated chunk: ${chunkId}${updateData.content !== currentChunk[0].content ? ' (regenerated embedding)' : ''}`
|
||||
)
|
||||
|
||||
return updatedChunk[0] as ChunkData
|
||||
})
|
||||
}
|
||||
|
||||
// If only enabled status is being updated, no need for transaction
|
||||
if (updateData.enabled !== undefined) {
|
||||
dbUpdateData.enabled = updateData.enabled
|
||||
}
|
||||
|
||||
await db.update(embedding).set(dbUpdateData).where(eq(embedding.id, chunkId))
|
||||
|
||||
// Fetch the updated chunk
|
||||
const updatedChunk = await db
|
||||
.select({
|
||||
id: embedding.id,
|
||||
chunkIndex: embedding.chunkIndex,
|
||||
content: embedding.content,
|
||||
contentLength: embedding.contentLength,
|
||||
tokenCount: embedding.tokenCount,
|
||||
enabled: embedding.enabled,
|
||||
startOffset: embedding.startOffset,
|
||||
endOffset: embedding.endOffset,
|
||||
tag1: embedding.tag1,
|
||||
tag2: embedding.tag2,
|
||||
tag3: embedding.tag3,
|
||||
tag4: embedding.tag4,
|
||||
tag5: embedding.tag5,
|
||||
tag6: embedding.tag6,
|
||||
tag7: embedding.tag7,
|
||||
createdAt: embedding.createdAt,
|
||||
updatedAt: embedding.updatedAt,
|
||||
})
|
||||
.from(embedding)
|
||||
.where(eq(embedding.id, chunkId))
|
||||
.limit(1)
|
||||
|
||||
if (updatedChunk.length === 0) {
|
||||
throw new Error(`Chunk ${chunkId} not found`)
|
||||
}
|
||||
|
||||
logger.info(`[${requestId}] Updated chunk: ${chunkId}`)
|
||||
|
||||
return updatedChunk[0] as ChunkData
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete a single chunk with document statistics updates
|
||||
*/
|
||||
export async function deleteChunk(
|
||||
chunkId: string,
|
||||
documentId: string,
|
||||
requestId: string
|
||||
): Promise<void> {
|
||||
await db.transaction(async (tx) => {
|
||||
// Get chunk data before deletion for statistics update
|
||||
const chunkToDelete = await tx
|
||||
.select({
|
||||
tokenCount: embedding.tokenCount,
|
||||
contentLength: embedding.contentLength,
|
||||
})
|
||||
.from(embedding)
|
||||
.where(eq(embedding.id, chunkId))
|
||||
.limit(1)
|
||||
|
||||
if (chunkToDelete.length === 0) {
|
||||
throw new Error('Chunk not found')
|
||||
}
|
||||
|
||||
const chunk = chunkToDelete[0]
|
||||
|
||||
// Delete the chunk
|
||||
await tx.delete(embedding).where(eq(embedding.id, chunkId))
|
||||
|
||||
// Update document statistics
|
||||
await tx
|
||||
.update(document)
|
||||
.set({
|
||||
chunkCount: sql`${document.chunkCount} - 1`,
|
||||
tokenCount: sql`${document.tokenCount} - ${chunk.tokenCount}`,
|
||||
characterCount: sql`${document.characterCount} - ${chunk.contentLength}`,
|
||||
})
|
||||
.where(eq(document.id, documentId))
|
||||
})
|
||||
|
||||
logger.info(`[${requestId}] Deleted chunk: ${chunkId}`)
|
||||
}
|
||||
47
apps/sim/lib/knowledge/chunks/types.ts
Normal file
47
apps/sim/lib/knowledge/chunks/types.ts
Normal file
@@ -0,0 +1,47 @@
|
||||
export interface ChunkFilters {
|
||||
search?: string
|
||||
enabled?: 'true' | 'false' | 'all'
|
||||
limit?: number
|
||||
offset?: number
|
||||
}
|
||||
|
||||
export interface ChunkData {
|
||||
id: string
|
||||
chunkIndex: number
|
||||
content: string
|
||||
contentLength: number
|
||||
tokenCount: number
|
||||
enabled: boolean
|
||||
startOffset: number
|
||||
endOffset: number
|
||||
tag1?: string | null
|
||||
tag2?: string | null
|
||||
tag3?: string | null
|
||||
tag4?: string | null
|
||||
tag5?: string | null
|
||||
tag6?: string | null
|
||||
tag7?: string | null
|
||||
createdAt: Date
|
||||
updatedAt: Date
|
||||
}
|
||||
|
||||
export interface ChunkQueryResult {
|
||||
chunks: ChunkData[]
|
||||
pagination: {
|
||||
total: number
|
||||
limit: number
|
||||
offset: number
|
||||
hasMore: boolean
|
||||
}
|
||||
}
|
||||
|
||||
export interface CreateChunkData {
|
||||
content: string
|
||||
enabled?: boolean
|
||||
}
|
||||
|
||||
export interface BatchOperationResult {
|
||||
success: boolean
|
||||
processed: number
|
||||
errors: string[]
|
||||
}
|
||||
@@ -26,7 +26,7 @@ export interface Chunk {
|
||||
|
||||
/**
|
||||
* Lightweight text chunker optimized for RAG applications
|
||||
* Uses hierarchical splitting with smart token estimation
|
||||
* Uses hierarchical splitting with simple character-based token estimation
|
||||
*/
|
||||
export class TextChunker {
|
||||
private readonly chunkSize: number
|
||||
@@ -62,39 +62,20 @@ export class TextChunker {
|
||||
}
|
||||
|
||||
/**
|
||||
* Estimate token count - optimized for common tokenizers
|
||||
* Simple token estimation using character count
|
||||
*/
|
||||
private estimateTokens(text: string): number {
|
||||
// Handle empty or whitespace-only text
|
||||
if (!text?.trim()) return 0
|
||||
|
||||
const words = text.trim().split(/\s+/)
|
||||
let tokenCount = 0
|
||||
|
||||
for (const word of words) {
|
||||
if (word.length === 0) continue
|
||||
|
||||
// Short words (1-4 chars) are usually 1 token
|
||||
if (word.length <= 4) {
|
||||
tokenCount += 1
|
||||
}
|
||||
// Medium words (5-8 chars) are usually 1-2 tokens
|
||||
else if (word.length <= 8) {
|
||||
tokenCount += Math.ceil(word.length / 5)
|
||||
}
|
||||
// Long words get split more by subword tokenization
|
||||
else {
|
||||
tokenCount += Math.ceil(word.length / 4)
|
||||
}
|
||||
}
|
||||
|
||||
return tokenCount
|
||||
// Simple estimation: ~4 characters per token
|
||||
return Math.ceil(text.length / 4)
|
||||
}
|
||||
|
||||
/**
|
||||
* Split text recursively using hierarchical separators
|
||||
*/
|
||||
private splitRecursively(text: string, separatorIndex = 0): string[] {
|
||||
private async splitRecursively(text: string, separatorIndex = 0): Promise<string[]> {
|
||||
const tokenCount = this.estimateTokens(text)
|
||||
|
||||
// If chunk is small enough, return it
|
||||
@@ -121,7 +102,7 @@ export class TextChunker {
|
||||
|
||||
// If no split occurred, try next separator
|
||||
if (parts.length <= 1) {
|
||||
return this.splitRecursively(text, separatorIndex + 1)
|
||||
return await this.splitRecursively(text, separatorIndex + 1)
|
||||
}
|
||||
|
||||
const chunks: string[] = []
|
||||
@@ -141,7 +122,7 @@ export class TextChunker {
|
||||
// Start new chunk with current part
|
||||
// If part itself is too large, split it further
|
||||
if (this.estimateTokens(part) > this.chunkSize) {
|
||||
chunks.push(...this.splitRecursively(part, separatorIndex + 1))
|
||||
chunks.push(...(await this.splitRecursively(part, separatorIndex + 1)))
|
||||
currentChunk = ''
|
||||
} else {
|
||||
currentChunk = part
|
||||
@@ -212,14 +193,14 @@ export class TextChunker {
|
||||
const cleanedText = this.cleanText(text)
|
||||
|
||||
// Split into chunks
|
||||
let chunks = this.splitRecursively(cleanedText)
|
||||
let chunks = await this.splitRecursively(cleanedText)
|
||||
|
||||
// Add overlap if configured
|
||||
chunks = this.addOverlap(chunks)
|
||||
|
||||
// Convert to Chunk objects with metadata
|
||||
let previousEndIndex = 0
|
||||
return chunks.map((chunkText, index) => {
|
||||
const chunkPromises = chunks.map(async (chunkText, index) => {
|
||||
let startIndex: number
|
||||
let actualContentLength: number
|
||||
|
||||
@@ -256,5 +237,7 @@ export class TextChunker {
|
||||
previousEndIndex = endIndexSafe
|
||||
return chunk
|
||||
})
|
||||
|
||||
return await Promise.all(chunkPromises)
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,10 @@
|
||||
import fs from 'fs/promises'
|
||||
import path from 'path'
|
||||
import { TextChunker } from '@/lib/documents/chunker'
|
||||
import type { DocChunk, DocsChunkerOptions, HeaderInfo } from '@/lib/documents/types'
|
||||
import { generateEmbeddings } from '@/lib/embeddings/utils'
|
||||
import { isDev } from '@/lib/environment'
|
||||
import { TextChunker } from '@/lib/knowledge/documents/chunker'
|
||||
import type { DocChunk, DocsChunkerOptions, HeaderInfo } from '@/lib/knowledge/documents/types'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { generateEmbeddings } from '@/app/api/knowledge/utils'
|
||||
|
||||
interface Frontmatter {
|
||||
title?: string
|
||||
@@ -1,9 +1,14 @@
|
||||
import { type Chunk, TextChunker } from '@/lib/documents/chunker'
|
||||
import { retryWithExponentialBackoff } from '@/lib/documents/utils'
|
||||
import { env } from '@/lib/env'
|
||||
import { parseBuffer, parseFile } from '@/lib/file-parsers'
|
||||
import { type Chunk, TextChunker } from '@/lib/knowledge/documents/chunker'
|
||||
import { retryWithExponentialBackoff } from '@/lib/knowledge/documents/utils'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { getPresignedUrlWithConfig, getStorageProvider, uploadFile } from '@/lib/uploads'
|
||||
import {
|
||||
type CustomStorageConfig,
|
||||
getPresignedUrlWithConfig,
|
||||
getStorageProvider,
|
||||
uploadFile,
|
||||
} from '@/lib/uploads'
|
||||
import { BLOB_KB_CONFIG, S3_KB_CONFIG } from '@/lib/uploads/setup'
|
||||
import { mistralParserTool } from '@/tools/mistral/parser'
|
||||
|
||||
@@ -14,19 +19,33 @@ const TIMEOUTS = {
|
||||
MISTRAL_OCR_API: 90000,
|
||||
} as const
|
||||
|
||||
type S3Config = {
|
||||
bucket: string
|
||||
region: string
|
||||
type OCRResult = {
|
||||
success: boolean
|
||||
error?: string
|
||||
output?: {
|
||||
content?: string
|
||||
}
|
||||
}
|
||||
|
||||
type BlobConfig = {
|
||||
containerName: string
|
||||
accountName: string
|
||||
accountKey?: string
|
||||
connectionString?: string
|
||||
type OCRPage = {
|
||||
markdown?: string
|
||||
}
|
||||
|
||||
const getKBConfig = (): S3Config | BlobConfig => {
|
||||
type OCRRequestBody = {
|
||||
model: string
|
||||
document: {
|
||||
type: string
|
||||
document_url: string
|
||||
}
|
||||
include_image_base64: boolean
|
||||
}
|
||||
|
||||
type AzureOCRResponse = {
|
||||
pages?: OCRPage[]
|
||||
[key: string]: unknown
|
||||
}
|
||||
|
||||
const getKBConfig = (): CustomStorageConfig => {
|
||||
const provider = getStorageProvider()
|
||||
return provider === 'blob'
|
||||
? {
|
||||
@@ -148,8 +167,8 @@ async function handleFileForOCR(fileUrl: string, filename: string, mimeType: str
|
||||
validateCloudConfig(kbConfig)
|
||||
|
||||
try {
|
||||
const cloudResult = await uploadFile(buffer, filename, mimeType, kbConfig as any)
|
||||
const httpsUrl = await getPresignedUrlWithConfig(cloudResult.key, kbConfig as any, 900)
|
||||
const cloudResult = await uploadFile(buffer, filename, mimeType, kbConfig)
|
||||
const httpsUrl = await getPresignedUrlWithConfig(cloudResult.key, kbConfig, 900)
|
||||
logger.info(`Successfully uploaded for OCR: ${cloudResult.key}`)
|
||||
return { httpsUrl, cloudUrl: httpsUrl }
|
||||
} catch (uploadError) {
|
||||
@@ -199,28 +218,26 @@ async function downloadFileForBase64(fileUrl: string): Promise<Buffer> {
|
||||
return fs.readFile(fileUrl)
|
||||
}
|
||||
|
||||
function validateCloudConfig(kbConfig: S3Config | BlobConfig) {
|
||||
function validateCloudConfig(kbConfig: CustomStorageConfig) {
|
||||
const provider = getStorageProvider()
|
||||
|
||||
if (provider === 'blob') {
|
||||
const config = kbConfig as BlobConfig
|
||||
if (
|
||||
!config.containerName ||
|
||||
(!config.connectionString && (!config.accountName || !config.accountKey))
|
||||
!kbConfig.containerName ||
|
||||
(!kbConfig.connectionString && (!kbConfig.accountName || !kbConfig.accountKey))
|
||||
) {
|
||||
throw new Error(
|
||||
'Azure Blob configuration missing. Set AZURE_CONNECTION_STRING or AZURE_ACCOUNT_NAME + AZURE_ACCOUNT_KEY + AZURE_KB_CONTAINER_NAME'
|
||||
)
|
||||
}
|
||||
} else {
|
||||
const config = kbConfig as S3Config
|
||||
if (!config.bucket || !config.region) {
|
||||
if (!kbConfig.bucket || !kbConfig.region) {
|
||||
throw new Error('S3 configuration missing. Set AWS_REGION and S3_KB_BUCKET_NAME')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function processOCRContent(result: any, filename: string): string {
|
||||
function processOCRContent(result: OCRResult, filename: string): string {
|
||||
if (!result.success) {
|
||||
throw new Error(`OCR processing failed: ${result.error || 'Unknown error'}`)
|
||||
}
|
||||
@@ -245,7 +262,7 @@ function validateOCRConfig(
|
||||
if (!modelName) throw new Error(`${service} model name required`)
|
||||
}
|
||||
|
||||
function extractPageContent(pages: any[]): string {
|
||||
function extractPageContent(pages: OCRPage[]): string {
|
||||
if (!pages?.length) return ''
|
||||
|
||||
return pages
|
||||
@@ -254,7 +271,11 @@ function extractPageContent(pages: any[]): string {
|
||||
.join('\n\n')
|
||||
}
|
||||
|
||||
async function makeOCRRequest(endpoint: string, headers: Record<string, string>, body: any) {
|
||||
async function makeOCRRequest(
|
||||
endpoint: string,
|
||||
headers: Record<string, string>,
|
||||
body: OCRRequestBody
|
||||
): Promise<Response> {
|
||||
const controller = new AbortController()
|
||||
const timeoutId = setTimeout(() => controller.abort(), TIMEOUTS.MISTRAL_OCR_API)
|
||||
|
||||
@@ -309,7 +330,7 @@ async function parseWithAzureMistralOCR(fileUrl: string, filename: string, mimeT
|
||||
Authorization: `Bearer ${env.OCR_AZURE_API_KEY}`,
|
||||
},
|
||||
{
|
||||
model: env.OCR_AZURE_MODEL_NAME,
|
||||
model: env.OCR_AZURE_MODEL_NAME!,
|
||||
document: {
|
||||
type: 'document_url',
|
||||
document_url: dataUri,
|
||||
@@ -320,8 +341,8 @@ async function parseWithAzureMistralOCR(fileUrl: string, filename: string, mimeT
|
||||
{ maxRetries: 3, initialDelayMs: 1000, maxDelayMs: 10000 }
|
||||
)
|
||||
|
||||
const ocrResult = await response.json()
|
||||
const content = extractPageContent(ocrResult.pages) || JSON.stringify(ocrResult, null, 2)
|
||||
const ocrResult = (await response.json()) as AzureOCRResponse
|
||||
const content = extractPageContent(ocrResult.pages || []) || JSON.stringify(ocrResult, null, 2)
|
||||
|
||||
if (!content.trim()) {
|
||||
throw new Error('Azure Mistral OCR returned empty content')
|
||||
@@ -365,13 +386,13 @@ async function parseWithMistralOCR(fileUrl: string, filename: string, mimeType:
|
||||
? mistralParserTool.request!.headers(params)
|
||||
: mistralParserTool.request!.headers
|
||||
|
||||
const requestBody = mistralParserTool.request!.body!(params)
|
||||
const requestBody = mistralParserTool.request!.body!(params) as OCRRequestBody
|
||||
return makeOCRRequest(url, headers as Record<string, string>, requestBody)
|
||||
},
|
||||
{ maxRetries: 3, initialDelayMs: 1000, maxDelayMs: 10000 }
|
||||
)
|
||||
|
||||
const result = await mistralParserTool.transformResponse!(response, params)
|
||||
const result = (await mistralParserTool.transformResponse!(response, params)) as OCRResult
|
||||
const content = processOCRContent(result, filename)
|
||||
|
||||
return { content, processingMethod: 'mistral-ocr' as const, cloudUrl }
|
||||
264
apps/sim/lib/knowledge/documents/queue.ts
Normal file
264
apps/sim/lib/knowledge/documents/queue.ts
Normal file
@@ -0,0 +1,264 @@
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { getRedisClient } from '@/lib/redis'
|
||||
|
||||
const logger = createLogger('DocumentQueue')
|
||||
|
||||
interface QueueJob<T = unknown> {
|
||||
id: string
|
||||
type: string
|
||||
data: T
|
||||
timestamp: number
|
||||
attempts: number
|
||||
maxAttempts: number
|
||||
}
|
||||
|
||||
interface QueueConfig {
|
||||
maxConcurrent: number
|
||||
retryDelay: number
|
||||
maxRetries: number
|
||||
}
|
||||
|
||||
export class DocumentProcessingQueue {
|
||||
private config: QueueConfig
|
||||
private processing = new Map<string, Promise<void>>()
|
||||
private fallbackQueue: QueueJob[] = []
|
||||
private fallbackProcessing = 0
|
||||
private processingStarted = false
|
||||
|
||||
constructor(config: QueueConfig) {
|
||||
this.config = config
|
||||
}
|
||||
|
||||
private isRedisAvailable(): boolean {
|
||||
const redis = getRedisClient()
|
||||
return redis !== null
|
||||
}
|
||||
|
||||
async addJob<T>(type: string, data: T, options: { maxAttempts?: number } = {}): Promise<string> {
|
||||
const job: QueueJob = {
|
||||
id: `${type}-${Date.now()}-${Math.random().toString(36).substring(2, 9)}`,
|
||||
type,
|
||||
data,
|
||||
timestamp: Date.now(),
|
||||
attempts: 0,
|
||||
maxAttempts: options.maxAttempts || this.config.maxRetries,
|
||||
}
|
||||
|
||||
if (this.isRedisAvailable()) {
|
||||
try {
|
||||
const redis = getRedisClient()!
|
||||
await redis.lpush('document-queue', JSON.stringify(job))
|
||||
logger.info(`Job ${job.id} added to Redis queue`)
|
||||
return job.id
|
||||
} catch (error) {
|
||||
logger.warn('Failed to add job to Redis, using fallback:', error)
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to in-memory queue
|
||||
this.fallbackQueue.push(job)
|
||||
logger.info(`Job ${job.id} added to in-memory fallback queue`)
|
||||
return job.id
|
||||
}
|
||||
|
||||
async processJobs(processor: (job: QueueJob) => Promise<void>): Promise<void> {
|
||||
if (this.processingStarted) {
|
||||
logger.info('Queue processing already started, skipping')
|
||||
return
|
||||
}
|
||||
|
||||
this.processingStarted = true
|
||||
logger.info('Starting queue processing')
|
||||
|
||||
if (this.isRedisAvailable()) {
|
||||
await this.processRedisJobs(processor)
|
||||
} else {
|
||||
await this.processFallbackJobs(processor)
|
||||
}
|
||||
}
|
||||
|
||||
private async processRedisJobs(processor: (job: QueueJob) => Promise<void>) {
|
||||
const redis = getRedisClient()
|
||||
if (!redis) {
|
||||
logger.warn('Redis client not available, falling back to in-memory processing')
|
||||
await this.processFallbackJobs(processor)
|
||||
return
|
||||
}
|
||||
|
||||
const processJobsContinuously = async () => {
|
||||
let consecutiveErrors = 0
|
||||
while (true) {
|
||||
if (this.processing.size >= this.config.maxConcurrent) {
|
||||
await new Promise((resolve) => setTimeout(resolve, 100)) // Wait before checking again
|
||||
continue
|
||||
}
|
||||
|
||||
try {
|
||||
const currentRedis = getRedisClient()
|
||||
if (!currentRedis) {
|
||||
logger.warn('Redis connection lost, switching to fallback processing')
|
||||
await this.processFallbackJobs(processor)
|
||||
return
|
||||
}
|
||||
|
||||
const result = await currentRedis.brpop('document-queue', 1)
|
||||
if (!result || !result[1]) {
|
||||
consecutiveErrors = 0 // Reset error counter on successful operation
|
||||
continue // Continue polling for jobs
|
||||
}
|
||||
|
||||
const job: QueueJob = JSON.parse(result[1])
|
||||
const promise = this.executeJob(job, processor)
|
||||
this.processing.set(job.id, promise)
|
||||
|
||||
promise.finally(() => {
|
||||
this.processing.delete(job.id)
|
||||
})
|
||||
|
||||
consecutiveErrors = 0 // Reset error counter on success
|
||||
// Don't await here - let it process in background while we get next job
|
||||
} catch (error: any) {
|
||||
consecutiveErrors++
|
||||
|
||||
if (
|
||||
error.message?.includes('Connection is closed') ||
|
||||
error.message?.includes('ECONNREFUSED') ||
|
||||
error.code === 'ECONNREFUSED' ||
|
||||
consecutiveErrors >= 5
|
||||
) {
|
||||
logger.warn(
|
||||
`Redis connection failed (${consecutiveErrors} consecutive errors), switching to fallback processing:`,
|
||||
error.message
|
||||
)
|
||||
await this.processFallbackJobs(processor)
|
||||
return
|
||||
}
|
||||
|
||||
logger.error('Error processing Redis job:', error)
|
||||
await new Promise((resolve) =>
|
||||
setTimeout(resolve, Math.min(1000 * consecutiveErrors, 5000))
|
||||
) // Exponential backoff
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Start multiple concurrent processors that run continuously
|
||||
const processors = Array(this.config.maxConcurrent)
|
||||
.fill(null)
|
||||
.map(() => processJobsContinuously())
|
||||
|
||||
// Don't await - let processors run in background
|
||||
Promise.allSettled(processors).catch((error) => {
|
||||
logger.error('Error in Redis queue processors:', error)
|
||||
})
|
||||
}
|
||||
|
||||
private async processFallbackJobs(processor: (job: QueueJob) => Promise<void>) {
|
||||
const processFallbackContinuously = async () => {
|
||||
while (true) {
|
||||
if (this.fallbackProcessing >= this.config.maxConcurrent) {
|
||||
await new Promise((resolve) => setTimeout(resolve, 100))
|
||||
continue
|
||||
}
|
||||
|
||||
const job = this.fallbackQueue.shift()
|
||||
if (!job) {
|
||||
await new Promise((resolve) => setTimeout(resolve, 500)) // Wait for new jobs
|
||||
continue
|
||||
}
|
||||
|
||||
this.fallbackProcessing++
|
||||
|
||||
this.executeJob(job, processor).finally(() => {
|
||||
this.fallbackProcessing--
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Start multiple concurrent processors for fallback queue
|
||||
const processors = Array(this.config.maxConcurrent)
|
||||
.fill(null)
|
||||
.map(() => processFallbackContinuously())
|
||||
|
||||
// Don't await - let processors run in background
|
||||
Promise.allSettled(processors).catch((error) => {
|
||||
logger.error('Error in fallback queue processors:', error)
|
||||
})
|
||||
}
|
||||
|
||||
private async executeJob(
|
||||
job: QueueJob,
|
||||
processor: (job: QueueJob) => Promise<void>
|
||||
): Promise<void> {
|
||||
try {
|
||||
job.attempts++
|
||||
logger.info(`Processing job ${job.id} (attempt ${job.attempts}/${job.maxAttempts})`)
|
||||
|
||||
await processor(job)
|
||||
logger.info(`Job ${job.id} completed successfully`)
|
||||
} catch (error) {
|
||||
logger.error(`Job ${job.id} failed (attempt ${job.attempts}):`, error)
|
||||
|
||||
if (job.attempts < job.maxAttempts) {
|
||||
// Retry logic with exponential backoff
|
||||
const delay = this.config.retryDelay * 2 ** (job.attempts - 1)
|
||||
|
||||
setTimeout(async () => {
|
||||
if (this.isRedisAvailable()) {
|
||||
try {
|
||||
const redis = getRedisClient()!
|
||||
await redis.lpush('document-queue', JSON.stringify(job))
|
||||
} catch (retryError) {
|
||||
logger.warn('Failed to requeue job to Redis, using fallback:', retryError)
|
||||
this.fallbackQueue.push(job)
|
||||
}
|
||||
} else {
|
||||
this.fallbackQueue.push(job)
|
||||
}
|
||||
}, delay)
|
||||
|
||||
logger.info(`Job ${job.id} will retry in ${delay}ms`)
|
||||
} else {
|
||||
logger.error(`Job ${job.id} failed permanently after ${job.attempts} attempts`)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async getQueueStats(): Promise<{ pending: number; processing: number; redisAvailable: boolean }> {
|
||||
let pending = 0
|
||||
const redisAvailable = this.isRedisAvailable()
|
||||
|
||||
if (redisAvailable) {
|
||||
try {
|
||||
const redis = getRedisClient()!
|
||||
pending = await redis.llen('document-queue')
|
||||
} catch (error) {
|
||||
logger.warn('Failed to get Redis queue stats:', error)
|
||||
pending = this.fallbackQueue.length
|
||||
}
|
||||
} else {
|
||||
pending = this.fallbackQueue.length
|
||||
}
|
||||
|
||||
return {
|
||||
pending,
|
||||
processing: redisAvailable ? this.processing.size : this.fallbackProcessing,
|
||||
redisAvailable,
|
||||
}
|
||||
}
|
||||
|
||||
async clearQueue(): Promise<void> {
|
||||
if (this.isRedisAvailable()) {
|
||||
try {
|
||||
const redis = getRedisClient()!
|
||||
await redis.del('document-queue')
|
||||
logger.info('Redis queue cleared')
|
||||
} catch (error) {
|
||||
logger.error('Failed to clear Redis queue:', error)
|
||||
}
|
||||
}
|
||||
|
||||
this.fallbackQueue.length = 0
|
||||
logger.info('Fallback queue cleared')
|
||||
}
|
||||
}
|
||||
1235
apps/sim/lib/knowledge/documents/service.ts
Normal file
1235
apps/sim/lib/knowledge/documents/service.ts
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,3 +1,18 @@
|
||||
// Document sorting options
|
||||
export type DocumentSortField =
|
||||
| 'filename'
|
||||
| 'fileSize'
|
||||
| 'tokenCount'
|
||||
| 'chunkCount'
|
||||
| 'uploadedAt'
|
||||
| 'processingStatus'
|
||||
export type SortOrder = 'asc' | 'desc'
|
||||
|
||||
export interface DocumentSortOptions {
|
||||
sortBy?: DocumentSortField
|
||||
sortOrder?: SortOrder
|
||||
}
|
||||
|
||||
export interface DocChunk {
|
||||
/** The chunk text content */
|
||||
text: string
|
||||
266
apps/sim/lib/knowledge/service.ts
Normal file
266
apps/sim/lib/knowledge/service.ts
Normal file
@@ -0,0 +1,266 @@
|
||||
import { randomUUID } from 'crypto'
|
||||
import { and, count, eq, isNotNull, isNull, or } from 'drizzle-orm'
|
||||
import type {
|
||||
ChunkingConfig,
|
||||
CreateKnowledgeBaseData,
|
||||
KnowledgeBaseWithCounts,
|
||||
} from '@/lib/knowledge/types'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { getUserEntityPermissions } from '@/lib/permissions/utils'
|
||||
import { db } from '@/db'
|
||||
import { document, knowledgeBase, permissions } from '@/db/schema'
|
||||
|
||||
const logger = createLogger('KnowledgeBaseService')
|
||||
|
||||
/**
|
||||
* Get knowledge bases that a user can access
|
||||
*/
|
||||
export async function getKnowledgeBases(
|
||||
userId: string,
|
||||
workspaceId?: string | null
|
||||
): Promise<KnowledgeBaseWithCounts[]> {
|
||||
const knowledgeBasesWithCounts = await db
|
||||
.select({
|
||||
id: knowledgeBase.id,
|
||||
name: knowledgeBase.name,
|
||||
description: knowledgeBase.description,
|
||||
tokenCount: knowledgeBase.tokenCount,
|
||||
embeddingModel: knowledgeBase.embeddingModel,
|
||||
embeddingDimension: knowledgeBase.embeddingDimension,
|
||||
chunkingConfig: knowledgeBase.chunkingConfig,
|
||||
createdAt: knowledgeBase.createdAt,
|
||||
updatedAt: knowledgeBase.updatedAt,
|
||||
workspaceId: knowledgeBase.workspaceId,
|
||||
docCount: count(document.id),
|
||||
})
|
||||
.from(knowledgeBase)
|
||||
.leftJoin(
|
||||
document,
|
||||
and(eq(document.knowledgeBaseId, knowledgeBase.id), isNull(document.deletedAt))
|
||||
)
|
||||
.leftJoin(
|
||||
permissions,
|
||||
and(
|
||||
eq(permissions.entityType, 'workspace'),
|
||||
eq(permissions.entityId, knowledgeBase.workspaceId),
|
||||
eq(permissions.userId, userId)
|
||||
)
|
||||
)
|
||||
.where(
|
||||
and(
|
||||
isNull(knowledgeBase.deletedAt),
|
||||
workspaceId
|
||||
? // When filtering by workspace
|
||||
or(
|
||||
// Knowledge bases belonging to the specified workspace (user must have workspace permissions)
|
||||
and(eq(knowledgeBase.workspaceId, workspaceId), isNotNull(permissions.userId)),
|
||||
// Fallback: User-owned knowledge bases without workspace (legacy)
|
||||
and(eq(knowledgeBase.userId, userId), isNull(knowledgeBase.workspaceId))
|
||||
)
|
||||
: // When not filtering by workspace, use original logic
|
||||
or(
|
||||
// User owns the knowledge base directly
|
||||
eq(knowledgeBase.userId, userId),
|
||||
// User has permissions on the knowledge base's workspace
|
||||
isNotNull(permissions.userId)
|
||||
)
|
||||
)
|
||||
)
|
||||
.groupBy(knowledgeBase.id)
|
||||
.orderBy(knowledgeBase.createdAt)
|
||||
|
||||
return knowledgeBasesWithCounts.map((kb) => ({
|
||||
...kb,
|
||||
chunkingConfig: kb.chunkingConfig as ChunkingConfig,
|
||||
docCount: Number(kb.docCount),
|
||||
}))
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new knowledge base
|
||||
*/
|
||||
export async function createKnowledgeBase(
|
||||
data: CreateKnowledgeBaseData,
|
||||
requestId: string
|
||||
): Promise<KnowledgeBaseWithCounts> {
|
||||
const kbId = randomUUID()
|
||||
const now = new Date()
|
||||
|
||||
if (data.workspaceId) {
|
||||
const hasPermission = await getUserEntityPermissions(data.userId, 'workspace', data.workspaceId)
|
||||
if (hasPermission === null) {
|
||||
throw new Error('User does not have permission to create knowledge bases in this workspace')
|
||||
}
|
||||
}
|
||||
|
||||
const newKnowledgeBase = {
|
||||
id: kbId,
|
||||
name: data.name,
|
||||
description: data.description ?? null,
|
||||
workspaceId: data.workspaceId ?? null,
|
||||
userId: data.userId,
|
||||
tokenCount: 0,
|
||||
embeddingModel: data.embeddingModel,
|
||||
embeddingDimension: data.embeddingDimension,
|
||||
chunkingConfig: data.chunkingConfig,
|
||||
createdAt: now,
|
||||
updatedAt: now,
|
||||
deletedAt: null,
|
||||
}
|
||||
|
||||
await db.insert(knowledgeBase).values(newKnowledgeBase)
|
||||
|
||||
logger.info(`[${requestId}] Created knowledge base: ${data.name} (${kbId})`)
|
||||
|
||||
return {
|
||||
id: kbId,
|
||||
name: data.name,
|
||||
description: data.description ?? null,
|
||||
tokenCount: 0,
|
||||
embeddingModel: data.embeddingModel,
|
||||
embeddingDimension: data.embeddingDimension,
|
||||
chunkingConfig: data.chunkingConfig,
|
||||
createdAt: now,
|
||||
updatedAt: now,
|
||||
workspaceId: data.workspaceId ?? null,
|
||||
docCount: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Update a knowledge base
|
||||
*/
|
||||
export async function updateKnowledgeBase(
|
||||
knowledgeBaseId: string,
|
||||
updates: {
|
||||
name?: string
|
||||
description?: string
|
||||
chunkingConfig?: {
|
||||
maxSize: number
|
||||
minSize: number
|
||||
overlap: number
|
||||
}
|
||||
},
|
||||
requestId: string
|
||||
): Promise<KnowledgeBaseWithCounts> {
|
||||
const now = new Date()
|
||||
const updateData: {
|
||||
updatedAt: Date
|
||||
name?: string
|
||||
description?: string | null
|
||||
chunkingConfig?: {
|
||||
maxSize: number
|
||||
minSize: number
|
||||
overlap: number
|
||||
}
|
||||
embeddingModel?: string
|
||||
embeddingDimension?: number
|
||||
} = {
|
||||
updatedAt: now,
|
||||
}
|
||||
|
||||
if (updates.name !== undefined) updateData.name = updates.name
|
||||
if (updates.description !== undefined) updateData.description = updates.description
|
||||
if (updates.chunkingConfig !== undefined) {
|
||||
updateData.chunkingConfig = updates.chunkingConfig
|
||||
updateData.embeddingModel = 'text-embedding-3-small'
|
||||
updateData.embeddingDimension = 1536
|
||||
}
|
||||
|
||||
await db.update(knowledgeBase).set(updateData).where(eq(knowledgeBase.id, knowledgeBaseId))
|
||||
|
||||
const updatedKb = await db
|
||||
.select({
|
||||
id: knowledgeBase.id,
|
||||
name: knowledgeBase.name,
|
||||
description: knowledgeBase.description,
|
||||
tokenCount: knowledgeBase.tokenCount,
|
||||
embeddingModel: knowledgeBase.embeddingModel,
|
||||
embeddingDimension: knowledgeBase.embeddingDimension,
|
||||
chunkingConfig: knowledgeBase.chunkingConfig,
|
||||
createdAt: knowledgeBase.createdAt,
|
||||
updatedAt: knowledgeBase.updatedAt,
|
||||
workspaceId: knowledgeBase.workspaceId,
|
||||
docCount: count(document.id),
|
||||
})
|
||||
.from(knowledgeBase)
|
||||
.leftJoin(
|
||||
document,
|
||||
and(eq(document.knowledgeBaseId, knowledgeBase.id), isNull(document.deletedAt))
|
||||
)
|
||||
.where(eq(knowledgeBase.id, knowledgeBaseId))
|
||||
.groupBy(knowledgeBase.id)
|
||||
.limit(1)
|
||||
|
||||
if (updatedKb.length === 0) {
|
||||
throw new Error(`Knowledge base ${knowledgeBaseId} not found`)
|
||||
}
|
||||
|
||||
logger.info(`[${requestId}] Updated knowledge base: ${knowledgeBaseId}`)
|
||||
|
||||
return {
|
||||
...updatedKb[0],
|
||||
chunkingConfig: updatedKb[0].chunkingConfig as ChunkingConfig,
|
||||
docCount: Number(updatedKb[0].docCount),
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a single knowledge base by ID
|
||||
*/
|
||||
export async function getKnowledgeBaseById(
|
||||
knowledgeBaseId: string
|
||||
): Promise<KnowledgeBaseWithCounts | null> {
|
||||
const result = await db
|
||||
.select({
|
||||
id: knowledgeBase.id,
|
||||
name: knowledgeBase.name,
|
||||
description: knowledgeBase.description,
|
||||
tokenCount: knowledgeBase.tokenCount,
|
||||
embeddingModel: knowledgeBase.embeddingModel,
|
||||
embeddingDimension: knowledgeBase.embeddingDimension,
|
||||
chunkingConfig: knowledgeBase.chunkingConfig,
|
||||
createdAt: knowledgeBase.createdAt,
|
||||
updatedAt: knowledgeBase.updatedAt,
|
||||
workspaceId: knowledgeBase.workspaceId,
|
||||
docCount: count(document.id),
|
||||
})
|
||||
.from(knowledgeBase)
|
||||
.leftJoin(
|
||||
document,
|
||||
and(eq(document.knowledgeBaseId, knowledgeBase.id), isNull(document.deletedAt))
|
||||
)
|
||||
.where(and(eq(knowledgeBase.id, knowledgeBaseId), isNull(knowledgeBase.deletedAt)))
|
||||
.groupBy(knowledgeBase.id)
|
||||
.limit(1)
|
||||
|
||||
if (result.length === 0) {
|
||||
return null
|
||||
}
|
||||
|
||||
return {
|
||||
...result[0],
|
||||
chunkingConfig: result[0].chunkingConfig as ChunkingConfig,
|
||||
docCount: Number(result[0].docCount),
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete a knowledge base (soft delete)
|
||||
*/
|
||||
export async function deleteKnowledgeBase(
|
||||
knowledgeBaseId: string,
|
||||
requestId: string
|
||||
): Promise<void> {
|
||||
const now = new Date()
|
||||
|
||||
await db
|
||||
.update(knowledgeBase)
|
||||
.set({
|
||||
deletedAt: now,
|
||||
updatedAt: now,
|
||||
})
|
||||
.where(eq(knowledgeBase.id, knowledgeBaseId))
|
||||
|
||||
logger.info(`[${requestId}] Soft deleted knowledge base: ${knowledgeBaseId}`)
|
||||
}
|
||||
649
apps/sim/lib/knowledge/tags/service.ts
Normal file
649
apps/sim/lib/knowledge/tags/service.ts
Normal file
@@ -0,0 +1,649 @@
|
||||
import { randomUUID } from 'crypto'
|
||||
import { and, eq, isNotNull, isNull, sql } from 'drizzle-orm'
|
||||
import {
|
||||
getSlotsForFieldType,
|
||||
SUPPORTED_FIELD_TYPES,
|
||||
type TAG_SLOT_CONFIG,
|
||||
} from '@/lib/constants/knowledge'
|
||||
import type { BulkTagDefinitionsData, DocumentTagDefinition } from '@/lib/knowledge/tags/types'
|
||||
import type {
|
||||
CreateTagDefinitionData,
|
||||
TagDefinition,
|
||||
UpdateTagDefinitionData,
|
||||
} from '@/lib/knowledge/types'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { db } from '@/db'
|
||||
import { document, embedding, knowledgeBaseTagDefinitions } from '@/db/schema'
|
||||
|
||||
const logger = createLogger('TagsService')
|
||||
|
||||
const VALID_TAG_SLOTS = ['tag1', 'tag2', 'tag3', 'tag4', 'tag5', 'tag6', 'tag7'] as const
|
||||
|
||||
function validateTagSlot(tagSlot: string): asserts tagSlot is (typeof VALID_TAG_SLOTS)[number] {
|
||||
if (!VALID_TAG_SLOTS.includes(tagSlot as (typeof VALID_TAG_SLOTS)[number])) {
|
||||
throw new Error(`Invalid tag slot: ${tagSlot}. Must be one of: ${VALID_TAG_SLOTS.join(', ')}`)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the next available slot for a knowledge base and field type
|
||||
*/
|
||||
export async function getNextAvailableSlot(
|
||||
knowledgeBaseId: string,
|
||||
fieldType: string,
|
||||
existingBySlot?: Map<string, any>
|
||||
): Promise<string | null> {
|
||||
const availableSlots = getSlotsForFieldType(fieldType)
|
||||
let usedSlots: Set<string>
|
||||
|
||||
if (existingBySlot) {
|
||||
usedSlots = new Set(
|
||||
Array.from(existingBySlot.entries())
|
||||
.filter(([_, def]) => def.fieldType === fieldType)
|
||||
.map(([slot, _]) => slot)
|
||||
)
|
||||
} else {
|
||||
const existingDefinitions = await db
|
||||
.select({ tagSlot: knowledgeBaseTagDefinitions.tagSlot })
|
||||
.from(knowledgeBaseTagDefinitions)
|
||||
.where(
|
||||
and(
|
||||
eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId),
|
||||
eq(knowledgeBaseTagDefinitions.fieldType, fieldType)
|
||||
)
|
||||
)
|
||||
|
||||
usedSlots = new Set(existingDefinitions.map((def) => def.tagSlot))
|
||||
}
|
||||
|
||||
for (const slot of availableSlots) {
|
||||
if (!usedSlots.has(slot)) {
|
||||
return slot
|
||||
}
|
||||
}
|
||||
|
||||
return null // All slots for this field type are used
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all tag definitions for a knowledge base
|
||||
*/
|
||||
export async function getDocumentTagDefinitions(
|
||||
knowledgeBaseId: string
|
||||
): Promise<DocumentTagDefinition[]> {
|
||||
const definitions = await db
|
||||
.select({
|
||||
id: knowledgeBaseTagDefinitions.id,
|
||||
knowledgeBaseId: knowledgeBaseTagDefinitions.knowledgeBaseId,
|
||||
tagSlot: knowledgeBaseTagDefinitions.tagSlot,
|
||||
displayName: knowledgeBaseTagDefinitions.displayName,
|
||||
fieldType: knowledgeBaseTagDefinitions.fieldType,
|
||||
createdAt: knowledgeBaseTagDefinitions.createdAt,
|
||||
updatedAt: knowledgeBaseTagDefinitions.updatedAt,
|
||||
})
|
||||
.from(knowledgeBaseTagDefinitions)
|
||||
.where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId))
|
||||
.orderBy(knowledgeBaseTagDefinitions.tagSlot)
|
||||
|
||||
return definitions.map((def) => ({
|
||||
...def,
|
||||
tagSlot: def.tagSlot as string,
|
||||
}))
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all tag definitions for a knowledge base (alias for compatibility)
|
||||
*/
|
||||
export async function getTagDefinitions(knowledgeBaseId: string): Promise<TagDefinition[]> {
|
||||
const tagDefinitions = await db
|
||||
.select({
|
||||
id: knowledgeBaseTagDefinitions.id,
|
||||
tagSlot: knowledgeBaseTagDefinitions.tagSlot,
|
||||
displayName: knowledgeBaseTagDefinitions.displayName,
|
||||
fieldType: knowledgeBaseTagDefinitions.fieldType,
|
||||
createdAt: knowledgeBaseTagDefinitions.createdAt,
|
||||
updatedAt: knowledgeBaseTagDefinitions.updatedAt,
|
||||
})
|
||||
.from(knowledgeBaseTagDefinitions)
|
||||
.where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId))
|
||||
.orderBy(knowledgeBaseTagDefinitions.tagSlot)
|
||||
|
||||
return tagDefinitions.map((def) => ({
|
||||
...def,
|
||||
tagSlot: def.tagSlot as string,
|
||||
}))
|
||||
}
|
||||
|
||||
/**
|
||||
* Create or update tag definitions in bulk
|
||||
*/
|
||||
export async function createOrUpdateTagDefinitionsBulk(
|
||||
knowledgeBaseId: string,
|
||||
bulkData: BulkTagDefinitionsData,
|
||||
requestId: string
|
||||
): Promise<{
|
||||
created: DocumentTagDefinition[]
|
||||
updated: DocumentTagDefinition[]
|
||||
errors: string[]
|
||||
}> {
|
||||
const { definitions } = bulkData
|
||||
const created: DocumentTagDefinition[] = []
|
||||
const updated: DocumentTagDefinition[] = []
|
||||
const errors: string[] = []
|
||||
|
||||
// Get existing definitions to check for conflicts and determine operations
|
||||
const existingDefinitions = await getDocumentTagDefinitions(knowledgeBaseId)
|
||||
const existingBySlot = new Map(existingDefinitions.map((def) => [def.tagSlot, def]))
|
||||
const existingByDisplayName = new Map(existingDefinitions.map((def) => [def.displayName, def]))
|
||||
|
||||
// Process each definition
|
||||
for (const defData of definitions) {
|
||||
try {
|
||||
const { tagSlot, displayName, fieldType, originalDisplayName } = defData
|
||||
|
||||
// Validate field type
|
||||
if (!SUPPORTED_FIELD_TYPES.includes(fieldType as (typeof SUPPORTED_FIELD_TYPES)[number])) {
|
||||
errors.push(`Invalid field type: ${fieldType}`)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this is an update (has originalDisplayName) or create
|
||||
const isUpdate = !!originalDisplayName
|
||||
|
||||
if (isUpdate) {
|
||||
// Update existing definition
|
||||
const existingDef = existingByDisplayName.get(originalDisplayName!)
|
||||
if (!existingDef) {
|
||||
errors.push(`Tag definition with display name "${originalDisplayName}" not found`)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if new display name conflicts with another definition
|
||||
if (displayName !== originalDisplayName && existingByDisplayName.has(displayName)) {
|
||||
errors.push(`Display name "${displayName}" already exists`)
|
||||
continue
|
||||
}
|
||||
|
||||
const now = new Date()
|
||||
await db
|
||||
.update(knowledgeBaseTagDefinitions)
|
||||
.set({
|
||||
displayName,
|
||||
fieldType,
|
||||
updatedAt: now,
|
||||
})
|
||||
.where(eq(knowledgeBaseTagDefinitions.id, existingDef.id))
|
||||
|
||||
updated.push({
|
||||
id: existingDef.id,
|
||||
knowledgeBaseId,
|
||||
tagSlot: existingDef.tagSlot,
|
||||
displayName,
|
||||
fieldType,
|
||||
createdAt: existingDef.createdAt,
|
||||
updatedAt: now,
|
||||
})
|
||||
} else {
|
||||
// Create new definition
|
||||
let finalTagSlot = tagSlot
|
||||
|
||||
// If no slot provided or slot is taken, find next available
|
||||
if (!finalTagSlot || existingBySlot.has(finalTagSlot)) {
|
||||
const nextSlot = await getNextAvailableSlot(knowledgeBaseId, fieldType, existingBySlot)
|
||||
if (!nextSlot) {
|
||||
errors.push(`No available slots for field type "${fieldType}"`)
|
||||
continue
|
||||
}
|
||||
finalTagSlot = nextSlot
|
||||
}
|
||||
|
||||
// Check slot conflicts
|
||||
if (existingBySlot.has(finalTagSlot)) {
|
||||
errors.push(`Tag slot "${finalTagSlot}" is already in use`)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check display name conflicts
|
||||
if (existingByDisplayName.has(displayName)) {
|
||||
errors.push(`Display name "${displayName}" already exists`)
|
||||
continue
|
||||
}
|
||||
|
||||
const id = randomUUID()
|
||||
const now = new Date()
|
||||
|
||||
const newDefinition = {
|
||||
id,
|
||||
knowledgeBaseId,
|
||||
tagSlot: finalTagSlot as (typeof TAG_SLOT_CONFIG.text.slots)[number],
|
||||
displayName,
|
||||
fieldType,
|
||||
createdAt: now,
|
||||
updatedAt: now,
|
||||
}
|
||||
|
||||
await db.insert(knowledgeBaseTagDefinitions).values(newDefinition)
|
||||
|
||||
// Add to maps to track for subsequent definitions in this batch
|
||||
existingBySlot.set(finalTagSlot, newDefinition)
|
||||
existingByDisplayName.set(displayName, newDefinition)
|
||||
|
||||
created.push(newDefinition as DocumentTagDefinition)
|
||||
}
|
||||
} catch (error) {
|
||||
errors.push(`Error processing definition "${defData.displayName}": ${error}`)
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Bulk tag definitions processed: ${created.length} created, ${updated.length} updated, ${errors.length} errors`
|
||||
)
|
||||
|
||||
return { created, updated, errors }
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a single tag definition by ID
|
||||
*/
|
||||
export async function getTagDefinitionById(
|
||||
tagDefinitionId: string
|
||||
): Promise<DocumentTagDefinition | null> {
|
||||
const result = await db
|
||||
.select({
|
||||
id: knowledgeBaseTagDefinitions.id,
|
||||
knowledgeBaseId: knowledgeBaseTagDefinitions.knowledgeBaseId,
|
||||
tagSlot: knowledgeBaseTagDefinitions.tagSlot,
|
||||
displayName: knowledgeBaseTagDefinitions.displayName,
|
||||
fieldType: knowledgeBaseTagDefinitions.fieldType,
|
||||
createdAt: knowledgeBaseTagDefinitions.createdAt,
|
||||
updatedAt: knowledgeBaseTagDefinitions.updatedAt,
|
||||
})
|
||||
.from(knowledgeBaseTagDefinitions)
|
||||
.where(eq(knowledgeBaseTagDefinitions.id, tagDefinitionId))
|
||||
.limit(1)
|
||||
|
||||
if (result.length === 0) {
|
||||
return null
|
||||
}
|
||||
|
||||
const def = result[0]
|
||||
return {
|
||||
...def,
|
||||
tagSlot: def.tagSlot as string,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Update tags on all documents and chunks when a tag value is changed
|
||||
*/
|
||||
export async function updateTagValuesInDocumentsAndChunks(
|
||||
knowledgeBaseId: string,
|
||||
tagSlot: string,
|
||||
oldValue: string | null,
|
||||
newValue: string | null,
|
||||
requestId: string
|
||||
): Promise<{ documentsUpdated: number; chunksUpdated: number }> {
|
||||
validateTagSlot(tagSlot)
|
||||
|
||||
let documentsUpdated = 0
|
||||
let chunksUpdated = 0
|
||||
|
||||
await db.transaction(async (tx) => {
|
||||
if (oldValue) {
|
||||
await tx
|
||||
.update(document)
|
||||
.set({
|
||||
[tagSlot]: newValue,
|
||||
})
|
||||
.where(
|
||||
and(
|
||||
eq(document.knowledgeBaseId, knowledgeBaseId),
|
||||
eq(sql.raw(`${document}.${tagSlot}`), oldValue)
|
||||
)
|
||||
)
|
||||
documentsUpdated = 1
|
||||
}
|
||||
|
||||
if (oldValue) {
|
||||
await tx
|
||||
.update(embedding)
|
||||
.set({
|
||||
[tagSlot]: newValue,
|
||||
})
|
||||
.where(
|
||||
and(
|
||||
eq(embedding.knowledgeBaseId, knowledgeBaseId),
|
||||
eq(sql.raw(`${embedding}.${tagSlot}`), oldValue)
|
||||
)
|
||||
)
|
||||
chunksUpdated = 1
|
||||
}
|
||||
})
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Updated tag values: ${documentsUpdated} documents, ${chunksUpdated} chunks`
|
||||
)
|
||||
|
||||
return { documentsUpdated, chunksUpdated }
|
||||
}
|
||||
|
||||
/**
|
||||
* Cleanup unused tag definitions for a knowledge base
|
||||
*/
|
||||
export async function cleanupUnusedTagDefinitions(
|
||||
knowledgeBaseId: string,
|
||||
requestId: string
|
||||
): Promise<number> {
|
||||
const definitions = await getDocumentTagDefinitions(knowledgeBaseId)
|
||||
let cleanedUp = 0
|
||||
|
||||
for (const def of definitions) {
|
||||
const tagSlot = def.tagSlot
|
||||
validateTagSlot(tagSlot)
|
||||
|
||||
const docCountResult = await db
|
||||
.select({ count: sql<number>`count(*)` })
|
||||
.from(document)
|
||||
.where(
|
||||
and(
|
||||
eq(document.knowledgeBaseId, knowledgeBaseId),
|
||||
isNull(document.deletedAt),
|
||||
sql`${sql.raw(tagSlot)} IS NOT NULL`
|
||||
)
|
||||
)
|
||||
|
||||
const chunkCountResult = await db
|
||||
.select({ count: sql<number>`count(*)` })
|
||||
.from(embedding)
|
||||
.where(
|
||||
and(eq(embedding.knowledgeBaseId, knowledgeBaseId), sql`${sql.raw(tagSlot)} IS NOT NULL`)
|
||||
)
|
||||
|
||||
const docCount = Number(docCountResult[0]?.count || 0)
|
||||
const chunkCount = Number(chunkCountResult[0]?.count || 0)
|
||||
|
||||
if (docCount === 0 && chunkCount === 0) {
|
||||
await db.delete(knowledgeBaseTagDefinitions).where(eq(knowledgeBaseTagDefinitions.id, def.id))
|
||||
|
||||
cleanedUp++
|
||||
logger.info(
|
||||
`[${requestId}] Cleaned up unused tag definition: ${def.displayName} (${def.tagSlot})`
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(`[${requestId}] Cleanup completed: ${cleanedUp} unused tag definitions removed`)
|
||||
return cleanedUp
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete all tag definitions for a knowledge base
|
||||
*/
|
||||
export async function deleteAllTagDefinitions(
|
||||
knowledgeBaseId: string,
|
||||
requestId: string
|
||||
): Promise<number> {
|
||||
const result = await db
|
||||
.delete(knowledgeBaseTagDefinitions)
|
||||
.where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId))
|
||||
.returning({ id: knowledgeBaseTagDefinitions.id })
|
||||
|
||||
const deletedCount = result.length
|
||||
logger.info(`[${requestId}] Deleted ${deletedCount} tag definitions for KB: ${knowledgeBaseId}`)
|
||||
|
||||
return deletedCount
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete a tag definition with comprehensive cleanup
|
||||
* This removes the definition and clears all document/chunk references
|
||||
*/
|
||||
export async function deleteTagDefinition(
|
||||
tagDefinitionId: string,
|
||||
requestId: string
|
||||
): Promise<{ tagSlot: string; displayName: string }> {
|
||||
const tagDef = await db
|
||||
.select({
|
||||
id: knowledgeBaseTagDefinitions.id,
|
||||
knowledgeBaseId: knowledgeBaseTagDefinitions.knowledgeBaseId,
|
||||
tagSlot: knowledgeBaseTagDefinitions.tagSlot,
|
||||
displayName: knowledgeBaseTagDefinitions.displayName,
|
||||
})
|
||||
.from(knowledgeBaseTagDefinitions)
|
||||
.where(eq(knowledgeBaseTagDefinitions.id, tagDefinitionId))
|
||||
.limit(1)
|
||||
|
||||
if (tagDef.length === 0) {
|
||||
throw new Error(`Tag definition ${tagDefinitionId} not found`)
|
||||
}
|
||||
|
||||
const definition = tagDef[0]
|
||||
const knowledgeBaseId = definition.knowledgeBaseId
|
||||
const tagSlot = definition.tagSlot as string
|
||||
|
||||
validateTagSlot(tagSlot)
|
||||
|
||||
await db.transaction(async (tx) => {
|
||||
await tx
|
||||
.update(document)
|
||||
.set({ [tagSlot]: null })
|
||||
.where(
|
||||
and(eq(document.knowledgeBaseId, knowledgeBaseId), isNotNull(sql`${sql.raw(tagSlot)}`))
|
||||
)
|
||||
|
||||
await tx
|
||||
.update(embedding)
|
||||
.set({ [tagSlot]: null })
|
||||
.where(
|
||||
and(eq(embedding.knowledgeBaseId, knowledgeBaseId), isNotNull(sql`${sql.raw(tagSlot)}`))
|
||||
)
|
||||
|
||||
await tx
|
||||
.delete(knowledgeBaseTagDefinitions)
|
||||
.where(eq(knowledgeBaseTagDefinitions.id, tagDefinitionId))
|
||||
})
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Deleted tag definition with cleanup: ${definition.displayName} (${tagSlot})`
|
||||
)
|
||||
|
||||
return {
|
||||
tagSlot,
|
||||
displayName: definition.displayName,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new tag definition
|
||||
*/
|
||||
export async function createTagDefinition(
|
||||
data: CreateTagDefinitionData,
|
||||
requestId: string
|
||||
): Promise<TagDefinition> {
|
||||
const tagDefinitionId = randomUUID()
|
||||
const now = new Date()
|
||||
|
||||
const newDefinition = {
|
||||
id: tagDefinitionId,
|
||||
knowledgeBaseId: data.knowledgeBaseId,
|
||||
tagSlot: data.tagSlot as (typeof TAG_SLOT_CONFIG.text.slots)[number],
|
||||
displayName: data.displayName,
|
||||
fieldType: data.fieldType,
|
||||
createdAt: now,
|
||||
updatedAt: now,
|
||||
}
|
||||
|
||||
await db.insert(knowledgeBaseTagDefinitions).values(newDefinition)
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Created tag definition: ${data.displayName} -> ${data.tagSlot} in KB ${data.knowledgeBaseId}`
|
||||
)
|
||||
|
||||
return {
|
||||
id: tagDefinitionId,
|
||||
tagSlot: data.tagSlot,
|
||||
displayName: data.displayName,
|
||||
fieldType: data.fieldType,
|
||||
createdAt: now,
|
||||
updatedAt: now,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Update an existing tag definition
|
||||
*/
|
||||
export async function updateTagDefinition(
|
||||
tagDefinitionId: string,
|
||||
data: UpdateTagDefinitionData,
|
||||
requestId: string
|
||||
): Promise<TagDefinition> {
|
||||
const now = new Date()
|
||||
|
||||
const updateData: {
|
||||
updatedAt: Date
|
||||
displayName?: string
|
||||
fieldType?: string
|
||||
} = {
|
||||
updatedAt: now,
|
||||
}
|
||||
|
||||
if (data.displayName !== undefined) {
|
||||
updateData.displayName = data.displayName
|
||||
}
|
||||
|
||||
if (data.fieldType !== undefined) {
|
||||
updateData.fieldType = data.fieldType
|
||||
}
|
||||
|
||||
const updatedRows = await db
|
||||
.update(knowledgeBaseTagDefinitions)
|
||||
.set(updateData)
|
||||
.where(eq(knowledgeBaseTagDefinitions.id, tagDefinitionId))
|
||||
.returning({
|
||||
id: knowledgeBaseTagDefinitions.id,
|
||||
tagSlot: knowledgeBaseTagDefinitions.tagSlot,
|
||||
displayName: knowledgeBaseTagDefinitions.displayName,
|
||||
fieldType: knowledgeBaseTagDefinitions.fieldType,
|
||||
createdAt: knowledgeBaseTagDefinitions.createdAt,
|
||||
updatedAt: knowledgeBaseTagDefinitions.updatedAt,
|
||||
})
|
||||
|
||||
if (updatedRows.length === 0) {
|
||||
throw new Error(`Tag definition ${tagDefinitionId} not found`)
|
||||
}
|
||||
|
||||
const updated = updatedRows[0]
|
||||
logger.info(`[${requestId}] Updated tag definition: ${tagDefinitionId}`)
|
||||
|
||||
return {
|
||||
...updated,
|
||||
tagSlot: updated.tagSlot as string,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get tag usage with detailed document information (original format)
|
||||
*/
|
||||
export async function getTagUsage(
|
||||
knowledgeBaseId: string,
|
||||
requestId = 'api'
|
||||
): Promise<
|
||||
Array<{
|
||||
tagName: string
|
||||
tagSlot: string
|
||||
documentCount: number
|
||||
documents: Array<{ id: string; name: string; tagValue: string }>
|
||||
}>
|
||||
> {
|
||||
const definitions = await getDocumentTagDefinitions(knowledgeBaseId)
|
||||
const usage = []
|
||||
|
||||
for (const def of definitions) {
|
||||
const tagSlot = def.tagSlot
|
||||
validateTagSlot(tagSlot)
|
||||
|
||||
const documentsWithTag = await db
|
||||
.select({
|
||||
id: document.id,
|
||||
filename: document.filename,
|
||||
tagValue: sql<string>`${sql.raw(tagSlot)}`,
|
||||
})
|
||||
.from(document)
|
||||
.where(
|
||||
and(
|
||||
eq(document.knowledgeBaseId, knowledgeBaseId),
|
||||
isNull(document.deletedAt),
|
||||
isNotNull(sql`${sql.raw(tagSlot)}`)
|
||||
)
|
||||
)
|
||||
|
||||
usage.push({
|
||||
tagName: def.displayName,
|
||||
tagSlot: def.tagSlot,
|
||||
documentCount: documentsWithTag.length,
|
||||
documents: documentsWithTag.map((doc) => ({
|
||||
id: doc.id,
|
||||
name: doc.filename,
|
||||
tagValue: doc.tagValue || '',
|
||||
})),
|
||||
})
|
||||
}
|
||||
|
||||
logger.info(`[${requestId}] Retrieved detailed tag usage for ${usage.length} definitions`)
|
||||
|
||||
return usage
|
||||
}
|
||||
|
||||
/**
|
||||
* Get tag usage statistics
|
||||
*/
|
||||
export async function getTagUsageStats(
|
||||
knowledgeBaseId: string,
|
||||
requestId: string
|
||||
): Promise<
|
||||
Array<{
|
||||
tagSlot: string
|
||||
displayName: string
|
||||
fieldType: string
|
||||
documentCount: number
|
||||
chunkCount: number
|
||||
}>
|
||||
> {
|
||||
const definitions = await getDocumentTagDefinitions(knowledgeBaseId)
|
||||
const stats = []
|
||||
|
||||
for (const def of definitions) {
|
||||
const tagSlot = def.tagSlot
|
||||
validateTagSlot(tagSlot)
|
||||
|
||||
const docCountResult = await db
|
||||
.select({ count: sql<number>`count(*)` })
|
||||
.from(document)
|
||||
.where(
|
||||
and(
|
||||
eq(document.knowledgeBaseId, knowledgeBaseId),
|
||||
isNull(document.deletedAt),
|
||||
sql`${sql.raw(tagSlot)} IS NOT NULL`
|
||||
)
|
||||
)
|
||||
|
||||
const chunkCountResult = await db
|
||||
.select({ count: sql<number>`count(*)` })
|
||||
.from(embedding)
|
||||
.where(
|
||||
and(eq(embedding.knowledgeBaseId, knowledgeBaseId), sql`${sql.raw(tagSlot)} IS NOT NULL`)
|
||||
)
|
||||
|
||||
stats.push({
|
||||
tagSlot: def.tagSlot,
|
||||
displayName: def.displayName,
|
||||
fieldType: def.fieldType,
|
||||
documentCount: Number(docCountResult[0]?.count || 0),
|
||||
chunkCount: Number(chunkCountResult[0]?.count || 0),
|
||||
})
|
||||
}
|
||||
|
||||
logger.info(`[${requestId}] Retrieved tag usage stats for ${stats.length} definitions`)
|
||||
|
||||
return stats
|
||||
}
|
||||
20
apps/sim/lib/knowledge/tags/types.ts
Normal file
20
apps/sim/lib/knowledge/tags/types.ts
Normal file
@@ -0,0 +1,20 @@
|
||||
export interface DocumentTagDefinition {
|
||||
id: string
|
||||
knowledgeBaseId: string
|
||||
tagSlot: string
|
||||
displayName: string
|
||||
fieldType: string
|
||||
createdAt: Date
|
||||
updatedAt: Date
|
||||
}
|
||||
|
||||
export interface CreateTagDefinitionData {
|
||||
tagSlot: string
|
||||
displayName: string
|
||||
fieldType: string
|
||||
originalDisplayName?: string
|
||||
}
|
||||
|
||||
export interface BulkTagDefinitionsData {
|
||||
definitions: CreateTagDefinitionData[]
|
||||
}
|
||||
50
apps/sim/lib/knowledge/types.ts
Normal file
50
apps/sim/lib/knowledge/types.ts
Normal file
@@ -0,0 +1,50 @@
|
||||
export interface ChunkingConfig {
|
||||
maxSize: number
|
||||
minSize: number
|
||||
overlap: number
|
||||
}
|
||||
|
||||
export interface KnowledgeBaseWithCounts {
|
||||
id: string
|
||||
name: string
|
||||
description: string | null
|
||||
tokenCount: number
|
||||
embeddingModel: string
|
||||
embeddingDimension: number
|
||||
chunkingConfig: ChunkingConfig
|
||||
createdAt: Date
|
||||
updatedAt: Date
|
||||
workspaceId: string | null
|
||||
docCount: number
|
||||
}
|
||||
|
||||
export interface CreateKnowledgeBaseData {
|
||||
name: string
|
||||
description?: string
|
||||
workspaceId?: string
|
||||
embeddingModel: 'text-embedding-3-small'
|
||||
embeddingDimension: 1536
|
||||
chunkingConfig: ChunkingConfig
|
||||
userId: string
|
||||
}
|
||||
|
||||
export interface TagDefinition {
|
||||
id: string
|
||||
tagSlot: string
|
||||
displayName: string
|
||||
fieldType: string
|
||||
createdAt: Date
|
||||
updatedAt: Date
|
||||
}
|
||||
|
||||
export interface CreateTagDefinitionData {
|
||||
knowledgeBaseId: string
|
||||
tagSlot: string
|
||||
displayName: string
|
||||
fieldType: string
|
||||
}
|
||||
|
||||
export interface UpdateTagDefinitionData {
|
||||
displayName?: string
|
||||
fieldType?: string
|
||||
}
|
||||
@@ -4,8 +4,8 @@ import { createLogger } from '@/lib/logs/console/logger'
|
||||
|
||||
const logger = createLogger('Redis')
|
||||
|
||||
// Default to localhost if REDIS_URL is not provided
|
||||
const redisUrl = env.REDIS_URL || 'redis://localhost:6379'
|
||||
// Only use Redis if explicitly configured
|
||||
const redisUrl = env.REDIS_URL
|
||||
|
||||
// Global Redis client for connection pooling
|
||||
// This is important for serverless environments like Vercel
|
||||
@@ -24,6 +24,11 @@ export function getRedisClient(): Redis | null {
|
||||
// For server-side only
|
||||
if (typeof window !== 'undefined') return null
|
||||
|
||||
// Return null immediately if no Redis URL is configured
|
||||
if (!redisUrl) {
|
||||
return null
|
||||
}
|
||||
|
||||
if (globalRedisClient) return globalRedisClient
|
||||
|
||||
try {
|
||||
|
||||
@@ -121,7 +121,10 @@ export function validateTokenizationInput(
|
||||
throw createTokenizationError(
|
||||
'MISSING_TEXT',
|
||||
'Either input text or output text must be provided',
|
||||
{ inputLength: inputText?.length || 0, outputLength: outputText?.length || 0 }
|
||||
{
|
||||
inputLength: inputText?.length || 0,
|
||||
outputLength: outputText?.length || 0,
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
import {
|
||||
BlobSASPermissions,
|
||||
BlobServiceClient,
|
||||
type BlockBlobClient,
|
||||
generateBlobSASQueryParameters,
|
||||
StorageSharedKeyCredential,
|
||||
} from '@azure/storage-blob'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { BLOB_CONFIG } from '@/lib/uploads/setup'
|
||||
|
||||
const logger = createLogger('BlobClient')
|
||||
|
||||
// Lazily create a single Blob service client instance.
|
||||
let _blobServiceClient: BlobServiceClient | null = null
|
||||
|
||||
@@ -133,8 +137,6 @@ export async function uploadToBlob(
|
||||
fileSize = configOrSize ?? file.length
|
||||
}
|
||||
|
||||
// Create a unique filename with timestamp to prevent collisions
|
||||
// Use a simple timestamp without directory structure
|
||||
const safeFileName = fileName.replace(/\s+/g, '-') // Replace spaces with hyphens
|
||||
const uniqueKey = `${Date.now()}-${safeFileName}`
|
||||
|
||||
@@ -142,7 +144,6 @@ export async function uploadToBlob(
|
||||
const containerClient = blobServiceClient.getContainerClient(config.containerName)
|
||||
const blockBlobClient = containerClient.getBlockBlobClient(uniqueKey)
|
||||
|
||||
// Upload the file to Azure Blob Storage
|
||||
await blockBlobClient.upload(file, fileSize, {
|
||||
blobHTTPHeaders: {
|
||||
blobContentType: contentType,
|
||||
@@ -153,7 +154,6 @@ export async function uploadToBlob(
|
||||
},
|
||||
})
|
||||
|
||||
// Create a path for API to serve the file
|
||||
const servePath = `/api/files/serve/blob/${encodeURIComponent(uniqueKey)}`
|
||||
|
||||
return {
|
||||
@@ -176,7 +176,6 @@ export async function getPresignedUrl(key: string, expiresIn = 3600) {
|
||||
const containerClient = blobServiceClient.getContainerClient(BLOB_CONFIG.containerName)
|
||||
const blockBlobClient = containerClient.getBlockBlobClient(key)
|
||||
|
||||
// Generate SAS token for the blob
|
||||
const sasOptions = {
|
||||
containerName: BLOB_CONFIG.containerName,
|
||||
blobName: key,
|
||||
@@ -211,7 +210,6 @@ export async function getPresignedUrlWithConfig(
|
||||
customConfig: CustomBlobConfig,
|
||||
expiresIn = 3600
|
||||
) {
|
||||
// Create a temporary client for the custom config
|
||||
let tempBlobServiceClient: BlobServiceClient
|
||||
|
||||
if (customConfig.connectionString) {
|
||||
@@ -234,7 +232,6 @@ export async function getPresignedUrlWithConfig(
|
||||
const containerClient = tempBlobServiceClient.getContainerClient(customConfig.containerName)
|
||||
const blockBlobClient = containerClient.getBlockBlobClient(key)
|
||||
|
||||
// Generate SAS token for the blob
|
||||
const sasOptions = {
|
||||
containerName: customConfig.containerName,
|
||||
blobName: key,
|
||||
@@ -280,7 +277,6 @@ export async function downloadFromBlob(
|
||||
let containerName: string
|
||||
|
||||
if (customConfig) {
|
||||
// Use custom configuration
|
||||
if (customConfig.connectionString) {
|
||||
blobServiceClient = BlobServiceClient.fromConnectionString(customConfig.connectionString)
|
||||
} else if (customConfig.accountName && customConfig.accountKey) {
|
||||
@@ -297,7 +293,6 @@ export async function downloadFromBlob(
|
||||
}
|
||||
containerName = customConfig.containerName
|
||||
} else {
|
||||
// Use default configuration
|
||||
blobServiceClient = getBlobServiceClient()
|
||||
containerName = BLOB_CONFIG.containerName
|
||||
}
|
||||
@@ -332,7 +327,6 @@ export async function deleteFromBlob(key: string, customConfig?: CustomBlobConfi
|
||||
let containerName: string
|
||||
|
||||
if (customConfig) {
|
||||
// Use custom configuration
|
||||
if (customConfig.connectionString) {
|
||||
blobServiceClient = BlobServiceClient.fromConnectionString(customConfig.connectionString)
|
||||
} else if (customConfig.accountName && customConfig.accountKey) {
|
||||
@@ -349,7 +343,6 @@ export async function deleteFromBlob(key: string, customConfig?: CustomBlobConfi
|
||||
}
|
||||
containerName = customConfig.containerName
|
||||
} else {
|
||||
// Use default configuration
|
||||
blobServiceClient = getBlobServiceClient()
|
||||
containerName = BLOB_CONFIG.containerName
|
||||
}
|
||||
@@ -375,3 +368,273 @@ async function streamToBuffer(readableStream: NodeJS.ReadableStream): Promise<Bu
|
||||
readableStream.on('error', reject)
|
||||
})
|
||||
}
|
||||
|
||||
// Multipart upload interfaces
|
||||
export interface AzureMultipartUploadInit {
|
||||
fileName: string
|
||||
contentType: string
|
||||
fileSize: number
|
||||
customConfig?: CustomBlobConfig
|
||||
}
|
||||
|
||||
export interface AzureMultipartUploadResult {
|
||||
uploadId: string
|
||||
key: string
|
||||
blockBlobClient: BlockBlobClient
|
||||
}
|
||||
|
||||
export interface AzurePartUploadUrl {
|
||||
partNumber: number
|
||||
blockId: string
|
||||
url: string
|
||||
}
|
||||
|
||||
export interface AzureMultipartPart {
|
||||
blockId: string
|
||||
partNumber: number
|
||||
}
|
||||
|
||||
/**
|
||||
* Initiate a multipart upload for Azure Blob Storage
|
||||
*/
|
||||
export async function initiateMultipartUpload(
|
||||
options: AzureMultipartUploadInit
|
||||
): Promise<{ uploadId: string; key: string }> {
|
||||
const { fileName, contentType, customConfig } = options
|
||||
|
||||
let blobServiceClient: BlobServiceClient
|
||||
let containerName: string
|
||||
|
||||
if (customConfig) {
|
||||
if (customConfig.connectionString) {
|
||||
blobServiceClient = BlobServiceClient.fromConnectionString(customConfig.connectionString)
|
||||
} else if (customConfig.accountName && customConfig.accountKey) {
|
||||
const credential = new StorageSharedKeyCredential(
|
||||
customConfig.accountName,
|
||||
customConfig.accountKey
|
||||
)
|
||||
blobServiceClient = new BlobServiceClient(
|
||||
`https://${customConfig.accountName}.blob.core.windows.net`,
|
||||
credential
|
||||
)
|
||||
} else {
|
||||
throw new Error('Invalid custom blob configuration')
|
||||
}
|
||||
containerName = customConfig.containerName
|
||||
} else {
|
||||
blobServiceClient = getBlobServiceClient()
|
||||
containerName = BLOB_CONFIG.containerName
|
||||
}
|
||||
|
||||
// Create unique key for the blob
|
||||
const safeFileName = fileName.replace(/\s+/g, '-').replace(/[^a-zA-Z0-9.-]/g, '_')
|
||||
const { v4: uuidv4 } = await import('uuid')
|
||||
const uniqueKey = `kb/${uuidv4()}-${safeFileName}`
|
||||
|
||||
// Generate a unique upload ID (Azure doesn't have native multipart like S3)
|
||||
const uploadId = uuidv4()
|
||||
|
||||
// Store the blob client reference for later use (in a real implementation, you'd use Redis or similar)
|
||||
const containerClient = blobServiceClient.getContainerClient(containerName)
|
||||
const blockBlobClient = containerClient.getBlockBlobClient(uniqueKey)
|
||||
|
||||
// Set metadata to track the multipart upload
|
||||
await blockBlobClient.setMetadata({
|
||||
uploadId,
|
||||
fileName: encodeURIComponent(fileName),
|
||||
contentType,
|
||||
uploadStarted: new Date().toISOString(),
|
||||
multipartUpload: 'true',
|
||||
})
|
||||
|
||||
return {
|
||||
uploadId,
|
||||
key: uniqueKey,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate presigned URLs for uploading parts
|
||||
*/
|
||||
export async function getMultipartPartUrls(
|
||||
key: string,
|
||||
_uploadId: string, // Not used in Azure Blob, kept for interface consistency
|
||||
partNumbers: number[],
|
||||
customConfig?: CustomBlobConfig
|
||||
): Promise<AzurePartUploadUrl[]> {
|
||||
let blobServiceClient: BlobServiceClient
|
||||
let containerName: string
|
||||
let accountName: string
|
||||
let accountKey: string
|
||||
|
||||
if (customConfig) {
|
||||
if (customConfig.connectionString) {
|
||||
blobServiceClient = BlobServiceClient.fromConnectionString(customConfig.connectionString)
|
||||
// Extract account name from connection string
|
||||
const match = customConfig.connectionString.match(/AccountName=([^;]+)/)
|
||||
if (!match) throw new Error('Cannot extract account name from connection string')
|
||||
accountName = match[1]
|
||||
|
||||
const keyMatch = customConfig.connectionString.match(/AccountKey=([^;]+)/)
|
||||
if (!keyMatch) throw new Error('Cannot extract account key from connection string')
|
||||
accountKey = keyMatch[1]
|
||||
} else if (customConfig.accountName && customConfig.accountKey) {
|
||||
const credential = new StorageSharedKeyCredential(
|
||||
customConfig.accountName,
|
||||
customConfig.accountKey
|
||||
)
|
||||
blobServiceClient = new BlobServiceClient(
|
||||
`https://${customConfig.accountName}.blob.core.windows.net`,
|
||||
credential
|
||||
)
|
||||
accountName = customConfig.accountName
|
||||
accountKey = customConfig.accountKey
|
||||
} else {
|
||||
throw new Error('Invalid custom blob configuration')
|
||||
}
|
||||
containerName = customConfig.containerName
|
||||
} else {
|
||||
blobServiceClient = getBlobServiceClient()
|
||||
containerName = BLOB_CONFIG.containerName
|
||||
accountName = BLOB_CONFIG.accountName
|
||||
accountKey =
|
||||
BLOB_CONFIG.accountKey ||
|
||||
(() => {
|
||||
throw new Error('AZURE_ACCOUNT_KEY is required')
|
||||
})()
|
||||
}
|
||||
|
||||
const containerClient = blobServiceClient.getContainerClient(containerName)
|
||||
const blockBlobClient = containerClient.getBlockBlobClient(key)
|
||||
|
||||
return partNumbers.map((partNumber) => {
|
||||
// Azure uses block IDs instead of part numbers
|
||||
// Block IDs must be base64 encoded and all the same length
|
||||
const blockId = Buffer.from(`block-${partNumber.toString().padStart(6, '0')}`).toString(
|
||||
'base64'
|
||||
)
|
||||
|
||||
// Generate SAS token for uploading this specific block
|
||||
const sasOptions = {
|
||||
containerName,
|
||||
blobName: key,
|
||||
permissions: BlobSASPermissions.parse('w'), // Write permission
|
||||
startsOn: new Date(),
|
||||
expiresOn: new Date(Date.now() + 3600 * 1000), // 1 hour
|
||||
}
|
||||
|
||||
const sasToken = generateBlobSASQueryParameters(
|
||||
sasOptions,
|
||||
new StorageSharedKeyCredential(accountName, accountKey)
|
||||
).toString()
|
||||
|
||||
return {
|
||||
partNumber,
|
||||
blockId,
|
||||
url: `${blockBlobClient.url}?comp=block&blockid=${encodeURIComponent(blockId)}&${sasToken}`,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Complete multipart upload by committing all blocks
|
||||
*/
|
||||
export async function completeMultipartUpload(
|
||||
key: string,
|
||||
_uploadId: string, // Not used in Azure Blob, kept for interface consistency
|
||||
parts: Array<{ blockId: string; partNumber: number }>,
|
||||
customConfig?: CustomBlobConfig
|
||||
): Promise<{ location: string; path: string; key: string }> {
|
||||
let blobServiceClient: BlobServiceClient
|
||||
let containerName: string
|
||||
|
||||
if (customConfig) {
|
||||
if (customConfig.connectionString) {
|
||||
blobServiceClient = BlobServiceClient.fromConnectionString(customConfig.connectionString)
|
||||
} else if (customConfig.accountName && customConfig.accountKey) {
|
||||
const credential = new StorageSharedKeyCredential(
|
||||
customConfig.accountName,
|
||||
customConfig.accountKey
|
||||
)
|
||||
blobServiceClient = new BlobServiceClient(
|
||||
`https://${customConfig.accountName}.blob.core.windows.net`,
|
||||
credential
|
||||
)
|
||||
} else {
|
||||
throw new Error('Invalid custom blob configuration')
|
||||
}
|
||||
containerName = customConfig.containerName
|
||||
} else {
|
||||
blobServiceClient = getBlobServiceClient()
|
||||
containerName = BLOB_CONFIG.containerName
|
||||
}
|
||||
|
||||
const containerClient = blobServiceClient.getContainerClient(containerName)
|
||||
const blockBlobClient = containerClient.getBlockBlobClient(key)
|
||||
|
||||
// Sort parts by part number and extract block IDs
|
||||
const sortedBlockIds = parts
|
||||
.sort((a, b) => a.partNumber - b.partNumber)
|
||||
.map((part) => part.blockId)
|
||||
|
||||
// Commit the block list to create the final blob
|
||||
await blockBlobClient.commitBlockList(sortedBlockIds, {
|
||||
metadata: {
|
||||
multipartUpload: 'completed',
|
||||
uploadCompletedAt: new Date().toISOString(),
|
||||
},
|
||||
})
|
||||
|
||||
const location = blockBlobClient.url
|
||||
const path = `/api/files/serve/blob/${encodeURIComponent(key)}`
|
||||
|
||||
return {
|
||||
location,
|
||||
path,
|
||||
key,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Abort multipart upload by deleting the blob if it exists
|
||||
*/
|
||||
export async function abortMultipartUpload(
|
||||
key: string,
|
||||
_uploadId: string, // Not used in Azure Blob, kept for interface consistency
|
||||
customConfig?: CustomBlobConfig
|
||||
): Promise<void> {
|
||||
let blobServiceClient: BlobServiceClient
|
||||
let containerName: string
|
||||
|
||||
if (customConfig) {
|
||||
if (customConfig.connectionString) {
|
||||
blobServiceClient = BlobServiceClient.fromConnectionString(customConfig.connectionString)
|
||||
} else if (customConfig.accountName && customConfig.accountKey) {
|
||||
const credential = new StorageSharedKeyCredential(
|
||||
customConfig.accountName,
|
||||
customConfig.accountKey
|
||||
)
|
||||
blobServiceClient = new BlobServiceClient(
|
||||
`https://${customConfig.accountName}.blob.core.windows.net`,
|
||||
credential
|
||||
)
|
||||
} else {
|
||||
throw new Error('Invalid custom blob configuration')
|
||||
}
|
||||
containerName = customConfig.containerName
|
||||
} else {
|
||||
blobServiceClient = getBlobServiceClient()
|
||||
containerName = BLOB_CONFIG.containerName
|
||||
}
|
||||
|
||||
const containerClient = blobServiceClient.getContainerClient(containerName)
|
||||
const blockBlobClient = containerClient.getBlockBlobClient(key)
|
||||
|
||||
try {
|
||||
// Delete the blob if it exists (this also cleans up any uncommitted blocks)
|
||||
await blockBlobClient.deleteIfExists()
|
||||
} catch (error) {
|
||||
// Ignore errors since we're just cleaning up
|
||||
logger.warn('Error cleaning up multipart upload:', error)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
import {
|
||||
AbortMultipartUploadCommand,
|
||||
CompleteMultipartUploadCommand,
|
||||
CreateMultipartUploadCommand,
|
||||
DeleteObjectCommand,
|
||||
GetObjectCommand,
|
||||
PutObjectCommand,
|
||||
S3Client,
|
||||
UploadPartCommand,
|
||||
} from '@aws-sdk/client-s3'
|
||||
import { getSignedUrl } from '@aws-sdk/s3-request-presigner'
|
||||
import { env } from '@/lib/env'
|
||||
import { S3_CONFIG } from '@/lib/uploads/setup'
|
||||
import { S3_CONFIG, S3_KB_CONFIG } from '@/lib/uploads/setup'
|
||||
|
||||
// Lazily create a single S3 client instance.
|
||||
let _s3Client: S3Client | null = null
|
||||
@@ -287,3 +291,142 @@ export async function deleteFromS3(key: string, customConfig?: CustomS3Config):
|
||||
})
|
||||
)
|
||||
}
|
||||
|
||||
// Multipart upload interfaces
|
||||
export interface S3MultipartUploadInit {
|
||||
fileName: string
|
||||
contentType: string
|
||||
fileSize: number
|
||||
customConfig?: CustomS3Config
|
||||
}
|
||||
|
||||
export interface S3PartUploadUrl {
|
||||
partNumber: number
|
||||
url: string
|
||||
}
|
||||
|
||||
export interface S3MultipartPart {
|
||||
ETag: string
|
||||
PartNumber: number
|
||||
}
|
||||
|
||||
/**
|
||||
* Initiate a multipart upload for S3
|
||||
*/
|
||||
export async function initiateS3MultipartUpload(
|
||||
options: S3MultipartUploadInit
|
||||
): Promise<{ uploadId: string; key: string }> {
|
||||
const { fileName, contentType, customConfig } = options
|
||||
|
||||
const config = customConfig || { bucket: S3_KB_CONFIG.bucket, region: S3_KB_CONFIG.region }
|
||||
const s3Client = getS3Client()
|
||||
|
||||
// Create unique key for the object
|
||||
const safeFileName = fileName.replace(/\s+/g, '-').replace(/[^a-zA-Z0-9.-]/g, '_')
|
||||
const { v4: uuidv4 } = await import('uuid')
|
||||
const uniqueKey = `kb/${uuidv4()}-${safeFileName}`
|
||||
|
||||
const command = new CreateMultipartUploadCommand({
|
||||
Bucket: config.bucket,
|
||||
Key: uniqueKey,
|
||||
ContentType: contentType,
|
||||
Metadata: {
|
||||
originalName: sanitizeFilenameForMetadata(fileName),
|
||||
uploadedAt: new Date().toISOString(),
|
||||
purpose: 'knowledge-base',
|
||||
},
|
||||
})
|
||||
|
||||
const response = await s3Client.send(command)
|
||||
|
||||
if (!response.UploadId) {
|
||||
throw new Error('Failed to initiate S3 multipart upload')
|
||||
}
|
||||
|
||||
return {
|
||||
uploadId: response.UploadId,
|
||||
key: uniqueKey,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate presigned URLs for uploading parts to S3
|
||||
*/
|
||||
export async function getS3MultipartPartUrls(
|
||||
key: string,
|
||||
uploadId: string,
|
||||
partNumbers: number[],
|
||||
customConfig?: CustomS3Config
|
||||
): Promise<S3PartUploadUrl[]> {
|
||||
const config = customConfig || { bucket: S3_KB_CONFIG.bucket, region: S3_KB_CONFIG.region }
|
||||
const s3Client = getS3Client()
|
||||
|
||||
const presignedUrls = await Promise.all(
|
||||
partNumbers.map(async (partNumber) => {
|
||||
const command = new UploadPartCommand({
|
||||
Bucket: config.bucket,
|
||||
Key: key,
|
||||
PartNumber: partNumber,
|
||||
UploadId: uploadId,
|
||||
})
|
||||
|
||||
const url = await getSignedUrl(s3Client, command, { expiresIn: 3600 })
|
||||
return { partNumber, url }
|
||||
})
|
||||
)
|
||||
|
||||
return presignedUrls
|
||||
}
|
||||
|
||||
/**
|
||||
* Complete multipart upload for S3
|
||||
*/
|
||||
export async function completeS3MultipartUpload(
|
||||
key: string,
|
||||
uploadId: string,
|
||||
parts: S3MultipartPart[],
|
||||
customConfig?: CustomS3Config
|
||||
): Promise<{ location: string; path: string; key: string }> {
|
||||
const config = customConfig || { bucket: S3_KB_CONFIG.bucket, region: S3_KB_CONFIG.region }
|
||||
const s3Client = getS3Client()
|
||||
|
||||
const command = new CompleteMultipartUploadCommand({
|
||||
Bucket: config.bucket,
|
||||
Key: key,
|
||||
UploadId: uploadId,
|
||||
MultipartUpload: {
|
||||
Parts: parts.sort((a, b) => a.PartNumber - b.PartNumber),
|
||||
},
|
||||
})
|
||||
|
||||
const response = await s3Client.send(command)
|
||||
const location =
|
||||
response.Location || `https://${config.bucket}.s3.${config.region}.amazonaws.com/${key}`
|
||||
const path = `/api/files/serve/s3/${encodeURIComponent(key)}`
|
||||
|
||||
return {
|
||||
location,
|
||||
path,
|
||||
key,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Abort multipart upload for S3
|
||||
*/
|
||||
export async function abortS3MultipartUpload(
|
||||
key: string,
|
||||
uploadId: string,
|
||||
customConfig?: CustomS3Config
|
||||
): Promise<void> {
|
||||
const config = customConfig || { bucket: S3_KB_CONFIG.bucket, region: S3_KB_CONFIG.region }
|
||||
const s3Client = getS3Client()
|
||||
|
||||
const command = new AbortMultipartUploadCommand({
|
||||
Bucket: config.bucket,
|
||||
Key: key,
|
||||
UploadId: uploadId,
|
||||
})
|
||||
|
||||
await s3Client.send(command)
|
||||
}
|
||||
|
||||
76
apps/sim/lib/uploads/validation.ts
Normal file
76
apps/sim/lib/uploads/validation.ts
Normal file
@@ -0,0 +1,76 @@
|
||||
import path from 'path'
|
||||
|
||||
export const SUPPORTED_DOCUMENT_EXTENSIONS = [
|
||||
'pdf',
|
||||
'csv',
|
||||
'doc',
|
||||
'docx',
|
||||
'txt',
|
||||
'md',
|
||||
'xlsx',
|
||||
'xls',
|
||||
] as const
|
||||
|
||||
export type SupportedDocumentExtension = (typeof SUPPORTED_DOCUMENT_EXTENSIONS)[number]
|
||||
|
||||
export const SUPPORTED_MIME_TYPES: Record<SupportedDocumentExtension, string[]> = {
|
||||
pdf: ['application/pdf'],
|
||||
csv: ['text/csv', 'application/csv'],
|
||||
doc: ['application/msword'],
|
||||
docx: ['application/vnd.openxmlformats-officedocument.wordprocessingml.document'],
|
||||
txt: ['text/plain'],
|
||||
md: ['text/markdown', 'text/x-markdown'],
|
||||
xlsx: ['application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'],
|
||||
xls: ['application/vnd.ms-excel'],
|
||||
}
|
||||
|
||||
export interface FileValidationError {
|
||||
code: 'UNSUPPORTED_FILE_TYPE' | 'MIME_TYPE_MISMATCH'
|
||||
message: string
|
||||
supportedTypes: string[]
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate if a file type is supported for document processing
|
||||
*/
|
||||
export function validateFileType(fileName: string, mimeType: string): FileValidationError | null {
|
||||
const extension = path.extname(fileName).toLowerCase().substring(1) as SupportedDocumentExtension
|
||||
|
||||
if (!SUPPORTED_DOCUMENT_EXTENSIONS.includes(extension)) {
|
||||
return {
|
||||
code: 'UNSUPPORTED_FILE_TYPE',
|
||||
message: `Unsupported file type: ${extension}. Supported types are: ${SUPPORTED_DOCUMENT_EXTENSIONS.join(', ')}`,
|
||||
supportedTypes: [...SUPPORTED_DOCUMENT_EXTENSIONS],
|
||||
}
|
||||
}
|
||||
|
||||
const allowedMimeTypes = SUPPORTED_MIME_TYPES[extension]
|
||||
if (!allowedMimeTypes.includes(mimeType)) {
|
||||
return {
|
||||
code: 'MIME_TYPE_MISMATCH',
|
||||
message: `MIME type ${mimeType} does not match file extension ${extension}. Expected: ${allowedMimeTypes.join(', ')}`,
|
||||
supportedTypes: allowedMimeTypes,
|
||||
}
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if file extension is supported
|
||||
*/
|
||||
export function isSupportedExtension(extension: string): extension is SupportedDocumentExtension {
|
||||
return SUPPORTED_DOCUMENT_EXTENSIONS.includes(
|
||||
extension.toLowerCase() as SupportedDocumentExtension
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get supported MIME types for an extension
|
||||
*/
|
||||
export function getSupportedMimeTypes(extension: string): string[] {
|
||||
if (isSupportedExtension(extension)) {
|
||||
return SUPPORTED_MIME_TYPES[extension as SupportedDocumentExtension]
|
||||
}
|
||||
return []
|
||||
}
|
||||
@@ -125,6 +125,7 @@
|
||||
"tailwindcss-animate": "^1.0.7",
|
||||
"three": "0.177.0",
|
||||
"uuid": "^11.1.0",
|
||||
"word-extractor": "1.0.4",
|
||||
"xlsx": "0.18.5",
|
||||
"zod": "^3.24.2"
|
||||
},
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
#!/usr/bin/env bun
|
||||
|
||||
import path from 'path'
|
||||
import { DocsChunker } from '@/lib/documents/docs-chunker'
|
||||
import type { DocChunk } from '@/lib/documents/types'
|
||||
import { DocsChunker } from '@/lib/knowledge/documents/docs-chunker'
|
||||
import type { DocChunk } from '@/lib/knowledge/documents/types'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
|
||||
const logger = createLogger('ChunkDocsScript')
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
|
||||
import path from 'path'
|
||||
import { sql } from 'drizzle-orm'
|
||||
import { DocsChunker } from '@/lib/documents/docs-chunker'
|
||||
import { isDev } from '@/lib/environment'
|
||||
import { DocsChunker } from '@/lib/knowledge/documents/docs-chunker'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { db } from '@/db'
|
||||
import { docsEmbeddings } from '@/db/schema'
|
||||
|
||||
@@ -99,6 +99,8 @@ export interface DocumentsCache {
|
||||
documents: DocumentData[]
|
||||
pagination: DocumentsPagination
|
||||
searchQuery?: string
|
||||
sortBy?: string
|
||||
sortOrder?: string
|
||||
lastFetchTime: number
|
||||
}
|
||||
|
||||
@@ -120,7 +122,13 @@ interface KnowledgeStore {
|
||||
getKnowledgeBase: (id: string) => Promise<KnowledgeBaseData | null>
|
||||
getDocuments: (
|
||||
knowledgeBaseId: string,
|
||||
options?: { search?: string; limit?: number; offset?: number }
|
||||
options?: {
|
||||
search?: string
|
||||
limit?: number
|
||||
offset?: number
|
||||
sortBy?: string
|
||||
sortOrder?: string
|
||||
}
|
||||
) => Promise<DocumentData[]>
|
||||
getChunks: (
|
||||
knowledgeBaseId: string,
|
||||
@@ -130,7 +138,13 @@ interface KnowledgeStore {
|
||||
getKnowledgeBasesList: (workspaceId?: string) => Promise<KnowledgeBaseData[]>
|
||||
refreshDocuments: (
|
||||
knowledgeBaseId: string,
|
||||
options?: { search?: string; limit?: number; offset?: number }
|
||||
options?: {
|
||||
search?: string
|
||||
limit?: number
|
||||
offset?: number
|
||||
sortBy?: string
|
||||
sortOrder?: string
|
||||
}
|
||||
) => Promise<DocumentData[]>
|
||||
refreshChunks: (
|
||||
knowledgeBaseId: string,
|
||||
@@ -257,7 +271,13 @@ export const useKnowledgeStore = create<KnowledgeStore>((set, get) => ({
|
||||
|
||||
getDocuments: async (
|
||||
knowledgeBaseId: string,
|
||||
options?: { search?: string; limit?: number; offset?: number }
|
||||
options?: {
|
||||
search?: string
|
||||
limit?: number
|
||||
offset?: number
|
||||
sortBy?: string
|
||||
sortOrder?: string
|
||||
}
|
||||
) => {
|
||||
const state = get()
|
||||
|
||||
@@ -266,12 +286,16 @@ export const useKnowledgeStore = create<KnowledgeStore>((set, get) => ({
|
||||
const requestLimit = options?.limit || 50
|
||||
const requestOffset = options?.offset || 0
|
||||
const requestSearch = options?.search
|
||||
const requestSortBy = options?.sortBy
|
||||
const requestSortOrder = options?.sortOrder
|
||||
|
||||
if (
|
||||
cached &&
|
||||
cached.searchQuery === requestSearch &&
|
||||
cached.pagination.limit === requestLimit &&
|
||||
cached.pagination.offset === requestOffset
|
||||
cached.pagination.offset === requestOffset &&
|
||||
cached.sortBy === requestSortBy &&
|
||||
cached.sortOrder === requestSortOrder
|
||||
) {
|
||||
return cached.documents
|
||||
}
|
||||
@@ -289,6 +313,8 @@ export const useKnowledgeStore = create<KnowledgeStore>((set, get) => ({
|
||||
// Build query parameters using the same defaults as caching
|
||||
const params = new URLSearchParams()
|
||||
if (requestSearch) params.set('search', requestSearch)
|
||||
if (requestSortBy) params.set('sortBy', requestSortBy)
|
||||
if (requestSortOrder) params.set('sortOrder', requestSortOrder)
|
||||
params.set('limit', requestLimit.toString())
|
||||
params.set('offset', requestOffset.toString())
|
||||
|
||||
@@ -317,6 +343,8 @@ export const useKnowledgeStore = create<KnowledgeStore>((set, get) => ({
|
||||
documents,
|
||||
pagination,
|
||||
searchQuery: requestSearch,
|
||||
sortBy: requestSortBy,
|
||||
sortOrder: requestSortOrder,
|
||||
lastFetchTime: Date.now(),
|
||||
}
|
||||
|
||||
@@ -510,7 +538,13 @@ export const useKnowledgeStore = create<KnowledgeStore>((set, get) => ({
|
||||
|
||||
refreshDocuments: async (
|
||||
knowledgeBaseId: string,
|
||||
options?: { search?: string; limit?: number; offset?: number }
|
||||
options?: {
|
||||
search?: string
|
||||
limit?: number
|
||||
offset?: number
|
||||
sortBy?: string
|
||||
sortOrder?: string
|
||||
}
|
||||
) => {
|
||||
const state = get()
|
||||
|
||||
@@ -528,9 +562,13 @@ export const useKnowledgeStore = create<KnowledgeStore>((set, get) => ({
|
||||
const requestLimit = options?.limit || 50
|
||||
const requestOffset = options?.offset || 0
|
||||
const requestSearch = options?.search
|
||||
const requestSortBy = options?.sortBy
|
||||
const requestSortOrder = options?.sortOrder
|
||||
|
||||
const params = new URLSearchParams()
|
||||
if (requestSearch) params.set('search', requestSearch)
|
||||
if (requestSortBy) params.set('sortBy', requestSortBy)
|
||||
if (requestSortOrder) params.set('sortOrder', requestSortOrder)
|
||||
params.set('limit', requestLimit.toString())
|
||||
params.set('offset', requestOffset.toString())
|
||||
|
||||
@@ -559,6 +597,8 @@ export const useKnowledgeStore = create<KnowledgeStore>((set, get) => ({
|
||||
documents,
|
||||
pagination,
|
||||
searchQuery: requestSearch,
|
||||
sortBy: requestSortBy,
|
||||
sortOrder: requestSortOrder,
|
||||
lastFetchTime: Date.now(),
|
||||
}
|
||||
|
||||
|
||||
20
bun.lock
20
bun.lock
@@ -15,6 +15,7 @@
|
||||
"devDependencies": {
|
||||
"@biomejs/biome": "2.0.0-beta.5",
|
||||
"@next/env": "^15.3.2",
|
||||
"@types/word-extractor": "1.0.6",
|
||||
"dotenv-cli": "^8.0.0",
|
||||
"husky": "9.1.7",
|
||||
"lint-staged": "16.0.0",
|
||||
@@ -154,6 +155,7 @@
|
||||
"tailwindcss-animate": "^1.0.7",
|
||||
"three": "0.177.0",
|
||||
"uuid": "^11.1.0",
|
||||
"word-extractor": "1.0.4",
|
||||
"xlsx": "0.18.5",
|
||||
"zod": "^3.24.2",
|
||||
},
|
||||
@@ -1436,6 +1438,8 @@
|
||||
|
||||
"@types/webxr": ["@types/webxr@0.5.22", "", {}, "sha512-Vr6Stjv5jPRqH690f5I5GLjVk8GSsoQSYJ2FVd/3jJF7KaqfwPi3ehfBS96mlQ2kPCwZaX6U0rG2+NGHBKkA/A=="],
|
||||
|
||||
"@types/word-extractor": ["@types/word-extractor@1.0.6", "", { "dependencies": { "@types/node": "*" } }, "sha512-NDrvZXGJi7cTKXGr8GTP08HiqiueggR1wfHZvBj1sfL8e52qecBSlvl1rBWrvOY0LLkk1DISkKVlFqMTfipLbQ=="],
|
||||
|
||||
"@types/xlsx": ["@types/xlsx@0.0.36", "", { "dependencies": { "xlsx": "*" } }, "sha512-mvfrKiKKMErQzLMF8ElYEH21qxWCZtN59pHhWGmWCWFJStYdMWjkDSAy6mGowFxHXaXZWe5/TW7pBUiWclIVOw=="],
|
||||
|
||||
"@typespec/ts-http-runtime": ["@typespec/ts-http-runtime@0.3.0", "", { "dependencies": { "http-proxy-agent": "^7.0.0", "https-proxy-agent": "^7.0.0", "tslib": "^2.6.2" } }, "sha512-sOx1PKSuFwnIl7z4RN0Ls7N9AQawmR9r66eI5rFCzLDIs8HTIYrIpH9QjYWoX0lkgGrkLxXhi4QnK7MizPRrIg=="],
|
||||
@@ -1604,6 +1608,8 @@
|
||||
|
||||
"buffer": ["buffer@5.7.1", "", { "dependencies": { "base64-js": "^1.3.1", "ieee754": "^1.1.13" } }, "sha512-EHcyIPBQ4BSGlvjB16k5KgAJ27CIsHY/2JBmCRReo48y9rQ3MaUzWX3KVlBa4U7MyX02HdVj0K7C3WaB3ju7FQ=="],
|
||||
|
||||
"buffer-crc32": ["buffer-crc32@0.2.13", "", {}, "sha512-VO9Ht/+p3SN7SKWqcrgEzjGbRSJYTx+Q1pTQC0wrWqHx0vpJraQ6GtHx8tvcg1rlK1byhU5gccxgOgj7B0TDkQ=="],
|
||||
|
||||
"buffer-equal-constant-time": ["buffer-equal-constant-time@1.0.1", "", {}, "sha512-zRpUiDwd/xk6ADqPMATG8vc9VPrkck7T07OIx0gnjmJAnHnTVXNQG3vfvWNuiZIkwu9KrKdA1iJKfsfTVxE6NA=="],
|
||||
|
||||
"buffer-from": ["buffer-from@1.1.2", "", {}, "sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ=="],
|
||||
@@ -1974,6 +1980,8 @@
|
||||
|
||||
"fastq": ["fastq@1.19.1", "", { "dependencies": { "reusify": "^1.0.4" } }, "sha512-GwLTyxkCXjXbxqIhTsMI2Nui8huMPtnxg7krajPJAjnEG/iiOS7i+zCtWGZR9G0NBKbXKh6X9m9UIsYX/N6vvQ=="],
|
||||
|
||||
"fd-slicer": ["fd-slicer@1.1.0", "", { "dependencies": { "pend": "~1.2.0" } }, "sha512-cE1qsB/VwyQozZ+q1dGxR8LBYNZeofhEdUNGSMbQD3Gw2lAzX9Zb3uIU6Ebc/Fmyjo9AWWfnn0AUCHqtevs/8g=="],
|
||||
|
||||
"fdir": ["fdir@6.5.0", "", { "peerDependencies": { "picomatch": "^3 || ^4" }, "optionalPeers": ["picomatch"] }, "sha512-tIbYtZbucOs0BRGqPJkshJUYdL+SDH7dVM8gjy+ERp3WAUjLEFJE+02kanyHtwjWOnwrKYBiwAmM0p4kLJAnXg=="],
|
||||
|
||||
"fetch-blob": ["fetch-blob@3.2.0", "", { "dependencies": { "node-domexception": "^1.0.0", "web-streams-polyfill": "^3.0.3" } }, "sha512-7yAQpD2UMJzLi1Dqv7qFYnPbaPx7ZfFK6PiIxQ4PfkGPyNyl2Ugx+a/umUonmKqjhM4DnfbMvdX6otXq83soQQ=="],
|
||||
@@ -2600,6 +2608,8 @@
|
||||
|
||||
"peberminta": ["peberminta@0.9.0", "", {}, "sha512-XIxfHpEuSJbITd1H3EeQwpcZbTLHc+VVr8ANI9t5sit565tsI4/xK3KWTUFE2e6QiangUkh3B0jihzmGnNrRsQ=="],
|
||||
|
||||
"pend": ["pend@1.2.0", "", {}, "sha512-F3asv42UuXchdzt+xXqfW1OGlVBe+mxa2mqI0pg5yAHZPvFmY3Y6drSf/GQ1A86WgWEN9Kzh/WrgKa6iGcHXLg=="],
|
||||
|
||||
"pg": ["pg@8.16.3", "", { "dependencies": { "pg-connection-string": "^2.9.1", "pg-pool": "^3.10.1", "pg-protocol": "^1.10.3", "pg-types": "2.2.0", "pgpass": "1.0.5" }, "optionalDependencies": { "pg-cloudflare": "^1.2.7" }, "peerDependencies": { "pg-native": ">=3.0.1" }, "optionalPeers": ["pg-native"] }, "sha512-enxc1h0jA/aq5oSDMvqyW3q89ra6XIIDZgCX9vkMrnz5DFTw/Ny3Li2lFQ+pt3L6MCgm/5o2o8HW9hiJji+xvw=="],
|
||||
|
||||
"pg-cloudflare": ["pg-cloudflare@1.2.7", "", {}, "sha512-YgCtzMH0ptvZJslLM1ffsY4EuGaU0cx4XSdXLRFae8bPP4dS5xL1tNB3k2o/N64cHJpwU7dxKli/nZ2lUa5fLg=="],
|
||||
@@ -3174,6 +3184,8 @@
|
||||
|
||||
"word": ["word@0.3.0", "", {}, "sha512-OELeY0Q61OXpdUfTp+oweA/vtLVg5VDOXh+3he3PNzLGG/y0oylSOC1xRVj0+l4vQ3tj/bB1HVHv1ocXkQceFA=="],
|
||||
|
||||
"word-extractor": ["word-extractor@1.0.4", "", { "dependencies": { "saxes": "^5.0.1", "yauzl": "^2.10.0" } }, "sha512-PyAGZQ2gjnVA5kcZAOAxoYciCMaAvu0dbVlw/zxHphhy+3be8cDeYKHJPO8iedIM3Sx0arA/ugKTJyXhZNgo6g=="],
|
||||
|
||||
"wrap-ansi": ["wrap-ansi@6.2.0", "", { "dependencies": { "ansi-styles": "^4.0.0", "string-width": "^4.1.0", "strip-ansi": "^6.0.0" } }, "sha512-r6lPcBGxZXlIcymEu7InxDMhdW0KDxpLgoFLcguasxCaJ/SOIZwINatK9KY/tf+ZrlywOKU0UDj3ATXUBfxJXA=="],
|
||||
|
||||
"wrap-ansi-cjs": ["wrap-ansi@7.0.0", "", { "dependencies": { "ansi-styles": "^4.0.0", "string-width": "^4.1.0", "strip-ansi": "^6.0.0" } }, "sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q=="],
|
||||
@@ -3206,6 +3218,8 @@
|
||||
|
||||
"yargs-parser": ["yargs-parser@21.1.1", "", {}, "sha512-tVpsJW7DdjecAiFpbIB1e3qxIQsE6NoPc5/eTdrbbIC4h0LVsWhnoa3g+m2HclBIujHzsxZ4VJVA+GUuc2/LBw=="],
|
||||
|
||||
"yauzl": ["yauzl@2.10.0", "", { "dependencies": { "buffer-crc32": "~0.2.3", "fd-slicer": "~1.1.0" } }, "sha512-p4a9I6X6nu6IhoGmBqAcbJy1mlC4j27vEPZX9F4L4/vZT3Lyq1VkFHw/V/PUcB9Buo+DG3iHkT0x3Qya58zc3g=="],
|
||||
|
||||
"yocto-queue": ["yocto-queue@0.1.0", "", {}, "sha512-rVksvsnNCdJ/ohGc6xgPwyN8eheCxsiLM8mxuE/t/mOVqJewPuO1miLpTHQiRgTKCLexL4MeAFVagts7HmNZ2Q=="],
|
||||
|
||||
"yoctocolors": ["yoctocolors@2.1.2", "", {}, "sha512-CzhO+pFNo8ajLM2d2IW/R93ipy99LWjtwblvC1RsoSUMZgyLbYFr221TnSNT7GjGdYui6P459mw9JH/g/zW2ug=="],
|
||||
@@ -3630,6 +3644,8 @@
|
||||
|
||||
"@types/webpack/@types/node": ["@types/node@24.2.1", "", { "dependencies": { "undici-types": "~7.10.0" } }, "sha512-DRh5K+ka5eJic8CjH7td8QpYEV6Zo10gfRkjHCO3weqZHWDtAaSTFtl4+VMqOJ4N5jcuhZ9/l+yy8rVgw7BQeQ=="],
|
||||
|
||||
"@types/word-extractor/@types/node": ["@types/node@24.2.1", "", { "dependencies": { "undici-types": "~7.10.0" } }, "sha512-DRh5K+ka5eJic8CjH7td8QpYEV6Zo10gfRkjHCO3weqZHWDtAaSTFtl4+VMqOJ4N5jcuhZ9/l+yy8rVgw7BQeQ=="],
|
||||
|
||||
"@vitejs/plugin-react/@babel/core": ["@babel/core@7.28.3", "", { "dependencies": { "@ampproject/remapping": "^2.2.0", "@babel/code-frame": "^7.27.1", "@babel/generator": "^7.28.3", "@babel/helper-compilation-targets": "^7.27.2", "@babel/helper-module-transforms": "^7.28.3", "@babel/helpers": "^7.28.3", "@babel/parser": "^7.28.3", "@babel/template": "^7.27.2", "@babel/traverse": "^7.28.3", "@babel/types": "^7.28.2", "convert-source-map": "^2.0.0", "debug": "^4.1.0", "gensync": "^1.0.0-beta.2", "json5": "^2.2.3", "semver": "^6.3.1" } }, "sha512-yDBHV9kQNcr2/sUr9jghVyz9C3Y5G2zUM2H2lo+9mKv4sFgbA8s8Z9t8D1jiTkGoO/NoIfKMyKWr4s6CN23ZwQ=="],
|
||||
|
||||
"accepts/mime-types": ["mime-types@2.1.35", "", { "dependencies": { "mime-db": "1.52.0" } }, "sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw=="],
|
||||
@@ -3834,6 +3850,8 @@
|
||||
|
||||
"webpack/mime-types": ["mime-types@2.1.35", "", { "dependencies": { "mime-db": "1.52.0" } }, "sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw=="],
|
||||
|
||||
"word-extractor/saxes": ["saxes@5.0.1", "", { "dependencies": { "xmlchars": "^2.2.0" } }, "sha512-5LBh1Tls8c9xgGjw3QrMwETmTMVk0oFgvrFSvWx62llR2hcEInrKNZ2GZCCuuy2lvWrdl5jhbpeqc5hRYKFOcw=="],
|
||||
|
||||
"@anthropic-ai/sdk/@types/node/undici-types": ["undici-types@5.26.5", "", {}, "sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA=="],
|
||||
|
||||
"@anthropic-ai/sdk/node-fetch/whatwg-url": ["whatwg-url@5.0.0", "", { "dependencies": { "tr46": "~0.0.3", "webidl-conversions": "^3.0.0" } }, "sha512-saE57nupxk6v3HY35+jzBwYa0rKSy0XR8JSxZPwgLr7ys0IBzhGviA1/TUGJLmSVqs8pb9AnvICXEuOHLprYTw=="],
|
||||
@@ -4202,6 +4220,8 @@
|
||||
|
||||
"@types/webpack/@types/node/undici-types": ["undici-types@7.10.0", "", {}, "sha512-t5Fy/nfn+14LuOc2KNYg75vZqClpAiqscVvMygNnlsHBFpSXdJaYtXMcdNLpl/Qvc3P2cB3s6lOV51nqsFq4ag=="],
|
||||
|
||||
"@types/word-extractor/@types/node/undici-types": ["undici-types@7.10.0", "", {}, "sha512-t5Fy/nfn+14LuOc2KNYg75vZqClpAiqscVvMygNnlsHBFpSXdJaYtXMcdNLpl/Qvc3P2cB3s6lOV51nqsFq4ag=="],
|
||||
|
||||
"@vitejs/plugin-react/@babel/core/@babel/parser": ["@babel/parser@7.28.3", "", { "dependencies": { "@babel/types": "^7.28.2" }, "bin": "./bin/babel-parser.js" }, "sha512-7+Ey1mAgYqFAx2h0RuoxcQT5+MlG3GTV0TQrgr7/ZliKsm/MNDxVVutlWaziMq7wJNAz8MTqz55XLpWvva6StA=="],
|
||||
|
||||
"@vitejs/plugin-react/@babel/core/@babel/traverse": ["@babel/traverse@7.28.3", "", { "dependencies": { "@babel/code-frame": "^7.27.1", "@babel/generator": "^7.28.3", "@babel/helper-globals": "^7.28.0", "@babel/parser": "^7.28.3", "@babel/template": "^7.27.2", "@babel/types": "^7.28.2", "debug": "^4.3.1" } }, "sha512-7w4kZYHneL3A6NP2nxzHvT3HCZ7puDZZjFMqDpBPECub79sTtSO5CGXDkKrTQq8ksAwfD/XI2MRFX23njdDaIQ=="],
|
||||
|
||||
@@ -41,6 +41,7 @@
|
||||
"devDependencies": {
|
||||
"@biomejs/biome": "2.0.0-beta.5",
|
||||
"@next/env": "^15.3.2",
|
||||
"@types/word-extractor": "1.0.6",
|
||||
"dotenv-cli": "^8.0.0",
|
||||
"husky": "9.1.7",
|
||||
"lint-staged": "16.0.0",
|
||||
|
||||
Reference in New Issue
Block a user