fix(mothership): Use heartbeat mechanism for chat locks (#4286)

This commit is contained in:
Theodore Li
2026-04-24 11:36:50 -07:00
committed by GitHub
parent ccb5f1e690
commit 04f1d015f3
7 changed files with 256 additions and 3 deletions

View File

@@ -210,6 +210,7 @@ export function createSSEStream(params: StreamingOrchestrationParams): ReadableS
const abortPoller = startAbortPoller(streamId, abortController, {
requestId,
chatId,
})
publisher.startKeepalive()

View File

@@ -0,0 +1,120 @@
/**
* @vitest-environment node
*/
import { redisConfigMock, redisConfigMockFns } from '@sim/testing'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
const { mockHasAbortMarker, mockClearAbortMarker, mockWriteAbortMarker } = vi.hoisted(() => ({
mockHasAbortMarker: vi.fn().mockResolvedValue(false),
mockClearAbortMarker: vi.fn().mockResolvedValue(undefined),
mockWriteAbortMarker: vi.fn().mockResolvedValue(undefined),
}))
vi.mock('@/lib/core/config/redis', () => redisConfigMock)
vi.mock('@/lib/copilot/request/session/buffer', () => ({
hasAbortMarker: mockHasAbortMarker,
clearAbortMarker: mockClearAbortMarker,
writeAbortMarker: mockWriteAbortMarker,
}))
vi.mock('@/lib/copilot/request/otel', () => ({
withCopilotSpan: (_span: unknown, _attrs: unknown, fn: (span: unknown) => unknown) =>
fn({ setAttribute: vi.fn() }),
}))
import { startAbortPoller } from '@/lib/copilot/request/session/abort'
describe('startAbortPoller heartbeat', () => {
beforeEach(() => {
vi.clearAllMocks()
vi.useFakeTimers()
mockHasAbortMarker.mockResolvedValue(false)
redisConfigMockFns.mockExtendLock.mockResolvedValue(true)
})
afterEach(() => {
vi.useRealTimers()
})
it('extends the chat stream lock approximately every heartbeat interval', async () => {
const controller = new AbortController()
const streamId = 'stream-heartbeat-1'
const chatId = 'chat-heartbeat-1'
const interval = startAbortPoller(streamId, controller, { chatId })
try {
await vi.advanceTimersByTimeAsync(15_000)
expect(redisConfigMockFns.mockExtendLock).not.toHaveBeenCalled()
await vi.advanceTimersByTimeAsync(6_000)
expect(redisConfigMockFns.mockExtendLock).toHaveBeenCalledTimes(1)
expect(redisConfigMockFns.mockExtendLock).toHaveBeenLastCalledWith(
`copilot:chat-stream-lock:${chatId}`,
streamId,
60
)
await vi.advanceTimersByTimeAsync(20_000)
expect(redisConfigMockFns.mockExtendLock).toHaveBeenCalledTimes(2)
await vi.advanceTimersByTimeAsync(20_000)
expect(redisConfigMockFns.mockExtendLock).toHaveBeenCalledTimes(3)
} finally {
clearInterval(interval)
}
})
it('does not extend the lock when no chatId is passed (backward compat)', async () => {
const controller = new AbortController()
const interval = startAbortPoller('stream-no-chat', controller, {})
try {
await vi.advanceTimersByTimeAsync(90_000)
expect(redisConfigMockFns.mockExtendLock).not.toHaveBeenCalled()
} finally {
clearInterval(interval)
}
})
it('retries on the next tick when extendLock throws (no 20s backoff)', async () => {
const controller = new AbortController()
const streamId = 'stream-retry'
const chatId = 'chat-retry'
redisConfigMockFns.mockExtendLock.mockRejectedValueOnce(new Error('redis down'))
const interval = startAbortPoller(streamId, controller, { chatId })
try {
await vi.advanceTimersByTimeAsync(20_000)
expect(redisConfigMockFns.mockExtendLock).toHaveBeenCalledTimes(1)
await vi.advanceTimersByTimeAsync(1_000)
expect(redisConfigMockFns.mockExtendLock).toHaveBeenCalledTimes(2)
} finally {
clearInterval(interval)
}
})
it('stops heartbeating after ownership is lost', async () => {
const controller = new AbortController()
const streamId = 'stream-lost'
const chatId = 'chat-lost'
redisConfigMockFns.mockExtendLock.mockResolvedValueOnce(false)
const interval = startAbortPoller(streamId, controller, { chatId })
try {
await vi.advanceTimersByTimeAsync(21_000)
expect(redisConfigMockFns.mockExtendLock).toHaveBeenCalledTimes(1)
await vi.advanceTimersByTimeAsync(60_000)
expect(redisConfigMockFns.mockExtendLock).toHaveBeenCalledTimes(1)
} finally {
clearInterval(interval)
}
})
})

View File

@@ -5,7 +5,7 @@ import { AbortBackend } from '@/lib/copilot/generated/trace-attribute-values-v1'
import { TraceAttr } from '@/lib/copilot/generated/trace-attributes-v1'
import { TraceSpan } from '@/lib/copilot/generated/trace-spans-v1'
import { withCopilotSpan } from '@/lib/copilot/request/otel'
import { acquireLock, getRedisClient, releaseLock } from '@/lib/core/config/redis'
import { acquireLock, extendLock, getRedisClient, releaseLock } from '@/lib/core/config/redis'
import { AbortReason } from './abort-reason'
import { clearAbortMarker, hasAbortMarker, writeAbortMarker } from './buffer'
@@ -18,7 +18,22 @@ const pendingChatStreams = new Map<
>()
const DEFAULT_ABORT_POLL_MS = 1000
const CHAT_STREAM_LOCK_TTL_SECONDS = 2 * 60 * 60
/**
* TTL for the per-chat stream lock. Kept short so that if the Sim pod
* holding the lock dies (SIGKILL, OOM, a SIGTERM drain that doesn't
* reach the release path), the lock self-heals inside a minute rather
* than stranding the chat for hours. A live stream keeps the lock alive
* via `CHAT_STREAM_LOCK_HEARTBEAT_INTERVAL_MS` heartbeats.
*/
const CHAT_STREAM_LOCK_TTL_SECONDS = 60
/**
* Heartbeat cadence for extending the per-chat stream lock. Set to a
* third of the TTL so one missed beat still leaves room for recovery
* before the lock expires under a still-live stream.
*/
const CHAT_STREAM_LOCK_HEARTBEAT_INTERVAL_MS = 20_000
function registerPendingChatStream(chatId: string, streamId: string): void {
let resolve!: () => void
@@ -262,10 +277,14 @@ const pollingStreams = new Set<string>()
export function startAbortPoller(
streamId: string,
abortController: AbortController,
options?: { pollMs?: number; requestId?: string }
options?: { pollMs?: number; requestId?: string; chatId?: string }
): ReturnType<typeof setInterval> {
const pollMs = options?.pollMs ?? DEFAULT_ABORT_POLL_MS
const requestId = options?.requestId
const chatId = options?.chatId
let lastHeartbeatAt = Date.now()
let heartbeatOwnershipLost = false
return setInterval(() => {
if (pollingStreams.has(streamId)) return
@@ -287,6 +306,33 @@ export function startAbortPoller(
} finally {
pollingStreams.delete(streamId)
}
if (!chatId || heartbeatOwnershipLost) return
if (Date.now() - lastHeartbeatAt < CHAT_STREAM_LOCK_HEARTBEAT_INTERVAL_MS) return
try {
const owned = await extendLock(
getChatStreamLockKey(chatId),
streamId,
CHAT_STREAM_LOCK_TTL_SECONDS
)
lastHeartbeatAt = Date.now()
if (!owned) {
heartbeatOwnershipLost = true
logger.warn('Lost ownership of chat stream lock — stopping heartbeat', {
chatId,
streamId,
...(requestId ? { requestId } : {}),
})
}
} catch (error) {
logger.warn('Failed to extend chat stream lock TTL', {
chatId,
streamId,
...(requestId ? { requestId } : {}),
error: toError(error).message,
})
}
})()
}, pollMs)
}

