mirror of
https://github.com/simstudioai/sim.git
synced 2026-01-11 07:58:06 -05:00
Compare commits
35 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ee17cf461a | ||
|
|
43cb124d97 | ||
|
|
76889fde26 | ||
|
|
7780d9b32b | ||
|
|
4a703a02cb | ||
|
|
a969d09782 | ||
|
|
0bc778130f | ||
|
|
df3d532495 | ||
|
|
f4f8fc051e | ||
|
|
76fac13f3d | ||
|
|
a3838302e0 | ||
|
|
4310dd6c15 | ||
|
|
813a0fb741 | ||
|
|
7e23e942d7 | ||
|
|
7fcbafab97 | ||
|
|
056dc2879c | ||
|
|
1aec32b7e2 | ||
|
|
316c9704af | ||
|
|
4e3a3bd1b1 | ||
|
|
36773e8cdb | ||
|
|
7ac89e35a1 | ||
|
|
faa094195a | ||
|
|
69319d21cd | ||
|
|
8362fd7a83 | ||
|
|
39ad793a9a | ||
|
|
921c755711 | ||
|
|
41ec75fcad | ||
|
|
f2502f5e48 | ||
|
|
f3c4f7e20a | ||
|
|
f578f43c9a | ||
|
|
5c73038023 | ||
|
|
92132024ca | ||
|
|
ed11456de3 | ||
|
|
8739a3d378 | ||
|
|
ca015deea9 |
@@ -9,8 +9,8 @@ services:
|
||||
command: sleep infinity
|
||||
environment:
|
||||
- NODE_ENV=development
|
||||
- DATABASE_URL=postgresql://postgres:postgres@db:${POSTGRES_PORT:-5432}/simstudio
|
||||
- POSTGRES_URL=postgresql://postgres:postgres@db:${POSTGRES_PORT:-5432}/simstudio
|
||||
- DATABASE_URL=postgresql://postgres:postgres@db:5432/simstudio
|
||||
- POSTGRES_URL=postgresql://postgres:postgres@db:5432/simstudio
|
||||
- BETTER_AUTH_URL=http://localhost:3000
|
||||
- NEXT_PUBLIC_APP_URL=http://localhost:3000
|
||||
- BUN_INSTALL_CACHE_DIR=/home/bun/.bun/cache
|
||||
@@ -39,7 +39,7 @@ services:
|
||||
command: sleep infinity
|
||||
environment:
|
||||
- NODE_ENV=development
|
||||
- DATABASE_URL=postgresql://postgres:postgres@db:${POSTGRES_PORT:-5432}/simstudio
|
||||
- DATABASE_URL=postgresql://postgres:postgres@db:5432/simstudio
|
||||
- BETTER_AUTH_URL=http://localhost:3000
|
||||
- NEXT_PUBLIC_APP_URL=http://localhost:3000
|
||||
depends_on:
|
||||
@@ -60,7 +60,7 @@ services:
|
||||
context: ..
|
||||
dockerfile: docker/db.Dockerfile
|
||||
environment:
|
||||
- DATABASE_URL=postgresql://postgres:postgres@db:${POSTGRES_PORT:-5432}/simstudio
|
||||
- DATABASE_URL=postgresql://postgres:postgres@db:5432/simstudio
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
|
||||
2
.github/workflows/ci.yml
vendored
2
.github/workflows/ci.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
||||
node-version: latest
|
||||
|
||||
- name: Install dependencies
|
||||
run: bun install
|
||||
run: bun install --frozen-lockfile
|
||||
|
||||
- name: Run tests with coverage
|
||||
env:
|
||||
|
||||
@@ -2,7 +2,7 @@ import { and, eq } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { getSimplifiedBillingSummary } from '@/lib/billing/core/billing'
|
||||
import { getOrganizationBillingData } from '@/lib/billing/core/organization-billing'
|
||||
import { getOrganizationBillingData } from '@/lib/billing/core/organization'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { db } from '@/db'
|
||||
import { member, userStats } from '@/db/schema'
|
||||
|
||||
@@ -12,9 +12,9 @@ import {
|
||||
import { getCopilotModel } from '@/lib/copilot/config'
|
||||
import type { CopilotProviderConfig } from '@/lib/copilot/types'
|
||||
import { env } from '@/lib/env'
|
||||
import { generateChatTitle } from '@/lib/generate-chat-title'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { SIM_AGENT_API_URL_DEFAULT } from '@/lib/sim-agent'
|
||||
import { generateChatTitle } from '@/lib/sim-agent/utils'
|
||||
import { createFileContent, isSupportedFileType } from '@/lib/uploads/file-utils'
|
||||
import { S3_COPILOT_CONFIG } from '@/lib/uploads/setup'
|
||||
import { downloadFile, getStorageProvider } from '@/lib/uploads/storage-client'
|
||||
|
||||
@@ -76,11 +76,9 @@ export async function POST(request: NextRequest) {
|
||||
|
||||
logger.info('File parse request received:', { filePath, fileType })
|
||||
|
||||
// Handle multiple files
|
||||
if (Array.isArray(filePath)) {
|
||||
const results = []
|
||||
for (const path of filePath) {
|
||||
// Skip empty or invalid paths
|
||||
if (!path || (typeof path === 'string' && path.trim() === '')) {
|
||||
results.push({
|
||||
success: false,
|
||||
@@ -91,12 +89,10 @@ export async function POST(request: NextRequest) {
|
||||
}
|
||||
|
||||
const result = await parseFileSingle(path, fileType)
|
||||
// Add processing time to metadata
|
||||
if (result.metadata) {
|
||||
result.metadata.processingTime = Date.now() - startTime
|
||||
}
|
||||
|
||||
// Transform each result to match expected frontend format
|
||||
if (result.success) {
|
||||
results.push({
|
||||
success: true,
|
||||
@@ -105,7 +101,7 @@ export async function POST(request: NextRequest) {
|
||||
name: result.filePath.split('/').pop() || 'unknown',
|
||||
fileType: result.metadata?.fileType || 'application/octet-stream',
|
||||
size: result.metadata?.size || 0,
|
||||
binary: false, // We only return text content
|
||||
binary: false,
|
||||
},
|
||||
filePath: result.filePath,
|
||||
})
|
||||
@@ -120,15 +116,12 @@ export async function POST(request: NextRequest) {
|
||||
})
|
||||
}
|
||||
|
||||
// Handle single file
|
||||
const result = await parseFileSingle(filePath, fileType)
|
||||
|
||||
// Add processing time to metadata
|
||||
if (result.metadata) {
|
||||
result.metadata.processingTime = Date.now() - startTime
|
||||
}
|
||||
|
||||
// Transform single file result to match expected frontend format
|
||||
if (result.success) {
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
@@ -142,8 +135,6 @@ export async function POST(request: NextRequest) {
|
||||
})
|
||||
}
|
||||
|
||||
// Only return 500 for actual server errors, not file processing failures
|
||||
// File processing failures (like file not found, parsing errors) should return 200 with success:false
|
||||
return NextResponse.json(result)
|
||||
} catch (error) {
|
||||
logger.error('Error in file parse API:', error)
|
||||
@@ -164,7 +155,6 @@ export async function POST(request: NextRequest) {
|
||||
async function parseFileSingle(filePath: string, fileType?: string): Promise<ParseResult> {
|
||||
logger.info('Parsing file:', filePath)
|
||||
|
||||
// Validate that filePath is not empty
|
||||
if (!filePath || filePath.trim() === '') {
|
||||
return {
|
||||
success: false,
|
||||
@@ -173,7 +163,6 @@ async function parseFileSingle(filePath: string, fileType?: string): Promise<Par
|
||||
}
|
||||
}
|
||||
|
||||
// Validate path for security before any processing
|
||||
const pathValidation = validateFilePath(filePath)
|
||||
if (!pathValidation.isValid) {
|
||||
return {
|
||||
@@ -183,49 +172,40 @@ async function parseFileSingle(filePath: string, fileType?: string): Promise<Par
|
||||
}
|
||||
}
|
||||
|
||||
// Check if this is an external URL
|
||||
if (filePath.startsWith('http://') || filePath.startsWith('https://')) {
|
||||
return handleExternalUrl(filePath, fileType)
|
||||
}
|
||||
|
||||
// Check if this is a cloud storage path (S3 or Blob)
|
||||
const isS3Path = filePath.includes('/api/files/serve/s3/')
|
||||
const isBlobPath = filePath.includes('/api/files/serve/blob/')
|
||||
|
||||
// Use cloud handler if it's a cloud path or we're in cloud mode
|
||||
if (isS3Path || isBlobPath || isUsingCloudStorage()) {
|
||||
return handleCloudFile(filePath, fileType)
|
||||
}
|
||||
|
||||
// Use local handler for local files
|
||||
return handleLocalFile(filePath, fileType)
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate file path for security
|
||||
* Validate file path for security - prevents null byte injection and path traversal attacks
|
||||
*/
|
||||
function validateFilePath(filePath: string): { isValid: boolean; error?: string } {
|
||||
// Check for null bytes
|
||||
if (filePath.includes('\0')) {
|
||||
return { isValid: false, error: 'Invalid path: null byte detected' }
|
||||
}
|
||||
|
||||
// Check for path traversal attempts
|
||||
if (filePath.includes('..')) {
|
||||
return { isValid: false, error: 'Access denied: path traversal detected' }
|
||||
}
|
||||
|
||||
// Check for tilde characters (home directory access)
|
||||
if (filePath.includes('~')) {
|
||||
return { isValid: false, error: 'Invalid path: tilde character not allowed' }
|
||||
}
|
||||
|
||||
// Check for absolute paths outside allowed directories
|
||||
if (filePath.startsWith('/') && !filePath.startsWith('/api/files/serve/')) {
|
||||
return { isValid: false, error: 'Path outside allowed directory' }
|
||||
}
|
||||
|
||||
// Check for Windows absolute paths
|
||||
if (/^[A-Za-z]:\\/.test(filePath)) {
|
||||
return { isValid: false, error: 'Path outside allowed directory' }
|
||||
}
|
||||
@@ -260,12 +240,10 @@ async function handleExternalUrl(url: string, fileType?: string): Promise<ParseR
|
||||
|
||||
logger.info(`Downloaded file from URL: ${url}, size: ${buffer.length} bytes`)
|
||||
|
||||
// Extract filename from URL
|
||||
const urlPath = new URL(url).pathname
|
||||
const filename = urlPath.split('/').pop() || 'download'
|
||||
const extension = path.extname(filename).toLowerCase().substring(1)
|
||||
|
||||
// Process the file based on its content type
|
||||
if (extension === 'pdf') {
|
||||
return await handlePdfBuffer(buffer, filename, fileType, url)
|
||||
}
|
||||
@@ -276,7 +254,6 @@ async function handleExternalUrl(url: string, fileType?: string): Promise<ParseR
|
||||
return await handleGenericTextBuffer(buffer, filename, extension, fileType, url)
|
||||
}
|
||||
|
||||
// For binary or unknown files
|
||||
return handleGenericBuffer(buffer, filename, extension, fileType)
|
||||
} catch (error) {
|
||||
logger.error(`Error handling external URL ${url}:`, error)
|
||||
@@ -289,35 +266,29 @@ async function handleExternalUrl(url: string, fileType?: string): Promise<ParseR
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle file stored in cloud storage (S3 or Azure Blob)
|
||||
* Handle file stored in cloud storage
|
||||
*/
|
||||
async function handleCloudFile(filePath: string, fileType?: string): Promise<ParseResult> {
|
||||
try {
|
||||
// Extract the cloud key from the path
|
||||
let cloudKey: string
|
||||
if (filePath.includes('/api/files/serve/s3/')) {
|
||||
cloudKey = decodeURIComponent(filePath.split('/api/files/serve/s3/')[1])
|
||||
} else if (filePath.includes('/api/files/serve/blob/')) {
|
||||
cloudKey = decodeURIComponent(filePath.split('/api/files/serve/blob/')[1])
|
||||
} else if (filePath.startsWith('/api/files/serve/')) {
|
||||
// Backwards-compatibility: path like "/api/files/serve/<key>"
|
||||
cloudKey = decodeURIComponent(filePath.substring('/api/files/serve/'.length))
|
||||
} else {
|
||||
// Assume raw key provided
|
||||
cloudKey = filePath
|
||||
}
|
||||
|
||||
logger.info('Extracted cloud key:', cloudKey)
|
||||
|
||||
// Download the file from cloud storage - this can throw for access errors
|
||||
const fileBuffer = await downloadFile(cloudKey)
|
||||
logger.info(`Downloaded file from cloud storage: ${cloudKey}, size: ${fileBuffer.length} bytes`)
|
||||
|
||||
// Extract the filename from the cloud key
|
||||
const filename = cloudKey.split('/').pop() || cloudKey
|
||||
const extension = path.extname(filename).toLowerCase().substring(1)
|
||||
|
||||
// Process the file based on its content type
|
||||
if (extension === 'pdf') {
|
||||
return await handlePdfBuffer(fileBuffer, filename, fileType, filePath)
|
||||
}
|
||||
@@ -325,22 +296,19 @@ async function handleCloudFile(filePath: string, fileType?: string): Promise<Par
|
||||
return await handleCsvBuffer(fileBuffer, filename, fileType, filePath)
|
||||
}
|
||||
if (isSupportedFileType(extension)) {
|
||||
// For other supported types that we have parsers for
|
||||
return await handleGenericTextBuffer(fileBuffer, filename, extension, fileType, filePath)
|
||||
}
|
||||
// For binary or unknown files
|
||||
return handleGenericBuffer(fileBuffer, filename, extension, fileType)
|
||||
} catch (error) {
|
||||
logger.error(`Error handling cloud file ${filePath}:`, error)
|
||||
|
||||
// Check if this is a download/access error that should trigger a 500 response
|
||||
// For download/access errors, throw to trigger 500 response
|
||||
const errorMessage = (error as Error).message
|
||||
if (errorMessage.includes('Access denied') || errorMessage.includes('Forbidden')) {
|
||||
// For access errors, throw to trigger 500 response
|
||||
throw new Error(`Error accessing file from cloud storage: ${errorMessage}`)
|
||||
}
|
||||
|
||||
// For other errors (parsing, processing), return success:false
|
||||
// For other errors (parsing, processing), return success:false and an error message
|
||||
return {
|
||||
success: false,
|
||||
error: `Error accessing file from cloud storage: ${errorMessage}`,
|
||||
@@ -354,28 +322,23 @@ async function handleCloudFile(filePath: string, fileType?: string): Promise<Par
|
||||
*/
|
||||
async function handleLocalFile(filePath: string, fileType?: string): Promise<ParseResult> {
|
||||
try {
|
||||
// Extract filename from path
|
||||
const filename = filePath.split('/').pop() || filePath
|
||||
const fullPath = path.join(UPLOAD_DIR_SERVER, filename)
|
||||
|
||||
logger.info('Processing local file:', fullPath)
|
||||
|
||||
// Check if file exists
|
||||
try {
|
||||
await fsPromises.access(fullPath)
|
||||
} catch {
|
||||
throw new Error(`File not found: ${filename}`)
|
||||
}
|
||||
|
||||
// Parse the file directly
|
||||
const result = await parseFile(fullPath)
|
||||
|
||||
// Get file stats for metadata
|
||||
const stats = await fsPromises.stat(fullPath)
|
||||
const fileBuffer = await readFile(fullPath)
|
||||
const hash = createHash('md5').update(fileBuffer).digest('hex')
|
||||
|
||||
// Extract file extension for type detection
|
||||
const extension = path.extname(filename).toLowerCase().substring(1)
|
||||
|
||||
return {
|
||||
@@ -386,7 +349,7 @@ async function handleLocalFile(filePath: string, fileType?: string): Promise<Par
|
||||
fileType: fileType || getMimeType(extension),
|
||||
size: stats.size,
|
||||
hash,
|
||||
processingTime: 0, // Will be set by caller
|
||||
processingTime: 0,
|
||||
},
|
||||
}
|
||||
} catch (error) {
|
||||
@@ -425,15 +388,14 @@ async function handlePdfBuffer(
|
||||
fileType: fileType || 'application/pdf',
|
||||
size: fileBuffer.length,
|
||||
hash: createHash('md5').update(fileBuffer).digest('hex'),
|
||||
processingTime: 0, // Will be set by caller
|
||||
processingTime: 0,
|
||||
},
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Failed to parse PDF in memory:', error)
|
||||
|
||||
// Create fallback message for PDF parsing failure
|
||||
const content = createPdfFailureMessage(
|
||||
0, // We can't determine page count without parsing
|
||||
0,
|
||||
fileBuffer.length,
|
||||
originalPath || filename,
|
||||
(error as Error).message
|
||||
@@ -447,7 +409,7 @@ async function handlePdfBuffer(
|
||||
fileType: fileType || 'application/pdf',
|
||||
size: fileBuffer.length,
|
||||
hash: createHash('md5').update(fileBuffer).digest('hex'),
|
||||
processingTime: 0, // Will be set by caller
|
||||
processingTime: 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -465,7 +427,6 @@ async function handleCsvBuffer(
|
||||
try {
|
||||
logger.info(`Parsing CSV in memory: ${filename}`)
|
||||
|
||||
// Use the parseBuffer function from our library
|
||||
const { parseBuffer } = await import('@/lib/file-parsers')
|
||||
const result = await parseBuffer(fileBuffer, 'csv')
|
||||
|
||||
@@ -477,7 +438,7 @@ async function handleCsvBuffer(
|
||||
fileType: fileType || 'text/csv',
|
||||
size: fileBuffer.length,
|
||||
hash: createHash('md5').update(fileBuffer).digest('hex'),
|
||||
processingTime: 0, // Will be set by caller
|
||||
processingTime: 0,
|
||||
},
|
||||
}
|
||||
} catch (error) {
|
||||
@@ -490,7 +451,7 @@ async function handleCsvBuffer(
|
||||
fileType: 'text/csv',
|
||||
size: 0,
|
||||
hash: '',
|
||||
processingTime: 0, // Will be set by caller
|
||||
processingTime: 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -509,7 +470,6 @@ async function handleGenericTextBuffer(
|
||||
try {
|
||||
logger.info(`Parsing text file in memory: ${filename}`)
|
||||
|
||||
// Try to use a specialized parser if available
|
||||
try {
|
||||
const { parseBuffer, isSupportedFileType } = await import('@/lib/file-parsers')
|
||||
|
||||
@@ -524,7 +484,7 @@ async function handleGenericTextBuffer(
|
||||
fileType: fileType || getMimeType(extension),
|
||||
size: fileBuffer.length,
|
||||
hash: createHash('md5').update(fileBuffer).digest('hex'),
|
||||
processingTime: 0, // Will be set by caller
|
||||
processingTime: 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -532,7 +492,6 @@ async function handleGenericTextBuffer(
|
||||
logger.warn('Specialized parser failed, falling back to generic parsing:', parserError)
|
||||
}
|
||||
|
||||
// Fallback to generic text parsing
|
||||
const content = fileBuffer.toString('utf-8')
|
||||
|
||||
return {
|
||||
@@ -543,7 +502,7 @@ async function handleGenericTextBuffer(
|
||||
fileType: fileType || getMimeType(extension),
|
||||
size: fileBuffer.length,
|
||||
hash: createHash('md5').update(fileBuffer).digest('hex'),
|
||||
processingTime: 0, // Will be set by caller
|
||||
processingTime: 0,
|
||||
},
|
||||
}
|
||||
} catch (error) {
|
||||
@@ -556,7 +515,7 @@ async function handleGenericTextBuffer(
|
||||
fileType: 'text/plain',
|
||||
size: 0,
|
||||
hash: '',
|
||||
processingTime: 0, // Will be set by caller
|
||||
processingTime: 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -584,7 +543,7 @@ function handleGenericBuffer(
|
||||
fileType: fileType || getMimeType(extension),
|
||||
size: fileBuffer.length,
|
||||
hash: createHash('md5').update(fileBuffer).digest('hex'),
|
||||
processingTime: 0, // Will be set by caller
|
||||
processingTime: 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -594,8 +553,6 @@ function handleGenericBuffer(
|
||||
*/
|
||||
async function parseBufferAsPdf(buffer: Buffer) {
|
||||
try {
|
||||
// Import parsers dynamically to avoid initialization issues in tests
|
||||
// First try to use the main PDF parser
|
||||
try {
|
||||
const { PdfParser } = await import('@/lib/file-parsers/pdf-parser')
|
||||
const parser = new PdfParser()
|
||||
@@ -606,7 +563,6 @@ async function parseBufferAsPdf(buffer: Buffer) {
|
||||
}
|
||||
throw new Error('PDF parser does not support buffer parsing')
|
||||
} catch (error) {
|
||||
// Fallback to raw PDF parser
|
||||
logger.warn('Main PDF parser failed, using raw parser for buffer:', error)
|
||||
const { RawPdfParser } = await import('@/lib/file-parsers/raw-pdf-parser')
|
||||
const rawParser = new RawPdfParser()
|
||||
@@ -655,7 +611,7 @@ Please use a PDF viewer for best results.`
|
||||
}
|
||||
|
||||
/**
|
||||
* Create error message for PDF parsing failure
|
||||
* Create error message for PDF parsing failure and make it more readable
|
||||
*/
|
||||
function createPdfFailureMessage(
|
||||
pageCount: number,
|
||||
|
||||
@@ -2,7 +2,7 @@ import { randomUUID } from 'crypto'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { SUPPORTED_FIELD_TYPES } from '@/lib/constants/knowledge'
|
||||
import { SUPPORTED_FIELD_TYPES } from '@/lib/knowledge/consts'
|
||||
import {
|
||||
cleanupUnusedTagDefinitions,
|
||||
createOrUpdateTagDefinitionsBulk,
|
||||
|
||||
@@ -2,7 +2,7 @@ import { randomUUID } from 'crypto'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { SUPPORTED_FIELD_TYPES } from '@/lib/constants/knowledge'
|
||||
import { SUPPORTED_FIELD_TYPES } from '@/lib/knowledge/consts'
|
||||
import { createTagDefinition, getTagDefinitions } from '@/lib/knowledge/tags/service'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { checkKnowledgeBaseAccess } from '@/app/api/knowledge/utils'
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { TAG_SLOTS } from '@/lib/constants/knowledge'
|
||||
import { TAG_SLOTS } from '@/lib/knowledge/consts'
|
||||
import { getDocumentTagDefinitions } from '@/lib/knowledge/tags/service'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { estimateTokenCount } from '@/lib/tokenization/estimators'
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { randomUUID } from 'crypto'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
@@ -11,13 +12,13 @@ const DeleteSchema = z.object({
|
||||
database: z.string().min(1, 'Database name is required'),
|
||||
username: z.string().min(1, 'Username is required'),
|
||||
password: z.string().min(1, 'Password is required'),
|
||||
ssl: z.enum(['disabled', 'required', 'preferred']).default('required'),
|
||||
ssl: z.enum(['disabled', 'required', 'preferred']).default('preferred'),
|
||||
table: z.string().min(1, 'Table name is required'),
|
||||
where: z.string().min(1, 'WHERE clause is required'),
|
||||
})
|
||||
|
||||
export async function POST(request: NextRequest) {
|
||||
const requestId = crypto.randomUUID().slice(0, 8)
|
||||
const requestId = randomUUID().slice(0, 8)
|
||||
|
||||
try {
|
||||
const body = await request.json()
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { randomUUID } from 'crypto'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
@@ -11,12 +12,12 @@ const ExecuteSchema = z.object({
|
||||
database: z.string().min(1, 'Database name is required'),
|
||||
username: z.string().min(1, 'Username is required'),
|
||||
password: z.string().min(1, 'Password is required'),
|
||||
ssl: z.enum(['disabled', 'required', 'preferred']).default('required'),
|
||||
ssl: z.enum(['disabled', 'required', 'preferred']).default('preferred'),
|
||||
query: z.string().min(1, 'Query is required'),
|
||||
})
|
||||
|
||||
export async function POST(request: NextRequest) {
|
||||
const requestId = crypto.randomUUID().slice(0, 8)
|
||||
const requestId = randomUUID().slice(0, 8)
|
||||
|
||||
try {
|
||||
const body = await request.json()
|
||||
@@ -26,7 +27,6 @@ export async function POST(request: NextRequest) {
|
||||
`[${requestId}] Executing raw SQL on ${params.host}:${params.port}/${params.database}`
|
||||
)
|
||||
|
||||
// Validate query before execution
|
||||
const validation = validateQuery(params.query)
|
||||
if (!validation.isValid) {
|
||||
logger.warn(`[${requestId}] Query validation failed: ${validation.error}`)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { randomUUID } from 'crypto'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
@@ -11,7 +12,7 @@ const InsertSchema = z.object({
|
||||
database: z.string().min(1, 'Database name is required'),
|
||||
username: z.string().min(1, 'Username is required'),
|
||||
password: z.string().min(1, 'Password is required'),
|
||||
ssl: z.enum(['disabled', 'required', 'preferred']).default('required'),
|
||||
ssl: z.enum(['disabled', 'required', 'preferred']).default('preferred'),
|
||||
table: z.string().min(1, 'Table name is required'),
|
||||
data: z.union([
|
||||
z
|
||||
@@ -38,13 +39,10 @@ const InsertSchema = z.object({
|
||||
})
|
||||
|
||||
export async function POST(request: NextRequest) {
|
||||
const requestId = crypto.randomUUID().slice(0, 8)
|
||||
const requestId = randomUUID().slice(0, 8)
|
||||
|
||||
try {
|
||||
const body = await request.json()
|
||||
|
||||
logger.info(`[${requestId}] Received data field type: ${typeof body.data}, value:`, body.data)
|
||||
|
||||
const params = InsertSchema.parse(body)
|
||||
|
||||
logger.info(
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { randomUUID } from 'crypto'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
@@ -11,12 +12,12 @@ const QuerySchema = z.object({
|
||||
database: z.string().min(1, 'Database name is required'),
|
||||
username: z.string().min(1, 'Username is required'),
|
||||
password: z.string().min(1, 'Password is required'),
|
||||
ssl: z.enum(['disabled', 'required', 'preferred']).default('required'),
|
||||
ssl: z.enum(['disabled', 'required', 'preferred']).default('preferred'),
|
||||
query: z.string().min(1, 'Query is required'),
|
||||
})
|
||||
|
||||
export async function POST(request: NextRequest) {
|
||||
const requestId = crypto.randomUUID().slice(0, 8)
|
||||
const requestId = randomUUID().slice(0, 8)
|
||||
|
||||
try {
|
||||
const body = await request.json()
|
||||
@@ -26,7 +27,6 @@ export async function POST(request: NextRequest) {
|
||||
`[${requestId}] Executing MySQL query on ${params.host}:${params.port}/${params.database}`
|
||||
)
|
||||
|
||||
// Validate query before execution
|
||||
const validation = validateQuery(params.query)
|
||||
if (!validation.isValid) {
|
||||
logger.warn(`[${requestId}] Query validation failed: ${validation.error}`)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { randomUUID } from 'crypto'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
@@ -11,7 +12,7 @@ const UpdateSchema = z.object({
|
||||
database: z.string().min(1, 'Database name is required'),
|
||||
username: z.string().min(1, 'Username is required'),
|
||||
password: z.string().min(1, 'Password is required'),
|
||||
ssl: z.enum(['disabled', 'required', 'preferred']).default('required'),
|
||||
ssl: z.enum(['disabled', 'required', 'preferred']).default('preferred'),
|
||||
table: z.string().min(1, 'Table name is required'),
|
||||
data: z.union([
|
||||
z
|
||||
@@ -36,7 +37,7 @@ const UpdateSchema = z.object({
|
||||
})
|
||||
|
||||
export async function POST(request: NextRequest) {
|
||||
const requestId = crypto.randomUUID().slice(0, 8)
|
||||
const requestId = randomUUID().slice(0, 8)
|
||||
|
||||
try {
|
||||
const body = await request.json()
|
||||
|
||||
@@ -6,7 +6,7 @@ export interface MySQLConnectionConfig {
|
||||
database: string
|
||||
username: string
|
||||
password: string
|
||||
ssl?: string
|
||||
ssl?: 'disabled' | 'required' | 'preferred'
|
||||
}
|
||||
|
||||
export async function createMySQLConnection(config: MySQLConnectionConfig) {
|
||||
@@ -18,13 +18,13 @@ export async function createMySQLConnection(config: MySQLConnectionConfig) {
|
||||
password: config.password,
|
||||
}
|
||||
|
||||
// Handle SSL configuration
|
||||
if (config.ssl === 'required') {
|
||||
if (config.ssl === 'disabled') {
|
||||
// Don't set ssl property at all to disable SSL
|
||||
} else if (config.ssl === 'required') {
|
||||
connectionConfig.ssl = { rejectUnauthorized: true }
|
||||
} else if (config.ssl === 'preferred') {
|
||||
connectionConfig.ssl = { rejectUnauthorized: false }
|
||||
}
|
||||
// For 'disabled', we don't set the ssl property at all
|
||||
|
||||
return mysql.createConnection(connectionConfig)
|
||||
}
|
||||
@@ -54,7 +54,6 @@ export async function executeQuery(
|
||||
export function validateQuery(query: string): { isValid: boolean; error?: string } {
|
||||
const trimmedQuery = query.trim().toLowerCase()
|
||||
|
||||
// Block dangerous SQL operations
|
||||
const dangerousPatterns = [
|
||||
/drop\s+database/i,
|
||||
/drop\s+schema/i,
|
||||
@@ -91,7 +90,6 @@ export function validateQuery(query: string): { isValid: boolean; error?: string
|
||||
}
|
||||
}
|
||||
|
||||
// Only allow specific statement types for execute endpoint
|
||||
const allowedStatements = /^(select|insert|update|delete|with|show|describe|explain)\s+/i
|
||||
if (!allowedStatements.test(trimmedQuery)) {
|
||||
return {
|
||||
@@ -116,6 +114,8 @@ export function buildInsertQuery(table: string, data: Record<string, unknown>) {
|
||||
}
|
||||
|
||||
export function buildUpdateQuery(table: string, data: Record<string, unknown>, where: string) {
|
||||
validateWhereClause(where)
|
||||
|
||||
const sanitizedTable = sanitizeIdentifier(table)
|
||||
const columns = Object.keys(data)
|
||||
const values = Object.values(data)
|
||||
@@ -127,14 +127,33 @@ export function buildUpdateQuery(table: string, data: Record<string, unknown>, w
|
||||
}
|
||||
|
||||
export function buildDeleteQuery(table: string, where: string) {
|
||||
validateWhereClause(where)
|
||||
|
||||
const sanitizedTable = sanitizeIdentifier(table)
|
||||
const query = `DELETE FROM ${sanitizedTable} WHERE ${where}`
|
||||
|
||||
return { query, values: [] }
|
||||
}
|
||||
|
||||
function validateWhereClause(where: string): void {
|
||||
const dangerousPatterns = [
|
||||
/;\s*(drop|delete|insert|update|create|alter|grant|revoke)/i,
|
||||
/union\s+select/i,
|
||||
/into\s+outfile/i,
|
||||
/load_file/i,
|
||||
/--/,
|
||||
/\/\*/,
|
||||
/\*\//,
|
||||
]
|
||||
|
||||
for (const pattern of dangerousPatterns) {
|
||||
if (pattern.test(where)) {
|
||||
throw new Error('WHERE clause contains potentially dangerous operation')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export function sanitizeIdentifier(identifier: string): string {
|
||||
// Handle schema.table format
|
||||
if (identifier.includes('.')) {
|
||||
const parts = identifier.split('.')
|
||||
return parts.map((part) => sanitizeSingleIdentifier(part)).join('.')
|
||||
@@ -144,16 +163,13 @@ export function sanitizeIdentifier(identifier: string): string {
|
||||
}
|
||||
|
||||
function sanitizeSingleIdentifier(identifier: string): string {
|
||||
// Remove any existing backticks to prevent double-escaping
|
||||
const cleaned = identifier.replace(/`/g, '')
|
||||
|
||||
// Validate identifier contains only safe characters
|
||||
if (!/^[a-zA-Z_][a-zA-Z0-9_]*$/.test(cleaned)) {
|
||||
throw new Error(
|
||||
`Invalid identifier: ${identifier}. Identifiers must start with a letter or underscore and contain only letters, numbers, and underscores.`
|
||||
)
|
||||
}
|
||||
|
||||
// Wrap in backticks for MySQL
|
||||
return `\`${cleaned}\``
|
||||
}
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
import { randomUUID } from 'crypto'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import {
|
||||
buildDeleteQuery,
|
||||
createPostgresConnection,
|
||||
executeQuery,
|
||||
} from '@/app/api/tools/postgresql/utils'
|
||||
import { createPostgresConnection, executeDelete } from '@/app/api/tools/postgresql/utils'
|
||||
|
||||
const logger = createLogger('PostgreSQLDeleteAPI')
|
||||
|
||||
@@ -15,13 +12,13 @@ const DeleteSchema = z.object({
|
||||
database: z.string().min(1, 'Database name is required'),
|
||||
username: z.string().min(1, 'Username is required'),
|
||||
password: z.string().min(1, 'Password is required'),
|
||||
ssl: z.enum(['disabled', 'required', 'preferred']).default('required'),
|
||||
ssl: z.enum(['disabled', 'required', 'preferred']).default('preferred'),
|
||||
table: z.string().min(1, 'Table name is required'),
|
||||
where: z.string().min(1, 'WHERE clause is required'),
|
||||
})
|
||||
|
||||
export async function POST(request: NextRequest) {
|
||||
const requestId = crypto.randomUUID().slice(0, 8)
|
||||
const requestId = randomUUID().slice(0, 8)
|
||||
|
||||
try {
|
||||
const body = await request.json()
|
||||
@@ -31,7 +28,7 @@ export async function POST(request: NextRequest) {
|
||||
`[${requestId}] Deleting data from ${params.table} on ${params.host}:${params.port}/${params.database}`
|
||||
)
|
||||
|
||||
const client = await createPostgresConnection({
|
||||
const sql = createPostgresConnection({
|
||||
host: params.host,
|
||||
port: params.port,
|
||||
database: params.database,
|
||||
@@ -41,8 +38,7 @@ export async function POST(request: NextRequest) {
|
||||
})
|
||||
|
||||
try {
|
||||
const { query, values } = buildDeleteQuery(params.table, params.where)
|
||||
const result = await executeQuery(client, query, values)
|
||||
const result = await executeDelete(sql, params.table, params.where)
|
||||
|
||||
logger.info(`[${requestId}] Delete executed successfully, ${result.rowCount} row(s) deleted`)
|
||||
|
||||
@@ -52,7 +48,7 @@ export async function POST(request: NextRequest) {
|
||||
rowCount: result.rowCount,
|
||||
})
|
||||
} finally {
|
||||
await client.end()
|
||||
await sql.end()
|
||||
}
|
||||
} catch (error) {
|
||||
if (error instanceof z.ZodError) {
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { randomUUID } from 'crypto'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
@@ -15,12 +16,12 @@ const ExecuteSchema = z.object({
|
||||
database: z.string().min(1, 'Database name is required'),
|
||||
username: z.string().min(1, 'Username is required'),
|
||||
password: z.string().min(1, 'Password is required'),
|
||||
ssl: z.enum(['disabled', 'required', 'preferred']).default('required'),
|
||||
ssl: z.enum(['disabled', 'required', 'preferred']).default('preferred'),
|
||||
query: z.string().min(1, 'Query is required'),
|
||||
})
|
||||
|
||||
export async function POST(request: NextRequest) {
|
||||
const requestId = crypto.randomUUID().slice(0, 8)
|
||||
const requestId = randomUUID().slice(0, 8)
|
||||
|
||||
try {
|
||||
const body = await request.json()
|
||||
@@ -30,7 +31,6 @@ export async function POST(request: NextRequest) {
|
||||
`[${requestId}] Executing raw SQL on ${params.host}:${params.port}/${params.database}`
|
||||
)
|
||||
|
||||
// Validate query before execution
|
||||
const validation = validateQuery(params.query)
|
||||
if (!validation.isValid) {
|
||||
logger.warn(`[${requestId}] Query validation failed: ${validation.error}`)
|
||||
@@ -40,7 +40,7 @@ export async function POST(request: NextRequest) {
|
||||
)
|
||||
}
|
||||
|
||||
const client = await createPostgresConnection({
|
||||
const sql = createPostgresConnection({
|
||||
host: params.host,
|
||||
port: params.port,
|
||||
database: params.database,
|
||||
@@ -50,7 +50,7 @@ export async function POST(request: NextRequest) {
|
||||
})
|
||||
|
||||
try {
|
||||
const result = await executeQuery(client, params.query)
|
||||
const result = await executeQuery(sql, params.query)
|
||||
|
||||
logger.info(`[${requestId}] SQL executed successfully, ${result.rowCount} row(s) affected`)
|
||||
|
||||
@@ -60,7 +60,7 @@ export async function POST(request: NextRequest) {
|
||||
rowCount: result.rowCount,
|
||||
})
|
||||
} finally {
|
||||
await client.end()
|
||||
await sql.end()
|
||||
}
|
||||
} catch (error) {
|
||||
if (error instanceof z.ZodError) {
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
import { randomUUID } from 'crypto'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import {
|
||||
buildInsertQuery,
|
||||
createPostgresConnection,
|
||||
executeQuery,
|
||||
} from '@/app/api/tools/postgresql/utils'
|
||||
import { createPostgresConnection, executeInsert } from '@/app/api/tools/postgresql/utils'
|
||||
|
||||
const logger = createLogger('PostgreSQLInsertAPI')
|
||||
|
||||
@@ -15,7 +12,7 @@ const InsertSchema = z.object({
|
||||
database: z.string().min(1, 'Database name is required'),
|
||||
username: z.string().min(1, 'Username is required'),
|
||||
password: z.string().min(1, 'Password is required'),
|
||||
ssl: z.enum(['disabled', 'required', 'preferred']).default('required'),
|
||||
ssl: z.enum(['disabled', 'required', 'preferred']).default('preferred'),
|
||||
table: z.string().min(1, 'Table name is required'),
|
||||
data: z.union([
|
||||
z
|
||||
@@ -42,21 +39,18 @@ const InsertSchema = z.object({
|
||||
})
|
||||
|
||||
export async function POST(request: NextRequest) {
|
||||
const requestId = crypto.randomUUID().slice(0, 8)
|
||||
const requestId = randomUUID().slice(0, 8)
|
||||
|
||||
try {
|
||||
const body = await request.json()
|
||||
|
||||
// Debug: Log the data field to see what we're getting
|
||||
logger.info(`[${requestId}] Received data field type: ${typeof body.data}, value:`, body.data)
|
||||
|
||||
const params = InsertSchema.parse(body)
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Inserting data into ${params.table} on ${params.host}:${params.port}/${params.database}`
|
||||
)
|
||||
|
||||
const client = await createPostgresConnection({
|
||||
const sql = createPostgresConnection({
|
||||
host: params.host,
|
||||
port: params.port,
|
||||
database: params.database,
|
||||
@@ -66,8 +60,7 @@ export async function POST(request: NextRequest) {
|
||||
})
|
||||
|
||||
try {
|
||||
const { query, values } = buildInsertQuery(params.table, params.data)
|
||||
const result = await executeQuery(client, query, values)
|
||||
const result = await executeInsert(sql, params.table, params.data)
|
||||
|
||||
logger.info(`[${requestId}] Insert executed successfully, ${result.rowCount} row(s) inserted`)
|
||||
|
||||
@@ -77,7 +70,7 @@ export async function POST(request: NextRequest) {
|
||||
rowCount: result.rowCount,
|
||||
})
|
||||
} finally {
|
||||
await client.end()
|
||||
await sql.end()
|
||||
}
|
||||
} catch (error) {
|
||||
if (error instanceof z.ZodError) {
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { randomUUID } from 'crypto'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
@@ -11,12 +12,12 @@ const QuerySchema = z.object({
|
||||
database: z.string().min(1, 'Database name is required'),
|
||||
username: z.string().min(1, 'Username is required'),
|
||||
password: z.string().min(1, 'Password is required'),
|
||||
ssl: z.enum(['disabled', 'required', 'preferred']).default('required'),
|
||||
ssl: z.enum(['disabled', 'required', 'preferred']).default('preferred'),
|
||||
query: z.string().min(1, 'Query is required'),
|
||||
})
|
||||
|
||||
export async function POST(request: NextRequest) {
|
||||
const requestId = crypto.randomUUID().slice(0, 8)
|
||||
const requestId = randomUUID().slice(0, 8)
|
||||
|
||||
try {
|
||||
const body = await request.json()
|
||||
@@ -26,7 +27,7 @@ export async function POST(request: NextRequest) {
|
||||
`[${requestId}] Executing PostgreSQL query on ${params.host}:${params.port}/${params.database}`
|
||||
)
|
||||
|
||||
const client = await createPostgresConnection({
|
||||
const sql = createPostgresConnection({
|
||||
host: params.host,
|
||||
port: params.port,
|
||||
database: params.database,
|
||||
@@ -36,7 +37,7 @@ export async function POST(request: NextRequest) {
|
||||
})
|
||||
|
||||
try {
|
||||
const result = await executeQuery(client, params.query)
|
||||
const result = await executeQuery(sql, params.query)
|
||||
|
||||
logger.info(`[${requestId}] Query executed successfully, returned ${result.rowCount} rows`)
|
||||
|
||||
@@ -46,7 +47,7 @@ export async function POST(request: NextRequest) {
|
||||
rowCount: result.rowCount,
|
||||
})
|
||||
} finally {
|
||||
await client.end()
|
||||
await sql.end()
|
||||
}
|
||||
} catch (error) {
|
||||
if (error instanceof z.ZodError) {
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
import { randomUUID } from 'crypto'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import {
|
||||
buildUpdateQuery,
|
||||
createPostgresConnection,
|
||||
executeQuery,
|
||||
} from '@/app/api/tools/postgresql/utils'
|
||||
import { createPostgresConnection, executeUpdate } from '@/app/api/tools/postgresql/utils'
|
||||
|
||||
const logger = createLogger('PostgreSQLUpdateAPI')
|
||||
|
||||
@@ -15,7 +12,7 @@ const UpdateSchema = z.object({
|
||||
database: z.string().min(1, 'Database name is required'),
|
||||
username: z.string().min(1, 'Username is required'),
|
||||
password: z.string().min(1, 'Password is required'),
|
||||
ssl: z.enum(['disabled', 'required', 'preferred']).default('required'),
|
||||
ssl: z.enum(['disabled', 'required', 'preferred']).default('preferred'),
|
||||
table: z.string().min(1, 'Table name is required'),
|
||||
data: z.union([
|
||||
z
|
||||
@@ -40,7 +37,7 @@ const UpdateSchema = z.object({
|
||||
})
|
||||
|
||||
export async function POST(request: NextRequest) {
|
||||
const requestId = crypto.randomUUID().slice(0, 8)
|
||||
const requestId = randomUUID().slice(0, 8)
|
||||
|
||||
try {
|
||||
const body = await request.json()
|
||||
@@ -50,7 +47,7 @@ export async function POST(request: NextRequest) {
|
||||
`[${requestId}] Updating data in ${params.table} on ${params.host}:${params.port}/${params.database}`
|
||||
)
|
||||
|
||||
const client = await createPostgresConnection({
|
||||
const sql = createPostgresConnection({
|
||||
host: params.host,
|
||||
port: params.port,
|
||||
database: params.database,
|
||||
@@ -60,8 +57,7 @@ export async function POST(request: NextRequest) {
|
||||
})
|
||||
|
||||
try {
|
||||
const { query, values } = buildUpdateQuery(params.table, params.data, params.where)
|
||||
const result = await executeQuery(client, query, values)
|
||||
const result = await executeUpdate(sql, params.table, params.data, params.where)
|
||||
|
||||
logger.info(`[${requestId}] Update executed successfully, ${result.rowCount} row(s) updated`)
|
||||
|
||||
@@ -71,7 +67,7 @@ export async function POST(request: NextRequest) {
|
||||
rowCount: result.rowCount,
|
||||
})
|
||||
} finally {
|
||||
await client.end()
|
||||
await sql.end()
|
||||
}
|
||||
} catch (error) {
|
||||
if (error instanceof z.ZodError) {
|
||||
|
||||
@@ -1,43 +1,41 @@
|
||||
import { Client } from 'pg'
|
||||
import postgres from 'postgres'
|
||||
import type { PostgresConnectionConfig } from '@/tools/postgresql/types'
|
||||
|
||||
export async function createPostgresConnection(config: PostgresConnectionConfig): Promise<Client> {
|
||||
const client = new Client({
|
||||
export function createPostgresConnection(config: PostgresConnectionConfig) {
|
||||
const sslConfig =
|
||||
config.ssl === 'disabled'
|
||||
? false
|
||||
: config.ssl === 'required'
|
||||
? 'require'
|
||||
: config.ssl === 'preferred'
|
||||
? 'prefer'
|
||||
: 'require'
|
||||
|
||||
const sql = postgres({
|
||||
host: config.host,
|
||||
port: config.port,
|
||||
database: config.database,
|
||||
user: config.username,
|
||||
username: config.username,
|
||||
password: config.password,
|
||||
ssl:
|
||||
config.ssl === 'disabled'
|
||||
? false
|
||||
: config.ssl === 'required'
|
||||
? true
|
||||
: config.ssl === 'preferred'
|
||||
? { rejectUnauthorized: false }
|
||||
: false,
|
||||
connectionTimeoutMillis: 10000, // 10 seconds
|
||||
query_timeout: 30000, // 30 seconds
|
||||
ssl: sslConfig,
|
||||
connect_timeout: 10, // 10 seconds
|
||||
idle_timeout: 20, // 20 seconds
|
||||
max_lifetime: 60 * 30, // 30 minutes
|
||||
max: 1, // Single connection for tool usage
|
||||
})
|
||||
|
||||
try {
|
||||
await client.connect()
|
||||
return client
|
||||
} catch (error) {
|
||||
await client.end()
|
||||
throw error
|
||||
}
|
||||
return sql
|
||||
}
|
||||
|
||||
export async function executeQuery(
|
||||
client: Client,
|
||||
sql: any,
|
||||
query: string,
|
||||
params: unknown[] = []
|
||||
): Promise<{ rows: unknown[]; rowCount: number }> {
|
||||
const result = await client.query(query, params)
|
||||
const result = await sql.unsafe(query, params)
|
||||
return {
|
||||
rows: result.rows || [],
|
||||
rowCount: result.rowCount || 0,
|
||||
rows: Array.isArray(result) ? result : [result],
|
||||
rowCount: Array.isArray(result) ? result.length : result ? 1 : 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -84,7 +82,6 @@ export function validateQuery(query: string): { isValid: boolean; error?: string
|
||||
}
|
||||
}
|
||||
|
||||
// Only allow specific statement types for execute endpoint
|
||||
const allowedStatements = /^(select|insert|update|delete|with|explain|analyze|show)\s+/i
|
||||
if (!allowedStatements.test(trimmedQuery)) {
|
||||
return {
|
||||
@@ -98,7 +95,6 @@ export function validateQuery(query: string): { isValid: boolean; error?: string
|
||||
}
|
||||
|
||||
export function sanitizeIdentifier(identifier: string): string {
|
||||
// Handle schema.table format
|
||||
if (identifier.includes('.')) {
|
||||
const parts = identifier.split('.')
|
||||
return parts.map((part) => sanitizeSingleIdentifier(part)).join('.')
|
||||
@@ -107,28 +103,41 @@ export function sanitizeIdentifier(identifier: string): string {
|
||||
return sanitizeSingleIdentifier(identifier)
|
||||
}
|
||||
|
||||
function validateWhereClause(where: string): void {
|
||||
const dangerousPatterns = [
|
||||
/;\s*(drop|delete|insert|update|create|alter|grant|revoke)/i,
|
||||
/union\s+select/i,
|
||||
/into\s+outfile/i,
|
||||
/load_file/i,
|
||||
/--/,
|
||||
/\/\*/,
|
||||
/\*\//,
|
||||
]
|
||||
|
||||
for (const pattern of dangerousPatterns) {
|
||||
if (pattern.test(where)) {
|
||||
throw new Error('WHERE clause contains potentially dangerous operation')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function sanitizeSingleIdentifier(identifier: string): string {
|
||||
// Remove any existing double quotes to prevent double-escaping
|
||||
const cleaned = identifier.replace(/"/g, '')
|
||||
|
||||
// Validate identifier contains only safe characters
|
||||
if (!/^[a-zA-Z_][a-zA-Z0-9_]*$/.test(cleaned)) {
|
||||
throw new Error(
|
||||
`Invalid identifier: ${identifier}. Identifiers must start with a letter or underscore and contain only letters, numbers, and underscores.`
|
||||
)
|
||||
}
|
||||
|
||||
// Wrap in double quotes for PostgreSQL
|
||||
return `"${cleaned}"`
|
||||
}
|
||||
|
||||
export function buildInsertQuery(
|
||||
export async function executeInsert(
|
||||
sql: any,
|
||||
table: string,
|
||||
data: Record<string, unknown>
|
||||
): {
|
||||
query: string
|
||||
values: unknown[]
|
||||
} {
|
||||
): Promise<{ rows: unknown[]; rowCount: number }> {
|
||||
const sanitizedTable = sanitizeIdentifier(table)
|
||||
const columns = Object.keys(data)
|
||||
const sanitizedColumns = columns.map((col) => sanitizeIdentifier(col))
|
||||
@@ -136,18 +145,22 @@ export function buildInsertQuery(
|
||||
const values = columns.map((col) => data[col])
|
||||
|
||||
const query = `INSERT INTO ${sanitizedTable} (${sanitizedColumns.join(', ')}) VALUES (${placeholders.join(', ')}) RETURNING *`
|
||||
const result = await sql.unsafe(query, values)
|
||||
|
||||
return { query, values }
|
||||
return {
|
||||
rows: Array.isArray(result) ? result : [result],
|
||||
rowCount: Array.isArray(result) ? result.length : result ? 1 : 0,
|
||||
}
|
||||
}
|
||||
|
||||
export function buildUpdateQuery(
|
||||
export async function executeUpdate(
|
||||
sql: any,
|
||||
table: string,
|
||||
data: Record<string, unknown>,
|
||||
where: string
|
||||
): {
|
||||
query: string
|
||||
values: unknown[]
|
||||
} {
|
||||
): Promise<{ rows: unknown[]; rowCount: number }> {
|
||||
validateWhereClause(where)
|
||||
|
||||
const sanitizedTable = sanitizeIdentifier(table)
|
||||
const columns = Object.keys(data)
|
||||
const sanitizedColumns = columns.map((col) => sanitizeIdentifier(col))
|
||||
@@ -155,19 +168,27 @@ export function buildUpdateQuery(
|
||||
const values = columns.map((col) => data[col])
|
||||
|
||||
const query = `UPDATE ${sanitizedTable} SET ${setClause} WHERE ${where} RETURNING *`
|
||||
const result = await sql.unsafe(query, values)
|
||||
|
||||
return { query, values }
|
||||
return {
|
||||
rows: Array.isArray(result) ? result : [result],
|
||||
rowCount: Array.isArray(result) ? result.length : result ? 1 : 0,
|
||||
}
|
||||
}
|
||||
|
||||
export function buildDeleteQuery(
|
||||
export async function executeDelete(
|
||||
sql: any,
|
||||
table: string,
|
||||
where: string
|
||||
): {
|
||||
query: string
|
||||
values: unknown[]
|
||||
} {
|
||||
): Promise<{ rows: unknown[]; rowCount: number }> {
|
||||
validateWhereClause(where)
|
||||
|
||||
const sanitizedTable = sanitizeIdentifier(table)
|
||||
const query = `DELETE FROM ${sanitizedTable} WHERE ${where} RETURNING *`
|
||||
const result = await sql.unsafe(query, [])
|
||||
|
||||
return { query, values: [] }
|
||||
return {
|
||||
rows: Array.isArray(result) ? result : [result],
|
||||
rowCount: Array.isArray(result) ? result.length : result ? 1 : 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { getUserUsageLimitInfo, updateUserUsageLimit } from '@/lib/billing'
|
||||
import { getOrganizationBillingData } from '@/lib/billing/core/organization-billing'
|
||||
import {
|
||||
getOrganizationBillingData,
|
||||
isOrganizationOwnerOrAdmin,
|
||||
} from '@/lib/billing/core/organization'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { isOrganizationOwnerOrAdmin } from '@/lib/permissions/utils'
|
||||
|
||||
const logger = createLogger('UnifiedUsageLimitsAPI')
|
||||
|
||||
@@ -25,7 +27,6 @@ export async function GET(request: NextRequest) {
|
||||
const userId = searchParams.get('userId') || session.user.id
|
||||
const organizationId = searchParams.get('organizationId')
|
||||
|
||||
// Validate context
|
||||
if (!['user', 'organization'].includes(context)) {
|
||||
return NextResponse.json(
|
||||
{ error: 'Invalid context. Must be "user" or "organization"' },
|
||||
@@ -33,7 +34,6 @@ export async function GET(request: NextRequest) {
|
||||
)
|
||||
}
|
||||
|
||||
// For user context, ensure they can only view their own info
|
||||
if (context === 'user' && userId !== session.user.id) {
|
||||
return NextResponse.json(
|
||||
{ error: "Cannot view other users' usage information" },
|
||||
@@ -41,7 +41,6 @@ export async function GET(request: NextRequest) {
|
||||
)
|
||||
}
|
||||
|
||||
// Get usage limit info
|
||||
if (context === 'organization') {
|
||||
if (!organizationId) {
|
||||
return NextResponse.json(
|
||||
@@ -107,10 +106,8 @@ export async function PUT(request: NextRequest) {
|
||||
}
|
||||
|
||||
if (context === 'user') {
|
||||
// Update user's own usage limit
|
||||
await updateUserUsageLimit(userId, limit)
|
||||
} else if (context === 'organization') {
|
||||
// context === 'organization'
|
||||
if (!organizationId) {
|
||||
return NextResponse.json(
|
||||
{ error: 'Organization ID is required when context=organization' },
|
||||
@@ -123,10 +120,7 @@ export async function PUT(request: NextRequest) {
|
||||
return NextResponse.json({ error: 'Permission denied' }, { status: 403 })
|
||||
}
|
||||
|
||||
// Use the dedicated function to update org usage limit
|
||||
const { updateOrganizationUsageLimit } = await import(
|
||||
'@/lib/billing/core/organization-billing'
|
||||
)
|
||||
const { updateOrganizationUsageLimit } = await import('@/lib/billing/core/organization')
|
||||
const result = await updateOrganizationUsageLimit(organizationId, limit)
|
||||
|
||||
if (!result.success) {
|
||||
@@ -137,7 +131,6 @@ export async function PUT(request: NextRequest) {
|
||||
return NextResponse.json({ success: true, context, userId, organizationId, data: updated })
|
||||
}
|
||||
|
||||
// Return updated limit info
|
||||
const updatedInfo = await getUserUsageLimitInfo(userId)
|
||||
|
||||
return NextResponse.json({
|
||||
|
||||
@@ -4,7 +4,7 @@ import { env } from '@/lib/env'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
|
||||
export const dynamic = 'force-dynamic'
|
||||
export const runtime = 'edge'
|
||||
export const runtime = 'nodejs'
|
||||
export const maxDuration = 60
|
||||
|
||||
const logger = createLogger('WandGenerateAPI')
|
||||
@@ -49,6 +49,15 @@ interface RequestBody {
|
||||
history?: ChatMessage[]
|
||||
}
|
||||
|
||||
// Helper: safe stringify for error payloads that may include circular structures
|
||||
function safeStringify(value: unknown): string {
|
||||
try {
|
||||
return JSON.stringify(value)
|
||||
} catch {
|
||||
return '[unserializable]'
|
||||
}
|
||||
}
|
||||
|
||||
export async function POST(req: NextRequest) {
|
||||
const requestId = crypto.randomUUID().slice(0, 8)
|
||||
logger.info(`[${requestId}] Received wand generation request`)
|
||||
@@ -110,124 +119,172 @@ export async function POST(req: NextRequest) {
|
||||
`[${requestId}] About to create stream with model: ${useWandAzure ? wandModelName : 'gpt-4o'}`
|
||||
)
|
||||
|
||||
// Add AbortController with timeout
|
||||
const abortController = new AbortController()
|
||||
const timeoutId = setTimeout(() => {
|
||||
abortController.abort('Stream timeout after 30 seconds')
|
||||
}, 30000)
|
||||
// Use native fetch for streaming to avoid OpenAI SDK issues with Node.js runtime
|
||||
const apiUrl = useWandAzure
|
||||
? `${azureEndpoint}/openai/deployments/${wandModelName}/chat/completions?api-version=${azureApiVersion}`
|
||||
: 'https://api.openai.com/v1/chat/completions'
|
||||
|
||||
// Forward request abort signal if available
|
||||
req.signal?.addEventListener('abort', () => {
|
||||
abortController.abort('Request cancelled by client')
|
||||
})
|
||||
const headers: Record<string, string> = {
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
|
||||
const streamCompletion = await client.chat.completions.create(
|
||||
{
|
||||
if (useWandAzure) {
|
||||
headers['api-key'] = azureApiKey!
|
||||
} else {
|
||||
headers.Authorization = `Bearer ${openaiApiKey}`
|
||||
}
|
||||
|
||||
logger.debug(`[${requestId}] Making streaming request to: ${apiUrl}`)
|
||||
|
||||
const response = await fetch(apiUrl, {
|
||||
method: 'POST',
|
||||
headers,
|
||||
body: JSON.stringify({
|
||||
model: useWandAzure ? wandModelName : 'gpt-4o',
|
||||
messages: messages,
|
||||
temperature: 0.3,
|
||||
max_tokens: 10000,
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
},
|
||||
{
|
||||
signal: abortController.signal, // Add AbortSignal
|
||||
}
|
||||
)
|
||||
|
||||
clearTimeout(timeoutId) // Clear timeout after successful creation
|
||||
logger.info(`[${requestId}] Stream created successfully, starting reader pattern`)
|
||||
|
||||
logger.debug(`[${requestId}] Stream connection established successfully`)
|
||||
|
||||
return new Response(
|
||||
new ReadableStream({
|
||||
async start(controller) {
|
||||
const encoder = new TextEncoder()
|
||||
|
||||
try {
|
||||
logger.info(`[${requestId}] Starting streaming with timeout protection`)
|
||||
let chunkCount = 0
|
||||
let hasUsageData = false
|
||||
|
||||
// Use for await with AbortController timeout protection
|
||||
for await (const chunk of streamCompletion) {
|
||||
chunkCount++
|
||||
|
||||
if (chunkCount === 1) {
|
||||
logger.info(`[${requestId}] Received first chunk via for await`)
|
||||
}
|
||||
|
||||
// Process the chunk
|
||||
const content = chunk.choices?.[0]?.delta?.content || ''
|
||||
if (content) {
|
||||
// Use SSE format identical to chat streaming
|
||||
controller.enqueue(
|
||||
encoder.encode(`data: ${JSON.stringify({ chunk: content })}\n\n`)
|
||||
)
|
||||
}
|
||||
|
||||
// Check for usage data
|
||||
if (chunk.usage) {
|
||||
hasUsageData = true
|
||||
logger.info(
|
||||
`[${requestId}] Received usage data: ${JSON.stringify(chunk.usage)}`
|
||||
)
|
||||
}
|
||||
|
||||
// Log every 5th chunk to avoid spam
|
||||
if (chunkCount % 5 === 0) {
|
||||
logger.debug(`[${requestId}] Processed ${chunkCount} chunks so far`)
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Reader pattern completed. Total chunks: ${chunkCount}, Usage data received: ${hasUsageData}`
|
||||
)
|
||||
|
||||
// Send completion signal in SSE format
|
||||
logger.info(`[${requestId}] Sending completion signal`)
|
||||
controller.enqueue(encoder.encode(`data: ${JSON.stringify({ done: true })}\n\n`))
|
||||
|
||||
logger.info(`[${requestId}] Closing controller`)
|
||||
controller.close()
|
||||
|
||||
logger.info(`[${requestId}] Wand generation streaming completed successfully`)
|
||||
} catch (streamError: any) {
|
||||
if (streamError.name === 'AbortError') {
|
||||
logger.info(
|
||||
`[${requestId}] Stream was aborted (timeout or cancel): ${streamError.message}`
|
||||
)
|
||||
controller.enqueue(
|
||||
encoder.encode(
|
||||
`data: ${JSON.stringify({ error: 'Stream cancelled', done: true })}\n\n`
|
||||
)
|
||||
)
|
||||
} else {
|
||||
logger.error(`[${requestId}] Streaming error`, { error: streamError.message })
|
||||
controller.enqueue(
|
||||
encoder.encode(
|
||||
`data: ${JSON.stringify({ error: 'Streaming failed', done: true })}\n\n`
|
||||
)
|
||||
)
|
||||
}
|
||||
controller.close()
|
||||
}
|
||||
},
|
||||
}),
|
||||
{
|
||||
headers: {
|
||||
'Content-Type': 'text/event-stream',
|
||||
'Cache-Control': 'no-cache',
|
||||
Connection: 'keep-alive',
|
||||
'X-Accel-Buffering': 'no',
|
||||
},
|
||||
}
|
||||
)
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text()
|
||||
logger.error(`[${requestId}] API request failed`, {
|
||||
status: response.status,
|
||||
statusText: response.statusText,
|
||||
error: errorText,
|
||||
})
|
||||
throw new Error(`API request failed: ${response.status} ${response.statusText}`)
|
||||
}
|
||||
|
||||
logger.info(`[${requestId}] Stream response received, starting processing`)
|
||||
|
||||
// Create a TransformStream to process the SSE data
|
||||
const encoder = new TextEncoder()
|
||||
const decoder = new TextDecoder()
|
||||
|
||||
const readable = new ReadableStream({
|
||||
async start(controller) {
|
||||
const reader = response.body?.getReader()
|
||||
if (!reader) {
|
||||
controller.close()
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
let buffer = ''
|
||||
let chunkCount = 0
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
|
||||
if (done) {
|
||||
logger.info(`[${requestId}] Stream completed. Total chunks: ${chunkCount}`)
|
||||
controller.enqueue(encoder.encode(`data: ${JSON.stringify({ done: true })}\n\n`))
|
||||
controller.close()
|
||||
break
|
||||
}
|
||||
|
||||
// Decode the chunk
|
||||
buffer += decoder.decode(value, { stream: true })
|
||||
|
||||
// Process complete SSE messages
|
||||
const lines = buffer.split('\n')
|
||||
buffer = lines.pop() || '' // Keep incomplete line in buffer
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.startsWith('data: ')) {
|
||||
const data = line.slice(6).trim()
|
||||
|
||||
if (data === '[DONE]') {
|
||||
logger.info(`[${requestId}] Received [DONE] signal`)
|
||||
controller.enqueue(
|
||||
encoder.encode(`data: ${JSON.stringify({ done: true })}\n\n`)
|
||||
)
|
||||
controller.close()
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
const parsed = JSON.parse(data)
|
||||
const content = parsed.choices?.[0]?.delta?.content
|
||||
|
||||
if (content) {
|
||||
chunkCount++
|
||||
if (chunkCount === 1) {
|
||||
logger.info(`[${requestId}] Received first content chunk`)
|
||||
}
|
||||
|
||||
// Forward the content
|
||||
controller.enqueue(
|
||||
encoder.encode(`data: ${JSON.stringify({ chunk: content })}\n\n`)
|
||||
)
|
||||
}
|
||||
|
||||
// Log usage if present
|
||||
if (parsed.usage) {
|
||||
logger.info(
|
||||
`[${requestId}] Received usage data: ${JSON.stringify(parsed.usage)}`
|
||||
)
|
||||
}
|
||||
|
||||
// Log progress periodically
|
||||
if (chunkCount % 10 === 0) {
|
||||
logger.debug(`[${requestId}] Processed ${chunkCount} chunks`)
|
||||
}
|
||||
} catch (parseError) {
|
||||
// Skip invalid JSON lines
|
||||
logger.debug(
|
||||
`[${requestId}] Skipped non-JSON line: ${data.substring(0, 100)}`
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(`[${requestId}] Wand generation streaming completed successfully`)
|
||||
} catch (streamError: any) {
|
||||
logger.error(`[${requestId}] Streaming error`, {
|
||||
name: streamError?.name,
|
||||
message: streamError?.message || 'Unknown error',
|
||||
stack: streamError?.stack,
|
||||
})
|
||||
|
||||
// Send error to client
|
||||
const errorData = `data: ${JSON.stringify({ error: 'Streaming failed', done: true })}\n\n`
|
||||
controller.enqueue(encoder.encode(errorData))
|
||||
controller.close()
|
||||
} finally {
|
||||
reader.releaseLock()
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
// Return Response with proper headers for Node.js runtime
|
||||
return new Response(readable, {
|
||||
headers: {
|
||||
'Content-Type': 'text/event-stream',
|
||||
'Cache-Control': 'no-cache, no-transform',
|
||||
Connection: 'keep-alive',
|
||||
'X-Accel-Buffering': 'no', // Disable Nginx buffering
|
||||
'Transfer-Encoding': 'chunked', // Important for Node.js runtime
|
||||
},
|
||||
})
|
||||
} catch (error: any) {
|
||||
logger.error(`[${requestId}] Streaming error`, {
|
||||
error: error.message || 'Unknown error',
|
||||
stack: error.stack,
|
||||
logger.error(`[${requestId}] Failed to create stream`, {
|
||||
name: error?.name,
|
||||
message: error?.message || 'Unknown error',
|
||||
code: error?.code,
|
||||
status: error?.status,
|
||||
responseStatus: error?.response?.status,
|
||||
responseData: error?.response?.data ? safeStringify(error.response.data) : undefined,
|
||||
stack: error?.stack,
|
||||
useWandAzure,
|
||||
model: useWandAzure ? wandModelName : 'gpt-4o',
|
||||
endpoint: useWandAzure ? azureEndpoint : 'api.openai.com',
|
||||
apiVersion: useWandAzure ? azureApiVersion : 'N/A',
|
||||
})
|
||||
|
||||
return NextResponse.json(
|
||||
@@ -261,8 +318,19 @@ export async function POST(req: NextRequest) {
|
||||
return NextResponse.json({ success: true, content: generatedContent })
|
||||
} catch (error: any) {
|
||||
logger.error(`[${requestId}] Wand generation failed`, {
|
||||
error: error.message || 'Unknown error',
|
||||
stack: error.stack,
|
||||
name: error?.name,
|
||||
message: error?.message || 'Unknown error',
|
||||
code: error?.code,
|
||||
status: error?.status,
|
||||
responseStatus: error instanceof OpenAI.APIError ? error.status : error?.response?.status,
|
||||
responseData: (error as any)?.response?.data
|
||||
? safeStringify((error as any).response.data)
|
||||
: undefined,
|
||||
stack: error?.stack,
|
||||
useWandAzure,
|
||||
model: useWandAzure ? wandModelName : 'gpt-4o',
|
||||
endpoint: useWandAzure ? azureEndpoint : 'api.openai.com',
|
||||
apiVersion: useWandAzure ? azureApiVersion : 'N/A',
|
||||
})
|
||||
|
||||
let clientErrorMessage = 'Wand generation failed. Please try again later.'
|
||||
|
||||
@@ -2,16 +2,19 @@ import crypto from 'crypto'
|
||||
import { and, eq } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { getUsersWithPermissions, hasWorkspaceAdminAccess } from '@/lib/permissions/utils'
|
||||
import { db } from '@/db'
|
||||
import { permissions, type permissionTypeEnum } from '@/db/schema'
|
||||
|
||||
const logger = createLogger('WorkspacesPermissionsAPI')
|
||||
|
||||
type PermissionType = (typeof permissionTypeEnum.enumValues)[number]
|
||||
|
||||
interface UpdatePermissionsRequest {
|
||||
updates: Array<{
|
||||
userId: string
|
||||
permissions: PermissionType // Single permission type instead of object with booleans
|
||||
permissions: PermissionType
|
||||
}>
|
||||
}
|
||||
|
||||
@@ -33,7 +36,6 @@ export async function GET(request: NextRequest, { params }: { params: Promise<{
|
||||
return NextResponse.json({ error: 'Authentication required' }, { status: 401 })
|
||||
}
|
||||
|
||||
// Verify the current user has access to this workspace
|
||||
const userPermission = await db
|
||||
.select()
|
||||
.from(permissions)
|
||||
@@ -57,7 +59,7 @@ export async function GET(request: NextRequest, { params }: { params: Promise<{
|
||||
total: result.length,
|
||||
})
|
||||
} catch (error) {
|
||||
console.error('Error fetching workspace permissions:', error)
|
||||
logger.error('Error fetching workspace permissions:', error)
|
||||
return NextResponse.json({ error: 'Failed to fetch workspace permissions' }, { status: 500 })
|
||||
}
|
||||
}
|
||||
@@ -81,7 +83,6 @@ export async function PATCH(request: NextRequest, { params }: { params: Promise<
|
||||
return NextResponse.json({ error: 'Authentication required' }, { status: 401 })
|
||||
}
|
||||
|
||||
// Verify the current user has admin access to this workspace (either direct or through organization)
|
||||
const hasAdminAccess = await hasWorkspaceAdminAccess(session.user.id, workspaceId)
|
||||
|
||||
if (!hasAdminAccess) {
|
||||
@@ -91,10 +92,8 @@ export async function PATCH(request: NextRequest, { params }: { params: Promise<
|
||||
)
|
||||
}
|
||||
|
||||
// Parse and validate request body
|
||||
const body: UpdatePermissionsRequest = await request.json()
|
||||
|
||||
// Prevent users from modifying their own admin permissions
|
||||
const selfUpdate = body.updates.find((update) => update.userId === session.user.id)
|
||||
if (selfUpdate && selfUpdate.permissions !== 'admin') {
|
||||
return NextResponse.json(
|
||||
@@ -103,10 +102,8 @@ export async function PATCH(request: NextRequest, { params }: { params: Promise<
|
||||
)
|
||||
}
|
||||
|
||||
// Process updates in a transaction
|
||||
await db.transaction(async (tx) => {
|
||||
for (const update of body.updates) {
|
||||
// Delete existing permissions for this user and workspace
|
||||
await tx
|
||||
.delete(permissions)
|
||||
.where(
|
||||
@@ -117,7 +114,6 @@ export async function PATCH(request: NextRequest, { params }: { params: Promise<
|
||||
)
|
||||
)
|
||||
|
||||
// Insert the single new permission
|
||||
await tx.insert(permissions).values({
|
||||
id: crypto.randomUUID(),
|
||||
userId: update.userId,
|
||||
@@ -138,7 +134,7 @@ export async function PATCH(request: NextRequest, { params }: { params: Promise<
|
||||
total: updatedUsers.length,
|
||||
})
|
||||
} catch (error) {
|
||||
console.error('Error updating workspace permissions:', error)
|
||||
logger.error('Error updating workspace permissions:', error)
|
||||
return NextResponse.json({ error: 'Failed to update workspace permissions' }, { status: 500 })
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { getAssetUrl } from '@/lib/utils'
|
||||
import '@/app/globals.css'
|
||||
|
||||
import { SessionProvider } from '@/lib/session-context'
|
||||
import { SessionProvider } from '@/lib/session/session-context'
|
||||
import { ThemeProvider } from '@/app/theme-provider'
|
||||
import { ZoomPrevention } from '@/app/zoom-prevention'
|
||||
|
||||
|
||||
@@ -7,22 +7,12 @@ import { Dialog, DialogContent, DialogHeader, DialogTitle } from '@/components/u
|
||||
import { Label } from '@/components/ui/label'
|
||||
import { Progress } from '@/components/ui/progress'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { ACCEPT_ATTRIBUTE, ACCEPTED_FILE_TYPES, MAX_FILE_SIZE } from '@/lib/uploads/validation'
|
||||
import { getDocumentIcon } from '@/app/workspace/[workspaceId]/knowledge/components'
|
||||
import { useKnowledgeUpload } from '@/app/workspace/[workspaceId]/knowledge/hooks/use-knowledge-upload'
|
||||
|
||||
const logger = createLogger('UploadModal')
|
||||
|
||||
const MAX_FILE_SIZE = 100 * 1024 * 1024 // 100MB
|
||||
const ACCEPTED_FILE_TYPES = [
|
||||
'application/pdf',
|
||||
'application/msword',
|
||||
'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
||||
'text/plain',
|
||||
'text/csv',
|
||||
'application/vnd.ms-excel',
|
||||
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
|
||||
]
|
||||
|
||||
interface FileWithPreview extends File {
|
||||
preview: string
|
||||
}
|
||||
@@ -74,7 +64,7 @@ export function UploadModal({
|
||||
return `File "${file.name}" is too large. Maximum size is 100MB.`
|
||||
}
|
||||
if (!ACCEPTED_FILE_TYPES.includes(file.type)) {
|
||||
return `File "${file.name}" has an unsupported format. Please use PDF, DOC, DOCX, TXT, CSV, XLS, or XLSX files.`
|
||||
return `File "${file.name}" has an unsupported format. Please use PDF, DOC, DOCX, TXT, CSV, XLS, XLSX, MD, PPT, PPTX, or HTML files.`
|
||||
}
|
||||
return null
|
||||
}
|
||||
@@ -168,7 +158,7 @@ export function UploadModal({
|
||||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={handleClose}>
|
||||
<DialogContent className='flex max-h-[95vh] max-w-2xl flex-col overflow-hidden'>
|
||||
<DialogContent className='flex max-h-[95vh] flex-col overflow-hidden sm:max-w-[600px]'>
|
||||
<DialogHeader>
|
||||
<DialogTitle>Upload Documents</DialogTitle>
|
||||
</DialogHeader>
|
||||
@@ -193,7 +183,7 @@ export function UploadModal({
|
||||
<input
|
||||
ref={fileInputRef}
|
||||
type='file'
|
||||
accept={ACCEPTED_FILE_TYPES.join(',')}
|
||||
accept={ACCEPT_ATTRIBUTE}
|
||||
onChange={handleFileChange}
|
||||
className='hidden'
|
||||
multiple
|
||||
@@ -203,7 +193,8 @@ export function UploadModal({
|
||||
{isDragging ? 'Drop files here!' : 'Drop files here or click to browse'}
|
||||
</p>
|
||||
<p className='text-muted-foreground text-xs'>
|
||||
Supports PDF, DOC, DOCX, TXT, CSV, XLS, XLSX (max 100MB each)
|
||||
Supports PDF, DOC, DOCX, TXT, CSV, XLS, XLSX, MD, PPT, PPTX, HTML (max 100MB
|
||||
each)
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
@@ -223,7 +214,7 @@ export function UploadModal({
|
||||
<input
|
||||
ref={fileInputRef}
|
||||
type='file'
|
||||
accept={ACCEPTED_FILE_TYPES.join(',')}
|
||||
accept={ACCEPT_ATTRIBUTE}
|
||||
onChange={handleFileChange}
|
||||
className='hidden'
|
||||
multiple
|
||||
@@ -233,7 +224,7 @@ export function UploadModal({
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className='max-h-60 space-y-2 overflow-auto'>
|
||||
<div className='max-h-80 space-y-2 overflow-auto'>
|
||||
{files.map((file, index) => {
|
||||
const fileStatus = uploadProgress.fileStatuses?.[index]
|
||||
const isCurrentlyUploading = fileStatus?.status === 'uploading'
|
||||
|
||||
@@ -14,23 +14,13 @@ import { Label } from '@/components/ui/label'
|
||||
import { Progress } from '@/components/ui/progress'
|
||||
import { Textarea } from '@/components/ui/textarea'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { ACCEPT_ATTRIBUTE, ACCEPTED_FILE_TYPES, MAX_FILE_SIZE } from '@/lib/uploads/validation'
|
||||
import { getDocumentIcon } from '@/app/workspace/[workspaceId]/knowledge/components'
|
||||
import { useKnowledgeUpload } from '@/app/workspace/[workspaceId]/knowledge/hooks/use-knowledge-upload'
|
||||
import type { KnowledgeBaseData } from '@/stores/knowledge/store'
|
||||
|
||||
const logger = createLogger('CreateModal')
|
||||
|
||||
const MAX_FILE_SIZE = 100 * 1024 * 1024 // 100MB
|
||||
const ACCEPTED_FILE_TYPES = [
|
||||
'application/pdf',
|
||||
'application/msword',
|
||||
'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
||||
'text/plain',
|
||||
'text/csv',
|
||||
'application/vnd.ms-excel',
|
||||
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
|
||||
]
|
||||
|
||||
interface FileWithPreview extends File {
|
||||
preview: string
|
||||
}
|
||||
@@ -168,7 +158,7 @@ export function CreateModal({ open, onOpenChange, onKnowledgeBaseCreated }: Crea
|
||||
// Check file type
|
||||
if (!ACCEPTED_FILE_TYPES.includes(file.type)) {
|
||||
setFileError(
|
||||
`File ${file.name} has an unsupported format. Please use PDF, DOC, DOCX, TXT, CSV, XLS, or XLSX.`
|
||||
`File ${file.name} has an unsupported format. Please use PDF, DOC, DOCX, TXT, CSV, XLS, XLSX, MD, PPT, PPTX, or HTML.`
|
||||
)
|
||||
hasError = true
|
||||
continue
|
||||
@@ -494,7 +484,7 @@ export function CreateModal({ open, onOpenChange, onKnowledgeBaseCreated }: Crea
|
||||
<input
|
||||
ref={fileInputRef}
|
||||
type='file'
|
||||
accept={ACCEPTED_FILE_TYPES.join(',')}
|
||||
accept={ACCEPT_ATTRIBUTE}
|
||||
onChange={handleFileChange}
|
||||
className='hidden'
|
||||
multiple
|
||||
@@ -511,7 +501,8 @@ export function CreateModal({ open, onOpenChange, onKnowledgeBaseCreated }: Crea
|
||||
: 'Drop files here or click to browse'}
|
||||
</p>
|
||||
<p className='text-muted-foreground text-xs'>
|
||||
Supports PDF, DOC, DOCX, TXT, CSV, XLS, XLSX (max 100MB each)
|
||||
Supports PDF, DOC, DOCX, TXT, CSV, XLS, XLSX, MD, PPT, PPTX, HTML (max
|
||||
100MB each)
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
@@ -535,7 +526,7 @@ export function CreateModal({ open, onOpenChange, onKnowledgeBaseCreated }: Crea
|
||||
<input
|
||||
ref={fileInputRef}
|
||||
type='file'
|
||||
accept={ACCEPTED_FILE_TYPES.join(',')}
|
||||
accept={ACCEPT_ATTRIBUTE}
|
||||
onChange={handleFileChange}
|
||||
className='hidden'
|
||||
multiple
|
||||
@@ -552,7 +543,8 @@ export function CreateModal({ open, onOpenChange, onKnowledgeBaseCreated }: Crea
|
||||
: 'Drop more files or click to browse'}
|
||||
</p>
|
||||
<p className='text-muted-foreground text-xs'>
|
||||
PDF, DOC, DOCX, TXT, CSV, XLS, XLSX (max 100MB each)
|
||||
PDF, DOC, DOCX, TXT, CSV, XLS, XLSX, MD, PPT, PPTX, HTML (max 100MB
|
||||
each)
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -25,7 +25,7 @@ import {
|
||||
TooltipProvider,
|
||||
TooltipTrigger,
|
||||
} from '@/components/ui'
|
||||
import { MAX_TAG_SLOTS, type TagSlot } from '@/lib/constants/knowledge'
|
||||
import { MAX_TAG_SLOTS, type TagSlot } from '@/lib/knowledge/consts'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { useKnowledgeBaseTagDefinitions } from '@/hooks/use-knowledge-base-tag-definitions'
|
||||
import { useNextAvailableSlot } from '@/hooks/use-next-available-slot'
|
||||
|
||||
@@ -6,7 +6,7 @@ import { Button } from '@/components/ui/button'
|
||||
import { Collapsible, CollapsibleContent, CollapsibleTrigger } from '@/components/ui/collapsible'
|
||||
import { Input } from '@/components/ui/input'
|
||||
import { Label } from '@/components/ui/label'
|
||||
import { TAG_SLOTS, type TagSlot } from '@/lib/constants/knowledge'
|
||||
import { TAG_SLOTS, type TagSlot } from '@/lib/knowledge/consts'
|
||||
import { useKnowledgeBaseTagDefinitions } from '@/hooks/use-knowledge-base-tag-definitions'
|
||||
|
||||
export type TagData = {
|
||||
|
||||
@@ -12,15 +12,17 @@ import {
|
||||
extractPathFromOutputId,
|
||||
parseOutputContentSafely,
|
||||
} from '@/lib/response-format'
|
||||
import { ChatMessage } from '@/app/workspace/[workspaceId]/w/[workflowId]/components/panel/components/chat/components/chat-message/chat-message'
|
||||
import { OutputSelect } from '@/app/workspace/[workspaceId]/w/[workflowId]/components/panel/components/chat/components/output-select/output-select'
|
||||
import {
|
||||
ChatFileUpload,
|
||||
ChatMessage,
|
||||
OutputSelect,
|
||||
} from '@/app/workspace/[workspaceId]/w/[workflowId]/components/panel/components/chat/components'
|
||||
import { useWorkflowExecution } from '@/app/workspace/[workspaceId]/w/[workflowId]/hooks/use-workflow-execution'
|
||||
import type { BlockLog, ExecutionResult } from '@/executor/types'
|
||||
import { useExecutionStore } from '@/stores/execution/store'
|
||||
import { useChatStore } from '@/stores/panel/chat/store'
|
||||
import { useConsoleStore } from '@/stores/panel/console/store'
|
||||
import { useWorkflowRegistry } from '@/stores/workflows/registry/store'
|
||||
import { ChatFileUpload } from './components/chat-file-upload'
|
||||
|
||||
const logger = createLogger('ChatPanel')
|
||||
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
export { ChatFileUpload } from './chat-file-upload/chat-file-upload'
|
||||
export { ChatMessage } from './chat-message/chat-message'
|
||||
export { OutputSelect } from './output-select/output-select'
|
||||
@@ -155,7 +155,7 @@ const ImagePreview = ({
|
||||
className='h-auto w-full rounded-lg border'
|
||||
unoptimized
|
||||
onError={(e) => {
|
||||
console.error('Image failed to load:', imageSrc)
|
||||
logger.error('Image failed to load:', imageSrc)
|
||||
setLoadError(true)
|
||||
onLoadError?.(true)
|
||||
}}
|
||||
@@ -333,7 +333,7 @@ export function ConsoleEntry({ entry, consoleWidth }: ConsoleEntryProps) {
|
||||
// Clean up the URL
|
||||
setTimeout(() => URL.revokeObjectURL(url), 100)
|
||||
} catch (error) {
|
||||
console.error('Error downloading image:', error)
|
||||
logger.error('Error downloading image:', error)
|
||||
alert('Failed to download image. Please try again later.')
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import {
|
||||
} from '@/components/ui/dropdown-menu'
|
||||
import { ScrollArea } from '@/components/ui/scroll-area'
|
||||
import { Tooltip, TooltipContent, TooltipTrigger } from '@/components/ui/tooltip'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { useCopilotStore } from '@/stores/copilot/store'
|
||||
import { useChatStore } from '@/stores/panel/chat/store'
|
||||
import { useConsoleStore } from '@/stores/panel/console/store'
|
||||
@@ -19,6 +20,8 @@ import { Console } from './components/console/console'
|
||||
import { Copilot } from './components/copilot/copilot'
|
||||
import { Variables } from './components/variables/variables'
|
||||
|
||||
const logger = createLogger('Panel')
|
||||
|
||||
export function Panel() {
|
||||
const [chatMessage, setChatMessage] = useState<string>('')
|
||||
const [isHistoryDropdownOpen, setIsHistoryDropdownOpen] = useState(false)
|
||||
@@ -67,7 +70,7 @@ export function Panel() {
|
||||
try {
|
||||
await deleteChat(chatId)
|
||||
} catch (error) {
|
||||
console.error('Error deleting chat:', error)
|
||||
logger.error('Error deleting chat:', error)
|
||||
}
|
||||
},
|
||||
[deleteChat]
|
||||
@@ -101,7 +104,7 @@ export function Panel() {
|
||||
lastLoadedWorkflowRef.current = activeWorkflowId
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to load copilot data:', error)
|
||||
logger.error('Failed to load copilot data:', error)
|
||||
}
|
||||
},
|
||||
[
|
||||
@@ -134,14 +137,14 @@ export function Panel() {
|
||||
if (!areChatsFresh(activeWorkflowId)) {
|
||||
// Don't await - let it load in background while dropdown is already open
|
||||
ensureCopilotDataLoaded(false).catch((error) => {
|
||||
console.error('Failed to load chat history:', error)
|
||||
logger.error('Failed to load chat history:', error)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// If streaming, just log that we're showing cached data
|
||||
if (open && isSendingMessage) {
|
||||
console.log('Chat history opened during stream - showing cached data only')
|
||||
logger.info('Chat history opened during stream - showing cached data only')
|
||||
}
|
||||
},
|
||||
[ensureCopilotDataLoaded, activeWorkflowId, areChatsFresh, isSendingMessage]
|
||||
@@ -278,7 +281,7 @@ export function Panel() {
|
||||
// This is a real workflow change, not just a tab switch
|
||||
if (copilotWorkflowId !== activeWorkflowId || !copilotWorkflowId) {
|
||||
ensureCopilotDataLoaded().catch((error) => {
|
||||
console.error('Failed to auto-load copilot data on workflow change:', error)
|
||||
logger.error('Failed to auto-load copilot data on workflow change:', error)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -385,16 +385,16 @@ export function Code({
|
||||
|
||||
<div
|
||||
className={cn(
|
||||
'group relative min-h-[100px] rounded-md border bg-background font-mono text-sm transition-colors',
|
||||
'group relative min-h-[100px] rounded-md border border-input bg-background font-mono text-sm transition-colors',
|
||||
isConnecting && 'ring-2 ring-blue-500 ring-offset-2',
|
||||
!isValidJson && 'border-2 border-destructive bg-destructive/10'
|
||||
!isValidJson && 'border-destructive bg-destructive/10'
|
||||
)}
|
||||
title={!isValidJson ? 'Invalid JSON' : undefined}
|
||||
onDragOver={(e) => e.preventDefault()}
|
||||
onDrop={handleDrop}
|
||||
>
|
||||
<div className='absolute top-2 right-3 z-10 flex items-center gap-1 opacity-0 transition-opacity group-hover:opacity-100'>
|
||||
{!isCollapsed && !isAiStreaming && !isPreview && (
|
||||
{wandConfig?.enabled && !isCollapsed && !isAiStreaming && !isPreview && (
|
||||
<Button
|
||||
variant='ghost'
|
||||
size='icon'
|
||||
@@ -486,7 +486,7 @@ export function Code({
|
||||
outline: 'none',
|
||||
}}
|
||||
className={cn(
|
||||
'code-editor-area caret-primary',
|
||||
'code-editor-area caret-primary dark:caret-white',
|
||||
'bg-transparent focus:outline-none',
|
||||
(isCollapsed || isAiStreaming) && 'cursor-not-allowed opacity-50'
|
||||
)}
|
||||
|
||||
@@ -6,7 +6,7 @@ import { Button } from '@/components/ui/button'
|
||||
import { formatDisplayText } from '@/components/ui/formatted-text'
|
||||
import { Input } from '@/components/ui/input'
|
||||
import { checkTagTrigger, TagDropdown } from '@/components/ui/tag-dropdown'
|
||||
import { MAX_TAG_SLOTS } from '@/lib/constants/knowledge'
|
||||
import { MAX_TAG_SLOTS } from '@/lib/knowledge/consts'
|
||||
import { cn } from '@/lib/utils'
|
||||
import { useSubBlockValue } from '@/app/workspace/[workspaceId]/w/[workflowId]/components/workflow-block/components/sub-block/hooks/use-sub-block-value'
|
||||
import type { SubBlockConfig } from '@/blocks/types'
|
||||
|
||||
@@ -235,7 +235,7 @@ export function FileUpload({
|
||||
})
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(`Error uploading ${file.name}:`, error)
|
||||
logger.error(`Error uploading ${file.name}:`, error)
|
||||
const errorMessage = error instanceof Error ? error.message : 'Unknown error'
|
||||
uploadErrors.push(`${file.name}: ${errorMessage}`)
|
||||
}
|
||||
@@ -428,7 +428,7 @@ export function FileUpload({
|
||||
deletionResults.failures.push(`${file.name}: ${errorMessage}`)
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(`Failed to delete file ${file.name}:`, error)
|
||||
logger.error(`Failed to delete file ${file.name}:`, error)
|
||||
deletionResults.failures.push(
|
||||
`${file.name}: ${error instanceof Error ? error.message : 'Unknown error'}`
|
||||
)
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
import { useEffect, useMemo, useRef, useState } from 'react'
|
||||
import { Wand2 } from 'lucide-react'
|
||||
import { useReactFlow } from 'reactflow'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import { checkEnvVarTrigger, EnvVarDropdown } from '@/components/ui/env-var-dropdown'
|
||||
import { formatDisplayText } from '@/components/ui/formatted-text'
|
||||
import { Input } from '@/components/ui/input'
|
||||
import { checkTagTrigger, TagDropdown } from '@/components/ui/tag-dropdown'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { cn } from '@/lib/utils'
|
||||
import { WandPromptBar } from '@/app/workspace/[workspaceId]/w/[workflowId]/components/wand-prompt-bar/wand-prompt-bar'
|
||||
import { useSubBlockValue } from '@/app/workspace/[workspaceId]/w/[workflowId]/components/workflow-block/components/sub-block/hooks/use-sub-block-value'
|
||||
import { useWand } from '@/app/workspace/[workspaceId]/w/[workflowId]/hooks/use-wand'
|
||||
import type { SubBlockConfig } from '@/blocks/types'
|
||||
import { useTagSelection } from '@/hooks/use-tag-selection'
|
||||
import { useOperationQueueStore } from '@/stores/operation-queue/store'
|
||||
@@ -40,19 +44,39 @@ export function ShortInput({
|
||||
previewValue,
|
||||
disabled = false,
|
||||
}: ShortInputProps) {
|
||||
// Local state for immediate UI updates during streaming
|
||||
const [localContent, setLocalContent] = useState<string>('')
|
||||
const [isFocused, setIsFocused] = useState(false)
|
||||
const [showEnvVars, setShowEnvVars] = useState(false)
|
||||
const [showTags, setShowTags] = useState(false)
|
||||
const validatePropValue = (value: any): string => {
|
||||
if (value === undefined || value === null) return ''
|
||||
if (typeof value === 'string') return value
|
||||
try {
|
||||
return String(value)
|
||||
} catch {
|
||||
return ''
|
||||
}
|
||||
}
|
||||
const [storeValue, setStoreValue] = useSubBlockValue(blockId, subBlockId)
|
||||
|
||||
// Wand functionality (only if wandConfig is enabled)
|
||||
const wandHook = config.wandConfig?.enabled
|
||||
? useWand({
|
||||
wandConfig: config.wandConfig,
|
||||
currentValue: localContent,
|
||||
onStreamStart: () => {
|
||||
// Clear the content when streaming starts
|
||||
setLocalContent('')
|
||||
},
|
||||
onStreamChunk: (chunk) => {
|
||||
// Update local content with each chunk as it arrives
|
||||
setLocalContent((current) => current + chunk)
|
||||
},
|
||||
onGeneratedContent: (content) => {
|
||||
// Final content update
|
||||
setLocalContent(content)
|
||||
},
|
||||
})
|
||||
: null
|
||||
// State management - useSubBlockValue with explicit streaming control
|
||||
const [storeValue, setStoreValue] = useSubBlockValue(blockId, subBlockId, false, {
|
||||
isStreaming: wandHook?.isStreaming || false,
|
||||
onStreamingEnd: () => {
|
||||
logger.debug('Wand streaming ended, value persisted', { blockId, subBlockId })
|
||||
},
|
||||
})
|
||||
|
||||
const [searchTerm, setSearchTerm] = useState('')
|
||||
const [cursorPosition, setCursorPosition] = useState(0)
|
||||
const inputRef = useRef<HTMLInputElement>(null)
|
||||
@@ -65,7 +89,29 @@ export function ShortInput({
|
||||
const reactFlowInstance = useReactFlow()
|
||||
|
||||
// Use preview value when in preview mode, otherwise use store value or prop value
|
||||
const value = isPreview ? previewValue : propValue !== undefined ? propValue : storeValue
|
||||
const baseValue = isPreview ? previewValue : propValue !== undefined ? propValue : storeValue
|
||||
|
||||
// During streaming, use local content; otherwise use base value
|
||||
const value = wandHook?.isStreaming ? localContent : baseValue
|
||||
|
||||
// Sync local content with base value when not streaming
|
||||
useEffect(() => {
|
||||
if (!wandHook?.isStreaming) {
|
||||
const baseValueString = baseValue?.toString() ?? ''
|
||||
if (baseValueString !== localContent) {
|
||||
setLocalContent(baseValueString)
|
||||
}
|
||||
}
|
||||
}, [baseValue, wandHook?.isStreaming])
|
||||
|
||||
// Update store value during streaming (but won't persist until streaming ends)
|
||||
useEffect(() => {
|
||||
if (wandHook?.isStreaming && localContent !== '') {
|
||||
if (!isPreview && !disabled) {
|
||||
setStoreValue(localContent)
|
||||
}
|
||||
}
|
||||
}, [localContent, wandHook?.isStreaming, isPreview, disabled, setStoreValue])
|
||||
|
||||
// Check if this input is API key related
|
||||
const isApiKeyField = useMemo(() => {
|
||||
@@ -297,91 +343,130 @@ export function ShortInput({
|
||||
}
|
||||
|
||||
return (
|
||||
<div className='relative w-full'>
|
||||
<Input
|
||||
ref={inputRef}
|
||||
className={cn(
|
||||
'allow-scroll w-full overflow-auto text-transparent caret-foreground placeholder:text-muted-foreground/50',
|
||||
isConnecting &&
|
||||
config?.connectionDroppable !== false &&
|
||||
'ring-2 ring-blue-500 ring-offset-2 focus-visible:ring-blue-500'
|
||||
)}
|
||||
placeholder={placeholder ?? ''}
|
||||
type='text'
|
||||
value={displayValue}
|
||||
onChange={handleChange}
|
||||
onFocus={() => {
|
||||
setIsFocused(true)
|
||||
<>
|
||||
<WandPromptBar
|
||||
isVisible={wandHook?.isPromptVisible || false}
|
||||
isLoading={wandHook?.isLoading || false}
|
||||
isStreaming={wandHook?.isStreaming || false}
|
||||
promptValue={wandHook?.promptInputValue || ''}
|
||||
onSubmit={(prompt: string) => wandHook?.generateStream({ prompt }) || undefined}
|
||||
onCancel={
|
||||
wandHook?.isStreaming
|
||||
? wandHook?.cancelGeneration
|
||||
: wandHook?.hidePromptInline || (() => {})
|
||||
}
|
||||
onChange={(value: string) => wandHook?.updatePromptValue?.(value)}
|
||||
placeholder={config.wandConfig?.placeholder || 'Describe what you want to generate...'}
|
||||
/>
|
||||
|
||||
// If this is an API key field, automatically show env vars dropdown
|
||||
if (isApiKeyField) {
|
||||
setShowEnvVars(true)
|
||||
setSearchTerm('')
|
||||
<div className='group relative w-full'>
|
||||
<Input
|
||||
ref={inputRef}
|
||||
className={cn(
|
||||
'allow-scroll w-full overflow-auto text-transparent caret-foreground placeholder:text-muted-foreground/50',
|
||||
isConnecting &&
|
||||
config?.connectionDroppable !== false &&
|
||||
'ring-2 ring-blue-500 ring-offset-2 focus-visible:ring-blue-500'
|
||||
)}
|
||||
placeholder={placeholder ?? ''}
|
||||
type='text'
|
||||
value={displayValue}
|
||||
onChange={handleChange}
|
||||
onFocus={() => {
|
||||
setIsFocused(true)
|
||||
|
||||
// Set cursor position to the end of the input
|
||||
const inputLength = value?.toString().length ?? 0
|
||||
setCursorPosition(inputLength)
|
||||
} else {
|
||||
// If this is an API key field, automatically show env vars dropdown
|
||||
if (isApiKeyField) {
|
||||
setShowEnvVars(true)
|
||||
setSearchTerm('')
|
||||
|
||||
// Set cursor position to the end of the input
|
||||
const inputLength = value?.toString().length ?? 0
|
||||
setCursorPosition(inputLength)
|
||||
} else {
|
||||
setShowEnvVars(false)
|
||||
setShowTags(false)
|
||||
setSearchTerm('')
|
||||
}
|
||||
}}
|
||||
onBlur={() => {
|
||||
setIsFocused(false)
|
||||
setShowEnvVars(false)
|
||||
setShowTags(false)
|
||||
setSearchTerm('')
|
||||
}
|
||||
}}
|
||||
onBlur={() => {
|
||||
setIsFocused(false)
|
||||
setShowEnvVars(false)
|
||||
try {
|
||||
useOperationQueueStore.getState().flushDebouncedForBlock(blockId)
|
||||
} catch {}
|
||||
}}
|
||||
onDrop={handleDrop}
|
||||
onDragOver={handleDragOver}
|
||||
onScroll={handleScroll}
|
||||
onPaste={handlePaste}
|
||||
onWheel={handleWheel}
|
||||
onKeyDown={handleKeyDown}
|
||||
autoComplete='off'
|
||||
style={{ overflowX: 'auto' }}
|
||||
disabled={disabled}
|
||||
/>
|
||||
<div
|
||||
ref={overlayRef}
|
||||
className='pointer-events-none absolute inset-0 flex items-center overflow-x-auto bg-transparent px-3 text-sm'
|
||||
style={{ overflowX: 'auto' }}
|
||||
>
|
||||
try {
|
||||
useOperationQueueStore.getState().flushDebouncedForBlock(blockId)
|
||||
} catch {}
|
||||
}}
|
||||
onDrop={handleDrop}
|
||||
onDragOver={handleDragOver}
|
||||
onScroll={handleScroll}
|
||||
onPaste={handlePaste}
|
||||
onWheel={handleWheel}
|
||||
onKeyDown={handleKeyDown}
|
||||
autoComplete='off'
|
||||
style={{ overflowX: 'auto' }}
|
||||
disabled={disabled}
|
||||
/>
|
||||
<div
|
||||
className='w-full whitespace-pre'
|
||||
style={{ scrollbarWidth: 'none', minWidth: 'fit-content' }}
|
||||
ref={overlayRef}
|
||||
className='pointer-events-none absolute inset-0 flex items-center overflow-x-auto bg-transparent px-3 text-sm'
|
||||
style={{ overflowX: 'auto' }}
|
||||
>
|
||||
{password && !isFocused
|
||||
? '•'.repeat(value?.toString().length ?? 0)
|
||||
: formatDisplayText(value?.toString() ?? '', true)}
|
||||
<div
|
||||
className='w-full whitespace-pre'
|
||||
style={{ scrollbarWidth: 'none', minWidth: 'fit-content' }}
|
||||
>
|
||||
{password && !isFocused
|
||||
? '•'.repeat(value?.toString().length ?? 0)
|
||||
: formatDisplayText(value?.toString() ?? '', true)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<EnvVarDropdown
|
||||
visible={showEnvVars}
|
||||
onSelect={handleEnvVarSelect}
|
||||
searchTerm={searchTerm}
|
||||
inputValue={value?.toString() ?? ''}
|
||||
cursorPosition={cursorPosition}
|
||||
onClose={() => {
|
||||
setShowEnvVars(false)
|
||||
setSearchTerm('')
|
||||
}}
|
||||
/>
|
||||
<TagDropdown
|
||||
visible={showTags}
|
||||
onSelect={handleEnvVarSelect}
|
||||
blockId={blockId}
|
||||
activeSourceBlockId={activeSourceBlockId}
|
||||
inputValue={value?.toString() ?? ''}
|
||||
cursorPosition={cursorPosition}
|
||||
onClose={() => {
|
||||
setShowTags(false)
|
||||
setActiveSourceBlockId(null)
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
{/* Wand Button */}
|
||||
{wandHook && !isPreview && !wandHook.isStreaming && (
|
||||
<div className='-translate-y-1/2 absolute top-1/2 right-3 z-10 flex items-center gap-1 opacity-0 transition-opacity group-hover:opacity-100'>
|
||||
<Button
|
||||
variant='ghost'
|
||||
size='icon'
|
||||
onClick={
|
||||
wandHook.isPromptVisible ? wandHook.hidePromptInline : wandHook.showPromptInline
|
||||
}
|
||||
disabled={wandHook.isLoading || wandHook.isStreaming || disabled}
|
||||
aria-label='Generate content with AI'
|
||||
className='h-8 w-8 rounded-full border border-transparent bg-muted/80 text-muted-foreground shadow-sm transition-all duration-200 hover:border-primary/20 hover:bg-muted hover:text-primary hover:shadow'
|
||||
>
|
||||
<Wand2 className='h-4 w-4' />
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{!wandHook?.isStreaming && (
|
||||
<>
|
||||
<EnvVarDropdown
|
||||
visible={showEnvVars}
|
||||
onSelect={handleEnvVarSelect}
|
||||
searchTerm={searchTerm}
|
||||
inputValue={value?.toString() ?? ''}
|
||||
cursorPosition={cursorPosition}
|
||||
onClose={() => {
|
||||
setShowEnvVars(false)
|
||||
setSearchTerm('')
|
||||
}}
|
||||
/>
|
||||
<TagDropdown
|
||||
visible={showTags}
|
||||
onSelect={handleEnvVarSelect}
|
||||
blockId={blockId}
|
||||
activeSourceBlockId={activeSourceBlockId}
|
||||
inputValue={value?.toString() ?? ''}
|
||||
cursorPosition={cursorPosition}
|
||||
onClose={() => {
|
||||
setShowTags(false)
|
||||
setActiveSourceBlockId(null)
|
||||
}}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -483,7 +483,7 @@ export function ToolInput({
|
||||
try {
|
||||
return block.tools.config.tool({ operation })
|
||||
} catch (error) {
|
||||
console.error('Error selecting tool for operation:', error)
|
||||
logger.error('Error selecting tool for operation:', error)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -486,10 +486,15 @@ export function SubBlock({
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
)}
|
||||
{config.id === 'responseFormat' && !isValidJson && (
|
||||
{config.id === 'responseFormat' && (
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<AlertTriangle className='h-4 w-4 cursor-pointer text-destructive' />
|
||||
<AlertTriangle
|
||||
className={cn(
|
||||
'h-4 w-4 cursor-pointer text-destructive',
|
||||
!isValidJson ? 'opacity-100' : 'opacity-0'
|
||||
)}
|
||||
/>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent side='top'>
|
||||
<p>Invalid JSON</p>
|
||||
|
||||
@@ -6,6 +6,7 @@ import { Badge } from '@/components/ui/badge'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import { Card } from '@/components/ui/card'
|
||||
import { Tooltip, TooltipContent, TooltipTrigger } from '@/components/ui/tooltip'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { parseCronToHumanReadable } from '@/lib/schedules/utils'
|
||||
import { cn, validateName } from '@/lib/utils'
|
||||
import { type DiffStatus, hasDiffStatus } from '@/lib/workflows/diff/types'
|
||||
@@ -23,6 +24,8 @@ import { ActionBar } from './components/action-bar/action-bar'
|
||||
import { ConnectionBlocks } from './components/connection-blocks/connection-blocks'
|
||||
import { SubBlock } from './components/sub-block/sub-block'
|
||||
|
||||
const logger = createLogger('WorkflowBlock')
|
||||
|
||||
interface WorkflowBlockProps {
|
||||
type: string
|
||||
config: BlockConfig
|
||||
@@ -232,10 +235,10 @@ export function WorkflowBlock({ id, data }: NodeProps<WorkflowBlockProps>) {
|
||||
fetchScheduleInfo(currentWorkflowId)
|
||||
}
|
||||
} else {
|
||||
console.error('Failed to reactivate schedule')
|
||||
logger.error('Failed to reactivate schedule')
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error reactivating schedule:', error)
|
||||
logger.error('Error reactivating schedule:', error)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -255,10 +258,10 @@ export function WorkflowBlock({ id, data }: NodeProps<WorkflowBlockProps>) {
|
||||
fetchScheduleInfo(currentWorkflowId)
|
||||
}
|
||||
} else {
|
||||
console.error('Failed to disable schedule')
|
||||
logger.error('Failed to disable schedule')
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error disabling schedule:', error)
|
||||
logger.error('Error disabling schedule:', error)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -328,12 +331,12 @@ export function WorkflowBlock({ id, data }: NodeProps<WorkflowBlockProps>) {
|
||||
return
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('Error fetching schedule status:', err)
|
||||
logger.error('Error fetching schedule status:', err)
|
||||
}
|
||||
|
||||
setScheduleInfo(baseInfo)
|
||||
} catch (error) {
|
||||
console.error('Error fetching schedule info:', error)
|
||||
logger.error('Error fetching schedule info:', error)
|
||||
setScheduleInfo(null)
|
||||
} finally {
|
||||
setIsLoadingScheduleInfo(false)
|
||||
|
||||
@@ -26,7 +26,7 @@ import {
|
||||
AlertDialogTitle,
|
||||
} from '@/components/ui/alert-dialog'
|
||||
import { ScrollArea } from '@/components/ui/scroll-area'
|
||||
import { MAX_TAG_SLOTS } from '@/lib/constants/knowledge'
|
||||
import { MAX_TAG_SLOTS } from '@/lib/knowledge/consts'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { getDocumentIcon } from '@/app/workspace/[workspaceId]/knowledge/components/icons/document-icons'
|
||||
import { useUserPermissionsContext } from '@/app/workspace/[workspaceId]/providers/workspace-permissions-provider'
|
||||
|
||||
@@ -17,7 +17,7 @@ import {
|
||||
SelectValue,
|
||||
} from '@/components/ui'
|
||||
import { ScrollArea } from '@/components/ui/scroll-area'
|
||||
import { MAX_TAG_SLOTS, TAG_SLOTS, type TagSlot } from '@/lib/constants/knowledge'
|
||||
import { MAX_TAG_SLOTS, TAG_SLOTS, type TagSlot } from '@/lib/knowledge/consts'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import type { DocumentTag } from '@/app/workspace/[workspaceId]/knowledge/components/document-tag-entry/document-tag-entry'
|
||||
import { useUserPermissionsContext } from '@/app/workspace/[workspaceId]/providers/workspace-permissions-provider'
|
||||
|
||||
@@ -15,9 +15,12 @@ import {
|
||||
import { Button } from '@/components/ui/button'
|
||||
import { Input } from '@/components/ui/input'
|
||||
import { Skeleton } from '@/components/ui/skeleton'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { useEnvironmentStore } from '@/stores/settings/environment/store'
|
||||
import type { EnvironmentVariable as StoreEnvironmentVariable } from '@/stores/settings/environment/types'
|
||||
|
||||
const logger = createLogger('EnvironmentVariables')
|
||||
|
||||
// Constants
|
||||
const GRID_COLS = 'grid grid-cols-[minmax(0,1fr),minmax(0,1fr),40px] gap-4'
|
||||
const INITIAL_ENV_VAR: UIEnvironmentVariable = { key: '', value: '' }
|
||||
@@ -263,7 +266,7 @@ export function EnvironmentVariables({
|
||||
// Single store update that triggers sync
|
||||
useEnvironmentStore.getState().setVariables(validVariables)
|
||||
} catch (error) {
|
||||
console.error('Failed to save environment variables:', error)
|
||||
logger.error('Failed to save environment variables:', error)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -197,10 +197,10 @@ export function Subscription({ onOpenChange }: SubscriptionProps) {
|
||||
const activeOrgId = activeOrganization?.id
|
||||
|
||||
useEffect(() => {
|
||||
if (subscription.isTeam && activeOrgId) {
|
||||
if ((subscription.isTeam || subscription.isEnterprise) && activeOrgId) {
|
||||
loadOrganizationBillingData(activeOrgId)
|
||||
}
|
||||
}, [activeOrgId, subscription.isTeam, loadOrganizationBillingData])
|
||||
}, [activeOrgId, subscription.isTeam, subscription.isEnterprise, loadOrganizationBillingData])
|
||||
|
||||
// Auto-clear upgrade error
|
||||
useEffect(() => {
|
||||
@@ -349,22 +349,39 @@ export function Subscription({ onOpenChange }: SubscriptionProps) {
|
||||
badgeText={badgeText}
|
||||
onBadgeClick={handleBadgeClick}
|
||||
seatsText={
|
||||
permissions.canManageTeam
|
||||
permissions.canManageTeam || subscription.isEnterprise
|
||||
? `${organizationBillingData?.totalSeats || subscription.seats || 1} seats`
|
||||
: undefined
|
||||
}
|
||||
current={usage.current}
|
||||
current={
|
||||
subscription.isEnterprise || subscription.isTeam
|
||||
? organizationBillingData?.totalCurrentUsage || 0
|
||||
: usage.current
|
||||
}
|
||||
limit={
|
||||
!subscription.isFree &&
|
||||
(permissions.canEditUsageLimit ||
|
||||
permissions.showTeamMemberView ||
|
||||
subscription.isEnterprise)
|
||||
? usage.current // placeholder; rightContent will render UsageLimit
|
||||
: usage.limit
|
||||
subscription.isEnterprise || subscription.isTeam
|
||||
? organizationBillingData?.totalUsageLimit ||
|
||||
organizationBillingData?.minimumBillingAmount ||
|
||||
0
|
||||
: !subscription.isFree &&
|
||||
(permissions.canEditUsageLimit || permissions.showTeamMemberView)
|
||||
? usage.current // placeholder; rightContent will render UsageLimit
|
||||
: usage.limit
|
||||
}
|
||||
isBlocked={Boolean(subscriptionData?.billingBlocked)}
|
||||
status={billingStatus === 'unknown' ? 'ok' : billingStatus}
|
||||
percentUsed={Math.round(usage.percentUsed)}
|
||||
percentUsed={
|
||||
subscription.isEnterprise || subscription.isTeam
|
||||
? organizationBillingData?.totalUsageLimit &&
|
||||
organizationBillingData.totalUsageLimit > 0
|
||||
? Math.round(
|
||||
(organizationBillingData.totalCurrentUsage /
|
||||
organizationBillingData.totalUsageLimit) *
|
||||
100
|
||||
)
|
||||
: 0
|
||||
: Math.round(usage.percentUsed)
|
||||
}
|
||||
onResolvePayment={async () => {
|
||||
try {
|
||||
const res = await fetch('/api/billing/portal', {
|
||||
@@ -387,9 +404,7 @@ export function Subscription({ onOpenChange }: SubscriptionProps) {
|
||||
}}
|
||||
rightContent={
|
||||
!subscription.isFree &&
|
||||
(permissions.canEditUsageLimit ||
|
||||
permissions.showTeamMemberView ||
|
||||
subscription.isEnterprise) ? (
|
||||
(permissions.canEditUsageLimit || permissions.showTeamMemberView) ? (
|
||||
<UsageLimit
|
||||
ref={usageLimitRef}
|
||||
currentLimit={
|
||||
@@ -398,7 +413,7 @@ export function Subscription({ onOpenChange }: SubscriptionProps) {
|
||||
: usageLimitData?.currentLimit || usage.limit
|
||||
}
|
||||
currentUsage={usage.current}
|
||||
canEdit={permissions.canEditUsageLimit && !subscription.isEnterprise}
|
||||
canEdit={permissions.canEditUsageLimit}
|
||||
minimumLimit={
|
||||
subscription.isTeam && isTeamAdmin
|
||||
? organizationBillingData?.minimumBillingAmount ||
|
||||
|
||||
@@ -1007,8 +1007,11 @@ export function Sidebar() {
|
||||
>
|
||||
<UsageIndicator
|
||||
onClick={() => {
|
||||
const isBlocked = useSubscriptionStore.getState().getBillingStatus() === 'blocked'
|
||||
if (isBlocked) {
|
||||
const subscriptionStore = useSubscriptionStore.getState()
|
||||
const isBlocked = subscriptionStore.getBillingStatus() === 'blocked'
|
||||
const canUpgrade = subscriptionStore.canUpgrade()
|
||||
|
||||
if (isBlocked || !canUpgrade) {
|
||||
if (typeof window !== 'undefined') {
|
||||
window.dispatchEvent(
|
||||
new CustomEvent('open-settings', { detail: { tab: 'subscription' } })
|
||||
@@ -1036,6 +1039,7 @@ export function Sidebar() {
|
||||
<HelpModal open={showHelp} onOpenChange={setShowHelp} />
|
||||
<InviteModal open={showInviteMembers} onOpenChange={setShowInviteMembers} />
|
||||
<SubscriptionModal open={showSubscriptionModal} onOpenChange={setShowSubscriptionModal} />
|
||||
|
||||
<SearchModal
|
||||
open={showSearchModal}
|
||||
onOpenChange={setShowSearchModal}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { task } from '@trigger.dev/sdk'
|
||||
import { env } from '@/lib/env'
|
||||
import { processDocumentAsync } from '@/lib/knowledge/documents/service'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
|
||||
@@ -25,15 +26,15 @@ export type DocumentProcessingPayload = {
|
||||
|
||||
export const processDocument = task({
|
||||
id: 'knowledge-process-document',
|
||||
maxDuration: 300,
|
||||
maxDuration: env.KB_CONFIG_MAX_DURATION || 300,
|
||||
retry: {
|
||||
maxAttempts: 3,
|
||||
factor: 2,
|
||||
minTimeoutInMs: 1000,
|
||||
maxTimeoutInMs: 10000,
|
||||
maxAttempts: env.KB_CONFIG_MAX_ATTEMPTS || 3,
|
||||
factor: env.KB_CONFIG_RETRY_FACTOR || 2,
|
||||
minTimeoutInMs: env.KB_CONFIG_MIN_TIMEOUT || 1000,
|
||||
maxTimeoutInMs: env.KB_CONFIG_MAX_TIMEOUT || 10000,
|
||||
},
|
||||
queue: {
|
||||
concurrencyLimit: 20,
|
||||
concurrencyLimit: env.KB_CONFIG_CONCURRENCY_LIMIT || 20,
|
||||
name: 'document-processing-queue',
|
||||
},
|
||||
run: async (payload: DocumentProcessingPayload) => {
|
||||
|
||||
@@ -118,6 +118,72 @@ export const MySQLBlock: BlockConfig<MySQLResponse> = {
|
||||
placeholder: 'SELECT * FROM users WHERE active = true',
|
||||
condition: { field: 'operation', value: 'query' },
|
||||
required: true,
|
||||
wandConfig: {
|
||||
enabled: true,
|
||||
maintainHistory: true,
|
||||
prompt: `You are an expert MySQL database developer. Write MySQL SQL queries based on the user's request.
|
||||
|
||||
### CONTEXT
|
||||
{context}
|
||||
|
||||
### CRITICAL INSTRUCTION
|
||||
Return ONLY the SQL query. Do not include any explanations, markdown formatting, comments, or additional text. Just the raw SQL query.
|
||||
|
||||
### QUERY GUIDELINES
|
||||
1. **Syntax**: Use MySQL-specific syntax and functions
|
||||
2. **Performance**: Write efficient queries with proper indexing considerations
|
||||
3. **Security**: Use parameterized queries when applicable
|
||||
4. **Readability**: Format queries with proper indentation and spacing
|
||||
5. **Best Practices**: Follow MySQL naming conventions
|
||||
|
||||
### MYSQL FEATURES
|
||||
- Use MySQL-specific functions (IFNULL, DATE_FORMAT, CONCAT, etc.)
|
||||
- Leverage MySQL features like GROUP_CONCAT, AUTO_INCREMENT
|
||||
- Use proper MySQL data types (VARCHAR, DATETIME, DECIMAL, JSON, etc.)
|
||||
- Include appropriate LIMIT clauses for large result sets
|
||||
|
||||
### EXAMPLES
|
||||
|
||||
**Simple Select**: "Get all active users"
|
||||
→ SELECT id, name, email, created_at
|
||||
FROM users
|
||||
WHERE active = 1
|
||||
ORDER BY created_at DESC;
|
||||
|
||||
**Complex Join**: "Get users with their order counts and total spent"
|
||||
→ SELECT
|
||||
u.id,
|
||||
u.name,
|
||||
u.email,
|
||||
COUNT(o.id) as order_count,
|
||||
IFNULL(SUM(o.total), 0) as total_spent
|
||||
FROM users u
|
||||
LEFT JOIN orders o ON u.id = o.user_id
|
||||
WHERE u.active = 1
|
||||
GROUP BY u.id, u.name, u.email
|
||||
HAVING COUNT(o.id) > 0
|
||||
ORDER BY total_spent DESC;
|
||||
|
||||
**With Subquery**: "Get top 10 products by sales"
|
||||
→ SELECT
|
||||
p.id,
|
||||
p.name,
|
||||
(SELECT SUM(oi.quantity * oi.price)
|
||||
FROM order_items oi
|
||||
JOIN orders o ON oi.order_id = o.id
|
||||
WHERE oi.product_id = p.id
|
||||
AND o.created_at >= DATE_SUB(NOW(), INTERVAL 30 DAY)
|
||||
) as total_sales
|
||||
FROM products p
|
||||
WHERE p.active = 1
|
||||
ORDER BY total_sales DESC
|
||||
LIMIT 10;
|
||||
|
||||
### REMEMBER
|
||||
Return ONLY the SQL query - no explanations, no markdown, no extra text.`,
|
||||
placeholder: 'Describe the SQL query you need...',
|
||||
generationType: 'sql-query',
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'query',
|
||||
@@ -127,6 +193,72 @@ export const MySQLBlock: BlockConfig<MySQLResponse> = {
|
||||
placeholder: 'SELECT * FROM table_name',
|
||||
condition: { field: 'operation', value: 'execute' },
|
||||
required: true,
|
||||
wandConfig: {
|
||||
enabled: true,
|
||||
maintainHistory: true,
|
||||
prompt: `You are an expert MySQL database developer. Write MySQL SQL queries based on the user's request.
|
||||
|
||||
### CONTEXT
|
||||
{context}
|
||||
|
||||
### CRITICAL INSTRUCTION
|
||||
Return ONLY the SQL query. Do not include any explanations, markdown formatting, comments, or additional text. Just the raw SQL query.
|
||||
|
||||
### QUERY GUIDELINES
|
||||
1. **Syntax**: Use MySQL-specific syntax and functions
|
||||
2. **Performance**: Write efficient queries with proper indexing considerations
|
||||
3. **Security**: Use parameterized queries when applicable
|
||||
4. **Readability**: Format queries with proper indentation and spacing
|
||||
5. **Best Practices**: Follow MySQL naming conventions
|
||||
|
||||
### MYSQL FEATURES
|
||||
- Use MySQL-specific functions (IFNULL, DATE_FORMAT, CONCAT, etc.)
|
||||
- Leverage MySQL features like GROUP_CONCAT, AUTO_INCREMENT
|
||||
- Use proper MySQL data types (VARCHAR, DATETIME, DECIMAL, JSON, etc.)
|
||||
- Include appropriate LIMIT clauses for large result sets
|
||||
|
||||
### EXAMPLES
|
||||
|
||||
**Simple Select**: "Get all active users"
|
||||
→ SELECT id, name, email, created_at
|
||||
FROM users
|
||||
WHERE active = 1
|
||||
ORDER BY created_at DESC;
|
||||
|
||||
**Complex Join**: "Get users with their order counts and total spent"
|
||||
→ SELECT
|
||||
u.id,
|
||||
u.name,
|
||||
u.email,
|
||||
COUNT(o.id) as order_count,
|
||||
IFNULL(SUM(o.total), 0) as total_spent
|
||||
FROM users u
|
||||
LEFT JOIN orders o ON u.id = o.user_id
|
||||
WHERE u.active = 1
|
||||
GROUP BY u.id, u.name, u.email
|
||||
HAVING COUNT(o.id) > 0
|
||||
ORDER BY total_spent DESC;
|
||||
|
||||
**With Subquery**: "Get top 10 products by sales"
|
||||
→ SELECT
|
||||
p.id,
|
||||
p.name,
|
||||
(SELECT SUM(oi.quantity * oi.price)
|
||||
FROM order_items oi
|
||||
JOIN orders o ON oi.order_id = o.id
|
||||
WHERE oi.product_id = p.id
|
||||
AND o.created_at >= DATE_SUB(NOW(), INTERVAL 30 DAY)
|
||||
) as total_sales
|
||||
FROM products p
|
||||
WHERE p.active = 1
|
||||
ORDER BY total_sales DESC
|
||||
LIMIT 10;
|
||||
|
||||
### REMEMBER
|
||||
Return ONLY the SQL query - no explanations, no markdown, no extra text.`,
|
||||
placeholder: 'Describe the SQL query you need...',
|
||||
generationType: 'sql-query',
|
||||
},
|
||||
},
|
||||
// Data for insert operations
|
||||
{
|
||||
|
||||
@@ -118,6 +118,73 @@ export const PostgreSQLBlock: BlockConfig<PostgresResponse> = {
|
||||
placeholder: 'SELECT * FROM users WHERE active = true',
|
||||
condition: { field: 'operation', value: 'query' },
|
||||
required: true,
|
||||
wandConfig: {
|
||||
enabled: true,
|
||||
maintainHistory: true,
|
||||
prompt: `You are an expert PostgreSQL database developer. Write PostgreSQL SQL queries based on the user's request.
|
||||
|
||||
### CONTEXT
|
||||
{context}
|
||||
|
||||
### CRITICAL INSTRUCTION
|
||||
Return ONLY the SQL query. Do not include any explanations, markdown formatting, comments, or additional text. Just the raw SQL query.
|
||||
|
||||
### QUERY GUIDELINES
|
||||
1. **Syntax**: Use PostgreSQL-specific syntax and functions
|
||||
2. **Performance**: Write efficient queries with proper indexing considerations
|
||||
3. **Security**: Use parameterized queries when applicable
|
||||
4. **Readability**: Format queries with proper indentation and spacing
|
||||
5. **Best Practices**: Follow PostgreSQL naming conventions
|
||||
|
||||
### POSTGRESQL FEATURES
|
||||
- Use PostgreSQL-specific functions (COALESCE, EXTRACT, etc.)
|
||||
- Leverage advanced features like CTEs, window functions, arrays
|
||||
- Use proper PostgreSQL data types (TEXT, TIMESTAMPTZ, JSONB, etc.)
|
||||
- Include appropriate LIMIT clauses for large result sets
|
||||
|
||||
### EXAMPLES
|
||||
|
||||
**Simple Select**: "Get all active users"
|
||||
→ SELECT id, name, email, created_at
|
||||
FROM users
|
||||
WHERE active = true
|
||||
ORDER BY created_at DESC;
|
||||
|
||||
**Complex Join**: "Get users with their order counts and total spent"
|
||||
→ SELECT
|
||||
u.id,
|
||||
u.name,
|
||||
u.email,
|
||||
COUNT(o.id) as order_count,
|
||||
COALESCE(SUM(o.total), 0) as total_spent
|
||||
FROM users u
|
||||
LEFT JOIN orders o ON u.id = o.user_id
|
||||
WHERE u.active = true
|
||||
GROUP BY u.id, u.name, u.email
|
||||
HAVING COUNT(o.id) > 0
|
||||
ORDER BY total_spent DESC;
|
||||
|
||||
**With CTE**: "Get top 10 products by sales"
|
||||
→ WITH product_sales AS (
|
||||
SELECT
|
||||
p.id,
|
||||
p.name,
|
||||
SUM(oi.quantity * oi.price) as total_sales
|
||||
FROM products p
|
||||
JOIN order_items oi ON p.id = oi.product_id
|
||||
JOIN orders o ON oi.order_id = o.id
|
||||
WHERE o.created_at >= CURRENT_DATE - INTERVAL '30 days'
|
||||
GROUP BY p.id, p.name
|
||||
)
|
||||
SELECT * FROM product_sales
|
||||
ORDER BY total_sales DESC
|
||||
LIMIT 10;
|
||||
|
||||
### REMEMBER
|
||||
Return ONLY the SQL query - no explanations, no markdown, no extra text.`,
|
||||
placeholder: 'Describe the SQL query you need...',
|
||||
generationType: 'sql-query',
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'query',
|
||||
@@ -127,6 +194,73 @@ export const PostgreSQLBlock: BlockConfig<PostgresResponse> = {
|
||||
placeholder: 'SELECT * FROM table_name',
|
||||
condition: { field: 'operation', value: 'execute' },
|
||||
required: true,
|
||||
wandConfig: {
|
||||
enabled: true,
|
||||
maintainHistory: true,
|
||||
prompt: `You are an expert PostgreSQL database developer. Write PostgreSQL SQL queries based on the user's request.
|
||||
|
||||
### CONTEXT
|
||||
{context}
|
||||
|
||||
### CRITICAL INSTRUCTION
|
||||
Return ONLY the SQL query. Do not include any explanations, markdown formatting, comments, or additional text. Just the raw SQL query.
|
||||
|
||||
### QUERY GUIDELINES
|
||||
1. **Syntax**: Use PostgreSQL-specific syntax and functions
|
||||
2. **Performance**: Write efficient queries with proper indexing considerations
|
||||
3. **Security**: Use parameterized queries when applicable
|
||||
4. **Readability**: Format queries with proper indentation and spacing
|
||||
5. **Best Practices**: Follow PostgreSQL naming conventions
|
||||
|
||||
### POSTGRESQL FEATURES
|
||||
- Use PostgreSQL-specific functions (COALESCE, EXTRACT, etc.)
|
||||
- Leverage advanced features like CTEs, window functions, arrays
|
||||
- Use proper PostgreSQL data types (TEXT, TIMESTAMPTZ, JSONB, etc.)
|
||||
- Include appropriate LIMIT clauses for large result sets
|
||||
|
||||
### EXAMPLES
|
||||
|
||||
**Simple Select**: "Get all active users"
|
||||
→ SELECT id, name, email, created_at
|
||||
FROM users
|
||||
WHERE active = true
|
||||
ORDER BY created_at DESC;
|
||||
|
||||
**Complex Join**: "Get users with their order counts and total spent"
|
||||
→ SELECT
|
||||
u.id,
|
||||
u.name,
|
||||
u.email,
|
||||
COUNT(o.id) as order_count,
|
||||
COALESCE(SUM(o.total), 0) as total_spent
|
||||
FROM users u
|
||||
LEFT JOIN orders o ON u.id = o.user_id
|
||||
WHERE u.active = true
|
||||
GROUP BY u.id, u.name, u.email
|
||||
HAVING COUNT(o.id) > 0
|
||||
ORDER BY total_spent DESC;
|
||||
|
||||
**With CTE**: "Get top 10 products by sales"
|
||||
→ WITH product_sales AS (
|
||||
SELECT
|
||||
p.id,
|
||||
p.name,
|
||||
SUM(oi.quantity * oi.price) as total_sales
|
||||
FROM products p
|
||||
JOIN order_items oi ON p.id = oi.product_id
|
||||
JOIN orders o ON oi.order_id = o.id
|
||||
WHERE o.created_at >= CURRENT_DATE - INTERVAL '30 days'
|
||||
GROUP BY p.id, p.name
|
||||
)
|
||||
SELECT * FROM product_sales
|
||||
ORDER BY total_sales DESC
|
||||
LIMIT 10;
|
||||
|
||||
### REMEMBER
|
||||
Return ONLY the SQL query - no explanations, no markdown, no extra text.`,
|
||||
placeholder: 'Describe the SQL query you need...',
|
||||
generationType: 'sql-query',
|
||||
},
|
||||
},
|
||||
// Data for insert operations
|
||||
{
|
||||
|
||||
@@ -94,6 +94,66 @@ export const SupabaseBlock: BlockConfig<SupabaseResponse> = {
|
||||
placeholder: 'id=eq.123',
|
||||
condition: { field: 'operation', value: 'get_row' },
|
||||
required: true,
|
||||
wandConfig: {
|
||||
enabled: true,
|
||||
maintainHistory: true,
|
||||
prompt: `You are an expert in PostgREST API syntax. Generate PostgREST filter expressions based on the user's request.
|
||||
|
||||
### CONTEXT
|
||||
{context}
|
||||
|
||||
### CRITICAL INSTRUCTION
|
||||
Return ONLY the PostgREST filter expression. Do not include any explanations, markdown formatting, or additional text. Just the raw filter expression.
|
||||
|
||||
### POSTGREST FILTER SYNTAX
|
||||
PostgREST uses a specific syntax for filtering data. The format is:
|
||||
column=operator.value
|
||||
|
||||
### OPERATORS
|
||||
- **eq** - equals: \`id=eq.123\`
|
||||
- **neq** - not equals: \`status=neq.inactive\`
|
||||
- **gt** - greater than: \`age=gt.18\`
|
||||
- **gte** - greater than or equal: \`score=gte.80\`
|
||||
- **lt** - less than: \`price=lt.100\`
|
||||
- **lte** - less than or equal: \`rating=lte.5\`
|
||||
- **like** - pattern matching: \`name=like.*john*\`
|
||||
- **ilike** - case-insensitive like: \`email=ilike.*@gmail.com\`
|
||||
- **in** - in list: \`category=in.(tech,science,art)\`
|
||||
- **is** - is null/not null: \`deleted_at=is.null\`
|
||||
- **not** - negation: \`not.and=(status.eq.active,verified.eq.true)\`
|
||||
|
||||
### COMBINING FILTERS
|
||||
- **AND**: Use \`&\` or \`and=(...)\`: \`id=eq.123&status=eq.active\`
|
||||
- **OR**: Use \`or=(...)\`: \`or=(status.eq.active,status.eq.pending)\`
|
||||
|
||||
### EXAMPLES
|
||||
|
||||
**Simple equality**: "Find user with ID 123"
|
||||
→ id=eq.123
|
||||
|
||||
**Text search**: "Find users with Gmail addresses"
|
||||
→ email=ilike.*@gmail.com
|
||||
|
||||
**Range filter**: "Find products under $50"
|
||||
→ price=lt.50
|
||||
|
||||
**Multiple conditions**: "Find active users over 18"
|
||||
→ age=gt.18&status=eq.active
|
||||
|
||||
**OR condition**: "Find active or pending orders"
|
||||
→ or=(status.eq.active,status.eq.pending)
|
||||
|
||||
**In list**: "Find posts in specific categories"
|
||||
→ category=in.(tech,science,health)
|
||||
|
||||
**Null check**: "Find users without a profile picture"
|
||||
→ profile_image=is.null
|
||||
|
||||
### REMEMBER
|
||||
Return ONLY the PostgREST filter expression - no explanations, no markdown, no extra text.`,
|
||||
placeholder: 'Describe the filter condition you need...',
|
||||
generationType: 'postgrest',
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'filter',
|
||||
@@ -103,6 +163,66 @@ export const SupabaseBlock: BlockConfig<SupabaseResponse> = {
|
||||
placeholder: 'id=eq.123',
|
||||
condition: { field: 'operation', value: 'update' },
|
||||
required: true,
|
||||
wandConfig: {
|
||||
enabled: true,
|
||||
maintainHistory: true,
|
||||
prompt: `You are an expert in PostgREST API syntax. Generate PostgREST filter expressions based on the user's request.
|
||||
|
||||
### CONTEXT
|
||||
{context}
|
||||
|
||||
### CRITICAL INSTRUCTION
|
||||
Return ONLY the PostgREST filter expression. Do not include any explanations, markdown formatting, or additional text. Just the raw filter expression.
|
||||
|
||||
### POSTGREST FILTER SYNTAX
|
||||
PostgREST uses a specific syntax for filtering data. The format is:
|
||||
column=operator.value
|
||||
|
||||
### OPERATORS
|
||||
- **eq** - equals: \`id=eq.123\`
|
||||
- **neq** - not equals: \`status=neq.inactive\`
|
||||
- **gt** - greater than: \`age=gt.18\`
|
||||
- **gte** - greater than or equal: \`score=gte.80\`
|
||||
- **lt** - less than: \`price=lt.100\`
|
||||
- **lte** - less than or equal: \`rating=lte.5\`
|
||||
- **like** - pattern matching: \`name=like.*john*\`
|
||||
- **ilike** - case-insensitive like: \`email=ilike.*@gmail.com\`
|
||||
- **in** - in list: \`category=in.(tech,science,art)\`
|
||||
- **is** - is null/not null: \`deleted_at=is.null\`
|
||||
- **not** - negation: \`not.and=(status.eq.active,verified.eq.true)\`
|
||||
|
||||
### COMBINING FILTERS
|
||||
- **AND**: Use \`&\` or \`and=(...)\`: \`id=eq.123&status=eq.active\`
|
||||
- **OR**: Use \`or=(...)\`: \`or=(status.eq.active,status.eq.pending)\`
|
||||
|
||||
### EXAMPLES
|
||||
|
||||
**Simple equality**: "Find user with ID 123"
|
||||
→ id=eq.123
|
||||
|
||||
**Text search**: "Find users with Gmail addresses"
|
||||
→ email=ilike.*@gmail.com
|
||||
|
||||
**Range filter**: "Find products under $50"
|
||||
→ price=lt.50
|
||||
|
||||
**Multiple conditions**: "Find active users over 18"
|
||||
→ age=gt.18&status=eq.active
|
||||
|
||||
**OR condition**: "Find active or pending orders"
|
||||
→ or=(status.eq.active,status.eq.pending)
|
||||
|
||||
**In list**: "Find posts in specific categories"
|
||||
→ category=in.(tech,science,health)
|
||||
|
||||
**Null check**: "Find users without a profile picture"
|
||||
→ profile_image=is.null
|
||||
|
||||
### REMEMBER
|
||||
Return ONLY the PostgREST filter expression - no explanations, no markdown, no extra text.`,
|
||||
placeholder: 'Describe the filter condition you need...',
|
||||
generationType: 'postgrest',
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'filter',
|
||||
@@ -112,6 +232,66 @@ export const SupabaseBlock: BlockConfig<SupabaseResponse> = {
|
||||
placeholder: 'id=eq.123',
|
||||
condition: { field: 'operation', value: 'delete' },
|
||||
required: true,
|
||||
wandConfig: {
|
||||
enabled: true,
|
||||
maintainHistory: true,
|
||||
prompt: `You are an expert in PostgREST API syntax. Generate PostgREST filter expressions based on the user's request.
|
||||
|
||||
### CONTEXT
|
||||
{context}
|
||||
|
||||
### CRITICAL INSTRUCTION
|
||||
Return ONLY the PostgREST filter expression. Do not include any explanations, markdown formatting, or additional text. Just the raw filter expression.
|
||||
|
||||
### POSTGREST FILTER SYNTAX
|
||||
PostgREST uses a specific syntax for filtering data. The format is:
|
||||
column=operator.value
|
||||
|
||||
### OPERATORS
|
||||
- **eq** - equals: \`id=eq.123\`
|
||||
- **neq** - not equals: \`status=neq.inactive\`
|
||||
- **gt** - greater than: \`age=gt.18\`
|
||||
- **gte** - greater than or equal: \`score=gte.80\`
|
||||
- **lt** - less than: \`price=lt.100\`
|
||||
- **lte** - less than or equal: \`rating=lte.5\`
|
||||
- **like** - pattern matching: \`name=like.*john*\`
|
||||
- **ilike** - case-insensitive like: \`email=ilike.*@gmail.com\`
|
||||
- **in** - in list: \`category=in.(tech,science,art)\`
|
||||
- **is** - is null/not null: \`deleted_at=is.null\`
|
||||
- **not** - negation: \`not.and=(status.eq.active,verified.eq.true)\`
|
||||
|
||||
### COMBINING FILTERS
|
||||
- **AND**: Use \`&\` or \`and=(...)\`: \`id=eq.123&status=eq.active\`
|
||||
- **OR**: Use \`or=(...)\`: \`or=(status.eq.active,status.eq.pending)\`
|
||||
|
||||
### EXAMPLES
|
||||
|
||||
**Simple equality**: "Find user with ID 123"
|
||||
→ id=eq.123
|
||||
|
||||
**Text search**: "Find users with Gmail addresses"
|
||||
→ email=ilike.*@gmail.com
|
||||
|
||||
**Range filter**: "Find products under $50"
|
||||
→ price=lt.50
|
||||
|
||||
**Multiple conditions**: "Find active users over 18"
|
||||
→ age=gt.18&status=eq.active
|
||||
|
||||
**OR condition**: "Find active or pending orders"
|
||||
→ or=(status.eq.active,status.eq.pending)
|
||||
|
||||
**In list**: "Find posts in specific categories"
|
||||
→ category=in.(tech,science,health)
|
||||
|
||||
**Null check**: "Find users without a profile picture"
|
||||
→ profile_image=is.null
|
||||
|
||||
### REMEMBER
|
||||
Return ONLY the PostgREST filter expression - no explanations, no markdown, no extra text.`,
|
||||
placeholder: 'Describe the filter condition you need...',
|
||||
generationType: 'postgrest',
|
||||
},
|
||||
},
|
||||
// Optional filter for query operation
|
||||
{
|
||||
@@ -121,6 +301,66 @@ export const SupabaseBlock: BlockConfig<SupabaseResponse> = {
|
||||
layout: 'full',
|
||||
placeholder: 'status=eq.active',
|
||||
condition: { field: 'operation', value: 'query' },
|
||||
wandConfig: {
|
||||
enabled: true,
|
||||
maintainHistory: true,
|
||||
prompt: `You are an expert in PostgREST API syntax. Generate PostgREST filter expressions based on the user's request.
|
||||
|
||||
### CONTEXT
|
||||
{context}
|
||||
|
||||
### CRITICAL INSTRUCTION
|
||||
Return ONLY the PostgREST filter expression. Do not include any explanations, markdown formatting, or additional text. Just the raw filter expression.
|
||||
|
||||
### POSTGREST FILTER SYNTAX
|
||||
PostgREST uses a specific syntax for filtering data. The format is:
|
||||
column=operator.value
|
||||
|
||||
### OPERATORS
|
||||
- **eq** - equals: \`id=eq.123\`
|
||||
- **neq** - not equals: \`status=neq.inactive\`
|
||||
- **gt** - greater than: \`age=gt.18\`
|
||||
- **gte** - greater than or equal: \`score=gte.80\`
|
||||
- **lt** - less than: \`price=lt.100\`
|
||||
- **lte** - less than or equal: \`rating=lte.5\`
|
||||
- **like** - pattern matching: \`name=like.*john*\`
|
||||
- **ilike** - case-insensitive like: \`email=ilike.*@gmail.com\`
|
||||
- **in** - in list: \`category=in.(tech,science,art)\`
|
||||
- **is** - is null/not null: \`deleted_at=is.null\`
|
||||
- **not** - negation: \`not.and=(status.eq.active,verified.eq.true)\`
|
||||
|
||||
### COMBINING FILTERS
|
||||
- **AND**: Use \`&\` or \`and=(...)\`: \`id=eq.123&status=eq.active\`
|
||||
- **OR**: Use \`or=(...)\`: \`or=(status.eq.active,status.eq.pending)\`
|
||||
|
||||
### EXAMPLES
|
||||
|
||||
**Simple equality**: "Find user with ID 123"
|
||||
→ id=eq.123
|
||||
|
||||
**Text search**: "Find users with Gmail addresses"
|
||||
→ email=ilike.*@gmail.com
|
||||
|
||||
**Range filter**: "Find products under $50"
|
||||
→ price=lt.50
|
||||
|
||||
**Multiple conditions**: "Find active users over 18"
|
||||
→ age=gt.18&status=eq.active
|
||||
|
||||
**OR condition**: "Find active or pending orders"
|
||||
→ or=(status.eq.active,status.eq.pending)
|
||||
|
||||
**In list**: "Find posts in specific categories"
|
||||
→ category=in.(tech,science,health)
|
||||
|
||||
**Null check**: "Find users without a profile picture"
|
||||
→ profile_image=is.null
|
||||
|
||||
### REMEMBER
|
||||
Return ONLY the PostgREST filter expression - no explanations, no markdown, no extra text.`,
|
||||
placeholder: 'Describe the filter condition...',
|
||||
generationType: 'postgrest',
|
||||
},
|
||||
},
|
||||
// Optional order by for query operation
|
||||
{
|
||||
|
||||
@@ -17,6 +17,8 @@ export type GenerationType =
|
||||
| 'json-object'
|
||||
| 'system-prompt'
|
||||
| 'custom-tool-schema'
|
||||
| 'sql-query'
|
||||
| 'postgrest'
|
||||
|
||||
// SubBlock types
|
||||
export type SubBlockType =
|
||||
|
||||
122
apps/sim/components/emails/enterprise-subscription-email.tsx
Normal file
122
apps/sim/components/emails/enterprise-subscription-email.tsx
Normal file
@@ -0,0 +1,122 @@
|
||||
import {
|
||||
Body,
|
||||
Column,
|
||||
Container,
|
||||
Head,
|
||||
Html,
|
||||
Img,
|
||||
Link,
|
||||
Preview,
|
||||
Row,
|
||||
Section,
|
||||
Text,
|
||||
} from '@react-email/components'
|
||||
import { format } from 'date-fns'
|
||||
import { getBrandConfig } from '@/lib/branding/branding'
|
||||
import { env } from '@/lib/env'
|
||||
import { getAssetUrl } from '@/lib/utils'
|
||||
import { baseStyles } from './base-styles'
|
||||
import EmailFooter from './footer'
|
||||
|
||||
interface EnterpriseSubscriptionEmailProps {
|
||||
userName?: string
|
||||
userEmail?: string
|
||||
loginLink?: string
|
||||
createdDate?: Date
|
||||
}
|
||||
|
||||
const baseUrl = env.NEXT_PUBLIC_APP_URL || 'https://sim.ai'
|
||||
|
||||
export const EnterpriseSubscriptionEmail = ({
|
||||
userName = 'Valued User',
|
||||
userEmail = '',
|
||||
loginLink = `${baseUrl}/login`,
|
||||
createdDate = new Date(),
|
||||
}: EnterpriseSubscriptionEmailProps) => {
|
||||
const brand = getBrandConfig()
|
||||
|
||||
return (
|
||||
<Html>
|
||||
<Head />
|
||||
<Body style={baseStyles.main}>
|
||||
<Preview>Your Enterprise Plan is now active on Sim</Preview>
|
||||
<Container style={baseStyles.container}>
|
||||
<Section style={{ padding: '30px 0', textAlign: 'center' }}>
|
||||
<Row>
|
||||
<Column style={{ textAlign: 'center' }}>
|
||||
<Img
|
||||
src={brand.logoUrl || getAssetUrl('static/sim.png')}
|
||||
width='114'
|
||||
alt={brand.name}
|
||||
style={{
|
||||
margin: '0 auto',
|
||||
}}
|
||||
/>
|
||||
</Column>
|
||||
</Row>
|
||||
</Section>
|
||||
|
||||
<Section style={baseStyles.sectionsBorders}>
|
||||
<Row>
|
||||
<Column style={baseStyles.sectionBorder} />
|
||||
<Column style={baseStyles.sectionCenter} />
|
||||
<Column style={baseStyles.sectionBorder} />
|
||||
</Row>
|
||||
</Section>
|
||||
|
||||
<Section style={baseStyles.content}>
|
||||
<Text style={baseStyles.paragraph}>Hello {userName},</Text>
|
||||
<Text style={baseStyles.paragraph}>
|
||||
Great news! Your <strong>Enterprise Plan</strong> has been activated on Sim. You now
|
||||
have access to advanced features and increased capacity for your workflows.
|
||||
</Text>
|
||||
|
||||
<Text style={baseStyles.paragraph}>
|
||||
Your account has been set up with full access to your organization. Click below to log
|
||||
in and start exploring your new Enterprise features:
|
||||
</Text>
|
||||
|
||||
<Link href={loginLink} style={{ textDecoration: 'none' }}>
|
||||
<Text style={baseStyles.button}>Access Your Enterprise Account</Text>
|
||||
</Link>
|
||||
|
||||
<Text style={baseStyles.paragraph}>
|
||||
<strong>What's next?</strong>
|
||||
</Text>
|
||||
<Text style={baseStyles.paragraph}>
|
||||
• Invite team members to your organization
|
||||
<br />• Begin building your workflows
|
||||
</Text>
|
||||
|
||||
<Text style={baseStyles.paragraph}>
|
||||
If you have any questions or need assistance getting started, our support team is here
|
||||
to help.
|
||||
</Text>
|
||||
|
||||
<Text style={baseStyles.paragraph}>
|
||||
Welcome to Sim Enterprise!
|
||||
<br />
|
||||
The Sim Team
|
||||
</Text>
|
||||
|
||||
<Text
|
||||
style={{
|
||||
...baseStyles.footerText,
|
||||
marginTop: '40px',
|
||||
textAlign: 'left',
|
||||
color: '#666666',
|
||||
}}
|
||||
>
|
||||
This email was sent on {format(createdDate, 'MMMM do, yyyy')} to {userEmail}
|
||||
regarding your Enterprise plan activation on Sim.
|
||||
</Text>
|
||||
</Section>
|
||||
</Container>
|
||||
|
||||
<EmailFooter baseUrl={baseUrl} />
|
||||
</Body>
|
||||
</Html>
|
||||
)
|
||||
}
|
||||
|
||||
export default EnterpriseSubscriptionEmail
|
||||
@@ -1,5 +1,6 @@
|
||||
export * from './base-styles'
|
||||
export { BatchInvitationEmail } from './batch-invitation-email'
|
||||
export { EnterpriseSubscriptionEmail } from './enterprise-subscription-email'
|
||||
export { default as EmailFooter } from './footer'
|
||||
export { HelpConfirmationEmail } from './help-confirmation-email'
|
||||
export { InvitationEmail } from './invitation-email'
|
||||
|
||||
@@ -14,6 +14,7 @@ import {
|
||||
import { format } from 'date-fns'
|
||||
import { getBrandConfig } from '@/lib/branding/branding'
|
||||
import { env } from '@/lib/env'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { getAssetUrl } from '@/lib/utils'
|
||||
import { baseStyles } from './base-styles'
|
||||
import EmailFooter from './footer'
|
||||
@@ -28,6 +29,8 @@ interface InvitationEmailProps {
|
||||
|
||||
const baseUrl = env.NEXT_PUBLIC_APP_URL || 'https://sim.ai'
|
||||
|
||||
const logger = createLogger('InvitationEmail')
|
||||
|
||||
export const InvitationEmail = ({
|
||||
inviterName = 'A team member',
|
||||
organizationName = 'an organization',
|
||||
@@ -49,7 +52,7 @@ export const InvitationEmail = ({
|
||||
enhancedLink = `${baseUrl}/invite/${invitationId}?token=${invitationId}`
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Error parsing invite link:', e)
|
||||
logger.error('Error parsing invite link:', e)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { render } from '@react-email/components'
|
||||
import {
|
||||
BatchInvitationEmail,
|
||||
EnterpriseSubscriptionEmail,
|
||||
HelpConfirmationEmail,
|
||||
InvitationEmail,
|
||||
OTPVerificationEmail,
|
||||
@@ -82,6 +83,23 @@ export async function renderHelpConfirmationEmail(
|
||||
)
|
||||
}
|
||||
|
||||
export async function renderEnterpriseSubscriptionEmail(
|
||||
userName: string,
|
||||
userEmail: string
|
||||
): Promise<string> {
|
||||
const baseUrl = process.env.NEXT_PUBLIC_APP_URL || 'https://sim.ai'
|
||||
const loginLink = `${baseUrl}/login`
|
||||
|
||||
return await render(
|
||||
EnterpriseSubscriptionEmail({
|
||||
userName,
|
||||
userEmail,
|
||||
loginLink,
|
||||
createdDate: new Date(),
|
||||
})
|
||||
)
|
||||
}
|
||||
|
||||
export function getEmailSubject(
|
||||
type:
|
||||
| 'sign-in'
|
||||
@@ -91,6 +109,7 @@ export function getEmailSubject(
|
||||
| 'invitation'
|
||||
| 'batch-invitation'
|
||||
| 'help-confirmation'
|
||||
| 'enterprise-subscription'
|
||||
): string {
|
||||
const brandName = getBrandConfig().name
|
||||
|
||||
@@ -109,6 +128,8 @@ export function getEmailSubject(
|
||||
return `You've been invited to join a team and workspaces on ${brandName}`
|
||||
case 'help-confirmation':
|
||||
return 'Your request has been received'
|
||||
case 'enterprise-subscription':
|
||||
return `Your Enterprise Plan is now active on ${brandName}`
|
||||
default:
|
||||
return brandName
|
||||
}
|
||||
|
||||
@@ -13,10 +13,13 @@ import {
|
||||
} from '@react-email/components'
|
||||
import { getBrandConfig } from '@/lib/branding/branding'
|
||||
import { env } from '@/lib/env'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { getAssetUrl } from '@/lib/utils'
|
||||
import { baseStyles } from './base-styles'
|
||||
import EmailFooter from './footer'
|
||||
|
||||
const logger = createLogger('WorkspaceInvitationEmail')
|
||||
|
||||
interface WorkspaceInvitationEmailProps {
|
||||
workspaceName?: string
|
||||
inviterName?: string
|
||||
@@ -45,7 +48,7 @@ export const WorkspaceInvitationEmail = ({
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Error enhancing invitation link:', e)
|
||||
logger.error('Error enhancing invitation link:', e)
|
||||
}
|
||||
|
||||
return (
|
||||
|
||||
@@ -14,9 +14,9 @@ const Slider = React.forwardRef<
|
||||
{...props}
|
||||
>
|
||||
<SliderPrimitive.Track className='relative h-2 w-full grow overflow-hidden rounded-full bg-secondary'>
|
||||
<SliderPrimitive.Range className='absolute h-full bg-primary' />
|
||||
<SliderPrimitive.Range className='absolute h-full bg-primary dark:bg-white' />
|
||||
</SliderPrimitive.Track>
|
||||
<SliderPrimitive.Thumb className='block h-5 w-5 rounded-full border-2 border-primary bg-background ring-offset-background transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:pointer-events-none disabled:opacity-50' />
|
||||
<SliderPrimitive.Thumb className='block h-5 w-5 rounded-full border-2 border-primary bg-background ring-offset-background transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:pointer-events-none disabled:opacity-50 dark:border-white dark:bg-black' />
|
||||
</SliderPrimitive.Root>
|
||||
))
|
||||
Slider.displayName = SliderPrimitive.Root.displayName
|
||||
|
||||
@@ -1254,7 +1254,7 @@ export class InputResolver {
|
||||
|
||||
return JSON.parse(normalizedExpression)
|
||||
} catch (jsonError) {
|
||||
console.error('Error parsing JSON for loop:', jsonError)
|
||||
logger.error('Error parsing JSON for loop:', jsonError)
|
||||
// If JSON parsing fails, continue with expression evaluation
|
||||
}
|
||||
}
|
||||
@@ -1267,7 +1267,7 @@ export class InputResolver {
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Error evaluating forEach items:', e)
|
||||
logger.error('Error evaluating forEach items:', e)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1712,7 +1712,7 @@ export class InputResolver {
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Error evaluating parallel distribution items:', e)
|
||||
logger.error('Error evaluating parallel distribution items:', e)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -175,10 +175,7 @@ describe('Full Executor Test', () => {
|
||||
} else {
|
||||
expect(result).toBeDefined()
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Execution error:', error)
|
||||
// Log the error but don't fail the test - we want to see what happens
|
||||
}
|
||||
} catch (error) {}
|
||||
})
|
||||
|
||||
it('should test the executor getNextExecutionLayer method directly', async () => {
|
||||
|
||||
@@ -621,7 +621,7 @@ export function useCollaborativeWorkflow() {
|
||||
}
|
||||
|
||||
if (!blockConfig) {
|
||||
console.error(`Block type ${type} not found`)
|
||||
logger.error(`Block type ${type} not found`)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
'use client'
|
||||
|
||||
import { useCallback, useEffect, useState } from 'react'
|
||||
import type { TagSlot } from '@/lib/constants/knowledge'
|
||||
import type { TagSlot } from '@/lib/knowledge/consts'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
|
||||
const logger = createLogger('useKnowledgeBaseTagDefinitions')
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import { useCallback, useEffect, useMemo, useState } from 'react'
|
||||
import Fuse from 'fuse.js'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { type ChunkData, type DocumentData, useKnowledgeStore } from '@/stores/knowledge/store'
|
||||
|
||||
const logger = createLogger('UseKnowledgeBase')
|
||||
|
||||
export function useKnowledgeBase(id: string) {
|
||||
const { getKnowledgeBase, getCachedKnowledgeBase, loadingKnowledgeBases } = useKnowledgeStore()
|
||||
|
||||
@@ -22,6 +25,7 @@ export function useKnowledgeBase(id: string) {
|
||||
} catch (err) {
|
||||
if (isMounted) {
|
||||
setError(err instanceof Error ? err.message : 'Failed to load knowledge base')
|
||||
logger.error(`Failed to load knowledge base ${id}:`, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -86,6 +90,7 @@ export function useKnowledgeBaseDocuments(
|
||||
} catch (err) {
|
||||
if (isMounted) {
|
||||
setError(err instanceof Error ? err.message : 'Failed to load documents')
|
||||
logger.error(`Failed to load documents for knowledge base ${knowledgeBaseId}:`, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -127,6 +132,7 @@ export function useKnowledgeBaseDocuments(
|
||||
})
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : 'Failed to refresh documents')
|
||||
logger.error(`Failed to refresh documents for knowledge base ${knowledgeBaseId}:`, err)
|
||||
}
|
||||
}, [
|
||||
knowledgeBaseId,
|
||||
@@ -141,6 +147,7 @@ export function useKnowledgeBaseDocuments(
|
||||
const updateDocumentLocal = useCallback(
|
||||
(documentId: string, updates: Partial<DocumentData>) => {
|
||||
updateDocument(knowledgeBaseId, documentId, updates)
|
||||
logger.info(`Updated document ${documentId} for knowledge base ${knowledgeBaseId}`)
|
||||
},
|
||||
[knowledgeBaseId, updateDocument]
|
||||
)
|
||||
@@ -204,10 +211,11 @@ export function useKnowledgeBasesList(workspaceId?: string) {
|
||||
retryTimeoutId = setTimeout(() => {
|
||||
if (isMounted) {
|
||||
loadData(attempt + 1)
|
||||
logger.warn(`Failed to load knowledge bases list, retrying... ${attempt + 1}`)
|
||||
}
|
||||
}, delay)
|
||||
} else {
|
||||
console.error('All retry attempts failed for knowledge bases list:', err)
|
||||
logger.error('All retry attempts failed for knowledge bases list:', err)
|
||||
setError(errorMessage)
|
||||
setRetryCount(maxRetries)
|
||||
}
|
||||
@@ -235,7 +243,7 @@ export function useKnowledgeBasesList(workspaceId?: string) {
|
||||
} catch (err) {
|
||||
const errorMessage = err instanceof Error ? err.message : 'Failed to refresh knowledge bases'
|
||||
setError(errorMessage)
|
||||
console.error('Error refreshing knowledge bases list:', err)
|
||||
logger.error('Error refreshing knowledge bases list:', err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -257,7 +265,7 @@ export function useKnowledgeBasesList(workspaceId?: string) {
|
||||
} catch (err) {
|
||||
const errorMessage = err instanceof Error ? err.message : 'Failed to refresh knowledge bases'
|
||||
setError(errorMessage)
|
||||
console.error('Error force refreshing knowledge bases list:', err)
|
||||
logger.error('Error force refreshing knowledge bases list:', err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -361,6 +369,7 @@ export function useDocumentChunks(
|
||||
} catch (err) {
|
||||
if (isMounted) {
|
||||
setError(err instanceof Error ? err.message : 'Failed to load chunks')
|
||||
logger.error(`Failed to load chunks for document ${documentId}:`, err)
|
||||
}
|
||||
} finally {
|
||||
if (isMounted) {
|
||||
@@ -559,6 +568,7 @@ export function useDocumentChunks(
|
||||
} catch (err) {
|
||||
if (isMounted) {
|
||||
setError(err instanceof Error ? err.message : 'Failed to load chunks')
|
||||
logger.error(`Failed to load chunks for document ${documentId}:`, err)
|
||||
}
|
||||
} finally {
|
||||
if (isMounted) {
|
||||
@@ -599,6 +609,7 @@ export function useDocumentChunks(
|
||||
|
||||
// Update loading state based on store
|
||||
if (!isStoreLoading && isLoading) {
|
||||
logger.info(`Chunks loaded for document ${documentId}`)
|
||||
setIsLoading(false)
|
||||
}
|
||||
}, [documentId, isStoreLoading, isLoading, initialLoadDone, serverSearchQuery, serverCurrentPage])
|
||||
@@ -629,6 +640,7 @@ export function useDocumentChunks(
|
||||
return fetchedChunks
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : 'Failed to load page')
|
||||
logger.error(`Failed to load page for document ${documentId}:`, err)
|
||||
throw err
|
||||
} finally {
|
||||
setIsLoading(false)
|
||||
@@ -676,6 +688,7 @@ export function useDocumentChunks(
|
||||
return fetchedChunks
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : 'Failed to refresh chunks')
|
||||
logger.error(`Failed to refresh chunks for document ${documentId}:`, err)
|
||||
throw err
|
||||
} finally {
|
||||
setIsLoading(false)
|
||||
@@ -704,6 +717,7 @@ export function useDocumentChunks(
|
||||
return searchResults
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : 'Failed to search chunks')
|
||||
logger.error(`Failed to search chunks for document ${documentId}:`, err)
|
||||
throw err
|
||||
} finally {
|
||||
setIsLoading(false)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
'use client'
|
||||
|
||||
import { useCallback, useEffect, useState } from 'react'
|
||||
import type { TagSlot } from '@/lib/constants/knowledge'
|
||||
import type { TagSlot } from '@/lib/knowledge/consts'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
|
||||
const logger = createLogger('useTagDefinitions')
|
||||
|
||||
@@ -10,7 +10,7 @@ import { createAuthClient } from 'better-auth/react'
|
||||
import type { auth } from '@/lib/auth'
|
||||
import { env, getEnv } from '@/lib/env'
|
||||
import { isProd } from '@/lib/environment'
|
||||
import { SessionContext, type SessionHookResult } from '@/lib/session-context'
|
||||
import { SessionContext, type SessionHookResult } from '@/lib/session/session-context'
|
||||
|
||||
export function getBaseURL() {
|
||||
let baseURL
|
||||
|
||||
@@ -24,7 +24,7 @@ import { authorizeSubscriptionReference } from '@/lib/billing/authorization'
|
||||
import { handleNewUser } from '@/lib/billing/core/usage'
|
||||
import { syncSubscriptionUsageLimits } from '@/lib/billing/organization'
|
||||
import { getPlans } from '@/lib/billing/plans'
|
||||
import type { EnterpriseSubscriptionMetadata } from '@/lib/billing/types'
|
||||
import { handleManualEnterpriseSubscription } from '@/lib/billing/webhooks/enterprise'
|
||||
import {
|
||||
handleInvoiceFinalized,
|
||||
handleInvoicePaymentFailed,
|
||||
@@ -52,121 +52,6 @@ if (validStripeKey) {
|
||||
})
|
||||
}
|
||||
|
||||
function isEnterpriseMetadata(value: unknown): value is EnterpriseSubscriptionMetadata {
|
||||
return (
|
||||
!!value &&
|
||||
typeof (value as any).plan === 'string' &&
|
||||
(value as any).plan.toLowerCase() === 'enterprise'
|
||||
)
|
||||
}
|
||||
|
||||
async function handleManualEnterpriseSubscription(event: Stripe.Event) {
|
||||
const stripeSubscription = event.data.object as Stripe.Subscription
|
||||
|
||||
const metaPlan = (stripeSubscription.metadata?.plan as string | undefined)?.toLowerCase() || ''
|
||||
|
||||
if (metaPlan !== 'enterprise') {
|
||||
logger.info('[subscription.created] Skipping non-enterprise subscription', {
|
||||
subscriptionId: stripeSubscription.id,
|
||||
plan: metaPlan || 'unknown',
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
const stripeCustomerId = stripeSubscription.customer as string
|
||||
|
||||
if (!stripeCustomerId) {
|
||||
logger.error('[subscription.created] Missing Stripe customer ID', {
|
||||
subscriptionId: stripeSubscription.id,
|
||||
})
|
||||
throw new Error('Missing Stripe customer ID on subscription')
|
||||
}
|
||||
|
||||
const metadata = stripeSubscription.metadata || {}
|
||||
|
||||
const referenceId =
|
||||
typeof metadata.referenceId === 'string' && metadata.referenceId.length > 0
|
||||
? metadata.referenceId
|
||||
: null
|
||||
|
||||
if (!referenceId) {
|
||||
logger.error('[subscription.created] Unable to resolve referenceId', {
|
||||
subscriptionId: stripeSubscription.id,
|
||||
stripeCustomerId,
|
||||
})
|
||||
throw new Error('Unable to resolve referenceId for subscription')
|
||||
}
|
||||
|
||||
const firstItem = stripeSubscription.items?.data?.[0]
|
||||
const seats = typeof firstItem?.quantity === 'number' ? firstItem.quantity : null
|
||||
|
||||
if (!isEnterpriseMetadata(metadata)) {
|
||||
logger.error('[subscription.created] Invalid enterprise metadata shape', {
|
||||
subscriptionId: stripeSubscription.id,
|
||||
metadata,
|
||||
})
|
||||
throw new Error('Invalid enterprise metadata for subscription')
|
||||
}
|
||||
const enterpriseMetadata = metadata
|
||||
const metadataJson: Record<string, unknown> = { ...enterpriseMetadata }
|
||||
|
||||
const subscriptionRow = {
|
||||
id: crypto.randomUUID(),
|
||||
plan: 'enterprise',
|
||||
referenceId,
|
||||
stripeCustomerId,
|
||||
stripeSubscriptionId: stripeSubscription.id,
|
||||
status: stripeSubscription.status || null,
|
||||
periodStart: stripeSubscription.current_period_start
|
||||
? new Date(stripeSubscription.current_period_start * 1000)
|
||||
: null,
|
||||
periodEnd: stripeSubscription.current_period_end
|
||||
? new Date(stripeSubscription.current_period_end * 1000)
|
||||
: null,
|
||||
cancelAtPeriodEnd: stripeSubscription.cancel_at_period_end ?? null,
|
||||
seats,
|
||||
trialStart: stripeSubscription.trial_start
|
||||
? new Date(stripeSubscription.trial_start * 1000)
|
||||
: null,
|
||||
trialEnd: stripeSubscription.trial_end ? new Date(stripeSubscription.trial_end * 1000) : null,
|
||||
metadata: metadataJson,
|
||||
}
|
||||
|
||||
const existing = await db
|
||||
.select({ id: schema.subscription.id })
|
||||
.from(schema.subscription)
|
||||
.where(eq(schema.subscription.stripeSubscriptionId, stripeSubscription.id))
|
||||
.limit(1)
|
||||
|
||||
if (existing.length > 0) {
|
||||
await db
|
||||
.update(schema.subscription)
|
||||
.set({
|
||||
plan: subscriptionRow.plan,
|
||||
referenceId: subscriptionRow.referenceId,
|
||||
stripeCustomerId: subscriptionRow.stripeCustomerId,
|
||||
status: subscriptionRow.status,
|
||||
periodStart: subscriptionRow.periodStart,
|
||||
periodEnd: subscriptionRow.periodEnd,
|
||||
cancelAtPeriodEnd: subscriptionRow.cancelAtPeriodEnd,
|
||||
seats: subscriptionRow.seats,
|
||||
trialStart: subscriptionRow.trialStart,
|
||||
trialEnd: subscriptionRow.trialEnd,
|
||||
metadata: subscriptionRow.metadata,
|
||||
})
|
||||
.where(eq(schema.subscription.stripeSubscriptionId, stripeSubscription.id))
|
||||
} else {
|
||||
await db.insert(schema.subscription).values(subscriptionRow)
|
||||
}
|
||||
|
||||
logger.info('[subscription.created] Upserted subscription', {
|
||||
subscriptionId: subscriptionRow.id,
|
||||
referenceId: subscriptionRow.referenceId,
|
||||
plan: subscriptionRow.plan,
|
||||
status: subscriptionRow.status,
|
||||
})
|
||||
}
|
||||
|
||||
export const auth = betterAuth({
|
||||
baseURL: getBaseURL(),
|
||||
trustedOrigins: [
|
||||
@@ -1161,7 +1046,7 @@ export const auth = betterAuth({
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text()
|
||||
console.error('Linear API error:', {
|
||||
logger.error('Linear API error:', {
|
||||
status: response.status,
|
||||
statusText: response.statusText,
|
||||
body: errorText,
|
||||
@@ -1172,12 +1057,12 @@ export const auth = betterAuth({
|
||||
const { data, errors } = await response.json()
|
||||
|
||||
if (errors) {
|
||||
console.error('GraphQL errors:', errors)
|
||||
logger.error('GraphQL errors:', errors)
|
||||
throw new Error(`GraphQL errors: ${JSON.stringify(errors)}`)
|
||||
}
|
||||
|
||||
if (!data?.viewer) {
|
||||
console.error('No viewer data in response:', data)
|
||||
logger.error('No viewer data in response:', data)
|
||||
throw new Error('No viewer data in response')
|
||||
}
|
||||
|
||||
@@ -1193,7 +1078,7 @@ export const auth = betterAuth({
|
||||
image: viewer.avatarUrl || null,
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error in getUserInfo:', error)
|
||||
logger.error('Error in getUserInfo:', error)
|
||||
throw error
|
||||
}
|
||||
},
|
||||
|
||||
@@ -31,9 +31,7 @@ export async function checkUsageStatus(userId: string): Promise<UsageData> {
|
||||
const statsRecords = await db.select().from(userStats).where(eq(userStats.userId, userId))
|
||||
const currentUsage =
|
||||
statsRecords.length > 0
|
||||
? Number.parseFloat(
|
||||
statsRecords[0].currentPeriodCost?.toString() || statsRecords[0].totalCost.toString()
|
||||
)
|
||||
? Number.parseFloat(statsRecords[0].currentPeriodCost?.toString())
|
||||
: 0
|
||||
|
||||
return {
|
||||
@@ -117,7 +115,7 @@ export async function checkUsageStatus(userId: string): Promise<UsageData> {
|
||||
// Fall back to minimum billing amount from Stripe subscription
|
||||
const orgSub = await getOrganizationSubscription(org.id)
|
||||
if (orgSub?.seats) {
|
||||
const { basePrice } = getPlanPricing(orgSub.plan, orgSub)
|
||||
const { basePrice } = getPlanPricing(orgSub.plan)
|
||||
orgCap = (orgSub.seats || 1) * basePrice
|
||||
} else {
|
||||
// If no subscription, use team default
|
||||
|
||||
@@ -2,12 +2,10 @@ import { and, eq } from 'drizzle-orm'
|
||||
import { getHighestPrioritySubscription } from '@/lib/billing/core/subscription'
|
||||
import { getUserUsageData } from '@/lib/billing/core/usage'
|
||||
import {
|
||||
getEnterpriseTierLimitPerSeat,
|
||||
getFreeTierLimit,
|
||||
getProTierLimit,
|
||||
getTeamTierLimitPerSeat,
|
||||
} from '@/lib/billing/subscriptions/utils'
|
||||
import type { EnterpriseSubscriptionMetadata } from '@/lib/billing/types'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { db } from '@/db'
|
||||
import { member, subscription, user } from '@/db/schema'
|
||||
@@ -43,11 +41,8 @@ export async function getOrganizationSubscription(organizationId: string) {
|
||||
/**
|
||||
* Get plan pricing information
|
||||
*/
|
||||
export function getPlanPricing(
|
||||
plan: string,
|
||||
subscription?: any
|
||||
): {
|
||||
basePrice: number // What they pay upfront via Stripe subscription (per seat for team/enterprise)
|
||||
export function getPlanPricing(plan: string): {
|
||||
basePrice: number // What they pay upfront via Stripe subscription
|
||||
} {
|
||||
switch (plan) {
|
||||
case 'free':
|
||||
@@ -55,25 +50,7 @@ export function getPlanPricing(
|
||||
case 'pro':
|
||||
return { basePrice: getProTierLimit() }
|
||||
case 'team':
|
||||
return { basePrice: getTeamTierLimitPerSeat() }
|
||||
case 'enterprise':
|
||||
// Enterprise uses per-seat pricing like Team plans
|
||||
// Custom per-seat price can be set in metadata
|
||||
if (subscription?.metadata) {
|
||||
const metadata: EnterpriseSubscriptionMetadata =
|
||||
typeof subscription.metadata === 'string'
|
||||
? JSON.parse(subscription.metadata)
|
||||
: subscription.metadata
|
||||
|
||||
const perSeatPrice = metadata.perSeatPrice
|
||||
? Number.parseFloat(String(metadata.perSeatPrice))
|
||||
: undefined
|
||||
if (perSeatPrice && perSeatPrice > 0 && !Number.isNaN(perSeatPrice)) {
|
||||
return { basePrice: perSeatPrice }
|
||||
}
|
||||
}
|
||||
// Default enterprise per-seat pricing
|
||||
return { basePrice: getEnterpriseTierLimitPerSeat() }
|
||||
return { basePrice: getTeamTierLimitPerSeat() } // Per-seat pricing
|
||||
default:
|
||||
return { basePrice: 0 }
|
||||
}
|
||||
@@ -103,7 +80,7 @@ export async function calculateUserOverage(userId: string): Promise<{
|
||||
}
|
||||
|
||||
const plan = subscription?.plan || 'free'
|
||||
const { basePrice } = getPlanPricing(plan, subscription)
|
||||
const { basePrice } = getPlanPricing(plan)
|
||||
const actualUsage = usageData.currentUsage
|
||||
|
||||
// Calculate overage: any usage beyond what they already paid for
|
||||
@@ -197,7 +174,7 @@ export async function getSimplifiedBillingSummary(
|
||||
.from(member)
|
||||
.where(eq(member.organizationId, organizationId))
|
||||
|
||||
const { basePrice: basePricePerSeat } = getPlanPricing(subscription.plan, subscription)
|
||||
const { basePrice: basePricePerSeat } = getPlanPricing(subscription.plan)
|
||||
// Use licensed seats from Stripe as source of truth
|
||||
const licensedSeats = subscription.seats || 1
|
||||
const totalBasePrice = basePricePerSeat * licensedSeats // Based on Stripe subscription
|
||||
@@ -270,7 +247,7 @@ export async function getSimplifiedBillingSummary(
|
||||
}
|
||||
|
||||
// Individual billing summary
|
||||
const { basePrice } = getPlanPricing(plan, subscription)
|
||||
const { basePrice } = getPlanPricing(plan)
|
||||
|
||||
// For team and enterprise plans, calculate total team usage instead of individual usage
|
||||
let currentUsage = usageData.currentUsage
|
||||
|
||||
@@ -131,35 +131,38 @@ export async function getOrganizationBillingData(
|
||||
const totalCurrentUsage = members.reduce((sum, member) => sum + member.currentUsage, 0)
|
||||
|
||||
// Get per-seat pricing for the plan
|
||||
const { basePrice: pricePerSeat } = getPlanPricing(subscription.plan, subscription)
|
||||
const { basePrice: pricePerSeat } = getPlanPricing(subscription.plan)
|
||||
|
||||
// Use Stripe subscription seats as source of truth
|
||||
// Ensure we always have at least 1 seat (protect against 0 or falsy values)
|
||||
const licensedSeats = Math.max(subscription.seats || 1, 1)
|
||||
|
||||
// Validate seat capacity - warn if members exceed licensed seats
|
||||
if (members.length > licensedSeats) {
|
||||
logger.warn('Organization has more members than licensed seats', {
|
||||
organizationId,
|
||||
licensedSeats,
|
||||
actualMembers: members.length,
|
||||
plan: subscription.plan,
|
||||
})
|
||||
// Calculate minimum billing amount
|
||||
let minimumBillingAmount: number
|
||||
let totalUsageLimit: number
|
||||
|
||||
if (subscription.plan === 'enterprise') {
|
||||
// Enterprise has fixed pricing set through custom Stripe product
|
||||
// Their usage limit is configured to match their monthly cost
|
||||
const configuredLimit = organizationData.orgUsageLimit
|
||||
? Number.parseFloat(organizationData.orgUsageLimit)
|
||||
: 0
|
||||
minimumBillingAmount = configuredLimit // For enterprise, this equals their fixed monthly cost
|
||||
totalUsageLimit = configuredLimit // Same as their monthly cost
|
||||
} else {
|
||||
// Team plan: Billing is based on licensed seats from Stripe
|
||||
minimumBillingAmount = licensedSeats * pricePerSeat
|
||||
|
||||
// Total usage limit: never below the minimum based on licensed seats
|
||||
const configuredLimit = organizationData.orgUsageLimit
|
||||
? Number.parseFloat(organizationData.orgUsageLimit)
|
||||
: null
|
||||
totalUsageLimit =
|
||||
configuredLimit !== null
|
||||
? Math.max(configuredLimit, minimumBillingAmount)
|
||||
: minimumBillingAmount
|
||||
}
|
||||
|
||||
// Billing is based on licensed seats from Stripe, not actual member count
|
||||
// This ensures organizations pay for their seat capacity regardless of utilization
|
||||
const minimumBillingAmount = licensedSeats * pricePerSeat
|
||||
|
||||
// Total usage limit: never below the minimum based on licensed seats
|
||||
const configuredLimit = organizationData.orgUsageLimit
|
||||
? Number.parseFloat(organizationData.orgUsageLimit)
|
||||
: null
|
||||
const totalUsageLimit =
|
||||
configuredLimit !== null
|
||||
? Math.max(configuredLimit, minimumBillingAmount)
|
||||
: minimumBillingAmount
|
||||
|
||||
const averageUsagePerMember = members.length > 0 ? totalCurrentUsage / members.length : 0
|
||||
|
||||
// Billing period comes from the organization's subscription
|
||||
@@ -213,8 +216,24 @@ export async function updateOrganizationUsageLimit(
|
||||
return { success: false, error: 'No active subscription found' }
|
||||
}
|
||||
|
||||
// Calculate minimum based on seats
|
||||
const { basePrice } = getPlanPricing(subscription.plan, subscription)
|
||||
// Enterprise plans have fixed usage limits that cannot be changed
|
||||
if (subscription.plan === 'enterprise') {
|
||||
return {
|
||||
success: false,
|
||||
error: 'Enterprise plans have fixed usage limits that cannot be changed',
|
||||
}
|
||||
}
|
||||
|
||||
// Only team plans can update their usage limits
|
||||
if (subscription.plan !== 'team') {
|
||||
return {
|
||||
success: false,
|
||||
error: 'Only team organizations can update usage limits',
|
||||
}
|
||||
}
|
||||
|
||||
// Team plans have minimum based on seats
|
||||
const { basePrice } = getPlanPricing(subscription.plan)
|
||||
const minimumLimit = Math.max(subscription.seats || 1, 1) * basePrice
|
||||
|
||||
// Validate new limit is not below minimum
|
||||
@@ -315,3 +334,33 @@ export async function getOrganizationBillingSummary(organizationId: string) {
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a user is an owner or admin of a specific organization
|
||||
*
|
||||
* @param userId - The ID of the user to check
|
||||
* @param organizationId - The ID of the organization
|
||||
* @returns Promise<boolean> - True if the user is an owner or admin of the organization
|
||||
*/
|
||||
export async function isOrganizationOwnerOrAdmin(
|
||||
userId: string,
|
||||
organizationId: string
|
||||
): Promise<boolean> {
|
||||
try {
|
||||
const memberRecord = await db
|
||||
.select({ role: member.role })
|
||||
.from(member)
|
||||
.where(and(eq(member.userId, userId), eq(member.organizationId, organizationId)))
|
||||
.limit(1)
|
||||
|
||||
if (memberRecord.length === 0) {
|
||||
return false
|
||||
}
|
||||
|
||||
const userRole = memberRecord[0].role
|
||||
return ['owner', 'admin'].includes(userRole)
|
||||
} catch (error) {
|
||||
logger.error('Error checking organization ownership/admin status:', error)
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -157,14 +157,26 @@ export async function hasExceededCostLimit(userId: string): Promise<boolean> {
|
||||
|
||||
// Calculate usage limit
|
||||
let limit = getFreeTierLimit() // Default free tier limit
|
||||
|
||||
if (subscription) {
|
||||
limit = getPerUserMinimumLimit(subscription)
|
||||
logger.info('Using subscription-based limit', {
|
||||
userId,
|
||||
plan: subscription.plan,
|
||||
seats: subscription.seats || 1,
|
||||
limit,
|
||||
})
|
||||
// Team/Enterprise: Use organization limit
|
||||
if (subscription.plan === 'team' || subscription.plan === 'enterprise') {
|
||||
const { getUserUsageLimit } = await import('@/lib/billing/core/usage')
|
||||
limit = await getUserUsageLimit(userId)
|
||||
logger.info('Using organization limit', {
|
||||
userId,
|
||||
plan: subscription.plan,
|
||||
limit,
|
||||
})
|
||||
} else {
|
||||
// Pro/Free: Use individual limit
|
||||
limit = getPerUserMinimumLimit(subscription)
|
||||
logger.info('Using subscription-based limit', {
|
||||
userId,
|
||||
plan: subscription.plan,
|
||||
limit,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
logger.info('Using free tier limit', { userId, limit })
|
||||
}
|
||||
@@ -231,7 +243,14 @@ export async function getUserSubscriptionState(userId: string): Promise<UserSubs
|
||||
if (isProd && statsRecords.length > 0) {
|
||||
let limit = getFreeTierLimit() // Default free tier limit
|
||||
if (subscription) {
|
||||
limit = getPerUserMinimumLimit(subscription)
|
||||
// Team/Enterprise: Use organization limit
|
||||
if (subscription.plan === 'team' || subscription.plan === 'enterprise') {
|
||||
const { getUserUsageLimit } = await import('@/lib/billing/core/usage')
|
||||
limit = await getUserUsageLimit(userId)
|
||||
} else {
|
||||
// Pro/Free: Use individual limit
|
||||
limit = getPerUserMinimumLimit(subscription)
|
||||
}
|
||||
}
|
||||
|
||||
const currentCost = Number.parseFloat(
|
||||
|
||||
@@ -71,7 +71,7 @@ export async function getUserUsageData(userId: string): Promise<UsageData> {
|
||||
.limit(1)
|
||||
|
||||
const { getPlanPricing } = await import('@/lib/billing/core/billing')
|
||||
const { basePrice } = getPlanPricing(subscription.plan, subscription)
|
||||
const { basePrice } = getPlanPricing(subscription.plan)
|
||||
const minimum = (subscription.seats || 1) * basePrice
|
||||
|
||||
if (orgData.length > 0 && orgData[0].orgUsageLimit) {
|
||||
@@ -144,7 +144,7 @@ export async function getUserUsageLimitInfo(userId: string): Promise<UsageLimitI
|
||||
.limit(1)
|
||||
|
||||
const { getPlanPricing } = await import('@/lib/billing/core/billing')
|
||||
const { basePrice } = getPlanPricing(subscription.plan, subscription)
|
||||
const { basePrice } = getPlanPricing(subscription.plan)
|
||||
const minimum = (subscription.seats || 1) * basePrice
|
||||
|
||||
if (orgData.length > 0 && orgData[0].orgUsageLimit) {
|
||||
@@ -335,14 +335,14 @@ export async function getUserUsageLimit(userId: string): Promise<number> {
|
||||
if (orgData[0].orgUsageLimit) {
|
||||
const configured = Number.parseFloat(orgData[0].orgUsageLimit)
|
||||
const { getPlanPricing } = await import('@/lib/billing/core/billing')
|
||||
const { basePrice } = getPlanPricing(subscription.plan, subscription)
|
||||
const { basePrice } = getPlanPricing(subscription.plan)
|
||||
const minimum = (subscription.seats || 1) * basePrice
|
||||
return Math.max(configured, minimum)
|
||||
}
|
||||
|
||||
// If org hasn't set a custom limit, use minimum (seats × cost per seat)
|
||||
const { getPlanPricing } = await import('@/lib/billing/core/billing')
|
||||
const { basePrice } = getPlanPricing(subscription.plan, subscription)
|
||||
const { basePrice } = getPlanPricing(subscription.plan)
|
||||
return (subscription.seats || 1) * basePrice
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
export * from '@/lib/billing/calculations/usage-monitor'
|
||||
export * from '@/lib/billing/core/billing'
|
||||
export * from '@/lib/billing/core/organization-billing'
|
||||
export * from '@/lib/billing/core/organization'
|
||||
export * from '@/lib/billing/core/subscription'
|
||||
export {
|
||||
getHighestPrioritySubscription as getActiveSubscription,
|
||||
@@ -23,10 +23,6 @@ export {
|
||||
updateUserUsageLimit as updateUsageLimit,
|
||||
} from '@/lib/billing/core/usage'
|
||||
export * from '@/lib/billing/subscriptions/utils'
|
||||
export {
|
||||
canEditUsageLimit as canEditLimit,
|
||||
getMinimumUsageLimit as getMinimumLimit,
|
||||
getSubscriptionAllowance as getDefaultLimit,
|
||||
} from '@/lib/billing/subscriptions/utils'
|
||||
export { canEditUsageLimit as canEditLimit } from '@/lib/billing/subscriptions/utils'
|
||||
export * from '@/lib/billing/types'
|
||||
export * from '@/lib/billing/validation/seat-management'
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
import { checkEnterprisePlan, getSubscriptionAllowance } from '@/lib/billing/subscriptions/utils'
|
||||
|
||||
vi.mock('@/lib/env', () => ({
|
||||
env: {
|
||||
FREE_TIER_COST_LIMIT: 10,
|
||||
PRO_TIER_COST_LIMIT: 20,
|
||||
TEAM_TIER_COST_LIMIT: 40,
|
||||
ENTERPRISE_TIER_COST_LIMIT: 200,
|
||||
},
|
||||
isTruthy: (value: string | boolean | number | undefined) =>
|
||||
typeof value === 'string' ? value.toLowerCase() === 'true' || value === '1' : Boolean(value),
|
||||
getEnv: (variable: string) => process.env[variable],
|
||||
}))
|
||||
|
||||
describe('Subscription Utilities', () => {
|
||||
describe('checkEnterprisePlan', () => {
|
||||
it.concurrent('returns true for active enterprise subscription', () => {
|
||||
expect(checkEnterprisePlan({ plan: 'enterprise', status: 'active' })).toBeTruthy()
|
||||
})
|
||||
|
||||
it.concurrent('returns false for inactive enterprise subscription', () => {
|
||||
expect(checkEnterprisePlan({ plan: 'enterprise', status: 'canceled' })).toBeFalsy()
|
||||
})
|
||||
|
||||
it.concurrent('returns false when plan is not enterprise', () => {
|
||||
expect(checkEnterprisePlan({ plan: 'pro', status: 'active' })).toBeFalsy()
|
||||
})
|
||||
})
|
||||
|
||||
describe('getSubscriptionAllowance', () => {
|
||||
it.concurrent('returns free-tier limit when subscription is null', () => {
|
||||
expect(getSubscriptionAllowance(null)).toBe(10)
|
||||
})
|
||||
|
||||
it.concurrent('returns free-tier limit when subscription is undefined', () => {
|
||||
expect(getSubscriptionAllowance(undefined)).toBe(10)
|
||||
})
|
||||
|
||||
it.concurrent('returns free-tier limit when subscription is not active', () => {
|
||||
expect(getSubscriptionAllowance({ plan: 'pro', status: 'canceled', seats: 1 })).toBe(10)
|
||||
})
|
||||
|
||||
it.concurrent('returns pro limit for active pro plan', () => {
|
||||
expect(getSubscriptionAllowance({ plan: 'pro', status: 'active', seats: 1 })).toBe(20)
|
||||
})
|
||||
|
||||
it.concurrent('returns team limit multiplied by seats', () => {
|
||||
expect(getSubscriptionAllowance({ plan: 'team', status: 'active', seats: 3 })).toBe(3 * 40)
|
||||
})
|
||||
|
||||
it.concurrent('returns enterprise limit using perSeatPrice metadata', () => {
|
||||
const sub = {
|
||||
plan: 'enterprise',
|
||||
status: 'active',
|
||||
seats: 10,
|
||||
metadata: { perSeatPrice: 150 },
|
||||
}
|
||||
expect(getSubscriptionAllowance(sub)).toBe(10 * 150)
|
||||
})
|
||||
|
||||
it.concurrent('returns enterprise limit using perSeatPrice as string', () => {
|
||||
const sub = {
|
||||
plan: 'enterprise',
|
||||
status: 'active',
|
||||
seats: 8,
|
||||
metadata: { perSeatPrice: '250' },
|
||||
}
|
||||
expect(getSubscriptionAllowance(sub)).toBe(8 * 250)
|
||||
})
|
||||
|
||||
it.concurrent('falls back to default enterprise tier when metadata missing', () => {
|
||||
const sub = { plan: 'enterprise', status: 'active', seats: 2, metadata: {} }
|
||||
expect(getSubscriptionAllowance(sub)).toBe(2 * 200)
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -4,7 +4,6 @@ import {
|
||||
DEFAULT_PRO_TIER_COST_LIMIT,
|
||||
DEFAULT_TEAM_TIER_COST_LIMIT,
|
||||
} from '@/lib/billing/constants'
|
||||
import type { EnterpriseSubscriptionMetadata } from '@/lib/billing/types'
|
||||
import { env } from '@/lib/env'
|
||||
|
||||
/**
|
||||
@@ -47,51 +46,10 @@ export function checkTeamPlan(subscription: any): boolean {
|
||||
return subscription?.plan === 'team' && subscription?.status === 'active'
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate the total subscription-level allowance (what the org/user gets for their base payment)
|
||||
* - Pro: Fixed amount per user
|
||||
* - Team: Seats * base price (pooled for the org)
|
||||
* - Enterprise: Seats * per-seat price (pooled, with optional custom pricing in metadata)
|
||||
* @param subscription The subscription object
|
||||
* @returns The total subscription allowance in dollars
|
||||
*/
|
||||
export function getSubscriptionAllowance(subscription: any): number {
|
||||
if (!subscription || subscription.status !== 'active') {
|
||||
return getFreeTierLimit()
|
||||
}
|
||||
|
||||
const seats = subscription.seats || 1
|
||||
|
||||
if (subscription.plan === 'pro') {
|
||||
return getProTierLimit()
|
||||
}
|
||||
if (subscription.plan === 'team') {
|
||||
return seats * getTeamTierLimitPerSeat()
|
||||
}
|
||||
if (subscription.plan === 'enterprise') {
|
||||
const metadata = subscription.metadata as EnterpriseSubscriptionMetadata | undefined
|
||||
|
||||
// Enterprise uses per-seat pricing (pooled like Team)
|
||||
// Custom per-seat price can be set in metadata
|
||||
let perSeatPrice = getEnterpriseTierLimitPerSeat()
|
||||
if (metadata?.perSeatPrice) {
|
||||
const parsed = Number.parseFloat(String(metadata.perSeatPrice))
|
||||
if (parsed > 0 && !Number.isNaN(parsed)) {
|
||||
perSeatPrice = parsed
|
||||
}
|
||||
}
|
||||
|
||||
return seats * perSeatPrice
|
||||
}
|
||||
|
||||
return getFreeTierLimit()
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the minimum usage limit for an individual user (used for validation)
|
||||
* - Pro: User's plan minimum
|
||||
* - Team: 0 (pooled model, no individual minimums)
|
||||
* - Enterprise: 0 (pooled model, no individual minimums)
|
||||
* Only applicable for plans with individual limits (Free/Pro)
|
||||
* Team and Enterprise plans use organization-level limits instead
|
||||
* @param subscription The subscription object
|
||||
* @returns The per-user minimum limit in dollars
|
||||
*/
|
||||
@@ -100,27 +58,15 @@ export function getPerUserMinimumLimit(subscription: any): number {
|
||||
return getFreeTierLimit()
|
||||
}
|
||||
|
||||
const seats = subscription.seats || 1
|
||||
|
||||
if (subscription.plan === 'pro') {
|
||||
return getProTierLimit()
|
||||
}
|
||||
if (subscription.plan === 'team') {
|
||||
// For team plans, return the total pooled limit (seats * cost per seat)
|
||||
// This becomes the user's individual limit representing their share of the team pool
|
||||
return seats * getTeamTierLimitPerSeat()
|
||||
}
|
||||
if (subscription.plan === 'enterprise') {
|
||||
// For enterprise plans, return the total pooled limit (seats * cost per seat)
|
||||
// This becomes the user's individual limit representing their share of the enterprise pool
|
||||
let perSeatPrice = getEnterpriseTierLimitPerSeat()
|
||||
if (subscription.metadata?.perSeatPrice) {
|
||||
const parsed = Number.parseFloat(String(subscription.metadata.perSeatPrice))
|
||||
if (parsed > 0 && !Number.isNaN(parsed)) {
|
||||
perSeatPrice = parsed
|
||||
}
|
||||
}
|
||||
return seats * perSeatPrice
|
||||
|
||||
if (subscription.plan === 'team' || subscription.plan === 'enterprise') {
|
||||
// Team and Enterprise don't have individual limits - they use organization limits
|
||||
// This function should not be called for these plans
|
||||
// Returning 0 to indicate no individual minimum
|
||||
return 0
|
||||
}
|
||||
|
||||
return getFreeTierLimit()
|
||||
@@ -128,7 +74,8 @@ export function getPerUserMinimumLimit(subscription: any): number {
|
||||
|
||||
/**
|
||||
* Check if a user can edit their usage limits based on their subscription
|
||||
* Free plan users cannot edit limits, paid plan users can
|
||||
* Free and Enterprise plans cannot edit limits
|
||||
* Pro and Team plans can increase their limits
|
||||
* @param subscription The subscription object
|
||||
* @returns Whether the user can edit their usage limits
|
||||
*/
|
||||
@@ -137,19 +84,7 @@ export function canEditUsageLimit(subscription: any): boolean {
|
||||
return false // Free plan users cannot edit limits
|
||||
}
|
||||
|
||||
return (
|
||||
subscription.plan === 'pro' ||
|
||||
subscription.plan === 'team' ||
|
||||
subscription.plan === 'enterprise'
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the minimum allowed usage limit for a subscription
|
||||
* This prevents users from setting limits below their plan's base amount
|
||||
* @param subscription The subscription object
|
||||
* @returns The minimum allowed usage limit in dollars
|
||||
*/
|
||||
export function getMinimumUsageLimit(subscription: any): number {
|
||||
return getPerUserMinimumLimit(subscription)
|
||||
// Only Pro and Team plans can edit limits
|
||||
// Enterprise has fixed limits that match their monthly cost
|
||||
return subscription.plan === 'pro' || subscription.plan === 'team'
|
||||
}
|
||||
|
||||
@@ -5,15 +5,15 @@
|
||||
|
||||
export interface EnterpriseSubscriptionMetadata {
|
||||
plan: 'enterprise'
|
||||
// Custom per-seat pricing (defaults to DEFAULT_ENTERPRISE_TIER_COST_LIMIT)
|
||||
// The referenceId must be provided in Stripe metadata to link to the organization
|
||||
// This gets stored in the subscription.referenceId column
|
||||
referenceId: string
|
||||
perSeatPrice?: number
|
||||
|
||||
// Maximum allowed seats (defaults to subscription.seats)
|
||||
maxSeats?: number
|
||||
|
||||
// Whether seats are fixed and cannot be changed
|
||||
fixedSeats?: boolean
|
||||
// The fixed monthly price for this enterprise customer (as string from Stripe metadata)
|
||||
// This will be used to set the organization's usage limit
|
||||
monthlyPrice: string
|
||||
// Number of seats for invitation limits (not for billing) (as string from Stripe metadata)
|
||||
// We set Stripe quantity to 1 and use this for actual seat count
|
||||
seats: string
|
||||
}
|
||||
|
||||
export interface UsageData {
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { and, count, eq } from 'drizzle-orm'
|
||||
import { getOrganizationSubscription } from '@/lib/billing/core/billing'
|
||||
import type { EnterpriseSubscriptionMetadata } from '@/lib/billing/types'
|
||||
import { quickValidateEmail } from '@/lib/email/validation'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { db } from '@/db'
|
||||
@@ -67,26 +66,9 @@ export async function validateSeatAvailability(
|
||||
const currentSeats = memberCount[0]?.count || 0
|
||||
|
||||
// Determine seat limits based on subscription
|
||||
let maxSeats = subscription.seats || 1
|
||||
|
||||
// For enterprise plans, check metadata for custom seat allowances
|
||||
if (subscription.plan === 'enterprise' && subscription.metadata) {
|
||||
try {
|
||||
const metadata: EnterpriseSubscriptionMetadata =
|
||||
typeof subscription.metadata === 'string'
|
||||
? JSON.parse(subscription.metadata)
|
||||
: subscription.metadata
|
||||
if (metadata.maxSeats && typeof metadata.maxSeats === 'number') {
|
||||
maxSeats = metadata.maxSeats
|
||||
}
|
||||
} catch (error) {
|
||||
logger.warn('Failed to parse enterprise subscription metadata', {
|
||||
organizationId,
|
||||
metadata: subscription.metadata,
|
||||
error,
|
||||
})
|
||||
}
|
||||
}
|
||||
// Team: seats from Stripe subscription quantity
|
||||
// Enterprise: seats from metadata (stored in subscription.seats)
|
||||
const maxSeats = subscription.seats || 1
|
||||
|
||||
const availableSeats = Math.max(0, maxSeats - currentSeats)
|
||||
const canInvite = availableSeats >= additionalSeats
|
||||
@@ -162,24 +144,11 @@ export async function getOrganizationSeatInfo(
|
||||
const currentSeats = memberCount[0]?.count || 0
|
||||
|
||||
// Determine seat limits
|
||||
let maxSeats = subscription.seats || 1
|
||||
let canAddSeats = true
|
||||
const maxSeats = subscription.seats || 1
|
||||
|
||||
if (subscription.plan === 'enterprise' && subscription.metadata) {
|
||||
try {
|
||||
const metadata: EnterpriseSubscriptionMetadata =
|
||||
typeof subscription.metadata === 'string'
|
||||
? JSON.parse(subscription.metadata)
|
||||
: subscription.metadata
|
||||
if (metadata.maxSeats && typeof metadata.maxSeats === 'number') {
|
||||
maxSeats = metadata.maxSeats
|
||||
}
|
||||
// Enterprise plans might have fixed seat counts
|
||||
canAddSeats = !metadata.fixedSeats
|
||||
} catch (error) {
|
||||
logger.warn('Failed to parse enterprise subscription metadata', { organizationId, error })
|
||||
}
|
||||
}
|
||||
// Enterprise plans have fixed seats (can't self-serve changes)
|
||||
// Team plans can add seats through Stripe
|
||||
const canAddSeats = subscription.plan !== 'enterprise'
|
||||
|
||||
const availableSeats = Math.max(0, maxSeats - currentSeats)
|
||||
|
||||
|
||||
251
apps/sim/lib/billing/webhooks/enterprise.ts
Normal file
251
apps/sim/lib/billing/webhooks/enterprise.ts
Normal file
@@ -0,0 +1,251 @@
|
||||
import { eq } from 'drizzle-orm'
|
||||
import type Stripe from 'stripe'
|
||||
import {
|
||||
getEmailSubject,
|
||||
renderEnterpriseSubscriptionEmail,
|
||||
} from '@/components/emails/render-email'
|
||||
import { sendEmail } from '@/lib/email/mailer'
|
||||
import { getFromEmailAddress } from '@/lib/email/utils'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { db } from '@/db'
|
||||
import { organization, subscription, user } from '@/db/schema'
|
||||
import type { EnterpriseSubscriptionMetadata } from '../types'
|
||||
|
||||
const logger = createLogger('BillingEnterprise')
|
||||
|
||||
function isEnterpriseMetadata(value: unknown): value is EnterpriseSubscriptionMetadata {
|
||||
return (
|
||||
!!value &&
|
||||
typeof value === 'object' &&
|
||||
'plan' in value &&
|
||||
'referenceId' in value &&
|
||||
'monthlyPrice' in value &&
|
||||
'seats' in value &&
|
||||
typeof value.plan === 'string' &&
|
||||
value.plan.toLowerCase() === 'enterprise' &&
|
||||
typeof value.referenceId === 'string' &&
|
||||
typeof value.monthlyPrice === 'string' &&
|
||||
typeof value.seats === 'string'
|
||||
)
|
||||
}
|
||||
|
||||
export async function handleManualEnterpriseSubscription(event: Stripe.Event) {
|
||||
const stripeSubscription = event.data.object as Stripe.Subscription
|
||||
|
||||
const metaPlan = (stripeSubscription.metadata?.plan as string | undefined)?.toLowerCase() || ''
|
||||
|
||||
if (metaPlan !== 'enterprise') {
|
||||
logger.info('[subscription.created] Skipping non-enterprise subscription', {
|
||||
subscriptionId: stripeSubscription.id,
|
||||
plan: metaPlan || 'unknown',
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
const stripeCustomerId = stripeSubscription.customer as string
|
||||
|
||||
if (!stripeCustomerId) {
|
||||
logger.error('[subscription.created] Missing Stripe customer ID', {
|
||||
subscriptionId: stripeSubscription.id,
|
||||
})
|
||||
throw new Error('Missing Stripe customer ID on subscription')
|
||||
}
|
||||
|
||||
const metadata = stripeSubscription.metadata || {}
|
||||
|
||||
const referenceId =
|
||||
typeof metadata.referenceId === 'string' && metadata.referenceId.length > 0
|
||||
? metadata.referenceId
|
||||
: null
|
||||
|
||||
if (!referenceId) {
|
||||
logger.error('[subscription.created] Unable to resolve referenceId', {
|
||||
subscriptionId: stripeSubscription.id,
|
||||
stripeCustomerId,
|
||||
})
|
||||
throw new Error('Unable to resolve referenceId for subscription')
|
||||
}
|
||||
|
||||
if (!isEnterpriseMetadata(metadata)) {
|
||||
logger.error('[subscription.created] Invalid enterprise metadata shape', {
|
||||
subscriptionId: stripeSubscription.id,
|
||||
metadata,
|
||||
})
|
||||
throw new Error('Invalid enterprise metadata for subscription')
|
||||
}
|
||||
const enterpriseMetadata = metadata
|
||||
const metadataJson: Record<string, unknown> = { ...enterpriseMetadata }
|
||||
|
||||
// Extract and parse seats and monthly price from metadata (they come as strings from Stripe)
|
||||
const seats = Number.parseInt(enterpriseMetadata.seats, 10)
|
||||
const monthlyPrice = Number.parseFloat(enterpriseMetadata.monthlyPrice)
|
||||
|
||||
if (!seats || seats <= 0 || Number.isNaN(seats)) {
|
||||
logger.error('[subscription.created] Invalid or missing seats in enterprise metadata', {
|
||||
subscriptionId: stripeSubscription.id,
|
||||
seatsRaw: enterpriseMetadata.seats,
|
||||
seatsParsed: seats,
|
||||
})
|
||||
throw new Error('Enterprise subscription must include valid seats in metadata')
|
||||
}
|
||||
|
||||
if (!monthlyPrice || monthlyPrice <= 0 || Number.isNaN(monthlyPrice)) {
|
||||
logger.error('[subscription.created] Invalid or missing monthlyPrice in enterprise metadata', {
|
||||
subscriptionId: stripeSubscription.id,
|
||||
monthlyPriceRaw: enterpriseMetadata.monthlyPrice,
|
||||
monthlyPriceParsed: monthlyPrice,
|
||||
})
|
||||
throw new Error('Enterprise subscription must include valid monthlyPrice in metadata')
|
||||
}
|
||||
|
||||
const subscriptionRow = {
|
||||
id: crypto.randomUUID(),
|
||||
plan: 'enterprise',
|
||||
referenceId,
|
||||
stripeCustomerId,
|
||||
stripeSubscriptionId: stripeSubscription.id,
|
||||
status: stripeSubscription.status || null,
|
||||
periodStart: stripeSubscription.current_period_start
|
||||
? new Date(stripeSubscription.current_period_start * 1000)
|
||||
: null,
|
||||
periodEnd: stripeSubscription.current_period_end
|
||||
? new Date(stripeSubscription.current_period_end * 1000)
|
||||
: null,
|
||||
cancelAtPeriodEnd: stripeSubscription.cancel_at_period_end ?? null,
|
||||
seats,
|
||||
trialStart: stripeSubscription.trial_start
|
||||
? new Date(stripeSubscription.trial_start * 1000)
|
||||
: null,
|
||||
trialEnd: stripeSubscription.trial_end ? new Date(stripeSubscription.trial_end * 1000) : null,
|
||||
metadata: metadataJson,
|
||||
}
|
||||
|
||||
const existing = await db
|
||||
.select({ id: subscription.id })
|
||||
.from(subscription)
|
||||
.where(eq(subscription.stripeSubscriptionId, stripeSubscription.id))
|
||||
.limit(1)
|
||||
|
||||
if (existing.length > 0) {
|
||||
await db
|
||||
.update(subscription)
|
||||
.set({
|
||||
plan: subscriptionRow.plan,
|
||||
referenceId: subscriptionRow.referenceId,
|
||||
stripeCustomerId: subscriptionRow.stripeCustomerId,
|
||||
status: subscriptionRow.status,
|
||||
periodStart: subscriptionRow.periodStart,
|
||||
periodEnd: subscriptionRow.periodEnd,
|
||||
cancelAtPeriodEnd: subscriptionRow.cancelAtPeriodEnd,
|
||||
seats: subscriptionRow.seats,
|
||||
trialStart: subscriptionRow.trialStart,
|
||||
trialEnd: subscriptionRow.trialEnd,
|
||||
metadata: subscriptionRow.metadata,
|
||||
})
|
||||
.where(eq(subscription.stripeSubscriptionId, stripeSubscription.id))
|
||||
} else {
|
||||
await db.insert(subscription).values(subscriptionRow)
|
||||
}
|
||||
|
||||
// Update the organization's usage limit to match the monthly price
|
||||
// The referenceId for enterprise plans is the organization ID
|
||||
try {
|
||||
await db
|
||||
.update(organization)
|
||||
.set({
|
||||
orgUsageLimit: monthlyPrice.toFixed(2),
|
||||
updatedAt: new Date(),
|
||||
})
|
||||
.where(eq(organization.id, referenceId))
|
||||
|
||||
logger.info('[subscription.created] Updated organization usage limit', {
|
||||
organizationId: referenceId,
|
||||
usageLimit: monthlyPrice,
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('[subscription.created] Failed to update organization usage limit', {
|
||||
organizationId: referenceId,
|
||||
usageLimit: monthlyPrice,
|
||||
error,
|
||||
})
|
||||
// Don't throw - the subscription was created successfully, just log the error
|
||||
}
|
||||
|
||||
logger.info('[subscription.created] Upserted enterprise subscription', {
|
||||
subscriptionId: subscriptionRow.id,
|
||||
referenceId: subscriptionRow.referenceId,
|
||||
plan: subscriptionRow.plan,
|
||||
status: subscriptionRow.status,
|
||||
monthlyPrice,
|
||||
seats,
|
||||
note: 'Seats from metadata, Stripe quantity set to 1',
|
||||
})
|
||||
|
||||
try {
|
||||
const userDetails = await db
|
||||
.select({
|
||||
id: user.id,
|
||||
name: user.name,
|
||||
email: user.email,
|
||||
})
|
||||
.from(user)
|
||||
.where(eq(user.stripeCustomerId, stripeCustomerId))
|
||||
.limit(1)
|
||||
|
||||
const orgDetails = await db
|
||||
.select({
|
||||
id: organization.id,
|
||||
name: organization.name,
|
||||
})
|
||||
.from(organization)
|
||||
.where(eq(organization.id, referenceId))
|
||||
.limit(1)
|
||||
|
||||
if (userDetails.length > 0 && orgDetails.length > 0) {
|
||||
const user = userDetails[0]
|
||||
const org = orgDetails[0]
|
||||
|
||||
const html = await renderEnterpriseSubscriptionEmail(user.name || user.email, user.email)
|
||||
|
||||
const emailResult = await sendEmail({
|
||||
to: user.email,
|
||||
subject: getEmailSubject('enterprise-subscription'),
|
||||
html,
|
||||
from: getFromEmailAddress(),
|
||||
emailType: 'transactional',
|
||||
})
|
||||
|
||||
if (emailResult.success) {
|
||||
logger.info('[subscription.created] Enterprise subscription email sent successfully', {
|
||||
userId: user.id,
|
||||
email: user.email,
|
||||
organizationId: org.id,
|
||||
subscriptionId: subscriptionRow.id,
|
||||
})
|
||||
} else {
|
||||
logger.warn('[subscription.created] Failed to send enterprise subscription email', {
|
||||
userId: user.id,
|
||||
email: user.email,
|
||||
error: emailResult.message,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
logger.warn(
|
||||
'[subscription.created] Could not find user or organization for email notification',
|
||||
{
|
||||
userFound: userDetails.length > 0,
|
||||
orgFound: orgDetails.length > 0,
|
||||
stripeCustomerId,
|
||||
referenceId,
|
||||
}
|
||||
)
|
||||
}
|
||||
} catch (emailError) {
|
||||
logger.error('[subscription.created] Error sending enterprise subscription email', {
|
||||
error: emailError,
|
||||
stripeCustomerId,
|
||||
referenceId,
|
||||
subscriptionId: subscriptionRow.id,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -197,6 +197,7 @@ export async function handleInvoicePaymentFailed(event: Stripe.Event) {
|
||||
|
||||
/**
|
||||
* Handle base invoice finalized → create a separate overage-only invoice
|
||||
* Note: Enterprise plans no longer have overages
|
||||
*/
|
||||
export async function handleInvoiceFinalized(event: Stripe.Event) {
|
||||
try {
|
||||
@@ -215,14 +216,22 @@ export async function handleInvoiceFinalized(event: Stripe.Event) {
|
||||
if (records.length === 0) return
|
||||
const sub = records[0]
|
||||
|
||||
// Always reset usage at cycle end for all plans
|
||||
await resetUsageForSubscription({ plan: sub.plan, referenceId: sub.referenceId })
|
||||
|
||||
// Enterprise plans have no overages - skip overage invoice creation
|
||||
if (sub.plan === 'enterprise') {
|
||||
return
|
||||
}
|
||||
|
||||
const stripe = requireStripeClient()
|
||||
const periodEnd =
|
||||
invoice.lines?.data?.[0]?.period?.end || invoice.period_end || Math.floor(Date.now() / 1000)
|
||||
const billingPeriod = new Date(periodEnd * 1000).toISOString().slice(0, 7)
|
||||
|
||||
// Compute overage
|
||||
// Compute overage (only for team and pro plans)
|
||||
let totalOverage = 0
|
||||
if (sub.plan === 'team' || sub.plan === 'enterprise') {
|
||||
if (sub.plan === 'team') {
|
||||
const members = await db
|
||||
.select({ userId: member.userId })
|
||||
.from(member)
|
||||
@@ -235,19 +244,16 @@ export async function handleInvoiceFinalized(event: Stripe.Event) {
|
||||
}
|
||||
|
||||
const { getPlanPricing } = await import('@/lib/billing/core/billing')
|
||||
const { basePrice } = getPlanPricing(sub.plan, sub)
|
||||
const { basePrice } = getPlanPricing(sub.plan)
|
||||
const baseSubscriptionAmount = (sub.seats || 1) * basePrice
|
||||
totalOverage = Math.max(0, totalTeamUsage - baseSubscriptionAmount)
|
||||
} else {
|
||||
const usage = await getUserUsageData(sub.referenceId)
|
||||
const { getPlanPricing } = await import('@/lib/billing/core/billing')
|
||||
const { basePrice } = getPlanPricing(sub.plan, sub)
|
||||
const { basePrice } = getPlanPricing(sub.plan)
|
||||
totalOverage = Math.max(0, usage.currentUsage - basePrice)
|
||||
}
|
||||
|
||||
// Always reset usage at cycle end, regardless of whether overage > 0
|
||||
await resetUsageForSubscription({ plan: sub.plan, referenceId: sub.referenceId })
|
||||
|
||||
if (totalOverage <= 0) return
|
||||
|
||||
const customerId = String(invoice.customer)
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
/**
|
||||
* Knowledge base and document constants
|
||||
*/
|
||||
|
||||
// Tag slot configuration by field type
|
||||
// Each field type maps to specific database columns
|
||||
export const TAG_SLOT_CONFIG = {
|
||||
text: {
|
||||
slots: ['tag1', 'tag2', 'tag3', 'tag4', 'tag5', 'tag6', 'tag7'] as const,
|
||||
maxSlots: 7,
|
||||
},
|
||||
// Future field types would be added here with their own database columns
|
||||
// date: {
|
||||
// slots: ['tag8', 'tag9'] as const,
|
||||
// maxSlots: 2,
|
||||
// },
|
||||
// number: {
|
||||
// slots: ['tag10', 'tag11'] as const,
|
||||
// maxSlots: 2,
|
||||
// },
|
||||
} as const
|
||||
|
||||
// Currently supported field types
|
||||
export const SUPPORTED_FIELD_TYPES = Object.keys(TAG_SLOT_CONFIG) as Array<
|
||||
keyof typeof TAG_SLOT_CONFIG
|
||||
>
|
||||
|
||||
// All tag slots (for backward compatibility)
|
||||
export const TAG_SLOTS = TAG_SLOT_CONFIG.text.slots
|
||||
|
||||
// Maximum number of tag slots for text type (for backward compatibility)
|
||||
export const MAX_TAG_SLOTS = TAG_SLOT_CONFIG.text.maxSlots
|
||||
|
||||
// Type for tag slot names
|
||||
export type TagSlot = (typeof TAG_SLOTS)[number]
|
||||
|
||||
// Helper function to get available slots for a field type
|
||||
export function getSlotsForFieldType(fieldType: string): readonly string[] {
|
||||
const config = TAG_SLOT_CONFIG[fieldType as keyof typeof TAG_SLOT_CONFIG]
|
||||
if (!config) {
|
||||
return [] // Return empty array for unsupported field types - system will naturally handle this
|
||||
}
|
||||
return config.slots
|
||||
}
|
||||
|
||||
// Helper function to get max slots for a field type
|
||||
export function getMaxSlotsForFieldType(fieldType: string): number {
|
||||
const config = TAG_SLOT_CONFIG[fieldType as keyof typeof TAG_SLOT_CONFIG]
|
||||
if (!config) {
|
||||
return 0 // Return 0 for unsupported field types
|
||||
}
|
||||
return config.maxSlots
|
||||
}
|
||||
@@ -139,6 +139,17 @@ export const env = createEnv({
|
||||
RATE_LIMIT_ENTERPRISE_SYNC: z.string().optional().default('150'), // Enterprise tier sync API executions per minute
|
||||
RATE_LIMIT_ENTERPRISE_ASYNC: z.string().optional().default('1000'), // Enterprise tier async API executions per minute
|
||||
|
||||
// Knowledge Base Processing Configuration - Shared across all processing methods
|
||||
KB_CONFIG_MAX_DURATION: z.number().optional().default(300), // Max processing duration in s
|
||||
KB_CONFIG_MAX_ATTEMPTS: z.number().optional().default(3), // Max retry attempts
|
||||
KB_CONFIG_RETRY_FACTOR: z.number().optional().default(2), // Retry backoff factor
|
||||
KB_CONFIG_MIN_TIMEOUT: z.number().optional().default(1000), // Min timeout in ms
|
||||
KB_CONFIG_MAX_TIMEOUT: z.number().optional().default(10000), // Max timeout in ms
|
||||
KB_CONFIG_CONCURRENCY_LIMIT: z.number().optional().default(20), // Queue concurrency limit
|
||||
KB_CONFIG_BATCH_SIZE: z.number().optional().default(20), // Processing batch size
|
||||
KB_CONFIG_DELAY_BETWEEN_BATCHES: z.number().optional().default(100), // Delay between batches in ms
|
||||
KB_CONFIG_DELAY_BETWEEN_DOCUMENTS: z.number().optional().default(50), // Delay between documents in ms
|
||||
|
||||
// Real-time Communication
|
||||
SOCKET_SERVER_URL: z.string().url().optional(), // WebSocket server URL for real-time features
|
||||
SOCKET_PORT: z.number().optional(), // Port for WebSocket server
|
||||
|
||||
@@ -1,139 +1,108 @@
|
||||
import { createReadStream, existsSync } from 'fs'
|
||||
import { Readable } from 'stream'
|
||||
import csvParser from 'csv-parser'
|
||||
import { existsSync, readFileSync } from 'fs'
|
||||
import * as Papa from 'papaparse'
|
||||
import type { FileParseResult, FileParser } from '@/lib/file-parsers/types'
|
||||
import { sanitizeTextForUTF8 } from '@/lib/file-parsers/utils'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
|
||||
const logger = createLogger('CsvParser')
|
||||
|
||||
const PARSE_OPTIONS = {
|
||||
header: true,
|
||||
skipEmptyLines: true,
|
||||
transformHeader: (header: string) => sanitizeTextForUTF8(String(header)),
|
||||
transform: (value: string) => sanitizeTextForUTF8(String(value || '')),
|
||||
}
|
||||
|
||||
export class CsvParser implements FileParser {
|
||||
async parseFile(filePath: string): Promise<FileParseResult> {
|
||||
return new Promise((resolve, reject) => {
|
||||
try {
|
||||
// Validate input
|
||||
if (!filePath) {
|
||||
return reject(new Error('No file path provided'))
|
||||
}
|
||||
|
||||
// Check if file exists
|
||||
if (!existsSync(filePath)) {
|
||||
return reject(new Error(`File not found: ${filePath}`))
|
||||
}
|
||||
|
||||
const results: Record<string, any>[] = []
|
||||
const headers: string[] = []
|
||||
|
||||
createReadStream(filePath)
|
||||
.on('error', (error: Error) => {
|
||||
logger.error('CSV stream error:', error)
|
||||
reject(new Error(`Failed to read CSV file: ${error.message}`))
|
||||
})
|
||||
.pipe(csvParser())
|
||||
.on('headers', (headerList: string[]) => {
|
||||
headers.push(...headerList)
|
||||
})
|
||||
.on('data', (data: Record<string, any>) => {
|
||||
results.push(data)
|
||||
})
|
||||
.on('end', () => {
|
||||
// Convert CSV data to a formatted string representation
|
||||
let content = ''
|
||||
|
||||
// Add headers
|
||||
if (headers.length > 0) {
|
||||
const cleanHeaders = headers.map((h) => sanitizeTextForUTF8(String(h)))
|
||||
content += `${cleanHeaders.join(', ')}\n`
|
||||
}
|
||||
|
||||
// Add rows
|
||||
results.forEach((row) => {
|
||||
const cleanValues = Object.values(row).map((v) =>
|
||||
sanitizeTextForUTF8(String(v || ''))
|
||||
)
|
||||
content += `${cleanValues.join(', ')}\n`
|
||||
})
|
||||
|
||||
resolve({
|
||||
content: sanitizeTextForUTF8(content),
|
||||
metadata: {
|
||||
rowCount: results.length,
|
||||
headers: headers,
|
||||
rawData: results,
|
||||
},
|
||||
})
|
||||
})
|
||||
.on('error', (error: Error) => {
|
||||
logger.error('CSV parsing error:', error)
|
||||
reject(new Error(`Failed to parse CSV file: ${error.message}`))
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('CSV general error:', error)
|
||||
reject(new Error(`Failed to process CSV file: ${(error as Error).message}`))
|
||||
try {
|
||||
if (!filePath) {
|
||||
throw new Error('No file path provided')
|
||||
}
|
||||
})
|
||||
|
||||
if (!existsSync(filePath)) {
|
||||
throw new Error(`File not found: ${filePath}`)
|
||||
}
|
||||
|
||||
const fileContent = readFileSync(filePath, 'utf8')
|
||||
|
||||
const parseResult = Papa.parse(fileContent, PARSE_OPTIONS)
|
||||
|
||||
if (parseResult.errors && parseResult.errors.length > 0) {
|
||||
const errorMessages = parseResult.errors.map((err) => err.message).join(', ')
|
||||
logger.error('CSV parsing errors:', parseResult.errors)
|
||||
throw new Error(`Failed to parse CSV file: ${errorMessages}`)
|
||||
}
|
||||
|
||||
const results = parseResult.data as Record<string, any>[]
|
||||
const headers = parseResult.meta.fields || []
|
||||
|
||||
let content = ''
|
||||
|
||||
if (headers.length > 0) {
|
||||
const cleanHeaders = headers.map((h) => sanitizeTextForUTF8(String(h)))
|
||||
content += `${cleanHeaders.join(', ')}\n`
|
||||
}
|
||||
|
||||
results.forEach((row) => {
|
||||
const cleanValues = Object.values(row).map((v) => sanitizeTextForUTF8(String(v || '')))
|
||||
content += `${cleanValues.join(', ')}\n`
|
||||
})
|
||||
|
||||
return {
|
||||
content: sanitizeTextForUTF8(content),
|
||||
metadata: {
|
||||
rowCount: results.length,
|
||||
headers: headers,
|
||||
rawData: results,
|
||||
},
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('CSV general error:', error)
|
||||
throw new Error(`Failed to process CSV file: ${(error as Error).message}`)
|
||||
}
|
||||
}
|
||||
|
||||
async parseBuffer(buffer: Buffer): Promise<FileParseResult> {
|
||||
return new Promise((resolve, reject) => {
|
||||
try {
|
||||
logger.info('Parsing buffer, size:', buffer.length)
|
||||
try {
|
||||
logger.info('Parsing buffer, size:', buffer.length)
|
||||
|
||||
const results: Record<string, any>[] = []
|
||||
const headers: string[] = []
|
||||
const fileContent = buffer.toString('utf8')
|
||||
|
||||
// Create a readable stream from the buffer
|
||||
const bufferStream = new Readable()
|
||||
bufferStream.push(buffer)
|
||||
bufferStream.push(null) // Signal the end of the stream
|
||||
const parseResult = Papa.parse(fileContent, PARSE_OPTIONS)
|
||||
|
||||
bufferStream
|
||||
.on('error', (error: Error) => {
|
||||
logger.error('CSV buffer stream error:', error)
|
||||
reject(new Error(`Failed to read CSV buffer: ${error.message}`))
|
||||
})
|
||||
.pipe(csvParser())
|
||||
.on('headers', (headerList: string[]) => {
|
||||
headers.push(...headerList)
|
||||
})
|
||||
.on('data', (data: Record<string, any>) => {
|
||||
results.push(data)
|
||||
})
|
||||
.on('end', () => {
|
||||
// Convert CSV data to a formatted string representation
|
||||
let content = ''
|
||||
|
||||
// Add headers
|
||||
if (headers.length > 0) {
|
||||
const cleanHeaders = headers.map((h) => sanitizeTextForUTF8(String(h)))
|
||||
content += `${cleanHeaders.join(', ')}\n`
|
||||
}
|
||||
|
||||
// Add rows
|
||||
results.forEach((row) => {
|
||||
const cleanValues = Object.values(row).map((v) =>
|
||||
sanitizeTextForUTF8(String(v || ''))
|
||||
)
|
||||
content += `${cleanValues.join(', ')}\n`
|
||||
})
|
||||
|
||||
resolve({
|
||||
content: sanitizeTextForUTF8(content),
|
||||
metadata: {
|
||||
rowCount: results.length,
|
||||
headers: headers,
|
||||
rawData: results,
|
||||
},
|
||||
})
|
||||
})
|
||||
.on('error', (error: Error) => {
|
||||
logger.error('CSV parsing error:', error)
|
||||
reject(new Error(`Failed to parse CSV buffer: ${error.message}`))
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('CSV buffer parsing error:', error)
|
||||
reject(new Error(`Failed to process CSV buffer: ${(error as Error).message}`))
|
||||
if (parseResult.errors && parseResult.errors.length > 0) {
|
||||
const errorMessages = parseResult.errors.map((err) => err.message).join(', ')
|
||||
logger.error('CSV parsing errors:', parseResult.errors)
|
||||
throw new Error(`Failed to parse CSV buffer: ${errorMessages}`)
|
||||
}
|
||||
})
|
||||
|
||||
const results = parseResult.data as Record<string, any>[]
|
||||
const headers = parseResult.meta.fields || []
|
||||
|
||||
let content = ''
|
||||
|
||||
if (headers.length > 0) {
|
||||
const cleanHeaders = headers.map((h) => sanitizeTextForUTF8(String(h)))
|
||||
content += `${cleanHeaders.join(', ')}\n`
|
||||
}
|
||||
|
||||
results.forEach((row) => {
|
||||
const cleanValues = Object.values(row).map((v) => sanitizeTextForUTF8(String(v || '')))
|
||||
content += `${cleanValues.join(', ')}\n`
|
||||
})
|
||||
|
||||
return {
|
||||
content: sanitizeTextForUTF8(content),
|
||||
metadata: {
|
||||
rowCount: results.length,
|
||||
headers: headers,
|
||||
rawData: results,
|
||||
},
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('CSV buffer parsing error:', error)
|
||||
throw new Error(`Failed to process CSV buffer: ${(error as Error).message}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,19 +9,16 @@ const logger = createLogger('DocParser')
|
||||
export class DocParser implements FileParser {
|
||||
async parseFile(filePath: string): Promise<FileParseResult> {
|
||||
try {
|
||||
// Validate input
|
||||
if (!filePath) {
|
||||
throw new Error('No file path provided')
|
||||
}
|
||||
|
||||
// Check if file exists
|
||||
if (!existsSync(filePath)) {
|
||||
throw new Error(`File not found: ${filePath}`)
|
||||
}
|
||||
|
||||
logger.info(`Parsing DOC file: ${filePath}`)
|
||||
|
||||
// Read the file
|
||||
const buffer = await readFile(filePath)
|
||||
return this.parseBuffer(buffer)
|
||||
} catch (error) {
|
||||
@@ -38,45 +35,37 @@ export class DocParser implements FileParser {
|
||||
throw new Error('Empty buffer provided')
|
||||
}
|
||||
|
||||
// Try to dynamically import the word extractor
|
||||
let WordExtractor
|
||||
let parseOfficeAsync
|
||||
try {
|
||||
WordExtractor = (await import('word-extractor')).default
|
||||
const officeParser = await import('officeparser')
|
||||
parseOfficeAsync = officeParser.parseOfficeAsync
|
||||
} catch (importError) {
|
||||
logger.warn('word-extractor not available, using fallback extraction')
|
||||
logger.warn('officeparser not available, using fallback extraction')
|
||||
return this.fallbackExtraction(buffer)
|
||||
}
|
||||
|
||||
try {
|
||||
const extractor = new WordExtractor()
|
||||
const extracted = await extractor.extract(buffer)
|
||||
const result = await parseOfficeAsync(buffer)
|
||||
|
||||
const content = sanitizeTextForUTF8(extracted.getBody())
|
||||
const headers = extracted.getHeaders()
|
||||
const footers = extracted.getFooters()
|
||||
|
||||
// Combine body with headers/footers if they exist
|
||||
let fullContent = content
|
||||
if (headers?.trim()) {
|
||||
fullContent = `${sanitizeTextForUTF8(headers)}\n\n${fullContent}`
|
||||
}
|
||||
if (footers?.trim()) {
|
||||
fullContent = `${fullContent}\n\n${sanitizeTextForUTF8(footers)}`
|
||||
if (!result) {
|
||||
throw new Error('officeparser returned no result')
|
||||
}
|
||||
|
||||
logger.info('DOC parsing completed successfully')
|
||||
const resultString = typeof result === 'string' ? result : String(result)
|
||||
|
||||
const content = sanitizeTextForUTF8(resultString.trim())
|
||||
|
||||
logger.info('DOC parsing completed successfully with officeparser')
|
||||
|
||||
return {
|
||||
content: fullContent.trim(),
|
||||
content: content,
|
||||
metadata: {
|
||||
hasHeaders: !!headers?.trim(),
|
||||
hasFooters: !!footers?.trim(),
|
||||
characterCount: fullContent.length,
|
||||
extractionMethod: 'word-extractor',
|
||||
characterCount: content.length,
|
||||
extractionMethod: 'officeparser',
|
||||
},
|
||||
}
|
||||
} catch (extractError) {
|
||||
logger.warn('word-extractor failed, using fallback:', extractError)
|
||||
logger.warn('officeparser failed, using fallback:', extractError)
|
||||
return this.fallbackExtraction(buffer)
|
||||
}
|
||||
} catch (error) {
|
||||
@@ -85,25 +74,16 @@ export class DocParser implements FileParser {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Fallback extraction method for when word-extractor is not available
|
||||
* This is a very basic extraction that looks for readable text in the binary
|
||||
*/
|
||||
private fallbackExtraction(buffer: Buffer): FileParseResult {
|
||||
logger.info('Using fallback text extraction for DOC file')
|
||||
|
||||
// Convert buffer to string and try to extract readable text
|
||||
// This is very basic and won't work well for complex DOC files
|
||||
const text = buffer.toString('utf8', 0, Math.min(buffer.length, 100000)) // Limit to first 100KB
|
||||
const text = buffer.toString('utf8', 0, Math.min(buffer.length, 100000))
|
||||
|
||||
// Extract sequences of printable ASCII characters
|
||||
const readableText = text
|
||||
.match(/[\x20-\x7E\s]{4,}/g) // Find sequences of 4+ printable characters
|
||||
.match(/[\x20-\x7E\s]{4,}/g)
|
||||
?.filter(
|
||||
(chunk) =>
|
||||
chunk.trim().length > 10 && // Minimum length
|
||||
/[a-zA-Z]/.test(chunk) && // Must contain letters
|
||||
!/^[\x00-\x1F]*$/.test(chunk) // Not just control characters
|
||||
chunk.trim().length > 10 && /[a-zA-Z]/.test(chunk) && !/^[\x00-\x1F]*$/.test(chunk)
|
||||
)
|
||||
.join(' ')
|
||||
.replace(/\s+/g, ' ')
|
||||
@@ -118,8 +98,7 @@ export class DocParser implements FileParser {
|
||||
metadata: {
|
||||
extractionMethod: 'fallback',
|
||||
characterCount: content.length,
|
||||
warning:
|
||||
'Basic text extraction used. For better results, install word-extractor package or convert to DOCX format.',
|
||||
warning: 'Basic text extraction used. For better results, convert to DOCX format.',
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,15 +14,12 @@ interface MammothResult {
|
||||
export class DocxParser implements FileParser {
|
||||
async parseFile(filePath: string): Promise<FileParseResult> {
|
||||
try {
|
||||
// Validate input
|
||||
if (!filePath) {
|
||||
throw new Error('No file path provided')
|
||||
}
|
||||
|
||||
// Read the file
|
||||
const buffer = await readFile(filePath)
|
||||
|
||||
// Use parseBuffer for consistent implementation
|
||||
return this.parseBuffer(buffer)
|
||||
} catch (error) {
|
||||
logger.error('DOCX file error:', error)
|
||||
@@ -34,10 +31,8 @@ export class DocxParser implements FileParser {
|
||||
try {
|
||||
logger.info('Parsing buffer, size:', buffer.length)
|
||||
|
||||
// Extract text with mammoth
|
||||
const result = await mammoth.extractRawText({ buffer })
|
||||
|
||||
// Extract HTML for metadata (optional - won't fail if this fails)
|
||||
let htmlResult: MammothResult = { value: '', messages: [] }
|
||||
try {
|
||||
htmlResult = await mammoth.convertToHtml({ buffer })
|
||||
|
||||
283
apps/sim/lib/file-parsers/html-parser.ts
Normal file
283
apps/sim/lib/file-parsers/html-parser.ts
Normal file
@@ -0,0 +1,283 @@
|
||||
import { readFile } from 'fs/promises'
|
||||
import * as cheerio from 'cheerio'
|
||||
import type { FileParseResult, FileParser } from '@/lib/file-parsers/types'
|
||||
import { sanitizeTextForUTF8 } from '@/lib/file-parsers/utils'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
|
||||
const logger = createLogger('HtmlParser')
|
||||
|
||||
export class HtmlParser implements FileParser {
|
||||
async parseFile(filePath: string): Promise<FileParseResult> {
|
||||
try {
|
||||
if (!filePath) {
|
||||
throw new Error('No file path provided')
|
||||
}
|
||||
|
||||
const buffer = await readFile(filePath)
|
||||
return this.parseBuffer(buffer)
|
||||
} catch (error) {
|
||||
logger.error('HTML file error:', error)
|
||||
throw new Error(`Failed to parse HTML file: ${(error as Error).message}`)
|
||||
}
|
||||
}
|
||||
|
||||
async parseBuffer(buffer: Buffer): Promise<FileParseResult> {
|
||||
try {
|
||||
logger.info('Parsing HTML buffer, size:', buffer.length)
|
||||
|
||||
const htmlContent = buffer.toString('utf-8')
|
||||
const $ = cheerio.load(htmlContent)
|
||||
|
||||
// Extract meta information before removing tags
|
||||
const title = $('title').text().trim()
|
||||
const metaDescription = $('meta[name="description"]').attr('content') || ''
|
||||
|
||||
$('script, style, noscript, meta, link, iframe, object, embed, svg').remove()
|
||||
|
||||
$.root()
|
||||
.contents()
|
||||
.filter(function () {
|
||||
return this.type === 'comment'
|
||||
})
|
||||
.remove()
|
||||
|
||||
const content = this.extractStructuredText($)
|
||||
|
||||
const sanitizedContent = sanitizeTextForUTF8(content)
|
||||
|
||||
const characterCount = sanitizedContent.length
|
||||
const wordCount = sanitizedContent.split(/\s+/).filter((word) => word.length > 0).length
|
||||
const estimatedTokenCount = Math.ceil(characterCount / 4)
|
||||
|
||||
const headings = this.extractHeadings($)
|
||||
|
||||
const links = this.extractLinks($)
|
||||
|
||||
return {
|
||||
content: sanitizedContent,
|
||||
metadata: {
|
||||
title,
|
||||
metaDescription,
|
||||
characterCount,
|
||||
wordCount,
|
||||
tokenCount: estimatedTokenCount,
|
||||
headings,
|
||||
links: links.slice(0, 50),
|
||||
hasImages: $('img').length > 0,
|
||||
imageCount: $('img').length,
|
||||
hasTable: $('table').length > 0,
|
||||
tableCount: $('table').length,
|
||||
hasList: $('ul, ol').length > 0,
|
||||
listCount: $('ul, ol').length,
|
||||
},
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('HTML buffer parsing error:', error)
|
||||
throw new Error(`Failed to parse HTML buffer: ${(error as Error).message}`)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract structured text content preserving document hierarchy
|
||||
*/
|
||||
private extractStructuredText($: cheerio.CheerioAPI): string {
|
||||
const contentParts: string[] = []
|
||||
|
||||
const rootElement = $('body').length > 0 ? $('body') : $.root()
|
||||
|
||||
this.processElement($, rootElement, contentParts, 0)
|
||||
|
||||
return contentParts.join('\n').trim()
|
||||
}
|
||||
|
||||
/**
|
||||
* Recursively process elements to extract text with structure
|
||||
*/
|
||||
private processElement(
|
||||
$: cheerio.CheerioAPI,
|
||||
element: cheerio.Cheerio<any>,
|
||||
contentParts: string[],
|
||||
depth: number
|
||||
): void {
|
||||
element.contents().each((_, node) => {
|
||||
if (node.type === 'text') {
|
||||
const text = $(node).text().trim()
|
||||
if (text) {
|
||||
contentParts.push(text)
|
||||
}
|
||||
} else if (node.type === 'tag') {
|
||||
const $node = $(node)
|
||||
const tagName = node.tagName?.toLowerCase()
|
||||
|
||||
switch (tagName) {
|
||||
case 'h1':
|
||||
case 'h2':
|
||||
case 'h3':
|
||||
case 'h4':
|
||||
case 'h5':
|
||||
case 'h6': {
|
||||
const headingText = $node.text().trim()
|
||||
if (headingText) {
|
||||
contentParts.push(`\n${headingText}\n`)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
case 'p': {
|
||||
const paragraphText = $node.text().trim()
|
||||
if (paragraphText) {
|
||||
contentParts.push(`${paragraphText}\n`)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
case 'br':
|
||||
contentParts.push('\n')
|
||||
break
|
||||
|
||||
case 'hr':
|
||||
contentParts.push('\n---\n')
|
||||
break
|
||||
|
||||
case 'li': {
|
||||
const listItemText = $node.text().trim()
|
||||
if (listItemText) {
|
||||
const indent = ' '.repeat(Math.min(depth, 3))
|
||||
contentParts.push(`${indent}• ${listItemText}`)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
case 'ul':
|
||||
case 'ol':
|
||||
contentParts.push('\n')
|
||||
this.processElement($, $node, contentParts, depth + 1)
|
||||
contentParts.push('\n')
|
||||
break
|
||||
|
||||
case 'table':
|
||||
this.processTable($, $node, contentParts)
|
||||
break
|
||||
|
||||
case 'blockquote': {
|
||||
const quoteText = $node.text().trim()
|
||||
if (quoteText) {
|
||||
contentParts.push(`\n> ${quoteText}\n`)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
case 'pre':
|
||||
case 'code': {
|
||||
const codeText = $node.text().trim()
|
||||
if (codeText) {
|
||||
contentParts.push(`\n\`\`\`\n${codeText}\n\`\`\`\n`)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
case 'div':
|
||||
case 'section':
|
||||
case 'article':
|
||||
case 'main':
|
||||
case 'aside':
|
||||
case 'nav':
|
||||
case 'header':
|
||||
case 'footer':
|
||||
this.processElement($, $node, contentParts, depth)
|
||||
break
|
||||
|
||||
case 'a': {
|
||||
const linkText = $node.text().trim()
|
||||
const href = $node.attr('href')
|
||||
if (linkText) {
|
||||
if (href?.startsWith('http')) {
|
||||
contentParts.push(`${linkText} (${href})`)
|
||||
} else {
|
||||
contentParts.push(linkText)
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
case 'img': {
|
||||
const alt = $node.attr('alt')
|
||||
if (alt) {
|
||||
contentParts.push(`[Image: ${alt}]`)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
default:
|
||||
this.processElement($, $node, contentParts, depth)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Process table elements to extract structured data
|
||||
*/
|
||||
private processTable(
|
||||
$: cheerio.CheerioAPI,
|
||||
table: cheerio.Cheerio<any>,
|
||||
contentParts: string[]
|
||||
): void {
|
||||
contentParts.push('\n[Table]')
|
||||
|
||||
table.find('tr').each((_, row) => {
|
||||
const $row = $(row)
|
||||
const cells: string[] = []
|
||||
|
||||
$row.find('td, th').each((_, cell) => {
|
||||
const cellText = $(cell).text().trim()
|
||||
cells.push(cellText || '')
|
||||
})
|
||||
|
||||
if (cells.length > 0) {
|
||||
contentParts.push(`| ${cells.join(' | ')} |`)
|
||||
}
|
||||
})
|
||||
|
||||
contentParts.push('[/Table]\n')
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract heading structure for metadata
|
||||
*/
|
||||
private extractHeadings($: cheerio.CheerioAPI): Array<{ level: number; text: string }> {
|
||||
const headings: Array<{ level: number; text: string }> = []
|
||||
|
||||
$('h1, h2, h3, h4, h5, h6').each((_, element) => {
|
||||
const $element = $(element)
|
||||
const tagName = element.tagName?.toLowerCase()
|
||||
const level = Number.parseInt(tagName?.charAt(1) || '1', 10)
|
||||
const text = $element.text().trim()
|
||||
|
||||
if (text) {
|
||||
headings.push({ level, text })
|
||||
}
|
||||
})
|
||||
|
||||
return headings
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract links from the document
|
||||
*/
|
||||
private extractLinks($: cheerio.CheerioAPI): Array<{ text: string; href: string }> {
|
||||
const links: Array<{ text: string; href: string }> = []
|
||||
|
||||
$('a[href]').each((_, element) => {
|
||||
const $element = $(element)
|
||||
const href = $element.attr('href')
|
||||
const text = $element.text().trim()
|
||||
|
||||
if (href && text && href.startsWith('http')) {
|
||||
links.push({ text, href })
|
||||
}
|
||||
})
|
||||
|
||||
return links
|
||||
}
|
||||
}
|
||||
@@ -51,6 +51,23 @@ const mockMdParseFile = vi.fn().mockResolvedValue({
|
||||
},
|
||||
})
|
||||
|
||||
const mockPptxParseFile = vi.fn().mockResolvedValue({
|
||||
content: 'Parsed PPTX content',
|
||||
metadata: {
|
||||
slideCount: 5,
|
||||
extractionMethod: 'officeparser',
|
||||
},
|
||||
})
|
||||
|
||||
const mockHtmlParseFile = vi.fn().mockResolvedValue({
|
||||
content: 'Parsed HTML content',
|
||||
metadata: {
|
||||
title: 'Test HTML Document',
|
||||
headingCount: 3,
|
||||
linkCount: 2,
|
||||
},
|
||||
})
|
||||
|
||||
const createMockModule = () => {
|
||||
const mockParsers: Record<string, FileParser> = {
|
||||
pdf: { parseFile: mockPdfParseFile },
|
||||
@@ -58,6 +75,10 @@ const createMockModule = () => {
|
||||
docx: { parseFile: mockDocxParseFile },
|
||||
txt: { parseFile: mockTxtParseFile },
|
||||
md: { parseFile: mockMdParseFile },
|
||||
pptx: { parseFile: mockPptxParseFile },
|
||||
ppt: { parseFile: mockPptxParseFile },
|
||||
html: { parseFile: mockHtmlParseFile },
|
||||
htm: { parseFile: mockHtmlParseFile },
|
||||
}
|
||||
|
||||
return {
|
||||
@@ -143,6 +164,18 @@ describe('File Parsers', () => {
|
||||
})),
|
||||
}))
|
||||
|
||||
vi.doMock('@/lib/file-parsers/pptx-parser', () => ({
|
||||
PptxParser: vi.fn().mockImplementation(() => ({
|
||||
parseFile: mockPptxParseFile,
|
||||
})),
|
||||
}))
|
||||
|
||||
vi.doMock('@/lib/file-parsers/html-parser', () => ({
|
||||
HtmlParser: vi.fn().mockImplementation(() => ({
|
||||
parseFile: mockHtmlParseFile,
|
||||
})),
|
||||
}))
|
||||
|
||||
global.console = {
|
||||
...console,
|
||||
log: vi.fn(),
|
||||
@@ -261,6 +294,82 @@ describe('File Parsers', () => {
|
||||
|
||||
const { parseFile } = await import('@/lib/file-parsers/index')
|
||||
const result = await parseFile('/test/files/document.md')
|
||||
|
||||
expect(result).toEqual(expectedResult)
|
||||
})
|
||||
|
||||
it('should parse PPTX files successfully', async () => {
|
||||
const expectedResult = {
|
||||
content: 'Parsed PPTX content',
|
||||
metadata: {
|
||||
slideCount: 5,
|
||||
extractionMethod: 'officeparser',
|
||||
},
|
||||
}
|
||||
|
||||
mockPptxParseFile.mockResolvedValueOnce(expectedResult)
|
||||
mockExistsSync.mockReturnValue(true)
|
||||
|
||||
const { parseFile } = await import('@/lib/file-parsers/index')
|
||||
const result = await parseFile('/test/files/presentation.pptx')
|
||||
|
||||
expect(result).toEqual(expectedResult)
|
||||
})
|
||||
|
||||
it('should parse PPT files successfully', async () => {
|
||||
const expectedResult = {
|
||||
content: 'Parsed PPTX content',
|
||||
metadata: {
|
||||
slideCount: 5,
|
||||
extractionMethod: 'officeparser',
|
||||
},
|
||||
}
|
||||
|
||||
mockPptxParseFile.mockResolvedValueOnce(expectedResult)
|
||||
mockExistsSync.mockReturnValue(true)
|
||||
|
||||
const { parseFile } = await import('@/lib/file-parsers/index')
|
||||
const result = await parseFile('/test/files/presentation.ppt')
|
||||
|
||||
expect(result).toEqual(expectedResult)
|
||||
})
|
||||
|
||||
it('should parse HTML files successfully', async () => {
|
||||
const expectedResult = {
|
||||
content: 'Parsed HTML content',
|
||||
metadata: {
|
||||
title: 'Test HTML Document',
|
||||
headingCount: 3,
|
||||
linkCount: 2,
|
||||
},
|
||||
}
|
||||
|
||||
mockHtmlParseFile.mockResolvedValueOnce(expectedResult)
|
||||
mockExistsSync.mockReturnValue(true)
|
||||
|
||||
const { parseFile } = await import('@/lib/file-parsers/index')
|
||||
const result = await parseFile('/test/files/document.html')
|
||||
|
||||
expect(result).toEqual(expectedResult)
|
||||
})
|
||||
|
||||
it('should parse HTM files successfully', async () => {
|
||||
const expectedResult = {
|
||||
content: 'Parsed HTML content',
|
||||
metadata: {
|
||||
title: 'Test HTML Document',
|
||||
headingCount: 3,
|
||||
linkCount: 2,
|
||||
},
|
||||
}
|
||||
|
||||
mockHtmlParseFile.mockResolvedValueOnce(expectedResult)
|
||||
mockExistsSync.mockReturnValue(true)
|
||||
|
||||
const { parseFile } = await import('@/lib/file-parsers/index')
|
||||
const result = await parseFile('/test/files/document.htm')
|
||||
|
||||
expect(result).toEqual(expectedResult)
|
||||
})
|
||||
|
||||
it('should throw error for unsupported file types', async () => {
|
||||
@@ -292,6 +401,10 @@ describe('File Parsers', () => {
|
||||
expect(isSupportedFileType('docx')).toBe(true)
|
||||
expect(isSupportedFileType('txt')).toBe(true)
|
||||
expect(isSupportedFileType('md')).toBe(true)
|
||||
expect(isSupportedFileType('pptx')).toBe(true)
|
||||
expect(isSupportedFileType('ppt')).toBe(true)
|
||||
expect(isSupportedFileType('html')).toBe(true)
|
||||
expect(isSupportedFileType('htm')).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false for unsupported file types', async () => {
|
||||
@@ -308,6 +421,8 @@ describe('File Parsers', () => {
|
||||
expect(isSupportedFileType('CSV')).toBe(true)
|
||||
expect(isSupportedFileType('TXT')).toBe(true)
|
||||
expect(isSupportedFileType('MD')).toBe(true)
|
||||
expect(isSupportedFileType('PPTX')).toBe(true)
|
||||
expect(isSupportedFileType('HTML')).toBe(true)
|
||||
})
|
||||
|
||||
it('should handle errors gracefully', async () => {
|
||||
|
||||
@@ -7,7 +7,6 @@ import { createLogger } from '@/lib/logs/console/logger'
|
||||
|
||||
const logger = createLogger('FileParser')
|
||||
|
||||
// Lazy-loaded parsers to avoid initialization issues
|
||||
let parserInstances: Record<string, FileParser> | null = null
|
||||
|
||||
/**
|
||||
@@ -18,25 +17,20 @@ function getParserInstances(): Record<string, FileParser> {
|
||||
parserInstances = {}
|
||||
|
||||
try {
|
||||
// Import parsers only when needed - with try/catch for each one
|
||||
try {
|
||||
logger.info('Attempting to load PDF parser...')
|
||||
try {
|
||||
// First try to use the pdf-parse library
|
||||
// Import the PdfParser using ES module import to avoid test file access
|
||||
const { PdfParser } = require('@/lib/file-parsers/pdf-parser')
|
||||
parserInstances.pdf = new PdfParser()
|
||||
logger.info('PDF parser loaded successfully')
|
||||
} catch (pdfParseError) {
|
||||
// If that fails, fallback to our raw PDF parser
|
||||
logger.error('Failed to load primary PDF parser:', pdfParseError)
|
||||
} catch (pdfLibError) {
|
||||
logger.error('Failed to load primary PDF parser:', pdfLibError)
|
||||
logger.info('Falling back to raw PDF parser')
|
||||
parserInstances.pdf = new RawPdfParser()
|
||||
logger.info('Raw PDF parser loaded successfully')
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Failed to load any PDF parser:', error)
|
||||
// Create a simple fallback that just returns the file size and a message
|
||||
parserInstances.pdf = {
|
||||
async parseFile(filePath: string): Promise<FileParseResult> {
|
||||
const buffer = await readFile(filePath)
|
||||
@@ -100,10 +94,26 @@ function getParserInstances(): Record<string, FileParser> {
|
||||
try {
|
||||
const { XlsxParser } = require('@/lib/file-parsers/xlsx-parser')
|
||||
parserInstances.xlsx = new XlsxParser()
|
||||
parserInstances.xls = new XlsxParser() // Both xls and xlsx use the same parser
|
||||
parserInstances.xls = new XlsxParser()
|
||||
} catch (error) {
|
||||
logger.error('Failed to load XLSX parser:', error)
|
||||
}
|
||||
|
||||
try {
|
||||
const { PptxParser } = require('@/lib/file-parsers/pptx-parser')
|
||||
parserInstances.pptx = new PptxParser()
|
||||
parserInstances.ppt = new PptxParser()
|
||||
} catch (error) {
|
||||
logger.error('Failed to load PPTX parser:', error)
|
||||
}
|
||||
|
||||
try {
|
||||
const { HtmlParser } = require('@/lib/file-parsers/html-parser')
|
||||
parserInstances.html = new HtmlParser()
|
||||
parserInstances.htm = new HtmlParser()
|
||||
} catch (error) {
|
||||
logger.error('Failed to load HTML parser:', error)
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Error loading file parsers:', error)
|
||||
}
|
||||
@@ -119,12 +129,10 @@ function getParserInstances(): Record<string, FileParser> {
|
||||
*/
|
||||
export async function parseFile(filePath: string): Promise<FileParseResult> {
|
||||
try {
|
||||
// Validate input
|
||||
if (!filePath) {
|
||||
throw new Error('No file path provided')
|
||||
}
|
||||
|
||||
// Check if file exists
|
||||
if (!existsSync(filePath)) {
|
||||
throw new Error(`File not found: ${filePath}`)
|
||||
}
|
||||
@@ -158,7 +166,6 @@ export async function parseFile(filePath: string): Promise<FileParseResult> {
|
||||
*/
|
||||
export async function parseBuffer(buffer: Buffer, extension: string): Promise<FileParseResult> {
|
||||
try {
|
||||
// Validate input
|
||||
if (!buffer || buffer.length === 0) {
|
||||
throw new Error('Empty buffer provided')
|
||||
}
|
||||
@@ -182,7 +189,6 @@ export async function parseBuffer(buffer: Buffer, extension: string): Promise<Fi
|
||||
logger.info('Using parser for extension:', normalizedExtension)
|
||||
const parser = parsers[normalizedExtension]
|
||||
|
||||
// Check if parser supports buffer parsing
|
||||
if (parser.parseBuffer) {
|
||||
return await parser.parseBuffer(buffer)
|
||||
}
|
||||
@@ -207,5 +213,4 @@ export function isSupportedFileType(extension: string): extension is SupportedFi
|
||||
}
|
||||
}
|
||||
|
||||
// Type exports
|
||||
export type { FileParseResult, FileParser, SupportedFileType }
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { readFile } from 'fs/promises'
|
||||
import type { FileParseResult, FileParser } from '@/lib/file-parsers/types'
|
||||
import { sanitizeTextForUTF8 } from '@/lib/file-parsers/utils'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
|
||||
const logger = createLogger('MdParser')
|
||||
@@ -7,15 +8,12 @@ const logger = createLogger('MdParser')
|
||||
export class MdParser implements FileParser {
|
||||
async parseFile(filePath: string): Promise<FileParseResult> {
|
||||
try {
|
||||
// Validate input
|
||||
if (!filePath) {
|
||||
throw new Error('No file path provided')
|
||||
}
|
||||
|
||||
// Read the file
|
||||
const buffer = await readFile(filePath)
|
||||
|
||||
// Use parseBuffer for consistent implementation
|
||||
return this.parseBuffer(buffer)
|
||||
} catch (error) {
|
||||
logger.error('MD file error:', error)
|
||||
@@ -27,14 +25,14 @@ export class MdParser implements FileParser {
|
||||
try {
|
||||
logger.info('Parsing buffer, size:', buffer.length)
|
||||
|
||||
// Extract content
|
||||
const result = buffer.toString('utf-8')
|
||||
const content = sanitizeTextForUTF8(result)
|
||||
|
||||
return {
|
||||
content: result,
|
||||
content,
|
||||
metadata: {
|
||||
characterCount: result.length,
|
||||
tokenCount: result.length / 4,
|
||||
characterCount: content.length,
|
||||
tokenCount: Math.floor(content.length / 4),
|
||||
},
|
||||
}
|
||||
} catch (error) {
|
||||
|
||||
@@ -1,22 +1,21 @@
|
||||
import { readFile } from 'fs/promises'
|
||||
// @ts-ignore
|
||||
import * as pdfParseLib from 'pdf-parse/lib/pdf-parse.js'
|
||||
import { PDFDocument } from 'pdf-lib'
|
||||
import type { FileParseResult, FileParser } from '@/lib/file-parsers/types'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { RawPdfParser } from './raw-pdf-parser'
|
||||
|
||||
const logger = createLogger('PdfParser')
|
||||
const rawPdfParser = new RawPdfParser()
|
||||
|
||||
export class PdfParser implements FileParser {
|
||||
async parseFile(filePath: string): Promise<FileParseResult> {
|
||||
try {
|
||||
logger.info('Starting to parse file:', filePath)
|
||||
|
||||
// Make sure we're only parsing the provided file path
|
||||
if (!filePath) {
|
||||
throw new Error('No file path provided')
|
||||
}
|
||||
|
||||
// Read the file
|
||||
logger.info('Reading file...')
|
||||
const dataBuffer = await readFile(filePath)
|
||||
logger.info('File read successfully, size:', dataBuffer.length)
|
||||
@@ -32,93 +31,66 @@ export class PdfParser implements FileParser {
|
||||
try {
|
||||
logger.info('Starting to parse buffer, size:', dataBuffer.length)
|
||||
|
||||
// Try to parse with pdf-parse library first
|
||||
try {
|
||||
logger.info('Attempting to parse with pdf-parse library...')
|
||||
logger.info('Attempting to parse with pdf-lib library...')
|
||||
|
||||
// Parse PDF with direct function call to avoid test file access
|
||||
logger.info('Starting PDF parsing...')
|
||||
const data = await pdfParseLib.default(dataBuffer)
|
||||
logger.info('PDF parsed successfully with pdf-parse, pages:', data.numpages)
|
||||
const pdfDoc = await PDFDocument.load(dataBuffer)
|
||||
const pages = pdfDoc.getPages()
|
||||
const pageCount = pages.length
|
||||
|
||||
logger.info('PDF parsed successfully with pdf-lib, pages:', pageCount)
|
||||
|
||||
const metadata: Record<string, any> = {
|
||||
pageCount,
|
||||
}
|
||||
|
||||
try {
|
||||
const title = pdfDoc.getTitle()
|
||||
const author = pdfDoc.getAuthor()
|
||||
const subject = pdfDoc.getSubject()
|
||||
const creator = pdfDoc.getCreator()
|
||||
const producer = pdfDoc.getProducer()
|
||||
const creationDate = pdfDoc.getCreationDate()
|
||||
const modificationDate = pdfDoc.getModificationDate()
|
||||
|
||||
if (title) metadata.title = title
|
||||
if (author) metadata.author = author
|
||||
if (subject) metadata.subject = subject
|
||||
if (creator) metadata.creator = creator
|
||||
if (producer) metadata.producer = producer
|
||||
if (creationDate) metadata.creationDate = creationDate.toISOString()
|
||||
if (modificationDate) metadata.modificationDate = modificationDate.toISOString()
|
||||
} catch (metadataError) {
|
||||
logger.warn('Could not extract PDF metadata:', metadataError)
|
||||
}
|
||||
|
||||
logger.info(
|
||||
'pdf-lib loaded successfully, but text extraction requires fallback to raw parser'
|
||||
)
|
||||
const rawResult = await rawPdfParser.parseBuffer(dataBuffer)
|
||||
|
||||
return {
|
||||
content: data.text,
|
||||
content: rawResult.content,
|
||||
metadata: {
|
||||
pageCount: data.numpages,
|
||||
info: data.info,
|
||||
version: data.version,
|
||||
...rawResult.metadata,
|
||||
...metadata,
|
||||
source: 'pdf-lib + raw-parser',
|
||||
},
|
||||
}
|
||||
} catch (pdfParseError: unknown) {
|
||||
logger.error('PDF-parse library failed:', pdfParseError)
|
||||
} catch (pdfLibError: unknown) {
|
||||
logger.error('PDF-lib library failed:', pdfLibError)
|
||||
|
||||
// Fallback to manual text extraction
|
||||
logger.info('Falling back to manual text extraction...')
|
||||
|
||||
// Extract basic PDF info from raw content
|
||||
const rawContent = dataBuffer.toString('utf-8', 0, Math.min(10000, dataBuffer.length))
|
||||
|
||||
let version = 'Unknown'
|
||||
let pageCount = 0
|
||||
|
||||
// Try to extract PDF version
|
||||
const versionMatch = rawContent.match(/%PDF-(\d+\.\d+)/)
|
||||
if (versionMatch?.[1]) {
|
||||
version = versionMatch[1]
|
||||
}
|
||||
|
||||
// Try to get page count
|
||||
const pageMatches = rawContent.match(/\/Type\s*\/Page\b/g)
|
||||
if (pageMatches) {
|
||||
pageCount = pageMatches.length
|
||||
}
|
||||
|
||||
// Try to extract text by looking for text-related operators in the PDF
|
||||
let extractedText = ''
|
||||
|
||||
// Look for text in the PDF content using common patterns
|
||||
const textMatches = rawContent.match(/BT[\s\S]*?ET/g)
|
||||
if (textMatches && textMatches.length > 0) {
|
||||
extractedText = textMatches
|
||||
.map((textBlock) => {
|
||||
// Extract text objects (Tj, TJ) from the text block
|
||||
const textObjects = textBlock.match(/\([^)]*\)\s*Tj|\[[^\]]*\]\s*TJ/g)
|
||||
if (textObjects) {
|
||||
return textObjects
|
||||
.map((obj) => {
|
||||
// Clean up text objects
|
||||
return (
|
||||
obj
|
||||
.replace(
|
||||
/\(([^)]*)\)\s*Tj|\[([^\]]*)\]\s*TJ/g,
|
||||
(match, p1, p2) => p1 || p2 || ''
|
||||
)
|
||||
// Clean up PDF escape sequences
|
||||
.replace(/\\(\d{3}|[()\\])/g, '')
|
||||
.replace(/\\\\/g, '\\')
|
||||
.replace(/\\\(/g, '(')
|
||||
.replace(/\\\)/g, ')')
|
||||
)
|
||||
})
|
||||
.join(' ')
|
||||
}
|
||||
return ''
|
||||
})
|
||||
.join('\n')
|
||||
}
|
||||
|
||||
// If we couldn't extract text or the text is too short, return a fallback message
|
||||
if (!extractedText || extractedText.length < 50) {
|
||||
extractedText = `This PDF contains ${pageCount} page(s) but text extraction was not successful.`
|
||||
}
|
||||
logger.info('Falling back to raw PDF parser...')
|
||||
const rawResult = await rawPdfParser.parseBuffer(dataBuffer)
|
||||
|
||||
return {
|
||||
content: extractedText,
|
||||
...rawResult,
|
||||
metadata: {
|
||||
pageCount,
|
||||
version,
|
||||
...rawResult.metadata,
|
||||
fallback: true,
|
||||
error: (pdfParseError as Error).message || 'Unknown error',
|
||||
source: 'raw-parser-only',
|
||||
error: (pdfLibError as Error).message || 'Unknown error',
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
106
apps/sim/lib/file-parsers/pptx-parser.ts
Normal file
106
apps/sim/lib/file-parsers/pptx-parser.ts
Normal file
@@ -0,0 +1,106 @@
|
||||
import { existsSync } from 'fs'
|
||||
import { readFile } from 'fs/promises'
|
||||
import type { FileParseResult, FileParser } from '@/lib/file-parsers/types'
|
||||
import { sanitizeTextForUTF8 } from '@/lib/file-parsers/utils'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
|
||||
const logger = createLogger('PptxParser')
|
||||
|
||||
export class PptxParser implements FileParser {
|
||||
async parseFile(filePath: string): Promise<FileParseResult> {
|
||||
try {
|
||||
if (!filePath) {
|
||||
throw new Error('No file path provided')
|
||||
}
|
||||
|
||||
if (!existsSync(filePath)) {
|
||||
throw new Error(`File not found: ${filePath}`)
|
||||
}
|
||||
|
||||
logger.info(`Parsing PowerPoint file: ${filePath}`)
|
||||
|
||||
const buffer = await readFile(filePath)
|
||||
return this.parseBuffer(buffer)
|
||||
} catch (error) {
|
||||
logger.error('PowerPoint file parsing error:', error)
|
||||
throw new Error(`Failed to parse PowerPoint file: ${(error as Error).message}`)
|
||||
}
|
||||
}
|
||||
|
||||
async parseBuffer(buffer: Buffer): Promise<FileParseResult> {
|
||||
try {
|
||||
logger.info('Parsing PowerPoint buffer, size:', buffer.length)
|
||||
|
||||
if (!buffer || buffer.length === 0) {
|
||||
throw new Error('Empty buffer provided')
|
||||
}
|
||||
|
||||
let parseOfficeAsync
|
||||
try {
|
||||
const officeParser = await import('officeparser')
|
||||
parseOfficeAsync = officeParser.parseOfficeAsync
|
||||
} catch (importError) {
|
||||
logger.warn('officeparser not available, using fallback extraction')
|
||||
return this.fallbackExtraction(buffer)
|
||||
}
|
||||
|
||||
try {
|
||||
const result = await parseOfficeAsync(buffer)
|
||||
|
||||
if (!result || typeof result !== 'string') {
|
||||
throw new Error('officeparser returned invalid result')
|
||||
}
|
||||
|
||||
const content = sanitizeTextForUTF8(result.trim())
|
||||
|
||||
logger.info('PowerPoint parsing completed successfully with officeparser')
|
||||
|
||||
return {
|
||||
content: content,
|
||||
metadata: {
|
||||
characterCount: content.length,
|
||||
extractionMethod: 'officeparser',
|
||||
},
|
||||
}
|
||||
} catch (extractError) {
|
||||
logger.warn('officeparser failed, using fallback:', extractError)
|
||||
return this.fallbackExtraction(buffer)
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('PowerPoint buffer parsing error:', error)
|
||||
throw new Error(`Failed to parse PowerPoint buffer: ${(error as Error).message}`)
|
||||
}
|
||||
}
|
||||
|
||||
private fallbackExtraction(buffer: Buffer): FileParseResult {
|
||||
logger.info('Using fallback text extraction for PowerPoint file')
|
||||
|
||||
const text = buffer.toString('utf8', 0, Math.min(buffer.length, 200000))
|
||||
|
||||
const readableText = text
|
||||
.match(/[\x20-\x7E\s]{4,}/g)
|
||||
?.filter(
|
||||
(chunk) =>
|
||||
chunk.trim().length > 10 &&
|
||||
/[a-zA-Z]/.test(chunk) &&
|
||||
!/^[\x00-\x1F]*$/.test(chunk) &&
|
||||
!/^[^\w\s]*$/.test(chunk)
|
||||
)
|
||||
.join(' ')
|
||||
.replace(/\s+/g, ' ')
|
||||
.trim()
|
||||
|
||||
const content = readableText
|
||||
? sanitizeTextForUTF8(readableText)
|
||||
: 'Unable to extract text from PowerPoint file. Please ensure the file contains readable text content.'
|
||||
|
||||
return {
|
||||
content,
|
||||
metadata: {
|
||||
extractionMethod: 'fallback',
|
||||
characterCount: content.length,
|
||||
warning: 'Basic text extraction used',
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -6,14 +6,9 @@ import { createLogger } from '@/lib/logs/console/logger'
|
||||
|
||||
const logger = createLogger('RawPdfParser')
|
||||
|
||||
// Promisify zlib functions
|
||||
const inflateAsync = promisify(zlib.inflate)
|
||||
const unzipAsync = promisify(zlib.unzip)
|
||||
|
||||
/**
|
||||
* A simple PDF parser that extracts readable text from a PDF file.
|
||||
* This is used as a fallback when the pdf-parse library fails.
|
||||
*/
|
||||
export class RawPdfParser implements FileParser {
|
||||
async parseFile(filePath: string): Promise<FileParseResult> {
|
||||
try {
|
||||
@@ -23,7 +18,6 @@ export class RawPdfParser implements FileParser {
|
||||
throw new Error('No file path provided')
|
||||
}
|
||||
|
||||
// Read the file
|
||||
logger.info('Reading file...')
|
||||
const dataBuffer = await readFile(filePath)
|
||||
logger.info('File read successfully, size:', dataBuffer.length)
|
||||
@@ -46,31 +40,22 @@ export class RawPdfParser implements FileParser {
|
||||
try {
|
||||
logger.info('Starting to parse buffer, size:', dataBuffer.length)
|
||||
|
||||
// Instead of trying to parse the binary PDF data directly,
|
||||
// we'll extract only the text sections that are readable
|
||||
|
||||
// First convert to string but only for pattern matching, not for display
|
||||
const rawContent = dataBuffer.toString('utf-8')
|
||||
|
||||
// Extract basic PDF info
|
||||
let version = 'Unknown'
|
||||
let pageCount = 0
|
||||
|
||||
// Try to extract PDF version
|
||||
const versionMatch = rawContent.match(/%PDF-(\d+\.\d+)/)
|
||||
if (versionMatch?.[1]) {
|
||||
version = versionMatch[1]
|
||||
}
|
||||
|
||||
// Count pages using multiple methods for redundancy
|
||||
// Method 1: Count "/Type /Page" occurrences (most reliable)
|
||||
const typePageMatches = rawContent.match(/\/Type\s*\/Page\b/gi)
|
||||
if (typePageMatches) {
|
||||
pageCount = typePageMatches.length
|
||||
logger.info('Found page count using /Type /Page:', pageCount)
|
||||
}
|
||||
|
||||
// Method 2: Look for "/Page" dictionary references
|
||||
if (pageCount === 0) {
|
||||
const pageMatches = rawContent.match(/\/Page\s*\//gi)
|
||||
if (pageMatches) {
|
||||
@@ -79,19 +64,15 @@ export class RawPdfParser implements FileParser {
|
||||
}
|
||||
}
|
||||
|
||||
// Method 3: Look for "/Pages" object references
|
||||
if (pageCount === 0) {
|
||||
const pagesObjMatches = rawContent.match(/\/Pages\s+\d+\s+\d+\s+R/gi)
|
||||
if (pagesObjMatches && pagesObjMatches.length > 0) {
|
||||
// Extract the object reference
|
||||
const pagesObjRef = pagesObjMatches[0].match(/\/Pages\s+(\d+)\s+\d+\s+R/i)
|
||||
if (pagesObjRef?.[1]) {
|
||||
const objNum = pagesObjRef[1]
|
||||
// Find the referenced object
|
||||
const objRegex = new RegExp(`${objNum}\\s+0\\s+obj[\\s\\S]*?endobj`, 'i')
|
||||
const objMatch = rawContent.match(objRegex)
|
||||
if (objMatch) {
|
||||
// Look for /Count within the Pages object
|
||||
const countMatch = objMatch[0].match(/\/Count\s+(\d+)/i)
|
||||
if (countMatch?.[1]) {
|
||||
pageCount = Number.parseInt(countMatch[1], 10)
|
||||
@@ -102,50 +83,40 @@ export class RawPdfParser implements FileParser {
|
||||
}
|
||||
}
|
||||
|
||||
// Method 4: Count trailer references to get an approximate count
|
||||
if (pageCount === 0) {
|
||||
const trailerMatches = rawContent.match(/trailer/gi)
|
||||
if (trailerMatches) {
|
||||
// This is just a rough estimate, not accurate
|
||||
pageCount = Math.max(1, Math.ceil(trailerMatches.length / 2))
|
||||
logger.info('Estimated page count using trailer references:', pageCount)
|
||||
}
|
||||
}
|
||||
|
||||
// Default to at least 1 page if we couldn't find any
|
||||
if (pageCount === 0) {
|
||||
pageCount = 1
|
||||
logger.info('Defaulting to 1 page as no count was found')
|
||||
}
|
||||
|
||||
// Extract text content using text markers commonly found in PDFs
|
||||
let extractedText = ''
|
||||
|
||||
// Method 1: Extract text between BT (Begin Text) and ET (End Text) markers
|
||||
const textMatches = rawContent.match(/BT[\s\S]*?ET/g)
|
||||
if (textMatches && textMatches.length > 0) {
|
||||
logger.info('Found', textMatches.length, 'text blocks')
|
||||
|
||||
extractedText = textMatches
|
||||
.map((textBlock) => {
|
||||
// Extract text objects (Tj, TJ) from the text block
|
||||
const textObjects = textBlock.match(/(\([^)]*\)|\[[^\]]*\])\s*(Tj|TJ)/g)
|
||||
if (textObjects && textObjects.length > 0) {
|
||||
return textObjects
|
||||
.map((obj) => {
|
||||
// Clean up text objects
|
||||
let text = ''
|
||||
if (obj.includes('Tj')) {
|
||||
// Handle Tj operator (simple string)
|
||||
const match = obj.match(/\(([^)]*)\)\s*Tj/)
|
||||
if (match?.[1]) {
|
||||
text = match[1]
|
||||
}
|
||||
} else if (obj.includes('TJ')) {
|
||||
// Handle TJ operator (array of strings and positioning)
|
||||
const match = obj.match(/\[(.*)\]\s*TJ/)
|
||||
if (match?.[1]) {
|
||||
// Extract only the string parts from the array
|
||||
const parts = match[1].match(/\([^)]*\)/g)
|
||||
if (parts) {
|
||||
text = parts.map((p) => p.slice(1, -1)).join(' ')
|
||||
@@ -153,7 +124,6 @@ export class RawPdfParser implements FileParser {
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up PDF escape sequences
|
||||
return text
|
||||
.replace(/\\(\d{3})/g, (_, octal) =>
|
||||
String.fromCharCode(Number.parseInt(octal, 8))
|
||||
@@ -170,50 +140,42 @@ export class RawPdfParser implements FileParser {
|
||||
.trim()
|
||||
}
|
||||
|
||||
// Try to extract metadata from XML
|
||||
let metadataText = ''
|
||||
const xmlMatch = rawContent.match(/<x:xmpmeta[\s\S]*?<\/x:xmpmeta>/)
|
||||
if (xmlMatch) {
|
||||
const xmlContent = xmlMatch[0]
|
||||
logger.info('Found XML metadata')
|
||||
|
||||
// Extract document title
|
||||
const titleMatch = xmlContent.match(/<dc:title>[\s\S]*?<rdf:li[^>]*>(.*?)<\/rdf:li>/i)
|
||||
if (titleMatch?.[1]) {
|
||||
const title = titleMatch[1].replace(/<[^>]+>/g, '').trim()
|
||||
metadataText += `Document Title: ${title}\n\n`
|
||||
}
|
||||
|
||||
// Extract creator/author
|
||||
const creatorMatch = xmlContent.match(/<dc:creator>[\s\S]*?<rdf:li[^>]*>(.*?)<\/rdf:li>/i)
|
||||
if (creatorMatch?.[1]) {
|
||||
const creator = creatorMatch[1].replace(/<[^>]+>/g, '').trim()
|
||||
metadataText += `Author: ${creator}\n`
|
||||
}
|
||||
|
||||
// Extract creation date
|
||||
const dateMatch = xmlContent.match(/<xmp:CreateDate>(.*?)<\/xmp:CreateDate>/i)
|
||||
if (dateMatch?.[1]) {
|
||||
metadataText += `Created: ${dateMatch[1].trim()}\n`
|
||||
}
|
||||
|
||||
// Extract producer
|
||||
const producerMatch = xmlContent.match(/<pdf:Producer>(.*?)<\/pdf:Producer>/i)
|
||||
if (producerMatch?.[1]) {
|
||||
metadataText += `Producer: ${producerMatch[1].trim()}\n`
|
||||
}
|
||||
}
|
||||
|
||||
// Try to extract actual text content from content streams
|
||||
if (!extractedText || extractedText.length < 100 || extractedText.includes('/Type /Page')) {
|
||||
logger.info('Trying advanced text extraction from content streams')
|
||||
|
||||
// Find content stream references
|
||||
const contentRefs = rawContent.match(/\/Contents\s+\[?\s*(\d+)\s+\d+\s+R\s*\]?/g)
|
||||
if (contentRefs && contentRefs.length > 0) {
|
||||
logger.info('Found', contentRefs.length, 'content stream references')
|
||||
|
||||
// Extract object numbers from content references
|
||||
const objNumbers = contentRefs
|
||||
.map((ref) => {
|
||||
const match = ref.match(/\/Contents\s+\[?\s*(\d+)\s+\d+\s+R\s*\]?/)
|
||||
@@ -223,7 +185,6 @@ export class RawPdfParser implements FileParser {
|
||||
|
||||
logger.info('Content stream object numbers:', objNumbers)
|
||||
|
||||
// Try to find those objects in the content
|
||||
if (objNumbers.length > 0) {
|
||||
let textFromStreams = ''
|
||||
|
||||
@@ -232,12 +193,10 @@ export class RawPdfParser implements FileParser {
|
||||
const objMatch = rawContent.match(objRegex)
|
||||
|
||||
if (objMatch) {
|
||||
// Look for stream content within the object
|
||||
const streamMatch = objMatch[0].match(/stream\r?\n([\s\S]*?)\r?\nendstream/)
|
||||
if (streamMatch?.[1]) {
|
||||
const streamContent = streamMatch[1]
|
||||
|
||||
// Look for text operations in the stream (Tj, TJ, etc.)
|
||||
const textFragments = streamContent.match(/\([^)]+\)\s*Tj|\[[^\]]*\]\s*TJ/g)
|
||||
if (textFragments && textFragments.length > 0) {
|
||||
const extractedFragments = textFragments
|
||||
@@ -290,35 +249,27 @@ export class RawPdfParser implements FileParser {
|
||||
}
|
||||
}
|
||||
|
||||
// Try to decompress PDF streams
|
||||
// This is especially helpful for PDFs with compressed content
|
||||
if (!extractedText || extractedText.length < 100) {
|
||||
logger.info('Trying to decompress PDF streams')
|
||||
|
||||
// Find compressed streams (FlateDecode)
|
||||
const compressedStreams = rawContent.match(
|
||||
/\/Filter\s*\/FlateDecode[\s\S]*?stream[\s\S]*?endstream/g
|
||||
)
|
||||
if (compressedStreams && compressedStreams.length > 0) {
|
||||
logger.info('Found', compressedStreams.length, 'compressed streams')
|
||||
|
||||
// Process each stream
|
||||
const decompressedContents = await Promise.all(
|
||||
compressedStreams.map(async (stream) => {
|
||||
try {
|
||||
// Extract stream content between stream and endstream
|
||||
const streamMatch = stream.match(/stream\r?\n([\s\S]*?)\r?\nendstream/)
|
||||
if (!streamMatch || !streamMatch[1]) return ''
|
||||
|
||||
const compressedData = Buffer.from(streamMatch[1], 'binary')
|
||||
|
||||
// Try different decompression methods
|
||||
try {
|
||||
// Try inflate (most common)
|
||||
const decompressed = await inflateAsync(compressedData)
|
||||
const content = decompressed.toString('utf-8')
|
||||
|
||||
// Check if it contains readable text
|
||||
const readable = content.replace(/[^\x20-\x7E\r\n]/g, ' ').trim()
|
||||
if (
|
||||
readable.length > 50 &&
|
||||
@@ -329,12 +280,10 @@ export class RawPdfParser implements FileParser {
|
||||
return readable
|
||||
}
|
||||
} catch (_inflateErr) {
|
||||
// Try unzip as fallback
|
||||
try {
|
||||
const decompressed = await unzipAsync(compressedData)
|
||||
const content = decompressed.toString('utf-8')
|
||||
|
||||
// Check if it contains readable text
|
||||
const readable = content.replace(/[^\x20-\x7E\r\n]/g, ' ').trim()
|
||||
if (
|
||||
readable.length > 50 &&
|
||||
@@ -345,12 +294,10 @@ export class RawPdfParser implements FileParser {
|
||||
return readable
|
||||
}
|
||||
} catch (_unzipErr) {
|
||||
// Both methods failed, continue to next stream
|
||||
return ''
|
||||
}
|
||||
}
|
||||
} catch (_error) {
|
||||
// Error processing this stream, skip it
|
||||
return ''
|
||||
}
|
||||
|
||||
@@ -358,7 +305,6 @@ export class RawPdfParser implements FileParser {
|
||||
})
|
||||
)
|
||||
|
||||
// Filter out empty results and combine
|
||||
const decompressedText = decompressedContents
|
||||
.filter((text) => text && text.length > 0)
|
||||
.join('\n\n')
|
||||
@@ -370,26 +316,19 @@ export class RawPdfParser implements FileParser {
|
||||
}
|
||||
}
|
||||
|
||||
// Method 2: Look for text stream data
|
||||
if (!extractedText || extractedText.length < 50) {
|
||||
logger.info('Trying alternative text extraction method with streams')
|
||||
|
||||
// Find text streams
|
||||
const streamMatches = rawContent.match(/stream[\s\S]*?endstream/g)
|
||||
if (streamMatches && streamMatches.length > 0) {
|
||||
logger.info('Found', streamMatches.length, 'streams')
|
||||
|
||||
// Process each stream to look for text content
|
||||
const textContent = streamMatches
|
||||
.map((stream) => {
|
||||
// Remove 'stream' and 'endstream' markers
|
||||
const content = stream.replace(/^stream\r?\n|\r?\nendstream$/g, '')
|
||||
|
||||
// Look for readable ASCII text (more strict heuristic)
|
||||
// Only keep ASCII printable characters
|
||||
const readable = content.replace(/[^\x20-\x7E\r\n]/g, ' ').trim()
|
||||
|
||||
// Only keep content that looks like real text (has spaces, periods, etc.)
|
||||
if (
|
||||
readable.length > 20 &&
|
||||
readable.includes(' ') &&
|
||||
@@ -400,7 +339,7 @@ export class RawPdfParser implements FileParser {
|
||||
}
|
||||
return ''
|
||||
})
|
||||
.filter((text) => text.length > 0 && text.split(' ').length > 5) // Must have at least 5 words
|
||||
.filter((text) => text.length > 0 && text.split(' ').length > 5)
|
||||
.join('\n\n')
|
||||
|
||||
if (textContent.length > 0) {
|
||||
@@ -409,22 +348,17 @@ export class RawPdfParser implements FileParser {
|
||||
}
|
||||
}
|
||||
|
||||
// Method 3: Look for object streams
|
||||
if (!extractedText || extractedText.length < 50) {
|
||||
logger.info('Trying object streams for text')
|
||||
|
||||
// Find object stream content
|
||||
const objMatches = rawContent.match(/\d+\s+\d+\s+obj[\s\S]*?endobj/g)
|
||||
if (objMatches && objMatches.length > 0) {
|
||||
logger.info('Found', objMatches.length, 'objects')
|
||||
|
||||
// Process objects looking for text content
|
||||
const textContent = objMatches
|
||||
.map((obj) => {
|
||||
// Find readable text in the object - only keep ASCII printable characters
|
||||
const readable = obj.replace(/[^\x20-\x7E\r\n]/g, ' ').trim()
|
||||
|
||||
// Only include if it looks like actual text (strict heuristic)
|
||||
if (
|
||||
readable.length > 50 &&
|
||||
readable.includes(' ') &&
|
||||
@@ -445,8 +379,6 @@ export class RawPdfParser implements FileParser {
|
||||
}
|
||||
}
|
||||
|
||||
// If what we extracted is just PDF structure information rather than readable text,
|
||||
// provide a clearer message
|
||||
if (
|
||||
extractedText &&
|
||||
(extractedText.includes('endobj') ||
|
||||
@@ -459,53 +391,41 @@ export class RawPdfParser implements FileParser {
|
||||
)
|
||||
extractedText = metadataText
|
||||
} else if (metadataText && !extractedText.includes('Document Title:')) {
|
||||
// Prepend metadata to extracted text if available
|
||||
extractedText = metadataText + (extractedText ? `\n\n${extractedText}` : '')
|
||||
}
|
||||
|
||||
// Validate that the extracted text looks meaningful
|
||||
// Count how many recognizable words/characters it contains
|
||||
const validCharCount = (extractedText || '').replace(/[^\x20-\x7E\r\n]/g, '').length
|
||||
const totalCharCount = (extractedText || '').length
|
||||
const validRatio = validCharCount / (totalCharCount || 1)
|
||||
|
||||
// Check for common PDF artifacts that indicate binary corruption
|
||||
const hasBinaryArtifacts =
|
||||
extractedText &&
|
||||
(extractedText.includes('\\u') ||
|
||||
extractedText.includes('\\x') ||
|
||||
extractedText.includes('\0') ||
|
||||
/[\x00-\x08\x0B\x0C\x0E-\x1F\x7F-\xFF]{10,}/g.test(extractedText) ||
|
||||
validRatio < 0.7) // Less than 70% valid characters
|
||||
validRatio < 0.7)
|
||||
|
||||
// Check if the content looks like gibberish
|
||||
const looksLikeGibberish =
|
||||
extractedText &&
|
||||
// Too many special characters
|
||||
(extractedText.replace(/[a-zA-Z0-9\s.,:'"()[\]{}]/g, '').length / extractedText.length >
|
||||
0.3 ||
|
||||
// Not enough spaces (real text has spaces between words)
|
||||
extractedText.split(' ').length < extractedText.length / 20)
|
||||
|
||||
// If no text was extracted, or if it's binary/gibberish,
|
||||
// provide a helpful message instead
|
||||
if (!extractedText || extractedText.length < 50 || hasBinaryArtifacts || looksLikeGibberish) {
|
||||
logger.info('Could not extract meaningful text, providing fallback message')
|
||||
logger.info('Valid character ratio:', validRatio)
|
||||
logger.info('Has binary artifacts:', hasBinaryArtifacts)
|
||||
logger.info('Looks like gibberish:', looksLikeGibberish)
|
||||
|
||||
// Start with metadata if available
|
||||
if (metadataText) {
|
||||
extractedText = `${metadataText}\n`
|
||||
} else {
|
||||
extractedText = ''
|
||||
}
|
||||
|
||||
// Add basic PDF info
|
||||
extractedText += `This is a PDF document with ${pageCount} page(s) and version ${version}.\n\n`
|
||||
|
||||
// Try to find a title in the PDF structure that we might have missed
|
||||
const titleInStructure =
|
||||
rawContent.match(/title\s*:\s*([^\n]+)/i) ||
|
||||
rawContent.match(/Microsoft Word -\s*([^\n]+)/i)
|
||||
|
||||
@@ -8,15 +8,12 @@ const logger = createLogger('TxtParser')
|
||||
export class TxtParser implements FileParser {
|
||||
async parseFile(filePath: string): Promise<FileParseResult> {
|
||||
try {
|
||||
// Validate input
|
||||
if (!filePath) {
|
||||
throw new Error('No file path provided')
|
||||
}
|
||||
|
||||
// Read the file
|
||||
const buffer = await readFile(filePath)
|
||||
|
||||
// Use parseBuffer for consistent implementation
|
||||
return this.parseBuffer(buffer)
|
||||
} catch (error) {
|
||||
logger.error('TXT file error:', error)
|
||||
@@ -28,7 +25,6 @@ export class TxtParser implements FileParser {
|
||||
try {
|
||||
logger.info('Parsing buffer, size:', buffer.length)
|
||||
|
||||
// Extract content and sanitize for UTF-8 storage
|
||||
const rawContent = buffer.toString('utf-8')
|
||||
const result = sanitizeTextForUTF8(rawContent)
|
||||
|
||||
|
||||
@@ -8,4 +8,16 @@ export interface FileParser {
|
||||
parseBuffer?(buffer: Buffer): Promise<FileParseResult>
|
||||
}
|
||||
|
||||
export type SupportedFileType = 'pdf' | 'csv' | 'doc' | 'docx' | 'txt' | 'md' | 'xlsx' | 'xls'
|
||||
export type SupportedFileType =
|
||||
| 'pdf'
|
||||
| 'csv'
|
||||
| 'doc'
|
||||
| 'docx'
|
||||
| 'txt'
|
||||
| 'md'
|
||||
| 'xlsx'
|
||||
| 'xls'
|
||||
| 'html'
|
||||
| 'htm'
|
||||
| 'pptx'
|
||||
| 'ppt'
|
||||
|
||||
@@ -9,19 +9,16 @@ const logger = createLogger('XlsxParser')
|
||||
export class XlsxParser implements FileParser {
|
||||
async parseFile(filePath: string): Promise<FileParseResult> {
|
||||
try {
|
||||
// Validate input
|
||||
if (!filePath) {
|
||||
throw new Error('No file path provided')
|
||||
}
|
||||
|
||||
// Check if file exists
|
||||
if (!existsSync(filePath)) {
|
||||
throw new Error(`File not found: ${filePath}`)
|
||||
}
|
||||
|
||||
logger.info(`Parsing XLSX file: ${filePath}`)
|
||||
|
||||
// Read the workbook
|
||||
const workbook = XLSX.readFile(filePath)
|
||||
return this.processWorkbook(workbook)
|
||||
} catch (error) {
|
||||
@@ -38,7 +35,6 @@ export class XlsxParser implements FileParser {
|
||||
throw new Error('Empty buffer provided')
|
||||
}
|
||||
|
||||
// Read the workbook from buffer
|
||||
const workbook = XLSX.read(buffer, { type: 'buffer' })
|
||||
return this.processWorkbook(workbook)
|
||||
} catch (error) {
|
||||
@@ -53,25 +49,20 @@ export class XlsxParser implements FileParser {
|
||||
let content = ''
|
||||
let totalRows = 0
|
||||
|
||||
// Process each worksheet
|
||||
for (const sheetName of sheetNames) {
|
||||
const worksheet = workbook.Sheets[sheetName]
|
||||
|
||||
// Convert to array of objects
|
||||
const sheetData = XLSX.utils.sheet_to_json(worksheet, { header: 1 })
|
||||
sheets[sheetName] = sheetData
|
||||
totalRows += sheetData.length
|
||||
|
||||
// Add sheet content to the overall content string (clean sheet name)
|
||||
const cleanSheetName = sanitizeTextForUTF8(sheetName)
|
||||
content += `Sheet: ${cleanSheetName}\n`
|
||||
content += `=${'='.repeat(cleanSheetName.length + 6)}\n\n`
|
||||
|
||||
if (sheetData.length > 0) {
|
||||
// Process each row
|
||||
sheetData.forEach((row: unknown, rowIndex: number) => {
|
||||
if (Array.isArray(row) && row.length > 0) {
|
||||
// Convert row to string, handling undefined/null values and cleaning non-UTF8 characters
|
||||
const rowString = row
|
||||
.map((cell) => {
|
||||
if (cell === null || cell === undefined) {
|
||||
@@ -93,7 +84,6 @@ export class XlsxParser implements FileParser {
|
||||
|
||||
logger.info(`XLSX parsing completed: ${sheetNames.length} sheets, ${totalRows} total rows`)
|
||||
|
||||
// Final cleanup of the entire content to ensure UTF-8 compatibility
|
||||
const cleanContent = sanitizeTextForUTF8(content).trim()
|
||||
|
||||
return {
|
||||
|
||||
24
apps/sim/lib/knowledge/consts.ts
Normal file
24
apps/sim/lib/knowledge/consts.ts
Normal file
@@ -0,0 +1,24 @@
|
||||
export const TAG_SLOT_CONFIG = {
|
||||
text: {
|
||||
slots: ['tag1', 'tag2', 'tag3', 'tag4', 'tag5', 'tag6', 'tag7'] as const,
|
||||
maxSlots: 7,
|
||||
},
|
||||
} as const
|
||||
|
||||
export const SUPPORTED_FIELD_TYPES = Object.keys(TAG_SLOT_CONFIG) as Array<
|
||||
keyof typeof TAG_SLOT_CONFIG
|
||||
>
|
||||
|
||||
export const TAG_SLOTS = TAG_SLOT_CONFIG.text.slots
|
||||
|
||||
export const MAX_TAG_SLOTS = TAG_SLOT_CONFIG.text.maxSlots
|
||||
|
||||
export type TagSlot = (typeof TAG_SLOTS)[number]
|
||||
|
||||
export function getSlotsForFieldType(fieldType: string): readonly string[] {
|
||||
const config = TAG_SLOT_CONFIG[fieldType as keyof typeof TAG_SLOT_CONFIG]
|
||||
if (!config) {
|
||||
return []
|
||||
}
|
||||
return config.slots
|
||||
}
|
||||
@@ -1,9 +1,9 @@
|
||||
import crypto, { randomUUID } from 'crypto'
|
||||
import { tasks } from '@trigger.dev/sdk'
|
||||
import { and, asc, desc, eq, inArray, isNull, sql } from 'drizzle-orm'
|
||||
import { getSlotsForFieldType, type TAG_SLOT_CONFIG } from '@/lib/constants/knowledge'
|
||||
import { generateEmbeddings } from '@/lib/embeddings/utils'
|
||||
import { env } from '@/lib/env'
|
||||
import { getSlotsForFieldType, type TAG_SLOT_CONFIG } from '@/lib/knowledge/consts'
|
||||
import { processDocument } from '@/lib/knowledge/documents/document-processor'
|
||||
import { getNextAvailableSlot } from '@/lib/knowledge/tags/service'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
@@ -17,8 +17,8 @@ import type { DocumentSortField, SortOrder } from './types'
|
||||
const logger = createLogger('DocumentService')
|
||||
|
||||
const TIMEOUTS = {
|
||||
OVERALL_PROCESSING: 600000,
|
||||
EMBEDDINGS_API: 180000,
|
||||
OVERALL_PROCESSING: (env.KB_CONFIG_MAX_DURATION || 300) * 1000,
|
||||
EMBEDDINGS_API: (env.KB_CONFIG_MAX_TIMEOUT || 10000) * 18,
|
||||
} as const
|
||||
|
||||
/**
|
||||
@@ -38,17 +38,17 @@ function withTimeout<T>(
|
||||
}
|
||||
|
||||
const PROCESSING_CONFIG = {
|
||||
maxConcurrentDocuments: 4,
|
||||
batchSize: 10,
|
||||
delayBetweenBatches: 200,
|
||||
delayBetweenDocuments: 100,
|
||||
maxConcurrentDocuments: Math.max(1, Math.floor((env.KB_CONFIG_CONCURRENCY_LIMIT || 20) / 5)) || 4,
|
||||
batchSize: Math.max(1, Math.floor((env.KB_CONFIG_BATCH_SIZE || 20) / 2)) || 10,
|
||||
delayBetweenBatches: (env.KB_CONFIG_DELAY_BETWEEN_BATCHES || 100) * 2,
|
||||
delayBetweenDocuments: (env.KB_CONFIG_DELAY_BETWEEN_DOCUMENTS || 50) * 2,
|
||||
}
|
||||
|
||||
const REDIS_PROCESSING_CONFIG = {
|
||||
maxConcurrentDocuments: 12,
|
||||
batchSize: 20,
|
||||
delayBetweenBatches: 100,
|
||||
delayBetweenDocuments: 50,
|
||||
maxConcurrentDocuments: env.KB_CONFIG_CONCURRENCY_LIMIT || 20,
|
||||
batchSize: env.KB_CONFIG_BATCH_SIZE || 20,
|
||||
delayBetweenBatches: env.KB_CONFIG_DELAY_BETWEEN_BATCHES || 100,
|
||||
delayBetweenDocuments: env.KB_CONFIG_DELAY_BETWEEN_DOCUMENTS || 50,
|
||||
}
|
||||
|
||||
let documentQueue: DocumentProcessingQueue | null = null
|
||||
@@ -59,8 +59,8 @@ export function getDocumentQueue(): DocumentProcessingQueue {
|
||||
const config = redisClient ? REDIS_PROCESSING_CONFIG : PROCESSING_CONFIG
|
||||
documentQueue = new DocumentProcessingQueue({
|
||||
maxConcurrent: config.maxConcurrentDocuments,
|
||||
retryDelay: 2000,
|
||||
maxRetries: 5,
|
||||
retryDelay: env.KB_CONFIG_MIN_TIMEOUT || 1000,
|
||||
maxRetries: env.KB_CONFIG_MAX_ATTEMPTS || 3,
|
||||
})
|
||||
}
|
||||
return documentQueue
|
||||
|
||||
@@ -4,7 +4,7 @@ import {
|
||||
getSlotsForFieldType,
|
||||
SUPPORTED_FIELD_TYPES,
|
||||
type TAG_SLOT_CONFIG,
|
||||
} from '@/lib/constants/knowledge'
|
||||
} from '@/lib/knowledge/consts'
|
||||
import type { BulkTagDefinitionsData, DocumentTagDefinition } from '@/lib/knowledge/tags/types'
|
||||
import type {
|
||||
CreateTagDefinitionData,
|
||||
|
||||
@@ -7,6 +7,7 @@ vi.mock('@/db', () => ({
|
||||
where: vi.fn(),
|
||||
limit: vi.fn(),
|
||||
innerJoin: vi.fn(),
|
||||
leftJoin: vi.fn(),
|
||||
orderBy: vi.fn(),
|
||||
},
|
||||
}))
|
||||
@@ -17,6 +18,7 @@ vi.mock('@/db/schema', () => ({
|
||||
userId: 'user_id',
|
||||
entityType: 'entity_type',
|
||||
entityId: 'entity_id',
|
||||
id: 'permission_id',
|
||||
},
|
||||
permissionTypeEnum: {
|
||||
enumValues: ['admin', 'write', 'read'] as const,
|
||||
@@ -25,23 +27,18 @@ vi.mock('@/db/schema', () => ({
|
||||
id: 'user_id',
|
||||
email: 'user_email',
|
||||
name: 'user_name',
|
||||
image: 'user_image',
|
||||
},
|
||||
workspace: {
|
||||
id: 'workspace_id',
|
||||
name: 'workspace_name',
|
||||
ownerId: 'workspace_owner_id',
|
||||
},
|
||||
member: {
|
||||
userId: 'member_user_id',
|
||||
organizationId: 'member_organization_id',
|
||||
role: 'member_role',
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('drizzle-orm', () => ({
|
||||
and: vi.fn().mockReturnValue('and-condition'),
|
||||
eq: vi.fn().mockReturnValue('eq-condition'),
|
||||
or: vi.fn().mockReturnValue('or-condition'),
|
||||
}))
|
||||
|
||||
import {
|
||||
@@ -50,8 +47,6 @@ import {
|
||||
getUsersWithPermissions,
|
||||
hasAdminPermission,
|
||||
hasWorkspaceAdminAccess,
|
||||
isOrganizationAdminForWorkspace,
|
||||
isOrganizationOwnerOrAdmin,
|
||||
} from '@/lib/permissions/utils'
|
||||
import { db } from '@/db'
|
||||
|
||||
@@ -124,11 +119,64 @@ describe('Permission Utils', () => {
|
||||
|
||||
expect(result).toBe('admin')
|
||||
})
|
||||
|
||||
it('should return write permission when user only has write access', async () => {
|
||||
const mockResults = [{ permissionType: 'write' as PermissionType }]
|
||||
const chain = createMockChain(mockResults)
|
||||
mockDb.select.mockReturnValue(chain)
|
||||
|
||||
const result = await getUserEntityPermissions('user123', 'workspace', 'workspace456')
|
||||
|
||||
expect(result).toBe('write')
|
||||
})
|
||||
|
||||
it('should prioritize write over read permissions', async () => {
|
||||
const mockResults = [
|
||||
{ permissionType: 'read' as PermissionType },
|
||||
{ permissionType: 'write' as PermissionType },
|
||||
]
|
||||
const chain = createMockChain(mockResults)
|
||||
mockDb.select.mockReturnValue(chain)
|
||||
|
||||
const result = await getUserEntityPermissions('user123', 'workspace', 'workspace456')
|
||||
|
||||
expect(result).toBe('write')
|
||||
})
|
||||
|
||||
it('should work with workflow entity type', async () => {
|
||||
const mockResults = [{ permissionType: 'admin' as PermissionType }]
|
||||
const chain = createMockChain(mockResults)
|
||||
mockDb.select.mockReturnValue(chain)
|
||||
|
||||
const result = await getUserEntityPermissions('user123', 'workflow', 'workflow789')
|
||||
|
||||
expect(result).toBe('admin')
|
||||
})
|
||||
|
||||
it('should work with organization entity type', async () => {
|
||||
const mockResults = [{ permissionType: 'read' as PermissionType }]
|
||||
const chain = createMockChain(mockResults)
|
||||
mockDb.select.mockReturnValue(chain)
|
||||
|
||||
const result = await getUserEntityPermissions('user123', 'organization', 'org456')
|
||||
|
||||
expect(result).toBe('read')
|
||||
})
|
||||
|
||||
it('should handle generic entity types', async () => {
|
||||
const mockResults = [{ permissionType: 'write' as PermissionType }]
|
||||
const chain = createMockChain(mockResults)
|
||||
mockDb.select.mockReturnValue(chain)
|
||||
|
||||
const result = await getUserEntityPermissions('user123', 'custom_entity', 'entity123')
|
||||
|
||||
expect(result).toBe('write')
|
||||
})
|
||||
})
|
||||
|
||||
describe('hasAdminPermission', () => {
|
||||
it('should return true when user has admin permission for workspace', async () => {
|
||||
const chain = createMockChain([{ permissionType: 'admin' }])
|
||||
const chain = createMockChain([{ id: 'perm1' }])
|
||||
mockDb.select.mockReturnValue(chain)
|
||||
|
||||
const result = await hasAdminPermission('admin-user', 'workspace123')
|
||||
@@ -144,6 +192,42 @@ describe('Permission Utils', () => {
|
||||
|
||||
expect(result).toBe(false)
|
||||
})
|
||||
|
||||
it('should return false when user has write permission but not admin', async () => {
|
||||
const chain = createMockChain([])
|
||||
mockDb.select.mockReturnValue(chain)
|
||||
|
||||
const result = await hasAdminPermission('write-user', 'workspace123')
|
||||
|
||||
expect(result).toBe(false)
|
||||
})
|
||||
|
||||
it('should return false when user has read permission but not admin', async () => {
|
||||
const chain = createMockChain([])
|
||||
mockDb.select.mockReturnValue(chain)
|
||||
|
||||
const result = await hasAdminPermission('read-user', 'workspace123')
|
||||
|
||||
expect(result).toBe(false)
|
||||
})
|
||||
|
||||
it('should handle non-existent workspace', async () => {
|
||||
const chain = createMockChain([])
|
||||
mockDb.select.mockReturnValue(chain)
|
||||
|
||||
const result = await hasAdminPermission('user123', 'non-existent-workspace')
|
||||
|
||||
expect(result).toBe(false)
|
||||
})
|
||||
|
||||
it('should handle empty user ID', async () => {
|
||||
const chain = createMockChain([])
|
||||
mockDb.select.mockReturnValue(chain)
|
||||
|
||||
const result = await hasAdminPermission('', 'workspace123')
|
||||
|
||||
expect(result).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('getUsersWithPermissions', () => {
|
||||
@@ -162,7 +246,6 @@ describe('Permission Utils', () => {
|
||||
userId: 'user1',
|
||||
email: 'alice@example.com',
|
||||
name: 'Alice Smith',
|
||||
image: 'https://example.com/alice.jpg',
|
||||
permissionType: 'admin' as PermissionType,
|
||||
},
|
||||
]
|
||||
@@ -177,43 +260,66 @@ describe('Permission Utils', () => {
|
||||
userId: 'user1',
|
||||
email: 'alice@example.com',
|
||||
name: 'Alice Smith',
|
||||
image: 'https://example.com/alice.jpg',
|
||||
permissionType: 'admin',
|
||||
},
|
||||
])
|
||||
})
|
||||
})
|
||||
|
||||
describe('isOrganizationAdminForWorkspace', () => {
|
||||
it('should return false when workspace does not exist', async () => {
|
||||
const chain = createMockChain([])
|
||||
mockDb.select.mockReturnValue(chain)
|
||||
it('should return multiple users with different permission levels', async () => {
|
||||
const mockUsersResults = [
|
||||
{
|
||||
userId: 'user1',
|
||||
email: 'admin@example.com',
|
||||
name: 'Admin User',
|
||||
permissionType: 'admin' as PermissionType,
|
||||
},
|
||||
{
|
||||
userId: 'user2',
|
||||
email: 'writer@example.com',
|
||||
name: 'Writer User',
|
||||
permissionType: 'write' as PermissionType,
|
||||
},
|
||||
{
|
||||
userId: 'user3',
|
||||
email: 'reader@example.com',
|
||||
name: 'Reader User',
|
||||
permissionType: 'read' as PermissionType,
|
||||
},
|
||||
]
|
||||
|
||||
const result = await isOrganizationAdminForWorkspace('user123', 'workspace456')
|
||||
const usersChain = createMockChain(mockUsersResults)
|
||||
mockDb.select.mockReturnValue(usersChain)
|
||||
|
||||
expect(result).toBe(false)
|
||||
const result = await getUsersWithPermissions('workspace456')
|
||||
|
||||
expect(result).toHaveLength(3)
|
||||
expect(result[0].permissionType).toBe('admin')
|
||||
expect(result[1].permissionType).toBe('write')
|
||||
expect(result[2].permissionType).toBe('read')
|
||||
})
|
||||
|
||||
it('should return false when user has no organization memberships', async () => {
|
||||
// Mock workspace exists, but user has no org memberships
|
||||
let callCount = 0
|
||||
mockDb.select.mockImplementation(() => {
|
||||
callCount++
|
||||
if (callCount === 1) {
|
||||
return createMockChain([{ ownerId: 'workspace-owner-123' }])
|
||||
}
|
||||
return createMockChain([]) // No memberships
|
||||
})
|
||||
it('should handle users with empty names', async () => {
|
||||
const mockUsersResults = [
|
||||
{
|
||||
userId: 'user1',
|
||||
email: 'test@example.com',
|
||||
name: '',
|
||||
permissionType: 'read' as PermissionType,
|
||||
},
|
||||
]
|
||||
|
||||
const result = await isOrganizationAdminForWorkspace('user123', 'workspace456')
|
||||
const usersChain = createMockChain(mockUsersResults)
|
||||
mockDb.select.mockReturnValue(usersChain)
|
||||
|
||||
expect(result).toBe(false)
|
||||
const result = await getUsersWithPermissions('workspace123')
|
||||
|
||||
expect(result[0].name).toBe('')
|
||||
})
|
||||
})
|
||||
|
||||
describe('hasWorkspaceAdminAccess', () => {
|
||||
it('should return true when user has direct admin permission', async () => {
|
||||
const chain = createMockChain([{ permissionType: 'admin' }])
|
||||
it('should return true when user owns the workspace', async () => {
|
||||
const chain = createMockChain([{ ownerId: 'user123' }])
|
||||
mockDb.select.mockReturnValue(chain)
|
||||
|
||||
const result = await hasWorkspaceAdminAccess('user123', 'workspace456')
|
||||
@@ -221,7 +327,22 @@ describe('Permission Utils', () => {
|
||||
expect(result).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false when user has neither direct nor organization admin access', async () => {
|
||||
it('should return true when user has direct admin permission', async () => {
|
||||
let callCount = 0
|
||||
mockDb.select.mockImplementation(() => {
|
||||
callCount++
|
||||
if (callCount === 1) {
|
||||
return createMockChain([{ ownerId: 'other-user' }])
|
||||
}
|
||||
return createMockChain([{ id: 'perm1' }])
|
||||
})
|
||||
|
||||
const result = await hasWorkspaceAdminAccess('user123', 'workspace456')
|
||||
|
||||
expect(result).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false when workspace does not exist', async () => {
|
||||
const chain = createMockChain([])
|
||||
mockDb.select.mockReturnValue(chain)
|
||||
|
||||
@@ -229,51 +350,137 @@ describe('Permission Utils', () => {
|
||||
|
||||
expect(result).toBe(false)
|
||||
})
|
||||
|
||||
it('should return false when user has no admin access', async () => {
|
||||
let callCount = 0
|
||||
mockDb.select.mockImplementation(() => {
|
||||
callCount++
|
||||
if (callCount === 1) {
|
||||
return createMockChain([{ ownerId: 'other-user' }])
|
||||
}
|
||||
return createMockChain([])
|
||||
})
|
||||
|
||||
const result = await hasWorkspaceAdminAccess('user123', 'workspace456')
|
||||
|
||||
expect(result).toBe(false)
|
||||
})
|
||||
|
||||
it('should return false when user has write permission but not admin', async () => {
|
||||
let callCount = 0
|
||||
mockDb.select.mockImplementation(() => {
|
||||
callCount++
|
||||
if (callCount === 1) {
|
||||
return createMockChain([{ ownerId: 'other-user' }])
|
||||
}
|
||||
return createMockChain([])
|
||||
})
|
||||
|
||||
const result = await hasWorkspaceAdminAccess('user123', 'workspace456')
|
||||
|
||||
expect(result).toBe(false)
|
||||
})
|
||||
|
||||
it('should return false when user has read permission but not admin', async () => {
|
||||
let callCount = 0
|
||||
mockDb.select.mockImplementation(() => {
|
||||
callCount++
|
||||
if (callCount === 1) {
|
||||
return createMockChain([{ ownerId: 'other-user' }])
|
||||
}
|
||||
return createMockChain([])
|
||||
})
|
||||
|
||||
const result = await hasWorkspaceAdminAccess('user123', 'workspace456')
|
||||
|
||||
expect(result).toBe(false)
|
||||
})
|
||||
|
||||
it('should handle empty workspace ID', async () => {
|
||||
const chain = createMockChain([])
|
||||
mockDb.select.mockReturnValue(chain)
|
||||
|
||||
const result = await hasWorkspaceAdminAccess('user123', '')
|
||||
|
||||
expect(result).toBe(false)
|
||||
})
|
||||
|
||||
it('should handle empty user ID', async () => {
|
||||
const chain = createMockChain([])
|
||||
mockDb.select.mockReturnValue(chain)
|
||||
|
||||
const result = await hasWorkspaceAdminAccess('', 'workspace456')
|
||||
|
||||
expect(result).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('isOrganizationOwnerOrAdmin', () => {
|
||||
it('should return true when user is owner of organization', async () => {
|
||||
const chain = createMockChain([{ role: 'owner' }])
|
||||
mockDb.select.mockReturnValue(chain)
|
||||
|
||||
const result = await isOrganizationOwnerOrAdmin('user123', 'org456')
|
||||
|
||||
expect(result).toBe(true)
|
||||
})
|
||||
|
||||
it('should return true when user is admin of organization', async () => {
|
||||
const chain = createMockChain([{ role: 'admin' }])
|
||||
mockDb.select.mockReturnValue(chain)
|
||||
|
||||
const result = await isOrganizationOwnerOrAdmin('user123', 'org456')
|
||||
|
||||
expect(result).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false when user is regular member of organization', async () => {
|
||||
const chain = createMockChain([{ role: 'member' }])
|
||||
mockDb.select.mockReturnValue(chain)
|
||||
|
||||
const result = await isOrganizationOwnerOrAdmin('user123', 'org456')
|
||||
|
||||
expect(result).toBe(false)
|
||||
})
|
||||
|
||||
it('should return false when user is not member of organization', async () => {
|
||||
describe('Edge Cases and Security Tests', () => {
|
||||
it('should handle SQL injection attempts in user IDs', async () => {
|
||||
const chain = createMockChain([])
|
||||
mockDb.select.mockReturnValue(chain)
|
||||
|
||||
const result = await isOrganizationOwnerOrAdmin('user123', 'org456')
|
||||
const result = await getUserEntityPermissions(
|
||||
"'; DROP TABLE users; --",
|
||||
'workspace',
|
||||
'workspace123'
|
||||
)
|
||||
|
||||
expect(result).toBeNull()
|
||||
})
|
||||
|
||||
it('should handle very long entity IDs', async () => {
|
||||
const longEntityId = 'a'.repeat(1000)
|
||||
const chain = createMockChain([])
|
||||
mockDb.select.mockReturnValue(chain)
|
||||
|
||||
const result = await getUserEntityPermissions('user123', 'workspace', longEntityId)
|
||||
|
||||
expect(result).toBeNull()
|
||||
})
|
||||
|
||||
it('should handle unicode characters in entity names', async () => {
|
||||
const chain = createMockChain([{ permissionType: 'read' as PermissionType }])
|
||||
mockDb.select.mockReturnValue(chain)
|
||||
|
||||
const result = await getUserEntityPermissions('user123', '📝workspace', '🏢org-id')
|
||||
|
||||
expect(result).toBe('read')
|
||||
})
|
||||
|
||||
it('should verify permission hierarchy ordering is consistent', () => {
|
||||
const permissionOrder: Record<PermissionType, number> = { admin: 3, write: 2, read: 1 }
|
||||
|
||||
expect(permissionOrder.admin).toBeGreaterThan(permissionOrder.write)
|
||||
expect(permissionOrder.write).toBeGreaterThan(permissionOrder.read)
|
||||
})
|
||||
|
||||
it('should handle workspace ownership checks with null owner IDs', async () => {
|
||||
let callCount = 0
|
||||
mockDb.select.mockImplementation(() => {
|
||||
callCount++
|
||||
if (callCount === 1) {
|
||||
return createMockChain([{ ownerId: null }])
|
||||
}
|
||||
return createMockChain([])
|
||||
})
|
||||
|
||||
const result = await hasWorkspaceAdminAccess('user123', 'workspace456')
|
||||
|
||||
expect(result).toBe(false)
|
||||
})
|
||||
|
||||
it('should handle errors gracefully', async () => {
|
||||
it('should handle null user ID correctly when owner ID is different', async () => {
|
||||
let callCount = 0
|
||||
mockDb.select.mockImplementation(() => {
|
||||
throw new Error('Database error')
|
||||
callCount++
|
||||
if (callCount === 1) {
|
||||
return createMockChain([{ ownerId: 'other-user' }])
|
||||
}
|
||||
return createMockChain([])
|
||||
})
|
||||
|
||||
const result = await isOrganizationOwnerOrAdmin('user123', 'org456')
|
||||
const result = await hasWorkspaceAdminAccess(null as any, 'workspace456')
|
||||
|
||||
expect(result).toBe(false)
|
||||
})
|
||||
@@ -289,27 +496,121 @@ describe('Permission Utils', () => {
|
||||
expect(result).toEqual([])
|
||||
})
|
||||
|
||||
it('should return direct admin workspaces', async () => {
|
||||
const mockDirectWorkspaces = [
|
||||
{ id: 'ws1', name: 'Workspace 1', ownerId: 'owner1' },
|
||||
{ id: 'ws2', name: 'Workspace 2', ownerId: 'owner2' },
|
||||
it('should return owned workspaces', async () => {
|
||||
const mockWorkspaces = [
|
||||
{ id: 'ws1', name: 'My Workspace 1', ownerId: 'user123' },
|
||||
{ id: 'ws2', name: 'My Workspace 2', ownerId: 'user123' },
|
||||
]
|
||||
|
||||
let callCount = 0
|
||||
mockDb.select.mockImplementation(() => {
|
||||
callCount++
|
||||
if (callCount === 1) {
|
||||
return createMockChain(mockDirectWorkspaces) // direct admin workspaces
|
||||
return createMockChain(mockWorkspaces) // Owned workspaces
|
||||
}
|
||||
return createMockChain([]) // no organization memberships
|
||||
return createMockChain([]) // No admin workspaces
|
||||
})
|
||||
|
||||
const result = await getManageableWorkspaces('user123')
|
||||
|
||||
expect(result).toEqual([
|
||||
{ id: 'ws1', name: 'Workspace 1', ownerId: 'owner1', accessType: 'direct' },
|
||||
{ id: 'ws2', name: 'Workspace 2', ownerId: 'owner2', accessType: 'direct' },
|
||||
{ id: 'ws1', name: 'My Workspace 1', ownerId: 'user123', accessType: 'owner' },
|
||||
{ id: 'ws2', name: 'My Workspace 2', ownerId: 'user123', accessType: 'owner' },
|
||||
])
|
||||
})
|
||||
|
||||
it('should return workspaces with direct admin permissions', async () => {
|
||||
const mockAdminWorkspaces = [{ id: 'ws1', name: 'Shared Workspace', ownerId: 'other-user' }]
|
||||
|
||||
let callCount = 0
|
||||
mockDb.select.mockImplementation(() => {
|
||||
callCount++
|
||||
if (callCount === 1) {
|
||||
return createMockChain([]) // No owned workspaces
|
||||
}
|
||||
return createMockChain(mockAdminWorkspaces) // Admin workspaces
|
||||
})
|
||||
|
||||
const result = await getManageableWorkspaces('user123')
|
||||
|
||||
expect(result).toEqual([
|
||||
{ id: 'ws1', name: 'Shared Workspace', ownerId: 'other-user', accessType: 'direct' },
|
||||
])
|
||||
})
|
||||
|
||||
it('should combine owned and admin workspaces without duplicates', async () => {
|
||||
const mockOwnedWorkspaces = [
|
||||
{ id: 'ws1', name: 'My Workspace', ownerId: 'user123' },
|
||||
{ id: 'ws2', name: 'Another Workspace', ownerId: 'user123' },
|
||||
]
|
||||
const mockAdminWorkspaces = [
|
||||
{ id: 'ws1', name: 'My Workspace', ownerId: 'user123' }, // Duplicate (should be filtered)
|
||||
{ id: 'ws3', name: 'Shared Workspace', ownerId: 'other-user' },
|
||||
]
|
||||
|
||||
let callCount = 0
|
||||
mockDb.select.mockImplementation(() => {
|
||||
callCount++
|
||||
if (callCount === 1) {
|
||||
return createMockChain(mockOwnedWorkspaces) // Owned workspaces
|
||||
}
|
||||
return createMockChain(mockAdminWorkspaces) // Admin workspaces
|
||||
})
|
||||
|
||||
const result = await getManageableWorkspaces('user123')
|
||||
|
||||
expect(result).toHaveLength(3)
|
||||
expect(result).toEqual([
|
||||
{ id: 'ws1', name: 'My Workspace', ownerId: 'user123', accessType: 'owner' },
|
||||
{ id: 'ws2', name: 'Another Workspace', ownerId: 'user123', accessType: 'owner' },
|
||||
{ id: 'ws3', name: 'Shared Workspace', ownerId: 'other-user', accessType: 'direct' },
|
||||
])
|
||||
})
|
||||
|
||||
it('should handle empty workspace names', async () => {
|
||||
const mockWorkspaces = [{ id: 'ws1', name: '', ownerId: 'user123' }]
|
||||
|
||||
let callCount = 0
|
||||
mockDb.select.mockImplementation(() => {
|
||||
callCount++
|
||||
if (callCount === 1) {
|
||||
return createMockChain(mockWorkspaces)
|
||||
}
|
||||
return createMockChain([])
|
||||
})
|
||||
|
||||
const result = await getManageableWorkspaces('user123')
|
||||
|
||||
expect(result[0].name).toBe('')
|
||||
})
|
||||
|
||||
it('should handle multiple admin permissions for same workspace', async () => {
|
||||
const mockAdminWorkspaces = [
|
||||
{ id: 'ws1', name: 'Shared Workspace', ownerId: 'other-user' },
|
||||
{ id: 'ws1', name: 'Shared Workspace', ownerId: 'other-user' }, // Duplicate
|
||||
]
|
||||
|
||||
let callCount = 0
|
||||
mockDb.select.mockImplementation(() => {
|
||||
callCount++
|
||||
if (callCount === 1) {
|
||||
return createMockChain([]) // No owned workspaces
|
||||
}
|
||||
return createMockChain(mockAdminWorkspaces) // Admin workspaces with duplicates
|
||||
})
|
||||
|
||||
const result = await getManageableWorkspaces('user123')
|
||||
|
||||
expect(result).toHaveLength(2) // Should include duplicates from admin permissions
|
||||
})
|
||||
|
||||
it('should handle empty user ID gracefully', async () => {
|
||||
const chain = createMockChain([])
|
||||
mockDb.select.mockReturnValue(chain)
|
||||
|
||||
const result = await getManageableWorkspaces('')
|
||||
|
||||
expect(result).toEqual([])
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { and, eq } from 'drizzle-orm'
|
||||
import { db } from '@/db'
|
||||
import { member, permissions, type permissionTypeEnum, user, workspace } from '@/db/schema'
|
||||
import { permissions, type permissionTypeEnum, user, workspace } from '@/db/schema'
|
||||
|
||||
export type PermissionType = (typeof permissionTypeEnum.enumValues)[number]
|
||||
|
||||
@@ -32,7 +32,6 @@ export async function getUserEntityPermissions(
|
||||
return null
|
||||
}
|
||||
|
||||
// If multiple permissions exist (legacy data), return the highest one
|
||||
const permissionOrder: Record<PermissionType, number> = { admin: 3, write: 2, read: 1 }
|
||||
const highestPermission = result.reduce((highest, current) => {
|
||||
return permissionOrder[current.permissionType] > permissionOrder[highest.permissionType]
|
||||
@@ -46,13 +45,13 @@ export async function getUserEntityPermissions(
|
||||
/**
|
||||
* Check if a user has admin permission for a specific workspace
|
||||
*
|
||||
* @param userId - The ID of the user to check permissions for
|
||||
* @param workspaceId - The ID of the workspace to check admin permission for
|
||||
* @param userId - The ID of the user to check
|
||||
* @param workspaceId - The ID of the workspace to check
|
||||
* @returns Promise<boolean> - True if the user has admin permission for the workspace, false otherwise
|
||||
*/
|
||||
export async function hasAdminPermission(userId: string, workspaceId: string): Promise<boolean> {
|
||||
const result = await db
|
||||
.select()
|
||||
.select({ id: permissions.id })
|
||||
.from(permissions)
|
||||
.where(
|
||||
and(
|
||||
@@ -73,13 +72,19 @@ export async function hasAdminPermission(userId: string, workspaceId: string): P
|
||||
* @param workspaceId - The ID of the workspace to retrieve user permissions for.
|
||||
* @returns A promise that resolves to an array of user objects, each containing user details and their permission type.
|
||||
*/
|
||||
export async function getUsersWithPermissions(workspaceId: string) {
|
||||
export async function getUsersWithPermissions(workspaceId: string): Promise<
|
||||
Array<{
|
||||
userId: string
|
||||
email: string
|
||||
name: string
|
||||
permissionType: PermissionType
|
||||
}>
|
||||
> {
|
||||
const usersWithPermissions = await db
|
||||
.select({
|
||||
userId: user.id,
|
||||
email: user.email,
|
||||
name: user.name,
|
||||
image: user.image,
|
||||
permissionType: permissions.permissionType,
|
||||
})
|
||||
.from(permissions)
|
||||
@@ -87,141 +92,71 @@ export async function getUsersWithPermissions(workspaceId: string) {
|
||||
.where(and(eq(permissions.entityType, 'workspace'), eq(permissions.entityId, workspaceId)))
|
||||
.orderBy(user.email)
|
||||
|
||||
// Since each user has only one permission, we can use the results directly
|
||||
return usersWithPermissions.map((row) => ({
|
||||
userId: row.userId,
|
||||
email: row.email,
|
||||
name: row.name,
|
||||
image: row.image,
|
||||
permissionType: row.permissionType,
|
||||
}))
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a user is an admin or owner of any organization that has access to a workspace
|
||||
* Check if a user has admin access to a specific workspace
|
||||
*
|
||||
* @param userId - The ID of the user to check
|
||||
* @param workspaceId - The ID of the workspace
|
||||
* @returns Promise<boolean> - True if the user is an organization admin with access to the workspace
|
||||
*/
|
||||
export async function isOrganizationAdminForWorkspace(
|
||||
userId: string,
|
||||
workspaceId: string
|
||||
): Promise<boolean> {
|
||||
try {
|
||||
// Get the workspace owner
|
||||
const workspaceRecord = await db
|
||||
.select({ ownerId: workspace.ownerId })
|
||||
.from(workspace)
|
||||
.where(eq(workspace.id, workspaceId))
|
||||
.limit(1)
|
||||
|
||||
if (workspaceRecord.length === 0) {
|
||||
return false
|
||||
}
|
||||
|
||||
const workspaceOwnerId = workspaceRecord[0].ownerId
|
||||
|
||||
// Check if the user is an admin/owner of any organization that the workspace owner belongs to
|
||||
const orgMemberships = await db
|
||||
.select({
|
||||
organizationId: member.organizationId,
|
||||
role: member.role,
|
||||
})
|
||||
.from(member)
|
||||
.where(
|
||||
and(
|
||||
eq(member.userId, userId),
|
||||
// Only admin and owner roles can manage workspace permissions
|
||||
eq(member.role, 'admin') // We'll also check for 'owner' separately
|
||||
)
|
||||
)
|
||||
|
||||
// Also check for owner role
|
||||
const ownerMemberships = await db
|
||||
.select({
|
||||
organizationId: member.organizationId,
|
||||
role: member.role,
|
||||
})
|
||||
.from(member)
|
||||
.where(and(eq(member.userId, userId), eq(member.role, 'owner')))
|
||||
|
||||
const allOrgMemberships = [...orgMemberships, ...ownerMemberships]
|
||||
|
||||
if (allOrgMemberships.length === 0) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if the workspace owner is a member of any of these organizations
|
||||
for (const membership of allOrgMemberships) {
|
||||
const workspaceOwnerInOrg = await db
|
||||
.select()
|
||||
.from(member)
|
||||
.where(
|
||||
and(
|
||||
eq(member.userId, workspaceOwnerId),
|
||||
eq(member.organizationId, membership.organizationId)
|
||||
)
|
||||
)
|
||||
.limit(1)
|
||||
|
||||
if (workspaceOwnerInOrg.length > 0) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
} catch (error) {
|
||||
console.error('Error checking organization admin status for workspace:', error)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a user has admin permissions (either direct workspace admin or organization admin)
|
||||
*
|
||||
* @param userId - The ID of the user to check permissions for
|
||||
* @param workspaceId - The ID of the workspace to check admin permission for
|
||||
* @returns Promise<boolean> - True if the user has admin permission for the workspace, false otherwise
|
||||
* @param workspaceId - The ID of the workspace to check
|
||||
* @returns Promise<boolean> - True if the user has admin access to the workspace, false otherwise
|
||||
*/
|
||||
export async function hasWorkspaceAdminAccess(
|
||||
userId: string,
|
||||
workspaceId: string
|
||||
): Promise<boolean> {
|
||||
// Check direct workspace admin permission
|
||||
const directAdmin = await hasAdminPermission(userId, workspaceId)
|
||||
if (directAdmin) {
|
||||
const workspaceResult = await db
|
||||
.select({ ownerId: workspace.ownerId })
|
||||
.from(workspace)
|
||||
.where(eq(workspace.id, workspaceId))
|
||||
.limit(1)
|
||||
|
||||
if (workspaceResult.length === 0) {
|
||||
return false
|
||||
}
|
||||
|
||||
if (workspaceResult[0].ownerId === userId) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check organization admin permission
|
||||
const orgAdmin = await isOrganizationAdminForWorkspace(userId, workspaceId)
|
||||
return orgAdmin
|
||||
return await hasAdminPermission(userId, workspaceId)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all workspaces that a user can manage (either as direct admin or organization admin)
|
||||
* Get a list of workspaces that the user has access to
|
||||
*
|
||||
* @param userId - The ID of the user
|
||||
* @returns Promise<Array<{id: string, name: string, ownerId: string}>> - Array of workspaces the user can manage
|
||||
* @param userId - The ID of the user to check
|
||||
* @returns Promise<Array<{
|
||||
* id: string
|
||||
* name: string
|
||||
* ownerId: string
|
||||
* accessType: 'direct' | 'owner'
|
||||
* }>> - A list of workspaces that the user has access to
|
||||
*/
|
||||
export async function getManageableWorkspaces(userId: string): Promise<
|
||||
Array<{
|
||||
id: string
|
||||
name: string
|
||||
ownerId: string
|
||||
accessType: 'direct' | 'organization'
|
||||
accessType: 'direct' | 'owner'
|
||||
}>
|
||||
> {
|
||||
const manageableWorkspaces: Array<{
|
||||
id: string
|
||||
name: string
|
||||
ownerId: string
|
||||
accessType: 'direct' | 'organization'
|
||||
}> = []
|
||||
const ownedWorkspaces = await db
|
||||
.select({
|
||||
id: workspace.id,
|
||||
name: workspace.name,
|
||||
ownerId: workspace.ownerId,
|
||||
})
|
||||
.from(workspace)
|
||||
.where(eq(workspace.ownerId, userId))
|
||||
|
||||
// Get workspaces where user has direct admin permissions
|
||||
const directWorkspaces = await db
|
||||
const adminWorkspaces = await db
|
||||
.select({
|
||||
id: workspace.id,
|
||||
name: workspace.name,
|
||||
@@ -237,86 +172,13 @@ export async function getManageableWorkspaces(userId: string): Promise<
|
||||
)
|
||||
)
|
||||
|
||||
directWorkspaces.forEach((ws) => {
|
||||
manageableWorkspaces.push({
|
||||
...ws,
|
||||
accessType: 'direct',
|
||||
})
|
||||
})
|
||||
const ownedSet = new Set(ownedWorkspaces.map((w) => w.id))
|
||||
const combined = [
|
||||
...ownedWorkspaces.map((ws) => ({ ...ws, accessType: 'owner' as const })),
|
||||
...adminWorkspaces
|
||||
.filter((ws) => !ownedSet.has(ws.id))
|
||||
.map((ws) => ({ ...ws, accessType: 'direct' as const })),
|
||||
]
|
||||
|
||||
// Get workspaces where user has organization admin access
|
||||
// First, get organizations where the user is admin/owner
|
||||
const adminOrgs = await db
|
||||
.select({ organizationId: member.organizationId })
|
||||
.from(member)
|
||||
.where(
|
||||
and(
|
||||
eq(member.userId, userId)
|
||||
// Check for both admin and owner roles
|
||||
)
|
||||
)
|
||||
|
||||
// Get all organization workspaces for these orgs
|
||||
for (const org of adminOrgs) {
|
||||
// Get all members of this organization
|
||||
const orgMembers = await db
|
||||
.select({ userId: member.userId })
|
||||
.from(member)
|
||||
.where(eq(member.organizationId, org.organizationId))
|
||||
|
||||
// Get workspaces owned by org members
|
||||
const orgWorkspaces = await db
|
||||
.select({
|
||||
id: workspace.id,
|
||||
name: workspace.name,
|
||||
ownerId: workspace.ownerId,
|
||||
})
|
||||
.from(workspace)
|
||||
.where(
|
||||
// Find workspaces owned by any org member
|
||||
eq(workspace.ownerId, orgMembers.length > 0 ? orgMembers[0].userId : 'none')
|
||||
)
|
||||
|
||||
// Add these workspaces if not already included
|
||||
orgWorkspaces.forEach((ws) => {
|
||||
if (!manageableWorkspaces.find((existing) => existing.id === ws.id)) {
|
||||
manageableWorkspaces.push({
|
||||
...ws,
|
||||
accessType: 'organization',
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return manageableWorkspaces
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a user is an owner or admin of a specific organization
|
||||
*
|
||||
* @param userId - The ID of the user to check
|
||||
* @param organizationId - The ID of the organization
|
||||
* @returns Promise<boolean> - True if the user is an owner or admin of the organization
|
||||
*/
|
||||
export async function isOrganizationOwnerOrAdmin(
|
||||
userId: string,
|
||||
organizationId: string
|
||||
): Promise<boolean> {
|
||||
try {
|
||||
const memberRecord = await db
|
||||
.select({ role: member.role })
|
||||
.from(member)
|
||||
.where(and(eq(member.userId, userId), eq(member.organizationId, organizationId)))
|
||||
.limit(1)
|
||||
|
||||
if (memberRecord.length === 0) {
|
||||
return false // User is not a member of the organization
|
||||
}
|
||||
|
||||
const userRole = memberRecord[0].role
|
||||
return ['owner', 'admin'].includes(userRole)
|
||||
} catch (error) {
|
||||
console.error('Error checking organization ownership/admin status:', error)
|
||||
return false
|
||||
}
|
||||
return combined
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import { SIM_AGENT_API_URL_DEFAULT } from '@/lib/sim-agent'
|
||||
|
||||
const logger = createLogger('SimAgentClient')
|
||||
|
||||
// Base URL for the sim-agent service
|
||||
const SIM_AGENT_BASE_URL = env.SIM_AGENT_API_URL || SIM_AGENT_API_URL_DEFAULT
|
||||
|
||||
export interface SimAgentRequest {
|
||||
@@ -45,7 +44,6 @@ class SimAgentClient {
|
||||
try {
|
||||
const url = `${this.baseUrl}${endpoint}`
|
||||
|
||||
// Use provided API key or try to get it from environment
|
||||
const requestHeaders: Record<string, string> = {
|
||||
'Content-Type': 'application/json',
|
||||
...headers,
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import OpenAI, { AzureOpenAI } from 'openai'
|
||||
import { env } from '@/lib/env'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
|
||||
const logger = createLogger('SimAgentUtils')
|
||||
|
||||
const azureApiKey = env.AZURE_OPENAI_API_KEY
|
||||
const azureEndpoint = env.AZURE_OPENAI_ENDPOINT
|
||||
@@ -52,7 +55,7 @@ export async function generateChatTitle(message: string): Promise<string | null>
|
||||
const title = response.choices[0]?.message?.content?.trim() || null
|
||||
return title
|
||||
} catch (error) {
|
||||
console.error('Error generating chat title:', error)
|
||||
logger.error('Error generating chat title:', error)
|
||||
return null
|
||||
}
|
||||
}
|
||||
@@ -154,9 +154,8 @@ export function useSubscriptionUpgrade() {
|
||||
} catch (error) {
|
||||
logger.error('Failed to initiate subscription upgrade:', error)
|
||||
|
||||
// Log detailed error information for debugging
|
||||
if (error instanceof Error) {
|
||||
console.error('Detailed error:', {
|
||||
logger.error('Detailed error:', {
|
||||
message: error.message,
|
||||
stack: error.stack,
|
||||
cause: error.cause,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user