mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-09 15:17:59 -05:00
fix linter errors
This commit is contained in:
@@ -5,6 +5,8 @@ from os import getenv
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from prisma.types import ProfileCreateInput
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.blocks.firecrawl.scrape import FirecrawlScrapeBlock
|
||||
@@ -49,13 +51,13 @@ async def setup_test_data():
|
||||
# 1b. Create a profile with username for the user (required for store agent lookup)
|
||||
username = user.email.split("@")[0]
|
||||
await prisma.profile.create(
|
||||
data={
|
||||
"userId": user.id,
|
||||
"username": username,
|
||||
"name": f"Test User {username}",
|
||||
"description": "Test user profile",
|
||||
"links": [], # Required field - empty array for test profiles
|
||||
}
|
||||
data=ProfileCreateInput(
|
||||
userId=user.id,
|
||||
username=username,
|
||||
name=f"Test User {username}",
|
||||
description="Test user profile",
|
||||
links=[], # Required field - empty array for test profiles
|
||||
)
|
||||
)
|
||||
|
||||
# 2. Create a test graph with agent input -> agent output
|
||||
@@ -172,13 +174,13 @@ async def setup_llm_test_data():
|
||||
# 1b. Create a profile with username for the user (required for store agent lookup)
|
||||
username = user.email.split("@")[0]
|
||||
await prisma.profile.create(
|
||||
data={
|
||||
"userId": user.id,
|
||||
"username": username,
|
||||
"name": f"Test User {username}",
|
||||
"description": "Test user profile for LLM tests",
|
||||
"links": [], # Required field - empty array for test profiles
|
||||
}
|
||||
data=ProfileCreateInput(
|
||||
userId=user.id,
|
||||
username=username,
|
||||
name=f"Test User {username}",
|
||||
description="Test user profile for LLM tests",
|
||||
links=[], # Required field - empty array for test profiles
|
||||
)
|
||||
)
|
||||
|
||||
# 2. Create test OpenAI credentials for the user
|
||||
@@ -332,13 +334,13 @@ async def setup_firecrawl_test_data():
|
||||
# 1b. Create a profile with username for the user (required for store agent lookup)
|
||||
username = user.email.split("@")[0]
|
||||
await prisma.profile.create(
|
||||
data={
|
||||
"userId": user.id,
|
||||
"username": username,
|
||||
"name": f"Test User {username}",
|
||||
"description": "Test user profile for Firecrawl tests",
|
||||
"links": [], # Required field - empty array for test profiles
|
||||
}
|
||||
data=ProfileCreateInput(
|
||||
userId=user.id,
|
||||
username=username,
|
||||
name=f"Test User {username}",
|
||||
description="Test user profile for Firecrawl tests",
|
||||
links=[], # Required field - empty array for test profiles
|
||||
)
|
||||
)
|
||||
|
||||
# NOTE: We deliberately do NOT create Firecrawl credentials for this user
|
||||
|
||||
@@ -817,18 +817,16 @@ async def add_store_agent_to_library(
|
||||
|
||||
# Create LibraryAgent entry
|
||||
added_agent = await prisma.models.LibraryAgent.prisma().create(
|
||||
data={
|
||||
"User": {"connect": {"id": user_id}},
|
||||
"AgentGraph": {
|
||||
data=prisma.types.LibraryAgentCreateInput(
|
||||
User={"connect": {"id": user_id}},
|
||||
AgentGraph={
|
||||
"connect": {
|
||||
"graphVersionId": {"id": graph.id, "version": graph.version}
|
||||
}
|
||||
},
|
||||
"isCreatedByUser": False,
|
||||
"settings": SafeJson(
|
||||
_initialize_graph_settings(graph_model).model_dump()
|
||||
),
|
||||
},
|
||||
isCreatedByUser=False,
|
||||
settings=SafeJson(_initialize_graph_settings(graph_model).model_dump()),
|
||||
),
|
||||
include=library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
),
|
||||
|
||||
@@ -27,6 +27,13 @@ from prisma.models import OAuthApplication as PrismaOAuthApplication
|
||||
from prisma.models import OAuthAuthorizationCode as PrismaOAuthAuthorizationCode
|
||||
from prisma.models import OAuthRefreshToken as PrismaOAuthRefreshToken
|
||||
from prisma.models import User as PrismaUser
|
||||
from prisma.types import (
|
||||
OAuthAccessTokenCreateInput,
|
||||
OAuthApplicationCreateInput,
|
||||
OAuthAuthorizationCodeCreateInput,
|
||||
OAuthRefreshTokenCreateInput,
|
||||
UserCreateInput,
|
||||
)
|
||||
|
||||
from backend.api.rest_api import app
|
||||
|
||||
@@ -48,11 +55,11 @@ def test_user_id() -> str:
|
||||
async def test_user(server, test_user_id: str):
|
||||
"""Create a test user in the database."""
|
||||
await PrismaUser.prisma().create(
|
||||
data={
|
||||
"id": test_user_id,
|
||||
"email": f"oauth-test-{test_user_id}@example.com",
|
||||
"name": "OAuth Test User",
|
||||
}
|
||||
data=UserCreateInput(
|
||||
id=test_user_id,
|
||||
email=f"oauth-test-{test_user_id}@example.com",
|
||||
name="OAuth Test User",
|
||||
)
|
||||
)
|
||||
|
||||
yield test_user_id
|
||||
@@ -77,22 +84,22 @@ async def test_oauth_app(test_user: str):
|
||||
client_secret_hash, client_secret_salt = keysmith.hash_key(client_secret_plaintext)
|
||||
|
||||
await PrismaOAuthApplication.prisma().create(
|
||||
data={
|
||||
"id": app_id,
|
||||
"name": "Test OAuth App",
|
||||
"description": "Test application for integration tests",
|
||||
"clientId": client_id,
|
||||
"clientSecret": client_secret_hash,
|
||||
"clientSecretSalt": client_secret_salt,
|
||||
"redirectUris": [
|
||||
data=OAuthApplicationCreateInput(
|
||||
id=app_id,
|
||||
name="Test OAuth App",
|
||||
description="Test application for integration tests",
|
||||
clientId=client_id,
|
||||
clientSecret=client_secret_hash,
|
||||
clientSecretSalt=client_secret_salt,
|
||||
redirectUris=[
|
||||
"https://example.com/callback",
|
||||
"http://localhost:3000/callback",
|
||||
],
|
||||
"grantTypes": ["authorization_code", "refresh_token"],
|
||||
"scopes": [APIKeyPermission.EXECUTE_GRAPH, APIKeyPermission.READ_GRAPH],
|
||||
"ownerId": test_user,
|
||||
"isActive": True,
|
||||
}
|
||||
grantTypes=["authorization_code", "refresh_token"],
|
||||
scopes=[APIKeyPermission.EXECUTE_GRAPH, APIKeyPermission.READ_GRAPH],
|
||||
ownerId=test_user,
|
||||
isActive=True,
|
||||
)
|
||||
)
|
||||
|
||||
yield {
|
||||
@@ -296,19 +303,19 @@ async def inactive_oauth_app(test_user: str):
|
||||
client_secret_hash, client_secret_salt = keysmith.hash_key(client_secret_plaintext)
|
||||
|
||||
await PrismaOAuthApplication.prisma().create(
|
||||
data={
|
||||
"id": app_id,
|
||||
"name": "Inactive OAuth App",
|
||||
"description": "Inactive test application",
|
||||
"clientId": client_id,
|
||||
"clientSecret": client_secret_hash,
|
||||
"clientSecretSalt": client_secret_salt,
|
||||
"redirectUris": ["https://example.com/callback"],
|
||||
"grantTypes": ["authorization_code", "refresh_token"],
|
||||
"scopes": [APIKeyPermission.EXECUTE_GRAPH],
|
||||
"ownerId": test_user,
|
||||
"isActive": False, # Inactive!
|
||||
}
|
||||
data=OAuthApplicationCreateInput(
|
||||
id=app_id,
|
||||
name="Inactive OAuth App",
|
||||
description="Inactive test application",
|
||||
clientId=client_id,
|
||||
clientSecret=client_secret_hash,
|
||||
clientSecretSalt=client_secret_salt,
|
||||
redirectUris=["https://example.com/callback"],
|
||||
grantTypes=["authorization_code", "refresh_token"],
|
||||
scopes=[APIKeyPermission.EXECUTE_GRAPH],
|
||||
ownerId=test_user,
|
||||
isActive=False, # Inactive!
|
||||
)
|
||||
)
|
||||
|
||||
yield {
|
||||
@@ -699,14 +706,14 @@ async def test_token_authorization_code_expired(
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
await PrismaOAuthAuthorizationCode.prisma().create(
|
||||
data={
|
||||
"code": expired_code,
|
||||
"applicationId": test_oauth_app["id"],
|
||||
"userId": test_user,
|
||||
"scopes": [APIKeyPermission.EXECUTE_GRAPH],
|
||||
"redirectUri": test_oauth_app["redirect_uri"],
|
||||
"expiresAt": now - timedelta(hours=1), # Already expired
|
||||
}
|
||||
data=OAuthAuthorizationCodeCreateInput(
|
||||
code=expired_code,
|
||||
applicationId=test_oauth_app["id"],
|
||||
userId=test_user,
|
||||
scopes=[APIKeyPermission.EXECUTE_GRAPH],
|
||||
redirectUri=test_oauth_app["redirect_uri"],
|
||||
expiresAt=now - timedelta(hours=1), # Already expired
|
||||
)
|
||||
)
|
||||
|
||||
response = await client.post(
|
||||
@@ -942,13 +949,13 @@ async def test_token_refresh_expired(
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
await PrismaOAuthRefreshToken.prisma().create(
|
||||
data={
|
||||
"token": expired_token_hash,
|
||||
"applicationId": test_oauth_app["id"],
|
||||
"userId": test_user,
|
||||
"scopes": [APIKeyPermission.EXECUTE_GRAPH],
|
||||
"expiresAt": now - timedelta(days=1), # Already expired
|
||||
}
|
||||
data=OAuthRefreshTokenCreateInput(
|
||||
token=expired_token_hash,
|
||||
applicationId=test_oauth_app["id"],
|
||||
userId=test_user,
|
||||
scopes=[APIKeyPermission.EXECUTE_GRAPH],
|
||||
expiresAt=now - timedelta(days=1), # Already expired
|
||||
)
|
||||
)
|
||||
|
||||
response = await client.post(
|
||||
@@ -980,14 +987,14 @@ async def test_token_refresh_revoked(
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
await PrismaOAuthRefreshToken.prisma().create(
|
||||
data={
|
||||
"token": revoked_token_hash,
|
||||
"applicationId": test_oauth_app["id"],
|
||||
"userId": test_user,
|
||||
"scopes": [APIKeyPermission.EXECUTE_GRAPH],
|
||||
"expiresAt": now + timedelta(days=30), # Not expired
|
||||
"revokedAt": now - timedelta(hours=1), # But revoked
|
||||
}
|
||||
data=OAuthRefreshTokenCreateInput(
|
||||
token=revoked_token_hash,
|
||||
applicationId=test_oauth_app["id"],
|
||||
userId=test_user,
|
||||
scopes=[APIKeyPermission.EXECUTE_GRAPH],
|
||||
expiresAt=now + timedelta(days=30), # Not expired
|
||||
revokedAt=now - timedelta(hours=1), # But revoked
|
||||
)
|
||||
)
|
||||
|
||||
response = await client.post(
|
||||
@@ -1013,19 +1020,19 @@ async def other_oauth_app(test_user: str):
|
||||
client_secret_hash, client_secret_salt = keysmith.hash_key(client_secret_plaintext)
|
||||
|
||||
await PrismaOAuthApplication.prisma().create(
|
||||
data={
|
||||
"id": app_id,
|
||||
"name": "Other OAuth App",
|
||||
"description": "Second test application",
|
||||
"clientId": client_id,
|
||||
"clientSecret": client_secret_hash,
|
||||
"clientSecretSalt": client_secret_salt,
|
||||
"redirectUris": ["https://other.example.com/callback"],
|
||||
"grantTypes": ["authorization_code", "refresh_token"],
|
||||
"scopes": [APIKeyPermission.EXECUTE_GRAPH],
|
||||
"ownerId": test_user,
|
||||
"isActive": True,
|
||||
}
|
||||
data=OAuthApplicationCreateInput(
|
||||
id=app_id,
|
||||
name="Other OAuth App",
|
||||
description="Second test application",
|
||||
clientId=client_id,
|
||||
clientSecret=client_secret_hash,
|
||||
clientSecretSalt=client_secret_salt,
|
||||
redirectUris=["https://other.example.com/callback"],
|
||||
grantTypes=["authorization_code", "refresh_token"],
|
||||
scopes=[APIKeyPermission.EXECUTE_GRAPH],
|
||||
ownerId=test_user,
|
||||
isActive=True,
|
||||
)
|
||||
)
|
||||
|
||||
yield {
|
||||
@@ -1052,13 +1059,13 @@ async def test_token_refresh_wrong_application(
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
await PrismaOAuthRefreshToken.prisma().create(
|
||||
data={
|
||||
"token": token_hash,
|
||||
"applicationId": test_oauth_app["id"], # Belongs to test_oauth_app
|
||||
"userId": test_user,
|
||||
"scopes": [APIKeyPermission.EXECUTE_GRAPH],
|
||||
"expiresAt": now + timedelta(days=30),
|
||||
}
|
||||
data=OAuthRefreshTokenCreateInput(
|
||||
token=token_hash,
|
||||
applicationId=test_oauth_app["id"], # Belongs to test_oauth_app
|
||||
userId=test_user,
|
||||
scopes=[APIKeyPermission.EXECUTE_GRAPH],
|
||||
expiresAt=now + timedelta(days=30),
|
||||
)
|
||||
)
|
||||
|
||||
# Try to use it with `other_oauth_app`
|
||||
@@ -1267,19 +1274,19 @@ async def test_validate_access_token_fails_when_app_disabled(
|
||||
client_secret_hash, client_secret_salt = keysmith.hash_key(client_secret_plaintext)
|
||||
|
||||
await PrismaOAuthApplication.prisma().create(
|
||||
data={
|
||||
"id": app_id,
|
||||
"name": "App To Be Disabled",
|
||||
"description": "Test app for disabled validation",
|
||||
"clientId": client_id,
|
||||
"clientSecret": client_secret_hash,
|
||||
"clientSecretSalt": client_secret_salt,
|
||||
"redirectUris": ["https://example.com/callback"],
|
||||
"grantTypes": ["authorization_code"],
|
||||
"scopes": [APIKeyPermission.EXECUTE_GRAPH],
|
||||
"ownerId": test_user,
|
||||
"isActive": True,
|
||||
}
|
||||
data=OAuthApplicationCreateInput(
|
||||
id=app_id,
|
||||
name="App To Be Disabled",
|
||||
description="Test app for disabled validation",
|
||||
clientId=client_id,
|
||||
clientSecret=client_secret_hash,
|
||||
clientSecretSalt=client_secret_salt,
|
||||
redirectUris=["https://example.com/callback"],
|
||||
grantTypes=["authorization_code"],
|
||||
scopes=[APIKeyPermission.EXECUTE_GRAPH],
|
||||
ownerId=test_user,
|
||||
isActive=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Create an access token directly in the database
|
||||
@@ -1288,13 +1295,13 @@ async def test_validate_access_token_fails_when_app_disabled(
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
await PrismaOAuthAccessToken.prisma().create(
|
||||
data={
|
||||
"token": token_hash,
|
||||
"applicationId": app_id,
|
||||
"userId": test_user,
|
||||
"scopes": [APIKeyPermission.EXECUTE_GRAPH],
|
||||
"expiresAt": now + timedelta(hours=1),
|
||||
}
|
||||
data=OAuthAccessTokenCreateInput(
|
||||
token=token_hash,
|
||||
applicationId=app_id,
|
||||
userId=test_user,
|
||||
scopes=[APIKeyPermission.EXECUTE_GRAPH],
|
||||
expiresAt=now + timedelta(hours=1),
|
||||
)
|
||||
)
|
||||
|
||||
# Token should be valid while app is active
|
||||
@@ -1561,19 +1568,19 @@ async def test_revoke_token_from_different_app_fails_silently(
|
||||
)
|
||||
|
||||
await PrismaOAuthApplication.prisma().create(
|
||||
data={
|
||||
"id": app2_id,
|
||||
"name": "Second Test OAuth App",
|
||||
"description": "Second test application for cross-app revocation test",
|
||||
"clientId": app2_client_id,
|
||||
"clientSecret": app2_client_secret_hash,
|
||||
"clientSecretSalt": app2_client_secret_salt,
|
||||
"redirectUris": ["https://other-app.com/callback"],
|
||||
"grantTypes": ["authorization_code", "refresh_token"],
|
||||
"scopes": [APIKeyPermission.EXECUTE_GRAPH, APIKeyPermission.READ_GRAPH],
|
||||
"ownerId": test_user,
|
||||
"isActive": True,
|
||||
}
|
||||
data=OAuthApplicationCreateInput(
|
||||
id=app2_id,
|
||||
name="Second Test OAuth App",
|
||||
description="Second test application for cross-app revocation test",
|
||||
clientId=app2_client_id,
|
||||
clientSecret=app2_client_secret_hash,
|
||||
clientSecretSalt=app2_client_secret_salt,
|
||||
redirectUris=["https://other-app.com/callback"],
|
||||
grantTypes=["authorization_code", "refresh_token"],
|
||||
scopes=[APIKeyPermission.EXECUTE_GRAPH, APIKeyPermission.READ_GRAPH],
|
||||
ownerId=test_user,
|
||||
isActive=True,
|
||||
)
|
||||
)
|
||||
|
||||
# App 2 tries to revoke App 1's access token
|
||||
|
||||
@@ -249,7 +249,9 @@ async def log_search_term(search_query: str):
|
||||
date = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
try:
|
||||
await prisma.models.SearchTerms.prisma().create(
|
||||
data={"searchTerm": search_query, "createdDate": date}
|
||||
data=prisma.types.SearchTermsCreateInput(
|
||||
searchTerm=search_query, createdDate=date
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
# Fail silently here so that logging search terms doesn't break the app
|
||||
@@ -1456,11 +1458,9 @@ async def _approve_sub_agent(
|
||||
# Create new version if no matching version found
|
||||
next_version = max((v.version for v in listing.Versions or []), default=0) + 1
|
||||
await prisma.models.StoreListingVersion.prisma(tx).create(
|
||||
data={
|
||||
**_create_sub_agent_version_data(sub_graph, heading, main_agent_name),
|
||||
"version": next_version,
|
||||
"storeListingId": listing.id,
|
||||
}
|
||||
data=_create_sub_agent_version_data(
|
||||
sub_graph, heading, main_agent_name, next_version, listing.id
|
||||
)
|
||||
)
|
||||
await prisma.models.StoreListing.prisma(tx).update(
|
||||
where={"id": listing.id}, data={"hasApprovedVersion": True}
|
||||
@@ -1468,10 +1468,14 @@ async def _approve_sub_agent(
|
||||
|
||||
|
||||
def _create_sub_agent_version_data(
|
||||
sub_graph: prisma.models.AgentGraph, heading: str, main_agent_name: str
|
||||
sub_graph: prisma.models.AgentGraph,
|
||||
heading: str,
|
||||
main_agent_name: str,
|
||||
version: typing.Optional[int] = None,
|
||||
store_listing_id: typing.Optional[str] = None,
|
||||
) -> prisma.types.StoreListingVersionCreateInput:
|
||||
"""Create store listing version data for a sub-agent"""
|
||||
return prisma.types.StoreListingVersionCreateInput(
|
||||
data = prisma.types.StoreListingVersionCreateInput(
|
||||
agentGraphId=sub_graph.id,
|
||||
agentGraphVersion=sub_graph.version,
|
||||
name=sub_graph.name or heading,
|
||||
@@ -1486,6 +1490,11 @@ def _create_sub_agent_version_data(
|
||||
imageUrls=[], # Sub-agents don't need images
|
||||
categories=[], # Sub-agents don't need categories
|
||||
)
|
||||
if version is not None:
|
||||
data["version"] = version
|
||||
if store_listing_id is not None:
|
||||
data["storeListingId"] = store_listing_id
|
||||
return data
|
||||
|
||||
|
||||
async def review_store_submission(
|
||||
|
||||
@@ -42,6 +42,7 @@ from urllib.parse import urlparse
|
||||
import click
|
||||
from autogpt_libs.api_key.keysmith import APIKeySmith
|
||||
from prisma.enums import APIKeyPermission
|
||||
from prisma.types import OAuthApplicationCreateInput
|
||||
|
||||
keysmith = APIKeySmith()
|
||||
|
||||
@@ -147,7 +148,7 @@ def format_sql_insert(creds: dict) -> str:
|
||||
|
||||
sql = f"""
|
||||
-- ============================================================
|
||||
-- OAuth Application: {creds['name']}
|
||||
-- OAuth Application: {creds["name"]}
|
||||
-- Generated: {now_iso} UTC
|
||||
-- ============================================================
|
||||
|
||||
@@ -167,14 +168,14 @@ INSERT INTO "OAuthApplication" (
|
||||
"isActive"
|
||||
)
|
||||
VALUES (
|
||||
'{creds['id']}',
|
||||
'{creds["id"]}',
|
||||
NOW(),
|
||||
NOW(),
|
||||
'{creds['name']}',
|
||||
{f"'{creds['description']}'" if creds['description'] else 'NULL'},
|
||||
'{creds['client_id']}',
|
||||
'{creds['client_secret_hash']}',
|
||||
'{creds['client_secret_salt']}',
|
||||
'{creds["name"]}',
|
||||
{f"'{creds['description']}'" if creds["description"] else "NULL"},
|
||||
'{creds["client_id"]}',
|
||||
'{creds["client_secret_hash"]}',
|
||||
'{creds["client_secret_salt"]}',
|
||||
ARRAY{redirect_uris_pg}::TEXT[],
|
||||
ARRAY{grant_types_pg}::TEXT[],
|
||||
ARRAY{scopes_pg}::"APIKeyPermission"[],
|
||||
@@ -186,8 +187,8 @@ VALUES (
|
||||
-- ⚠️ IMPORTANT: Save these credentials securely!
|
||||
-- ============================================================
|
||||
--
|
||||
-- Client ID: {creds['client_id']}
|
||||
-- Client Secret: {creds['client_secret_plaintext']}
|
||||
-- Client ID: {creds["client_id"]}
|
||||
-- Client Secret: {creds["client_secret_plaintext"]}
|
||||
--
|
||||
-- ⚠️ The client secret is shown ONLY ONCE!
|
||||
-- ⚠️ Store it securely and share only with the application developer.
|
||||
@@ -200,7 +201,7 @@ VALUES (
|
||||
-- To verify the application was created:
|
||||
-- SELECT "clientId", name, scopes, "redirectUris", "isActive"
|
||||
-- FROM "OAuthApplication"
|
||||
-- WHERE "clientId" = '{creds['client_id']}';
|
||||
-- WHERE "clientId" = '{creds["client_id"]}';
|
||||
"""
|
||||
return sql
|
||||
|
||||
@@ -834,19 +835,19 @@ async def create_test_app_in_db(
|
||||
|
||||
# Insert into database
|
||||
app = await OAuthApplication.prisma().create(
|
||||
data={
|
||||
"id": creds["id"],
|
||||
"name": creds["name"],
|
||||
"description": creds["description"],
|
||||
"clientId": creds["client_id"],
|
||||
"clientSecret": creds["client_secret_hash"],
|
||||
"clientSecretSalt": creds["client_secret_salt"],
|
||||
"redirectUris": creds["redirect_uris"],
|
||||
"grantTypes": creds["grant_types"],
|
||||
"scopes": creds["scopes"],
|
||||
"ownerId": owner_id,
|
||||
"isActive": True,
|
||||
}
|
||||
data=OAuthApplicationCreateInput(
|
||||
id=creds["id"],
|
||||
name=creds["name"],
|
||||
description=creds["description"],
|
||||
clientId=creds["client_id"],
|
||||
clientSecret=creds["client_secret_hash"],
|
||||
clientSecretSalt=creds["client_secret_salt"],
|
||||
redirectUris=creds["redirect_uris"],
|
||||
grantTypes=creds["grant_types"],
|
||||
scopes=creds["scopes"],
|
||||
ownerId=owner_id,
|
||||
isActive=True,
|
||||
)
|
||||
)
|
||||
|
||||
click.echo(f"✓ Created test OAuth application: {app.clientId}")
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Literal, Optional
|
||||
from autogpt_libs.api_key.keysmith import APIKeySmith
|
||||
from prisma.enums import APIKeyPermission, APIKeyStatus
|
||||
from prisma.models import APIKey as PrismaAPIKey
|
||||
from prisma.types import APIKeyWhereUniqueInput
|
||||
from prisma.types import APIKeyCreateInput, APIKeyWhereUniqueInput
|
||||
from pydantic import Field
|
||||
|
||||
from backend.data.includes import MAX_USER_API_KEYS_FETCH
|
||||
@@ -82,17 +82,17 @@ async def create_api_key(
|
||||
generated_key = keysmith.generate_key()
|
||||
|
||||
saved_key_obj = await PrismaAPIKey.prisma().create(
|
||||
data={
|
||||
"id": str(uuid.uuid4()),
|
||||
"name": name,
|
||||
"head": generated_key.head,
|
||||
"tail": generated_key.tail,
|
||||
"hash": generated_key.hash,
|
||||
"salt": generated_key.salt,
|
||||
"permissions": [p for p in permissions],
|
||||
"description": description,
|
||||
"userId": user_id,
|
||||
}
|
||||
data=APIKeyCreateInput(
|
||||
id=str(uuid.uuid4()),
|
||||
name=name,
|
||||
head=generated_key.head,
|
||||
tail=generated_key.tail,
|
||||
hash=generated_key.hash,
|
||||
salt=generated_key.salt,
|
||||
permissions=[p for p in permissions],
|
||||
description=description,
|
||||
userId=user_id,
|
||||
)
|
||||
)
|
||||
|
||||
return APIKeyInfo.from_db(saved_key_obj), generated_key.key
|
||||
|
||||
@@ -22,7 +22,12 @@ from prisma.models import OAuthAccessToken as PrismaOAuthAccessToken
|
||||
from prisma.models import OAuthApplication as PrismaOAuthApplication
|
||||
from prisma.models import OAuthAuthorizationCode as PrismaOAuthAuthorizationCode
|
||||
from prisma.models import OAuthRefreshToken as PrismaOAuthRefreshToken
|
||||
from prisma.types import OAuthApplicationUpdateInput
|
||||
from prisma.types import (
|
||||
OAuthAccessTokenCreateInput,
|
||||
OAuthApplicationUpdateInput,
|
||||
OAuthAuthorizationCodeCreateInput,
|
||||
OAuthRefreshTokenCreateInput,
|
||||
)
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from .base import APIAuthorizationInfo
|
||||
@@ -359,17 +364,17 @@ async def create_authorization_code(
|
||||
expires_at = now + AUTHORIZATION_CODE_TTL
|
||||
|
||||
saved_code = await PrismaOAuthAuthorizationCode.prisma().create(
|
||||
data={
|
||||
"id": str(uuid.uuid4()),
|
||||
"code": code,
|
||||
"expiresAt": expires_at,
|
||||
"applicationId": application_id,
|
||||
"userId": user_id,
|
||||
"scopes": [s for s in scopes],
|
||||
"redirectUri": redirect_uri,
|
||||
"codeChallenge": code_challenge,
|
||||
"codeChallengeMethod": code_challenge_method,
|
||||
}
|
||||
data=OAuthAuthorizationCodeCreateInput(
|
||||
id=str(uuid.uuid4()),
|
||||
code=code,
|
||||
expiresAt=expires_at,
|
||||
applicationId=application_id,
|
||||
userId=user_id,
|
||||
scopes=[s for s in scopes],
|
||||
redirectUri=redirect_uri,
|
||||
codeChallenge=code_challenge,
|
||||
codeChallengeMethod=code_challenge_method,
|
||||
)
|
||||
)
|
||||
|
||||
return OAuthAuthorizationCodeInfo.from_db(saved_code)
|
||||
@@ -490,14 +495,14 @@ async def create_access_token(
|
||||
expires_at = now + ACCESS_TOKEN_TTL
|
||||
|
||||
saved_token = await PrismaOAuthAccessToken.prisma().create(
|
||||
data={
|
||||
"id": str(uuid.uuid4()),
|
||||
"token": token_hash, # SHA256 hash for direct lookup
|
||||
"expiresAt": expires_at,
|
||||
"applicationId": application_id,
|
||||
"userId": user_id,
|
||||
"scopes": [s for s in scopes],
|
||||
}
|
||||
data=OAuthAccessTokenCreateInput(
|
||||
id=str(uuid.uuid4()),
|
||||
token=token_hash, # SHA256 hash for direct lookup
|
||||
expiresAt=expires_at,
|
||||
applicationId=application_id,
|
||||
userId=user_id,
|
||||
scopes=[s for s in scopes],
|
||||
)
|
||||
)
|
||||
|
||||
return OAuthAccessToken.from_db(saved_token, plaintext_token=plaintext_token)
|
||||
@@ -607,14 +612,14 @@ async def create_refresh_token(
|
||||
expires_at = now + REFRESH_TOKEN_TTL
|
||||
|
||||
saved_token = await PrismaOAuthRefreshToken.prisma().create(
|
||||
data={
|
||||
"id": str(uuid.uuid4()),
|
||||
"token": token_hash, # SHA256 hash for direct lookup
|
||||
"expiresAt": expires_at,
|
||||
"applicationId": application_id,
|
||||
"userId": user_id,
|
||||
"scopes": [s for s in scopes],
|
||||
}
|
||||
data=OAuthRefreshTokenCreateInput(
|
||||
id=str(uuid.uuid4()),
|
||||
token=token_hash, # SHA256 hash for direct lookup
|
||||
expiresAt=expires_at,
|
||||
applicationId=application_id,
|
||||
userId=user_id,
|
||||
scopes=[s for s in scopes],
|
||||
)
|
||||
)
|
||||
|
||||
return OAuthRefreshToken.from_db(saved_token, plaintext_token=plaintext_token)
|
||||
|
||||
@@ -11,6 +11,11 @@ import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditTransaction, User, UserBalance
|
||||
from prisma.types import (
|
||||
UserBalanceCreateInput,
|
||||
UserBalanceUpsertInput,
|
||||
UserCreateInput,
|
||||
)
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
from backend.util.json import SafeJson
|
||||
@@ -21,11 +26,11 @@ async def create_test_user(user_id: str) -> None:
|
||||
"""Create a test user for ceiling tests."""
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
}
|
||||
data=UserCreateInput(
|
||||
id=user_id,
|
||||
email=f"test-{user_id}@example.com",
|
||||
name=f"Test User {user_id[:8]}",
|
||||
)
|
||||
)
|
||||
except UniqueViolationError:
|
||||
# User already exists, continue
|
||||
@@ -33,7 +38,10 @@ async def create_test_user(user_id: str) -> None:
|
||||
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
|
||||
data=UserBalanceUpsertInput(
|
||||
create=UserBalanceCreateInput(userId=user_id, balance=0),
|
||||
update={"balance": 0},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -107,15 +115,15 @@ async def test_ceiling_balance_clamps_when_would_exceed(server: SpinTestServer):
|
||||
)
|
||||
|
||||
# Balance should be clamped to ceiling
|
||||
assert (
|
||||
final_balance == 1000
|
||||
), f"Balance should be clamped to 1000, got {final_balance}"
|
||||
assert final_balance == 1000, (
|
||||
f"Balance should be clamped to 1000, got {final_balance}"
|
||||
)
|
||||
|
||||
# Verify with get_credits too
|
||||
stored_balance = await credit_system.get_credits(user_id)
|
||||
assert (
|
||||
stored_balance == 1000
|
||||
), f"Stored balance should be 1000, got {stored_balance}"
|
||||
assert stored_balance == 1000, (
|
||||
f"Stored balance should be 1000, got {stored_balance}"
|
||||
)
|
||||
|
||||
# Verify transaction shows the clamped amount
|
||||
transactions = await CreditTransaction.prisma().find_many(
|
||||
@@ -164,9 +172,9 @@ async def test_ceiling_balance_allows_when_under_threshold(server: SpinTestServe
|
||||
|
||||
# Verify with get_credits too
|
||||
stored_balance = await credit_system.get_credits(user_id)
|
||||
assert (
|
||||
stored_balance == 500
|
||||
), f"Stored balance should be 500, got {stored_balance}"
|
||||
assert stored_balance == 500, (
|
||||
f"Stored balance should be 500, got {stored_balance}"
|
||||
)
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
@@ -14,6 +14,11 @@ import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditTransaction, User, UserBalance
|
||||
from prisma.types import (
|
||||
UserBalanceCreateInput,
|
||||
UserBalanceUpsertInput,
|
||||
UserCreateInput,
|
||||
)
|
||||
|
||||
from backend.data.credit import POSTGRES_INT_MAX, UsageTransactionMetadata, UserCredit
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
@@ -28,11 +33,11 @@ async def create_test_user(user_id: str) -> None:
|
||||
"""Create a test user with initial balance."""
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
}
|
||||
data=UserCreateInput(
|
||||
id=user_id,
|
||||
email=f"test-{user_id}@example.com",
|
||||
name=f"Test User {user_id[:8]}",
|
||||
)
|
||||
)
|
||||
except UniqueViolationError:
|
||||
# User already exists, continue
|
||||
@@ -41,7 +46,10 @@ async def create_test_user(user_id: str) -> None:
|
||||
# Ensure UserBalance record exists
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
|
||||
data=UserBalanceUpsertInput(
|
||||
create=UserBalanceCreateInput(userId=user_id, balance=0),
|
||||
update={"balance": 0},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -108,9 +116,9 @@ async def test_concurrent_spends_same_user(server: SpinTestServer):
|
||||
transactions = await CreditTransaction.prisma().find_many(
|
||||
where={"userId": user_id, "type": prisma.enums.CreditTransactionType.USAGE}
|
||||
)
|
||||
assert (
|
||||
len(transactions) == 10
|
||||
), f"Expected 10 transactions, got {len(transactions)}"
|
||||
assert len(transactions) == 10, (
|
||||
f"Expected 10 transactions, got {len(transactions)}"
|
||||
)
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
@@ -321,9 +329,9 @@ async def test_onboarding_reward_idempotency(server: SpinTestServer):
|
||||
"transactionKey": f"REWARD-{user_id}-WELCOME",
|
||||
}
|
||||
)
|
||||
assert (
|
||||
len(transactions) == 1
|
||||
), f"Expected 1 reward transaction, got {len(transactions)}"
|
||||
assert len(transactions) == 1, (
|
||||
f"Expected 1 reward transaction, got {len(transactions)}"
|
||||
)
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
@@ -342,10 +350,10 @@ async def test_integer_overflow_protection(server: SpinTestServer):
|
||||
# First, set balance near max
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": max_int - 100},
|
||||
"update": {"balance": max_int - 100},
|
||||
},
|
||||
data=UserBalanceUpsertInput(
|
||||
create=UserBalanceCreateInput(userId=user_id, balance=max_int - 100),
|
||||
update={"balance": max_int - 100},
|
||||
),
|
||||
)
|
||||
|
||||
# Try to add more than possible - should clamp to POSTGRES_INT_MAX
|
||||
@@ -358,9 +366,9 @@ async def test_integer_overflow_protection(server: SpinTestServer):
|
||||
|
||||
# Balance should be clamped to max_int, not overflowed
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert (
|
||||
final_balance == max_int
|
||||
), f"Balance should be clamped to {max_int}, got {final_balance}"
|
||||
assert final_balance == max_int, (
|
||||
f"Balance should be clamped to {max_int}, got {final_balance}"
|
||||
)
|
||||
|
||||
# Verify transaction was created with clamped amount
|
||||
transactions = await CreditTransaction.prisma().find_many(
|
||||
@@ -371,9 +379,9 @@ async def test_integer_overflow_protection(server: SpinTestServer):
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
assert len(transactions) > 0, "Transaction should be created"
|
||||
assert (
|
||||
transactions[0].runningBalance == max_int
|
||||
), "Transaction should show clamped balance"
|
||||
assert transactions[0].runningBalance == max_int, (
|
||||
"Transaction should show clamped balance"
|
||||
)
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
@@ -432,9 +440,9 @@ async def test_high_concurrency_stress(server: SpinTestServer):
|
||||
|
||||
# Verify final balance
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert (
|
||||
final_balance == expected_balance
|
||||
), f"Expected {expected_balance}, got {final_balance}"
|
||||
assert final_balance == expected_balance, (
|
||||
f"Expected {expected_balance}, got {final_balance}"
|
||||
)
|
||||
assert final_balance >= 0, "Balance went negative!"
|
||||
|
||||
finally:
|
||||
@@ -507,7 +515,7 @@ async def test_concurrent_multiple_spends_sufficient_balance(server: SpinTestSer
|
||||
sorted_timings = sorted(timings.items(), key=lambda x: x[1]["start"])
|
||||
print("\nExecution order by start time:")
|
||||
for i, (label, timing) in enumerate(sorted_timings):
|
||||
print(f" {i+1}. {label}: {timing['start']:.4f} -> {timing['end']:.4f}")
|
||||
print(f" {i + 1}. {label}: {timing['start']:.4f} -> {timing['end']:.4f}")
|
||||
|
||||
# Check for overlap (true concurrency) vs serialization
|
||||
overlaps = []
|
||||
@@ -533,9 +541,9 @@ async def test_concurrent_multiple_spends_sufficient_balance(server: SpinTestSer
|
||||
print(f"Successful: {len(successful)}, Failed: {len(failed)}")
|
||||
|
||||
# All should succeed since 150 - (10 + 20 + 30) = 90 > 0
|
||||
assert (
|
||||
len(successful) == 3
|
||||
), f"Expected all 3 to succeed, got {len(successful)} successes: {results}"
|
||||
assert len(successful) == 3, (
|
||||
f"Expected all 3 to succeed, got {len(successful)} successes: {results}"
|
||||
)
|
||||
assert final_balance == 90, f"Expected balance 90, got {final_balance}"
|
||||
|
||||
# Check transaction timestamps to confirm database-level serialization
|
||||
@@ -546,7 +554,7 @@ async def test_concurrent_multiple_spends_sufficient_balance(server: SpinTestSer
|
||||
print("\nDatabase transaction order (by createdAt):")
|
||||
for i, tx in enumerate(transactions):
|
||||
print(
|
||||
f" {i+1}. Amount {tx.amount}, Running balance: {tx.runningBalance}, Created: {tx.createdAt}"
|
||||
f" {i + 1}. Amount {tx.amount}, Running balance: {tx.runningBalance}, Created: {tx.createdAt}"
|
||||
)
|
||||
|
||||
# Verify running balances are chronologically consistent (ordered by createdAt)
|
||||
@@ -575,38 +583,38 @@ async def test_concurrent_multiple_spends_sufficient_balance(server: SpinTestSer
|
||||
|
||||
# Verify all balances are valid intermediate states
|
||||
for balance in actual_balances:
|
||||
assert (
|
||||
balance in expected_possible_balances
|
||||
), f"Invalid balance {balance}, expected one of {expected_possible_balances}"
|
||||
assert balance in expected_possible_balances, (
|
||||
f"Invalid balance {balance}, expected one of {expected_possible_balances}"
|
||||
)
|
||||
|
||||
# Final balance should always be 90 (150 - 60)
|
||||
assert (
|
||||
min(actual_balances) == 90
|
||||
), f"Final balance should be 90, got {min(actual_balances)}"
|
||||
assert min(actual_balances) == 90, (
|
||||
f"Final balance should be 90, got {min(actual_balances)}"
|
||||
)
|
||||
|
||||
# The final transaction should always have balance 90
|
||||
# The other transactions should have valid intermediate balances
|
||||
assert (
|
||||
90 in actual_balances
|
||||
), f"Final balance 90 should be in actual_balances: {actual_balances}"
|
||||
assert 90 in actual_balances, (
|
||||
f"Final balance 90 should be in actual_balances: {actual_balances}"
|
||||
)
|
||||
|
||||
# All balances should be >= 90 (the final state)
|
||||
assert all(
|
||||
balance >= 90 for balance in actual_balances
|
||||
), f"All balances should be >= 90, got {actual_balances}"
|
||||
assert all(balance >= 90 for balance in actual_balances), (
|
||||
f"All balances should be >= 90, got {actual_balances}"
|
||||
)
|
||||
|
||||
# CRITICAL: Transactions are atomic but can complete in any order
|
||||
# What matters is that all running balances are valid intermediate states
|
||||
# Each balance should be between 90 (final) and 140 (after first transaction)
|
||||
for balance in actual_balances:
|
||||
assert (
|
||||
90 <= balance <= 140
|
||||
), f"Balance {balance} is outside valid range [90, 140]"
|
||||
assert 90 <= balance <= 140, (
|
||||
f"Balance {balance} is outside valid range [90, 140]"
|
||||
)
|
||||
|
||||
# Final balance (minimum) should always be 90
|
||||
assert (
|
||||
min(actual_balances) == 90
|
||||
), f"Final balance should be 90, got {min(actual_balances)}"
|
||||
assert min(actual_balances) == 90, (
|
||||
f"Final balance should be 90, got {min(actual_balances)}"
|
||||
)
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
@@ -707,7 +715,7 @@ async def test_prove_database_locking_behavior(server: SpinTestServer):
|
||||
|
||||
for i, result in enumerate(sorted_results):
|
||||
print(
|
||||
f" {i+1}. {result['label']}: DB operation took {result['db_duration']:.4f}s"
|
||||
f" {i + 1}. {result['label']}: DB operation took {result['db_duration']:.4f}s"
|
||||
)
|
||||
|
||||
# Check if any operations overlapped at the database level
|
||||
@@ -722,9 +730,9 @@ async def test_prove_database_locking_behavior(server: SpinTestServer):
|
||||
print(f"\n💰 Final balance: {final_balance}")
|
||||
|
||||
if len(successful) == 3:
|
||||
assert (
|
||||
final_balance == 0
|
||||
), f"If all succeeded, balance should be 0, got {final_balance}"
|
||||
assert final_balance == 0, (
|
||||
f"If all succeeded, balance should be 0, got {final_balance}"
|
||||
)
|
||||
print(
|
||||
"✅ CONCLUSION: Database row locking causes requests to WAIT and execute serially"
|
||||
)
|
||||
|
||||
@@ -8,6 +8,7 @@ which would have caught the CreditTransactionType enum casting bug.
|
||||
import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.models import CreditTransaction, User, UserBalance
|
||||
from prisma.types import UserCreateInput
|
||||
|
||||
from backend.data.credit import (
|
||||
AutoTopUpConfig,
|
||||
@@ -29,12 +30,12 @@ async def cleanup_test_user():
|
||||
# Create the user first
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"topUpConfig": SafeJson({}),
|
||||
"timezone": "UTC",
|
||||
}
|
||||
data=UserCreateInput(
|
||||
id=user_id,
|
||||
email=f"test-{user_id}@example.com",
|
||||
topUpConfig=SafeJson({}),
|
||||
timezone="UTC",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
# User might already exist, that's fine
|
||||
|
||||
@@ -12,6 +12,12 @@ import pytest
|
||||
import stripe
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.models import CreditRefundRequest, CreditTransaction, User, UserBalance
|
||||
from prisma.types import (
|
||||
CreditRefundRequestCreateInput,
|
||||
CreditTransactionCreateInput,
|
||||
UserBalanceCreateInput,
|
||||
UserCreateInput,
|
||||
)
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
from backend.util.json import SafeJson
|
||||
@@ -35,32 +41,32 @@ async def setup_test_user_with_topup():
|
||||
|
||||
# Create user
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": REFUND_TEST_USER_ID,
|
||||
"email": f"{REFUND_TEST_USER_ID}@example.com",
|
||||
"name": "Refund Test User",
|
||||
}
|
||||
data=UserCreateInput(
|
||||
id=REFUND_TEST_USER_ID,
|
||||
email=f"{REFUND_TEST_USER_ID}@example.com",
|
||||
name="Refund Test User",
|
||||
)
|
||||
)
|
||||
|
||||
# Create user balance
|
||||
await UserBalance.prisma().create(
|
||||
data={
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"balance": 1000, # $10
|
||||
}
|
||||
data=UserBalanceCreateInput(
|
||||
userId=REFUND_TEST_USER_ID,
|
||||
balance=1000, # $10
|
||||
)
|
||||
)
|
||||
|
||||
# Create a top-up transaction that can be refunded
|
||||
topup_tx = await CreditTransaction.prisma().create(
|
||||
data={
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"amount": 1000,
|
||||
"type": CreditTransactionType.TOP_UP,
|
||||
"transactionKey": "pi_test_12345",
|
||||
"runningBalance": 1000,
|
||||
"isActive": True,
|
||||
"metadata": SafeJson({"stripe_payment_intent": "pi_test_12345"}),
|
||||
}
|
||||
data=CreditTransactionCreateInput(
|
||||
userId=REFUND_TEST_USER_ID,
|
||||
amount=1000,
|
||||
type=CreditTransactionType.TOP_UP,
|
||||
transactionKey="pi_test_12345",
|
||||
runningBalance=1000,
|
||||
isActive=True,
|
||||
metadata=SafeJson({"stripe_payment_intent": "pi_test_12345"}),
|
||||
)
|
||||
)
|
||||
|
||||
return topup_tx
|
||||
@@ -93,12 +99,12 @@ async def test_deduct_credits_atomic(server: SpinTestServer):
|
||||
|
||||
# Create refund request record (simulating webhook flow)
|
||||
await CreditRefundRequest.prisma().create(
|
||||
data={
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"amount": 500,
|
||||
"transactionKey": topup_tx.transactionKey, # Should match the original transaction
|
||||
"reason": "Test refund",
|
||||
}
|
||||
data=CreditRefundRequestCreateInput(
|
||||
userId=REFUND_TEST_USER_ID,
|
||||
amount=500,
|
||||
transactionKey=topup_tx.transactionKey, # Should match the original transaction
|
||||
reason="Test refund",
|
||||
)
|
||||
)
|
||||
|
||||
# Call deduct_credits
|
||||
@@ -109,9 +115,9 @@ async def test_deduct_credits_atomic(server: SpinTestServer):
|
||||
where={"userId": REFUND_TEST_USER_ID}
|
||||
)
|
||||
assert user_balance is not None
|
||||
assert (
|
||||
user_balance.balance == 500
|
||||
), f"Expected balance 500, got {user_balance.balance}"
|
||||
assert user_balance.balance == 500, (
|
||||
f"Expected balance 500, got {user_balance.balance}"
|
||||
)
|
||||
|
||||
# Verify refund transaction was created
|
||||
refund_tx = await CreditTransaction.prisma().find_first(
|
||||
@@ -205,9 +211,9 @@ async def test_handle_dispute_with_sufficient_balance(
|
||||
where={"userId": REFUND_TEST_USER_ID}
|
||||
)
|
||||
assert user_balance is not None
|
||||
assert (
|
||||
user_balance.balance == 1000
|
||||
), f"Balance should remain 1000, got {user_balance.balance}"
|
||||
assert user_balance.balance == 1000, (
|
||||
f"Balance should remain 1000, got {user_balance.balance}"
|
||||
)
|
||||
|
||||
finally:
|
||||
await cleanup_test_user()
|
||||
@@ -286,12 +292,12 @@ async def test_concurrent_refunds(server: SpinTestServer):
|
||||
refund_requests = []
|
||||
for i in range(5):
|
||||
req = await CreditRefundRequest.prisma().create(
|
||||
data={
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"amount": 100, # $1 each
|
||||
"transactionKey": topup_tx.transactionKey,
|
||||
"reason": f"Test refund {i}",
|
||||
}
|
||||
data=CreditRefundRequestCreateInput(
|
||||
userId=REFUND_TEST_USER_ID,
|
||||
amount=100, # $1 each
|
||||
transactionKey=topup_tx.transactionKey,
|
||||
reason=f"Test refund {i}",
|
||||
)
|
||||
)
|
||||
refund_requests.append(req)
|
||||
|
||||
@@ -332,9 +338,9 @@ async def test_concurrent_refunds(server: SpinTestServer):
|
||||
print(f"DEBUG: Final balance = {user_balance.balance}, expected = 500")
|
||||
|
||||
# With atomic implementation, all 5 refunds should process correctly
|
||||
assert (
|
||||
user_balance.balance == 500
|
||||
), f"Expected balance 500 after 5 refunds of 100 each, got {user_balance.balance}"
|
||||
assert user_balance.balance == 500, (
|
||||
f"Expected balance 500 after 5 refunds of 100 each, got {user_balance.balance}"
|
||||
)
|
||||
|
||||
# Verify all refund transactions exist
|
||||
refund_txs = await CreditTransaction.prisma().find_many(
|
||||
@@ -343,9 +349,9 @@ async def test_concurrent_refunds(server: SpinTestServer):
|
||||
"type": CreditTransactionType.REFUND,
|
||||
}
|
||||
)
|
||||
assert (
|
||||
len(refund_txs) == 5
|
||||
), f"Expected 5 refund transactions, got {len(refund_txs)}"
|
||||
assert len(refund_txs) == 5, (
|
||||
f"Expected 5 refund transactions, got {len(refund_txs)}"
|
||||
)
|
||||
|
||||
running_balances: set[int] = {
|
||||
tx.runningBalance for tx in refund_txs if tx.runningBalance is not None
|
||||
@@ -353,20 +359,20 @@ async def test_concurrent_refunds(server: SpinTestServer):
|
||||
|
||||
# Verify all balances are valid intermediate states
|
||||
for balance in running_balances:
|
||||
assert (
|
||||
500 <= balance <= 1000
|
||||
), f"Invalid balance {balance}, should be between 500 and 1000"
|
||||
assert 500 <= balance <= 1000, (
|
||||
f"Invalid balance {balance}, should be between 500 and 1000"
|
||||
)
|
||||
|
||||
# Final balance should be present
|
||||
assert (
|
||||
500 in running_balances
|
||||
), f"Final balance 500 should be in {running_balances}"
|
||||
assert 500 in running_balances, (
|
||||
f"Final balance 500 should be in {running_balances}"
|
||||
)
|
||||
|
||||
# All balances should be unique and form a valid sequence
|
||||
sorted_balances = sorted(running_balances, reverse=True)
|
||||
assert (
|
||||
len(sorted_balances) == 5
|
||||
), f"Expected 5 unique balances, got {len(sorted_balances)}"
|
||||
assert len(sorted_balances) == 5, (
|
||||
f"Expected 5 unique balances, got {len(sorted_balances)}"
|
||||
)
|
||||
|
||||
finally:
|
||||
await cleanup_test_user()
|
||||
|
||||
@@ -3,6 +3,11 @@ from datetime import datetime, timedelta, timezone
|
||||
import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.models import CreditTransaction, UserBalance
|
||||
from prisma.types import (
|
||||
CreditTransactionCreateInput,
|
||||
UserBalanceCreateInput,
|
||||
UserBalanceUpsertInput,
|
||||
)
|
||||
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
from backend.data.block import get_block
|
||||
@@ -23,10 +28,10 @@ async def disable_test_user_transactions():
|
||||
old_date = datetime.now(timezone.utc) - timedelta(days=35) # More than a month ago
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": DEFAULT_USER_ID},
|
||||
data={
|
||||
"create": {"userId": DEFAULT_USER_ID, "balance": 0},
|
||||
"update": {"balance": 0, "updatedAt": old_date},
|
||||
},
|
||||
data=UserBalanceUpsertInput(
|
||||
create=UserBalanceCreateInput(userId=DEFAULT_USER_ID, balance=0),
|
||||
update={"balance": 0, "updatedAt": old_date},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -140,23 +145,23 @@ async def test_block_credit_reset(server: SpinTestServer):
|
||||
|
||||
# Manually create a transaction with month 1 timestamp to establish history
|
||||
await CreditTransaction.prisma().create(
|
||||
data={
|
||||
"userId": DEFAULT_USER_ID,
|
||||
"amount": 100,
|
||||
"type": CreditTransactionType.TOP_UP,
|
||||
"runningBalance": 1100,
|
||||
"isActive": True,
|
||||
"createdAt": month1, # Set specific timestamp
|
||||
}
|
||||
data=CreditTransactionCreateInput(
|
||||
userId=DEFAULT_USER_ID,
|
||||
amount=100,
|
||||
type=CreditTransactionType.TOP_UP,
|
||||
runningBalance=1100,
|
||||
isActive=True,
|
||||
createdAt=month1, # Set specific timestamp
|
||||
)
|
||||
)
|
||||
|
||||
# Update user balance to match
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": DEFAULT_USER_ID},
|
||||
data={
|
||||
"create": {"userId": DEFAULT_USER_ID, "balance": 1100},
|
||||
"update": {"balance": 1100},
|
||||
},
|
||||
data=UserBalanceUpsertInput(
|
||||
create=UserBalanceCreateInput(userId=DEFAULT_USER_ID, balance=1100),
|
||||
update={"balance": 1100},
|
||||
),
|
||||
)
|
||||
|
||||
# Now test month 2 behavior
|
||||
@@ -175,14 +180,14 @@ async def test_block_credit_reset(server: SpinTestServer):
|
||||
|
||||
# Create a month 2 transaction to update the last transaction time
|
||||
await CreditTransaction.prisma().create(
|
||||
data={
|
||||
"userId": DEFAULT_USER_ID,
|
||||
"amount": -700, # Spent 700 to get to 400
|
||||
"type": CreditTransactionType.USAGE,
|
||||
"runningBalance": 400,
|
||||
"isActive": True,
|
||||
"createdAt": month2,
|
||||
}
|
||||
data=CreditTransactionCreateInput(
|
||||
userId=DEFAULT_USER_ID,
|
||||
amount=-700, # Spent 700 to get to 400
|
||||
type=CreditTransactionType.USAGE,
|
||||
runningBalance=400,
|
||||
isActive=True,
|
||||
createdAt=month2,
|
||||
)
|
||||
)
|
||||
|
||||
# Move to month 3
|
||||
|
||||
@@ -12,6 +12,11 @@ import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditTransaction, User, UserBalance
|
||||
from prisma.types import (
|
||||
UserBalanceCreateInput,
|
||||
UserBalanceUpsertInput,
|
||||
UserCreateInput,
|
||||
)
|
||||
|
||||
from backend.data.credit import POSTGRES_INT_MIN, UserCredit
|
||||
from backend.util.test import SpinTestServer
|
||||
@@ -21,11 +26,11 @@ async def create_test_user(user_id: str) -> None:
|
||||
"""Create a test user for underflow tests."""
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
}
|
||||
data=UserCreateInput(
|
||||
id=user_id,
|
||||
email=f"test-{user_id}@example.com",
|
||||
name=f"Test User {user_id[:8]}",
|
||||
)
|
||||
)
|
||||
except UniqueViolationError:
|
||||
# User already exists, continue
|
||||
@@ -33,7 +38,10 @@ async def create_test_user(user_id: str) -> None:
|
||||
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
|
||||
data=UserBalanceUpsertInput(
|
||||
create=UserBalanceCreateInput(userId=user_id, balance=0),
|
||||
update={"balance": 0},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -66,14 +74,14 @@ async def test_debug_underflow_step_by_step(server: SpinTestServer):
|
||||
initial_balance_target = POSTGRES_INT_MIN + 100
|
||||
|
||||
# Use direct database update to set the balance close to underflow
|
||||
from prisma.models import UserBalance
|
||||
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": initial_balance_target},
|
||||
"update": {"balance": initial_balance_target},
|
||||
},
|
||||
data=UserBalanceUpsertInput(
|
||||
create=UserBalanceCreateInput(
|
||||
userId=user_id, balance=initial_balance_target
|
||||
),
|
||||
update={"balance": initial_balance_target},
|
||||
),
|
||||
)
|
||||
|
||||
current_balance = await credit_system.get_credits(user_id)
|
||||
@@ -82,9 +90,7 @@ async def test_debug_underflow_step_by_step(server: SpinTestServer):
|
||||
|
||||
# Test 2: Apply amount that should cause underflow
|
||||
print("\n=== Test 2: Testing underflow protection ===")
|
||||
test_amount = (
|
||||
-200
|
||||
) # This should cause underflow: (POSTGRES_INT_MIN + 100) + (-200) = POSTGRES_INT_MIN - 100
|
||||
test_amount = -200 # This should cause underflow: (POSTGRES_INT_MIN + 100) + (-200) = POSTGRES_INT_MIN - 100
|
||||
expected_without_protection = current_balance + test_amount
|
||||
print(f"Current balance: {current_balance}")
|
||||
print(f"Test amount: {test_amount}")
|
||||
@@ -101,19 +107,19 @@ async def test_debug_underflow_step_by_step(server: SpinTestServer):
|
||||
print(f"Actual result: {balance_result}")
|
||||
|
||||
# Check if underflow protection worked
|
||||
assert (
|
||||
balance_result == POSTGRES_INT_MIN
|
||||
), f"Expected underflow protection to clamp balance to {POSTGRES_INT_MIN}, got {balance_result}"
|
||||
assert balance_result == POSTGRES_INT_MIN, (
|
||||
f"Expected underflow protection to clamp balance to {POSTGRES_INT_MIN}, got {balance_result}"
|
||||
)
|
||||
|
||||
# Test 3: Edge case - exactly at POSTGRES_INT_MIN
|
||||
print("\n=== Test 3: Testing exact POSTGRES_INT_MIN boundary ===")
|
||||
# Set balance to exactly POSTGRES_INT_MIN
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": POSTGRES_INT_MIN},
|
||||
"update": {"balance": POSTGRES_INT_MIN},
|
||||
},
|
||||
data=UserBalanceUpsertInput(
|
||||
create=UserBalanceCreateInput(userId=user_id, balance=POSTGRES_INT_MIN),
|
||||
update={"balance": POSTGRES_INT_MIN},
|
||||
),
|
||||
)
|
||||
|
||||
edge_balance = await credit_system.get_credits(user_id)
|
||||
@@ -128,9 +134,9 @@ async def test_debug_underflow_step_by_step(server: SpinTestServer):
|
||||
)
|
||||
print(f"After subtracting 1: {edge_result}")
|
||||
|
||||
assert (
|
||||
edge_result == POSTGRES_INT_MIN
|
||||
), f"Expected balance to remain clamped at {POSTGRES_INT_MIN}, got {edge_result}"
|
||||
assert edge_result == POSTGRES_INT_MIN, (
|
||||
f"Expected balance to remain clamped at {POSTGRES_INT_MIN}, got {edge_result}"
|
||||
)
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
@@ -147,15 +153,13 @@ async def test_underflow_protection_large_refunds(server: SpinTestServer):
|
||||
# Set up balance close to underflow threshold to test the protection
|
||||
# Set balance to POSTGRES_INT_MIN + 1000, then try to subtract 2000
|
||||
# This should trigger underflow protection
|
||||
from prisma.models import UserBalance
|
||||
|
||||
test_balance = POSTGRES_INT_MIN + 1000
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": test_balance},
|
||||
"update": {"balance": test_balance},
|
||||
},
|
||||
data=UserBalanceUpsertInput(
|
||||
create=UserBalanceCreateInput(userId=user_id, balance=test_balance),
|
||||
update={"balance": test_balance},
|
||||
),
|
||||
)
|
||||
|
||||
current_balance = await credit_system.get_credits(user_id)
|
||||
@@ -176,18 +180,18 @@ async def test_underflow_protection_large_refunds(server: SpinTestServer):
|
||||
)
|
||||
|
||||
# Balance should be clamped to POSTGRES_INT_MIN, not the calculated underflow value
|
||||
assert (
|
||||
final_balance == POSTGRES_INT_MIN
|
||||
), f"Balance should be clamped to {POSTGRES_INT_MIN}, got {final_balance}"
|
||||
assert (
|
||||
final_balance > expected_without_protection
|
||||
), f"Balance should be greater than underflow result {expected_without_protection}, got {final_balance}"
|
||||
assert final_balance == POSTGRES_INT_MIN, (
|
||||
f"Balance should be clamped to {POSTGRES_INT_MIN}, got {final_balance}"
|
||||
)
|
||||
assert final_balance > expected_without_protection, (
|
||||
f"Balance should be greater than underflow result {expected_without_protection}, got {final_balance}"
|
||||
)
|
||||
|
||||
# Verify with get_credits too
|
||||
stored_balance = await credit_system.get_credits(user_id)
|
||||
assert (
|
||||
stored_balance == POSTGRES_INT_MIN
|
||||
), f"Stored balance should be {POSTGRES_INT_MIN}, got {stored_balance}"
|
||||
assert stored_balance == POSTGRES_INT_MIN, (
|
||||
f"Stored balance should be {POSTGRES_INT_MIN}, got {stored_balance}"
|
||||
)
|
||||
|
||||
# Verify transaction was created with the underflow-protected balance
|
||||
transactions = await CreditTransaction.prisma().find_many(
|
||||
@@ -195,9 +199,9 @@ async def test_underflow_protection_large_refunds(server: SpinTestServer):
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
assert len(transactions) > 0, "Refund transaction should be created"
|
||||
assert (
|
||||
transactions[0].runningBalance == POSTGRES_INT_MIN
|
||||
), f"Transaction should show clamped balance {POSTGRES_INT_MIN}, got {transactions[0].runningBalance}"
|
||||
assert transactions[0].runningBalance == POSTGRES_INT_MIN, (
|
||||
f"Transaction should show clamped balance {POSTGRES_INT_MIN}, got {transactions[0].runningBalance}"
|
||||
)
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
@@ -212,15 +216,13 @@ async def test_multiple_large_refunds_cumulative_underflow(server: SpinTestServe
|
||||
|
||||
try:
|
||||
# Set up balance close to underflow threshold
|
||||
from prisma.models import UserBalance
|
||||
|
||||
initial_balance = POSTGRES_INT_MIN + 500 # Close to minimum but with some room
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": initial_balance},
|
||||
"update": {"balance": initial_balance},
|
||||
},
|
||||
data=UserBalanceUpsertInput(
|
||||
create=UserBalanceCreateInput(userId=user_id, balance=initial_balance),
|
||||
update={"balance": initial_balance},
|
||||
),
|
||||
)
|
||||
|
||||
# Apply multiple refunds that would cumulatively underflow
|
||||
@@ -238,12 +240,12 @@ async def test_multiple_large_refunds_cumulative_underflow(server: SpinTestServe
|
||||
expected_balance_1 = (
|
||||
initial_balance + refund_amount
|
||||
) # Should be POSTGRES_INT_MIN + 200
|
||||
assert (
|
||||
balance_1 == expected_balance_1
|
||||
), f"First refund should result in {expected_balance_1}, got {balance_1}"
|
||||
assert (
|
||||
balance_1 >= POSTGRES_INT_MIN
|
||||
), f"First refund should not go below {POSTGRES_INT_MIN}, got {balance_1}"
|
||||
assert balance_1 == expected_balance_1, (
|
||||
f"First refund should result in {expected_balance_1}, got {balance_1}"
|
||||
)
|
||||
assert balance_1 >= POSTGRES_INT_MIN, (
|
||||
f"First refund should not go below {POSTGRES_INT_MIN}, got {balance_1}"
|
||||
)
|
||||
|
||||
# Second refund: (POSTGRES_INT_MIN + 200) + (-300) = POSTGRES_INT_MIN - 100 (would underflow)
|
||||
balance_2, _ = await credit_system._add_transaction(
|
||||
@@ -254,9 +256,9 @@ async def test_multiple_large_refunds_cumulative_underflow(server: SpinTestServe
|
||||
)
|
||||
|
||||
# Should be clamped to minimum due to underflow protection
|
||||
assert (
|
||||
balance_2 == POSTGRES_INT_MIN
|
||||
), f"Second refund should be clamped to {POSTGRES_INT_MIN}, got {balance_2}"
|
||||
assert balance_2 == POSTGRES_INT_MIN, (
|
||||
f"Second refund should be clamped to {POSTGRES_INT_MIN}, got {balance_2}"
|
||||
)
|
||||
|
||||
# Third refund: Should stay at minimum
|
||||
balance_3, _ = await credit_system._add_transaction(
|
||||
@@ -267,15 +269,15 @@ async def test_multiple_large_refunds_cumulative_underflow(server: SpinTestServe
|
||||
)
|
||||
|
||||
# Should still be at minimum
|
||||
assert (
|
||||
balance_3 == POSTGRES_INT_MIN
|
||||
), f"Third refund should stay at {POSTGRES_INT_MIN}, got {balance_3}"
|
||||
assert balance_3 == POSTGRES_INT_MIN, (
|
||||
f"Third refund should stay at {POSTGRES_INT_MIN}, got {balance_3}"
|
||||
)
|
||||
|
||||
# Final balance check
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert (
|
||||
final_balance == POSTGRES_INT_MIN
|
||||
), f"Final balance should be {POSTGRES_INT_MIN}, got {final_balance}"
|
||||
assert final_balance == POSTGRES_INT_MIN, (
|
||||
f"Final balance should be {POSTGRES_INT_MIN}, got {final_balance}"
|
||||
)
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
@@ -295,10 +297,10 @@ async def test_concurrent_large_refunds_no_underflow(server: SpinTestServer):
|
||||
initial_balance = POSTGRES_INT_MIN + 1000 # Close to minimum
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": initial_balance},
|
||||
"update": {"balance": initial_balance},
|
||||
},
|
||||
data=UserBalanceUpsertInput(
|
||||
create=UserBalanceCreateInput(userId=user_id, balance=initial_balance),
|
||||
update={"balance": initial_balance},
|
||||
),
|
||||
)
|
||||
|
||||
async def large_refund(amount: int, label: str):
|
||||
@@ -327,35 +329,35 @@ async def test_concurrent_large_refunds_no_underflow(server: SpinTestServer):
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, tuple):
|
||||
balance, _ = result
|
||||
assert (
|
||||
balance >= POSTGRES_INT_MIN
|
||||
), f"Result {i} balance {balance} underflowed below {POSTGRES_INT_MIN}"
|
||||
assert balance >= POSTGRES_INT_MIN, (
|
||||
f"Result {i} balance {balance} underflowed below {POSTGRES_INT_MIN}"
|
||||
)
|
||||
valid_results.append(balance)
|
||||
elif isinstance(result, str) and "FAILED" in result:
|
||||
# Some operations might fail due to validation, that's okay
|
||||
pass
|
||||
else:
|
||||
# Unexpected exception
|
||||
assert not isinstance(
|
||||
result, Exception
|
||||
), f"Unexpected exception in result {i}: {result}"
|
||||
assert not isinstance(result, Exception), (
|
||||
f"Unexpected exception in result {i}: {result}"
|
||||
)
|
||||
|
||||
# At least one operation should succeed
|
||||
assert (
|
||||
len(valid_results) > 0
|
||||
), f"At least one refund should succeed, got results: {results}"
|
||||
assert len(valid_results) > 0, (
|
||||
f"At least one refund should succeed, got results: {results}"
|
||||
)
|
||||
|
||||
# All successful results should be >= POSTGRES_INT_MIN
|
||||
for balance in valid_results:
|
||||
assert (
|
||||
balance >= POSTGRES_INT_MIN
|
||||
), f"Balance {balance} should not be below {POSTGRES_INT_MIN}"
|
||||
assert balance >= POSTGRES_INT_MIN, (
|
||||
f"Balance {balance} should not be below {POSTGRES_INT_MIN}"
|
||||
)
|
||||
|
||||
# Final balance should be valid and at or above POSTGRES_INT_MIN
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert (
|
||||
final_balance >= POSTGRES_INT_MIN
|
||||
), f"Final balance {final_balance} should not underflow below {POSTGRES_INT_MIN}"
|
||||
assert final_balance >= POSTGRES_INT_MIN, (
|
||||
f"Final balance {final_balance} should not underflow below {POSTGRES_INT_MIN}"
|
||||
)
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
@@ -14,6 +14,7 @@ import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditTransaction, User, UserBalance
|
||||
from prisma.types import UserBalanceCreateInput, UserCreateInput
|
||||
|
||||
from backend.data.credit import UsageTransactionMetadata, UserCredit
|
||||
from backend.util.json import SafeJson
|
||||
@@ -24,11 +25,11 @@ async def create_test_user(user_id: str) -> None:
|
||||
"""Create a test user for migration tests."""
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
}
|
||||
data=UserCreateInput(
|
||||
id=user_id,
|
||||
email=f"test-{user_id}@example.com",
|
||||
name=f"Test User {user_id[:8]}",
|
||||
)
|
||||
)
|
||||
except UniqueViolationError:
|
||||
# User already exists, continue
|
||||
@@ -60,9 +61,9 @@ async def test_user_balance_migration_complete(server: SpinTestServer):
|
||||
# User.balance should not exist or should be None/0 if it exists
|
||||
user_balance_attr = getattr(user, "balance", None)
|
||||
if user_balance_attr is not None:
|
||||
assert (
|
||||
user_balance_attr == 0 or user_balance_attr is None
|
||||
), f"User.balance should be 0 or None, got {user_balance_attr}"
|
||||
assert user_balance_attr == 0 or user_balance_attr is None, (
|
||||
f"User.balance should be 0 or None, got {user_balance_attr}"
|
||||
)
|
||||
|
||||
# 2. Perform various credit operations using internal method (bypasses Stripe)
|
||||
await credit_system._add_transaction(
|
||||
@@ -87,9 +88,9 @@ async def test_user_balance_migration_complete(server: SpinTestServer):
|
||||
# 3. Verify UserBalance table has correct values
|
||||
user_balance = await UserBalance.prisma().find_unique(where={"userId": user_id})
|
||||
assert user_balance is not None
|
||||
assert (
|
||||
user_balance.balance == 700
|
||||
), f"UserBalance should be 700, got {user_balance.balance}"
|
||||
assert user_balance.balance == 700, (
|
||||
f"UserBalance should be 700, got {user_balance.balance}"
|
||||
)
|
||||
|
||||
# 4. CRITICAL: Verify User.balance is NEVER updated during operations
|
||||
user_after = await User.prisma().find_unique(where={"id": user_id})
|
||||
@@ -97,15 +98,15 @@ async def test_user_balance_migration_complete(server: SpinTestServer):
|
||||
user_balance_after = getattr(user_after, "balance", None)
|
||||
if user_balance_after is not None:
|
||||
# If User.balance exists, it should still be 0 (never updated)
|
||||
assert (
|
||||
user_balance_after == 0 or user_balance_after is None
|
||||
), f"User.balance should remain 0/None after operations, got {user_balance_after}. This indicates User.balance is still being used!"
|
||||
assert user_balance_after == 0 or user_balance_after is None, (
|
||||
f"User.balance should remain 0/None after operations, got {user_balance_after}. This indicates User.balance is still being used!"
|
||||
)
|
||||
|
||||
# 5. Verify get_credits always returns UserBalance value, not User.balance
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert (
|
||||
final_balance == user_balance.balance
|
||||
), f"get_credits should return UserBalance value {user_balance.balance}, got {final_balance}"
|
||||
assert final_balance == user_balance.balance, (
|
||||
f"get_credits should return UserBalance value {user_balance.balance}, got {final_balance}"
|
||||
)
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
@@ -121,14 +122,14 @@ async def test_detect_stale_user_balance_queries(server: SpinTestServer):
|
||||
try:
|
||||
# Create UserBalance with specific value
|
||||
await UserBalance.prisma().create(
|
||||
data={"userId": user_id, "balance": 5000} # $50
|
||||
data=UserBalanceCreateInput(userId=user_id, balance=5000) # $50
|
||||
)
|
||||
|
||||
# Verify that get_credits returns UserBalance value (5000), not any stale User.balance value
|
||||
balance = await credit_system.get_credits(user_id)
|
||||
assert (
|
||||
balance == 5000
|
||||
), f"Expected get_credits to return 5000 from UserBalance, got {balance}"
|
||||
assert balance == 5000, (
|
||||
f"Expected get_credits to return 5000 from UserBalance, got {balance}"
|
||||
)
|
||||
|
||||
# Verify all operations use UserBalance using internal method (bypasses Stripe)
|
||||
await credit_system._add_transaction(
|
||||
@@ -143,9 +144,9 @@ async def test_detect_stale_user_balance_queries(server: SpinTestServer):
|
||||
# Verify UserBalance table has the correct value
|
||||
user_balance = await UserBalance.prisma().find_unique(where={"userId": user_id})
|
||||
assert user_balance is not None
|
||||
assert (
|
||||
user_balance.balance == 6000
|
||||
), f"UserBalance should be 6000, got {user_balance.balance}"
|
||||
assert user_balance.balance == 6000, (
|
||||
f"UserBalance should be 6000, got {user_balance.balance}"
|
||||
)
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
@@ -160,7 +161,9 @@ async def test_concurrent_operations_use_userbalance_only(server: SpinTestServer
|
||||
|
||||
try:
|
||||
# Set initial balance in UserBalance
|
||||
await UserBalance.prisma().create(data={"userId": user_id, "balance": 1000})
|
||||
await UserBalance.prisma().create(
|
||||
data=UserBalanceCreateInput(userId=user_id, balance=1000)
|
||||
)
|
||||
|
||||
# Run concurrent operations to ensure they all use UserBalance atomic operations
|
||||
async def concurrent_spend(amount: int, label: str):
|
||||
@@ -196,9 +199,9 @@ async def test_concurrent_operations_use_userbalance_only(server: SpinTestServer
|
||||
# Verify UserBalance has correct value
|
||||
user_balance = await UserBalance.prisma().find_unique(where={"userId": user_id})
|
||||
assert user_balance is not None
|
||||
assert (
|
||||
user_balance.balance == 400
|
||||
), f"UserBalance should be 400, got {user_balance.balance}"
|
||||
assert user_balance.balance == 400, (
|
||||
f"UserBalance should be 400, got {user_balance.balance}"
|
||||
)
|
||||
|
||||
# Critical: If User.balance exists and was used, it might have wrong value
|
||||
try:
|
||||
|
||||
@@ -28,6 +28,7 @@ from prisma.models import (
|
||||
AgentNodeExecutionKeyValueData,
|
||||
)
|
||||
from prisma.types import (
|
||||
AgentGraphExecutionCreateInput,
|
||||
AgentGraphExecutionUpdateManyMutationInput,
|
||||
AgentGraphExecutionWhereInput,
|
||||
AgentNodeExecutionCreateInput,
|
||||
@@ -709,18 +710,18 @@ async def create_graph_execution(
|
||||
The id of the AgentGraphExecution and the list of ExecutionResult for each node.
|
||||
"""
|
||||
result = await AgentGraphExecution.prisma().create(
|
||||
data={
|
||||
"agentGraphId": graph_id,
|
||||
"agentGraphVersion": graph_version,
|
||||
"executionStatus": ExecutionStatus.INCOMPLETE,
|
||||
"inputs": SafeJson(inputs),
|
||||
"credentialInputs": (
|
||||
data=AgentGraphExecutionCreateInput(
|
||||
agentGraphId=graph_id,
|
||||
agentGraphVersion=graph_version,
|
||||
executionStatus=ExecutionStatus.INCOMPLETE,
|
||||
inputs=SafeJson(inputs),
|
||||
credentialInputs=(
|
||||
SafeJson(credential_inputs) if credential_inputs else Json({})
|
||||
),
|
||||
"nodesInputMasks": (
|
||||
nodesInputMasks=(
|
||||
SafeJson(nodes_input_masks) if nodes_input_masks else Json({})
|
||||
),
|
||||
"NodeExecutions": {
|
||||
NodeExecutions={
|
||||
"create": [
|
||||
AgentNodeExecutionCreateInput(
|
||||
agentNodeId=node_id,
|
||||
@@ -736,10 +737,10 @@ async def create_graph_execution(
|
||||
for node_id, node_input in starting_nodes_input
|
||||
]
|
||||
},
|
||||
"userId": user_id,
|
||||
"agentPresetId": preset_id,
|
||||
"parentGraphExecutionId": parent_graph_exec_id,
|
||||
},
|
||||
userId=user_id,
|
||||
agentPresetId=preset_id,
|
||||
parentGraphExecutionId=parent_graph_exec_id,
|
||||
),
|
||||
include=GRAPH_EXECUTION_INCLUDE_WITH_NODES,
|
||||
)
|
||||
|
||||
@@ -831,10 +832,10 @@ async def upsert_execution_output(
|
||||
"""
|
||||
Insert AgentNodeExecutionInputOutput record for as one of AgentNodeExecution.Output.
|
||||
"""
|
||||
data: AgentNodeExecutionInputOutputCreateInput = {
|
||||
"name": output_name,
|
||||
"referencedByOutputExecId": node_exec_id,
|
||||
}
|
||||
data = AgentNodeExecutionInputOutputCreateInput(
|
||||
name=output_name,
|
||||
referencedByOutputExecId=node_exec_id,
|
||||
)
|
||||
if output_data is not None:
|
||||
data["data"] = SafeJson(output_data)
|
||||
await AgentNodeExecutionInputOutput.prisma().create(data=data)
|
||||
@@ -974,14 +975,12 @@ async def update_node_execution_status(
|
||||
f"Invalid status transition: {status} has no valid source statuses"
|
||||
)
|
||||
|
||||
where_clause: Any = {
|
||||
"id": node_exec_id,
|
||||
"executionStatus": {"in": [s.value for s in allowed_from]},
|
||||
}
|
||||
if res := await AgentNodeExecution.prisma().update(
|
||||
where=cast(
|
||||
AgentNodeExecutionWhereUniqueInput,
|
||||
{
|
||||
"id": node_exec_id,
|
||||
"executionStatus": {"in": [s.value for s in allowed_from]},
|
||||
},
|
||||
),
|
||||
where=where_clause,
|
||||
data=_get_update_status_data(status, execution_data, stats),
|
||||
include=EXECUTION_RESULT_INCLUDE,
|
||||
):
|
||||
|
||||
@@ -10,7 +10,11 @@ from typing import Optional
|
||||
|
||||
from prisma.enums import ReviewStatus
|
||||
from prisma.models import PendingHumanReview
|
||||
from prisma.types import PendingHumanReviewUpdateInput
|
||||
from prisma.types import (
|
||||
PendingHumanReviewCreateInput,
|
||||
PendingHumanReviewUpdateInput,
|
||||
PendingHumanReviewUpsertInput,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.api.features.executions.review.model import (
|
||||
@@ -66,20 +70,20 @@ async def get_or_create_human_review(
|
||||
# Upsert - get existing or create new review
|
||||
review = await PendingHumanReview.prisma().upsert(
|
||||
where={"nodeExecId": node_exec_id},
|
||||
data={
|
||||
"create": {
|
||||
"userId": user_id,
|
||||
"nodeExecId": node_exec_id,
|
||||
"graphExecId": graph_exec_id,
|
||||
"graphId": graph_id,
|
||||
"graphVersion": graph_version,
|
||||
"payload": SafeJson(input_data),
|
||||
"instructions": message,
|
||||
"editable": editable,
|
||||
"status": ReviewStatus.WAITING,
|
||||
},
|
||||
"update": {}, # Do nothing on update - keep existing review as is
|
||||
},
|
||||
data=PendingHumanReviewUpsertInput(
|
||||
create=PendingHumanReviewCreateInput(
|
||||
userId=user_id,
|
||||
nodeExecId=node_exec_id,
|
||||
graphExecId=graph_exec_id,
|
||||
graphId=graph_id,
|
||||
graphVersion=graph_version,
|
||||
payload=SafeJson(input_data),
|
||||
instructions=message,
|
||||
editable=editable,
|
||||
status=ReviewStatus.WAITING,
|
||||
),
|
||||
update={}, # Do nothing on update - keep existing review as is
|
||||
),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
|
||||
@@ -7,7 +7,11 @@ import prisma
|
||||
import pydantic
|
||||
from prisma.enums import OnboardingStep
|
||||
from prisma.models import UserOnboarding
|
||||
from prisma.types import UserOnboardingCreateInput, UserOnboardingUpdateInput
|
||||
from prisma.types import (
|
||||
UserOnboardingCreateInput,
|
||||
UserOnboardingUpdateInput,
|
||||
UserOnboardingUpsertInput,
|
||||
)
|
||||
|
||||
from backend.api.features.store.model import StoreAgentDetails
|
||||
from backend.api.model import OnboardingNotificationPayload
|
||||
@@ -110,12 +114,13 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
|
||||
if data.onboardingAgentExecutionId is not None:
|
||||
update["onboardingAgentExecutionId"] = data.onboardingAgentExecutionId
|
||||
|
||||
create_input = UserOnboardingCreateInput(userId=user_id, **update)
|
||||
return await UserOnboarding.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, **update},
|
||||
"update": update,
|
||||
},
|
||||
data=UserOnboardingUpsertInput(
|
||||
create=create_input,
|
||||
update=update,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user