fix linter errors

This commit is contained in:
Swifty
2026-01-08 13:04:02 +01:00
parent fc25e008b3
commit 6686de1701
17 changed files with 572 additions and 509 deletions

View File

@@ -5,6 +5,8 @@ from os import getenv
import pytest
from pydantic import SecretStr
from prisma.types import ProfileCreateInput
from backend.api.features.chat.model import ChatSession
from backend.api.features.store import db as store_db
from backend.blocks.firecrawl.scrape import FirecrawlScrapeBlock
@@ -49,13 +51,13 @@ async def setup_test_data():
# 1b. Create a profile with username for the user (required for store agent lookup)
username = user.email.split("@")[0]
await prisma.profile.create(
data={
"userId": user.id,
"username": username,
"name": f"Test User {username}",
"description": "Test user profile",
"links": [], # Required field - empty array for test profiles
}
data=ProfileCreateInput(
userId=user.id,
username=username,
name=f"Test User {username}",
description="Test user profile",
links=[], # Required field - empty array for test profiles
)
)
# 2. Create a test graph with agent input -> agent output
@@ -172,13 +174,13 @@ async def setup_llm_test_data():
# 1b. Create a profile with username for the user (required for store agent lookup)
username = user.email.split("@")[0]
await prisma.profile.create(
data={
"userId": user.id,
"username": username,
"name": f"Test User {username}",
"description": "Test user profile for LLM tests",
"links": [], # Required field - empty array for test profiles
}
data=ProfileCreateInput(
userId=user.id,
username=username,
name=f"Test User {username}",
description="Test user profile for LLM tests",
links=[], # Required field - empty array for test profiles
)
)
# 2. Create test OpenAI credentials for the user
@@ -332,13 +334,13 @@ async def setup_firecrawl_test_data():
# 1b. Create a profile with username for the user (required for store agent lookup)
username = user.email.split("@")[0]
await prisma.profile.create(
data={
"userId": user.id,
"username": username,
"name": f"Test User {username}",
"description": "Test user profile for Firecrawl tests",
"links": [], # Required field - empty array for test profiles
}
data=ProfileCreateInput(
userId=user.id,
username=username,
name=f"Test User {username}",
description="Test user profile for Firecrawl tests",
links=[], # Required field - empty array for test profiles
)
)
# NOTE: We deliberately do NOT create Firecrawl credentials for this user

View File

