Add type annotations to CLI directory (#8291)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Graham Neubig
2025-05-08 21:25:02 -04:00
committed by GitHub
parent 4c33b28dad
commit 7d356cad47
5 changed files with 151 additions and 78 deletions

View File

@@ -100,7 +100,7 @@ def handle_exit_command(
return close_repl
def handle_help_command():
def handle_help_command() -> None:
display_help()
@@ -135,7 +135,7 @@ async def handle_init_command(
return close_repl, reload_microagents
def handle_status_command(usage_metrics: UsageMetrics, sid: str):
def handle_status_command(usage_metrics: UsageMetrics, sid: str) -> None:
display_status(usage_metrics, sid)
@@ -168,7 +168,7 @@ def handle_new_command(
async def handle_settings_command(
config: AppConfig,
settings_store: FileSettingsStore,
):
) -> None:
display_settings(config)
modify_settings = cli_confirm(
'\nWhich settings would you like to modify?',
@@ -213,6 +213,7 @@ async def init_repository(current_dir: str) -> bool:
if repo_file_path.exists():
try:
# Path.exists() ensures repo_file_path is not None, so we can safely pass it to read_file
content = await asyncio.get_event_loop().run_in_executor(
None, read_file, repo_file_path
)
@@ -263,7 +264,7 @@ async def init_repository(current_dir: str) -> bool:
return init_repo
def check_folder_security_agreement(config: AppConfig, current_dir):
def check_folder_security_agreement(config: AppConfig, current_dir: str) -> bool:
# Directories trusted by user for the CLI to use as workspace
# Config from ~/.openhands/config.toml overrides the app config

View File

@@ -68,7 +68,7 @@ async def cleanup_session(
agent: Agent,
runtime: Runtime,
controller: AgentController,
):
) -> None:
"""Clean up all resources from the current session."""
try:
# Cancel all running tasks except the current one
@@ -126,7 +126,7 @@ async def run_session(
usage_metrics = UsageMetrics()
async def prompt_for_next_task(agent_state: str):
async def prompt_for_next_task(agent_state: str) -> None:
nonlocal reload_microagents, new_session_requested
while True:
next_message = await read_prompt_input(
@@ -271,7 +271,7 @@ async def run_session(
return new_session_requested
async def main(loop: asyncio.AbstractEventLoop):
async def main(loop: asyncio.AbstractEventLoop) -> None:
"""Runs the agent in CLI mode."""
args = parse_arguments()

View File

@@ -29,7 +29,7 @@ from openhands.storage.settings.file_settings_store import FileSettingsStore
from openhands.utils.llm import get_supported_llm_models
def display_settings(config: AppConfig):
def display_settings(config: AppConfig) -> None:
llm_config = config.get_llm_config()
advanced_llm_settings = True if llm_config.base_url else False
@@ -108,8 +108,8 @@ async def get_validated_input(
prompt_text: str,
completer=None,
validator=None,
error_message='Input cannot be empty',
):
error_message: str = 'Input cannot be empty',
) -> str:
session.completer = completer
value = None
@@ -146,7 +146,7 @@ def save_settings_confirmation() -> bool:
async def modify_llm_settings_basic(
config: AppConfig, settings_store: FileSettingsStore
):
) -> None:
model_list = get_supported_llm_models(config)
organized_models = organize_models_and_providers(model_list)
@@ -171,20 +171,24 @@ async def modify_llm_settings_basic(
error_message='Invalid provider selected',
)
model_list = organized_models[provider]['models']
provider_models = organized_models[provider]['models']
if provider == 'openai':
model_list = [m for m in model_list if m not in VERIFIED_OPENAI_MODELS]
model_list = VERIFIED_OPENAI_MODELS + model_list
provider_models = [
m for m in provider_models if m not in VERIFIED_OPENAI_MODELS
]
provider_models = VERIFIED_OPENAI_MODELS + provider_models
if provider == 'anthropic':
model_list = [m for m in model_list if m not in VERIFIED_ANTHROPIC_MODELS]
model_list = VERIFIED_ANTHROPIC_MODELS + model_list
provider_models = [
m for m in provider_models if m not in VERIFIED_ANTHROPIC_MODELS
]
provider_models = VERIFIED_ANTHROPIC_MODELS + provider_models
model_completer = FuzzyWordCompleter(model_list)
model_completer = FuzzyWordCompleter(provider_models)
model = await get_validated_input(
session,
'(Step 2/3) Select LLM Model (TAB for options, CTRL-c to cancel): ',
completer=model_completer,
validator=lambda x: x in organized_models[provider]['models'],
validator=lambda x: x in provider_models,
error_message=f'Invalid model selected for provider {provider}',
)
@@ -201,10 +205,8 @@ async def modify_llm_settings_basic(
):
return # Return on exception
# TODO: check for empty string inputs?
# Handle case where a prompt might return None unexpectedly
if provider is None or model is None or api_key is None:
return
# The try-except block above ensures we either have valid inputs or we've already returned
# No need to check for None values here
save_settings = save_settings_confirmation()
@@ -212,7 +214,7 @@ async def modify_llm_settings_basic(
return
llm_config = config.get_llm_config()
llm_config.model = provider + organized_models[provider]['separator'] + model
llm_config.model = f'{provider}{organized_models[provider]["separator"]}{model}'
llm_config.api_key = SecretStr(api_key)
llm_config.base_url = None
config.set_llm_config(llm_config)
@@ -232,7 +234,7 @@ async def modify_llm_settings_basic(
if not settings:
settings = Settings()
settings.llm_model = provider + organized_models[provider]['separator'] + model
settings.llm_model = f'{provider}{organized_models[provider]["separator"]}{model}'
settings.llm_api_key = SecretStr(api_key)
settings.llm_base_url = None
settings.agent = OH_DEFAULT_AGENT
@@ -244,7 +246,7 @@ async def modify_llm_settings_basic(
async def modify_llm_settings_advanced(
config: AppConfig, settings_store: FileSettingsStore
):
) -> None:
session = PromptSession(key_bindings=kb_cancel())
custom_model = None
@@ -304,10 +306,8 @@ async def modify_llm_settings_advanced(
):
return # Return on exception
# TODO: check for empty string inputs?
# Handle case where a prompt might return None unexpectedly
if custom_model is None or base_url is None or api_key is None or agent is None:
return
# The try-except block above ensures we either have valid inputs or we've already returned
# No need to check for None values here
save_settings = save_settings_confirmation()

View File

@@ -6,10 +6,12 @@ import asyncio
import sys
import threading
import time
from typing import Generator
from prompt_toolkit import PromptSession, print_formatted_text
from prompt_toolkit.application import Application
from prompt_toolkit.completion import Completer, Completion
from prompt_toolkit.completion import CompleteEvent, Completer, Completion
from prompt_toolkit.document import Document
from prompt_toolkit.formatted_text import HTML, FormattedText, StyleAndTextTuples
from prompt_toolkit.input import create_input
from prompt_toolkit.key_binding import KeyBindings
@@ -96,7 +98,7 @@ class CustomDiffLexer(Lexer):
# CLI initialization and startup display functions
def display_runtime_initialization_message(runtime: str):
def display_runtime_initialization_message(runtime: str) -> None:
print_formatted_text('')
if runtime == 'local':
print_formatted_text(HTML('<grey>⚙️ Starting local runtime...</grey>'))
@@ -105,7 +107,7 @@ def display_runtime_initialization_message(runtime: str):
print_formatted_text('')
def display_initialization_animation(text, is_loaded: asyncio.Event):
def display_initialization_animation(text: str, is_loaded: asyncio.Event) -> None:
ANIMATION_FRAMES = ['', '', '', '', '', '', '', '', '', '']
i = 0
@@ -122,7 +124,7 @@ def display_initialization_animation(text, is_loaded: asyncio.Event):
sys.stdout.flush()
def display_banner(session_id: str):
def display_banner(session_id: str) -> None:
print_formatted_text(
HTML(r"""<gold>
___ _ _ _
@@ -142,7 +144,7 @@ def display_banner(session_id: str):
print_formatted_text('')
def display_welcome_message():
def display_welcome_message() -> None:
print_formatted_text(
HTML("<gold>Let's start building!</gold>\n"), style=DEFAULT_STYLE
)
@@ -152,7 +154,7 @@ def display_welcome_message():
)
def display_initial_user_prompt(prompt: str):
def display_initial_user_prompt(prompt: str) -> None:
print_formatted_text(
FormattedText(
[
@@ -187,14 +189,14 @@ def display_event(event: Event, config: AppConfig) -> None:
display_agent_state_change_message(event.agent_state)
def display_message(message: str):
def display_message(message: str) -> None:
message = message.strip()
if message:
print_formatted_text(f'\n{message}')
def display_command(event: CmdRunAction):
def display_command(event: CmdRunAction) -> None:
if event.confirmation_state == ActionConfirmationStatus.AWAITING_CONFIRMATION:
container = Frame(
TextArea(
@@ -210,7 +212,7 @@ def display_command(event: CmdRunAction):
print_container(container)
def display_command_output(output: str):
def display_command_output(output: str) -> None:
lines = output.split('\n')
formatted_lines = []
for line in lines:
@@ -238,7 +240,7 @@ def display_command_output(output: str):
print_container(container)
def display_file_edit(event: FileEditObservation):
def display_file_edit(event: FileEditObservation) -> None:
container = Frame(
TextArea(
text=event.visualize_diff(n_context_lines=4),
@@ -253,7 +255,7 @@ def display_file_edit(event: FileEditObservation):
print_container(container)
def display_file_read(event: FileReadObservation):
def display_file_read(event: FileReadObservation) -> None:
content = event.content.replace('\t', ' ')
container = Frame(
TextArea(
@@ -270,7 +272,7 @@ def display_file_read(event: FileReadObservation):
# Interactive command output display functions
def display_help():
def display_help() -> None:
# Version header and introduction
print_formatted_text(
HTML(
@@ -314,7 +316,7 @@ def display_help():
)
def display_usage_metrics(usage_metrics: UsageMetrics):
def display_usage_metrics(usage_metrics: UsageMetrics) -> None:
cost_str = f'${usage_metrics.metrics.accumulated_cost:.6f}'
input_tokens_str = (
f'{usage_metrics.metrics.accumulated_token_usage.prompt_tokens:,}'
@@ -375,7 +377,7 @@ def get_session_duration(session_init_time: float) -> str:
return f'{int(hours)}h {int(minutes)}m {int(seconds)}s'
def display_shutdown_message(usage_metrics: UsageMetrics, session_id: str):
def display_shutdown_message(usage_metrics: UsageMetrics, session_id: str) -> None:
duration_str = get_session_duration(usage_metrics.session_init_time)
print_formatted_text(HTML('<grey>Closing current conversation...</grey>'))
@@ -388,7 +390,7 @@ def display_shutdown_message(usage_metrics: UsageMetrics, session_id: str):
print_formatted_text('')
def display_status(usage_metrics: UsageMetrics, session_id: str):
def display_status(usage_metrics: UsageMetrics, session_id: str) -> None:
duration_str = get_session_duration(usage_metrics.session_init_time)
print_formatted_text('')
@@ -398,14 +400,14 @@ def display_status(usage_metrics: UsageMetrics, session_id: str):
display_usage_metrics(usage_metrics)
def display_agent_running_message():
def display_agent_running_message() -> None:
print_formatted_text('')
print_formatted_text(
HTML('<gold>Agent running...</gold> <grey>(Press Ctrl-P to pause)</grey>')
)
def display_agent_state_change_message(agent_state: str):
def display_agent_state_change_message(agent_state: str) -> None:
if agent_state == AgentState.PAUSED:
print_formatted_text('')
print_formatted_text(
@@ -429,7 +431,9 @@ class CommandCompleter(Completer):
super().__init__()
self.agent_state = agent_state
def get_completions(self, document, complete_event):
def get_completions(
self, document: Document, complete_event: CompleteEvent
) -> Generator[Completion, None, None]:
text = document.text_before_cursor.lstrip()
if text.startswith('/'):
available_commands = dict(COMMANDS)
@@ -446,11 +450,11 @@ class CommandCompleter(Completer):
)
def create_prompt_session():
def create_prompt_session() -> PromptSession:
return PromptSession(style=DEFAULT_STYLE)
async def read_prompt_input(agent_state: str, multiline=False):
async def read_prompt_input(agent_state: str, multiline: bool = False) -> str:
try:
prompt_session = create_prompt_session()
prompt_session.completer = (
@@ -461,7 +465,7 @@ async def read_prompt_input(agent_state: str, multiline=False):
kb = KeyBindings()
@kb.add('c-d')
def _(event):
def _(event) -> None:
event.current_buffer.validate_and_handle()
with patch_stdout():
@@ -511,7 +515,7 @@ async def read_confirmation_input() -> str:
async def process_agent_pause(done: asyncio.Event, event_stream: EventStream) -> None:
input = create_input()
def keys_ready():
def keys_ready() -> None:
for key_press in input.read_keys():
if (
key_press.key == Keys.ControlP
@@ -543,7 +547,7 @@ def cli_confirm(
choices = ['Yes', 'No']
selected = [0] # Using list to allow modification in closure
def get_choice_text():
def get_choice_text() -> list:
return [
('class:question', f'{question}\n\n'),
] + [
@@ -557,15 +561,15 @@ def cli_confirm(
kb = KeyBindings()
@kb.add('up')
def _(event):
def _(event) -> None:
selected[0] = (selected[0] - 1) % len(choices)
@kb.add('down')
def _(event):
def _(event) -> None:
selected[0] = (selected[0] + 1) % len(choices)
@kb.add('enter')
def _(event):
def _(event) -> None:
event.app.exit(result=selected[0])
style = Style.from_dict({'selected': COLOR_GOLD, 'unselected': ''})
@@ -592,12 +596,12 @@ def cli_confirm(
return app.run(in_thread=True)
def kb_cancel():
def kb_cancel() -> KeyBindings:
"""Custom key bindings to handle ESC as a user cancellation."""
bindings = KeyBindings()
@bindings.add('escape')
def _(event):
def _(event) -> None:
event.app.exit(exception=UserCancelledError, style='class:aborting')
return bindings

View File

@@ -1,6 +1,7 @@
from pathlib import Path
import toml
from pydantic import BaseModel, Field
from openhands.cli.tui import (
UsageMetrics,
@@ -24,7 +25,7 @@ def get_local_config_trusted_dirs() -> list[str]:
return []
def add_local_config_trusted_dir(folder_path: str):
def add_local_config_trusted_dir(folder_path: str) -> None:
config = _DEFAULT_CONFIG
if _LOCAL_CONFIG_FILE_PATH.exists():
try:
@@ -47,7 +48,7 @@ def add_local_config_trusted_dir(folder_path: str):
toml.dump(config, f)
def update_usage_metrics(event: Event, usage_metrics: UsageMetrics):
def update_usage_metrics(event: Event, usage_metrics: UsageMetrics) -> None:
if not hasattr(event, 'llm_metrics'):
return
@@ -58,7 +59,34 @@ def update_usage_metrics(event: Event, usage_metrics: UsageMetrics):
usage_metrics.metrics = llm_metrics
def extract_model_and_provider(model):
class ModelInfo(BaseModel):
"""Information about a model and its provider."""
provider: str = Field(description='The provider of the model')
model: str = Field(description='The model identifier')
separator: str = Field(description='The separator used in the model identifier')
def __getitem__(self, key: str) -> str:
"""Allow dictionary-like access to fields."""
if key == 'provider':
return self.provider
elif key == 'model':
return self.model
elif key == 'separator':
return self.separator
raise KeyError(f'ModelInfo has no key {key}')
def extract_model_and_provider(model: str) -> ModelInfo:
"""
Extract provider and model information from a model identifier.
Args:
model: The model identifier string
Returns:
A ModelInfo object containing provider, model, and separator information
"""
separator = '/'
split = model.split(separator)
@@ -72,25 +100,36 @@ def extract_model_and_provider(model):
if len(split) == 1:
# no "/" or "." separator found
if split[0] in VERIFIED_OPENAI_MODELS:
return {'provider': 'openai', 'model': split[0], 'separator': '/'}
return ModelInfo(provider='openai', model=split[0], separator='/')
if split[0] in VERIFIED_ANTHROPIC_MODELS:
return {'provider': 'anthropic', 'model': split[0], 'separator': '/'}
return ModelInfo(provider='anthropic', model=split[0], separator='/')
# return as model only
return {'provider': '', 'model': model, 'separator': ''}
return ModelInfo(provider='', model=model, separator='')
provider = split[0]
model_id = separator.join(split[1:])
return {'provider': provider, 'model': model_id, 'separator': separator}
return ModelInfo(provider=provider, model=model_id, separator=separator)
def organize_models_and_providers(models):
result = {}
def organize_models_and_providers(
models: list[str],
) -> dict[str, 'ProviderInfo']:
"""
Organize a list of model identifiers by provider.
Args:
models: List of model identifiers
Returns:
A mapping of providers to their information and models
"""
result_dict: dict[str, ProviderInfo] = {}
for model in models:
extracted = extract_model_and_provider(model)
separator = extracted['separator']
provider = extracted['provider']
model_id = extracted['model']
separator = extracted.separator
provider = extracted.provider
model_id = extracted.model
# Ignore "anthropic" providers with a separator of "."
# These are outdated and incompatible providers.
@@ -98,12 +137,12 @@ def organize_models_and_providers(models):
continue
key = provider or 'other'
if key not in result:
result[key] = {'separator': separator, 'models': []}
if key not in result_dict:
result_dict[key] = ProviderInfo(separator=separator, models=[])
result[key]['models'].append(model_id)
result_dict[key].models.append(model_id)
return result
return result_dict
VERIFIED_PROVIDERS = ['openai', 'azure', 'anthropic', 'deepseek']
@@ -133,19 +172,48 @@ VERIFIED_ANTHROPIC_MODELS = [
]
def is_number(char):
class ProviderInfo(BaseModel):
"""Information about a provider and its models."""
separator: str = Field(description='The separator used in model identifiers')
models: list[str] = Field(
default_factory=list, description='List of model identifiers'
)
def __getitem__(self, key: str) -> str | list[str]:
"""Allow dictionary-like access to fields."""
if key == 'separator':
return self.separator
elif key == 'models':
return self.models
raise KeyError(f'ProviderInfo has no key {key}')
def get(self, key: str, default=None) -> str | list[str] | None:
"""Dictionary-like get method with default value."""
try:
return self[key]
except KeyError:
return default
def is_number(char: str) -> bool:
return char.isdigit()
def split_is_actually_version(split):
return len(split) > 1 and split[1] and split[1][0] and is_number(split[1][0])
def split_is_actually_version(split: list[str]) -> bool:
return (
len(split) > 1
and bool(split[1])
and bool(split[1][0])
and is_number(split[1][0])
)
def read_file(file_path):
def read_file(file_path: str | Path) -> str:
with open(file_path, 'r') as f:
return f.read()
def write_to_file(file_path, content):
def write_to_file(file_path: str | Path, content: str) -> None:
with open(file_path, 'w') as f:
f.write(content)