fix(backend/db): Fix unchecked Prisma statements (#9805)

This commit is contained in:
Reinier van der Leer
2025-04-10 23:04:42 +02:00
committed by GitHub
parent 2ca18d77a4
commit 8ea3bfabc4
11 changed files with 82 additions and 133 deletions

View File

@@ -3,7 +3,6 @@ import logging
from abc import ABC, abstractmethod
from collections import defaultdict
from datetime import datetime, timezone
from typing import cast
import stripe
from autogpt_libs.utils.cache import thread_cached
@@ -19,7 +18,6 @@ from prisma.types import (
CreditRefundRequestCreateInput,
CreditTransactionCreateInput,
CreditTransactionWhereInput,
IntFilter,
)
from tenacity import retry, stop_after_attempt, wait_exponential
@@ -215,7 +213,7 @@ class UserCreditBase(ABC):
"userId": user_id,
"createdAt": {"lte": top_time},
"isActive": True,
"runningBalance": cast(IntFilter, {"not": None}),
"NOT": [{"runningBalance": None}],
},
order={"createdAt": "desc"},
)

View File

@@ -89,7 +89,7 @@ async def transaction():
async def locked_transaction(key: str):
lock_key = zlib.crc32(key.encode("utf-8"))
async with transaction() as tx:
await tx.execute_raw(f"SELECT pg_advisory_xact_lock({lock_key})")
await tx.execute_raw("SELECT pg_advisory_xact_lock($1)", lock_key)
yield tx

View File

