Fix mypy errors in agenthub directory (#6811)

Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Xingyao Wang <xingyao@all-hands.dev>
Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
This commit is contained in:
Graham Neubig
2025-02-21 08:55:27 -05:00
committed by GitHub
parent 9d3a0a02b8
commit f093c14ad3
10 changed files with 77 additions and 42 deletions

View File

@@ -10,17 +10,24 @@ from openhands.events.action import (
class BrowsingResponseParser(ResponseParser):
def __init__(self):
def __init__(self) -> None:
# Need to pay attention to the item order in self.action_parsers
super().__init__()
self.action_parsers = [BrowsingActionParserMessage()]
self.default_parser = BrowsingActionParserBrowseInteractive()
def parse(self, response: str) -> Action:
action_str = self.parse_response(response)
def parse(
self, response: str | dict[str, list[dict[str, dict[str, str | None]]]]
) -> Action:
if isinstance(response, str):
action_str = response
else:
action_str = self.parse_response(response)
return self.parse_action(action_str)
def parse_response(self, response) -> str:
def parse_response(
self, response: dict[str, list[dict[str, dict[str, str | None]]]]
) -> str:
action_str = response['choices'][0]['message']['content']
if action_str is None:
return ''
@@ -47,9 +54,7 @@ class BrowsingActionParserMessage(ActionParser):
- BrowseInteractiveAction(browser_actions) - unexpected response format, message back to user
"""
def __init__(
self,
):
def __init__(self) -> None:
pass
def check_condition(self, action_str: str) -> bool:
@@ -69,9 +74,7 @@ class BrowsingActionParserBrowseInteractive(ActionParser):
- BrowseInteractiveAction(browser_actions) - handle send message to user function call in BrowserGym
"""
def __init__(
self,
):
def __init__(self) -> None:
pass
def check_condition(self, action_str: str) -> bool:

View File

@@ -5,7 +5,7 @@ from warnings import warn
import yaml
def yaml_parser(message):
def yaml_parser(message: str) -> tuple[dict, bool, str]:
"""Parse a yaml message for the retry function."""
# saves gpt-3.5 from some yaml parsing errors
message = re.sub(r':\s*\n(?=\S|\n)', ': ', message)
@@ -22,7 +22,9 @@ def yaml_parser(message):
return value, valid, retry_message
def _compress_chunks(text, identifier, skip_list, split_regex='\n\n+'):
def _compress_chunks(
text: str, identifier: str, skip_list: list[str], split_regex: str = '\n\n+'
) -> tuple[dict[str, str], str]:
"""Compress a string by replacing redundant chunks by identifiers. Chunks are defined by the split_regex."""
text_list = re.split(split_regex, text)
text_list = [chunk.strip() for chunk in text_list]
@@ -44,7 +46,7 @@ def _compress_chunks(text, identifier, skip_list, split_regex='\n\n+'):
return def_dict, compressed_text
def compress_string(text):
def compress_string(text: str) -> str:
"""Compress a string by replacing redundant paragraphs and lines with identifiers."""
# Perform paragraph-level compression
def_dict, compressed_text = _compress_chunks(
@@ -67,7 +69,7 @@ def compress_string(text):
return definitions + '\n' + compressed_text
def extract_html_tags(text, keys):
def extract_html_tags(text: str, keys: list[str]) -> dict[str, list[str]]:
"""Extract the content within HTML tags for a list of keys.
Parameters
@@ -102,7 +104,12 @@ class ParseError(Exception):
pass
def parse_html_tags_raise(text, keys=(), optional_keys=(), merge_multiple=False):
def parse_html_tags_raise(
text: str,
keys: list[str] | None = None,
optional_keys: list[str] | None = None,
merge_multiple: bool = False,
) -> dict[str, str]:
"""A version of parse_html_tags that raises an exception if the parsing is not successful."""
content_dict, valid, retry_message = parse_html_tags(
text, keys, optional_keys, merge_multiple=merge_multiple
@@ -112,7 +119,12 @@ def parse_html_tags_raise(text, keys=(), optional_keys=(), merge_multiple=False)
return content_dict
def parse_html_tags(text, keys=(), optional_keys=(), merge_multiple=False):
def parse_html_tags(
text: str,
keys: list[str] | None = None,
optional_keys: list[str] | None = None,
merge_multiple: bool = False,
) -> tuple[dict[str, str], bool, str]:
"""Satisfy the parse api, extracts 1 match per key and validates that all keys are present
Parameters
@@ -133,9 +145,12 @@ def parse_html_tags(text, keys=(), optional_keys=(), merge_multiple=False):
str
A message to be displayed to the agent if the parsing was not successful.
"""
all_keys = tuple(keys) + tuple(optional_keys)
keys = keys or []
optional_keys = optional_keys or []
all_keys = list(keys) + list(optional_keys)
content_dict = extract_html_tags(text, all_keys)
retry_messages = []
result_dict: dict[str, str] = {}
for key in all_keys:
if key not in content_dict:
@@ -143,7 +158,6 @@ def parse_html_tags(text, keys=(), optional_keys=(), merge_multiple=False):
retry_messages.append(f'Missing the key <{key}> in the answer.')
else:
val = content_dict[key]
content_dict[key] = val[0]
if len(val) > 1:
if not merge_multiple:
retry_messages.append(
@@ -151,8 +165,10 @@ def parse_html_tags(text, keys=(), optional_keys=(), merge_multiple=False):
)
else:
# merge the multiple instances
content_dict[key] = '\n'.join(val)
result_dict[key] = '\n'.join(val)
else:
result_dict[key] = val[0]
valid = len(retry_messages) == 0
retry_message = '\n'.join(retry_messages)
return content_dict, valid, retry_message
return result_dict, valid, retry_message

View File

@@ -475,8 +475,9 @@ def combine_thought(action: Action, thought: str) -> Action:
def response_to_actions(response: ModelResponse) -> list[Action]:
actions: list[Action] = []
assert len(response.choices) == 1, 'Only one choice is supported for now'
assistant_msg = response.choices[0].message
if assistant_msg.tool_calls:
choice = response.choices[0]
assistant_msg = choice.message
if hasattr(assistant_msg, 'tool_calls') and assistant_msg.tool_calls:
# Check if there's assistant_msg.content. If so, add it to the thought
thought = ''
if isinstance(assistant_msg.content, str):
@@ -592,7 +593,10 @@ def response_to_actions(response: ModelResponse) -> list[Action]:
actions.append(action)
else:
actions.append(
MessageAction(content=assistant_msg.content, wait_for_response=True)
MessageAction(
content=str(assistant_msg.content) if assistant_msg.content else '',
wait_for_response=True,
)
)
assert len(actions) >= 1

View File

@@ -22,7 +22,7 @@ def parse_response(orig_response: str) -> Action:
return action_from_dict(action_dict)
def to_json(obj, **kwargs):
def to_json(obj: object, **kwargs: dict) -> str:
"""Serialize an object to str format"""
return json.dumps(obj, **kwargs)
@@ -32,7 +32,9 @@ class MicroAgent(Agent):
prompt = ''
agent_definition: dict = {}
def history_to_json(self, history: list[Event], max_events: int = 20, **kwargs):
def history_to_json(
self, history: list[Event], max_events: int = 20, **kwargs: dict
) -> str:
"""
Serialize and simplify history to str format
"""
@@ -60,7 +62,7 @@ class MicroAgent(Agent):
super().__init__(llm, config)
if 'name' not in self.agent_definition:
raise ValueError('Agent definition must contain a name')
self.prompt_template = Environment(loader=BaseLoader).from_string(self.prompt)
self.prompt_template = Environment(loader=BaseLoader()).from_string(self.prompt)
self.delegates = all_microagents.copy()
del self.delegates[self.agent_definition['name']]
@@ -74,7 +76,7 @@ class MicroAgent(Agent):
delegates=self.delegates,
latest_user_message=last_user_message,
)
content = [TextContent(text=prompt)]
content: list[TextContent | ImageContent] = [TextContent(text=prompt)]
if self.llm.vision_is_active() and last_image_urls:
content.append(ImageContent(image_urls=last_image_urls))
message = Message(role='user', content=content)

View File

@@ -29,7 +29,9 @@ def get_error_prefix(obs: BrowserOutputObservation) -> str:
return f'## Error from previous action:\n{obs.last_browser_action_error}\n'
def create_goal_prompt(goal: str, image_urls: list[str] | None):
def create_goal_prompt(
goal: str, image_urls: list[str] | None
) -> tuple[str, list[str]]:
goal_txt: str = f"""\
# Instructions
Review the current state of the page and all other information to find the best possible next action to accomplish your goal. Your answer will be interpreted and executed by a program, make sure to follow the formatting instructions.
@@ -52,7 +54,7 @@ def create_observation_prompt(
focused_element: str,
error_prefix: str,
som_screenshot: str | None,
):
) -> tuple[str, str | None]:
txt_observation = f"""
# Observation of current step:
{tabs}{axtree_txt}{focused_element}{error_prefix}
@@ -273,7 +275,9 @@ Note:
observation_txt, som_screenshot = create_observation_prompt(
cur_axtree_txt, tabs, focused_element, error_prefix, set_of_marks
)
human_prompt = [TextContent(type='text', text=goal_txt)]
human_prompt: list[TextContent | ImageContent] = [
TextContent(type='text', text=goal_txt)
]
if len(goal_images) > 0:
human_prompt.append(ImageContent(image_urls=goal_images))
human_prompt.append(TextContent(type='text', text=observation_txt))

View File

@@ -168,10 +168,12 @@ class DockerRuntimeBuilder(RuntimeBuilder):
)
except subprocess.CalledProcessError as e:
logger.error(f'Image build failed:\n{e}') # TODO: {e} is empty
logger.error(f'Image build failed:\n{e}') # TODO: {e} is empty
logger.error(f'Command output:\n{e.output}')
if self.rolling_logger.is_enabled():
logger.error("Docker build output:\n" + self.rolling_logger.all_lines) # Show the error
logger.error(
'Docker build output:\n' + self.rolling_logger.all_lines
) # Show the error
raise
except subprocess.TimeoutExpired:

View File

@@ -270,7 +270,10 @@ echo 'export INPUTRC=/etc/inputrc' >> /etc/bash.bashrc
tunnel = self.sandbox.tunnels()[self._vscode_port]
tunnel_url = tunnel.url
self._vscode_url = tunnel_url + f'/?tkn={token}&folder={self.config.workspace_mount_path_in_sandbox}'
self._vscode_url = (
tunnel_url
+ f'/?tkn={token}&folder={self.config.workspace_mount_path_in_sandbox}'
)
self.log(
'debug',

View File

@@ -13,4 +13,4 @@ function deactivate() {}
module.exports = {
activate,
deactivate
}
}

View File

@@ -20,4 +20,4 @@
"title": "Hello World from OpenHands"
}]
}
}
}

View File

@@ -189,19 +189,20 @@ async def search_conversations(
config, get_user_id(request)
)
conversation_metadata_result_set = await conversation_store.search(page_id, limit)
# Filter out conversations older than max_age
now = datetime.now(timezone.utc)
max_age = config.conversation_max_age_seconds
filtered_results = [
conversation for conversation in conversation_metadata_result_set.results
if hasattr(conversation, 'created_at') and
(now - conversation.created_at.replace(tzinfo=timezone.utc)).total_seconds() <= max_age
conversation
for conversation in conversation_metadata_result_set.results
if hasattr(conversation, 'created_at')
and (now - conversation.created_at.replace(tzinfo=timezone.utc)).total_seconds()
<= max_age
]
conversation_ids = set(
conversation.conversation_id
for conversation in filtered_results
conversation.conversation_id for conversation in filtered_results
)
running_conversations = await conversation_manager.get_running_agent_loops(
get_user_id(request), set(conversation_ids)