fix(backend): Update deprecated code caused by upgrades (#9758)

This series of upgrades:
https://github.com/significant-gravitas/autogpt/pull/9727
https://github.com/Significant-Gravitas/AutoGPT/pull/9728
https://github.com/Significant-Gravitas/AutoGPT/pull/9560

Caused some code in the repo being deprecated, this PR addresses those.

### Changes 🏗️

Fix pydantic config, usage of field, usage of proper prisma
`CreateInput` type, pytest loop-scope.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] CI, manual test on running some agents.

---------

Co-authored-by: Nicholas Tindle <nicholas.tindle@agpt.co>
This commit is contained in:
Zamil Majdy
2025-04-04 20:34:40 +04:00
committed by GitHub
parent 4397746a87
commit 3771a0924c
38 changed files with 433 additions and 382 deletions

View File

@@ -1,6 +1,7 @@
import logging
import re
from typing import Any
import uvicorn.config
from colorama import Fore

View File

@@ -1,7 +1,7 @@
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from backend.data.model import SchemaField
@@ -143,11 +143,12 @@ class ContactEmail(BaseModel):
class EmploymentHistory(BaseModel):
"""An employment history in Apollo"""
class Config:
extra = "allow"
arbitrary_types_allowed = True
from_attributes = True
populate_by_name = True
model_config = ConfigDict(
extra="allow",
arbitrary_types_allowed=True,
from_attributes=True,
populate_by_name=True,
)
_id: Optional[str] = None
created_at: Optional[str] = None
@@ -188,11 +189,12 @@ class TypedCustomField(BaseModel):
class Pagination(BaseModel):
"""Pagination in Apollo"""
class Config:
extra = "allow" # Allow extra fields
arbitrary_types_allowed = True # Allow any type
from_attributes = True # Allow from_orm
populate_by_name = True # Allow field aliases to work both ways
model_config = ConfigDict(
extra="allow",
arbitrary_types_allowed=True,
from_attributes=True,
populate_by_name=True,
)
page: int = 0
per_page: int = 0
@@ -230,11 +232,12 @@ class PhoneNumber(BaseModel):
class Organization(BaseModel):
"""An organization in Apollo"""
class Config:
extra = "allow"
arbitrary_types_allowed = True
from_attributes = True
populate_by_name = True
model_config = ConfigDict(
extra="allow",
arbitrary_types_allowed=True,
from_attributes=True,
populate_by_name=True,
)
id: Optional[str] = "N/A"
name: Optional[str] = "N/A"
@@ -268,11 +271,12 @@ class Organization(BaseModel):
class Contact(BaseModel):
"""A contact in Apollo"""
class Config:
extra = "allow"
arbitrary_types_allowed = True
from_attributes = True
populate_by_name = True
model_config = ConfigDict(
extra="allow",
arbitrary_types_allowed=True,
from_attributes=True,
populate_by_name=True,
)
contact_roles: list[Any] = []
id: Optional[str] = None
@@ -522,11 +526,12 @@ Use the page parameter to search the different pages of data.""",
class SearchPeopleResponse(BaseModel):
"""Response from Apollo's search people API"""
class Config:
extra = "allow" # Allow extra fields
arbitrary_types_allowed = True # Allow any type
from_attributes = True # Allow from_orm
populate_by_name = True # Allow field aliases to work both ways
model_config = ConfigDict(
extra="allow",
arbitrary_types_allowed=True,
from_attributes=True,
populate_by_name=True,
)
breadcrumbs: list[Breadcrumb] = []
partial_results_only: bool = True

View File

@@ -42,7 +42,7 @@ class AIVideoGeneratorBlock(Block):
description="Error message if video generation failed."
)
logs: list[str] = SchemaField(
description="Generation progress logs.", optional=True
description="Generation progress logs.",
)
def __init__(self):

View File

@@ -27,7 +27,6 @@ class HubSpotEngagementBlock(Block):
timeframe_days: int = SchemaField(
description="Number of days to look back for engagement",
default=30,
optional=True,
)
class Output(BlockSchema):

View File

@@ -23,7 +23,7 @@ class JinaChunkingBlock(Block):
class Output(BlockSchema):
chunks: list = SchemaField(description="List of chunked texts")
tokens: list = SchemaField(
description="List of token information for each chunk", optional=True
description="List of token information for each chunk",
)
def __init__(self):

View File

