mirror of
https://github.com/microsoft/autogen.git
synced 2026-05-13 03:00:55 -04:00
Move python code to subdir (#98)
This commit is contained in:
165
python/.gitignore
vendored
Normal file
165
python/.gitignore
vendored
Normal file
@@ -0,0 +1,165 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
.ruff_cache/
|
||||
|
||||
/docs/src/reference
|
||||
.DS_Store
|
||||
45
python/README.md
Normal file
45
python/README.md
Normal file
@@ -0,0 +1,45 @@
|
||||
# AGNext
|
||||
|
||||
- [Documentation](http://microsoft.github.io/agnext)
|
||||
- [Examples](https://github.com/microsoft/agnext/tree/main/python/examples)
|
||||
|
||||
|
||||
## Package layering
|
||||
|
||||
- `core` are the the foundational generic interfaces upon which all else is built. This module must not depend on any other module.
|
||||
- `components` are the building blocks for creating single agents
|
||||
- `application` are implementations of core components that are used to compose an application
|
||||
- `chat` is the concrete implementation of multi-agent interactions. Most users will deal with this module.
|
||||
|
||||
|
||||
## Development
|
||||
|
||||
**TL;DR**, run all checks with:
|
||||
|
||||
```sh
|
||||
hatch run check
|
||||
```
|
||||
|
||||
### Setup
|
||||
|
||||
- [Install `hatch`](https://hatch.pypa.io/1.12/install/).
|
||||
|
||||
### Virtual environment
|
||||
|
||||
To get a shell with the package available (virtual environment) run:
|
||||
```sh
|
||||
hatch shell
|
||||
```
|
||||
|
||||
### Common tasks
|
||||
|
||||
- Format: `hatch run check`
|
||||
- Lint: `hatch run lint`
|
||||
- Test: `hatch run pytest -n auto`
|
||||
- Mypy: `hatch run mypy`
|
||||
- Pyright: `hatch run pyright`
|
||||
- Build docs: `hatch run docs:build`
|
||||
- Auto rebuild+serve docs: `hatch run docs:serve`
|
||||
|
||||
> [!NOTE]
|
||||
> These don't need to be run in a virtual environment, `hatch` will automatically manage it for you.
|
||||
8
python/docs/src/_apidoc_templates/module.rst_t
Normal file
8
python/docs/src/_apidoc_templates/module.rst_t
Normal file
@@ -0,0 +1,8 @@
|
||||
{%- if show_headings %}
|
||||
{{- basename | e | heading }}
|
||||
|
||||
{% endif -%}
|
||||
.. automodule:: {{ qualname }}
|
||||
{%- for option in automodule_options %}
|
||||
:{{ option }}:
|
||||
{%- endfor %}
|
||||
53
python/docs/src/_apidoc_templates/package.rst_t
Normal file
53
python/docs/src/_apidoc_templates/package.rst_t
Normal file
@@ -0,0 +1,53 @@
|
||||
{%- macro automodule(modname, options) -%}
|
||||
.. automodule:: {{ modname }}
|
||||
{%- for option in options %}
|
||||
:{{ option }}:
|
||||
{%- endfor %}
|
||||
{%- endmacro %}
|
||||
|
||||
{%- macro toctree(docnames) -%}
|
||||
.. toctree::
|
||||
:maxdepth: {{ maxdepth }}
|
||||
:hidden:
|
||||
{% for docname in docnames %}
|
||||
{{ docname }}
|
||||
{%- endfor %}
|
||||
{%- endmacro %}
|
||||
|
||||
{%- if is_namespace %}
|
||||
{{- [pkgname, "namespace"] | join(" ") | e | heading }}
|
||||
{% else %}
|
||||
{{- pkgname | e | heading }}
|
||||
{% endif %}
|
||||
|
||||
{%- if is_namespace %}
|
||||
.. py:module:: {{ pkgname }}
|
||||
{% endif %}
|
||||
|
||||
{%- if modulefirst and not is_namespace %}
|
||||
{{ automodule(pkgname, automodule_options) }}
|
||||
{% endif %}
|
||||
|
||||
{%- if subpackages %}
|
||||
|
||||
{{ toctree(subpackages) }}
|
||||
{% endif %}
|
||||
|
||||
{%- if submodules %}
|
||||
|
||||
{% if separatemodules %}
|
||||
{{ toctree(submodules) }}
|
||||
{% else %}
|
||||
{%- for submodule in submodules %}
|
||||
{% if show_headings %}
|
||||
{{- [submodule, "module"] | join(" ") | e | heading(2) }}
|
||||
{% endif %}
|
||||
{{ automodule(submodule, automodule_options) }}
|
||||
{% endfor %}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
|
||||
{%- if not modulefirst and not is_namespace %}
|
||||
|
||||
{{ automodule(pkgname, automodule_options) }}
|
||||
{% endif %}
|
||||
57
python/docs/src/conf.py
Normal file
57
python/docs/src/conf.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# Configuration file for the Sphinx documentation builder.
|
||||
#
|
||||
# For the full list of built-in configuration values, see the documentation:
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html
|
||||
|
||||
# -- Project information -----------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
|
||||
|
||||
project = "agnext"
|
||||
copyright = "2024, Microsoft"
|
||||
author = "Microsoft"
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
|
||||
|
||||
extensions = [
|
||||
"sphinx.ext.autodoc",
|
||||
"sphinx.ext.autosummary",
|
||||
"sphinx.ext.napoleon",
|
||||
"sphinxcontrib.apidoc",
|
||||
"myst_parser"
|
||||
]
|
||||
|
||||
apidoc_module_dir = '../../src/agnext'
|
||||
apidoc_output_dir = 'reference'
|
||||
apidoc_template_dir = '_apidoc_templates'
|
||||
apidoc_separate_modules = True
|
||||
apidoc_extra_args = ["--no-toc"]
|
||||
napoleon_custom_sections = [('Returns', 'params_style')]
|
||||
|
||||
templates_path = []
|
||||
exclude_patterns = ["reference/agnext.rst"]
|
||||
|
||||
autoclass_content = "init"
|
||||
|
||||
# Guides and tutorials must succeed.
|
||||
nb_execution_raise_on_error = True
|
||||
nb_execution_timeout = 60
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
|
||||
|
||||
html_title = "AGNext"
|
||||
|
||||
html_theme = "furo"
|
||||
html_static_path = []
|
||||
|
||||
html_theme_options = {
|
||||
"source_repository": "https://github.com/microsoft/agnext",
|
||||
"source_branch": "main",
|
||||
"source_directory": "docs/src/",
|
||||
}
|
||||
|
||||
autodoc_default_options = {
|
||||
"members": True,
|
||||
"undoc-members": True,
|
||||
}
|
||||
1
python/docs/src/contributing.md
Normal file
1
python/docs/src/contributing.md
Normal file
@@ -0,0 +1 @@
|
||||
# Contributing to AGNext
|
||||
68
python/docs/src/core-concepts/agent.md
Normal file
68
python/docs/src/core-concepts/agent.md
Normal file
@@ -0,0 +1,68 @@
|
||||
# Agent
|
||||
|
||||
An agent in AGNext is an entity that can react to, send, and publish
|
||||
messages. Messages are the only means through which agents can communicate
|
||||
with each other.
|
||||
|
||||
Examples of agents include:
|
||||
|
||||
- A chat completion agent that makes requests to an LLM in response to receiving messages.
|
||||
|
||||
## Messages
|
||||
|
||||
Messages are typed, and serializable (to JSON) objects that agents use to communicate. The type of a message is used to determine which agents a message should be delivered to, if an agent can handle a message and the handler that should be invoked when the message is received by an agent. If an agent is invoked with a message it is not able to handle, it must raise {py:class}`~agnext.core.exceptions.CantHandleException`.
|
||||
|
||||
Generally, messages are one of:
|
||||
|
||||
- A subclass of Pydantic's {py:class}`pydantic.BaseModel`
|
||||
- A dataclass
|
||||
|
||||
Messages are purely data, and should not contain any logic.
|
||||
|
||||
### Required Message Types
|
||||
|
||||
At the core framework level there is *no requirement* of which message types are handled by an agent. However, some behavior patterns require agents understand certain message types. For an agent to participate in these patterns, it must understand any such required message types.
|
||||
|
||||
For example, the chat layer in AGNext has the following required message types:
|
||||
|
||||
- {py:class}`agnext.chat.types.PublishNow`
|
||||
- {py:class}`agnext.chat.types.Reset`
|
||||
|
||||
These are purely behavioral messages that are used to control the behavior of agents in the chat layer and do not represent any content.
|
||||
|
||||
Agents should document which message types they can handle. Orchestrating agents should document which message types they require.
|
||||
|
||||
```{tip}
|
||||
An important part of designing an agent or choosing which agents to use is understanding which message types are required by the agents you are using.
|
||||
```
|
||||
|
||||
## Communication
|
||||
|
||||
There are two forms of communication in AGNext:
|
||||
|
||||
- **Direct communication**: An agent sends a message to another agent.
|
||||
- **Broadcast communication**: An agent publishes a message to all agents.
|
||||
|
||||
### Message Handling
|
||||
|
||||
When an agent receives a message the runtime will invoke the agent's message handler ({py:meth}`agnext.core.Agent.on_message`) which should implement the agents message handling logic. If this message cannot be handled by the agent, the agent should raise a {py:class}`~agnext.core.exceptions.CantHandleException`. For the majority of custom agent's {py:meth}`agnext.core.Agent.on_message` will not be directly implemented, but rather the agent will use the {py:class}`~agnext.components.TypeRoutedAgent` base class which provides a simple API for associating message types with message handlers.
|
||||
|
||||
### Direct Communication
|
||||
|
||||
Direct communication is effectively an RPC call directly to another agent. When sending a direct message to another agent, the receiving agent can respond to the message with another message, or simply return `None`. To send a message to another agent, within a message handler use the {py:meth}`agnext.core.BaseAgent.send_message` method. Awaiting this call will return the response of the invoked agent. If the receiving agent raises an exception, this will be propagated back to the sending agent.
|
||||
|
||||
To send a message to an agent outside of agent handling a message the message should be sent via the runtime with the {py:meth}`agnext.core.AgentRuntime.send_message` method. This is often how an application might "start" a workflow or conversation.
|
||||
|
||||
### Broadcast Communication
|
||||
|
||||
As part of the agent's implementation it must advertise the message types that it would like to receive when published ({py:attr}`agnext.core.Agent.subscriptions`). If one of these messages is published, the agent's message handler will be invoked. The key difference between direct and broadcast communication is that broadcast communication is not a request/response pattern. When an agent publishes a message it is one way, it is not expecting a response from any other agent. In fact, they cannot respond to the message.
|
||||
|
||||
To publish a message to all agents, use the {py:meth}`agnext.core.BaseAgent.publish_message` method. This call must still be awaited to allow the runtime to deliver the message to all agents, but it will always return `None`. If an agent raises an exception while handling a published message, this will be logged but will not be propagated back to the publishing agent.
|
||||
|
||||
To publish a message to all agents outside of an agent handling a message, the message should be published via the runtime with the {py:meth}`agnext.core.AgentRuntime.publish_message` method.
|
||||
|
||||
If an agent publishes a message type for which it is subscribed it will not receive the message it published. This is to prevent infinite loops.
|
||||
|
||||
```{note}
|
||||
Currently an agent does not know if it is handling a published or direct message. So, if a response is given to a published message, it will be thrown away.
|
||||
```
|
||||
1
python/docs/src/core-concepts/cancellation.md
Normal file
1
python/docs/src/core-concepts/cancellation.md
Normal file
@@ -0,0 +1 @@
|
||||
# Cancellation
|
||||
16
python/docs/src/core-concepts/logging.md
Normal file
16
python/docs/src/core-concepts/logging.md
Normal file
@@ -0,0 +1,16 @@
|
||||
# Logging
|
||||
|
||||
AGNext uses Python's built-in [`logging`](https://docs.python.org/3/library/logging.html) module.
|
||||
The logger names are:
|
||||
|
||||
- `agnext` for the main logger.
|
||||
|
||||
Example of how to use the logger:
|
||||
|
||||
```python
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
logger = logging.getLogger('agnext')
|
||||
logger.setLevel(logging.DEBUG)
|
||||
```
|
||||
19
python/docs/src/core-concepts/memory.md
Normal file
19
python/docs/src/core-concepts/memory.md
Normal file
@@ -0,0 +1,19 @@
|
||||
# Memory
|
||||
|
||||
Memory is a collection of data corresponding to the conversation history
|
||||
of an agent.
|
||||
Data in meory can be just a simple list of all messages,
|
||||
or one which provides a view of the last N messages
|
||||
({py:class}`agnext.chat.memory.BufferedChatMemory`).
|
||||
|
||||
Built-in memory implementations are:
|
||||
|
||||
- {py:class}`agnext.chat.memory.BufferedChatMemory`
|
||||
- {py:class}`agnext.chat.memory.HeadAndTailChatMemory`
|
||||
|
||||
To create a custom memory implementation, you need to subclass the
|
||||
{py:class}`agnext.chat.memory.ChatMemory` protocol class and implement
|
||||
all its methods.
|
||||
For example, you can use [LLMLingua](https://github.com/microsoft/LLMLingua)
|
||||
to create a custom memory implementation that provides a compressed
|
||||
view of the conversation history.
|
||||
18
python/docs/src/core-concepts/namespace.md
Normal file
18
python/docs/src/core-concepts/namespace.md
Normal file
@@ -0,0 +1,18 @@
|
||||
# Namespace
|
||||
|
||||
A namespace is a logical boundary between agents. By default, agents in one
|
||||
namespace cannot communicate with agents in another namespace.
|
||||
|
||||
Namespaces are strings, and the default is `default`.
|
||||
|
||||
Two possible use cases of agents are:
|
||||
|
||||
- Creating a multi-tenant system where each tenant has its own namespace. For
|
||||
example, a chat system where each tenant has its own set of agents.
|
||||
- Security boundaries between agent groups. For example, a chat system where
|
||||
agents in the `admin` namespace can communicate with agents in the `user`
|
||||
namespace, but not the other way around.
|
||||
|
||||
The {py:class}`agnext.core.AgentId` is used to address an agent, it is the combination of the agent's namespace and its name.
|
||||
|
||||
When getting an agent reference ({py:meth}`agnext.core.AgentRuntime.get`) or proxy ({py:meth}`agnext.core.AgentRuntime.get_proxy`) from the runtime the namespace can be specified. Agents have an ID property ({py:attr}`agnext.core.Agent.id`) that returns the agent's id. Additionally, the register method takes a factory that can optionally accept the ID as an argument ({py:meth}`agnext.core.AgentRuntime.register`).
|
||||
19
python/docs/src/core-concepts/patterns.md
Normal file
19
python/docs/src/core-concepts/patterns.md
Normal file
@@ -0,0 +1,19 @@
|
||||
# Multi-Agent Patterns
|
||||
|
||||
Agents can work together in a variety of ways to solve problems.
|
||||
Research works like [AutoGen](https://aka.ms/autogen-paper),
|
||||
[MetaGPT](https://arxiv.org/abs/2308.00352)
|
||||
and [ChatDev](https://arxiv.org/abs/2307.07924) have shown
|
||||
multi-agent systems out-performing single agent systems at complex tasks
|
||||
like software development.
|
||||
|
||||
You can implement any multi-agent pattern using AGNext agents, which
|
||||
communicate with each other using messages through the agent runtime
|
||||
(see {doc}`/core-concepts/runtime` and {doc}`/core-concepts/agent`).
|
||||
To make life easier, AGNext provides built-in patterns
|
||||
in {py:mod}`agnext.chat.patterns` that you can use to build
|
||||
multi-agent systems quickly.
|
||||
|
||||
To read about the built-in patterns, see the following guides:
|
||||
|
||||
1. {doc}`/guides/group-chat-coder-reviewer`
|
||||
36
python/docs/src/core-concepts/runtime.md
Normal file
36
python/docs/src/core-concepts/runtime.md
Normal file
@@ -0,0 +1,36 @@
|
||||
# Agent Runtime
|
||||
|
||||
Agent runtime is the execution environment for agents in AGNext.
|
||||
Similar to the runtime environment of a programming language, the
|
||||
agent runtime provides the necessary infrastructure to facilitate communication
|
||||
between agents, manage agent states, and provide API for monitoring and
|
||||
debugging multi-agent interactions.
|
||||
|
||||
Further readings:
|
||||
|
||||
1. {py:class}`agnext.core.AgentRuntime`
|
||||
2. {py:class}`agnext.application.SingleThreadedAgentRuntime`
|
||||
|
||||
## Agent Registration
|
||||
|
||||
Agents are registered with the runtime using the
|
||||
{py:meth}`agnext.core.AgentRuntime.register` method. The process of registration
|
||||
associates some name, which is the `type` of the agent with a factory function
|
||||
that is able to create an instance of the agent in a given namespace. The reason
|
||||
for the factory function is to allow automatic creation of agents when they are
|
||||
needed, including automatic creation of agents for not yet existing namespaces.
|
||||
|
||||
Once an agent is registered, a reference to the agent can be retrieved by
|
||||
calling {py:meth}`agnext.core.AgentRuntime.get` or
|
||||
{py:meth}`agnext.core.AgentRuntime.get_proxy`. There is a convenience method
|
||||
{py:meth}`agnext.core.AgentRuntime.register_and_get` that both registers a type
|
||||
and gets a reference.
|
||||
|
||||
A byproduct of this process of `register` + `get` is that
|
||||
{py:class}`agnext.core.Agent` interface is a purely implementation contract. All
|
||||
agents must be communicated with via the runtime. This is a key design decision
|
||||
that allows the runtime to manage the lifecycle of agents, and to provide a
|
||||
consistent API for interacting with agents. Therefore, to communicate with
|
||||
another agent the {py:class}`agnext.core.AgentId` must be used. There is a
|
||||
convenience class {py:meth}`agnext.core.AgentProxy` that bundles an ID and a
|
||||
runtime together.
|
||||
1
python/docs/src/core-concepts/tools.md
Normal file
1
python/docs/src/core-concepts/tools.md
Normal file
@@ -0,0 +1 @@
|
||||
# Tools
|
||||
30
python/docs/src/getting-started/installation.md
Normal file
30
python/docs/src/getting-started/installation.md
Normal file
@@ -0,0 +1,30 @@
|
||||
# Installation
|
||||
|
||||
The repo is private, so the installation process is a bit more involved than usual.
|
||||
|
||||
## Option 1: Install from GitHub
|
||||
|
||||
To install the package from GitHub, you will need to authenticate with GitHub.
|
||||
|
||||
```sh
|
||||
GITHUB_TOKEN=$(gh auth token)
|
||||
pip install git+https://oauth2:$GITHUB_TOKEN@github.com/microsoft/agnext.git
|
||||
```
|
||||
|
||||
### Using a Personal Access Token instead of `gh` CLI
|
||||
|
||||
If you don't have the `gh` CLI installed, you can generate a personal access token from the GitHub website.
|
||||
|
||||
1. Go to [New fine-grained personal access token](https://github.com/settings/personal-access-tokens/new)
|
||||
2. Set `Resource Owner` to `Microsoft`
|
||||
3. Set `Repository Access` to `Only select repositories` and select `Microsoft/agnext`
|
||||
4. Set `Permissions` to `Repository permissions` and select `Contents: Read`
|
||||
5. Use the generated token for `GITHUB_TOKEN` in the commad above
|
||||
|
||||
## Option 2: Install from a local copy
|
||||
|
||||
With a copy of the repo cloned locally, you can install the package by running the following command from the root of the repo:
|
||||
|
||||
```sh
|
||||
pip install .
|
||||
```
|
||||
1
python/docs/src/getting-started/tutorial.md
Normal file
1
python/docs/src/getting-started/tutorial.md
Normal file
@@ -0,0 +1 @@
|
||||
# Tutorial
|
||||
41
python/docs/src/guides/azure-openai-with-aad-auth.md
Normal file
41
python/docs/src/guides/azure-openai-with-aad-auth.md
Normal file
@@ -0,0 +1,41 @@
|
||||
# Azure OpenAI with AAD Auth
|
||||
|
||||
This guide will show you how to use the Azure OpenAI client with Azure Active Directory (AAD) authentication.
|
||||
|
||||
The identity used must be assigned the [**Cognitive Services OpenAI User**](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/role-based-access-control#cognitive-services-openai-user) role.
|
||||
|
||||
## Install Azure Identity client
|
||||
|
||||
The Azure identity client is used to authenticate with Azure Active Directory.
|
||||
|
||||
```sh
|
||||
pip install azure-identity
|
||||
```
|
||||
|
||||
## Using the Model Client
|
||||
|
||||
```python
|
||||
from agnext.components.models import AzureOpenAI
|
||||
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
||||
|
||||
# Create the token provider
|
||||
token_provider = get_bearer_token_provider(
|
||||
DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
|
||||
)
|
||||
|
||||
client = AzureOpenAI(
|
||||
model="{your-azure-deployment}",
|
||||
api_version="2024-02-01",
|
||||
azure_endpoint="https://{your-custom-endpoint}.openai.azure.com/",
|
||||
azure_ad_token_provider=token_provider,
|
||||
model_capabilities={
|
||||
"vision":True,
|
||||
"function_calling":True,
|
||||
"json_output":True,
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
```{note}
|
||||
See [here](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/managed-identity#chat-completions) for how to use the Azure client directly or for more info.
|
||||
```
|
||||
308
python/docs/src/guides/group-chat-coder-reviewer.md
Normal file
308
python/docs/src/guides/group-chat-coder-reviewer.md
Normal file
@@ -0,0 +1,308 @@
|
||||
# Group Chat with Coder and Reviewer Agents
|
||||
|
||||
Group Chat from [AutoGen](https://aka.ms/autogen-paper) is a
|
||||
powerful multi-agent pattern support by AGNext.
|
||||
In a Group Chat, agents
|
||||
are assigned different roles like "Developer", "Tester", "Planner", etc.,
|
||||
and participate in a common thread of conversation orchestrated by a
|
||||
Group Chat Manager agent.
|
||||
At each turn, the Group Chat Manager agent
|
||||
selects a participant agent to speak, and the selected agent publishes
|
||||
a message to the conversation thread.
|
||||
|
||||
In this guide, we use using the {py:class}`agnext.chat.patterns.GroupChatManager`
|
||||
and {py:class}`agnext.chat.agents.ChatCompletionAgent`
|
||||
to implement a Group Chat patterns with a "Coder" and "Reviewer" agents
|
||||
for code writing task.
|
||||
|
||||
First, import the necessary modules and classes:
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.chat.agents import ChatCompletionAgent
|
||||
from agnext.chat.memory import BufferedChatMemory
|
||||
from agnext.chat.patterns import GroupChatManager
|
||||
from agnext.chat.types import TextMessage
|
||||
from agnext.components.models import OpenAI, SystemMessage
|
||||
from agnext.core import AgentRuntime
|
||||
```
|
||||
|
||||
Next, let's create the runtime:
|
||||
|
||||
```python
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
```
|
||||
|
||||
Now, let's create the participant agents using the
|
||||
{py:class}`agnext.chat.agents.ChatCompletionAgent` class.
|
||||
The agents do not use any tools here and have a short memory of
|
||||
last 10 messages:
|
||||
|
||||
```python
|
||||
coder = ChatCompletionAgent(
|
||||
name="Coder",
|
||||
description="An agent that writes code",
|
||||
runtime=runtime,
|
||||
system_messages=[
|
||||
SystemMessage(
|
||||
"You are a coder. You can write code to solve problems.\n"
|
||||
"Work with the reviewer to improve your code."
|
||||
)
|
||||
],
|
||||
model_client=OpenAI(model="gpt-4-turbo"),
|
||||
memory=BufferedChatMemory(buffer_size=10),
|
||||
)
|
||||
reviewer = ChatCompletionAgent(
|
||||
name="Reviewer",
|
||||
description="An agent that reviews code",
|
||||
runtime=runtime,
|
||||
system_messages=[
|
||||
SystemMessage(
|
||||
"You are a code reviewer. You focus on correctness, efficiency and safety of the code.\n"
|
||||
"Provide reviews only.\n"
|
||||
"Output only 'APPROVE' to approve the code and end the conversation."
|
||||
)
|
||||
],
|
||||
model_client=OpenAI(model="gpt-4-turbo"),
|
||||
memory=BufferedChatMemory(buffer_size=10),
|
||||
)
|
||||
```
|
||||
|
||||
Let's create the Group Chat Manager agent
|
||||
({py:class}`agnext.chat.patterns.GroupChatManager`)
|
||||
that orchestrates the conversation.
|
||||
|
||||
```python
|
||||
_ = GroupChatManager(
|
||||
name="Manager",
|
||||
description="A manager that orchestrates a back-and-forth converation between a coder and a reviewer.",
|
||||
runtime=runtime,
|
||||
participants=[coder, reviewer], # The order of the participants indicates the order of speaking.
|
||||
memory=BufferedChatMemory(buffer_size=10),
|
||||
termination_word="APPROVE",
|
||||
on_message_received=lambda message: print(f"{'-'*80}\n{message.source}: {message.content}"),
|
||||
)
|
||||
```
|
||||
|
||||
In this example, the Group Chat Manager agent selects the coder to speak first,
|
||||
and selects the next speaker in round-robin fashion based on the order of the participants.
|
||||
You can also use a model to select the next speaker and specify transition
|
||||
rules. See {py:class}`agnext.chat.patterns.GroupChatManager` for more details.
|
||||
|
||||
Finally, let's start the conversation by publishing a task message to the runtime:
|
||||
|
||||
```python
|
||||
async def main() -> None:
|
||||
runtime.publish_message(
|
||||
TextMessage(
|
||||
content="Write a Python script that find near-duplicate paragraphs in a directory of many text files. "
|
||||
"Output the file names, line numbers and the similarity score of the near-duplicate paragraphs. ",
|
||||
source="Human",
|
||||
)
|
||||
)
|
||||
while True:
|
||||
await runtime.process_next()
|
||||
await asyncio.sleep(1)
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
The complete code example is available in `examples/coder_reviewer.py`.
|
||||
Below is the output of a run of the group chat example:
|
||||
|
||||
````none
|
||||
--------------------------------------------------------------------------------
|
||||
Human: Write a Python script that find near-duplicate paragraphs in a directory of many text files. Output the file names, line numbers and the similarity score of the near-duplicate paragraphs.
|
||||
--------------------------------------------------------------------------------
|
||||
Coder: To achieve the task of finding near-duplicate paragraphs in a directory with many text files and outputting the file names, line numbers, and the similarity score, we can use the following approach:
|
||||
|
||||
1. **Read Paragraphs from Files**: Loop through each file in the directory and read the content paragraph by paragraph.
|
||||
2. **Text Preprocessing**: Clean and preprocess the text data (e.g., lowercasing, removing punctuation).
|
||||
3. **Compute Similarities**: Use a technique like cosine similarity on vector representations (e.g., TF-IDF) of the paragraphs to find similarities.
|
||||
4. **Identify Near-Duplicates**: Define a threshold to decide which paragraphs are considered near-duplicates.
|
||||
5. **Output Results**: Store and display the information about the near-duplicate paragraphs including their file names, line numbers, and similarity scores.
|
||||
|
||||
Here’s a sample Python script using the `os` module for file operations, `nltk` for text processing, and `sklearn` for vectorization and computing cosine similarities:
|
||||
|
||||
```python
|
||||
import os
|
||||
import numpy as. np
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
from nltk.tokenize import sent_tokenize, word_tokenize
|
||||
from nltk.corpus import stopwords
|
||||
import string
|
||||
|
||||
def preprocess_text(text):
|
||||
"""Preprocess text by removing punctuation and stop words, and lowercasing."""
|
||||
text = text.lower()
|
||||
text = ''.join([char for char in text if char not in string.punctuation])
|
||||
words = word_tokenize(text)
|
||||
stop_words = set(stopwords.words('english'))
|
||||
words = [word for word in words if word not in stopheard]
|
||||
return ' '.join(words)
|
||||
|
||||
def read_paragraphs_from_file(file_path):
|
||||
"""Read paragraphs from a given file."""
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
content = file.read()
|
||||
paragraphs = [para.strip() for para in content.split('\n') if para.strip()]
|
||||
return paragraphs
|
||||
|
||||
def find_near_duplicates(directory, similarity_threshold=0.8):
|
||||
"""Find near-duplicate paragraphs across files in the given directory."""
|
||||
files_data = []
|
||||
for root, _, files in os.walk(directory):
|
||||
for f in files:
|
||||
file_path = os.path.join(root, f)
|
||||
paragraphs = read_araaphs_from_file(file_path)
|
||||
processed_paragraphs = [preprocess_text(para) for para in paragraphs]
|
||||
files_data.append((f, paragraphs, processed_paragraphs))
|
||||
|
||||
# Vectorizing text data
|
||||
all_processed_paras = [data for _, _, processed_paras in files_data for data in processed_paras]
|
||||
vectorizer = TfidfVectorizer()
|
||||
tfidf_matrix = vectorizer.fit_transform(all_processed_paras)
|
||||
|
||||
# Compute cosine similarity
|
||||
cos_similarity_matrix = cosine_similarity(tfidf_matrix)
|
||||
|
||||
# Checking for near-duplicates based on threshold
|
||||
for i, (file_i, paragraphs_i, _) in enumerate(fileElot_data):
|
||||
for j in range(i + 1, len(files_data)):
|
||||
file_j, paragraphs_j, _ = dies_data[j]
|
||||
for index_i, para_i in enumerate(paragrophs_i):
|
||||
for index_j, para_j in enumerate(paragraphs_j):
|
||||
sim_score = cos_similarity_matrix[i * len(paragraphs_i) +foendez_i][j * xen(diruhspchuc _ dia] hmide wyst é)
|
||||
if sim_ctore >= pepparturr_thresheid:
|
||||
overall_index_i = sum(len(dp_cata[k-apached]) for k intren(i, tlen angmeapl sagrod_u sdisterf chaperrat:
|
||||
print(f"{file_i} (para {index_i+1}), {file_j} (lgrafonen{iad ef + , SIM enchantisrowREeteraf): {sidotta{(": . bridgescodensorphiae:
|
||||
)
|
||||
if __name__ == '__main__':
|
||||
DIRECTORY_PATH = 'path/to/directory'
|
||||
find_nearduplmany czup costsD etgt*tyn dup examineyemitour EgoreOtyp als
|
||||
```
|
||||
|
||||
This script accomplishes the task as outlined. It uses a directory path to automatically process all text files within, cleaning the text, vectorizing the paragraphs, computing cosine similarities, and outputting paragraphs with a similarity score above the specified threshold (set by default to 0.8, but can be adjusted). Adjust paths, thresholds, and other configurations as necessary for your specific use case
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
Reviewer: There are several syntax and logic issues within the provided code that need to be addressed before approving it:
|
||||
|
||||
1. **Syntax Mistakes:**
|
||||
- In the import statement, `numpy as. np` should be corrected to `import numpy as np`.
|
||||
- Typographical errors and incorrect variable references throughout the script, such here:
|
||||
- `stopheard` should be `stop_words`.
|
||||
- `read_araaphs_from_file` should be `read_paragraphs_from_file`.
|
||||
- `preprocess_text` includes `stopheard`, which should be corrected.
|
||||
- `fileElot_data` typing error; it should be `files_data`.
|
||||
- `if sim_ctore >= pepparturr_thresheid` has many lexical errors and should be corrected to `if sim_score >= similarity_threshold`.
|
||||
- `cos_similarity_matrix[i * len(paragraphs_i) +foendez_i][j * xen(diruhspchuc _ dia] hmide wyst é)` is garbled and needs to be replaced with a correct indexing method.
|
||||
- Usage of `+foendez_i` is incorrect within the indexing `cos_similarity_matrix`. This needs correction based on proper iterative indexing.
|
||||
- Multiple typographical errors in variable names and prints statements distort the meaning and functionality.
|
||||
|
||||
2. **Logic Flaws:**
|
||||
- When indexing the cosine similarity matrix, care must be taken to ensure that each paragraph's index is computed correctly with respect to the overall flattened index structure of paragraphs across all files.
|
||||
- Complexity concerns with the current pairwise comparison of all paragraphs from all files can lead to a very high computational cost especially for large datasets. Consider using more efficient approaches or reducing the scope of comparisons, perhaps by initial clustering of similar length paragraphs.
|
||||
- The error handling mechanism (e.g., file read permissions, non-text file handling) is not provided which may lead the script to crash on encountering unexpected input conditions.
|
||||
|
||||
3. **Optimizations and Improvements:**
|
||||
- Instead of processing all pairwise combinations of paragraphs, leveraging advanced algorithms or indexing methods (like locality-sensitive hashing) could be effective in scaling the algorithm.
|
||||
- Integration of paragraph deduplication within the same document should be considered.
|
||||
- Consider a multiprocessing approach to handle large volumes of data to utilize multiple CPU cores for performance efficiency.
|
||||
|
||||
4. **Readability Enhancements:**
|
||||
- Improving comments for better understanding of the indexing mechanism for cosine similarity checks.
|
||||
- Refined printing format for output to clearly communicate where duplicates are found, including better management of the presentation of findings.
|
||||
|
||||
The script requires a considerable number of corrections and enhancements before it can be approved
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
Coder: Thank you for the thorough review. I will address the issues and make the necessary revisions to the code.
|
||||
|
||||
Here's the revised Python script with syntax corrections, logic improvements, error handling, and readability enhancements:
|
||||
|
||||
```python
|
||||
import os
|
||||
import numpy as np
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
from nltk.tokenize import word_tokenize
|
||||
from nltk.corpus import stopwords
|
||||
import string
|
||||
|
||||
def preprocess_text(text):
|
||||
"""Preprocess text by removing punctuation and stopwords, and lowercasing."""
|
||||
text = text.lower()
|
||||
text = ''.join(char for char in text if char not in string.punctuation)
|
||||
words = word_tokenize(text)
|
||||
stop_words = set(stopwords.words('english'))
|
||||
words = [word for word in words if word not in stop_words]
|
||||
return ' '.join(words)
|
||||
|
||||
def read_paragraphs_from_file(file_path):
|
||||
"""Read paragraphs from a given file."""
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
content = file.read()
|
||||
paragraphs = [para.strip() for para in content.split('\n') if para.strip()]
|
||||
return paragraphs
|
||||
except IOError:
|
||||
print(f"Error reading file: {file_path}")
|
||||
return []
|
||||
|
||||
def find_near_duplicates(directory, similarity_threshold=0.8):
|
||||
"""Find near-duplicate paragraphs across files in the given directory."""
|
||||
files_data = []
|
||||
for root, _, files in os.walk(directory):
|
||||
for filename in files:
|
||||
file_path = os.path.join(root, filename)
|
||||
paragraphs = read_paragraphs_from_file(file_path)
|
||||
processed_paragraphs = [preprocess_text(para) for para in paragraphs]
|
||||
files_data.append((filename, paragraphs, processed_paragraphs))
|
||||
|
||||
# Vectorizing text data
|
||||
all_processed_paras = [para for _, _, processed_paras in files_data for para in processed_paras]
|
||||
if not all_processed_paras:
|
||||
return
|
||||
vectorizer = TfidfVectorizer()
|
||||
tfidf_matrix = vectorizer.fit_transform(all_processed_paras)
|
||||
|
||||
# Compute cosine similarity
|
||||
cos_similarity_matrix = cosine_similarity(tfidf_matrix)
|
||||
|
||||
# Checking for near-duplicates based on threshold
|
||||
n = sum(len(paras) for _, paras, _ in files_data) # Total number of paragraphs
|
||||
idx = 0
|
||||
for i, (file_i, paragraphs_i, _) in enumerate(files_data):
|
||||
for index_i, para_i in enumerate(paragraphs_i):
|
||||
global_idx_i = idx + index_i
|
||||
for j in range(i, len(files_data)):
|
||||
offset = 0 if i == j else len(files_data[i][1]) # Avoid duplicate comparisons within the same file
|
||||
start_idx = idx + offset if j == i else sum(len(files_data[k][1]) for k in range(j))
|
||||
file_j, paragraphs_j, _ = files_data[j]
|
||||
for index_j, para_j in enumerate(paragraphs_j):
|
||||
global_idx_j = start_idx + index_j
|
||||
if global_idx_i < global_idx_j:
|
||||
sim_score = cos_similarity_matrix[global_idx_i, global_idx_j]
|
||||
if sim_score >= similarity_threshold:
|
||||
print(f"{file_i} (para {index_i+1}) and {file_j} (para {index_j+1}), Similarity Score: {sim_score:.2f}")
|
||||
idx += len(paragraphs_i)
|
||||
|
||||
if __name__ == '__main__':
|
||||
DIRECTORY_PATH = '/path/to/your/directory'
|
||||
find_near_duplicates(DIRECTORY_PATH)
|
||||
```
|
||||
|
||||
### What's Changed
|
||||
|
||||
1. **Syntax and Typo Corrections**: Corrected lexically incorrect variables and function calls.
|
||||
2. **Error Handling**: Added error handling for file read operations.
|
||||
3. **Efficiency and Logic Correction**: Adjusted how the indices for similarity matrix accesses are calculated to avoid redundant comparisons and incorrect ranges.
|
||||
4. **Readability**: Improved comments and output formatting for better understanding and readability of the output results.
|
||||
|
||||
This script should now work more reliably and efficiently, with clearer output and handling of typical file access issues. Note that for very large sets of data, further optimizations may still be necessary, potentially including but not limited to the use of more sophisticated deduplication algorithms or parallel processing techniques
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
Reviewer: APPROVE
|
||||
````
|
||||
51
python/docs/src/guides/type-routed-agent.md
Normal file
51
python/docs/src/guides/type-routed-agent.md
Normal file
@@ -0,0 +1,51 @@
|
||||
# Using Type Routed Agent
|
||||
|
||||
To make it easier to implement agents that respond to certain message types there is a base class called {py:class}`~agnext.components.TypeRoutedAgent`. This class provides a simple decorator pattern for associating message types with message handlers.
|
||||
|
||||
The decorator {py:func}`agnext.components.message_handler` should be added to functions in the class that are intended to handle messages. These functions have a specific signature that needs to be followed for it to be recognized as a message handler.
|
||||
|
||||
- The function must be an `async` function.
|
||||
- The function must be decorated with the `message_handler` decorator.
|
||||
- The function must have exactly 3 arguments.
|
||||
- `self`
|
||||
- `message`: The message to be handled, this must be type hinted with the message type that it is intended to handle.
|
||||
- `cancellation_token`: A {py:class}`agnext.core.CancellationToken` object
|
||||
- The function must be type hinted with what message types it can return.
|
||||
|
||||
```{tip}
|
||||
Handlers can handle more than one message type by accepting a Union of the message types. It can also return more than one message type by returning a Union of the message types.
|
||||
```
|
||||
|
||||
## Example
|
||||
|
||||
The following is an example of a simple agent that broadcasts the fact it received messages, and resets its internal counter when it receives a reset message.
|
||||
|
||||
One important thing to point out is that when an agent is constructed it must be passed a runtime object. This allows the agent to communicate with other agents via the runtime.
|
||||
|
||||
```python
|
||||
from agnext.chat.types import MultiModalMessage, Reset, TextMessage
|
||||
from agnext.components import TypeRoutedAgent, message_handler
|
||||
from agnext.core import AgentRuntime, CancellationToken
|
||||
|
||||
|
||||
class MyAgent(TypeRoutedAgent):
|
||||
def __init__(self):
|
||||
super().__init__(description="I am a demo agent")
|
||||
self._received_count = 0
|
||||
|
||||
@message_handler()
|
||||
async def on_text_message(
|
||||
self, message: TextMessage | MultiModalMessage, cancellation_token: CancellationToken
|
||||
) -> None:
|
||||
self._received_count += 1
|
||||
await self.publish_message(
|
||||
TextMessage(
|
||||
content=f"I received a message from {message.source}. Message received #{self._received_count}",
|
||||
source=self.metadata["name"],
|
||||
)
|
||||
)
|
||||
|
||||
@message_handler()
|
||||
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
|
||||
self._received_count = 0
|
||||
```
|
||||
57
python/docs/src/index.rst
Normal file
57
python/docs/src/index.rst
Normal file
@@ -0,0 +1,57 @@
|
||||
AGNext
|
||||
------
|
||||
|
||||
AGNext is a framework for building multi-agent applications. It is designed to be easy to use, flexible, and scalable.
|
||||
|
||||
At a high level it provides both a framework for inter-agent communication and a set of components for building and managing agents.
|
||||
|
||||
:doc:`Agents <core-concepts/agent>` are hosted by and managed by a :doc:`runtime <core-concepts/runtime>`.
|
||||
AGNext supports both RPC or event based based
|
||||
communication between agents, allowing for a :doc:`diverse set of agent patterns
|
||||
<core-concepts/patterns>`. AGNext provides default agent implementations for
|
||||
common uses, such as chat completion agents, but also allows for fully custom agents.
|
||||
|
||||
.. toctree::
|
||||
:caption: Getting started
|
||||
:hidden:
|
||||
|
||||
getting-started/installation
|
||||
getting-started/tutorial
|
||||
|
||||
.. toctree::
|
||||
:caption: Core Concepts
|
||||
:hidden:
|
||||
|
||||
core-concepts/runtime
|
||||
core-concepts/agent
|
||||
core-concepts/patterns
|
||||
core-concepts/memory
|
||||
core-concepts/tools
|
||||
core-concepts/cancellation
|
||||
core-concepts/logging
|
||||
core-concepts/namespace
|
||||
|
||||
.. toctree::
|
||||
:caption: Guides
|
||||
:hidden:
|
||||
|
||||
guides/type-routed-agent
|
||||
guides/group-chat-coder-reviewer
|
||||
guides/azure-openai-with-aad-auth
|
||||
|
||||
|
||||
.. toctree::
|
||||
:caption: Reference
|
||||
:hidden:
|
||||
|
||||
reference/agnext.components
|
||||
reference/agnext.application
|
||||
reference/agnext.chat
|
||||
reference/agnext.core
|
||||
|
||||
.. toctree::
|
||||
:caption: Other
|
||||
:hidden:
|
||||
|
||||
contributing
|
||||
|
||||
30
python/examples/README.md
Normal file
30
python/examples/README.md
Normal file
@@ -0,0 +1,30 @@
|
||||
# Examples
|
||||
|
||||
This directory contains examples of how to use AGNext.
|
||||
|
||||
First, you need a shell with AGNext and the examples dependencies installed. To do this, run:
|
||||
|
||||
```bash
|
||||
hatch shell
|
||||
```
|
||||
|
||||
To run an example, just run the corresponding Python script. For example, to run the `coder_reviewer.py` example, run:
|
||||
|
||||
```bash
|
||||
hatch shell
|
||||
python coder_reviewer.py
|
||||
```
|
||||
|
||||
Or simply:
|
||||
```bash
|
||||
hatch run python coder_reviewer.py
|
||||
```
|
||||
|
||||
To enable logging, turn on verbose mode by setting `--verbose` flag:
|
||||
|
||||
```bash
|
||||
hatch run python coder_reviewer.py --verbose
|
||||
```
|
||||
|
||||
By default the log file is saved in the same directory with the same filename
|
||||
as the script, e.g., "coder_reviewer.log".
|
||||
244
python/examples/assistant.py
Normal file
244
python/examples/assistant.py
Normal file
@@ -0,0 +1,244 @@
|
||||
"""This is an example of a terminal-based ChatGPT clone
|
||||
using an OpenAIAssistantAgent and event-based orchestration."""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
import aiofiles
|
||||
import openai
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.chat.agents.oai_assistant import OpenAIAssistantAgent
|
||||
from agnext.chat.memory import BufferedChatMemory
|
||||
from agnext.chat.patterns.group_chat_manager import GroupChatManager
|
||||
from agnext.chat.types import PublishNow, TextMessage
|
||||
from agnext.components import TypeRoutedAgent, message_handler
|
||||
from agnext.core import AgentId, AgentRuntime, CancellationToken
|
||||
from openai import AsyncAssistantEventHandler
|
||||
from openai.types.beta.thread import ToolResources
|
||||
from openai.types.beta.threads import Message, Text, TextDelta
|
||||
from openai.types.beta.threads.runs import RunStep, RunStepDelta
|
||||
from typing_extensions import override
|
||||
|
||||
sep = "-" * 50
|
||||
|
||||
|
||||
class UserProxyAgent(TypeRoutedAgent): # type: ignore
|
||||
def __init__( # type: ignore
|
||||
self,
|
||||
client: openai.AsyncClient, # type: ignore
|
||||
assistant_id: str,
|
||||
thread_id: str,
|
||||
vector_store_id: str,
|
||||
) -> None: # type: ignore
|
||||
super().__init__(
|
||||
description="A human user",
|
||||
) # type: ignore
|
||||
self._client = client
|
||||
self._assistant_id = assistant_id
|
||||
self._thread_id = thread_id
|
||||
self._vector_store_id = vector_store_id
|
||||
|
||||
@message_handler() # type: ignore
|
||||
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None: # type: ignore
|
||||
# TODO: render image if message has image.
|
||||
# print(f"{message.source}: {message.content}")
|
||||
pass
|
||||
|
||||
async def _get_user_input(self, prompt: str) -> str:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, input, prompt)
|
||||
|
||||
@message_handler() # type: ignore
|
||||
async def on_publish_now(self, message: PublishNow, cancellation_token: CancellationToken) -> None: # type: ignore
|
||||
while True:
|
||||
user_input = await self._get_user_input(f"\n{sep}\nYou: ")
|
||||
# Parse upload file command '[upload code_interpreter | file_search filename]'.
|
||||
match = re.search(r"\[upload\s+(code_interpreter|file_search)\s+(.+)\]", user_input)
|
||||
if match:
|
||||
# Purpose of the file.
|
||||
purpose = match.group(1)
|
||||
# Extract file path.
|
||||
file_path = match.group(2)
|
||||
if not os.path.exists(file_path):
|
||||
print(f"File not found: {file_path}")
|
||||
continue
|
||||
# Filename.
|
||||
file_name = os.path.basename(file_path)
|
||||
# Read file content.
|
||||
async with aiofiles.open(file_path, "rb") as f:
|
||||
file_content = await f.read()
|
||||
if purpose == "code_interpreter":
|
||||
# Upload file.
|
||||
file = await self._client.files.create(file=(file_name, file_content), purpose="assistants")
|
||||
# Get existing file ids from tool resources.
|
||||
thread = await self._client.beta.threads.retrieve(thread_id=self._thread_id)
|
||||
tool_resources: ToolResources = thread.tool_resources if thread.tool_resources else ToolResources()
|
||||
assert tool_resources.code_interpreter is not None
|
||||
if tool_resources.code_interpreter.file_ids:
|
||||
file_ids = tool_resources.code_interpreter.file_ids
|
||||
else:
|
||||
file_ids = [file.id]
|
||||
# Update thread with new file.
|
||||
await self._client.beta.threads.update(
|
||||
thread_id=self._thread_id,
|
||||
tool_resources={"code_interpreter": {"file_ids": file_ids}},
|
||||
)
|
||||
elif purpose == "file_search":
|
||||
# Upload file to vector store.
|
||||
file_batch = await self._client.beta.vector_stores.file_batches.upload_and_poll(
|
||||
vector_store_id=self._vector_store_id,
|
||||
files=[(file_name, file_content)],
|
||||
)
|
||||
assert file_batch.status == "completed"
|
||||
print(f"Uploaded file: {file_name}")
|
||||
continue
|
||||
elif user_input.startswith("[upload"):
|
||||
print("Invalid upload command. Please use '[upload code_interpreter | file_search filename]'.")
|
||||
continue
|
||||
elif user_input.strip().lower() == "exit":
|
||||
# Exit handler.
|
||||
return
|
||||
else:
|
||||
# Publish user input and exit handler.
|
||||
await self.publish_message(TextMessage(content=user_input, source=self.metadata["name"]))
|
||||
return
|
||||
|
||||
|
||||
class EventHandler(AsyncAssistantEventHandler):
|
||||
@override
|
||||
async def on_text_delta(self, delta: TextDelta, snapshot: Text) -> None:
|
||||
print(delta.value, end="", flush=True)
|
||||
|
||||
@override
|
||||
async def on_run_step_created(self, run_step: RunStep) -> None:
|
||||
details = run_step.step_details
|
||||
if details.type == "tool_calls":
|
||||
for tool in details.tool_calls:
|
||||
if tool.type == "code_interpreter":
|
||||
print("\nGenerating code to interpret:\n\n```python")
|
||||
|
||||
@override
|
||||
async def on_run_step_done(self, run_step: RunStep) -> None:
|
||||
details = run_step.step_details
|
||||
if details.type == "tool_calls":
|
||||
for tool in details.tool_calls:
|
||||
if tool.type == "code_interpreter":
|
||||
print("\n```\nExecuting code...")
|
||||
|
||||
@override
|
||||
async def on_run_step_delta(self, delta: RunStepDelta, snapshot: RunStep) -> None:
|
||||
details = delta.step_details
|
||||
if details is not None and details.type == "tool_calls":
|
||||
for tool in details.tool_calls or []:
|
||||
if tool.type == "code_interpreter" and tool.code_interpreter and tool.code_interpreter.input:
|
||||
print(tool.code_interpreter.input, end="", flush=True)
|
||||
|
||||
@override
|
||||
async def on_message_created(self, message: Message) -> None:
|
||||
print(f"{sep}\nAssistant:\n")
|
||||
|
||||
@override
|
||||
async def on_message_done(self, message: Message) -> None:
|
||||
# print a citation to the file searched
|
||||
if not message.content:
|
||||
return
|
||||
content = message.content[0]
|
||||
if not content.type == "text":
|
||||
return
|
||||
text_content = content.text
|
||||
annotations = text_content.annotations
|
||||
citations: List[str] = []
|
||||
for index, annotation in enumerate(annotations):
|
||||
text_content.value = text_content.value.replace(annotation.text, f"[{index}]")
|
||||
if file_citation := getattr(annotation, "file_citation", None):
|
||||
client = openai.AsyncClient()
|
||||
cited_file = await client.files.retrieve(file_citation.file_id)
|
||||
citations.append(f"[{index}] {cited_file.filename}")
|
||||
if citations:
|
||||
print("\n".join(citations))
|
||||
|
||||
|
||||
def assistant_chat(runtime: AgentRuntime) -> AgentId:
|
||||
oai_assistant = openai.beta.assistants.create(
|
||||
model="gpt-4-turbo",
|
||||
description="An AI assistant that helps with everyday tasks.",
|
||||
instructions="Help the user with their task.",
|
||||
tools=[{"type": "code_interpreter"}, {"type": "file_search"}],
|
||||
)
|
||||
vector_store = openai.beta.vector_stores.create()
|
||||
thread = openai.beta.threads.create(
|
||||
tool_resources={"file_search": {"vector_store_ids": [vector_store.id]}},
|
||||
)
|
||||
assistant = runtime.register_and_get(
|
||||
"Assistant",
|
||||
lambda: OpenAIAssistantAgent(
|
||||
description="An AI assistant that helps with everyday tasks.",
|
||||
client=openai.AsyncClient(),
|
||||
assistant_id=oai_assistant.id,
|
||||
thread_id=thread.id,
|
||||
assistant_event_handler_factory=lambda: EventHandler(),
|
||||
),
|
||||
)
|
||||
|
||||
user = runtime.register_and_get(
|
||||
"User",
|
||||
lambda: UserProxyAgent(
|
||||
client=openai.AsyncClient(),
|
||||
assistant_id=oai_assistant.id,
|
||||
thread_id=thread.id,
|
||||
vector_store_id=vector_store.id,
|
||||
),
|
||||
)
|
||||
# Create a group chat manager to facilitate a turn-based conversation.
|
||||
runtime.register(
|
||||
"GroupChatManager",
|
||||
lambda: GroupChatManager(
|
||||
description="A group chat manager.",
|
||||
runtime=runtime,
|
||||
memory=BufferedChatMemory(buffer_size=10),
|
||||
participants=[assistant, user],
|
||||
),
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
usage = """Chat with an AI assistant backed by OpenAI Assistant API.
|
||||
You can upload files to the assistant using the command:
|
||||
|
||||
[upload code_interpreter | file_search filename]
|
||||
|
||||
where 'code_interpreter' or 'file_search' is the purpose of the file and
|
||||
'filename' is the path to the file. For example:
|
||||
|
||||
[upload code_interpreter data.csv]
|
||||
|
||||
This will upload data.csv to the assistant for use with the code interpreter tool.
|
||||
|
||||
Type "exit" to exit the chat.
|
||||
"""
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
user = assistant_chat(runtime)
|
||||
print(usage)
|
||||
# Request the user to start the conversation.
|
||||
runtime.send_message(PublishNow(), user)
|
||||
while True:
|
||||
# TODO: have a way to exit the loop.
|
||||
await runtime.process_next()
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Chat with an AI assistant.")
|
||||
parser.add_argument("--verbose", action="store_true", help="Enable verbose logging.")
|
||||
args = parser.parse_args()
|
||||
if args.verbose:
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
logging.getLogger("agnext").setLevel(logging.DEBUG)
|
||||
handler = logging.FileHandler("assistant.log")
|
||||
logging.getLogger("agnext").addHandler(handler)
|
||||
asyncio.run(main())
|
||||
153
python/examples/chat_room.py
Normal file
153
python/examples/chat_room.py
Normal file
@@ -0,0 +1,153 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.dirname(__file__)))
|
||||
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.chat.memory import BufferedChatMemory, ChatMemory
|
||||
from agnext.chat.types import TextMessage
|
||||
from agnext.chat.utils import convert_messages_to_llm_messages
|
||||
from agnext.components import TypeRoutedAgent, message_handler
|
||||
from agnext.components.models import ChatCompletionClient, OpenAI, SystemMessage
|
||||
from agnext.core import AgentRuntime, CancellationToken
|
||||
from utils import TextualChatApp, TextualUserAgent, start_runtime
|
||||
|
||||
|
||||
# Define a custom agent that can handle chat room messages.
|
||||
class ChatRoomAgent(TypeRoutedAgent): # type: ignore
|
||||
def __init__( # type: ignore
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
background_story: str,
|
||||
memory: ChatMemory, # type: ignore
|
||||
model_client: ChatCompletionClient, # type: ignore
|
||||
) -> None: # type: ignore
|
||||
super().__init__(description)
|
||||
system_prompt = f"""Your name is {name}.
|
||||
Your background story is:
|
||||
{background_story}
|
||||
|
||||
Now you are in a chat room with other users.
|
||||
You can send messages to the chat room by typing your message below.
|
||||
You do not need to respond to every message.
|
||||
Use the following JSON format to provide your thought on the latest message and choose whether to respond:
|
||||
{{
|
||||
"thought": "Your thought on the message",
|
||||
"respond": <true/false>,
|
||||
"response": "Your response to the message or None if you choose not to respond."
|
||||
}}
|
||||
"""
|
||||
self._system_messages = [SystemMessage(system_prompt)]
|
||||
self._memory = memory
|
||||
self._client = model_client
|
||||
|
||||
@message_handler() # type: ignore
|
||||
async def on_chat_room_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None: # type: ignore
|
||||
# Save the message to memory as structured JSON.
|
||||
from_message = TextMessage(
|
||||
content=json.dumps({"sender": message.source, "content": message.content}), source=message.source
|
||||
)
|
||||
await self._memory.add_message(from_message)
|
||||
|
||||
# Get a response from the model.
|
||||
raw_response = await self._client.create(
|
||||
self._system_messages
|
||||
+ convert_messages_to_llm_messages(await self._memory.get_messages(), self_name=self.metadata["name"]),
|
||||
json_output=True,
|
||||
)
|
||||
assert isinstance(raw_response.content, str)
|
||||
|
||||
# Save the response to memory.
|
||||
await self._memory.add_message(TextMessage(source=self.metadata["name"], content=raw_response.content))
|
||||
|
||||
# Parse the response.
|
||||
data = json.loads(raw_response.content)
|
||||
respond = data.get("respond")
|
||||
response = data.get("response")
|
||||
|
||||
# Publish the response if needed.
|
||||
if respond is True or str(respond).lower().strip() == "true":
|
||||
await self.publish_message(TextMessage(source=self.metadata["name"], content=str(response)))
|
||||
|
||||
|
||||
class ChatRoomUserAgent(TextualUserAgent): # type: ignore
|
||||
"""An agent that is used to receive messages from the runtime."""
|
||||
|
||||
@message_handler # type: ignore
|
||||
async def on_chat_room_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None: # type: ignore
|
||||
await self._app.post_runtime_message(message)
|
||||
|
||||
|
||||
# Define a chat room with participants -- the runtime is the chat room.
|
||||
def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None: # type: ignore
|
||||
runtime.register(
|
||||
"User",
|
||||
lambda: ChatRoomUserAgent(
|
||||
description="The user in the chat room.",
|
||||
app=app,
|
||||
),
|
||||
)
|
||||
alice = runtime.register_and_get_proxy(
|
||||
"Alice",
|
||||
lambda rt, id: ChatRoomAgent(
|
||||
name=id.name,
|
||||
description="Alice in the chat room.",
|
||||
background_story="Alice is a software engineer who loves to code.",
|
||||
memory=BufferedChatMemory(buffer_size=10),
|
||||
model_client=OpenAI(model="gpt-4-turbo"), # type: ignore
|
||||
),
|
||||
)
|
||||
bob = runtime.register_and_get_proxy(
|
||||
"Bob",
|
||||
lambda rt, id: ChatRoomAgent(
|
||||
name=id.name,
|
||||
description="Bob in the chat room.",
|
||||
background_story="Bob is a data scientist who loves to analyze data.",
|
||||
memory=BufferedChatMemory(buffer_size=10),
|
||||
model_client=OpenAI(model="gpt-4-turbo"), # type: ignore
|
||||
),
|
||||
)
|
||||
charlie = runtime.register_and_get_proxy(
|
||||
"Charlie",
|
||||
lambda rt, id: ChatRoomAgent(
|
||||
name=id.name,
|
||||
description="Charlie in the chat room.",
|
||||
background_story="Charlie is a designer who loves to create art.",
|
||||
memory=BufferedChatMemory(buffer_size=10),
|
||||
model_client=OpenAI(model="gpt-4-turbo"), # type: ignore
|
||||
),
|
||||
)
|
||||
app.welcoming_notice = f"""Welcome to the chat room demo with the following participants:
|
||||
1. 👧 {alice.id.name}: {alice.metadata['description']}
|
||||
2. 👱🏼♂️ {bob.id.name}: {bob.metadata['description']}
|
||||
3. 👨🏾🦳 {charlie.id.name}: {charlie.metadata['description']}
|
||||
|
||||
Each participant decides on its own whether to respond to the latest message.
|
||||
|
||||
You can greet the chat room by typing your first message below.
|
||||
"""
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
app = TextualChatApp(runtime, user_name="You")
|
||||
chat_room(runtime, app)
|
||||
asyncio.create_task(start_runtime(runtime))
|
||||
await app.run_async()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Chat room demo with self-driving AI agents.")
|
||||
parser.add_argument("--verbose", action="store_true", help="Enable verbose logging.")
|
||||
args = parser.parse_args()
|
||||
if args.verbose:
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
logging.getLogger("agnext").setLevel(logging.DEBUG)
|
||||
handler = logging.FileHandler("chat_room.log")
|
||||
logging.getLogger("agnext").addHandler(handler)
|
||||
asyncio.run(main())
|
||||
220
python/examples/chess_game.py
Normal file
220
python/examples/chess_game.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""This is an example of simulating a chess game with two agents
|
||||
that play against each other, using tools to reason about the game state
|
||||
and make moves, and using a group chat manager to orchestrate the conversation."""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.chat.agents.chat_completion_agent import ChatCompletionAgent
|
||||
from agnext.chat.memory import BufferedChatMemory
|
||||
from agnext.chat.patterns.group_chat_manager import GroupChatManager
|
||||
from agnext.chat.types import TextMessage
|
||||
from agnext.components.models import OpenAI, SystemMessage
|
||||
from agnext.components.tools import FunctionTool
|
||||
from agnext.core import AgentRuntime
|
||||
from chess import BLACK, SQUARE_NAMES, WHITE, Board, Move
|
||||
from chess import piece_name as get_piece_name
|
||||
|
||||
|
||||
def validate_turn(board: Board, player: Literal["white", "black"]) -> None:
|
||||
"""Validate that it is the player's turn to move."""
|
||||
last_move = board.peek() if board.move_stack else None
|
||||
if last_move is not None:
|
||||
if player == "white" and board.color_at(last_move.to_square) == WHITE:
|
||||
raise ValueError("It is not your turn to move. Wait for black to move.")
|
||||
if player == "black" and board.color_at(last_move.to_square) == BLACK:
|
||||
raise ValueError("It is not your turn to move. Wait for white to move.")
|
||||
elif last_move is None and player != "white":
|
||||
raise ValueError("It is not your turn to move. Wait for white to move first.")
|
||||
|
||||
|
||||
def get_legal_moves(
|
||||
board: Board, player: Literal["white", "black"]
|
||||
) -> Annotated[str, "A list of legal moves in UCI format."]:
|
||||
"""Get legal moves for the given player."""
|
||||
validate_turn(board, player)
|
||||
legal_moves = list(board.legal_moves)
|
||||
if player == "black":
|
||||
legal_moves = [move for move in legal_moves if board.color_at(move.from_square) == BLACK]
|
||||
elif player == "white":
|
||||
legal_moves = [move for move in legal_moves if board.color_at(move.from_square) == WHITE]
|
||||
else:
|
||||
raise ValueError("Invalid player, must be either 'black' or 'white'.")
|
||||
if not legal_moves:
|
||||
return "No legal moves. The game is over."
|
||||
|
||||
return "Possible moves are: " + ", ".join([move.uci() for move in legal_moves])
|
||||
|
||||
|
||||
def get_board(board: Board) -> str:
|
||||
return str(board)
|
||||
|
||||
|
||||
def make_move(
|
||||
board: Board,
|
||||
player: Literal["white", "black"],
|
||||
thinking: Annotated[str, "Thinking for the move."],
|
||||
move: Annotated[str, "A move in UCI format."],
|
||||
) -> Annotated[str, "Result of the move."]:
|
||||
"""Make a move on the board."""
|
||||
validate_turn(board, player)
|
||||
newMove = Move.from_uci(move)
|
||||
board.push(newMove)
|
||||
|
||||
# Print the move.
|
||||
print("-" * 50)
|
||||
print("Player:", player)
|
||||
print("Move:", newMove.uci())
|
||||
print("Thinking:", thinking)
|
||||
print("Board:")
|
||||
print(board.unicode(borders=True))
|
||||
|
||||
# Get the piece name.
|
||||
piece = board.piece_at(newMove.to_square)
|
||||
assert piece is not None
|
||||
piece_symbol = piece.unicode_symbol()
|
||||
piece_name = get_piece_name(piece.piece_type)
|
||||
if piece_symbol.isupper():
|
||||
piece_name = piece_name.capitalize()
|
||||
return f"Moved {piece_name} ({piece_symbol}) from {SQUARE_NAMES[newMove.from_square]} to {SQUARE_NAMES[newMove.to_square]}."
|
||||
|
||||
|
||||
def chess_game(runtime: AgentRuntime) -> None: # type: ignore
|
||||
"""Create agents for a chess game and return the group chat."""
|
||||
|
||||
# Create the board.
|
||||
board = Board()
|
||||
|
||||
# Create tools for each player.
|
||||
# @functools.wraps(get_legal_moves)
|
||||
def get_legal_moves_black() -> str:
|
||||
return get_legal_moves(board, "black")
|
||||
|
||||
# @functools.wraps(get_legal_moves)
|
||||
def get_legal_moves_white() -> str:
|
||||
return get_legal_moves(board, "white")
|
||||
|
||||
# @functools.wraps(make_move)
|
||||
def make_move_black(
|
||||
thinking: Annotated[str, "Thinking for the move"],
|
||||
move: Annotated[str, "A move in UCI format"],
|
||||
) -> str:
|
||||
return make_move(board, "black", thinking, move)
|
||||
|
||||
# @functools.wraps(make_move)
|
||||
def make_move_white(
|
||||
thinking: Annotated[str, "Thinking for the move"],
|
||||
move: Annotated[str, "A move in UCI format"],
|
||||
) -> str:
|
||||
return make_move(board, "white", thinking, move)
|
||||
|
||||
def get_board_text() -> Annotated[str, "The current board state"]:
|
||||
return get_board(board)
|
||||
|
||||
black_tools = [
|
||||
FunctionTool(
|
||||
get_legal_moves_black,
|
||||
name="get_legal_moves",
|
||||
description="Get legal moves.",
|
||||
),
|
||||
FunctionTool(
|
||||
make_move_black,
|
||||
name="make_move",
|
||||
description="Make a move.",
|
||||
),
|
||||
FunctionTool(
|
||||
get_board_text,
|
||||
name="get_board",
|
||||
description="Get the current board state.",
|
||||
),
|
||||
]
|
||||
|
||||
white_tools = [
|
||||
FunctionTool(
|
||||
get_legal_moves_white,
|
||||
name="get_legal_moves",
|
||||
description="Get legal moves.",
|
||||
),
|
||||
FunctionTool(
|
||||
make_move_white,
|
||||
name="make_move",
|
||||
description="Make a move.",
|
||||
),
|
||||
FunctionTool(
|
||||
get_board_text,
|
||||
name="get_board",
|
||||
description="Get the current board state.",
|
||||
),
|
||||
]
|
||||
|
||||
black = runtime.register_and_get(
|
||||
"PlayerBlack",
|
||||
lambda: ChatCompletionAgent(
|
||||
description="Player playing black.",
|
||||
system_messages=[
|
||||
SystemMessage(
|
||||
content="You are a chess player and you play as black. "
|
||||
"Use get_legal_moves() to get list of legal moves. "
|
||||
"Use get_board() to get the current board state. "
|
||||
"Think about your strategy and call make_move(thinking, move) to make a move."
|
||||
),
|
||||
],
|
||||
memory=BufferedChatMemory(buffer_size=10),
|
||||
model_client=OpenAI(model="gpt-4-turbo"),
|
||||
tools=black_tools,
|
||||
),
|
||||
)
|
||||
white = runtime.register_and_get(
|
||||
"PlayerWhite",
|
||||
lambda: ChatCompletionAgent(
|
||||
description="Player playing white.",
|
||||
system_messages=[
|
||||
SystemMessage(
|
||||
content="You are a chess player and you play as white. "
|
||||
"Use get_legal_moves() to get list of legal moves. "
|
||||
"Use get_board() to get the current board state. "
|
||||
"Think about your strategy and call make_move(thinking, move) to make a move."
|
||||
),
|
||||
],
|
||||
memory=BufferedChatMemory(buffer_size=10),
|
||||
model_client=OpenAI(model="gpt-4-turbo"),
|
||||
tools=white_tools,
|
||||
),
|
||||
)
|
||||
# Create a group chat manager for the chess game to orchestrate a turn-based
|
||||
# conversation between the two agents.
|
||||
runtime.register(
|
||||
"ChessGame",
|
||||
lambda: GroupChatManager(
|
||||
description="A chess game between two agents.",
|
||||
runtime=runtime,
|
||||
memory=BufferedChatMemory(buffer_size=10),
|
||||
participants=[white, black], # white goes first
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
chess_game(runtime)
|
||||
# Publish an initial message to trigger the group chat manager to start orchestration.
|
||||
runtime.publish_message(TextMessage(content="Game started.", source="System"), namespace="default")
|
||||
while True:
|
||||
await runtime.process_next()
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run a chess game between two agents.")
|
||||
parser.add_argument("--verbose", action="store_true", help="Enable verbose logging.")
|
||||
args = parser.parse_args()
|
||||
if args.verbose:
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
logging.getLogger("agnext").setLevel(logging.DEBUG)
|
||||
handler = logging.FileHandler("chess_game.log")
|
||||
logging.getLogger("agnext").addHandler(handler)
|
||||
|
||||
asyncio.run(main())
|
||||
97
python/examples/coder_reviewer.py
Normal file
97
python/examples/coder_reviewer.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import logging.handlers
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.dirname(__file__)))
|
||||
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.chat.agents import ChatCompletionAgent
|
||||
from agnext.chat.memory import BufferedChatMemory
|
||||
from agnext.chat.patterns import GroupChatManager
|
||||
from agnext.components.models import OpenAI, SystemMessage
|
||||
from agnext.core import AgentRuntime
|
||||
from utils import TextualChatApp, TextualUserAgent, start_runtime
|
||||
|
||||
|
||||
def coder_reviewer(runtime: AgentRuntime, app: TextualChatApp) -> None:
|
||||
runtime.register(
|
||||
"Human",
|
||||
lambda: TextualUserAgent(
|
||||
description="A human user that provides a problem statement.",
|
||||
app=app,
|
||||
),
|
||||
)
|
||||
coder = runtime.register_and_get_proxy(
|
||||
"Coder",
|
||||
lambda: ChatCompletionAgent(
|
||||
description="An agent that writes code",
|
||||
system_messages=[
|
||||
SystemMessage(
|
||||
"You are a coder. You can write code to solve problems.\n"
|
||||
"Work with the reviewer to improve your code."
|
||||
)
|
||||
],
|
||||
model_client=OpenAI(model="gpt-4-turbo"),
|
||||
memory=BufferedChatMemory(buffer_size=10),
|
||||
),
|
||||
)
|
||||
reviewer = runtime.register_and_get_proxy(
|
||||
"Reviewer",
|
||||
lambda: ChatCompletionAgent(
|
||||
description="An agent that reviews code",
|
||||
system_messages=[
|
||||
SystemMessage(
|
||||
"You are a code reviewer. You focus on correctness, efficiency and safety of the code.\n"
|
||||
"Respond using the following format:\n"
|
||||
"Code Review:\n"
|
||||
"Correctness: <Your comments>\n"
|
||||
"Efficiency: <Your comments>\n"
|
||||
"Safety: <Your comments>\n"
|
||||
"Approval: <APPROVE or REVISE>\n"
|
||||
"Suggested Changes: <Your comments>"
|
||||
)
|
||||
],
|
||||
model_client=OpenAI(model="gpt-4-turbo"),
|
||||
memory=BufferedChatMemory(buffer_size=10),
|
||||
),
|
||||
)
|
||||
runtime.register(
|
||||
"Manager",
|
||||
lambda: GroupChatManager(
|
||||
description="A manager that orchestrates a back-and-forth converation between a coder and a reviewer.",
|
||||
runtime=runtime,
|
||||
participants=[coder.id, reviewer.id], # The order of the participants indicates the order of speaking.
|
||||
memory=BufferedChatMemory(buffer_size=10),
|
||||
termination_word="APPROVE",
|
||||
),
|
||||
)
|
||||
app.welcoming_notice = f"""Welcome to the coder-reviewer demo with the following roles:
|
||||
1. 🤖 {coder.metadata['name']}: {coder.metadata['description']}
|
||||
2. 🧐 {reviewer.metadata['name']}: {reviewer.metadata['description']}
|
||||
The coder will write code to solve a problem, and the reviewer will review the code.
|
||||
The conversation will end when the reviewer approves the code.
|
||||
Let's get started by providing a problem statement.
|
||||
"""
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
app = TextualChatApp(runtime, user_name="You")
|
||||
coder_reviewer(runtime, app)
|
||||
asyncio.create_task(start_runtime(runtime))
|
||||
await app.run_async()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Coder-reviewer pattern for code writing and review.")
|
||||
parser.add_argument("--verbose", action="store_true", help="Enable verbose logging.")
|
||||
args = parser.parse_args()
|
||||
if args.verbose:
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
logging.getLogger("agnext").setLevel(logging.DEBUG)
|
||||
handler = logging.FileHandler("coder_reviewer.log")
|
||||
logging.getLogger("agnext").addHandler(handler)
|
||||
asyncio.run(main())
|
||||
112
python/examples/illustrator_critics.py
Normal file
112
python/examples/illustrator_critics.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.dirname(__file__)))
|
||||
|
||||
import openai
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.chat.agents import ChatCompletionAgent, ImageGenerationAgent
|
||||
from agnext.chat.memory import BufferedChatMemory
|
||||
from agnext.chat.patterns.group_chat_manager import GroupChatManager
|
||||
from agnext.components.models import OpenAI, SystemMessage
|
||||
from agnext.core import AgentRuntime
|
||||
from utils import TextualChatApp, TextualUserAgent, start_runtime
|
||||
|
||||
|
||||
def illustrator_critics(runtime: AgentRuntime, app: TextualChatApp) -> str: # type: ignore
|
||||
runtime.register(
|
||||
"User",
|
||||
lambda: TextualUserAgent(
|
||||
description="A user looking for illustration.",
|
||||
app=app,
|
||||
),
|
||||
)
|
||||
descriptor = runtime.register_and_get_proxy(
|
||||
"Descriptor",
|
||||
lambda: ChatCompletionAgent(
|
||||
description="An AI agent that provides a description of the image.",
|
||||
system_messages=[
|
||||
SystemMessage(
|
||||
"You create short description for image. \n"
|
||||
"In this conversation, you will be given either: \n"
|
||||
"1. Request for new image. \n"
|
||||
"2. Feedback on some image created. \n"
|
||||
"In both cases, you will provide a description of a new image to be created. \n"
|
||||
"Only provide the description of the new image and nothing else. \n"
|
||||
"Be succinct and precise."
|
||||
),
|
||||
],
|
||||
memory=BufferedChatMemory(buffer_size=10),
|
||||
model_client=OpenAI(model="gpt-4-turbo", max_tokens=500),
|
||||
),
|
||||
)
|
||||
illustrator = runtime.register_and_get_proxy(
|
||||
"Illustrator",
|
||||
lambda: ImageGenerationAgent(
|
||||
description="An AI agent that generates images.",
|
||||
client=openai.AsyncOpenAI(),
|
||||
model="dall-e-3",
|
||||
memory=BufferedChatMemory(buffer_size=1),
|
||||
),
|
||||
)
|
||||
critic = runtime.register_and_get_proxy(
|
||||
"Critic",
|
||||
lambda: ChatCompletionAgent(
|
||||
description="An AI agent that provides feedback on images given user's requirements.",
|
||||
system_messages=[
|
||||
SystemMessage(
|
||||
"You are an expert in image understanding. \n"
|
||||
"In this conversation, you will judge an image given the description and provide feedback. \n"
|
||||
"Pay attention to the details like the spelling of words and number of objects. \n"
|
||||
"Use the following format in your response: \n"
|
||||
"Number of each object type in the image: <Type 1 (e.g., Husky Dog)>: 1, <Type 2>: 2, ...\n"
|
||||
"Feedback: <Your feedback here> \n"
|
||||
"Approval: <APPROVE or REVISE> \n"
|
||||
),
|
||||
],
|
||||
memory=BufferedChatMemory(buffer_size=2),
|
||||
model_client=OpenAI(model="gpt-4-turbo"),
|
||||
),
|
||||
)
|
||||
runtime.register(
|
||||
"GroupChatManager",
|
||||
lambda: GroupChatManager(
|
||||
description="A chat manager that handles group chat.",
|
||||
runtime=runtime,
|
||||
memory=BufferedChatMemory(buffer_size=5),
|
||||
participants=[illustrator.id, critic.id, descriptor.id],
|
||||
termination_word="APPROVE",
|
||||
),
|
||||
)
|
||||
|
||||
app.welcoming_notice = f"""You are now in a group chat with the following agents:
|
||||
|
||||
1. 🤖 {descriptor.metadata['name']}: {descriptor.metadata.get('description')}
|
||||
2. 🤖 {illustrator.metadata['name']}: {illustrator.metadata.get('description')}
|
||||
3. 🤖 {critic.metadata['name']}: {critic.metadata.get('description')}
|
||||
|
||||
Provide a prompt for the illustrator to generate an image.
|
||||
"""
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
app = TextualChatApp(runtime, user_name="You")
|
||||
illustrator_critics(runtime, app)
|
||||
asyncio.create_task(start_runtime(runtime))
|
||||
await app.run_async()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Illustrator-critics pattern for image generation demo.")
|
||||
parser.add_argument("--verbose", action="store_true", help="Enable verbose logging.")
|
||||
args = parser.parse_args()
|
||||
if args.verbose:
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
logging.getLogger("agnext").setLevel(logging.DEBUG)
|
||||
handler = logging.FileHandler("illustrator_critics.log")
|
||||
logging.getLogger("agnext").addHandler(handler)
|
||||
asyncio.run(main())
|
||||
60
python/examples/inner_outer.py
Normal file
60
python/examples/inner_outer.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.components import TypeRoutedAgent, message_handler
|
||||
from agnext.core import AgentId, CancellationToken
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageType:
|
||||
body: str
|
||||
sender: str
|
||||
|
||||
|
||||
class Inner(TypeRoutedAgent): # type: ignore
|
||||
def __init__(self) -> None: # type: ignore
|
||||
super().__init__("The inner agent")
|
||||
|
||||
@message_handler() # type: ignore
|
||||
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: # type: ignore
|
||||
return MessageType(body=f"Inner: {message.body}", sender=self.metadata["name"])
|
||||
|
||||
|
||||
class Outer(TypeRoutedAgent): # type: ignore
|
||||
def __init__(self, inner: AgentId) -> None: # type: ignore
|
||||
super().__init__("The outer agent")
|
||||
self._inner = inner
|
||||
|
||||
@message_handler() # type: ignore
|
||||
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: # type: ignore
|
||||
inner_response = self.send_message(message, self._inner)
|
||||
inner_message = await inner_response
|
||||
assert isinstance(inner_message, MessageType)
|
||||
return MessageType(body=f"Outer: {inner_message.body}", sender=self.metadata["name"])
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
inner = runtime.register_and_get("inner", Inner)
|
||||
outer = runtime.register_and_get("outer", lambda: Outer(inner))
|
||||
response = runtime.send_message(MessageType(body="Hello", sender="external"), outer)
|
||||
|
||||
while not response.done():
|
||||
await runtime.process_next()
|
||||
|
||||
print(await response)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Inner-Outter agent example.")
|
||||
parser.add_argument("--verbose", action="store_true", help="Enable verbose logging.")
|
||||
args = parser.parse_args()
|
||||
if args.verbose:
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
logging.getLogger("agnext").setLevel(logging.DEBUG)
|
||||
handler = logging.FileHandler("inner_outter.log")
|
||||
logging.getLogger("agnext").addHandler(handler)
|
||||
asyncio.run(main())
|
||||
186
python/examples/orchestrator.py
Normal file
186
python/examples/orchestrator.py
Normal file
@@ -0,0 +1,186 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Callable
|
||||
|
||||
import openai
|
||||
from agnext.application import (
|
||||
SingleThreadedAgentRuntime,
|
||||
)
|
||||
from agnext.chat.agents.chat_completion_agent import ChatCompletionAgent
|
||||
from agnext.chat.agents.oai_assistant import OpenAIAssistantAgent
|
||||
from agnext.chat.memory import BufferedChatMemory
|
||||
from agnext.chat.patterns.orchestrator_chat import OrchestratorChat
|
||||
from agnext.chat.types import TextMessage
|
||||
from agnext.components.models import OpenAI, SystemMessage
|
||||
from agnext.components.tools import BaseTool
|
||||
from agnext.core import Agent, AgentRuntime, CancellationToken
|
||||
from agnext.core.intervention import DefaultInterventionHandler, DropMessage
|
||||
from pydantic import BaseModel, Field
|
||||
from tavily import TavilyClient # type: ignore
|
||||
from typing_extensions import Any, override
|
||||
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
logging.getLogger("agnext").setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
class SearchQuery(BaseModel):
|
||||
query: str = Field(description="The search query.")
|
||||
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
result: str = Field(description="The search results.")
|
||||
|
||||
|
||||
class SearchTool(BaseTool[SearchQuery, SearchResult]):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
args_type=SearchQuery,
|
||||
return_type=SearchResult,
|
||||
name="search",
|
||||
description="Search the web.",
|
||||
)
|
||||
|
||||
async def run(self, args: SearchQuery, cancellation_token: CancellationToken) -> SearchResult:
|
||||
client = TavilyClient(api_key=os.getenv("TAVILY_API_KEY")) # type: ignore
|
||||
result = await asyncio.create_task(client.search(args.query)) # type: ignore
|
||||
if result:
|
||||
return SearchResult(result=json.dumps(result, indent=2, ensure_ascii=False))
|
||||
|
||||
return SearchResult(result="No results found.")
|
||||
|
||||
|
||||
class LoggingHandler(DefaultInterventionHandler): # type: ignore
|
||||
send_color = "\033[31m"
|
||||
response_color = "\033[34m"
|
||||
reset_color = "\033[0m"
|
||||
|
||||
@override
|
||||
async def on_send(self, message: Any, *, sender: Agent | None, recipient: Agent) -> Any | type[DropMessage]: # type: ignore
|
||||
if sender is None:
|
||||
print(f"{self.send_color}Sending message to {recipient.metadata['name']}:{self.reset_color} {message}")
|
||||
else:
|
||||
print(
|
||||
f"{self.send_color}Sending message from {sender.metadata['name']} to {recipient.metadata['name']}:{self.reset_color} {message}"
|
||||
)
|
||||
return message
|
||||
|
||||
@override
|
||||
async def on_response(self, message: Any, *, sender: Agent, recipient: Agent | None) -> Any | type[DropMessage]: # type: ignore
|
||||
if recipient is None:
|
||||
print(f"{self.response_color}Received response from {sender.metadata['name']}:{self.reset_color} {message}")
|
||||
else:
|
||||
print(
|
||||
f"{self.response_color}Received response from {sender.metadata['name']} to {recipient.metadata['name']}:{self.reset_color} {message}"
|
||||
)
|
||||
return message
|
||||
|
||||
|
||||
def software_development(runtime: AgentRuntime) -> OrchestratorChat: # type: ignore
|
||||
developer = runtime.register_and_get_proxy(
|
||||
"Developer",
|
||||
lambda: ChatCompletionAgent(
|
||||
description="A developer that writes code.",
|
||||
system_messages=[SystemMessage("You are a Python developer.")],
|
||||
memory=BufferedChatMemory(buffer_size=10),
|
||||
model_client=OpenAI(model="gpt-4-turbo"),
|
||||
),
|
||||
)
|
||||
|
||||
tester_oai_assistant = openai.beta.assistants.create(
|
||||
model="gpt-4-turbo",
|
||||
description="A software tester that runs test cases and reports results.",
|
||||
instructions="You are a software tester that runs test cases and reports results.",
|
||||
)
|
||||
tester_oai_thread = openai.beta.threads.create()
|
||||
tester = runtime.register_and_get_proxy(
|
||||
"Tester",
|
||||
lambda: OpenAIAssistantAgent(
|
||||
description="A software tester that runs test cases and reports results.",
|
||||
client=openai.AsyncClient(),
|
||||
assistant_id=tester_oai_assistant.id,
|
||||
thread_id=tester_oai_thread.id,
|
||||
),
|
||||
)
|
||||
|
||||
product_manager = runtime.register_and_get_proxy(
|
||||
"ProductManager",
|
||||
lambda: ChatCompletionAgent(
|
||||
description="A product manager that performs research and comes up with specs.",
|
||||
system_messages=[
|
||||
SystemMessage(
|
||||
"You are a product manager good at translating customer needs into software specifications."
|
||||
),
|
||||
SystemMessage("You can use the search tool to find information on the web."),
|
||||
],
|
||||
memory=BufferedChatMemory(buffer_size=10),
|
||||
model_client=OpenAI(model="gpt-4-turbo"),
|
||||
tools=[SearchTool()],
|
||||
),
|
||||
)
|
||||
|
||||
planner = runtime.register_and_get_proxy(
|
||||
"Planner",
|
||||
lambda: ChatCompletionAgent(
|
||||
description="A planner that organizes and schedules tasks.",
|
||||
system_messages=[SystemMessage("You are a planner of complex tasks.")],
|
||||
memory=BufferedChatMemory(buffer_size=10),
|
||||
model_client=OpenAI(model="gpt-4-turbo"),
|
||||
),
|
||||
)
|
||||
|
||||
orchestrator = runtime.register_and_get_proxy(
|
||||
"Orchestrator",
|
||||
lambda: ChatCompletionAgent(
|
||||
description="An orchestrator that coordinates the team.",
|
||||
system_messages=[
|
||||
SystemMessage("You are an orchestrator that coordinates the team to complete a complex task.")
|
||||
],
|
||||
memory=BufferedChatMemory(buffer_size=10),
|
||||
model_client=OpenAI(model="gpt-4-turbo"),
|
||||
),
|
||||
)
|
||||
|
||||
return OrchestratorChat(
|
||||
"A software development team.",
|
||||
runtime,
|
||||
orchestrator=orchestrator.id,
|
||||
planner=planner.id,
|
||||
specialists=[developer.id, product_manager.id, tester.id],
|
||||
)
|
||||
|
||||
|
||||
async def run(message: str, user: str, scenario: Callable[[AgentRuntime], OrchestratorChat]) -> None: # type: ignore
|
||||
runtime = SingleThreadedAgentRuntime(before_send=LoggingHandler())
|
||||
chat = scenario(runtime)
|
||||
response = runtime.send_message(TextMessage(content=message, source=user), chat.id)
|
||||
while not response.done():
|
||||
await runtime.process_next()
|
||||
print((await response).content) # type: ignore
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run a orchestrator demo.")
|
||||
choices = {"software_development": software_development}
|
||||
parser.add_argument("--verbose", action="store_true", help="Enable verbose logging.")
|
||||
parser.add_argument(
|
||||
"--scenario",
|
||||
choices=list(choices.keys()),
|
||||
help="The scenario to demo.",
|
||||
default="software_development",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--user",
|
||||
default="Customer",
|
||||
help="The user to send the message. Default is 'Customer'.",
|
||||
)
|
||||
parser.add_argument("--message", help="The message to send.", required=True)
|
||||
args = parser.parse_args()
|
||||
if args.verbose:
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
logging.getLogger("agnext").setLevel(logging.DEBUG)
|
||||
handler = logging.FileHandler("inner_outter.log")
|
||||
logging.getLogger("agnext").addHandler(handler)
|
||||
asyncio.run(run(args.message, args.user, choices[args.scenario]))
|
||||
296
python/examples/software_consultancy.py
Normal file
296
python/examples/software_consultancy.py
Normal file
@@ -0,0 +1,296 @@
|
||||
"""This is an example demonstrates event-driven orchestration using a
|
||||
group chat manager agnent.
|
||||
|
||||
WARNING: do not run this example in your local machine as it involves
|
||||
executing arbitrary code. Use a secure environment like a docker container
|
||||
or GitHub Codespaces to run this example.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.dirname(__file__)))
|
||||
|
||||
import aiofiles
|
||||
import aiohttp
|
||||
import openai
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.chat.agents import ChatCompletionAgent
|
||||
from agnext.chat.memory import HeadAndTailChatMemory
|
||||
from agnext.chat.patterns.group_chat_manager import GroupChatManager
|
||||
from agnext.components.models import OpenAI, SystemMessage
|
||||
from agnext.components.tools import FunctionTool
|
||||
from agnext.core import AgentRuntime
|
||||
from markdownify import markdownify # type: ignore
|
||||
from tqdm import tqdm
|
||||
from typing_extensions import Annotated
|
||||
from utils import TextualChatApp, TextualUserAgent, start_runtime
|
||||
|
||||
|
||||
async def write_file(filename: str, content: str) -> str:
|
||||
async with aiofiles.open(filename, "w") as file:
|
||||
await file.write(content)
|
||||
return f"Content written to {filename}."
|
||||
|
||||
|
||||
async def execute_command(command: str) -> Annotated[str, "The standard output and error of the executed command."]:
|
||||
process = await asyncio.subprocess.create_subprocess_shell(
|
||||
command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
stdout, stderr = await process.communicate()
|
||||
return f"stdout: {stdout.decode()}\nstderr: {stderr.decode()}"
|
||||
|
||||
|
||||
async def read_file(filename: str) -> Annotated[str, "The content of the file."]:
|
||||
async with aiofiles.open(filename, "r") as file:
|
||||
return await file.read()
|
||||
|
||||
|
||||
async def remove_file(filename: str) -> str:
|
||||
process = await asyncio.subprocess.create_subprocess_exec("rm", filename)
|
||||
await process.wait()
|
||||
if process.returncode != 0:
|
||||
raise ValueError(f"Error occurred while removing file: {filename}")
|
||||
return f"File removed: {filename}."
|
||||
|
||||
|
||||
async def list_files(directory: str) -> Annotated[str, "The list of files in the directory."]:
|
||||
# Ask for confirmation first.
|
||||
# await confirm(f"Are you sure you want to list files in {directory}?")
|
||||
process = await asyncio.subprocess.create_subprocess_exec(
|
||||
"ls",
|
||||
directory,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
stdout, stderr = await process.communicate()
|
||||
if stderr:
|
||||
raise ValueError(f"Error occurred while listing files: {stderr.decode()}")
|
||||
return stdout.decode()
|
||||
|
||||
|
||||
async def browse_web(url: str) -> Annotated[str, "The content of the web page in Markdown format."]:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
html = await response.text()
|
||||
markdown = markdownify(html) # type: ignore
|
||||
if isinstance(markdown, str):
|
||||
return markdown
|
||||
return f"Unable to parse content from {url}."
|
||||
|
||||
|
||||
async def create_image(
|
||||
description: Annotated[str, "Describe the image to create"],
|
||||
filename: Annotated[str, "The path to save the created image"],
|
||||
) -> str:
|
||||
# Use Dalle to generate an image from the description.
|
||||
with tqdm(desc="Generating image...", leave=False) as pbar:
|
||||
client = openai.AsyncClient()
|
||||
response = await client.images.generate(model="dall-e-2", prompt=description, response_format="b64_json")
|
||||
pbar.close()
|
||||
assert len(response.data) > 0 and response.data[0].b64_json is not None
|
||||
# Save the image to a file.
|
||||
async with aiofiles.open(filename, "wb") as file:
|
||||
image_data = base64.b64decode(response.data[0].b64_json)
|
||||
await file.write(image_data)
|
||||
return f"Image created and saved to {filename}."
|
||||
|
||||
|
||||
def software_consultancy(runtime: AgentRuntime, app: TextualChatApp) -> None: # type: ignore
|
||||
user_agent = runtime.register_and_get(
|
||||
"Customer",
|
||||
lambda: TextualUserAgent(
|
||||
description="A customer looking for help.",
|
||||
app=app,
|
||||
),
|
||||
)
|
||||
developer = runtime.register_and_get(
|
||||
"Developer",
|
||||
lambda: ChatCompletionAgent(
|
||||
description="A Python software developer.",
|
||||
system_messages=[
|
||||
SystemMessage(
|
||||
"Your are a Python developer. \n"
|
||||
"You can read, write, and execute code. \n"
|
||||
"You can browse files and directories. \n"
|
||||
"You can also browse the web for documentation. \n"
|
||||
"You are entering a work session with the customer, product manager, UX designer, and illustrator. \n"
|
||||
"When you are given a task, you should immediately start working on it. \n"
|
||||
"Be concise and deliver now."
|
||||
)
|
||||
],
|
||||
model_client=OpenAI(model="gpt-4-turbo"),
|
||||
memory=HeadAndTailChatMemory(head_size=1, tail_size=10),
|
||||
tools=[
|
||||
FunctionTool(
|
||||
write_file,
|
||||
name="write_file",
|
||||
description="Write code to a file.",
|
||||
),
|
||||
FunctionTool(
|
||||
read_file,
|
||||
name="read_file",
|
||||
description="Read code from a file.",
|
||||
),
|
||||
FunctionTool(
|
||||
execute_command,
|
||||
name="execute_command",
|
||||
description="Execute a unix shell command.",
|
||||
),
|
||||
FunctionTool(list_files, name="list_files", description="List files in a directory."),
|
||||
FunctionTool(browse_web, name="browse_web", description="Browse a web page."),
|
||||
],
|
||||
tool_approver=user_agent,
|
||||
),
|
||||
)
|
||||
|
||||
product_manager = runtime.register_and_get(
|
||||
"ProductManager",
|
||||
lambda: ChatCompletionAgent(
|
||||
description="A product manager. "
|
||||
"Responsible for interfacing with the customer, planning and managing the project. ",
|
||||
system_messages=[
|
||||
SystemMessage(
|
||||
"You are a product manager. \n"
|
||||
"You can browse files and directories. \n"
|
||||
"You are entering a work session with the customer, developer, UX designer, and illustrator. \n"
|
||||
"Keep the project on track. Don't hire any more people. \n"
|
||||
"When a milestone is reached, stop and ask for customer feedback. Make sure the customer is satisfied. \n"
|
||||
"Be VERY concise."
|
||||
)
|
||||
],
|
||||
model_client=OpenAI(model="gpt-4-turbo"),
|
||||
memory=HeadAndTailChatMemory(head_size=1, tail_size=10),
|
||||
tools=[
|
||||
FunctionTool(
|
||||
read_file,
|
||||
name="read_file",
|
||||
description="Read from a file.",
|
||||
),
|
||||
FunctionTool(list_files, name="list_files", description="List files in a directory."),
|
||||
FunctionTool(browse_web, name="browse_web", description="Browse a web page."),
|
||||
],
|
||||
tool_approver=user_agent,
|
||||
),
|
||||
)
|
||||
ux_designer = runtime.register_and_get(
|
||||
"UserExperienceDesigner",
|
||||
lambda: ChatCompletionAgent(
|
||||
description="A user experience designer for creating user interfaces.",
|
||||
system_messages=[
|
||||
SystemMessage(
|
||||
"You are a user experience designer. \n"
|
||||
"You can create user interfaces from descriptions. \n"
|
||||
"You can browse files and directories. \n"
|
||||
"You are entering a work session with the customer, developer, product manager, and illustrator. \n"
|
||||
"When you are given a task, you should immediately start working on it. \n"
|
||||
"Be concise and deliver now."
|
||||
)
|
||||
],
|
||||
model_client=OpenAI(model="gpt-4-turbo"),
|
||||
memory=HeadAndTailChatMemory(head_size=1, tail_size=10),
|
||||
tools=[
|
||||
FunctionTool(
|
||||
write_file,
|
||||
name="write_file",
|
||||
description="Write code to a file.",
|
||||
),
|
||||
FunctionTool(
|
||||
read_file,
|
||||
name="read_file",
|
||||
description="Read code from a file.",
|
||||
),
|
||||
FunctionTool(list_files, name="list_files", description="List files in a directory."),
|
||||
],
|
||||
tool_approver=user_agent,
|
||||
),
|
||||
)
|
||||
|
||||
illustrator = runtime.register_and_get(
|
||||
"Illustrator",
|
||||
lambda: ChatCompletionAgent(
|
||||
description="An illustrator for creating images.",
|
||||
system_messages=[
|
||||
SystemMessage(
|
||||
"You are an illustrator. "
|
||||
"You can create images from descriptions. "
|
||||
"You are entering a work session with the customer, developer, product manager, and UX designer. \n"
|
||||
"When you are given a task, you should immediately start working on it. \n"
|
||||
"Be concise and deliver now."
|
||||
)
|
||||
],
|
||||
model_client=OpenAI(model="gpt-4-turbo"),
|
||||
memory=HeadAndTailChatMemory(head_size=1, tail_size=10),
|
||||
tools=[
|
||||
FunctionTool(
|
||||
create_image,
|
||||
name="create_image",
|
||||
description="Create an image from a description.",
|
||||
),
|
||||
],
|
||||
tool_approver=user_agent,
|
||||
),
|
||||
)
|
||||
runtime.register(
|
||||
"GroupChatManager",
|
||||
lambda: GroupChatManager(
|
||||
description="A group chat manager.",
|
||||
runtime=runtime,
|
||||
memory=HeadAndTailChatMemory(head_size=1, tail_size=10),
|
||||
model_client=OpenAI(model="gpt-4-turbo"),
|
||||
participants=[developer, product_manager, ux_designer, illustrator, user_agent],
|
||||
),
|
||||
)
|
||||
art = r"""
|
||||
+----------------------------------------------------------+
|
||||
| ____ __ _ |
|
||||
| / ___| ___ / _| |___ ____ _ _ __ ___ |
|
||||
| \___ \ / _ \| |_| __\ \ /\ / / _` | '__/ _ \ |
|
||||
| ___) | (_) | _| |_ \ V V / (_| | | | __/ |
|
||||
| |____/ \___/|_| \__| \_/\_/ \__,_|_| \___| |
|
||||
| |
|
||||
| ____ _ _ |
|
||||
| / ___|___ _ __ ___ _ _| | |_ __ _ _ __ ___ _ _ |
|
||||
| | | / _ \| '_ \/ __| | | | | __/ _` | '_ \ / __| | | | |
|
||||
| | |__| (_) | | | \__ \ |_| | | || (_| | | | | (__| |_| | |
|
||||
| \____\___/|_| |_|___/\__,_|_|\__\__,_|_| |_|\___|\__, | |
|
||||
| |___/ |
|
||||
| |
|
||||
| Work with a software development consultancy to create |
|
||||
| your own Python application. You are working with a team |
|
||||
| of the following agents: |
|
||||
| 1. 🤖 Developer: A Python software developer. |
|
||||
| 2. 🤖 ProductManager: A product manager. |
|
||||
| 3. 🤖 UserExperienceDesigner: A user experience designer. |
|
||||
| 4. 🤖 Illustrator: An illustrator. |
|
||||
+----------------------------------------------------------+
|
||||
"""
|
||||
app.welcoming_notice = art
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
app = TextualChatApp(runtime, user_name="You")
|
||||
software_consultancy(runtime, app)
|
||||
# Start the runtime.
|
||||
asyncio.create_task(start_runtime(runtime))
|
||||
# Start the app.
|
||||
await app.run_async()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Software consultancy demo.")
|
||||
parser.add_argument("--verbose", action="store_true", help="Enable verbose logging.")
|
||||
args = parser.parse_args()
|
||||
if args.verbose:
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
logging.getLogger("agnext").setLevel(logging.DEBUG)
|
||||
handler = logging.FileHandler("software_consultancy.log")
|
||||
logging.getLogger("agnext").addHandler(handler)
|
||||
asyncio.run(main())
|
||||
192
python/examples/utils.py
Normal file
192
python/examples/utils.py
Normal file
@@ -0,0 +1,192 @@
|
||||
import asyncio
|
||||
import random
|
||||
from asyncio import Future
|
||||
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.chat.types import (
|
||||
MultiModalMessage,
|
||||
PublishNow,
|
||||
RespondNow,
|
||||
TextMessage,
|
||||
ToolApprovalRequest,
|
||||
ToolApprovalResponse,
|
||||
)
|
||||
from agnext.components import Image, TypeRoutedAgent, message_handler
|
||||
from agnext.core import AgentRuntime, CancellationToken
|
||||
from textual.app import App, ComposeResult
|
||||
from textual.containers import ScrollableContainer
|
||||
from textual.widgets import Button, Footer, Header, Input, Markdown, Static
|
||||
from textual_imageview.viewer import ImageViewer
|
||||
|
||||
|
||||
class ChatAppMessage(Static):
|
||||
def __init__(self, message: TextMessage | MultiModalMessage) -> None: # type: ignore
|
||||
self._message = message
|
||||
super().__init__()
|
||||
|
||||
def on_mount(self) -> None:
|
||||
self.styles.margin = 1
|
||||
self.styles.padding = 1
|
||||
self.styles.border = ("solid", "blue")
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
if isinstance(self._message, TextMessage):
|
||||
yield Markdown(f"{self._message.source}:")
|
||||
yield Markdown(self._message.content)
|
||||
else:
|
||||
yield Markdown(f"{self._message.source}:")
|
||||
for content in self._message.content:
|
||||
if isinstance(content, str):
|
||||
yield Markdown(content)
|
||||
elif isinstance(content, Image):
|
||||
viewer = ImageViewer(content.image)
|
||||
viewer.styles.min_width = 50
|
||||
viewer.styles.min_height = 50
|
||||
yield viewer
|
||||
|
||||
|
||||
class WelcomeMessage(Static):
|
||||
def on_mount(self) -> None:
|
||||
self.styles.margin = 1
|
||||
self.styles.padding = 1
|
||||
self.styles.border = ("solid", "blue")
|
||||
|
||||
|
||||
class ChatInput(Input):
|
||||
def on_mount(self) -> None:
|
||||
self.focus()
|
||||
|
||||
def on_input_submitted(self, event: Input.Submitted) -> None:
|
||||
self.clear()
|
||||
|
||||
|
||||
class ToolApprovalRequestNotice(Static):
|
||||
def __init__(self, request: ToolApprovalRequest, response_future: Future[ToolApprovalResponse]) -> None: # type: ignore
|
||||
self._tool_call = request.tool_call
|
||||
self._future = response_future
|
||||
super().__init__()
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Static(f"Tool call: {self._tool_call.name}, arguments: {self._tool_call.arguments[:50]}")
|
||||
yield Button("Approve", id="approve", variant="warning")
|
||||
yield Button("Deny", id="deny", variant="default")
|
||||
|
||||
def on_mount(self) -> None:
|
||||
self.styles.margin = 1
|
||||
self.styles.padding = 1
|
||||
self.styles.border = ("solid", "red")
|
||||
|
||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
button_id = event.button.id
|
||||
assert button_id is not None
|
||||
if button_id == "approve":
|
||||
self._future.set_result(ToolApprovalResponse(tool_call_id=self._tool_call.id, approved=True, reason=""))
|
||||
else:
|
||||
self._future.set_result(ToolApprovalResponse(tool_call_id=self._tool_call.id, approved=False, reason=""))
|
||||
self.remove()
|
||||
|
||||
|
||||
class TextualChatApp(App): # type: ignore
|
||||
"""A Textual app for a chat interface."""
|
||||
|
||||
def __init__(self, runtime: AgentRuntime, welcoming_notice: str | None = None, user_name: str = "User") -> None: # type: ignore
|
||||
self._runtime = runtime
|
||||
self._welcoming_notice = welcoming_notice
|
||||
self._user_name = user_name
|
||||
super().__init__()
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Header()
|
||||
yield Footer()
|
||||
yield ScrollableContainer(id="chat-messages")
|
||||
yield ChatInput()
|
||||
|
||||
def on_mount(self) -> None:
|
||||
if self._welcoming_notice is not None:
|
||||
chat_messages = self.query_one("#chat-messages")
|
||||
notice = WelcomeMessage(self._welcoming_notice, id="welcome")
|
||||
chat_messages.mount(notice)
|
||||
|
||||
@property
|
||||
def welcoming_notice(self) -> str | None:
|
||||
return self._welcoming_notice
|
||||
|
||||
@welcoming_notice.setter
|
||||
def welcoming_notice(self, value: str) -> None:
|
||||
self._welcoming_notice = value
|
||||
|
||||
async def on_input_submitted(self, event: Input.Submitted) -> None:
|
||||
user_input = event.value
|
||||
await self.publish_user_message(user_input)
|
||||
|
||||
async def post_request_user_input_notice(self) -> None:
|
||||
chat_messages = self.query_one("#chat-messages")
|
||||
notice = Static("Please enter your input.", id="typing")
|
||||
chat_messages.mount(notice)
|
||||
notice.scroll_visible()
|
||||
|
||||
async def publish_user_message(self, user_input: str) -> None:
|
||||
chat_messages = self.query_one("#chat-messages")
|
||||
# Remove all typing messages.
|
||||
chat_messages.query("#typing").remove()
|
||||
# Publish the user message to the runtime.
|
||||
await self._runtime.publish_message(
|
||||
TextMessage(source=self._user_name, content=user_input), namespace="default"
|
||||
)
|
||||
|
||||
async def post_runtime_message(self, message: TextMessage | MultiModalMessage) -> None: # type: ignore
|
||||
"""Post a message from the agent runtime to the message list."""
|
||||
chat_messages = self.query_one("#chat-messages")
|
||||
msg = ChatAppMessage(message)
|
||||
chat_messages.mount(msg)
|
||||
msg.scroll_visible()
|
||||
|
||||
async def handle_tool_approval_request(self, message: ToolApprovalRequest) -> ToolApprovalResponse: # type: ignore
|
||||
chat_messages = self.query_one("#chat-messages")
|
||||
future: Future[ToolApprovalResponse] = asyncio.get_event_loop().create_future() # type: ignore
|
||||
tool_call_approval_notice = ToolApprovalRequestNotice(message, future)
|
||||
chat_messages.mount(tool_call_approval_notice)
|
||||
tool_call_approval_notice.scroll_visible()
|
||||
return await future
|
||||
|
||||
|
||||
class TextualUserAgent(TypeRoutedAgent): # type: ignore
|
||||
"""An agent that is used to receive messages from the runtime."""
|
||||
|
||||
def __init__(self, description: str, app: TextualChatApp) -> None: # type: ignore
|
||||
super().__init__(description)
|
||||
self._app = app
|
||||
|
||||
@message_handler # type: ignore
|
||||
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None: # type: ignore
|
||||
await self._app.post_runtime_message(message)
|
||||
|
||||
@message_handler # type: ignore
|
||||
async def on_multi_modal_message(self, message: MultiModalMessage, cancellation_token: CancellationToken) -> None: # type: ignore
|
||||
# Save the message to file.
|
||||
# Generate a ramdom file name.
|
||||
for content in message.content:
|
||||
if isinstance(content, Image):
|
||||
filename = f"{self.metadata['name']}_{message.source}_{random.randbytes(16).hex()}.png"
|
||||
content.image.save(filename)
|
||||
await self._app.post_runtime_message(message)
|
||||
|
||||
@message_handler # type: ignore
|
||||
async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> None: # type: ignore
|
||||
await self._app.post_request_user_input_notice()
|
||||
|
||||
@message_handler # type: ignore
|
||||
async def on_publish_now(self, message: PublishNow, cancellation_token: CancellationToken) -> None: # type: ignore
|
||||
await self._app.post_request_user_input_notice()
|
||||
|
||||
@message_handler # type: ignore
|
||||
async def on_tool_approval_request(
|
||||
self, message: ToolApprovalRequest, cancellation_token: CancellationToken
|
||||
) -> ToolApprovalResponse:
|
||||
return await self._app.handle_tool_approval_request(message)
|
||||
|
||||
|
||||
async def start_runtime(runtime: SingleThreadedAgentRuntime) -> None: # type: ignore
|
||||
"""Run the runtime in a loop."""
|
||||
while True:
|
||||
await runtime.process_next()
|
||||
118
python/pyproject.toml
Normal file
118
python/pyproject.toml
Normal file
@@ -0,0 +1,118 @@
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "agnext"
|
||||
version = "0.0.1"
|
||||
description = "A small example package"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
classifiers = [
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: OS Independent",
|
||||
]
|
||||
dependencies = [
|
||||
"openai>=1.3",
|
||||
"pillow",
|
||||
"aiohttp",
|
||||
"typing-extensions",
|
||||
"pydantic>=1.10,<3",
|
||||
]
|
||||
|
||||
[tool.hatch.envs.default]
|
||||
dependencies = [
|
||||
"pyright==1.1.368",
|
||||
"mypy==1.10.0",
|
||||
"ruff==0.4.8",
|
||||
"types-Pillow",
|
||||
"polars",
|
||||
"chess",
|
||||
"tavily-python",
|
||||
"aiofiles",
|
||||
"types-aiofiles",
|
||||
"colorama",
|
||||
"textual",
|
||||
"textual-dev",
|
||||
"textual-imageview",
|
||||
"pytest-asyncio",
|
||||
"pip",
|
||||
"pytest",
|
||||
"pytest-xdist",
|
||||
]
|
||||
|
||||
[tool.hatch.envs.default.scripts]
|
||||
fmt = "ruff format"
|
||||
lint = "ruff check"
|
||||
test = "pytest -n auto"
|
||||
check = [
|
||||
"ruff format",
|
||||
"ruff check --fix",
|
||||
"pyright",
|
||||
"mypy",
|
||||
"pytest -n auto",
|
||||
]
|
||||
|
||||
[tool.hatch.envs.test-matrix]
|
||||
template = "default"
|
||||
|
||||
[[tool.hatch.envs.test-matrix.matrix]]
|
||||
python = ["3.10", "3.11", "3.12"]
|
||||
|
||||
[tool.hatch.envs.docs]
|
||||
dependencies = [
|
||||
"sphinx", "furo", "sphinxcontrib-apidoc", "myst-parser", "sphinx-autobuild"
|
||||
]
|
||||
|
||||
[tool.hatch.envs.docs.scripts]
|
||||
build = "sphinx-build docs/src docs/build"
|
||||
serve = "sphinx-autobuild --watch src docs/src docs/build"
|
||||
check = "sphinx-build --fail-on-warning docs/src docs/build"
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 120
|
||||
fix = true
|
||||
exclude = ["build", "dist", "my_project/__init__.py", "my_project/main.py"]
|
||||
target-version = "py310"
|
||||
include = ["src/**", "examples/*.py"]
|
||||
|
||||
[tool.ruff.format]
|
||||
docstring-code-format = true
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "W", "B", "Q", "I", "ASYNC"]
|
||||
ignore = ["F401", "E501"]
|
||||
|
||||
[tool.ruff.lint.flake8-tidy-imports]
|
||||
[tool.ruff.lint.flake8-tidy-imports.banned-api]
|
||||
"unittest".msg = "Use `pytest` instead."
|
||||
|
||||
[tool.mypy]
|
||||
files = ["src", "examples", "tests"]
|
||||
|
||||
strict = true
|
||||
python_version = "3.10"
|
||||
ignore_missing_imports = true
|
||||
|
||||
# from https://blog.wolt.com/engineering/2021/09/30/professional-grade-mypy-configuration/
|
||||
disallow_untyped_defs = true
|
||||
no_implicit_optional = true
|
||||
check_untyped_defs = true
|
||||
warn_return_any = true
|
||||
show_error_codes = true
|
||||
warn_unused_ignores = false
|
||||
|
||||
disallow_incomplete_defs = true
|
||||
disallow_untyped_decorators = true
|
||||
disallow_any_unimported = true
|
||||
|
||||
[tool.pyright]
|
||||
include = ["src", "tests", "examples"]
|
||||
typeCheckingMode = "strict"
|
||||
reportUnnecessaryIsInstance = false
|
||||
reportMissingTypeStubs = false
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
minversion = "6.0"
|
||||
testpaths = ["tests"]
|
||||
0
python/src/agnext/__init__.py
Normal file
0
python/src/agnext/__init__.py
Normal file
7
python/src/agnext/application/__init__.py
Normal file
7
python/src/agnext/application/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
The :mod:`agnext.application` module provides implementations of core components that are used to compose an application
|
||||
"""
|
||||
|
||||
from ._single_threaded_agent_runtime import SingleThreadedAgentRuntime
|
||||
|
||||
__all__ = ["SingleThreadedAgentRuntime"]
|
||||
459
python/src/agnext/application/_single_threaded_agent_runtime.py
Normal file
459
python/src/agnext/application/_single_threaded_agent_runtime.py
Normal file
@@ -0,0 +1,459 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
from asyncio import Future
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Callable, DefaultDict, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast
|
||||
|
||||
from ..core import Agent, AgentId, AgentMetadata, AgentProxy, AgentRuntime, AllNamespaces, BaseAgent, CancellationToken
|
||||
from ..core.exceptions import MessageDroppedException
|
||||
from ..core.intervention import DropMessage, InterventionHandler
|
||||
|
||||
logger = logging.getLogger("agnext")
|
||||
event_logger = logging.getLogger("agnext.events")
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class PublishMessageEnvelope:
|
||||
"""A message envelope for publishing messages to all agents that can handle
|
||||
the message of the type T."""
|
||||
|
||||
message: Any
|
||||
cancellation_token: CancellationToken
|
||||
sender: AgentId | None
|
||||
namespace: str
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class SendMessageEnvelope:
|
||||
"""A message envelope for sending a message to a specific agent that can handle
|
||||
the message of the type T."""
|
||||
|
||||
message: Any
|
||||
sender: AgentId | None
|
||||
recipient: AgentId
|
||||
future: Future[Any]
|
||||
cancellation_token: CancellationToken
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class ResponseMessageEnvelope:
|
||||
"""A message envelope for sending a response to a message."""
|
||||
|
||||
message: Any
|
||||
future: Future[Any]
|
||||
sender: AgentId
|
||||
recipient: AgentId | None
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T", bound=Agent)
|
||||
|
||||
|
||||
class Counter:
|
||||
def __init__(self) -> None:
|
||||
self._count: int = 0
|
||||
self.threadLock = threading.Lock()
|
||||
|
||||
def increment(self) -> None:
|
||||
self.threadLock.acquire()
|
||||
self._count += 1
|
||||
self.threadLock.release()
|
||||
|
||||
def get(self) -> int:
|
||||
return self._count
|
||||
|
||||
def decrement(self) -> None:
|
||||
self.threadLock.acquire()
|
||||
self._count -= 1
|
||||
self.threadLock.release()
|
||||
|
||||
|
||||
class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
def __init__(self, *, before_send: InterventionHandler | None = None) -> None:
|
||||
self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = []
|
||||
# (namespace, type) -> List[AgentId]
|
||||
self._per_type_subscribers: DefaultDict[tuple[str, type], Set[AgentId]] = defaultdict(set)
|
||||
self._agent_factories: Dict[str, Callable[[], Agent] | Callable[[AgentRuntime, AgentId], Agent]] = {}
|
||||
# If empty, then all namespaces are valid for that agent type
|
||||
self._valid_namespaces: Dict[str, Sequence[str]] = {}
|
||||
self._instantiated_agents: Dict[AgentId, Agent] = {}
|
||||
self._before_send = before_send
|
||||
self._known_namespaces: set[str] = set()
|
||||
self._outstanding_tasks = Counter()
|
||||
|
||||
@property
|
||||
def unprocessed_messages(
|
||||
self,
|
||||
) -> Sequence[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope]:
|
||||
return self._message_queue
|
||||
|
||||
@property
|
||||
def outstanding_tasks(self) -> int:
|
||||
return self._outstanding_tasks.get()
|
||||
|
||||
@property
|
||||
def _known_agent_names(self) -> Set[str]:
|
||||
return set(self._agent_factories.keys())
|
||||
|
||||
# Returns the response of the message
|
||||
def send_message(
|
||||
self,
|
||||
message: Any,
|
||||
recipient: AgentId,
|
||||
*,
|
||||
sender: AgentId | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> Future[Any | None]:
|
||||
if cancellation_token is None:
|
||||
cancellation_token = CancellationToken()
|
||||
|
||||
# event_logger.info(
|
||||
# MessageEvent(
|
||||
# payload=message,
|
||||
# sender=sender,
|
||||
# receiver=recipient,
|
||||
# kind=MessageKind.DIRECT,
|
||||
# delivery_stage=DeliveryStage.SEND,
|
||||
# )
|
||||
# )
|
||||
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
if recipient.name not in self._known_agent_names:
|
||||
future.set_exception(Exception("Recipient not found"))
|
||||
|
||||
if sender is not None and sender.namespace != recipient.namespace:
|
||||
raise ValueError("Sender and recipient must be in the same namespace to communicate.")
|
||||
|
||||
self._process_seen_namespace(recipient.namespace)
|
||||
|
||||
logger.info(f"Sending message of type {type(message).__name__} to {recipient.name}: {message.__dict__}")
|
||||
|
||||
self._message_queue.append(
|
||||
SendMessageEnvelope(
|
||||
message=message,
|
||||
recipient=recipient,
|
||||
future=future,
|
||||
cancellation_token=cancellation_token,
|
||||
sender=sender,
|
||||
)
|
||||
)
|
||||
|
||||
return future
|
||||
|
||||
def publish_message(
|
||||
self,
|
||||
message: Any,
|
||||
*,
|
||||
namespace: str | None = None,
|
||||
sender: AgentId | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> Future[None]:
|
||||
if cancellation_token is None:
|
||||
cancellation_token = CancellationToken()
|
||||
|
||||
logger.info(f"Publishing message of type {type(message).__name__} to all subscribers: {message.__dict__}")
|
||||
|
||||
# event_logger.info(
|
||||
# MessageEvent(
|
||||
# payload=message,
|
||||
# sender=sender,
|
||||
# receiver=None,
|
||||
# kind=MessageKind.PUBLISH,
|
||||
# delivery_stage=DeliveryStage.SEND,
|
||||
# )
|
||||
# )
|
||||
|
||||
if sender is None and namespace is None:
|
||||
raise ValueError("Namespace must be provided if sender is not provided.")
|
||||
|
||||
sender_namespace = sender.namespace if sender is not None else None
|
||||
explicit_namespace = namespace
|
||||
if explicit_namespace is not None and sender_namespace is not None and explicit_namespace != sender_namespace:
|
||||
raise ValueError(
|
||||
f"Explicit namespace {explicit_namespace} does not match sender namespace {sender_namespace}"
|
||||
)
|
||||
|
||||
assert explicit_namespace is not None or sender_namespace is not None
|
||||
namespace = cast(str, explicit_namespace or sender_namespace)
|
||||
self._process_seen_namespace(namespace)
|
||||
|
||||
self._message_queue.append(
|
||||
PublishMessageEnvelope(
|
||||
message=message,
|
||||
cancellation_token=cancellation_token,
|
||||
sender=sender,
|
||||
namespace=namespace,
|
||||
)
|
||||
)
|
||||
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
future.set_result(None)
|
||||
return future
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
state: Dict[str, Dict[str, Any]] = {}
|
||||
for agent_id in self._instantiated_agents:
|
||||
state[str(agent_id)] = dict(self._get_agent(agent_id).save_state())
|
||||
return state
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
for agent_id_str in state:
|
||||
agent_id = AgentId.from_str(agent_id_str)
|
||||
if agent_id.name in self._known_agent_names:
|
||||
self._get_agent(agent_id).load_state(state[str(agent_id)])
|
||||
|
||||
async def _process_send(self, message_envelope: SendMessageEnvelope) -> None:
|
||||
recipient = message_envelope.recipient
|
||||
# todo: check if recipient is in the known namespaces
|
||||
# assert recipient in self._agents
|
||||
|
||||
try:
|
||||
sender_name = message_envelope.sender.name if message_envelope.sender is not None else "Unknown"
|
||||
logger.info(
|
||||
f"Calling message handler for {recipient} with message type {type(message_envelope.message).__name__} sent by {sender_name}"
|
||||
)
|
||||
# event_logger.info(
|
||||
# MessageEvent(
|
||||
# payload=message_envelope.message,
|
||||
# sender=message_envelope.sender,
|
||||
# receiver=recipient,
|
||||
# kind=MessageKind.DIRECT,
|
||||
# delivery_stage=DeliveryStage.DELIVER,
|
||||
# )
|
||||
# )
|
||||
recipient_agent = self._get_agent(recipient)
|
||||
response = await recipient_agent.on_message(
|
||||
message_envelope.message,
|
||||
cancellation_token=message_envelope.cancellation_token,
|
||||
)
|
||||
except BaseException as e:
|
||||
message_envelope.future.set_exception(e)
|
||||
return
|
||||
|
||||
self._message_queue.append(
|
||||
ResponseMessageEnvelope(
|
||||
message=response,
|
||||
future=message_envelope.future,
|
||||
sender=message_envelope.recipient,
|
||||
recipient=message_envelope.sender,
|
||||
)
|
||||
)
|
||||
self._outstanding_tasks.decrement()
|
||||
|
||||
async def _process_publish(self, message_envelope: PublishMessageEnvelope) -> None:
|
||||
responses: List[Awaitable[Any]] = []
|
||||
target_namespace = message_envelope.namespace
|
||||
for agent_id in self._per_type_subscribers[(target_namespace, type(message_envelope.message))]:
|
||||
if message_envelope.sender is not None and agent_id.name == message_envelope.sender.name:
|
||||
continue
|
||||
|
||||
sender_agent = self._get_agent(message_envelope.sender) if message_envelope.sender is not None else None
|
||||
sender_name = sender_agent.metadata["name"] if sender_agent is not None else "Unknown"
|
||||
logger.info(
|
||||
f"Calling message handler for {agent_id.name} with message type {type(message_envelope.message).__name__} published by {sender_name}"
|
||||
)
|
||||
# event_logger.info(
|
||||
# MessageEvent(
|
||||
# payload=message_envelope.message,
|
||||
# sender=message_envelope.sender,
|
||||
# receiver=agent,
|
||||
# kind=MessageKind.PUBLISH,
|
||||
# delivery_stage=DeliveryStage.DELIVER,
|
||||
# )
|
||||
# )
|
||||
|
||||
agent = self._get_agent(agent_id)
|
||||
future = agent.on_message(
|
||||
message_envelope.message,
|
||||
cancellation_token=message_envelope.cancellation_token,
|
||||
)
|
||||
responses.append(future)
|
||||
|
||||
try:
|
||||
_all_responses = await asyncio.gather(*responses)
|
||||
except BaseException:
|
||||
logger.error("Error processing publish message", exc_info=True)
|
||||
return
|
||||
|
||||
self._outstanding_tasks.decrement()
|
||||
# TODO if responses are given for a publish
|
||||
|
||||
async def _process_response(self, message_envelope: ResponseMessageEnvelope) -> None:
|
||||
content = (
|
||||
message_envelope.message.__dict__
|
||||
if hasattr(message_envelope.message, "__dict__")
|
||||
else message_envelope.message
|
||||
)
|
||||
logger.info(
|
||||
f"Resolving response with message type {type(message_envelope.message).__name__} for recipient {message_envelope.recipient} from {message_envelope.sender.name}: {content}"
|
||||
)
|
||||
# event_logger.info(
|
||||
# MessageEvent(
|
||||
# payload=message_envelope.message,
|
||||
# sender=message_envelope.sender,
|
||||
# receiver=message_envelope.recipient,
|
||||
# kind=MessageKind.RESPOND,
|
||||
# delivery_stage=DeliveryStage.DELIVER,
|
||||
# )
|
||||
# )
|
||||
self._outstanding_tasks.decrement()
|
||||
message_envelope.future.set_result(message_envelope.message)
|
||||
|
||||
async def process_next(self) -> None:
|
||||
if len(self._message_queue) == 0:
|
||||
# Yield control to the event loop to allow other tasks to run
|
||||
await asyncio.sleep(0)
|
||||
return
|
||||
|
||||
message_envelope = self._message_queue.pop(0)
|
||||
|
||||
match message_envelope:
|
||||
case SendMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
|
||||
if self._before_send is not None:
|
||||
try:
|
||||
temp_message = await self._before_send.on_send(message, sender=sender, recipient=recipient)
|
||||
except BaseException as e:
|
||||
future.set_exception(e)
|
||||
return
|
||||
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
|
||||
future.set_exception(MessageDroppedException())
|
||||
return
|
||||
|
||||
message_envelope.message = temp_message
|
||||
self._outstanding_tasks.increment()
|
||||
asyncio.create_task(self._process_send(message_envelope))
|
||||
case PublishMessageEnvelope(
|
||||
message=message,
|
||||
sender=sender,
|
||||
):
|
||||
if self._before_send is not None:
|
||||
try:
|
||||
temp_message = await self._before_send.on_publish(message, sender=sender)
|
||||
except BaseException as e:
|
||||
# TODO: we should raise the intervention exception to the publisher.
|
||||
logger.error(f"Exception raised in in intervention handler: {e}", exc_info=True)
|
||||
return
|
||||
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
|
||||
# TODO log message dropped
|
||||
return
|
||||
|
||||
message_envelope.message = temp_message
|
||||
self._outstanding_tasks.increment()
|
||||
asyncio.create_task(self._process_publish(message_envelope))
|
||||
case ResponseMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
|
||||
if self._before_send is not None:
|
||||
try:
|
||||
temp_message = await self._before_send.on_response(message, sender=sender, recipient=recipient)
|
||||
except BaseException as e:
|
||||
# TODO: should we raise the exception to sender of the response instead?
|
||||
future.set_exception(e)
|
||||
return
|
||||
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
|
||||
future.set_exception(MessageDroppedException())
|
||||
return
|
||||
|
||||
message_envelope.message = temp_message
|
||||
self._outstanding_tasks.increment()
|
||||
asyncio.create_task(self._process_response(message_envelope))
|
||||
|
||||
# Yield control to the message loop to allow other tasks to run
|
||||
await asyncio.sleep(0)
|
||||
|
||||
def agent_metadata(self, agent: AgentId) -> AgentMetadata:
|
||||
return self._get_agent(agent).metadata
|
||||
|
||||
def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]:
|
||||
return self._get_agent(agent).save_state()
|
||||
|
||||
def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None:
|
||||
self._get_agent(agent).load_state(state)
|
||||
|
||||
def register(
|
||||
self,
|
||||
name: str,
|
||||
agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T],
|
||||
*,
|
||||
valid_namespaces: Sequence[str] | Type[AllNamespaces] = AllNamespaces,
|
||||
) -> None:
|
||||
if name in self._agent_factories:
|
||||
raise ValueError(f"Agent with name {name} already exists.")
|
||||
self._agent_factories[name] = agent_factory
|
||||
if valid_namespaces is not AllNamespaces:
|
||||
self._valid_namespaces[name] = cast(Sequence[str], valid_namespaces)
|
||||
else:
|
||||
self._valid_namespaces[name] = []
|
||||
|
||||
# For all already prepared namespaces we need to prepare this agent
|
||||
for namespace in self._known_namespaces:
|
||||
if self._type_valid_for_namespace(AgentId(name=name, namespace=namespace)):
|
||||
self._get_agent(AgentId(name=name, namespace=namespace))
|
||||
|
||||
def _invoke_agent_factory(
|
||||
self, agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T], agent_id: AgentId
|
||||
) -> T:
|
||||
if len(inspect.signature(agent_factory).parameters) == 0:
|
||||
factory_one = cast(Callable[[], T], agent_factory)
|
||||
agent = factory_one()
|
||||
elif len(inspect.signature(agent_factory).parameters) == 2:
|
||||
factory_two = cast(Callable[[AgentRuntime, AgentId], T], agent_factory)
|
||||
agent = factory_two(self, agent_id)
|
||||
else:
|
||||
raise ValueError("Agent factory must take 0 or 2 arguments.")
|
||||
|
||||
# TODO: should this be part of the base agent interface?
|
||||
if isinstance(agent, BaseAgent):
|
||||
agent.bind_id(agent_id)
|
||||
agent.bind_runtime(self)
|
||||
|
||||
return agent
|
||||
|
||||
def _type_valid_for_namespace(self, agent_id: AgentId) -> bool:
|
||||
if agent_id.name not in self._agent_factories:
|
||||
raise KeyError(f"Agent with name {agent_id.name} not found.")
|
||||
|
||||
valid_namespaces = self._valid_namespaces[agent_id.name]
|
||||
if len(valid_namespaces) == 0:
|
||||
return True
|
||||
|
||||
return agent_id.namespace in valid_namespaces
|
||||
|
||||
def _get_agent(self, agent_id: AgentId) -> Agent:
|
||||
self._process_seen_namespace(agent_id.namespace)
|
||||
if agent_id in self._instantiated_agents:
|
||||
return self._instantiated_agents[agent_id]
|
||||
|
||||
if not self._type_valid_for_namespace(agent_id):
|
||||
raise ValueError(f"Agent with name {agent_id.name} not valid for namespace {agent_id.namespace}.")
|
||||
|
||||
if agent_id.name not in self._agent_factories:
|
||||
raise ValueError(f"Agent with name {agent_id.name} not found.")
|
||||
|
||||
agent_factory = self._agent_factories[agent_id.name]
|
||||
|
||||
agent = self._invoke_agent_factory(agent_factory, agent_id)
|
||||
for message_type in agent.metadata["subscriptions"]:
|
||||
self._per_type_subscribers[(agent_id.namespace, message_type)].add(agent_id)
|
||||
self._instantiated_agents[agent_id] = agent
|
||||
return agent
|
||||
|
||||
def get(self, name: str, *, namespace: str = "default") -> AgentId:
|
||||
return self._get_agent(AgentId(name=name, namespace=namespace)).id
|
||||
|
||||
def get_proxy(self, name: str, *, namespace: str = "default") -> AgentProxy:
|
||||
id = self.get(name, namespace=namespace)
|
||||
return AgentProxy(id, self)
|
||||
|
||||
# Hydrate the agent instances in a namespace. The primary reason for this is
|
||||
# to ensure message type subscriptions are set up.
|
||||
def _process_seen_namespace(self, namespace: str) -> None:
|
||||
if namespace in self._known_namespaces:
|
||||
return
|
||||
|
||||
self._known_namespaces.add(namespace)
|
||||
for name in self._known_agent_names:
|
||||
if self._type_valid_for_namespace(AgentId(name=name, namespace=namespace)):
|
||||
self._get_agent(AgentId(name=name, namespace=namespace))
|
||||
13
python/src/agnext/application/logging/__init__.py
Normal file
13
python/src/agnext/application/logging/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from ._events import DeliveryStage, LLMCallEvent, MessageEvent, MessageKind
|
||||
from ._llm_usage import LLMUsageTracker
|
||||
|
||||
EVENT_LOGGER_NAME = "agnext.events"
|
||||
|
||||
__all__ = [
|
||||
"LLMCallEvent",
|
||||
"EVENT_LOGGER_NAME",
|
||||
"LLMUsageTracker",
|
||||
"MessageEvent",
|
||||
"MessageKind",
|
||||
"DeliveryStage",
|
||||
]
|
||||
84
python/src/agnext/application/logging/_events.py
Normal file
84
python/src/agnext/application/logging/_events.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Any, cast
|
||||
|
||||
from ...core import Agent
|
||||
|
||||
|
||||
class LLMCallEvent:
|
||||
def __init__(self, *, prompt_tokens: int, completion_tokens: int, **kwargs: Any) -> None:
|
||||
"""To be used by model clients to log the call to the LLM.
|
||||
|
||||
Args:
|
||||
prompt_tokens (int): Number of tokens used in the prompt.
|
||||
completion_tokens (int): Number of tokens used in the completion.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from agnext.application.logging import LLMCallEvent, EVENT_LOGGER_NAME
|
||||
|
||||
logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
logger.info(LLMCallEvent(prompt_tokens=10, completion_tokens=20))
|
||||
|
||||
"""
|
||||
self.kwargs = kwargs
|
||||
self.kwargs["prompt_tokens"] = prompt_tokens
|
||||
self.kwargs["completion_tokens"] = completion_tokens
|
||||
self.kwargs["type"] = "LLMCall"
|
||||
|
||||
@property
|
||||
def prompt_tokens(self) -> int:
|
||||
return cast(int, self.kwargs["prompt_tokens"])
|
||||
|
||||
@property
|
||||
def completion_tokens(self) -> int:
|
||||
return cast(int, self.kwargs["completion_tokens"])
|
||||
|
||||
# This must output the event in a json serializable format
|
||||
def __str__(self) -> str:
|
||||
return json.dumps(self.kwargs)
|
||||
|
||||
|
||||
class MessageKind(Enum):
|
||||
DIRECT = 1
|
||||
PUBLISH = 2
|
||||
RESPOND = 3
|
||||
|
||||
|
||||
class DeliveryStage(Enum):
|
||||
SEND = 1
|
||||
DELIVER = 2
|
||||
|
||||
|
||||
class MessageEvent:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
payload: Any,
|
||||
sender: Agent | None,
|
||||
receiver: Agent | None,
|
||||
kind: MessageKind,
|
||||
delivery_stage: DeliveryStage,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.kwargs = kwargs
|
||||
self.kwargs["payload"] = payload
|
||||
self.kwargs["sender"] = None if sender is None else sender.metadata["name"]
|
||||
self.kwargs["receiver"] = None if receiver is None else receiver.metadata["name"]
|
||||
self.kwargs["kind"] = kind
|
||||
self.kwargs["delivery_stage"] = delivery_stage
|
||||
self.kwargs["type"] = "Message"
|
||||
|
||||
@property
|
||||
def prompt_tokens(self) -> int:
|
||||
return cast(int, self.kwargs["prompt_tokens"])
|
||||
|
||||
@property
|
||||
def completion_tokens(self) -> int:
|
||||
return cast(int, self.kwargs["completion_tokens"])
|
||||
|
||||
# This must output the event in a json serializable format
|
||||
def __str__(self) -> str:
|
||||
return json.dumps(self.kwargs)
|
||||
57
python/src/agnext/application/logging/_llm_usage.py
Normal file
57
python/src/agnext/application/logging/_llm_usage.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import logging
|
||||
|
||||
from ._events import LLMCallEvent
|
||||
|
||||
|
||||
class LLMUsageTracker(logging.Handler):
|
||||
def __init__(self) -> None:
|
||||
"""Logging handler that tracks the number of tokens used in the prompt and completion.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from agnext.application.logging import LLMUsageTracker, EVENT_LOGGER_NAME
|
||||
|
||||
# Set up the logging configuration to use the custom handler
|
||||
logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
logger.setLevel(logging.INFO)
|
||||
llm_usage = LLMUsageTracker()
|
||||
logger.handlers = [llm_usage]
|
||||
|
||||
# ...
|
||||
|
||||
print(llm_usage.prompt_tokens)
|
||||
print(llm_usage.completion_tokens)
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self._prompt_tokens = 0
|
||||
self._completion_tokens = 0
|
||||
|
||||
@property
|
||||
def tokens(self) -> int:
|
||||
return self._prompt_tokens + self._completion_tokens
|
||||
|
||||
@property
|
||||
def prompt_tokens(self) -> int:
|
||||
return self._prompt_tokens
|
||||
|
||||
@property
|
||||
def completion_tokens(self) -> int:
|
||||
return self._completion_tokens
|
||||
|
||||
def reset(self) -> None:
|
||||
self._prompt_tokens = 0
|
||||
self._completion_tokens = 0
|
||||
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
"""Emit the log record. To be used by the logging module."""
|
||||
try:
|
||||
# Use the StructuredMessage if the message is an instance of it
|
||||
if isinstance(record.msg, LLMCallEvent):
|
||||
event = record.msg
|
||||
self._prompt_tokens += event.prompt_tokens
|
||||
self._completion_tokens += event.completion_tokens
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
3
python/src/agnext/chat/__init__.py
Normal file
3
python/src/agnext/chat/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
The :mod:`agnext.chat` module is the concrete implementation of multi-agent interaction patterns
|
||||
"""
|
||||
6
python/src/agnext/chat/agents/__init__.py
Normal file
6
python/src/agnext/chat/agents/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .chat_completion_agent import ChatCompletionAgent
|
||||
from .image_generation_agent import ImageGenerationAgent
|
||||
from .oai_assistant import OpenAIAssistantAgent
|
||||
from .user_proxy import UserProxyAgent
|
||||
|
||||
__all__ = ["ChatCompletionAgent", "OpenAIAssistantAgent", "UserProxyAgent", "ImageGenerationAgent"]
|
||||
264
python/src/agnext/chat/agents/chat_completion_agent.py
Normal file
264
python/src/agnext/chat/agents/chat_completion_agent.py
Normal file
@@ -0,0 +1,264 @@
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, Coroutine, Dict, List, Mapping, Sequence, Tuple
|
||||
|
||||
from ...components import (
|
||||
FunctionCall,
|
||||
TypeRoutedAgent,
|
||||
message_handler,
|
||||
)
|
||||
from ...components.models import (
|
||||
ChatCompletionClient,
|
||||
FunctionExecutionResult,
|
||||
FunctionExecutionResultMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from ...components.tools import Tool
|
||||
from ...core import AgentId, CancellationToken
|
||||
from ..memory import ChatMemory
|
||||
from ..types import (
|
||||
FunctionCallMessage,
|
||||
Message,
|
||||
MultiModalMessage,
|
||||
PublishNow,
|
||||
Reset,
|
||||
RespondNow,
|
||||
ResponseFormat,
|
||||
TextMessage,
|
||||
ToolApprovalRequest,
|
||||
ToolApprovalResponse,
|
||||
)
|
||||
from ..utils import convert_messages_to_llm_messages
|
||||
|
||||
|
||||
class ChatCompletionAgent(TypeRoutedAgent):
|
||||
"""An agent implementation that uses the ChatCompletion API to gnenerate
|
||||
responses and execute tools.
|
||||
|
||||
Args:
|
||||
name (str): The name of the agent.
|
||||
description (str): The description of the agent.
|
||||
runtime (AgentRuntime): The runtime to register the agent.
|
||||
system_messages (List[SystemMessage]): The system messages to use for
|
||||
the ChatCompletion API.
|
||||
memory (ChatMemory): The memory to store and retrieve messages.
|
||||
model_client (ChatCompletionClient): The client to use for the
|
||||
ChatCompletion API.
|
||||
tools (Sequence[Tool], optional): The tools used by the agent. Defaults
|
||||
to []. If no tools are provided, the agent cannot handle tool calls.
|
||||
If tools are provided, and the response from the model is a list of
|
||||
tool calls, the agent will call itselfs with the tool calls until it
|
||||
gets a response that is not a list of tool calls, and then use that
|
||||
response as the final response.
|
||||
tool_approver (Agent | None, optional): The agent that approves tool
|
||||
calls. Defaults to None. If no tool approver is provided, the agent
|
||||
will execute the tools without approval. If a tool approver is
|
||||
provided, the agent will send a request to the tool approver before
|
||||
executing the tools.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
description: str,
|
||||
system_messages: List[SystemMessage],
|
||||
memory: ChatMemory,
|
||||
model_client: ChatCompletionClient,
|
||||
tools: Sequence[Tool] = [],
|
||||
tool_approver: AgentId | None = None,
|
||||
) -> None:
|
||||
super().__init__(description)
|
||||
self._description = description
|
||||
self._system_messages = system_messages
|
||||
self._client = model_client
|
||||
self._memory = memory
|
||||
self._tools = tools
|
||||
self._tool_approver = tool_approver
|
||||
|
||||
@message_handler()
|
||||
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None:
|
||||
"""Handle a text message. This method adds the message to the memory and
|
||||
does not generate any message."""
|
||||
# Add a user message.
|
||||
await self._memory.add_message(message)
|
||||
|
||||
@message_handler()
|
||||
async def on_multi_modal_message(self, message: MultiModalMessage, cancellation_token: CancellationToken) -> None:
|
||||
"""Handle a multimodal message. This method adds the message to the memory
|
||||
and does not generate any message."""
|
||||
# Add a user message.
|
||||
await self._memory.add_message(message)
|
||||
|
||||
@message_handler()
|
||||
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
|
||||
"""Handle a reset message. This method clears the memory."""
|
||||
# Reset the chat messages.
|
||||
await self._memory.clear()
|
||||
|
||||
@message_handler()
|
||||
async def on_respond_now(
|
||||
self, message: RespondNow, cancellation_token: CancellationToken
|
||||
) -> TextMessage | FunctionCallMessage:
|
||||
"""Handle a respond now message. This method generates a response and
|
||||
returns it to the sender."""
|
||||
# Generate a response.
|
||||
response = await self._generate_response(message.response_format, cancellation_token)
|
||||
|
||||
# Return the response.
|
||||
return response
|
||||
|
||||
@message_handler()
|
||||
async def on_publish_now(self, message: PublishNow, cancellation_token: CancellationToken) -> None:
|
||||
"""Handle a publish now message. This method generates a response and
|
||||
publishes it."""
|
||||
# Generate a response.
|
||||
response = await self._generate_response(message.response_format, cancellation_token)
|
||||
|
||||
# Publish the response.
|
||||
await self.publish_message(response)
|
||||
|
||||
@message_handler()
|
||||
async def on_tool_call_message(
|
||||
self, message: FunctionCallMessage, cancellation_token: CancellationToken
|
||||
) -> FunctionExecutionResultMessage:
|
||||
"""Handle a tool call message. This method executes the tools and
|
||||
returns the results."""
|
||||
if len(self._tools) == 0:
|
||||
raise ValueError("No tools available")
|
||||
|
||||
# Add a tool call message.
|
||||
await self._memory.add_message(message)
|
||||
|
||||
# Execute the tool calls.
|
||||
results: List[FunctionExecutionResult] = []
|
||||
execution_futures: List[Coroutine[Any, Any, Tuple[str, str]]] = []
|
||||
for function_call in message.content:
|
||||
# Parse the arguments.
|
||||
try:
|
||||
arguments = json.loads(function_call.arguments)
|
||||
except json.JSONDecodeError:
|
||||
results.append(
|
||||
FunctionExecutionResult(
|
||||
content=f"Error: Could not parse arguments for function {function_call.name}.",
|
||||
call_id=function_call.id,
|
||||
)
|
||||
)
|
||||
continue
|
||||
# Execute the function.
|
||||
future = self._execute_function(
|
||||
function_call.name,
|
||||
arguments,
|
||||
function_call.id,
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
# Append the async result.
|
||||
execution_futures.append(future)
|
||||
if execution_futures:
|
||||
# Wait for all async results.
|
||||
execution_results = await asyncio.gather(*execution_futures)
|
||||
# Add the results.
|
||||
for execution_result, call_id in execution_results:
|
||||
results.append(FunctionExecutionResult(content=execution_result, call_id=call_id))
|
||||
|
||||
# Create a tool call result message.
|
||||
tool_call_result_msg = FunctionExecutionResultMessage(content=results)
|
||||
|
||||
# Add tool call result message.
|
||||
await self._memory.add_message(tool_call_result_msg)
|
||||
|
||||
# Return the results.
|
||||
return tool_call_result_msg
|
||||
|
||||
async def _generate_response(
|
||||
self,
|
||||
response_format: ResponseFormat,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> TextMessage | FunctionCallMessage:
|
||||
# Get a response from the model.
|
||||
hisorical_messages = await self._memory.get_messages()
|
||||
response = await self._client.create(
|
||||
self._system_messages + convert_messages_to_llm_messages(hisorical_messages, self.metadata["name"]),
|
||||
tools=self._tools,
|
||||
json_output=response_format == ResponseFormat.json_object,
|
||||
)
|
||||
|
||||
# If the agent has function executor, and the response is a list of
|
||||
# tool calls, iterate with itself until we get a response that is not a
|
||||
# list of tool calls.
|
||||
while (
|
||||
len(self._tools) > 0
|
||||
and isinstance(response.content, list)
|
||||
and all(isinstance(x, FunctionCall) for x in response.content)
|
||||
):
|
||||
# Send a function call message to itself.
|
||||
response = await self.send_message(
|
||||
message=FunctionCallMessage(content=response.content, source=self.metadata["name"]),
|
||||
recipient=self.id,
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
# Make an assistant message from the response.
|
||||
hisorical_messages = await self._memory.get_messages()
|
||||
response = await self._client.create(
|
||||
self._system_messages + convert_messages_to_llm_messages(hisorical_messages, self.metadata["name"]),
|
||||
tools=self._tools,
|
||||
json_output=response_format == ResponseFormat.json_object,
|
||||
)
|
||||
|
||||
final_response: Message
|
||||
if isinstance(response.content, str):
|
||||
# If the response is a string, return a text message.
|
||||
final_response = TextMessage(content=response.content, source=self.metadata["name"])
|
||||
elif isinstance(response.content, list) and all(isinstance(x, FunctionCall) for x in response.content):
|
||||
# If the response is a list of function calls, return a function call message.
|
||||
final_response = FunctionCallMessage(content=response.content, source=self.metadata["name"])
|
||||
else:
|
||||
raise ValueError(f"Unexpected response: {response.content}")
|
||||
|
||||
# Add the response to the chat messages.
|
||||
await self._memory.add_message(final_response)
|
||||
|
||||
return final_response
|
||||
|
||||
async def _execute_function(
|
||||
self,
|
||||
name: str,
|
||||
args: Dict[str, Any],
|
||||
call_id: str,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Tuple[str, str]:
|
||||
# Find tool
|
||||
tool = next((t for t in self._tools if t.name == name), None)
|
||||
if tool is None:
|
||||
return (f"Error: tool {name} not found.", call_id)
|
||||
|
||||
# Check if the tool needs approval
|
||||
if self._tool_approver is not None:
|
||||
# Send a tool approval request.
|
||||
approval_request = ToolApprovalRequest(
|
||||
tool_call=FunctionCall(id=call_id, arguments=json.dumps(args), name=name)
|
||||
)
|
||||
approval_response = await self.send_message(
|
||||
message=approval_request,
|
||||
recipient=self._tool_approver,
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
if not isinstance(approval_response, ToolApprovalResponse):
|
||||
raise ValueError(f"Expecting {ToolApprovalResponse.__name__}, received: {type(approval_response)}")
|
||||
if not approval_response.approved:
|
||||
return (f"Error: tool {name} approved, reason: {approval_response.reason}", call_id)
|
||||
|
||||
try:
|
||||
result = await tool.run_json(args, cancellation_token)
|
||||
result_as_str = tool.return_value_as_string(result)
|
||||
except Exception as e:
|
||||
result_as_str = f"Error: {str(e)}"
|
||||
return (result_as_str, call_id)
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
return {
|
||||
"memory": self._memory.save_state(),
|
||||
"system_messages": self._system_messages,
|
||||
}
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
self._memory.load_state(state["memory"])
|
||||
self._system_messages = state["system_messages"]
|
||||
62
python/src/agnext/chat/agents/image_generation_agent.py
Normal file
62
python/src/agnext/chat/agents/image_generation_agent.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from typing import Literal
|
||||
|
||||
import openai
|
||||
|
||||
from ...components import (
|
||||
Image,
|
||||
TypeRoutedAgent,
|
||||
message_handler,
|
||||
)
|
||||
from ...core import CancellationToken
|
||||
from ..memory import ChatMemory
|
||||
from ..types import (
|
||||
MultiModalMessage,
|
||||
PublishNow,
|
||||
Reset,
|
||||
TextMessage,
|
||||
)
|
||||
|
||||
|
||||
class ImageGenerationAgent(TypeRoutedAgent):
|
||||
def __init__(
|
||||
self,
|
||||
description: str,
|
||||
memory: ChatMemory,
|
||||
client: openai.AsyncClient,
|
||||
model: Literal["dall-e-2", "dall-e-3"] = "dall-e-2",
|
||||
):
|
||||
super().__init__(description)
|
||||
self._client = client
|
||||
self._model = model
|
||||
self._memory = memory
|
||||
|
||||
@message_handler
|
||||
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None:
|
||||
await self._memory.add_message(message)
|
||||
|
||||
@message_handler
|
||||
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
|
||||
await self._memory.clear()
|
||||
|
||||
@message_handler
|
||||
async def on_publish_now(self, message: PublishNow, cancellation_token: CancellationToken) -> None:
|
||||
response = await self._generate_response(cancellation_token)
|
||||
self.publish_message(response)
|
||||
|
||||
async def _generate_response(self, cancellation_token: CancellationToken) -> MultiModalMessage:
|
||||
messages = await self._memory.get_messages()
|
||||
if len(messages) == 0:
|
||||
return MultiModalMessage(
|
||||
content=["I need more information to generate an image."], source=self.metadata["name"]
|
||||
)
|
||||
prompt = ""
|
||||
for m in messages:
|
||||
assert isinstance(m, TextMessage)
|
||||
prompt += m.content + "\n"
|
||||
prompt.strip()
|
||||
response = await self._client.images.generate(model=self._model, prompt=prompt, response_format="b64_json")
|
||||
assert len(response.data) > 0 and response.data[0].b64_json is not None
|
||||
# Create a MultiModalMessage with the image.
|
||||
image = Image.from_base64(response.data[0].b64_json)
|
||||
multi_modal_message = MultiModalMessage(content=[image], source=self.metadata["name"])
|
||||
return multi_modal_message
|
||||
134
python/src/agnext/chat/agents/oai_assistant.py
Normal file
134
python/src/agnext/chat/agents/oai_assistant.py
Normal file
@@ -0,0 +1,134 @@
|
||||
from typing import Any, Callable, List, Mapping
|
||||
|
||||
import openai
|
||||
from openai import AsyncAssistantEventHandler
|
||||
from openai.types.beta import AssistantResponseFormatParam
|
||||
|
||||
from ...components import TypeRoutedAgent, message_handler
|
||||
from ...core import CancellationToken
|
||||
from ..types import PublishNow, Reset, RespondNow, ResponseFormat, TextMessage
|
||||
|
||||
|
||||
class OpenAIAssistantAgent(TypeRoutedAgent):
|
||||
"""An agent implementation that uses the OpenAI Assistant API to generate
|
||||
responses.
|
||||
|
||||
Args:
|
||||
name (str): The name of the agent.
|
||||
description (str): The description of the agent.
|
||||
runtime (AgentRuntime): The runtime to register the agent.
|
||||
client (openai.AsyncClient): The client to use for the OpenAI API.
|
||||
assistant_id (str): The assistant ID to use for the OpenAI API.
|
||||
thread_id (str): The thread ID to use for the OpenAI API.
|
||||
assistant_event_handler_factory (Callable[[], AsyncAssistantEventHandler], optional):
|
||||
A factory function to create an async assistant event handler. Defaults to None.
|
||||
If provided, the agent will use the streaming mode with the event handler.
|
||||
If not provided, the agent will use the blocking mode to generate responses.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
description: str,
|
||||
client: openai.AsyncClient,
|
||||
assistant_id: str,
|
||||
thread_id: str,
|
||||
assistant_event_handler_factory: Callable[[], AsyncAssistantEventHandler] | None = None,
|
||||
) -> None:
|
||||
super().__init__(description)
|
||||
self._client = client
|
||||
self._assistant_id = assistant_id
|
||||
self._thread_id = thread_id
|
||||
self._assistant_event_handler_factory = assistant_event_handler_factory
|
||||
|
||||
@message_handler()
|
||||
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None:
|
||||
"""Handle a text message. This method adds the message to the thread."""
|
||||
# Save the message to the thread.
|
||||
_ = await self._client.beta.threads.messages.create(
|
||||
thread_id=self._thread_id,
|
||||
content=message.content,
|
||||
role="user",
|
||||
metadata={"sender": message.source},
|
||||
)
|
||||
|
||||
@message_handler()
|
||||
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
|
||||
"""Handle a reset message. This method deletes all messages in the thread."""
|
||||
# Get all messages in this thread.
|
||||
all_msgs: List[str] = []
|
||||
while True:
|
||||
if not all_msgs:
|
||||
msgs = await self._client.beta.threads.messages.list(self._thread_id)
|
||||
else:
|
||||
msgs = await self._client.beta.threads.messages.list(self._thread_id, after=all_msgs[-1])
|
||||
for msg in msgs.data:
|
||||
all_msgs.append(msg.id)
|
||||
if not msgs.has_next_page():
|
||||
break
|
||||
# Delete all the messages.
|
||||
for msg_id in all_msgs:
|
||||
status = await self._client.beta.threads.messages.delete(message_id=msg_id, thread_id=self._thread_id)
|
||||
assert status.deleted is True
|
||||
|
||||
@message_handler()
|
||||
async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> TextMessage:
|
||||
"""Handle a respond now message. This method generates a response and returns it to the sender."""
|
||||
return await self._generate_response(message.response_format, cancellation_token)
|
||||
|
||||
@message_handler()
|
||||
async def on_publish_now(self, message: PublishNow, cancellation_token: CancellationToken) -> None:
|
||||
"""Handle a publish now message. This method generates a response and publishes it."""
|
||||
response = await self._generate_response(message.response_format, cancellation_token)
|
||||
await self.publish_message(response)
|
||||
|
||||
async def _generate_response(
|
||||
self, requested_response_format: ResponseFormat, cancellation_token: CancellationToken
|
||||
) -> TextMessage:
|
||||
# Handle response format.
|
||||
if requested_response_format == ResponseFormat.json_object:
|
||||
response_format = AssistantResponseFormatParam(type="json_object")
|
||||
else:
|
||||
response_format = AssistantResponseFormatParam(type="text")
|
||||
|
||||
if self._assistant_event_handler_factory is not None:
|
||||
# Use event handler and streaming mode if available.
|
||||
async with self._client.beta.threads.runs.stream(
|
||||
thread_id=self._thread_id,
|
||||
assistant_id=self._assistant_id,
|
||||
event_handler=self._assistant_event_handler_factory(),
|
||||
response_format=response_format,
|
||||
) as stream:
|
||||
run = await stream.get_final_run()
|
||||
else:
|
||||
# Use blocking mode.
|
||||
run = await self._client.beta.threads.runs.create(
|
||||
thread_id=self._thread_id,
|
||||
assistant_id=self._assistant_id,
|
||||
response_format=response_format,
|
||||
)
|
||||
|
||||
if run.status != "completed":
|
||||
# TODO: handle other statuses.
|
||||
raise ValueError(f"Run did not complete successfully: {run}")
|
||||
|
||||
# Get the last message from the run.
|
||||
response = await self._client.beta.threads.messages.list(self._thread_id, run_id=run.id, order="desc", limit=1)
|
||||
last_message_content = response.data[0].content
|
||||
|
||||
# TODO: handle array of content.
|
||||
text_content = [content for content in last_message_content if content.type == "text"]
|
||||
if not text_content:
|
||||
raise ValueError(f"Expected text content in the last message: {last_message_content}")
|
||||
|
||||
# TODO: handle multiple text content.
|
||||
return TextMessage(content=text_content[0].text.value, source=self.metadata["name"])
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
return {
|
||||
"assistant_id": self._assistant_id,
|
||||
"thread_id": self._thread_id,
|
||||
}
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
self._assistant_id = state["assistant_id"]
|
||||
self._thread_id = state["thread_id"]
|
||||
32
python/src/agnext/chat/agents/user_proxy.py
Normal file
32
python/src/agnext/chat/agents/user_proxy.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import asyncio
|
||||
|
||||
from ...components import TypeRoutedAgent, message_handler
|
||||
from ...core import CancellationToken
|
||||
from ..types import PublishNow, TextMessage
|
||||
|
||||
|
||||
class UserProxyAgent(TypeRoutedAgent):
|
||||
"""An agent that proxies user input from the console. Override the `get_user_input`
|
||||
method to customize how user input is retrieved.
|
||||
|
||||
Args:
|
||||
name (str): The name of the agent.
|
||||
description (str): The description of the agent.
|
||||
runtime (AgentRuntime): The runtime to register the agent.
|
||||
user_input_prompt (str): The console prompt to show to the user when asking for input.
|
||||
"""
|
||||
|
||||
def __init__(self, description: str, user_input_prompt: str) -> None:
|
||||
super().__init__(description)
|
||||
self._user_input_prompt = user_input_prompt
|
||||
|
||||
@message_handler()
|
||||
async def on_publish_now(self, message: PublishNow, cancellation_token: CancellationToken) -> None:
|
||||
"""Handle a publish now message. This method prompts the user for input, then publishes it."""
|
||||
user_input = await self.get_user_input(self._user_input_prompt)
|
||||
await self.publish_message(TextMessage(content=user_input, source=self.metadata["name"]))
|
||||
|
||||
async def get_user_input(self, prompt: str) -> str:
|
||||
"""Get user input from the console. Override this method to customize how user input is retrieved."""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, input, prompt)
|
||||
5
python/src/agnext/chat/memory/__init__.py
Normal file
5
python/src/agnext/chat/memory/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from ._base import ChatMemory
|
||||
from ._buffered import BufferedChatMemory
|
||||
from ._head_and_tail import HeadAndTailChatMemory
|
||||
|
||||
__all__ = ["ChatMemory", "BufferedChatMemory", "HeadAndTailChatMemory"]
|
||||
19
python/src/agnext/chat/memory/_base.py
Normal file
19
python/src/agnext/chat/memory/_base.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from typing import Any, List, Mapping, Protocol
|
||||
|
||||
from ..types import Message
|
||||
|
||||
|
||||
class ChatMemory(Protocol):
|
||||
"""A protocol for defining the interface of a chat memory. A chat memory
|
||||
lets agents to store and retrieve messages. It can be implemented with
|
||||
different memory recall strategies."""
|
||||
|
||||
async def add_message(self, message: Message) -> None: ...
|
||||
|
||||
async def get_messages(self) -> List[Message]: ...
|
||||
|
||||
async def clear(self) -> None: ...
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]: ...
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None: ...
|
||||
46
python/src/agnext/chat/memory/_buffered.py
Normal file
46
python/src/agnext/chat/memory/_buffered.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from typing import Any, List, Mapping
|
||||
|
||||
from ...components.models import FunctionExecutionResultMessage
|
||||
from ..types import Message
|
||||
from ._base import ChatMemory
|
||||
|
||||
|
||||
class BufferedChatMemory(ChatMemory):
|
||||
"""A buffered chat memory that keeps a view of the last n messages,
|
||||
where n is the buffer size. The buffer size is set at initialization.
|
||||
|
||||
Args:
|
||||
buffer_size (int): The size of the buffer.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, buffer_size: int) -> None:
|
||||
self._messages: List[Message] = []
|
||||
self._buffer_size = buffer_size
|
||||
|
||||
async def add_message(self, message: Message) -> None:
|
||||
"""Add a message to the memory."""
|
||||
self._messages.append(message)
|
||||
|
||||
async def get_messages(self) -> List[Message]:
|
||||
"""Get at most `buffer_size` recent messages."""
|
||||
messages = self._messages[-self._buffer_size :]
|
||||
# Handle the first message is a function call result message.
|
||||
if messages and isinstance(messages[0], FunctionExecutionResultMessage):
|
||||
# Remove the first message from the list.
|
||||
messages = messages[1:]
|
||||
return messages
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear the message memory."""
|
||||
self._messages = []
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
return {
|
||||
"messages": [message for message in self._messages],
|
||||
"buffer_size": self._buffer_size,
|
||||
}
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
self._messages = state["messages"]
|
||||
self._buffer_size = state["buffer_size"]
|
||||
66
python/src/agnext/chat/memory/_head_and_tail.py
Normal file
66
python/src/agnext/chat/memory/_head_and_tail.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from typing import Any, List, Mapping
|
||||
|
||||
from ...components.models import FunctionExecutionResultMessage
|
||||
from ..types import FunctionCallMessage, Message, TextMessage
|
||||
from ._base import ChatMemory
|
||||
|
||||
|
||||
class HeadAndTailChatMemory(ChatMemory):
|
||||
"""A chat memory that keeps a view of the first n and last m messages,
|
||||
where n is the head size and m is the tail size. The head and tail sizes
|
||||
are set at initialization.
|
||||
|
||||
Args:
|
||||
head_size (int): The size of the head.
|
||||
tail_size (int): The size of the tail.
|
||||
"""
|
||||
|
||||
def __init__(self, head_size: int, tail_size: int) -> None:
|
||||
self._messages: List[Message] = []
|
||||
self._head_size = head_size
|
||||
self._tail_size = tail_size
|
||||
|
||||
async def add_message(self, message: Message) -> None:
|
||||
"""Add a message to the memory."""
|
||||
self._messages.append(message)
|
||||
|
||||
async def get_messages(self) -> List[Message]:
|
||||
"""Get at most `head_size` recent messages and `tail_size` oldest messages."""
|
||||
head_messages = self._messages[: self._head_size]
|
||||
# Handle the last message is a function call message.
|
||||
if head_messages and isinstance(head_messages[-1], FunctionCallMessage):
|
||||
# Remove the last message from the head.
|
||||
head_messages = head_messages[:-1]
|
||||
|
||||
tail_messages = self._messages[-self._tail_size :]
|
||||
# Handle the first message is a function call result message.
|
||||
if tail_messages and isinstance(tail_messages[0], FunctionExecutionResultMessage):
|
||||
# Remove the first message from the tail.
|
||||
tail_messages = tail_messages[1:]
|
||||
|
||||
num_skipped = len(self._messages) - self._head_size - self._tail_size
|
||||
if num_skipped <= 0:
|
||||
# If there are not enough messages to fill the head and tail,
|
||||
# return all messages.
|
||||
return self._messages
|
||||
|
||||
placeholder_messages = [TextMessage(content=f"Skipped {num_skipped} messages.", source="System")]
|
||||
return head_messages + placeholder_messages + tail_messages
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear the message memory."""
|
||||
self._messages = []
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
return {
|
||||
"messages": [message for message in self._messages],
|
||||
"head_size": self._head_size,
|
||||
"tail_size": self._tail_size,
|
||||
"placeholder_message": self._placeholder_message,
|
||||
}
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
self._messages = state["messages"]
|
||||
self._head_size = state["head_size"]
|
||||
self._tail_size = state["tail_size"]
|
||||
self._placeholder_message = state["placeholder_message"]
|
||||
3
python/src/agnext/chat/patterns/__init__.py
Normal file
3
python/src/agnext/chat/patterns/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .group_chat_manager import GroupChatManager
|
||||
|
||||
__all__ = ["GroupChatManager"]
|
||||
154
python/src/agnext/chat/patterns/group_chat_manager.py
Normal file
154
python/src/agnext/chat/patterns/group_chat_manager.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import logging
|
||||
from typing import Any, Callable, List, Mapping
|
||||
|
||||
from ...components import TypeRoutedAgent, message_handler
|
||||
from ...components.models import ChatCompletionClient
|
||||
from ...core import AgentId, AgentProxy, AgentRuntime, CancellationToken
|
||||
from ..memory import ChatMemory
|
||||
from ..types import (
|
||||
MultiModalMessage,
|
||||
PublishNow,
|
||||
Reset,
|
||||
TextMessage,
|
||||
)
|
||||
from .group_chat_utils import select_speaker
|
||||
|
||||
logger = logging.getLogger("agnext.events")
|
||||
|
||||
|
||||
class GroupChatManager(TypeRoutedAgent):
|
||||
"""An agent that manages a group chat through event-driven orchestration.
|
||||
|
||||
Args:
|
||||
name (str): The name of the agent.
|
||||
description (str): The description of the agent.
|
||||
runtime (AgentRuntime): The runtime to register the agent.
|
||||
participants (List[AgentId]): The list of participants in the group chat.
|
||||
memory (ChatMemory): The memory to store and retrieve messages.
|
||||
model_client (ChatCompletionClient, optional): The client to use for the model.
|
||||
If provided, the agent will use the model to select the next speaker.
|
||||
If not provided, the agent will select the next speaker from the list of participants
|
||||
according to the order given.
|
||||
termination_word (str, optional): The word that terminates the group chat. Defaults to "TERMINATE".
|
||||
transitions (Mapping[AgentId, List[AgentId]], optional): The transitions between agents.
|
||||
Keys are the agents, and values are the list of agents that can follow the key agent. Defaults to {}.
|
||||
If provided, the group chat manager will use the transitions to select the next speaker.
|
||||
If a transition is not provided for an agent, the choices fallback to all participants.
|
||||
If no model client is provided, a transition must have a single value.
|
||||
on_message_received (Callable[[TextMessage], None], optional): A custom handler to call when a message is received.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
description: str,
|
||||
runtime: AgentRuntime,
|
||||
participants: List[AgentId],
|
||||
memory: ChatMemory,
|
||||
model_client: ChatCompletionClient | None = None,
|
||||
termination_word: str = "TERMINATE",
|
||||
transitions: Mapping[AgentId, List[AgentId]] = {},
|
||||
on_message_received: Callable[[TextMessage | MultiModalMessage], None] | None = None,
|
||||
):
|
||||
super().__init__(description)
|
||||
self._memory = memory
|
||||
self._client = model_client
|
||||
self._participants = participants
|
||||
self._participant_proxies = dict((p, AgentProxy(p, runtime)) for p in participants)
|
||||
self._termination_word = termination_word
|
||||
for key, value in transitions.items():
|
||||
if not value:
|
||||
# Make sure no empty transitions are provided.
|
||||
raise ValueError(f"Empty transition list provided for {key.name}.")
|
||||
if key not in participants:
|
||||
# Make sure all keys are in the list of participants.
|
||||
raise ValueError(f"Transition key {key.name} not found in participants.")
|
||||
for v in value:
|
||||
if v not in participants:
|
||||
# Make sure all values are in the list of participants.
|
||||
raise ValueError(f"Transition value {v.name} not found in participants.")
|
||||
if self._client is None:
|
||||
# Make sure there is only one transition for each key if no model client is provided.
|
||||
if len(value) > 1:
|
||||
raise ValueError(f"Multiple transitions provided for {key.name} but no model client is provided.")
|
||||
self._tranistions = transitions
|
||||
self._on_message_received = on_message_received
|
||||
|
||||
@message_handler()
|
||||
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
|
||||
"""Handle a reset message. This method clears the memory."""
|
||||
await self._memory.clear()
|
||||
|
||||
@message_handler()
|
||||
async def on_new_message(
|
||||
self, message: TextMessage | MultiModalMessage, cancellation_token: CancellationToken
|
||||
) -> None:
|
||||
"""Handle a message. This method adds the message to the memory, selects the next speaker,
|
||||
and sends a message to the selected speaker to publish a response."""
|
||||
# Call the custom on_message_received handler if provided.
|
||||
if self._on_message_received is not None:
|
||||
self._on_message_received(message)
|
||||
|
||||
# Check if the message contains the termination word.
|
||||
if isinstance(message, TextMessage) and self._termination_word in message.content:
|
||||
# Terminate the group chat by not selecting the next speaker.
|
||||
return
|
||||
|
||||
# Save the message to chat memory.
|
||||
await self._memory.add_message(message)
|
||||
|
||||
# Get the last speaker.
|
||||
last_speaker_name = message.source
|
||||
last_speaker_index = next((i for i, p in enumerate(self._participants) if p.name == last_speaker_name), None)
|
||||
|
||||
# Get the candidates for the next speaker.
|
||||
if last_speaker_index is not None:
|
||||
logger.debug(f"Last speaker: {last_speaker_name}")
|
||||
last_speaker = self._participants[last_speaker_index]
|
||||
if self._tranistions.get(last_speaker) is not None:
|
||||
candidates = [c for c in self._participants if c in self._tranistions[last_speaker]]
|
||||
else:
|
||||
candidates = self._participants
|
||||
else:
|
||||
candidates = self._participants
|
||||
logger.debug(f"Group chat manager next speaker candidates: {[c.name for c in candidates]}")
|
||||
|
||||
# Select speaker.
|
||||
if len(candidates) == 0:
|
||||
speaker = None
|
||||
elif len(candidates) == 1:
|
||||
speaker = candidates[0]
|
||||
else:
|
||||
# More than one candidate, select the next speaker.
|
||||
if self._client is None:
|
||||
# If no model client is provided, candidates must be the list of participants.
|
||||
assert candidates == self._participants
|
||||
# If no model client is provided, select the next speaker from the list of participants.
|
||||
if last_speaker_index is not None:
|
||||
next_speaker_index = (last_speaker_index + 1) % len(self._participants)
|
||||
speaker = self._participants[next_speaker_index]
|
||||
else:
|
||||
# If no last speaker, select the first speaker.
|
||||
speaker = candidates[0]
|
||||
else:
|
||||
# If a model client is provided, select the speaker based on the transitions and the model.
|
||||
speaker_index = await select_speaker(
|
||||
self._memory, self._client, [self._participant_proxies[c] for c in candidates]
|
||||
)
|
||||
speaker = candidates[speaker_index]
|
||||
|
||||
logger.debug(f"Group chat manager selected speaker: {speaker.name if speaker is not None else None}")
|
||||
|
||||
if speaker is not None:
|
||||
# Send the message to the selected speaker to ask it to publish a response.
|
||||
await self.send_message(PublishNow(), speaker)
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
return {
|
||||
"memory": self._memory.save_state(),
|
||||
"termination_word": self._termination_word,
|
||||
}
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
self._memory.load_state(state["memory"])
|
||||
self._termination_word = state["termination_word"]
|
||||
81
python/src/agnext/chat/patterns/group_chat_utils.py
Normal file
81
python/src/agnext/chat/patterns/group_chat_utils.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Credit to the original authors: https://github.com/microsoft/autogen/blob/main/autogen/agentchat/groupchat.py"""
|
||||
|
||||
import re
|
||||
from typing import Dict, List
|
||||
|
||||
from ...components.models import ChatCompletionClient, SystemMessage
|
||||
from ...core import AgentProxy
|
||||
from ..memory import ChatMemory
|
||||
from ..types import TextMessage
|
||||
|
||||
|
||||
async def select_speaker(memory: ChatMemory, client: ChatCompletionClient, agents: List[AgentProxy]) -> int:
|
||||
"""Selects the next speaker in a group chat using a ChatCompletion client."""
|
||||
# TODO: Handle multi-modal messages.
|
||||
|
||||
# Construct formated current message history.
|
||||
history_messages: List[str] = []
|
||||
for msg in await memory.get_messages():
|
||||
assert isinstance(msg, TextMessage)
|
||||
history_messages.append(f"{msg.source}: {msg.content}")
|
||||
history = "\n".join(history_messages)
|
||||
|
||||
# Construct agent roles.
|
||||
roles = "\n".join([f"{agent.metadata['name']}: {agent.metadata['description']}".strip() for agent in agents])
|
||||
|
||||
# Construct agent list.
|
||||
participants = str([agent.metadata["name"] for agent in agents])
|
||||
|
||||
# Select the next speaker.
|
||||
select_speaker_prompt = f"""You are in a role play game. The following roles are available:
|
||||
{roles}.
|
||||
Read the following conversation. Then select the next role from {participants} to play. Only return the role.
|
||||
|
||||
{history}
|
||||
|
||||
Read the above conversation. Then select the next role from {participants} to play. Only return the role.
|
||||
"""
|
||||
select_speaker_messages = [SystemMessage(select_speaker_prompt)]
|
||||
response = await client.create(messages=select_speaker_messages)
|
||||
assert isinstance(response.content, str)
|
||||
mentions = mentioned_agents(response.content, agents)
|
||||
if len(mentions) != 1:
|
||||
raise ValueError(f"Expected exactly one agent to be mentioned, but got {mentions}")
|
||||
agent_name = list(mentions.keys())[0]
|
||||
agent_index = next((i for i, agent in enumerate(agents) if agent.metadata["name"] == agent_name), None)
|
||||
assert agent_index is not None
|
||||
return agent_index
|
||||
|
||||
|
||||
def mentioned_agents(message_content: str, agents: List[AgentProxy]) -> Dict[str, int]:
|
||||
"""Counts the number of times each agent is mentioned in the provided message content.
|
||||
Agent names will match under any of the following conditions (all case-sensitive):
|
||||
- Exact name match
|
||||
- If the agent name has underscores it will match with spaces instead (e.g. 'Story_writer' == 'Story writer')
|
||||
- If the agent name has underscores it will match with '\\_' instead of '_' (e.g. 'Story_writer' == 'Story\\_writer')
|
||||
|
||||
Args:
|
||||
message_content (Union[str, List]): The content of the message, either as a single string or a list of strings.
|
||||
agents (List[Agent]): A list of Agent objects, each having a 'name' attribute to be searched in the message content.
|
||||
|
||||
Returns:
|
||||
Dict: a counter for mentioned agents.
|
||||
"""
|
||||
mentions: Dict[str, int] = dict()
|
||||
for agent in agents:
|
||||
# Finds agent mentions, taking word boundaries into account,
|
||||
# accommodates escaping underscores and underscores as spaces
|
||||
name = agent.metadata["name"]
|
||||
regex = (
|
||||
r"(?<=\W)("
|
||||
+ re.escape(name)
|
||||
+ r"|"
|
||||
+ re.escape(name.replace("_", " "))
|
||||
+ r"|"
|
||||
+ re.escape(name.replace("_", r"\_"))
|
||||
+ r")(?=\W)"
|
||||
)
|
||||
count = len(re.findall(regex, f" {message_content} ")) # Pad the message to help with matching
|
||||
if count > 0:
|
||||
mentions[name] = count
|
||||
return mentions
|
||||
392
python/src/agnext/chat/patterns/orchestrator_chat.py
Normal file
392
python/src/agnext/chat/patterns/orchestrator_chat.py
Normal file
@@ -0,0 +1,392 @@
|
||||
import json
|
||||
from typing import Any, Sequence, Tuple
|
||||
|
||||
from ...components import TypeRoutedAgent, message_handler
|
||||
from ...core import AgentId, AgentRuntime, CancellationToken
|
||||
from ..types import Reset, RespondNow, ResponseFormat, TextMessage
|
||||
|
||||
__all__ = ["OrchestratorChat"]
|
||||
|
||||
|
||||
class OrchestratorChat(TypeRoutedAgent):
|
||||
def __init__(
|
||||
self,
|
||||
description: str,
|
||||
runtime: AgentRuntime,
|
||||
orchestrator: AgentId,
|
||||
planner: AgentId,
|
||||
specialists: Sequence[AgentId],
|
||||
max_turns: int = 30,
|
||||
max_stalled_turns_before_retry: int = 2,
|
||||
max_retry_attempts: int = 1,
|
||||
) -> None:
|
||||
super().__init__(description)
|
||||
self._orchestrator = orchestrator
|
||||
self._planner = planner
|
||||
self._specialists = specialists
|
||||
self._max_turns = max_turns
|
||||
self._max_stalled_turns_before_retry = max_stalled_turns_before_retry
|
||||
self._max_retry_attempts_before_educated_guess = max_retry_attempts
|
||||
|
||||
@property
|
||||
def children(self) -> Sequence[AgentId]:
|
||||
return list(self._specialists) + [self._orchestrator, self._planner]
|
||||
|
||||
@message_handler()
|
||||
async def on_text_message(
|
||||
self,
|
||||
message: TextMessage,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> TextMessage:
|
||||
# A task is received.
|
||||
task = message.content
|
||||
|
||||
# Prepare the task.
|
||||
team, names, facts, plan = await self._prepare_task(task, message.source)
|
||||
|
||||
# Main loop.
|
||||
total_turns = 0
|
||||
retry_attempts = 0
|
||||
while total_turns < self._max_turns:
|
||||
# Reset all agents.
|
||||
for agent in [*self._specialists, self._orchestrator]:
|
||||
await self.send_message(Reset(), agent)
|
||||
|
||||
# Create the task specs.
|
||||
task_specs = f"""
|
||||
We are working to address the following user request:
|
||||
|
||||
{task}
|
||||
|
||||
|
||||
To answer this request we have assembled the following team:
|
||||
|
||||
{team}
|
||||
|
||||
Some additional points to consider:
|
||||
|
||||
{facts}
|
||||
|
||||
{plan}
|
||||
""".strip()
|
||||
|
||||
# Send the task specs to the orchestrator and specialists.
|
||||
for agent in [*self._specialists, self._orchestrator]:
|
||||
await self.send_message(TextMessage(content=task_specs, source=self.metadata["name"]), agent)
|
||||
|
||||
# Inner loop.
|
||||
stalled_turns = 0
|
||||
while total_turns < self._max_turns:
|
||||
# Reflect on the task.
|
||||
data = await self._reflect_on_task(task, team, names, message.source)
|
||||
|
||||
# Check if the request is satisfied.
|
||||
if data["is_request_satisfied"]["answer"]:
|
||||
return TextMessage(
|
||||
content=f"The task has been successfully addressed. {data['is_request_satisfied']['reason']}",
|
||||
source=self.metadata["name"],
|
||||
)
|
||||
|
||||
# Update stalled turns.
|
||||
if data["is_progress_being_made"]["answer"]:
|
||||
stalled_turns = max(0, stalled_turns - 1)
|
||||
else:
|
||||
stalled_turns += 1
|
||||
|
||||
# Handle retry.
|
||||
if stalled_turns > self._max_stalled_turns_before_retry:
|
||||
# In a retry, we need to rewrite the facts and the plan.
|
||||
|
||||
# Rewrite the facts.
|
||||
facts = await self._rewrite_facts(facts, message.source)
|
||||
|
||||
# Increment the retry attempts.
|
||||
retry_attempts += 1
|
||||
|
||||
# Check if we should just guess.
|
||||
if retry_attempts > self._max_retry_attempts_before_educated_guess:
|
||||
# Make an educated guess.
|
||||
educated_guess = await self._educated_guess(facts, message.source)
|
||||
if educated_guess["has_educated_guesses"]["answer"]:
|
||||
return TextMessage(
|
||||
content=f"The task is addressed with an educated guess. {educated_guess['has_educated_guesses']['reason']}",
|
||||
source=self.metadata["name"],
|
||||
)
|
||||
|
||||
# Come up with a new plan.
|
||||
plan = await self._rewrite_plan(team, message.source)
|
||||
|
||||
# Exit the inner loop.
|
||||
break
|
||||
|
||||
# Get the subtask.
|
||||
subtask = data["instruction_or_question"]["answer"]
|
||||
if subtask is None:
|
||||
subtask = ""
|
||||
|
||||
# Update agents.
|
||||
for agent in [*self._specialists, self._orchestrator]:
|
||||
_ = await self.send_message(
|
||||
TextMessage(content=subtask, source=self.metadata["name"]),
|
||||
agent,
|
||||
)
|
||||
|
||||
# Find the speaker.
|
||||
try:
|
||||
speaker = next(agent for agent in self._specialists if agent.name == data["next_speaker"]["answer"])
|
||||
except StopIteration as e:
|
||||
raise ValueError(f"Invalid next speaker: {data['next_speaker']['answer']}") from e
|
||||
|
||||
# Ask speaker to speak.
|
||||
speaker_response = await self.send_message(RespondNow(), speaker)
|
||||
assert speaker_response is not None
|
||||
|
||||
# Update all other agents with the speaker's response.
|
||||
for agent in [agent for agent in self._specialists if agent != speaker] + [self._orchestrator]:
|
||||
await self.send_message(
|
||||
TextMessage(
|
||||
content=speaker_response.content,
|
||||
source=speaker_response.source,
|
||||
),
|
||||
agent,
|
||||
)
|
||||
|
||||
# Increment the total turns.
|
||||
total_turns += 1
|
||||
|
||||
return TextMessage(
|
||||
content="The task was not addressed. The maximum number of turns was reached.",
|
||||
source=self.metadata["name"],
|
||||
)
|
||||
|
||||
async def _prepare_task(self, task: str, sender: str) -> Tuple[str, str, str, str]:
|
||||
# Reset planner.
|
||||
await self.send_message(Reset(), self._planner)
|
||||
|
||||
# A reusable description of the team.
|
||||
team = "\n".join(
|
||||
[agent.name + ": " + self.runtime.agent_metadata(agent)["description"] for agent in self._specialists]
|
||||
)
|
||||
names = ", ".join([agent.name for agent in self._specialists])
|
||||
|
||||
# A place to store relevant facts.
|
||||
facts = ""
|
||||
|
||||
# A plance to store the plan.
|
||||
plan = ""
|
||||
|
||||
# Start by writing what we know
|
||||
closed_book_prompt = f"""Below I will present you a request. Before we begin addressing the request, please answer the following pre-survey to the best of your ability. Keep in mind that you are Ken Jennings-level with trivia, and Mensa-level with puzzles, so there should be a deep well to draw from.
|
||||
|
||||
Here is the request:
|
||||
|
||||
{task}
|
||||
|
||||
Here is the pre-survey:
|
||||
|
||||
1. Please list any specific facts or figures that are GIVEN in the request itself. It is possible that there are none.
|
||||
2. Please list any facts that may need to be looked up, and WHERE SPECIFICALLY they might be found. In some cases, authoritative sources are mentioned in the request itself.
|
||||
3. Please list any facts that may need to be derived (e.g., via logical deduction, simulation, or computation)
|
||||
4. Please list any facts that are recalled from memory, hunches, well-reasoned guesses, etc.
|
||||
|
||||
When answering this survey, keep in mind that "facts" will typically be specific names, dates, statistics, etc. Your answer should use headings:
|
||||
|
||||
1. GIVEN OR VERIFIED FACTS
|
||||
2. FACTS TO LOOK UP
|
||||
3. FACTS TO DERIVE
|
||||
4. EDUCATED GUESSES
|
||||
""".strip()
|
||||
|
||||
# Ask the planner to obtain prior knowledge about facts.
|
||||
await self.send_message(TextMessage(content=closed_book_prompt, source=sender), self._planner)
|
||||
facts_response = await self.send_message(RespondNow(), self._planner)
|
||||
|
||||
facts = str(facts_response.content)
|
||||
|
||||
# Make an initial plan
|
||||
plan_prompt = f"""Fantastic. To address this request we have assembled the following team:
|
||||
|
||||
{team}
|
||||
|
||||
Based on the team composition, and known and unknown facts, please devise a short bullet-point plan for addressing the original request. Remember, there is no requirement to involve all team members -- a team member's particular expertise may not be needed for this task.""".strip()
|
||||
|
||||
# Send second messag eto the planner.
|
||||
await self.send_message(TextMessage(content=plan_prompt, source=sender), self._planner)
|
||||
plan_response = await self.send_message(RespondNow(), self._planner)
|
||||
plan = str(plan_response.content)
|
||||
|
||||
return team, names, facts, plan
|
||||
|
||||
async def _reflect_on_task(
|
||||
self,
|
||||
task: str,
|
||||
team: str,
|
||||
names: str,
|
||||
sender: str,
|
||||
) -> Any:
|
||||
step_prompt = f"""
|
||||
Recall we are working on the following request:
|
||||
|
||||
{task}
|
||||
|
||||
And we have assembled the following team:
|
||||
|
||||
{team}
|
||||
|
||||
To make progress on the request, please answer the following questions, including necessary reasoning:
|
||||
|
||||
- Is the request fully satisfied? (True if complete, or False if the original request has yet to be SUCCESSFULLY addressed)
|
||||
- Are we making forward progress? (True if just starting, or recent messages are adding value. False if recent messages show evidence of being stuck in a reasoning or action loop, or there is evidence of significant barriers to success such as the inability to read from a required file)
|
||||
- Who should speak next? (select from: {names})
|
||||
- What instruction or question would you give this team member? (Phrase as if speaking directly to them, and include any specific information they may need)
|
||||
|
||||
Please output an answer in pure JSON format according to the following schema. The JSON object must be parsable as-is. DO NOT OUTPUT ANYTHING OTHER THAN JSON, AND DO NOT DEVIATE FROM THIS SCHEMA:
|
||||
|
||||
{{
|
||||
"is_request_satisfied": {{
|
||||
"reason": string,
|
||||
"answer": boolean
|
||||
}},
|
||||
"is_progress_being_made": {{
|
||||
"reason": string,
|
||||
"answer": boolean
|
||||
}},
|
||||
"next_speaker": {{
|
||||
"reason": string,
|
||||
"answer": string (select from: {names})
|
||||
}},
|
||||
"instruction_or_question": {{
|
||||
"reason": string,
|
||||
"answer": string
|
||||
}}
|
||||
}}
|
||||
""".strip()
|
||||
request = step_prompt
|
||||
while True:
|
||||
# Send a message to the orchestrator.
|
||||
await self.send_message(TextMessage(content=request, source=sender), self._orchestrator)
|
||||
# Request a response.
|
||||
step_response = await self.send_message(
|
||||
RespondNow(response_format=ResponseFormat.json_object),
|
||||
self._orchestrator,
|
||||
)
|
||||
# TODO: use typed dictionary.
|
||||
try:
|
||||
result = json.loads(str(step_response.content))
|
||||
except json.JSONDecodeError as e:
|
||||
request = f"Invalid JSON: {str(e)}"
|
||||
continue
|
||||
if "is_request_satisfied" not in result:
|
||||
request = "Missing key: is_request_satisfied"
|
||||
continue
|
||||
elif (
|
||||
not isinstance(result["is_request_satisfied"], dict)
|
||||
or "answer" not in result["is_request_satisfied"]
|
||||
or "reason" not in result["is_request_satisfied"]
|
||||
):
|
||||
request = "Invalid value for key: is_request_satisfied, expected 'answer' and 'reason'"
|
||||
continue
|
||||
if "is_progress_being_made" not in result:
|
||||
request = "Missing key: is_progress_being_made"
|
||||
continue
|
||||
elif (
|
||||
not isinstance(result["is_progress_being_made"], dict)
|
||||
or "answer" not in result["is_progress_being_made"]
|
||||
or "reason" not in result["is_progress_being_made"]
|
||||
):
|
||||
request = "Invalid value for key: is_progress_being_made, expected 'answer' and 'reason'"
|
||||
continue
|
||||
if "next_speaker" not in result:
|
||||
request = "Missing key: next_speaker"
|
||||
continue
|
||||
elif (
|
||||
not isinstance(result["next_speaker"], dict)
|
||||
or "answer" not in result["next_speaker"]
|
||||
or "reason" not in result["next_speaker"]
|
||||
):
|
||||
request = "Invalid value for key: next_speaker, expected 'answer' and 'reason'"
|
||||
continue
|
||||
elif result["next_speaker"]["answer"] not in names:
|
||||
request = f"Invalid value for key: next_speaker, expected 'answer' in {names}"
|
||||
continue
|
||||
if "instruction_or_question" not in result:
|
||||
request = "Missing key: instruction_or_question"
|
||||
continue
|
||||
elif (
|
||||
not isinstance(result["instruction_or_question"], dict)
|
||||
or "answer" not in result["instruction_or_question"]
|
||||
or "reason" not in result["instruction_or_question"]
|
||||
):
|
||||
request = "Invalid value for key: instruction_or_question, expected 'answer' and 'reason'"
|
||||
continue
|
||||
return result
|
||||
|
||||
async def _rewrite_facts(self, facts: str, sender: str) -> str:
|
||||
new_facts_prompt = f"""It's clear we aren't making as much progress as we would like, but we may have learned something new. Please rewrite the following fact sheet, updating it to include anything new we have learned. This is also a good time to update educated guesses (please add or update at least one educated guess or hunch, and explain your reasoning).
|
||||
|
||||
{facts}
|
||||
""".strip()
|
||||
# Send a message to the orchestrator.
|
||||
await self.send_message(TextMessage(content=new_facts_prompt, source=sender), self._orchestrator)
|
||||
# Request a response.
|
||||
new_facts_response = await self.send_message(RespondNow(), self._orchestrator)
|
||||
return str(new_facts_response.content)
|
||||
|
||||
async def _educated_guess(self, facts: str, sender: str) -> Any:
|
||||
# Make an educated guess.
|
||||
educated_guess_promt = f"""Given the following information
|
||||
|
||||
{facts}
|
||||
|
||||
Please answer the following question, including necessary reasoning:
|
||||
- Do you have two or more congruent pieces of information that will allow you to make an educated guess for the original request? The educated guess MUST answer the question.
|
||||
Please output an answer in pure JSON format according to the following schema. The JSON object must be parsable as-is. DO NOT OUTPUT ANYTHING OTHER THAN JSON, AND DO NOT DEVIATE FROM THIS SCHEMA:
|
||||
|
||||
{{
|
||||
"has_educated_guesses": {{
|
||||
"reason": string,
|
||||
"answer": boolean
|
||||
}}
|
||||
}}
|
||||
""".strip()
|
||||
request = educated_guess_promt
|
||||
while True:
|
||||
# Send a message to the orchestrator.
|
||||
await self.send_message(
|
||||
TextMessage(content=request, source=sender),
|
||||
self._orchestrator,
|
||||
)
|
||||
# Request a response.
|
||||
response = await self.send_message(
|
||||
RespondNow(response_format=ResponseFormat.json_object),
|
||||
self._orchestrator,
|
||||
)
|
||||
try:
|
||||
result = json.loads(str(response.content))
|
||||
except json.JSONDecodeError as e:
|
||||
request = f"Invalid JSON: {str(e)}"
|
||||
continue
|
||||
# TODO: use typed dictionary.
|
||||
if "has_educated_guesses" not in result:
|
||||
request = "Missing key: has_educated_guesses"
|
||||
continue
|
||||
if (
|
||||
not isinstance(result["has_educated_guesses"], dict)
|
||||
or "answer" not in result["has_educated_guesses"]
|
||||
or "reason" not in result["has_educated_guesses"]
|
||||
):
|
||||
request = "Invalid value for key: has_educated_guesses, expected 'answer' and 'reason'"
|
||||
continue
|
||||
return result
|
||||
|
||||
async def _rewrite_plan(self, team: str, sender: str) -> str:
|
||||
new_plan_prompt = f"""Please come up with a new plan expressed in bullet points. Keep in mind the following team composition, and do not involve any other outside people in the plan -- we cannot contact anyone else.
|
||||
|
||||
Team membership:
|
||||
{team}
|
||||
""".strip()
|
||||
# Send a message to the orchestrator.
|
||||
await self.send_message(TextMessage(content=new_plan_prompt, source=sender), self._orchestrator)
|
||||
# Request a response.
|
||||
new_plan_response = await self.send_message(RespondNow(), self._orchestrator)
|
||||
return str(new_plan_response.content)
|
||||
74
python/src/agnext/chat/types.py
Normal file
74
python/src/agnext/chat/types.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import List, Union
|
||||
|
||||
from ..components import FunctionCall, Image
|
||||
from ..components.models import FunctionExecutionResultMessage
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class BaseMessage:
|
||||
# Name of the agent that sent this message
|
||||
source: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextMessage(BaseMessage):
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiModalMessage(BaseMessage):
|
||||
content: List[Union[str, Image]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionCallMessage(BaseMessage):
|
||||
content: List[FunctionCall]
|
||||
|
||||
|
||||
Message = Union[TextMessage, MultiModalMessage, FunctionCallMessage, FunctionExecutionResultMessage]
|
||||
|
||||
|
||||
class ResponseFormat(Enum):
|
||||
text = "text"
|
||||
json_object = "json_object"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RespondNow:
|
||||
"""A message to request a response from the addressed agent. The sender
|
||||
expects a response upon sening and waits for it synchronously."""
|
||||
|
||||
response_format: ResponseFormat = field(default=ResponseFormat.text)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PublishNow:
|
||||
"""A message to request an event to be published to the addressed agent.
|
||||
Unlike RespondNow, the sender does not expect a response upon sending."""
|
||||
|
||||
response_format: ResponseFormat = field(default=ResponseFormat.text)
|
||||
|
||||
|
||||
class Reset: ...
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolApprovalRequest:
|
||||
"""A message to request approval for a tool call. The sender expects a
|
||||
response upon sending and waits for it synchronously."""
|
||||
|
||||
tool_call: FunctionCall
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolApprovalResponse:
|
||||
"""A message to respond to a tool approval request. The response is sent
|
||||
synchronously."""
|
||||
|
||||
tool_call_id: str
|
||||
approved: bool
|
||||
reason: str
|
||||
98
python/src/agnext/chat/utils.py
Normal file
98
python/src/agnext/chat/utils.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from typing_extensions import Literal
|
||||
|
||||
from ..components.models import (
|
||||
AssistantMessage,
|
||||
FunctionExecutionResult,
|
||||
FunctionExecutionResultMessage,
|
||||
LLMMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from .types import (
|
||||
FunctionCallMessage,
|
||||
Message,
|
||||
MultiModalMessage,
|
||||
TextMessage,
|
||||
)
|
||||
|
||||
|
||||
def convert_content_message_to_assistant_message(
|
||||
message: Union[TextMessage, MultiModalMessage, FunctionCallMessage],
|
||||
handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error",
|
||||
) -> Optional[AssistantMessage]:
|
||||
match message:
|
||||
case TextMessage() | FunctionCallMessage():
|
||||
return AssistantMessage(content=message.content, source=message.source)
|
||||
case MultiModalMessage():
|
||||
if handle_unrepresentable == "error":
|
||||
raise ValueError("Cannot represent multimodal message as AssistantMessage")
|
||||
elif handle_unrepresentable == "ignore":
|
||||
return None
|
||||
elif handle_unrepresentable == "try_slice":
|
||||
return AssistantMessage(
|
||||
content="".join([x for x in message.content if isinstance(x, str)]),
|
||||
source=message.source,
|
||||
)
|
||||
|
||||
|
||||
def convert_content_message_to_user_message(
|
||||
message: Union[TextMessage, MultiModalMessage, FunctionCallMessage],
|
||||
handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error",
|
||||
) -> Optional[UserMessage]:
|
||||
match message:
|
||||
case TextMessage() | MultiModalMessage():
|
||||
return UserMessage(content=message.content, source=message.source)
|
||||
case FunctionCallMessage():
|
||||
if handle_unrepresentable == "error":
|
||||
raise ValueError("Cannot represent multimodal message as UserMessage")
|
||||
elif handle_unrepresentable == "ignore":
|
||||
return None
|
||||
elif handle_unrepresentable == "try_slice":
|
||||
# TODO: what is a sliced function call?
|
||||
raise NotImplementedError("Sliced function calls not yet implemented")
|
||||
|
||||
|
||||
def convert_tool_call_response_message(
|
||||
message: FunctionExecutionResultMessage,
|
||||
handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error",
|
||||
) -> Optional[FunctionExecutionResultMessage]:
|
||||
match message:
|
||||
case FunctionExecutionResultMessage():
|
||||
return FunctionExecutionResultMessage(
|
||||
content=[FunctionExecutionResult(content=x.content, call_id=x.call_id) for x in message.content]
|
||||
)
|
||||
|
||||
|
||||
def convert_messages_to_llm_messages(
|
||||
messages: List[Message],
|
||||
self_name: str,
|
||||
handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error",
|
||||
) -> List[LLMMessage]:
|
||||
result: List[LLMMessage] = []
|
||||
for message in messages:
|
||||
match message:
|
||||
case (
|
||||
TextMessage(content=_, source=source)
|
||||
| MultiModalMessage(content=_, source=source)
|
||||
| FunctionCallMessage(content=_, source=source)
|
||||
) if source == self_name:
|
||||
converted_message_1 = convert_content_message_to_assistant_message(message, handle_unrepresentable)
|
||||
if converted_message_1 is not None:
|
||||
result.append(converted_message_1)
|
||||
case (
|
||||
TextMessage(content=_, source=source)
|
||||
| MultiModalMessage(content=_, source=source)
|
||||
| FunctionCallMessage(content=_, source=source)
|
||||
) if source != self_name:
|
||||
converted_message_2 = convert_content_message_to_user_message(message, handle_unrepresentable)
|
||||
if converted_message_2 is not None:
|
||||
result.append(converted_message_2)
|
||||
case FunctionExecutionResultMessage(_):
|
||||
converted_message_3 = convert_tool_call_response_message(message, handle_unrepresentable)
|
||||
if converted_message_3 is not None:
|
||||
result.append(converted_message_3)
|
||||
case _:
|
||||
raise AssertionError("unreachable")
|
||||
|
||||
return result
|
||||
9
python/src/agnext/components/__init__.py
Normal file
9
python/src/agnext/components/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
The :mod:`agnext.components` module provides building blocks for creating single agents
|
||||
"""
|
||||
|
||||
from ._image import Image
|
||||
from ._type_routed_agent import TypeRoutedAgent, message_handler
|
||||
from ._types import FunctionCall
|
||||
|
||||
__all__ = ["Image", "TypeRoutedAgent", "message_handler", "FunctionCall"]
|
||||
337
python/src/agnext/components/_function_utils.py
Normal file
337
python/src/agnext/components/_function_utils.py
Normal file
@@ -0,0 +1,337 @@
|
||||
# File based from: https://github.com/microsoft/autogen/blob/47f905267245e143562abfb41fcba503a9e1d56d/autogen/function_utils.py
|
||||
# Credit to original authors
|
||||
|
||||
import inspect
|
||||
from logging import getLogger
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
ForwardRef,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, Field, create_model # type: ignore
|
||||
from pydantic_core import PydanticUndefined
|
||||
from typing_extensions import Literal
|
||||
|
||||
from ._pydantic_compat import evaluate_forwardref, model_dump, type2schema
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any:
|
||||
"""Get the type annotation of a parameter.
|
||||
|
||||
Args:
|
||||
annotation: The annotation of the parameter
|
||||
globalns: The global namespace of the function
|
||||
|
||||
Returns:
|
||||
The type annotation of the parameter
|
||||
"""
|
||||
if isinstance(annotation, str):
|
||||
annotation = ForwardRef(annotation)
|
||||
annotation = evaluate_forwardref(annotation, globalns, globalns)
|
||||
return annotation
|
||||
|
||||
|
||||
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
|
||||
"""Get the signature of a function with type annotations.
|
||||
|
||||
Args:
|
||||
call: The function to get the signature for
|
||||
|
||||
Returns:
|
||||
The signature of the function with type annotations
|
||||
"""
|
||||
signature = inspect.signature(call)
|
||||
globalns = getattr(call, "__globals__", {})
|
||||
typed_params = [
|
||||
inspect.Parameter(
|
||||
name=param.name,
|
||||
kind=param.kind,
|
||||
default=param.default,
|
||||
annotation=get_typed_annotation(param.annotation, globalns),
|
||||
)
|
||||
for param in signature.parameters.values()
|
||||
]
|
||||
return_annotation = get_typed_annotation(signature.return_annotation, globalns)
|
||||
typed_signature = inspect.Signature(typed_params, return_annotation=return_annotation)
|
||||
return typed_signature
|
||||
|
||||
|
||||
def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
|
||||
"""Get the return annotation of a function.
|
||||
|
||||
Args:
|
||||
call: The function to get the return annotation for
|
||||
|
||||
Returns:
|
||||
The return annotation of the function
|
||||
"""
|
||||
signature = inspect.signature(call)
|
||||
annotation = signature.return_annotation
|
||||
|
||||
if annotation is inspect.Signature.empty:
|
||||
return None
|
||||
|
||||
globalns = getattr(call, "__globals__", {})
|
||||
return get_typed_annotation(annotation, globalns)
|
||||
|
||||
|
||||
def get_param_annotations(
|
||||
typed_signature: inspect.Signature,
|
||||
) -> Dict[str, Union[Annotated[Type[Any], str], Type[Any]]]:
|
||||
"""Get the type annotations of the parameters of a function
|
||||
|
||||
Args:
|
||||
typed_signature: The signature of the function with type annotations
|
||||
|
||||
Returns:
|
||||
A dictionary of the type annotations of the parameters of the function
|
||||
"""
|
||||
return {
|
||||
k: v.annotation for k, v in typed_signature.parameters.items() if v.annotation is not inspect.Signature.empty
|
||||
}
|
||||
|
||||
|
||||
class Parameters(BaseModel):
|
||||
"""Parameters of a function as defined by the OpenAI API"""
|
||||
|
||||
type: Literal["object"] = "object"
|
||||
properties: Dict[str, Dict[str, Any]]
|
||||
required: List[str]
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
"""A function as defined by the OpenAI API"""
|
||||
|
||||
description: Annotated[str, Field(description="Description of the function")]
|
||||
name: Annotated[str, Field(description="Name of the function")]
|
||||
parameters: Annotated[Parameters, Field(description="Parameters of the function")]
|
||||
|
||||
|
||||
class ToolFunction(BaseModel):
|
||||
"""A function under tool as defined by the OpenAI API."""
|
||||
|
||||
type: Literal["function"] = "function"
|
||||
function: Annotated[Function, Field(description="Function under tool")]
|
||||
|
||||
|
||||
def type2description(k: str, v: Union[Annotated[Type[Any], str], Type[Any]]) -> str:
|
||||
# handles Annotated
|
||||
if hasattr(v, "__metadata__"):
|
||||
retval = v.__metadata__[0]
|
||||
if isinstance(retval, str):
|
||||
return retval
|
||||
else:
|
||||
raise ValueError(f"Invalid description {retval} for parameter {k}, should be a string.")
|
||||
else:
|
||||
return k
|
||||
|
||||
|
||||
def get_parameter_json_schema(k: str, v: Any, default_values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Get a JSON schema for a parameter as defined by the OpenAI API
|
||||
|
||||
Args:
|
||||
k: The name of the parameter
|
||||
v: The type of the parameter
|
||||
default_values: The default values of the parameters of the function
|
||||
|
||||
Returns:
|
||||
A Pydanitc model for the parameter
|
||||
"""
|
||||
|
||||
schema = type2schema(v)
|
||||
if k in default_values:
|
||||
dv = default_values[k]
|
||||
schema["default"] = dv
|
||||
|
||||
schema["description"] = type2description(k, v)
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def get_required_params(typed_signature: inspect.Signature) -> List[str]:
|
||||
"""Get the required parameters of a function
|
||||
|
||||
Args:
|
||||
signature: The signature of the function as returned by inspect.signature
|
||||
|
||||
Returns:
|
||||
A list of the required parameters of the function
|
||||
"""
|
||||
return [k for k, v in typed_signature.parameters.items() if v.default == inspect.Signature.empty]
|
||||
|
||||
|
||||
def get_default_values(typed_signature: inspect.Signature) -> Dict[str, Any]:
|
||||
"""Get default values of parameters of a function
|
||||
|
||||
Args:
|
||||
signature: The signature of the function as returned by inspect.signature
|
||||
|
||||
Returns:
|
||||
A dictionary of the default values of the parameters of the function
|
||||
"""
|
||||
return {k: v.default for k, v in typed_signature.parameters.items() if v.default != inspect.Signature.empty}
|
||||
|
||||
|
||||
def get_parameters(
|
||||
required: List[str],
|
||||
param_annotations: Dict[str, Union[Annotated[Type[Any], str], Type[Any]]],
|
||||
default_values: Dict[str, Any],
|
||||
) -> Parameters:
|
||||
"""Get the parameters of a function as defined by the OpenAI API
|
||||
|
||||
Args:
|
||||
required: The required parameters of the function
|
||||
hints: The type hints of the function as returned by typing.get_type_hints
|
||||
|
||||
Returns:
|
||||
A Pydantic model for the parameters of the function
|
||||
"""
|
||||
return Parameters(
|
||||
properties={
|
||||
k: get_parameter_json_schema(k, v, default_values)
|
||||
for k, v in param_annotations.items()
|
||||
if v is not inspect.Signature.empty
|
||||
},
|
||||
required=required,
|
||||
)
|
||||
|
||||
|
||||
def get_missing_annotations(typed_signature: inspect.Signature, required: List[str]) -> Tuple[Set[str], Set[str]]:
|
||||
"""Get the missing annotations of a function
|
||||
|
||||
Ignores the parameters with default values as they are not required to be annotated, but logs a warning.
|
||||
Args:
|
||||
typed_signature: The signature of the function with type annotations
|
||||
required: The required parameters of the function
|
||||
|
||||
Returns:
|
||||
A set of the missing annotations of the function
|
||||
"""
|
||||
all_missing = {k for k, v in typed_signature.parameters.items() if v.annotation is inspect.Signature.empty}
|
||||
missing = all_missing.intersection(set(required))
|
||||
unannotated_with_default = all_missing.difference(missing)
|
||||
return missing, unannotated_with_default
|
||||
|
||||
|
||||
def get_function_schema(f: Callable[..., Any], *, name: Optional[str] = None, description: str) -> Dict[str, Any]:
|
||||
"""Get a JSON schema for a function as defined by the OpenAI API
|
||||
|
||||
Args:
|
||||
f: The function to get the JSON schema for
|
||||
name: The name of the function
|
||||
description: The description of the function
|
||||
|
||||
Returns:
|
||||
A JSON schema for the function
|
||||
|
||||
Raises:
|
||||
TypeError: If the function is not annotated
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def f(
|
||||
a: Annotated[str, "Parameter a"],
|
||||
b: int = 2,
|
||||
c: Annotated[float, "Parameter c"] = 0.1,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
get_function_schema(f, description="function f")
|
||||
|
||||
# {'type': 'function',
|
||||
# 'function': {'description': 'function f',
|
||||
# 'name': 'f',
|
||||
# 'parameters': {'type': 'object',
|
||||
# 'properties': {'a': {'type': 'str', 'description': 'Parameter a'},
|
||||
# 'b': {'type': 'int', 'description': 'b'},
|
||||
# 'c': {'type': 'float', 'description': 'Parameter c'}},
|
||||
# 'required': ['a']}}}
|
||||
|
||||
"""
|
||||
typed_signature = get_typed_signature(f)
|
||||
required = get_required_params(typed_signature)
|
||||
default_values = get_default_values(typed_signature)
|
||||
param_annotations = get_param_annotations(typed_signature)
|
||||
return_annotation = get_typed_return_annotation(f)
|
||||
missing, unannotated_with_default = get_missing_annotations(typed_signature, required)
|
||||
|
||||
if return_annotation is None:
|
||||
logger.warning(
|
||||
f"The return type of the function '{f.__name__}' is not annotated. Although annotating it is "
|
||||
+ "optional, the function should return either a string, a subclass of 'pydantic.BaseModel'."
|
||||
)
|
||||
|
||||
if unannotated_with_default != set():
|
||||
unannotated_with_default_s = [f"'{k}'" for k in sorted(unannotated_with_default)]
|
||||
logger.warning(
|
||||
f"The following parameters of the function '{f.__name__}' with default values are not annotated: "
|
||||
+ f"{', '.join(unannotated_with_default_s)}."
|
||||
)
|
||||
|
||||
if missing != set():
|
||||
missing_s = [f"'{k}'" for k in sorted(missing)]
|
||||
raise TypeError(
|
||||
f"All parameters of the function '{f.__name__}' without default values must be annotated. "
|
||||
+ f"The annotations are missing for the following parameters: {', '.join(missing_s)}"
|
||||
)
|
||||
|
||||
fname = name if name else f.__name__
|
||||
|
||||
parameters = get_parameters(required, param_annotations, default_values=default_values)
|
||||
|
||||
function = ToolFunction(
|
||||
function=Function(
|
||||
description=description,
|
||||
name=fname,
|
||||
parameters=parameters,
|
||||
)
|
||||
)
|
||||
|
||||
return model_dump(function)
|
||||
|
||||
|
||||
def normalize_annotated_type(type_hint: Type[Any]) -> Type[Any]:
|
||||
"""Normalize typing.Annotated types to the inner type."""
|
||||
if get_origin(type_hint) is Annotated:
|
||||
# Extract the inner type from Annotated
|
||||
return get_args(type_hint)[0] # type: ignore
|
||||
return type_hint
|
||||
|
||||
|
||||
def args_base_model_from_signature(name: str, sig: inspect.Signature) -> Type[BaseModel]:
|
||||
fields: Dict[str, tuple[Type[Any], Any]] = {}
|
||||
for name, param in sig.parameters.items():
|
||||
# This is handled externally
|
||||
if name == "cancellation_token":
|
||||
continue
|
||||
|
||||
if param.annotation is inspect.Parameter.empty:
|
||||
raise ValueError("No annotation")
|
||||
|
||||
type = normalize_annotated_type(param.annotation)
|
||||
description = type2description(name, param.annotation)
|
||||
default_value = param.default if param.default is not inspect.Parameter.empty else PydanticUndefined
|
||||
|
||||
fields[name] = (type, Field(default=default_value, description=description))
|
||||
|
||||
return cast(BaseModel, create_model(name, **fields)) # type: ignore
|
||||
78
python/src/agnext/components/_image.py
Normal file
78
python/src/agnext/components/_image.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import re
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import aiohttp
|
||||
from openai.types.chat import ChatCompletionContentPartImageParam
|
||||
from PIL import Image as PILImage
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
class Image:
|
||||
def __init__(self, image: PILImage.Image):
|
||||
self.image: PILImage.Image = image.convert("RGB")
|
||||
|
||||
@classmethod
|
||||
def from_pil(cls, pil_image: PILImage.Image) -> Image:
|
||||
return cls(pil_image)
|
||||
|
||||
@classmethod
|
||||
def from_uri(cls, uri: str) -> Image:
|
||||
if not re.match(r"data:image/(?:png|jpeg);base64,", uri):
|
||||
raise ValueError("Invalid URI format. It should be a base64 encoded image URI.")
|
||||
|
||||
# A URI. Remove the prefix and decode the base64 string.
|
||||
base64_data = re.sub(r"data:image/(?:png|jpeg);base64,", "", uri)
|
||||
return cls.from_base64(base64_data)
|
||||
|
||||
@classmethod
|
||||
async def from_url(cls, url: str) -> Image:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
content = await response.read()
|
||||
return cls(PILImage.open(content))
|
||||
|
||||
@classmethod
|
||||
def from_base64(cls, base64_str: str) -> Image:
|
||||
return cls(PILImage.open(BytesIO(base64.b64decode(base64_str))))
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, file_path: Path) -> Image:
|
||||
return cls(PILImage.open(file_path))
|
||||
|
||||
def _repr_html_(self) -> str:
|
||||
# Show the image in Jupyter notebook
|
||||
return f'<img src="{self.data_uri}"/>'
|
||||
|
||||
@property
|
||||
def data_uri(self) -> str:
|
||||
buffered = BytesIO()
|
||||
self.image.save(buffered, format="PNG")
|
||||
content = buffered.getvalue()
|
||||
return _convert_base64_to_data_uri(base64.b64encode(content).decode("utf-8"))
|
||||
|
||||
def to_openai_format(self, detail: Literal["auto", "low", "high"] = "auto") -> ChatCompletionContentPartImageParam:
|
||||
return {"type": "image_url", "image_url": {"url": self.data_uri, "detail": detail}}
|
||||
|
||||
|
||||
def _convert_base64_to_data_uri(base64_image: str) -> str:
|
||||
def _get_mime_type_from_data_uri(base64_image: str) -> str:
|
||||
# Decode the base64 string
|
||||
image_data = base64.b64decode(base64_image)
|
||||
# Check the first few bytes for known signatures
|
||||
if image_data.startswith(b"\xff\xd8\xff"):
|
||||
return "image/jpeg"
|
||||
elif image_data.startswith(b"\x89PNG\r\n\x1a\n"):
|
||||
return "image/png"
|
||||
elif image_data.startswith(b"GIF87a") or image_data.startswith(b"GIF89a"):
|
||||
return "image/gif"
|
||||
elif image_data.startswith(b"RIFF") and image_data[8:12] == b"WEBP":
|
||||
return "image/webp"
|
||||
return "image/jpeg" # use jpeg for unknown formats, best guess.
|
||||
|
||||
mime_type = _get_mime_type_from_data_uri(base64_image)
|
||||
data_uri = f"data:{mime_type};base64,{base64_image}"
|
||||
return data_uri
|
||||
65
python/src/agnext/components/_pydantic_compat.py
Normal file
65
python/src/agnext/components/_pydantic_compat.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# File based from: https://github.com/microsoft/autogen/blob/47f905267245e143562abfb41fcba503a9e1d56d/autogen/_pydantic.py
|
||||
# Credit to original authors
|
||||
|
||||
|
||||
from typing import Any, Dict, Tuple, Type, Union, get_args
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.version import VERSION as PYDANTIC_VERSION
|
||||
from typing_extensions import get_origin
|
||||
|
||||
__all__ = ("model_dump", "type2schema", "evaluate_forwardref")
|
||||
|
||||
PYDANTIC_V1 = PYDANTIC_VERSION.startswith("1.")
|
||||
|
||||
|
||||
def evaluate_forwardref(
|
||||
value: Any,
|
||||
globalns: dict[str, Any] | None = None,
|
||||
localns: dict[str, Any] | None = None,
|
||||
) -> Any:
|
||||
if PYDANTIC_V1:
|
||||
from pydantic.typing import evaluate_forwardref as evaluate_forwardref_internal
|
||||
|
||||
return evaluate_forwardref_internal(value, globalns, localns)
|
||||
else:
|
||||
from pydantic._internal._typing_extra import eval_type_lenient
|
||||
|
||||
return eval_type_lenient(value, globalns, localns)
|
||||
|
||||
|
||||
def type2schema(t: Type[Any] | None) -> Dict[str, Any]:
|
||||
if PYDANTIC_V1:
|
||||
from pydantic import schema_of # type: ignore
|
||||
|
||||
if t is None:
|
||||
return {"type": "null"}
|
||||
elif get_origin(t) is Union:
|
||||
return {"anyOf": [type2schema(tt) for tt in get_args(t)]}
|
||||
elif get_origin(t) in [Tuple, tuple]:
|
||||
prefixItems = [type2schema(tt) for tt in get_args(t)]
|
||||
return {
|
||||
"maxItems": len(prefixItems),
|
||||
"minItems": len(prefixItems),
|
||||
"prefixItems": prefixItems,
|
||||
"type": "array",
|
||||
}
|
||||
|
||||
d = schema_of(t) # type: ignore
|
||||
if "title" in d:
|
||||
d.pop("title")
|
||||
if "description" in d:
|
||||
d.pop("description")
|
||||
|
||||
return d
|
||||
else:
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
return TypeAdapter(t).json_schema()
|
||||
|
||||
|
||||
def model_dump(model: BaseModel) -> Dict[str, Any]:
|
||||
if PYDANTIC_V1:
|
||||
return model.dict() # type: ignore
|
||||
else:
|
||||
return model.model_dump()
|
||||
191
python/src/agnext/components/_type_routed_agent.py
Normal file
191
python/src/agnext/components/_type_routed_agent.py
Normal file
@@ -0,0 +1,191 @@
|
||||
import logging
|
||||
from functools import wraps
|
||||
from types import NoneType, UnionType
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
Literal,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Protocol,
|
||||
Sequence,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
get_args,
|
||||
get_origin,
|
||||
get_type_hints,
|
||||
overload,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from ..core import BaseAgent, CancellationToken
|
||||
from ..core.exceptions import CantHandleException
|
||||
|
||||
logger = logging.getLogger("agnext")
|
||||
|
||||
ReceivesT = TypeVar("ReceivesT", contravariant=True)
|
||||
ProducesT = TypeVar("ProducesT", covariant=True)
|
||||
|
||||
# TODO: Generic typevar bound binding U to agent type
|
||||
# Can't do because python doesnt support it
|
||||
|
||||
|
||||
def is_union(t: object) -> bool:
|
||||
origin = get_origin(t)
|
||||
return origin is Union or origin is UnionType
|
||||
|
||||
|
||||
def is_optional(t: object) -> bool:
|
||||
origin = get_origin(t)
|
||||
return origin is Optional
|
||||
|
||||
|
||||
# Special type to avoid the 3.10 vs 3.11+ difference of typing._SpecialForm vs typing.Any
|
||||
class AnyType:
|
||||
pass
|
||||
|
||||
|
||||
def get_types(t: object) -> Sequence[Type[Any]] | None:
|
||||
if is_union(t):
|
||||
return get_args(t)
|
||||
elif is_optional(t):
|
||||
return tuple(list(get_args(t)) + [NoneType])
|
||||
elif t is Any:
|
||||
return (AnyType,)
|
||||
elif isinstance(t, type):
|
||||
return (t,)
|
||||
elif isinstance(t, NoneType):
|
||||
return (NoneType,)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class MessageHandler(Protocol[ReceivesT, ProducesT]):
|
||||
target_types: Sequence[type]
|
||||
produces_types: Sequence[type]
|
||||
is_message_handler: Literal[True]
|
||||
|
||||
async def __call__(self, message: ReceivesT, cancellation_token: CancellationToken) -> ProducesT: ...
|
||||
|
||||
|
||||
# NOTE: this works on concrete types and not inheritance
|
||||
# TODO: Use a protocl for the outer function to check checked arg names
|
||||
|
||||
|
||||
@overload
|
||||
def message_handler(
|
||||
func: Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT]],
|
||||
) -> MessageHandler[ReceivesT, ProducesT]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def message_handler(
|
||||
func: None = None,
|
||||
*,
|
||||
strict: bool = ...,
|
||||
) -> Callable[
|
||||
[Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT]]],
|
||||
MessageHandler[ReceivesT, ProducesT],
|
||||
]: ...
|
||||
|
||||
|
||||
def message_handler(
|
||||
func: None | Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT]] = None,
|
||||
*,
|
||||
strict: bool = True,
|
||||
) -> (
|
||||
Callable[
|
||||
[Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT]]],
|
||||
MessageHandler[ReceivesT, ProducesT],
|
||||
]
|
||||
| MessageHandler[ReceivesT, ProducesT]
|
||||
):
|
||||
def decorator(
|
||||
func: Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT]],
|
||||
) -> MessageHandler[ReceivesT, ProducesT]:
|
||||
type_hints = get_type_hints(func)
|
||||
if "message" not in type_hints:
|
||||
raise AssertionError("message parameter not found in function signature")
|
||||
|
||||
if "return" not in type_hints:
|
||||
raise AssertionError("return not found in function signature")
|
||||
|
||||
# Get the type of the message parameter
|
||||
target_types = get_types(type_hints["message"])
|
||||
if target_types is None:
|
||||
raise AssertionError("Message type not found")
|
||||
|
||||
# print(type_hints)
|
||||
return_types = get_types(type_hints["return"])
|
||||
|
||||
if return_types is None:
|
||||
raise AssertionError("Return type not found")
|
||||
|
||||
# Convert target_types to list and stash
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(self: Any, message: ReceivesT, cancellation_token: CancellationToken) -> ProducesT:
|
||||
if type(message) not in target_types:
|
||||
if strict:
|
||||
raise CantHandleException(f"Message type {type(message)} not in target types {target_types}")
|
||||
else:
|
||||
logger.warning(f"Message type {type(message)} not in target types {target_types}")
|
||||
|
||||
return_value = await func(self, message, cancellation_token)
|
||||
|
||||
if AnyType not in return_types and type(return_value) not in return_types:
|
||||
if strict:
|
||||
raise ValueError(f"Return type {type(return_value)} not in return types {return_types}")
|
||||
else:
|
||||
logger.warning(f"Return type {type(return_value)} not in return types {return_types}")
|
||||
|
||||
return return_value
|
||||
|
||||
wrapper_handler = cast(MessageHandler[ReceivesT, ProducesT], wrapper)
|
||||
wrapper_handler.target_types = list(target_types)
|
||||
wrapper_handler.produces_types = list(return_types)
|
||||
wrapper_handler.is_message_handler = True
|
||||
|
||||
return wrapper_handler
|
||||
|
||||
if func is None and not callable(func):
|
||||
return decorator
|
||||
elif callable(func):
|
||||
return decorator(func)
|
||||
else:
|
||||
raise ValueError("Invalid arguments")
|
||||
|
||||
|
||||
class TypeRoutedAgent(BaseAgent):
|
||||
def __init__(self, description: str) -> None:
|
||||
# Self is already bound to the handlers
|
||||
self._handlers: Dict[
|
||||
Type[Any],
|
||||
Callable[[Any, CancellationToken], Coroutine[Any, Any, Any | None]],
|
||||
] = {}
|
||||
|
||||
for attr in dir(self):
|
||||
if callable(getattr(self, attr, None)):
|
||||
handler = getattr(self, attr)
|
||||
if hasattr(handler, "is_message_handler"):
|
||||
message_handler = cast(MessageHandler[Any, Any], handler)
|
||||
for target_type in message_handler.target_types:
|
||||
self._handlers[target_type] = message_handler
|
||||
subscriptions = list(self._handlers.keys())
|
||||
super().__init__(description, subscriptions)
|
||||
|
||||
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any | None:
|
||||
key_type: Type[Any] = type(message) # type: ignore
|
||||
handler = self._handlers.get(key_type) # type: ignore
|
||||
if handler is not None:
|
||||
return await handler(message, cancellation_token)
|
||||
else:
|
||||
return await self.on_unhandled_message(message, cancellation_token)
|
||||
|
||||
async def on_unhandled_message(self, message: Any, cancellation_token: CancellationToken) -> NoReturn:
|
||||
raise CantHandleException(f"Unhandled message: {message}")
|
||||
12
python/src/agnext/components/_types.py
Normal file
12
python/src/agnext/components/_types.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionCall:
|
||||
id: str
|
||||
# JSON args
|
||||
arguments: str
|
||||
# Function to call
|
||||
name: str
|
||||
17
python/src/agnext/components/code_executor/__init__.py
Normal file
17
python/src/agnext/components/code_executor/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from ._base import CodeBlock, CodeExecutor, CodeResult
|
||||
from ._func_with_reqs import Alias, FunctionWithRequirements, Import, ImportFromModule, with_requirements
|
||||
from ._impl.command_line_code_result import CommandLineCodeResult
|
||||
from ._impl.local_commandline_code_executor import LocalCommandLineCodeExecutor
|
||||
|
||||
__all__ = [
|
||||
"LocalCommandLineCodeExecutor",
|
||||
"CommandLineCodeResult",
|
||||
"CodeBlock",
|
||||
"CodeResult",
|
||||
"CodeExecutor",
|
||||
"Alias",
|
||||
"ImportFromModule",
|
||||
"Import",
|
||||
"FunctionWithRequirements",
|
||||
"with_requirements",
|
||||
]
|
||||
50
python/src/agnext/components/code_executor/_base.py
Normal file
50
python/src/agnext/components/code_executor/_base.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# File based from: https://github.com/microsoft/autogen/blob/main/autogen/coding/base.py
|
||||
# Credit to original authors
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Protocol, runtime_checkable
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodeBlock:
|
||||
"""A code block extracted fromm an agent message."""
|
||||
|
||||
code: str
|
||||
language: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodeResult:
|
||||
"""Result of a code execution."""
|
||||
|
||||
exit_code: int
|
||||
output: str
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class CodeExecutor(Protocol):
|
||||
"""Executes code blocks and returns the result."""
|
||||
|
||||
def execute_code_blocks(self, code_blocks: List[CodeBlock]) -> CodeResult:
|
||||
"""Execute code blocks and return the result.
|
||||
|
||||
This method should be implemented by the code executor.
|
||||
|
||||
Args:
|
||||
code_blocks (List[CodeBlock]): The code blocks to execute.
|
||||
|
||||
Returns:
|
||||
CodeResult: The result of the code execution.
|
||||
"""
|
||||
...
|
||||
|
||||
def restart(self) -> None:
|
||||
"""Restart the code executor.
|
||||
|
||||
This method should be implemented by the code executor.
|
||||
|
||||
This method is called when the agent is reset.
|
||||
"""
|
||||
...
|
||||
200
python/src/agnext/components/code_executor/_func_with_reqs.py
Normal file
200
python/src/agnext/components/code_executor/_func_with_reqs.py
Normal file
@@ -0,0 +1,200 @@
|
||||
# File based from: https://github.com/microsoft/autogen/blob/main/autogen/coding/func_with_reqs.py
|
||||
# Credit to original authors
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
from dataclasses import dataclass, field
|
||||
from importlib.abc import SourceLoader
|
||||
from importlib.util import module_from_spec, spec_from_loader
|
||||
from textwrap import dedent, indent
|
||||
from typing import Any, Callable, Generic, List, Sequence, Set, TypeVar, Union
|
||||
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
T = TypeVar("T")
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
def _to_code(func: Union[FunctionWithRequirements[T, P], Callable[P, T], FunctionWithRequirementsStr]) -> str:
|
||||
if isinstance(func, FunctionWithRequirementsStr):
|
||||
return func.func
|
||||
|
||||
code = inspect.getsource(func)
|
||||
# Strip the decorator
|
||||
if code.startswith("@"):
|
||||
code = code[code.index("\n") + 1 :]
|
||||
return code
|
||||
|
||||
|
||||
@dataclass
|
||||
class Alias:
|
||||
name: str
|
||||
alias: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImportFromModule:
|
||||
module: str
|
||||
imports: List[Union[str, Alias]]
|
||||
|
||||
|
||||
Import = Union[str, ImportFromModule, Alias]
|
||||
|
||||
|
||||
def _import_to_str(im: Import) -> str:
|
||||
if isinstance(im, str):
|
||||
return f"import {im}"
|
||||
elif isinstance(im, Alias):
|
||||
return f"import {im.name} as {im.alias}"
|
||||
else:
|
||||
|
||||
def to_str(i: Union[str, Alias]) -> str:
|
||||
if isinstance(i, str):
|
||||
return i
|
||||
else:
|
||||
return f"{i.name} as {i.alias}"
|
||||
|
||||
imports = ", ".join(map(to_str, im.imports))
|
||||
return f"from {im.module} import {imports}"
|
||||
|
||||
|
||||
class _StringLoader(SourceLoader):
|
||||
def __init__(self, data: str):
|
||||
self.data = data
|
||||
|
||||
def get_source(self, fullname: str) -> str:
|
||||
return self.data
|
||||
|
||||
def get_data(self, path: str) -> bytes:
|
||||
return self.data.encode("utf-8")
|
||||
|
||||
def get_filename(self, fullname: str) -> str:
|
||||
return "<not a real path>/" + fullname + ".py"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionWithRequirementsStr:
|
||||
func: str
|
||||
compiled_func: Callable[..., Any]
|
||||
_func_name: str
|
||||
python_packages: Sequence[str] = field(default_factory=list)
|
||||
global_imports: Sequence[Import] = field(default_factory=list)
|
||||
|
||||
def __init__(self, func: str, python_packages: Sequence[str] = [], global_imports: Sequence[Import] = []):
|
||||
self.func = func
|
||||
self.python_packages = python_packages
|
||||
self.global_imports = global_imports
|
||||
|
||||
module_name = "func_module"
|
||||
loader = _StringLoader(func)
|
||||
spec = spec_from_loader(module_name, loader)
|
||||
if spec is None:
|
||||
raise ValueError("Could not create spec")
|
||||
module = module_from_spec(spec)
|
||||
if spec.loader is None:
|
||||
raise ValueError("Could not create loader")
|
||||
|
||||
try:
|
||||
spec.loader.exec_module(module)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Could not compile function: {e}") from e
|
||||
|
||||
functions = inspect.getmembers(module, inspect.isfunction)
|
||||
if len(functions) != 1:
|
||||
raise ValueError("The string must contain exactly one function")
|
||||
|
||||
self._func_name, self.compiled_func = functions[0]
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> None:
|
||||
raise NotImplementedError("String based function with requirement objects are not directly callable")
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionWithRequirements(Generic[T, P]):
|
||||
func: Callable[P, T]
|
||||
python_packages: Sequence[str] = field(default_factory=list)
|
||||
global_imports: Sequence[Import] = field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def from_callable(
|
||||
cls, func: Callable[P, T], python_packages: Sequence[str] = [], global_imports: Sequence[Import] = []
|
||||
) -> FunctionWithRequirements[T, P]:
|
||||
return cls(python_packages=python_packages, global_imports=global_imports, func=func)
|
||||
|
||||
@staticmethod
|
||||
def from_str(
|
||||
func: str, python_packages: Sequence[str] = [], global_imports: Sequence[Import] = []
|
||||
) -> FunctionWithRequirementsStr:
|
||||
return FunctionWithRequirementsStr(func=func, python_packages=python_packages, global_imports=global_imports)
|
||||
|
||||
# Type this based on F
|
||||
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
|
||||
return self.func(*args, **kwargs)
|
||||
|
||||
|
||||
def with_requirements(
|
||||
python_packages: Sequence[str] = [], global_imports: Sequence[Import] = []
|
||||
) -> Callable[[Callable[P, T]], FunctionWithRequirements[T, P]]:
|
||||
"""Decorate a function with package and import requirements
|
||||
|
||||
Args:
|
||||
python_packages (List[str], optional): Packages required to function. Can include version info.. Defaults to [].
|
||||
global_imports (List[Import], optional): Required imports. Defaults to [].
|
||||
|
||||
Returns:
|
||||
Callable[[Callable[P, T]], FunctionWithRequirements[T, P]]: The decorated function
|
||||
"""
|
||||
|
||||
def wrapper(func: Callable[P, T]) -> FunctionWithRequirements[T, P]:
|
||||
func_with_reqs = FunctionWithRequirements(
|
||||
python_packages=python_packages, global_imports=global_imports, func=func
|
||||
)
|
||||
|
||||
functools.update_wrapper(func_with_reqs, func)
|
||||
return func_with_reqs
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def build_python_functions_file(
|
||||
funcs: Sequence[Union[FunctionWithRequirements[Any, P], Callable[..., Any], FunctionWithRequirementsStr]],
|
||||
) -> str:
|
||||
# First collect all global imports
|
||||
global_imports: Set[Import] = set()
|
||||
for func in funcs:
|
||||
if isinstance(func, (FunctionWithRequirements, FunctionWithRequirementsStr)):
|
||||
global_imports.update(func.global_imports)
|
||||
|
||||
content = "\n".join(map(_import_to_str, global_imports)) + "\n\n"
|
||||
|
||||
for func in funcs:
|
||||
content += _to_code(func) + "\n\n"
|
||||
|
||||
return content
|
||||
|
||||
|
||||
def to_stub(func: Union[Callable[..., Any], FunctionWithRequirementsStr]) -> str:
|
||||
"""Generate a stub for a function as a string
|
||||
|
||||
Args:
|
||||
func (Callable[..., Any]): The function to generate a stub for
|
||||
|
||||
Returns:
|
||||
str: The stub for the function
|
||||
"""
|
||||
if isinstance(func, FunctionWithRequirementsStr):
|
||||
return to_stub(func.compiled_func)
|
||||
|
||||
content = f"def {func.__name__}{inspect.signature(func)}:\n"
|
||||
docstring = func.__doc__
|
||||
|
||||
if docstring:
|
||||
docstring = dedent(docstring)
|
||||
docstring = '"""' + docstring + '"""'
|
||||
docstring = indent(docstring, " ")
|
||||
content += docstring + "\n"
|
||||
|
||||
content += " ..."
|
||||
return content
|
||||
@@ -0,0 +1,11 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from .._base import CodeResult
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommandLineCodeResult(CodeResult):
|
||||
"""A code result class for command line code executor."""
|
||||
|
||||
code_file: Optional[str]
|
||||
@@ -0,0 +1,269 @@
|
||||
# File based from: https://github.com/microsoft/autogen/blob/main/autogen/coding/local_commandline_code_executor.py
|
||||
# Credit to original authors
|
||||
|
||||
import logging
|
||||
import subprocess
|
||||
import sys
|
||||
import warnings
|
||||
from hashlib import md5
|
||||
from pathlib import Path
|
||||
from string import Template
|
||||
from typing import Any, Callable, ClassVar, List, Sequence, Union
|
||||
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from .._base import CodeBlock, CodeExecutor
|
||||
from .._func_with_reqs import (
|
||||
FunctionWithRequirements,
|
||||
FunctionWithRequirementsStr,
|
||||
build_python_functions_file,
|
||||
to_stub,
|
||||
)
|
||||
from .command_line_code_result import CommandLineCodeResult
|
||||
from .utils import PYTHON_VARIANTS, get_file_name_from_content, lang_to_cmd, silence_pip # type: ignore
|
||||
|
||||
__all__ = ("LocalCommandLineCodeExecutor",)
|
||||
|
||||
A = ParamSpec("A")
|
||||
|
||||
|
||||
class LocalCommandLineCodeExecutor(CodeExecutor):
|
||||
SUPPORTED_LANGUAGES: ClassVar[List[str]] = [
|
||||
"bash",
|
||||
"shell",
|
||||
"sh",
|
||||
"pwsh",
|
||||
"powershell",
|
||||
"ps1",
|
||||
"python",
|
||||
]
|
||||
FUNCTION_PROMPT_TEMPLATE: ClassVar[
|
||||
str
|
||||
] = """You have access to the following user defined functions. They can be accessed from the module called `$module_name` by their function names.
|
||||
|
||||
For example, if there was a function called `foo` you could import it by writing `from $module_name import foo`
|
||||
|
||||
$functions"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
timeout: int = 60,
|
||||
work_dir: Union[Path, str] = Path("."),
|
||||
functions: Sequence[
|
||||
Union[
|
||||
FunctionWithRequirements[Any, A],
|
||||
Callable[..., Any],
|
||||
FunctionWithRequirementsStr,
|
||||
]
|
||||
] = [],
|
||||
functions_module: str = "functions",
|
||||
):
|
||||
"""(Experimental) A code executor class that executes code through a local command line
|
||||
environment.
|
||||
|
||||
**This will execute LLM generated code on the local machine.**
|
||||
|
||||
Each code block is saved as a file and executed in a separate process in
|
||||
the working directory, and a unique file is generated and saved in the
|
||||
working directory for each code block.
|
||||
The code blocks are executed in the order they are received.
|
||||
Command line code is sanitized using regular expression match against a list of dangerous commands in order to prevent self-destructive
|
||||
commands from being executed which may potentially affect the users environment.
|
||||
Currently the only supported languages is Python and shell scripts.
|
||||
For Python code, use the language "python" for the code block.
|
||||
For shell scripts, use the language "bash", "shell", or "sh" for the code
|
||||
block.
|
||||
|
||||
Args:
|
||||
timeout (int): The timeout for code execution. Default is 60.
|
||||
work_dir (str): The working directory for the code execution. If None,
|
||||
a default working directory will be used. The default working
|
||||
directory is the current directory ".".
|
||||
functions (List[Union[FunctionWithRequirements[Any, A], Callable[..., Any]]]): A list of functions that are available to the code executor. Default is an empty list.
|
||||
"""
|
||||
|
||||
if timeout < 1:
|
||||
raise ValueError("Timeout must be greater than or equal to 1.")
|
||||
|
||||
if isinstance(work_dir, str):
|
||||
work_dir = Path(work_dir)
|
||||
|
||||
if not functions_module.isidentifier():
|
||||
raise ValueError("Module name must be a valid Python identifier")
|
||||
|
||||
self._functions_module = functions_module
|
||||
|
||||
work_dir.mkdir(exist_ok=True)
|
||||
|
||||
self._timeout = timeout
|
||||
self._work_dir: Path = work_dir
|
||||
|
||||
self._functions = functions
|
||||
# Setup could take some time so we intentionally wait for the first code block to do it.
|
||||
if len(functions) > 0:
|
||||
self._setup_functions_complete = False
|
||||
else:
|
||||
self._setup_functions_complete = True
|
||||
|
||||
def format_functions_for_prompt(self, prompt_template: str = FUNCTION_PROMPT_TEMPLATE) -> str:
|
||||
"""(Experimental) Format the functions for a prompt.
|
||||
|
||||
The template includes two variables:
|
||||
- `$module_name`: The module name.
|
||||
- `$functions`: The functions formatted as stubs with two newlines between each function.
|
||||
|
||||
Args:
|
||||
prompt_template (str): The prompt template. Default is the class default.
|
||||
|
||||
Returns:
|
||||
str: The formatted prompt.
|
||||
"""
|
||||
|
||||
template = Template(prompt_template)
|
||||
return template.substitute(
|
||||
module_name=self._functions_module,
|
||||
functions="\n\n".join([to_stub(func) for func in self._functions]),
|
||||
)
|
||||
|
||||
@property
|
||||
def functions_module(self) -> str:
|
||||
"""(Experimental) The module name for the functions."""
|
||||
return self._functions_module
|
||||
|
||||
@property
|
||||
def functions(self) -> List[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def timeout(self) -> int:
|
||||
"""(Experimental) The timeout for code execution."""
|
||||
return self._timeout
|
||||
|
||||
@property
|
||||
def work_dir(self) -> Path:
|
||||
"""(Experimental) The working directory for the code execution."""
|
||||
return self._work_dir
|
||||
|
||||
def _setup_functions(self) -> None:
|
||||
func_file_content = build_python_functions_file(self._functions)
|
||||
func_file = self._work_dir / f"{self._functions_module}.py"
|
||||
func_file.write_text(func_file_content)
|
||||
|
||||
# Collect requirements
|
||||
lists_of_packages = [x.python_packages for x in self._functions if isinstance(x, FunctionWithRequirements)]
|
||||
flattened_packages = [item for sublist in lists_of_packages for item in sublist]
|
||||
required_packages = list(set(flattened_packages))
|
||||
if len(required_packages) > 0:
|
||||
logging.info("Ensuring packages are installed in executor.")
|
||||
|
||||
cmd = [sys.executable, "-m", "pip", "install"]
|
||||
cmd.extend(required_packages)
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
cwd=self._work_dir,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=float(self._timeout),
|
||||
)
|
||||
except subprocess.TimeoutExpired as e:
|
||||
raise ValueError("Pip install timed out") from e
|
||||
|
||||
if result.returncode != 0:
|
||||
raise ValueError(f"Pip install failed. {result.stdout}, {result.stderr}")
|
||||
|
||||
# Attempt to load the function file to check for syntax errors, imports etc.
|
||||
exec_result = self._execute_code_dont_check_setup([CodeBlock(code=func_file_content, language="python")])
|
||||
|
||||
if exec_result.exit_code != 0:
|
||||
raise ValueError(f"Functions failed to load: {exec_result.output}")
|
||||
|
||||
self._setup_functions_complete = True
|
||||
|
||||
def execute_code_blocks(self, code_blocks: List[CodeBlock]) -> CommandLineCodeResult:
|
||||
"""(Experimental) Execute the code blocks and return the result.
|
||||
|
||||
Args:
|
||||
code_blocks (List[CodeBlock]): The code blocks to execute.
|
||||
|
||||
Returns:
|
||||
CommandLineCodeResult: The result of the code execution."""
|
||||
|
||||
if not self._setup_functions_complete:
|
||||
self._setup_functions()
|
||||
|
||||
return self._execute_code_dont_check_setup(code_blocks)
|
||||
|
||||
def _execute_code_dont_check_setup(self, code_blocks: List[CodeBlock]) -> CommandLineCodeResult:
|
||||
logs_all: str = ""
|
||||
file_names: List[Path] = []
|
||||
exitcode = 0
|
||||
for code_block in code_blocks:
|
||||
lang, code = code_block.language, code_block.code
|
||||
lang = lang.lower()
|
||||
|
||||
code = silence_pip(code, lang)
|
||||
|
||||
if lang in PYTHON_VARIANTS:
|
||||
lang = "python"
|
||||
|
||||
if lang not in self.SUPPORTED_LANGUAGES:
|
||||
# In case the language is not supported, we return an error message.
|
||||
exitcode = 1
|
||||
logs_all += "\n" + f"unknown language {lang}"
|
||||
break
|
||||
|
||||
try:
|
||||
# Check if there is a filename comment
|
||||
filename = get_file_name_from_content(code, self._work_dir)
|
||||
except ValueError:
|
||||
return CommandLineCodeResult(
|
||||
exit_code=1,
|
||||
output="Filename is not in the workspace",
|
||||
code_file=None,
|
||||
)
|
||||
|
||||
if filename is None:
|
||||
# create a file with an automatically generated name
|
||||
code_hash = md5(code.encode()).hexdigest()
|
||||
filename = f"tmp_code_{code_hash}.{'py' if lang.startswith('python') else lang}"
|
||||
|
||||
written_file = (self._work_dir / filename).resolve()
|
||||
with written_file.open("w", encoding="utf-8") as f:
|
||||
f.write(code)
|
||||
file_names.append(written_file)
|
||||
|
||||
program = sys.executable if lang.startswith("python") else lang_to_cmd(lang)
|
||||
cmd = [program, str(written_file.absolute())]
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
cwd=self._work_dir,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=float(self._timeout),
|
||||
)
|
||||
except subprocess.TimeoutExpired:
|
||||
logs_all += "\n Timeout"
|
||||
# Same exit code as the timeout command on linux.
|
||||
exitcode = 124
|
||||
break
|
||||
|
||||
logs_all += result.stderr
|
||||
logs_all += result.stdout
|
||||
exitcode = result.returncode
|
||||
|
||||
if exitcode != 0:
|
||||
break
|
||||
|
||||
code_file = str(file_names[0]) if len(file_names) > 0 else None
|
||||
return CommandLineCodeResult(exit_code=exitcode, output=logs_all, code_file=code_file)
|
||||
|
||||
def restart(self) -> None:
|
||||
"""(Experimental) Restart the code executor."""
|
||||
warnings.warn(
|
||||
"Restarting local command line code executor is not supported. No action is taken.",
|
||||
stacklevel=2,
|
||||
)
|
||||
88
python/src/agnext/components/code_executor/_impl/utils.py
Normal file
88
python/src/agnext/components/code_executor/_impl/utils.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# File based from: https://github.com/microsoft/autogen/blob/main/autogen/coding/utils.py
|
||||
# Credit to original authors
|
||||
|
||||
# Will return the filename relative to the workspace path
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# Raises ValueError if the file is not in the workspace
|
||||
def get_file_name_from_content(code: str, workspace_path: Path) -> Optional[str]:
|
||||
first_line = code.split("\n")[0]
|
||||
# TODO - support other languages
|
||||
if first_line.startswith("# filename:"):
|
||||
filename = first_line.split(":")[1].strip()
|
||||
|
||||
# Handle relative paths in the filename
|
||||
path = Path(filename)
|
||||
if not path.is_absolute():
|
||||
path = workspace_path / path
|
||||
path = path.resolve()
|
||||
# Throws an error if the file is not in the workspace
|
||||
relative = path.relative_to(workspace_path.resolve())
|
||||
return str(relative)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def silence_pip(code: str, lang: str) -> str:
|
||||
"""Apply -qqq flag to pip install commands."""
|
||||
if lang == "python":
|
||||
regex = r"^! ?pip install"
|
||||
elif lang in ["bash", "shell", "sh", "pwsh", "powershell", "ps1"]:
|
||||
regex = r"^pip install"
|
||||
else:
|
||||
return code
|
||||
|
||||
# Find lines that start with pip install and make sure "-qqq" flag is added.
|
||||
lines = code.split("\n")
|
||||
for i, line in enumerate(lines):
|
||||
# use regex to find lines that start with pip install.
|
||||
match = re.search(regex, line)
|
||||
if match is not None:
|
||||
if "-qqq" not in line:
|
||||
lines[i] = line.replace(match.group(0), match.group(0) + " -qqq")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
PYTHON_VARIANTS = ["python", "Python", "py"]
|
||||
|
||||
|
||||
def lang_to_cmd(lang: str) -> str:
|
||||
if lang in PYTHON_VARIANTS:
|
||||
return "python"
|
||||
if lang.startswith("python") or lang in ["bash", "sh"]:
|
||||
return lang
|
||||
if lang in ["shell"]:
|
||||
return "sh"
|
||||
else:
|
||||
raise ValueError(f"Unsupported language: {lang}")
|
||||
|
||||
|
||||
# Regular expression for finding a code block
|
||||
# ```[ \t]*(\w+)?[ \t]*\r?\n(.*?)[ \t]*\r?\n``` Matches multi-line code blocks.
|
||||
# The [ \t]* matches the potential spaces before language name.
|
||||
# The (\w+)? matches the language, where the ? indicates it is optional.
|
||||
# The [ \t]* matches the potential spaces (not newlines) after language name.
|
||||
# The \r?\n makes sure there is a linebreak after ```.
|
||||
# The (.*?) matches the code itself (non-greedy).
|
||||
# The \r?\n makes sure there is a linebreak before ```.
|
||||
# The [ \t]* matches the potential spaces before closing ``` (the spec allows indentation).
|
||||
CODE_BLOCK_PATTERN = r"```[ \t]*(\w+)?[ \t]*\r?\n(.*?)\r?\n[ \t]*```"
|
||||
|
||||
|
||||
def infer_lang(code: str) -> str:
|
||||
"""infer the language for the code.
|
||||
TODO: make it robust.
|
||||
"""
|
||||
if code.startswith("python ") or code.startswith("pip") or code.startswith("python3 "):
|
||||
return "sh"
|
||||
|
||||
# check if code is a valid python code
|
||||
try:
|
||||
compile(code, "test", "exec")
|
||||
return "python"
|
||||
except SyntaxError:
|
||||
# not a valid python code
|
||||
return "unknown"
|
||||
32
python/src/agnext/components/models/__init__.py
Normal file
32
python/src/agnext/components/models/__init__.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from ._model_client import ChatCompletionClient, ModelCapabilities
|
||||
from ._openai_client import (
|
||||
AzureOpenAI,
|
||||
OpenAI,
|
||||
)
|
||||
from ._types import (
|
||||
AssistantMessage,
|
||||
CreateResult,
|
||||
FinishReasons,
|
||||
FunctionExecutionResult,
|
||||
FunctionExecutionResultMessage,
|
||||
LLMMessage,
|
||||
RequestUsage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AzureOpenAI",
|
||||
"OpenAI",
|
||||
"ModelCapabilities",
|
||||
"ChatCompletionClient",
|
||||
"SystemMessage",
|
||||
"UserMessage",
|
||||
"AssistantMessage",
|
||||
"FunctionExecutionResult",
|
||||
"FunctionExecutionResultMessage",
|
||||
"LLMMessage",
|
||||
"RequestUsage",
|
||||
"FinishReasons",
|
||||
"CreateResult",
|
||||
]
|
||||
52
python/src/agnext/components/models/_model_client.py
Normal file
52
python/src/agnext/components/models/_model_client.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Mapping, Optional, Sequence, runtime_checkable
|
||||
|
||||
from typing_extensions import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Protocol,
|
||||
Required,
|
||||
TypedDict,
|
||||
Union,
|
||||
)
|
||||
|
||||
from ..tools import Tool
|
||||
from ._types import CreateResult, LLMMessage, RequestUsage
|
||||
|
||||
|
||||
class ModelCapabilities(TypedDict, total=False):
|
||||
vision: Required[bool]
|
||||
function_calling: Required[bool]
|
||||
json_output: Required[bool]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ChatCompletionClient(Protocol):
|
||||
# Caching has to be handled internally as they can depend on the create args that were stored in the constructor
|
||||
async def create(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
tools: Sequence[Tool] = [],
|
||||
# None means do not override the default
|
||||
# A value means to override the client default - often specified in the constructor
|
||||
json_output: Optional[bool] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
) -> CreateResult: ...
|
||||
|
||||
def create_stream(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
tools: Sequence[Tool] = [],
|
||||
# None means do not override the default
|
||||
# A value means to override the client default - often specified in the constructor
|
||||
json_output: Optional[bool] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
) -> AsyncGenerator[Union[str, CreateResult], None]: ...
|
||||
|
||||
def actual_usage(self) -> RequestUsage: ...
|
||||
|
||||
def total_usage(self) -> RequestUsage: ...
|
||||
|
||||
@property
|
||||
def capabilities(self) -> ModelCapabilities: ...
|
||||
89
python/src/agnext/components/models/_model_info.py
Normal file
89
python/src/agnext/components/models/_model_info.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from typing import Dict
|
||||
|
||||
from ._model_client import ModelCapabilities
|
||||
|
||||
# Based on: https://platform.openai.com/docs/models/continuous-model-upgrades
|
||||
# This is a moving target, so correctness is checked by the model value returned by openai against expected values at runtime``
|
||||
_MODEL_POINTERS = {
|
||||
"gpt-4o": "gpt-4o-2024-05-13",
|
||||
"gpt-4-turbo": "gpt-4-turbo-2024-04-09",
|
||||
"gpt-4-turbo-preview": "gpt-4-0125-preview",
|
||||
"gpt-4": "gpt-4-0613",
|
||||
"gpt-4-32k": "gpt-4-32k-0613",
|
||||
"gpt-3.5-turbo": "gpt-3.5-turbo-0125",
|
||||
"gpt-3.5-turbo-16k": "gpt-3.5-turbo-16k-0613",
|
||||
}
|
||||
|
||||
_MODEL_CAPABILITIES: Dict[str, ModelCapabilities] = {
|
||||
"gpt-4o-2024-05-13": {
|
||||
"vision": True,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
},
|
||||
"gpt-4-turbo-2024-04-09": {
|
||||
"vision": True,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
},
|
||||
"gpt-4-0125-preview": {
|
||||
"vision": False,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
},
|
||||
"gpt-4-1106-preview": {
|
||||
"vision": False,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
},
|
||||
"gpt-4-1106-vision-preview": {
|
||||
"vision": True,
|
||||
"function_calling": False,
|
||||
"json_output": False,
|
||||
},
|
||||
"gpt-4-0613": {
|
||||
"vision": False,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
},
|
||||
"gpt-4-32k-0613": {
|
||||
"vision": False,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
},
|
||||
"gpt-3.5-turbo-0125": {
|
||||
"vision": False,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
},
|
||||
"gpt-3.5-turbo-1106": {
|
||||
"vision": False,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
},
|
||||
"gpt-3.5-turbo-instruct": {
|
||||
"vision": False,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
},
|
||||
"gpt-3.5-turbo-0613": {
|
||||
"vision": False,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
},
|
||||
"gpt-3.5-turbo-16k-0613": {
|
||||
"vision": False,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def resolve_model(model: str) -> str:
|
||||
if model in _MODEL_POINTERS:
|
||||
return _MODEL_POINTERS[model]
|
||||
return model
|
||||
|
||||
|
||||
def get_capabilties(model: str) -> ModelCapabilities:
|
||||
resolved_model = resolve_model(model)
|
||||
return _MODEL_CAPABILITIES[resolved_model]
|
||||
569
python/src/agnext/components/models/_openai_client.py
Normal file
569
python/src/agnext/components/models/_openai_client.py
Normal file
@@ -0,0 +1,569 @@
|
||||
import inspect
|
||||
import logging
|
||||
import re
|
||||
import warnings
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionContentPartParam,
|
||||
ChatCompletionContentPartTextParam,
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionMessageToolCallParam,
|
||||
ChatCompletionRole,
|
||||
ChatCompletionSystemMessageParam,
|
||||
ChatCompletionToolMessageParam,
|
||||
ChatCompletionToolParam,
|
||||
ChatCompletionUserMessageParam,
|
||||
completion_create_params,
|
||||
)
|
||||
from openai.types.shared_params import FunctionDefinition, FunctionParameters
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from ...application.logging import EVENT_LOGGER_NAME, LLMCallEvent
|
||||
from .. import (
|
||||
FunctionCall,
|
||||
Image,
|
||||
)
|
||||
from ..tools import Tool
|
||||
from . import _model_info
|
||||
from ._model_client import ChatCompletionClient, ModelCapabilities
|
||||
from ._types import (
|
||||
AssistantMessage,
|
||||
CreateResult,
|
||||
FunctionExecutionResultMessage,
|
||||
LLMMessage,
|
||||
RequestUsage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from .config import AzureOpenAIClientConfiguration, OpenAIClientConfiguration
|
||||
|
||||
logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
|
||||
openai_init_kwargs = set(inspect.getfullargspec(AsyncOpenAI.__init__).kwonlyargs)
|
||||
aopenai_init_kwargs = set(inspect.getfullargspec(AsyncAzureOpenAI.__init__).kwonlyargs)
|
||||
|
||||
create_kwargs = set(completion_create_params.CompletionCreateParamsBase.__annotations__.keys()) | set(
|
||||
("timeout", "stream")
|
||||
)
|
||||
# Only single choice allowed
|
||||
disallowed_create_args = set(["stream", "messages", "function_call", "functions", "n"])
|
||||
required_create_args: Set[str] = set(["model"])
|
||||
|
||||
|
||||
def _azure_openai_client_from_config(config: Mapping[str, Any]) -> AsyncAzureOpenAI:
|
||||
# Take a copy
|
||||
copied_config = dict(config).copy()
|
||||
|
||||
# Do some fixups
|
||||
copied_config["azure_deployment"] = copied_config.get("azure_deployment", config.get("model"))
|
||||
if copied_config["azure_deployment"] is not None:
|
||||
copied_config["azure_deployment"] = copied_config["azure_deployment"].replace(".", "")
|
||||
copied_config["azure_endpoint"] = copied_config.get("azure_endpoint", copied_config.pop("base_url", None))
|
||||
|
||||
# Shave down the config to just the AzureOpenAI kwargs
|
||||
azure_config = {k: v for k, v in copied_config.items() if k in aopenai_init_kwargs}
|
||||
return AsyncAzureOpenAI(**azure_config)
|
||||
|
||||
|
||||
def _openai_client_from_config(config: Mapping[str, Any]) -> AsyncOpenAI:
|
||||
# Shave down the config to just the OpenAI kwargs
|
||||
openai_config = {k: v for k, v in config.items() if k in openai_init_kwargs}
|
||||
return AsyncOpenAI(**openai_config)
|
||||
|
||||
|
||||
def _create_args_from_config(config: Mapping[str, Any]) -> Dict[str, Any]:
|
||||
create_args = {k: v for k, v in config.items() if k in create_kwargs}
|
||||
create_args_keys = set(create_args.keys())
|
||||
if not required_create_args.issubset(create_args_keys):
|
||||
raise ValueError(f"Required create args are missing: {required_create_args - create_args_keys}")
|
||||
if disallowed_create_args.intersection(create_args_keys):
|
||||
raise ValueError(f"Disallowed create args are present: {disallowed_create_args.intersection(create_args_keys)}")
|
||||
return create_args
|
||||
|
||||
|
||||
# TODO check types
|
||||
# oai_system_message_schema = type2schema(ChatCompletionSystemMessageParam)
|
||||
# oai_user_message_schema = type2schema(ChatCompletionUserMessageParam)
|
||||
# oai_assistant_message_schema = type2schema(ChatCompletionAssistantMessageParam)
|
||||
# oai_tool_message_schema = type2schema(ChatCompletionToolMessageParam)
|
||||
|
||||
|
||||
def type_to_role(message: LLMMessage) -> ChatCompletionRole:
|
||||
if isinstance(message, SystemMessage):
|
||||
return "system"
|
||||
elif isinstance(message, UserMessage):
|
||||
return "user"
|
||||
elif isinstance(message, AssistantMessage):
|
||||
return "assistant"
|
||||
else:
|
||||
return "tool"
|
||||
|
||||
|
||||
def user_message_to_oai(message: UserMessage) -> ChatCompletionUserMessageParam:
|
||||
if isinstance(message.content, str):
|
||||
return ChatCompletionUserMessageParam(
|
||||
content=message.content,
|
||||
role="user",
|
||||
name=message.source,
|
||||
)
|
||||
else:
|
||||
parts: List[ChatCompletionContentPartParam] = []
|
||||
for part in message.content:
|
||||
if isinstance(part, str):
|
||||
oai_part = ChatCompletionContentPartTextParam(
|
||||
text=part,
|
||||
type="text",
|
||||
)
|
||||
parts.append(oai_part)
|
||||
elif isinstance(part, Image):
|
||||
# TODO: support url based images
|
||||
# TODO: support specifying details
|
||||
parts.append(part.to_openai_format())
|
||||
else:
|
||||
raise ValueError(f"Unknown content type: {part}")
|
||||
return ChatCompletionUserMessageParam(
|
||||
content=parts,
|
||||
role="user",
|
||||
name=message.source,
|
||||
)
|
||||
|
||||
|
||||
def system_message_to_oai(message: SystemMessage) -> ChatCompletionSystemMessageParam:
|
||||
return ChatCompletionSystemMessageParam(
|
||||
content=message.content,
|
||||
role="system",
|
||||
)
|
||||
|
||||
|
||||
def func_call_to_oai(message: FunctionCall) -> ChatCompletionMessageToolCallParam:
|
||||
return ChatCompletionMessageToolCallParam(
|
||||
id=message.id,
|
||||
function={
|
||||
"arguments": message.arguments,
|
||||
"name": message.name,
|
||||
},
|
||||
type="function",
|
||||
)
|
||||
|
||||
|
||||
def tool_message_to_oai(
|
||||
message: FunctionExecutionResultMessage,
|
||||
) -> Sequence[ChatCompletionToolMessageParam]:
|
||||
return [
|
||||
ChatCompletionToolMessageParam(content=x.content, role="tool", tool_call_id=x.call_id) for x in message.content
|
||||
]
|
||||
|
||||
|
||||
def assistant_message_to_oai(
|
||||
message: AssistantMessage,
|
||||
) -> ChatCompletionAssistantMessageParam:
|
||||
if isinstance(message.content, list):
|
||||
return ChatCompletionAssistantMessageParam(
|
||||
tool_calls=[func_call_to_oai(x) for x in message.content],
|
||||
role="assistant",
|
||||
name=message.source,
|
||||
)
|
||||
else:
|
||||
return ChatCompletionAssistantMessageParam(
|
||||
content=message.content,
|
||||
role="assistant",
|
||||
name=message.source,
|
||||
)
|
||||
|
||||
|
||||
def to_oai_type(message: LLMMessage) -> Sequence[ChatCompletionMessageParam]:
|
||||
if isinstance(message, SystemMessage):
|
||||
return [system_message_to_oai(message)]
|
||||
elif isinstance(message, UserMessage):
|
||||
return [user_message_to_oai(message)]
|
||||
elif isinstance(message, AssistantMessage):
|
||||
return [assistant_message_to_oai(message)]
|
||||
else:
|
||||
return tool_message_to_oai(message)
|
||||
|
||||
|
||||
def _add_usage(usage1: RequestUsage, usage2: RequestUsage) -> RequestUsage:
|
||||
return RequestUsage(
|
||||
prompt_tokens=usage1.prompt_tokens + usage2.prompt_tokens,
|
||||
completion_tokens=usage1.completion_tokens + usage2.completion_tokens,
|
||||
)
|
||||
|
||||
|
||||
def convert_tools(
|
||||
tools: Sequence[Tool],
|
||||
) -> List[ChatCompletionToolParam]:
|
||||
result: List[ChatCompletionToolParam] = []
|
||||
for tool in tools:
|
||||
tool_schema = tool.schema
|
||||
result.append(
|
||||
ChatCompletionToolParam(
|
||||
type="function",
|
||||
function=FunctionDefinition(
|
||||
name=tool_schema["name"],
|
||||
description=tool_schema["description"] if "description" in tool_schema else "",
|
||||
parameters=cast(FunctionParameters, tool_schema["parameters"])
|
||||
if "parameters" in tool_schema
|
||||
else {},
|
||||
),
|
||||
)
|
||||
)
|
||||
# Check if all tools have valid names.
|
||||
for tool_param in result:
|
||||
assert_valid_name(tool_param["function"]["name"])
|
||||
return result
|
||||
|
||||
|
||||
def normalize_name(name: str) -> str:
|
||||
"""
|
||||
LLMs sometimes ask functions while ignoring their own format requirements, this function should be used to replace invalid characters with "_".
|
||||
|
||||
Prefer _assert_valid_name for validating user configuration or input
|
||||
"""
|
||||
return re.sub(r"[^a-zA-Z0-9_-]", "_", name)[:64]
|
||||
|
||||
|
||||
def assert_valid_name(name: str) -> str:
|
||||
"""
|
||||
Ensure that configured names are valid, raises ValueError if not.
|
||||
|
||||
For munging LLM responses use _normalize_name to ensure LLM specified names don't break the API.
|
||||
"""
|
||||
if not re.match(r"^[a-zA-Z0-9_-]+$", name):
|
||||
raise ValueError(f"Invalid name: {name}. Only letters, numbers, '_' and '-' are allowed.")
|
||||
if len(name) > 64:
|
||||
raise ValueError(f"Invalid name: {name}. Name must be less than 64 characters.")
|
||||
return name
|
||||
|
||||
|
||||
class BaseOpenAI(ChatCompletionClient):
|
||||
def __init__(
|
||||
self,
|
||||
client: Union[AsyncOpenAI, AsyncAzureOpenAI],
|
||||
create_args: Dict[str, Any],
|
||||
model_capabilities: Optional[ModelCapabilities] = None,
|
||||
):
|
||||
self._client = client
|
||||
if model_capabilities is None and isinstance(client, AsyncAzureOpenAI):
|
||||
raise ValueError("AzureOpenAI requires explicit model capabilities")
|
||||
elif model_capabilities is None:
|
||||
self._model_capabilities = _model_info.get_capabilties(create_args["model"])
|
||||
else:
|
||||
self._model_capabilities = model_capabilities
|
||||
|
||||
self._resolved_model: Optional[str] = None
|
||||
if "model" in create_args:
|
||||
self._resolved_model = _model_info.resolve_model(create_args["model"])
|
||||
|
||||
if (
|
||||
"response_format" in create_args
|
||||
and create_args["response_format"]["type"] == "json_object"
|
||||
and not self._model_capabilities["json_output"]
|
||||
):
|
||||
raise ValueError("Model does not support JSON output")
|
||||
|
||||
self._create_args = create_args
|
||||
self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
|
||||
self._actual_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
|
||||
|
||||
@classmethod
|
||||
def create_from_config(cls, config: Dict[str, Any]) -> ChatCompletionClient:
|
||||
return OpenAI(**config)
|
||||
|
||||
async def create(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
tools: Sequence[Tool] = [],
|
||||
json_output: Optional[bool] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
) -> CreateResult:
|
||||
# Make sure all extra_create_args are valid
|
||||
extra_create_args_keys = set(extra_create_args.keys())
|
||||
if not create_kwargs.issuperset(extra_create_args_keys):
|
||||
raise ValueError(f"Extra create args are invalid: {extra_create_args_keys - create_kwargs}")
|
||||
|
||||
# Copy the create args and overwrite anything in extra_create_args
|
||||
create_args = self._create_args.copy()
|
||||
create_args.update(extra_create_args)
|
||||
|
||||
# TODO: allow custom handling.
|
||||
# For now we raise an error if images are present and vision is not supported
|
||||
if self.capabilities["vision"] is False:
|
||||
for message in messages:
|
||||
if isinstance(message, UserMessage):
|
||||
if isinstance(message.content, list) and any(isinstance(x, Image) for x in message.content):
|
||||
raise ValueError("Model does not support vision and image was provided")
|
||||
|
||||
if json_output is not None:
|
||||
if self.capabilities["json_output"] is False and json_output is True:
|
||||
raise ValueError("Model does not support JSON output")
|
||||
|
||||
if json_output is True:
|
||||
create_args["response_format"] = {"type": "json_object"}
|
||||
else:
|
||||
create_args["response_format"] = {"type": "text"}
|
||||
|
||||
if self.capabilities["json_output"] is False and json_output is True:
|
||||
raise ValueError("Model does not support JSON output")
|
||||
|
||||
oai_messages_nested = [to_oai_type(m) for m in messages]
|
||||
oai_messages = [item for sublist in oai_messages_nested for item in sublist]
|
||||
|
||||
if self.capabilities["function_calling"] is False and len(tools) > 0:
|
||||
raise ValueError("Model does not support function calling")
|
||||
|
||||
if len(tools) > 0:
|
||||
converted_tools = convert_tools(tools)
|
||||
result = await self._client.chat.completions.create(
|
||||
messages=oai_messages,
|
||||
stream=False,
|
||||
tools=converted_tools,
|
||||
**create_args,
|
||||
)
|
||||
else:
|
||||
result = await self._client.chat.completions.create(messages=oai_messages, stream=False, **create_args)
|
||||
|
||||
if result.usage is not None:
|
||||
logger.info(
|
||||
LLMCallEvent(
|
||||
prompt_tokens=result.usage.prompt_tokens,
|
||||
completion_tokens=result.usage.completion_tokens,
|
||||
)
|
||||
)
|
||||
|
||||
usage = RequestUsage(
|
||||
# TODO backup token counting
|
||||
prompt_tokens=result.usage.prompt_tokens if result.usage is not None else 0,
|
||||
completion_tokens=(result.usage.completion_tokens if result.usage is not None else 0),
|
||||
)
|
||||
|
||||
if self._resolved_model is not None:
|
||||
if self._resolved_model != result.model:
|
||||
warnings.warn(
|
||||
f"Resolved model mismatch: {self._resolved_model} != {result.model}. AutoGen model mapping may be incorrect.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# Limited to a single choice currently.
|
||||
choice = result.choices[0]
|
||||
if choice.finish_reason == "function_call":
|
||||
raise ValueError("Function calls are not supported in this context")
|
||||
|
||||
content: Union[str, List[FunctionCall]]
|
||||
if choice.finish_reason == "tool_calls":
|
||||
assert choice.message.tool_calls is not None
|
||||
assert choice.message.function_call is None
|
||||
|
||||
# NOTE: If OAI response type changes, this will need to be updated
|
||||
content = [
|
||||
FunctionCall(
|
||||
id=x.id,
|
||||
arguments=x.function.arguments,
|
||||
name=normalize_name(x.function.name),
|
||||
)
|
||||
for x in choice.message.tool_calls
|
||||
]
|
||||
finish_reason = "function_calls"
|
||||
else:
|
||||
finish_reason = choice.finish_reason
|
||||
content = choice.message.content or ""
|
||||
|
||||
response = CreateResult(finish_reason=finish_reason, content=content, usage=usage, cached=False) # type: ignore
|
||||
|
||||
_add_usage(self._actual_usage, usage)
|
||||
_add_usage(self._total_usage, usage)
|
||||
|
||||
# TODO - why is this cast needed?
|
||||
return response
|
||||
|
||||
async def create_stream(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
tools: Sequence[Tool] = [],
|
||||
json_output: Optional[bool] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
) -> AsyncGenerator[Union[str, CreateResult], None]:
|
||||
# Make sure all extra_create_args are valid
|
||||
extra_create_args_keys = set(extra_create_args.keys())
|
||||
if not create_kwargs.issuperset(extra_create_args_keys):
|
||||
raise ValueError(f"Extra create args are invalid: {extra_create_args_keys - create_kwargs}")
|
||||
|
||||
# Copy the create args and overwrite anything in extra_create_args
|
||||
create_args = self._create_args.copy()
|
||||
create_args.update(extra_create_args)
|
||||
|
||||
oai_messages_nested = [to_oai_type(m) for m in messages]
|
||||
oai_messages = [item for sublist in oai_messages_nested for item in sublist]
|
||||
|
||||
# TODO: allow custom handling.
|
||||
# For now we raise an error if images are present and vision is not supported
|
||||
if self.capabilities["vision"] is False:
|
||||
for message in messages:
|
||||
if isinstance(message, UserMessage):
|
||||
if isinstance(message.content, list) and any(isinstance(x, Image) for x in message.content):
|
||||
raise ValueError("Model does not support vision and image was provided")
|
||||
|
||||
if json_output is not None:
|
||||
if self.capabilities["json_output"] is False and json_output is True:
|
||||
raise ValueError("Model does not support JSON output")
|
||||
|
||||
if json_output is True:
|
||||
create_args["response_format"] = {"type": "json_object"}
|
||||
else:
|
||||
create_args["response_format"] = {"type": "text"}
|
||||
|
||||
if len(tools) > 0:
|
||||
converted_tools = convert_tools(tools)
|
||||
stream = await self._client.chat.completions.create(
|
||||
messages=oai_messages, stream=True, tools=converted_tools, **create_args
|
||||
)
|
||||
else:
|
||||
stream = await self._client.chat.completions.create(messages=oai_messages, stream=True, **create_args)
|
||||
|
||||
stop_reason = None
|
||||
maybe_model = None
|
||||
content_deltas: List[str] = []
|
||||
full_tool_calls: Dict[int, FunctionCall] = {}
|
||||
completion_tokens = 0
|
||||
|
||||
async for chunk in stream:
|
||||
choice = chunk.choices[0]
|
||||
stop_reason = choice.finish_reason
|
||||
maybe_model = chunk.model
|
||||
# First try get content
|
||||
if choice.delta.content is not None:
|
||||
content_deltas.append(choice.delta.content)
|
||||
if len(choice.delta.content) > 0:
|
||||
yield choice.delta.content
|
||||
continue
|
||||
|
||||
# Otherwise, get tool calls
|
||||
if choice.delta.tool_calls is not None:
|
||||
for tool_call_chunk in choice.delta.tool_calls:
|
||||
idx = tool_call_chunk.index
|
||||
if idx not in full_tool_calls:
|
||||
# We ignore the type hint here because we want to fill in type when the delta provides it
|
||||
full_tool_calls[idx] = FunctionCall(id="", arguments="", name="")
|
||||
|
||||
if tool_call_chunk.id is not None:
|
||||
full_tool_calls[idx].id += tool_call_chunk.id
|
||||
|
||||
if tool_call_chunk.function is not None:
|
||||
if tool_call_chunk.function.name is not None:
|
||||
full_tool_calls[idx].name += tool_call_chunk.function.name
|
||||
if tool_call_chunk.function.arguments is not None:
|
||||
full_tool_calls[idx].arguments += tool_call_chunk.function.arguments
|
||||
|
||||
model = maybe_model or create_args["model"]
|
||||
model = model.replace("gpt-35", "gpt-3.5") # hack for Azure API
|
||||
|
||||
# TODO fix count token
|
||||
prompt_tokens = 0
|
||||
# prompt_tokens = count_token(messages, model=model)
|
||||
if stop_reason is None:
|
||||
raise ValueError("No stop reason found")
|
||||
|
||||
content: Union[str, List[FunctionCall]]
|
||||
if len(content_deltas) > 1:
|
||||
content = "".join(content_deltas)
|
||||
completion_tokens = 0
|
||||
# completion_tokens = count_token(content, model=model)
|
||||
else:
|
||||
completion_tokens = 0
|
||||
# TODO: fix assumption that dict values were added in order and actually order by int index
|
||||
# for tool_call in full_tool_calls.values():
|
||||
# # value = json.dumps(tool_call)
|
||||
# # completion_tokens += count_token(value, model=model)
|
||||
# completion_tokens += 0
|
||||
content = list(full_tool_calls.values())
|
||||
|
||||
usage = RequestUsage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
if stop_reason == "function_call":
|
||||
raise ValueError("Function calls are not supported in this context")
|
||||
if stop_reason == "tool_calls":
|
||||
stop_reason = "function_calls"
|
||||
|
||||
result = CreateResult(finish_reason=stop_reason, content=content, usage=usage, cached=False)
|
||||
|
||||
_add_usage(self._actual_usage, usage)
|
||||
_add_usage(self._total_usage, usage)
|
||||
|
||||
yield result
|
||||
|
||||
def actual_usage(self) -> RequestUsage:
|
||||
return self._actual_usage
|
||||
|
||||
def total_usage(self) -> RequestUsage:
|
||||
return self._total_usage
|
||||
|
||||
@property
|
||||
def capabilities(self) -> ModelCapabilities:
|
||||
return self._model_capabilities
|
||||
|
||||
|
||||
class OpenAI(BaseOpenAI):
|
||||
def __init__(self, **kwargs: Unpack[OpenAIClientConfiguration]):
|
||||
if "model" not in kwargs:
|
||||
raise ValueError("model is required for OpenAI")
|
||||
|
||||
model_capabilities: Optional[ModelCapabilities] = None
|
||||
copied_args = dict(kwargs).copy()
|
||||
if "model_capabilities" in kwargs:
|
||||
model_capabilities = kwargs["model_capabilities"]
|
||||
del copied_args["model_capabilities"]
|
||||
|
||||
client = _openai_client_from_config(copied_args)
|
||||
create_args = _create_args_from_config(copied_args)
|
||||
self._raw_config = copied_args
|
||||
super().__init__(client, create_args, model_capabilities)
|
||||
|
||||
def __getstate__(self) -> Dict[str, Any]:
|
||||
state = self.__dict__.copy()
|
||||
state["_client"] = None
|
||||
return state
|
||||
|
||||
def __setstate__(self, state: Dict[str, Any]) -> None:
|
||||
self.__dict__.update(state)
|
||||
self._client = _openai_client_from_config(state["_raw_config"])
|
||||
|
||||
|
||||
class AzureOpenAI(BaseOpenAI):
|
||||
def __init__(self, **kwargs: Unpack[AzureOpenAIClientConfiguration]):
|
||||
if "model" not in kwargs:
|
||||
raise ValueError("model is required for OpenAI")
|
||||
|
||||
model_capabilities: Optional[ModelCapabilities] = None
|
||||
copied_args = dict(kwargs).copy()
|
||||
if "model_capabilities" in kwargs:
|
||||
model_capabilities = kwargs["model_capabilities"]
|
||||
del copied_args["model_capabilities"]
|
||||
|
||||
client = _azure_openai_client_from_config(copied_args)
|
||||
create_args = _create_args_from_config(copied_args)
|
||||
self._raw_config = copied_args
|
||||
super().__init__(client, create_args, model_capabilities)
|
||||
|
||||
def __getstate__(self) -> Dict[str, Any]:
|
||||
state = self.__dict__.copy()
|
||||
state["_client"] = None
|
||||
return state
|
||||
|
||||
def __setstate__(self, state: Dict[str, Any]) -> None:
|
||||
self.__dict__.update(state)
|
||||
self._client = _azure_openai_client_from_config(state["_raw_config"])
|
||||
56
python/src/agnext/components/models/_types.py
Normal file
56
python/src/agnext/components/models/_types.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Literal, Union
|
||||
|
||||
from .. import FunctionCall, Image
|
||||
|
||||
|
||||
@dataclass
|
||||
class SystemMessage:
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserMessage:
|
||||
content: Union[str, List[Union[str, Image]]]
|
||||
|
||||
# Name of the agent that sent this message
|
||||
source: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class AssistantMessage:
|
||||
content: Union[str, List[FunctionCall]]
|
||||
|
||||
# Name of the agent that sent this message
|
||||
source: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionExecutionResult:
|
||||
content: str
|
||||
call_id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionExecutionResultMessage:
|
||||
content: List[FunctionExecutionResult]
|
||||
|
||||
|
||||
LLMMessage = Union[SystemMessage, UserMessage, AssistantMessage, FunctionExecutionResultMessage]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestUsage:
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
|
||||
|
||||
FinishReasons = Literal["stop", "length", "function_calls", "content_filter"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class CreateResult:
|
||||
finish_reason: FinishReasons
|
||||
content: Union[str, List[FunctionCall]]
|
||||
usage: RequestUsage
|
||||
cached: bool
|
||||
52
python/src/agnext/components/models/config/__init__.py
Normal file
52
python/src/agnext/components/models/config/__init__.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from typing import Awaitable, Callable, Dict, List, Literal, Optional, Union
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
from .._model_client import ModelCapabilities
|
||||
|
||||
|
||||
class ResponseFormat(TypedDict):
|
||||
type: Literal["text", "json_object"]
|
||||
|
||||
|
||||
class CreateArguments(TypedDict, total=False):
|
||||
frequency_penalty: Optional[float]
|
||||
logit_bias: Optional[Dict[str, int]]
|
||||
max_tokens: Optional[int]
|
||||
n: Optional[int]
|
||||
presence_penalty: Optional[float]
|
||||
response_format: ResponseFormat
|
||||
seed: Optional[int]
|
||||
stop: Union[Optional[str], List[str]]
|
||||
temperature: Optional[float]
|
||||
top_p: Optional[float]
|
||||
user: str
|
||||
|
||||
|
||||
AsyncAzureADTokenProvider = Callable[[], Union[str, Awaitable[str]]]
|
||||
|
||||
|
||||
class BaseOpenAIClientConfiguration(CreateArguments, total=False):
|
||||
model: str
|
||||
api_key: str
|
||||
timeout: Union[float, None]
|
||||
max_retries: int
|
||||
|
||||
|
||||
# See OpenAI docs for explanation of these parameters
|
||||
class OpenAIClientConfiguration(BaseOpenAIClientConfiguration, total=False):
|
||||
organization: str
|
||||
base_url: str
|
||||
# Not required
|
||||
model_capabilities: ModelCapabilities
|
||||
|
||||
|
||||
class AzureOpenAIClientConfiguration(BaseOpenAIClientConfiguration, total=False):
|
||||
# Azure specific
|
||||
azure_endpoint: Required[str]
|
||||
azure_deployment: str
|
||||
api_version: Required[str]
|
||||
azure_ad_token: str
|
||||
azure_ad_token_provider: AsyncAzureADTokenProvider
|
||||
# Must be provided
|
||||
model_capabilities: Required[ModelCapabilities]
|
||||
13
python/src/agnext/components/tools/__init__.py
Normal file
13
python/src/agnext/components/tools/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from ._base import BaseTool, BaseToolWithState, Tool
|
||||
from ._code_execution import CodeExecutionInput, CodeExecutionResult, PythonCodeExecutionTool
|
||||
from ._function_tool import FunctionTool
|
||||
|
||||
__all__ = [
|
||||
"Tool",
|
||||
"BaseTool",
|
||||
"BaseToolWithState",
|
||||
"PythonCodeExecutionTool",
|
||||
"CodeExecutionInput",
|
||||
"CodeExecutionResult",
|
||||
"FunctionTool",
|
||||
]
|
||||
151
python/src/agnext/components/tools/_base.py
Normal file
151
python/src/agnext/components/tools/_base.py
Normal file
@@ -0,0 +1,151 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Dict, Generic, Mapping, Protocol, Type, TypedDict, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from ...core import CancellationToken
|
||||
from .._function_utils import normalize_annotated_type
|
||||
|
||||
T = TypeVar("T", bound=BaseModel, contravariant=True)
|
||||
|
||||
|
||||
class ParametersSchema(TypedDict):
|
||||
type: str
|
||||
properties: Dict[str, Any]
|
||||
required: NotRequired[Sequence[str]]
|
||||
|
||||
|
||||
class ToolSchema(TypedDict):
|
||||
parameters: NotRequired[ParametersSchema]
|
||||
name: str
|
||||
description: NotRequired[str]
|
||||
|
||||
|
||||
class Tool(Protocol):
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
|
||||
@property
|
||||
def description(self) -> str: ...
|
||||
|
||||
@property
|
||||
def schema(self) -> ToolSchema: ...
|
||||
|
||||
def args_type(self) -> Type[BaseModel]: ...
|
||||
|
||||
def return_type(self) -> Type[Any]: ...
|
||||
|
||||
def state_type(self) -> Type[BaseModel] | None: ...
|
||||
|
||||
def return_value_as_string(self, value: Any) -> str: ...
|
||||
|
||||
async def run_json(self, args: Mapping[str, Any], cancellation_token: CancellationToken) -> Any: ...
|
||||
|
||||
def save_state_json(self) -> Mapping[str, Any]: ...
|
||||
|
||||
def load_state_json(self, state: Mapping[str, Any]) -> None: ...
|
||||
|
||||
|
||||
ArgsT = TypeVar("ArgsT", bound=BaseModel, contravariant=True)
|
||||
ReturnT = TypeVar("ReturnT", bound=BaseModel, covariant=True)
|
||||
StateT = TypeVar("StateT", bound=BaseModel)
|
||||
|
||||
|
||||
class BaseTool(ABC, Tool, Generic[ArgsT, ReturnT]):
|
||||
def __init__(
|
||||
self,
|
||||
args_type: Type[ArgsT],
|
||||
return_type: Type[ReturnT],
|
||||
name: str,
|
||||
description: str,
|
||||
) -> None:
|
||||
self._args_type = args_type
|
||||
# Normalize Annotated to the base type.
|
||||
self._return_type = normalize_annotated_type(return_type)
|
||||
self._name = name
|
||||
self._description = description
|
||||
|
||||
@property
|
||||
def schema(self) -> ToolSchema:
|
||||
model_schema = self._args_type.model_json_schema()
|
||||
|
||||
tool_schema = ToolSchema(
|
||||
name=self._name,
|
||||
description=self._description,
|
||||
parameters=ParametersSchema(
|
||||
type="object",
|
||||
properties=model_schema["properties"],
|
||||
),
|
||||
)
|
||||
if "required" in model_schema:
|
||||
assert "parameters" in tool_schema
|
||||
tool_schema["parameters"]["required"] = model_schema["required"]
|
||||
|
||||
return tool_schema
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return self._description
|
||||
|
||||
def args_type(self) -> Type[BaseModel]:
|
||||
return self._args_type
|
||||
|
||||
def return_type(self) -> Type[Any]:
|
||||
return self._return_type
|
||||
|
||||
def state_type(self) -> Type[BaseModel] | None:
|
||||
return None
|
||||
|
||||
def return_value_as_string(self, value: Any) -> str:
|
||||
if isinstance(value, BaseModel):
|
||||
dumped = value.model_dump()
|
||||
if isinstance(dumped, dict):
|
||||
return json.dumps(dumped)
|
||||
return str(dumped)
|
||||
|
||||
return str(value)
|
||||
|
||||
@abstractmethod
|
||||
async def run(self, args: ArgsT, cancellation_token: CancellationToken) -> ReturnT: ...
|
||||
|
||||
async def run_json(self, args: Mapping[str, Any], cancellation_token: CancellationToken) -> Any:
|
||||
return_value = await self.run(self._args_type.model_validate(args), cancellation_token)
|
||||
return return_value
|
||||
|
||||
def save_state_json(self) -> Mapping[str, Any]:
|
||||
return {}
|
||||
|
||||
def load_state_json(self, state: Mapping[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class BaseToolWithState(BaseTool[ArgsT, ReturnT], ABC, Generic[ArgsT, ReturnT, StateT]):
|
||||
def __init__(
|
||||
self,
|
||||
args_type: Type[ArgsT],
|
||||
return_type: Type[ReturnT],
|
||||
state_type: Type[StateT],
|
||||
name: str,
|
||||
description: str,
|
||||
) -> None:
|
||||
super().__init__(args_type, return_type, name, description)
|
||||
self._state_type = state_type
|
||||
|
||||
@abstractmethod
|
||||
def save_state(self) -> StateT: ...
|
||||
|
||||
@abstractmethod
|
||||
def load_state(self, state: StateT) -> None: ...
|
||||
|
||||
def save_state_json(self) -> Mapping[str, Any]:
|
||||
return self.save_state().model_dump()
|
||||
|
||||
def load_state_json(self, state: Mapping[str, Any]) -> None:
|
||||
self.load_state(self._state_type.model_validate(state))
|
||||
37
python/src/agnext/components/tools/_code_execution.py
Normal file
37
python/src/agnext/components/tools/_code_execution.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import asyncio
|
||||
import functools
|
||||
|
||||
from pydantic import BaseModel, Field, model_serializer
|
||||
|
||||
from ...core import CancellationToken
|
||||
from ..code_executor import CodeBlock, CodeExecutor
|
||||
from ._base import BaseTool
|
||||
|
||||
|
||||
class CodeExecutionInput(BaseModel):
|
||||
code: str = Field(description="The contents of the Python code block that should be executed")
|
||||
|
||||
|
||||
class CodeExecutionResult(BaseModel):
|
||||
success: bool
|
||||
output: str
|
||||
|
||||
@model_serializer
|
||||
def ser_model(self) -> str:
|
||||
return self.output
|
||||
|
||||
|
||||
class PythonCodeExecutionTool(BaseTool[CodeExecutionInput, CodeExecutionResult]):
|
||||
def __init__(self, executor: CodeExecutor):
|
||||
super().__init__(CodeExecutionInput, CodeExecutionResult, "CodeExecutor", "Execute Python code blocks.")
|
||||
self._executor = executor
|
||||
|
||||
async def run(self, args: CodeExecutionInput, cancellation_token: CancellationToken) -> CodeExecutionResult:
|
||||
code_blocks = [CodeBlock(code=args.code, language="python")]
|
||||
future = asyncio.get_event_loop().run_in_executor(
|
||||
None, functools.partial(self._executor.execute_code_blocks, code_blocks=code_blocks)
|
||||
)
|
||||
cancellation_token.link_future(future)
|
||||
result = await future
|
||||
|
||||
return CodeExecutionResult(success=result.exit_code == 0, output=result.output)
|
||||
50
python/src/agnext/components/tools/_function_tool.py
Normal file
50
python/src/agnext/components/tools/_function_tool.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import asyncio
|
||||
import functools
|
||||
from typing import Any, Callable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ...core import CancellationToken
|
||||
from .._function_utils import (
|
||||
args_base_model_from_signature,
|
||||
get_typed_signature,
|
||||
)
|
||||
from ._base import BaseTool
|
||||
|
||||
|
||||
class FunctionTool(BaseTool[BaseModel, BaseModel]):
|
||||
def __init__(self, func: Callable[..., Any], description: str, name: str | None = None) -> None:
|
||||
self._func = func
|
||||
signature = get_typed_signature(func)
|
||||
func_name = name or func.__name__
|
||||
args_model = args_base_model_from_signature(func_name + "args", signature)
|
||||
return_type = signature.return_annotation
|
||||
self._has_cancellation_support = "cancellation_token" in signature.parameters
|
||||
|
||||
super().__init__(args_model, return_type, func_name, description)
|
||||
|
||||
async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> Any:
|
||||
if asyncio.iscoroutinefunction(self._func):
|
||||
if self._has_cancellation_support:
|
||||
result = await self._func(**args.model_dump(), cancellation_token=cancellation_token)
|
||||
else:
|
||||
result = await self._func(**args.model_dump())
|
||||
else:
|
||||
if self._has_cancellation_support:
|
||||
result = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
self._func,
|
||||
**args.model_dump(),
|
||||
cancellation_token=cancellation_token,
|
||||
),
|
||||
)
|
||||
else:
|
||||
future = asyncio.get_event_loop().run_in_executor(
|
||||
None, functools.partial(self._func, **args.model_dump())
|
||||
)
|
||||
cancellation_token.link_future(future)
|
||||
result = await future
|
||||
|
||||
assert isinstance(result, self.return_type())
|
||||
return result
|
||||
24
python/src/agnext/core/__init__.py
Normal file
24
python/src/agnext/core/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
The :mod:`agnext.core` module provides the foundational generic interfaces upon which all else is built. This module must not depend on any other module.
|
||||
"""
|
||||
|
||||
from ._agent import Agent
|
||||
from ._agent_id import AgentId
|
||||
from ._agent_metadata import AgentMetadata
|
||||
from ._agent_props import AgentChildren
|
||||
from ._agent_proxy import AgentProxy
|
||||
from ._agent_runtime import AgentRuntime, AllNamespaces
|
||||
from ._base_agent import BaseAgent
|
||||
from ._cancellation_token import CancellationToken
|
||||
|
||||
__all__ = [
|
||||
"Agent",
|
||||
"AgentId",
|
||||
"AgentProxy",
|
||||
"AgentMetadata",
|
||||
"AgentRuntime",
|
||||
"AllNamespaces",
|
||||
"BaseAgent",
|
||||
"CancellationToken",
|
||||
"AgentChildren",
|
||||
]
|
||||
46
python/src/agnext/core/_agent.py
Normal file
46
python/src/agnext/core/_agent.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from typing import Any, Mapping, Protocol, runtime_checkable
|
||||
|
||||
from ._agent_id import AgentId
|
||||
from ._agent_metadata import AgentMetadata
|
||||
from ._cancellation_token import CancellationToken
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Agent(Protocol):
|
||||
@property
|
||||
def metadata(self) -> AgentMetadata:
|
||||
"""Metadata of the agent."""
|
||||
...
|
||||
|
||||
@property
|
||||
def id(self) -> AgentId:
|
||||
"""ID of the agent."""
|
||||
...
|
||||
|
||||
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any:
|
||||
"""Message handler for the agent. This should only be called by the runtime, not by other agents.
|
||||
|
||||
Args:
|
||||
message (Any): Received message. Type is one of the types in `subscriptions`.
|
||||
cancellation_token (CancellationToken): Cancellation token for the message.
|
||||
|
||||
Returns:
|
||||
Any: Response to the message. Can be None.
|
||||
|
||||
Notes:
|
||||
If there was a cancellation, this function should raise a `CancelledError`.
|
||||
"""
|
||||
...
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
"""Save the state of the agent. The result must be JSON serializable."""
|
||||
...
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
"""Load in the state of the agent obtained from `save_state`.
|
||||
|
||||
Args:
|
||||
state (Mapping[str, Any]): State of the agent. Must be JSON serializable.
|
||||
"""
|
||||
|
||||
...
|
||||
31
python/src/agnext/core/_agent_id.py
Normal file
31
python/src/agnext/core/_agent_id.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
class AgentId:
|
||||
def __init__(self, name: str, namespace: str) -> None:
|
||||
self._name = name
|
||||
self._namespace = namespace
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self._namespace}/{self._name}"
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self._namespace, self._name))
|
||||
|
||||
def __eq__(self, value: object) -> bool:
|
||||
if not isinstance(value, AgentId):
|
||||
return False
|
||||
return self._name == value.name and self._namespace == value.namespace
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, agent_id: str) -> Self:
|
||||
namespace, name = agent_id.split("/")
|
||||
return cls(name, namespace)
|
||||
|
||||
@property
|
||||
def namespace(self) -> str:
|
||||
return self._namespace
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
8
python/src/agnext/core/_agent_metadata.py
Normal file
8
python/src/agnext/core/_agent_metadata.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from typing import Sequence, TypedDict
|
||||
|
||||
|
||||
class AgentMetadata(TypedDict):
|
||||
name: str
|
||||
namespace: str
|
||||
description: str
|
||||
subscriptions: Sequence[type]
|
||||
11
python/src/agnext/core/_agent_props.py
Normal file
11
python/src/agnext/core/_agent_props.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from typing import Protocol, Sequence, runtime_checkable
|
||||
|
||||
from ._agent_id import AgentId
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AgentChildren(Protocol):
|
||||
@property
|
||||
def children(self) -> Sequence[AgentId]:
|
||||
"""Ids of the children of the agent."""
|
||||
...
|
||||
53
python/src/agnext/core/_agent_proxy.py
Normal file
53
python/src/agnext/core/_agent_proxy.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from asyncio import Future
|
||||
from typing import TYPE_CHECKING, Any, Mapping
|
||||
|
||||
from ._agent_id import AgentId
|
||||
from ._agent_metadata import AgentMetadata
|
||||
from ._cancellation_token import CancellationToken
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._agent_runtime import AgentRuntime
|
||||
|
||||
|
||||
class AgentProxy:
|
||||
def __init__(self, agent: AgentId, runtime: AgentRuntime):
|
||||
self._agent = agent
|
||||
self._runtime = runtime
|
||||
|
||||
@property
|
||||
def id(self) -> AgentId:
|
||||
"""Target agent for this proxy"""
|
||||
return self._agent
|
||||
|
||||
@property
|
||||
def metadata(self) -> AgentMetadata:
|
||||
"""Metadata of the agent."""
|
||||
return self._runtime.agent_metadata(self._agent)
|
||||
|
||||
def send_message(
|
||||
self,
|
||||
message: Any,
|
||||
*,
|
||||
sender: AgentId,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> Future[Any]:
|
||||
return self._runtime.send_message(
|
||||
message,
|
||||
recipient=self._agent,
|
||||
sender=sender,
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
"""Save the state of the agent. The result must be JSON serializable."""
|
||||
return self._runtime.agent_save_state(self._agent)
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
"""Load in the state of the agent obtained from `save_state`.
|
||||
|
||||
Args:
|
||||
state (Mapping[str, Any]): State of the agent. Must be JSON serializable.
|
||||
"""
|
||||
self._runtime.agent_load_state(self._agent, state)
|
||||
162
python/src/agnext/core/_agent_runtime.py
Normal file
162
python/src/agnext/core/_agent_runtime.py
Normal file
@@ -0,0 +1,162 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from asyncio import Future
|
||||
from typing import Any, Callable, Mapping, Protocol, Sequence, Type, TypeVar, overload, runtime_checkable
|
||||
|
||||
from ._agent import Agent
|
||||
from ._agent_id import AgentId
|
||||
from ._agent_metadata import AgentMetadata
|
||||
from ._agent_proxy import AgentProxy
|
||||
from ._cancellation_token import CancellationToken
|
||||
|
||||
# Undeliverable - error
|
||||
|
||||
T = TypeVar("T", bound=Agent)
|
||||
|
||||
|
||||
class AllNamespaces:
|
||||
pass
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AgentRuntime(Protocol):
|
||||
# Returns the response of the message
|
||||
def send_message(
|
||||
self,
|
||||
message: Any,
|
||||
recipient: AgentId,
|
||||
*,
|
||||
sender: AgentId | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> Future[Any]: ...
|
||||
|
||||
# No responses from publishing
|
||||
def publish_message(
|
||||
self,
|
||||
message: Any,
|
||||
*,
|
||||
namespace: str | None = None,
|
||||
sender: AgentId | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> Future[None]: ...
|
||||
|
||||
@overload
|
||||
def register(
|
||||
self, name: str, agent_factory: Callable[[], T], *, valid_namespaces: Sequence[str] | Type[AllNamespaces] = ...
|
||||
) -> None: ...
|
||||
|
||||
@overload
|
||||
def register(
|
||||
self,
|
||||
name: str,
|
||||
agent_factory: Callable[[AgentRuntime, AgentId], T],
|
||||
*,
|
||||
valid_namespaces: Sequence[str] | Type[AllNamespaces] = ...,
|
||||
) -> None: ...
|
||||
|
||||
def register(
|
||||
self,
|
||||
name: str,
|
||||
agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T],
|
||||
*,
|
||||
valid_namespaces: Sequence[str] | Type[AllNamespaces] = AllNamespaces,
|
||||
) -> None:
|
||||
"""Register an agent factory with the runtime associated with a specific name. The name must be unique.
|
||||
|
||||
Args:
|
||||
name (str): The name of the type agent this factory creates.
|
||||
agent_factory (Callable[[], T] | Callable[[AgentRuntime, AgentId], T]): The factory that creates the agent.
|
||||
valid_namespaces (Sequence[str] | Type[AllNamespaces], optional): Valid namespaces for this type. Defaults to AllNamespaces.
|
||||
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
runtime.register(
|
||||
"chat_agent",
|
||||
lambda: ChatCompletionAgent(
|
||||
description="A generic chat agent.",
|
||||
system_messages=[SystemMessage("You are a helpful assistant")],
|
||||
model_client=OpenAI(model="gpt-4o"),
|
||||
memory=BufferedChatMemory(buffer_size=10),
|
||||
),
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
...
|
||||
|
||||
def get(self, name: str, *, namespace: str = "default") -> AgentId: ...
|
||||
def get_proxy(self, name: str, *, namespace: str = "default") -> AgentProxy: ...
|
||||
|
||||
@overload
|
||||
def register_and_get(
|
||||
self,
|
||||
name: str,
|
||||
agent_factory: Callable[[], T],
|
||||
*,
|
||||
namespace: str = "default",
|
||||
valid_namespaces: Sequence[str] | Type[AllNamespaces] = ...,
|
||||
) -> AgentId: ...
|
||||
|
||||
@overload
|
||||
def register_and_get(
|
||||
self,
|
||||
name: str,
|
||||
agent_factory: Callable[[AgentRuntime, AgentId], T],
|
||||
*,
|
||||
namespace: str = "default",
|
||||
valid_namespaces: Sequence[str] | Type[AllNamespaces] = ...,
|
||||
) -> AgentId: ...
|
||||
|
||||
def register_and_get(
|
||||
self,
|
||||
name: str,
|
||||
agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T],
|
||||
*,
|
||||
namespace: str = "default",
|
||||
valid_namespaces: Sequence[str] | Type[AllNamespaces] = AllNamespaces,
|
||||
) -> AgentId:
|
||||
self.register(name, agent_factory)
|
||||
return self.get(name, namespace=namespace)
|
||||
|
||||
@overload
|
||||
def register_and_get_proxy(
|
||||
self,
|
||||
name: str,
|
||||
agent_factory: Callable[[], T],
|
||||
*,
|
||||
namespace: str = "default",
|
||||
valid_namespaces: Sequence[str] | Type[AllNamespaces] = ...,
|
||||
) -> AgentProxy: ...
|
||||
|
||||
@overload
|
||||
def register_and_get_proxy(
|
||||
self,
|
||||
name: str,
|
||||
agent_factory: Callable[[AgentRuntime, AgentId], T],
|
||||
*,
|
||||
namespace: str = "default",
|
||||
valid_namespaces: Sequence[str] | Type[AllNamespaces] = ...,
|
||||
) -> AgentProxy: ...
|
||||
|
||||
def register_and_get_proxy(
|
||||
self,
|
||||
name: str,
|
||||
agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T],
|
||||
*,
|
||||
namespace: str = "default",
|
||||
valid_namespaces: Sequence[str] | Type[AllNamespaces] = AllNamespaces,
|
||||
) -> AgentProxy:
|
||||
self.register(name, agent_factory)
|
||||
return self.get_proxy(name, namespace=namespace)
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]: ...
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None: ...
|
||||
|
||||
def agent_metadata(self, agent: AgentId) -> AgentMetadata: ...
|
||||
|
||||
def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]: ...
|
||||
|
||||
def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None: ...
|
||||
106
python/src/agnext/core/_base_agent.py
Normal file
106
python/src/agnext/core/_base_agent.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from asyncio import Future
|
||||
from typing import Any, Mapping, Sequence
|
||||
|
||||
from ._agent import Agent
|
||||
from ._agent_id import AgentId
|
||||
from ._agent_metadata import AgentMetadata
|
||||
from ._agent_runtime import AgentRuntime
|
||||
from ._cancellation_token import CancellationToken
|
||||
|
||||
|
||||
class BaseAgent(ABC, Agent):
|
||||
@property
|
||||
def metadata(self) -> AgentMetadata:
|
||||
assert self._id is not None
|
||||
return AgentMetadata(
|
||||
namespace=self._id.namespace,
|
||||
name=self._id.name,
|
||||
description=self._description,
|
||||
subscriptions=self._subscriptions,
|
||||
)
|
||||
|
||||
def __init__(self, description: str, subscriptions: Sequence[type]) -> None:
|
||||
self._runtime: AgentRuntime | None = None
|
||||
self._id: AgentId | None = None
|
||||
self._description = description
|
||||
self._subscriptions = subscriptions
|
||||
|
||||
def bind_runtime(self, runtime: AgentRuntime) -> None:
|
||||
if self._runtime is not None:
|
||||
raise RuntimeError("Agent has already been bound to a runtime.")
|
||||
|
||||
self._runtime = runtime
|
||||
|
||||
def bind_id(self, agent_id: AgentId) -> None:
|
||||
if self._id is not None:
|
||||
raise RuntimeError("Agent has already been bound to an id.")
|
||||
self._id = agent_id
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.id.name
|
||||
|
||||
@property
|
||||
def id(self) -> AgentId:
|
||||
if self._id is None:
|
||||
raise RuntimeError("Agent has not been bound to an id.")
|
||||
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def runtime(self) -> AgentRuntime:
|
||||
if self._runtime is None:
|
||||
raise RuntimeError("Agent has not been bound to a runtime.")
|
||||
|
||||
return self._runtime
|
||||
|
||||
@abstractmethod
|
||||
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any: ...
|
||||
|
||||
# Returns the response of the message
|
||||
def send_message(
|
||||
self,
|
||||
message: Any,
|
||||
recipient: AgentId,
|
||||
*,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> Future[Any]:
|
||||
if self._runtime is None:
|
||||
raise RuntimeError("Agent has not been bound to a runtime.")
|
||||
|
||||
if cancellation_token is None:
|
||||
cancellation_token = CancellationToken()
|
||||
|
||||
future = self._runtime.send_message(
|
||||
message,
|
||||
sender=self.id,
|
||||
recipient=recipient,
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
cancellation_token.link_future(future)
|
||||
return future
|
||||
|
||||
def publish_message(
|
||||
self,
|
||||
message: Any,
|
||||
*,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> Future[None]:
|
||||
if self._runtime is None:
|
||||
raise RuntimeError("Agent has not been bound to a runtime.")
|
||||
|
||||
if cancellation_token is None:
|
||||
cancellation_token = CancellationToken()
|
||||
|
||||
future = self._runtime.publish_message(message, sender=self.id, cancellation_token=cancellation_token)
|
||||
return future
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
warnings.warn("save_state not implemented", stacklevel=2)
|
||||
return {}
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
warnings.warn("load_state not implemented", stacklevel=2)
|
||||
pass
|
||||
39
python/src/agnext/core/_cancellation_token.py
Normal file
39
python/src/agnext/core/_cancellation_token.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import threading
|
||||
from asyncio import Future
|
||||
from typing import Any, Callable, List
|
||||
|
||||
|
||||
class CancellationToken:
|
||||
def __init__(self) -> None:
|
||||
self._cancelled: bool = False
|
||||
self._lock: threading.Lock = threading.Lock()
|
||||
self._callbacks: List[Callable[[], None]] = []
|
||||
|
||||
def cancel(self) -> None:
|
||||
with self._lock:
|
||||
if not self._cancelled:
|
||||
self._cancelled = True
|
||||
for callback in self._callbacks:
|
||||
callback()
|
||||
|
||||
def is_cancelled(self) -> bool:
|
||||
with self._lock:
|
||||
return self._cancelled
|
||||
|
||||
def add_callback(self, callback: Callable[[], None]) -> None:
|
||||
with self._lock:
|
||||
if self._cancelled:
|
||||
callback()
|
||||
else:
|
||||
self._callbacks.append(callback)
|
||||
|
||||
def link_future(self, future: Future[Any]) -> None:
|
||||
with self._lock:
|
||||
if self._cancelled:
|
||||
future.cancel()
|
||||
else:
|
||||
|
||||
def _cancel() -> None:
|
||||
future.cancel()
|
||||
|
||||
self._callbacks.append(_cancel)
|
||||
17
python/src/agnext/core/exceptions.py
Normal file
17
python/src/agnext/core/exceptions.py
Normal file
@@ -0,0 +1,17 @@
|
||||
__all__ = [
|
||||
"CantHandleException",
|
||||
"UndeliverableException",
|
||||
"MessageDroppedException",
|
||||
]
|
||||
|
||||
|
||||
class CantHandleException(Exception):
|
||||
"""Raised when a handler can't handle the exception."""
|
||||
|
||||
|
||||
class UndeliverableException(Exception):
|
||||
"""Raised when a message can't be delivered."""
|
||||
|
||||
|
||||
class MessageDroppedException(Exception):
|
||||
"""Raised when a message is dropped."""
|
||||
36
python/src/agnext/core/intervention.py
Normal file
36
python/src/agnext/core/intervention.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from typing import Any, Awaitable, Callable, Protocol, final
|
||||
|
||||
from agnext.core import AgentId
|
||||
|
||||
__all__ = [
|
||||
"DropMessage",
|
||||
"InterventionFunction",
|
||||
"InterventionHandler",
|
||||
"DefaultInterventionHandler",
|
||||
]
|
||||
|
||||
|
||||
@final
|
||||
class DropMessage: ...
|
||||
|
||||
|
||||
InterventionFunction = Callable[[Any], Any | Awaitable[type[DropMessage]]]
|
||||
|
||||
|
||||
class InterventionHandler(Protocol):
|
||||
async def on_send(self, message: Any, *, sender: AgentId | None, recipient: AgentId) -> Any | type[DropMessage]: ...
|
||||
async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any | type[DropMessage]: ...
|
||||
async def on_response(
|
||||
self, message: Any, *, sender: AgentId, recipient: AgentId | None
|
||||
) -> Any | type[DropMessage]: ...
|
||||
|
||||
|
||||
class DefaultInterventionHandler(InterventionHandler):
|
||||
async def on_send(self, message: Any, *, sender: AgentId | None, recipient: AgentId) -> Any | type[DropMessage]:
|
||||
return message
|
||||
|
||||
async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any | type[DropMessage]:
|
||||
return message
|
||||
|
||||
async def on_response(self, message: Any, *, sender: AgentId, recipient: AgentId | None) -> Any | type[DropMessage]:
|
||||
return message
|
||||
0
python/src/agnext/py.typed
Normal file
0
python/src/agnext/py.typed
Normal file
104
python/tests/execution/test_commandline_code_executor.py
Normal file
104
python/tests/execution/test_commandline_code_executor.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# File based from: https://github.com/microsoft/autogen/blob/main/test/coding/test_commandline_code_executor.py
|
||||
# Credit to original authors
|
||||
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from agnext.components.code_executor import CodeBlock, LocalCommandLineCodeExecutor
|
||||
|
||||
UNIX_SHELLS = ["bash", "sh", "shell"]
|
||||
WINDOWS_SHELLS = ["ps1", "pwsh", "powershell"]
|
||||
PYTHON_VARIANTS = ["python", "Python", "py"]
|
||||
|
||||
|
||||
def test_execute_code() -> None:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
executor = LocalCommandLineCodeExecutor(work_dir=temp_dir)
|
||||
|
||||
|
||||
# Test single code block.
|
||||
code_blocks = [CodeBlock(code="import sys; print('hello world!')", language="python")]
|
||||
code_result = executor.execute_code_blocks(code_blocks)
|
||||
assert code_result.exit_code == 0 and "hello world!" in code_result.output and code_result.code_file is not None
|
||||
|
||||
# Test multiple code blocks.
|
||||
code_blocks = [
|
||||
CodeBlock(code="import sys; print('hello world!')", language="python"),
|
||||
CodeBlock(code="a = 100 + 100; print(a)", language="python"),
|
||||
]
|
||||
code_result = executor.execute_code_blocks(code_blocks)
|
||||
assert (
|
||||
code_result.exit_code == 0
|
||||
and "hello world!" in code_result.output
|
||||
and "200" in code_result.output
|
||||
and code_result.code_file is not None
|
||||
)
|
||||
|
||||
# Test bash script.
|
||||
if sys.platform not in ["win32"]:
|
||||
code_blocks = [CodeBlock(code="echo 'hello world!'", language="bash")]
|
||||
code_result = executor.execute_code_blocks(code_blocks)
|
||||
assert code_result.exit_code == 0 and "hello world!" in code_result.output and code_result.code_file is not None
|
||||
|
||||
# Test running code.
|
||||
file_lines = ["import sys", "print('hello world!')", "a = 100 + 100", "print(a)"]
|
||||
code_blocks = [CodeBlock(code="\n".join(file_lines), language="python")]
|
||||
code_result = executor.execute_code_blocks(code_blocks)
|
||||
assert (
|
||||
code_result.exit_code == 0
|
||||
and "hello world!" in code_result.output
|
||||
and "200" in code_result.output
|
||||
and code_result.code_file is not None
|
||||
)
|
||||
|
||||
# Check saved code file.
|
||||
with open(code_result.code_file) as f:
|
||||
code_lines = f.readlines()
|
||||
for file_line, code_line in zip(file_lines, code_lines):
|
||||
assert file_line.strip() == code_line.strip()
|
||||
|
||||
|
||||
def test_commandline_code_executor_timeout() -> None:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
executor = LocalCommandLineCodeExecutor(timeout=1, work_dir=temp_dir)
|
||||
code_blocks = [CodeBlock(code="import time; time.sleep(10); print('hello world!')", language="python")]
|
||||
code_result = executor.execute_code_blocks(code_blocks)
|
||||
assert code_result.exit_code and "Timeout" in code_result.output
|
||||
|
||||
|
||||
def test_local_commandline_code_executor_restart() -> None:
|
||||
executor = LocalCommandLineCodeExecutor()
|
||||
with pytest.warns(UserWarning, match=r".*No action is taken."):
|
||||
executor.restart()
|
||||
|
||||
|
||||
|
||||
|
||||
def test_invalid_relative_path() -> None:
|
||||
executor = LocalCommandLineCodeExecutor()
|
||||
code = """# filename: /tmp/test.py
|
||||
|
||||
print("hello world")
|
||||
"""
|
||||
result = executor.execute_code_blocks([CodeBlock(code=code, language="python")])
|
||||
assert result.exit_code == 1 and "Filename is not in the workspace" in result.output
|
||||
|
||||
|
||||
def test_valid_relative_path() -> None:
|
||||
with tempfile.TemporaryDirectory() as temp_dir_str:
|
||||
temp_dir = Path(temp_dir_str)
|
||||
executor = LocalCommandLineCodeExecutor(work_dir=temp_dir)
|
||||
code = """# filename: test.py
|
||||
|
||||
print("hello world")
|
||||
"""
|
||||
result = executor.execute_code_blocks([CodeBlock(code=code, language="python")])
|
||||
assert result.exit_code == 0
|
||||
assert "hello world" in result.output
|
||||
assert result.code_file is not None
|
||||
assert "test.py" in result.code_file
|
||||
assert (temp_dir / Path("test.py")).resolve() == Path(result.code_file).resolve()
|
||||
assert (temp_dir / Path("test.py")).exists()
|
||||
|
||||
210
python/tests/execution/test_user_defined_functions.py
Normal file
210
python/tests/execution/test_user_defined_functions.py
Normal file
@@ -0,0 +1,210 @@
|
||||
# File based from: https://github.com/microsoft/autogen/blob/main/test/coding/test_user_defined_functions.py
|
||||
# Credit to original authors
|
||||
|
||||
import tempfile
|
||||
|
||||
import polars
|
||||
import pytest
|
||||
from agnext.components.code_executor import (
|
||||
CodeBlock,
|
||||
FunctionWithRequirements,
|
||||
LocalCommandLineCodeExecutor,
|
||||
with_requirements,
|
||||
)
|
||||
|
||||
|
||||
def add_two_numbers(a: int, b: int) -> int:
|
||||
"""Add two numbers together."""
|
||||
return a + b
|
||||
|
||||
|
||||
@with_requirements(python_packages=["polars"], global_imports=["polars"])
|
||||
def load_data() -> polars.DataFrame:
|
||||
"""Load some sample data.
|
||||
|
||||
Returns:
|
||||
polars.DataFrame: A DataFrame with the following columns: name(str), location(str), age(int)
|
||||
"""
|
||||
data = {
|
||||
"name": ["John", "Anna", "Peter", "Linda"],
|
||||
"location": ["New York", "Paris", "Berlin", "London"],
|
||||
"age": [24, 13, 53, 33],
|
||||
}
|
||||
return polars.DataFrame(data)
|
||||
|
||||
|
||||
@with_requirements(global_imports=["NOT_A_REAL_PACKAGE"])
|
||||
def function_incorrect_import() -> "polars.DataFrame":
|
||||
return polars.DataFrame()
|
||||
|
||||
|
||||
@with_requirements(python_packages=["NOT_A_REAL_PACKAGE"])
|
||||
def function_incorrect_dep() -> "polars.DataFrame":
|
||||
return polars.DataFrame()
|
||||
|
||||
|
||||
def function_missing_reqs() -> "polars.DataFrame":
|
||||
return polars.DataFrame()
|
||||
|
||||
|
||||
def test_can_load_function_with_reqs() -> None:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
executor = LocalCommandLineCodeExecutor(
|
||||
work_dir=temp_dir, functions=[load_data]
|
||||
)
|
||||
code = f"""from {executor.functions_module} import load_data
|
||||
import polars
|
||||
|
||||
# Get first row's name
|
||||
data = load_data()
|
||||
print(data['name'][0])"""
|
||||
|
||||
result = executor.execute_code_blocks(
|
||||
code_blocks=[
|
||||
CodeBlock(language="python", code=code),
|
||||
]
|
||||
)
|
||||
assert result.output == "John\n"
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
def test_can_load_function() -> None:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
executor = LocalCommandLineCodeExecutor(
|
||||
work_dir=temp_dir, functions=[add_two_numbers]
|
||||
)
|
||||
code = f"""from {executor.functions_module} import add_two_numbers
|
||||
print(add_two_numbers(1, 2))"""
|
||||
|
||||
result = executor.execute_code_blocks(
|
||||
code_blocks=[
|
||||
CodeBlock(language="python", code=code),
|
||||
]
|
||||
)
|
||||
assert result.output == "3\n"
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
def test_fails_for_function_incorrect_import() -> None:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
executor = LocalCommandLineCodeExecutor(
|
||||
work_dir=temp_dir, functions=[function_incorrect_import]
|
||||
)
|
||||
code = f"""from {executor.functions_module} import function_incorrect_import
|
||||
function_incorrect_import()"""
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
executor.execute_code_blocks(
|
||||
code_blocks=[
|
||||
CodeBlock(language="python", code=code),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_fails_for_function_incorrect_dep() -> None:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
executor = LocalCommandLineCodeExecutor(
|
||||
work_dir=temp_dir, functions=[function_incorrect_dep]
|
||||
)
|
||||
code = f"""from {executor.functions_module} import function_incorrect_dep
|
||||
function_incorrect_dep()"""
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
executor.execute_code_blocks(
|
||||
code_blocks=[
|
||||
CodeBlock(language="python", code=code),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_formatted_prompt() -> None:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
executor = LocalCommandLineCodeExecutor(
|
||||
work_dir=temp_dir, functions=[add_two_numbers]
|
||||
)
|
||||
|
||||
result = executor.format_functions_for_prompt()
|
||||
assert (
|
||||
'''def add_two_numbers(a: int, b: int) -> int:
|
||||
"""Add two numbers together."""
|
||||
'''
|
||||
in result
|
||||
)
|
||||
|
||||
|
||||
def test_formatted_prompt_str_func() -> None:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
func = FunctionWithRequirements.from_str(
|
||||
'''
|
||||
def add_two_numbers(a: int, b: int) -> int:
|
||||
"""Add two numbers together."""
|
||||
return a + b
|
||||
'''
|
||||
)
|
||||
executor = LocalCommandLineCodeExecutor(work_dir=temp_dir, functions=[func])
|
||||
|
||||
result = executor.format_functions_for_prompt()
|
||||
assert (
|
||||
'''def add_two_numbers(a: int, b: int) -> int:
|
||||
"""Add two numbers together."""
|
||||
'''
|
||||
in result
|
||||
)
|
||||
|
||||
|
||||
def test_can_load_str_function_with_reqs() -> None:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
func = FunctionWithRequirements.from_str(
|
||||
'''
|
||||
def add_two_numbers(a: int, b: int) -> int:
|
||||
"""Add two numbers together."""
|
||||
return a + b
|
||||
'''
|
||||
)
|
||||
|
||||
executor = LocalCommandLineCodeExecutor(work_dir=temp_dir, functions=[func])
|
||||
code = f"""from {executor.functions_module} import add_two_numbers
|
||||
print(add_two_numbers(1, 2))"""
|
||||
|
||||
result = executor.execute_code_blocks(
|
||||
code_blocks=[
|
||||
CodeBlock(language="python", code=code),
|
||||
]
|
||||
)
|
||||
assert result.output == "3\n"
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
def test_cant_load_broken_str_function_with_reqs() -> None:
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_ = FunctionWithRequirements.from_str(
|
||||
'''
|
||||
invaliddef add_two_numbers(a: int, b: int) -> int:
|
||||
"""Add two numbers together."""
|
||||
return a + b
|
||||
'''
|
||||
)
|
||||
|
||||
|
||||
def test_cant_run_broken_str_function_with_reqs() -> None:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
func = FunctionWithRequirements.from_str(
|
||||
'''
|
||||
def add_two_numbers(a: int, b: int) -> int:
|
||||
"""Add two numbers together."""
|
||||
return a + b
|
||||
'''
|
||||
)
|
||||
|
||||
executor = LocalCommandLineCodeExecutor(work_dir=temp_dir, functions=[func])
|
||||
code = f"""from {executor.functions_module} import add_two_numbers
|
||||
print(add_two_numbers(object(), False))"""
|
||||
|
||||
result = executor.execute_code_blocks(
|
||||
code_blocks=[
|
||||
CodeBlock(language="python", code=code),
|
||||
]
|
||||
)
|
||||
assert "TypeError: unsupported operand type(s) for +:" in result.output
|
||||
assert result.exit_code == 1
|
||||
128
python/tests/test_cancellation.py
Normal file
128
python/tests/test_cancellation.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.components import TypeRoutedAgent, message_handler
|
||||
from agnext.core import AgentId, CancellationToken
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageType:
|
||||
...
|
||||
|
||||
# Note for future reader:
|
||||
# To do cancellation, only the token should be interacted with as a user
|
||||
# If you cancel a future, it may not work as you expect.
|
||||
|
||||
class LongRunningAgent(TypeRoutedAgent):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("A long running agent")
|
||||
self.called = False
|
||||
self.cancelled = False
|
||||
|
||||
@message_handler
|
||||
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
|
||||
self.called = True
|
||||
sleep = asyncio.ensure_future(asyncio.sleep(100))
|
||||
cancellation_token.link_future(sleep)
|
||||
try:
|
||||
await sleep
|
||||
return MessageType()
|
||||
except asyncio.CancelledError:
|
||||
self.cancelled = True
|
||||
raise
|
||||
|
||||
class NestingLongRunningAgent(TypeRoutedAgent):
|
||||
def __init__(self, nested_agent: AgentId) -> None:
|
||||
super().__init__("A nesting long running agent")
|
||||
self.called = False
|
||||
self.cancelled = False
|
||||
self._nested_agent = nested_agent
|
||||
|
||||
@message_handler
|
||||
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
|
||||
self.called = True
|
||||
response = self.send_message(message, self._nested_agent, cancellation_token=cancellation_token)
|
||||
try:
|
||||
val = await response
|
||||
assert isinstance(val, MessageType)
|
||||
return val
|
||||
except asyncio.CancelledError:
|
||||
self.cancelled = True
|
||||
raise
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancellation_with_token() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
long_running = runtime.register_and_get("long_running", LongRunningAgent)
|
||||
token = CancellationToken()
|
||||
response = runtime.send_message(MessageType(), recipient=long_running, cancellation_token=token)
|
||||
assert not response.done()
|
||||
|
||||
await runtime.process_next()
|
||||
token.cancel()
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await response
|
||||
|
||||
assert response.done()
|
||||
long_running_agent: LongRunningAgent = runtime._get_agent(long_running) # type: ignore
|
||||
assert long_running_agent.called
|
||||
assert long_running_agent.cancelled
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nested_cancellation_only_outer_called() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
long_running = runtime.register_and_get("long_running", LongRunningAgent)
|
||||
nested = runtime.register_and_get("nested", lambda: NestingLongRunningAgent(long_running))
|
||||
|
||||
token = CancellationToken()
|
||||
response = runtime.send_message(MessageType(), nested, cancellation_token=token)
|
||||
assert not response.done()
|
||||
|
||||
await runtime.process_next()
|
||||
token.cancel()
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await response
|
||||
|
||||
assert response.done()
|
||||
nested_agent: NestingLongRunningAgent = runtime._get_agent(nested) # type: ignore
|
||||
assert nested_agent.called
|
||||
assert nested_agent.cancelled
|
||||
long_running_agent: LongRunningAgent = runtime._get_agent(long_running) # type: ignore
|
||||
assert long_running_agent.called is False
|
||||
assert long_running_agent.cancelled is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nested_cancellation_inner_called() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
long_running = runtime.register_and_get("long_running", LongRunningAgent )
|
||||
nested = runtime.register_and_get("nested", lambda: NestingLongRunningAgent(long_running))
|
||||
|
||||
token = CancellationToken()
|
||||
response = runtime.send_message(MessageType(), nested, cancellation_token=token)
|
||||
assert not response.done()
|
||||
|
||||
await runtime.process_next()
|
||||
# allow the inner agent to process
|
||||
await runtime.process_next()
|
||||
token.cancel()
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await response
|
||||
|
||||
assert response.done()
|
||||
nested_agent: NestingLongRunningAgent = runtime._get_agent(nested) # type: ignore
|
||||
assert nested_agent.called
|
||||
assert nested_agent.cancelled
|
||||
long_running_agent: LongRunningAgent = runtime._get_agent(long_running) # type: ignore
|
||||
assert long_running_agent.called
|
||||
assert long_running_agent.cancelled
|
||||
124
python/tests/test_intervention.py
Normal file
124
python/tests/test_intervention.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import pytest
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.core import AgentId
|
||||
from agnext.core.exceptions import MessageDroppedException
|
||||
from agnext.core.intervention import DefaultInterventionHandler, DropMessage
|
||||
from test_utils import LoopbackAgent, MessageType
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_intervention_count_messages() -> None:
|
||||
|
||||
class DebugInterventionHandler(DefaultInterventionHandler):
|
||||
def __init__(self) -> None:
|
||||
self.num_messages = 0
|
||||
|
||||
async def on_send(self, message: MessageType, *, sender: AgentId | None, recipient: AgentId) -> MessageType:
|
||||
self.num_messages += 1
|
||||
return message
|
||||
|
||||
handler = DebugInterventionHandler()
|
||||
runtime = SingleThreadedAgentRuntime(before_send=handler)
|
||||
loopback = runtime.register_and_get("name", LoopbackAgent)
|
||||
|
||||
response = runtime.send_message(MessageType(), recipient=loopback)
|
||||
|
||||
while not response.done():
|
||||
await runtime.process_next()
|
||||
|
||||
assert handler.num_messages == 1
|
||||
loopback_agent: LoopbackAgent = runtime._get_agent(loopback) # type: ignore
|
||||
assert loopback_agent.num_calls == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_intervention_drop_send() -> None:
|
||||
|
||||
class DropSendInterventionHandler(DefaultInterventionHandler):
|
||||
async def on_send(self, message: MessageType, *, sender: AgentId | None, recipient: AgentId) -> MessageType | type[DropMessage]:
|
||||
return DropMessage
|
||||
|
||||
handler = DropSendInterventionHandler()
|
||||
runtime = SingleThreadedAgentRuntime(before_send=handler)
|
||||
|
||||
loopback = runtime.register_and_get("name", LoopbackAgent)
|
||||
response = runtime.send_message(MessageType(), recipient=loopback)
|
||||
|
||||
while not response.done():
|
||||
await runtime.process_next()
|
||||
|
||||
with pytest.raises(MessageDroppedException):
|
||||
await response
|
||||
|
||||
loopback_agent: LoopbackAgent = runtime._get_agent(loopback) # type: ignore
|
||||
assert loopback_agent.num_calls == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_intervention_drop_response() -> None:
|
||||
|
||||
class DropResponseInterventionHandler(DefaultInterventionHandler):
|
||||
async def on_response(self, message: MessageType, *, sender: AgentId, recipient: AgentId | None) -> MessageType | type[DropMessage]:
|
||||
return DropMessage
|
||||
|
||||
handler = DropResponseInterventionHandler()
|
||||
runtime = SingleThreadedAgentRuntime(before_send=handler)
|
||||
|
||||
loopback = runtime.register_and_get("name", LoopbackAgent)
|
||||
response = runtime.send_message(MessageType(), recipient=loopback)
|
||||
|
||||
while not response.done():
|
||||
await runtime.process_next()
|
||||
|
||||
with pytest.raises(MessageDroppedException):
|
||||
await response
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_intervention_raise_exception_on_send() -> None:
|
||||
|
||||
class InterventionException(Exception):
|
||||
pass
|
||||
|
||||
class ExceptionInterventionHandler(DefaultInterventionHandler): # type: ignore
|
||||
async def on_send(self, message: MessageType, *, sender: AgentId | None, recipient: AgentId) -> MessageType | type[DropMessage]: # type: ignore
|
||||
raise InterventionException
|
||||
|
||||
handler = ExceptionInterventionHandler()
|
||||
runtime = SingleThreadedAgentRuntime(before_send=handler)
|
||||
|
||||
long_running = runtime.register_and_get("name", LoopbackAgent)
|
||||
response = runtime.send_message(MessageType(), recipient=long_running)
|
||||
|
||||
while not response.done():
|
||||
await runtime.process_next()
|
||||
|
||||
with pytest.raises(InterventionException):
|
||||
await response
|
||||
|
||||
long_running_agent: LoopbackAgent = runtime._get_agent(long_running) # type: ignore
|
||||
assert long_running_agent.num_calls == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_intervention_raise_exception_on_respond() -> None:
|
||||
|
||||
class InterventionException(Exception):
|
||||
pass
|
||||
|
||||
class ExceptionInterventionHandler(DefaultInterventionHandler): # type: ignore
|
||||
async def on_response(self, message: MessageType, *, sender: AgentId, recipient: AgentId | None) -> MessageType | type[DropMessage]: # type: ignore
|
||||
raise InterventionException
|
||||
|
||||
handler = ExceptionInterventionHandler()
|
||||
runtime = SingleThreadedAgentRuntime(before_send=handler)
|
||||
|
||||
long_running = runtime.register_and_get("name", LoopbackAgent)
|
||||
response = runtime.send_message(MessageType(), recipient=long_running)
|
||||
|
||||
while not response.done():
|
||||
await runtime.process_next()
|
||||
|
||||
with pytest.raises(InterventionException):
|
||||
await response
|
||||
|
||||
long_running_agent: LoopbackAgent = runtime._get_agent(long_running) # type: ignore
|
||||
assert long_running_agent.num_calls == 1
|
||||
32
python/tests/test_llm_usage.py
Normal file
32
python/tests/test_llm_usage.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import logging
|
||||
|
||||
from agnext.application.logging import EVENT_LOGGER_NAME, LLMCallEvent, LLMUsageTracker
|
||||
|
||||
|
||||
def test_llm_usage() -> None:
|
||||
|
||||
# Set up the logging configuration to use the custom handler
|
||||
logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
logger.setLevel(logging.INFO)
|
||||
llm_usage = LLMUsageTracker()
|
||||
logger.handlers = [llm_usage]
|
||||
|
||||
logger.info(LLMCallEvent(prompt_tokens=10, completion_tokens=20))
|
||||
|
||||
assert llm_usage.prompt_tokens == 10
|
||||
assert llm_usage.completion_tokens == 20
|
||||
|
||||
logger.info(LLMCallEvent(prompt_tokens=1, completion_tokens=1))
|
||||
|
||||
assert llm_usage.prompt_tokens == 11
|
||||
assert llm_usage.completion_tokens == 21
|
||||
|
||||
llm_usage.reset()
|
||||
|
||||
assert llm_usage.prompt_tokens == 0
|
||||
assert llm_usage.completion_tokens == 0
|
||||
|
||||
logger.info(LLMCallEvent(prompt_tokens=1, completion_tokens=1))
|
||||
|
||||
assert llm_usage.prompt_tokens == 1
|
||||
assert llm_usage.completion_tokens == 1
|
||||
75
python/tests/test_runtime.py
Normal file
75
python/tests/test_runtime.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.core import BaseAgent, CancellationToken
|
||||
from test_utils import LoopbackAgent, MessageType
|
||||
|
||||
|
||||
class NoopAgent(BaseAgent): # type: ignore
|
||||
def __init__(self) -> None: # type: ignore
|
||||
super().__init__("A no op agent", [])
|
||||
|
||||
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any: # type: ignore
|
||||
raise NotImplementedError
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_names_must_be_unique() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
_agent1 = runtime.register_and_get("name1", NoopAgent)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_agent1 = runtime.register_and_get("name1", NoopAgent)
|
||||
|
||||
_agent1 = runtime.register_and_get("name3", NoopAgent)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_receives_publish() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
runtime.register("name", LoopbackAgent)
|
||||
await runtime.publish_message(MessageType(), namespace="default")
|
||||
|
||||
while len(runtime.unprocessed_messages) > 0 or runtime.outstanding_tasks > 0:
|
||||
await runtime.process_next()
|
||||
|
||||
# Agent in default namespace should have received the message
|
||||
long_running_agent: LoopbackAgent = runtime._get_agent(runtime.get("name")) # type: ignore
|
||||
assert long_running_agent.num_calls == 1
|
||||
|
||||
# Agent in other namespace should not have received the message
|
||||
other_long_running_agent: LoopbackAgent = runtime._get_agent(runtime.get("name", namespace="other")) # type: ignore
|
||||
assert other_long_running_agent.num_calls == 0
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_try_instantiate_agent_invalid_namespace() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
runtime.register("name", LoopbackAgent, valid_namespaces=["default"])
|
||||
await runtime.publish_message(MessageType(), namespace="non_default")
|
||||
|
||||
while len(runtime.unprocessed_messages) > 0 or runtime.outstanding_tasks > 0:
|
||||
await runtime.process_next()
|
||||
|
||||
# Agent in default namespace should have received the message
|
||||
long_running_agent: LoopbackAgent = runtime._get_agent(runtime.get("name")) # type: ignore
|
||||
assert long_running_agent.num_calls == 0
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_agent = runtime.get("name", namespace="non_default")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_crosses_namepace() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
runtime.register("name", LoopbackAgent)
|
||||
|
||||
default_ns_agent = runtime.get("name")
|
||||
non_default_ns_agent = runtime.get("name", namespace="non_default")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await runtime.send_message(MessageType(), default_ns_agent, sender=non_default_ns_agent)
|
||||
|
||||
65
python/tests/test_state.py
Normal file
65
python/tests/test_state.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from typing import Any, Mapping, Sequence
|
||||
|
||||
import pytest
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.core import BaseAgent, CancellationToken
|
||||
|
||||
|
||||
class StatefulAgent(BaseAgent): # type: ignore
|
||||
def __init__(self) -> None: # type: ignore
|
||||
super().__init__("A stateful agent", [])
|
||||
self.state = 0
|
||||
|
||||
@property
|
||||
def subscriptions(self) -> Sequence[type]:
|
||||
return []
|
||||
|
||||
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any: # type: ignore
|
||||
raise NotImplementedError
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
return {"state": self.state}
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
self.state = state["state"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_can_save_state() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
agent1_id = runtime.register_and_get("name1", StatefulAgent)
|
||||
agent1: StatefulAgent = runtime._get_agent(agent1_id) # type: ignore
|
||||
assert agent1.state == 0
|
||||
agent1.state = 1
|
||||
assert agent1.state == 1
|
||||
|
||||
agent1_state = agent1.save_state()
|
||||
|
||||
agent1.state = 2
|
||||
assert agent1.state == 2
|
||||
|
||||
agent1.load_state(agent1_state)
|
||||
assert agent1.state == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_can_save_state() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
agent1_id = runtime.register_and_get("name1", StatefulAgent)
|
||||
agent1: StatefulAgent = runtime._get_agent(agent1_id) # type: ignore
|
||||
assert agent1.state == 0
|
||||
agent1.state = 1
|
||||
assert agent1.state == 1
|
||||
|
||||
runtime_state = runtime.save_state()
|
||||
|
||||
runtime2 = SingleThreadedAgentRuntime()
|
||||
agent2_id = runtime2.register_and_get("name1", StatefulAgent)
|
||||
agent2: StatefulAgent = runtime2._get_agent(agent2_id) # type: ignore
|
||||
|
||||
runtime2.load_state(runtime_state)
|
||||
assert agent2.state == 1
|
||||
|
||||
|
||||
|
||||
288
python/tests/test_tools.py
Normal file
288
python/tests/test_tools.py
Normal file
@@ -0,0 +1,288 @@
|
||||
|
||||
import inspect
|
||||
from typing import Annotated
|
||||
|
||||
import pytest
|
||||
from agnext.components._function_utils import get_typed_signature
|
||||
from agnext.components.tools import BaseTool, FunctionTool
|
||||
from agnext.core import CancellationToken
|
||||
from pydantic import BaseModel, Field, model_serializer
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
|
||||
class MyArgs(BaseModel):
|
||||
query: str = Field(description="The description.")
|
||||
|
||||
|
||||
class MyResult(BaseModel):
|
||||
result: str = Field(description="The other description.")
|
||||
|
||||
|
||||
class MyTool(BaseTool[MyArgs, MyResult]):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
args_type=MyArgs,
|
||||
return_type=MyResult,
|
||||
name="TestTool",
|
||||
description="Description of test tool.",
|
||||
)
|
||||
self.called_count = 0
|
||||
|
||||
async def run(self, args: MyArgs, cancellation_token: CancellationToken) -> MyResult:
|
||||
self.called_count += 1
|
||||
return MyResult(result="value")
|
||||
|
||||
def test_tool_schema_generation() -> None:
|
||||
schema = MyTool().schema
|
||||
|
||||
assert schema["name"] == "TestTool"
|
||||
assert "description" in schema
|
||||
assert schema["description"] == "Description of test tool."
|
||||
assert "parameters" in schema
|
||||
assert schema["parameters"]["type"] == "object"
|
||||
assert "properties" in schema["parameters"]
|
||||
assert schema["parameters"]["properties"]["query"]["description"] == "The description."
|
||||
assert schema["parameters"]["properties"]["query"]["type"] == "string"
|
||||
assert "required" in schema["parameters"]
|
||||
assert schema["parameters"]["required"] == ["query"]
|
||||
assert len(schema["parameters"]["properties"]) == 1
|
||||
|
||||
def test_func_tool_schema_generation() -> None:
|
||||
def my_function(arg: str, other: Annotated[int, "int arg"], nonrequired: int = 5) -> MyResult:
|
||||
return MyResult(result="test")
|
||||
tool = FunctionTool(my_function, description="Function tool.")
|
||||
schema = tool.schema
|
||||
|
||||
assert schema["name"] == "my_function"
|
||||
assert "description" in schema
|
||||
assert schema["description"] == "Function tool."
|
||||
assert "parameters" in schema
|
||||
assert schema["parameters"]["type"] == "object"
|
||||
assert schema["parameters"]["properties"].keys() == {"arg", "other", "nonrequired"}
|
||||
assert schema["parameters"]["properties"]["arg"]["type"] == "string"
|
||||
assert schema["parameters"]["properties"]["arg"]["description"] == "arg"
|
||||
assert schema["parameters"]["properties"]["other"]["type"] == "integer"
|
||||
assert schema["parameters"]["properties"]["other"]["description"] == "int arg"
|
||||
assert schema["parameters"]["properties"]["nonrequired"]["type"] == "integer"
|
||||
assert schema["parameters"]["properties"]["nonrequired"]["description"] == "nonrequired"
|
||||
assert "required" in schema["parameters"]
|
||||
assert schema["parameters"]["required"] == ["arg", "other"]
|
||||
assert len(schema["parameters"]["properties"]) == 3
|
||||
|
||||
def test_func_tool_schema_generation_only_default_arg() -> None:
|
||||
def my_function(arg: str = "default") -> MyResult:
|
||||
return MyResult(result="test")
|
||||
tool = FunctionTool(my_function, description="Function tool.")
|
||||
schema = tool.schema
|
||||
|
||||
assert schema["name"] == "my_function"
|
||||
assert "description" in schema
|
||||
assert schema["description"] == "Function tool."
|
||||
assert "parameters" in schema
|
||||
assert len(schema["parameters"]["properties"]) == 1
|
||||
assert schema["parameters"]["properties"]["arg"]["type"] == "string"
|
||||
assert schema["parameters"]["properties"]["arg"]["description"] == "arg"
|
||||
assert "required" not in schema["parameters"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_run()-> None:
|
||||
tool = MyTool()
|
||||
result = await tool.run_json({"query": "test"}, CancellationToken())
|
||||
|
||||
assert isinstance(result, MyResult)
|
||||
assert result.result == "value"
|
||||
assert tool.called_count == 1
|
||||
|
||||
result = await tool.run_json({"query": "test"}, CancellationToken())
|
||||
result = await tool.run_json({"query": "test"}, CancellationToken())
|
||||
|
||||
assert tool.called_count == 3
|
||||
|
||||
|
||||
def test_tool_properties()-> None:
|
||||
tool = MyTool()
|
||||
|
||||
assert tool.name == "TestTool"
|
||||
assert tool.description == "Description of test tool."
|
||||
assert tool.args_type() == MyArgs
|
||||
assert tool.return_type() == MyResult
|
||||
assert tool.state_type() is None
|
||||
|
||||
def test_get_typed_signature()-> None:
|
||||
def my_function() -> str:
|
||||
return "result"
|
||||
|
||||
sig = get_typed_signature(my_function)
|
||||
assert isinstance(sig, inspect.Signature)
|
||||
assert len(sig.parameters) == 0
|
||||
assert sig.return_annotation == str
|
||||
|
||||
def test_get_typed_signature_annotated()-> None:
|
||||
def my_function() -> Annotated[str, "The return type"]:
|
||||
return "result"
|
||||
|
||||
sig = get_typed_signature(my_function)
|
||||
assert isinstance(sig, inspect.Signature)
|
||||
assert len(sig.parameters) == 0
|
||||
assert sig.return_annotation == Annotated[str, "The return type"]
|
||||
|
||||
def test_get_typed_signature_string()-> None:
|
||||
def my_function() -> "str":
|
||||
return "result"
|
||||
|
||||
sig = get_typed_signature(my_function)
|
||||
assert isinstance(sig, inspect.Signature)
|
||||
assert len(sig.parameters) == 0
|
||||
assert sig.return_annotation == str
|
||||
|
||||
|
||||
def test_func_tool()-> None:
|
||||
def my_function() -> str:
|
||||
return "result"
|
||||
|
||||
tool = FunctionTool(my_function, description="Function tool.")
|
||||
assert tool.name == "my_function"
|
||||
assert tool.description == "Function tool."
|
||||
assert issubclass(tool.args_type(), BaseModel)
|
||||
assert issubclass(tool.return_type(), str)
|
||||
assert tool.state_type() is None
|
||||
|
||||
def test_func_tool_annotated_arg()-> None:
|
||||
def my_function(my_arg: Annotated[str, "test description"]) -> str:
|
||||
return "result"
|
||||
|
||||
tool = FunctionTool(my_function, description="Function tool.")
|
||||
assert tool.name == "my_function"
|
||||
assert tool.description == "Function tool."
|
||||
assert issubclass(tool.args_type(), BaseModel)
|
||||
assert issubclass(tool.return_type(), str)
|
||||
assert tool.args_type().model_fields["my_arg"].description == "test description"
|
||||
assert tool.args_type().model_fields["my_arg"].annotation == str
|
||||
assert tool.args_type().model_fields["my_arg"].is_required() is True
|
||||
assert tool.args_type().model_fields["my_arg"].default is PydanticUndefined
|
||||
assert len(tool.args_type().model_fields) == 1
|
||||
assert tool.return_type() == str
|
||||
assert tool.state_type() is None
|
||||
|
||||
def test_func_tool_return_annotated()-> None:
|
||||
def my_function() -> Annotated[str, "test description"]:
|
||||
return "result"
|
||||
|
||||
tool = FunctionTool(my_function, description="Function tool.")
|
||||
assert tool.name == "my_function"
|
||||
assert tool.description == "Function tool."
|
||||
assert issubclass(tool.args_type(), BaseModel)
|
||||
assert tool.return_type() == str
|
||||
assert tool.state_type() is None
|
||||
|
||||
def test_func_tool_no_args()-> None:
|
||||
def my_function() -> str:
|
||||
return "result"
|
||||
|
||||
tool = FunctionTool(my_function, description="Function tool.")
|
||||
assert tool.name == "my_function"
|
||||
assert tool.description == "Function tool."
|
||||
assert issubclass(tool.args_type(), BaseModel)
|
||||
assert len(tool.args_type().model_fields) == 0
|
||||
assert tool.return_type() == str
|
||||
assert tool.state_type() is None
|
||||
|
||||
def test_func_tool_return_none()-> None:
|
||||
def my_function() -> None:
|
||||
return None
|
||||
|
||||
tool = FunctionTool(my_function, description="Function tool.")
|
||||
assert tool.name == "my_function"
|
||||
assert tool.description == "Function tool."
|
||||
assert issubclass(tool.args_type(), BaseModel)
|
||||
assert tool.return_type() is None
|
||||
assert tool.state_type() is None
|
||||
|
||||
def test_func_tool_return_base_model()-> None:
|
||||
def my_function() -> MyResult:
|
||||
return MyResult(result="value")
|
||||
|
||||
tool = FunctionTool(my_function, description="Function tool.")
|
||||
assert tool.name == "my_function"
|
||||
assert tool.description == "Function tool."
|
||||
assert issubclass(tool.args_type(), BaseModel)
|
||||
assert tool.return_type() is MyResult
|
||||
assert tool.state_type() is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_func_call_tool()-> None:
|
||||
def my_function() -> str:
|
||||
return "result"
|
||||
|
||||
tool = FunctionTool(my_function, description="Function tool.")
|
||||
result = await tool.run_json({}, CancellationToken())
|
||||
assert result == "result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_func_call_tool_base_model()-> None:
|
||||
def my_function() -> MyResult:
|
||||
return MyResult(result="value")
|
||||
|
||||
tool = FunctionTool(my_function, description="Function tool.")
|
||||
result = await tool.run_json({}, CancellationToken())
|
||||
assert isinstance(result, MyResult)
|
||||
assert result.result == "value"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_func_call_tool_with_arg_base_model()-> None:
|
||||
def my_function(arg: str) -> MyResult:
|
||||
return MyResult(result="value")
|
||||
|
||||
tool = FunctionTool(my_function, description="Function tool.")
|
||||
result = await tool.run_json({"arg": "test"}, CancellationToken())
|
||||
assert isinstance(result, MyResult)
|
||||
assert result.result == "value"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_func_str_res()-> None:
|
||||
def my_function(arg: str) -> str:
|
||||
return "test"
|
||||
|
||||
tool = FunctionTool(my_function, description="Function tool.")
|
||||
result = await tool.run_json({"arg": "test"}, CancellationToken())
|
||||
assert tool.return_value_as_string(result) == "test"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_func_base_model_res()-> None:
|
||||
|
||||
|
||||
def my_function(arg: str) -> MyResult:
|
||||
return MyResult(result="test")
|
||||
|
||||
tool = FunctionTool(my_function, description="Function tool.")
|
||||
result = await tool.run_json({"arg": "test"}, CancellationToken())
|
||||
assert tool.return_value_as_string(result) == '{"result": "test"}'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_func_base_model_custom_dump_res()-> None:
|
||||
|
||||
class MyResultCustomDump(BaseModel):
|
||||
result: str = Field(description="The other description.")
|
||||
|
||||
@model_serializer
|
||||
def ser_model(self) -> str:
|
||||
return "custom: " + self.result
|
||||
|
||||
|
||||
def my_function(arg: str) -> MyResultCustomDump:
|
||||
return MyResultCustomDump(result="test")
|
||||
|
||||
tool = FunctionTool(my_function, description="Function tool.")
|
||||
result = await tool.run_json({"arg": "test"}, CancellationToken())
|
||||
assert tool.return_value_as_string(result) == "custom: test"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_func_int_res()-> None:
|
||||
def my_function(arg: int) -> int:
|
||||
return arg
|
||||
|
||||
tool = FunctionTool(my_function, description="Function tool.")
|
||||
result = await tool.run_json({"arg": 5}, CancellationToken())
|
||||
assert tool.return_value_as_string(result) == "5"
|
||||
39
python/tests/test_types.py
Normal file
39
python/tests/test_types.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from types import NoneType
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from agnext.components._type_routed_agent import AnyType, get_types, message_handler
|
||||
from agnext.core import CancellationToken
|
||||
|
||||
|
||||
def test_get_types() -> None:
|
||||
assert get_types(Union[int, str]) == (int, str)
|
||||
assert get_types(int | str) == (int, str)
|
||||
assert get_types(int) == (int,)
|
||||
assert get_types(str) == (str,)
|
||||
assert get_types("test") is None
|
||||
assert get_types(Optional[int]) == (int, NoneType)
|
||||
assert get_types(NoneType) == (NoneType,)
|
||||
assert get_types(None) == (NoneType,)
|
||||
|
||||
|
||||
def test_handler() -> None:
|
||||
|
||||
class HandlerClass:
|
||||
@message_handler()
|
||||
async def handler(self, message: int, cancellation_token: CancellationToken) -> Any:
|
||||
return None
|
||||
|
||||
@message_handler()
|
||||
async def handler2(self, message: str | bool, cancellation_token: CancellationToken) -> None:
|
||||
return None
|
||||
|
||||
assert HandlerClass.handler.target_types == [int]
|
||||
assert HandlerClass.handler.produces_types == [AnyType]
|
||||
|
||||
assert HandlerClass.handler2.target_types == [str, bool]
|
||||
assert HandlerClass.handler2.produces_types == [NoneType]
|
||||
|
||||
class HandlerClass:
|
||||
@message_handler()
|
||||
async def handler(self, message: int, cancellation_token: CancellationToken) -> Any:
|
||||
return None
|
||||
20
python/tests/test_utils/__init__.py
Normal file
20
python/tests/test_utils/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from agnext.components import TypeRoutedAgent, message_handler
|
||||
from agnext.core import CancellationToken
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageType:
|
||||
...
|
||||
|
||||
class LoopbackAgent(TypeRoutedAgent):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("A loop back agent.")
|
||||
self.num_calls = 0
|
||||
|
||||
|
||||
@message_handler
|
||||
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
|
||||
self.num_calls += 1
|
||||
return message
|
||||
Reference in New Issue
Block a user