mirror of
https://github.com/simstudioai/sim.git
synced 2026-04-28 03:00:29 -04:00
fix(mothership): Use heartbeat mechanism for chat locks (#4286)
This commit is contained in:
@@ -210,6 +210,7 @@ export function createSSEStream(params: StreamingOrchestrationParams): ReadableS
|
||||
|
||||
const abortPoller = startAbortPoller(streamId, abortController, {
|
||||
requestId,
|
||||
chatId,
|
||||
})
|
||||
publisher.startKeepalive()
|
||||
|
||||
|
||||
120
apps/sim/lib/copilot/request/session/abort.test.ts
Normal file
120
apps/sim/lib/copilot/request/session/abort.test.ts
Normal file
@@ -0,0 +1,120 @@
|
||||
/**
|
||||
* @vitest-environment node
|
||||
*/
|
||||
|
||||
import { redisConfigMock, redisConfigMockFns } from '@sim/testing'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
const { mockHasAbortMarker, mockClearAbortMarker, mockWriteAbortMarker } = vi.hoisted(() => ({
|
||||
mockHasAbortMarker: vi.fn().mockResolvedValue(false),
|
||||
mockClearAbortMarker: vi.fn().mockResolvedValue(undefined),
|
||||
mockWriteAbortMarker: vi.fn().mockResolvedValue(undefined),
|
||||
}))
|
||||
|
||||
vi.mock('@/lib/core/config/redis', () => redisConfigMock)
|
||||
vi.mock('@/lib/copilot/request/session/buffer', () => ({
|
||||
hasAbortMarker: mockHasAbortMarker,
|
||||
clearAbortMarker: mockClearAbortMarker,
|
||||
writeAbortMarker: mockWriteAbortMarker,
|
||||
}))
|
||||
vi.mock('@/lib/copilot/request/otel', () => ({
|
||||
withCopilotSpan: (_span: unknown, _attrs: unknown, fn: (span: unknown) => unknown) =>
|
||||
fn({ setAttribute: vi.fn() }),
|
||||
}))
|
||||
|
||||
import { startAbortPoller } from '@/lib/copilot/request/session/abort'
|
||||
|
||||
describe('startAbortPoller heartbeat', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.useFakeTimers()
|
||||
mockHasAbortMarker.mockResolvedValue(false)
|
||||
redisConfigMockFns.mockExtendLock.mockResolvedValue(true)
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
it('extends the chat stream lock approximately every heartbeat interval', async () => {
|
||||
const controller = new AbortController()
|
||||
const streamId = 'stream-heartbeat-1'
|
||||
const chatId = 'chat-heartbeat-1'
|
||||
|
||||
const interval = startAbortPoller(streamId, controller, { chatId })
|
||||
|
||||
try {
|
||||
await vi.advanceTimersByTimeAsync(15_000)
|
||||
expect(redisConfigMockFns.mockExtendLock).not.toHaveBeenCalled()
|
||||
|
||||
await vi.advanceTimersByTimeAsync(6_000)
|
||||
|
||||
expect(redisConfigMockFns.mockExtendLock).toHaveBeenCalledTimes(1)
|
||||
expect(redisConfigMockFns.mockExtendLock).toHaveBeenLastCalledWith(
|
||||
`copilot:chat-stream-lock:${chatId}`,
|
||||
streamId,
|
||||
60
|
||||
)
|
||||
|
||||
await vi.advanceTimersByTimeAsync(20_000)
|
||||
expect(redisConfigMockFns.mockExtendLock).toHaveBeenCalledTimes(2)
|
||||
|
||||
await vi.advanceTimersByTimeAsync(20_000)
|
||||
expect(redisConfigMockFns.mockExtendLock).toHaveBeenCalledTimes(3)
|
||||
} finally {
|
||||
clearInterval(interval)
|
||||
}
|
||||
})
|
||||
|
||||
it('does not extend the lock when no chatId is passed (backward compat)', async () => {
|
||||
const controller = new AbortController()
|
||||
const interval = startAbortPoller('stream-no-chat', controller, {})
|
||||
|
||||
try {
|
||||
await vi.advanceTimersByTimeAsync(90_000)
|
||||
expect(redisConfigMockFns.mockExtendLock).not.toHaveBeenCalled()
|
||||
} finally {
|
||||
clearInterval(interval)
|
||||
}
|
||||
})
|
||||
|
||||
it('retries on the next tick when extendLock throws (no 20s backoff)', async () => {
|
||||
const controller = new AbortController()
|
||||
const streamId = 'stream-retry'
|
||||
const chatId = 'chat-retry'
|
||||
|
||||
redisConfigMockFns.mockExtendLock.mockRejectedValueOnce(new Error('redis down'))
|
||||
|
||||
const interval = startAbortPoller(streamId, controller, { chatId })
|
||||
|
||||
try {
|
||||
await vi.advanceTimersByTimeAsync(20_000)
|
||||
expect(redisConfigMockFns.mockExtendLock).toHaveBeenCalledTimes(1)
|
||||
|
||||
await vi.advanceTimersByTimeAsync(1_000)
|
||||
expect(redisConfigMockFns.mockExtendLock).toHaveBeenCalledTimes(2)
|
||||
} finally {
|
||||
clearInterval(interval)
|
||||
}
|
||||
})
|
||||
|
||||
it('stops heartbeating after ownership is lost', async () => {
|
||||
const controller = new AbortController()
|
||||
const streamId = 'stream-lost'
|
||||
const chatId = 'chat-lost'
|
||||
|
||||
redisConfigMockFns.mockExtendLock.mockResolvedValueOnce(false)
|
||||
|
||||
const interval = startAbortPoller(streamId, controller, { chatId })
|
||||
|
||||
try {
|
||||
await vi.advanceTimersByTimeAsync(21_000)
|
||||
expect(redisConfigMockFns.mockExtendLock).toHaveBeenCalledTimes(1)
|
||||
|
||||
await vi.advanceTimersByTimeAsync(60_000)
|
||||
expect(redisConfigMockFns.mockExtendLock).toHaveBeenCalledTimes(1)
|
||||
} finally {
|
||||
clearInterval(interval)
|
||||
}
|
||||
})
|
||||
})
|
||||
@@ -5,7 +5,7 @@ import { AbortBackend } from '@/lib/copilot/generated/trace-attribute-values-v1'
|
||||
import { TraceAttr } from '@/lib/copilot/generated/trace-attributes-v1'
|
||||
import { TraceSpan } from '@/lib/copilot/generated/trace-spans-v1'
|
||||
import { withCopilotSpan } from '@/lib/copilot/request/otel'
|
||||
import { acquireLock, getRedisClient, releaseLock } from '@/lib/core/config/redis'
|
||||
import { acquireLock, extendLock, getRedisClient, releaseLock } from '@/lib/core/config/redis'
|
||||
import { AbortReason } from './abort-reason'
|
||||
import { clearAbortMarker, hasAbortMarker, writeAbortMarker } from './buffer'
|
||||
|
||||
@@ -18,7 +18,22 @@ const pendingChatStreams = new Map<
|
||||
>()
|
||||
|
||||
const DEFAULT_ABORT_POLL_MS = 1000
|
||||
const CHAT_STREAM_LOCK_TTL_SECONDS = 2 * 60 * 60
|
||||
|
||||
/**
|
||||
* TTL for the per-chat stream lock. Kept short so that if the Sim pod
|
||||
* holding the lock dies (SIGKILL, OOM, a SIGTERM drain that doesn't
|
||||
* reach the release path), the lock self-heals inside a minute rather
|
||||
* than stranding the chat for hours. A live stream keeps the lock alive
|
||||
* via `CHAT_STREAM_LOCK_HEARTBEAT_INTERVAL_MS` heartbeats.
|
||||
*/
|
||||
const CHAT_STREAM_LOCK_TTL_SECONDS = 60
|
||||
|
||||
/**
|
||||
* Heartbeat cadence for extending the per-chat stream lock. Set to a
|
||||
* third of the TTL so one missed beat still leaves room for recovery
|
||||
* before the lock expires under a still-live stream.
|
||||
*/
|
||||
const CHAT_STREAM_LOCK_HEARTBEAT_INTERVAL_MS = 20_000
|
||||
|
||||
function registerPendingChatStream(chatId: string, streamId: string): void {
|
||||
let resolve!: () => void
|
||||
@@ -262,10 +277,14 @@ const pollingStreams = new Set<string>()
|
||||
export function startAbortPoller(
|
||||
streamId: string,
|
||||
abortController: AbortController,
|
||||
options?: { pollMs?: number; requestId?: string }
|
||||
options?: { pollMs?: number; requestId?: string; chatId?: string }
|
||||
): ReturnType<typeof setInterval> {
|
||||
const pollMs = options?.pollMs ?? DEFAULT_ABORT_POLL_MS
|
||||
const requestId = options?.requestId
|
||||
const chatId = options?.chatId
|
||||
|
||||
let lastHeartbeatAt = Date.now()
|
||||
let heartbeatOwnershipLost = false
|
||||
|
||||
return setInterval(() => {
|
||||
if (pollingStreams.has(streamId)) return
|
||||
@@ -287,6 +306,33 @@ export function startAbortPoller(
|
||||
} finally {
|
||||
pollingStreams.delete(streamId)
|
||||
}
|
||||
|
||||
if (!chatId || heartbeatOwnershipLost) return
|
||||
if (Date.now() - lastHeartbeatAt < CHAT_STREAM_LOCK_HEARTBEAT_INTERVAL_MS) return
|
||||
|
||||
try {
|
||||
const owned = await extendLock(
|
||||
getChatStreamLockKey(chatId),
|
||||
streamId,
|
||||
CHAT_STREAM_LOCK_TTL_SECONDS
|
||||
)
|
||||
lastHeartbeatAt = Date.now()
|
||||
if (!owned) {
|
||||
heartbeatOwnershipLost = true
|
||||
logger.warn('Lost ownership of chat stream lock — stopping heartbeat', {
|
||||
chatId,
|
||||
streamId,
|
||||
...(requestId ? { requestId } : {}),
|
||||
})
|
||||
}
|
||||
} catch (error) {
|
||||
logger.warn('Failed to extend chat stream lock TTL', {
|
||||
chatId,
|
||||
streamId,
|
||||
...(requestId ? { requestId } : {}),
|
||||
error: toError(error).message,
|
||||
})
|
||||
}
|
||||
})()
|
||||
}, pollMs)
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ vi.mock('ioredis', () => ({
|
||||
|
||||
import {
|
||||
closeRedisConnection,
|
||||
extendLock,
|
||||
getRedisClient,
|
||||
onRedisReconnect,
|
||||
resetForTesting,
|
||||
@@ -120,6 +121,48 @@ describe('redis config', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('extendLock', () => {
|
||||
const lockKey = 'copilot:chat-stream-lock:chat-1'
|
||||
const value = 'stream-abc'
|
||||
const ttlSeconds = 60
|
||||
|
||||
it('returns true when the caller still owns the lock and EXPIRE succeeds', async () => {
|
||||
mockRedisInstance.eval.mockResolvedValueOnce(1)
|
||||
|
||||
const extended = await extendLock(lockKey, value, ttlSeconds)
|
||||
|
||||
expect(extended).toBe(true)
|
||||
expect(mockRedisInstance.eval).toHaveBeenCalledWith(
|
||||
expect.stringContaining('expire'),
|
||||
1,
|
||||
lockKey,
|
||||
value,
|
||||
ttlSeconds
|
||||
)
|
||||
})
|
||||
|
||||
it('returns false when the value does not match (lock owned by another)', async () => {
|
||||
mockRedisInstance.eval.mockResolvedValueOnce(0)
|
||||
|
||||
const extended = await extendLock(lockKey, value, ttlSeconds)
|
||||
|
||||
expect(extended).toBe(false)
|
||||
})
|
||||
|
||||
it('returns true as a no-op when Redis is unavailable', async () => {
|
||||
vi.resetModules()
|
||||
vi.doMock('@/lib/core/config/env', () =>
|
||||
createEnvMock({ REDIS_URL: undefined as unknown as string })
|
||||
)
|
||||
const { extendLock: extendLockNoRedis } = await import('@/lib/core/config/redis')
|
||||
|
||||
const extended = await extendLockNoRedis(lockKey, value, ttlSeconds)
|
||||
|
||||
expect(extended).toBe(true)
|
||||
vi.doUnmock('@/lib/core/config/env')
|
||||
})
|
||||
})
|
||||
|
||||
describe('retryStrategy', () => {
|
||||
function captureRetryStrategy(): (times: number) => number {
|
||||
let capturedConfig: Record<string, unknown> = {}
|
||||
|
||||
@@ -136,6 +136,21 @@ else
|
||||
end
|
||||
`
|
||||
|
||||
/**
|
||||
* Lua script for safe lock TTL extension.
|
||||
* Only refreshes the expiry if the value matches (ownership verification),
|
||||
* so a stale heartbeat from a prior owner cannot extend a lock currently
|
||||
* held by someone else after a TTL eviction.
|
||||
* Returns 1 if the TTL was extended, 0 if not (value mismatch or key gone).
|
||||
*/
|
||||
const EXTEND_LOCK_SCRIPT = `
|
||||
if redis.call("get", KEYS[1]) == ARGV[1] then
|
||||
return redis.call("expire", KEYS[1], ARGV[2])
|
||||
else
|
||||
return 0
|
||||
end
|
||||
`
|
||||
|
||||
/**
|
||||
* Acquire a distributed lock using Redis SET NX.
|
||||
* Returns true if lock acquired, false if already held.
|
||||
@@ -175,6 +190,29 @@ export async function releaseLock(lockKey: string, value: string): Promise<boole
|
||||
return result === 1
|
||||
}
|
||||
|
||||
/**
|
||||
* Extend the TTL of a distributed lock if still owned by the caller.
|
||||
* Returns true if the caller still owns the lock and the TTL was refreshed,
|
||||
* false if the lock has been taken over by another owner or has expired.
|
||||
*
|
||||
* When Redis is not available, returns true (no-op) to match the behavior
|
||||
* of `acquireLock` / `releaseLock`: single-replica deployments without
|
||||
* Redis never held a real lock, so heartbeat success is implicit.
|
||||
*/
|
||||
export async function extendLock(
|
||||
lockKey: string,
|
||||
value: string,
|
||||
expirySeconds: number
|
||||
): Promise<boolean> {
|
||||
const redis = getRedisClient()
|
||||
if (!redis) {
|
||||
return true
|
||||
}
|
||||
|
||||
const result = await redis.eval(EXTEND_LOCK_SCRIPT, 1, lockKey, value, expirySeconds)
|
||||
return result === 1
|
||||
}
|
||||
|
||||
/**
|
||||
* Close the Redis connection.
|
||||
* Use for graceful shutdown.
|
||||
|
||||
@@ -17,6 +17,7 @@ export const redisConfigMockFns = {
|
||||
mockOnRedisReconnect: vi.fn(),
|
||||
mockAcquireLock: vi.fn().mockResolvedValue(true),
|
||||
mockReleaseLock: vi.fn().mockResolvedValue(true),
|
||||
mockExtendLock: vi.fn().mockResolvedValue(true),
|
||||
mockCloseRedisConnection: vi.fn().mockResolvedValue(undefined),
|
||||
mockResetForTesting: vi.fn(),
|
||||
}
|
||||
@@ -34,6 +35,7 @@ export const redisConfigMock = {
|
||||
onRedisReconnect: redisConfigMockFns.mockOnRedisReconnect,
|
||||
acquireLock: redisConfigMockFns.mockAcquireLock,
|
||||
releaseLock: redisConfigMockFns.mockReleaseLock,
|
||||
extendLock: redisConfigMockFns.mockExtendLock,
|
||||
closeRedisConnection: redisConfigMockFns.mockCloseRedisConnection,
|
||||
resetForTesting: redisConfigMockFns.mockResetForTesting,
|
||||
}
|
||||
|
||||
@@ -56,6 +56,9 @@ export function createMockRedis() {
|
||||
exec: vi.fn().mockResolvedValue([]),
|
||||
})),
|
||||
|
||||
// Scripting
|
||||
eval: vi.fn().mockResolvedValue(0),
|
||||
|
||||
// Connection
|
||||
ping: vi.fn().mockResolvedValue('PONG'),
|
||||
quit: vi.fn().mockResolvedValue('OK'),
|
||||
|
||||
Reference in New Issue
Block a user