simplify function header generation

This commit is contained in:
Reinier van der Leer
2024-06-07 12:56:41 +02:00
parent fcca4cc893
commit 6e715b6c71
6 changed files with 111 additions and 52 deletions

View File

@@ -41,7 +41,6 @@ from forge.llm.providers import (
ChatModelResponse,
MultiProvider,
)
from forge.llm.providers.utils import function_specs_from_commands
from forge.models.action import (
ActionErrorResult,
ActionInterruptedByHuman,
@@ -184,7 +183,7 @@ class Agent(BaseAgent[OneShotAgentActionProposal], Configurable[AgentSettings]):
task=self.state.task,
ai_profile=self.state.ai_profile,
ai_directives=directives,
commands=function_specs_from_commands(self.commands),
commands=self.commands,
include_os_info=self.legacy_config.execute_local_commands,
)

View File

@@ -1,12 +1,18 @@
import re
from logging import Logger
from typing import Iterable, Sequence
from forge.command import Command
from forge.config.ai_directives import AIDirectives
from forge.config.ai_profile import AIProfile
from forge.json.parsing import extract_dict_from_json
from forge.llm.prompting import ChatPrompt, LanguageModelClassification, PromptStrategy
from forge.llm.providers import AssistantChatMessage, CompletionModelFunction
from forge.llm.providers.schema import AssistantFunctionCall, ChatMessage
from forge.llm.prompting.utils import indent
from forge.llm.providers.schema import (
AssistantChatMessage,
AssistantFunctionCall,
ChatMessage,
)
from forge.models.config import SystemConfiguration
from forge.models.json_schema import JSONSchema
from forge.utils.exceptions import InvalidAgentResponseError
@@ -92,7 +98,7 @@ class CodeFlowAgentPromptStrategy(PromptStrategy):
CodeFlowAgentActionProposal.schema()
)
self.logger = logger
self.commands: list[CompletionModelFunction] = []
self.commands: Sequence[Command] = [] # Sequence -> disallow list modification
@property
def model_classification(self) -> LanguageModelClassification:
@@ -105,7 +111,7 @@ class CodeFlowAgentPromptStrategy(PromptStrategy):
task: str,
ai_profile: AIProfile,
ai_directives: AIDirectives,
commands: list[CompletionModelFunction],
commands: Sequence[Command],
**extras,
) -> ChatPrompt:
"""Constructs and returns a prompt with the following structure:
@@ -115,7 +121,7 @@ class CodeFlowAgentPromptStrategy(PromptStrategy):
system_prompt, response_prefill = self.build_system_prompt(
ai_profile=ai_profile,
ai_directives=ai_directives,
functions=commands,
commands=commands,
)
self.commands = commands
@@ -135,7 +141,7 @@ class CodeFlowAgentPromptStrategy(PromptStrategy):
self,
ai_profile: AIProfile,
ai_directives: AIDirectives,
functions: list[CompletionModelFunction],
commands: Iterable[Command],
) -> tuple[str, str]:
"""
Builds the system prompt.
@@ -153,7 +159,7 @@ class CodeFlowAgentPromptStrategy(PromptStrategy):
" in the next message. Your job is to complete the task, "
"and terminate when your task is done."
]
+ ["## Available Functions\n" + self._generate_function_headers(functions)]
+ ["## Available Functions\n" + self._generate_function_headers(commands)]
+ ["## RESPONSE FORMAT\n" + response_fmt_instruction]
)
@@ -195,8 +201,29 @@ class CodeFlowAgentPromptStrategy(PromptStrategy):
# "simple strategies with no legal complications.",
]
def _generate_function_headers(self, funcs: list[CompletionModelFunction]) -> str:
return "\n\n".join(f.fmt_header(force_async=True) for f in funcs)
def _generate_function_headers(self, commands: Iterable[Command]) -> str:
return "\n\n".join(
f.header
+ "\n"
+ indent(
(
'"""\n'
f"{f.description}\n\n"
"Params:\n"
+ indent(
"\n".join(
f"{param.name}: {param.spec.description}"
for param in f.parameters
if param.spec.description
)
)
+ "\n"
'"""\n'
"pass"
),
)
for f in commands
)
async def parse_response_content(
self,
@@ -220,21 +247,21 @@ class CodeFlowAgentPromptStrategy(PromptStrategy):
raise ValueError("python_code is empty")
available_functions = {
f.name: FunctionDef(
name=f.name,
arg_types=[(name, p.python_type) for name, p in f.parameters.items()],
arg_descs={name: p.description for name, p in f.parameters.items()},
c.name: FunctionDef(
name=c.name,
arg_types=[(p.name, p.spec.python_type) for p in c.parameters],
arg_descs={p.name: p.spec.description for p in c.parameters},
arg_defaults={
name: p.default or "None"
for name, p in f.parameters.items()
if p.default or not p.required
p.name: p.spec.default or "None"
for p in c.parameters
if p.spec.default or not p.spec.required
},
return_type=f.return_type,
return_type=c.return_type,
return_desc="Output of the function",
function_desc=f.description,
is_async=True,
function_desc=c.description,
is_async=c.is_async,
)
for f in self.commands
for c in self.commands
}
available_functions.update(
{

View File

@@ -6,6 +6,7 @@ import re
from logging import Logger
import distro
from forge.command import Command
from forge.config.ai_directives import AIDirectives
from forge.config.ai_profile import AIProfile
from forge.json.parsing import extract_dict_from_json
@@ -16,6 +17,7 @@ from forge.llm.providers.schema import (
ChatMessage,
CompletionModelFunction,
)
from forge.llm.providers.utils import function_specs_from_commands
from forge.models.action import ActionProposal
from forge.models.config import SystemConfiguration, UserConfigurable
from forge.models.json_schema import JSONSchema
@@ -117,7 +119,7 @@ class OneShotAgentPromptStrategy(PromptStrategy):
task: str,
ai_profile: AIProfile,
ai_directives: AIDirectives,
commands: list[CompletionModelFunction],
commands: list[Command],
include_os_info: bool,
**extras,
) -> ChatPrompt:
@@ -125,10 +127,11 @@ class OneShotAgentPromptStrategy(PromptStrategy):
1. System prompt
3. `cycle_instruction`
"""
functions = function_specs_from_commands(commands)
system_prompt, response_prefill = self.build_system_prompt(
ai_profile=ai_profile,
ai_directives=ai_directives,
commands=commands,
functions=functions,
include_os_info=include_os_info,
)
@@ -142,14 +145,14 @@ class OneShotAgentPromptStrategy(PromptStrategy):
final_instruction_msg,
],
prefill_response=response_prefill,
functions=commands if self.config.use_functions_api else [],
functions=functions if self.config.use_functions_api else [],
)
def build_system_prompt(
self,
ai_profile: AIProfile,
ai_directives: AIDirectives,
commands: list[CompletionModelFunction],
functions: list[CompletionModelFunction],
include_os_info: bool,
) -> tuple[str, str]:
"""
@@ -169,7 +172,7 @@ class OneShotAgentPromptStrategy(PromptStrategy):
self.config.body_template.format(
constraints=format_numbered_list(ai_directives.constraints),
resources=format_numbered_list(ai_directives.resources),
commands=self._generate_commands_list(commands),
commands=self._generate_commands_list(functions),
best_practices=format_numbered_list(ai_directives.best_practices),
)
]

View File

@@ -1,13 +1,16 @@
import logging
from typing import Optional
import pytest
from forge.agent.protocols import CommandProvider
from forge.command import command
from forge.components.code_flow_executor.code_flow_executor import (
CodeFlowExecutionComponent,
)
from forge.config.ai_directives import AIDirectives
from forge.config.ai_profile import AIProfile
from forge.llm.providers import AssistantChatMessage
from forge.llm.providers.schema import CompletionModelFunction, JSONSchema
from forge.llm.providers.schema import JSONSchema
from autogpt.agents.prompt_strategies.code_flow import CodeFlowAgentPromptStrategy
@@ -16,32 +19,38 @@ config = CodeFlowAgentPromptStrategy.default_configuration.copy(deep=True)
prompt_strategy = CodeFlowAgentPromptStrategy(config, logger)
class MockWebSearchProvider(CommandProvider):
def get_commands(self):
yield self.mock_web_search
@command(
description="Searches the web",
parameters={
"query": JSONSchema(
type=JSONSchema.Type.STRING,
description="The search query",
required=True,
),
"num_results": JSONSchema(
type=JSONSchema.Type.INTEGER,
description="The number of results to return",
minimum=1,
maximum=10,
required=False,
),
},
)
def mock_web_search(self, query: str, num_results: Optional[int] = None) -> str:
return "results"
@pytest.mark.asyncio
async def test_code_flow_build_prompt():
commands = [
CompletionModelFunction(
name="web_search",
description="Searches the web",
parameters={
"query": JSONSchema(
type=JSONSchema.Type.STRING,
description="The search query",
required=True,
),
"num_results": JSONSchema(
type=JSONSchema.Type.INTEGER,
description="The number of results to return",
minimum=1,
maximum=10,
required=False,
),
},
),
]
commands = list(MockWebSearchProvider().get_commands())
ai_profile = AIProfile()
ai_profile.ai_name = "DummyGPT"
ai_profile.ai_goals = "A model for testing purpose"
ai_profile.ai_goals = ["A model for testing purposes"]
ai_profile.ai_role = "Help Testing"
ai_directives = AIDirectives()
@@ -59,7 +68,9 @@ async def test_code_flow_build_prompt():
)
)
assert "DummyGPT" in prompt
assert "async def web_search(query: str, num_results: int = None)" in prompt
assert (
"def mock_web_search(query: str, num_results: Optional[int] = None)" in prompt
)
@pytest.mark.asyncio