@@ -817,18 +817,16 @@ async def add_store_agent_to_library(
# Create LibraryAgent entry
added_agent = await prisma.models.LibraryAgent.prisma().create(
data={
"User": {"connect": {"id": user_id}},
"AgentGraph": {
data=prisma.types.LibraryAgentCreateInput(
User={"connect": {"id": user_id}},
AgentGraph={
"connect": {
"graphVersionId": {"id": graph.id, "version": graph.version}
}
},
"isCreatedByUser": False,
"settings": SafeJson(
_initialize_graph_settings(graph_model).model_dump()
),
},
isCreatedByUser=False,
settings=SafeJson(_initialize_graph_settings(graph_model).model_dump()),
),
include=library_agent_include(
user_id, include_nodes=False, include_executions=False
),

View File

@@ -27,6 +27,13 @@ from prisma.models import OAuthApplication as PrismaOAuthApplication
from prisma.models import OAuthAuthorizationCode as PrismaOAuthAuthorizationCode
from prisma.models import OAuthRefreshToken as PrismaOAuthRefreshToken
from prisma.models import User as PrismaUser
from prisma.types import (
OAuthAccessTokenCreateInput,
OAuthApplicationCreateInput,
OAuthAuthorizationCodeCreateInput,
OAuthRefreshTokenCreateInput,
UserCreateInput,
)
from backend.api.rest_api import app
@@ -48,11 +55,11 @@ def test_user_id() -> str:
async def test_user(server, test_user_id: str):
"""Create a test user in the database."""
await PrismaUser.prisma().create(
data={
"id": test_user_id,
"email": f"oauth-test-{test_user_id}@example.com",
"name": "OAuth Test User",
}
data=UserCreateInput(
id=test_user_id,
email=f"oauth-test-{test_user_id}@example.com",
name="OAuth Test User",
)
)
yield test_user_id
@@ -77,22 +84,22 @@ async def test_oauth_app(test_user: str):
client_secret_hash, client_secret_salt = keysmith.hash_key(client_secret_plaintext)
await PrismaOAuthApplication.prisma().create(
data={
"id": app_id,
"name": "Test OAuth App",
"description": "Test application for integration tests",
"clientId": client_id,
"clientSecret": client_secret_hash,
"clientSecretSalt": client_secret_salt,
"redirectUris": [
data=OAuthApplicationCreateInput(
id=app_id,
name="Test OAuth App",
description="Test application for integration tests",
clientId=client_id,
clientSecret=client_secret_hash,
clientSecretSalt=client_secret_salt,
redirectUris=[
"https://example.com/callback",
"http://localhost:3000/callback",
],
"grantTypes": ["authorization_code", "refresh_token"],
"scopes": [APIKeyPermission.EXECUTE_GRAPH, APIKeyPermission.READ_GRAPH],
"ownerId": test_user,
"isActive": True,
}
grantTypes=["authorization_code", "refresh_token"],
scopes=[APIKeyPermission.EXECUTE_GRAPH, APIKeyPermission.READ_GRAPH],
ownerId=test_user,
isActive=True,
)
)
yield {
@@ -296,19 +303,19 @@ async def inactive_oauth_app(test_user: str):
client_secret_hash, client_secret_salt = keysmith.hash_key(client_secret_plaintext)
await PrismaOAuthApplication.prisma().create(
data={
"id": app_id,
"name": "Inactive OAuth App",
"description": "Inactive test application",
"clientId": client_id,
"clientSecret": client_secret_hash,
"clientSecretSalt": client_secret_salt,
"redirectUris": ["https://example.com/callback"],
"grantTypes": ["authorization_code", "refresh_token"],
"scopes": [APIKeyPermission.EXECUTE_GRAPH],
"ownerId": test_user,
"isActive": False, # Inactive!
}
data=OAuthApplicationCreateInput(
id=app_id,
name="Inactive OAuth App",
description="Inactive test application",
clientId=client_id,
clientSecret=client_secret_hash,
clientSecretSalt=client_secret_salt,
redirectUris=["https://example.com/callback"],
grantTypes=["authorization_code", "refresh_token"],
scopes=[APIKeyPermission.EXECUTE_GRAPH],
ownerId=test_user,
isActive=False, # Inactive!
)
)
yield {
@@ -699,14 +706,14 @@ async def test_token_authorization_code_expired(
now = datetime.now(timezone.utc)
await PrismaOAuthAuthorizationCode.prisma().create(
data={
"code": expired_code,
"applicationId": test_oauth_app["id"],
"userId": test_user,
"scopes": [APIKeyPermission.EXECUTE_GRAPH],
"redirectUri": test_oauth_app["redirect_uri"],
"expiresAt": now - timedelta(hours=1), # Already expired
}
data=OAuthAuthorizationCodeCreateInput(
code=expired_code,
applicationId=test_oauth_app["id"],
userId=test_user,
scopes=[APIKeyPermission.EXECUTE_GRAPH],
redirectUri=test_oauth_app["redirect_uri"],
expiresAt=now - timedelta(hours=1), # Already expired
)
)
response = await client.post(
@@ -942,13 +949,13 @@ async def test_token_refresh_expired(
now = datetime.now(timezone.utc)
await PrismaOAuthRefreshToken.prisma().create(
data={
"token": expired_token_hash,
"applicationId": test_oauth_app["id"],
"userId": test_user,
"scopes": [APIKeyPermission.EXECUTE_GRAPH],
"expiresAt": now - timedelta(days=1), # Already expired
}
data=OAuthRefreshTokenCreateInput(
token=expired_token_hash,
applicationId=test_oauth_app["id"],
userId=test_user,
scopes=[APIKeyPermission.EXECUTE_GRAPH],
expiresAt=now - timedelta(days=1), # Already expired
)
)
response = await client.post(
@@ -980,14 +987,14 @@ async def test_token_refresh_revoked(
now = datetime.now(timezone.utc)
await PrismaOAuthRefreshToken.prisma().create(
data={
"token": revoked_token_hash,
"applicationId": test_oauth_app["id"],
"userId": test_user,
"scopes": [APIKeyPermission.EXECUTE_GRAPH],
"expiresAt": now + timedelta(days=30), # Not expired
"revokedAt": now - timedelta(hours=1), # But revoked
}
data=OAuthRefreshTokenCreateInput(
token=revoked_token_hash,
applicationId=test_oauth_app["id"],
userId=test_user,
scopes=[APIKeyPermission.EXECUTE_GRAPH],
expiresAt=now + timedelta(days=30), # Not expired
revokedAt=now - timedelta(hours=1), # But revoked
)
)
response = await client.post(
@@ -1013,19 +1020,19 @@ async def other_oauth_app(test_user: str):
client_secret_hash, client_secret_salt = keysmith.hash_key(client_secret_plaintext)
await PrismaOAuthApplication.prisma().create(
data={
"id": app_id,
"name": "Other OAuth App",
"description": "Second test application",
"clientId": client_id,
"clientSecret": client_secret_hash,
"clientSecretSalt": client_secret_salt,
"redirectUris": ["https://other.example.com/callback"],
"grantTypes": ["authorization_code", "refresh_token"],
"scopes": [APIKeyPermission.EXECUTE_GRAPH],
"ownerId": test_user,
"isActive": True,
}
data=OAuthApplicationCreateInput(
id=app_id,
name="Other OAuth App",
description="Second test application",
clientId=client_id,
clientSecret=client_secret_hash,
clientSecretSalt=client_secret_salt,
redirectUris=["https://other.example.com/callback"],
grantTypes=["authorization_code", "refresh_token"],
scopes=[APIKeyPermission.EXECUTE_GRAPH],
ownerId=test_user,
isActive=True,
)
)
yield {
@@ -1052,13 +1059,13 @@ async def test_token_refresh_wrong_application(
now = datetime.now(timezone.utc)
await PrismaOAuthRefreshToken.prisma().create(
data={
"token": token_hash,
"applicationId": test_oauth_app["id"], # Belongs to test_oauth_app
"userId": test_user,
"scopes": [APIKeyPermission.EXECUTE_GRAPH],
"expiresAt": now + timedelta(days=30),
}
data=OAuthRefreshTokenCreateInput(
token=token_hash,
applicationId=test_oauth_app["id"], # Belongs to test_oauth_app
userId=test_user,
scopes=[APIKeyPermission.EXECUTE_GRAPH],
expiresAt=now + timedelta(days=30),
)
)
# Try to use it with `other_oauth_app`
@@ -1267,19 +1274,19 @@ async def test_validate_access_token_fails_when_app_disabled(
client_secret_hash, client_secret_salt = keysmith.hash_key(client_secret_plaintext)
await PrismaOAuthApplication.prisma().create(
data={
"id": app_id,
"name": "App To Be Disabled",
"description": "Test app for disabled validation",
"clientId": client_id,
"clientSecret": client_secret_hash,
"clientSecretSalt": client_secret_salt,
"redirectUris": ["https://example.com/callback"],
"grantTypes": ["authorization_code"],
"scopes": [APIKeyPermission.EXECUTE_GRAPH],
"ownerId": test_user,
"isActive": True,
}
data=OAuthApplicationCreateInput(
id=app_id,
name="App To Be Disabled",
description="Test app for disabled validation",
clientId=client_id,
clientSecret=client_secret_hash,
clientSecretSalt=client_secret_salt,
redirectUris=["https://example.com/callback"],
grantTypes=["authorization_code"],
scopes=[APIKeyPermission.EXECUTE_GRAPH],
ownerId=test_user,
isActive=True,
)
)
# Create an access token directly in the database
@@ -1288,13 +1295,13 @@ async def test_validate_access_token_fails_when_app_disabled(
now = datetime.now(timezone.utc)
await PrismaOAuthAccessToken.prisma().create(
data={
"token": token_hash,
"applicationId": app_id,
"userId": test_user,
"scopes": [APIKeyPermission.EXECUTE_GRAPH],
"expiresAt": now + timedelta(hours=1),
}
data=OAuthAccessTokenCreateInput(
token=token_hash,
applicationId=app_id,
userId=test_user,
scopes=[APIKeyPermission.EXECUTE_GRAPH],
expiresAt=now + timedelta(hours=1),
)
)
# Token should be valid while app is active
@@ -1561,19 +1568,19 @@ async def test_revoke_token_from_different_app_fails_silently(
)
await PrismaOAuthApplication.prisma().create(
data={
"id": app2_id,
"name": "Second Test OAuth App",
"description": "Second test application for cross-app revocation test",
"clientId": app2_client_id,
"clientSecret": app2_client_secret_hash,
"clientSecretSalt": app2_client_secret_salt,
"redirectUris": ["https://other-app.com/callback"],
"grantTypes": ["authorization_code", "refresh_token"],
"scopes": [APIKeyPermission.EXECUTE_GRAPH, APIKeyPermission.READ_GRAPH],
"ownerId": test_user,
"isActive": True,
}
data=OAuthApplicationCreateInput(
id=app2_id,
name="Second Test OAuth App",
description="Second test application for cross-app revocation test",
clientId=app2_client_id,
clientSecret=app2_client_secret_hash,
clientSecretSalt=app2_client_secret_salt,
redirectUris=["https://other-app.com/callback"],
grantTypes=["authorization_code", "refresh_token"],
scopes=[APIKeyPermission.EXECUTE_GRAPH, APIKeyPermission.READ_GRAPH],
ownerId=test_user,
isActive=True,
)
)
# App 2 tries to revoke App 1's access token

View File

@@ -249,7 +249,9 @@ async def log_search_term(search_query: str):
date = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0)
try:
await prisma.models.SearchTerms.prisma().create(
data={"searchTerm": search_query, "createdDate": date}
data=prisma.types.SearchTermsCreateInput(
searchTerm=search_query, createdDate=date
)
)
except Exception as e:
# Fail silently here so that logging search terms doesn't break the app
@@ -1456,11 +1458,9 @@ async def _approve_sub_agent(
# Create new version if no matching version found
next_version = max((v.version for v in listing.Versions or []), default=0) + 1
await prisma.models.StoreListingVersion.prisma(tx).create(
data={
**_create_sub_agent_version_data(sub_graph, heading, main_agent_name),
"version": next_version,
"storeListingId": listing.id,
}
data=_create_sub_agent_version_data(
sub_graph, heading, main_agent_name, next_version, listing.id
)
)
await prisma.models.StoreListing.prisma(tx).update(
where={"id": listing.id}, data={"hasApprovedVersion": True}
@@ -1468,10 +1468,14 @@ async def _approve_sub_agent(
def _create_sub_agent_version_data(
sub_graph: prisma.models.AgentGraph, heading: str, main_agent_name: str
sub_graph: prisma.models.AgentGraph,
heading: str,
main_agent_name: str,
version: typing.Optional[int] = None,
store_listing_id: typing.Optional[str] = None,
) -> prisma.types.StoreListingVersionCreateInput:
"""Create store listing version data for a sub-agent"""
return prisma.types.StoreListingVersionCreateInput(
data = prisma.types.StoreListingVersionCreateInput(
agentGraphId=sub_graph.id,
agentGraphVersion=sub_graph.version,
name=sub_graph.name or heading,
@@ -1486,6 +1490,11 @@ def _create_sub_agent_version_data(
imageUrls=[], # Sub-agents don't need images
categories=[], # Sub-agents don't need categories
)
if version is not None:
data["version"] = version
if store_listing_id is not None:
data["storeListingId"] = store_listing_id
return data
async def review_store_submission(

View File

@@ -42,6 +42,7 @@ from urllib.parse import urlparse
import click
from autogpt_libs.api_key.keysmith import APIKeySmith
from prisma.enums import APIKeyPermission
from prisma.types import OAuthApplicationCreateInput
keysmith = APIKeySmith()
@@ -147,7 +148,7 @@ def format_sql_insert(creds: dict) -> str:
sql = f"""
-- ============================================================
-- OAuth Application: {creds['name']}
-- OAuth Application: {creds["name"]}
-- Generated: {now_iso} UTC
-- ============================================================
@@ -167,14 +168,14 @@ INSERT INTO "OAuthApplication" (
"isActive"
)
VALUES (
'{creds['id']}',
'{creds["id"]}',
NOW(),
NOW(),
'{creds['name']}',
{f"'{creds['description']}'" if creds['description'] else 'NULL'},
'{creds['client_id']}',
'{creds['client_secret_hash']}',
'{creds['client_secret_salt']}',
'{creds["name"]}',
{f"'{creds['description']}'" if creds["description"] else "NULL"},
'{creds["client_id"]}',
'{creds["client_secret_hash"]}',
'{creds["client_secret_salt"]}',
ARRAY{redirect_uris_pg}::TEXT[],
ARRAY{grant_types_pg}::TEXT[],
ARRAY{scopes_pg}::"APIKeyPermission"[],
@@ -186,8 +187,8 @@ VALUES (
-- ⚠️ IMPORTANT: Save these credentials securely!
-- ============================================================
--
-- Client ID: {creds['client_id']}
-- Client Secret: {creds['client_secret_plaintext']}
-- Client ID: {creds["client_id"]}
-- Client Secret: {creds["client_secret_plaintext"]}
--
-- ⚠️ The client secret is shown ONLY ONCE!
-- ⚠️ Store it securely and share only with the application developer.
@@ -200,7 +201,7 @@ VALUES (
-- To verify the application was created:
-- SELECT "clientId", name, scopes, "redirectUris", "isActive"
-- FROM "OAuthApplication"
-- WHERE "clientId" = '{creds['client_id']}';
-- WHERE "clientId" = '{creds["client_id"]}';
"""
return sql
@@ -834,19 +835,19 @@ async def create_test_app_in_db(
# Insert into database
app = await OAuthApplication.prisma().create(
data={
"id": creds["id"],
"name": creds["name"],
"description": creds["description"],
"clientId": creds["client_id"],
"clientSecret": creds["client_secret_hash"],
"clientSecretSalt": creds["client_secret_salt"],
"redirectUris": creds["redirect_uris"],
"grantTypes": creds["grant_types"],
"scopes": creds["scopes"],
"ownerId": owner_id,
"isActive": True,
}
data=OAuthApplicationCreateInput(
id=creds["id"],
name=creds["name"],
description=creds["description"],
clientId=creds["client_id"],
clientSecret=creds["client_secret_hash"],
clientSecretSalt=creds["client_secret_salt"],
redirectUris=creds["redirect_uris"],
grantTypes=creds["grant_types"],
scopes=creds["scopes"],
ownerId=owner_id,
isActive=True,
)
)
click.echo(f"✓ Created test OAuth application: {app.clientId}")

View File

@@ -6,7 +6,7 @@ from typing import Literal, Optional
from autogpt_libs.api_key.keysmith import APIKeySmith
from prisma.enums import APIKeyPermission, APIKeyStatus
from prisma.models import APIKey as PrismaAPIKey
from prisma.types import APIKeyWhereUniqueInput
from prisma.types import APIKeyCreateInput, APIKeyWhereUniqueInput
from pydantic import Field
from backend.data.includes import MAX_USER_API_KEYS_FETCH
@@ -82,17 +82,17 @@ async def create_api_key(
generated_key = keysmith.generate_key()
saved_key_obj = await PrismaAPIKey.prisma().create(
data={
"id": str(uuid.uuid4()),
"name": name,
"head": generated_key.head,
"tail": generated_key.tail,
"hash": generated_key.hash,
"salt": generated_key.salt,
"permissions": [p for p in permissions],
"description": description,
"userId": user_id,
}
data=APIKeyCreateInput(
id=str(uuid.uuid4()),
name=name,
head=generated_key.head,
tail=generated_key.tail,
hash=generated_key.hash,
salt=generated_key.salt,
permissions=[p for p in permissions],
description=description,
userId=user_id,
)
)
return APIKeyInfo.from_db(saved_key_obj), generated_key.key

View File

@@ -22,7 +22,12 @@ from prisma.models import OAuthAccessToken as PrismaOAuthAccessToken
from prisma.models import OAuthApplication as PrismaOAuthApplication
from prisma.models import OAuthAuthorizationCode as PrismaOAuthAuthorizationCode
from prisma.models import OAuthRefreshToken as PrismaOAuthRefreshToken
from prisma.types import OAuthApplicationUpdateInput
from prisma.types import (
OAuthAccessTokenCreateInput,
OAuthApplicationUpdateInput,
OAuthAuthorizationCodeCreateInput,
OAuthRefreshTokenCreateInput,
)
from pydantic import BaseModel, Field, SecretStr
from .base import APIAuthorizationInfo
@@ -359,17 +364,17 @@ async def create_authorization_code(
expires_at = now + AUTHORIZATION_CODE_TTL
saved_code = await PrismaOAuthAuthorizationCode.prisma().create(
data={
"id": str(uuid.uuid4()),
"code": code,
"expiresAt": expires_at,
"applicationId": application_id,
"userId": user_id,
"scopes": [s for s in scopes],
"redirectUri": redirect_uri,
"codeChallenge": code_challenge,
"codeChallengeMethod": code_challenge_method,
}
data=OAuthAuthorizationCodeCreateInput(
id=str(uuid.uuid4()),
code=code,
expiresAt=expires_at,
applicationId=application_id,
userId=user_id,
scopes=[s for s in scopes],
redirectUri=redirect_uri,
codeChallenge=code_challenge,
codeChallengeMethod=code_challenge_method,
)
)
return OAuthAuthorizationCodeInfo.from_db(saved_code)
@@ -490,14 +495,14 @@ async def create_access_token(
expires_at = now + ACCESS_TOKEN_TTL
saved_token = await PrismaOAuthAccessToken.prisma().create(
data={
"id": str(uuid.uuid4()),
"token": token_hash, # SHA256 hash for direct lookup
"expiresAt": expires_at,
"applicationId": application_id,
"userId": user_id,
"scopes": [s for s in scopes],
}
data=OAuthAccessTokenCreateInput(
id=str(uuid.uuid4()),
token=token_hash, # SHA256 hash for direct lookup
expiresAt=expires_at,
applicationId=application_id,
userId=user_id,
scopes=[s for s in scopes],
)
)
return OAuthAccessToken.from_db(saved_token, plaintext_token=plaintext_token)
@@ -607,14 +612,14 @@ async def create_refresh_token(
expires_at = now + REFRESH_TOKEN_TTL
saved_token = await PrismaOAuthRefreshToken.prisma().create(
data={
"id": str(uuid.uuid4()),
"token": token_hash, # SHA256 hash for direct lookup
"expiresAt": expires_at,
"applicationId": application_id,
"userId": user_id,
"scopes": [s for s in scopes],
}
data=OAuthRefreshTokenCreateInput(
id=str(uuid.uuid4()),
token=token_hash, # SHA256 hash for direct lookup
expiresAt=expires_at,
applicationId=application_id,
userId=user_id,
scopes=[s for s in scopes],
)
)
return OAuthRefreshToken.from_db(saved_token, plaintext_token=plaintext_token)

View File

@@ -11,6 +11,11 @@ import pytest
from prisma.enums import CreditTransactionType
from prisma.errors import UniqueViolationError
from prisma.models import CreditTransaction, User, UserBalance
from prisma.types import (
UserBalanceCreateInput,
UserBalanceUpsertInput,
UserCreateInput,
)
from backend.data.credit import UserCredit
from backend.util.json import SafeJson
@@ -21,11 +26,11 @@ async def create_test_user(user_id: str) -> None:
"""Create a test user for ceiling tests."""
try:
await User.prisma().create(
data={
"id": user_id,
"email": f"test-{user_id}@example.com",
"name": f"Test User {user_id[:8]}",
}
data=UserCreateInput(
id=user_id,
email=f"test-{user_id}@example.com",
name=f"Test User {user_id[:8]}",
)
)
except UniqueViolationError:
# User already exists, continue
@@ -33,7 +38,10 @@ async def create_test_user(user_id: str) -> None:
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
data=UserBalanceUpsertInput(
create=UserBalanceCreateInput(userId=user_id, balance=0),
update={"balance": 0},
),
)
@@ -107,15 +115,15 @@ async def test_ceiling_balance_clamps_when_would_exceed(server: SpinTestServer):
)
# Balance should be clamped to ceiling
assert (
final_balance == 1000
), f"Balance should be clamped to 1000, got {final_balance}"
assert final_balance == 1000, (
f"Balance should be clamped to 1000, got {final_balance}"
)
# Verify with get_credits too
stored_balance = await credit_system.get_credits(user_id)
assert (
stored_balance == 1000
), f"Stored balance should be 1000, got {stored_balance}"
assert stored_balance == 1000, (
f"Stored balance should be 1000, got {stored_balance}"
)
# Verify transaction shows the clamped amount
transactions = await CreditTransaction.prisma().find_many(
@@ -164,9 +172,9 @@ async def test_ceiling_balance_allows_when_under_threshold(server: SpinTestServe
# Verify with get_credits too
stored_balance = await credit_system.get_credits(user_id)
assert (
stored_balance == 500
), f"Stored balance should be 500, got {stored_balance}"
assert stored_balance == 500, (
f"Stored balance should be 500, got {stored_balance}"
)
finally:
await cleanup_test_user(user_id)

View File

@@ -14,6 +14,11 @@ import pytest
from prisma.enums import CreditTransactionType
from prisma.errors import UniqueViolationError
from prisma.models import CreditTransaction, User, UserBalance
from prisma.types import (
UserBalanceCreateInput,
UserBalanceUpsertInput,
UserCreateInput,
)
from backend.data.credit import POSTGRES_INT_MAX, UsageTransactionMetadata, UserCredit
from backend.util.exceptions import InsufficientBalanceError
@@ -28,11 +33,11 @@ async def create_test_user(user_id: str) -> None:
"""Create a test user with initial balance."""
try:
await User.prisma().create(
data={
"id": user_id,
"email": f"test-{user_id}@example.com",
"name": f"Test User {user_id[:8]}",
}
data=UserCreateInput(
id=user_id,
email=f"test-{user_id}@example.com",
name=f"Test User {user_id[:8]}",
)
)
except UniqueViolationError:
# User already exists, continue
@@ -41,7 +46,10 @@ async def create_test_user(user_id: str) -> None:
# Ensure UserBalance record exists
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
data=UserBalanceUpsertInput(
create=UserBalanceCreateInput(userId=user_id, balance=0),
update={"balance": 0},
),
)
@@ -108,9 +116,9 @@ async def test_concurrent_spends_same_user(server: SpinTestServer):
transactions = await CreditTransaction.prisma().find_many(
where={"userId": user_id, "type": prisma.enums.CreditTransactionType.USAGE}
)
assert (
len(transactions) == 10
), f"Expected 10 transactions, got {len(transactions)}"
assert len(transactions) == 10, (
f"Expected 10 transactions, got {len(transactions)}"
)
finally:
await cleanup_test_user(user_id)
@@ -321,9 +329,9 @@ async def test_onboarding_reward_idempotency(server: SpinTestServer):
"transactionKey": f"REWARD-{user_id}-WELCOME",
}
)
assert (
len(transactions) == 1
), f"Expected 1 reward transaction, got {len(transactions)}"
assert len(transactions) == 1, (
f"Expected 1 reward transaction, got {len(transactions)}"
)
finally:
await cleanup_test_user(user_id)
@@ -342,10 +350,10 @@ async def test_integer_overflow_protection(server: SpinTestServer):
# First, set balance near max
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "balance": max_int - 100},
"update": {"balance": max_int - 100},
},
data=UserBalanceUpsertInput(
create=UserBalanceCreateInput(userId=user_id, balance=max_int - 100),
update={"balance": max_int - 100},
),
)
# Try to add more than possible - should clamp to POSTGRES_INT_MAX
@@ -358,9 +366,9 @@ async def test_integer_overflow_protection(server: SpinTestServer):
# Balance should be clamped to max_int, not overflowed
final_balance = await credit_system.get_credits(user_id)
assert (
final_balance == max_int
), f"Balance should be clamped to {max_int}, got {final_balance}"
assert final_balance == max_int, (
f"Balance should be clamped to {max_int}, got {final_balance}"
)
# Verify transaction was created with clamped amount
transactions = await CreditTransaction.prisma().find_many(
@@ -371,9 +379,9 @@ async def test_integer_overflow_protection(server: SpinTestServer):
order={"createdAt": "desc"},
)
assert len(transactions) > 0, "Transaction should be created"
assert (
transactions[0].runningBalance == max_int
), "Transaction should show clamped balance"
assert transactions[0].runningBalance == max_int, (
"Transaction should show clamped balance"
)
finally:
await cleanup_test_user(user_id)
@@ -432,9 +440,9 @@ async def test_high_concurrency_stress(server: SpinTestServer):
# Verify final balance
final_balance = await credit_system.get_credits(user_id)
assert (
final_balance == expected_balance
), f"Expected {expected_balance}, got {final_balance}"
assert final_balance == expected_balance, (
f"Expected {expected_balance}, got {final_balance}"
)
assert final_balance >= 0, "Balance went negative!"
finally:
@@ -507,7 +515,7 @@ async def test_concurrent_multiple_spends_sufficient_balance(server: SpinTestSer
sorted_timings = sorted(timings.items(), key=lambda x: x[1]["start"])
print("\nExecution order by start time:")
for i, (label, timing) in enumerate(sorted_timings):
print(f" {i+1}. {label}: {timing['start']:.4f} -> {timing['end']:.4f}")
print(f" {i + 1}. {label}: {timing['start']:.4f} -> {timing['end']:.4f}")
# Check for overlap (true concurrency) vs serialization
overlaps = []
@@ -533,9 +541,9 @@ async def test_concurrent_multiple_spends_sufficient_balance(server: SpinTestSer
print(f"Successful: {len(successful)}, Failed: {len(failed)}")
# All should succeed since 150 - (10 + 20 + 30) = 90 > 0
assert (
len(successful) == 3
), f"Expected all 3 to succeed, got {len(successful)} successes: {results}"
assert len(successful) == 3, (
f"Expected all 3 to succeed, got {len(successful)} successes: {results}"
)
assert final_balance == 90, f"Expected balance 90, got {final_balance}"
# Check transaction timestamps to confirm database-level serialization
@@ -546,7 +554,7 @@ async def test_concurrent_multiple_spends_sufficient_balance(server: SpinTestSer
print("\nDatabase transaction order (by createdAt):")
for i, tx in enumerate(transactions):
print(
f" {i+1}. Amount {tx.amount}, Running balance: {tx.runningBalance}, Created: {tx.createdAt}"
f" {i + 1}. Amount {tx.amount}, Running balance: {tx.runningBalance}, Created: {tx.createdAt}"
)
# Verify running balances are chronologically consistent (ordered by createdAt)
@@ -575,38 +583,38 @@ async def test_concurrent_multiple_spends_sufficient_balance(server: SpinTestSer
# Verify all balances are valid intermediate states
for balance in actual_balances:
assert (
balance in expected_possible_balances
), f"Invalid balance {balance}, expected one of {expected_possible_balances}"
assert balance in expected_possible_balances, (
f"Invalid balance {balance}, expected one of {expected_possible_balances}"
)
# Final balance should always be 90 (150 - 60)
assert (
min(actual_balances) == 90
), f"Final balance should be 90, got {min(actual_balances)}"
assert min(actual_balances) == 90, (
f"Final balance should be 90, got {min(actual_balances)}"
)
# The final transaction should always have balance 90
# The other transactions should have valid intermediate balances
assert (
90 in actual_balances
), f"Final balance 90 should be in actual_balances: {actual_balances}"
assert 90 in actual_balances, (
f"Final balance 90 should be in actual_balances: {actual_balances}"
)
# All balances should be >= 90 (the final state)
assert all(
balance >= 90 for balance in actual_balances
), f"All balances should be >= 90, got {actual_balances}"
assert all(balance >= 90 for balance in actual_balances), (
f"All balances should be >= 90, got {actual_balances}"
)
# CRITICAL: Transactions are atomic but can complete in any order
# What matters is that all running balances are valid intermediate states
# Each balance should be between 90 (final) and 140 (after first transaction)
for balance in actual_balances:
assert (
90 <= balance <= 140
), f"Balance {balance} is outside valid range [90, 140]"
assert 90 <= balance <= 140, (
f"Balance {balance} is outside valid range [90, 140]"
)
# Final balance (minimum) should always be 90
assert (
min(actual_balances) == 90
), f"Final balance should be 90, got {min(actual_balances)}"
assert min(actual_balances) == 90, (
f"Final balance should be 90, got {min(actual_balances)}"
)
finally:
await cleanup_test_user(user_id)
@@ -707,7 +715,7 @@ async def test_prove_database_locking_behavior(server: SpinTestServer):
for i, result in enumerate(sorted_results):
print(
f" {i+1}. {result['label']}: DB operation took {result['db_duration']:.4f}s"
f" {i + 1}. {result['label']}: DB operation took {result['db_duration']:.4f}s"
)
# Check if any operations overlapped at the database level
@@ -722,9 +730,9 @@ async def test_prove_database_locking_behavior(server: SpinTestServer):
print(f"\n💰 Final balance: {final_balance}")
if len(successful) == 3:
assert (
final_balance == 0
), f"If all succeeded, balance should be 0, got {final_balance}"
assert final_balance == 0, (
f"If all succeeded, balance should be 0, got {final_balance}"
)
print(
"✅ CONCLUSION: Database row locking causes requests to WAIT and execute serially"
)

View File

@@ -8,6 +8,7 @@ which would have caught the CreditTransactionType enum casting bug.
import pytest
from prisma.enums import CreditTransactionType
from prisma.models import CreditTransaction, User, UserBalance
from prisma.types import UserCreateInput
from backend.data.credit import (
AutoTopUpConfig,
@@ -29,12 +30,12 @@ async def cleanup_test_user():
# Create the user first
try:
await User.prisma().create(
data={
"id": user_id,
"email": f"test-{user_id}@example.com",
"topUpConfig": SafeJson({}),
"timezone": "UTC",
}
data=UserCreateInput(
id=user_id,
email=f"test-{user_id}@example.com",
topUpConfig=SafeJson({}),
timezone="UTC",
)
)
except Exception:
# User might already exist, that's fine

View File

@@ -12,6 +12,12 @@ import pytest
import stripe
from prisma.enums import CreditTransactionType
from prisma.models import CreditRefundRequest, CreditTransaction, User, UserBalance
from prisma.types import (
CreditRefundRequestCreateInput,
CreditTransactionCreateInput,
UserBalanceCreateInput,
UserCreateInput,
)
from backend.data.credit import UserCredit
from backend.util.json import SafeJson
@@ -35,32 +41,32 @@ async def setup_test_user_with_topup():
# Create user
await User.prisma().create(
data={
"id": REFUND_TEST_USER_ID,
"email": f"{REFUND_TEST_USER_ID}@example.com",
"name": "Refund Test User",
}
data=UserCreateInput(
id=REFUND_TEST_USER_ID,
email=f"{REFUND_TEST_USER_ID}@example.com",
name="Refund Test User",
)
)
# Create user balance
await UserBalance.prisma().create(
data={
"userId": REFUND_TEST_USER_ID,
"balance": 1000, # $10
}
data=UserBalanceCreateInput(
userId=REFUND_TEST_USER_ID,
balance=1000, # $10
)
)
# Create a top-up transaction that can be refunded
topup_tx = await CreditTransaction.prisma().create(
data={
"userId": REFUND_TEST_USER_ID,
"amount": 1000,
"type": CreditTransactionType.TOP_UP,
"transactionKey": "pi_test_12345",
"runningBalance": 1000,
"isActive": True,
"metadata": SafeJson({"stripe_payment_intent": "pi_test_12345"}),
}
data=CreditTransactionCreateInput(
userId=REFUND_TEST_USER_ID,
amount=1000,
type=CreditTransactionType.TOP_UP,
transactionKey="pi_test_12345",
runningBalance=1000,
isActive=True,
metadata=SafeJson({"stripe_payment_intent": "pi_test_12345"}),
)
)
return topup_tx
@@ -93,12 +99,12 @@ async def test_deduct_credits_atomic(server: SpinTestServer):
# Create refund request record (simulating webhook flow)
await CreditRefundRequest.prisma().create(
data={
"userId": REFUND_TEST_USER_ID,
"amount": 500,
"transactionKey": topup_tx.transactionKey, # Should match the original transaction
"reason": "Test refund",
}
data=CreditRefundRequestCreateInput(
userId=REFUND_TEST_USER_ID,
amount=500,
transactionKey=topup_tx.transactionKey, # Should match the original transaction
reason="Test refund",
)
)
# Call deduct_credits
@@ -109,9 +115,9 @@ async def test_deduct_credits_atomic(server: SpinTestServer):
where={"userId": REFUND_TEST_USER_ID}
)
assert user_balance is not None
assert (
user_balance.balance == 500
), f"Expected balance 500, got {user_balance.balance}"
assert user_balance.balance == 500, (
f"Expected balance 500, got {user_balance.balance}"
)
# Verify refund transaction was created
refund_tx = await CreditTransaction.prisma().find_first(
@@ -205,9 +211,9 @@ async def test_handle_dispute_with_sufficient_balance(
where={"userId": REFUND_TEST_USER_ID}
)
assert user_balance is not None
assert (
user_balance.balance == 1000
), f"Balance should remain 1000, got {user_balance.balance}"
assert user_balance.balance == 1000, (
f"Balance should remain 1000, got {user_balance.balance}"
)
finally:
await cleanup_test_user()
@@ -286,12 +292,12 @@ async def test_concurrent_refunds(server: SpinTestServer):
refund_requests = []
for i in range(5):
req = await CreditRefundRequest.prisma().create(
data={
"userId": REFUND_TEST_USER_ID,
"amount": 100, # $1 each
"transactionKey": topup_tx.transactionKey,
"reason": f"Test refund {i}",
}
data=CreditRefundRequestCreateInput(
userId=REFUND_TEST_USER_ID,
amount=100, # $1 each
transactionKey=topup_tx.transactionKey,
reason=f"Test refund {i}",
)
)
refund_requests.append(req)
@@ -332,9 +338,9 @@ async def test_concurrent_refunds(server: SpinTestServer):
print(f"DEBUG: Final balance = {user_balance.balance}, expected = 500")
# With atomic implementation, all 5 refunds should process correctly
assert (
user_balance.balance == 500
), f"Expected balance 500 after 5 refunds of 100 each, got {user_balance.balance}"
assert user_balance.balance == 500, (
f"Expected balance 500 after 5 refunds of 100 each, got {user_balance.balance}"
)
# Verify all refund transactions exist
refund_txs = await CreditTransaction.prisma().find_many(
@@ -343,9 +349,9 @@ async def test_concurrent_refunds(server: SpinTestServer):
"type": CreditTransactionType.REFUND,
}
)
assert (
len(refund_txs) == 5
), f"Expected 5 refund transactions, got {len(refund_txs)}"
assert len(refund_txs) == 5, (
f"Expected 5 refund transactions, got {len(refund_txs)}"
)
running_balances: set[int] = {
tx.runningBalance for tx in refund_txs if tx.runningBalance is not None
@@ -353,20 +359,20 @@ async def test_concurrent_refunds(server: SpinTestServer):
# Verify all balances are valid intermediate states
for balance in running_balances:
assert (
500 <= balance <= 1000
), f"Invalid balance {balance}, should be between 500 and 1000"
assert 500 <= balance <= 1000, (
f"Invalid balance {balance}, should be between 500 and 1000"
)
# Final balance should be present
assert (
500 in running_balances
), f"Final balance 500 should be in {running_balances}"
assert 500 in running_balances, (
f"Final balance 500 should be in {running_balances}"
)
# All balances should be unique and form a valid sequence
sorted_balances = sorted(running_balances, reverse=True)
assert (
len(sorted_balances) == 5
), f"Expected 5 unique balances, got {len(sorted_balances)}"
assert len(sorted_balances) == 5, (
f"Expected 5 unique balances, got {len(sorted_balances)}"
)
finally:
await cleanup_test_user()

View File

@@ -3,6 +3,11 @@ from datetime import datetime, timedelta, timezone
import pytest
from prisma.enums import CreditTransactionType
from prisma.models import CreditTransaction, UserBalance
from prisma.types import (
CreditTransactionCreateInput,
UserBalanceCreateInput,
UserBalanceUpsertInput,
)
from backend.blocks.llm import AITextGeneratorBlock
from backend.data.block import get_block
@@ -23,10 +28,10 @@ async def disable_test_user_transactions():
old_date = datetime.now(timezone.utc) - timedelta(days=35) # More than a month ago
await UserBalance.prisma().upsert(
where={"userId": DEFAULT_USER_ID},
data={
"create": {"userId": DEFAULT_USER_ID, "balance": 0},
"update": {"balance": 0, "updatedAt": old_date},
},
data=UserBalanceUpsertInput(
create=UserBalanceCreateInput(userId=DEFAULT_USER_ID, balance=0),
update={"balance": 0, "updatedAt": old_date},
),
)
@@ -140,23 +145,23 @@ async def test_block_credit_reset(server: SpinTestServer):
# Manually create a transaction with month 1 timestamp to establish history
await CreditTransaction.prisma().create(
data={
"userId": DEFAULT_USER_ID,
"amount": 100,
"type": CreditTransactionType.TOP_UP,
"runningBalance": 1100,
"isActive": True,
"createdAt": month1, # Set specific timestamp
}
data=CreditTransactionCreateInput(
userId=DEFAULT_USER_ID,
amount=100,
type=CreditTransactionType.TOP_UP,
runningBalance=1100,
isActive=True,
createdAt=month1, # Set specific timestamp
)
)
# Update user balance to match
await UserBalance.prisma().upsert(
where={"userId": DEFAULT_USER_ID},
data={
"create": {"userId": DEFAULT_USER_ID, "balance": 1100},
"update": {"balance": 1100},
},
data=UserBalanceUpsertInput(
create=UserBalanceCreateInput(userId=DEFAULT_USER_ID, balance=1100),
update={"balance": 1100},
),
)
# Now test month 2 behavior
@@ -175,14 +180,14 @@ async def test_block_credit_reset(server: SpinTestServer):
# Create a month 2 transaction to update the last transaction time
await CreditTransaction.prisma().create(
data={
"userId": DEFAULT_USER_ID,
"amount": -700, # Spent 700 to get to 400
"type": CreditTransactionType.USAGE,
"runningBalance": 400,
"isActive": True,
"createdAt": month2,
}
data=CreditTransactionCreateInput(
userId=DEFAULT_USER_ID,
amount=-700, # Spent 700 to get to 400
type=CreditTransactionType.USAGE,
runningBalance=400,
isActive=True,
createdAt=month2,
)
)
# Move to month 3

View File

@@ -12,6 +12,11 @@ import pytest
from prisma.enums import CreditTransactionType
from prisma.errors import UniqueViolationError
from prisma.models import CreditTransaction, User, UserBalance
from prisma.types import (
UserBalanceCreateInput,
UserBalanceUpsertInput,
UserCreateInput,
)
from backend.data.credit import POSTGRES_INT_MIN, UserCredit
from backend.util.test import SpinTestServer
@@ -21,11 +26,11 @@ async def create_test_user(user_id: str) -> None:
"""Create a test user for underflow tests."""
try:
await User.prisma().create(
data={
"id": user_id,
"email": f"test-{user_id}@example.com",
"name": f"Test User {user_id[:8]}",
}
data=UserCreateInput(
id=user_id,
email=f"test-{user_id}@example.com",
name=f"Test User {user_id[:8]}",
)
)
except UniqueViolationError:
# User already exists, continue
@@ -33,7 +38,10 @@ async def create_test_user(user_id: str) -> None:
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
data=UserBalanceUpsertInput(
create=UserBalanceCreateInput(userId=user_id, balance=0),
update={"balance": 0},
),
)
@@ -66,14 +74,14 @@ async def test_debug_underflow_step_by_step(server: SpinTestServer):
initial_balance_target = POSTGRES_INT_MIN + 100
# Use direct database update to set the balance close to underflow
from prisma.models import UserBalance
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "balance": initial_balance_target},
"update": {"balance": initial_balance_target},
},
data=UserBalanceUpsertInput(
create=UserBalanceCreateInput(
userId=user_id, balance=initial_balance_target
),
update={"balance": initial_balance_target},
),
)
current_balance = await credit_system.get_credits(user_id)
@@ -82,9 +90,7 @@ async def test_debug_underflow_step_by_step(server: SpinTestServer):
# Test 2: Apply amount that should cause underflow
print("\n=== Test 2: Testing underflow protection ===")
test_amount = (
-200
) # This should cause underflow: (POSTGRES_INT_MIN + 100) + (-200) = POSTGRES_INT_MIN - 100
test_amount = -200 # This should cause underflow: (POSTGRES_INT_MIN + 100) + (-200) = POSTGRES_INT_MIN - 100
expected_without_protection = current_balance + test_amount
print(f"Current balance: {current_balance}")
print(f"Test amount: {test_amount}")
@@ -101,19 +107,19 @@ async def test_debug_underflow_step_by_step(server: SpinTestServer):
print(f"Actual result: {balance_result}")
# Check if underflow protection worked
assert (
balance_result == POSTGRES_INT_MIN
), f"Expected underflow protection to clamp balance to {POSTGRES_INT_MIN}, got {balance_result}"
assert balance_result == POSTGRES_INT_MIN, (
f"Expected underflow protection to clamp balance to {POSTGRES_INT_MIN}, got {balance_result}"
)
# Test 3: Edge case - exactly at POSTGRES_INT_MIN
print("\n=== Test 3: Testing exact POSTGRES_INT_MIN boundary ===")
# Set balance to exactly POSTGRES_INT_MIN
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "balance": POSTGRES_INT_MIN},
"update": {"balance": POSTGRES_INT_MIN},
},
data=UserBalanceUpsertInput(
create=UserBalanceCreateInput(userId=user_id, balance=POSTGRES_INT_MIN),
update={"balance": POSTGRES_INT_MIN},
),
)
edge_balance = await credit_system.get_credits(user_id)
@@ -128,9 +134,9 @@ async def test_debug_underflow_step_by_step(server: SpinTestServer):
)
print(f"After subtracting 1: {edge_result}")
assert (
edge_result == POSTGRES_INT_MIN
), f"Expected balance to remain clamped at {POSTGRES_INT_MIN}, got {edge_result}"
assert edge_result == POSTGRES_INT_MIN, (
f"Expected balance to remain clamped at {POSTGRES_INT_MIN}, got {edge_result}"
)
finally:
await cleanup_test_user(user_id)
@@ -147,15 +153,13 @@ async def test_underflow_protection_large_refunds(server: SpinTestServer):
# Set up balance close to underflow threshold to test the protection
# Set balance to POSTGRES_INT_MIN + 1000, then try to subtract 2000
# This should trigger underflow protection
from prisma.models import UserBalance
test_balance = POSTGRES_INT_MIN + 1000
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "balance": test_balance},
"update": {"balance": test_balance},
},
data=UserBalanceUpsertInput(
create=UserBalanceCreateInput(userId=user_id, balance=test_balance),
update={"balance": test_balance},
),
)
current_balance = await credit_system.get_credits(user_id)
@@ -176,18 +180,18 @@ async def test_underflow_protection_large_refunds(server: SpinTestServer):
)
# Balance should be clamped to POSTGRES_INT_MIN, not the calculated underflow value
assert (
final_balance == POSTGRES_INT_MIN
), f"Balance should be clamped to {POSTGRES_INT_MIN}, got {final_balance}"
assert (
final_balance > expected_without_protection
), f"Balance should be greater than underflow result {expected_without_protection}, got {final_balance}"
assert final_balance == POSTGRES_INT_MIN, (
f"Balance should be clamped to {POSTGRES_INT_MIN}, got {final_balance}"
)
assert final_balance > expected_without_protection, (
f"Balance should be greater than underflow result {expected_without_protection}, got {final_balance}"
)
# Verify with get_credits too
stored_balance = await credit_system.get_credits(user_id)
assert (
stored_balance == POSTGRES_INT_MIN
), f"Stored balance should be {POSTGRES_INT_MIN}, got {stored_balance}"
assert stored_balance == POSTGRES_INT_MIN, (
f"Stored balance should be {POSTGRES_INT_MIN}, got {stored_balance}"
)
# Verify transaction was created with the underflow-protected balance
transactions = await CreditTransaction.prisma().find_many(
@@ -195,9 +199,9 @@ async def test_underflow_protection_large_refunds(server: SpinTestServer):
order={"createdAt": "desc"},
)
assert len(transactions) > 0, "Refund transaction should be created"
assert (
transactions[0].runningBalance == POSTGRES_INT_MIN
), f"Transaction should show clamped balance {POSTGRES_INT_MIN}, got {transactions[0].runningBalance}"
assert transactions[0].runningBalance == POSTGRES_INT_MIN, (
f"Transaction should show clamped balance {POSTGRES_INT_MIN}, got {transactions[0].runningBalance}"
)
finally:
await cleanup_test_user(user_id)
@@ -212,15 +216,13 @@ async def test_multiple_large_refunds_cumulative_underflow(server: SpinTestServe
try:
# Set up balance close to underflow threshold
from prisma.models import UserBalance
initial_balance = POSTGRES_INT_MIN + 500 # Close to minimum but with some room
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "balance": initial_balance},
"update": {"balance": initial_balance},
},
data=UserBalanceUpsertInput(
create=UserBalanceCreateInput(userId=user_id, balance=initial_balance),
update={"balance": initial_balance},
),
)
# Apply multiple refunds that would cumulatively underflow
@@ -238,12 +240,12 @@ async def test_multiple_large_refunds_cumulative_underflow(server: SpinTestServe
expected_balance_1 = (
initial_balance + refund_amount
) # Should be POSTGRES_INT_MIN + 200
assert (
balance_1 == expected_balance_1
), f"First refund should result in {expected_balance_1}, got {balance_1}"
assert (
balance_1 >= POSTGRES_INT_MIN
), f"First refund should not go below {POSTGRES_INT_MIN}, got {balance_1}"
assert balance_1 == expected_balance_1, (
f"First refund should result in {expected_balance_1}, got {balance_1}"
)
assert balance_1 >= POSTGRES_INT_MIN, (
f"First refund should not go below {POSTGRES_INT_MIN}, got {balance_1}"
)
# Second refund: (POSTGRES_INT_MIN + 200) + (-300) = POSTGRES_INT_MIN - 100 (would underflow)
balance_2, _ = await credit_system._add_transaction(
@@ -254,9 +256,9 @@ async def test_multiple_large_refunds_cumulative_underflow(server: SpinTestServe
)
# Should be clamped to minimum due to underflow protection
assert (
balance_2 == POSTGRES_INT_MIN
), f"Second refund should be clamped to {POSTGRES_INT_MIN}, got {balance_2}"
assert balance_2 == POSTGRES_INT_MIN, (
f"Second refund should be clamped to {POSTGRES_INT_MIN}, got {balance_2}"
)
# Third refund: Should stay at minimum
balance_3, _ = await credit_system._add_transaction(
@@ -267,15 +269,15 @@ async def test_multiple_large_refunds_cumulative_underflow(server: SpinTestServe
)
# Should still be at minimum
assert (
balance_3 == POSTGRES_INT_MIN
), f"Third refund should stay at {POSTGRES_INT_MIN}, got {balance_3}"
assert balance_3 == POSTGRES_INT_MIN, (
f"Third refund should stay at {POSTGRES_INT_MIN}, got {balance_3}"
)
# Final balance check
final_balance = await credit_system.get_credits(user_id)
assert (
final_balance == POSTGRES_INT_MIN
), f"Final balance should be {POSTGRES_INT_MIN}, got {final_balance}"
assert final_balance == POSTGRES_INT_MIN, (
f"Final balance should be {POSTGRES_INT_MIN}, got {final_balance}"
)
finally:
await cleanup_test_user(user_id)
@@ -295,10 +297,10 @@ async def test_concurrent_large_refunds_no_underflow(server: SpinTestServer):
initial_balance = POSTGRES_INT_MIN + 1000 # Close to minimum
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "balance": initial_balance},
"update": {"balance": initial_balance},
},
data=UserBalanceUpsertInput(
create=UserBalanceCreateInput(userId=user_id, balance=initial_balance),
update={"balance": initial_balance},
),
)
async def large_refund(amount: int, label: str):
@@ -327,35 +329,35 @@ async def test_concurrent_large_refunds_no_underflow(server: SpinTestServer):
for i, result in enumerate(results):
if isinstance(result, tuple):
balance, _ = result
assert (
balance >= POSTGRES_INT_MIN
), f"Result {i} balance {balance} underflowed below {POSTGRES_INT_MIN}"
assert balance >= POSTGRES_INT_MIN, (
f"Result {i} balance {balance} underflowed below {POSTGRES_INT_MIN}"
)
valid_results.append(balance)
elif isinstance(result, str) and "FAILED" in result:
# Some operations might fail due to validation, that's okay
pass
else:
# Unexpected exception
assert not isinstance(
result, Exception
), f"Unexpected exception in result {i}: {result}"
assert not isinstance(result, Exception), (
f"Unexpected exception in result {i}: {result}"
)
# At least one operation should succeed
assert (
len(valid_results) > 0
), f"At least one refund should succeed, got results: {results}"
assert len(valid_results) > 0, (
f"At least one refund should succeed, got results: {results}"
)
# All successful results should be >= POSTGRES_INT_MIN
for balance in valid_results:
assert (
balance >= POSTGRES_INT_MIN
), f"Balance {balance} should not be below {POSTGRES_INT_MIN}"
assert balance >= POSTGRES_INT_MIN, (
f"Balance {balance} should not be below {POSTGRES_INT_MIN}"
)
# Final balance should be valid and at or above POSTGRES_INT_MIN
final_balance = await credit_system.get_credits(user_id)
assert (
final_balance >= POSTGRES_INT_MIN
), f"Final balance {final_balance} should not underflow below {POSTGRES_INT_MIN}"
assert final_balance >= POSTGRES_INT_MIN, (
f"Final balance {final_balance} should not underflow below {POSTGRES_INT_MIN}"
)
finally:
await cleanup_test_user(user_id)

View File

@@ -14,6 +14,7 @@ import pytest
from prisma.enums import CreditTransactionType
from prisma.errors import UniqueViolationError
from prisma.models import CreditTransaction, User, UserBalance
from prisma.types import UserBalanceCreateInput, UserCreateInput
from backend.data.credit import UsageTransactionMetadata, UserCredit
from backend.util.json import SafeJson
@@ -24,11 +25,11 @@ async def create_test_user(user_id: str) -> None:
"""Create a test user for migration tests."""
try:
await User.prisma().create(
data={
"id": user_id,
"email": f"test-{user_id}@example.com",
"name": f"Test User {user_id[:8]}",
}
data=UserCreateInput(
id=user_id,
email=f"test-{user_id}@example.com",
name=f"Test User {user_id[:8]}",
)
)
except UniqueViolationError:
# User already exists, continue
@@ -60,9 +61,9 @@ async def test_user_balance_migration_complete(server: SpinTestServer):
# User.balance should not exist or should be None/0 if it exists
user_balance_attr = getattr(user, "balance", None)
if user_balance_attr is not None:
assert (
user_balance_attr == 0 or user_balance_attr is None
), f"User.balance should be 0 or None, got {user_balance_attr}"
assert user_balance_attr == 0 or user_balance_attr is None, (
f"User.balance should be 0 or None, got {user_balance_attr}"
)
# 2. Perform various credit operations using internal method (bypasses Stripe)
await credit_system._add_transaction(
@@ -87,9 +88,9 @@ async def test_user_balance_migration_complete(server: SpinTestServer):
# 3. Verify UserBalance table has correct values
user_balance = await UserBalance.prisma().find_unique(where={"userId": user_id})
assert user_balance is not None
assert (
user_balance.balance == 700
), f"UserBalance should be 700, got {user_balance.balance}"
assert user_balance.balance == 700, (
f"UserBalance should be 700, got {user_balance.balance}"
)
# 4. CRITICAL: Verify User.balance is NEVER updated during operations
user_after = await User.prisma().find_unique(where={"id": user_id})
@@ -97,15 +98,15 @@ async def test_user_balance_migration_complete(server: SpinTestServer):
user_balance_after = getattr(user_after, "balance", None)
if user_balance_after is not None:
# If User.balance exists, it should still be 0 (never updated)
assert (
user_balance_after == 0 or user_balance_after is None
), f"User.balance should remain 0/None after operations, got {user_balance_after}. This indicates User.balance is still being used!"
assert user_balance_after == 0 or user_balance_after is None, (
f"User.balance should remain 0/None after operations, got {user_balance_after}. This indicates User.balance is still being used!"
)
# 5. Verify get_credits always returns UserBalance value, not User.balance
final_balance = await credit_system.get_credits(user_id)
assert (
final_balance == user_balance.balance
), f"get_credits should return UserBalance value {user_balance.balance}, got {final_balance}"
assert final_balance == user_balance.balance, (
f"get_credits should return UserBalance value {user_balance.balance}, got {final_balance}"
)
finally:
await cleanup_test_user(user_id)
@@ -121,14 +122,14 @@ async def test_detect_stale_user_balance_queries(server: SpinTestServer):
try:
# Create UserBalance with specific value
await UserBalance.prisma().create(
data={"userId": user_id, "balance": 5000} # $50
data=UserBalanceCreateInput(userId=user_id, balance=5000) # $50
)
# Verify that get_credits returns UserBalance value (5000), not any stale User.balance value
balance = await credit_system.get_credits(user_id)
assert (
balance == 5000
), f"Expected get_credits to return 5000 from UserBalance, got {balance}"
assert balance == 5000, (
f"Expected get_credits to return 5000 from UserBalance, got {balance}"
)
# Verify all operations use UserBalance using internal method (bypasses Stripe)
await credit_system._add_transaction(
@@ -143,9 +144,9 @@ async def test_detect_stale_user_balance_queries(server: SpinTestServer):
# Verify UserBalance table has the correct value
user_balance = await UserBalance.prisma().find_unique(where={"userId": user_id})
assert user_balance is not None
assert (
user_balance.balance == 6000
), f"UserBalance should be 6000, got {user_balance.balance}"
assert user_balance.balance == 6000, (
f"UserBalance should be 6000, got {user_balance.balance}"
)
finally:
await cleanup_test_user(user_id)
@@ -160,7 +161,9 @@ async def test_concurrent_operations_use_userbalance_only(server: SpinTestServer
try:
# Set initial balance in UserBalance
await UserBalance.prisma().create(data={"userId": user_id, "balance": 1000})
await UserBalance.prisma().create(
data=UserBalanceCreateInput(userId=user_id, balance=1000)
)
# Run concurrent operations to ensure they all use UserBalance atomic operations
async def concurrent_spend(amount: int, label: str):
@@ -196,9 +199,9 @@ async def test_concurrent_operations_use_userbalance_only(server: SpinTestServer
# Verify UserBalance has correct value
user_balance = await UserBalance.prisma().find_unique(where={"userId": user_id})
assert user_balance is not None
assert (
user_balance.balance == 400
), f"UserBalance should be 400, got {user_balance.balance}"
assert user_balance.balance == 400, (
f"UserBalance should be 400, got {user_balance.balance}"
)
# Critical: If User.balance exists and was used, it might have wrong value
try:

View File

@@ -28,6 +28,7 @@ from prisma.models import (
AgentNodeExecutionKeyValueData,
)
from prisma.types import (
AgentGraphExecutionCreateInput,
AgentGraphExecutionUpdateManyMutationInput,
AgentGraphExecutionWhereInput,
AgentNodeExecutionCreateInput,
@@ -709,18 +710,18 @@ async def create_graph_execution(
The id of the AgentGraphExecution and the list of ExecutionResult for each node.
"""
result = await AgentGraphExecution.prisma().create(
data={
"agentGraphId": graph_id,
"agentGraphVersion": graph_version,
"executionStatus": ExecutionStatus.INCOMPLETE,
"inputs": SafeJson(inputs),
"credentialInputs": (
data=AgentGraphExecutionCreateInput(
agentGraphId=graph_id,
agentGraphVersion=graph_version,
executionStatus=ExecutionStatus.INCOMPLETE,
inputs=SafeJson(inputs),
credentialInputs=(
SafeJson(credential_inputs) if credential_inputs else Json({})
),
"nodesInputMasks": (
nodesInputMasks=(
SafeJson(nodes_input_masks) if nodes_input_masks else Json({})
),
"NodeExecutions": {
NodeExecutions={
"create": [
AgentNodeExecutionCreateInput(
agentNodeId=node_id,
@@ -736,10 +737,10 @@ async def create_graph_execution(
for node_id, node_input in starting_nodes_input
]
},
"userId": user_id,
"agentPresetId": preset_id,
"parentGraphExecutionId": parent_graph_exec_id,
},
userId=user_id,
agentPresetId=preset_id,
parentGraphExecutionId=parent_graph_exec_id,
),
include=GRAPH_EXECUTION_INCLUDE_WITH_NODES,
)
@@ -831,10 +832,10 @@ async def upsert_execution_output(
"""
Insert AgentNodeExecutionInputOutput record for as one of AgentNodeExecution.Output.
"""
data: AgentNodeExecutionInputOutputCreateInput = {
"name": output_name,
"referencedByOutputExecId": node_exec_id,
}
data = AgentNodeExecutionInputOutputCreateInput(
name=output_name,
referencedByOutputExecId=node_exec_id,
)
if output_data is not None:
data["data"] = SafeJson(output_data)
await AgentNodeExecutionInputOutput.prisma().create(data=data)
@@ -974,14 +975,12 @@ async def update_node_execution_status(
f"Invalid status transition: {status} has no valid source statuses"
)
where_clause: Any = {
"id": node_exec_id,
"executionStatus": {"in": [s.value for s in allowed_from]},
}
if res := await AgentNodeExecution.prisma().update(
where=cast(
AgentNodeExecutionWhereUniqueInput,
{
"id": node_exec_id,
"executionStatus": {"in": [s.value for s in allowed_from]},
},
),
where=where_clause,
data=_get_update_status_data(status, execution_data, stats),
include=EXECUTION_RESULT_INCLUDE,
):

View File

@@ -10,7 +10,11 @@ from typing import Optional
from prisma.enums import ReviewStatus
from prisma.models import PendingHumanReview
from prisma.types import PendingHumanReviewUpdateInput
from prisma.types import (
PendingHumanReviewCreateInput,
PendingHumanReviewUpdateInput,
PendingHumanReviewUpsertInput,
)
from pydantic import BaseModel
from backend.api.features.executions.review.model import (
@@ -66,20 +70,20 @@ async def get_or_create_human_review(
# Upsert - get existing or create new review
review = await PendingHumanReview.prisma().upsert(
where={"nodeExecId": node_exec_id},
data={
"create": {
"userId": user_id,
"nodeExecId": node_exec_id,
"graphExecId": graph_exec_id,
"graphId": graph_id,
"graphVersion": graph_version,
"payload": SafeJson(input_data),
"instructions": message,
"editable": editable,
"status": ReviewStatus.WAITING,
},
"update": {}, # Do nothing on update - keep existing review as is
},
data=PendingHumanReviewUpsertInput(
create=PendingHumanReviewCreateInput(
userId=user_id,
nodeExecId=node_exec_id,
graphExecId=graph_exec_id,
graphId=graph_id,
graphVersion=graph_version,
payload=SafeJson(input_data),
instructions=message,
editable=editable,
status=ReviewStatus.WAITING,
),
update={}, # Do nothing on update - keep existing review as is
),
)
logger.info(

View File

@@ -7,7 +7,11 @@ import prisma
import pydantic
from prisma.enums import OnboardingStep
from prisma.models import UserOnboarding
from prisma.types import UserOnboardingCreateInput, UserOnboardingUpdateInput
from prisma.types import (
UserOnboardingCreateInput,
UserOnboardingUpdateInput,
UserOnboardingUpsertInput,
)
from backend.api.features.store.model import StoreAgentDetails
from backend.api.model import OnboardingNotificationPayload
@@ -110,12 +114,13 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
if data.onboardingAgentExecutionId is not None:
update["onboardingAgentExecutionId"] = data.onboardingAgentExecutionId
create_input = UserOnboardingCreateInput(userId=user_id, **update)
return await UserOnboarding.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, **update},
"update": update,
},
data=UserOnboardingUpsertInput(
create=create_input,
update=update,
),
)