mirror of
https://github.com/simstudioai/sim.git
synced 2026-04-06 03:00:16 -04:00
feat(knowledge): add Live sync option to KB connectors + fix embedding billing (#3959)
* feat(knowledge): add Live sync option to KB connector modal for Max/Enterprise users Adds a "Live" (every 5 min) sync frequency option gated to Max and Enterprise plan users. Includes client-side badge + disabled state, shared sync intervals constant, and server-side plan validation on both POST and PATCH connector routes. * fix(knowledge): record embedding usage cost for KB document processing Adds billing tracking to the KB embedding pipeline, which was previously generating OpenAI API calls with no cost recorded. Token counts are now captured from the actual API response and recorded via recordUsage after successful embedding insertion. BYOK workspaces are excluded from billing. Applies to all execution paths: direct, BullMQ, and Trigger.dev. * fix(knowledge): simplify embedding billing — use calculateCost, return modelName - Use calculateCost() from @/providers/utils instead of inline formula, consistent with how LLM billing works throughout the platform - Return modelName from GenerateEmbeddingsResult so billing uses the actual model (handles custom Azure deployments) instead of a hardcoded fallback string - Fix docs-chunker.ts empty-path fallback to satisfy full GenerateEmbeddingsResult type * fix(knowledge): remove dev bypass from hasLiveSyncAccess * chore(knowledge): rename sync-intervals to consts, fix stale TSDoc comment * improvement(knowledge): extract MaxBadge component, capture billing config once per document * fix(knowledge): add knowledge-base to usage_log_source enum, fix docs-chunker type * fix(knowledge): generate migration for knowledge-base usage_log_source enum value * fix(knowledge): add knowledge-base to usage_log_source enum via drizzle-kit * fix(knowledge): fix search embedding test mocks, parallelize billing lookups * fix(knowledge): warn when embedding model has no pricing entry * fix(knowledge): call checkAndBillOverageThreshold after embedding usage
This commit is contained in:
@@ -13,6 +13,7 @@ import { z } from 'zod'
|
||||
import { decryptApiKey } from '@/lib/api-key/crypto'
|
||||
import { AuditAction, AuditResourceType, recordAudit } from '@/lib/audit/log'
|
||||
import { checkSessionOrInternalAuth } from '@/lib/auth/hybrid'
|
||||
import { hasLiveSyncAccess } from '@/lib/billing/core/subscription'
|
||||
import { generateRequestId } from '@/lib/core/utils/request'
|
||||
import { deleteDocumentStorageFiles } from '@/lib/knowledge/documents/service'
|
||||
import { cleanupUnusedTagDefinitions } from '@/lib/knowledge/tags/service'
|
||||
@@ -116,6 +117,20 @@ export async function PATCH(request: NextRequest, { params }: RouteParams) {
|
||||
)
|
||||
}
|
||||
|
||||
if (
|
||||
parsed.data.syncIntervalMinutes !== undefined &&
|
||||
parsed.data.syncIntervalMinutes > 0 &&
|
||||
parsed.data.syncIntervalMinutes < 60
|
||||
) {
|
||||
const canUseLiveSync = await hasLiveSyncAccess(auth.userId)
|
||||
if (!canUseLiveSync) {
|
||||
return NextResponse.json(
|
||||
{ error: 'Live sync requires a Max or Enterprise plan' },
|
||||
{ status: 403 }
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if (parsed.data.sourceConfig !== undefined) {
|
||||
const existingRows = await db
|
||||
.select()
|
||||
|
||||
@@ -7,6 +7,7 @@ import { z } from 'zod'
|
||||
import { encryptApiKey } from '@/lib/api-key/crypto'
|
||||
import { AuditAction, AuditResourceType, recordAudit } from '@/lib/audit/log'
|
||||
import { checkSessionOrInternalAuth } from '@/lib/auth/hybrid'
|
||||
import { hasLiveSyncAccess } from '@/lib/billing/core/subscription'
|
||||
import { generateRequestId } from '@/lib/core/utils/request'
|
||||
import { dispatchSync } from '@/lib/knowledge/connectors/sync-engine'
|
||||
import { allocateTagSlots } from '@/lib/knowledge/constants'
|
||||
@@ -97,6 +98,16 @@ export async function POST(request: NextRequest, { params }: { params: Promise<{
|
||||
|
||||
const { connectorType, credentialId, apiKey, sourceConfig, syncIntervalMinutes } = parsed.data
|
||||
|
||||
if (syncIntervalMinutes > 0 && syncIntervalMinutes < 60) {
|
||||
const canUseLiveSync = await hasLiveSyncAccess(auth.userId)
|
||||
if (!canUseLiveSync) {
|
||||
return NextResponse.json(
|
||||
{ error: 'Live sync requires a Max or Enterprise plan' },
|
||||
{ status: 403 }
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
const connectorConfig = CONNECTOR_REGISTRY[connectorType]
|
||||
if (!connectorConfig) {
|
||||
return NextResponse.json(
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
* @vitest-environment node
|
||||
*/
|
||||
import { createEnvMock, databaseMock, loggerMock } from '@sim/testing'
|
||||
import { mockNextFetchResponse } from '@sim/testing/mocks'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
vi.mock('drizzle-orm')
|
||||
@@ -14,16 +15,6 @@ vi.mock('@/lib/knowledge/documents/utils', () => ({
|
||||
retryWithExponentialBackoff: (fn: any) => fn(),
|
||||
}))
|
||||
|
||||
vi.stubGlobal(
|
||||
'fetch',
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
data: [{ embedding: [0.1, 0.2, 0.3] }],
|
||||
}),
|
||||
})
|
||||
)
|
||||
|
||||
vi.mock('@/lib/core/config/env', () => createEnvMock())
|
||||
|
||||
import {
|
||||
@@ -178,17 +169,16 @@ describe('Knowledge Search Utils', () => {
|
||||
OPENAI_API_KEY: 'test-openai-key',
|
||||
})
|
||||
|
||||
const fetchSpy = vi.mocked(fetch)
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
mockNextFetchResponse({
|
||||
json: {
|
||||
data: [{ embedding: [0.1, 0.2, 0.3] }],
|
||||
}),
|
||||
} as any)
|
||||
usage: { prompt_tokens: 1, total_tokens: 1 },
|
||||
},
|
||||
})
|
||||
|
||||
const result = await generateSearchEmbedding('test query')
|
||||
|
||||
expect(fetchSpy).toHaveBeenCalledWith(
|
||||
expect(vi.mocked(fetch)).toHaveBeenCalledWith(
|
||||
'https://test.openai.azure.com/openai/deployments/text-embedding-ada-002/embeddings?api-version=2024-12-01-preview',
|
||||
expect.objectContaining({
|
||||
headers: expect.objectContaining({
|
||||
@@ -209,17 +199,16 @@ describe('Knowledge Search Utils', () => {
|
||||
OPENAI_API_KEY: 'test-openai-key',
|
||||
})
|
||||
|
||||
const fetchSpy = vi.mocked(fetch)
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
mockNextFetchResponse({
|
||||
json: {
|
||||
data: [{ embedding: [0.1, 0.2, 0.3] }],
|
||||
}),
|
||||
} as any)
|
||||
usage: { prompt_tokens: 1, total_tokens: 1 },
|
||||
},
|
||||
})
|
||||
|
||||
const result = await generateSearchEmbedding('test query')
|
||||
|
||||
expect(fetchSpy).toHaveBeenCalledWith(
|
||||
expect(vi.mocked(fetch)).toHaveBeenCalledWith(
|
||||
'https://api.openai.com/v1/embeddings',
|
||||
expect.objectContaining({
|
||||
headers: expect.objectContaining({
|
||||
@@ -243,17 +232,16 @@ describe('Knowledge Search Utils', () => {
|
||||
OPENAI_API_KEY: 'test-openai-key',
|
||||
})
|
||||
|
||||
const fetchSpy = vi.mocked(fetch)
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
mockNextFetchResponse({
|
||||
json: {
|
||||
data: [{ embedding: [0.1, 0.2, 0.3] }],
|
||||
}),
|
||||
} as any)
|
||||
usage: { prompt_tokens: 1, total_tokens: 1 },
|
||||
},
|
||||
})
|
||||
|
||||
await generateSearchEmbedding('test query')
|
||||
|
||||
expect(fetchSpy).toHaveBeenCalledWith(
|
||||
expect(vi.mocked(fetch)).toHaveBeenCalledWith(
|
||||
expect.stringContaining('api-version='),
|
||||
expect.any(Object)
|
||||
)
|
||||
@@ -273,17 +261,16 @@ describe('Knowledge Search Utils', () => {
|
||||
OPENAI_API_KEY: 'test-openai-key',
|
||||
})
|
||||
|
||||
const fetchSpy = vi.mocked(fetch)
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
mockNextFetchResponse({
|
||||
json: {
|
||||
data: [{ embedding: [0.1, 0.2, 0.3] }],
|
||||
}),
|
||||
} as any)
|
||||
usage: { prompt_tokens: 1, total_tokens: 1 },
|
||||
},
|
||||
})
|
||||
|
||||
await generateSearchEmbedding('test query', 'text-embedding-3-small')
|
||||
|
||||
expect(fetchSpy).toHaveBeenCalledWith(
|
||||
expect(vi.mocked(fetch)).toHaveBeenCalledWith(
|
||||
'https://test.openai.azure.com/openai/deployments/custom-embedding-model/embeddings?api-version=2024-12-01-preview',
|
||||
expect.any(Object)
|
||||
)
|
||||
@@ -311,13 +298,12 @@ describe('Knowledge Search Utils', () => {
|
||||
KB_OPENAI_MODEL_NAME: 'text-embedding-ada-002',
|
||||
})
|
||||
|
||||
const fetchSpy = vi.mocked(fetch)
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
mockNextFetchResponse({
|
||||
ok: false,
|
||||
status: 404,
|
||||
statusText: 'Not Found',
|
||||
text: async () => 'Deployment not found',
|
||||
} as any)
|
||||
text: 'Deployment not found',
|
||||
})
|
||||
|
||||
await expect(generateSearchEmbedding('test query')).rejects.toThrow('Embedding API failed')
|
||||
|
||||
@@ -332,13 +318,12 @@ describe('Knowledge Search Utils', () => {
|
||||
OPENAI_API_KEY: 'test-openai-key',
|
||||
})
|
||||
|
||||
const fetchSpy = vi.mocked(fetch)
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
mockNextFetchResponse({
|
||||
ok: false,
|
||||
status: 429,
|
||||
statusText: 'Too Many Requests',
|
||||
text: async () => 'Rate limit exceeded',
|
||||
} as any)
|
||||
text: 'Rate limit exceeded',
|
||||
})
|
||||
|
||||
await expect(generateSearchEmbedding('test query')).rejects.toThrow('Embedding API failed')
|
||||
|
||||
@@ -356,17 +341,16 @@ describe('Knowledge Search Utils', () => {
|
||||
KB_OPENAI_MODEL_NAME: 'text-embedding-ada-002',
|
||||
})
|
||||
|
||||
const fetchSpy = vi.mocked(fetch)
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
mockNextFetchResponse({
|
||||
json: {
|
||||
data: [{ embedding: [0.1, 0.2, 0.3] }],
|
||||
}),
|
||||
} as any)
|
||||
usage: { prompt_tokens: 1, total_tokens: 1 },
|
||||
},
|
||||
})
|
||||
|
||||
await generateSearchEmbedding('test query')
|
||||
|
||||
expect(fetchSpy).toHaveBeenCalledWith(
|
||||
expect(vi.mocked(fetch)).toHaveBeenCalledWith(
|
||||
expect.any(String),
|
||||
expect.objectContaining({
|
||||
body: JSON.stringify({
|
||||
@@ -387,17 +371,16 @@ describe('Knowledge Search Utils', () => {
|
||||
OPENAI_API_KEY: 'test-openai-key',
|
||||
})
|
||||
|
||||
const fetchSpy = vi.mocked(fetch)
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
mockNextFetchResponse({
|
||||
json: {
|
||||
data: [{ embedding: [0.1, 0.2, 0.3] }],
|
||||
}),
|
||||
} as any)
|
||||
usage: { prompt_tokens: 1, total_tokens: 1 },
|
||||
},
|
||||
})
|
||||
|
||||
await generateSearchEmbedding('test query', 'text-embedding-3-small')
|
||||
|
||||
expect(fetchSpy).toHaveBeenCalledWith(
|
||||
expect(vi.mocked(fetch)).toHaveBeenCalledWith(
|
||||
expect.any(String),
|
||||
expect.objectContaining({
|
||||
body: JSON.stringify({
|
||||
|
||||
@@ -77,6 +77,7 @@ vi.stubGlobal(
|
||||
{ embedding: [0.1, 0.2], index: 0 },
|
||||
{ embedding: [0.3, 0.4], index: 1 },
|
||||
],
|
||||
usage: { prompt_tokens: 2, total_tokens: 2 },
|
||||
}),
|
||||
})
|
||||
)
|
||||
@@ -294,7 +295,7 @@ describe('Knowledge Utils', () => {
|
||||
it.concurrent('should return same length as input', async () => {
|
||||
const result = await generateEmbeddings(['a', 'b'])
|
||||
|
||||
expect(result.length).toBe(2)
|
||||
expect(result.embeddings.length).toBe(2)
|
||||
})
|
||||
|
||||
it('should use Azure OpenAI when Azure config is provided', async () => {
|
||||
@@ -313,6 +314,7 @@ describe('Knowledge Utils', () => {
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
data: [{ embedding: [0.1, 0.2], index: 0 }],
|
||||
usage: { prompt_tokens: 1, total_tokens: 1 },
|
||||
}),
|
||||
} as any)
|
||||
|
||||
@@ -342,6 +344,7 @@ describe('Knowledge Utils', () => {
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
data: [{ embedding: [0.1, 0.2], index: 0 }],
|
||||
usage: { prompt_tokens: 1, total_tokens: 1 },
|
||||
}),
|
||||
} as any)
|
||||
|
||||
|
||||
@@ -19,26 +19,23 @@ import {
|
||||
ModalHeader,
|
||||
Tooltip,
|
||||
} from '@/components/emcn'
|
||||
import { getSubscriptionAccessState } from '@/lib/billing/client'
|
||||
import { consumeOAuthReturnContext } from '@/lib/credentials/client-state'
|
||||
import { getProviderIdFromServiceId, type OAuthProvider } from '@/lib/oauth'
|
||||
import { OAuthModal } from '@/app/workspace/[workspaceId]/components/oauth-modal'
|
||||
import { ConnectorSelectorField } from '@/app/workspace/[workspaceId]/knowledge/[id]/components/add-connector-modal/components/connector-selector-field'
|
||||
import { SYNC_INTERVALS } from '@/app/workspace/[workspaceId]/knowledge/[id]/components/consts'
|
||||
import { MaxBadge } from '@/app/workspace/[workspaceId]/knowledge/[id]/components/max-badge'
|
||||
import { isBillingEnabled } from '@/app/workspace/[workspaceId]/settings/navigation'
|
||||
import { getDependsOnFields } from '@/blocks/utils'
|
||||
import { CONNECTOR_REGISTRY } from '@/connectors/registry'
|
||||
import type { ConnectorConfig, ConnectorConfigField } from '@/connectors/types'
|
||||
import { useCreateConnector } from '@/hooks/queries/kb/connectors'
|
||||
import { useOAuthCredentials } from '@/hooks/queries/oauth/oauth-credentials'
|
||||
import { useSubscriptionData } from '@/hooks/queries/subscription'
|
||||
import type { SelectorKey } from '@/hooks/selectors/types'
|
||||
import { useCredentialRefreshTriggers } from '@/hooks/use-credential-refresh-triggers'
|
||||
|
||||
const SYNC_INTERVALS = [
|
||||
{ label: 'Every hour', value: 60 },
|
||||
{ label: 'Every 6 hours', value: 360 },
|
||||
{ label: 'Daily', value: 1440 },
|
||||
{ label: 'Weekly', value: 10080 },
|
||||
{ label: 'Manual only', value: 0 },
|
||||
] as const
|
||||
|
||||
const CONNECTOR_ENTRIES = Object.entries(CONNECTOR_REGISTRY)
|
||||
|
||||
interface AddConnectorModalProps {
|
||||
@@ -75,6 +72,10 @@ export function AddConnectorModal({
|
||||
const { workspaceId } = useParams<{ workspaceId: string }>()
|
||||
const { mutate: createConnector, isPending: isCreating } = useCreateConnector()
|
||||
|
||||
const { data: subscriptionResponse } = useSubscriptionData({ enabled: isBillingEnabled })
|
||||
const subscriptionAccess = getSubscriptionAccessState(subscriptionResponse?.data)
|
||||
const hasMaxAccess = !isBillingEnabled || subscriptionAccess.hasUsableMaxAccess
|
||||
|
||||
const connectorConfig = selectedType ? CONNECTOR_REGISTRY[selectedType] : null
|
||||
const isApiKeyMode = connectorConfig?.auth.mode === 'apiKey'
|
||||
const connectorProviderId = useMemo(
|
||||
@@ -528,8 +529,13 @@ export function AddConnectorModal({
|
||||
onValueChange={(val) => setSyncInterval(Number(val))}
|
||||
>
|
||||
{SYNC_INTERVALS.map((interval) => (
|
||||
<ButtonGroupItem key={interval.value} value={String(interval.value)}>
|
||||
<ButtonGroupItem
|
||||
key={interval.value}
|
||||
value={String(interval.value)}
|
||||
disabled={interval.requiresMax && !hasMaxAccess}
|
||||
>
|
||||
{interval.label}
|
||||
{interval.requiresMax && !hasMaxAccess && <MaxBadge />}
|
||||
</ButtonGroupItem>
|
||||
))}
|
||||
</ButtonGroup>
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
export const SYNC_INTERVALS = [
|
||||
{ label: 'Live', value: 5, requiresMax: true },
|
||||
{ label: 'Every hour', value: 60, requiresMax: false },
|
||||
{ label: 'Every 6 hours', value: 360, requiresMax: false },
|
||||
{ label: 'Daily', value: 1440, requiresMax: false },
|
||||
{ label: 'Weekly', value: 10080, requiresMax: false },
|
||||
{ label: 'Manual only', value: 0, requiresMax: false },
|
||||
] as const
|
||||
@@ -21,6 +21,10 @@ import {
|
||||
ModalTabsTrigger,
|
||||
Skeleton,
|
||||
} from '@/components/emcn'
|
||||
import { getSubscriptionAccessState } from '@/lib/billing/client'
|
||||
import { SYNC_INTERVALS } from '@/app/workspace/[workspaceId]/knowledge/[id]/components/consts'
|
||||
import { MaxBadge } from '@/app/workspace/[workspaceId]/knowledge/[id]/components/max-badge'
|
||||
import { isBillingEnabled } from '@/app/workspace/[workspaceId]/settings/navigation'
|
||||
import { CONNECTOR_REGISTRY } from '@/connectors/registry'
|
||||
import type { ConnectorConfig } from '@/connectors/types'
|
||||
import type { ConnectorData } from '@/hooks/queries/kb/connectors'
|
||||
@@ -30,17 +34,10 @@ import {
|
||||
useRestoreConnectorDocument,
|
||||
useUpdateConnector,
|
||||
} from '@/hooks/queries/kb/connectors'
|
||||
import { useSubscriptionData } from '@/hooks/queries/subscription'
|
||||
|
||||
const logger = createLogger('EditConnectorModal')
|
||||
|
||||
const SYNC_INTERVALS = [
|
||||
{ label: 'Every hour', value: 60 },
|
||||
{ label: 'Every 6 hours', value: 360 },
|
||||
{ label: 'Daily', value: 1440 },
|
||||
{ label: 'Weekly', value: 10080 },
|
||||
{ label: 'Manual only', value: 0 },
|
||||
] as const
|
||||
|
||||
/** Keys injected by the sync engine — not user-editable */
|
||||
const INTERNAL_CONFIG_KEYS = new Set(['tagSlotMapping', 'disabledTagIds'])
|
||||
|
||||
@@ -76,6 +73,10 @@ export function EditConnectorModal({
|
||||
|
||||
const { mutate: updateConnector, isPending: isSaving } = useUpdateConnector()
|
||||
|
||||
const { data: subscriptionResponse } = useSubscriptionData({ enabled: isBillingEnabled })
|
||||
const subscriptionAccess = getSubscriptionAccessState(subscriptionResponse?.data)
|
||||
const hasMaxAccess = !isBillingEnabled || subscriptionAccess.hasUsableMaxAccess
|
||||
|
||||
const hasChanges = useMemo(() => {
|
||||
if (syncInterval !== connector.syncIntervalMinutes) return true
|
||||
for (const [key, value] of Object.entries(sourceConfig)) {
|
||||
@@ -146,6 +147,7 @@ export function EditConnectorModal({
|
||||
setSourceConfig={setSourceConfig}
|
||||
syncInterval={syncInterval}
|
||||
setSyncInterval={setSyncInterval}
|
||||
hasMaxAccess={hasMaxAccess}
|
||||
error={error}
|
||||
/>
|
||||
</ModalTabsContent>
|
||||
@@ -184,6 +186,7 @@ interface SettingsTabProps {
|
||||
setSourceConfig: React.Dispatch<React.SetStateAction<Record<string, string>>>
|
||||
syncInterval: number
|
||||
setSyncInterval: (v: number) => void
|
||||
hasMaxAccess: boolean
|
||||
error: string | null
|
||||
}
|
||||
|
||||
@@ -193,6 +196,7 @@ function SettingsTab({
|
||||
setSourceConfig,
|
||||
syncInterval,
|
||||
setSyncInterval,
|
||||
hasMaxAccess,
|
||||
error,
|
||||
}: SettingsTabProps) {
|
||||
return (
|
||||
@@ -234,8 +238,13 @@ function SettingsTab({
|
||||
onValueChange={(val) => setSyncInterval(Number(val))}
|
||||
>
|
||||
{SYNC_INTERVALS.map((interval) => (
|
||||
<ButtonGroupItem key={interval.value} value={String(interval.value)}>
|
||||
<ButtonGroupItem
|
||||
key={interval.value}
|
||||
value={String(interval.value)}
|
||||
disabled={interval.requiresMax && !hasMaxAccess}
|
||||
>
|
||||
{interval.label}
|
||||
{interval.requiresMax && !hasMaxAccess && <MaxBadge />}
|
||||
</ButtonGroupItem>
|
||||
))}
|
||||
</ButtonGroup>
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
export function MaxBadge() {
|
||||
return (
|
||||
<span className='ml-1 shrink-0 rounded-[3px] bg-[var(--surface-5)] px-1 py-[1px] font-medium text-[9px] text-[var(--text-icon)] uppercase tracking-wide'>
|
||||
Max
|
||||
</span>
|
||||
)
|
||||
}
|
||||
@@ -448,9 +448,11 @@ export async function hasInboxAccess(userId: string): Promise<boolean> {
|
||||
if (!isProd) {
|
||||
return true
|
||||
}
|
||||
const sub = await getHighestPrioritySubscription(userId)
|
||||
const [sub, billingStatus] = await Promise.all([
|
||||
getHighestPrioritySubscription(userId),
|
||||
getEffectiveBillingStatus(userId),
|
||||
])
|
||||
if (!sub) return false
|
||||
const billingStatus = await getEffectiveBillingStatus(userId)
|
||||
if (!hasUsableSubscriptionAccess(sub.status, billingStatus.billingBlocked)) return false
|
||||
return getPlanTierCredits(sub.plan) >= 25000 || checkEnterprisePlan(sub)
|
||||
} catch (error) {
|
||||
@@ -459,6 +461,30 @@ export async function hasInboxAccess(userId: string): Promise<boolean> {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if user has access to live sync (every 5 minutes) for KB connectors
|
||||
* Returns true if:
|
||||
* - Self-hosted deployment, OR
|
||||
* - User has a Max plan (credits >= 25000) or enterprise plan
|
||||
*/
|
||||
export async function hasLiveSyncAccess(userId: string): Promise<boolean> {
|
||||
try {
|
||||
if (!isHosted) {
|
||||
return true
|
||||
}
|
||||
const [sub, billingStatus] = await Promise.all([
|
||||
getHighestPrioritySubscription(userId),
|
||||
getEffectiveBillingStatus(userId),
|
||||
])
|
||||
if (!sub) return false
|
||||
if (!hasUsableSubscriptionAccess(sub.status, billingStatus.billingBlocked)) return false
|
||||
return getPlanTierCredits(sub.plan) >= 25000 || checkEnterprisePlan(sub)
|
||||
} catch (error) {
|
||||
logger.error('Error checking live sync access', { error, userId })
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if user has exceeded their cost limit based on current period usage
|
||||
*/
|
||||
|
||||
@@ -21,6 +21,7 @@ export type UsageLogSource =
|
||||
| 'workspace-chat'
|
||||
| 'mcp_copilot'
|
||||
| 'mothership_block'
|
||||
| 'knowledge-base'
|
||||
|
||||
/**
|
||||
* Metadata for 'model' category charges
|
||||
|
||||
@@ -81,7 +81,8 @@ export class DocsChunker {
|
||||
const textChunks = await this.splitContent(markdownContent)
|
||||
|
||||
logger.info(`Generating embeddings for ${textChunks.length} chunks in ${relativePath}`)
|
||||
const embeddings = textChunks.length > 0 ? await generateEmbeddings(textChunks) : []
|
||||
const embeddings: number[][] =
|
||||
textChunks.length > 0 ? (await generateEmbeddings(textChunks)).embeddings : []
|
||||
const embeddingModel = 'text-embedding-3-small'
|
||||
|
||||
const chunks: DocChunk[] = []
|
||||
|
||||
@@ -110,7 +110,7 @@ export async function createChunk(
|
||||
workspaceId?: string | null
|
||||
): Promise<ChunkData> {
|
||||
logger.info(`[${requestId}] Generating embedding for manual chunk`)
|
||||
const embeddings = await generateEmbeddings([chunkData.content], undefined, workspaceId)
|
||||
const { embeddings } = await generateEmbeddings([chunkData.content], undefined, workspaceId)
|
||||
|
||||
// Calculate accurate token count
|
||||
const tokenCount = estimateTokenCount(chunkData.content, 'openai')
|
||||
@@ -359,7 +359,7 @@ export async function updateChunk(
|
||||
if (content !== currentChunk[0].content) {
|
||||
logger.info(`[${requestId}] Content changed, regenerating embedding for chunk ${chunkId}`)
|
||||
|
||||
const embeddings = await generateEmbeddings([content], undefined, workspaceId)
|
||||
const { embeddings } = await generateEmbeddings([content], undefined, workspaceId)
|
||||
|
||||
// Calculate accurate token count
|
||||
const tokenCount = estimateTokenCount(content, 'openai')
|
||||
|
||||
@@ -25,9 +25,11 @@ import {
|
||||
type SQL,
|
||||
sql,
|
||||
} from 'drizzle-orm'
|
||||
import { recordUsage } from '@/lib/billing/core/usage-log'
|
||||
import { checkAndBillOverageThreshold } from '@/lib/billing/threshold-billing'
|
||||
import { createBullMQJobData, isBullMQEnabled } from '@/lib/core/bullmq'
|
||||
import { env } from '@/lib/core/config/env'
|
||||
import { isTriggerDevEnabled } from '@/lib/core/config/feature-flags'
|
||||
import { getCostMultiplier, isTriggerDevEnabled } from '@/lib/core/config/feature-flags'
|
||||
import { enqueueWorkspaceDispatch } from '@/lib/core/workspace-dispatch'
|
||||
import { processDocument } from '@/lib/knowledge/documents/document-processor'
|
||||
import type { DocumentSortField, SortOrder } from '@/lib/knowledge/documents/types'
|
||||
@@ -43,6 +45,7 @@ import type { ProcessedDocumentTags } from '@/lib/knowledge/types'
|
||||
import { deleteFile } from '@/lib/uploads/core/storage-service'
|
||||
import { extractStorageKey } from '@/lib/uploads/utils/file-utils'
|
||||
import type { DocumentProcessingPayload } from '@/background/knowledge-processing'
|
||||
import { calculateCost } from '@/providers/utils'
|
||||
|
||||
const logger = createLogger('DocumentService')
|
||||
|
||||
@@ -460,6 +463,10 @@ export async function processDocumentAsync(
|
||||
overlap: rawConfig?.overlap ?? 200,
|
||||
}
|
||||
|
||||
let totalEmbeddingTokens = 0
|
||||
let embeddingIsBYOK = false
|
||||
let embeddingModelName = 'text-embedding-3-small'
|
||||
|
||||
await withTimeout(
|
||||
(async () => {
|
||||
const processed = await processDocument(
|
||||
@@ -500,10 +507,20 @@ export async function processDocumentAsync(
|
||||
const batchNum = Math.floor(i / batchSize) + 1
|
||||
|
||||
logger.info(`[${documentId}] Processing embedding batch ${batchNum}/${totalBatches}`)
|
||||
const batchEmbeddings = await generateEmbeddings(batch, undefined, kb[0].workspaceId)
|
||||
const {
|
||||
embeddings: batchEmbeddings,
|
||||
totalTokens: batchTokens,
|
||||
isBYOK,
|
||||
modelName,
|
||||
} = await generateEmbeddings(batch, undefined, kb[0].workspaceId)
|
||||
for (const emb of batchEmbeddings) {
|
||||
embeddings.push(emb)
|
||||
}
|
||||
totalEmbeddingTokens += batchTokens
|
||||
if (i === 0) {
|
||||
embeddingIsBYOK = isBYOK
|
||||
embeddingModelName = modelName
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -638,6 +655,45 @@ export async function processDocumentAsync(
|
||||
|
||||
const processingTime = Date.now() - startTime
|
||||
logger.info(`[${documentId}] Successfully processed document in ${processingTime}ms`)
|
||||
|
||||
if (!embeddingIsBYOK && totalEmbeddingTokens > 0 && kb[0].userId) {
|
||||
try {
|
||||
const costMultiplier = getCostMultiplier()
|
||||
const { total: cost } = calculateCost(
|
||||
embeddingModelName,
|
||||
totalEmbeddingTokens,
|
||||
0,
|
||||
false,
|
||||
costMultiplier
|
||||
)
|
||||
if (cost > 0) {
|
||||
await recordUsage({
|
||||
userId: kb[0].userId,
|
||||
workspaceId: kb[0].workspaceId ?? undefined,
|
||||
entries: [
|
||||
{
|
||||
category: 'model',
|
||||
source: 'knowledge-base',
|
||||
description: embeddingModelName,
|
||||
cost,
|
||||
metadata: { inputTokens: totalEmbeddingTokens, outputTokens: 0 },
|
||||
},
|
||||
],
|
||||
additionalStats: {
|
||||
totalTokensUsed: sql`total_tokens_used + ${totalEmbeddingTokens}`,
|
||||
},
|
||||
})
|
||||
await checkAndBillOverageThreshold(kb[0].userId)
|
||||
} else {
|
||||
logger.warn(
|
||||
`[${documentId}] Embedding model "${embeddingModelName}" has no pricing entry — billing skipped`,
|
||||
{ totalEmbeddingTokens, embeddingModelName }
|
||||
)
|
||||
}
|
||||
} catch (billingError) {
|
||||
logger.error(`[${documentId}] Failed to record embedding usage`, { error: billingError })
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
const processingTime = Date.now() - startTime
|
||||
const errorMessage = error instanceof Error ? error.message : 'Unknown error'
|
||||
|
||||
@@ -35,6 +35,7 @@ interface EmbeddingConfig {
|
||||
apiUrl: string
|
||||
headers: Record<string, string>
|
||||
modelName: string
|
||||
isBYOK: boolean
|
||||
}
|
||||
|
||||
interface EmbeddingResponseItem {
|
||||
@@ -71,16 +72,19 @@ async function getEmbeddingConfig(
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
modelName: kbModelName,
|
||||
isBYOK: false,
|
||||
}
|
||||
}
|
||||
|
||||
let openaiApiKey = env.OPENAI_API_KEY
|
||||
let isBYOK = false
|
||||
|
||||
if (workspaceId) {
|
||||
const byokResult = await getBYOKKey(workspaceId, 'openai')
|
||||
if (byokResult) {
|
||||
logger.info('Using workspace BYOK key for OpenAI embeddings')
|
||||
openaiApiKey = byokResult.apiKey
|
||||
isBYOK = true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -98,12 +102,16 @@ async function getEmbeddingConfig(
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
modelName: embeddingModel,
|
||||
isBYOK,
|
||||
}
|
||||
}
|
||||
|
||||
const EMBEDDING_REQUEST_TIMEOUT_MS = 60_000
|
||||
|
||||
async function callEmbeddingAPI(inputs: string[], config: EmbeddingConfig): Promise<number[][]> {
|
||||
async function callEmbeddingAPI(
|
||||
inputs: string[],
|
||||
config: EmbeddingConfig
|
||||
): Promise<{ embeddings: number[][]; totalTokens: number }> {
|
||||
return retryWithExponentialBackoff(
|
||||
async () => {
|
||||
const useDimensions = supportsCustomDimensions(config.modelName)
|
||||
@@ -140,7 +148,10 @@ async function callEmbeddingAPI(inputs: string[], config: EmbeddingConfig): Prom
|
||||
}
|
||||
|
||||
const data: EmbeddingAPIResponse = await response.json()
|
||||
return data.data.map((item) => item.embedding)
|
||||
return {
|
||||
embeddings: data.data.map((item) => item.embedding),
|
||||
totalTokens: data.usage.total_tokens,
|
||||
}
|
||||
},
|
||||
{
|
||||
maxRetries: 3,
|
||||
@@ -178,14 +189,23 @@ async function processWithConcurrency<T, R>(
|
||||
return results
|
||||
}
|
||||
|
||||
export interface GenerateEmbeddingsResult {
|
||||
embeddings: number[][]
|
||||
totalTokens: number
|
||||
isBYOK: boolean
|
||||
modelName: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate embeddings for multiple texts with token-aware batching and parallel processing
|
||||
* Generate embeddings for multiple texts with token-aware batching and parallel processing.
|
||||
* Returns embeddings alongside actual token count, model name, and whether a workspace BYOK key
|
||||
* was used (vs. the platform's shared key) — enabling callers to make correct billing decisions.
|
||||
*/
|
||||
export async function generateEmbeddings(
|
||||
texts: string[],
|
||||
embeddingModel = 'text-embedding-3-small',
|
||||
workspaceId?: string | null
|
||||
): Promise<number[][]> {
|
||||
): Promise<GenerateEmbeddingsResult> {
|
||||
const config = await getEmbeddingConfig(embeddingModel, workspaceId)
|
||||
|
||||
const batches = batchByTokenLimit(texts, MAX_TOKENS_PER_REQUEST, embeddingModel)
|
||||
@@ -204,13 +224,20 @@ export async function generateEmbeddings(
|
||||
)
|
||||
|
||||
const allEmbeddings: number[][] = []
|
||||
let totalTokens = 0
|
||||
for (const batch of batchResults) {
|
||||
for (const emb of batch) {
|
||||
for (const emb of batch.embeddings) {
|
||||
allEmbeddings.push(emb)
|
||||
}
|
||||
totalTokens += batch.totalTokens
|
||||
}
|
||||
|
||||
return allEmbeddings
|
||||
return {
|
||||
embeddings: allEmbeddings,
|
||||
totalTokens,
|
||||
isBYOK: config.isBYOK,
|
||||
modelName: config.modelName,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -227,6 +254,6 @@ export async function generateSearchEmbedding(
|
||||
`Using ${config.useAzure ? 'Azure OpenAI' : 'OpenAI'} for search embedding generation`
|
||||
)
|
||||
|
||||
const embeddings = await callEmbeddingAPI([query], config)
|
||||
const { embeddings } = await callEmbeddingAPI([query], config)
|
||||
return embeddings[0]
|
||||
}
|
||||
|
||||
1
packages/db/migrations/0185_new_gravity.sql
Normal file
1
packages/db/migrations/0185_new_gravity.sql
Normal file
@@ -0,0 +1 @@
|
||||
ALTER TYPE "public"."usage_log_source" ADD VALUE 'knowledge-base';
|
||||
14618
packages/db/migrations/meta/0185_snapshot.json
Normal file
14618
packages/db/migrations/meta/0185_snapshot.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1289,6 +1289,13 @@
|
||||
"when": 1775149654511,
|
||||
"tag": "0184_hard_thaddeus_ross",
|
||||
"breakpoints": true
|
||||
},
|
||||
{
|
||||
"idx": 185,
|
||||
"version": "7",
|
||||
"when": 1775247973312,
|
||||
"tag": "0185_new_gravity",
|
||||
"breakpoints": true
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -2273,6 +2273,7 @@ export const usageLogSourceEnum = pgEnum('usage_log_source', [
|
||||
'workspace-chat',
|
||||
'mcp_copilot',
|
||||
'mothership_block',
|
||||
'knowledge-base',
|
||||
])
|
||||
|
||||
export const usageLog = pgTable(
|
||||
|
||||
Reference in New Issue
Block a user