Files
sim/apps/sim/lib/guardrails/validate_hallucination.ts
Vikhyath Mondreti be3cdcf981 Merge pull request #3179 from simstudioai/improvement/file-download-timeouts
improvement(timeouts): files/base64 should use max timeouts + auth centralization
2026-02-10 15:57:06 -08:00

319 lines
8.8 KiB
TypeScript

import { db } from '@sim/db'
import { account } from '@sim/db/schema'
import { createLogger } from '@sim/logger'
import { eq } from 'drizzle-orm'
import { getBaseUrl } from '@/lib/core/utils/urls'
import { refreshTokenIfNeeded } from '@/app/api/auth/oauth/utils'
import { executeProviderRequest } from '@/providers'
import { getProviderFromModel } from '@/providers/utils'
const logger = createLogger('HallucinationValidator')
export interface HallucinationValidationResult {
passed: boolean
error?: string
score?: number
reasoning?: string
}
export interface HallucinationValidationInput {
userInput: string
knowledgeBaseId: string
threshold: number // 0-10 confidence scale, default 3 (scores below 3 fail)
topK: number // Number of chunks to retrieve, default 10
model: string
apiKey?: string
providerCredentials?: {
azureEndpoint?: string
azureApiVersion?: string
vertexProject?: string
vertexLocation?: string
vertexCredential?: string
bedrockAccessKeyId?: string
bedrockSecretKey?: string
bedrockRegion?: string
}
workflowId?: string
workspaceId?: string
authHeaders?: {
cookie?: string
authorization?: string
}
requestId: string
}
/**
* Query knowledge base to get relevant context chunks using the search API
*/
async function queryKnowledgeBase(
knowledgeBaseId: string,
query: string,
topK: number,
requestId: string,
workflowId?: string,
authHeaders?: { cookie?: string; authorization?: string }
): Promise<string[]> {
try {
logger.info(`[${requestId}] Querying knowledge base`, {
knowledgeBaseId,
query: query.substring(0, 100),
topK,
})
// Call the knowledge base search API directly
const searchUrl = `${getBaseUrl()}/api/knowledge/search`
const response = await fetch(searchUrl, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
...(authHeaders?.cookie ? { Cookie: authHeaders.cookie } : {}),
...(authHeaders?.authorization ? { Authorization: authHeaders.authorization } : {}),
},
body: JSON.stringify({
knowledgeBaseIds: [knowledgeBaseId],
query,
topK,
workflowId,
}),
})
if (!response.ok) {
logger.error(`[${requestId}] Knowledge base query failed`, {
status: response.status,
})
return []
}
const result = await response.json()
const results = result.data?.results || []
const chunks = results.map((r: any) => r.content || '').filter((c: string) => c.length > 0)
logger.info(`[${requestId}] Retrieved ${chunks.length} chunks from knowledge base`)
return chunks
} catch (error: any) {
logger.error(`[${requestId}] Error querying knowledge base`, {
error: error.message,
})
return []
}
}
/**
* Use an LLM to score confidence based on RAG context
* Returns a confidence score from 0-10 where:
* - 0 = full hallucination (completely unsupported)
* - 10 = fully grounded (completely supported)
*/
async function scoreHallucinationWithLLM(
userInput: string,
ragContext: string[],
model: string,
apiKey: string | undefined,
providerCredentials: HallucinationValidationInput['providerCredentials'],
workspaceId: string | undefined,
requestId: string
): Promise<{ score: number; reasoning: string }> {
try {
const contextText = ragContext.join('\n\n---\n\n')
const systemPrompt = `You are a confidence scoring system. Your job is to evaluate how well a user's input is supported by the provided reference context from a knowledge base.
Score the input on a confidence scale from 0 to 10:
- 0-2: Full hallucination - completely unsupported by context, contradicts the context
- 3-4: Low confidence - mostly unsupported, significant claims not in context
- 5-6: Medium confidence - partially supported, some claims not in context
- 7-8: High confidence - mostly supported, minor details not in context
- 9-10: Very high confidence - fully supported by context, all claims verified
Respond ONLY with valid JSON in this exact format:
{
"score": <number between 0-10>,
"reasoning": "<brief explanation of your score>"
}
Do not include any other text, markdown formatting, or code blocks. Only output the raw JSON object. Be strict - only give high scores (7+) if the input is well-supported by the context.`
const userPrompt = `Reference Context:
${contextText}
User Input to Evaluate:
${userInput}
Evaluate the consistency and provide your score and reasoning in JSON format.`
logger.info(`[${requestId}] Calling LLM for hallucination scoring`, {
model,
contextChunks: ragContext.length,
})
const providerId = getProviderFromModel(model)
let finalApiKey: string | undefined = apiKey
if (providerId === 'vertex' && providerCredentials?.vertexCredential) {
const credential = await db.query.account.findFirst({
where: eq(account.id, providerCredentials.vertexCredential),
})
if (credential) {
const { accessToken } = await refreshTokenIfNeeded(
requestId,
credential,
providerCredentials.vertexCredential
)
if (accessToken) {
finalApiKey = accessToken
}
}
}
const response = await executeProviderRequest(providerId, {
model,
systemPrompt,
messages: [
{
role: 'user',
content: userPrompt,
},
],
temperature: 0.1, // Low temperature for consistent scoring
apiKey: finalApiKey,
azureEndpoint: providerCredentials?.azureEndpoint,
azureApiVersion: providerCredentials?.azureApiVersion,
vertexProject: providerCredentials?.vertexProject,
vertexLocation: providerCredentials?.vertexLocation,
bedrockAccessKeyId: providerCredentials?.bedrockAccessKeyId,
bedrockSecretKey: providerCredentials?.bedrockSecretKey,
bedrockRegion: providerCredentials?.bedrockRegion,
workspaceId,
})
if (response instanceof ReadableStream || ('stream' in response && 'execution' in response)) {
throw new Error('Unexpected streaming response from LLM')
}
const content = response.content.trim()
logger.debug(`[${requestId}] LLM response:`, { content })
let jsonContent = content
if (content.includes('```')) {
const jsonMatch = content.match(/```(?:json)?\s*(\{[\s\S]*?\})\s*```/)
if (jsonMatch) {
jsonContent = jsonMatch[1]
}
}
const result = JSON.parse(jsonContent)
if (typeof result.score !== 'number' || result.score < 0 || result.score > 10) {
throw new Error('Invalid score format from LLM')
}
logger.info(`[${requestId}] Confidence score: ${result.score}/10`, {
reasoning: result.reasoning,
})
return {
score: result.score,
reasoning: result.reasoning || 'No reasoning provided',
}
} catch (error: any) {
logger.error(`[${requestId}] Error scoring with LLM`, {
error: error.message,
})
throw new Error(`Failed to score confidence: ${error.message}`)
}
}
/**
* Validate user input against knowledge base using RAG + LLM scoring
*/
export async function validateHallucination(
input: HallucinationValidationInput
): Promise<HallucinationValidationResult> {
const {
userInput,
knowledgeBaseId,
threshold,
topK,
model,
apiKey,
providerCredentials,
workflowId,
workspaceId,
authHeaders,
requestId,
} = input
try {
if (!userInput || userInput.trim().length === 0) {
return {
passed: false,
error: 'User input is required',
}
}
if (!knowledgeBaseId) {
return {
passed: false,
error: 'Knowledge base ID is required',
}
}
// Step 1: Query knowledge base with RAG
const ragContext = await queryKnowledgeBase(
knowledgeBaseId,
userInput,
topK,
requestId,
workflowId,
authHeaders
)
if (ragContext.length === 0) {
return {
passed: false,
error: 'No relevant context found in knowledge base',
}
}
// Step 2: Use LLM to score confidence
const { score, reasoning } = await scoreHallucinationWithLLM(
userInput,
ragContext,
model,
apiKey,
providerCredentials,
workspaceId,
requestId
)
logger.info(`[${requestId}] Confidence score: ${score}`, {
reasoning,
threshold,
})
// Step 3: Check against threshold. Lower scores = less confidence = fail validation
const passed = score >= threshold
return {
passed,
score,
reasoning,
error: passed
? undefined
: `Low confidence: score ${score}/10 is below threshold ${threshold}`,
}
} catch (error: any) {
logger.error(`[${requestId}] Hallucination validation error`, {
error: error.message,
})
return {
passed: false,
error: `Validation error: ${error.message}`,
}
}
}