Compare commits

...

5 Commits

Author SHA1 Message Date
Swifty
402bec4595 pr comments 2026-01-08 14:43:07 +01:00
Swifty
1c8cba9c5f fix more linting issues 2026-01-08 13:30:51 +01:00
Swifty
072c647baa fix addintal formatting issues 2026-01-08 13:24:30 +01:00
Swifty
d5f490b85d Merge branch 'dev' into swiftyos/fix-linting-errors 2026-01-08 13:07:05 +01:00
Swifty
6686de1701 fix linter errors 2026-01-08 13:04:02 +01:00
20 changed files with 536 additions and 467 deletions

View File

@@ -3,6 +3,7 @@ from datetime import UTC, datetime
from os import getenv from os import getenv
import pytest import pytest
from prisma.types import ProfileCreateInput
from pydantic import SecretStr from pydantic import SecretStr
from backend.api.features.chat.model import ChatSession from backend.api.features.chat.model import ChatSession
@@ -49,13 +50,13 @@ async def setup_test_data():
# 1b. Create a profile with username for the user (required for store agent lookup) # 1b. Create a profile with username for the user (required for store agent lookup)
username = user.email.split("@")[0] username = user.email.split("@")[0]
await prisma.profile.create( await prisma.profile.create(
data={ data=ProfileCreateInput(
"userId": user.id, userId=user.id,
"username": username, username=username,
"name": f"Test User {username}", name=f"Test User {username}",
"description": "Test user profile", description="Test user profile",
"links": [], # Required field - empty array for test profiles links=[], # Required field - empty array for test profiles
} )
) )
# 2. Create a test graph with agent input -> agent output # 2. Create a test graph with agent input -> agent output
@@ -172,13 +173,13 @@ async def setup_llm_test_data():
# 1b. Create a profile with username for the user (required for store agent lookup) # 1b. Create a profile with username for the user (required for store agent lookup)
username = user.email.split("@")[0] username = user.email.split("@")[0]
await prisma.profile.create( await prisma.profile.create(
data={ data=ProfileCreateInput(
"userId": user.id, userId=user.id,
"username": username, username=username,
"name": f"Test User {username}", name=f"Test User {username}",
"description": "Test user profile for LLM tests", description="Test user profile for LLM tests",
"links": [], # Required field - empty array for test profiles links=[], # Required field - empty array for test profiles
} )
) )
# 2. Create test OpenAI credentials for the user # 2. Create test OpenAI credentials for the user
@@ -332,13 +333,13 @@ async def setup_firecrawl_test_data():
# 1b. Create a profile with username for the user (required for store agent lookup) # 1b. Create a profile with username for the user (required for store agent lookup)
username = user.email.split("@")[0] username = user.email.split("@")[0]
await prisma.profile.create( await prisma.profile.create(
data={ data=ProfileCreateInput(
"userId": user.id, userId=user.id,
"username": username, username=username,
"name": f"Test User {username}", name=f"Test User {username}",
"description": "Test user profile for Firecrawl tests", description="Test user profile for Firecrawl tests",
"links": [], # Required field - empty array for test profiles links=[], # Required field - empty array for test profiles
} )
) )
# NOTE: We deliberately do NOT create Firecrawl credentials for this user # NOTE: We deliberately do NOT create Firecrawl credentials for this user

View File

