Move python code to subdir (#98)

This commit is contained in:
Jack Gerrits
2024-06-20 15:19:56 -04:00
committed by GitHub
parent c9e09e2d27
commit d365a588cb
102 changed files with 57 additions and 51 deletions

View File

View File

@@ -0,0 +1,7 @@
"""
The :mod:`agnext.application` module provides implementations of core components that are used to compose an application
"""
from ._single_threaded_agent_runtime import SingleThreadedAgentRuntime
__all__ = ["SingleThreadedAgentRuntime"]

View File

@@ -0,0 +1,459 @@
import asyncio
import inspect
import logging
import threading
from asyncio import Future
from collections import defaultdict
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, Awaitable, Callable, DefaultDict, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast
from ..core import Agent, AgentId, AgentMetadata, AgentProxy, AgentRuntime, AllNamespaces, BaseAgent, CancellationToken
from ..core.exceptions import MessageDroppedException
from ..core.intervention import DropMessage, InterventionHandler
logger = logging.getLogger("agnext")
event_logger = logging.getLogger("agnext.events")
@dataclass(kw_only=True)
class PublishMessageEnvelope:
"""A message envelope for publishing messages to all agents that can handle
the message of the type T."""
message: Any
cancellation_token: CancellationToken
sender: AgentId | None
namespace: str
@dataclass(kw_only=True)
class SendMessageEnvelope:
"""A message envelope for sending a message to a specific agent that can handle
the message of the type T."""
message: Any
sender: AgentId | None
recipient: AgentId
future: Future[Any]
cancellation_token: CancellationToken
@dataclass(kw_only=True)
class ResponseMessageEnvelope:
"""A message envelope for sending a response to a message."""
message: Any
future: Future[Any]
sender: AgentId
recipient: AgentId | None
P = ParamSpec("P")
T = TypeVar("T", bound=Agent)
class Counter:
def __init__(self) -> None:
self._count: int = 0
self.threadLock = threading.Lock()
def increment(self) -> None:
self.threadLock.acquire()
self._count += 1
self.threadLock.release()
def get(self) -> int:
return self._count
def decrement(self) -> None:
self.threadLock.acquire()
self._count -= 1
self.threadLock.release()
class SingleThreadedAgentRuntime(AgentRuntime):
def __init__(self, *, before_send: InterventionHandler | None = None) -> None:
self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = []
# (namespace, type) -> List[AgentId]
self._per_type_subscribers: DefaultDict[tuple[str, type], Set[AgentId]] = defaultdict(set)
self._agent_factories: Dict[str, Callable[[], Agent] | Callable[[AgentRuntime, AgentId], Agent]] = {}
# If empty, then all namespaces are valid for that agent type
self._valid_namespaces: Dict[str, Sequence[str]] = {}
self._instantiated_agents: Dict[AgentId, Agent] = {}
self._before_send = before_send
self._known_namespaces: set[str] = set()
self._outstanding_tasks = Counter()
@property
def unprocessed_messages(
self,
) -> Sequence[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope]:
return self._message_queue
@property
def outstanding_tasks(self) -> int:
return self._outstanding_tasks.get()
@property
def _known_agent_names(self) -> Set[str]:
return set(self._agent_factories.keys())
# Returns the response of the message
def send_message(
self,
message: Any,
recipient: AgentId,
*,
sender: AgentId | None = None,
cancellation_token: CancellationToken | None = None,
) -> Future[Any | None]:
if cancellation_token is None:
cancellation_token = CancellationToken()
# event_logger.info(
# MessageEvent(
# payload=message,
# sender=sender,
# receiver=recipient,
# kind=MessageKind.DIRECT,
# delivery_stage=DeliveryStage.SEND,
# )
# )
future = asyncio.get_event_loop().create_future()
if recipient.name not in self._known_agent_names:
future.set_exception(Exception("Recipient not found"))
if sender is not None and sender.namespace != recipient.namespace:
raise ValueError("Sender and recipient must be in the same namespace to communicate.")
self._process_seen_namespace(recipient.namespace)
logger.info(f"Sending message of type {type(message).__name__} to {recipient.name}: {message.__dict__}")
self._message_queue.append(
SendMessageEnvelope(
message=message,
recipient=recipient,
future=future,
cancellation_token=cancellation_token,
sender=sender,
)
)
return future
def publish_message(
self,
message: Any,
*,
namespace: str | None = None,
sender: AgentId | None = None,
cancellation_token: CancellationToken | None = None,
) -> Future[None]:
if cancellation_token is None:
cancellation_token = CancellationToken()
logger.info(f"Publishing message of type {type(message).__name__} to all subscribers: {message.__dict__}")
# event_logger.info(
# MessageEvent(
# payload=message,
# sender=sender,
# receiver=None,
# kind=MessageKind.PUBLISH,
# delivery_stage=DeliveryStage.SEND,
# )
# )
if sender is None and namespace is None:
raise ValueError("Namespace must be provided if sender is not provided.")
sender_namespace = sender.namespace if sender is not None else None
explicit_namespace = namespace
if explicit_namespace is not None and sender_namespace is not None and explicit_namespace != sender_namespace:
raise ValueError(
f"Explicit namespace {explicit_namespace} does not match sender namespace {sender_namespace}"
)
assert explicit_namespace is not None or sender_namespace is not None
namespace = cast(str, explicit_namespace or sender_namespace)
self._process_seen_namespace(namespace)
self._message_queue.append(
PublishMessageEnvelope(
message=message,
cancellation_token=cancellation_token,
sender=sender,
namespace=namespace,
)
)
future = asyncio.get_event_loop().create_future()
future.set_result(None)
return future
def save_state(self) -> Mapping[str, Any]:
state: Dict[str, Dict[str, Any]] = {}
for agent_id in self._instantiated_agents:
state[str(agent_id)] = dict(self._get_agent(agent_id).save_state())
return state
def load_state(self, state: Mapping[str, Any]) -> None:
for agent_id_str in state:
agent_id = AgentId.from_str(agent_id_str)
if agent_id.name in self._known_agent_names:
self._get_agent(agent_id).load_state(state[str(agent_id)])
async def _process_send(self, message_envelope: SendMessageEnvelope) -> None:
recipient = message_envelope.recipient
# todo: check if recipient is in the known namespaces
# assert recipient in self._agents
try:
sender_name = message_envelope.sender.name if message_envelope.sender is not None else "Unknown"
logger.info(
f"Calling message handler for {recipient} with message type {type(message_envelope.message).__name__} sent by {sender_name}"
)
# event_logger.info(
# MessageEvent(
# payload=message_envelope.message,
# sender=message_envelope.sender,
# receiver=recipient,
# kind=MessageKind.DIRECT,
# delivery_stage=DeliveryStage.DELIVER,
# )
# )
recipient_agent = self._get_agent(recipient)
response = await recipient_agent.on_message(
message_envelope.message,
cancellation_token=message_envelope.cancellation_token,
)
except BaseException as e:
message_envelope.future.set_exception(e)
return
self._message_queue.append(
ResponseMessageEnvelope(
message=response,
future=message_envelope.future,
sender=message_envelope.recipient,
recipient=message_envelope.sender,
)
)
self._outstanding_tasks.decrement()
async def _process_publish(self, message_envelope: PublishMessageEnvelope) -> None:
responses: List[Awaitable[Any]] = []
target_namespace = message_envelope.namespace
for agent_id in self._per_type_subscribers[(target_namespace, type(message_envelope.message))]:
if message_envelope.sender is not None and agent_id.name == message_envelope.sender.name:
continue
sender_agent = self._get_agent(message_envelope.sender) if message_envelope.sender is not None else None
sender_name = sender_agent.metadata["name"] if sender_agent is not None else "Unknown"
logger.info(
f"Calling message handler for {agent_id.name} with message type {type(message_envelope.message).__name__} published by {sender_name}"
)
# event_logger.info(
# MessageEvent(
# payload=message_envelope.message,
# sender=message_envelope.sender,
# receiver=agent,
# kind=MessageKind.PUBLISH,
# delivery_stage=DeliveryStage.DELIVER,
# )
# )
agent = self._get_agent(agent_id)
future = agent.on_message(
message_envelope.message,
cancellation_token=message_envelope.cancellation_token,
)
responses.append(future)
try:
_all_responses = await asyncio.gather(*responses)
except BaseException:
logger.error("Error processing publish message", exc_info=True)
return
self._outstanding_tasks.decrement()
# TODO if responses are given for a publish
async def _process_response(self, message_envelope: ResponseMessageEnvelope) -> None:
content = (
message_envelope.message.__dict__
if hasattr(message_envelope.message, "__dict__")
else message_envelope.message
)
logger.info(
f"Resolving response with message type {type(message_envelope.message).__name__} for recipient {message_envelope.recipient} from {message_envelope.sender.name}: {content}"
)
# event_logger.info(
# MessageEvent(
# payload=message_envelope.message,
# sender=message_envelope.sender,
# receiver=message_envelope.recipient,
# kind=MessageKind.RESPOND,
# delivery_stage=DeliveryStage.DELIVER,
# )
# )
self._outstanding_tasks.decrement()
message_envelope.future.set_result(message_envelope.message)
async def process_next(self) -> None:
if len(self._message_queue) == 0:
# Yield control to the event loop to allow other tasks to run
await asyncio.sleep(0)
return
message_envelope = self._message_queue.pop(0)
match message_envelope:
case SendMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
if self._before_send is not None:
try:
temp_message = await self._before_send.on_send(message, sender=sender, recipient=recipient)
except BaseException as e:
future.set_exception(e)
return
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
future.set_exception(MessageDroppedException())
return
message_envelope.message = temp_message
self._outstanding_tasks.increment()
asyncio.create_task(self._process_send(message_envelope))
case PublishMessageEnvelope(
message=message,
sender=sender,
):
if self._before_send is not None:
try:
temp_message = await self._before_send.on_publish(message, sender=sender)
except BaseException as e:
# TODO: we should raise the intervention exception to the publisher.
logger.error(f"Exception raised in in intervention handler: {e}", exc_info=True)
return
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
# TODO log message dropped
return
message_envelope.message = temp_message
self._outstanding_tasks.increment()
asyncio.create_task(self._process_publish(message_envelope))
case ResponseMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
if self._before_send is not None:
try:
temp_message = await self._before_send.on_response(message, sender=sender, recipient=recipient)
except BaseException as e:
# TODO: should we raise the exception to sender of the response instead?
future.set_exception(e)
return
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
future.set_exception(MessageDroppedException())
return
message_envelope.message = temp_message
self._outstanding_tasks.increment()
asyncio.create_task(self._process_response(message_envelope))
# Yield control to the message loop to allow other tasks to run
await asyncio.sleep(0)
def agent_metadata(self, agent: AgentId) -> AgentMetadata:
return self._get_agent(agent).metadata
def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]:
return self._get_agent(agent).save_state()
def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None:
self._get_agent(agent).load_state(state)
def register(
self,
name: str,
agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T],
*,
valid_namespaces: Sequence[str] | Type[AllNamespaces] = AllNamespaces,
) -> None:
if name in self._agent_factories:
raise ValueError(f"Agent with name {name} already exists.")
self._agent_factories[name] = agent_factory
if valid_namespaces is not AllNamespaces:
self._valid_namespaces[name] = cast(Sequence[str], valid_namespaces)
else:
self._valid_namespaces[name] = []
# For all already prepared namespaces we need to prepare this agent
for namespace in self._known_namespaces:
if self._type_valid_for_namespace(AgentId(name=name, namespace=namespace)):
self._get_agent(AgentId(name=name, namespace=namespace))
def _invoke_agent_factory(
self, agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T], agent_id: AgentId
) -> T:
if len(inspect.signature(agent_factory).parameters) == 0:
factory_one = cast(Callable[[], T], agent_factory)
agent = factory_one()
elif len(inspect.signature(agent_factory).parameters) == 2:
factory_two = cast(Callable[[AgentRuntime, AgentId], T], agent_factory)
agent = factory_two(self, agent_id)
else:
raise ValueError("Agent factory must take 0 or 2 arguments.")
# TODO: should this be part of the base agent interface?
if isinstance(agent, BaseAgent):
agent.bind_id(agent_id)
agent.bind_runtime(self)
return agent
def _type_valid_for_namespace(self, agent_id: AgentId) -> bool:
if agent_id.name not in self._agent_factories:
raise KeyError(f"Agent with name {agent_id.name} not found.")
valid_namespaces = self._valid_namespaces[agent_id.name]
if len(valid_namespaces) == 0:
return True
return agent_id.namespace in valid_namespaces
def _get_agent(self, agent_id: AgentId) -> Agent:
self._process_seen_namespace(agent_id.namespace)
if agent_id in self._instantiated_agents:
return self._instantiated_agents[agent_id]
if not self._type_valid_for_namespace(agent_id):
raise ValueError(f"Agent with name {agent_id.name} not valid for namespace {agent_id.namespace}.")
if agent_id.name not in self._agent_factories:
raise ValueError(f"Agent with name {agent_id.name} not found.")
agent_factory = self._agent_factories[agent_id.name]
agent = self._invoke_agent_factory(agent_factory, agent_id)
for message_type in agent.metadata["subscriptions"]:
self._per_type_subscribers[(agent_id.namespace, message_type)].add(agent_id)
self._instantiated_agents[agent_id] = agent
return agent
def get(self, name: str, *, namespace: str = "default") -> AgentId:
return self._get_agent(AgentId(name=name, namespace=namespace)).id
def get_proxy(self, name: str, *, namespace: str = "default") -> AgentProxy:
id = self.get(name, namespace=namespace)
return AgentProxy(id, self)
# Hydrate the agent instances in a namespace. The primary reason for this is
# to ensure message type subscriptions are set up.
def _process_seen_namespace(self, namespace: str) -> None:
if namespace in self._known_namespaces:
return
self._known_namespaces.add(namespace)
for name in self._known_agent_names:
if self._type_valid_for_namespace(AgentId(name=name, namespace=namespace)):
self._get_agent(AgentId(name=name, namespace=namespace))

View File

@@ -0,0 +1,13 @@
from ._events import DeliveryStage, LLMCallEvent, MessageEvent, MessageKind
from ._llm_usage import LLMUsageTracker
EVENT_LOGGER_NAME = "agnext.events"
__all__ = [
"LLMCallEvent",
"EVENT_LOGGER_NAME",
"LLMUsageTracker",
"MessageEvent",
"MessageKind",
"DeliveryStage",
]

View File

@@ -0,0 +1,84 @@
import json
from enum import Enum
from typing import Any, cast
from ...core import Agent
class LLMCallEvent:
def __init__(self, *, prompt_tokens: int, completion_tokens: int, **kwargs: Any) -> None:
"""To be used by model clients to log the call to the LLM.
Args:
prompt_tokens (int): Number of tokens used in the prompt.
completion_tokens (int): Number of tokens used in the completion.
Example:
.. code-block:: python
from agnext.application.logging import LLMCallEvent, EVENT_LOGGER_NAME
logger = logging.getLogger(EVENT_LOGGER_NAME)
logger.info(LLMCallEvent(prompt_tokens=10, completion_tokens=20))
"""
self.kwargs = kwargs
self.kwargs["prompt_tokens"] = prompt_tokens
self.kwargs["completion_tokens"] = completion_tokens
self.kwargs["type"] = "LLMCall"
@property
def prompt_tokens(self) -> int:
return cast(int, self.kwargs["prompt_tokens"])
@property
def completion_tokens(self) -> int:
return cast(int, self.kwargs["completion_tokens"])
# This must output the event in a json serializable format
def __str__(self) -> str:
return json.dumps(self.kwargs)
class MessageKind(Enum):
DIRECT = 1
PUBLISH = 2
RESPOND = 3
class DeliveryStage(Enum):
SEND = 1
DELIVER = 2
class MessageEvent:
def __init__(
self,
*,
payload: Any,
sender: Agent | None,
receiver: Agent | None,
kind: MessageKind,
delivery_stage: DeliveryStage,
**kwargs: Any,
) -> None:
self.kwargs = kwargs
self.kwargs["payload"] = payload
self.kwargs["sender"] = None if sender is None else sender.metadata["name"]
self.kwargs["receiver"] = None if receiver is None else receiver.metadata["name"]
self.kwargs["kind"] = kind
self.kwargs["delivery_stage"] = delivery_stage
self.kwargs["type"] = "Message"
@property
def prompt_tokens(self) -> int:
return cast(int, self.kwargs["prompt_tokens"])
@property
def completion_tokens(self) -> int:
return cast(int, self.kwargs["completion_tokens"])
# This must output the event in a json serializable format
def __str__(self) -> str:
return json.dumps(self.kwargs)

View File

@@ -0,0 +1,57 @@
import logging
from ._events import LLMCallEvent
class LLMUsageTracker(logging.Handler):
def __init__(self) -> None:
"""Logging handler that tracks the number of tokens used in the prompt and completion.
Example:
.. code-block:: python
from agnext.application.logging import LLMUsageTracker, EVENT_LOGGER_NAME
# Set up the logging configuration to use the custom handler
logger = logging.getLogger(EVENT_LOGGER_NAME)
logger.setLevel(logging.INFO)
llm_usage = LLMUsageTracker()
logger.handlers = [llm_usage]
# ...
print(llm_usage.prompt_tokens)
print(llm_usage.completion_tokens)
"""
super().__init__()
self._prompt_tokens = 0
self._completion_tokens = 0
@property
def tokens(self) -> int:
return self._prompt_tokens + self._completion_tokens
@property
def prompt_tokens(self) -> int:
return self._prompt_tokens
@property
def completion_tokens(self) -> int:
return self._completion_tokens
def reset(self) -> None:
self._prompt_tokens = 0
self._completion_tokens = 0
def emit(self, record: logging.LogRecord) -> None:
"""Emit the log record. To be used by the logging module."""
try:
# Use the StructuredMessage if the message is an instance of it
if isinstance(record.msg, LLMCallEvent):
event = record.msg
self._prompt_tokens += event.prompt_tokens
self._completion_tokens += event.completion_tokens
except Exception:
self.handleError(record)

