Fix type checking errors in resolver directory (#6738)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Graham Neubig
2025-02-18 20:13:33 -05:00
committed by GitHub
parent 1a7003a705
commit f4e5fb2873
13 changed files with 209 additions and 230 deletions

View File

@@ -22,28 +22,28 @@ class GithubIssueHandler(IssueHandlerInterface):
self.clone_url = self.get_clone_url() self.clone_url = self.get_clone_url()
self.headers = self.get_headers() self.headers = self.get_headers()
def set_owner(self, owner: str): def set_owner(self, owner: str) -> None:
self.owner = owner self.owner = owner
def get_headers(self): def get_headers(self) -> dict[str, str]:
return { return {
'Authorization': f'token {self.token}', 'Authorization': f'token {self.token}',
'Accept': 'application/vnd.github.v3+json', 'Accept': 'application/vnd.github.v3+json',
} }
def get_base_url(self): def get_base_url(self) -> str:
return f'https://api.github.com/repos/{self.owner}/{self.repo}' return f'https://api.github.com/repos/{self.owner}/{self.repo}'
def get_authorize_url(self): def get_authorize_url(self) -> str:
return f'https://{self.username}:{self.token}@github.com/' return f'https://{self.username}:{self.token}@github.com/'
def get_branch_url(self, branch_name: str): def get_branch_url(self, branch_name: str) -> str:
return self.get_base_url() + f'/branches/{branch_name}' return self.get_base_url() + f'/branches/{branch_name}'
def get_download_url(self): def get_download_url(self) -> str:
return f'{self.base_url}/issues' return f'{self.base_url}/issues'
def get_clone_url(self): def get_clone_url(self) -> str:
username_and_token = ( username_and_token = (
f'{self.username}:{self.token}' f'{self.username}:{self.token}'
if self.username if self.username
@@ -51,10 +51,10 @@ class GithubIssueHandler(IssueHandlerInterface):
) )
return f'https://{username_and_token}@github.com/{self.owner}/{self.repo}.git' return f'https://{username_and_token}@github.com/{self.owner}/{self.repo}.git'
def get_graphql_url(self): def get_graphql_url(self) -> str:
return 'https://api.github.com/graphql' return 'https://api.github.com/graphql'
def get_compare_url(self, branch_name: str): def get_compare_url(self, branch_name: str) -> str:
return f'https://github.com/{self.owner}/{self.repo}/compare/{branch_name}?expand=1' return f'https://github.com/{self.owner}/{self.repo}/compare/{branch_name}?expand=1'
def get_converted_issues( def get_converted_issues(
@@ -186,7 +186,7 @@ class GithubIssueHandler(IssueHandlerInterface):
print(f'Branch {branch_name} exists: {exists}') print(f'Branch {branch_name} exists: {exists}')
return exists return exists
def get_branch_name(self, base_branch_name: str): def get_branch_name(self, base_branch_name: str) -> str:
branch_name = base_branch_name branch_name = base_branch_name
attempt = 1 attempt = 1
while self.branch_exists(branch_name): while self.branch_exists(branch_name):
@@ -194,7 +194,7 @@ class GithubIssueHandler(IssueHandlerInterface):
branch_name = f'{base_branch_name}-try{attempt}' branch_name = f'{base_branch_name}-try{attempt}'
return branch_name return branch_name
def reply_to_comment(self, pr_number: int, comment_id: str, reply: str): def reply_to_comment(self, pr_number: int, comment_id: str, reply: str) -> None:
# Opting for graphql as REST API doesn't allow reply to replies in comment threads # Opting for graphql as REST API doesn't allow reply to replies in comment threads
query = """ query = """
mutation($body: String!, $pullRequestReviewThreadId: ID!) { mutation($body: String!, $pullRequestReviewThreadId: ID!) {
@@ -221,15 +221,18 @@ class GithubIssueHandler(IssueHandlerInterface):
) )
response.raise_for_status() response.raise_for_status()
def get_pull_url(self, pr_number: int): def get_pull_url(self, pr_number: int) -> str:
return f'https://github.com/{self.owner}/{self.repo}/pull/{pr_number}' return f'https://github.com/{self.owner}/{self.repo}/pull/{pr_number}'
def get_default_branch_name(self) -> str: def get_default_branch_name(self) -> str:
response = requests.get(f'{self.base_url}', headers=self.headers) response = requests.get(f'{self.base_url}', headers=self.headers)
response.raise_for_status() response.raise_for_status()
return response.json()['default_branch'] data = response.json()
return str(data['default_branch'])
def create_pull_request(self, data=dict) -> dict: def create_pull_request(self, data: dict[str, Any] | None = None) -> dict[str, Any]:
if data is None:
data = {}
response = requests.post( response = requests.post(
f'{self.base_url}/pulls', headers=self.headers, json=data f'{self.base_url}/pulls', headers=self.headers, json=data
) )
@@ -240,9 +243,9 @@ class GithubIssueHandler(IssueHandlerInterface):
) )
response.raise_for_status() response.raise_for_status()
pr_data = response.json() pr_data = response.json()
return pr_data return dict(pr_data)
def request_reviewers(self, reviewer: str, pr_number: int): def request_reviewers(self, reviewer: str, pr_number: int) -> None:
review_data = {'reviewers': [reviewer]} review_data = {'reviewers': [reviewer]}
review_response = requests.post( review_response = requests.post(
f'{self.base_url}/pulls/{pr_number}/requested_reviewers', f'{self.base_url}/pulls/{pr_number}/requested_reviewers',
@@ -254,7 +257,7 @@ class GithubIssueHandler(IssueHandlerInterface):
f'Warning: Failed to request review from {reviewer}: {review_response.text}' f'Warning: Failed to request review from {reviewer}: {review_response.text}'
) )
def send_comment_msg(self, issue_number: int, msg: str): def send_comment_msg(self, issue_number: int, msg: str) -> None:
"""Send a comment message to a GitHub issue or pull request. """Send a comment message to a GitHub issue or pull request.
Args: Args:
@@ -282,8 +285,8 @@ class GithubIssueHandler(IssueHandlerInterface):
review_comments: list[str] | None, review_comments: list[str] | None,
review_threads: list[ReviewThread], review_threads: list[ReviewThread],
thread_comments: list[str] | None, thread_comments: list[str] | None,
): ) -> list[str]:
pass return []
class GithubPRHandler(GithubIssueHandler): class GithubPRHandler(GithubIssueHandler):
@@ -487,7 +490,7 @@ class GithubPRHandler(GithubIssueHandler):
review_comments: list[str] | None, review_comments: list[str] | None,
review_threads: list[ReviewThread], review_threads: list[ReviewThread],
thread_comments: list[str] | None, thread_comments: list[str] | None,
): ) -> list[str]:
new_issue_references = [] new_issue_references = []
if issue_body: if issue_body:

