mirror of
https://github.com/microsoft/autogen.git
synced 2026-02-12 11:25:30 -05:00
Add examples to showcase patterns (#55)
* add chess example * wip * wip * fix tool schema generation * fixes * Agent handle exception Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com> * format * mypy * fix test for annotated --------- Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com>
This commit is contained in:
@@ -105,18 +105,18 @@ def message_handler(
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(self: Any, message: ReceivesT, cancellation_token: CancellationToken) -> ProducesT:
|
||||
if strict:
|
||||
if type(message) not in target_types:
|
||||
if type(message) not in target_types:
|
||||
if strict:
|
||||
raise CantHandleException(f"Message type {type(message)} not in target types {target_types}")
|
||||
else:
|
||||
logger.warning(f"Message type {type(message)} not in target types {target_types}")
|
||||
|
||||
return_value = await func(self, message, cancellation_token)
|
||||
|
||||
if strict:
|
||||
if return_value is not AnyType and type(return_value) not in return_types:
|
||||
if AnyType not in return_types and type(return_value) not in return_types:
|
||||
if strict:
|
||||
raise ValueError(f"Return type {type(return_value)} not in return types {return_types}")
|
||||
elif return_value is not AnyType:
|
||||
else:
|
||||
logger.warning(f"Return type {type(return_value)} not in return types {return_types}")
|
||||
|
||||
return return_value
|
||||
@@ -134,7 +134,10 @@ def message_handler(
|
||||
class TypeRoutedAgent(BaseAgent):
|
||||
def __init__(self, name: str, router: AgentRuntime) -> None:
|
||||
# Self is already bound to the handlers
|
||||
self._handlers: Dict[Type[Any], Callable[[Any, CancellationToken], Coroutine[Any, Any, Any | None]]] = {}
|
||||
self._handlers: Dict[
|
||||
Type[Any],
|
||||
Callable[[Any, CancellationToken], Coroutine[Any, Any, Any | None]],
|
||||
] = {}
|
||||
|
||||
for attr in dir(self):
|
||||
if callable(getattr(self, attr, None)):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import inspect
|
||||
import logging
|
||||
import re
|
||||
import warnings
|
||||
from typing import (
|
||||
Any,
|
||||
@@ -11,6 +12,7 @@ from typing import (
|
||||
Sequence,
|
||||
Set,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
||||
@@ -27,6 +29,7 @@ from openai.types.chat import (
|
||||
ChatCompletionUserMessageParam,
|
||||
completion_create_params,
|
||||
)
|
||||
from openai.types.shared_params import FunctionDefinition, FunctionParameters
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from ...application.logging import EVENT_LOGGER_NAME, LLMCallEvent
|
||||
@@ -205,15 +208,47 @@ def convert_tools(
|
||||
) -> List[ChatCompletionToolParam]:
|
||||
result: List[ChatCompletionToolParam] = []
|
||||
for tool in tools:
|
||||
tool_schema = tool.schema
|
||||
result.append(
|
||||
{
|
||||
"type": "function",
|
||||
"function": tool.schema, # type: ignore
|
||||
}
|
||||
ChatCompletionToolParam(
|
||||
type="function",
|
||||
function=FunctionDefinition(
|
||||
name=tool_schema["name"],
|
||||
description=tool_schema["description"] if "description" in tool_schema else "",
|
||||
parameters=cast(FunctionParameters, tool_schema["parameters"])
|
||||
if "parameters" in tool_schema
|
||||
else {},
|
||||
),
|
||||
)
|
||||
)
|
||||
# Check if all tools have valid names.
|
||||
for tool_param in result:
|
||||
assert_valid_name(tool_param["function"]["name"])
|
||||
return result
|
||||
|
||||
|
||||
def normalize_name(name: str) -> str:
|
||||
"""
|
||||
LLMs sometimes ask functions while ignoring their own format requirements, this function should be used to replace invalid characters with "_".
|
||||
|
||||
Prefer _assert_valid_name for validating user configuration or input
|
||||
"""
|
||||
return re.sub(r"[^a-zA-Z0-9_-]", "_", name)[:64]
|
||||
|
||||
|
||||
def assert_valid_name(name: str) -> str:
|
||||
"""
|
||||
Ensure that configured names are valid, raises ValueError if not.
|
||||
|
||||
For munging LLM responses use _normalize_name to ensure LLM specified names don't break the API.
|
||||
"""
|
||||
if not re.match(r"^[a-zA-Z0-9_-]+$", name):
|
||||
raise ValueError(f"Invalid name: {name}. Only letters, numbers, '_' and '-' are allowed.")
|
||||
if len(name) > 64:
|
||||
raise ValueError(f"Invalid name: {name}. Name must be less than 64 characters.")
|
||||
return name
|
||||
|
||||
|
||||
class BaseOpenAI(ChatCompletionClient):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -293,7 +328,10 @@ class BaseOpenAI(ChatCompletionClient):
|
||||
if len(tools) > 0:
|
||||
converted_tools = convert_tools(tools)
|
||||
result = await self._client.chat.completions.create(
|
||||
messages=oai_messages, stream=False, tools=converted_tools, **create_args
|
||||
messages=oai_messages,
|
||||
stream=False,
|
||||
tools=converted_tools,
|
||||
**create_args,
|
||||
)
|
||||
else:
|
||||
result = await self._client.chat.completions.create(messages=oai_messages, stream=False, **create_args)
|
||||
@@ -331,7 +369,11 @@ class BaseOpenAI(ChatCompletionClient):
|
||||
|
||||
# NOTE: If OAI response type changes, this will need to be updated
|
||||
content = [
|
||||
FunctionCall(id=x.id, arguments=x.function.arguments, name=x.function.name)
|
||||
FunctionCall(
|
||||
id=x.id,
|
||||
arguments=x.function.arguments,
|
||||
name=normalize_name(x.function.name),
|
||||
)
|
||||
for x in choice.message.tool_calls
|
||||
]
|
||||
finish_reason = "function_calls"
|
||||
|
||||
@@ -1,14 +1,29 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Generic, Mapping, Protocol, Type, TypeVar
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Dict, Generic, Mapping, Protocol, Type, TypedDict, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from ...core import CancellationToken
|
||||
from .._function_utils import normalize_annotated_type
|
||||
|
||||
T = TypeVar("T", bound=BaseModel, contravariant=True)
|
||||
|
||||
|
||||
class ParametersSchema(TypedDict):
|
||||
type: str
|
||||
properties: Dict[str, Any]
|
||||
required: NotRequired[Sequence[str]]
|
||||
|
||||
|
||||
class ToolSchema(TypedDict):
|
||||
parameters: NotRequired[ParametersSchema]
|
||||
name: str
|
||||
description: NotRequired[str]
|
||||
|
||||
|
||||
class Tool(Protocol):
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
@@ -17,7 +32,7 @@ class Tool(Protocol):
|
||||
def description(self) -> str: ...
|
||||
|
||||
@property
|
||||
def schema(self) -> Mapping[str, Any]: ...
|
||||
def schema(self) -> ToolSchema: ...
|
||||
|
||||
def args_type(self) -> Type[BaseModel]: ...
|
||||
|
||||
@@ -40,20 +55,36 @@ StateT = TypeVar("StateT", bound=BaseModel)
|
||||
|
||||
|
||||
class BaseTool(ABC, Tool, Generic[ArgsT, ReturnT]):
|
||||
def __init__(self, args_type: Type[ArgsT], return_type: Type[ReturnT], name: str, description: str) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
args_type: Type[ArgsT],
|
||||
return_type: Type[ReturnT],
|
||||
name: str,
|
||||
description: str,
|
||||
) -> None:
|
||||
self._args_type = args_type
|
||||
self._return_type = return_type
|
||||
# Normalize Annotated to the base type.
|
||||
self._return_type = normalize_annotated_type(return_type)
|
||||
self._name = name
|
||||
self._description = description
|
||||
|
||||
@property
|
||||
def schema(self) -> Mapping[str, Any]:
|
||||
def schema(self) -> ToolSchema:
|
||||
model_schema = self._args_type.model_json_schema()
|
||||
parameter_schema: Dict[str, Any] = dict()
|
||||
parameter_schema["parameters"] = model_schema["properties"]
|
||||
parameter_schema["name"] = self._name
|
||||
parameter_schema["description"] = self._description
|
||||
return parameter_schema
|
||||
|
||||
tool_schema = ToolSchema(
|
||||
name=self._name,
|
||||
description=self._description,
|
||||
parameters=ParametersSchema(
|
||||
type="object",
|
||||
properties=model_schema["properties"],
|
||||
),
|
||||
)
|
||||
if "required" in model_schema:
|
||||
assert "parameters" in tool_schema
|
||||
tool_schema["parameters"]["required"] = model_schema["required"]
|
||||
|
||||
return tool_schema
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -97,7 +128,12 @@ class BaseTool(ABC, Tool, Generic[ArgsT, ReturnT]):
|
||||
|
||||
class BaseToolWithState(BaseTool[ArgsT, ReturnT], ABC, Generic[ArgsT, ReturnT, StateT]):
|
||||
def __init__(
|
||||
self, args_type: Type[ArgsT], return_type: Type[ReturnT], state_type: Type[StateT], name: str, description: str
|
||||
self,
|
||||
args_type: Type[ArgsT],
|
||||
return_type: Type[ReturnT],
|
||||
state_type: Type[StateT],
|
||||
name: str,
|
||||
description: str,
|
||||
) -> None:
|
||||
super().__init__(args_type, return_type, name, description)
|
||||
self._state_type = state_type
|
||||
|
||||
@@ -32,7 +32,12 @@ class FunctionTool(BaseTool[BaseModel, BaseModel]):
|
||||
else:
|
||||
if self._has_cancellation_support:
|
||||
result = await asyncio.get_event_loop().run_in_executor(
|
||||
None, functools.partial(self._func, **args.model_dump(), cancellation_token=cancellation_token)
|
||||
None,
|
||||
functools.partial(
|
||||
self._func,
|
||||
**args.model_dump(),
|
||||
cancellation_token=cancellation_token,
|
||||
),
|
||||
)
|
||||
else:
|
||||
future = asyncio.get_event_loop().run_in_executor(
|
||||
|
||||
Reference in New Issue
Block a user