Files
OS-Copilot/working_dir/code/python_interpreter.py
2023-12-15 00:45:06 +08:00

119 lines
4.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import copy
import io
from contextlib import redirect_stdout
from typing import Any, Optional
from jarvis.action.base_action import BaseAction
from jarvis.core.schema import ActionReturn
class GenericRuntime:
GLOBAL_DICT = {}
LOCAL_DICT = None
HEADERS = []
def __init__(self):
self._global_vars = copy.copy(self.GLOBAL_DICT)
self._local_vars = copy.copy(
self.LOCAL_DICT) if self.LOCAL_DICT else None
for c in self.HEADERS:
self.exec_code(c)
def exec_code(self, code_piece: str) -> None:
exec(code_piece, self._global_vars)
def eval_code(self, expr: str) -> Any:
return eval(expr, self._global_vars)
DEFAULT_DESCRIPTION = """用来执行Python代码。代码必须是一个函数
函数名必须得是 'solution',代码对应你的思考过程。代码实例格式如下:
```python
# import 依赖包
import time
def solution():
# 初始化一些变量
print("hello world!")
# 步骤一
#mid_variable = func(variable_names_with_real_meaning)
# 步骤 x
#mid_variable = func(mid_variable)
# 最后结果
#final_answer = func(mid_variable)
return "return!"
```"""
class python_interpreter(BaseAction):
"""A Python executor that can execute Python scripts.
Args:
description (str): The description of the action. Defaults to
DEFAULT_DESCRIPTION.
answer_symbol (str, Optional): the answer symbol from LLM
answer_expr (str, Optional): the answer function name of the Python
script. Default to 'solution()'.
answer_from_stdout (boolean): whether the execution results is from
stdout.
name (str, optional): The name of the action. If None, the name will
be class nameDefaults to None.
timeout (int): Upper bound of waiting time for Python script execution.
"""
def __init__(self,
description: str = DEFAULT_DESCRIPTION,
answer_symbol: Optional[str] = None,
answer_expr: Optional[str] = 'solution()',
answer_from_stdout: bool = False,
name: Optional[str] = None,
timeout: int = 20) -> None:
super().__init__(description, name, timeout)
self.answer_symbol = answer_symbol
self.answer_expr = answer_expr
self.answer_from_stdout = answer_from_stdout
def __call__(self, command: str) -> ActionReturn:
self.runtime = GenericRuntime()
self._call(command)
def _call(self, command: str) -> ActionReturn:
tool_return = ActionReturn(url=None, args=None, type=self.name)
try:
if '```python' in command:
command = command.split('```python')[1].split('```')[0]
elif '```' in command:
command = command.split('```')[1].split('```')[0]
tool_return.args = dict(text='```python\n' + command + '\n```')
command = command.split('\n')
if self.answer_from_stdout:
program_io = io.StringIO()
with redirect_stdout(program_io):
self.runtime.exec_code('\n'.join(command))
program_io.seek(0)
res = program_io.readlines()[-1]
elif self.answer_symbol:
self.runtime.exec_code('\n'.join(command))
res = self.runtime._global_vars[self.answer_symbol]
elif self.answer_expr:
self.runtime.exec_code('\n'.join(command))
res = self.runtime.eval_code(self.answer_expr)
else:
self.runtime.exec_code('\n'.join(command[:-1]))
res = self.runtime.eval_code(command[-1])
except Exception as e:
tool_return.errmsg = repr(e)
tool_return.type = self.name
# tool_return.state = ActionStatusCode.API_ERROR
return tool_return
try:
tool_return.result = dict(text=str(res))
# tool_return.state = ActionStatusCode.SUCCESS
except Exception as e:
tool_return.errmsg = repr(e)
tool_return.type = self.name
# tool_return.state = ActionStatusCode.API_ERROR
return tool_return