forge.llm.providers.schema + code_flow_executor lint-fix and cleanup

This commit is contained in:
Reinier van der Leer
2024-06-08 15:28:52 +02:00
parent 6dd0975236
commit e264bf7764
2 changed files with 29 additions and 25 deletions

View File

@@ -1,3 +1,3 @@
from .code_flow_executor import (
CodeFlowExecutionComponent
)
from .code_flow_executor import CodeFlowExecutionComponent
__all__ = ["CodeFlowExecutionComponent"]

View File

@@ -144,38 +144,42 @@ class CompletionModelFunction(BaseModel):
)
return f"{self.name}: {self.description}. Params: ({params})"
def fmt_header(self, impl="pass", force_async=False) -> str:
def fmt_function_stub(self, impl: str = "pass") -> str:
"""
Formats and returns the function header as a string with types and descriptions.
Formats and returns a function stub as a string with types and descriptions.
Returns:
str: The formatted function header.
"""
def indent(content: str, spaces: int = 4):
return " " * spaces + content.replace("\n", "\n" + " " * spaces)
from forge.llm.prompting.utils import indent
params = ", ".join(
f"{name}: {p.python_type}{f'= {str(p.default)}' if p.default else ' = None' if not p.required else ''}"
f"{name}: {p.python_type}"
+ (
f" = {str(p.default)}"
if p.default
else " = None"
if not p.required
else ""
)
for name, p in self.parameters.items()
)
func = "async def" if self.is_async or force_async else "def"
return_str = f" -> {self.return_type}" if self.return_type else ""
return f"{func} {self.name}({params}){return_str}:\n" + indent(
(
'"""\n'
f"{self.description}\n\n"
"Params:\n"
+ indent(
"\n".join(
f"{name}: {param.description}"
for name, param in self.parameters.items()
if param.description
)
_def = "async def" if self.is_async else "def"
_return = f" -> {self.return_type}" if self.return_type else ""
return f"{_def} {self.name}({params}){_return}:\n" + indent(
'"""\n'
f"{self.description}\n\n"
"Params:\n"
+ indent(
"\n".join(
f"{name}: {param.description}"
for name, param in self.parameters.items()
if param.description
)
+ "\n"
'"""\n'
f"{impl}"
),
)
+ "\n"
'"""\n'
f"{impl}"
)
def validate_call(