Add examples to showcase patterns (#55)

* add chess example

* wip

* wip

* fix tool schema generation

* fixes

* Agent handle exception

Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com>

* format

* mypy

* fix test for annotated

---------

Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com>
This commit is contained in:
Eric Zhu
2024-06-07 13:33:51 -07:00
committed by GitHub
parent c6360feeb6
commit b4ade8b735
12 changed files with 420 additions and 52 deletions

View File

@@ -105,18 +105,18 @@ def message_handler(
@wraps(func)
async def wrapper(self: Any, message: ReceivesT, cancellation_token: CancellationToken) -> ProducesT:
if strict:
if type(message) not in target_types:
if type(message) not in target_types:
if strict:
raise CantHandleException(f"Message type {type(message)} not in target types {target_types}")
else:
logger.warning(f"Message type {type(message)} not in target types {target_types}")
return_value = await func(self, message, cancellation_token)
if strict:
if return_value is not AnyType and type(return_value) not in return_types:
if AnyType not in return_types and type(return_value) not in return_types:
if strict:
raise ValueError(f"Return type {type(return_value)} not in return types {return_types}")
elif return_value is not AnyType:
else:
logger.warning(f"Return type {type(return_value)} not in return types {return_types}")
return return_value
@@ -134,7 +134,10 @@ def message_handler(
class TypeRoutedAgent(BaseAgent):
def __init__(self, name: str, router: AgentRuntime) -> None:
# Self is already bound to the handlers
self._handlers: Dict[Type[Any], Callable[[Any, CancellationToken], Coroutine[Any, Any, Any | None]]] = {}
self._handlers: Dict[
Type[Any],
Callable[[Any, CancellationToken], Coroutine[Any, Any, Any | None]],
] = {}
for attr in dir(self):
if callable(getattr(self, attr, None)):

View File

@@ -1,5 +1,6 @@
import inspect
import logging
import re
import warnings
from typing import (
Any,
@@ -11,6 +12,7 @@ from typing import (
Sequence,
Set,
Union,
cast,
)
from openai import AsyncAzureOpenAI, AsyncOpenAI
@@ -27,6 +29,7 @@ from openai.types.chat import (
ChatCompletionUserMessageParam,
completion_create_params,
)
from openai.types.shared_params import FunctionDefinition, FunctionParameters
from typing_extensions import Unpack
from ...application.logging import EVENT_LOGGER_NAME, LLMCallEvent
@@ -205,15 +208,47 @@ def convert_tools(
) -> List[ChatCompletionToolParam]:
result: List[ChatCompletionToolParam] = []
for tool in tools:
tool_schema = tool.schema
result.append(
{
"type": "function",
"function": tool.schema, # type: ignore
}
ChatCompletionToolParam(
type="function",
function=FunctionDefinition(
name=tool_schema["name"],
description=tool_schema["description"] if "description" in tool_schema else "",
parameters=cast(FunctionParameters, tool_schema["parameters"])
if "parameters" in tool_schema
else {},
),
)
)
# Check if all tools have valid names.
for tool_param in result:
assert_valid_name(tool_param["function"]["name"])
return result
def normalize_name(name: str) -> str:
"""
LLMs sometimes ask functions while ignoring their own format requirements, this function should be used to replace invalid characters with "_".
Prefer _assert_valid_name for validating user configuration or input
"""
return re.sub(r"[^a-zA-Z0-9_-]", "_", name)[:64]
def assert_valid_name(name: str) -> str:
"""
Ensure that configured names are valid, raises ValueError if not.
For munging LLM responses use _normalize_name to ensure LLM specified names don't break the API.
"""
if not re.match(r"^[a-zA-Z0-9_-]+$", name):
raise ValueError(f"Invalid name: {name}. Only letters, numbers, '_' and '-' are allowed.")
if len(name) > 64:
raise ValueError(f"Invalid name: {name}. Name must be less than 64 characters.")
return name
class BaseOpenAI(ChatCompletionClient):
def __init__(
self,
@@ -293,7 +328,10 @@ class BaseOpenAI(ChatCompletionClient):
if len(tools) > 0:
converted_tools = convert_tools(tools)
result = await self._client.chat.completions.create(
messages=oai_messages, stream=False, tools=converted_tools, **create_args
messages=oai_messages,
stream=False,
tools=converted_tools,
**create_args,
)
else:
result = await self._client.chat.completions.create(messages=oai_messages, stream=False, **create_args)
@@ -331,7 +369,11 @@ class BaseOpenAI(ChatCompletionClient):
# NOTE: If OAI response type changes, this will need to be updated
content = [
FunctionCall(id=x.id, arguments=x.function.arguments, name=x.function.name)
FunctionCall(
id=x.id,
arguments=x.function.arguments,
name=normalize_name(x.function.name),
)
for x in choice.message.tool_calls
]
finish_reason = "function_calls"

View File

@@ -1,14 +1,29 @@
import json
from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, Mapping, Protocol, Type, TypeVar
from collections.abc import Sequence
from typing import Any, Dict, Generic, Mapping, Protocol, Type, TypedDict, TypeVar
from pydantic import BaseModel
from typing_extensions import NotRequired
from ...core import CancellationToken
from .._function_utils import normalize_annotated_type
T = TypeVar("T", bound=BaseModel, contravariant=True)
class ParametersSchema(TypedDict):
type: str
properties: Dict[str, Any]
required: NotRequired[Sequence[str]]
class ToolSchema(TypedDict):
parameters: NotRequired[ParametersSchema]
name: str
description: NotRequired[str]
class Tool(Protocol):
@property
def name(self) -> str: ...
@@ -17,7 +32,7 @@ class Tool(Protocol):
def description(self) -> str: ...
@property
def schema(self) -> Mapping[str, Any]: ...
def schema(self) -> ToolSchema: ...
def args_type(self) -> Type[BaseModel]: ...
@@ -40,20 +55,36 @@ StateT = TypeVar("StateT", bound=BaseModel)
class BaseTool(ABC, Tool, Generic[ArgsT, ReturnT]):
def __init__(self, args_type: Type[ArgsT], return_type: Type[ReturnT], name: str, description: str) -> None:
def __init__(
self,
args_type: Type[ArgsT],
return_type: Type[ReturnT],
name: str,
description: str,
) -> None:
self._args_type = args_type
self._return_type = return_type
# Normalize Annotated to the base type.
self._return_type = normalize_annotated_type(return_type)
self._name = name
self._description = description
@property
def schema(self) -> Mapping[str, Any]:
def schema(self) -> ToolSchema:
model_schema = self._args_type.model_json_schema()
parameter_schema: Dict[str, Any] = dict()
parameter_schema["parameters"] = model_schema["properties"]
parameter_schema["name"] = self._name
parameter_schema["description"] = self._description
return parameter_schema
tool_schema = ToolSchema(
name=self._name,
description=self._description,
parameters=ParametersSchema(
type="object",
properties=model_schema["properties"],
),
)
if "required" in model_schema:
assert "parameters" in tool_schema
tool_schema["parameters"]["required"] = model_schema["required"]
return tool_schema
@property
def name(self) -> str:
@@ -97,7 +128,12 @@ class BaseTool(ABC, Tool, Generic[ArgsT, ReturnT]):
class BaseToolWithState(BaseTool[ArgsT, ReturnT], ABC, Generic[ArgsT, ReturnT, StateT]):
def __init__(
self, args_type: Type[ArgsT], return_type: Type[ReturnT], state_type: Type[StateT], name: str, description: str
self,
args_type: Type[ArgsT],
return_type: Type[ReturnT],
state_type: Type[StateT],
name: str,
description: str,
) -> None:
super().__init__(args_type, return_type, name, description)
self._state_type = state_type

View File

@@ -32,7 +32,12 @@ class FunctionTool(BaseTool[BaseModel, BaseModel]):
else:
if self._has_cancellation_support:
result = await asyncio.get_event_loop().run_in_executor(
None, functools.partial(self._func, **args.model_dump(), cancellation_token=cancellation_token)
None,
functools.partial(
self._func,
**args.model_dump(),
cancellation_token=cancellation_token,
),
)
else:
future = asyncio.get_event_loop().run_in_executor(