View File

@@ -0,0 +1,3 @@
"""
The :mod:`agnext.chat` module is the concrete implementation of multi-agent interaction patterns
"""

View File

@@ -0,0 +1,6 @@
from .chat_completion_agent import ChatCompletionAgent
from .image_generation_agent import ImageGenerationAgent
from .oai_assistant import OpenAIAssistantAgent
from .user_proxy import UserProxyAgent
__all__ = ["ChatCompletionAgent", "OpenAIAssistantAgent", "UserProxyAgent", "ImageGenerationAgent"]

View File

@@ -0,0 +1,264 @@
import asyncio
import json
from typing import Any, Coroutine, Dict, List, Mapping, Sequence, Tuple
from ...components import (
FunctionCall,
TypeRoutedAgent,
message_handler,
)
from ...components.models import (
ChatCompletionClient,
FunctionExecutionResult,
FunctionExecutionResultMessage,
SystemMessage,
)
from ...components.tools import Tool
from ...core import AgentId, CancellationToken
from ..memory import ChatMemory
from ..types import (
FunctionCallMessage,
Message,
MultiModalMessage,
PublishNow,
Reset,
RespondNow,
ResponseFormat,
TextMessage,
ToolApprovalRequest,
ToolApprovalResponse,
)
from ..utils import convert_messages_to_llm_messages
class ChatCompletionAgent(TypeRoutedAgent):
"""An agent implementation that uses the ChatCompletion API to gnenerate
responses and execute tools.
Args:
name (str): The name of the agent.
description (str): The description of the agent.
runtime (AgentRuntime): The runtime to register the agent.
system_messages (List[SystemMessage]): The system messages to use for
the ChatCompletion API.
memory (ChatMemory): The memory to store and retrieve messages.
model_client (ChatCompletionClient): The client to use for the
ChatCompletion API.
tools (Sequence[Tool], optional): The tools used by the agent. Defaults
to []. If no tools are provided, the agent cannot handle tool calls.
If tools are provided, and the response from the model is a list of
tool calls, the agent will call itselfs with the tool calls until it
gets a response that is not a list of tool calls, and then use that
response as the final response.
tool_approver (Agent | None, optional): The agent that approves tool
calls. Defaults to None. If no tool approver is provided, the agent
will execute the tools without approval. If a tool approver is
provided, the agent will send a request to the tool approver before
executing the tools.
"""
def __init__(
self,
description: str,
system_messages: List[SystemMessage],
memory: ChatMemory,
model_client: ChatCompletionClient,
tools: Sequence[Tool] = [],
tool_approver: AgentId | None = None,
) -> None:
super().__init__(description)
self._description = description
self._system_messages = system_messages
self._client = model_client
self._memory = memory
self._tools = tools
self._tool_approver = tool_approver
@message_handler()
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None:
"""Handle a text message. This method adds the message to the memory and
does not generate any message."""
# Add a user message.
await self._memory.add_message(message)
@message_handler()
async def on_multi_modal_message(self, message: MultiModalMessage, cancellation_token: CancellationToken) -> None:
"""Handle a multimodal message. This method adds the message to the memory
and does not generate any message."""
# Add a user message.
await self._memory.add_message(message)
@message_handler()
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
"""Handle a reset message. This method clears the memory."""
# Reset the chat messages.
await self._memory.clear()
@message_handler()
async def on_respond_now(
self, message: RespondNow, cancellation_token: CancellationToken
) -> TextMessage | FunctionCallMessage:
"""Handle a respond now message. This method generates a response and
returns it to the sender."""
# Generate a response.
response = await self._generate_response(message.response_format, cancellation_token)
# Return the response.
return response
@message_handler()
async def on_publish_now(self, message: PublishNow, cancellation_token: CancellationToken) -> None:
"""Handle a publish now message. This method generates a response and
publishes it."""
# Generate a response.
response = await self._generate_response(message.response_format, cancellation_token)
# Publish the response.
await self.publish_message(response)
@message_handler()
async def on_tool_call_message(
self, message: FunctionCallMessage, cancellation_token: CancellationToken
) -> FunctionExecutionResultMessage:
"""Handle a tool call message. This method executes the tools and
returns the results."""
if len(self._tools) == 0:
raise ValueError("No tools available")
# Add a tool call message.
await self._memory.add_message(message)
# Execute the tool calls.
results: List[FunctionExecutionResult] = []
execution_futures: List[Coroutine[Any, Any, Tuple[str, str]]] = []
for function_call in message.content:
# Parse the arguments.
try:
arguments = json.loads(function_call.arguments)
except json.JSONDecodeError:
results.append(
FunctionExecutionResult(
content=f"Error: Could not parse arguments for function {function_call.name}.",
call_id=function_call.id,
)
)
continue
# Execute the function.
future = self._execute_function(
function_call.name,
arguments,
function_call.id,
cancellation_token=cancellation_token,
)
# Append the async result.
execution_futures.append(future)
if execution_futures:
# Wait for all async results.
execution_results = await asyncio.gather(*execution_futures)
# Add the results.
for execution_result, call_id in execution_results:
results.append(FunctionExecutionResult(content=execution_result, call_id=call_id))
# Create a tool call result message.
tool_call_result_msg = FunctionExecutionResultMessage(content=results)
# Add tool call result message.
await self._memory.add_message(tool_call_result_msg)
# Return the results.
return tool_call_result_msg
async def _generate_response(
self,
response_format: ResponseFormat,
cancellation_token: CancellationToken,
) -> TextMessage | FunctionCallMessage:
# Get a response from the model.
hisorical_messages = await self._memory.get_messages()
response = await self._client.create(
self._system_messages + convert_messages_to_llm_messages(hisorical_messages, self.metadata["name"]),
tools=self._tools,
json_output=response_format == ResponseFormat.json_object,
)
# If the agent has function executor, and the response is a list of
# tool calls, iterate with itself until we get a response that is not a
# list of tool calls.
while (
len(self._tools) > 0
and isinstance(response.content, list)
and all(isinstance(x, FunctionCall) for x in response.content)
):
# Send a function call message to itself.
response = await self.send_message(
message=FunctionCallMessage(content=response.content, source=self.metadata["name"]),
recipient=self.id,
cancellation_token=cancellation_token,
)
# Make an assistant message from the response.
hisorical_messages = await self._memory.get_messages()
response = await self._client.create(
self._system_messages + convert_messages_to_llm_messages(hisorical_messages, self.metadata["name"]),
tools=self._tools,
json_output=response_format == ResponseFormat.json_object,
)
final_response: Message
if isinstance(response.content, str):
# If the response is a string, return a text message.
final_response = TextMessage(content=response.content, source=self.metadata["name"])
elif isinstance(response.content, list) and all(isinstance(x, FunctionCall) for x in response.content):
# If the response is a list of function calls, return a function call message.
final_response = FunctionCallMessage(content=response.content, source=self.metadata["name"])
else:
raise ValueError(f"Unexpected response: {response.content}")
# Add the response to the chat messages.
await self._memory.add_message(final_response)
return final_response
async def _execute_function(
self,
name: str,
args: Dict[str, Any],
call_id: str,
cancellation_token: CancellationToken,
) -> Tuple[str, str]:
# Find tool
tool = next((t for t in self._tools if t.name == name), None)
if tool is None:
return (f"Error: tool {name} not found.", call_id)
# Check if the tool needs approval
if self._tool_approver is not None:
# Send a tool approval request.
approval_request = ToolApprovalRequest(
tool_call=FunctionCall(id=call_id, arguments=json.dumps(args), name=name)
)
approval_response = await self.send_message(
message=approval_request,
recipient=self._tool_approver,
cancellation_token=cancellation_token,
)
if not isinstance(approval_response, ToolApprovalResponse):
raise ValueError(f"Expecting {ToolApprovalResponse.__name__}, received: {type(approval_response)}")
if not approval_response.approved:
return (f"Error: tool {name} approved, reason: {approval_response.reason}", call_id)
try:
result = await tool.run_json(args, cancellation_token)
result_as_str = tool.return_value_as_string(result)
except Exception as e:
result_as_str = f"Error: {str(e)}"
return (result_as_str, call_id)
def save_state(self) -> Mapping[str, Any]:
return {
"memory": self._memory.save_state(),
"system_messages": self._system_messages,
}
def load_state(self, state: Mapping[str, Any]) -> None:
self._memory.load_state(state["memory"])
self._system_messages = state["system_messages"]

View File

@@ -0,0 +1,62 @@
from typing import Literal
import openai
from ...components import (
Image,
TypeRoutedAgent,
message_handler,
)
from ...core import CancellationToken
from ..memory import ChatMemory
from ..types import (
MultiModalMessage,
PublishNow,
Reset,
TextMessage,
)
class ImageGenerationAgent(TypeRoutedAgent):
def __init__(
self,
description: str,
memory: ChatMemory,
client: openai.AsyncClient,
model: Literal["dall-e-2", "dall-e-3"] = "dall-e-2",
):
super().__init__(description)
self._client = client
self._model = model
self._memory = memory
@message_handler
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None:
await self._memory.add_message(message)
@message_handler
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
await self._memory.clear()
@message_handler
async def on_publish_now(self, message: PublishNow, cancellation_token: CancellationToken) -> None:
response = await self._generate_response(cancellation_token)
self.publish_message(response)
async def _generate_response(self, cancellation_token: CancellationToken) -> MultiModalMessage:
messages = await self._memory.get_messages()
if len(messages) == 0:
return MultiModalMessage(
content=["I need more information to generate an image."], source=self.metadata["name"]
)
prompt = ""
for m in messages:
assert isinstance(m, TextMessage)
prompt += m.content + "\n"
prompt.strip()
response = await self._client.images.generate(model=self._model, prompt=prompt, response_format="b64_json")
assert len(response.data) > 0 and response.data[0].b64_json is not None
# Create a MultiModalMessage with the image.
image = Image.from_base64(response.data[0].b64_json)
multi_modal_message = MultiModalMessage(content=[image], source=self.metadata["name"])
return multi_modal_message

View File

@@ -0,0 +1,134 @@
from typing import Any, Callable, List, Mapping
import openai
from openai import AsyncAssistantEventHandler
from openai.types.beta import AssistantResponseFormatParam
from ...components import TypeRoutedAgent, message_handler
from ...core import CancellationToken
from ..types import PublishNow, Reset, RespondNow, ResponseFormat, TextMessage
class OpenAIAssistantAgent(TypeRoutedAgent):
"""An agent implementation that uses the OpenAI Assistant API to generate
responses.
Args:
name (str): The name of the agent.
description (str): The description of the agent.
runtime (AgentRuntime): The runtime to register the agent.
client (openai.AsyncClient): The client to use for the OpenAI API.
assistant_id (str): The assistant ID to use for the OpenAI API.
thread_id (str): The thread ID to use for the OpenAI API.
assistant_event_handler_factory (Callable[[], AsyncAssistantEventHandler], optional):
A factory function to create an async assistant event handler. Defaults to None.
If provided, the agent will use the streaming mode with the event handler.
If not provided, the agent will use the blocking mode to generate responses.
"""
def __init__(
self,
description: str,
client: openai.AsyncClient,
assistant_id: str,
thread_id: str,
assistant_event_handler_factory: Callable[[], AsyncAssistantEventHandler] | None = None,
) -> None:
super().__init__(description)
self._client = client
self._assistant_id = assistant_id
self._thread_id = thread_id
self._assistant_event_handler_factory = assistant_event_handler_factory
@message_handler()
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None:
"""Handle a text message. This method adds the message to the thread."""
# Save the message to the thread.
_ = await self._client.beta.threads.messages.create(
thread_id=self._thread_id,
content=message.content,
role="user",
metadata={"sender": message.source},
)
@message_handler()
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
"""Handle a reset message. This method deletes all messages in the thread."""
# Get all messages in this thread.
all_msgs: List[str] = []
while True:
if not all_msgs:
msgs = await self._client.beta.threads.messages.list(self._thread_id)
else:
msgs = await self._client.beta.threads.messages.list(self._thread_id, after=all_msgs[-1])
for msg in msgs.data:
all_msgs.append(msg.id)
if not msgs.has_next_page():
break
# Delete all the messages.
for msg_id in all_msgs:
status = await self._client.beta.threads.messages.delete(message_id=msg_id, thread_id=self._thread_id)
assert status.deleted is True
@message_handler()
async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> TextMessage:
"""Handle a respond now message. This method generates a response and returns it to the sender."""
return await self._generate_response(message.response_format, cancellation_token)
@message_handler()
async def on_publish_now(self, message: PublishNow, cancellation_token: CancellationToken) -> None:
"""Handle a publish now message. This method generates a response and publishes it."""
response = await self._generate_response(message.response_format, cancellation_token)
await self.publish_message(response)
async def _generate_response(
self, requested_response_format: ResponseFormat, cancellation_token: CancellationToken
) -> TextMessage:
# Handle response format.
if requested_response_format == ResponseFormat.json_object:
response_format = AssistantResponseFormatParam(type="json_object")
else:
response_format = AssistantResponseFormatParam(type="text")
if self._assistant_event_handler_factory is not None:
# Use event handler and streaming mode if available.
async with self._client.beta.threads.runs.stream(
thread_id=self._thread_id,
assistant_id=self._assistant_id,
event_handler=self._assistant_event_handler_factory(),
response_format=response_format,
) as stream:
run = await stream.get_final_run()
else:
# Use blocking mode.
run = await self._client.beta.threads.runs.create(
thread_id=self._thread_id,
assistant_id=self._assistant_id,
response_format=response_format,
)
if run.status != "completed":
# TODO: handle other statuses.
raise ValueError(f"Run did not complete successfully: {run}")
# Get the last message from the run.
response = await self._client.beta.threads.messages.list(self._thread_id, run_id=run.id, order="desc", limit=1)
last_message_content = response.data[0].content
# TODO: handle array of content.
text_content = [content for content in last_message_content if content.type == "text"]
if not text_content:
raise ValueError(f"Expected text content in the last message: {last_message_content}")
# TODO: handle multiple text content.
return TextMessage(content=text_content[0].text.value, source=self.metadata["name"])
def save_state(self) -> Mapping[str, Any]:
return {
"assistant_id": self._assistant_id,
"thread_id": self._thread_id,
}
def load_state(self, state: Mapping[str, Any]) -> None:
self._assistant_id = state["assistant_id"]
self._thread_id = state["thread_id"]

View File

@@ -0,0 +1,32 @@
import asyncio
from ...components import TypeRoutedAgent, message_handler
from ...core import CancellationToken
from ..types import PublishNow, TextMessage
class UserProxyAgent(TypeRoutedAgent):
"""An agent that proxies user input from the console. Override the `get_user_input`
method to customize how user input is retrieved.
Args:
name (str): The name of the agent.
description (str): The description of the agent.
runtime (AgentRuntime): The runtime to register the agent.
user_input_prompt (str): The console prompt to show to the user when asking for input.
"""
def __init__(self, description: str, user_input_prompt: str) -> None:
super().__init__(description)
self._user_input_prompt = user_input_prompt
@message_handler()
async def on_publish_now(self, message: PublishNow, cancellation_token: CancellationToken) -> None:
"""Handle a publish now message. This method prompts the user for input, then publishes it."""
user_input = await self.get_user_input(self._user_input_prompt)
await self.publish_message(TextMessage(content=user_input, source=self.metadata["name"]))
async def get_user_input(self, prompt: str) -> str:
"""Get user input from the console. Override this method to customize how user input is retrieved."""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, input, prompt)

View File

@@ -0,0 +1,5 @@
from ._base import ChatMemory
from ._buffered import BufferedChatMemory
from ._head_and_tail import HeadAndTailChatMemory
__all__ = ["ChatMemory", "BufferedChatMemory", "HeadAndTailChatMemory"]

View File

@@ -0,0 +1,19 @@
from typing import Any, List, Mapping, Protocol
from ..types import Message
class ChatMemory(Protocol):
"""A protocol for defining the interface of a chat memory. A chat memory
lets agents to store and retrieve messages. It can be implemented with
different memory recall strategies."""
async def add_message(self, message: Message) -> None: ...
async def get_messages(self) -> List[Message]: ...
async def clear(self) -> None: ...
def save_state(self) -> Mapping[str, Any]: ...
def load_state(self, state: Mapping[str, Any]) -> None: ...

View File

@@ -0,0 +1,46 @@
from typing import Any, List, Mapping
from ...components.models import FunctionExecutionResultMessage
from ..types import Message
from ._base import ChatMemory
class BufferedChatMemory(ChatMemory):
"""A buffered chat memory that keeps a view of the last n messages,
where n is the buffer size. The buffer size is set at initialization.
Args:
buffer_size (int): The size of the buffer.
"""
def __init__(self, buffer_size: int) -> None:
self._messages: List[Message] = []
self._buffer_size = buffer_size
async def add_message(self, message: Message) -> None:
"""Add a message to the memory."""
self._messages.append(message)
async def get_messages(self) -> List[Message]:
"""Get at most `buffer_size` recent messages."""
messages = self._messages[-self._buffer_size :]
# Handle the first message is a function call result message.
if messages and isinstance(messages[0], FunctionExecutionResultMessage):
# Remove the first message from the list.
messages = messages[1:]
return messages
async def clear(self) -> None:
"""Clear the message memory."""
self._messages = []
def save_state(self) -> Mapping[str, Any]:
return {
"messages": [message for message in self._messages],
"buffer_size": self._buffer_size,
}
def load_state(self, state: Mapping[str, Any]) -> None:
self._messages = state["messages"]
self._buffer_size = state["buffer_size"]

View File

