feat(rnd): Add initial credit accounting system for block execution (#8047)

### Background

We need a way to set an execution quota per user for each block execution.

### Changes 🏗️

* Introduced a `UserBlockCredit`, a transaction table tracking the block usage along with it cost/quota.
* The tracking is toggled by `ENABLE_CREDIT` config, default = false.
* Introduced  `BLOCK_COSTS` | `GET /blocks/costs` as a source of information for the cost on each block depending on the input configuration.

Improvements:
* Refactor logging in manager.py to always print a prefix and pass the metadata.
* Make executionStatus on AgentNodeExecution prisma enum. And add executionStatus on AgentGraphExecution.
* Use executionStatus from AgentGraphExecution to improve waiting logic on test_manager.py.
This commit is contained in:
Zamil Majdy
2024-09-14 11:47:28 -05:00
committed by GitHub
parent f32244a112
commit c1f301ab8b
17 changed files with 646 additions and 197 deletions

View File

@@ -9,7 +9,8 @@ REDIS_HOST=localhost
REDIS_PORT=6379
REDIS_PASSWORD=password
AUTH_ENABLED=false
ENABLE_AUTH=false
ENABLE_CREDIT=false
APP_ENV="local"
PYRO_HOST=localhost
SENTRY_DSN=

View File

@@ -24,6 +24,7 @@ LlmApiKeys = {
class ModelMetadata(NamedTuple):
provider: str
context_window: int
cost_factor: int
class LlmModel(str, Enum):
@@ -55,26 +56,29 @@ class LlmModel(str, Enum):
MODEL_METADATA = {
LlmModel.GPT4O_MINI: ModelMetadata("openai", 128000),
LlmModel.GPT4O: ModelMetadata("openai", 128000),
LlmModel.GPT4_TURBO: ModelMetadata("openai", 128000),
LlmModel.GPT3_5_TURBO: ModelMetadata("openai", 16385),
LlmModel.CLAUDE_3_5_SONNET: ModelMetadata("anthropic", 200000),
LlmModel.CLAUDE_3_HAIKU: ModelMetadata("anthropic", 200000),
LlmModel.LLAMA3_8B: ModelMetadata("groq", 8192),
LlmModel.LLAMA3_70B: ModelMetadata("groq", 8192),
LlmModel.MIXTRAL_8X7B: ModelMetadata("groq", 32768),
LlmModel.GEMMA_7B: ModelMetadata("groq", 8192),
LlmModel.GEMMA2_9B: ModelMetadata("groq", 8192),
LlmModel.LLAMA3_1_405B: ModelMetadata(
"groq", 8192
), # Limited to 16k during preview
LlmModel.LLAMA3_1_70B: ModelMetadata("groq", 131072),
LlmModel.LLAMA3_1_8B: ModelMetadata("groq", 131072),
LlmModel.OLLAMA_LLAMA3_8B: ModelMetadata("ollama", 8192),
LlmModel.OLLAMA_LLAMA3_405B: ModelMetadata("ollama", 8192),
LlmModel.GPT4O_MINI: ModelMetadata("openai", 128000, cost_factor=10),
LlmModel.GPT4O: ModelMetadata("openai", 128000, cost_factor=12),
LlmModel.GPT4_TURBO: ModelMetadata("openai", 128000, cost_factor=11),
LlmModel.GPT3_5_TURBO: ModelMetadata("openai", 16385, cost_factor=8),
LlmModel.CLAUDE_3_5_SONNET: ModelMetadata("anthropic", 200000, cost_factor=14),
LlmModel.CLAUDE_3_HAIKU: ModelMetadata("anthropic", 200000, cost_factor=13),
LlmModel.LLAMA3_8B: ModelMetadata("groq", 8192, cost_factor=6),
LlmModel.LLAMA3_70B: ModelMetadata("groq", 8192, cost_factor=9),
LlmModel.MIXTRAL_8X7B: ModelMetadata("groq", 32768, cost_factor=7),
LlmModel.GEMMA_7B: ModelMetadata("groq", 8192, cost_factor=6),
LlmModel.GEMMA2_9B: ModelMetadata("groq", 8192, cost_factor=7),
LlmModel.LLAMA3_1_405B: ModelMetadata("groq", 8192, cost_factor=10),
# Limited to 16k during preview
LlmModel.LLAMA3_1_70B: ModelMetadata("groq", 131072, cost_factor=15),
LlmModel.LLAMA3_1_8B: ModelMetadata("groq", 131072, cost_factor=13),
LlmModel.OLLAMA_LLAMA3_8B: ModelMetadata("ollama", 8192, cost_factor=7),
LlmModel.OLLAMA_LLAMA3_405B: ModelMetadata("ollama", 8192, cost_factor=11),
}
for model in LlmModel:
if model not in MODEL_METADATA:
raise ValueError(f"Missing MODEL_METADATA metadata for model: {model}")
class AIStructuredResponseGeneratorBlock(Block):
class Input(BlockSchema):
@@ -301,7 +305,7 @@ class AITextGeneratorBlock(Block):
yield "error", str(e)
class TextSummarizerBlock(Block):
class AITextSummarizerBlock(Block):
class Input(BlockSchema):
text: str
model: LlmModel = LlmModel.GPT4_TURBO
@@ -319,8 +323,8 @@ class TextSummarizerBlock(Block):
id="c3d4e5f6-7g8h-9i0j-1k2l-m3n4o5p6q7r8",
description="Utilize a Large Language Model (LLM) to summarize a long text.",
categories={BlockCategory.AI, BlockCategory.TEXT},
input_schema=TextSummarizerBlock.Input,
output_schema=TextSummarizerBlock.Output,
input_schema=AITextSummarizerBlock.Input,
output_schema=AITextSummarizerBlock.Output,
test_input={"text": "Lorem ipsum..." * 100},
test_output=("summary", "Final summary of a long text"),
test_mock={
@@ -412,7 +416,7 @@ class TextSummarizerBlock(Block):
else:
# If combined summaries are still too long, recursively summarize
return self._run(
TextSummarizerBlock.Input(
AITextSummarizerBlock.Input(
text=combined_text,
api_key=input_data.api_key,
model=input_data.model,

View File

@@ -0,0 +1,263 @@
from abc import ABC, abstractmethod
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Optional, Type
import prisma.errors
from prisma import Json
from prisma.enums import UserBlockCreditType
from prisma.models import UserBlockCredit
from pydantic import BaseModel
from autogpt_server.blocks.llm import (
MODEL_METADATA,
AIConversationBlock,
AIStructuredResponseGeneratorBlock,
AITextGeneratorBlock,
AITextSummarizerBlock,
)
from autogpt_server.blocks.talking_head import CreateTalkingAvatarVideoBlock
from autogpt_server.data.block import Block, BlockInput
from autogpt_server.util.settings import Config
class BlockCostType(str, Enum):
RUN = "run" # cost X credits per run
BYTE = "byte" # cost X credits per byte
SECOND = "second" # cost X credits per second
class BlockCost(BaseModel):
cost_amount: int
cost_filter: BlockInput
cost_type: BlockCostType
def __init__(
self,
cost_amount: int,
cost_type: BlockCostType = BlockCostType.RUN,
cost_filter: Optional[BlockInput] = None,
**data: Any,
) -> None:
super().__init__(
cost_amount=cost_amount,
cost_filter=cost_filter or {},
cost_type=cost_type,
**data,
)
llm_cost = [
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter={
"model": model,
"api_key": None, # Running LLM with user own API key is free.
},
cost_amount=metadata.cost_factor,
)
for model, metadata in MODEL_METADATA.items()
]
BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
AIConversationBlock: llm_cost,
AITextGeneratorBlock: llm_cost,
AIStructuredResponseGeneratorBlock: llm_cost,
AITextSummarizerBlock: llm_cost,
CreateTalkingAvatarVideoBlock: [
BlockCost(cost_amount=15, cost_filter={"api_key": None})
],
}
class UserCreditBase(ABC):
def __init__(self, num_user_credits_refill: int):
self.num_user_credits_refill = num_user_credits_refill
@abstractmethod
async def get_or_refill_credit(self, user_id: str) -> int:
"""
Get the current credit for the user and refill if no transaction has been made in the current cycle.
Returns:
int: The current credit for the user.
"""
pass
@abstractmethod
async def spend_credits(
self,
user_id: str,
user_credit: int,
block: Block,
input_data: BlockInput,
data_size: float,
run_time: float,
) -> int:
"""
Spend the credits for the user based on the block usage.
Args:
user_id (str): The user ID.
user_credit (int): The current credit for the user.
block (Block): The block that is being used.
input_data (BlockInput): The input data for the block.
data_size (float): The size of the data being processed.
run_time (float): The time taken to run the block.
Returns:
int: amount of credit spent
"""
pass
@abstractmethod
async def top_up_credits(self, user_id: str, amount: int):
"""
Top up the credits for the user.
Args:
user_id (str): The user ID.
amount (int): The amount to top up.
"""
pass
class UserCredit(UserCreditBase):
async def get_or_refill_credit(self, user_id: str) -> int:
cur_time = self.time_now()
cur_month = cur_time.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
nxt_month = cur_month.replace(month=cur_month.month + 1)
user_credit = await UserBlockCredit.prisma().group_by(
by=["userId"],
sum={"amount": True},
where={
"userId": user_id,
"createdAt": {"gte": cur_month, "lt": nxt_month},
"isActive": True,
},
)
if user_credit:
credit_sum = user_credit[0].get("_sum") or {}
return credit_sum.get("amount", 0)
key = f"MONTHLY-CREDIT-TOP-UP-{cur_month}"
try:
await UserBlockCredit.prisma().create(
data={
"amount": self.num_user_credits_refill,
"type": UserBlockCreditType.TOP_UP,
"userId": user_id,
"transactionKey": key,
"createdAt": self.time_now(),
}
)
except prisma.errors.UniqueViolationError:
pass # Already refilled this month
return self.num_user_credits_refill
@staticmethod
def time_now():
return datetime.now(timezone.utc)
@staticmethod
def _block_usage_cost(
block: Block,
input_data: BlockInput,
data_size: float,
run_time: float,
) -> tuple[int, BlockInput]:
block_costs = BLOCK_COSTS.get(type(block))
if not block_costs:
return 0, {}
for block_cost in block_costs:
if all(input_data.get(k) == b for k, b in block_cost.cost_filter.items()):
if block_cost.cost_type == BlockCostType.RUN:
return block_cost.cost_amount, block_cost.cost_filter
if block_cost.cost_type == BlockCostType.SECOND:
return (
int(run_time * block_cost.cost_amount),
block_cost.cost_filter,
)
if block_cost.cost_type == BlockCostType.BYTE:
return (
int(data_size * block_cost.cost_amount),
block_cost.cost_filter,
)
return 0, {}
async def spend_credits(
self,
user_id: str,
user_credit: int,
block: Block,
input_data: BlockInput,
data_size: float,
run_time: float,
validate_balance: bool = True,
) -> int:
cost, matching_filter = self._block_usage_cost(
block=block, input_data=input_data, data_size=data_size, run_time=run_time
)
if cost <= 0:
return 0
if validate_balance and user_credit < cost:
raise ValueError(f"Insufficient credit: {user_credit} < {cost}")
await UserBlockCredit.prisma().create(
data={
"userId": user_id,
"amount": -cost,
"type": UserBlockCreditType.USAGE,
"blockId": block.id,
"metadata": Json(
{
"block": block.name,
"input": matching_filter,
}
),
"createdAt": self.time_now(),
}
)
return cost
async def top_up_credits(self, user_id: str, amount: int):
await UserBlockCredit.prisma().create(
data={
"userId": user_id,
"amount": amount,
"type": UserBlockCreditType.TOP_UP,
"createdAt": self.time_now(),
}
)
class DisabledUserCredit(UserCreditBase):
async def get_or_refill_credit(self, *args, **kwargs) -> int:
return 0
async def spend_credits(self, *args, **kwargs) -> int:
return 0
async def top_up_credits(self, *args, **kwargs):
pass
def get_user_credit_model() -> UserCreditBase:
config = Config()
if config.enable_credit.lower() == "true":
return UserCredit(config.num_user_credits_refill)
else:
return DisabledUserCredit(0)
def get_block_costs() -> dict[str, list[BlockCost]]:
return {block().id: costs for block, costs in BLOCK_COSTS.items()}

View File

@@ -1,9 +1,9 @@
from collections import defaultdict
from datetime import datetime, timezone
from enum import Enum
from multiprocessing import Manager
from typing import Any, Generic, TypeVar
from prisma.enums import AgentExecutionStatus
from prisma.models import (
AgentGraphExecution,
AgentNodeExecution,
@@ -21,12 +21,14 @@ from autogpt_server.util import json, mock
class GraphExecution(BaseModel):
user_id: str
graph_exec_id: str
start_node_execs: list["NodeExecution"]
graph_id: str
start_node_execs: list["NodeExecution"]
class NodeExecution(BaseModel):
user_id: str
graph_exec_id: str
graph_id: str
node_exec_id: str
@@ -34,13 +36,7 @@ class NodeExecution(BaseModel):
data: BlockInput
class ExecutionStatus(str, Enum):
INCOMPLETE = "INCOMPLETE"
QUEUED = "QUEUED"
RUNNING = "RUNNING"
COMPLETED = "COMPLETED"
FAILED = "FAILED"
ExecutionStatus = AgentExecutionStatus
T = TypeVar("T")
@@ -148,6 +144,7 @@ async def create_graph_execution(
data={
"agentGraphId": graph_id,
"agentGraphVersion": graph_version,
"executionStatus": ExecutionStatus.QUEUED,
"AgentNodeExecutions": {
"create": [ # type: ignore
{
@@ -259,10 +256,20 @@ async def upsert_execution_output(
)
async def update_graph_execution_start_time(graph_exec_id: str):
await AgentGraphExecution.prisma().update(
where={"id": graph_exec_id},
data={
"executionStatus": ExecutionStatus.RUNNING,
"startedAt": datetime.now(tz=timezone.utc),
},
)
async def update_graph_execution_stats(graph_exec_id: str, stats: dict[str, Any]):
await AgentGraphExecution.prisma().update(
where={"id": graph_exec_id},
data={"stats": json.dumps(stats)},
data={"executionStatus": ExecutionStatus.COMPLETED, "stats": json.dumps(stats)},
)

View File

@@ -17,8 +17,10 @@ if TYPE_CHECKING:
from autogpt_server.blocks.basic import AgentInputBlock
from autogpt_server.data import db
from autogpt_server.data.block import Block, BlockData, BlockInput, get_block
from autogpt_server.data.credit import get_user_credit_model
from autogpt_server.data.execution import (
ExecutionQueue,
ExecutionResult,
ExecutionStatus,
GraphExecution,
NodeExecution,
@@ -45,25 +47,41 @@ from autogpt_server.util.type import convert
logger = logging.getLogger(__name__)
def get_log_metadata(
graph_eid: str,
graph_id: str,
node_eid: str,
node_id: str,
block_name: str,
) -> dict:
return {
"component": "ExecutionManager",
"graph_eid": graph_eid,
"graph_id": graph_id,
"node_eid": node_eid,
"node_id": node_id,
"block_name": block_name,
}
class LogMetadata:
def __init__(
self,
user_id: str,
graph_eid: str,
graph_id: str,
node_eid: str,
node_id: str,
block_name: str,
):
self.metadata = {
"component": "ExecutionManager",
"user_id": user_id,
"graph_eid": graph_eid,
"graph_id": graph_id,
"node_eid": node_eid,
"node_id": node_id,
"block_name": block_name,
}
self.prefix = f"[ExecutionManager|uid:{user_id}|gid:{graph_id}|nid:{node_id}]|geid:{graph_eid}|nid:{node_eid}|{block_name}]"
def info(self, msg: str, **extra):
logger.info(msg, extra={"json_fields": {**self.metadata, **extra}})
def get_log_prefix(graph_eid: str, node_eid: str, block_name: str = "-"):
return f"[ExecutionManager][graph-eid-{graph_eid}|node-eid-{node_eid}|{block_name}]"
def warning(self, msg: str, **extra):
logger.warning(msg, extra={"json_fields": {**self.metadata, **extra}})
def error(self, msg: str, **extra):
logger.error(msg, extra={"json_fields": {**self.metadata, **extra}})
def debug(self, msg: str, **extra):
logger.debug(msg, extra={"json_fields": {**self.metadata, **extra}})
def exception(self, msg: str, **extra):
logger.exception(msg, extra={"json_fields": {**self.metadata, **extra}})
T = TypeVar("T")
@@ -89,6 +107,7 @@ def execute_node(
Returns:
The subsequent node to be enqueued, or None if there is no subsequent node.
"""
user_id = data.user_id
graph_exec_id = data.graph_exec_id
graph_id = data.graph_id
node_exec_id = data.node_exec_id
@@ -99,9 +118,10 @@ def execute_node(
def wait(f: Coroutine[Any, Any, T]) -> T:
return loop.run_until_complete(f)
def update_execution(status: ExecutionStatus):
def update_execution(status: ExecutionStatus) -> ExecutionResult:
exec_update = wait(update_execution_status(node_exec_id, status))
api_client.send_execution_update(exec_update.model_dump())
return exec_update
node = wait(get_node(node_id))
@@ -111,43 +131,35 @@ def execute_node(
return
# Sanity check: validate the execution input.
log_metadata = get_log_metadata(
log_metadata = LogMetadata(
user_id=user_id,
graph_eid=graph_exec_id,
graph_id=graph_id,
node_eid=node_exec_id,
node_id=node_id,
block_name=node_block.name,
)
prefix = get_log_prefix(
graph_eid=graph_exec_id,
node_eid=node_exec_id,
block_name=node_block.name,
)
input_data, error = validate_exec(node, data.data, resolve_input=False)
if input_data is None:
logger.error(
"{prefix} Skip execution, input validation error",
extra={"json_fields": {**log_metadata, "error": error}},
)
log_metadata.error(f"Skip execution, input validation error: {error}")
return
# Execute the node
input_data_str = json.dumps(input_data)
input_size = len(input_data_str)
logger.info(
f"{prefix} Executed node with input",
extra={"json_fields": {**log_metadata, "input": input_data_str}},
)
log_metadata.info("Executed node with input", input=input_data_str)
update_execution(ExecutionStatus.RUNNING)
user_credit = get_user_credit_model()
output_size = 0
try:
credit = wait(user_credit.get_or_refill_credit(user_id))
if credit < 0:
raise ValueError("Insufficient credit: {credit}")
for output_name, output_data in node_block.execute(input_data):
output_size += len(json.dumps(output_data))
logger.info(
f"{prefix} Node produced output",
extra={"json_fields": {**log_metadata, output_name: output_data}},
)
log_metadata.info("Node produced output", output_name=output_data)
wait(upsert_execution_output(node_exec_id, output_name, output_data))
for execution in _enqueue_next_nodes(
@@ -155,20 +167,25 @@ def execute_node(
loop=loop,
node=node,
output=(output_name, output_data),
user_id=user_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
log_metadata=log_metadata,
):
yield execution
update_execution(ExecutionStatus.COMPLETED)
r = update_execution(ExecutionStatus.COMPLETED)
s = input_size + output_size
t = (
(r.end_time - r.start_time).total_seconds()
if r.end_time and r.start_time
else 0
)
wait(user_credit.spend_credits(user_id, credit, node_block, input_data, s, t))
except Exception as e:
error_msg = f"{e.__class__.__name__}: {e}"
logger.exception(
f"{prefix} Node execution failed with error",
extra={"json_fields": {**log_metadata, error: error_msg}},
)
error_msg = str(e)
log_metadata.exception(f"Node execution failed with error {error_msg}")
wait(upsert_execution_output(node_exec_id, "error", error_msg))
update_execution(ExecutionStatus.FAILED)
@@ -194,9 +211,10 @@ def _enqueue_next_nodes(
loop: asyncio.AbstractEventLoop,
node: Node,
output: BlockData,
user_id: str,
graph_exec_id: str,
graph_id: str,
log_metadata: dict,
log_metadata: LogMetadata,
) -> list[NodeExecution]:
def wait(f: Coroutine[Any, Any, T]) -> T:
return loop.run_until_complete(f)
@@ -209,6 +227,7 @@ def _enqueue_next_nodes(
)
api_client.send_execution_update(exec_update.model_dump())
return NodeExecution(
user_id=user_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
node_exec_id=node_exec_id,
@@ -262,17 +281,11 @@ def _enqueue_next_nodes(
# Incomplete input data, skip queueing the execution.
if not next_node_input:
logger.warning(
f"Skipped queueing {suffix}",
extra={"json_fields": {**log_metadata}},
)
log_metadata.warning(f"Skipped queueing {suffix}")
return enqueued_executions
# Input is complete, enqueue the execution.
logger.info(
f"Enqueued {suffix}",
extra={"json_fields": {**log_metadata}},
)
log_metadata.info(f"Enqueued {suffix}")
enqueued_executions.append(
add_enqueued_execution(next_node_exec_id, next_node_id, next_node_input)
)
@@ -298,11 +311,9 @@ def _enqueue_next_nodes(
idata, msg = validate_exec(next_node, idata)
suffix = f"{next_output_name}>{next_input_name}~{ineid}:{msg}"
if not idata:
logger.info(
f"{log_metadata} Enqueueing static-link skipped: {suffix}"
)
log_metadata.info(f"Enqueueing static-link skipped: {suffix}")
continue
logger.info(f"{log_metadata} Enqueueing static-link execution {suffix}")
log_metadata.info(f"Enqueueing static-link execution {suffix}")
enqueued_executions.append(
add_enqueued_execution(iexec.node_exec_id, next_node_id, idata)
)
@@ -443,22 +454,18 @@ class Executor:
def on_node_execution(
cls, q: ExecutionQueue[NodeExecution], node_exec: NodeExecution
):
log_metadata = get_log_metadata(
log_metadata = LogMetadata(
user_id=node_exec.user_id,
graph_eid=node_exec.graph_exec_id,
graph_id=node_exec.graph_id,
node_eid=node_exec.node_exec_id,
node_id=node_exec.node_id,
block_name="-",
)
prefix = get_log_prefix(
graph_eid=node_exec.graph_exec_id,
node_eid=node_exec.node_exec_id,
block_name="-",
)
execution_stats = {}
timing_info, _ = cls._on_node_execution(
q, node_exec, log_metadata, prefix, execution_stats
q, node_exec, log_metadata, execution_stats
)
execution_stats["walltime"] = timing_info.wall_time
execution_stats["cputime"] = timing_info.cpu_time
@@ -473,29 +480,19 @@ class Executor:
cls,
q: ExecutionQueue[NodeExecution],
node_exec: NodeExecution,
log_metadata: dict,
prefix: str,
log_metadata: LogMetadata,
stats: dict[str, Any] | None = None,
):
try:
logger.info(
f"{prefix} Start node execution {node_exec.node_exec_id}",
extra={"json_fields": {**log_metadata}},
)
log_metadata.info(f"Start node execution {node_exec.node_exec_id}")
for execution in execute_node(
cls.loop, cls.agent_server_client, node_exec, stats
):
q.add(execution)
logger.info(
f"{prefix} Finished node execution {node_exec.node_exec_id}",
extra={"json_fields": {**log_metadata}},
)
log_metadata.info(f"Finished node execution {node_exec.node_exec_id}")
except Exception as e:
logger.exception(
f"Failed node execution {node_exec.node_exec_id}: {e}",
extra={
**log_metadata,
},
log_metadata.exception(
f"Failed node execution {node_exec.node_exec_id}: {e}"
)
@classmethod
@@ -517,10 +514,12 @@ class Executor:
@classmethod
def on_graph_executor_stop(cls):
logger.info(
f"[on_graph_executor_stop {cls.pid}]Terminating node executor pool..."
)
prefix = f"[on_graph_executor_stop {cls.pid}]"
logger.info(f"{prefix}Disconnecting DB...")
cls.loop.run_until_complete(db.disconnect())
logger.info(f"{prefix} ⏳ Terminating node executor pool...")
cls.executor.terminate()
logger.info(f"{prefix} ✅ Finished cleanup")
@classmethod
def _init_node_executor_pool(cls):
@@ -532,20 +531,16 @@ class Executor:
@classmethod
@error_logged
def on_graph_execution(cls, graph_exec: GraphExecution, cancel: threading.Event):
log_metadata = get_log_metadata(
log_metadata = LogMetadata(
user_id=graph_exec.user_id,
graph_eid=graph_exec.graph_exec_id,
graph_id=graph_exec.graph_id,
node_id="*",
node_eid="*",
block_name="-",
)
prefix = get_log_prefix(
graph_eid=graph_exec.graph_exec_id,
node_eid="*",
block_name="-",
)
timing_info, node_count = cls._on_graph_execution(
graph_exec, cancel, log_metadata, prefix
graph_exec, cancel, log_metadata
)
cls.loop.run_until_complete(
@@ -565,13 +560,9 @@ class Executor:
cls,
graph_exec: GraphExecution,
cancel: threading.Event,
log_metadata: dict,
prefix: str,
log_metadata: LogMetadata,
) -> int:
logger.info(
f"{prefix} Start graph execution {graph_exec.graph_exec_id}",
extra={"json_fields": {**log_metadata}},
)
log_metadata.info(f"Start graph execution {graph_exec.graph_exec_id}")
n_node_executions = 0
finished = False
@@ -581,10 +572,7 @@ class Executor:
if finished:
return
cls.executor.terminate()
logger.info(
f"{prefix} Terminated graph execution {graph_exec.graph_exec_id}",
extra={"json_fields": {**log_metadata}},
)
log_metadata.info(f"Terminated graph execution {graph_exec.graph_exec_id}")
cls._init_node_executor_pool()
cancel_thread = threading.Thread(target=cancel_handler)
@@ -622,10 +610,9 @@ class Executor:
# Re-enqueueing the data back to the queue will disrupt the order.
execution.wait()
logger.debug(
f"{prefix} Dispatching node execution {exec_data.node_exec_id} "
log_metadata.debug(
f"Dispatching node execution {exec_data.node_exec_id} "
f"for node {exec_data.node_id}",
extra={**log_metadata},
)
running_executions[exec_data.node_id] = cls.executor.apply_async(
cls.on_node_execution,
@@ -635,10 +622,8 @@ class Executor:
# Avoid terminating graph execution when some nodes are still running.
while queue.empty() and running_executions:
logger.debug(
"Queue empty; running nodes: "
f"{list(running_executions.keys())}",
extra={"json_fields": {**log_metadata}},
log_metadata.debug(
f"Queue empty; running nodes: {list(running_executions.keys())}"
)
for node_id, execution in list(running_executions.items()):
if cancel.is_set():
@@ -647,20 +632,13 @@ class Executor:
if not queue.empty():
break # yield to parent loop to execute new queue items
logger.debug(
f"Waiting on execution of node {node_id}",
extra={"json_fields": {**log_metadata}},
)
log_metadata.debug(f"Waiting on execution of node {node_id}")
execution.wait(3)
logger.info(
f"{prefix} Finished graph execution {graph_exec.graph_exec_id}",
extra={"json_fields": {**log_metadata}},
)
log_metadata.info(f"Finished graph execution {graph_exec.graph_exec_id}")
except Exception as e:
logger.exception(
f"{prefix} Failed graph execution {graph_exec.graph_exec_id}: {e}",
extra={"json_fields": {**log_metadata}},
log_metadata.exception(
f"Failed graph execution {graph_exec.graph_exec_id}: {e}"
)
finally:
if not cancel.is_set():
@@ -747,6 +725,7 @@ class ExecutionManager(AppService):
for node_exec in node_execs:
starting_node_execs.append(
NodeExecution(
user_id=user_id,
graph_exec_id=node_exec.graph_exec_id,
graph_id=node_exec.graph_id,
node_exec_id=node_exec.node_exec_id,
@@ -762,6 +741,7 @@ class ExecutionManager(AppService):
self.agent_server_client.send_execution_update(exec_update.model_dump())
graph_exec = GraphExecution(
user_id=user_id,
graph_id=graph_id,
graph_exec_id=graph_exec_id,
start_node_execs=starting_node_execs,

View File

@@ -15,6 +15,7 @@ from autogpt_server.data import execution as execution_db
from autogpt_server.data import graph as graph_db
from autogpt_server.data import user as user_db
from autogpt_server.data.block import BlockInput, CompletedBlockOutput
from autogpt_server.data.credit import get_block_costs, get_user_credit_model
from autogpt_server.data.queue import AsyncEventQueue, AsyncRedisEventQueue
from autogpt_server.data.user import get_or_create_user
from autogpt_server.executor import ExecutionManager, ExecutionScheduler
@@ -32,6 +33,7 @@ class AgentServer(AppService):
mutex = KeyedMutex()
use_redis = True
_test_dependency_overrides = {}
_user_credit_model = get_user_credit_model()
def __init__(self, event_queue: AsyncEventQueue | None = None):
super().__init__(port=Config().agent_server_port)
@@ -91,6 +93,11 @@ class AgentServer(AppService):
endpoint=self.get_graph_blocks,
methods=["GET"],
)
api_router.add_api_route(
path="/blocks/costs",
endpoint=self.get_graph_block_costs,
methods=["GET"],
)
api_router.add_api_route(
path="/blocks/{block_id}/execute",
endpoint=self.execute_graph_block,
@@ -196,6 +203,11 @@ class AgentServer(AppService):
endpoint=self.update_schedule,
methods=["PUT"],
)
api_router.add_api_route(
path="/credits",
endpoint=self.get_user_credits,
methods=["GET"],
)
api_router.add_api_route(
path="/settings",
@@ -265,6 +277,10 @@ class AgentServer(AppService):
def get_graph_blocks(cls) -> list[dict[Any, Any]]:
return [v.to_dict() for v in block.get_blocks().values()]
@classmethod
def get_graph_block_costs(cls) -> dict[Any, Any]:
return get_block_costs()
@classmethod
def execute_graph_block(
cls, block_id: str, data: BlockInput
@@ -481,6 +497,25 @@ class AgentServer(AppService):
return await execution_db.list_executions(graph_id, graph_version)
@classmethod
async def get_graph_run_status(
cls,
graph_id: str,
graph_exec_id: str,
user_id: Annotated[str, Depends(get_user_id)],
) -> execution_db.ExecutionStatus:
graph = await graph_db.get_graph(graph_id, user_id=user_id)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
execution = await execution_db.get_graph_execution(graph_exec_id, user_id)
if not execution:
raise HTTPException(
status_code=404, detail=f"Execution #{graph_exec_id} not found."
)
return execution.executionStatus
@classmethod
async def get_graph_run_node_execution_results(
cls,
@@ -522,6 +557,11 @@ class AgentServer(AppService):
execution_scheduler.update_schedule(schedule_id, is_enabled, user_id=user_id)
return {"id": schedule_id}
async def get_user_credits(
self, user_id: Annotated[str, Depends(get_user_id)]
) -> dict[str, int]:
return {"credits": await self._user_credit_model.get_or_refill_credit(user_id)}
def get_execution_schedules(
self, graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> dict[str, str]:

View File

@@ -252,7 +252,6 @@ Here are a couple of sample of the Block class implementation:
async def block_autogen_agent():
async with SpinTestServer() as server:
test_manager = server.exec_manager
test_user = await create_test_user()
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
input_data = {"input": "Write me a block that writes a string into a file."}
@@ -261,10 +260,8 @@ async def block_autogen_agent():
)
print(response)
result = await wait_execution(
exec_manager=test_manager,
graph_id=test_graph.id,
graph_exec_id=response["id"],
num_execs=10,
timeout=1200,
user_id=test_user.id,
)

View File

@@ -153,7 +153,6 @@ async def create_test_user() -> User:
async def reddit_marketing_agent():
async with SpinTestServer() as server:
exec_man = server.exec_manager
test_user = await create_test_user()
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
input_data = {"subreddit": "AutoGPT"}
@@ -161,9 +160,7 @@ async def reddit_marketing_agent():
test_graph.id, input_data, test_user.id
)
print(response)
result = await wait_execution(
exec_man, test_user.id, test_graph.id, response["id"], 13, 120
)
result = await wait_execution(test_user.id, test_graph.id, response["id"], 120)
print(result)

View File

@@ -75,7 +75,6 @@ def create_test_graph() -> graph.Graph:
async def sample_agent():
async with SpinTestServer() as server:
exec_man = server.exec_manager
test_user = await create_test_user()
test_graph = await create_graph(create_test_graph(), test_user.id)
input_data = {"input_1": "Hello", "input_2": "World"}
@@ -83,9 +82,7 @@ async def sample_agent():
test_graph.id, input_data, test_user.id
)
print(response)
result = await wait_execution(
exec_man, test_user.id, test_graph.id, response["id"], 4, 10
)
result = await wait_execution(test_user.id, test_graph.id, response["id"], 10)
print(result)

View File

@@ -42,15 +42,15 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
"""Config for the server."""
num_graph_workers: int = Field(
default=1,
default=10,
ge=1,
le=100,
le=1000,
description="Maximum number of workers to use for graph execution.",
)
num_node_workers: int = Field(
default=1,
default=5,
ge=1,
le=100,
le=1000,
description="Maximum number of workers to use for node execution within a single graph.",
)
pyro_host: str = Field(
@@ -61,6 +61,14 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
default="false",
description="If authentication is enabled or not",
)
enable_credit: str = Field(
default="false",
description="If user credit system is enabled or not",
)
num_user_credits_refill: int = Field(
default=1500,
description="Number of credits to refill for each user",
)
# Add more configuration fields as needed
model_config = SettingsConfigDict(

View File

@@ -5,6 +5,7 @@ from autogpt_server.data import db
from autogpt_server.data.block import Block, initialize_blocks
from autogpt_server.data.execution import ExecutionResult, ExecutionStatus
from autogpt_server.data.queue import AsyncEventQueue
from autogpt_server.data.user import create_default_user
from autogpt_server.executor import ExecutionManager, ExecutionScheduler
from autogpt_server.server import AgentServer
from autogpt_server.server.rest_api import get_user_id
@@ -64,6 +65,7 @@ class SpinTestServer:
await db.connect()
await initialize_blocks()
await create_default_user("false")
return self
@@ -82,25 +84,18 @@ class SpinTestServer:
async def wait_execution(
exec_manager: ExecutionManager,
user_id: str,
graph_id: str,
graph_exec_id: str,
num_execs: int,
timeout: int = 20,
) -> list:
async def is_execution_completed():
execs = await AgentServer().get_graph_run_node_execution_results(
status = await AgentServer().get_graph_run_status(
graph_id, graph_exec_id, user_id
)
return (
exec_manager.queue.empty()
and len(execs) == num_execs
and all(
v.status in [ExecutionStatus.COMPLETED, ExecutionStatus.FAILED]
for v in execs
)
)
if status == ExecutionStatus.FAILED:
raise Exception("Execution failed")
return status == ExecutionStatus.COMPLETED
# Wait for the executions to complete
for i in range(timeout):

View File

@@ -1,4 +1,5 @@
{
"num_graph_workers": 10,
"num_node_workers": 5
"num_node_workers": 5,
"num_user_credits_refill": 1500
}

View File

@@ -0,0 +1,39 @@
/*
Warnings:
- The `executionStatus` column on the `AgentNodeExecution` table would be dropped and recreated. This will lead to data loss if there is data in the column.
*/
-- CreateEnum
CREATE TYPE "AgentExecutionStatus" AS ENUM ('INCOMPLETE', 'QUEUED', 'RUNNING', 'COMPLETED', 'FAILED');
-- CreateEnum
CREATE TYPE "UserBlockCreditType" AS ENUM ('TOP_UP', 'USAGE');
-- AlterTable
ALTER TABLE "AgentGraphExecution" ADD COLUMN "executionStatus" "AgentExecutionStatus" NOT NULL DEFAULT 'COMPLETED',
ADD COLUMN "startedAt" TIMESTAMP(3);
-- AlterTable
ALTER TABLE "AgentNodeExecution" DROP COLUMN "executionStatus",
ADD COLUMN "executionStatus" "AgentExecutionStatus" NOT NULL DEFAULT 'COMPLETED';
-- CreateTable
CREATE TABLE "UserBlockCredit" (
"transactionKey" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"userId" TEXT NOT NULL,
"blockId" TEXT,
"amount" INTEGER NOT NULL,
"type" "UserBlockCreditType" NOT NULL,
"isActive" BOOLEAN NOT NULL DEFAULT true,
"metadata" JSONB,
CONSTRAINT "UserBlockCredit_pkey" PRIMARY KEY ("transactionKey","userId")
);
-- AddForeignKey
ALTER TABLE "UserBlockCredit" ADD CONSTRAINT "UserBlockCredit_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE RESTRICT ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "UserBlockCredit" ADD CONSTRAINT "UserBlockCredit_blockId_fkey" FOREIGN KEY ("blockId") REFERENCES "AgentBlock"("id") ON DELETE SET NULL ON UPDATE CASCADE;

View File

@@ -22,6 +22,7 @@ model User {
AgentGraphs AgentGraph[]
AgentGraphExecutions AgentGraphExecution[]
AgentGraphExecutionSchedules AgentGraphExecutionSchedule[]
UserBlockCredit UserBlockCredit[]
@@index([id])
@@index([email])
@@ -29,9 +30,9 @@ model User {
// This model describes the Agent Graph/Flow (Multi Agent System).
model AgentGraph {
id String @default(uuid())
version Int @default(1)
createdAt DateTime @default(now())
id String @default(uuid())
version Int @default(1)
createdAt DateTime @default(now())
updatedAt DateTime? @updatedAt
name String?
@@ -111,13 +112,26 @@ model AgentBlock {
// Prisma requires explicit back-references.
ReferencedByAgentNode AgentNode[]
UserBlockCredit UserBlockCredit[]
}
// This model describes the status of an AgentGraphExecution or AgentNodeExecution.
enum AgentExecutionStatus {
INCOMPLETE
QUEUED
RUNNING
COMPLETED
FAILED
}
// This model describes the execution of an AgentGraph.
model AgentGraphExecution {
id String @id @default(uuid())
createdAt DateTime @default(now())
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime? @updatedAt
startedAt DateTime?
executionStatus AgentExecutionStatus @default(COMPLETED)
agentGraphId String
agentGraphVersion Int @default(1)
@@ -145,12 +159,10 @@ model AgentNodeExecution {
Input AgentNodeExecutionInputOutput[] @relation("AgentNodeExecutionInput")
Output AgentNodeExecutionInputOutput[] @relation("AgentNodeExecutionOutput")
// sqlite does not support enum
// enum Status { INCOMPLETE, QUEUED, RUNNING, SUCCESS, FAILED }
executionStatus String
executionStatus AgentExecutionStatus @default(COMPLETED)
// Final JSON serialized input data for the node execution.
executionData String?
addedTime DateTime @default(now())
addedTime DateTime @default(now())
queuedTime DateTime?
startedTime DateTime?
endedTime DateTime?
@@ -178,8 +190,8 @@ model AgentNodeExecutionInputOutput {
// This model describes the recurring execution schedule of an Agent.
model AgentGraphExecutionSchedule {
id String @id
createdAt DateTime @default(now())
id String @id
createdAt DateTime @default(now())
updatedAt DateTime? @updatedAt
agentGraphId String
@@ -199,3 +211,27 @@ model AgentGraphExecutionSchedule {
@@index([isEnabled])
}
enum UserBlockCreditType {
TOP_UP
USAGE
}
model UserBlockCredit {
transactionKey String @default(uuid())
createdAt DateTime @default(now())
userId String
user User @relation(fields: [userId], references: [id])
blockId String?
block AgentBlock? @relation(fields: [blockId], references: [id])
amount Int
type UserBlockCreditType
isActive Boolean @default(true)
metadata Json?
@@id(name: "creditTransactionIdentifier", [transactionKey, userId])
}

View File

@@ -0,0 +1,90 @@
from datetime import datetime
import pytest
from prisma.models import UserBlockCredit
from autogpt_server.blocks.llm import AITextGeneratorBlock
from autogpt_server.data.credit import UserCredit
from autogpt_server.data.user import DEFAULT_USER_ID
from autogpt_server.util.test import SpinTestServer
REFILL_VALUE = 1000
user_credit = UserCredit(REFILL_VALUE)
@pytest.mark.asyncio(scope="session")
async def test_block_credit_usage(server: SpinTestServer):
current_credit = await user_credit.get_or_refill_credit(DEFAULT_USER_ID)
spending_amount_1 = await user_credit.spend_credits(
DEFAULT_USER_ID,
current_credit,
AITextGeneratorBlock(),
{"model": "gpt-4-turbo"},
0.0,
0.0,
validate_balance=False,
)
assert spending_amount_1 > 0
spending_amount_2 = await user_credit.spend_credits(
DEFAULT_USER_ID,
current_credit,
AITextGeneratorBlock(),
{"model": "gpt-4-turbo", "api_key": "owned_api_key"},
0.0,
0.0,
validate_balance=False,
)
assert spending_amount_2 == 0
new_credit = await user_credit.get_or_refill_credit(DEFAULT_USER_ID)
assert new_credit == current_credit - spending_amount_1 - spending_amount_2
@pytest.mark.asyncio(scope="session")
async def test_block_credit_top_up(server: SpinTestServer):
current_credit = await user_credit.get_or_refill_credit(DEFAULT_USER_ID)
await user_credit.top_up_credits(DEFAULT_USER_ID, 100)
new_credit = await user_credit.get_or_refill_credit(DEFAULT_USER_ID)
assert new_credit == current_credit + 100
@pytest.mark.asyncio(scope="session")
async def test_block_credit_reset(server: SpinTestServer):
month1 = datetime(2022, 1, 15)
month2 = datetime(2022, 2, 15)
user_credit.time_now = lambda: month2
month2credit = await user_credit.get_or_refill_credit(DEFAULT_USER_ID)
# Month 1 result should only affect month 1
user_credit.time_now = lambda: month1
month1credit = await user_credit.get_or_refill_credit(DEFAULT_USER_ID)
await user_credit.top_up_credits(DEFAULT_USER_ID, 100)
assert await user_credit.get_or_refill_credit(DEFAULT_USER_ID) == month1credit + 100
# Month 2 balance is unaffected
user_credit.time_now = lambda: month2
assert await user_credit.get_or_refill_credit(DEFAULT_USER_ID) == month2credit
@pytest.mark.asyncio(scope="session")
async def test_credit_refill(server: SpinTestServer):
# Clear all transactions within the month
await UserBlockCredit.prisma().update_many(
where={
"userId": DEFAULT_USER_ID,
"createdAt": {
"gte": datetime(2022, 2, 1),
"lt": datetime(2022, 3, 1),
},
},
data={"isActive": False},
)
user_credit.time_now = lambda: datetime(2022, 2, 15)
balance = await user_credit.get_or_refill_credit(DEFAULT_USER_ID)
assert balance == REFILL_VALUE

View File

@@ -4,7 +4,7 @@ import pytest
from autogpt_server.blocks.basic import AgentInputBlock, StoreValueBlock
from autogpt_server.data.graph import Graph, Link, Node
from autogpt_server.data.user import DEFAULT_USER_ID, create_default_user
from autogpt_server.data.user import DEFAULT_USER_ID
from autogpt_server.server.model import CreateGraph
from autogpt_server.util.test import SpinTestServer
@@ -22,8 +22,6 @@ async def test_graph_creation(server: SpinTestServer):
Args:
server (SpinTestServer): The test server instance.
"""
await create_default_user("false")
value_block = StoreValueBlock().id
input_block = AgentInputBlock().id

View File

@@ -4,7 +4,6 @@ from prisma.models import User
from autogpt_server.blocks.basic import FindInDictionaryBlock, StoreValueBlock
from autogpt_server.blocks.maths import CalculatorBlock, Operation
from autogpt_server.data import execution, graph
from autogpt_server.executor import ExecutionManager
from autogpt_server.server import AgentServer
from autogpt_server.usecases.sample import create_test_graph, create_test_user
from autogpt_server.util.test import SpinTestServer, wait_execution
@@ -12,7 +11,6 @@ from autogpt_server.util.test import SpinTestServer, wait_execution
async def execute_graph(
agent_server: AgentServer,
test_manager: ExecutionManager,
test_graph: graph.Graph,
test_user: User,
input_data: dict,
@@ -23,9 +21,8 @@ async def execute_graph(
graph_exec_id = response["id"]
# Execution queue should be empty
assert await wait_execution(
test_manager, test_user.id, test_graph.id, graph_exec_id, num_execs
)
result = await wait_execution(test_user.id, test_graph.id, graph_exec_id)
assert result and len(result) == num_execs
return graph_exec_id
@@ -108,7 +105,6 @@ async def test_agent_execution(server: SpinTestServer):
data = {"input_1": "Hello", "input_2": "World"}
graph_exec_id = await execute_graph(
server.agent_server,
server.exec_manager,
test_graph,
test_user,
data,
@@ -169,7 +165,7 @@ async def test_input_pin_always_waited(server: SpinTestServer):
test_user = await create_test_user()
test_graph = await graph.create_graph(test_graph, user_id=test_user.id)
graph_exec_id = await execute_graph(
server.agent_server, server.exec_manager, test_graph, test_user, {}, 3
server.agent_server, test_graph, test_user, {}, 3
)
executions = await server.agent_server.get_graph_run_node_execution_results(
@@ -250,7 +246,7 @@ async def test_static_input_link_on_graph(server: SpinTestServer):
test_user = await create_test_user()
test_graph = await graph.create_graph(test_graph, user_id=test_user.id)
graph_exec_id = await execute_graph(
server.agent_server, server.exec_manager, test_graph, test_user, {}, 8
server.agent_server, test_graph, test_user, {}, 8
)
executions = await server.agent_server.get_graph_run_node_execution_results(
test_graph.id, graph_exec_id, test_user.id