feat(backend): Track LLM token usage + LLM blocks cleanup (#8367)

This commit is contained in:
Zamil Majdy
2024-10-23 02:13:27 +03:00
committed by GitHub
parent 0c517216df
commit f4dac22335
7 changed files with 195 additions and 133 deletions

View File

@@ -2,6 +2,7 @@ import importlib
import os
import re
from pathlib import Path
from typing import Type, TypeVar
from backend.data.block import Block
@@ -24,28 +25,31 @@ for module in modules:
AVAILABLE_MODULES.append(module)
# Load all Block instances from the available modules
AVAILABLE_BLOCKS = {}
AVAILABLE_BLOCKS: dict[str, Type[Block]] = {}
def all_subclasses(clz):
subclasses = clz.__subclasses__()
T = TypeVar("T")
def all_subclasses(cls: Type[T]) -> list[Type[T]]:
subclasses = cls.__subclasses__()
for subclass in subclasses:
subclasses += all_subclasses(subclass)
return subclasses
for cls in all_subclasses(Block):
name = cls.__name__
for block_cls in all_subclasses(Block):
name = block_cls.__name__
if cls.__name__.endswith("Base"):
if block_cls.__name__.endswith("Base"):
continue
if not cls.__name__.endswith("Block"):
if not block_cls.__name__.endswith("Block"):
raise ValueError(
f"Block class {cls.__name__} does not end with 'Block', If you are creating an abstract class, please name the class with 'Base' at the end"
f"Block class {block_cls.__name__} does not end with 'Block', If you are creating an abstract class, please name the class with 'Base' at the end"
)
block = cls()
block = block_cls.create()
if not isinstance(block.id, str) or len(block.id) != 36:
raise ValueError(f"Block ID {block.name} error: {block.id} is not a valid UUID")
@@ -87,6 +91,6 @@ for cls in all_subclasses(Block):
if block.disabled:
continue
AVAILABLE_BLOCKS[block.id] = block
AVAILABLE_BLOCKS[block.id] = block_cls
__all__ = ["AVAILABLE_MODULES", "AVAILABLE_BLOCKS"]

View File

@@ -122,6 +122,17 @@ for model in LlmModel:
raise ValueError(f"Missing MODEL_METADATA metadata for model: {model}")
class MessageRole(str, Enum):
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
class Message(BlockSchema):
role: MessageRole
content: str
class AIStructuredResponseGeneratorBlock(Block):
class Input(BlockSchema):
prompt: str = SchemaField(
@@ -144,6 +155,10 @@ class AIStructuredResponseGeneratorBlock(Block):
default="",
description="The system prompt to provide additional context to the model.",
)
conversation_history: list[Message] = SchemaField(
default=[],
description="The conversation history to provide context for the prompt.",
)
retry: int = SchemaField(
title="Retry Count",
default=3,
@@ -152,6 +167,11 @@ class AIStructuredResponseGeneratorBlock(Block):
prompt_values: dict[str, str] = SchemaField(
advanced=False, default={}, description="Values used to fill in the prompt."
)
max_tokens: int | None = SchemaField(
advanced=True,
default=None,
description="The maximum number of tokens to generate in the chat completion.",
)
class Output(BlockSchema):
response: dict[str, Any] = SchemaField(
@@ -177,26 +197,47 @@ class AIStructuredResponseGeneratorBlock(Block):
},
test_output=("response", {"key1": "key1Value", "key2": "key2Value"}),
test_mock={
"llm_call": lambda *args, **kwargs: json.dumps(
{
"key1": "key1Value",
"key2": "key2Value",
}
"llm_call": lambda *args, **kwargs: (
json.dumps(
{
"key1": "key1Value",
"key2": "key2Value",
}
),
0,
0,
)
},
)
@staticmethod
def llm_call(
api_key: str, model: LlmModel, prompt: list[dict], json_format: bool
) -> str:
provider = model.metadata.provider
api_key: str,
llm_model: LlmModel,
prompt: list[dict],
json_format: bool,
max_tokens: int | None = None,
) -> tuple[str, int, int]:
"""
Args:
api_key: API key for the LLM provider.
llm_model: The LLM model to use.
prompt: The prompt to send to the LLM.
json_format: Whether the response should be in JSON format.
max_tokens: The maximum number of tokens to generate in the chat completion.
Returns:
The response from the LLM.
The number of tokens used in the prompt.
The number of tokens used in the completion.
"""
provider = llm_model.metadata.provider
if provider == "openai":
openai.api_key = api_key
response_format = None
if model in [LlmModel.O1_MINI, LlmModel.O1_PREVIEW]:
if llm_model in [LlmModel.O1_MINI, LlmModel.O1_PREVIEW]:
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
prompt = [
@@ -207,11 +248,17 @@ class AIStructuredResponseGeneratorBlock(Block):
response_format = {"type": "json_object"}
response = openai.chat.completions.create(
model=model.value,
model=llm_model.value,
messages=prompt, # type: ignore
response_format=response_format, # type: ignore
max_completion_tokens=max_tokens,
)
return (
response.choices[0].message.content or "",
response.usage.prompt_tokens if response.usage else 0,
response.usage.completion_tokens if response.usage else 0,
)
return response.choices[0].message.content or ""
elif provider == "anthropic":
system_messages = [p["content"] for p in prompt if p["role"] == "system"]
sysprompt = " ".join(system_messages)
@@ -229,13 +276,18 @@ class AIStructuredResponseGeneratorBlock(Block):
client = anthropic.Anthropic(api_key=api_key)
try:
response = client.messages.create(
model=model.value,
max_tokens=4096,
resp = client.messages.create(
model=llm_model.value,
system=sysprompt,
messages=messages,
max_tokens=max_tokens or 8192,
)
return (
resp.content[0].text if resp.content else "",
resp.usage.input_tokens,
resp.usage.output_tokens,
)
return response.content[0].text if response.content else ""
except anthropic.APIError as e:
error_message = f"Anthropic API error: {str(e)}"
logger.error(error_message)
@@ -244,23 +296,35 @@ class AIStructuredResponseGeneratorBlock(Block):
client = Groq(api_key=api_key)
response_format = {"type": "json_object"} if json_format else None
response = client.chat.completions.create(
model=model.value,
model=llm_model.value,
messages=prompt, # type: ignore
response_format=response_format, # type: ignore
max_tokens=max_tokens,
)
return (
response.choices[0].message.content or "",
response.usage.prompt_tokens if response.usage else 0,
response.usage.completion_tokens if response.usage else 0,
)
return response.choices[0].message.content or ""
elif provider == "ollama":
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
response = ollama.generate(
model=model.value,
prompt=prompt[0]["content"],
model=llm_model.value,
prompt=f"{sys_messages}\n\n{usr_messages}",
stream=False,
)
return (
response.get("response") or "",
response.get("prompt_eval_count") or 0,
response.get("eval_count") or 0,
)
return response["response"]
else:
raise ValueError(f"Unsupported LLM provider: {provider}")
def run(self, input_data: Input, **kwargs) -> BlockOutput:
logger.debug(f"Calling LLM with input data: {input_data}")
prompt = []
prompt = [p.model_dump() for p in input_data.conversation_history]
def trim_prompt(s: str) -> str:
lines = s.strip().split("\n")
@@ -289,7 +353,8 @@ class AIStructuredResponseGeneratorBlock(Block):
)
prompt.append({"role": "system", "content": sys_prompt})
prompt.append({"role": "user", "content": input_data.prompt})
if input_data.prompt:
prompt.append({"role": "user", "content": input_data.prompt})
def parse_response(resp: str) -> tuple[dict[str, Any], str | None]:
try:
@@ -305,19 +370,26 @@ class AIStructuredResponseGeneratorBlock(Block):
logger.info(f"LLM request: {prompt}")
retry_prompt = ""
model = input_data.model
llm_model = input_data.model
api_key = (
input_data.api_key.get_secret_value()
or LlmApiKeys[model.metadata.provider].get_secret_value()
or LlmApiKeys[llm_model.metadata.provider].get_secret_value()
)
for retry_count in range(input_data.retry):
try:
response_text = self.llm_call(
response_text, input_token, output_token = self.llm_call(
api_key=api_key,
model=model,
llm_model=llm_model,
prompt=prompt,
json_format=bool(input_data.expected_format),
max_tokens=input_data.max_tokens,
)
self.merge_stats(
{
"input_token_count": input_token,
"output_token_count": output_token,
}
)
logger.info(f"LLM attempt-{retry_count} response: {response_text}")
@@ -354,8 +426,15 @@ class AIStructuredResponseGeneratorBlock(Block):
)
prompt.append({"role": "user", "content": retry_prompt})
except Exception as e:
logger.error(f"Error calling LLM: {e}")
logger.exception(f"Error calling LLM: {e}")
retry_prompt = f"Error calling LLM: {e}"
finally:
self.merge_stats(
{
"llm_call_count": retry_count + 1,
"llm_retry_count": retry_count,
}
)
raise RuntimeError(retry_prompt)
@@ -386,6 +465,11 @@ class AITextGeneratorBlock(Block):
prompt_values: dict[str, str] = SchemaField(
advanced=False, default={}, description="Values used to fill in the prompt."
)
max_tokens: int | None = SchemaField(
advanced=True,
default=None,
description="The maximum number of tokens to generate in the chat completion.",
)
class Output(BlockSchema):
response: str = SchemaField(
@@ -405,15 +489,11 @@ class AITextGeneratorBlock(Block):
test_mock={"llm_call": lambda *args, **kwargs: "Response text"},
)
@staticmethod
def llm_call(input_data: AIStructuredResponseGeneratorBlock.Input) -> str:
object_block = AIStructuredResponseGeneratorBlock()
for output_name, output_data in object_block.run(input_data):
if output_name == "response":
return output_data["response"]
else:
raise RuntimeError(output_data)
raise ValueError("Failed to get a response from the LLM.")
def llm_call(self, input_data: AIStructuredResponseGeneratorBlock.Input) -> str:
block = AIStructuredResponseGeneratorBlock()
response = block.run_once(input_data, "response")
self.merge_stats(block.execution_stats)
return response["response"]
def run(self, input_data: Input, **kwargs) -> BlockOutput:
object_input_data = AIStructuredResponseGeneratorBlock.Input(
@@ -517,15 +597,11 @@ class AITextSummarizerBlock(Block):
return chunks
@staticmethod
def llm_call(
input_data: AIStructuredResponseGeneratorBlock.Input,
) -> dict[str, str]:
llm_block = AIStructuredResponseGeneratorBlock()
for output_name, output_data in llm_block.run(input_data):
if output_name == "response":
return output_data
raise ValueError("Failed to get a response from the LLM.")
def llm_call(self, input_data: AIStructuredResponseGeneratorBlock.Input) -> dict:
block = AIStructuredResponseGeneratorBlock()
response = block.run_once(input_data, "response")
self.merge_stats(block.execution_stats)
return response
def _summarize_chunk(self, chunk: str, input_data: Input) -> str:
prompt = f"Summarize the following text in a {input_data.style} form. Focus your summary on the topic of `{input_data.focus}` if present, otherwise just provide a general summary:\n\n```{chunk}```"
@@ -574,17 +650,6 @@ class AITextSummarizerBlock(Block):
] # Get the first yielded value
class MessageRole(str, Enum):
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
class Message(BlockSchema):
role: MessageRole
content: str
class AIConversationBlock(Block):
class Input(BlockSchema):
messages: List[Message] = SchemaField(
@@ -599,9 +664,9 @@ class AIConversationBlock(Block):
value="", description="API key for the chosen language model provider."
)
max_tokens: int | None = SchemaField(
advanced=True,
default=None,
description="The maximum number of tokens to generate in the chat completion.",
ge=1,
)
class Output(BlockSchema):
@@ -639,62 +704,22 @@ class AIConversationBlock(Block):
},
)
@staticmethod
def llm_call(
api_key: str,
model: LlmModel,
messages: List[dict[str, str]],
max_tokens: int | None = None,
) -> str:
provider = model.metadata.provider
if provider == "openai":
openai.api_key = api_key
response = openai.chat.completions.create(
model=model.value,
messages=messages, # type: ignore
max_tokens=max_tokens,
)
return response.choices[0].message.content or ""
elif provider == "anthropic":
client = anthropic.Anthropic(api_key=api_key)
response = client.messages.create(
model=model.value,
max_tokens=max_tokens or 4096,
messages=messages, # type: ignore
)
return response.content[0].text if response.content else ""
elif provider == "groq":
client = Groq(api_key=api_key)
response = client.chat.completions.create(
model=model.value,
messages=messages, # type: ignore
max_tokens=max_tokens,
)
return response.choices[0].message.content or ""
elif provider == "ollama":
response = ollama.chat(
model=model.value,
messages=messages, # type: ignore
stream=False, # type: ignore
)
return response["message"]["content"]
else:
raise ValueError(f"Unsupported LLM provider: {provider}")
def llm_call(self, input_data: AIStructuredResponseGeneratorBlock.Input) -> str:
block = AIStructuredResponseGeneratorBlock()
response = block.run_once(input_data, "response")
self.merge_stats(block.execution_stats)
return response["response"]
def run(self, input_data: Input, **kwargs) -> BlockOutput:
api_key = (
input_data.api_key.get_secret_value()
or LlmApiKeys[input_data.model.metadata.provider].get_secret_value()
)
messages = [message.model_dump() for message in input_data.messages]
response = self.llm_call(
api_key=api_key,
model=input_data.model,
messages=messages,
max_tokens=input_data.max_tokens,
AIStructuredResponseGeneratorBlock.Input(
prompt="",
api_key=input_data.api_key,
model=input_data.model,
conversation_history=input_data.messages,
max_tokens=input_data.max_tokens,
expected_format={},
)
)
yield "response", response
@@ -727,6 +752,11 @@ class AIListGeneratorBlock(Block):
ge=1,
le=5,
)
max_tokens: int | None = SchemaField(
advanced=True,
default=None,
description="The maximum number of tokens to generate in the chat completion.",
)
class Output(BlockSchema):
generated_list: List[str] = SchemaField(description="The generated list.")
@@ -781,11 +811,8 @@ class AIListGeneratorBlock(Block):
input_data: AIStructuredResponseGeneratorBlock.Input,
) -> dict[str, str]:
llm_block = AIStructuredResponseGeneratorBlock()
for output_name, output_data in llm_block.run(input_data):
if output_name == "response":
logger.debug(f"Received response from LLM: {output_data}")
return output_data
raise ValueError("Failed to get a response from the LLM.")
response = llm_block.run_once(input_data, "response")
return response
@staticmethod
def string_to_list(string):

