Compare commits

..

6 Commits

Author SHA1 Message Date
openhands
aa5473ac04 Refactor action-suggestions.tsx to reduce duplication 2025-04-07 15:55:58 +00:00
openhands
32979a4864 Fix 'Push and create PR' button sending wrong query 2025-04-07 15:55:10 +00:00
Boxuan Li
e951da7a25 Fix action execution server JSONResponse (#7721) 2025-04-07 22:49:39 +08:00
Carlos Freund
f830d5814c fix(unittest): Parallel Test failure because of shared memory (#7729)
Co-authored-by: Carlos Freund <carlosfreund@gmail.com>
2025-04-07 09:29:22 -04:00
Carlos Freund
0519e9e3c2 fix(test) test_memory: initialize in fixture with new dict. (#7733)
Co-authored-by: Carlos Freund <carlosfreund@gmail.com>
2025-04-06 23:52:14 +02:00
Graham Neubig
9b8a628395 Add more extensive typing to openhands/llm/ directory (#7727)
Co-authored-by: openhands <openhands@all-hands.dev>
2025-04-06 17:59:25 +00:00
23 changed files with 161 additions and 242 deletions

View File

@@ -11,6 +11,14 @@ interface ActionSuggestionsProps {
onSuggestionsClick: (value: string) => void;
}
// Define button configurations to reduce duplication
interface ButtonConfig {
label: string;
value: string;
eventName: string;
callback?: () => void;
}
export function ActionSuggestions({
onSuggestionsClick,
}: ActionSuggestionsProps) {
@@ -30,16 +38,37 @@ export function ActionSuggestions({
const pr = isGitLab ? "merge request" : "pull request";
const prShort = isGitLab ? "MR" : "PR";
const terms = {
pr,
prShort,
pushToBranch: `Please push the changes to a remote branch on ${
// Define the button configurations
const PUSH_TO_BRANCH: ButtonConfig = {
label: t(I18nKey.ACTION$PUSH_TO_BRANCH),
value: `Please push the changes to a remote branch on ${
isGitLab ? "GitLab" : "GitHub"
}, but do NOT create a ${pr}. Please use the exact SAME branch name as the one you are currently on.`,
createPR: `Please push the changes to ${
eventName: "push_to_branch_button_clicked",
};
const PUSH_AND_CREATE_PR: ButtonConfig = {
label: t(I18nKey.ACTION$PUSH_CREATE_PR),
value: `Please push the changes to ${
isGitLab ? "GitLab" : "GitHub"
} and open a ${pr}. Please create a meaningful branch name that describes the changes. If a ${pr} template exists in the repository, please follow it when creating the ${prShort} description.`,
pushToPR: `Please push the latest changes to the existing ${pr}.`,
eventName: "create_pr_button_clicked",
callback: () => setHasPullRequest(true),
};
const PUSH_TO_PR: ButtonConfig = {
label: t(I18nKey.ACTION$PUSH_CHANGES_TO_PR),
value: `Please push the latest changes to the existing ${pr}.`,
eventName: "push_to_pr_button_clicked",
};
// Helper function to handle button clicks
const handleButtonClick = (config: ButtonConfig) => {
posthog.capture(config.eventName);
onSuggestionsClick(config.value);
if (config.callback) {
config.callback();
}
};
return (
@@ -50,36 +79,26 @@ export function ActionSuggestions({
<>
<SuggestionItem
suggestion={{
label: t(I18nKey.ACTION$PUSH_TO_BRANCH),
value: terms.pushToBranch,
}}
onClick={(value) => {
posthog.capture("push_to_branch_button_clicked");
onSuggestionsClick(value);
label: PUSH_TO_BRANCH.label,
value: PUSH_TO_BRANCH.value,
}}
onClick={() => handleButtonClick(PUSH_TO_BRANCH)}
/>
<SuggestionItem
suggestion={{
label: t(I18nKey.ACTION$PUSH_CREATE_PR),
value: terms.createPR,
}}
onClick={(value) => {
posthog.capture("create_pr_button_clicked");
onSuggestionsClick(value);
setHasPullRequest(true);
label: PUSH_AND_CREATE_PR.label,
value: PUSH_AND_CREATE_PR.value,
}}
onClick={() => handleButtonClick(PUSH_AND_CREATE_PR)}
/>
</>
) : (
<SuggestionItem
suggestion={{
label: t(I18nKey.ACTION$PUSH_CHANGES_TO_PR),
value: terms.pushToPR,
}}
onClick={(value) => {
posthog.capture("push_to_pr_button_clicked");
onSuggestionsClick(value);
label: PUSH_TO_PR.label,
value: PUSH_TO_PR.value,
}}
onClick={() => handleButtonClick(PUSH_TO_PR)}
/>
)}
</div>

View File

@@ -1,6 +1,6 @@
import asyncio
from functools import partial
from typing import Any
from typing import Any, Callable
from litellm import acompletion as litellm_acompletion
@@ -17,7 +17,7 @@ from openhands.utils.shutdown_listener import should_continue
class AsyncLLM(LLM):
"""Asynchronous LLM class."""
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._async_completion = partial(
@@ -46,7 +46,7 @@ class AsyncLLM(LLM):
retry_max_wait=self.config.retry_max_wait,
retry_multiplier=self.config.retry_multiplier,
)
async def async_completion_wrapper(*args, **kwargs):
async def async_completion_wrapper(*args: Any, **kwargs: Any) -> Any:
"""Wrapper for the litellm acompletion function that adds logging and cost tracking."""
messages: list[dict[str, Any]] | dict[str, Any] = []
@@ -77,7 +77,7 @@ class AsyncLLM(LLM):
self.log_prompt(messages)
async def check_stopped():
async def check_stopped() -> None:
while should_continue():
if (
hasattr(self.config, 'on_cancel_requested_fn')
@@ -117,14 +117,14 @@ class AsyncLLM(LLM):
except asyncio.CancelledError:
pass
self._async_completion = async_completion_wrapper # type: ignore
self._async_completion = async_completion_wrapper
async def _call_acompletion(self, *args, **kwargs):
async def _call_acompletion(self, *args: Any, **kwargs: Any) -> Any:
"""Wrapper for the litellm acompletion function."""
# Used in testing?
return await litellm_acompletion(*args, **kwargs)
@property
def async_completion(self):
def async_completion(self) -> Callable:
"""Decorator for the async litellm acompletion function."""
return self._async_completion

View File

@@ -28,5 +28,5 @@ def list_foundation_models(
return []
def remove_error_modelId(model_list):
def remove_error_modelId(model_list: list[str]) -> list[str]:
return list(filter(lambda m: not m.startswith('bedrock'), model_list))

View File

@@ -7,7 +7,7 @@ MESSAGE_SEPARATOR = '\n\n----------\n\n'
class DebugMixin:
def log_prompt(self, messages: list[dict[str, Any]] | dict[str, Any]):
def log_prompt(self, messages: list[dict[str, Any]] | dict[str, Any]) -> None:
if not messages:
logger.debug('No completion messages!')
return
@@ -24,11 +24,11 @@ class DebugMixin:
else:
logger.debug('No completion messages!')
def log_response(self, message_back: str):
def log_response(self, message_back: str) -> None:
if message_back:
llm_response_logger.debug(message_back)
def _format_message_content(self, message: dict[str, Any]):
def _format_message_content(self, message: dict[str, Any]) -> str:
content = message['content']
if isinstance(content, list):
return '\n'.join(
@@ -36,18 +36,18 @@ class DebugMixin:
)
return str(content)
def _format_content_element(self, element: dict[str, Any]):
def _format_content_element(self, element: dict[str, Any] | Any) -> str:
if isinstance(element, dict):
if 'text' in element:
return element['text']
return str(element['text'])
if (
self.vision_is_active()
and 'image_url' in element
and 'url' in element['image_url']
):
return element['image_url']['url']
return str(element['image_url']['url'])
return str(element)
# This method should be implemented in the class that uses DebugMixin
def vision_is_active(self):
def vision_is_active(self) -> bool:
raise NotImplementedError

View File

@@ -186,7 +186,7 @@ class LLM(RetryMixin, DebugMixin):
retry_multiplier=self.config.retry_multiplier,
retry_listener=self.retry_listener,
)
def wrapper(*args, **kwargs):
def wrapper(*args: Any, **kwargs: Any) -> Any:
"""Wrapper for the litellm completion function. Logs the input and output of the completion function."""
from openhands.io import json
@@ -355,14 +355,14 @@ class LLM(RetryMixin, DebugMixin):
self._completion = wrapper
@property
def completion(self):
def completion(self) -> Callable:
"""Decorator for the litellm completion function.
Check the complete documentation at https://litellm.vercel.app/docs/completion
"""
return self._completion
def init_model_info(self):
def init_model_info(self) -> None:
if self._tried_model_info:
return
self._tried_model_info = True
@@ -622,10 +622,12 @@ class LLM(RetryMixin, DebugMixin):
# try to get the token count with the default litellm tokenizers
# or the custom tokenizer if set for this LLM configuration
try:
return litellm.token_counter(
model=self.config.model,
messages=messages,
custom_tokenizer=self.tokenizer,
return int(
litellm.token_counter(
model=self.config.model,
messages=messages,
custom_tokenizer=self.tokenizer,
)
)
except Exception as e:
# limit logspam in case token count is not supported
@@ -654,7 +656,7 @@ class LLM(RetryMixin, DebugMixin):
return True
return False
def _completion_cost(self, response) -> float:
def _completion_cost(self, response: Any) -> float:
"""Calculate completion cost and update metrics with running total.
Calculate the cost of a completion response based on the model. Local models are treated as free.
@@ -707,21 +709,21 @@ class LLM(RetryMixin, DebugMixin):
logger.debug(
f'Using fallback model name {_model_name} to get cost: {cost}'
)
self.metrics.add_cost(cost)
return cost
self.metrics.add_cost(float(cost))
return float(cost)
except Exception:
self.cost_metric_supported = False
logger.debug('Cost calculation not supported for this model.')
return 0.0
def __str__(self):
def __str__(self) -> str:
if self.config.api_version:
return f'LLM(model={self.config.model}, api_version={self.config.api_version}, base_url={self.config.base_url})'
elif self.config.base_url:
return f'LLM(model={self.config.model}, base_url={self.config.base_url})'
return f'LLM(model={self.config.model})'
def __repr__(self):
def __repr__(self) -> str:
return str(self)
def reset(self) -> None:

View File

@@ -177,7 +177,7 @@ class Metrics:
'token_usages': [usage.model_dump() for usage in self._token_usages],
}
def reset(self):
def reset(self) -> None:
self._accumulated_cost = 0.0
self._costs = []
self._response_latencies = []
@@ -192,7 +192,7 @@ class Metrics:
response_id='',
)
def log(self):
def log(self) -> str:
"""Log the metrics."""
metrics = self.get()
logs = ''
@@ -200,5 +200,5 @@ class Metrics:
logs += f'{key}: {value}\n'
return logs
def __repr__(self):
def __repr__(self) -> str:
return f'Metrics({self.get()}'

View File

@@ -1,3 +1,5 @@
from typing import Any, Callable
from tenacity import (
retry,
retry_if_exception_type,
@@ -13,7 +15,7 @@ from openhands.utils.tenacity_stop import stop_if_should_exit
class RetryMixin:
"""Mixin class for retry logic."""
def retry_decorator(self, **kwargs):
def retry_decorator(self, **kwargs: Any) -> Callable:
"""
Create a LLM retry decorator with customizable parameters. This is used for 429 errors, and a few other exceptions in LLM classes.
@@ -31,7 +33,7 @@ class RetryMixin:
retry_multiplier = kwargs.get('retry_multiplier')
retry_listener = kwargs.get('retry_listener')
def before_sleep(retry_state):
def before_sleep(retry_state: Any) -> None:
self.log_retry_attempt(retry_state)
if retry_listener:
retry_listener(retry_state.attempt_number, num_retries)
@@ -52,7 +54,7 @@ class RetryMixin:
f'LLMNoResponseError detected with temperature={current_temp}, keeping original temperature'
)
return retry(
retry_decorator: Callable = retry(
before_sleep=before_sleep,
stop=stop_after_attempt(num_retries) | stop_if_should_exit(),
reraise=True,
@@ -65,8 +67,9 @@ class RetryMixin:
max=retry_max_wait,
),
)
return retry_decorator
def log_retry_attempt(self, retry_state):
def log_retry_attempt(self, retry_state: Any) -> None:
"""Log retry attempts."""
exception = retry_state.outcome.exception()
logger.error(

View File

@@ -1,6 +1,6 @@
import asyncio
from functools import partial
from typing import Any
from typing import Any, Callable
from openhands.core.exceptions import UserCancelledError
from openhands.core.logger import openhands_logger as logger
@@ -11,7 +11,7 @@ from openhands.llm.llm import REASONING_EFFORT_SUPPORTED_MODELS
class StreamingLLM(AsyncLLM):
"""Streaming LLM class."""
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._async_streaming_completion = partial(
@@ -40,7 +40,7 @@ class StreamingLLM(AsyncLLM):
retry_max_wait=self.config.retry_max_wait,
retry_multiplier=self.config.retry_multiplier,
)
async def async_streaming_completion_wrapper(*args, **kwargs):
async def async_streaming_completion_wrapper(*args: Any, **kwargs: Any) -> Any:
messages: list[dict[str, Any]] | dict[str, Any] = []
# some callers might send the model and messages directly
@@ -108,6 +108,6 @@ class StreamingLLM(AsyncLLM):
self._async_streaming_completion = async_streaming_completion_wrapper
@property
def async_streaming_completion(self):
def async_streaming_completion(self) -> Callable:
"""Decorator for the async litellm acompletion function with streaming."""
return self._async_streaming_completion

View File

@@ -585,7 +585,10 @@ if __name__ == '__main__':
logger.error(f'Validation error occurred: {exc}')
return JSONResponse(
status_code=422,
content={'detail': 'Invalid request parameters', 'errors': exc.errors()},
content={
'detail': 'Invalid request parameters',
'errors': str(exc.errors()),
},
)
@app.middleware('http')

View File

@@ -31,81 +31,37 @@ async def browse(
try:
# obs provided by BrowserGym: see https://github.com/ServiceNow/BrowserGym/blob/main/core/src/browsergym/core/env.py#L396
obs = await call_sync_from_async(browser.step, action_str)
# Extract values with type checking
text_content = obs.get('text_content', '')
if not isinstance(text_content, str):
text_content = str(text_content)
url = obs.get('url', '')
if not isinstance(url, str):
url = str(url)
image_content = obs.get('image_content', [])
if not isinstance(image_content, list):
image_content = []
open_pages_urls = obs.get('open_pages_urls', [])
if not isinstance(open_pages_urls, list):
open_pages_urls = []
active_page_index = obs.get('active_page_index', -1)
if not isinstance(active_page_index, int):
try:
active_page_index = int(active_page_index)
except (ValueError, TypeError):
active_page_index = -1
dom_object = obs.get('dom_object', {})
if not isinstance(dom_object, dict):
dom_object = {}
axtree_object = obs.get('axtree_object', {})
if not isinstance(axtree_object, dict):
axtree_object = {}
extra_element_properties = obs.get('extra_element_properties', {})
if not isinstance(extra_element_properties, dict):
extra_element_properties = {}
last_action = obs.get('last_action', '')
if not isinstance(last_action, str):
last_action = str(last_action)
last_action_error = obs.get('last_action_error', '')
if not isinstance(last_action_error, str):
last_action_error = str(last_action_error)
return BrowserOutputObservation(
content=text_content, # text content of the page
url=url, # URL of the page
content=obs['text_content'], # text content of the page
url=obs.get('url', ''), # URL of the page
screenshot=obs.get('screenshot', None), # base64-encoded screenshot, png
set_of_marks=obs.get(
'set_of_marks', None
), # base64-encoded Set-of-Marks annotated screenshot, png,
goal_image_urls=image_content,
open_pages_urls=open_pages_urls, # list of open pages
active_page_index=active_page_index, # index of the active page
dom_object=dom_object, # DOM object
axtree_object=axtree_object, # accessibility tree object
extra_element_properties=extra_element_properties,
goal_image_urls=obs.get('image_content', []),
open_pages_urls=obs.get('open_pages_urls', []), # list of open pages
active_page_index=obs.get(
'active_page_index', -1
), # index of the active page
dom_object=obs.get('dom_object', {}), # DOM object
axtree_object=obs.get('axtree_object', {}), # accessibility tree object
extra_element_properties=obs.get('extra_element_properties', {}),
focused_element_bid=obs.get(
'focused_element_bid', None
), # focused element bid
last_browser_action=last_action, # last browser env action performed
last_browser_action_error=last_action_error,
error=bool(last_action_error), # error flag
last_browser_action=obs.get(
'last_action', ''
), # last browser env action performed
last_browser_action_error=obs.get('last_action_error', ''),
error=True if obs.get('last_action_error', '') else False, # error flag
trigger_by_action=action.action,
)
except Exception as e:
error_message = str(e)
url_value = asked_url if action.action == ActionType.BROWSE else ''
return BrowserOutputObservation(
content=error_message,
content=str(e),
screenshot='',
error=True,
last_browser_action_error=error_message,
url=url_value,
last_browser_action_error=str(e),
url=asked_url if action.action == ActionType.BROWSE else '',
trigger_by_action=action.action,
)

View File

@@ -184,8 +184,6 @@ class DaytonaRuntime(ActionExecutionClient):
self.api_url = self._construct_api_url(self._sandbox_port)
# Ensure workspace is not None before accessing its attributes
assert self.workspace is not None, 'Workspace should not be None at this point'
state = self.workspace.instance.state
if state == 'stopping':

View File

@@ -33,14 +33,6 @@ from openhands.server.file_config import (
)
from openhands.utils.async_utils import call_sync_from_async
def _write_to_file(file_path: str, contents: bytes) -> None:
"""Helper function to write contents to a file."""
with open(file_path, 'wb') as file:
file.write(contents)
file.flush()
app = APIRouter(prefix='/api/conversations/{conversation_id}')
@@ -203,8 +195,9 @@ async def upload_file(request: Request, conversation_id: str, files: list[Upload
# copy the file to the runtime
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_file_path = os.path.join(tmp_dir, safe_filename)
# Use a helper function to write to the file
await call_sync_from_async(_write_to_file, tmp_file_path, file_contents)
with open(tmp_file_path, 'wb') as tmp_file:
tmp_file.write(file_contents)
tmp_file.flush()
runtime: Runtime = request.state.conversation.runtime
try:

View File

@@ -112,9 +112,7 @@ async def _create_new_conversation(
title=conversation_title,
user_id=user_id,
github_user_id=None,
selected_repository=selected_repository.full_name
if selected_repository
else selected_repository,
selected_repository=selected_repository.full_name if selected_repository else selected_repository,
selected_branch=selected_branch,
)
)
@@ -382,7 +380,7 @@ async def delete_conversation(
async def _get_conversation_info(
conversation: ConversationMetadata,
is_running: bool,
) -> ConversationInfo:
) -> ConversationInfo | None:
try:
title = conversation.title
if not title:
@@ -402,12 +400,4 @@ async def _get_conversation_info(
f'Error loading conversation {conversation.conversation_id}: {str(e)}',
extra={'session_id': conversation.conversation_id},
)
# Create a default ConversationInfo object instead of returning None
return ConversationInfo(
conversation_id=conversation.conversation_id,
title=get_default_conversation_title(conversation.conversation_id),
last_updated_at=conversation.last_updated_at,
created_at=conversation.created_at,
selected_repository='',
status=ConversationStatus.STOPPED,
)
return None

View File

@@ -15,13 +15,6 @@ from openhands.core.config import LLMConfig
from openhands.core.logger import openhands_logger as logger
from openhands.llm import bedrock
from openhands.server.shared import config, server_config
from openhands.utils.async_utils import call_sync_from_async
def _get_ollama_models(url: str) -> list:
"""Helper function to get Ollama models."""
return httpx.get(url, timeout=3).json()['models']
app = APIRouter(prefix='/api/options')
@@ -67,9 +60,7 @@ async def get_litellm_models() -> list[str]:
if ollama_base_url:
ollama_url = ollama_base_url.strip('/') + '/api/tags'
try:
ollama_models_list = await call_sync_from_async(
_get_ollama_models, ollama_url
)
ollama_models_list = httpx.get(ollama_url, timeout=3).json()['models']
for model in ollama_models_list:
model_list.append('ollama/' + model['name'])
break

View File

@@ -3,14 +3,14 @@ import os
from openhands.core.logger import openhands_logger as logger
from openhands.storage.files import FileStore
IN_MEMORY_FILES: dict = {}
class InMemoryFileStore(FileStore):
files: dict[str, str]
def __init__(self, files: dict[str, str] = IN_MEMORY_FILES):
self.files = files
def __init__(self, files: dict[str, str] | None = None) -> None:
self.files = {}
if files is not None:
self.files = files
def write(self, path: str, contents: str | bytes) -> None:
if isinstance(contents, bytes):

View File

@@ -1,22 +1,17 @@
import asyncio
from concurrent import futures
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Callable, Coroutine, Iterable, List, TypeVar
T = TypeVar('T') # Return type of the async function
R = TypeVar('R') # Return type of the sync function
from typing import Callable, Coroutine, Iterable, List
GENERAL_TIMEOUT: int = 15
EXECUTOR = ThreadPoolExecutor()
async def call_sync_from_async(fn: Callable[..., R], *args: Any, **kwargs: Any) -> R:
async def call_sync_from_async(fn: Callable, *args, **kwargs):
"""
Shorthand for running a function in the default background thread pool executor
and awaiting the result. The nature of synchronous code is that the future
returned by this function is not cancellable.
Preserves the return type of the original function.
returned by this function is not cancellable
"""
loop = asyncio.get_event_loop()
coro = loop.run_in_executor(None, lambda: fn(*args, **kwargs))
@@ -25,28 +20,24 @@ async def call_sync_from_async(fn: Callable[..., R], *args: Any, **kwargs: Any)
def call_async_from_sync(
corofn: Callable[..., Coroutine[Any, Any, T]],
timeout: float = GENERAL_TIMEOUT,
*args: Any,
**kwargs: Any,
) -> T:
corofn: Callable, timeout: float = GENERAL_TIMEOUT, *args, **kwargs
):
"""
Shorthand for running a coroutine in the default background thread pool executor
and awaiting the result.
Preserves the return type of the original coroutine function.
and awaiting the result
"""
if corofn is None:
raise ValueError('corofn is None')
if not asyncio.iscoroutinefunction(corofn):
raise ValueError('corofn is not a coroutine function')
async def arun() -> T:
async def arun():
coro = corofn(*args, **kwargs)
result = await coro
return result
def run() -> T:
def run():
loop_for_thread = asyncio.new_event_loop()
try:
asyncio.set_event_loop(loop_for_thread)
@@ -61,31 +52,20 @@ def call_async_from_sync(
async def call_coro_in_bg_thread(
corofn: Callable[..., Coroutine[Any, Any, T]],
timeout: float = GENERAL_TIMEOUT,
*args: Any,
**kwargs: Any,
) -> T:
"""
Function for running a coroutine in a background thread.
Preserves the return type of the original coroutine function.
"""
return await call_sync_from_async(
call_async_from_sync, corofn, timeout, *args, **kwargs
)
corofn: Callable, timeout: float = GENERAL_TIMEOUT, *args, **kwargs
):
"""Function for running a coroutine in a background thread."""
await call_sync_from_async(call_async_from_sync, corofn, timeout, *args, **kwargs)
async def wait_all(
iterable: Iterable[Coroutine[Any, Any, T]], timeout: int = GENERAL_TIMEOUT
) -> List[T]:
iterable: Iterable[Coroutine], timeout: int = GENERAL_TIMEOUT
) -> List:
"""
Shorthand for waiting for all the coroutines in the iterable given in parallel. Creates
a task for each coroutine.
Returns a list of results in the original order. If any single task raised an exception, this is raised.
If multiple tasks raised exceptions, an AsyncException is raised containing all exceptions.
Preserves the return type of the original coroutines.
"""
tasks = [asyncio.create_task(c) for c in iterable]
if not tasks:
@@ -110,8 +90,8 @@ async def wait_all(
class AsyncException(Exception):
def __init__(self, exceptions: list[Exception]) -> None:
def __init__(self, exceptions):
self.exceptions = exceptions
def __str__(self) -> str:
def __str__(self):
return '\n'.join(str(e) for e in self.exceptions)

View File

@@ -25,7 +25,7 @@ class Chunk(BaseModel):
return ret
def _create_chunks_from_raw_string(content: str, size: int) -> list[Chunk]:
def _create_chunks_from_raw_string(content: str, size: int):
lines = content.split('\n')
ret = []
for i in range(0, len(lines), size):
@@ -65,7 +65,7 @@ def normalized_lcs(chunk: str, query: str) -> float:
"""
if len(chunk) == 0:
return 0.0
_score = float(pylcs.lcs_sequence_length(chunk, query))
_score = pylcs.lcs_sequence_length(chunk, query)
return _score / len(chunk)

View File

@@ -15,17 +15,15 @@ Hopefully, this will be fixed soon and we can remove this abomination.
"""
import contextlib
from typing import Any, Callable, Iterator, TypeVar
from typing import Callable
import httpx
T = TypeVar('T')
@contextlib.contextmanager
def ensure_httpx_close() -> Iterator[None]:
def ensure_httpx_close():
wrapped_class = httpx.Client
proxys: list[Any] = []
proxys = []
class ClientProxy:
"""
@@ -34,24 +32,24 @@ def ensure_httpx_close() -> Iterator[None]:
where a client is reused, we need to be able to reuse the client even after closing it.
"""
client_constructor: Callable[..., Any]
args: tuple[Any, ...]
kwargs: dict[str, Any]
client: httpx.Client | None
client_constructor: Callable
args: tuple
kwargs: dict
client: httpx.Client
def __init__(self, *args: Any, **kwargs: Any) -> None:
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
self.client = wrapped_class(*self.args, **self.kwargs)
proxys.append(self)
def __getattr__(self, name: str) -> Any:
def __getattr__(self, name):
# Invoke a method on the proxied client - create one if required
if self.client is None:
self.client = wrapped_class(*self.args, **self.kwargs)
return getattr(self.client, name)
def close(self) -> None:
def close(self):
# Close the client if it is open
if self.client:
self.client.close()
@@ -64,21 +62,17 @@ def ensure_httpx_close() -> Iterator[None]:
return object.__getattribute__(self, 'iter')(*args, **kwargs)
@property
def is_closed(self) -> bool:
def is_closed(self):
# Check if closed
if self.client is None:
return True
# Convert to bool to ensure we return a bool
return bool(self.client.is_closed)
return self.client.is_closed
# We need to monkey patch the Client class to track instances
# This is a hack until LiteLLM fixes their client lifecycle management
original_client = httpx.Client
httpx.Client = ClientProxy
try:
yield
finally:
httpx.Client = original_client
httpx.Client = wrapped_class
while proxys:
proxy = proxys.pop()
proxy.close()

View File

@@ -5,15 +5,12 @@ from typing import Type, TypeVar
T = TypeVar('T')
def import_from(qual_name: str) -> type:
def import_from(qual_name: str):
"""Import the value from the qualified name given"""
parts = qual_name.split('.')
module_name = '.'.join(parts[:-1])
module = importlib.import_module(module_name)
result = getattr(module, parts[-1])
assert isinstance(
result, type
), f'Expected {qual_name} to be a type, got {type(result)}'
return result

View File

@@ -1,5 +1,5 @@
import base64
from typing import Any, AsyncIterator, Callable
from typing import AsyncIterator, Callable
def offset_to_page_id(offset: int, has_next: bool) -> str | None:
@@ -16,7 +16,7 @@ def page_id_to_offset(page_id: str | None) -> int:
return offset
async def iterate(fn: Callable[..., Any], **kwargs: Any) -> AsyncIterator[Any]:
async def iterate(fn: Callable, **kwargs) -> AsyncIterator:
"""Iterate over paged result sets. Assumes that the results sets contain an array of result objects, and a next_page_id"""
kwargs = {**kwargs}
kwargs['page_id'] = None

View File

@@ -22,7 +22,4 @@ def colorize(text: str, color: TermColor = TermColor.WARNING) -> str:
Returns:
str: Colored text
"""
# colored() returns a string with ANSI color codes
result = colored(text, color.value)
assert isinstance(result, str)
return result
return colored(text, color.value)

View File

@@ -59,7 +59,6 @@ async def test_agent_session_start_with_no_state(mock_agent):
# Create a mock runtime and set it up
mock_runtime = MagicMock(spec=Runtime)
mock_runtime.get_microagents_from_selected_repo.return_value = []
# Mock the runtime creation to set up the runtime attribute
async def mock_create_runtime(*args, **kwargs):
@@ -143,7 +142,6 @@ async def test_agent_session_start_with_restored_state(mock_agent):
# Create a mock runtime and set it up
mock_runtime = MagicMock(spec=Runtime)
mock_runtime.get_microagents_from_selected_repo.return_value = []
# Mock the runtime creation to set up the runtime attribute
async def mock_create_runtime(*args, **kwargs):

View File

@@ -28,7 +28,7 @@ from openhands.storage.memory import InMemoryFileStore
@pytest.fixture
def file_store():
"""Create a temporary file store for testing."""
return InMemoryFileStore()
return InMemoryFileStore({})
@pytest.fixture
@@ -190,10 +190,9 @@ async def test_memory_with_microagents():
assert 'magic word' in observation.microagent_knowledge[0].content
def test_memory_repository_info(prompt_dir):
def test_memory_repository_info(prompt_dir, file_store):
"""Test that Memory adds repository info to RecallObservations."""
# Create an in-memory file store and real event stream
file_store = InMemoryFileStore()
# real event stream
event_stream = EventStream(sid='test-session', file_store=file_store)
# Create a test repo microagent first
@@ -321,10 +320,9 @@ async def test_memory_with_agent_microagents():
assert 'magic word' in observation.microagent_knowledge[0].content
def test_memory_multiple_repo_microagents(prompt_dir):
def test_memory_multiple_repo_microagents(prompt_dir, file_store):
"""Test that Memory loads and concatenates multiple repo microagents correctly."""
# Create an in-memory file store and real event stream
file_store = InMemoryFileStore()
# Create real event stream
event_stream = EventStream(sid='test-session', file_store=file_store)
# Create two test repo microagents