mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Fix Await fiasco
This commit is contained in:
@@ -241,6 +241,7 @@ class Agent(BaseAgent, Configurable[AgentSettings]):
|
||||
# Get commands
|
||||
self.commands = await self.run_pipeline(CommandProvider.get_commands)
|
||||
self._remove_disabled_commands()
|
||||
self.code_flow_executor.set_available_functions(self.commands)
|
||||
|
||||
try:
|
||||
return_value = await self._execute_tool(tool)
|
||||
|
||||
@@ -4,8 +4,7 @@ from logging import Logger
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from autogpt.agents.base import BaseAgentActionProposal
|
||||
from autogpt.agents.prompt_strategies.one_shot import OneShotAgentPromptConfiguration, AssistantThoughts
|
||||
from autogpt.agents.prompt_strategies.one_shot import OneShotAgentPromptConfiguration, AssistantThoughts, OneShotAgentActionProposal
|
||||
from autogpt.config.ai_directives import AIDirectives
|
||||
from autogpt.config.ai_profile import AIProfile
|
||||
from autogpt.core.configuration.schema import SystemConfiguration
|
||||
@@ -59,7 +58,8 @@ FINAL_INSTRUCTION: str = (
|
||||
"arguments can't be determined yet. Reduce the amount of unnecessary data passed into "
|
||||
"these magic functions where possible, because magic costs money and magically "
|
||||
"processing large amounts of data is expensive. If you think are done with the task, "
|
||||
"you can simply call finish(reason='your reason') to end the task. "
|
||||
"you can simply call finish(reason='your reason') to end the task, "
|
||||
"a function that has one `finish` command, don't mix finish with other functions. "
|
||||
)
|
||||
|
||||
|
||||
@@ -187,7 +187,7 @@ class CodeFlowAgentPromptStrategy(PromptStrategy):
|
||||
async def parse_response_content(
|
||||
self,
|
||||
response: AssistantChatMessage,
|
||||
) -> BaseAgentActionProposal:
|
||||
) -> OneShotAgentActionProposal:
|
||||
if not response.content:
|
||||
raise InvalidAgentResponseError("Assistant response has no text content")
|
||||
|
||||
@@ -210,6 +210,7 @@ class CodeFlowAgentPromptStrategy(PromptStrategy):
|
||||
name=f.name,
|
||||
arg_types=[(name, p.python_type) for name, p in f.parameters.items()],
|
||||
arg_descs={name: p.description for name, p in f.parameters.items()},
|
||||
arg_defaults={name: p.default or "None" for name, p in f.parameters.items() if p.default or not p.required},
|
||||
return_type="str",
|
||||
return_desc="Output of the function",
|
||||
function_desc=f.description,
|
||||
@@ -235,14 +236,24 @@ class CodeFlowAgentPromptStrategy(PromptStrategy):
|
||||
available_functions=available_functions,
|
||||
).validate_code(parsed_response.python_code)
|
||||
|
||||
result = BaseAgentActionProposal(
|
||||
thoughts=parsed_response.thoughts,
|
||||
use_tool=AssistantFunctionCall(
|
||||
name="execute_code_flow",
|
||||
arguments={
|
||||
"python_code": code_validation.functionCode,
|
||||
"plan_text": parsed_response.immediate_plan,
|
||||
},
|
||||
),
|
||||
)
|
||||
if re.search(r"finish\((.*?)\)", code_validation.functionCode):
|
||||
finish_reason = re.search(r"finish\((reason=)?(.*?)\)", code_validation.functionCode).group(2)
|
||||
result = OneShotAgentActionProposal(
|
||||
thoughts=parsed_response.thoughts,
|
||||
use_tool=AssistantFunctionCall(
|
||||
name="finish",
|
||||
arguments={"reason": finish_reason[1:-1]},
|
||||
),
|
||||
)
|
||||
else:
|
||||
result = OneShotAgentActionProposal(
|
||||
thoughts=parsed_response.thoughts,
|
||||
use_tool=AssistantFunctionCall(
|
||||
name="execute_code_flow",
|
||||
arguments={
|
||||
"python_code": code_validation.functionCode,
|
||||
"plan_text": parsed_response.immediate_plan,
|
||||
},
|
||||
),
|
||||
)
|
||||
return result
|
||||
|
||||
@@ -162,7 +162,6 @@ class CodeValidator:
|
||||
validation_errors=validation_errors,
|
||||
)
|
||||
function_template = main_func.function_template
|
||||
function_code = main_func.function_code
|
||||
else:
|
||||
function_template = None
|
||||
|
||||
@@ -397,6 +396,7 @@ async def __execute_pyright(func: ValidationResponse) -> list[str]:
|
||||
|
||||
# read code from code.py. split the code into imports and raw code
|
||||
code = open(f"{temp_dir}/code.py").read()
|
||||
code, error_messages = await __fix_async_calls(code, validation_errors)
|
||||
func.imports, func.rawCode = __unpack_import_and_function_code(code)
|
||||
|
||||
return validation_errors
|
||||
@@ -450,6 +450,35 @@ for t in collections.__all__:
|
||||
AUTO_IMPORT_TYPES[t] = f"from collections import {t}"
|
||||
|
||||
|
||||
async def __fix_async_calls(code: str, errors: list[str]) -> tuple[str, list[str]]:
|
||||
"""
|
||||
Fix the async calls in the code
|
||||
Args:
|
||||
code (str): The code snippet
|
||||
errors (list[str]): The list of errors
|
||||
func (ValidationResponse): The function to fix the async calls
|
||||
Returns:
|
||||
tuple[str, list[str]]: The fixed code snippet and the list of errors
|
||||
"""
|
||||
async_calls = set()
|
||||
new_errors = []
|
||||
for error in errors:
|
||||
pattern = '"__await__" is not present. reportGeneralTypeIssues -> (.+)'
|
||||
match = re.search(pattern, error)
|
||||
if match:
|
||||
async_calls.add(match.group(1))
|
||||
else:
|
||||
new_errors.append(error)
|
||||
|
||||
for async_call in async_calls:
|
||||
func_call = re.search(r"await ([a-zA-Z0-9_]+)", async_call)
|
||||
if func_call:
|
||||
func_name = func_call.group(1)
|
||||
code = code.replace(f"await {func_name}", f"{func_name}")
|
||||
|
||||
return code, new_errors
|
||||
|
||||
|
||||
async def __fix_missing_imports(
|
||||
errors: list[str], func: ValidationResponse
|
||||
) -> tuple[set[str], list[str]]:
|
||||
|
||||
@@ -36,6 +36,7 @@ class ObjectField(BaseModel):
|
||||
class FunctionDef(BaseModel):
|
||||
name: str
|
||||
arg_types: list[tuple[str, str]]
|
||||
arg_defaults: dict[str, str] = {}
|
||||
arg_descs: dict[str, str]
|
||||
return_type: str | None = None
|
||||
return_desc: str
|
||||
@@ -46,7 +47,10 @@ class FunctionDef(BaseModel):
|
||||
is_async: bool = False
|
||||
|
||||
def __generate_function_template(f) -> str:
|
||||
args_str = ", ".join([f"{name}: {type}" for name, type in f.arg_types])
|
||||
args_str = ", ".join([
|
||||
f"{name}: {type}" + (f" = {f.arg_defaults.get(name, '')}" if name in f.arg_defaults else "")
|
||||
for name, type in f.arg_types
|
||||
])
|
||||
arg_desc = f"\n{' '*4}".join(
|
||||
[
|
||||
f'{name} ({type}): {f.arg_descs.get(name, "-")}'
|
||||
|
||||
@@ -56,12 +56,16 @@ def crawl_info(url: str, query: str) -> str | None:
|
||||
|
||||
return None
|
||||
|
||||
def hehe():
|
||||
return 'hehe'
|
||||
|
||||
def main() -> str:
|
||||
query = "Find the number of contributors to the autogpt github repository, or if any, list of urls that can be crawled to find the number of contributors"
|
||||
for title, url in ("autogpt github contributor page"):
|
||||
info = await crawl_info(url, query)
|
||||
if info:
|
||||
return info
|
||||
x = await hehe()
|
||||
return "No info found"
|
||||
""",
|
||||
packages=[],
|
||||
@@ -69,3 +73,5 @@ def main() -> str:
|
||||
assert response.functionCode is not None
|
||||
assert "async def crawl_info" in response.functionCode # async is added
|
||||
assert "async def main" in response.functionCode
|
||||
assert "x = hehe()" in response.functionCode # await is removed
|
||||
|
||||
|
||||
Reference in New Issue
Block a user