mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-12 08:38:09 -05:00
Compare commits
58 Commits
native-aut
...
zamilmajdy
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7e0b1156cc | ||
|
|
da9360fdeb | ||
|
|
9dea6a273e | ||
|
|
e19636ac3e | ||
|
|
f03c6546b8 | ||
|
|
2c4afd4458 | ||
|
|
8b1d416de3 | ||
|
|
7f6b7d6d7e | ||
|
|
736ac778cc | ||
|
|
38eafdbb66 | ||
|
|
6d9f564dc5 | ||
|
|
3e675123d7 | ||
|
|
37cc047656 | ||
|
|
9f804080ed | ||
|
|
680fbf49aa | ||
|
|
901dadefc3 | ||
|
|
e204491c6c | ||
|
|
3597f801a7 | ||
|
|
b59862c402 | ||
|
|
81bac301e8 | ||
|
|
a9eb49d54e | ||
|
|
2c6e1eb4c8 | ||
|
|
3e8849b08e | ||
|
|
111e8585b5 | ||
|
|
8144d26cef | ||
|
|
e264bf7764 | ||
|
|
6dd0975236 | ||
|
|
c3acb99314 | ||
|
|
0578fb0246 | ||
|
|
731d0345f0 | ||
|
|
b4cd735f26 | ||
|
|
6e715b6c71 | ||
|
|
fcca4cc893 | ||
|
|
5c7c276c10 | ||
|
|
ae63aa8ebb | ||
|
|
fdd9f9b5ec | ||
|
|
a825aa8515 | ||
|
|
ae43136c2c | ||
|
|
c8e16f3fe1 | ||
|
|
3a60504138 | ||
|
|
dfa77739c3 | ||
|
|
9f6e25664c | ||
|
|
3c4ff60e11 | ||
|
|
47eeaf0325 | ||
|
|
81ad3cb69a | ||
|
|
834eb6c6e0 | ||
|
|
fb802400ba | ||
|
|
922e643737 | ||
|
|
7b5272f1f2 | ||
|
|
ea134c7dbd | ||
|
|
f7634524fa | ||
|
|
0eccbe1483 | ||
|
|
0916df4df7 | ||
|
|
22e2373a0b | ||
|
|
40426e4646 | ||
|
|
ef1fe7c4e8 | ||
|
|
ca7ca226ff | ||
|
|
ed5f12c02b |
@@ -23,6 +23,7 @@ from forge.components.code_executor.code_executor import (
|
||||
CodeExecutorComponent,
|
||||
CodeExecutorConfiguration,
|
||||
)
|
||||
from forge.components.code_flow_executor import CodeFlowExecutionComponent
|
||||
from forge.components.context.context import AgentContext, ContextComponent
|
||||
from forge.components.file_manager import FileManagerComponent
|
||||
from forge.components.git_operations import GitOperationsComponent
|
||||
@@ -40,7 +41,6 @@ from forge.llm.providers import (
|
||||
ChatModelResponse,
|
||||
MultiProvider,
|
||||
)
|
||||
from forge.llm.providers.utils import function_specs_from_commands
|
||||
from forge.models.action import (
|
||||
ActionErrorResult,
|
||||
ActionInterruptedByHuman,
|
||||
@@ -56,6 +56,7 @@ from forge.utils.exceptions import (
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from .prompt_strategies.code_flow import CodeFlowAgentPromptStrategy
|
||||
from .prompt_strategies.one_shot import (
|
||||
OneShotAgentActionProposal,
|
||||
OneShotAgentPromptStrategy,
|
||||
@@ -96,11 +97,14 @@ class Agent(BaseAgent[OneShotAgentActionProposal], Configurable[AgentSettings]):
|
||||
llm_provider: MultiProvider,
|
||||
file_storage: FileStorage,
|
||||
app_config: AppConfig,
|
||||
prompt_strategy_class: type[
|
||||
OneShotAgentPromptStrategy | CodeFlowAgentPromptStrategy
|
||||
] = CodeFlowAgentPromptStrategy,
|
||||
):
|
||||
super().__init__(settings)
|
||||
|
||||
self.llm_provider = llm_provider
|
||||
prompt_config = OneShotAgentPromptStrategy.default_configuration.model_copy(
|
||||
prompt_config = prompt_strategy_class.default_configuration.model_copy(
|
||||
deep=True
|
||||
)
|
||||
prompt_config.use_functions_api = (
|
||||
@@ -108,7 +112,7 @@ class Agent(BaseAgent[OneShotAgentActionProposal], Configurable[AgentSettings]):
|
||||
# Anthropic currently doesn't support tools + prefilling :(
|
||||
and self.llm.provider_name != "anthropic"
|
||||
)
|
||||
self.prompt_strategy = OneShotAgentPromptStrategy(prompt_config, logger)
|
||||
self.prompt_strategy = prompt_strategy_class(prompt_config, logger)
|
||||
self.commands: list[Command] = []
|
||||
|
||||
# Components
|
||||
@@ -145,6 +149,7 @@ class Agent(BaseAgent[OneShotAgentActionProposal], Configurable[AgentSettings]):
|
||||
self.watchdog = WatchdogComponent(settings.config, settings.history).run_after(
|
||||
ContextComponent
|
||||
)
|
||||
self.code_flow_executor = CodeFlowExecutionComponent(lambda: self.commands)
|
||||
|
||||
self.event_history = settings.history
|
||||
self.app_config = app_config
|
||||
@@ -185,7 +190,7 @@ class Agent(BaseAgent[OneShotAgentActionProposal], Configurable[AgentSettings]):
|
||||
task=self.state.task,
|
||||
ai_profile=self.state.ai_profile,
|
||||
ai_directives=directives,
|
||||
commands=function_specs_from_commands(self.commands),
|
||||
commands=self.commands,
|
||||
include_os_info=include_os_info,
|
||||
)
|
||||
|
||||
@@ -201,9 +206,7 @@ class Agent(BaseAgent[OneShotAgentActionProposal], Configurable[AgentSettings]):
|
||||
if exception:
|
||||
prompt.messages.append(ChatMessage.system(f"Error: {exception}"))
|
||||
|
||||
response: ChatModelResponse[
|
||||
OneShotAgentActionProposal
|
||||
] = await self.llm_provider.create_chat_completion(
|
||||
response: ChatModelResponse = await self.llm_provider.create_chat_completion(
|
||||
prompt.messages,
|
||||
model_name=self.llm.name,
|
||||
completion_parser=self.prompt_strategy.parse_response_content,
|
||||
@@ -281,7 +284,7 @@ class Agent(BaseAgent[OneShotAgentActionProposal], Configurable[AgentSettings]):
|
||||
except AgentException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise CommandExecutionError(str(e))
|
||||
raise CommandExecutionError(str(e)) from e
|
||||
|
||||
def _get_command(self, command_name: str) -> Command:
|
||||
for command in reversed(self.commands):
|
||||
|
||||
355
autogpt/autogpt/agents/prompt_strategies/code_flow.py
Normal file
355
autogpt/autogpt/agents/prompt_strategies/code_flow.py
Normal file
@@ -0,0 +1,355 @@
|
||||
import inspect
|
||||
import re
|
||||
from logging import Logger
|
||||
from typing import Callable, Iterable, Sequence, get_args, get_origin
|
||||
|
||||
from forge.command import Command
|
||||
from forge.components.code_flow_executor import CodeFlowExecutionComponent
|
||||
from forge.config.ai_directives import AIDirectives
|
||||
from forge.config.ai_profile import AIProfile
|
||||
from forge.json.parsing import extract_dict_from_json
|
||||
from forge.llm.prompting import ChatPrompt, LanguageModelClassification, PromptStrategy
|
||||
from forge.llm.prompting.utils import indent
|
||||
from forge.llm.providers.schema import (
|
||||
AssistantChatMessage,
|
||||
AssistantFunctionCall,
|
||||
ChatMessage,
|
||||
)
|
||||
from forge.models.config import SystemConfiguration
|
||||
from forge.models.json_schema import JSONSchema
|
||||
from forge.utils.exceptions import InvalidAgentResponseError
|
||||
from forge.utils.function.code_validation import CodeValidator
|
||||
from forge.utils.function.model import FunctionDef
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from autogpt.agents.prompt_strategies.one_shot import (
|
||||
AssistantThoughts,
|
||||
OneShotAgentActionProposal,
|
||||
OneShotAgentPromptConfiguration,
|
||||
)
|
||||
|
||||
_RESPONSE_INTERFACE_NAME = "AssistantResponse"
|
||||
|
||||
|
||||
class CodeFlowAgentActionProposal(BaseModel):
|
||||
thoughts: AssistantThoughts
|
||||
immediate_plan: str = Field(
|
||||
...,
|
||||
description="We will be running an iterative process to execute the plan, "
|
||||
"Write the partial / immediate plan to execute your plan as detailed and "
|
||||
"efficiently as possible without the help of the reasoning/intelligence. "
|
||||
"The plan should describe the output of the immediate plan, so that the next "
|
||||
"iteration can be executed by taking the output into account. "
|
||||
"Try to do as much as possible without making any assumption or uninformed "
|
||||
"guesses. Avoid large output at all costs!!!\n"
|
||||
"Format: Objective[Objective of this iteration, explain what's the use of this "
|
||||
"iteration for the next one] Plan[Plan that does not require any reasoning or "
|
||||
"intelligence] Output[Output of the plan / should be small, avoid whole file "
|
||||
"output]",
|
||||
)
|
||||
python_code: str = Field(
|
||||
...,
|
||||
description=(
|
||||
"Write the fully-functional Python code of the immediate plan. "
|
||||
"The output will be an `async def main() -> str` function of the immediate "
|
||||
"plan that return the string output, the output will be passed into the "
|
||||
"LLM context window so avoid returning the whole content!. "
|
||||
"Use ONLY the listed available functions and built-in Python features. "
|
||||
"Leverage the given magic functions to implement function calls for which "
|
||||
"the arguments can't be determined yet. "
|
||||
"Example:`async def main() -> str:\n"
|
||||
" return await provided_function('arg1', 'arg2').split('\\n')[0]`"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
FINAL_INSTRUCTION: str = (
|
||||
"You have to give the answer in the from of JSON schema specified previously. "
|
||||
"For the `python_code` field, you have to write Python code to execute your plan "
|
||||
"as efficiently as possible. Your code will be executed directly without any "
|
||||
"editing, if it doesn't work you will be held responsible. "
|
||||
"Use ONLY the listed available functions and built-in Python features. "
|
||||
"Do not make uninformed assumptions "
|
||||
"(e.g. about the content or format of an unknown file). Leverage the given magic "
|
||||
"functions to implement function calls for which the arguments can't be determined "
|
||||
"yet. Reduce the amount of unnecessary data passed into these magic functions "
|
||||
"where possible, because magic costs money and magically processing large amounts "
|
||||
"of data is expensive. If you think are done with the task, you can simply call "
|
||||
"finish(reason='your reason') to end the task, "
|
||||
"a function that has one `finish` command, don't mix finish with other functions! "
|
||||
"If you still need to do other functions, "
|
||||
"let the next cycle execute the `finish` function. "
|
||||
"Avoid hard-coding input values as input, and avoid returning large outputs. "
|
||||
"The code that you have been executing in the past cycles can also be buggy, "
|
||||
"so if you see undesired output, you can always try to re-plan, and re-code. "
|
||||
)
|
||||
|
||||
|
||||
class CodeFlowAgentPromptStrategy(PromptStrategy):
|
||||
default_configuration: OneShotAgentPromptConfiguration = (
|
||||
OneShotAgentPromptConfiguration()
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
configuration: SystemConfiguration,
|
||||
logger: Logger,
|
||||
):
|
||||
self.config = configuration
|
||||
self.response_schema = JSONSchema.from_dict(
|
||||
CodeFlowAgentActionProposal.model_json_schema()
|
||||
)
|
||||
self.logger = logger
|
||||
self.commands: Sequence[Command] = [] # Sequence -> disallow list modification
|
||||
|
||||
@property
|
||||
def llm_classification(self) -> LanguageModelClassification:
|
||||
return LanguageModelClassification.SMART_MODEL # FIXME: dynamic switching
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
*,
|
||||
messages: list[ChatMessage],
|
||||
task: str,
|
||||
ai_profile: AIProfile,
|
||||
ai_directives: AIDirectives,
|
||||
commands: Sequence[Command],
|
||||
**extras,
|
||||
) -> ChatPrompt:
|
||||
"""Constructs and returns a prompt with the following structure:
|
||||
1. System prompt
|
||||
3. `cycle_instruction`
|
||||
"""
|
||||
system_prompt, response_prefill = self.build_system_prompt(
|
||||
ai_profile=ai_profile,
|
||||
ai_directives=ai_directives,
|
||||
commands=commands,
|
||||
)
|
||||
|
||||
self.commands = commands
|
||||
final_instruction_msg = ChatMessage.system(FINAL_INSTRUCTION)
|
||||
|
||||
return ChatPrompt(
|
||||
messages=[
|
||||
ChatMessage.system(system_prompt),
|
||||
ChatMessage.user(f'"""{task}"""'),
|
||||
*messages,
|
||||
*(
|
||||
[final_instruction_msg]
|
||||
if not any(m.role == "assistant" for m in messages)
|
||||
else []
|
||||
),
|
||||
],
|
||||
prefill_response=response_prefill,
|
||||
)
|
||||
|
||||
def build_system_prompt(
|
||||
self,
|
||||
ai_profile: AIProfile,
|
||||
ai_directives: AIDirectives,
|
||||
commands: Iterable[Command],
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Builds the system prompt.
|
||||
|
||||
Returns:
|
||||
str: The system prompt body
|
||||
str: The desired start for the LLM's response; used to steer the output
|
||||
"""
|
||||
response_fmt_instruction, response_prefill = self.response_format_instruction()
|
||||
system_prompt_parts = (
|
||||
self._generate_intro_prompt(ai_profile)
|
||||
+ [
|
||||
"## Your Task\n"
|
||||
"The user will specify a task for you to execute, in triple quotes,"
|
||||
" in the next message. Your job is to complete the task, "
|
||||
"and terminate when your task is done."
|
||||
]
|
||||
+ ["## Available Functions\n" + self._generate_function_headers(commands)]
|
||||
+ ["## RESPONSE FORMAT\n" + response_fmt_instruction]
|
||||
)
|
||||
|
||||
# Join non-empty parts together into paragraph format
|
||||
return (
|
||||
"\n\n".join(filter(None, system_prompt_parts)).strip("\n"),
|
||||
response_prefill,
|
||||
)
|
||||
|
||||
def response_format_instruction(self) -> tuple[str, str]:
|
||||
response_schema = self.response_schema.model_copy(deep=True)
|
||||
assert response_schema.properties
|
||||
|
||||
# Unindent for performance
|
||||
response_format = re.sub(
|
||||
r"\n\s+",
|
||||
"\n",
|
||||
response_schema.to_typescript_object_interface(_RESPONSE_INTERFACE_NAME),
|
||||
)
|
||||
response_prefill = f'{{\n "{list(response_schema.properties.keys())[0]}":'
|
||||
|
||||
return (
|
||||
(
|
||||
f"YOU MUST ALWAYS RESPOND WITH A JSON OBJECT OF THE FOLLOWING TYPE:\n"
|
||||
f"{response_format}"
|
||||
),
|
||||
response_prefill,
|
||||
)
|
||||
|
||||
def _generate_intro_prompt(self, ai_profile: AIProfile) -> list[str]:
|
||||
"""Generates the introduction part of the prompt.
|
||||
|
||||
Returns:
|
||||
list[str]: A list of strings forming the introduction part of the prompt.
|
||||
"""
|
||||
return [
|
||||
f"You are {ai_profile.ai_name}, {ai_profile.ai_role.rstrip('.')}.",
|
||||
# "Your decisions must always be made independently without seeking "
|
||||
# "user assistance. Play to your strengths as an LLM and pursue "
|
||||
# "simple strategies with no legal complications.",
|
||||
]
|
||||
|
||||
def _generate_function_headers(self, commands: Iterable[Command]) -> str:
|
||||
function_stubs: list[str] = []
|
||||
annotation_types_in_context: set[type] = set()
|
||||
for f in commands:
|
||||
# Add source code of non-builtin types from function signatures
|
||||
new_annotation_types = extract_annotation_types(f.method).difference(
|
||||
annotation_types_in_context
|
||||
)
|
||||
new_annotation_types_src = [
|
||||
f"# {a.__module__}.{a.__qualname__}\n{inspect.getsource(a)}"
|
||||
for a in new_annotation_types
|
||||
]
|
||||
annotation_types_in_context.update(new_annotation_types)
|
||||
|
||||
param_descriptions = "\n".join(
|
||||
f"{param.name}: {param.spec.description}"
|
||||
for param in f.parameters
|
||||
if param.spec.description
|
||||
)
|
||||
full_function_stub = (
|
||||
("\n".join(new_annotation_types_src) + "\n" + f.header).strip()
|
||||
+ "\n"
|
||||
+ indent(
|
||||
(
|
||||
'"""\n'
|
||||
f"{f.description}\n\n"
|
||||
f"Params:\n{indent(param_descriptions)}\n"
|
||||
'"""\n'
|
||||
"pass"
|
||||
),
|
||||
)
|
||||
)
|
||||
function_stubs.append(full_function_stub)
|
||||
|
||||
return "\n\n\n".join(function_stubs)
|
||||
|
||||
async def parse_response_content(
|
||||
self,
|
||||
response: AssistantChatMessage,
|
||||
) -> OneShotAgentActionProposal:
|
||||
if not response.content:
|
||||
raise InvalidAgentResponseError("Assistant response has no text content")
|
||||
|
||||
self.logger.debug(
|
||||
"LLM response content:"
|
||||
+ (
|
||||
f"\n{response.content}"
|
||||
if "\n" in response.content
|
||||
else f" '{response.content}'"
|
||||
)
|
||||
)
|
||||
assistant_reply_dict = extract_dict_from_json(response.content)
|
||||
|
||||
parsed_response = CodeFlowAgentActionProposal.model_validate(
|
||||
assistant_reply_dict
|
||||
)
|
||||
if not parsed_response.python_code:
|
||||
raise ValueError("python_code is empty")
|
||||
|
||||
available_functions = {
|
||||
c.name: FunctionDef(
|
||||
name=c.name,
|
||||
arg_types=[(p.name, p.spec.python_type) for p in c.parameters],
|
||||
arg_descs={p.name: p.spec.description for p in c.parameters},
|
||||
arg_defaults={
|
||||
p.name: p.spec.default or "None"
|
||||
for p in c.parameters
|
||||
if p.spec.default or not p.spec.required
|
||||
},
|
||||
return_type=c.return_type,
|
||||
return_desc="Output of the function",
|
||||
function_desc=c.description,
|
||||
is_async=c.is_async,
|
||||
)
|
||||
for c in self.commands
|
||||
}
|
||||
available_functions.update(
|
||||
{
|
||||
"main": FunctionDef(
|
||||
name="main",
|
||||
arg_types=[],
|
||||
arg_descs={},
|
||||
return_type="str",
|
||||
return_desc="Output of the function",
|
||||
function_desc="The main function to execute the plan",
|
||||
is_async=True,
|
||||
)
|
||||
}
|
||||
)
|
||||
code_validation = await CodeValidator(
|
||||
function_name="main",
|
||||
available_functions=available_functions,
|
||||
).validate_code(parsed_response.python_code)
|
||||
|
||||
clean_response = response.model_copy()
|
||||
clean_response.content = parsed_response.model_dump_json(indent=4)
|
||||
|
||||
# TODO: prevent combining finish with other functions
|
||||
if _finish_call := re.search(
|
||||
r"finish\((reason=)?(.*?)\)", code_validation.functionCode
|
||||
):
|
||||
finish_reason = _finish_call.group(2)[1:-1] # remove quotes
|
||||
result = OneShotAgentActionProposal(
|
||||
thoughts=parsed_response.thoughts,
|
||||
use_tool=AssistantFunctionCall(
|
||||
name="finish",
|
||||
arguments={"reason": finish_reason},
|
||||
),
|
||||
raw_message=clean_response,
|
||||
)
|
||||
else:
|
||||
result = OneShotAgentActionProposal(
|
||||
thoughts=parsed_response.thoughts,
|
||||
use_tool=AssistantFunctionCall(
|
||||
name=CodeFlowExecutionComponent.execute_code_flow.name,
|
||||
arguments={
|
||||
"python_code": code_validation.functionCode,
|
||||
"plan_text": parsed_response.immediate_plan,
|
||||
},
|
||||
),
|
||||
raw_message=clean_response,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def extract_annotation_types(func: Callable) -> set[type]:
|
||||
annotation_types = set()
|
||||
for annotation in inspect.get_annotations(func).values():
|
||||
annotation_types.update(_get_nested_types(annotation))
|
||||
return annotation_types
|
||||
|
||||
|
||||
def _get_nested_types(annotation: type) -> Iterable[type]:
|
||||
if _args := get_args(annotation):
|
||||
for a in _args:
|
||||
yield from _get_nested_types(a)
|
||||
if not _is_builtin_type(_a := get_origin(annotation) or annotation):
|
||||
yield _a
|
||||
|
||||
|
||||
def _is_builtin_type(_type: type):
|
||||
"""Check if a given type is a built-in type."""
|
||||
import sys
|
||||
|
||||
return _type.__module__ in sys.stdlib_module_names
|
||||
@@ -6,6 +6,7 @@ import re
|
||||
from logging import Logger
|
||||
|
||||
import distro
|
||||
from forge.command import Command
|
||||
from forge.config.ai_directives import AIDirectives
|
||||
from forge.config.ai_profile import AIProfile
|
||||
from forge.json.parsing import extract_dict_from_json
|
||||
@@ -16,6 +17,7 @@ from forge.llm.providers.schema import (
|
||||
ChatMessage,
|
||||
CompletionModelFunction,
|
||||
)
|
||||
from forge.llm.providers.utils import function_specs_from_commands
|
||||
from forge.models.action import ActionProposal
|
||||
from forge.models.config import SystemConfiguration, UserConfigurable
|
||||
from forge.models.json_schema import JSONSchema
|
||||
@@ -27,13 +29,21 @@ _RESPONSE_INTERFACE_NAME = "AssistantResponse"
|
||||
|
||||
|
||||
class AssistantThoughts(ModelWithSummary):
|
||||
past_action_summary: str = Field(
|
||||
...,
|
||||
description="Summary of the last action you took, if there is none, "
|
||||
"you can leave it empty",
|
||||
)
|
||||
observations: str = Field(
|
||||
description="Relevant observations from your last action (if any)"
|
||||
description="Relevant observations from your last actions (if any)"
|
||||
)
|
||||
text: str = Field(description="Thoughts")
|
||||
reasoning: str = Field(description="Reasoning behind the thoughts")
|
||||
self_criticism: str = Field(description="Constructive self-criticism")
|
||||
plan: list[str] = Field(description="Short list that conveys the long-term plan")
|
||||
plan: list[str] = Field(
|
||||
description="Short list that conveys the long-term plan, "
|
||||
"considering the progress on your task so far",
|
||||
)
|
||||
speak: str = Field(description="Summary of thoughts, to say to user")
|
||||
|
||||
def summary(self) -> str:
|
||||
@@ -101,7 +111,7 @@ class OneShotAgentPromptStrategy(PromptStrategy):
|
||||
|
||||
@property
|
||||
def llm_classification(self) -> LanguageModelClassification:
|
||||
return LanguageModelClassification.FAST_MODEL # FIXME: dynamic switching
|
||||
return LanguageModelClassification.SMART_MODEL # FIXME: dynamic switching
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
@@ -110,7 +120,7 @@ class OneShotAgentPromptStrategy(PromptStrategy):
|
||||
task: str,
|
||||
ai_profile: AIProfile,
|
||||
ai_directives: AIDirectives,
|
||||
commands: list[CompletionModelFunction],
|
||||
commands: list[Command],
|
||||
include_os_info: bool,
|
||||
**extras,
|
||||
) -> ChatPrompt:
|
||||
@@ -118,10 +128,11 @@ class OneShotAgentPromptStrategy(PromptStrategy):
|
||||
1. System prompt
|
||||
3. `cycle_instruction`
|
||||
"""
|
||||
functions = function_specs_from_commands(commands)
|
||||
system_prompt, response_prefill = self.build_system_prompt(
|
||||
ai_profile=ai_profile,
|
||||
ai_directives=ai_directives,
|
||||
commands=commands,
|
||||
functions=functions,
|
||||
include_os_info=include_os_info,
|
||||
)
|
||||
|
||||
@@ -135,14 +146,14 @@ class OneShotAgentPromptStrategy(PromptStrategy):
|
||||
final_instruction_msg,
|
||||
],
|
||||
prefill_response=response_prefill,
|
||||
functions=commands if self.config.use_functions_api else [],
|
||||
functions=functions if self.config.use_functions_api else [],
|
||||
)
|
||||
|
||||
def build_system_prompt(
|
||||
self,
|
||||
ai_profile: AIProfile,
|
||||
ai_directives: AIDirectives,
|
||||
commands: list[CompletionModelFunction],
|
||||
functions: list[CompletionModelFunction],
|
||||
include_os_info: bool,
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
@@ -162,7 +173,7 @@ class OneShotAgentPromptStrategy(PromptStrategy):
|
||||
self.config.body_template.format(
|
||||
constraints=format_numbered_list(ai_directives.constraints),
|
||||
resources=format_numbered_list(ai_directives.resources),
|
||||
commands=self._generate_commands_list(commands),
|
||||
commands=self._generate_commands_list(functions),
|
||||
best_practices=format_numbered_list(ai_directives.best_practices),
|
||||
)
|
||||
]
|
||||
|
||||
@@ -23,6 +23,7 @@ from forge.agent_protocol.models import (
|
||||
TaskRequestBody,
|
||||
TaskStepsListResponse,
|
||||
)
|
||||
from forge.components.code_flow_executor import CodeFlowExecutionComponent
|
||||
from forge.file_storage import FileStorage
|
||||
from forge.llm.providers import ModelProviderBudget, MultiProvider
|
||||
from forge.models.action import ActionErrorResult, ActionSuccessResult
|
||||
@@ -298,11 +299,16 @@ class AgentProtocolServer:
|
||||
else ""
|
||||
)
|
||||
output += f"{assistant_response.thoughts.speak}\n\n"
|
||||
output += (
|
||||
f"Next Command: {next_tool_to_use}"
|
||||
if next_tool_to_use.name != ASK_COMMAND
|
||||
else next_tool_to_use.arguments["question"]
|
||||
)
|
||||
if next_tool_to_use.name == CodeFlowExecutionComponent.execute_code_flow.name:
|
||||
code = next_tool_to_use.arguments["python_code"]
|
||||
plan = next_tool_to_use.arguments["plan_text"]
|
||||
output += f"Code for next step:\n```py\n# {plan}\n\n{code}\n```"
|
||||
else:
|
||||
output += (
|
||||
f"Next Command: {next_tool_to_use}"
|
||||
if next_tool_to_use.name != ASK_COMMAND
|
||||
else next_tool_to_use.arguments["question"]
|
||||
)
|
||||
|
||||
additional_output = {
|
||||
**(
|
||||
|
||||
@@ -630,6 +630,9 @@ def update_user(
|
||||
command_args: The arguments for the command.
|
||||
assistant_reply_dict: The assistant's reply.
|
||||
"""
|
||||
from forge.components.code_flow_executor import CodeFlowExecutionComponent
|
||||
from forge.llm.prompting.utils import indent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
print_assistant_thoughts(
|
||||
@@ -644,15 +647,29 @@ def update_user(
|
||||
# First log new-line so user can differentiate sections better in console
|
||||
print()
|
||||
safe_tool_name = remove_ansi_escape(action_proposal.use_tool.name)
|
||||
logger.info(
|
||||
f"COMMAND = {Fore.CYAN}{safe_tool_name}{Style.RESET_ALL} "
|
||||
f"ARGUMENTS = {Fore.CYAN}{action_proposal.use_tool.arguments}{Style.RESET_ALL}",
|
||||
extra={
|
||||
"title": "NEXT ACTION:",
|
||||
"title_color": Fore.CYAN,
|
||||
"preserve_color": True,
|
||||
},
|
||||
)
|
||||
if safe_tool_name == CodeFlowExecutionComponent.execute_code_flow.name:
|
||||
plan = action_proposal.use_tool.arguments["plan_text"]
|
||||
code = action_proposal.use_tool.arguments["python_code"]
|
||||
logger.info(
|
||||
f"\n{indent(code, f'{Fore.GREEN}>>> {Fore.RESET}')}\n",
|
||||
extra={
|
||||
"title": "PROPOSED ACTION:",
|
||||
"title_color": Fore.GREEN,
|
||||
"preserve_color": True,
|
||||
},
|
||||
)
|
||||
logger.debug(
|
||||
f"{plan}\n", extra={"title": "EXPLANATION:", "title_color": Fore.YELLOW}
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
str(action_proposal.use_tool),
|
||||
extra={
|
||||
"title": "PROPOSED ACTION:",
|
||||
"title_color": Fore.GREEN,
|
||||
"preserve_color": True,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def get_user_feedback(
|
||||
@@ -732,6 +749,12 @@ def print_assistant_thoughts(
|
||||
)
|
||||
|
||||
if isinstance(thoughts, AssistantThoughts):
|
||||
if thoughts.observations:
|
||||
print_attribute(
|
||||
"OBSERVATIONS",
|
||||
remove_ansi_escape(thoughts.observations),
|
||||
title_color=Fore.YELLOW,
|
||||
)
|
||||
print_attribute(
|
||||
"REASONING", remove_ansi_escape(thoughts.reasoning), title_color=Fore.YELLOW
|
||||
)
|
||||
@@ -753,7 +776,7 @@ def print_assistant_thoughts(
|
||||
line.strip(), extra={"title": "- ", "title_color": Fore.GREEN}
|
||||
)
|
||||
print_attribute(
|
||||
"CRITICISM",
|
||||
"SELF-CRITICISM",
|
||||
remove_ansi_escape(thoughts.self_criticism),
|
||||
title_color=Fore.YELLOW,
|
||||
)
|
||||
@@ -764,7 +787,7 @@ def print_assistant_thoughts(
|
||||
speak(assistant_thoughts_speak)
|
||||
else:
|
||||
print_attribute(
|
||||
"SPEAK", assistant_thoughts_speak, title_color=Fore.YELLOW
|
||||
"TL;DR", assistant_thoughts_speak, title_color=Fore.YELLOW
|
||||
)
|
||||
else:
|
||||
speak(thoughts_text)
|
||||
|
||||
30
autogpt/poetry.lock
generated
30
autogpt/poetry.lock
generated
@@ -4216,7 +4216,7 @@ test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"]
|
||||
name = "ptyprocess"
|
||||
version = "0.7.0"
|
||||
description = "Run a subprocess in a pseudo terminal"
|
||||
optional = true
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35"},
|
||||
@@ -5212,6 +5212,32 @@ files = [
|
||||
[package.dependencies]
|
||||
pyasn1 = ">=0.1.3"
|
||||
|
||||
[[package]]
|
||||
name = "ruff"
|
||||
version = "0.4.4"
|
||||
description = "An extremely fast Python linter and code formatter, written in Rust."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "ruff-0.4.4-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:29d44ef5bb6a08e235c8249294fa8d431adc1426bfda99ed493119e6f9ea1bf6"},
|
||||
{file = "ruff-0.4.4-py3-none-macosx_11_0_arm64.whl", hash = "sha256:c4efe62b5bbb24178c950732ddd40712b878a9b96b1d02b0ff0b08a090cbd891"},
|
||||
{file = "ruff-0.4.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4c8e2f1e8fc12d07ab521a9005d68a969e167b589cbcaee354cb61e9d9de9c15"},
|
||||
{file = "ruff-0.4.4-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:60ed88b636a463214905c002fa3eaab19795679ed55529f91e488db3fe8976ab"},
|
||||
{file = "ruff-0.4.4-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b90fc5e170fc71c712cc4d9ab0e24ea505c6a9e4ebf346787a67e691dfb72e85"},
|
||||
{file = "ruff-0.4.4-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:8e7e6ebc10ef16dcdc77fd5557ee60647512b400e4a60bdc4849468f076f6eef"},
|
||||
{file = "ruff-0.4.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b9ddb2c494fb79fc208cd15ffe08f32b7682519e067413dbaf5f4b01a6087bcd"},
|
||||
{file = "ruff-0.4.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c51c928a14f9f0a871082603e25a1588059b7e08a920f2f9fa7157b5bf08cfe9"},
|
||||
{file = "ruff-0.4.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b5eb0a4bfd6400b7d07c09a7725e1a98c3b838be557fee229ac0f84d9aa49c36"},
|
||||
{file = "ruff-0.4.4-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:b1867ee9bf3acc21778dcb293db504692eda5f7a11a6e6cc40890182a9f9e595"},
|
||||
{file = "ruff-0.4.4-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:1aecced1269481ef2894cc495647392a34b0bf3e28ff53ed95a385b13aa45768"},
|
||||
{file = "ruff-0.4.4-py3-none-musllinux_1_2_i686.whl", hash = "sha256:9da73eb616b3241a307b837f32756dc20a0b07e2bcb694fec73699c93d04a69e"},
|
||||
{file = "ruff-0.4.4-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:958b4ea5589706a81065e2a776237de2ecc3e763342e5cc8e02a4a4d8a5e6f95"},
|
||||
{file = "ruff-0.4.4-py3-none-win32.whl", hash = "sha256:cb53473849f011bca6e754f2cdf47cafc9c4f4ff4570003a0dad0b9b6890e876"},
|
||||
{file = "ruff-0.4.4-py3-none-win_amd64.whl", hash = "sha256:424e5b72597482543b684c11def82669cc6b395aa8cc69acc1858b5ef3e5daae"},
|
||||
{file = "ruff-0.4.4-py3-none-win_arm64.whl", hash = "sha256:39df0537b47d3b597293edbb95baf54ff5b49589eb7ff41926d8243caa995ea6"},
|
||||
{file = "ruff-0.4.4.tar.gz", hash = "sha256:f87ea42d5cdebdc6a69761a9d0bc83ae9b3b30d0ad78952005ba6568d6c022af"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "s3transfer"
|
||||
version = "0.10.0"
|
||||
@@ -6758,4 +6784,4 @@ benchmark = ["agbenchmark"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "b3d4efee5861b32152024dada1ec61f4241122419cb538012c00a6ed55ac8a4b"
|
||||
content-hash = "c729e10fd5ac85400d2499397974d1b1831fed3b591657a2fea9e86501b96e19"
|
||||
|
||||
@@ -30,9 +30,12 @@ gitpython = "^3.1.32"
|
||||
hypercorn = "^0.14.4"
|
||||
openai = "^1.7.2"
|
||||
orjson = "^3.8.10"
|
||||
ptyprocess = "^0.7.0"
|
||||
pydantic = "^2.7.2"
|
||||
pyright = "^1.1.364"
|
||||
python-dotenv = "^1.0.0"
|
||||
requests = "*"
|
||||
ruff = "^0.4.4"
|
||||
sentry-sdk = "^1.40.4"
|
||||
|
||||
# Benchmarking
|
||||
@@ -47,7 +50,6 @@ black = "^23.12.1"
|
||||
flake8 = "^7.0.0"
|
||||
isort = "^5.13.1"
|
||||
pre-commit = "*"
|
||||
pyright = "^1.1.364"
|
||||
|
||||
# Type stubs
|
||||
types-colorama = "*"
|
||||
|
||||
126
autogpt/tests/unit/test_code_flow_strategy.py
Normal file
126
autogpt/tests/unit/test_code_flow_strategy.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from forge.agent.protocols import CommandProvider
|
||||
from forge.command import Command, command
|
||||
from forge.components.code_flow_executor import CodeFlowExecutionComponent
|
||||
from forge.config.ai_directives import AIDirectives
|
||||
from forge.config.ai_profile import AIProfile
|
||||
from forge.llm.providers import AssistantChatMessage
|
||||
from forge.llm.providers.schema import JSONSchema
|
||||
|
||||
from autogpt.agents.prompt_strategies.code_flow import CodeFlowAgentPromptStrategy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = CodeFlowAgentPromptStrategy.default_configuration.copy(deep=True)
|
||||
prompt_strategy = CodeFlowAgentPromptStrategy(config, logger)
|
||||
|
||||
|
||||
class MockWebSearchProvider(CommandProvider):
|
||||
def get_commands(self):
|
||||
yield self.mock_web_search
|
||||
|
||||
@command(
|
||||
description="Searches the web",
|
||||
parameters={
|
||||
"query": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The search query",
|
||||
required=True,
|
||||
),
|
||||
"num_results": JSONSchema(
|
||||
type=JSONSchema.Type.INTEGER,
|
||||
description="The number of results to return",
|
||||
minimum=1,
|
||||
maximum=10,
|
||||
required=False,
|
||||
),
|
||||
},
|
||||
)
|
||||
def mock_web_search(self, query: str, num_results: Optional[int] = None) -> str:
|
||||
return "results"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_code_flow_build_prompt():
|
||||
commands = list(MockWebSearchProvider().get_commands())
|
||||
|
||||
ai_profile = AIProfile()
|
||||
ai_profile.ai_name = "DummyGPT"
|
||||
ai_profile.ai_goals = ["A model for testing purposes"]
|
||||
ai_profile.ai_role = "Help Testing"
|
||||
|
||||
ai_directives = AIDirectives()
|
||||
ai_directives.resources = ["resource_1"]
|
||||
ai_directives.constraints = ["constraint_1"]
|
||||
ai_directives.best_practices = ["best_practice_1"]
|
||||
|
||||
prompt = str(
|
||||
prompt_strategy.build_prompt(
|
||||
task="Figure out from file.csv how much was spent on utilities",
|
||||
messages=[],
|
||||
ai_profile=ai_profile,
|
||||
ai_directives=ai_directives,
|
||||
commands=commands,
|
||||
)
|
||||
)
|
||||
assert "DummyGPT" in prompt
|
||||
assert (
|
||||
"def mock_web_search(query: str, num_results: Optional[int] = None)" in prompt
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_code_flow_parse_response():
|
||||
response_content = """
|
||||
{
|
||||
"thoughts": {
|
||||
"past_action_summary": "This is the past action summary.",
|
||||
"observations": "This is the observation.",
|
||||
"text": "Some text on the AI's thoughts.",
|
||||
"reasoning": "This is the reasoning.",
|
||||
"self_criticism": "This is the self-criticism.",
|
||||
"plan": [
|
||||
"Plan 1",
|
||||
"Plan 2",
|
||||
"Plan 3"
|
||||
],
|
||||
"speak": "This is what the AI would say."
|
||||
},
|
||||
"immediate_plan": "Objective[objective1] Plan[plan1] Output[out1]",
|
||||
"python_code": "async def main() -> str:\n return 'You passed the test.'",
|
||||
}
|
||||
"""
|
||||
response = await CodeFlowAgentPromptStrategy(config, logger).parse_response_content(
|
||||
AssistantChatMessage(content=response_content)
|
||||
)
|
||||
assert "This is the observation." == response.thoughts.observations
|
||||
assert "This is the reasoning." == response.thoughts.reasoning
|
||||
|
||||
assert CodeFlowExecutionComponent.execute_code_flow.name == response.use_tool.name
|
||||
assert "async def main() -> str" in response.use_tool.arguments["python_code"]
|
||||
assert (
|
||||
"Objective[objective1] Plan[plan1] Output[out1]"
|
||||
in response.use_tool.arguments["plan_text"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_code_flow_execution():
|
||||
executor = CodeFlowExecutionComponent(
|
||||
lambda: [
|
||||
Command(
|
||||
names=["test_func"],
|
||||
description="",
|
||||
parameters=[],
|
||||
method=lambda: "You've passed the test!",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = await executor.execute_code_flow(
|
||||
python_code="async def main() -> str:\n return test_func()",
|
||||
plan_text="This is the plan text.",
|
||||
)
|
||||
assert "You've passed the test!" in result
|
||||
75
autogpt/tests/unit/test_function_code_validation.py
Normal file
75
autogpt/tests/unit/test_function_code_validation.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import pytest
|
||||
from forge.utils.function.code_validation import CodeValidator, FunctionDef
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_code_validation():
|
||||
validator = CodeValidator(
|
||||
available_functions={
|
||||
"read_webpage": FunctionDef(
|
||||
name="read_webpage",
|
||||
arg_types=[("url", "str"), ("query", "str")],
|
||||
arg_descs={
|
||||
"url": "URL to read",
|
||||
"query": "Query to search",
|
||||
"return_type": "Type of return value",
|
||||
},
|
||||
return_type="str",
|
||||
return_desc="Information matching the query",
|
||||
function_desc="Read a webpage and return the info matching the query",
|
||||
is_async=True,
|
||||
),
|
||||
"web_search": FunctionDef(
|
||||
name="web_search",
|
||||
arg_types=[("query", "str")],
|
||||
arg_descs={"query": "Query to search"},
|
||||
return_type="list[(str,str)]",
|
||||
return_desc="List of tuples with title and URL",
|
||||
function_desc="Search the web and return the search results",
|
||||
is_async=True,
|
||||
),
|
||||
"main": FunctionDef(
|
||||
name="main",
|
||||
arg_types=[],
|
||||
arg_descs={},
|
||||
return_type="str",
|
||||
return_desc="Answer in the text format",
|
||||
function_desc="Get the num of contributors to the autogpt github repo",
|
||||
is_async=False,
|
||||
),
|
||||
},
|
||||
available_objects={},
|
||||
)
|
||||
response = await validator.validate_code(
|
||||
raw_code="""
|
||||
def crawl_info(url: str, query: str) -> str | None:
|
||||
info = await read_webpage(url, query)
|
||||
if info:
|
||||
return info
|
||||
|
||||
urls = await read_webpage(url, "autogpt github contributor page")
|
||||
for url in urls.split('\\n'):
|
||||
info = await crawl_info(url, query)
|
||||
if info:
|
||||
return info
|
||||
|
||||
return None
|
||||
|
||||
def hehe():
|
||||
return 'hehe'
|
||||
|
||||
def main() -> str:
|
||||
query = "Find the number of contributors to the autogpt github repository"
|
||||
for title, url in ("autogpt github contributor page"):
|
||||
info = await crawl_info(url, query)
|
||||
if info:
|
||||
return info
|
||||
x = await hehe()
|
||||
return "No info found"
|
||||
""",
|
||||
packages=[],
|
||||
)
|
||||
assert response.functionCode is not None
|
||||
assert "async def crawl_info" in response.functionCode # async is added
|
||||
assert "async def main" in response.functionCode
|
||||
assert "x = hehe()" in response.functionCode # await is removed
|
||||
@@ -1,17 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from typing import Callable, Concatenate, Generic, ParamSpec, TypeVar, cast
|
||||
|
||||
from forge.agent.protocols import CommandProvider
|
||||
from typing import Callable, Generic, ParamSpec, TypeVar
|
||||
|
||||
from .parameter import CommandParameter
|
||||
|
||||
P = ParamSpec("P")
|
||||
CO = TypeVar("CO") # command output
|
||||
|
||||
_CP = TypeVar("_CP", bound=CommandProvider)
|
||||
|
||||
|
||||
class Command(Generic[P, CO]):
|
||||
"""A class representing a command.
|
||||
@@ -26,37 +22,60 @@ class Command(Generic[P, CO]):
|
||||
self,
|
||||
names: list[str],
|
||||
description: str,
|
||||
method: Callable[Concatenate[_CP, P], CO],
|
||||
method: Callable[P, CO],
|
||||
parameters: list[CommandParameter],
|
||||
):
|
||||
# Check if all parameters are provided
|
||||
if not self._parameters_match(method, parameters):
|
||||
raise ValueError(
|
||||
f"Command {names[0]} has different parameters than provided schema"
|
||||
)
|
||||
self.names = names
|
||||
self.description = description
|
||||
# Method technically has a `self` parameter, but we can ignore that
|
||||
# since Python passes it internally.
|
||||
self.method = cast(Callable[P, CO], method)
|
||||
self.method = method
|
||||
self.parameters = parameters
|
||||
|
||||
# Check if all parameters are provided
|
||||
if not self._parameters_match_signature():
|
||||
raise ValueError(
|
||||
f"Command {self.name} has different parameters than provided schema"
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.names[0] # TODO: fallback to other name if first one is taken
|
||||
|
||||
@property
|
||||
def is_async(self) -> bool:
|
||||
return inspect.iscoroutinefunction(self.method)
|
||||
|
||||
def _parameters_match(
|
||||
self, func: Callable, parameters: list[CommandParameter]
|
||||
) -> bool:
|
||||
@property
|
||||
def return_type(self) -> str:
|
||||
_type = inspect.signature(self.method).return_annotation
|
||||
if _type == inspect.Signature.empty:
|
||||
return "None"
|
||||
return _type.__name__
|
||||
|
||||
@property
|
||||
def header(self) -> str:
|
||||
"""Returns a function header representing the command's signature
|
||||
|
||||
Examples:
|
||||
```py
|
||||
def execute_python_code(code: str) -> str:
|
||||
|
||||
async def extract_info_from_content(content: str, instruction: str, output_type: type[~T]) -> ~T:
|
||||
""" # noqa
|
||||
return (
|
||||
f"{'async ' if self.is_async else ''}"
|
||||
f"def {self.name}{inspect.signature(self.method)}:"
|
||||
)
|
||||
|
||||
def _parameters_match_signature(self) -> bool:
|
||||
# Get the function's signature
|
||||
signature = inspect.signature(func)
|
||||
signature = inspect.signature(self.method)
|
||||
# Extract parameter names, ignoring 'self' for methods
|
||||
func_param_names = [
|
||||
param.name
|
||||
for param in signature.parameters.values()
|
||||
if param.name != "self"
|
||||
]
|
||||
names = [param.name for param in parameters]
|
||||
names = [param.name for param in self.parameters]
|
||||
# Check if sorted lists of names/keys are equal
|
||||
return sorted(func_param_names) == sorted(names)
|
||||
|
||||
@@ -71,7 +90,7 @@ class Command(Generic[P, CO]):
|
||||
for param in self.parameters
|
||||
]
|
||||
return (
|
||||
f"{self.names[0]}: {self.description.rstrip('.')}. "
|
||||
f"{self.name}: {self.description.rstrip('.')}. "
|
||||
f"Params: ({', '.join(params)})"
|
||||
)
|
||||
|
||||
|
||||
@@ -1,21 +1,28 @@
|
||||
import inspect
|
||||
import logging
|
||||
import re
|
||||
from typing import Callable, Concatenate, Optional, TypeVar
|
||||
from typing import Callable, Concatenate, Optional, TypeVar, cast
|
||||
|
||||
from forge.agent.protocols import CommandProvider
|
||||
from forge.models.json_schema import JSONSchema
|
||||
|
||||
from .command import CO, Command, CommandParameter, P
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_CP = TypeVar("_CP", bound=CommandProvider)
|
||||
|
||||
|
||||
def command(
|
||||
names: list[str] = [],
|
||||
names: Optional[list[str]] = None,
|
||||
description: Optional[str] = None,
|
||||
parameters: dict[str, JSONSchema] = {},
|
||||
) -> Callable[[Callable[Concatenate[_CP, P], CO]], Command[P, CO]]:
|
||||
parameters: Optional[dict[str, JSONSchema]] = None,
|
||||
) -> Callable[[Callable[Concatenate[_CP, P], CO] | Callable[P, CO]], Command[P, CO]]:
|
||||
"""
|
||||
The command decorator is used to make a Command from a function.
|
||||
Make a `Command` from a function or a method on a `CommandProvider`.
|
||||
All parameters are optional if the decorated function has a fully featured
|
||||
docstring. For the requirements of such a docstring,
|
||||
see `get_param_descriptions_from_docstring`.
|
||||
|
||||
Args:
|
||||
names (list[str]): The names of the command.
|
||||
@@ -27,34 +34,141 @@ def command(
|
||||
that the command executes.
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[Concatenate[_CP, P], CO]) -> Command[P, CO]:
|
||||
doc = func.__doc__ or ""
|
||||
def decorator(
|
||||
func: Callable[Concatenate[_CP, P], CO] | Callable[P, CO]
|
||||
) -> Command[P, CO]:
|
||||
# If names is not provided, use the function name
|
||||
command_names = names or [func.__name__]
|
||||
# If description is not provided, use the first part of the docstring
|
||||
if not (command_description := description):
|
||||
if not func.__doc__:
|
||||
raise ValueError("Description is required if function has no docstring")
|
||||
# Return the part of the docstring before double line break or everything
|
||||
command_description = re.sub(r"\s+", " ", doc.split("\n\n")[0].strip())
|
||||
_names = names or [func.__name__]
|
||||
|
||||
# If description is not provided, use the first part of the docstring
|
||||
docstring = inspect.getdoc(func)
|
||||
if not (_description := description):
|
||||
if not docstring:
|
||||
raise ValueError(
|
||||
"'description' is required if function has no docstring"
|
||||
)
|
||||
_description = get_clean_description_from_docstring(docstring)
|
||||
|
||||
if not (_parameters := parameters):
|
||||
if not docstring:
|
||||
raise ValueError(
|
||||
"'parameters' is required if function has no docstring"
|
||||
)
|
||||
|
||||
# Combine descriptions from docstring with JSONSchemas from annotations
|
||||
param_descriptions = get_param_descriptions_from_docstring(docstring)
|
||||
_parameters = get_params_json_schemas(func)
|
||||
for param, param_schema in _parameters.items():
|
||||
if desc := param_descriptions.get(param):
|
||||
param_schema.description = desc
|
||||
|
||||
# Parameters
|
||||
typed_parameters = [
|
||||
CommandParameter(
|
||||
name=param_name,
|
||||
spec=spec,
|
||||
)
|
||||
for param_name, spec in parameters.items()
|
||||
for param_name, spec in _parameters.items()
|
||||
]
|
||||
|
||||
# Wrap func with Command
|
||||
command = Command(
|
||||
names=command_names,
|
||||
description=command_description,
|
||||
method=func,
|
||||
names=_names,
|
||||
description=_description,
|
||||
# Method technically has a `self` parameter, but we can ignore that
|
||||
# since Python passes it internally.
|
||||
method=cast(Callable[P, CO], func),
|
||||
parameters=typed_parameters,
|
||||
)
|
||||
|
||||
return command
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def get_clean_description_from_docstring(docstring: str) -> str:
|
||||
"""Return the part of the docstring before double line break or everything"""
|
||||
return re.sub(r"\s+", " ", docstring.split("\n\n")[0].strip())
|
||||
|
||||
|
||||
def get_params_json_schemas(func: Callable) -> dict[str, JSONSchema]:
|
||||
"""Gets the annotations of the given function and converts them to JSONSchemas"""
|
||||
result: dict[str, JSONSchema] = {}
|
||||
for name, parameter in inspect.signature(func).parameters.items():
|
||||
if name == "return":
|
||||
continue
|
||||
param_schema = result[name] = JSONSchema.from_python_type(parameter.annotation)
|
||||
if parameter.default:
|
||||
param_schema.default = parameter.default
|
||||
param_schema.required = False
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_param_descriptions_from_docstring(docstring: str) -> dict[str, str]:
|
||||
"""
|
||||
Get parameter descriptions from a docstring. Requirements for the docstring:
|
||||
- The section describing the parameters MUST start with `Params:` or `Args:`, in any
|
||||
capitalization.
|
||||
- An entry describing a parameter MUST be indented by 4 spaces.
|
||||
- An entry describing a parameter MUST start with the parameter name, an optional
|
||||
type annotation, followed by a colon `:`.
|
||||
- Continuations of parameter descriptions MUST be indented relative to the first
|
||||
line of the entry.
|
||||
- The docstring must not be indented as a whole. To get a docstring with the uniform
|
||||
indentation stripped off, use `inspect.getdoc(func)`.
|
||||
|
||||
Example:
|
||||
```python
|
||||
\"\"\"
|
||||
This is the description. This will be ignored.
|
||||
The description can span multiple lines,
|
||||
|
||||
or contain any number of line breaks.
|
||||
|
||||
Params:
|
||||
param1: This is a simple parameter description.
|
||||
param2 (list[str]): This parameter also has a type annotation.
|
||||
param3: This parameter has a long description. This means we will have to let it
|
||||
continue on the next line. The continuation is indented relative to the first
|
||||
line of the entry.
|
||||
|
||||
param4: Extra line breaks to group parameters are allowed. This will not break
|
||||
the algorithm.
|
||||
|
||||
This text is
|
||||
is indented by
|
||||
less than 4 spaces
|
||||
and is interpreted as the end of the `Params:` section.
|
||||
\"\"\"
|
||||
```
|
||||
"""
|
||||
param_descriptions: dict[str, str] = {}
|
||||
param_section = False
|
||||
last_param_name = ""
|
||||
for line in docstring.split("\n"):
|
||||
if not line.strip(): # ignore empty lines
|
||||
continue
|
||||
|
||||
if line.lower().startswith(("params:", "args:")):
|
||||
param_section = True
|
||||
continue
|
||||
|
||||
if param_section:
|
||||
if line.strip() and not line.startswith(" " * 4): # end of section
|
||||
break
|
||||
|
||||
line = line[4:]
|
||||
if line.startswith(" ") and last_param_name: # continuation of description
|
||||
param_descriptions[last_param_name] += " " + line.strip()
|
||||
else:
|
||||
if _match := re.match(r"^(\w+).*?: (.*)", line):
|
||||
param_name = _match.group(1)
|
||||
param_desc = _match.group(2).strip()
|
||||
else:
|
||||
logger.warning(
|
||||
f"Invalid line in docstring's parameter section: {repr(line)}"
|
||||
)
|
||||
continue
|
||||
param_descriptions[param_name] = param_desc
|
||||
last_param_name = param_name
|
||||
return param_descriptions
|
||||
|
||||
@@ -102,6 +102,8 @@ class ActionHistoryComponent(
|
||||
|
||||
@staticmethod
|
||||
def _make_result_message(episode: Episode, result: ActionResult) -> ChatMessage:
|
||||
from forge.components.code_flow_executor import CodeFlowExecutionComponent
|
||||
|
||||
if result.status == "success":
|
||||
return (
|
||||
ToolResultMessage(
|
||||
@@ -110,11 +112,18 @@ class ActionHistoryComponent(
|
||||
)
|
||||
if episode.action.raw_message.tool_calls
|
||||
else ChatMessage.user(
|
||||
f"{episode.action.use_tool.name} returned: "
|
||||
(
|
||||
"Execution result:"
|
||||
if (
|
||||
episode.action.use_tool.name
|
||||
== CodeFlowExecutionComponent.execute_code_flow.name
|
||||
)
|
||||
else f"{episode.action.use_tool.name} returned:"
|
||||
)
|
||||
+ (
|
||||
f"```\n{result.outputs}\n```"
|
||||
f"\n```\n{result.outputs}\n```"
|
||||
if "\n" in str(result.outputs)
|
||||
else f"`{result.outputs}`"
|
||||
else f" `{result.outputs}`"
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
3
forge/forge/components/code_flow_executor/__init__.py
Normal file
3
forge/forge/components/code_flow_executor/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .code_flow_executor import CodeFlowExecutionComponent
|
||||
|
||||
__all__ = ["CodeFlowExecutionComponent"]
|
||||
@@ -0,0 +1,69 @@
|
||||
"""Commands to generate images based on text input"""
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
from typing import Any, Callable, Iterable, Iterator
|
||||
|
||||
from forge.agent.protocols import CommandProvider
|
||||
from forge.command import Command, command
|
||||
from forge.models.json_schema import JSONSchema
|
||||
|
||||
MAX_RESULT_LENGTH = 1000
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CodeFlowExecutionComponent(CommandProvider):
|
||||
"""A component that provides commands to execute code flow."""
|
||||
|
||||
def __init__(self, get_available_commands: Callable[[], Iterable[Command]]):
|
||||
self.get_available_commands = get_available_commands
|
||||
|
||||
def get_commands(self) -> Iterator[Command]:
|
||||
yield self.execute_code_flow
|
||||
|
||||
@command(
|
||||
parameters={
|
||||
"python_code": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The Python code to execute",
|
||||
required=True,
|
||||
),
|
||||
"plan_text": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The plan to written in a natural language",
|
||||
required=False,
|
||||
),
|
||||
},
|
||||
)
|
||||
async def execute_code_flow(self, python_code: str, plan_text: str) -> str:
|
||||
"""Execute the code flow.
|
||||
|
||||
Args:
|
||||
python_code: The Python code to execute
|
||||
plan_text: The plan implemented by the given Python code
|
||||
|
||||
Returns:
|
||||
str: The result of the code execution
|
||||
"""
|
||||
locals: dict[str, Any] = {}
|
||||
locals.update(self._get_available_functions())
|
||||
code = f"{python_code}" "\n\n" "exec_output = main()"
|
||||
logger.debug(f"Code-Flow Execution code:\n```py\n{code}\n```")
|
||||
exec(code, locals)
|
||||
result = await locals["exec_output"]
|
||||
logger.debug(f"Code-Flow Execution result:\n{result}")
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
|
||||
# limit the result to limit the characters
|
||||
if len(result) > MAX_RESULT_LENGTH:
|
||||
result = result[:MAX_RESULT_LENGTH] + "...[Truncated, Content is too long]"
|
||||
return result
|
||||
|
||||
def _get_available_functions(self) -> dict[str, Callable]:
|
||||
return {
|
||||
name: command
|
||||
for command in self.get_available_commands()
|
||||
for name in command.names
|
||||
if name != self.execute_code_flow.name
|
||||
}
|
||||
@@ -169,7 +169,8 @@ class FileManagerComponent(
|
||||
parameters={
|
||||
"folder": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The folder to list files in",
|
||||
description="The folder to list files in. "
|
||||
"Pass an empty string to list files in the workspace.",
|
||||
required=True,
|
||||
)
|
||||
},
|
||||
|
||||
@@ -25,7 +25,7 @@ class UserInteractionComponent(CommandProvider):
|
||||
},
|
||||
)
|
||||
def ask_user(self, question: str) -> str:
|
||||
"""If you need more details or information regarding the given goals,
|
||||
"""If you need more details or information regarding the given task,
|
||||
you can ask the user for input."""
|
||||
print(f"\nQ: {question}")
|
||||
resp = click.prompt("A")
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import inspect
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
@@ -154,7 +155,10 @@ class BaseOpenAIChatProvider(
|
||||
self,
|
||||
model_prompt: list[ChatMessage],
|
||||
model_name: _ModelName,
|
||||
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
|
||||
completion_parser: (
|
||||
Callable[[AssistantChatMessage], Awaitable[_T]]
|
||||
| Callable[[AssistantChatMessage], _T]
|
||||
) = lambda _: None,
|
||||
functions: Optional[list[CompletionModelFunction]] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
prefill_response: str = "",
|
||||
@@ -208,7 +212,15 @@ class BaseOpenAIChatProvider(
|
||||
parsed_result: _T = None # type: ignore
|
||||
if not parse_errors:
|
||||
try:
|
||||
parsed_result = completion_parser(assistant_msg)
|
||||
parsed_result = (
|
||||
await _result
|
||||
if inspect.isawaitable(
|
||||
_result := completion_parser(assistant_msg)
|
||||
)
|
||||
# cast(..) needed because inspect.isawaitable(..) loses type:
|
||||
# https://github.com/microsoft/pyright/issues/3690
|
||||
else cast(_T, _result)
|
||||
)
|
||||
except Exception as e:
|
||||
parse_errors.append(e)
|
||||
|
||||
|
||||
@@ -1,8 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import inspect
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, ParamSpec, Sequence, TypeVar
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Optional,
|
||||
ParamSpec,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
cast,
|
||||
)
|
||||
|
||||
import sentry_sdk
|
||||
import tenacity
|
||||
@@ -171,7 +182,10 @@ class AnthropicProvider(BaseChatModelProvider[AnthropicModelName, AnthropicSetti
|
||||
self,
|
||||
model_prompt: list[ChatMessage],
|
||||
model_name: AnthropicModelName,
|
||||
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
|
||||
completion_parser: (
|
||||
Callable[[AssistantChatMessage], Awaitable[_T]]
|
||||
| Callable[[AssistantChatMessage], _T]
|
||||
) = lambda _: None,
|
||||
functions: Optional[list[CompletionModelFunction]] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
prefill_response: str = "",
|
||||
@@ -237,7 +251,14 @@ class AnthropicProvider(BaseChatModelProvider[AnthropicModelName, AnthropicSetti
|
||||
+ "\n".join(str(e) for e in tool_call_errors)
|
||||
)
|
||||
|
||||
parsed_result = completion_parser(assistant_msg)
|
||||
# cast(..) needed because inspect.isawaitable(..) loses type info:
|
||||
# https://github.com/microsoft/pyright/issues/3690
|
||||
parsed_result = cast(
|
||||
_T,
|
||||
await _result
|
||||
if inspect.isawaitable(_result := completion_parser(assistant_msg))
|
||||
else _result,
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
self._logger.debug(
|
||||
|
||||
@@ -1,7 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, AsyncIterator, Callable, Optional, Sequence, TypeVar, get_args
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
get_args,
|
||||
)
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
@@ -99,7 +108,10 @@ class MultiProvider(BaseChatModelProvider[ModelName, ModelProviderSettings]):
|
||||
self,
|
||||
model_prompt: list[ChatMessage],
|
||||
model_name: ModelName,
|
||||
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
|
||||
completion_parser: (
|
||||
Callable[[AssistantChatMessage], Awaitable[_T]]
|
||||
| Callable[[AssistantChatMessage], _T]
|
||||
) = lambda _: None,
|
||||
functions: Optional[list[CompletionModelFunction]] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
prefill_response: str = "",
|
||||
|
||||
@@ -6,6 +6,7 @@ from collections import defaultdict
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Generic,
|
||||
@@ -135,6 +136,8 @@ class CompletionModelFunction(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
parameters: dict[str, "JSONSchema"]
|
||||
return_type: str | None = None
|
||||
is_async: bool = False
|
||||
|
||||
def fmt_line(self) -> str:
|
||||
params = ", ".join(
|
||||
@@ -143,6 +146,44 @@ class CompletionModelFunction(BaseModel):
|
||||
)
|
||||
return f"{self.name}: {self.description}. Params: ({params})"
|
||||
|
||||
def fmt_function_stub(self, impl: str = "pass") -> str:
|
||||
"""
|
||||
Formats and returns a function stub as a string with types and descriptions.
|
||||
|
||||
Returns:
|
||||
str: The formatted function header.
|
||||
"""
|
||||
from forge.llm.prompting.utils import indent
|
||||
|
||||
params = ", ".join(
|
||||
f"{name}: {p.python_type}"
|
||||
+ (
|
||||
f" = {str(p.default)}"
|
||||
if p.default
|
||||
else " = None"
|
||||
if not p.required
|
||||
else ""
|
||||
)
|
||||
for name, p in self.parameters.items()
|
||||
)
|
||||
_def = "async def" if self.is_async else "def"
|
||||
_return = f" -> {self.return_type}" if self.return_type else ""
|
||||
return f"{_def} {self.name}({params}){_return}:\n" + indent(
|
||||
'"""\n'
|
||||
f"{self.description}\n\n"
|
||||
"Params:\n"
|
||||
+ indent(
|
||||
"\n".join(
|
||||
f"{name}: {param.description}"
|
||||
for name, param in self.parameters.items()
|
||||
if param.description
|
||||
)
|
||||
)
|
||||
+ "\n"
|
||||
'"""\n'
|
||||
f"{impl}"
|
||||
)
|
||||
|
||||
def validate_call(
|
||||
self, function_call: AssistantFunctionCall
|
||||
) -> tuple[bool, list["ValidationError"]]:
|
||||
@@ -415,7 +456,10 @@ class BaseChatModelProvider(BaseModelProvider[_ModelName, _ModelProviderSettings
|
||||
self,
|
||||
model_prompt: list[ChatMessage],
|
||||
model_name: _ModelName,
|
||||
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
|
||||
completion_parser: (
|
||||
Callable[[AssistantChatMessage], Awaitable[_T]]
|
||||
| Callable[[AssistantChatMessage], _T]
|
||||
) = lambda _: None,
|
||||
functions: Optional[list[CompletionModelFunction]] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
prefill_response: str = "",
|
||||
|
||||
@@ -80,7 +80,7 @@ def function_specs_from_commands(
|
||||
"""Get LLM-consumable function specs for the agent's available commands."""
|
||||
return [
|
||||
CompletionModelFunction(
|
||||
name=command.names[0],
|
||||
name=command.name,
|
||||
description=command.description,
|
||||
parameters={param.name: param.spec for param in command.parameters},
|
||||
)
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import ast
|
||||
import enum
|
||||
import typing
|
||||
from textwrap import indent
|
||||
from typing import Optional, overload
|
||||
from types import NoneType
|
||||
from typing import Any, Optional, is_typeddict, overload
|
||||
|
||||
from jsonschema import Draft7Validator, ValidationError
|
||||
from pydantic import BaseModel
|
||||
@@ -14,14 +17,17 @@ class JSONSchema(BaseModel):
|
||||
NUMBER = "number"
|
||||
INTEGER = "integer"
|
||||
BOOLEAN = "boolean"
|
||||
TYPE = "type"
|
||||
|
||||
# TODO: add docstrings
|
||||
description: Optional[str] = None
|
||||
type: Optional[Type] = None
|
||||
enum: Optional[list] = None
|
||||
required: bool = False
|
||||
default: Any = None
|
||||
items: Optional["JSONSchema"] = None
|
||||
properties: Optional[dict[str, "JSONSchema"]] = None
|
||||
additional_properties: Optional["JSONSchema"] = None
|
||||
minimum: Optional[int | float] = None
|
||||
maximum: Optional[int | float] = None
|
||||
minItems: Optional[int] = None
|
||||
@@ -31,6 +37,7 @@ class JSONSchema(BaseModel):
|
||||
schema: dict = {
|
||||
"type": self.type.value if self.type else None,
|
||||
"description": self.description,
|
||||
"default": repr(self.default),
|
||||
}
|
||||
if self.type == "array":
|
||||
if self.items:
|
||||
@@ -45,6 +52,8 @@ class JSONSchema(BaseModel):
|
||||
schema["required"] = [
|
||||
name for name, prop in self.properties.items() if prop.required
|
||||
]
|
||||
if self.additional_properties:
|
||||
schema["additionalProperties"] = self.additional_properties.to_dict()
|
||||
elif self.enum:
|
||||
schema["enum"] = self.enum
|
||||
else:
|
||||
@@ -63,11 +72,15 @@ class JSONSchema(BaseModel):
|
||||
return JSONSchema(
|
||||
description=schema.get("description"),
|
||||
type=schema["type"],
|
||||
default=ast.literal_eval(d) if (d := schema.get("default")) else None,
|
||||
enum=schema.get("enum"),
|
||||
items=JSONSchema.from_dict(schema["items"]) if "items" in schema else None,
|
||||
items=JSONSchema.from_dict(i) if (i := schema.get("items")) else None,
|
||||
properties=JSONSchema.parse_properties(schema)
|
||||
if schema["type"] == "object"
|
||||
else None,
|
||||
additional_properties=JSONSchema.from_dict(ap)
|
||||
if schema["type"] == "object" and (ap := schema.get("additionalProperties"))
|
||||
else None,
|
||||
minimum=schema.get("minimum"),
|
||||
maximum=schema.get("maximum"),
|
||||
minItems=schema.get("minItems"),
|
||||
@@ -123,6 +136,82 @@ class JSONSchema(BaseModel):
|
||||
f"interface {interface_name} " if interface_name else ""
|
||||
) + f"{{\n{indent(attributes_string, ' ')}\n}}"
|
||||
|
||||
_PYTHON_TO_JSON_TYPE: dict[typing.Type, Type] = {
|
||||
int: Type.INTEGER,
|
||||
str: Type.STRING,
|
||||
bool: Type.BOOLEAN,
|
||||
float: Type.NUMBER,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_python_type(cls, T: typing.Type) -> "JSONSchema":
|
||||
if _t := cls._PYTHON_TO_JSON_TYPE.get(T):
|
||||
partial_schema = cls(type=_t, required=True)
|
||||
elif (
|
||||
typing.get_origin(T) is typing.Union and typing.get_args(T)[-1] is NoneType
|
||||
):
|
||||
if len(typing.get_args(T)[:-1]) > 1:
|
||||
raise NotImplementedError("Union types are currently not supported")
|
||||
partial_schema = cls.from_python_type(typing.get_args(T)[0])
|
||||
partial_schema.required = False
|
||||
return partial_schema
|
||||
elif issubclass(T, BaseModel):
|
||||
partial_schema = JSONSchema.from_dict(T.schema())
|
||||
elif T is list or typing.get_origin(T) is list:
|
||||
partial_schema = JSONSchema(
|
||||
type=JSONSchema.Type.ARRAY,
|
||||
items=JSONSchema.from_python_type(T_v)
|
||||
if (T_v := typing.get_args(T)[0])
|
||||
else None,
|
||||
)
|
||||
elif T is dict or typing.get_origin(T) is dict:
|
||||
partial_schema = JSONSchema(
|
||||
type=JSONSchema.Type.OBJECT,
|
||||
additional_properties=JSONSchema.from_python_type(T_v)
|
||||
if (T_v := typing.get_args(T)[1])
|
||||
else None,
|
||||
)
|
||||
elif is_typeddict(T):
|
||||
partial_schema = JSONSchema(
|
||||
type=JSONSchema.Type.OBJECT,
|
||||
properties={
|
||||
k: JSONSchema.from_python_type(v)
|
||||
for k, v in T.__annotations__.items()
|
||||
},
|
||||
)
|
||||
else:
|
||||
raise TypeError(f"JSONSchema.from_python_type is not implemented for {T}")
|
||||
|
||||
partial_schema.required = True
|
||||
return partial_schema
|
||||
|
||||
_JSON_TO_PYTHON_TYPE: dict[Type, typing.Type] = {
|
||||
j: p for p, j in _PYTHON_TO_JSON_TYPE.items()
|
||||
}
|
||||
|
||||
@property
|
||||
def python_type(self) -> str:
|
||||
if self.type in self._JSON_TO_PYTHON_TYPE:
|
||||
return self._JSON_TO_PYTHON_TYPE[self.type].__name__
|
||||
elif self.type == JSONSchema.Type.ARRAY:
|
||||
return f"list[{self.items.python_type}]" if self.items else "list"
|
||||
elif self.type == JSONSchema.Type.OBJECT:
|
||||
if not self.properties:
|
||||
return "dict"
|
||||
raise NotImplementedError(
|
||||
"JSONSchema.python_type doesn't support TypedDicts yet"
|
||||
)
|
||||
elif self.enum:
|
||||
return "Union[" + ", ".join(repr(v) for v in self.enum) + "]"
|
||||
elif self.type == JSONSchema.Type.TYPE:
|
||||
return "type"
|
||||
elif self.type is None:
|
||||
return "Any"
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"JSONSchema.python_type does not support Type.{self.type.name} yet"
|
||||
)
|
||||
|
||||
@property
|
||||
def typescript_type(self) -> str:
|
||||
if not self.type:
|
||||
@@ -141,6 +230,10 @@ class JSONSchema(BaseModel):
|
||||
return self.to_typescript_object_interface()
|
||||
if self.enum:
|
||||
return " | ".join(repr(v) for v in self.enum)
|
||||
elif self.type == JSONSchema.Type.TYPE:
|
||||
return "type"
|
||||
elif self.type is None:
|
||||
return "any"
|
||||
|
||||
raise NotImplementedError(
|
||||
f"JSONSchema.typescript_type does not support Type.{self.type.name} yet"
|
||||
|
||||
515
forge/forge/utils/function/code_validation.py
Normal file
515
forge/forge/utils/function/code_validation.py
Normal file
@@ -0,0 +1,515 @@
|
||||
import ast
|
||||
import collections
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import pathlib
|
||||
import re
|
||||
import typing
|
||||
import black
|
||||
import isort
|
||||
|
||||
from forge.utils.function.model import FunctionDef, ObjectType, ValidationResponse
|
||||
from forge.utils.function.visitor import FunctionVisitor
|
||||
from forge.utils.function.util import (
|
||||
genererate_line_error,
|
||||
generate_object_code,
|
||||
generate_compiled_code,
|
||||
validate_matching_function,
|
||||
)
|
||||
from forge.utils.function.exec import (
|
||||
exec_external_on_contents,
|
||||
ExecError,
|
||||
PROJECT_TEMP_DIR,
|
||||
DEFAULT_DEPS,
|
||||
execute_command,
|
||||
setup_if_required,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CodeValidator:
|
||||
def __init__(
|
||||
self,
|
||||
function_name: str | None = None,
|
||||
available_functions: dict[str, FunctionDef] | None = None,
|
||||
available_objects: dict[str, ObjectType] | None = None,
|
||||
):
|
||||
self.func_name: str = function_name or ""
|
||||
self.available_functions: dict[str, FunctionDef] = available_functions or {}
|
||||
self.available_objects: dict[str, ObjectType] = available_objects or {}
|
||||
|
||||
async def reformat_code(
|
||||
self,
|
||||
code: str,
|
||||
packages: list[str] = [],
|
||||
) -> str:
|
||||
"""
|
||||
Reformat the code snippet
|
||||
Args:
|
||||
code (str): The code snippet to reformat
|
||||
packages (list[str]): The list of packages to validate
|
||||
Returns:
|
||||
str: The reformatted code snippet
|
||||
"""
|
||||
try:
|
||||
code = (
|
||||
await self.validate_code(
|
||||
raw_code=code,
|
||||
packages=packages,
|
||||
raise_validation_error=False,
|
||||
add_code_stubs=False,
|
||||
)
|
||||
).get_compiled_code()
|
||||
except Exception as e:
|
||||
# We move on with unfixed code if there's an error
|
||||
logger.warning(f"Error formatting code for route #{self.func_name}: {e}")
|
||||
raise e
|
||||
|
||||
for formatter in [
|
||||
lambda code: isort.code(code),
|
||||
lambda code: black.format_str(code, mode=black.FileMode()),
|
||||
]:
|
||||
try:
|
||||
code = formatter(code)
|
||||
except Exception as e:
|
||||
# We move on with unformatted code if there's an error
|
||||
logger.warning(
|
||||
f"Error formatting code for route #{self.func_name}: {e}"
|
||||
)
|
||||
|
||||
return code
|
||||
|
||||
async def validate_code(
|
||||
self,
|
||||
raw_code: str,
|
||||
packages: list[str] = [],
|
||||
raise_validation_error: bool = True,
|
||||
add_code_stubs: bool = True,
|
||||
call_cnt: int = 0,
|
||||
) -> ValidationResponse:
|
||||
"""
|
||||
Validate the code snippet for any error
|
||||
Args:
|
||||
packages (list[Package]): The list of packages to validate
|
||||
raw_code (str): The code snippet to validate
|
||||
Returns:
|
||||
ValidationResponse: The response of the validation
|
||||
Raise:
|
||||
ValidationError(e): The list of validation errors in the code snippet
|
||||
"""
|
||||
validation_errors: list[str] = []
|
||||
|
||||
try:
|
||||
tree = ast.parse(raw_code)
|
||||
visitor = FunctionVisitor()
|
||||
visitor.visit(tree)
|
||||
validation_errors.extend(visitor.errors)
|
||||
except Exception as e:
|
||||
# parse invalid code line and add it to the error message
|
||||
error = f"Error parsing code: {e}"
|
||||
|
||||
if "async lambda" in raw_code:
|
||||
error += "\nAsync lambda is not supported in Python. "
|
||||
"Use async def instead!"
|
||||
|
||||
if line := re.search(r"line (\d+)", error):
|
||||
raise Exception(
|
||||
genererate_line_error(error, raw_code, int(line.group(1)))
|
||||
)
|
||||
else:
|
||||
raise Exception(error)
|
||||
|
||||
# Eliminate duplicate visitor.functions and visitor.objects, prefer the last one
|
||||
visitor.imports = list(set(visitor.imports))
|
||||
visitor.functions = list({f.name: f for f in visitor.functions}.values())
|
||||
visitor.objects = list(
|
||||
{
|
||||
o.name: o
|
||||
for o in visitor.objects
|
||||
if o.name not in self.available_objects
|
||||
}.values()
|
||||
)
|
||||
|
||||
# Add implemented functions into the main function, only link the stub functions
|
||||
deps_funcs = [f for f in visitor.functions if f.is_implemented]
|
||||
stub_funcs = [f for f in visitor.functions if not f.is_implemented]
|
||||
|
||||
objects_block = zip(
|
||||
["\n\n" + generate_object_code(obj) for obj in visitor.objects],
|
||||
visitor.objectsIdx,
|
||||
)
|
||||
functions_block = zip(
|
||||
["\n\n" + fun.function_code for fun in deps_funcs], visitor.functionsIdx
|
||||
)
|
||||
globals_block = zip(
|
||||
["\n\n" + glob for glob in visitor.globals], visitor.globalsIdx
|
||||
)
|
||||
function_code = "".join(
|
||||
code
|
||||
for code, _ in sorted(
|
||||
list(objects_block) + list(functions_block) + list(globals_block),
|
||||
key=lambda x: x[1],
|
||||
)
|
||||
).strip()
|
||||
|
||||
# No need to validate main function if it's not provided
|
||||
if self.func_name:
|
||||
main_func = self.__validate_main_function(
|
||||
deps_funcs=deps_funcs,
|
||||
function_code=function_code,
|
||||
validation_errors=validation_errors,
|
||||
)
|
||||
function_template = main_func.function_template
|
||||
else:
|
||||
function_template = None
|
||||
|
||||
# Validate that code is not re-declaring any existing entities.
|
||||
already_declared_entities = set(
|
||||
[
|
||||
obj.name
|
||||
for obj in visitor.objects
|
||||
if obj.name in self.available_objects.keys()
|
||||
]
|
||||
+ [
|
||||
func.name
|
||||
for func in visitor.functions
|
||||
if func.name in self.available_functions.keys()
|
||||
]
|
||||
)
|
||||
if not already_declared_entities:
|
||||
validation_errors.append(
|
||||
"These class/function names has already been declared in the code, "
|
||||
"no need to declare them again: " + ", ".join(already_declared_entities)
|
||||
)
|
||||
|
||||
result = ValidationResponse(
|
||||
function_name=self.func_name,
|
||||
available_objects=self.available_objects,
|
||||
available_functions=self.available_functions,
|
||||
rawCode=function_code,
|
||||
imports=visitor.imports.copy(),
|
||||
objects=[], # Objects will be bundled in the function_code instead.
|
||||
template=function_template or "",
|
||||
functionCode=function_code,
|
||||
functions=stub_funcs,
|
||||
packages=packages,
|
||||
)
|
||||
|
||||
# Execute static validators and fixers.
|
||||
# print('old compiled code import ---->', result.imports)
|
||||
old_compiled_code = generate_compiled_code(result, add_code_stubs)
|
||||
validation_errors.extend(await static_code_analysis(result))
|
||||
new_compiled_code = result.get_compiled_code()
|
||||
|
||||
# Auto-fixer works, retry validation (limit to 5 times, to avoid infinite loop)
|
||||
if old_compiled_code != new_compiled_code and call_cnt < 5:
|
||||
return await self.validate_code(
|
||||
packages=packages,
|
||||
raw_code=new_compiled_code,
|
||||
raise_validation_error=raise_validation_error,
|
||||
add_code_stubs=add_code_stubs,
|
||||
call_cnt=call_cnt + 1,
|
||||
)
|
||||
|
||||
if validation_errors:
|
||||
if raise_validation_error:
|
||||
error_message = "".join("\n * " + e for e in validation_errors)
|
||||
raise Exception("Error validating code: " + error_message)
|
||||
else:
|
||||
# This should happen only on `reformat_code` call
|
||||
logger.warning("Error validating code: %s", validation_errors)
|
||||
|
||||
return result
|
||||
|
||||
def __validate_main_function(
|
||||
self,
|
||||
deps_funcs: list[FunctionDef],
|
||||
function_code: str,
|
||||
validation_errors: list[str],
|
||||
) -> FunctionDef:
|
||||
"""
|
||||
Validate the main function body and signature
|
||||
Returns:
|
||||
tuple[str, FunctionDef]: The function ID and the function object
|
||||
"""
|
||||
# Validate that the main function is implemented.
|
||||
func_obj = next((f for f in deps_funcs if f.name == self.func_name), None)
|
||||
if not func_obj or not func_obj.is_implemented:
|
||||
raise Exception(
|
||||
f"Main Function body {self.func_name} is not implemented."
|
||||
f" Please complete the implementation of this function!"
|
||||
)
|
||||
func_obj.function_code = function_code
|
||||
|
||||
# Validate that the main function is matching the expected signature.
|
||||
func_req: FunctionDef | None = self.available_functions.get(self.func_name)
|
||||
if not func_req:
|
||||
raise AssertionError(f"Function {self.func_name} does not exist on DB")
|
||||
try:
|
||||
validate_matching_function(func_obj, func_req)
|
||||
except Exception as e:
|
||||
validation_errors.append(e.__str__())
|
||||
|
||||
return func_obj
|
||||
|
||||
|
||||
# ======= Static Code Validation Helper Functions =======#
|
||||
|
||||
|
||||
async def static_code_analysis(func: ValidationResponse) -> list[str]:
|
||||
"""
|
||||
Run static code analysis on the function code and mutate the function code to
|
||||
fix any issues.
|
||||
Args:
|
||||
func (ValidationResponse):
|
||||
The function to run static code analysis on. `func` will be mutated.
|
||||
Returns:
|
||||
list[str]: The list of validation errors
|
||||
"""
|
||||
validation_errors = []
|
||||
validation_errors += await __execute_ruff(func)
|
||||
validation_errors += await __execute_pyright(func)
|
||||
|
||||
return validation_errors
|
||||
|
||||
|
||||
CODE_SEPARATOR = "#------Code-Start------#"
|
||||
|
||||
|
||||
def __pack_import_and_function_code(func: ValidationResponse) -> str:
|
||||
return "\n".join(func.imports + [CODE_SEPARATOR, func.rawCode])
|
||||
|
||||
|
||||
def __unpack_import_and_function_code(code: str) -> tuple[list[str], str]:
|
||||
split = code.split(CODE_SEPARATOR)
|
||||
return split[0].splitlines(), split[1].strip()
|
||||
|
||||
|
||||
async def __execute_ruff(func: ValidationResponse) -> list[str]:
|
||||
code = __pack_import_and_function_code(func)
|
||||
|
||||
try:
|
||||
# Currently Disabled Rule List
|
||||
# E402 module level import not at top of file
|
||||
# F841 local variable is assigned to but never used
|
||||
code = await exec_external_on_contents(
|
||||
command_arguments=[
|
||||
"ruff",
|
||||
"check",
|
||||
"--fix",
|
||||
"--ignore",
|
||||
"F841",
|
||||
"--ignore",
|
||||
"E402",
|
||||
"--ignore",
|
||||
"F811", # Redefinition of unused '...' from line ...
|
||||
],
|
||||
file_contents=code,
|
||||
suffix=".py",
|
||||
raise_file_contents_on_error=True,
|
||||
)
|
||||
func.imports, func.rawCode = __unpack_import_and_function_code(code)
|
||||
return []
|
||||
|
||||
except ExecError as e:
|
||||
if e.content:
|
||||
# Ruff failed, but the code is reformatted
|
||||
code = e.content
|
||||
e = str(e)
|
||||
|
||||
error_messages = [
|
||||
v
|
||||
for v in str(e).split("\n")
|
||||
if v.strip()
|
||||
if re.match(r"Found \d+ errors?\.*", v) is None
|
||||
]
|
||||
|
||||
added_imports, error_messages = await __fix_missing_imports(
|
||||
error_messages, func
|
||||
)
|
||||
|
||||
# Append problematic line to the error message or add it as TODO line
|
||||
validation_errors: list[str] = []
|
||||
split_pattern = r"(.+):(\d+):(\d+): (.+)"
|
||||
for error_message in error_messages:
|
||||
error_split = re.match(split_pattern, error_message)
|
||||
|
||||
if not error_split:
|
||||
error = error_message
|
||||
else:
|
||||
_, line, _, error = error_split.groups()
|
||||
error = genererate_line_error(error, code, int(line))
|
||||
|
||||
validation_errors.append(error)
|
||||
|
||||
func.imports, func.rawCode = __unpack_import_and_function_code(code)
|
||||
func.imports.extend(added_imports) # Avoid line-code change, do it at the end.
|
||||
|
||||
return validation_errors
|
||||
|
||||
|
||||
async def __execute_pyright(func: ValidationResponse) -> list[str]:
|
||||
code = __pack_import_and_function_code(func)
|
||||
validation_errors: list[str] = []
|
||||
|
||||
# Create temporary directory under the TEMP_DIR with random name
|
||||
temp_dir = PROJECT_TEMP_DIR / (func.function_name)
|
||||
py_path = await setup_if_required(temp_dir)
|
||||
|
||||
async def __execute_pyright_commands(code: str) -> list[str]:
|
||||
try:
|
||||
await execute_command(
|
||||
["pip", "install", "-r", "requirements.txt"], temp_dir, py_path
|
||||
)
|
||||
except Exception as e:
|
||||
# Unknown deps should be reported as validation errors
|
||||
validation_errors.append(e.__str__())
|
||||
|
||||
# execute pyright
|
||||
result = await execute_command(
|
||||
["pyright", "--outputjson"], temp_dir, py_path, raise_on_error=False
|
||||
)
|
||||
if not result:
|
||||
return []
|
||||
|
||||
try:
|
||||
json_response = json.loads(result)["generalDiagnostics"]
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing pyright output, error: {e} output: {result}")
|
||||
raise e
|
||||
|
||||
for e in json_response:
|
||||
rule: str = e.get("rule", "")
|
||||
severity: str = e.get("severity", "")
|
||||
excluded_rules = ["reportRedeclaration"]
|
||||
if severity != "error" or any([rule.startswith(r) for r in excluded_rules]):
|
||||
continue
|
||||
|
||||
e = genererate_line_error(
|
||||
error=f"{e['message']}. {e.get('rule', '')}",
|
||||
code=code,
|
||||
line_number=e["range"]["start"]["line"] + 1,
|
||||
)
|
||||
validation_errors.append(e)
|
||||
|
||||
# read code from code.py. split the code into imports and raw code
|
||||
code = open(f"{temp_dir}/code.py").read()
|
||||
code, error_messages = await __fix_async_calls(code, validation_errors)
|
||||
func.imports, func.rawCode = __unpack_import_and_function_code(code)
|
||||
|
||||
return validation_errors
|
||||
|
||||
packages = "\n".join([str(p) for p in func.packages if p not in DEFAULT_DEPS])
|
||||
(temp_dir / "requirements.txt").write_text(packages)
|
||||
(temp_dir / "code.py").write_text(code)
|
||||
|
||||
return await __execute_pyright_commands(code)
|
||||
|
||||
|
||||
async def find_module_dist_and_source(
|
||||
module: str, py_path: pathlib.Path | str
|
||||
) -> typing.Tuple[pathlib.Path | None, pathlib.Path | None]:
|
||||
# Find the module in the env
|
||||
modules_path = pathlib.Path(py_path).parent / "lib" / "python3.11" / "site-packages"
|
||||
matches = modules_path.glob(f"{module}*")
|
||||
|
||||
# resolve the generator to an array
|
||||
matches = list(matches)
|
||||
if not matches:
|
||||
return None, None
|
||||
|
||||
# find the dist info path and the module path
|
||||
dist_info_path: typing.Optional[pathlib.Path] = None
|
||||
module_path: typing.Optional[pathlib.Path] = None
|
||||
|
||||
# find the dist info path
|
||||
for match in matches:
|
||||
if re.match(f"{module}-[0-9]+.[0-9]+.[0-9]+.dist-info", match.name):
|
||||
dist_info_path = match
|
||||
break
|
||||
# Get the module path
|
||||
for match in matches:
|
||||
if module == match.name:
|
||||
module_path = match
|
||||
break
|
||||
|
||||
return dist_info_path, module_path
|
||||
|
||||
|
||||
AUTO_IMPORT_TYPES: dict[str, str] = {
|
||||
"Enum": "from enum import Enum",
|
||||
"array": "from array import array",
|
||||
}
|
||||
for t in typing.__all__:
|
||||
AUTO_IMPORT_TYPES[t] = f"from typing import {t}"
|
||||
for t in datetime.__all__:
|
||||
AUTO_IMPORT_TYPES[t] = f"from datetime import {t}"
|
||||
for t in collections.__all__:
|
||||
AUTO_IMPORT_TYPES[t] = f"from collections import {t}"
|
||||
|
||||
|
||||
async def __fix_async_calls(code: str, errors: list[str]) -> tuple[str, list[str]]:
|
||||
"""
|
||||
Fix the async calls in the code
|
||||
Args:
|
||||
code (str): The code snippet
|
||||
errors (list[str]): The list of errors
|
||||
func (ValidationResponse): The function to fix the async calls
|
||||
Returns:
|
||||
tuple[str, list[str]]: The fixed code snippet and the list of errors
|
||||
"""
|
||||
async_calls = set()
|
||||
new_errors = []
|
||||
for error in errors:
|
||||
pattern = '"__await__" is not present. reportGeneralTypeIssues -> (.+)'
|
||||
match = re.search(pattern, error)
|
||||
if match:
|
||||
async_calls.add(match.group(1))
|
||||
else:
|
||||
new_errors.append(error)
|
||||
|
||||
for async_call in async_calls:
|
||||
func_call = re.search(r"await ([a-zA-Z0-9_]+)", async_call)
|
||||
if func_call:
|
||||
func_name = func_call.group(1)
|
||||
code = code.replace(f"await {func_name}", f"{func_name}")
|
||||
|
||||
return code, new_errors
|
||||
|
||||
|
||||
async def __fix_missing_imports(
|
||||
errors: list[str], func: ValidationResponse
|
||||
) -> tuple[set[str], list[str]]:
|
||||
"""
|
||||
Generate missing imports based on the errors
|
||||
Args:
|
||||
errors (list[str]): The list of errors
|
||||
func (ValidationResponse): The function to fix the imports
|
||||
Returns:
|
||||
tuple[set[str], list[str]]: The set of missing imports and the list
|
||||
of non-missing import errors
|
||||
"""
|
||||
missing_imports = []
|
||||
filtered_errors = []
|
||||
for error in errors:
|
||||
pattern = r"Undefined name `(.+?)`"
|
||||
match = re.search(pattern, error)
|
||||
if not match:
|
||||
filtered_errors.append(error)
|
||||
continue
|
||||
|
||||
missing = match.group(1)
|
||||
if missing in AUTO_IMPORT_TYPES:
|
||||
missing_imports.append(AUTO_IMPORT_TYPES[missing])
|
||||
elif missing in func.available_functions:
|
||||
# TODO FIX THIS!! IMPORT AUTOGPT CORRECY SERVICE.
|
||||
missing_imports.append(f"from project.{missing}_service import {missing}")
|
||||
elif missing in func.available_objects:
|
||||
# TODO FIX THIS!! IMPORT AUTOGPT CORRECY SERVICE.
|
||||
missing_imports.append(f"from project.{missing}_object import {missing}")
|
||||
else:
|
||||
filtered_errors.append(error)
|
||||
|
||||
return set(missing_imports), filtered_errors
|
||||
195
forge/forge/utils/function/exec.py
Normal file
195
forge/forge/utils/function/exec.py
Normal file
@@ -0,0 +1,195 @@
|
||||
import asyncio
|
||||
import enum
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
from asyncio.subprocess import Process
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OutputType(enum.Enum):
|
||||
STD_OUT = "stdout"
|
||||
STD_ERR = "stderr"
|
||||
BOTH = "both"
|
||||
|
||||
|
||||
class ExecError(Exception):
|
||||
content: str | None
|
||||
|
||||
def __init__(self, error: str, content: str | None = None):
|
||||
super().__init__(error)
|
||||
self.content = content
|
||||
|
||||
|
||||
async def exec_external_on_contents(
|
||||
command_arguments: list[str],
|
||||
file_contents,
|
||||
suffix: str = ".py",
|
||||
output_type: OutputType = OutputType.BOTH,
|
||||
raise_file_contents_on_error: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Execute an external tool with the provided command arguments and file contents
|
||||
:param command_arguments: The command arguments to execute
|
||||
:param file_contents: The file contents to execute the command on
|
||||
:param suffix: The suffix of the temporary file. Default is ".py"
|
||||
:return: The file contents after the command has been executed
|
||||
|
||||
Note: The file contents are written to a temporary file and the command is executed
|
||||
on that file. The command arguments should be a list of strings, where the first
|
||||
element is the command to execute and the rest of the elements are the arguments to
|
||||
the command. There is no need to provide the file path as an argument, as it will
|
||||
be appended to the command arguments.
|
||||
|
||||
Example:
|
||||
exec_external(["ruff", "check"], "print('Hello World')")
|
||||
will run the command "ruff check <temp_file_path>" with the file contents
|
||||
"print('Hello World')" and return the file contents after the command
|
||||
has been executed.
|
||||
|
||||
"""
|
||||
errors = ""
|
||||
if len(command_arguments) == 0:
|
||||
raise ExecError("No command arguments provided")
|
||||
|
||||
# Run ruff to validate the code
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
|
||||
temp_file_path = temp_file.name
|
||||
temp_file.write(file_contents.encode("utf-8"))
|
||||
temp_file.flush()
|
||||
|
||||
command_arguments.append(str(temp_file_path))
|
||||
|
||||
# Run Ruff on the temporary file
|
||||
try:
|
||||
r: Process = await asyncio.create_subprocess_exec(
|
||||
*command_arguments,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
result = await r.communicate()
|
||||
stdout, stderr = result[0].decode("utf-8"), result[1].decode("utf-8")
|
||||
logger.debug(f"Output: {stdout}")
|
||||
if temp_file_path in stdout:
|
||||
stdout = stdout # .replace(temp_file.name, "/generated_file")
|
||||
logger.debug(f"Errors: {stderr}")
|
||||
if output_type == OutputType.STD_OUT:
|
||||
errors = stdout
|
||||
elif output_type == OutputType.STD_ERR:
|
||||
errors = stderr
|
||||
else:
|
||||
errors = stdout + "\n" + stderr
|
||||
|
||||
with open(temp_file_path, "r") as f:
|
||||
file_contents = f.read()
|
||||
finally:
|
||||
# Ensure the temporary file is deleted
|
||||
os.remove(temp_file_path)
|
||||
|
||||
if not errors:
|
||||
return file_contents
|
||||
|
||||
if raise_file_contents_on_error:
|
||||
raise ExecError(errors, file_contents)
|
||||
|
||||
raise ExecError(errors)
|
||||
|
||||
|
||||
FOLDER_NAME = "agpt-static-code-analysis"
|
||||
PROJECT_PARENT_DIR = Path(__file__).resolve().parent.parent.parent / f".{FOLDER_NAME}"
|
||||
PROJECT_TEMP_DIR = Path(tempfile.gettempdir()) / FOLDER_NAME
|
||||
DEFAULT_DEPS = ["pyright", "pydantic", "virtualenv-clone"]
|
||||
|
||||
|
||||
def is_env_exists(path: Path):
|
||||
return (
|
||||
(path / "venv/bin/python").exists()
|
||||
and (path / "venv/bin/pip").exists()
|
||||
and (path / "venv/bin/virtualenv-clone").exists()
|
||||
and (path / "venv/bin/pyright").exists()
|
||||
)
|
||||
|
||||
|
||||
async def setup_if_required(
|
||||
cwd: Path = PROJECT_PARENT_DIR, copy_from_parent: bool = True
|
||||
) -> Path:
|
||||
"""
|
||||
Set-up the virtual environment if it does not exist
|
||||
This setup is executed expectedly once per application run
|
||||
Args:
|
||||
cwd (Path): The current working directory
|
||||
copy_from_parent (bool):
|
||||
Whether to copy the virtual environment from PROJECT_PARENT_DIR
|
||||
Returns:
|
||||
Path: The path to the virtual environment
|
||||
"""
|
||||
if not cwd.exists():
|
||||
cwd.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
path = cwd / "venv/bin"
|
||||
if is_env_exists(cwd):
|
||||
return path
|
||||
|
||||
if copy_from_parent and cwd != PROJECT_PARENT_DIR:
|
||||
if (cwd / "venv").exists():
|
||||
await execute_command(["rm", "-rf", str(cwd / "venv")], cwd, None)
|
||||
await execute_command(
|
||||
["virtualenv-clone", str(PROJECT_PARENT_DIR / "venv"), str(cwd / "venv")],
|
||||
cwd,
|
||||
await setup_if_required(PROJECT_PARENT_DIR),
|
||||
)
|
||||
return path
|
||||
|
||||
# Create a virtual environment
|
||||
output = await execute_command(["python", "-m", "venv", "venv"], cwd, None)
|
||||
logger.info(f"[Setup] Created virtual environment: {output}")
|
||||
|
||||
# Install dependencies
|
||||
output = await execute_command(["pip", "install", "-I"] + DEFAULT_DEPS, cwd, path)
|
||||
logger.info(f"[Setup] Installed {DEFAULT_DEPS}: {output}")
|
||||
|
||||
output = await execute_command(["pyright"], cwd, path, raise_on_error=False)
|
||||
logger.info(f"[Setup] Set up pyright: {output}")
|
||||
|
||||
return path
|
||||
|
||||
|
||||
async def execute_command(
|
||||
command: list[str],
|
||||
cwd: str | Path | None,
|
||||
python_path: str | Path | None = None,
|
||||
raise_on_error: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
Execute a command in the shell
|
||||
Args:
|
||||
command (list[str]): The command to execute
|
||||
cwd (str | Path): The current working directory
|
||||
python_path (str | Path): The python executable path
|
||||
raise_on_error (bool): Whether to raise an error if the command fails
|
||||
Returns:
|
||||
str: The output of the command
|
||||
"""
|
||||
# Set the python path by replacing the env 'PATH' with the provided python path
|
||||
venv = os.environ.copy()
|
||||
if python_path:
|
||||
# PATH prioritize first occurrence of python_path, so we need to prepend.
|
||||
venv["PATH"] = f"{python_path}:{venv['PATH']}"
|
||||
r = await asyncio.create_subprocess_exec(
|
||||
*command,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
cwd=str(cwd),
|
||||
env=venv,
|
||||
)
|
||||
stdout, stderr = await r.communicate()
|
||||
if r.returncode == 0:
|
||||
return (stdout or stderr).decode("utf-8")
|
||||
|
||||
if raise_on_error:
|
||||
raise Exception((stderr or stdout).decode("utf-8"))
|
||||
else:
|
||||
return (stderr or stdout).decode("utf-8")
|
||||
110
forge/forge/utils/function/model.py
Normal file
110
forge/forge/utils/function/model.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ObjectType(BaseModel):
|
||||
name: str = Field(description="The name of the object")
|
||||
code: Optional[str] = Field(description="The code of the object", default=None)
|
||||
description: Optional[str] = Field(
|
||||
description="The description of the object", default=None
|
||||
)
|
||||
Fields: List["ObjectField"] = Field(description="The fields of the object")
|
||||
is_pydantic: bool = Field(
|
||||
description="Whether the object is a pydantic model", default=True
|
||||
)
|
||||
is_implemented: bool = Field(
|
||||
description="Whether the object is implemented", default=True
|
||||
)
|
||||
is_enum: bool = Field(description="Whether the object is an enum", default=False)
|
||||
|
||||
|
||||
class ObjectField(BaseModel):
|
||||
name: str = Field(description="The name of the field")
|
||||
description: Optional[str] = Field(
|
||||
description="The description of the field", default=None
|
||||
)
|
||||
type: str = Field(
|
||||
description="The type of the field. Can be a string like List[str] or an use "
|
||||
"any of they related types like list[User]",
|
||||
)
|
||||
value: Optional[str] = Field(description="The value of the field", default=None)
|
||||
related_types: Optional[List[ObjectType]] = Field(
|
||||
description="The related types of the field", default=[]
|
||||
)
|
||||
|
||||
|
||||
class FunctionDef(BaseModel):
|
||||
name: str
|
||||
arg_types: list[tuple[str, str]]
|
||||
arg_defaults: dict[str, str] = {}
|
||||
arg_descs: dict[str, str]
|
||||
return_type: str | None = None
|
||||
return_desc: str
|
||||
function_desc: str
|
||||
is_implemented: bool = False
|
||||
function_code: str = ""
|
||||
function_template: str | None = None
|
||||
is_async: bool = False
|
||||
|
||||
def __generate_function_template(self) -> str:
|
||||
args_str = ", ".join(
|
||||
[
|
||||
f"{name}: {type}"
|
||||
+ (
|
||||
f" = {self.arg_defaults.get(name, '')}"
|
||||
if name in self.arg_defaults
|
||||
else ""
|
||||
)
|
||||
for name, type in self.arg_types
|
||||
]
|
||||
)
|
||||
arg_desc = f"\n{' '*4}".join(
|
||||
[
|
||||
f'{name} ({type}): {self.arg_descs.get(name, "-")}'
|
||||
for name, type in self.arg_types
|
||||
]
|
||||
)
|
||||
|
||||
_def = "async def" if "await " in self.function_code or self.is_async else "def"
|
||||
_return_type = f" -> {self.return_type}" if self.return_type else ""
|
||||
func_desc = self.function_desc.replace("\n", "\n ")
|
||||
|
||||
template = f"""
|
||||
{_def} {self.name}({args_str}){_return_type}:
|
||||
\"\"\"
|
||||
{func_desc}
|
||||
|
||||
Args:
|
||||
{arg_desc}
|
||||
|
||||
Returns:
|
||||
{self.return_type}{': ' + self.return_desc if self.return_desc else ''}
|
||||
\"\"\"
|
||||
pass
|
||||
"""
|
||||
return "\n".join([line for line in template.split("\n")]).strip()
|
||||
|
||||
def __init__(self, function_template: Optional[str] = None, **data):
|
||||
super().__init__(**data)
|
||||
self.function_template = (
|
||||
function_template or self.__generate_function_template()
|
||||
)
|
||||
|
||||
|
||||
class ValidationResponse(BaseModel):
|
||||
function_name: str
|
||||
available_objects: dict[str, ObjectType]
|
||||
available_functions: dict[str, FunctionDef]
|
||||
|
||||
template: str
|
||||
rawCode: str
|
||||
packages: List[str]
|
||||
imports: List[str]
|
||||
functionCode: str
|
||||
|
||||
functions: List[FunctionDef]
|
||||
objects: List[ObjectType]
|
||||
|
||||
def get_compiled_code(self) -> str:
|
||||
return "\n".join(self.imports) + "\n\n" + self.rawCode.strip()
|
||||
292
forge/forge/utils/function/util.py
Normal file
292
forge/forge/utils/function/util.py
Normal file
@@ -0,0 +1,292 @@
|
||||
from typing import List, Tuple, __all__ as all_types
|
||||
from forge.utils.function.model import FunctionDef, ObjectType, ValidationResponse
|
||||
|
||||
OPEN_BRACES = {"{": "Dict", "[": "List", "(": "Tuple"}
|
||||
CLOSE_BRACES = {"}": "Dict", "]": "List", ")": "Tuple"}
|
||||
|
||||
RENAMED_TYPES = {
|
||||
"dict": "Dict",
|
||||
"list": "List",
|
||||
"tuple": "Tuple",
|
||||
"set": "Set",
|
||||
"frozenset": "FrozenSet",
|
||||
"type": "Type",
|
||||
}
|
||||
PYTHON_TYPES = set(all_types)
|
||||
|
||||
|
||||
def unwrap_object_type(type: str) -> Tuple[str, List[str]]:
|
||||
"""
|
||||
Get the type and children of a composite type.
|
||||
Args:
|
||||
type (str): The type to parse.
|
||||
Returns:
|
||||
str: The type.
|
||||
[str]: The children types.
|
||||
"""
|
||||
type = type.replace(" ", "")
|
||||
if not type:
|
||||
return "", []
|
||||
|
||||
def split_outer_level(type: str, separator: str) -> List[str]:
|
||||
brace_count = 0
|
||||
last_index = 0
|
||||
splits = []
|
||||
|
||||
for i, c in enumerate(type):
|
||||
if c in OPEN_BRACES:
|
||||
brace_count += 1
|
||||
elif c in CLOSE_BRACES:
|
||||
brace_count -= 1
|
||||
elif c == separator and brace_count == 0:
|
||||
splits.append(type[last_index:i])
|
||||
last_index = i + 1
|
||||
|
||||
splits.append(type[last_index:])
|
||||
return splits
|
||||
|
||||
# Unwrap primitive union types
|
||||
union_split = split_outer_level(type, "|")
|
||||
if len(union_split) > 1:
|
||||
if len(union_split) == 2 and "None" in union_split:
|
||||
return "Optional", [v for v in union_split if v != "None"]
|
||||
return "Union", union_split
|
||||
|
||||
# Unwrap primitive dict/list/tuple types
|
||||
if type[0] in OPEN_BRACES and type[-1] in CLOSE_BRACES:
|
||||
type_name = OPEN_BRACES[type[0]]
|
||||
type_children = split_outer_level(type[1:-1], ",")
|
||||
return type_name, type_children
|
||||
|
||||
brace_pos = type.find("[")
|
||||
if brace_pos != -1 and type[-1] == "]":
|
||||
# Unwrap normal composite types
|
||||
type_name = type[:brace_pos]
|
||||
type_children = split_outer_level(type[brace_pos + 1 : -1], ",")
|
||||
else:
|
||||
# Non-composite types, no need to unwrap
|
||||
type_name = type
|
||||
type_children = []
|
||||
|
||||
return RENAMED_TYPES.get(type_name, type_name), type_children
|
||||
|
||||
|
||||
def is_type_equal(type1: str | None, type2: str | None) -> bool:
|
||||
"""
|
||||
Check if two types are equal.
|
||||
This function handle composite types like list, dict, and tuple.
|
||||
group similar types like list[str], List[str], and [str] as equal.
|
||||
"""
|
||||
if type1 is None and type2 is None:
|
||||
return True
|
||||
if type1 is None or type2 is None:
|
||||
return False
|
||||
|
||||
evaluated_type1, children1 = unwrap_object_type(type1)
|
||||
evaluated_type2, children2 = unwrap_object_type(type2)
|
||||
|
||||
# Compare the class name of the types (ignoring the module)
|
||||
# TODO(majdyz): compare the module name as well.
|
||||
t_len = min(len(evaluated_type1), len(evaluated_type2))
|
||||
if evaluated_type1.split(".")[-t_len:] != evaluated_type2.split(".")[-t_len:]:
|
||||
return False
|
||||
|
||||
if len(children1) != len(children2):
|
||||
return False
|
||||
|
||||
if len(children1) == len(children2) == 0:
|
||||
return True
|
||||
|
||||
for c1, c2 in zip(children1, children2):
|
||||
if not is_type_equal(c1, c2):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def validate_matching_function(this: FunctionDef, that: FunctionDef):
|
||||
expected_args = that.arg_types
|
||||
expected_rets = that.return_type
|
||||
func_name = that.name
|
||||
errors = []
|
||||
|
||||
# Fix the async flag based on the expectation.
|
||||
if this.is_async != that.is_async:
|
||||
this.is_async = that.is_async
|
||||
if this.is_async and f"async def {this.name}" not in this.function_code:
|
||||
this.function_code = this.function_code.replace(
|
||||
f"def {this.name}", f"async def {this.name}"
|
||||
)
|
||||
if not this.is_async and f"async def {this.name}" in this.function_code:
|
||||
this.function_code = this.function_code.replace(
|
||||
f"async def {this.name}", f"def {this.name}"
|
||||
)
|
||||
|
||||
if any(
|
||||
[
|
||||
x[0] != y[0] or not is_type_equal(x[1], y[1]) and x[1] != "object"
|
||||
# TODO: remove sorted and provide a stable order for one-to-many arg-types.
|
||||
for x, y in zip(sorted(expected_args), sorted(this.arg_types))
|
||||
]
|
||||
):
|
||||
errors.append(
|
||||
f"Function {func_name} has different arguments than expected, "
|
||||
f"expected {expected_args} but got {this.arg_types}"
|
||||
)
|
||||
if not is_type_equal(expected_rets, this.return_type) and expected_rets != "object":
|
||||
errors.append(
|
||||
f"Function {func_name} has different return type than expected, expected "
|
||||
f"{expected_rets} but got {this.return_type}"
|
||||
)
|
||||
|
||||
if errors:
|
||||
raise Exception("Signature validation errors:\n " + "\n ".join(errors))
|
||||
|
||||
|
||||
def normalize_type(type: str, renamed_types: dict[str, str] = {}) -> str:
|
||||
"""
|
||||
Normalize the type to a standard format.
|
||||
e.g. list[str] -> List[str], dict[str, int | float] -> Dict[str, Union[int, float]]
|
||||
|
||||
Args:
|
||||
type (str): The type to normalize.
|
||||
Returns:
|
||||
str: The normalized type.
|
||||
"""
|
||||
parent_type, children = unwrap_object_type(type)
|
||||
|
||||
if parent_type in renamed_types:
|
||||
parent_type = renamed_types[parent_type]
|
||||
|
||||
if len(children) == 0:
|
||||
return parent_type
|
||||
|
||||
content_type = ", ".join([normalize_type(c, renamed_types) for c in children])
|
||||
return f"{parent_type}[{content_type}]"
|
||||
|
||||
|
||||
def generate_object_code(obj: ObjectType) -> str:
|
||||
if not obj.name:
|
||||
return "" # Avoid generating an empty object
|
||||
|
||||
# Auto-generate a template for the object, this will not capture any class functions
|
||||
fields = f"\n{' ' * 4}".join(
|
||||
[
|
||||
f"{field.name}: {field.type} "
|
||||
f"{('= '+field.value) if field.value else ''} "
|
||||
f"{('# '+field.description) if field.description else ''}"
|
||||
for field in obj.Fields or []
|
||||
]
|
||||
)
|
||||
|
||||
parent_class = ""
|
||||
if obj.is_enum:
|
||||
parent_class = "Enum"
|
||||
elif obj.is_pydantic:
|
||||
parent_class = "BaseModel"
|
||||
|
||||
doc_string = (
|
||||
f"""\"\"\"
|
||||
{obj.description}
|
||||
\"\"\""""
|
||||
if obj.description
|
||||
else ""
|
||||
)
|
||||
|
||||
method_body = ("\n" + " " * 4).join(obj.code.split("\n")) + "\n" if obj.code else ""
|
||||
|
||||
template = f"""
|
||||
class {obj.name}({parent_class}):
|
||||
{doc_string if doc_string else ""}
|
||||
{fields if fields else ""}
|
||||
{method_body if method_body else ""}
|
||||
{"pass" if not fields and not method_body else ""}
|
||||
"""
|
||||
return "\n".join(line for line in template.split("\n")).strip()
|
||||
|
||||
|
||||
def genererate_line_error(error: str, code: str, line_number: int) -> str:
|
||||
lines = code.split("\n")
|
||||
if line_number > len(lines):
|
||||
return error
|
||||
|
||||
code_line = lines[line_number - 1]
|
||||
return f"{error} -> '{code_line.strip()}'"
|
||||
|
||||
|
||||
def generate_compiled_code(
|
||||
resp: ValidationResponse, add_code_stubs: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
Regenerate imports & raw code using the available objects and functions.
|
||||
"""
|
||||
resp.imports = sorted(set(resp.imports))
|
||||
|
||||
def __append_comment(code_block: str, comment: str) -> str:
|
||||
"""
|
||||
Append `# noqa` to the first line of the code block.
|
||||
This is to suppress flake8 warnings for redefined names.
|
||||
"""
|
||||
lines = code_block.split("\n")
|
||||
lines[0] = lines[0] + " # " + comment
|
||||
return "\n".join(lines)
|
||||
|
||||
def __generate_stub(name, is_enum):
|
||||
if not name:
|
||||
return ""
|
||||
elif is_enum:
|
||||
return f"class {name}(Enum):\n pass"
|
||||
else:
|
||||
return f"class {name}(BaseModel):\n pass"
|
||||
|
||||
stub_objects = resp.available_objects if add_code_stubs else {}
|
||||
stub_functions = resp.available_functions if add_code_stubs else {}
|
||||
|
||||
object_stubs_code = "\n\n".join(
|
||||
[
|
||||
__append_comment(__generate_stub(obj.name, obj.is_enum), "type: ignore")
|
||||
for obj in stub_objects.values()
|
||||
]
|
||||
+ [
|
||||
__append_comment(__generate_stub(obj.name, obj.is_enum), "type: ignore")
|
||||
for obj in resp.objects
|
||||
if obj.name not in stub_objects
|
||||
]
|
||||
)
|
||||
|
||||
objects_code = "\n\n".join(
|
||||
[
|
||||
__append_comment(generate_object_code(obj), "noqa")
|
||||
for obj in stub_objects.values()
|
||||
]
|
||||
+ [
|
||||
__append_comment(generate_object_code(obj), "noqa")
|
||||
for obj in resp.objects
|
||||
if obj.name not in stub_objects
|
||||
]
|
||||
)
|
||||
|
||||
functions_code = "\n\n".join(
|
||||
[
|
||||
__append_comment(f.function_template.strip(), "type: ignore")
|
||||
for f in stub_functions.values()
|
||||
if f.name != resp.function_name and f.function_template
|
||||
]
|
||||
+ [
|
||||
__append_comment(f.function_template.strip(), "type: ignore")
|
||||
for f in resp.functions
|
||||
if f.name not in stub_functions and f.function_template
|
||||
]
|
||||
)
|
||||
|
||||
resp.rawCode = (
|
||||
object_stubs_code.strip()
|
||||
+ "\n\n"
|
||||
+ objects_code.strip()
|
||||
+ "\n\n"
|
||||
+ functions_code.strip()
|
||||
+ "\n\n"
|
||||
+ resp.functionCode.strip()
|
||||
)
|
||||
|
||||
return resp.get_compiled_code()
|
||||
222
forge/forge/utils/function/visitor.py
Normal file
222
forge/forge/utils/function/visitor.py
Normal file
@@ -0,0 +1,222 @@
|
||||
import ast
|
||||
import re
|
||||
|
||||
from forge.utils.function.model import FunctionDef, ObjectType, ObjectField
|
||||
from forge.utils.function.util import normalize_type, PYTHON_TYPES
|
||||
|
||||
|
||||
class FunctionVisitor(ast.NodeVisitor):
|
||||
"""
|
||||
Visits a Python AST and extracts function definitions and Pydantic class definitions
|
||||
|
||||
To use this class, create an instance and call the visit method with the AST.
|
||||
as the argument The extracted function definitions and Pydantic class definitions
|
||||
can be accessed from the functions and objects attributes respectively.
|
||||
|
||||
Example:
|
||||
```
|
||||
visitor = FunctionVisitor()
|
||||
visitor.visit(ast.parse("def foo(x: int) -> int: return x"))
|
||||
print(visitor.functions)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.functions: list[FunctionDef] = []
|
||||
self.functionsIdx: list[int] = []
|
||||
self.objects: list[ObjectType] = []
|
||||
self.objectsIdx: list[int] = []
|
||||
self.globals: list[str] = []
|
||||
self.globalsIdx: list[int] = []
|
||||
self.imports: list[str] = []
|
||||
self.errors: list[str] = []
|
||||
|
||||
def visit_Import(self, node):
|
||||
for alias in node.names:
|
||||
import_line = f"import {alias.name}"
|
||||
if alias.asname:
|
||||
import_line += f" as {alias.asname}"
|
||||
self.imports.append(import_line)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_ImportFrom(self, node):
|
||||
for alias in node.names:
|
||||
import_line = f"from {node.module} import {alias.name}"
|
||||
if alias.asname:
|
||||
import_line += f" as {alias.asname}"
|
||||
self.imports.append(import_line)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
|
||||
# treat async functions as normal functions
|
||||
self.visit_FunctionDef(node) # type: ignore
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
|
||||
args = []
|
||||
for arg in node.args.args:
|
||||
arg_type = ast.unparse(arg.annotation) if arg.annotation else "object"
|
||||
args.append((arg.arg, normalize_type(arg_type)))
|
||||
return_type = (
|
||||
normalize_type(ast.unparse(node.returns)) if node.returns else None
|
||||
)
|
||||
|
||||
# Extract doc_string & function body
|
||||
if (
|
||||
node.body
|
||||
and isinstance(node.body[0], ast.Expr)
|
||||
and isinstance(node.body[0].value, ast.Constant)
|
||||
):
|
||||
doc_string = node.body[0].value.s.strip()
|
||||
template_body = [node.body[0], ast.Pass()]
|
||||
is_implemented = not isinstance(node.body[1], ast.Pass)
|
||||
else:
|
||||
doc_string = ""
|
||||
template_body = [ast.Pass()]
|
||||
is_implemented = not isinstance(node.body[0], ast.Pass)
|
||||
|
||||
# Construct function template
|
||||
original_body = node.body.copy()
|
||||
node.body = template_body # type: ignore
|
||||
function_template = ast.unparse(node)
|
||||
node.body = original_body
|
||||
|
||||
function_code = ast.unparse(node)
|
||||
if "await" in function_code and "async def" not in function_code:
|
||||
function_code = function_code.replace("def ", "async def ")
|
||||
function_template = function_template.replace("def ", "async def ")
|
||||
|
||||
def split_doc(keywords: list[str], doc: str) -> tuple[str, str]:
|
||||
for keyword in keywords:
|
||||
if match := re.search(f"{keyword}\\s?:", doc):
|
||||
return doc[: match.start()], doc[match.end() :]
|
||||
return doc, ""
|
||||
|
||||
# Decompose doc_pattern into func_doc, args_doc, rets_doc, errs_doc, usage_doc
|
||||
# by splitting in reverse order
|
||||
func_doc = doc_string
|
||||
func_doc, usage_doc = split_doc(
|
||||
["Ex", "Usage", "Usages", "Example", "Examples"], func_doc
|
||||
)
|
||||
func_doc, errs_doc = split_doc(["Error", "Errors", "Raise", "Raises"], func_doc)
|
||||
func_doc, rets_doc = split_doc(["Return", "Returns"], func_doc)
|
||||
func_doc, args_doc = split_doc(
|
||||
["Arg", "Args", "Argument", "Arguments"], func_doc
|
||||
)
|
||||
|
||||
# Extract Func
|
||||
function_desc = func_doc.strip()
|
||||
|
||||
# Extract Args
|
||||
args_descs = {}
|
||||
split_pattern = r"\n(\s+.+):"
|
||||
for match in reversed(list(re.finditer(split_pattern, string=args_doc))):
|
||||
arg = match.group(1).strip().split(" ")[0]
|
||||
desc = args_doc.rsplit(match.group(1), 1)[1].strip(": ")
|
||||
args_descs[arg] = desc.strip()
|
||||
args_doc = args_doc[: match.start()]
|
||||
|
||||
# Extract Returns
|
||||
return_desc = ""
|
||||
if match := re.match(split_pattern, string=rets_doc):
|
||||
return_desc = rets_doc[match.end() :].strip()
|
||||
|
||||
self.functions.append(
|
||||
FunctionDef(
|
||||
name=node.name,
|
||||
arg_types=args,
|
||||
arg_descs=args_descs,
|
||||
return_type=return_type,
|
||||
return_desc=return_desc,
|
||||
is_implemented=is_implemented,
|
||||
function_desc=function_desc,
|
||||
function_template=function_template,
|
||||
function_code=function_code,
|
||||
)
|
||||
)
|
||||
self.functionsIdx.append(node.lineno)
|
||||
|
||||
def visit_ClassDef(self, node: ast.ClassDef) -> None:
|
||||
"""
|
||||
Visits a ClassDef node in the AST and checks if it is a Pydantic class.
|
||||
If it is a Pydantic class, adds its name to the list of Pydantic classes.
|
||||
"""
|
||||
is_pydantic = any(
|
||||
[
|
||||
(isinstance(base, ast.Name) and base.id == "BaseModel")
|
||||
or (isinstance(base, ast.Attribute) and base.attr == "BaseModel")
|
||||
for base in node.bases
|
||||
]
|
||||
)
|
||||
is_enum = any(
|
||||
[
|
||||
(isinstance(base, ast.Name) and base.id.endswith("Enum"))
|
||||
or (isinstance(base, ast.Attribute) and base.attr.endswith("Enum"))
|
||||
for base in node.bases
|
||||
]
|
||||
)
|
||||
is_implemented = not any(isinstance(v, ast.Pass) for v in node.body)
|
||||
doc_string = ""
|
||||
if (
|
||||
node.body
|
||||
and isinstance(node.body[0], ast.Expr)
|
||||
and isinstance(node.body[0].value, ast.Constant)
|
||||
):
|
||||
doc_string = node.body[0].value.s.strip()
|
||||
|
||||
if node.name in PYTHON_TYPES:
|
||||
self.errors.append(
|
||||
f"Can't declare class with a Python built-in name "
|
||||
f"`{node.name}`. Please use a different name."
|
||||
)
|
||||
|
||||
fields = []
|
||||
methods = []
|
||||
for v in node.body:
|
||||
if isinstance(v, ast.AnnAssign):
|
||||
field = ObjectField(
|
||||
name=ast.unparse(v.target),
|
||||
type=normalize_type(ast.unparse(v.annotation)),
|
||||
value=ast.unparse(v.value) if v.value else None,
|
||||
)
|
||||
if field.value is None and field.type.startswith("Optional"):
|
||||
field.value = "None"
|
||||
elif isinstance(v, ast.Assign):
|
||||
if len(v.targets) > 1:
|
||||
self.errors.append(
|
||||
f"Class {node.name} has multiple assignments in a single line."
|
||||
)
|
||||
field = ObjectField(
|
||||
name=ast.unparse(v.targets[0]),
|
||||
type=type(ast.unparse(v.value)).__name__,
|
||||
value=ast.unparse(v.value) if v.value else None,
|
||||
)
|
||||
elif isinstance(v, ast.Expr) and isinstance(v.value, ast.Constant):
|
||||
# skip comments and docstrings
|
||||
continue
|
||||
else:
|
||||
methods.append(ast.unparse(v))
|
||||
continue
|
||||
fields.append(field)
|
||||
|
||||
self.objects.append(
|
||||
ObjectType(
|
||||
name=node.name,
|
||||
code="\n".join(methods),
|
||||
description=doc_string,
|
||||
Fields=fields,
|
||||
is_pydantic=is_pydantic,
|
||||
is_enum=is_enum,
|
||||
is_implemented=is_implemented,
|
||||
)
|
||||
)
|
||||
self.objectsIdx.append(node.lineno)
|
||||
|
||||
def visit(self, node):
|
||||
if (
|
||||
isinstance(node, ast.Assign)
|
||||
or isinstance(node, ast.AnnAssign)
|
||||
or isinstance(node, ast.AugAssign)
|
||||
) and node.col_offset == 0:
|
||||
self.globals.append(ast.unparse(node))
|
||||
self.globalsIdx.append(node.lineno)
|
||||
super().visit(node)
|
||||
Reference in New Issue
Block a user