mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-20 20:48:11 -05:00
Compare commits
5 Commits
testing-cl
...
swiftyos/f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
402bec4595 | ||
|
|
1c8cba9c5f | ||
|
|
072c647baa | ||
|
|
d5f490b85d | ||
|
|
6686de1701 |
@@ -3,6 +3,7 @@ from datetime import UTC, datetime
|
|||||||
from os import getenv
|
from os import getenv
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from prisma.types import ProfileCreateInput
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
@@ -49,13 +50,13 @@ async def setup_test_data():
|
|||||||
# 1b. Create a profile with username for the user (required for store agent lookup)
|
# 1b. Create a profile with username for the user (required for store agent lookup)
|
||||||
username = user.email.split("@")[0]
|
username = user.email.split("@")[0]
|
||||||
await prisma.profile.create(
|
await prisma.profile.create(
|
||||||
data={
|
data=ProfileCreateInput(
|
||||||
"userId": user.id,
|
userId=user.id,
|
||||||
"username": username,
|
username=username,
|
||||||
"name": f"Test User {username}",
|
name=f"Test User {username}",
|
||||||
"description": "Test user profile",
|
description="Test user profile",
|
||||||
"links": [], # Required field - empty array for test profiles
|
links=[], # Required field - empty array for test profiles
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. Create a test graph with agent input -> agent output
|
# 2. Create a test graph with agent input -> agent output
|
||||||
@@ -172,13 +173,13 @@ async def setup_llm_test_data():
|
|||||||
# 1b. Create a profile with username for the user (required for store agent lookup)
|
# 1b. Create a profile with username for the user (required for store agent lookup)
|
||||||
username = user.email.split("@")[0]
|
username = user.email.split("@")[0]
|
||||||
await prisma.profile.create(
|
await prisma.profile.create(
|
||||||
data={
|
data=ProfileCreateInput(
|
||||||
"userId": user.id,
|
userId=user.id,
|
||||||
"username": username,
|
username=username,
|
||||||
"name": f"Test User {username}",
|
name=f"Test User {username}",
|
||||||
"description": "Test user profile for LLM tests",
|
description="Test user profile for LLM tests",
|
||||||
"links": [], # Required field - empty array for test profiles
|
links=[], # Required field - empty array for test profiles
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. Create test OpenAI credentials for the user
|
# 2. Create test OpenAI credentials for the user
|
||||||
@@ -332,13 +333,13 @@ async def setup_firecrawl_test_data():
|
|||||||
# 1b. Create a profile with username for the user (required for store agent lookup)
|
# 1b. Create a profile with username for the user (required for store agent lookup)
|
||||||
username = user.email.split("@")[0]
|
username = user.email.split("@")[0]
|
||||||
await prisma.profile.create(
|
await prisma.profile.create(
|
||||||
data={
|
data=ProfileCreateInput(
|
||||||
"userId": user.id,
|
userId=user.id,
|
||||||
"username": username,
|
username=username,
|
||||||
"name": f"Test User {username}",
|
name=f"Test User {username}",
|
||||||
"description": "Test user profile for Firecrawl tests",
|
description="Test user profile for Firecrawl tests",
|
||||||
"links": [], # Required field - empty array for test profiles
|
links=[], # Required field - empty array for test profiles
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# NOTE: We deliberately do NOT create Firecrawl credentials for this user
|
# NOTE: We deliberately do NOT create Firecrawl credentials for this user
|
||||||
|
|||||||
@@ -817,18 +817,16 @@ async def add_store_agent_to_library(
|
|||||||
|
|
||||||
# Create LibraryAgent entry
|
# Create LibraryAgent entry
|
||||||
added_agent = await prisma.models.LibraryAgent.prisma().create(
|
added_agent = await prisma.models.LibraryAgent.prisma().create(
|
||||||
data={
|
data=prisma.types.LibraryAgentCreateInput(
|
||||||
"User": {"connect": {"id": user_id}},
|
User={"connect": {"id": user_id}},
|
||||||
"AgentGraph": {
|
AgentGraph={
|
||||||
"connect": {
|
"connect": {
|
||||||
"graphVersionId": {"id": graph.id, "version": graph.version}
|
"graphVersionId": {"id": graph.id, "version": graph.version}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"isCreatedByUser": False,
|
isCreatedByUser=False,
|
||||||
"settings": SafeJson(
|
settings=SafeJson(_initialize_graph_settings(graph_model).model_dump()),
|
||||||
_initialize_graph_settings(graph_model).model_dump()
|
),
|
||||||
),
|
|
||||||
},
|
|
||||||
include=library_agent_include(
|
include=library_agent_include(
|
||||||
user_id, include_nodes=False, include_executions=False
|
user_id, include_nodes=False, include_executions=False
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -27,6 +27,13 @@ from prisma.models import OAuthApplication as PrismaOAuthApplication
|
|||||||
from prisma.models import OAuthAuthorizationCode as PrismaOAuthAuthorizationCode
|
from prisma.models import OAuthAuthorizationCode as PrismaOAuthAuthorizationCode
|
||||||
from prisma.models import OAuthRefreshToken as PrismaOAuthRefreshToken
|
from prisma.models import OAuthRefreshToken as PrismaOAuthRefreshToken
|
||||||
from prisma.models import User as PrismaUser
|
from prisma.models import User as PrismaUser
|
||||||
|
from prisma.types import (
|
||||||
|
OAuthAccessTokenCreateInput,
|
||||||
|
OAuthApplicationCreateInput,
|
||||||
|
OAuthAuthorizationCodeCreateInput,
|
||||||
|
OAuthRefreshTokenCreateInput,
|
||||||
|
UserCreateInput,
|
||||||
|
)
|
||||||
|
|
||||||
from backend.api.rest_api import app
|
from backend.api.rest_api import app
|
||||||
|
|
||||||
@@ -48,11 +55,11 @@ def test_user_id() -> str:
|
|||||||
async def test_user(server, test_user_id: str):
|
async def test_user(server, test_user_id: str):
|
||||||
"""Create a test user in the database."""
|
"""Create a test user in the database."""
|
||||||
await PrismaUser.prisma().create(
|
await PrismaUser.prisma().create(
|
||||||
data={
|
data=UserCreateInput(
|
||||||
"id": test_user_id,
|
id=test_user_id,
|
||||||
"email": f"oauth-test-{test_user_id}@example.com",
|
email=f"oauth-test-{test_user_id}@example.com",
|
||||||
"name": "OAuth Test User",
|
name="OAuth Test User",
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
yield test_user_id
|
yield test_user_id
|
||||||
@@ -77,22 +84,22 @@ async def test_oauth_app(test_user: str):
|
|||||||
client_secret_hash, client_secret_salt = keysmith.hash_key(client_secret_plaintext)
|
client_secret_hash, client_secret_salt = keysmith.hash_key(client_secret_plaintext)
|
||||||
|
|
||||||
await PrismaOAuthApplication.prisma().create(
|
await PrismaOAuthApplication.prisma().create(
|
||||||
data={
|
data=OAuthApplicationCreateInput(
|
||||||
"id": app_id,
|
id=app_id,
|
||||||
"name": "Test OAuth App",
|
name="Test OAuth App",
|
||||||
"description": "Test application for integration tests",
|
description="Test application for integration tests",
|
||||||
"clientId": client_id,
|
clientId=client_id,
|
||||||
"clientSecret": client_secret_hash,
|
clientSecret=client_secret_hash,
|
||||||
"clientSecretSalt": client_secret_salt,
|
clientSecretSalt=client_secret_salt,
|
||||||
"redirectUris": [
|
redirectUris=[
|
||||||
"https://example.com/callback",
|
"https://example.com/callback",
|
||||||
"http://localhost:3000/callback",
|
"http://localhost:3000/callback",
|
||||||
],
|
],
|
||||||
"grantTypes": ["authorization_code", "refresh_token"],
|
grantTypes=["authorization_code", "refresh_token"],
|
||||||
"scopes": [APIKeyPermission.EXECUTE_GRAPH, APIKeyPermission.READ_GRAPH],
|
scopes=[APIKeyPermission.EXECUTE_GRAPH, APIKeyPermission.READ_GRAPH],
|
||||||
"ownerId": test_user,
|
ownerId=test_user,
|
||||||
"isActive": True,
|
isActive=True,
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
@@ -296,19 +303,19 @@ async def inactive_oauth_app(test_user: str):
|
|||||||
client_secret_hash, client_secret_salt = keysmith.hash_key(client_secret_plaintext)
|
client_secret_hash, client_secret_salt = keysmith.hash_key(client_secret_plaintext)
|
||||||
|
|
||||||
await PrismaOAuthApplication.prisma().create(
|
await PrismaOAuthApplication.prisma().create(
|
||||||
data={
|
data=OAuthApplicationCreateInput(
|
||||||
"id": app_id,
|
id=app_id,
|
||||||
"name": "Inactive OAuth App",
|
name="Inactive OAuth App",
|
||||||
"description": "Inactive test application",
|
description="Inactive test application",
|
||||||
"clientId": client_id,
|
clientId=client_id,
|
||||||
"clientSecret": client_secret_hash,
|
clientSecret=client_secret_hash,
|
||||||
"clientSecretSalt": client_secret_salt,
|
clientSecretSalt=client_secret_salt,
|
||||||
"redirectUris": ["https://example.com/callback"],
|
redirectUris=["https://example.com/callback"],
|
||||||
"grantTypes": ["authorization_code", "refresh_token"],
|
grantTypes=["authorization_code", "refresh_token"],
|
||||||
"scopes": [APIKeyPermission.EXECUTE_GRAPH],
|
scopes=[APIKeyPermission.EXECUTE_GRAPH],
|
||||||
"ownerId": test_user,
|
ownerId=test_user,
|
||||||
"isActive": False, # Inactive!
|
isActive=False, # Inactive!
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
@@ -699,14 +706,14 @@ async def test_token_authorization_code_expired(
|
|||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
await PrismaOAuthAuthorizationCode.prisma().create(
|
await PrismaOAuthAuthorizationCode.prisma().create(
|
||||||
data={
|
data=OAuthAuthorizationCodeCreateInput(
|
||||||
"code": expired_code,
|
code=expired_code,
|
||||||
"applicationId": test_oauth_app["id"],
|
applicationId=test_oauth_app["id"],
|
||||||
"userId": test_user,
|
userId=test_user,
|
||||||
"scopes": [APIKeyPermission.EXECUTE_GRAPH],
|
scopes=[APIKeyPermission.EXECUTE_GRAPH],
|
||||||
"redirectUri": test_oauth_app["redirect_uri"],
|
redirectUri=test_oauth_app["redirect_uri"],
|
||||||
"expiresAt": now - timedelta(hours=1), # Already expired
|
expiresAt=now - timedelta(hours=1), # Already expired
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
@@ -942,13 +949,13 @@ async def test_token_refresh_expired(
|
|||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
await PrismaOAuthRefreshToken.prisma().create(
|
await PrismaOAuthRefreshToken.prisma().create(
|
||||||
data={
|
data=OAuthRefreshTokenCreateInput(
|
||||||
"token": expired_token_hash,
|
token=expired_token_hash,
|
||||||
"applicationId": test_oauth_app["id"],
|
applicationId=test_oauth_app["id"],
|
||||||
"userId": test_user,
|
userId=test_user,
|
||||||
"scopes": [APIKeyPermission.EXECUTE_GRAPH],
|
scopes=[APIKeyPermission.EXECUTE_GRAPH],
|
||||||
"expiresAt": now - timedelta(days=1), # Already expired
|
expiresAt=now - timedelta(days=1), # Already expired
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
@@ -980,14 +987,14 @@ async def test_token_refresh_revoked(
|
|||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
await PrismaOAuthRefreshToken.prisma().create(
|
await PrismaOAuthRefreshToken.prisma().create(
|
||||||
data={
|
data=OAuthRefreshTokenCreateInput(
|
||||||
"token": revoked_token_hash,
|
token=revoked_token_hash,
|
||||||
"applicationId": test_oauth_app["id"],
|
applicationId=test_oauth_app["id"],
|
||||||
"userId": test_user,
|
userId=test_user,
|
||||||
"scopes": [APIKeyPermission.EXECUTE_GRAPH],
|
scopes=[APIKeyPermission.EXECUTE_GRAPH],
|
||||||
"expiresAt": now + timedelta(days=30), # Not expired
|
expiresAt=now + timedelta(days=30), # Not expired
|
||||||
"revokedAt": now - timedelta(hours=1), # But revoked
|
revokedAt=now - timedelta(hours=1), # But revoked
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
@@ -1013,19 +1020,19 @@ async def other_oauth_app(test_user: str):
|
|||||||
client_secret_hash, client_secret_salt = keysmith.hash_key(client_secret_plaintext)
|
client_secret_hash, client_secret_salt = keysmith.hash_key(client_secret_plaintext)
|
||||||
|
|
||||||
await PrismaOAuthApplication.prisma().create(
|
await PrismaOAuthApplication.prisma().create(
|
||||||
data={
|
data=OAuthApplicationCreateInput(
|
||||||
"id": app_id,
|
id=app_id,
|
||||||
"name": "Other OAuth App",
|
name="Other OAuth App",
|
||||||
"description": "Second test application",
|
description="Second test application",
|
||||||
"clientId": client_id,
|
clientId=client_id,
|
||||||
"clientSecret": client_secret_hash,
|
clientSecret=client_secret_hash,
|
||||||
"clientSecretSalt": client_secret_salt,
|
clientSecretSalt=client_secret_salt,
|
||||||
"redirectUris": ["https://other.example.com/callback"],
|
redirectUris=["https://other.example.com/callback"],
|
||||||
"grantTypes": ["authorization_code", "refresh_token"],
|
grantTypes=["authorization_code", "refresh_token"],
|
||||||
"scopes": [APIKeyPermission.EXECUTE_GRAPH],
|
scopes=[APIKeyPermission.EXECUTE_GRAPH],
|
||||||
"ownerId": test_user,
|
ownerId=test_user,
|
||||||
"isActive": True,
|
isActive=True,
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
@@ -1052,13 +1059,13 @@ async def test_token_refresh_wrong_application(
|
|||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
await PrismaOAuthRefreshToken.prisma().create(
|
await PrismaOAuthRefreshToken.prisma().create(
|
||||||
data={
|
data=OAuthRefreshTokenCreateInput(
|
||||||
"token": token_hash,
|
token=token_hash,
|
||||||
"applicationId": test_oauth_app["id"], # Belongs to test_oauth_app
|
applicationId=test_oauth_app["id"], # Belongs to test_oauth_app
|
||||||
"userId": test_user,
|
userId=test_user,
|
||||||
"scopes": [APIKeyPermission.EXECUTE_GRAPH],
|
scopes=[APIKeyPermission.EXECUTE_GRAPH],
|
||||||
"expiresAt": now + timedelta(days=30),
|
expiresAt=now + timedelta(days=30),
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Try to use it with `other_oauth_app`
|
# Try to use it with `other_oauth_app`
|
||||||
@@ -1267,19 +1274,19 @@ async def test_validate_access_token_fails_when_app_disabled(
|
|||||||
client_secret_hash, client_secret_salt = keysmith.hash_key(client_secret_plaintext)
|
client_secret_hash, client_secret_salt = keysmith.hash_key(client_secret_plaintext)
|
||||||
|
|
||||||
await PrismaOAuthApplication.prisma().create(
|
await PrismaOAuthApplication.prisma().create(
|
||||||
data={
|
data=OAuthApplicationCreateInput(
|
||||||
"id": app_id,
|
id=app_id,
|
||||||
"name": "App To Be Disabled",
|
name="App To Be Disabled",
|
||||||
"description": "Test app for disabled validation",
|
description="Test app for disabled validation",
|
||||||
"clientId": client_id,
|
clientId=client_id,
|
||||||
"clientSecret": client_secret_hash,
|
clientSecret=client_secret_hash,
|
||||||
"clientSecretSalt": client_secret_salt,
|
clientSecretSalt=client_secret_salt,
|
||||||
"redirectUris": ["https://example.com/callback"],
|
redirectUris=["https://example.com/callback"],
|
||||||
"grantTypes": ["authorization_code"],
|
grantTypes=["authorization_code"],
|
||||||
"scopes": [APIKeyPermission.EXECUTE_GRAPH],
|
scopes=[APIKeyPermission.EXECUTE_GRAPH],
|
||||||
"ownerId": test_user,
|
ownerId=test_user,
|
||||||
"isActive": True,
|
isActive=True,
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create an access token directly in the database
|
# Create an access token directly in the database
|
||||||
@@ -1288,13 +1295,13 @@ async def test_validate_access_token_fails_when_app_disabled(
|
|||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
await PrismaOAuthAccessToken.prisma().create(
|
await PrismaOAuthAccessToken.prisma().create(
|
||||||
data={
|
data=OAuthAccessTokenCreateInput(
|
||||||
"token": token_hash,
|
token=token_hash,
|
||||||
"applicationId": app_id,
|
applicationId=app_id,
|
||||||
"userId": test_user,
|
userId=test_user,
|
||||||
"scopes": [APIKeyPermission.EXECUTE_GRAPH],
|
scopes=[APIKeyPermission.EXECUTE_GRAPH],
|
||||||
"expiresAt": now + timedelta(hours=1),
|
expiresAt=now + timedelta(hours=1),
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Token should be valid while app is active
|
# Token should be valid while app is active
|
||||||
@@ -1561,19 +1568,19 @@ async def test_revoke_token_from_different_app_fails_silently(
|
|||||||
)
|
)
|
||||||
|
|
||||||
await PrismaOAuthApplication.prisma().create(
|
await PrismaOAuthApplication.prisma().create(
|
||||||
data={
|
data=OAuthApplicationCreateInput(
|
||||||
"id": app2_id,
|
id=app2_id,
|
||||||
"name": "Second Test OAuth App",
|
name="Second Test OAuth App",
|
||||||
"description": "Second test application for cross-app revocation test",
|
description="Second test application for cross-app revocation test",
|
||||||
"clientId": app2_client_id,
|
clientId=app2_client_id,
|
||||||
"clientSecret": app2_client_secret_hash,
|
clientSecret=app2_client_secret_hash,
|
||||||
"clientSecretSalt": app2_client_secret_salt,
|
clientSecretSalt=app2_client_secret_salt,
|
||||||
"redirectUris": ["https://other-app.com/callback"],
|
redirectUris=["https://other-app.com/callback"],
|
||||||
"grantTypes": ["authorization_code", "refresh_token"],
|
grantTypes=["authorization_code", "refresh_token"],
|
||||||
"scopes": [APIKeyPermission.EXECUTE_GRAPH, APIKeyPermission.READ_GRAPH],
|
scopes=[APIKeyPermission.EXECUTE_GRAPH, APIKeyPermission.READ_GRAPH],
|
||||||
"ownerId": test_user,
|
ownerId=test_user,
|
||||||
"isActive": True,
|
isActive=True,
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# App 2 tries to revoke App 1's access token
|
# App 2 tries to revoke App 1's access token
|
||||||
|
|||||||
@@ -249,7 +249,9 @@ async def log_search_term(search_query: str):
|
|||||||
date = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0)
|
date = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
try:
|
try:
|
||||||
await prisma.models.SearchTerms.prisma().create(
|
await prisma.models.SearchTerms.prisma().create(
|
||||||
data={"searchTerm": search_query, "createdDate": date}
|
data=prisma.types.SearchTermsCreateInput(
|
||||||
|
searchTerm=search_query, createdDate=date
|
||||||
|
)
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Fail silently here so that logging search terms doesn't break the app
|
# Fail silently here so that logging search terms doesn't break the app
|
||||||
@@ -1456,11 +1458,9 @@ async def _approve_sub_agent(
|
|||||||
# Create new version if no matching version found
|
# Create new version if no matching version found
|
||||||
next_version = max((v.version for v in listing.Versions or []), default=0) + 1
|
next_version = max((v.version for v in listing.Versions or []), default=0) + 1
|
||||||
await prisma.models.StoreListingVersion.prisma(tx).create(
|
await prisma.models.StoreListingVersion.prisma(tx).create(
|
||||||
data={
|
data=_create_sub_agent_version_data(
|
||||||
**_create_sub_agent_version_data(sub_graph, heading, main_agent_name),
|
sub_graph, heading, main_agent_name, next_version, listing.id
|
||||||
"version": next_version,
|
)
|
||||||
"storeListingId": listing.id,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
await prisma.models.StoreListing.prisma(tx).update(
|
await prisma.models.StoreListing.prisma(tx).update(
|
||||||
where={"id": listing.id}, data={"hasApprovedVersion": True}
|
where={"id": listing.id}, data={"hasApprovedVersion": True}
|
||||||
@@ -1468,10 +1468,14 @@ async def _approve_sub_agent(
|
|||||||
|
|
||||||
|
|
||||||
def _create_sub_agent_version_data(
|
def _create_sub_agent_version_data(
|
||||||
sub_graph: prisma.models.AgentGraph, heading: str, main_agent_name: str
|
sub_graph: prisma.models.AgentGraph,
|
||||||
|
heading: str,
|
||||||
|
main_agent_name: str,
|
||||||
|
version: typing.Optional[int] = None,
|
||||||
|
store_listing_id: typing.Optional[str] = None,
|
||||||
) -> prisma.types.StoreListingVersionCreateInput:
|
) -> prisma.types.StoreListingVersionCreateInput:
|
||||||
"""Create store listing version data for a sub-agent"""
|
"""Create store listing version data for a sub-agent"""
|
||||||
return prisma.types.StoreListingVersionCreateInput(
|
data = prisma.types.StoreListingVersionCreateInput(
|
||||||
agentGraphId=sub_graph.id,
|
agentGraphId=sub_graph.id,
|
||||||
agentGraphVersion=sub_graph.version,
|
agentGraphVersion=sub_graph.version,
|
||||||
name=sub_graph.name or heading,
|
name=sub_graph.name or heading,
|
||||||
@@ -1486,6 +1490,11 @@ def _create_sub_agent_version_data(
|
|||||||
imageUrls=[], # Sub-agents don't need images
|
imageUrls=[], # Sub-agents don't need images
|
||||||
categories=[], # Sub-agents don't need categories
|
categories=[], # Sub-agents don't need categories
|
||||||
)
|
)
|
||||||
|
if version is not None:
|
||||||
|
data["version"] = version
|
||||||
|
if store_listing_id is not None:
|
||||||
|
data["storeListingId"] = store_listing_id
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
async def review_store_submission(
|
async def review_store_submission(
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ from urllib.parse import urlparse
|
|||||||
import click
|
import click
|
||||||
from autogpt_libs.api_key.keysmith import APIKeySmith
|
from autogpt_libs.api_key.keysmith import APIKeySmith
|
||||||
from prisma.enums import APIKeyPermission
|
from prisma.enums import APIKeyPermission
|
||||||
|
from prisma.types import OAuthApplicationCreateInput
|
||||||
|
|
||||||
keysmith = APIKeySmith()
|
keysmith = APIKeySmith()
|
||||||
|
|
||||||
@@ -147,7 +148,7 @@ def format_sql_insert(creds: dict) -> str:
|
|||||||
|
|
||||||
sql = f"""
|
sql = f"""
|
||||||
-- ============================================================
|
-- ============================================================
|
||||||
-- OAuth Application: {creds['name']}
|
-- OAuth Application: {creds["name"]}
|
||||||
-- Generated: {now_iso} UTC
|
-- Generated: {now_iso} UTC
|
||||||
-- ============================================================
|
-- ============================================================
|
||||||
|
|
||||||
@@ -167,14 +168,14 @@ INSERT INTO "OAuthApplication" (
|
|||||||
"isActive"
|
"isActive"
|
||||||
)
|
)
|
||||||
VALUES (
|
VALUES (
|
||||||
'{creds['id']}',
|
'{creds["id"]}',
|
||||||
NOW(),
|
NOW(),
|
||||||
NOW(),
|
NOW(),
|
||||||
'{creds['name']}',
|
'{creds["name"]}',
|
||||||
{f"'{creds['description']}'" if creds['description'] else 'NULL'},
|
{f"'{creds['description']}'" if creds["description"] else "NULL"},
|
||||||
'{creds['client_id']}',
|
'{creds["client_id"]}',
|
||||||
'{creds['client_secret_hash']}',
|
'{creds["client_secret_hash"]}',
|
||||||
'{creds['client_secret_salt']}',
|
'{creds["client_secret_salt"]}',
|
||||||
ARRAY{redirect_uris_pg}::TEXT[],
|
ARRAY{redirect_uris_pg}::TEXT[],
|
||||||
ARRAY{grant_types_pg}::TEXT[],
|
ARRAY{grant_types_pg}::TEXT[],
|
||||||
ARRAY{scopes_pg}::"APIKeyPermission"[],
|
ARRAY{scopes_pg}::"APIKeyPermission"[],
|
||||||
@@ -186,8 +187,8 @@ VALUES (
|
|||||||
-- ⚠️ IMPORTANT: Save these credentials securely!
|
-- ⚠️ IMPORTANT: Save these credentials securely!
|
||||||
-- ============================================================
|
-- ============================================================
|
||||||
--
|
--
|
||||||
-- Client ID: {creds['client_id']}
|
-- Client ID: {creds["client_id"]}
|
||||||
-- Client Secret: {creds['client_secret_plaintext']}
|
-- Client Secret: {creds["client_secret_plaintext"]}
|
||||||
--
|
--
|
||||||
-- ⚠️ The client secret is shown ONLY ONCE!
|
-- ⚠️ The client secret is shown ONLY ONCE!
|
||||||
-- ⚠️ Store it securely and share only with the application developer.
|
-- ⚠️ Store it securely and share only with the application developer.
|
||||||
@@ -200,7 +201,7 @@ VALUES (
|
|||||||
-- To verify the application was created:
|
-- To verify the application was created:
|
||||||
-- SELECT "clientId", name, scopes, "redirectUris", "isActive"
|
-- SELECT "clientId", name, scopes, "redirectUris", "isActive"
|
||||||
-- FROM "OAuthApplication"
|
-- FROM "OAuthApplication"
|
||||||
-- WHERE "clientId" = '{creds['client_id']}';
|
-- WHERE "clientId" = '{creds["client_id"]}';
|
||||||
"""
|
"""
|
||||||
return sql
|
return sql
|
||||||
|
|
||||||
@@ -834,19 +835,19 @@ async def create_test_app_in_db(
|
|||||||
|
|
||||||
# Insert into database
|
# Insert into database
|
||||||
app = await OAuthApplication.prisma().create(
|
app = await OAuthApplication.prisma().create(
|
||||||
data={
|
data=OAuthApplicationCreateInput(
|
||||||
"id": creds["id"],
|
id=creds["id"],
|
||||||
"name": creds["name"],
|
name=creds["name"],
|
||||||
"description": creds["description"],
|
description=creds["description"],
|
||||||
"clientId": creds["client_id"],
|
clientId=creds["client_id"],
|
||||||
"clientSecret": creds["client_secret_hash"],
|
clientSecret=creds["client_secret_hash"],
|
||||||
"clientSecretSalt": creds["client_secret_salt"],
|
clientSecretSalt=creds["client_secret_salt"],
|
||||||
"redirectUris": creds["redirect_uris"],
|
redirectUris=creds["redirect_uris"],
|
||||||
"grantTypes": creds["grant_types"],
|
grantTypes=creds["grant_types"],
|
||||||
"scopes": creds["scopes"],
|
scopes=creds["scopes"],
|
||||||
"ownerId": owner_id,
|
ownerId=owner_id,
|
||||||
"isActive": True,
|
isActive=True,
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
click.echo(f"✓ Created test OAuth application: {app.clientId}")
|
click.echo(f"✓ Created test OAuth application: {app.clientId}")
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import Literal, Optional
|
|||||||
from autogpt_libs.api_key.keysmith import APIKeySmith
|
from autogpt_libs.api_key.keysmith import APIKeySmith
|
||||||
from prisma.enums import APIKeyPermission, APIKeyStatus
|
from prisma.enums import APIKeyPermission, APIKeyStatus
|
||||||
from prisma.models import APIKey as PrismaAPIKey
|
from prisma.models import APIKey as PrismaAPIKey
|
||||||
from prisma.types import APIKeyWhereUniqueInput
|
from prisma.types import APIKeyCreateInput, APIKeyWhereUniqueInput
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from backend.data.includes import MAX_USER_API_KEYS_FETCH
|
from backend.data.includes import MAX_USER_API_KEYS_FETCH
|
||||||
@@ -82,17 +82,17 @@ async def create_api_key(
|
|||||||
generated_key = keysmith.generate_key()
|
generated_key = keysmith.generate_key()
|
||||||
|
|
||||||
saved_key_obj = await PrismaAPIKey.prisma().create(
|
saved_key_obj = await PrismaAPIKey.prisma().create(
|
||||||
data={
|
data=APIKeyCreateInput(
|
||||||
"id": str(uuid.uuid4()),
|
id=str(uuid.uuid4()),
|
||||||
"name": name,
|
name=name,
|
||||||
"head": generated_key.head,
|
head=generated_key.head,
|
||||||
"tail": generated_key.tail,
|
tail=generated_key.tail,
|
||||||
"hash": generated_key.hash,
|
hash=generated_key.hash,
|
||||||
"salt": generated_key.salt,
|
salt=generated_key.salt,
|
||||||
"permissions": [p for p in permissions],
|
permissions=[p for p in permissions],
|
||||||
"description": description,
|
description=description,
|
||||||
"userId": user_id,
|
userId=user_id,
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return APIKeyInfo.from_db(saved_key_obj), generated_key.key
|
return APIKeyInfo.from_db(saved_key_obj), generated_key.key
|
||||||
|
|||||||
@@ -22,7 +22,12 @@ from prisma.models import OAuthAccessToken as PrismaOAuthAccessToken
|
|||||||
from prisma.models import OAuthApplication as PrismaOAuthApplication
|
from prisma.models import OAuthApplication as PrismaOAuthApplication
|
||||||
from prisma.models import OAuthAuthorizationCode as PrismaOAuthAuthorizationCode
|
from prisma.models import OAuthAuthorizationCode as PrismaOAuthAuthorizationCode
|
||||||
from prisma.models import OAuthRefreshToken as PrismaOAuthRefreshToken
|
from prisma.models import OAuthRefreshToken as PrismaOAuthRefreshToken
|
||||||
from prisma.types import OAuthApplicationUpdateInput
|
from prisma.types import (
|
||||||
|
OAuthAccessTokenCreateInput,
|
||||||
|
OAuthApplicationUpdateInput,
|
||||||
|
OAuthAuthorizationCodeCreateInput,
|
||||||
|
OAuthRefreshTokenCreateInput,
|
||||||
|
)
|
||||||
from pydantic import BaseModel, Field, SecretStr
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
from .base import APIAuthorizationInfo
|
from .base import APIAuthorizationInfo
|
||||||
@@ -359,17 +364,17 @@ async def create_authorization_code(
|
|||||||
expires_at = now + AUTHORIZATION_CODE_TTL
|
expires_at = now + AUTHORIZATION_CODE_TTL
|
||||||
|
|
||||||
saved_code = await PrismaOAuthAuthorizationCode.prisma().create(
|
saved_code = await PrismaOAuthAuthorizationCode.prisma().create(
|
||||||
data={
|
data=OAuthAuthorizationCodeCreateInput(
|
||||||
"id": str(uuid.uuid4()),
|
id=str(uuid.uuid4()),
|
||||||
"code": code,
|
code=code,
|
||||||
"expiresAt": expires_at,
|
expiresAt=expires_at,
|
||||||
"applicationId": application_id,
|
applicationId=application_id,
|
||||||
"userId": user_id,
|
userId=user_id,
|
||||||
"scopes": [s for s in scopes],
|
scopes=[s for s in scopes],
|
||||||
"redirectUri": redirect_uri,
|
redirectUri=redirect_uri,
|
||||||
"codeChallenge": code_challenge,
|
codeChallenge=code_challenge,
|
||||||
"codeChallengeMethod": code_challenge_method,
|
codeChallengeMethod=code_challenge_method,
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return OAuthAuthorizationCodeInfo.from_db(saved_code)
|
return OAuthAuthorizationCodeInfo.from_db(saved_code)
|
||||||
@@ -490,14 +495,14 @@ async def create_access_token(
|
|||||||
expires_at = now + ACCESS_TOKEN_TTL
|
expires_at = now + ACCESS_TOKEN_TTL
|
||||||
|
|
||||||
saved_token = await PrismaOAuthAccessToken.prisma().create(
|
saved_token = await PrismaOAuthAccessToken.prisma().create(
|
||||||
data={
|
data=OAuthAccessTokenCreateInput(
|
||||||
"id": str(uuid.uuid4()),
|
id=str(uuid.uuid4()),
|
||||||
"token": token_hash, # SHA256 hash for direct lookup
|
token=token_hash, # SHA256 hash for direct lookup
|
||||||
"expiresAt": expires_at,
|
expiresAt=expires_at,
|
||||||
"applicationId": application_id,
|
applicationId=application_id,
|
||||||
"userId": user_id,
|
userId=user_id,
|
||||||
"scopes": [s for s in scopes],
|
scopes=[s for s in scopes],
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return OAuthAccessToken.from_db(saved_token, plaintext_token=plaintext_token)
|
return OAuthAccessToken.from_db(saved_token, plaintext_token=plaintext_token)
|
||||||
@@ -607,14 +612,14 @@ async def create_refresh_token(
|
|||||||
expires_at = now + REFRESH_TOKEN_TTL
|
expires_at = now + REFRESH_TOKEN_TTL
|
||||||
|
|
||||||
saved_token = await PrismaOAuthRefreshToken.prisma().create(
|
saved_token = await PrismaOAuthRefreshToken.prisma().create(
|
||||||
data={
|
data=OAuthRefreshTokenCreateInput(
|
||||||
"id": str(uuid.uuid4()),
|
id=str(uuid.uuid4()),
|
||||||
"token": token_hash, # SHA256 hash for direct lookup
|
token=token_hash, # SHA256 hash for direct lookup
|
||||||
"expiresAt": expires_at,
|
expiresAt=expires_at,
|
||||||
"applicationId": application_id,
|
applicationId=application_id,
|
||||||
"userId": user_id,
|
userId=user_id,
|
||||||
"scopes": [s for s in scopes],
|
scopes=[s for s in scopes],
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return OAuthRefreshToken.from_db(saved_token, plaintext_token=plaintext_token)
|
return OAuthRefreshToken.from_db(saved_token, plaintext_token=plaintext_token)
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import pytest
|
|||||||
from prisma.enums import CreditTransactionType
|
from prisma.enums import CreditTransactionType
|
||||||
from prisma.errors import UniqueViolationError
|
from prisma.errors import UniqueViolationError
|
||||||
from prisma.models import CreditTransaction, User, UserBalance
|
from prisma.models import CreditTransaction, User, UserBalance
|
||||||
|
from prisma.types import UserBalanceCreateInput, UserBalanceUpsertInput, UserCreateInput
|
||||||
|
|
||||||
from backend.data.credit import UserCredit
|
from backend.data.credit import UserCredit
|
||||||
from backend.util.json import SafeJson
|
from backend.util.json import SafeJson
|
||||||
@@ -21,11 +22,11 @@ async def create_test_user(user_id: str) -> None:
|
|||||||
"""Create a test user for ceiling tests."""
|
"""Create a test user for ceiling tests."""
|
||||||
try:
|
try:
|
||||||
await User.prisma().create(
|
await User.prisma().create(
|
||||||
data={
|
data=UserCreateInput(
|
||||||
"id": user_id,
|
id=user_id,
|
||||||
"email": f"test-{user_id}@example.com",
|
email=f"test-{user_id}@example.com",
|
||||||
"name": f"Test User {user_id[:8]}",
|
name=f"Test User {user_id[:8]}",
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
except UniqueViolationError:
|
except UniqueViolationError:
|
||||||
# User already exists, continue
|
# User already exists, continue
|
||||||
@@ -33,7 +34,10 @@ async def create_test_user(user_id: str) -> None:
|
|||||||
|
|
||||||
await UserBalance.prisma().upsert(
|
await UserBalance.prisma().upsert(
|
||||||
where={"userId": user_id},
|
where={"userId": user_id},
|
||||||
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
|
data=UserBalanceUpsertInput(
|
||||||
|
create=UserBalanceCreateInput(userId=user_id, balance=0),
|
||||||
|
update={"balance": 0},
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import pytest
|
|||||||
from prisma.enums import CreditTransactionType
|
from prisma.enums import CreditTransactionType
|
||||||
from prisma.errors import UniqueViolationError
|
from prisma.errors import UniqueViolationError
|
||||||
from prisma.models import CreditTransaction, User, UserBalance
|
from prisma.models import CreditTransaction, User, UserBalance
|
||||||
|
from prisma.types import UserBalanceCreateInput, UserBalanceUpsertInput, UserCreateInput
|
||||||
|
|
||||||
from backend.data.credit import POSTGRES_INT_MAX, UsageTransactionMetadata, UserCredit
|
from backend.data.credit import POSTGRES_INT_MAX, UsageTransactionMetadata, UserCredit
|
||||||
from backend.util.exceptions import InsufficientBalanceError
|
from backend.util.exceptions import InsufficientBalanceError
|
||||||
@@ -28,11 +29,11 @@ async def create_test_user(user_id: str) -> None:
|
|||||||
"""Create a test user with initial balance."""
|
"""Create a test user with initial balance."""
|
||||||
try:
|
try:
|
||||||
await User.prisma().create(
|
await User.prisma().create(
|
||||||
data={
|
data=UserCreateInput(
|
||||||
"id": user_id,
|
id=user_id,
|
||||||
"email": f"test-{user_id}@example.com",
|
email=f"test-{user_id}@example.com",
|
||||||
"name": f"Test User {user_id[:8]}",
|
name=f"Test User {user_id[:8]}",
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
except UniqueViolationError:
|
except UniqueViolationError:
|
||||||
# User already exists, continue
|
# User already exists, continue
|
||||||
@@ -41,7 +42,10 @@ async def create_test_user(user_id: str) -> None:
|
|||||||
# Ensure UserBalance record exists
|
# Ensure UserBalance record exists
|
||||||
await UserBalance.prisma().upsert(
|
await UserBalance.prisma().upsert(
|
||||||
where={"userId": user_id},
|
where={"userId": user_id},
|
||||||
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
|
data=UserBalanceUpsertInput(
|
||||||
|
create=UserBalanceCreateInput(userId=user_id, balance=0),
|
||||||
|
update={"balance": 0},
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -342,10 +346,10 @@ async def test_integer_overflow_protection(server: SpinTestServer):
|
|||||||
# First, set balance near max
|
# First, set balance near max
|
||||||
await UserBalance.prisma().upsert(
|
await UserBalance.prisma().upsert(
|
||||||
where={"userId": user_id},
|
where={"userId": user_id},
|
||||||
data={
|
data=UserBalanceUpsertInput(
|
||||||
"create": {"userId": user_id, "balance": max_int - 100},
|
create=UserBalanceCreateInput(userId=user_id, balance=max_int - 100),
|
||||||
"update": {"balance": max_int - 100},
|
update={"balance": max_int - 100},
|
||||||
},
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Try to add more than possible - should clamp to POSTGRES_INT_MAX
|
# Try to add more than possible - should clamp to POSTGRES_INT_MAX
|
||||||
@@ -507,7 +511,7 @@ async def test_concurrent_multiple_spends_sufficient_balance(server: SpinTestSer
|
|||||||
sorted_timings = sorted(timings.items(), key=lambda x: x[1]["start"])
|
sorted_timings = sorted(timings.items(), key=lambda x: x[1]["start"])
|
||||||
print("\nExecution order by start time:")
|
print("\nExecution order by start time:")
|
||||||
for i, (label, timing) in enumerate(sorted_timings):
|
for i, (label, timing) in enumerate(sorted_timings):
|
||||||
print(f" {i+1}. {label}: {timing['start']:.4f} -> {timing['end']:.4f}")
|
print(f" {i + 1}. {label}: {timing['start']:.4f} -> {timing['end']:.4f}")
|
||||||
|
|
||||||
# Check for overlap (true concurrency) vs serialization
|
# Check for overlap (true concurrency) vs serialization
|
||||||
overlaps = []
|
overlaps = []
|
||||||
@@ -546,7 +550,7 @@ async def test_concurrent_multiple_spends_sufficient_balance(server: SpinTestSer
|
|||||||
print("\nDatabase transaction order (by createdAt):")
|
print("\nDatabase transaction order (by createdAt):")
|
||||||
for i, tx in enumerate(transactions):
|
for i, tx in enumerate(transactions):
|
||||||
print(
|
print(
|
||||||
f" {i+1}. Amount {tx.amount}, Running balance: {tx.runningBalance}, Created: {tx.createdAt}"
|
f" {i + 1}. Amount {tx.amount}, Running balance: {tx.runningBalance}, Created: {tx.createdAt}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify running balances are chronologically consistent (ordered by createdAt)
|
# Verify running balances are chronologically consistent (ordered by createdAt)
|
||||||
@@ -707,7 +711,7 @@ async def test_prove_database_locking_behavior(server: SpinTestServer):
|
|||||||
|
|
||||||
for i, result in enumerate(sorted_results):
|
for i, result in enumerate(sorted_results):
|
||||||
print(
|
print(
|
||||||
f" {i+1}. {result['label']}: DB operation took {result['db_duration']:.4f}s"
|
f" {i + 1}. {result['label']}: DB operation took {result['db_duration']:.4f}s"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if any operations overlapped at the database level
|
# Check if any operations overlapped at the database level
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ which would have caught the CreditTransactionType enum casting bug.
|
|||||||
import pytest
|
import pytest
|
||||||
from prisma.enums import CreditTransactionType
|
from prisma.enums import CreditTransactionType
|
||||||
from prisma.models import CreditTransaction, User, UserBalance
|
from prisma.models import CreditTransaction, User, UserBalance
|
||||||
|
from prisma.types import UserCreateInput
|
||||||
|
|
||||||
from backend.data.credit import (
|
from backend.data.credit import (
|
||||||
AutoTopUpConfig,
|
AutoTopUpConfig,
|
||||||
@@ -29,12 +30,12 @@ async def cleanup_test_user():
|
|||||||
# Create the user first
|
# Create the user first
|
||||||
try:
|
try:
|
||||||
await User.prisma().create(
|
await User.prisma().create(
|
||||||
data={
|
data=UserCreateInput(
|
||||||
"id": user_id,
|
id=user_id,
|
||||||
"email": f"test-{user_id}@example.com",
|
email=f"test-{user_id}@example.com",
|
||||||
"topUpConfig": SafeJson({}),
|
topUpConfig=SafeJson({}),
|
||||||
"timezone": "UTC",
|
timezone="UTC",
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
# User might already exist, that's fine
|
# User might already exist, that's fine
|
||||||
|
|||||||
@@ -12,6 +12,12 @@ import pytest
|
|||||||
import stripe
|
import stripe
|
||||||
from prisma.enums import CreditTransactionType
|
from prisma.enums import CreditTransactionType
|
||||||
from prisma.models import CreditRefundRequest, CreditTransaction, User, UserBalance
|
from prisma.models import CreditRefundRequest, CreditTransaction, User, UserBalance
|
||||||
|
from prisma.types import (
|
||||||
|
CreditRefundRequestCreateInput,
|
||||||
|
CreditTransactionCreateInput,
|
||||||
|
UserBalanceCreateInput,
|
||||||
|
UserCreateInput,
|
||||||
|
)
|
||||||
|
|
||||||
from backend.data.credit import UserCredit
|
from backend.data.credit import UserCredit
|
||||||
from backend.util.json import SafeJson
|
from backend.util.json import SafeJson
|
||||||
@@ -35,32 +41,32 @@ async def setup_test_user_with_topup():
|
|||||||
|
|
||||||
# Create user
|
# Create user
|
||||||
await User.prisma().create(
|
await User.prisma().create(
|
||||||
data={
|
data=UserCreateInput(
|
||||||
"id": REFUND_TEST_USER_ID,
|
id=REFUND_TEST_USER_ID,
|
||||||
"email": f"{REFUND_TEST_USER_ID}@example.com",
|
email=f"{REFUND_TEST_USER_ID}@example.com",
|
||||||
"name": "Refund Test User",
|
name="Refund Test User",
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create user balance
|
# Create user balance
|
||||||
await UserBalance.prisma().create(
|
await UserBalance.prisma().create(
|
||||||
data={
|
data=UserBalanceCreateInput(
|
||||||
"userId": REFUND_TEST_USER_ID,
|
userId=REFUND_TEST_USER_ID,
|
||||||
"balance": 1000, # $10
|
balance=1000, # $10
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a top-up transaction that can be refunded
|
# Create a top-up transaction that can be refunded
|
||||||
topup_tx = await CreditTransaction.prisma().create(
|
topup_tx = await CreditTransaction.prisma().create(
|
||||||
data={
|
data=CreditTransactionCreateInput(
|
||||||
"userId": REFUND_TEST_USER_ID,
|
userId=REFUND_TEST_USER_ID,
|
||||||
"amount": 1000,
|
amount=1000,
|
||||||
"type": CreditTransactionType.TOP_UP,
|
type=CreditTransactionType.TOP_UP,
|
||||||
"transactionKey": "pi_test_12345",
|
transactionKey="pi_test_12345",
|
||||||
"runningBalance": 1000,
|
runningBalance=1000,
|
||||||
"isActive": True,
|
isActive=True,
|
||||||
"metadata": SafeJson({"stripe_payment_intent": "pi_test_12345"}),
|
metadata=SafeJson({"stripe_payment_intent": "pi_test_12345"}),
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return topup_tx
|
return topup_tx
|
||||||
@@ -93,12 +99,12 @@ async def test_deduct_credits_atomic(server: SpinTestServer):
|
|||||||
|
|
||||||
# Create refund request record (simulating webhook flow)
|
# Create refund request record (simulating webhook flow)
|
||||||
await CreditRefundRequest.prisma().create(
|
await CreditRefundRequest.prisma().create(
|
||||||
data={
|
data=CreditRefundRequestCreateInput(
|
||||||
"userId": REFUND_TEST_USER_ID,
|
userId=REFUND_TEST_USER_ID,
|
||||||
"amount": 500,
|
amount=500,
|
||||||
"transactionKey": topup_tx.transactionKey, # Should match the original transaction
|
transactionKey=topup_tx.transactionKey, # Should match the original transaction
|
||||||
"reason": "Test refund",
|
reason="Test refund",
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call deduct_credits
|
# Call deduct_credits
|
||||||
@@ -286,12 +292,12 @@ async def test_concurrent_refunds(server: SpinTestServer):
|
|||||||
refund_requests = []
|
refund_requests = []
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
req = await CreditRefundRequest.prisma().create(
|
req = await CreditRefundRequest.prisma().create(
|
||||||
data={
|
data=CreditRefundRequestCreateInput(
|
||||||
"userId": REFUND_TEST_USER_ID,
|
userId=REFUND_TEST_USER_ID,
|
||||||
"amount": 100, # $1 each
|
amount=100, # $1 each
|
||||||
"transactionKey": topup_tx.transactionKey,
|
transactionKey=topup_tx.transactionKey,
|
||||||
"reason": f"Test refund {i}",
|
reason=f"Test refund {i}",
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
refund_requests.append(req)
|
refund_requests.append(req)
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,11 @@ from datetime import datetime, timedelta, timezone
|
|||||||
import pytest
|
import pytest
|
||||||
from prisma.enums import CreditTransactionType
|
from prisma.enums import CreditTransactionType
|
||||||
from prisma.models import CreditTransaction, UserBalance
|
from prisma.models import CreditTransaction, UserBalance
|
||||||
|
from prisma.types import (
|
||||||
|
CreditTransactionCreateInput,
|
||||||
|
UserBalanceCreateInput,
|
||||||
|
UserBalanceUpsertInput,
|
||||||
|
)
|
||||||
|
|
||||||
from backend.blocks.llm import AITextGeneratorBlock
|
from backend.blocks.llm import AITextGeneratorBlock
|
||||||
from backend.data.block import get_block
|
from backend.data.block import get_block
|
||||||
@@ -23,10 +28,10 @@ async def disable_test_user_transactions():
|
|||||||
old_date = datetime.now(timezone.utc) - timedelta(days=35) # More than a month ago
|
old_date = datetime.now(timezone.utc) - timedelta(days=35) # More than a month ago
|
||||||
await UserBalance.prisma().upsert(
|
await UserBalance.prisma().upsert(
|
||||||
where={"userId": DEFAULT_USER_ID},
|
where={"userId": DEFAULT_USER_ID},
|
||||||
data={
|
data=UserBalanceUpsertInput(
|
||||||
"create": {"userId": DEFAULT_USER_ID, "balance": 0},
|
create=UserBalanceCreateInput(userId=DEFAULT_USER_ID, balance=0),
|
||||||
"update": {"balance": 0, "updatedAt": old_date},
|
update={"balance": 0, "updatedAt": old_date},
|
||||||
},
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -140,23 +145,23 @@ async def test_block_credit_reset(server: SpinTestServer):
|
|||||||
|
|
||||||
# Manually create a transaction with month 1 timestamp to establish history
|
# Manually create a transaction with month 1 timestamp to establish history
|
||||||
await CreditTransaction.prisma().create(
|
await CreditTransaction.prisma().create(
|
||||||
data={
|
data=CreditTransactionCreateInput(
|
||||||
"userId": DEFAULT_USER_ID,
|
userId=DEFAULT_USER_ID,
|
||||||
"amount": 100,
|
amount=100,
|
||||||
"type": CreditTransactionType.TOP_UP,
|
type=CreditTransactionType.TOP_UP,
|
||||||
"runningBalance": 1100,
|
runningBalance=1100,
|
||||||
"isActive": True,
|
isActive=True,
|
||||||
"createdAt": month1, # Set specific timestamp
|
createdAt=month1, # Set specific timestamp
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update user balance to match
|
# Update user balance to match
|
||||||
await UserBalance.prisma().upsert(
|
await UserBalance.prisma().upsert(
|
||||||
where={"userId": DEFAULT_USER_ID},
|
where={"userId": DEFAULT_USER_ID},
|
||||||
data={
|
data=UserBalanceUpsertInput(
|
||||||
"create": {"userId": DEFAULT_USER_ID, "balance": 1100},
|
create=UserBalanceCreateInput(userId=DEFAULT_USER_ID, balance=1100),
|
||||||
"update": {"balance": 1100},
|
update={"balance": 1100},
|
||||||
},
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Now test month 2 behavior
|
# Now test month 2 behavior
|
||||||
@@ -175,14 +180,14 @@ async def test_block_credit_reset(server: SpinTestServer):
|
|||||||
|
|
||||||
# Create a month 2 transaction to update the last transaction time
|
# Create a month 2 transaction to update the last transaction time
|
||||||
await CreditTransaction.prisma().create(
|
await CreditTransaction.prisma().create(
|
||||||
data={
|
data=CreditTransactionCreateInput(
|
||||||
"userId": DEFAULT_USER_ID,
|
userId=DEFAULT_USER_ID,
|
||||||
"amount": -700, # Spent 700 to get to 400
|
amount=-700, # Spent 700 to get to 400
|
||||||
"type": CreditTransactionType.USAGE,
|
type=CreditTransactionType.USAGE,
|
||||||
"runningBalance": 400,
|
runningBalance=400,
|
||||||
"isActive": True,
|
isActive=True,
|
||||||
"createdAt": month2,
|
createdAt=month2,
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Move to month 3
|
# Move to month 3
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import pytest
|
|||||||
from prisma.enums import CreditTransactionType
|
from prisma.enums import CreditTransactionType
|
||||||
from prisma.errors import UniqueViolationError
|
from prisma.errors import UniqueViolationError
|
||||||
from prisma.models import CreditTransaction, User, UserBalance
|
from prisma.models import CreditTransaction, User, UserBalance
|
||||||
|
from prisma.types import UserBalanceCreateInput, UserBalanceUpsertInput, UserCreateInput
|
||||||
|
|
||||||
from backend.data.credit import POSTGRES_INT_MIN, UserCredit
|
from backend.data.credit import POSTGRES_INT_MIN, UserCredit
|
||||||
from backend.util.test import SpinTestServer
|
from backend.util.test import SpinTestServer
|
||||||
@@ -21,11 +22,11 @@ async def create_test_user(user_id: str) -> None:
|
|||||||
"""Create a test user for underflow tests."""
|
"""Create a test user for underflow tests."""
|
||||||
try:
|
try:
|
||||||
await User.prisma().create(
|
await User.prisma().create(
|
||||||
data={
|
data=UserCreateInput(
|
||||||
"id": user_id,
|
id=user_id,
|
||||||
"email": f"test-{user_id}@example.com",
|
email=f"test-{user_id}@example.com",
|
||||||
"name": f"Test User {user_id[:8]}",
|
name=f"Test User {user_id[:8]}",
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
except UniqueViolationError:
|
except UniqueViolationError:
|
||||||
# User already exists, continue
|
# User already exists, continue
|
||||||
@@ -33,7 +34,10 @@ async def create_test_user(user_id: str) -> None:
|
|||||||
|
|
||||||
await UserBalance.prisma().upsert(
|
await UserBalance.prisma().upsert(
|
||||||
where={"userId": user_id},
|
where={"userId": user_id},
|
||||||
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
|
data=UserBalanceUpsertInput(
|
||||||
|
create=UserBalanceCreateInput(userId=user_id, balance=0),
|
||||||
|
update={"balance": 0},
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -66,14 +70,14 @@ async def test_debug_underflow_step_by_step(server: SpinTestServer):
|
|||||||
initial_balance_target = POSTGRES_INT_MIN + 100
|
initial_balance_target = POSTGRES_INT_MIN + 100
|
||||||
|
|
||||||
# Use direct database update to set the balance close to underflow
|
# Use direct database update to set the balance close to underflow
|
||||||
from prisma.models import UserBalance
|
|
||||||
|
|
||||||
await UserBalance.prisma().upsert(
|
await UserBalance.prisma().upsert(
|
||||||
where={"userId": user_id},
|
where={"userId": user_id},
|
||||||
data={
|
data=UserBalanceUpsertInput(
|
||||||
"create": {"userId": user_id, "balance": initial_balance_target},
|
create=UserBalanceCreateInput(
|
||||||
"update": {"balance": initial_balance_target},
|
userId=user_id, balance=initial_balance_target
|
||||||
},
|
),
|
||||||
|
update={"balance": initial_balance_target},
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
current_balance = await credit_system.get_credits(user_id)
|
current_balance = await credit_system.get_credits(user_id)
|
||||||
@@ -110,10 +114,10 @@ async def test_debug_underflow_step_by_step(server: SpinTestServer):
|
|||||||
# Set balance to exactly POSTGRES_INT_MIN
|
# Set balance to exactly POSTGRES_INT_MIN
|
||||||
await UserBalance.prisma().upsert(
|
await UserBalance.prisma().upsert(
|
||||||
where={"userId": user_id},
|
where={"userId": user_id},
|
||||||
data={
|
data=UserBalanceUpsertInput(
|
||||||
"create": {"userId": user_id, "balance": POSTGRES_INT_MIN},
|
create=UserBalanceCreateInput(userId=user_id, balance=POSTGRES_INT_MIN),
|
||||||
"update": {"balance": POSTGRES_INT_MIN},
|
update={"balance": POSTGRES_INT_MIN},
|
||||||
},
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
edge_balance = await credit_system.get_credits(user_id)
|
edge_balance = await credit_system.get_credits(user_id)
|
||||||
@@ -147,15 +151,13 @@ async def test_underflow_protection_large_refunds(server: SpinTestServer):
|
|||||||
# Set up balance close to underflow threshold to test the protection
|
# Set up balance close to underflow threshold to test the protection
|
||||||
# Set balance to POSTGRES_INT_MIN + 1000, then try to subtract 2000
|
# Set balance to POSTGRES_INT_MIN + 1000, then try to subtract 2000
|
||||||
# This should trigger underflow protection
|
# This should trigger underflow protection
|
||||||
from prisma.models import UserBalance
|
|
||||||
|
|
||||||
test_balance = POSTGRES_INT_MIN + 1000
|
test_balance = POSTGRES_INT_MIN + 1000
|
||||||
await UserBalance.prisma().upsert(
|
await UserBalance.prisma().upsert(
|
||||||
where={"userId": user_id},
|
where={"userId": user_id},
|
||||||
data={
|
data=UserBalanceUpsertInput(
|
||||||
"create": {"userId": user_id, "balance": test_balance},
|
create=UserBalanceCreateInput(userId=user_id, balance=test_balance),
|
||||||
"update": {"balance": test_balance},
|
update={"balance": test_balance},
|
||||||
},
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
current_balance = await credit_system.get_credits(user_id)
|
current_balance = await credit_system.get_credits(user_id)
|
||||||
@@ -212,15 +214,13 @@ async def test_multiple_large_refunds_cumulative_underflow(server: SpinTestServe
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Set up balance close to underflow threshold
|
# Set up balance close to underflow threshold
|
||||||
from prisma.models import UserBalance
|
|
||||||
|
|
||||||
initial_balance = POSTGRES_INT_MIN + 500 # Close to minimum but with some room
|
initial_balance = POSTGRES_INT_MIN + 500 # Close to minimum but with some room
|
||||||
await UserBalance.prisma().upsert(
|
await UserBalance.prisma().upsert(
|
||||||
where={"userId": user_id},
|
where={"userId": user_id},
|
||||||
data={
|
data=UserBalanceUpsertInput(
|
||||||
"create": {"userId": user_id, "balance": initial_balance},
|
create=UserBalanceCreateInput(userId=user_id, balance=initial_balance),
|
||||||
"update": {"balance": initial_balance},
|
update={"balance": initial_balance},
|
||||||
},
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply multiple refunds that would cumulatively underflow
|
# Apply multiple refunds that would cumulatively underflow
|
||||||
@@ -295,10 +295,10 @@ async def test_concurrent_large_refunds_no_underflow(server: SpinTestServer):
|
|||||||
initial_balance = POSTGRES_INT_MIN + 1000 # Close to minimum
|
initial_balance = POSTGRES_INT_MIN + 1000 # Close to minimum
|
||||||
await UserBalance.prisma().upsert(
|
await UserBalance.prisma().upsert(
|
||||||
where={"userId": user_id},
|
where={"userId": user_id},
|
||||||
data={
|
data=UserBalanceUpsertInput(
|
||||||
"create": {"userId": user_id, "balance": initial_balance},
|
create=UserBalanceCreateInput(userId=user_id, balance=initial_balance),
|
||||||
"update": {"balance": initial_balance},
|
update={"balance": initial_balance},
|
||||||
},
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def large_refund(amount: int, label: str):
|
async def large_refund(amount: int, label: str):
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import pytest
|
|||||||
from prisma.enums import CreditTransactionType
|
from prisma.enums import CreditTransactionType
|
||||||
from prisma.errors import UniqueViolationError
|
from prisma.errors import UniqueViolationError
|
||||||
from prisma.models import CreditTransaction, User, UserBalance
|
from prisma.models import CreditTransaction, User, UserBalance
|
||||||
|
from prisma.types import UserBalanceCreateInput, UserCreateInput
|
||||||
|
|
||||||
from backend.data.credit import UsageTransactionMetadata, UserCredit
|
from backend.data.credit import UsageTransactionMetadata, UserCredit
|
||||||
from backend.util.json import SafeJson
|
from backend.util.json import SafeJson
|
||||||
@@ -24,11 +25,11 @@ async def create_test_user(user_id: str) -> None:
|
|||||||
"""Create a test user for migration tests."""
|
"""Create a test user for migration tests."""
|
||||||
try:
|
try:
|
||||||
await User.prisma().create(
|
await User.prisma().create(
|
||||||
data={
|
data=UserCreateInput(
|
||||||
"id": user_id,
|
id=user_id,
|
||||||
"email": f"test-{user_id}@example.com",
|
email=f"test-{user_id}@example.com",
|
||||||
"name": f"Test User {user_id[:8]}",
|
name=f"Test User {user_id[:8]}",
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
except UniqueViolationError:
|
except UniqueViolationError:
|
||||||
# User already exists, continue
|
# User already exists, continue
|
||||||
@@ -121,7 +122,7 @@ async def test_detect_stale_user_balance_queries(server: SpinTestServer):
|
|||||||
try:
|
try:
|
||||||
# Create UserBalance with specific value
|
# Create UserBalance with specific value
|
||||||
await UserBalance.prisma().create(
|
await UserBalance.prisma().create(
|
||||||
data={"userId": user_id, "balance": 5000} # $50
|
data=UserBalanceCreateInput(userId=user_id, balance=5000) # $50
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify that get_credits returns UserBalance value (5000), not any stale User.balance value
|
# Verify that get_credits returns UserBalance value (5000), not any stale User.balance value
|
||||||
@@ -160,7 +161,9 @@ async def test_concurrent_operations_use_userbalance_only(server: SpinTestServer
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Set initial balance in UserBalance
|
# Set initial balance in UserBalance
|
||||||
await UserBalance.prisma().create(data={"userId": user_id, "balance": 1000})
|
await UserBalance.prisma().create(
|
||||||
|
data=UserBalanceCreateInput(userId=user_id, balance=1000)
|
||||||
|
)
|
||||||
|
|
||||||
# Run concurrent operations to ensure they all use UserBalance atomic operations
|
# Run concurrent operations to ensure they all use UserBalance atomic operations
|
||||||
async def concurrent_spend(amount: int, label: str):
|
async def concurrent_spend(amount: int, label: str):
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from prisma.models import (
|
|||||||
AgentNodeExecutionKeyValueData,
|
AgentNodeExecutionKeyValueData,
|
||||||
)
|
)
|
||||||
from prisma.types import (
|
from prisma.types import (
|
||||||
|
AgentGraphExecutionCreateInput,
|
||||||
AgentGraphExecutionUpdateManyMutationInput,
|
AgentGraphExecutionUpdateManyMutationInput,
|
||||||
AgentGraphExecutionWhereInput,
|
AgentGraphExecutionWhereInput,
|
||||||
AgentNodeExecutionCreateInput,
|
AgentNodeExecutionCreateInput,
|
||||||
@@ -35,7 +36,6 @@ from prisma.types import (
|
|||||||
AgentNodeExecutionKeyValueDataCreateInput,
|
AgentNodeExecutionKeyValueDataCreateInput,
|
||||||
AgentNodeExecutionUpdateInput,
|
AgentNodeExecutionUpdateInput,
|
||||||
AgentNodeExecutionWhereInput,
|
AgentNodeExecutionWhereInput,
|
||||||
AgentNodeExecutionWhereUniqueInput,
|
|
||||||
)
|
)
|
||||||
from pydantic import BaseModel, ConfigDict, JsonValue, ValidationError
|
from pydantic import BaseModel, ConfigDict, JsonValue, ValidationError
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
@@ -709,18 +709,18 @@ async def create_graph_execution(
|
|||||||
The id of the AgentGraphExecution and the list of ExecutionResult for each node.
|
The id of the AgentGraphExecution and the list of ExecutionResult for each node.
|
||||||
"""
|
"""
|
||||||
result = await AgentGraphExecution.prisma().create(
|
result = await AgentGraphExecution.prisma().create(
|
||||||
data={
|
data=AgentGraphExecutionCreateInput(
|
||||||
"agentGraphId": graph_id,
|
agentGraphId=graph_id,
|
||||||
"agentGraphVersion": graph_version,
|
agentGraphVersion=graph_version,
|
||||||
"executionStatus": ExecutionStatus.INCOMPLETE,
|
executionStatus=ExecutionStatus.INCOMPLETE,
|
||||||
"inputs": SafeJson(inputs),
|
inputs=SafeJson(inputs),
|
||||||
"credentialInputs": (
|
credentialInputs=(
|
||||||
SafeJson(credential_inputs) if credential_inputs else Json({})
|
SafeJson(credential_inputs) if credential_inputs else Json({})
|
||||||
),
|
),
|
||||||
"nodesInputMasks": (
|
nodesInputMasks=(
|
||||||
SafeJson(nodes_input_masks) if nodes_input_masks else Json({})
|
SafeJson(nodes_input_masks) if nodes_input_masks else Json({})
|
||||||
),
|
),
|
||||||
"NodeExecutions": {
|
NodeExecutions={
|
||||||
"create": [
|
"create": [
|
||||||
AgentNodeExecutionCreateInput(
|
AgentNodeExecutionCreateInput(
|
||||||
agentNodeId=node_id,
|
agentNodeId=node_id,
|
||||||
@@ -736,10 +736,10 @@ async def create_graph_execution(
|
|||||||
for node_id, node_input in starting_nodes_input
|
for node_id, node_input in starting_nodes_input
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"userId": user_id,
|
userId=user_id,
|
||||||
"agentPresetId": preset_id,
|
agentPresetId=preset_id,
|
||||||
"parentGraphExecutionId": parent_graph_exec_id,
|
parentGraphExecutionId=parent_graph_exec_id,
|
||||||
},
|
),
|
||||||
include=GRAPH_EXECUTION_INCLUDE_WITH_NODES,
|
include=GRAPH_EXECUTION_INCLUDE_WITH_NODES,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -831,10 +831,10 @@ async def upsert_execution_output(
|
|||||||
"""
|
"""
|
||||||
Insert AgentNodeExecutionInputOutput record for as one of AgentNodeExecution.Output.
|
Insert AgentNodeExecutionInputOutput record for as one of AgentNodeExecution.Output.
|
||||||
"""
|
"""
|
||||||
data: AgentNodeExecutionInputOutputCreateInput = {
|
data = AgentNodeExecutionInputOutputCreateInput(
|
||||||
"name": output_name,
|
name=output_name,
|
||||||
"referencedByOutputExecId": node_exec_id,
|
referencedByOutputExecId=node_exec_id,
|
||||||
}
|
)
|
||||||
if output_data is not None:
|
if output_data is not None:
|
||||||
data["data"] = SafeJson(output_data)
|
data["data"] = SafeJson(output_data)
|
||||||
await AgentNodeExecutionInputOutput.prisma().create(data=data)
|
await AgentNodeExecutionInputOutput.prisma().create(data=data)
|
||||||
@@ -964,6 +964,12 @@ async def update_node_execution_status(
|
|||||||
execution_data: BlockInput | None = None,
|
execution_data: BlockInput | None = None,
|
||||||
stats: dict[str, Any] | None = None,
|
stats: dict[str, Any] | None = None,
|
||||||
) -> NodeExecutionResult:
|
) -> NodeExecutionResult:
|
||||||
|
"""
|
||||||
|
Update a node execution's status with validation of allowed transitions.
|
||||||
|
|
||||||
|
⚠️ Internal executor use only - no user_id check. Callers (executor/manager.py)
|
||||||
|
are responsible for validating user authorization before invoking this function.
|
||||||
|
"""
|
||||||
if status == ExecutionStatus.QUEUED and execution_data is None:
|
if status == ExecutionStatus.QUEUED and execution_data is None:
|
||||||
raise ValueError("Execution data must be provided when queuing an execution.")
|
raise ValueError("Execution data must be provided when queuing an execution.")
|
||||||
|
|
||||||
@@ -974,25 +980,27 @@ async def update_node_execution_status(
|
|||||||
f"Invalid status transition: {status} has no valid source statuses"
|
f"Invalid status transition: {status} has no valid source statuses"
|
||||||
)
|
)
|
||||||
|
|
||||||
if res := await AgentNodeExecution.prisma().update(
|
# Fetch current execution to validate status transition
|
||||||
where=cast(
|
current = await AgentNodeExecution.prisma().find_unique(
|
||||||
AgentNodeExecutionWhereUniqueInput,
|
where={"id": node_exec_id}, include=EXECUTION_RESULT_INCLUDE
|
||||||
{
|
)
|
||||||
"id": node_exec_id,
|
if not current:
|
||||||
"executionStatus": {"in": [s.value for s in allowed_from]},
|
raise ValueError(f"Execution {node_exec_id} not found.")
|
||||||
},
|
|
||||||
),
|
# Validate current status allows transition to the new status
|
||||||
|
if current.executionStatus not in allowed_from:
|
||||||
|
# Return current state without updating if transition is not allowed
|
||||||
|
return NodeExecutionResult.from_db(current)
|
||||||
|
|
||||||
|
# Perform the update with only the unique identifier
|
||||||
|
res = await AgentNodeExecution.prisma().update(
|
||||||
|
where={"id": node_exec_id},
|
||||||
data=_get_update_status_data(status, execution_data, stats),
|
data=_get_update_status_data(status, execution_data, stats),
|
||||||
include=EXECUTION_RESULT_INCLUDE,
|
include=EXECUTION_RESULT_INCLUDE,
|
||||||
):
|
)
|
||||||
return NodeExecutionResult.from_db(res)
|
if not res:
|
||||||
|
raise ValueError(f"Failed to update execution {node_exec_id}.")
|
||||||
if res := await AgentNodeExecution.prisma().find_unique(
|
return NodeExecutionResult.from_db(res)
|
||||||
where={"id": node_exec_id}, include=EXECUTION_RESULT_INCLUDE
|
|
||||||
):
|
|
||||||
return NodeExecutionResult.from_db(res)
|
|
||||||
|
|
||||||
raise ValueError(f"Execution {node_exec_id} not found.")
|
|
||||||
|
|
||||||
|
|
||||||
def _get_update_status_data(
|
def _get_update_status_data(
|
||||||
|
|||||||
@@ -10,7 +10,11 @@ from typing import Optional
|
|||||||
|
|
||||||
from prisma.enums import ReviewStatus
|
from prisma.enums import ReviewStatus
|
||||||
from prisma.models import PendingHumanReview
|
from prisma.models import PendingHumanReview
|
||||||
from prisma.types import PendingHumanReviewUpdateInput
|
from prisma.types import (
|
||||||
|
PendingHumanReviewCreateInput,
|
||||||
|
PendingHumanReviewUpdateInput,
|
||||||
|
PendingHumanReviewUpsertInput,
|
||||||
|
)
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.api.features.executions.review.model import (
|
from backend.api.features.executions.review.model import (
|
||||||
@@ -66,20 +70,20 @@ async def get_or_create_human_review(
|
|||||||
# Upsert - get existing or create new review
|
# Upsert - get existing or create new review
|
||||||
review = await PendingHumanReview.prisma().upsert(
|
review = await PendingHumanReview.prisma().upsert(
|
||||||
where={"nodeExecId": node_exec_id},
|
where={"nodeExecId": node_exec_id},
|
||||||
data={
|
data=PendingHumanReviewUpsertInput(
|
||||||
"create": {
|
create=PendingHumanReviewCreateInput(
|
||||||
"userId": user_id,
|
userId=user_id,
|
||||||
"nodeExecId": node_exec_id,
|
nodeExecId=node_exec_id,
|
||||||
"graphExecId": graph_exec_id,
|
graphExecId=graph_exec_id,
|
||||||
"graphId": graph_id,
|
graphId=graph_id,
|
||||||
"graphVersion": graph_version,
|
graphVersion=graph_version,
|
||||||
"payload": SafeJson(input_data),
|
payload=SafeJson(input_data),
|
||||||
"instructions": message,
|
instructions=message,
|
||||||
"editable": editable,
|
editable=editable,
|
||||||
"status": ReviewStatus.WAITING,
|
status=ReviewStatus.WAITING,
|
||||||
},
|
),
|
||||||
"update": {}, # Do nothing on update - keep existing review as is
|
update={}, # Do nothing on update - keep existing review as is
|
||||||
},
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|||||||
@@ -7,7 +7,11 @@ import prisma
|
|||||||
import pydantic
|
import pydantic
|
||||||
from prisma.enums import OnboardingStep
|
from prisma.enums import OnboardingStep
|
||||||
from prisma.models import UserOnboarding
|
from prisma.models import UserOnboarding
|
||||||
from prisma.types import UserOnboardingCreateInput, UserOnboardingUpdateInput
|
from prisma.types import (
|
||||||
|
UserOnboardingCreateInput,
|
||||||
|
UserOnboardingUpdateInput,
|
||||||
|
UserOnboardingUpsertInput,
|
||||||
|
)
|
||||||
|
|
||||||
from backend.api.features.store.model import StoreAgentDetails
|
from backend.api.features.store.model import StoreAgentDetails
|
||||||
from backend.api.model import OnboardingNotificationPayload
|
from backend.api.model import OnboardingNotificationPayload
|
||||||
@@ -92,6 +96,7 @@ async def reset_user_onboarding(user_id: str):
|
|||||||
|
|
||||||
async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
|
async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
|
||||||
update: UserOnboardingUpdateInput = {}
|
update: UserOnboardingUpdateInput = {}
|
||||||
|
# get_user_onboarding guarantees the record exists via upsert
|
||||||
onboarding = await get_user_onboarding(user_id)
|
onboarding = await get_user_onboarding(user_id)
|
||||||
if data.walletShown:
|
if data.walletShown:
|
||||||
update["walletShown"] = data.walletShown
|
update["walletShown"] = data.walletShown
|
||||||
@@ -110,12 +115,14 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
|
|||||||
if data.onboardingAgentExecutionId is not None:
|
if data.onboardingAgentExecutionId is not None:
|
||||||
update["onboardingAgentExecutionId"] = data.onboardingAgentExecutionId
|
update["onboardingAgentExecutionId"] = data.onboardingAgentExecutionId
|
||||||
|
|
||||||
|
# The create branch is never taken since get_user_onboarding ensures the record exists,
|
||||||
|
# but upsert requires a create payload so we provide a minimal one
|
||||||
return await UserOnboarding.prisma().upsert(
|
return await UserOnboarding.prisma().upsert(
|
||||||
where={"userId": user_id},
|
where={"userId": user_id},
|
||||||
data={
|
data=UserOnboardingUpsertInput(
|
||||||
"create": {"userId": user_id, **update},
|
create=UserOnboardingCreateInput(userId=user_id),
|
||||||
"update": update,
|
update=update,
|
||||||
},
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import random
|
|||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from faker import Faker
|
from faker import Faker
|
||||||
|
from prisma.types import AgentBlockCreateInput
|
||||||
|
|
||||||
# Import API functions from the backend
|
# Import API functions from the backend
|
||||||
from backend.api.features.library.db import create_library_agent, create_preset
|
from backend.api.features.library.db import create_library_agent, create_preset
|
||||||
@@ -179,12 +180,12 @@ class TestDataCreator:
|
|||||||
for block in blocks_to_create:
|
for block in blocks_to_create:
|
||||||
try:
|
try:
|
||||||
await prisma.agentblock.create(
|
await prisma.agentblock.create(
|
||||||
data={
|
data=AgentBlockCreateInput(
|
||||||
"id": block.id,
|
id=block.id,
|
||||||
"name": block.name,
|
name=block.name,
|
||||||
"inputSchema": "{}",
|
inputSchema="{}",
|
||||||
"outputSchema": "{}",
|
outputSchema="{}",
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error creating block {block.name}: {e}")
|
print(f"Error creating block {block.name}: {e}")
|
||||||
|
|||||||
@@ -30,13 +30,19 @@ from prisma.types import (
|
|||||||
AgentGraphCreateInput,
|
AgentGraphCreateInput,
|
||||||
AgentNodeCreateInput,
|
AgentNodeCreateInput,
|
||||||
AgentNodeLinkCreateInput,
|
AgentNodeLinkCreateInput,
|
||||||
|
AgentPresetCreateInput,
|
||||||
AnalyticsDetailsCreateInput,
|
AnalyticsDetailsCreateInput,
|
||||||
AnalyticsMetricsCreateInput,
|
AnalyticsMetricsCreateInput,
|
||||||
|
APIKeyCreateInput,
|
||||||
CreditTransactionCreateInput,
|
CreditTransactionCreateInput,
|
||||||
IntegrationWebhookCreateInput,
|
IntegrationWebhookCreateInput,
|
||||||
|
LibraryAgentCreateInput,
|
||||||
ProfileCreateInput,
|
ProfileCreateInput,
|
||||||
|
StoreListingCreateInput,
|
||||||
StoreListingReviewCreateInput,
|
StoreListingReviewCreateInput,
|
||||||
|
StoreListingVersionCreateInput,
|
||||||
UserCreateInput,
|
UserCreateInput,
|
||||||
|
UserOnboardingCreateInput,
|
||||||
)
|
)
|
||||||
|
|
||||||
faker = Faker()
|
faker = Faker()
|
||||||
@@ -172,14 +178,14 @@ async def main():
|
|||||||
for _ in range(num_presets): # Create 1 AgentPreset per user
|
for _ in range(num_presets): # Create 1 AgentPreset per user
|
||||||
graph = random.choice(agent_graphs)
|
graph = random.choice(agent_graphs)
|
||||||
preset = await db.agentpreset.create(
|
preset = await db.agentpreset.create(
|
||||||
data={
|
data=AgentPresetCreateInput(
|
||||||
"name": faker.sentence(nb_words=3),
|
name=faker.sentence(nb_words=3),
|
||||||
"description": faker.text(max_nb_chars=200),
|
description=faker.text(max_nb_chars=200),
|
||||||
"userId": user.id,
|
userId=user.id,
|
||||||
"agentGraphId": graph.id,
|
agentGraphId=graph.id,
|
||||||
"agentGraphVersion": graph.version,
|
agentGraphVersion=graph.version,
|
||||||
"isActive": True,
|
isActive=True,
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
agent_presets.append(preset)
|
agent_presets.append(preset)
|
||||||
|
|
||||||
@@ -220,18 +226,18 @@ async def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
library_agent = await db.libraryagent.create(
|
library_agent = await db.libraryagent.create(
|
||||||
data={
|
data=LibraryAgentCreateInput(
|
||||||
"userId": user.id,
|
userId=user.id,
|
||||||
"agentGraphId": graph.id,
|
agentGraphId=graph.id,
|
||||||
"agentGraphVersion": graph.version,
|
agentGraphVersion=graph.version,
|
||||||
"creatorId": creator_profile.id if creator_profile else None,
|
creatorId=creator_profile.id if creator_profile else None,
|
||||||
"imageUrl": get_image() if random.random() < 0.5 else None,
|
imageUrl=get_image() if random.random() < 0.5 else None,
|
||||||
"useGraphIsActiveVersion": random.choice([True, False]),
|
useGraphIsActiveVersion=random.choice([True, False]),
|
||||||
"isFavorite": random.choice([True, False]),
|
isFavorite=random.choice([True, False]),
|
||||||
"isCreatedByUser": random.choice([True, False]),
|
isCreatedByUser=random.choice([True, False]),
|
||||||
"isArchived": random.choice([True, False]),
|
isArchived=random.choice([True, False]),
|
||||||
"isDeleted": random.choice([True, False]),
|
isDeleted=random.choice([True, False]),
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
library_agents.append(library_agent)
|
library_agents.append(library_agent)
|
||||||
|
|
||||||
@@ -392,13 +398,13 @@ async def main():
|
|||||||
user = random.choice(users)
|
user = random.choice(users)
|
||||||
slug = faker.slug()
|
slug = faker.slug()
|
||||||
listing = await db.storelisting.create(
|
listing = await db.storelisting.create(
|
||||||
data={
|
data=StoreListingCreateInput(
|
||||||
"agentGraphId": graph.id,
|
agentGraphId=graph.id,
|
||||||
"agentGraphVersion": graph.version,
|
agentGraphVersion=graph.version,
|
||||||
"owningUserId": user.id,
|
owningUserId=user.id,
|
||||||
"hasApprovedVersion": random.choice([True, False]),
|
hasApprovedVersion=random.choice([True, False]),
|
||||||
"slug": slug,
|
slug=slug,
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
store_listings.append(listing)
|
store_listings.append(listing)
|
||||||
|
|
||||||
@@ -408,26 +414,26 @@ async def main():
|
|||||||
for listing in store_listings:
|
for listing in store_listings:
|
||||||
graph = [g for g in agent_graphs if g.id == listing.agentGraphId][0]
|
graph = [g for g in agent_graphs if g.id == listing.agentGraphId][0]
|
||||||
version = await db.storelistingversion.create(
|
version = await db.storelistingversion.create(
|
||||||
data={
|
data=StoreListingVersionCreateInput(
|
||||||
"agentGraphId": graph.id,
|
agentGraphId=graph.id,
|
||||||
"agentGraphVersion": graph.version,
|
agentGraphVersion=graph.version,
|
||||||
"name": graph.name or faker.sentence(nb_words=3),
|
name=graph.name or faker.sentence(nb_words=3),
|
||||||
"subHeading": faker.sentence(),
|
subHeading=faker.sentence(),
|
||||||
"videoUrl": get_video_url() if random.random() < 0.3 else None,
|
videoUrl=get_video_url() if random.random() < 0.3 else None,
|
||||||
"imageUrls": [get_image() for _ in range(3)],
|
imageUrls=[get_image() for _ in range(3)],
|
||||||
"description": faker.text(),
|
description=faker.text(),
|
||||||
"categories": [faker.word() for _ in range(3)],
|
categories=[faker.word() for _ in range(3)],
|
||||||
"isFeatured": random.choice([True, False]),
|
isFeatured=random.choice([True, False]),
|
||||||
"isAvailable": True,
|
isAvailable=True,
|
||||||
"storeListingId": listing.id,
|
storeListingId=listing.id,
|
||||||
"submissionStatus": random.choice(
|
submissionStatus=random.choice(
|
||||||
[
|
[
|
||||||
prisma.enums.SubmissionStatus.PENDING,
|
prisma.enums.SubmissionStatus.PENDING,
|
||||||
prisma.enums.SubmissionStatus.APPROVED,
|
prisma.enums.SubmissionStatus.APPROVED,
|
||||||
prisma.enums.SubmissionStatus.REJECTED,
|
prisma.enums.SubmissionStatus.REJECTED,
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
store_listing_versions.append(version)
|
store_listing_versions.append(version)
|
||||||
|
|
||||||
@@ -469,51 +475,49 @@ async def main():
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
await db.useronboarding.create(
|
await db.useronboarding.create(
|
||||||
data={
|
data=UserOnboardingCreateInput(
|
||||||
"userId": user.id,
|
userId=user.id,
|
||||||
"completedSteps": completed_steps,
|
completedSteps=completed_steps,
|
||||||
"walletShown": random.choice([True, False]),
|
walletShown=random.choice([True, False]),
|
||||||
"notified": (
|
notified=(
|
||||||
random.sample(completed_steps, k=min(3, len(completed_steps)))
|
random.sample(completed_steps, k=min(3, len(completed_steps)))
|
||||||
if completed_steps
|
if completed_steps
|
||||||
else []
|
else []
|
||||||
),
|
),
|
||||||
"rewardedFor": (
|
rewardedFor=(
|
||||||
random.sample(completed_steps, k=min(2, len(completed_steps)))
|
random.sample(completed_steps, k=min(2, len(completed_steps)))
|
||||||
if completed_steps
|
if completed_steps
|
||||||
else []
|
else []
|
||||||
),
|
),
|
||||||
"usageReason": (
|
usageReason=(
|
||||||
random.choice(["personal", "business", "research", "learning"])
|
random.choice(["personal", "business", "research", "learning"])
|
||||||
if random.random() < 0.7
|
if random.random() < 0.7
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
"integrations": random.sample(
|
integrations=random.sample(
|
||||||
["github", "google", "discord", "slack"], k=random.randint(0, 2)
|
["github", "google", "discord", "slack"], k=random.randint(0, 2)
|
||||||
),
|
),
|
||||||
"otherIntegrations": (
|
otherIntegrations=(faker.word() if random.random() < 0.2 else None),
|
||||||
faker.word() if random.random() < 0.2 else None
|
selectedStoreListingVersionId=(
|
||||||
),
|
|
||||||
"selectedStoreListingVersionId": (
|
|
||||||
random.choice(store_listing_versions).id
|
random.choice(store_listing_versions).id
|
||||||
if store_listing_versions and random.random() < 0.5
|
if store_listing_versions and random.random() < 0.5
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
"onboardingAgentExecutionId": (
|
onboardingAgentExecutionId=(
|
||||||
random.choice(agent_graph_executions).id
|
random.choice(agent_graph_executions).id
|
||||||
if agent_graph_executions and random.random() < 0.3
|
if agent_graph_executions and random.random() < 0.3
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
"agentRuns": random.randint(0, 10),
|
agentRuns=random.randint(0, 10),
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error creating onboarding for user {user.id}: {e}")
|
print(f"Error creating onboarding for user {user.id}: {e}")
|
||||||
# Try simpler version
|
# Try simpler version
|
||||||
await db.useronboarding.create(
|
await db.useronboarding.create(
|
||||||
data={
|
data=UserOnboardingCreateInput(
|
||||||
"userId": user.id,
|
userId=user.id,
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Insert IntegrationWebhooks for some users
|
# Insert IntegrationWebhooks for some users
|
||||||
@@ -544,20 +548,20 @@ async def main():
|
|||||||
for user in users:
|
for user in users:
|
||||||
api_key = APIKeySmith().generate_key()
|
api_key = APIKeySmith().generate_key()
|
||||||
await db.apikey.create(
|
await db.apikey.create(
|
||||||
data={
|
data=APIKeyCreateInput(
|
||||||
"name": faker.word(),
|
name=faker.word(),
|
||||||
"head": api_key.head,
|
head=api_key.head,
|
||||||
"tail": api_key.tail,
|
tail=api_key.tail,
|
||||||
"hash": api_key.hash,
|
hash=api_key.hash,
|
||||||
"salt": api_key.salt,
|
salt=api_key.salt,
|
||||||
"status": prisma.enums.APIKeyStatus.ACTIVE,
|
status=prisma.enums.APIKeyStatus.ACTIVE,
|
||||||
"permissions": [
|
permissions=[
|
||||||
prisma.enums.APIKeyPermission.EXECUTE_GRAPH,
|
prisma.enums.APIKeyPermission.EXECUTE_GRAPH,
|
||||||
prisma.enums.APIKeyPermission.READ_GRAPH,
|
prisma.enums.APIKeyPermission.READ_GRAPH,
|
||||||
],
|
],
|
||||||
"description": faker.text(),
|
description=faker.text(),
|
||||||
"userId": user.id,
|
userId=user.id,
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Refresh materialized views
|
# Refresh materialized views
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from datetime import datetime, timedelta
|
|||||||
import prisma.enums
|
import prisma.enums
|
||||||
from faker import Faker
|
from faker import Faker
|
||||||
from prisma import Json, Prisma
|
from prisma import Json, Prisma
|
||||||
|
from prisma.types import CreditTransactionCreateInput, StoreListingReviewCreateInput
|
||||||
|
|
||||||
faker = Faker()
|
faker = Faker()
|
||||||
|
|
||||||
@@ -166,16 +167,16 @@ async def main():
|
|||||||
score = random.choices([1, 2, 3, 4, 5], weights=[5, 10, 20, 40, 25])[0]
|
score = random.choices([1, 2, 3, 4, 5], weights=[5, 10, 20, 40, 25])[0]
|
||||||
|
|
||||||
await db.storelistingreview.create(
|
await db.storelistingreview.create(
|
||||||
data={
|
data=StoreListingReviewCreateInput(
|
||||||
"storeListingVersionId": version.id,
|
storeListingVersionId=version.id,
|
||||||
"reviewByUserId": reviewer.id,
|
reviewByUserId=reviewer.id,
|
||||||
"score": score,
|
score=score,
|
||||||
"comments": (
|
comments=(
|
||||||
faker.text(max_nb_chars=200)
|
faker.text(max_nb_chars=200)
|
||||||
if random.random() < 0.7
|
if random.random() < 0.7
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
new_reviews_count += 1
|
new_reviews_count += 1
|
||||||
|
|
||||||
@@ -244,17 +245,17 @@ async def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
await db.credittransaction.create(
|
await db.credittransaction.create(
|
||||||
data={
|
data=CreditTransactionCreateInput(
|
||||||
"userId": user.id,
|
userId=user.id,
|
||||||
"amount": amount,
|
amount=amount,
|
||||||
"type": transaction_type,
|
type=transaction_type,
|
||||||
"metadata": Json(
|
metadata=Json(
|
||||||
{
|
{
|
||||||
"source": "test_updater",
|
"source": "test_updater",
|
||||||
"timestamp": datetime.now().isoformat(),
|
"timestamp": datetime.now().isoformat(),
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
transaction_count += 1
|
transaction_count += 1
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user