View File

@@ -23,38 +23,38 @@ class GitlabIssueHandler(IssueHandlerInterface):
self.clone_url = self.get_clone_url() self.clone_url = self.get_clone_url()
self.headers = self.get_headers() self.headers = self.get_headers()
def set_owner(self, owner: str): def set_owner(self, owner: str) -> None:
self.owner = owner self.owner = owner
def get_headers(self): def get_headers(self) -> dict[str, str]:
return { return {
'Authorization': f'Bearer {self.token}', 'Authorization': f'Bearer {self.token}',
'Accept': 'application/json', 'Accept': 'application/json',
} }
def get_base_url(self): def get_base_url(self) -> str:
project_path = quote(f'{self.owner}/{self.repo}', safe="") project_path = quote(f'{self.owner}/{self.repo}', safe='')
return f'https://gitlab.com/api/v4/projects/{project_path}' return f'https://gitlab.com/api/v4/projects/{project_path}'
def get_authorize_url(self): def get_authorize_url(self) -> str:
return f'https://{self.username}:{self.token}@gitlab.com/' return f'https://{self.username}:{self.token}@gitlab.com/'
def get_branch_url(self, branch_name: str): def get_branch_url(self, branch_name: str) -> str:
return self.get_base_url() + f'/repository/branches/{branch_name}' return self.get_base_url() + f'/repository/branches/{branch_name}'
def get_download_url(self): def get_download_url(self) -> str:
return f'{self.base_url}/issues' return f'{self.base_url}/issues'
def get_clone_url(self): def get_clone_url(self) -> str:
username_and_token = self.token username_and_token = self.token
if self.username: if self.username:
username_and_token = f'{self.username}:{self.token}' username_and_token = f'{self.username}:{self.token}'
return f'https://{username_and_token}@gitlab.com/{self.owner}/{self.repo}.git' return f'https://{username_and_token}@gitlab.com/{self.owner}/{self.repo}.git'
def get_graphql_url(self): def get_graphql_url(self) -> str:
return 'https://gitlab.com/api/graphql' return 'https://gitlab.com/api/graphql'
def get_compare_url(self, branch_name: str): def get_compare_url(self, branch_name: str) -> str:
return f'https://gitlab.com/{self.owner}/{self.repo}/-/compare/{self.get_default_branch_name()}...{branch_name}' return f'https://gitlab.com/{self.owner}/{self.repo}/-/compare/{self.get_default_branch_name()}...{branch_name}'
def get_converted_issues( def get_converted_issues(
@@ -189,7 +189,7 @@ class GitlabIssueHandler(IssueHandlerInterface):
print(f'Branch {branch_name} exists: {exists}') print(f'Branch {branch_name} exists: {exists}')
return exists return exists
def get_branch_name(self, base_branch_name: str): def get_branch_name(self, base_branch_name: str) -> str:
branch_name = base_branch_name branch_name = base_branch_name
attempt = 1 attempt = 1
while self.branch_exists(branch_name): while self.branch_exists(branch_name):
@@ -197,7 +197,7 @@ class GitlabIssueHandler(IssueHandlerInterface):
branch_name = f'{base_branch_name}-try{attempt}' branch_name = f'{base_branch_name}-try{attempt}'
return branch_name return branch_name
def reply_to_comment(self, pr_number: int, comment_id: str, reply: str): def reply_to_comment(self, pr_number: int, comment_id: str, reply: str) -> None:
response = requests.get( response = requests.get(
f'{self.base_url}/merge_requests/{pr_number}/discussions/{comment_id.split('/')[-1]}', f'{self.base_url}/merge_requests/{pr_number}/discussions/{comment_id.split('/')[-1]}',
headers=self.headers, headers=self.headers,
@@ -216,7 +216,7 @@ class GitlabIssueHandler(IssueHandlerInterface):
) )
response.raise_for_status() response.raise_for_status()
def get_pull_url(self, pr_number: int): def get_pull_url(self, pr_number: int) -> str:
return ( return (
f'https://gitlab.com/{self.owner}/{self.repo}/-/merge_requests/{pr_number}' f'https://gitlab.com/{self.owner}/{self.repo}/-/merge_requests/{pr_number}'
) )
@@ -224,9 +224,12 @@ class GitlabIssueHandler(IssueHandlerInterface):
def get_default_branch_name(self) -> str: def get_default_branch_name(self) -> str:
response = requests.get(f'{self.base_url}', headers=self.headers) response = requests.get(f'{self.base_url}', headers=self.headers)
response.raise_for_status() response.raise_for_status()
return response.json()['default_branch'] data = response.json()
return str(data['default_branch'])
def create_pull_request(self, data=dict) -> dict: def create_pull_request(self, data: dict[str, Any] | None = None) -> dict[str, Any]:
if data is None:
data = {}
response = requests.post( response = requests.post(
f'{self.base_url}/merge_requests', headers=self.headers, json=data f'{self.base_url}/merge_requests', headers=self.headers, json=data
) )
@@ -243,9 +246,9 @@ class GitlabIssueHandler(IssueHandlerInterface):
if 'iid' in pr_data: if 'iid' in pr_data:
pr_data['number'] = pr_data['iid'] pr_data['number'] = pr_data['iid']
return pr_data return dict(pr_data)
def request_reviewers(self, reviewer: str, pr_number: int): def request_reviewers(self, reviewer: str, pr_number: int) -> None:
response = requests.get( response = requests.get(
f'https://gitlab.com/api/v4/users?username={reviewer}', f'https://gitlab.com/api/v4/users?username={reviewer}',
headers=self.headers, headers=self.headers,
@@ -264,7 +267,7 @@ class GitlabIssueHandler(IssueHandlerInterface):
f'Warning: Failed to request review from {reviewer}: {review_response.text}' f'Warning: Failed to request review from {reviewer}: {review_response.text}'
) )
def send_comment_msg(self, issue_number: int, msg: str): def send_comment_msg(self, issue_number: int, msg: str) -> None:
"""Send a comment message to a GitHub issue or pull request. """Send a comment message to a GitHub issue or pull request.
Args: Args:
@@ -292,8 +295,8 @@ class GitlabIssueHandler(IssueHandlerInterface):
review_comments: list[str] | None, review_comments: list[str] | None,
review_threads: list[ReviewThread], review_threads: list[ReviewThread],
thread_comments: list[str] | None, thread_comments: list[str] | None,
): ) -> list[str]:
pass return []
class GitlabPRHandler(GitlabIssueHandler): class GitlabPRHandler(GitlabIssueHandler):
@@ -479,7 +482,7 @@ class GitlabPRHandler(GitlabIssueHandler):
review_comments: list[str] | None, review_comments: list[str] | None,
review_threads: list[ReviewThread], review_threads: list[ReviewThread],
thread_comments: list[str] | None, thread_comments: list[str] | None,
): ) -> list[str]:
new_issue_references = [] new_issue_references = []
if issue_body: if issue_body:

