mirror of
https://github.com/simstudioai/sim.git
synced 2026-02-18 10:22:00 -05:00
* fix(azure): add azure-anthropic support to router, evaluator, copilot, and tokenization * added azure anthropic values to env * fix(azure): make anthropic-version configurable for azure-anthropic provider * fix(azure): thread provider credentials through guardrails and fix translate missing bedrockAccessKeyId * updated guardrails * ack'd PR comments * fix(azure): unify credential passing pattern across all LLM handlers - Pass all provider credentials unconditionally in router, evaluator (matching agent pattern) - Remove conditional if-branching on providerId for credential fields - Thread workspaceId through guardrails → hallucination validator for BYOK key resolution - Remove getApiKey() from hallucination validator, let executeProviderRequest handle it - Resolve vertex OAuth credentials in hallucination validator matching agent handler pattern Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
342 lines
8.6 KiB
TypeScript
342 lines
8.6 KiB
TypeScript
/**
|
|
* Token estimation and accurate counting functions for different providers
|
|
*/
|
|
|
|
import { createLogger } from '@sim/logger'
|
|
import { encodingForModel, type Tiktoken } from 'js-tiktoken'
|
|
import { MIN_TEXT_LENGTH_FOR_ESTIMATION, TOKENIZATION_CONFIG } from '@/lib/tokenization/constants'
|
|
import type { TokenEstimate } from '@/lib/tokenization/types'
|
|
import { getProviderConfig } from '@/lib/tokenization/utils'
|
|
|
|
const logger = createLogger('TokenizationEstimators')
|
|
|
|
const encodingCache = new Map<string, Tiktoken>()
|
|
|
|
/**
|
|
* Get or create a cached encoding for a model
|
|
*/
|
|
function getEncoding(modelName: string): Tiktoken {
|
|
if (encodingCache.has(modelName)) {
|
|
return encodingCache.get(modelName)!
|
|
}
|
|
|
|
try {
|
|
const encoding = encodingForModel(modelName as Parameters<typeof encodingForModel>[0])
|
|
encodingCache.set(modelName, encoding)
|
|
return encoding
|
|
} catch (error) {
|
|
logger.warn(`Failed to get encoding for model ${modelName}, falling back to cl100k_base`)
|
|
const encoding = encodingForModel('gpt-4')
|
|
encodingCache.set(modelName, encoding)
|
|
return encoding
|
|
}
|
|
}
|
|
|
|
if (typeof process !== 'undefined') {
|
|
process.on('beforeExit', () => {
|
|
clearEncodingCache()
|
|
})
|
|
}
|
|
|
|
/**
|
|
* Get accurate token count for text using tiktoken
|
|
* This is the exact count OpenAI's API will use
|
|
*/
|
|
export function getAccurateTokenCount(text: string, modelName = 'text-embedding-3-small'): number {
|
|
if (!text || text.length === 0) {
|
|
return 0
|
|
}
|
|
|
|
try {
|
|
const encoding = getEncoding(modelName)
|
|
const tokens = encoding.encode(text)
|
|
return tokens.length
|
|
} catch (error) {
|
|
logger.error('Error counting tokens with tiktoken:', error)
|
|
return Math.ceil(text.length / 4)
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Get individual tokens as strings for visualization
|
|
* Returns an array of token strings that can be displayed with colors
|
|
*/
|
|
export function getTokenStrings(text: string, modelName = 'text-embedding-3-small'): string[] {
|
|
if (!text || text.length === 0) {
|
|
return []
|
|
}
|
|
|
|
try {
|
|
const encoding = getEncoding(modelName)
|
|
const tokenIds = encoding.encode(text)
|
|
|
|
const textChars = [...text]
|
|
const result: string[] = []
|
|
let prevCharCount = 0
|
|
|
|
for (let i = 0; i < tokenIds.length; i++) {
|
|
const decoded = encoding.decode(tokenIds.slice(0, i + 1))
|
|
const currentCharCount = [...decoded].length
|
|
const tokenCharCount = currentCharCount - prevCharCount
|
|
|
|
const tokenStr = textChars.slice(prevCharCount, prevCharCount + tokenCharCount).join('')
|
|
result.push(tokenStr)
|
|
prevCharCount = currentCharCount
|
|
}
|
|
|
|
return result
|
|
} catch (error) {
|
|
logger.error('Error getting token strings:', error)
|
|
return text.split(/(\s+)/).filter((s) => s.length > 0)
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Truncate text to a maximum token count
|
|
* Useful for handling texts that exceed model limits
|
|
*/
|
|
export function truncateToTokenLimit(
|
|
text: string,
|
|
maxTokens: number,
|
|
modelName = 'text-embedding-3-small'
|
|
): string {
|
|
if (!text || maxTokens <= 0) {
|
|
return ''
|
|
}
|
|
|
|
try {
|
|
const encoding = getEncoding(modelName)
|
|
const tokens = encoding.encode(text)
|
|
|
|
if (tokens.length <= maxTokens) {
|
|
return text
|
|
}
|
|
|
|
const truncatedTokens = tokens.slice(0, maxTokens)
|
|
const truncatedText = encoding.decode(truncatedTokens)
|
|
|
|
logger.warn(
|
|
`Truncated text from ${tokens.length} to ${maxTokens} tokens (${text.length} to ${truncatedText.length} chars)`
|
|
)
|
|
|
|
return truncatedText
|
|
} catch (error) {
|
|
logger.error('Error truncating text:', error)
|
|
const maxChars = maxTokens * 4
|
|
return text.slice(0, maxChars)
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Batch texts by token count to stay within API limits
|
|
* Returns array of batches where each batch's total tokens <= maxTokensPerBatch
|
|
*/
|
|
export function batchByTokenLimit(
|
|
texts: string[],
|
|
maxTokensPerBatch: number,
|
|
modelName = 'text-embedding-3-small'
|
|
): string[][] {
|
|
const batches: string[][] = []
|
|
let currentBatch: string[] = []
|
|
let currentTokenCount = 0
|
|
|
|
for (const text of texts) {
|
|
const tokenCount = getAccurateTokenCount(text, modelName)
|
|
|
|
if (tokenCount > maxTokensPerBatch) {
|
|
if (currentBatch.length > 0) {
|
|
batches.push(currentBatch)
|
|
currentBatch = []
|
|
currentTokenCount = 0
|
|
}
|
|
|
|
const truncated = truncateToTokenLimit(text, maxTokensPerBatch, modelName)
|
|
batches.push([truncated])
|
|
continue
|
|
}
|
|
|
|
if (currentBatch.length > 0 && currentTokenCount + tokenCount > maxTokensPerBatch) {
|
|
batches.push(currentBatch)
|
|
currentBatch = [text]
|
|
currentTokenCount = tokenCount
|
|
} else {
|
|
currentBatch.push(text)
|
|
currentTokenCount += tokenCount
|
|
}
|
|
}
|
|
|
|
if (currentBatch.length > 0) {
|
|
batches.push(currentBatch)
|
|
}
|
|
|
|
return batches
|
|
}
|
|
|
|
/**
|
|
* Clean up cached encodings (call when shutting down)
|
|
*/
|
|
export function clearEncodingCache(): void {
|
|
encodingCache.clear()
|
|
logger.info('Cleared tiktoken encoding cache')
|
|
}
|
|
|
|
/**
|
|
* Estimates token count for text using provider-specific heuristics
|
|
*/
|
|
export function estimateTokenCount(text: string, providerId?: string): TokenEstimate {
|
|
if (!text || text.length < MIN_TEXT_LENGTH_FOR_ESTIMATION) {
|
|
return {
|
|
count: 0,
|
|
confidence: 'high',
|
|
provider: providerId || 'unknown',
|
|
method: 'fallback',
|
|
}
|
|
}
|
|
|
|
const effectiveProviderId = providerId || TOKENIZATION_CONFIG.defaults.provider
|
|
const config = getProviderConfig(effectiveProviderId)
|
|
|
|
let estimatedTokens: number
|
|
|
|
switch (effectiveProviderId) {
|
|
case 'openai':
|
|
case 'azure-openai':
|
|
estimatedTokens = estimateOpenAITokens(text)
|
|
break
|
|
case 'anthropic':
|
|
case 'azure-anthropic':
|
|
estimatedTokens = estimateAnthropicTokens(text)
|
|
break
|
|
case 'google':
|
|
estimatedTokens = estimateGoogleTokens(text)
|
|
break
|
|
default:
|
|
estimatedTokens = estimateGenericTokens(text, config.avgCharsPerToken)
|
|
}
|
|
|
|
return {
|
|
count: Math.max(1, Math.round(estimatedTokens)),
|
|
confidence: config.confidence,
|
|
provider: effectiveProviderId,
|
|
method: 'heuristic',
|
|
}
|
|
}
|
|
|
|
/**
|
|
* OpenAI-specific token estimation using BPE characteristics
|
|
*/
|
|
function estimateOpenAITokens(text: string): number {
|
|
const words = text.trim().split(/\s+/)
|
|
let tokenCount = 0
|
|
|
|
for (const word of words) {
|
|
if (word.length === 0) continue
|
|
|
|
if (word.length <= 4) {
|
|
tokenCount += 1
|
|
} else if (word.length <= 8) {
|
|
tokenCount += Math.ceil(word.length / 4.5)
|
|
} else {
|
|
tokenCount += Math.ceil(word.length / 4)
|
|
}
|
|
|
|
const punctuationCount = (word.match(/[.,!?;:"'()[\]{}<>]/g) || []).length
|
|
tokenCount += punctuationCount * 0.5
|
|
}
|
|
|
|
const newlineCount = (text.match(/\n/g) || []).length
|
|
tokenCount += newlineCount * 0.5
|
|
|
|
return tokenCount
|
|
}
|
|
|
|
/**
|
|
* Anthropic Claude-specific token estimation
|
|
*/
|
|
function estimateAnthropicTokens(text: string): number {
|
|
const words = text.trim().split(/\s+/)
|
|
let tokenCount = 0
|
|
|
|
for (const word of words) {
|
|
if (word.length === 0) continue
|
|
|
|
if (word.length <= 4) {
|
|
tokenCount += 1
|
|
} else if (word.length <= 8) {
|
|
tokenCount += Math.ceil(word.length / 5)
|
|
} else {
|
|
tokenCount += Math.ceil(word.length / 4.5)
|
|
}
|
|
}
|
|
|
|
const newlineCount = (text.match(/\n/g) || []).length
|
|
tokenCount += newlineCount * 0.3
|
|
|
|
return tokenCount
|
|
}
|
|
|
|
/**
|
|
* Google Gemini-specific token estimation
|
|
*/
|
|
function estimateGoogleTokens(text: string): number {
|
|
const words = text.trim().split(/\s+/)
|
|
let tokenCount = 0
|
|
|
|
for (const word of words) {
|
|
if (word.length === 0) continue
|
|
|
|
if (word.length <= 5) {
|
|
tokenCount += 1
|
|
} else if (word.length <= 10) {
|
|
tokenCount += Math.ceil(word.length / 6)
|
|
} else {
|
|
tokenCount += Math.ceil(word.length / 5)
|
|
}
|
|
}
|
|
|
|
return tokenCount
|
|
}
|
|
|
|
/**
|
|
* Generic token estimation fallback
|
|
*/
|
|
function estimateGenericTokens(text: string, avgCharsPerToken: number): number {
|
|
const charCount = text.trim().length
|
|
return Math.ceil(charCount / avgCharsPerToken)
|
|
}
|
|
|
|
/**
|
|
* Estimates tokens for input content including context
|
|
*/
|
|
export function estimateInputTokens(
|
|
systemPrompt?: string,
|
|
context?: string,
|
|
messages?: Array<{ role: string; content: string }>,
|
|
providerId?: string
|
|
): TokenEstimate {
|
|
let totalText = ''
|
|
|
|
if (systemPrompt) {
|
|
totalText += `${systemPrompt}\n`
|
|
}
|
|
|
|
if (context) {
|
|
totalText += `${context}\n`
|
|
}
|
|
|
|
if (messages) {
|
|
for (const message of messages) {
|
|
totalText += `${message.role}: ${message.content}\n`
|
|
}
|
|
}
|
|
|
|
return estimateTokenCount(totalText, providerId)
|
|
}
|
|
|
|
/**
|
|
* Estimates tokens for output content
|
|
*/
|
|
export function estimateOutputTokens(content: string, providerId?: string): TokenEstimate {
|
|
return estimateTokenCount(content, providerId)
|
|
}
|