mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
feat(rnd): Add initial credit accounting system for block execution (#8047)
### Background
We need a way to set an execution quota per user for each block execution.
### Changes 🏗️
* Introduced a `UserBlockCredit`, a transaction table tracking the block usage along with it cost/quota.
* The tracking is toggled by `ENABLE_CREDIT` config, default = false.
* Introduced `BLOCK_COSTS` | `GET /blocks/costs` as a source of information for the cost on each block depending on the input configuration.
Improvements:
* Refactor logging in manager.py to always print a prefix and pass the metadata.
* Make executionStatus on AgentNodeExecution prisma enum. And add executionStatus on AgentGraphExecution.
* Use executionStatus from AgentGraphExecution to improve waiting logic on test_manager.py.
This commit is contained in:
@@ -9,7 +9,8 @@ REDIS_HOST=localhost
|
||||
REDIS_PORT=6379
|
||||
REDIS_PASSWORD=password
|
||||
|
||||
AUTH_ENABLED=false
|
||||
ENABLE_AUTH=false
|
||||
ENABLE_CREDIT=false
|
||||
APP_ENV="local"
|
||||
PYRO_HOST=localhost
|
||||
SENTRY_DSN=
|
||||
|
||||
@@ -24,6 +24,7 @@ LlmApiKeys = {
|
||||
class ModelMetadata(NamedTuple):
|
||||
provider: str
|
||||
context_window: int
|
||||
cost_factor: int
|
||||
|
||||
|
||||
class LlmModel(str, Enum):
|
||||
@@ -55,26 +56,29 @@ class LlmModel(str, Enum):
|
||||
|
||||
|
||||
MODEL_METADATA = {
|
||||
LlmModel.GPT4O_MINI: ModelMetadata("openai", 128000),
|
||||
LlmModel.GPT4O: ModelMetadata("openai", 128000),
|
||||
LlmModel.GPT4_TURBO: ModelMetadata("openai", 128000),
|
||||
LlmModel.GPT3_5_TURBO: ModelMetadata("openai", 16385),
|
||||
LlmModel.CLAUDE_3_5_SONNET: ModelMetadata("anthropic", 200000),
|
||||
LlmModel.CLAUDE_3_HAIKU: ModelMetadata("anthropic", 200000),
|
||||
LlmModel.LLAMA3_8B: ModelMetadata("groq", 8192),
|
||||
LlmModel.LLAMA3_70B: ModelMetadata("groq", 8192),
|
||||
LlmModel.MIXTRAL_8X7B: ModelMetadata("groq", 32768),
|
||||
LlmModel.GEMMA_7B: ModelMetadata("groq", 8192),
|
||||
LlmModel.GEMMA2_9B: ModelMetadata("groq", 8192),
|
||||
LlmModel.LLAMA3_1_405B: ModelMetadata(
|
||||
"groq", 8192
|
||||
), # Limited to 16k during preview
|
||||
LlmModel.LLAMA3_1_70B: ModelMetadata("groq", 131072),
|
||||
LlmModel.LLAMA3_1_8B: ModelMetadata("groq", 131072),
|
||||
LlmModel.OLLAMA_LLAMA3_8B: ModelMetadata("ollama", 8192),
|
||||
LlmModel.OLLAMA_LLAMA3_405B: ModelMetadata("ollama", 8192),
|
||||
LlmModel.GPT4O_MINI: ModelMetadata("openai", 128000, cost_factor=10),
|
||||
LlmModel.GPT4O: ModelMetadata("openai", 128000, cost_factor=12),
|
||||
LlmModel.GPT4_TURBO: ModelMetadata("openai", 128000, cost_factor=11),
|
||||
LlmModel.GPT3_5_TURBO: ModelMetadata("openai", 16385, cost_factor=8),
|
||||
LlmModel.CLAUDE_3_5_SONNET: ModelMetadata("anthropic", 200000, cost_factor=14),
|
||||
LlmModel.CLAUDE_3_HAIKU: ModelMetadata("anthropic", 200000, cost_factor=13),
|
||||
LlmModel.LLAMA3_8B: ModelMetadata("groq", 8192, cost_factor=6),
|
||||
LlmModel.LLAMA3_70B: ModelMetadata("groq", 8192, cost_factor=9),
|
||||
LlmModel.MIXTRAL_8X7B: ModelMetadata("groq", 32768, cost_factor=7),
|
||||
LlmModel.GEMMA_7B: ModelMetadata("groq", 8192, cost_factor=6),
|
||||
LlmModel.GEMMA2_9B: ModelMetadata("groq", 8192, cost_factor=7),
|
||||
LlmModel.LLAMA3_1_405B: ModelMetadata("groq", 8192, cost_factor=10),
|
||||
# Limited to 16k during preview
|
||||
LlmModel.LLAMA3_1_70B: ModelMetadata("groq", 131072, cost_factor=15),
|
||||
LlmModel.LLAMA3_1_8B: ModelMetadata("groq", 131072, cost_factor=13),
|
||||
LlmModel.OLLAMA_LLAMA3_8B: ModelMetadata("ollama", 8192, cost_factor=7),
|
||||
LlmModel.OLLAMA_LLAMA3_405B: ModelMetadata("ollama", 8192, cost_factor=11),
|
||||
}
|
||||
|
||||
for model in LlmModel:
|
||||
if model not in MODEL_METADATA:
|
||||
raise ValueError(f"Missing MODEL_METADATA metadata for model: {model}")
|
||||
|
||||
|
||||
class AIStructuredResponseGeneratorBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
@@ -301,7 +305,7 @@ class AITextGeneratorBlock(Block):
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class TextSummarizerBlock(Block):
|
||||
class AITextSummarizerBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
text: str
|
||||
model: LlmModel = LlmModel.GPT4_TURBO
|
||||
@@ -319,8 +323,8 @@ class TextSummarizerBlock(Block):
|
||||
id="c3d4e5f6-7g8h-9i0j-1k2l-m3n4o5p6q7r8",
|
||||
description="Utilize a Large Language Model (LLM) to summarize a long text.",
|
||||
categories={BlockCategory.AI, BlockCategory.TEXT},
|
||||
input_schema=TextSummarizerBlock.Input,
|
||||
output_schema=TextSummarizerBlock.Output,
|
||||
input_schema=AITextSummarizerBlock.Input,
|
||||
output_schema=AITextSummarizerBlock.Output,
|
||||
test_input={"text": "Lorem ipsum..." * 100},
|
||||
test_output=("summary", "Final summary of a long text"),
|
||||
test_mock={
|
||||
@@ -412,7 +416,7 @@ class TextSummarizerBlock(Block):
|
||||
else:
|
||||
# If combined summaries are still too long, recursively summarize
|
||||
return self._run(
|
||||
TextSummarizerBlock.Input(
|
||||
AITextSummarizerBlock.Input(
|
||||
text=combined_text,
|
||||
api_key=input_data.api_key,
|
||||
model=input_data.model,
|
||||
|
||||
263
rnd/autogpt_server/autogpt_server/data/credit.py
Normal file
263
rnd/autogpt_server/autogpt_server/data/credit.py
Normal file
@@ -0,0 +1,263 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
import prisma.errors
|
||||
from prisma import Json
|
||||
from prisma.enums import UserBlockCreditType
|
||||
from prisma.models import UserBlockCredit
|
||||
from pydantic import BaseModel
|
||||
|
||||
from autogpt_server.blocks.llm import (
|
||||
MODEL_METADATA,
|
||||
AIConversationBlock,
|
||||
AIStructuredResponseGeneratorBlock,
|
||||
AITextGeneratorBlock,
|
||||
AITextSummarizerBlock,
|
||||
)
|
||||
from autogpt_server.blocks.talking_head import CreateTalkingAvatarVideoBlock
|
||||
from autogpt_server.data.block import Block, BlockInput
|
||||
from autogpt_server.util.settings import Config
|
||||
|
||||
|
||||
class BlockCostType(str, Enum):
|
||||
RUN = "run" # cost X credits per run
|
||||
BYTE = "byte" # cost X credits per byte
|
||||
SECOND = "second" # cost X credits per second
|
||||
|
||||
|
||||
class BlockCost(BaseModel):
|
||||
cost_amount: int
|
||||
cost_filter: BlockInput
|
||||
cost_type: BlockCostType
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cost_amount: int,
|
||||
cost_type: BlockCostType = BlockCostType.RUN,
|
||||
cost_filter: Optional[BlockInput] = None,
|
||||
**data: Any,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
cost_amount=cost_amount,
|
||||
cost_filter=cost_filter or {},
|
||||
cost_type=cost_type,
|
||||
**data,
|
||||
)
|
||||
|
||||
|
||||
llm_cost = [
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter={
|
||||
"model": model,
|
||||
"api_key": None, # Running LLM with user own API key is free.
|
||||
},
|
||||
cost_amount=metadata.cost_factor,
|
||||
)
|
||||
for model, metadata in MODEL_METADATA.items()
|
||||
]
|
||||
|
||||
BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
|
||||
AIConversationBlock: llm_cost,
|
||||
AITextGeneratorBlock: llm_cost,
|
||||
AIStructuredResponseGeneratorBlock: llm_cost,
|
||||
AITextSummarizerBlock: llm_cost,
|
||||
CreateTalkingAvatarVideoBlock: [
|
||||
BlockCost(cost_amount=15, cost_filter={"api_key": None})
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class UserCreditBase(ABC):
|
||||
def __init__(self, num_user_credits_refill: int):
|
||||
self.num_user_credits_refill = num_user_credits_refill
|
||||
|
||||
@abstractmethod
|
||||
async def get_or_refill_credit(self, user_id: str) -> int:
|
||||
"""
|
||||
Get the current credit for the user and refill if no transaction has been made in the current cycle.
|
||||
|
||||
Returns:
|
||||
int: The current credit for the user.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def spend_credits(
|
||||
self,
|
||||
user_id: str,
|
||||
user_credit: int,
|
||||
block: Block,
|
||||
input_data: BlockInput,
|
||||
data_size: float,
|
||||
run_time: float,
|
||||
) -> int:
|
||||
"""
|
||||
Spend the credits for the user based on the block usage.
|
||||
|
||||
Args:
|
||||
user_id (str): The user ID.
|
||||
user_credit (int): The current credit for the user.
|
||||
block (Block): The block that is being used.
|
||||
input_data (BlockInput): The input data for the block.
|
||||
data_size (float): The size of the data being processed.
|
||||
run_time (float): The time taken to run the block.
|
||||
|
||||
Returns:
|
||||
int: amount of credit spent
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def top_up_credits(self, user_id: str, amount: int):
|
||||
"""
|
||||
Top up the credits for the user.
|
||||
|
||||
Args:
|
||||
user_id (str): The user ID.
|
||||
amount (int): The amount to top up.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class UserCredit(UserCreditBase):
|
||||
async def get_or_refill_credit(self, user_id: str) -> int:
|
||||
cur_time = self.time_now()
|
||||
cur_month = cur_time.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||
nxt_month = cur_month.replace(month=cur_month.month + 1)
|
||||
|
||||
user_credit = await UserBlockCredit.prisma().group_by(
|
||||
by=["userId"],
|
||||
sum={"amount": True},
|
||||
where={
|
||||
"userId": user_id,
|
||||
"createdAt": {"gte": cur_month, "lt": nxt_month},
|
||||
"isActive": True,
|
||||
},
|
||||
)
|
||||
|
||||
if user_credit:
|
||||
credit_sum = user_credit[0].get("_sum") or {}
|
||||
return credit_sum.get("amount", 0)
|
||||
|
||||
key = f"MONTHLY-CREDIT-TOP-UP-{cur_month}"
|
||||
|
||||
try:
|
||||
await UserBlockCredit.prisma().create(
|
||||
data={
|
||||
"amount": self.num_user_credits_refill,
|
||||
"type": UserBlockCreditType.TOP_UP,
|
||||
"userId": user_id,
|
||||
"transactionKey": key,
|
||||
"createdAt": self.time_now(),
|
||||
}
|
||||
)
|
||||
except prisma.errors.UniqueViolationError:
|
||||
pass # Already refilled this month
|
||||
|
||||
return self.num_user_credits_refill
|
||||
|
||||
@staticmethod
|
||||
def time_now():
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
@staticmethod
|
||||
def _block_usage_cost(
|
||||
block: Block,
|
||||
input_data: BlockInput,
|
||||
data_size: float,
|
||||
run_time: float,
|
||||
) -> tuple[int, BlockInput]:
|
||||
block_costs = BLOCK_COSTS.get(type(block))
|
||||
if not block_costs:
|
||||
return 0, {}
|
||||
|
||||
for block_cost in block_costs:
|
||||
if all(input_data.get(k) == b for k, b in block_cost.cost_filter.items()):
|
||||
if block_cost.cost_type == BlockCostType.RUN:
|
||||
return block_cost.cost_amount, block_cost.cost_filter
|
||||
|
||||
if block_cost.cost_type == BlockCostType.SECOND:
|
||||
return (
|
||||
int(run_time * block_cost.cost_amount),
|
||||
block_cost.cost_filter,
|
||||
)
|
||||
|
||||
if block_cost.cost_type == BlockCostType.BYTE:
|
||||
return (
|
||||
int(data_size * block_cost.cost_amount),
|
||||
block_cost.cost_filter,
|
||||
)
|
||||
|
||||
return 0, {}
|
||||
|
||||
async def spend_credits(
|
||||
self,
|
||||
user_id: str,
|
||||
user_credit: int,
|
||||
block: Block,
|
||||
input_data: BlockInput,
|
||||
data_size: float,
|
||||
run_time: float,
|
||||
validate_balance: bool = True,
|
||||
) -> int:
|
||||
cost, matching_filter = self._block_usage_cost(
|
||||
block=block, input_data=input_data, data_size=data_size, run_time=run_time
|
||||
)
|
||||
if cost <= 0:
|
||||
return 0
|
||||
|
||||
if validate_balance and user_credit < cost:
|
||||
raise ValueError(f"Insufficient credit: {user_credit} < {cost}")
|
||||
|
||||
await UserBlockCredit.prisma().create(
|
||||
data={
|
||||
"userId": user_id,
|
||||
"amount": -cost,
|
||||
"type": UserBlockCreditType.USAGE,
|
||||
"blockId": block.id,
|
||||
"metadata": Json(
|
||||
{
|
||||
"block": block.name,
|
||||
"input": matching_filter,
|
||||
}
|
||||
),
|
||||
"createdAt": self.time_now(),
|
||||
}
|
||||
)
|
||||
return cost
|
||||
|
||||
async def top_up_credits(self, user_id: str, amount: int):
|
||||
await UserBlockCredit.prisma().create(
|
||||
data={
|
||||
"userId": user_id,
|
||||
"amount": amount,
|
||||
"type": UserBlockCreditType.TOP_UP,
|
||||
"createdAt": self.time_now(),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class DisabledUserCredit(UserCreditBase):
|
||||
async def get_or_refill_credit(self, *args, **kwargs) -> int:
|
||||
return 0
|
||||
|
||||
async def spend_credits(self, *args, **kwargs) -> int:
|
||||
return 0
|
||||
|
||||
async def top_up_credits(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
def get_user_credit_model() -> UserCreditBase:
|
||||
config = Config()
|
||||
if config.enable_credit.lower() == "true":
|
||||
return UserCredit(config.num_user_credits_refill)
|
||||
else:
|
||||
return DisabledUserCredit(0)
|
||||
|
||||
|
||||
def get_block_costs() -> dict[str, list[BlockCost]]:
|
||||
return {block().id: costs for block, costs in BLOCK_COSTS.items()}
|
||||
@@ -1,9 +1,9 @@
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from multiprocessing import Manager
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from prisma.enums import AgentExecutionStatus
|
||||
from prisma.models import (
|
||||
AgentGraphExecution,
|
||||
AgentNodeExecution,
|
||||
@@ -21,12 +21,14 @@ from autogpt_server.util import json, mock
|
||||
|
||||
|
||||
class GraphExecution(BaseModel):
|
||||
user_id: str
|
||||
graph_exec_id: str
|
||||
start_node_execs: list["NodeExecution"]
|
||||
graph_id: str
|
||||
start_node_execs: list["NodeExecution"]
|
||||
|
||||
|
||||
class NodeExecution(BaseModel):
|
||||
user_id: str
|
||||
graph_exec_id: str
|
||||
graph_id: str
|
||||
node_exec_id: str
|
||||
@@ -34,13 +36,7 @@ class NodeExecution(BaseModel):
|
||||
data: BlockInput
|
||||
|
||||
|
||||
class ExecutionStatus(str, Enum):
|
||||
INCOMPLETE = "INCOMPLETE"
|
||||
QUEUED = "QUEUED"
|
||||
RUNNING = "RUNNING"
|
||||
COMPLETED = "COMPLETED"
|
||||
FAILED = "FAILED"
|
||||
|
||||
ExecutionStatus = AgentExecutionStatus
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -148,6 +144,7 @@ async def create_graph_execution(
|
||||
data={
|
||||
"agentGraphId": graph_id,
|
||||
"agentGraphVersion": graph_version,
|
||||
"executionStatus": ExecutionStatus.QUEUED,
|
||||
"AgentNodeExecutions": {
|
||||
"create": [ # type: ignore
|
||||
{
|
||||
@@ -259,10 +256,20 @@ async def upsert_execution_output(
|
||||
)
|
||||
|
||||
|
||||
async def update_graph_execution_start_time(graph_exec_id: str):
|
||||
await AgentGraphExecution.prisma().update(
|
||||
where={"id": graph_exec_id},
|
||||
data={
|
||||
"executionStatus": ExecutionStatus.RUNNING,
|
||||
"startedAt": datetime.now(tz=timezone.utc),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def update_graph_execution_stats(graph_exec_id: str, stats: dict[str, Any]):
|
||||
await AgentGraphExecution.prisma().update(
|
||||
where={"id": graph_exec_id},
|
||||
data={"stats": json.dumps(stats)},
|
||||
data={"executionStatus": ExecutionStatus.COMPLETED, "stats": json.dumps(stats)},
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -17,8 +17,10 @@ if TYPE_CHECKING:
|
||||
from autogpt_server.blocks.basic import AgentInputBlock
|
||||
from autogpt_server.data import db
|
||||
from autogpt_server.data.block import Block, BlockData, BlockInput, get_block
|
||||
from autogpt_server.data.credit import get_user_credit_model
|
||||
from autogpt_server.data.execution import (
|
||||
ExecutionQueue,
|
||||
ExecutionResult,
|
||||
ExecutionStatus,
|
||||
GraphExecution,
|
||||
NodeExecution,
|
||||
@@ -45,25 +47,41 @@ from autogpt_server.util.type import convert
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_log_metadata(
|
||||
graph_eid: str,
|
||||
graph_id: str,
|
||||
node_eid: str,
|
||||
node_id: str,
|
||||
block_name: str,
|
||||
) -> dict:
|
||||
return {
|
||||
"component": "ExecutionManager",
|
||||
"graph_eid": graph_eid,
|
||||
"graph_id": graph_id,
|
||||
"node_eid": node_eid,
|
||||
"node_id": node_id,
|
||||
"block_name": block_name,
|
||||
}
|
||||
class LogMetadata:
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str,
|
||||
graph_eid: str,
|
||||
graph_id: str,
|
||||
node_eid: str,
|
||||
node_id: str,
|
||||
block_name: str,
|
||||
):
|
||||
self.metadata = {
|
||||
"component": "ExecutionManager",
|
||||
"user_id": user_id,
|
||||
"graph_eid": graph_eid,
|
||||
"graph_id": graph_id,
|
||||
"node_eid": node_eid,
|
||||
"node_id": node_id,
|
||||
"block_name": block_name,
|
||||
}
|
||||
self.prefix = f"[ExecutionManager|uid:{user_id}|gid:{graph_id}|nid:{node_id}]|geid:{graph_eid}|nid:{node_eid}|{block_name}]"
|
||||
|
||||
def info(self, msg: str, **extra):
|
||||
logger.info(msg, extra={"json_fields": {**self.metadata, **extra}})
|
||||
|
||||
def get_log_prefix(graph_eid: str, node_eid: str, block_name: str = "-"):
|
||||
return f"[ExecutionManager][graph-eid-{graph_eid}|node-eid-{node_eid}|{block_name}]"
|
||||
def warning(self, msg: str, **extra):
|
||||
logger.warning(msg, extra={"json_fields": {**self.metadata, **extra}})
|
||||
|
||||
def error(self, msg: str, **extra):
|
||||
logger.error(msg, extra={"json_fields": {**self.metadata, **extra}})
|
||||
|
||||
def debug(self, msg: str, **extra):
|
||||
logger.debug(msg, extra={"json_fields": {**self.metadata, **extra}})
|
||||
|
||||
def exception(self, msg: str, **extra):
|
||||
logger.exception(msg, extra={"json_fields": {**self.metadata, **extra}})
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -89,6 +107,7 @@ def execute_node(
|
||||
Returns:
|
||||
The subsequent node to be enqueued, or None if there is no subsequent node.
|
||||
"""
|
||||
user_id = data.user_id
|
||||
graph_exec_id = data.graph_exec_id
|
||||
graph_id = data.graph_id
|
||||
node_exec_id = data.node_exec_id
|
||||
@@ -99,9 +118,10 @@ def execute_node(
|
||||
def wait(f: Coroutine[Any, Any, T]) -> T:
|
||||
return loop.run_until_complete(f)
|
||||
|
||||
def update_execution(status: ExecutionStatus):
|
||||
def update_execution(status: ExecutionStatus) -> ExecutionResult:
|
||||
exec_update = wait(update_execution_status(node_exec_id, status))
|
||||
api_client.send_execution_update(exec_update.model_dump())
|
||||
return exec_update
|
||||
|
||||
node = wait(get_node(node_id))
|
||||
|
||||
@@ -111,43 +131,35 @@ def execute_node(
|
||||
return
|
||||
|
||||
# Sanity check: validate the execution input.
|
||||
log_metadata = get_log_metadata(
|
||||
log_metadata = LogMetadata(
|
||||
user_id=user_id,
|
||||
graph_eid=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
node_eid=node_exec_id,
|
||||
node_id=node_id,
|
||||
block_name=node_block.name,
|
||||
)
|
||||
prefix = get_log_prefix(
|
||||
graph_eid=graph_exec_id,
|
||||
node_eid=node_exec_id,
|
||||
block_name=node_block.name,
|
||||
)
|
||||
input_data, error = validate_exec(node, data.data, resolve_input=False)
|
||||
if input_data is None:
|
||||
logger.error(
|
||||
"{prefix} Skip execution, input validation error",
|
||||
extra={"json_fields": {**log_metadata, "error": error}},
|
||||
)
|
||||
log_metadata.error(f"Skip execution, input validation error: {error}")
|
||||
return
|
||||
|
||||
# Execute the node
|
||||
input_data_str = json.dumps(input_data)
|
||||
input_size = len(input_data_str)
|
||||
logger.info(
|
||||
f"{prefix} Executed node with input",
|
||||
extra={"json_fields": {**log_metadata, "input": input_data_str}},
|
||||
)
|
||||
log_metadata.info("Executed node with input", input=input_data_str)
|
||||
update_execution(ExecutionStatus.RUNNING)
|
||||
user_credit = get_user_credit_model()
|
||||
|
||||
output_size = 0
|
||||
try:
|
||||
credit = wait(user_credit.get_or_refill_credit(user_id))
|
||||
if credit < 0:
|
||||
raise ValueError("Insufficient credit: {credit}")
|
||||
|
||||
for output_name, output_data in node_block.execute(input_data):
|
||||
output_size += len(json.dumps(output_data))
|
||||
logger.info(
|
||||
f"{prefix} Node produced output",
|
||||
extra={"json_fields": {**log_metadata, output_name: output_data}},
|
||||
)
|
||||
log_metadata.info("Node produced output", output_name=output_data)
|
||||
wait(upsert_execution_output(node_exec_id, output_name, output_data))
|
||||
|
||||
for execution in _enqueue_next_nodes(
|
||||
@@ -155,20 +167,25 @@ def execute_node(
|
||||
loop=loop,
|
||||
node=node,
|
||||
output=(output_name, output_data),
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
log_metadata=log_metadata,
|
||||
):
|
||||
yield execution
|
||||
|
||||
update_execution(ExecutionStatus.COMPLETED)
|
||||
r = update_execution(ExecutionStatus.COMPLETED)
|
||||
s = input_size + output_size
|
||||
t = (
|
||||
(r.end_time - r.start_time).total_seconds()
|
||||
if r.end_time and r.start_time
|
||||
else 0
|
||||
)
|
||||
wait(user_credit.spend_credits(user_id, credit, node_block, input_data, s, t))
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"{e.__class__.__name__}: {e}"
|
||||
logger.exception(
|
||||
f"{prefix} Node execution failed with error",
|
||||
extra={"json_fields": {**log_metadata, error: error_msg}},
|
||||
)
|
||||
error_msg = str(e)
|
||||
log_metadata.exception(f"Node execution failed with error {error_msg}")
|
||||
wait(upsert_execution_output(node_exec_id, "error", error_msg))
|
||||
update_execution(ExecutionStatus.FAILED)
|
||||
|
||||
@@ -194,9 +211,10 @@ def _enqueue_next_nodes(
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
node: Node,
|
||||
output: BlockData,
|
||||
user_id: str,
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
log_metadata: dict,
|
||||
log_metadata: LogMetadata,
|
||||
) -> list[NodeExecution]:
|
||||
def wait(f: Coroutine[Any, Any, T]) -> T:
|
||||
return loop.run_until_complete(f)
|
||||
@@ -209,6 +227,7 @@ def _enqueue_next_nodes(
|
||||
)
|
||||
api_client.send_execution_update(exec_update.model_dump())
|
||||
return NodeExecution(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
node_exec_id=node_exec_id,
|
||||
@@ -262,17 +281,11 @@ def _enqueue_next_nodes(
|
||||
|
||||
# Incomplete input data, skip queueing the execution.
|
||||
if not next_node_input:
|
||||
logger.warning(
|
||||
f"Skipped queueing {suffix}",
|
||||
extra={"json_fields": {**log_metadata}},
|
||||
)
|
||||
log_metadata.warning(f"Skipped queueing {suffix}")
|
||||
return enqueued_executions
|
||||
|
||||
# Input is complete, enqueue the execution.
|
||||
logger.info(
|
||||
f"Enqueued {suffix}",
|
||||
extra={"json_fields": {**log_metadata}},
|
||||
)
|
||||
log_metadata.info(f"Enqueued {suffix}")
|
||||
enqueued_executions.append(
|
||||
add_enqueued_execution(next_node_exec_id, next_node_id, next_node_input)
|
||||
)
|
||||
@@ -298,11 +311,9 @@ def _enqueue_next_nodes(
|
||||
idata, msg = validate_exec(next_node, idata)
|
||||
suffix = f"{next_output_name}>{next_input_name}~{ineid}:{msg}"
|
||||
if not idata:
|
||||
logger.info(
|
||||
f"{log_metadata} Enqueueing static-link skipped: {suffix}"
|
||||
)
|
||||
log_metadata.info(f"Enqueueing static-link skipped: {suffix}")
|
||||
continue
|
||||
logger.info(f"{log_metadata} Enqueueing static-link execution {suffix}")
|
||||
log_metadata.info(f"Enqueueing static-link execution {suffix}")
|
||||
enqueued_executions.append(
|
||||
add_enqueued_execution(iexec.node_exec_id, next_node_id, idata)
|
||||
)
|
||||
@@ -443,22 +454,18 @@ class Executor:
|
||||
def on_node_execution(
|
||||
cls, q: ExecutionQueue[NodeExecution], node_exec: NodeExecution
|
||||
):
|
||||
log_metadata = get_log_metadata(
|
||||
log_metadata = LogMetadata(
|
||||
user_id=node_exec.user_id,
|
||||
graph_eid=node_exec.graph_exec_id,
|
||||
graph_id=node_exec.graph_id,
|
||||
node_eid=node_exec.node_exec_id,
|
||||
node_id=node_exec.node_id,
|
||||
block_name="-",
|
||||
)
|
||||
prefix = get_log_prefix(
|
||||
graph_eid=node_exec.graph_exec_id,
|
||||
node_eid=node_exec.node_exec_id,
|
||||
block_name="-",
|
||||
)
|
||||
|
||||
execution_stats = {}
|
||||
timing_info, _ = cls._on_node_execution(
|
||||
q, node_exec, log_metadata, prefix, execution_stats
|
||||
q, node_exec, log_metadata, execution_stats
|
||||
)
|
||||
execution_stats["walltime"] = timing_info.wall_time
|
||||
execution_stats["cputime"] = timing_info.cpu_time
|
||||
@@ -473,29 +480,19 @@ class Executor:
|
||||
cls,
|
||||
q: ExecutionQueue[NodeExecution],
|
||||
node_exec: NodeExecution,
|
||||
log_metadata: dict,
|
||||
prefix: str,
|
||||
log_metadata: LogMetadata,
|
||||
stats: dict[str, Any] | None = None,
|
||||
):
|
||||
try:
|
||||
logger.info(
|
||||
f"{prefix} Start node execution {node_exec.node_exec_id}",
|
||||
extra={"json_fields": {**log_metadata}},
|
||||
)
|
||||
log_metadata.info(f"Start node execution {node_exec.node_exec_id}")
|
||||
for execution in execute_node(
|
||||
cls.loop, cls.agent_server_client, node_exec, stats
|
||||
):
|
||||
q.add(execution)
|
||||
logger.info(
|
||||
f"{prefix} Finished node execution {node_exec.node_exec_id}",
|
||||
extra={"json_fields": {**log_metadata}},
|
||||
)
|
||||
log_metadata.info(f"Finished node execution {node_exec.node_exec_id}")
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed node execution {node_exec.node_exec_id}: {e}",
|
||||
extra={
|
||||
**log_metadata,
|
||||
},
|
||||
log_metadata.exception(
|
||||
f"Failed node execution {node_exec.node_exec_id}: {e}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -517,10 +514,12 @@ class Executor:
|
||||
|
||||
@classmethod
|
||||
def on_graph_executor_stop(cls):
|
||||
logger.info(
|
||||
f"[on_graph_executor_stop {cls.pid}] ⏳ Terminating node executor pool..."
|
||||
)
|
||||
prefix = f"[on_graph_executor_stop {cls.pid}]"
|
||||
logger.info(f"{prefix} ⏳ Disconnecting DB...")
|
||||
cls.loop.run_until_complete(db.disconnect())
|
||||
logger.info(f"{prefix} ⏳ Terminating node executor pool...")
|
||||
cls.executor.terminate()
|
||||
logger.info(f"{prefix} ✅ Finished cleanup")
|
||||
|
||||
@classmethod
|
||||
def _init_node_executor_pool(cls):
|
||||
@@ -532,20 +531,16 @@ class Executor:
|
||||
@classmethod
|
||||
@error_logged
|
||||
def on_graph_execution(cls, graph_exec: GraphExecution, cancel: threading.Event):
|
||||
log_metadata = get_log_metadata(
|
||||
log_metadata = LogMetadata(
|
||||
user_id=graph_exec.user_id,
|
||||
graph_eid=graph_exec.graph_exec_id,
|
||||
graph_id=graph_exec.graph_id,
|
||||
node_id="*",
|
||||
node_eid="*",
|
||||
block_name="-",
|
||||
)
|
||||
prefix = get_log_prefix(
|
||||
graph_eid=graph_exec.graph_exec_id,
|
||||
node_eid="*",
|
||||
block_name="-",
|
||||
)
|
||||
timing_info, node_count = cls._on_graph_execution(
|
||||
graph_exec, cancel, log_metadata, prefix
|
||||
graph_exec, cancel, log_metadata
|
||||
)
|
||||
|
||||
cls.loop.run_until_complete(
|
||||
@@ -565,13 +560,9 @@ class Executor:
|
||||
cls,
|
||||
graph_exec: GraphExecution,
|
||||
cancel: threading.Event,
|
||||
log_metadata: dict,
|
||||
prefix: str,
|
||||
log_metadata: LogMetadata,
|
||||
) -> int:
|
||||
logger.info(
|
||||
f"{prefix} Start graph execution {graph_exec.graph_exec_id}",
|
||||
extra={"json_fields": {**log_metadata}},
|
||||
)
|
||||
log_metadata.info(f"Start graph execution {graph_exec.graph_exec_id}")
|
||||
n_node_executions = 0
|
||||
finished = False
|
||||
|
||||
@@ -581,10 +572,7 @@ class Executor:
|
||||
if finished:
|
||||
return
|
||||
cls.executor.terminate()
|
||||
logger.info(
|
||||
f"{prefix} Terminated graph execution {graph_exec.graph_exec_id}",
|
||||
extra={"json_fields": {**log_metadata}},
|
||||
)
|
||||
log_metadata.info(f"Terminated graph execution {graph_exec.graph_exec_id}")
|
||||
cls._init_node_executor_pool()
|
||||
|
||||
cancel_thread = threading.Thread(target=cancel_handler)
|
||||
@@ -622,10 +610,9 @@ class Executor:
|
||||
# Re-enqueueing the data back to the queue will disrupt the order.
|
||||
execution.wait()
|
||||
|
||||
logger.debug(
|
||||
f"{prefix} Dispatching node execution {exec_data.node_exec_id} "
|
||||
log_metadata.debug(
|
||||
f"Dispatching node execution {exec_data.node_exec_id} "
|
||||
f"for node {exec_data.node_id}",
|
||||
extra={**log_metadata},
|
||||
)
|
||||
running_executions[exec_data.node_id] = cls.executor.apply_async(
|
||||
cls.on_node_execution,
|
||||
@@ -635,10 +622,8 @@ class Executor:
|
||||
|
||||
# Avoid terminating graph execution when some nodes are still running.
|
||||
while queue.empty() and running_executions:
|
||||
logger.debug(
|
||||
"Queue empty; running nodes: "
|
||||
f"{list(running_executions.keys())}",
|
||||
extra={"json_fields": {**log_metadata}},
|
||||
log_metadata.debug(
|
||||
f"Queue empty; running nodes: {list(running_executions.keys())}"
|
||||
)
|
||||
for node_id, execution in list(running_executions.items()):
|
||||
if cancel.is_set():
|
||||
@@ -647,20 +632,13 @@ class Executor:
|
||||
if not queue.empty():
|
||||
break # yield to parent loop to execute new queue items
|
||||
|
||||
logger.debug(
|
||||
f"Waiting on execution of node {node_id}",
|
||||
extra={"json_fields": {**log_metadata}},
|
||||
)
|
||||
log_metadata.debug(f"Waiting on execution of node {node_id}")
|
||||
execution.wait(3)
|
||||
|
||||
logger.info(
|
||||
f"{prefix} Finished graph execution {graph_exec.graph_exec_id}",
|
||||
extra={"json_fields": {**log_metadata}},
|
||||
)
|
||||
log_metadata.info(f"Finished graph execution {graph_exec.graph_exec_id}")
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"{prefix} Failed graph execution {graph_exec.graph_exec_id}: {e}",
|
||||
extra={"json_fields": {**log_metadata}},
|
||||
log_metadata.exception(
|
||||
f"Failed graph execution {graph_exec.graph_exec_id}: {e}"
|
||||
)
|
||||
finally:
|
||||
if not cancel.is_set():
|
||||
@@ -747,6 +725,7 @@ class ExecutionManager(AppService):
|
||||
for node_exec in node_execs:
|
||||
starting_node_execs.append(
|
||||
NodeExecution(
|
||||
user_id=user_id,
|
||||
graph_exec_id=node_exec.graph_exec_id,
|
||||
graph_id=node_exec.graph_id,
|
||||
node_exec_id=node_exec.node_exec_id,
|
||||
@@ -762,6 +741,7 @@ class ExecutionManager(AppService):
|
||||
self.agent_server_client.send_execution_update(exec_update.model_dump())
|
||||
|
||||
graph_exec = GraphExecution(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
start_node_execs=starting_node_execs,
|
||||
|
||||
@@ -15,6 +15,7 @@ from autogpt_server.data import execution as execution_db
|
||||
from autogpt_server.data import graph as graph_db
|
||||
from autogpt_server.data import user as user_db
|
||||
from autogpt_server.data.block import BlockInput, CompletedBlockOutput
|
||||
from autogpt_server.data.credit import get_block_costs, get_user_credit_model
|
||||
from autogpt_server.data.queue import AsyncEventQueue, AsyncRedisEventQueue
|
||||
from autogpt_server.data.user import get_or_create_user
|
||||
from autogpt_server.executor import ExecutionManager, ExecutionScheduler
|
||||
@@ -32,6 +33,7 @@ class AgentServer(AppService):
|
||||
mutex = KeyedMutex()
|
||||
use_redis = True
|
||||
_test_dependency_overrides = {}
|
||||
_user_credit_model = get_user_credit_model()
|
||||
|
||||
def __init__(self, event_queue: AsyncEventQueue | None = None):
|
||||
super().__init__(port=Config().agent_server_port)
|
||||
@@ -91,6 +93,11 @@ class AgentServer(AppService):
|
||||
endpoint=self.get_graph_blocks,
|
||||
methods=["GET"],
|
||||
)
|
||||
api_router.add_api_route(
|
||||
path="/blocks/costs",
|
||||
endpoint=self.get_graph_block_costs,
|
||||
methods=["GET"],
|
||||
)
|
||||
api_router.add_api_route(
|
||||
path="/blocks/{block_id}/execute",
|
||||
endpoint=self.execute_graph_block,
|
||||
@@ -196,6 +203,11 @@ class AgentServer(AppService):
|
||||
endpoint=self.update_schedule,
|
||||
methods=["PUT"],
|
||||
)
|
||||
api_router.add_api_route(
|
||||
path="/credits",
|
||||
endpoint=self.get_user_credits,
|
||||
methods=["GET"],
|
||||
)
|
||||
|
||||
api_router.add_api_route(
|
||||
path="/settings",
|
||||
@@ -265,6 +277,10 @@ class AgentServer(AppService):
|
||||
def get_graph_blocks(cls) -> list[dict[Any, Any]]:
|
||||
return [v.to_dict() for v in block.get_blocks().values()]
|
||||
|
||||
@classmethod
|
||||
def get_graph_block_costs(cls) -> dict[Any, Any]:
|
||||
return get_block_costs()
|
||||
|
||||
@classmethod
|
||||
def execute_graph_block(
|
||||
cls, block_id: str, data: BlockInput
|
||||
@@ -481,6 +497,25 @@ class AgentServer(AppService):
|
||||
|
||||
return await execution_db.list_executions(graph_id, graph_version)
|
||||
|
||||
@classmethod
|
||||
async def get_graph_run_status(
|
||||
cls,
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> execution_db.ExecutionStatus:
|
||||
graph = await graph_db.get_graph(graph_id, user_id=user_id)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
|
||||
execution = await execution_db.get_graph_execution(graph_exec_id, user_id)
|
||||
if not execution:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Execution #{graph_exec_id} not found."
|
||||
)
|
||||
|
||||
return execution.executionStatus
|
||||
|
||||
@classmethod
|
||||
async def get_graph_run_node_execution_results(
|
||||
cls,
|
||||
@@ -522,6 +557,11 @@ class AgentServer(AppService):
|
||||
execution_scheduler.update_schedule(schedule_id, is_enabled, user_id=user_id)
|
||||
return {"id": schedule_id}
|
||||
|
||||
async def get_user_credits(
|
||||
self, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> dict[str, int]:
|
||||
return {"credits": await self._user_credit_model.get_or_refill_credit(user_id)}
|
||||
|
||||
def get_execution_schedules(
|
||||
self, graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> dict[str, str]:
|
||||
|
||||
@@ -252,7 +252,6 @@ Here are a couple of sample of the Block class implementation:
|
||||
|
||||
async def block_autogen_agent():
|
||||
async with SpinTestServer() as server:
|
||||
test_manager = server.exec_manager
|
||||
test_user = await create_test_user()
|
||||
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
|
||||
input_data = {"input": "Write me a block that writes a string into a file."}
|
||||
@@ -261,10 +260,8 @@ async def block_autogen_agent():
|
||||
)
|
||||
print(response)
|
||||
result = await wait_execution(
|
||||
exec_manager=test_manager,
|
||||
graph_id=test_graph.id,
|
||||
graph_exec_id=response["id"],
|
||||
num_execs=10,
|
||||
timeout=1200,
|
||||
user_id=test_user.id,
|
||||
)
|
||||
|
||||
@@ -153,7 +153,6 @@ async def create_test_user() -> User:
|
||||
|
||||
async def reddit_marketing_agent():
|
||||
async with SpinTestServer() as server:
|
||||
exec_man = server.exec_manager
|
||||
test_user = await create_test_user()
|
||||
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
|
||||
input_data = {"subreddit": "AutoGPT"}
|
||||
@@ -161,9 +160,7 @@ async def reddit_marketing_agent():
|
||||
test_graph.id, input_data, test_user.id
|
||||
)
|
||||
print(response)
|
||||
result = await wait_execution(
|
||||
exec_man, test_user.id, test_graph.id, response["id"], 13, 120
|
||||
)
|
||||
result = await wait_execution(test_user.id, test_graph.id, response["id"], 120)
|
||||
print(result)
|
||||
|
||||
|
||||
|
||||
@@ -75,7 +75,6 @@ def create_test_graph() -> graph.Graph:
|
||||
|
||||
async def sample_agent():
|
||||
async with SpinTestServer() as server:
|
||||
exec_man = server.exec_manager
|
||||
test_user = await create_test_user()
|
||||
test_graph = await create_graph(create_test_graph(), test_user.id)
|
||||
input_data = {"input_1": "Hello", "input_2": "World"}
|
||||
@@ -83,9 +82,7 @@ async def sample_agent():
|
||||
test_graph.id, input_data, test_user.id
|
||||
)
|
||||
print(response)
|
||||
result = await wait_execution(
|
||||
exec_man, test_user.id, test_graph.id, response["id"], 4, 10
|
||||
)
|
||||
result = await wait_execution(test_user.id, test_graph.id, response["id"], 10)
|
||||
print(result)
|
||||
|
||||
|
||||
|
||||
@@ -42,15 +42,15 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
"""Config for the server."""
|
||||
|
||||
num_graph_workers: int = Field(
|
||||
default=1,
|
||||
default=10,
|
||||
ge=1,
|
||||
le=100,
|
||||
le=1000,
|
||||
description="Maximum number of workers to use for graph execution.",
|
||||
)
|
||||
num_node_workers: int = Field(
|
||||
default=1,
|
||||
default=5,
|
||||
ge=1,
|
||||
le=100,
|
||||
le=1000,
|
||||
description="Maximum number of workers to use for node execution within a single graph.",
|
||||
)
|
||||
pyro_host: str = Field(
|
||||
@@ -61,6 +61,14 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
default="false",
|
||||
description="If authentication is enabled or not",
|
||||
)
|
||||
enable_credit: str = Field(
|
||||
default="false",
|
||||
description="If user credit system is enabled or not",
|
||||
)
|
||||
num_user_credits_refill: int = Field(
|
||||
default=1500,
|
||||
description="Number of credits to refill for each user",
|
||||
)
|
||||
# Add more configuration fields as needed
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
|
||||
@@ -5,6 +5,7 @@ from autogpt_server.data import db
|
||||
from autogpt_server.data.block import Block, initialize_blocks
|
||||
from autogpt_server.data.execution import ExecutionResult, ExecutionStatus
|
||||
from autogpt_server.data.queue import AsyncEventQueue
|
||||
from autogpt_server.data.user import create_default_user
|
||||
from autogpt_server.executor import ExecutionManager, ExecutionScheduler
|
||||
from autogpt_server.server import AgentServer
|
||||
from autogpt_server.server.rest_api import get_user_id
|
||||
@@ -64,6 +65,7 @@ class SpinTestServer:
|
||||
|
||||
await db.connect()
|
||||
await initialize_blocks()
|
||||
await create_default_user("false")
|
||||
|
||||
return self
|
||||
|
||||
@@ -82,25 +84,18 @@ class SpinTestServer:
|
||||
|
||||
|
||||
async def wait_execution(
|
||||
exec_manager: ExecutionManager,
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
num_execs: int,
|
||||
timeout: int = 20,
|
||||
) -> list:
|
||||
async def is_execution_completed():
|
||||
execs = await AgentServer().get_graph_run_node_execution_results(
|
||||
status = await AgentServer().get_graph_run_status(
|
||||
graph_id, graph_exec_id, user_id
|
||||
)
|
||||
return (
|
||||
exec_manager.queue.empty()
|
||||
and len(execs) == num_execs
|
||||
and all(
|
||||
v.status in [ExecutionStatus.COMPLETED, ExecutionStatus.FAILED]
|
||||
for v in execs
|
||||
)
|
||||
)
|
||||
if status == ExecutionStatus.FAILED:
|
||||
raise Exception("Execution failed")
|
||||
return status == ExecutionStatus.COMPLETED
|
||||
|
||||
# Wait for the executions to complete
|
||||
for i in range(timeout):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
{
|
||||
"num_graph_workers": 10,
|
||||
"num_node_workers": 5
|
||||
"num_node_workers": 5,
|
||||
"num_user_credits_refill": 1500
|
||||
}
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
/*
|
||||
Warnings:
|
||||
|
||||
- The `executionStatus` column on the `AgentNodeExecution` table would be dropped and recreated. This will lead to data loss if there is data in the column.
|
||||
|
||||
*/
|
||||
-- CreateEnum
|
||||
CREATE TYPE "AgentExecutionStatus" AS ENUM ('INCOMPLETE', 'QUEUED', 'RUNNING', 'COMPLETED', 'FAILED');
|
||||
|
||||
-- CreateEnum
|
||||
CREATE TYPE "UserBlockCreditType" AS ENUM ('TOP_UP', 'USAGE');
|
||||
|
||||
-- AlterTable
|
||||
ALTER TABLE "AgentGraphExecution" ADD COLUMN "executionStatus" "AgentExecutionStatus" NOT NULL DEFAULT 'COMPLETED',
|
||||
ADD COLUMN "startedAt" TIMESTAMP(3);
|
||||
|
||||
-- AlterTable
|
||||
ALTER TABLE "AgentNodeExecution" DROP COLUMN "executionStatus",
|
||||
ADD COLUMN "executionStatus" "AgentExecutionStatus" NOT NULL DEFAULT 'COMPLETED';
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "UserBlockCredit" (
|
||||
"transactionKey" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"userId" TEXT NOT NULL,
|
||||
"blockId" TEXT,
|
||||
"amount" INTEGER NOT NULL,
|
||||
"type" "UserBlockCreditType" NOT NULL,
|
||||
"isActive" BOOLEAN NOT NULL DEFAULT true,
|
||||
"metadata" JSONB,
|
||||
|
||||
CONSTRAINT "UserBlockCredit_pkey" PRIMARY KEY ("transactionKey","userId")
|
||||
);
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "UserBlockCredit" ADD CONSTRAINT "UserBlockCredit_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE RESTRICT ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "UserBlockCredit" ADD CONSTRAINT "UserBlockCredit_blockId_fkey" FOREIGN KEY ("blockId") REFERENCES "AgentBlock"("id") ON DELETE SET NULL ON UPDATE CASCADE;
|
||||
@@ -22,6 +22,7 @@ model User {
|
||||
AgentGraphs AgentGraph[]
|
||||
AgentGraphExecutions AgentGraphExecution[]
|
||||
AgentGraphExecutionSchedules AgentGraphExecutionSchedule[]
|
||||
UserBlockCredit UserBlockCredit[]
|
||||
|
||||
@@index([id])
|
||||
@@index([email])
|
||||
@@ -29,9 +30,9 @@ model User {
|
||||
|
||||
// This model describes the Agent Graph/Flow (Multi Agent System).
|
||||
model AgentGraph {
|
||||
id String @default(uuid())
|
||||
version Int @default(1)
|
||||
createdAt DateTime @default(now())
|
||||
id String @default(uuid())
|
||||
version Int @default(1)
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime? @updatedAt
|
||||
|
||||
name String?
|
||||
@@ -111,13 +112,26 @@ model AgentBlock {
|
||||
|
||||
// Prisma requires explicit back-references.
|
||||
ReferencedByAgentNode AgentNode[]
|
||||
UserBlockCredit UserBlockCredit[]
|
||||
}
|
||||
|
||||
// This model describes the status of an AgentGraphExecution or AgentNodeExecution.
|
||||
enum AgentExecutionStatus {
|
||||
INCOMPLETE
|
||||
QUEUED
|
||||
RUNNING
|
||||
COMPLETED
|
||||
FAILED
|
||||
}
|
||||
|
||||
// This model describes the execution of an AgentGraph.
|
||||
model AgentGraphExecution {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime? @updatedAt
|
||||
startedAt DateTime?
|
||||
|
||||
executionStatus AgentExecutionStatus @default(COMPLETED)
|
||||
|
||||
agentGraphId String
|
||||
agentGraphVersion Int @default(1)
|
||||
@@ -145,12 +159,10 @@ model AgentNodeExecution {
|
||||
Input AgentNodeExecutionInputOutput[] @relation("AgentNodeExecutionInput")
|
||||
Output AgentNodeExecutionInputOutput[] @relation("AgentNodeExecutionOutput")
|
||||
|
||||
// sqlite does not support enum
|
||||
// enum Status { INCOMPLETE, QUEUED, RUNNING, SUCCESS, FAILED }
|
||||
executionStatus String
|
||||
executionStatus AgentExecutionStatus @default(COMPLETED)
|
||||
// Final JSON serialized input data for the node execution.
|
||||
executionData String?
|
||||
addedTime DateTime @default(now())
|
||||
addedTime DateTime @default(now())
|
||||
queuedTime DateTime?
|
||||
startedTime DateTime?
|
||||
endedTime DateTime?
|
||||
@@ -178,8 +190,8 @@ model AgentNodeExecutionInputOutput {
|
||||
|
||||
// This model describes the recurring execution schedule of an Agent.
|
||||
model AgentGraphExecutionSchedule {
|
||||
id String @id
|
||||
createdAt DateTime @default(now())
|
||||
id String @id
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime? @updatedAt
|
||||
|
||||
agentGraphId String
|
||||
@@ -199,3 +211,27 @@ model AgentGraphExecutionSchedule {
|
||||
|
||||
@@index([isEnabled])
|
||||
}
|
||||
|
||||
enum UserBlockCreditType {
|
||||
TOP_UP
|
||||
USAGE
|
||||
}
|
||||
|
||||
model UserBlockCredit {
|
||||
transactionKey String @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
|
||||
userId String
|
||||
user User @relation(fields: [userId], references: [id])
|
||||
|
||||
blockId String?
|
||||
block AgentBlock? @relation(fields: [blockId], references: [id])
|
||||
|
||||
amount Int
|
||||
type UserBlockCreditType
|
||||
|
||||
isActive Boolean @default(true)
|
||||
metadata Json?
|
||||
|
||||
@@id(name: "creditTransactionIdentifier", [transactionKey, userId])
|
||||
}
|
||||
|
||||
90
rnd/autogpt_server/test/data/test_credit.py
Normal file
90
rnd/autogpt_server/test/data/test_credit.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
from prisma.models import UserBlockCredit
|
||||
|
||||
from autogpt_server.blocks.llm import AITextGeneratorBlock
|
||||
from autogpt_server.data.credit import UserCredit
|
||||
from autogpt_server.data.user import DEFAULT_USER_ID
|
||||
from autogpt_server.util.test import SpinTestServer
|
||||
|
||||
REFILL_VALUE = 1000
|
||||
user_credit = UserCredit(REFILL_VALUE)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_block_credit_usage(server: SpinTestServer):
|
||||
current_credit = await user_credit.get_or_refill_credit(DEFAULT_USER_ID)
|
||||
|
||||
spending_amount_1 = await user_credit.spend_credits(
|
||||
DEFAULT_USER_ID,
|
||||
current_credit,
|
||||
AITextGeneratorBlock(),
|
||||
{"model": "gpt-4-turbo"},
|
||||
0.0,
|
||||
0.0,
|
||||
validate_balance=False,
|
||||
)
|
||||
assert spending_amount_1 > 0
|
||||
|
||||
spending_amount_2 = await user_credit.spend_credits(
|
||||
DEFAULT_USER_ID,
|
||||
current_credit,
|
||||
AITextGeneratorBlock(),
|
||||
{"model": "gpt-4-turbo", "api_key": "owned_api_key"},
|
||||
0.0,
|
||||
0.0,
|
||||
validate_balance=False,
|
||||
)
|
||||
assert spending_amount_2 == 0
|
||||
|
||||
new_credit = await user_credit.get_or_refill_credit(DEFAULT_USER_ID)
|
||||
assert new_credit == current_credit - spending_amount_1 - spending_amount_2
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_block_credit_top_up(server: SpinTestServer):
|
||||
current_credit = await user_credit.get_or_refill_credit(DEFAULT_USER_ID)
|
||||
|
||||
await user_credit.top_up_credits(DEFAULT_USER_ID, 100)
|
||||
|
||||
new_credit = await user_credit.get_or_refill_credit(DEFAULT_USER_ID)
|
||||
assert new_credit == current_credit + 100
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_block_credit_reset(server: SpinTestServer):
|
||||
month1 = datetime(2022, 1, 15)
|
||||
month2 = datetime(2022, 2, 15)
|
||||
|
||||
user_credit.time_now = lambda: month2
|
||||
month2credit = await user_credit.get_or_refill_credit(DEFAULT_USER_ID)
|
||||
|
||||
# Month 1 result should only affect month 1
|
||||
user_credit.time_now = lambda: month1
|
||||
month1credit = await user_credit.get_or_refill_credit(DEFAULT_USER_ID)
|
||||
await user_credit.top_up_credits(DEFAULT_USER_ID, 100)
|
||||
assert await user_credit.get_or_refill_credit(DEFAULT_USER_ID) == month1credit + 100
|
||||
|
||||
# Month 2 balance is unaffected
|
||||
user_credit.time_now = lambda: month2
|
||||
assert await user_credit.get_or_refill_credit(DEFAULT_USER_ID) == month2credit
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_credit_refill(server: SpinTestServer):
|
||||
# Clear all transactions within the month
|
||||
await UserBlockCredit.prisma().update_many(
|
||||
where={
|
||||
"userId": DEFAULT_USER_ID,
|
||||
"createdAt": {
|
||||
"gte": datetime(2022, 2, 1),
|
||||
"lt": datetime(2022, 3, 1),
|
||||
},
|
||||
},
|
||||
data={"isActive": False},
|
||||
)
|
||||
user_credit.time_now = lambda: datetime(2022, 2, 15)
|
||||
|
||||
balance = await user_credit.get_or_refill_credit(DEFAULT_USER_ID)
|
||||
assert balance == REFILL_VALUE
|
||||
@@ -4,7 +4,7 @@ import pytest
|
||||
|
||||
from autogpt_server.blocks.basic import AgentInputBlock, StoreValueBlock
|
||||
from autogpt_server.data.graph import Graph, Link, Node
|
||||
from autogpt_server.data.user import DEFAULT_USER_ID, create_default_user
|
||||
from autogpt_server.data.user import DEFAULT_USER_ID
|
||||
from autogpt_server.server.model import CreateGraph
|
||||
from autogpt_server.util.test import SpinTestServer
|
||||
|
||||
@@ -22,8 +22,6 @@ async def test_graph_creation(server: SpinTestServer):
|
||||
Args:
|
||||
server (SpinTestServer): The test server instance.
|
||||
"""
|
||||
await create_default_user("false")
|
||||
|
||||
value_block = StoreValueBlock().id
|
||||
input_block = AgentInputBlock().id
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ from prisma.models import User
|
||||
from autogpt_server.blocks.basic import FindInDictionaryBlock, StoreValueBlock
|
||||
from autogpt_server.blocks.maths import CalculatorBlock, Operation
|
||||
from autogpt_server.data import execution, graph
|
||||
from autogpt_server.executor import ExecutionManager
|
||||
from autogpt_server.server import AgentServer
|
||||
from autogpt_server.usecases.sample import create_test_graph, create_test_user
|
||||
from autogpt_server.util.test import SpinTestServer, wait_execution
|
||||
@@ -12,7 +11,6 @@ from autogpt_server.util.test import SpinTestServer, wait_execution
|
||||
|
||||
async def execute_graph(
|
||||
agent_server: AgentServer,
|
||||
test_manager: ExecutionManager,
|
||||
test_graph: graph.Graph,
|
||||
test_user: User,
|
||||
input_data: dict,
|
||||
@@ -23,9 +21,8 @@ async def execute_graph(
|
||||
graph_exec_id = response["id"]
|
||||
|
||||
# Execution queue should be empty
|
||||
assert await wait_execution(
|
||||
test_manager, test_user.id, test_graph.id, graph_exec_id, num_execs
|
||||
)
|
||||
result = await wait_execution(test_user.id, test_graph.id, graph_exec_id)
|
||||
assert result and len(result) == num_execs
|
||||
return graph_exec_id
|
||||
|
||||
|
||||
@@ -108,7 +105,6 @@ async def test_agent_execution(server: SpinTestServer):
|
||||
data = {"input_1": "Hello", "input_2": "World"}
|
||||
graph_exec_id = await execute_graph(
|
||||
server.agent_server,
|
||||
server.exec_manager,
|
||||
test_graph,
|
||||
test_user,
|
||||
data,
|
||||
@@ -169,7 +165,7 @@ async def test_input_pin_always_waited(server: SpinTestServer):
|
||||
test_user = await create_test_user()
|
||||
test_graph = await graph.create_graph(test_graph, user_id=test_user.id)
|
||||
graph_exec_id = await execute_graph(
|
||||
server.agent_server, server.exec_manager, test_graph, test_user, {}, 3
|
||||
server.agent_server, test_graph, test_user, {}, 3
|
||||
)
|
||||
|
||||
executions = await server.agent_server.get_graph_run_node_execution_results(
|
||||
@@ -250,7 +246,7 @@ async def test_static_input_link_on_graph(server: SpinTestServer):
|
||||
test_user = await create_test_user()
|
||||
test_graph = await graph.create_graph(test_graph, user_id=test_user.id)
|
||||
graph_exec_id = await execute_graph(
|
||||
server.agent_server, server.exec_manager, test_graph, test_user, {}, 8
|
||||
server.agent_server, test_graph, test_user, {}, 8
|
||||
)
|
||||
executions = await server.agent_server.get_graph_run_node_execution_results(
|
||||
test_graph.id, graph_exec_id, test_user.id
|
||||
|
||||
Reference in New Issue
Block a user