View File

@@ -26,7 +26,7 @@ class Issue(BaseModel):
class IssueHandlerInterface(ABC): class IssueHandlerInterface(ABC):
@abstractmethod @abstractmethod
def set_owner(self, owner: str): def set_owner(self, owner: str) -> None:
pass pass
@abstractmethod @abstractmethod
@@ -40,43 +40,43 @@ class IssueHandlerInterface(ABC):
pass pass
@abstractmethod @abstractmethod
def get_base_url(self): def get_base_url(self) -> str:
pass pass
@abstractmethod @abstractmethod
def get_branch_url(self, branch_name): def get_branch_url(self, branch_name: str) -> str:
pass pass
@abstractmethod @abstractmethod
def get_download_url(self): def get_download_url(self) -> str:
pass pass
@abstractmethod @abstractmethod
def get_clone_url(self): def get_clone_url(self) -> str:
pass pass
@abstractmethod @abstractmethod
def get_pull_url(self, pr_number: int): def get_pull_url(self, pr_number: int) -> str:
pass pass
@abstractmethod @abstractmethod
def get_graphql_url(self): def get_graphql_url(self) -> str:
pass pass
@abstractmethod @abstractmethod
def get_headers(self): def get_headers(self) -> dict[str, str]:
pass pass
@abstractmethod @abstractmethod
def get_compare_url(self, branch_name): def get_compare_url(self, branch_name: str) -> str:
pass pass
@abstractmethod @abstractmethod
def get_branch_name(self, base_branch_name: str): def get_branch_name(self, base_branch_name: str) -> str:
pass pass
@abstractmethod @abstractmethod
def get_default_branch_name(self): def get_default_branch_name(self) -> str:
pass pass
@abstractmethod @abstractmethod
@@ -84,23 +84,25 @@ class IssueHandlerInterface(ABC):
pass pass
@abstractmethod @abstractmethod
def reply_to_comment(self, pr_number: int, comment_id: str, reply: str): def reply_to_comment(self, pr_number: int, comment_id: str, reply: str) -> None:
pass pass
@abstractmethod @abstractmethod
def send_comment_msg(self, issue_number: int, msg: str): def send_comment_msg(self, issue_number: int, msg: str) -> None:
pass pass
@abstractmethod @abstractmethod
def get_authorize_url(self): def get_authorize_url(self) -> str:
pass pass
@abstractmethod @abstractmethod
def create_pull_request(self, data=dict) -> dict: def create_pull_request(self, data: dict[str, Any] | None = None) -> dict[str, Any]:
pass if data is None:
data = {}
raise NotImplementedError
@abstractmethod @abstractmethod
def request_reviewers(self, reviewer: str, pr_number: int): def request_reviewers(self, reviewer: str, pr_number: int) -> None:
pass pass
@abstractmethod @abstractmethod
@@ -112,7 +114,7 @@ class IssueHandlerInterface(ABC):
review_comments: list[str] | None, review_comments: list[str] | None,
review_threads: list[ReviewThread], review_threads: list[ReviewThread],
thread_comments: list[str] | None, thread_comments: list[str] | None,
): ) -> list[str]:
pass pass
@abstractmethod @abstractmethod

