mirror of
https://github.com/simstudioai/sim.git
synced 2026-04-06 03:00:16 -04:00
* improvement(processing): reduce redundant DB queries in execution preprocessing * improvement(processing): add defensive ID check for prefetched workflow record * improvement(processing): fix type safety in execution error logging Replace `as any` cast in non-SSE error path with proper `buildTraceSpans()` transformation, matching the SSE error path. Remove redundant `as any` cast in preprocessing.ts where the types already align. * improvement(processing): replace `as any` casts with proper types in logging - logger.ts: cast JSONB cost column to `WorkflowExecutionLog['cost']` instead of `any` in both `completeWorkflowExecution` and `getWorkflowExecution` - logger.ts: replace `(orgUsageBefore as any)?.toString?.()` with `String()` since COALESCE guarantees a non-null SQL aggregate value - logging-session.ts: cast JSONB cost to `AccumulatedCost` (the local interface) instead of `any` in `loadExistingCost` * improvement(processing): use exported HighestPrioritySubscription type in usage.ts Replace inline `Awaited<ReturnType<typeof getHighestPrioritySubscription>>` with the already-exported `HighestPrioritySubscription` type alias. * improvement(processing): replace remaining `as any` casts with proper types - preprocessing.ts: use exported `HighestPrioritySubscription` type instead of redeclaring via `Awaited<ReturnType<...>>` - deploy/route.ts, status/route.ts: cast `hasWorkflowChanged` args to `WorkflowState` instead of `any` (JSONB + object literal narrowing) - state/route.ts: type block sanitization and save with `BlockState` and `WorkflowState` instead of `any` - search-suggestions.ts: remove 8 unnecessary `as any` casts on `'date'` literal that already satisfies the `Suggestion['category']` union * fix(processing): prevent double-billing race in LoggingSession completion When executeWorkflowCore throws, its catch block fire-and-forgets safeCompleteWithError, then re-throws. The caller's catch block also fire-and-forgets safeCompleteWithError on the same LoggingSession. Both check this.completed (still false) before either's async DB write resolves, so both proceed to completeWorkflowExecution which uses additive SQL for billing — doubling the charged cost on every failed execution. Fix: add a synchronous `completing` flag set immediately before the async work begins. This blocks concurrent callers at the guard check. On failure, the flag is reset so the safe* fallback path (completeWithCostOnlyLog) can still attempt recovery. * fix(processing): unblock error responses and isolate run-count failures Remove unnecessary `await waitForCompletion()` from non-SSE and SSE error paths where no `markAsFailed()` follows — these were blocking error responses on log persistence for no reason. Wrap `updateWorkflowRunCounts` in its own try/catch so a run-count DB failure cannot prevent session completion, billing, and trace span persistence. * improvement(processing): remove dead setupExecutor method The method body was just a debug log with an `any` parameter — logging now works entirely through trace spans with no executor integration. * remove logger.debug * fix(processing): guard completionPromise as write-once (singleton promise) Prevent concurrent safeComplete* calls from overwriting completionPromise with a no-op. The guard now lives at the assignment site — if a completion is already in-flight, return its promise instead of starting a new one. This ensures waitForCompletion() always awaits the real work. * improvement(processing): remove empty else/catch blocks left by debug log cleanup * fix(processing): enforce waitForCompletion inside markAsFailed to prevent completion races Move waitForCompletion() into markAsFailed() so every call site is automatically safe against in-flight fire-and-forget completions. Remove the now-redundant external waitForCompletion() calls in route.ts. * fix(processing): reset completing flag on fallback failure, clean up empty catch - completeWithCostOnlyLog now resets this.completing = false when the fallback itself fails, preventing a permanently stuck session - Use _disconnectError in MCP test-connection to signal intentional ignore * fix(processing): restore disconnect error logging in MCP test-connection Revert unrelated debug log removal — this file isn't part of the processing improvements and the log aids connection leak detection. * fix(processing): address audit findings across branch - preprocessing.ts: use undefined (not null) for failed subscription fetch so getUserUsageLimit does a fresh lookup instead of silently falling back to free-tier limits - deployed/route.ts: log warning on loadDeployedWorkflowState failure instead of silently swallowing the error - schedule-execution.ts: remove dead successLog parameter and all call-site arguments left over from logger.debug cleanup - mcp/middleware.ts: drop unused error binding in empty catch - audit/log.ts, wand.ts: promote logger.debug to logger.warn in catch blocks where these are the only failure signal * revert: undo unnecessary subscription null→undefined change getHighestPrioritySubscription never throws (it catches internally and returns null), so the catch block in preprocessExecution is dead code. The null vs undefined distinction doesn't matter and the coercions added unnecessary complexity. * improvement(processing): remove dead try/catch around getHighestPrioritySubscription getHighestPrioritySubscription catches internally and returns null on error, so the wrapping try/catch was unreachable dead code. * improvement(processing): remove dead getSnapshotByHash method No longer called after createSnapshotWithDeduplication was refactored to use a single upsert instead of select-then-insert. ---------
491 lines
14 KiB
TypeScript
491 lines
14 KiB
TypeScript
import { db } from '@sim/db'
|
|
import { document, embedding } from '@sim/db/schema'
|
|
import { and, eq, inArray, isNull, sql } from 'drizzle-orm'
|
|
import type { StructuredFilter } from '@/lib/knowledge/types'
|
|
|
|
export async function getDocumentNamesByIds(
|
|
documentIds: string[]
|
|
): Promise<Record<string, string>> {
|
|
if (documentIds.length === 0) {
|
|
return {}
|
|
}
|
|
|
|
const uniqueIds = [...new Set(documentIds)]
|
|
const documents = await db
|
|
.select({
|
|
id: document.id,
|
|
filename: document.filename,
|
|
})
|
|
.from(document)
|
|
.where(and(inArray(document.id, uniqueIds), isNull(document.deletedAt)))
|
|
|
|
const documentNameMap: Record<string, string> = {}
|
|
documents.forEach((doc) => {
|
|
documentNameMap[doc.id] = doc.filename
|
|
})
|
|
|
|
return documentNameMap
|
|
}
|
|
|
|
export interface SearchResult {
|
|
id: string
|
|
content: string
|
|
documentId: string
|
|
chunkIndex: number
|
|
// Text tags
|
|
tag1: string | null
|
|
tag2: string | null
|
|
tag3: string | null
|
|
tag4: string | null
|
|
tag5: string | null
|
|
tag6: string | null
|
|
tag7: string | null
|
|
// Number tags (5 slots)
|
|
number1: number | null
|
|
number2: number | null
|
|
number3: number | null
|
|
number4: number | null
|
|
number5: number | null
|
|
// Date tags (2 slots)
|
|
date1: Date | null
|
|
date2: Date | null
|
|
// Boolean tags (3 slots)
|
|
boolean1: boolean | null
|
|
boolean2: boolean | null
|
|
boolean3: boolean | null
|
|
distance: number
|
|
knowledgeBaseId: string
|
|
}
|
|
|
|
export interface SearchParams {
|
|
knowledgeBaseIds: string[]
|
|
topK: number
|
|
structuredFilters?: StructuredFilter[]
|
|
queryVector?: string
|
|
distanceThreshold?: number
|
|
}
|
|
|
|
// Use shared embedding utility
|
|
export { generateSearchEmbedding } from '@/lib/knowledge/embeddings'
|
|
|
|
/** All valid tag slot keys */
|
|
const TAG_SLOT_KEYS = [
|
|
// Text tags (7 slots)
|
|
'tag1',
|
|
'tag2',
|
|
'tag3',
|
|
'tag4',
|
|
'tag5',
|
|
'tag6',
|
|
'tag7',
|
|
// Number tags (5 slots)
|
|
'number1',
|
|
'number2',
|
|
'number3',
|
|
'number4',
|
|
'number5',
|
|
// Date tags (2 slots)
|
|
'date1',
|
|
'date2',
|
|
// Boolean tags (3 slots)
|
|
'boolean1',
|
|
'boolean2',
|
|
'boolean3',
|
|
] as const
|
|
|
|
type TagSlotKey = (typeof TAG_SLOT_KEYS)[number]
|
|
|
|
function isTagSlotKey(key: string): key is TagSlotKey {
|
|
return TAG_SLOT_KEYS.includes(key as TagSlotKey)
|
|
}
|
|
|
|
/** Common fields selected for search results */
|
|
const getSearchResultFields = (distanceExpr: any) => ({
|
|
id: embedding.id,
|
|
content: embedding.content,
|
|
documentId: embedding.documentId,
|
|
chunkIndex: embedding.chunkIndex,
|
|
// Text tags
|
|
tag1: embedding.tag1,
|
|
tag2: embedding.tag2,
|
|
tag3: embedding.tag3,
|
|
tag4: embedding.tag4,
|
|
tag5: embedding.tag5,
|
|
tag6: embedding.tag6,
|
|
tag7: embedding.tag7,
|
|
// Number tags (5 slots)
|
|
number1: embedding.number1,
|
|
number2: embedding.number2,
|
|
number3: embedding.number3,
|
|
number4: embedding.number4,
|
|
number5: embedding.number5,
|
|
// Date tags (2 slots)
|
|
date1: embedding.date1,
|
|
date2: embedding.date2,
|
|
// Boolean tags (3 slots)
|
|
boolean1: embedding.boolean1,
|
|
boolean2: embedding.boolean2,
|
|
boolean3: embedding.boolean3,
|
|
distance: distanceExpr,
|
|
knowledgeBaseId: embedding.knowledgeBaseId,
|
|
})
|
|
|
|
/**
|
|
* Build a single SQL condition for a filter
|
|
*/
|
|
function buildFilterCondition(filter: StructuredFilter, embeddingTable: any) {
|
|
const { tagSlot, fieldType, operator, value, valueTo } = filter
|
|
|
|
if (!isTagSlotKey(tagSlot)) {
|
|
return null
|
|
}
|
|
|
|
const column = embeddingTable[tagSlot]
|
|
if (!column) return null
|
|
|
|
// Handle text operators
|
|
if (fieldType === 'text') {
|
|
const stringValue = String(value)
|
|
switch (operator) {
|
|
case 'eq':
|
|
return sql`LOWER(${column}) = LOWER(${stringValue})`
|
|
case 'neq':
|
|
return sql`LOWER(${column}) != LOWER(${stringValue})`
|
|
case 'contains':
|
|
return sql`LOWER(${column}) LIKE LOWER(${`%${stringValue}%`})`
|
|
case 'not_contains':
|
|
return sql`LOWER(${column}) NOT LIKE LOWER(${`%${stringValue}%`})`
|
|
case 'starts_with':
|
|
return sql`LOWER(${column}) LIKE LOWER(${`${stringValue}%`})`
|
|
case 'ends_with':
|
|
return sql`LOWER(${column}) LIKE LOWER(${`%${stringValue}`})`
|
|
default:
|
|
return sql`LOWER(${column}) = LOWER(${stringValue})`
|
|
}
|
|
}
|
|
|
|
// Handle number operators
|
|
if (fieldType === 'number') {
|
|
const numValue = typeof value === 'number' ? value : Number.parseFloat(String(value))
|
|
if (Number.isNaN(numValue)) return null
|
|
|
|
switch (operator) {
|
|
case 'eq':
|
|
return sql`${column} = ${numValue}`
|
|
case 'neq':
|
|
return sql`${column} != ${numValue}`
|
|
case 'gt':
|
|
return sql`${column} > ${numValue}`
|
|
case 'gte':
|
|
return sql`${column} >= ${numValue}`
|
|
case 'lt':
|
|
return sql`${column} < ${numValue}`
|
|
case 'lte':
|
|
return sql`${column} <= ${numValue}`
|
|
case 'between':
|
|
if (valueTo !== undefined) {
|
|
const numValueTo =
|
|
typeof valueTo === 'number' ? valueTo : Number.parseFloat(String(valueTo))
|
|
if (Number.isNaN(numValueTo)) return sql`${column} = ${numValue}`
|
|
return sql`${column} >= ${numValue} AND ${column} <= ${numValueTo}`
|
|
}
|
|
return sql`${column} = ${numValue}`
|
|
default:
|
|
return sql`${column} = ${numValue}`
|
|
}
|
|
}
|
|
|
|
// Handle date operators - expects YYYY-MM-DD format from frontend
|
|
if (fieldType === 'date') {
|
|
const dateStr = String(value)
|
|
// Validate YYYY-MM-DD format
|
|
if (!/^\d{4}-\d{2}-\d{2}$/.test(dateStr)) {
|
|
return null
|
|
}
|
|
|
|
switch (operator) {
|
|
case 'eq':
|
|
return sql`${column}::date = ${dateStr}::date`
|
|
case 'neq':
|
|
return sql`${column}::date != ${dateStr}::date`
|
|
case 'gt':
|
|
return sql`${column}::date > ${dateStr}::date`
|
|
case 'gte':
|
|
return sql`${column}::date >= ${dateStr}::date`
|
|
case 'lt':
|
|
return sql`${column}::date < ${dateStr}::date`
|
|
case 'lte':
|
|
return sql`${column}::date <= ${dateStr}::date`
|
|
case 'between':
|
|
if (valueTo !== undefined) {
|
|
const dateStrTo = String(valueTo)
|
|
if (!/^\d{4}-\d{2}-\d{2}$/.test(dateStrTo)) {
|
|
return sql`${column}::date = ${dateStr}::date`
|
|
}
|
|
return sql`${column}::date >= ${dateStr}::date AND ${column}::date <= ${dateStrTo}::date`
|
|
}
|
|
return sql`${column}::date = ${dateStr}::date`
|
|
default:
|
|
return sql`${column}::date = ${dateStr}::date`
|
|
}
|
|
}
|
|
|
|
// Handle boolean operators
|
|
if (fieldType === 'boolean') {
|
|
const boolValue = value === true || value === 'true'
|
|
switch (operator) {
|
|
case 'eq':
|
|
return sql`${column} = ${boolValue}`
|
|
case 'neq':
|
|
return sql`${column} != ${boolValue}`
|
|
default:
|
|
return sql`${column} = ${boolValue}`
|
|
}
|
|
}
|
|
|
|
// Fallback to equality
|
|
return sql`${column} = ${value}`
|
|
}
|
|
|
|
/**
|
|
* Build SQL conditions from structured filters with operator support
|
|
* - Same tag multiple times: OR logic
|
|
* - Different tags: AND logic
|
|
*/
|
|
function getStructuredTagFilters(filters: StructuredFilter[], embeddingTable: any) {
|
|
// Group filters by tagSlot
|
|
const filtersBySlot = new Map<string, StructuredFilter[]>()
|
|
for (const filter of filters) {
|
|
const slot = filter.tagSlot
|
|
if (!filtersBySlot.has(slot)) {
|
|
filtersBySlot.set(slot, [])
|
|
}
|
|
filtersBySlot.get(slot)!.push(filter)
|
|
}
|
|
|
|
// Build conditions: OR within same slot, AND across different slots
|
|
const conditions: ReturnType<typeof sql>[] = []
|
|
|
|
for (const [slot, slotFilters] of filtersBySlot) {
|
|
const slotConditions = slotFilters
|
|
.map((f) => buildFilterCondition(f, embeddingTable))
|
|
.filter((c): c is ReturnType<typeof sql> => c !== null)
|
|
|
|
if (slotConditions.length === 0) continue
|
|
|
|
if (slotConditions.length === 1) {
|
|
// Single condition for this slot
|
|
conditions.push(slotConditions[0])
|
|
} else {
|
|
// Multiple conditions for same slot - OR them together
|
|
conditions.push(sql`(${sql.join(slotConditions, sql` OR `)})`)
|
|
}
|
|
}
|
|
|
|
return conditions
|
|
}
|
|
|
|
export function getQueryStrategy(kbCount: number, topK: number) {
|
|
const useParallel = kbCount > 4 || (kbCount > 2 && topK > 50)
|
|
const distanceThreshold = kbCount > 3 ? 0.8 : 1.0
|
|
const parallelLimit = Math.ceil(topK / kbCount) + 5
|
|
|
|
return {
|
|
useParallel,
|
|
distanceThreshold,
|
|
parallelLimit,
|
|
singleQueryOptimized: kbCount <= 2,
|
|
}
|
|
}
|
|
|
|
async function executeTagFilterQuery(
|
|
knowledgeBaseIds: string[],
|
|
structuredFilters: StructuredFilter[]
|
|
): Promise<{ id: string }[]> {
|
|
const tagFilterConditions = getStructuredTagFilters(structuredFilters, embedding)
|
|
|
|
if (knowledgeBaseIds.length === 1) {
|
|
return await db
|
|
.select({ id: embedding.id })
|
|
.from(embedding)
|
|
.innerJoin(document, eq(embedding.documentId, document.id))
|
|
.where(
|
|
and(
|
|
eq(embedding.knowledgeBaseId, knowledgeBaseIds[0]),
|
|
eq(embedding.enabled, true),
|
|
isNull(document.deletedAt),
|
|
...tagFilterConditions
|
|
)
|
|
)
|
|
}
|
|
return await db
|
|
.select({ id: embedding.id })
|
|
.from(embedding)
|
|
.innerJoin(document, eq(embedding.documentId, document.id))
|
|
.where(
|
|
and(
|
|
inArray(embedding.knowledgeBaseId, knowledgeBaseIds),
|
|
eq(embedding.enabled, true),
|
|
isNull(document.deletedAt),
|
|
...tagFilterConditions
|
|
)
|
|
)
|
|
}
|
|
|
|
async function executeVectorSearchOnIds(
|
|
embeddingIds: string[],
|
|
queryVector: string,
|
|
topK: number,
|
|
distanceThreshold: number
|
|
): Promise<SearchResult[]> {
|
|
if (embeddingIds.length === 0) {
|
|
return []
|
|
}
|
|
|
|
return await db
|
|
.select(
|
|
getSearchResultFields(
|
|
sql<number>`${embedding.embedding} <=> ${queryVector}::vector`.as('distance')
|
|
)
|
|
)
|
|
.from(embedding)
|
|
.innerJoin(document, eq(embedding.documentId, document.id))
|
|
.where(
|
|
and(
|
|
inArray(embedding.id, embeddingIds),
|
|
isNull(document.deletedAt),
|
|
sql`${embedding.embedding} <=> ${queryVector}::vector < ${distanceThreshold}`
|
|
)
|
|
)
|
|
.orderBy(sql`${embedding.embedding} <=> ${queryVector}::vector`)
|
|
.limit(topK)
|
|
}
|
|
|
|
export async function handleTagOnlySearch(params: SearchParams): Promise<SearchResult[]> {
|
|
const { knowledgeBaseIds, topK, structuredFilters } = params
|
|
|
|
if (!structuredFilters || structuredFilters.length === 0) {
|
|
throw new Error('Tag filters are required for tag-only search')
|
|
}
|
|
|
|
const strategy = getQueryStrategy(knowledgeBaseIds.length, topK)
|
|
const tagFilterConditions = getStructuredTagFilters(structuredFilters, embedding)
|
|
|
|
if (strategy.useParallel) {
|
|
// Parallel approach for many KBs
|
|
const parallelLimit = Math.ceil(topK / knowledgeBaseIds.length) + 5
|
|
|
|
const queryPromises = knowledgeBaseIds.map(async (kbId) => {
|
|
return await db
|
|
.select(getSearchResultFields(sql<number>`0`.as('distance')))
|
|
.from(embedding)
|
|
.innerJoin(document, eq(embedding.documentId, document.id))
|
|
.where(
|
|
and(
|
|
eq(embedding.knowledgeBaseId, kbId),
|
|
eq(embedding.enabled, true),
|
|
isNull(document.deletedAt),
|
|
...tagFilterConditions
|
|
)
|
|
)
|
|
.limit(parallelLimit)
|
|
})
|
|
|
|
const parallelResults = await Promise.all(queryPromises)
|
|
return parallelResults.flat().slice(0, topK)
|
|
}
|
|
// Single query for fewer KBs
|
|
return await db
|
|
.select(getSearchResultFields(sql<number>`0`.as('distance')))
|
|
.from(embedding)
|
|
.innerJoin(document, eq(embedding.documentId, document.id))
|
|
.where(
|
|
and(
|
|
inArray(embedding.knowledgeBaseId, knowledgeBaseIds),
|
|
eq(embedding.enabled, true),
|
|
isNull(document.deletedAt),
|
|
...tagFilterConditions
|
|
)
|
|
)
|
|
.limit(topK)
|
|
}
|
|
|
|
export async function handleVectorOnlySearch(params: SearchParams): Promise<SearchResult[]> {
|
|
const { knowledgeBaseIds, topK, queryVector, distanceThreshold } = params
|
|
|
|
if (!queryVector || !distanceThreshold) {
|
|
throw new Error('Query vector and distance threshold are required for vector-only search')
|
|
}
|
|
|
|
const strategy = getQueryStrategy(knowledgeBaseIds.length, topK)
|
|
|
|
const distanceExpr = sql<number>`${embedding.embedding} <=> ${queryVector}::vector`.as('distance')
|
|
|
|
if (strategy.useParallel) {
|
|
// Parallel approach for many KBs
|
|
const parallelLimit = Math.ceil(topK / knowledgeBaseIds.length) + 5
|
|
|
|
const queryPromises = knowledgeBaseIds.map(async (kbId) => {
|
|
return await db
|
|
.select(getSearchResultFields(distanceExpr))
|
|
.from(embedding)
|
|
.innerJoin(document, eq(embedding.documentId, document.id))
|
|
.where(
|
|
and(
|
|
eq(embedding.knowledgeBaseId, kbId),
|
|
eq(embedding.enabled, true),
|
|
isNull(document.deletedAt),
|
|
sql`${embedding.embedding} <=> ${queryVector}::vector < ${distanceThreshold}`
|
|
)
|
|
)
|
|
.orderBy(sql`${embedding.embedding} <=> ${queryVector}::vector`)
|
|
.limit(parallelLimit)
|
|
})
|
|
|
|
const parallelResults = await Promise.all(queryPromises)
|
|
const allResults = parallelResults.flat()
|
|
return allResults.sort((a, b) => a.distance - b.distance).slice(0, topK)
|
|
}
|
|
// Single query for fewer KBs
|
|
return await db
|
|
.select(getSearchResultFields(distanceExpr))
|
|
.from(embedding)
|
|
.innerJoin(document, eq(embedding.documentId, document.id))
|
|
.where(
|
|
and(
|
|
inArray(embedding.knowledgeBaseId, knowledgeBaseIds),
|
|
eq(embedding.enabled, true),
|
|
isNull(document.deletedAt),
|
|
sql`${embedding.embedding} <=> ${queryVector}::vector < ${distanceThreshold}`
|
|
)
|
|
)
|
|
.orderBy(sql`${embedding.embedding} <=> ${queryVector}::vector`)
|
|
.limit(topK)
|
|
}
|
|
|
|
export async function handleTagAndVectorSearch(params: SearchParams): Promise<SearchResult[]> {
|
|
const { knowledgeBaseIds, topK, structuredFilters, queryVector, distanceThreshold } = params
|
|
|
|
if (!structuredFilters || structuredFilters.length === 0) {
|
|
throw new Error('Tag filters are required for tag and vector search')
|
|
}
|
|
if (!queryVector || !distanceThreshold) {
|
|
throw new Error('Query vector and distance threshold are required for tag and vector search')
|
|
}
|
|
|
|
// Step 1: Filter by tags first
|
|
const tagFilteredIds = await executeTagFilterQuery(knowledgeBaseIds, structuredFilters)
|
|
|
|
if (tagFilteredIds.length === 0) {
|
|
return []
|
|
}
|
|
|
|
// Step 2: Perform vector search only on tag-filtered results
|
|
return await executeVectorSearchOnIds(
|
|
tagFilteredIds.map((r) => r.id),
|
|
queryVector,
|
|
topK,
|
|
distanceThreshold
|
|
)
|
|
}
|