Add CLI mode (#3564)

* set log levels

* basic cli flow

* basic display

* better exits

* set log level

* fix messages

* clean up logs

* better exits

* better printing

* add todo
This commit is contained in:
Robert Brennan
2024-08-25 18:10:21 -04:00
committed by GitHub
parent 7589be671e
commit 356d9b34be
4 changed files with 162 additions and 18 deletions

138
openhands/core/cli.py Normal file
View File

@@ -0,0 +1,138 @@
import asyncio
import logging
from typing import Type
from termcolor import colored
import agenthub # noqa F401 (we import this to get the agents registered)
from openhands.controller import AgentController
from openhands.controller.agent import Agent
from openhands.core.config import (
load_app_config,
)
from openhands.core.logger import openhands_logger as logger
from openhands.core.schema import AgentState
from openhands.events import EventSource, EventStream, EventStreamSubscriber
from openhands.events.action import (
Action,
ChangeAgentStateAction,
CmdRunAction,
MessageAction,
)
from openhands.events.event import Event
from openhands.events.observation import (
AgentStateChangedObservation,
CmdOutputObservation,
)
from openhands.llm.llm import LLM
from openhands.runtime import get_runtime_cls
from openhands.runtime.runtime import Runtime
from openhands.storage import get_file_store
def display_message(message: str):
print(colored('🤖 ' + message + '\n', 'yellow'))
def display_command(command: str):
print(' ' + colored(command + '\n', 'green'))
def display_command_output(output: str):
lines = output.split('\n')
for line in lines:
if line.startswith('[Python Interpreter') or line.startswith('openhands@'):
# TODO: clean this up once we clean up terminal output
continue
print(colored(line, 'blue'))
print('\n')
def display_event(event: Event):
if isinstance(event, Action):
if hasattr(event, 'thought'):
display_message(event.thought)
if isinstance(event, MessageAction):
if event.source != EventSource.USER:
display_message(event.content)
if isinstance(event, CmdRunAction):
display_command(event.command)
if isinstance(event, CmdOutputObservation):
display_command_output(event.content)
async def main():
"""Runs the agent in CLI mode"""
logger.setLevel(logging.WARNING)
config = load_app_config()
sid = 'cli'
agent_cls: Type[Agent] = Agent.get_cls(config.default_agent)
agent_config = config.get_agent_config(config.default_agent)
llm_config = config.get_llm_config_from_agent(config.default_agent)
agent = agent_cls(
llm=LLM(config=llm_config),
config=agent_config,
)
file_store = get_file_store(config.file_store, config.file_store_path)
event_stream = EventStream(sid, file_store)
runtime_cls = get_runtime_cls(config.runtime)
runtime: Runtime = runtime_cls(
config=config,
event_stream=event_stream,
sid=sid,
plugins=agent_cls.sandbox_plugins,
)
await runtime.ainit()
controller = AgentController(
agent=agent,
max_iterations=config.max_iterations,
max_budget_per_task=config.max_budget_per_task,
agent_to_llm_config=config.get_agent_to_llm_config_map(),
event_stream=event_stream,
)
async def prompt_for_next_task():
next_message = input('How can I help? >> ')
if next_message == 'exit':
event_stream.add_event(
ChangeAgentStateAction(AgentState.STOPPED), EventSource.USER
)
return
action = MessageAction(content=next_message)
event_stream.add_event(action, EventSource.USER)
async def on_event(event: Event):
display_event(event)
if isinstance(event, AgentStateChangedObservation):
if event.agent_state == AgentState.ERROR:
print('An error occurred. Please try again.')
if event.agent_state in [
AgentState.AWAITING_USER_INPUT,
AgentState.FINISHED,
AgentState.ERROR,
]:
await prompt_for_next_task()
event_stream.subscribe(EventStreamSubscriber.MAIN, on_event)
await prompt_for_next_task()
while controller.state.agent_state not in [
AgentState.STOPPED,
]:
await asyncio.sleep(1) # Give back control for a tick, so the agent can run
print('Exiting...')
await controller.close()
if __name__ == '__main__':
loop = asyncio.get_event_loop()
try:
loop.run_until_complete(main())
finally:
pass

View File

@@ -8,9 +8,13 @@ from typing import Literal, Mapping
from termcolor import colored
DISABLE_COLOR_PRINTING = False
LOG_LEVEL = os.getenv('LOG_LEVEL', 'INFO').upper()
DEBUG = os.getenv('DEBUG', 'False').lower() in ['true', '1', 'yes']
if DEBUG:
LOG_LEVEL = 'DEBUG'
LOG_TO_FILE = os.getenv('LOG_TO_FILE', 'False').lower() in ['true', '1', 'yes']
DISABLE_COLOR_PRINTING = False
ColorType = Literal[
'red',
@@ -116,9 +120,7 @@ class SensitiveDataFilter(logging.Filter):
def get_console_handler():
"""Returns a console handler for logging."""
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
if DEBUG:
console_handler.setLevel(logging.DEBUG)
console_handler.setLevel(logging.getLevelName(LOG_LEVEL))
console_handler.setFormatter(console_formatter)
return console_handler
@@ -129,8 +131,7 @@ def get_file_handler(log_dir):
timestamp = datetime.now().strftime('%Y-%m-%d')
file_name = f'openhands_{timestamp}.log'
file_handler = logging.FileHandler(os.path.join(log_dir, file_name))
if DEBUG:
file_handler.setLevel(logging.DEBUG)
file_handler.setLevel(logging.getLevelName(LOG_LEVEL))
file_handler.setFormatter(file_formatter)
return file_handler
@@ -157,19 +158,16 @@ def log_uncaught_exceptions(ex_cls, ex, tb):
sys.excepthook = log_uncaught_exceptions
openhands_logger = logging.getLogger('openhands')
openhands_logger.setLevel(logging.INFO)
openhands_logger.setLevel(logging.getLevelName(LOG_LEVEL))
LOG_DIR = os.path.join(
# parent dir of openhands/core (i.e., root of the repo)
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
'logs',
)
if DEBUG:
openhands_logger.setLevel(logging.DEBUG)
if LOG_TO_FILE:
# default log to project root
openhands_logger.info('Logging to file is enabled. Logging to %s', LOG_DIR)
openhands_logger.debug('Logging to file is enabled. Logging to %s', LOG_DIR)
openhands_logger.addHandler(get_file_handler(LOG_DIR))
openhands_logger.addHandler(get_console_handler())
@@ -233,21 +231,21 @@ class LlmFileHandler(logging.FileHandler):
self.message_counter += 1
def _get_llm_file_handler(name, debug_level=logging.DEBUG):
def _get_llm_file_handler(name, log_level=logging.DEBUG):
# The 'delay' parameter, when set to True, postpones the opening of the log file
# until the first log message is emitted.
llm_file_handler = LlmFileHandler(name, delay=True)
llm_file_handler.setFormatter(llm_formatter)
llm_file_handler.setLevel(debug_level)
llm_file_handler.setLevel(log_level)
return llm_file_handler
def _setup_llm_logger(name, debug_level=logging.DEBUG):
def _setup_llm_logger(name, log_level=logging.DEBUG):
logger = logging.getLogger(name)
logger.propagate = False
logger.setLevel(debug_level)
logger.setLevel(log_level)
if LOG_TO_FILE:
logger.addHandler(_get_llm_file_handler(name, debug_level))
logger.addHandler(_get_llm_file_handler(name, log_level))
return logger

View File

@@ -26,7 +26,7 @@ class DockerRuntimeBuilder(RuntimeBuilder):
for log in build_logs:
if 'stream' in log:
print(log['stream'].strip())
logger.info(log['stream'].strip())
elif 'error' in log:
logger.error(log['error'].strip())
else:

View File

@@ -42,8 +42,16 @@ def _create_project_source_dist():
# run "python -m build -s" on project_root to create project tarball
result = subprocess.run(
f'python -m build -s ' + project_root.replace(" ", r"\ "), shell=True
'python -m build -s ' + project_root.replace(' ', r'\ '),
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
logger.info(result.stdout.decode())
err_logs = result.stderr.decode()
if err_logs:
logger.error(err_logs)
if result.returncode != 0:
logger.error(f'Build failed: {result}')
raise Exception(f'Build failed: {result}')