mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-09 14:57:59 -05:00
Fix type checking errors in resolver directory (#6738)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
@@ -22,28 +22,28 @@ class GithubIssueHandler(IssueHandlerInterface):
|
|||||||
self.clone_url = self.get_clone_url()
|
self.clone_url = self.get_clone_url()
|
||||||
self.headers = self.get_headers()
|
self.headers = self.get_headers()
|
||||||
|
|
||||||
def set_owner(self, owner: str):
|
def set_owner(self, owner: str) -> None:
|
||||||
self.owner = owner
|
self.owner = owner
|
||||||
|
|
||||||
def get_headers(self):
|
def get_headers(self) -> dict[str, str]:
|
||||||
return {
|
return {
|
||||||
'Authorization': f'token {self.token}',
|
'Authorization': f'token {self.token}',
|
||||||
'Accept': 'application/vnd.github.v3+json',
|
'Accept': 'application/vnd.github.v3+json',
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_base_url(self):
|
def get_base_url(self) -> str:
|
||||||
return f'https://api.github.com/repos/{self.owner}/{self.repo}'
|
return f'https://api.github.com/repos/{self.owner}/{self.repo}'
|
||||||
|
|
||||||
def get_authorize_url(self):
|
def get_authorize_url(self) -> str:
|
||||||
return f'https://{self.username}:{self.token}@github.com/'
|
return f'https://{self.username}:{self.token}@github.com/'
|
||||||
|
|
||||||
def get_branch_url(self, branch_name: str):
|
def get_branch_url(self, branch_name: str) -> str:
|
||||||
return self.get_base_url() + f'/branches/{branch_name}'
|
return self.get_base_url() + f'/branches/{branch_name}'
|
||||||
|
|
||||||
def get_download_url(self):
|
def get_download_url(self) -> str:
|
||||||
return f'{self.base_url}/issues'
|
return f'{self.base_url}/issues'
|
||||||
|
|
||||||
def get_clone_url(self):
|
def get_clone_url(self) -> str:
|
||||||
username_and_token = (
|
username_and_token = (
|
||||||
f'{self.username}:{self.token}'
|
f'{self.username}:{self.token}'
|
||||||
if self.username
|
if self.username
|
||||||
@@ -51,10 +51,10 @@ class GithubIssueHandler(IssueHandlerInterface):
|
|||||||
)
|
)
|
||||||
return f'https://{username_and_token}@github.com/{self.owner}/{self.repo}.git'
|
return f'https://{username_and_token}@github.com/{self.owner}/{self.repo}.git'
|
||||||
|
|
||||||
def get_graphql_url(self):
|
def get_graphql_url(self) -> str:
|
||||||
return 'https://api.github.com/graphql'
|
return 'https://api.github.com/graphql'
|
||||||
|
|
||||||
def get_compare_url(self, branch_name: str):
|
def get_compare_url(self, branch_name: str) -> str:
|
||||||
return f'https://github.com/{self.owner}/{self.repo}/compare/{branch_name}?expand=1'
|
return f'https://github.com/{self.owner}/{self.repo}/compare/{branch_name}?expand=1'
|
||||||
|
|
||||||
def get_converted_issues(
|
def get_converted_issues(
|
||||||
@@ -186,7 +186,7 @@ class GithubIssueHandler(IssueHandlerInterface):
|
|||||||
print(f'Branch {branch_name} exists: {exists}')
|
print(f'Branch {branch_name} exists: {exists}')
|
||||||
return exists
|
return exists
|
||||||
|
|
||||||
def get_branch_name(self, base_branch_name: str):
|
def get_branch_name(self, base_branch_name: str) -> str:
|
||||||
branch_name = base_branch_name
|
branch_name = base_branch_name
|
||||||
attempt = 1
|
attempt = 1
|
||||||
while self.branch_exists(branch_name):
|
while self.branch_exists(branch_name):
|
||||||
@@ -194,7 +194,7 @@ class GithubIssueHandler(IssueHandlerInterface):
|
|||||||
branch_name = f'{base_branch_name}-try{attempt}'
|
branch_name = f'{base_branch_name}-try{attempt}'
|
||||||
return branch_name
|
return branch_name
|
||||||
|
|
||||||
def reply_to_comment(self, pr_number: int, comment_id: str, reply: str):
|
def reply_to_comment(self, pr_number: int, comment_id: str, reply: str) -> None:
|
||||||
# Opting for graphql as REST API doesn't allow reply to replies in comment threads
|
# Opting for graphql as REST API doesn't allow reply to replies in comment threads
|
||||||
query = """
|
query = """
|
||||||
mutation($body: String!, $pullRequestReviewThreadId: ID!) {
|
mutation($body: String!, $pullRequestReviewThreadId: ID!) {
|
||||||
@@ -221,15 +221,18 @@ class GithubIssueHandler(IssueHandlerInterface):
|
|||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
def get_pull_url(self, pr_number: int):
|
def get_pull_url(self, pr_number: int) -> str:
|
||||||
return f'https://github.com/{self.owner}/{self.repo}/pull/{pr_number}'
|
return f'https://github.com/{self.owner}/{self.repo}/pull/{pr_number}'
|
||||||
|
|
||||||
def get_default_branch_name(self) -> str:
|
def get_default_branch_name(self) -> str:
|
||||||
response = requests.get(f'{self.base_url}', headers=self.headers)
|
response = requests.get(f'{self.base_url}', headers=self.headers)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.json()['default_branch']
|
data = response.json()
|
||||||
|
return str(data['default_branch'])
|
||||||
|
|
||||||
def create_pull_request(self, data=dict) -> dict:
|
def create_pull_request(self, data: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||||
|
if data is None:
|
||||||
|
data = {}
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f'{self.base_url}/pulls', headers=self.headers, json=data
|
f'{self.base_url}/pulls', headers=self.headers, json=data
|
||||||
)
|
)
|
||||||
@@ -240,9 +243,9 @@ class GithubIssueHandler(IssueHandlerInterface):
|
|||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
pr_data = response.json()
|
pr_data = response.json()
|
||||||
return pr_data
|
return dict(pr_data)
|
||||||
|
|
||||||
def request_reviewers(self, reviewer: str, pr_number: int):
|
def request_reviewers(self, reviewer: str, pr_number: int) -> None:
|
||||||
review_data = {'reviewers': [reviewer]}
|
review_data = {'reviewers': [reviewer]}
|
||||||
review_response = requests.post(
|
review_response = requests.post(
|
||||||
f'{self.base_url}/pulls/{pr_number}/requested_reviewers',
|
f'{self.base_url}/pulls/{pr_number}/requested_reviewers',
|
||||||
@@ -254,7 +257,7 @@ class GithubIssueHandler(IssueHandlerInterface):
|
|||||||
f'Warning: Failed to request review from {reviewer}: {review_response.text}'
|
f'Warning: Failed to request review from {reviewer}: {review_response.text}'
|
||||||
)
|
)
|
||||||
|
|
||||||
def send_comment_msg(self, issue_number: int, msg: str):
|
def send_comment_msg(self, issue_number: int, msg: str) -> None:
|
||||||
"""Send a comment message to a GitHub issue or pull request.
|
"""Send a comment message to a GitHub issue or pull request.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -282,8 +285,8 @@ class GithubIssueHandler(IssueHandlerInterface):
|
|||||||
review_comments: list[str] | None,
|
review_comments: list[str] | None,
|
||||||
review_threads: list[ReviewThread],
|
review_threads: list[ReviewThread],
|
||||||
thread_comments: list[str] | None,
|
thread_comments: list[str] | None,
|
||||||
):
|
) -> list[str]:
|
||||||
pass
|
return []
|
||||||
|
|
||||||
|
|
||||||
class GithubPRHandler(GithubIssueHandler):
|
class GithubPRHandler(GithubIssueHandler):
|
||||||
@@ -487,7 +490,7 @@ class GithubPRHandler(GithubIssueHandler):
|
|||||||
review_comments: list[str] | None,
|
review_comments: list[str] | None,
|
||||||
review_threads: list[ReviewThread],
|
review_threads: list[ReviewThread],
|
||||||
thread_comments: list[str] | None,
|
thread_comments: list[str] | None,
|
||||||
):
|
) -> list[str]:
|
||||||
new_issue_references = []
|
new_issue_references = []
|
||||||
|
|
||||||
if issue_body:
|
if issue_body:
|
||||||
|
|||||||
@@ -23,38 +23,38 @@ class GitlabIssueHandler(IssueHandlerInterface):
|
|||||||
self.clone_url = self.get_clone_url()
|
self.clone_url = self.get_clone_url()
|
||||||
self.headers = self.get_headers()
|
self.headers = self.get_headers()
|
||||||
|
|
||||||
def set_owner(self, owner: str):
|
def set_owner(self, owner: str) -> None:
|
||||||
self.owner = owner
|
self.owner = owner
|
||||||
|
|
||||||
def get_headers(self):
|
def get_headers(self) -> dict[str, str]:
|
||||||
return {
|
return {
|
||||||
'Authorization': f'Bearer {self.token}',
|
'Authorization': f'Bearer {self.token}',
|
||||||
'Accept': 'application/json',
|
'Accept': 'application/json',
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_base_url(self):
|
def get_base_url(self) -> str:
|
||||||
project_path = quote(f'{self.owner}/{self.repo}', safe="")
|
project_path = quote(f'{self.owner}/{self.repo}', safe='')
|
||||||
return f'https://gitlab.com/api/v4/projects/{project_path}'
|
return f'https://gitlab.com/api/v4/projects/{project_path}'
|
||||||
|
|
||||||
def get_authorize_url(self):
|
def get_authorize_url(self) -> str:
|
||||||
return f'https://{self.username}:{self.token}@gitlab.com/'
|
return f'https://{self.username}:{self.token}@gitlab.com/'
|
||||||
|
|
||||||
def get_branch_url(self, branch_name: str):
|
def get_branch_url(self, branch_name: str) -> str:
|
||||||
return self.get_base_url() + f'/repository/branches/{branch_name}'
|
return self.get_base_url() + f'/repository/branches/{branch_name}'
|
||||||
|
|
||||||
def get_download_url(self):
|
def get_download_url(self) -> str:
|
||||||
return f'{self.base_url}/issues'
|
return f'{self.base_url}/issues'
|
||||||
|
|
||||||
def get_clone_url(self):
|
def get_clone_url(self) -> str:
|
||||||
username_and_token = self.token
|
username_and_token = self.token
|
||||||
if self.username:
|
if self.username:
|
||||||
username_and_token = f'{self.username}:{self.token}'
|
username_and_token = f'{self.username}:{self.token}'
|
||||||
return f'https://{username_and_token}@gitlab.com/{self.owner}/{self.repo}.git'
|
return f'https://{username_and_token}@gitlab.com/{self.owner}/{self.repo}.git'
|
||||||
|
|
||||||
def get_graphql_url(self):
|
def get_graphql_url(self) -> str:
|
||||||
return 'https://gitlab.com/api/graphql'
|
return 'https://gitlab.com/api/graphql'
|
||||||
|
|
||||||
def get_compare_url(self, branch_name: str):
|
def get_compare_url(self, branch_name: str) -> str:
|
||||||
return f'https://gitlab.com/{self.owner}/{self.repo}/-/compare/{self.get_default_branch_name()}...{branch_name}'
|
return f'https://gitlab.com/{self.owner}/{self.repo}/-/compare/{self.get_default_branch_name()}...{branch_name}'
|
||||||
|
|
||||||
def get_converted_issues(
|
def get_converted_issues(
|
||||||
@@ -189,7 +189,7 @@ class GitlabIssueHandler(IssueHandlerInterface):
|
|||||||
print(f'Branch {branch_name} exists: {exists}')
|
print(f'Branch {branch_name} exists: {exists}')
|
||||||
return exists
|
return exists
|
||||||
|
|
||||||
def get_branch_name(self, base_branch_name: str):
|
def get_branch_name(self, base_branch_name: str) -> str:
|
||||||
branch_name = base_branch_name
|
branch_name = base_branch_name
|
||||||
attempt = 1
|
attempt = 1
|
||||||
while self.branch_exists(branch_name):
|
while self.branch_exists(branch_name):
|
||||||
@@ -197,7 +197,7 @@ class GitlabIssueHandler(IssueHandlerInterface):
|
|||||||
branch_name = f'{base_branch_name}-try{attempt}'
|
branch_name = f'{base_branch_name}-try{attempt}'
|
||||||
return branch_name
|
return branch_name
|
||||||
|
|
||||||
def reply_to_comment(self, pr_number: int, comment_id: str, reply: str):
|
def reply_to_comment(self, pr_number: int, comment_id: str, reply: str) -> None:
|
||||||
response = requests.get(
|
response = requests.get(
|
||||||
f'{self.base_url}/merge_requests/{pr_number}/discussions/{comment_id.split('/')[-1]}',
|
f'{self.base_url}/merge_requests/{pr_number}/discussions/{comment_id.split('/')[-1]}',
|
||||||
headers=self.headers,
|
headers=self.headers,
|
||||||
@@ -216,7 +216,7 @@ class GitlabIssueHandler(IssueHandlerInterface):
|
|||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
def get_pull_url(self, pr_number: int):
|
def get_pull_url(self, pr_number: int) -> str:
|
||||||
return (
|
return (
|
||||||
f'https://gitlab.com/{self.owner}/{self.repo}/-/merge_requests/{pr_number}'
|
f'https://gitlab.com/{self.owner}/{self.repo}/-/merge_requests/{pr_number}'
|
||||||
)
|
)
|
||||||
@@ -224,9 +224,12 @@ class GitlabIssueHandler(IssueHandlerInterface):
|
|||||||
def get_default_branch_name(self) -> str:
|
def get_default_branch_name(self) -> str:
|
||||||
response = requests.get(f'{self.base_url}', headers=self.headers)
|
response = requests.get(f'{self.base_url}', headers=self.headers)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.json()['default_branch']
|
data = response.json()
|
||||||
|
return str(data['default_branch'])
|
||||||
|
|
||||||
def create_pull_request(self, data=dict) -> dict:
|
def create_pull_request(self, data: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||||
|
if data is None:
|
||||||
|
data = {}
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f'{self.base_url}/merge_requests', headers=self.headers, json=data
|
f'{self.base_url}/merge_requests', headers=self.headers, json=data
|
||||||
)
|
)
|
||||||
@@ -243,9 +246,9 @@ class GitlabIssueHandler(IssueHandlerInterface):
|
|||||||
if 'iid' in pr_data:
|
if 'iid' in pr_data:
|
||||||
pr_data['number'] = pr_data['iid']
|
pr_data['number'] = pr_data['iid']
|
||||||
|
|
||||||
return pr_data
|
return dict(pr_data)
|
||||||
|
|
||||||
def request_reviewers(self, reviewer: str, pr_number: int):
|
def request_reviewers(self, reviewer: str, pr_number: int) -> None:
|
||||||
response = requests.get(
|
response = requests.get(
|
||||||
f'https://gitlab.com/api/v4/users?username={reviewer}',
|
f'https://gitlab.com/api/v4/users?username={reviewer}',
|
||||||
headers=self.headers,
|
headers=self.headers,
|
||||||
@@ -264,7 +267,7 @@ class GitlabIssueHandler(IssueHandlerInterface):
|
|||||||
f'Warning: Failed to request review from {reviewer}: {review_response.text}'
|
f'Warning: Failed to request review from {reviewer}: {review_response.text}'
|
||||||
)
|
)
|
||||||
|
|
||||||
def send_comment_msg(self, issue_number: int, msg: str):
|
def send_comment_msg(self, issue_number: int, msg: str) -> None:
|
||||||
"""Send a comment message to a GitHub issue or pull request.
|
"""Send a comment message to a GitHub issue or pull request.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -292,8 +295,8 @@ class GitlabIssueHandler(IssueHandlerInterface):
|
|||||||
review_comments: list[str] | None,
|
review_comments: list[str] | None,
|
||||||
review_threads: list[ReviewThread],
|
review_threads: list[ReviewThread],
|
||||||
thread_comments: list[str] | None,
|
thread_comments: list[str] | None,
|
||||||
):
|
) -> list[str]:
|
||||||
pass
|
return []
|
||||||
|
|
||||||
|
|
||||||
class GitlabPRHandler(GitlabIssueHandler):
|
class GitlabPRHandler(GitlabIssueHandler):
|
||||||
@@ -479,7 +482,7 @@ class GitlabPRHandler(GitlabIssueHandler):
|
|||||||
review_comments: list[str] | None,
|
review_comments: list[str] | None,
|
||||||
review_threads: list[ReviewThread],
|
review_threads: list[ReviewThread],
|
||||||
thread_comments: list[str] | None,
|
thread_comments: list[str] | None,
|
||||||
):
|
) -> list[str]:
|
||||||
new_issue_references = []
|
new_issue_references = []
|
||||||
|
|
||||||
if issue_body:
|
if issue_body:
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ class Issue(BaseModel):
|
|||||||
|
|
||||||
class IssueHandlerInterface(ABC):
|
class IssueHandlerInterface(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def set_owner(self, owner: str):
|
def set_owner(self, owner: str) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -40,43 +40,43 @@ class IssueHandlerInterface(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_base_url(self):
|
def get_base_url(self) -> str:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_branch_url(self, branch_name):
|
def get_branch_url(self, branch_name: str) -> str:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_download_url(self):
|
def get_download_url(self) -> str:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_clone_url(self):
|
def get_clone_url(self) -> str:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_pull_url(self, pr_number: int):
|
def get_pull_url(self, pr_number: int) -> str:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_graphql_url(self):
|
def get_graphql_url(self) -> str:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_headers(self):
|
def get_headers(self) -> dict[str, str]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_compare_url(self, branch_name):
|
def get_compare_url(self, branch_name: str) -> str:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_branch_name(self, base_branch_name: str):
|
def get_branch_name(self, base_branch_name: str) -> str:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_default_branch_name(self):
|
def get_default_branch_name(self) -> str:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -84,23 +84,25 @@ class IssueHandlerInterface(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def reply_to_comment(self, pr_number: int, comment_id: str, reply: str):
|
def reply_to_comment(self, pr_number: int, comment_id: str, reply: str) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def send_comment_msg(self, issue_number: int, msg: str):
|
def send_comment_msg(self, issue_number: int, msg: str) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_authorize_url(self):
|
def get_authorize_url(self) -> str:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_pull_request(self, data=dict) -> dict:
|
def create_pull_request(self, data: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||||
pass
|
if data is None:
|
||||||
|
data = {}
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def request_reviewers(self, reviewer: str, pr_number: int):
|
def request_reviewers(self, reviewer: str, pr_number: int) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -112,7 +114,7 @@ class IssueHandlerInterface(ABC):
|
|||||||
review_comments: list[str] | None,
|
review_comments: list[str] | None,
|
||||||
review_threads: list[ReviewThread],
|
review_threads: list[ReviewThread],
|
||||||
thread_comments: list[str] | None,
|
thread_comments: list[str] | None,
|
||||||
):
|
) -> list[str]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ class ServiceContext:
|
|||||||
if llm_config is not None:
|
if llm_config is not None:
|
||||||
self.llm = LLM(llm_config)
|
self.llm = LLM(llm_config)
|
||||||
|
|
||||||
def set_strategy(self, strategy):
|
def set_strategy(self, strategy: IssueHandlerInterface) -> None:
|
||||||
self._strategy = strategy
|
self._strategy = strategy
|
||||||
|
|
||||||
|
|
||||||
@@ -36,7 +36,7 @@ class ServiceContextPR(ServiceContext):
|
|||||||
def __init__(self, strategy: IssueHandlerInterface, llm_config: LLMConfig):
|
def __init__(self, strategy: IssueHandlerInterface, llm_config: LLMConfig):
|
||||||
super().__init__(strategy, llm_config)
|
super().__init__(strategy, llm_config)
|
||||||
|
|
||||||
def get_clone_url(self):
|
def get_clone_url(self) -> str:
|
||||||
return self._strategy.get_clone_url()
|
return self._strategy.get_clone_url()
|
||||||
|
|
||||||
def download_issues(self) -> list[Any]:
|
def download_issues(self) -> list[Any]:
|
||||||
@@ -266,31 +266,31 @@ class ServiceContextIssue(ServiceContext):
|
|||||||
def __init__(self, strategy: IssueHandlerInterface, llm_config: LLMConfig | None):
|
def __init__(self, strategy: IssueHandlerInterface, llm_config: LLMConfig | None):
|
||||||
super().__init__(strategy, llm_config)
|
super().__init__(strategy, llm_config)
|
||||||
|
|
||||||
def get_base_url(self):
|
def get_base_url(self) -> str:
|
||||||
return self._strategy.get_base_url()
|
return self._strategy.get_base_url()
|
||||||
|
|
||||||
def get_branch_url(self, branch_name):
|
def get_branch_url(self, branch_name: str) -> str:
|
||||||
return self._strategy.get_branch_url(branch_name)
|
return self._strategy.get_branch_url(branch_name)
|
||||||
|
|
||||||
def get_download_url(self):
|
def get_download_url(self) -> str:
|
||||||
return self._strategy.get_download_url()
|
return self._strategy.get_download_url()
|
||||||
|
|
||||||
def get_clone_url(self):
|
def get_clone_url(self) -> str:
|
||||||
return self._strategy.get_clone_url()
|
return self._strategy.get_clone_url()
|
||||||
|
|
||||||
def get_graphql_url(self):
|
def get_graphql_url(self) -> str:
|
||||||
return self._strategy.get_graphql_url()
|
return self._strategy.get_graphql_url()
|
||||||
|
|
||||||
def get_headers(self):
|
def get_headers(self) -> dict[str, str]:
|
||||||
return self._strategy.get_headers()
|
return self._strategy.get_headers()
|
||||||
|
|
||||||
def get_authorize_url(self):
|
def get_authorize_url(self) -> str:
|
||||||
return self._strategy.get_authorize_url()
|
return self._strategy.get_authorize_url()
|
||||||
|
|
||||||
def get_pull_url(self, pr_number: int):
|
def get_pull_url(self, pr_number: int) -> str:
|
||||||
return self._strategy.get_pull_url(pr_number)
|
return self._strategy.get_pull_url(pr_number)
|
||||||
|
|
||||||
def get_compare_url(self, branch_name: str):
|
def get_compare_url(self, branch_name: str) -> str:
|
||||||
return self._strategy.get_compare_url(branch_name)
|
return self._strategy.get_compare_url(branch_name)
|
||||||
|
|
||||||
def download_issues(self) -> list[Any]:
|
def download_issues(self) -> list[Any]:
|
||||||
@@ -299,25 +299,27 @@ class ServiceContextIssue(ServiceContext):
|
|||||||
def get_branch_name(
|
def get_branch_name(
|
||||||
self,
|
self,
|
||||||
base_branch_name: str,
|
base_branch_name: str,
|
||||||
):
|
) -> str:
|
||||||
return self._strategy.get_branch_name(base_branch_name)
|
return self._strategy.get_branch_name(base_branch_name)
|
||||||
|
|
||||||
def branch_exists(self, branch_name: str):
|
def branch_exists(self, branch_name: str) -> bool:
|
||||||
return self._strategy.branch_exists(branch_name)
|
return self._strategy.branch_exists(branch_name)
|
||||||
|
|
||||||
def get_default_branch_name(self) -> str:
|
def get_default_branch_name(self) -> str:
|
||||||
return self._strategy.get_default_branch_name()
|
return self._strategy.get_default_branch_name()
|
||||||
|
|
||||||
def create_pull_request(self, data=dict):
|
def create_pull_request(self, data: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||||
|
if data is None:
|
||||||
|
data = {}
|
||||||
return self._strategy.create_pull_request(data)
|
return self._strategy.create_pull_request(data)
|
||||||
|
|
||||||
def request_reviewers(self, reviewer: str, pr_number: int):
|
def request_reviewers(self, reviewer: str, pr_number: int) -> None:
|
||||||
return self._strategy.request_reviewers(reviewer, pr_number)
|
return self._strategy.request_reviewers(reviewer, pr_number)
|
||||||
|
|
||||||
def reply_to_comment(self, pr_number, comment_id, reply):
|
def reply_to_comment(self, pr_number: int, comment_id: str, reply: str) -> None:
|
||||||
return self._strategy.reply_to_comment(pr_number, comment_id, reply)
|
return self._strategy.reply_to_comment(pr_number, comment_id, reply)
|
||||||
|
|
||||||
def send_comment_msg(self, issue_number: int, msg: str):
|
def send_comment_msg(self, issue_number: int, msg: str) -> None:
|
||||||
return self._strategy.send_comment_msg(issue_number, msg)
|
return self._strategy.send_comment_msg(issue_number, msg)
|
||||||
|
|
||||||
def get_issue_comments(
|
def get_issue_comments(
|
||||||
|
|||||||
@@ -5,10 +5,13 @@ import subprocess
|
|||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
from .exceptions import HunkApplyException, SubprocessException
|
from .exceptions import HunkApplyException, SubprocessException
|
||||||
|
from .patch import Change, diffobj
|
||||||
from .snippets import remove, which
|
from .snippets import remove, which
|
||||||
|
|
||||||
|
|
||||||
def _apply_diff_with_subprocess(diff, lines, reverse=False):
|
def _apply_diff_with_subprocess(
|
||||||
|
diff: diffobj, lines: list[str], reverse: bool = False
|
||||||
|
) -> tuple[list[str], list[str] | None]:
|
||||||
# call out to patch program
|
# call out to patch program
|
||||||
patchexec = which('patch')
|
patchexec = which('patch')
|
||||||
if not patchexec:
|
if not patchexec:
|
||||||
@@ -63,21 +66,21 @@ def _apply_diff_with_subprocess(diff, lines, reverse=False):
|
|||||||
return lines, rejlines
|
return lines, rejlines
|
||||||
|
|
||||||
|
|
||||||
def _reverse(changes):
|
def _reverse(changes: list[Change]) -> list[Change]:
|
||||||
def _reverse_change(c):
|
def _reverse_change(c: Change) -> Change:
|
||||||
return c._replace(old=c.new, new=c.old)
|
return c._replace(old=c.new, new=c.old)
|
||||||
|
|
||||||
return [_reverse_change(c) for c in changes]
|
return [_reverse_change(c) for c in changes]
|
||||||
|
|
||||||
|
|
||||||
def apply_diff(diff, text, reverse=False, use_patch=False):
|
def apply_diff(
|
||||||
try:
|
diff: diffobj, text: str | list[str], reverse: bool = False, use_patch: bool = False
|
||||||
lines = text.splitlines()
|
) -> list[str]:
|
||||||
except AttributeError:
|
lines = text.splitlines() if isinstance(text, str) else list(text)
|
||||||
lines = list(text)
|
|
||||||
|
|
||||||
if use_patch:
|
if use_patch:
|
||||||
return _apply_diff_with_subprocess(diff, lines, reverse)
|
lines, _ = _apply_diff_with_subprocess(diff, lines, reverse)
|
||||||
|
return lines
|
||||||
|
|
||||||
n_lines = len(lines)
|
n_lines = len(lines)
|
||||||
|
|
||||||
|
|||||||
@@ -1,31 +1,31 @@
|
|||||||
class PatchingException(Exception):
|
class PatchingException(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class HunkException(PatchingException):
|
class HunkException(PatchingException):
|
||||||
def __init__(self, msg, hunk=None):
|
def __init__(self, msg: str, hunk: int | None = None) -> None:
|
||||||
self.hunk = hunk
|
self.hunk = hunk
|
||||||
if hunk is not None:
|
if hunk is not None:
|
||||||
super(HunkException, self).__init__(
|
super(HunkException, self).__init__(
|
||||||
'{msg}, in hunk #{n}'.format(msg=msg, n=hunk)
|
'{msg}, in hunk #{n}'.format(msg=msg, n=hunk)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
super(HunkException, self).__init__(msg)
|
super(HunkException, self).__init__(msg)
|
||||||
|
|
||||||
|
|
||||||
class ApplyException(PatchingException):
|
class ApplyException(PatchingException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class SubprocessException(ApplyException):
|
class SubprocessException(ApplyException):
|
||||||
def __init__(self, msg, code):
|
def __init__(self, msg: str, code: int) -> None:
|
||||||
super(SubprocessException, self).__init__(msg)
|
super(SubprocessException, self).__init__(msg)
|
||||||
self.code = code
|
self.code = code
|
||||||
|
|
||||||
|
|
||||||
class HunkApplyException(HunkException, ApplyException, ValueError):
|
class HunkApplyException(HunkException, ApplyException, ValueError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ParseException(HunkException, ValueError):
|
class ParseException(HunkException, ValueError):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import base64
|
|||||||
import re
|
import re
|
||||||
import zlib
|
import zlib
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
from . import exceptions
|
from . import exceptions
|
||||||
from .snippets import findall_regex, split_by_regex
|
from .snippets import findall_regex, split_by_regex
|
||||||
@@ -71,11 +72,8 @@ cvs_header_timestamp_colon = re.compile(r':([\d.]+)\t(.+)')
|
|||||||
old_cvs_diffcmd_header = re.compile('^diff.* (.+):(.*) (.+):(.*)$')
|
old_cvs_diffcmd_header = re.compile('^diff.* (.+):(.*) (.+):(.*)$')
|
||||||
|
|
||||||
|
|
||||||
def parse_patch(text):
|
def parse_patch(text: str | list[str]) -> Iterable[diffobj]:
|
||||||
try:
|
lines = text.splitlines() if isinstance(text, str) else text
|
||||||
lines = text.splitlines()
|
|
||||||
except AttributeError:
|
|
||||||
lines = text
|
|
||||||
|
|
||||||
# maybe use this to nuke all of those line endings?
|
# maybe use this to nuke all of those line endings?
|
||||||
# lines = [x.splitlines()[0] for x in lines]
|
# lines = [x.splitlines()[0] for x in lines]
|
||||||
@@ -104,18 +102,15 @@ def parse_patch(text):
|
|||||||
yield diffobj(header=h, changes=d, text=difftext)
|
yield diffobj(header=h, changes=d, text=difftext)
|
||||||
|
|
||||||
|
|
||||||
def parse_header(text):
|
def parse_header(text: str | list[str]) -> header | None:
|
||||||
h = parse_scm_header(text)
|
h = parse_scm_header(text)
|
||||||
if h is None:
|
if h is None:
|
||||||
h = parse_diff_header(text)
|
h = parse_diff_header(text)
|
||||||
return h
|
return h
|
||||||
|
|
||||||
|
|
||||||
def parse_scm_header(text):
|
def parse_scm_header(text: str | list[str]) -> header | None:
|
||||||
try:
|
lines = text.splitlines() if isinstance(text, str) else text
|
||||||
lines = text.splitlines()
|
|
||||||
except AttributeError:
|
|
||||||
lines = text
|
|
||||||
|
|
||||||
check = [
|
check = [
|
||||||
(git_header_index, parse_git_header),
|
(git_header_index, parse_git_header),
|
||||||
@@ -154,11 +149,8 @@ def parse_scm_header(text):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def parse_diff_header(text):
|
def parse_diff_header(text: str | list[str]) -> header | None:
|
||||||
try:
|
lines = text.splitlines() if isinstance(text, str) else text
|
||||||
lines = text.splitlines()
|
|
||||||
except AttributeError:
|
|
||||||
lines = text
|
|
||||||
|
|
||||||
check = [
|
check = [
|
||||||
(unified_header_new_line, parse_unified_header),
|
(unified_header_new_line, parse_unified_header),
|
||||||
@@ -178,10 +170,10 @@ def parse_diff_header(text):
|
|||||||
return None # no header?
|
return None # no header?
|
||||||
|
|
||||||
|
|
||||||
def parse_diff(text):
|
def parse_diff(text: str | list[str]) -> list[Change] | None:
|
||||||
try:
|
if isinstance(text, str):
|
||||||
lines = text.splitlines()
|
lines = text.splitlines()
|
||||||
except AttributeError:
|
else:
|
||||||
lines = text
|
lines = text
|
||||||
|
|
||||||
check = [
|
check = [
|
||||||
@@ -200,11 +192,8 @@ def parse_diff(text):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def parse_git_header(text):
|
def parse_git_header(text: str | list[str]) -> header | None:
|
||||||
try:
|
lines = text.splitlines() if isinstance(text, str) else text
|
||||||
lines = text.splitlines()
|
|
||||||
except AttributeError:
|
|
||||||
lines = text
|
|
||||||
|
|
||||||
old_version = None
|
old_version = None
|
||||||
new_version = None
|
new_version = None
|
||||||
@@ -275,11 +264,8 @@ def parse_git_header(text):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def parse_svn_header(text):
|
def parse_svn_header(text: str | list[str]) -> header | None:
|
||||||
try:
|
lines = text.splitlines() if isinstance(text, str) else text
|
||||||
lines = text.splitlines()
|
|
||||||
except AttributeError:
|
|
||||||
lines = text
|
|
||||||
|
|
||||||
headers = findall_regex(lines, svn_header_index)
|
headers = findall_regex(lines, svn_header_index)
|
||||||
if len(headers) == 0:
|
if len(headers) == 0:
|
||||||
@@ -346,11 +332,8 @@ def parse_svn_header(text):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def parse_cvs_header(text):
|
def parse_cvs_header(text: str | list[str]) -> header | None:
|
||||||
try:
|
lines = text.splitlines() if isinstance(text, str) else text
|
||||||
lines = text.splitlines()
|
|
||||||
except AttributeError:
|
|
||||||
lines = text
|
|
||||||
|
|
||||||
headers = findall_regex(lines, cvs_header_rcs)
|
headers = findall_regex(lines, cvs_header_rcs)
|
||||||
headers_old = findall_regex(lines, old_cvs_diffcmd_header)
|
headers_old = findall_regex(lines, old_cvs_diffcmd_header)
|
||||||
@@ -430,11 +413,8 @@ def parse_cvs_header(text):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def parse_diffcmd_header(text):
|
def parse_diffcmd_header(text: str | list[str]) -> header | None:
|
||||||
try:
|
lines = text.splitlines() if isinstance(text, str) else text
|
||||||
lines = text.splitlines()
|
|
||||||
except AttributeError:
|
|
||||||
lines = text
|
|
||||||
|
|
||||||
headers = findall_regex(lines, diffcmd_header)
|
headers = findall_regex(lines, diffcmd_header)
|
||||||
if len(headers) == 0:
|
if len(headers) == 0:
|
||||||
@@ -454,11 +434,8 @@ def parse_diffcmd_header(text):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def parse_unified_header(text):
|
def parse_unified_header(text: str | list[str]) -> header | None:
|
||||||
try:
|
lines = text.splitlines() if isinstance(text, str) else text
|
||||||
lines = text.splitlines()
|
|
||||||
except AttributeError:
|
|
||||||
lines = text
|
|
||||||
|
|
||||||
headers = findall_regex(lines, unified_header_new_line)
|
headers = findall_regex(lines, unified_header_new_line)
|
||||||
if len(headers) == 0:
|
if len(headers) == 0:
|
||||||
@@ -490,11 +467,8 @@ def parse_unified_header(text):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def parse_context_header(text):
|
def parse_context_header(text: str | list[str]) -> header | None:
|
||||||
try:
|
lines = text.splitlines() if isinstance(text, str) else text
|
||||||
lines = text.splitlines()
|
|
||||||
except AttributeError:
|
|
||||||
lines = text
|
|
||||||
|
|
||||||
headers = findall_regex(lines, context_header_old_line)
|
headers = findall_regex(lines, context_header_old_line)
|
||||||
if len(headers) == 0:
|
if len(headers) == 0:
|
||||||
@@ -526,11 +500,8 @@ def parse_context_header(text):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def parse_default_diff(text):
|
def parse_default_diff(text: str | list[str]) -> list[Change] | None:
|
||||||
try:
|
lines = text.splitlines() if isinstance(text, str) else text
|
||||||
lines = text.splitlines()
|
|
||||||
except AttributeError:
|
|
||||||
lines = text
|
|
||||||
|
|
||||||
old = 0
|
old = 0
|
||||||
new = 0
|
new = 0
|
||||||
@@ -582,11 +553,8 @@ def parse_default_diff(text):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def parse_unified_diff(text):
|
def parse_unified_diff(text: str | list[str]) -> list[Change] | None:
|
||||||
try:
|
lines = text.splitlines() if isinstance(text, str) else text
|
||||||
lines = text.splitlines()
|
|
||||||
except AttributeError:
|
|
||||||
lines = text
|
|
||||||
|
|
||||||
old = 0
|
old = 0
|
||||||
new = 0
|
new = 0
|
||||||
@@ -652,11 +620,8 @@ def parse_unified_diff(text):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def parse_context_diff(text):
|
def parse_context_diff(text: str | list[str]) -> list[Change] | None:
|
||||||
try:
|
lines = text.splitlines() if isinstance(text, str) else text
|
||||||
lines = text.splitlines()
|
|
||||||
except AttributeError:
|
|
||||||
lines = text
|
|
||||||
|
|
||||||
old = 0
|
old = 0
|
||||||
new = 0
|
new = 0
|
||||||
@@ -795,11 +760,8 @@ def parse_context_diff(text):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def parse_ed_diff(text):
|
def parse_ed_diff(text: str | list[str]) -> list[Change] | None:
|
||||||
try:
|
lines = text.splitlines() if isinstance(text, str) else text
|
||||||
lines = text.splitlines()
|
|
||||||
except AttributeError:
|
|
||||||
lines = text
|
|
||||||
|
|
||||||
old = 0
|
old = 0
|
||||||
j = 0
|
j = 0
|
||||||
@@ -878,12 +840,9 @@ def parse_ed_diff(text):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def parse_rcs_ed_diff(text):
|
def parse_rcs_ed_diff(text: str | list[str]) -> list[Change] | None:
|
||||||
# much like forward ed, but no 'c' type
|
# much like forward ed, but no 'c' type
|
||||||
try:
|
lines = text.splitlines() if isinstance(text, str) else text
|
||||||
lines = text.splitlines()
|
|
||||||
except AttributeError:
|
|
||||||
lines = text
|
|
||||||
|
|
||||||
old = 0
|
old = 0
|
||||||
j = 0
|
j = 0
|
||||||
@@ -905,7 +864,7 @@ def parse_rcs_ed_diff(text):
|
|||||||
|
|
||||||
hunk_kind = o.group(1)
|
hunk_kind = o.group(1)
|
||||||
old = int(o.group(2))
|
old = int(o.group(2))
|
||||||
size = int(o.group(3))
|
size = int(o.group(3)) if o.group(3) else 0
|
||||||
|
|
||||||
if hunk_kind == 'a':
|
if hunk_kind == 'a':
|
||||||
old += total_change_size + 1
|
old += total_change_size + 1
|
||||||
@@ -926,15 +885,11 @@ def parse_rcs_ed_diff(text):
|
|||||||
|
|
||||||
if len(changes) > 0:
|
if len(changes) > 0:
|
||||||
return changes
|
return changes
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def parse_git_binary_diff(text):
|
def parse_git_binary_diff(text: str | list[str]) -> list[Change] | None:
|
||||||
try:
|
lines = text.splitlines() if isinstance(text, str) else text
|
||||||
lines = text.splitlines()
|
|
||||||
except AttributeError:
|
|
||||||
lines = text
|
|
||||||
|
|
||||||
changes: list[Change] = list()
|
changes: list[Change] = list()
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
from shutil import rmtree
|
from shutil import rmtree
|
||||||
|
|
||||||
|
|
||||||
def remove(path):
|
def remove(path: str) -> None:
|
||||||
if os.path.exists(path):
|
if os.path.exists(path):
|
||||||
if os.path.isdir(path):
|
if os.path.isdir(path):
|
||||||
rmtree(path)
|
rmtree(path)
|
||||||
@@ -13,7 +14,7 @@ def remove(path):
|
|||||||
|
|
||||||
|
|
||||||
# find all indices of a list of strings that match a regex
|
# find all indices of a list of strings that match a regex
|
||||||
def findall_regex(items, regex):
|
def findall_regex(items: list[str], regex: re.Pattern[str]) -> list[int]:
|
||||||
found = list()
|
found = list()
|
||||||
for i in range(0, len(items)):
|
for i in range(0, len(items)):
|
||||||
k = regex.match(items[i])
|
k = regex.match(items[i])
|
||||||
@@ -24,7 +25,7 @@ def findall_regex(items, regex):
|
|||||||
return found
|
return found
|
||||||
|
|
||||||
|
|
||||||
def split_by_regex(items, regex):
|
def split_by_regex(items: list[str], regex: re.Pattern[str]) -> list[list[str]]:
|
||||||
splits = list()
|
splits = list()
|
||||||
indices = findall_regex(items, regex)
|
indices = findall_regex(items, regex)
|
||||||
if not indices:
|
if not indices:
|
||||||
@@ -45,8 +46,8 @@ def split_by_regex(items, regex):
|
|||||||
|
|
||||||
|
|
||||||
# http://stackoverflow.com/questions/377017/test-if-executable-exists-in-python
|
# http://stackoverflow.com/questions/377017/test-if-executable-exists-in-python
|
||||||
def which(program):
|
def which(program: str) -> str | None:
|
||||||
def is_exe(fpath):
|
def is_exe(fpath: str) -> bool:
|
||||||
return os.path.isfile(fpath) and os.access(fpath, os.X_OK)
|
return os.path.isfile(fpath) and os.access(fpath, os.X_OK)
|
||||||
|
|
||||||
fpath, fname = os.path.split(program)
|
fpath, fname = os.path.split(program)
|
||||||
|
|||||||
@@ -6,8 +6,9 @@ import multiprocessing as mp
|
|||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import subprocess
|
import subprocess
|
||||||
from typing import Awaitable, TextIO
|
from typing import Any, Awaitable, TextIO
|
||||||
|
|
||||||
|
from pydantic import SecretStr
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import openhands
|
import openhands
|
||||||
@@ -25,7 +26,7 @@ from openhands.resolver.utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def cleanup():
|
def cleanup() -> None:
|
||||||
print('Cleaning up child processes...')
|
print('Cleaning up child processes...')
|
||||||
for process in mp.active_children():
|
for process in mp.active_children():
|
||||||
print(f'Terminating child process: {process.name}')
|
print(f'Terminating child process: {process.name}')
|
||||||
@@ -214,7 +215,7 @@ async def resolve_issues(
|
|||||||
# Use asyncio.gather with a semaphore to limit concurrency
|
# Use asyncio.gather with a semaphore to limit concurrency
|
||||||
sem = asyncio.Semaphore(num_workers)
|
sem = asyncio.Semaphore(num_workers)
|
||||||
|
|
||||||
async def run_with_semaphore(task):
|
async def run_with_semaphore(task: Awaitable[Any]) -> Any:
|
||||||
async with sem:
|
async with sem:
|
||||||
return await task
|
return await task
|
||||||
|
|
||||||
@@ -228,7 +229,7 @@ async def resolve_issues(
|
|||||||
logger.info('Finished.')
|
logger.info('Finished.')
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description='Resolve multiple issues from Github or Gitlab.'
|
description='Resolve multiple issues from Github or Gitlab.'
|
||||||
)
|
)
|
||||||
@@ -349,7 +350,7 @@ def main():
|
|||||||
|
|
||||||
llm_config = LLMConfig(
|
llm_config = LLMConfig(
|
||||||
model=my_args.llm_model or os.environ['LLM_MODEL'],
|
model=my_args.llm_model or os.environ['LLM_MODEL'],
|
||||||
api_key=str(api_key) if api_key else None,
|
api_key=SecretStr(api_key) if api_key else None,
|
||||||
base_url=my_args.llm_base_url or os.environ.get('LLM_BASE_URL', None),
|
base_url=my_args.llm_base_url or os.environ.get('LLM_BASE_URL', None),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import subprocess
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from pydantic import SecretStr
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
import openhands
|
import openhands
|
||||||
@@ -18,6 +19,7 @@ from openhands.core.config import AgentConfig, AppConfig, LLMConfig, SandboxConf
|
|||||||
from openhands.core.logger import openhands_logger as logger
|
from openhands.core.logger import openhands_logger as logger
|
||||||
from openhands.core.main import create_runtime, run_controller
|
from openhands.core.main import create_runtime, run_controller
|
||||||
from openhands.events.action import CmdRunAction, MessageAction
|
from openhands.events.action import CmdRunAction, MessageAction
|
||||||
|
from openhands.events.event import Event
|
||||||
from openhands.events.observation import (
|
from openhands.events.observation import (
|
||||||
CmdOutputObservation,
|
CmdOutputObservation,
|
||||||
ErrorObservation,
|
ErrorObservation,
|
||||||
@@ -48,7 +50,7 @@ AGENT_CLASS = 'CodeActAgent'
|
|||||||
def initialize_runtime(
|
def initialize_runtime(
|
||||||
runtime: Runtime,
|
runtime: Runtime,
|
||||||
platform: Platform,
|
platform: Platform,
|
||||||
):
|
) -> None:
|
||||||
"""Initialize the runtime for the agent.
|
"""Initialize the runtime for the agent.
|
||||||
|
|
||||||
This function is called before the runtime is used to run the agent.
|
This function is called before the runtime is used to run the agent.
|
||||||
@@ -192,26 +194,28 @@ async def process_issue(
|
|||||||
# This code looks unnecessary because these are default values in the config class
|
# This code looks unnecessary because these are default values in the config class
|
||||||
# they're set by default if nothing else overrides them
|
# they're set by default if nothing else overrides them
|
||||||
# FIXME we should remove them here
|
# FIXME we should remove them here
|
||||||
kwargs = {}
|
sandbox_config = SandboxConfig(
|
||||||
|
runtime_container_image=runtime_container_image,
|
||||||
|
enable_auto_lint=False,
|
||||||
|
use_host_network=False,
|
||||||
|
# large enough timeout, since some testcases take very long to run
|
||||||
|
timeout=300,
|
||||||
|
)
|
||||||
|
|
||||||
if os.getenv('GITLAB_CI') == 'True':
|
if os.getenv('GITLAB_CI') == 'True':
|
||||||
kwargs['local_runtime_url'] = os.getenv('LOCAL_RUNTIME_URL', 'http://localhost')
|
sandbox_config.local_runtime_url = os.getenv(
|
||||||
|
'LOCAL_RUNTIME_URL', 'http://localhost'
|
||||||
|
)
|
||||||
user_id = os.getuid() if hasattr(os, 'getuid') else 1000
|
user_id = os.getuid() if hasattr(os, 'getuid') else 1000
|
||||||
if user_id == 0:
|
if user_id == 0:
|
||||||
kwargs['user_id'] = get_unique_uid()
|
sandbox_config.user_id = get_unique_uid()
|
||||||
|
|
||||||
config = AppConfig(
|
config = AppConfig(
|
||||||
default_agent='CodeActAgent',
|
default_agent='CodeActAgent',
|
||||||
runtime='docker',
|
runtime='docker',
|
||||||
max_budget_per_task=4,
|
max_budget_per_task=4,
|
||||||
max_iterations=max_iterations,
|
max_iterations=max_iterations,
|
||||||
sandbox=SandboxConfig(
|
sandbox=sandbox_config,
|
||||||
runtime_container_image=runtime_container_image,
|
|
||||||
enable_auto_lint=False,
|
|
||||||
use_host_network=False,
|
|
||||||
# large enough timeout, since some testcases take very long to run
|
|
||||||
timeout=300,
|
|
||||||
**kwargs,
|
|
||||||
),
|
|
||||||
# do not mount workspace
|
# do not mount workspace
|
||||||
workspace_base=workspace_base,
|
workspace_base=workspace_base,
|
||||||
workspace_mount_path=workspace_base,
|
workspace_mount_path=workspace_base,
|
||||||
@@ -222,7 +226,7 @@ async def process_issue(
|
|||||||
runtime = create_runtime(config)
|
runtime = create_runtime(config)
|
||||||
await runtime.connect()
|
await runtime.connect()
|
||||||
|
|
||||||
def on_event(evt):
|
def on_event(evt: Event) -> None:
|
||||||
logger.info(evt)
|
logger.info(evt)
|
||||||
|
|
||||||
runtime.event_stream.subscribe(EventStreamSubscriber.MAIN, on_event, str(uuid4()))
|
runtime.event_stream.subscribe(EventStreamSubscriber.MAIN, on_event, str(uuid4()))
|
||||||
@@ -524,10 +528,10 @@ async def resolve_issue(
|
|||||||
logger.info('Finished.')
|
logger.info('Finished.')
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
def int_or_none(value):
|
def int_or_none(value: str) -> int | None:
|
||||||
if value.lower() == 'none':
|
if value.lower() == 'none':
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
@@ -654,7 +658,7 @@ def main():
|
|||||||
api_key = my_args.llm_api_key or os.environ['LLM_API_KEY']
|
api_key = my_args.llm_api_key or os.environ['LLM_API_KEY']
|
||||||
llm_config = LLMConfig(
|
llm_config = LLMConfig(
|
||||||
model=my_args.llm_model or os.environ['LLM_MODEL'],
|
model=my_args.llm_model or os.environ['LLM_MODEL'],
|
||||||
api_key=str(api_key) if api_key else None,
|
api_key=SecretStr(api_key) if api_key else None,
|
||||||
base_url=my_args.llm_base_url or os.environ.get('LLM_BASE_URL', None),
|
base_url=my_args.llm_base_url or os.environ.get('LLM_BASE_URL', None),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import shutil
|
|||||||
import subprocess
|
import subprocess
|
||||||
|
|
||||||
import jinja2
|
import jinja2
|
||||||
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from openhands.core.config import LLMConfig
|
from openhands.core.config import LLMConfig
|
||||||
from openhands.core.logger import openhands_logger as logger
|
from openhands.core.logger import openhands_logger as logger
|
||||||
@@ -543,7 +544,7 @@ def process_all_successful_issues(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description='Send a pull request to Github or Gitlab.'
|
description='Send a pull request to Github or Gitlab.'
|
||||||
)
|
)
|
||||||
@@ -641,7 +642,7 @@ def main():
|
|||||||
api_key = my_args.llm_api_key or os.environ['LLM_API_KEY']
|
api_key = my_args.llm_api_key or os.environ['LLM_API_KEY']
|
||||||
llm_config = LLMConfig(
|
llm_config = LLMConfig(
|
||||||
model=my_args.llm_model or os.environ['LLM_MODEL'],
|
model=my_args.llm_model or os.environ['LLM_MODEL'],
|
||||||
api_key=str(api_key) if api_key else None,
|
api_key=SecretStr(api_key) if api_key else None,
|
||||||
base_url=my_args.llm_base_url or os.environ.get('LLM_BASE_URL', None),
|
base_url=my_args.llm_base_url or os.environ.get('LLM_BASE_URL', None),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -107,7 +107,7 @@ def codeact_user_response(
|
|||||||
return msg
|
return msg
|
||||||
|
|
||||||
|
|
||||||
def cleanup():
|
def cleanup() -> None:
|
||||||
print('Cleaning up child processes...')
|
print('Cleaning up child processes...')
|
||||||
for process in mp.active_children():
|
for process in mp.active_children():
|
||||||
print(f'Terminating child process: {process.name}')
|
print(f'Terminating child process: {process.name}')
|
||||||
@@ -115,7 +115,9 @@ def cleanup():
|
|||||||
process.join()
|
process.join()
|
||||||
|
|
||||||
|
|
||||||
def prepare_dataset(dataset: pd.DataFrame, output_file: str, eval_n_limit: int):
|
def prepare_dataset(
|
||||||
|
dataset: pd.DataFrame, output_file: str, eval_n_limit: int
|
||||||
|
) -> pd.DataFrame:
|
||||||
assert 'instance_id' in dataset.columns, (
|
assert 'instance_id' in dataset.columns, (
|
||||||
"Expected 'instance_id' column in the dataset. You should define your own "
|
"Expected 'instance_id' column in the dataset. You should define your own "
|
||||||
"unique identifier for each instance and use it as the 'instance_id' column."
|
"unique identifier for each instance and use it as the 'instance_id' column."
|
||||||
@@ -152,7 +154,7 @@ def prepare_dataset(dataset: pd.DataFrame, output_file: str, eval_n_limit: int):
|
|||||||
|
|
||||||
def reset_logger_for_multiprocessing(
|
def reset_logger_for_multiprocessing(
|
||||||
logger: logging.Logger, instance_id: str, log_dir: str
|
logger: logging.Logger, instance_id: str, log_dir: str
|
||||||
):
|
) -> None:
|
||||||
"""Reset the logger for multiprocessing.
|
"""Reset the logger for multiprocessing.
|
||||||
|
|
||||||
Save logs to a separate file for each process, instead of trying to write to the
|
Save logs to a separate file for each process, instead of trying to write to the
|
||||||
@@ -208,7 +210,7 @@ def extract_issue_references(body: str) -> list[int]:
|
|||||||
return [int(match) for match in re.findall(pattern, body)]
|
return [int(match) for match in re.findall(pattern, body)]
|
||||||
|
|
||||||
|
|
||||||
def get_unique_uid(start_uid=1000):
|
def get_unique_uid(start_uid: int = 1000) -> int:
|
||||||
existing_uids = set()
|
existing_uids = set()
|
||||||
with open('/etc/passwd', 'r') as passwd_file:
|
with open('/etc/passwd', 'r') as passwd_file:
|
||||||
for line in passwd_file:
|
for line in passwd_file:
|
||||||
|
|||||||
@@ -4,7 +4,9 @@ import os
|
|||||||
from openhands.resolver.io_utils import load_single_resolver_output
|
from openhands.resolver.io_utils import load_single_resolver_output
|
||||||
|
|
||||||
|
|
||||||
def visualize_resolver_output(issue_number: int, output_dir: str, vis_method: str):
|
def visualize_resolver_output(
|
||||||
|
issue_number: int, output_dir: str, vis_method: str
|
||||||
|
) -> None:
|
||||||
output_jsonl = os.path.join(output_dir, 'output.jsonl')
|
output_jsonl = os.path.join(output_dir, 'output.jsonl')
|
||||||
resolver_output = load_single_resolver_output(output_jsonl, issue_number)
|
resolver_output = load_single_resolver_output(output_jsonl, issue_number)
|
||||||
if vis_method == 'json':
|
if vis_method == 'json':
|
||||||
|
|||||||
Reference in New Issue
Block a user