feat(knowledge): add Live sync option to KB connectors + fix embedding billing (#3959)

* feat(knowledge): add Live sync option to KB connector modal for Max/Enterprise users

Adds a "Live" (every 5 min) sync frequency option gated to Max and Enterprise plan users.
Includes client-side badge + disabled state, shared sync intervals constant, and server-side
plan validation on both POST and PATCH connector routes.

* fix(knowledge): record embedding usage cost for KB document processing

Adds billing tracking to the KB embedding pipeline, which was previously
generating OpenAI API calls with no cost recorded. Token counts are now
captured from the actual API response and recorded via recordUsage after
successful embedding insertion. BYOK workspaces are excluded from billing.
Applies to all execution paths: direct, BullMQ, and Trigger.dev.

* fix(knowledge): simplify embedding billing — use calculateCost, return modelName

- Use calculateCost() from @/providers/utils instead of inline formula, consistent
  with how LLM billing works throughout the platform
- Return modelName from GenerateEmbeddingsResult so billing uses the actual model
  (handles custom Azure deployments) instead of a hardcoded fallback string
- Fix docs-chunker.ts empty-path fallback to satisfy full GenerateEmbeddingsResult type

* fix(knowledge): remove dev bypass from hasLiveSyncAccess

* chore(knowledge): rename sync-intervals to consts, fix stale TSDoc comment

* improvement(knowledge): extract MaxBadge component, capture billing config once per document

* fix(knowledge): add knowledge-base to usage_log_source enum, fix docs-chunker type

* fix(knowledge): generate migration for knowledge-base usage_log_source enum value

* fix(knowledge): add knowledge-base to usage_log_source enum via drizzle-kit

* fix(knowledge): fix search embedding test mocks, parallelize billing lookups

* fix(knowledge): warn when embedding model has no pricing entry

* fix(knowledge): call checkAndBillOverageThreshold after embedding usage
This commit is contained in:
Waleed
2026-04-04 16:49:42 -07:00
committed by GitHub
parent 7971a64e63
commit ce53275e9d
18 changed files with 14873 additions and 93 deletions

View File

@@ -13,6 +13,7 @@ import { z } from 'zod'
import { decryptApiKey } from '@/lib/api-key/crypto'
import { AuditAction, AuditResourceType, recordAudit } from '@/lib/audit/log'
import { checkSessionOrInternalAuth } from '@/lib/auth/hybrid'
import { hasLiveSyncAccess } from '@/lib/billing/core/subscription'
import { generateRequestId } from '@/lib/core/utils/request'
import { deleteDocumentStorageFiles } from '@/lib/knowledge/documents/service'
import { cleanupUnusedTagDefinitions } from '@/lib/knowledge/tags/service'
@@ -116,6 +117,20 @@ export async function PATCH(request: NextRequest, { params }: RouteParams) {
)
}
if (
parsed.data.syncIntervalMinutes !== undefined &&
parsed.data.syncIntervalMinutes > 0 &&
parsed.data.syncIntervalMinutes < 60
) {
const canUseLiveSync = await hasLiveSyncAccess(auth.userId)
if (!canUseLiveSync) {
return NextResponse.json(
{ error: 'Live sync requires a Max or Enterprise plan' },
{ status: 403 }
)
}
}
if (parsed.data.sourceConfig !== undefined) {
const existingRows = await db
.select()

View File

@@ -7,6 +7,7 @@ import { z } from 'zod'
import { encryptApiKey } from '@/lib/api-key/crypto'
import { AuditAction, AuditResourceType, recordAudit } from '@/lib/audit/log'
import { checkSessionOrInternalAuth } from '@/lib/auth/hybrid'
import { hasLiveSyncAccess } from '@/lib/billing/core/subscription'
import { generateRequestId } from '@/lib/core/utils/request'
import { dispatchSync } from '@/lib/knowledge/connectors/sync-engine'
import { allocateTagSlots } from '@/lib/knowledge/constants'
@@ -97,6 +98,16 @@ export async function POST(request: NextRequest, { params }: { params: Promise<{
const { connectorType, credentialId, apiKey, sourceConfig, syncIntervalMinutes } = parsed.data
if (syncIntervalMinutes > 0 && syncIntervalMinutes < 60) {
const canUseLiveSync = await hasLiveSyncAccess(auth.userId)
if (!canUseLiveSync) {
return NextResponse.json(
{ error: 'Live sync requires a Max or Enterprise plan' },
{ status: 403 }
)
}
}
const connectorConfig = CONNECTOR_REGISTRY[connectorType]
if (!connectorConfig) {
return NextResponse.json(

View File

@@ -5,6 +5,7 @@
* @vitest-environment node
*/
import { createEnvMock, databaseMock, loggerMock } from '@sim/testing'
import { mockNextFetchResponse } from '@sim/testing/mocks'
import { beforeEach, describe, expect, it, vi } from 'vitest'
vi.mock('drizzle-orm')
@@ -14,16 +15,6 @@ vi.mock('@/lib/knowledge/documents/utils', () => ({
retryWithExponentialBackoff: (fn: any) => fn(),
}))
vi.stubGlobal(
'fetch',
vi.fn().mockResolvedValue({
ok: true,
json: async () => ({
data: [{ embedding: [0.1, 0.2, 0.3] }],
}),
})
)
vi.mock('@/lib/core/config/env', () => createEnvMock())
import {
@@ -178,17 +169,16 @@ describe('Knowledge Search Utils', () => {
OPENAI_API_KEY: 'test-openai-key',
})
const fetchSpy = vi.mocked(fetch)
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({
mockNextFetchResponse({
json: {
data: [{ embedding: [0.1, 0.2, 0.3] }],
}),
} as any)
usage: { prompt_tokens: 1, total_tokens: 1 },
},
})
const result = await generateSearchEmbedding('test query')
expect(fetchSpy).toHaveBeenCalledWith(
expect(vi.mocked(fetch)).toHaveBeenCalledWith(
'https://test.openai.azure.com/openai/deployments/text-embedding-ada-002/embeddings?api-version=2024-12-01-preview',
expect.objectContaining({
headers: expect.objectContaining({
@@ -209,17 +199,16 @@ describe('Knowledge Search Utils', () => {
OPENAI_API_KEY: 'test-openai-key',
})
const fetchSpy = vi.mocked(fetch)
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({
mockNextFetchResponse({
json: {
data: [{ embedding: [0.1, 0.2, 0.3] }],
}),
} as any)
usage: { prompt_tokens: 1, total_tokens: 1 },
},
})
const result = await generateSearchEmbedding('test query')
expect(fetchSpy).toHaveBeenCalledWith(
expect(vi.mocked(fetch)).toHaveBeenCalledWith(
'https://api.openai.com/v1/embeddings',
expect.objectContaining({
headers: expect.objectContaining({
@@ -243,17 +232,16 @@ describe('Knowledge Search Utils', () => {
OPENAI_API_KEY: 'test-openai-key',
})
const fetchSpy = vi.mocked(fetch)
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({
mockNextFetchResponse({
json: {
data: [{ embedding: [0.1, 0.2, 0.3] }],
}),
} as any)
usage: { prompt_tokens: 1, total_tokens: 1 },
},
})
await generateSearchEmbedding('test query')
expect(fetchSpy).toHaveBeenCalledWith(
expect(vi.mocked(fetch)).toHaveBeenCalledWith(
expect.stringContaining('api-version='),
expect.any(Object)
)
@@ -273,17 +261,16 @@ describe('Knowledge Search Utils', () => {
OPENAI_API_KEY: 'test-openai-key',
})
const fetchSpy = vi.mocked(fetch)
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({
mockNextFetchResponse({
json: {
data: [{ embedding: [0.1, 0.2, 0.3] }],
}),
} as any)
usage: { prompt_tokens: 1, total_tokens: 1 },
},
})
await generateSearchEmbedding('test query', 'text-embedding-3-small')
expect(fetchSpy).toHaveBeenCalledWith(
expect(vi.mocked(fetch)).toHaveBeenCalledWith(
'https://test.openai.azure.com/openai/deployments/custom-embedding-model/embeddings?api-version=2024-12-01-preview',
expect.any(Object)
)
@@ -311,13 +298,12 @@ describe('Knowledge Search Utils', () => {
KB_OPENAI_MODEL_NAME: 'text-embedding-ada-002',
})
const fetchSpy = vi.mocked(fetch)
fetchSpy.mockResolvedValueOnce({
mockNextFetchResponse({
ok: false,
status: 404,
statusText: 'Not Found',
text: async () => 'Deployment not found',
} as any)
text: 'Deployment not found',
})
await expect(generateSearchEmbedding('test query')).rejects.toThrow('Embedding API failed')
@@ -332,13 +318,12 @@ describe('Knowledge Search Utils', () => {
OPENAI_API_KEY: 'test-openai-key',
})
const fetchSpy = vi.mocked(fetch)
fetchSpy.mockResolvedValueOnce({
mockNextFetchResponse({
ok: false,
status: 429,
statusText: 'Too Many Requests',
text: async () => 'Rate limit exceeded',
} as any)
text: 'Rate limit exceeded',
})
await expect(generateSearchEmbedding('test query')).rejects.toThrow('Embedding API failed')
@@ -356,17 +341,16 @@ describe('Knowledge Search Utils', () => {
KB_OPENAI_MODEL_NAME: 'text-embedding-ada-002',
})
const fetchSpy = vi.mocked(fetch)
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({
mockNextFetchResponse({
json: {
data: [{ embedding: [0.1, 0.2, 0.3] }],
}),
} as any)
usage: { prompt_tokens: 1, total_tokens: 1 },
},
})
await generateSearchEmbedding('test query')
expect(fetchSpy).toHaveBeenCalledWith(
expect(vi.mocked(fetch)).toHaveBeenCalledWith(
expect.any(String),
expect.objectContaining({
body: JSON.stringify({
@@ -387,17 +371,16 @@ describe('Knowledge Search Utils', () => {
OPENAI_API_KEY: 'test-openai-key',
})
const fetchSpy = vi.mocked(fetch)
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({
mockNextFetchResponse({
json: {
data: [{ embedding: [0.1, 0.2, 0.3] }],
}),
} as any)
usage: { prompt_tokens: 1, total_tokens: 1 },
},
})
await generateSearchEmbedding('test query', 'text-embedding-3-small')
expect(fetchSpy).toHaveBeenCalledWith(
expect(vi.mocked(fetch)).toHaveBeenCalledWith(
expect.any(String),
expect.objectContaining({
body: JSON.stringify({

View File

@@ -77,6 +77,7 @@ vi.stubGlobal(
{ embedding: [0.1, 0.2], index: 0 },
{ embedding: [0.3, 0.4], index: 1 },
],
usage: { prompt_tokens: 2, total_tokens: 2 },
}),
})
)
@@ -294,7 +295,7 @@ describe('Knowledge Utils', () => {
it.concurrent('should return same length as input', async () => {
const result = await generateEmbeddings(['a', 'b'])
expect(result.length).toBe(2)
expect(result.embeddings.length).toBe(2)
})
it('should use Azure OpenAI when Azure config is provided', async () => {
@@ -313,6 +314,7 @@ describe('Knowledge Utils', () => {
ok: true,
json: async () => ({
data: [{ embedding: [0.1, 0.2], index: 0 }],
usage: { prompt_tokens: 1, total_tokens: 1 },
}),
} as any)
@@ -342,6 +344,7 @@ describe('Knowledge Utils', () => {
ok: true,
json: async () => ({
data: [{ embedding: [0.1, 0.2], index: 0 }],
usage: { prompt_tokens: 1, total_tokens: 1 },
}),
} as any)

View File

@@ -19,26 +19,23 @@ import {
ModalHeader,
Tooltip,
} from '@/components/emcn'
import { getSubscriptionAccessState } from '@/lib/billing/client'
import { consumeOAuthReturnContext } from '@/lib/credentials/client-state'
import { getProviderIdFromServiceId, type OAuthProvider } from '@/lib/oauth'
import { OAuthModal } from '@/app/workspace/[workspaceId]/components/oauth-modal'
import { ConnectorSelectorField } from '@/app/workspace/[workspaceId]/knowledge/[id]/components/add-connector-modal/components/connector-selector-field'
import { SYNC_INTERVALS } from '@/app/workspace/[workspaceId]/knowledge/[id]/components/consts'
import { MaxBadge } from '@/app/workspace/[workspaceId]/knowledge/[id]/components/max-badge'
import { isBillingEnabled } from '@/app/workspace/[workspaceId]/settings/navigation'
import { getDependsOnFields } from '@/blocks/utils'
import { CONNECTOR_REGISTRY } from '@/connectors/registry'
import type { ConnectorConfig, ConnectorConfigField } from '@/connectors/types'
import { useCreateConnector } from '@/hooks/queries/kb/connectors'
import { useOAuthCredentials } from '@/hooks/queries/oauth/oauth-credentials'
import { useSubscriptionData } from '@/hooks/queries/subscription'
import type { SelectorKey } from '@/hooks/selectors/types'
import { useCredentialRefreshTriggers } from '@/hooks/use-credential-refresh-triggers'
const SYNC_INTERVALS = [
{ label: 'Every hour', value: 60 },
{ label: 'Every 6 hours', value: 360 },
{ label: 'Daily', value: 1440 },
{ label: 'Weekly', value: 10080 },
{ label: 'Manual only', value: 0 },
] as const
const CONNECTOR_ENTRIES = Object.entries(CONNECTOR_REGISTRY)
interface AddConnectorModalProps {
@@ -75,6 +72,10 @@ export function AddConnectorModal({
const { workspaceId } = useParams<{ workspaceId: string }>()
const { mutate: createConnector, isPending: isCreating } = useCreateConnector()
const { data: subscriptionResponse } = useSubscriptionData({ enabled: isBillingEnabled })
const subscriptionAccess = getSubscriptionAccessState(subscriptionResponse?.data)
const hasMaxAccess = !isBillingEnabled || subscriptionAccess.hasUsableMaxAccess
const connectorConfig = selectedType ? CONNECTOR_REGISTRY[selectedType] : null
const isApiKeyMode = connectorConfig?.auth.mode === 'apiKey'
const connectorProviderId = useMemo(
@@ -528,8 +529,13 @@ export function AddConnectorModal({
onValueChange={(val) => setSyncInterval(Number(val))}
>
{SYNC_INTERVALS.map((interval) => (
<ButtonGroupItem key={interval.value} value={String(interval.value)}>
<ButtonGroupItem
key={interval.value}
value={String(interval.value)}
disabled={interval.requiresMax && !hasMaxAccess}
>
{interval.label}
{interval.requiresMax && !hasMaxAccess && <MaxBadge />}
</ButtonGroupItem>
))}
</ButtonGroup>

View File

@@ -0,0 +1,8 @@
export const SYNC_INTERVALS = [
{ label: 'Live', value: 5, requiresMax: true },
{ label: 'Every hour', value: 60, requiresMax: false },
{ label: 'Every 6 hours', value: 360, requiresMax: false },
{ label: 'Daily', value: 1440, requiresMax: false },
{ label: 'Weekly', value: 10080, requiresMax: false },
{ label: 'Manual only', value: 0, requiresMax: false },
] as const

View File

@@ -21,6 +21,10 @@ import {
ModalTabsTrigger,
Skeleton,
} from '@/components/emcn'
import { getSubscriptionAccessState } from '@/lib/billing/client'
import { SYNC_INTERVALS } from '@/app/workspace/[workspaceId]/knowledge/[id]/components/consts'
import { MaxBadge } from '@/app/workspace/[workspaceId]/knowledge/[id]/components/max-badge'
import { isBillingEnabled } from '@/app/workspace/[workspaceId]/settings/navigation'
import { CONNECTOR_REGISTRY } from '@/connectors/registry'
import type { ConnectorConfig } from '@/connectors/types'
import type { ConnectorData } from '@/hooks/queries/kb/connectors'
@@ -30,17 +34,10 @@ import {
useRestoreConnectorDocument,
useUpdateConnector,
} from '@/hooks/queries/kb/connectors'
import { useSubscriptionData } from '@/hooks/queries/subscription'
const logger = createLogger('EditConnectorModal')
const SYNC_INTERVALS = [
{ label: 'Every hour', value: 60 },
{ label: 'Every 6 hours', value: 360 },
{ label: 'Daily', value: 1440 },
{ label: 'Weekly', value: 10080 },
{ label: 'Manual only', value: 0 },
] as const
/** Keys injected by the sync engine — not user-editable */
const INTERNAL_CONFIG_KEYS = new Set(['tagSlotMapping', 'disabledTagIds'])
@@ -76,6 +73,10 @@ export function EditConnectorModal({
const { mutate: updateConnector, isPending: isSaving } = useUpdateConnector()
const { data: subscriptionResponse } = useSubscriptionData({ enabled: isBillingEnabled })
const subscriptionAccess = getSubscriptionAccessState(subscriptionResponse?.data)
const hasMaxAccess = !isBillingEnabled || subscriptionAccess.hasUsableMaxAccess
const hasChanges = useMemo(() => {
if (syncInterval !== connector.syncIntervalMinutes) return true
for (const [key, value] of Object.entries(sourceConfig)) {
@@ -146,6 +147,7 @@ export function EditConnectorModal({
setSourceConfig={setSourceConfig}
syncInterval={syncInterval}
setSyncInterval={setSyncInterval}
hasMaxAccess={hasMaxAccess}
error={error}
/>
</ModalTabsContent>
@@ -184,6 +186,7 @@ interface SettingsTabProps {
setSourceConfig: React.Dispatch<React.SetStateAction<Record<string, string>>>
syncInterval: number
setSyncInterval: (v: number) => void
hasMaxAccess: boolean
error: string | null
}
@@ -193,6 +196,7 @@ function SettingsTab({
setSourceConfig,
syncInterval,
setSyncInterval,
hasMaxAccess,
error,
}: SettingsTabProps) {
return (
@@ -234,8 +238,13 @@ function SettingsTab({
onValueChange={(val) => setSyncInterval(Number(val))}
>
{SYNC_INTERVALS.map((interval) => (
<ButtonGroupItem key={interval.value} value={String(interval.value)}>
<ButtonGroupItem
key={interval.value}
value={String(interval.value)}
disabled={interval.requiresMax && !hasMaxAccess}
>
{interval.label}
{interval.requiresMax && !hasMaxAccess && <MaxBadge />}
</ButtonGroupItem>
))}
</ButtonGroup>

View File

@@ -0,0 +1,7 @@
export function MaxBadge() {
return (
<span className='ml-1 shrink-0 rounded-[3px] bg-[var(--surface-5)] px-1 py-[1px] font-medium text-[9px] text-[var(--text-icon)] uppercase tracking-wide'>
Max
</span>
)
}

View File

@@ -448,9 +448,11 @@ export async function hasInboxAccess(userId: string): Promise<boolean> {
if (!isProd) {
return true
}
const sub = await getHighestPrioritySubscription(userId)
const [sub, billingStatus] = await Promise.all([
getHighestPrioritySubscription(userId),
getEffectiveBillingStatus(userId),
])
if (!sub) return false
const billingStatus = await getEffectiveBillingStatus(userId)
if (!hasUsableSubscriptionAccess(sub.status, billingStatus.billingBlocked)) return false
return getPlanTierCredits(sub.plan) >= 25000 || checkEnterprisePlan(sub)
} catch (error) {
@@ -459,6 +461,30 @@ export async function hasInboxAccess(userId: string): Promise<boolean> {
}
}
/**
* Check if user has access to live sync (every 5 minutes) for KB connectors
* Returns true if:
* - Self-hosted deployment, OR
* - User has a Max plan (credits >= 25000) or enterprise plan
*/
export async function hasLiveSyncAccess(userId: string): Promise<boolean> {
try {
if (!isHosted) {
return true
}
const [sub, billingStatus] = await Promise.all([
getHighestPrioritySubscription(userId),
getEffectiveBillingStatus(userId),
])
if (!sub) return false
if (!hasUsableSubscriptionAccess(sub.status, billingStatus.billingBlocked)) return false
return getPlanTierCredits(sub.plan) >= 25000 || checkEnterprisePlan(sub)
} catch (error) {
logger.error('Error checking live sync access', { error, userId })
return false
}
}
/**
* Check if user has exceeded their cost limit based on current period usage
*/

View File

@@ -21,6 +21,7 @@ export type UsageLogSource =
| 'workspace-chat'
| 'mcp_copilot'
| 'mothership_block'
| 'knowledge-base'
/**
* Metadata for 'model' category charges

View File

@@ -81,7 +81,8 @@ export class DocsChunker {
const textChunks = await this.splitContent(markdownContent)
logger.info(`Generating embeddings for ${textChunks.length} chunks in ${relativePath}`)
const embeddings = textChunks.length > 0 ? await generateEmbeddings(textChunks) : []
const embeddings: number[][] =
textChunks.length > 0 ? (await generateEmbeddings(textChunks)).embeddings : []
const embeddingModel = 'text-embedding-3-small'
const chunks: DocChunk[] = []

View File

@@ -110,7 +110,7 @@ export async function createChunk(
workspaceId?: string | null
): Promise<ChunkData> {
logger.info(`[${requestId}] Generating embedding for manual chunk`)
const embeddings = await generateEmbeddings([chunkData.content], undefined, workspaceId)
const { embeddings } = await generateEmbeddings([chunkData.content], undefined, workspaceId)
// Calculate accurate token count
const tokenCount = estimateTokenCount(chunkData.content, 'openai')
@@ -359,7 +359,7 @@ export async function updateChunk(
if (content !== currentChunk[0].content) {
logger.info(`[${requestId}] Content changed, regenerating embedding for chunk ${chunkId}`)
const embeddings = await generateEmbeddings([content], undefined, workspaceId)
const { embeddings } = await generateEmbeddings([content], undefined, workspaceId)
// Calculate accurate token count
const tokenCount = estimateTokenCount(content, 'openai')

View File

@@ -25,9 +25,11 @@ import {
type SQL,
sql,
} from 'drizzle-orm'
import { recordUsage } from '@/lib/billing/core/usage-log'
import { checkAndBillOverageThreshold } from '@/lib/billing/threshold-billing'
import { createBullMQJobData, isBullMQEnabled } from '@/lib/core/bullmq'
import { env } from '@/lib/core/config/env'
import { isTriggerDevEnabled } from '@/lib/core/config/feature-flags'
import { getCostMultiplier, isTriggerDevEnabled } from '@/lib/core/config/feature-flags'
import { enqueueWorkspaceDispatch } from '@/lib/core/workspace-dispatch'
import { processDocument } from '@/lib/knowledge/documents/document-processor'
import type { DocumentSortField, SortOrder } from '@/lib/knowledge/documents/types'
@@ -43,6 +45,7 @@ import type { ProcessedDocumentTags } from '@/lib/knowledge/types'
import { deleteFile } from '@/lib/uploads/core/storage-service'
import { extractStorageKey } from '@/lib/uploads/utils/file-utils'
import type { DocumentProcessingPayload } from '@/background/knowledge-processing'
import { calculateCost } from '@/providers/utils'
const logger = createLogger('DocumentService')
@@ -460,6 +463,10 @@ export async function processDocumentAsync(
overlap: rawConfig?.overlap ?? 200,
}
let totalEmbeddingTokens = 0
let embeddingIsBYOK = false
let embeddingModelName = 'text-embedding-3-small'
await withTimeout(
(async () => {
const processed = await processDocument(
@@ -500,10 +507,20 @@ export async function processDocumentAsync(
const batchNum = Math.floor(i / batchSize) + 1
logger.info(`[${documentId}] Processing embedding batch ${batchNum}/${totalBatches}`)
const batchEmbeddings = await generateEmbeddings(batch, undefined, kb[0].workspaceId)
const {
embeddings: batchEmbeddings,
totalTokens: batchTokens,
isBYOK,
modelName,
} = await generateEmbeddings(batch, undefined, kb[0].workspaceId)
for (const emb of batchEmbeddings) {
embeddings.push(emb)
}
totalEmbeddingTokens += batchTokens
if (i === 0) {
embeddingIsBYOK = isBYOK
embeddingModelName = modelName
}
}
}
@@ -638,6 +655,45 @@ export async function processDocumentAsync(
const processingTime = Date.now() - startTime
logger.info(`[${documentId}] Successfully processed document in ${processingTime}ms`)
if (!embeddingIsBYOK && totalEmbeddingTokens > 0 && kb[0].userId) {
try {
const costMultiplier = getCostMultiplier()
const { total: cost } = calculateCost(
embeddingModelName,
totalEmbeddingTokens,
0,
false,
costMultiplier
)
if (cost > 0) {
await recordUsage({
userId: kb[0].userId,
workspaceId: kb[0].workspaceId ?? undefined,
entries: [
{
category: 'model',
source: 'knowledge-base',
description: embeddingModelName,
cost,
metadata: { inputTokens: totalEmbeddingTokens, outputTokens: 0 },
},
],
additionalStats: {
totalTokensUsed: sql`total_tokens_used + ${totalEmbeddingTokens}`,
},
})
await checkAndBillOverageThreshold(kb[0].userId)
} else {
logger.warn(
`[${documentId}] Embedding model "${embeddingModelName}" has no pricing entry — billing skipped`,
{ totalEmbeddingTokens, embeddingModelName }
)
}
} catch (billingError) {
logger.error(`[${documentId}] Failed to record embedding usage`, { error: billingError })
}
}
} catch (error) {
const processingTime = Date.now() - startTime
const errorMessage = error instanceof Error ? error.message : 'Unknown error'

View File

@@ -35,6 +35,7 @@ interface EmbeddingConfig {
apiUrl: string
headers: Record<string, string>
modelName: string
isBYOK: boolean
}
interface EmbeddingResponseItem {
@@ -71,16 +72,19 @@ async function getEmbeddingConfig(
'Content-Type': 'application/json',
},
modelName: kbModelName,
isBYOK: false,
}
}
let openaiApiKey = env.OPENAI_API_KEY
let isBYOK = false
if (workspaceId) {
const byokResult = await getBYOKKey(workspaceId, 'openai')
if (byokResult) {
logger.info('Using workspace BYOK key for OpenAI embeddings')
openaiApiKey = byokResult.apiKey
isBYOK = true
}
}
@@ -98,12 +102,16 @@ async function getEmbeddingConfig(
'Content-Type': 'application/json',
},
modelName: embeddingModel,
isBYOK,
}
}
const EMBEDDING_REQUEST_TIMEOUT_MS = 60_000
async function callEmbeddingAPI(inputs: string[], config: EmbeddingConfig): Promise<number[][]> {
async function callEmbeddingAPI(
inputs: string[],
config: EmbeddingConfig
): Promise<{ embeddings: number[][]; totalTokens: number }> {
return retryWithExponentialBackoff(
async () => {
const useDimensions = supportsCustomDimensions(config.modelName)
@@ -140,7 +148,10 @@ async function callEmbeddingAPI(inputs: string[], config: EmbeddingConfig): Prom
}
const data: EmbeddingAPIResponse = await response.json()
return data.data.map((item) => item.embedding)
return {
embeddings: data.data.map((item) => item.embedding),
totalTokens: data.usage.total_tokens,
}
},
{
maxRetries: 3,
@@ -178,14 +189,23 @@ async function processWithConcurrency<T, R>(
return results
}
export interface GenerateEmbeddingsResult {
embeddings: number[][]
totalTokens: number
isBYOK: boolean
modelName: string
}
/**
* Generate embeddings for multiple texts with token-aware batching and parallel processing
* Generate embeddings for multiple texts with token-aware batching and parallel processing.
* Returns embeddings alongside actual token count, model name, and whether a workspace BYOK key
* was used (vs. the platform's shared key) — enabling callers to make correct billing decisions.
*/
export async function generateEmbeddings(
texts: string[],
embeddingModel = 'text-embedding-3-small',
workspaceId?: string | null
): Promise<number[][]> {
): Promise<GenerateEmbeddingsResult> {
const config = await getEmbeddingConfig(embeddingModel, workspaceId)
const batches = batchByTokenLimit(texts, MAX_TOKENS_PER_REQUEST, embeddingModel)
@@ -204,13 +224,20 @@ export async function generateEmbeddings(
)
const allEmbeddings: number[][] = []
let totalTokens = 0
for (const batch of batchResults) {
for (const emb of batch) {
for (const emb of batch.embeddings) {
allEmbeddings.push(emb)
}
totalTokens += batch.totalTokens
}
return allEmbeddings
return {
embeddings: allEmbeddings,
totalTokens,
isBYOK: config.isBYOK,
modelName: config.modelName,
}
}
/**
@@ -227,6 +254,6 @@ export async function generateSearchEmbedding(
`Using ${config.useAzure ? 'Azure OpenAI' : 'OpenAI'} for search embedding generation`
)
const embeddings = await callEmbeddingAPI([query], config)
const { embeddings } = await callEmbeddingAPI([query], config)
return embeddings[0]
}

View File

@@ -0,0 +1 @@
ALTER TYPE "public"."usage_log_source" ADD VALUE 'knowledge-base';

File diff suppressed because it is too large Load Diff

View File

@@ -1289,6 +1289,13 @@
"when": 1775149654511,
"tag": "0184_hard_thaddeus_ross",
"breakpoints": true
},
{
"idx": 185,
"version": "7",
"when": 1775247973312,
"tag": "0185_new_gravity",
"breakpoints": true
}
]
}

View File

@@ -2273,6 +2273,7 @@ export const usageLogSourceEnum = pgEnum('usage_log_source', [
'workspace-chat',
'mcp_copilot',
'mothership_block',
'knowledge-base',
])
export const usageLog = pgTable(