Add return type

This commit is contained in:
Zamil Majdy
2024-05-17 17:10:54 +02:00
parent 922e643737
commit fb802400ba
5 changed files with 36 additions and 11 deletions

View File

@@ -59,7 +59,8 @@ FINAL_INSTRUCTION: str = (
"these magic functions where possible, because magic costs money and magically "
"processing large amounts of data is expensive. If you think are done with the task, "
"you can simply call finish(reason='your reason') to end the task, "
"a function that has one `finish` command, don't mix finish with other functions. "
"a function that has one `finish` command, don't mix finish with other functions! "
"If you still need to do other functions, let the next cycle execute the `finish` function. "
)
@@ -182,7 +183,7 @@ class CodeFlowAgentPromptStrategy(PromptStrategy):
]
def _generate_function_headers(self, funcs: list[CompletionModelFunction]) -> str:
return "\n\n".join(f.fmt_header() for f in funcs)
return "\n\n".join(f.fmt_header(force_async=True) for f in funcs)
async def parse_response_content(
self,
@@ -211,10 +212,10 @@ class CodeFlowAgentPromptStrategy(PromptStrategy):
arg_types=[(name, p.python_type) for name, p in f.parameters.items()],
arg_descs={name: p.description for name, p in f.parameters.items()},
arg_defaults={name: p.default or "None" for name, p in f.parameters.items() if p.default or not p.required},
return_type="str",
return_type=f.return_type,
return_desc="Output of the function",
function_desc=f.description,
is_async=f.is_async,
is_async=True,
)
for f in self.commands
}
@@ -236,6 +237,7 @@ class CodeFlowAgentPromptStrategy(PromptStrategy):
available_functions=available_functions,
).validate_code(parsed_response.python_code)
# TODO: prevent combining finish with other functions
if re.search(r"finish\((.*?)\)", code_validation.functionCode):
finish_reason = re.search(r"finish\((reason=)?(.*?)\)", code_validation.functionCode).group(2)
result = OneShotAgentActionProposal(

View File

@@ -51,13 +51,26 @@ class CodeFlowExecutionComponent(CommandProvider):
Returns:
str: The result of the code execution
"""
code = f"{python_code}\nexec_output = main()"
code_header = "import inspect\n".join(
[
f"""
async def {name}(*args, **kwargs):
result = {name}_func(*args, **kwargs)
if inspect.isawaitable(result):
result = await result
return result
"""
for name in self.available_functions.keys()
]
)
result = {
name: func
for name, func in self.available_functions.items()
name + "_func": func for name, func in self.available_functions.items()
}
code = f"{code_header}\n{python_code}\nexec_output = main()"
print("----> Executing code:", python_code)
exec(code, result)
result = await result['exec_output']
result = await result["exec_output"]
print("----> Execution result:", result)
if inspect.isawaitable(result):
result = await result
return f"Execution Plan:\n{plan_text}\n\nExecution Output:\n{result}"

View File

@@ -125,6 +125,7 @@ class CompletionModelFunction(BaseModel):
name: str
description: str
parameters: dict[str, "JSONSchema"]
return_type: str | None = None
is_async: bool = False
@property
@@ -160,7 +161,7 @@ class CompletionModelFunction(BaseModel):
)
return f"{self.name}: {self.description}. Params: ({params})"
def fmt_header(self, impl="pass") -> str:
def fmt_header(self, impl="pass", force_async=False) -> str:
"""
Formats and returns the function header as a string with types and descriptions.
@@ -171,8 +172,9 @@ class CompletionModelFunction(BaseModel):
f"{name}: {p.python_type}{f'= {str(p.default)}' if p.default else ' = None' if not p.required else ''}"
for name, p in self.parameters.items()
)
func = "async def" if self.is_async else "def"
return f"{func} {self.name}({params}) -> str:\n" + indent(
func = "async def" if self.is_async or force_async else "def"
return_str = f" -> {self.return_type}" if self.return_type else ""
return f"{func} {self.name}({params}){return_str}:\n" + indent(
(
'"""\n'
f"{self.description}\n\n"

View File

@@ -26,6 +26,7 @@ def function_specs_from_commands(
description=command.description,
is_async=command.is_async,
parameters={param.name: param.spec for param in command.parameters},
return_type=command.return_type,
)
for command in commands
]

View File

@@ -42,6 +42,13 @@ class Command(Generic[P, CO]):
@property
def is_async(self) -> bool:
return inspect.iscoroutinefunction(self.method)
@property
def return_type(self) -> type:
type = inspect.signature(self.method).return_annotation
if type == inspect.Signature.empty:
return None
return type.__name__
def _parameters_match(
self, func: Callable, parameters: list[CommandParameter]