Compare commits

...

1 Commits

Author SHA1 Message Date
openhands
a0eb4f7f0a 🔒 Add type hints to utils modules 2025-03-25 12:23:37 +00:00
6 changed files with 55 additions and 32 deletions

View File

@@ -7,7 +7,9 @@ GENERAL_TIMEOUT: int = 15
EXECUTOR = ThreadPoolExecutor()
async def call_sync_from_async(fn: Callable, *args, **kwargs):
async def call_sync_from_async(
fn: Callable[..., object], *args: object, **kwargs: object
) -> object:
"""
Shorthand for running a function in the default background thread pool executor
and awaiting the result. The nature of synchronous code is that the future
@@ -20,8 +22,11 @@ async def call_sync_from_async(fn: Callable, *args, **kwargs):
def call_async_from_sync(
corofn: Callable, timeout: float = GENERAL_TIMEOUT, *args, **kwargs
):
corofn: Callable[..., Coroutine[object, object, object]],
timeout: float = GENERAL_TIMEOUT,
*args: object,
**kwargs: object,
) -> object:
"""
Shorthand for running a coroutine in the default background thread pool executor
and awaiting the result
@@ -32,12 +37,12 @@ def call_async_from_sync(
if not asyncio.iscoroutinefunction(corofn):
raise ValueError('corofn is not a coroutine function')
async def arun():
async def arun() -> object:
coro = corofn(*args, **kwargs)
result = await coro
return result
def run():
def run() -> object:
loop_for_thread = asyncio.new_event_loop()
try:
asyncio.set_event_loop(loop_for_thread)
@@ -52,10 +57,15 @@ def call_async_from_sync(
async def call_coro_in_bg_thread(
corofn: Callable, timeout: float = GENERAL_TIMEOUT, *args, **kwargs
):
corofn: Callable[..., Coroutine[object, object, object]],
timeout: float = GENERAL_TIMEOUT,
*args: object,
**kwargs: object,
) -> object:
"""Function for running a coroutine in a background thread."""
await call_sync_from_async(call_async_from_sync, corofn, timeout, *args, **kwargs)
return await call_sync_from_async(
call_async_from_sync, corofn, timeout, *args, **kwargs
)
async def wait_all(
@@ -90,8 +100,8 @@ async def wait_all(
class AsyncException(Exception):
def __init__(self, exceptions):
def __init__(self, exceptions: list[Exception]) -> None:
self.exceptions = exceptions
def __str__(self):
def __str__(self) -> str:
return '\n'.join(str(e) for e in self.exceptions)

View File

@@ -25,7 +25,7 @@ class Chunk(BaseModel):
return ret
def _create_chunks_from_raw_string(content: str, size: int):
def _create_chunks_from_raw_string(content: str, size: int) -> list[Chunk]:
lines = content.split('\n')
ret = []
for i in range(0, len(lines), size):
@@ -65,7 +65,7 @@ def normalized_lcs(chunk: str, query: str) -> float:
"""
if len(chunk) == 0:
return 0.0
_score = pylcs.lcs_sequence_length(chunk, query)
_score = float(pylcs.lcs_sequence_length(chunk, query))
return _score / len(chunk)

View File

@@ -15,15 +15,17 @@ Hopefully, this will be fixed soon and we can remove this abomination.
"""
import contextlib
from typing import Callable
from typing import Any, Callable, Iterator, TypeVar
import httpx
T = TypeVar('T')
@contextlib.contextmanager
def ensure_httpx_close():
def ensure_httpx_close() -> Iterator[None]:
wrapped_class = httpx.Client
proxys = []
proxys: list[Any] = []
class ClientProxy:
"""
@@ -32,47 +34,52 @@ def ensure_httpx_close():
where a client is reused, we need to be able to reuse the client even after closing it.
"""
client_constructor: Callable
args: tuple
kwargs: dict
client: httpx.Client
client_constructor: Callable[..., Any]
args: tuple[Any, ...]
kwargs: dict[str, Any]
client: httpx.Client | None
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.args = args
self.kwargs = kwargs
self.client = wrapped_class(*self.args, **self.kwargs)
proxys.append(self)
def __getattr__(self, name):
def __getattr__(self, name: str) -> Any:
# Invoke a method on the proxied client - create one if required
if self.client is None:
self.client = wrapped_class(*self.args, **self.kwargs)
return getattr(self.client, name)
def close(self):
def close(self) -> None:
# Close the client if it is open
if self.client:
self.client.close()
self.client = None
def __iter__(self, *args, **kwargs):
def __iter__(self) -> Iterator[Any]:
# We have to override this as debuggers invoke it causing the client to reopen
if self.client:
return self.client.iter(*args, **kwargs)
return object.__getattribute__(self, 'iter')(*args, **kwargs)
# Convert client to list first since it's not directly iterable
return iter(list(self.client.__dict__.items()))
return iter([])
@property
def is_closed(self):
def is_closed(self) -> bool:
# Check if closed
if self.client is None:
return True
return self.client.is_closed
# Convert to bool to ensure we return a bool
return bool(self.client.is_closed)
# We need to monkey patch the Client class to track instances
# This is a hack until LiteLLM fixes their client lifecycle management
original_client = httpx.Client
httpx.Client = ClientProxy
try:
yield
finally:
httpx.Client = wrapped_class
httpx.Client = original_client
while proxys:
proxy = proxys.pop()
proxy.close()

View File

@@ -4,12 +4,15 @@ from typing import Type, TypeVar
T = TypeVar('T')
def import_from(qual_name: str):
def import_from(qual_name: str) -> type:
"""Import the value from the qualified name given"""
parts = qual_name.split('.')
module_name = '.'.join(parts[:-1])
module = importlib.import_module(module_name)
result = getattr(module, parts[-1])
assert isinstance(
result, type
), f'Expected {qual_name} to be a type, got {type(result)}'
return result

View File

@@ -1,5 +1,5 @@
import base64
from typing import AsyncIterator, Callable
from typing import Any, AsyncIterator, Callable
def offset_to_page_id(offset: int, has_next: bool) -> str | None:
@@ -16,7 +16,7 @@ def page_id_to_offset(page_id: str | None) -> int:
return offset
async def iterate(fn: Callable, **kwargs) -> AsyncIterator:
async def iterate(fn: Callable[..., Any], **kwargs: Any) -> AsyncIterator[Any]:
"""Iterate over paged result sets. Assumes that the results sets contain an array of result objects, and a next_page_id"""
kwargs = {**kwargs}
kwargs['page_id'] = None

View File

@@ -22,4 +22,7 @@ def colorize(text: str, color: TermColor = TermColor.WARNING) -> str:
Returns:
str: Colored text
"""
return colored(text, color.value)
# colored() returns a string with ANSI color codes
result = colored(text, color.value)
assert isinstance(result, str)
return result