implement annotation expansion for non-builtin types

This commit is contained in:
Reinier van der Leer
2024-06-08 02:01:20 +02:00
parent b4cd735f26
commit 731d0345f0

View File

@@ -1,6 +1,7 @@
import inspect
import re
from logging import Logger
from typing import Iterable, Sequence
from typing import Callable, Iterable, Sequence, get_args, get_origin
from forge.command import Command
from forge.config.ai_directives import AIDirectives
@@ -202,28 +203,40 @@ class CodeFlowAgentPromptStrategy(PromptStrategy):
]
def _generate_function_headers(self, commands: Iterable[Command]) -> str:
return "\n\n".join(
f.header
+ "\n"
+ indent(
(
'"""\n'
f"{f.description}\n\n"
"Params:\n"
+ indent(
"\n".join(
f"{param.name}: {param.spec.description}"
for param in f.parameters
if param.spec.description
)
)
+ "\n"
'"""\n'
"pass"
),
function_stubs: list[str] = []
annotation_types_in_context: set[type] = set()
for f in commands:
# Add source code of non-builtin types from function signatures
new_annotation_types = extract_annotation_types(f.method).difference(
annotation_types_in_context
)
for f in commands
)
new_annotation_types_src = [
f"# {a.__module__}.{a.__qualname__}\n{inspect.getsource(a)}"
for a in new_annotation_types
]
annotation_types_in_context.update(new_annotation_types)
param_descriptions = "\n".join(
f"{param.name}: {param.spec.description}"
for param in f.parameters
if param.spec.description
)
full_function_stub = (
("\n".join(new_annotation_types_src) + "\n" + f.header).strip()
+ "\n"
+ indent(
(
'"""\n'
f"{f.description}\n\n"
f"Params:\n{indent(param_descriptions)}\n"
'"""\n'
"pass"
),
)
)
function_stubs.append(full_function_stub)
return "\n\n\n".join(function_stubs)
async def parse_response_content(
self,
@@ -305,3 +318,25 @@ class CodeFlowAgentPromptStrategy(PromptStrategy):
),
)
return result
def extract_annotation_types(func: Callable) -> set[type]:
annotation_types = set()
for annotation in inspect.get_annotations(func).values():
annotation_types.update(_get_nested_types(annotation))
return annotation_types
def _get_nested_types(annotation: type) -> Iterable[type]:
if _args := get_args(annotation):
for a in _args:
yield from _get_nested_types(a)
if not _is_builtin_type(_a := get_origin(annotation) or annotation):
yield _a
def _is_builtin_type(_type: type):
"""Check if a given type is a built-in type."""
import sys
return _type.__module__ in sys.stdlib_module_names