diff --git a/forge/forge/models/json_schema.py b/forge/forge/models/json_schema.py index a4ab88835a..43d060e90a 100644 --- a/forge/forge/models/json_schema.py +++ b/forge/forge/models/json_schema.py @@ -3,7 +3,7 @@ import enum import typing from textwrap import indent from types import NoneType -from typing import Any, Optional, TypedDict, overload +from typing import Any, Optional, is_typeddict, overload from jsonschema import Draft7Validator, ValidationError from pydantic import BaseModel @@ -171,7 +171,7 @@ class JSONSchema(BaseModel): if (T_v := typing.get_args(T)[1]) else None, ) - elif issubclass(T, TypedDict): + elif is_typeddict(T): partial_schema = JSONSchema( type=JSONSchema.Type.OBJECT, properties={ diff --git a/forge/forge/utils/function/model.py b/forge/forge/utils/function/model.py index 8f409b078f..1f1afa51a4 100644 --- a/forge/forge/utils/function/model.py +++ b/forge/forge/utils/function/model.py @@ -1,4 +1,5 @@ from typing import List, Optional + from pydantic import BaseModel, Field @@ -46,31 +47,31 @@ class FunctionDef(BaseModel): function_template: str | None = None is_async: bool = False - def __generate_function_template(f) -> str: + def __generate_function_template(self) -> str: args_str = ", ".join( [ f"{name}: {type}" + ( - f" = {f.arg_defaults.get(name, '')}" - if name in f.arg_defaults + f" = {self.arg_defaults.get(name, '')}" + if name in self.arg_defaults else "" ) - for name, type in f.arg_types + for name, type in self.arg_types ] ) arg_desc = f"\n{' '*4}".join( [ - f'{name} ({type}): {f.arg_descs.get(name, "-")}' - for name, type in f.arg_types + f'{name} ({type}): {self.arg_descs.get(name, "-")}' + for name, type in self.arg_types ] ) - def_str = "async def" if "await " in f.function_code or f.is_async else "def" - ret_type_str = f" -> {f.return_type}" if f.return_type else "" - func_desc = f.function_desc.replace("\n", "\n ") + _def = "async def" if "await " in self.function_code or self.is_async else "def" + _return_type = f" -> {self.return_type}" if self.return_type else "" + func_desc = self.function_desc.replace("\n", "\n ") template = f""" -{def_str} {f.name}({args_str}){ret_type_str}: +{_def} {self.name}({args_str}){_return_type}: \"\"\" {func_desc} @@ -78,7 +79,7 @@ class FunctionDef(BaseModel): {arg_desc} Returns: - {f.return_type}{': ' + f.return_desc if f.return_desc else ''} + {self.return_type}{': ' + self.return_desc if self.return_desc else ''} \"\"\" pass """