fix(backend/chat): Fix bugs and remove dead code in SDK integration

- Fix message accumulation bug: reset has_appended_assistant when
  creating new post-tool assistant message to prevent lost text deltas
- Fix hardcoded model in anthropic_fallback.py: use config.model instead
  of hardcoded "claude-sonnet-4-20250514"
- Fix _SDK_TOOL_RESULTS_DIR using hardcoded /root/ path: use expanduser
- Remove unused create_strict_security_hooks (~75 lines)
- Remove unused create_heartbeat/create_usage from response adapter
- Remove unused RAW_TOOL_NAMES from tool_adapter
- Extract _MAX_TOOL_ITERATIONS constant from magic number
This commit is contained in:
Zamil Majdy
2026-02-11 04:42:05 +04:00
parent 8b509e56de
commit 1926127ddd
5 changed files with 15 additions and 99 deletions

View File

@@ -11,6 +11,7 @@ import uuid
from collections.abc import AsyncGenerator
from typing import Any, cast
from ..config import ChatConfig
from ..model import ChatMessage, ChatSession
from ..response_model import (
StreamBaseResponse,
@@ -27,6 +28,10 @@ from ..response_model import (
from .tool_adapter import get_tool_definitions, get_tool_handlers
logger = logging.getLogger(__name__)
config = ChatConfig()
# Maximum tool-call iterations before stopping to prevent infinite loops
_MAX_TOOL_ITERATIONS = 10
async def stream_with_anthropic(
@@ -72,14 +77,15 @@ async def stream_with_anthropic(
)
has_started_text = False
max_iterations = 10
accumulated_text = ""
accumulated_tool_calls: list[dict[str, Any]] = []
for _ in range(max_iterations):
for _ in range(_MAX_TOOL_ITERATIONS):
try:
async with client.messages.stream(
model="claude-sonnet-4-20250514",
model=(
config.model.split("/")[-1] if "/" in config.model else config.model
),
max_tokens=4096,
system=system_prompt,
messages=cast(Any, anthropic_messages),

View File

@@ -24,7 +24,6 @@ from backend.api.features.chat.response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamHeartbeat,
StreamStart,
StreamTextDelta,
StreamTextEnd,
@@ -32,7 +31,6 @@ from backend.api.features.chat.response_model import (
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
StreamUsage,
)
logger = logging.getLogger(__name__)
@@ -141,18 +139,6 @@ class SDKResponseAdapter:
responses.append(StreamTextEnd(id=self.text_block_id))
self.has_ended_text = True
def create_heartbeat(self, tool_call_id: str | None = None) -> StreamHeartbeat:
"""Create a heartbeat response."""
return StreamHeartbeat(toolCallId=tool_call_id)
def create_usage(self, prompt_tokens: int, completion_tokens: int) -> StreamUsage:
"""Create a usage statistics response."""
return StreamUsage(
promptTokens=prompt_tokens,
completionTokens=completion_tokens,
totalTokens=prompt_tokens + completion_tokens,
)
def _extract_tool_output(content: str | list[dict[str, str]] | None) -> str:
"""Extract a string output from a ToolResultBlock's content field."""

View File

@@ -210,80 +210,3 @@ def create_security_hooks(user_id: str | None) -> dict[str, Any]:
except ImportError:
# Fallback for when SDK isn't available - return empty hooks
return {}
def create_strict_security_hooks(
user_id: str | None,
allowed_tools: list[str] | None = None,
) -> dict[str, Any]:
"""Create strict security hooks that only allow specific tools.
Args:
user_id: Current user ID
allowed_tools: List of allowed tool names (defaults to CoPilot tools)
Returns:
Hooks configuration dict
"""
try:
from claude_agent_sdk import HookMatcher
from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput
from .tool_adapter import RAW_TOOL_NAMES
tools_list = allowed_tools if allowed_tools is not None else RAW_TOOL_NAMES
allowed_set = set(tools_list)
async def strict_pre_tool_use(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Strict validation that only allows whitelisted tools."""
_ = context # unused but required by signature
tool_name = cast(str, input_data.get("tool_name", ""))
tool_input = cast(dict[str, Any], input_data.get("tool_input", {}))
# Remove MCP prefix if present
clean_name = tool_name.removeprefix(MCP_TOOL_PREFIX)
if clean_name not in allowed_set:
logger.warning(f"Blocked non-whitelisted tool: {tool_name}")
return cast(
SyncHookJSONOutput,
{
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "deny",
"permissionDecisionReason": (
f"Tool '{tool_name}' is not in the allowed list"
),
}
},
)
# Only run blocklist check for non-CoPilot tools; whitelisted
# MCP tools are already sandboxed by tool_adapter.
is_copilot_tool = tool_name.startswith(MCP_TOOL_PREFIX)
if not is_copilot_tool:
result = _validate_tool_access(clean_name, tool_input)
if result:
return cast(SyncHookJSONOutput, result)
result = _validate_user_isolation(clean_name, tool_input, user_id)
if result:
return cast(SyncHookJSONOutput, result)
logger.debug(
f"[SDK Audit] Tool call: tool={tool_name}, "
f"user={user_id}, tool_use_id={tool_use_id}"
)
return cast(SyncHookJSONOutput, {})
return {
"PreToolUse": [
HookMatcher(matcher="*", hooks=[strict_pre_tool_use]),
],
}
except ImportError:
return {}

View File

@@ -291,13 +291,17 @@ async def stream_chat_completion_sdk(
if isinstance(response, StreamTextDelta):
delta = response.delta or ""
# After tool results, start a new assistant
# message for the post-tool text.
if has_tool_results and has_appended_assistant:
assistant_response = ChatMessage(
role="assistant", content=delta
)
accumulated_tool_calls = []
session.messages.append(assistant_response)
has_appended_assistant = False
has_tool_results = False
session.messages.append(assistant_response)
has_appended_assistant = True
else:
assistant_response.content = (
assistant_response.content or ""

View File

@@ -18,7 +18,7 @@ from backend.api.features.chat.tools.base import BaseTool
logger = logging.getLogger(__name__)
# Allowed base directory for the Read tool (SDK saves oversized tool results here)
_SDK_TOOL_RESULTS_DIR = "/root/.claude/"
_SDK_TOOL_RESULTS_DIR = os.path.expanduser("~/.claude/")
# MCP server naming - the SDK prefixes tool names as "mcp__{server_name}__{tool}"
MCP_SERVER_NAME = "copilot"
@@ -282,6 +282,3 @@ COPILOT_TOOL_NAMES = [
*[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()],
f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}",
]
# Also export the raw tool names for flexibility
RAW_TOOL_NAMES = list(TOOL_REGISTRY.keys())