@@ -0,0 +1,66 @@
from typing import Any, List, Mapping
from ...components.models import FunctionExecutionResultMessage
from ..types import FunctionCallMessage, Message, TextMessage
from ._base import ChatMemory
class HeadAndTailChatMemory(ChatMemory):
"""A chat memory that keeps a view of the first n and last m messages,
where n is the head size and m is the tail size. The head and tail sizes
are set at initialization.
Args:
head_size (int): The size of the head.
tail_size (int): The size of the tail.
"""
def __init__(self, head_size: int, tail_size: int) -> None:
self._messages: List[Message] = []
self._head_size = head_size
self._tail_size = tail_size
async def add_message(self, message: Message) -> None:
"""Add a message to the memory."""
self._messages.append(message)
async def get_messages(self) -> List[Message]:
"""Get at most `head_size` recent messages and `tail_size` oldest messages."""
head_messages = self._messages[: self._head_size]
# Handle the last message is a function call message.
if head_messages and isinstance(head_messages[-1], FunctionCallMessage):
# Remove the last message from the head.
head_messages = head_messages[:-1]
tail_messages = self._messages[-self._tail_size :]
# Handle the first message is a function call result message.
if tail_messages and isinstance(tail_messages[0], FunctionExecutionResultMessage):
# Remove the first message from the tail.
tail_messages = tail_messages[1:]
num_skipped = len(self._messages) - self._head_size - self._tail_size
if num_skipped <= 0:
# If there are not enough messages to fill the head and tail,
# return all messages.
return self._messages
placeholder_messages = [TextMessage(content=f"Skipped {num_skipped} messages.", source="System")]
return head_messages + placeholder_messages + tail_messages
async def clear(self) -> None:
"""Clear the message memory."""
self._messages = []
def save_state(self) -> Mapping[str, Any]:
return {
"messages": [message for message in self._messages],
"head_size": self._head_size,
"tail_size": self._tail_size,
"placeholder_message": self._placeholder_message,
}
def load_state(self, state: Mapping[str, Any]) -> None:
self._messages = state["messages"]
self._head_size = state["head_size"]
self._tail_size = state["tail_size"]
self._placeholder_message = state["placeholder_message"]

View File

@@ -0,0 +1,3 @@
from .group_chat_manager import GroupChatManager
__all__ = ["GroupChatManager"]

View File

@@ -0,0 +1,154 @@
import logging
from typing import Any, Callable, List, Mapping
from ...components import TypeRoutedAgent, message_handler
from ...components.models import ChatCompletionClient
from ...core import AgentId, AgentProxy, AgentRuntime, CancellationToken
from ..memory import ChatMemory
from ..types import (
MultiModalMessage,
PublishNow,
Reset,
TextMessage,
)
from .group_chat_utils import select_speaker
logger = logging.getLogger("agnext.events")
class GroupChatManager(TypeRoutedAgent):
"""An agent that manages a group chat through event-driven orchestration.
Args:
name (str): The name of the agent.
description (str): The description of the agent.
runtime (AgentRuntime): The runtime to register the agent.
participants (List[AgentId]): The list of participants in the group chat.
memory (ChatMemory): The memory to store and retrieve messages.
model_client (ChatCompletionClient, optional): The client to use for the model.
If provided, the agent will use the model to select the next speaker.
If not provided, the agent will select the next speaker from the list of participants
according to the order given.
termination_word (str, optional): The word that terminates the group chat. Defaults to "TERMINATE".
transitions (Mapping[AgentId, List[AgentId]], optional): The transitions between agents.
Keys are the agents, and values are the list of agents that can follow the key agent. Defaults to {}.
If provided, the group chat manager will use the transitions to select the next speaker.
If a transition is not provided for an agent, the choices fallback to all participants.
If no model client is provided, a transition must have a single value.
on_message_received (Callable[[TextMessage], None], optional): A custom handler to call when a message is received.
Defaults to None.
"""
def __init__(
self,
description: str,
runtime: AgentRuntime,
participants: List[AgentId],
memory: ChatMemory,
model_client: ChatCompletionClient | None = None,
termination_word: str = "TERMINATE",
transitions: Mapping[AgentId, List[AgentId]] = {},
on_message_received: Callable[[TextMessage | MultiModalMessage], None] | None = None,
):
super().__init__(description)
self._memory = memory
self._client = model_client
self._participants = participants
self._participant_proxies = dict((p, AgentProxy(p, runtime)) for p in participants)
self._termination_word = termination_word
for key, value in transitions.items():
if not value:
# Make sure no empty transitions are provided.
raise ValueError(f"Empty transition list provided for {key.name}.")
if key not in participants:
# Make sure all keys are in the list of participants.
raise ValueError(f"Transition key {key.name} not found in participants.")
for v in value:
if v not in participants:
# Make sure all values are in the list of participants.
raise ValueError(f"Transition value {v.name} not found in participants.")
if self._client is None:
# Make sure there is only one transition for each key if no model client is provided.
if len(value) > 1:
raise ValueError(f"Multiple transitions provided for {key.name} but no model client is provided.")
self._tranistions = transitions
self._on_message_received = on_message_received
@message_handler()
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
"""Handle a reset message. This method clears the memory."""
await self._memory.clear()
@message_handler()
async def on_new_message(
self, message: TextMessage | MultiModalMessage, cancellation_token: CancellationToken
) -> None:
"""Handle a message. This method adds the message to the memory, selects the next speaker,
and sends a message to the selected speaker to publish a response."""
# Call the custom on_message_received handler if provided.
if self._on_message_received is not None:
self._on_message_received(message)
# Check if the message contains the termination word.
if isinstance(message, TextMessage) and self._termination_word in message.content:
# Terminate the group chat by not selecting the next speaker.
return
# Save the message to chat memory.
await self._memory.add_message(message)
# Get the last speaker.
last_speaker_name = message.source
last_speaker_index = next((i for i, p in enumerate(self._participants) if p.name == last_speaker_name), None)
# Get the candidates for the next speaker.
if last_speaker_index is not None:
logger.debug(f"Last speaker: {last_speaker_name}")
last_speaker = self._participants[last_speaker_index]
if self._tranistions.get(last_speaker) is not None:
candidates = [c for c in self._participants if c in self._tranistions[last_speaker]]
else:
candidates = self._participants
else:
candidates = self._participants
logger.debug(f"Group chat manager next speaker candidates: {[c.name for c in candidates]}")
# Select speaker.
if len(candidates) == 0:
speaker = None
elif len(candidates) == 1:
speaker = candidates[0]
else:
# More than one candidate, select the next speaker.
if self._client is None:
# If no model client is provided, candidates must be the list of participants.
assert candidates == self._participants
# If no model client is provided, select the next speaker from the list of participants.
if last_speaker_index is not None:
next_speaker_index = (last_speaker_index + 1) % len(self._participants)
speaker = self._participants[next_speaker_index]
else:
# If no last speaker, select the first speaker.
speaker = candidates[0]
else:
# If a model client is provided, select the speaker based on the transitions and the model.
speaker_index = await select_speaker(
self._memory, self._client, [self._participant_proxies[c] for c in candidates]
)
speaker = candidates[speaker_index]
logger.debug(f"Group chat manager selected speaker: {speaker.name if speaker is not None else None}")
if speaker is not None:
# Send the message to the selected speaker to ask it to publish a response.
await self.send_message(PublishNow(), speaker)
def save_state(self) -> Mapping[str, Any]:
return {
"memory": self._memory.save_state(),
"termination_word": self._termination_word,
}
def load_state(self, state: Mapping[str, Any]) -> None:
self._memory.load_state(state["memory"])
self._termination_word = state["termination_word"]

View File

@@ -0,0 +1,81 @@
"""Credit to the original authors: https://github.com/microsoft/autogen/blob/main/autogen/agentchat/groupchat.py"""
import re
from typing import Dict, List
from ...components.models import ChatCompletionClient, SystemMessage
from ...core import AgentProxy
from ..memory import ChatMemory
from ..types import TextMessage
async def select_speaker(memory: ChatMemory, client: ChatCompletionClient, agents: List[AgentProxy]) -> int:
"""Selects the next speaker in a group chat using a ChatCompletion client."""
# TODO: Handle multi-modal messages.
# Construct formated current message history.
history_messages: List[str] = []
for msg in await memory.get_messages():
assert isinstance(msg, TextMessage)
history_messages.append(f"{msg.source}: {msg.content}")
history = "\n".join(history_messages)
# Construct agent roles.
roles = "\n".join([f"{agent.metadata['name']}: {agent.metadata['description']}".strip() for agent in agents])
# Construct agent list.
participants = str([agent.metadata["name"] for agent in agents])
# Select the next speaker.
select_speaker_prompt = f"""You are in a role play game. The following roles are available:
{roles}.
Read the following conversation. Then select the next role from {participants} to play. Only return the role.
{history}
Read the above conversation. Then select the next role from {participants} to play. Only return the role.
"""
select_speaker_messages = [SystemMessage(select_speaker_prompt)]
response = await client.create(messages=select_speaker_messages)
assert isinstance(response.content, str)
mentions = mentioned_agents(response.content, agents)
if len(mentions) != 1:
raise ValueError(f"Expected exactly one agent to be mentioned, but got {mentions}")
agent_name = list(mentions.keys())[0]
agent_index = next((i for i, agent in enumerate(agents) if agent.metadata["name"] == agent_name), None)
assert agent_index is not None
return agent_index
def mentioned_agents(message_content: str, agents: List[AgentProxy]) -> Dict[str, int]:
"""Counts the number of times each agent is mentioned in the provided message content.
Agent names will match under any of the following conditions (all case-sensitive):
- Exact name match
- If the agent name has underscores it will match with spaces instead (e.g. 'Story_writer' == 'Story writer')
- If the agent name has underscores it will match with '\\_' instead of '_' (e.g. 'Story_writer' == 'Story\\_writer')
Args:
message_content (Union[str, List]): The content of the message, either as a single string or a list of strings.
agents (List[Agent]): A list of Agent objects, each having a 'name' attribute to be searched in the message content.
Returns:
Dict: a counter for mentioned agents.
"""
mentions: Dict[str, int] = dict()
for agent in agents:
# Finds agent mentions, taking word boundaries into account,
# accommodates escaping underscores and underscores as spaces
name = agent.metadata["name"]
regex = (
r"(?<=\W)("
+ re.escape(name)
+ r"|"
+ re.escape(name.replace("_", " "))
+ r"|"
+ re.escape(name.replace("_", r"\_"))
+ r")(?=\W)"
)
count = len(re.findall(regex, f" {message_content} ")) # Pad the message to help with matching
if count > 0:
mentions[name] = count
return mentions

View File

@@ -0,0 +1,392 @@
import json
from typing import Any, Sequence, Tuple
from ...components import TypeRoutedAgent, message_handler
from ...core import AgentId, AgentRuntime, CancellationToken
from ..types import Reset, RespondNow, ResponseFormat, TextMessage
__all__ = ["OrchestratorChat"]
class OrchestratorChat(TypeRoutedAgent):
def __init__(
self,
description: str,
runtime: AgentRuntime,
orchestrator: AgentId,
planner: AgentId,
specialists: Sequence[AgentId],
max_turns: int = 30,
max_stalled_turns_before_retry: int = 2,
max_retry_attempts: int = 1,
) -> None:
super().__init__(description)
self._orchestrator = orchestrator
self._planner = planner
self._specialists = specialists
self._max_turns = max_turns
self._max_stalled_turns_before_retry = max_stalled_turns_before_retry
self._max_retry_attempts_before_educated_guess = max_retry_attempts
@property
def children(self) -> Sequence[AgentId]:
return list(self._specialists) + [self._orchestrator, self._planner]
@message_handler()
async def on_text_message(
self,
message: TextMessage,
cancellation_token: CancellationToken,
) -> TextMessage:
# A task is received.
task = message.content
# Prepare the task.
team, names, facts, plan = await self._prepare_task(task, message.source)
# Main loop.
total_turns = 0
retry_attempts = 0
while total_turns < self._max_turns:
# Reset all agents.
for agent in [*self._specialists, self._orchestrator]:
await self.send_message(Reset(), agent)
# Create the task specs.
task_specs = f"""
We are working to address the following user request:
{task}
To answer this request we have assembled the following team:
{team}
Some additional points to consider:
{facts}
{plan}
""".strip()
# Send the task specs to the orchestrator and specialists.
for agent in [*self._specialists, self._orchestrator]:
await self.send_message(TextMessage(content=task_specs, source=self.metadata["name"]), agent)
# Inner loop.
stalled_turns = 0
while total_turns < self._max_turns:
# Reflect on the task.
data = await self._reflect_on_task(task, team, names, message.source)
# Check if the request is satisfied.
if data["is_request_satisfied"]["answer"]:
return TextMessage(
content=f"The task has been successfully addressed. {data['is_request_satisfied']['reason']}",
source=self.metadata["name"],
)
# Update stalled turns.
if data["is_progress_being_made"]["answer"]:
stalled_turns = max(0, stalled_turns - 1)
else:
stalled_turns += 1
# Handle retry.
if stalled_turns > self._max_stalled_turns_before_retry:
# In a retry, we need to rewrite the facts and the plan.
# Rewrite the facts.
facts = await self._rewrite_facts(facts, message.source)
# Increment the retry attempts.
retry_attempts += 1
# Check if we should just guess.
if retry_attempts > self._max_retry_attempts_before_educated_guess:
# Make an educated guess.
educated_guess = await self._educated_guess(facts, message.source)
if educated_guess["has_educated_guesses"]["answer"]:
return TextMessage(
content=f"The task is addressed with an educated guess. {educated_guess['has_educated_guesses']['reason']}",
source=self.metadata["name"],
)
# Come up with a new plan.
plan = await self._rewrite_plan(team, message.source)
# Exit the inner loop.
break
# Get the subtask.
subtask = data["instruction_or_question"]["answer"]
if subtask is None:
subtask = ""
# Update agents.
for agent in [*self._specialists, self._orchestrator]:
_ = await self.send_message(
TextMessage(content=subtask, source=self.metadata["name"]),
agent,
)
# Find the speaker.
try:
speaker = next(agent for agent in self._specialists if agent.name == data["next_speaker"]["answer"])
except StopIteration as e:
raise ValueError(f"Invalid next speaker: {data['next_speaker']['answer']}") from e
# Ask speaker to speak.
speaker_response = await self.send_message(RespondNow(), speaker)
assert speaker_response is not None
# Update all other agents with the speaker's response.
for agent in [agent for agent in self._specialists if agent != speaker] + [self._orchestrator]:
await self.send_message(
TextMessage(
content=speaker_response.content,
source=speaker_response.source,
),
agent,
)
# Increment the total turns.
total_turns += 1
return TextMessage(
content="The task was not addressed. The maximum number of turns was reached.",
source=self.metadata["name"],
)
async def _prepare_task(self, task: str, sender: str) -> Tuple[str, str, str, str]:
# Reset planner.
await self.send_message(Reset(), self._planner)
# A reusable description of the team.
team = "\n".join(
[agent.name + ": " + self.runtime.agent_metadata(agent)["description"] for agent in self._specialists]
)
names = ", ".join([agent.name for agent in self._specialists])
# A place to store relevant facts.
facts = ""
# A plance to store the plan.
plan = ""
# Start by writing what we know
closed_book_prompt = f"""Below I will present you a request. Before we begin addressing the request, please answer the following pre-survey to the best of your ability. Keep in mind that you are Ken Jennings-level with trivia, and Mensa-level with puzzles, so there should be a deep well to draw from.
Here is the request:
{task}
Here is the pre-survey:
1. Please list any specific facts or figures that are GIVEN in the request itself. It is possible that there are none.
2. Please list any facts that may need to be looked up, and WHERE SPECIFICALLY they might be found. In some cases, authoritative sources are mentioned in the request itself.
3. Please list any facts that may need to be derived (e.g., via logical deduction, simulation, or computation)
4. Please list any facts that are recalled from memory, hunches, well-reasoned guesses, etc.
When answering this survey, keep in mind that "facts" will typically be specific names, dates, statistics, etc. Your answer should use headings:
1. GIVEN OR VERIFIED FACTS
2. FACTS TO LOOK UP
3. FACTS TO DERIVE
4. EDUCATED GUESSES
""".strip()
# Ask the planner to obtain prior knowledge about facts.
await self.send_message(TextMessage(content=closed_book_prompt, source=sender), self._planner)
facts_response = await self.send_message(RespondNow(), self._planner)
facts = str(facts_response.content)
# Make an initial plan
plan_prompt = f"""Fantastic. To address this request we have assembled the following team:
{team}
Based on the team composition, and known and unknown facts, please devise a short bullet-point plan for addressing the original request. Remember, there is no requirement to involve all team members -- a team member's particular expertise may not be needed for this task.""".strip()
# Send second messag eto the planner.
await self.send_message(TextMessage(content=plan_prompt, source=sender), self._planner)
plan_response = await self.send_message(RespondNow(), self._planner)
plan = str(plan_response.content)
return team, names, facts, plan
async def _reflect_on_task(
self,
task: str,
team: str,
names: str,
sender: str,
) -> Any:
step_prompt = f"""
Recall we are working on the following request:
{task}
And we have assembled the following team:
{team}
To make progress on the request, please answer the following questions, including necessary reasoning:
- Is the request fully satisfied? (True if complete, or False if the original request has yet to be SUCCESSFULLY addressed)
- Are we making forward progress? (True if just starting, or recent messages are adding value. False if recent messages show evidence of being stuck in a reasoning or action loop, or there is evidence of significant barriers to success such as the inability to read from a required file)
- Who should speak next? (select from: {names})
- What instruction or question would you give this team member? (Phrase as if speaking directly to them, and include any specific information they may need)
Please output an answer in pure JSON format according to the following schema. The JSON object must be parsable as-is. DO NOT OUTPUT ANYTHING OTHER THAN JSON, AND DO NOT DEVIATE FROM THIS SCHEMA:
{{
"is_request_satisfied": {{
"reason": string,
"answer": boolean
}},
"is_progress_being_made": {{
"reason": string,
"answer": boolean
}},
"next_speaker": {{
"reason": string,
"answer": string (select from: {names})
}},
"instruction_or_question": {{
"reason": string,
"answer": string
}}
}}
""".strip()
request = step_prompt
while True:
# Send a message to the orchestrator.
await self.send_message(TextMessage(content=request, source=sender), self._orchestrator)
# Request a response.
step_response = await self.send_message(
RespondNow(response_format=ResponseFormat.json_object),
self._orchestrator,
)
# TODO: use typed dictionary.
try:
result = json.loads(str(step_response.content))
except json.JSONDecodeError as e:
request = f"Invalid JSON: {str(e)}"
continue
if "is_request_satisfied" not in result:
request = "Missing key: is_request_satisfied"
continue
elif (
not isinstance(result["is_request_satisfied"], dict)
or "answer" not in result["is_request_satisfied"]
or "reason" not in result["is_request_satisfied"]
):
request = "Invalid value for key: is_request_satisfied, expected 'answer' and 'reason'"
continue
if "is_progress_being_made" not in result:
request = "Missing key: is_progress_being_made"
continue
elif (
not isinstance(result["is_progress_being_made"], dict)
or "answer" not in result["is_progress_being_made"]
or "reason" not in result["is_progress_being_made"]
):
request = "Invalid value for key: is_progress_being_made, expected 'answer' and 'reason'"
continue
if "next_speaker" not in result:
request = "Missing key: next_speaker"
continue
elif (
not isinstance(result["next_speaker"], dict)
or "answer" not in result["next_speaker"]
or "reason" not in result["next_speaker"]
):
request = "Invalid value for key: next_speaker, expected 'answer' and 'reason'"
continue
elif result["next_speaker"]["answer"] not in names:
request = f"Invalid value for key: next_speaker, expected 'answer' in {names}"
continue
if "instruction_or_question" not in result:
request = "Missing key: instruction_or_question"
continue
elif (
not isinstance(result["instruction_or_question"], dict)
or "answer" not in result["instruction_or_question"]
or "reason" not in result["instruction_or_question"]
):
request = "Invalid value for key: instruction_or_question, expected 'answer' and 'reason'"
continue
return result
async def _rewrite_facts(self, facts: str, sender: str) -> str:
new_facts_prompt = f"""It's clear we aren't making as much progress as we would like, but we may have learned something new. Please rewrite the following fact sheet, updating it to include anything new we have learned. This is also a good time to update educated guesses (please add or update at least one educated guess or hunch, and explain your reasoning).
{facts}
""".strip()
# Send a message to the orchestrator.
await self.send_message(TextMessage(content=new_facts_prompt, source=sender), self._orchestrator)
# Request a response.
new_facts_response = await self.send_message(RespondNow(), self._orchestrator)
return str(new_facts_response.content)
async def _educated_guess(self, facts: str, sender: str) -> Any:
# Make an educated guess.
educated_guess_promt = f"""Given the following information
{facts}
Please answer the following question, including necessary reasoning:
- Do you have two or more congruent pieces of information that will allow you to make an educated guess for the original request? The educated guess MUST answer the question.
Please output an answer in pure JSON format according to the following schema. The JSON object must be parsable as-is. DO NOT OUTPUT ANYTHING OTHER THAN JSON, AND DO NOT DEVIATE FROM THIS SCHEMA:
{{
"has_educated_guesses": {{
"reason": string,
"answer": boolean
}}
}}
""".strip()
request = educated_guess_promt
while True:
# Send a message to the orchestrator.
await self.send_message(
TextMessage(content=request, source=sender),
self._orchestrator,
)
# Request a response.
response = await self.send_message(
RespondNow(response_format=ResponseFormat.json_object),
self._orchestrator,
)
try:
result = json.loads(str(response.content))
except json.JSONDecodeError as e:
request = f"Invalid JSON: {str(e)}"
continue
# TODO: use typed dictionary.
if "has_educated_guesses" not in result:
request = "Missing key: has_educated_guesses"
continue
if (
not isinstance(result["has_educated_guesses"], dict)
or "answer" not in result["has_educated_guesses"]
or "reason" not in result["has_educated_guesses"]
):
request = "Invalid value for key: has_educated_guesses, expected 'answer' and 'reason'"
continue
return result
async def _rewrite_plan(self, team: str, sender: str) -> str:
new_plan_prompt = f"""Please come up with a new plan expressed in bullet points. Keep in mind the following team composition, and do not involve any other outside people in the plan -- we cannot contact anyone else.
Team membership:
{team}
""".strip()
# Send a message to the orchestrator.
await self.send_message(TextMessage(content=new_plan_prompt, source=sender), self._orchestrator)
# Request a response.
new_plan_response = await self.send_message(RespondNow(), self._orchestrator)
return str(new_plan_response.content)

