mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Move python code to subdir (#98)
This commit is contained in:
9
python/src/agnext/components/__init__.py
Normal file
9
python/src/agnext/components/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
The :mod:`agnext.components` module provides building blocks for creating single agents
|
||||
"""
|
||||
|
||||
from ._image import Image
|
||||
from ._type_routed_agent import TypeRoutedAgent, message_handler
|
||||
from ._types import FunctionCall
|
||||
|
||||
__all__ = ["Image", "TypeRoutedAgent", "message_handler", "FunctionCall"]
|
||||
337
python/src/agnext/components/_function_utils.py
Normal file
337
python/src/agnext/components/_function_utils.py
Normal file
@@ -0,0 +1,337 @@
|
||||
# File based from: https://github.com/microsoft/autogen/blob/47f905267245e143562abfb41fcba503a9e1d56d/autogen/function_utils.py
|
||||
# Credit to original authors
|
||||
|
||||
import inspect
|
||||
from logging import getLogger
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
ForwardRef,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, Field, create_model # type: ignore
|
||||
from pydantic_core import PydanticUndefined
|
||||
from typing_extensions import Literal
|
||||
|
||||
from ._pydantic_compat import evaluate_forwardref, model_dump, type2schema
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any:
|
||||
"""Get the type annotation of a parameter.
|
||||
|
||||
Args:
|
||||
annotation: The annotation of the parameter
|
||||
globalns: The global namespace of the function
|
||||
|
||||
Returns:
|
||||
The type annotation of the parameter
|
||||
"""
|
||||
if isinstance(annotation, str):
|
||||
annotation = ForwardRef(annotation)
|
||||
annotation = evaluate_forwardref(annotation, globalns, globalns)
|
||||
return annotation
|
||||
|
||||
|
||||
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
|
||||
"""Get the signature of a function with type annotations.
|
||||
|
||||
Args:
|
||||
call: The function to get the signature for
|
||||
|
||||
Returns:
|
||||
The signature of the function with type annotations
|
||||
"""
|
||||
signature = inspect.signature(call)
|
||||
globalns = getattr(call, "__globals__", {})
|
||||
typed_params = [
|
||||
inspect.Parameter(
|
||||
name=param.name,
|
||||
kind=param.kind,
|
||||
default=param.default,
|
||||
annotation=get_typed_annotation(param.annotation, globalns),
|
||||
)
|
||||
for param in signature.parameters.values()
|
||||
]
|
||||
return_annotation = get_typed_annotation(signature.return_annotation, globalns)
|
||||
typed_signature = inspect.Signature(typed_params, return_annotation=return_annotation)
|
||||
return typed_signature
|
||||
|
||||
|
||||
def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
|
||||
"""Get the return annotation of a function.
|
||||
|
||||
Args:
|
||||
call: The function to get the return annotation for
|
||||
|
||||
Returns:
|
||||
The return annotation of the function
|
||||
"""
|
||||
signature = inspect.signature(call)
|
||||
annotation = signature.return_annotation
|
||||
|
||||
if annotation is inspect.Signature.empty:
|
||||
return None
|
||||
|
||||
globalns = getattr(call, "__globals__", {})
|
||||
return get_typed_annotation(annotation, globalns)
|
||||
|
||||
|
||||
def get_param_annotations(
|
||||
typed_signature: inspect.Signature,
|
||||
) -> Dict[str, Union[Annotated[Type[Any], str], Type[Any]]]:
|
||||
"""Get the type annotations of the parameters of a function
|
||||
|
||||
Args:
|
||||
typed_signature: The signature of the function with type annotations
|
||||
|
||||
Returns:
|
||||
A dictionary of the type annotations of the parameters of the function
|
||||
"""
|
||||
return {
|
||||
k: v.annotation for k, v in typed_signature.parameters.items() if v.annotation is not inspect.Signature.empty
|
||||
}
|
||||
|
||||
|
||||
class Parameters(BaseModel):
|
||||
"""Parameters of a function as defined by the OpenAI API"""
|
||||
|
||||
type: Literal["object"] = "object"
|
||||
properties: Dict[str, Dict[str, Any]]
|
||||
required: List[str]
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
"""A function as defined by the OpenAI API"""
|
||||
|
||||
description: Annotated[str, Field(description="Description of the function")]
|
||||
name: Annotated[str, Field(description="Name of the function")]
|
||||
parameters: Annotated[Parameters, Field(description="Parameters of the function")]
|
||||
|
||||
|
||||
class ToolFunction(BaseModel):
|
||||
"""A function under tool as defined by the OpenAI API."""
|
||||
|
||||
type: Literal["function"] = "function"
|
||||
function: Annotated[Function, Field(description="Function under tool")]
|
||||
|
||||
|
||||
def type2description(k: str, v: Union[Annotated[Type[Any], str], Type[Any]]) -> str:
|
||||
# handles Annotated
|
||||
if hasattr(v, "__metadata__"):
|
||||
retval = v.__metadata__[0]
|
||||
if isinstance(retval, str):
|
||||
return retval
|
||||
else:
|
||||
raise ValueError(f"Invalid description {retval} for parameter {k}, should be a string.")
|
||||
else:
|
||||
return k
|
||||
|
||||
|
||||
def get_parameter_json_schema(k: str, v: Any, default_values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Get a JSON schema for a parameter as defined by the OpenAI API
|
||||
|
||||
Args:
|
||||
k: The name of the parameter
|
||||
v: The type of the parameter
|
||||
default_values: The default values of the parameters of the function
|
||||
|
||||
Returns:
|
||||
A Pydanitc model for the parameter
|
||||
"""
|
||||
|
||||
schema = type2schema(v)
|
||||
if k in default_values:
|
||||
dv = default_values[k]
|
||||
schema["default"] = dv
|
||||
|
||||
schema["description"] = type2description(k, v)
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def get_required_params(typed_signature: inspect.Signature) -> List[str]:
|
||||
"""Get the required parameters of a function
|
||||
|
||||
Args:
|
||||
signature: The signature of the function as returned by inspect.signature
|
||||
|
||||
Returns:
|
||||
A list of the required parameters of the function
|
||||
"""
|
||||
return [k for k, v in typed_signature.parameters.items() if v.default == inspect.Signature.empty]
|
||||
|
||||
|
||||
def get_default_values(typed_signature: inspect.Signature) -> Dict[str, Any]:
|
||||
"""Get default values of parameters of a function
|
||||
|
||||
Args:
|
||||
signature: The signature of the function as returned by inspect.signature
|
||||
|
||||
Returns:
|
||||
A dictionary of the default values of the parameters of the function
|
||||
"""
|
||||
return {k: v.default for k, v in typed_signature.parameters.items() if v.default != inspect.Signature.empty}
|
||||
|
||||
|
||||
def get_parameters(
|
||||
required: List[str],
|
||||
param_annotations: Dict[str, Union[Annotated[Type[Any], str], Type[Any]]],
|
||||
default_values: Dict[str, Any],
|
||||
) -> Parameters:
|
||||
"""Get the parameters of a function as defined by the OpenAI API
|
||||
|
||||
Args:
|
||||
required: The required parameters of the function
|
||||
hints: The type hints of the function as returned by typing.get_type_hints
|
||||
|
||||
Returns:
|
||||
A Pydantic model for the parameters of the function
|
||||
"""
|
||||
return Parameters(
|
||||
properties={
|
||||
k: get_parameter_json_schema(k, v, default_values)
|
||||
for k, v in param_annotations.items()
|
||||
if v is not inspect.Signature.empty
|
||||
},
|
||||
required=required,
|
||||
)
|
||||
|
||||
|
||||
def get_missing_annotations(typed_signature: inspect.Signature, required: List[str]) -> Tuple[Set[str], Set[str]]:
|
||||
"""Get the missing annotations of a function
|
||||
|
||||
Ignores the parameters with default values as they are not required to be annotated, but logs a warning.
|
||||
Args:
|
||||
typed_signature: The signature of the function with type annotations
|
||||
required: The required parameters of the function
|
||||
|
||||
Returns:
|
||||
A set of the missing annotations of the function
|
||||
"""
|
||||
all_missing = {k for k, v in typed_signature.parameters.items() if v.annotation is inspect.Signature.empty}
|
||||
missing = all_missing.intersection(set(required))
|
||||
unannotated_with_default = all_missing.difference(missing)
|
||||
return missing, unannotated_with_default
|
||||
|
||||
|
||||
def get_function_schema(f: Callable[..., Any], *, name: Optional[str] = None, description: str) -> Dict[str, Any]:
|
||||
"""Get a JSON schema for a function as defined by the OpenAI API
|
||||
|
||||
Args:
|
||||
f: The function to get the JSON schema for
|
||||
name: The name of the function
|
||||
description: The description of the function
|
||||
|
||||
Returns:
|
||||
A JSON schema for the function
|
||||
|
||||
Raises:
|
||||
TypeError: If the function is not annotated
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def f(
|
||||
a: Annotated[str, "Parameter a"],
|
||||
b: int = 2,
|
||||
c: Annotated[float, "Parameter c"] = 0.1,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
get_function_schema(f, description="function f")
|
||||
|
||||
# {'type': 'function',
|
||||
# 'function': {'description': 'function f',
|
||||
# 'name': 'f',
|
||||
# 'parameters': {'type': 'object',
|
||||
# 'properties': {'a': {'type': 'str', 'description': 'Parameter a'},
|
||||
# 'b': {'type': 'int', 'description': 'b'},
|
||||
# 'c': {'type': 'float', 'description': 'Parameter c'}},
|
||||
# 'required': ['a']}}}
|
||||
|
||||
"""
|
||||
typed_signature = get_typed_signature(f)
|
||||
required = get_required_params(typed_signature)
|
||||
default_values = get_default_values(typed_signature)
|
||||
param_annotations = get_param_annotations(typed_signature)
|
||||
return_annotation = get_typed_return_annotation(f)
|
||||
missing, unannotated_with_default = get_missing_annotations(typed_signature, required)
|
||||
|
||||
if return_annotation is None:
|
||||
logger.warning(
|
||||
f"The return type of the function '{f.__name__}' is not annotated. Although annotating it is "
|
||||
+ "optional, the function should return either a string, a subclass of 'pydantic.BaseModel'."
|
||||
)
|
||||
|
||||
if unannotated_with_default != set():
|
||||
unannotated_with_default_s = [f"'{k}'" for k in sorted(unannotated_with_default)]
|
||||
logger.warning(
|
||||
f"The following parameters of the function '{f.__name__}' with default values are not annotated: "
|
||||
+ f"{', '.join(unannotated_with_default_s)}."
|
||||
)
|
||||
|
||||
if missing != set():
|
||||
missing_s = [f"'{k}'" for k in sorted(missing)]
|
||||
raise TypeError(
|
||||
f"All parameters of the function '{f.__name__}' without default values must be annotated. "
|
||||
+ f"The annotations are missing for the following parameters: {', '.join(missing_s)}"
|
||||
)
|
||||
|
||||
fname = name if name else f.__name__
|
||||
|
||||
parameters = get_parameters(required, param_annotations, default_values=default_values)
|
||||
|
||||
function = ToolFunction(
|
||||
function=Function(
|
||||
description=description,
|
||||
name=fname,
|
||||
parameters=parameters,
|
||||
)
|
||||
)
|
||||
|
||||
return model_dump(function)
|
||||
|
||||
|
||||
def normalize_annotated_type(type_hint: Type[Any]) -> Type[Any]:
|
||||
"""Normalize typing.Annotated types to the inner type."""
|
||||
if get_origin(type_hint) is Annotated:
|
||||
# Extract the inner type from Annotated
|
||||
return get_args(type_hint)[0] # type: ignore
|
||||
return type_hint
|
||||
|
||||
|
||||
def args_base_model_from_signature(name: str, sig: inspect.Signature) -> Type[BaseModel]:
|
||||
fields: Dict[str, tuple[Type[Any], Any]] = {}
|
||||
for name, param in sig.parameters.items():
|
||||
# This is handled externally
|
||||
if name == "cancellation_token":
|
||||
continue
|
||||
|
||||
if param.annotation is inspect.Parameter.empty:
|
||||
raise ValueError("No annotation")
|
||||
|
||||
type = normalize_annotated_type(param.annotation)
|
||||
description = type2description(name, param.annotation)
|
||||
default_value = param.default if param.default is not inspect.Parameter.empty else PydanticUndefined
|
||||
|
||||
fields[name] = (type, Field(default=default_value, description=description))
|
||||
|
||||
return cast(BaseModel, create_model(name, **fields)) # type: ignore
|
||||
78
python/src/agnext/components/_image.py
Normal file
78
python/src/agnext/components/_image.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import re
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import aiohttp
|
||||
from openai.types.chat import ChatCompletionContentPartImageParam
|
||||
from PIL import Image as PILImage
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
class Image:
|
||||
def __init__(self, image: PILImage.Image):
|
||||
self.image: PILImage.Image = image.convert("RGB")
|
||||
|
||||
@classmethod
|
||||
def from_pil(cls, pil_image: PILImage.Image) -> Image:
|
||||
return cls(pil_image)
|
||||
|
||||
@classmethod
|
||||
def from_uri(cls, uri: str) -> Image:
|
||||
if not re.match(r"data:image/(?:png|jpeg);base64,", uri):
|
||||
raise ValueError("Invalid URI format. It should be a base64 encoded image URI.")
|
||||
|
||||
# A URI. Remove the prefix and decode the base64 string.
|
||||
base64_data = re.sub(r"data:image/(?:png|jpeg);base64,", "", uri)
|
||||
return cls.from_base64(base64_data)
|
||||
|
||||
@classmethod
|
||||
async def from_url(cls, url: str) -> Image:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
content = await response.read()
|
||||
return cls(PILImage.open(content))
|
||||
|
||||
@classmethod
|
||||
def from_base64(cls, base64_str: str) -> Image:
|
||||
return cls(PILImage.open(BytesIO(base64.b64decode(base64_str))))
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, file_path: Path) -> Image:
|
||||
return cls(PILImage.open(file_path))
|
||||
|
||||
def _repr_html_(self) -> str:
|
||||
# Show the image in Jupyter notebook
|
||||
return f'<img src="{self.data_uri}"/>'
|
||||
|
||||
@property
|
||||
def data_uri(self) -> str:
|
||||
buffered = BytesIO()
|
||||
self.image.save(buffered, format="PNG")
|
||||
content = buffered.getvalue()
|
||||
return _convert_base64_to_data_uri(base64.b64encode(content).decode("utf-8"))
|
||||
|
||||
def to_openai_format(self, detail: Literal["auto", "low", "high"] = "auto") -> ChatCompletionContentPartImageParam:
|
||||
return {"type": "image_url", "image_url": {"url": self.data_uri, "detail": detail}}
|
||||
|
||||
|
||||
def _convert_base64_to_data_uri(base64_image: str) -> str:
|
||||
def _get_mime_type_from_data_uri(base64_image: str) -> str:
|
||||
# Decode the base64 string
|
||||
image_data = base64.b64decode(base64_image)
|
||||
# Check the first few bytes for known signatures
|
||||
if image_data.startswith(b"\xff\xd8\xff"):
|
||||
return "image/jpeg"
|
||||
elif image_data.startswith(b"\x89PNG\r\n\x1a\n"):
|
||||
return "image/png"
|
||||
elif image_data.startswith(b"GIF87a") or image_data.startswith(b"GIF89a"):
|
||||
return "image/gif"
|
||||
elif image_data.startswith(b"RIFF") and image_data[8:12] == b"WEBP":
|
||||
return "image/webp"
|
||||
return "image/jpeg" # use jpeg for unknown formats, best guess.
|
||||
|
||||
mime_type = _get_mime_type_from_data_uri(base64_image)
|
||||
data_uri = f"data:{mime_type};base64,{base64_image}"
|
||||
return data_uri
|
||||
65
python/src/agnext/components/_pydantic_compat.py
Normal file
65
python/src/agnext/components/_pydantic_compat.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# File based from: https://github.com/microsoft/autogen/blob/47f905267245e143562abfb41fcba503a9e1d56d/autogen/_pydantic.py
|
||||
# Credit to original authors
|
||||
|
||||
|
||||
from typing import Any, Dict, Tuple, Type, Union, get_args
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.version import VERSION as PYDANTIC_VERSION
|
||||
from typing_extensions import get_origin
|
||||
|
||||
__all__ = ("model_dump", "type2schema", "evaluate_forwardref")
|
||||
|
||||
PYDANTIC_V1 = PYDANTIC_VERSION.startswith("1.")
|
||||
|
||||
|
||||
def evaluate_forwardref(
|
||||
value: Any,
|
||||
globalns: dict[str, Any] | None = None,
|
||||
localns: dict[str, Any] | None = None,
|
||||
) -> Any:
|
||||
if PYDANTIC_V1:
|
||||
from pydantic.typing import evaluate_forwardref as evaluate_forwardref_internal
|
||||
|
||||
return evaluate_forwardref_internal(value, globalns, localns)
|
||||
else:
|
||||
from pydantic._internal._typing_extra import eval_type_lenient
|
||||
|
||||
return eval_type_lenient(value, globalns, localns)
|
||||
|
||||
|
||||
def type2schema(t: Type[Any] | None) -> Dict[str, Any]:
|
||||
if PYDANTIC_V1:
|
||||
from pydantic import schema_of # type: ignore
|
||||
|
||||
if t is None:
|
||||
return {"type": "null"}
|
||||
elif get_origin(t) is Union:
|
||||
return {"anyOf": [type2schema(tt) for tt in get_args(t)]}
|
||||
elif get_origin(t) in [Tuple, tuple]:
|
||||
prefixItems = [type2schema(tt) for tt in get_args(t)]
|
||||
return {
|
||||
"maxItems": len(prefixItems),
|
||||
"minItems": len(prefixItems),
|
||||
"prefixItems": prefixItems,
|
||||
"type": "array",
|
||||
}
|
||||
|
||||
d = schema_of(t) # type: ignore
|
||||
if "title" in d:
|
||||
d.pop("title")
|
||||
if "description" in d:
|
||||
d.pop("description")
|
||||
|
||||
return d
|
||||
else:
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
return TypeAdapter(t).json_schema()
|
||||
|
||||
|
||||
def model_dump(model: BaseModel) -> Dict[str, Any]:
|
||||
if PYDANTIC_V1:
|
||||
return model.dict() # type: ignore
|
||||
else:
|
||||
return model.model_dump()
|
||||
191
python/src/agnext/components/_type_routed_agent.py
Normal file
191
python/src/agnext/components/_type_routed_agent.py
Normal file
@@ -0,0 +1,191 @@
|
||||
import logging
|
||||
from functools import wraps
|
||||
from types import NoneType, UnionType
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
Literal,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Protocol,
|
||||
Sequence,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
get_args,
|
||||
get_origin,
|
||||
get_type_hints,
|
||||
overload,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from ..core import BaseAgent, CancellationToken
|
||||
from ..core.exceptions import CantHandleException
|
||||
|
||||
logger = logging.getLogger("agnext")
|
||||
|
||||
ReceivesT = TypeVar("ReceivesT", contravariant=True)
|
||||
ProducesT = TypeVar("ProducesT", covariant=True)
|
||||
|
||||
# TODO: Generic typevar bound binding U to agent type
|
||||
# Can't do because python doesnt support it
|
||||
|
||||
|
||||
def is_union(t: object) -> bool:
|
||||
origin = get_origin(t)
|
||||
return origin is Union or origin is UnionType
|
||||
|
||||
|
||||
def is_optional(t: object) -> bool:
|
||||
origin = get_origin(t)
|
||||
return origin is Optional
|
||||
|
||||
|
||||
# Special type to avoid the 3.10 vs 3.11+ difference of typing._SpecialForm vs typing.Any
|
||||
class AnyType:
|
||||
pass
|
||||
|
||||
|
||||
def get_types(t: object) -> Sequence[Type[Any]] | None:
|
||||
if is_union(t):
|
||||
return get_args(t)
|
||||
elif is_optional(t):
|
||||
return tuple(list(get_args(t)) + [NoneType])
|
||||
elif t is Any:
|
||||
return (AnyType,)
|
||||
elif isinstance(t, type):
|
||||
return (t,)
|
||||
elif isinstance(t, NoneType):
|
||||
return (NoneType,)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class MessageHandler(Protocol[ReceivesT, ProducesT]):
|
||||
target_types: Sequence[type]
|
||||
produces_types: Sequence[type]
|
||||
is_message_handler: Literal[True]
|
||||
|
||||
async def __call__(self, message: ReceivesT, cancellation_token: CancellationToken) -> ProducesT: ...
|
||||
|
||||
|
||||
# NOTE: this works on concrete types and not inheritance
|
||||
# TODO: Use a protocl for the outer function to check checked arg names
|
||||
|
||||
|
||||
@overload
|
||||
def message_handler(
|
||||
func: Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT]],
|
||||
) -> MessageHandler[ReceivesT, ProducesT]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def message_handler(
|
||||
func: None = None,
|
||||
*,
|
||||
strict: bool = ...,
|
||||
) -> Callable[
|
||||
[Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT]]],
|
||||
MessageHandler[ReceivesT, ProducesT],
|
||||
]: ...
|
||||
|
||||
|
||||
def message_handler(
|
||||
func: None | Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT]] = None,
|
||||
*,
|
||||
strict: bool = True,
|
||||
) -> (
|
||||
Callable[
|
||||
[Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT]]],
|
||||
MessageHandler[ReceivesT, ProducesT],
|
||||
]
|
||||
| MessageHandler[ReceivesT, ProducesT]
|
||||
):
|
||||
def decorator(
|
||||
func: Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT]],
|
||||
) -> MessageHandler[ReceivesT, ProducesT]:
|
||||
type_hints = get_type_hints(func)
|
||||
if "message" not in type_hints:
|
||||
raise AssertionError("message parameter not found in function signature")
|
||||
|
||||
if "return" not in type_hints:
|
||||
raise AssertionError("return not found in function signature")
|
||||
|
||||
# Get the type of the message parameter
|
||||
target_types = get_types(type_hints["message"])
|
||||
if target_types is None:
|
||||
raise AssertionError("Message type not found")
|
||||
|
||||
# print(type_hints)
|
||||
return_types = get_types(type_hints["return"])
|
||||
|
||||
if return_types is None:
|
||||
raise AssertionError("Return type not found")
|
||||
|
||||
# Convert target_types to list and stash
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(self: Any, message: ReceivesT, cancellation_token: CancellationToken) -> ProducesT:
|
||||
if type(message) not in target_types:
|
||||
if strict:
|
||||
raise CantHandleException(f"Message type {type(message)} not in target types {target_types}")
|
||||
else:
|
||||
logger.warning(f"Message type {type(message)} not in target types {target_types}")
|
||||
|
||||
return_value = await func(self, message, cancellation_token)
|
||||
|
||||
if AnyType not in return_types and type(return_value) not in return_types:
|
||||
if strict:
|
||||
raise ValueError(f"Return type {type(return_value)} not in return types {return_types}")
|
||||
else:
|
||||
logger.warning(f"Return type {type(return_value)} not in return types {return_types}")
|
||||
|
||||
return return_value
|
||||
|
||||
wrapper_handler = cast(MessageHandler[ReceivesT, ProducesT], wrapper)
|
||||
wrapper_handler.target_types = list(target_types)
|
||||
wrapper_handler.produces_types = list(return_types)
|
||||
wrapper_handler.is_message_handler = True
|
||||
|
||||
return wrapper_handler
|
||||
|
||||
if func is None and not callable(func):
|
||||
return decorator
|
||||
elif callable(func):
|
||||
return decorator(func)
|
||||
else:
|
||||
raise ValueError("Invalid arguments")
|
||||
|
||||
|
||||
class TypeRoutedAgent(BaseAgent):
|
||||
def __init__(self, description: str) -> None:
|
||||
# Self is already bound to the handlers
|
||||
self._handlers: Dict[
|
||||
Type[Any],
|
||||
Callable[[Any, CancellationToken], Coroutine[Any, Any, Any | None]],
|
||||
] = {}
|
||||
|
||||
for attr in dir(self):
|
||||
if callable(getattr(self, attr, None)):
|
||||
handler = getattr(self, attr)
|
||||
if hasattr(handler, "is_message_handler"):
|
||||
message_handler = cast(MessageHandler[Any, Any], handler)
|
||||
for target_type in message_handler.target_types:
|
||||
self._handlers[target_type] = message_handler
|
||||
subscriptions = list(self._handlers.keys())
|
||||
super().__init__(description, subscriptions)
|
||||
|
||||
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any | None:
|
||||
key_type: Type[Any] = type(message) # type: ignore
|
||||
handler = self._handlers.get(key_type) # type: ignore
|
||||
if handler is not None:
|
||||
return await handler(message, cancellation_token)
|
||||
else:
|
||||
return await self.on_unhandled_message(message, cancellation_token)
|
||||
|
||||
async def on_unhandled_message(self, message: Any, cancellation_token: CancellationToken) -> NoReturn:
|
||||
raise CantHandleException(f"Unhandled message: {message}")
|
||||
12
python/src/agnext/components/_types.py
Normal file
12
python/src/agnext/components/_types.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionCall:
|
||||
id: str
|
||||
# JSON args
|
||||
arguments: str
|
||||
# Function to call
|
||||
name: str
|
||||
17
python/src/agnext/components/code_executor/__init__.py
Normal file
17
python/src/agnext/components/code_executor/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from ._base import CodeBlock, CodeExecutor, CodeResult
|
||||
from ._func_with_reqs import Alias, FunctionWithRequirements, Import, ImportFromModule, with_requirements
|
||||
from ._impl.command_line_code_result import CommandLineCodeResult
|
||||
from ._impl.local_commandline_code_executor import LocalCommandLineCodeExecutor
|
||||
|
||||
__all__ = [
|
||||
"LocalCommandLineCodeExecutor",
|
||||
"CommandLineCodeResult",
|
||||
"CodeBlock",
|
||||
"CodeResult",
|
||||
"CodeExecutor",
|
||||
"Alias",
|
||||
"ImportFromModule",
|
||||
"Import",
|
||||
"FunctionWithRequirements",
|
||||
"with_requirements",
|
||||
]
|
||||
50
python/src/agnext/components/code_executor/_base.py
Normal file
50
python/src/agnext/components/code_executor/_base.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# File based from: https://github.com/microsoft/autogen/blob/main/autogen/coding/base.py
|
||||
# Credit to original authors
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Protocol, runtime_checkable
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodeBlock:
|
||||
"""A code block extracted fromm an agent message."""
|
||||
|
||||
code: str
|
||||
language: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodeResult:
|
||||
"""Result of a code execution."""
|
||||
|
||||
exit_code: int
|
||||
output: str
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class CodeExecutor(Protocol):
|
||||
"""Executes code blocks and returns the result."""
|
||||
|
||||
def execute_code_blocks(self, code_blocks: List[CodeBlock]) -> CodeResult:
|
||||
"""Execute code blocks and return the result.
|
||||
|
||||
This method should be implemented by the code executor.
|
||||
|
||||
Args:
|
||||
code_blocks (List[CodeBlock]): The code blocks to execute.
|
||||
|
||||
Returns:
|
||||
CodeResult: The result of the code execution.
|
||||
"""
|
||||
...
|
||||
|
||||
def restart(self) -> None:
|
||||
"""Restart the code executor.
|
||||
|
||||
This method should be implemented by the code executor.
|
||||
|
||||
This method is called when the agent is reset.
|
||||
"""
|
||||
...
|
||||
200
python/src/agnext/components/code_executor/_func_with_reqs.py
Normal file
200
python/src/agnext/components/code_executor/_func_with_reqs.py
Normal file
@@ -0,0 +1,200 @@
|
||||
# File based from: https://github.com/microsoft/autogen/blob/main/autogen/coding/func_with_reqs.py
|
||||
# Credit to original authors
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
from dataclasses import dataclass, field
|
||||
from importlib.abc import SourceLoader
|
||||
from importlib.util import module_from_spec, spec_from_loader
|
||||
from textwrap import dedent, indent
|
||||
from typing import Any, Callable, Generic, List, Sequence, Set, TypeVar, Union
|
||||
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
T = TypeVar("T")
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
def _to_code(func: Union[FunctionWithRequirements[T, P], Callable[P, T], FunctionWithRequirementsStr]) -> str:
|
||||
if isinstance(func, FunctionWithRequirementsStr):
|
||||
return func.func
|
||||
|
||||
code = inspect.getsource(func)
|
||||
# Strip the decorator
|
||||
if code.startswith("@"):
|
||||
code = code[code.index("\n") + 1 :]
|
||||
return code
|
||||
|
||||
|
||||
@dataclass
|
||||
class Alias:
|
||||
name: str
|
||||
alias: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImportFromModule:
|
||||
module: str
|
||||
imports: List[Union[str, Alias]]
|
||||
|
||||
|
||||
Import = Union[str, ImportFromModule, Alias]
|
||||
|
||||
|
||||
def _import_to_str(im: Import) -> str:
|
||||
if isinstance(im, str):
|
||||
return f"import {im}"
|
||||
elif isinstance(im, Alias):
|
||||
return f"import {im.name} as {im.alias}"
|
||||
else:
|
||||
|
||||
def to_str(i: Union[str, Alias]) -> str:
|
||||
if isinstance(i, str):
|
||||
return i
|
||||
else:
|
||||
return f"{i.name} as {i.alias}"
|
||||
|
||||
imports = ", ".join(map(to_str, im.imports))
|
||||
return f"from {im.module} import {imports}"
|
||||
|
||||
|
||||
class _StringLoader(SourceLoader):
|
||||
def __init__(self, data: str):
|
||||
self.data = data
|
||||
|
||||
def get_source(self, fullname: str) -> str:
|
||||
return self.data
|
||||
|
||||
def get_data(self, path: str) -> bytes:
|
||||
return self.data.encode("utf-8")
|
||||
|
||||
def get_filename(self, fullname: str) -> str:
|
||||
return "<not a real path>/" + fullname + ".py"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionWithRequirementsStr:
|
||||
func: str
|
||||
compiled_func: Callable[..., Any]
|
||||
_func_name: str
|
||||
python_packages: Sequence[str] = field(default_factory=list)
|
||||
global_imports: Sequence[Import] = field(default_factory=list)
|
||||
|
||||
def __init__(self, func: str, python_packages: Sequence[str] = [], global_imports: Sequence[Import] = []):
|
||||
self.func = func
|
||||
self.python_packages = python_packages
|
||||
self.global_imports = global_imports
|
||||
|
||||
module_name = "func_module"
|
||||
loader = _StringLoader(func)
|
||||
spec = spec_from_loader(module_name, loader)
|
||||
if spec is None:
|
||||
raise ValueError("Could not create spec")
|
||||
module = module_from_spec(spec)
|
||||
if spec.loader is None:
|
||||
raise ValueError("Could not create loader")
|
||||
|
||||
try:
|
||||
spec.loader.exec_module(module)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Could not compile function: {e}") from e
|
||||
|
||||
functions = inspect.getmembers(module, inspect.isfunction)
|
||||
if len(functions) != 1:
|
||||
raise ValueError("The string must contain exactly one function")
|
||||
|
||||
self._func_name, self.compiled_func = functions[0]
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> None:
|
||||
raise NotImplementedError("String based function with requirement objects are not directly callable")
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionWithRequirements(Generic[T, P]):
|
||||
func: Callable[P, T]
|
||||
python_packages: Sequence[str] = field(default_factory=list)
|
||||
global_imports: Sequence[Import] = field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def from_callable(
|
||||
cls, func: Callable[P, T], python_packages: Sequence[str] = [], global_imports: Sequence[Import] = []
|
||||
) -> FunctionWithRequirements[T, P]:
|
||||
return cls(python_packages=python_packages, global_imports=global_imports, func=func)
|
||||
|
||||
@staticmethod
|
||||
def from_str(
|
||||
func: str, python_packages: Sequence[str] = [], global_imports: Sequence[Import] = []
|
||||
) -> FunctionWithRequirementsStr:
|
||||
return FunctionWithRequirementsStr(func=func, python_packages=python_packages, global_imports=global_imports)
|
||||
|
||||
# Type this based on F
|
||||
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
|
||||
return self.func(*args, **kwargs)
|
||||
|
||||
|
||||
def with_requirements(
|
||||
python_packages: Sequence[str] = [], global_imports: Sequence[Import] = []
|
||||
) -> Callable[[Callable[P, T]], FunctionWithRequirements[T, P]]:
|
||||
"""Decorate a function with package and import requirements
|
||||
|
||||
Args:
|
||||
python_packages (List[str], optional): Packages required to function. Can include version info.. Defaults to [].
|
||||
global_imports (List[Import], optional): Required imports. Defaults to [].
|
||||
|
||||
Returns:
|
||||
Callable[[Callable[P, T]], FunctionWithRequirements[T, P]]: The decorated function
|
||||
"""
|
||||
|
||||
def wrapper(func: Callable[P, T]) -> FunctionWithRequirements[T, P]:
|
||||
func_with_reqs = FunctionWithRequirements(
|
||||
python_packages=python_packages, global_imports=global_imports, func=func
|
||||
)
|
||||
|
||||
functools.update_wrapper(func_with_reqs, func)
|
||||
return func_with_reqs
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def build_python_functions_file(
|
||||
funcs: Sequence[Union[FunctionWithRequirements[Any, P], Callable[..., Any], FunctionWithRequirementsStr]],
|
||||
) -> str:
|
||||
# First collect all global imports
|
||||
global_imports: Set[Import] = set()
|
||||
for func in funcs:
|
||||
if isinstance(func, (FunctionWithRequirements, FunctionWithRequirementsStr)):
|
||||
global_imports.update(func.global_imports)
|
||||
|
||||
content = "\n".join(map(_import_to_str, global_imports)) + "\n\n"
|
||||
|
||||
for func in funcs:
|
||||
content += _to_code(func) + "\n\n"
|
||||
|
||||
return content
|
||||
|
||||
|
||||
def to_stub(func: Union[Callable[..., Any], FunctionWithRequirementsStr]) -> str:
|
||||
"""Generate a stub for a function as a string
|
||||
|
||||
Args:
|
||||
func (Callable[..., Any]): The function to generate a stub for
|
||||
|
||||
Returns:
|
||||
str: The stub for the function
|
||||
"""
|
||||
if isinstance(func, FunctionWithRequirementsStr):
|
||||
return to_stub(func.compiled_func)
|
||||
|
||||
content = f"def {func.__name__}{inspect.signature(func)}:\n"
|
||||
docstring = func.__doc__
|
||||
|
||||
if docstring:
|
||||
docstring = dedent(docstring)
|
||||
docstring = '"""' + docstring + '"""'
|
||||
docstring = indent(docstring, " ")
|
||||
content += docstring + "\n"
|
||||
|
||||
content += " ..."
|
||||
return content
|
||||
@@ -0,0 +1,11 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from .._base import CodeResult
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommandLineCodeResult(CodeResult):
|
||||
"""A code result class for command line code executor."""
|
||||
|
||||
code_file: Optional[str]
|
||||
@@ -0,0 +1,269 @@
|
||||
# File based from: https://github.com/microsoft/autogen/blob/main/autogen/coding/local_commandline_code_executor.py
|
||||
# Credit to original authors
|
||||
|
||||
import logging
|
||||
import subprocess
|
||||
import sys
|
||||
import warnings
|
||||
from hashlib import md5
|
||||
from pathlib import Path
|
||||
from string import Template
|
||||
from typing import Any, Callable, ClassVar, List, Sequence, Union
|
||||
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from .._base import CodeBlock, CodeExecutor
|
||||
from .._func_with_reqs import (
|
||||
FunctionWithRequirements,
|
||||
FunctionWithRequirementsStr,
|
||||
build_python_functions_file,
|
||||
to_stub,
|
||||
)
|
||||
from .command_line_code_result import CommandLineCodeResult
|
||||
from .utils import PYTHON_VARIANTS, get_file_name_from_content, lang_to_cmd, silence_pip # type: ignore
|
||||
|
||||
__all__ = ("LocalCommandLineCodeExecutor",)
|
||||
|
||||
A = ParamSpec("A")
|
||||
|
||||
|
||||
class LocalCommandLineCodeExecutor(CodeExecutor):
|
||||
SUPPORTED_LANGUAGES: ClassVar[List[str]] = [
|
||||
"bash",
|
||||
"shell",
|
||||
"sh",
|
||||
"pwsh",
|
||||
"powershell",
|
||||
"ps1",
|
||||
"python",
|
||||
]
|
||||
FUNCTION_PROMPT_TEMPLATE: ClassVar[
|
||||
str
|
||||
] = """You have access to the following user defined functions. They can be accessed from the module called `$module_name` by their function names.
|
||||
|
||||
For example, if there was a function called `foo` you could import it by writing `from $module_name import foo`
|
||||
|
||||
$functions"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
timeout: int = 60,
|
||||
work_dir: Union[Path, str] = Path("."),
|
||||
functions: Sequence[
|
||||
Union[
|
||||
FunctionWithRequirements[Any, A],
|
||||
Callable[..., Any],
|
||||
FunctionWithRequirementsStr,
|
||||
]
|
||||
] = [],
|
||||
functions_module: str = "functions",
|
||||
):
|
||||
"""(Experimental) A code executor class that executes code through a local command line
|
||||
environment.
|
||||
|
||||
**This will execute LLM generated code on the local machine.**
|
||||
|
||||
Each code block is saved as a file and executed in a separate process in
|
||||
the working directory, and a unique file is generated and saved in the
|
||||
working directory for each code block.
|
||||
The code blocks are executed in the order they are received.
|
||||
Command line code is sanitized using regular expression match against a list of dangerous commands in order to prevent self-destructive
|
||||
commands from being executed which may potentially affect the users environment.
|
||||
Currently the only supported languages is Python and shell scripts.
|
||||
For Python code, use the language "python" for the code block.
|
||||
For shell scripts, use the language "bash", "shell", or "sh" for the code
|
||||
block.
|
||||
|
||||
Args:
|
||||
timeout (int): The timeout for code execution. Default is 60.
|
||||
work_dir (str): The working directory for the code execution. If None,
|
||||
a default working directory will be used. The default working
|
||||
directory is the current directory ".".
|
||||
functions (List[Union[FunctionWithRequirements[Any, A], Callable[..., Any]]]): A list of functions that are available to the code executor. Default is an empty list.
|
||||
"""
|
||||
|
||||
if timeout < 1:
|
||||
raise ValueError("Timeout must be greater than or equal to 1.")
|
||||
|
||||
if isinstance(work_dir, str):
|
||||
work_dir = Path(work_dir)
|
||||
|
||||
if not functions_module.isidentifier():
|
||||
raise ValueError("Module name must be a valid Python identifier")
|
||||
|
||||
self._functions_module = functions_module
|
||||
|
||||
work_dir.mkdir(exist_ok=True)
|
||||
|
||||
self._timeout = timeout
|
||||
self._work_dir: Path = work_dir
|
||||
|
||||
self._functions = functions
|
||||
# Setup could take some time so we intentionally wait for the first code block to do it.
|
||||
if len(functions) > 0:
|
||||
self._setup_functions_complete = False
|
||||
else:
|
||||
self._setup_functions_complete = True
|
||||
|
||||
def format_functions_for_prompt(self, prompt_template: str = FUNCTION_PROMPT_TEMPLATE) -> str:
|
||||
"""(Experimental) Format the functions for a prompt.
|
||||
|
||||
The template includes two variables:
|
||||
- `$module_name`: The module name.
|
||||
- `$functions`: The functions formatted as stubs with two newlines between each function.
|
||||
|
||||
Args:
|
||||
prompt_template (str): The prompt template. Default is the class default.
|
||||
|
||||
Returns:
|
||||
str: The formatted prompt.
|
||||
"""
|
||||
|
||||
template = Template(prompt_template)
|
||||
return template.substitute(
|
||||
module_name=self._functions_module,
|
||||
functions="\n\n".join([to_stub(func) for func in self._functions]),
|
||||
)
|
||||
|
||||
@property
|
||||
def functions_module(self) -> str:
|
||||
"""(Experimental) The module name for the functions."""
|
||||
return self._functions_module
|
||||
|
||||
@property
|
||||
def functions(self) -> List[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def timeout(self) -> int:
|
||||
"""(Experimental) The timeout for code execution."""
|
||||
return self._timeout
|
||||
|
||||
@property
|
||||
def work_dir(self) -> Path:
|
||||
"""(Experimental) The working directory for the code execution."""
|
||||
return self._work_dir
|
||||
|
||||
def _setup_functions(self) -> None:
|
||||
func_file_content = build_python_functions_file(self._functions)
|
||||
func_file = self._work_dir / f"{self._functions_module}.py"
|
||||
func_file.write_text(func_file_content)
|
||||
|
||||
# Collect requirements
|
||||
lists_of_packages = [x.python_packages for x in self._functions if isinstance(x, FunctionWithRequirements)]
|
||||
flattened_packages = [item for sublist in lists_of_packages for item in sublist]
|
||||
required_packages = list(set(flattened_packages))
|
||||
if len(required_packages) > 0:
|
||||
logging.info("Ensuring packages are installed in executor.")
|
||||
|
||||
cmd = [sys.executable, "-m", "pip", "install"]
|
||||
cmd.extend(required_packages)
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
cwd=self._work_dir,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=float(self._timeout),
|
||||
)
|
||||
except subprocess.TimeoutExpired as e:
|
||||
raise ValueError("Pip install timed out") from e
|
||||
|
||||
if result.returncode != 0:
|
||||
raise ValueError(f"Pip install failed. {result.stdout}, {result.stderr}")
|
||||
|
||||
# Attempt to load the function file to check for syntax errors, imports etc.
|
||||
exec_result = self._execute_code_dont_check_setup([CodeBlock(code=func_file_content, language="python")])
|
||||
|
||||
if exec_result.exit_code != 0:
|
||||
raise ValueError(f"Functions failed to load: {exec_result.output}")
|
||||
|
||||
self._setup_functions_complete = True
|
||||
|
||||
def execute_code_blocks(self, code_blocks: List[CodeBlock]) -> CommandLineCodeResult:
|
||||
"""(Experimental) Execute the code blocks and return the result.
|
||||
|
||||
Args:
|
||||
code_blocks (List[CodeBlock]): The code blocks to execute.
|
||||
|
||||
Returns:
|
||||
CommandLineCodeResult: The result of the code execution."""
|
||||
|
||||
if not self._setup_functions_complete:
|
||||
self._setup_functions()
|
||||
|
||||
return self._execute_code_dont_check_setup(code_blocks)
|
||||
|
||||
def _execute_code_dont_check_setup(self, code_blocks: List[CodeBlock]) -> CommandLineCodeResult:
|
||||
logs_all: str = ""
|
||||
file_names: List[Path] = []
|
||||
exitcode = 0
|
||||
for code_block in code_blocks:
|
||||
lang, code = code_block.language, code_block.code
|
||||
lang = lang.lower()
|
||||
|
||||
code = silence_pip(code, lang)
|
||||
|
||||
if lang in PYTHON_VARIANTS:
|
||||
lang = "python"
|
||||
|
||||
if lang not in self.SUPPORTED_LANGUAGES:
|
||||
# In case the language is not supported, we return an error message.
|
||||
exitcode = 1
|
||||
logs_all += "\n" + f"unknown language {lang}"
|
||||
break
|
||||
|
||||
try:
|
||||
# Check if there is a filename comment
|
||||
filename = get_file_name_from_content(code, self._work_dir)
|
||||
except ValueError:
|
||||
return CommandLineCodeResult(
|
||||
exit_code=1,
|
||||
output="Filename is not in the workspace",
|
||||
code_file=None,
|
||||
)
|
||||
|
||||
if filename is None:
|
||||
# create a file with an automatically generated name
|
||||
code_hash = md5(code.encode()).hexdigest()
|
||||
filename = f"tmp_code_{code_hash}.{'py' if lang.startswith('python') else lang}"
|
||||
|
||||
written_file = (self._work_dir / filename).resolve()
|
||||
with written_file.open("w", encoding="utf-8") as f:
|
||||
f.write(code)
|
||||
file_names.append(written_file)
|
||||
|
||||
program = sys.executable if lang.startswith("python") else lang_to_cmd(lang)
|
||||
cmd = [program, str(written_file.absolute())]
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
cwd=self._work_dir,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=float(self._timeout),
|
||||
)
|
||||
except subprocess.TimeoutExpired:
|
||||
logs_all += "\n Timeout"
|
||||
# Same exit code as the timeout command on linux.
|
||||
exitcode = 124
|
||||
break
|
||||
|
||||
logs_all += result.stderr
|
||||
logs_all += result.stdout
|
||||
exitcode = result.returncode
|
||||
|
||||
if exitcode != 0:
|
||||
break
|
||||
|
||||
code_file = str(file_names[0]) if len(file_names) > 0 else None
|
||||
return CommandLineCodeResult(exit_code=exitcode, output=logs_all, code_file=code_file)
|
||||
|
||||
def restart(self) -> None:
|
||||
"""(Experimental) Restart the code executor."""
|
||||
warnings.warn(
|
||||
"Restarting local command line code executor is not supported. No action is taken.",
|
||||
stacklevel=2,
|
||||
)
|
||||
88
python/src/agnext/components/code_executor/_impl/utils.py
Normal file
88
python/src/agnext/components/code_executor/_impl/utils.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# File based from: https://github.com/microsoft/autogen/blob/main/autogen/coding/utils.py
|
||||
# Credit to original authors
|
||||
|
||||
# Will return the filename relative to the workspace path
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# Raises ValueError if the file is not in the workspace
|
||||
def get_file_name_from_content(code: str, workspace_path: Path) -> Optional[str]:
|
||||
first_line = code.split("\n")[0]
|
||||
# TODO - support other languages
|
||||
if first_line.startswith("# filename:"):
|
||||
filename = first_line.split(":")[1].strip()
|
||||
|
||||
# Handle relative paths in the filename
|
||||
path = Path(filename)
|
||||
if not path.is_absolute():
|
||||
path = workspace_path / path
|
||||
path = path.resolve()
|
||||
# Throws an error if the file is not in the workspace
|
||||
relative = path.relative_to(workspace_path.resolve())
|
||||
return str(relative)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def silence_pip(code: str, lang: str) -> str:
|
||||
"""Apply -qqq flag to pip install commands."""
|
||||
if lang == "python":
|
||||
regex = r"^! ?pip install"
|
||||
elif lang in ["bash", "shell", "sh", "pwsh", "powershell", "ps1"]:
|
||||
regex = r"^pip install"
|
||||
else:
|
||||
return code
|
||||
|
||||
# Find lines that start with pip install and make sure "-qqq" flag is added.
|
||||
lines = code.split("\n")
|
||||
for i, line in enumerate(lines):
|
||||
# use regex to find lines that start with pip install.
|
||||
match = re.search(regex, line)
|
||||
if match is not None:
|
||||
if "-qqq" not in line:
|
||||
lines[i] = line.replace(match.group(0), match.group(0) + " -qqq")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
PYTHON_VARIANTS = ["python", "Python", "py"]
|
||||
|
||||
|
||||
def lang_to_cmd(lang: str) -> str:
|
||||
if lang in PYTHON_VARIANTS:
|
||||
return "python"
|
||||
if lang.startswith("python") or lang in ["bash", "sh"]:
|
||||
return lang
|
||||
if lang in ["shell"]:
|
||||
return "sh"
|
||||
else:
|
||||
raise ValueError(f"Unsupported language: {lang}")
|
||||
|
||||
|
||||
# Regular expression for finding a code block
|
||||
# ```[ \t]*(\w+)?[ \t]*\r?\n(.*?)[ \t]*\r?\n``` Matches multi-line code blocks.
|
||||
# The [ \t]* matches the potential spaces before language name.
|
||||
# The (\w+)? matches the language, where the ? indicates it is optional.
|
||||
# The [ \t]* matches the potential spaces (not newlines) after language name.
|
||||
# The \r?\n makes sure there is a linebreak after ```.
|
||||
# The (.*?) matches the code itself (non-greedy).
|
||||
# The \r?\n makes sure there is a linebreak before ```.
|
||||
# The [ \t]* matches the potential spaces before closing ``` (the spec allows indentation).
|
||||
CODE_BLOCK_PATTERN = r"```[ \t]*(\w+)?[ \t]*\r?\n(.*?)\r?\n[ \t]*```"
|
||||
|
||||
|
||||
def infer_lang(code: str) -> str:
|
||||
"""infer the language for the code.
|
||||
TODO: make it robust.
|
||||
"""
|
||||
if code.startswith("python ") or code.startswith("pip") or code.startswith("python3 "):
|
||||
return "sh"
|
||||
|
||||
# check if code is a valid python code
|
||||
try:
|
||||
compile(code, "test", "exec")
|
||||
return "python"
|
||||
except SyntaxError:
|
||||
# not a valid python code
|
||||
return "unknown"
|
||||
32
python/src/agnext/components/models/__init__.py
Normal file
32
python/src/agnext/components/models/__init__.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from ._model_client import ChatCompletionClient, ModelCapabilities
|
||||
from ._openai_client import (
|
||||
AzureOpenAI,
|
||||
OpenAI,
|
||||
)
|
||||
from ._types import (
|
||||
AssistantMessage,
|
||||
CreateResult,
|
||||
FinishReasons,
|
||||
FunctionExecutionResult,
|
||||
FunctionExecutionResultMessage,
|
||||
LLMMessage,
|
||||
RequestUsage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AzureOpenAI",
|
||||
"OpenAI",
|
||||
"ModelCapabilities",
|
||||
"ChatCompletionClient",
|
||||
"SystemMessage",
|
||||
"UserMessage",
|
||||
"AssistantMessage",
|
||||
"FunctionExecutionResult",
|
||||
"FunctionExecutionResultMessage",
|
||||
"LLMMessage",
|
||||
"RequestUsage",
|
||||
"FinishReasons",
|
||||
"CreateResult",
|
||||
]
|
||||
52
python/src/agnext/components/models/_model_client.py
Normal file
52
python/src/agnext/components/models/_model_client.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Mapping, Optional, Sequence, runtime_checkable
|
||||
|
||||
from typing_extensions import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Protocol,
|
||||
Required,
|
||||
TypedDict,
|
||||
Union,
|
||||
)
|
||||
|
||||
from ..tools import Tool
|
||||
from ._types import CreateResult, LLMMessage, RequestUsage
|
||||
|
||||
|
||||
class ModelCapabilities(TypedDict, total=False):
|
||||
vision: Required[bool]
|
||||
function_calling: Required[bool]
|
||||
json_output: Required[bool]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ChatCompletionClient(Protocol):
|
||||
# Caching has to be handled internally as they can depend on the create args that were stored in the constructor
|
||||
async def create(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
tools: Sequence[Tool] = [],
|
||||
# None means do not override the default
|
||||
# A value means to override the client default - often specified in the constructor
|
||||
json_output: Optional[bool] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
) -> CreateResult: ...
|
||||
|
||||
def create_stream(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
tools: Sequence[Tool] = [],
|
||||
# None means do not override the default
|
||||
# A value means to override the client default - often specified in the constructor
|
||||
json_output: Optional[bool] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
) -> AsyncGenerator[Union[str, CreateResult], None]: ...
|
||||
|
||||
def actual_usage(self) -> RequestUsage: ...
|
||||
|
||||
def total_usage(self) -> RequestUsage: ...
|
||||
|
||||
@property
|
||||
def capabilities(self) -> ModelCapabilities: ...
|
||||
89
python/src/agnext/components/models/_model_info.py
Normal file
89
python/src/agnext/components/models/_model_info.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from typing import Dict
|
||||
|
||||
from ._model_client import ModelCapabilities
|
||||
|
||||
# Based on: https://platform.openai.com/docs/models/continuous-model-upgrades
|
||||
# This is a moving target, so correctness is checked by the model value returned by openai against expected values at runtime``
|
||||
_MODEL_POINTERS = {
|
||||
"gpt-4o": "gpt-4o-2024-05-13",
|
||||
"gpt-4-turbo": "gpt-4-turbo-2024-04-09",
|
||||
"gpt-4-turbo-preview": "gpt-4-0125-preview",
|
||||
"gpt-4": "gpt-4-0613",
|
||||
"gpt-4-32k": "gpt-4-32k-0613",
|
||||
"gpt-3.5-turbo": "gpt-3.5-turbo-0125",
|
||||
"gpt-3.5-turbo-16k": "gpt-3.5-turbo-16k-0613",
|
||||
}
|
||||
|
||||
_MODEL_CAPABILITIES: Dict[str, ModelCapabilities] = {
|
||||
"gpt-4o-2024-05-13": {
|
||||
"vision": True,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
},
|
||||
"gpt-4-turbo-2024-04-09": {
|
||||
"vision": True,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
},
|
||||
"gpt-4-0125-preview": {
|
||||
"vision": False,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
},
|
||||
"gpt-4-1106-preview": {
|
||||
"vision": False,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
},
|
||||
"gpt-4-1106-vision-preview": {
|
||||
"vision": True,
|
||||
"function_calling": False,
|
||||
"json_output": False,
|
||||
},
|
||||
"gpt-4-0613": {
|
||||
"vision": False,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
},
|
||||
"gpt-4-32k-0613": {
|
||||
"vision": False,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
},
|
||||
"gpt-3.5-turbo-0125": {
|
||||
"vision": False,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
},
|
||||
"gpt-3.5-turbo-1106": {
|
||||
"vision": False,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
},
|
||||
"gpt-3.5-turbo-instruct": {
|
||||
"vision": False,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
},
|
||||
"gpt-3.5-turbo-0613": {
|
||||
"vision": False,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
},
|
||||
"gpt-3.5-turbo-16k-0613": {
|
||||
"vision": False,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def resolve_model(model: str) -> str:
|
||||
if model in _MODEL_POINTERS:
|
||||
return _MODEL_POINTERS[model]
|
||||
return model
|
||||
|
||||
|
||||
def get_capabilties(model: str) -> ModelCapabilities:
|
||||
resolved_model = resolve_model(model)
|
||||
return _MODEL_CAPABILITIES[resolved_model]
|
||||
569
python/src/agnext/components/models/_openai_client.py
Normal file
569
python/src/agnext/components/models/_openai_client.py
Normal file
@@ -0,0 +1,569 @@
|
||||
import inspect
|
||||
import logging
|
||||
import re
|
||||
import warnings
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionContentPartParam,
|
||||
ChatCompletionContentPartTextParam,
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionMessageToolCallParam,
|
||||
ChatCompletionRole,
|
||||
ChatCompletionSystemMessageParam,
|
||||
ChatCompletionToolMessageParam,
|
||||
ChatCompletionToolParam,
|
||||
ChatCompletionUserMessageParam,
|
||||
completion_create_params,
|
||||
)
|
||||
from openai.types.shared_params import FunctionDefinition, FunctionParameters
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from ...application.logging import EVENT_LOGGER_NAME, LLMCallEvent
|
||||
from .. import (
|
||||
FunctionCall,
|
||||
Image,
|
||||
)
|
||||
from ..tools import Tool
|
||||
from . import _model_info
|
||||
from ._model_client import ChatCompletionClient, ModelCapabilities
|
||||
from ._types import (
|
||||
AssistantMessage,
|
||||
CreateResult,
|
||||
FunctionExecutionResultMessage,
|
||||
LLMMessage,
|
||||
RequestUsage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from .config import AzureOpenAIClientConfiguration, OpenAIClientConfiguration
|
||||
|
||||
logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
|
||||
openai_init_kwargs = set(inspect.getfullargspec(AsyncOpenAI.__init__).kwonlyargs)
|
||||
aopenai_init_kwargs = set(inspect.getfullargspec(AsyncAzureOpenAI.__init__).kwonlyargs)
|
||||
|
||||
create_kwargs = set(completion_create_params.CompletionCreateParamsBase.__annotations__.keys()) | set(
|
||||
("timeout", "stream")
|
||||
)
|
||||
# Only single choice allowed
|
||||
disallowed_create_args = set(["stream", "messages", "function_call", "functions", "n"])
|
||||
required_create_args: Set[str] = set(["model"])
|
||||
|
||||
|
||||
def _azure_openai_client_from_config(config: Mapping[str, Any]) -> AsyncAzureOpenAI:
|
||||
# Take a copy
|
||||
copied_config = dict(config).copy()
|
||||
|
||||
# Do some fixups
|
||||
copied_config["azure_deployment"] = copied_config.get("azure_deployment", config.get("model"))
|
||||
if copied_config["azure_deployment"] is not None:
|
||||
copied_config["azure_deployment"] = copied_config["azure_deployment"].replace(".", "")
|
||||
copied_config["azure_endpoint"] = copied_config.get("azure_endpoint", copied_config.pop("base_url", None))
|
||||
|
||||
# Shave down the config to just the AzureOpenAI kwargs
|
||||
azure_config = {k: v for k, v in copied_config.items() if k in aopenai_init_kwargs}
|
||||
return AsyncAzureOpenAI(**azure_config)
|
||||
|
||||
|
||||
def _openai_client_from_config(config: Mapping[str, Any]) -> AsyncOpenAI:
|
||||
# Shave down the config to just the OpenAI kwargs
|
||||
openai_config = {k: v for k, v in config.items() if k in openai_init_kwargs}
|
||||
return AsyncOpenAI(**openai_config)
|
||||
|
||||
|
||||
def _create_args_from_config(config: Mapping[str, Any]) -> Dict[str, Any]:
|
||||
create_args = {k: v for k, v in config.items() if k in create_kwargs}
|
||||
create_args_keys = set(create_args.keys())
|
||||
if not required_create_args.issubset(create_args_keys):
|
||||
raise ValueError(f"Required create args are missing: {required_create_args - create_args_keys}")
|
||||
if disallowed_create_args.intersection(create_args_keys):
|
||||
raise ValueError(f"Disallowed create args are present: {disallowed_create_args.intersection(create_args_keys)}")
|
||||
return create_args
|
||||
|
||||
|
||||
# TODO check types
|
||||
# oai_system_message_schema = type2schema(ChatCompletionSystemMessageParam)
|
||||
# oai_user_message_schema = type2schema(ChatCompletionUserMessageParam)
|
||||
# oai_assistant_message_schema = type2schema(ChatCompletionAssistantMessageParam)
|
||||
# oai_tool_message_schema = type2schema(ChatCompletionToolMessageParam)
|
||||
|
||||
|
||||
def type_to_role(message: LLMMessage) -> ChatCompletionRole:
|
||||
if isinstance(message, SystemMessage):
|
||||
return "system"
|
||||
elif isinstance(message, UserMessage):
|
||||
return "user"
|
||||
elif isinstance(message, AssistantMessage):
|
||||
return "assistant"
|
||||
else:
|
||||
return "tool"
|
||||
|
||||
|
||||
def user_message_to_oai(message: UserMessage) -> ChatCompletionUserMessageParam:
|
||||
if isinstance(message.content, str):
|
||||
return ChatCompletionUserMessageParam(
|
||||
content=message.content,
|
||||
role="user",
|
||||
name=message.source,
|
||||
)
|
||||
else:
|
||||
parts: List[ChatCompletionContentPartParam] = []
|
||||
for part in message.content:
|
||||
if isinstance(part, str):
|
||||
oai_part = ChatCompletionContentPartTextParam(
|
||||
text=part,
|
||||
type="text",
|
||||
)
|
||||
parts.append(oai_part)
|
||||
elif isinstance(part, Image):
|
||||
# TODO: support url based images
|
||||
# TODO: support specifying details
|
||||
parts.append(part.to_openai_format())
|
||||
else:
|
||||
raise ValueError(f"Unknown content type: {part}")
|
||||
return ChatCompletionUserMessageParam(
|
||||
content=parts,
|
||||
role="user",
|
||||
name=message.source,
|
||||
)
|
||||
|
||||
|
||||
def system_message_to_oai(message: SystemMessage) -> ChatCompletionSystemMessageParam:
|
||||
return ChatCompletionSystemMessageParam(
|
||||
content=message.content,
|
||||
role="system",
|
||||
)
|
||||
|
||||
|
||||
def func_call_to_oai(message: FunctionCall) -> ChatCompletionMessageToolCallParam:
|
||||
return ChatCompletionMessageToolCallParam(
|
||||
id=message.id,
|
||||
function={
|
||||
"arguments": message.arguments,
|
||||
"name": message.name,
|
||||
},
|
||||
type="function",
|
||||
)
|
||||
|
||||
|
||||
def tool_message_to_oai(
|
||||
message: FunctionExecutionResultMessage,
|
||||
) -> Sequence[ChatCompletionToolMessageParam]:
|
||||
return [
|
||||
ChatCompletionToolMessageParam(content=x.content, role="tool", tool_call_id=x.call_id) for x in message.content
|
||||
]
|
||||
|
||||
|
||||
def assistant_message_to_oai(
|
||||
message: AssistantMessage,
|
||||
) -> ChatCompletionAssistantMessageParam:
|
||||
if isinstance(message.content, list):
|
||||
return ChatCompletionAssistantMessageParam(
|
||||
tool_calls=[func_call_to_oai(x) for x in message.content],
|
||||
role="assistant",
|
||||
name=message.source,
|
||||
)
|
||||
else:
|
||||
return ChatCompletionAssistantMessageParam(
|
||||
content=message.content,
|
||||
role="assistant",
|
||||
name=message.source,
|
||||
)
|
||||
|
||||
|
||||
def to_oai_type(message: LLMMessage) -> Sequence[ChatCompletionMessageParam]:
|
||||
if isinstance(message, SystemMessage):
|
||||
return [system_message_to_oai(message)]
|
||||
elif isinstance(message, UserMessage):
|
||||
return [user_message_to_oai(message)]
|
||||
elif isinstance(message, AssistantMessage):
|
||||
return [assistant_message_to_oai(message)]
|
||||
else:
|
||||
return tool_message_to_oai(message)
|
||||
|
||||
|
||||
def _add_usage(usage1: RequestUsage, usage2: RequestUsage) -> RequestUsage:
|
||||
return RequestUsage(
|
||||
prompt_tokens=usage1.prompt_tokens + usage2.prompt_tokens,
|
||||
completion_tokens=usage1.completion_tokens + usage2.completion_tokens,
|
||||
)
|
||||
|
||||
|
||||
def convert_tools(
|
||||
tools: Sequence[Tool],
|
||||
) -> List[ChatCompletionToolParam]:
|
||||
result: List[ChatCompletionToolParam] = []
|
||||
for tool in tools:
|
||||
tool_schema = tool.schema
|
||||
result.append(
|
||||
ChatCompletionToolParam(
|
||||
type="function",
|
||||
function=FunctionDefinition(
|
||||
name=tool_schema["name"],
|
||||
description=tool_schema["description"] if "description" in tool_schema else "",
|
||||
parameters=cast(FunctionParameters, tool_schema["parameters"])
|
||||
if "parameters" in tool_schema
|
||||
else {},
|
||||
),
|
||||
)
|
||||
)
|
||||
# Check if all tools have valid names.
|
||||
for tool_param in result:
|
||||
assert_valid_name(tool_param["function"]["name"])
|
||||
return result
|
||||
|
||||
|
||||
def normalize_name(name: str) -> str:
|
||||
"""
|
||||
LLMs sometimes ask functions while ignoring their own format requirements, this function should be used to replace invalid characters with "_".
|
||||
|
||||
Prefer _assert_valid_name for validating user configuration or input
|
||||
"""
|
||||
return re.sub(r"[^a-zA-Z0-9_-]", "_", name)[:64]
|
||||
|
||||
|
||||
def assert_valid_name(name: str) -> str:
|
||||
"""
|
||||
Ensure that configured names are valid, raises ValueError if not.
|
||||
|
||||
For munging LLM responses use _normalize_name to ensure LLM specified names don't break the API.
|
||||
"""
|
||||
if not re.match(r"^[a-zA-Z0-9_-]+$", name):
|
||||
raise ValueError(f"Invalid name: {name}. Only letters, numbers, '_' and '-' are allowed.")
|
||||
if len(name) > 64:
|
||||
raise ValueError(f"Invalid name: {name}. Name must be less than 64 characters.")
|
||||
return name
|
||||
|
||||
|
||||
class BaseOpenAI(ChatCompletionClient):
|
||||
def __init__(
|
||||
self,
|
||||
client: Union[AsyncOpenAI, AsyncAzureOpenAI],
|
||||
create_args: Dict[str, Any],
|
||||
model_capabilities: Optional[ModelCapabilities] = None,
|
||||
):
|
||||
self._client = client
|
||||
if model_capabilities is None and isinstance(client, AsyncAzureOpenAI):
|
||||
raise ValueError("AzureOpenAI requires explicit model capabilities")
|
||||
elif model_capabilities is None:
|
||||
self._model_capabilities = _model_info.get_capabilties(create_args["model"])
|
||||
else:
|
||||
self._model_capabilities = model_capabilities
|
||||
|
||||
self._resolved_model: Optional[str] = None
|
||||
if "model" in create_args:
|
||||
self._resolved_model = _model_info.resolve_model(create_args["model"])
|
||||
|
||||
if (
|
||||
"response_format" in create_args
|
||||
and create_args["response_format"]["type"] == "json_object"
|
||||
and not self._model_capabilities["json_output"]
|
||||
):
|
||||
raise ValueError("Model does not support JSON output")
|
||||
|
||||
self._create_args = create_args
|
||||
self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
|
||||
self._actual_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
|
||||
|
||||
@classmethod
|
||||
def create_from_config(cls, config: Dict[str, Any]) -> ChatCompletionClient:
|
||||
return OpenAI(**config)
|
||||
|
||||
async def create(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
tools: Sequence[Tool] = [],
|
||||
json_output: Optional[bool] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
) -> CreateResult:
|
||||
# Make sure all extra_create_args are valid
|
||||
extra_create_args_keys = set(extra_create_args.keys())
|
||||
if not create_kwargs.issuperset(extra_create_args_keys):
|
||||
raise ValueError(f"Extra create args are invalid: {extra_create_args_keys - create_kwargs}")
|
||||
|
||||
# Copy the create args and overwrite anything in extra_create_args
|
||||
create_args = self._create_args.copy()
|
||||
create_args.update(extra_create_args)
|
||||
|
||||
# TODO: allow custom handling.
|
||||
# For now we raise an error if images are present and vision is not supported
|
||||
if self.capabilities["vision"] is False:
|
||||
for message in messages:
|
||||
if isinstance(message, UserMessage):
|
||||
if isinstance(message.content, list) and any(isinstance(x, Image) for x in message.content):
|
||||
raise ValueError("Model does not support vision and image was provided")
|
||||
|
||||
if json_output is not None:
|
||||
if self.capabilities["json_output"] is False and json_output is True:
|
||||
raise ValueError("Model does not support JSON output")
|
||||
|
||||
if json_output is True:
|
||||
create_args["response_format"] = {"type": "json_object"}
|
||||
else:
|
||||
create_args["response_format"] = {"type": "text"}
|
||||
|
||||
if self.capabilities["json_output"] is False and json_output is True:
|
||||
raise ValueError("Model does not support JSON output")
|
||||
|
||||
oai_messages_nested = [to_oai_type(m) for m in messages]
|
||||
oai_messages = [item for sublist in oai_messages_nested for item in sublist]
|
||||
|
||||
if self.capabilities["function_calling"] is False and len(tools) > 0:
|
||||
raise ValueError("Model does not support function calling")
|
||||
|
||||
if len(tools) > 0:
|
||||
converted_tools = convert_tools(tools)
|
||||
result = await self._client.chat.completions.create(
|
||||
messages=oai_messages,
|
||||
stream=False,
|
||||
tools=converted_tools,
|
||||
**create_args,
|
||||
)
|
||||
else:
|
||||
result = await self._client.chat.completions.create(messages=oai_messages, stream=False, **create_args)
|
||||
|
||||
if result.usage is not None:
|
||||
logger.info(
|
||||
LLMCallEvent(
|
||||
prompt_tokens=result.usage.prompt_tokens,
|
||||
completion_tokens=result.usage.completion_tokens,
|
||||
)
|
||||
)
|
||||
|
||||
usage = RequestUsage(
|
||||
# TODO backup token counting
|
||||
prompt_tokens=result.usage.prompt_tokens if result.usage is not None else 0,
|
||||
completion_tokens=(result.usage.completion_tokens if result.usage is not None else 0),
|
||||
)
|
||||
|
||||
if self._resolved_model is not None:
|
||||
if self._resolved_model != result.model:
|
||||
warnings.warn(
|
||||
f"Resolved model mismatch: {self._resolved_model} != {result.model}. AutoGen model mapping may be incorrect.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# Limited to a single choice currently.
|
||||
choice = result.choices[0]
|
||||
if choice.finish_reason == "function_call":
|
||||
raise ValueError("Function calls are not supported in this context")
|
||||
|
||||
content: Union[str, List[FunctionCall]]
|
||||
if choice.finish_reason == "tool_calls":
|
||||
assert choice.message.tool_calls is not None
|
||||
assert choice.message.function_call is None
|
||||
|
||||
# NOTE: If OAI response type changes, this will need to be updated
|
||||
content = [
|
||||
FunctionCall(
|
||||
id=x.id,
|
||||
arguments=x.function.arguments,
|
||||
name=normalize_name(x.function.name),
|
||||
)
|
||||
for x in choice.message.tool_calls
|
||||
]
|
||||
finish_reason = "function_calls"
|
||||
else:
|
||||
finish_reason = choice.finish_reason
|
||||
content = choice.message.content or ""
|
||||
|
||||
response = CreateResult(finish_reason=finish_reason, content=content, usage=usage, cached=False) # type: ignore
|
||||
|
||||
_add_usage(self._actual_usage, usage)
|
||||
_add_usage(self._total_usage, usage)
|
||||
|
||||
# TODO - why is this cast needed?
|
||||
return response
|
||||
|
||||
async def create_stream(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
tools: Sequence[Tool] = [],
|
||||
json_output: Optional[bool] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
) -> AsyncGenerator[Union[str, CreateResult], None]:
|
||||
# Make sure all extra_create_args are valid
|
||||
extra_create_args_keys = set(extra_create_args.keys())
|
||||
if not create_kwargs.issuperset(extra_create_args_keys):
|
||||
raise ValueError(f"Extra create args are invalid: {extra_create_args_keys - create_kwargs}")
|
||||
|
||||
# Copy the create args and overwrite anything in extra_create_args
|
||||
create_args = self._create_args.copy()
|
||||
create_args.update(extra_create_args)
|
||||
|
||||
oai_messages_nested = [to_oai_type(m) for m in messages]
|
||||
oai_messages = [item for sublist in oai_messages_nested for item in sublist]
|
||||
|
||||
# TODO: allow custom handling.
|
||||
# For now we raise an error if images are present and vision is not supported
|
||||
if self.capabilities["vision"] is False:
|
||||
for message in messages:
|
||||
if isinstance(message, UserMessage):
|
||||
if isinstance(message.content, list) and any(isinstance(x, Image) for x in message.content):
|
||||
raise ValueError("Model does not support vision and image was provided")
|
||||
|
||||
if json_output is not None:
|
||||
if self.capabilities["json_output"] is False and json_output is True:
|
||||
raise ValueError("Model does not support JSON output")
|
||||
|
||||
if json_output is True:
|
||||
create_args["response_format"] = {"type": "json_object"}
|
||||
else:
|
||||
create_args["response_format"] = {"type": "text"}
|
||||
|
||||
if len(tools) > 0:
|
||||
converted_tools = convert_tools(tools)
|
||||
stream = await self._client.chat.completions.create(
|
||||
messages=oai_messages, stream=True, tools=converted_tools, **create_args
|
||||
)
|
||||
else:
|
||||
stream = await self._client.chat.completions.create(messages=oai_messages, stream=True, **create_args)
|
||||
|
||||
stop_reason = None
|
||||
maybe_model = None
|
||||
content_deltas: List[str] = []
|
||||
full_tool_calls: Dict[int, FunctionCall] = {}
|
||||
completion_tokens = 0
|
||||
|
||||
async for chunk in stream:
|
||||
choice = chunk.choices[0]
|
||||
stop_reason = choice.finish_reason
|
||||
maybe_model = chunk.model
|
||||
# First try get content
|
||||
if choice.delta.content is not None:
|
||||
content_deltas.append(choice.delta.content)
|
||||
if len(choice.delta.content) > 0:
|
||||
yield choice.delta.content
|
||||
continue
|
||||
|
||||
# Otherwise, get tool calls
|
||||
if choice.delta.tool_calls is not None:
|
||||
for tool_call_chunk in choice.delta.tool_calls:
|
||||
idx = tool_call_chunk.index
|
||||
if idx not in full_tool_calls:
|
||||
# We ignore the type hint here because we want to fill in type when the delta provides it
|
||||
full_tool_calls[idx] = FunctionCall(id="", arguments="", name="")
|
||||
|
||||
if tool_call_chunk.id is not None:
|
||||
full_tool_calls[idx].id += tool_call_chunk.id
|
||||
|
||||
if tool_call_chunk.function is not None:
|
||||
if tool_call_chunk.function.name is not None:
|
||||
full_tool_calls[idx].name += tool_call_chunk.function.name
|
||||
if tool_call_chunk.function.arguments is not None:
|
||||
full_tool_calls[idx].arguments += tool_call_chunk.function.arguments
|
||||
|
||||
model = maybe_model or create_args["model"]
|
||||
model = model.replace("gpt-35", "gpt-3.5") # hack for Azure API
|
||||
|
||||
# TODO fix count token
|
||||
prompt_tokens = 0
|
||||
# prompt_tokens = count_token(messages, model=model)
|
||||
if stop_reason is None:
|
||||
raise ValueError("No stop reason found")
|
||||
|
||||
content: Union[str, List[FunctionCall]]
|
||||
if len(content_deltas) > 1:
|
||||
content = "".join(content_deltas)
|
||||
completion_tokens = 0
|
||||
# completion_tokens = count_token(content, model=model)
|
||||
else:
|
||||
completion_tokens = 0
|
||||
# TODO: fix assumption that dict values were added in order and actually order by int index
|
||||
# for tool_call in full_tool_calls.values():
|
||||
# # value = json.dumps(tool_call)
|
||||
# # completion_tokens += count_token(value, model=model)
|
||||
# completion_tokens += 0
|
||||
content = list(full_tool_calls.values())
|
||||
|
||||
usage = RequestUsage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
if stop_reason == "function_call":
|
||||
raise ValueError("Function calls are not supported in this context")
|
||||
if stop_reason == "tool_calls":
|
||||
stop_reason = "function_calls"
|
||||
|
||||
result = CreateResult(finish_reason=stop_reason, content=content, usage=usage, cached=False)
|
||||
|
||||
_add_usage(self._actual_usage, usage)
|
||||
_add_usage(self._total_usage, usage)
|
||||
|
||||
yield result
|
||||
|
||||
def actual_usage(self) -> RequestUsage:
|
||||
return self._actual_usage
|
||||
|
||||
def total_usage(self) -> RequestUsage:
|
||||
return self._total_usage
|
||||
|
||||
@property
|
||||
def capabilities(self) -> ModelCapabilities:
|
||||
return self._model_capabilities
|
||||
|
||||
|
||||
class OpenAI(BaseOpenAI):
|
||||
def __init__(self, **kwargs: Unpack[OpenAIClientConfiguration]):
|
||||
if "model" not in kwargs:
|
||||
raise ValueError("model is required for OpenAI")
|
||||
|
||||
model_capabilities: Optional[ModelCapabilities] = None
|
||||
copied_args = dict(kwargs).copy()
|
||||
if "model_capabilities" in kwargs:
|
||||
model_capabilities = kwargs["model_capabilities"]
|
||||
del copied_args["model_capabilities"]
|
||||
|
||||
client = _openai_client_from_config(copied_args)
|
||||
create_args = _create_args_from_config(copied_args)
|
||||
self._raw_config = copied_args
|
||||
super().__init__(client, create_args, model_capabilities)
|
||||
|
||||
def __getstate__(self) -> Dict[str, Any]:
|
||||
state = self.__dict__.copy()
|
||||
state["_client"] = None
|
||||
return state
|
||||
|
||||
def __setstate__(self, state: Dict[str, Any]) -> None:
|
||||
self.__dict__.update(state)
|
||||
self._client = _openai_client_from_config(state["_raw_config"])
|
||||
|
||||
|
||||
class AzureOpenAI(BaseOpenAI):
|
||||
def __init__(self, **kwargs: Unpack[AzureOpenAIClientConfiguration]):
|
||||
if "model" not in kwargs:
|
||||
raise ValueError("model is required for OpenAI")
|
||||
|
||||
model_capabilities: Optional[ModelCapabilities] = None
|
||||
copied_args = dict(kwargs).copy()
|
||||
if "model_capabilities" in kwargs:
|
||||
model_capabilities = kwargs["model_capabilities"]
|
||||
del copied_args["model_capabilities"]
|
||||
|
||||
client = _azure_openai_client_from_config(copied_args)
|
||||
create_args = _create_args_from_config(copied_args)
|
||||
self._raw_config = copied_args
|
||||
super().__init__(client, create_args, model_capabilities)
|
||||
|
||||
def __getstate__(self) -> Dict[str, Any]:
|
||||
state = self.__dict__.copy()
|
||||
state["_client"] = None
|
||||
return state
|
||||
|
||||
def __setstate__(self, state: Dict[str, Any]) -> None:
|
||||
self.__dict__.update(state)
|
||||
self._client = _azure_openai_client_from_config(state["_raw_config"])
|
||||
56
python/src/agnext/components/models/_types.py
Normal file
56
python/src/agnext/components/models/_types.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Literal, Union
|
||||
|
||||
from .. import FunctionCall, Image
|
||||
|
||||
|
||||
@dataclass
|
||||
class SystemMessage:
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserMessage:
|
||||
content: Union[str, List[Union[str, Image]]]
|
||||
|
||||
# Name of the agent that sent this message
|
||||
source: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class AssistantMessage:
|
||||
content: Union[str, List[FunctionCall]]
|
||||
|
||||
# Name of the agent that sent this message
|
||||
source: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionExecutionResult:
|
||||
content: str
|
||||
call_id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionExecutionResultMessage:
|
||||
content: List[FunctionExecutionResult]
|
||||
|
||||
|
||||
LLMMessage = Union[SystemMessage, UserMessage, AssistantMessage, FunctionExecutionResultMessage]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestUsage:
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
|
||||
|
||||
FinishReasons = Literal["stop", "length", "function_calls", "content_filter"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class CreateResult:
|
||||
finish_reason: FinishReasons
|
||||
content: Union[str, List[FunctionCall]]
|
||||
usage: RequestUsage
|
||||
cached: bool
|
||||
52
python/src/agnext/components/models/config/__init__.py
Normal file
52
python/src/agnext/components/models/config/__init__.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from typing import Awaitable, Callable, Dict, List, Literal, Optional, Union
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
from .._model_client import ModelCapabilities
|
||||
|
||||
|
||||
class ResponseFormat(TypedDict):
|
||||
type: Literal["text", "json_object"]
|
||||
|
||||
|
||||
class CreateArguments(TypedDict, total=False):
|
||||
frequency_penalty: Optional[float]
|
||||
logit_bias: Optional[Dict[str, int]]
|
||||
max_tokens: Optional[int]
|
||||
n: Optional[int]
|
||||
presence_penalty: Optional[float]
|
||||
response_format: ResponseFormat
|
||||
seed: Optional[int]
|
||||
stop: Union[Optional[str], List[str]]
|
||||
temperature: Optional[float]
|
||||
top_p: Optional[float]
|
||||
user: str
|
||||
|
||||
|
||||
AsyncAzureADTokenProvider = Callable[[], Union[str, Awaitable[str]]]
|
||||
|
||||
|
||||
class BaseOpenAIClientConfiguration(CreateArguments, total=False):
|
||||
model: str
|
||||
api_key: str
|
||||
timeout: Union[float, None]
|
||||
max_retries: int
|
||||
|
||||
|
||||
# See OpenAI docs for explanation of these parameters
|
||||
class OpenAIClientConfiguration(BaseOpenAIClientConfiguration, total=False):
|
||||
organization: str
|
||||
base_url: str
|
||||
# Not required
|
||||
model_capabilities: ModelCapabilities
|
||||
|
||||
|
||||
class AzureOpenAIClientConfiguration(BaseOpenAIClientConfiguration, total=False):
|
||||
# Azure specific
|
||||
azure_endpoint: Required[str]
|
||||
azure_deployment: str
|
||||
api_version: Required[str]
|
||||
azure_ad_token: str
|
||||
azure_ad_token_provider: AsyncAzureADTokenProvider
|
||||
# Must be provided
|
||||
model_capabilities: Required[ModelCapabilities]
|
||||
13
python/src/agnext/components/tools/__init__.py
Normal file
13
python/src/agnext/components/tools/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from ._base import BaseTool, BaseToolWithState, Tool
|
||||
from ._code_execution import CodeExecutionInput, CodeExecutionResult, PythonCodeExecutionTool
|
||||
from ._function_tool import FunctionTool
|
||||
|
||||
__all__ = [
|
||||
"Tool",
|
||||
"BaseTool",
|
||||
"BaseToolWithState",
|
||||
"PythonCodeExecutionTool",
|
||||
"CodeExecutionInput",
|
||||
"CodeExecutionResult",
|
||||
"FunctionTool",
|
||||
]
|
||||
151
python/src/agnext/components/tools/_base.py
Normal file
151
python/src/agnext/components/tools/_base.py
Normal file
@@ -0,0 +1,151 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Dict, Generic, Mapping, Protocol, Type, TypedDict, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from ...core import CancellationToken
|
||||
from .._function_utils import normalize_annotated_type
|
||||
|
||||
T = TypeVar("T", bound=BaseModel, contravariant=True)
|
||||
|
||||
|
||||
class ParametersSchema(TypedDict):
|
||||
type: str
|
||||
properties: Dict[str, Any]
|
||||
required: NotRequired[Sequence[str]]
|
||||
|
||||
|
||||
class ToolSchema(TypedDict):
|
||||
parameters: NotRequired[ParametersSchema]
|
||||
name: str
|
||||
description: NotRequired[str]
|
||||
|
||||
|
||||
class Tool(Protocol):
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
|
||||
@property
|
||||
def description(self) -> str: ...
|
||||
|
||||
@property
|
||||
def schema(self) -> ToolSchema: ...
|
||||
|
||||
def args_type(self) -> Type[BaseModel]: ...
|
||||
|
||||
def return_type(self) -> Type[Any]: ...
|
||||
|
||||
def state_type(self) -> Type[BaseModel] | None: ...
|
||||
|
||||
def return_value_as_string(self, value: Any) -> str: ...
|
||||
|
||||
async def run_json(self, args: Mapping[str, Any], cancellation_token: CancellationToken) -> Any: ...
|
||||
|
||||
def save_state_json(self) -> Mapping[str, Any]: ...
|
||||
|
||||
def load_state_json(self, state: Mapping[str, Any]) -> None: ...
|
||||
|
||||
|
||||
ArgsT = TypeVar("ArgsT", bound=BaseModel, contravariant=True)
|
||||
ReturnT = TypeVar("ReturnT", bound=BaseModel, covariant=True)
|
||||
StateT = TypeVar("StateT", bound=BaseModel)
|
||||
|
||||
|
||||
class BaseTool(ABC, Tool, Generic[ArgsT, ReturnT]):
|
||||
def __init__(
|
||||
self,
|
||||
args_type: Type[ArgsT],
|
||||
return_type: Type[ReturnT],
|
||||
name: str,
|
||||
description: str,
|
||||
) -> None:
|
||||
self._args_type = args_type
|
||||
# Normalize Annotated to the base type.
|
||||
self._return_type = normalize_annotated_type(return_type)
|
||||
self._name = name
|
||||
self._description = description
|
||||
|
||||
@property
|
||||
def schema(self) -> ToolSchema:
|
||||
model_schema = self._args_type.model_json_schema()
|
||||
|
||||
tool_schema = ToolSchema(
|
||||
name=self._name,
|
||||
description=self._description,
|
||||
parameters=ParametersSchema(
|
||||
type="object",
|
||||
properties=model_schema["properties"],
|
||||
),
|
||||
)
|
||||
if "required" in model_schema:
|
||||
assert "parameters" in tool_schema
|
||||
tool_schema["parameters"]["required"] = model_schema["required"]
|
||||
|
||||
return tool_schema
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return self._description
|
||||
|
||||
def args_type(self) -> Type[BaseModel]:
|
||||
return self._args_type
|
||||
|
||||
def return_type(self) -> Type[Any]:
|
||||
return self._return_type
|
||||
|
||||
def state_type(self) -> Type[BaseModel] | None:
|
||||
return None
|
||||
|
||||
def return_value_as_string(self, value: Any) -> str:
|
||||
if isinstance(value, BaseModel):
|
||||
dumped = value.model_dump()
|
||||
if isinstance(dumped, dict):
|
||||
return json.dumps(dumped)
|
||||
return str(dumped)
|
||||
|
||||
return str(value)
|
||||
|
||||
@abstractmethod
|
||||
async def run(self, args: ArgsT, cancellation_token: CancellationToken) -> ReturnT: ...
|
||||
|
||||
async def run_json(self, args: Mapping[str, Any], cancellation_token: CancellationToken) -> Any:
|
||||
return_value = await self.run(self._args_type.model_validate(args), cancellation_token)
|
||||
return return_value
|
||||
|
||||
def save_state_json(self) -> Mapping[str, Any]:
|
||||
return {}
|
||||
|
||||
def load_state_json(self, state: Mapping[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class BaseToolWithState(BaseTool[ArgsT, ReturnT], ABC, Generic[ArgsT, ReturnT, StateT]):
|
||||
def __init__(
|
||||
self,
|
||||
args_type: Type[ArgsT],
|
||||
return_type: Type[ReturnT],
|
||||
state_type: Type[StateT],
|
||||
name: str,
|
||||
description: str,
|
||||
) -> None:
|
||||
super().__init__(args_type, return_type, name, description)
|
||||
self._state_type = state_type
|
||||
|
||||
@abstractmethod
|
||||
def save_state(self) -> StateT: ...
|
||||
|
||||
@abstractmethod
|
||||
def load_state(self, state: StateT) -> None: ...
|
||||
|
||||
def save_state_json(self) -> Mapping[str, Any]:
|
||||
return self.save_state().model_dump()
|
||||
|
||||
def load_state_json(self, state: Mapping[str, Any]) -> None:
|
||||
self.load_state(self._state_type.model_validate(state))
|
||||
37
python/src/agnext/components/tools/_code_execution.py
Normal file
37
python/src/agnext/components/tools/_code_execution.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import asyncio
|
||||
import functools
|
||||
|
||||
from pydantic import BaseModel, Field, model_serializer
|
||||
|
||||
from ...core import CancellationToken
|
||||
from ..code_executor import CodeBlock, CodeExecutor
|
||||
from ._base import BaseTool
|
||||
|
||||
|
||||
class CodeExecutionInput(BaseModel):
|
||||
code: str = Field(description="The contents of the Python code block that should be executed")
|
||||
|
||||
|
||||
class CodeExecutionResult(BaseModel):
|
||||
success: bool
|
||||
output: str
|
||||
|
||||
@model_serializer
|
||||
def ser_model(self) -> str:
|
||||
return self.output
|
||||
|
||||
|
||||
class PythonCodeExecutionTool(BaseTool[CodeExecutionInput, CodeExecutionResult]):
|
||||
def __init__(self, executor: CodeExecutor):
|
||||
super().__init__(CodeExecutionInput, CodeExecutionResult, "CodeExecutor", "Execute Python code blocks.")
|
||||
self._executor = executor
|
||||
|
||||
async def run(self, args: CodeExecutionInput, cancellation_token: CancellationToken) -> CodeExecutionResult:
|
||||
code_blocks = [CodeBlock(code=args.code, language="python")]
|
||||
future = asyncio.get_event_loop().run_in_executor(
|
||||
None, functools.partial(self._executor.execute_code_blocks, code_blocks=code_blocks)
|
||||
)
|
||||
cancellation_token.link_future(future)
|
||||
result = await future
|
||||
|
||||
return CodeExecutionResult(success=result.exit_code == 0, output=result.output)
|
||||
50
python/src/agnext/components/tools/_function_tool.py
Normal file
50
python/src/agnext/components/tools/_function_tool.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import asyncio
|
||||
import functools
|
||||
from typing import Any, Callable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ...core import CancellationToken
|
||||
from .._function_utils import (
|
||||
args_base_model_from_signature,
|
||||
get_typed_signature,
|
||||
)
|
||||
from ._base import BaseTool
|
||||
|
||||
|
||||
class FunctionTool(BaseTool[BaseModel, BaseModel]):
|
||||
def __init__(self, func: Callable[..., Any], description: str, name: str | None = None) -> None:
|
||||
self._func = func
|
||||
signature = get_typed_signature(func)
|
||||
func_name = name or func.__name__
|
||||
args_model = args_base_model_from_signature(func_name + "args", signature)
|
||||
return_type = signature.return_annotation
|
||||
self._has_cancellation_support = "cancellation_token" in signature.parameters
|
||||
|
||||
super().__init__(args_model, return_type, func_name, description)
|
||||
|
||||
async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> Any:
|
||||
if asyncio.iscoroutinefunction(self._func):
|
||||
if self._has_cancellation_support:
|
||||
result = await self._func(**args.model_dump(), cancellation_token=cancellation_token)
|
||||
else:
|
||||
result = await self._func(**args.model_dump())
|
||||
else:
|
||||
if self._has_cancellation_support:
|
||||
result = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
self._func,
|
||||
**args.model_dump(),
|
||||
cancellation_token=cancellation_token,
|
||||
),
|
||||
)
|
||||
else:
|
||||
future = asyncio.get_event_loop().run_in_executor(
|
||||
None, functools.partial(self._func, **args.model_dump())
|
||||
)
|
||||
cancellation_token.link_future(future)
|
||||
result = await future
|
||||
|
||||
assert isinstance(result, self.return_type())
|
||||
return result
|
||||
Reference in New Issue
Block a user