From 2dc7af87efb090da06c7e6626b656331b8a9ce78 Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Wed, 29 May 2024 17:12:02 -0400 Subject: [PATCH] Add function and code execution (#34) * WIP code execution * add tests, reorganize * fix polars test * credit statements * attributions --- .github/workflows/checks.yml | 5 +- pyproject.toml | 27 +- .../code_executor/__init__.py | 14 + .../agent_components/code_executor/_base.py | 66 ++++ .../code_executor/_impl/__init__.py | 0 .../_impl/command_line_code_result.py | 11 + .../_impl/local_commandline_code_executor.py | 272 +++++++++++++++++ .../_impl/markdown_code_extractor.py | 37 +++ .../code_executor/_impl/utils.py | 88 ++++++ src/agnext/agent_components/func_with_reqs.py | 200 ++++++++++++ .../function_executor/__init__.py | 4 + .../function_executor/_base.py | 36 +++ .../function_executor/_impl/__init__.py | 0 .../_impl/in_process_function_executor.py | 40 +++ src/agnext/agent_components/function_utils.py | 284 ++++++++++++++++++ .../agent_components/pydantic_compat.py | 63 ++++ test.sh | 2 +- .../test_commandline_code_executor.py | 105 +++++++ .../execution/test_markdown_code_extractor.py | 118 ++++++++ .../execution/test_user_defined_functions.py | 199 ++++++++++++ 20 files changed, 1564 insertions(+), 7 deletions(-) create mode 100644 src/agnext/agent_components/code_executor/__init__.py create mode 100644 src/agnext/agent_components/code_executor/_base.py create mode 100644 src/agnext/agent_components/code_executor/_impl/__init__.py create mode 100644 src/agnext/agent_components/code_executor/_impl/command_line_code_result.py create mode 100644 src/agnext/agent_components/code_executor/_impl/local_commandline_code_executor.py create mode 100644 src/agnext/agent_components/code_executor/_impl/markdown_code_extractor.py create mode 100644 src/agnext/agent_components/code_executor/_impl/utils.py create mode 100644 src/agnext/agent_components/func_with_reqs.py create mode 100644 src/agnext/agent_components/function_executor/__init__.py create mode 100644 src/agnext/agent_components/function_executor/_base.py create mode 100644 src/agnext/agent_components/function_executor/_impl/__init__.py create mode 100644 src/agnext/agent_components/function_executor/_impl/in_process_function_executor.py create mode 100644 src/agnext/agent_components/function_utils.py create mode 100644 src/agnext/agent_components/pydantic_compat.py create mode 100644 tests/execution/test_commandline_code_executor.py create mode 100644 tests/execution/test_markdown_code_extractor.py create mode 100644 tests/execution/test_user_defined_functions.py diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index df6236d68..5f781d060 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -55,14 +55,15 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["pypy3.10", "3.10", "3.11", "3.12"] + # "pypy3.10" disabled until better example than polars used in tests + python-version: ["3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - run: pip install ".[dev]" - - run: pytest + - run: pytest -n auto docs: runs-on: ubuntu-latest diff --git a/pyproject.toml b/pyproject.toml index 0ae9b954c..e79ae9535 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,11 +13,26 @@ classifiers = [ "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ] -dependencies = ["openai>=1.3", "pillow", "aiohttp", "typing-extensions"] +dependencies = [ + "openai>=1.3", + "pillow", + "aiohttp", + "typing-extensions", + "pydantic>=1.10,<3", +] [project.optional-dependencies] -dev = ["ruff==0.4.6", "pyright", "mypy", "pytest", "pytest-asyncio", "types-Pillow"] -docs = [ "sphinx", "furo", "sphinxcontrib-apidoc"] +dev = [ + "ruff==0.4.6", + "pyright", + "mypy", + "pytest", + "pytest-asyncio", + "pytest-xdist", + "types-Pillow", + "polars", +] +docs = ["sphinx", "furo", "sphinxcontrib-apidoc"] [tool.setuptools.package-data] agnext = ["py.typed"] @@ -30,9 +45,13 @@ target-version = "py310" include = ["src/**", "examples/**"] [tool.ruff.lint] -select = ["E", "F", "W", "B", "Q", "I"] +select = ["E", "F", "W", "B", "Q", "I", "ASYNC"] ignore = ["F401", "E501"] +[tool.ruff.lint.flake8-tidy-imports] +[tool.ruff.lint.flake8-tidy-imports.banned-api] +"unittest".msg = "Use `pytest` instead." + [tool.mypy] files = ["src", "examples", "tests"] diff --git a/src/agnext/agent_components/code_executor/__init__.py b/src/agnext/agent_components/code_executor/__init__.py new file mode 100644 index 000000000..5df725747 --- /dev/null +++ b/src/agnext/agent_components/code_executor/__init__.py @@ -0,0 +1,14 @@ +from ._base import CodeBlock, CodeExecutor, CodeExtractor, CodeResult +from ._impl.command_line_code_result import CommandLineCodeResult +from ._impl.local_commandline_code_executor import LocalCommandLineCodeExecutor +from ._impl.markdown_code_extractor import MarkdownCodeExtractor + +__all__ = [ + "LocalCommandLineCodeExecutor", + "MarkdownCodeExtractor", + "CommandLineCodeResult", + "CodeBlock", + "CodeResult", + "CodeExecutor", + "CodeExtractor", +] diff --git a/src/agnext/agent_components/code_executor/_base.py b/src/agnext/agent_components/code_executor/_base.py new file mode 100644 index 000000000..46fabb3f3 --- /dev/null +++ b/src/agnext/agent_components/code_executor/_base.py @@ -0,0 +1,66 @@ +# File based from: https://github.com/microsoft/autogen/blob/main/autogen/coding/base.py +# Credit to original authors + +from __future__ import annotations + +from dataclasses import dataclass +from typing import List, Protocol, runtime_checkable + + +@dataclass +class CodeBlock: + """A code block extracted fromm an agent message.""" + + code: str + language: str + + +@dataclass +class CodeResult: + """Result of a code execution.""" + + exit_code: int + output: str + + +class CodeExtractor(Protocol): + """Extracts code blocks from a message.""" + + # TODO support text or multimodal message directly + def extract_code_blocks(self, message: str) -> List[CodeBlock]: + """Extract code blocks from a message. + + Args: + message (str): The message to extract code blocks from. + + Returns: + List[CodeBlock]: The extracted code blocks. + """ + ... + + +@runtime_checkable +class CodeExecutor(Protocol): + """Executes code blocks and returns the result.""" + + def execute_code_blocks(self, code_blocks: List[CodeBlock]) -> CodeResult: + """Execute code blocks and return the result. + + This method should be implemented by the code executor. + + Args: + code_blocks (List[CodeBlock]): The code blocks to execute. + + Returns: + CodeResult: The result of the code execution. + """ + ... + + def restart(self) -> None: + """Restart the code executor. + + This method should be implemented by the code executor. + + This method is called when the agent is reset. + """ + ... diff --git a/src/agnext/agent_components/code_executor/_impl/__init__.py b/src/agnext/agent_components/code_executor/_impl/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/agnext/agent_components/code_executor/_impl/command_line_code_result.py b/src/agnext/agent_components/code_executor/_impl/command_line_code_result.py new file mode 100644 index 000000000..cafc73f14 --- /dev/null +++ b/src/agnext/agent_components/code_executor/_impl/command_line_code_result.py @@ -0,0 +1,11 @@ +from dataclasses import dataclass +from typing import Optional + +from .._base import CodeResult + + +@dataclass +class CommandLineCodeResult(CodeResult): + """A code result class for command line code executor.""" + + code_file: Optional[str] diff --git a/src/agnext/agent_components/code_executor/_impl/local_commandline_code_executor.py b/src/agnext/agent_components/code_executor/_impl/local_commandline_code_executor.py new file mode 100644 index 000000000..927395732 --- /dev/null +++ b/src/agnext/agent_components/code_executor/_impl/local_commandline_code_executor.py @@ -0,0 +1,272 @@ +# File based from: https://github.com/microsoft/autogen/blob/main/autogen/coding/local_commandline_code_executor.py +# Credit to original authors + +import logging +import re +import subprocess +import sys +import warnings +from hashlib import md5 +from pathlib import Path +from string import Template +from typing import Any, Callable, ClassVar, List, Sequence, Union + +from typing_extensions import ParamSpec + +from ...func_with_reqs import ( + FunctionWithRequirements, + FunctionWithRequirementsStr, + build_python_functions_file, + to_stub, +) +from .._base import CodeBlock, CodeExecutor, CodeExtractor +from .command_line_code_result import CommandLineCodeResult +from .markdown_code_extractor import MarkdownCodeExtractor +from .utils import PYTHON_VARIANTS, get_file_name_from_content, lang_to_cmd, silence_pip # type: ignore + +__all__ = ("LocalCommandLineCodeExecutor",) + +A = ParamSpec("A") + + +class LocalCommandLineCodeExecutor(CodeExecutor): + SUPPORTED_LANGUAGES: ClassVar[List[str]] = ["bash", "shell", "sh", "pwsh", "powershell", "ps1", "python"] + FUNCTION_PROMPT_TEMPLATE: ClassVar[ + str + ] = """You have access to the following user defined functions. They can be accessed from the module called `$module_name` by their function names. + +For example, if there was a function called `foo` you could import it by writing `from $module_name import foo` + +$functions""" + + def __init__( + self, + timeout: int = 60, + work_dir: Union[Path, str] = Path("."), + functions: Sequence[ + Union[FunctionWithRequirements[Any, A], Callable[..., Any], FunctionWithRequirementsStr] + ] = [], + functions_module: str = "functions", + ): + """(Experimental) A code executor class that executes code through a local command line + environment. + + **This will execute LLM generated code on the local machine.** + + Each code block is saved as a file and executed in a separate process in + the working directory, and a unique file is generated and saved in the + working directory for each code block. + The code blocks are executed in the order they are received. + Command line code is sanitized using regular expression match against a list of dangerous commands in order to prevent self-destructive + commands from being executed which may potentially affect the users environment. + Currently the only supported languages is Python and shell scripts. + For Python code, use the language "python" for the code block. + For shell scripts, use the language "bash", "shell", or "sh" for the code + block. + + Args: + timeout (int): The timeout for code execution. Default is 60. + work_dir (str): The working directory for the code execution. If None, + a default working directory will be used. The default working + directory is the current directory ".". + functions (List[Union[FunctionWithRequirements[Any, A], Callable[..., Any]]]): A list of functions that are available to the code executor. Default is an empty list. + """ + + if timeout < 1: + raise ValueError("Timeout must be greater than or equal to 1.") + + if isinstance(work_dir, str): + work_dir = Path(work_dir) + + if not functions_module.isidentifier(): + raise ValueError("Module name must be a valid Python identifier") + + self._functions_module = functions_module + + work_dir.mkdir(exist_ok=True) + + self._timeout = timeout + self._work_dir: Path = work_dir + + self._functions = functions + # Setup could take some time so we intentionally wait for the first code block to do it. + if len(functions) > 0: + self._setup_functions_complete = False + else: + self._setup_functions_complete = True + + def format_functions_for_prompt(self, prompt_template: str = FUNCTION_PROMPT_TEMPLATE) -> str: + """(Experimental) Format the functions for a prompt. + + The template includes two variables: + - `$module_name`: The module name. + - `$functions`: The functions formatted as stubs with two newlines between each function. + + Args: + prompt_template (str): The prompt template. Default is the class default. + + Returns: + str: The formatted prompt. + """ + + template = Template(prompt_template) + return template.substitute( + module_name=self._functions_module, + functions="\n\n".join([to_stub(func) for func in self._functions]), + ) + + @property + def functions_module(self) -> str: + """(Experimental) The module name for the functions.""" + return self._functions_module + + @property + def functions(self) -> List[str]: + raise NotImplementedError + + @property + def timeout(self) -> int: + """(Experimental) The timeout for code execution.""" + return self._timeout + + @property + def work_dir(self) -> Path: + """(Experimental) The working directory for the code execution.""" + return self._work_dir + + @property + def code_extractor(self) -> CodeExtractor: + """(Experimental) Export a code extractor that can be used by an agent.""" + return MarkdownCodeExtractor() + + @staticmethod + def sanitize_command(lang: str, code: str) -> None: + """ + Sanitize the code block to prevent dangerous commands. + This approach acknowledges that while Docker or similar + containerization/sandboxing technologies provide a robust layer of security, + not all users may have Docker installed or may choose not to use it. + Therefore, having a baseline level of protection helps mitigate risks for users who, + either out of choice or necessity, run code outside of a sandboxed environment. + """ + dangerous_patterns = [ + (r"\brm\s+-rf\b", "Use of 'rm -rf' command is not allowed."), + (r"\bmv\b.*?\s+/dev/null", "Moving files to /dev/null is not allowed."), + (r"\bdd\b", "Use of 'dd' command is not allowed."), + (r">\s*/dev/sd[a-z][1-9]?", "Overwriting disk blocks directly is not allowed."), + (r":\(\)\{\s*:\|\:&\s*\};:", "Fork bombs are not allowed."), + ] + if lang in ["bash", "shell", "sh"]: + for pattern, message in dangerous_patterns: + if re.search(pattern, code): + raise ValueError(f"Potentially dangerous command detected: {message}") + + def _setup_functions(self) -> None: + func_file_content = build_python_functions_file(self._functions) + func_file = self._work_dir / f"{self._functions_module}.py" + func_file.write_text(func_file_content) + + # Collect requirements + lists_of_packages = [x.python_packages for x in self._functions if isinstance(x, FunctionWithRequirements)] + flattened_packages = [item for sublist in lists_of_packages for item in sublist] + required_packages = list(set(flattened_packages)) + if len(required_packages) > 0: + logging.info("Ensuring packages are installed in executor.") + + cmd = [sys.executable, "-m", "pip", "install"] + cmd.extend(required_packages) + + try: + result = subprocess.run( + cmd, cwd=self._work_dir, capture_output=True, text=True, timeout=float(self._timeout) + ) + except subprocess.TimeoutExpired as e: + raise ValueError("Pip install timed out") from e + + if result.returncode != 0: + raise ValueError(f"Pip install failed. {result.stdout}, {result.stderr}") + + # Attempt to load the function file to check for syntax errors, imports etc. + exec_result = self._execute_code_dont_check_setup([CodeBlock(code=func_file_content, language="python")]) + + if exec_result.exit_code != 0: + raise ValueError(f"Functions failed to load: {exec_result.output}") + + self._setup_functions_complete = True + + def execute_code_blocks(self, code_blocks: List[CodeBlock]) -> CommandLineCodeResult: + """(Experimental) Execute the code blocks and return the result. + + Args: + code_blocks (List[CodeBlock]): The code blocks to execute. + + Returns: + CommandLineCodeResult: The result of the code execution.""" + + if not self._setup_functions_complete: + self._setup_functions() + + return self._execute_code_dont_check_setup(code_blocks) + + def _execute_code_dont_check_setup(self, code_blocks: List[CodeBlock]) -> CommandLineCodeResult: + logs_all: str = "" + file_names: List[Path] = [] + exitcode = 0 + for code_block in code_blocks: + lang, code = code_block.language, code_block.code + lang = lang.lower() + + LocalCommandLineCodeExecutor.sanitize_command(lang, code) + code = silence_pip(code, lang) + + if lang in PYTHON_VARIANTS: + lang = "python" + + if lang not in self.SUPPORTED_LANGUAGES: + # In case the language is not supported, we return an error message. + exitcode = 1 + logs_all += "\n" + f"unknown language {lang}" + break + + try: + # Check if there is a filename comment + filename = get_file_name_from_content(code, self._work_dir) + except ValueError: + return CommandLineCodeResult(exit_code=1, output="Filename is not in the workspace", code_file=None) + + if filename is None: + # create a file with an automatically generated name + code_hash = md5(code.encode()).hexdigest() + filename = f"tmp_code_{code_hash}.{'py' if lang.startswith('python') else lang}" + + written_file = (self._work_dir / filename).resolve() + with written_file.open("w", encoding="utf-8") as f: + f.write(code) + file_names.append(written_file) + + program = sys.executable if lang.startswith("python") else lang_to_cmd(lang) + cmd = [program, str(written_file.absolute())] + + try: + result = subprocess.run( + cmd, cwd=self._work_dir, capture_output=True, text=True, timeout=float(self._timeout) + ) + except subprocess.TimeoutExpired: + logs_all += "\n Timeout" + # Same exit code as the timeout command on linux. + exitcode = 124 + break + + logs_all += result.stderr + logs_all += result.stdout + exitcode = result.returncode + + if exitcode != 0: + break + + code_file = str(file_names[0]) if len(file_names) > 0 else None + return CommandLineCodeResult(exit_code=exitcode, output=logs_all, code_file=code_file) + + def restart(self) -> None: + """(Experimental) Restart the code executor.""" + warnings.warn("Restarting local command line code executor is not supported. No action is taken.", stacklevel=2) diff --git a/src/agnext/agent_components/code_executor/_impl/markdown_code_extractor.py b/src/agnext/agent_components/code_executor/_impl/markdown_code_extractor.py new file mode 100644 index 000000000..18239ecd6 --- /dev/null +++ b/src/agnext/agent_components/code_executor/_impl/markdown_code_extractor.py @@ -0,0 +1,37 @@ +# File based from: https://github.com/microsoft/autogen/blob/main/autogen/coding/markdown_code_extractor.py +# Credit to original authors + +import re +from typing import List + +from .._base import CodeBlock, CodeExtractor +from .utils import CODE_BLOCK_PATTERN, infer_lang + +__all__ = ("MarkdownCodeExtractor",) + + +class MarkdownCodeExtractor(CodeExtractor): + """(Experimental) A class that extracts code blocks from a message using Markdown syntax.""" + + def extract_code_blocks(self, message: str) -> List[CodeBlock]: + """(Experimental) Extract code blocks from a message. If no code blocks are found, + return an empty list. + + Args: + message (str): The message to extract code blocks from. + + Returns: + List[CodeBlock]: The extracted code blocks or an empty list. + """ + + match = re.findall(CODE_BLOCK_PATTERN, message, flags=re.DOTALL) + if not match: + return [] + code_blocks: List[CodeBlock] = [] + for lang, code in match: + if lang == "": + lang = infer_lang(code) + if lang == "unknown": + lang = "" + code_blocks.append(CodeBlock(code=code, language=lang)) + return code_blocks diff --git a/src/agnext/agent_components/code_executor/_impl/utils.py b/src/agnext/agent_components/code_executor/_impl/utils.py new file mode 100644 index 000000000..3fc39054a --- /dev/null +++ b/src/agnext/agent_components/code_executor/_impl/utils.py @@ -0,0 +1,88 @@ +# File based from: https://github.com/microsoft/autogen/blob/main/autogen/coding/utils.py +# Credit to original authors + +# Will return the filename relative to the workspace path +import re +from pathlib import Path +from typing import Optional + + +# Raises ValueError if the file is not in the workspace +def get_file_name_from_content(code: str, workspace_path: Path) -> Optional[str]: + first_line = code.split("\n")[0] + # TODO - support other languages + if first_line.startswith("# filename:"): + filename = first_line.split(":")[1].strip() + + # Handle relative paths in the filename + path = Path(filename) + if not path.is_absolute(): + path = workspace_path / path + path = path.resolve() + # Throws an error if the file is not in the workspace + relative = path.relative_to(workspace_path.resolve()) + return str(relative) + + return None + + +def silence_pip(code: str, lang: str) -> str: + """Apply -qqq flag to pip install commands.""" + if lang == "python": + regex = r"^! ?pip install" + elif lang in ["bash", "shell", "sh", "pwsh", "powershell", "ps1"]: + regex = r"^pip install" + else: + return code + + # Find lines that start with pip install and make sure "-qqq" flag is added. + lines = code.split("\n") + for i, line in enumerate(lines): + # use regex to find lines that start with pip install. + match = re.search(regex, line) + if match is not None: + if "-qqq" not in line: + lines[i] = line.replace(match.group(0), match.group(0) + " -qqq") + return "\n".join(lines) + + +PYTHON_VARIANTS = ["python", "Python", "py"] + + +def lang_to_cmd(lang: str) -> str: + if lang in PYTHON_VARIANTS: + return "python" + if lang.startswith("python") or lang in ["bash", "sh"]: + return lang + if lang in ["shell"]: + return "sh" + else: + raise ValueError(f"Unsupported language: {lang}") + + +# Regular expression for finding a code block +# ```[ \t]*(\w+)?[ \t]*\r?\n(.*?)[ \t]*\r?\n``` Matches multi-line code blocks. +# The [ \t]* matches the potential spaces before language name. +# The (\w+)? matches the language, where the ? indicates it is optional. +# The [ \t]* matches the potential spaces (not newlines) after language name. +# The \r?\n makes sure there is a linebreak after ```. +# The (.*?) matches the code itself (non-greedy). +# The \r?\n makes sure there is a linebreak before ```. +# The [ \t]* matches the potential spaces before closing ``` (the spec allows indentation). +CODE_BLOCK_PATTERN = r"```[ \t]*(\w+)?[ \t]*\r?\n(.*?)\r?\n[ \t]*```" + + +def infer_lang(code: str) -> str: + """infer the language for the code. + TODO: make it robust. + """ + if code.startswith("python ") or code.startswith("pip") or code.startswith("python3 "): + return "sh" + + # check if code is a valid python code + try: + compile(code, "test", "exec") + return "python" + except SyntaxError: + # not a valid python code + return "unknown" diff --git a/src/agnext/agent_components/func_with_reqs.py b/src/agnext/agent_components/func_with_reqs.py new file mode 100644 index 000000000..1ef01cedd --- /dev/null +++ b/src/agnext/agent_components/func_with_reqs.py @@ -0,0 +1,200 @@ +# File based from: https://github.com/microsoft/autogen/blob/main/autogen/coding/func_with_reqs.py +# Credit to original authors + +from __future__ import annotations + +import functools +import inspect +from dataclasses import dataclass, field +from importlib.abc import SourceLoader +from importlib.util import module_from_spec, spec_from_loader +from textwrap import dedent, indent +from typing import Any, Callable, Generic, List, Sequence, Set, TypeVar, Union + +from typing_extensions import ParamSpec + +T = TypeVar("T") +P = ParamSpec("P") + + +def _to_code(func: Union[FunctionWithRequirements[T, P], Callable[P, T], FunctionWithRequirementsStr]) -> str: + if isinstance(func, FunctionWithRequirementsStr): + return func.func + + code = inspect.getsource(func) + # Strip the decorator + if code.startswith("@"): + code = code[code.index("\n") + 1 :] + return code + + +@dataclass +class Alias: + name: str + alias: str + + +@dataclass +class ImportFromModule: + module: str + imports: List[Union[str, Alias]] + + +Import = Union[str, ImportFromModule, Alias] + + +def _import_to_str(im: Import) -> str: + if isinstance(im, str): + return f"import {im}" + elif isinstance(im, Alias): + return f"import {im.name} as {im.alias}" + else: + + def to_str(i: Union[str, Alias]) -> str: + if isinstance(i, str): + return i + else: + return f"{i.name} as {i.alias}" + + imports = ", ".join(map(to_str, im.imports)) + return f"from {im.module} import {imports}" + + +class _StringLoader(SourceLoader): + def __init__(self, data: str): + self.data = data + + def get_source(self, fullname: str) -> str: + return self.data + + def get_data(self, path: str) -> bytes: + return self.data.encode("utf-8") + + def get_filename(self, fullname: str) -> str: + return "/" + fullname + ".py" + + +@dataclass +class FunctionWithRequirementsStr: + func: str + compiled_func: Callable[..., Any] + _func_name: str + python_packages: Sequence[str] = field(default_factory=list) + global_imports: Sequence[Import] = field(default_factory=list) + + def __init__(self, func: str, python_packages: Sequence[str] = [], global_imports: Sequence[Import] = []): + self.func = func + self.python_packages = python_packages + self.global_imports = global_imports + + module_name = "func_module" + loader = _StringLoader(func) + spec = spec_from_loader(module_name, loader) + if spec is None: + raise ValueError("Could not create spec") + module = module_from_spec(spec) + if spec.loader is None: + raise ValueError("Could not create loader") + + try: + spec.loader.exec_module(module) + except Exception as e: + raise ValueError(f"Could not compile function: {e}") from e + + functions = inspect.getmembers(module, inspect.isfunction) + if len(functions) != 1: + raise ValueError("The string must contain exactly one function") + + self._func_name, self.compiled_func = functions[0] + + def __call__(self, *args: Any, **kwargs: Any) -> None: + raise NotImplementedError("String based function with requirement objects are not directly callable") + + +@dataclass +class FunctionWithRequirements(Generic[T, P]): + func: Callable[P, T] + python_packages: Sequence[str] = field(default_factory=list) + global_imports: Sequence[Import] = field(default_factory=list) + + @classmethod + def from_callable( + cls, func: Callable[P, T], python_packages: Sequence[str] = [], global_imports: Sequence[Import] = [] + ) -> FunctionWithRequirements[T, P]: + return cls(python_packages=python_packages, global_imports=global_imports, func=func) + + @staticmethod + def from_str( + func: str, python_packages: Sequence[str] = [], global_imports: Sequence[Import] = [] + ) -> FunctionWithRequirementsStr: + return FunctionWithRequirementsStr(func=func, python_packages=python_packages, global_imports=global_imports) + + # Type this based on F + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: + return self.func(*args, **kwargs) + + +def with_requirements( + python_packages: Sequence[str] = [], global_imports: Sequence[Import] = [] +) -> Callable[[Callable[P, T]], FunctionWithRequirements[T, P]]: + """Decorate a function with package and import requirements + + Args: + python_packages (List[str], optional): Packages required to function. Can include version info.. Defaults to []. + global_imports (List[Import], optional): Required imports. Defaults to []. + + Returns: + Callable[[Callable[P, T]], FunctionWithRequirements[T, P]]: The decorated function + """ + + def wrapper(func: Callable[P, T]) -> FunctionWithRequirements[T, P]: + func_with_reqs = FunctionWithRequirements( + python_packages=python_packages, global_imports=global_imports, func=func + ) + + functools.update_wrapper(func_with_reqs, func) + return func_with_reqs + + return wrapper + + +def build_python_functions_file( + funcs: Sequence[Union[FunctionWithRequirements[Any, P], Callable[..., Any], FunctionWithRequirementsStr]], +) -> str: + # First collect all global imports + global_imports: Set[Import] = set() + for func in funcs: + if isinstance(func, (FunctionWithRequirements, FunctionWithRequirementsStr)): + global_imports.update(func.global_imports) + + content = "\n".join(map(_import_to_str, global_imports)) + "\n\n" + + for func in funcs: + content += _to_code(func) + "\n\n" + + return content + + +def to_stub(func: Union[Callable[..., Any], FunctionWithRequirementsStr]) -> str: + """Generate a stub for a function as a string + + Args: + func (Callable[..., Any]): The function to generate a stub for + + Returns: + str: The stub for the function + """ + if isinstance(func, FunctionWithRequirementsStr): + return to_stub(func.compiled_func) + + content = f"def {func.__name__}{inspect.signature(func)}:\n" + docstring = func.__doc__ + + if docstring: + docstring = dedent(docstring) + docstring = '"""' + docstring + '"""' + docstring = indent(docstring, " ") + content += docstring + "\n" + + content += " ..." + return content diff --git a/src/agnext/agent_components/function_executor/__init__.py b/src/agnext/agent_components/function_executor/__init__.py new file mode 100644 index 000000000..7c3a41bd8 --- /dev/null +++ b/src/agnext/agent_components/function_executor/__init__.py @@ -0,0 +1,4 @@ +from ._base import FunctionExecutor, FunctionInfo, into_function_definition +from ._impl.in_process_function_executor import InProcessFunctionExecutor + +__all__ = ["FunctionExecutor", "FunctionInfo", "into_function_definition", "InProcessFunctionExecutor"] diff --git a/src/agnext/agent_components/function_executor/_base.py b/src/agnext/agent_components/function_executor/_base.py new file mode 100644 index 000000000..afad5d6b7 --- /dev/null +++ b/src/agnext/agent_components/function_executor/_base.py @@ -0,0 +1,36 @@ +from collections.abc import Sequence +from typing import Any, Callable, Dict, Protocol, TypedDict, Union, runtime_checkable + +from typing_extensions import NotRequired, Required + +from ..function_utils import get_function_schema +from ..types import FunctionDefinition + + +@runtime_checkable +class FunctionExecutor(Protocol): + async def execute_function(self, function_name: str, arguments: Dict[str, Any]) -> str: ... + + @property + def functions(self) -> Sequence[str]: ... + + +class FunctionInfo(TypedDict): + func: Required[Callable[..., Any]] + name: NotRequired[str] + description: NotRequired[str] + + +def into_function_definition( + func_info: Union[FunctionInfo, FunctionDefinition, Callable[..., Any]], +) -> FunctionDefinition: + if isinstance(func_info, FunctionDefinition): + return func_info + elif isinstance(func_info, dict): + name = func_info.get("name", func_info["func"].__name__) + description = func_info.get("description", "") + parameters = get_function_schema(func_info["func"], description="", name="")["function"]["parameters"] + return FunctionDefinition(name=name, description=description, parameters=parameters) + else: + parameters = get_function_schema(func_info, description="", name="")["function"]["parameters"] + return FunctionDefinition(name=func_info.__name__, description="", parameters=parameters) diff --git a/src/agnext/agent_components/function_executor/_impl/__init__.py b/src/agnext/agent_components/function_executor/_impl/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/agnext/agent_components/function_executor/_impl/in_process_function_executor.py b/src/agnext/agent_components/function_executor/_impl/in_process_function_executor.py new file mode 100644 index 000000000..7fddbf601 --- /dev/null +++ b/src/agnext/agent_components/function_executor/_impl/in_process_function_executor.py @@ -0,0 +1,40 @@ +import asyncio +import functools +from collections.abc import Sequence +from typing import Any, Callable, Union + +from .._base import FunctionExecutor, FunctionInfo + + +class InProcessFunctionExecutor(FunctionExecutor): + def __init__( + self, + functions: Sequence[Union[Callable[..., Any], FunctionInfo]] = [], + ) -> None: + def _name(func: Union[Callable[..., Any], FunctionInfo]) -> str: + if isinstance(func, dict): + return func.get("name", func["func"].__name__) + else: + return func.__name__ + + def _func(func: Union[Callable[..., Any], FunctionInfo]) -> Any: + if isinstance(func, dict): + return func.get("func") + else: + return func + + self._functions = dict([(_name(x), _func(x)) for x in functions]) + + async def execute_function(self, function_name: str, arguments: dict[str, Any]) -> str: + if function_name in self._functions: + function = self._functions[function_name] + if asyncio.iscoroutinefunction(function): + return str(function(**arguments)) + else: + return await asyncio.get_event_loop().run_in_executor(None, functools.partial(function, **arguments)) + + raise ValueError(f"Function {function_name} not found") + + @property + def functions(self) -> Sequence[str]: + return list(self._functions.keys()) diff --git a/src/agnext/agent_components/function_utils.py b/src/agnext/agent_components/function_utils.py new file mode 100644 index 000000000..a6e1c7dc0 --- /dev/null +++ b/src/agnext/agent_components/function_utils.py @@ -0,0 +1,284 @@ +# File based from: https://github.com/microsoft/autogen/blob/main/autogen/function_utils.py +# Credit to original authors + +import inspect +from logging import getLogger +from typing import Any, Callable, Dict, ForwardRef, List, Optional, Set, Tuple, Type, TypeVar, Union + +from pydantic import BaseModel, Field +from typing_extensions import Annotated, Literal + +from .pydantic_compat import evaluate_forwardref, model_dump, type2schema + +logger = getLogger(__name__) + +T = TypeVar("T") + + +def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any: + """Get the type annotation of a parameter. + + Args: + annotation: The annotation of the parameter + globalns: The global namespace of the function + + Returns: + The type annotation of the parameter + """ + if isinstance(annotation, str): + annotation = ForwardRef(annotation) + annotation = evaluate_forwardref(annotation, globalns, globalns) + return annotation + + +def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: + """Get the signature of a function with type annotations. + + Args: + call: The function to get the signature for + + Returns: + The signature of the function with type annotations + """ + signature = inspect.signature(call) + globalns = getattr(call, "__globals__", {}) + typed_params = [ + inspect.Parameter( + name=param.name, + kind=param.kind, + default=param.default, + annotation=get_typed_annotation(param.annotation, globalns), + ) + for param in signature.parameters.values() + ] + typed_signature = inspect.Signature(typed_params) + return typed_signature + + +def get_typed_return_annotation(call: Callable[..., Any]) -> Any: + """Get the return annotation of a function. + + Args: + call: The function to get the return annotation for + + Returns: + The return annotation of the function + """ + signature = inspect.signature(call) + annotation = signature.return_annotation + + if annotation is inspect.Signature.empty: + return None + + globalns = getattr(call, "__globals__", {}) + return get_typed_annotation(annotation, globalns) + + +def get_param_annotations(typed_signature: inspect.Signature) -> Dict[str, Union[Annotated[Type[Any], str], Type[Any]]]: + """Get the type annotations of the parameters of a function + + Args: + typed_signature: The signature of the function with type annotations + + Returns: + A dictionary of the type annotations of the parameters of the function + """ + return { + k: v.annotation for k, v in typed_signature.parameters.items() if v.annotation is not inspect.Signature.empty + } + + +class Parameters(BaseModel): + """Parameters of a function as defined by the OpenAI API""" + + type: Literal["object"] = "object" + properties: Dict[str, Dict[str, Any]] + required: List[str] + + +class Function(BaseModel): + """A function as defined by the OpenAI API""" + + description: Annotated[str, Field(description="Description of the function")] + name: Annotated[str, Field(description="Name of the function")] + parameters: Annotated[Parameters, Field(description="Parameters of the function")] + + +class ToolFunction(BaseModel): + """A function under tool as defined by the OpenAI API.""" + + type: Literal["function"] = "function" + function: Annotated[Function, Field(description="Function under tool")] + + +def get_parameter_json_schema(k: str, v: Any, default_values: Dict[str, Any]) -> Dict[str, Any]: + """Get a JSON schema for a parameter as defined by the OpenAI API + + Args: + k: The name of the parameter + v: The type of the parameter + default_values: The default values of the parameters of the function + + Returns: + A Pydanitc model for the parameter + """ + + def type2description(k: str, v: Union[Annotated[Type[Any], str], Type[Any]]) -> str: + # handles Annotated + if hasattr(v, "__metadata__"): + retval = v.__metadata__[0] + if isinstance(retval, str): + return retval + else: + raise ValueError(f"Invalid description {retval} for parameter {k}, should be a string.") + else: + return k + + schema = type2schema(v) + if k in default_values: + dv = default_values[k] + schema["default"] = dv + + schema["description"] = type2description(k, v) + + return schema + + +def get_required_params(typed_signature: inspect.Signature) -> List[str]: + """Get the required parameters of a function + + Args: + signature: The signature of the function as returned by inspect.signature + + Returns: + A list of the required parameters of the function + """ + return [k for k, v in typed_signature.parameters.items() if v.default == inspect.Signature.empty] + + +def get_default_values(typed_signature: inspect.Signature) -> Dict[str, Any]: + """Get default values of parameters of a function + + Args: + signature: The signature of the function as returned by inspect.signature + + Returns: + A dictionary of the default values of the parameters of the function + """ + return {k: v.default for k, v in typed_signature.parameters.items() if v.default != inspect.Signature.empty} + + +def get_parameters( + required: List[str], + param_annotations: Dict[str, Union[Annotated[Type[Any], str], Type[Any]]], + default_values: Dict[str, Any], +) -> Parameters: + """Get the parameters of a function as defined by the OpenAI API + + Args: + required: The required parameters of the function + hints: The type hints of the function as returned by typing.get_type_hints + + Returns: + A Pydantic model for the parameters of the function + """ + return Parameters( + properties={ + k: get_parameter_json_schema(k, v, default_values) + for k, v in param_annotations.items() + if v is not inspect.Signature.empty + }, + required=required, + ) + + +def get_missing_annotations(typed_signature: inspect.Signature, required: List[str]) -> Tuple[Set[str], Set[str]]: + """Get the missing annotations of a function + + Ignores the parameters with default values as they are not required to be annotated, but logs a warning. + Args: + typed_signature: The signature of the function with type annotations + required: The required parameters of the function + + Returns: + A set of the missing annotations of the function + """ + all_missing = {k for k, v in typed_signature.parameters.items() if v.annotation is inspect.Signature.empty} + missing = all_missing.intersection(set(required)) + unannotated_with_default = all_missing.difference(missing) + return missing, unannotated_with_default + + +def get_function_schema(f: Callable[..., Any], *, name: Optional[str] = None, description: str) -> Dict[str, Any]: + """Get a JSON schema for a function as defined by the OpenAI API + + Args: + f: The function to get the JSON schema for + name: The name of the function + description: The description of the function + + Returns: + A JSON schema for the function + + Raises: + TypeError: If the function is not annotated + + Examples: + + .. code-block:: python + + def f(a: Annotated[str, "Parameter a"], b: int = 2, c: Annotated[float, "Parameter c"] = 0.1) -> None: + pass + + get_function_schema(f, description="function f") + + # {'type': 'function', + # 'function': {'description': 'function f', + # 'name': 'f', + # 'parameters': {'type': 'object', + # 'properties': {'a': {'type': 'str', 'description': 'Parameter a'}, + # 'b': {'type': 'int', 'description': 'b'}, + # 'c': {'type': 'float', 'description': 'Parameter c'}}, + # 'required': ['a']}}} + + """ + typed_signature = get_typed_signature(f) + required = get_required_params(typed_signature) + default_values = get_default_values(typed_signature) + param_annotations = get_param_annotations(typed_signature) + return_annotation = get_typed_return_annotation(f) + missing, unannotated_with_default = get_missing_annotations(typed_signature, required) + + if return_annotation is None: + logger.warning( + f"The return type of the function '{f.__name__}' is not annotated. Although annotating it is " + + "optional, the function should return either a string, a subclass of 'pydantic.BaseModel'." + ) + + if unannotated_with_default != set(): + unannotated_with_default_s = [f"'{k}'" for k in sorted(unannotated_with_default)] + logger.warning( + f"The following parameters of the function '{f.__name__}' with default values are not annotated: " + + f"{', '.join(unannotated_with_default_s)}." + ) + + if missing != set(): + missing_s = [f"'{k}'" for k in sorted(missing)] + raise TypeError( + f"All parameters of the function '{f.__name__}' without default values must be annotated. " + + f"The annotations are missing for the following parameters: {', '.join(missing_s)}" + ) + + fname = name if name else f.__name__ + + parameters = get_parameters(required, param_annotations, default_values=default_values) + + function = ToolFunction( + function=Function( + description=description, + name=fname, + parameters=parameters, + ) + ) + + return model_dump(function) diff --git a/src/agnext/agent_components/pydantic_compat.py b/src/agnext/agent_components/pydantic_compat.py new file mode 100644 index 000000000..d554d83d2 --- /dev/null +++ b/src/agnext/agent_components/pydantic_compat.py @@ -0,0 +1,63 @@ +# File based from: https://github.com/microsoft/autogen/blob/main/autogen/_pydantic.py +# Credit to original authors + + +from typing import Any, Dict, Tuple, Type, Union, get_args + +from pydantic import BaseModel +from pydantic.version import VERSION as PYDANTIC_VERSION +from typing_extensions import get_origin + +__all__ = ("model_dump", "type2schema", "evaluate_forwardref") + +PYDANTIC_V1 = PYDANTIC_VERSION.startswith("1.") + + +def evaluate_forwardref( + value: Any, globalns: dict[str, Any] | None = None, localns: dict[str, Any] | None = None +) -> Any: + if PYDANTIC_V1: + from pydantic.typing import evaluate_forwardref as evaluate_forwardref_internal + + return evaluate_forwardref_internal(value, globalns, localns) + else: + from pydantic._internal._typing_extra import eval_type_lenient + + return eval_type_lenient(value, globalns, localns) + + +def type2schema(t: Type[Any] | None) -> Dict[str, Any]: + if PYDANTIC_V1: + from pydantic import schema_of # type: ignore + + if t is None: + return {"type": "null"} + elif get_origin(t) is Union: + return {"anyOf": [type2schema(tt) for tt in get_args(t)]} + elif get_origin(t) in [Tuple, tuple]: + prefixItems = [type2schema(tt) for tt in get_args(t)] + return { + "maxItems": len(prefixItems), + "minItems": len(prefixItems), + "prefixItems": prefixItems, + "type": "array", + } + + d = schema_of(t) # type: ignore + if "title" in d: + d.pop("title") + if "description" in d: + d.pop("description") + + return d + else: + from pydantic import TypeAdapter + + return TypeAdapter(t).json_schema() + + +def model_dump(model: BaseModel) -> Dict[str, Any]: + if PYDANTIC_V1: + return model.dict() # type: ignore + else: + return model.model_dump() diff --git a/test.sh b/test.sh index 636d66382..9236c93ff 100755 --- a/test.sh +++ b/test.sh @@ -20,4 +20,4 @@ pyright echo "--- Running mypy ---" mypy echo "--- Running pytest ---" -pytest \ No newline at end of file +pytest -n auto \ No newline at end of file diff --git a/tests/execution/test_commandline_code_executor.py b/tests/execution/test_commandline_code_executor.py new file mode 100644 index 000000000..2df31d6e7 --- /dev/null +++ b/tests/execution/test_commandline_code_executor.py @@ -0,0 +1,105 @@ +# File based from: https://github.com/microsoft/autogen/blob/main/test/coding/test_commandline_code_executor.py +# Credit to original authors + +import sys +import tempfile +from pathlib import Path + +import pytest + +from agnext.agent_components.code_executor import LocalCommandLineCodeExecutor, CodeBlock + +UNIX_SHELLS = ["bash", "sh", "shell"] +WINDOWS_SHELLS = ["ps1", "pwsh", "powershell"] +PYTHON_VARIANTS = ["python", "Python", "py"] + + +def test_execute_code() -> None: + with tempfile.TemporaryDirectory() as temp_dir: + executor = LocalCommandLineCodeExecutor(work_dir=temp_dir) + + + # Test single code block. + code_blocks = [CodeBlock(code="import sys; print('hello world!')", language="python")] + code_result = executor.execute_code_blocks(code_blocks) + assert code_result.exit_code == 0 and "hello world!" in code_result.output and code_result.code_file is not None + + # Test multiple code blocks. + code_blocks = [ + CodeBlock(code="import sys; print('hello world!')", language="python"), + CodeBlock(code="a = 100 + 100; print(a)", language="python"), + ] + code_result = executor.execute_code_blocks(code_blocks) + assert ( + code_result.exit_code == 0 + and "hello world!" in code_result.output + and "200" in code_result.output + and code_result.code_file is not None + ) + + # Test bash script. + if sys.platform not in ["win32"]: + code_blocks = [CodeBlock(code="echo 'hello world!'", language="bash")] + code_result = executor.execute_code_blocks(code_blocks) + assert code_result.exit_code == 0 and "hello world!" in code_result.output and code_result.code_file is not None + + # Test running code. + file_lines = ["import sys", "print('hello world!')", "a = 100 + 100", "print(a)"] + code_blocks = [CodeBlock(code="\n".join(file_lines), language="python")] + code_result = executor.execute_code_blocks(code_blocks) + assert ( + code_result.exit_code == 0 + and "hello world!" in code_result.output + and "200" in code_result.output + and code_result.code_file is not None + ) + + # Check saved code file. + with open(code_result.code_file) as f: + code_lines = f.readlines() + for file_line, code_line in zip(file_lines, code_lines): + assert file_line.strip() == code_line.strip() + + +def test_commandline_code_executor_timeout() -> None: + with tempfile.TemporaryDirectory() as temp_dir: + executor = LocalCommandLineCodeExecutor(timeout=1, work_dir=temp_dir) + code_blocks = [CodeBlock(code="import time; time.sleep(10); print('hello world!')", language="python")] + code_result = executor.execute_code_blocks(code_blocks) + assert code_result.exit_code and "Timeout" in code_result.output + + +def test_local_commandline_code_executor_restart() -> None: + executor = LocalCommandLineCodeExecutor() + with pytest.warns(UserWarning, match=r".*No action is taken."): + executor.restart() + + + + +def test_invalid_relative_path() -> None: + executor = LocalCommandLineCodeExecutor() + code = """# filename: /tmp/test.py + +print("hello world") +""" + result = executor.execute_code_blocks([CodeBlock(code=code, language="python")]) + assert result.exit_code == 1 and "Filename is not in the workspace" in result.output + + +def test_valid_relative_path() -> None: + with tempfile.TemporaryDirectory() as temp_dir_str: + temp_dir = Path(temp_dir_str) + executor = LocalCommandLineCodeExecutor(work_dir=temp_dir) + code = """# filename: test.py + +print("hello world") +""" + result = executor.execute_code_blocks([CodeBlock(code=code, language="python")]) + assert result.exit_code == 0 + assert "hello world" in result.output + assert result.code_file is not None + assert "test.py" in result.code_file + assert (temp_dir / Path("test.py")).resolve() == Path(result.code_file).resolve() + assert (temp_dir / Path("test.py")).exists() + diff --git a/tests/execution/test_markdown_code_extractor.py b/tests/execution/test_markdown_code_extractor.py new file mode 100644 index 000000000..13894b3b5 --- /dev/null +++ b/tests/execution/test_markdown_code_extractor.py @@ -0,0 +1,118 @@ +# File based from: https://github.com/microsoft/autogen/blob/main/test/coding/test_markdown_code_extractor.py +# Credit to original authors + +from agnext.agent_components.code_executor import MarkdownCodeExtractor + +_message_1 = """ +Example: +``` +print("hello extract code") +``` +""" + +_message_2 = """Example: +```python +def scrape(url): + import requests + from bs4 import BeautifulSoup + response = requests.get(url) + soup = BeautifulSoup(response.text, "html.parser") + title = soup.find("title").text + text = soup.find("div", {"id": "bodyContent"}).text + return title, text +``` +Test: +```python +url = "https://en.wikipedia.org/wiki/Web_scraping" +title, text = scrape(url) +print(f"Title: {title}") +print(f"Text: {text}") +``` +""" + +_message_3 = """ +Example: + ```python + def scrape(url): + import requests + from bs4 import BeautifulSoup + response = requests.get(url) + soup = BeautifulSoup(response.text, "html.parser") + title = soup.find("title").text + text = soup.find("div", {"id": "bodyContent"}).text + return title, text + ``` +""" + +_message_4 = """ +Example: +``` python +def scrape(url): + import requests + from bs4 import BeautifulSoup + response = requests.get(url) + soup = BeautifulSoup(response.text, "html.parser") + title = soup.find("title").text + text = soup.find("div", {"id": "bodyContent"}).text + return title, text +``` +""".replace( + "\n", "\r\n" +) + +_message_5 = """ +Test bash script: +```bash +echo 'hello world!' +``` +""" + +_message_6 = """ +Test some C# code, expecting "" +``` +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace ConsoleApplication1 +{ + class Program + { + static void Main(string[] args) + { + Console.WriteLine("Hello World"); + } + } +} +``` +""" + +_message_7 = """ +Test some message that has no code block. +""" + + +def test_extract_code() -> None: + extractor = MarkdownCodeExtractor() + + code_blocks = extractor.extract_code_blocks(_message_1) + assert len(code_blocks) == 1 and code_blocks[0].language == "python" + + code_blocks = extractor.extract_code_blocks(_message_2) + assert len(code_blocks) == 2 and code_blocks[0].language == "python" and code_blocks[1].language == "python" + + code_blocks = extractor.extract_code_blocks(_message_3) + assert len(code_blocks) == 1 and code_blocks[0].language == "python" + + code_blocks = extractor.extract_code_blocks(_message_4) + assert len(code_blocks) == 1 and code_blocks[0].language == "python" + + code_blocks = extractor.extract_code_blocks(_message_5) + assert len(code_blocks) == 1 and code_blocks[0].language == "bash" + + code_blocks = extractor.extract_code_blocks(_message_6) + assert len(code_blocks) == 1 and code_blocks[0].language == "" + + code_blocks = extractor.extract_code_blocks(_message_7) + assert len(code_blocks) == 0 diff --git a/tests/execution/test_user_defined_functions.py b/tests/execution/test_user_defined_functions.py new file mode 100644 index 000000000..382ecca2c --- /dev/null +++ b/tests/execution/test_user_defined_functions.py @@ -0,0 +1,199 @@ +# File based from: https://github.com/microsoft/autogen/blob/main/test/coding/test_user_defined_functions.py +# Credit to original authors + +import tempfile + +import pytest + +from agnext.agent_components.code_executor import LocalCommandLineCodeExecutor, CodeBlock +from agnext.agent_components.func_with_reqs import FunctionWithRequirements, with_requirements + +import polars + +def add_two_numbers(a: int, b: int) -> int: + """Add two numbers together.""" + return a + b + + +@with_requirements(python_packages=["polars"], global_imports=["polars"]) +def load_data() -> polars.DataFrame: + """Load some sample data. + + Returns: + polars.DataFrame: A DataFrame with the following columns: name(str), location(str), age(int) + """ + data = { + "name": ["John", "Anna", "Peter", "Linda"], + "location": ["New York", "Paris", "Berlin", "London"], + "age": [24, 13, 53, 33], + } + return polars.DataFrame(data) + + +@with_requirements(global_imports=["NOT_A_REAL_PACKAGE"]) +def function_incorrect_import() -> "polars.DataFrame": + return polars.DataFrame() + + +@with_requirements(python_packages=["NOT_A_REAL_PACKAGE"]) +def function_incorrect_dep() -> "polars.DataFrame": + return polars.DataFrame() + + +def function_missing_reqs() -> "polars.DataFrame": + return polars.DataFrame() + + +def test_can_load_function_with_reqs() -> None: + with tempfile.TemporaryDirectory() as temp_dir: + executor = LocalCommandLineCodeExecutor(work_dir=temp_dir, functions=[load_data]) + code = f"""from {executor.functions_module} import load_data +import polars + +# Get first row's name +data = load_data() +print(data['name'][0])""" + + result = executor.execute_code_blocks( + code_blocks=[ + CodeBlock(language="python", code=code), + ] + ) + assert result.output == "John\n" + assert result.exit_code == 0 + + +def test_can_load_function() -> None: + with tempfile.TemporaryDirectory() as temp_dir: + executor = LocalCommandLineCodeExecutor(work_dir=temp_dir, functions=[add_two_numbers]) + code = f"""from {executor.functions_module} import add_two_numbers +print(add_two_numbers(1, 2))""" + + result = executor.execute_code_blocks( + code_blocks=[ + CodeBlock(language="python", code=code), + ] + ) + assert result.output == "3\n" + assert result.exit_code == 0 + + +def test_fails_for_function_incorrect_import() -> None: + with tempfile.TemporaryDirectory() as temp_dir: + executor = LocalCommandLineCodeExecutor(work_dir=temp_dir, functions=[function_incorrect_import]) + code = f"""from {executor.functions_module} import function_incorrect_import +function_incorrect_import()""" + + with pytest.raises(ValueError): + executor.execute_code_blocks( + code_blocks=[ + CodeBlock(language="python", code=code), + ] + ) + + +def test_fails_for_function_incorrect_dep() -> None: + with tempfile.TemporaryDirectory() as temp_dir: + executor = LocalCommandLineCodeExecutor(work_dir=temp_dir, functions=[function_incorrect_dep]) + code = f"""from {executor.functions_module} import function_incorrect_dep +function_incorrect_dep()""" + + with pytest.raises(ValueError): + executor.execute_code_blocks( + code_blocks=[ + CodeBlock(language="python", code=code), + ] + ) + + + +def test_formatted_prompt() -> None: + with tempfile.TemporaryDirectory() as temp_dir: + executor = LocalCommandLineCodeExecutor(work_dir=temp_dir, functions=[add_two_numbers]) + + result = executor.format_functions_for_prompt() + assert ( + '''def add_two_numbers(a: int, b: int) -> int: + """Add two numbers together.""" +''' + in result + ) + + +def test_formatted_prompt_str_func() -> None: + with tempfile.TemporaryDirectory() as temp_dir: + func = FunctionWithRequirements.from_str( + ''' +def add_two_numbers(a: int, b: int) -> int: + """Add two numbers together.""" + return a + b +''' + ) + executor = LocalCommandLineCodeExecutor(work_dir=temp_dir, functions=[func]) + + result = executor.format_functions_for_prompt() + assert ( + '''def add_two_numbers(a: int, b: int) -> int: + """Add two numbers together.""" +''' + in result + ) + + + +def test_can_load_str_function_with_reqs() -> None: + with tempfile.TemporaryDirectory() as temp_dir: + func = FunctionWithRequirements.from_str( + ''' +def add_two_numbers(a: int, b: int) -> int: + """Add two numbers together.""" + return a + b +''' + ) + + executor = LocalCommandLineCodeExecutor(work_dir=temp_dir, functions=[func]) + code = f"""from {executor.functions_module} import add_two_numbers +print(add_two_numbers(1, 2))""" + + result = executor.execute_code_blocks( + code_blocks=[ + CodeBlock(language="python", code=code), + ] + ) + assert result.output == "3\n" + assert result.exit_code == 0 + + +def test_cant_load_broken_str_function_with_reqs() -> None: + + with pytest.raises(ValueError): + _ = FunctionWithRequirements.from_str( + ''' +invaliddef add_two_numbers(a: int, b: int) -> int: + """Add two numbers together.""" + return a + b +''' + ) + + +def test_cant_run_broken_str_function_with_reqs() -> None: + with tempfile.TemporaryDirectory() as temp_dir: + func = FunctionWithRequirements.from_str( + ''' +def add_two_numbers(a: int, b: int) -> int: + """Add two numbers together.""" + return a + b +''' + ) + + executor = LocalCommandLineCodeExecutor(work_dir=temp_dir, functions=[func]) + code = f"""from {executor.functions_module} import add_two_numbers +print(add_two_numbers(object(), False))""" + + result = executor.execute_code_blocks( + code_blocks=[ + CodeBlock(language="python", code=code), + ] + ) + assert "TypeError: unsupported operand type(s) for +:" in result.output + assert result.exit_code == 1