feat(byok): byok for hosted model capabilities (#2574)

* feat(byok): byok for hosted model capabilities

* fix type

* add ignore lint

* accidentally added feature flags

* centralize byok fetch for LLM calls

* remove feature flags ts

* fix tests

* update docs
This commit is contained in:
Vikhyath Mondreti
2025-12-24 18:20:54 -08:00
committed by GitHub
parent 40a6bf5c8c
commit 47a259b428
35 changed files with 9656 additions and 181 deletions

View File

@@ -104,6 +104,10 @@ The model breakdown shows:
Pricing shown reflects rates as of September 10, 2025. Check provider documentation for current pricing.
</Callout>
## Bring Your Own Key (BYOK)
You can use your own API keys for hosted models (OpenAI, Anthropic, Google, Mistral) in **Settings → BYOK** to pay base prices. Keys are encrypted and apply workspace-wide.
## Cost Optimization Strategies
- **Model Selection**: Choose models based on task complexity. Simple tasks can use GPT-4.1-nano while complex reasoning might need o1 or Claude Opus.

View File

@@ -100,7 +100,12 @@ export async function PUT(
try {
const validatedData = UpdateChunkSchema.parse(body)
const updatedChunk = await updateChunk(chunkId, validatedData, requestId)
const updatedChunk = await updateChunk(
chunkId,
validatedData,
requestId,
accessCheck.knowledgeBase?.workspaceId
)
logger.info(
`[${requestId}] Chunk updated: ${chunkId} in document ${documentId} in knowledge base ${knowledgeBaseId}`

View File

@@ -184,7 +184,8 @@ export async function POST(
documentId,
docTags,
validatedData,
requestId
requestId,
accessCheck.knowledgeBase?.workspaceId
)
let cost = null

View File

@@ -183,11 +183,11 @@ export async function POST(request: NextRequest) {
)
}
// Generate query embedding only if query is provided
const workspaceId = accessChecks.find((ac) => ac?.hasAccess)?.knowledgeBase?.workspaceId
const hasQuery = validatedData.query && validatedData.query.trim().length > 0
// Start embedding generation early and await when needed
const queryEmbeddingPromise = hasQuery
? generateSearchEmbedding(validatedData.query!)
? generateSearchEmbedding(validatedData.query!, undefined, workspaceId)
: Promise.resolve(null)
// Check if any requested knowledge bases were not accessible

View File

@@ -99,7 +99,7 @@ export interface EmbeddingData {
export interface KnowledgeBaseAccessResult {
hasAccess: true
knowledgeBase: Pick<KnowledgeBaseData, 'id' | 'userId'>
knowledgeBase: Pick<KnowledgeBaseData, 'id' | 'userId' | 'workspaceId'>
}
export interface KnowledgeBaseAccessDenied {
@@ -113,7 +113,7 @@ export type KnowledgeBaseAccessCheck = KnowledgeBaseAccessResult | KnowledgeBase
export interface DocumentAccessResult {
hasAccess: true
document: DocumentData
knowledgeBase: Pick<KnowledgeBaseData, 'id' | 'userId'>
knowledgeBase: Pick<KnowledgeBaseData, 'id' | 'userId' | 'workspaceId'>
}
export interface DocumentAccessDenied {
@@ -128,7 +128,7 @@ export interface ChunkAccessResult {
hasAccess: true
chunk: EmbeddingData
document: DocumentData
knowledgeBase: Pick<KnowledgeBaseData, 'id' | 'userId'>
knowledgeBase: Pick<KnowledgeBaseData, 'id' | 'userId' | 'workspaceId'>
}
export interface ChunkAccessDenied {

View File

@@ -7,7 +7,6 @@ import { createLogger } from '@/lib/logs/console/logger'
import { refreshTokenIfNeeded } from '@/app/api/auth/oauth/utils'
import type { StreamingExecution } from '@/executor/types'
import { executeProviderRequest } from '@/providers'
import { getApiKey } from '@/providers/utils'
const logger = createLogger('ProvidersAPI')
@@ -80,23 +79,20 @@ export async function POST(request: NextRequest) {
verbosity,
})
let finalApiKey: string
let finalApiKey: string | undefined = apiKey
try {
if (provider === 'vertex' && vertexCredential) {
finalApiKey = await resolveVertexCredential(requestId, vertexCredential)
} else {
finalApiKey = getApiKey(provider, model, apiKey)
}
} catch (error) {
logger.error(`[${requestId}] Failed to get API key:`, {
logger.error(`[${requestId}] Failed to resolve Vertex credential:`, {
provider,
model,
error: error instanceof Error ? error.message : String(error),
hasProvidedApiKey: !!apiKey,
hasVertexCredential: !!vertexCredential,
})
return NextResponse.json(
{ error: error instanceof Error ? error.message : 'API key error' },
{ error: error instanceof Error ? error.message : 'Credential error' },
{ status: 400 }
)
}
@@ -108,7 +104,6 @@ export async function POST(request: NextRequest) {
hasApiKey: !!finalApiKey,
})
// Execute provider request directly with the managed key
const response = await executeProviderRequest(provider, {
model,
systemPrompt,

View File

@@ -1,5 +1,6 @@
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { getBYOKKey } from '@/lib/api-key/byok'
import { checkHybridAuth } from '@/lib/auth/hybrid'
import { SEARCH_TOOL_COST } from '@/lib/billing/constants'
import { env } from '@/lib/core/config/env'
@@ -10,6 +11,7 @@ const logger = createLogger('search')
const SearchRequestSchema = z.object({
query: z.string().min(1),
workspaceId: z.string().optional(),
})
export const maxDuration = 60
@@ -39,8 +41,20 @@ export async function POST(request: NextRequest) {
const body = await request.json()
const validated = SearchRequestSchema.parse(body)
if (!env.EXA_API_KEY) {
logger.error(`[${requestId}] EXA_API_KEY not configured`)
let exaApiKey = env.EXA_API_KEY
let isBYOK = false
if (validated.workspaceId) {
const byokResult = await getBYOKKey(validated.workspaceId, 'exa')
if (byokResult) {
exaApiKey = byokResult.apiKey
isBYOK = true
logger.info(`[${requestId}] Using workspace BYOK key for Exa search`)
}
}
if (!exaApiKey) {
logger.error(`[${requestId}] No Exa API key available`)
return NextResponse.json(
{ success: false, error: 'Search service not configured' },
{ status: 503 }
@@ -50,6 +64,7 @@ export async function POST(request: NextRequest) {
logger.info(`[${requestId}] Executing search`, {
userId,
query: validated.query,
isBYOK,
})
const result = await executeTool('exa_search', {
@@ -57,7 +72,7 @@ export async function POST(request: NextRequest) {
type: 'auto',
useAutoprompt: true,
highlights: true,
apiKey: env.EXA_API_KEY,
apiKey: exaApiKey,
})
if (!result.success) {
@@ -85,7 +100,7 @@ export async function POST(request: NextRequest) {
const cost = {
input: 0,
output: 0,
total: SEARCH_TOOL_COST,
total: isBYOK ? 0 : SEARCH_TOOL_COST,
tokens: {
input: 0,
output: 0,
@@ -104,6 +119,7 @@ export async function POST(request: NextRequest) {
userId,
resultCount: results.length,
cost: cost.total,
isBYOK,
})
return NextResponse.json({

View File

@@ -3,6 +3,7 @@ import { userStats, workflow } from '@sim/db/schema'
import { eq, sql } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import OpenAI, { AzureOpenAI } from 'openai'
import { getBYOKKey } from '@/lib/api-key/byok'
import { getSession } from '@/lib/auth'
import { logModelUsage } from '@/lib/billing/core/usage-log'
import { checkAndBillOverageThreshold } from '@/lib/billing/threshold-billing'
@@ -75,7 +76,8 @@ async function updateUserStatsForWand(
completion_tokens?: number
total_tokens?: number
},
requestId: string
requestId: string,
isBYOK = false
): Promise<void> {
if (!isBillingEnabled) {
logger.debug(`[${requestId}] Billing is disabled, skipping wand usage cost update`)
@@ -93,21 +95,24 @@ async function updateUserStatsForWand(
const completionTokens = usage.completion_tokens || 0
const modelName = useWandAzure ? wandModelName : 'gpt-4o'
const pricing = getModelPricing(modelName)
let costToStore = 0
const costMultiplier = getCostMultiplier()
let modelCost = 0
if (!isBYOK) {
const pricing = getModelPricing(modelName)
const costMultiplier = getCostMultiplier()
let modelCost = 0
if (pricing) {
const inputCost = (promptTokens / 1000000) * pricing.input
const outputCost = (completionTokens / 1000000) * pricing.output
modelCost = inputCost + outputCost
} else {
modelCost = (promptTokens / 1000000) * 0.005 + (completionTokens / 1000000) * 0.015
if (pricing) {
const inputCost = (promptTokens / 1000000) * pricing.input
const outputCost = (completionTokens / 1000000) * pricing.output
modelCost = inputCost + outputCost
} else {
modelCost = (promptTokens / 1000000) * 0.005 + (completionTokens / 1000000) * 0.015
}
costToStore = modelCost * costMultiplier
}
const costToStore = modelCost * costMultiplier
await db
.update(userStats)
.set({
@@ -122,6 +127,7 @@ async function updateUserStatsForWand(
userId,
tokensUsed: totalTokens,
costAdded: costToStore,
isBYOK,
})
await logModelUsage({
@@ -149,14 +155,6 @@ export async function POST(req: NextRequest) {
return NextResponse.json({ success: false, error: 'Unauthorized' }, { status: 401 })
}
if (!client) {
logger.error(`[${requestId}] AI client not initialized. Missing API key.`)
return NextResponse.json(
{ success: false, error: 'Wand generation service is not configured.' },
{ status: 503 }
)
}
try {
const body = (await req.json()) as RequestBody
@@ -170,6 +168,7 @@ export async function POST(req: NextRequest) {
)
}
let workspaceId: string | null = null
if (workflowId) {
const [workflowRecord] = await db
.select({ workspaceId: workflow.workspaceId, userId: workflow.userId })
@@ -182,6 +181,8 @@ export async function POST(req: NextRequest) {
return NextResponse.json({ success: false, error: 'Workflow not found' }, { status: 404 })
}
workspaceId = workflowRecord.workspaceId
if (workflowRecord.workspaceId) {
const permission = await verifyWorkspaceMembership(
session.user.id,
@@ -199,6 +200,28 @@ export async function POST(req: NextRequest) {
}
}
let isBYOK = false
let activeClient = client
let byokApiKey: string | null = null
if (workspaceId && !useWandAzure) {
const byokResult = await getBYOKKey(workspaceId, 'openai')
if (byokResult) {
isBYOK = true
byokApiKey = byokResult.apiKey
activeClient = new OpenAI({ apiKey: byokResult.apiKey })
logger.info(`[${requestId}] Using BYOK OpenAI key for wand generation`)
}
}
if (!activeClient) {
logger.error(`[${requestId}] AI client not initialized. Missing API key.`)
return NextResponse.json(
{ success: false, error: 'Wand generation service is not configured.' },
{ status: 503 }
)
}
const finalSystemPrompt =
systemPrompt ||
'You are a helpful AI assistant. Generate content exactly as requested by the user.'
@@ -241,7 +264,7 @@ export async function POST(req: NextRequest) {
if (useWandAzure) {
headers['api-key'] = azureApiKey!
} else {
headers.Authorization = `Bearer ${openaiApiKey}`
headers.Authorization = `Bearer ${byokApiKey || openaiApiKey}`
}
logger.debug(`[${requestId}] Making streaming request to: ${apiUrl}`)
@@ -310,7 +333,7 @@ export async function POST(req: NextRequest) {
logger.info(`[${requestId}] Received [DONE] signal`)
if (finalUsage) {
await updateUserStatsForWand(session.user.id, finalUsage, requestId)
await updateUserStatsForWand(session.user.id, finalUsage, requestId, isBYOK)
}
controller.enqueue(
@@ -395,7 +418,7 @@ export async function POST(req: NextRequest) {
}
}
const completion = await client.chat.completions.create({
const completion = await activeClient.chat.completions.create({
model: useWandAzure ? wandModelName : 'gpt-4o',
messages: messages,
temperature: 0.3,
@@ -417,7 +440,7 @@ export async function POST(req: NextRequest) {
logger.info(`[${requestId}] Wand generation successful`)
if (completion.usage) {
await updateUserStatsForWand(session.user.id, completion.usage, requestId)
await updateUserStatsForWand(session.user.id, completion.usage, requestId, isBYOK)
}
return NextResponse.json({ success: true, content: generatedContent })

View File

@@ -0,0 +1,256 @@
import { db } from '@sim/db'
import { workspace, workspaceBYOKKeys } from '@sim/db/schema'
import { and, eq } from 'drizzle-orm'
import { nanoid } from 'nanoid'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { getSession } from '@/lib/auth'
import { decryptSecret, encryptSecret } from '@/lib/core/security/encryption'
import { generateRequestId } from '@/lib/core/utils/request'
import { createLogger } from '@/lib/logs/console/logger'
import { getUserEntityPermissions } from '@/lib/workspaces/permissions/utils'
const logger = createLogger('WorkspaceBYOKKeysAPI')
const VALID_PROVIDERS = ['openai', 'anthropic', 'google', 'mistral', 'exa'] as const
const UpsertKeySchema = z.object({
providerId: z.enum(VALID_PROVIDERS),
apiKey: z.string().min(1, 'API key is required'),
})
const DeleteKeySchema = z.object({
providerId: z.enum(VALID_PROVIDERS),
})
function maskApiKey(key: string): string {
if (key.length <= 8) {
return '•'.repeat(8)
}
if (key.length <= 12) {
return `${key.slice(0, 4)}...${key.slice(-4)}`
}
return `${key.slice(0, 6)}...${key.slice(-4)}`
}
export async function GET(request: NextRequest, { params }: { params: Promise<{ id: string }> }) {
const requestId = generateRequestId()
const workspaceId = (await params).id
try {
const session = await getSession()
if (!session?.user?.id) {
logger.warn(`[${requestId}] Unauthorized BYOK keys access attempt`)
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
const userId = session.user.id
const ws = await db.select().from(workspace).where(eq(workspace.id, workspaceId)).limit(1)
if (!ws.length) {
return NextResponse.json({ error: 'Workspace not found' }, { status: 404 })
}
const permission = await getUserEntityPermissions(userId, 'workspace', workspaceId)
if (!permission) {
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
const byokKeys = await db
.select({
id: workspaceBYOKKeys.id,
providerId: workspaceBYOKKeys.providerId,
encryptedApiKey: workspaceBYOKKeys.encryptedApiKey,
createdBy: workspaceBYOKKeys.createdBy,
createdAt: workspaceBYOKKeys.createdAt,
updatedAt: workspaceBYOKKeys.updatedAt,
})
.from(workspaceBYOKKeys)
.where(eq(workspaceBYOKKeys.workspaceId, workspaceId))
.orderBy(workspaceBYOKKeys.providerId)
const formattedKeys = await Promise.all(
byokKeys.map(async (key) => {
try {
const { decrypted } = await decryptSecret(key.encryptedApiKey)
return {
id: key.id,
providerId: key.providerId,
maskedKey: maskApiKey(decrypted),
createdBy: key.createdBy,
createdAt: key.createdAt,
updatedAt: key.updatedAt,
}
} catch (error) {
logger.error(`[${requestId}] Failed to decrypt BYOK key for provider ${key.providerId}`, {
error,
})
return {
id: key.id,
providerId: key.providerId,
maskedKey: '••••••••',
createdBy: key.createdBy,
createdAt: key.createdAt,
updatedAt: key.updatedAt,
}
}
})
)
return NextResponse.json({ keys: formattedKeys })
} catch (error: unknown) {
logger.error(`[${requestId}] BYOK keys GET error`, error)
return NextResponse.json(
{ error: error instanceof Error ? error.message : 'Failed to load BYOK keys' },
{ status: 500 }
)
}
}
export async function POST(request: NextRequest, { params }: { params: Promise<{ id: string }> }) {
const requestId = generateRequestId()
const workspaceId = (await params).id
try {
const session = await getSession()
if (!session?.user?.id) {
logger.warn(`[${requestId}] Unauthorized BYOK key creation attempt`)
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
const userId = session.user.id
const permission = await getUserEntityPermissions(userId, 'workspace', workspaceId)
if (permission !== 'admin') {
return NextResponse.json(
{ error: 'Only workspace admins can manage BYOK keys' },
{ status: 403 }
)
}
const body = await request.json()
const { providerId, apiKey } = UpsertKeySchema.parse(body)
const { encrypted } = await encryptSecret(apiKey)
const existingKey = await db
.select()
.from(workspaceBYOKKeys)
.where(
and(
eq(workspaceBYOKKeys.workspaceId, workspaceId),
eq(workspaceBYOKKeys.providerId, providerId)
)
)
.limit(1)
if (existingKey.length > 0) {
await db
.update(workspaceBYOKKeys)
.set({
encryptedApiKey: encrypted,
updatedAt: new Date(),
})
.where(eq(workspaceBYOKKeys.id, existingKey[0].id))
logger.info(`[${requestId}] Updated BYOK key for ${providerId} in workspace ${workspaceId}`)
return NextResponse.json({
success: true,
key: {
id: existingKey[0].id,
providerId,
maskedKey: maskApiKey(apiKey),
updatedAt: new Date(),
},
})
}
const [newKey] = await db
.insert(workspaceBYOKKeys)
.values({
id: nanoid(),
workspaceId,
providerId,
encryptedApiKey: encrypted,
createdBy: userId,
createdAt: new Date(),
updatedAt: new Date(),
})
.returning({
id: workspaceBYOKKeys.id,
providerId: workspaceBYOKKeys.providerId,
createdAt: workspaceBYOKKeys.createdAt,
})
logger.info(`[${requestId}] Created BYOK key for ${providerId} in workspace ${workspaceId}`)
return NextResponse.json({
success: true,
key: {
...newKey,
maskedKey: maskApiKey(apiKey),
},
})
} catch (error: unknown) {
logger.error(`[${requestId}] BYOK key POST error`, error)
if (error instanceof z.ZodError) {
return NextResponse.json({ error: error.errors[0].message }, { status: 400 })
}
return NextResponse.json(
{ error: error instanceof Error ? error.message : 'Failed to save BYOK key' },
{ status: 500 }
)
}
}
export async function DELETE(
request: NextRequest,
{ params }: { params: Promise<{ id: string }> }
) {
const requestId = generateRequestId()
const workspaceId = (await params).id
try {
const session = await getSession()
if (!session?.user?.id) {
logger.warn(`[${requestId}] Unauthorized BYOK key deletion attempt`)
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
const userId = session.user.id
const permission = await getUserEntityPermissions(userId, 'workspace', workspaceId)
if (permission !== 'admin') {
return NextResponse.json(
{ error: 'Only workspace admins can manage BYOK keys' },
{ status: 403 }
)
}
const body = await request.json()
const { providerId } = DeleteKeySchema.parse(body)
const result = await db
.delete(workspaceBYOKKeys)
.where(
and(
eq(workspaceBYOKKeys.workspaceId, workspaceId),
eq(workspaceBYOKKeys.providerId, providerId)
)
)
logger.info(`[${requestId}] Deleted BYOK key for ${providerId} from workspace ${workspaceId}`)
return NextResponse.json({ success: true })
} catch (error: unknown) {
logger.error(`[${requestId}] BYOK key DELETE error`, error)
if (error instanceof z.ZodError) {
return NextResponse.json({ error: error.errors[0].message }, { status: 400 })
}
return NextResponse.json(
{ error: error instanceof Error ? error.message : 'Failed to delete BYOK key' },
{ status: 500 }
)
}
}

View File

@@ -0,0 +1,316 @@
'use client'
import { useState } from 'react'
import { Eye, EyeOff } from 'lucide-react'
import { useParams } from 'next/navigation'
import {
Button,
Input as EmcnInput,
Modal,
ModalBody,
ModalContent,
ModalFooter,
ModalHeader,
Trash,
} from '@/components/emcn'
import { AnthropicIcon, ExaAIIcon, GeminiIcon, MistralIcon, OpenAIIcon } from '@/components/icons'
import { Skeleton } from '@/components/ui'
import { createLogger } from '@/lib/logs/console/logger'
import {
type BYOKKey,
type BYOKProviderId,
useBYOKKeys,
useDeleteBYOKKey,
useUpsertBYOKKey,
} from '@/hooks/queries/byok-keys'
const logger = createLogger('BYOKSettings')
const PROVIDERS: {
id: BYOKProviderId
name: string
icon: React.ComponentType<{ className?: string }>
description: string
placeholder: string
}[] = [
{
id: 'openai',
name: 'OpenAI',
icon: OpenAIIcon,
description: 'LLM calls and Knowledge Base embeddings',
placeholder: 'sk-...',
},
{
id: 'anthropic',
name: 'Anthropic',
icon: AnthropicIcon,
description: 'LLM calls',
placeholder: 'sk-ant-...',
},
{
id: 'google',
name: 'Google',
icon: GeminiIcon,
description: 'LLM calls',
placeholder: 'Enter your API key',
},
{
id: 'mistral',
name: 'Mistral',
icon: MistralIcon,
description: 'LLM calls and Knowledge Base OCR',
placeholder: 'Enter your API key',
},
{
id: 'exa',
name: 'Exa',
icon: ExaAIIcon,
description: 'Web Search block',
placeholder: 'Enter your API key',
},
]
function BYOKKeySkeleton() {
return (
<div className='flex items-center justify-between gap-[12px] rounded-[8px] border p-[12px]'>
<div className='flex items-center gap-[12px]'>
<Skeleton className='h-[32px] w-[32px] rounded-[6px]' />
<div className='flex flex-col gap-[4px]'>
<Skeleton className='h-[16px] w-[80px]' />
<Skeleton className='h-[14px] w-[160px]' />
</div>
</div>
<Skeleton className='h-[32px] w-[80px] rounded-[6px]' />
</div>
)
}
export function BYOK() {
const params = useParams()
const workspaceId = (params?.workspaceId as string) || ''
const { data: keys = [], isLoading } = useBYOKKeys(workspaceId)
const upsertKey = useUpsertBYOKKey()
const deleteKey = useDeleteBYOKKey()
const [editingProvider, setEditingProvider] = useState<BYOKProviderId | null>(null)
const [apiKeyInput, setApiKeyInput] = useState('')
const [showApiKey, setShowApiKey] = useState(false)
const [error, setError] = useState<string | null>(null)
const [deleteConfirmProvider, setDeleteConfirmProvider] = useState<BYOKProviderId | null>(null)
const getKeyForProvider = (providerId: BYOKProviderId): BYOKKey | undefined => {
return keys.find((k) => k.providerId === providerId)
}
const handleSave = async () => {
if (!editingProvider || !apiKeyInput.trim()) return
setError(null)
try {
await upsertKey.mutateAsync({
workspaceId,
providerId: editingProvider,
apiKey: apiKeyInput.trim(),
})
setEditingProvider(null)
setApiKeyInput('')
setShowApiKey(false)
} catch (err) {
const message = err instanceof Error ? err.message : 'Failed to save API key'
setError(message)
logger.error('Failed to save BYOK key', { error: err })
}
}
const handleDelete = async () => {
if (!deleteConfirmProvider) return
try {
await deleteKey.mutateAsync({
workspaceId,
providerId: deleteConfirmProvider,
})
setDeleteConfirmProvider(null)
} catch (err) {
logger.error('Failed to delete BYOK key', { error: err })
}
}
const openEditModal = (providerId: BYOKProviderId) => {
setEditingProvider(providerId)
setApiKeyInput('')
setShowApiKey(false)
setError(null)
}
return (
<>
<div className='flex h-full flex-col gap-[16px]'>
<p className='text-[13px] text-[var(--text-secondary)]'>
Use your own API keys for hosted model providers.
</p>
<div className='min-h-0 flex-1 overflow-y-auto'>
{isLoading ? (
<div className='flex flex-col gap-[8px]'>
{PROVIDERS.map((p) => (
<BYOKKeySkeleton key={p.id} />
))}
</div>
) : (
<div className='flex flex-col gap-[8px]'>
{PROVIDERS.map((provider) => {
const existingKey = getKeyForProvider(provider.id)
const Icon = provider.icon
return (
<div
key={provider.id}
className='flex items-center justify-between gap-[12px] rounded-[8px] border p-[12px]'
>
<div className='flex items-center gap-[12px]'>
<div className='flex h-[32px] w-[32px] items-center justify-center rounded-[6px] bg-[var(--surface-3)]'>
<Icon className='h-[18px] w-[18px]' />
</div>
<div className='flex flex-col gap-[2px]'>
<span className='font-medium text-[14px]'>{provider.name}</span>
<span className='text-[12px] text-[var(--text-tertiary)]'>
{provider.description}
</span>
{existingKey && (
<span className='font-mono text-[11px] text-[var(--text-muted)]'>
{existingKey.maskedKey}
</span>
)}
</div>
</div>
<div className='flex items-center gap-[6px]'>
{existingKey && (
<Button
variant='ghost'
className='h-9 w-9'
onClick={() => setDeleteConfirmProvider(provider.id)}
>
<Trash />
</Button>
)}
<Button variant='default' onClick={() => openEditModal(provider.id)}>
{existingKey ? 'Update' : 'Add Key'}
</Button>
</div>
</div>
)
})}
</div>
)}
</div>
</div>
<Modal
open={!!editingProvider}
onOpenChange={(open) => {
if (!open) {
setEditingProvider(null)
setApiKeyInput('')
setShowApiKey(false)
setError(null)
}
}}
>
<ModalContent className='w-[420px]'>
<ModalHeader>
{editingProvider && (
<>
{getKeyForProvider(editingProvider) ? 'Update' : 'Add'}{' '}
{PROVIDERS.find((p) => p.id === editingProvider)?.name} API Key
</>
)}
</ModalHeader>
<ModalBody>
<p className='text-[12px] text-[var(--text-tertiary)]'>
This key will be used for all {PROVIDERS.find((p) => p.id === editingProvider)?.name}{' '}
requests in this workspace. Your key is encrypted and stored securely.
</p>
<div className='mt-[12px] flex flex-col gap-[8px]'>
<div className='relative'>
<EmcnInput
type={showApiKey ? 'text' : 'password'}
value={apiKeyInput}
onChange={(e) => {
setApiKeyInput(e.target.value)
if (error) setError(null)
}}
placeholder={PROVIDERS.find((p) => p.id === editingProvider)?.placeholder}
className='h-9 pr-[36px]'
autoFocus
/>
<Button
variant='ghost'
className='-translate-y-1/2 absolute top-1/2 right-[4px] h-[28px] w-[28px] p-0'
onClick={() => setShowApiKey(!showApiKey)}
>
{showApiKey ? (
<EyeOff className='h-[14px] w-[14px]' />
) : (
<Eye className='h-[14px] w-[14px]' />
)}
</Button>
</div>
{error && (
<p className='text-[11px] text-[var(--text-error)] leading-tight'>{error}</p>
)}
</div>
</ModalBody>
<ModalFooter>
<Button
variant='default'
onClick={() => {
setEditingProvider(null)
setApiKeyInput('')
setShowApiKey(false)
setError(null)
}}
>
Cancel
</Button>
<Button
variant='primary'
onClick={handleSave}
disabled={!apiKeyInput.trim() || upsertKey.isPending}
>
{upsertKey.isPending ? 'Saving...' : 'Save'}
</Button>
</ModalFooter>
</ModalContent>
</Modal>
<Modal open={!!deleteConfirmProvider} onOpenChange={() => setDeleteConfirmProvider(null)}>
<ModalContent className='w-[400px]'>
<ModalHeader>Delete API Key</ModalHeader>
<ModalBody>
<p className='text-[12px] text-[var(--text-tertiary)]'>
Are you sure you want to delete the{' '}
<span className='font-medium text-[var(--text-primary)]'>
{PROVIDERS.find((p) => p.id === deleteConfirmProvider)?.name}
</span>{' '}
API key? This workspace will revert to using platform keys with the 2x multiplier.
</p>
</ModalBody>
<ModalFooter>
<Button variant='default' onClick={() => setDeleteConfirmProvider(null)}>
Cancel
</Button>
<Button variant='primary' onClick={handleDelete} disabled={deleteKey.isPending}>
{deleteKey.isPending ? 'Deleting...' : 'Delete'}
</Button>
</ModalFooter>
</ModalContent>
</Modal>
</>
)
}

View File

@@ -1,4 +1,5 @@
export { ApiKeys } from './api-keys/api-keys'
export { BYOK } from './byok/byok'
export { Copilot } from './copilot/copilot'
export { CustomTools } from './custom-tools/custom-tools'
export { EnvironmentVariables } from './environment/environment'

View File

@@ -4,7 +4,7 @@ import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
import * as DialogPrimitive from '@radix-ui/react-dialog'
import * as VisuallyHidden from '@radix-ui/react-visually-hidden'
import { useQueryClient } from '@tanstack/react-query'
import { Files, LogIn, Settings, User, Users, Wrench } from 'lucide-react'
import { Files, KeySquare, LogIn, Settings, User, Users, Wrench } from 'lucide-react'
import {
Card,
Connections,
@@ -30,6 +30,7 @@ import { isHosted } from '@/lib/core/config/feature-flags'
import { getUserRole } from '@/lib/workspaces/organization'
import {
ApiKeys,
BYOK,
Copilot,
CustomTools,
EnvironmentVariables,
@@ -62,6 +63,7 @@ type SettingsSection =
| 'template-profile'
| 'integrations'
| 'apikeys'
| 'byok'
| 'files'
| 'subscription'
| 'team'
@@ -114,6 +116,13 @@ const allNavigationItems: NavigationItem[] = [
{ id: 'mcp', label: 'MCPs', icon: McpIcon, section: 'tools' },
{ id: 'environment', label: 'Environment', icon: FolderCode, section: 'system' },
{ id: 'apikeys', label: 'API Keys', icon: Key, section: 'system' },
{
id: 'byok',
label: 'BYOK',
icon: KeySquare,
section: 'system',
requiresHosted: true,
},
{
id: 'copilot',
label: 'Copilot Keys',
@@ -456,6 +465,7 @@ export function SettingsModal({ open, onOpenChange }: SettingsModalProps) {
{isBillingEnabled && activeSection === 'subscription' && <Subscription />}
{isBillingEnabled && activeSection === 'team' && <TeamManagement />}
{activeSection === 'sso' && <SSO />}
{activeSection === 'byok' && <BYOK />}
{activeSection === 'copilot' && <Copilot />}
{activeSection === 'mcp' && <MCP initialServerId={pendingMcpServerId} />}
{activeSection === 'custom-tools' && <CustomTools />}

View File

@@ -26,7 +26,7 @@ import { collectBlockData } from '@/executor/utils/block-data'
import { buildAPIUrl, buildAuthHeaders, extractAPIErrorMessage } from '@/executor/utils/http'
import { stringifyJSON } from '@/executor/utils/json'
import { executeProviderRequest } from '@/providers'
import { getApiKey, getProviderFromModel, transformBlockTool } from '@/providers/utils'
import { getProviderFromModel, transformBlockTool } from '@/providers/utils'
import type { SerializedBlock } from '@/serializer/types'
import { executeTool } from '@/tools'
import { getTool, getToolAsync } from '@/tools/utils'
@@ -1006,15 +1006,13 @@ export class AgentBlockHandler implements BlockHandler {
responseFormat: any,
providerStartTime: number
) {
let finalApiKey: string
let finalApiKey: string | undefined = providerRequest.apiKey
if (providerId === 'vertex' && providerRequest.vertexCredential) {
finalApiKey = await this.resolveVertexCredential(
providerRequest.vertexCredential,
ctx.workflowId
)
} else {
finalApiKey = this.getApiKey(providerId, model, providerRequest.apiKey)
}
const { blockData, blockNameMapping } = collectBlockData(ctx)
@@ -1033,7 +1031,7 @@ export class AgentBlockHandler implements BlockHandler {
vertexLocation: providerRequest.vertexLocation,
responseFormat: providerRequest.responseFormat,
workflowId: providerRequest.workflowId,
workspaceId: providerRequest.workspaceId,
workspaceId: ctx.workspaceId,
stream: providerRequest.stream,
messages: 'messages' in providerRequest ? providerRequest.messages : undefined,
environmentVariables: ctx.environmentVariables || {},
@@ -1111,20 +1109,6 @@ export class AgentBlockHandler implements BlockHandler {
return this.createMinimalStreamingExecution(response.body!)
}
private getApiKey(providerId: string, model: string, inputApiKey: string): string {
try {
return getApiKey(providerId, model, inputApiKey)
} catch (error) {
logger.error('Failed to get API key:', {
provider: providerId,
model,
error: error instanceof Error ? error.message : String(error),
hasProvidedApiKey: !!inputApiKey,
})
throw new Error(error instanceof Error ? error.message : 'API key error')
}
}
/**
* Resolves a Vertex AI OAuth credential to an access token
*/

View File

@@ -388,21 +388,6 @@ describe('EvaluatorBlockHandler', () => {
})
})
it('should throw error when API key is missing for non-hosted models', async () => {
const inputs = {
content: 'Test content',
metrics: [{ name: 'score', description: 'Score', range: { min: 0, max: 10 } }],
model: 'gpt-4o',
// No apiKey provided
}
mockGetProviderFromModel.mockReturnValue('openai')
await expect(handler.execute(mockContext, mockBlock, inputs)).rejects.toThrow(
/API key is required/
)
})
it('should handle Vertex AI models with OAuth credential', async () => {
const inputs = {
content: 'Test content to evaluate',

View File

@@ -8,7 +8,7 @@ import { BlockType, DEFAULTS, EVALUATOR, HTTP } from '@/executor/constants'
import type { BlockHandler, ExecutionContext } from '@/executor/types'
import { buildAPIUrl, extractAPIErrorMessage } from '@/executor/utils/http'
import { isJSONString, parseJSON, stringifyJSON } from '@/executor/utils/json'
import { calculateCost, getApiKey, getProviderFromModel } from '@/providers/utils'
import { calculateCost, getProviderFromModel } from '@/providers/utils'
import type { SerializedBlock } from '@/serializer/types'
const logger = createLogger('EvaluatorBlockHandler')
@@ -35,11 +35,9 @@ export class EvaluatorBlockHandler implements BlockHandler {
}
const providerId = getProviderFromModel(evaluatorConfig.model)
let finalApiKey: string
let finalApiKey: string | undefined = evaluatorConfig.apiKey
if (providerId === 'vertex' && evaluatorConfig.vertexCredential) {
finalApiKey = await this.resolveVertexCredential(evaluatorConfig.vertexCredential)
} else {
finalApiKey = this.getApiKey(providerId, evaluatorConfig.model, evaluatorConfig.apiKey)
}
const processedContent = this.processContent(inputs.content)
@@ -117,6 +115,7 @@ export class EvaluatorBlockHandler implements BlockHandler {
temperature: EVALUATOR.DEFAULT_TEMPERATURE,
apiKey: finalApiKey,
workflowId: ctx.workflowId,
workspaceId: ctx.workspaceId,
}
if (providerId === 'vertex') {
@@ -275,20 +274,6 @@ export class EvaluatorBlockHandler implements BlockHandler {
return DEFAULTS.EXECUTION_TIME
}
private getApiKey(providerId: string, model: string, inputApiKey: string): string {
try {
return getApiKey(providerId, model, inputApiKey)
} catch (error) {
logger.error('Failed to get API key:', {
provider: providerId,
model,
error: error instanceof Error ? error.message : String(error),
hasProvidedApiKey: !!inputApiKey,
})
throw new Error(error instanceof Error ? error.message : 'API key error')
}
}
/**
* Resolves a Vertex AI OAuth credential to an access token
*/

View File

@@ -265,20 +265,6 @@ describe('RouterBlockHandler', () => {
})
})
it('should throw error when API key is missing for non-hosted models', async () => {
const inputs = {
prompt: 'Test without API key',
model: 'gpt-4o',
// No apiKey provided
}
mockGetProviderFromModel.mockReturnValue('openai')
await expect(handler.execute(mockContext, mockBlock, inputs)).rejects.toThrow(
/API key is required/
)
})
it('should handle Vertex AI models with OAuth credential', async () => {
const inputs = {
prompt: 'Choose the best option.',

View File

@@ -8,7 +8,7 @@ import { generateRouterPrompt } from '@/blocks/blocks/router'
import type { BlockOutput } from '@/blocks/types'
import { BlockType, DEFAULTS, HTTP, isAgentBlockType, ROUTER } from '@/executor/constants'
import type { BlockHandler, ExecutionContext } from '@/executor/types'
import { calculateCost, getApiKey, getProviderFromModel } from '@/providers/utils'
import { calculateCost, getProviderFromModel } from '@/providers/utils'
import type { SerializedBlock } from '@/serializer/types'
const logger = createLogger('RouterBlockHandler')
@@ -47,11 +47,9 @@ export class RouterBlockHandler implements BlockHandler {
const messages = [{ role: 'user', content: routerConfig.prompt }]
const systemPrompt = generateRouterPrompt(routerConfig.prompt, targetBlocks)
let finalApiKey: string
let finalApiKey: string | undefined = routerConfig.apiKey
if (providerId === 'vertex' && routerConfig.vertexCredential) {
finalApiKey = await this.resolveVertexCredential(routerConfig.vertexCredential)
} else {
finalApiKey = this.getApiKey(providerId, routerConfig.model, routerConfig.apiKey)
}
const providerRequest: Record<string, any> = {
@@ -62,6 +60,7 @@ export class RouterBlockHandler implements BlockHandler {
temperature: ROUTER.INFERENCE_TEMPERATURE,
apiKey: finalApiKey,
workflowId: ctx.workflowId,
workspaceId: ctx.workspaceId,
}
if (providerId === 'vertex') {
@@ -178,20 +177,6 @@ export class RouterBlockHandler implements BlockHandler {
})
}
private getApiKey(providerId: string, model: string, inputApiKey: string): string {
try {
return getApiKey(providerId, model, inputApiKey)
} catch (error) {
logger.error('Failed to get API key:', {
provider: providerId,
model,
error: error instanceof Error ? error.message : String(error),
hasProvidedApiKey: !!inputApiKey,
})
throw new Error(error instanceof Error ? error.message : 'API key error')
}
}
/**
* Resolves a Vertex AI OAuth credential to an access token
*/

View File

@@ -19,7 +19,7 @@ const sleep = async (ms: number, options: SleepOptions = {}): Promise<boolean> =
}
return new Promise((resolve) => {
// biome-ignore lint/style/useConst: Variable is assigned after closure definitions that reference it
// biome-ignore lint/style/useConst: needs to be declared before cleanup() but assigned later
let mainTimeoutId: NodeJS.Timeout | undefined
let checkIntervalId: NodeJS.Timeout | undefined
let resolved = false

View File

@@ -0,0 +1,105 @@
import { keepPreviousData, useMutation, useQuery, useQueryClient } from '@tanstack/react-query'
import { createLogger } from '@/lib/logs/console/logger'
import { API_ENDPOINTS } from '@/stores/constants'
const logger = createLogger('BYOKKeysQueries')
export type BYOKProviderId = 'openai' | 'anthropic' | 'google' | 'mistral' | 'exa'
export interface BYOKKey {
id: string
providerId: BYOKProviderId
maskedKey: string
createdBy: string | null
createdAt: string
updatedAt: string
}
export const byokKeysKeys = {
all: ['byok-keys'] as const,
workspace: (workspaceId: string) => [...byokKeysKeys.all, 'workspace', workspaceId] as const,
}
async function fetchBYOKKeys(workspaceId: string): Promise<BYOKKey[]> {
const response = await fetch(API_ENDPOINTS.WORKSPACE_BYOK_KEYS(workspaceId))
if (!response.ok) {
throw new Error(`Failed to load BYOK keys: ${response.statusText}`)
}
const { keys } = await response.json()
return keys
}
export function useBYOKKeys(workspaceId: string) {
return useQuery({
queryKey: byokKeysKeys.workspace(workspaceId),
queryFn: () => fetchBYOKKeys(workspaceId),
enabled: !!workspaceId,
staleTime: 60 * 1000,
placeholderData: keepPreviousData,
})
}
interface UpsertBYOKKeyParams {
workspaceId: string
providerId: BYOKProviderId
apiKey: string
}
export function useUpsertBYOKKey() {
const queryClient = useQueryClient()
return useMutation({
mutationFn: async ({ workspaceId, providerId, apiKey }: UpsertBYOKKeyParams) => {
const response = await fetch(API_ENDPOINTS.WORKSPACE_BYOK_KEYS(workspaceId), {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ providerId, apiKey }),
})
if (!response.ok) {
const data = await response.json().catch(() => ({}))
throw new Error(data.error || `Failed to save BYOK key: ${response.statusText}`)
}
logger.info(`Saved BYOK key for ${providerId} in workspace ${workspaceId}`)
return await response.json()
},
onSuccess: (_data, variables) => {
queryClient.invalidateQueries({
queryKey: byokKeysKeys.workspace(variables.workspaceId),
})
},
})
}
interface DeleteBYOKKeyParams {
workspaceId: string
providerId: BYOKProviderId
}
export function useDeleteBYOKKey() {
const queryClient = useQueryClient()
return useMutation({
mutationFn: async ({ workspaceId, providerId }: DeleteBYOKKeyParams) => {
const response = await fetch(API_ENDPOINTS.WORKSPACE_BYOK_KEYS(workspaceId), {
method: 'DELETE',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ providerId }),
})
if (!response.ok) {
const data = await response.json().catch(() => ({}))
throw new Error(data.error || `Failed to delete BYOK key: ${response.statusText}`)
}
logger.info(`Deleted BYOK key for ${providerId} from workspace ${workspaceId}`)
return await response.json()
},
onSuccess: (_data, variables) => {
queryClient.invalidateQueries({
queryKey: byokKeysKeys.workspace(variables.workspaceId),
})
},
})
}

View File

@@ -0,0 +1,121 @@
import { db } from '@sim/db'
import { workspaceBYOKKeys } from '@sim/db/schema'
import { and, eq } from 'drizzle-orm'
import { decryptSecret } from '@/lib/core/security/encryption'
import { createLogger } from '@/lib/logs/console/logger'
const logger = createLogger('BYOKKeys')
export type BYOKProviderId = 'openai' | 'anthropic' | 'google' | 'mistral' | 'exa'
export interface BYOKKeyResult {
apiKey: string
isBYOK: true
}
export async function getBYOKKey(
workspaceId: string | undefined | null,
providerId: BYOKProviderId
): Promise<BYOKKeyResult | null> {
if (!workspaceId) {
return null
}
try {
const result = await db
.select({ encryptedApiKey: workspaceBYOKKeys.encryptedApiKey })
.from(workspaceBYOKKeys)
.where(
and(
eq(workspaceBYOKKeys.workspaceId, workspaceId),
eq(workspaceBYOKKeys.providerId, providerId)
)
)
.limit(1)
if (!result.length) {
return null
}
const { decrypted } = await decryptSecret(result[0].encryptedApiKey)
return { apiKey: decrypted, isBYOK: true }
} catch (error) {
logger.error('Failed to get BYOK key', { workspaceId, providerId, error })
return null
}
}
export async function getApiKeyWithBYOK(
provider: string,
model: string,
workspaceId: string | undefined | null,
userProvidedKey?: string
): Promise<{ apiKey: string; isBYOK: boolean }> {
const { isHosted } = await import('@/lib/core/config/feature-flags')
const { useProvidersStore } = await import('@/stores/providers/store')
const isOllamaModel =
provider === 'ollama' || useProvidersStore.getState().providers.ollama.models.includes(model)
if (isOllamaModel) {
return { apiKey: 'empty', isBYOK: false }
}
const isVllmModel =
provider === 'vllm' || useProvidersStore.getState().providers.vllm.models.includes(model)
if (isVllmModel) {
return { apiKey: userProvidedKey || 'empty', isBYOK: false }
}
const isOpenAIModel = provider === 'openai'
const isClaudeModel = provider === 'anthropic'
const isGeminiModel = provider === 'google'
const isMistralModel = provider === 'mistral'
const byokProviderId = isGeminiModel ? 'google' : (provider as BYOKProviderId)
if (
isHosted &&
workspaceId &&
(isOpenAIModel || isClaudeModel || isGeminiModel || isMistralModel)
) {
const { getHostedModels } = await import('@/providers/models')
const hostedModels = getHostedModels()
const isModelHosted = hostedModels.some((m) => m.toLowerCase() === model.toLowerCase())
logger.debug('BYOK check', { provider, model, workspaceId, isHosted, isModelHosted })
if (isModelHosted || isMistralModel) {
const byokResult = await getBYOKKey(workspaceId, byokProviderId)
if (byokResult) {
logger.info('Using BYOK key', { provider, model, workspaceId })
return byokResult
}
logger.debug('No BYOK key found, falling back', { provider, model, workspaceId })
if (isModelHosted) {
try {
const { getRotatingApiKey } = await import('@/lib/core/config/api-keys')
const serverKey = getRotatingApiKey(isGeminiModel ? 'gemini' : provider)
return { apiKey: serverKey, isBYOK: false }
} catch (_error) {
if (userProvidedKey) {
return { apiKey: userProvidedKey, isBYOK: false }
}
throw new Error(`No API key available for ${provider} ${model}`)
}
}
}
}
if (!userProvidedKey) {
logger.debug('BYOK not applicable, no user key provided', {
provider,
model,
workspaceId,
isHosted,
})
throw new Error(`API key is required for ${provider} ${model}`)
}
return { apiKey: userProvidedKey, isBYOK: false }
}

View File

@@ -66,7 +66,7 @@ export interface LogFixedUsageParams {
* Log a model usage charge (token-based)
*/
export async function logModelUsage(params: LogModelUsageParams): Promise<void> {
if (!isBillingEnabled) {
if (!isBillingEnabled || params.cost <= 0) {
return
}
@@ -108,7 +108,7 @@ export async function logModelUsage(params: LogModelUsageParams): Promise<void>
* Log a fixed charge (flat fee like base execution charge or search)
*/
export async function logFixedUsage(params: LogFixedUsageParams): Promise<void> {
if (!isBillingEnabled) {
if (!isBillingEnabled || params.cost <= 0) {
return
}

View File

@@ -174,8 +174,11 @@ export const knowledgeBaseServerTool: BaseServerTool<KnowledgeBaseArgs, Knowledg
const topK = args.topK || 5
// Generate embedding for the query
const queryEmbedding = await generateSearchEmbedding(args.query)
const queryEmbedding = await generateSearchEmbedding(
args.query,
undefined,
kb.workspaceId
)
const queryVector = JSON.stringify(queryEmbedding)
// Get search strategy

View File

@@ -94,11 +94,11 @@ export async function createChunk(
documentId: string,
docTags: Record<string, string | number | boolean | Date | null>,
chunkData: CreateChunkData,
requestId: string
requestId: string,
workspaceId?: string | null
): Promise<ChunkData> {
// Generate embedding for the content first (outside transaction for performance)
logger.info(`[${requestId}] Generating embedding for manual chunk`)
const embeddings = await generateEmbeddings([chunkData.content])
const embeddings = await generateEmbeddings([chunkData.content], undefined, workspaceId)
// Calculate accurate token count
const tokenCount = estimateTokenCount(chunkData.content, 'openai')
@@ -285,7 +285,8 @@ export async function updateChunk(
content?: string
enabled?: boolean
},
requestId: string
requestId: string,
workspaceId?: string | null
): Promise<ChunkData> {
const dbUpdateData: {
updatedAt: Date
@@ -327,8 +328,7 @@ export async function updateChunk(
if (content !== currentChunk[0].content) {
logger.info(`[${requestId}] Content changed, regenerating embedding for chunk ${chunkId}`)
// Generate new embedding for the updated content
const embeddings = await generateEmbeddings([content])
const embeddings = await generateEmbeddings([content], undefined, workspaceId)
// Calculate accurate token count
const tokenCount = estimateTokenCount(content, 'openai')

View File

@@ -1,3 +1,4 @@
import { getBYOKKey } from '@/lib/api-key/byok'
import { type Chunk, JsonYamlChunker, StructuredDataChunker, TextChunker } from '@/lib/chunkers'
import { env } from '@/lib/core/config/env'
import { parseBuffer, parseFile } from '@/lib/file-parsers'
@@ -131,6 +132,17 @@ export async function processDocument(
}
}
async function getMistralApiKey(workspaceId?: string | null): Promise<string | null> {
if (workspaceId) {
const byokResult = await getBYOKKey(workspaceId, 'mistral')
if (byokResult) {
logger.info('Using workspace BYOK key for Mistral OCR')
return byokResult.apiKey
}
}
return env.MISTRAL_API_KEY || null
}
async function parseDocument(
fileUrl: string,
filename: string,
@@ -146,7 +158,9 @@ async function parseDocument(
const isPDF = mimeType === 'application/pdf'
const hasAzureMistralOCR =
env.OCR_AZURE_API_KEY && env.OCR_AZURE_ENDPOINT && env.OCR_AZURE_MODEL_NAME
const hasMistralOCR = env.MISTRAL_API_KEY
const mistralApiKey = await getMistralApiKey(workspaceId)
const hasMistralOCR = !!mistralApiKey
if (isPDF && (hasAzureMistralOCR || hasMistralOCR)) {
if (hasAzureMistralOCR) {
@@ -156,7 +170,7 @@ async function parseDocument(
if (hasMistralOCR) {
logger.info(`Using Mistral OCR: ${filename}`)
return parseWithMistralOCR(fileUrl, filename, mimeType, userId, workspaceId)
return parseWithMistralOCR(fileUrl, filename, mimeType, userId, workspaceId, mistralApiKey)
}
}
@@ -360,9 +374,18 @@ async function parseWithAzureMistralOCR(
message: error instanceof Error ? error.message : String(error),
})
return env.MISTRAL_API_KEY
? parseWithMistralOCR(fileUrl, filename, mimeType, userId, workspaceId)
: parseWithFileParser(fileUrl, filename, mimeType)
const fallbackMistralKey = await getMistralApiKey(workspaceId)
if (fallbackMistralKey) {
return parseWithMistralOCR(
fileUrl,
filename,
mimeType,
userId,
workspaceId,
fallbackMistralKey
)
}
return parseWithFileParser(fileUrl, filename, mimeType)
}
}
@@ -371,9 +394,11 @@ async function parseWithMistralOCR(
filename: string,
mimeType: string,
userId?: string,
workspaceId?: string | null
workspaceId?: string | null,
mistralApiKey?: string | null
) {
if (!env.MISTRAL_API_KEY) {
const apiKey = mistralApiKey || env.MISTRAL_API_KEY
if (!apiKey) {
throw new Error('Mistral API key required')
}
@@ -388,7 +413,7 @@ async function parseWithMistralOCR(
userId,
workspaceId
)
const params = { filePath: httpsUrl, apiKey: env.MISTRAL_API_KEY, resultType: 'text' as const }
const params = { filePath: httpsUrl, apiKey, resultType: 'text' as const }
try {
const response = await retryWithExponentialBackoff(

View File

@@ -484,7 +484,7 @@ export async function processDocumentAsync(
const batchNum = Math.floor(i / batchSize) + 1
logger.info(`[${documentId}] Processing embedding batch ${batchNum}/${totalBatches}`)
const batchEmbeddings = await generateEmbeddings(batch)
const batchEmbeddings = await generateEmbeddings(batch, undefined, kb[0].workspaceId)
embeddings.push(...batchEmbeddings)
}
}

View File

@@ -1,3 +1,4 @@
import { getBYOKKey } from '@/lib/api-key/byok'
import { env } from '@/lib/core/config/env'
import { isRetryableError, retryWithExponentialBackoff } from '@/lib/knowledge/documents/utils'
import { createLogger } from '@/lib/logs/console/logger'
@@ -24,40 +25,53 @@ interface EmbeddingConfig {
modelName: string
}
function getEmbeddingConfig(embeddingModel = 'text-embedding-3-small'): EmbeddingConfig {
async function getEmbeddingConfig(
embeddingModel = 'text-embedding-3-small',
workspaceId?: string | null
): Promise<EmbeddingConfig> {
const azureApiKey = env.AZURE_OPENAI_API_KEY
const azureEndpoint = env.AZURE_OPENAI_ENDPOINT
const azureApiVersion = env.AZURE_OPENAI_API_VERSION
const kbModelName = env.KB_OPENAI_MODEL_NAME || embeddingModel
const openaiApiKey = env.OPENAI_API_KEY
const useAzure = !!(azureApiKey && azureEndpoint)
if (!useAzure && !openaiApiKey) {
if (useAzure) {
return {
useAzure: true,
apiUrl: `${azureEndpoint}/openai/deployments/${kbModelName}/embeddings?api-version=${azureApiVersion}`,
headers: {
'api-key': azureApiKey!,
'Content-Type': 'application/json',
},
modelName: kbModelName,
}
}
let openaiApiKey = env.OPENAI_API_KEY
if (workspaceId) {
const byokResult = await getBYOKKey(workspaceId, 'openai')
if (byokResult) {
logger.info('Using workspace BYOK key for OpenAI embeddings')
openaiApiKey = byokResult.apiKey
}
}
if (!openaiApiKey) {
throw new Error(
'Either OPENAI_API_KEY or Azure OpenAI configuration (AZURE_OPENAI_API_KEY + AZURE_OPENAI_ENDPOINT) must be configured'
)
}
const apiUrl = useAzure
? `${azureEndpoint}/openai/deployments/${kbModelName}/embeddings?api-version=${azureApiVersion}`
: 'https://api.openai.com/v1/embeddings'
const headers: Record<string, string> = useAzure
? {
'api-key': azureApiKey!,
'Content-Type': 'application/json',
}
: {
Authorization: `Bearer ${openaiApiKey!}`,
'Content-Type': 'application/json',
}
return {
useAzure,
apiUrl,
headers,
modelName: useAzure ? kbModelName : embeddingModel,
useAzure: false,
apiUrl: 'https://api.openai.com/v1/embeddings',
headers: {
Authorization: `Bearer ${openaiApiKey}`,
'Content-Type': 'application/json',
},
modelName: embeddingModel,
}
}
@@ -112,9 +126,10 @@ async function callEmbeddingAPI(inputs: string[], config: EmbeddingConfig): Prom
*/
export async function generateEmbeddings(
texts: string[],
embeddingModel = 'text-embedding-3-small'
embeddingModel = 'text-embedding-3-small',
workspaceId?: string | null
): Promise<number[][]> {
const config = getEmbeddingConfig(embeddingModel)
const config = await getEmbeddingConfig(embeddingModel, workspaceId)
logger.info(
`Using ${config.useAzure ? 'Azure OpenAI' : 'OpenAI'} for embeddings generation (${texts.length} texts)`
@@ -163,9 +178,10 @@ export async function generateEmbeddings(
*/
export async function generateSearchEmbedding(
query: string,
embeddingModel = 'text-embedding-3-small'
embeddingModel = 'text-embedding-3-small',
workspaceId?: string | null
): Promise<number[]> {
const config = getEmbeddingConfig(embeddingModel)
const config = await getEmbeddingConfig(embeddingModel, workspaceId)
logger.info(
`Using ${config.useAzure ? 'Azure OpenAI' : 'OpenAI'} for search embedding generation`

View File

@@ -1,3 +1,4 @@
import { getApiKeyWithBYOK } from '@/lib/api-key/byok'
import { getCostMultiplier } from '@/lib/core/config/feature-flags'
import { createLogger } from '@/lib/logs/console/logger'
import type { StreamingExecution } from '@/executor/types'
@@ -48,7 +49,32 @@ export async function executeProviderRequest(
if (!provider.executeRequest) {
throw new Error(`Provider ${providerId} does not implement executeRequest`)
}
const sanitizedRequest = sanitizeRequest(request)
let resolvedRequest = sanitizeRequest(request)
let isBYOK = false
if (request.workspaceId) {
try {
const result = await getApiKeyWithBYOK(
providerId,
request.model,
request.workspaceId,
request.apiKey
)
resolvedRequest = { ...resolvedRequest, apiKey: result.apiKey }
isBYOK = result.isBYOK
} catch (error) {
logger.error('Failed to resolve API key:', {
provider: providerId,
model: request.model,
error: error instanceof Error ? error.message : String(error),
})
throw error
}
}
resolvedRequest.isBYOK = isBYOK
const sanitizedRequest = resolvedRequest
if (sanitizedRequest.responseFormat) {
if (
@@ -88,7 +114,8 @@ export async function executeProviderRequest(
const { input: promptTokens = 0, output: completionTokens = 0 } = response.tokens
const useCachedInput = !!request.context && request.context.length > 0
if (shouldBillModelUsage(response.model)) {
const shouldBill = shouldBillModelUsage(response.model) && !isBYOK
if (shouldBill) {
const costMultiplier = getCostMultiplier()
response.cost = calculateCost(
response.model,
@@ -109,9 +136,13 @@ export async function executeProviderRequest(
updatedAt: new Date().toISOString(),
},
}
logger.debug(
`Not billing model usage for ${response.model} - user provided API key or not hosted model`
)
if (isBYOK) {
logger.debug(`Not billing model usage for ${response.model} - workspace BYOK key used`)
} else {
logger.debug(
`Not billing model usage for ${response.model} - user provided API key or not hosted model`
)
}
}
}

View File

@@ -134,12 +134,12 @@ export interface Message {
export interface ProviderRequest {
model: string
systemPrompt: string
systemPrompt?: string
context?: string
tools?: ProviderToolConfig[]
temperature?: number
maxTokens?: number
apiKey: string
apiKey?: string
messages?: Message[]
responseFormat?: {
name: string
@@ -158,6 +158,7 @@ export interface ProviderRequest {
blockData?: Record<string, any>
blockNameMapping?: Record<string, string>
isCopilotRequest?: boolean
isBYOK?: boolean
azureEndpoint?: string
azureApiVersion?: string
vertexProject?: string

View File

@@ -4,6 +4,7 @@ export const API_ENDPOINTS = {
WORKFLOWS: '/api/workflows',
WORKSPACE_PERMISSIONS: (id: string) => `/api/workspaces/${id}/permissions`,
WORKSPACE_ENVIRONMENT: (id: string) => `/api/workspaces/${id}/environment`,
WORKSPACE_BYOK_KEYS: (id: string) => `/api/workspaces/${id}/byok-keys`,
}
export const COPILOT_TOOL_DISPLAY_NAMES: Record<string, string> = {

View File

@@ -25,6 +25,7 @@ export const searchTool: ToolConfig<SearchParams, SearchResponse> = {
}),
body: (params) => ({
query: params.query,
workspaceId: params._context?.workspaceId,
}),
},

View File

@@ -2,6 +2,11 @@ import type { ToolResponse } from '@/tools/types'
export interface SearchParams {
query: string
_context?: {
workflowId?: string
workspaceId?: string
executionId?: string
}
}
export interface SearchResponse extends ToolResponse {

View File

@@ -0,0 +1,14 @@
CREATE TABLE "workspace_byok_keys" (
"id" text PRIMARY KEY NOT NULL,
"workspace_id" text NOT NULL,
"provider_id" text NOT NULL,
"encrypted_api_key" text NOT NULL,
"created_by" text,
"created_at" timestamp DEFAULT now() NOT NULL,
"updated_at" timestamp DEFAULT now() NOT NULL
);
--> statement-breakpoint
ALTER TABLE "workspace_byok_keys" ADD CONSTRAINT "workspace_byok_keys_workspace_id_workspace_id_fk" FOREIGN KEY ("workspace_id") REFERENCES "public"."workspace"("id") ON DELETE cascade ON UPDATE no action;--> statement-breakpoint
ALTER TABLE "workspace_byok_keys" ADD CONSTRAINT "workspace_byok_keys_created_by_user_id_fk" FOREIGN KEY ("created_by") REFERENCES "public"."user"("id") ON DELETE set null ON UPDATE no action;--> statement-breakpoint
CREATE UNIQUE INDEX "workspace_byok_provider_unique" ON "workspace_byok_keys" USING btree ("workspace_id","provider_id");--> statement-breakpoint
CREATE INDEX "workspace_byok_workspace_idx" ON "workspace_byok_keys" USING btree ("workspace_id");

File diff suppressed because it is too large Load Diff

View File

@@ -925,6 +925,13 @@
"when": 1766529613309,
"tag": "0132_dazzling_leech",
"breakpoints": true
},
{
"idx": 133,
"version": "7",
"when": 1766607372265,
"tag": "0133_smiling_cargill",
"breakpoints": true
}
]
}

View File

@@ -420,6 +420,28 @@ export const workspaceEnvironment = pgTable(
})
)
export const workspaceBYOKKeys = pgTable(
'workspace_byok_keys',
{
id: text('id').primaryKey(),
workspaceId: text('workspace_id')
.notNull()
.references(() => workspace.id, { onDelete: 'cascade' }),
providerId: text('provider_id').notNull(),
encryptedApiKey: text('encrypted_api_key').notNull(),
createdBy: text('created_by').references(() => user.id, { onDelete: 'set null' }),
createdAt: timestamp('created_at').notNull().defaultNow(),
updatedAt: timestamp('updated_at').notNull().defaultNow(),
},
(table) => ({
workspaceProviderUnique: uniqueIndex('workspace_byok_provider_unique').on(
table.workspaceId,
table.providerId
),
workspaceIdx: index('workspace_byok_workspace_idx').on(table.workspaceId),
})
)
export const settings = pgTable('settings', {
id: text('id').primaryKey(), // Use the user id as the key
userId: text('user_id')