mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-09 14:57:59 -05:00
Fix mypy errors in agenthub directory (#6811)
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: Xingyao Wang <xingyao@all-hands.dev> Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
This commit is contained in:
@@ -10,17 +10,24 @@ from openhands.events.action import (
|
||||
|
||||
|
||||
class BrowsingResponseParser(ResponseParser):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
# Need to pay attention to the item order in self.action_parsers
|
||||
super().__init__()
|
||||
self.action_parsers = [BrowsingActionParserMessage()]
|
||||
self.default_parser = BrowsingActionParserBrowseInteractive()
|
||||
|
||||
def parse(self, response: str) -> Action:
|
||||
action_str = self.parse_response(response)
|
||||
def parse(
|
||||
self, response: str | dict[str, list[dict[str, dict[str, str | None]]]]
|
||||
) -> Action:
|
||||
if isinstance(response, str):
|
||||
action_str = response
|
||||
else:
|
||||
action_str = self.parse_response(response)
|
||||
return self.parse_action(action_str)
|
||||
|
||||
def parse_response(self, response) -> str:
|
||||
def parse_response(
|
||||
self, response: dict[str, list[dict[str, dict[str, str | None]]]]
|
||||
) -> str:
|
||||
action_str = response['choices'][0]['message']['content']
|
||||
if action_str is None:
|
||||
return ''
|
||||
@@ -47,9 +54,7 @@ class BrowsingActionParserMessage(ActionParser):
|
||||
- BrowseInteractiveAction(browser_actions) - unexpected response format, message back to user
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def check_condition(self, action_str: str) -> bool:
|
||||
@@ -69,9 +74,7 @@ class BrowsingActionParserBrowseInteractive(ActionParser):
|
||||
- BrowseInteractiveAction(browser_actions) - handle send message to user function call in BrowserGym
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def check_condition(self, action_str: str) -> bool:
|
||||
|
||||
@@ -5,7 +5,7 @@ from warnings import warn
|
||||
import yaml
|
||||
|
||||
|
||||
def yaml_parser(message):
|
||||
def yaml_parser(message: str) -> tuple[dict, bool, str]:
|
||||
"""Parse a yaml message for the retry function."""
|
||||
# saves gpt-3.5 from some yaml parsing errors
|
||||
message = re.sub(r':\s*\n(?=\S|\n)', ': ', message)
|
||||
@@ -22,7 +22,9 @@ def yaml_parser(message):
|
||||
return value, valid, retry_message
|
||||
|
||||
|
||||
def _compress_chunks(text, identifier, skip_list, split_regex='\n\n+'):
|
||||
def _compress_chunks(
|
||||
text: str, identifier: str, skip_list: list[str], split_regex: str = '\n\n+'
|
||||
) -> tuple[dict[str, str], str]:
|
||||
"""Compress a string by replacing redundant chunks by identifiers. Chunks are defined by the split_regex."""
|
||||
text_list = re.split(split_regex, text)
|
||||
text_list = [chunk.strip() for chunk in text_list]
|
||||
@@ -44,7 +46,7 @@ def _compress_chunks(text, identifier, skip_list, split_regex='\n\n+'):
|
||||
return def_dict, compressed_text
|
||||
|
||||
|
||||
def compress_string(text):
|
||||
def compress_string(text: str) -> str:
|
||||
"""Compress a string by replacing redundant paragraphs and lines with identifiers."""
|
||||
# Perform paragraph-level compression
|
||||
def_dict, compressed_text = _compress_chunks(
|
||||
@@ -67,7 +69,7 @@ def compress_string(text):
|
||||
return definitions + '\n' + compressed_text
|
||||
|
||||
|
||||
def extract_html_tags(text, keys):
|
||||
def extract_html_tags(text: str, keys: list[str]) -> dict[str, list[str]]:
|
||||
"""Extract the content within HTML tags for a list of keys.
|
||||
|
||||
Parameters
|
||||
@@ -102,7 +104,12 @@ class ParseError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def parse_html_tags_raise(text, keys=(), optional_keys=(), merge_multiple=False):
|
||||
def parse_html_tags_raise(
|
||||
text: str,
|
||||
keys: list[str] | None = None,
|
||||
optional_keys: list[str] | None = None,
|
||||
merge_multiple: bool = False,
|
||||
) -> dict[str, str]:
|
||||
"""A version of parse_html_tags that raises an exception if the parsing is not successful."""
|
||||
content_dict, valid, retry_message = parse_html_tags(
|
||||
text, keys, optional_keys, merge_multiple=merge_multiple
|
||||
@@ -112,7 +119,12 @@ def parse_html_tags_raise(text, keys=(), optional_keys=(), merge_multiple=False)
|
||||
return content_dict
|
||||
|
||||
|
||||
def parse_html_tags(text, keys=(), optional_keys=(), merge_multiple=False):
|
||||
def parse_html_tags(
|
||||
text: str,
|
||||
keys: list[str] | None = None,
|
||||
optional_keys: list[str] | None = None,
|
||||
merge_multiple: bool = False,
|
||||
) -> tuple[dict[str, str], bool, str]:
|
||||
"""Satisfy the parse api, extracts 1 match per key and validates that all keys are present
|
||||
|
||||
Parameters
|
||||
@@ -133,9 +145,12 @@ def parse_html_tags(text, keys=(), optional_keys=(), merge_multiple=False):
|
||||
str
|
||||
A message to be displayed to the agent if the parsing was not successful.
|
||||
"""
|
||||
all_keys = tuple(keys) + tuple(optional_keys)
|
||||
keys = keys or []
|
||||
optional_keys = optional_keys or []
|
||||
all_keys = list(keys) + list(optional_keys)
|
||||
content_dict = extract_html_tags(text, all_keys)
|
||||
retry_messages = []
|
||||
result_dict: dict[str, str] = {}
|
||||
|
||||
for key in all_keys:
|
||||
if key not in content_dict:
|
||||
@@ -143,7 +158,6 @@ def parse_html_tags(text, keys=(), optional_keys=(), merge_multiple=False):
|
||||
retry_messages.append(f'Missing the key <{key}> in the answer.')
|
||||
else:
|
||||
val = content_dict[key]
|
||||
content_dict[key] = val[0]
|
||||
if len(val) > 1:
|
||||
if not merge_multiple:
|
||||
retry_messages.append(
|
||||
@@ -151,8 +165,10 @@ def parse_html_tags(text, keys=(), optional_keys=(), merge_multiple=False):
|
||||
)
|
||||
else:
|
||||
# merge the multiple instances
|
||||
content_dict[key] = '\n'.join(val)
|
||||
result_dict[key] = '\n'.join(val)
|
||||
else:
|
||||
result_dict[key] = val[0]
|
||||
|
||||
valid = len(retry_messages) == 0
|
||||
retry_message = '\n'.join(retry_messages)
|
||||
return content_dict, valid, retry_message
|
||||
return result_dict, valid, retry_message
|
||||
|
||||
@@ -475,8 +475,9 @@ def combine_thought(action: Action, thought: str) -> Action:
|
||||
def response_to_actions(response: ModelResponse) -> list[Action]:
|
||||
actions: list[Action] = []
|
||||
assert len(response.choices) == 1, 'Only one choice is supported for now'
|
||||
assistant_msg = response.choices[0].message
|
||||
if assistant_msg.tool_calls:
|
||||
choice = response.choices[0]
|
||||
assistant_msg = choice.message
|
||||
if hasattr(assistant_msg, 'tool_calls') and assistant_msg.tool_calls:
|
||||
# Check if there's assistant_msg.content. If so, add it to the thought
|
||||
thought = ''
|
||||
if isinstance(assistant_msg.content, str):
|
||||
@@ -592,7 +593,10 @@ def response_to_actions(response: ModelResponse) -> list[Action]:
|
||||
actions.append(action)
|
||||
else:
|
||||
actions.append(
|
||||
MessageAction(content=assistant_msg.content, wait_for_response=True)
|
||||
MessageAction(
|
||||
content=str(assistant_msg.content) if assistant_msg.content else '',
|
||||
wait_for_response=True,
|
||||
)
|
||||
)
|
||||
|
||||
assert len(actions) >= 1
|
||||
|
||||
@@ -22,7 +22,7 @@ def parse_response(orig_response: str) -> Action:
|
||||
return action_from_dict(action_dict)
|
||||
|
||||
|
||||
def to_json(obj, **kwargs):
|
||||
def to_json(obj: object, **kwargs: dict) -> str:
|
||||
"""Serialize an object to str format"""
|
||||
return json.dumps(obj, **kwargs)
|
||||
|
||||
@@ -32,7 +32,9 @@ class MicroAgent(Agent):
|
||||
prompt = ''
|
||||
agent_definition: dict = {}
|
||||
|
||||
def history_to_json(self, history: list[Event], max_events: int = 20, **kwargs):
|
||||
def history_to_json(
|
||||
self, history: list[Event], max_events: int = 20, **kwargs: dict
|
||||
) -> str:
|
||||
"""
|
||||
Serialize and simplify history to str format
|
||||
"""
|
||||
@@ -60,7 +62,7 @@ class MicroAgent(Agent):
|
||||
super().__init__(llm, config)
|
||||
if 'name' not in self.agent_definition:
|
||||
raise ValueError('Agent definition must contain a name')
|
||||
self.prompt_template = Environment(loader=BaseLoader).from_string(self.prompt)
|
||||
self.prompt_template = Environment(loader=BaseLoader()).from_string(self.prompt)
|
||||
self.delegates = all_microagents.copy()
|
||||
del self.delegates[self.agent_definition['name']]
|
||||
|
||||
@@ -74,7 +76,7 @@ class MicroAgent(Agent):
|
||||
delegates=self.delegates,
|
||||
latest_user_message=last_user_message,
|
||||
)
|
||||
content = [TextContent(text=prompt)]
|
||||
content: list[TextContent | ImageContent] = [TextContent(text=prompt)]
|
||||
if self.llm.vision_is_active() and last_image_urls:
|
||||
content.append(ImageContent(image_urls=last_image_urls))
|
||||
message = Message(role='user', content=content)
|
||||
|
||||
@@ -29,7 +29,9 @@ def get_error_prefix(obs: BrowserOutputObservation) -> str:
|
||||
return f'## Error from previous action:\n{obs.last_browser_action_error}\n'
|
||||
|
||||
|
||||
def create_goal_prompt(goal: str, image_urls: list[str] | None):
|
||||
def create_goal_prompt(
|
||||
goal: str, image_urls: list[str] | None
|
||||
) -> tuple[str, list[str]]:
|
||||
goal_txt: str = f"""\
|
||||
# Instructions
|
||||
Review the current state of the page and all other information to find the best possible next action to accomplish your goal. Your answer will be interpreted and executed by a program, make sure to follow the formatting instructions.
|
||||
@@ -52,7 +54,7 @@ def create_observation_prompt(
|
||||
focused_element: str,
|
||||
error_prefix: str,
|
||||
som_screenshot: str | None,
|
||||
):
|
||||
) -> tuple[str, str | None]:
|
||||
txt_observation = f"""
|
||||
# Observation of current step:
|
||||
{tabs}{axtree_txt}{focused_element}{error_prefix}
|
||||
@@ -273,7 +275,9 @@ Note:
|
||||
observation_txt, som_screenshot = create_observation_prompt(
|
||||
cur_axtree_txt, tabs, focused_element, error_prefix, set_of_marks
|
||||
)
|
||||
human_prompt = [TextContent(type='text', text=goal_txt)]
|
||||
human_prompt: list[TextContent | ImageContent] = [
|
||||
TextContent(type='text', text=goal_txt)
|
||||
]
|
||||
if len(goal_images) > 0:
|
||||
human_prompt.append(ImageContent(image_urls=goal_images))
|
||||
human_prompt.append(TextContent(type='text', text=observation_txt))
|
||||
|
||||
@@ -168,10 +168,12 @@ class DockerRuntimeBuilder(RuntimeBuilder):
|
||||
)
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f'Image build failed:\n{e}') # TODO: {e} is empty
|
||||
logger.error(f'Image build failed:\n{e}') # TODO: {e} is empty
|
||||
logger.error(f'Command output:\n{e.output}')
|
||||
if self.rolling_logger.is_enabled():
|
||||
logger.error("Docker build output:\n" + self.rolling_logger.all_lines) # Show the error
|
||||
logger.error(
|
||||
'Docker build output:\n' + self.rolling_logger.all_lines
|
||||
) # Show the error
|
||||
raise
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
|
||||
@@ -270,7 +270,10 @@ echo 'export INPUTRC=/etc/inputrc' >> /etc/bash.bashrc
|
||||
|
||||
tunnel = self.sandbox.tunnels()[self._vscode_port]
|
||||
tunnel_url = tunnel.url
|
||||
self._vscode_url = tunnel_url + f'/?tkn={token}&folder={self.config.workspace_mount_path_in_sandbox}'
|
||||
self._vscode_url = (
|
||||
tunnel_url
|
||||
+ f'/?tkn={token}&folder={self.config.workspace_mount_path_in_sandbox}'
|
||||
)
|
||||
|
||||
self.log(
|
||||
'debug',
|
||||
|
||||
@@ -13,4 +13,4 @@ function deactivate() {}
|
||||
module.exports = {
|
||||
activate,
|
||||
deactivate
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,4 +20,4 @@
|
||||
"title": "Hello World from OpenHands"
|
||||
}]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -189,19 +189,20 @@ async def search_conversations(
|
||||
config, get_user_id(request)
|
||||
)
|
||||
conversation_metadata_result_set = await conversation_store.search(page_id, limit)
|
||||
|
||||
|
||||
# Filter out conversations older than max_age
|
||||
now = datetime.now(timezone.utc)
|
||||
max_age = config.conversation_max_age_seconds
|
||||
filtered_results = [
|
||||
conversation for conversation in conversation_metadata_result_set.results
|
||||
if hasattr(conversation, 'created_at') and
|
||||
(now - conversation.created_at.replace(tzinfo=timezone.utc)).total_seconds() <= max_age
|
||||
conversation
|
||||
for conversation in conversation_metadata_result_set.results
|
||||
if hasattr(conversation, 'created_at')
|
||||
and (now - conversation.created_at.replace(tzinfo=timezone.utc)).total_seconds()
|
||||
<= max_age
|
||||
]
|
||||
|
||||
|
||||
conversation_ids = set(
|
||||
conversation.conversation_id
|
||||
for conversation in filtered_results
|
||||
conversation.conversation_id for conversation in filtered_results
|
||||
)
|
||||
running_conversations = await conversation_manager.get_running_agent_loops(
|
||||
get_user_id(request), set(conversation_ids)
|
||||
|
||||
Reference in New Issue
Block a user