mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
fix(backend/db): Fix unchecked Prisma statements (#9805)
This commit is contained in:
committed by
GitHub
parent
2ca18d77a4
commit
8ea3bfabc4
@@ -3,7 +3,6 @@ import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import cast
|
||||
|
||||
import stripe
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
@@ -19,7 +18,6 @@ from prisma.types import (
|
||||
CreditRefundRequestCreateInput,
|
||||
CreditTransactionCreateInput,
|
||||
CreditTransactionWhereInput,
|
||||
IntFilter,
|
||||
)
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
|
||||
@@ -215,7 +213,7 @@ class UserCreditBase(ABC):
|
||||
"userId": user_id,
|
||||
"createdAt": {"lte": top_time},
|
||||
"isActive": True,
|
||||
"runningBalance": cast(IntFilter, {"not": None}),
|
||||
"NOT": [{"runningBalance": None}],
|
||||
},
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
|
||||
@@ -89,7 +89,7 @@ async def transaction():
|
||||
async def locked_transaction(key: str):
|
||||
lock_key = zlib.crc32(key.encode("utf-8"))
|
||||
async with transaction() as tx:
|
||||
await tx.execute_raw(f"SELECT pg_advisory_xact_lock({lock_key})")
|
||||
await tx.execute_raw("SELECT pg_advisory_xact_lock($1)", lock_key)
|
||||
yield tx
|
||||
|
||||
|
||||
|
||||
@@ -492,7 +492,8 @@ async def update_graph_execution_stats(
|
||||
data = stats.model_dump() if stats else {}
|
||||
if isinstance(data.get("error"), Exception):
|
||||
data["error"] = str(data["error"])
|
||||
res = await AgentGraphExecution.prisma().update(
|
||||
|
||||
updated_count = await AgentGraphExecution.prisma().update_many(
|
||||
where={
|
||||
"id": graph_exec_id,
|
||||
"OR": [
|
||||
@@ -504,10 +505,15 @@ async def update_graph_execution_stats(
|
||||
"executionStatus": status,
|
||||
"stats": Json(data),
|
||||
},
|
||||
)
|
||||
if updated_count == 0:
|
||||
return None
|
||||
|
||||
graph_exec = await AgentGraphExecution.prisma().find_unique_or_raise(
|
||||
where={"id": graph_exec_id},
|
||||
include=GRAPH_EXECUTION_INCLUDE,
|
||||
)
|
||||
|
||||
return GraphExecution.from_db(res) if res else None
|
||||
return GraphExecution.from_db(graph_exec)
|
||||
|
||||
|
||||
async def update_node_execution_stats(node_exec_id: str, stats: NodeExecutionStats):
|
||||
@@ -643,7 +649,7 @@ async def get_latest_node_execution(
|
||||
where={
|
||||
"agentNodeId": node_id,
|
||||
"agentGraphExecutionId": graph_eid,
|
||||
"executionStatus": {"not": ExecutionStatus.INCOMPLETE}, # type: ignore
|
||||
"NOT": [{"executionStatus": ExecutionStatus.INCOMPLETE}],
|
||||
},
|
||||
order=[
|
||||
{"queuedTime": "desc"},
|
||||
|
||||
@@ -10,9 +10,7 @@ from prisma.models import AgentGraph, AgentNode, AgentNodeLink, StoreListingVers
|
||||
from prisma.types import (
|
||||
AgentGraphCreateInput,
|
||||
AgentGraphWhereInput,
|
||||
AgentGraphWhereInputRecursive1,
|
||||
AgentNodeCreateInput,
|
||||
AgentNodeIncludeFromAgentNodeRecursive1,
|
||||
AgentNodeLinkCreateInput,
|
||||
)
|
||||
from pydantic.fields import computed_field
|
||||
@@ -655,14 +653,11 @@ async def get_sub_graphs(graph: AgentGraph) -> list[AgentGraph]:
|
||||
graphs = await AgentGraph.prisma().find_many(
|
||||
where={
|
||||
"OR": [
|
||||
type_utils.typed(
|
||||
AgentGraphWhereInputRecursive1,
|
||||
{
|
||||
"id": graph_id,
|
||||
"version": graph_version,
|
||||
"userId": graph.userId, # Ensure the sub-graph is owned by the same user
|
||||
},
|
||||
)
|
||||
{
|
||||
"id": graph_id,
|
||||
"version": graph_version,
|
||||
"userId": graph.userId, # Ensure the sub-graph is owned by the same user
|
||||
}
|
||||
for graph_id, graph_version in sub_graph_ids
|
||||
]
|
||||
},
|
||||
@@ -678,13 +673,7 @@ async def get_sub_graphs(graph: AgentGraph) -> list[AgentGraph]:
|
||||
async def get_connected_output_nodes(node_id: str) -> list[tuple[Link, Node]]:
|
||||
links = await AgentNodeLink.prisma().find_many(
|
||||
where={"agentNodeSourceId": node_id},
|
||||
include={
|
||||
"AgentNodeSink": {
|
||||
"include": cast(
|
||||
AgentNodeIncludeFromAgentNodeRecursive1, AGENT_NODE_INCLUDE
|
||||
)
|
||||
}
|
||||
},
|
||||
include={"AgentNodeSink": {"include": AGENT_NODE_INCLUDE}},
|
||||
)
|
||||
return [
|
||||
(Link.from_db(link), NodeModel.from_db(link.AgentNodeSink))
|
||||
@@ -930,12 +919,19 @@ async def migrate_llm_models(migrate_to: LlmModel):
|
||||
# Convert enum values to a list of strings for the SQL query
|
||||
enum_values = [v.value for v in LlmModel.__members__.values()]
|
||||
|
||||
escaped_enum_values = repr(tuple(enum_values)) # hack but works
|
||||
query = f"""
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = jsonb_set("constantInput", '{{{path}}}', '"{migrate_to.value}"', true)
|
||||
WHERE "agentBlockId" = '{id}'
|
||||
AND "constantInput" ? '{path}'
|
||||
AND "constantInput"->>'{path}' NOT IN ({','.join(f"'{value}'" for value in enum_values)})
|
||||
SET "constantInput" = jsonb_set("constantInput", $1, $2, true)
|
||||
WHERE "agentBlockId" = $3
|
||||
AND "constantInput" ? $4
|
||||
AND "constantInput"->>$4 NOT IN {escaped_enum_values}
|
||||
"""
|
||||
|
||||
await db.execute_raw(query)
|
||||
await db.execute_raw(
|
||||
query, # type: ignore - is supposed to be LiteralString
|
||||
"{" + path + "}",
|
||||
f'"{migrate_to.value}"',
|
||||
id,
|
||||
path,
|
||||
)
|
||||
|
||||
@@ -4,7 +4,6 @@ import prisma.enums
|
||||
import prisma.types
|
||||
|
||||
from backend.blocks.io import IO_BLOCK_IDs
|
||||
from backend.util.type import typed_cast
|
||||
|
||||
AGENT_NODE_INCLUDE: prisma.types.AgentNodeInclude = {
|
||||
"Input": True,
|
||||
@@ -14,13 +13,7 @@ AGENT_NODE_INCLUDE: prisma.types.AgentNodeInclude = {
|
||||
}
|
||||
|
||||
AGENT_GRAPH_INCLUDE: prisma.types.AgentGraphInclude = {
|
||||
"Nodes": {
|
||||
"include": typed_cast(
|
||||
prisma.types.AgentNodeIncludeFromAgentNodeRecursive1,
|
||||
prisma.types.AgentNodeIncludeFromAgentNode,
|
||||
AGENT_NODE_INCLUDE,
|
||||
)
|
||||
}
|
||||
"Nodes": {"include": AGENT_NODE_INCLUDE}
|
||||
}
|
||||
|
||||
EXECUTION_RESULT_INCLUDE: prisma.types.AgentNodeExecutionInclude = {
|
||||
@@ -56,13 +49,7 @@ GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = {
|
||||
GRAPH_EXECUTION_INCLUDE_WITH_NODES["NodeExecutions"],
|
||||
),
|
||||
"where": {
|
||||
"Node": typed_cast(
|
||||
prisma.types.AgentNodeRelationFilter,
|
||||
prisma.types.AgentNodeWhereInput,
|
||||
{
|
||||
"AgentBlock": {"id": {"in": IO_BLOCK_IDs}},
|
||||
},
|
||||
),
|
||||
"Node": {"is": {"AgentBlock": {"is": {"id": {"in": IO_BLOCK_IDs}}}}},
|
||||
"NOT": [{"executionStatus": prisma.enums.AgentExecutionStatus.INCOMPLETE}],
|
||||
},
|
||||
}
|
||||
@@ -70,13 +57,7 @@ GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = {
|
||||
|
||||
|
||||
INTEGRATION_WEBHOOK_INCLUDE: prisma.types.IntegrationWebhookInclude = {
|
||||
"AgentNodes": {
|
||||
"include": typed_cast(
|
||||
prisma.types.AgentNodeIncludeFromAgentNodeRecursive1,
|
||||
prisma.types.AgentNodeInclude,
|
||||
AGENT_NODE_INCLUDE,
|
||||
)
|
||||
}
|
||||
"AgentNodes": {"include": AGENT_NODE_INCLUDE}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
|
||||
return await UserOnboarding.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, **update}, # type: ignore
|
||||
"create": {"userId": user_id, **update},
|
||||
"update": update,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -149,7 +149,7 @@ async def migrate_and_encrypt_user_integrations():
|
||||
logger.info(f"Migrating integration credentials for {len(users)} users")
|
||||
|
||||
for user in users:
|
||||
raw_metadata = cast(UserMetadataRaw, user.metadata)
|
||||
raw_metadata = cast(dict, user.metadata)
|
||||
metadata = UserMetadata.model_validate(raw_metadata)
|
||||
|
||||
# Get existing integrations data
|
||||
@@ -165,7 +165,6 @@ async def migrate_and_encrypt_user_integrations():
|
||||
await update_user_integrations(user_id=user.id, data=integrations)
|
||||
|
||||
# Remove from metadata
|
||||
raw_metadata = dict(raw_metadata)
|
||||
raw_metadata.pop("integration_credentials", None)
|
||||
raw_metadata.pop("integration_oauth_states", None)
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ import backend.server.v2.store.exceptions
|
||||
import backend.server.v2.store.model
|
||||
from backend.data.graph import GraphModel, get_sub_graphs
|
||||
from backend.data.includes import AGENT_GRAPH_INCLUDE
|
||||
from backend.util.type import typed_cast
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -960,7 +959,7 @@ async def get_my_agents(
|
||||
try:
|
||||
search_filter: prisma.types.LibraryAgentWhereInput = {
|
||||
"userId": user_id,
|
||||
"AgentGraph": {"is": {"StoreListing": {"none": {"isDeleted": False}}}},
|
||||
"AgentGraph": {"is": {"StoreListings": {"none": {"isDeleted": False}}}},
|
||||
"isArchived": False,
|
||||
"isDeleted": False,
|
||||
}
|
||||
@@ -1088,13 +1087,7 @@ async def review_store_submission(
|
||||
where={"id": store_listing_version_id},
|
||||
include={
|
||||
"StoreListing": True,
|
||||
"AgentGraph": {
|
||||
"include": typed_cast(
|
||||
prisma.types.AgentGraphIncludeFromAgentGraphRecursive1,
|
||||
prisma.types.AgentGraphInclude,
|
||||
AGENT_GRAPH_INCLUDE,
|
||||
)
|
||||
},
|
||||
"AgentGraph": {"include": AGENT_GRAPH_INCLUDE},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -198,18 +198,6 @@ def convert(value: Any, target_type: Type[T]) -> T:
|
||||
raise ConversionError(f"Failed to convert {value} to {target_type}") from e
|
||||
|
||||
|
||||
def typed(type: type[T], value: T) -> T:
|
||||
"""
|
||||
Add an explicit type to a value. Useful in nested statements, e.g. dict literals.
|
||||
"""
|
||||
return value
|
||||
|
||||
|
||||
def typed_cast(to_type: type[TT], from_type: type[T], value: T) -> TT:
|
||||
"""Strict cast to preserve type checking abilities."""
|
||||
return cast(TT, value)
|
||||
|
||||
|
||||
class FormattedStringType(str):
|
||||
string_format: str
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ datasource db {
|
||||
|
||||
generator client {
|
||||
provider = "prisma-client-py"
|
||||
recursive_type_depth = 5
|
||||
recursive_type_depth = -1
|
||||
interface = "asyncio"
|
||||
previewFeatures = ["views"]
|
||||
}
|
||||
|
||||
@@ -140,14 +140,14 @@ async def main():
|
||||
for _ in range(num_presets): # Create 1 AgentPreset per user
|
||||
graph = random.choice(agent_graphs)
|
||||
preset = await db.agentpreset.create(
|
||||
data=AgentPresetCreateInput(
|
||||
name=faker.sentence(nb_words=3),
|
||||
description=faker.text(max_nb_chars=200),
|
||||
userId=user.id,
|
||||
agentId=graph.id,
|
||||
agentVersion=graph.version,
|
||||
isActive=True,
|
||||
)
|
||||
data={
|
||||
"name": faker.sentence(nb_words=3),
|
||||
"description": faker.text(max_nb_chars=200),
|
||||
"userId": user.id,
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
"isActive": True,
|
||||
}
|
||||
)
|
||||
agent_presets.append(preset)
|
||||
|
||||
@@ -160,16 +160,15 @@ async def main():
|
||||
graph = random.choice(agent_graphs)
|
||||
preset = random.choice(agent_presets)
|
||||
user_agent = await db.libraryagent.create(
|
||||
data=LibraryAgentCreateInput(
|
||||
userId=user.id,
|
||||
agentId=graph.id,
|
||||
agentVersion=graph.version,
|
||||
agentPresetId=preset.id,
|
||||
isFavorite=random.choice([True, False]),
|
||||
isCreatedByUser=random.choice([True, False]),
|
||||
isArchived=random.choice([True, False]),
|
||||
isDeleted=random.choice([True, False]),
|
||||
)
|
||||
data={
|
||||
"userId": user.id,
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
"isFavorite": random.choice([True, False]),
|
||||
"isCreatedByUser": random.choice([True, False]),
|
||||
"isArchived": random.choice([True, False]),
|
||||
"isDeleted": random.choice([True, False]),
|
||||
}
|
||||
)
|
||||
user_agents.append(user_agent)
|
||||
|
||||
@@ -346,13 +345,13 @@ async def main():
|
||||
user = random.choice(users)
|
||||
slug = faker.slug()
|
||||
listing = await db.storelisting.create(
|
||||
data=StoreListingCreateInput(
|
||||
agentId=graph.id,
|
||||
agentVersion=graph.version,
|
||||
owningUserId=user.id,
|
||||
hasApprovedVersion=random.choice([True, False]),
|
||||
slug=slug,
|
||||
)
|
||||
data={
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
"owningUserId": user.id,
|
||||
"hasApprovedVersion": random.choice([True, False]),
|
||||
"slug": slug,
|
||||
}
|
||||
)
|
||||
store_listings.append(listing)
|
||||
|
||||
@@ -362,26 +361,26 @@ async def main():
|
||||
for listing in store_listings:
|
||||
graph = [g for g in agent_graphs if g.id == listing.agentId][0]
|
||||
version = await db.storelistingversion.create(
|
||||
data=StoreListingVersionCreateInput(
|
||||
agentId=graph.id,
|
||||
agentVersion=graph.version,
|
||||
name=graph.name or faker.sentence(nb_words=3),
|
||||
subHeading=faker.sentence(),
|
||||
videoUrl=faker.url(),
|
||||
imageUrls=[get_image() for _ in range(3)],
|
||||
description=faker.text(),
|
||||
categories=[faker.word() for _ in range(3)],
|
||||
isFeatured=random.choice([True, False]),
|
||||
isAvailable=True,
|
||||
storeListingId=listing.id,
|
||||
submissionStatus=random.choice(
|
||||
data={
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
"name": graph.name or faker.sentence(nb_words=3),
|
||||
"subHeading": faker.sentence(),
|
||||
"videoUrl": faker.url(),
|
||||
"imageUrls": [get_image() for _ in range(3)],
|
||||
"description": faker.text(),
|
||||
"categories": [faker.word() for _ in range(3)],
|
||||
"isFeatured": random.choice([True, False]),
|
||||
"isAvailable": True,
|
||||
"storeListingId": listing.id,
|
||||
"submissionStatus": random.choice(
|
||||
[
|
||||
prisma.enums.SubmissionStatus.PENDING,
|
||||
prisma.enums.SubmissionStatus.APPROVED,
|
||||
prisma.enums.SubmissionStatus.REJECTED,
|
||||
]
|
||||
),
|
||||
)
|
||||
}
|
||||
)
|
||||
store_listing_versions.append(version)
|
||||
|
||||
@@ -422,23 +421,12 @@ async def main():
|
||||
)
|
||||
await db.storelistingversion.update(
|
||||
where={"id": version.id},
|
||||
data=StoreListingVersionCreateInput(
|
||||
submissionStatus=status,
|
||||
Reviewer={"connect": {"id": reviewer.id}},
|
||||
reviewComments=faker.text(),
|
||||
reviewedAt=datetime.now(),
|
||||
agentId=version.agentId, # preserving existing fields
|
||||
agentVersion=version.agentVersion,
|
||||
name=version.name,
|
||||
subHeading=version.subHeading,
|
||||
videoUrl=version.videoUrl,
|
||||
imageUrls=version.imageUrls,
|
||||
description=version.description,
|
||||
categories=version.categories,
|
||||
isFeatured=version.isFeatured,
|
||||
isAvailable=version.isAvailable,
|
||||
storeListingId=version.storeListingId,
|
||||
),
|
||||
data={
|
||||
"submissionStatus": status,
|
||||
"Reviewer": {"connect": {"id": reviewer.id}},
|
||||
"reviewComments": faker.text(),
|
||||
"reviewedAt": datetime.now(),
|
||||
},
|
||||
)
|
||||
|
||||
# Insert APIKeys
|
||||
|
||||
Reference in New Issue
Block a user