View File

@@ -0,0 +1,74 @@
from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Union
from ..components import FunctionCall, Image
from ..components.models import FunctionExecutionResultMessage
@dataclass(kw_only=True)
class BaseMessage:
# Name of the agent that sent this message
source: str
@dataclass
class TextMessage(BaseMessage):
content: str
@dataclass
class MultiModalMessage(BaseMessage):
content: List[Union[str, Image]]
@dataclass
class FunctionCallMessage(BaseMessage):
content: List[FunctionCall]
Message = Union[TextMessage, MultiModalMessage, FunctionCallMessage, FunctionExecutionResultMessage]
class ResponseFormat(Enum):
text = "text"
json_object = "json_object"
@dataclass
class RespondNow:
"""A message to request a response from the addressed agent. The sender
expects a response upon sening and waits for it synchronously."""
response_format: ResponseFormat = field(default=ResponseFormat.text)
@dataclass
class PublishNow:
"""A message to request an event to be published to the addressed agent.
Unlike RespondNow, the sender does not expect a response upon sending."""
response_format: ResponseFormat = field(default=ResponseFormat.text)
class Reset: ...
@dataclass
class ToolApprovalRequest:
"""A message to request approval for a tool call. The sender expects a
response upon sending and waits for it synchronously."""
tool_call: FunctionCall
@dataclass
class ToolApprovalResponse:
"""A message to respond to a tool approval request. The response is sent
synchronously."""
tool_call_id: str
approved: bool
reason: str

View File

@@ -0,0 +1,98 @@
from typing import List, Optional, Union
from typing_extensions import Literal
from ..components.models import (
AssistantMessage,
FunctionExecutionResult,
FunctionExecutionResultMessage,
LLMMessage,
UserMessage,
)
from .types import (
FunctionCallMessage,
Message,
MultiModalMessage,
TextMessage,
)
def convert_content_message_to_assistant_message(
message: Union[TextMessage, MultiModalMessage, FunctionCallMessage],
handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error",
) -> Optional[AssistantMessage]:
match message:
case TextMessage() | FunctionCallMessage():
return AssistantMessage(content=message.content, source=message.source)
case MultiModalMessage():
if handle_unrepresentable == "error":
raise ValueError("Cannot represent multimodal message as AssistantMessage")
elif handle_unrepresentable == "ignore":
return None
elif handle_unrepresentable == "try_slice":
return AssistantMessage(
content="".join([x for x in message.content if isinstance(x, str)]),
source=message.source,
)
def convert_content_message_to_user_message(
message: Union[TextMessage, MultiModalMessage, FunctionCallMessage],
handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error",
) -> Optional[UserMessage]:
match message:
case TextMessage() | MultiModalMessage():
return UserMessage(content=message.content, source=message.source)
case FunctionCallMessage():
if handle_unrepresentable == "error":
raise ValueError("Cannot represent multimodal message as UserMessage")
elif handle_unrepresentable == "ignore":
return None
elif handle_unrepresentable == "try_slice":
# TODO: what is a sliced function call?
raise NotImplementedError("Sliced function calls not yet implemented")
def convert_tool_call_response_message(
message: FunctionExecutionResultMessage,
handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error",
) -> Optional[FunctionExecutionResultMessage]:
match message:
case FunctionExecutionResultMessage():
return FunctionExecutionResultMessage(
content=[FunctionExecutionResult(content=x.content, call_id=x.call_id) for x in message.content]
)
def convert_messages_to_llm_messages(
messages: List[Message],
self_name: str,
handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error",
) -> List[LLMMessage]:
result: List[LLMMessage] = []
for message in messages:
match message:
case (
TextMessage(content=_, source=source)
| MultiModalMessage(content=_, source=source)
| FunctionCallMessage(content=_, source=source)
) if source == self_name:
converted_message_1 = convert_content_message_to_assistant_message(message, handle_unrepresentable)
if converted_message_1 is not None:
result.append(converted_message_1)
case (
TextMessage(content=_, source=source)
| MultiModalMessage(content=_, source=source)
| FunctionCallMessage(content=_, source=source)
) if source != self_name:
converted_message_2 = convert_content_message_to_user_message(message, handle_unrepresentable)
if converted_message_2 is not None:
result.append(converted_message_2)
case FunctionExecutionResultMessage(_):
converted_message_3 = convert_tool_call_response_message(message, handle_unrepresentable)
if converted_message_3 is not None:
result.append(converted_message_3)
case _:
raise AssertionError("unreachable")
return result

View File

@@ -0,0 +1,9 @@
"""
The :mod:`agnext.components` module provides building blocks for creating single agents
"""
from ._image import Image
from ._type_routed_agent import TypeRoutedAgent, message_handler
from ._types import FunctionCall
__all__ = ["Image", "TypeRoutedAgent", "message_handler", "FunctionCall"]

View File

@@ -0,0 +1,337 @@
# File based from: https://github.com/microsoft/autogen/blob/47f905267245e143562abfb41fcba503a9e1d56d/autogen/function_utils.py
# Credit to original authors
import inspect
from logging import getLogger
from typing import (
Annotated,
Any,
Callable,
Dict,
ForwardRef,
List,
Optional,
Set,
Tuple,
Type,
TypeVar,
Union,
cast,
get_args,
get_origin,
)
from pydantic import BaseModel, Field, create_model # type: ignore
from pydantic_core import PydanticUndefined
from typing_extensions import Literal
from ._pydantic_compat import evaluate_forwardref, model_dump, type2schema
logger = getLogger(__name__)
T = TypeVar("T")
def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any:
"""Get the type annotation of a parameter.
Args:
annotation: The annotation of the parameter
globalns: The global namespace of the function
Returns:
The type annotation of the parameter
"""
if isinstance(annotation, str):
annotation = ForwardRef(annotation)
annotation = evaluate_forwardref(annotation, globalns, globalns)
return annotation
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
"""Get the signature of a function with type annotations.
Args:
call: The function to get the signature for
Returns:
The signature of the function with type annotations
"""
signature = inspect.signature(call)
globalns = getattr(call, "__globals__", {})
typed_params = [
inspect.Parameter(
name=param.name,
kind=param.kind,
default=param.default,
annotation=get_typed_annotation(param.annotation, globalns),
)
for param in signature.parameters.values()
]
return_annotation = get_typed_annotation(signature.return_annotation, globalns)
typed_signature = inspect.Signature(typed_params, return_annotation=return_annotation)
return typed_signature
def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
"""Get the return annotation of a function.
Args:
call: The function to get the return annotation for
Returns:
The return annotation of the function
"""
signature = inspect.signature(call)
annotation = signature.return_annotation
if annotation is inspect.Signature.empty:
return None
globalns = getattr(call, "__globals__", {})
return get_typed_annotation(annotation, globalns)
def get_param_annotations(
typed_signature: inspect.Signature,
) -> Dict[str, Union[Annotated[Type[Any], str], Type[Any]]]:
"""Get the type annotations of the parameters of a function
Args:
typed_signature: The signature of the function with type annotations
Returns:
A dictionary of the type annotations of the parameters of the function
"""
return {
k: v.annotation for k, v in typed_signature.parameters.items() if v.annotation is not inspect.Signature.empty
}
class Parameters(BaseModel):
"""Parameters of a function as defined by the OpenAI API"""
type: Literal["object"] = "object"
properties: Dict[str, Dict[str, Any]]
required: List[str]
class Function(BaseModel):
"""A function as defined by the OpenAI API"""
description: Annotated[str, Field(description="Description of the function")]
name: Annotated[str, Field(description="Name of the function")]
parameters: Annotated[Parameters, Field(description="Parameters of the function")]
class ToolFunction(BaseModel):
"""A function under tool as defined by the OpenAI API."""
type: Literal["function"] = "function"
function: Annotated[Function, Field(description="Function under tool")]
def type2description(k: str, v: Union[Annotated[Type[Any], str], Type[Any]]) -> str:
# handles Annotated
if hasattr(v, "__metadata__"):
retval = v.__metadata__[0]
if isinstance(retval, str):
return retval
else:
raise ValueError(f"Invalid description {retval} for parameter {k}, should be a string.")
else:
return k
def get_parameter_json_schema(k: str, v: Any, default_values: Dict[str, Any]) -> Dict[str, Any]:
"""Get a JSON schema for a parameter as defined by the OpenAI API
Args:
k: The name of the parameter
v: The type of the parameter
default_values: The default values of the parameters of the function
Returns:
A Pydanitc model for the parameter
"""
schema = type2schema(v)
if k in default_values:
dv = default_values[k]
schema["default"] = dv
schema["description"] = type2description(k, v)
return schema
def get_required_params(typed_signature: inspect.Signature) -> List[str]:
"""Get the required parameters of a function
Args:
signature: The signature of the function as returned by inspect.signature
Returns:
A list of the required parameters of the function
"""
return [k for k, v in typed_signature.parameters.items() if v.default == inspect.Signature.empty]
def get_default_values(typed_signature: inspect.Signature) -> Dict[str, Any]:
"""Get default values of parameters of a function
Args:
signature: The signature of the function as returned by inspect.signature
Returns:
A dictionary of the default values of the parameters of the function
"""
return {k: v.default for k, v in typed_signature.parameters.items() if v.default != inspect.Signature.empty}
def get_parameters(
required: List[str],
param_annotations: Dict[str, Union[Annotated[Type[Any], str], Type[Any]]],
default_values: Dict[str, Any],
) -> Parameters:
"""Get the parameters of a function as defined by the OpenAI API
Args:
required: The required parameters of the function
hints: The type hints of the function as returned by typing.get_type_hints
Returns:
A Pydantic model for the parameters of the function
"""
return Parameters(
properties={
k: get_parameter_json_schema(k, v, default_values)
for k, v in param_annotations.items()
if v is not inspect.Signature.empty
},
required=required,
)
def get_missing_annotations(typed_signature: inspect.Signature, required: List[str]) -> Tuple[Set[str], Set[str]]:
"""Get the missing annotations of a function
Ignores the parameters with default values as they are not required to be annotated, but logs a warning.
Args:
typed_signature: The signature of the function with type annotations
required: The required parameters of the function
Returns:
A set of the missing annotations of the function
"""
all_missing = {k for k, v in typed_signature.parameters.items() if v.annotation is inspect.Signature.empty}
missing = all_missing.intersection(set(required))
unannotated_with_default = all_missing.difference(missing)
return missing, unannotated_with_default
def get_function_schema(f: Callable[..., Any], *, name: Optional[str] = None, description: str) -> Dict[str, Any]:
"""Get a JSON schema for a function as defined by the OpenAI API
Args:
f: The function to get the JSON schema for
name: The name of the function
description: The description of the function
Returns:
A JSON schema for the function
Raises:
TypeError: If the function is not annotated
Examples:
.. code-block:: python
def f(
a: Annotated[str, "Parameter a"],
b: int = 2,
c: Annotated[float, "Parameter c"] = 0.1,
) -> None:
pass
get_function_schema(f, description="function f")
# {'type': 'function',
# 'function': {'description': 'function f',
# 'name': 'f',
# 'parameters': {'type': 'object',
# 'properties': {'a': {'type': 'str', 'description': 'Parameter a'},
# 'b': {'type': 'int', 'description': 'b'},
# 'c': {'type': 'float', 'description': 'Parameter c'}},
# 'required': ['a']}}}
"""
typed_signature = get_typed_signature(f)
required = get_required_params(typed_signature)
default_values = get_default_values(typed_signature)
param_annotations = get_param_annotations(typed_signature)
return_annotation = get_typed_return_annotation(f)
missing, unannotated_with_default = get_missing_annotations(typed_signature, required)
if return_annotation is None:
logger.warning(
f"The return type of the function '{f.__name__}' is not annotated. Although annotating it is "
+ "optional, the function should return either a string, a subclass of 'pydantic.BaseModel'."
)
if unannotated_with_default != set():
unannotated_with_default_s = [f"'{k}'" for k in sorted(unannotated_with_default)]
logger.warning(
f"The following parameters of the function '{f.__name__}' with default values are not annotated: "
+ f"{', '.join(unannotated_with_default_s)}."
)
if missing != set():
missing_s = [f"'{k}'" for k in sorted(missing)]
raise TypeError(
f"All parameters of the function '{f.__name__}' without default values must be annotated. "
+ f"The annotations are missing for the following parameters: {', '.join(missing_s)}"
)
fname = name if name else f.__name__
parameters = get_parameters(required, param_annotations, default_values=default_values)
function = ToolFunction(
function=Function(
description=description,
name=fname,
parameters=parameters,
)
)
return model_dump(function)
def normalize_annotated_type(type_hint: Type[Any]) -> Type[Any]:
"""Normalize typing.Annotated types to the inner type."""
if get_origin(type_hint) is Annotated:
# Extract the inner type from Annotated
return get_args(type_hint)[0] # type: ignore
return type_hint
def args_base_model_from_signature(name: str, sig: inspect.Signature) -> Type[BaseModel]:
fields: Dict[str, tuple[Type[Any], Any]] = {}
for name, param in sig.parameters.items():
# This is handled externally
if name == "cancellation_token":
continue
if param.annotation is inspect.Parameter.empty:
raise ValueError("No annotation")
type = normalize_annotated_type(param.annotation)
description = type2description(name, param.annotation)
default_value = param.default if param.default is not inspect.Parameter.empty else PydanticUndefined
fields[name] = (type, Field(default=default_value, description=description))
return cast(BaseModel, create_model(name, **fields)) # type: ignore

