mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
implement annotation expansion for non-builtin types
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import inspect
|
||||
import re
|
||||
from logging import Logger
|
||||
from typing import Iterable, Sequence
|
||||
from typing import Callable, Iterable, Sequence, get_args, get_origin
|
||||
|
||||
from forge.command import Command
|
||||
from forge.config.ai_directives import AIDirectives
|
||||
@@ -202,28 +203,40 @@ class CodeFlowAgentPromptStrategy(PromptStrategy):
|
||||
]
|
||||
|
||||
def _generate_function_headers(self, commands: Iterable[Command]) -> str:
|
||||
return "\n\n".join(
|
||||
f.header
|
||||
+ "\n"
|
||||
+ indent(
|
||||
(
|
||||
'"""\n'
|
||||
f"{f.description}\n\n"
|
||||
"Params:\n"
|
||||
+ indent(
|
||||
"\n".join(
|
||||
f"{param.name}: {param.spec.description}"
|
||||
for param in f.parameters
|
||||
if param.spec.description
|
||||
)
|
||||
)
|
||||
+ "\n"
|
||||
'"""\n'
|
||||
"pass"
|
||||
),
|
||||
function_stubs: list[str] = []
|
||||
annotation_types_in_context: set[type] = set()
|
||||
for f in commands:
|
||||
# Add source code of non-builtin types from function signatures
|
||||
new_annotation_types = extract_annotation_types(f.method).difference(
|
||||
annotation_types_in_context
|
||||
)
|
||||
for f in commands
|
||||
)
|
||||
new_annotation_types_src = [
|
||||
f"# {a.__module__}.{a.__qualname__}\n{inspect.getsource(a)}"
|
||||
for a in new_annotation_types
|
||||
]
|
||||
annotation_types_in_context.update(new_annotation_types)
|
||||
|
||||
param_descriptions = "\n".join(
|
||||
f"{param.name}: {param.spec.description}"
|
||||
for param in f.parameters
|
||||
if param.spec.description
|
||||
)
|
||||
full_function_stub = (
|
||||
("\n".join(new_annotation_types_src) + "\n" + f.header).strip()
|
||||
+ "\n"
|
||||
+ indent(
|
||||
(
|
||||
'"""\n'
|
||||
f"{f.description}\n\n"
|
||||
f"Params:\n{indent(param_descriptions)}\n"
|
||||
'"""\n'
|
||||
"pass"
|
||||
),
|
||||
)
|
||||
)
|
||||
function_stubs.append(full_function_stub)
|
||||
|
||||
return "\n\n\n".join(function_stubs)
|
||||
|
||||
async def parse_response_content(
|
||||
self,
|
||||
@@ -305,3 +318,25 @@ class CodeFlowAgentPromptStrategy(PromptStrategy):
|
||||
),
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def extract_annotation_types(func: Callable) -> set[type]:
|
||||
annotation_types = set()
|
||||
for annotation in inspect.get_annotations(func).values():
|
||||
annotation_types.update(_get_nested_types(annotation))
|
||||
return annotation_types
|
||||
|
||||
|
||||
def _get_nested_types(annotation: type) -> Iterable[type]:
|
||||
if _args := get_args(annotation):
|
||||
for a in _args:
|
||||
yield from _get_nested_types(a)
|
||||
if not _is_builtin_type(_a := get_origin(annotation) or annotation):
|
||||
yield _a
|
||||
|
||||
|
||||
def _is_builtin_type(_type: type):
|
||||
"""Check if a given type is a built-in type."""
|
||||
import sys
|
||||
|
||||
return _type.__module__ in sys.stdlib_module_names
|
||||
|
||||
Reference in New Issue
Block a user