mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-22 13:38:10 -05:00
Compare commits
58 Commits
testing-cl
...
zamilmajdy
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7e0b1156cc | ||
|
|
da9360fdeb | ||
|
|
9dea6a273e | ||
|
|
e19636ac3e | ||
|
|
f03c6546b8 | ||
|
|
2c4afd4458 | ||
|
|
8b1d416de3 | ||
|
|
7f6b7d6d7e | ||
|
|
736ac778cc | ||
|
|
38eafdbb66 | ||
|
|
6d9f564dc5 | ||
|
|
3e675123d7 | ||
|
|
37cc047656 | ||
|
|
9f804080ed | ||
|
|
680fbf49aa | ||
|
|
901dadefc3 | ||
|
|
e204491c6c | ||
|
|
3597f801a7 | ||
|
|
b59862c402 | ||
|
|
81bac301e8 | ||
|
|
a9eb49d54e | ||
|
|
2c6e1eb4c8 | ||
|
|
3e8849b08e | ||
|
|
111e8585b5 | ||
|
|
8144d26cef | ||
|
|
e264bf7764 | ||
|
|
6dd0975236 | ||
|
|
c3acb99314 | ||
|
|
0578fb0246 | ||
|
|
731d0345f0 | ||
|
|
b4cd735f26 | ||
|
|
6e715b6c71 | ||
|
|
fcca4cc893 | ||
|
|
5c7c276c10 | ||
|
|
ae63aa8ebb | ||
|
|
fdd9f9b5ec | ||
|
|
a825aa8515 | ||
|
|
ae43136c2c | ||
|
|
c8e16f3fe1 | ||
|
|
3a60504138 | ||
|
|
dfa77739c3 | ||
|
|
9f6e25664c | ||
|
|
3c4ff60e11 | ||
|
|
47eeaf0325 | ||
|
|
81ad3cb69a | ||
|
|
834eb6c6e0 | ||
|
|
fb802400ba | ||
|
|
922e643737 | ||
|
|
7b5272f1f2 | ||
|
|
ea134c7dbd | ||
|
|
f7634524fa | ||
|
|
0eccbe1483 | ||
|
|
0916df4df7 | ||
|
|
22e2373a0b | ||
|
|
40426e4646 | ||
|
|
ef1fe7c4e8 | ||
|
|
ca7ca226ff | ||
|
|
ed5f12c02b |
@@ -23,6 +23,7 @@ from forge.components.code_executor.code_executor import (
|
|||||||
CodeExecutorComponent,
|
CodeExecutorComponent,
|
||||||
CodeExecutorConfiguration,
|
CodeExecutorConfiguration,
|
||||||
)
|
)
|
||||||
|
from forge.components.code_flow_executor import CodeFlowExecutionComponent
|
||||||
from forge.components.context.context import AgentContext, ContextComponent
|
from forge.components.context.context import AgentContext, ContextComponent
|
||||||
from forge.components.file_manager import FileManagerComponent
|
from forge.components.file_manager import FileManagerComponent
|
||||||
from forge.components.git_operations import GitOperationsComponent
|
from forge.components.git_operations import GitOperationsComponent
|
||||||
@@ -40,7 +41,6 @@ from forge.llm.providers import (
|
|||||||
ChatModelResponse,
|
ChatModelResponse,
|
||||||
MultiProvider,
|
MultiProvider,
|
||||||
)
|
)
|
||||||
from forge.llm.providers.utils import function_specs_from_commands
|
|
||||||
from forge.models.action import (
|
from forge.models.action import (
|
||||||
ActionErrorResult,
|
ActionErrorResult,
|
||||||
ActionInterruptedByHuman,
|
ActionInterruptedByHuman,
|
||||||
@@ -56,6 +56,7 @@ from forge.utils.exceptions import (
|
|||||||
)
|
)
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
|
from .prompt_strategies.code_flow import CodeFlowAgentPromptStrategy
|
||||||
from .prompt_strategies.one_shot import (
|
from .prompt_strategies.one_shot import (
|
||||||
OneShotAgentActionProposal,
|
OneShotAgentActionProposal,
|
||||||
OneShotAgentPromptStrategy,
|
OneShotAgentPromptStrategy,
|
||||||
@@ -96,11 +97,14 @@ class Agent(BaseAgent[OneShotAgentActionProposal], Configurable[AgentSettings]):
|
|||||||
llm_provider: MultiProvider,
|
llm_provider: MultiProvider,
|
||||||
file_storage: FileStorage,
|
file_storage: FileStorage,
|
||||||
app_config: AppConfig,
|
app_config: AppConfig,
|
||||||
|
prompt_strategy_class: type[
|
||||||
|
OneShotAgentPromptStrategy | CodeFlowAgentPromptStrategy
|
||||||
|
] = CodeFlowAgentPromptStrategy,
|
||||||
):
|
):
|
||||||
super().__init__(settings)
|
super().__init__(settings)
|
||||||
|
|
||||||
self.llm_provider = llm_provider
|
self.llm_provider = llm_provider
|
||||||
prompt_config = OneShotAgentPromptStrategy.default_configuration.model_copy(
|
prompt_config = prompt_strategy_class.default_configuration.model_copy(
|
||||||
deep=True
|
deep=True
|
||||||
)
|
)
|
||||||
prompt_config.use_functions_api = (
|
prompt_config.use_functions_api = (
|
||||||
@@ -108,7 +112,7 @@ class Agent(BaseAgent[OneShotAgentActionProposal], Configurable[AgentSettings]):
|
|||||||
# Anthropic currently doesn't support tools + prefilling :(
|
# Anthropic currently doesn't support tools + prefilling :(
|
||||||
and self.llm.provider_name != "anthropic"
|
and self.llm.provider_name != "anthropic"
|
||||||
)
|
)
|
||||||
self.prompt_strategy = OneShotAgentPromptStrategy(prompt_config, logger)
|
self.prompt_strategy = prompt_strategy_class(prompt_config, logger)
|
||||||
self.commands: list[Command] = []
|
self.commands: list[Command] = []
|
||||||
|
|
||||||
# Components
|
# Components
|
||||||
@@ -145,6 +149,7 @@ class Agent(BaseAgent[OneShotAgentActionProposal], Configurable[AgentSettings]):
|
|||||||
self.watchdog = WatchdogComponent(settings.config, settings.history).run_after(
|
self.watchdog = WatchdogComponent(settings.config, settings.history).run_after(
|
||||||
ContextComponent
|
ContextComponent
|
||||||
)
|
)
|
||||||
|
self.code_flow_executor = CodeFlowExecutionComponent(lambda: self.commands)
|
||||||
|
|
||||||
self.event_history = settings.history
|
self.event_history = settings.history
|
||||||
self.app_config = app_config
|
self.app_config = app_config
|
||||||
@@ -185,7 +190,7 @@ class Agent(BaseAgent[OneShotAgentActionProposal], Configurable[AgentSettings]):
|
|||||||
task=self.state.task,
|
task=self.state.task,
|
||||||
ai_profile=self.state.ai_profile,
|
ai_profile=self.state.ai_profile,
|
||||||
ai_directives=directives,
|
ai_directives=directives,
|
||||||
commands=function_specs_from_commands(self.commands),
|
commands=self.commands,
|
||||||
include_os_info=include_os_info,
|
include_os_info=include_os_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -201,9 +206,7 @@ class Agent(BaseAgent[OneShotAgentActionProposal], Configurable[AgentSettings]):
|
|||||||
if exception:
|
if exception:
|
||||||
prompt.messages.append(ChatMessage.system(f"Error: {exception}"))
|
prompt.messages.append(ChatMessage.system(f"Error: {exception}"))
|
||||||
|
|
||||||
response: ChatModelResponse[
|
response: ChatModelResponse = await self.llm_provider.create_chat_completion(
|
||||||
OneShotAgentActionProposal
|
|
||||||
] = await self.llm_provider.create_chat_completion(
|
|
||||||
prompt.messages,
|
prompt.messages,
|
||||||
model_name=self.llm.name,
|
model_name=self.llm.name,
|
||||||
completion_parser=self.prompt_strategy.parse_response_content,
|
completion_parser=self.prompt_strategy.parse_response_content,
|
||||||
@@ -281,7 +284,7 @@ class Agent(BaseAgent[OneShotAgentActionProposal], Configurable[AgentSettings]):
|
|||||||
except AgentException:
|
except AgentException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise CommandExecutionError(str(e))
|
raise CommandExecutionError(str(e)) from e
|
||||||
|
|
||||||
def _get_command(self, command_name: str) -> Command:
|
def _get_command(self, command_name: str) -> Command:
|
||||||
for command in reversed(self.commands):
|
for command in reversed(self.commands):
|
||||||
|
|||||||
355
autogpt/autogpt/agents/prompt_strategies/code_flow.py
Normal file
355
autogpt/autogpt/agents/prompt_strategies/code_flow.py
Normal file
@@ -0,0 +1,355 @@
|
|||||||
|
import inspect
|
||||||
|
import re
|
||||||
|
from logging import Logger
|
||||||
|
from typing import Callable, Iterable, Sequence, get_args, get_origin
|
||||||
|
|
||||||
|
from forge.command import Command
|
||||||
|
from forge.components.code_flow_executor import CodeFlowExecutionComponent
|
||||||
|
from forge.config.ai_directives import AIDirectives
|
||||||
|
from forge.config.ai_profile import AIProfile
|
||||||
|
from forge.json.parsing import extract_dict_from_json
|
||||||
|
from forge.llm.prompting import ChatPrompt, LanguageModelClassification, PromptStrategy
|
||||||
|
from forge.llm.prompting.utils import indent
|
||||||
|
from forge.llm.providers.schema import (
|
||||||
|
AssistantChatMessage,
|
||||||
|
AssistantFunctionCall,
|
||||||
|
ChatMessage,
|
||||||
|
)
|
||||||
|
from forge.models.config import SystemConfiguration
|
||||||
|
from forge.models.json_schema import JSONSchema
|
||||||
|
from forge.utils.exceptions import InvalidAgentResponseError
|
||||||
|
from forge.utils.function.code_validation import CodeValidator
|
||||||
|
from forge.utils.function.model import FunctionDef
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from autogpt.agents.prompt_strategies.one_shot import (
|
||||||
|
AssistantThoughts,
|
||||||
|
OneShotAgentActionProposal,
|
||||||
|
OneShotAgentPromptConfiguration,
|
||||||
|
)
|
||||||
|
|
||||||
|
_RESPONSE_INTERFACE_NAME = "AssistantResponse"
|
||||||
|
|
||||||
|
|
||||||
|
class CodeFlowAgentActionProposal(BaseModel):
|
||||||
|
thoughts: AssistantThoughts
|
||||||
|
immediate_plan: str = Field(
|
||||||
|
...,
|
||||||
|
description="We will be running an iterative process to execute the plan, "
|
||||||
|
"Write the partial / immediate plan to execute your plan as detailed and "
|
||||||
|
"efficiently as possible without the help of the reasoning/intelligence. "
|
||||||
|
"The plan should describe the output of the immediate plan, so that the next "
|
||||||
|
"iteration can be executed by taking the output into account. "
|
||||||
|
"Try to do as much as possible without making any assumption or uninformed "
|
||||||
|
"guesses. Avoid large output at all costs!!!\n"
|
||||||
|
"Format: Objective[Objective of this iteration, explain what's the use of this "
|
||||||
|
"iteration for the next one] Plan[Plan that does not require any reasoning or "
|
||||||
|
"intelligence] Output[Output of the plan / should be small, avoid whole file "
|
||||||
|
"output]",
|
||||||
|
)
|
||||||
|
python_code: str = Field(
|
||||||
|
...,
|
||||||
|
description=(
|
||||||
|
"Write the fully-functional Python code of the immediate plan. "
|
||||||
|
"The output will be an `async def main() -> str` function of the immediate "
|
||||||
|
"plan that return the string output, the output will be passed into the "
|
||||||
|
"LLM context window so avoid returning the whole content!. "
|
||||||
|
"Use ONLY the listed available functions and built-in Python features. "
|
||||||
|
"Leverage the given magic functions to implement function calls for which "
|
||||||
|
"the arguments can't be determined yet. "
|
||||||
|
"Example:`async def main() -> str:\n"
|
||||||
|
" return await provided_function('arg1', 'arg2').split('\\n')[0]`"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
FINAL_INSTRUCTION: str = (
|
||||||
|
"You have to give the answer in the from of JSON schema specified previously. "
|
||||||
|
"For the `python_code` field, you have to write Python code to execute your plan "
|
||||||
|
"as efficiently as possible. Your code will be executed directly without any "
|
||||||
|
"editing, if it doesn't work you will be held responsible. "
|
||||||
|
"Use ONLY the listed available functions and built-in Python features. "
|
||||||
|
"Do not make uninformed assumptions "
|
||||||
|
"(e.g. about the content or format of an unknown file). Leverage the given magic "
|
||||||
|
"functions to implement function calls for which the arguments can't be determined "
|
||||||
|
"yet. Reduce the amount of unnecessary data passed into these magic functions "
|
||||||
|
"where possible, because magic costs money and magically processing large amounts "
|
||||||
|
"of data is expensive. If you think are done with the task, you can simply call "
|
||||||
|
"finish(reason='your reason') to end the task, "
|
||||||
|
"a function that has one `finish` command, don't mix finish with other functions! "
|
||||||
|
"If you still need to do other functions, "
|
||||||
|
"let the next cycle execute the `finish` function. "
|
||||||
|
"Avoid hard-coding input values as input, and avoid returning large outputs. "
|
||||||
|
"The code that you have been executing in the past cycles can also be buggy, "
|
||||||
|
"so if you see undesired output, you can always try to re-plan, and re-code. "
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CodeFlowAgentPromptStrategy(PromptStrategy):
|
||||||
|
default_configuration: OneShotAgentPromptConfiguration = (
|
||||||
|
OneShotAgentPromptConfiguration()
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
configuration: SystemConfiguration,
|
||||||
|
logger: Logger,
|
||||||
|
):
|
||||||
|
self.config = configuration
|
||||||
|
self.response_schema = JSONSchema.from_dict(
|
||||||
|
CodeFlowAgentActionProposal.model_json_schema()
|
||||||
|
)
|
||||||
|
self.logger = logger
|
||||||
|
self.commands: Sequence[Command] = [] # Sequence -> disallow list modification
|
||||||
|
|
||||||
|
@property
|
||||||
|
def llm_classification(self) -> LanguageModelClassification:
|
||||||
|
return LanguageModelClassification.SMART_MODEL # FIXME: dynamic switching
|
||||||
|
|
||||||
|
def build_prompt(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
messages: list[ChatMessage],
|
||||||
|
task: str,
|
||||||
|
ai_profile: AIProfile,
|
||||||
|
ai_directives: AIDirectives,
|
||||||
|
commands: Sequence[Command],
|
||||||
|
**extras,
|
||||||
|
) -> ChatPrompt:
|
||||||
|
"""Constructs and returns a prompt with the following structure:
|
||||||
|
1. System prompt
|
||||||
|
3. `cycle_instruction`
|
||||||
|
"""
|
||||||
|
system_prompt, response_prefill = self.build_system_prompt(
|
||||||
|
ai_profile=ai_profile,
|
||||||
|
ai_directives=ai_directives,
|
||||||
|
commands=commands,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.commands = commands
|
||||||
|
final_instruction_msg = ChatMessage.system(FINAL_INSTRUCTION)
|
||||||
|
|
||||||
|
return ChatPrompt(
|
||||||
|
messages=[
|
||||||
|
ChatMessage.system(system_prompt),
|
||||||
|
ChatMessage.user(f'"""{task}"""'),
|
||||||
|
*messages,
|
||||||
|
*(
|
||||||
|
[final_instruction_msg]
|
||||||
|
if not any(m.role == "assistant" for m in messages)
|
||||||
|
else []
|
||||||
|
),
|
||||||
|
],
|
||||||
|
prefill_response=response_prefill,
|
||||||
|
)
|
||||||
|
|
||||||
|
def build_system_prompt(
|
||||||
|
self,
|
||||||
|
ai_profile: AIProfile,
|
||||||
|
ai_directives: AIDirectives,
|
||||||
|
commands: Iterable[Command],
|
||||||
|
) -> tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Builds the system prompt.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The system prompt body
|
||||||
|
str: The desired start for the LLM's response; used to steer the output
|
||||||
|
"""
|
||||||
|
response_fmt_instruction, response_prefill = self.response_format_instruction()
|
||||||
|
system_prompt_parts = (
|
||||||
|
self._generate_intro_prompt(ai_profile)
|
||||||
|
+ [
|
||||||
|
"## Your Task\n"
|
||||||
|
"The user will specify a task for you to execute, in triple quotes,"
|
||||||
|
" in the next message. Your job is to complete the task, "
|
||||||
|
"and terminate when your task is done."
|
||||||
|
]
|
||||||
|
+ ["## Available Functions\n" + self._generate_function_headers(commands)]
|
||||||
|
+ ["## RESPONSE FORMAT\n" + response_fmt_instruction]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Join non-empty parts together into paragraph format
|
||||||
|
return (
|
||||||
|
"\n\n".join(filter(None, system_prompt_parts)).strip("\n"),
|
||||||
|
response_prefill,
|
||||||
|
)
|
||||||
|
|
||||||
|
def response_format_instruction(self) -> tuple[str, str]:
|
||||||
|
response_schema = self.response_schema.model_copy(deep=True)
|
||||||
|
assert response_schema.properties
|
||||||
|
|
||||||
|
# Unindent for performance
|
||||||
|
response_format = re.sub(
|
||||||
|
r"\n\s+",
|
||||||
|
"\n",
|
||||||
|
response_schema.to_typescript_object_interface(_RESPONSE_INTERFACE_NAME),
|
||||||
|
)
|
||||||
|
response_prefill = f'{{\n "{list(response_schema.properties.keys())[0]}":'
|
||||||
|
|
||||||
|
return (
|
||||||
|
(
|
||||||
|
f"YOU MUST ALWAYS RESPOND WITH A JSON OBJECT OF THE FOLLOWING TYPE:\n"
|
||||||
|
f"{response_format}"
|
||||||
|
),
|
||||||
|
response_prefill,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generate_intro_prompt(self, ai_profile: AIProfile) -> list[str]:
|
||||||
|
"""Generates the introduction part of the prompt.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[str]: A list of strings forming the introduction part of the prompt.
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
f"You are {ai_profile.ai_name}, {ai_profile.ai_role.rstrip('.')}.",
|
||||||
|
# "Your decisions must always be made independently without seeking "
|
||||||
|
# "user assistance. Play to your strengths as an LLM and pursue "
|
||||||
|
# "simple strategies with no legal complications.",
|
||||||
|
]
|
||||||
|
|
||||||
|
def _generate_function_headers(self, commands: Iterable[Command]) -> str:
|
||||||
|
function_stubs: list[str] = []
|
||||||
|
annotation_types_in_context: set[type] = set()
|
||||||
|
for f in commands:
|
||||||
|
# Add source code of non-builtin types from function signatures
|
||||||
|
new_annotation_types = extract_annotation_types(f.method).difference(
|
||||||
|
annotation_types_in_context
|
||||||
|
)
|
||||||
|
new_annotation_types_src = [
|
||||||
|
f"# {a.__module__}.{a.__qualname__}\n{inspect.getsource(a)}"
|
||||||
|
for a in new_annotation_types
|
||||||
|
]
|
||||||
|
annotation_types_in_context.update(new_annotation_types)
|
||||||
|
|
||||||
|
param_descriptions = "\n".join(
|
||||||
|
f"{param.name}: {param.spec.description}"
|
||||||
|
for param in f.parameters
|
||||||
|
if param.spec.description
|
||||||
|
)
|
||||||
|
full_function_stub = (
|
||||||
|
("\n".join(new_annotation_types_src) + "\n" + f.header).strip()
|
||||||
|
+ "\n"
|
||||||
|
+ indent(
|
||||||
|
(
|
||||||
|
'"""\n'
|
||||||
|
f"{f.description}\n\n"
|
||||||
|
f"Params:\n{indent(param_descriptions)}\n"
|
||||||
|
'"""\n'
|
||||||
|
"pass"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
function_stubs.append(full_function_stub)
|
||||||
|
|
||||||
|
return "\n\n\n".join(function_stubs)
|
||||||
|
|
||||||
|
async def parse_response_content(
|
||||||
|
self,
|
||||||
|
response: AssistantChatMessage,
|
||||||
|
) -> OneShotAgentActionProposal:
|
||||||
|
if not response.content:
|
||||||
|
raise InvalidAgentResponseError("Assistant response has no text content")
|
||||||
|
|
||||||
|
self.logger.debug(
|
||||||
|
"LLM response content:"
|
||||||
|
+ (
|
||||||
|
f"\n{response.content}"
|
||||||
|
if "\n" in response.content
|
||||||
|
else f" '{response.content}'"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assistant_reply_dict = extract_dict_from_json(response.content)
|
||||||
|
|
||||||
|
parsed_response = CodeFlowAgentActionProposal.model_validate(
|
||||||
|
assistant_reply_dict
|
||||||
|
)
|
||||||
|
if not parsed_response.python_code:
|
||||||
|
raise ValueError("python_code is empty")
|
||||||
|
|
||||||
|
available_functions = {
|
||||||
|
c.name: FunctionDef(
|
||||||
|
name=c.name,
|
||||||
|
arg_types=[(p.name, p.spec.python_type) for p in c.parameters],
|
||||||
|
arg_descs={p.name: p.spec.description for p in c.parameters},
|
||||||
|
arg_defaults={
|
||||||
|
p.name: p.spec.default or "None"
|
||||||
|
for p in c.parameters
|
||||||
|
if p.spec.default or not p.spec.required
|
||||||
|
},
|
||||||
|
return_type=c.return_type,
|
||||||
|
return_desc="Output of the function",
|
||||||
|
function_desc=c.description,
|
||||||
|
is_async=c.is_async,
|
||||||
|
)
|
||||||
|
for c in self.commands
|
||||||
|
}
|
||||||
|
available_functions.update(
|
||||||
|
{
|
||||||
|
"main": FunctionDef(
|
||||||
|
name="main",
|
||||||
|
arg_types=[],
|
||||||
|
arg_descs={},
|
||||||
|
return_type="str",
|
||||||
|
return_desc="Output of the function",
|
||||||
|
function_desc="The main function to execute the plan",
|
||||||
|
is_async=True,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
code_validation = await CodeValidator(
|
||||||
|
function_name="main",
|
||||||
|
available_functions=available_functions,
|
||||||
|
).validate_code(parsed_response.python_code)
|
||||||
|
|
||||||
|
clean_response = response.model_copy()
|
||||||
|
clean_response.content = parsed_response.model_dump_json(indent=4)
|
||||||
|
|
||||||
|
# TODO: prevent combining finish with other functions
|
||||||
|
if _finish_call := re.search(
|
||||||
|
r"finish\((reason=)?(.*?)\)", code_validation.functionCode
|
||||||
|
):
|
||||||
|
finish_reason = _finish_call.group(2)[1:-1] # remove quotes
|
||||||
|
result = OneShotAgentActionProposal(
|
||||||
|
thoughts=parsed_response.thoughts,
|
||||||
|
use_tool=AssistantFunctionCall(
|
||||||
|
name="finish",
|
||||||
|
arguments={"reason": finish_reason},
|
||||||
|
),
|
||||||
|
raw_message=clean_response,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = OneShotAgentActionProposal(
|
||||||
|
thoughts=parsed_response.thoughts,
|
||||||
|
use_tool=AssistantFunctionCall(
|
||||||
|
name=CodeFlowExecutionComponent.execute_code_flow.name,
|
||||||
|
arguments={
|
||||||
|
"python_code": code_validation.functionCode,
|
||||||
|
"plan_text": parsed_response.immediate_plan,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
raw_message=clean_response,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def extract_annotation_types(func: Callable) -> set[type]:
|
||||||
|
annotation_types = set()
|
||||||
|
for annotation in inspect.get_annotations(func).values():
|
||||||
|
annotation_types.update(_get_nested_types(annotation))
|
||||||
|
return annotation_types
|
||||||
|
|
||||||
|
|
||||||
|
def _get_nested_types(annotation: type) -> Iterable[type]:
|
||||||
|
if _args := get_args(annotation):
|
||||||
|
for a in _args:
|
||||||
|
yield from _get_nested_types(a)
|
||||||
|
if not _is_builtin_type(_a := get_origin(annotation) or annotation):
|
||||||
|
yield _a
|
||||||
|
|
||||||
|
|
||||||
|
def _is_builtin_type(_type: type):
|
||||||
|
"""Check if a given type is a built-in type."""
|
||||||
|
import sys
|
||||||
|
|
||||||
|
return _type.__module__ in sys.stdlib_module_names
|
||||||
@@ -6,6 +6,7 @@ import re
|
|||||||
from logging import Logger
|
from logging import Logger
|
||||||
|
|
||||||
import distro
|
import distro
|
||||||
|
from forge.command import Command
|
||||||
from forge.config.ai_directives import AIDirectives
|
from forge.config.ai_directives import AIDirectives
|
||||||
from forge.config.ai_profile import AIProfile
|
from forge.config.ai_profile import AIProfile
|
||||||
from forge.json.parsing import extract_dict_from_json
|
from forge.json.parsing import extract_dict_from_json
|
||||||
@@ -16,6 +17,7 @@ from forge.llm.providers.schema import (
|
|||||||
ChatMessage,
|
ChatMessage,
|
||||||
CompletionModelFunction,
|
CompletionModelFunction,
|
||||||
)
|
)
|
||||||
|
from forge.llm.providers.utils import function_specs_from_commands
|
||||||
from forge.models.action import ActionProposal
|
from forge.models.action import ActionProposal
|
||||||
from forge.models.config import SystemConfiguration, UserConfigurable
|
from forge.models.config import SystemConfiguration, UserConfigurable
|
||||||
from forge.models.json_schema import JSONSchema
|
from forge.models.json_schema import JSONSchema
|
||||||
@@ -27,13 +29,21 @@ _RESPONSE_INTERFACE_NAME = "AssistantResponse"
|
|||||||
|
|
||||||
|
|
||||||
class AssistantThoughts(ModelWithSummary):
|
class AssistantThoughts(ModelWithSummary):
|
||||||
|
past_action_summary: str = Field(
|
||||||
|
...,
|
||||||
|
description="Summary of the last action you took, if there is none, "
|
||||||
|
"you can leave it empty",
|
||||||
|
)
|
||||||
observations: str = Field(
|
observations: str = Field(
|
||||||
description="Relevant observations from your last action (if any)"
|
description="Relevant observations from your last actions (if any)"
|
||||||
)
|
)
|
||||||
text: str = Field(description="Thoughts")
|
text: str = Field(description="Thoughts")
|
||||||
reasoning: str = Field(description="Reasoning behind the thoughts")
|
reasoning: str = Field(description="Reasoning behind the thoughts")
|
||||||
self_criticism: str = Field(description="Constructive self-criticism")
|
self_criticism: str = Field(description="Constructive self-criticism")
|
||||||
plan: list[str] = Field(description="Short list that conveys the long-term plan")
|
plan: list[str] = Field(
|
||||||
|
description="Short list that conveys the long-term plan, "
|
||||||
|
"considering the progress on your task so far",
|
||||||
|
)
|
||||||
speak: str = Field(description="Summary of thoughts, to say to user")
|
speak: str = Field(description="Summary of thoughts, to say to user")
|
||||||
|
|
||||||
def summary(self) -> str:
|
def summary(self) -> str:
|
||||||
@@ -101,7 +111,7 @@ class OneShotAgentPromptStrategy(PromptStrategy):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def llm_classification(self) -> LanguageModelClassification:
|
def llm_classification(self) -> LanguageModelClassification:
|
||||||
return LanguageModelClassification.FAST_MODEL # FIXME: dynamic switching
|
return LanguageModelClassification.SMART_MODEL # FIXME: dynamic switching
|
||||||
|
|
||||||
def build_prompt(
|
def build_prompt(
|
||||||
self,
|
self,
|
||||||
@@ -110,7 +120,7 @@ class OneShotAgentPromptStrategy(PromptStrategy):
|
|||||||
task: str,
|
task: str,
|
||||||
ai_profile: AIProfile,
|
ai_profile: AIProfile,
|
||||||
ai_directives: AIDirectives,
|
ai_directives: AIDirectives,
|
||||||
commands: list[CompletionModelFunction],
|
commands: list[Command],
|
||||||
include_os_info: bool,
|
include_os_info: bool,
|
||||||
**extras,
|
**extras,
|
||||||
) -> ChatPrompt:
|
) -> ChatPrompt:
|
||||||
@@ -118,10 +128,11 @@ class OneShotAgentPromptStrategy(PromptStrategy):
|
|||||||
1. System prompt
|
1. System prompt
|
||||||
3. `cycle_instruction`
|
3. `cycle_instruction`
|
||||||
"""
|
"""
|
||||||
|
functions = function_specs_from_commands(commands)
|
||||||
system_prompt, response_prefill = self.build_system_prompt(
|
system_prompt, response_prefill = self.build_system_prompt(
|
||||||
ai_profile=ai_profile,
|
ai_profile=ai_profile,
|
||||||
ai_directives=ai_directives,
|
ai_directives=ai_directives,
|
||||||
commands=commands,
|
functions=functions,
|
||||||
include_os_info=include_os_info,
|
include_os_info=include_os_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -135,14 +146,14 @@ class OneShotAgentPromptStrategy(PromptStrategy):
|
|||||||
final_instruction_msg,
|
final_instruction_msg,
|
||||||
],
|
],
|
||||||
prefill_response=response_prefill,
|
prefill_response=response_prefill,
|
||||||
functions=commands if self.config.use_functions_api else [],
|
functions=functions if self.config.use_functions_api else [],
|
||||||
)
|
)
|
||||||
|
|
||||||
def build_system_prompt(
|
def build_system_prompt(
|
||||||
self,
|
self,
|
||||||
ai_profile: AIProfile,
|
ai_profile: AIProfile,
|
||||||
ai_directives: AIDirectives,
|
ai_directives: AIDirectives,
|
||||||
commands: list[CompletionModelFunction],
|
functions: list[CompletionModelFunction],
|
||||||
include_os_info: bool,
|
include_os_info: bool,
|
||||||
) -> tuple[str, str]:
|
) -> tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
@@ -162,7 +173,7 @@ class OneShotAgentPromptStrategy(PromptStrategy):
|
|||||||
self.config.body_template.format(
|
self.config.body_template.format(
|
||||||
constraints=format_numbered_list(ai_directives.constraints),
|
constraints=format_numbered_list(ai_directives.constraints),
|
||||||
resources=format_numbered_list(ai_directives.resources),
|
resources=format_numbered_list(ai_directives.resources),
|
||||||
commands=self._generate_commands_list(commands),
|
commands=self._generate_commands_list(functions),
|
||||||
best_practices=format_numbered_list(ai_directives.best_practices),
|
best_practices=format_numbered_list(ai_directives.best_practices),
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from forge.agent_protocol.models import (
|
|||||||
TaskRequestBody,
|
TaskRequestBody,
|
||||||
TaskStepsListResponse,
|
TaskStepsListResponse,
|
||||||
)
|
)
|
||||||
|
from forge.components.code_flow_executor import CodeFlowExecutionComponent
|
||||||
from forge.file_storage import FileStorage
|
from forge.file_storage import FileStorage
|
||||||
from forge.llm.providers import ModelProviderBudget, MultiProvider
|
from forge.llm.providers import ModelProviderBudget, MultiProvider
|
||||||
from forge.models.action import ActionErrorResult, ActionSuccessResult
|
from forge.models.action import ActionErrorResult, ActionSuccessResult
|
||||||
@@ -298,11 +299,16 @@ class AgentProtocolServer:
|
|||||||
else ""
|
else ""
|
||||||
)
|
)
|
||||||
output += f"{assistant_response.thoughts.speak}\n\n"
|
output += f"{assistant_response.thoughts.speak}\n\n"
|
||||||
output += (
|
if next_tool_to_use.name == CodeFlowExecutionComponent.execute_code_flow.name:
|
||||||
f"Next Command: {next_tool_to_use}"
|
code = next_tool_to_use.arguments["python_code"]
|
||||||
if next_tool_to_use.name != ASK_COMMAND
|
plan = next_tool_to_use.arguments["plan_text"]
|
||||||
else next_tool_to_use.arguments["question"]
|
output += f"Code for next step:\n```py\n# {plan}\n\n{code}\n```"
|
||||||
)
|
else:
|
||||||
|
output += (
|
||||||
|
f"Next Command: {next_tool_to_use}"
|
||||||
|
if next_tool_to_use.name != ASK_COMMAND
|
||||||
|
else next_tool_to_use.arguments["question"]
|
||||||
|
)
|
||||||
|
|
||||||
additional_output = {
|
additional_output = {
|
||||||
**(
|
**(
|
||||||
|
|||||||
@@ -630,6 +630,9 @@ def update_user(
|
|||||||
command_args: The arguments for the command.
|
command_args: The arguments for the command.
|
||||||
assistant_reply_dict: The assistant's reply.
|
assistant_reply_dict: The assistant's reply.
|
||||||
"""
|
"""
|
||||||
|
from forge.components.code_flow_executor import CodeFlowExecutionComponent
|
||||||
|
from forge.llm.prompting.utils import indent
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
print_assistant_thoughts(
|
print_assistant_thoughts(
|
||||||
@@ -644,15 +647,29 @@ def update_user(
|
|||||||
# First log new-line so user can differentiate sections better in console
|
# First log new-line so user can differentiate sections better in console
|
||||||
print()
|
print()
|
||||||
safe_tool_name = remove_ansi_escape(action_proposal.use_tool.name)
|
safe_tool_name = remove_ansi_escape(action_proposal.use_tool.name)
|
||||||
logger.info(
|
if safe_tool_name == CodeFlowExecutionComponent.execute_code_flow.name:
|
||||||
f"COMMAND = {Fore.CYAN}{safe_tool_name}{Style.RESET_ALL} "
|
plan = action_proposal.use_tool.arguments["plan_text"]
|
||||||
f"ARGUMENTS = {Fore.CYAN}{action_proposal.use_tool.arguments}{Style.RESET_ALL}",
|
code = action_proposal.use_tool.arguments["python_code"]
|
||||||
extra={
|
logger.info(
|
||||||
"title": "NEXT ACTION:",
|
f"\n{indent(code, f'{Fore.GREEN}>>> {Fore.RESET}')}\n",
|
||||||
"title_color": Fore.CYAN,
|
extra={
|
||||||
"preserve_color": True,
|
"title": "PROPOSED ACTION:",
|
||||||
},
|
"title_color": Fore.GREEN,
|
||||||
)
|
"preserve_color": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"{plan}\n", extra={"title": "EXPLANATION:", "title_color": Fore.YELLOW}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
str(action_proposal.use_tool),
|
||||||
|
extra={
|
||||||
|
"title": "PROPOSED ACTION:",
|
||||||
|
"title_color": Fore.GREEN,
|
||||||
|
"preserve_color": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def get_user_feedback(
|
async def get_user_feedback(
|
||||||
@@ -732,6 +749,12 @@ def print_assistant_thoughts(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(thoughts, AssistantThoughts):
|
if isinstance(thoughts, AssistantThoughts):
|
||||||
|
if thoughts.observations:
|
||||||
|
print_attribute(
|
||||||
|
"OBSERVATIONS",
|
||||||
|
remove_ansi_escape(thoughts.observations),
|
||||||
|
title_color=Fore.YELLOW,
|
||||||
|
)
|
||||||
print_attribute(
|
print_attribute(
|
||||||
"REASONING", remove_ansi_escape(thoughts.reasoning), title_color=Fore.YELLOW
|
"REASONING", remove_ansi_escape(thoughts.reasoning), title_color=Fore.YELLOW
|
||||||
)
|
)
|
||||||
@@ -753,7 +776,7 @@ def print_assistant_thoughts(
|
|||||||
line.strip(), extra={"title": "- ", "title_color": Fore.GREEN}
|
line.strip(), extra={"title": "- ", "title_color": Fore.GREEN}
|
||||||
)
|
)
|
||||||
print_attribute(
|
print_attribute(
|
||||||
"CRITICISM",
|
"SELF-CRITICISM",
|
||||||
remove_ansi_escape(thoughts.self_criticism),
|
remove_ansi_escape(thoughts.self_criticism),
|
||||||
title_color=Fore.YELLOW,
|
title_color=Fore.YELLOW,
|
||||||
)
|
)
|
||||||
@@ -764,7 +787,7 @@ def print_assistant_thoughts(
|
|||||||
speak(assistant_thoughts_speak)
|
speak(assistant_thoughts_speak)
|
||||||
else:
|
else:
|
||||||
print_attribute(
|
print_attribute(
|
||||||
"SPEAK", assistant_thoughts_speak, title_color=Fore.YELLOW
|
"TL;DR", assistant_thoughts_speak, title_color=Fore.YELLOW
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
speak(thoughts_text)
|
speak(thoughts_text)
|
||||||
|
|||||||
30
autogpt/poetry.lock
generated
30
autogpt/poetry.lock
generated
@@ -4216,7 +4216,7 @@ test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"]
|
|||||||
name = "ptyprocess"
|
name = "ptyprocess"
|
||||||
version = "0.7.0"
|
version = "0.7.0"
|
||||||
description = "Run a subprocess in a pseudo terminal"
|
description = "Run a subprocess in a pseudo terminal"
|
||||||
optional = true
|
optional = false
|
||||||
python-versions = "*"
|
python-versions = "*"
|
||||||
files = [
|
files = [
|
||||||
{file = "ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35"},
|
{file = "ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35"},
|
||||||
@@ -5212,6 +5212,32 @@ files = [
|
|||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
pyasn1 = ">=0.1.3"
|
pyasn1 = ">=0.1.3"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ruff"
|
||||||
|
version = "0.4.4"
|
||||||
|
description = "An extremely fast Python linter and code formatter, written in Rust."
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "ruff-0.4.4-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:29d44ef5bb6a08e235c8249294fa8d431adc1426bfda99ed493119e6f9ea1bf6"},
|
||||||
|
{file = "ruff-0.4.4-py3-none-macosx_11_0_arm64.whl", hash = "sha256:c4efe62b5bbb24178c950732ddd40712b878a9b96b1d02b0ff0b08a090cbd891"},
|
||||||
|
{file = "ruff-0.4.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4c8e2f1e8fc12d07ab521a9005d68a969e167b589cbcaee354cb61e9d9de9c15"},
|
||||||
|
{file = "ruff-0.4.4-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:60ed88b636a463214905c002fa3eaab19795679ed55529f91e488db3fe8976ab"},
|
||||||
|
{file = "ruff-0.4.4-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b90fc5e170fc71c712cc4d9ab0e24ea505c6a9e4ebf346787a67e691dfb72e85"},
|
||||||
|
{file = "ruff-0.4.4-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:8e7e6ebc10ef16dcdc77fd5557ee60647512b400e4a60bdc4849468f076f6eef"},
|
||||||
|
{file = "ruff-0.4.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b9ddb2c494fb79fc208cd15ffe08f32b7682519e067413dbaf5f4b01a6087bcd"},
|
||||||
|
{file = "ruff-0.4.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c51c928a14f9f0a871082603e25a1588059b7e08a920f2f9fa7157b5bf08cfe9"},
|
||||||
|
{file = "ruff-0.4.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b5eb0a4bfd6400b7d07c09a7725e1a98c3b838be557fee229ac0f84d9aa49c36"},
|
||||||
|
{file = "ruff-0.4.4-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:b1867ee9bf3acc21778dcb293db504692eda5f7a11a6e6cc40890182a9f9e595"},
|
||||||
|
{file = "ruff-0.4.4-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:1aecced1269481ef2894cc495647392a34b0bf3e28ff53ed95a385b13aa45768"},
|
||||||
|
{file = "ruff-0.4.4-py3-none-musllinux_1_2_i686.whl", hash = "sha256:9da73eb616b3241a307b837f32756dc20a0b07e2bcb694fec73699c93d04a69e"},
|
||||||
|
{file = "ruff-0.4.4-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:958b4ea5589706a81065e2a776237de2ecc3e763342e5cc8e02a4a4d8a5e6f95"},
|
||||||
|
{file = "ruff-0.4.4-py3-none-win32.whl", hash = "sha256:cb53473849f011bca6e754f2cdf47cafc9c4f4ff4570003a0dad0b9b6890e876"},
|
||||||
|
{file = "ruff-0.4.4-py3-none-win_amd64.whl", hash = "sha256:424e5b72597482543b684c11def82669cc6b395aa8cc69acc1858b5ef3e5daae"},
|
||||||
|
{file = "ruff-0.4.4-py3-none-win_arm64.whl", hash = "sha256:39df0537b47d3b597293edbb95baf54ff5b49589eb7ff41926d8243caa995ea6"},
|
||||||
|
{file = "ruff-0.4.4.tar.gz", hash = "sha256:f87ea42d5cdebdc6a69761a9d0bc83ae9b3b30d0ad78952005ba6568d6c022af"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "s3transfer"
|
name = "s3transfer"
|
||||||
version = "0.10.0"
|
version = "0.10.0"
|
||||||
@@ -6758,4 +6784,4 @@ benchmark = ["agbenchmark"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.10"
|
python-versions = "^3.10"
|
||||||
content-hash = "b3d4efee5861b32152024dada1ec61f4241122419cb538012c00a6ed55ac8a4b"
|
content-hash = "c729e10fd5ac85400d2499397974d1b1831fed3b591657a2fea9e86501b96e19"
|
||||||
|
|||||||
@@ -30,9 +30,12 @@ gitpython = "^3.1.32"
|
|||||||
hypercorn = "^0.14.4"
|
hypercorn = "^0.14.4"
|
||||||
openai = "^1.7.2"
|
openai = "^1.7.2"
|
||||||
orjson = "^3.8.10"
|
orjson = "^3.8.10"
|
||||||
|
ptyprocess = "^0.7.0"
|
||||||
pydantic = "^2.7.2"
|
pydantic = "^2.7.2"
|
||||||
|
pyright = "^1.1.364"
|
||||||
python-dotenv = "^1.0.0"
|
python-dotenv = "^1.0.0"
|
||||||
requests = "*"
|
requests = "*"
|
||||||
|
ruff = "^0.4.4"
|
||||||
sentry-sdk = "^1.40.4"
|
sentry-sdk = "^1.40.4"
|
||||||
|
|
||||||
# Benchmarking
|
# Benchmarking
|
||||||
@@ -47,7 +50,6 @@ black = "^23.12.1"
|
|||||||
flake8 = "^7.0.0"
|
flake8 = "^7.0.0"
|
||||||
isort = "^5.13.1"
|
isort = "^5.13.1"
|
||||||
pre-commit = "*"
|
pre-commit = "*"
|
||||||
pyright = "^1.1.364"
|
|
||||||
|
|
||||||
# Type stubs
|
# Type stubs
|
||||||
types-colorama = "*"
|
types-colorama = "*"
|
||||||
|
|||||||
126
autogpt/tests/unit/test_code_flow_strategy.py
Normal file
126
autogpt/tests/unit/test_code_flow_strategy.py
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from forge.agent.protocols import CommandProvider
|
||||||
|
from forge.command import Command, command
|
||||||
|
from forge.components.code_flow_executor import CodeFlowExecutionComponent
|
||||||
|
from forge.config.ai_directives import AIDirectives
|
||||||
|
from forge.config.ai_profile import AIProfile
|
||||||
|
from forge.llm.providers import AssistantChatMessage
|
||||||
|
from forge.llm.providers.schema import JSONSchema
|
||||||
|
|
||||||
|
from autogpt.agents.prompt_strategies.code_flow import CodeFlowAgentPromptStrategy
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
config = CodeFlowAgentPromptStrategy.default_configuration.copy(deep=True)
|
||||||
|
prompt_strategy = CodeFlowAgentPromptStrategy(config, logger)
|
||||||
|
|
||||||
|
|
||||||
|
class MockWebSearchProvider(CommandProvider):
|
||||||
|
def get_commands(self):
|
||||||
|
yield self.mock_web_search
|
||||||
|
|
||||||
|
@command(
|
||||||
|
description="Searches the web",
|
||||||
|
parameters={
|
||||||
|
"query": JSONSchema(
|
||||||
|
type=JSONSchema.Type.STRING,
|
||||||
|
description="The search query",
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
"num_results": JSONSchema(
|
||||||
|
type=JSONSchema.Type.INTEGER,
|
||||||
|
description="The number of results to return",
|
||||||
|
minimum=1,
|
||||||
|
maximum=10,
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
def mock_web_search(self, query: str, num_results: Optional[int] = None) -> str:
|
||||||
|
return "results"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_code_flow_build_prompt():
|
||||||
|
commands = list(MockWebSearchProvider().get_commands())
|
||||||
|
|
||||||
|
ai_profile = AIProfile()
|
||||||
|
ai_profile.ai_name = "DummyGPT"
|
||||||
|
ai_profile.ai_goals = ["A model for testing purposes"]
|
||||||
|
ai_profile.ai_role = "Help Testing"
|
||||||
|
|
||||||
|
ai_directives = AIDirectives()
|
||||||
|
ai_directives.resources = ["resource_1"]
|
||||||
|
ai_directives.constraints = ["constraint_1"]
|
||||||
|
ai_directives.best_practices = ["best_practice_1"]
|
||||||
|
|
||||||
|
prompt = str(
|
||||||
|
prompt_strategy.build_prompt(
|
||||||
|
task="Figure out from file.csv how much was spent on utilities",
|
||||||
|
messages=[],
|
||||||
|
ai_profile=ai_profile,
|
||||||
|
ai_directives=ai_directives,
|
||||||
|
commands=commands,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert "DummyGPT" in prompt
|
||||||
|
assert (
|
||||||
|
"def mock_web_search(query: str, num_results: Optional[int] = None)" in prompt
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_code_flow_parse_response():
|
||||||
|
response_content = """
|
||||||
|
{
|
||||||
|
"thoughts": {
|
||||||
|
"past_action_summary": "This is the past action summary.",
|
||||||
|
"observations": "This is the observation.",
|
||||||
|
"text": "Some text on the AI's thoughts.",
|
||||||
|
"reasoning": "This is the reasoning.",
|
||||||
|
"self_criticism": "This is the self-criticism.",
|
||||||
|
"plan": [
|
||||||
|
"Plan 1",
|
||||||
|
"Plan 2",
|
||||||
|
"Plan 3"
|
||||||
|
],
|
||||||
|
"speak": "This is what the AI would say."
|
||||||
|
},
|
||||||
|
"immediate_plan": "Objective[objective1] Plan[plan1] Output[out1]",
|
||||||
|
"python_code": "async def main() -> str:\n return 'You passed the test.'",
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
response = await CodeFlowAgentPromptStrategy(config, logger).parse_response_content(
|
||||||
|
AssistantChatMessage(content=response_content)
|
||||||
|
)
|
||||||
|
assert "This is the observation." == response.thoughts.observations
|
||||||
|
assert "This is the reasoning." == response.thoughts.reasoning
|
||||||
|
|
||||||
|
assert CodeFlowExecutionComponent.execute_code_flow.name == response.use_tool.name
|
||||||
|
assert "async def main() -> str" in response.use_tool.arguments["python_code"]
|
||||||
|
assert (
|
||||||
|
"Objective[objective1] Plan[plan1] Output[out1]"
|
||||||
|
in response.use_tool.arguments["plan_text"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_code_flow_execution():
|
||||||
|
executor = CodeFlowExecutionComponent(
|
||||||
|
lambda: [
|
||||||
|
Command(
|
||||||
|
names=["test_func"],
|
||||||
|
description="",
|
||||||
|
parameters=[],
|
||||||
|
method=lambda: "You've passed the test!",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await executor.execute_code_flow(
|
||||||
|
python_code="async def main() -> str:\n return test_func()",
|
||||||
|
plan_text="This is the plan text.",
|
||||||
|
)
|
||||||
|
assert "You've passed the test!" in result
|
||||||
75
autogpt/tests/unit/test_function_code_validation.py
Normal file
75
autogpt/tests/unit/test_function_code_validation.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
import pytest
|
||||||
|
from forge.utils.function.code_validation import CodeValidator, FunctionDef
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_code_validation():
|
||||||
|
validator = CodeValidator(
|
||||||
|
available_functions={
|
||||||
|
"read_webpage": FunctionDef(
|
||||||
|
name="read_webpage",
|
||||||
|
arg_types=[("url", "str"), ("query", "str")],
|
||||||
|
arg_descs={
|
||||||
|
"url": "URL to read",
|
||||||
|
"query": "Query to search",
|
||||||
|
"return_type": "Type of return value",
|
||||||
|
},
|
||||||
|
return_type="str",
|
||||||
|
return_desc="Information matching the query",
|
||||||
|
function_desc="Read a webpage and return the info matching the query",
|
||||||
|
is_async=True,
|
||||||
|
),
|
||||||
|
"web_search": FunctionDef(
|
||||||
|
name="web_search",
|
||||||
|
arg_types=[("query", "str")],
|
||||||
|
arg_descs={"query": "Query to search"},
|
||||||
|
return_type="list[(str,str)]",
|
||||||
|
return_desc="List of tuples with title and URL",
|
||||||
|
function_desc="Search the web and return the search results",
|
||||||
|
is_async=True,
|
||||||
|
),
|
||||||
|
"main": FunctionDef(
|
||||||
|
name="main",
|
||||||
|
arg_types=[],
|
||||||
|
arg_descs={},
|
||||||
|
return_type="str",
|
||||||
|
return_desc="Answer in the text format",
|
||||||
|
function_desc="Get the num of contributors to the autogpt github repo",
|
||||||
|
is_async=False,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
available_objects={},
|
||||||
|
)
|
||||||
|
response = await validator.validate_code(
|
||||||
|
raw_code="""
|
||||||
|
def crawl_info(url: str, query: str) -> str | None:
|
||||||
|
info = await read_webpage(url, query)
|
||||||
|
if info:
|
||||||
|
return info
|
||||||
|
|
||||||
|
urls = await read_webpage(url, "autogpt github contributor page")
|
||||||
|
for url in urls.split('\\n'):
|
||||||
|
info = await crawl_info(url, query)
|
||||||
|
if info:
|
||||||
|
return info
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def hehe():
|
||||||
|
return 'hehe'
|
||||||
|
|
||||||
|
def main() -> str:
|
||||||
|
query = "Find the number of contributors to the autogpt github repository"
|
||||||
|
for title, url in ("autogpt github contributor page"):
|
||||||
|
info = await crawl_info(url, query)
|
||||||
|
if info:
|
||||||
|
return info
|
||||||
|
x = await hehe()
|
||||||
|
return "No info found"
|
||||||
|
""",
|
||||||
|
packages=[],
|
||||||
|
)
|
||||||
|
assert response.functionCode is not None
|
||||||
|
assert "async def crawl_info" in response.functionCode # async is added
|
||||||
|
assert "async def main" in response.functionCode
|
||||||
|
assert "x = hehe()" in response.functionCode # await is removed
|
||||||
@@ -1,17 +1,13 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Callable, Concatenate, Generic, ParamSpec, TypeVar, cast
|
from typing import Callable, Generic, ParamSpec, TypeVar
|
||||||
|
|
||||||
from forge.agent.protocols import CommandProvider
|
|
||||||
|
|
||||||
from .parameter import CommandParameter
|
from .parameter import CommandParameter
|
||||||
|
|
||||||
P = ParamSpec("P")
|
P = ParamSpec("P")
|
||||||
CO = TypeVar("CO") # command output
|
CO = TypeVar("CO") # command output
|
||||||
|
|
||||||
_CP = TypeVar("_CP", bound=CommandProvider)
|
|
||||||
|
|
||||||
|
|
||||||
class Command(Generic[P, CO]):
|
class Command(Generic[P, CO]):
|
||||||
"""A class representing a command.
|
"""A class representing a command.
|
||||||
@@ -26,37 +22,60 @@ class Command(Generic[P, CO]):
|
|||||||
self,
|
self,
|
||||||
names: list[str],
|
names: list[str],
|
||||||
description: str,
|
description: str,
|
||||||
method: Callable[Concatenate[_CP, P], CO],
|
method: Callable[P, CO],
|
||||||
parameters: list[CommandParameter],
|
parameters: list[CommandParameter],
|
||||||
):
|
):
|
||||||
# Check if all parameters are provided
|
|
||||||
if not self._parameters_match(method, parameters):
|
|
||||||
raise ValueError(
|
|
||||||
f"Command {names[0]} has different parameters than provided schema"
|
|
||||||
)
|
|
||||||
self.names = names
|
self.names = names
|
||||||
self.description = description
|
self.description = description
|
||||||
# Method technically has a `self` parameter, but we can ignore that
|
self.method = method
|
||||||
# since Python passes it internally.
|
|
||||||
self.method = cast(Callable[P, CO], method)
|
|
||||||
self.parameters = parameters
|
self.parameters = parameters
|
||||||
|
|
||||||
|
# Check if all parameters are provided
|
||||||
|
if not self._parameters_match_signature():
|
||||||
|
raise ValueError(
|
||||||
|
f"Command {self.name} has different parameters than provided schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return self.names[0] # TODO: fallback to other name if first one is taken
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_async(self) -> bool:
|
def is_async(self) -> bool:
|
||||||
return inspect.iscoroutinefunction(self.method)
|
return inspect.iscoroutinefunction(self.method)
|
||||||
|
|
||||||
def _parameters_match(
|
@property
|
||||||
self, func: Callable, parameters: list[CommandParameter]
|
def return_type(self) -> str:
|
||||||
) -> bool:
|
_type = inspect.signature(self.method).return_annotation
|
||||||
|
if _type == inspect.Signature.empty:
|
||||||
|
return "None"
|
||||||
|
return _type.__name__
|
||||||
|
|
||||||
|
@property
|
||||||
|
def header(self) -> str:
|
||||||
|
"""Returns a function header representing the command's signature
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
```py
|
||||||
|
def execute_python_code(code: str) -> str:
|
||||||
|
|
||||||
|
async def extract_info_from_content(content: str, instruction: str, output_type: type[~T]) -> ~T:
|
||||||
|
""" # noqa
|
||||||
|
return (
|
||||||
|
f"{'async ' if self.is_async else ''}"
|
||||||
|
f"def {self.name}{inspect.signature(self.method)}:"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parameters_match_signature(self) -> bool:
|
||||||
# Get the function's signature
|
# Get the function's signature
|
||||||
signature = inspect.signature(func)
|
signature = inspect.signature(self.method)
|
||||||
# Extract parameter names, ignoring 'self' for methods
|
# Extract parameter names, ignoring 'self' for methods
|
||||||
func_param_names = [
|
func_param_names = [
|
||||||
param.name
|
param.name
|
||||||
for param in signature.parameters.values()
|
for param in signature.parameters.values()
|
||||||
if param.name != "self"
|
if param.name != "self"
|
||||||
]
|
]
|
||||||
names = [param.name for param in parameters]
|
names = [param.name for param in self.parameters]
|
||||||
# Check if sorted lists of names/keys are equal
|
# Check if sorted lists of names/keys are equal
|
||||||
return sorted(func_param_names) == sorted(names)
|
return sorted(func_param_names) == sorted(names)
|
||||||
|
|
||||||
@@ -71,7 +90,7 @@ class Command(Generic[P, CO]):
|
|||||||
for param in self.parameters
|
for param in self.parameters
|
||||||
]
|
]
|
||||||
return (
|
return (
|
||||||
f"{self.names[0]}: {self.description.rstrip('.')}. "
|
f"{self.name}: {self.description.rstrip('.')}. "
|
||||||
f"Params: ({', '.join(params)})"
|
f"Params: ({', '.join(params)})"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,21 +1,28 @@
|
|||||||
|
import inspect
|
||||||
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Callable, Concatenate, Optional, TypeVar
|
from typing import Callable, Concatenate, Optional, TypeVar, cast
|
||||||
|
|
||||||
from forge.agent.protocols import CommandProvider
|
from forge.agent.protocols import CommandProvider
|
||||||
from forge.models.json_schema import JSONSchema
|
from forge.models.json_schema import JSONSchema
|
||||||
|
|
||||||
from .command import CO, Command, CommandParameter, P
|
from .command import CO, Command, CommandParameter, P
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_CP = TypeVar("_CP", bound=CommandProvider)
|
_CP = TypeVar("_CP", bound=CommandProvider)
|
||||||
|
|
||||||
|
|
||||||
def command(
|
def command(
|
||||||
names: list[str] = [],
|
names: Optional[list[str]] = None,
|
||||||
description: Optional[str] = None,
|
description: Optional[str] = None,
|
||||||
parameters: dict[str, JSONSchema] = {},
|
parameters: Optional[dict[str, JSONSchema]] = None,
|
||||||
) -> Callable[[Callable[Concatenate[_CP, P], CO]], Command[P, CO]]:
|
) -> Callable[[Callable[Concatenate[_CP, P], CO] | Callable[P, CO]], Command[P, CO]]:
|
||||||
"""
|
"""
|
||||||
The command decorator is used to make a Command from a function.
|
Make a `Command` from a function or a method on a `CommandProvider`.
|
||||||
|
All parameters are optional if the decorated function has a fully featured
|
||||||
|
docstring. For the requirements of such a docstring,
|
||||||
|
see `get_param_descriptions_from_docstring`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
names (list[str]): The names of the command.
|
names (list[str]): The names of the command.
|
||||||
@@ -27,34 +34,141 @@ def command(
|
|||||||
that the command executes.
|
that the command executes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(func: Callable[Concatenate[_CP, P], CO]) -> Command[P, CO]:
|
def decorator(
|
||||||
doc = func.__doc__ or ""
|
func: Callable[Concatenate[_CP, P], CO] | Callable[P, CO]
|
||||||
|
) -> Command[P, CO]:
|
||||||
# If names is not provided, use the function name
|
# If names is not provided, use the function name
|
||||||
command_names = names or [func.__name__]
|
_names = names or [func.__name__]
|
||||||
# If description is not provided, use the first part of the docstring
|
|
||||||
if not (command_description := description):
|
# If description is not provided, use the first part of the docstring
|
||||||
if not func.__doc__:
|
docstring = inspect.getdoc(func)
|
||||||
raise ValueError("Description is required if function has no docstring")
|
if not (_description := description):
|
||||||
# Return the part of the docstring before double line break or everything
|
if not docstring:
|
||||||
command_description = re.sub(r"\s+", " ", doc.split("\n\n")[0].strip())
|
raise ValueError(
|
||||||
|
"'description' is required if function has no docstring"
|
||||||
|
)
|
||||||
|
_description = get_clean_description_from_docstring(docstring)
|
||||||
|
|
||||||
|
if not (_parameters := parameters):
|
||||||
|
if not docstring:
|
||||||
|
raise ValueError(
|
||||||
|
"'parameters' is required if function has no docstring"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Combine descriptions from docstring with JSONSchemas from annotations
|
||||||
|
param_descriptions = get_param_descriptions_from_docstring(docstring)
|
||||||
|
_parameters = get_params_json_schemas(func)
|
||||||
|
for param, param_schema in _parameters.items():
|
||||||
|
if desc := param_descriptions.get(param):
|
||||||
|
param_schema.description = desc
|
||||||
|
|
||||||
# Parameters
|
|
||||||
typed_parameters = [
|
typed_parameters = [
|
||||||
CommandParameter(
|
CommandParameter(
|
||||||
name=param_name,
|
name=param_name,
|
||||||
spec=spec,
|
spec=spec,
|
||||||
)
|
)
|
||||||
for param_name, spec in parameters.items()
|
for param_name, spec in _parameters.items()
|
||||||
]
|
]
|
||||||
|
|
||||||
# Wrap func with Command
|
# Wrap func with Command
|
||||||
command = Command(
|
command = Command(
|
||||||
names=command_names,
|
names=_names,
|
||||||
description=command_description,
|
description=_description,
|
||||||
method=func,
|
# Method technically has a `self` parameter, but we can ignore that
|
||||||
|
# since Python passes it internally.
|
||||||
|
method=cast(Callable[P, CO], func),
|
||||||
parameters=typed_parameters,
|
parameters=typed_parameters,
|
||||||
)
|
)
|
||||||
|
|
||||||
return command
|
return command
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def get_clean_description_from_docstring(docstring: str) -> str:
|
||||||
|
"""Return the part of the docstring before double line break or everything"""
|
||||||
|
return re.sub(r"\s+", " ", docstring.split("\n\n")[0].strip())
|
||||||
|
|
||||||
|
|
||||||
|
def get_params_json_schemas(func: Callable) -> dict[str, JSONSchema]:
|
||||||
|
"""Gets the annotations of the given function and converts them to JSONSchemas"""
|
||||||
|
result: dict[str, JSONSchema] = {}
|
||||||
|
for name, parameter in inspect.signature(func).parameters.items():
|
||||||
|
if name == "return":
|
||||||
|
continue
|
||||||
|
param_schema = result[name] = JSONSchema.from_python_type(parameter.annotation)
|
||||||
|
if parameter.default:
|
||||||
|
param_schema.default = parameter.default
|
||||||
|
param_schema.required = False
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def get_param_descriptions_from_docstring(docstring: str) -> dict[str, str]:
|
||||||
|
"""
|
||||||
|
Get parameter descriptions from a docstring. Requirements for the docstring:
|
||||||
|
- The section describing the parameters MUST start with `Params:` or `Args:`, in any
|
||||||
|
capitalization.
|
||||||
|
- An entry describing a parameter MUST be indented by 4 spaces.
|
||||||
|
- An entry describing a parameter MUST start with the parameter name, an optional
|
||||||
|
type annotation, followed by a colon `:`.
|
||||||
|
- Continuations of parameter descriptions MUST be indented relative to the first
|
||||||
|
line of the entry.
|
||||||
|
- The docstring must not be indented as a whole. To get a docstring with the uniform
|
||||||
|
indentation stripped off, use `inspect.getdoc(func)`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
\"\"\"
|
||||||
|
This is the description. This will be ignored.
|
||||||
|
The description can span multiple lines,
|
||||||
|
|
||||||
|
or contain any number of line breaks.
|
||||||
|
|
||||||
|
Params:
|
||||||
|
param1: This is a simple parameter description.
|
||||||
|
param2 (list[str]): This parameter also has a type annotation.
|
||||||
|
param3: This parameter has a long description. This means we will have to let it
|
||||||
|
continue on the next line. The continuation is indented relative to the first
|
||||||
|
line of the entry.
|
||||||
|
|
||||||
|
param4: Extra line breaks to group parameters are allowed. This will not break
|
||||||
|
the algorithm.
|
||||||
|
|
||||||
|
This text is
|
||||||
|
is indented by
|
||||||
|
less than 4 spaces
|
||||||
|
and is interpreted as the end of the `Params:` section.
|
||||||
|
\"\"\"
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
param_descriptions: dict[str, str] = {}
|
||||||
|
param_section = False
|
||||||
|
last_param_name = ""
|
||||||
|
for line in docstring.split("\n"):
|
||||||
|
if not line.strip(): # ignore empty lines
|
||||||
|
continue
|
||||||
|
|
||||||
|
if line.lower().startswith(("params:", "args:")):
|
||||||
|
param_section = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
if param_section:
|
||||||
|
if line.strip() and not line.startswith(" " * 4): # end of section
|
||||||
|
break
|
||||||
|
|
||||||
|
line = line[4:]
|
||||||
|
if line.startswith(" ") and last_param_name: # continuation of description
|
||||||
|
param_descriptions[last_param_name] += " " + line.strip()
|
||||||
|
else:
|
||||||
|
if _match := re.match(r"^(\w+).*?: (.*)", line):
|
||||||
|
param_name = _match.group(1)
|
||||||
|
param_desc = _match.group(2).strip()
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"Invalid line in docstring's parameter section: {repr(line)}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
param_descriptions[param_name] = param_desc
|
||||||
|
last_param_name = param_name
|
||||||
|
return param_descriptions
|
||||||
|
|||||||
@@ -102,6 +102,8 @@ class ActionHistoryComponent(
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _make_result_message(episode: Episode, result: ActionResult) -> ChatMessage:
|
def _make_result_message(episode: Episode, result: ActionResult) -> ChatMessage:
|
||||||
|
from forge.components.code_flow_executor import CodeFlowExecutionComponent
|
||||||
|
|
||||||
if result.status == "success":
|
if result.status == "success":
|
||||||
return (
|
return (
|
||||||
ToolResultMessage(
|
ToolResultMessage(
|
||||||
@@ -110,11 +112,18 @@ class ActionHistoryComponent(
|
|||||||
)
|
)
|
||||||
if episode.action.raw_message.tool_calls
|
if episode.action.raw_message.tool_calls
|
||||||
else ChatMessage.user(
|
else ChatMessage.user(
|
||||||
f"{episode.action.use_tool.name} returned: "
|
(
|
||||||
|
"Execution result:"
|
||||||
|
if (
|
||||||
|
episode.action.use_tool.name
|
||||||
|
== CodeFlowExecutionComponent.execute_code_flow.name
|
||||||
|
)
|
||||||
|
else f"{episode.action.use_tool.name} returned:"
|
||||||
|
)
|
||||||
+ (
|
+ (
|
||||||
f"```\n{result.outputs}\n```"
|
f"\n```\n{result.outputs}\n```"
|
||||||
if "\n" in str(result.outputs)
|
if "\n" in str(result.outputs)
|
||||||
else f"`{result.outputs}`"
|
else f" `{result.outputs}`"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
3
forge/forge/components/code_flow_executor/__init__.py
Normal file
3
forge/forge/components/code_flow_executor/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .code_flow_executor import CodeFlowExecutionComponent
|
||||||
|
|
||||||
|
__all__ = ["CodeFlowExecutionComponent"]
|
||||||
@@ -0,0 +1,69 @@
|
|||||||
|
"""Commands to generate images based on text input"""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
from typing import Any, Callable, Iterable, Iterator
|
||||||
|
|
||||||
|
from forge.agent.protocols import CommandProvider
|
||||||
|
from forge.command import Command, command
|
||||||
|
from forge.models.json_schema import JSONSchema
|
||||||
|
|
||||||
|
MAX_RESULT_LENGTH = 1000
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CodeFlowExecutionComponent(CommandProvider):
|
||||||
|
"""A component that provides commands to execute code flow."""
|
||||||
|
|
||||||
|
def __init__(self, get_available_commands: Callable[[], Iterable[Command]]):
|
||||||
|
self.get_available_commands = get_available_commands
|
||||||
|
|
||||||
|
def get_commands(self) -> Iterator[Command]:
|
||||||
|
yield self.execute_code_flow
|
||||||
|
|
||||||
|
@command(
|
||||||
|
parameters={
|
||||||
|
"python_code": JSONSchema(
|
||||||
|
type=JSONSchema.Type.STRING,
|
||||||
|
description="The Python code to execute",
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
"plan_text": JSONSchema(
|
||||||
|
type=JSONSchema.Type.STRING,
|
||||||
|
description="The plan to written in a natural language",
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def execute_code_flow(self, python_code: str, plan_text: str) -> str:
|
||||||
|
"""Execute the code flow.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
python_code: The Python code to execute
|
||||||
|
plan_text: The plan implemented by the given Python code
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The result of the code execution
|
||||||
|
"""
|
||||||
|
locals: dict[str, Any] = {}
|
||||||
|
locals.update(self._get_available_functions())
|
||||||
|
code = f"{python_code}" "\n\n" "exec_output = main()"
|
||||||
|
logger.debug(f"Code-Flow Execution code:\n```py\n{code}\n```")
|
||||||
|
exec(code, locals)
|
||||||
|
result = await locals["exec_output"]
|
||||||
|
logger.debug(f"Code-Flow Execution result:\n{result}")
|
||||||
|
if inspect.isawaitable(result):
|
||||||
|
result = await result
|
||||||
|
|
||||||
|
# limit the result to limit the characters
|
||||||
|
if len(result) > MAX_RESULT_LENGTH:
|
||||||
|
result = result[:MAX_RESULT_LENGTH] + "...[Truncated, Content is too long]"
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _get_available_functions(self) -> dict[str, Callable]:
|
||||||
|
return {
|
||||||
|
name: command
|
||||||
|
for command in self.get_available_commands()
|
||||||
|
for name in command.names
|
||||||
|
if name != self.execute_code_flow.name
|
||||||
|
}
|
||||||
@@ -169,7 +169,8 @@ class FileManagerComponent(
|
|||||||
parameters={
|
parameters={
|
||||||
"folder": JSONSchema(
|
"folder": JSONSchema(
|
||||||
type=JSONSchema.Type.STRING,
|
type=JSONSchema.Type.STRING,
|
||||||
description="The folder to list files in",
|
description="The folder to list files in. "
|
||||||
|
"Pass an empty string to list files in the workspace.",
|
||||||
required=True,
|
required=True,
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ class UserInteractionComponent(CommandProvider):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
def ask_user(self, question: str) -> str:
|
def ask_user(self, question: str) -> str:
|
||||||
"""If you need more details or information regarding the given goals,
|
"""If you need more details or information regarding the given task,
|
||||||
you can ask the user for input."""
|
you can ask the user for input."""
|
||||||
print(f"\nQ: {question}")
|
print(f"\nQ: {question}")
|
||||||
resp = click.prompt("A")
|
resp = click.prompt("A")
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
@@ -154,7 +155,10 @@ class BaseOpenAIChatProvider(
|
|||||||
self,
|
self,
|
||||||
model_prompt: list[ChatMessage],
|
model_prompt: list[ChatMessage],
|
||||||
model_name: _ModelName,
|
model_name: _ModelName,
|
||||||
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
|
completion_parser: (
|
||||||
|
Callable[[AssistantChatMessage], Awaitable[_T]]
|
||||||
|
| Callable[[AssistantChatMessage], _T]
|
||||||
|
) = lambda _: None,
|
||||||
functions: Optional[list[CompletionModelFunction]] = None,
|
functions: Optional[list[CompletionModelFunction]] = None,
|
||||||
max_output_tokens: Optional[int] = None,
|
max_output_tokens: Optional[int] = None,
|
||||||
prefill_response: str = "",
|
prefill_response: str = "",
|
||||||
@@ -208,7 +212,15 @@ class BaseOpenAIChatProvider(
|
|||||||
parsed_result: _T = None # type: ignore
|
parsed_result: _T = None # type: ignore
|
||||||
if not parse_errors:
|
if not parse_errors:
|
||||||
try:
|
try:
|
||||||
parsed_result = completion_parser(assistant_msg)
|
parsed_result = (
|
||||||
|
await _result
|
||||||
|
if inspect.isawaitable(
|
||||||
|
_result := completion_parser(assistant_msg)
|
||||||
|
)
|
||||||
|
# cast(..) needed because inspect.isawaitable(..) loses type:
|
||||||
|
# https://github.com/microsoft/pyright/issues/3690
|
||||||
|
else cast(_T, _result)
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
parse_errors.append(e)
|
parse_errors.append(e)
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,19 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Optional, ParamSpec, Sequence, TypeVar
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Awaitable,
|
||||||
|
Callable,
|
||||||
|
Optional,
|
||||||
|
ParamSpec,
|
||||||
|
Sequence,
|
||||||
|
TypeVar,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
import sentry_sdk
|
import sentry_sdk
|
||||||
import tenacity
|
import tenacity
|
||||||
@@ -171,7 +182,10 @@ class AnthropicProvider(BaseChatModelProvider[AnthropicModelName, AnthropicSetti
|
|||||||
self,
|
self,
|
||||||
model_prompt: list[ChatMessage],
|
model_prompt: list[ChatMessage],
|
||||||
model_name: AnthropicModelName,
|
model_name: AnthropicModelName,
|
||||||
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
|
completion_parser: (
|
||||||
|
Callable[[AssistantChatMessage], Awaitable[_T]]
|
||||||
|
| Callable[[AssistantChatMessage], _T]
|
||||||
|
) = lambda _: None,
|
||||||
functions: Optional[list[CompletionModelFunction]] = None,
|
functions: Optional[list[CompletionModelFunction]] = None,
|
||||||
max_output_tokens: Optional[int] = None,
|
max_output_tokens: Optional[int] = None,
|
||||||
prefill_response: str = "",
|
prefill_response: str = "",
|
||||||
@@ -237,7 +251,14 @@ class AnthropicProvider(BaseChatModelProvider[AnthropicModelName, AnthropicSetti
|
|||||||
+ "\n".join(str(e) for e in tool_call_errors)
|
+ "\n".join(str(e) for e in tool_call_errors)
|
||||||
)
|
)
|
||||||
|
|
||||||
parsed_result = completion_parser(assistant_msg)
|
# cast(..) needed because inspect.isawaitable(..) loses type info:
|
||||||
|
# https://github.com/microsoft/pyright/issues/3690
|
||||||
|
parsed_result = cast(
|
||||||
|
_T,
|
||||||
|
await _result
|
||||||
|
if inspect.isawaitable(_result := completion_parser(assistant_msg))
|
||||||
|
else _result,
|
||||||
|
)
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._logger.debug(
|
self._logger.debug(
|
||||||
|
|||||||
@@ -1,7 +1,16 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, AsyncIterator, Callable, Optional, Sequence, TypeVar, get_args
|
from typing import (
|
||||||
|
Any,
|
||||||
|
AsyncIterator,
|
||||||
|
Awaitable,
|
||||||
|
Callable,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
TypeVar,
|
||||||
|
get_args,
|
||||||
|
)
|
||||||
|
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
@@ -99,7 +108,10 @@ class MultiProvider(BaseChatModelProvider[ModelName, ModelProviderSettings]):
|
|||||||
self,
|
self,
|
||||||
model_prompt: list[ChatMessage],
|
model_prompt: list[ChatMessage],
|
||||||
model_name: ModelName,
|
model_name: ModelName,
|
||||||
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
|
completion_parser: (
|
||||||
|
Callable[[AssistantChatMessage], Awaitable[_T]]
|
||||||
|
| Callable[[AssistantChatMessage], _T]
|
||||||
|
) = lambda _: None,
|
||||||
functions: Optional[list[CompletionModelFunction]] = None,
|
functions: Optional[list[CompletionModelFunction]] = None,
|
||||||
max_output_tokens: Optional[int] = None,
|
max_output_tokens: Optional[int] = None,
|
||||||
prefill_response: str = "",
|
prefill_response: str = "",
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from collections import defaultdict
|
|||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
ClassVar,
|
ClassVar,
|
||||||
Generic,
|
Generic,
|
||||||
@@ -135,6 +136,8 @@ class CompletionModelFunction(BaseModel):
|
|||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
parameters: dict[str, "JSONSchema"]
|
parameters: dict[str, "JSONSchema"]
|
||||||
|
return_type: str | None = None
|
||||||
|
is_async: bool = False
|
||||||
|
|
||||||
def fmt_line(self) -> str:
|
def fmt_line(self) -> str:
|
||||||
params = ", ".join(
|
params = ", ".join(
|
||||||
@@ -143,6 +146,44 @@ class CompletionModelFunction(BaseModel):
|
|||||||
)
|
)
|
||||||
return f"{self.name}: {self.description}. Params: ({params})"
|
return f"{self.name}: {self.description}. Params: ({params})"
|
||||||
|
|
||||||
|
def fmt_function_stub(self, impl: str = "pass") -> str:
|
||||||
|
"""
|
||||||
|
Formats and returns a function stub as a string with types and descriptions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The formatted function header.
|
||||||
|
"""
|
||||||
|
from forge.llm.prompting.utils import indent
|
||||||
|
|
||||||
|
params = ", ".join(
|
||||||
|
f"{name}: {p.python_type}"
|
||||||
|
+ (
|
||||||
|
f" = {str(p.default)}"
|
||||||
|
if p.default
|
||||||
|
else " = None"
|
||||||
|
if not p.required
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
for name, p in self.parameters.items()
|
||||||
|
)
|
||||||
|
_def = "async def" if self.is_async else "def"
|
||||||
|
_return = f" -> {self.return_type}" if self.return_type else ""
|
||||||
|
return f"{_def} {self.name}({params}){_return}:\n" + indent(
|
||||||
|
'"""\n'
|
||||||
|
f"{self.description}\n\n"
|
||||||
|
"Params:\n"
|
||||||
|
+ indent(
|
||||||
|
"\n".join(
|
||||||
|
f"{name}: {param.description}"
|
||||||
|
for name, param in self.parameters.items()
|
||||||
|
if param.description
|
||||||
|
)
|
||||||
|
)
|
||||||
|
+ "\n"
|
||||||
|
'"""\n'
|
||||||
|
f"{impl}"
|
||||||
|
)
|
||||||
|
|
||||||
def validate_call(
|
def validate_call(
|
||||||
self, function_call: AssistantFunctionCall
|
self, function_call: AssistantFunctionCall
|
||||||
) -> tuple[bool, list["ValidationError"]]:
|
) -> tuple[bool, list["ValidationError"]]:
|
||||||
@@ -415,7 +456,10 @@ class BaseChatModelProvider(BaseModelProvider[_ModelName, _ModelProviderSettings
|
|||||||
self,
|
self,
|
||||||
model_prompt: list[ChatMessage],
|
model_prompt: list[ChatMessage],
|
||||||
model_name: _ModelName,
|
model_name: _ModelName,
|
||||||
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
|
completion_parser: (
|
||||||
|
Callable[[AssistantChatMessage], Awaitable[_T]]
|
||||||
|
| Callable[[AssistantChatMessage], _T]
|
||||||
|
) = lambda _: None,
|
||||||
functions: Optional[list[CompletionModelFunction]] = None,
|
functions: Optional[list[CompletionModelFunction]] = None,
|
||||||
max_output_tokens: Optional[int] = None,
|
max_output_tokens: Optional[int] = None,
|
||||||
prefill_response: str = "",
|
prefill_response: str = "",
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ def function_specs_from_commands(
|
|||||||
"""Get LLM-consumable function specs for the agent's available commands."""
|
"""Get LLM-consumable function specs for the agent's available commands."""
|
||||||
return [
|
return [
|
||||||
CompletionModelFunction(
|
CompletionModelFunction(
|
||||||
name=command.names[0],
|
name=command.name,
|
||||||
description=command.description,
|
description=command.description,
|
||||||
parameters={param.name: param.spec for param in command.parameters},
|
parameters={param.name: param.spec for param in command.parameters},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
|
import ast
|
||||||
import enum
|
import enum
|
||||||
|
import typing
|
||||||
from textwrap import indent
|
from textwrap import indent
|
||||||
from typing import Optional, overload
|
from types import NoneType
|
||||||
|
from typing import Any, Optional, is_typeddict, overload
|
||||||
|
|
||||||
from jsonschema import Draft7Validator, ValidationError
|
from jsonschema import Draft7Validator, ValidationError
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -14,14 +17,17 @@ class JSONSchema(BaseModel):
|
|||||||
NUMBER = "number"
|
NUMBER = "number"
|
||||||
INTEGER = "integer"
|
INTEGER = "integer"
|
||||||
BOOLEAN = "boolean"
|
BOOLEAN = "boolean"
|
||||||
|
TYPE = "type"
|
||||||
|
|
||||||
# TODO: add docstrings
|
# TODO: add docstrings
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
type: Optional[Type] = None
|
type: Optional[Type] = None
|
||||||
enum: Optional[list] = None
|
enum: Optional[list] = None
|
||||||
required: bool = False
|
required: bool = False
|
||||||
|
default: Any = None
|
||||||
items: Optional["JSONSchema"] = None
|
items: Optional["JSONSchema"] = None
|
||||||
properties: Optional[dict[str, "JSONSchema"]] = None
|
properties: Optional[dict[str, "JSONSchema"]] = None
|
||||||
|
additional_properties: Optional["JSONSchema"] = None
|
||||||
minimum: Optional[int | float] = None
|
minimum: Optional[int | float] = None
|
||||||
maximum: Optional[int | float] = None
|
maximum: Optional[int | float] = None
|
||||||
minItems: Optional[int] = None
|
minItems: Optional[int] = None
|
||||||
@@ -31,6 +37,7 @@ class JSONSchema(BaseModel):
|
|||||||
schema: dict = {
|
schema: dict = {
|
||||||
"type": self.type.value if self.type else None,
|
"type": self.type.value if self.type else None,
|
||||||
"description": self.description,
|
"description": self.description,
|
||||||
|
"default": repr(self.default),
|
||||||
}
|
}
|
||||||
if self.type == "array":
|
if self.type == "array":
|
||||||
if self.items:
|
if self.items:
|
||||||
@@ -45,6 +52,8 @@ class JSONSchema(BaseModel):
|
|||||||
schema["required"] = [
|
schema["required"] = [
|
||||||
name for name, prop in self.properties.items() if prop.required
|
name for name, prop in self.properties.items() if prop.required
|
||||||
]
|
]
|
||||||
|
if self.additional_properties:
|
||||||
|
schema["additionalProperties"] = self.additional_properties.to_dict()
|
||||||
elif self.enum:
|
elif self.enum:
|
||||||
schema["enum"] = self.enum
|
schema["enum"] = self.enum
|
||||||
else:
|
else:
|
||||||
@@ -63,11 +72,15 @@ class JSONSchema(BaseModel):
|
|||||||
return JSONSchema(
|
return JSONSchema(
|
||||||
description=schema.get("description"),
|
description=schema.get("description"),
|
||||||
type=schema["type"],
|
type=schema["type"],
|
||||||
|
default=ast.literal_eval(d) if (d := schema.get("default")) else None,
|
||||||
enum=schema.get("enum"),
|
enum=schema.get("enum"),
|
||||||
items=JSONSchema.from_dict(schema["items"]) if "items" in schema else None,
|
items=JSONSchema.from_dict(i) if (i := schema.get("items")) else None,
|
||||||
properties=JSONSchema.parse_properties(schema)
|
properties=JSONSchema.parse_properties(schema)
|
||||||
if schema["type"] == "object"
|
if schema["type"] == "object"
|
||||||
else None,
|
else None,
|
||||||
|
additional_properties=JSONSchema.from_dict(ap)
|
||||||
|
if schema["type"] == "object" and (ap := schema.get("additionalProperties"))
|
||||||
|
else None,
|
||||||
minimum=schema.get("minimum"),
|
minimum=schema.get("minimum"),
|
||||||
maximum=schema.get("maximum"),
|
maximum=schema.get("maximum"),
|
||||||
minItems=schema.get("minItems"),
|
minItems=schema.get("minItems"),
|
||||||
@@ -123,6 +136,82 @@ class JSONSchema(BaseModel):
|
|||||||
f"interface {interface_name} " if interface_name else ""
|
f"interface {interface_name} " if interface_name else ""
|
||||||
) + f"{{\n{indent(attributes_string, ' ')}\n}}"
|
) + f"{{\n{indent(attributes_string, ' ')}\n}}"
|
||||||
|
|
||||||
|
_PYTHON_TO_JSON_TYPE: dict[typing.Type, Type] = {
|
||||||
|
int: Type.INTEGER,
|
||||||
|
str: Type.STRING,
|
||||||
|
bool: Type.BOOLEAN,
|
||||||
|
float: Type.NUMBER,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_python_type(cls, T: typing.Type) -> "JSONSchema":
|
||||||
|
if _t := cls._PYTHON_TO_JSON_TYPE.get(T):
|
||||||
|
partial_schema = cls(type=_t, required=True)
|
||||||
|
elif (
|
||||||
|
typing.get_origin(T) is typing.Union and typing.get_args(T)[-1] is NoneType
|
||||||
|
):
|
||||||
|
if len(typing.get_args(T)[:-1]) > 1:
|
||||||
|
raise NotImplementedError("Union types are currently not supported")
|
||||||
|
partial_schema = cls.from_python_type(typing.get_args(T)[0])
|
||||||
|
partial_schema.required = False
|
||||||
|
return partial_schema
|
||||||
|
elif issubclass(T, BaseModel):
|
||||||
|
partial_schema = JSONSchema.from_dict(T.schema())
|
||||||
|
elif T is list or typing.get_origin(T) is list:
|
||||||
|
partial_schema = JSONSchema(
|
||||||
|
type=JSONSchema.Type.ARRAY,
|
||||||
|
items=JSONSchema.from_python_type(T_v)
|
||||||
|
if (T_v := typing.get_args(T)[0])
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
elif T is dict or typing.get_origin(T) is dict:
|
||||||
|
partial_schema = JSONSchema(
|
||||||
|
type=JSONSchema.Type.OBJECT,
|
||||||
|
additional_properties=JSONSchema.from_python_type(T_v)
|
||||||
|
if (T_v := typing.get_args(T)[1])
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
elif is_typeddict(T):
|
||||||
|
partial_schema = JSONSchema(
|
||||||
|
type=JSONSchema.Type.OBJECT,
|
||||||
|
properties={
|
||||||
|
k: JSONSchema.from_python_type(v)
|
||||||
|
for k, v in T.__annotations__.items()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise TypeError(f"JSONSchema.from_python_type is not implemented for {T}")
|
||||||
|
|
||||||
|
partial_schema.required = True
|
||||||
|
return partial_schema
|
||||||
|
|
||||||
|
_JSON_TO_PYTHON_TYPE: dict[Type, typing.Type] = {
|
||||||
|
j: p for p, j in _PYTHON_TO_JSON_TYPE.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def python_type(self) -> str:
|
||||||
|
if self.type in self._JSON_TO_PYTHON_TYPE:
|
||||||
|
return self._JSON_TO_PYTHON_TYPE[self.type].__name__
|
||||||
|
elif self.type == JSONSchema.Type.ARRAY:
|
||||||
|
return f"list[{self.items.python_type}]" if self.items else "list"
|
||||||
|
elif self.type == JSONSchema.Type.OBJECT:
|
||||||
|
if not self.properties:
|
||||||
|
return "dict"
|
||||||
|
raise NotImplementedError(
|
||||||
|
"JSONSchema.python_type doesn't support TypedDicts yet"
|
||||||
|
)
|
||||||
|
elif self.enum:
|
||||||
|
return "Union[" + ", ".join(repr(v) for v in self.enum) + "]"
|
||||||
|
elif self.type == JSONSchema.Type.TYPE:
|
||||||
|
return "type"
|
||||||
|
elif self.type is None:
|
||||||
|
return "Any"
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"JSONSchema.python_type does not support Type.{self.type.name} yet"
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def typescript_type(self) -> str:
|
def typescript_type(self) -> str:
|
||||||
if not self.type:
|
if not self.type:
|
||||||
@@ -141,6 +230,10 @@ class JSONSchema(BaseModel):
|
|||||||
return self.to_typescript_object_interface()
|
return self.to_typescript_object_interface()
|
||||||
if self.enum:
|
if self.enum:
|
||||||
return " | ".join(repr(v) for v in self.enum)
|
return " | ".join(repr(v) for v in self.enum)
|
||||||
|
elif self.type == JSONSchema.Type.TYPE:
|
||||||
|
return "type"
|
||||||
|
elif self.type is None:
|
||||||
|
return "any"
|
||||||
|
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"JSONSchema.typescript_type does not support Type.{self.type.name} yet"
|
f"JSONSchema.typescript_type does not support Type.{self.type.name} yet"
|
||||||
|
|||||||
515
forge/forge/utils/function/code_validation.py
Normal file
515
forge/forge/utils/function/code_validation.py
Normal file
@@ -0,0 +1,515 @@
|
|||||||
|
import ast
|
||||||
|
import collections
|
||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import pathlib
|
||||||
|
import re
|
||||||
|
import typing
|
||||||
|
import black
|
||||||
|
import isort
|
||||||
|
|
||||||
|
from forge.utils.function.model import FunctionDef, ObjectType, ValidationResponse
|
||||||
|
from forge.utils.function.visitor import FunctionVisitor
|
||||||
|
from forge.utils.function.util import (
|
||||||
|
genererate_line_error,
|
||||||
|
generate_object_code,
|
||||||
|
generate_compiled_code,
|
||||||
|
validate_matching_function,
|
||||||
|
)
|
||||||
|
from forge.utils.function.exec import (
|
||||||
|
exec_external_on_contents,
|
||||||
|
ExecError,
|
||||||
|
PROJECT_TEMP_DIR,
|
||||||
|
DEFAULT_DEPS,
|
||||||
|
execute_command,
|
||||||
|
setup_if_required,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CodeValidator:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
function_name: str | None = None,
|
||||||
|
available_functions: dict[str, FunctionDef] | None = None,
|
||||||
|
available_objects: dict[str, ObjectType] | None = None,
|
||||||
|
):
|
||||||
|
self.func_name: str = function_name or ""
|
||||||
|
self.available_functions: dict[str, FunctionDef] = available_functions or {}
|
||||||
|
self.available_objects: dict[str, ObjectType] = available_objects or {}
|
||||||
|
|
||||||
|
async def reformat_code(
|
||||||
|
self,
|
||||||
|
code: str,
|
||||||
|
packages: list[str] = [],
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Reformat the code snippet
|
||||||
|
Args:
|
||||||
|
code (str): The code snippet to reformat
|
||||||
|
packages (list[str]): The list of packages to validate
|
||||||
|
Returns:
|
||||||
|
str: The reformatted code snippet
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
code = (
|
||||||
|
await self.validate_code(
|
||||||
|
raw_code=code,
|
||||||
|
packages=packages,
|
||||||
|
raise_validation_error=False,
|
||||||
|
add_code_stubs=False,
|
||||||
|
)
|
||||||
|
).get_compiled_code()
|
||||||
|
except Exception as e:
|
||||||
|
# We move on with unfixed code if there's an error
|
||||||
|
logger.warning(f"Error formatting code for route #{self.func_name}: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
for formatter in [
|
||||||
|
lambda code: isort.code(code),
|
||||||
|
lambda code: black.format_str(code, mode=black.FileMode()),
|
||||||
|
]:
|
||||||
|
try:
|
||||||
|
code = formatter(code)
|
||||||
|
except Exception as e:
|
||||||
|
# We move on with unformatted code if there's an error
|
||||||
|
logger.warning(
|
||||||
|
f"Error formatting code for route #{self.func_name}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return code
|
||||||
|
|
||||||
|
async def validate_code(
|
||||||
|
self,
|
||||||
|
raw_code: str,
|
||||||
|
packages: list[str] = [],
|
||||||
|
raise_validation_error: bool = True,
|
||||||
|
add_code_stubs: bool = True,
|
||||||
|
call_cnt: int = 0,
|
||||||
|
) -> ValidationResponse:
|
||||||
|
"""
|
||||||
|
Validate the code snippet for any error
|
||||||
|
Args:
|
||||||
|
packages (list[Package]): The list of packages to validate
|
||||||
|
raw_code (str): The code snippet to validate
|
||||||
|
Returns:
|
||||||
|
ValidationResponse: The response of the validation
|
||||||
|
Raise:
|
||||||
|
ValidationError(e): The list of validation errors in the code snippet
|
||||||
|
"""
|
||||||
|
validation_errors: list[str] = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
tree = ast.parse(raw_code)
|
||||||
|
visitor = FunctionVisitor()
|
||||||
|
visitor.visit(tree)
|
||||||
|
validation_errors.extend(visitor.errors)
|
||||||
|
except Exception as e:
|
||||||
|
# parse invalid code line and add it to the error message
|
||||||
|
error = f"Error parsing code: {e}"
|
||||||
|
|
||||||
|
if "async lambda" in raw_code:
|
||||||
|
error += "\nAsync lambda is not supported in Python. "
|
||||||
|
"Use async def instead!"
|
||||||
|
|
||||||
|
if line := re.search(r"line (\d+)", error):
|
||||||
|
raise Exception(
|
||||||
|
genererate_line_error(error, raw_code, int(line.group(1)))
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception(error)
|
||||||
|
|
||||||
|
# Eliminate duplicate visitor.functions and visitor.objects, prefer the last one
|
||||||
|
visitor.imports = list(set(visitor.imports))
|
||||||
|
visitor.functions = list({f.name: f for f in visitor.functions}.values())
|
||||||
|
visitor.objects = list(
|
||||||
|
{
|
||||||
|
o.name: o
|
||||||
|
for o in visitor.objects
|
||||||
|
if o.name not in self.available_objects
|
||||||
|
}.values()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add implemented functions into the main function, only link the stub functions
|
||||||
|
deps_funcs = [f for f in visitor.functions if f.is_implemented]
|
||||||
|
stub_funcs = [f for f in visitor.functions if not f.is_implemented]
|
||||||
|
|
||||||
|
objects_block = zip(
|
||||||
|
["\n\n" + generate_object_code(obj) for obj in visitor.objects],
|
||||||
|
visitor.objectsIdx,
|
||||||
|
)
|
||||||
|
functions_block = zip(
|
||||||
|
["\n\n" + fun.function_code for fun in deps_funcs], visitor.functionsIdx
|
||||||
|
)
|
||||||
|
globals_block = zip(
|
||||||
|
["\n\n" + glob for glob in visitor.globals], visitor.globalsIdx
|
||||||
|
)
|
||||||
|
function_code = "".join(
|
||||||
|
code
|
||||||
|
for code, _ in sorted(
|
||||||
|
list(objects_block) + list(functions_block) + list(globals_block),
|
||||||
|
key=lambda x: x[1],
|
||||||
|
)
|
||||||
|
).strip()
|
||||||
|
|
||||||
|
# No need to validate main function if it's not provided
|
||||||
|
if self.func_name:
|
||||||
|
main_func = self.__validate_main_function(
|
||||||
|
deps_funcs=deps_funcs,
|
||||||
|
function_code=function_code,
|
||||||
|
validation_errors=validation_errors,
|
||||||
|
)
|
||||||
|
function_template = main_func.function_template
|
||||||
|
else:
|
||||||
|
function_template = None
|
||||||
|
|
||||||
|
# Validate that code is not re-declaring any existing entities.
|
||||||
|
already_declared_entities = set(
|
||||||
|
[
|
||||||
|
obj.name
|
||||||
|
for obj in visitor.objects
|
||||||
|
if obj.name in self.available_objects.keys()
|
||||||
|
]
|
||||||
|
+ [
|
||||||
|
func.name
|
||||||
|
for func in visitor.functions
|
||||||
|
if func.name in self.available_functions.keys()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if not already_declared_entities:
|
||||||
|
validation_errors.append(
|
||||||
|
"These class/function names has already been declared in the code, "
|
||||||
|
"no need to declare them again: " + ", ".join(already_declared_entities)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = ValidationResponse(
|
||||||
|
function_name=self.func_name,
|
||||||
|
available_objects=self.available_objects,
|
||||||
|
available_functions=self.available_functions,
|
||||||
|
rawCode=function_code,
|
||||||
|
imports=visitor.imports.copy(),
|
||||||
|
objects=[], # Objects will be bundled in the function_code instead.
|
||||||
|
template=function_template or "",
|
||||||
|
functionCode=function_code,
|
||||||
|
functions=stub_funcs,
|
||||||
|
packages=packages,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute static validators and fixers.
|
||||||
|
# print('old compiled code import ---->', result.imports)
|
||||||
|
old_compiled_code = generate_compiled_code(result, add_code_stubs)
|
||||||
|
validation_errors.extend(await static_code_analysis(result))
|
||||||
|
new_compiled_code = result.get_compiled_code()
|
||||||
|
|
||||||
|
# Auto-fixer works, retry validation (limit to 5 times, to avoid infinite loop)
|
||||||
|
if old_compiled_code != new_compiled_code and call_cnt < 5:
|
||||||
|
return await self.validate_code(
|
||||||
|
packages=packages,
|
||||||
|
raw_code=new_compiled_code,
|
||||||
|
raise_validation_error=raise_validation_error,
|
||||||
|
add_code_stubs=add_code_stubs,
|
||||||
|
call_cnt=call_cnt + 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
if validation_errors:
|
||||||
|
if raise_validation_error:
|
||||||
|
error_message = "".join("\n * " + e for e in validation_errors)
|
||||||
|
raise Exception("Error validating code: " + error_message)
|
||||||
|
else:
|
||||||
|
# This should happen only on `reformat_code` call
|
||||||
|
logger.warning("Error validating code: %s", validation_errors)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def __validate_main_function(
|
||||||
|
self,
|
||||||
|
deps_funcs: list[FunctionDef],
|
||||||
|
function_code: str,
|
||||||
|
validation_errors: list[str],
|
||||||
|
) -> FunctionDef:
|
||||||
|
"""
|
||||||
|
Validate the main function body and signature
|
||||||
|
Returns:
|
||||||
|
tuple[str, FunctionDef]: The function ID and the function object
|
||||||
|
"""
|
||||||
|
# Validate that the main function is implemented.
|
||||||
|
func_obj = next((f for f in deps_funcs if f.name == self.func_name), None)
|
||||||
|
if not func_obj or not func_obj.is_implemented:
|
||||||
|
raise Exception(
|
||||||
|
f"Main Function body {self.func_name} is not implemented."
|
||||||
|
f" Please complete the implementation of this function!"
|
||||||
|
)
|
||||||
|
func_obj.function_code = function_code
|
||||||
|
|
||||||
|
# Validate that the main function is matching the expected signature.
|
||||||
|
func_req: FunctionDef | None = self.available_functions.get(self.func_name)
|
||||||
|
if not func_req:
|
||||||
|
raise AssertionError(f"Function {self.func_name} does not exist on DB")
|
||||||
|
try:
|
||||||
|
validate_matching_function(func_obj, func_req)
|
||||||
|
except Exception as e:
|
||||||
|
validation_errors.append(e.__str__())
|
||||||
|
|
||||||
|
return func_obj
|
||||||
|
|
||||||
|
|
||||||
|
# ======= Static Code Validation Helper Functions =======#
|
||||||
|
|
||||||
|
|
||||||
|
async def static_code_analysis(func: ValidationResponse) -> list[str]:
|
||||||
|
"""
|
||||||
|
Run static code analysis on the function code and mutate the function code to
|
||||||
|
fix any issues.
|
||||||
|
Args:
|
||||||
|
func (ValidationResponse):
|
||||||
|
The function to run static code analysis on. `func` will be mutated.
|
||||||
|
Returns:
|
||||||
|
list[str]: The list of validation errors
|
||||||
|
"""
|
||||||
|
validation_errors = []
|
||||||
|
validation_errors += await __execute_ruff(func)
|
||||||
|
validation_errors += await __execute_pyright(func)
|
||||||
|
|
||||||
|
return validation_errors
|
||||||
|
|
||||||
|
|
||||||
|
CODE_SEPARATOR = "#------Code-Start------#"
|
||||||
|
|
||||||
|
|
||||||
|
def __pack_import_and_function_code(func: ValidationResponse) -> str:
|
||||||
|
return "\n".join(func.imports + [CODE_SEPARATOR, func.rawCode])
|
||||||
|
|
||||||
|
|
||||||
|
def __unpack_import_and_function_code(code: str) -> tuple[list[str], str]:
|
||||||
|
split = code.split(CODE_SEPARATOR)
|
||||||
|
return split[0].splitlines(), split[1].strip()
|
||||||
|
|
||||||
|
|
||||||
|
async def __execute_ruff(func: ValidationResponse) -> list[str]:
|
||||||
|
code = __pack_import_and_function_code(func)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Currently Disabled Rule List
|
||||||
|
# E402 module level import not at top of file
|
||||||
|
# F841 local variable is assigned to but never used
|
||||||
|
code = await exec_external_on_contents(
|
||||||
|
command_arguments=[
|
||||||
|
"ruff",
|
||||||
|
"check",
|
||||||
|
"--fix",
|
||||||
|
"--ignore",
|
||||||
|
"F841",
|
||||||
|
"--ignore",
|
||||||
|
"E402",
|
||||||
|
"--ignore",
|
||||||
|
"F811", # Redefinition of unused '...' from line ...
|
||||||
|
],
|
||||||
|
file_contents=code,
|
||||||
|
suffix=".py",
|
||||||
|
raise_file_contents_on_error=True,
|
||||||
|
)
|
||||||
|
func.imports, func.rawCode = __unpack_import_and_function_code(code)
|
||||||
|
return []
|
||||||
|
|
||||||
|
except ExecError as e:
|
||||||
|
if e.content:
|
||||||
|
# Ruff failed, but the code is reformatted
|
||||||
|
code = e.content
|
||||||
|
e = str(e)
|
||||||
|
|
||||||
|
error_messages = [
|
||||||
|
v
|
||||||
|
for v in str(e).split("\n")
|
||||||
|
if v.strip()
|
||||||
|
if re.match(r"Found \d+ errors?\.*", v) is None
|
||||||
|
]
|
||||||
|
|
||||||
|
added_imports, error_messages = await __fix_missing_imports(
|
||||||
|
error_messages, func
|
||||||
|
)
|
||||||
|
|
||||||
|
# Append problematic line to the error message or add it as TODO line
|
||||||
|
validation_errors: list[str] = []
|
||||||
|
split_pattern = r"(.+):(\d+):(\d+): (.+)"
|
||||||
|
for error_message in error_messages:
|
||||||
|
error_split = re.match(split_pattern, error_message)
|
||||||
|
|
||||||
|
if not error_split:
|
||||||
|
error = error_message
|
||||||
|
else:
|
||||||
|
_, line, _, error = error_split.groups()
|
||||||
|
error = genererate_line_error(error, code, int(line))
|
||||||
|
|
||||||
|
validation_errors.append(error)
|
||||||
|
|
||||||
|
func.imports, func.rawCode = __unpack_import_and_function_code(code)
|
||||||
|
func.imports.extend(added_imports) # Avoid line-code change, do it at the end.
|
||||||
|
|
||||||
|
return validation_errors
|
||||||
|
|
||||||
|
|
||||||
|
async def __execute_pyright(func: ValidationResponse) -> list[str]:
|
||||||
|
code = __pack_import_and_function_code(func)
|
||||||
|
validation_errors: list[str] = []
|
||||||
|
|
||||||
|
# Create temporary directory under the TEMP_DIR with random name
|
||||||
|
temp_dir = PROJECT_TEMP_DIR / (func.function_name)
|
||||||
|
py_path = await setup_if_required(temp_dir)
|
||||||
|
|
||||||
|
async def __execute_pyright_commands(code: str) -> list[str]:
|
||||||
|
try:
|
||||||
|
await execute_command(
|
||||||
|
["pip", "install", "-r", "requirements.txt"], temp_dir, py_path
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# Unknown deps should be reported as validation errors
|
||||||
|
validation_errors.append(e.__str__())
|
||||||
|
|
||||||
|
# execute pyright
|
||||||
|
result = await execute_command(
|
||||||
|
["pyright", "--outputjson"], temp_dir, py_path, raise_on_error=False
|
||||||
|
)
|
||||||
|
if not result:
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
json_response = json.loads(result)["generalDiagnostics"]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error parsing pyright output, error: {e} output: {result}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
for e in json_response:
|
||||||
|
rule: str = e.get("rule", "")
|
||||||
|
severity: str = e.get("severity", "")
|
||||||
|
excluded_rules = ["reportRedeclaration"]
|
||||||
|
if severity != "error" or any([rule.startswith(r) for r in excluded_rules]):
|
||||||
|
continue
|
||||||
|
|
||||||
|
e = genererate_line_error(
|
||||||
|
error=f"{e['message']}. {e.get('rule', '')}",
|
||||||
|
code=code,
|
||||||
|
line_number=e["range"]["start"]["line"] + 1,
|
||||||
|
)
|
||||||
|
validation_errors.append(e)
|
||||||
|
|
||||||
|
# read code from code.py. split the code into imports and raw code
|
||||||
|
code = open(f"{temp_dir}/code.py").read()
|
||||||
|
code, error_messages = await __fix_async_calls(code, validation_errors)
|
||||||
|
func.imports, func.rawCode = __unpack_import_and_function_code(code)
|
||||||
|
|
||||||
|
return validation_errors
|
||||||
|
|
||||||
|
packages = "\n".join([str(p) for p in func.packages if p not in DEFAULT_DEPS])
|
||||||
|
(temp_dir / "requirements.txt").write_text(packages)
|
||||||
|
(temp_dir / "code.py").write_text(code)
|
||||||
|
|
||||||
|
return await __execute_pyright_commands(code)
|
||||||
|
|
||||||
|
|
||||||
|
async def find_module_dist_and_source(
|
||||||
|
module: str, py_path: pathlib.Path | str
|
||||||
|
) -> typing.Tuple[pathlib.Path | None, pathlib.Path | None]:
|
||||||
|
# Find the module in the env
|
||||||
|
modules_path = pathlib.Path(py_path).parent / "lib" / "python3.11" / "site-packages"
|
||||||
|
matches = modules_path.glob(f"{module}*")
|
||||||
|
|
||||||
|
# resolve the generator to an array
|
||||||
|
matches = list(matches)
|
||||||
|
if not matches:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
# find the dist info path and the module path
|
||||||
|
dist_info_path: typing.Optional[pathlib.Path] = None
|
||||||
|
module_path: typing.Optional[pathlib.Path] = None
|
||||||
|
|
||||||
|
# find the dist info path
|
||||||
|
for match in matches:
|
||||||
|
if re.match(f"{module}-[0-9]+.[0-9]+.[0-9]+.dist-info", match.name):
|
||||||
|
dist_info_path = match
|
||||||
|
break
|
||||||
|
# Get the module path
|
||||||
|
for match in matches:
|
||||||
|
if module == match.name:
|
||||||
|
module_path = match
|
||||||
|
break
|
||||||
|
|
||||||
|
return dist_info_path, module_path
|
||||||
|
|
||||||
|
|
||||||
|
AUTO_IMPORT_TYPES: dict[str, str] = {
|
||||||
|
"Enum": "from enum import Enum",
|
||||||
|
"array": "from array import array",
|
||||||
|
}
|
||||||
|
for t in typing.__all__:
|
||||||
|
AUTO_IMPORT_TYPES[t] = f"from typing import {t}"
|
||||||
|
for t in datetime.__all__:
|
||||||
|
AUTO_IMPORT_TYPES[t] = f"from datetime import {t}"
|
||||||
|
for t in collections.__all__:
|
||||||
|
AUTO_IMPORT_TYPES[t] = f"from collections import {t}"
|
||||||
|
|
||||||
|
|
||||||
|
async def __fix_async_calls(code: str, errors: list[str]) -> tuple[str, list[str]]:
|
||||||
|
"""
|
||||||
|
Fix the async calls in the code
|
||||||
|
Args:
|
||||||
|
code (str): The code snippet
|
||||||
|
errors (list[str]): The list of errors
|
||||||
|
func (ValidationResponse): The function to fix the async calls
|
||||||
|
Returns:
|
||||||
|
tuple[str, list[str]]: The fixed code snippet and the list of errors
|
||||||
|
"""
|
||||||
|
async_calls = set()
|
||||||
|
new_errors = []
|
||||||
|
for error in errors:
|
||||||
|
pattern = '"__await__" is not present. reportGeneralTypeIssues -> (.+)'
|
||||||
|
match = re.search(pattern, error)
|
||||||
|
if match:
|
||||||
|
async_calls.add(match.group(1))
|
||||||
|
else:
|
||||||
|
new_errors.append(error)
|
||||||
|
|
||||||
|
for async_call in async_calls:
|
||||||
|
func_call = re.search(r"await ([a-zA-Z0-9_]+)", async_call)
|
||||||
|
if func_call:
|
||||||
|
func_name = func_call.group(1)
|
||||||
|
code = code.replace(f"await {func_name}", f"{func_name}")
|
||||||
|
|
||||||
|
return code, new_errors
|
||||||
|
|
||||||
|
|
||||||
|
async def __fix_missing_imports(
|
||||||
|
errors: list[str], func: ValidationResponse
|
||||||
|
) -> tuple[set[str], list[str]]:
|
||||||
|
"""
|
||||||
|
Generate missing imports based on the errors
|
||||||
|
Args:
|
||||||
|
errors (list[str]): The list of errors
|
||||||
|
func (ValidationResponse): The function to fix the imports
|
||||||
|
Returns:
|
||||||
|
tuple[set[str], list[str]]: The set of missing imports and the list
|
||||||
|
of non-missing import errors
|
||||||
|
"""
|
||||||
|
missing_imports = []
|
||||||
|
filtered_errors = []
|
||||||
|
for error in errors:
|
||||||
|
pattern = r"Undefined name `(.+?)`"
|
||||||
|
match = re.search(pattern, error)
|
||||||
|
if not match:
|
||||||
|
filtered_errors.append(error)
|
||||||
|
continue
|
||||||
|
|
||||||
|
missing = match.group(1)
|
||||||
|
if missing in AUTO_IMPORT_TYPES:
|
||||||
|
missing_imports.append(AUTO_IMPORT_TYPES[missing])
|
||||||
|
elif missing in func.available_functions:
|
||||||
|
# TODO FIX THIS!! IMPORT AUTOGPT CORRECY SERVICE.
|
||||||
|
missing_imports.append(f"from project.{missing}_service import {missing}")
|
||||||
|
elif missing in func.available_objects:
|
||||||
|
# TODO FIX THIS!! IMPORT AUTOGPT CORRECY SERVICE.
|
||||||
|
missing_imports.append(f"from project.{missing}_object import {missing}")
|
||||||
|
else:
|
||||||
|
filtered_errors.append(error)
|
||||||
|
|
||||||
|
return set(missing_imports), filtered_errors
|
||||||
195
forge/forge/utils/function/exec.py
Normal file
195
forge/forge/utils/function/exec.py
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
import asyncio
|
||||||
|
import enum
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
from asyncio.subprocess import Process
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OutputType(enum.Enum):
|
||||||
|
STD_OUT = "stdout"
|
||||||
|
STD_ERR = "stderr"
|
||||||
|
BOTH = "both"
|
||||||
|
|
||||||
|
|
||||||
|
class ExecError(Exception):
|
||||||
|
content: str | None
|
||||||
|
|
||||||
|
def __init__(self, error: str, content: str | None = None):
|
||||||
|
super().__init__(error)
|
||||||
|
self.content = content
|
||||||
|
|
||||||
|
|
||||||
|
async def exec_external_on_contents(
|
||||||
|
command_arguments: list[str],
|
||||||
|
file_contents,
|
||||||
|
suffix: str = ".py",
|
||||||
|
output_type: OutputType = OutputType.BOTH,
|
||||||
|
raise_file_contents_on_error: bool = False,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Execute an external tool with the provided command arguments and file contents
|
||||||
|
:param command_arguments: The command arguments to execute
|
||||||
|
:param file_contents: The file contents to execute the command on
|
||||||
|
:param suffix: The suffix of the temporary file. Default is ".py"
|
||||||
|
:return: The file contents after the command has been executed
|
||||||
|
|
||||||
|
Note: The file contents are written to a temporary file and the command is executed
|
||||||
|
on that file. The command arguments should be a list of strings, where the first
|
||||||
|
element is the command to execute and the rest of the elements are the arguments to
|
||||||
|
the command. There is no need to provide the file path as an argument, as it will
|
||||||
|
be appended to the command arguments.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
exec_external(["ruff", "check"], "print('Hello World')")
|
||||||
|
will run the command "ruff check <temp_file_path>" with the file contents
|
||||||
|
"print('Hello World')" and return the file contents after the command
|
||||||
|
has been executed.
|
||||||
|
|
||||||
|
"""
|
||||||
|
errors = ""
|
||||||
|
if len(command_arguments) == 0:
|
||||||
|
raise ExecError("No command arguments provided")
|
||||||
|
|
||||||
|
# Run ruff to validate the code
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
|
||||||
|
temp_file_path = temp_file.name
|
||||||
|
temp_file.write(file_contents.encode("utf-8"))
|
||||||
|
temp_file.flush()
|
||||||
|
|
||||||
|
command_arguments.append(str(temp_file_path))
|
||||||
|
|
||||||
|
# Run Ruff on the temporary file
|
||||||
|
try:
|
||||||
|
r: Process = await asyncio.create_subprocess_exec(
|
||||||
|
*command_arguments,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
)
|
||||||
|
result = await r.communicate()
|
||||||
|
stdout, stderr = result[0].decode("utf-8"), result[1].decode("utf-8")
|
||||||
|
logger.debug(f"Output: {stdout}")
|
||||||
|
if temp_file_path in stdout:
|
||||||
|
stdout = stdout # .replace(temp_file.name, "/generated_file")
|
||||||
|
logger.debug(f"Errors: {stderr}")
|
||||||
|
if output_type == OutputType.STD_OUT:
|
||||||
|
errors = stdout
|
||||||
|
elif output_type == OutputType.STD_ERR:
|
||||||
|
errors = stderr
|
||||||
|
else:
|
||||||
|
errors = stdout + "\n" + stderr
|
||||||
|
|
||||||
|
with open(temp_file_path, "r") as f:
|
||||||
|
file_contents = f.read()
|
||||||
|
finally:
|
||||||
|
# Ensure the temporary file is deleted
|
||||||
|
os.remove(temp_file_path)
|
||||||
|
|
||||||
|
if not errors:
|
||||||
|
return file_contents
|
||||||
|
|
||||||
|
if raise_file_contents_on_error:
|
||||||
|
raise ExecError(errors, file_contents)
|
||||||
|
|
||||||
|
raise ExecError(errors)
|
||||||
|
|
||||||
|
|
||||||
|
FOLDER_NAME = "agpt-static-code-analysis"
|
||||||
|
PROJECT_PARENT_DIR = Path(__file__).resolve().parent.parent.parent / f".{FOLDER_NAME}"
|
||||||
|
PROJECT_TEMP_DIR = Path(tempfile.gettempdir()) / FOLDER_NAME
|
||||||
|
DEFAULT_DEPS = ["pyright", "pydantic", "virtualenv-clone"]
|
||||||
|
|
||||||
|
|
||||||
|
def is_env_exists(path: Path):
|
||||||
|
return (
|
||||||
|
(path / "venv/bin/python").exists()
|
||||||
|
and (path / "venv/bin/pip").exists()
|
||||||
|
and (path / "venv/bin/virtualenv-clone").exists()
|
||||||
|
and (path / "venv/bin/pyright").exists()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def setup_if_required(
|
||||||
|
cwd: Path = PROJECT_PARENT_DIR, copy_from_parent: bool = True
|
||||||
|
) -> Path:
|
||||||
|
"""
|
||||||
|
Set-up the virtual environment if it does not exist
|
||||||
|
This setup is executed expectedly once per application run
|
||||||
|
Args:
|
||||||
|
cwd (Path): The current working directory
|
||||||
|
copy_from_parent (bool):
|
||||||
|
Whether to copy the virtual environment from PROJECT_PARENT_DIR
|
||||||
|
Returns:
|
||||||
|
Path: The path to the virtual environment
|
||||||
|
"""
|
||||||
|
if not cwd.exists():
|
||||||
|
cwd.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
path = cwd / "venv/bin"
|
||||||
|
if is_env_exists(cwd):
|
||||||
|
return path
|
||||||
|
|
||||||
|
if copy_from_parent and cwd != PROJECT_PARENT_DIR:
|
||||||
|
if (cwd / "venv").exists():
|
||||||
|
await execute_command(["rm", "-rf", str(cwd / "venv")], cwd, None)
|
||||||
|
await execute_command(
|
||||||
|
["virtualenv-clone", str(PROJECT_PARENT_DIR / "venv"), str(cwd / "venv")],
|
||||||
|
cwd,
|
||||||
|
await setup_if_required(PROJECT_PARENT_DIR),
|
||||||
|
)
|
||||||
|
return path
|
||||||
|
|
||||||
|
# Create a virtual environment
|
||||||
|
output = await execute_command(["python", "-m", "venv", "venv"], cwd, None)
|
||||||
|
logger.info(f"[Setup] Created virtual environment: {output}")
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
output = await execute_command(["pip", "install", "-I"] + DEFAULT_DEPS, cwd, path)
|
||||||
|
logger.info(f"[Setup] Installed {DEFAULT_DEPS}: {output}")
|
||||||
|
|
||||||
|
output = await execute_command(["pyright"], cwd, path, raise_on_error=False)
|
||||||
|
logger.info(f"[Setup] Set up pyright: {output}")
|
||||||
|
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
async def execute_command(
|
||||||
|
command: list[str],
|
||||||
|
cwd: str | Path | None,
|
||||||
|
python_path: str | Path | None = None,
|
||||||
|
raise_on_error: bool = True,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Execute a command in the shell
|
||||||
|
Args:
|
||||||
|
command (list[str]): The command to execute
|
||||||
|
cwd (str | Path): The current working directory
|
||||||
|
python_path (str | Path): The python executable path
|
||||||
|
raise_on_error (bool): Whether to raise an error if the command fails
|
||||||
|
Returns:
|
||||||
|
str: The output of the command
|
||||||
|
"""
|
||||||
|
# Set the python path by replacing the env 'PATH' with the provided python path
|
||||||
|
venv = os.environ.copy()
|
||||||
|
if python_path:
|
||||||
|
# PATH prioritize first occurrence of python_path, so we need to prepend.
|
||||||
|
venv["PATH"] = f"{python_path}:{venv['PATH']}"
|
||||||
|
r = await asyncio.create_subprocess_exec(
|
||||||
|
*command,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
cwd=str(cwd),
|
||||||
|
env=venv,
|
||||||
|
)
|
||||||
|
stdout, stderr = await r.communicate()
|
||||||
|
if r.returncode == 0:
|
||||||
|
return (stdout or stderr).decode("utf-8")
|
||||||
|
|
||||||
|
if raise_on_error:
|
||||||
|
raise Exception((stderr or stdout).decode("utf-8"))
|
||||||
|
else:
|
||||||
|
return (stderr or stdout).decode("utf-8")
|
||||||
110
forge/forge/utils/function/model.py
Normal file
110
forge/forge/utils/function/model.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectType(BaseModel):
|
||||||
|
name: str = Field(description="The name of the object")
|
||||||
|
code: Optional[str] = Field(description="The code of the object", default=None)
|
||||||
|
description: Optional[str] = Field(
|
||||||
|
description="The description of the object", default=None
|
||||||
|
)
|
||||||
|
Fields: List["ObjectField"] = Field(description="The fields of the object")
|
||||||
|
is_pydantic: bool = Field(
|
||||||
|
description="Whether the object is a pydantic model", default=True
|
||||||
|
)
|
||||||
|
is_implemented: bool = Field(
|
||||||
|
description="Whether the object is implemented", default=True
|
||||||
|
)
|
||||||
|
is_enum: bool = Field(description="Whether the object is an enum", default=False)
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectField(BaseModel):
|
||||||
|
name: str = Field(description="The name of the field")
|
||||||
|
description: Optional[str] = Field(
|
||||||
|
description="The description of the field", default=None
|
||||||
|
)
|
||||||
|
type: str = Field(
|
||||||
|
description="The type of the field. Can be a string like List[str] or an use "
|
||||||
|
"any of they related types like list[User]",
|
||||||
|
)
|
||||||
|
value: Optional[str] = Field(description="The value of the field", default=None)
|
||||||
|
related_types: Optional[List[ObjectType]] = Field(
|
||||||
|
description="The related types of the field", default=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionDef(BaseModel):
|
||||||
|
name: str
|
||||||
|
arg_types: list[tuple[str, str]]
|
||||||
|
arg_defaults: dict[str, str] = {}
|
||||||
|
arg_descs: dict[str, str]
|
||||||
|
return_type: str | None = None
|
||||||
|
return_desc: str
|
||||||
|
function_desc: str
|
||||||
|
is_implemented: bool = False
|
||||||
|
function_code: str = ""
|
||||||
|
function_template: str | None = None
|
||||||
|
is_async: bool = False
|
||||||
|
|
||||||
|
def __generate_function_template(self) -> str:
|
||||||
|
args_str = ", ".join(
|
||||||
|
[
|
||||||
|
f"{name}: {type}"
|
||||||
|
+ (
|
||||||
|
f" = {self.arg_defaults.get(name, '')}"
|
||||||
|
if name in self.arg_defaults
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
for name, type in self.arg_types
|
||||||
|
]
|
||||||
|
)
|
||||||
|
arg_desc = f"\n{' '*4}".join(
|
||||||
|
[
|
||||||
|
f'{name} ({type}): {self.arg_descs.get(name, "-")}'
|
||||||
|
for name, type in self.arg_types
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
_def = "async def" if "await " in self.function_code or self.is_async else "def"
|
||||||
|
_return_type = f" -> {self.return_type}" if self.return_type else ""
|
||||||
|
func_desc = self.function_desc.replace("\n", "\n ")
|
||||||
|
|
||||||
|
template = f"""
|
||||||
|
{_def} {self.name}({args_str}){_return_type}:
|
||||||
|
\"\"\"
|
||||||
|
{func_desc}
|
||||||
|
|
||||||
|
Args:
|
||||||
|
{arg_desc}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{self.return_type}{': ' + self.return_desc if self.return_desc else ''}
|
||||||
|
\"\"\"
|
||||||
|
pass
|
||||||
|
"""
|
||||||
|
return "\n".join([line for line in template.split("\n")]).strip()
|
||||||
|
|
||||||
|
def __init__(self, function_template: Optional[str] = None, **data):
|
||||||
|
super().__init__(**data)
|
||||||
|
self.function_template = (
|
||||||
|
function_template or self.__generate_function_template()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationResponse(BaseModel):
|
||||||
|
function_name: str
|
||||||
|
available_objects: dict[str, ObjectType]
|
||||||
|
available_functions: dict[str, FunctionDef]
|
||||||
|
|
||||||
|
template: str
|
||||||
|
rawCode: str
|
||||||
|
packages: List[str]
|
||||||
|
imports: List[str]
|
||||||
|
functionCode: str
|
||||||
|
|
||||||
|
functions: List[FunctionDef]
|
||||||
|
objects: List[ObjectType]
|
||||||
|
|
||||||
|
def get_compiled_code(self) -> str:
|
||||||
|
return "\n".join(self.imports) + "\n\n" + self.rawCode.strip()
|
||||||
292
forge/forge/utils/function/util.py
Normal file
292
forge/forge/utils/function/util.py
Normal file
@@ -0,0 +1,292 @@
|
|||||||
|
from typing import List, Tuple, __all__ as all_types
|
||||||
|
from forge.utils.function.model import FunctionDef, ObjectType, ValidationResponse
|
||||||
|
|
||||||
|
OPEN_BRACES = {"{": "Dict", "[": "List", "(": "Tuple"}
|
||||||
|
CLOSE_BRACES = {"}": "Dict", "]": "List", ")": "Tuple"}
|
||||||
|
|
||||||
|
RENAMED_TYPES = {
|
||||||
|
"dict": "Dict",
|
||||||
|
"list": "List",
|
||||||
|
"tuple": "Tuple",
|
||||||
|
"set": "Set",
|
||||||
|
"frozenset": "FrozenSet",
|
||||||
|
"type": "Type",
|
||||||
|
}
|
||||||
|
PYTHON_TYPES = set(all_types)
|
||||||
|
|
||||||
|
|
||||||
|
def unwrap_object_type(type: str) -> Tuple[str, List[str]]:
|
||||||
|
"""
|
||||||
|
Get the type and children of a composite type.
|
||||||
|
Args:
|
||||||
|
type (str): The type to parse.
|
||||||
|
Returns:
|
||||||
|
str: The type.
|
||||||
|
[str]: The children types.
|
||||||
|
"""
|
||||||
|
type = type.replace(" ", "")
|
||||||
|
if not type:
|
||||||
|
return "", []
|
||||||
|
|
||||||
|
def split_outer_level(type: str, separator: str) -> List[str]:
|
||||||
|
brace_count = 0
|
||||||
|
last_index = 0
|
||||||
|
splits = []
|
||||||
|
|
||||||
|
for i, c in enumerate(type):
|
||||||
|
if c in OPEN_BRACES:
|
||||||
|
brace_count += 1
|
||||||
|
elif c in CLOSE_BRACES:
|
||||||
|
brace_count -= 1
|
||||||
|
elif c == separator and brace_count == 0:
|
||||||
|
splits.append(type[last_index:i])
|
||||||
|
last_index = i + 1
|
||||||
|
|
||||||
|
splits.append(type[last_index:])
|
||||||
|
return splits
|
||||||
|
|
||||||
|
# Unwrap primitive union types
|
||||||
|
union_split = split_outer_level(type, "|")
|
||||||
|
if len(union_split) > 1:
|
||||||
|
if len(union_split) == 2 and "None" in union_split:
|
||||||
|
return "Optional", [v for v in union_split if v != "None"]
|
||||||
|
return "Union", union_split
|
||||||
|
|
||||||
|
# Unwrap primitive dict/list/tuple types
|
||||||
|
if type[0] in OPEN_BRACES and type[-1] in CLOSE_BRACES:
|
||||||
|
type_name = OPEN_BRACES[type[0]]
|
||||||
|
type_children = split_outer_level(type[1:-1], ",")
|
||||||
|
return type_name, type_children
|
||||||
|
|
||||||
|
brace_pos = type.find("[")
|
||||||
|
if brace_pos != -1 and type[-1] == "]":
|
||||||
|
# Unwrap normal composite types
|
||||||
|
type_name = type[:brace_pos]
|
||||||
|
type_children = split_outer_level(type[brace_pos + 1 : -1], ",")
|
||||||
|
else:
|
||||||
|
# Non-composite types, no need to unwrap
|
||||||
|
type_name = type
|
||||||
|
type_children = []
|
||||||
|
|
||||||
|
return RENAMED_TYPES.get(type_name, type_name), type_children
|
||||||
|
|
||||||
|
|
||||||
|
def is_type_equal(type1: str | None, type2: str | None) -> bool:
|
||||||
|
"""
|
||||||
|
Check if two types are equal.
|
||||||
|
This function handle composite types like list, dict, and tuple.
|
||||||
|
group similar types like list[str], List[str], and [str] as equal.
|
||||||
|
"""
|
||||||
|
if type1 is None and type2 is None:
|
||||||
|
return True
|
||||||
|
if type1 is None or type2 is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
evaluated_type1, children1 = unwrap_object_type(type1)
|
||||||
|
evaluated_type2, children2 = unwrap_object_type(type2)
|
||||||
|
|
||||||
|
# Compare the class name of the types (ignoring the module)
|
||||||
|
# TODO(majdyz): compare the module name as well.
|
||||||
|
t_len = min(len(evaluated_type1), len(evaluated_type2))
|
||||||
|
if evaluated_type1.split(".")[-t_len:] != evaluated_type2.split(".")[-t_len:]:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if len(children1) != len(children2):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if len(children1) == len(children2) == 0:
|
||||||
|
return True
|
||||||
|
|
||||||
|
for c1, c2 in zip(children1, children2):
|
||||||
|
if not is_type_equal(c1, c2):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def validate_matching_function(this: FunctionDef, that: FunctionDef):
|
||||||
|
expected_args = that.arg_types
|
||||||
|
expected_rets = that.return_type
|
||||||
|
func_name = that.name
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
# Fix the async flag based on the expectation.
|
||||||
|
if this.is_async != that.is_async:
|
||||||
|
this.is_async = that.is_async
|
||||||
|
if this.is_async and f"async def {this.name}" not in this.function_code:
|
||||||
|
this.function_code = this.function_code.replace(
|
||||||
|
f"def {this.name}", f"async def {this.name}"
|
||||||
|
)
|
||||||
|
if not this.is_async and f"async def {this.name}" in this.function_code:
|
||||||
|
this.function_code = this.function_code.replace(
|
||||||
|
f"async def {this.name}", f"def {this.name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if any(
|
||||||
|
[
|
||||||
|
x[0] != y[0] or not is_type_equal(x[1], y[1]) and x[1] != "object"
|
||||||
|
# TODO: remove sorted and provide a stable order for one-to-many arg-types.
|
||||||
|
for x, y in zip(sorted(expected_args), sorted(this.arg_types))
|
||||||
|
]
|
||||||
|
):
|
||||||
|
errors.append(
|
||||||
|
f"Function {func_name} has different arguments than expected, "
|
||||||
|
f"expected {expected_args} but got {this.arg_types}"
|
||||||
|
)
|
||||||
|
if not is_type_equal(expected_rets, this.return_type) and expected_rets != "object":
|
||||||
|
errors.append(
|
||||||
|
f"Function {func_name} has different return type than expected, expected "
|
||||||
|
f"{expected_rets} but got {this.return_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if errors:
|
||||||
|
raise Exception("Signature validation errors:\n " + "\n ".join(errors))
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_type(type: str, renamed_types: dict[str, str] = {}) -> str:
|
||||||
|
"""
|
||||||
|
Normalize the type to a standard format.
|
||||||
|
e.g. list[str] -> List[str], dict[str, int | float] -> Dict[str, Union[int, float]]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
type (str): The type to normalize.
|
||||||
|
Returns:
|
||||||
|
str: The normalized type.
|
||||||
|
"""
|
||||||
|
parent_type, children = unwrap_object_type(type)
|
||||||
|
|
||||||
|
if parent_type in renamed_types:
|
||||||
|
parent_type = renamed_types[parent_type]
|
||||||
|
|
||||||
|
if len(children) == 0:
|
||||||
|
return parent_type
|
||||||
|
|
||||||
|
content_type = ", ".join([normalize_type(c, renamed_types) for c in children])
|
||||||
|
return f"{parent_type}[{content_type}]"
|
||||||
|
|
||||||
|
|
||||||
|
def generate_object_code(obj: ObjectType) -> str:
|
||||||
|
if not obj.name:
|
||||||
|
return "" # Avoid generating an empty object
|
||||||
|
|
||||||
|
# Auto-generate a template for the object, this will not capture any class functions
|
||||||
|
fields = f"\n{' ' * 4}".join(
|
||||||
|
[
|
||||||
|
f"{field.name}: {field.type} "
|
||||||
|
f"{('= '+field.value) if field.value else ''} "
|
||||||
|
f"{('# '+field.description) if field.description else ''}"
|
||||||
|
for field in obj.Fields or []
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
parent_class = ""
|
||||||
|
if obj.is_enum:
|
||||||
|
parent_class = "Enum"
|
||||||
|
elif obj.is_pydantic:
|
||||||
|
parent_class = "BaseModel"
|
||||||
|
|
||||||
|
doc_string = (
|
||||||
|
f"""\"\"\"
|
||||||
|
{obj.description}
|
||||||
|
\"\"\""""
|
||||||
|
if obj.description
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
|
||||||
|
method_body = ("\n" + " " * 4).join(obj.code.split("\n")) + "\n" if obj.code else ""
|
||||||
|
|
||||||
|
template = f"""
|
||||||
|
class {obj.name}({parent_class}):
|
||||||
|
{doc_string if doc_string else ""}
|
||||||
|
{fields if fields else ""}
|
||||||
|
{method_body if method_body else ""}
|
||||||
|
{"pass" if not fields and not method_body else ""}
|
||||||
|
"""
|
||||||
|
return "\n".join(line for line in template.split("\n")).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def genererate_line_error(error: str, code: str, line_number: int) -> str:
|
||||||
|
lines = code.split("\n")
|
||||||
|
if line_number > len(lines):
|
||||||
|
return error
|
||||||
|
|
||||||
|
code_line = lines[line_number - 1]
|
||||||
|
return f"{error} -> '{code_line.strip()}'"
|
||||||
|
|
||||||
|
|
||||||
|
def generate_compiled_code(
|
||||||
|
resp: ValidationResponse, add_code_stubs: bool = True
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Regenerate imports & raw code using the available objects and functions.
|
||||||
|
"""
|
||||||
|
resp.imports = sorted(set(resp.imports))
|
||||||
|
|
||||||
|
def __append_comment(code_block: str, comment: str) -> str:
|
||||||
|
"""
|
||||||
|
Append `# noqa` to the first line of the code block.
|
||||||
|
This is to suppress flake8 warnings for redefined names.
|
||||||
|
"""
|
||||||
|
lines = code_block.split("\n")
|
||||||
|
lines[0] = lines[0] + " # " + comment
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
def __generate_stub(name, is_enum):
|
||||||
|
if not name:
|
||||||
|
return ""
|
||||||
|
elif is_enum:
|
||||||
|
return f"class {name}(Enum):\n pass"
|
||||||
|
else:
|
||||||
|
return f"class {name}(BaseModel):\n pass"
|
||||||
|
|
||||||
|
stub_objects = resp.available_objects if add_code_stubs else {}
|
||||||
|
stub_functions = resp.available_functions if add_code_stubs else {}
|
||||||
|
|
||||||
|
object_stubs_code = "\n\n".join(
|
||||||
|
[
|
||||||
|
__append_comment(__generate_stub(obj.name, obj.is_enum), "type: ignore")
|
||||||
|
for obj in stub_objects.values()
|
||||||
|
]
|
||||||
|
+ [
|
||||||
|
__append_comment(__generate_stub(obj.name, obj.is_enum), "type: ignore")
|
||||||
|
for obj in resp.objects
|
||||||
|
if obj.name not in stub_objects
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
objects_code = "\n\n".join(
|
||||||
|
[
|
||||||
|
__append_comment(generate_object_code(obj), "noqa")
|
||||||
|
for obj in stub_objects.values()
|
||||||
|
]
|
||||||
|
+ [
|
||||||
|
__append_comment(generate_object_code(obj), "noqa")
|
||||||
|
for obj in resp.objects
|
||||||
|
if obj.name not in stub_objects
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
functions_code = "\n\n".join(
|
||||||
|
[
|
||||||
|
__append_comment(f.function_template.strip(), "type: ignore")
|
||||||
|
for f in stub_functions.values()
|
||||||
|
if f.name != resp.function_name and f.function_template
|
||||||
|
]
|
||||||
|
+ [
|
||||||
|
__append_comment(f.function_template.strip(), "type: ignore")
|
||||||
|
for f in resp.functions
|
||||||
|
if f.name not in stub_functions and f.function_template
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
resp.rawCode = (
|
||||||
|
object_stubs_code.strip()
|
||||||
|
+ "\n\n"
|
||||||
|
+ objects_code.strip()
|
||||||
|
+ "\n\n"
|
||||||
|
+ functions_code.strip()
|
||||||
|
+ "\n\n"
|
||||||
|
+ resp.functionCode.strip()
|
||||||
|
)
|
||||||
|
|
||||||
|
return resp.get_compiled_code()
|
||||||
222
forge/forge/utils/function/visitor.py
Normal file
222
forge/forge/utils/function/visitor.py
Normal file
@@ -0,0 +1,222 @@
|
|||||||
|
import ast
|
||||||
|
import re
|
||||||
|
|
||||||
|
from forge.utils.function.model import FunctionDef, ObjectType, ObjectField
|
||||||
|
from forge.utils.function.util import normalize_type, PYTHON_TYPES
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionVisitor(ast.NodeVisitor):
|
||||||
|
"""
|
||||||
|
Visits a Python AST and extracts function definitions and Pydantic class definitions
|
||||||
|
|
||||||
|
To use this class, create an instance and call the visit method with the AST.
|
||||||
|
as the argument The extracted function definitions and Pydantic class definitions
|
||||||
|
can be accessed from the functions and objects attributes respectively.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```
|
||||||
|
visitor = FunctionVisitor()
|
||||||
|
visitor.visit(ast.parse("def foo(x: int) -> int: return x"))
|
||||||
|
print(visitor.functions)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.functions: list[FunctionDef] = []
|
||||||
|
self.functionsIdx: list[int] = []
|
||||||
|
self.objects: list[ObjectType] = []
|
||||||
|
self.objectsIdx: list[int] = []
|
||||||
|
self.globals: list[str] = []
|
||||||
|
self.globalsIdx: list[int] = []
|
||||||
|
self.imports: list[str] = []
|
||||||
|
self.errors: list[str] = []
|
||||||
|
|
||||||
|
def visit_Import(self, node):
|
||||||
|
for alias in node.names:
|
||||||
|
import_line = f"import {alias.name}"
|
||||||
|
if alias.asname:
|
||||||
|
import_line += f" as {alias.asname}"
|
||||||
|
self.imports.append(import_line)
|
||||||
|
self.generic_visit(node)
|
||||||
|
|
||||||
|
def visit_ImportFrom(self, node):
|
||||||
|
for alias in node.names:
|
||||||
|
import_line = f"from {node.module} import {alias.name}"
|
||||||
|
if alias.asname:
|
||||||
|
import_line += f" as {alias.asname}"
|
||||||
|
self.imports.append(import_line)
|
||||||
|
self.generic_visit(node)
|
||||||
|
|
||||||
|
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
|
||||||
|
# treat async functions as normal functions
|
||||||
|
self.visit_FunctionDef(node) # type: ignore
|
||||||
|
|
||||||
|
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
|
||||||
|
args = []
|
||||||
|
for arg in node.args.args:
|
||||||
|
arg_type = ast.unparse(arg.annotation) if arg.annotation else "object"
|
||||||
|
args.append((arg.arg, normalize_type(arg_type)))
|
||||||
|
return_type = (
|
||||||
|
normalize_type(ast.unparse(node.returns)) if node.returns else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract doc_string & function body
|
||||||
|
if (
|
||||||
|
node.body
|
||||||
|
and isinstance(node.body[0], ast.Expr)
|
||||||
|
and isinstance(node.body[0].value, ast.Constant)
|
||||||
|
):
|
||||||
|
doc_string = node.body[0].value.s.strip()
|
||||||
|
template_body = [node.body[0], ast.Pass()]
|
||||||
|
is_implemented = not isinstance(node.body[1], ast.Pass)
|
||||||
|
else:
|
||||||
|
doc_string = ""
|
||||||
|
template_body = [ast.Pass()]
|
||||||
|
is_implemented = not isinstance(node.body[0], ast.Pass)
|
||||||
|
|
||||||
|
# Construct function template
|
||||||
|
original_body = node.body.copy()
|
||||||
|
node.body = template_body # type: ignore
|
||||||
|
function_template = ast.unparse(node)
|
||||||
|
node.body = original_body
|
||||||
|
|
||||||
|
function_code = ast.unparse(node)
|
||||||
|
if "await" in function_code and "async def" not in function_code:
|
||||||
|
function_code = function_code.replace("def ", "async def ")
|
||||||
|
function_template = function_template.replace("def ", "async def ")
|
||||||
|
|
||||||
|
def split_doc(keywords: list[str], doc: str) -> tuple[str, str]:
|
||||||
|
for keyword in keywords:
|
||||||
|
if match := re.search(f"{keyword}\\s?:", doc):
|
||||||
|
return doc[: match.start()], doc[match.end() :]
|
||||||
|
return doc, ""
|
||||||
|
|
||||||
|
# Decompose doc_pattern into func_doc, args_doc, rets_doc, errs_doc, usage_doc
|
||||||
|
# by splitting in reverse order
|
||||||
|
func_doc = doc_string
|
||||||
|
func_doc, usage_doc = split_doc(
|
||||||
|
["Ex", "Usage", "Usages", "Example", "Examples"], func_doc
|
||||||
|
)
|
||||||
|
func_doc, errs_doc = split_doc(["Error", "Errors", "Raise", "Raises"], func_doc)
|
||||||
|
func_doc, rets_doc = split_doc(["Return", "Returns"], func_doc)
|
||||||
|
func_doc, args_doc = split_doc(
|
||||||
|
["Arg", "Args", "Argument", "Arguments"], func_doc
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract Func
|
||||||
|
function_desc = func_doc.strip()
|
||||||
|
|
||||||
|
# Extract Args
|
||||||
|
args_descs = {}
|
||||||
|
split_pattern = r"\n(\s+.+):"
|
||||||
|
for match in reversed(list(re.finditer(split_pattern, string=args_doc))):
|
||||||
|
arg = match.group(1).strip().split(" ")[0]
|
||||||
|
desc = args_doc.rsplit(match.group(1), 1)[1].strip(": ")
|
||||||
|
args_descs[arg] = desc.strip()
|
||||||
|
args_doc = args_doc[: match.start()]
|
||||||
|
|
||||||
|
# Extract Returns
|
||||||
|
return_desc = ""
|
||||||
|
if match := re.match(split_pattern, string=rets_doc):
|
||||||
|
return_desc = rets_doc[match.end() :].strip()
|
||||||
|
|
||||||
|
self.functions.append(
|
||||||
|
FunctionDef(
|
||||||
|
name=node.name,
|
||||||
|
arg_types=args,
|
||||||
|
arg_descs=args_descs,
|
||||||
|
return_type=return_type,
|
||||||
|
return_desc=return_desc,
|
||||||
|
is_implemented=is_implemented,
|
||||||
|
function_desc=function_desc,
|
||||||
|
function_template=function_template,
|
||||||
|
function_code=function_code,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.functionsIdx.append(node.lineno)
|
||||||
|
|
||||||
|
def visit_ClassDef(self, node: ast.ClassDef) -> None:
|
||||||
|
"""
|
||||||
|
Visits a ClassDef node in the AST and checks if it is a Pydantic class.
|
||||||
|
If it is a Pydantic class, adds its name to the list of Pydantic classes.
|
||||||
|
"""
|
||||||
|
is_pydantic = any(
|
||||||
|
[
|
||||||
|
(isinstance(base, ast.Name) and base.id == "BaseModel")
|
||||||
|
or (isinstance(base, ast.Attribute) and base.attr == "BaseModel")
|
||||||
|
for base in node.bases
|
||||||
|
]
|
||||||
|
)
|
||||||
|
is_enum = any(
|
||||||
|
[
|
||||||
|
(isinstance(base, ast.Name) and base.id.endswith("Enum"))
|
||||||
|
or (isinstance(base, ast.Attribute) and base.attr.endswith("Enum"))
|
||||||
|
for base in node.bases
|
||||||
|
]
|
||||||
|
)
|
||||||
|
is_implemented = not any(isinstance(v, ast.Pass) for v in node.body)
|
||||||
|
doc_string = ""
|
||||||
|
if (
|
||||||
|
node.body
|
||||||
|
and isinstance(node.body[0], ast.Expr)
|
||||||
|
and isinstance(node.body[0].value, ast.Constant)
|
||||||
|
):
|
||||||
|
doc_string = node.body[0].value.s.strip()
|
||||||
|
|
||||||
|
if node.name in PYTHON_TYPES:
|
||||||
|
self.errors.append(
|
||||||
|
f"Can't declare class with a Python built-in name "
|
||||||
|
f"`{node.name}`. Please use a different name."
|
||||||
|
)
|
||||||
|
|
||||||
|
fields = []
|
||||||
|
methods = []
|
||||||
|
for v in node.body:
|
||||||
|
if isinstance(v, ast.AnnAssign):
|
||||||
|
field = ObjectField(
|
||||||
|
name=ast.unparse(v.target),
|
||||||
|
type=normalize_type(ast.unparse(v.annotation)),
|
||||||
|
value=ast.unparse(v.value) if v.value else None,
|
||||||
|
)
|
||||||
|
if field.value is None and field.type.startswith("Optional"):
|
||||||
|
field.value = "None"
|
||||||
|
elif isinstance(v, ast.Assign):
|
||||||
|
if len(v.targets) > 1:
|
||||||
|
self.errors.append(
|
||||||
|
f"Class {node.name} has multiple assignments in a single line."
|
||||||
|
)
|
||||||
|
field = ObjectField(
|
||||||
|
name=ast.unparse(v.targets[0]),
|
||||||
|
type=type(ast.unparse(v.value)).__name__,
|
||||||
|
value=ast.unparse(v.value) if v.value else None,
|
||||||
|
)
|
||||||
|
elif isinstance(v, ast.Expr) and isinstance(v.value, ast.Constant):
|
||||||
|
# skip comments and docstrings
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
methods.append(ast.unparse(v))
|
||||||
|
continue
|
||||||
|
fields.append(field)
|
||||||
|
|
||||||
|
self.objects.append(
|
||||||
|
ObjectType(
|
||||||
|
name=node.name,
|
||||||
|
code="\n".join(methods),
|
||||||
|
description=doc_string,
|
||||||
|
Fields=fields,
|
||||||
|
is_pydantic=is_pydantic,
|
||||||
|
is_enum=is_enum,
|
||||||
|
is_implemented=is_implemented,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.objectsIdx.append(node.lineno)
|
||||||
|
|
||||||
|
def visit(self, node):
|
||||||
|
if (
|
||||||
|
isinstance(node, ast.Assign)
|
||||||
|
or isinstance(node, ast.AnnAssign)
|
||||||
|
or isinstance(node, ast.AugAssign)
|
||||||
|
) and node.col_offset == 0:
|
||||||
|
self.globals.append(ast.unparse(node))
|
||||||
|
self.globalsIdx.append(node.lineno)
|
||||||
|
super().visit(node)
|
||||||
Reference in New Issue
Block a user