View File

@@ -0,0 +1,78 @@
from __future__ import annotations
import base64
import re
from io import BytesIO
from pathlib import Path
import aiohttp
from openai.types.chat import ChatCompletionContentPartImageParam
from PIL import Image as PILImage
from typing_extensions import Literal
class Image:
def __init__(self, image: PILImage.Image):
self.image: PILImage.Image = image.convert("RGB")
@classmethod
def from_pil(cls, pil_image: PILImage.Image) -> Image:
return cls(pil_image)
@classmethod
def from_uri(cls, uri: str) -> Image:
if not re.match(r"data:image/(?:png|jpeg);base64,", uri):
raise ValueError("Invalid URI format. It should be a base64 encoded image URI.")
# A URI. Remove the prefix and decode the base64 string.
base64_data = re.sub(r"data:image/(?:png|jpeg);base64,", "", uri)
return cls.from_base64(base64_data)
@classmethod
async def from_url(cls, url: str) -> Image:
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
content = await response.read()
return cls(PILImage.open(content))
@classmethod
def from_base64(cls, base64_str: str) -> Image:
return cls(PILImage.open(BytesIO(base64.b64decode(base64_str))))
@classmethod
def from_file(cls, file_path: Path) -> Image:
return cls(PILImage.open(file_path))
def _repr_html_(self) -> str:
# Show the image in Jupyter notebook
return f'<img src="{self.data_uri}"/>'
@property
def data_uri(self) -> str:
buffered = BytesIO()
self.image.save(buffered, format="PNG")
content = buffered.getvalue()
return _convert_base64_to_data_uri(base64.b64encode(content).decode("utf-8"))
def to_openai_format(self, detail: Literal["auto", "low", "high"] = "auto") -> ChatCompletionContentPartImageParam:
return {"type": "image_url", "image_url": {"url": self.data_uri, "detail": detail}}
def _convert_base64_to_data_uri(base64_image: str) -> str:
def _get_mime_type_from_data_uri(base64_image: str) -> str:
# Decode the base64 string
image_data = base64.b64decode(base64_image)
# Check the first few bytes for known signatures
if image_data.startswith(b"\xff\xd8\xff"):
return "image/jpeg"
elif image_data.startswith(b"\x89PNG\r\n\x1a\n"):
return "image/png"
elif image_data.startswith(b"GIF87a") or image_data.startswith(b"GIF89a"):
return "image/gif"
elif image_data.startswith(b"RIFF") and image_data[8:12] == b"WEBP":
return "image/webp"
return "image/jpeg" # use jpeg for unknown formats, best guess.
mime_type = _get_mime_type_from_data_uri(base64_image)
data_uri = f"data:{mime_type};base64,{base64_image}"
return data_uri

View File

@@ -0,0 +1,65 @@
# File based from: https://github.com/microsoft/autogen/blob/47f905267245e143562abfb41fcba503a9e1d56d/autogen/_pydantic.py
# Credit to original authors
from typing import Any, Dict, Tuple, Type, Union, get_args
from pydantic import BaseModel
from pydantic.version import VERSION as PYDANTIC_VERSION
from typing_extensions import get_origin
__all__ = ("model_dump", "type2schema", "evaluate_forwardref")
PYDANTIC_V1 = PYDANTIC_VERSION.startswith("1.")
def evaluate_forwardref(
value: Any,
globalns: dict[str, Any] | None = None,
localns: dict[str, Any] | None = None,
) -> Any:
if PYDANTIC_V1:
from pydantic.typing import evaluate_forwardref as evaluate_forwardref_internal
return evaluate_forwardref_internal(value, globalns, localns)
else:
from pydantic._internal._typing_extra import eval_type_lenient
return eval_type_lenient(value, globalns, localns)
def type2schema(t: Type[Any] | None) -> Dict[str, Any]:
if PYDANTIC_V1:
from pydantic import schema_of # type: ignore
if t is None:
return {"type": "null"}
elif get_origin(t) is Union:
return {"anyOf": [type2schema(tt) for tt in get_args(t)]}
elif get_origin(t) in [Tuple, tuple]:
prefixItems = [type2schema(tt) for tt in get_args(t)]
return {
"maxItems": len(prefixItems),
"minItems": len(prefixItems),
"prefixItems": prefixItems,
"type": "array",
}
d = schema_of(t) # type: ignore
if "title" in d:
d.pop("title")
if "description" in d:
d.pop("description")
return d
else:
from pydantic import TypeAdapter
return TypeAdapter(t).json_schema()
def model_dump(model: BaseModel) -> Dict[str, Any]:
if PYDANTIC_V1:
return model.dict() # type: ignore
else:
return model.model_dump()

View File

@@ -0,0 +1,191 @@
import logging
from functools import wraps
from types import NoneType, UnionType
from typing import (
Any,
Callable,
Coroutine,
Dict,
Literal,
NoReturn,
Optional,
Protocol,
Sequence,
Type,
TypeVar,
Union,
cast,
get_args,
get_origin,
get_type_hints,
overload,
runtime_checkable,
)
from ..core import BaseAgent, CancellationToken
from ..core.exceptions import CantHandleException
logger = logging.getLogger("agnext")
ReceivesT = TypeVar("ReceivesT", contravariant=True)
ProducesT = TypeVar("ProducesT", covariant=True)
# TODO: Generic typevar bound binding U to agent type
# Can't do because python doesnt support it
def is_union(t: object) -> bool:
origin = get_origin(t)
return origin is Union or origin is UnionType
def is_optional(t: object) -> bool:
origin = get_origin(t)
return origin is Optional
# Special type to avoid the 3.10 vs 3.11+ difference of typing._SpecialForm vs typing.Any
class AnyType:
pass
def get_types(t: object) -> Sequence[Type[Any]] | None:
if is_union(t):
return get_args(t)
elif is_optional(t):
return tuple(list(get_args(t)) + [NoneType])
elif t is Any:
return (AnyType,)
elif isinstance(t, type):
return (t,)
elif isinstance(t, NoneType):
return (NoneType,)
else:
return None
@runtime_checkable
class MessageHandler(Protocol[ReceivesT, ProducesT]):
target_types: Sequence[type]
produces_types: Sequence[type]
is_message_handler: Literal[True]
async def __call__(self, message: ReceivesT, cancellation_token: CancellationToken) -> ProducesT: ...
# NOTE: this works on concrete types and not inheritance
# TODO: Use a protocl for the outer function to check checked arg names
@overload
def message_handler(
func: Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT]],
) -> MessageHandler[ReceivesT, ProducesT]: ...
@overload
def message_handler(
func: None = None,
*,
strict: bool = ...,
) -> Callable[
[Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT]]],
MessageHandler[ReceivesT, ProducesT],
]: ...
def message_handler(
func: None | Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT]] = None,
*,
strict: bool = True,
) -> (
Callable[
[Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT]]],
MessageHandler[ReceivesT, ProducesT],
]
| MessageHandler[ReceivesT, ProducesT]
):
def decorator(
func: Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT]],
) -> MessageHandler[ReceivesT, ProducesT]:
type_hints = get_type_hints(func)
if "message" not in type_hints:
raise AssertionError("message parameter not found in function signature")
if "return" not in type_hints:
raise AssertionError("return not found in function signature")
# Get the type of the message parameter
target_types = get_types(type_hints["message"])
if target_types is None:
raise AssertionError("Message type not found")
# print(type_hints)
return_types = get_types(type_hints["return"])
if return_types is None:
raise AssertionError("Return type not found")
# Convert target_types to list and stash
@wraps(func)
async def wrapper(self: Any, message: ReceivesT, cancellation_token: CancellationToken) -> ProducesT:
if type(message) not in target_types:
if strict:
raise CantHandleException(f"Message type {type(message)} not in target types {target_types}")
else:
logger.warning(f"Message type {type(message)} not in target types {target_types}")
return_value = await func(self, message, cancellation_token)
if AnyType not in return_types and type(return_value) not in return_types:
if strict:
raise ValueError(f"Return type {type(return_value)} not in return types {return_types}")
else:
logger.warning(f"Return type {type(return_value)} not in return types {return_types}")
return return_value
wrapper_handler = cast(MessageHandler[ReceivesT, ProducesT], wrapper)
wrapper_handler.target_types = list(target_types)
wrapper_handler.produces_types = list(return_types)
wrapper_handler.is_message_handler = True
return wrapper_handler
if func is None and not callable(func):
return decorator
elif callable(func):
return decorator(func)
else:
raise ValueError("Invalid arguments")
class TypeRoutedAgent(BaseAgent):
def __init__(self, description: str) -> None:
# Self is already bound to the handlers
self._handlers: Dict[
Type[Any],
Callable[[Any, CancellationToken], Coroutine[Any, Any, Any | None]],
] = {}
for attr in dir(self):
if callable(getattr(self, attr, None)):
handler = getattr(self, attr)
if hasattr(handler, "is_message_handler"):
message_handler = cast(MessageHandler[Any, Any], handler)
for target_type in message_handler.target_types:
self._handlers[target_type] = message_handler
subscriptions = list(self._handlers.keys())
super().__init__(description, subscriptions)
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any | None:
key_type: Type[Any] = type(message) # type: ignore
handler = self._handlers.get(key_type) # type: ignore
if handler is not None:
return await handler(message, cancellation_token)
else:
return await self.on_unhandled_message(message, cancellation_token)
async def on_unhandled_message(self, message: Any, cancellation_token: CancellationToken) -> NoReturn:
raise CantHandleException(f"Unhandled message: {message}")

View File

@@ -0,0 +1,12 @@
from __future__ import annotations
from dataclasses import dataclass
@dataclass
class FunctionCall:
id: str
# JSON args
arguments: str
# Function to call
name: str

View File

@@ -0,0 +1,17 @@
from ._base import CodeBlock, CodeExecutor, CodeResult
from ._func_with_reqs import Alias, FunctionWithRequirements, Import, ImportFromModule, with_requirements
from ._impl.command_line_code_result import CommandLineCodeResult
from ._impl.local_commandline_code_executor import LocalCommandLineCodeExecutor
__all__ = [
"LocalCommandLineCodeExecutor",
"CommandLineCodeResult",
"CodeBlock",
"CodeResult",
"CodeExecutor",
"Alias",
"ImportFromModule",
"Import",
"FunctionWithRequirements",
"with_requirements",
]

View File

@@ -0,0 +1,50 @@
# File based from: https://github.com/microsoft/autogen/blob/main/autogen/coding/base.py
# Credit to original authors
from __future__ import annotations
from dataclasses import dataclass
from typing import List, Protocol, runtime_checkable
@dataclass
class CodeBlock:
"""A code block extracted fromm an agent message."""
code: str
language: str
@dataclass
class CodeResult:
"""Result of a code execution."""
exit_code: int
output: str
@runtime_checkable
class CodeExecutor(Protocol):
"""Executes code blocks and returns the result."""
def execute_code_blocks(self, code_blocks: List[CodeBlock]) -> CodeResult:
"""Execute code blocks and return the result.
This method should be implemented by the code executor.
Args:
code_blocks (List[CodeBlock]): The code blocks to execute.
Returns:
CodeResult: The result of the code execution.
"""
...
def restart(self) -> None:
"""Restart the code executor.
This method should be implemented by the code executor.
This method is called when the agent is reset.
"""
...

View File

