mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
simplify function header generation
This commit is contained in:
@@ -41,7 +41,6 @@ from forge.llm.providers import (
|
||||
ChatModelResponse,
|
||||
MultiProvider,
|
||||
)
|
||||
from forge.llm.providers.utils import function_specs_from_commands
|
||||
from forge.models.action import (
|
||||
ActionErrorResult,
|
||||
ActionInterruptedByHuman,
|
||||
@@ -184,7 +183,7 @@ class Agent(BaseAgent[OneShotAgentActionProposal], Configurable[AgentSettings]):
|
||||
task=self.state.task,
|
||||
ai_profile=self.state.ai_profile,
|
||||
ai_directives=directives,
|
||||
commands=function_specs_from_commands(self.commands),
|
||||
commands=self.commands,
|
||||
include_os_info=self.legacy_config.execute_local_commands,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,12 +1,18 @@
|
||||
import re
|
||||
from logging import Logger
|
||||
from typing import Iterable, Sequence
|
||||
|
||||
from forge.command import Command
|
||||
from forge.config.ai_directives import AIDirectives
|
||||
from forge.config.ai_profile import AIProfile
|
||||
from forge.json.parsing import extract_dict_from_json
|
||||
from forge.llm.prompting import ChatPrompt, LanguageModelClassification, PromptStrategy
|
||||
from forge.llm.providers import AssistantChatMessage, CompletionModelFunction
|
||||
from forge.llm.providers.schema import AssistantFunctionCall, ChatMessage
|
||||
from forge.llm.prompting.utils import indent
|
||||
from forge.llm.providers.schema import (
|
||||
AssistantChatMessage,
|
||||
AssistantFunctionCall,
|
||||
ChatMessage,
|
||||
)
|
||||
from forge.models.config import SystemConfiguration
|
||||
from forge.models.json_schema import JSONSchema
|
||||
from forge.utils.exceptions import InvalidAgentResponseError
|
||||
@@ -92,7 +98,7 @@ class CodeFlowAgentPromptStrategy(PromptStrategy):
|
||||
CodeFlowAgentActionProposal.schema()
|
||||
)
|
||||
self.logger = logger
|
||||
self.commands: list[CompletionModelFunction] = []
|
||||
self.commands: Sequence[Command] = [] # Sequence -> disallow list modification
|
||||
|
||||
@property
|
||||
def model_classification(self) -> LanguageModelClassification:
|
||||
@@ -105,7 +111,7 @@ class CodeFlowAgentPromptStrategy(PromptStrategy):
|
||||
task: str,
|
||||
ai_profile: AIProfile,
|
||||
ai_directives: AIDirectives,
|
||||
commands: list[CompletionModelFunction],
|
||||
commands: Sequence[Command],
|
||||
**extras,
|
||||
) -> ChatPrompt:
|
||||
"""Constructs and returns a prompt with the following structure:
|
||||
@@ -115,7 +121,7 @@ class CodeFlowAgentPromptStrategy(PromptStrategy):
|
||||
system_prompt, response_prefill = self.build_system_prompt(
|
||||
ai_profile=ai_profile,
|
||||
ai_directives=ai_directives,
|
||||
functions=commands,
|
||||
commands=commands,
|
||||
)
|
||||
|
||||
self.commands = commands
|
||||
@@ -135,7 +141,7 @@ class CodeFlowAgentPromptStrategy(PromptStrategy):
|
||||
self,
|
||||
ai_profile: AIProfile,
|
||||
ai_directives: AIDirectives,
|
||||
functions: list[CompletionModelFunction],
|
||||
commands: Iterable[Command],
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Builds the system prompt.
|
||||
@@ -153,7 +159,7 @@ class CodeFlowAgentPromptStrategy(PromptStrategy):
|
||||
" in the next message. Your job is to complete the task, "
|
||||
"and terminate when your task is done."
|
||||
]
|
||||
+ ["## Available Functions\n" + self._generate_function_headers(functions)]
|
||||
+ ["## Available Functions\n" + self._generate_function_headers(commands)]
|
||||
+ ["## RESPONSE FORMAT\n" + response_fmt_instruction]
|
||||
)
|
||||
|
||||
@@ -195,8 +201,29 @@ class CodeFlowAgentPromptStrategy(PromptStrategy):
|
||||
# "simple strategies with no legal complications.",
|
||||
]
|
||||
|
||||
def _generate_function_headers(self, funcs: list[CompletionModelFunction]) -> str:
|
||||
return "\n\n".join(f.fmt_header(force_async=True) for f in funcs)
|
||||
def _generate_function_headers(self, commands: Iterable[Command]) -> str:
|
||||
return "\n\n".join(
|
||||
f.header
|
||||
+ "\n"
|
||||
+ indent(
|
||||
(
|
||||
'"""\n'
|
||||
f"{f.description}\n\n"
|
||||
"Params:\n"
|
||||
+ indent(
|
||||
"\n".join(
|
||||
f"{param.name}: {param.spec.description}"
|
||||
for param in f.parameters
|
||||
if param.spec.description
|
||||
)
|
||||
)
|
||||
+ "\n"
|
||||
'"""\n'
|
||||
"pass"
|
||||
),
|
||||
)
|
||||
for f in commands
|
||||
)
|
||||
|
||||
async def parse_response_content(
|
||||
self,
|
||||
@@ -220,21 +247,21 @@ class CodeFlowAgentPromptStrategy(PromptStrategy):
|
||||
raise ValueError("python_code is empty")
|
||||
|
||||
available_functions = {
|
||||
f.name: FunctionDef(
|
||||
name=f.name,
|
||||
arg_types=[(name, p.python_type) for name, p in f.parameters.items()],
|
||||
arg_descs={name: p.description for name, p in f.parameters.items()},
|
||||
c.name: FunctionDef(
|
||||
name=c.name,
|
||||
arg_types=[(p.name, p.spec.python_type) for p in c.parameters],
|
||||
arg_descs={p.name: p.spec.description for p in c.parameters},
|
||||
arg_defaults={
|
||||
name: p.default or "None"
|
||||
for name, p in f.parameters.items()
|
||||
if p.default or not p.required
|
||||
p.name: p.spec.default or "None"
|
||||
for p in c.parameters
|
||||
if p.spec.default or not p.spec.required
|
||||
},
|
||||
return_type=f.return_type,
|
||||
return_type=c.return_type,
|
||||
return_desc="Output of the function",
|
||||
function_desc=f.description,
|
||||
is_async=True,
|
||||
function_desc=c.description,
|
||||
is_async=c.is_async,
|
||||
)
|
||||
for f in self.commands
|
||||
for c in self.commands
|
||||
}
|
||||
available_functions.update(
|
||||
{
|
||||
|
||||
@@ -6,6 +6,7 @@ import re
|
||||
from logging import Logger
|
||||
|
||||
import distro
|
||||
from forge.command import Command
|
||||
from forge.config.ai_directives import AIDirectives
|
||||
from forge.config.ai_profile import AIProfile
|
||||
from forge.json.parsing import extract_dict_from_json
|
||||
@@ -16,6 +17,7 @@ from forge.llm.providers.schema import (
|
||||
ChatMessage,
|
||||
CompletionModelFunction,
|
||||
)
|
||||
from forge.llm.providers.utils import function_specs_from_commands
|
||||
from forge.models.action import ActionProposal
|
||||
from forge.models.config import SystemConfiguration, UserConfigurable
|
||||
from forge.models.json_schema import JSONSchema
|
||||
@@ -117,7 +119,7 @@ class OneShotAgentPromptStrategy(PromptStrategy):
|
||||
task: str,
|
||||
ai_profile: AIProfile,
|
||||
ai_directives: AIDirectives,
|
||||
commands: list[CompletionModelFunction],
|
||||
commands: list[Command],
|
||||
include_os_info: bool,
|
||||
**extras,
|
||||
) -> ChatPrompt:
|
||||
@@ -125,10 +127,11 @@ class OneShotAgentPromptStrategy(PromptStrategy):
|
||||
1. System prompt
|
||||
3. `cycle_instruction`
|
||||
"""
|
||||
functions = function_specs_from_commands(commands)
|
||||
system_prompt, response_prefill = self.build_system_prompt(
|
||||
ai_profile=ai_profile,
|
||||
ai_directives=ai_directives,
|
||||
commands=commands,
|
||||
functions=functions,
|
||||
include_os_info=include_os_info,
|
||||
)
|
||||
|
||||
@@ -142,14 +145,14 @@ class OneShotAgentPromptStrategy(PromptStrategy):
|
||||
final_instruction_msg,
|
||||
],
|
||||
prefill_response=response_prefill,
|
||||
functions=commands if self.config.use_functions_api else [],
|
||||
functions=functions if self.config.use_functions_api else [],
|
||||
)
|
||||
|
||||
def build_system_prompt(
|
||||
self,
|
||||
ai_profile: AIProfile,
|
||||
ai_directives: AIDirectives,
|
||||
commands: list[CompletionModelFunction],
|
||||
functions: list[CompletionModelFunction],
|
||||
include_os_info: bool,
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
@@ -169,7 +172,7 @@ class OneShotAgentPromptStrategy(PromptStrategy):
|
||||
self.config.body_template.format(
|
||||
constraints=format_numbered_list(ai_directives.constraints),
|
||||
resources=format_numbered_list(ai_directives.resources),
|
||||
commands=self._generate_commands_list(commands),
|
||||
commands=self._generate_commands_list(functions),
|
||||
best_practices=format_numbered_list(ai_directives.best_practices),
|
||||
)
|
||||
]
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from forge.agent.protocols import CommandProvider
|
||||
from forge.command import command
|
||||
from forge.components.code_flow_executor.code_flow_executor import (
|
||||
CodeFlowExecutionComponent,
|
||||
)
|
||||
from forge.config.ai_directives import AIDirectives
|
||||
from forge.config.ai_profile import AIProfile
|
||||
from forge.llm.providers import AssistantChatMessage
|
||||
from forge.llm.providers.schema import CompletionModelFunction, JSONSchema
|
||||
from forge.llm.providers.schema import JSONSchema
|
||||
|
||||
from autogpt.agents.prompt_strategies.code_flow import CodeFlowAgentPromptStrategy
|
||||
|
||||
@@ -16,32 +19,38 @@ config = CodeFlowAgentPromptStrategy.default_configuration.copy(deep=True)
|
||||
prompt_strategy = CodeFlowAgentPromptStrategy(config, logger)
|
||||
|
||||
|
||||
class MockWebSearchProvider(CommandProvider):
|
||||
def get_commands(self):
|
||||
yield self.mock_web_search
|
||||
|
||||
@command(
|
||||
description="Searches the web",
|
||||
parameters={
|
||||
"query": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The search query",
|
||||
required=True,
|
||||
),
|
||||
"num_results": JSONSchema(
|
||||
type=JSONSchema.Type.INTEGER,
|
||||
description="The number of results to return",
|
||||
minimum=1,
|
||||
maximum=10,
|
||||
required=False,
|
||||
),
|
||||
},
|
||||
)
|
||||
def mock_web_search(self, query: str, num_results: Optional[int] = None) -> str:
|
||||
return "results"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_code_flow_build_prompt():
|
||||
commands = [
|
||||
CompletionModelFunction(
|
||||
name="web_search",
|
||||
description="Searches the web",
|
||||
parameters={
|
||||
"query": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The search query",
|
||||
required=True,
|
||||
),
|
||||
"num_results": JSONSchema(
|
||||
type=JSONSchema.Type.INTEGER,
|
||||
description="The number of results to return",
|
||||
minimum=1,
|
||||
maximum=10,
|
||||
required=False,
|
||||
),
|
||||
},
|
||||
),
|
||||
]
|
||||
commands = list(MockWebSearchProvider().get_commands())
|
||||
|
||||
ai_profile = AIProfile()
|
||||
ai_profile.ai_name = "DummyGPT"
|
||||
ai_profile.ai_goals = "A model for testing purpose"
|
||||
ai_profile.ai_goals = ["A model for testing purposes"]
|
||||
ai_profile.ai_role = "Help Testing"
|
||||
|
||||
ai_directives = AIDirectives()
|
||||
@@ -59,7 +68,9 @@ async def test_code_flow_build_prompt():
|
||||
)
|
||||
)
|
||||
assert "DummyGPT" in prompt
|
||||
assert "async def web_search(query: str, num_results: int = None)" in prompt
|
||||
assert (
|
||||
"def mock_web_search(query: str, num_results: Optional[int] = None)" in prompt
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -41,6 +41,10 @@ class Command(Generic[P, CO]):
|
||||
self.method = cast(Callable[P, CO], method)
|
||||
self.parameters = parameters
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.names[0] # TODO: fallback to other name if first one is taken
|
||||
|
||||
@property
|
||||
def is_async(self) -> bool:
|
||||
return inspect.iscoroutinefunction(self.method)
|
||||
@@ -52,6 +56,21 @@ class Command(Generic[P, CO]):
|
||||
return None
|
||||
return type.__name__
|
||||
|
||||
@property
|
||||
def header(self) -> str:
|
||||
"""Returns a function header representing the command's signature
|
||||
|
||||
Examples:
|
||||
```py
|
||||
def execute_python_code(code: str) -> str:
|
||||
|
||||
async def extract_info_from_content(content: str, instruction: str, output_type: type[~T]) -> ~T:
|
||||
""" # noqa
|
||||
return (
|
||||
f"{'async ' if self.is_async else ''}"
|
||||
f"def {self.name}{inspect.signature(self.method)}:"
|
||||
)
|
||||
|
||||
def _parameters_match(
|
||||
self, func: Callable, parameters: list[CommandParameter]
|
||||
) -> bool:
|
||||
@@ -78,7 +97,7 @@ class Command(Generic[P, CO]):
|
||||
for param in self.parameters
|
||||
]
|
||||
return (
|
||||
f"{self.names[0]}: {self.description.rstrip('.')}. "
|
||||
f"{self.name}: {self.description.rstrip('.')}. "
|
||||
f"Params: ({', '.join(params)})"
|
||||
)
|
||||
|
||||
|
||||
@@ -80,7 +80,7 @@ def function_specs_from_commands(
|
||||
"""Get LLM-consumable function specs for the agent's available commands."""
|
||||
return [
|
||||
CompletionModelFunction(
|
||||
name=command.names[0],
|
||||
name=command.name,
|
||||
description=command.description,
|
||||
parameters={param.name: param.spec for param in command.parameters},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user