@@ -817,18 +817,16 @@ async def add_store_agent_to_library(
# Create LibraryAgent entry # Create LibraryAgent entry
added_agent = await prisma.models.LibraryAgent.prisma().create( added_agent = await prisma.models.LibraryAgent.prisma().create(
data={ data=prisma.types.LibraryAgentCreateInput(
"User": {"connect": {"id": user_id}}, User={"connect": {"id": user_id}},
"AgentGraph": { AgentGraph={
"connect": { "connect": {
"graphVersionId": {"id": graph.id, "version": graph.version} "graphVersionId": {"id": graph.id, "version": graph.version}
} }
}, },
"isCreatedByUser": False, isCreatedByUser=False,
"settings": SafeJson( settings=SafeJson(_initialize_graph_settings(graph_model).model_dump()),
_initialize_graph_settings(graph_model).model_dump() ),
),
},
include=library_agent_include( include=library_agent_include(
user_id, include_nodes=False, include_executions=False user_id, include_nodes=False, include_executions=False
), ),

View File

@@ -27,6 +27,13 @@ from prisma.models import OAuthApplication as PrismaOAuthApplication
from prisma.models import OAuthAuthorizationCode as PrismaOAuthAuthorizationCode from prisma.models import OAuthAuthorizationCode as PrismaOAuthAuthorizationCode
from prisma.models import OAuthRefreshToken as PrismaOAuthRefreshToken from prisma.models import OAuthRefreshToken as PrismaOAuthRefreshToken
from prisma.models import User as PrismaUser from prisma.models import User as PrismaUser
from prisma.types import (
OAuthAccessTokenCreateInput,
OAuthApplicationCreateInput,
OAuthAuthorizationCodeCreateInput,
OAuthRefreshTokenCreateInput,
UserCreateInput,
)
from backend.api.rest_api import app from backend.api.rest_api import app
@@ -48,11 +55,11 @@ def test_user_id() -> str:
async def test_user(server, test_user_id: str): async def test_user(server, test_user_id: str):
"""Create a test user in the database.""" """Create a test user in the database."""
await PrismaUser.prisma().create( await PrismaUser.prisma().create(
data={ data=UserCreateInput(
"id": test_user_id, id=test_user_id,
"email": f"oauth-test-{test_user_id}@example.com", email=f"oauth-test-{test_user_id}@example.com",
"name": "OAuth Test User", name="OAuth Test User",
} )
) )
yield test_user_id yield test_user_id
@@ -77,22 +84,22 @@ async def test_oauth_app(test_user: str):
client_secret_hash, client_secret_salt = keysmith.hash_key(client_secret_plaintext) client_secret_hash, client_secret_salt = keysmith.hash_key(client_secret_plaintext)
await PrismaOAuthApplication.prisma().create( await PrismaOAuthApplication.prisma().create(
data={ data=OAuthApplicationCreateInput(
"id": app_id, id=app_id,
"name": "Test OAuth App", name="Test OAuth App",
"description": "Test application for integration tests", description="Test application for integration tests",
"clientId": client_id, clientId=client_id,
"clientSecret": client_secret_hash, clientSecret=client_secret_hash,
"clientSecretSalt": client_secret_salt, clientSecretSalt=client_secret_salt,
"redirectUris": [ redirectUris=[
"https://example.com/callback", "https://example.com/callback",
"http://localhost:3000/callback", "http://localhost:3000/callback",
], ],
"grantTypes": ["authorization_code", "refresh_token"], grantTypes=["authorization_code", "refresh_token"],
"scopes": [APIKeyPermission.EXECUTE_GRAPH, APIKeyPermission.READ_GRAPH], scopes=[APIKeyPermission.EXECUTE_GRAPH, APIKeyPermission.READ_GRAPH],
"ownerId": test_user, ownerId=test_user,
"isActive": True, isActive=True,
} )
) )
yield { yield {
@@ -296,19 +303,19 @@ async def inactive_oauth_app(test_user: str):
client_secret_hash, client_secret_salt = keysmith.hash_key(client_secret_plaintext) client_secret_hash, client_secret_salt = keysmith.hash_key(client_secret_plaintext)
await PrismaOAuthApplication.prisma().create( await PrismaOAuthApplication.prisma().create(
data={ data=OAuthApplicationCreateInput(
"id": app_id, id=app_id,
"name": "Inactive OAuth App", name="Inactive OAuth App",
"description": "Inactive test application", description="Inactive test application",
"clientId": client_id, clientId=client_id,
"clientSecret": client_secret_hash, clientSecret=client_secret_hash,
"clientSecretSalt": client_secret_salt, clientSecretSalt=client_secret_salt,
"redirectUris": ["https://example.com/callback"], redirectUris=["https://example.com/callback"],
"grantTypes": ["authorization_code", "refresh_token"], grantTypes=["authorization_code", "refresh_token"],
"scopes": [APIKeyPermission.EXECUTE_GRAPH], scopes=[APIKeyPermission.EXECUTE_GRAPH],
"ownerId": test_user, ownerId=test_user,
"isActive": False, # Inactive! isActive=False, # Inactive!
} )
) )
yield { yield {
@@ -699,14 +706,14 @@ async def test_token_authorization_code_expired(
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
await PrismaOAuthAuthorizationCode.prisma().create( await PrismaOAuthAuthorizationCode.prisma().create(
data={ data=OAuthAuthorizationCodeCreateInput(
"code": expired_code, code=expired_code,
"applicationId": test_oauth_app["id"], applicationId=test_oauth_app["id"],
"userId": test_user, userId=test_user,
"scopes": [APIKeyPermission.EXECUTE_GRAPH], scopes=[APIKeyPermission.EXECUTE_GRAPH],
"redirectUri": test_oauth_app["redirect_uri"], redirectUri=test_oauth_app["redirect_uri"],
"expiresAt": now - timedelta(hours=1), # Already expired expiresAt=now - timedelta(hours=1), # Already expired
} )
) )
response = await client.post( response = await client.post(
@@ -942,13 +949,13 @@ async def test_token_refresh_expired(
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
await PrismaOAuthRefreshToken.prisma().create( await PrismaOAuthRefreshToken.prisma().create(
data={ data=OAuthRefreshTokenCreateInput(
"token": expired_token_hash, token=expired_token_hash,
"applicationId": test_oauth_app["id"], applicationId=test_oauth_app["id"],
"userId": test_user, userId=test_user,
"scopes": [APIKeyPermission.EXECUTE_GRAPH], scopes=[APIKeyPermission.EXECUTE_GRAPH],
"expiresAt": now - timedelta(days=1), # Already expired expiresAt=now - timedelta(days=1), # Already expired
} )
) )
response = await client.post( response = await client.post(
@@ -980,14 +987,14 @@ async def test_token_refresh_revoked(
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
await PrismaOAuthRefreshToken.prisma().create( await PrismaOAuthRefreshToken.prisma().create(
data={ data=OAuthRefreshTokenCreateInput(
"token": revoked_token_hash, token=revoked_token_hash,
"applicationId": test_oauth_app["id"], applicationId=test_oauth_app["id"],
"userId": test_user, userId=test_user,
"scopes": [APIKeyPermission.EXECUTE_GRAPH], scopes=[APIKeyPermission.EXECUTE_GRAPH],
"expiresAt": now + timedelta(days=30), # Not expired expiresAt=now + timedelta(days=30), # Not expired
"revokedAt": now - timedelta(hours=1), # But revoked revokedAt=now - timedelta(hours=1), # But revoked
} )
) )
response = await client.post( response = await client.post(
@@ -1013,19 +1020,19 @@ async def other_oauth_app(test_user: str):
client_secret_hash, client_secret_salt = keysmith.hash_key(client_secret_plaintext) client_secret_hash, client_secret_salt = keysmith.hash_key(client_secret_plaintext)
await PrismaOAuthApplication.prisma().create( await PrismaOAuthApplication.prisma().create(
data={ data=OAuthApplicationCreateInput(
"id": app_id, id=app_id,
"name": "Other OAuth App", name="Other OAuth App",
"description": "Second test application", description="Second test application",
"clientId": client_id, clientId=client_id,
"clientSecret": client_secret_hash, clientSecret=client_secret_hash,
"clientSecretSalt": client_secret_salt, clientSecretSalt=client_secret_salt,
"redirectUris": ["https://other.example.com/callback"], redirectUris=["https://other.example.com/callback"],
"grantTypes": ["authorization_code", "refresh_token"], grantTypes=["authorization_code", "refresh_token"],
"scopes": [APIKeyPermission.EXECUTE_GRAPH], scopes=[APIKeyPermission.EXECUTE_GRAPH],
"ownerId": test_user, ownerId=test_user,
"isActive": True, isActive=True,
} )
) )
yield { yield {
@@ -1052,13 +1059,13 @@ async def test_token_refresh_wrong_application(
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
await PrismaOAuthRefreshToken.prisma().create( await PrismaOAuthRefreshToken.prisma().create(
data={ data=OAuthRefreshTokenCreateInput(
"token": token_hash, token=token_hash,
"applicationId": test_oauth_app["id"], # Belongs to test_oauth_app applicationId=test_oauth_app["id"], # Belongs to test_oauth_app
"userId": test_user, userId=test_user,
"scopes": [APIKeyPermission.EXECUTE_GRAPH], scopes=[APIKeyPermission.EXECUTE_GRAPH],
"expiresAt": now + timedelta(days=30), expiresAt=now + timedelta(days=30),
} )
) )
# Try to use it with `other_oauth_app` # Try to use it with `other_oauth_app`
@@ -1267,19 +1274,19 @@ async def test_validate_access_token_fails_when_app_disabled(
client_secret_hash, client_secret_salt = keysmith.hash_key(client_secret_plaintext) client_secret_hash, client_secret_salt = keysmith.hash_key(client_secret_plaintext)
await PrismaOAuthApplication.prisma().create( await PrismaOAuthApplication.prisma().create(
data={ data=OAuthApplicationCreateInput(
"id": app_id, id=app_id,
"name": "App To Be Disabled", name="App To Be Disabled",
"description": "Test app for disabled validation", description="Test app for disabled validation",
"clientId": client_id, clientId=client_id,
"clientSecret": client_secret_hash, clientSecret=client_secret_hash,
"clientSecretSalt": client_secret_salt, clientSecretSalt=client_secret_salt,
"redirectUris": ["https://example.com/callback"], redirectUris=["https://example.com/callback"],
"grantTypes": ["authorization_code"], grantTypes=["authorization_code"],
"scopes": [APIKeyPermission.EXECUTE_GRAPH], scopes=[APIKeyPermission.EXECUTE_GRAPH],
"ownerId": test_user, ownerId=test_user,
"isActive": True, isActive=True,
} )
) )
# Create an access token directly in the database # Create an access token directly in the database
@@ -1288,13 +1295,13 @@ async def test_validate_access_token_fails_when_app_disabled(
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
await PrismaOAuthAccessToken.prisma().create( await PrismaOAuthAccessToken.prisma().create(
data={ data=OAuthAccessTokenCreateInput(
"token": token_hash, token=token_hash,
"applicationId": app_id, applicationId=app_id,
"userId": test_user, userId=test_user,
"scopes": [APIKeyPermission.EXECUTE_GRAPH], scopes=[APIKeyPermission.EXECUTE_GRAPH],
"expiresAt": now + timedelta(hours=1), expiresAt=now + timedelta(hours=1),
} )
) )
# Token should be valid while app is active # Token should be valid while app is active
@@ -1561,19 +1568,19 @@ async def test_revoke_token_from_different_app_fails_silently(
) )
await PrismaOAuthApplication.prisma().create( await PrismaOAuthApplication.prisma().create(
data={ data=OAuthApplicationCreateInput(
"id": app2_id, id=app2_id,
"name": "Second Test OAuth App", name="Second Test OAuth App",
"description": "Second test application for cross-app revocation test", description="Second test application for cross-app revocation test",
"clientId": app2_client_id, clientId=app2_client_id,
"clientSecret": app2_client_secret_hash, clientSecret=app2_client_secret_hash,
"clientSecretSalt": app2_client_secret_salt, clientSecretSalt=app2_client_secret_salt,
"redirectUris": ["https://other-app.com/callback"], redirectUris=["https://other-app.com/callback"],
"grantTypes": ["authorization_code", "refresh_token"], grantTypes=["authorization_code", "refresh_token"],
"scopes": [APIKeyPermission.EXECUTE_GRAPH, APIKeyPermission.READ_GRAPH], scopes=[APIKeyPermission.EXECUTE_GRAPH, APIKeyPermission.READ_GRAPH],
"ownerId": test_user, ownerId=test_user,
"isActive": True, isActive=True,
} )
) )
# App 2 tries to revoke App 1's access token # App 2 tries to revoke App 1's access token

View File

@@ -249,7 +249,9 @@ async def log_search_term(search_query: str):
date = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0) date = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0)
try: try:
await prisma.models.SearchTerms.prisma().create( await prisma.models.SearchTerms.prisma().create(
data={"searchTerm": search_query, "createdDate": date} data=prisma.types.SearchTermsCreateInput(
searchTerm=search_query, createdDate=date
)
) )
except Exception as e: except Exception as e:
# Fail silently here so that logging search terms doesn't break the app # Fail silently here so that logging search terms doesn't break the app
@@ -1456,11 +1458,9 @@ async def _approve_sub_agent(
# Create new version if no matching version found # Create new version if no matching version found
next_version = max((v.version for v in listing.Versions or []), default=0) + 1 next_version = max((v.version for v in listing.Versions or []), default=0) + 1
await prisma.models.StoreListingVersion.prisma(tx).create( await prisma.models.StoreListingVersion.prisma(tx).create(
data={ data=_create_sub_agent_version_data(
**_create_sub_agent_version_data(sub_graph, heading, main_agent_name), sub_graph, heading, main_agent_name, next_version, listing.id
"version": next_version, )
"storeListingId": listing.id,
}
) )
await prisma.models.StoreListing.prisma(tx).update( await prisma.models.StoreListing.prisma(tx).update(
where={"id": listing.id}, data={"hasApprovedVersion": True} where={"id": listing.id}, data={"hasApprovedVersion": True}
@@ -1468,10 +1468,14 @@ async def _approve_sub_agent(
def _create_sub_agent_version_data( def _create_sub_agent_version_data(
sub_graph: prisma.models.AgentGraph, heading: str, main_agent_name: str sub_graph: prisma.models.AgentGraph,
heading: str,
main_agent_name: str,
version: typing.Optional[int] = None,
store_listing_id: typing.Optional[str] = None,
) -> prisma.types.StoreListingVersionCreateInput: ) -> prisma.types.StoreListingVersionCreateInput:
"""Create store listing version data for a sub-agent""" """Create store listing version data for a sub-agent"""
return prisma.types.StoreListingVersionCreateInput( data = prisma.types.StoreListingVersionCreateInput(
agentGraphId=sub_graph.id, agentGraphId=sub_graph.id,
agentGraphVersion=sub_graph.version, agentGraphVersion=sub_graph.version,
name=sub_graph.name or heading, name=sub_graph.name or heading,
@@ -1486,6 +1490,11 @@ def _create_sub_agent_version_data(
imageUrls=[], # Sub-agents don't need images imageUrls=[], # Sub-agents don't need images
categories=[], # Sub-agents don't need categories categories=[], # Sub-agents don't need categories
) )
if version is not None:
data["version"] = version
if store_listing_id is not None:
data["storeListingId"] = store_listing_id
return data
async def review_store_submission( async def review_store_submission(

View File

@@ -42,6 +42,7 @@ from urllib.parse import urlparse
import click import click
from autogpt_libs.api_key.keysmith import APIKeySmith from autogpt_libs.api_key.keysmith import APIKeySmith
from prisma.enums import APIKeyPermission from prisma.enums import APIKeyPermission
from prisma.types import OAuthApplicationCreateInput
keysmith = APIKeySmith() keysmith = APIKeySmith()
@@ -147,7 +148,7 @@ def format_sql_insert(creds: dict) -> str:
sql = f""" sql = f"""
-- ============================================================ -- ============================================================
-- OAuth Application: {creds['name']} -- OAuth Application: {creds["name"]}
-- Generated: {now_iso} UTC -- Generated: {now_iso} UTC
-- ============================================================ -- ============================================================
@@ -167,14 +168,14 @@ INSERT INTO "OAuthApplication" (
"isActive" "isActive"
) )
VALUES ( VALUES (
'{creds['id']}', '{creds["id"]}',
NOW(), NOW(),
NOW(), NOW(),
'{creds['name']}', '{creds["name"]}',
{f"'{creds['description']}'" if creds['description'] else 'NULL'}, {f"'{creds['description']}'" if creds["description"] else "NULL"},
'{creds['client_id']}', '{creds["client_id"]}',
'{creds['client_secret_hash']}', '{creds["client_secret_hash"]}',
'{creds['client_secret_salt']}', '{creds["client_secret_salt"]}',
ARRAY{redirect_uris_pg}::TEXT[], ARRAY{redirect_uris_pg}::TEXT[],
ARRAY{grant_types_pg}::TEXT[], ARRAY{grant_types_pg}::TEXT[],
ARRAY{scopes_pg}::"APIKeyPermission"[], ARRAY{scopes_pg}::"APIKeyPermission"[],
@@ -186,8 +187,8 @@ VALUES (
-- ⚠️ IMPORTANT: Save these credentials securely! -- ⚠️ IMPORTANT: Save these credentials securely!
-- ============================================================ -- ============================================================
-- --
-- Client ID: {creds['client_id']} -- Client ID: {creds["client_id"]}
-- Client Secret: {creds['client_secret_plaintext']} -- Client Secret: {creds["client_secret_plaintext"]}
-- --
-- ⚠️ The client secret is shown ONLY ONCE! -- ⚠️ The client secret is shown ONLY ONCE!
-- ⚠️ Store it securely and share only with the application developer. -- ⚠️ Store it securely and share only with the application developer.
@@ -200,7 +201,7 @@ VALUES (
-- To verify the application was created: -- To verify the application was created:
-- SELECT "clientId", name, scopes, "redirectUris", "isActive" -- SELECT "clientId", name, scopes, "redirectUris", "isActive"
-- FROM "OAuthApplication" -- FROM "OAuthApplication"
-- WHERE "clientId" = '{creds['client_id']}'; -- WHERE "clientId" = '{creds["client_id"]}';
""" """
return sql return sql
@@ -834,19 +835,19 @@ async def create_test_app_in_db(
# Insert into database # Insert into database
app = await OAuthApplication.prisma().create( app = await OAuthApplication.prisma().create(
data={ data=OAuthApplicationCreateInput(
"id": creds["id"], id=creds["id"],
"name": creds["name"], name=creds["name"],
"description": creds["description"], description=creds["description"],
"clientId": creds["client_id"], clientId=creds["client_id"],
"clientSecret": creds["client_secret_hash"], clientSecret=creds["client_secret_hash"],
"clientSecretSalt": creds["client_secret_salt"], clientSecretSalt=creds["client_secret_salt"],
"redirectUris": creds["redirect_uris"], redirectUris=creds["redirect_uris"],
"grantTypes": creds["grant_types"], grantTypes=creds["grant_types"],
"scopes": creds["scopes"], scopes=creds["scopes"],
"ownerId": owner_id, ownerId=owner_id,
"isActive": True, isActive=True,
} )
) )
click.echo(f"✓ Created test OAuth application: {app.clientId}") click.echo(f"✓ Created test OAuth application: {app.clientId}")

View File

@@ -6,7 +6,7 @@ from typing import Literal, Optional
from autogpt_libs.api_key.keysmith import APIKeySmith from autogpt_libs.api_key.keysmith import APIKeySmith
from prisma.enums import APIKeyPermission, APIKeyStatus from prisma.enums import APIKeyPermission, APIKeyStatus
from prisma.models import APIKey as PrismaAPIKey from prisma.models import APIKey as PrismaAPIKey
from prisma.types import APIKeyWhereUniqueInput from prisma.types import APIKeyCreateInput, APIKeyWhereUniqueInput
from pydantic import Field from pydantic import Field
from backend.data.includes import MAX_USER_API_KEYS_FETCH from backend.data.includes import MAX_USER_API_KEYS_FETCH
@@ -82,17 +82,17 @@ async def create_api_key(
generated_key = keysmith.generate_key() generated_key = keysmith.generate_key()
saved_key_obj = await PrismaAPIKey.prisma().create( saved_key_obj = await PrismaAPIKey.prisma().create(
data={ data=APIKeyCreateInput(
"id": str(uuid.uuid4()), id=str(uuid.uuid4()),
"name": name, name=name,
"head": generated_key.head, head=generated_key.head,
"tail": generated_key.tail, tail=generated_key.tail,
"hash": generated_key.hash, hash=generated_key.hash,
"salt": generated_key.salt, salt=generated_key.salt,
"permissions": [p for p in permissions], permissions=[p for p in permissions],
"description": description, description=description,
"userId": user_id, userId=user_id,
} )
) )
return APIKeyInfo.from_db(saved_key_obj), generated_key.key return APIKeyInfo.from_db(saved_key_obj), generated_key.key

View File

@@ -22,7 +22,12 @@ from prisma.models import OAuthAccessToken as PrismaOAuthAccessToken
from prisma.models import OAuthApplication as PrismaOAuthApplication from prisma.models import OAuthApplication as PrismaOAuthApplication
from prisma.models import OAuthAuthorizationCode as PrismaOAuthAuthorizationCode from prisma.models import OAuthAuthorizationCode as PrismaOAuthAuthorizationCode
from prisma.models import OAuthRefreshToken as PrismaOAuthRefreshToken from prisma.models import OAuthRefreshToken as PrismaOAuthRefreshToken
from prisma.types import OAuthApplicationUpdateInput from prisma.types import (
OAuthAccessTokenCreateInput,
OAuthApplicationUpdateInput,
OAuthAuthorizationCodeCreateInput,
OAuthRefreshTokenCreateInput,
)
from pydantic import BaseModel, Field, SecretStr from pydantic import BaseModel, Field, SecretStr
from .base import APIAuthorizationInfo from .base import APIAuthorizationInfo
@@ -359,17 +364,17 @@ async def create_authorization_code(
expires_at = now + AUTHORIZATION_CODE_TTL expires_at = now + AUTHORIZATION_CODE_TTL
saved_code = await PrismaOAuthAuthorizationCode.prisma().create( saved_code = await PrismaOAuthAuthorizationCode.prisma().create(
data={ data=OAuthAuthorizationCodeCreateInput(
"id": str(uuid.uuid4()), id=str(uuid.uuid4()),
"code": code, code=code,
"expiresAt": expires_at, expiresAt=expires_at,
"applicationId": application_id, applicationId=application_id,
"userId": user_id, userId=user_id,
"scopes": [s for s in scopes], scopes=[s for s in scopes],
"redirectUri": redirect_uri, redirectUri=redirect_uri,
"codeChallenge": code_challenge, codeChallenge=code_challenge,
"codeChallengeMethod": code_challenge_method, codeChallengeMethod=code_challenge_method,
} )
) )
return OAuthAuthorizationCodeInfo.from_db(saved_code) return OAuthAuthorizationCodeInfo.from_db(saved_code)
@@ -490,14 +495,14 @@ async def create_access_token(
expires_at = now + ACCESS_TOKEN_TTL expires_at = now + ACCESS_TOKEN_TTL
saved_token = await PrismaOAuthAccessToken.prisma().create( saved_token = await PrismaOAuthAccessToken.prisma().create(
data={ data=OAuthAccessTokenCreateInput(
"id": str(uuid.uuid4()), id=str(uuid.uuid4()),
"token": token_hash, # SHA256 hash for direct lookup token=token_hash, # SHA256 hash for direct lookup
"expiresAt": expires_at, expiresAt=expires_at,
"applicationId": application_id, applicationId=application_id,
"userId": user_id, userId=user_id,
"scopes": [s for s in scopes], scopes=[s for s in scopes],
} )
) )
return OAuthAccessToken.from_db(saved_token, plaintext_token=plaintext_token) return OAuthAccessToken.from_db(saved_token, plaintext_token=plaintext_token)
@@ -607,14 +612,14 @@ async def create_refresh_token(
expires_at = now + REFRESH_TOKEN_TTL expires_at = now + REFRESH_TOKEN_TTL
saved_token = await PrismaOAuthRefreshToken.prisma().create( saved_token = await PrismaOAuthRefreshToken.prisma().create(
data={ data=OAuthRefreshTokenCreateInput(
"id": str(uuid.uuid4()), id=str(uuid.uuid4()),
"token": token_hash, # SHA256 hash for direct lookup token=token_hash, # SHA256 hash for direct lookup
"expiresAt": expires_at, expiresAt=expires_at,
"applicationId": application_id, applicationId=application_id,
"userId": user_id, userId=user_id,
"scopes": [s for s in scopes], scopes=[s for s in scopes],
} )
) )
return OAuthRefreshToken.from_db(saved_token, plaintext_token=plaintext_token) return OAuthRefreshToken.from_db(saved_token, plaintext_token=plaintext_token)

View File

@@ -11,6 +11,7 @@ import pytest
from prisma.enums import CreditTransactionType from prisma.enums import CreditTransactionType
from prisma.errors import UniqueViolationError from prisma.errors import UniqueViolationError
from prisma.models import CreditTransaction, User, UserBalance from prisma.models import CreditTransaction, User, UserBalance
from prisma.types import UserBalanceCreateInput, UserBalanceUpsertInput, UserCreateInput
from backend.data.credit import UserCredit from backend.data.credit import UserCredit
from backend.util.json import SafeJson from backend.util.json import SafeJson
@@ -21,11 +22,11 @@ async def create_test_user(user_id: str) -> None:
"""Create a test user for ceiling tests.""" """Create a test user for ceiling tests."""
try: try:
await User.prisma().create( await User.prisma().create(
data={ data=UserCreateInput(
"id": user_id, id=user_id,
"email": f"test-{user_id}@example.com", email=f"test-{user_id}@example.com",
"name": f"Test User {user_id[:8]}", name=f"Test User {user_id[:8]}",
} )
) )
except UniqueViolationError: except UniqueViolationError:
# User already exists, continue # User already exists, continue
@@ -33,7 +34,10 @@ async def create_test_user(user_id: str) -> None:
await UserBalance.prisma().upsert( await UserBalance.prisma().upsert(
where={"userId": user_id}, where={"userId": user_id},
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}}, data=UserBalanceUpsertInput(
create=UserBalanceCreateInput(userId=user_id, balance=0),
update={"balance": 0},
),
) )

View File

@@ -14,6 +14,7 @@ import pytest
from prisma.enums import CreditTransactionType from prisma.enums import CreditTransactionType
from prisma.errors import UniqueViolationError from prisma.errors import UniqueViolationError
from prisma.models import CreditTransaction, User, UserBalance from prisma.models import CreditTransaction, User, UserBalance
from prisma.types import UserBalanceCreateInput, UserBalanceUpsertInput, UserCreateInput
from backend.data.credit import POSTGRES_INT_MAX, UsageTransactionMetadata, UserCredit from backend.data.credit import POSTGRES_INT_MAX, UsageTransactionMetadata, UserCredit
from backend.util.exceptions import InsufficientBalanceError from backend.util.exceptions import InsufficientBalanceError
@@ -28,11 +29,11 @@ async def create_test_user(user_id: str) -> None:
"""Create a test user with initial balance.""" """Create a test user with initial balance."""
try: try:
await User.prisma().create( await User.prisma().create(
data={ data=UserCreateInput(
"id": user_id, id=user_id,
"email": f"test-{user_id}@example.com", email=f"test-{user_id}@example.com",
"name": f"Test User {user_id[:8]}", name=f"Test User {user_id[:8]}",
} )
) )
except UniqueViolationError: except UniqueViolationError:
# User already exists, continue # User already exists, continue
@@ -41,7 +42,10 @@ async def create_test_user(user_id: str) -> None:
# Ensure UserBalance record exists # Ensure UserBalance record exists
await UserBalance.prisma().upsert( await UserBalance.prisma().upsert(
where={"userId": user_id}, where={"userId": user_id},
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}}, data=UserBalanceUpsertInput(
create=UserBalanceCreateInput(userId=user_id, balance=0),
update={"balance": 0},
),
) )
@@ -342,10 +346,10 @@ async def test_integer_overflow_protection(server: SpinTestServer):
# First, set balance near max # First, set balance near max
await UserBalance.prisma().upsert( await UserBalance.prisma().upsert(
where={"userId": user_id}, where={"userId": user_id},
data={ data=UserBalanceUpsertInput(
"create": {"userId": user_id, "balance": max_int - 100}, create=UserBalanceCreateInput(userId=user_id, balance=max_int - 100),
"update": {"balance": max_int - 100}, update={"balance": max_int - 100},
}, ),
) )
# Try to add more than possible - should clamp to POSTGRES_INT_MAX # Try to add more than possible - should clamp to POSTGRES_INT_MAX
@@ -507,7 +511,7 @@ async def test_concurrent_multiple_spends_sufficient_balance(server: SpinTestSer
sorted_timings = sorted(timings.items(), key=lambda x: x[1]["start"]) sorted_timings = sorted(timings.items(), key=lambda x: x[1]["start"])
print("\nExecution order by start time:") print("\nExecution order by start time:")
for i, (label, timing) in enumerate(sorted_timings): for i, (label, timing) in enumerate(sorted_timings):
print(f" {i+1}. {label}: {timing['start']:.4f} -> {timing['end']:.4f}") print(f" {i + 1}. {label}: {timing['start']:.4f} -> {timing['end']:.4f}")
# Check for overlap (true concurrency) vs serialization # Check for overlap (true concurrency) vs serialization
overlaps = [] overlaps = []
@@ -546,7 +550,7 @@ async def test_concurrent_multiple_spends_sufficient_balance(server: SpinTestSer
print("\nDatabase transaction order (by createdAt):") print("\nDatabase transaction order (by createdAt):")
for i, tx in enumerate(transactions): for i, tx in enumerate(transactions):
print( print(
f" {i+1}. Amount {tx.amount}, Running balance: {tx.runningBalance}, Created: {tx.createdAt}" f" {i + 1}. Amount {tx.amount}, Running balance: {tx.runningBalance}, Created: {tx.createdAt}"
) )
# Verify running balances are chronologically consistent (ordered by createdAt) # Verify running balances are chronologically consistent (ordered by createdAt)
@@ -707,7 +711,7 @@ async def test_prove_database_locking_behavior(server: SpinTestServer):
for i, result in enumerate(sorted_results): for i, result in enumerate(sorted_results):
print( print(
f" {i+1}. {result['label']}: DB operation took {result['db_duration']:.4f}s" f" {i + 1}. {result['label']}: DB operation took {result['db_duration']:.4f}s"
) )
# Check if any operations overlapped at the database level # Check if any operations overlapped at the database level

View File

@@ -8,6 +8,7 @@ which would have caught the CreditTransactionType enum casting bug.
import pytest import pytest
from prisma.enums import CreditTransactionType from prisma.enums import CreditTransactionType
from prisma.models import CreditTransaction, User, UserBalance from prisma.models import CreditTransaction, User, UserBalance
from prisma.types import UserCreateInput
from backend.data.credit import ( from backend.data.credit import (
AutoTopUpConfig, AutoTopUpConfig,
@@ -29,12 +30,12 @@ async def cleanup_test_user():
# Create the user first # Create the user first
try: try:
await User.prisma().create( await User.prisma().create(
data={ data=UserCreateInput(
"id": user_id, id=user_id,
"email": f"test-{user_id}@example.com", email=f"test-{user_id}@example.com",
"topUpConfig": SafeJson({}), topUpConfig=SafeJson({}),
"timezone": "UTC", timezone="UTC",
} )
) )
except Exception: except Exception:
# User might already exist, that's fine # User might already exist, that's fine

View File

@@ -12,6 +12,12 @@ import pytest
import stripe import stripe
from prisma.enums import CreditTransactionType from prisma.enums import CreditTransactionType
from prisma.models import CreditRefundRequest, CreditTransaction, User, UserBalance from prisma.models import CreditRefundRequest, CreditTransaction, User, UserBalance
from prisma.types import (
CreditRefundRequestCreateInput,
CreditTransactionCreateInput,
UserBalanceCreateInput,
UserCreateInput,
)
from backend.data.credit import UserCredit from backend.data.credit import UserCredit
from backend.util.json import SafeJson from backend.util.json import SafeJson
@@ -35,32 +41,32 @@ async def setup_test_user_with_topup():
# Create user # Create user
await User.prisma().create( await User.prisma().create(
data={ data=UserCreateInput(
"id": REFUND_TEST_USER_ID, id=REFUND_TEST_USER_ID,
"email": f"{REFUND_TEST_USER_ID}@example.com", email=f"{REFUND_TEST_USER_ID}@example.com",
"name": "Refund Test User", name="Refund Test User",
} )
) )
# Create user balance # Create user balance
await UserBalance.prisma().create( await UserBalance.prisma().create(
data={ data=UserBalanceCreateInput(
"userId": REFUND_TEST_USER_ID, userId=REFUND_TEST_USER_ID,
"balance": 1000, # $10 balance=1000, # $10
} )
) )
# Create a top-up transaction that can be refunded # Create a top-up transaction that can be refunded
topup_tx = await CreditTransaction.prisma().create( topup_tx = await CreditTransaction.prisma().create(
data={ data=CreditTransactionCreateInput(
"userId": REFUND_TEST_USER_ID, userId=REFUND_TEST_USER_ID,
"amount": 1000, amount=1000,
"type": CreditTransactionType.TOP_UP, type=CreditTransactionType.TOP_UP,
"transactionKey": "pi_test_12345", transactionKey="pi_test_12345",
"runningBalance": 1000, runningBalance=1000,
"isActive": True, isActive=True,
"metadata": SafeJson({"stripe_payment_intent": "pi_test_12345"}), metadata=SafeJson({"stripe_payment_intent": "pi_test_12345"}),
} )
) )
return topup_tx return topup_tx
@@ -93,12 +99,12 @@ async def test_deduct_credits_atomic(server: SpinTestServer):
# Create refund request record (simulating webhook flow) # Create refund request record (simulating webhook flow)
await CreditRefundRequest.prisma().create( await CreditRefundRequest.prisma().create(
data={ data=CreditRefundRequestCreateInput(
"userId": REFUND_TEST_USER_ID, userId=REFUND_TEST_USER_ID,
"amount": 500, amount=500,
"transactionKey": topup_tx.transactionKey, # Should match the original transaction transactionKey=topup_tx.transactionKey, # Should match the original transaction
"reason": "Test refund", reason="Test refund",
} )
) )
# Call deduct_credits # Call deduct_credits
@@ -286,12 +292,12 @@ async def test_concurrent_refunds(server: SpinTestServer):
refund_requests = [] refund_requests = []
for i in range(5): for i in range(5):
req = await CreditRefundRequest.prisma().create( req = await CreditRefundRequest.prisma().create(
data={ data=CreditRefundRequestCreateInput(
"userId": REFUND_TEST_USER_ID, userId=REFUND_TEST_USER_ID,
"amount": 100, # $1 each amount=100, # $1 each
"transactionKey": topup_tx.transactionKey, transactionKey=topup_tx.transactionKey,
"reason": f"Test refund {i}", reason=f"Test refund {i}",
} )
) )
refund_requests.append(req) refund_requests.append(req)

View File

@@ -3,6 +3,11 @@ from datetime import datetime, timedelta, timezone
import pytest import pytest
from prisma.enums import CreditTransactionType from prisma.enums import CreditTransactionType
from prisma.models import CreditTransaction, UserBalance from prisma.models import CreditTransaction, UserBalance
from prisma.types import (
CreditTransactionCreateInput,
UserBalanceCreateInput,
UserBalanceUpsertInput,
)
from backend.blocks.llm import AITextGeneratorBlock from backend.blocks.llm import AITextGeneratorBlock
from backend.data.block import get_block from backend.data.block import get_block
@@ -23,10 +28,10 @@ async def disable_test_user_transactions():
old_date = datetime.now(timezone.utc) - timedelta(days=35) # More than a month ago old_date = datetime.now(timezone.utc) - timedelta(days=35) # More than a month ago
await UserBalance.prisma().upsert( await UserBalance.prisma().upsert(
where={"userId": DEFAULT_USER_ID}, where={"userId": DEFAULT_USER_ID},
data={ data=UserBalanceUpsertInput(
"create": {"userId": DEFAULT_USER_ID, "balance": 0}, create=UserBalanceCreateInput(userId=DEFAULT_USER_ID, balance=0),
"update": {"balance": 0, "updatedAt": old_date}, update={"balance": 0, "updatedAt": old_date},
}, ),
) )
@@ -140,23 +145,23 @@ async def test_block_credit_reset(server: SpinTestServer):
# Manually create a transaction with month 1 timestamp to establish history # Manually create a transaction with month 1 timestamp to establish history
await CreditTransaction.prisma().create( await CreditTransaction.prisma().create(
data={ data=CreditTransactionCreateInput(
"userId": DEFAULT_USER_ID, userId=DEFAULT_USER_ID,
"amount": 100, amount=100,
"type": CreditTransactionType.TOP_UP, type=CreditTransactionType.TOP_UP,
"runningBalance": 1100, runningBalance=1100,
"isActive": True, isActive=True,
"createdAt": month1, # Set specific timestamp createdAt=month1, # Set specific timestamp
} )
) )
# Update user balance to match # Update user balance to match
await UserBalance.prisma().upsert( await UserBalance.prisma().upsert(
where={"userId": DEFAULT_USER_ID}, where={"userId": DEFAULT_USER_ID},
data={ data=UserBalanceUpsertInput(
"create": {"userId": DEFAULT_USER_ID, "balance": 1100}, create=UserBalanceCreateInput(userId=DEFAULT_USER_ID, balance=1100),
"update": {"balance": 1100}, update={"balance": 1100},
}, ),
) )
# Now test month 2 behavior # Now test month 2 behavior
@@ -175,14 +180,14 @@ async def test_block_credit_reset(server: SpinTestServer):
# Create a month 2 transaction to update the last transaction time # Create a month 2 transaction to update the last transaction time
await CreditTransaction.prisma().create( await CreditTransaction.prisma().create(
data={ data=CreditTransactionCreateInput(
"userId": DEFAULT_USER_ID, userId=DEFAULT_USER_ID,
"amount": -700, # Spent 700 to get to 400 amount=-700, # Spent 700 to get to 400
"type": CreditTransactionType.USAGE, type=CreditTransactionType.USAGE,
"runningBalance": 400, runningBalance=400,
"isActive": True, isActive=True,
"createdAt": month2, createdAt=month2,
} )
) )
# Move to month 3 # Move to month 3

View File

@@ -12,6 +12,7 @@ import pytest
from prisma.enums import CreditTransactionType from prisma.enums import CreditTransactionType
from prisma.errors import UniqueViolationError from prisma.errors import UniqueViolationError
from prisma.models import CreditTransaction, User, UserBalance from prisma.models import CreditTransaction, User, UserBalance
from prisma.types import UserBalanceCreateInput, UserBalanceUpsertInput, UserCreateInput
from backend.data.credit import POSTGRES_INT_MIN, UserCredit from backend.data.credit import POSTGRES_INT_MIN, UserCredit
from backend.util.test import SpinTestServer from backend.util.test import SpinTestServer
@@ -21,11 +22,11 @@ async def create_test_user(user_id: str) -> None:
"""Create a test user for underflow tests.""" """Create a test user for underflow tests."""
try: try:
await User.prisma().create( await User.prisma().create(
data={ data=UserCreateInput(
"id": user_id, id=user_id,
"email": f"test-{user_id}@example.com", email=f"test-{user_id}@example.com",
"name": f"Test User {user_id[:8]}", name=f"Test User {user_id[:8]}",
} )
) )
except UniqueViolationError: except UniqueViolationError:
# User already exists, continue # User already exists, continue
@@ -33,7 +34,10 @@ async def create_test_user(user_id: str) -> None:
await UserBalance.prisma().upsert( await UserBalance.prisma().upsert(
where={"userId": user_id}, where={"userId": user_id},
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}}, data=UserBalanceUpsertInput(
create=UserBalanceCreateInput(userId=user_id, balance=0),
update={"balance": 0},
),
) )
@@ -66,14 +70,14 @@ async def test_debug_underflow_step_by_step(server: SpinTestServer):
initial_balance_target = POSTGRES_INT_MIN + 100 initial_balance_target = POSTGRES_INT_MIN + 100
# Use direct database update to set the balance close to underflow # Use direct database update to set the balance close to underflow
from prisma.models import UserBalance
await UserBalance.prisma().upsert( await UserBalance.prisma().upsert(
where={"userId": user_id}, where={"userId": user_id},
data={ data=UserBalanceUpsertInput(
"create": {"userId": user_id, "balance": initial_balance_target}, create=UserBalanceCreateInput(
"update": {"balance": initial_balance_target}, userId=user_id, balance=initial_balance_target
}, ),
update={"balance": initial_balance_target},
),
) )
current_balance = await credit_system.get_credits(user_id) current_balance = await credit_system.get_credits(user_id)
@@ -110,10 +114,10 @@ async def test_debug_underflow_step_by_step(server: SpinTestServer):
# Set balance to exactly POSTGRES_INT_MIN # Set balance to exactly POSTGRES_INT_MIN
await UserBalance.prisma().upsert( await UserBalance.prisma().upsert(
where={"userId": user_id}, where={"userId": user_id},
data={ data=UserBalanceUpsertInput(
"create": {"userId": user_id, "balance": POSTGRES_INT_MIN}, create=UserBalanceCreateInput(userId=user_id, balance=POSTGRES_INT_MIN),
"update": {"balance": POSTGRES_INT_MIN}, update={"balance": POSTGRES_INT_MIN},
}, ),
) )
edge_balance = await credit_system.get_credits(user_id) edge_balance = await credit_system.get_credits(user_id)
@@ -147,15 +151,13 @@ async def test_underflow_protection_large_refunds(server: SpinTestServer):
# Set up balance close to underflow threshold to test the protection # Set up balance close to underflow threshold to test the protection
# Set balance to POSTGRES_INT_MIN + 1000, then try to subtract 2000 # Set balance to POSTGRES_INT_MIN + 1000, then try to subtract 2000
# This should trigger underflow protection # This should trigger underflow protection
from prisma.models import UserBalance
test_balance = POSTGRES_INT_MIN + 1000 test_balance = POSTGRES_INT_MIN + 1000
await UserBalance.prisma().upsert( await UserBalance.prisma().upsert(
where={"userId": user_id}, where={"userId": user_id},
data={ data=UserBalanceUpsertInput(
"create": {"userId": user_id, "balance": test_balance}, create=UserBalanceCreateInput(userId=user_id, balance=test_balance),
"update": {"balance": test_balance}, update={"balance": test_balance},
}, ),
) )
current_balance = await credit_system.get_credits(user_id) current_balance = await credit_system.get_credits(user_id)
@@ -212,15 +214,13 @@ async def test_multiple_large_refunds_cumulative_underflow(server: SpinTestServe
try: try:
# Set up balance close to underflow threshold # Set up balance close to underflow threshold
from prisma.models import UserBalance
initial_balance = POSTGRES_INT_MIN + 500 # Close to minimum but with some room initial_balance = POSTGRES_INT_MIN + 500 # Close to minimum but with some room
await UserBalance.prisma().upsert( await UserBalance.prisma().upsert(
where={"userId": user_id}, where={"userId": user_id},
data={ data=UserBalanceUpsertInput(
"create": {"userId": user_id, "balance": initial_balance}, create=UserBalanceCreateInput(userId=user_id, balance=initial_balance),
"update": {"balance": initial_balance}, update={"balance": initial_balance},
}, ),
) )
# Apply multiple refunds that would cumulatively underflow # Apply multiple refunds that would cumulatively underflow
@@ -295,10 +295,10 @@ async def test_concurrent_large_refunds_no_underflow(server: SpinTestServer):
initial_balance = POSTGRES_INT_MIN + 1000 # Close to minimum initial_balance = POSTGRES_INT_MIN + 1000 # Close to minimum
await UserBalance.prisma().upsert( await UserBalance.prisma().upsert(
where={"userId": user_id}, where={"userId": user_id},
data={ data=UserBalanceUpsertInput(
"create": {"userId": user_id, "balance": initial_balance}, create=UserBalanceCreateInput(userId=user_id, balance=initial_balance),
"update": {"balance": initial_balance}, update={"balance": initial_balance},
}, ),
) )
async def large_refund(amount: int, label: str): async def large_refund(amount: int, label: str):

View File

@@ -14,6 +14,7 @@ import pytest
from prisma.enums import CreditTransactionType from prisma.enums import CreditTransactionType
from prisma.errors import UniqueViolationError from prisma.errors import UniqueViolationError
from prisma.models import CreditTransaction, User, UserBalance from prisma.models import CreditTransaction, User, UserBalance
from prisma.types import UserBalanceCreateInput, UserCreateInput
from backend.data.credit import UsageTransactionMetadata, UserCredit from backend.data.credit import UsageTransactionMetadata, UserCredit
from backend.util.json import SafeJson from backend.util.json import SafeJson
@@ -24,11 +25,11 @@ async def create_test_user(user_id: str) -> None:
"""Create a test user for migration tests.""" """Create a test user for migration tests."""
try: try:
await User.prisma().create( await User.prisma().create(
data={ data=UserCreateInput(
"id": user_id, id=user_id,
"email": f"test-{user_id}@example.com", email=f"test-{user_id}@example.com",
"name": f"Test User {user_id[:8]}", name=f"Test User {user_id[:8]}",
} )
) )
except UniqueViolationError: except UniqueViolationError:
# User already exists, continue # User already exists, continue
@@ -121,7 +122,7 @@ async def test_detect_stale_user_balance_queries(server: SpinTestServer):
try: try:
# Create UserBalance with specific value # Create UserBalance with specific value
await UserBalance.prisma().create( await UserBalance.prisma().create(
data={"userId": user_id, "balance": 5000} # $50 data=UserBalanceCreateInput(userId=user_id, balance=5000) # $50
) )
# Verify that get_credits returns UserBalance value (5000), not any stale User.balance value # Verify that get_credits returns UserBalance value (5000), not any stale User.balance value
@@ -160,7 +161,9 @@ async def test_concurrent_operations_use_userbalance_only(server: SpinTestServer
try: try:
# Set initial balance in UserBalance # Set initial balance in UserBalance
await UserBalance.prisma().create(data={"userId": user_id, "balance": 1000}) await UserBalance.prisma().create(
data=UserBalanceCreateInput(userId=user_id, balance=1000)
)
# Run concurrent operations to ensure they all use UserBalance atomic operations # Run concurrent operations to ensure they all use UserBalance atomic operations
async def concurrent_spend(amount: int, label: str): async def concurrent_spend(amount: int, label: str):

View File

@@ -28,6 +28,7 @@ from prisma.models import (
AgentNodeExecutionKeyValueData, AgentNodeExecutionKeyValueData,
) )
from prisma.types import ( from prisma.types import (
AgentGraphExecutionCreateInput,
AgentGraphExecutionUpdateManyMutationInput, AgentGraphExecutionUpdateManyMutationInput,
AgentGraphExecutionWhereInput, AgentGraphExecutionWhereInput,
AgentNodeExecutionCreateInput, AgentNodeExecutionCreateInput,
@@ -35,7 +36,6 @@ from prisma.types import (
AgentNodeExecutionKeyValueDataCreateInput, AgentNodeExecutionKeyValueDataCreateInput,
AgentNodeExecutionUpdateInput, AgentNodeExecutionUpdateInput,
AgentNodeExecutionWhereInput, AgentNodeExecutionWhereInput,
AgentNodeExecutionWhereUniqueInput,
) )
from pydantic import BaseModel, ConfigDict, JsonValue, ValidationError from pydantic import BaseModel, ConfigDict, JsonValue, ValidationError
from pydantic.fields import Field from pydantic.fields import Field
@@ -709,18 +709,18 @@ async def create_graph_execution(
The id of the AgentGraphExecution and the list of ExecutionResult for each node. The id of the AgentGraphExecution and the list of ExecutionResult for each node.
""" """
result = await AgentGraphExecution.prisma().create( result = await AgentGraphExecution.prisma().create(
data={ data=AgentGraphExecutionCreateInput(
"agentGraphId": graph_id, agentGraphId=graph_id,
"agentGraphVersion": graph_version, agentGraphVersion=graph_version,
"executionStatus": ExecutionStatus.INCOMPLETE, executionStatus=ExecutionStatus.INCOMPLETE,
"inputs": SafeJson(inputs), inputs=SafeJson(inputs),
"credentialInputs": ( credentialInputs=(
SafeJson(credential_inputs) if credential_inputs else Json({}) SafeJson(credential_inputs) if credential_inputs else Json({})
), ),
"nodesInputMasks": ( nodesInputMasks=(
SafeJson(nodes_input_masks) if nodes_input_masks else Json({}) SafeJson(nodes_input_masks) if nodes_input_masks else Json({})
), ),
"NodeExecutions": { NodeExecutions={
"create": [ "create": [
AgentNodeExecutionCreateInput( AgentNodeExecutionCreateInput(
agentNodeId=node_id, agentNodeId=node_id,
@@ -736,10 +736,10 @@ async def create_graph_execution(
for node_id, node_input in starting_nodes_input for node_id, node_input in starting_nodes_input
] ]
}, },
"userId": user_id, userId=user_id,
"agentPresetId": preset_id, agentPresetId=preset_id,
"parentGraphExecutionId": parent_graph_exec_id, parentGraphExecutionId=parent_graph_exec_id,
}, ),
include=GRAPH_EXECUTION_INCLUDE_WITH_NODES, include=GRAPH_EXECUTION_INCLUDE_WITH_NODES,
) )
@@ -831,10 +831,10 @@ async def upsert_execution_output(
""" """
Insert AgentNodeExecutionInputOutput record for as one of AgentNodeExecution.Output. Insert AgentNodeExecutionInputOutput record for as one of AgentNodeExecution.Output.
""" """
data: AgentNodeExecutionInputOutputCreateInput = { data = AgentNodeExecutionInputOutputCreateInput(
"name": output_name, name=output_name,
"referencedByOutputExecId": node_exec_id, referencedByOutputExecId=node_exec_id,
} )
if output_data is not None: if output_data is not None:
data["data"] = SafeJson(output_data) data["data"] = SafeJson(output_data)
await AgentNodeExecutionInputOutput.prisma().create(data=data) await AgentNodeExecutionInputOutput.prisma().create(data=data)
@@ -964,6 +964,12 @@ async def update_node_execution_status(
execution_data: BlockInput | None = None, execution_data: BlockInput | None = None,
stats: dict[str, Any] | None = None, stats: dict[str, Any] | None = None,
) -> NodeExecutionResult: ) -> NodeExecutionResult:
"""
Update a node execution's status with validation of allowed transitions.
⚠️ Internal executor use only - no user_id check. Callers (executor/manager.py)
are responsible for validating user authorization before invoking this function.
"""
if status == ExecutionStatus.QUEUED and execution_data is None: if status == ExecutionStatus.QUEUED and execution_data is None:
raise ValueError("Execution data must be provided when queuing an execution.") raise ValueError("Execution data must be provided when queuing an execution.")
@@ -974,25 +980,27 @@ async def update_node_execution_status(
f"Invalid status transition: {status} has no valid source statuses" f"Invalid status transition: {status} has no valid source statuses"
) )
if res := await AgentNodeExecution.prisma().update( # Fetch current execution to validate status transition
where=cast( current = await AgentNodeExecution.prisma().find_unique(
AgentNodeExecutionWhereUniqueInput, where={"id": node_exec_id}, include=EXECUTION_RESULT_INCLUDE
{ )
"id": node_exec_id, if not current:
"executionStatus": {"in": [s.value for s in allowed_from]}, raise ValueError(f"Execution {node_exec_id} not found.")
},
), # Validate current status allows transition to the new status
if current.executionStatus not in allowed_from:
# Return current state without updating if transition is not allowed
return NodeExecutionResult.from_db(current)
# Perform the update with only the unique identifier
res = await AgentNodeExecution.prisma().update(
where={"id": node_exec_id},
data=_get_update_status_data(status, execution_data, stats), data=_get_update_status_data(status, execution_data, stats),
include=EXECUTION_RESULT_INCLUDE, include=EXECUTION_RESULT_INCLUDE,
): )
return NodeExecutionResult.from_db(res) if not res:
raise ValueError(f"Failed to update execution {node_exec_id}.")
if res := await AgentNodeExecution.prisma().find_unique( return NodeExecutionResult.from_db(res)
where={"id": node_exec_id}, include=EXECUTION_RESULT_INCLUDE
):
return NodeExecutionResult.from_db(res)
raise ValueError(f"Execution {node_exec_id} not found.")
def _get_update_status_data( def _get_update_status_data(

View File

@@ -10,7 +10,11 @@ from typing import Optional
from prisma.enums import ReviewStatus from prisma.enums import ReviewStatus
from prisma.models import PendingHumanReview from prisma.models import PendingHumanReview
from prisma.types import PendingHumanReviewUpdateInput from prisma.types import (
PendingHumanReviewCreateInput,
PendingHumanReviewUpdateInput,
PendingHumanReviewUpsertInput,
)
from pydantic import BaseModel from pydantic import BaseModel
from backend.api.features.executions.review.model import ( from backend.api.features.executions.review.model import (
@@ -66,20 +70,20 @@ async def get_or_create_human_review(
# Upsert - get existing or create new review # Upsert - get existing or create new review
review = await PendingHumanReview.prisma().upsert( review = await PendingHumanReview.prisma().upsert(
where={"nodeExecId": node_exec_id}, where={"nodeExecId": node_exec_id},
data={ data=PendingHumanReviewUpsertInput(
"create": { create=PendingHumanReviewCreateInput(
"userId": user_id, userId=user_id,
"nodeExecId": node_exec_id, nodeExecId=node_exec_id,
"graphExecId": graph_exec_id, graphExecId=graph_exec_id,
"graphId": graph_id, graphId=graph_id,
"graphVersion": graph_version, graphVersion=graph_version,
"payload": SafeJson(input_data), payload=SafeJson(input_data),
"instructions": message, instructions=message,
"editable": editable, editable=editable,
"status": ReviewStatus.WAITING, status=ReviewStatus.WAITING,
}, ),
"update": {}, # Do nothing on update - keep existing review as is update={}, # Do nothing on update - keep existing review as is
}, ),
) )
logger.info( logger.info(

View File

@@ -7,7 +7,11 @@ import prisma
import pydantic import pydantic
from prisma.enums import OnboardingStep from prisma.enums import OnboardingStep
from prisma.models import UserOnboarding from prisma.models import UserOnboarding
from prisma.types import UserOnboardingCreateInput, UserOnboardingUpdateInput from prisma.types import (
UserOnboardingCreateInput,
UserOnboardingUpdateInput,
UserOnboardingUpsertInput,
)
from backend.api.features.store.model import StoreAgentDetails from backend.api.features.store.model import StoreAgentDetails
from backend.api.model import OnboardingNotificationPayload from backend.api.model import OnboardingNotificationPayload
@@ -92,6 +96,7 @@ async def reset_user_onboarding(user_id: str):
async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate): async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
update: UserOnboardingUpdateInput = {} update: UserOnboardingUpdateInput = {}
# get_user_onboarding guarantees the record exists via upsert
onboarding = await get_user_onboarding(user_id) onboarding = await get_user_onboarding(user_id)
if data.walletShown: if data.walletShown:
update["walletShown"] = data.walletShown update["walletShown"] = data.walletShown
@@ -110,12 +115,14 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
if data.onboardingAgentExecutionId is not None: if data.onboardingAgentExecutionId is not None:
update["onboardingAgentExecutionId"] = data.onboardingAgentExecutionId update["onboardingAgentExecutionId"] = data.onboardingAgentExecutionId
# The create branch is never taken since get_user_onboarding ensures the record exists,
# but upsert requires a create payload so we provide a minimal one
return await UserOnboarding.prisma().upsert( return await UserOnboarding.prisma().upsert(
where={"userId": user_id}, where={"userId": user_id},
data={ data=UserOnboardingUpsertInput(
"create": {"userId": user_id, **update}, create=UserOnboardingCreateInput(userId=user_id),
"update": update, update=update,
}, ),
) )

View File

@@ -22,6 +22,7 @@ import random
from typing import Any, Dict, List from typing import Any, Dict, List
from faker import Faker from faker import Faker
from prisma.types import AgentBlockCreateInput
# Import API functions from the backend # Import API functions from the backend
from backend.api.features.library.db import create_library_agent, create_preset from backend.api.features.library.db import create_library_agent, create_preset
@@ -179,12 +180,12 @@ class TestDataCreator:
for block in blocks_to_create: for block in blocks_to_create:
try: try:
await prisma.agentblock.create( await prisma.agentblock.create(
data={ data=AgentBlockCreateInput(
"id": block.id, id=block.id,
"name": block.name, name=block.name,
"inputSchema": "{}", inputSchema="{}",
"outputSchema": "{}", outputSchema="{}",
} )
) )
except Exception as e: except Exception as e:
print(f"Error creating block {block.name}: {e}") print(f"Error creating block {block.name}: {e}")

View File

@@ -30,13 +30,19 @@ from prisma.types import (
AgentGraphCreateInput, AgentGraphCreateInput,
AgentNodeCreateInput, AgentNodeCreateInput,
AgentNodeLinkCreateInput, AgentNodeLinkCreateInput,
AgentPresetCreateInput,
AnalyticsDetailsCreateInput, AnalyticsDetailsCreateInput,
AnalyticsMetricsCreateInput, AnalyticsMetricsCreateInput,
APIKeyCreateInput,
CreditTransactionCreateInput, CreditTransactionCreateInput,
IntegrationWebhookCreateInput, IntegrationWebhookCreateInput,
LibraryAgentCreateInput,
ProfileCreateInput, ProfileCreateInput,
StoreListingCreateInput,
StoreListingReviewCreateInput, StoreListingReviewCreateInput,
StoreListingVersionCreateInput,
UserCreateInput, UserCreateInput,
UserOnboardingCreateInput,
) )
faker = Faker() faker = Faker()
@@ -172,14 +178,14 @@ async def main():
for _ in range(num_presets): # Create 1 AgentPreset per user for _ in range(num_presets): # Create 1 AgentPreset per user
graph = random.choice(agent_graphs) graph = random.choice(agent_graphs)
preset = await db.agentpreset.create( preset = await db.agentpreset.create(
data={ data=AgentPresetCreateInput(
"name": faker.sentence(nb_words=3), name=faker.sentence(nb_words=3),
"description": faker.text(max_nb_chars=200), description=faker.text(max_nb_chars=200),
"userId": user.id, userId=user.id,
"agentGraphId": graph.id, agentGraphId=graph.id,
"agentGraphVersion": graph.version, agentGraphVersion=graph.version,
"isActive": True, isActive=True,
} )
) )
agent_presets.append(preset) agent_presets.append(preset)
@@ -220,18 +226,18 @@ async def main():
) )
library_agent = await db.libraryagent.create( library_agent = await db.libraryagent.create(
data={ data=LibraryAgentCreateInput(
"userId": user.id, userId=user.id,
"agentGraphId": graph.id, agentGraphId=graph.id,
"agentGraphVersion": graph.version, agentGraphVersion=graph.version,
"creatorId": creator_profile.id if creator_profile else None, creatorId=creator_profile.id if creator_profile else None,
"imageUrl": get_image() if random.random() < 0.5 else None, imageUrl=get_image() if random.random() < 0.5 else None,
"useGraphIsActiveVersion": random.choice([True, False]), useGraphIsActiveVersion=random.choice([True, False]),
"isFavorite": random.choice([True, False]), isFavorite=random.choice([True, False]),
"isCreatedByUser": random.choice([True, False]), isCreatedByUser=random.choice([True, False]),
"isArchived": random.choice([True, False]), isArchived=random.choice([True, False]),
"isDeleted": random.choice([True, False]), isDeleted=random.choice([True, False]),
} )
) )
library_agents.append(library_agent) library_agents.append(library_agent)
@@ -392,13 +398,13 @@ async def main():
user = random.choice(users) user = random.choice(users)
slug = faker.slug() slug = faker.slug()
listing = await db.storelisting.create( listing = await db.storelisting.create(
data={ data=StoreListingCreateInput(
"agentGraphId": graph.id, agentGraphId=graph.id,
"agentGraphVersion": graph.version, agentGraphVersion=graph.version,
"owningUserId": user.id, owningUserId=user.id,
"hasApprovedVersion": random.choice([True, False]), hasApprovedVersion=random.choice([True, False]),
"slug": slug, slug=slug,
} )
) )
store_listings.append(listing) store_listings.append(listing)
@@ -408,26 +414,26 @@ async def main():
for listing in store_listings: for listing in store_listings:
graph = [g for g in agent_graphs if g.id == listing.agentGraphId][0] graph = [g for g in agent_graphs if g.id == listing.agentGraphId][0]
version = await db.storelistingversion.create( version = await db.storelistingversion.create(
data={ data=StoreListingVersionCreateInput(
"agentGraphId": graph.id, agentGraphId=graph.id,
"agentGraphVersion": graph.version, agentGraphVersion=graph.version,
"name": graph.name or faker.sentence(nb_words=3), name=graph.name or faker.sentence(nb_words=3),
"subHeading": faker.sentence(), subHeading=faker.sentence(),
"videoUrl": get_video_url() if random.random() < 0.3 else None, videoUrl=get_video_url() if random.random() < 0.3 else None,
"imageUrls": [get_image() for _ in range(3)], imageUrls=[get_image() for _ in range(3)],
"description": faker.text(), description=faker.text(),
"categories": [faker.word() for _ in range(3)], categories=[faker.word() for _ in range(3)],
"isFeatured": random.choice([True, False]), isFeatured=random.choice([True, False]),
"isAvailable": True, isAvailable=True,
"storeListingId": listing.id, storeListingId=listing.id,
"submissionStatus": random.choice( submissionStatus=random.choice(
[ [
prisma.enums.SubmissionStatus.PENDING, prisma.enums.SubmissionStatus.PENDING,
prisma.enums.SubmissionStatus.APPROVED, prisma.enums.SubmissionStatus.APPROVED,
prisma.enums.SubmissionStatus.REJECTED, prisma.enums.SubmissionStatus.REJECTED,
] ]
), ),
} )
) )
store_listing_versions.append(version) store_listing_versions.append(version)
@@ -469,51 +475,49 @@ async def main():
try: try:
await db.useronboarding.create( await db.useronboarding.create(
data={ data=UserOnboardingCreateInput(
"userId": user.id, userId=user.id,
"completedSteps": completed_steps, completedSteps=completed_steps,
"walletShown": random.choice([True, False]), walletShown=random.choice([True, False]),
"notified": ( notified=(
random.sample(completed_steps, k=min(3, len(completed_steps))) random.sample(completed_steps, k=min(3, len(completed_steps)))
if completed_steps if completed_steps
else [] else []
), ),
"rewardedFor": ( rewardedFor=(
random.sample(completed_steps, k=min(2, len(completed_steps))) random.sample(completed_steps, k=min(2, len(completed_steps)))
if completed_steps if completed_steps
else [] else []
), ),
"usageReason": ( usageReason=(
random.choice(["personal", "business", "research", "learning"]) random.choice(["personal", "business", "research", "learning"])
if random.random() < 0.7 if random.random() < 0.7
else None else None
), ),
"integrations": random.sample( integrations=random.sample(
["github", "google", "discord", "slack"], k=random.randint(0, 2) ["github", "google", "discord", "slack"], k=random.randint(0, 2)
), ),
"otherIntegrations": ( otherIntegrations=(faker.word() if random.random() < 0.2 else None),
faker.word() if random.random() < 0.2 else None selectedStoreListingVersionId=(
),
"selectedStoreListingVersionId": (
random.choice(store_listing_versions).id random.choice(store_listing_versions).id
if store_listing_versions and random.random() < 0.5 if store_listing_versions and random.random() < 0.5
else None else None
), ),
"onboardingAgentExecutionId": ( onboardingAgentExecutionId=(
random.choice(agent_graph_executions).id random.choice(agent_graph_executions).id
if agent_graph_executions and random.random() < 0.3 if agent_graph_executions and random.random() < 0.3
else None else None
), ),
"agentRuns": random.randint(0, 10), agentRuns=random.randint(0, 10),
} )
) )
except Exception as e: except Exception as e:
print(f"Error creating onboarding for user {user.id}: {e}") print(f"Error creating onboarding for user {user.id}: {e}")
# Try simpler version # Try simpler version
await db.useronboarding.create( await db.useronboarding.create(
data={ data=UserOnboardingCreateInput(
"userId": user.id, userId=user.id,
} )
) )
# Insert IntegrationWebhooks for some users # Insert IntegrationWebhooks for some users
@@ -544,20 +548,20 @@ async def main():
for user in users: for user in users:
api_key = APIKeySmith().generate_key() api_key = APIKeySmith().generate_key()
await db.apikey.create( await db.apikey.create(
data={ data=APIKeyCreateInput(
"name": faker.word(), name=faker.word(),
"head": api_key.head, head=api_key.head,
"tail": api_key.tail, tail=api_key.tail,
"hash": api_key.hash, hash=api_key.hash,
"salt": api_key.salt, salt=api_key.salt,
"status": prisma.enums.APIKeyStatus.ACTIVE, status=prisma.enums.APIKeyStatus.ACTIVE,
"permissions": [ permissions=[
prisma.enums.APIKeyPermission.EXECUTE_GRAPH, prisma.enums.APIKeyPermission.EXECUTE_GRAPH,
prisma.enums.APIKeyPermission.READ_GRAPH, prisma.enums.APIKeyPermission.READ_GRAPH,
], ],
"description": faker.text(), description=faker.text(),
"userId": user.id, userId=user.id,
} )
) )
# Refresh materialized views # Refresh materialized views

View File

@@ -16,6 +16,7 @@ from datetime import datetime, timedelta
import prisma.enums import prisma.enums
from faker import Faker from faker import Faker
from prisma import Json, Prisma from prisma import Json, Prisma
from prisma.types import CreditTransactionCreateInput, StoreListingReviewCreateInput
faker = Faker() faker = Faker()
@@ -166,16 +167,16 @@ async def main():
score = random.choices([1, 2, 3, 4, 5], weights=[5, 10, 20, 40, 25])[0] score = random.choices([1, 2, 3, 4, 5], weights=[5, 10, 20, 40, 25])[0]
await db.storelistingreview.create( await db.storelistingreview.create(
data={ data=StoreListingReviewCreateInput(
"storeListingVersionId": version.id, storeListingVersionId=version.id,
"reviewByUserId": reviewer.id, reviewByUserId=reviewer.id,
"score": score, score=score,
"comments": ( comments=(
faker.text(max_nb_chars=200) faker.text(max_nb_chars=200)
if random.random() < 0.7 if random.random() < 0.7
else None else None
), ),
} )
) )
new_reviews_count += 1 new_reviews_count += 1
@@ -244,17 +245,17 @@ async def main():
) )
await db.credittransaction.create( await db.credittransaction.create(
data={ data=CreditTransactionCreateInput(
"userId": user.id, userId=user.id,
"amount": amount, amount=amount,
"type": transaction_type, type=transaction_type,
"metadata": Json( metadata=Json(
{ {
"source": "test_updater", "source": "test_updater",
"timestamp": datetime.now().isoformat(), "timestamp": datetime.now().isoformat(),
} }
), ),
} )
) )
transaction_count += 1 transaction_count += 1