View File

@@ -25,7 +25,7 @@ class ServiceContext:
if llm_config is not None: if llm_config is not None:
self.llm = LLM(llm_config) self.llm = LLM(llm_config)
def set_strategy(self, strategy): def set_strategy(self, strategy: IssueHandlerInterface) -> None:
self._strategy = strategy self._strategy = strategy
@@ -36,7 +36,7 @@ class ServiceContextPR(ServiceContext):
def __init__(self, strategy: IssueHandlerInterface, llm_config: LLMConfig): def __init__(self, strategy: IssueHandlerInterface, llm_config: LLMConfig):
super().__init__(strategy, llm_config) super().__init__(strategy, llm_config)
def get_clone_url(self): def get_clone_url(self) -> str:
return self._strategy.get_clone_url() return self._strategy.get_clone_url()
def download_issues(self) -> list[Any]: def download_issues(self) -> list[Any]:
@@ -266,31 +266,31 @@ class ServiceContextIssue(ServiceContext):
def __init__(self, strategy: IssueHandlerInterface, llm_config: LLMConfig | None): def __init__(self, strategy: IssueHandlerInterface, llm_config: LLMConfig | None):
super().__init__(strategy, llm_config) super().__init__(strategy, llm_config)
def get_base_url(self): def get_base_url(self) -> str:
return self._strategy.get_base_url() return self._strategy.get_base_url()
def get_branch_url(self, branch_name): def get_branch_url(self, branch_name: str) -> str:
return self._strategy.get_branch_url(branch_name) return self._strategy.get_branch_url(branch_name)
def get_download_url(self): def get_download_url(self) -> str:
return self._strategy.get_download_url() return self._strategy.get_download_url()
def get_clone_url(self): def get_clone_url(self) -> str:
return self._strategy.get_clone_url() return self._strategy.get_clone_url()
def get_graphql_url(self): def get_graphql_url(self) -> str:
return self._strategy.get_graphql_url() return self._strategy.get_graphql_url()
def get_headers(self): def get_headers(self) -> dict[str, str]:
return self._strategy.get_headers() return self._strategy.get_headers()
def get_authorize_url(self): def get_authorize_url(self) -> str:
return self._strategy.get_authorize_url() return self._strategy.get_authorize_url()
def get_pull_url(self, pr_number: int): def get_pull_url(self, pr_number: int) -> str:
return self._strategy.get_pull_url(pr_number) return self._strategy.get_pull_url(pr_number)
def get_compare_url(self, branch_name: str): def get_compare_url(self, branch_name: str) -> str:
return self._strategy.get_compare_url(branch_name) return self._strategy.get_compare_url(branch_name)
def download_issues(self) -> list[Any]: def download_issues(self) -> list[Any]:
@@ -299,25 +299,27 @@ class ServiceContextIssue(ServiceContext):
def get_branch_name( def get_branch_name(
self, self,
base_branch_name: str, base_branch_name: str,
): ) -> str:
return self._strategy.get_branch_name(base_branch_name) return self._strategy.get_branch_name(base_branch_name)
def branch_exists(self, branch_name: str): def branch_exists(self, branch_name: str) -> bool:
return self._strategy.branch_exists(branch_name) return self._strategy.branch_exists(branch_name)
def get_default_branch_name(self) -> str: def get_default_branch_name(self) -> str:
return self._strategy.get_default_branch_name() return self._strategy.get_default_branch_name()
def create_pull_request(self, data=dict): def create_pull_request(self, data: dict[str, Any] | None = None) -> dict[str, Any]:
if data is None:
data = {}
return self._strategy.create_pull_request(data) return self._strategy.create_pull_request(data)
def request_reviewers(self, reviewer: str, pr_number: int): def request_reviewers(self, reviewer: str, pr_number: int) -> None:
return self._strategy.request_reviewers(reviewer, pr_number) return self._strategy.request_reviewers(reviewer, pr_number)
def reply_to_comment(self, pr_number, comment_id, reply): def reply_to_comment(self, pr_number: int, comment_id: str, reply: str) -> None:
return self._strategy.reply_to_comment(pr_number, comment_id, reply) return self._strategy.reply_to_comment(pr_number, comment_id, reply)
def send_comment_msg(self, issue_number: int, msg: str): def send_comment_msg(self, issue_number: int, msg: str) -> None:
return self._strategy.send_comment_msg(issue_number, msg) return self._strategy.send_comment_msg(issue_number, msg)
def get_issue_comments( def get_issue_comments(

View File

@@ -5,10 +5,13 @@ import subprocess
import tempfile import tempfile
from .exceptions import HunkApplyException, SubprocessException from .exceptions import HunkApplyException, SubprocessException
from .patch import Change, diffobj
from .snippets import remove, which from .snippets import remove, which
def _apply_diff_with_subprocess(diff, lines, reverse=False): def _apply_diff_with_subprocess(
diff: diffobj, lines: list[str], reverse: bool = False
) -> tuple[list[str], list[str] | None]:
# call out to patch program # call out to patch program
patchexec = which('patch') patchexec = which('patch')
if not patchexec: if not patchexec:
@@ -63,21 +66,21 @@ def _apply_diff_with_subprocess(diff, lines, reverse=False):
return lines, rejlines return lines, rejlines
def _reverse(changes): def _reverse(changes: list[Change]) -> list[Change]:
def _reverse_change(c): def _reverse_change(c: Change) -> Change:
return c._replace(old=c.new, new=c.old) return c._replace(old=c.new, new=c.old)
return [_reverse_change(c) for c in changes] return [_reverse_change(c) for c in changes]
def apply_diff(diff, text, reverse=False, use_patch=False): def apply_diff(
try: diff: diffobj, text: str | list[str], reverse: bool = False, use_patch: bool = False
lines = text.splitlines() ) -> list[str]:
except AttributeError: lines = text.splitlines() if isinstance(text, str) else list(text)
lines = list(text)
if use_patch: if use_patch:
return _apply_diff_with_subprocess(diff, lines, reverse) lines, _ = _apply_diff_with_subprocess(diff, lines, reverse)
return lines
n_lines = len(lines) n_lines = len(lines)

View File

@@ -1,31 +1,31 @@
class PatchingException(Exception): class PatchingException(Exception):
pass pass
class HunkException(PatchingException): class HunkException(PatchingException):
def __init__(self, msg, hunk=None): def __init__(self, msg: str, hunk: int | None = None) -> None:
self.hunk = hunk self.hunk = hunk
if hunk is not None: if hunk is not None:
super(HunkException, self).__init__( super(HunkException, self).__init__(
'{msg}, in hunk #{n}'.format(msg=msg, n=hunk) '{msg}, in hunk #{n}'.format(msg=msg, n=hunk)
) )
else: else:
super(HunkException, self).__init__(msg) super(HunkException, self).__init__(msg)
class ApplyException(PatchingException): class ApplyException(PatchingException):
pass pass
class SubprocessException(ApplyException): class SubprocessException(ApplyException):
def __init__(self, msg, code): def __init__(self, msg: str, code: int) -> None:
super(SubprocessException, self).__init__(msg) super(SubprocessException, self).__init__(msg)
self.code = code self.code = code
class HunkApplyException(HunkException, ApplyException, ValueError): class HunkApplyException(HunkException, ApplyException, ValueError):
pass pass
class ParseException(HunkException, ValueError): class ParseException(HunkException, ValueError):
pass pass

View File

@@ -3,6 +3,7 @@ import base64
import re import re
import zlib import zlib
from collections import namedtuple from collections import namedtuple
from typing import Iterable
from . import exceptions from . import exceptions
from .snippets import findall_regex, split_by_regex from .snippets import findall_regex, split_by_regex
@@ -71,11 +72,8 @@ cvs_header_timestamp_colon = re.compile(r':([\d.]+)\t(.+)')
old_cvs_diffcmd_header = re.compile('^diff.* (.+):(.*) (.+):(.*)$') old_cvs_diffcmd_header = re.compile('^diff.* (.+):(.*) (.+):(.*)$')
def parse_patch(text): def parse_patch(text: str | list[str]) -> Iterable[diffobj]:
try: lines = text.splitlines() if isinstance(text, str) else text
lines = text.splitlines()
except AttributeError:
lines = text
# maybe use this to nuke all of those line endings? # maybe use this to nuke all of those line endings?
# lines = [x.splitlines()[0] for x in lines] # lines = [x.splitlines()[0] for x in lines]
@@ -104,18 +102,15 @@ def parse_patch(text):
yield diffobj(header=h, changes=d, text=difftext) yield diffobj(header=h, changes=d, text=difftext)
def parse_header(text): def parse_header(text: str | list[str]) -> header | None:
h = parse_scm_header(text) h = parse_scm_header(text)
if h is None: if h is None:
h = parse_diff_header(text) h = parse_diff_header(text)
return h return h
def parse_scm_header(text): def parse_scm_header(text: str | list[str]) -> header | None:
try: lines = text.splitlines() if isinstance(text, str) else text
lines = text.splitlines()
except AttributeError:
lines = text
check = [ check = [
(git_header_index, parse_git_header), (git_header_index, parse_git_header),
@@ -154,11 +149,8 @@ def parse_scm_header(text):
return None return None
def parse_diff_header(text): def parse_diff_header(text: str | list[str]) -> header | None:
try: lines = text.splitlines() if isinstance(text, str) else text
lines = text.splitlines()
except AttributeError:
lines = text
check = [ check = [
(unified_header_new_line, parse_unified_header), (unified_header_new_line, parse_unified_header),
@@ -178,10 +170,10 @@ def parse_diff_header(text):
return None # no header? return None # no header?
def parse_diff(text): def parse_diff(text: str | list[str]) -> list[Change] | None:
try: if isinstance(text, str):
lines = text.splitlines() lines = text.splitlines()
except AttributeError: else:
lines = text lines = text
check = [ check = [
@@ -200,11 +192,8 @@ def parse_diff(text):
return None return None
def parse_git_header(text): def parse_git_header(text: str | list[str]) -> header | None:
try: lines = text.splitlines() if isinstance(text, str) else text
lines = text.splitlines()
except AttributeError:
lines = text
old_version = None old_version = None
new_version = None new_version = None
@@ -275,11 +264,8 @@ def parse_git_header(text):
return None return None
def parse_svn_header(text): def parse_svn_header(text: str | list[str]) -> header | None:
try: lines = text.splitlines() if isinstance(text, str) else text
lines = text.splitlines()
except AttributeError:
lines = text
headers = findall_regex(lines, svn_header_index) headers = findall_regex(lines, svn_header_index)
if len(headers) == 0: if len(headers) == 0:
@@ -346,11 +332,8 @@ def parse_svn_header(text):
return None return None
def parse_cvs_header(text): def parse_cvs_header(text: str | list[str]) -> header | None:
try: lines = text.splitlines() if isinstance(text, str) else text
lines = text.splitlines()
except AttributeError:
lines = text
headers = findall_regex(lines, cvs_header_rcs) headers = findall_regex(lines, cvs_header_rcs)
headers_old = findall_regex(lines, old_cvs_diffcmd_header) headers_old = findall_regex(lines, old_cvs_diffcmd_header)
@@ -430,11 +413,8 @@ def parse_cvs_header(text):
return None return None
def parse_diffcmd_header(text): def parse_diffcmd_header(text: str | list[str]) -> header | None:
try: lines = text.splitlines() if isinstance(text, str) else text
lines = text.splitlines()
except AttributeError:
lines = text
headers = findall_regex(lines, diffcmd_header) headers = findall_regex(lines, diffcmd_header)
if len(headers) == 0: if len(headers) == 0:
@@ -454,11 +434,8 @@ def parse_diffcmd_header(text):
return None return None
def parse_unified_header(text): def parse_unified_header(text: str | list[str]) -> header | None:
try: lines = text.splitlines() if isinstance(text, str) else text
lines = text.splitlines()
except AttributeError:
lines = text
headers = findall_regex(lines, unified_header_new_line) headers = findall_regex(lines, unified_header_new_line)
if len(headers) == 0: if len(headers) == 0:
@@ -490,11 +467,8 @@ def parse_unified_header(text):
return None return None
def parse_context_header(text): def parse_context_header(text: str | list[str]) -> header | None:
try: lines = text.splitlines() if isinstance(text, str) else text
lines = text.splitlines()
except AttributeError:
lines = text
headers = findall_regex(lines, context_header_old_line) headers = findall_regex(lines, context_header_old_line)
if len(headers) == 0: if len(headers) == 0:
@@ -526,11 +500,8 @@ def parse_context_header(text):
return None return None
def parse_default_diff(text): def parse_default_diff(text: str | list[str]) -> list[Change] | None:
try: lines = text.splitlines() if isinstance(text, str) else text
lines = text.splitlines()
except AttributeError:
lines = text
old = 0 old = 0
new = 0 new = 0
@@ -582,11 +553,8 @@ def parse_default_diff(text):
return None return None
def parse_unified_diff(text): def parse_unified_diff(text: str | list[str]) -> list[Change] | None:
try: lines = text.splitlines() if isinstance(text, str) else text
lines = text.splitlines()
except AttributeError:
lines = text
old = 0 old = 0
new = 0 new = 0
@@ -652,11 +620,8 @@ def parse_unified_diff(text):
return None return None
def parse_context_diff(text): def parse_context_diff(text: str | list[str]) -> list[Change] | None:
try: lines = text.splitlines() if isinstance(text, str) else text
lines = text.splitlines()
except AttributeError:
lines = text
old = 0 old = 0
new = 0 new = 0
@@ -795,11 +760,8 @@ def parse_context_diff(text):
return None return None
def parse_ed_diff(text): def parse_ed_diff(text: str | list[str]) -> list[Change] | None:
try: lines = text.splitlines() if isinstance(text, str) else text
lines = text.splitlines()
except AttributeError:
lines = text
old = 0 old = 0
j = 0 j = 0
@@ -878,12 +840,9 @@ def parse_ed_diff(text):
return None return None
def parse_rcs_ed_diff(text): def parse_rcs_ed_diff(text: str | list[str]) -> list[Change] | None:
# much like forward ed, but no 'c' type # much like forward ed, but no 'c' type
try: lines = text.splitlines() if isinstance(text, str) else text
lines = text.splitlines()
except AttributeError:
lines = text
old = 0 old = 0
j = 0 j = 0
@@ -905,7 +864,7 @@ def parse_rcs_ed_diff(text):
hunk_kind = o.group(1) hunk_kind = o.group(1)
old = int(o.group(2)) old = int(o.group(2))
size = int(o.group(3)) size = int(o.group(3)) if o.group(3) else 0
if hunk_kind == 'a': if hunk_kind == 'a':
old += total_change_size + 1 old += total_change_size + 1
@@ -926,15 +885,11 @@ def parse_rcs_ed_diff(text):
if len(changes) > 0: if len(changes) > 0:
return changes return changes
return None return None
def parse_git_binary_diff(text): def parse_git_binary_diff(text: str | list[str]) -> list[Change] | None:
try: lines = text.splitlines() if isinstance(text, str) else text
lines = text.splitlines()
except AttributeError:
lines = text
changes: list[Change] = list() changes: list[Change] = list()

View File

@@ -1,10 +1,11 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import os import os
import re
from shutil import rmtree from shutil import rmtree
def remove(path): def remove(path: str) -> None:
if os.path.exists(path): if os.path.exists(path):
if os.path.isdir(path): if os.path.isdir(path):
rmtree(path) rmtree(path)
@@ -13,7 +14,7 @@ def remove(path):
# find all indices of a list of strings that match a regex # find all indices of a list of strings that match a regex
def findall_regex(items, regex): def findall_regex(items: list[str], regex: re.Pattern[str]) -> list[int]:
found = list() found = list()
for i in range(0, len(items)): for i in range(0, len(items)):
k = regex.match(items[i]) k = regex.match(items[i])
@@ -24,7 +25,7 @@ def findall_regex(items, regex):
return found return found
def split_by_regex(items, regex): def split_by_regex(items: list[str], regex: re.Pattern[str]) -> list[list[str]]:
splits = list() splits = list()
indices = findall_regex(items, regex) indices = findall_regex(items, regex)
if not indices: if not indices:
@@ -45,8 +46,8 @@ def split_by_regex(items, regex):
# http://stackoverflow.com/questions/377017/test-if-executable-exists-in-python # http://stackoverflow.com/questions/377017/test-if-executable-exists-in-python
def which(program): def which(program: str) -> str | None:
def is_exe(fpath): def is_exe(fpath: str) -> bool:
return os.path.isfile(fpath) and os.access(fpath, os.X_OK) return os.path.isfile(fpath) and os.access(fpath, os.X_OK)
fpath, fname = os.path.split(program) fpath, fname = os.path.split(program)

View File

@@ -6,8 +6,9 @@ import multiprocessing as mp
import os import os
import pathlib import pathlib
import subprocess import subprocess
from typing import Awaitable, TextIO from typing import Any, Awaitable, TextIO
from pydantic import SecretStr
from tqdm import tqdm from tqdm import tqdm
import openhands import openhands
@@ -25,7 +26,7 @@ from openhands.resolver.utils import (
) )
def cleanup(): def cleanup() -> None:
print('Cleaning up child processes...') print('Cleaning up child processes...')
for process in mp.active_children(): for process in mp.active_children():
print(f'Terminating child process: {process.name}') print(f'Terminating child process: {process.name}')
@@ -214,7 +215,7 @@ async def resolve_issues(
# Use asyncio.gather with a semaphore to limit concurrency # Use asyncio.gather with a semaphore to limit concurrency
sem = asyncio.Semaphore(num_workers) sem = asyncio.Semaphore(num_workers)
async def run_with_semaphore(task): async def run_with_semaphore(task: Awaitable[Any]) -> Any:
async with sem: async with sem:
return await task return await task
@@ -228,7 +229,7 @@ async def resolve_issues(
logger.info('Finished.') logger.info('Finished.')
def main(): def main() -> None:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Resolve multiple issues from Github or Gitlab.' description='Resolve multiple issues from Github or Gitlab.'
) )
@@ -349,7 +350,7 @@ def main():
llm_config = LLMConfig( llm_config = LLMConfig(
model=my_args.llm_model or os.environ['LLM_MODEL'], model=my_args.llm_model or os.environ['LLM_MODEL'],
api_key=str(api_key) if api_key else None, api_key=SecretStr(api_key) if api_key else None,
base_url=my_args.llm_base_url or os.environ.get('LLM_BASE_URL', None), base_url=my_args.llm_base_url or os.environ.get('LLM_BASE_URL', None),
) )

View File

@@ -10,6 +10,7 @@ import subprocess
from typing import Any from typing import Any
from uuid import uuid4 from uuid import uuid4
from pydantic import SecretStr
from termcolor import colored from termcolor import colored
import openhands import openhands
@@ -18,6 +19,7 @@ from openhands.core.config import AgentConfig, AppConfig, LLMConfig, SandboxConf
from openhands.core.logger import openhands_logger as logger from openhands.core.logger import openhands_logger as logger
from openhands.core.main import create_runtime, run_controller from openhands.core.main import create_runtime, run_controller
from openhands.events.action import CmdRunAction, MessageAction from openhands.events.action import CmdRunAction, MessageAction
from openhands.events.event import Event
from openhands.events.observation import ( from openhands.events.observation import (
CmdOutputObservation, CmdOutputObservation,
ErrorObservation, ErrorObservation,
@@ -48,7 +50,7 @@ AGENT_CLASS = 'CodeActAgent'
def initialize_runtime( def initialize_runtime(
runtime: Runtime, runtime: Runtime,
platform: Platform, platform: Platform,
): ) -> None:
"""Initialize the runtime for the agent. """Initialize the runtime for the agent.
This function is called before the runtime is used to run the agent. This function is called before the runtime is used to run the agent.
@@ -192,26 +194,28 @@ async def process_issue(
# This code looks unnecessary because these are default values in the config class # This code looks unnecessary because these are default values in the config class
# they're set by default if nothing else overrides them # they're set by default if nothing else overrides them
# FIXME we should remove them here # FIXME we should remove them here
kwargs = {} sandbox_config = SandboxConfig(
runtime_container_image=runtime_container_image,
enable_auto_lint=False,
use_host_network=False,
# large enough timeout, since some testcases take very long to run
timeout=300,
)
if os.getenv('GITLAB_CI') == 'True': if os.getenv('GITLAB_CI') == 'True':
kwargs['local_runtime_url'] = os.getenv('LOCAL_RUNTIME_URL', 'http://localhost') sandbox_config.local_runtime_url = os.getenv(
'LOCAL_RUNTIME_URL', 'http://localhost'
)
user_id = os.getuid() if hasattr(os, 'getuid') else 1000 user_id = os.getuid() if hasattr(os, 'getuid') else 1000
if user_id == 0: if user_id == 0:
kwargs['user_id'] = get_unique_uid() sandbox_config.user_id = get_unique_uid()
config = AppConfig( config = AppConfig(
default_agent='CodeActAgent', default_agent='CodeActAgent',
runtime='docker', runtime='docker',
max_budget_per_task=4, max_budget_per_task=4,
max_iterations=max_iterations, max_iterations=max_iterations,
sandbox=SandboxConfig( sandbox=sandbox_config,
runtime_container_image=runtime_container_image,
enable_auto_lint=False,
use_host_network=False,
# large enough timeout, since some testcases take very long to run
timeout=300,
**kwargs,
),
# do not mount workspace # do not mount workspace
workspace_base=workspace_base, workspace_base=workspace_base,
workspace_mount_path=workspace_base, workspace_mount_path=workspace_base,
@@ -222,7 +226,7 @@ async def process_issue(
runtime = create_runtime(config) runtime = create_runtime(config)
await runtime.connect() await runtime.connect()
def on_event(evt): def on_event(evt: Event) -> None:
logger.info(evt) logger.info(evt)
runtime.event_stream.subscribe(EventStreamSubscriber.MAIN, on_event, str(uuid4())) runtime.event_stream.subscribe(EventStreamSubscriber.MAIN, on_event, str(uuid4()))
@@ -524,10 +528,10 @@ async def resolve_issue(
logger.info('Finished.') logger.info('Finished.')
def main(): def main() -> None:
import argparse import argparse
def int_or_none(value): def int_or_none(value: str) -> int | None:
if value.lower() == 'none': if value.lower() == 'none':
return None return None
else: else:
@@ -654,7 +658,7 @@ def main():
api_key = my_args.llm_api_key or os.environ['LLM_API_KEY'] api_key = my_args.llm_api_key or os.environ['LLM_API_KEY']
llm_config = LLMConfig( llm_config = LLMConfig(
model=my_args.llm_model or os.environ['LLM_MODEL'], model=my_args.llm_model or os.environ['LLM_MODEL'],
api_key=str(api_key) if api_key else None, api_key=SecretStr(api_key) if api_key else None,
base_url=my_args.llm_base_url or os.environ.get('LLM_BASE_URL', None), base_url=my_args.llm_base_url or os.environ.get('LLM_BASE_URL', None),
) )

View File

@@ -5,6 +5,7 @@ import shutil
import subprocess import subprocess
import jinja2 import jinja2
from pydantic import SecretStr
from openhands.core.config import LLMConfig from openhands.core.config import LLMConfig
from openhands.core.logger import openhands_logger as logger from openhands.core.logger import openhands_logger as logger
@@ -543,7 +544,7 @@ def process_all_successful_issues(
) )
def main(): def main() -> None:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Send a pull request to Github or Gitlab.' description='Send a pull request to Github or Gitlab.'
) )
@@ -641,7 +642,7 @@ def main():
api_key = my_args.llm_api_key or os.environ['LLM_API_KEY'] api_key = my_args.llm_api_key or os.environ['LLM_API_KEY']
llm_config = LLMConfig( llm_config = LLMConfig(
model=my_args.llm_model or os.environ['LLM_MODEL'], model=my_args.llm_model or os.environ['LLM_MODEL'],
api_key=str(api_key) if api_key else None, api_key=SecretStr(api_key) if api_key else None,
base_url=my_args.llm_base_url or os.environ.get('LLM_BASE_URL', None), base_url=my_args.llm_base_url or os.environ.get('LLM_BASE_URL', None),
) )

View File

@@ -107,7 +107,7 @@ def codeact_user_response(
return msg return msg
def cleanup(): def cleanup() -> None:
print('Cleaning up child processes...') print('Cleaning up child processes...')
for process in mp.active_children(): for process in mp.active_children():
print(f'Terminating child process: {process.name}') print(f'Terminating child process: {process.name}')
@@ -115,7 +115,9 @@ def cleanup():
process.join() process.join()
def prepare_dataset(dataset: pd.DataFrame, output_file: str, eval_n_limit: int): def prepare_dataset(
dataset: pd.DataFrame, output_file: str, eval_n_limit: int
) -> pd.DataFrame:
assert 'instance_id' in dataset.columns, ( assert 'instance_id' in dataset.columns, (
"Expected 'instance_id' column in the dataset. You should define your own " "Expected 'instance_id' column in the dataset. You should define your own "
"unique identifier for each instance and use it as the 'instance_id' column." "unique identifier for each instance and use it as the 'instance_id' column."
@@ -152,7 +154,7 @@ def prepare_dataset(dataset: pd.DataFrame, output_file: str, eval_n_limit: int):
def reset_logger_for_multiprocessing( def reset_logger_for_multiprocessing(
logger: logging.Logger, instance_id: str, log_dir: str logger: logging.Logger, instance_id: str, log_dir: str
): ) -> None:
"""Reset the logger for multiprocessing. """Reset the logger for multiprocessing.
Save logs to a separate file for each process, instead of trying to write to the Save logs to a separate file for each process, instead of trying to write to the
@@ -208,7 +210,7 @@ def extract_issue_references(body: str) -> list[int]:
return [int(match) for match in re.findall(pattern, body)] return [int(match) for match in re.findall(pattern, body)]
def get_unique_uid(start_uid=1000): def get_unique_uid(start_uid: int = 1000) -> int:
existing_uids = set() existing_uids = set()
with open('/etc/passwd', 'r') as passwd_file: with open('/etc/passwd', 'r') as passwd_file:
for line in passwd_file: for line in passwd_file:

View File

@@ -4,7 +4,9 @@ import os
from openhands.resolver.io_utils import load_single_resolver_output from openhands.resolver.io_utils import load_single_resolver_output
def visualize_resolver_output(issue_number: int, output_dir: str, vis_method: str): def visualize_resolver_output(
issue_number: int, output_dir: str, vis_method: str
) -> None:
output_jsonl = os.path.join(output_dir, 'output.jsonl') output_jsonl = os.path.join(output_dir, 'output.jsonl')
resolver_output = load_single_resolver_output(output_jsonl, issue_number) resolver_output = load_single_resolver_output(output_jsonl, issue_number)
if vis_method == 'json': if vis_method == 'json':