mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
refactor(backend/copilot): use credit_db() adapter pattern for credit operations
Replace inline _get_credits/_spend_credits helpers with the centralized credit_db() accessor from db_accessors.py, consistent with workspace_db(), chat_db(), and other DB accessors. Add module-level get_credits() and spend_credits() to credit.py so the accessor can return the module directly.
This commit is contained in:
@@ -8,9 +8,8 @@ from pydantic_core import PydanticUndefined
|
||||
|
||||
from backend.blocks._base import AnyBlockSchema
|
||||
from backend.copilot.constants import COPILOT_NODE_PREFIX, COPILOT_SESSION_PREFIX
|
||||
from backend.data import db
|
||||
from backend.data.credit import UsageTransactionMetadata
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.data.db_accessors import credit_db, workspace_db
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||
from backend.executor.utils import block_usage_cost
|
||||
@@ -24,36 +23,6 @@ from .utils import match_credentials_to_requirements
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _get_credits(user_id: str) -> int:
|
||||
"""Get user credits using the adapter pattern (RPC when Prisma unavailable)."""
|
||||
if db.is_connected():
|
||||
from backend.data.credit import get_user_credit_model
|
||||
|
||||
credit_model = await get_user_credit_model(user_id)
|
||||
return await credit_model.get_credits(user_id)
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
return await get_database_manager_async_client().get_credits(user_id)
|
||||
|
||||
|
||||
async def _spend_credits(
|
||||
user_id: str, cost: int, metadata: UsageTransactionMetadata
|
||||
) -> int:
|
||||
"""Spend user credits using the adapter pattern (RPC when Prisma unavailable)."""
|
||||
if db.is_connected():
|
||||
from backend.data.credit import get_user_credit_model
|
||||
|
||||
credit_model = await get_user_credit_model(user_id)
|
||||
return await credit_model.spend_credits(user_id, cost, metadata)
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
return await get_database_manager_async_client().spend_credits(
|
||||
user_id, cost, metadata
|
||||
)
|
||||
|
||||
|
||||
def get_inputs_from_schema(
|
||||
input_schema: dict[str, Any],
|
||||
exclude_fields: set[str] | None = None,
|
||||
@@ -152,7 +121,7 @@ async def execute_block(
|
||||
cost, cost_filter = block_usage_cost(block, input_data)
|
||||
has_cost = cost > 0
|
||||
if has_cost:
|
||||
balance = await _get_credits(user_id)
|
||||
balance = await credit_db().get_credits(user_id)
|
||||
if balance < cost:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
@@ -173,7 +142,7 @@ async def execute_block(
|
||||
# Charge credits for block execution
|
||||
if has_cost:
|
||||
try:
|
||||
await _spend_credits(
|
||||
await credit_db().spend_credits(
|
||||
user_id=user_id,
|
||||
cost=cost,
|
||||
metadata=UsageTransactionMetadata(
|
||||
@@ -216,7 +185,7 @@ async def execute_block(
|
||||
except Exception as e:
|
||||
logger.error("Unexpected error executing block: %s", e, exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to execute block: {e}",
|
||||
message=f"Failed to execute block: {str(e)}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
@@ -42,6 +42,23 @@ def _patch_workspace():
|
||||
return patch("backend.copilot.tools.helpers.workspace_db", return_value=mock_ws_db)
|
||||
|
||||
|
||||
def _patch_credit_db(
|
||||
get_credits_return: int = 100,
|
||||
spend_credits_side_effect: Any = None,
|
||||
):
|
||||
"""Patch credit_db to return a mock credit accessor."""
|
||||
mock_credit = MagicMock()
|
||||
mock_credit.get_credits = AsyncMock(return_value=get_credits_return)
|
||||
if spend_credits_side_effect is not None:
|
||||
mock_credit.spend_credits = AsyncMock(side_effect=spend_credits_side_effect)
|
||||
else:
|
||||
mock_credit.spend_credits = AsyncMock()
|
||||
return (
|
||||
patch("backend.copilot.tools.helpers.credit_db", return_value=mock_credit),
|
||||
mock_credit,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Credit charging tests
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -52,7 +69,7 @@ class TestExecuteBlockCreditCharging:
|
||||
async def test_charges_credits_when_cost_is_positive(self):
|
||||
"""Block with cost > 0 should call spend_credits after execution."""
|
||||
block = _make_block()
|
||||
mock_spend = AsyncMock()
|
||||
credit_patch, mock_credit = _patch_credit_db(get_credits_return=100)
|
||||
|
||||
with (
|
||||
_patch_workspace(),
|
||||
@@ -60,16 +77,7 @@ class TestExecuteBlockCreditCharging:
|
||||
"backend.copilot.tools.helpers.block_usage_cost",
|
||||
return_value=(10, {"key": "val"}),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers._get_credits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=100,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers._spend_credits",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=mock_spend,
|
||||
),
|
||||
credit_patch,
|
||||
):
|
||||
result = await execute_block(
|
||||
block=block,
|
||||
@@ -83,14 +91,15 @@ class TestExecuteBlockCreditCharging:
|
||||
|
||||
assert isinstance(result, BlockOutputResponse)
|
||||
assert result.success is True
|
||||
mock_spend.assert_awaited_once()
|
||||
call_kwargs = mock_spend.call_args.kwargs
|
||||
mock_credit.spend_credits.assert_awaited_once()
|
||||
call_kwargs = mock_credit.spend_credits.call_args.kwargs
|
||||
assert call_kwargs["cost"] == 10
|
||||
assert call_kwargs["metadata"].reason == "copilot_block_execution"
|
||||
|
||||
async def test_returns_error_when_insufficient_credits_before_exec(self):
|
||||
"""Pre-execution check should return ErrorResponse when balance < cost."""
|
||||
block = _make_block()
|
||||
credit_patch, mock_credit = _patch_credit_db(get_credits_return=5)
|
||||
|
||||
with (
|
||||
_patch_workspace(),
|
||||
@@ -98,11 +107,7 @@ class TestExecuteBlockCreditCharging:
|
||||
"backend.copilot.tools.helpers.block_usage_cost",
|
||||
return_value=(10, {}),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers._get_credits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=5, # balance < cost (10)
|
||||
),
|
||||
credit_patch,
|
||||
):
|
||||
result = await execute_block(
|
||||
block=block,
|
||||
@@ -120,6 +125,7 @@ class TestExecuteBlockCreditCharging:
|
||||
async def test_no_charge_when_cost_is_zero(self):
|
||||
"""Block with cost 0 should not call spend_credits."""
|
||||
block = _make_block()
|
||||
credit_patch, mock_credit = _patch_credit_db()
|
||||
|
||||
with (
|
||||
_patch_workspace(),
|
||||
@@ -127,12 +133,7 @@ class TestExecuteBlockCreditCharging:
|
||||
"backend.copilot.tools.helpers.block_usage_cost",
|
||||
return_value=(0, {}),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers._get_credits",
|
||||
) as mock_get_credits,
|
||||
patch(
|
||||
"backend.copilot.tools.helpers._spend_credits",
|
||||
) as mock_spend_credits,
|
||||
credit_patch,
|
||||
):
|
||||
result = await execute_block(
|
||||
block=block,
|
||||
@@ -147,14 +148,20 @@ class TestExecuteBlockCreditCharging:
|
||||
assert isinstance(result, BlockOutputResponse)
|
||||
assert result.success is True
|
||||
# Credit functions should not be called at all for zero-cost blocks
|
||||
mock_get_credits.assert_not_awaited()
|
||||
mock_spend_credits.assert_not_awaited()
|
||||
mock_credit.get_credits.assert_not_awaited()
|
||||
mock_credit.spend_credits.assert_not_awaited()
|
||||
|
||||
async def test_returns_output_on_post_exec_insufficient_balance(self):
|
||||
"""If charging fails after execution, output is still returned (block already ran)."""
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
|
||||
block = _make_block()
|
||||
credit_patch, mock_credit = _patch_credit_db(
|
||||
get_credits_return=15,
|
||||
spend_credits_side_effect=InsufficientBalanceError(
|
||||
"Low balance", _USER, 5, 10
|
||||
),
|
||||
)
|
||||
|
||||
with (
|
||||
_patch_workspace(),
|
||||
@@ -162,18 +169,7 @@ class TestExecuteBlockCreditCharging:
|
||||
"backend.copilot.tools.helpers.block_usage_cost",
|
||||
return_value=(10, {}),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers._get_credits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=15, # passes pre-check
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.helpers._spend_credits",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=InsufficientBalanceError(
|
||||
"Low balance", _USER, 5, 10
|
||||
), # fails during actual charge (race with concurrent spend)
|
||||
),
|
||||
credit_patch,
|
||||
):
|
||||
result = await execute_block(
|
||||
block=block,
|
||||
@@ -238,7 +234,7 @@ _TEST_USER_ID = "test-user-coerce"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_json_string_to_nested_list():
|
||||
"""JSON string → list[list[str]] (Google Sheets CSV import case)."""
|
||||
"""JSON string -> list[list[str]] (Google Sheets CSV import case)."""
|
||||
block = _make_coerce_block(
|
||||
"sheets-write",
|
||||
"Google Sheets Write",
|
||||
@@ -280,7 +276,7 @@ async def test_coerce_json_string_to_nested_list():
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_json_string_to_list():
|
||||
"""JSON string → list[str]."""
|
||||
"""JSON string -> list[str]."""
|
||||
block = _make_coerce_block(
|
||||
"list-block",
|
||||
"List Block",
|
||||
@@ -312,7 +308,7 @@ async def test_coerce_json_string_to_list():
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_json_string_to_dict():
|
||||
"""JSON string → dict[str, str]."""
|
||||
"""JSON string -> dict[str, str]."""
|
||||
block = _make_coerce_block(
|
||||
"dict-block",
|
||||
"Dict Block",
|
||||
@@ -378,7 +374,7 @@ async def test_no_coercion_when_type_matches():
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_coerce_string_to_int():
|
||||
"""String number → int."""
|
||||
"""String number -> int."""
|
||||
block = _make_coerce_block(
|
||||
"int-block",
|
||||
"Int Block",
|
||||
|
||||
@@ -1220,6 +1220,20 @@ async def get_user_credit_model(user_id: str) -> UserCreditBase:
|
||||
return BetaUserCredit(settings.config.num_user_credits_refill)
|
||||
|
||||
|
||||
async def get_credits(user_id: str) -> int:
|
||||
"""Get the current credit balance for a user."""
|
||||
credit_model = await get_user_credit_model(user_id)
|
||||
return await credit_model.get_credits(user_id)
|
||||
|
||||
|
||||
async def spend_credits(
|
||||
user_id: str, cost: int, metadata: UsageTransactionMetadata
|
||||
) -> int:
|
||||
"""Spend credits for a user."""
|
||||
credit_model = await get_user_credit_model(user_id)
|
||||
return await credit_model.spend_credits(user_id, cost, metadata)
|
||||
|
||||
|
||||
def get_block_costs() -> dict[str, list["BlockCost"]]:
|
||||
return {block().id: costs for block, costs in BLOCK_COSTS.items()}
|
||||
|
||||
|
||||
@@ -129,3 +129,16 @@ def review_db():
|
||||
review_db = get_database_manager_async_client()
|
||||
|
||||
return review_db
|
||||
|
||||
|
||||
def credit_db():
|
||||
if db.is_connected():
|
||||
from backend.data import credit as _credit_db
|
||||
|
||||
credit_db = _credit_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
credit_db = get_database_manager_async_client()
|
||||
|
||||
return credit_db
|
||||
|
||||
Reference in New Issue
Block a user