fix(backend/chat): Fix cleanup race condition and move to outer finally

- Use session-specific temp dir (/tmp/copilot-{session_id}) as SDK cwd
  to prevent concurrent sessions from deleting each other's tool-result
  files during cleanup
- Move _cleanup_sdk_tool_results() to outer finally block so it runs
  even when the outer except Exception fires
- Clean up the temp cwd directory after each session
- Remove unnecessary inner try/finally nesting
This commit is contained in:
Zamil Majdy
2026-02-11 05:13:02 +04:00
parent 1926127ddd
commit d7ef70469e

View File

@@ -44,20 +44,35 @@ config = ChatConfig()
# Set to hold background tasks to prevent garbage collection
_background_tasks: set[asyncio.Task[Any]] = set()
# SDK tool-results glob pattern — clean these up after each query
_SDK_TOOL_RESULTS_GLOB = os.path.expanduser("~/.claude/projects/*/tool-results/*")
def _cleanup_sdk_tool_results(cwd: str) -> None:
"""Remove SDK tool-result files for a specific session working directory.
def _cleanup_sdk_tool_results() -> None:
"""Remove SDK tool-result files to prevent disk accumulation."""
The SDK creates tool-result files under ~/.claude/projects/<encoded-cwd>/tool-results/.
We clean only the specific cwd's results to avoid race conditions between
concurrent sessions.
"""
import glob as _glob
import shutil
for path in _glob.glob(_SDK_TOOL_RESULTS_GLOB):
# SDK encodes the cwd path by replacing '/' with '-'
encoded_cwd = cwd.replace("/", "-")
project_dir = os.path.expanduser(f"~/.claude/projects/{encoded_cwd}")
results_glob = os.path.join(project_dir, "tool-results", "*")
for path in _glob.glob(results_glob):
try:
os.remove(path)
except OSError:
pass
# Also clean up the temp cwd directory itself
if cwd.startswith("/tmp/copilot-"):
try:
shutil.rmtree(cwd, ignore_errors=True)
except OSError:
pass
async def _compress_conversation_history(
session: ChatSession,
@@ -222,6 +237,10 @@ async def stream_chat_completion_sdk(
yield StreamStart(messageId=message_id, taskId=task_id)
stream_completed = False
# Use a session-specific temp dir to avoid cleanup race conditions
# between concurrent sessions.
sdk_cwd = f"/tmp/copilot-{session_id}"
os.makedirs(sdk_cwd, exist_ok=True)
try:
try:
@@ -234,127 +253,121 @@ async def stream_chat_completion_sdk(
mcp_servers={"copilot": mcp_server}, # type: ignore[arg-type]
allowed_tools=COPILOT_TOOL_NAMES,
hooks=create_security_hooks(user_id), # type: ignore[arg-type]
cwd="/tmp",
cwd=sdk_cwd,
)
adapter = SDKResponseAdapter(message_id=message_id)
adapter.set_task_id(task_id)
try:
async with ClaudeSDKClient(options=options) as client:
current_message = message or ""
if not current_message and session.messages:
last_user = [m for m in session.messages if m.role == "user"]
if last_user:
current_message = last_user[-1].content or ""
async with ClaudeSDKClient(options=options) as client:
current_message = message or ""
if not current_message and session.messages:
last_user = [m for m in session.messages if m.role == "user"]
if last_user:
current_message = last_user[-1].content or ""
if not current_message.strip():
yield StreamError(
errorText="Message cannot be empty.",
code="empty_prompt",
)
yield StreamFinish()
return
# Build query with conversation history context.
# Compress history first to handle long conversations.
query_message = current_message
if len(session.messages) > 1:
compressed = await _compress_conversation_history(session)
history_context = _format_conversation_context(compressed)
if history_context:
query_message = (
f"{history_context}\n\n"
f"Now, the user says:\n{current_message}"
)
logger.info(
f"[SDK] Sending query: {current_message[:80]!r}"
f" ({len(session.messages)} msgs in session)"
if not current_message.strip():
yield StreamError(
errorText="Message cannot be empty.",
code="empty_prompt",
)
await client.query(query_message, session_id=session_id)
yield StreamFinish()
return
assistant_response = ChatMessage(role="assistant", content="")
accumulated_tool_calls: list[dict[str, Any]] = []
has_appended_assistant = False
has_tool_results = False
async for sdk_msg in client.receive_messages():
logger.debug(
f"[SDK] Received: {type(sdk_msg).__name__} "
f"{getattr(sdk_msg, 'subtype', '')}"
# Build query with conversation history context.
# Compress history first to handle long conversations.
query_message = current_message
if len(session.messages) > 1:
compressed = await _compress_conversation_history(session)
history_context = _format_conversation_context(compressed)
if history_context:
query_message = (
f"{history_context}\n\n"
f"Now, the user says:\n{current_message}"
)
for response in adapter.convert_message(sdk_msg):
if isinstance(response, StreamStart):
continue
yield response
if isinstance(response, StreamTextDelta):
delta = response.delta or ""
# After tool results, start a new assistant
# message for the post-tool text.
if has_tool_results and has_appended_assistant:
assistant_response = ChatMessage(
role="assistant", content=delta
)
accumulated_tool_calls = []
has_appended_assistant = False
has_tool_results = False
session.messages.append(assistant_response)
has_appended_assistant = True
else:
assistant_response.content = (
assistant_response.content or ""
) + delta
if not has_appended_assistant:
session.messages.append(assistant_response)
has_appended_assistant = True
logger.info(
f"[SDK] Sending query: {current_message[:80]!r}"
f" ({len(session.messages)} msgs in session)"
)
await client.query(query_message, session_id=session_id)
elif isinstance(response, StreamToolInputAvailable):
accumulated_tool_calls.append(
{
"id": response.toolCallId,
"type": "function",
"function": {
"name": response.toolName,
"arguments": json.dumps(
response.input or {}
),
},
}
assistant_response = ChatMessage(role="assistant", content="")
accumulated_tool_calls: list[dict[str, Any]] = []
has_appended_assistant = False
has_tool_results = False
async for sdk_msg in client.receive_messages():
logger.debug(
f"[SDK] Received: {type(sdk_msg).__name__} "
f"{getattr(sdk_msg, 'subtype', '')}"
)
for response in adapter.convert_message(sdk_msg):
if isinstance(response, StreamStart):
continue
yield response
if isinstance(response, StreamTextDelta):
delta = response.delta or ""
# After tool results, start a new assistant
# message for the post-tool text.
if has_tool_results and has_appended_assistant:
assistant_response = ChatMessage(
role="assistant", content=delta
)
assistant_response.tool_calls = accumulated_tool_calls
accumulated_tool_calls = []
has_appended_assistant = False
has_tool_results = False
session.messages.append(assistant_response)
has_appended_assistant = True
else:
assistant_response.content = (
assistant_response.content or ""
) + delta
if not has_appended_assistant:
session.messages.append(assistant_response)
has_appended_assistant = True
elif isinstance(response, StreamToolOutputAvailable):
session.messages.append(
ChatMessage(
role="tool",
content=(
response.output
if isinstance(response.output, str)
else str(response.output)
),
tool_call_id=response.toolCallId,
)
elif isinstance(response, StreamToolInputAvailable):
accumulated_tool_calls.append(
{
"id": response.toolCallId,
"type": "function",
"function": {
"name": response.toolName,
"arguments": json.dumps(response.input or {}),
},
}
)
assistant_response.tool_calls = accumulated_tool_calls
if not has_appended_assistant:
session.messages.append(assistant_response)
has_appended_assistant = True
elif isinstance(response, StreamToolOutputAvailable):
session.messages.append(
ChatMessage(
role="tool",
content=(
response.output
if isinstance(response.output, str)
else str(response.output)
),
tool_call_id=response.toolCallId,
)
has_tool_results = True
)
has_tool_results = True
elif isinstance(response, StreamFinish):
stream_completed = True
elif isinstance(response, StreamFinish):
stream_completed = True
if stream_completed:
break
if stream_completed:
break
if (
assistant_response.content or assistant_response.tool_calls
) and not has_appended_assistant:
session.messages.append(assistant_response)
finally:
_cleanup_sdk_tool_results()
if (
assistant_response.content or assistant_response.tool_calls
) and not has_appended_assistant:
session.messages.append(assistant_response)
except ImportError:
logger.warning(
@@ -385,6 +398,8 @@ async def stream_chat_completion_sdk(
code="sdk_error",
)
yield StreamFinish()
finally:
_cleanup_sdk_tool_results(sdk_cwd)
async def _update_title_async(