mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-09 06:48:02 -05:00
Add type annotations to CLI directory (#8291)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
@@ -100,7 +100,7 @@ def handle_exit_command(
|
||||
return close_repl
|
||||
|
||||
|
||||
def handle_help_command():
|
||||
def handle_help_command() -> None:
|
||||
display_help()
|
||||
|
||||
|
||||
@@ -135,7 +135,7 @@ async def handle_init_command(
|
||||
return close_repl, reload_microagents
|
||||
|
||||
|
||||
def handle_status_command(usage_metrics: UsageMetrics, sid: str):
|
||||
def handle_status_command(usage_metrics: UsageMetrics, sid: str) -> None:
|
||||
display_status(usage_metrics, sid)
|
||||
|
||||
|
||||
@@ -168,7 +168,7 @@ def handle_new_command(
|
||||
async def handle_settings_command(
|
||||
config: AppConfig,
|
||||
settings_store: FileSettingsStore,
|
||||
):
|
||||
) -> None:
|
||||
display_settings(config)
|
||||
modify_settings = cli_confirm(
|
||||
'\nWhich settings would you like to modify?',
|
||||
@@ -213,6 +213,7 @@ async def init_repository(current_dir: str) -> bool:
|
||||
|
||||
if repo_file_path.exists():
|
||||
try:
|
||||
# Path.exists() ensures repo_file_path is not None, so we can safely pass it to read_file
|
||||
content = await asyncio.get_event_loop().run_in_executor(
|
||||
None, read_file, repo_file_path
|
||||
)
|
||||
@@ -263,7 +264,7 @@ async def init_repository(current_dir: str) -> bool:
|
||||
return init_repo
|
||||
|
||||
|
||||
def check_folder_security_agreement(config: AppConfig, current_dir):
|
||||
def check_folder_security_agreement(config: AppConfig, current_dir: str) -> bool:
|
||||
# Directories trusted by user for the CLI to use as workspace
|
||||
# Config from ~/.openhands/config.toml overrides the app config
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ async def cleanup_session(
|
||||
agent: Agent,
|
||||
runtime: Runtime,
|
||||
controller: AgentController,
|
||||
):
|
||||
) -> None:
|
||||
"""Clean up all resources from the current session."""
|
||||
try:
|
||||
# Cancel all running tasks except the current one
|
||||
@@ -126,7 +126,7 @@ async def run_session(
|
||||
|
||||
usage_metrics = UsageMetrics()
|
||||
|
||||
async def prompt_for_next_task(agent_state: str):
|
||||
async def prompt_for_next_task(agent_state: str) -> None:
|
||||
nonlocal reload_microagents, new_session_requested
|
||||
while True:
|
||||
next_message = await read_prompt_input(
|
||||
@@ -271,7 +271,7 @@ async def run_session(
|
||||
return new_session_requested
|
||||
|
||||
|
||||
async def main(loop: asyncio.AbstractEventLoop):
|
||||
async def main(loop: asyncio.AbstractEventLoop) -> None:
|
||||
"""Runs the agent in CLI mode."""
|
||||
|
||||
args = parse_arguments()
|
||||
|
||||
@@ -29,7 +29,7 @@ from openhands.storage.settings.file_settings_store import FileSettingsStore
|
||||
from openhands.utils.llm import get_supported_llm_models
|
||||
|
||||
|
||||
def display_settings(config: AppConfig):
|
||||
def display_settings(config: AppConfig) -> None:
|
||||
llm_config = config.get_llm_config()
|
||||
advanced_llm_settings = True if llm_config.base_url else False
|
||||
|
||||
@@ -108,8 +108,8 @@ async def get_validated_input(
|
||||
prompt_text: str,
|
||||
completer=None,
|
||||
validator=None,
|
||||
error_message='Input cannot be empty',
|
||||
):
|
||||
error_message: str = 'Input cannot be empty',
|
||||
) -> str:
|
||||
session.completer = completer
|
||||
value = None
|
||||
|
||||
@@ -146,7 +146,7 @@ def save_settings_confirmation() -> bool:
|
||||
|
||||
async def modify_llm_settings_basic(
|
||||
config: AppConfig, settings_store: FileSettingsStore
|
||||
):
|
||||
) -> None:
|
||||
model_list = get_supported_llm_models(config)
|
||||
organized_models = organize_models_and_providers(model_list)
|
||||
|
||||
@@ -171,20 +171,24 @@ async def modify_llm_settings_basic(
|
||||
error_message='Invalid provider selected',
|
||||
)
|
||||
|
||||
model_list = organized_models[provider]['models']
|
||||
provider_models = organized_models[provider]['models']
|
||||
if provider == 'openai':
|
||||
model_list = [m for m in model_list if m not in VERIFIED_OPENAI_MODELS]
|
||||
model_list = VERIFIED_OPENAI_MODELS + model_list
|
||||
provider_models = [
|
||||
m for m in provider_models if m not in VERIFIED_OPENAI_MODELS
|
||||
]
|
||||
provider_models = VERIFIED_OPENAI_MODELS + provider_models
|
||||
if provider == 'anthropic':
|
||||
model_list = [m for m in model_list if m not in VERIFIED_ANTHROPIC_MODELS]
|
||||
model_list = VERIFIED_ANTHROPIC_MODELS + model_list
|
||||
provider_models = [
|
||||
m for m in provider_models if m not in VERIFIED_ANTHROPIC_MODELS
|
||||
]
|
||||
provider_models = VERIFIED_ANTHROPIC_MODELS + provider_models
|
||||
|
||||
model_completer = FuzzyWordCompleter(model_list)
|
||||
model_completer = FuzzyWordCompleter(provider_models)
|
||||
model = await get_validated_input(
|
||||
session,
|
||||
'(Step 2/3) Select LLM Model (TAB for options, CTRL-c to cancel): ',
|
||||
completer=model_completer,
|
||||
validator=lambda x: x in organized_models[provider]['models'],
|
||||
validator=lambda x: x in provider_models,
|
||||
error_message=f'Invalid model selected for provider {provider}',
|
||||
)
|
||||
|
||||
@@ -201,10 +205,8 @@ async def modify_llm_settings_basic(
|
||||
):
|
||||
return # Return on exception
|
||||
|
||||
# TODO: check for empty string inputs?
|
||||
# Handle case where a prompt might return None unexpectedly
|
||||
if provider is None or model is None or api_key is None:
|
||||
return
|
||||
# The try-except block above ensures we either have valid inputs or we've already returned
|
||||
# No need to check for None values here
|
||||
|
||||
save_settings = save_settings_confirmation()
|
||||
|
||||
@@ -212,7 +214,7 @@ async def modify_llm_settings_basic(
|
||||
return
|
||||
|
||||
llm_config = config.get_llm_config()
|
||||
llm_config.model = provider + organized_models[provider]['separator'] + model
|
||||
llm_config.model = f'{provider}{organized_models[provider]["separator"]}{model}'
|
||||
llm_config.api_key = SecretStr(api_key)
|
||||
llm_config.base_url = None
|
||||
config.set_llm_config(llm_config)
|
||||
@@ -232,7 +234,7 @@ async def modify_llm_settings_basic(
|
||||
if not settings:
|
||||
settings = Settings()
|
||||
|
||||
settings.llm_model = provider + organized_models[provider]['separator'] + model
|
||||
settings.llm_model = f'{provider}{organized_models[provider]["separator"]}{model}'
|
||||
settings.llm_api_key = SecretStr(api_key)
|
||||
settings.llm_base_url = None
|
||||
settings.agent = OH_DEFAULT_AGENT
|
||||
@@ -244,7 +246,7 @@ async def modify_llm_settings_basic(
|
||||
|
||||
async def modify_llm_settings_advanced(
|
||||
config: AppConfig, settings_store: FileSettingsStore
|
||||
):
|
||||
) -> None:
|
||||
session = PromptSession(key_bindings=kb_cancel())
|
||||
|
||||
custom_model = None
|
||||
@@ -304,10 +306,8 @@ async def modify_llm_settings_advanced(
|
||||
):
|
||||
return # Return on exception
|
||||
|
||||
# TODO: check for empty string inputs?
|
||||
# Handle case where a prompt might return None unexpectedly
|
||||
if custom_model is None or base_url is None or api_key is None or agent is None:
|
||||
return
|
||||
# The try-except block above ensures we either have valid inputs or we've already returned
|
||||
# No need to check for None values here
|
||||
|
||||
save_settings = save_settings_confirmation()
|
||||
|
||||
|
||||
@@ -6,10 +6,12 @@ import asyncio
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from typing import Generator
|
||||
|
||||
from prompt_toolkit import PromptSession, print_formatted_text
|
||||
from prompt_toolkit.application import Application
|
||||
from prompt_toolkit.completion import Completer, Completion
|
||||
from prompt_toolkit.completion import CompleteEvent, Completer, Completion
|
||||
from prompt_toolkit.document import Document
|
||||
from prompt_toolkit.formatted_text import HTML, FormattedText, StyleAndTextTuples
|
||||
from prompt_toolkit.input import create_input
|
||||
from prompt_toolkit.key_binding import KeyBindings
|
||||
@@ -96,7 +98,7 @@ class CustomDiffLexer(Lexer):
|
||||
|
||||
|
||||
# CLI initialization and startup display functions
|
||||
def display_runtime_initialization_message(runtime: str):
|
||||
def display_runtime_initialization_message(runtime: str) -> None:
|
||||
print_formatted_text('')
|
||||
if runtime == 'local':
|
||||
print_formatted_text(HTML('<grey>⚙️ Starting local runtime...</grey>'))
|
||||
@@ -105,7 +107,7 @@ def display_runtime_initialization_message(runtime: str):
|
||||
print_formatted_text('')
|
||||
|
||||
|
||||
def display_initialization_animation(text, is_loaded: asyncio.Event):
|
||||
def display_initialization_animation(text: str, is_loaded: asyncio.Event) -> None:
|
||||
ANIMATION_FRAMES = ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏']
|
||||
|
||||
i = 0
|
||||
@@ -122,7 +124,7 @@ def display_initialization_animation(text, is_loaded: asyncio.Event):
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def display_banner(session_id: str):
|
||||
def display_banner(session_id: str) -> None:
|
||||
print_formatted_text(
|
||||
HTML(r"""<gold>
|
||||
___ _ _ _
|
||||
@@ -142,7 +144,7 @@ def display_banner(session_id: str):
|
||||
print_formatted_text('')
|
||||
|
||||
|
||||
def display_welcome_message():
|
||||
def display_welcome_message() -> None:
|
||||
print_formatted_text(
|
||||
HTML("<gold>Let's start building!</gold>\n"), style=DEFAULT_STYLE
|
||||
)
|
||||
@@ -152,7 +154,7 @@ def display_welcome_message():
|
||||
)
|
||||
|
||||
|
||||
def display_initial_user_prompt(prompt: str):
|
||||
def display_initial_user_prompt(prompt: str) -> None:
|
||||
print_formatted_text(
|
||||
FormattedText(
|
||||
[
|
||||
@@ -187,14 +189,14 @@ def display_event(event: Event, config: AppConfig) -> None:
|
||||
display_agent_state_change_message(event.agent_state)
|
||||
|
||||
|
||||
def display_message(message: str):
|
||||
def display_message(message: str) -> None:
|
||||
message = message.strip()
|
||||
|
||||
if message:
|
||||
print_formatted_text(f'\n{message}')
|
||||
|
||||
|
||||
def display_command(event: CmdRunAction):
|
||||
def display_command(event: CmdRunAction) -> None:
|
||||
if event.confirmation_state == ActionConfirmationStatus.AWAITING_CONFIRMATION:
|
||||
container = Frame(
|
||||
TextArea(
|
||||
@@ -210,7 +212,7 @@ def display_command(event: CmdRunAction):
|
||||
print_container(container)
|
||||
|
||||
|
||||
def display_command_output(output: str):
|
||||
def display_command_output(output: str) -> None:
|
||||
lines = output.split('\n')
|
||||
formatted_lines = []
|
||||
for line in lines:
|
||||
@@ -238,7 +240,7 @@ def display_command_output(output: str):
|
||||
print_container(container)
|
||||
|
||||
|
||||
def display_file_edit(event: FileEditObservation):
|
||||
def display_file_edit(event: FileEditObservation) -> None:
|
||||
container = Frame(
|
||||
TextArea(
|
||||
text=event.visualize_diff(n_context_lines=4),
|
||||
@@ -253,7 +255,7 @@ def display_file_edit(event: FileEditObservation):
|
||||
print_container(container)
|
||||
|
||||
|
||||
def display_file_read(event: FileReadObservation):
|
||||
def display_file_read(event: FileReadObservation) -> None:
|
||||
content = event.content.replace('\t', ' ')
|
||||
container = Frame(
|
||||
TextArea(
|
||||
@@ -270,7 +272,7 @@ def display_file_read(event: FileReadObservation):
|
||||
|
||||
|
||||
# Interactive command output display functions
|
||||
def display_help():
|
||||
def display_help() -> None:
|
||||
# Version header and introduction
|
||||
print_formatted_text(
|
||||
HTML(
|
||||
@@ -314,7 +316,7 @@ def display_help():
|
||||
)
|
||||
|
||||
|
||||
def display_usage_metrics(usage_metrics: UsageMetrics):
|
||||
def display_usage_metrics(usage_metrics: UsageMetrics) -> None:
|
||||
cost_str = f'${usage_metrics.metrics.accumulated_cost:.6f}'
|
||||
input_tokens_str = (
|
||||
f'{usage_metrics.metrics.accumulated_token_usage.prompt_tokens:,}'
|
||||
@@ -375,7 +377,7 @@ def get_session_duration(session_init_time: float) -> str:
|
||||
return f'{int(hours)}h {int(minutes)}m {int(seconds)}s'
|
||||
|
||||
|
||||
def display_shutdown_message(usage_metrics: UsageMetrics, session_id: str):
|
||||
def display_shutdown_message(usage_metrics: UsageMetrics, session_id: str) -> None:
|
||||
duration_str = get_session_duration(usage_metrics.session_init_time)
|
||||
|
||||
print_formatted_text(HTML('<grey>Closing current conversation...</grey>'))
|
||||
@@ -388,7 +390,7 @@ def display_shutdown_message(usage_metrics: UsageMetrics, session_id: str):
|
||||
print_formatted_text('')
|
||||
|
||||
|
||||
def display_status(usage_metrics: UsageMetrics, session_id: str):
|
||||
def display_status(usage_metrics: UsageMetrics, session_id: str) -> None:
|
||||
duration_str = get_session_duration(usage_metrics.session_init_time)
|
||||
|
||||
print_formatted_text('')
|
||||
@@ -398,14 +400,14 @@ def display_status(usage_metrics: UsageMetrics, session_id: str):
|
||||
display_usage_metrics(usage_metrics)
|
||||
|
||||
|
||||
def display_agent_running_message():
|
||||
def display_agent_running_message() -> None:
|
||||
print_formatted_text('')
|
||||
print_formatted_text(
|
||||
HTML('<gold>Agent running...</gold> <grey>(Press Ctrl-P to pause)</grey>')
|
||||
)
|
||||
|
||||
|
||||
def display_agent_state_change_message(agent_state: str):
|
||||
def display_agent_state_change_message(agent_state: str) -> None:
|
||||
if agent_state == AgentState.PAUSED:
|
||||
print_formatted_text('')
|
||||
print_formatted_text(
|
||||
@@ -429,7 +431,9 @@ class CommandCompleter(Completer):
|
||||
super().__init__()
|
||||
self.agent_state = agent_state
|
||||
|
||||
def get_completions(self, document, complete_event):
|
||||
def get_completions(
|
||||
self, document: Document, complete_event: CompleteEvent
|
||||
) -> Generator[Completion, None, None]:
|
||||
text = document.text_before_cursor.lstrip()
|
||||
if text.startswith('/'):
|
||||
available_commands = dict(COMMANDS)
|
||||
@@ -446,11 +450,11 @@ class CommandCompleter(Completer):
|
||||
)
|
||||
|
||||
|
||||
def create_prompt_session():
|
||||
def create_prompt_session() -> PromptSession:
|
||||
return PromptSession(style=DEFAULT_STYLE)
|
||||
|
||||
|
||||
async def read_prompt_input(agent_state: str, multiline=False):
|
||||
async def read_prompt_input(agent_state: str, multiline: bool = False) -> str:
|
||||
try:
|
||||
prompt_session = create_prompt_session()
|
||||
prompt_session.completer = (
|
||||
@@ -461,7 +465,7 @@ async def read_prompt_input(agent_state: str, multiline=False):
|
||||
kb = KeyBindings()
|
||||
|
||||
@kb.add('c-d')
|
||||
def _(event):
|
||||
def _(event) -> None:
|
||||
event.current_buffer.validate_and_handle()
|
||||
|
||||
with patch_stdout():
|
||||
@@ -511,7 +515,7 @@ async def read_confirmation_input() -> str:
|
||||
async def process_agent_pause(done: asyncio.Event, event_stream: EventStream) -> None:
|
||||
input = create_input()
|
||||
|
||||
def keys_ready():
|
||||
def keys_ready() -> None:
|
||||
for key_press in input.read_keys():
|
||||
if (
|
||||
key_press.key == Keys.ControlP
|
||||
@@ -543,7 +547,7 @@ def cli_confirm(
|
||||
choices = ['Yes', 'No']
|
||||
selected = [0] # Using list to allow modification in closure
|
||||
|
||||
def get_choice_text():
|
||||
def get_choice_text() -> list:
|
||||
return [
|
||||
('class:question', f'{question}\n\n'),
|
||||
] + [
|
||||
@@ -557,15 +561,15 @@ def cli_confirm(
|
||||
kb = KeyBindings()
|
||||
|
||||
@kb.add('up')
|
||||
def _(event):
|
||||
def _(event) -> None:
|
||||
selected[0] = (selected[0] - 1) % len(choices)
|
||||
|
||||
@kb.add('down')
|
||||
def _(event):
|
||||
def _(event) -> None:
|
||||
selected[0] = (selected[0] + 1) % len(choices)
|
||||
|
||||
@kb.add('enter')
|
||||
def _(event):
|
||||
def _(event) -> None:
|
||||
event.app.exit(result=selected[0])
|
||||
|
||||
style = Style.from_dict({'selected': COLOR_GOLD, 'unselected': ''})
|
||||
@@ -592,12 +596,12 @@ def cli_confirm(
|
||||
return app.run(in_thread=True)
|
||||
|
||||
|
||||
def kb_cancel():
|
||||
def kb_cancel() -> KeyBindings:
|
||||
"""Custom key bindings to handle ESC as a user cancellation."""
|
||||
bindings = KeyBindings()
|
||||
|
||||
@bindings.add('escape')
|
||||
def _(event):
|
||||
def _(event) -> None:
|
||||
event.app.exit(exception=UserCancelledError, style='class:aborting')
|
||||
|
||||
return bindings
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from pathlib import Path
|
||||
|
||||
import toml
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from openhands.cli.tui import (
|
||||
UsageMetrics,
|
||||
@@ -24,7 +25,7 @@ def get_local_config_trusted_dirs() -> list[str]:
|
||||
return []
|
||||
|
||||
|
||||
def add_local_config_trusted_dir(folder_path: str):
|
||||
def add_local_config_trusted_dir(folder_path: str) -> None:
|
||||
config = _DEFAULT_CONFIG
|
||||
if _LOCAL_CONFIG_FILE_PATH.exists():
|
||||
try:
|
||||
@@ -47,7 +48,7 @@ def add_local_config_trusted_dir(folder_path: str):
|
||||
toml.dump(config, f)
|
||||
|
||||
|
||||
def update_usage_metrics(event: Event, usage_metrics: UsageMetrics):
|
||||
def update_usage_metrics(event: Event, usage_metrics: UsageMetrics) -> None:
|
||||
if not hasattr(event, 'llm_metrics'):
|
||||
return
|
||||
|
||||
@@ -58,7 +59,34 @@ def update_usage_metrics(event: Event, usage_metrics: UsageMetrics):
|
||||
usage_metrics.metrics = llm_metrics
|
||||
|
||||
|
||||
def extract_model_and_provider(model):
|
||||
class ModelInfo(BaseModel):
|
||||
"""Information about a model and its provider."""
|
||||
|
||||
provider: str = Field(description='The provider of the model')
|
||||
model: str = Field(description='The model identifier')
|
||||
separator: str = Field(description='The separator used in the model identifier')
|
||||
|
||||
def __getitem__(self, key: str) -> str:
|
||||
"""Allow dictionary-like access to fields."""
|
||||
if key == 'provider':
|
||||
return self.provider
|
||||
elif key == 'model':
|
||||
return self.model
|
||||
elif key == 'separator':
|
||||
return self.separator
|
||||
raise KeyError(f'ModelInfo has no key {key}')
|
||||
|
||||
|
||||
def extract_model_and_provider(model: str) -> ModelInfo:
|
||||
"""
|
||||
Extract provider and model information from a model identifier.
|
||||
|
||||
Args:
|
||||
model: The model identifier string
|
||||
|
||||
Returns:
|
||||
A ModelInfo object containing provider, model, and separator information
|
||||
"""
|
||||
separator = '/'
|
||||
split = model.split(separator)
|
||||
|
||||
@@ -72,25 +100,36 @@ def extract_model_and_provider(model):
|
||||
if len(split) == 1:
|
||||
# no "/" or "." separator found
|
||||
if split[0] in VERIFIED_OPENAI_MODELS:
|
||||
return {'provider': 'openai', 'model': split[0], 'separator': '/'}
|
||||
return ModelInfo(provider='openai', model=split[0], separator='/')
|
||||
if split[0] in VERIFIED_ANTHROPIC_MODELS:
|
||||
return {'provider': 'anthropic', 'model': split[0], 'separator': '/'}
|
||||
return ModelInfo(provider='anthropic', model=split[0], separator='/')
|
||||
# return as model only
|
||||
return {'provider': '', 'model': model, 'separator': ''}
|
||||
return ModelInfo(provider='', model=model, separator='')
|
||||
|
||||
provider = split[0]
|
||||
model_id = separator.join(split[1:])
|
||||
return {'provider': provider, 'model': model_id, 'separator': separator}
|
||||
return ModelInfo(provider=provider, model=model_id, separator=separator)
|
||||
|
||||
|
||||
def organize_models_and_providers(models):
|
||||
result = {}
|
||||
def organize_models_and_providers(
|
||||
models: list[str],
|
||||
) -> dict[str, 'ProviderInfo']:
|
||||
"""
|
||||
Organize a list of model identifiers by provider.
|
||||
|
||||
Args:
|
||||
models: List of model identifiers
|
||||
|
||||
Returns:
|
||||
A mapping of providers to their information and models
|
||||
"""
|
||||
result_dict: dict[str, ProviderInfo] = {}
|
||||
|
||||
for model in models:
|
||||
extracted = extract_model_and_provider(model)
|
||||
separator = extracted['separator']
|
||||
provider = extracted['provider']
|
||||
model_id = extracted['model']
|
||||
separator = extracted.separator
|
||||
provider = extracted.provider
|
||||
model_id = extracted.model
|
||||
|
||||
# Ignore "anthropic" providers with a separator of "."
|
||||
# These are outdated and incompatible providers.
|
||||
@@ -98,12 +137,12 @@ def organize_models_and_providers(models):
|
||||
continue
|
||||
|
||||
key = provider or 'other'
|
||||
if key not in result:
|
||||
result[key] = {'separator': separator, 'models': []}
|
||||
if key not in result_dict:
|
||||
result_dict[key] = ProviderInfo(separator=separator, models=[])
|
||||
|
||||
result[key]['models'].append(model_id)
|
||||
result_dict[key].models.append(model_id)
|
||||
|
||||
return result
|
||||
return result_dict
|
||||
|
||||
|
||||
VERIFIED_PROVIDERS = ['openai', 'azure', 'anthropic', 'deepseek']
|
||||
@@ -133,19 +172,48 @@ VERIFIED_ANTHROPIC_MODELS = [
|
||||
]
|
||||
|
||||
|
||||
def is_number(char):
|
||||
class ProviderInfo(BaseModel):
|
||||
"""Information about a provider and its models."""
|
||||
|
||||
separator: str = Field(description='The separator used in model identifiers')
|
||||
models: list[str] = Field(
|
||||
default_factory=list, description='List of model identifiers'
|
||||
)
|
||||
|
||||
def __getitem__(self, key: str) -> str | list[str]:
|
||||
"""Allow dictionary-like access to fields."""
|
||||
if key == 'separator':
|
||||
return self.separator
|
||||
elif key == 'models':
|
||||
return self.models
|
||||
raise KeyError(f'ProviderInfo has no key {key}')
|
||||
|
||||
def get(self, key: str, default=None) -> str | list[str] | None:
|
||||
"""Dictionary-like get method with default value."""
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
|
||||
def is_number(char: str) -> bool:
|
||||
return char.isdigit()
|
||||
|
||||
|
||||
def split_is_actually_version(split):
|
||||
return len(split) > 1 and split[1] and split[1][0] and is_number(split[1][0])
|
||||
def split_is_actually_version(split: list[str]) -> bool:
|
||||
return (
|
||||
len(split) > 1
|
||||
and bool(split[1])
|
||||
and bool(split[1][0])
|
||||
and is_number(split[1][0])
|
||||
)
|
||||
|
||||
|
||||
def read_file(file_path):
|
||||
def read_file(file_path: str | Path) -> str:
|
||||
with open(file_path, 'r') as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def write_to_file(file_path, content):
|
||||
def write_to_file(file_path: str | Path, content: str) -> None:
|
||||
with open(file_path, 'w') as f:
|
||||
f.write(content)
|
||||
|
||||
Reference in New Issue
Block a user