@@ -0,0 +1,200 @@
# File based from: https://github.com/microsoft/autogen/blob/main/autogen/coding/func_with_reqs.py
# Credit to original authors
from __future__ import annotations
import functools
import inspect
from dataclasses import dataclass, field
from importlib.abc import SourceLoader
from importlib.util import module_from_spec, spec_from_loader
from textwrap import dedent, indent
from typing import Any, Callable, Generic, List, Sequence, Set, TypeVar, Union
from typing_extensions import ParamSpec
T = TypeVar("T")
P = ParamSpec("P")
def _to_code(func: Union[FunctionWithRequirements[T, P], Callable[P, T], FunctionWithRequirementsStr]) -> str:
if isinstance(func, FunctionWithRequirementsStr):
return func.func
code = inspect.getsource(func)
# Strip the decorator
if code.startswith("@"):
code = code[code.index("\n") + 1 :]
return code
@dataclass
class Alias:
name: str
alias: str
@dataclass
class ImportFromModule:
module: str
imports: List[Union[str, Alias]]
Import = Union[str, ImportFromModule, Alias]
def _import_to_str(im: Import) -> str:
if isinstance(im, str):
return f"import {im}"
elif isinstance(im, Alias):
return f"import {im.name} as {im.alias}"
else:
def to_str(i: Union[str, Alias]) -> str:
if isinstance(i, str):
return i
else:
return f"{i.name} as {i.alias}"
imports = ", ".join(map(to_str, im.imports))
return f"from {im.module} import {imports}"
class _StringLoader(SourceLoader):
def __init__(self, data: str):
self.data = data
def get_source(self, fullname: str) -> str:
return self.data
def get_data(self, path: str) -> bytes:
return self.data.encode("utf-8")
def get_filename(self, fullname: str) -> str:
return "<not a real path>/" + fullname + ".py"
@dataclass
class FunctionWithRequirementsStr:
func: str
compiled_func: Callable[..., Any]
_func_name: str
python_packages: Sequence[str] = field(default_factory=list)
global_imports: Sequence[Import] = field(default_factory=list)
def __init__(self, func: str, python_packages: Sequence[str] = [], global_imports: Sequence[Import] = []):
self.func = func
self.python_packages = python_packages
self.global_imports = global_imports
module_name = "func_module"
loader = _StringLoader(func)
spec = spec_from_loader(module_name, loader)
if spec is None:
raise ValueError("Could not create spec")
module = module_from_spec(spec)
if spec.loader is None:
raise ValueError("Could not create loader")
try:
spec.loader.exec_module(module)
except Exception as e:
raise ValueError(f"Could not compile function: {e}") from e
functions = inspect.getmembers(module, inspect.isfunction)
if len(functions) != 1:
raise ValueError("The string must contain exactly one function")
self._func_name, self.compiled_func = functions[0]
def __call__(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("String based function with requirement objects are not directly callable")
@dataclass
class FunctionWithRequirements(Generic[T, P]):
func: Callable[P, T]
python_packages: Sequence[str] = field(default_factory=list)
global_imports: Sequence[Import] = field(default_factory=list)
@classmethod
def from_callable(
cls, func: Callable[P, T], python_packages: Sequence[str] = [], global_imports: Sequence[Import] = []
) -> FunctionWithRequirements[T, P]:
return cls(python_packages=python_packages, global_imports=global_imports, func=func)
@staticmethod
def from_str(
func: str, python_packages: Sequence[str] = [], global_imports: Sequence[Import] = []
) -> FunctionWithRequirementsStr:
return FunctionWithRequirementsStr(func=func, python_packages=python_packages, global_imports=global_imports)
# Type this based on F
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
return self.func(*args, **kwargs)
def with_requirements(
python_packages: Sequence[str] = [], global_imports: Sequence[Import] = []
) -> Callable[[Callable[P, T]], FunctionWithRequirements[T, P]]:
"""Decorate a function with package and import requirements
Args:
python_packages (List[str], optional): Packages required to function. Can include version info.. Defaults to [].
global_imports (List[Import], optional): Required imports. Defaults to [].
Returns:
Callable[[Callable[P, T]], FunctionWithRequirements[T, P]]: The decorated function
"""
def wrapper(func: Callable[P, T]) -> FunctionWithRequirements[T, P]:
func_with_reqs = FunctionWithRequirements(
python_packages=python_packages, global_imports=global_imports, func=func
)
functools.update_wrapper(func_with_reqs, func)
return func_with_reqs
return wrapper
def build_python_functions_file(
funcs: Sequence[Union[FunctionWithRequirements[Any, P], Callable[..., Any], FunctionWithRequirementsStr]],
) -> str:
# First collect all global imports
global_imports: Set[Import] = set()
for func in funcs:
if isinstance(func, (FunctionWithRequirements, FunctionWithRequirementsStr)):
global_imports.update(func.global_imports)
content = "\n".join(map(_import_to_str, global_imports)) + "\n\n"
for func in funcs:
content += _to_code(func) + "\n\n"
return content
def to_stub(func: Union[Callable[..., Any], FunctionWithRequirementsStr]) -> str:
"""Generate a stub for a function as a string
Args:
func (Callable[..., Any]): The function to generate a stub for
Returns:
str: The stub for the function
"""
if isinstance(func, FunctionWithRequirementsStr):
return to_stub(func.compiled_func)
content = f"def {func.__name__}{inspect.signature(func)}:\n"
docstring = func.__doc__
if docstring:
docstring = dedent(docstring)
docstring = '"""' + docstring + '"""'
docstring = indent(docstring, " ")
content += docstring + "\n"
content += " ..."
return content

View File

@@ -0,0 +1,11 @@
from dataclasses import dataclass
from typing import Optional
from .._base import CodeResult
@dataclass
class CommandLineCodeResult(CodeResult):
"""A code result class for command line code executor."""
code_file: Optional[str]

View File

@@ -0,0 +1,269 @@
# File based from: https://github.com/microsoft/autogen/blob/main/autogen/coding/local_commandline_code_executor.py
# Credit to original authors
import logging
import subprocess
import sys
import warnings
from hashlib import md5
from pathlib import Path
from string import Template
from typing import Any, Callable, ClassVar, List, Sequence, Union
from typing_extensions import ParamSpec
from .._base import CodeBlock, CodeExecutor
from .._func_with_reqs import (
FunctionWithRequirements,
FunctionWithRequirementsStr,
build_python_functions_file,
to_stub,
)
from .command_line_code_result import CommandLineCodeResult
from .utils import PYTHON_VARIANTS, get_file_name_from_content, lang_to_cmd, silence_pip # type: ignore
__all__ = ("LocalCommandLineCodeExecutor",)
A = ParamSpec("A")
class LocalCommandLineCodeExecutor(CodeExecutor):
SUPPORTED_LANGUAGES: ClassVar[List[str]] = [
"bash",
"shell",
"sh",
"pwsh",
"powershell",
"ps1",
"python",
]
FUNCTION_PROMPT_TEMPLATE: ClassVar[
str
] = """You have access to the following user defined functions. They can be accessed from the module called `$module_name` by their function names.
For example, if there was a function called `foo` you could import it by writing `from $module_name import foo`
$functions"""
def __init__(
self,
timeout: int = 60,
work_dir: Union[Path, str] = Path("."),
functions: Sequence[
Union[
FunctionWithRequirements[Any, A],
Callable[..., Any],
FunctionWithRequirementsStr,
]
] = [],
functions_module: str = "functions",
):
"""(Experimental) A code executor class that executes code through a local command line
environment.
**This will execute LLM generated code on the local machine.**
Each code block is saved as a file and executed in a separate process in
the working directory, and a unique file is generated and saved in the
working directory for each code block.
The code blocks are executed in the order they are received.
Command line code is sanitized using regular expression match against a list of dangerous commands in order to prevent self-destructive
commands from being executed which may potentially affect the users environment.
Currently the only supported languages is Python and shell scripts.
For Python code, use the language "python" for the code block.
For shell scripts, use the language "bash", "shell", or "sh" for the code
block.
Args:
timeout (int): The timeout for code execution. Default is 60.
work_dir (str): The working directory for the code execution. If None,
a default working directory will be used. The default working
directory is the current directory ".".
functions (List[Union[FunctionWithRequirements[Any, A], Callable[..., Any]]]): A list of functions that are available to the code executor. Default is an empty list.
"""
if timeout < 1:
raise ValueError("Timeout must be greater than or equal to 1.")
if isinstance(work_dir, str):
work_dir = Path(work_dir)
if not functions_module.isidentifier():
raise ValueError("Module name must be a valid Python identifier")
self._functions_module = functions_module
work_dir.mkdir(exist_ok=True)
self._timeout = timeout
self._work_dir: Path = work_dir
self._functions = functions
# Setup could take some time so we intentionally wait for the first code block to do it.
if len(functions) > 0:
self._setup_functions_complete = False
else:
self._setup_functions_complete = True
def format_functions_for_prompt(self, prompt_template: str = FUNCTION_PROMPT_TEMPLATE) -> str:
"""(Experimental) Format the functions for a prompt.
The template includes two variables:
- `$module_name`: The module name.
- `$functions`: The functions formatted as stubs with two newlines between each function.
Args:
prompt_template (str): The prompt template. Default is the class default.
Returns:
str: The formatted prompt.
"""
template = Template(prompt_template)
return template.substitute(
module_name=self._functions_module,
functions="\n\n".join([to_stub(func) for func in self._functions]),
)
@property
def functions_module(self) -> str:
"""(Experimental) The module name for the functions."""
return self._functions_module
@property
def functions(self) -> List[str]:
raise NotImplementedError
@property
def timeout(self) -> int:
"""(Experimental) The timeout for code execution."""
return self._timeout
@property
def work_dir(self) -> Path:
"""(Experimental) The working directory for the code execution."""
return self._work_dir
def _setup_functions(self) -> None:
func_file_content = build_python_functions_file(self._functions)
func_file = self._work_dir / f"{self._functions_module}.py"
func_file.write_text(func_file_content)
# Collect requirements
lists_of_packages = [x.python_packages for x in self._functions if isinstance(x, FunctionWithRequirements)]
flattened_packages = [item for sublist in lists_of_packages for item in sublist]
required_packages = list(set(flattened_packages))
if len(required_packages) > 0:
logging.info("Ensuring packages are installed in executor.")
cmd = [sys.executable, "-m", "pip", "install"]
cmd.extend(required_packages)
try:
result = subprocess.run(
cmd,
cwd=self._work_dir,
capture_output=True,
text=True,
timeout=float(self._timeout),
)
except subprocess.TimeoutExpired as e:
raise ValueError("Pip install timed out") from e
if result.returncode != 0:
raise ValueError(f"Pip install failed. {result.stdout}, {result.stderr}")
# Attempt to load the function file to check for syntax errors, imports etc.
exec_result = self._execute_code_dont_check_setup([CodeBlock(code=func_file_content, language="python")])
if exec_result.exit_code != 0:
raise ValueError(f"Functions failed to load: {exec_result.output}")
self._setup_functions_complete = True
def execute_code_blocks(self, code_blocks: List[CodeBlock]) -> CommandLineCodeResult:
"""(Experimental) Execute the code blocks and return the result.
Args:
code_blocks (List[CodeBlock]): The code blocks to execute.
Returns:
CommandLineCodeResult: The result of the code execution."""
if not self._setup_functions_complete:
self._setup_functions()
return self._execute_code_dont_check_setup(code_blocks)
def _execute_code_dont_check_setup(self, code_blocks: List[CodeBlock]) -> CommandLineCodeResult:
logs_all: str = ""
file_names: List[Path] = []
exitcode = 0
for code_block in code_blocks:
lang, code = code_block.language, code_block.code
lang = lang.lower()
code = silence_pip(code, lang)
if lang in PYTHON_VARIANTS:
lang = "python"
if lang not in self.SUPPORTED_LANGUAGES:
# In case the language is not supported, we return an error message.
exitcode = 1
logs_all += "\n" + f"unknown language {lang}"
break
try:
# Check if there is a filename comment
filename = get_file_name_from_content(code, self._work_dir)
except ValueError:
return CommandLineCodeResult(
exit_code=1,
output="Filename is not in the workspace",
code_file=None,
)
if filename is None:
# create a file with an automatically generated name
code_hash = md5(code.encode()).hexdigest()
filename = f"tmp_code_{code_hash}.{'py' if lang.startswith('python') else lang}"
written_file = (self._work_dir / filename).resolve()
with written_file.open("w", encoding="utf-8") as f:
f.write(code)
file_names.append(written_file)
program = sys.executable if lang.startswith("python") else lang_to_cmd(lang)
cmd = [program, str(written_file.absolute())]
try:
result = subprocess.run(
cmd,
cwd=self._work_dir,
capture_output=True,
text=True,
timeout=float(self._timeout),
)
except subprocess.TimeoutExpired:
logs_all += "\n Timeout"
# Same exit code as the timeout command on linux.
exitcode = 124
break
logs_all += result.stderr
logs_all += result.stdout
exitcode = result.returncode
if exitcode != 0:
break
code_file = str(file_names[0]) if len(file_names) > 0 else None
return CommandLineCodeResult(exit_code=exitcode, output=logs_all, code_file=code_file)
def restart(self) -> None:
"""(Experimental) Restart the code executor."""
warnings.warn(
"Restarting local command line code executor is not supported. No action is taken.",
stacklevel=2,
)

View File

@@ -0,0 +1,88 @@
# File based from: https://github.com/microsoft/autogen/blob/main/autogen/coding/utils.py
# Credit to original authors
# Will return the filename relative to the workspace path
import re
from pathlib import Path
from typing import Optional
# Raises ValueError if the file is not in the workspace
def get_file_name_from_content(code: str, workspace_path: Path) -> Optional[str]:
first_line = code.split("\n")[0]
# TODO - support other languages
if first_line.startswith("# filename:"):
filename = first_line.split(":")[1].strip()
# Handle relative paths in the filename
path = Path(filename)
if not path.is_absolute():
path = workspace_path / path
path = path.resolve()
# Throws an error if the file is not in the workspace
relative = path.relative_to(workspace_path.resolve())
return str(relative)
return None
def silence_pip(code: str, lang: str) -> str:
"""Apply -qqq flag to pip install commands."""
if lang == "python":
regex = r"^! ?pip install"
elif lang in ["bash", "shell", "sh", "pwsh", "powershell", "ps1"]:
regex = r"^pip install"
else:
return code
# Find lines that start with pip install and make sure "-qqq" flag is added.
lines = code.split("\n")
for i, line in enumerate(lines):
# use regex to find lines that start with pip install.
match = re.search(regex, line)
if match is not None:
if "-qqq" not in line:
lines[i] = line.replace(match.group(0), match.group(0) + " -qqq")
return "\n".join(lines)
PYTHON_VARIANTS = ["python", "Python", "py"]
def lang_to_cmd(lang: str) -> str:
if lang in PYTHON_VARIANTS:
return "python"
if lang.startswith("python") or lang in ["bash", "sh"]:
return lang
if lang in ["shell"]:
return "sh"
else:
raise ValueError(f"Unsupported language: {lang}")
# Regular expression for finding a code block
# ```[ \t]*(\w+)?[ \t]*\r?\n(.*?)[ \t]*\r?\n``` Matches multi-line code blocks.
# The [ \t]* matches the potential spaces before language name.
# The (\w+)? matches the language, where the ? indicates it is optional.
# The [ \t]* matches the potential spaces (not newlines) after language name.
# The \r?\n makes sure there is a linebreak after ```.
# The (.*?) matches the code itself (non-greedy).
# The \r?\n makes sure there is a linebreak before ```.
# The [ \t]* matches the potential spaces before closing ``` (the spec allows indentation).
CODE_BLOCK_PATTERN = r"```[ \t]*(\w+)?[ \t]*\r?\n(.*?)\r?\n[ \t]*```"
def infer_lang(code: str) -> str:
"""infer the language for the code.
TODO: make it robust.
"""
if code.startswith("python ") or code.startswith("pip") or code.startswith("python3 "):
return "sh"
# check if code is a valid python code
try:
compile(code, "test", "exec")
return "python"
except SyntaxError:
# not a valid python code
return "unknown"

View File

@@ -0,0 +1,32 @@
from ._model_client import ChatCompletionClient, ModelCapabilities
from ._openai_client import (
AzureOpenAI,
OpenAI,
)
from ._types import (
AssistantMessage,
CreateResult,
FinishReasons,
FunctionExecutionResult,
FunctionExecutionResultMessage,
LLMMessage,
RequestUsage,
SystemMessage,
UserMessage,
)
__all__ = [
"AzureOpenAI",
"OpenAI",
"ModelCapabilities",
"ChatCompletionClient",
"SystemMessage",
"UserMessage",
"AssistantMessage",
"FunctionExecutionResult",
"FunctionExecutionResultMessage",
"LLMMessage",
"RequestUsage",
"FinishReasons",
"CreateResult",
]

View File

@@ -0,0 +1,52 @@
from __future__ import annotations
from typing import Mapping, Optional, Sequence, runtime_checkable
from typing_extensions import (
Any,
AsyncGenerator,
Protocol,
Required,
TypedDict,
Union,
)
from ..tools import Tool
from ._types import CreateResult, LLMMessage, RequestUsage
class ModelCapabilities(TypedDict, total=False):
vision: Required[bool]
function_calling: Required[bool]
json_output: Required[bool]
@runtime_checkable
class ChatCompletionClient(Protocol):
# Caching has to be handled internally as they can depend on the create args that were stored in the constructor
async def create(
self,
messages: Sequence[LLMMessage],
tools: Sequence[Tool] = [],
# None means do not override the default
# A value means to override the client default - often specified in the constructor
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
) -> CreateResult: ...
def create_stream(
self,
messages: Sequence[LLMMessage],
tools: Sequence[Tool] = [],
# None means do not override the default
# A value means to override the client default - often specified in the constructor
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
) -> AsyncGenerator[Union[str, CreateResult], None]: ...
def actual_usage(self) -> RequestUsage: ...
def total_usage(self) -> RequestUsage: ...
@property
def capabilities(self) -> ModelCapabilities: ...

View File

@@ -0,0 +1,89 @@
from typing import Dict
from ._model_client import ModelCapabilities
# Based on: https://platform.openai.com/docs/models/continuous-model-upgrades
# This is a moving target, so correctness is checked by the model value returned by openai against expected values at runtime``
_MODEL_POINTERS = {
"gpt-4o": "gpt-4o-2024-05-13",
"gpt-4-turbo": "gpt-4-turbo-2024-04-09",
"gpt-4-turbo-preview": "gpt-4-0125-preview",
"gpt-4": "gpt-4-0613",
"gpt-4-32k": "gpt-4-32k-0613",
"gpt-3.5-turbo": "gpt-3.5-turbo-0125",
"gpt-3.5-turbo-16k": "gpt-3.5-turbo-16k-0613",
}
_MODEL_CAPABILITIES: Dict[str, ModelCapabilities] = {
"gpt-4o-2024-05-13": {
"vision": True,
"function_calling": True,
"json_output": True,
},
"gpt-4-turbo-2024-04-09": {
"vision": True,
"function_calling": True,
"json_output": True,
},
"gpt-4-0125-preview": {
"vision": False,
"function_calling": True,
"json_output": True,
},
"gpt-4-1106-preview": {
"vision": False,
"function_calling": True,
"json_output": True,
},
"gpt-4-1106-vision-preview": {
"vision": True,
"function_calling": False,
"json_output": False,
},
"gpt-4-0613": {
"vision": False,
"function_calling": True,
"json_output": True,
},
"gpt-4-32k-0613": {
"vision": False,
"function_calling": True,
"json_output": True,
},
"gpt-3.5-turbo-0125": {
"vision": False,
"function_calling": True,
"json_output": True,
},
"gpt-3.5-turbo-1106": {
"vision": False,
"function_calling": True,
"json_output": True,
},
"gpt-3.5-turbo-instruct": {
"vision": False,
"function_calling": True,
"json_output": True,
},
"gpt-3.5-turbo-0613": {
"vision": False,
"function_calling": True,
"json_output": True,
},
"gpt-3.5-turbo-16k-0613": {
"vision": False,
"function_calling": True,
"json_output": True,
},
}
def resolve_model(model: str) -> str:
if model in _MODEL_POINTERS:
return _MODEL_POINTERS[model]
return model
def get_capabilties(model: str) -> ModelCapabilities:
resolved_model = resolve_model(model)
return _MODEL_CAPABILITIES[resolved_model]

View File

@@ -0,0 +1,569 @@
import inspect
import logging
import re
import warnings
from typing import (
Any,
AsyncGenerator,
Dict,
List,
Mapping,
Optional,
Sequence,
Set,
Union,
cast,
)
from openai import AsyncAzureOpenAI, AsyncOpenAI
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionContentPartParam,
ChatCompletionContentPartTextParam,
ChatCompletionMessageParam,
ChatCompletionMessageToolCallParam,
ChatCompletionRole,
ChatCompletionSystemMessageParam,
ChatCompletionToolMessageParam,
ChatCompletionToolParam,
ChatCompletionUserMessageParam,
completion_create_params,
)
from openai.types.shared_params import FunctionDefinition, FunctionParameters
from typing_extensions import Unpack
from ...application.logging import EVENT_LOGGER_NAME, LLMCallEvent
from .. import (
FunctionCall,
Image,
)
from ..tools import Tool
from . import _model_info
from ._model_client import ChatCompletionClient, ModelCapabilities
from ._types import (
AssistantMessage,
CreateResult,
FunctionExecutionResultMessage,
LLMMessage,
RequestUsage,
SystemMessage,
UserMessage,
)
from .config import AzureOpenAIClientConfiguration, OpenAIClientConfiguration
logger = logging.getLogger(EVENT_LOGGER_NAME)
openai_init_kwargs = set(inspect.getfullargspec(AsyncOpenAI.__init__).kwonlyargs)
aopenai_init_kwargs = set(inspect.getfullargspec(AsyncAzureOpenAI.__init__).kwonlyargs)
create_kwargs = set(completion_create_params.CompletionCreateParamsBase.__annotations__.keys()) | set(
("timeout", "stream")
)
# Only single choice allowed
disallowed_create_args = set(["stream", "messages", "function_call", "functions", "n"])
required_create_args: Set[str] = set(["model"])
def _azure_openai_client_from_config(config: Mapping[str, Any]) -> AsyncAzureOpenAI:
# Take a copy
copied_config = dict(config).copy()
# Do some fixups
copied_config["azure_deployment"] = copied_config.get("azure_deployment", config.get("model"))
if copied_config["azure_deployment"] is not None:
copied_config["azure_deployment"] = copied_config["azure_deployment"].replace(".", "")
copied_config["azure_endpoint"] = copied_config.get("azure_endpoint", copied_config.pop("base_url", None))
# Shave down the config to just the AzureOpenAI kwargs
azure_config = {k: v for k, v in copied_config.items() if k in aopenai_init_kwargs}
return AsyncAzureOpenAI(**azure_config)
def _openai_client_from_config(config: Mapping[str, Any]) -> AsyncOpenAI:
# Shave down the config to just the OpenAI kwargs
openai_config = {k: v for k, v in config.items() if k in openai_init_kwargs}
return AsyncOpenAI(**openai_config)
def _create_args_from_config(config: Mapping[str, Any]) -> Dict[str, Any]:
create_args = {k: v for k, v in config.items() if k in create_kwargs}
create_args_keys = set(create_args.keys())
if not required_create_args.issubset(create_args_keys):
raise ValueError(f"Required create args are missing: {required_create_args - create_args_keys}")
if disallowed_create_args.intersection(create_args_keys):
raise ValueError(f"Disallowed create args are present: {disallowed_create_args.intersection(create_args_keys)}")
return create_args
# TODO check types
# oai_system_message_schema = type2schema(ChatCompletionSystemMessageParam)
# oai_user_message_schema = type2schema(ChatCompletionUserMessageParam)
# oai_assistant_message_schema = type2schema(ChatCompletionAssistantMessageParam)
# oai_tool_message_schema = type2schema(ChatCompletionToolMessageParam)
def type_to_role(message: LLMMessage) -> ChatCompletionRole:
if isinstance(message, SystemMessage):
return "system"
elif isinstance(message, UserMessage):
return "user"
elif isinstance(message, AssistantMessage):
return "assistant"
else:
return "tool"
def user_message_to_oai(message: UserMessage) -> ChatCompletionUserMessageParam:
if isinstance(message.content, str):
return ChatCompletionUserMessageParam(
content=message.content,
role="user",
name=message.source,
)
else:
parts: List[ChatCompletionContentPartParam] = []
for part in message.content:
if isinstance(part, str):
oai_part = ChatCompletionContentPartTextParam(
text=part,
type="text",
)
parts.append(oai_part)
elif isinstance(part, Image):
# TODO: support url based images
# TODO: support specifying details
parts.append(part.to_openai_format())
else:
raise ValueError(f"Unknown content type: {part}")
return ChatCompletionUserMessageParam(
content=parts,
role="user",
name=message.source,
)
def system_message_to_oai(message: SystemMessage) -> ChatCompletionSystemMessageParam:
return ChatCompletionSystemMessageParam(
content=message.content,
role="system",
)
def func_call_to_oai(message: FunctionCall) -> ChatCompletionMessageToolCallParam:
return ChatCompletionMessageToolCallParam(
id=message.id,
function={
"arguments": message.arguments,
"name": message.name,
},
type="function",
)
def tool_message_to_oai(
message: FunctionExecutionResultMessage,
) -> Sequence[ChatCompletionToolMessageParam]:
return [
ChatCompletionToolMessageParam(content=x.content, role="tool", tool_call_id=x.call_id) for x in message.content
]
def assistant_message_to_oai(
message: AssistantMessage,
) -> ChatCompletionAssistantMessageParam:
if isinstance(message.content, list):
return ChatCompletionAssistantMessageParam(
tool_calls=[func_call_to_oai(x) for x in message.content],
role="assistant",
name=message.source,
)
else:
return ChatCompletionAssistantMessageParam(
content=message.content,
role="assistant",
name=message.source,
)
def to_oai_type(message: LLMMessage) -> Sequence[ChatCompletionMessageParam]:
if isinstance(message, SystemMessage):
return [system_message_to_oai(message)]
elif isinstance(message, UserMessage):
return [user_message_to_oai(message)]
elif isinstance(message, AssistantMessage):
return [assistant_message_to_oai(message)]
else:
return tool_message_to_oai(message)
def _add_usage(usage1: RequestUsage, usage2: RequestUsage) -> RequestUsage:
return RequestUsage(
prompt_tokens=usage1.prompt_tokens + usage2.prompt_tokens,
completion_tokens=usage1.completion_tokens + usage2.completion_tokens,
)
def convert_tools(
tools: Sequence[Tool],
) -> List[ChatCompletionToolParam]:
result: List[ChatCompletionToolParam] = []
for tool in tools:
tool_schema = tool.schema
result.append(
ChatCompletionToolParam(
type="function",
function=FunctionDefinition(
name=tool_schema["name"],
description=tool_schema["description"] if "description" in tool_schema else "",
parameters=cast(FunctionParameters, tool_schema["parameters"])
if "parameters" in tool_schema
else {},
),
)
)
# Check if all tools have valid names.
for tool_param in result:
assert_valid_name(tool_param["function"]["name"])
return result
def normalize_name(name: str) -> str:
"""
LLMs sometimes ask functions while ignoring their own format requirements, this function should be used to replace invalid characters with "_".
Prefer _assert_valid_name for validating user configuration or input
"""
return re.sub(r"[^a-zA-Z0-9_-]", "_", name)[:64]
def assert_valid_name(name: str) -> str:
"""
Ensure that configured names are valid, raises ValueError if not.
For munging LLM responses use _normalize_name to ensure LLM specified names don't break the API.
"""
if not re.match(r"^[a-zA-Z0-9_-]+$", name):
raise ValueError(f"Invalid name: {name}. Only letters, numbers, '_' and '-' are allowed.")
if len(name) > 64:
raise ValueError(f"Invalid name: {name}. Name must be less than 64 characters.")
return name
class BaseOpenAI(ChatCompletionClient):
def __init__(
self,
client: Union[AsyncOpenAI, AsyncAzureOpenAI],
create_args: Dict[str, Any],
model_capabilities: Optional[ModelCapabilities] = None,
):
self._client = client
if model_capabilities is None and isinstance(client, AsyncAzureOpenAI):
raise ValueError("AzureOpenAI requires explicit model capabilities")
elif model_capabilities is None:
self._model_capabilities = _model_info.get_capabilties(create_args["model"])
else:
self._model_capabilities = model_capabilities
self._resolved_model: Optional[str] = None
if "model" in create_args:
self._resolved_model = _model_info.resolve_model(create_args["model"])
if (
"response_format" in create_args
and create_args["response_format"]["type"] == "json_object"
and not self._model_capabilities["json_output"]
):
raise ValueError("Model does not support JSON output")
self._create_args = create_args
self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
self._actual_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
@classmethod
def create_from_config(cls, config: Dict[str, Any]) -> ChatCompletionClient:
return OpenAI(**config)
async def create(
self,
messages: Sequence[LLMMessage],
tools: Sequence[Tool] = [],
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
) -> CreateResult:
# Make sure all extra_create_args are valid
extra_create_args_keys = set(extra_create_args.keys())
if not create_kwargs.issuperset(extra_create_args_keys):
raise ValueError(f"Extra create args are invalid: {extra_create_args_keys - create_kwargs}")
# Copy the create args and overwrite anything in extra_create_args
create_args = self._create_args.copy()
create_args.update(extra_create_args)
# TODO: allow custom handling.
# For now we raise an error if images are present and vision is not supported
if self.capabilities["vision"] is False:
for message in messages:
if isinstance(message, UserMessage):
if isinstance(message.content, list) and any(isinstance(x, Image) for x in message.content):
raise ValueError("Model does not support vision and image was provided")
if json_output is not None:
if self.capabilities["json_output"] is False and json_output is True:
raise ValueError("Model does not support JSON output")
if json_output is True:
create_args["response_format"] = {"type": "json_object"}
else:
create_args["response_format"] = {"type": "text"}
if self.capabilities["json_output"] is False and json_output is True:
raise ValueError("Model does not support JSON output")
oai_messages_nested = [to_oai_type(m) for m in messages]
oai_messages = [item for sublist in oai_messages_nested for item in sublist]
if self.capabilities["function_calling"] is False and len(tools) > 0:
raise ValueError("Model does not support function calling")
if len(tools) > 0:
converted_tools = convert_tools(tools)
result = await self._client.chat.completions.create(
messages=oai_messages,
stream=False,
tools=converted_tools,
**create_args,
)
else:
result = await self._client.chat.completions.create(messages=oai_messages, stream=False, **create_args)
if result.usage is not None:
logger.info(
LLMCallEvent(
prompt_tokens=result.usage.prompt_tokens,
completion_tokens=result.usage.completion_tokens,
)
)
usage = RequestUsage(
# TODO backup token counting
prompt_tokens=result.usage.prompt_tokens if result.usage is not None else 0,
completion_tokens=(result.usage.completion_tokens if result.usage is not None else 0),
)
if self._resolved_model is not None:
if self._resolved_model != result.model:
warnings.warn(
f"Resolved model mismatch: {self._resolved_model} != {result.model}. AutoGen model mapping may be incorrect.",
stacklevel=2,
)
# Limited to a single choice currently.
choice = result.choices[0]
if choice.finish_reason == "function_call":
raise ValueError("Function calls are not supported in this context")
content: Union[str, List[FunctionCall]]
if choice.finish_reason == "tool_calls":
assert choice.message.tool_calls is not None
assert choice.message.function_call is None
# NOTE: If OAI response type changes, this will need to be updated
content = [
FunctionCall(
id=x.id,
arguments=x.function.arguments,
name=normalize_name(x.function.name),
)
for x in choice.message.tool_calls
]
finish_reason = "function_calls"
else:
finish_reason = choice.finish_reason
content = choice.message.content or ""
response = CreateResult(finish_reason=finish_reason, content=content, usage=usage, cached=False) # type: ignore
_add_usage(self._actual_usage, usage)
_add_usage(self._total_usage, usage)
# TODO - why is this cast needed?
return response
async def create_stream(
self,
messages: Sequence[LLMMessage],
tools: Sequence[Tool] = [],
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
) -> AsyncGenerator[Union[str, CreateResult], None]:
# Make sure all extra_create_args are valid
extra_create_args_keys = set(extra_create_args.keys())
if not create_kwargs.issuperset(extra_create_args_keys):
raise ValueError(f"Extra create args are invalid: {extra_create_args_keys - create_kwargs}")
# Copy the create args and overwrite anything in extra_create_args
create_args = self._create_args.copy()
create_args.update(extra_create_args)
oai_messages_nested = [to_oai_type(m) for m in messages]
oai_messages = [item for sublist in oai_messages_nested for item in sublist]
# TODO: allow custom handling.
# For now we raise an error if images are present and vision is not supported
if self.capabilities["vision"] is False:
for message in messages:
if isinstance(message, UserMessage):
if isinstance(message.content, list) and any(isinstance(x, Image) for x in message.content):
raise ValueError("Model does not support vision and image was provided")
if json_output is not None:
if self.capabilities["json_output"] is False and json_output is True:
raise ValueError("Model does not support JSON output")
if json_output is True:
create_args["response_format"] = {"type": "json_object"}
else:
create_args["response_format"] = {"type": "text"}
if len(tools) > 0:
converted_tools = convert_tools(tools)
stream = await self._client.chat.completions.create(
messages=oai_messages, stream=True, tools=converted_tools, **create_args
)
else:
stream = await self._client.chat.completions.create(messages=oai_messages, stream=True, **create_args)
stop_reason = None
maybe_model = None
content_deltas: List[str] = []
full_tool_calls: Dict[int, FunctionCall] = {}
completion_tokens = 0
async for chunk in stream:
choice = chunk.choices[0]
stop_reason = choice.finish_reason
maybe_model = chunk.model
# First try get content
if choice.delta.content is not None:
content_deltas.append(choice.delta.content)
if len(choice.delta.content) > 0:
yield choice.delta.content
continue
# Otherwise, get tool calls
if choice.delta.tool_calls is not None:
for tool_call_chunk in choice.delta.tool_calls:
idx = tool_call_chunk.index
if idx not in full_tool_calls:
# We ignore the type hint here because we want to fill in type when the delta provides it
full_tool_calls[idx] = FunctionCall(id="", arguments="", name="")
if tool_call_chunk.id is not None:
full_tool_calls[idx].id += tool_call_chunk.id
if tool_call_chunk.function is not None:
if tool_call_chunk.function.name is not None:
full_tool_calls[idx].name += tool_call_chunk.function.name
if tool_call_chunk.function.arguments is not None:
full_tool_calls[idx].arguments += tool_call_chunk.function.arguments
model = maybe_model or create_args["model"]
model = model.replace("gpt-35", "gpt-3.5") # hack for Azure API
# TODO fix count token
prompt_tokens = 0
# prompt_tokens = count_token(messages, model=model)
if stop_reason is None:
raise ValueError("No stop reason found")
content: Union[str, List[FunctionCall]]
if len(content_deltas) > 1:
content = "".join(content_deltas)
completion_tokens = 0
# completion_tokens = count_token(content, model=model)
else:
completion_tokens = 0
# TODO: fix assumption that dict values were added in order and actually order by int index
# for tool_call in full_tool_calls.values():
# # value = json.dumps(tool_call)
# # completion_tokens += count_token(value, model=model)
# completion_tokens += 0
content = list(full_tool_calls.values())
usage = RequestUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
if stop_reason == "function_call":
raise ValueError("Function calls are not supported in this context")
if stop_reason == "tool_calls":
stop_reason = "function_calls"
result = CreateResult(finish_reason=stop_reason, content=content, usage=usage, cached=False)
_add_usage(self._actual_usage, usage)
_add_usage(self._total_usage, usage)
yield result
def actual_usage(self) -> RequestUsage:
return self._actual_usage
def total_usage(self) -> RequestUsage:
return self._total_usage
@property
def capabilities(self) -> ModelCapabilities:
return self._model_capabilities
class OpenAI(BaseOpenAI):
def __init__(self, **kwargs: Unpack[OpenAIClientConfiguration]):
if "model" not in kwargs:
raise ValueError("model is required for OpenAI")
model_capabilities: Optional[ModelCapabilities] = None
copied_args = dict(kwargs).copy()
if "model_capabilities" in kwargs:
model_capabilities = kwargs["model_capabilities"]
del copied_args["model_capabilities"]
client = _openai_client_from_config(copied_args)
create_args = _create_args_from_config(copied_args)
self._raw_config = copied_args
super().__init__(client, create_args, model_capabilities)
def __getstate__(self) -> Dict[str, Any]:
state = self.__dict__.copy()
state["_client"] = None
return state
def __setstate__(self, state: Dict[str, Any]) -> None:
self.__dict__.update(state)
self._client = _openai_client_from_config(state["_raw_config"])
class AzureOpenAI(BaseOpenAI):
def __init__(self, **kwargs: Unpack[AzureOpenAIClientConfiguration]):
if "model" not in kwargs:
raise ValueError("model is required for OpenAI")
model_capabilities: Optional[ModelCapabilities] = None
copied_args = dict(kwargs).copy()
if "model_capabilities" in kwargs:
model_capabilities = kwargs["model_capabilities"]
del copied_args["model_capabilities"]
client = _azure_openai_client_from_config(copied_args)
create_args = _create_args_from_config(copied_args)
self._raw_config = copied_args
super().__init__(client, create_args, model_capabilities)
def __getstate__(self) -> Dict[str, Any]:
state = self.__dict__.copy()
state["_client"] = None
return state
def __setstate__(self, state: Dict[str, Any]) -> None:
self.__dict__.update(state)
self._client = _azure_openai_client_from_config(state["_raw_config"])