@@ -492,7 +492,8 @@ async def update_graph_execution_stats(
data = stats.model_dump() if stats else {}
if isinstance(data.get("error"), Exception):
data["error"] = str(data["error"])
res = await AgentGraphExecution.prisma().update(
updated_count = await AgentGraphExecution.prisma().update_many(
where={
"id": graph_exec_id,
"OR": [
@@ -504,10 +505,15 @@ async def update_graph_execution_stats(
"executionStatus": status,
"stats": Json(data),
},
)
if updated_count == 0:
return None
graph_exec = await AgentGraphExecution.prisma().find_unique_or_raise(
where={"id": graph_exec_id},
include=GRAPH_EXECUTION_INCLUDE,
)
return GraphExecution.from_db(res) if res else None
return GraphExecution.from_db(graph_exec)
async def update_node_execution_stats(node_exec_id: str, stats: NodeExecutionStats):
@@ -643,7 +649,7 @@ async def get_latest_node_execution(
where={
"agentNodeId": node_id,
"agentGraphExecutionId": graph_eid,
"executionStatus": {"not": ExecutionStatus.INCOMPLETE}, # type: ignore
"NOT": [{"executionStatus": ExecutionStatus.INCOMPLETE}],
},
order=[
{"queuedTime": "desc"},

View File

@@ -10,9 +10,7 @@ from prisma.models import AgentGraph, AgentNode, AgentNodeLink, StoreListingVers
from prisma.types import (
AgentGraphCreateInput,
AgentGraphWhereInput,
AgentGraphWhereInputRecursive1,
AgentNodeCreateInput,
AgentNodeIncludeFromAgentNodeRecursive1,
AgentNodeLinkCreateInput,
)
from pydantic.fields import computed_field
@@ -655,14 +653,11 @@ async def get_sub_graphs(graph: AgentGraph) -> list[AgentGraph]:
graphs = await AgentGraph.prisma().find_many(
where={
"OR": [
type_utils.typed(
AgentGraphWhereInputRecursive1,
{
"id": graph_id,
"version": graph_version,
"userId": graph.userId, # Ensure the sub-graph is owned by the same user
},
)
{
"id": graph_id,
"version": graph_version,
"userId": graph.userId, # Ensure the sub-graph is owned by the same user
}
for graph_id, graph_version in sub_graph_ids
]
},
@@ -678,13 +673,7 @@ async def get_sub_graphs(graph: AgentGraph) -> list[AgentGraph]:
async def get_connected_output_nodes(node_id: str) -> list[tuple[Link, Node]]:
links = await AgentNodeLink.prisma().find_many(
where={"agentNodeSourceId": node_id},
include={
"AgentNodeSink": {
"include": cast(
AgentNodeIncludeFromAgentNodeRecursive1, AGENT_NODE_INCLUDE
)
}
},
include={"AgentNodeSink": {"include": AGENT_NODE_INCLUDE}},
)
return [
(Link.from_db(link), NodeModel.from_db(link.AgentNodeSink))
@@ -930,12 +919,19 @@ async def migrate_llm_models(migrate_to: LlmModel):
# Convert enum values to a list of strings for the SQL query
enum_values = [v.value for v in LlmModel.__members__.values()]
escaped_enum_values = repr(tuple(enum_values)) # hack but works
query = f"""
UPDATE "AgentNode"
SET "constantInput" = jsonb_set("constantInput", '{{{path}}}', '"{migrate_to.value}"', true)
WHERE "agentBlockId" = '{id}'
AND "constantInput" ? '{path}'
AND "constantInput"->>'{path}' NOT IN ({','.join(f"'{value}'" for value in enum_values)})
SET "constantInput" = jsonb_set("constantInput", $1, $2, true)
WHERE "agentBlockId" = $3
AND "constantInput" ? $4
AND "constantInput"->>$4 NOT IN {escaped_enum_values}
"""
await db.execute_raw(query)
await db.execute_raw(
query, # type: ignore - is supposed to be LiteralString
"{" + path + "}",
f'"{migrate_to.value}"',
id,
path,
)

View File

@@ -4,7 +4,6 @@ import prisma.enums
import prisma.types
from backend.blocks.io import IO_BLOCK_IDs
from backend.util.type import typed_cast
AGENT_NODE_INCLUDE: prisma.types.AgentNodeInclude = {
"Input": True,
@@ -14,13 +13,7 @@ AGENT_NODE_INCLUDE: prisma.types.AgentNodeInclude = {
}
AGENT_GRAPH_INCLUDE: prisma.types.AgentGraphInclude = {
"Nodes": {
"include": typed_cast(
prisma.types.AgentNodeIncludeFromAgentNodeRecursive1,
prisma.types.AgentNodeIncludeFromAgentNode,
AGENT_NODE_INCLUDE,
)
}
"Nodes": {"include": AGENT_NODE_INCLUDE}
}
EXECUTION_RESULT_INCLUDE: prisma.types.AgentNodeExecutionInclude = {
@@ -56,13 +49,7 @@ GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = {
GRAPH_EXECUTION_INCLUDE_WITH_NODES["NodeExecutions"],
),
"where": {
"Node": typed_cast(
prisma.types.AgentNodeRelationFilter,
prisma.types.AgentNodeWhereInput,
{
"AgentBlock": {"id": {"in": IO_BLOCK_IDs}},
},
),
"Node": {"is": {"AgentBlock": {"is": {"id": {"in": IO_BLOCK_IDs}}}}},
"NOT": [{"executionStatus": prisma.enums.AgentExecutionStatus.INCOMPLETE}],
},
}
@@ -70,13 +57,7 @@ GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = {
INTEGRATION_WEBHOOK_INCLUDE: prisma.types.IntegrationWebhookInclude = {
"AgentNodes": {
"include": typed_cast(
prisma.types.AgentNodeIncludeFromAgentNodeRecursive1,
prisma.types.AgentNodeInclude,
AGENT_NODE_INCLUDE,
)
}
"AgentNodes": {"include": AGENT_NODE_INCLUDE}
}

View File

@@ -62,7 +62,7 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
return await UserOnboarding.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, **update}, # type: ignore
"create": {"userId": user_id, **update},
"update": update,
},
)

View File

@@ -149,7 +149,7 @@ async def migrate_and_encrypt_user_integrations():
logger.info(f"Migrating integration credentials for {len(users)} users")
for user in users:
raw_metadata = cast(UserMetadataRaw, user.metadata)
raw_metadata = cast(dict, user.metadata)
metadata = UserMetadata.model_validate(raw_metadata)
# Get existing integrations data
@@ -165,7 +165,6 @@ async def migrate_and_encrypt_user_integrations():
await update_user_integrations(user_id=user.id, data=integrations)
# Remove from metadata
raw_metadata = dict(raw_metadata)
raw_metadata.pop("integration_credentials", None)
raw_metadata.pop("integration_oauth_states", None)

View File

@@ -12,7 +12,6 @@ import backend.server.v2.store.exceptions
import backend.server.v2.store.model
from backend.data.graph import GraphModel, get_sub_graphs
from backend.data.includes import AGENT_GRAPH_INCLUDE
from backend.util.type import typed_cast
logger = logging.getLogger(__name__)
@@ -960,7 +959,7 @@ async def get_my_agents(
try:
search_filter: prisma.types.LibraryAgentWhereInput = {
"userId": user_id,
"AgentGraph": {"is": {"StoreListing": {"none": {"isDeleted": False}}}},
"AgentGraph": {"is": {"StoreListings": {"none": {"isDeleted": False}}}},
"isArchived": False,
"isDeleted": False,
}
@@ -1088,13 +1087,7 @@ async def review_store_submission(
where={"id": store_listing_version_id},
include={
"StoreListing": True,
"AgentGraph": {
"include": typed_cast(
prisma.types.AgentGraphIncludeFromAgentGraphRecursive1,
prisma.types.AgentGraphInclude,
AGENT_GRAPH_INCLUDE,
)
},
"AgentGraph": {"include": AGENT_GRAPH_INCLUDE},
},
)
)

View File

@@ -198,18 +198,6 @@ def convert(value: Any, target_type: Type[T]) -> T:
raise ConversionError(f"Failed to convert {value} to {target_type}") from e
def typed(type: type[T], value: T) -> T:
"""
Add an explicit type to a value. Useful in nested statements, e.g. dict literals.
"""
return value
def typed_cast(to_type: type[TT], from_type: type[T], value: T) -> TT:
"""Strict cast to preserve type checking abilities."""
return cast(TT, value)
class FormattedStringType(str):
string_format: str

View File

@@ -5,7 +5,7 @@ datasource db {
generator client {
provider = "prisma-client-py"
recursive_type_depth = 5
recursive_type_depth = -1
interface = "asyncio"
previewFeatures = ["views"]
}

View File

@@ -140,14 +140,14 @@ async def main():
for _ in range(num_presets): # Create 1 AgentPreset per user
graph = random.choice(agent_graphs)
preset = await db.agentpreset.create(
data=AgentPresetCreateInput(
name=faker.sentence(nb_words=3),
description=faker.text(max_nb_chars=200),
userId=user.id,
agentId=graph.id,
agentVersion=graph.version,
isActive=True,
)
data={
"name": faker.sentence(nb_words=3),
"description": faker.text(max_nb_chars=200),
"userId": user.id,
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
"isActive": True,
}
)
agent_presets.append(preset)
@@ -160,16 +160,15 @@ async def main():
graph = random.choice(agent_graphs)
preset = random.choice(agent_presets)
user_agent = await db.libraryagent.create(
data=LibraryAgentCreateInput(
userId=user.id,
agentId=graph.id,
agentVersion=graph.version,
agentPresetId=preset.id,
isFavorite=random.choice([True, False]),
isCreatedByUser=random.choice([True, False]),
isArchived=random.choice([True, False]),
isDeleted=random.choice([True, False]),
)
data={
"userId": user.id,
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
"isFavorite": random.choice([True, False]),
"isCreatedByUser": random.choice([True, False]),
"isArchived": random.choice([True, False]),
"isDeleted": random.choice([True, False]),
}
)
user_agents.append(user_agent)
@@ -346,13 +345,13 @@ async def main():
user = random.choice(users)
slug = faker.slug()
listing = await db.storelisting.create(
data=StoreListingCreateInput(
agentId=graph.id,
agentVersion=graph.version,
owningUserId=user.id,
hasApprovedVersion=random.choice([True, False]),
slug=slug,
)
data={
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
"owningUserId": user.id,
"hasApprovedVersion": random.choice([True, False]),
"slug": slug,
}
)
store_listings.append(listing)
@@ -362,26 +361,26 @@ async def main():
for listing in store_listings:
graph = [g for g in agent_graphs if g.id == listing.agentId][0]
version = await db.storelistingversion.create(
data=StoreListingVersionCreateInput(
agentId=graph.id,
agentVersion=graph.version,
name=graph.name or faker.sentence(nb_words=3),
subHeading=faker.sentence(),
videoUrl=faker.url(),
imageUrls=[get_image() for _ in range(3)],
description=faker.text(),
categories=[faker.word() for _ in range(3)],
isFeatured=random.choice([True, False]),
isAvailable=True,
storeListingId=listing.id,
submissionStatus=random.choice(
data={
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
"name": graph.name or faker.sentence(nb_words=3),
"subHeading": faker.sentence(),
"videoUrl": faker.url(),
"imageUrls": [get_image() for _ in range(3)],
"description": faker.text(),
"categories": [faker.word() for _ in range(3)],
"isFeatured": random.choice([True, False]),
"isAvailable": True,
"storeListingId": listing.id,
"submissionStatus": random.choice(
[
prisma.enums.SubmissionStatus.PENDING,
prisma.enums.SubmissionStatus.APPROVED,
prisma.enums.SubmissionStatus.REJECTED,
]
),
)
}
)
store_listing_versions.append(version)
@@ -422,23 +421,12 @@ async def main():
)
await db.storelistingversion.update(
where={"id": version.id},
data=StoreListingVersionCreateInput(
submissionStatus=status,
Reviewer={"connect": {"id": reviewer.id}},
reviewComments=faker.text(),
reviewedAt=datetime.now(),
agentId=version.agentId, # preserving existing fields
agentVersion=version.agentVersion,
name=version.name,
subHeading=version.subHeading,
videoUrl=version.videoUrl,
imageUrls=version.imageUrls,
description=version.description,
categories=version.categories,
isFeatured=version.isFeatured,
isAvailable=version.isAvailable,
storeListingId=version.storeListingId,
),
data={
"submissionStatus": status,
"Reviewer": {"connect": {"id": reviewer.id}},
"reviewComments": faker.text(),
"reviewedAt": datetime.now(),
},
)
# Insert APIKeys