mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-09 14:57:59 -05:00
Add CLI mode (#3564)
* set log levels * basic cli flow * basic display * better exits * set log level * fix messages * clean up logs * better exits * better printing * add todo
This commit is contained in:
138
openhands/core/cli.py
Normal file
138
openhands/core/cli.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Type
|
||||
|
||||
from termcolor import colored
|
||||
|
||||
import agenthub # noqa F401 (we import this to get the agents registered)
|
||||
from openhands.controller import AgentController
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.core.config import (
|
||||
load_app_config,
|
||||
)
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.events import EventSource, EventStream, EventStreamSubscriber
|
||||
from openhands.events.action import (
|
||||
Action,
|
||||
ChangeAgentStateAction,
|
||||
CmdRunAction,
|
||||
MessageAction,
|
||||
)
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation import (
|
||||
AgentStateChangedObservation,
|
||||
CmdOutputObservation,
|
||||
)
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.runtime import get_runtime_cls
|
||||
from openhands.runtime.runtime import Runtime
|
||||
from openhands.storage import get_file_store
|
||||
|
||||
|
||||
def display_message(message: str):
|
||||
print(colored('🤖 ' + message + '\n', 'yellow'))
|
||||
|
||||
|
||||
def display_command(command: str):
|
||||
print('❯ ' + colored(command + '\n', 'green'))
|
||||
|
||||
|
||||
def display_command_output(output: str):
|
||||
lines = output.split('\n')
|
||||
for line in lines:
|
||||
if line.startswith('[Python Interpreter') or line.startswith('openhands@'):
|
||||
# TODO: clean this up once we clean up terminal output
|
||||
continue
|
||||
print(colored(line, 'blue'))
|
||||
print('\n')
|
||||
|
||||
|
||||
def display_event(event: Event):
|
||||
if isinstance(event, Action):
|
||||
if hasattr(event, 'thought'):
|
||||
display_message(event.thought)
|
||||
if isinstance(event, MessageAction):
|
||||
if event.source != EventSource.USER:
|
||||
display_message(event.content)
|
||||
if isinstance(event, CmdRunAction):
|
||||
display_command(event.command)
|
||||
if isinstance(event, CmdOutputObservation):
|
||||
display_command_output(event.content)
|
||||
|
||||
|
||||
async def main():
|
||||
"""Runs the agent in CLI mode"""
|
||||
logger.setLevel(logging.WARNING)
|
||||
config = load_app_config()
|
||||
sid = 'cli'
|
||||
|
||||
agent_cls: Type[Agent] = Agent.get_cls(config.default_agent)
|
||||
agent_config = config.get_agent_config(config.default_agent)
|
||||
llm_config = config.get_llm_config_from_agent(config.default_agent)
|
||||
agent = agent_cls(
|
||||
llm=LLM(config=llm_config),
|
||||
config=agent_config,
|
||||
)
|
||||
|
||||
file_store = get_file_store(config.file_store, config.file_store_path)
|
||||
event_stream = EventStream(sid, file_store)
|
||||
|
||||
runtime_cls = get_runtime_cls(config.runtime)
|
||||
runtime: Runtime = runtime_cls(
|
||||
config=config,
|
||||
event_stream=event_stream,
|
||||
sid=sid,
|
||||
plugins=agent_cls.sandbox_plugins,
|
||||
)
|
||||
await runtime.ainit()
|
||||
|
||||
controller = AgentController(
|
||||
agent=agent,
|
||||
max_iterations=config.max_iterations,
|
||||
max_budget_per_task=config.max_budget_per_task,
|
||||
agent_to_llm_config=config.get_agent_to_llm_config_map(),
|
||||
event_stream=event_stream,
|
||||
)
|
||||
|
||||
async def prompt_for_next_task():
|
||||
next_message = input('How can I help? >> ')
|
||||
if next_message == 'exit':
|
||||
event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.STOPPED), EventSource.USER
|
||||
)
|
||||
return
|
||||
action = MessageAction(content=next_message)
|
||||
event_stream.add_event(action, EventSource.USER)
|
||||
|
||||
async def on_event(event: Event):
|
||||
display_event(event)
|
||||
if isinstance(event, AgentStateChangedObservation):
|
||||
if event.agent_state == AgentState.ERROR:
|
||||
print('An error occurred. Please try again.')
|
||||
if event.agent_state in [
|
||||
AgentState.AWAITING_USER_INPUT,
|
||||
AgentState.FINISHED,
|
||||
AgentState.ERROR,
|
||||
]:
|
||||
await prompt_for_next_task()
|
||||
|
||||
event_stream.subscribe(EventStreamSubscriber.MAIN, on_event)
|
||||
|
||||
await prompt_for_next_task()
|
||||
|
||||
while controller.state.agent_state not in [
|
||||
AgentState.STOPPED,
|
||||
]:
|
||||
await asyncio.sleep(1) # Give back control for a tick, so the agent can run
|
||||
|
||||
print('Exiting...')
|
||||
await controller.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
loop.run_until_complete(main())
|
||||
finally:
|
||||
pass
|
||||
@@ -8,9 +8,13 @@ from typing import Literal, Mapping
|
||||
|
||||
from termcolor import colored
|
||||
|
||||
DISABLE_COLOR_PRINTING = False
|
||||
LOG_LEVEL = os.getenv('LOG_LEVEL', 'INFO').upper()
|
||||
DEBUG = os.getenv('DEBUG', 'False').lower() in ['true', '1', 'yes']
|
||||
if DEBUG:
|
||||
LOG_LEVEL = 'DEBUG'
|
||||
|
||||
LOG_TO_FILE = os.getenv('LOG_TO_FILE', 'False').lower() in ['true', '1', 'yes']
|
||||
DISABLE_COLOR_PRINTING = False
|
||||
|
||||
ColorType = Literal[
|
||||
'red',
|
||||
@@ -116,9 +120,7 @@ class SensitiveDataFilter(logging.Filter):
|
||||
def get_console_handler():
|
||||
"""Returns a console handler for logging."""
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.INFO)
|
||||
if DEBUG:
|
||||
console_handler.setLevel(logging.DEBUG)
|
||||
console_handler.setLevel(logging.getLevelName(LOG_LEVEL))
|
||||
console_handler.setFormatter(console_formatter)
|
||||
return console_handler
|
||||
|
||||
@@ -129,8 +131,7 @@ def get_file_handler(log_dir):
|
||||
timestamp = datetime.now().strftime('%Y-%m-%d')
|
||||
file_name = f'openhands_{timestamp}.log'
|
||||
file_handler = logging.FileHandler(os.path.join(log_dir, file_name))
|
||||
if DEBUG:
|
||||
file_handler.setLevel(logging.DEBUG)
|
||||
file_handler.setLevel(logging.getLevelName(LOG_LEVEL))
|
||||
file_handler.setFormatter(file_formatter)
|
||||
return file_handler
|
||||
|
||||
@@ -157,19 +158,16 @@ def log_uncaught_exceptions(ex_cls, ex, tb):
|
||||
sys.excepthook = log_uncaught_exceptions
|
||||
|
||||
openhands_logger = logging.getLogger('openhands')
|
||||
openhands_logger.setLevel(logging.INFO)
|
||||
openhands_logger.setLevel(logging.getLevelName(LOG_LEVEL))
|
||||
LOG_DIR = os.path.join(
|
||||
# parent dir of openhands/core (i.e., root of the repo)
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
|
||||
'logs',
|
||||
)
|
||||
|
||||
if DEBUG:
|
||||
openhands_logger.setLevel(logging.DEBUG)
|
||||
|
||||
if LOG_TO_FILE:
|
||||
# default log to project root
|
||||
openhands_logger.info('Logging to file is enabled. Logging to %s', LOG_DIR)
|
||||
openhands_logger.debug('Logging to file is enabled. Logging to %s', LOG_DIR)
|
||||
openhands_logger.addHandler(get_file_handler(LOG_DIR))
|
||||
|
||||
openhands_logger.addHandler(get_console_handler())
|
||||
@@ -233,21 +231,21 @@ class LlmFileHandler(logging.FileHandler):
|
||||
self.message_counter += 1
|
||||
|
||||
|
||||
def _get_llm_file_handler(name, debug_level=logging.DEBUG):
|
||||
def _get_llm_file_handler(name, log_level=logging.DEBUG):
|
||||
# The 'delay' parameter, when set to True, postpones the opening of the log file
|
||||
# until the first log message is emitted.
|
||||
llm_file_handler = LlmFileHandler(name, delay=True)
|
||||
llm_file_handler.setFormatter(llm_formatter)
|
||||
llm_file_handler.setLevel(debug_level)
|
||||
llm_file_handler.setLevel(log_level)
|
||||
return llm_file_handler
|
||||
|
||||
|
||||
def _setup_llm_logger(name, debug_level=logging.DEBUG):
|
||||
def _setup_llm_logger(name, log_level=logging.DEBUG):
|
||||
logger = logging.getLogger(name)
|
||||
logger.propagate = False
|
||||
logger.setLevel(debug_level)
|
||||
logger.setLevel(log_level)
|
||||
if LOG_TO_FILE:
|
||||
logger.addHandler(_get_llm_file_handler(name, debug_level))
|
||||
logger.addHandler(_get_llm_file_handler(name, log_level))
|
||||
return logger
|
||||
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ class DockerRuntimeBuilder(RuntimeBuilder):
|
||||
|
||||
for log in build_logs:
|
||||
if 'stream' in log:
|
||||
print(log['stream'].strip())
|
||||
logger.info(log['stream'].strip())
|
||||
elif 'error' in log:
|
||||
logger.error(log['error'].strip())
|
||||
else:
|
||||
|
||||
@@ -42,8 +42,16 @@ def _create_project_source_dist():
|
||||
|
||||
# run "python -m build -s" on project_root to create project tarball
|
||||
result = subprocess.run(
|
||||
f'python -m build -s ' + project_root.replace(" ", r"\ "), shell=True
|
||||
'python -m build -s ' + project_root.replace(' ', r'\ '),
|
||||
shell=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
logger.info(result.stdout.decode())
|
||||
err_logs = result.stderr.decode()
|
||||
if err_logs:
|
||||
logger.error(err_logs)
|
||||
|
||||
if result.returncode != 0:
|
||||
logger.error(f'Build failed: {result}')
|
||||
raise Exception(f'Build failed: {result}')
|
||||
|
||||
Reference in New Issue
Block a user