diff --git a/autogpt_platform/backend/backend/blocks/llm.py b/autogpt_platform/backend/backend/blocks/llm.py index e289738c12..e6927b7e6a 100644 --- a/autogpt_platform/backend/backend/blocks/llm.py +++ b/autogpt_platform/backend/backend/blocks/llm.py @@ -4,9 +4,9 @@ from abc import ABC from enum import Enum, EnumMeta from json import JSONDecodeError from types import MappingProxyType -from typing import TYPE_CHECKING, Any, List, Literal, NamedTuple +from typing import TYPE_CHECKING, Any, Iterable, List, Literal, NamedTuple, Optional -from pydantic import SecretStr +from pydantic import BaseModel, SecretStr from backend.integrations.providers import ProviderName @@ -16,6 +16,8 @@ if TYPE_CHECKING: import anthropic import ollama import openai +from anthropic._types import NotGiven +from anthropic.types import ToolParam from groq import Groq from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema @@ -238,6 +240,291 @@ class Message(BlockSchema): content: str +class ToolCall(BaseModel): + name: str + arguments: str + + +class ToolContentBlock(BaseModel): + id: str + type: str + function: ToolCall + + +class LLMResponse(BaseModel): + prompt: str + response: str + tool_calls: Optional[List[ToolContentBlock]] | None + prompt_tokens: int + completion_tokens: int + + +def convert_openai_tool_fmt_to_anthropic( + openai_tools: list[dict] | None = None, +) -> Iterable[ToolParam] | NotGiven: + """ + Convert OpenAI tool format to Anthropic tool format. + """ + if not openai_tools or len(openai_tools) == 0: + return anthropic.NOT_GIVEN + + anthropic_tools = [] + for tool in openai_tools: + if "function" in tool: + # Handle case where tool is already in OpenAI format with "type" and "function" + function_data = tool["function"] + else: + # Handle case where tool is just the function definition + function_data = tool + + anthropic_tool: anthropic.types.ToolParam = { + "name": function_data["name"], + "description": function_data.get("description", ""), + "input_schema": { + "type": "object", + "properties": function_data.get("parameters", {}).get("properties", {}), + "required": function_data.get("parameters", {}).get("required", []), + }, + } + anthropic_tools.append(anthropic_tool) + + return anthropic_tools + + +def llm_call( + credentials: APIKeyCredentials, + llm_model: LlmModel, + prompt: list[dict], + json_format: bool, + max_tokens: int | None, + tools: list[dict] | None = None, + ollama_host: str = "localhost:11434", +) -> LLMResponse: + """ + Make a call to a language model. + + Args: + credentials: The API key credentials to use. + llm_model: The LLM model to use. + prompt: The prompt to send to the LLM. + json_format: Whether the response should be in JSON format. + max_tokens: The maximum number of tokens to generate in the chat completion. + tools: The tools to use in the chat completion. + ollama_host: The host for ollama to use. + + Returns: + LLMResponse object containing: + - prompt: The prompt sent to the LLM. + - response: The text response from the LLM. + - tool_calls: Any tool calls the model made, if applicable. + - prompt_tokens: The number of tokens used in the prompt. + - completion_tokens: The number of tokens used in the completion. + """ + provider = llm_model.metadata.provider + max_tokens = max_tokens or llm_model.max_output_tokens or 4096 + + if provider == "openai": + tools_param = tools if tools else openai.NOT_GIVEN + oai_client = openai.OpenAI(api_key=credentials.api_key.get_secret_value()) + response_format = None + + if llm_model in [LlmModel.O1_MINI, LlmModel.O1_PREVIEW]: + sys_messages = [p["content"] for p in prompt if p["role"] == "system"] + usr_messages = [p["content"] for p in prompt if p["role"] != "system"] + prompt = [ + {"role": "user", "content": "\n".join(sys_messages)}, + {"role": "user", "content": "\n".join(usr_messages)}, + ] + elif json_format: + response_format = {"type": "json_object"} + + response = oai_client.chat.completions.create( + model=llm_model.value, + messages=prompt, # type: ignore + response_format=response_format, # type: ignore + max_completion_tokens=max_tokens, + tools=tools_param, # type: ignore + ) + + if response.choices[0].message.tool_calls: + tool_calls = [ + ToolContentBlock( + id=tool.id, + type=tool.type, + function=ToolCall( + name=tool.function.name, + arguments=tool.function.arguments, + ), + ) + for tool in response.choices[0].message.tool_calls + ] + else: + tool_calls = None + + return LLMResponse( + prompt=json.dumps(prompt), + response=response.choices[0].message.content or "", + tool_calls=tool_calls, + prompt_tokens=response.usage.prompt_tokens if response.usage else 0, + completion_tokens=response.usage.completion_tokens if response.usage else 0, + ) + elif provider == "anthropic": + + an_tools = convert_openai_tool_fmt_to_anthropic(tools) + + system_messages = [p["content"] for p in prompt if p["role"] == "system"] + sysprompt = " ".join(system_messages) + + messages = [] + last_role = None + for p in prompt: + if p["role"] in ["user", "assistant"]: + if p["role"] != last_role: + messages.append({"role": p["role"], "content": p["content"]}) + last_role = p["role"] + else: + # If the role is the same as the last one, combine the content + messages[-1]["content"] += "\n" + p["content"] + + client = anthropic.Anthropic(api_key=credentials.api_key.get_secret_value()) + try: + resp = client.messages.create( + model=llm_model.value, + system=sysprompt, + messages=messages, + max_tokens=max_tokens, + tools=an_tools, + ) + + if not resp.content: + raise ValueError("No content returned from Anthropic.") + + tool_calls = None + for content_block in resp.content: + # Antropic is different to openai, need to iterate through + # the content blocks to find the tool calls + if content_block.type == "tool_use": + if tool_calls is None: + tool_calls = [] + tool_calls.append( + ToolContentBlock( + id=content_block.id, + type=content_block.type, + function=ToolCall( + name=content_block.name, + arguments=json.dumps(content_block.input), + ), + ) + ) + + if not tool_calls and resp.stop_reason == "tool_use": + logger.warning( + "Tool use stop reason but no tool calls found in content. %s", resp + ) + + return LLMResponse( + prompt=json.dumps(prompt), + response=( + resp.content[0].name + if isinstance(resp.content[0], anthropic.types.ToolUseBlock) + else resp.content[0].text + ), + tool_calls=tool_calls, + prompt_tokens=resp.usage.input_tokens, + completion_tokens=resp.usage.output_tokens, + ) + except anthropic.APIError as e: + error_message = f"Anthropic API error: {str(e)}" + logger.error(error_message) + raise ValueError(error_message) + elif provider == "groq": + if tools: + raise ValueError("Groq does not support tools.") + + client = Groq(api_key=credentials.api_key.get_secret_value()) + response_format = {"type": "json_object"} if json_format else None + response = client.chat.completions.create( + model=llm_model.value, + messages=prompt, # type: ignore + response_format=response_format, # type: ignore + max_tokens=max_tokens, + ) + return LLMResponse( + prompt=json.dumps(prompt), + response=response.choices[0].message.content or "", + tool_calls=None, + prompt_tokens=response.usage.prompt_tokens if response.usage else 0, + completion_tokens=response.usage.completion_tokens if response.usage else 0, + ) + elif provider == "ollama": + if tools: + raise ValueError("Ollama does not support tools.") + + client = ollama.Client(host=ollama_host) + sys_messages = [p["content"] for p in prompt if p["role"] == "system"] + usr_messages = [p["content"] for p in prompt if p["role"] != "system"] + response = client.generate( + model=llm_model.value, + prompt=f"{sys_messages}\n\n{usr_messages}", + stream=False, + ) + return LLMResponse( + prompt=json.dumps(prompt), + response=response.get("response") or "", + tool_calls=None, + prompt_tokens=response.get("prompt_eval_count") or 0, + completion_tokens=response.get("eval_count") or 0, + ) + elif provider == "open_router": + tools_param = tools if tools else openai.NOT_GIVEN + client = openai.OpenAI( + base_url="https://openrouter.ai/api/v1", + api_key=credentials.api_key.get_secret_value(), + ) + + response = client.chat.completions.create( + extra_headers={ + "HTTP-Referer": "https://agpt.co", + "X-Title": "AutoGPT", + }, + model=llm_model.value, + messages=prompt, # type: ignore + max_tokens=max_tokens, + tools=tools_param, # type: ignore + ) + + # If there's no response, raise an error + if not response.choices: + if response: + raise ValueError(f"OpenRouter error: {response}") + else: + raise ValueError("No response from OpenRouter.") + + if response.choices[0].message.tool_calls: + tool_calls = [ + ToolContentBlock( + id=tool.id, + type=tool.type, + function=ToolCall( + name=tool.function.name, arguments=tool.function.arguments + ), + ) + for tool in response.choices[0].message.tool_calls + ] + else: + tool_calls = None + + return LLMResponse( + prompt=json.dumps(prompt), + response=response.choices[0].message.content or "", + tool_calls=tool_calls, + prompt_tokens=response.usage.prompt_tokens if response.usage else 0, + completion_tokens=response.usage.completion_tokens if response.usage else 0, + ) + else: + raise ValueError(f"Unsupported LLM provider: {provider}") + + class AIBlockBase(Block, ABC): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -260,7 +547,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase): ) model: LlmModel = SchemaField( title="LLM Model", - default=LlmModel.GPT4_TURBO, + default=LlmModel.GPT4O, description="The language model to use for answering the prompt.", advanced=False, ) @@ -311,7 +598,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase): input_schema=AIStructuredResponseGeneratorBlock.Input, output_schema=AIStructuredResponseGeneratorBlock.Output, test_input={ - "model": LlmModel.GPT4_TURBO, + "model": LlmModel.GPT4O, "credentials": TEST_CREDENTIALS_INPUT, "expected_format": { "key1": "value1", @@ -325,19 +612,20 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase): ("prompt", str), ], test_mock={ - "llm_call": lambda *args, **kwargs: ( - json.dumps( + "llm_call": lambda *args, **kwargs: LLMResponse( + prompt="", + response=json.dumps( { "key1": "key1Value", "key2": "key2Value", } ), - 0, - 0, + tool_calls=None, + prompt_tokens=0, + completion_tokens=0, ) }, ) - self.prompt = "" def llm_call( self, @@ -346,154 +634,22 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase): prompt: list[dict], json_format: bool, max_tokens: int | None, + tools: list[dict] | None = None, ollama_host: str = "localhost:11434", - ) -> tuple[str, int, int]: + ) -> LLMResponse: """ - Args: - credentials: The API key credentials to use. - llm_model: The LLM model to use. - prompt: The prompt to send to the LLM. - json_format: Whether the response should be in JSON format. - max_tokens: The maximum number of tokens to generate in the chat completion. - ollama_host: The host for ollama to use - - Returns: - The response from the LLM. - The number of tokens used in the prompt. - The number of tokens used in the completion. + Test mocks work only on class functions, this wraps the llm_call function + so that it can be mocked withing the block testing framework. """ - provider = llm_model.metadata.provider - max_tokens = max_tokens or llm_model.max_output_tokens or 4096 - - if provider == "openai": - oai_client = openai.OpenAI(api_key=credentials.api_key.get_secret_value()) - response_format = None - - if llm_model in [LlmModel.O1_MINI, LlmModel.O1_PREVIEW]: - sys_messages = [p["content"] for p in prompt if p["role"] == "system"] - usr_messages = [p["content"] for p in prompt if p["role"] != "system"] - prompt = [ - {"role": "user", "content": "\n".join(sys_messages)}, - {"role": "user", "content": "\n".join(usr_messages)}, - ] - elif json_format: - response_format = {"type": "json_object"} - - response = oai_client.chat.completions.create( - model=llm_model.value, - messages=prompt, # type: ignore - response_format=response_format, # type: ignore - max_completion_tokens=max_tokens, - ) - self.prompt = json.dumps(prompt) - - return ( - response.choices[0].message.content or "", - response.usage.prompt_tokens if response.usage else 0, - response.usage.completion_tokens if response.usage else 0, - ) - elif provider == "anthropic": - system_messages = [p["content"] for p in prompt if p["role"] == "system"] - sysprompt = " ".join(system_messages) - - messages = [] - last_role = None - for p in prompt: - if p["role"] in ["user", "assistant"]: - if p["role"] != last_role: - messages.append({"role": p["role"], "content": p["content"]}) - last_role = p["role"] - else: - # If the role is the same as the last one, combine the content - messages[-1]["content"] += "\n" + p["content"] - - client = anthropic.Anthropic(api_key=credentials.api_key.get_secret_value()) - try: - resp = client.messages.create( - model=llm_model.value, - system=sysprompt, - messages=messages, - max_tokens=max_tokens, - ) - self.prompt = json.dumps(prompt) - - if not resp.content: - raise ValueError("No content returned from Anthropic.") - - return ( - ( - resp.content[0].name - if isinstance(resp.content[0], anthropic.types.ToolUseBlock) - else resp.content[0].text - ), - resp.usage.input_tokens, - resp.usage.output_tokens, - ) - except anthropic.APIError as e: - error_message = f"Anthropic API error: {str(e)}" - logger.error(error_message) - raise ValueError(error_message) - elif provider == "groq": - client = Groq(api_key=credentials.api_key.get_secret_value()) - response_format = {"type": "json_object"} if json_format else None - response = client.chat.completions.create( - model=llm_model.value, - messages=prompt, # type: ignore - response_format=response_format, # type: ignore - max_tokens=max_tokens, - ) - self.prompt = json.dumps(prompt) - return ( - response.choices[0].message.content or "", - response.usage.prompt_tokens if response.usage else 0, - response.usage.completion_tokens if response.usage else 0, - ) - elif provider == "ollama": - client = ollama.Client(host=ollama_host) - sys_messages = [p["content"] for p in prompt if p["role"] == "system"] - usr_messages = [p["content"] for p in prompt if p["role"] != "system"] - response = client.generate( - model=llm_model.value, - prompt=f"{sys_messages}\n\n{usr_messages}", - stream=False, - ) - self.prompt = json.dumps(prompt) - return ( - response.get("response") or "", - response.get("prompt_eval_count") or 0, - response.get("eval_count") or 0, - ) - elif provider == "open_router": - client = openai.OpenAI( - base_url="https://openrouter.ai/api/v1", - api_key=credentials.api_key.get_secret_value(), - ) - - response = client.chat.completions.create( - extra_headers={ - "HTTP-Referer": "https://agpt.co", - "X-Title": "AutoGPT", - }, - model=llm_model.value, - messages=prompt, # type: ignore - max_tokens=max_tokens, - ) - self.prompt = json.dumps(prompt) - - # If there's no response, raise an error - if not response.choices: - if response: - raise ValueError(f"OpenRouter error: {response}") - else: - raise ValueError("No response from OpenRouter.") - - return ( - response.choices[0].message.content or "", - response.usage.prompt_tokens if response.usage else 0, - response.usage.completion_tokens if response.usage else 0, - ) - else: - raise ValueError(f"Unsupported LLM provider: {provider}") + return llm_call( + credentials=credentials, + llm_model=llm_model, + prompt=prompt, + json_format=json_format, + max_tokens=max_tokens, + tools=tools, + ollama_host=ollama_host, + ) def run( self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs @@ -549,7 +705,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase): for retry_count in range(input_data.retry): try: - response_text, input_token, output_token = self.llm_call( + llm_response = self.llm_call( credentials=credentials, llm_model=llm_model, prompt=prompt, @@ -557,10 +713,11 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase): ollama_host=input_data.ollama_host, max_tokens=input_data.max_tokens, ) + response_text = llm_response.response self.merge_stats( { - "input_token_count": input_token, - "output_token_count": output_token, + "input_token_count": llm_response.prompt_tokens, + "output_token_count": llm_response.completion_tokens, } ) logger.info(f"LLM attempt-{retry_count} response: {response_text}") @@ -621,7 +778,7 @@ class AITextGeneratorBlock(AIBlockBase): ) model: LlmModel = SchemaField( title="LLM Model", - default=LlmModel.GPT4_TURBO, + default=LlmModel.GPT4O, description="The language model to use for answering the prompt.", advanced=False, ) @@ -714,7 +871,7 @@ class AITextSummarizerBlock(AIBlockBase): ) model: LlmModel = SchemaField( title="LLM Model", - default=LlmModel.GPT4_TURBO, + default=LlmModel.GPT4O, description="The language model to use for summarizing the text.", ) focus: str = SchemaField( @@ -880,7 +1037,7 @@ class AIConversationBlock(AIBlockBase): ) model: LlmModel = SchemaField( title="LLM Model", - default=LlmModel.GPT4_TURBO, + default=LlmModel.GPT4O, description="The language model to use for the conversation.", ) credentials: AICredentials = AICredentialsField() @@ -919,7 +1076,7 @@ class AIConversationBlock(AIBlockBase): }, {"role": "user", "content": "Where was it played?"}, ], - "model": LlmModel.GPT4_TURBO, + "model": LlmModel.GPT4O, "credentials": TEST_CREDENTIALS_INPUT, }, test_credentials=TEST_CREDENTIALS, @@ -981,7 +1138,7 @@ class AIListGeneratorBlock(AIBlockBase): ) model: LlmModel = SchemaField( title="LLM Model", - default=LlmModel.GPT4_TURBO, + default=LlmModel.GPT4O, description="The language model to use for generating the list.", advanced=True, ) @@ -1030,7 +1187,7 @@ class AIListGeneratorBlock(AIBlockBase): "drawing explorers to uncover its mysteries. Each planet showcases the limitless possibilities of " "fictional worlds." ), - "model": LlmModel.GPT4_TURBO, + "model": LlmModel.GPT4O, "credentials": TEST_CREDENTIALS_INPUT, "max_retries": 3, }, diff --git a/autogpt_platform/backend/backend/blocks/smart_decision_maker.py b/autogpt_platform/backend/backend/blocks/smart_decision_maker.py new file mode 100644 index 0000000000..2894d07efa --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/smart_decision_maker.py @@ -0,0 +1,307 @@ +import json +import logging +import re +from typing import Any, List + +from autogpt_libs.utils.cache import thread_cached + +import backend.blocks.llm as llm +from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema, BlockType +from backend.data.model import SchemaField + +logger = logging.getLogger(__name__) + + +@thread_cached +def get_database_manager_client(): + from backend.executor import DatabaseManager + from backend.util.service import get_service_client + + return get_service_client(DatabaseManager) + + +class SmartDecisionMakerBlock(Block): + """ + A block that uses a language model to make smart decisions based on a given prompt. + """ + + class Input(BlockSchema): + prompt: str = SchemaField( + description="The prompt to send to the language model.", + placeholder="Enter your prompt here...", + ) + model: llm.LlmModel = SchemaField( + title="LLM Model", + default=llm.LlmModel.GPT4O, + description="The language model to use for answering the prompt.", + advanced=False, + ) + credentials: llm.AICredentials = llm.AICredentialsField() + sys_prompt: str = SchemaField( + title="System Prompt", + default="Thinking carefully step by step decide which function to call. Always choose a function call from the list of function signatures.", + description="The system prompt to provide additional context to the model.", + ) + conversation_history: list[llm.Message] = SchemaField( + default=[], + description="The conversation history to provide context for the prompt.", + ) + retry: int = SchemaField( + title="Retry Count", + default=3, + description="Number of times to retry the LLM call if the response does not match the expected format.", + ) + prompt_values: dict[str, str] = SchemaField( + advanced=False, + default={}, + description="Values used to fill in the prompt. The values can be used in the prompt by putting them in a double curly braces, e.g. {{variable_name}}.", + ) + max_tokens: int | None = SchemaField( + advanced=True, + default=None, + description="The maximum number of tokens to generate in the chat completion.", + ) + ollama_host: str = SchemaField( + advanced=True, + default="localhost:11434", + description="Ollama host for local models", + ) + + class Output(BlockSchema): + + prompt: str = SchemaField(description="The prompt sent to the language model.") + error: str = SchemaField(description="Error message if the API call failed.") + function_signatures: list[dict[str, Any]] = SchemaField( + description="The function signatures that are sent to the language model." + ) + tools: Any = SchemaField(description="The tools that are available to use.") + finished: str = SchemaField( + description="The finished message to display to the user." + ) + + def __init__(self): + super().__init__( + id="3b191d9f-356f-482d-8238-ba04b6d18381", + description="Uses AI to intelligently decide what tool to use.", + categories={BlockCategory.AI}, + block_type=BlockType.AI, + input_schema=SmartDecisionMakerBlock.Input, + output_schema=SmartDecisionMakerBlock.Output, + test_input={ + "prompt": "Hello, World!", + "credentials": llm.TEST_CREDENTIALS_INPUT, + }, + test_output=[], + test_credentials=llm.TEST_CREDENTIALS, + ) + + # If I import Graph here, it will break with a circular import. + def _get_tool_graph_metadata(self, node_id: str, graph: Any) -> List[Any]: + """ + Retrieves metadata for tool graphs linked to a specified node within a graph. + + This method identifies the tool links connected to the given node_id and fetches + the metadata for each linked tool graph from the database. + + Args: + node_id (str): The ID of the node for which tool graph metadata is to be retrieved. + graph (Any): The graph object containing nodes and links. + + Returns: + List[Any]: A list of metadata for the tool graphs linked to the specified node. + """ + db_client = get_database_manager_client() + graph_meta = [] + + tool_links = { + link.sink_id + for link in graph.links + if link.source_name.startswith("tools_^_") and link.source_id == node_id + } + + for link_id in tool_links: + node = next((node for node in graph.nodes if node.id == link_id), None) + if node: + node_graph_meta = db_client.get_graph_metadata( + node.input_default["graph_id"], node.input_default["graph_version"] + ) + if node_graph_meta: + graph_meta.append(node_graph_meta) + + return graph_meta + + @staticmethod + def _create_function_signature( + # If I import Graph here, it will break with a circular import. + node_id: str, + graph: Any, + tool_graph_metadata: List[Any], + ) -> list[dict[str, Any]]: + """ + Creates function signatures for tools linked to a specified node within a graph. + + This method filters the graph links to identify those that are tools and are + connected to the given node_id. It then constructs function signatures for each + tool based on the metadata and input schema of the linked nodes. + + Args: + node_id (str): The ID of the node for which tool function signatures are to be created. + graph (Any): The graph object containing nodes and links. + tool_graph_metadata (List[Any]): Metadata for the tool graphs, used to retrieve + names and descriptions for the tools. + + Returns: + list[dict[str, Any]]: A list of dictionaries, each representing a function signature + for a tool, including its name, description, and parameters. + + Raises: + ValueError: If no tool links are found for the specified node_id, or if a sink node + or its metadata cannot be found. + """ + # Filter the graph links to find those that are tools and are linked to the specified node_id + tool_links = [ + link + for link in graph.links + # NOTE: Maybe we can do a specific database call to only get relevant nodes + # async def get_connected_output_nodes(source_node_id: str) -> list[Node]: + # links = await AgentNodeLink.prisma().find_many( + # where={"agentNodeSourceId": source_node_id}, + # include={"AgentNode": {"include": AGENT_NODE_INCLUDE}}, + # ) + # return [NodeModel.from_db(link.AgentNodeSink) for link in links] + if link.source_name.startswith("tools_^_") and link.source_id == node_id + ] + + if not tool_links: + raise ValueError( + f"Expected at least one tool link in the graph. Node ID: {node_id}. Graph: {graph.links}" + ) + + return_tool_functions = [] + + grouped_tool_links = {} + + for link in tool_links: + grouped_tool_links.setdefault(link.sink_id, []).append(link) + + for _, links in grouped_tool_links.items(): + sink_node = next( + (node for node in graph.nodes if node.id == links[0].sink_id), None + ) + + if not sink_node: + raise ValueError(f"Sink node not found: {links[0].sink_id}") + + graph_id = sink_node.input_default["graph_id"] + graph_version = sink_node.input_default["graph_version"] + + sink_graph_meta = next( + ( + meta + for meta in tool_graph_metadata + if meta.id == graph_id and meta.version == graph_version + ), + None, + ) + + if not sink_graph_meta: + raise ValueError( + f"Sink graph metadata not found: {graph_id} {graph_version}" + ) + + tool_function: dict[str, Any] = { + "name": re.sub(r"[^a-zA-Z0-9_-]", "_", sink_graph_meta.name).lower(), + "description": sink_graph_meta.description, + } + + properties = {} + required = [] + + for link in links: + sink_block_input_schema = sink_node.input_default["input_schema"] + description = ( + sink_block_input_schema["properties"][link.sink_name]["description"] + if "description" + in sink_block_input_schema["properties"][link.sink_name] + else f"The {link.sink_name} of the tool" + ) + properties[link.sink_name.lower()] = { + "type": "string", + "description": description, + } + + tool_function["parameters"] = { + "type": "object", + "properties": properties, + "required": required, + "additionalProperties": False, + "strict": True, + } + + return_tool_functions.append( + {"type": "function", "function": tool_function} + ) + return return_tool_functions + + def run( + self, + input_data: Input, + *, + credentials: llm.APIKeyCredentials, + graph_id: str, + node_id: str, + graph_exec_id: str, + node_exec_id: str, + user_id: str, + **kwargs, + ) -> BlockOutput: + db_client = get_database_manager_client() + + # Retrieve the current graph and node details + graph = db_client.get_graph(graph_id=graph_id, user_id=user_id) + + if not graph: + raise ValueError( + f"The currently running graph that is executing this node is not found {graph_id}" + ) + + tool_graph_metadata = self._get_tool_graph_metadata(node_id, graph) + + tool_functions = self._create_function_signature( + node_id, graph, tool_graph_metadata + ) + + prompt = [p.model_dump() for p in input_data.conversation_history] + + values = input_data.prompt_values + if values: + input_data.prompt = llm.fmt.format_string(input_data.prompt, values) + input_data.sys_prompt = llm.fmt.format_string(input_data.sys_prompt, values) + + if input_data.sys_prompt: + prompt.append({"role": "system", "content": input_data.sys_prompt}) + + if input_data.prompt: + prompt.append({"role": "user", "content": input_data.prompt}) + + response = llm.llm_call( + credentials=credentials, + llm_model=input_data.model, + prompt=prompt, + json_format=False, + max_tokens=input_data.max_tokens, + tools=tool_functions, + ollama_host=input_data.ollama_host, + ) + + if not response.tool_calls: + + yield "finished", f"No Decision Made finishing task: {response.response}" + + if response.tool_calls: + for tool_call in response.tool_calls: + tool_name = tool_call.function.name + tool_args = json.loads(tool_call.function.arguments) + + for arg_name, arg_value in tool_args.items(): + yield f"tools_^_{tool_name}_{arg_name}".lower(), arg_value diff --git a/autogpt_platform/backend/backend/data/block.py b/autogpt_platform/backend/backend/data/block.py index 5592af8dc7..05ff08e08c 100644 --- a/autogpt_platform/backend/backend/data/block.py +++ b/autogpt_platform/backend/backend/data/block.py @@ -44,6 +44,7 @@ class BlockType(Enum): WEBHOOK = "Webhook" WEBHOOK_MANUAL = "Webhook (manual)" AGENT = "Agent" + AI = "AI" class BlockCategory(Enum): @@ -351,6 +352,14 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]): Run the block with the given input data. Args: input_data: The input data with the structure of input_schema. + + Kwargs: Currently 14/02/2025 these include + graph_id: The ID of the graph. + node_id: The ID of the node. + graph_exec_id: The ID of the graph execution. + node_exec_id: The ID of the node execution. + user_id: The ID of the user. + Returns: A Generator that yields (output_name, output_data). output_name: One of the output name defined in Block's output_schema. diff --git a/autogpt_platform/backend/backend/data/execution.py b/autogpt_platform/backend/backend/data/execution.py index eae46278d8..bd6d79d47e 100644 --- a/autogpt_platform/backend/backend/data/execution.py +++ b/autogpt_platform/backend/backend/data/execution.py @@ -399,7 +399,38 @@ OBJC_SPLIT = "_@_" def parse_execution_output(output: BlockData, name: str) -> Any | None: - # Allow extracting partial output data by name. + """ + Extracts partial output data by name from a given BlockData. + + The function supports extracting data from lists, dictionaries, and objects + using specific naming conventions: + - For lists: _$_ + - For dictionaries: _#_ + - For objects: _@_ + + Args: + output (BlockData): A tuple containing the output name and data. + name (str): The name used to extract specific data from the output. + + Returns: + Any | None: The extracted data if found, otherwise None. + + Examples: + >>> output = ("result", [10, 20, 30]) + >>> parse_execution_output(output, "result_$_1") + 20 + + >>> output = ("config", {"key1": "value1", "key2": "value2"}) + >>> parse_execution_output(output, "config_#_key1") + 'value1' + + >>> class Sample: + ... attr1 = "value1" + ... attr2 = "value2" + >>> output = ("object", Sample()) + >>> parse_execution_output(output, "object_@_attr1") + 'value1' + """ output_name, output_data = output if name == output_name: @@ -428,11 +459,37 @@ def parse_execution_output(output: BlockData, name: str) -> Any | None: def merge_execution_input(data: BlockInput) -> BlockInput: """ - Merge all dynamic input pins which described by the following pattern: - - _$_ for list input. - - _#_ for dict input. - - _@_ for object input. - This function will construct pins with the same name into a single list/dict/object. + Merges dynamic input pins into a single list, dictionary, or object based on naming patterns. + + This function processes input keys that follow specific patterns to merge them into a unified structure: + - `_$_` for list inputs. + - `_#_` for dictionary inputs. + - `_@_` for object inputs. + + Args: + data (BlockInput): A dictionary containing input keys and their corresponding values. + + Returns: + BlockInput: A dictionary with merged inputs. + + Raises: + ValueError: If a list index is not an integer. + + Examples: + >>> data = { + ... "list_$_0": "a", + ... "list_$_1": "b", + ... "dict_#_key1": "value1", + ... "dict_#_key2": "value2", + ... "object_@_attr1": "value1", + ... "object_@_attr2": "value2" + ... } + >>> merge_execution_input(data) + { + "list": ["a", "b"], + "dict": {"key1": "value1", "key2": "value2"}, + "object": + } """ # Merge all input with _$_ into a single list. diff --git a/autogpt_platform/backend/backend/data/graph.py b/autogpt_platform/backend/backend/data/graph.py index 52e37625cd..7df978bc0a 100644 --- a/autogpt_platform/backend/backend/data/graph.py +++ b/autogpt_platform/backend/backend/data/graph.py @@ -312,16 +312,49 @@ class GraphModel(Graph): def validate_graph(self, for_run: bool = False): def sanitize(name): - return name.split("_#_")[0].split("_@_")[0].split("_$_")[0] + sanitized_name = name.split("_#_")[0].split("_@_")[0].split("_$_")[0] + if sanitized_name.startswith("tools_^_"): + return sanitized_name.split("_^_")[0] + return sanitized_name + + # Validate smart decision maker nodes + smart_decision_maker_nodes = set() + agent_nodes = set() + nodes_block = { + node.id: block + for node in self.nodes + if (block := get_block(node.block_id)) is not None + } + + for node in self.nodes: + if (block := nodes_block.get(node.id)) is None: + raise ValueError(f"Invalid block {node.block_id} for node #{node.id}") + + # Smart decision maker nodes + if block.block_type == BlockType.AI: + smart_decision_maker_nodes.add(node.id) + # Agent nodes + elif block.block_type == BlockType.AGENT: + agent_nodes.add(node.id) input_links = defaultdict(list) + for link in self.links: input_links[link.sink_id].append(link) + # Check if the link is a tool link from a smart decision maker to a non-agent node + if ( + link.source_id in smart_decision_maker_nodes + and link.source_name.startswith("tools_^_") + and link.sink_id not in agent_nodes + ): + raise ValueError( + f"Smart decision maker node {link.source_id} cannot link to non-agent node {link.sink_id}" + ) + # Nodes: required fields are filled or connected and dependencies are satisfied for node in self.nodes: - block = get_block(node.block_id) - if block is None: + if (block := nodes_block.get(node.id)) is None: raise ValueError(f"Invalid block {node.block_id} for node #{node.id}") provided_inputs = set( @@ -341,6 +374,7 @@ class GraphModel(Graph): or block.block_type == BlockType.INPUT or block.block_type == BlockType.OUTPUT or block.block_type == BlockType.AGENT + or block.block_type == BlockType.AI ) ): raise ValueError( @@ -410,16 +444,16 @@ class GraphModel(Graph): if i == 0: fields = ( block.output_schema.get_fields() - if block.block_type != BlockType.AGENT + if block.block_type not in [BlockType.AGENT, BlockType.AI] else vals.get("output_schema", {}).get("properties", {}).keys() ) else: fields = ( block.input_schema.get_fields() - if block.block_type != BlockType.AGENT + if block.block_type not in [BlockType.AGENT, BlockType.AI] else vals.get("input_schema", {}).get("properties", {}).keys() ) - if sanitized_name not in fields: + if sanitized_name not in fields and not name.startswith("tools_"): fields_msg = f"Allowed fields: {fields}" raise ValueError(f"{suffix}, `{name}` invalid, {fields_msg}") @@ -610,6 +644,33 @@ async def get_execution(user_id: str, execution_id: str) -> GraphExecution | Non return GraphExecution.from_db(execution) if execution else None +async def get_graph_metadata(graph_id: str, version: int | None = None) -> Graph | None: + where_clause: AgentGraphWhereInput = { + "id": graph_id, + } + + if version is not None: + where_clause["version"] = version + + graph = await AgentGraph.prisma().find_first( + where=where_clause, + include=AGENT_GRAPH_INCLUDE, + order={"version": "desc"}, + ) + + if not graph: + return None + + return Graph( + id=graph.id, + name=graph.name or "", + description=graph.description or "", + version=graph.version, + is_active=graph.isActive, + is_template=graph.isTemplate, + ) + + async def get_graph( graph_id: str, version: int | None = None, diff --git a/autogpt_platform/backend/backend/executor/database.py b/autogpt_platform/backend/backend/executor/database.py index 0834d96c29..fad1a678de 100644 --- a/autogpt_platform/backend/backend/executor/database.py +++ b/autogpt_platform/backend/backend/executor/database.py @@ -13,7 +13,7 @@ from backend.data.execution import ( upsert_execution_input, upsert_execution_output, ) -from backend.data.graph import get_graph, get_node +from backend.data.graph import get_graph, get_graph_metadata, get_node from backend.data.user import ( get_user_integrations, get_user_metadata, @@ -60,6 +60,7 @@ class DatabaseManager(AppService): # Graphs get_node = exposed_run_and_wait(get_node) get_graph = exposed_run_and_wait(get_graph) + get_graph_metadata = exposed_run_and_wait(get_graph_metadata) # Credits spend_credits = exposed_run_and_wait(_spend_credits) diff --git a/autogpt_platform/backend/backend/server/rest_api.py b/autogpt_platform/backend/backend/server/rest_api.py index d5435e3615..06249e4c02 100644 --- a/autogpt_platform/backend/backend/server/rest_api.py +++ b/autogpt_platform/backend/backend/server/rest_api.py @@ -16,6 +16,7 @@ import backend.data.block import backend.data.db import backend.data.graph import backend.data.user +import backend.server.integrations.router import backend.server.routers.v1 import backend.server.v2.library.db import backend.server.v2.library.model @@ -24,6 +25,8 @@ import backend.server.v2.store.model import backend.server.v2.store.routes import backend.util.service import backend.util.settings +from backend.data.model import Credentials +from backend.integrations.providers import ProviderName from backend.server.external.api import external_app settings = backend.util.settings.Settings() @@ -243,5 +246,15 @@ class AgentServer(backend.util.service.AppProcess): ): return await backend.server.v2.store.routes.review_submission(request, user) + @staticmethod + def test_create_credentials( + user_id: str, + provider: ProviderName, + credentials: Credentials, + ) -> Credentials: + return backend.server.integrations.router.create_credentials( + user_id=user_id, provider=provider, credentials=credentials + ) + def set_test_dependency_overrides(self, overrides: dict): app.dependency_overrides.update(overrides) diff --git a/autogpt_platform/backend/backend/usecases/sample.py b/autogpt_platform/backend/backend/usecases/sample.py index 8171850a7c..e19357fb4a 100644 --- a/autogpt_platform/backend/backend/usecases/sample.py +++ b/autogpt_platform/backend/backend/usecases/sample.py @@ -40,7 +40,10 @@ def create_test_graph() -> graph.Graph: ), graph.Node( block_id=AgentInputBlock().id, - input_default={"name": "input_2"}, + input_default={ + "name": "input_2", + "description": "This is my description of this parameter", + }, ), graph.Node( block_id=FillTextTemplateBlock().id, @@ -74,7 +77,7 @@ def create_test_graph() -> graph.Graph: return graph.Graph( name="TestGraph", - description="Test graph", + description="Test graph description", nodes=nodes, links=links, ) diff --git a/autogpt_platform/backend/test/executor/test_execution_functions.py b/autogpt_platform/backend/test/executor/test_execution_functions.py new file mode 100644 index 0000000000..4fce8f49a8 --- /dev/null +++ b/autogpt_platform/backend/test/executor/test_execution_functions.py @@ -0,0 +1,55 @@ +from backend.data.execution import merge_execution_input, parse_execution_output + + +def test_parse_execution_output(): + # Test case for list extraction + output = ("result", [10, 20, 30]) + assert parse_execution_output(output, "result_$_1") == 20 + assert parse_execution_output(output, "result_$_3") is None + + # Test case for dictionary extraction + output = ("config", {"key1": "value1", "key2": "value2"}) + assert parse_execution_output(output, "config_#_key1") == "value1" + assert parse_execution_output(output, "config_#_key3") is None + + # Test case for object extraction + class Sample: + attr1 = "value1" + attr2 = "value2" + + output = ("object", Sample()) + assert parse_execution_output(output, "object_@_attr1") == "value1" + assert parse_execution_output(output, "object_@_attr3") is None + + # Test case for direct match + output = ("direct", "match") + assert parse_execution_output(output, "direct") == "match" + assert parse_execution_output(output, "nomatch") is None + + +def test_merge_execution_input(): + # Test case for merging list inputs + data = {"list_$_0": "a", "list_$_1": "b", "list_$_3": "d"} + merged_data = merge_execution_input(data) + assert merged_data["list"] == ["a", "b", "", "d"] + + # Test case for merging dictionary inputs + data = {"dict_#_key1": "value1", "dict_#_key2": "value2"} + merged_data = merge_execution_input(data) + assert merged_data["dict"] == {"key1": "value1", "key2": "value2"} + + # Test case for merging object inputs + data = {"object_@_attr1": "value1", "object_@_attr2": "value2"} + merged_data = merge_execution_input(data) + assert hasattr(merged_data["object"], "attr1") + assert hasattr(merged_data["object"], "attr2") + assert merged_data["object"].attr1 == "value1" + assert merged_data["object"].attr2 == "value2" + + # Test case for mixed inputs + data = {"list_$_0": "a", "dict_#_key1": "value1", "object_@_attr1": "value1"} + merged_data = merge_execution_input(data) + assert merged_data["list"] == ["a"] + assert merged_data["dict"] == {"key1": "value1"} + assert hasattr(merged_data["object"], "attr1") + assert merged_data["object"].attr1 == "value1" diff --git a/autogpt_platform/backend/test/executor/test_manager.py b/autogpt_platform/backend/test/executor/test_manager.py index 0412a779bc..40dbc9dfaa 100644 --- a/autogpt_platform/backend/test/executor/test_manager.py +++ b/autogpt_platform/backend/test/executor/test_manager.py @@ -73,6 +73,7 @@ async def assert_sample_graph_executions( { "name": "input_2", "value": "World", + "description": "This is my description of this parameter", }, ] diff --git a/autogpt_platform/backend/test/executor/test_tool_use.py b/autogpt_platform/backend/test/executor/test_tool_use.py new file mode 100644 index 0000000000..cc70415001 --- /dev/null +++ b/autogpt_platform/backend/test/executor/test_tool_use.py @@ -0,0 +1,249 @@ +import logging + +import pytest +from prisma.models import User + +import backend.blocks.llm as llm +from backend.blocks.agent import AgentExecutorBlock +from backend.blocks.basic import StoreValueBlock +from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock +from backend.data import graph +from backend.data.model import ProviderName +from backend.server.model import CreateGraph +from backend.server.rest_api import AgentServer +from backend.usecases.sample import create_test_graph, create_test_user +from backend.util.test import SpinTestServer, wait_execution + +logger = logging.getLogger(__name__) + + +async def create_graph(s: SpinTestServer, g: graph.Graph, u: User) -> graph.Graph: + logger.info("Creating graph for user %s", u.id) + return await s.agent_server.test_create_graph(CreateGraph(graph=g), u.id) + + +def create_credentials(s: SpinTestServer, u: User): + provider = ProviderName.OPENAI + credentials = llm.TEST_CREDENTIALS + try: + s.agent_server.test_create_credentials(u.id, provider, credentials) + except Exception: + # ValueErrors is raised trying to recreate the same credentials + # so hidding the error + pass + + +async def execute_graph( + agent_server: AgentServer, + test_graph: graph.Graph, + test_user: User, + input_data: dict, + num_execs: int = 4, +) -> str: + logger.info("Executing graph %s for user %s", test_graph.id, test_user.id) + logger.info("Input data: %s", input_data) + + # --- Test adding new executions --- # + response = await agent_server.test_execute_graph( + user_id=test_user.id, + graph_id=test_graph.id, + graph_version=test_graph.version, + node_input=input_data, + ) + graph_exec_id = response.graph_exec_id + logger.info("Created execution with ID: %s", graph_exec_id) + + # Execution queue should be empty + logger.info("Waiting for execution to complete...") + result = await wait_execution(test_user.id, test_graph.id, graph_exec_id, 30) + logger.info("Execution completed with %d results", len(result)) + return graph_exec_id + + +@pytest.mark.skip() +@pytest.mark.asyncio(scope="session") +async def test_graph_validation_with_tool_nodes_correct(server: SpinTestServer): + test_user = await create_test_user() + test_tool_graph = await create_graph(server, create_test_graph(), test_user) + create_credentials(server, test_user) + + nodes = [ + graph.Node( + block_id=SmartDecisionMakerBlock().id, + input_default={ + "prompt": "Hello, World!", + "credentials": llm.TEST_CREDENTIALS_INPUT, + }, + ), + graph.Node( + block_id=AgentExecutorBlock().id, + input_default={ + "graph_id": test_tool_graph.id, + "graph_version": test_tool_graph.version, + "input_schema": test_tool_graph.input_schema, + "output_schema": test_tool_graph.output_schema, + }, + ), + ] + + links = [ + graph.Link( + source_id=nodes[0].id, + sink_id=nodes[1].id, + source_name="tools_^_sample_tool_input_1", + sink_name="input_1", + ), + graph.Link( + source_id=nodes[0].id, + sink_id=nodes[1].id, + source_name="tools_^_sample_tool_input_2", + sink_name="input_2", + ), + ] + + test_graph = graph.Graph( + name="TestGraph", + description="Test graph", + nodes=nodes, + links=links, + ) + test_graph = await create_graph(server, test_graph, test_user) + + +@pytest.mark.skip() +@pytest.mark.asyncio(scope="session") +async def test_graph_validation_with_tool_nodes_raises_error(server: SpinTestServer): + + test_user = await create_test_user() + test_tool_graph = await create_graph(server, create_test_graph(), test_user) + create_credentials(server, test_user) + + nodes = [ + graph.Node( + block_id=SmartDecisionMakerBlock().id, + input_default={ + "prompt": "Hello, World!", + "credentials": llm.TEST_CREDENTIALS_INPUT, + }, + ), + graph.Node( + block_id=AgentExecutorBlock().id, + input_default={ + "graph_id": test_tool_graph.id, + "graph_version": test_tool_graph.version, + "input_schema": test_tool_graph.input_schema, + "output_schema": test_tool_graph.output_schema, + }, + ), + graph.Node( + block_id=StoreValueBlock().id, + ), + ] + + links = [ + graph.Link( + source_id=nodes[0].id, + sink_id=nodes[1].id, + source_name="tools_^_sample_tool_input_1", + sink_name="input_1", + ), + graph.Link( + source_id=nodes[0].id, + sink_id=nodes[1].id, + source_name="tools_^_sample_tool_input_2", + sink_name="input_2", + ), + graph.Link( + source_id=nodes[0].id, + sink_id=nodes[2].id, + source_name="tools_^_store_value_input", + sink_name="input", + ), + ] + + test_graph = graph.Graph( + name="TestGraph", + description="Test graph", + nodes=nodes, + links=links, + ) + with pytest.raises(ValueError): + test_graph = await create_graph(server, test_graph, test_user) + + +@pytest.mark.skip() +@pytest.mark.asyncio(scope="session") +async def test_smart_decision_maker_function_signature(server: SpinTestServer): + test_user = await create_test_user() + test_tool_graph = await create_graph(server, create_test_graph(), test_user) + create_credentials(server, test_user) + + nodes = [ + graph.Node( + block_id=SmartDecisionMakerBlock().id, + input_default={ + "prompt": "Hello, World!", + "credentials": llm.TEST_CREDENTIALS_INPUT, + }, + ), + graph.Node( + block_id=AgentExecutorBlock().id, + input_default={ + "graph_id": test_tool_graph.id, + "graph_version": test_tool_graph.version, + "input_schema": test_tool_graph.input_schema, + "output_schema": test_tool_graph.output_schema, + }, + ), + ] + + links = [ + graph.Link( + source_id=nodes[0].id, + sink_id=nodes[1].id, + source_name="tools_^_sample_tool_input_1", + sink_name="input_1", + ), + graph.Link( + source_id=nodes[0].id, + sink_id=nodes[1].id, + source_name="tools_^_sample_tool_input_2", + sink_name="input_2", + ), + ] + + test_graph = graph.Graph( + name="TestGraph", + description="Test graph", + nodes=nodes, + links=links, + ) + test_graph = await create_graph(server, test_graph, test_user) + + tool_functions = SmartDecisionMakerBlock._create_function_signature( + test_graph.nodes[0].id, test_graph, [test_tool_graph] + ) + assert tool_functions is not None, "Tool functions should not be None" + assert ( + len(tool_functions) == 1 + ), f"Expected 1 tool function, got {len(tool_functions)}" + + tool_function = next( + filter(lambda x: x["function"]["name"] == "testgraph", tool_functions), + None, + ) + assert tool_function is not None, f"testgraph function not found: {tool_functions}" + assert ( + tool_function["function"]["name"] == "testgraph" + ), "Incorrect function name for testgraph" + assert ( + tool_function["function"]["parameters"]["properties"]["input_1"]["type"] + == "string" + ), "Input type for input_1 should be 'string'" + assert ( + tool_function["function"]["parameters"]["properties"]["input_2"]["type"] + == "string" + ), "Input type for input_2 should be 'string'" + assert ( + tool_function["function"]["parameters"]["required"] == [] + ), "Required parameters should be an empty list" diff --git a/autogpt_platform/frontend/src/hooks/useAgentGraph.tsx b/autogpt_platform/frontend/src/hooks/useAgentGraph.tsx index 9a589bfaf0..5541c52bd2 100644 --- a/autogpt_platform/frontend/src/hooks/useAgentGraph.tsx +++ b/autogpt_platform/frontend/src/hooks/useAgentGraph.tsx @@ -191,12 +191,15 @@ export default function useAgentGraph( }); setEdges(() => graph.links.map((link) => { + const adjustedSourceName = link.source_name?.startsWith("tools_^_") + ? "tools" + : link.source_name; return { id: formatEdgeID(link), type: "custom", data: { edgeColor: getTypeColor( - getOutputType(newNodes, link.source_id, link.source_name!), + getOutputType(newNodes, link.source_id, adjustedSourceName!), ), sourcePos: newNodes.find((node) => node.id === link.source_id) ?.position, @@ -209,12 +212,12 @@ export default function useAgentGraph( type: MarkerType.ArrowClosed, strokeWidth: 2, color: getTypeColor( - getOutputType(newNodes, link.source_id, link.source_name!), + getOutputType(newNodes, link.source_id, adjustedSourceName!), ), }, source: link.source_id, target: link.sink_id, - sourceHandle: link.source_name || undefined, + sourceHandle: adjustedSourceName || undefined, targetHandle: link.sink_name || undefined, }; }), @@ -795,12 +798,35 @@ export default function useAgentGraph( }; }); - const links = edges.map((edge) => ({ - source_id: edge.source, - sink_id: edge.target, - source_name: edge.sourceHandle || "", - sink_name: edge.targetHandle || "", - })); + const links = edges.map((edge) => { + let sourceName = edge.sourceHandle || ""; + if (sourceName.toLowerCase() === "tools") { + const sinkNode = nodes.find((node) => node.id === edge.target); + + const sinkNodeName = sinkNode + ? sinkNode.data.block_id === "e189baac-8c20-45a1-94a7-55177ea42565" // AgentExecutorBlock ID + ? sinkNode.data.hardcodedValues?.graph_id + ? availableFlows + .find( + (flow) => + flow.id === sinkNode.data.hardcodedValues.graph_id, + ) + ?.name?.toLowerCase() + .replace(/ /g, "_") || "agentexecutorblock" + : "agentexecutorblock" + : sinkNode.data.title.toLowerCase().replace(/ /g, "_") + : ""; + + sourceName = + `tools_^_${sinkNodeName}_${edge.targetHandle || ""}`.toLowerCase(); + } + return { + source_id: edge.source, + sink_id: edge.target, + source_name: sourceName, + sink_name: edge.targetHandle || "", + }; + }); const payload = { id: savedAgent?.id!, diff --git a/autogpt_platform/frontend/src/tests/pages/build.page.ts b/autogpt_platform/frontend/src/tests/pages/build.page.ts index 0e76a045c7..ccb26991a2 100644 --- a/autogpt_platform/frontend/src/tests/pages/build.page.ts +++ b/autogpt_platform/frontend/src/tests/pages/build.page.ts @@ -411,6 +411,15 @@ export class BuildPage extends BasePage { }; } + async getSmartDecisionMakerBlockDetails(): Promise { + return { + id: "3b191d9f-356f-482d-8238-ba04b6d18381", + name: "Smart Decision Maker", + description: + "This block is used to make a decision based on the input and the available tools.", + }; + } + async nextTutorialStep(): Promise { console.log(`clicking next tutorial step`); await this.page.getByRole("button", { name: "Next" }).click(); @@ -487,6 +496,7 @@ export class BuildPage extends BasePage { (await this.getAgentInputBlockDetails()).id, (await this.getAgentOutputBlockDetails()).id, (await this.getGithubTriggerBlockDetails()).id, + (await this.getSmartDecisionMakerBlockDetails()).id, ]; }