refactor(copilot): extract _run_stream_attempt to flatten retry loop

Move the entire streaming attempt (client lifecycle, query, message
iteration, response dispatch, post-stream cleanup) into an inner async
generator. The retry loop now becomes a clean for/else with try/except,
reducing nesting by 2-3 levels.
This commit is contained in:
Zamil Majdy
2026-03-14 20:35:42 +07:00
parent 4cc0bbf472
commit 384b261e7f

View File

@@ -793,7 +793,6 @@ async def stream_chat_completion_sdk(
message_id = str(uuid.uuid4())
stream_id = str(uuid.uuid4())
stream_completed = False
ended_with_stream_error = False
e2b_sandbox = None
use_resume = False
@@ -1041,9 +1040,315 @@ async def stream_chat_completion_sdk(
if attachments.hint:
query_message = f"{query_message}\n\n{attachments.hint}"
_stream_error: Exception | None = None
_tried_compaction = False
async def _run_stream_attempt() -> AsyncGenerator[StreamBaseResponse, None]:
"""Run one SDK streaming attempt.
Opens a ``ClaudeSDKClient``, sends the query, iterates SDK
messages with heartbeat timeouts, dispatches adapter responses,
and performs post-stream cleanup (safety-net flush, stopped-by-
user handling).
Yields stream events. On stream error the exception propagates
to the caller so the retry loop can rollback and retry.
"""
assistant_response = ChatMessage(role="assistant", content="")
accumulated_tool_calls: list[dict[str, Any]] = []
has_appended_assistant = False
has_tool_results = False
stream_completed = False
async with ClaudeSDKClient(options=options) as client:
logger.info(
"%s Sending query — resume=%s, total_msgs=%d, "
"query_len=%d, attached_files=%d, image_blocks=%d",
log_prefix,
use_resume,
len(session.messages),
len(query_message),
len(file_ids) if file_ids else 0,
len(attachments.image_blocks),
)
compaction.reset_for_query()
if was_compacted:
for ev in compaction.emit_pre_query(session):
yield ev
if attachments.image_blocks:
content_blocks: list[dict[str, Any]] = [
*attachments.image_blocks,
{"type": "text", "text": query_message},
]
user_msg = {
"type": "user",
"message": {"role": "user", "content": content_blocks},
"parent_tool_use_id": None,
"session_id": session_id,
}
assert client._transport is not None # noqa: SLF001
await client._transport.write( # noqa: SLF001
json.dumps(user_msg) + "\n"
)
transcript_builder.append_user(
content=[
*attachments.image_blocks,
{"type": "text", "text": current_message},
]
)
else:
await client.query(query_message, session_id=session_id)
transcript_builder.append_user(content=current_message)
async for sdk_msg in _iter_sdk_messages(client):
if sdk_msg is None:
await lock.refresh()
for ev in compaction.emit_start_if_ready():
yield ev
yield StreamHeartbeat()
continue
logger.info(
"%s Received: %s %s (unresolved=%d, current=%d, resolved=%d)",
log_prefix,
type(sdk_msg).__name__,
getattr(sdk_msg, "subtype", ""),
len(adapter.current_tool_calls)
- len(adapter.resolved_tool_calls),
len(adapter.current_tool_calls),
len(adapter.resolved_tool_calls),
)
sdk_error = getattr(sdk_msg, "error", None)
if isinstance(sdk_msg, AssistantMessage) and sdk_error:
logger.error(
"[SDK] [%s] AssistantMessage has error=%s, "
"content_blocks=%d, content_preview=%s",
session_id[:12],
sdk_error,
len(sdk_msg.content),
str(sdk_msg.content)[:500],
)
# Wait for PostToolUse hook stash (race-condition fix).
# Skip for parallel tool continuations (all ToolUseBlocks).
is_parallel_continuation = isinstance(
sdk_msg, AssistantMessage
) and all(isinstance(b, ToolUseBlock) for b in sdk_msg.content)
if (
adapter.has_unresolved_tool_calls
and isinstance(sdk_msg, (AssistantMessage, ResultMessage))
and not is_parallel_continuation
):
if await wait_for_stash(timeout=0.5):
await asyncio.sleep(0)
else:
logger.warning(
"%s Timed out waiting for PostToolUse "
"hook stash (%d unresolved tool calls)",
log_prefix,
len(adapter.current_tool_calls)
- len(adapter.resolved_tool_calls),
)
if isinstance(sdk_msg, ResultMessage):
logger.info(
"%s Received: ResultMessage %s "
"(unresolved=%d, current=%d, resolved=%d)",
log_prefix,
sdk_msg.subtype,
len(adapter.current_tool_calls)
- len(adapter.resolved_tool_calls),
len(adapter.current_tool_calls),
len(adapter.resolved_tool_calls),
)
if sdk_msg.subtype in (
"error",
"error_during_execution",
):
logger.error(
"%s SDK execution failed with error: %s",
log_prefix,
sdk_msg.result or "(no error message provided)",
)
# Sync TranscriptBuilder with CLI after compaction
compact_result = await compaction.emit_end_if_ready(session)
for ev in compact_result.events:
yield ev
entries_replaced = False
if compact_result.just_ended:
compacted = await asyncio.to_thread(
read_compacted_entries,
compact_result.transcript_path,
)
if compacted is not None:
transcript_builder.replace_entries(
compacted, log_prefix=log_prefix
)
entries_replaced = True
for response in adapter.convert_message(sdk_msg):
if isinstance(response, StreamStart):
continue
if isinstance(
response,
(
StreamToolInputAvailable,
StreamToolOutputAvailable,
),
):
extra = ""
if isinstance(response, StreamToolOutputAvailable):
out_len = len(str(response.output))
extra = f", output_len={out_len}"
logger.info(
"%s Tool event: %s, tool=%s%s",
log_prefix,
type(response).__name__,
getattr(response, "toolName", "N/A"),
extra,
)
if isinstance(response, StreamError):
logger.error(
"%s Sending error to frontend: %s (code=%s)",
log_prefix,
response.errorText,
response.code,
)
yield response
if isinstance(response, StreamTextDelta):
delta = response.delta or ""
if has_tool_results and has_appended_assistant:
assistant_response = ChatMessage(
role="assistant", content=delta
)
accumulated_tool_calls = []
has_appended_assistant = False
has_tool_results = False
session.messages.append(assistant_response)
has_appended_assistant = True
else:
assistant_response.content = (
assistant_response.content or ""
) + delta
if not has_appended_assistant:
session.messages.append(assistant_response)
has_appended_assistant = True
elif isinstance(response, StreamToolInputAvailable):
accumulated_tool_calls.append(
{
"id": response.toolCallId,
"type": "function",
"function": {
"name": response.toolName,
"arguments": json.dumps(response.input or {}),
},
}
)
assistant_response.tool_calls = accumulated_tool_calls
if not has_appended_assistant:
session.messages.append(assistant_response)
has_appended_assistant = True
elif isinstance(response, StreamToolOutputAvailable):
content = (
response.output
if isinstance(response.output, str)
else json.dumps(response.output, ensure_ascii=False)
)
session.messages.append(
ChatMessage(
role="tool",
content=content,
tool_call_id=response.toolCallId,
)
)
if not entries_replaced:
transcript_builder.append_tool_result(
tool_use_id=response.toolCallId,
content=content,
)
has_tool_results = True
elif isinstance(response, StreamFinish):
stream_completed = True
# Append assistant entry AFTER convert_message so
# stashed tool results come first (required API order).
if isinstance(sdk_msg, AssistantMessage) and not entries_replaced:
transcript_builder.append_assistant(
content_blocks=_format_sdk_content_blocks(sdk_msg.content),
model=sdk_msg.model,
)
if stream_completed:
break
# --- Post-stream processing (only on success) ---
if adapter.has_unresolved_tool_calls:
logger.warning(
"%s %d unresolved tool(s) after stream — flushing",
log_prefix,
len(adapter.current_tool_calls) - len(adapter.resolved_tool_calls),
)
safety_responses: list[StreamBaseResponse] = []
adapter._flush_unresolved_tool_calls(safety_responses)
for response in safety_responses:
if isinstance(
response,
(StreamToolInputAvailable, StreamToolOutputAvailable),
):
logger.info(
"%s Safety flush: %s, tool=%s",
log_prefix,
type(response).__name__,
getattr(response, "toolName", "N/A"),
)
if isinstance(response, StreamToolOutputAvailable):
transcript_builder.append_tool_result(
tool_use_id=response.toolCallId,
content=(
response.output
if isinstance(response.output, str)
else json.dumps(response.output, ensure_ascii=False)
),
)
yield response
if not stream_completed:
logger.info(
"%s Stream ended without ResultMessage (stopped by user)",
log_prefix,
)
closing_responses: list[StreamBaseResponse] = []
adapter._end_text_if_open(closing_responses)
for r in closing_responses:
yield r
session.messages.append(
ChatMessage(
role="assistant",
content=f"{COPILOT_SYSTEM_PREFIX} Execution stopped by user",
)
)
if (
assistant_response.content or assistant_response.tool_calls
) and not has_appended_assistant:
session.messages.append(assistant_response)
# ---------------------------------------------------------------
# Retry loop: original → compacted → no transcript
# ---------------------------------------------------------------
ended_with_stream_error = False
stream_err: Exception | None = None
for _attempt in range(_MAX_STREAM_ATTEMPTS):
if _attempt > 0:
logger.info(
@@ -1052,8 +1357,6 @@ async def stream_chat_completion_sdk(
_attempt + 1,
_MAX_STREAM_ATTEMPTS,
)
_stream_error = None
stream_completed = False
(
transcript_builder,
@@ -1092,360 +1395,46 @@ async def stream_chat_completion_sdk(
message_id=message_id, session_id=session_id
)
# Save message count so we can rollback partial messages on retry
_pre_attempt_msg_count = len(session.messages)
assistant_response = ChatMessage(role="assistant", content="")
accumulated_tool_calls: list[dict[str, Any]] = []
has_appended_assistant = False
has_tool_results = False
ended_with_stream_error = False
async with ClaudeSDKClient(options=options) as client:
logger.info(
"%s Sending query — resume=%s, total_msgs=%d, "
"query_len=%d, attached_files=%d, image_blocks=%d",
try:
async for event in _run_stream_attempt():
yield event
break # Stream completed — exit retry loop
except asyncio.CancelledError:
logger.warning(
"%s Streaming cancelled (attempt %d/%d)",
log_prefix,
use_resume,
len(session.messages),
len(query_message),
len(file_ids) if file_ids else 0,
len(attachments.image_blocks),
_attempt + 1,
_MAX_STREAM_ATTEMPTS,
)
raise
except Exception as stream_err:
logger.warning(
"%s Stream error (attempt %d/%d): %s",
log_prefix,
_attempt + 1,
_MAX_STREAM_ATTEMPTS,
stream_err,
exc_info=True,
)
compaction.reset_for_query()
if was_compacted:
for ev in compaction.emit_pre_query(session):
yield ev
if attachments.image_blocks:
# Build multimodal content: image blocks + text
content_blocks: list[dict[str, Any]] = [
*attachments.image_blocks,
{"type": "text", "text": query_message},
]
user_msg = {
"type": "user",
"message": {"role": "user", "content": content_blocks},
"parent_tool_use_id": None,
"session_id": session_id,
}
assert client._transport is not None # noqa: SLF001
await client._transport.write( # noqa: SLF001
json.dumps(user_msg) + "\n"
)
transcript_builder.append_user(
content=[
*attachments.image_blocks,
{"type": "text", "text": current_message},
]
)
else:
await client.query(query_message, session_id=session_id)
transcript_builder.append_user(content=current_message)
try:
async for sdk_msg in _iter_sdk_messages(client):
# Heartbeat sentinel — refresh lock and keep SSE alive
if sdk_msg is None:
await lock.refresh()
for ev in compaction.emit_start_if_ready():
yield ev
yield StreamHeartbeat()
continue
logger.info(
"%s Received: %s %s "
"(unresolved=%d, current=%d, resolved=%d)",
log_prefix,
type(sdk_msg).__name__,
getattr(sdk_msg, "subtype", ""),
len(adapter.current_tool_calls)
- len(adapter.resolved_tool_calls),
len(adapter.current_tool_calls),
len(adapter.resolved_tool_calls),
)
# Log AssistantMessage API errors (e.g. invalid_request)
sdk_error = getattr(sdk_msg, "error", None)
if isinstance(sdk_msg, AssistantMessage) and sdk_error:
logger.error(
"[SDK] [%s] AssistantMessage has error=%s, "
"content_blocks=%d, content_preview=%s",
session_id[:12],
sdk_error,
len(sdk_msg.content),
str(sdk_msg.content)[:500],
)
# Wait for PostToolUse hook stash (race-condition fix).
# Skip for parallel tool continuations (all ToolUseBlocks).
is_parallel_continuation = isinstance(
sdk_msg, AssistantMessage
) and all(isinstance(b, ToolUseBlock) for b in sdk_msg.content)
if (
adapter.has_unresolved_tool_calls
and isinstance(sdk_msg, (AssistantMessage, ResultMessage))
and not is_parallel_continuation
):
if await wait_for_stash(timeout=0.5):
await asyncio.sleep(0)
else:
logger.warning(
"%s Timed out waiting for "
"PostToolUse hook stash "
"(%d unresolved tool calls)",
log_prefix,
len(adapter.current_tool_calls)
- len(adapter.resolved_tool_calls),
)
# Log ResultMessage details for debugging
if isinstance(sdk_msg, ResultMessage):
logger.info(
"%s Received: ResultMessage %s "
"(unresolved=%d, current=%d, resolved=%d)",
log_prefix,
sdk_msg.subtype,
len(adapter.current_tool_calls)
- len(adapter.resolved_tool_calls),
len(adapter.current_tool_calls),
len(adapter.resolved_tool_calls),
)
if sdk_msg.subtype in (
"error",
"error_during_execution",
):
logger.error(
"%s SDK execution failed with error: %s",
log_prefix,
sdk_msg.result or "(no error message provided)",
)
# Emit compaction end if SDK finished compacting.
# Sync TranscriptBuilder with the CLI's active context.
compact_result = await compaction.emit_end_if_ready(session)
for ev in compact_result.events:
yield ev
entries_replaced = False
if compact_result.just_ended:
compacted = await asyncio.to_thread(
read_compacted_entries,
compact_result.transcript_path,
)
if compacted is not None:
transcript_builder.replace_entries(
compacted, log_prefix=log_prefix
)
entries_replaced = True
# --- Dispatch adapter responses ---
for response in adapter.convert_message(sdk_msg):
if isinstance(response, StreamStart):
continue
if isinstance(
response,
(
StreamToolInputAvailable,
StreamToolOutputAvailable,
),
):
extra = ""
if isinstance(response, StreamToolOutputAvailable):
out_len = len(str(response.output))
extra = f", output_len={out_len}"
logger.info(
"%s Tool event: %s, tool=%s%s",
log_prefix,
type(response).__name__,
getattr(response, "toolName", "N/A"),
extra,
)
if isinstance(response, StreamError):
logger.error(
"%s Sending error to frontend: %s (code=%s)",
log_prefix,
response.errorText,
response.code,
)
yield response
if isinstance(response, StreamTextDelta):
delta = response.delta or ""
if has_tool_results and has_appended_assistant:
assistant_response = ChatMessage(
role="assistant", content=delta
)
accumulated_tool_calls = []
has_appended_assistant = False
has_tool_results = False
session.messages.append(assistant_response)
has_appended_assistant = True
else:
assistant_response.content = (
assistant_response.content or ""
) + delta
if not has_appended_assistant:
session.messages.append(assistant_response)
has_appended_assistant = True
elif isinstance(response, StreamToolInputAvailable):
accumulated_tool_calls.append(
{
"id": response.toolCallId,
"type": "function",
"function": {
"name": response.toolName,
"arguments": json.dumps(
response.input or {}
),
},
}
)
assistant_response.tool_calls = accumulated_tool_calls
if not has_appended_assistant:
session.messages.append(assistant_response)
has_appended_assistant = True
elif isinstance(response, StreamToolOutputAvailable):
content = (
response.output
if isinstance(response.output, str)
else json.dumps(response.output, ensure_ascii=False)
)
session.messages.append(
ChatMessage(
role="tool",
content=content,
tool_call_id=response.toolCallId,
)
)
if not entries_replaced:
transcript_builder.append_tool_result(
tool_use_id=response.toolCallId,
content=content,
)
has_tool_results = True
elif isinstance(response, StreamFinish):
stream_completed = True
# Append assistant entry AFTER convert_message so
# stashed tool results come first (required API order).
if (
isinstance(sdk_msg, AssistantMessage)
and not entries_replaced
):
transcript_builder.append_assistant(
content_blocks=_format_sdk_content_blocks(
sdk_msg.content
),
model=sdk_msg.model,
)
if stream_completed:
break
except asyncio.CancelledError:
logger.warning(
"%s Streaming loop cancelled (asyncio.CancelledError)",
log_prefix,
)
raise
except Exception as stream_err:
_stream_error = stream_err
logger.warning(
"%s Stream error (attempt %d/%d): %s",
log_prefix,
_attempt + 1,
_MAX_STREAM_ATTEMPTS,
stream_err,
exc_info=True,
)
# On error, rollback partial messages and retry
if _stream_error is not None:
session.messages = session.messages[:_pre_attempt_msg_count]
continue
# Safety net: flush unresolved tools so frontend stops spinners
if adapter.has_unresolved_tool_calls:
logger.warning(
"%s %d unresolved tool(s) after stream loop — "
"flushing as safety net",
log_prefix,
len(adapter.current_tool_calls) - len(adapter.resolved_tool_calls),
)
safety_responses: list[StreamBaseResponse] = []
adapter._flush_unresolved_tool_calls(safety_responses)
for response in safety_responses:
if isinstance(
response,
(StreamToolInputAvailable, StreamToolOutputAvailable),
):
logger.info(
"%s Safety flush: %s, tool=%s",
log_prefix,
type(response).__name__,
getattr(response, "toolName", "N/A"),
)
if isinstance(response, StreamToolOutputAvailable):
transcript_builder.append_tool_result(
tool_use_id=response.toolCallId,
content=(
response.output
if isinstance(response.output, str)
else json.dumps(response.output, ensure_ascii=False)
),
)
yield response
# Stream ended without ResultMessage → stopped by user
if not stream_completed and not ended_with_stream_error:
logger.info(
"%s Stream ended without ResultMessage (stopped by user)",
log_prefix,
)
closing_responses: list[StreamBaseResponse] = []
adapter._end_text_if_open(closing_responses)
for r in closing_responses:
yield r
session.messages.append(
ChatMessage(
role="assistant",
content=f"{COPILOT_SYSTEM_PREFIX} Execution stopped by user",
)
)
if (
assistant_response.content or assistant_response.tool_calls
) and not has_appended_assistant:
session.messages.append(assistant_response)
break # Stream completed — exit retry loop
# All retry attempts exhausted — surface error to the user.
if _stream_error is not None:
else:
# All retry attempts exhausted (loop ended without break)
transcript_caused_error = True
ended_with_stream_error = True
logger.error(
"%s All %d query attempts exhausted: %s",
log_prefix,
_MAX_STREAM_ATTEMPTS,
_stream_error,
stream_err,
)
yield StreamError(
errorText=f"SDK stream error: {_stream_error}",
errorText=f"SDK stream error: {stream_err}",
code="all_attempts_exhausted",
)
# Transcript upload is handled exclusively in the finally block
# to avoid double-uploads (the success path used to upload the
# old resume file, then the finally block overwrote it with the
# stop hook content — which could be smaller after compaction).
if ended_with_stream_error:
logger.warning(
"%s Stream ended with SDK error after %d messages",