refactor(backend/copilot): use credit_db() adapter pattern for credit operations

Replace inline _get_credits/_spend_credits helpers with the centralized
credit_db() accessor from db_accessors.py, consistent with workspace_db(),
chat_db(), and other DB accessors. Add module-level get_credits() and
spend_credits() to credit.py so the accessor can return the module directly.
This commit is contained in:
Zamil Majdy
2026-03-13 22:16:54 +07:00
parent 7ef530c672
commit cb7d271472
4 changed files with 69 additions and 77 deletions

View File

@@ -8,9 +8,8 @@ from pydantic_core import PydanticUndefined
from backend.blocks._base import AnyBlockSchema
from backend.copilot.constants import COPILOT_NODE_PREFIX, COPILOT_SESSION_PREFIX
from backend.data import db
from backend.data.credit import UsageTransactionMetadata
from backend.data.db_accessors import workspace_db
from backend.data.db_accessors import credit_db, workspace_db
from backend.data.execution import ExecutionContext
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
from backend.executor.utils import block_usage_cost
@@ -24,36 +23,6 @@ from .utils import match_credentials_to_requirements
logger = logging.getLogger(__name__)
async def _get_credits(user_id: str) -> int:
"""Get user credits using the adapter pattern (RPC when Prisma unavailable)."""
if db.is_connected():
from backend.data.credit import get_user_credit_model
credit_model = await get_user_credit_model(user_id)
return await credit_model.get_credits(user_id)
else:
from backend.util.clients import get_database_manager_async_client
return await get_database_manager_async_client().get_credits(user_id)
async def _spend_credits(
user_id: str, cost: int, metadata: UsageTransactionMetadata
) -> int:
"""Spend user credits using the adapter pattern (RPC when Prisma unavailable)."""
if db.is_connected():
from backend.data.credit import get_user_credit_model
credit_model = await get_user_credit_model(user_id)
return await credit_model.spend_credits(user_id, cost, metadata)
else:
from backend.util.clients import get_database_manager_async_client
return await get_database_manager_async_client().spend_credits(
user_id, cost, metadata
)
def get_inputs_from_schema(
input_schema: dict[str, Any],
exclude_fields: set[str] | None = None,
@@ -152,7 +121,7 @@ async def execute_block(
cost, cost_filter = block_usage_cost(block, input_data)
has_cost = cost > 0
if has_cost:
balance = await _get_credits(user_id)
balance = await credit_db().get_credits(user_id)
if balance < cost:
return ErrorResponse(
message=(
@@ -173,7 +142,7 @@ async def execute_block(
# Charge credits for block execution
if has_cost:
try:
await _spend_credits(
await credit_db().spend_credits(
user_id=user_id,
cost=cost,
metadata=UsageTransactionMetadata(
@@ -216,7 +185,7 @@ async def execute_block(
except Exception as e:
logger.error("Unexpected error executing block: %s", e, exc_info=True)
return ErrorResponse(
message=f"Failed to execute block: {e}",
message=f"Failed to execute block: {str(e)}",
error=str(e),
session_id=session_id,
)

View File

@@ -42,6 +42,23 @@ def _patch_workspace():
return patch("backend.copilot.tools.helpers.workspace_db", return_value=mock_ws_db)
def _patch_credit_db(
get_credits_return: int = 100,
spend_credits_side_effect: Any = None,
):
"""Patch credit_db to return a mock credit accessor."""
mock_credit = MagicMock()
mock_credit.get_credits = AsyncMock(return_value=get_credits_return)
if spend_credits_side_effect is not None:
mock_credit.spend_credits = AsyncMock(side_effect=spend_credits_side_effect)
else:
mock_credit.spend_credits = AsyncMock()
return (
patch("backend.copilot.tools.helpers.credit_db", return_value=mock_credit),
mock_credit,
)
# ---------------------------------------------------------------------------
# Credit charging tests
# ---------------------------------------------------------------------------
@@ -52,7 +69,7 @@ class TestExecuteBlockCreditCharging:
async def test_charges_credits_when_cost_is_positive(self):
"""Block with cost > 0 should call spend_credits after execution."""
block = _make_block()
mock_spend = AsyncMock()
credit_patch, mock_credit = _patch_credit_db(get_credits_return=100)
with (
_patch_workspace(),
@@ -60,16 +77,7 @@ class TestExecuteBlockCreditCharging:
"backend.copilot.tools.helpers.block_usage_cost",
return_value=(10, {"key": "val"}),
),
patch(
"backend.copilot.tools.helpers._get_credits",
new_callable=AsyncMock,
return_value=100,
),
patch(
"backend.copilot.tools.helpers._spend_credits",
new_callable=AsyncMock,
side_effect=mock_spend,
),
credit_patch,
):
result = await execute_block(
block=block,
@@ -83,14 +91,15 @@ class TestExecuteBlockCreditCharging:
assert isinstance(result, BlockOutputResponse)
assert result.success is True
mock_spend.assert_awaited_once()
call_kwargs = mock_spend.call_args.kwargs
mock_credit.spend_credits.assert_awaited_once()
call_kwargs = mock_credit.spend_credits.call_args.kwargs
assert call_kwargs["cost"] == 10
assert call_kwargs["metadata"].reason == "copilot_block_execution"
async def test_returns_error_when_insufficient_credits_before_exec(self):
"""Pre-execution check should return ErrorResponse when balance < cost."""
block = _make_block()
credit_patch, mock_credit = _patch_credit_db(get_credits_return=5)
with (
_patch_workspace(),
@@ -98,11 +107,7 @@ class TestExecuteBlockCreditCharging:
"backend.copilot.tools.helpers.block_usage_cost",
return_value=(10, {}),
),
patch(
"backend.copilot.tools.helpers._get_credits",
new_callable=AsyncMock,
return_value=5, # balance < cost (10)
),
credit_patch,
):
result = await execute_block(
block=block,
@@ -120,6 +125,7 @@ class TestExecuteBlockCreditCharging:
async def test_no_charge_when_cost_is_zero(self):
"""Block with cost 0 should not call spend_credits."""
block = _make_block()
credit_patch, mock_credit = _patch_credit_db()
with (
_patch_workspace(),
@@ -127,12 +133,7 @@ class TestExecuteBlockCreditCharging:
"backend.copilot.tools.helpers.block_usage_cost",
return_value=(0, {}),
),
patch(
"backend.copilot.tools.helpers._get_credits",
) as mock_get_credits,
patch(
"backend.copilot.tools.helpers._spend_credits",
) as mock_spend_credits,
credit_patch,
):
result = await execute_block(
block=block,
@@ -147,14 +148,20 @@ class TestExecuteBlockCreditCharging:
assert isinstance(result, BlockOutputResponse)
assert result.success is True
# Credit functions should not be called at all for zero-cost blocks
mock_get_credits.assert_not_awaited()
mock_spend_credits.assert_not_awaited()
mock_credit.get_credits.assert_not_awaited()
mock_credit.spend_credits.assert_not_awaited()
async def test_returns_output_on_post_exec_insufficient_balance(self):
"""If charging fails after execution, output is still returned (block already ran)."""
from backend.util.exceptions import InsufficientBalanceError
block = _make_block()
credit_patch, mock_credit = _patch_credit_db(
get_credits_return=15,
spend_credits_side_effect=InsufficientBalanceError(
"Low balance", _USER, 5, 10
),
)
with (
_patch_workspace(),
@@ -162,18 +169,7 @@ class TestExecuteBlockCreditCharging:
"backend.copilot.tools.helpers.block_usage_cost",
return_value=(10, {}),
),
patch(
"backend.copilot.tools.helpers._get_credits",
new_callable=AsyncMock,
return_value=15, # passes pre-check
),
patch(
"backend.copilot.tools.helpers._spend_credits",
new_callable=AsyncMock,
side_effect=InsufficientBalanceError(
"Low balance", _USER, 5, 10
), # fails during actual charge (race with concurrent spend)
),
credit_patch,
):
result = await execute_block(
block=block,
@@ -238,7 +234,7 @@ _TEST_USER_ID = "test-user-coerce"
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_json_string_to_nested_list():
"""JSON string list[list[str]] (Google Sheets CSV import case)."""
"""JSON string -> list[list[str]] (Google Sheets CSV import case)."""
block = _make_coerce_block(
"sheets-write",
"Google Sheets Write",
@@ -280,7 +276,7 @@ async def test_coerce_json_string_to_nested_list():
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_json_string_to_list():
"""JSON string list[str]."""
"""JSON string -> list[str]."""
block = _make_coerce_block(
"list-block",
"List Block",
@@ -312,7 +308,7 @@ async def test_coerce_json_string_to_list():
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_json_string_to_dict():
"""JSON string dict[str, str]."""
"""JSON string -> dict[str, str]."""
block = _make_coerce_block(
"dict-block",
"Dict Block",
@@ -378,7 +374,7 @@ async def test_no_coercion_when_type_matches():
@pytest.mark.asyncio(loop_scope="session")
async def test_coerce_string_to_int():
"""String number int."""
"""String number -> int."""
block = _make_coerce_block(
"int-block",
"Int Block",

View File

@@ -1220,6 +1220,20 @@ async def get_user_credit_model(user_id: str) -> UserCreditBase:
return BetaUserCredit(settings.config.num_user_credits_refill)
async def get_credits(user_id: str) -> int:
"""Get the current credit balance for a user."""
credit_model = await get_user_credit_model(user_id)
return await credit_model.get_credits(user_id)
async def spend_credits(
user_id: str, cost: int, metadata: UsageTransactionMetadata
) -> int:
"""Spend credits for a user."""
credit_model = await get_user_credit_model(user_id)
return await credit_model.spend_credits(user_id, cost, metadata)
def get_block_costs() -> dict[str, list["BlockCost"]]:
return {block().id: costs for block, costs in BLOCK_COSTS.items()}

View File

@@ -129,3 +129,16 @@ def review_db():
review_db = get_database_manager_async_client()
return review_db
def credit_db():
if db.is_connected():
from backend.data import credit as _credit_db
credit_db = _credit_db
else:
from backend.util.clients import get_database_manager_async_client
credit_db = get_database_manager_async_client()
return credit_db