@@ -1,4 +1,4 @@
from groq._utils._utils import quote
from urllib.parse import quote
from backend.blocks.jina._auth import (
TEST_CREDENTIALS,

View File

@@ -28,8 +28,8 @@ class LinearCreateIssueBlock(Block):
priority: int | None = SchemaField(
description="Priority of the issue",
default=None,
minimum=0,
maximum=4,
ge=0,
le=4,
)
project_name: str | None = SchemaField(
description="Name of the project to create the issue on",

View File

@@ -4,30 +4,25 @@ from abc import ABC
from enum import Enum, EnumMeta
from json import JSONDecodeError
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Iterable, List, Literal, NamedTuple, Optional
from pydantic import BaseModel, SecretStr
from backend.data.model import NodeExecutionStats
from backend.integrations.providers import ProviderName
if TYPE_CHECKING:
from enum import _EnumMemberT
from typing import Any, Iterable, List, Literal, NamedTuple, Optional
import anthropic
import ollama
import openai
from anthropic._types import NotGiven
from anthropic import NotGiven
from anthropic.types import ToolParam
from groq import Groq
from pydantic import BaseModel, SecretStr
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
CredentialsMetaInput,
NodeExecutionStats,
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util import json
from backend.util.settings import BehaveAs, Settings
from backend.util.text import TextFormatter
@@ -77,12 +72,10 @@ class ModelMetadata(NamedTuple):
class LlmModelMeta(EnumMeta):
@property
def __members__(
self: type["_EnumMemberT"],
) -> MappingProxyType[str, "_EnumMemberT"]:
def __members__(self) -> MappingProxyType:
if Settings().config.behave_as == BehaveAs.LOCAL:
members = super().__members__
return members
return MappingProxyType(members)
else:
removed_providers = ["ollama"]
existing_members = super().__members__
@@ -424,7 +417,7 @@ def llm_call(
response=(
resp.content[0].name
if isinstance(resp.content[0], anthropic.types.ToolUseBlock)
else resp.content[0].text
else getattr(resp.content[0], "text", "")
),
tool_calls=tool_calls,
prompt_tokens=resp.usage.input_tokens,
@@ -528,7 +521,7 @@ def llm_call(
class AIBlockBase(Block, ABC):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.prompt = ""
self.prompt = []
def merge_llm_stats(self, block: "AIBlockBase"):
self.merge_stats(block.execution_stats)
@@ -587,7 +580,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
response: dict[str, Any] = SchemaField(
description="The response object generated by the language model."
)
prompt: str = SchemaField(description="The prompt sent to the language model.")
prompt: list = SchemaField(description="The prompt sent to the language model.")
error: str = SchemaField(description="Error message if the API call failed.")
def __init__(self):
@@ -609,7 +602,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
test_credentials=TEST_CREDENTIALS,
test_output=[
("response", {"key1": "key1Value", "key2": "key2Value"}),
("prompt", str),
("prompt", list),
],
test_mock={
"llm_call": lambda *args, **kwargs: LLMResponse(
@@ -642,6 +635,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
Test mocks work only on class functions, this wraps the llm_call function
so that it can be mocked withing the block testing framework.
"""
self.prompt = prompt
return llm_call(
credentials=credentials,
llm_model=llm_model,
@@ -814,7 +808,7 @@ class AITextGeneratorBlock(AIBlockBase):
response: str = SchemaField(
description="The response generated by the language model."
)
prompt: str = SchemaField(description="The prompt sent to the language model.")
prompt: list = SchemaField(description="The prompt sent to the language model.")
error: str = SchemaField(description="Error message if the API call failed.")
def __init__(self):
@@ -831,7 +825,7 @@ class AITextGeneratorBlock(AIBlockBase):
test_credentials=TEST_CREDENTIALS,
test_output=[
("response", "Response text"),
("prompt", str),
("prompt", list),
],
test_mock={"llm_call": lambda *args, **kwargs: "Response text"},
)
@@ -850,7 +844,10 @@ class AITextGeneratorBlock(AIBlockBase):
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
object_input_data = AIStructuredResponseGeneratorBlock.Input(
**{attr: getattr(input_data, attr) for attr in input_data.model_fields},
**{
attr: getattr(input_data, attr)
for attr in AITextGeneratorBlock.Input.model_fields
},
expected_format={},
)
yield "response", self.llm_call(object_input_data, credentials)
@@ -907,7 +904,7 @@ class AITextSummarizerBlock(AIBlockBase):
class Output(BlockSchema):
summary: str = SchemaField(description="The final summary of the text.")
prompt: str = SchemaField(description="The prompt sent to the language model.")
prompt: list = SchemaField(description="The prompt sent to the language model.")
error: str = SchemaField(description="Error message if the API call failed.")
def __init__(self):
@@ -924,7 +921,7 @@ class AITextSummarizerBlock(AIBlockBase):
test_credentials=TEST_CREDENTIALS,
test_output=[
("summary", "Final summary of a long text"),
("prompt", str),
("prompt", list),
],
test_mock={
"llm_call": lambda input_data, credentials: (
@@ -1033,8 +1030,14 @@ class AITextSummarizerBlock(AIBlockBase):
class AIConversationBlock(AIBlockBase):
class Input(BlockSchema):
prompt: str = SchemaField(
description="The prompt to send to the language model.",
placeholder="Enter your prompt here...",
default="",
advanced=False,
)
messages: List[Any] = SchemaField(
description="List of messages in the conversation.", min_length=1
description="List of messages in the conversation.",
)
model: LlmModel = SchemaField(
title="LLM Model",
@@ -1057,7 +1060,7 @@ class AIConversationBlock(AIBlockBase):
response: str = SchemaField(
description="The model's response to the conversation."
)
prompt: str = SchemaField(description="The prompt sent to the language model.")
prompt: list = SchemaField(description="The prompt sent to the language model.")
error: str = SchemaField(description="Error message if the API call failed.")
def __init__(self):
@@ -1086,7 +1089,7 @@ class AIConversationBlock(AIBlockBase):
"response",
"The 2020 World Series was played at Globe Life Field in Arlington, Texas.",
),
("prompt", str),
("prompt", list),
],
test_mock={
"llm_call": lambda *args, **kwargs: "The 2020 World Series was played at Globe Life Field in Arlington, Texas."
@@ -1108,7 +1111,7 @@ class AIConversationBlock(AIBlockBase):
) -> BlockOutput:
response = self.llm_call(
AIStructuredResponseGeneratorBlock.Input(
prompt="",
prompt=input_data.prompt,
credentials=input_data.credentials,
model=input_data.model,
conversation_history=input_data.messages,
@@ -1166,7 +1169,7 @@ class AIListGeneratorBlock(AIBlockBase):
list_item: str = SchemaField(
description="Each individual item in the list.",
)
prompt: str = SchemaField(description="The prompt sent to the language model.")
prompt: list = SchemaField(description="The prompt sent to the language model.")
error: str = SchemaField(
description="Error message if the list generation failed."
)
@@ -1198,7 +1201,7 @@ class AIListGeneratorBlock(AIBlockBase):
"generated_list",
["Zylora Prime", "Kharon-9", "Vortexia", "Oceara", "Draknos"],
),
("prompt", str),
("prompt", list),
("list_item", "Zylora Prime"),
("list_item", "Kharon-9"),
("list_item", "Vortexia"),

View File

@@ -39,7 +39,6 @@ class TwitterGetListBlock(Block):
list_id: str = SchemaField(
description="The ID of the List to lookup",
placeholder="Enter list ID",
required=True,
)
class Output(BlockSchema):
@@ -184,7 +183,6 @@ class TwitterGetOwnedListsBlock(Block):
user_id: str = SchemaField(
description="The user ID whose owned Lists to retrieve",
placeholder="Enter user ID",
required=True,
)
max_results: int | None = SchemaField(

View File

@@ -45,13 +45,11 @@ class TwitterRemoveListMemberBlock(Block):
list_id: str = SchemaField(
description="The ID of the List to remove the member from",
placeholder="Enter list ID",
required=True,
)
user_id: str = SchemaField(
description="The ID of the user to remove from the List",
placeholder="Enter user ID to remove",
required=True,
)
class Output(BlockSchema):
@@ -120,13 +118,11 @@ class TwitterAddListMemberBlock(Block):
list_id: str = SchemaField(
description="The ID of the List to add the member to",
placeholder="Enter list ID",
required=True,
)
user_id: str = SchemaField(
description="The ID of the user to add to the List",
placeholder="Enter user ID to add",
required=True,
)
class Output(BlockSchema):
@@ -195,7 +191,6 @@ class TwitterGetListMembersBlock(Block):
list_id: str = SchemaField(
description="The ID of the List to get members from",
placeholder="Enter list ID",
required=True,
)
max_results: int | None = SchemaField(
@@ -376,7 +371,6 @@ class TwitterGetListMembershipsBlock(Block):
user_id: str = SchemaField(
description="The ID of the user whose List memberships to retrieve",
placeholder="Enter user ID",
required=True,
)
max_results: int | None = SchemaField(

View File

@@ -42,7 +42,6 @@ class TwitterGetListTweetsBlock(Block):
list_id: str = SchemaField(
description="The ID of the List whose Tweets you would like to retrieve",
placeholder="Enter list ID",
required=True,
)
max_results: int | None = SchemaField(

View File

@@ -28,7 +28,6 @@ class TwitterDeleteListBlock(Block):
list_id: str = SchemaField(
description="The ID of the List to be deleted",
placeholder="Enter list ID",
required=True,
)
class Output(BlockSchema):

View File

@@ -39,7 +39,6 @@ class TwitterUnpinListBlock(Block):
list_id: str = SchemaField(
description="The ID of the List to unpin",
placeholder="Enter list ID",
required=True,
)
class Output(BlockSchema):
@@ -103,7 +102,6 @@ class TwitterPinListBlock(Block):
list_id: str = SchemaField(
description="The ID of the List to pin",
placeholder="Enter list ID",
required=True,
)
class Output(BlockSchema):

View File

@@ -227,7 +227,6 @@ class TwitterGetSpaceByIdBlock(Block):
space_id: str = SchemaField(
description="Space ID to lookup",
placeholder="Enter Space ID",
required=True,
)
class Output(BlockSchema):
@@ -389,7 +388,6 @@ class TwitterGetSpaceBuyersBlock(Block):
space_id: str = SchemaField(
description="Space ID to lookup buyers for",
placeholder="Enter Space ID",
required=True,
)
class Output(BlockSchema):
@@ -517,7 +515,6 @@ class TwitterGetSpaceTweetsBlock(Block):
space_id: str = SchemaField(
description="Space ID to lookup tweets for",
placeholder="Enter Space ID",
required=True,
)
class Output(BlockSchema):

View File

@@ -12,12 +12,12 @@ async def log_raw_analytics(
data_index: str,
):
details = await prisma.models.AnalyticsDetails.prisma().create(
data={
"userId": user_id,
"type": type,
"data": prisma.Json(data),
"dataIndex": data_index,
}
data=prisma.types.AnalyticsDetailsCreateInput(
userId=user_id,
type=type,
data=prisma.Json(data),
dataIndex=data_index,
)
)
return details
@@ -32,12 +32,12 @@ async def log_raw_metric(
raise ValueError("metric_value must be non-negative")
result = await prisma.models.AnalyticsMetrics.prisma().create(
data={
"value": metric_value,
"analyticMetric": metric_name,
"userId": user_id,
"dataString": data_string,
},
data=prisma.types.AnalyticsMetricsCreateInput(
value=metric_value,
analyticMetric=metric_name,
userId=user_id,
dataString=data_string,
)
)
return result

View File

@@ -17,6 +17,7 @@ from typing import (
import jsonref
import jsonschema
from prisma.models import AgentBlock
from prisma.types import AgentBlockCreateInput
from pydantic import BaseModel
from backend.data.model import NodeExecutionStats
@@ -480,12 +481,12 @@ async def initialize_blocks() -> None:
)
if not existing_block:
await AgentBlock.prisma().create(
data={
"id": block.id,
"name": block.name,
"inputSchema": json.dumps(block.input_schema.jsonschema()),
"outputSchema": json.dumps(block.output_schema.jsonschema()),
}
data=AgentBlockCreateInput(
id=block.id,
name=block.name,
inputSchema=json.dumps(block.input_schema.jsonschema()),
outputSchema=json.dumps(block.output_schema.jsonschema()),
)
)
continue

View File

@@ -14,7 +14,11 @@ from prisma.enums import (
)
from prisma.errors import UniqueViolationError
from prisma.models import CreditRefundRequest, CreditTransaction, User
from prisma.types import CreditTransactionCreateInput, CreditTransactionWhereInput
from prisma.types import (
CreditRefundRequestCreateInput,
CreditTransactionCreateInput,
CreditTransactionWhereInput,
)
from tenacity import retry, stop_after_attempt, wait_exponential
from backend.data import db
@@ -331,15 +335,15 @@ class UserCreditBase(ABC):
amount = min(-user_balance, 0)
# Create the transaction
transaction_data: CreditTransactionCreateInput = {
"userId": user_id,
"amount": amount,
"runningBalance": user_balance + amount,
"type": transaction_type,
"metadata": metadata,
"isActive": is_active,
"createdAt": self.time_now(),
}
transaction_data = CreditTransactionCreateInput(
userId=user_id,
amount=amount,
runningBalance=user_balance + amount,
type=transaction_type,
metadata=metadata,
isActive=is_active,
createdAt=self.time_now(),
)
if transaction_key:
transaction_data["transactionKey"] = transaction_key
tx = await CreditTransaction.prisma().create(data=transaction_data)
@@ -422,15 +426,15 @@ class UserCredit(UserCreditBase):
try:
refund_request = await CreditRefundRequest.prisma().create(
data={
"id": refund_key,
"transactionKey": transaction_key,
"userId": user_id,
"amount": amount,
"reason": metadata.get("reason", ""),
"status": CreditRefundRequestStatus.PENDING,
"result": "The refund request is under review.",
}
data=CreditRefundRequestCreateInput(
id=refund_key,
transactionKey=transaction_key,
userId=user_id,
amount=amount,
reason=metadata.get("reason", ""),
status=CreditRefundRequestStatus.PENDING,
result="The refund request is under review.",
)
)
except UniqueViolationError:
raise ValueError(

View File

@@ -24,6 +24,8 @@ from prisma.models import (
)
from prisma.types import (
AgentGraphExecutionWhereInput,
AgentNodeExecutionCreateInput,
AgentNodeExecutionInputOutputCreateInput,
AgentNodeExecutionUpdateInput,
AgentNodeExecutionWhereInput,
)
@@ -167,7 +169,7 @@ class GraphExecution(GraphExecutionMeta):
return GraphExecution(
**{
field_name: getattr(graph_exec, field_name)
for field_name in graph_exec.model_fields
for field_name in GraphExecutionMeta.model_fields
},
inputs=inputs,
outputs=outputs,
@@ -195,7 +197,7 @@ class GraphExecutionWithNodes(GraphExecution):
return GraphExecutionWithNodes(
**{
field_name: getattr(graph_exec_with_io, field_name)
for field_name in graph_exec_with_io.model_fields
for field_name in GraphExecution.model_fields
},
node_executions=node_executions,
)
@@ -418,11 +420,11 @@ async def upsert_execution_input(
if existing_execution:
await AgentNodeExecutionInputOutput.prisma().create(
data={
"name": input_name,
"data": json_input_data,
"referencedByInputExecId": existing_execution.id,
}
data=AgentNodeExecutionInputOutputCreateInput(
name=input_name,
data=json_input_data,
referencedByInputExecId=existing_execution.id,
)
)
return existing_execution.id, {
**{
@@ -434,12 +436,12 @@ async def upsert_execution_input(
elif not node_exec_id:
result = await AgentNodeExecution.prisma().create(
data={
"agentNodeId": node_id,
"agentGraphExecutionId": graph_exec_id,
"executionStatus": ExecutionStatus.INCOMPLETE,
"Input": {"create": {"name": input_name, "data": json_input_data}},
}
data=AgentNodeExecutionCreateInput(
agentNodeId=node_id,
agentGraphExecutionId=graph_exec_id,
executionStatus=ExecutionStatus.INCOMPLETE,
Input={"create": {"name": input_name, "data": json_input_data}},
)
)
return result.id, {input_name: input_data}
@@ -458,11 +460,11 @@ async def upsert_execution_output(
Insert AgentNodeExecutionInputOutput record for as one of AgentNodeExecution.Output.
"""
await AgentNodeExecutionInputOutput.prisma().create(
data={
"name": output_name,
"data": Json(output_data),
"referencedByOutputExecId": node_exec_id,
}
data=AgentNodeExecutionInputOutputCreateInput(
name=output_name,
data=Json(output_data),
referencedByOutputExecId=node_exec_id,
)
)

View File

@@ -7,7 +7,12 @@ import prisma
from prisma import Json
from prisma.enums import SubmissionStatus
from prisma.models import AgentGraph, AgentNode, AgentNodeLink, StoreListingVersion
from prisma.types import AgentGraphWhereInput
from prisma.types import (
AgentGraphCreateInput,
AgentGraphWhereInput,
AgentNodeCreateInput,
AgentNodeLinkCreateInput,
)
from pydantic.fields import computed_field
from backend.blocks.agent import AgentExecutorBlock
@@ -737,28 +742,28 @@ async def __create_graph(tx, graph: Graph, user_id: str):
await AgentGraph.prisma(tx).create_many(
data=[
{
"id": graph.id,
"version": graph.version,
"name": graph.name,
"description": graph.description,
"isActive": graph.is_active,
"userId": user_id,
}
AgentGraphCreateInput(
id=graph.id,
version=graph.version,
name=graph.name,
description=graph.description,
isActive=graph.is_active,
userId=user_id,
)
for graph in graphs
]
)
await AgentNode.prisma(tx).create_many(
data=[
{
"id": node.id,
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
"agentBlockId": node.block_id,
"constantInput": Json(node.input_default),
"metadata": Json(node.metadata),
}
AgentNodeCreateInput(
id=node.id,
agentGraphId=graph.id,
agentGraphVersion=graph.version,
agentBlockId=node.block_id,
constantInput=Json(node.input_default),
metadata=Json(node.metadata),
)
for graph in graphs
for node in graph.nodes
]
@@ -766,14 +771,14 @@ async def __create_graph(tx, graph: Graph, user_id: str):
await AgentNodeLink.prisma(tx).create_many(
data=[
{
"id": str(uuid.uuid4()),
"sourceName": link.source_name,
"sinkName": link.sink_name,
"agentNodeSourceId": link.source_id,
"agentNodeSinkId": link.sink_id,
"isStatic": link.is_static,
}
AgentNodeLinkCreateInput(
id=str(uuid.uuid4()),
sourceName=link.source_name,
sinkName=link.sink_name,
agentNodeSourceId=link.source_id,
agentNodeSinkId=link.sink_id,
isStatic=link.is_static,
)
for graph in graphs
for link in graph.links
]

View File

@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, AsyncGenerator, Optional
from prisma import Json
from prisma.models import IntegrationWebhook
from prisma.types import IntegrationWebhookCreateInput
from pydantic import Field, computed_field
from backend.data.includes import INTEGRATION_WEBHOOK_INCLUDE
@@ -66,18 +67,18 @@ class Webhook(BaseDbModel):
async def create_webhook(webhook: Webhook) -> Webhook:
created_webhook = await IntegrationWebhook.prisma().create(
data={
"id": webhook.id,
"userId": webhook.user_id,
"provider": webhook.provider.value,
"credentialsId": webhook.credentials_id,
"webhookType": webhook.webhook_type,
"resource": webhook.resource,
"events": webhook.events,
"config": Json(webhook.config),
"secret": webhook.secret,
"providerWebhookId": webhook.provider_webhook_id,
}
data=IntegrationWebhookCreateInput(
id=webhook.id,
userId=webhook.user_id,
provider=webhook.provider.value,
credentialsId=webhook.credentials_id,
webhookType=webhook.webhook_type,
resource=webhook.resource,
events=webhook.events,
config=Json(webhook.config),
secret=webhook.secret,
providerWebhookId=webhook.provider_webhook_id,
)
)
return Webhook.from_db(created_webhook)

View File

@@ -142,8 +142,12 @@ def SchemaField(
exclude: bool = False,
hidden: Optional[bool] = None,
depends_on: Optional[list[str]] = None,
ge: Optional[float] = None,
le: Optional[float] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
discriminator: Optional[str] = None,
json_schema_extra: Optional[dict[str, Any]] = None,
**kwargs,
) -> T:
if default is PydanticUndefined and default_factory is None:
advanced = False
@@ -170,8 +174,12 @@ def SchemaField(
title=title,
description=description,
exclude=exclude,
ge=ge,
le=le,
min_length=min_length,
max_length=max_length,
discriminator=discriminator,
json_schema_extra=json_schema_extra,
**kwargs,
) # type: ignore
@@ -405,9 +413,10 @@ class RefundRequest(BaseModel):
class NodeExecutionStats(BaseModel):
"""Execution statistics for a node execution."""
class Config:
arbitrary_types_allowed = True
extra = "allow"
model_config = ConfigDict(
extra="allow",
arbitrary_types_allowed=True,
)
error: Optional[Exception | str] = None
walltime: float = 0
@@ -423,9 +432,10 @@ class NodeExecutionStats(BaseModel):
class GraphExecutionStats(BaseModel):
"""Execution statistics for a graph execution."""
class Config:
arbitrary_types_allowed = True
extra = "allow"
model_config = ConfigDict(
extra="allow",
arbitrary_types_allowed=True,
)
error: Optional[Exception | str] = None
walltime: float = Field(

View File

@@ -6,10 +6,14 @@ from typing import Annotated, Any, Generic, Optional, TypeVar, Union
from prisma import Json
from prisma.enums import NotificationType
from prisma.models import NotificationEvent, UserNotificationBatch
from prisma.types import UserNotificationBatchWhereInput
from prisma.types import (
NotificationEventCreateInput,
UserNotificationBatchCreateInput,
UserNotificationBatchWhereInput,
)
# from backend.notifications.models import NotificationEvent
from pydantic import BaseModel, EmailStr, Field, field_validator
from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator
from backend.server.v2.store.exceptions import DatabaseError
@@ -35,8 +39,7 @@ class QueueType(Enum):
class BaseNotificationData(BaseModel):
class Config:
extra = "allow"
model_config = ConfigDict(extra="allow")
class AgentRunData(BaseNotificationData):
@@ -418,30 +421,30 @@ async def create_or_add_to_user_notification_batch(
if not existing_batch:
async with transaction() as tx:
notification_event = await tx.notificationevent.create(
data={
"type": notification_type,
"data": json_data,
}
data=NotificationEventCreateInput(
type=notification_type,
data=json_data,
)
)
# Create new batch
resp = await tx.usernotificationbatch.create(
data={
"userId": user_id,
"type": notification_type,
"Notifications": {"connect": [{"id": notification_event.id}]},
},
data=UserNotificationBatchCreateInput(
userId=user_id,
type=notification_type,
Notifications={"connect": [{"id": notification_event.id}]},
),
include={"Notifications": True},
)
return UserNotificationBatchDTO.from_db(resp)
else:
async with transaction() as tx:
notification_event = await tx.notificationevent.create(
data={
"type": notification_type,
"data": json_data,
"UserNotificationBatch": {"connect": {"id": existing_batch.id}},
}
data=NotificationEventCreateInput(
type=notification_type,
data=json_data,
UserNotificationBatch={"connect": {"id": existing_batch.id}},
)
)
# Add to existing batch
resp = await tx.usernotificationbatch.update(

View File

@@ -11,7 +11,7 @@ from fastapi import HTTPException
from prisma import Json
from prisma.enums import NotificationType
from prisma.models import User
from prisma.types import UserUpdateInput
from prisma.types import UserCreateInput, UserUpdateInput
from backend.data.db import prisma
from backend.data.model import UserIntegrations, UserMetadata, UserMetadataRaw
@@ -36,11 +36,11 @@ async def get_or_create_user(user_data: dict) -> User:
user = await prisma.user.find_unique(where={"id": user_id})
if not user:
user = await prisma.user.create(
data={
"id": user_id,
"email": user_email,
"name": user_data.get("user_metadata", {}).get("name"),
}
data=UserCreateInput(
id=user_id,
email=user_email,
name=user_data.get("user_metadata", {}).get("name"),
)
)
return User.model_validate(user)
@@ -84,11 +84,11 @@ async def create_default_user() -> Optional[User]:
user = await prisma.user.find_unique(where={"id": DEFAULT_USER_ID})
if not user:
user = await prisma.user.create(
data={
"id": DEFAULT_USER_ID,
"email": "default@example.com",
"name": "Default User",
}
data=UserCreateInput(
id=DEFAULT_USER_ID,
email="default@example.com",
name="Default User",
)
)
return User.model_validate(user)

View File

@@ -10,6 +10,7 @@ from backend.data import redis
from backend.data.model import Credentials
from backend.integrations.credentials_store import IntegrationCredentialsStore
from backend.integrations.oauth import HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName
from backend.util.exceptions import MissingConfigError
from backend.util.settings import Settings
@@ -153,7 +154,8 @@ class IntegrationCredentialsManager:
self.store.locks.release_all_locks()
def _get_provider_oauth_handler(provider_name: str) -> "BaseOAuthHandler":
def _get_provider_oauth_handler(provider_name_str: str) -> "BaseOAuthHandler":
provider_name = ProviderName(provider_name_str)
if provider_name not in HANDLERS_BY_NAME:
raise KeyError(f"Unknown provider '{provider_name}'")

View File

@@ -3,7 +3,6 @@ import logging
from typing import Any, Optional
import autogpt_libs.auth.models
from autogpt_libs.logging.utils import generate_uvicorn_config
import fastapi
import fastapi.responses
import starlette.middleware.cors
@@ -12,6 +11,7 @@ from autogpt_libs.feature_flag.client import (
initialize_launchdarkly,
shutdown_launchdarkly,
)
from autogpt_libs.logging.utils import generate_uvicorn_config
import backend.data.block
import backend.data.db

View File

@@ -6,6 +6,7 @@ import prisma.errors
import prisma.fields
import prisma.models
import prisma.types
from prisma.types import AgentPresetCreateInput
import backend.data.graph
import backend.server.model
@@ -228,16 +229,16 @@ async def create_library_agent(
try:
return await prisma.models.LibraryAgent.prisma().create(
data={
"isCreatedByUser": (user_id == graph.user_id),
"useGraphIsActiveVersion": True,
"User": {"connect": {"id": user_id}},
"Agent": {
data=prisma.types.LibraryAgentCreateInput(
isCreatedByUser=(user_id == graph.user_id),
useGraphIsActiveVersion=True,
User={"connect": {"id": user_id}},
Agent={
"connect": {
"graphVersionId": {"id": graph.id, "version": graph.version}
}
},
}
)
)
except prisma.errors.PrismaError as e:
logger.error(f"Database error creating agent in library: {e}")
@@ -418,12 +419,12 @@ async def add_store_agent_to_library(
# Create LibraryAgent entry
added_agent = await prisma.models.LibraryAgent.prisma().create(
data={
"userId": user_id,
"agentId": graph.id,
"agentVersion": graph.version,
"isCreatedByUser": False,
},
data=prisma.types.LibraryAgentCreateInput(
userId=user_id,
agentId=graph.id,
agentVersion=graph.version,
isCreatedByUser=False,
),
include=library_agent_include(user_id),
)
logger.debug(
@@ -601,17 +602,17 @@ async def upsert_preset(
# Update existing preset
updated = await prisma.models.AgentPreset.prisma().update(
where={"id": preset_id},
data={
"name": preset.name,
"description": preset.description,
"isActive": preset.is_active,
"InputPresets": {
data=AgentPresetCreateInput(
name=preset.name,
description=preset.description,
isActive=preset.is_active,
InputPresets={
"create": [
{"name": name, "data": prisma.fields.Json(data)}
for name, data in preset.inputs.items()
]
},
},
),
include={"InputPresets": True},
)
if not updated:
@@ -620,20 +621,20 @@ async def upsert_preset(
else:
# Create new preset
new_preset = await prisma.models.AgentPreset.prisma().create(
data={
"userId": user_id,
"name": preset.name,
"description": preset.description,
"agentId": preset.agent_id,
"agentVersion": preset.agent_version,
"isActive": preset.is_active,
"InputPresets": {
data=prisma.types.AgentPresetCreateInput(
userId=user_id,
name=preset.name,
description=preset.description,
agentId=preset.agent_id,
agentVersion=preset.agent_version,
isActive=preset.is_active,
InputPresets={
"create": [
{"name": name, "data": prisma.fields.Json(data)}
for name, data in preset.inputs.items()
]
},
},
),
include={"InputPresets": True},
)
return library_model.LibraryAgentPreset.from_db(new_preset)

View File

@@ -81,7 +81,7 @@ async def test_get_library_agents(mocker):
assert result.pagination.page_size == 50
@pytest.mark.asyncio(scope="session")
@pytest.mark.asyncio(loop_scope="session")
async def test_add_agent_to_library(mocker):
await connect()
# Mock data
@@ -165,7 +165,7 @@ async def test_add_agent_to_library(mocker):
)
@pytest.mark.asyncio(scope="session")
@pytest.mark.asyncio(loop_scope="session")
async def test_add_agent_to_library_not_found(mocker):
await connect()
# Mock prisma calls

View File

@@ -3,9 +3,9 @@ import logging
from contextlib import asynccontextmanager
from typing import Protocol
from autogpt_libs.logging.utils import generate_uvicorn_config
import uvicorn
from autogpt_libs.auth import parse_jwt_token
from autogpt_libs.logging.utils import generate_uvicorn_config
from autogpt_libs.utils.cache import thread_cached
from fastapi import Depends, FastAPI, WebSocket, WebSocketDisconnect
from starlette.middleware.cors import CORSMiddleware

View File

@@ -142,7 +142,7 @@ def validate_url(
# Resolve all IP addresses for the hostname
try:
ip_list = [res[4][0] for res in socket.getaddrinfo(ascii_hostname, None)]
ip_list = [str(res[4][0]) for res in socket.getaddrinfo(ascii_hostname, None)]
ipv4 = [ip for ip in ip_list if ":" not in ip]
ipv6 = [ip for ip in ip_list if ":" in ip]
ip_addresses = ipv4 + ipv6 # Prefer IPv4 over IPv6

View File

@@ -324,14 +324,16 @@ class FastApiAppService(BaseAppService, ABC):
async def async_endpoint(body: RequestBodyModel): # type: ignore #RequestBodyModel being variable
return await f(
**{name: getattr(body, name) for name in body.model_fields}
**{name: getattr(body, name) for name in type(body).model_fields}
)
return async_endpoint
else:
def sync_endpoint(body: RequestBodyModel): # type: ignore #RequestBodyModel being variable
return f(**{name: getattr(body, name) for name in body.model_fields})
return f(
**{name: getattr(body, name) for name in type(body).model_fields}
)
return sync_endpoint

View File

@@ -31,12 +31,12 @@ class UpdateTrackingModel(BaseModel, Generic[T]):
_updated_fields: Set[str] = PrivateAttr(default_factory=set)
def __setattr__(self, name: str, value) -> None:
if name in self.model_fields:
if name in UpdateTrackingModel.model_fields:
self._updated_fields.add(name)
super().__setattr__(name, value)
def mark_updated(self, field_name: str) -> None:
if field_name in self.model_fields:
if field_name in UpdateTrackingModel.model_fields:
self._updated_fields.add(field_name)
def clear_updates(self) -> None:

View File

@@ -52,7 +52,7 @@ async def spend_credits(entry: NodeExecutionEntry) -> int:
return cost
@pytest.mark.asyncio(scope="session")
@pytest.mark.asyncio(loop_scope="session")
async def test_block_credit_usage(server: SpinTestServer):
await disable_test_user_transactions()
await top_up(100)
@@ -95,7 +95,7 @@ async def test_block_credit_usage(server: SpinTestServer):
assert new_credit == current_credit - spending_amount_1 - spending_amount_2
@pytest.mark.asyncio(scope="session")
@pytest.mark.asyncio(loop_scope="session")
async def test_block_credit_top_up(server: SpinTestServer):
await disable_test_user_transactions()
current_credit = await user_credit.get_credits(DEFAULT_USER_ID)
@@ -106,7 +106,7 @@ async def test_block_credit_top_up(server: SpinTestServer):
assert new_credit == current_credit + 100
@pytest.mark.asyncio(scope="session")
@pytest.mark.asyncio(loop_scope="session")
async def test_block_credit_reset(server: SpinTestServer):
await disable_test_user_transactions()
month1 = 1
@@ -133,7 +133,7 @@ async def test_block_credit_reset(server: SpinTestServer):
assert await user_credit.get_credits(DEFAULT_USER_ID) == month2credit
@pytest.mark.asyncio(scope="session")
@pytest.mark.asyncio(loop_scope="session")
async def test_credit_refill(server: SpinTestServer):
await disable_test_user_transactions()
balance = await user_credit.get_credits(DEFAULT_USER_ID)

View File

@@ -17,7 +17,7 @@ from backend.usecases.sample import create_test_user
from backend.util.test import SpinTestServer
@pytest.mark.asyncio(scope="session")
@pytest.mark.asyncio(loop_scope="session")
async def test_graph_creation(server: SpinTestServer):
"""
Test the creation of a graph with nodes and links.
@@ -71,7 +71,7 @@ async def test_graph_creation(server: SpinTestServer):
assert links[0].sink_id in {nodes[0].id, nodes[1].id, nodes[2].id}
@pytest.mark.asyncio(scope="session")
@pytest.mark.asyncio(loop_scope="session")
async def test_get_input_schema(server: SpinTestServer):
"""
Test the get_input_schema method of a created graph.
@@ -167,7 +167,7 @@ async def test_get_input_schema(server: SpinTestServer):
assert output_schema == ExpectedOutputSchema.jsonschema()
@pytest.mark.asyncio(scope="session")
@pytest.mark.asyncio(loop_scope="session")
async def test_clean_graph(server: SpinTestServer):
"""
Test the clean_graph function that:
@@ -211,7 +211,7 @@ async def test_clean_graph(server: SpinTestServer):
assert input_node.input_default["value"] == ""
@pytest.mark.asyncio(scope="session")
@pytest.mark.asyncio(loop_scope="session")
async def test_access_store_listing_graph(server: SpinTestServer):
"""
Test the access of a store listing graph.

View File

@@ -127,7 +127,7 @@ async def assert_sample_graph_executions(
assert exec.node_id == test_graph.nodes[3].id
@pytest.mark.asyncio(scope="session")
@pytest.mark.asyncio(loop_scope="session")
async def test_agent_execution(server: SpinTestServer):
logger.info("Starting test_agent_execution")
test_user = await create_test_user()
@@ -144,7 +144,7 @@ async def test_agent_execution(server: SpinTestServer):
logger.info("Completed test_agent_execution")
@pytest.mark.asyncio(scope="session")
@pytest.mark.asyncio(loop_scope="session")
async def test_input_pin_always_waited(server: SpinTestServer):
"""
This test is asserting that the input pin should always be waited for the execution,
@@ -211,7 +211,7 @@ async def test_input_pin_always_waited(server: SpinTestServer):
logger.info("Completed test_input_pin_always_waited")
@pytest.mark.asyncio(scope="session")
@pytest.mark.asyncio(loop_scope="session")
async def test_static_input_link_on_graph(server: SpinTestServer):
"""
This test is asserting the behaviour of static input link, e.g: reusable input link.
@@ -296,7 +296,7 @@ async def test_static_input_link_on_graph(server: SpinTestServer):
logger.info("Completed test_static_input_link_on_graph")
@pytest.mark.asyncio(scope="session")
@pytest.mark.asyncio(loop_scope="session")
async def test_execute_preset(server: SpinTestServer):
"""
Test executing a preset.
@@ -392,7 +392,7 @@ async def test_execute_preset(server: SpinTestServer):
assert executions[3].output_data == {"output": ["World"]}
@pytest.mark.asyncio(scope="session")
@pytest.mark.asyncio(loop_scope="session")
async def test_execute_preset_with_clash(server: SpinTestServer):
"""
Test executing a preset with clashing input data.
@@ -482,7 +482,7 @@ async def test_execute_preset_with_clash(server: SpinTestServer):
assert executions[3].output_data == {"output": ["Hello"]}
@pytest.mark.asyncio(scope="session")
@pytest.mark.asyncio(loop_scope="session")
async def test_store_listing_graph(server: SpinTestServer):
logger.info("Starting test_agent_execution")
test_user = await create_test_user()

View File

@@ -8,7 +8,7 @@ from backend.util.service import get_service_client
from backend.util.test import SpinTestServer
@pytest.mark.asyncio(scope="session")
@pytest.mark.asyncio(loop_scope="session")
async def test_agent_schedule(server: SpinTestServer):
await db.connect()
test_user = await create_test_user()

View File

@@ -61,7 +61,7 @@ async def execute_graph(
@pytest.mark.skip()
@pytest.mark.asyncio(scope="session")
@pytest.mark.asyncio(loop_scope="session")
async def test_graph_validation_with_tool_nodes_correct(server: SpinTestServer):
test_user = await create_test_user()
test_tool_graph = await create_graph(server, create_test_graph(), test_user)
@@ -111,7 +111,7 @@ async def test_graph_validation_with_tool_nodes_correct(server: SpinTestServer):
@pytest.mark.skip()
@pytest.mark.asyncio(scope="session")
@pytest.mark.asyncio(loop_scope="session")
async def test_graph_validation_with_tool_nodes_raises_error(server: SpinTestServer):
test_user = await create_test_user()
@@ -172,7 +172,7 @@ async def test_graph_validation_with_tool_nodes_raises_error(server: SpinTestSer
@pytest.mark.skip()
@pytest.mark.asyncio(scope="session")
@pytest.mark.asyncio(loop_scope="session")
async def test_smart_decision_maker_function_signature(server: SpinTestServer):
test_user = await create_test_user()
test_tool_graph = await create_graph(server, create_test_graph(), test_user)

View File

@@ -5,6 +5,23 @@ from datetime import datetime
import prisma.enums
from faker import Faker
from prisma import Json, Prisma
from prisma.types import (
AgentBlockCreateInput,
AgentGraphCreateInput,
AgentNodeCreateInput,
AgentNodeLinkCreateInput,
AgentPresetCreateInput,
AnalyticsDetailsCreateInput,
AnalyticsMetricsCreateInput,
APIKeyCreateInput,
CreditTransactionCreateInput,
LibraryAgentCreateInput,
ProfileCreateInput,
StoreListingCreateInput,
StoreListingReviewCreateInput,
StoreListingVersionCreateInput,
UserCreateInput,
)
faker = Faker()
@@ -55,13 +72,13 @@ async def main():
users = []
for _ in range(NUM_USERS):
user = await db.user.create(
data={
"id": str(faker.uuid4()),
"email": faker.unique.email(),
"name": faker.name(),
"metadata": prisma.Json({}),
"integrations": "",
}
data=UserCreateInput(
id=str(faker.uuid4()),
email=faker.unique.email(),
name=faker.name(),
metadata=prisma.Json({}),
integrations="",
)
)
users.append(user)
@@ -70,11 +87,11 @@ async def main():
print(f"Inserting {NUM_AGENT_BLOCKS} agent blocks")
for _ in range(NUM_AGENT_BLOCKS):
block = await db.agentblock.create(
data={
"name": f"{faker.word()}_{str(faker.uuid4())[:8]}",
"inputSchema": "{}",
"outputSchema": "{}",
}
data=AgentBlockCreateInput(
name=f"{faker.word()}_{str(faker.uuid4())[:8]}",
inputSchema="{}",
outputSchema="{}",
)
)
agent_blocks.append(block)
@@ -86,12 +103,12 @@ async def main():
random.randint(MIN_GRAPHS_PER_USER, MAX_GRAPHS_PER_USER)
): # Adjust the range to create more graphs per user if desired
graph = await db.agentgraph.create(
data={
"name": faker.sentence(nb_words=3),
"description": faker.text(max_nb_chars=200),
"userId": user.id,
"isActive": True,
}
data=AgentGraphCreateInput(
name=faker.sentence(nb_words=3),
description=faker.text(max_nb_chars=200),
userId=user.id,
isActive=True,
)
)
agent_graphs.append(graph)
@@ -105,13 +122,13 @@ async def main():
for _ in range(num_nodes): # Create 5 AgentNodes per graph
block = random.choice(agent_blocks)
node = await db.agentnode.create(
data={
"agentBlockId": block.id,
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
"constantInput": Json({}),
"metadata": Json({}),
}
data=AgentNodeCreateInput(
agentBlockId=block.id,
agentGraphId=graph.id,
agentGraphVersion=graph.version,
constantInput=Json({}),
metadata=Json({}),
)
)
agent_nodes.append(node)
@@ -123,14 +140,14 @@ async def main():
for _ in range(num_presets): # Create 1 AgentPreset per user
graph = random.choice(agent_graphs)
preset = await db.agentpreset.create(
data={
"name": faker.sentence(nb_words=3),
"description": faker.text(max_nb_chars=200),
"userId": user.id,
"agentId": graph.id,
"agentVersion": graph.version,
"isActive": True,
}
data=AgentPresetCreateInput(
name=faker.sentence(nb_words=3),
description=faker.text(max_nb_chars=200),
userId=user.id,
agentId=graph.id,
agentVersion=graph.version,
isActive=True,
)
)
agent_presets.append(preset)
@@ -143,20 +160,19 @@ async def main():
graph = random.choice(agent_graphs)
preset = random.choice(agent_presets)
user_agent = await db.libraryagent.create(
data={
"userId": user.id,
"agentId": graph.id,
"agentVersion": graph.version,
"agentPresetId": preset.id,
"isFavorite": random.choice([True, False]),
"isCreatedByUser": random.choice([True, False]),
"isArchived": random.choice([True, False]),
"isDeleted": random.choice([True, False]),
}
data=LibraryAgentCreateInput(
userId=user.id,
agentId=graph.id,
agentVersion=graph.version,
agentPresetId=preset.id,
isFavorite=random.choice([True, False]),
isCreatedByUser=random.choice([True, False]),
isArchived=random.choice([True, False]),
isDeleted=random.choice([True, False]),
)
)
user_agents.append(user_agent)
# Insert AgentGraphExecutions
# Insert AgentGraphExecutions
agent_graph_executions = []
print(
@@ -253,13 +269,13 @@ async def main():
source_node = nodes[0]
sink_node = nodes[1]
await db.agentnodelink.create(
data={
"agentNodeSourceId": source_node.id,
"sourceName": "output1",
"agentNodeSinkId": sink_node.id,
"sinkName": "input1",
"isStatic": False,
}
data=AgentNodeLinkCreateInput(
agentNodeSourceId=source_node.id,
sourceName="output1",
agentNodeSinkId=sink_node.id,
sinkName="input1",
isStatic=False,
)
)
# Insert AnalyticsDetails
@@ -267,12 +283,12 @@ async def main():
for user in users:
for _ in range(1):
await db.analyticsdetails.create(
data={
"userId": user.id,
"type": faker.word(),
"data": prisma.Json({}),
"dataIndex": faker.word(),
}
data=AnalyticsDetailsCreateInput(
userId=user.id,
type=faker.word(),
data=prisma.Json({}),
dataIndex=faker.word(),
)
)
# Insert AnalyticsMetrics
@@ -280,12 +296,12 @@ async def main():
for user in users:
for _ in range(1):
await db.analyticsmetrics.create(
data={
"userId": user.id,
"analyticMetric": faker.word(),
"value": random.uniform(0, 100),
"dataString": faker.word(),
}
data=AnalyticsMetricsCreateInput(
userId=user.id,
analyticMetric=faker.word(),
value=random.uniform(0, 100),
dataString=faker.word(),
)
)
# Insert CreditTransaction (formerly UserBlockCredit)
@@ -294,17 +310,17 @@ async def main():
for _ in range(1):
block = random.choice(agent_blocks)
await db.credittransaction.create(
data={
"transactionKey": str(faker.uuid4()),
"userId": user.id,
"amount": random.randint(1, 100),
"type": (
data=CreditTransactionCreateInput(
transactionKey=str(faker.uuid4()),
userId=user.id,
amount=random.randint(1, 100),
type=(
prisma.enums.CreditTransactionType.TOP_UP
if random.random() < 0.5
else prisma.enums.CreditTransactionType.USAGE
),
"metadata": prisma.Json({}),
}
metadata=prisma.Json({}),
)
)
# Insert Profiles
@@ -312,14 +328,14 @@ async def main():
print(f"Inserting {NUM_USERS} profiles")
for user in users:
profile = await db.profile.create(
data={
"userId": user.id,
"name": user.name or faker.name(),
"username": faker.unique.user_name(),
"description": faker.text(),
"links": [faker.url() for _ in range(3)],
"avatarUrl": get_image(),
}
data=ProfileCreateInput(
userId=user.id,
name=user.name or faker.name(),
username=faker.unique.user_name(),
description=faker.text(),
links=[faker.url() for _ in range(3)],
avatarUrl=get_image(),
)
)
profiles.append(profile)
@@ -330,13 +346,13 @@ async def main():
user = random.choice(users)
slug = faker.slug()
listing = await db.storelisting.create(
data={
"agentId": graph.id,
"agentVersion": graph.version,
"owningUserId": user.id,
"hasApprovedVersion": random.choice([True, False]),
"slug": slug,
}
data=StoreListingCreateInput(
agentId=graph.id,
agentVersion=graph.version,
owningUserId=user.id,
hasApprovedVersion=random.choice([True, False]),
slug=slug,
)
)
store_listings.append(listing)
@@ -346,26 +362,26 @@ async def main():
for listing in store_listings:
graph = [g for g in agent_graphs if g.id == listing.agentId][0]
version = await db.storelistingversion.create(
data={
"agentId": graph.id,
"agentVersion": graph.version,
"name": graph.name or faker.sentence(nb_words=3),
"subHeading": faker.sentence(),
"videoUrl": faker.url(),
"imageUrls": [get_image() for _ in range(3)],
"description": faker.text(),
"categories": [faker.word() for _ in range(3)],
"isFeatured": random.choice([True, False]),
"isAvailable": True,
"storeListingId": listing.id,
"submissionStatus": random.choice(
data=StoreListingVersionCreateInput(
agentId=graph.id,
agentVersion=graph.version,
name=graph.name or faker.sentence(nb_words=3),
subHeading=faker.sentence(),
videoUrl=faker.url(),
imageUrls=[get_image() for _ in range(3)],
description=faker.text(),
categories=[faker.word() for _ in range(3)],
isFeatured=random.choice([True, False]),
isAvailable=True,
storeListingId=listing.id,
submissionStatus=random.choice(
[
prisma.enums.SubmissionStatus.PENDING,
prisma.enums.SubmissionStatus.APPROVED,
prisma.enums.SubmissionStatus.REJECTED,
]
),
}
)
)
store_listing_versions.append(version)
@@ -385,12 +401,12 @@ async def main():
# Take only the first num_reviews reviewers
for reviewer in available_reviewers[:num_reviews]:
await db.storelistingreview.create(
data={
"storeListingVersionId": version.id,
"reviewByUserId": reviewer.id,
"score": random.randint(1, 5),
"comments": faker.text(),
}
data=StoreListingReviewCreateInput(
storeListingVersionId=version.id,
reviewByUserId=reviewer.id,
score=random.randint(1, 5),
comments=faker.text(),
)
)
# Update StoreListingVersions with submission status (StoreListingSubmissions table no longer exists)
@@ -406,31 +422,42 @@ async def main():
)
await db.storelistingversion.update(
where={"id": version.id},
data={
"submissionStatus": status,
"Reviewer": {"connect": {"id": reviewer.id}},
"reviewComments": faker.text(),
"reviewedAt": datetime.now(),
},
data=StoreListingVersionCreateInput(
submissionStatus=status,
Reviewer={"connect": {"id": reviewer.id}},
reviewComments=faker.text(),
reviewedAt=datetime.now(),
agentId=version.agentId, # preserving existing fields
agentVersion=version.agentVersion,
name=version.name,
subHeading=version.subHeading,
videoUrl=version.videoUrl,
imageUrls=version.imageUrls,
description=version.description,
categories=version.categories,
isFeatured=version.isFeatured,
isAvailable=version.isAvailable,
storeListingId=version.storeListingId,
),
)
# Insert APIKeys
print(f"Inserting {NUM_USERS} api keys")
for user in users:
await db.apikey.create(
data={
"name": faker.word(),
"prefix": str(faker.uuid4())[:8],
"postfix": str(faker.uuid4())[-8:],
"key": str(faker.sha256()),
"status": prisma.enums.APIKeyStatus.ACTIVE,
"permissions": [
data=APIKeyCreateInput(
name=faker.word(),
prefix=str(faker.uuid4())[:8],
postfix=str(faker.uuid4())[-8:],
key=str(faker.sha256()),
status=prisma.enums.APIKeyStatus.ACTIVE,
permissions=[
prisma.enums.APIKeyPermission.EXECUTE_GRAPH,
prisma.enums.APIKeyPermission.READ_GRAPH,
],
"description": faker.text(),
"userId": user.id,
}
description=faker.text(),
userId=user.id,
)
)
await db.disconnect()

View File

@@ -29,7 +29,7 @@ class ServiceTest(AppService):
return self.run_and_wait(add_async(a, b))
@pytest.mark.asyncio(scope="session")
@pytest.mark.asyncio(loop_scope="session")
async def test_service_creation(server):
with ServiceTest():
client = get_service_client(ServiceTest)