fix(backend/copilot): address review comments — top-level imports, bug fixes, refactoring

- Move all local imports to top-level (transcript.py, helpers.py,
  baseline/service.py) per project style rules
- Fix sub.get("text") bug: use `or` instead of default arg to handle
  None values in tool_result content blocks
- Extract _flatten_assistant_content and _flatten_tool_result_content
  helpers from _transcript_to_messages to reduce nesting
- Use early returns in _get_credits/_spend_credits
- Extract _fetch_counters helper in rate_limit.py to DRY the Redis
  fetch pattern between get_usage_status and check_rate_limit
- Fix DRY violation: compute total_tokens once before StreamUsage,
  reuse in finally block for session persistence
- Fix chunk.choices guard: use early continue instead of ternary
  to prevent IndexError on empty list
- Make generic error message in execute_block to avoid leaking internals
This commit is contained in:
Zamil Majdy
2026-03-13 19:49:28 +07:00
parent 8b970c4c3d
commit 90b7edf1f1
5 changed files with 96 additions and 94 deletions

View File

@@ -49,7 +49,11 @@ from backend.copilot.service import (
from backend.copilot.tools import execute_tool, get_available_tools
from backend.copilot.tracking import track_user_message
from backend.util.exceptions import NotFoundError
from backend.util.prompt import compress_context
from backend.util.prompt import (
compress_context,
estimate_token_count,
estimate_token_count_str,
)
logger = logging.getLogger(__name__)
@@ -258,7 +262,9 @@ async def stream_chat_completion_baseline(
turn_prompt_tokens += chunk.usage.prompt_tokens or 0
turn_completion_tokens += chunk.usage.completion_tokens or 0
delta = chunk.choices[0].delta if chunk.choices else None
if not chunk.choices:
continue
delta = chunk.choices[0].delta
if not delta:
continue
@@ -432,11 +438,6 @@ async def stream_chat_completion_baseline(
# Count the full message list (system + history + turn) since
# each API call sends the complete context window.
if turn_prompt_tokens == 0 and turn_completion_tokens == 0:
from backend.util.prompt import (
estimate_token_count,
estimate_token_count_str,
)
turn_prompt_tokens = max(
estimate_token_count(openai_messages, model=config.model), 0
)

View File

@@ -86,6 +86,19 @@ def _weekly_reset_time(now: datetime | None = None) -> datetime:
)
async def _fetch_counters(user_id: str, now: datetime) -> tuple[int, int]:
"""Fetch daily and weekly token counters from Redis.
Returns (daily_used, weekly_used). Returns (0, 0) if Redis is unavailable.
"""
redis = await get_redis_async()
daily_raw, weekly_raw = await asyncio.gather(
redis.get(_daily_key(user_id, now=now)),
redis.get(_weekly_key(user_id, now=now)),
)
return int(daily_raw or 0), int(weekly_raw or 0)
async def get_usage_status(
user_id: str,
daily_token_limit: int,
@@ -102,20 +115,13 @@ async def get_usage_status(
CoPilotUsageStatus with current usage and limits.
"""
now = datetime.now(UTC)
daily_used = 0
weekly_used = 0
try:
redis = await get_redis_async()
daily_raw, weekly_raw = await asyncio.gather(
redis.get(_daily_key(user_id, now=now)),
redis.get(_weekly_key(user_id, now=now)),
)
daily_used = int(daily_raw or 0)
weekly_used = int(weekly_raw or 0)
daily_used, weekly_used = await _fetch_counters(user_id, now)
except Exception:
logger.warning(
"Redis unavailable for usage status, returning zeros", exc_info=True
)
daily_used, weekly_used = 0, 0
return CoPilotUsageStatus(
daily=UsageWindow(
@@ -148,13 +154,7 @@ async def check_rate_limit(
"""
now = datetime.now(UTC)
try:
redis = await get_redis_async()
daily_raw, weekly_raw = await asyncio.gather(
redis.get(_daily_key(user_id, now=now)),
redis.get(_weekly_key(user_id, now=now)),
)
daily_used = int(daily_raw or 0)
weekly_used = int(weekly_raw or 0)
daily_used, weekly_used = await _fetch_counters(user_id, now)
except Exception:
logger.warning(
"Redis unavailable for rate limit check, allowing request", exc_info=True

View File

@@ -746,6 +746,7 @@ async def stream_chat_completion_sdk(
turn_completion_tokens = 0
turn_cache_read_tokens = 0
turn_cache_creation_tokens = 0
total_tokens = 0 # computed once before StreamUsage, reused in finally
turn_cost_usd: float | None = None
try:
@@ -1430,13 +1431,15 @@ async def stream_chat_completion_sdk(
# Session persistence of usage is in finally to stay consistent with
# rate-limit recording even if an exception interrupts between here
# and the finally block.
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
total_tokens = (
turn_prompt_tokens
+ turn_cache_read_tokens
+ turn_cache_creation_tokens
+ turn_completion_tokens
)
# Compute total_tokens once; reused in the finally block for
# session persistence and rate-limit recording.
total_tokens = (
turn_prompt_tokens
+ turn_cache_read_tokens
+ turn_cache_creation_tokens
+ turn_completion_tokens
)
if total_tokens > 0:
yield StreamUsage(
promptTokens=turn_prompt_tokens,
completionTokens=turn_completion_tokens,
@@ -1512,13 +1515,8 @@ async def stream_chat_completion_sdk(
# --- Persist token usage to session + rate-limit counters ---
# Both must live in finally so they stay consistent even when an
# exception interrupts the try block after StreamUsage was yielded.
if turn_prompt_tokens > 0 or turn_completion_tokens > 0:
total_tokens = (
turn_prompt_tokens
+ turn_cache_read_tokens
+ turn_cache_creation_tokens
+ turn_completion_tokens
)
# total_tokens is computed once before StreamUsage yield above.
if total_tokens > 0:
if session is not None:
session.usage.append(
Usage(
@@ -1540,7 +1538,7 @@ async def stream_chat_completion_sdk(
total_tokens,
turn_cost_usd,
)
if user_id and (turn_prompt_tokens > 0 or turn_completion_tokens > 0):
if user_id and total_tokens > 0:
try:
await record_token_usage(
user_id=user_id,

View File

@@ -13,12 +13,17 @@ filesystem for self-hosted) — no DB column needed.
import logging
import os
import re
import shutil
import time
from dataclasses import dataclass
from pathlib import Path
from uuid import uuid4
import openai
from backend.copilot.config import ChatConfig
from backend.util import json
from backend.util.prompt import compress_context
logger = logging.getLogger(__name__)
@@ -200,8 +205,6 @@ def read_cli_session_file(sdk_cwd: str) -> str | None:
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
"""Remove the CLI's project directory for a specific working directory."""
import shutil
project_dir = _cli_project_dir(sdk_cwd)
if not project_dir:
return
@@ -474,6 +477,49 @@ async def delete_transcript(user_id: str, session_id: str) -> None:
COMPACT_THRESHOLD_BYTES = 400_000
def _flatten_assistant_content(blocks: list) -> str:
"""Flatten assistant content blocks into a single plain-text string."""
parts: list[str] = []
for block in blocks:
if isinstance(block, dict):
if block.get("type") == "text":
parts.append(block.get("text", ""))
elif block.get("type") == "tool_use":
parts.append(f"[tool_use: {block.get('name', '?')}]")
elif isinstance(block, str):
parts.append(block)
return "\n".join(parts) if parts else ""
def _flatten_tool_result_content(blocks: list) -> str:
"""Flatten tool_result and other content blocks into plain text.
Handles nested tool_result structures, text blocks, and raw strings.
Uses ``json.dumps`` as fallback for dict blocks without a ``text`` key
or where ``text`` is ``None``.
"""
str_parts: list[str] = []
for block in blocks:
if isinstance(block, dict) and block.get("type") == "tool_result":
inner = block.get("content", "")
if isinstance(inner, list):
for sub in inner:
if isinstance(sub, dict):
text = sub.get("text")
str_parts.append(
str(text) if text is not None else json.dumps(sub)
)
else:
str_parts.append(str(sub))
else:
str_parts.append(str(inner))
elif isinstance(block, dict) and block.get("type") == "text":
str_parts.append(str(block.get("text", "")))
elif isinstance(block, str):
str_parts.append(block)
return "\n".join(str_parts) if str_parts else ""
def _transcript_to_messages(content: str) -> list[dict]:
"""Convert JSONL transcript entries to message dicts for compress_context."""
messages: list[dict] = []
@@ -492,37 +538,9 @@ def _transcript_to_messages(content: str) -> list[dict]:
msg_dict: dict = {"role": role}
raw_content = msg.get("content")
if role == "assistant" and isinstance(raw_content, list):
parts: list[str] = []
for block in raw_content:
if isinstance(block, dict):
if block.get("type") == "text":
parts.append(block.get("text", ""))
elif block.get("type") == "tool_use":
parts.append(f"[tool_use: {block.get('name', '?')}]")
elif isinstance(block, str):
parts.append(block)
msg_dict["content"] = "\n".join(parts) if parts else ""
msg_dict["content"] = _flatten_assistant_content(raw_content)
elif isinstance(raw_content, list):
str_parts: list[str] = []
for block in raw_content:
if isinstance(block, dict) and block.get("type") == "tool_result":
# Flatten tool_result content for summarisation;
# tool_use_id pairing is not preserved through LLM
# compaction — the compacted transcript uses fresh IDs.
inner = block.get("content", "")
if isinstance(inner, list):
for sub in inner:
if isinstance(sub, dict):
str_parts.append(str(sub.get("text", json.dumps(sub))))
else:
str_parts.append(str(sub))
else:
str_parts.append(str(inner))
elif isinstance(block, dict) and block.get("type") == "text":
str_parts.append(str(block.get("text", "")))
elif isinstance(block, str):
str_parts.append(block)
msg_dict["content"] = "\n".join(str_parts) if str_parts else ""
msg_dict["content"] = _flatten_tool_result_content(raw_content)
else:
msg_dict["content"] = raw_content or ""
messages.append(msg_dict)
@@ -572,18 +590,12 @@ async def compact_transcript(
Returns the compacted JSONL string, or ``None`` on failure.
"""
from backend.copilot.config import ChatConfig
cfg = ChatConfig()
messages = _transcript_to_messages(content)
if len(messages) < 2:
logger.warning("%s Too few messages to compact (%d)", log_prefix, len(messages))
return None
try:
import openai
from backend.util.prompt import compress_context
try:
async with openai.AsyncOpenAI(
api_key=cfg.api_key, base_url=cfg.base_url, timeout=30.0

View File

@@ -9,12 +9,13 @@ from pydantic_core import PydanticUndefined
from backend.blocks._base import AnyBlockSchema
from backend.copilot.constants import COPILOT_NODE_PREFIX, COPILOT_SESSION_PREFIX
from backend.data import db
from backend.data.credit import UsageTransactionMetadata
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
from backend.data.db_accessors import workspace_db
from backend.data.execution import ExecutionContext
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
from backend.executor.utils import block_usage_cost
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util.clients import get_database_manager_async_client
from backend.util.exceptions import BlockError, InsufficientBalanceError
from backend.util.type import coerce_inputs_to_schema
@@ -26,32 +27,22 @@ logger = logging.getLogger(__name__)
async def _get_credits(user_id: str) -> int:
"""Get user credits using the adapter pattern (RPC when Prisma unavailable)."""
if db.is_connected():
from backend.data.credit import get_user_credit_model
credit_model = await get_user_credit_model(user_id)
return await credit_model.get_credits(user_id)
else:
from backend.util.clients import get_database_manager_async_client
if not db.is_connected():
return await get_database_manager_async_client().get_credits(user_id)
credit_model = await get_user_credit_model(user_id)
return await credit_model.get_credits(user_id)
async def _spend_credits(
user_id: str, cost: int, metadata: UsageTransactionMetadata
) -> int:
"""Spend user credits using the adapter pattern (RPC when Prisma unavailable)."""
if db.is_connected():
from backend.data.credit import get_user_credit_model
credit_model = await get_user_credit_model(user_id)
return await credit_model.spend_credits(user_id, cost, metadata)
else:
from backend.util.clients import get_database_manager_async_client
if not db.is_connected():
return await get_database_manager_async_client().spend_credits(
user_id, cost, metadata
)
credit_model = await get_user_credit_model(user_id)
return await credit_model.spend_credits(user_id, cost, metadata)
def get_inputs_from_schema(
@@ -220,7 +211,7 @@ async def execute_block(
except Exception as e:
logger.error("Unexpected error executing block: %s", e, exc_info=True)
return ErrorResponse(
message="Failed to execute block",
message="An unexpected error occurred while executing the block",
error=str(e),
session_id=session_id,
)