View File

@@ -15,6 +15,7 @@ vi.mock('ioredis', () => ({
import {
closeRedisConnection,
extendLock,
getRedisClient,
onRedisReconnect,
resetForTesting,
@@ -120,6 +121,48 @@ describe('redis config', () => {
})
})
describe('extendLock', () => {
const lockKey = 'copilot:chat-stream-lock:chat-1'
const value = 'stream-abc'
const ttlSeconds = 60
it('returns true when the caller still owns the lock and EXPIRE succeeds', async () => {
mockRedisInstance.eval.mockResolvedValueOnce(1)
const extended = await extendLock(lockKey, value, ttlSeconds)
expect(extended).toBe(true)
expect(mockRedisInstance.eval).toHaveBeenCalledWith(
expect.stringContaining('expire'),
1,
lockKey,
value,
ttlSeconds
)
})
it('returns false when the value does not match (lock owned by another)', async () => {
mockRedisInstance.eval.mockResolvedValueOnce(0)
const extended = await extendLock(lockKey, value, ttlSeconds)
expect(extended).toBe(false)
})
it('returns true as a no-op when Redis is unavailable', async () => {
vi.resetModules()
vi.doMock('@/lib/core/config/env', () =>
createEnvMock({ REDIS_URL: undefined as unknown as string })
)
const { extendLock: extendLockNoRedis } = await import('@/lib/core/config/redis')
const extended = await extendLockNoRedis(lockKey, value, ttlSeconds)
expect(extended).toBe(true)
vi.doUnmock('@/lib/core/config/env')
})
})
describe('retryStrategy', () => {
function captureRetryStrategy(): (times: number) => number {
let capturedConfig: Record<string, unknown> = {}

View File

@@ -136,6 +136,21 @@ else
end
`
/**
* Lua script for safe lock TTL extension.
* Only refreshes the expiry if the value matches (ownership verification),
* so a stale heartbeat from a prior owner cannot extend a lock currently
* held by someone else after a TTL eviction.
* Returns 1 if the TTL was extended, 0 if not (value mismatch or key gone).
*/
const EXTEND_LOCK_SCRIPT = `
if redis.call("get", KEYS[1]) == ARGV[1] then
return redis.call("expire", KEYS[1], ARGV[2])
else
return 0
end
`
/**
* Acquire a distributed lock using Redis SET NX.
* Returns true if lock acquired, false if already held.
@@ -175,6 +190,29 @@ export async function releaseLock(lockKey: string, value: string): Promise<boole
return result === 1
}
/**
* Extend the TTL of a distributed lock if still owned by the caller.
* Returns true if the caller still owns the lock and the TTL was refreshed,
* false if the lock has been taken over by another owner or has expired.
*
* When Redis is not available, returns true (no-op) to match the behavior
* of `acquireLock` / `releaseLock`: single-replica deployments without
* Redis never held a real lock, so heartbeat success is implicit.
*/
export async function extendLock(
lockKey: string,
value: string,
expirySeconds: number
): Promise<boolean> {
const redis = getRedisClient()
if (!redis) {
return true
}
const result = await redis.eval(EXTEND_LOCK_SCRIPT, 1, lockKey, value, expirySeconds)
return result === 1
}
/**
* Close the Redis connection.
* Use for graceful shutdown.

View File

@@ -17,6 +17,7 @@ export const redisConfigMockFns = {
mockOnRedisReconnect: vi.fn(),
mockAcquireLock: vi.fn().mockResolvedValue(true),
mockReleaseLock: vi.fn().mockResolvedValue(true),
mockExtendLock: vi.fn().mockResolvedValue(true),
mockCloseRedisConnection: vi.fn().mockResolvedValue(undefined),
mockResetForTesting: vi.fn(),
}
@@ -34,6 +35,7 @@ export const redisConfigMock = {
onRedisReconnect: redisConfigMockFns.mockOnRedisReconnect,
acquireLock: redisConfigMockFns.mockAcquireLock,
releaseLock: redisConfigMockFns.mockReleaseLock,
extendLock: redisConfigMockFns.mockExtendLock,
closeRedisConnection: redisConfigMockFns.mockCloseRedisConnection,
resetForTesting: redisConfigMockFns.mockResetForTesting,
}

View File

@@ -56,6 +56,9 @@ export function createMockRedis() {
exec: vi.fn().mockResolvedValue([]),
})),
// Scripting
eval: vi.fn().mockResolvedValue(0),
// Connection
ping: vi.fn().mockResolvedValue('PONG'),
quit: vi.fn().mockResolvedValue('OK'),