mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-07 05:15:09 -05:00
Compare commits
3 Commits
dev
...
pwuts/secr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9a41d5c1b8 | ||
|
|
86df73abbe | ||
|
|
df6f78f74c |
@@ -45,10 +45,7 @@ async def create_chat_session(
|
||||
successfulAgentRuns=SafeJson({}),
|
||||
successfulAgentSchedules=SafeJson({}),
|
||||
)
|
||||
return await PrismaChatSession.prisma().create(
|
||||
data=data,
|
||||
include={"Messages": True},
|
||||
)
|
||||
return await PrismaChatSession.prisma().create(data=data)
|
||||
|
||||
|
||||
async def update_chat_session(
|
||||
|
||||
@@ -266,12 +266,24 @@ async def stream_chat_post(
|
||||
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
stream_start_time = time.perf_counter()
|
||||
logger.info(
|
||||
f"[TIMING] stream_chat_post STARTED for session={session_id}, "
|
||||
f"message_len={len(request.message)}"
|
||||
)
|
||||
|
||||
session = await _validate_and_get_session(session_id, user_id)
|
||||
logger.info(
|
||||
f"[TIMING] stream_chat_post session validated in "
|
||||
f"{(time.perf_counter() - stream_start_time)*1000:.1f}ms"
|
||||
)
|
||||
|
||||
# Create a task in the stream registry for reconnection support
|
||||
task_id = str(uuid_module.uuid4())
|
||||
operation_id = str(uuid_module.uuid4())
|
||||
task_create_start = time.perf_counter()
|
||||
await stream_registry.create_task(
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
@@ -280,14 +292,37 @@ async def stream_chat_post(
|
||||
tool_name="chat",
|
||||
operation_id=operation_id,
|
||||
)
|
||||
logger.info(
|
||||
f"[TIMING][routes] stream_registry.create_task completed in "
|
||||
f"{(time.perf_counter() - task_create_start)*1000:.1f}ms, "
|
||||
f"task_id={task_id}, session={session_id}"
|
||||
)
|
||||
|
||||
# Background task that runs the AI generation independently of SSE connection
|
||||
async def run_ai_generation():
|
||||
import time as time_module
|
||||
|
||||
gen_start_time = time_module.perf_counter()
|
||||
logger.info(
|
||||
f"[TIMING][routes] run_ai_generation STARTED, task_id={task_id}, "
|
||||
f"session={session_id}"
|
||||
)
|
||||
first_chunk_time = None
|
||||
chunk_count = 0
|
||||
try:
|
||||
# Emit a start event with task_id for reconnection
|
||||
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
|
||||
await stream_registry.publish_chunk(task_id, start_chunk)
|
||||
logger.info(
|
||||
f"[TIMING][routes] StreamStart published at "
|
||||
f"{(time_module.perf_counter() - gen_start_time)*1000:.1f}ms, "
|
||||
f"task_id={task_id}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[TIMING][routes] Calling chat_service.stream_chat_completion, "
|
||||
f"task_id={task_id}"
|
||||
)
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
request.message,
|
||||
@@ -296,54 +331,134 @@ async def stream_chat_post(
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
context=request.context,
|
||||
):
|
||||
chunk_count += 1
|
||||
if first_chunk_time is None:
|
||||
first_chunk_time = time_module.perf_counter()
|
||||
logger.info(
|
||||
f"[TIMING][routes] FIRST AI CHUNK received at "
|
||||
f"{(first_chunk_time - gen_start_time)*1000:.1f}ms, "
|
||||
f"chunk_type={type(chunk).__name__}, task_id={task_id}"
|
||||
)
|
||||
# Write to Redis (subscribers will receive via XREAD)
|
||||
await stream_registry.publish_chunk(task_id, chunk)
|
||||
|
||||
gen_end_time = time_module.perf_counter()
|
||||
logger.info(
|
||||
f"[TIMING][routes] run_ai_generation COMPLETED, "
|
||||
f"total_time={(gen_end_time - gen_start_time)*1000:.1f}ms, "
|
||||
f"time_to_first_chunk={(first_chunk_time - gen_start_time)*1000:.1f}ms if first_chunk_time else 'N/A', "
|
||||
f"chunk_count={chunk_count}, task_id={task_id}"
|
||||
)
|
||||
|
||||
# Mark task as completed
|
||||
await stream_registry.mark_task_completed(task_id, "completed")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in background AI generation for session {session_id}: {e}"
|
||||
f"[TIMING][routes] Error in run_ai_generation for session {session_id} "
|
||||
f"after {(time_module.perf_counter() - gen_start_time)*1000:.1f}ms: {e}"
|
||||
)
|
||||
await stream_registry.mark_task_completed(task_id, "failed")
|
||||
|
||||
# Start the AI generation in a background task
|
||||
bg_task = asyncio.create_task(run_ai_generation())
|
||||
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
||||
logger.info(
|
||||
f"[TIMING][routes] Background task started, total setup time="
|
||||
f"{(time.perf_counter() - stream_start_time)*1000:.1f}ms, task_id={task_id}"
|
||||
)
|
||||
|
||||
# SSE endpoint that subscribes to the task's stream
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
import time as time_module
|
||||
|
||||
event_gen_start = time_module.perf_counter()
|
||||
logger.info(f"[TIMING][routes] event_generator STARTED, task_id={task_id}")
|
||||
subscriber_queue = None
|
||||
first_chunk_yielded = False
|
||||
chunks_yielded = 0
|
||||
try:
|
||||
# Subscribe to the task stream (this replays existing messages + live updates)
|
||||
subscribe_start = time_module.perf_counter()
|
||||
logger.info(
|
||||
f"[TIMING][routes] Calling subscribe_to_task, task_id={task_id}"
|
||||
)
|
||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||
task_id=task_id,
|
||||
user_id=user_id,
|
||||
last_message_id="0-0", # Get all messages from the beginning
|
||||
)
|
||||
logger.info(
|
||||
f"[TIMING][routes] subscribe_to_task completed in "
|
||||
f"{(time_module.perf_counter() - subscribe_start)*1000:.1f}ms, "
|
||||
f"queue_obtained={subscriber_queue is not None}, task_id={task_id}"
|
||||
)
|
||||
|
||||
if subscriber_queue is None:
|
||||
logger.info(
|
||||
f"[TIMING][routes] subscriber_queue is None, yielding finish, "
|
||||
f"task_id={task_id}"
|
||||
)
|
||||
yield StreamFinish().to_sse()
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
# Read from the subscriber queue and yield to SSE
|
||||
logger.info(
|
||||
f"[TIMING][routes] Starting to read from subscriber_queue, "
|
||||
f"task_id={task_id}"
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
queue_wait_start = time_module.perf_counter()
|
||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
||||
queue_wait_time = (
|
||||
time_module.perf_counter() - queue_wait_start
|
||||
) * 1000
|
||||
chunks_yielded += 1
|
||||
|
||||
if not first_chunk_yielded:
|
||||
first_chunk_yielded = True
|
||||
logger.info(
|
||||
f"[TIMING][routes] FIRST CHUNK from queue at "
|
||||
f"{(time_module.perf_counter() - event_gen_start)*1000:.1f}ms, "
|
||||
f"chunk_type={type(chunk).__name__}, "
|
||||
f"queue_wait={queue_wait_time:.1f}ms, task_id={task_id}"
|
||||
)
|
||||
elif chunks_yielded % 50 == 0:
|
||||
logger.info(
|
||||
f"[TIMING][routes] Chunk #{chunks_yielded} yielded, "
|
||||
f"chunk_type={type(chunk).__name__}, task_id={task_id}"
|
||||
)
|
||||
|
||||
yield chunk.to_sse()
|
||||
|
||||
# Check for finish signal
|
||||
if isinstance(chunk, StreamFinish):
|
||||
logger.info(
|
||||
f"[TIMING][routes] StreamFinish received, total chunks={chunks_yielded}, "
|
||||
f"total_time={(time_module.perf_counter() - event_gen_start)*1000:.1f}ms, "
|
||||
f"task_id={task_id}"
|
||||
)
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
# Send heartbeat to keep connection alive
|
||||
logger.info(
|
||||
f"[TIMING][routes] Heartbeat timeout, sending heartbeat, "
|
||||
f"chunks_so_far={chunks_yielded}, task_id={task_id}"
|
||||
)
|
||||
yield StreamHeartbeat().to_sse()
|
||||
|
||||
except GeneratorExit:
|
||||
logger.info(
|
||||
f"[TIMING][routes] GeneratorExit (client disconnected), "
|
||||
f"chunks_yielded={chunks_yielded}, task_id={task_id}"
|
||||
)
|
||||
pass # Client disconnected - background task continues
|
||||
except Exception as e:
|
||||
logger.error(f"Error in SSE stream for task {task_id}: {e}")
|
||||
logger.error(
|
||||
f"[TIMING][routes] Error in event_generator for task {task_id} "
|
||||
f"after {(time_module.perf_counter() - event_gen_start)*1000:.1f}ms: {e}"
|
||||
)
|
||||
finally:
|
||||
# Unsubscribe when client disconnects or stream ends to prevent resource leak
|
||||
if subscriber_queue is not None:
|
||||
@@ -357,6 +472,11 @@ async def stream_chat_post(
|
||||
exc_info=True,
|
||||
)
|
||||
# AI SDK protocol termination - always yield even if unsubscribe fails
|
||||
logger.info(
|
||||
f"[TIMING][routes] event_generator FINISHED, total_time="
|
||||
f"{(time_module.perf_counter() - event_gen_start)*1000:.1f}ms, "
|
||||
f"chunks_yielded={chunks_yielded}, task_id={task_id}"
|
||||
)
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
|
||||
@@ -371,20 +371,25 @@ async def stream_chat_completion(
|
||||
ValueError: If max_context_messages is exceeded
|
||||
|
||||
"""
|
||||
completion_start = time.monotonic()
|
||||
logger.info(
|
||||
f"Streaming chat completion for session {session_id} for message {message} and user id {user_id}. Message is user message: {is_user_message}"
|
||||
f"[TIMING][service] stream_chat_completion STARTED, session={session_id}, "
|
||||
f"message_len={len(message) if message else 0}, is_user_message={is_user_message}"
|
||||
)
|
||||
|
||||
# Only fetch from Redis if session not provided (initial call)
|
||||
if session is None:
|
||||
fetch_start = time.monotonic()
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
logger.info(
|
||||
f"Fetched session from Redis: {session.session_id if session else 'None'}, "
|
||||
f"[TIMING][service] get_chat_session took "
|
||||
f"{(time.monotonic() - fetch_start)*1000:.1f}ms, "
|
||||
f"session={session.session_id if session else 'None'}, "
|
||||
f"message_count={len(session.messages) if session else 0}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Using provided session object: {session.session_id}, "
|
||||
f"[TIMING][service] Using provided session object: {session.session_id}, "
|
||||
f"message_count={len(session.messages)}"
|
||||
)
|
||||
|
||||
@@ -412,11 +417,16 @@ async def stream_chat_completion(
|
||||
message_length=len(message),
|
||||
)
|
||||
|
||||
upsert_start = time.monotonic()
|
||||
logger.info(
|
||||
f"Upserting session: {session.session_id} with user id {session.user_id}, "
|
||||
f"[TIMING][service] Upserting session: {session.session_id}, "
|
||||
f"message_count={len(session.messages)}"
|
||||
)
|
||||
session = await upsert_chat_session(session)
|
||||
logger.info(
|
||||
f"[TIMING][service] upsert_chat_session took "
|
||||
f"{(time.monotonic() - upsert_start)*1000:.1f}ms, session={session_id}"
|
||||
)
|
||||
assert session, "Session not found"
|
||||
|
||||
# Generate title for new sessions on first user message (non-blocking)
|
||||
@@ -454,7 +464,12 @@ async def stream_chat_completion(
|
||||
asyncio.create_task(_update_title())
|
||||
|
||||
# Build system prompt with business understanding
|
||||
prompt_start = time.monotonic()
|
||||
system_prompt, understanding = await _build_system_prompt(user_id)
|
||||
logger.info(
|
||||
f"[TIMING][service] _build_system_prompt took "
|
||||
f"{(time.monotonic() - prompt_start)*1000:.1f}ms, session={session_id}"
|
||||
)
|
||||
|
||||
# Initialize variables for streaming
|
||||
assistant_response = ChatMessage(
|
||||
@@ -483,9 +498,17 @@ async def stream_chat_completion(
|
||||
text_block_id = str(uuid_module.uuid4())
|
||||
|
||||
# Yield message start
|
||||
logger.info(
|
||||
f"[TIMING][service] Setup complete, yielding StreamStart at "
|
||||
f"{(time.monotonic() - completion_start)*1000:.1f}ms, session={session_id}"
|
||||
)
|
||||
yield StreamStart(messageId=message_id)
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
f"[TIMING][service] Calling _stream_chat_chunks at "
|
||||
f"{(time.monotonic() - completion_start)*1000:.1f}ms, session={session_id}"
|
||||
)
|
||||
async for chunk in _stream_chat_chunks(
|
||||
session=session,
|
||||
tools=tools,
|
||||
@@ -893,9 +916,15 @@ async def _stream_chat_chunks(
|
||||
SSE formatted JSON response objects
|
||||
|
||||
"""
|
||||
import time as time_module
|
||||
|
||||
stream_chunks_start = time_module.perf_counter()
|
||||
model = config.model
|
||||
|
||||
logger.info("Starting pure chat stream")
|
||||
logger.info(
|
||||
f"[TIMING][service] _stream_chat_chunks STARTED, session={session.session_id}, "
|
||||
f"message_count={len(session.messages)}"
|
||||
)
|
||||
|
||||
messages = session.to_openai_messages()
|
||||
if system_prompt:
|
||||
@@ -906,12 +935,18 @@ async def _stream_chat_chunks(
|
||||
messages = [system_message] + messages
|
||||
|
||||
# Apply context window management
|
||||
context_start = time_module.perf_counter()
|
||||
context_result = await _manage_context_window(
|
||||
messages=messages,
|
||||
model=model,
|
||||
api_key=config.api_key,
|
||||
base_url=config.base_url,
|
||||
)
|
||||
logger.info(
|
||||
f"[TIMING][service] _manage_context_window took "
|
||||
f"{(time_module.perf_counter() - context_start)*1000:.1f}ms, "
|
||||
f"session={session.session_id}"
|
||||
)
|
||||
|
||||
if context_result.error:
|
||||
if "System prompt dropped" in context_result.error:
|
||||
@@ -947,8 +982,10 @@ async def _stream_chat_chunks(
|
||||
while retry_count <= MAX_RETRIES:
|
||||
try:
|
||||
logger.info(
|
||||
f"Creating OpenAI chat completion stream..."
|
||||
f"{f' (retry {retry_count}/{MAX_RETRIES})' if retry_count > 0 else ''}"
|
||||
f"[TIMING][service] Creating OpenAI chat completion stream "
|
||||
f"at {(time_module.perf_counter() - stream_chunks_start)*1000:.1f}ms"
|
||||
f"{f' (retry {retry_count}/{MAX_RETRIES})' if retry_count > 0 else ''}, "
|
||||
f"session={session.session_id}"
|
||||
)
|
||||
|
||||
# Build extra_body for OpenRouter tracing and PostHog analytics
|
||||
@@ -965,6 +1002,7 @@ async def _stream_chat_chunks(
|
||||
:128
|
||||
] # OpenRouter limit
|
||||
|
||||
api_call_start = time_module.perf_counter()
|
||||
stream = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=cast(list[ChatCompletionMessageParam], messages),
|
||||
@@ -974,6 +1012,11 @@ async def _stream_chat_chunks(
|
||||
stream_options=ChatCompletionStreamOptionsParam(include_usage=True),
|
||||
extra_body=extra_body,
|
||||
)
|
||||
logger.info(
|
||||
f"[TIMING][service] OpenAI client.chat.completions.create returned "
|
||||
f"(stream object) in {(time_module.perf_counter() - api_call_start)*1000:.1f}ms, "
|
||||
f"session={session.session_id}"
|
||||
)
|
||||
|
||||
# Variables to accumulate tool calls
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
@@ -984,10 +1027,13 @@ async def _stream_chat_chunks(
|
||||
|
||||
# Track if we've started the text block
|
||||
text_started = False
|
||||
first_content_chunk = True
|
||||
chunk_count = 0
|
||||
|
||||
# Process the stream
|
||||
chunk: ChatCompletionChunk
|
||||
async for chunk in stream:
|
||||
chunk_count += 1
|
||||
if chunk.usage:
|
||||
yield StreamUsage(
|
||||
promptTokens=chunk.usage.prompt_tokens,
|
||||
@@ -1010,6 +1056,15 @@ async def _stream_chat_chunks(
|
||||
if not text_started and text_block_id:
|
||||
yield StreamTextStart(id=text_block_id)
|
||||
text_started = True
|
||||
# Log timing for first content chunk
|
||||
if first_content_chunk:
|
||||
first_content_chunk = False
|
||||
logger.info(
|
||||
f"[TIMING][service] FIRST CONTENT CHUNK from OpenAI at "
|
||||
f"{(time_module.perf_counter() - api_call_start)*1000:.1f}ms "
|
||||
f"(since API call), chunk_count={chunk_count}, "
|
||||
f"session={session.session_id}"
|
||||
)
|
||||
# Stream the text delta
|
||||
text_response = StreamTextDelta(
|
||||
id=text_block_id or "",
|
||||
@@ -1066,7 +1121,13 @@ async def _stream_chat_chunks(
|
||||
toolName=tool_calls[idx]["function"]["name"],
|
||||
)
|
||||
emitted_start_for_idx.add(idx)
|
||||
logger.info(f"Stream complete. Finish reason: {finish_reason}")
|
||||
stream_duration = (time_module.perf_counter() - api_call_start) * 1000
|
||||
logger.info(
|
||||
f"[TIMING][service] OpenAI stream COMPLETE, "
|
||||
f"finish_reason={finish_reason}, duration={stream_duration:.1f}ms, "
|
||||
f"chunk_count={chunk_count}, tool_calls={len(tool_calls)}, "
|
||||
f"session={session.session_id}"
|
||||
)
|
||||
|
||||
# Yield all accumulated tool calls after the stream is complete
|
||||
# This ensures all tool call arguments have been fully received
|
||||
@@ -1086,6 +1147,11 @@ async def _stream_chat_chunks(
|
||||
# Re-raise to trigger retry logic in the parent function
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"[TIMING][service] _stream_chat_chunks COMPLETED, "
|
||||
f"total_time={(time_module.perf_counter() - stream_chunks_start)*1000:.1f}ms, "
|
||||
f"session={session.session_id}"
|
||||
)
|
||||
yield StreamFinish()
|
||||
return
|
||||
except Exception as e:
|
||||
|
||||
@@ -104,6 +104,14 @@ async def create_task(
|
||||
Returns:
|
||||
The created ActiveTask instance (metadata only)
|
||||
"""
|
||||
import time
|
||||
|
||||
start_time = time.perf_counter()
|
||||
logger.info(
|
||||
f"[TIMING][stream_registry] create_task STARTED, task_id={task_id}, "
|
||||
f"session={session_id}"
|
||||
)
|
||||
|
||||
task = ActiveTask(
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
@@ -114,10 +122,17 @@ async def create_task(
|
||||
)
|
||||
|
||||
# Store metadata in Redis
|
||||
redis_start = time.perf_counter()
|
||||
redis = await get_redis_async()
|
||||
logger.info(
|
||||
f"[TIMING][stream_registry] get_redis_async took "
|
||||
f"{(time.perf_counter() - redis_start)*1000:.1f}ms, task_id={task_id}"
|
||||
)
|
||||
|
||||
meta_key = _get_task_meta_key(task_id)
|
||||
op_key = _get_operation_mapping_key(operation_id)
|
||||
|
||||
hset_start = time.perf_counter()
|
||||
await redis.hset( # type: ignore[misc]
|
||||
meta_key,
|
||||
mapping={
|
||||
@@ -131,12 +146,20 @@ async def create_task(
|
||||
"created_at": task.created_at.isoformat(),
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
f"[TIMING][stream_registry] redis.hset took "
|
||||
f"{(time.perf_counter() - hset_start)*1000:.1f}ms, task_id={task_id}"
|
||||
)
|
||||
|
||||
await redis.expire(meta_key, config.stream_ttl)
|
||||
|
||||
# Create operation_id -> task_id mapping for webhook lookups
|
||||
await redis.set(op_key, task_id, ex=config.stream_ttl)
|
||||
|
||||
logger.debug(f"Created task {task_id} for session {session_id}")
|
||||
logger.info(
|
||||
f"[TIMING][stream_registry] create_task COMPLETED in "
|
||||
f"{(time.perf_counter() - start_time)*1000:.1f}ms, task_id={task_id}"
|
||||
)
|
||||
|
||||
return task
|
||||
|
||||
@@ -156,6 +179,10 @@ async def publish_chunk(
|
||||
Returns:
|
||||
The Redis Stream message ID
|
||||
"""
|
||||
import time
|
||||
|
||||
start_time = time.perf_counter()
|
||||
chunk_type = type(chunk).__name__
|
||||
chunk_json = chunk.model_dump_json()
|
||||
message_id = "0-0"
|
||||
|
||||
@@ -164,18 +191,34 @@ async def publish_chunk(
|
||||
stream_key = _get_task_stream_key(task_id)
|
||||
|
||||
# Write to Redis Stream for persistence and real-time delivery
|
||||
xadd_start = time.perf_counter()
|
||||
raw_id = await redis.xadd(
|
||||
stream_key,
|
||||
{"data": chunk_json},
|
||||
maxlen=config.stream_max_length,
|
||||
)
|
||||
xadd_time = (time.perf_counter() - xadd_start) * 1000
|
||||
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
|
||||
|
||||
# Set TTL on stream to match task metadata TTL
|
||||
await redis.expire(stream_key, config.stream_ttl)
|
||||
|
||||
total_time = (time.perf_counter() - start_time) * 1000
|
||||
# Only log timing for significant chunks or slow operations
|
||||
if (
|
||||
chunk_type
|
||||
in ("StreamStart", "StreamFinish", "StreamTextStart", "StreamTextEnd")
|
||||
or total_time > 50
|
||||
):
|
||||
logger.info(
|
||||
f"[TIMING][stream_registry] publish_chunk {chunk_type} in "
|
||||
f"{total_time:.1f}ms (xadd={xadd_time:.1f}ms), "
|
||||
f"task_id={task_id}, msg_id={message_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to publish chunk for task {task_id}: {e}",
|
||||
f"[TIMING][stream_registry] Failed to publish chunk {chunk_type} "
|
||||
f"for task {task_id} after {(time.perf_counter() - start_time)*1000:.1f}ms: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
@@ -200,12 +243,28 @@ async def subscribe_to_task(
|
||||
An asyncio Queue that will receive stream chunks, or None if task not found
|
||||
or user doesn't have access
|
||||
"""
|
||||
import time
|
||||
|
||||
start_time = time.perf_counter()
|
||||
logger.info(
|
||||
f"[TIMING][stream_registry] subscribe_to_task STARTED, task_id={task_id}, "
|
||||
f"last_message_id={last_message_id}"
|
||||
)
|
||||
|
||||
redis_start = time.perf_counter()
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_task_meta_key(task_id)
|
||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
logger.info(
|
||||
f"[TIMING][stream_registry] Redis hgetall took "
|
||||
f"{(time.perf_counter() - redis_start)*1000:.1f}ms, task_id={task_id}"
|
||||
)
|
||||
|
||||
if not meta:
|
||||
logger.debug(f"Task {task_id} not found in Redis")
|
||||
logger.info(
|
||||
f"[TIMING][stream_registry] Task {task_id} not found in Redis after "
|
||||
f"{(time.perf_counter() - start_time)*1000:.1f}ms"
|
||||
)
|
||||
return None
|
||||
|
||||
# Note: Redis client uses decode_responses=True, so keys are strings
|
||||
@@ -216,8 +275,8 @@ async def subscribe_to_task(
|
||||
if task_user_id:
|
||||
if user_id != task_user_id:
|
||||
logger.warning(
|
||||
f"User {user_id} denied access to task {task_id} "
|
||||
f"owned by {task_user_id}"
|
||||
f"[TIMING][stream_registry] User {user_id} denied access to task "
|
||||
f"{task_id} owned by {task_user_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -225,7 +284,13 @@ async def subscribe_to_task(
|
||||
stream_key = _get_task_stream_key(task_id)
|
||||
|
||||
# Step 1: Replay messages from Redis Stream
|
||||
xread_start = time.perf_counter()
|
||||
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
|
||||
xread_time = (time.perf_counter() - xread_start) * 1000
|
||||
logger.info(
|
||||
f"[TIMING][stream_registry] Redis xread (replay) took {xread_time:.1f}ms, "
|
||||
f"task_id={task_id}, task_status={task_status}"
|
||||
)
|
||||
|
||||
replayed_count = 0
|
||||
replay_last_id = last_message_id
|
||||
@@ -244,10 +309,17 @@ async def subscribe_to_task(
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to replay message: {e}")
|
||||
|
||||
logger.debug(f"Task {task_id}: replayed {replayed_count} messages")
|
||||
logger.info(
|
||||
f"[TIMING][stream_registry] Replayed {replayed_count} messages, "
|
||||
f"replay_last_id={replay_last_id}, task_id={task_id}"
|
||||
)
|
||||
|
||||
# Step 2: If task is still running, start stream listener for live updates
|
||||
if task_status == "running":
|
||||
logger.info(
|
||||
f"[TIMING][stream_registry] Task still running, starting _stream_listener, "
|
||||
f"task_id={task_id}"
|
||||
)
|
||||
listener_task = asyncio.create_task(
|
||||
_stream_listener(task_id, subscriber_queue, replay_last_id)
|
||||
)
|
||||
@@ -255,8 +327,17 @@ async def subscribe_to_task(
|
||||
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
|
||||
else:
|
||||
# Task is completed/failed - add finish marker
|
||||
logger.info(
|
||||
f"[TIMING][stream_registry] Task already {task_status}, adding StreamFinish, "
|
||||
f"task_id={task_id}"
|
||||
)
|
||||
await subscriber_queue.put(StreamFinish())
|
||||
|
||||
logger.info(
|
||||
f"[TIMING][stream_registry] subscribe_to_task COMPLETED in "
|
||||
f"{(time.perf_counter() - start_time)*1000:.1f}ms, "
|
||||
f"replayed={replayed_count}, task_id={task_id}"
|
||||
)
|
||||
return subscriber_queue
|
||||
|
||||
|
||||
@@ -275,9 +356,20 @@ async def _stream_listener(
|
||||
subscriber_queue: Queue to deliver messages to
|
||||
last_replayed_id: Last message ID from replay (continue from here)
|
||||
"""
|
||||
import time
|
||||
|
||||
start_time = time.perf_counter()
|
||||
logger.info(
|
||||
f"[TIMING][stream_registry] _stream_listener STARTED, task_id={task_id}, "
|
||||
f"last_replayed_id={last_replayed_id}"
|
||||
)
|
||||
|
||||
queue_id = id(subscriber_queue)
|
||||
# Track the last successfully delivered message ID for recovery hints
|
||||
last_delivered_id = last_replayed_id
|
||||
messages_delivered = 0
|
||||
first_message_time = None
|
||||
xread_count = 0
|
||||
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
@@ -287,9 +379,25 @@ async def _stream_listener(
|
||||
while True:
|
||||
# Block for up to 30 seconds waiting for new messages
|
||||
# This allows periodic checking if task is still running
|
||||
xread_start = time.perf_counter()
|
||||
xread_count += 1
|
||||
messages = await redis.xread(
|
||||
{stream_key: current_id}, block=30000, count=100
|
||||
)
|
||||
xread_time = (time.perf_counter() - xread_start) * 1000
|
||||
|
||||
if messages:
|
||||
logger.info(
|
||||
f"[TIMING][stream_registry] _stream_listener xread #{xread_count} "
|
||||
f"returned {sum(len(msgs) for _, msgs in messages)} messages in "
|
||||
f"{xread_time:.1f}ms, task_id={task_id}"
|
||||
)
|
||||
elif xread_time > 1000:
|
||||
# Only log timeouts (30s blocking)
|
||||
logger.info(
|
||||
f"[TIMING][stream_registry] _stream_listener xread #{xread_count} "
|
||||
f"timeout after {xread_time:.1f}ms, task_id={task_id}"
|
||||
)
|
||||
|
||||
if not messages:
|
||||
# Timeout - check if task is still running
|
||||
@@ -326,9 +434,17 @@ async def _stream_listener(
|
||||
)
|
||||
# Update last delivered ID on successful delivery
|
||||
last_delivered_id = current_id
|
||||
messages_delivered += 1
|
||||
if first_message_time is None:
|
||||
first_message_time = time.perf_counter()
|
||||
logger.info(
|
||||
f"[TIMING][stream_registry] _stream_listener FIRST live message "
|
||||
f"delivered at {(first_message_time - start_time)*1000:.1f}ms, "
|
||||
f"chunk_type={type(chunk).__name__}, task_id={task_id}"
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
f"Subscriber queue full for task {task_id}, "
|
||||
f"[TIMING][stream_registry] Subscriber queue full for task {task_id}, "
|
||||
f"message delivery timed out after {QUEUE_PUT_TIMEOUT}s"
|
||||
)
|
||||
# Send overflow error with recovery info
|
||||
@@ -351,15 +467,27 @@ async def _stream_listener(
|
||||
|
||||
# Stop listening on finish
|
||||
if isinstance(chunk, StreamFinish):
|
||||
logger.info(
|
||||
f"[TIMING][stream_registry] _stream_listener received StreamFinish, "
|
||||
f"total_time={(time.perf_counter() - start_time)*1000:.1f}ms, "
|
||||
f"messages_delivered={messages_delivered}, task_id={task_id}"
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing stream message: {e}")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug(f"Stream listener cancelled for task {task_id}")
|
||||
logger.info(
|
||||
f"[TIMING][stream_registry] _stream_listener CANCELLED after "
|
||||
f"{(time.perf_counter() - start_time)*1000:.1f}ms, "
|
||||
f"messages_delivered={messages_delivered}, task_id={task_id}"
|
||||
)
|
||||
raise # Re-raise to propagate cancellation
|
||||
except Exception as e:
|
||||
logger.error(f"Stream listener error for task {task_id}: {e}")
|
||||
logger.error(
|
||||
f"[TIMING][stream_registry] _stream_listener ERROR after "
|
||||
f"{(time.perf_counter() - start_time)*1000:.1f}ms: {e}, task_id={task_id}"
|
||||
)
|
||||
# On error, send finish to unblock subscriber
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
@@ -372,6 +500,12 @@ async def _stream_listener(
|
||||
)
|
||||
finally:
|
||||
# Clean up listener task mapping on exit
|
||||
logger.info(
|
||||
f"[TIMING][stream_registry] _stream_listener FINISHED, "
|
||||
f"total_time={(time.perf_counter() - start_time)*1000:.1f}ms, "
|
||||
f"messages_delivered={messages_delivered}, xread_count={xread_count}, "
|
||||
f"task_id={task_id}"
|
||||
)
|
||||
_listener_tasks.pop(queue_id, None)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user