mirror of
https://github.com/simstudioai/sim.git
synced 2026-01-08 22:48:14 -05:00
feat(byok): byok for hosted model capabilities (#2574)
* feat(byok): byok for hosted model capabilities * fix type * add ignore lint * accidentally added feature flags * centralize byok fetch for LLM calls * remove feature flags ts * fix tests * update docs
This commit is contained in:
committed by
GitHub
parent
40a6bf5c8c
commit
47a259b428
@@ -104,6 +104,10 @@ The model breakdown shows:
|
||||
Pricing shown reflects rates as of September 10, 2025. Check provider documentation for current pricing.
|
||||
</Callout>
|
||||
|
||||
## Bring Your Own Key (BYOK)
|
||||
|
||||
You can use your own API keys for hosted models (OpenAI, Anthropic, Google, Mistral) in **Settings → BYOK** to pay base prices. Keys are encrypted and apply workspace-wide.
|
||||
|
||||
## Cost Optimization Strategies
|
||||
|
||||
- **Model Selection**: Choose models based on task complexity. Simple tasks can use GPT-4.1-nano while complex reasoning might need o1 or Claude Opus.
|
||||
|
||||
@@ -100,7 +100,12 @@ export async function PUT(
|
||||
try {
|
||||
const validatedData = UpdateChunkSchema.parse(body)
|
||||
|
||||
const updatedChunk = await updateChunk(chunkId, validatedData, requestId)
|
||||
const updatedChunk = await updateChunk(
|
||||
chunkId,
|
||||
validatedData,
|
||||
requestId,
|
||||
accessCheck.knowledgeBase?.workspaceId
|
||||
)
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Chunk updated: ${chunkId} in document ${documentId} in knowledge base ${knowledgeBaseId}`
|
||||
|
||||
@@ -184,7 +184,8 @@ export async function POST(
|
||||
documentId,
|
||||
docTags,
|
||||
validatedData,
|
||||
requestId
|
||||
requestId,
|
||||
accessCheck.knowledgeBase?.workspaceId
|
||||
)
|
||||
|
||||
let cost = null
|
||||
|
||||
@@ -183,11 +183,11 @@ export async function POST(request: NextRequest) {
|
||||
)
|
||||
}
|
||||
|
||||
// Generate query embedding only if query is provided
|
||||
const workspaceId = accessChecks.find((ac) => ac?.hasAccess)?.knowledgeBase?.workspaceId
|
||||
|
||||
const hasQuery = validatedData.query && validatedData.query.trim().length > 0
|
||||
// Start embedding generation early and await when needed
|
||||
const queryEmbeddingPromise = hasQuery
|
||||
? generateSearchEmbedding(validatedData.query!)
|
||||
? generateSearchEmbedding(validatedData.query!, undefined, workspaceId)
|
||||
: Promise.resolve(null)
|
||||
|
||||
// Check if any requested knowledge bases were not accessible
|
||||
|
||||
@@ -99,7 +99,7 @@ export interface EmbeddingData {
|
||||
|
||||
export interface KnowledgeBaseAccessResult {
|
||||
hasAccess: true
|
||||
knowledgeBase: Pick<KnowledgeBaseData, 'id' | 'userId'>
|
||||
knowledgeBase: Pick<KnowledgeBaseData, 'id' | 'userId' | 'workspaceId'>
|
||||
}
|
||||
|
||||
export interface KnowledgeBaseAccessDenied {
|
||||
@@ -113,7 +113,7 @@ export type KnowledgeBaseAccessCheck = KnowledgeBaseAccessResult | KnowledgeBase
|
||||
export interface DocumentAccessResult {
|
||||
hasAccess: true
|
||||
document: DocumentData
|
||||
knowledgeBase: Pick<KnowledgeBaseData, 'id' | 'userId'>
|
||||
knowledgeBase: Pick<KnowledgeBaseData, 'id' | 'userId' | 'workspaceId'>
|
||||
}
|
||||
|
||||
export interface DocumentAccessDenied {
|
||||
@@ -128,7 +128,7 @@ export interface ChunkAccessResult {
|
||||
hasAccess: true
|
||||
chunk: EmbeddingData
|
||||
document: DocumentData
|
||||
knowledgeBase: Pick<KnowledgeBaseData, 'id' | 'userId'>
|
||||
knowledgeBase: Pick<KnowledgeBaseData, 'id' | 'userId' | 'workspaceId'>
|
||||
}
|
||||
|
||||
export interface ChunkAccessDenied {
|
||||
|
||||
@@ -7,7 +7,6 @@ import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { refreshTokenIfNeeded } from '@/app/api/auth/oauth/utils'
|
||||
import type { StreamingExecution } from '@/executor/types'
|
||||
import { executeProviderRequest } from '@/providers'
|
||||
import { getApiKey } from '@/providers/utils'
|
||||
|
||||
const logger = createLogger('ProvidersAPI')
|
||||
|
||||
@@ -80,23 +79,20 @@ export async function POST(request: NextRequest) {
|
||||
verbosity,
|
||||
})
|
||||
|
||||
let finalApiKey: string
|
||||
let finalApiKey: string | undefined = apiKey
|
||||
try {
|
||||
if (provider === 'vertex' && vertexCredential) {
|
||||
finalApiKey = await resolveVertexCredential(requestId, vertexCredential)
|
||||
} else {
|
||||
finalApiKey = getApiKey(provider, model, apiKey)
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`[${requestId}] Failed to get API key:`, {
|
||||
logger.error(`[${requestId}] Failed to resolve Vertex credential:`, {
|
||||
provider,
|
||||
model,
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
hasProvidedApiKey: !!apiKey,
|
||||
hasVertexCredential: !!vertexCredential,
|
||||
})
|
||||
return NextResponse.json(
|
||||
{ error: error instanceof Error ? error.message : 'API key error' },
|
||||
{ error: error instanceof Error ? error.message : 'Credential error' },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
@@ -108,7 +104,6 @@ export async function POST(request: NextRequest) {
|
||||
hasApiKey: !!finalApiKey,
|
||||
})
|
||||
|
||||
// Execute provider request directly with the managed key
|
||||
const response = await executeProviderRequest(provider, {
|
||||
model,
|
||||
systemPrompt,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { getBYOKKey } from '@/lib/api-key/byok'
|
||||
import { checkHybridAuth } from '@/lib/auth/hybrid'
|
||||
import { SEARCH_TOOL_COST } from '@/lib/billing/constants'
|
||||
import { env } from '@/lib/core/config/env'
|
||||
@@ -10,6 +11,7 @@ const logger = createLogger('search')
|
||||
|
||||
const SearchRequestSchema = z.object({
|
||||
query: z.string().min(1),
|
||||
workspaceId: z.string().optional(),
|
||||
})
|
||||
|
||||
export const maxDuration = 60
|
||||
@@ -39,8 +41,20 @@ export async function POST(request: NextRequest) {
|
||||
const body = await request.json()
|
||||
const validated = SearchRequestSchema.parse(body)
|
||||
|
||||
if (!env.EXA_API_KEY) {
|
||||
logger.error(`[${requestId}] EXA_API_KEY not configured`)
|
||||
let exaApiKey = env.EXA_API_KEY
|
||||
let isBYOK = false
|
||||
|
||||
if (validated.workspaceId) {
|
||||
const byokResult = await getBYOKKey(validated.workspaceId, 'exa')
|
||||
if (byokResult) {
|
||||
exaApiKey = byokResult.apiKey
|
||||
isBYOK = true
|
||||
logger.info(`[${requestId}] Using workspace BYOK key for Exa search`)
|
||||
}
|
||||
}
|
||||
|
||||
if (!exaApiKey) {
|
||||
logger.error(`[${requestId}] No Exa API key available`)
|
||||
return NextResponse.json(
|
||||
{ success: false, error: 'Search service not configured' },
|
||||
{ status: 503 }
|
||||
@@ -50,6 +64,7 @@ export async function POST(request: NextRequest) {
|
||||
logger.info(`[${requestId}] Executing search`, {
|
||||
userId,
|
||||
query: validated.query,
|
||||
isBYOK,
|
||||
})
|
||||
|
||||
const result = await executeTool('exa_search', {
|
||||
@@ -57,7 +72,7 @@ export async function POST(request: NextRequest) {
|
||||
type: 'auto',
|
||||
useAutoprompt: true,
|
||||
highlights: true,
|
||||
apiKey: env.EXA_API_KEY,
|
||||
apiKey: exaApiKey,
|
||||
})
|
||||
|
||||
if (!result.success) {
|
||||
@@ -85,7 +100,7 @@ export async function POST(request: NextRequest) {
|
||||
const cost = {
|
||||
input: 0,
|
||||
output: 0,
|
||||
total: SEARCH_TOOL_COST,
|
||||
total: isBYOK ? 0 : SEARCH_TOOL_COST,
|
||||
tokens: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
@@ -104,6 +119,7 @@ export async function POST(request: NextRequest) {
|
||||
userId,
|
||||
resultCount: results.length,
|
||||
cost: cost.total,
|
||||
isBYOK,
|
||||
})
|
||||
|
||||
return NextResponse.json({
|
||||
|
||||
@@ -3,6 +3,7 @@ import { userStats, workflow } from '@sim/db/schema'
|
||||
import { eq, sql } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import OpenAI, { AzureOpenAI } from 'openai'
|
||||
import { getBYOKKey } from '@/lib/api-key/byok'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { logModelUsage } from '@/lib/billing/core/usage-log'
|
||||
import { checkAndBillOverageThreshold } from '@/lib/billing/threshold-billing'
|
||||
@@ -75,7 +76,8 @@ async function updateUserStatsForWand(
|
||||
completion_tokens?: number
|
||||
total_tokens?: number
|
||||
},
|
||||
requestId: string
|
||||
requestId: string,
|
||||
isBYOK = false
|
||||
): Promise<void> {
|
||||
if (!isBillingEnabled) {
|
||||
logger.debug(`[${requestId}] Billing is disabled, skipping wand usage cost update`)
|
||||
@@ -93,21 +95,24 @@ async function updateUserStatsForWand(
|
||||
const completionTokens = usage.completion_tokens || 0
|
||||
|
||||
const modelName = useWandAzure ? wandModelName : 'gpt-4o'
|
||||
const pricing = getModelPricing(modelName)
|
||||
let costToStore = 0
|
||||
|
||||
const costMultiplier = getCostMultiplier()
|
||||
let modelCost = 0
|
||||
if (!isBYOK) {
|
||||
const pricing = getModelPricing(modelName)
|
||||
const costMultiplier = getCostMultiplier()
|
||||
let modelCost = 0
|
||||
|
||||
if (pricing) {
|
||||
const inputCost = (promptTokens / 1000000) * pricing.input
|
||||
const outputCost = (completionTokens / 1000000) * pricing.output
|
||||
modelCost = inputCost + outputCost
|
||||
} else {
|
||||
modelCost = (promptTokens / 1000000) * 0.005 + (completionTokens / 1000000) * 0.015
|
||||
if (pricing) {
|
||||
const inputCost = (promptTokens / 1000000) * pricing.input
|
||||
const outputCost = (completionTokens / 1000000) * pricing.output
|
||||
modelCost = inputCost + outputCost
|
||||
} else {
|
||||
modelCost = (promptTokens / 1000000) * 0.005 + (completionTokens / 1000000) * 0.015
|
||||
}
|
||||
|
||||
costToStore = modelCost * costMultiplier
|
||||
}
|
||||
|
||||
const costToStore = modelCost * costMultiplier
|
||||
|
||||
await db
|
||||
.update(userStats)
|
||||
.set({
|
||||
@@ -122,6 +127,7 @@ async function updateUserStatsForWand(
|
||||
userId,
|
||||
tokensUsed: totalTokens,
|
||||
costAdded: costToStore,
|
||||
isBYOK,
|
||||
})
|
||||
|
||||
await logModelUsage({
|
||||
@@ -149,14 +155,6 @@ export async function POST(req: NextRequest) {
|
||||
return NextResponse.json({ success: false, error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
if (!client) {
|
||||
logger.error(`[${requestId}] AI client not initialized. Missing API key.`)
|
||||
return NextResponse.json(
|
||||
{ success: false, error: 'Wand generation service is not configured.' },
|
||||
{ status: 503 }
|
||||
)
|
||||
}
|
||||
|
||||
try {
|
||||
const body = (await req.json()) as RequestBody
|
||||
|
||||
@@ -170,6 +168,7 @@ export async function POST(req: NextRequest) {
|
||||
)
|
||||
}
|
||||
|
||||
let workspaceId: string | null = null
|
||||
if (workflowId) {
|
||||
const [workflowRecord] = await db
|
||||
.select({ workspaceId: workflow.workspaceId, userId: workflow.userId })
|
||||
@@ -182,6 +181,8 @@ export async function POST(req: NextRequest) {
|
||||
return NextResponse.json({ success: false, error: 'Workflow not found' }, { status: 404 })
|
||||
}
|
||||
|
||||
workspaceId = workflowRecord.workspaceId
|
||||
|
||||
if (workflowRecord.workspaceId) {
|
||||
const permission = await verifyWorkspaceMembership(
|
||||
session.user.id,
|
||||
@@ -199,6 +200,28 @@ export async function POST(req: NextRequest) {
|
||||
}
|
||||
}
|
||||
|
||||
let isBYOK = false
|
||||
let activeClient = client
|
||||
let byokApiKey: string | null = null
|
||||
|
||||
if (workspaceId && !useWandAzure) {
|
||||
const byokResult = await getBYOKKey(workspaceId, 'openai')
|
||||
if (byokResult) {
|
||||
isBYOK = true
|
||||
byokApiKey = byokResult.apiKey
|
||||
activeClient = new OpenAI({ apiKey: byokResult.apiKey })
|
||||
logger.info(`[${requestId}] Using BYOK OpenAI key for wand generation`)
|
||||
}
|
||||
}
|
||||
|
||||
if (!activeClient) {
|
||||
logger.error(`[${requestId}] AI client not initialized. Missing API key.`)
|
||||
return NextResponse.json(
|
||||
{ success: false, error: 'Wand generation service is not configured.' },
|
||||
{ status: 503 }
|
||||
)
|
||||
}
|
||||
|
||||
const finalSystemPrompt =
|
||||
systemPrompt ||
|
||||
'You are a helpful AI assistant. Generate content exactly as requested by the user.'
|
||||
@@ -241,7 +264,7 @@ export async function POST(req: NextRequest) {
|
||||
if (useWandAzure) {
|
||||
headers['api-key'] = azureApiKey!
|
||||
} else {
|
||||
headers.Authorization = `Bearer ${openaiApiKey}`
|
||||
headers.Authorization = `Bearer ${byokApiKey || openaiApiKey}`
|
||||
}
|
||||
|
||||
logger.debug(`[${requestId}] Making streaming request to: ${apiUrl}`)
|
||||
@@ -310,7 +333,7 @@ export async function POST(req: NextRequest) {
|
||||
logger.info(`[${requestId}] Received [DONE] signal`)
|
||||
|
||||
if (finalUsage) {
|
||||
await updateUserStatsForWand(session.user.id, finalUsage, requestId)
|
||||
await updateUserStatsForWand(session.user.id, finalUsage, requestId, isBYOK)
|
||||
}
|
||||
|
||||
controller.enqueue(
|
||||
@@ -395,7 +418,7 @@ export async function POST(req: NextRequest) {
|
||||
}
|
||||
}
|
||||
|
||||
const completion = await client.chat.completions.create({
|
||||
const completion = await activeClient.chat.completions.create({
|
||||
model: useWandAzure ? wandModelName : 'gpt-4o',
|
||||
messages: messages,
|
||||
temperature: 0.3,
|
||||
@@ -417,7 +440,7 @@ export async function POST(req: NextRequest) {
|
||||
logger.info(`[${requestId}] Wand generation successful`)
|
||||
|
||||
if (completion.usage) {
|
||||
await updateUserStatsForWand(session.user.id, completion.usage, requestId)
|
||||
await updateUserStatsForWand(session.user.id, completion.usage, requestId, isBYOK)
|
||||
}
|
||||
|
||||
return NextResponse.json({ success: true, content: generatedContent })
|
||||
|
||||
256
apps/sim/app/api/workspaces/[id]/byok-keys/route.ts
Normal file
256
apps/sim/app/api/workspaces/[id]/byok-keys/route.ts
Normal file
@@ -0,0 +1,256 @@
|
||||
import { db } from '@sim/db'
|
||||
import { workspace, workspaceBYOKKeys } from '@sim/db/schema'
|
||||
import { and, eq } from 'drizzle-orm'
|
||||
import { nanoid } from 'nanoid'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { decryptSecret, encryptSecret } from '@/lib/core/security/encryption'
|
||||
import { generateRequestId } from '@/lib/core/utils/request'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { getUserEntityPermissions } from '@/lib/workspaces/permissions/utils'
|
||||
|
||||
const logger = createLogger('WorkspaceBYOKKeysAPI')
|
||||
|
||||
const VALID_PROVIDERS = ['openai', 'anthropic', 'google', 'mistral', 'exa'] as const
|
||||
|
||||
const UpsertKeySchema = z.object({
|
||||
providerId: z.enum(VALID_PROVIDERS),
|
||||
apiKey: z.string().min(1, 'API key is required'),
|
||||
})
|
||||
|
||||
const DeleteKeySchema = z.object({
|
||||
providerId: z.enum(VALID_PROVIDERS),
|
||||
})
|
||||
|
||||
function maskApiKey(key: string): string {
|
||||
if (key.length <= 8) {
|
||||
return '•'.repeat(8)
|
||||
}
|
||||
if (key.length <= 12) {
|
||||
return `${key.slice(0, 4)}...${key.slice(-4)}`
|
||||
}
|
||||
return `${key.slice(0, 6)}...${key.slice(-4)}`
|
||||
}
|
||||
|
||||
export async function GET(request: NextRequest, { params }: { params: Promise<{ id: string }> }) {
|
||||
const requestId = generateRequestId()
|
||||
const workspaceId = (await params).id
|
||||
|
||||
try {
|
||||
const session = await getSession()
|
||||
if (!session?.user?.id) {
|
||||
logger.warn(`[${requestId}] Unauthorized BYOK keys access attempt`)
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
const userId = session.user.id
|
||||
|
||||
const ws = await db.select().from(workspace).where(eq(workspace.id, workspaceId)).limit(1)
|
||||
if (!ws.length) {
|
||||
return NextResponse.json({ error: 'Workspace not found' }, { status: 404 })
|
||||
}
|
||||
|
||||
const permission = await getUserEntityPermissions(userId, 'workspace', workspaceId)
|
||||
if (!permission) {
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
const byokKeys = await db
|
||||
.select({
|
||||
id: workspaceBYOKKeys.id,
|
||||
providerId: workspaceBYOKKeys.providerId,
|
||||
encryptedApiKey: workspaceBYOKKeys.encryptedApiKey,
|
||||
createdBy: workspaceBYOKKeys.createdBy,
|
||||
createdAt: workspaceBYOKKeys.createdAt,
|
||||
updatedAt: workspaceBYOKKeys.updatedAt,
|
||||
})
|
||||
.from(workspaceBYOKKeys)
|
||||
.where(eq(workspaceBYOKKeys.workspaceId, workspaceId))
|
||||
.orderBy(workspaceBYOKKeys.providerId)
|
||||
|
||||
const formattedKeys = await Promise.all(
|
||||
byokKeys.map(async (key) => {
|
||||
try {
|
||||
const { decrypted } = await decryptSecret(key.encryptedApiKey)
|
||||
return {
|
||||
id: key.id,
|
||||
providerId: key.providerId,
|
||||
maskedKey: maskApiKey(decrypted),
|
||||
createdBy: key.createdBy,
|
||||
createdAt: key.createdAt,
|
||||
updatedAt: key.updatedAt,
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`[${requestId}] Failed to decrypt BYOK key for provider ${key.providerId}`, {
|
||||
error,
|
||||
})
|
||||
return {
|
||||
id: key.id,
|
||||
providerId: key.providerId,
|
||||
maskedKey: '••••••••',
|
||||
createdBy: key.createdBy,
|
||||
createdAt: key.createdAt,
|
||||
updatedAt: key.updatedAt,
|
||||
}
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
return NextResponse.json({ keys: formattedKeys })
|
||||
} catch (error: unknown) {
|
||||
logger.error(`[${requestId}] BYOK keys GET error`, error)
|
||||
return NextResponse.json(
|
||||
{ error: error instanceof Error ? error.message : 'Failed to load BYOK keys' },
|
||||
{ status: 500 }
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
export async function POST(request: NextRequest, { params }: { params: Promise<{ id: string }> }) {
|
||||
const requestId = generateRequestId()
|
||||
const workspaceId = (await params).id
|
||||
|
||||
try {
|
||||
const session = await getSession()
|
||||
if (!session?.user?.id) {
|
||||
logger.warn(`[${requestId}] Unauthorized BYOK key creation attempt`)
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
const userId = session.user.id
|
||||
|
||||
const permission = await getUserEntityPermissions(userId, 'workspace', workspaceId)
|
||||
if (permission !== 'admin') {
|
||||
return NextResponse.json(
|
||||
{ error: 'Only workspace admins can manage BYOK keys' },
|
||||
{ status: 403 }
|
||||
)
|
||||
}
|
||||
|
||||
const body = await request.json()
|
||||
const { providerId, apiKey } = UpsertKeySchema.parse(body)
|
||||
|
||||
const { encrypted } = await encryptSecret(apiKey)
|
||||
|
||||
const existingKey = await db
|
||||
.select()
|
||||
.from(workspaceBYOKKeys)
|
||||
.where(
|
||||
and(
|
||||
eq(workspaceBYOKKeys.workspaceId, workspaceId),
|
||||
eq(workspaceBYOKKeys.providerId, providerId)
|
||||
)
|
||||
)
|
||||
.limit(1)
|
||||
|
||||
if (existingKey.length > 0) {
|
||||
await db
|
||||
.update(workspaceBYOKKeys)
|
||||
.set({
|
||||
encryptedApiKey: encrypted,
|
||||
updatedAt: new Date(),
|
||||
})
|
||||
.where(eq(workspaceBYOKKeys.id, existingKey[0].id))
|
||||
|
||||
logger.info(`[${requestId}] Updated BYOK key for ${providerId} in workspace ${workspaceId}`)
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
key: {
|
||||
id: existingKey[0].id,
|
||||
providerId,
|
||||
maskedKey: maskApiKey(apiKey),
|
||||
updatedAt: new Date(),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
const [newKey] = await db
|
||||
.insert(workspaceBYOKKeys)
|
||||
.values({
|
||||
id: nanoid(),
|
||||
workspaceId,
|
||||
providerId,
|
||||
encryptedApiKey: encrypted,
|
||||
createdBy: userId,
|
||||
createdAt: new Date(),
|
||||
updatedAt: new Date(),
|
||||
})
|
||||
.returning({
|
||||
id: workspaceBYOKKeys.id,
|
||||
providerId: workspaceBYOKKeys.providerId,
|
||||
createdAt: workspaceBYOKKeys.createdAt,
|
||||
})
|
||||
|
||||
logger.info(`[${requestId}] Created BYOK key for ${providerId} in workspace ${workspaceId}`)
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
key: {
|
||||
...newKey,
|
||||
maskedKey: maskApiKey(apiKey),
|
||||
},
|
||||
})
|
||||
} catch (error: unknown) {
|
||||
logger.error(`[${requestId}] BYOK key POST error`, error)
|
||||
if (error instanceof z.ZodError) {
|
||||
return NextResponse.json({ error: error.errors[0].message }, { status: 400 })
|
||||
}
|
||||
return NextResponse.json(
|
||||
{ error: error instanceof Error ? error.message : 'Failed to save BYOK key' },
|
||||
{ status: 500 }
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
export async function DELETE(
|
||||
request: NextRequest,
|
||||
{ params }: { params: Promise<{ id: string }> }
|
||||
) {
|
||||
const requestId = generateRequestId()
|
||||
const workspaceId = (await params).id
|
||||
|
||||
try {
|
||||
const session = await getSession()
|
||||
if (!session?.user?.id) {
|
||||
logger.warn(`[${requestId}] Unauthorized BYOK key deletion attempt`)
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
const userId = session.user.id
|
||||
|
||||
const permission = await getUserEntityPermissions(userId, 'workspace', workspaceId)
|
||||
if (permission !== 'admin') {
|
||||
return NextResponse.json(
|
||||
{ error: 'Only workspace admins can manage BYOK keys' },
|
||||
{ status: 403 }
|
||||
)
|
||||
}
|
||||
|
||||
const body = await request.json()
|
||||
const { providerId } = DeleteKeySchema.parse(body)
|
||||
|
||||
const result = await db
|
||||
.delete(workspaceBYOKKeys)
|
||||
.where(
|
||||
and(
|
||||
eq(workspaceBYOKKeys.workspaceId, workspaceId),
|
||||
eq(workspaceBYOKKeys.providerId, providerId)
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(`[${requestId}] Deleted BYOK key for ${providerId} from workspace ${workspaceId}`)
|
||||
|
||||
return NextResponse.json({ success: true })
|
||||
} catch (error: unknown) {
|
||||
logger.error(`[${requestId}] BYOK key DELETE error`, error)
|
||||
if (error instanceof z.ZodError) {
|
||||
return NextResponse.json({ error: error.errors[0].message }, { status: 400 })
|
||||
}
|
||||
return NextResponse.json(
|
||||
{ error: error instanceof Error ? error.message : 'Failed to delete BYOK key' },
|
||||
{ status: 500 }
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,316 @@
|
||||
'use client'
|
||||
|
||||
import { useState } from 'react'
|
||||
import { Eye, EyeOff } from 'lucide-react'
|
||||
import { useParams } from 'next/navigation'
|
||||
import {
|
||||
Button,
|
||||
Input as EmcnInput,
|
||||
Modal,
|
||||
ModalBody,
|
||||
ModalContent,
|
||||
ModalFooter,
|
||||
ModalHeader,
|
||||
Trash,
|
||||
} from '@/components/emcn'
|
||||
import { AnthropicIcon, ExaAIIcon, GeminiIcon, MistralIcon, OpenAIIcon } from '@/components/icons'
|
||||
import { Skeleton } from '@/components/ui'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import {
|
||||
type BYOKKey,
|
||||
type BYOKProviderId,
|
||||
useBYOKKeys,
|
||||
useDeleteBYOKKey,
|
||||
useUpsertBYOKKey,
|
||||
} from '@/hooks/queries/byok-keys'
|
||||
|
||||
const logger = createLogger('BYOKSettings')
|
||||
|
||||
const PROVIDERS: {
|
||||
id: BYOKProviderId
|
||||
name: string
|
||||
icon: React.ComponentType<{ className?: string }>
|
||||
description: string
|
||||
placeholder: string
|
||||
}[] = [
|
||||
{
|
||||
id: 'openai',
|
||||
name: 'OpenAI',
|
||||
icon: OpenAIIcon,
|
||||
description: 'LLM calls and Knowledge Base embeddings',
|
||||
placeholder: 'sk-...',
|
||||
},
|
||||
{
|
||||
id: 'anthropic',
|
||||
name: 'Anthropic',
|
||||
icon: AnthropicIcon,
|
||||
description: 'LLM calls',
|
||||
placeholder: 'sk-ant-...',
|
||||
},
|
||||
{
|
||||
id: 'google',
|
||||
name: 'Google',
|
||||
icon: GeminiIcon,
|
||||
description: 'LLM calls',
|
||||
placeholder: 'Enter your API key',
|
||||
},
|
||||
{
|
||||
id: 'mistral',
|
||||
name: 'Mistral',
|
||||
icon: MistralIcon,
|
||||
description: 'LLM calls and Knowledge Base OCR',
|
||||
placeholder: 'Enter your API key',
|
||||
},
|
||||
{
|
||||
id: 'exa',
|
||||
name: 'Exa',
|
||||
icon: ExaAIIcon,
|
||||
description: 'Web Search block',
|
||||
placeholder: 'Enter your API key',
|
||||
},
|
||||
]
|
||||
|
||||
function BYOKKeySkeleton() {
|
||||
return (
|
||||
<div className='flex items-center justify-between gap-[12px] rounded-[8px] border p-[12px]'>
|
||||
<div className='flex items-center gap-[12px]'>
|
||||
<Skeleton className='h-[32px] w-[32px] rounded-[6px]' />
|
||||
<div className='flex flex-col gap-[4px]'>
|
||||
<Skeleton className='h-[16px] w-[80px]' />
|
||||
<Skeleton className='h-[14px] w-[160px]' />
|
||||
</div>
|
||||
</div>
|
||||
<Skeleton className='h-[32px] w-[80px] rounded-[6px]' />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export function BYOK() {
|
||||
const params = useParams()
|
||||
const workspaceId = (params?.workspaceId as string) || ''
|
||||
|
||||
const { data: keys = [], isLoading } = useBYOKKeys(workspaceId)
|
||||
const upsertKey = useUpsertBYOKKey()
|
||||
const deleteKey = useDeleteBYOKKey()
|
||||
|
||||
const [editingProvider, setEditingProvider] = useState<BYOKProviderId | null>(null)
|
||||
const [apiKeyInput, setApiKeyInput] = useState('')
|
||||
const [showApiKey, setShowApiKey] = useState(false)
|
||||
const [error, setError] = useState<string | null>(null)
|
||||
|
||||
const [deleteConfirmProvider, setDeleteConfirmProvider] = useState<BYOKProviderId | null>(null)
|
||||
|
||||
const getKeyForProvider = (providerId: BYOKProviderId): BYOKKey | undefined => {
|
||||
return keys.find((k) => k.providerId === providerId)
|
||||
}
|
||||
|
||||
const handleSave = async () => {
|
||||
if (!editingProvider || !apiKeyInput.trim()) return
|
||||
|
||||
setError(null)
|
||||
try {
|
||||
await upsertKey.mutateAsync({
|
||||
workspaceId,
|
||||
providerId: editingProvider,
|
||||
apiKey: apiKeyInput.trim(),
|
||||
})
|
||||
setEditingProvider(null)
|
||||
setApiKeyInput('')
|
||||
setShowApiKey(false)
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : 'Failed to save API key'
|
||||
setError(message)
|
||||
logger.error('Failed to save BYOK key', { error: err })
|
||||
}
|
||||
}
|
||||
|
||||
const handleDelete = async () => {
|
||||
if (!deleteConfirmProvider) return
|
||||
|
||||
try {
|
||||
await deleteKey.mutateAsync({
|
||||
workspaceId,
|
||||
providerId: deleteConfirmProvider,
|
||||
})
|
||||
setDeleteConfirmProvider(null)
|
||||
} catch (err) {
|
||||
logger.error('Failed to delete BYOK key', { error: err })
|
||||
}
|
||||
}
|
||||
|
||||
const openEditModal = (providerId: BYOKProviderId) => {
|
||||
setEditingProvider(providerId)
|
||||
setApiKeyInput('')
|
||||
setShowApiKey(false)
|
||||
setError(null)
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<div className='flex h-full flex-col gap-[16px]'>
|
||||
<p className='text-[13px] text-[var(--text-secondary)]'>
|
||||
Use your own API keys for hosted model providers.
|
||||
</p>
|
||||
|
||||
<div className='min-h-0 flex-1 overflow-y-auto'>
|
||||
{isLoading ? (
|
||||
<div className='flex flex-col gap-[8px]'>
|
||||
{PROVIDERS.map((p) => (
|
||||
<BYOKKeySkeleton key={p.id} />
|
||||
))}
|
||||
</div>
|
||||
) : (
|
||||
<div className='flex flex-col gap-[8px]'>
|
||||
{PROVIDERS.map((provider) => {
|
||||
const existingKey = getKeyForProvider(provider.id)
|
||||
const Icon = provider.icon
|
||||
|
||||
return (
|
||||
<div
|
||||
key={provider.id}
|
||||
className='flex items-center justify-between gap-[12px] rounded-[8px] border p-[12px]'
|
||||
>
|
||||
<div className='flex items-center gap-[12px]'>
|
||||
<div className='flex h-[32px] w-[32px] items-center justify-center rounded-[6px] bg-[var(--surface-3)]'>
|
||||
<Icon className='h-[18px] w-[18px]' />
|
||||
</div>
|
||||
<div className='flex flex-col gap-[2px]'>
|
||||
<span className='font-medium text-[14px]'>{provider.name}</span>
|
||||
<span className='text-[12px] text-[var(--text-tertiary)]'>
|
||||
{provider.description}
|
||||
</span>
|
||||
{existingKey && (
|
||||
<span className='font-mono text-[11px] text-[var(--text-muted)]'>
|
||||
{existingKey.maskedKey}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className='flex items-center gap-[6px]'>
|
||||
{existingKey && (
|
||||
<Button
|
||||
variant='ghost'
|
||||
className='h-9 w-9'
|
||||
onClick={() => setDeleteConfirmProvider(provider.id)}
|
||||
>
|
||||
<Trash />
|
||||
</Button>
|
||||
)}
|
||||
<Button variant='default' onClick={() => openEditModal(provider.id)}>
|
||||
{existingKey ? 'Update' : 'Add Key'}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Modal
|
||||
open={!!editingProvider}
|
||||
onOpenChange={(open) => {
|
||||
if (!open) {
|
||||
setEditingProvider(null)
|
||||
setApiKeyInput('')
|
||||
setShowApiKey(false)
|
||||
setError(null)
|
||||
}
|
||||
}}
|
||||
>
|
||||
<ModalContent className='w-[420px]'>
|
||||
<ModalHeader>
|
||||
{editingProvider && (
|
||||
<>
|
||||
{getKeyForProvider(editingProvider) ? 'Update' : 'Add'}{' '}
|
||||
{PROVIDERS.find((p) => p.id === editingProvider)?.name} API Key
|
||||
</>
|
||||
)}
|
||||
</ModalHeader>
|
||||
<ModalBody>
|
||||
<p className='text-[12px] text-[var(--text-tertiary)]'>
|
||||
This key will be used for all {PROVIDERS.find((p) => p.id === editingProvider)?.name}{' '}
|
||||
requests in this workspace. Your key is encrypted and stored securely.
|
||||
</p>
|
||||
|
||||
<div className='mt-[12px] flex flex-col gap-[8px]'>
|
||||
<div className='relative'>
|
||||
<EmcnInput
|
||||
type={showApiKey ? 'text' : 'password'}
|
||||
value={apiKeyInput}
|
||||
onChange={(e) => {
|
||||
setApiKeyInput(e.target.value)
|
||||
if (error) setError(null)
|
||||
}}
|
||||
placeholder={PROVIDERS.find((p) => p.id === editingProvider)?.placeholder}
|
||||
className='h-9 pr-[36px]'
|
||||
autoFocus
|
||||
/>
|
||||
<Button
|
||||
variant='ghost'
|
||||
className='-translate-y-1/2 absolute top-1/2 right-[4px] h-[28px] w-[28px] p-0'
|
||||
onClick={() => setShowApiKey(!showApiKey)}
|
||||
>
|
||||
{showApiKey ? (
|
||||
<EyeOff className='h-[14px] w-[14px]' />
|
||||
) : (
|
||||
<Eye className='h-[14px] w-[14px]' />
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
{error && (
|
||||
<p className='text-[11px] text-[var(--text-error)] leading-tight'>{error}</p>
|
||||
)}
|
||||
</div>
|
||||
</ModalBody>
|
||||
|
||||
<ModalFooter>
|
||||
<Button
|
||||
variant='default'
|
||||
onClick={() => {
|
||||
setEditingProvider(null)
|
||||
setApiKeyInput('')
|
||||
setShowApiKey(false)
|
||||
setError(null)
|
||||
}}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
variant='primary'
|
||||
onClick={handleSave}
|
||||
disabled={!apiKeyInput.trim() || upsertKey.isPending}
|
||||
>
|
||||
{upsertKey.isPending ? 'Saving...' : 'Save'}
|
||||
</Button>
|
||||
</ModalFooter>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
|
||||
<Modal open={!!deleteConfirmProvider} onOpenChange={() => setDeleteConfirmProvider(null)}>
|
||||
<ModalContent className='w-[400px]'>
|
||||
<ModalHeader>Delete API Key</ModalHeader>
|
||||
<ModalBody>
|
||||
<p className='text-[12px] text-[var(--text-tertiary)]'>
|
||||
Are you sure you want to delete the{' '}
|
||||
<span className='font-medium text-[var(--text-primary)]'>
|
||||
{PROVIDERS.find((p) => p.id === deleteConfirmProvider)?.name}
|
||||
</span>{' '}
|
||||
API key? This workspace will revert to using platform keys with the 2x multiplier.
|
||||
</p>
|
||||
</ModalBody>
|
||||
<ModalFooter>
|
||||
<Button variant='default' onClick={() => setDeleteConfirmProvider(null)}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button variant='primary' onClick={handleDelete} disabled={deleteKey.isPending}>
|
||||
{deleteKey.isPending ? 'Deleting...' : 'Delete'}
|
||||
</Button>
|
||||
</ModalFooter>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
</>
|
||||
)
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
export { ApiKeys } from './api-keys/api-keys'
|
||||
export { BYOK } from './byok/byok'
|
||||
export { Copilot } from './copilot/copilot'
|
||||
export { CustomTools } from './custom-tools/custom-tools'
|
||||
export { EnvironmentVariables } from './environment/environment'
|
||||
|
||||
@@ -4,7 +4,7 @@ import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
|
||||
import * as DialogPrimitive from '@radix-ui/react-dialog'
|
||||
import * as VisuallyHidden from '@radix-ui/react-visually-hidden'
|
||||
import { useQueryClient } from '@tanstack/react-query'
|
||||
import { Files, LogIn, Settings, User, Users, Wrench } from 'lucide-react'
|
||||
import { Files, KeySquare, LogIn, Settings, User, Users, Wrench } from 'lucide-react'
|
||||
import {
|
||||
Card,
|
||||
Connections,
|
||||
@@ -30,6 +30,7 @@ import { isHosted } from '@/lib/core/config/feature-flags'
|
||||
import { getUserRole } from '@/lib/workspaces/organization'
|
||||
import {
|
||||
ApiKeys,
|
||||
BYOK,
|
||||
Copilot,
|
||||
CustomTools,
|
||||
EnvironmentVariables,
|
||||
@@ -62,6 +63,7 @@ type SettingsSection =
|
||||
| 'template-profile'
|
||||
| 'integrations'
|
||||
| 'apikeys'
|
||||
| 'byok'
|
||||
| 'files'
|
||||
| 'subscription'
|
||||
| 'team'
|
||||
@@ -114,6 +116,13 @@ const allNavigationItems: NavigationItem[] = [
|
||||
{ id: 'mcp', label: 'MCPs', icon: McpIcon, section: 'tools' },
|
||||
{ id: 'environment', label: 'Environment', icon: FolderCode, section: 'system' },
|
||||
{ id: 'apikeys', label: 'API Keys', icon: Key, section: 'system' },
|
||||
{
|
||||
id: 'byok',
|
||||
label: 'BYOK',
|
||||
icon: KeySquare,
|
||||
section: 'system',
|
||||
requiresHosted: true,
|
||||
},
|
||||
{
|
||||
id: 'copilot',
|
||||
label: 'Copilot Keys',
|
||||
@@ -456,6 +465,7 @@ export function SettingsModal({ open, onOpenChange }: SettingsModalProps) {
|
||||
{isBillingEnabled && activeSection === 'subscription' && <Subscription />}
|
||||
{isBillingEnabled && activeSection === 'team' && <TeamManagement />}
|
||||
{activeSection === 'sso' && <SSO />}
|
||||
{activeSection === 'byok' && <BYOK />}
|
||||
{activeSection === 'copilot' && <Copilot />}
|
||||
{activeSection === 'mcp' && <MCP initialServerId={pendingMcpServerId} />}
|
||||
{activeSection === 'custom-tools' && <CustomTools />}
|
||||
|
||||
@@ -26,7 +26,7 @@ import { collectBlockData } from '@/executor/utils/block-data'
|
||||
import { buildAPIUrl, buildAuthHeaders, extractAPIErrorMessage } from '@/executor/utils/http'
|
||||
import { stringifyJSON } from '@/executor/utils/json'
|
||||
import { executeProviderRequest } from '@/providers'
|
||||
import { getApiKey, getProviderFromModel, transformBlockTool } from '@/providers/utils'
|
||||
import { getProviderFromModel, transformBlockTool } from '@/providers/utils'
|
||||
import type { SerializedBlock } from '@/serializer/types'
|
||||
import { executeTool } from '@/tools'
|
||||
import { getTool, getToolAsync } from '@/tools/utils'
|
||||
@@ -1006,15 +1006,13 @@ export class AgentBlockHandler implements BlockHandler {
|
||||
responseFormat: any,
|
||||
providerStartTime: number
|
||||
) {
|
||||
let finalApiKey: string
|
||||
let finalApiKey: string | undefined = providerRequest.apiKey
|
||||
|
||||
if (providerId === 'vertex' && providerRequest.vertexCredential) {
|
||||
finalApiKey = await this.resolveVertexCredential(
|
||||
providerRequest.vertexCredential,
|
||||
ctx.workflowId
|
||||
)
|
||||
} else {
|
||||
finalApiKey = this.getApiKey(providerId, model, providerRequest.apiKey)
|
||||
}
|
||||
|
||||
const { blockData, blockNameMapping } = collectBlockData(ctx)
|
||||
@@ -1033,7 +1031,7 @@ export class AgentBlockHandler implements BlockHandler {
|
||||
vertexLocation: providerRequest.vertexLocation,
|
||||
responseFormat: providerRequest.responseFormat,
|
||||
workflowId: providerRequest.workflowId,
|
||||
workspaceId: providerRequest.workspaceId,
|
||||
workspaceId: ctx.workspaceId,
|
||||
stream: providerRequest.stream,
|
||||
messages: 'messages' in providerRequest ? providerRequest.messages : undefined,
|
||||
environmentVariables: ctx.environmentVariables || {},
|
||||
@@ -1111,20 +1109,6 @@ export class AgentBlockHandler implements BlockHandler {
|
||||
return this.createMinimalStreamingExecution(response.body!)
|
||||
}
|
||||
|
||||
private getApiKey(providerId: string, model: string, inputApiKey: string): string {
|
||||
try {
|
||||
return getApiKey(providerId, model, inputApiKey)
|
||||
} catch (error) {
|
||||
logger.error('Failed to get API key:', {
|
||||
provider: providerId,
|
||||
model,
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
hasProvidedApiKey: !!inputApiKey,
|
||||
})
|
||||
throw new Error(error instanceof Error ? error.message : 'API key error')
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolves a Vertex AI OAuth credential to an access token
|
||||
*/
|
||||
|
||||
@@ -388,21 +388,6 @@ describe('EvaluatorBlockHandler', () => {
|
||||
})
|
||||
})
|
||||
|
||||
it('should throw error when API key is missing for non-hosted models', async () => {
|
||||
const inputs = {
|
||||
content: 'Test content',
|
||||
metrics: [{ name: 'score', description: 'Score', range: { min: 0, max: 10 } }],
|
||||
model: 'gpt-4o',
|
||||
// No apiKey provided
|
||||
}
|
||||
|
||||
mockGetProviderFromModel.mockReturnValue('openai')
|
||||
|
||||
await expect(handler.execute(mockContext, mockBlock, inputs)).rejects.toThrow(
|
||||
/API key is required/
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle Vertex AI models with OAuth credential', async () => {
|
||||
const inputs = {
|
||||
content: 'Test content to evaluate',
|
||||
|
||||
@@ -8,7 +8,7 @@ import { BlockType, DEFAULTS, EVALUATOR, HTTP } from '@/executor/constants'
|
||||
import type { BlockHandler, ExecutionContext } from '@/executor/types'
|
||||
import { buildAPIUrl, extractAPIErrorMessage } from '@/executor/utils/http'
|
||||
import { isJSONString, parseJSON, stringifyJSON } from '@/executor/utils/json'
|
||||
import { calculateCost, getApiKey, getProviderFromModel } from '@/providers/utils'
|
||||
import { calculateCost, getProviderFromModel } from '@/providers/utils'
|
||||
import type { SerializedBlock } from '@/serializer/types'
|
||||
|
||||
const logger = createLogger('EvaluatorBlockHandler')
|
||||
@@ -35,11 +35,9 @@ export class EvaluatorBlockHandler implements BlockHandler {
|
||||
}
|
||||
const providerId = getProviderFromModel(evaluatorConfig.model)
|
||||
|
||||
let finalApiKey: string
|
||||
let finalApiKey: string | undefined = evaluatorConfig.apiKey
|
||||
if (providerId === 'vertex' && evaluatorConfig.vertexCredential) {
|
||||
finalApiKey = await this.resolveVertexCredential(evaluatorConfig.vertexCredential)
|
||||
} else {
|
||||
finalApiKey = this.getApiKey(providerId, evaluatorConfig.model, evaluatorConfig.apiKey)
|
||||
}
|
||||
|
||||
const processedContent = this.processContent(inputs.content)
|
||||
@@ -117,6 +115,7 @@ export class EvaluatorBlockHandler implements BlockHandler {
|
||||
temperature: EVALUATOR.DEFAULT_TEMPERATURE,
|
||||
apiKey: finalApiKey,
|
||||
workflowId: ctx.workflowId,
|
||||
workspaceId: ctx.workspaceId,
|
||||
}
|
||||
|
||||
if (providerId === 'vertex') {
|
||||
@@ -275,20 +274,6 @@ export class EvaluatorBlockHandler implements BlockHandler {
|
||||
return DEFAULTS.EXECUTION_TIME
|
||||
}
|
||||
|
||||
private getApiKey(providerId: string, model: string, inputApiKey: string): string {
|
||||
try {
|
||||
return getApiKey(providerId, model, inputApiKey)
|
||||
} catch (error) {
|
||||
logger.error('Failed to get API key:', {
|
||||
provider: providerId,
|
||||
model,
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
hasProvidedApiKey: !!inputApiKey,
|
||||
})
|
||||
throw new Error(error instanceof Error ? error.message : 'API key error')
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolves a Vertex AI OAuth credential to an access token
|
||||
*/
|
||||
|
||||
@@ -265,20 +265,6 @@ describe('RouterBlockHandler', () => {
|
||||
})
|
||||
})
|
||||
|
||||
it('should throw error when API key is missing for non-hosted models', async () => {
|
||||
const inputs = {
|
||||
prompt: 'Test without API key',
|
||||
model: 'gpt-4o',
|
||||
// No apiKey provided
|
||||
}
|
||||
|
||||
mockGetProviderFromModel.mockReturnValue('openai')
|
||||
|
||||
await expect(handler.execute(mockContext, mockBlock, inputs)).rejects.toThrow(
|
||||
/API key is required/
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle Vertex AI models with OAuth credential', async () => {
|
||||
const inputs = {
|
||||
prompt: 'Choose the best option.',
|
||||
|
||||
@@ -8,7 +8,7 @@ import { generateRouterPrompt } from '@/blocks/blocks/router'
|
||||
import type { BlockOutput } from '@/blocks/types'
|
||||
import { BlockType, DEFAULTS, HTTP, isAgentBlockType, ROUTER } from '@/executor/constants'
|
||||
import type { BlockHandler, ExecutionContext } from '@/executor/types'
|
||||
import { calculateCost, getApiKey, getProviderFromModel } from '@/providers/utils'
|
||||
import { calculateCost, getProviderFromModel } from '@/providers/utils'
|
||||
import type { SerializedBlock } from '@/serializer/types'
|
||||
|
||||
const logger = createLogger('RouterBlockHandler')
|
||||
@@ -47,11 +47,9 @@ export class RouterBlockHandler implements BlockHandler {
|
||||
const messages = [{ role: 'user', content: routerConfig.prompt }]
|
||||
const systemPrompt = generateRouterPrompt(routerConfig.prompt, targetBlocks)
|
||||
|
||||
let finalApiKey: string
|
||||
let finalApiKey: string | undefined = routerConfig.apiKey
|
||||
if (providerId === 'vertex' && routerConfig.vertexCredential) {
|
||||
finalApiKey = await this.resolveVertexCredential(routerConfig.vertexCredential)
|
||||
} else {
|
||||
finalApiKey = this.getApiKey(providerId, routerConfig.model, routerConfig.apiKey)
|
||||
}
|
||||
|
||||
const providerRequest: Record<string, any> = {
|
||||
@@ -62,6 +60,7 @@ export class RouterBlockHandler implements BlockHandler {
|
||||
temperature: ROUTER.INFERENCE_TEMPERATURE,
|
||||
apiKey: finalApiKey,
|
||||
workflowId: ctx.workflowId,
|
||||
workspaceId: ctx.workspaceId,
|
||||
}
|
||||
|
||||
if (providerId === 'vertex') {
|
||||
@@ -178,20 +177,6 @@ export class RouterBlockHandler implements BlockHandler {
|
||||
})
|
||||
}
|
||||
|
||||
private getApiKey(providerId: string, model: string, inputApiKey: string): string {
|
||||
try {
|
||||
return getApiKey(providerId, model, inputApiKey)
|
||||
} catch (error) {
|
||||
logger.error('Failed to get API key:', {
|
||||
provider: providerId,
|
||||
model,
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
hasProvidedApiKey: !!inputApiKey,
|
||||
})
|
||||
throw new Error(error instanceof Error ? error.message : 'API key error')
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolves a Vertex AI OAuth credential to an access token
|
||||
*/
|
||||
|
||||
@@ -19,7 +19,7 @@ const sleep = async (ms: number, options: SleepOptions = {}): Promise<boolean> =
|
||||
}
|
||||
|
||||
return new Promise((resolve) => {
|
||||
// biome-ignore lint/style/useConst: Variable is assigned after closure definitions that reference it
|
||||
// biome-ignore lint/style/useConst: needs to be declared before cleanup() but assigned later
|
||||
let mainTimeoutId: NodeJS.Timeout | undefined
|
||||
let checkIntervalId: NodeJS.Timeout | undefined
|
||||
let resolved = false
|
||||
|
||||
105
apps/sim/hooks/queries/byok-keys.ts
Normal file
105
apps/sim/hooks/queries/byok-keys.ts
Normal file
@@ -0,0 +1,105 @@
|
||||
import { keepPreviousData, useMutation, useQuery, useQueryClient } from '@tanstack/react-query'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { API_ENDPOINTS } from '@/stores/constants'
|
||||
|
||||
const logger = createLogger('BYOKKeysQueries')
|
||||
|
||||
export type BYOKProviderId = 'openai' | 'anthropic' | 'google' | 'mistral' | 'exa'
|
||||
|
||||
export interface BYOKKey {
|
||||
id: string
|
||||
providerId: BYOKProviderId
|
||||
maskedKey: string
|
||||
createdBy: string | null
|
||||
createdAt: string
|
||||
updatedAt: string
|
||||
}
|
||||
|
||||
export const byokKeysKeys = {
|
||||
all: ['byok-keys'] as const,
|
||||
workspace: (workspaceId: string) => [...byokKeysKeys.all, 'workspace', workspaceId] as const,
|
||||
}
|
||||
|
||||
async function fetchBYOKKeys(workspaceId: string): Promise<BYOKKey[]> {
|
||||
const response = await fetch(API_ENDPOINTS.WORKSPACE_BYOK_KEYS(workspaceId))
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to load BYOK keys: ${response.statusText}`)
|
||||
}
|
||||
const { keys } = await response.json()
|
||||
return keys
|
||||
}
|
||||
|
||||
export function useBYOKKeys(workspaceId: string) {
|
||||
return useQuery({
|
||||
queryKey: byokKeysKeys.workspace(workspaceId),
|
||||
queryFn: () => fetchBYOKKeys(workspaceId),
|
||||
enabled: !!workspaceId,
|
||||
staleTime: 60 * 1000,
|
||||
placeholderData: keepPreviousData,
|
||||
})
|
||||
}
|
||||
|
||||
interface UpsertBYOKKeyParams {
|
||||
workspaceId: string
|
||||
providerId: BYOKProviderId
|
||||
apiKey: string
|
||||
}
|
||||
|
||||
export function useUpsertBYOKKey() {
|
||||
const queryClient = useQueryClient()
|
||||
|
||||
return useMutation({
|
||||
mutationFn: async ({ workspaceId, providerId, apiKey }: UpsertBYOKKeyParams) => {
|
||||
const response = await fetch(API_ENDPOINTS.WORKSPACE_BYOK_KEYS(workspaceId), {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ providerId, apiKey }),
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
const data = await response.json().catch(() => ({}))
|
||||
throw new Error(data.error || `Failed to save BYOK key: ${response.statusText}`)
|
||||
}
|
||||
|
||||
logger.info(`Saved BYOK key for ${providerId} in workspace ${workspaceId}`)
|
||||
return await response.json()
|
||||
},
|
||||
onSuccess: (_data, variables) => {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: byokKeysKeys.workspace(variables.workspaceId),
|
||||
})
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
interface DeleteBYOKKeyParams {
|
||||
workspaceId: string
|
||||
providerId: BYOKProviderId
|
||||
}
|
||||
|
||||
export function useDeleteBYOKKey() {
|
||||
const queryClient = useQueryClient()
|
||||
|
||||
return useMutation({
|
||||
mutationFn: async ({ workspaceId, providerId }: DeleteBYOKKeyParams) => {
|
||||
const response = await fetch(API_ENDPOINTS.WORKSPACE_BYOK_KEYS(workspaceId), {
|
||||
method: 'DELETE',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ providerId }),
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
const data = await response.json().catch(() => ({}))
|
||||
throw new Error(data.error || `Failed to delete BYOK key: ${response.statusText}`)
|
||||
}
|
||||
|
||||
logger.info(`Deleted BYOK key for ${providerId} from workspace ${workspaceId}`)
|
||||
return await response.json()
|
||||
},
|
||||
onSuccess: (_data, variables) => {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: byokKeysKeys.workspace(variables.workspaceId),
|
||||
})
|
||||
},
|
||||
})
|
||||
}
|
||||
121
apps/sim/lib/api-key/byok.ts
Normal file
121
apps/sim/lib/api-key/byok.ts
Normal file
@@ -0,0 +1,121 @@
|
||||
import { db } from '@sim/db'
|
||||
import { workspaceBYOKKeys } from '@sim/db/schema'
|
||||
import { and, eq } from 'drizzle-orm'
|
||||
import { decryptSecret } from '@/lib/core/security/encryption'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
|
||||
const logger = createLogger('BYOKKeys')
|
||||
|
||||
export type BYOKProviderId = 'openai' | 'anthropic' | 'google' | 'mistral' | 'exa'
|
||||
|
||||
export interface BYOKKeyResult {
|
||||
apiKey: string
|
||||
isBYOK: true
|
||||
}
|
||||
|
||||
export async function getBYOKKey(
|
||||
workspaceId: string | undefined | null,
|
||||
providerId: BYOKProviderId
|
||||
): Promise<BYOKKeyResult | null> {
|
||||
if (!workspaceId) {
|
||||
return null
|
||||
}
|
||||
|
||||
try {
|
||||
const result = await db
|
||||
.select({ encryptedApiKey: workspaceBYOKKeys.encryptedApiKey })
|
||||
.from(workspaceBYOKKeys)
|
||||
.where(
|
||||
and(
|
||||
eq(workspaceBYOKKeys.workspaceId, workspaceId),
|
||||
eq(workspaceBYOKKeys.providerId, providerId)
|
||||
)
|
||||
)
|
||||
.limit(1)
|
||||
|
||||
if (!result.length) {
|
||||
return null
|
||||
}
|
||||
|
||||
const { decrypted } = await decryptSecret(result[0].encryptedApiKey)
|
||||
return { apiKey: decrypted, isBYOK: true }
|
||||
} catch (error) {
|
||||
logger.error('Failed to get BYOK key', { workspaceId, providerId, error })
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
export async function getApiKeyWithBYOK(
|
||||
provider: string,
|
||||
model: string,
|
||||
workspaceId: string | undefined | null,
|
||||
userProvidedKey?: string
|
||||
): Promise<{ apiKey: string; isBYOK: boolean }> {
|
||||
const { isHosted } = await import('@/lib/core/config/feature-flags')
|
||||
const { useProvidersStore } = await import('@/stores/providers/store')
|
||||
|
||||
const isOllamaModel =
|
||||
provider === 'ollama' || useProvidersStore.getState().providers.ollama.models.includes(model)
|
||||
if (isOllamaModel) {
|
||||
return { apiKey: 'empty', isBYOK: false }
|
||||
}
|
||||
|
||||
const isVllmModel =
|
||||
provider === 'vllm' || useProvidersStore.getState().providers.vllm.models.includes(model)
|
||||
if (isVllmModel) {
|
||||
return { apiKey: userProvidedKey || 'empty', isBYOK: false }
|
||||
}
|
||||
|
||||
const isOpenAIModel = provider === 'openai'
|
||||
const isClaudeModel = provider === 'anthropic'
|
||||
const isGeminiModel = provider === 'google'
|
||||
const isMistralModel = provider === 'mistral'
|
||||
|
||||
const byokProviderId = isGeminiModel ? 'google' : (provider as BYOKProviderId)
|
||||
|
||||
if (
|
||||
isHosted &&
|
||||
workspaceId &&
|
||||
(isOpenAIModel || isClaudeModel || isGeminiModel || isMistralModel)
|
||||
) {
|
||||
const { getHostedModels } = await import('@/providers/models')
|
||||
const hostedModels = getHostedModels()
|
||||
const isModelHosted = hostedModels.some((m) => m.toLowerCase() === model.toLowerCase())
|
||||
|
||||
logger.debug('BYOK check', { provider, model, workspaceId, isHosted, isModelHosted })
|
||||
|
||||
if (isModelHosted || isMistralModel) {
|
||||
const byokResult = await getBYOKKey(workspaceId, byokProviderId)
|
||||
if (byokResult) {
|
||||
logger.info('Using BYOK key', { provider, model, workspaceId })
|
||||
return byokResult
|
||||
}
|
||||
logger.debug('No BYOK key found, falling back', { provider, model, workspaceId })
|
||||
|
||||
if (isModelHosted) {
|
||||
try {
|
||||
const { getRotatingApiKey } = await import('@/lib/core/config/api-keys')
|
||||
const serverKey = getRotatingApiKey(isGeminiModel ? 'gemini' : provider)
|
||||
return { apiKey: serverKey, isBYOK: false }
|
||||
} catch (_error) {
|
||||
if (userProvidedKey) {
|
||||
return { apiKey: userProvidedKey, isBYOK: false }
|
||||
}
|
||||
throw new Error(`No API key available for ${provider} ${model}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!userProvidedKey) {
|
||||
logger.debug('BYOK not applicable, no user key provided', {
|
||||
provider,
|
||||
model,
|
||||
workspaceId,
|
||||
isHosted,
|
||||
})
|
||||
throw new Error(`API key is required for ${provider} ${model}`)
|
||||
}
|
||||
|
||||
return { apiKey: userProvidedKey, isBYOK: false }
|
||||
}
|
||||
@@ -66,7 +66,7 @@ export interface LogFixedUsageParams {
|
||||
* Log a model usage charge (token-based)
|
||||
*/
|
||||
export async function logModelUsage(params: LogModelUsageParams): Promise<void> {
|
||||
if (!isBillingEnabled) {
|
||||
if (!isBillingEnabled || params.cost <= 0) {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -108,7 +108,7 @@ export async function logModelUsage(params: LogModelUsageParams): Promise<void>
|
||||
* Log a fixed charge (flat fee like base execution charge or search)
|
||||
*/
|
||||
export async function logFixedUsage(params: LogFixedUsageParams): Promise<void> {
|
||||
if (!isBillingEnabled) {
|
||||
if (!isBillingEnabled || params.cost <= 0) {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -174,8 +174,11 @@ export const knowledgeBaseServerTool: BaseServerTool<KnowledgeBaseArgs, Knowledg
|
||||
|
||||
const topK = args.topK || 5
|
||||
|
||||
// Generate embedding for the query
|
||||
const queryEmbedding = await generateSearchEmbedding(args.query)
|
||||
const queryEmbedding = await generateSearchEmbedding(
|
||||
args.query,
|
||||
undefined,
|
||||
kb.workspaceId
|
||||
)
|
||||
const queryVector = JSON.stringify(queryEmbedding)
|
||||
|
||||
// Get search strategy
|
||||
|
||||
@@ -94,11 +94,11 @@ export async function createChunk(
|
||||
documentId: string,
|
||||
docTags: Record<string, string | number | boolean | Date | null>,
|
||||
chunkData: CreateChunkData,
|
||||
requestId: string
|
||||
requestId: string,
|
||||
workspaceId?: string | null
|
||||
): Promise<ChunkData> {
|
||||
// Generate embedding for the content first (outside transaction for performance)
|
||||
logger.info(`[${requestId}] Generating embedding for manual chunk`)
|
||||
const embeddings = await generateEmbeddings([chunkData.content])
|
||||
const embeddings = await generateEmbeddings([chunkData.content], undefined, workspaceId)
|
||||
|
||||
// Calculate accurate token count
|
||||
const tokenCount = estimateTokenCount(chunkData.content, 'openai')
|
||||
@@ -285,7 +285,8 @@ export async function updateChunk(
|
||||
content?: string
|
||||
enabled?: boolean
|
||||
},
|
||||
requestId: string
|
||||
requestId: string,
|
||||
workspaceId?: string | null
|
||||
): Promise<ChunkData> {
|
||||
const dbUpdateData: {
|
||||
updatedAt: Date
|
||||
@@ -327,8 +328,7 @@ export async function updateChunk(
|
||||
if (content !== currentChunk[0].content) {
|
||||
logger.info(`[${requestId}] Content changed, regenerating embedding for chunk ${chunkId}`)
|
||||
|
||||
// Generate new embedding for the updated content
|
||||
const embeddings = await generateEmbeddings([content])
|
||||
const embeddings = await generateEmbeddings([content], undefined, workspaceId)
|
||||
|
||||
// Calculate accurate token count
|
||||
const tokenCount = estimateTokenCount(content, 'openai')
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { getBYOKKey } from '@/lib/api-key/byok'
|
||||
import { type Chunk, JsonYamlChunker, StructuredDataChunker, TextChunker } from '@/lib/chunkers'
|
||||
import { env } from '@/lib/core/config/env'
|
||||
import { parseBuffer, parseFile } from '@/lib/file-parsers'
|
||||
@@ -131,6 +132,17 @@ export async function processDocument(
|
||||
}
|
||||
}
|
||||
|
||||
async function getMistralApiKey(workspaceId?: string | null): Promise<string | null> {
|
||||
if (workspaceId) {
|
||||
const byokResult = await getBYOKKey(workspaceId, 'mistral')
|
||||
if (byokResult) {
|
||||
logger.info('Using workspace BYOK key for Mistral OCR')
|
||||
return byokResult.apiKey
|
||||
}
|
||||
}
|
||||
return env.MISTRAL_API_KEY || null
|
||||
}
|
||||
|
||||
async function parseDocument(
|
||||
fileUrl: string,
|
||||
filename: string,
|
||||
@@ -146,7 +158,9 @@ async function parseDocument(
|
||||
const isPDF = mimeType === 'application/pdf'
|
||||
const hasAzureMistralOCR =
|
||||
env.OCR_AZURE_API_KEY && env.OCR_AZURE_ENDPOINT && env.OCR_AZURE_MODEL_NAME
|
||||
const hasMistralOCR = env.MISTRAL_API_KEY
|
||||
|
||||
const mistralApiKey = await getMistralApiKey(workspaceId)
|
||||
const hasMistralOCR = !!mistralApiKey
|
||||
|
||||
if (isPDF && (hasAzureMistralOCR || hasMistralOCR)) {
|
||||
if (hasAzureMistralOCR) {
|
||||
@@ -156,7 +170,7 @@ async function parseDocument(
|
||||
|
||||
if (hasMistralOCR) {
|
||||
logger.info(`Using Mistral OCR: ${filename}`)
|
||||
return parseWithMistralOCR(fileUrl, filename, mimeType, userId, workspaceId)
|
||||
return parseWithMistralOCR(fileUrl, filename, mimeType, userId, workspaceId, mistralApiKey)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -360,9 +374,18 @@ async function parseWithAzureMistralOCR(
|
||||
message: error instanceof Error ? error.message : String(error),
|
||||
})
|
||||
|
||||
return env.MISTRAL_API_KEY
|
||||
? parseWithMistralOCR(fileUrl, filename, mimeType, userId, workspaceId)
|
||||
: parseWithFileParser(fileUrl, filename, mimeType)
|
||||
const fallbackMistralKey = await getMistralApiKey(workspaceId)
|
||||
if (fallbackMistralKey) {
|
||||
return parseWithMistralOCR(
|
||||
fileUrl,
|
||||
filename,
|
||||
mimeType,
|
||||
userId,
|
||||
workspaceId,
|
||||
fallbackMistralKey
|
||||
)
|
||||
}
|
||||
return parseWithFileParser(fileUrl, filename, mimeType)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -371,9 +394,11 @@ async function parseWithMistralOCR(
|
||||
filename: string,
|
||||
mimeType: string,
|
||||
userId?: string,
|
||||
workspaceId?: string | null
|
||||
workspaceId?: string | null,
|
||||
mistralApiKey?: string | null
|
||||
) {
|
||||
if (!env.MISTRAL_API_KEY) {
|
||||
const apiKey = mistralApiKey || env.MISTRAL_API_KEY
|
||||
if (!apiKey) {
|
||||
throw new Error('Mistral API key required')
|
||||
}
|
||||
|
||||
@@ -388,7 +413,7 @@ async function parseWithMistralOCR(
|
||||
userId,
|
||||
workspaceId
|
||||
)
|
||||
const params = { filePath: httpsUrl, apiKey: env.MISTRAL_API_KEY, resultType: 'text' as const }
|
||||
const params = { filePath: httpsUrl, apiKey, resultType: 'text' as const }
|
||||
|
||||
try {
|
||||
const response = await retryWithExponentialBackoff(
|
||||
|
||||
@@ -484,7 +484,7 @@ export async function processDocumentAsync(
|
||||
const batchNum = Math.floor(i / batchSize) + 1
|
||||
|
||||
logger.info(`[${documentId}] Processing embedding batch ${batchNum}/${totalBatches}`)
|
||||
const batchEmbeddings = await generateEmbeddings(batch)
|
||||
const batchEmbeddings = await generateEmbeddings(batch, undefined, kb[0].workspaceId)
|
||||
embeddings.push(...batchEmbeddings)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { getBYOKKey } from '@/lib/api-key/byok'
|
||||
import { env } from '@/lib/core/config/env'
|
||||
import { isRetryableError, retryWithExponentialBackoff } from '@/lib/knowledge/documents/utils'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
@@ -24,40 +25,53 @@ interface EmbeddingConfig {
|
||||
modelName: string
|
||||
}
|
||||
|
||||
function getEmbeddingConfig(embeddingModel = 'text-embedding-3-small'): EmbeddingConfig {
|
||||
async function getEmbeddingConfig(
|
||||
embeddingModel = 'text-embedding-3-small',
|
||||
workspaceId?: string | null
|
||||
): Promise<EmbeddingConfig> {
|
||||
const azureApiKey = env.AZURE_OPENAI_API_KEY
|
||||
const azureEndpoint = env.AZURE_OPENAI_ENDPOINT
|
||||
const azureApiVersion = env.AZURE_OPENAI_API_VERSION
|
||||
const kbModelName = env.KB_OPENAI_MODEL_NAME || embeddingModel
|
||||
const openaiApiKey = env.OPENAI_API_KEY
|
||||
|
||||
const useAzure = !!(azureApiKey && azureEndpoint)
|
||||
|
||||
if (!useAzure && !openaiApiKey) {
|
||||
if (useAzure) {
|
||||
return {
|
||||
useAzure: true,
|
||||
apiUrl: `${azureEndpoint}/openai/deployments/${kbModelName}/embeddings?api-version=${azureApiVersion}`,
|
||||
headers: {
|
||||
'api-key': azureApiKey!,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
modelName: kbModelName,
|
||||
}
|
||||
}
|
||||
|
||||
let openaiApiKey = env.OPENAI_API_KEY
|
||||
|
||||
if (workspaceId) {
|
||||
const byokResult = await getBYOKKey(workspaceId, 'openai')
|
||||
if (byokResult) {
|
||||
logger.info('Using workspace BYOK key for OpenAI embeddings')
|
||||
openaiApiKey = byokResult.apiKey
|
||||
}
|
||||
}
|
||||
|
||||
if (!openaiApiKey) {
|
||||
throw new Error(
|
||||
'Either OPENAI_API_KEY or Azure OpenAI configuration (AZURE_OPENAI_API_KEY + AZURE_OPENAI_ENDPOINT) must be configured'
|
||||
)
|
||||
}
|
||||
|
||||
const apiUrl = useAzure
|
||||
? `${azureEndpoint}/openai/deployments/${kbModelName}/embeddings?api-version=${azureApiVersion}`
|
||||
: 'https://api.openai.com/v1/embeddings'
|
||||
|
||||
const headers: Record<string, string> = useAzure
|
||||
? {
|
||||
'api-key': azureApiKey!,
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
: {
|
||||
Authorization: `Bearer ${openaiApiKey!}`,
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
|
||||
return {
|
||||
useAzure,
|
||||
apiUrl,
|
||||
headers,
|
||||
modelName: useAzure ? kbModelName : embeddingModel,
|
||||
useAzure: false,
|
||||
apiUrl: 'https://api.openai.com/v1/embeddings',
|
||||
headers: {
|
||||
Authorization: `Bearer ${openaiApiKey}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
modelName: embeddingModel,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -112,9 +126,10 @@ async function callEmbeddingAPI(inputs: string[], config: EmbeddingConfig): Prom
|
||||
*/
|
||||
export async function generateEmbeddings(
|
||||
texts: string[],
|
||||
embeddingModel = 'text-embedding-3-small'
|
||||
embeddingModel = 'text-embedding-3-small',
|
||||
workspaceId?: string | null
|
||||
): Promise<number[][]> {
|
||||
const config = getEmbeddingConfig(embeddingModel)
|
||||
const config = await getEmbeddingConfig(embeddingModel, workspaceId)
|
||||
|
||||
logger.info(
|
||||
`Using ${config.useAzure ? 'Azure OpenAI' : 'OpenAI'} for embeddings generation (${texts.length} texts)`
|
||||
@@ -163,9 +178,10 @@ export async function generateEmbeddings(
|
||||
*/
|
||||
export async function generateSearchEmbedding(
|
||||
query: string,
|
||||
embeddingModel = 'text-embedding-3-small'
|
||||
embeddingModel = 'text-embedding-3-small',
|
||||
workspaceId?: string | null
|
||||
): Promise<number[]> {
|
||||
const config = getEmbeddingConfig(embeddingModel)
|
||||
const config = await getEmbeddingConfig(embeddingModel, workspaceId)
|
||||
|
||||
logger.info(
|
||||
`Using ${config.useAzure ? 'Azure OpenAI' : 'OpenAI'} for search embedding generation`
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { getApiKeyWithBYOK } from '@/lib/api-key/byok'
|
||||
import { getCostMultiplier } from '@/lib/core/config/feature-flags'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import type { StreamingExecution } from '@/executor/types'
|
||||
@@ -48,7 +49,32 @@ export async function executeProviderRequest(
|
||||
if (!provider.executeRequest) {
|
||||
throw new Error(`Provider ${providerId} does not implement executeRequest`)
|
||||
}
|
||||
const sanitizedRequest = sanitizeRequest(request)
|
||||
|
||||
let resolvedRequest = sanitizeRequest(request)
|
||||
let isBYOK = false
|
||||
|
||||
if (request.workspaceId) {
|
||||
try {
|
||||
const result = await getApiKeyWithBYOK(
|
||||
providerId,
|
||||
request.model,
|
||||
request.workspaceId,
|
||||
request.apiKey
|
||||
)
|
||||
resolvedRequest = { ...resolvedRequest, apiKey: result.apiKey }
|
||||
isBYOK = result.isBYOK
|
||||
} catch (error) {
|
||||
logger.error('Failed to resolve API key:', {
|
||||
provider: providerId,
|
||||
model: request.model,
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
})
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
resolvedRequest.isBYOK = isBYOK
|
||||
const sanitizedRequest = resolvedRequest
|
||||
|
||||
if (sanitizedRequest.responseFormat) {
|
||||
if (
|
||||
@@ -88,7 +114,8 @@ export async function executeProviderRequest(
|
||||
const { input: promptTokens = 0, output: completionTokens = 0 } = response.tokens
|
||||
const useCachedInput = !!request.context && request.context.length > 0
|
||||
|
||||
if (shouldBillModelUsage(response.model)) {
|
||||
const shouldBill = shouldBillModelUsage(response.model) && !isBYOK
|
||||
if (shouldBill) {
|
||||
const costMultiplier = getCostMultiplier()
|
||||
response.cost = calculateCost(
|
||||
response.model,
|
||||
@@ -109,9 +136,13 @@ export async function executeProviderRequest(
|
||||
updatedAt: new Date().toISOString(),
|
||||
},
|
||||
}
|
||||
logger.debug(
|
||||
`Not billing model usage for ${response.model} - user provided API key or not hosted model`
|
||||
)
|
||||
if (isBYOK) {
|
||||
logger.debug(`Not billing model usage for ${response.model} - workspace BYOK key used`)
|
||||
} else {
|
||||
logger.debug(
|
||||
`Not billing model usage for ${response.model} - user provided API key or not hosted model`
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -134,12 +134,12 @@ export interface Message {
|
||||
|
||||
export interface ProviderRequest {
|
||||
model: string
|
||||
systemPrompt: string
|
||||
systemPrompt?: string
|
||||
context?: string
|
||||
tools?: ProviderToolConfig[]
|
||||
temperature?: number
|
||||
maxTokens?: number
|
||||
apiKey: string
|
||||
apiKey?: string
|
||||
messages?: Message[]
|
||||
responseFormat?: {
|
||||
name: string
|
||||
@@ -158,6 +158,7 @@ export interface ProviderRequest {
|
||||
blockData?: Record<string, any>
|
||||
blockNameMapping?: Record<string, string>
|
||||
isCopilotRequest?: boolean
|
||||
isBYOK?: boolean
|
||||
azureEndpoint?: string
|
||||
azureApiVersion?: string
|
||||
vertexProject?: string
|
||||
|
||||
@@ -4,6 +4,7 @@ export const API_ENDPOINTS = {
|
||||
WORKFLOWS: '/api/workflows',
|
||||
WORKSPACE_PERMISSIONS: (id: string) => `/api/workspaces/${id}/permissions`,
|
||||
WORKSPACE_ENVIRONMENT: (id: string) => `/api/workspaces/${id}/environment`,
|
||||
WORKSPACE_BYOK_KEYS: (id: string) => `/api/workspaces/${id}/byok-keys`,
|
||||
}
|
||||
|
||||
export const COPILOT_TOOL_DISPLAY_NAMES: Record<string, string> = {
|
||||
|
||||
@@ -25,6 +25,7 @@ export const searchTool: ToolConfig<SearchParams, SearchResponse> = {
|
||||
}),
|
||||
body: (params) => ({
|
||||
query: params.query,
|
||||
workspaceId: params._context?.workspaceId,
|
||||
}),
|
||||
},
|
||||
|
||||
|
||||
@@ -2,6 +2,11 @@ import type { ToolResponse } from '@/tools/types'
|
||||
|
||||
export interface SearchParams {
|
||||
query: string
|
||||
_context?: {
|
||||
workflowId?: string
|
||||
workspaceId?: string
|
||||
executionId?: string
|
||||
}
|
||||
}
|
||||
|
||||
export interface SearchResponse extends ToolResponse {
|
||||
|
||||
14
packages/db/migrations/0133_smiling_cargill.sql
Normal file
14
packages/db/migrations/0133_smiling_cargill.sql
Normal file
@@ -0,0 +1,14 @@
|
||||
CREATE TABLE "workspace_byok_keys" (
|
||||
"id" text PRIMARY KEY NOT NULL,
|
||||
"workspace_id" text NOT NULL,
|
||||
"provider_id" text NOT NULL,
|
||||
"encrypted_api_key" text NOT NULL,
|
||||
"created_by" text,
|
||||
"created_at" timestamp DEFAULT now() NOT NULL,
|
||||
"updated_at" timestamp DEFAULT now() NOT NULL
|
||||
);
|
||||
--> statement-breakpoint
|
||||
ALTER TABLE "workspace_byok_keys" ADD CONSTRAINT "workspace_byok_keys_workspace_id_workspace_id_fk" FOREIGN KEY ("workspace_id") REFERENCES "public"."workspace"("id") ON DELETE cascade ON UPDATE no action;--> statement-breakpoint
|
||||
ALTER TABLE "workspace_byok_keys" ADD CONSTRAINT "workspace_byok_keys_created_by_user_id_fk" FOREIGN KEY ("created_by") REFERENCES "public"."user"("id") ON DELETE set null ON UPDATE no action;--> statement-breakpoint
|
||||
CREATE UNIQUE INDEX "workspace_byok_provider_unique" ON "workspace_byok_keys" USING btree ("workspace_id","provider_id");--> statement-breakpoint
|
||||
CREATE INDEX "workspace_byok_workspace_idx" ON "workspace_byok_keys" USING btree ("workspace_id");
|
||||
8571
packages/db/migrations/meta/0133_snapshot.json
Normal file
8571
packages/db/migrations/meta/0133_snapshot.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -925,6 +925,13 @@
|
||||
"when": 1766529613309,
|
||||
"tag": "0132_dazzling_leech",
|
||||
"breakpoints": true
|
||||
},
|
||||
{
|
||||
"idx": 133,
|
||||
"version": "7",
|
||||
"when": 1766607372265,
|
||||
"tag": "0133_smiling_cargill",
|
||||
"breakpoints": true
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -420,6 +420,28 @@ export const workspaceEnvironment = pgTable(
|
||||
})
|
||||
)
|
||||
|
||||
export const workspaceBYOKKeys = pgTable(
|
||||
'workspace_byok_keys',
|
||||
{
|
||||
id: text('id').primaryKey(),
|
||||
workspaceId: text('workspace_id')
|
||||
.notNull()
|
||||
.references(() => workspace.id, { onDelete: 'cascade' }),
|
||||
providerId: text('provider_id').notNull(),
|
||||
encryptedApiKey: text('encrypted_api_key').notNull(),
|
||||
createdBy: text('created_by').references(() => user.id, { onDelete: 'set null' }),
|
||||
createdAt: timestamp('created_at').notNull().defaultNow(),
|
||||
updatedAt: timestamp('updated_at').notNull().defaultNow(),
|
||||
},
|
||||
(table) => ({
|
||||
workspaceProviderUnique: uniqueIndex('workspace_byok_provider_unique').on(
|
||||
table.workspaceId,
|
||||
table.providerId
|
||||
),
|
||||
workspaceIdx: index('workspace_byok_workspace_idx').on(table.workspaceId),
|
||||
})
|
||||
)
|
||||
|
||||
export const settings = pgTable('settings', {
|
||||
id: text('id').primaryKey(), // Use the user id as the key
|
||||
userId: text('user_id')
|
||||
|
||||
Reference in New Issue
Block a user