View File

@@ -41,6 +41,10 @@ class Command(Generic[P, CO]):
self.method = cast(Callable[P, CO], method)
self.parameters = parameters
@property
def name(self) -> str:
return self.names[0] # TODO: fallback to other name if first one is taken
@property
def is_async(self) -> bool:
return inspect.iscoroutinefunction(self.method)
@@ -52,6 +56,21 @@ class Command(Generic[P, CO]):
return None
return type.__name__
@property
def header(self) -> str:
"""Returns a function header representing the command's signature
Examples:
```py
def execute_python_code(code: str) -> str:
async def extract_info_from_content(content: str, instruction: str, output_type: type[~T]) -> ~T:
""" # noqa
return (
f"{'async ' if self.is_async else ''}"
f"def {self.name}{inspect.signature(self.method)}:"
)
def _parameters_match(
self, func: Callable, parameters: list[CommandParameter]
) -> bool:
@@ -78,7 +97,7 @@ class Command(Generic[P, CO]):
for param in self.parameters
]
return (
f"{self.names[0]}: {self.description.rstrip('.')}. "
f"{self.name}: {self.description.rstrip('.')}. "
f"Params: ({', '.join(params)})"
)

View File

@@ -80,7 +80,7 @@ def function_specs_from_commands(
"""Get LLM-consumable function specs for the agent's available commands."""
return [
CompletionModelFunction(
name=command.names[0],
name=command.name,
description=command.description,
parameters={param.name: param.spec for param in command.parameters},
)