View File

@@ -230,6 +230,11 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
self.disabled = disabled
self.static_output = static_output
self.block_type = block_type
self.execution_stats = {}
@classmethod
def create(cls: Type["Block"]) -> "Block":
return cls()
@abstractmethod
def run(self, input_data: BlockSchemaInputType, **kwargs) -> BlockOutput:
@@ -244,6 +249,26 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
"""
pass
def run_once(self, input_data: BlockSchemaInputType, output: str, **kwargs) -> Any:
for name, data in self.run(input_data, **kwargs):
if name == output:
return data
raise ValueError(f"{self.name} did not produce any output for {output}")
def merge_stats(self, stats: dict[str, Any]) -> dict[str, Any]:
for key, value in stats.items():
if isinstance(value, dict):
self.execution_stats.setdefault(key, {}).update(value)
elif isinstance(value, (int, float)):
self.execution_stats.setdefault(key, 0)
self.execution_stats[key] += value
elif isinstance(value, list):
self.execution_stats.setdefault(key, [])
self.execution_stats[key].extend(value)
else:
self.execution_stats[key] = value
return self.execution_stats
@property
def name(self):
return self.__class__.__name__
@@ -282,14 +307,15 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
# ======================= Block Helper Functions ======================= #
def get_blocks() -> dict[str, Block]:
def get_blocks() -> dict[str, Type[Block]]:
from backend.blocks import AVAILABLE_BLOCKS # noqa: E402
return AVAILABLE_BLOCKS
async def initialize_blocks() -> None:
for block in get_blocks().values():
for cls in get_blocks().values():
block = cls()
existing_block = await AgentBlock.prisma().find_first(
where={"OR": [{"id": block.id}, {"name": block.name}]}
)
@@ -324,4 +350,5 @@ async def initialize_blocks() -> None:
def get_block(block_id: str) -> Block | None:
return get_blocks().get(block_id)
cls = get_blocks().get(block_id)
return cls() if cls else None

View File

@@ -257,7 +257,7 @@ class Graph(GraphMeta):
block = get_block(node.block_id)
if not block:
blocks = {v.id: v.name for v in get_blocks().values()}
blocks = {v().id: v().name for v in get_blocks().values()}
raise ValueError(
f"{suffix}, {node.block_id} is invalid block id, available blocks: {blocks}"
)

View File

@@ -104,6 +104,7 @@ def execute_node(
Args:
db_client: The client to send execution updates to the server.
creds_manager: The manager to acquire and release credentials.
data: The execution data for executing the current node.
execution_stats: The execution statistics to be updated.
@@ -209,6 +210,7 @@ def execute_node(
if creds_lock:
creds_lock.release()
if execution_stats is not None:
execution_stats.update(node_block.execution_stats)
execution_stats["input_size"] = input_size
execution_stats["output_size"] = output_size

View File

@@ -331,9 +331,9 @@ class AgentServer(AppService):
@classmethod
def get_graph_blocks(cls) -> list[dict[Any, Any]]:
blocks = block.get_blocks()
blocks = [cls() for cls in block.get_blocks().values()]
costs = get_block_costs()
return [{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks.values()]
return [{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks]
@classmethod
def execute_graph_block(

View File

@@ -1,3 +1,5 @@
from typing import Type
import pytest
from backend.data.block import Block, get_blocks
@@ -5,5 +7,5 @@ from backend.util.test import execute_block_test
@pytest.mark.parametrize("block", get_blocks().values(), ids=lambda b: b.name)
def test_available_blocks(block: Block):
execute_block_test(type(block)())
def test_available_blocks(block: Type[Block]):
execute_block_test(block())