View File

@@ -0,0 +1,56 @@
from dataclasses import dataclass
from typing import List, Literal, Union
from .. import FunctionCall, Image
@dataclass
class SystemMessage:
content: str
@dataclass
class UserMessage:
content: Union[str, List[Union[str, Image]]]
# Name of the agent that sent this message
source: str
@dataclass
class AssistantMessage:
content: Union[str, List[FunctionCall]]
# Name of the agent that sent this message
source: str
@dataclass
class FunctionExecutionResult:
content: str
call_id: str
@dataclass
class FunctionExecutionResultMessage:
content: List[FunctionExecutionResult]
LLMMessage = Union[SystemMessage, UserMessage, AssistantMessage, FunctionExecutionResultMessage]
@dataclass
class RequestUsage:
prompt_tokens: int
completion_tokens: int
FinishReasons = Literal["stop", "length", "function_calls", "content_filter"]
@dataclass
class CreateResult:
finish_reason: FinishReasons
content: Union[str, List[FunctionCall]]
usage: RequestUsage
cached: bool

View File

@@ -0,0 +1,52 @@
from typing import Awaitable, Callable, Dict, List, Literal, Optional, Union
from typing_extensions import Required, TypedDict
from .._model_client import ModelCapabilities
class ResponseFormat(TypedDict):
type: Literal["text", "json_object"]
class CreateArguments(TypedDict, total=False):
frequency_penalty: Optional[float]
logit_bias: Optional[Dict[str, int]]
max_tokens: Optional[int]
n: Optional[int]
presence_penalty: Optional[float]
response_format: ResponseFormat
seed: Optional[int]
stop: Union[Optional[str], List[str]]
temperature: Optional[float]
top_p: Optional[float]
user: str
AsyncAzureADTokenProvider = Callable[[], Union[str, Awaitable[str]]]
class BaseOpenAIClientConfiguration(CreateArguments, total=False):
model: str
api_key: str
timeout: Union[float, None]
max_retries: int
# See OpenAI docs for explanation of these parameters
class OpenAIClientConfiguration(BaseOpenAIClientConfiguration, total=False):
organization: str
base_url: str
# Not required
model_capabilities: ModelCapabilities
class AzureOpenAIClientConfiguration(BaseOpenAIClientConfiguration, total=False):
# Azure specific
azure_endpoint: Required[str]
azure_deployment: str
api_version: Required[str]
azure_ad_token: str
azure_ad_token_provider: AsyncAzureADTokenProvider
# Must be provided
model_capabilities: Required[ModelCapabilities]

View File

@@ -0,0 +1,13 @@
from ._base import BaseTool, BaseToolWithState, Tool
from ._code_execution import CodeExecutionInput, CodeExecutionResult, PythonCodeExecutionTool
from ._function_tool import FunctionTool
__all__ = [
"Tool",
"BaseTool",
"BaseToolWithState",
"PythonCodeExecutionTool",
"CodeExecutionInput",
"CodeExecutionResult",
"FunctionTool",
]

View File

