Compare commits

...

58 Commits

Author SHA1 Message Date
Reinier van der Leer
7e0b1156cc feat(agent): Improve history format for code flow execution results 2024-07-23 23:15:50 +02:00
Reinier van der Leer
da9360fdeb feat(agent/api): Pretty-print execute_code_flow proposal in Agent Protocol output 2024-07-23 22:55:16 +02:00
Reinier van der Leer
9dea6a273e Merge branch 'master' into zamilmajdy/code-validation 2024-07-23 22:42:32 +02:00
Reinier van der Leer
e19636ac3e feat(agent/cli): Pretty-print code flow proposal
- Amend `main.py:update_user(..)` to improve displaying of existing + new prompt strategy output
- Replace `"execute_code_flow"` literals with references to the command definition
2024-07-23 22:18:54 +02:00
Reinier van der Leer
f03c6546b8 Merge branch 'master' into zamilmajdy/code-validation 2024-07-23 20:38:30 +02:00
Reinier van der Leer
2c4afd4458 Migrate autogpt/agents/prompt_strategies/code_flow.py to Pydantic v2 2024-07-04 01:07:16 -06:00
Reinier van der Leer
8b1d416de3 Merge branch 'master' into zamilmajdy/code-validation 2024-07-04 01:05:40 -06:00
Reinier van der Leer
7f6b7d6d7e remove unused import in forge/llm/providers/openai.py 2024-07-02 13:05:53 -06:00
Reinier van der Leer
736ac778cc Merge branch 'master' into zamilmajdy/code-validation 2024-07-02 13:05:31 -06:00
Reinier van der Leer
38eafdbb66 Update CodeFlowPromptStrategy with upstream changes (#7223) 2024-07-02 04:29:55 +02:00
Reinier van der Leer
6d9f564dc5 Merge branch 'master' into zamilmajdy/code-validation 2024-07-01 20:16:31 -06:00
Reinier van der Leer
3e675123d7 Merge branch 'master' into zamilmajdy/code-validation 2024-06-27 14:12:19 -06:00
Reinier van der Leer
37cc047656 lint-fix + minor refactor 2024-06-25 09:55:13 -07:00
Reinier van der Leer
9f804080ed address feedback: pass commands getter to CodeFlowExecutionComponent(..) 2024-06-25 09:30:26 -07:00
Reinier van der Leer
680fbf49aa Merge branch 'master' into zamilmajdy/code-validation 2024-06-24 20:42:52 -07:00
Krzysztof Czerwinski
901dadefc3 Merge branch 'master' into zamilmajdy/code-validation 2024-06-19 13:05:40 +02:00
Nicholas Tindle
e204491c6c Merge branch 'master' into zamilmajdy/code-validation 2024-06-13 17:33:44 -05:00
Zamil Majdy
3597f801a7 Merge branch 'master' of github.com:Significant-Gravitas/AutoGPT into zamilmajdy/code-validation 2024-06-10 13:05:05 +07:00
Zamil Majdy
b59862c402 Address comment 2024-06-10 13:04:54 +07:00
Reinier van der Leer
81bac301e8 fix type issues 2024-06-08 23:44:45 +02:00
Reinier van der Leer
a9eb49d54e Merge branch 'master' into zamilmajdy/code-validation 2024-06-08 21:52:36 +02:00
Reinier van der Leer
2c6e1eb4c8 fix type issue in test_code_flow_strategy.py 2024-06-08 21:38:22 +02:00
Reinier van der Leer
3e8849b08e fix linting and type issues 2024-06-08 21:32:10 +02:00
Reinier van der Leer
111e8585b5 feat(forge/llm): allow async completion parsers 2024-06-08 21:29:35 +02:00
Reinier van der Leer
8144d26cef fix type issues 2024-06-08 21:02:44 +02:00
Reinier van der Leer
e264bf7764 forge.llm.providers.schema + code_flow_executor lint-fix and cleanup 2024-06-08 15:28:52 +02:00
Reinier van der Leer
6dd0975236 clean up & improve @command decorator
- add ability to extract parameter descriptions from docstring
- add ability to determine parameter JSON schemas from function signature
- add `JSONSchema.from_python_type` factory
2024-06-08 15:05:45 +02:00
Reinier van der Leer
c3acb99314 clean up forge.command.command 2024-06-08 15:01:00 +02:00
Reinier van der Leer
0578fb0246 fix async issues with code flow execution 2024-06-08 02:03:22 +02:00
Reinier van der Leer
731d0345f0 implement annotation expansion for non-builtin types 2024-06-08 02:01:20 +02:00
Reinier van der Leer
b4cd735f26 fix name collision with type in Command.return_type 2024-06-07 12:57:30 +02:00
Reinier van der Leer
6e715b6c71 simplify function header generation 2024-06-07 12:56:41 +02:00
Reinier van der Leer
fcca4cc893 clarify execute_code_flow 2024-06-03 22:00:34 +02:00
Reinier van der Leer
5c7c276c10 Merge branch 'master' into zamilmajdy/code-validation 2024-06-03 21:43:59 +02:00
Zamil Majdy
ae63aa8ebb Merge remote-tracking branch 'origin/zamilmajdy/code-validation' into zamilmajdy/code-validation 2024-05-20 22:39:47 +07:00
Zamil Majdy
fdd9f9b5ec Log fix 2024-05-20 22:39:30 +07:00
Zamil Majdy
a825aa8515 Merge branch 'master' into zamilmajdy/code-validation 2024-05-20 16:53:52 +02:00
Zamil Majdy
ae43136c2c Fix linting 2024-05-20 18:48:44 +07:00
Zamil Majdy
c8e16f3fe1 Fix linting 2024-05-20 18:42:36 +07:00
Zamil Majdy
3a60504138 isort 2024-05-20 18:21:17 +07:00
Zamil Majdy
dfa77739c3 Remove unnecessary changes 2024-05-20 18:14:39 +07:00
Zamil Majdy
9f6e25664c Debug Log changes 2024-05-20 18:11:53 +07:00
Zamil Majdy
3c4ff60e11 Add unit tests 2024-05-20 18:09:16 +07:00
Zamil Majdy
47eeaf0325 Revert dumb changes 2024-05-20 17:07:55 +07:00
Zamil Majdy
81ad3cb69a Merge conflicts 2024-05-20 17:00:25 +07:00
Zamil Majdy
834eb6c6e0 Some quality polishing 2024-05-20 15:56:18 +07:00
Zamil Majdy
fb802400ba Add return type 2024-05-17 17:10:54 +02:00
Zamil Majdy
922e643737 Fix Await fiasco 2024-05-17 00:57:29 +02:00
Zamil Majdy
7b5272f1f2 Fix Await fiasco 2024-05-17 00:50:11 +02:00
Zamil Majdy
ea134c7dbd Benchmark test 2024-05-16 20:09:10 +02:00
Zamil Majdy
f7634524fa More prompt engineering 2024-05-16 19:53:42 +02:00
Zamil Majdy
0eccbe1483 Prompt change 2024-05-15 21:04:52 +02:00
Zamil Majdy
0916df4df7 Fix async fiasco 2024-05-15 19:30:29 +02:00
Zamil Majdy
22e2373a0b Add code flow as a loop 2024-05-15 17:10:51 +02:00
Zamil Majdy
40426e4646 Merge master 2024-05-14 23:36:31 +02:00
Zamil Majdy
ef1fe7c4e8 Update notebook 2024-05-11 12:12:30 +02:00
Reinier van der Leer
ca7ca226ff one_shot_flow.ipynb + edits to make it work 2024-05-10 20:03:40 +02:00
Zamil Majdy
ed5f12c02b Add code validation 2024-05-10 16:43:53 +02:00
27 changed files with 2448 additions and 90 deletions

View File

@@ -23,6 +23,7 @@ from forge.components.code_executor.code_executor import (
CodeExecutorComponent,
CodeExecutorConfiguration,
)
from forge.components.code_flow_executor import CodeFlowExecutionComponent
from forge.components.context.context import AgentContext, ContextComponent
from forge.components.file_manager import FileManagerComponent
from forge.components.git_operations import GitOperationsComponent
@@ -40,7 +41,6 @@ from forge.llm.providers import (
ChatModelResponse,
MultiProvider,
)
from forge.llm.providers.utils import function_specs_from_commands
from forge.models.action import (
ActionErrorResult,
ActionInterruptedByHuman,
@@ -56,6 +56,7 @@ from forge.utils.exceptions import (
)
from pydantic import Field
from .prompt_strategies.code_flow import CodeFlowAgentPromptStrategy
from .prompt_strategies.one_shot import (
OneShotAgentActionProposal,
OneShotAgentPromptStrategy,
@@ -96,11 +97,14 @@ class Agent(BaseAgent[OneShotAgentActionProposal], Configurable[AgentSettings]):
llm_provider: MultiProvider,
file_storage: FileStorage,
app_config: AppConfig,
prompt_strategy_class: type[
OneShotAgentPromptStrategy | CodeFlowAgentPromptStrategy
] = CodeFlowAgentPromptStrategy,
):
super().__init__(settings)
self.llm_provider = llm_provider
prompt_config = OneShotAgentPromptStrategy.default_configuration.model_copy(
prompt_config = prompt_strategy_class.default_configuration.model_copy(
deep=True
)
prompt_config.use_functions_api = (
@@ -108,7 +112,7 @@ class Agent(BaseAgent[OneShotAgentActionProposal], Configurable[AgentSettings]):
# Anthropic currently doesn't support tools + prefilling :(
and self.llm.provider_name != "anthropic"
)
self.prompt_strategy = OneShotAgentPromptStrategy(prompt_config, logger)
self.prompt_strategy = prompt_strategy_class(prompt_config, logger)
self.commands: list[Command] = []
# Components
@@ -145,6 +149,7 @@ class Agent(BaseAgent[OneShotAgentActionProposal], Configurable[AgentSettings]):
self.watchdog = WatchdogComponent(settings.config, settings.history).run_after(
ContextComponent
)
self.code_flow_executor = CodeFlowExecutionComponent(lambda: self.commands)
self.event_history = settings.history
self.app_config = app_config
@@ -185,7 +190,7 @@ class Agent(BaseAgent[OneShotAgentActionProposal], Configurable[AgentSettings]):
task=self.state.task,
ai_profile=self.state.ai_profile,
ai_directives=directives,
commands=function_specs_from_commands(self.commands),
commands=self.commands,
include_os_info=include_os_info,
)
@@ -201,9 +206,7 @@ class Agent(BaseAgent[OneShotAgentActionProposal], Configurable[AgentSettings]):
if exception:
prompt.messages.append(ChatMessage.system(f"Error: {exception}"))
response: ChatModelResponse[
OneShotAgentActionProposal
] = await self.llm_provider.create_chat_completion(
response: ChatModelResponse = await self.llm_provider.create_chat_completion(
prompt.messages,
model_name=self.llm.name,
completion_parser=self.prompt_strategy.parse_response_content,
@@ -281,7 +284,7 @@ class Agent(BaseAgent[OneShotAgentActionProposal], Configurable[AgentSettings]):
except AgentException:
raise
except Exception as e:
raise CommandExecutionError(str(e))
raise CommandExecutionError(str(e)) from e
def _get_command(self, command_name: str) -> Command:
for command in reversed(self.commands):

View File

@@ -0,0 +1,355 @@
import inspect
import re
from logging import Logger
from typing import Callable, Iterable, Sequence, get_args, get_origin
from forge.command import Command
from forge.components.code_flow_executor import CodeFlowExecutionComponent
from forge.config.ai_directives import AIDirectives
from forge.config.ai_profile import AIProfile
from forge.json.parsing import extract_dict_from_json
from forge.llm.prompting import ChatPrompt, LanguageModelClassification, PromptStrategy
from forge.llm.prompting.utils import indent
from forge.llm.providers.schema import (
AssistantChatMessage,
AssistantFunctionCall,
ChatMessage,
)
from forge.models.config import SystemConfiguration
from forge.models.json_schema import JSONSchema
from forge.utils.exceptions import InvalidAgentResponseError
from forge.utils.function.code_validation import CodeValidator
from forge.utils.function.model import FunctionDef
from pydantic import BaseModel, Field
from autogpt.agents.prompt_strategies.one_shot import (
AssistantThoughts,
OneShotAgentActionProposal,
OneShotAgentPromptConfiguration,
)
_RESPONSE_INTERFACE_NAME = "AssistantResponse"
class CodeFlowAgentActionProposal(BaseModel):
thoughts: AssistantThoughts
immediate_plan: str = Field(
...,
description="We will be running an iterative process to execute the plan, "
"Write the partial / immediate plan to execute your plan as detailed and "
"efficiently as possible without the help of the reasoning/intelligence. "
"The plan should describe the output of the immediate plan, so that the next "
"iteration can be executed by taking the output into account. "
"Try to do as much as possible without making any assumption or uninformed "
"guesses. Avoid large output at all costs!!!\n"
"Format: Objective[Objective of this iteration, explain what's the use of this "
"iteration for the next one] Plan[Plan that does not require any reasoning or "
"intelligence] Output[Output of the plan / should be small, avoid whole file "
"output]",
)
python_code: str = Field(
...,
description=(
"Write the fully-functional Python code of the immediate plan. "
"The output will be an `async def main() -> str` function of the immediate "
"plan that return the string output, the output will be passed into the "
"LLM context window so avoid returning the whole content!. "
"Use ONLY the listed available functions and built-in Python features. "
"Leverage the given magic functions to implement function calls for which "
"the arguments can't be determined yet. "
"Example:`async def main() -> str:\n"
" return await provided_function('arg1', 'arg2').split('\\n')[0]`"
),
)
FINAL_INSTRUCTION: str = (
"You have to give the answer in the from of JSON schema specified previously. "
"For the `python_code` field, you have to write Python code to execute your plan "
"as efficiently as possible. Your code will be executed directly without any "
"editing, if it doesn't work you will be held responsible. "
"Use ONLY the listed available functions and built-in Python features. "
"Do not make uninformed assumptions "
"(e.g. about the content or format of an unknown file). Leverage the given magic "
"functions to implement function calls for which the arguments can't be determined "
"yet. Reduce the amount of unnecessary data passed into these magic functions "
"where possible, because magic costs money and magically processing large amounts "
"of data is expensive. If you think are done with the task, you can simply call "
"finish(reason='your reason') to end the task, "
"a function that has one `finish` command, don't mix finish with other functions! "
"If you still need to do other functions, "
"let the next cycle execute the `finish` function. "
"Avoid hard-coding input values as input, and avoid returning large outputs. "
"The code that you have been executing in the past cycles can also be buggy, "
"so if you see undesired output, you can always try to re-plan, and re-code. "
)
class CodeFlowAgentPromptStrategy(PromptStrategy):
default_configuration: OneShotAgentPromptConfiguration = (
OneShotAgentPromptConfiguration()
)
def __init__(
self,
configuration: SystemConfiguration,
logger: Logger,
):
self.config = configuration
self.response_schema = JSONSchema.from_dict(
CodeFlowAgentActionProposal.model_json_schema()
)
self.logger = logger
self.commands: Sequence[Command] = [] # Sequence -> disallow list modification
@property
def llm_classification(self) -> LanguageModelClassification:
return LanguageModelClassification.SMART_MODEL # FIXME: dynamic switching
def build_prompt(
self,
*,
messages: list[ChatMessage],
task: str,
ai_profile: AIProfile,
ai_directives: AIDirectives,
commands: Sequence[Command],
**extras,
) -> ChatPrompt:
"""Constructs and returns a prompt with the following structure:
1. System prompt
3. `cycle_instruction`
"""
system_prompt, response_prefill = self.build_system_prompt(
ai_profile=ai_profile,
ai_directives=ai_directives,
commands=commands,
)
self.commands = commands
final_instruction_msg = ChatMessage.system(FINAL_INSTRUCTION)
return ChatPrompt(
messages=[
ChatMessage.system(system_prompt),
ChatMessage.user(f'"""{task}"""'),
*messages,
*(
[final_instruction_msg]
if not any(m.role == "assistant" for m in messages)
else []
),
],
prefill_response=response_prefill,
)
def build_system_prompt(
self,
ai_profile: AIProfile,
ai_directives: AIDirectives,
commands: Iterable[Command],
) -> tuple[str, str]:
"""
Builds the system prompt.
Returns:
str: The system prompt body
str: The desired start for the LLM's response; used to steer the output
"""
response_fmt_instruction, response_prefill = self.response_format_instruction()
system_prompt_parts = (
self._generate_intro_prompt(ai_profile)
+ [
"## Your Task\n"
"The user will specify a task for you to execute, in triple quotes,"
" in the next message. Your job is to complete the task, "
"and terminate when your task is done."
]
+ ["## Available Functions\n" + self._generate_function_headers(commands)]
+ ["## RESPONSE FORMAT\n" + response_fmt_instruction]
)
# Join non-empty parts together into paragraph format
return (
"\n\n".join(filter(None, system_prompt_parts)).strip("\n"),
response_prefill,
)
def response_format_instruction(self) -> tuple[str, str]:
response_schema = self.response_schema.model_copy(deep=True)
assert response_schema.properties
# Unindent for performance
response_format = re.sub(
r"\n\s+",
"\n",
response_schema.to_typescript_object_interface(_RESPONSE_INTERFACE_NAME),
)
response_prefill = f'{{\n "{list(response_schema.properties.keys())[0]}":'
return (
(
f"YOU MUST ALWAYS RESPOND WITH A JSON OBJECT OF THE FOLLOWING TYPE:\n"
f"{response_format}"
),
response_prefill,
)
def _generate_intro_prompt(self, ai_profile: AIProfile) -> list[str]:
"""Generates the introduction part of the prompt.
Returns:
list[str]: A list of strings forming the introduction part of the prompt.
"""
return [
f"You are {ai_profile.ai_name}, {ai_profile.ai_role.rstrip('.')}.",
# "Your decisions must always be made independently without seeking "
# "user assistance. Play to your strengths as an LLM and pursue "
# "simple strategies with no legal complications.",
]
def _generate_function_headers(self, commands: Iterable[Command]) -> str:
function_stubs: list[str] = []
annotation_types_in_context: set[type] = set()
for f in commands:
# Add source code of non-builtin types from function signatures
new_annotation_types = extract_annotation_types(f.method).difference(
annotation_types_in_context
)
new_annotation_types_src = [
f"# {a.__module__}.{a.__qualname__}\n{inspect.getsource(a)}"
for a in new_annotation_types
]
annotation_types_in_context.update(new_annotation_types)
param_descriptions = "\n".join(
f"{param.name}: {param.spec.description}"
for param in f.parameters
if param.spec.description
)
full_function_stub = (
("\n".join(new_annotation_types_src) + "\n" + f.header).strip()
+ "\n"
+ indent(
(
'"""\n'
f"{f.description}\n\n"
f"Params:\n{indent(param_descriptions)}\n"
'"""\n'
"pass"
),
)
)
function_stubs.append(full_function_stub)
return "\n\n\n".join(function_stubs)
async def parse_response_content(
self,
response: AssistantChatMessage,
) -> OneShotAgentActionProposal:
if not response.content:
raise InvalidAgentResponseError("Assistant response has no text content")
self.logger.debug(
"LLM response content:"
+ (
f"\n{response.content}"
if "\n" in response.content
else f" '{response.content}'"
)
)
assistant_reply_dict = extract_dict_from_json(response.content)
parsed_response = CodeFlowAgentActionProposal.model_validate(
assistant_reply_dict
)
if not parsed_response.python_code:
raise ValueError("python_code is empty")
available_functions = {
c.name: FunctionDef(
name=c.name,
arg_types=[(p.name, p.spec.python_type) for p in c.parameters],
arg_descs={p.name: p.spec.description for p in c.parameters},
arg_defaults={
p.name: p.spec.default or "None"
for p in c.parameters
if p.spec.default or not p.spec.required
},
return_type=c.return_type,
return_desc="Output of the function",
function_desc=c.description,
is_async=c.is_async,
)
for c in self.commands
}
available_functions.update(
{
"main": FunctionDef(
name="main",
arg_types=[],
arg_descs={},
return_type="str",
return_desc="Output of the function",
function_desc="The main function to execute the plan",
is_async=True,
)
}
)
code_validation = await CodeValidator(
function_name="main",
available_functions=available_functions,
).validate_code(parsed_response.python_code)
clean_response = response.model_copy()
clean_response.content = parsed_response.model_dump_json(indent=4)
# TODO: prevent combining finish with other functions
if _finish_call := re.search(
r"finish\((reason=)?(.*?)\)", code_validation.functionCode
):
finish_reason = _finish_call.group(2)[1:-1] # remove quotes
result = OneShotAgentActionProposal(
thoughts=parsed_response.thoughts,
use_tool=AssistantFunctionCall(
name="finish",
arguments={"reason": finish_reason},
),
raw_message=clean_response,
)
else:
result = OneShotAgentActionProposal(
thoughts=parsed_response.thoughts,
use_tool=AssistantFunctionCall(
name=CodeFlowExecutionComponent.execute_code_flow.name,
arguments={
"python_code": code_validation.functionCode,
"plan_text": parsed_response.immediate_plan,
},
),
raw_message=clean_response,
)
return result
def extract_annotation_types(func: Callable) -> set[type]:
annotation_types = set()
for annotation in inspect.get_annotations(func).values():
annotation_types.update(_get_nested_types(annotation))
return annotation_types
def _get_nested_types(annotation: type) -> Iterable[type]:
if _args := get_args(annotation):
for a in _args:
yield from _get_nested_types(a)
if not _is_builtin_type(_a := get_origin(annotation) or annotation):
yield _a
def _is_builtin_type(_type: type):
"""Check if a given type is a built-in type."""
import sys
return _type.__module__ in sys.stdlib_module_names

View File

@@ -6,6 +6,7 @@ import re
from logging import Logger
import distro
from forge.command import Command
from forge.config.ai_directives import AIDirectives
from forge.config.ai_profile import AIProfile
from forge.json.parsing import extract_dict_from_json
@@ -16,6 +17,7 @@ from forge.llm.providers.schema import (
ChatMessage,
CompletionModelFunction,
)
from forge.llm.providers.utils import function_specs_from_commands
from forge.models.action import ActionProposal
from forge.models.config import SystemConfiguration, UserConfigurable
from forge.models.json_schema import JSONSchema
@@ -27,13 +29,21 @@ _RESPONSE_INTERFACE_NAME = "AssistantResponse"
class AssistantThoughts(ModelWithSummary):
past_action_summary: str = Field(
...,
description="Summary of the last action you took, if there is none, "
"you can leave it empty",
)
observations: str = Field(
description="Relevant observations from your last action (if any)"
description="Relevant observations from your last actions (if any)"
)
text: str = Field(description="Thoughts")
reasoning: str = Field(description="Reasoning behind the thoughts")
self_criticism: str = Field(description="Constructive self-criticism")
plan: list[str] = Field(description="Short list that conveys the long-term plan")
plan: list[str] = Field(
description="Short list that conveys the long-term plan, "
"considering the progress on your task so far",
)
speak: str = Field(description="Summary of thoughts, to say to user")
def summary(self) -> str:
@@ -101,7 +111,7 @@ class OneShotAgentPromptStrategy(PromptStrategy):
@property
def llm_classification(self) -> LanguageModelClassification:
return LanguageModelClassification.FAST_MODEL # FIXME: dynamic switching
return LanguageModelClassification.SMART_MODEL # FIXME: dynamic switching
def build_prompt(
self,
@@ -110,7 +120,7 @@ class OneShotAgentPromptStrategy(PromptStrategy):
task: str,
ai_profile: AIProfile,
ai_directives: AIDirectives,
commands: list[CompletionModelFunction],
commands: list[Command],
include_os_info: bool,
**extras,
) -> ChatPrompt:
@@ -118,10 +128,11 @@ class OneShotAgentPromptStrategy(PromptStrategy):
1. System prompt
3. `cycle_instruction`
"""
functions = function_specs_from_commands(commands)
system_prompt, response_prefill = self.build_system_prompt(
ai_profile=ai_profile,
ai_directives=ai_directives,
commands=commands,
functions=functions,
include_os_info=include_os_info,
)
@@ -135,14 +146,14 @@ class OneShotAgentPromptStrategy(PromptStrategy):
final_instruction_msg,
],
prefill_response=response_prefill,
functions=commands if self.config.use_functions_api else [],
functions=functions if self.config.use_functions_api else [],
)
def build_system_prompt(
self,
ai_profile: AIProfile,
ai_directives: AIDirectives,
commands: list[CompletionModelFunction],
functions: list[CompletionModelFunction],
include_os_info: bool,
) -> tuple[str, str]:
"""
@@ -162,7 +173,7 @@ class OneShotAgentPromptStrategy(PromptStrategy):
self.config.body_template.format(
constraints=format_numbered_list(ai_directives.constraints),
resources=format_numbered_list(ai_directives.resources),
commands=self._generate_commands_list(commands),
commands=self._generate_commands_list(functions),
best_practices=format_numbered_list(ai_directives.best_practices),
)
]

View File

@@ -23,6 +23,7 @@ from forge.agent_protocol.models import (
TaskRequestBody,
TaskStepsListResponse,
)
from forge.components.code_flow_executor import CodeFlowExecutionComponent
from forge.file_storage import FileStorage
from forge.llm.providers import ModelProviderBudget, MultiProvider
from forge.models.action import ActionErrorResult, ActionSuccessResult
@@ -298,11 +299,16 @@ class AgentProtocolServer:
else ""
)
output += f"{assistant_response.thoughts.speak}\n\n"
output += (
f"Next Command: {next_tool_to_use}"
if next_tool_to_use.name != ASK_COMMAND
else next_tool_to_use.arguments["question"]
)
if next_tool_to_use.name == CodeFlowExecutionComponent.execute_code_flow.name:
code = next_tool_to_use.arguments["python_code"]
plan = next_tool_to_use.arguments["plan_text"]
output += f"Code for next step:\n```py\n# {plan}\n\n{code}\n```"
else:
output += (
f"Next Command: {next_tool_to_use}"
if next_tool_to_use.name != ASK_COMMAND
else next_tool_to_use.arguments["question"]
)
additional_output = {
**(

View File

@@ -630,6 +630,9 @@ def update_user(
command_args: The arguments for the command.
assistant_reply_dict: The assistant's reply.
"""
from forge.components.code_flow_executor import CodeFlowExecutionComponent
from forge.llm.prompting.utils import indent
logger = logging.getLogger(__name__)
print_assistant_thoughts(
@@ -644,15 +647,29 @@ def update_user(
# First log new-line so user can differentiate sections better in console
print()
safe_tool_name = remove_ansi_escape(action_proposal.use_tool.name)
logger.info(
f"COMMAND = {Fore.CYAN}{safe_tool_name}{Style.RESET_ALL} "
f"ARGUMENTS = {Fore.CYAN}{action_proposal.use_tool.arguments}{Style.RESET_ALL}",
extra={
"title": "NEXT ACTION:",
"title_color": Fore.CYAN,
"preserve_color": True,
},
)
if safe_tool_name == CodeFlowExecutionComponent.execute_code_flow.name:
plan = action_proposal.use_tool.arguments["plan_text"]
code = action_proposal.use_tool.arguments["python_code"]
logger.info(
f"\n{indent(code, f'{Fore.GREEN}>>> {Fore.RESET}')}\n",
extra={
"title": "PROPOSED ACTION:",
"title_color": Fore.GREEN,
"preserve_color": True,
},
)
logger.debug(
f"{plan}\n", extra={"title": "EXPLANATION:", "title_color": Fore.YELLOW}
)
else:
logger.info(
str(action_proposal.use_tool),
extra={
"title": "PROPOSED ACTION:",
"title_color": Fore.GREEN,
"preserve_color": True,
},
)
async def get_user_feedback(
@@ -732,6 +749,12 @@ def print_assistant_thoughts(
)
if isinstance(thoughts, AssistantThoughts):
if thoughts.observations:
print_attribute(
"OBSERVATIONS",
remove_ansi_escape(thoughts.observations),
title_color=Fore.YELLOW,
)
print_attribute(
"REASONING", remove_ansi_escape(thoughts.reasoning), title_color=Fore.YELLOW
)
@@ -753,7 +776,7 @@ def print_assistant_thoughts(
line.strip(), extra={"title": "- ", "title_color": Fore.GREEN}
)
print_attribute(
"CRITICISM",
"SELF-CRITICISM",
remove_ansi_escape(thoughts.self_criticism),
title_color=Fore.YELLOW,
)
@@ -764,7 +787,7 @@ def print_assistant_thoughts(
speak(assistant_thoughts_speak)
else:
print_attribute(
"SPEAK", assistant_thoughts_speak, title_color=Fore.YELLOW
"TL;DR", assistant_thoughts_speak, title_color=Fore.YELLOW
)
else:
speak(thoughts_text)

30
autogpt/poetry.lock generated
View File

@@ -4216,7 +4216,7 @@ test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"]
name = "ptyprocess"
version = "0.7.0"
description = "Run a subprocess in a pseudo terminal"
optional = true
optional = false
python-versions = "*"
files = [
{file = "ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35"},
@@ -5212,6 +5212,32 @@ files = [
[package.dependencies]
pyasn1 = ">=0.1.3"
[[package]]
name = "ruff"
version = "0.4.4"
description = "An extremely fast Python linter and code formatter, written in Rust."
optional = false
python-versions = ">=3.7"
files = [
{file = "ruff-0.4.4-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:29d44ef5bb6a08e235c8249294fa8d431adc1426bfda99ed493119e6f9ea1bf6"},
{file = "ruff-0.4.4-py3-none-macosx_11_0_arm64.whl", hash = "sha256:c4efe62b5bbb24178c950732ddd40712b878a9b96b1d02b0ff0b08a090cbd891"},
{file = "ruff-0.4.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4c8e2f1e8fc12d07ab521a9005d68a969e167b589cbcaee354cb61e9d9de9c15"},
{file = "ruff-0.4.4-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:60ed88b636a463214905c002fa3eaab19795679ed55529f91e488db3fe8976ab"},
{file = "ruff-0.4.4-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b90fc5e170fc71c712cc4d9ab0e24ea505c6a9e4ebf346787a67e691dfb72e85"},
{file = "ruff-0.4.4-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:8e7e6ebc10ef16dcdc77fd5557ee60647512b400e4a60bdc4849468f076f6eef"},
{file = "ruff-0.4.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b9ddb2c494fb79fc208cd15ffe08f32b7682519e067413dbaf5f4b01a6087bcd"},
{file = "ruff-0.4.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c51c928a14f9f0a871082603e25a1588059b7e08a920f2f9fa7157b5bf08cfe9"},
{file = "ruff-0.4.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b5eb0a4bfd6400b7d07c09a7725e1a98c3b838be557fee229ac0f84d9aa49c36"},
{file = "ruff-0.4.4-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:b1867ee9bf3acc21778dcb293db504692eda5f7a11a6e6cc40890182a9f9e595"},
{file = "ruff-0.4.4-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:1aecced1269481ef2894cc495647392a34b0bf3e28ff53ed95a385b13aa45768"},
{file = "ruff-0.4.4-py3-none-musllinux_1_2_i686.whl", hash = "sha256:9da73eb616b3241a307b837f32756dc20a0b07e2bcb694fec73699c93d04a69e"},
{file = "ruff-0.4.4-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:958b4ea5589706a81065e2a776237de2ecc3e763342e5cc8e02a4a4d8a5e6f95"},
{file = "ruff-0.4.4-py3-none-win32.whl", hash = "sha256:cb53473849f011bca6e754f2cdf47cafc9c4f4ff4570003a0dad0b9b6890e876"},
{file = "ruff-0.4.4-py3-none-win_amd64.whl", hash = "sha256:424e5b72597482543b684c11def82669cc6b395aa8cc69acc1858b5ef3e5daae"},
{file = "ruff-0.4.4-py3-none-win_arm64.whl", hash = "sha256:39df0537b47d3b597293edbb95baf54ff5b49589eb7ff41926d8243caa995ea6"},
{file = "ruff-0.4.4.tar.gz", hash = "sha256:f87ea42d5cdebdc6a69761a9d0bc83ae9b3b30d0ad78952005ba6568d6c022af"},
]
[[package]]
name = "s3transfer"
version = "0.10.0"
@@ -6758,4 +6784,4 @@ benchmark = ["agbenchmark"]
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "b3d4efee5861b32152024dada1ec61f4241122419cb538012c00a6ed55ac8a4b"
content-hash = "c729e10fd5ac85400d2499397974d1b1831fed3b591657a2fea9e86501b96e19"

View File

@@ -30,9 +30,12 @@ gitpython = "^3.1.32"
hypercorn = "^0.14.4"
openai = "^1.7.2"
orjson = "^3.8.10"
ptyprocess = "^0.7.0"
pydantic = "^2.7.2"
pyright = "^1.1.364"
python-dotenv = "^1.0.0"
requests = "*"
ruff = "^0.4.4"
sentry-sdk = "^1.40.4"
# Benchmarking
@@ -47,7 +50,6 @@ black = "^23.12.1"
flake8 = "^7.0.0"
isort = "^5.13.1"
pre-commit = "*"
pyright = "^1.1.364"
# Type stubs
types-colorama = "*"

View File

@@ -0,0 +1,126 @@
import logging
from typing import Optional
import pytest
from forge.agent.protocols import CommandProvider
from forge.command import Command, command
from forge.components.code_flow_executor import CodeFlowExecutionComponent
from forge.config.ai_directives import AIDirectives
from forge.config.ai_profile import AIProfile
from forge.llm.providers import AssistantChatMessage
from forge.llm.providers.schema import JSONSchema
from autogpt.agents.prompt_strategies.code_flow import CodeFlowAgentPromptStrategy
logger = logging.getLogger(__name__)
config = CodeFlowAgentPromptStrategy.default_configuration.copy(deep=True)
prompt_strategy = CodeFlowAgentPromptStrategy(config, logger)
class MockWebSearchProvider(CommandProvider):
def get_commands(self):
yield self.mock_web_search
@command(
description="Searches the web",
parameters={
"query": JSONSchema(
type=JSONSchema.Type.STRING,
description="The search query",
required=True,
),
"num_results": JSONSchema(
type=JSONSchema.Type.INTEGER,
description="The number of results to return",
minimum=1,
maximum=10,
required=False,
),
},
)
def mock_web_search(self, query: str, num_results: Optional[int] = None) -> str:
return "results"
@pytest.mark.asyncio
async def test_code_flow_build_prompt():
commands = list(MockWebSearchProvider().get_commands())
ai_profile = AIProfile()
ai_profile.ai_name = "DummyGPT"
ai_profile.ai_goals = ["A model for testing purposes"]
ai_profile.ai_role = "Help Testing"
ai_directives = AIDirectives()
ai_directives.resources = ["resource_1"]
ai_directives.constraints = ["constraint_1"]
ai_directives.best_practices = ["best_practice_1"]
prompt = str(
prompt_strategy.build_prompt(
task="Figure out from file.csv how much was spent on utilities",
messages=[],
ai_profile=ai_profile,
ai_directives=ai_directives,
commands=commands,
)
)
assert "DummyGPT" in prompt
assert (
"def mock_web_search(query: str, num_results: Optional[int] = None)" in prompt
)
@pytest.mark.asyncio
async def test_code_flow_parse_response():
response_content = """
{
"thoughts": {
"past_action_summary": "This is the past action summary.",
"observations": "This is the observation.",
"text": "Some text on the AI's thoughts.",
"reasoning": "This is the reasoning.",
"self_criticism": "This is the self-criticism.",
"plan": [
"Plan 1",
"Plan 2",
"Plan 3"
],
"speak": "This is what the AI would say."
},
"immediate_plan": "Objective[objective1] Plan[plan1] Output[out1]",
"python_code": "async def main() -> str:\n return 'You passed the test.'",
}
"""
response = await CodeFlowAgentPromptStrategy(config, logger).parse_response_content(
AssistantChatMessage(content=response_content)
)
assert "This is the observation." == response.thoughts.observations
assert "This is the reasoning." == response.thoughts.reasoning
assert CodeFlowExecutionComponent.execute_code_flow.name == response.use_tool.name
assert "async def main() -> str" in response.use_tool.arguments["python_code"]
assert (
"Objective[objective1] Plan[plan1] Output[out1]"
in response.use_tool.arguments["plan_text"]
)
@pytest.mark.asyncio
async def test_code_flow_execution():
executor = CodeFlowExecutionComponent(
lambda: [
Command(
names=["test_func"],
description="",
parameters=[],
method=lambda: "You've passed the test!",
)
]
)
result = await executor.execute_code_flow(
python_code="async def main() -> str:\n return test_func()",
plan_text="This is the plan text.",
)
assert "You've passed the test!" in result

View File

@@ -0,0 +1,75 @@
import pytest
from forge.utils.function.code_validation import CodeValidator, FunctionDef
@pytest.mark.asyncio
async def test_code_validation():
validator = CodeValidator(
available_functions={
"read_webpage": FunctionDef(
name="read_webpage",
arg_types=[("url", "str"), ("query", "str")],
arg_descs={
"url": "URL to read",
"query": "Query to search",
"return_type": "Type of return value",
},
return_type="str",
return_desc="Information matching the query",
function_desc="Read a webpage and return the info matching the query",
is_async=True,
),
"web_search": FunctionDef(
name="web_search",
arg_types=[("query", "str")],
arg_descs={"query": "Query to search"},
return_type="list[(str,str)]",
return_desc="List of tuples with title and URL",
function_desc="Search the web and return the search results",
is_async=True,
),
"main": FunctionDef(
name="main",
arg_types=[],
arg_descs={},
return_type="str",
return_desc="Answer in the text format",
function_desc="Get the num of contributors to the autogpt github repo",
is_async=False,
),
},
available_objects={},
)
response = await validator.validate_code(
raw_code="""
def crawl_info(url: str, query: str) -> str | None:
info = await read_webpage(url, query)
if info:
return info
urls = await read_webpage(url, "autogpt github contributor page")
for url in urls.split('\\n'):
info = await crawl_info(url, query)
if info:
return info
return None
def hehe():
return 'hehe'
def main() -> str:
query = "Find the number of contributors to the autogpt github repository"
for title, url in ("autogpt github contributor page"):
info = await crawl_info(url, query)
if info:
return info
x = await hehe()
return "No info found"
""",
packages=[],
)
assert response.functionCode is not None
assert "async def crawl_info" in response.functionCode # async is added
assert "async def main" in response.functionCode
assert "x = hehe()" in response.functionCode # await is removed

View File

@@ -1,17 +1,13 @@
from __future__ import annotations
import inspect
from typing import Callable, Concatenate, Generic, ParamSpec, TypeVar, cast
from forge.agent.protocols import CommandProvider
from typing import Callable, Generic, ParamSpec, TypeVar
from .parameter import CommandParameter
P = ParamSpec("P")
CO = TypeVar("CO") # command output
_CP = TypeVar("_CP", bound=CommandProvider)
class Command(Generic[P, CO]):
"""A class representing a command.
@@ -26,37 +22,60 @@ class Command(Generic[P, CO]):
self,
names: list[str],
description: str,
method: Callable[Concatenate[_CP, P], CO],
method: Callable[P, CO],
parameters: list[CommandParameter],
):
# Check if all parameters are provided
if not self._parameters_match(method, parameters):
raise ValueError(
f"Command {names[0]} has different parameters than provided schema"
)
self.names = names
self.description = description
# Method technically has a `self` parameter, but we can ignore that
# since Python passes it internally.
self.method = cast(Callable[P, CO], method)
self.method = method
self.parameters = parameters
# Check if all parameters are provided
if not self._parameters_match_signature():
raise ValueError(
f"Command {self.name} has different parameters than provided schema"
)
@property
def name(self) -> str:
return self.names[0] # TODO: fallback to other name if first one is taken
@property
def is_async(self) -> bool:
return inspect.iscoroutinefunction(self.method)
def _parameters_match(
self, func: Callable, parameters: list[CommandParameter]
) -> bool:
@property
def return_type(self) -> str:
_type = inspect.signature(self.method).return_annotation
if _type == inspect.Signature.empty:
return "None"
return _type.__name__
@property
def header(self) -> str:
"""Returns a function header representing the command's signature
Examples:
```py
def execute_python_code(code: str) -> str:
async def extract_info_from_content(content: str, instruction: str, output_type: type[~T]) -> ~T:
""" # noqa
return (
f"{'async ' if self.is_async else ''}"
f"def {self.name}{inspect.signature(self.method)}:"
)
def _parameters_match_signature(self) -> bool:
# Get the function's signature
signature = inspect.signature(func)
signature = inspect.signature(self.method)
# Extract parameter names, ignoring 'self' for methods
func_param_names = [
param.name
for param in signature.parameters.values()
if param.name != "self"
]
names = [param.name for param in parameters]
names = [param.name for param in self.parameters]
# Check if sorted lists of names/keys are equal
return sorted(func_param_names) == sorted(names)
@@ -71,7 +90,7 @@ class Command(Generic[P, CO]):
for param in self.parameters
]
return (
f"{self.names[0]}: {self.description.rstrip('.')}. "
f"{self.name}: {self.description.rstrip('.')}. "
f"Params: ({', '.join(params)})"
)

View File

@@ -1,21 +1,28 @@
import inspect
import logging
import re
from typing import Callable, Concatenate, Optional, TypeVar
from typing import Callable, Concatenate, Optional, TypeVar, cast
from forge.agent.protocols import CommandProvider
from forge.models.json_schema import JSONSchema
from .command import CO, Command, CommandParameter, P
logger = logging.getLogger(__name__)
_CP = TypeVar("_CP", bound=CommandProvider)
def command(
names: list[str] = [],
names: Optional[list[str]] = None,
description: Optional[str] = None,
parameters: dict[str, JSONSchema] = {},
) -> Callable[[Callable[Concatenate[_CP, P], CO]], Command[P, CO]]:
parameters: Optional[dict[str, JSONSchema]] = None,
) -> Callable[[Callable[Concatenate[_CP, P], CO] | Callable[P, CO]], Command[P, CO]]:
"""
The command decorator is used to make a Command from a function.
Make a `Command` from a function or a method on a `CommandProvider`.
All parameters are optional if the decorated function has a fully featured
docstring. For the requirements of such a docstring,
see `get_param_descriptions_from_docstring`.
Args:
names (list[str]): The names of the command.
@@ -27,34 +34,141 @@ def command(
that the command executes.
"""
def decorator(func: Callable[Concatenate[_CP, P], CO]) -> Command[P, CO]:
doc = func.__doc__ or ""
def decorator(
func: Callable[Concatenate[_CP, P], CO] | Callable[P, CO]
) -> Command[P, CO]:
# If names is not provided, use the function name
command_names = names or [func.__name__]
# If description is not provided, use the first part of the docstring
if not (command_description := description):
if not func.__doc__:
raise ValueError("Description is required if function has no docstring")
# Return the part of the docstring before double line break or everything
command_description = re.sub(r"\s+", " ", doc.split("\n\n")[0].strip())
_names = names or [func.__name__]
# If description is not provided, use the first part of the docstring
docstring = inspect.getdoc(func)
if not (_description := description):
if not docstring:
raise ValueError(
"'description' is required if function has no docstring"
)
_description = get_clean_description_from_docstring(docstring)
if not (_parameters := parameters):
if not docstring:
raise ValueError(
"'parameters' is required if function has no docstring"
)
# Combine descriptions from docstring with JSONSchemas from annotations
param_descriptions = get_param_descriptions_from_docstring(docstring)
_parameters = get_params_json_schemas(func)
for param, param_schema in _parameters.items():
if desc := param_descriptions.get(param):
param_schema.description = desc
# Parameters
typed_parameters = [
CommandParameter(
name=param_name,
spec=spec,
)
for param_name, spec in parameters.items()
for param_name, spec in _parameters.items()
]
# Wrap func with Command
command = Command(
names=command_names,
description=command_description,
method=func,
names=_names,
description=_description,
# Method technically has a `self` parameter, but we can ignore that
# since Python passes it internally.
method=cast(Callable[P, CO], func),
parameters=typed_parameters,
)
return command
return decorator
def get_clean_description_from_docstring(docstring: str) -> str:
"""Return the part of the docstring before double line break or everything"""
return re.sub(r"\s+", " ", docstring.split("\n\n")[0].strip())
def get_params_json_schemas(func: Callable) -> dict[str, JSONSchema]:
"""Gets the annotations of the given function and converts them to JSONSchemas"""
result: dict[str, JSONSchema] = {}
for name, parameter in inspect.signature(func).parameters.items():
if name == "return":
continue
param_schema = result[name] = JSONSchema.from_python_type(parameter.annotation)
if parameter.default:
param_schema.default = parameter.default
param_schema.required = False
return result
def get_param_descriptions_from_docstring(docstring: str) -> dict[str, str]:
"""
Get parameter descriptions from a docstring. Requirements for the docstring:
- The section describing the parameters MUST start with `Params:` or `Args:`, in any
capitalization.
- An entry describing a parameter MUST be indented by 4 spaces.
- An entry describing a parameter MUST start with the parameter name, an optional
type annotation, followed by a colon `:`.
- Continuations of parameter descriptions MUST be indented relative to the first
line of the entry.
- The docstring must not be indented as a whole. To get a docstring with the uniform
indentation stripped off, use `inspect.getdoc(func)`.
Example:
```python
\"\"\"
This is the description. This will be ignored.
The description can span multiple lines,
or contain any number of line breaks.
Params:
param1: This is a simple parameter description.
param2 (list[str]): This parameter also has a type annotation.
param3: This parameter has a long description. This means we will have to let it
continue on the next line. The continuation is indented relative to the first
line of the entry.
param4: Extra line breaks to group parameters are allowed. This will not break
the algorithm.
This text is
is indented by
less than 4 spaces
and is interpreted as the end of the `Params:` section.
\"\"\"
```
"""
param_descriptions: dict[str, str] = {}
param_section = False
last_param_name = ""
for line in docstring.split("\n"):
if not line.strip(): # ignore empty lines
continue
if line.lower().startswith(("params:", "args:")):
param_section = True
continue
if param_section:
if line.strip() and not line.startswith(" " * 4): # end of section
break
line = line[4:]
if line.startswith(" ") and last_param_name: # continuation of description
param_descriptions[last_param_name] += " " + line.strip()
else:
if _match := re.match(r"^(\w+).*?: (.*)", line):
param_name = _match.group(1)
param_desc = _match.group(2).strip()
else:
logger.warning(
f"Invalid line in docstring's parameter section: {repr(line)}"
)
continue
param_descriptions[param_name] = param_desc
last_param_name = param_name
return param_descriptions

View File

@@ -102,6 +102,8 @@ class ActionHistoryComponent(
@staticmethod
def _make_result_message(episode: Episode, result: ActionResult) -> ChatMessage:
from forge.components.code_flow_executor import CodeFlowExecutionComponent
if result.status == "success":
return (
ToolResultMessage(
@@ -110,11 +112,18 @@ class ActionHistoryComponent(
)
if episode.action.raw_message.tool_calls
else ChatMessage.user(
f"{episode.action.use_tool.name} returned: "
(
"Execution result:"
if (
episode.action.use_tool.name
== CodeFlowExecutionComponent.execute_code_flow.name
)
else f"{episode.action.use_tool.name} returned:"
)
+ (
f"```\n{result.outputs}\n```"
f"\n```\n{result.outputs}\n```"
if "\n" in str(result.outputs)
else f"`{result.outputs}`"
else f" `{result.outputs}`"
)
)
)

View File

@@ -0,0 +1,3 @@
from .code_flow_executor import CodeFlowExecutionComponent
__all__ = ["CodeFlowExecutionComponent"]

View File

@@ -0,0 +1,69 @@
"""Commands to generate images based on text input"""
import inspect
import logging
from typing import Any, Callable, Iterable, Iterator
from forge.agent.protocols import CommandProvider
from forge.command import Command, command
from forge.models.json_schema import JSONSchema
MAX_RESULT_LENGTH = 1000
logger = logging.getLogger(__name__)
class CodeFlowExecutionComponent(CommandProvider):
"""A component that provides commands to execute code flow."""
def __init__(self, get_available_commands: Callable[[], Iterable[Command]]):
self.get_available_commands = get_available_commands
def get_commands(self) -> Iterator[Command]:
yield self.execute_code_flow
@command(
parameters={
"python_code": JSONSchema(
type=JSONSchema.Type.STRING,
description="The Python code to execute",
required=True,
),
"plan_text": JSONSchema(
type=JSONSchema.Type.STRING,
description="The plan to written in a natural language",
required=False,
),
},
)
async def execute_code_flow(self, python_code: str, plan_text: str) -> str:
"""Execute the code flow.
Args:
python_code: The Python code to execute
plan_text: The plan implemented by the given Python code
Returns:
str: The result of the code execution
"""
locals: dict[str, Any] = {}
locals.update(self._get_available_functions())
code = f"{python_code}" "\n\n" "exec_output = main()"
logger.debug(f"Code-Flow Execution code:\n```py\n{code}\n```")
exec(code, locals)
result = await locals["exec_output"]
logger.debug(f"Code-Flow Execution result:\n{result}")
if inspect.isawaitable(result):
result = await result
# limit the result to limit the characters
if len(result) > MAX_RESULT_LENGTH:
result = result[:MAX_RESULT_LENGTH] + "...[Truncated, Content is too long]"
return result
def _get_available_functions(self) -> dict[str, Callable]:
return {
name: command
for command in self.get_available_commands()
for name in command.names
if name != self.execute_code_flow.name
}

View File

@@ -169,7 +169,8 @@ class FileManagerComponent(
parameters={
"folder": JSONSchema(
type=JSONSchema.Type.STRING,
description="The folder to list files in",
description="The folder to list files in. "
"Pass an empty string to list files in the workspace.",
required=True,
)
},

View File

@@ -25,7 +25,7 @@ class UserInteractionComponent(CommandProvider):
},
)
def ask_user(self, question: str) -> str:
"""If you need more details or information regarding the given goals,
"""If you need more details or information regarding the given task,
you can ask the user for input."""
print(f"\nQ: {question}")
resp = click.prompt("A")

View File

@@ -1,3 +1,4 @@
import inspect
import logging
from typing import (
Any,
@@ -154,7 +155,10 @@ class BaseOpenAIChatProvider(
self,
model_prompt: list[ChatMessage],
model_name: _ModelName,
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
completion_parser: (
Callable[[AssistantChatMessage], Awaitable[_T]]
| Callable[[AssistantChatMessage], _T]
) = lambda _: None,
functions: Optional[list[CompletionModelFunction]] = None,
max_output_tokens: Optional[int] = None,
prefill_response: str = "",
@@ -208,7 +212,15 @@ class BaseOpenAIChatProvider(
parsed_result: _T = None # type: ignore
if not parse_errors:
try:
parsed_result = completion_parser(assistant_msg)
parsed_result = (
await _result
if inspect.isawaitable(
_result := completion_parser(assistant_msg)
)
# cast(..) needed because inspect.isawaitable(..) loses type:
# https://github.com/microsoft/pyright/issues/3690
else cast(_T, _result)
)
except Exception as e:
parse_errors.append(e)

View File

@@ -1,8 +1,19 @@
from __future__ import annotations
import enum
import inspect
import logging
from typing import TYPE_CHECKING, Any, Callable, Optional, ParamSpec, Sequence, TypeVar
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Optional,
ParamSpec,
Sequence,
TypeVar,
cast,
)
import sentry_sdk
import tenacity
@@ -171,7 +182,10 @@ class AnthropicProvider(BaseChatModelProvider[AnthropicModelName, AnthropicSetti
self,
model_prompt: list[ChatMessage],
model_name: AnthropicModelName,
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
completion_parser: (
Callable[[AssistantChatMessage], Awaitable[_T]]
| Callable[[AssistantChatMessage], _T]
) = lambda _: None,
functions: Optional[list[CompletionModelFunction]] = None,
max_output_tokens: Optional[int] = None,
prefill_response: str = "",
@@ -237,7 +251,14 @@ class AnthropicProvider(BaseChatModelProvider[AnthropicModelName, AnthropicSetti
+ "\n".join(str(e) for e in tool_call_errors)
)
parsed_result = completion_parser(assistant_msg)
# cast(..) needed because inspect.isawaitable(..) loses type info:
# https://github.com/microsoft/pyright/issues/3690
parsed_result = cast(
_T,
await _result
if inspect.isawaitable(_result := completion_parser(assistant_msg))
else _result,
)
break
except Exception as e:
self._logger.debug(

View File

@@ -1,7 +1,16 @@
from __future__ import annotations
import logging
from typing import Any, AsyncIterator, Callable, Optional, Sequence, TypeVar, get_args
from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Optional,
Sequence,
TypeVar,
get_args,
)
from pydantic import ValidationError
@@ -99,7 +108,10 @@ class MultiProvider(BaseChatModelProvider[ModelName, ModelProviderSettings]):
self,
model_prompt: list[ChatMessage],
model_name: ModelName,
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
completion_parser: (
Callable[[AssistantChatMessage], Awaitable[_T]]
| Callable[[AssistantChatMessage], _T]
) = lambda _: None,
functions: Optional[list[CompletionModelFunction]] = None,
max_output_tokens: Optional[int] = None,
prefill_response: str = "",

View File

@@ -6,6 +6,7 @@ from collections import defaultdict
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
ClassVar,
Generic,
@@ -135,6 +136,8 @@ class CompletionModelFunction(BaseModel):
name: str
description: str
parameters: dict[str, "JSONSchema"]
return_type: str | None = None
is_async: bool = False
def fmt_line(self) -> str:
params = ", ".join(
@@ -143,6 +146,44 @@ class CompletionModelFunction(BaseModel):
)
return f"{self.name}: {self.description}. Params: ({params})"
def fmt_function_stub(self, impl: str = "pass") -> str:
"""
Formats and returns a function stub as a string with types and descriptions.
Returns:
str: The formatted function header.
"""
from forge.llm.prompting.utils import indent
params = ", ".join(
f"{name}: {p.python_type}"
+ (
f" = {str(p.default)}"
if p.default
else " = None"
if not p.required
else ""
)
for name, p in self.parameters.items()
)
_def = "async def" if self.is_async else "def"
_return = f" -> {self.return_type}" if self.return_type else ""
return f"{_def} {self.name}({params}){_return}:\n" + indent(
'"""\n'
f"{self.description}\n\n"
"Params:\n"
+ indent(
"\n".join(
f"{name}: {param.description}"
for name, param in self.parameters.items()
if param.description
)
)
+ "\n"
'"""\n'
f"{impl}"
)
def validate_call(
self, function_call: AssistantFunctionCall
) -> tuple[bool, list["ValidationError"]]:
@@ -415,7 +456,10 @@ class BaseChatModelProvider(BaseModelProvider[_ModelName, _ModelProviderSettings
self,
model_prompt: list[ChatMessage],
model_name: _ModelName,
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
completion_parser: (
Callable[[AssistantChatMessage], Awaitable[_T]]
| Callable[[AssistantChatMessage], _T]
) = lambda _: None,
functions: Optional[list[CompletionModelFunction]] = None,
max_output_tokens: Optional[int] = None,
prefill_response: str = "",

View File

@@ -80,7 +80,7 @@ def function_specs_from_commands(
"""Get LLM-consumable function specs for the agent's available commands."""
return [
CompletionModelFunction(
name=command.names[0],
name=command.name,
description=command.description,
parameters={param.name: param.spec for param in command.parameters},
)

View File

@@ -1,6 +1,9 @@
import ast
import enum
import typing
from textwrap import indent
from typing import Optional, overload
from types import NoneType
from typing import Any, Optional, is_typeddict, overload
from jsonschema import Draft7Validator, ValidationError
from pydantic import BaseModel
@@ -14,14 +17,17 @@ class JSONSchema(BaseModel):
NUMBER = "number"
INTEGER = "integer"
BOOLEAN = "boolean"
TYPE = "type"
# TODO: add docstrings
description: Optional[str] = None
type: Optional[Type] = None
enum: Optional[list] = None
required: bool = False
default: Any = None
items: Optional["JSONSchema"] = None
properties: Optional[dict[str, "JSONSchema"]] = None
additional_properties: Optional["JSONSchema"] = None
minimum: Optional[int | float] = None
maximum: Optional[int | float] = None
minItems: Optional[int] = None
@@ -31,6 +37,7 @@ class JSONSchema(BaseModel):
schema: dict = {
"type": self.type.value if self.type else None,
"description": self.description,
"default": repr(self.default),
}
if self.type == "array":
if self.items:
@@ -45,6 +52,8 @@ class JSONSchema(BaseModel):
schema["required"] = [
name for name, prop in self.properties.items() if prop.required
]
if self.additional_properties:
schema["additionalProperties"] = self.additional_properties.to_dict()
elif self.enum:
schema["enum"] = self.enum
else:
@@ -63,11 +72,15 @@ class JSONSchema(BaseModel):
return JSONSchema(
description=schema.get("description"),
type=schema["type"],
default=ast.literal_eval(d) if (d := schema.get("default")) else None,
enum=schema.get("enum"),
items=JSONSchema.from_dict(schema["items"]) if "items" in schema else None,
items=JSONSchema.from_dict(i) if (i := schema.get("items")) else None,
properties=JSONSchema.parse_properties(schema)
if schema["type"] == "object"
else None,
additional_properties=JSONSchema.from_dict(ap)
if schema["type"] == "object" and (ap := schema.get("additionalProperties"))
else None,
minimum=schema.get("minimum"),
maximum=schema.get("maximum"),
minItems=schema.get("minItems"),
@@ -123,6 +136,82 @@ class JSONSchema(BaseModel):
f"interface {interface_name} " if interface_name else ""
) + f"{{\n{indent(attributes_string, ' ')}\n}}"
_PYTHON_TO_JSON_TYPE: dict[typing.Type, Type] = {
int: Type.INTEGER,
str: Type.STRING,
bool: Type.BOOLEAN,
float: Type.NUMBER,
}
@classmethod
def from_python_type(cls, T: typing.Type) -> "JSONSchema":
if _t := cls._PYTHON_TO_JSON_TYPE.get(T):
partial_schema = cls(type=_t, required=True)
elif (
typing.get_origin(T) is typing.Union and typing.get_args(T)[-1] is NoneType
):
if len(typing.get_args(T)[:-1]) > 1:
raise NotImplementedError("Union types are currently not supported")
partial_schema = cls.from_python_type(typing.get_args(T)[0])
partial_schema.required = False
return partial_schema
elif issubclass(T, BaseModel):
partial_schema = JSONSchema.from_dict(T.schema())
elif T is list or typing.get_origin(T) is list:
partial_schema = JSONSchema(
type=JSONSchema.Type.ARRAY,
items=JSONSchema.from_python_type(T_v)
if (T_v := typing.get_args(T)[0])
else None,
)
elif T is dict or typing.get_origin(T) is dict:
partial_schema = JSONSchema(
type=JSONSchema.Type.OBJECT,
additional_properties=JSONSchema.from_python_type(T_v)
if (T_v := typing.get_args(T)[1])
else None,
)
elif is_typeddict(T):
partial_schema = JSONSchema(
type=JSONSchema.Type.OBJECT,
properties={
k: JSONSchema.from_python_type(v)
for k, v in T.__annotations__.items()
},
)
else:
raise TypeError(f"JSONSchema.from_python_type is not implemented for {T}")
partial_schema.required = True
return partial_schema
_JSON_TO_PYTHON_TYPE: dict[Type, typing.Type] = {
j: p for p, j in _PYTHON_TO_JSON_TYPE.items()
}
@property
def python_type(self) -> str:
if self.type in self._JSON_TO_PYTHON_TYPE:
return self._JSON_TO_PYTHON_TYPE[self.type].__name__
elif self.type == JSONSchema.Type.ARRAY:
return f"list[{self.items.python_type}]" if self.items else "list"
elif self.type == JSONSchema.Type.OBJECT:
if not self.properties:
return "dict"
raise NotImplementedError(
"JSONSchema.python_type doesn't support TypedDicts yet"
)
elif self.enum:
return "Union[" + ", ".join(repr(v) for v in self.enum) + "]"
elif self.type == JSONSchema.Type.TYPE:
return "type"
elif self.type is None:
return "Any"
else:
raise NotImplementedError(
f"JSONSchema.python_type does not support Type.{self.type.name} yet"
)
@property
def typescript_type(self) -> str:
if not self.type:
@@ -141,6 +230,10 @@ class JSONSchema(BaseModel):
return self.to_typescript_object_interface()
if self.enum:
return " | ".join(repr(v) for v in self.enum)
elif self.type == JSONSchema.Type.TYPE:
return "type"
elif self.type is None:
return "any"
raise NotImplementedError(
f"JSONSchema.typescript_type does not support Type.{self.type.name} yet"

View File

@@ -0,0 +1,515 @@
import ast
import collections
import datetime
import json
import logging
import pathlib
import re
import typing
import black
import isort
from forge.utils.function.model import FunctionDef, ObjectType, ValidationResponse
from forge.utils.function.visitor import FunctionVisitor
from forge.utils.function.util import (
genererate_line_error,
generate_object_code,
generate_compiled_code,
validate_matching_function,
)
from forge.utils.function.exec import (
exec_external_on_contents,
ExecError,
PROJECT_TEMP_DIR,
DEFAULT_DEPS,
execute_command,
setup_if_required,
)
logger = logging.getLogger(__name__)
class CodeValidator:
def __init__(
self,
function_name: str | None = None,
available_functions: dict[str, FunctionDef] | None = None,
available_objects: dict[str, ObjectType] | None = None,
):
self.func_name: str = function_name or ""
self.available_functions: dict[str, FunctionDef] = available_functions or {}
self.available_objects: dict[str, ObjectType] = available_objects or {}
async def reformat_code(
self,
code: str,
packages: list[str] = [],
) -> str:
"""
Reformat the code snippet
Args:
code (str): The code snippet to reformat
packages (list[str]): The list of packages to validate
Returns:
str: The reformatted code snippet
"""
try:
code = (
await self.validate_code(
raw_code=code,
packages=packages,
raise_validation_error=False,
add_code_stubs=False,
)
).get_compiled_code()
except Exception as e:
# We move on with unfixed code if there's an error
logger.warning(f"Error formatting code for route #{self.func_name}: {e}")
raise e
for formatter in [
lambda code: isort.code(code),
lambda code: black.format_str(code, mode=black.FileMode()),
]:
try:
code = formatter(code)
except Exception as e:
# We move on with unformatted code if there's an error
logger.warning(
f"Error formatting code for route #{self.func_name}: {e}"
)
return code
async def validate_code(
self,
raw_code: str,
packages: list[str] = [],
raise_validation_error: bool = True,
add_code_stubs: bool = True,
call_cnt: int = 0,
) -> ValidationResponse:
"""
Validate the code snippet for any error
Args:
packages (list[Package]): The list of packages to validate
raw_code (str): The code snippet to validate
Returns:
ValidationResponse: The response of the validation
Raise:
ValidationError(e): The list of validation errors in the code snippet
"""
validation_errors: list[str] = []
try:
tree = ast.parse(raw_code)
visitor = FunctionVisitor()
visitor.visit(tree)
validation_errors.extend(visitor.errors)
except Exception as e:
# parse invalid code line and add it to the error message
error = f"Error parsing code: {e}"
if "async lambda" in raw_code:
error += "\nAsync lambda is not supported in Python. "
"Use async def instead!"
if line := re.search(r"line (\d+)", error):
raise Exception(
genererate_line_error(error, raw_code, int(line.group(1)))
)
else:
raise Exception(error)
# Eliminate duplicate visitor.functions and visitor.objects, prefer the last one
visitor.imports = list(set(visitor.imports))
visitor.functions = list({f.name: f for f in visitor.functions}.values())
visitor.objects = list(
{
o.name: o
for o in visitor.objects
if o.name not in self.available_objects
}.values()
)
# Add implemented functions into the main function, only link the stub functions
deps_funcs = [f for f in visitor.functions if f.is_implemented]
stub_funcs = [f for f in visitor.functions if not f.is_implemented]
objects_block = zip(
["\n\n" + generate_object_code(obj) for obj in visitor.objects],
visitor.objectsIdx,
)
functions_block = zip(
["\n\n" + fun.function_code for fun in deps_funcs], visitor.functionsIdx
)
globals_block = zip(
["\n\n" + glob for glob in visitor.globals], visitor.globalsIdx
)
function_code = "".join(
code
for code, _ in sorted(
list(objects_block) + list(functions_block) + list(globals_block),
key=lambda x: x[1],
)
).strip()
# No need to validate main function if it's not provided
if self.func_name:
main_func = self.__validate_main_function(
deps_funcs=deps_funcs,
function_code=function_code,
validation_errors=validation_errors,
)
function_template = main_func.function_template
else:
function_template = None
# Validate that code is not re-declaring any existing entities.
already_declared_entities = set(
[
obj.name
for obj in visitor.objects
if obj.name in self.available_objects.keys()
]
+ [
func.name
for func in visitor.functions
if func.name in self.available_functions.keys()
]
)
if not already_declared_entities:
validation_errors.append(
"These class/function names has already been declared in the code, "
"no need to declare them again: " + ", ".join(already_declared_entities)
)
result = ValidationResponse(
function_name=self.func_name,
available_objects=self.available_objects,
available_functions=self.available_functions,
rawCode=function_code,
imports=visitor.imports.copy(),
objects=[], # Objects will be bundled in the function_code instead.
template=function_template or "",
functionCode=function_code,
functions=stub_funcs,
packages=packages,
)
# Execute static validators and fixers.
# print('old compiled code import ---->', result.imports)
old_compiled_code = generate_compiled_code(result, add_code_stubs)
validation_errors.extend(await static_code_analysis(result))
new_compiled_code = result.get_compiled_code()
# Auto-fixer works, retry validation (limit to 5 times, to avoid infinite loop)
if old_compiled_code != new_compiled_code and call_cnt < 5:
return await self.validate_code(
packages=packages,
raw_code=new_compiled_code,
raise_validation_error=raise_validation_error,
add_code_stubs=add_code_stubs,
call_cnt=call_cnt + 1,
)
if validation_errors:
if raise_validation_error:
error_message = "".join("\n * " + e for e in validation_errors)
raise Exception("Error validating code: " + error_message)
else:
# This should happen only on `reformat_code` call
logger.warning("Error validating code: %s", validation_errors)
return result
def __validate_main_function(
self,
deps_funcs: list[FunctionDef],
function_code: str,
validation_errors: list[str],
) -> FunctionDef:
"""
Validate the main function body and signature
Returns:
tuple[str, FunctionDef]: The function ID and the function object
"""
# Validate that the main function is implemented.
func_obj = next((f for f in deps_funcs if f.name == self.func_name), None)
if not func_obj or not func_obj.is_implemented:
raise Exception(
f"Main Function body {self.func_name} is not implemented."
f" Please complete the implementation of this function!"
)
func_obj.function_code = function_code
# Validate that the main function is matching the expected signature.
func_req: FunctionDef | None = self.available_functions.get(self.func_name)
if not func_req:
raise AssertionError(f"Function {self.func_name} does not exist on DB")
try:
validate_matching_function(func_obj, func_req)
except Exception as e:
validation_errors.append(e.__str__())
return func_obj
# ======= Static Code Validation Helper Functions =======#
async def static_code_analysis(func: ValidationResponse) -> list[str]:
"""
Run static code analysis on the function code and mutate the function code to
fix any issues.
Args:
func (ValidationResponse):
The function to run static code analysis on. `func` will be mutated.
Returns:
list[str]: The list of validation errors
"""
validation_errors = []
validation_errors += await __execute_ruff(func)
validation_errors += await __execute_pyright(func)
return validation_errors
CODE_SEPARATOR = "#------Code-Start------#"
def __pack_import_and_function_code(func: ValidationResponse) -> str:
return "\n".join(func.imports + [CODE_SEPARATOR, func.rawCode])
def __unpack_import_and_function_code(code: str) -> tuple[list[str], str]:
split = code.split(CODE_SEPARATOR)
return split[0].splitlines(), split[1].strip()
async def __execute_ruff(func: ValidationResponse) -> list[str]:
code = __pack_import_and_function_code(func)
try:
# Currently Disabled Rule List
# E402 module level import not at top of file
# F841 local variable is assigned to but never used
code = await exec_external_on_contents(
command_arguments=[
"ruff",
"check",
"--fix",
"--ignore",
"F841",
"--ignore",
"E402",
"--ignore",
"F811", # Redefinition of unused '...' from line ...
],
file_contents=code,
suffix=".py",
raise_file_contents_on_error=True,
)
func.imports, func.rawCode = __unpack_import_and_function_code(code)
return []
except ExecError as e:
if e.content:
# Ruff failed, but the code is reformatted
code = e.content
e = str(e)
error_messages = [
v
for v in str(e).split("\n")
if v.strip()
if re.match(r"Found \d+ errors?\.*", v) is None
]
added_imports, error_messages = await __fix_missing_imports(
error_messages, func
)
# Append problematic line to the error message or add it as TODO line
validation_errors: list[str] = []
split_pattern = r"(.+):(\d+):(\d+): (.+)"
for error_message in error_messages:
error_split = re.match(split_pattern, error_message)
if not error_split:
error = error_message
else:
_, line, _, error = error_split.groups()
error = genererate_line_error(error, code, int(line))
validation_errors.append(error)
func.imports, func.rawCode = __unpack_import_and_function_code(code)
func.imports.extend(added_imports) # Avoid line-code change, do it at the end.
return validation_errors
async def __execute_pyright(func: ValidationResponse) -> list[str]:
code = __pack_import_and_function_code(func)
validation_errors: list[str] = []
# Create temporary directory under the TEMP_DIR with random name
temp_dir = PROJECT_TEMP_DIR / (func.function_name)
py_path = await setup_if_required(temp_dir)
async def __execute_pyright_commands(code: str) -> list[str]:
try:
await execute_command(
["pip", "install", "-r", "requirements.txt"], temp_dir, py_path
)
except Exception as e:
# Unknown deps should be reported as validation errors
validation_errors.append(e.__str__())
# execute pyright
result = await execute_command(
["pyright", "--outputjson"], temp_dir, py_path, raise_on_error=False
)
if not result:
return []
try:
json_response = json.loads(result)["generalDiagnostics"]
except Exception as e:
logger.error(f"Error parsing pyright output, error: {e} output: {result}")
raise e
for e in json_response:
rule: str = e.get("rule", "")
severity: str = e.get("severity", "")
excluded_rules = ["reportRedeclaration"]
if severity != "error" or any([rule.startswith(r) for r in excluded_rules]):
continue
e = genererate_line_error(
error=f"{e['message']}. {e.get('rule', '')}",
code=code,
line_number=e["range"]["start"]["line"] + 1,
)
validation_errors.append(e)
# read code from code.py. split the code into imports and raw code
code = open(f"{temp_dir}/code.py").read()
code, error_messages = await __fix_async_calls(code, validation_errors)
func.imports, func.rawCode = __unpack_import_and_function_code(code)
return validation_errors
packages = "\n".join([str(p) for p in func.packages if p not in DEFAULT_DEPS])
(temp_dir / "requirements.txt").write_text(packages)
(temp_dir / "code.py").write_text(code)
return await __execute_pyright_commands(code)
async def find_module_dist_and_source(
module: str, py_path: pathlib.Path | str
) -> typing.Tuple[pathlib.Path | None, pathlib.Path | None]:
# Find the module in the env
modules_path = pathlib.Path(py_path).parent / "lib" / "python3.11" / "site-packages"
matches = modules_path.glob(f"{module}*")
# resolve the generator to an array
matches = list(matches)
if not matches:
return None, None
# find the dist info path and the module path
dist_info_path: typing.Optional[pathlib.Path] = None
module_path: typing.Optional[pathlib.Path] = None
# find the dist info path
for match in matches:
if re.match(f"{module}-[0-9]+.[0-9]+.[0-9]+.dist-info", match.name):
dist_info_path = match
break
# Get the module path
for match in matches:
if module == match.name:
module_path = match
break
return dist_info_path, module_path
AUTO_IMPORT_TYPES: dict[str, str] = {
"Enum": "from enum import Enum",
"array": "from array import array",
}
for t in typing.__all__:
AUTO_IMPORT_TYPES[t] = f"from typing import {t}"
for t in datetime.__all__:
AUTO_IMPORT_TYPES[t] = f"from datetime import {t}"
for t in collections.__all__:
AUTO_IMPORT_TYPES[t] = f"from collections import {t}"
async def __fix_async_calls(code: str, errors: list[str]) -> tuple[str, list[str]]:
"""
Fix the async calls in the code
Args:
code (str): The code snippet
errors (list[str]): The list of errors
func (ValidationResponse): The function to fix the async calls
Returns:
tuple[str, list[str]]: The fixed code snippet and the list of errors
"""
async_calls = set()
new_errors = []
for error in errors:
pattern = '"__await__" is not present. reportGeneralTypeIssues -> (.+)'
match = re.search(pattern, error)
if match:
async_calls.add(match.group(1))
else:
new_errors.append(error)
for async_call in async_calls:
func_call = re.search(r"await ([a-zA-Z0-9_]+)", async_call)
if func_call:
func_name = func_call.group(1)
code = code.replace(f"await {func_name}", f"{func_name}")
return code, new_errors
async def __fix_missing_imports(
errors: list[str], func: ValidationResponse
) -> tuple[set[str], list[str]]:
"""
Generate missing imports based on the errors
Args:
errors (list[str]): The list of errors
func (ValidationResponse): The function to fix the imports
Returns:
tuple[set[str], list[str]]: The set of missing imports and the list
of non-missing import errors
"""
missing_imports = []
filtered_errors = []
for error in errors:
pattern = r"Undefined name `(.+?)`"
match = re.search(pattern, error)
if not match:
filtered_errors.append(error)
continue
missing = match.group(1)
if missing in AUTO_IMPORT_TYPES:
missing_imports.append(AUTO_IMPORT_TYPES[missing])
elif missing in func.available_functions:
# TODO FIX THIS!! IMPORT AUTOGPT CORRECY SERVICE.
missing_imports.append(f"from project.{missing}_service import {missing}")
elif missing in func.available_objects:
# TODO FIX THIS!! IMPORT AUTOGPT CORRECY SERVICE.
missing_imports.append(f"from project.{missing}_object import {missing}")
else:
filtered_errors.append(error)
return set(missing_imports), filtered_errors

View File

@@ -0,0 +1,195 @@
import asyncio
import enum
import logging
import os
import subprocess
import tempfile
from asyncio.subprocess import Process
from pathlib import Path
logger = logging.getLogger(__name__)
class OutputType(enum.Enum):
STD_OUT = "stdout"
STD_ERR = "stderr"
BOTH = "both"
class ExecError(Exception):
content: str | None
def __init__(self, error: str, content: str | None = None):
super().__init__(error)
self.content = content
async def exec_external_on_contents(
command_arguments: list[str],
file_contents,
suffix: str = ".py",
output_type: OutputType = OutputType.BOTH,
raise_file_contents_on_error: bool = False,
) -> str:
"""
Execute an external tool with the provided command arguments and file contents
:param command_arguments: The command arguments to execute
:param file_contents: The file contents to execute the command on
:param suffix: The suffix of the temporary file. Default is ".py"
:return: The file contents after the command has been executed
Note: The file contents are written to a temporary file and the command is executed
on that file. The command arguments should be a list of strings, where the first
element is the command to execute and the rest of the elements are the arguments to
the command. There is no need to provide the file path as an argument, as it will
be appended to the command arguments.
Example:
exec_external(["ruff", "check"], "print('Hello World')")
will run the command "ruff check <temp_file_path>" with the file contents
"print('Hello World')" and return the file contents after the command
has been executed.
"""
errors = ""
if len(command_arguments) == 0:
raise ExecError("No command arguments provided")
# Run ruff to validate the code
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
temp_file_path = temp_file.name
temp_file.write(file_contents.encode("utf-8"))
temp_file.flush()
command_arguments.append(str(temp_file_path))
# Run Ruff on the temporary file
try:
r: Process = await asyncio.create_subprocess_exec(
*command_arguments,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
result = await r.communicate()
stdout, stderr = result[0].decode("utf-8"), result[1].decode("utf-8")
logger.debug(f"Output: {stdout}")
if temp_file_path in stdout:
stdout = stdout # .replace(temp_file.name, "/generated_file")
logger.debug(f"Errors: {stderr}")
if output_type == OutputType.STD_OUT:
errors = stdout
elif output_type == OutputType.STD_ERR:
errors = stderr
else:
errors = stdout + "\n" + stderr
with open(temp_file_path, "r") as f:
file_contents = f.read()
finally:
# Ensure the temporary file is deleted
os.remove(temp_file_path)
if not errors:
return file_contents
if raise_file_contents_on_error:
raise ExecError(errors, file_contents)
raise ExecError(errors)
FOLDER_NAME = "agpt-static-code-analysis"
PROJECT_PARENT_DIR = Path(__file__).resolve().parent.parent.parent / f".{FOLDER_NAME}"
PROJECT_TEMP_DIR = Path(tempfile.gettempdir()) / FOLDER_NAME
DEFAULT_DEPS = ["pyright", "pydantic", "virtualenv-clone"]
def is_env_exists(path: Path):
return (
(path / "venv/bin/python").exists()
and (path / "venv/bin/pip").exists()
and (path / "venv/bin/virtualenv-clone").exists()
and (path / "venv/bin/pyright").exists()
)
async def setup_if_required(
cwd: Path = PROJECT_PARENT_DIR, copy_from_parent: bool = True
) -> Path:
"""
Set-up the virtual environment if it does not exist
This setup is executed expectedly once per application run
Args:
cwd (Path): The current working directory
copy_from_parent (bool):
Whether to copy the virtual environment from PROJECT_PARENT_DIR
Returns:
Path: The path to the virtual environment
"""
if not cwd.exists():
cwd.mkdir(parents=True, exist_ok=True)
path = cwd / "venv/bin"
if is_env_exists(cwd):
return path
if copy_from_parent and cwd != PROJECT_PARENT_DIR:
if (cwd / "venv").exists():
await execute_command(["rm", "-rf", str(cwd / "venv")], cwd, None)
await execute_command(
["virtualenv-clone", str(PROJECT_PARENT_DIR / "venv"), str(cwd / "venv")],
cwd,
await setup_if_required(PROJECT_PARENT_DIR),
)
return path
# Create a virtual environment
output = await execute_command(["python", "-m", "venv", "venv"], cwd, None)
logger.info(f"[Setup] Created virtual environment: {output}")
# Install dependencies
output = await execute_command(["pip", "install", "-I"] + DEFAULT_DEPS, cwd, path)
logger.info(f"[Setup] Installed {DEFAULT_DEPS}: {output}")
output = await execute_command(["pyright"], cwd, path, raise_on_error=False)
logger.info(f"[Setup] Set up pyright: {output}")
return path
async def execute_command(
command: list[str],
cwd: str | Path | None,
python_path: str | Path | None = None,
raise_on_error: bool = True,
) -> str:
"""
Execute a command in the shell
Args:
command (list[str]): The command to execute
cwd (str | Path): The current working directory
python_path (str | Path): The python executable path
raise_on_error (bool): Whether to raise an error if the command fails
Returns:
str: The output of the command
"""
# Set the python path by replacing the env 'PATH' with the provided python path
venv = os.environ.copy()
if python_path:
# PATH prioritize first occurrence of python_path, so we need to prepend.
venv["PATH"] = f"{python_path}:{venv['PATH']}"
r = await asyncio.create_subprocess_exec(
*command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
cwd=str(cwd),
env=venv,
)
stdout, stderr = await r.communicate()
if r.returncode == 0:
return (stdout or stderr).decode("utf-8")
if raise_on_error:
raise Exception((stderr or stdout).decode("utf-8"))
else:
return (stderr or stdout).decode("utf-8")

View File

@@ -0,0 +1,110 @@
from typing import List, Optional
from pydantic import BaseModel, Field
class ObjectType(BaseModel):
name: str = Field(description="The name of the object")
code: Optional[str] = Field(description="The code of the object", default=None)
description: Optional[str] = Field(
description="The description of the object", default=None
)
Fields: List["ObjectField"] = Field(description="The fields of the object")
is_pydantic: bool = Field(
description="Whether the object is a pydantic model", default=True
)
is_implemented: bool = Field(
description="Whether the object is implemented", default=True
)
is_enum: bool = Field(description="Whether the object is an enum", default=False)
class ObjectField(BaseModel):
name: str = Field(description="The name of the field")
description: Optional[str] = Field(
description="The description of the field", default=None
)
type: str = Field(
description="The type of the field. Can be a string like List[str] or an use "
"any of they related types like list[User]",
)
value: Optional[str] = Field(description="The value of the field", default=None)
related_types: Optional[List[ObjectType]] = Field(
description="The related types of the field", default=[]
)
class FunctionDef(BaseModel):
name: str
arg_types: list[tuple[str, str]]
arg_defaults: dict[str, str] = {}
arg_descs: dict[str, str]
return_type: str | None = None
return_desc: str
function_desc: str
is_implemented: bool = False
function_code: str = ""
function_template: str | None = None
is_async: bool = False
def __generate_function_template(self) -> str:
args_str = ", ".join(
[
f"{name}: {type}"
+ (
f" = {self.arg_defaults.get(name, '')}"
if name in self.arg_defaults
else ""
)
for name, type in self.arg_types
]
)
arg_desc = f"\n{' '*4}".join(
[
f'{name} ({type}): {self.arg_descs.get(name, "-")}'
for name, type in self.arg_types
]
)
_def = "async def" if "await " in self.function_code or self.is_async else "def"
_return_type = f" -> {self.return_type}" if self.return_type else ""
func_desc = self.function_desc.replace("\n", "\n ")
template = f"""
{_def} {self.name}({args_str}){_return_type}:
\"\"\"
{func_desc}
Args:
{arg_desc}
Returns:
{self.return_type}{': ' + self.return_desc if self.return_desc else ''}
\"\"\"
pass
"""
return "\n".join([line for line in template.split("\n")]).strip()
def __init__(self, function_template: Optional[str] = None, **data):
super().__init__(**data)
self.function_template = (
function_template or self.__generate_function_template()
)
class ValidationResponse(BaseModel):
function_name: str
available_objects: dict[str, ObjectType]
available_functions: dict[str, FunctionDef]
template: str
rawCode: str
packages: List[str]
imports: List[str]
functionCode: str
functions: List[FunctionDef]
objects: List[ObjectType]
def get_compiled_code(self) -> str:
return "\n".join(self.imports) + "\n\n" + self.rawCode.strip()

View File

@@ -0,0 +1,292 @@
from typing import List, Tuple, __all__ as all_types
from forge.utils.function.model import FunctionDef, ObjectType, ValidationResponse
OPEN_BRACES = {"{": "Dict", "[": "List", "(": "Tuple"}
CLOSE_BRACES = {"}": "Dict", "]": "List", ")": "Tuple"}
RENAMED_TYPES = {
"dict": "Dict",
"list": "List",
"tuple": "Tuple",
"set": "Set",
"frozenset": "FrozenSet",
"type": "Type",
}
PYTHON_TYPES = set(all_types)
def unwrap_object_type(type: str) -> Tuple[str, List[str]]:
"""
Get the type and children of a composite type.
Args:
type (str): The type to parse.
Returns:
str: The type.
[str]: The children types.
"""
type = type.replace(" ", "")
if not type:
return "", []
def split_outer_level(type: str, separator: str) -> List[str]:
brace_count = 0
last_index = 0
splits = []
for i, c in enumerate(type):
if c in OPEN_BRACES:
brace_count += 1
elif c in CLOSE_BRACES:
brace_count -= 1
elif c == separator and brace_count == 0:
splits.append(type[last_index:i])
last_index = i + 1
splits.append(type[last_index:])
return splits
# Unwrap primitive union types
union_split = split_outer_level(type, "|")
if len(union_split) > 1:
if len(union_split) == 2 and "None" in union_split:
return "Optional", [v for v in union_split if v != "None"]
return "Union", union_split
# Unwrap primitive dict/list/tuple types
if type[0] in OPEN_BRACES and type[-1] in CLOSE_BRACES:
type_name = OPEN_BRACES[type[0]]
type_children = split_outer_level(type[1:-1], ",")
return type_name, type_children
brace_pos = type.find("[")
if brace_pos != -1 and type[-1] == "]":
# Unwrap normal composite types
type_name = type[:brace_pos]
type_children = split_outer_level(type[brace_pos + 1 : -1], ",")
else:
# Non-composite types, no need to unwrap
type_name = type
type_children = []
return RENAMED_TYPES.get(type_name, type_name), type_children
def is_type_equal(type1: str | None, type2: str | None) -> bool:
"""
Check if two types are equal.
This function handle composite types like list, dict, and tuple.
group similar types like list[str], List[str], and [str] as equal.
"""
if type1 is None and type2 is None:
return True
if type1 is None or type2 is None:
return False
evaluated_type1, children1 = unwrap_object_type(type1)
evaluated_type2, children2 = unwrap_object_type(type2)
# Compare the class name of the types (ignoring the module)
# TODO(majdyz): compare the module name as well.
t_len = min(len(evaluated_type1), len(evaluated_type2))
if evaluated_type1.split(".")[-t_len:] != evaluated_type2.split(".")[-t_len:]:
return False
if len(children1) != len(children2):
return False
if len(children1) == len(children2) == 0:
return True
for c1, c2 in zip(children1, children2):
if not is_type_equal(c1, c2):
return False
return True
def validate_matching_function(this: FunctionDef, that: FunctionDef):
expected_args = that.arg_types
expected_rets = that.return_type
func_name = that.name
errors = []
# Fix the async flag based on the expectation.
if this.is_async != that.is_async:
this.is_async = that.is_async
if this.is_async and f"async def {this.name}" not in this.function_code:
this.function_code = this.function_code.replace(
f"def {this.name}", f"async def {this.name}"
)
if not this.is_async and f"async def {this.name}" in this.function_code:
this.function_code = this.function_code.replace(
f"async def {this.name}", f"def {this.name}"
)
if any(
[
x[0] != y[0] or not is_type_equal(x[1], y[1]) and x[1] != "object"
# TODO: remove sorted and provide a stable order for one-to-many arg-types.
for x, y in zip(sorted(expected_args), sorted(this.arg_types))
]
):
errors.append(
f"Function {func_name} has different arguments than expected, "
f"expected {expected_args} but got {this.arg_types}"
)
if not is_type_equal(expected_rets, this.return_type) and expected_rets != "object":
errors.append(
f"Function {func_name} has different return type than expected, expected "
f"{expected_rets} but got {this.return_type}"
)
if errors:
raise Exception("Signature validation errors:\n " + "\n ".join(errors))
def normalize_type(type: str, renamed_types: dict[str, str] = {}) -> str:
"""
Normalize the type to a standard format.
e.g. list[str] -> List[str], dict[str, int | float] -> Dict[str, Union[int, float]]
Args:
type (str): The type to normalize.
Returns:
str: The normalized type.
"""
parent_type, children = unwrap_object_type(type)
if parent_type in renamed_types:
parent_type = renamed_types[parent_type]
if len(children) == 0:
return parent_type
content_type = ", ".join([normalize_type(c, renamed_types) for c in children])
return f"{parent_type}[{content_type}]"
def generate_object_code(obj: ObjectType) -> str:
if not obj.name:
return "" # Avoid generating an empty object
# Auto-generate a template for the object, this will not capture any class functions
fields = f"\n{' ' * 4}".join(
[
f"{field.name}: {field.type} "
f"{('= '+field.value) if field.value else ''} "
f"{('# '+field.description) if field.description else ''}"
for field in obj.Fields or []
]
)
parent_class = ""
if obj.is_enum:
parent_class = "Enum"
elif obj.is_pydantic:
parent_class = "BaseModel"
doc_string = (
f"""\"\"\"
{obj.description}
\"\"\""""
if obj.description
else ""
)
method_body = ("\n" + " " * 4).join(obj.code.split("\n")) + "\n" if obj.code else ""
template = f"""
class {obj.name}({parent_class}):
{doc_string if doc_string else ""}
{fields if fields else ""}
{method_body if method_body else ""}
{"pass" if not fields and not method_body else ""}
"""
return "\n".join(line for line in template.split("\n")).strip()
def genererate_line_error(error: str, code: str, line_number: int) -> str:
lines = code.split("\n")
if line_number > len(lines):
return error
code_line = lines[line_number - 1]
return f"{error} -> '{code_line.strip()}'"
def generate_compiled_code(
resp: ValidationResponse, add_code_stubs: bool = True
) -> str:
"""
Regenerate imports & raw code using the available objects and functions.
"""
resp.imports = sorted(set(resp.imports))
def __append_comment(code_block: str, comment: str) -> str:
"""
Append `# noqa` to the first line of the code block.
This is to suppress flake8 warnings for redefined names.
"""
lines = code_block.split("\n")
lines[0] = lines[0] + " # " + comment
return "\n".join(lines)
def __generate_stub(name, is_enum):
if not name:
return ""
elif is_enum:
return f"class {name}(Enum):\n pass"
else:
return f"class {name}(BaseModel):\n pass"
stub_objects = resp.available_objects if add_code_stubs else {}
stub_functions = resp.available_functions if add_code_stubs else {}
object_stubs_code = "\n\n".join(
[
__append_comment(__generate_stub(obj.name, obj.is_enum), "type: ignore")
for obj in stub_objects.values()
]
+ [
__append_comment(__generate_stub(obj.name, obj.is_enum), "type: ignore")
for obj in resp.objects
if obj.name not in stub_objects
]
)
objects_code = "\n\n".join(
[
__append_comment(generate_object_code(obj), "noqa")
for obj in stub_objects.values()
]
+ [
__append_comment(generate_object_code(obj), "noqa")
for obj in resp.objects
if obj.name not in stub_objects
]
)
functions_code = "\n\n".join(
[
__append_comment(f.function_template.strip(), "type: ignore")
for f in stub_functions.values()
if f.name != resp.function_name and f.function_template
]
+ [
__append_comment(f.function_template.strip(), "type: ignore")
for f in resp.functions
if f.name not in stub_functions and f.function_template
]
)
resp.rawCode = (
object_stubs_code.strip()
+ "\n\n"
+ objects_code.strip()
+ "\n\n"
+ functions_code.strip()
+ "\n\n"
+ resp.functionCode.strip()
)
return resp.get_compiled_code()

View File

@@ -0,0 +1,222 @@
import ast
import re
from forge.utils.function.model import FunctionDef, ObjectType, ObjectField
from forge.utils.function.util import normalize_type, PYTHON_TYPES
class FunctionVisitor(ast.NodeVisitor):
"""
Visits a Python AST and extracts function definitions and Pydantic class definitions
To use this class, create an instance and call the visit method with the AST.
as the argument The extracted function definitions and Pydantic class definitions
can be accessed from the functions and objects attributes respectively.
Example:
```
visitor = FunctionVisitor()
visitor.visit(ast.parse("def foo(x: int) -> int: return x"))
print(visitor.functions)
```
"""
def __init__(self):
self.functions: list[FunctionDef] = []
self.functionsIdx: list[int] = []
self.objects: list[ObjectType] = []
self.objectsIdx: list[int] = []
self.globals: list[str] = []
self.globalsIdx: list[int] = []
self.imports: list[str] = []
self.errors: list[str] = []
def visit_Import(self, node):
for alias in node.names:
import_line = f"import {alias.name}"
if alias.asname:
import_line += f" as {alias.asname}"
self.imports.append(import_line)
self.generic_visit(node)
def visit_ImportFrom(self, node):
for alias in node.names:
import_line = f"from {node.module} import {alias.name}"
if alias.asname:
import_line += f" as {alias.asname}"
self.imports.append(import_line)
self.generic_visit(node)
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
# treat async functions as normal functions
self.visit_FunctionDef(node) # type: ignore
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
args = []
for arg in node.args.args:
arg_type = ast.unparse(arg.annotation) if arg.annotation else "object"
args.append((arg.arg, normalize_type(arg_type)))
return_type = (
normalize_type(ast.unparse(node.returns)) if node.returns else None
)
# Extract doc_string & function body
if (
node.body
and isinstance(node.body[0], ast.Expr)
and isinstance(node.body[0].value, ast.Constant)
):
doc_string = node.body[0].value.s.strip()
template_body = [node.body[0], ast.Pass()]
is_implemented = not isinstance(node.body[1], ast.Pass)
else:
doc_string = ""
template_body = [ast.Pass()]
is_implemented = not isinstance(node.body[0], ast.Pass)
# Construct function template
original_body = node.body.copy()
node.body = template_body # type: ignore
function_template = ast.unparse(node)
node.body = original_body
function_code = ast.unparse(node)
if "await" in function_code and "async def" not in function_code:
function_code = function_code.replace("def ", "async def ")
function_template = function_template.replace("def ", "async def ")
def split_doc(keywords: list[str], doc: str) -> tuple[str, str]:
for keyword in keywords:
if match := re.search(f"{keyword}\\s?:", doc):
return doc[: match.start()], doc[match.end() :]
return doc, ""
# Decompose doc_pattern into func_doc, args_doc, rets_doc, errs_doc, usage_doc
# by splitting in reverse order
func_doc = doc_string
func_doc, usage_doc = split_doc(
["Ex", "Usage", "Usages", "Example", "Examples"], func_doc
)
func_doc, errs_doc = split_doc(["Error", "Errors", "Raise", "Raises"], func_doc)
func_doc, rets_doc = split_doc(["Return", "Returns"], func_doc)
func_doc, args_doc = split_doc(
["Arg", "Args", "Argument", "Arguments"], func_doc
)
# Extract Func
function_desc = func_doc.strip()
# Extract Args
args_descs = {}
split_pattern = r"\n(\s+.+):"
for match in reversed(list(re.finditer(split_pattern, string=args_doc))):
arg = match.group(1).strip().split(" ")[0]
desc = args_doc.rsplit(match.group(1), 1)[1].strip(": ")
args_descs[arg] = desc.strip()
args_doc = args_doc[: match.start()]
# Extract Returns
return_desc = ""
if match := re.match(split_pattern, string=rets_doc):
return_desc = rets_doc[match.end() :].strip()
self.functions.append(
FunctionDef(
name=node.name,
arg_types=args,
arg_descs=args_descs,
return_type=return_type,
return_desc=return_desc,
is_implemented=is_implemented,
function_desc=function_desc,
function_template=function_template,
function_code=function_code,
)
)
self.functionsIdx.append(node.lineno)
def visit_ClassDef(self, node: ast.ClassDef) -> None:
"""
Visits a ClassDef node in the AST and checks if it is a Pydantic class.
If it is a Pydantic class, adds its name to the list of Pydantic classes.
"""
is_pydantic = any(
[
(isinstance(base, ast.Name) and base.id == "BaseModel")
or (isinstance(base, ast.Attribute) and base.attr == "BaseModel")
for base in node.bases
]
)
is_enum = any(
[
(isinstance(base, ast.Name) and base.id.endswith("Enum"))
or (isinstance(base, ast.Attribute) and base.attr.endswith("Enum"))
for base in node.bases
]
)
is_implemented = not any(isinstance(v, ast.Pass) for v in node.body)
doc_string = ""
if (
node.body
and isinstance(node.body[0], ast.Expr)
and isinstance(node.body[0].value, ast.Constant)
):
doc_string = node.body[0].value.s.strip()
if node.name in PYTHON_TYPES:
self.errors.append(
f"Can't declare class with a Python built-in name "
f"`{node.name}`. Please use a different name."
)
fields = []
methods = []
for v in node.body:
if isinstance(v, ast.AnnAssign):
field = ObjectField(
name=ast.unparse(v.target),
type=normalize_type(ast.unparse(v.annotation)),
value=ast.unparse(v.value) if v.value else None,
)
if field.value is None and field.type.startswith("Optional"):
field.value = "None"
elif isinstance(v, ast.Assign):
if len(v.targets) > 1:
self.errors.append(
f"Class {node.name} has multiple assignments in a single line."
)
field = ObjectField(
name=ast.unparse(v.targets[0]),
type=type(ast.unparse(v.value)).__name__,
value=ast.unparse(v.value) if v.value else None,
)
elif isinstance(v, ast.Expr) and isinstance(v.value, ast.Constant):
# skip comments and docstrings
continue
else:
methods.append(ast.unparse(v))
continue
fields.append(field)
self.objects.append(
ObjectType(
name=node.name,
code="\n".join(methods),
description=doc_string,
Fields=fields,
is_pydantic=is_pydantic,
is_enum=is_enum,
is_implemented=is_implemented,
)
)
self.objectsIdx.append(node.lineno)
def visit(self, node):
if (
isinstance(node, ast.Assign)
or isinstance(node, ast.AnnAssign)
or isinstance(node, ast.AugAssign)
) and node.col_offset == 0:
self.globals.append(ast.unparse(node))
self.globalsIdx.append(node.lineno)
super().visit(node)