Implement closure agent (#143)

This commit is contained in:
Jack Gerrits
2024-06-28 10:22:44 -04:00
committed by GitHub
parent 8901b4d224
commit 13b0d0deb4
12 changed files with 218 additions and 45 deletions

View File

@@ -2,8 +2,9 @@
The :mod:`agnext.components` module provides building blocks for creating single agents
"""
from ._closure_agent import ClosureAgent
from ._image import Image
from ._type_routed_agent import TypeRoutedAgent, message_handler
from ._types import FunctionCall
__all__ = ["Image", "TypeRoutedAgent", "message_handler", "FunctionCall"]
__all__ = ["Image", "TypeRoutedAgent", "ClosureAgent", "message_handler", "FunctionCall"]

View File

@@ -0,0 +1,93 @@
import inspect
from typing import Any, Awaitable, Callable, Mapping, Sequence, TypeVar, get_type_hints
from ..core._agent import Agent
from ..core._agent_id import AgentId
from ..core._agent_metadata import AgentMetadata
from ..core._agent_runtime import AgentRuntime, agent_instantiation_context
from ..core._cancellation_token import CancellationToken
from ..core.exceptions import CantHandleException
from ._type_helpers import get_types
T = TypeVar("T")
def get_subscriptions_from_closure(
closure: Callable[[AgentRuntime, AgentId, T, CancellationToken], Awaitable[Any]],
) -> Sequence[type]:
args = inspect.getfullargspec(closure)[0]
if len(args) != 4:
raise AssertionError("Closure must have 4 arguments")
message_arg_name = args[2]
type_hints = get_type_hints(closure)
if "return" not in type_hints:
raise AssertionError("return not found in function signature")
# Get the type of the message parameter
target_types = get_types(type_hints[message_arg_name])
if target_types is None:
raise AssertionError("Message type not found")
# print(type_hints)
return_types = get_types(type_hints["return"])
if return_types is None:
raise AssertionError("Return type not found")
return target_types
class ClosureAgent(Agent):
def __init__(
self, description: str, closure: Callable[[AgentRuntime, AgentId, T, CancellationToken], Awaitable[Any]]
) -> None:
try:
runtime, id = agent_instantiation_context.get()
except LookupError as e:
raise RuntimeError(
"ClosureAgent must be instantiated within the context of an AgentRuntime. It cannot be directly instantiated."
) from e
self._runtime: AgentRuntime = runtime
self._id: AgentId = id
self._description = description
self._subscriptions = get_subscriptions_from_closure(closure)
self._closure = closure
@property
def metadata(self) -> AgentMetadata:
assert self._id is not None
return AgentMetadata(
namespace=self._id.namespace,
name=self._id.name,
description=self._description,
subscriptions=self._subscriptions,
)
@property
def name(self) -> str:
return self.id.name
@property
def id(self) -> AgentId:
return self._id
@property
def runtime(self) -> AgentRuntime:
return self._runtime
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any:
if type(message) not in self._subscriptions:
raise CantHandleException(
f"Message type {type(message)} not in target types {self._subscriptions} of {self.id}"
)
return await self._closure(self._runtime, self._id, message, cancellation_token)
def save_state(self) -> Mapping[str, Any]:
raise ValueError("save_state not implemented for ClosureAgent")
def load_state(self, state: Mapping[str, Any]) -> None:
raise ValueError("load_state not implemented for ClosureAgent")

View File

@@ -0,0 +1,33 @@
from collections.abc import Sequence
from types import NoneType, UnionType
from typing import Any, Optional, Type, Union, get_args, get_origin
def is_union(t: object) -> bool:
origin = get_origin(t)
return origin is Union or origin is UnionType
def is_optional(t: object) -> bool:
origin = get_origin(t)
return origin is Optional
# Special type to avoid the 3.10 vs 3.11+ difference of typing._SpecialForm vs typing.Any
class AnyType:
pass
def get_types(t: object) -> Sequence[Type[Any]] | None:
if is_union(t):
return get_args(t)
elif is_optional(t):
return tuple(list(get_args(t)) + [NoneType])
elif t is Any:
return (AnyType,)
elif isinstance(t, type):
return (t,)
elif isinstance(t, NoneType):
return (NoneType,)
else:
return None

View File

@@ -1,6 +1,5 @@
import logging
from functools import wraps
from types import NoneType, UnionType
from typing import (
Any,
Callable,
@@ -8,15 +7,11 @@ from typing import (
Dict,
Literal,
NoReturn,
Optional,
Protocol,
Sequence,
Type,
TypeVar,
Union,
cast,
get_args,
get_origin,
get_type_hints,
overload,
runtime_checkable,
@@ -24,6 +19,7 @@ from typing import (
from ..core import BaseAgent, CancellationToken
from ..core.exceptions import CantHandleException
from ._type_helpers import AnyType, get_types
logger = logging.getLogger("agnext")
@@ -34,36 +30,6 @@ ProducesT = TypeVar("ProducesT", covariant=True)
# Can't do because python doesnt support it
def is_union(t: object) -> bool:
origin = get_origin(t)
return origin is Union or origin is UnionType
def is_optional(t: object) -> bool:
origin = get_origin(t)
return origin is Optional
# Special type to avoid the 3.10 vs 3.11+ difference of typing._SpecialForm vs typing.Any
class AnyType:
pass
def get_types(t: object) -> Sequence[Type[Any]] | None:
if is_union(t):
return get_args(t)
elif is_optional(t):
return tuple(list(get_args(t)) + [NoneType])
elif t is Any:
return (AnyType,)
elif isinstance(t, type):
return (t,)
elif isinstance(t, NoneType):
return (NoneType,)
else:
return None
@runtime_checkable
class MessageHandler(Protocol[ReceivesT, ProducesT]):
target_types: Sequence[type]