@@ -0,0 +1,151 @@
import json
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Any, Dict, Generic, Mapping, Protocol, Type, TypedDict, TypeVar
from pydantic import BaseModel
from typing_extensions import NotRequired
from ...core import CancellationToken
from .._function_utils import normalize_annotated_type
T = TypeVar("T", bound=BaseModel, contravariant=True)
class ParametersSchema(TypedDict):
type: str
properties: Dict[str, Any]
required: NotRequired[Sequence[str]]
class ToolSchema(TypedDict):
parameters: NotRequired[ParametersSchema]
name: str
description: NotRequired[str]
class Tool(Protocol):
@property
def name(self) -> str: ...
@property
def description(self) -> str: ...
@property
def schema(self) -> ToolSchema: ...
def args_type(self) -> Type[BaseModel]: ...
def return_type(self) -> Type[Any]: ...
def state_type(self) -> Type[BaseModel] | None: ...
def return_value_as_string(self, value: Any) -> str: ...
async def run_json(self, args: Mapping[str, Any], cancellation_token: CancellationToken) -> Any: ...
def save_state_json(self) -> Mapping[str, Any]: ...
def load_state_json(self, state: Mapping[str, Any]) -> None: ...
ArgsT = TypeVar("ArgsT", bound=BaseModel, contravariant=True)
ReturnT = TypeVar("ReturnT", bound=BaseModel, covariant=True)
StateT = TypeVar("StateT", bound=BaseModel)
class BaseTool(ABC, Tool, Generic[ArgsT, ReturnT]):
def __init__(
self,
args_type: Type[ArgsT],
return_type: Type[ReturnT],
name: str,
description: str,
) -> None:
self._args_type = args_type
# Normalize Annotated to the base type.
self._return_type = normalize_annotated_type(return_type)
self._name = name
self._description = description
@property
def schema(self) -> ToolSchema:
model_schema = self._args_type.model_json_schema()
tool_schema = ToolSchema(
name=self._name,
description=self._description,
parameters=ParametersSchema(
type="object",
properties=model_schema["properties"],
),
)
if "required" in model_schema:
assert "parameters" in tool_schema
tool_schema["parameters"]["required"] = model_schema["required"]
return tool_schema
@property
def name(self) -> str:
return self._name
@property
def description(self) -> str:
return self._description
def args_type(self) -> Type[BaseModel]:
return self._args_type
def return_type(self) -> Type[Any]:
return self._return_type
def state_type(self) -> Type[BaseModel] | None:
return None
def return_value_as_string(self, value: Any) -> str:
if isinstance(value, BaseModel):
dumped = value.model_dump()
if isinstance(dumped, dict):
return json.dumps(dumped)
return str(dumped)
return str(value)
@abstractmethod
async def run(self, args: ArgsT, cancellation_token: CancellationToken) -> ReturnT: ...
async def run_json(self, args: Mapping[str, Any], cancellation_token: CancellationToken) -> Any:
return_value = await self.run(self._args_type.model_validate(args), cancellation_token)
return return_value
def save_state_json(self) -> Mapping[str, Any]:
return {}
def load_state_json(self, state: Mapping[str, Any]) -> None:
pass
class BaseToolWithState(BaseTool[ArgsT, ReturnT], ABC, Generic[ArgsT, ReturnT, StateT]):
def __init__(
self,
args_type: Type[ArgsT],
return_type: Type[ReturnT],
state_type: Type[StateT],
name: str,
description: str,
) -> None:
super().__init__(args_type, return_type, name, description)
self._state_type = state_type
@abstractmethod
def save_state(self) -> StateT: ...
@abstractmethod
def load_state(self, state: StateT) -> None: ...
def save_state_json(self) -> Mapping[str, Any]:
return self.save_state().model_dump()
def load_state_json(self, state: Mapping[str, Any]) -> None:
self.load_state(self._state_type.model_validate(state))

View File

@@ -0,0 +1,37 @@
import asyncio
import functools
from pydantic import BaseModel, Field, model_serializer
from ...core import CancellationToken
from ..code_executor import CodeBlock, CodeExecutor
from ._base import BaseTool
class CodeExecutionInput(BaseModel):
code: str = Field(description="The contents of the Python code block that should be executed")
class CodeExecutionResult(BaseModel):
success: bool
output: str
@model_serializer
def ser_model(self) -> str:
return self.output
class PythonCodeExecutionTool(BaseTool[CodeExecutionInput, CodeExecutionResult]):
def __init__(self, executor: CodeExecutor):
super().__init__(CodeExecutionInput, CodeExecutionResult, "CodeExecutor", "Execute Python code blocks.")
self._executor = executor
async def run(self, args: CodeExecutionInput, cancellation_token: CancellationToken) -> CodeExecutionResult:
code_blocks = [CodeBlock(code=args.code, language="python")]
future = asyncio.get_event_loop().run_in_executor(
None, functools.partial(self._executor.execute_code_blocks, code_blocks=code_blocks)
)
cancellation_token.link_future(future)
result = await future
return CodeExecutionResult(success=result.exit_code == 0, output=result.output)

View File

@@ -0,0 +1,50 @@
import asyncio
import functools
from typing import Any, Callable
from pydantic import BaseModel
from ...core import CancellationToken
from .._function_utils import (
args_base_model_from_signature,
get_typed_signature,
)
from ._base import BaseTool
class FunctionTool(BaseTool[BaseModel, BaseModel]):
def __init__(self, func: Callable[..., Any], description: str, name: str | None = None) -> None:
self._func = func
signature = get_typed_signature(func)
func_name = name or func.__name__
args_model = args_base_model_from_signature(func_name + "args", signature)
return_type = signature.return_annotation
self._has_cancellation_support = "cancellation_token" in signature.parameters
super().__init__(args_model, return_type, func_name, description)
async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> Any:
if asyncio.iscoroutinefunction(self._func):
if self._has_cancellation_support:
result = await self._func(**args.model_dump(), cancellation_token=cancellation_token)
else:
result = await self._func(**args.model_dump())
else:
if self._has_cancellation_support:
result = await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(
self._func,
**args.model_dump(),
cancellation_token=cancellation_token,
),
)
else:
future = asyncio.get_event_loop().run_in_executor(
None, functools.partial(self._func, **args.model_dump())
)
cancellation_token.link_future(future)
result = await future
assert isinstance(result, self.return_type())
return result

View File

@@ -0,0 +1,24 @@
"""
The :mod:`agnext.core` module provides the foundational generic interfaces upon which all else is built. This module must not depend on any other module.
"""
from ._agent import Agent
from ._agent_id import AgentId
from ._agent_metadata import AgentMetadata
from ._agent_props import AgentChildren
from ._agent_proxy import AgentProxy
from ._agent_runtime import AgentRuntime, AllNamespaces
from ._base_agent import BaseAgent
from ._cancellation_token import CancellationToken
__all__ = [
"Agent",
"AgentId",
"AgentProxy",
"AgentMetadata",
"AgentRuntime",
"AllNamespaces",
"BaseAgent",
"CancellationToken",
"AgentChildren",
]

View File

@@ -0,0 +1,46 @@
from typing import Any, Mapping, Protocol, runtime_checkable
from ._agent_id import AgentId
from ._agent_metadata import AgentMetadata
from ._cancellation_token import CancellationToken
@runtime_checkable
class Agent(Protocol):
@property
def metadata(self) -> AgentMetadata:
"""Metadata of the agent."""
...
@property
def id(self) -> AgentId:
"""ID of the agent."""
...
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any:
"""Message handler for the agent. This should only be called by the runtime, not by other agents.
Args:
message (Any): Received message. Type is one of the types in `subscriptions`.
cancellation_token (CancellationToken): Cancellation token for the message.
Returns:
Any: Response to the message. Can be None.
Notes:
If there was a cancellation, this function should raise a `CancelledError`.
"""
...
def save_state(self) -> Mapping[str, Any]:
"""Save the state of the agent. The result must be JSON serializable."""
...
def load_state(self, state: Mapping[str, Any]) -> None:
"""Load in the state of the agent obtained from `save_state`.
Args:
state (Mapping[str, Any]): State of the agent. Must be JSON serializable.
"""
...

View File

@@ -0,0 +1,31 @@
from typing_extensions import Self
class AgentId:
def __init__(self, name: str, namespace: str) -> None:
self._name = name
self._namespace = namespace
def __str__(self) -> str:
return f"{self._namespace}/{self._name}"
def __hash__(self) -> int:
return hash((self._namespace, self._name))
def __eq__(self, value: object) -> bool:
if not isinstance(value, AgentId):
return False
return self._name == value.name and self._namespace == value.namespace
@classmethod
def from_str(cls, agent_id: str) -> Self:
namespace, name = agent_id.split("/")
return cls(name, namespace)
@property
def namespace(self) -> str:
return self._namespace
@property
def name(self) -> str:
return self._name

View File

@@ -0,0 +1,8 @@
from typing import Sequence, TypedDict
class AgentMetadata(TypedDict):
name: str
namespace: str
description: str
subscriptions: Sequence[type]

View File

@@ -0,0 +1,11 @@
from typing import Protocol, Sequence, runtime_checkable
from ._agent_id import AgentId
@runtime_checkable
class AgentChildren(Protocol):
@property
def children(self) -> Sequence[AgentId]:
"""Ids of the children of the agent."""
...

View File

@@ -0,0 +1,53 @@
from __future__ import annotations
from asyncio import Future
from typing import TYPE_CHECKING, Any, Mapping
from ._agent_id import AgentId
from ._agent_metadata import AgentMetadata
from ._cancellation_token import CancellationToken
if TYPE_CHECKING:
from ._agent_runtime import AgentRuntime
class AgentProxy:
def __init__(self, agent: AgentId, runtime: AgentRuntime):
self._agent = agent
self._runtime = runtime
@property
def id(self) -> AgentId:
"""Target agent for this proxy"""
return self._agent
@property
def metadata(self) -> AgentMetadata:
"""Metadata of the agent."""
return self._runtime.agent_metadata(self._agent)
def send_message(
self,
message: Any,
*,
sender: AgentId,
cancellation_token: CancellationToken | None = None,
) -> Future[Any]:
return self._runtime.send_message(
message,
recipient=self._agent,
sender=sender,
cancellation_token=cancellation_token,
)
def save_state(self) -> Mapping[str, Any]:
"""Save the state of the agent. The result must be JSON serializable."""
return self._runtime.agent_save_state(self._agent)
def load_state(self, state: Mapping[str, Any]) -> None:
"""Load in the state of the agent obtained from `save_state`.
Args:
state (Mapping[str, Any]): State of the agent. Must be JSON serializable.
"""
self._runtime.agent_load_state(self._agent, state)

View File

@@ -0,0 +1,162 @@
from __future__ import annotations
from asyncio import Future
from typing import Any, Callable, Mapping, Protocol, Sequence, Type, TypeVar, overload, runtime_checkable
from ._agent import Agent
from ._agent_id import AgentId
from ._agent_metadata import AgentMetadata
from ._agent_proxy import AgentProxy
from ._cancellation_token import CancellationToken
# Undeliverable - error
T = TypeVar("T", bound=Agent)
class AllNamespaces:
pass
@runtime_checkable
class AgentRuntime(Protocol):
# Returns the response of the message
def send_message(
self,
message: Any,
recipient: AgentId,
*,
sender: AgentId | None = None,
cancellation_token: CancellationToken | None = None,
) -> Future[Any]: ...
# No responses from publishing
def publish_message(
self,
message: Any,
*,
namespace: str | None = None,
sender: AgentId | None = None,
cancellation_token: CancellationToken | None = None,
) -> Future[None]: ...
@overload
def register(
self, name: str, agent_factory: Callable[[], T], *, valid_namespaces: Sequence[str] | Type[AllNamespaces] = ...
) -> None: ...
@overload
def register(
self,
name: str,
agent_factory: Callable[[AgentRuntime, AgentId], T],
*,
valid_namespaces: Sequence[str] | Type[AllNamespaces] = ...,
) -> None: ...
def register(
self,
name: str,
agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T],
*,
valid_namespaces: Sequence[str] | Type[AllNamespaces] = AllNamespaces,
) -> None:
"""Register an agent factory with the runtime associated with a specific name. The name must be unique.
Args:
name (str): The name of the type agent this factory creates.
agent_factory (Callable[[], T] | Callable[[AgentRuntime, AgentId], T]): The factory that creates the agent.
valid_namespaces (Sequence[str] | Type[AllNamespaces], optional): Valid namespaces for this type. Defaults to AllNamespaces.
Example:
.. code-block:: python
runtime.register(
"chat_agent",
lambda: ChatCompletionAgent(
description="A generic chat agent.",
system_messages=[SystemMessage("You are a helpful assistant")],
model_client=OpenAI(model="gpt-4o"),
memory=BufferedChatMemory(buffer_size=10),
),
)
"""
...
def get(self, name: str, *, namespace: str = "default") -> AgentId: ...
def get_proxy(self, name: str, *, namespace: str = "default") -> AgentProxy: ...
@overload
def register_and_get(
self,
name: str,
agent_factory: Callable[[], T],
*,
namespace: str = "default",
valid_namespaces: Sequence[str] | Type[AllNamespaces] = ...,
) -> AgentId: ...
@overload
def register_and_get(
self,
name: str,
agent_factory: Callable[[AgentRuntime, AgentId], T],
*,
namespace: str = "default",
valid_namespaces: Sequence[str] | Type[AllNamespaces] = ...,
) -> AgentId: ...
def register_and_get(
self,
name: str,
agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T],
*,
namespace: str = "default",
valid_namespaces: Sequence[str] | Type[AllNamespaces] = AllNamespaces,
) -> AgentId:
self.register(name, agent_factory)
return self.get(name, namespace=namespace)
@overload
def register_and_get_proxy(
self,
name: str,
agent_factory: Callable[[], T],
*,
namespace: str = "default",
valid_namespaces: Sequence[str] | Type[AllNamespaces] = ...,
) -> AgentProxy: ...
@overload
def register_and_get_proxy(
self,
name: str,
agent_factory: Callable[[AgentRuntime, AgentId], T],
*,
namespace: str = "default",
valid_namespaces: Sequence[str] | Type[AllNamespaces] = ...,
) -> AgentProxy: ...
def register_and_get_proxy(
self,
name: str,
agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T],
*,
namespace: str = "default",
valid_namespaces: Sequence[str] | Type[AllNamespaces] = AllNamespaces,
) -> AgentProxy:
self.register(name, agent_factory)
return self.get_proxy(name, namespace=namespace)
def save_state(self) -> Mapping[str, Any]: ...
def load_state(self, state: Mapping[str, Any]) -> None: ...
def agent_metadata(self, agent: AgentId) -> AgentMetadata: ...
def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]: ...
def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None: ...

View File

@@ -0,0 +1,106 @@
import warnings
from abc import ABC, abstractmethod
from asyncio import Future
from typing import Any, Mapping, Sequence
from ._agent import Agent
from ._agent_id import AgentId
from ._agent_metadata import AgentMetadata
from ._agent_runtime import AgentRuntime
from ._cancellation_token import CancellationToken
class BaseAgent(ABC, Agent):
@property
def metadata(self) -> AgentMetadata:
assert self._id is not None
return AgentMetadata(
namespace=self._id.namespace,
name=self._id.name,
description=self._description,
subscriptions=self._subscriptions,
)
def __init__(self, description: str, subscriptions: Sequence[type]) -> None:
self._runtime: AgentRuntime | None = None
self._id: AgentId | None = None
self._description = description
self._subscriptions = subscriptions
def bind_runtime(self, runtime: AgentRuntime) -> None:
if self._runtime is not None:
raise RuntimeError("Agent has already been bound to a runtime.")
self._runtime = runtime
def bind_id(self, agent_id: AgentId) -> None:
if self._id is not None:
raise RuntimeError("Agent has already been bound to an id.")
self._id = agent_id
@property
def name(self) -> str:
return self.id.name
@property
def id(self) -> AgentId:
if self._id is None:
raise RuntimeError("Agent has not been bound to an id.")
return self._id
@property
def runtime(self) -> AgentRuntime:
if self._runtime is None:
raise RuntimeError("Agent has not been bound to a runtime.")
return self._runtime
@abstractmethod
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any: ...
# Returns the response of the message
def send_message(
self,
message: Any,
recipient: AgentId,
*,
cancellation_token: CancellationToken | None = None,
) -> Future[Any]:
if self._runtime is None:
raise RuntimeError("Agent has not been bound to a runtime.")
if cancellation_token is None:
cancellation_token = CancellationToken()
future = self._runtime.send_message(
message,
sender=self.id,
recipient=recipient,
cancellation_token=cancellation_token,
)
cancellation_token.link_future(future)
return future
def publish_message(
self,
message: Any,
*,
cancellation_token: CancellationToken | None = None,
) -> Future[None]:
if self._runtime is None:
raise RuntimeError("Agent has not been bound to a runtime.")
if cancellation_token is None:
cancellation_token = CancellationToken()
future = self._runtime.publish_message(message, sender=self.id, cancellation_token=cancellation_token)
return future
def save_state(self) -> Mapping[str, Any]:
warnings.warn("save_state not implemented", stacklevel=2)
return {}
def load_state(self, state: Mapping[str, Any]) -> None:
warnings.warn("load_state not implemented", stacklevel=2)
pass

View File

@@ -0,0 +1,39 @@
import threading
from asyncio import Future
from typing import Any, Callable, List
class CancellationToken:
def __init__(self) -> None:
self._cancelled: bool = False
self._lock: threading.Lock = threading.Lock()
self._callbacks: List[Callable[[], None]] = []
def cancel(self) -> None:
with self._lock:
if not self._cancelled:
self._cancelled = True
for callback in self._callbacks:
callback()
def is_cancelled(self) -> bool:
with self._lock:
return self._cancelled
def add_callback(self, callback: Callable[[], None]) -> None:
with self._lock:
if self._cancelled:
callback()
else:
self._callbacks.append(callback)
def link_future(self, future: Future[Any]) -> None:
with self._lock:
if self._cancelled:
future.cancel()
else:
def _cancel() -> None:
future.cancel()
self._callbacks.append(_cancel)

View File

@@ -0,0 +1,17 @@
__all__ = [
"CantHandleException",
"UndeliverableException",
"MessageDroppedException",
]
class CantHandleException(Exception):
"""Raised when a handler can't handle the exception."""
class UndeliverableException(Exception):
"""Raised when a message can't be delivered."""
class MessageDroppedException(Exception):
"""Raised when a message is dropped."""

View File

@@ -0,0 +1,36 @@
from typing import Any, Awaitable, Callable, Protocol, final
from agnext.core import AgentId
__all__ = [
"DropMessage",
"InterventionFunction",
"InterventionHandler",
"DefaultInterventionHandler",
]
@final
class DropMessage: ...
InterventionFunction = Callable[[Any], Any | Awaitable[type[DropMessage]]]
class InterventionHandler(Protocol):
async def on_send(self, message: Any, *, sender: AgentId | None, recipient: AgentId) -> Any | type[DropMessage]: ...
async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any | type[DropMessage]: ...
async def on_response(
self, message: Any, *, sender: AgentId, recipient: AgentId | None
) -> Any | type[DropMessage]: ...
class DefaultInterventionHandler(InterventionHandler):
async def on_send(self, message: Any, *, sender: AgentId | None, recipient: AgentId) -> Any | type[DropMessage]:
return message
async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any | type[DropMessage]:
return message
async def on_response(self, message: Any, *, sender: AgentId, recipient: AgentId | None) -> Any | type[DropMessage]:
return message

View File