mirror of
https://github.com/simstudioai/sim.git
synced 2026-04-28 03:00:29 -04:00
feat(azure-openai): allow usage of azure-openai for knowledgebase uploads and wand generation (#1056)
* feat(azure-openai): allow usage of azure-openai for knowledgebase uploads * feat(azure-openai): added azure-openai for kb and wand * added embeddings utils, added the ability to use mistral through Azure * fix(oauth): gdrive picker race condition, token route cleanup * fix test * feat(mailer): consolidated all emailing to mailer service, added support for Azure ACS (#1054) * feat(mailer): consolidated all emailing to mailer service, added support for Azure ACS * fix batch invitation email template * cleanup * improvement(emails): add help template instead of doing it inline * remove fallback version --------- Co-authored-by: Vikhyath Mondreti <vikhyath@simstudio.ai>
This commit is contained in:
@@ -4,15 +4,50 @@
|
||||
*
|
||||
* @vitest-environment node
|
||||
*/
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
vi.mock('drizzle-orm')
|
||||
vi.mock('@/lib/logs/console/logger')
|
||||
vi.mock('@/lib/logs/console/logger', () => ({
|
||||
createLogger: vi.fn(() => ({
|
||||
info: vi.fn(),
|
||||
debug: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
error: vi.fn(),
|
||||
})),
|
||||
}))
|
||||
vi.mock('@/db')
|
||||
vi.mock('@/lib/documents/utils', () => ({
|
||||
retryWithExponentialBackoff: (fn: any) => fn(),
|
||||
}))
|
||||
|
||||
import { handleTagAndVectorSearch, handleTagOnlySearch, handleVectorOnlySearch } from './utils'
|
||||
vi.stubGlobal(
|
||||
'fetch',
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
data: [{ embedding: [0.1, 0.2, 0.3] }],
|
||||
}),
|
||||
})
|
||||
)
|
||||
|
||||
vi.mock('@/lib/env', () => ({
|
||||
env: {},
|
||||
isTruthy: (value: string | boolean | number | undefined) =>
|
||||
typeof value === 'string' ? value === 'true' || value === '1' : Boolean(value),
|
||||
}))
|
||||
|
||||
import {
|
||||
generateSearchEmbedding,
|
||||
handleTagAndVectorSearch,
|
||||
handleTagOnlySearch,
|
||||
handleVectorOnlySearch,
|
||||
} from './utils'
|
||||
|
||||
describe('Knowledge Search Utils', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('handleTagOnlySearch', () => {
|
||||
it('should throw error when no filters provided', async () => {
|
||||
const params = {
|
||||
@@ -140,4 +175,251 @@ describe('Knowledge Search Utils', () => {
|
||||
expect(params.distanceThreshold).toBe(0.8)
|
||||
})
|
||||
})
|
||||
|
||||
describe('generateSearchEmbedding', () => {
|
||||
it('should use Azure OpenAI when KB-specific config is provided', async () => {
|
||||
const { env } = await import('@/lib/env')
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
Object.assign(env, {
|
||||
AZURE_OPENAI_API_KEY: 'test-azure-key',
|
||||
AZURE_OPENAI_ENDPOINT: 'https://test.openai.azure.com',
|
||||
AZURE_OPENAI_API_VERSION: '2024-12-01-preview',
|
||||
KB_OPENAI_MODEL_NAME: 'text-embedding-ada-002',
|
||||
OPENAI_API_KEY: 'test-openai-key',
|
||||
})
|
||||
|
||||
const fetchSpy = vi.mocked(fetch)
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
data: [{ embedding: [0.1, 0.2, 0.3] }],
|
||||
}),
|
||||
} as any)
|
||||
|
||||
const result = await generateSearchEmbedding('test query')
|
||||
|
||||
expect(fetchSpy).toHaveBeenCalledWith(
|
||||
'https://test.openai.azure.com/openai/deployments/text-embedding-ada-002/embeddings?api-version=2024-12-01-preview',
|
||||
expect.objectContaining({
|
||||
headers: expect.objectContaining({
|
||||
'api-key': 'test-azure-key',
|
||||
}),
|
||||
})
|
||||
)
|
||||
expect(result).toEqual([0.1, 0.2, 0.3])
|
||||
|
||||
// Clean up
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
})
|
||||
|
||||
it('should fallback to OpenAI when no KB Azure config provided', async () => {
|
||||
const { env } = await import('@/lib/env')
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
Object.assign(env, {
|
||||
OPENAI_API_KEY: 'test-openai-key',
|
||||
})
|
||||
|
||||
const fetchSpy = vi.mocked(fetch)
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
data: [{ embedding: [0.1, 0.2, 0.3] }],
|
||||
}),
|
||||
} as any)
|
||||
|
||||
const result = await generateSearchEmbedding('test query')
|
||||
|
||||
expect(fetchSpy).toHaveBeenCalledWith(
|
||||
'https://api.openai.com/v1/embeddings',
|
||||
expect.objectContaining({
|
||||
headers: expect.objectContaining({
|
||||
Authorization: 'Bearer test-openai-key',
|
||||
}),
|
||||
})
|
||||
)
|
||||
expect(result).toEqual([0.1, 0.2, 0.3])
|
||||
|
||||
// Clean up
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
})
|
||||
|
||||
it('should use default API version when not provided in Azure config', async () => {
|
||||
const { env } = await import('@/lib/env')
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
Object.assign(env, {
|
||||
AZURE_OPENAI_API_KEY: 'test-azure-key',
|
||||
AZURE_OPENAI_ENDPOINT: 'https://test.openai.azure.com',
|
||||
KB_OPENAI_MODEL_NAME: 'custom-embedding-model',
|
||||
OPENAI_API_KEY: 'test-openai-key',
|
||||
})
|
||||
|
||||
const fetchSpy = vi.mocked(fetch)
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
data: [{ embedding: [0.1, 0.2, 0.3] }],
|
||||
}),
|
||||
} as any)
|
||||
|
||||
await generateSearchEmbedding('test query')
|
||||
|
||||
expect(fetchSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining('api-version='),
|
||||
expect.any(Object)
|
||||
)
|
||||
|
||||
// Clean up
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
})
|
||||
|
||||
it('should use custom model name when provided in Azure config', async () => {
|
||||
const { env } = await import('@/lib/env')
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
Object.assign(env, {
|
||||
AZURE_OPENAI_API_KEY: 'test-azure-key',
|
||||
AZURE_OPENAI_ENDPOINT: 'https://test.openai.azure.com',
|
||||
AZURE_OPENAI_API_VERSION: '2024-12-01-preview',
|
||||
KB_OPENAI_MODEL_NAME: 'custom-embedding-model',
|
||||
OPENAI_API_KEY: 'test-openai-key',
|
||||
})
|
||||
|
||||
const fetchSpy = vi.mocked(fetch)
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
data: [{ embedding: [0.1, 0.2, 0.3] }],
|
||||
}),
|
||||
} as any)
|
||||
|
||||
await generateSearchEmbedding('test query', 'text-embedding-3-small')
|
||||
|
||||
expect(fetchSpy).toHaveBeenCalledWith(
|
||||
'https://test.openai.azure.com/openai/deployments/custom-embedding-model/embeddings?api-version=2024-12-01-preview',
|
||||
expect.any(Object)
|
||||
)
|
||||
|
||||
// Clean up
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
})
|
||||
|
||||
it('should throw error when no API configuration provided', async () => {
|
||||
const { env } = await import('@/lib/env')
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
|
||||
await expect(generateSearchEmbedding('test query')).rejects.toThrow(
|
||||
'Either OPENAI_API_KEY or Azure OpenAI configuration (AZURE_OPENAI_API_KEY + AZURE_OPENAI_ENDPOINT) must be configured'
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle Azure OpenAI API errors properly', async () => {
|
||||
const { env } = await import('@/lib/env')
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
Object.assign(env, {
|
||||
AZURE_OPENAI_API_KEY: 'test-azure-key',
|
||||
AZURE_OPENAI_ENDPOINT: 'https://test.openai.azure.com',
|
||||
AZURE_OPENAI_API_VERSION: '2024-12-01-preview',
|
||||
KB_OPENAI_MODEL_NAME: 'text-embedding-ada-002',
|
||||
})
|
||||
|
||||
const fetchSpy = vi.mocked(fetch)
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: false,
|
||||
status: 404,
|
||||
statusText: 'Not Found',
|
||||
text: async () => 'Deployment not found',
|
||||
} as any)
|
||||
|
||||
await expect(generateSearchEmbedding('test query')).rejects.toThrow('Embedding API failed')
|
||||
|
||||
// Clean up
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
})
|
||||
|
||||
it('should handle OpenAI API errors properly', async () => {
|
||||
const { env } = await import('@/lib/env')
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
Object.assign(env, {
|
||||
OPENAI_API_KEY: 'test-openai-key',
|
||||
})
|
||||
|
||||
const fetchSpy = vi.mocked(fetch)
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: false,
|
||||
status: 429,
|
||||
statusText: 'Too Many Requests',
|
||||
text: async () => 'Rate limit exceeded',
|
||||
} as any)
|
||||
|
||||
await expect(generateSearchEmbedding('test query')).rejects.toThrow('Embedding API failed')
|
||||
|
||||
// Clean up
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
})
|
||||
|
||||
it('should include correct request body for Azure OpenAI', async () => {
|
||||
const { env } = await import('@/lib/env')
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
Object.assign(env, {
|
||||
AZURE_OPENAI_API_KEY: 'test-azure-key',
|
||||
AZURE_OPENAI_ENDPOINT: 'https://test.openai.azure.com',
|
||||
AZURE_OPENAI_API_VERSION: '2024-12-01-preview',
|
||||
KB_OPENAI_MODEL_NAME: 'text-embedding-ada-002',
|
||||
})
|
||||
|
||||
const fetchSpy = vi.mocked(fetch)
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
data: [{ embedding: [0.1, 0.2, 0.3] }],
|
||||
}),
|
||||
} as any)
|
||||
|
||||
await generateSearchEmbedding('test query')
|
||||
|
||||
expect(fetchSpy).toHaveBeenCalledWith(
|
||||
expect.any(String),
|
||||
expect.objectContaining({
|
||||
body: JSON.stringify({
|
||||
input: ['test query'],
|
||||
encoding_format: 'float',
|
||||
}),
|
||||
})
|
||||
)
|
||||
|
||||
// Clean up
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
})
|
||||
|
||||
it('should include correct request body for OpenAI', async () => {
|
||||
const { env } = await import('@/lib/env')
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
Object.assign(env, {
|
||||
OPENAI_API_KEY: 'test-openai-key',
|
||||
})
|
||||
|
||||
const fetchSpy = vi.mocked(fetch)
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
data: [{ embedding: [0.1, 0.2, 0.3] }],
|
||||
}),
|
||||
} as any)
|
||||
|
||||
await generateSearchEmbedding('test query', 'text-embedding-3-small')
|
||||
|
||||
expect(fetchSpy).toHaveBeenCalledWith(
|
||||
expect.any(String),
|
||||
expect.objectContaining({
|
||||
body: JSON.stringify({
|
||||
input: ['test query'],
|
||||
model: 'text-embedding-3-small',
|
||||
encoding_format: 'float',
|
||||
}),
|
||||
})
|
||||
)
|
||||
|
||||
// Clean up
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,22 +1,10 @@
|
||||
import { and, eq, inArray, sql } from 'drizzle-orm'
|
||||
import { retryWithExponentialBackoff } from '@/lib/documents/utils'
|
||||
import { env } from '@/lib/env'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { db } from '@/db'
|
||||
import { embedding } from '@/db/schema'
|
||||
|
||||
const logger = createLogger('KnowledgeSearchUtils')
|
||||
|
||||
export class APIError extends Error {
|
||||
public status: number
|
||||
|
||||
constructor(message: string, status: number) {
|
||||
super(message)
|
||||
this.name = 'APIError'
|
||||
this.status = status
|
||||
}
|
||||
}
|
||||
|
||||
export interface SearchResult {
|
||||
id: string
|
||||
content: string
|
||||
@@ -41,61 +29,8 @@ export interface SearchParams {
|
||||
distanceThreshold?: number
|
||||
}
|
||||
|
||||
export async function generateSearchEmbedding(query: string): Promise<number[]> {
|
||||
const openaiApiKey = env.OPENAI_API_KEY
|
||||
if (!openaiApiKey) {
|
||||
throw new Error('OPENAI_API_KEY not configured')
|
||||
}
|
||||
|
||||
try {
|
||||
const embedding = await retryWithExponentialBackoff(
|
||||
async () => {
|
||||
const response = await fetch('https://api.openai.com/v1/embeddings', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Authorization: `Bearer ${openaiApiKey}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
input: query,
|
||||
model: 'text-embedding-3-small',
|
||||
encoding_format: 'float',
|
||||
}),
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text()
|
||||
const error = new APIError(
|
||||
`OpenAI API error: ${response.status} ${response.statusText} - ${errorText}`,
|
||||
response.status
|
||||
)
|
||||
throw error
|
||||
}
|
||||
|
||||
const data = await response.json()
|
||||
|
||||
if (!data.data || !Array.isArray(data.data) || data.data.length === 0) {
|
||||
throw new Error('Invalid response format from OpenAI embeddings API')
|
||||
}
|
||||
|
||||
return data.data[0].embedding
|
||||
},
|
||||
{
|
||||
maxRetries: 5,
|
||||
initialDelayMs: 1000,
|
||||
maxDelayMs: 30000,
|
||||
backoffMultiplier: 2,
|
||||
}
|
||||
)
|
||||
|
||||
return embedding
|
||||
} catch (error) {
|
||||
logger.error('Failed to generate search embedding:', error)
|
||||
throw new Error(
|
||||
`Embedding generation failed: ${error instanceof Error ? error.message : 'Unknown error'}`
|
||||
)
|
||||
}
|
||||
}
|
||||
// Use shared embedding utility
|
||||
export { generateSearchEmbedding } from '@/lib/embeddings/utils'
|
||||
|
||||
function getTagFilters(filters: Record<string, string>, embedding: any) {
|
||||
return Object.entries(filters).map(([key, value]) => {
|
||||
|
||||
@@ -252,5 +252,76 @@ describe('Knowledge Utils', () => {
|
||||
|
||||
expect(result.length).toBe(2)
|
||||
})
|
||||
|
||||
it('should use Azure OpenAI when Azure config is provided', async () => {
|
||||
const { env } = await import('@/lib/env')
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
Object.assign(env, {
|
||||
AZURE_OPENAI_API_KEY: 'test-azure-key',
|
||||
AZURE_OPENAI_ENDPOINT: 'https://test.openai.azure.com',
|
||||
AZURE_OPENAI_API_VERSION: '2024-12-01-preview',
|
||||
KB_OPENAI_MODEL_NAME: 'text-embedding-ada-002',
|
||||
OPENAI_API_KEY: 'test-openai-key',
|
||||
})
|
||||
|
||||
const fetchSpy = vi.mocked(fetch)
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
data: [{ embedding: [0.1, 0.2], index: 0 }],
|
||||
}),
|
||||
} as any)
|
||||
|
||||
await generateEmbeddings(['test text'])
|
||||
|
||||
expect(fetchSpy).toHaveBeenCalledWith(
|
||||
'https://test.openai.azure.com/openai/deployments/text-embedding-ada-002/embeddings?api-version=2024-12-01-preview',
|
||||
expect.objectContaining({
|
||||
headers: expect.objectContaining({
|
||||
'api-key': 'test-azure-key',
|
||||
}),
|
||||
})
|
||||
)
|
||||
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
})
|
||||
|
||||
it('should fallback to OpenAI when no Azure config provided', async () => {
|
||||
const { env } = await import('@/lib/env')
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
Object.assign(env, {
|
||||
OPENAI_API_KEY: 'test-openai-key',
|
||||
})
|
||||
|
||||
const fetchSpy = vi.mocked(fetch)
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
data: [{ embedding: [0.1, 0.2], index: 0 }],
|
||||
}),
|
||||
} as any)
|
||||
|
||||
await generateEmbeddings(['test text'])
|
||||
|
||||
expect(fetchSpy).toHaveBeenCalledWith(
|
||||
'https://api.openai.com/v1/embeddings',
|
||||
expect.objectContaining({
|
||||
headers: expect.objectContaining({
|
||||
Authorization: 'Bearer test-openai-key',
|
||||
}),
|
||||
})
|
||||
)
|
||||
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
})
|
||||
|
||||
it('should throw error when no API configuration provided', async () => {
|
||||
const { env } = await import('@/lib/env')
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
|
||||
await expect(generateEmbeddings(['test text'])).rejects.toThrow(
|
||||
'Either OPENAI_API_KEY or Azure OpenAI configuration (AZURE_OPENAI_API_KEY + AZURE_OPENAI_ENDPOINT) must be configured'
|
||||
)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import crypto from 'crypto'
|
||||
import { and, eq, isNull } from 'drizzle-orm'
|
||||
import { processDocument } from '@/lib/documents/document-processor'
|
||||
import { retryWithExponentialBackoff } from '@/lib/documents/utils'
|
||||
import { env } from '@/lib/env'
|
||||
import { generateEmbeddings } from '@/lib/embeddings/utils'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { getUserEntityPermissions } from '@/lib/permissions/utils'
|
||||
import { db } from '@/db'
|
||||
@@ -10,22 +9,11 @@ import { document, embedding, knowledgeBase } from '@/db/schema'
|
||||
|
||||
const logger = createLogger('KnowledgeUtils')
|
||||
|
||||
// Timeout constants (in milliseconds)
|
||||
const TIMEOUTS = {
|
||||
OVERALL_PROCESSING: 150000, // 150 seconds (2.5 minutes)
|
||||
EMBEDDINGS_API: 60000, // 60 seconds per batch
|
||||
} as const
|
||||
|
||||
class APIError extends Error {
|
||||
public status: number
|
||||
|
||||
constructor(message: string, status: number) {
|
||||
super(message)
|
||||
this.name = 'APIError'
|
||||
this.status = status
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a timeout wrapper for async operations
|
||||
*/
|
||||
@@ -110,18 +98,6 @@ export interface EmbeddingData {
|
||||
updatedAt: Date
|
||||
}
|
||||
|
||||
interface OpenAIEmbeddingResponse {
|
||||
data: Array<{
|
||||
embedding: number[]
|
||||
index: number
|
||||
}>
|
||||
model: string
|
||||
usage: {
|
||||
prompt_tokens: number
|
||||
total_tokens: number
|
||||
}
|
||||
}
|
||||
|
||||
export interface KnowledgeBaseAccessResult {
|
||||
hasAccess: true
|
||||
knowledgeBase: Pick<KnowledgeBaseData, 'id' | 'userId'>
|
||||
@@ -405,87 +381,8 @@ export async function checkChunkAccess(
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate embeddings using OpenAI API with retry logic for rate limiting
|
||||
*/
|
||||
export async function generateEmbeddings(
|
||||
texts: string[],
|
||||
embeddingModel = 'text-embedding-3-small'
|
||||
): Promise<number[][]> {
|
||||
const openaiApiKey = env.OPENAI_API_KEY
|
||||
if (!openaiApiKey) {
|
||||
throw new Error('OPENAI_API_KEY not configured')
|
||||
}
|
||||
|
||||
try {
|
||||
const batchSize = 100
|
||||
const allEmbeddings: number[][] = []
|
||||
|
||||
for (let i = 0; i < texts.length; i += batchSize) {
|
||||
const batch = texts.slice(i, i + batchSize)
|
||||
|
||||
logger.info(
|
||||
`Generating embeddings for batch ${Math.floor(i / batchSize) + 1} (${batch.length} texts)`
|
||||
)
|
||||
|
||||
const batchEmbeddings = await retryWithExponentialBackoff(
|
||||
async () => {
|
||||
const controller = new AbortController()
|
||||
const timeoutId = setTimeout(() => controller.abort(), TIMEOUTS.EMBEDDINGS_API)
|
||||
|
||||
try {
|
||||
const response = await fetch('https://api.openai.com/v1/embeddings', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Authorization: `Bearer ${openaiApiKey}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
input: batch,
|
||||
model: embeddingModel,
|
||||
encoding_format: 'float',
|
||||
}),
|
||||
signal: controller.signal,
|
||||
})
|
||||
|
||||
clearTimeout(timeoutId)
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text()
|
||||
const error = new APIError(
|
||||
`OpenAI API error: ${response.status} ${response.statusText} - ${errorText}`,
|
||||
response.status
|
||||
)
|
||||
throw error
|
||||
}
|
||||
|
||||
const data: OpenAIEmbeddingResponse = await response.json()
|
||||
return data.data.map((item) => item.embedding)
|
||||
} catch (error) {
|
||||
clearTimeout(timeoutId)
|
||||
if (error instanceof Error && error.name === 'AbortError') {
|
||||
throw new Error('OpenAI API request timed out')
|
||||
}
|
||||
throw error
|
||||
}
|
||||
},
|
||||
{
|
||||
maxRetries: 5,
|
||||
initialDelayMs: 1000,
|
||||
maxDelayMs: 60000, // Max 1 minute delay for embeddings
|
||||
backoffMultiplier: 2,
|
||||
}
|
||||
)
|
||||
|
||||
allEmbeddings.push(...batchEmbeddings)
|
||||
}
|
||||
|
||||
return allEmbeddings
|
||||
} catch (error) {
|
||||
logger.error('Failed to generate embeddings:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
// Export for external use
|
||||
export { generateEmbeddings }
|
||||
|
||||
/**
|
||||
* Process a document asynchronously with full error handling
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { unstable_noStore as noStore } from 'next/cache'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import OpenAI from 'openai'
|
||||
import OpenAI, { AzureOpenAI } from 'openai'
|
||||
import { env } from '@/lib/env'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
|
||||
@@ -10,14 +10,32 @@ export const maxDuration = 60
|
||||
|
||||
const logger = createLogger('WandGenerateAPI')
|
||||
|
||||
const openai = env.OPENAI_API_KEY
|
||||
? new OpenAI({
|
||||
apiKey: env.OPENAI_API_KEY,
|
||||
})
|
||||
: null
|
||||
const azureApiKey = env.AZURE_OPENAI_API_KEY
|
||||
const azureEndpoint = env.AZURE_OPENAI_ENDPOINT
|
||||
const azureApiVersion = env.AZURE_OPENAI_API_VERSION
|
||||
const wandModelName = env.WAND_OPENAI_MODEL_NAME || 'gpt-4o'
|
||||
const openaiApiKey = env.OPENAI_API_KEY
|
||||
|
||||
if (!env.OPENAI_API_KEY) {
|
||||
logger.warn('OPENAI_API_KEY not found. Wand generation API will not function.')
|
||||
const useWandAzure = azureApiKey && azureEndpoint && azureApiVersion
|
||||
|
||||
const client = useWandAzure
|
||||
? new AzureOpenAI({
|
||||
apiKey: azureApiKey,
|
||||
apiVersion: azureApiVersion,
|
||||
endpoint: azureEndpoint,
|
||||
})
|
||||
: openaiApiKey
|
||||
? new OpenAI({
|
||||
apiKey: openaiApiKey,
|
||||
})
|
||||
: null
|
||||
|
||||
if (!useWandAzure && !openaiApiKey) {
|
||||
logger.warn(
|
||||
'Neither Azure OpenAI nor OpenAI API key found. Wand generation API will not function.'
|
||||
)
|
||||
} else {
|
||||
logger.info(`Using ${useWandAzure ? 'Azure OpenAI' : 'OpenAI'} for wand generation`)
|
||||
}
|
||||
|
||||
interface ChatMessage {
|
||||
@@ -32,14 +50,12 @@ interface RequestBody {
|
||||
history?: ChatMessage[]
|
||||
}
|
||||
|
||||
// The endpoint is now generic - system prompts come from wand configs
|
||||
|
||||
export async function POST(req: NextRequest) {
|
||||
const requestId = crypto.randomUUID().slice(0, 8)
|
||||
logger.info(`[${requestId}] Received wand generation request`)
|
||||
|
||||
if (!openai) {
|
||||
logger.error(`[${requestId}] OpenAI client not initialized. Missing API key.`)
|
||||
if (!client) {
|
||||
logger.error(`[${requestId}] AI client not initialized. Missing API key.`)
|
||||
return NextResponse.json(
|
||||
{ success: false, error: 'Wand generation service is not configured.' },
|
||||
{ status: 503 }
|
||||
@@ -74,16 +90,19 @@ export async function POST(req: NextRequest) {
|
||||
// Add the current user prompt
|
||||
messages.push({ role: 'user', content: prompt })
|
||||
|
||||
logger.debug(`[${requestId}] Calling OpenAI API for wand generation`, {
|
||||
stream,
|
||||
historyLength: history.length,
|
||||
})
|
||||
logger.debug(
|
||||
`[${requestId}] Calling ${useWandAzure ? 'Azure OpenAI' : 'OpenAI'} API for wand generation`,
|
||||
{
|
||||
stream,
|
||||
historyLength: history.length,
|
||||
}
|
||||
)
|
||||
|
||||
// For streaming responses
|
||||
if (stream) {
|
||||
try {
|
||||
const streamCompletion = await openai?.chat.completions.create({
|
||||
model: 'gpt-4o',
|
||||
const streamCompletion = await client.chat.completions.create({
|
||||
model: useWandAzure ? wandModelName : 'gpt-4o',
|
||||
messages: messages,
|
||||
temperature: 0.3,
|
||||
max_tokens: 10000,
|
||||
@@ -141,8 +160,8 @@ export async function POST(req: NextRequest) {
|
||||
}
|
||||
|
||||
// For non-streaming responses
|
||||
const completion = await openai?.chat.completions.create({
|
||||
model: 'gpt-4o',
|
||||
const completion = await client.chat.completions.create({
|
||||
model: useWandAzure ? wandModelName : 'gpt-4o',
|
||||
messages: messages,
|
||||
temperature: 0.3,
|
||||
max_tokens: 10000,
|
||||
@@ -151,9 +170,11 @@ export async function POST(req: NextRequest) {
|
||||
const generatedContent = completion.choices[0]?.message?.content?.trim()
|
||||
|
||||
if (!generatedContent) {
|
||||
logger.error(`[${requestId}] OpenAI response was empty or invalid.`)
|
||||
logger.error(
|
||||
`[${requestId}] ${useWandAzure ? 'Azure OpenAI' : 'OpenAI'} response was empty or invalid.`
|
||||
)
|
||||
return NextResponse.json(
|
||||
{ success: false, error: 'Failed to generate content. OpenAI response was empty.' },
|
||||
{ success: false, error: 'Failed to generate content. AI response was empty.' },
|
||||
{ status: 500 }
|
||||
)
|
||||
}
|
||||
@@ -171,7 +192,9 @@ export async function POST(req: NextRequest) {
|
||||
|
||||
if (error instanceof OpenAI.APIError) {
|
||||
status = error.status || 500
|
||||
logger.error(`[${requestId}] OpenAI API Error: ${status} - ${error.message}`)
|
||||
logger.error(
|
||||
`[${requestId}] ${useWandAzure ? 'Azure OpenAI' : 'OpenAI'} API Error: ${status} - ${error.message}`
|
||||
)
|
||||
|
||||
if (status === 401) {
|
||||
clientErrorMessage = 'Authentication failed. Please check your API key configuration.'
|
||||
@@ -181,6 +204,10 @@ export async function POST(req: NextRequest) {
|
||||
clientErrorMessage =
|
||||
'The wand generation service is currently unavailable. Please try again later.'
|
||||
}
|
||||
} else if (useWandAzure && error.message?.includes('DeploymentNotFound')) {
|
||||
clientErrorMessage =
|
||||
'Azure OpenAI deployment not found. Please check your model deployment configuration.'
|
||||
status = 404
|
||||
}
|
||||
|
||||
return NextResponse.json(
|
||||
|
||||
@@ -9,10 +9,9 @@ import { mistralParserTool } from '@/tools/mistral/parser'
|
||||
|
||||
const logger = createLogger('DocumentProcessor')
|
||||
|
||||
// Timeout constants (in milliseconds)
|
||||
const TIMEOUTS = {
|
||||
FILE_DOWNLOAD: 60000, // 60 seconds
|
||||
MISTRAL_OCR_API: 90000, // 90 seconds
|
||||
FILE_DOWNLOAD: 60000,
|
||||
MISTRAL_OCR_API: 90000,
|
||||
} as const
|
||||
|
||||
type S3Config = {
|
||||
@@ -27,20 +26,19 @@ type BlobConfig = {
|
||||
connectionString?: string
|
||||
}
|
||||
|
||||
function getKBConfig(): S3Config | BlobConfig {
|
||||
const getKBConfig = (): S3Config | BlobConfig => {
|
||||
const provider = getStorageProvider()
|
||||
if (provider === 'blob') {
|
||||
return {
|
||||
containerName: BLOB_KB_CONFIG.containerName,
|
||||
accountName: BLOB_KB_CONFIG.accountName,
|
||||
accountKey: BLOB_KB_CONFIG.accountKey,
|
||||
connectionString: BLOB_KB_CONFIG.connectionString,
|
||||
}
|
||||
}
|
||||
return {
|
||||
bucket: S3_KB_CONFIG.bucket,
|
||||
region: S3_KB_CONFIG.region,
|
||||
}
|
||||
return provider === 'blob'
|
||||
? {
|
||||
containerName: BLOB_KB_CONFIG.containerName,
|
||||
accountName: BLOB_KB_CONFIG.accountName,
|
||||
accountKey: BLOB_KB_CONFIG.accountKey,
|
||||
connectionString: BLOB_KB_CONFIG.connectionString,
|
||||
}
|
||||
: {
|
||||
bucket: S3_KB_CONFIG.bucket,
|
||||
region: S3_KB_CONFIG.region,
|
||||
}
|
||||
}
|
||||
|
||||
class APIError extends Error {
|
||||
@@ -53,9 +51,6 @@ class APIError extends Error {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Process a document by parsing it and chunking the content
|
||||
*/
|
||||
export async function processDocument(
|
||||
fileUrl: string,
|
||||
filename: string,
|
||||
@@ -79,29 +74,23 @@ export async function processDocument(
|
||||
logger.info(`Processing document: ${filename}`)
|
||||
|
||||
try {
|
||||
// Parse the document
|
||||
const { content, processingMethod, cloudUrl } = await parseDocument(fileUrl, filename, mimeType)
|
||||
|
||||
// Create chunker and process content
|
||||
const chunker = new TextChunker({
|
||||
chunkSize,
|
||||
overlap: chunkOverlap,
|
||||
minChunkSize,
|
||||
})
|
||||
const parseResult = await parseDocument(fileUrl, filename, mimeType)
|
||||
const { content, processingMethod } = parseResult
|
||||
const cloudUrl = 'cloudUrl' in parseResult ? parseResult.cloudUrl : undefined
|
||||
|
||||
const chunker = new TextChunker({ chunkSize, overlap: chunkOverlap, minChunkSize })
|
||||
const chunks = await chunker.chunk(content)
|
||||
|
||||
// Calculate metadata
|
||||
const characterCount = content.length
|
||||
const tokenCount = chunks.reduce((sum: number, chunk: Chunk) => sum + chunk.tokenCount, 0)
|
||||
const tokenCount = chunks.reduce((sum, chunk) => sum + chunk.tokenCount, 0)
|
||||
|
||||
logger.info(`Document processed successfully: ${chunks.length} chunks, ${tokenCount} tokens`)
|
||||
logger.info(`Document processed: ${chunks.length} chunks, ${tokenCount} tokens`)
|
||||
|
||||
return {
|
||||
chunks,
|
||||
metadata: {
|
||||
filename,
|
||||
fileSize: content.length, // Using content length as file size approximation
|
||||
fileSize: characterCount,
|
||||
mimeType,
|
||||
chunkCount: chunks.length,
|
||||
tokenCount,
|
||||
@@ -116,9 +105,6 @@ export async function processDocument(
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse a document from a URL or file path
|
||||
*/
|
||||
async function parseDocument(
|
||||
fileUrl: string,
|
||||
filename: string,
|
||||
@@ -128,283 +114,286 @@ async function parseDocument(
|
||||
processingMethod: 'file-parser' | 'mistral-ocr'
|
||||
cloudUrl?: string
|
||||
}> {
|
||||
// Check if we should use Mistral OCR for PDFs
|
||||
const shouldUseMistralOCR = mimeType === 'application/pdf' && env.MISTRAL_API_KEY
|
||||
const isPDF = mimeType === 'application/pdf'
|
||||
const hasAzureMistralOCR =
|
||||
env.AZURE_OPENAI_API_KEY && env.OCR_AZURE_ENDPOINT && env.OCR_AZURE_MODEL_NAME
|
||||
const hasMistralOCR = env.MISTRAL_API_KEY
|
||||
|
||||
if (shouldUseMistralOCR) {
|
||||
logger.info(`Using Mistral OCR for PDF: ${filename}`)
|
||||
return await parseWithMistralOCR(fileUrl, filename, mimeType)
|
||||
// Check Azure Mistral OCR configuration
|
||||
|
||||
if (isPDF && hasAzureMistralOCR) {
|
||||
logger.info(`Using Azure Mistral OCR: ${filename}`)
|
||||
return parseWithAzureMistralOCR(fileUrl, filename, mimeType)
|
||||
}
|
||||
|
||||
// Use standard file parser
|
||||
logger.info(`Using file parser for: ${filename}`)
|
||||
return await parseWithFileParser(fileUrl, filename, mimeType)
|
||||
if (isPDF && hasMistralOCR) {
|
||||
logger.info(`Using Mistral OCR: ${filename}`)
|
||||
return parseWithMistralOCR(fileUrl, filename, mimeType)
|
||||
}
|
||||
|
||||
logger.info(`Using file parser: ${filename}`)
|
||||
return parseWithFileParser(fileUrl, filename, mimeType)
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse document using Mistral OCR
|
||||
*/
|
||||
async function parseWithMistralOCR(
|
||||
fileUrl: string,
|
||||
filename: string,
|
||||
mimeType: string
|
||||
): Promise<{
|
||||
content: string
|
||||
processingMethod: 'file-parser' | 'mistral-ocr'
|
||||
cloudUrl?: string
|
||||
}> {
|
||||
const mistralApiKey = env.MISTRAL_API_KEY
|
||||
if (!mistralApiKey) {
|
||||
throw new Error('Mistral API key is required for OCR processing')
|
||||
async function handleFileForOCR(fileUrl: string, filename: string, mimeType: string) {
|
||||
if (fileUrl.startsWith('https://')) {
|
||||
return { httpsUrl: fileUrl }
|
||||
}
|
||||
|
||||
let httpsUrl = fileUrl
|
||||
let cloudUrl: string | undefined
|
||||
logger.info(`Uploading "${filename}" to cloud storage for OCR`)
|
||||
|
||||
// If the URL is not HTTPS, we need to upload to cloud storage first
|
||||
if (!fileUrl.startsWith('https://')) {
|
||||
logger.info(`Uploading "${filename}" to cloud storage for Mistral OCR access`)
|
||||
const buffer = await downloadFileWithTimeout(fileUrl)
|
||||
const kbConfig = getKBConfig()
|
||||
|
||||
// Download the file content with timeout
|
||||
const controller = new AbortController()
|
||||
const timeoutId = setTimeout(() => controller.abort(), TIMEOUTS.FILE_DOWNLOAD)
|
||||
validateCloudConfig(kbConfig)
|
||||
|
||||
try {
|
||||
const response = await fetch(fileUrl, { signal: controller.signal })
|
||||
clearTimeout(timeoutId)
|
||||
try {
|
||||
const cloudResult = await uploadFile(buffer, filename, mimeType, kbConfig as any)
|
||||
const httpsUrl = await getPresignedUrlWithConfig(cloudResult.key, kbConfig as any, 900)
|
||||
logger.info(`Successfully uploaded for OCR: ${cloudResult.key}`)
|
||||
return { httpsUrl, cloudUrl: httpsUrl }
|
||||
} catch (uploadError) {
|
||||
const message = uploadError instanceof Error ? uploadError.message : 'Unknown error'
|
||||
throw new Error(`Cloud upload failed: ${message}. Cloud upload is required for OCR.`)
|
||||
}
|
||||
}
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to download file for cloud upload: ${response.statusText}`)
|
||||
}
|
||||
async function downloadFileWithTimeout(fileUrl: string): Promise<Buffer> {
|
||||
const controller = new AbortController()
|
||||
const timeoutId = setTimeout(() => controller.abort(), TIMEOUTS.FILE_DOWNLOAD)
|
||||
|
||||
const buffer = Buffer.from(await response.arrayBuffer())
|
||||
try {
|
||||
const response = await fetch(fileUrl, { signal: controller.signal })
|
||||
clearTimeout(timeoutId)
|
||||
|
||||
// Always upload to cloud storage for Mistral OCR, even in development
|
||||
const kbConfig = getKBConfig()
|
||||
const provider = getStorageProvider()
|
||||
|
||||
if (provider === 'blob') {
|
||||
const blobConfig = kbConfig as BlobConfig
|
||||
if (
|
||||
!blobConfig.containerName ||
|
||||
(!blobConfig.connectionString && (!blobConfig.accountName || !blobConfig.accountKey))
|
||||
) {
|
||||
throw new Error(
|
||||
'Azure Blob configuration missing for PDF processing with Mistral OCR. Set AZURE_CONNECTION_STRING or both AZURE_ACCOUNT_NAME + AZURE_ACCOUNT_KEY, and AZURE_KB_CONTAINER_NAME.'
|
||||
)
|
||||
}
|
||||
} else {
|
||||
const s3Config = kbConfig as S3Config
|
||||
if (!s3Config.bucket || !s3Config.region) {
|
||||
throw new Error(
|
||||
'S3 configuration missing for PDF processing with Mistral OCR. Set AWS_REGION and S3_KB_BUCKET_NAME environment variables.'
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
// Upload to cloud storage
|
||||
const cloudResult = await uploadFile(buffer, filename, mimeType, kbConfig as any)
|
||||
// Generate presigned URL with 15 minutes expiration
|
||||
httpsUrl = await getPresignedUrlWithConfig(cloudResult.key, kbConfig as any, 900)
|
||||
cloudUrl = httpsUrl
|
||||
logger.info(`Successfully uploaded to cloud storage for Mistral OCR: ${cloudResult.key}`)
|
||||
} catch (uploadError) {
|
||||
logger.error('Failed to upload to cloud storage for Mistral OCR:', uploadError)
|
||||
throw new Error(
|
||||
`Cloud upload failed: ${uploadError instanceof Error ? uploadError.message : 'Unknown error'}. Cloud upload is required for PDF processing with Mistral OCR.`
|
||||
)
|
||||
}
|
||||
} catch (error) {
|
||||
clearTimeout(timeoutId)
|
||||
if (error instanceof Error && error.name === 'AbortError') {
|
||||
throw new Error('File download timed out for Mistral OCR processing')
|
||||
}
|
||||
throw error
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to download file: ${response.statusText}`)
|
||||
}
|
||||
|
||||
return Buffer.from(await response.arrayBuffer())
|
||||
} catch (error) {
|
||||
clearTimeout(timeoutId)
|
||||
if (error instanceof Error && error.name === 'AbortError') {
|
||||
throw new Error('File download timed out')
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
async function downloadFileForBase64(fileUrl: string): Promise<Buffer> {
|
||||
// Handle different URL types for Azure Mistral OCR base64 requirement
|
||||
if (fileUrl.startsWith('data:')) {
|
||||
// Extract base64 data from data URI
|
||||
const [, base64Data] = fileUrl.split(',')
|
||||
if (!base64Data) {
|
||||
throw new Error('Invalid data URI format')
|
||||
}
|
||||
return Buffer.from(base64Data, 'base64')
|
||||
}
|
||||
if (fileUrl.startsWith('http')) {
|
||||
// Download from HTTP(S) URL
|
||||
return downloadFileWithTimeout(fileUrl)
|
||||
}
|
||||
// Local file - read it
|
||||
const fs = await import('fs/promises')
|
||||
return fs.readFile(fileUrl)
|
||||
}
|
||||
|
||||
function validateCloudConfig(kbConfig: S3Config | BlobConfig) {
|
||||
const provider = getStorageProvider()
|
||||
|
||||
if (provider === 'blob') {
|
||||
const config = kbConfig as BlobConfig
|
||||
if (
|
||||
!config.containerName ||
|
||||
(!config.connectionString && (!config.accountName || !config.accountKey))
|
||||
) {
|
||||
throw new Error(
|
||||
'Azure Blob configuration missing. Set AZURE_CONNECTION_STRING or AZURE_ACCOUNT_NAME + AZURE_ACCOUNT_KEY + AZURE_KB_CONTAINER_NAME'
|
||||
)
|
||||
}
|
||||
} else {
|
||||
const config = kbConfig as S3Config
|
||||
if (!config.bucket || !config.region) {
|
||||
throw new Error('S3 configuration missing. Set AWS_REGION and S3_KB_BUCKET_NAME')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function processOCRContent(result: any, filename: string): string {
|
||||
if (!result.success) {
|
||||
throw new Error(`OCR processing failed: ${result.error || 'Unknown error'}`)
|
||||
}
|
||||
|
||||
const content = result.output?.content || ''
|
||||
if (!content.trim()) {
|
||||
throw new Error('OCR returned empty content')
|
||||
}
|
||||
|
||||
logger.info(`OCR completed: ${filename}`)
|
||||
return content
|
||||
}
|
||||
|
||||
function validateOCRConfig(
|
||||
apiKey?: string,
|
||||
endpoint?: string,
|
||||
modelName?: string,
|
||||
service = 'OCR'
|
||||
) {
|
||||
if (!apiKey) throw new Error(`${service} API key required`)
|
||||
if (!endpoint) throw new Error(`${service} endpoint required`)
|
||||
if (!modelName) throw new Error(`${service} model name required`)
|
||||
}
|
||||
|
||||
function extractPageContent(pages: any[]): string {
|
||||
if (!pages?.length) return ''
|
||||
|
||||
return pages
|
||||
.map((page) => page?.markdown || '')
|
||||
.filter(Boolean)
|
||||
.join('\n\n')
|
||||
}
|
||||
|
||||
async function makeOCRRequest(endpoint: string, headers: Record<string, string>, body: any) {
|
||||
const controller = new AbortController()
|
||||
const timeoutId = setTimeout(() => controller.abort(), TIMEOUTS.MISTRAL_OCR_API)
|
||||
|
||||
try {
|
||||
const response = await fetch(endpoint, {
|
||||
method: 'POST',
|
||||
headers,
|
||||
body: JSON.stringify(body),
|
||||
signal: controller.signal,
|
||||
})
|
||||
|
||||
clearTimeout(timeoutId)
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text()
|
||||
throw new APIError(
|
||||
`OCR failed: ${response.status} ${response.statusText} - ${errorText}`,
|
||||
response.status
|
||||
)
|
||||
}
|
||||
|
||||
return response
|
||||
} catch (error) {
|
||||
clearTimeout(timeoutId)
|
||||
if (error instanceof Error && error.name === 'AbortError') {
|
||||
throw new Error('OCR API request timed out')
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
async function parseWithAzureMistralOCR(fileUrl: string, filename: string, mimeType: string) {
|
||||
validateOCRConfig(
|
||||
env.AZURE_OPENAI_API_KEY,
|
||||
env.OCR_AZURE_ENDPOINT,
|
||||
env.OCR_AZURE_MODEL_NAME,
|
||||
'Azure Mistral OCR'
|
||||
)
|
||||
|
||||
// Azure Mistral OCR accepts data URIs with base64 content
|
||||
const fileBuffer = await downloadFileForBase64(fileUrl)
|
||||
const base64Data = fileBuffer.toString('base64')
|
||||
const dataUri = `data:${mimeType};base64,${base64Data}`
|
||||
|
||||
try {
|
||||
const response = await retryWithExponentialBackoff(
|
||||
() =>
|
||||
makeOCRRequest(
|
||||
env.OCR_AZURE_ENDPOINT!,
|
||||
{
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${env.AZURE_OPENAI_API_KEY}`,
|
||||
},
|
||||
{
|
||||
model: env.OCR_AZURE_MODEL_NAME,
|
||||
document: {
|
||||
type: 'document_url',
|
||||
document_url: dataUri,
|
||||
},
|
||||
include_image_base64: false,
|
||||
}
|
||||
),
|
||||
{ maxRetries: 3, initialDelayMs: 1000, maxDelayMs: 10000 }
|
||||
)
|
||||
|
||||
const ocrResult = await response.json()
|
||||
const content = extractPageContent(ocrResult.pages) || JSON.stringify(ocrResult, null, 2)
|
||||
|
||||
if (!content.trim()) {
|
||||
throw new Error('Azure Mistral OCR returned empty content')
|
||||
}
|
||||
|
||||
logger.info(`Azure Mistral OCR completed: ${filename}`)
|
||||
return { content, processingMethod: 'mistral-ocr' as const, cloudUrl: undefined }
|
||||
} catch (error) {
|
||||
logger.error(`Azure Mistral OCR failed for ${filename}:`, {
|
||||
message: error instanceof Error ? error.message : String(error),
|
||||
})
|
||||
|
||||
return env.MISTRAL_API_KEY
|
||||
? parseWithMistralOCR(fileUrl, filename, mimeType)
|
||||
: parseWithFileParser(fileUrl, filename, mimeType)
|
||||
}
|
||||
}
|
||||
|
||||
async function parseWithMistralOCR(fileUrl: string, filename: string, mimeType: string) {
|
||||
if (!env.MISTRAL_API_KEY) {
|
||||
throw new Error('Mistral API key required')
|
||||
}
|
||||
|
||||
if (!mistralParserTool.request?.body) {
|
||||
throw new Error('Mistral parser tool not properly configured')
|
||||
throw new Error('Mistral parser tool not configured')
|
||||
}
|
||||
|
||||
const requestBody = mistralParserTool.request.body({
|
||||
filePath: httpsUrl,
|
||||
apiKey: mistralApiKey,
|
||||
resultType: 'text',
|
||||
})
|
||||
const { httpsUrl, cloudUrl } = await handleFileForOCR(fileUrl, filename, mimeType)
|
||||
const params = { filePath: httpsUrl, apiKey: env.MISTRAL_API_KEY, resultType: 'text' as const }
|
||||
|
||||
try {
|
||||
const response = await retryWithExponentialBackoff(
|
||||
async () => {
|
||||
const url =
|
||||
typeof mistralParserTool.request!.url === 'function'
|
||||
? mistralParserTool.request!.url({
|
||||
filePath: httpsUrl,
|
||||
apiKey: mistralApiKey,
|
||||
resultType: 'text',
|
||||
})
|
||||
? mistralParserTool.request!.url(params)
|
||||
: mistralParserTool.request!.url
|
||||
|
||||
const headers =
|
||||
typeof mistralParserTool.request!.headers === 'function'
|
||||
? mistralParserTool.request!.headers({
|
||||
filePath: httpsUrl,
|
||||
apiKey: mistralApiKey,
|
||||
resultType: 'text',
|
||||
})
|
||||
? mistralParserTool.request!.headers(params)
|
||||
: mistralParserTool.request!.headers
|
||||
|
||||
const controller = new AbortController()
|
||||
const timeoutId = setTimeout(() => controller.abort(), TIMEOUTS.MISTRAL_OCR_API)
|
||||
|
||||
try {
|
||||
const method =
|
||||
typeof mistralParserTool.request!.method === 'function'
|
||||
? mistralParserTool.request!.method(requestBody as any)
|
||||
: mistralParserTool.request!.method
|
||||
|
||||
const res = await fetch(url, {
|
||||
method,
|
||||
headers,
|
||||
body: JSON.stringify(requestBody),
|
||||
signal: controller.signal,
|
||||
})
|
||||
|
||||
clearTimeout(timeoutId)
|
||||
|
||||
if (!res.ok) {
|
||||
const errorText = await res.text()
|
||||
throw new APIError(
|
||||
`Mistral OCR failed: ${res.status} ${res.statusText} - ${errorText}`,
|
||||
res.status
|
||||
)
|
||||
}
|
||||
|
||||
return res
|
||||
} catch (error) {
|
||||
clearTimeout(timeoutId)
|
||||
if (error instanceof Error && error.name === 'AbortError') {
|
||||
throw new Error('Mistral OCR API request timed out')
|
||||
}
|
||||
throw error
|
||||
}
|
||||
const requestBody = mistralParserTool.request!.body!(params)
|
||||
return makeOCRRequest(url, headers as Record<string, string>, requestBody)
|
||||
},
|
||||
{
|
||||
maxRetries: 3,
|
||||
initialDelayMs: 1000,
|
||||
maxDelayMs: 10000,
|
||||
}
|
||||
{ maxRetries: 3, initialDelayMs: 1000, maxDelayMs: 10000 }
|
||||
)
|
||||
|
||||
const result = await mistralParserTool.transformResponse!(response, {
|
||||
filePath: httpsUrl,
|
||||
apiKey: mistralApiKey,
|
||||
resultType: 'text',
|
||||
})
|
||||
const result = await mistralParserTool.transformResponse!(response, params)
|
||||
const content = processOCRContent(result, filename)
|
||||
|
||||
if (!result.success) {
|
||||
throw new Error(`Mistral OCR processing failed: ${result.error || 'Unknown error'}`)
|
||||
}
|
||||
|
||||
const content = result.output?.content || ''
|
||||
if (!content.trim()) {
|
||||
throw new Error('Mistral OCR returned empty content')
|
||||
}
|
||||
|
||||
logger.info(`Mistral OCR completed successfully for ${filename}`)
|
||||
return {
|
||||
content,
|
||||
processingMethod: 'mistral-ocr',
|
||||
cloudUrl,
|
||||
}
|
||||
return { content, processingMethod: 'mistral-ocr' as const, cloudUrl }
|
||||
} catch (error) {
|
||||
logger.error(`Mistral OCR failed for ${filename}:`, {
|
||||
message: error instanceof Error ? error.message : String(error),
|
||||
stack: error instanceof Error ? error.stack : undefined,
|
||||
name: error instanceof Error ? error.name : 'Unknown',
|
||||
})
|
||||
|
||||
// Fall back to file parser
|
||||
logger.info(`Falling back to file parser for ${filename}`)
|
||||
return await parseWithFileParser(fileUrl, filename, mimeType)
|
||||
logger.info(`Falling back to file parser: ${filename}`)
|
||||
return parseWithFileParser(fileUrl, filename, mimeType)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse document using standard file parser
|
||||
*/
|
||||
async function parseWithFileParser(
|
||||
fileUrl: string,
|
||||
filename: string,
|
||||
mimeType: string
|
||||
): Promise<{
|
||||
content: string
|
||||
processingMethod: 'file-parser' | 'mistral-ocr'
|
||||
cloudUrl?: string
|
||||
}> {
|
||||
async function parseWithFileParser(fileUrl: string, filename: string, mimeType: string) {
|
||||
try {
|
||||
let content: string
|
||||
|
||||
if (fileUrl.startsWith('data:')) {
|
||||
logger.info(`Processing data URI for: ${filename}`)
|
||||
|
||||
try {
|
||||
const [header, base64Data] = fileUrl.split(',')
|
||||
if (!base64Data) {
|
||||
throw new Error('Invalid data URI format')
|
||||
}
|
||||
|
||||
if (header.includes('base64')) {
|
||||
const buffer = Buffer.from(base64Data, 'base64')
|
||||
content = buffer.toString('utf8')
|
||||
} else {
|
||||
content = decodeURIComponent(base64Data)
|
||||
}
|
||||
|
||||
if (mimeType === 'text/plain') {
|
||||
logger.info(`Data URI processed successfully for text content: ${filename}`)
|
||||
} else {
|
||||
const extension = filename.split('.').pop()?.toLowerCase() || 'txt'
|
||||
const buffer = Buffer.from(base64Data, 'base64')
|
||||
const result = await parseBuffer(buffer, extension)
|
||||
content = result.content
|
||||
}
|
||||
} catch (error) {
|
||||
throw new Error(
|
||||
`Failed to process data URI: ${error instanceof Error ? error.message : 'Unknown error'}`
|
||||
)
|
||||
}
|
||||
} else if (fileUrl.startsWith('http://') || fileUrl.startsWith('https://')) {
|
||||
const controller = new AbortController()
|
||||
const timeoutId = setTimeout(() => controller.abort(), TIMEOUTS.FILE_DOWNLOAD)
|
||||
|
||||
try {
|
||||
const response = await fetch(fileUrl, { signal: controller.signal })
|
||||
clearTimeout(timeoutId)
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to download file: ${response.status} ${response.statusText}`)
|
||||
}
|
||||
|
||||
const buffer = Buffer.from(await response.arrayBuffer())
|
||||
|
||||
const extension = filename.split('.').pop()?.toLowerCase() || ''
|
||||
if (!extension) {
|
||||
throw new Error(`Could not determine file extension from filename: ${filename}`)
|
||||
}
|
||||
|
||||
const result = await parseBuffer(buffer, extension)
|
||||
content = result.content
|
||||
} catch (error) {
|
||||
clearTimeout(timeoutId)
|
||||
if (error instanceof Error && error.name === 'AbortError') {
|
||||
throw new Error('File download timed out')
|
||||
}
|
||||
throw error
|
||||
}
|
||||
content = await parseDataURI(fileUrl, filename, mimeType)
|
||||
} else if (fileUrl.startsWith('http')) {
|
||||
content = await parseHttpFile(fileUrl, filename)
|
||||
} else {
|
||||
// Parse local file
|
||||
const result = await parseFile(fileUrl)
|
||||
content = result.content
|
||||
}
|
||||
@@ -413,12 +402,39 @@ async function parseWithFileParser(
|
||||
throw new Error('File parser returned empty content')
|
||||
}
|
||||
|
||||
return {
|
||||
content,
|
||||
processingMethod: 'file-parser',
|
||||
}
|
||||
return { content, processingMethod: 'file-parser' as const, cloudUrl: undefined }
|
||||
} catch (error) {
|
||||
logger.error(`File parser failed for ${filename}:`, error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
async function parseDataURI(fileUrl: string, filename: string, mimeType: string): Promise<string> {
|
||||
const [header, base64Data] = fileUrl.split(',')
|
||||
if (!base64Data) {
|
||||
throw new Error('Invalid data URI format')
|
||||
}
|
||||
|
||||
if (mimeType === 'text/plain') {
|
||||
return header.includes('base64')
|
||||
? Buffer.from(base64Data, 'base64').toString('utf8')
|
||||
: decodeURIComponent(base64Data)
|
||||
}
|
||||
|
||||
const extension = filename.split('.').pop()?.toLowerCase() || 'txt'
|
||||
const buffer = Buffer.from(base64Data, 'base64')
|
||||
const result = await parseBuffer(buffer, extension)
|
||||
return result.content
|
||||
}
|
||||
|
||||
async function parseHttpFile(fileUrl: string, filename: string): Promise<string> {
|
||||
const buffer = await downloadFileWithTimeout(fileUrl)
|
||||
|
||||
const extension = filename.split('.').pop()?.toLowerCase()
|
||||
if (!extension) {
|
||||
throw new Error(`Could not determine file extension: ${filename}`)
|
||||
}
|
||||
|
||||
const result = await parseBuffer(buffer, extension)
|
||||
return result.content
|
||||
}
|
||||
|
||||
148
apps/sim/lib/embeddings/utils.ts
Normal file
148
apps/sim/lib/embeddings/utils.ts
Normal file
@@ -0,0 +1,148 @@
|
||||
import { isRetryableError, retryWithExponentialBackoff } from '@/lib/documents/utils'
|
||||
import { env } from '@/lib/env'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
|
||||
const logger = createLogger('EmbeddingUtils')
|
||||
|
||||
export class EmbeddingAPIError extends Error {
|
||||
public status: number
|
||||
|
||||
constructor(message: string, status: number) {
|
||||
super(message)
|
||||
this.name = 'EmbeddingAPIError'
|
||||
this.status = status
|
||||
}
|
||||
}
|
||||
|
||||
interface EmbeddingConfig {
|
||||
useAzure: boolean
|
||||
apiUrl: string
|
||||
headers: Record<string, string>
|
||||
modelName: string
|
||||
}
|
||||
|
||||
function getEmbeddingConfig(embeddingModel = 'text-embedding-3-small'): EmbeddingConfig {
|
||||
const azureApiKey = env.AZURE_OPENAI_API_KEY
|
||||
const azureEndpoint = env.AZURE_OPENAI_ENDPOINT
|
||||
const azureApiVersion = env.AZURE_OPENAI_API_VERSION
|
||||
const kbModelName = env.KB_OPENAI_MODEL_NAME || embeddingModel
|
||||
const openaiApiKey = env.OPENAI_API_KEY
|
||||
|
||||
const useAzure = !!(azureApiKey && azureEndpoint)
|
||||
|
||||
if (!useAzure && !openaiApiKey) {
|
||||
throw new Error(
|
||||
'Either OPENAI_API_KEY or Azure OpenAI configuration (AZURE_OPENAI_API_KEY + AZURE_OPENAI_ENDPOINT) must be configured'
|
||||
)
|
||||
}
|
||||
|
||||
const apiUrl = useAzure
|
||||
? `${azureEndpoint}/openai/deployments/${kbModelName}/embeddings?api-version=${azureApiVersion}`
|
||||
: 'https://api.openai.com/v1/embeddings'
|
||||
|
||||
const headers: Record<string, string> = useAzure
|
||||
? {
|
||||
'api-key': azureApiKey!,
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
: {
|
||||
Authorization: `Bearer ${openaiApiKey!}`,
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
|
||||
return {
|
||||
useAzure,
|
||||
apiUrl,
|
||||
headers,
|
||||
modelName: useAzure ? kbModelName : embeddingModel,
|
||||
}
|
||||
}
|
||||
|
||||
async function callEmbeddingAPI(inputs: string[], config: EmbeddingConfig): Promise<number[][]> {
|
||||
return retryWithExponentialBackoff(
|
||||
async () => {
|
||||
const requestBody = config.useAzure
|
||||
? {
|
||||
input: inputs,
|
||||
encoding_format: 'float',
|
||||
}
|
||||
: {
|
||||
input: inputs,
|
||||
model: config.modelName,
|
||||
encoding_format: 'float',
|
||||
}
|
||||
|
||||
const response = await fetch(config.apiUrl, {
|
||||
method: 'POST',
|
||||
headers: config.headers,
|
||||
body: JSON.stringify(requestBody),
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text()
|
||||
throw new EmbeddingAPIError(
|
||||
`Embedding API failed: ${response.status} ${response.statusText} - ${errorText}`,
|
||||
response.status
|
||||
)
|
||||
}
|
||||
|
||||
const data = await response.json()
|
||||
return data.data.map((item: any) => item.embedding)
|
||||
},
|
||||
{
|
||||
maxRetries: 3,
|
||||
initialDelayMs: 1000,
|
||||
maxDelayMs: 10000,
|
||||
retryCondition: (error: any) => {
|
||||
if (error instanceof EmbeddingAPIError) {
|
||||
return error.status === 429 || error.status >= 500
|
||||
}
|
||||
return isRetryableError(error)
|
||||
},
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate embeddings for multiple texts with batching
|
||||
*/
|
||||
export async function generateEmbeddings(
|
||||
texts: string[],
|
||||
embeddingModel = 'text-embedding-3-small'
|
||||
): Promise<number[][]> {
|
||||
const config = getEmbeddingConfig(embeddingModel)
|
||||
|
||||
logger.info(`Using ${config.useAzure ? 'Azure OpenAI' : 'OpenAI'} for embeddings generation`)
|
||||
|
||||
const batchSize = 100
|
||||
const allEmbeddings: number[][] = []
|
||||
|
||||
for (let i = 0; i < texts.length; i += batchSize) {
|
||||
const batch = texts.slice(i, i + batchSize)
|
||||
const batchEmbeddings = await callEmbeddingAPI(batch, config)
|
||||
allEmbeddings.push(...batchEmbeddings)
|
||||
|
||||
logger.info(
|
||||
`Generated embeddings for batch ${Math.floor(i / batchSize) + 1}/${Math.ceil(texts.length / batchSize)}`
|
||||
)
|
||||
}
|
||||
|
||||
return allEmbeddings
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate embedding for a single search query
|
||||
*/
|
||||
export async function generateSearchEmbedding(
|
||||
query: string,
|
||||
embeddingModel = 'text-embedding-3-small'
|
||||
): Promise<number[]> {
|
||||
const config = getEmbeddingConfig(embeddingModel)
|
||||
|
||||
logger.info(
|
||||
`Using ${config.useAzure ? 'Azure OpenAI' : 'OpenAI'} for search embedding generation`
|
||||
)
|
||||
|
||||
const embeddings = await callEmbeddingAPI([query], config)
|
||||
return embeddings[0]
|
||||
}
|
||||
@@ -66,9 +66,14 @@ export const env = createEnv({
|
||||
ELEVENLABS_API_KEY: z.string().min(1).optional(), // ElevenLabs API key for text-to-speech in deployed chat
|
||||
SERPER_API_KEY: z.string().min(1).optional(), // Serper API key for online search
|
||||
|
||||
// Azure OpenAI Configuration
|
||||
AZURE_OPENAI_ENDPOINT: z.string().url().optional(), // Azure OpenAI service endpoint
|
||||
AZURE_OPENAI_API_VERSION: z.string().optional(), // Azure OpenAI API version
|
||||
// Azure Configuration - Shared credentials with feature-specific models
|
||||
AZURE_OPENAI_ENDPOINT: z.string().url().optional(), // Shared Azure OpenAI service endpoint
|
||||
AZURE_OPENAI_API_VERSION: z.string().optional(), // Shared Azure OpenAI API version
|
||||
AZURE_OPENAI_API_KEY: z.string().min(1).optional(), // Shared Azure OpenAI API key
|
||||
KB_OPENAI_MODEL_NAME: z.string().optional(), // Knowledge base OpenAI model name (works with both regular OpenAI and Azure OpenAI)
|
||||
WAND_OPENAI_MODEL_NAME: z.string().optional(), // Wand generation OpenAI model name (works with both regular OpenAI and Azure OpenAI)
|
||||
OCR_AZURE_ENDPOINT: z.string().url().optional(), // Azure Mistral OCR service endpoint
|
||||
OCR_AZURE_MODEL_NAME: z.string().optional(), // Azure Mistral OCR model name for document processing
|
||||
|
||||
// Monitoring & Analytics
|
||||
TELEMETRY_ENDPOINT: z.string().url().optional(), // Custom telemetry/analytics endpoint
|
||||
|
||||
Reference in New Issue
Block a user