Fix type checking errors in resolver directory (#6738)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Graham Neubig
2025-02-18 20:13:33 -05:00
committed by GitHub
parent 1a7003a705
commit f4e5fb2873
13 changed files with 209 additions and 230 deletions

View File

@@ -5,10 +5,13 @@ import subprocess
import tempfile
from .exceptions import HunkApplyException, SubprocessException
from .patch import Change, diffobj
from .snippets import remove, which
def _apply_diff_with_subprocess(diff, lines, reverse=False):
def _apply_diff_with_subprocess(
diff: diffobj, lines: list[str], reverse: bool = False
) -> tuple[list[str], list[str] | None]:
# call out to patch program
patchexec = which('patch')
if not patchexec:
@@ -63,21 +66,21 @@ def _apply_diff_with_subprocess(diff, lines, reverse=False):
return lines, rejlines
def _reverse(changes):
def _reverse_change(c):
def _reverse(changes: list[Change]) -> list[Change]:
def _reverse_change(c: Change) -> Change:
return c._replace(old=c.new, new=c.old)
return [_reverse_change(c) for c in changes]
def apply_diff(diff, text, reverse=False, use_patch=False):
try:
lines = text.splitlines()
except AttributeError:
lines = list(text)
def apply_diff(
diff: diffobj, text: str | list[str], reverse: bool = False, use_patch: bool = False
) -> list[str]:
lines = text.splitlines() if isinstance(text, str) else list(text)
if use_patch:
return _apply_diff_with_subprocess(diff, lines, reverse)
lines, _ = _apply_diff_with_subprocess(diff, lines, reverse)
return lines
n_lines = len(lines)

View File

@@ -1,31 +1,31 @@
class PatchingException(Exception):
pass
class HunkException(PatchingException):
def __init__(self, msg, hunk=None):
self.hunk = hunk
if hunk is not None:
super(HunkException, self).__init__(
'{msg}, in hunk #{n}'.format(msg=msg, n=hunk)
)
else:
super(HunkException, self).__init__(msg)
class ApplyException(PatchingException):
pass
class SubprocessException(ApplyException):
def __init__(self, msg, code):
super(SubprocessException, self).__init__(msg)
self.code = code
class HunkApplyException(HunkException, ApplyException, ValueError):
pass
class ParseException(HunkException, ValueError):
pass
class PatchingException(Exception):
pass
class HunkException(PatchingException):
def __init__(self, msg: str, hunk: int | None = None) -> None:
self.hunk = hunk
if hunk is not None:
super(HunkException, self).__init__(
'{msg}, in hunk #{n}'.format(msg=msg, n=hunk)
)
else:
super(HunkException, self).__init__(msg)
class ApplyException(PatchingException):
pass
class SubprocessException(ApplyException):
def __init__(self, msg: str, code: int) -> None:
super(SubprocessException, self).__init__(msg)
self.code = code
class HunkApplyException(HunkException, ApplyException, ValueError):
pass
class ParseException(HunkException, ValueError):
pass

View File

@@ -3,6 +3,7 @@ import base64
import re
import zlib
from collections import namedtuple
from typing import Iterable
from . import exceptions
from .snippets import findall_regex, split_by_regex
@@ -71,11 +72,8 @@ cvs_header_timestamp_colon = re.compile(r':([\d.]+)\t(.+)')
old_cvs_diffcmd_header = re.compile('^diff.* (.+):(.*) (.+):(.*)$')
def parse_patch(text):
try:
lines = text.splitlines()
except AttributeError:
lines = text
def parse_patch(text: str | list[str]) -> Iterable[diffobj]:
lines = text.splitlines() if isinstance(text, str) else text
# maybe use this to nuke all of those line endings?
# lines = [x.splitlines()[0] for x in lines]
@@ -104,18 +102,15 @@ def parse_patch(text):
yield diffobj(header=h, changes=d, text=difftext)
def parse_header(text):
def parse_header(text: str | list[str]) -> header | None:
h = parse_scm_header(text)
if h is None:
h = parse_diff_header(text)
return h
def parse_scm_header(text):
try:
lines = text.splitlines()
except AttributeError:
lines = text
def parse_scm_header(text: str | list[str]) -> header | None:
lines = text.splitlines() if isinstance(text, str) else text
check = [
(git_header_index, parse_git_header),
@@ -154,11 +149,8 @@ def parse_scm_header(text):
return None
def parse_diff_header(text):
try:
lines = text.splitlines()
except AttributeError:
lines = text
def parse_diff_header(text: str | list[str]) -> header | None:
lines = text.splitlines() if isinstance(text, str) else text
check = [
(unified_header_new_line, parse_unified_header),
@@ -178,10 +170,10 @@ def parse_diff_header(text):
return None # no header?
def parse_diff(text):
try:
def parse_diff(text: str | list[str]) -> list[Change] | None:
if isinstance(text, str):
lines = text.splitlines()
except AttributeError:
else:
lines = text
check = [
@@ -200,11 +192,8 @@ def parse_diff(text):
return None
def parse_git_header(text):
try:
lines = text.splitlines()
except AttributeError:
lines = text
def parse_git_header(text: str | list[str]) -> header | None:
lines = text.splitlines() if isinstance(text, str) else text
old_version = None
new_version = None
@@ -275,11 +264,8 @@ def parse_git_header(text):
return None
def parse_svn_header(text):
try:
lines = text.splitlines()
except AttributeError:
lines = text
def parse_svn_header(text: str | list[str]) -> header | None:
lines = text.splitlines() if isinstance(text, str) else text
headers = findall_regex(lines, svn_header_index)
if len(headers) == 0:
@@ -346,11 +332,8 @@ def parse_svn_header(text):
return None
def parse_cvs_header(text):
try:
lines = text.splitlines()
except AttributeError:
lines = text
def parse_cvs_header(text: str | list[str]) -> header | None:
lines = text.splitlines() if isinstance(text, str) else text
headers = findall_regex(lines, cvs_header_rcs)
headers_old = findall_regex(lines, old_cvs_diffcmd_header)
@@ -430,11 +413,8 @@ def parse_cvs_header(text):
return None
def parse_diffcmd_header(text):
try:
lines = text.splitlines()
except AttributeError:
lines = text
def parse_diffcmd_header(text: str | list[str]) -> header | None:
lines = text.splitlines() if isinstance(text, str) else text
headers = findall_regex(lines, diffcmd_header)
if len(headers) == 0:
@@ -454,11 +434,8 @@ def parse_diffcmd_header(text):
return None
def parse_unified_header(text):
try:
lines = text.splitlines()
except AttributeError:
lines = text
def parse_unified_header(text: str | list[str]) -> header | None:
lines = text.splitlines() if isinstance(text, str) else text
headers = findall_regex(lines, unified_header_new_line)
if len(headers) == 0:
@@ -490,11 +467,8 @@ def parse_unified_header(text):
return None
def parse_context_header(text):
try:
lines = text.splitlines()
except AttributeError:
lines = text
def parse_context_header(text: str | list[str]) -> header | None:
lines = text.splitlines() if isinstance(text, str) else text
headers = findall_regex(lines, context_header_old_line)
if len(headers) == 0:
@@ -526,11 +500,8 @@ def parse_context_header(text):
return None
def parse_default_diff(text):
try:
lines = text.splitlines()
except AttributeError:
lines = text
def parse_default_diff(text: str | list[str]) -> list[Change] | None:
lines = text.splitlines() if isinstance(text, str) else text
old = 0
new = 0
@@ -582,11 +553,8 @@ def parse_default_diff(text):
return None
def parse_unified_diff(text):
try:
lines = text.splitlines()
except AttributeError:
lines = text
def parse_unified_diff(text: str | list[str]) -> list[Change] | None:
lines = text.splitlines() if isinstance(text, str) else text
old = 0
new = 0
@@ -652,11 +620,8 @@ def parse_unified_diff(text):
return None
def parse_context_diff(text):
try:
lines = text.splitlines()
except AttributeError:
lines = text
def parse_context_diff(text: str | list[str]) -> list[Change] | None:
lines = text.splitlines() if isinstance(text, str) else text
old = 0
new = 0
@@ -795,11 +760,8 @@ def parse_context_diff(text):
return None
def parse_ed_diff(text):
try:
lines = text.splitlines()
except AttributeError:
lines = text
def parse_ed_diff(text: str | list[str]) -> list[Change] | None:
lines = text.splitlines() if isinstance(text, str) else text
old = 0
j = 0
@@ -878,12 +840,9 @@ def parse_ed_diff(text):
return None
def parse_rcs_ed_diff(text):
def parse_rcs_ed_diff(text: str | list[str]) -> list[Change] | None:
# much like forward ed, but no 'c' type
try:
lines = text.splitlines()
except AttributeError:
lines = text
lines = text.splitlines() if isinstance(text, str) else text
old = 0
j = 0
@@ -905,7 +864,7 @@ def parse_rcs_ed_diff(text):
hunk_kind = o.group(1)
old = int(o.group(2))
size = int(o.group(3))
size = int(o.group(3)) if o.group(3) else 0
if hunk_kind == 'a':
old += total_change_size + 1
@@ -926,15 +885,11 @@ def parse_rcs_ed_diff(text):
if len(changes) > 0:
return changes
return None
def parse_git_binary_diff(text):
try:
lines = text.splitlines()
except AttributeError:
lines = text
def parse_git_binary_diff(text: str | list[str]) -> list[Change] | None:
lines = text.splitlines() if isinstance(text, str) else text
changes: list[Change] = list()

View File

@@ -1,10 +1,11 @@
# -*- coding: utf-8 -*-
import os
import re
from shutil import rmtree
def remove(path):
def remove(path: str) -> None:
if os.path.exists(path):
if os.path.isdir(path):
rmtree(path)
@@ -13,7 +14,7 @@ def remove(path):
# find all indices of a list of strings that match a regex
def findall_regex(items, regex):
def findall_regex(items: list[str], regex: re.Pattern[str]) -> list[int]:
found = list()
for i in range(0, len(items)):
k = regex.match(items[i])
@@ -24,7 +25,7 @@ def findall_regex(items, regex):
return found
def split_by_regex(items, regex):
def split_by_regex(items: list[str], regex: re.Pattern[str]) -> list[list[str]]:
splits = list()
indices = findall_regex(items, regex)
if not indices:
@@ -45,8 +46,8 @@ def split_by_regex(items, regex):
# http://stackoverflow.com/questions/377017/test-if-executable-exists-in-python
def which(program):
def is_exe(fpath):
def which(program: str) -> str | None:
def is_exe(fpath: str) -> bool:
return os.path.isfile(fpath) and os.access(fpath, os.X_OK)
fpath, fname = os.path.split(program)