mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
6 Commits
fix/utils-
...
fix-push-c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
aa5473ac04 | ||
|
|
32979a4864 | ||
|
|
e951da7a25 | ||
|
|
f830d5814c | ||
|
|
0519e9e3c2 | ||
|
|
9b8a628395 |
@@ -11,6 +11,14 @@ interface ActionSuggestionsProps {
|
||||
onSuggestionsClick: (value: string) => void;
|
||||
}
|
||||
|
||||
// Define button configurations to reduce duplication
|
||||
interface ButtonConfig {
|
||||
label: string;
|
||||
value: string;
|
||||
eventName: string;
|
||||
callback?: () => void;
|
||||
}
|
||||
|
||||
export function ActionSuggestions({
|
||||
onSuggestionsClick,
|
||||
}: ActionSuggestionsProps) {
|
||||
@@ -30,16 +38,37 @@ export function ActionSuggestions({
|
||||
const pr = isGitLab ? "merge request" : "pull request";
|
||||
const prShort = isGitLab ? "MR" : "PR";
|
||||
|
||||
const terms = {
|
||||
pr,
|
||||
prShort,
|
||||
pushToBranch: `Please push the changes to a remote branch on ${
|
||||
// Define the button configurations
|
||||
const PUSH_TO_BRANCH: ButtonConfig = {
|
||||
label: t(I18nKey.ACTION$PUSH_TO_BRANCH),
|
||||
value: `Please push the changes to a remote branch on ${
|
||||
isGitLab ? "GitLab" : "GitHub"
|
||||
}, but do NOT create a ${pr}. Please use the exact SAME branch name as the one you are currently on.`,
|
||||
createPR: `Please push the changes to ${
|
||||
eventName: "push_to_branch_button_clicked",
|
||||
};
|
||||
|
||||
const PUSH_AND_CREATE_PR: ButtonConfig = {
|
||||
label: t(I18nKey.ACTION$PUSH_CREATE_PR),
|
||||
value: `Please push the changes to ${
|
||||
isGitLab ? "GitLab" : "GitHub"
|
||||
} and open a ${pr}. Please create a meaningful branch name that describes the changes. If a ${pr} template exists in the repository, please follow it when creating the ${prShort} description.`,
|
||||
pushToPR: `Please push the latest changes to the existing ${pr}.`,
|
||||
eventName: "create_pr_button_clicked",
|
||||
callback: () => setHasPullRequest(true),
|
||||
};
|
||||
|
||||
const PUSH_TO_PR: ButtonConfig = {
|
||||
label: t(I18nKey.ACTION$PUSH_CHANGES_TO_PR),
|
||||
value: `Please push the latest changes to the existing ${pr}.`,
|
||||
eventName: "push_to_pr_button_clicked",
|
||||
};
|
||||
|
||||
// Helper function to handle button clicks
|
||||
const handleButtonClick = (config: ButtonConfig) => {
|
||||
posthog.capture(config.eventName);
|
||||
onSuggestionsClick(config.value);
|
||||
if (config.callback) {
|
||||
config.callback();
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
@@ -50,36 +79,26 @@ export function ActionSuggestions({
|
||||
<>
|
||||
<SuggestionItem
|
||||
suggestion={{
|
||||
label: t(I18nKey.ACTION$PUSH_TO_BRANCH),
|
||||
value: terms.pushToBranch,
|
||||
}}
|
||||
onClick={(value) => {
|
||||
posthog.capture("push_to_branch_button_clicked");
|
||||
onSuggestionsClick(value);
|
||||
label: PUSH_TO_BRANCH.label,
|
||||
value: PUSH_TO_BRANCH.value,
|
||||
}}
|
||||
onClick={() => handleButtonClick(PUSH_TO_BRANCH)}
|
||||
/>
|
||||
<SuggestionItem
|
||||
suggestion={{
|
||||
label: t(I18nKey.ACTION$PUSH_CREATE_PR),
|
||||
value: terms.createPR,
|
||||
}}
|
||||
onClick={(value) => {
|
||||
posthog.capture("create_pr_button_clicked");
|
||||
onSuggestionsClick(value);
|
||||
setHasPullRequest(true);
|
||||
label: PUSH_AND_CREATE_PR.label,
|
||||
value: PUSH_AND_CREATE_PR.value,
|
||||
}}
|
||||
onClick={() => handleButtonClick(PUSH_AND_CREATE_PR)}
|
||||
/>
|
||||
</>
|
||||
) : (
|
||||
<SuggestionItem
|
||||
suggestion={{
|
||||
label: t(I18nKey.ACTION$PUSH_CHANGES_TO_PR),
|
||||
value: terms.pushToPR,
|
||||
}}
|
||||
onClick={(value) => {
|
||||
posthog.capture("push_to_pr_button_clicked");
|
||||
onSuggestionsClick(value);
|
||||
label: PUSH_TO_PR.label,
|
||||
value: PUSH_TO_PR.value,
|
||||
}}
|
||||
onClick={() => handleButtonClick(PUSH_TO_PR)}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
from typing import Any, Callable
|
||||
|
||||
from litellm import acompletion as litellm_acompletion
|
||||
|
||||
@@ -17,7 +17,7 @@ from openhands.utils.shutdown_listener import should_continue
|
||||
class AsyncLLM(LLM):
|
||||
"""Asynchronous LLM class."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self._async_completion = partial(
|
||||
@@ -46,7 +46,7 @@ class AsyncLLM(LLM):
|
||||
retry_max_wait=self.config.retry_max_wait,
|
||||
retry_multiplier=self.config.retry_multiplier,
|
||||
)
|
||||
async def async_completion_wrapper(*args, **kwargs):
|
||||
async def async_completion_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
"""Wrapper for the litellm acompletion function that adds logging and cost tracking."""
|
||||
messages: list[dict[str, Any]] | dict[str, Any] = []
|
||||
|
||||
@@ -77,7 +77,7 @@ class AsyncLLM(LLM):
|
||||
|
||||
self.log_prompt(messages)
|
||||
|
||||
async def check_stopped():
|
||||
async def check_stopped() -> None:
|
||||
while should_continue():
|
||||
if (
|
||||
hasattr(self.config, 'on_cancel_requested_fn')
|
||||
@@ -117,14 +117,14 @@ class AsyncLLM(LLM):
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self._async_completion = async_completion_wrapper # type: ignore
|
||||
self._async_completion = async_completion_wrapper
|
||||
|
||||
async def _call_acompletion(self, *args, **kwargs):
|
||||
async def _call_acompletion(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Wrapper for the litellm acompletion function."""
|
||||
# Used in testing?
|
||||
return await litellm_acompletion(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def async_completion(self):
|
||||
def async_completion(self) -> Callable:
|
||||
"""Decorator for the async litellm acompletion function."""
|
||||
return self._async_completion
|
||||
|
||||
@@ -28,5 +28,5 @@ def list_foundation_models(
|
||||
return []
|
||||
|
||||
|
||||
def remove_error_modelId(model_list):
|
||||
def remove_error_modelId(model_list: list[str]) -> list[str]:
|
||||
return list(filter(lambda m: not m.startswith('bedrock'), model_list))
|
||||
|
||||
@@ -7,7 +7,7 @@ MESSAGE_SEPARATOR = '\n\n----------\n\n'
|
||||
|
||||
|
||||
class DebugMixin:
|
||||
def log_prompt(self, messages: list[dict[str, Any]] | dict[str, Any]):
|
||||
def log_prompt(self, messages: list[dict[str, Any]] | dict[str, Any]) -> None:
|
||||
if not messages:
|
||||
logger.debug('No completion messages!')
|
||||
return
|
||||
@@ -24,11 +24,11 @@ class DebugMixin:
|
||||
else:
|
||||
logger.debug('No completion messages!')
|
||||
|
||||
def log_response(self, message_back: str):
|
||||
def log_response(self, message_back: str) -> None:
|
||||
if message_back:
|
||||
llm_response_logger.debug(message_back)
|
||||
|
||||
def _format_message_content(self, message: dict[str, Any]):
|
||||
def _format_message_content(self, message: dict[str, Any]) -> str:
|
||||
content = message['content']
|
||||
if isinstance(content, list):
|
||||
return '\n'.join(
|
||||
@@ -36,18 +36,18 @@ class DebugMixin:
|
||||
)
|
||||
return str(content)
|
||||
|
||||
def _format_content_element(self, element: dict[str, Any]):
|
||||
def _format_content_element(self, element: dict[str, Any] | Any) -> str:
|
||||
if isinstance(element, dict):
|
||||
if 'text' in element:
|
||||
return element['text']
|
||||
return str(element['text'])
|
||||
if (
|
||||
self.vision_is_active()
|
||||
and 'image_url' in element
|
||||
and 'url' in element['image_url']
|
||||
):
|
||||
return element['image_url']['url']
|
||||
return str(element['image_url']['url'])
|
||||
return str(element)
|
||||
|
||||
# This method should be implemented in the class that uses DebugMixin
|
||||
def vision_is_active(self):
|
||||
def vision_is_active(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -186,7 +186,7 @@ class LLM(RetryMixin, DebugMixin):
|
||||
retry_multiplier=self.config.retry_multiplier,
|
||||
retry_listener=self.retry_listener,
|
||||
)
|
||||
def wrapper(*args, **kwargs):
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
"""Wrapper for the litellm completion function. Logs the input and output of the completion function."""
|
||||
from openhands.io import json
|
||||
|
||||
@@ -355,14 +355,14 @@ class LLM(RetryMixin, DebugMixin):
|
||||
self._completion = wrapper
|
||||
|
||||
@property
|
||||
def completion(self):
|
||||
def completion(self) -> Callable:
|
||||
"""Decorator for the litellm completion function.
|
||||
|
||||
Check the complete documentation at https://litellm.vercel.app/docs/completion
|
||||
"""
|
||||
return self._completion
|
||||
|
||||
def init_model_info(self):
|
||||
def init_model_info(self) -> None:
|
||||
if self._tried_model_info:
|
||||
return
|
||||
self._tried_model_info = True
|
||||
@@ -622,10 +622,12 @@ class LLM(RetryMixin, DebugMixin):
|
||||
# try to get the token count with the default litellm tokenizers
|
||||
# or the custom tokenizer if set for this LLM configuration
|
||||
try:
|
||||
return litellm.token_counter(
|
||||
model=self.config.model,
|
||||
messages=messages,
|
||||
custom_tokenizer=self.tokenizer,
|
||||
return int(
|
||||
litellm.token_counter(
|
||||
model=self.config.model,
|
||||
messages=messages,
|
||||
custom_tokenizer=self.tokenizer,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
# limit logspam in case token count is not supported
|
||||
@@ -654,7 +656,7 @@ class LLM(RetryMixin, DebugMixin):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _completion_cost(self, response) -> float:
|
||||
def _completion_cost(self, response: Any) -> float:
|
||||
"""Calculate completion cost and update metrics with running total.
|
||||
|
||||
Calculate the cost of a completion response based on the model. Local models are treated as free.
|
||||
@@ -707,21 +709,21 @@ class LLM(RetryMixin, DebugMixin):
|
||||
logger.debug(
|
||||
f'Using fallback model name {_model_name} to get cost: {cost}'
|
||||
)
|
||||
self.metrics.add_cost(cost)
|
||||
return cost
|
||||
self.metrics.add_cost(float(cost))
|
||||
return float(cost)
|
||||
except Exception:
|
||||
self.cost_metric_supported = False
|
||||
logger.debug('Cost calculation not supported for this model.')
|
||||
return 0.0
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
if self.config.api_version:
|
||||
return f'LLM(model={self.config.model}, api_version={self.config.api_version}, base_url={self.config.base_url})'
|
||||
elif self.config.base_url:
|
||||
return f'LLM(model={self.config.model}, base_url={self.config.base_url})'
|
||||
return f'LLM(model={self.config.model})'
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return str(self)
|
||||
|
||||
def reset(self) -> None:
|
||||
|
||||
@@ -177,7 +177,7 @@ class Metrics:
|
||||
'token_usages': [usage.model_dump() for usage in self._token_usages],
|
||||
}
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
self._accumulated_cost = 0.0
|
||||
self._costs = []
|
||||
self._response_latencies = []
|
||||
@@ -192,7 +192,7 @@ class Metrics:
|
||||
response_id='',
|
||||
)
|
||||
|
||||
def log(self):
|
||||
def log(self) -> str:
|
||||
"""Log the metrics."""
|
||||
metrics = self.get()
|
||||
logs = ''
|
||||
@@ -200,5 +200,5 @@ class Metrics:
|
||||
logs += f'{key}: {value}\n'
|
||||
return logs
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f'Metrics({self.get()}'
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any, Callable
|
||||
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
@@ -13,7 +15,7 @@ from openhands.utils.tenacity_stop import stop_if_should_exit
|
||||
class RetryMixin:
|
||||
"""Mixin class for retry logic."""
|
||||
|
||||
def retry_decorator(self, **kwargs):
|
||||
def retry_decorator(self, **kwargs: Any) -> Callable:
|
||||
"""
|
||||
Create a LLM retry decorator with customizable parameters. This is used for 429 errors, and a few other exceptions in LLM classes.
|
||||
|
||||
@@ -31,7 +33,7 @@ class RetryMixin:
|
||||
retry_multiplier = kwargs.get('retry_multiplier')
|
||||
retry_listener = kwargs.get('retry_listener')
|
||||
|
||||
def before_sleep(retry_state):
|
||||
def before_sleep(retry_state: Any) -> None:
|
||||
self.log_retry_attempt(retry_state)
|
||||
if retry_listener:
|
||||
retry_listener(retry_state.attempt_number, num_retries)
|
||||
@@ -52,7 +54,7 @@ class RetryMixin:
|
||||
f'LLMNoResponseError detected with temperature={current_temp}, keeping original temperature'
|
||||
)
|
||||
|
||||
return retry(
|
||||
retry_decorator: Callable = retry(
|
||||
before_sleep=before_sleep,
|
||||
stop=stop_after_attempt(num_retries) | stop_if_should_exit(),
|
||||
reraise=True,
|
||||
@@ -65,8 +67,9 @@ class RetryMixin:
|
||||
max=retry_max_wait,
|
||||
),
|
||||
)
|
||||
return retry_decorator
|
||||
|
||||
def log_retry_attempt(self, retry_state):
|
||||
def log_retry_attempt(self, retry_state: Any) -> None:
|
||||
"""Log retry attempts."""
|
||||
exception = retry_state.outcome.exception()
|
||||
logger.error(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
from typing import Any, Callable
|
||||
|
||||
from openhands.core.exceptions import UserCancelledError
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@@ -11,7 +11,7 @@ from openhands.llm.llm import REASONING_EFFORT_SUPPORTED_MODELS
|
||||
class StreamingLLM(AsyncLLM):
|
||||
"""Streaming LLM class."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self._async_streaming_completion = partial(
|
||||
@@ -40,7 +40,7 @@ class StreamingLLM(AsyncLLM):
|
||||
retry_max_wait=self.config.retry_max_wait,
|
||||
retry_multiplier=self.config.retry_multiplier,
|
||||
)
|
||||
async def async_streaming_completion_wrapper(*args, **kwargs):
|
||||
async def async_streaming_completion_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
messages: list[dict[str, Any]] | dict[str, Any] = []
|
||||
|
||||
# some callers might send the model and messages directly
|
||||
@@ -108,6 +108,6 @@ class StreamingLLM(AsyncLLM):
|
||||
self._async_streaming_completion = async_streaming_completion_wrapper
|
||||
|
||||
@property
|
||||
def async_streaming_completion(self):
|
||||
def async_streaming_completion(self) -> Callable:
|
||||
"""Decorator for the async litellm acompletion function with streaming."""
|
||||
return self._async_streaming_completion
|
||||
|
||||
@@ -585,7 +585,10 @@ if __name__ == '__main__':
|
||||
logger.error(f'Validation error occurred: {exc}')
|
||||
return JSONResponse(
|
||||
status_code=422,
|
||||
content={'detail': 'Invalid request parameters', 'errors': exc.errors()},
|
||||
content={
|
||||
'detail': 'Invalid request parameters',
|
||||
'errors': str(exc.errors()),
|
||||
},
|
||||
)
|
||||
|
||||
@app.middleware('http')
|
||||
|
||||
@@ -31,81 +31,37 @@ async def browse(
|
||||
try:
|
||||
# obs provided by BrowserGym: see https://github.com/ServiceNow/BrowserGym/blob/main/core/src/browsergym/core/env.py#L396
|
||||
obs = await call_sync_from_async(browser.step, action_str)
|
||||
|
||||
# Extract values with type checking
|
||||
text_content = obs.get('text_content', '')
|
||||
if not isinstance(text_content, str):
|
||||
text_content = str(text_content)
|
||||
|
||||
url = obs.get('url', '')
|
||||
if not isinstance(url, str):
|
||||
url = str(url)
|
||||
|
||||
image_content = obs.get('image_content', [])
|
||||
if not isinstance(image_content, list):
|
||||
image_content = []
|
||||
|
||||
open_pages_urls = obs.get('open_pages_urls', [])
|
||||
if not isinstance(open_pages_urls, list):
|
||||
open_pages_urls = []
|
||||
|
||||
active_page_index = obs.get('active_page_index', -1)
|
||||
if not isinstance(active_page_index, int):
|
||||
try:
|
||||
active_page_index = int(active_page_index)
|
||||
except (ValueError, TypeError):
|
||||
active_page_index = -1
|
||||
|
||||
dom_object = obs.get('dom_object', {})
|
||||
if not isinstance(dom_object, dict):
|
||||
dom_object = {}
|
||||
|
||||
axtree_object = obs.get('axtree_object', {})
|
||||
if not isinstance(axtree_object, dict):
|
||||
axtree_object = {}
|
||||
|
||||
extra_element_properties = obs.get('extra_element_properties', {})
|
||||
if not isinstance(extra_element_properties, dict):
|
||||
extra_element_properties = {}
|
||||
|
||||
last_action = obs.get('last_action', '')
|
||||
if not isinstance(last_action, str):
|
||||
last_action = str(last_action)
|
||||
|
||||
last_action_error = obs.get('last_action_error', '')
|
||||
if not isinstance(last_action_error, str):
|
||||
last_action_error = str(last_action_error)
|
||||
|
||||
return BrowserOutputObservation(
|
||||
content=text_content, # text content of the page
|
||||
url=url, # URL of the page
|
||||
content=obs['text_content'], # text content of the page
|
||||
url=obs.get('url', ''), # URL of the page
|
||||
screenshot=obs.get('screenshot', None), # base64-encoded screenshot, png
|
||||
set_of_marks=obs.get(
|
||||
'set_of_marks', None
|
||||
), # base64-encoded Set-of-Marks annotated screenshot, png,
|
||||
goal_image_urls=image_content,
|
||||
open_pages_urls=open_pages_urls, # list of open pages
|
||||
active_page_index=active_page_index, # index of the active page
|
||||
dom_object=dom_object, # DOM object
|
||||
axtree_object=axtree_object, # accessibility tree object
|
||||
extra_element_properties=extra_element_properties,
|
||||
goal_image_urls=obs.get('image_content', []),
|
||||
open_pages_urls=obs.get('open_pages_urls', []), # list of open pages
|
||||
active_page_index=obs.get(
|
||||
'active_page_index', -1
|
||||
), # index of the active page
|
||||
dom_object=obs.get('dom_object', {}), # DOM object
|
||||
axtree_object=obs.get('axtree_object', {}), # accessibility tree object
|
||||
extra_element_properties=obs.get('extra_element_properties', {}),
|
||||
focused_element_bid=obs.get(
|
||||
'focused_element_bid', None
|
||||
), # focused element bid
|
||||
last_browser_action=last_action, # last browser env action performed
|
||||
last_browser_action_error=last_action_error,
|
||||
error=bool(last_action_error), # error flag
|
||||
last_browser_action=obs.get(
|
||||
'last_action', ''
|
||||
), # last browser env action performed
|
||||
last_browser_action_error=obs.get('last_action_error', ''),
|
||||
error=True if obs.get('last_action_error', '') else False, # error flag
|
||||
trigger_by_action=action.action,
|
||||
)
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
url_value = asked_url if action.action == ActionType.BROWSE else ''
|
||||
|
||||
return BrowserOutputObservation(
|
||||
content=error_message,
|
||||
content=str(e),
|
||||
screenshot='',
|
||||
error=True,
|
||||
last_browser_action_error=error_message,
|
||||
url=url_value,
|
||||
last_browser_action_error=str(e),
|
||||
url=asked_url if action.action == ActionType.BROWSE else '',
|
||||
trigger_by_action=action.action,
|
||||
)
|
||||
|
||||
@@ -184,8 +184,6 @@ class DaytonaRuntime(ActionExecutionClient):
|
||||
|
||||
self.api_url = self._construct_api_url(self._sandbox_port)
|
||||
|
||||
# Ensure workspace is not None before accessing its attributes
|
||||
assert self.workspace is not None, 'Workspace should not be None at this point'
|
||||
state = self.workspace.instance.state
|
||||
|
||||
if state == 'stopping':
|
||||
|
||||
@@ -33,14 +33,6 @@ from openhands.server.file_config import (
|
||||
)
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
|
||||
def _write_to_file(file_path: str, contents: bytes) -> None:
|
||||
"""Helper function to write contents to a file."""
|
||||
with open(file_path, 'wb') as file:
|
||||
file.write(contents)
|
||||
file.flush()
|
||||
|
||||
|
||||
app = APIRouter(prefix='/api/conversations/{conversation_id}')
|
||||
|
||||
|
||||
@@ -203,8 +195,9 @@ async def upload_file(request: Request, conversation_id: str, files: list[Upload
|
||||
# copy the file to the runtime
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_file_path = os.path.join(tmp_dir, safe_filename)
|
||||
# Use a helper function to write to the file
|
||||
await call_sync_from_async(_write_to_file, tmp_file_path, file_contents)
|
||||
with open(tmp_file_path, 'wb') as tmp_file:
|
||||
tmp_file.write(file_contents)
|
||||
tmp_file.flush()
|
||||
|
||||
runtime: Runtime = request.state.conversation.runtime
|
||||
try:
|
||||
|
||||
@@ -112,9 +112,7 @@ async def _create_new_conversation(
|
||||
title=conversation_title,
|
||||
user_id=user_id,
|
||||
github_user_id=None,
|
||||
selected_repository=selected_repository.full_name
|
||||
if selected_repository
|
||||
else selected_repository,
|
||||
selected_repository=selected_repository.full_name if selected_repository else selected_repository,
|
||||
selected_branch=selected_branch,
|
||||
)
|
||||
)
|
||||
@@ -382,7 +380,7 @@ async def delete_conversation(
|
||||
async def _get_conversation_info(
|
||||
conversation: ConversationMetadata,
|
||||
is_running: bool,
|
||||
) -> ConversationInfo:
|
||||
) -> ConversationInfo | None:
|
||||
try:
|
||||
title = conversation.title
|
||||
if not title:
|
||||
@@ -402,12 +400,4 @@ async def _get_conversation_info(
|
||||
f'Error loading conversation {conversation.conversation_id}: {str(e)}',
|
||||
extra={'session_id': conversation.conversation_id},
|
||||
)
|
||||
# Create a default ConversationInfo object instead of returning None
|
||||
return ConversationInfo(
|
||||
conversation_id=conversation.conversation_id,
|
||||
title=get_default_conversation_title(conversation.conversation_id),
|
||||
last_updated_at=conversation.last_updated_at,
|
||||
created_at=conversation.created_at,
|
||||
selected_repository='',
|
||||
status=ConversationStatus.STOPPED,
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -15,13 +15,6 @@ from openhands.core.config import LLMConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.llm import bedrock
|
||||
from openhands.server.shared import config, server_config
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
|
||||
def _get_ollama_models(url: str) -> list:
|
||||
"""Helper function to get Ollama models."""
|
||||
return httpx.get(url, timeout=3).json()['models']
|
||||
|
||||
|
||||
app = APIRouter(prefix='/api/options')
|
||||
|
||||
@@ -67,9 +60,7 @@ async def get_litellm_models() -> list[str]:
|
||||
if ollama_base_url:
|
||||
ollama_url = ollama_base_url.strip('/') + '/api/tags'
|
||||
try:
|
||||
ollama_models_list = await call_sync_from_async(
|
||||
_get_ollama_models, ollama_url
|
||||
)
|
||||
ollama_models_list = httpx.get(ollama_url, timeout=3).json()['models']
|
||||
for model in ollama_models_list:
|
||||
model_list.append('ollama/' + model['name'])
|
||||
break
|
||||
|
||||
@@ -3,14 +3,14 @@ import os
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.storage.files import FileStore
|
||||
|
||||
IN_MEMORY_FILES: dict = {}
|
||||
|
||||
|
||||
class InMemoryFileStore(FileStore):
|
||||
files: dict[str, str]
|
||||
|
||||
def __init__(self, files: dict[str, str] = IN_MEMORY_FILES):
|
||||
self.files = files
|
||||
def __init__(self, files: dict[str, str] | None = None) -> None:
|
||||
self.files = {}
|
||||
if files is not None:
|
||||
self.files = files
|
||||
|
||||
def write(self, path: str, contents: str | bytes) -> None:
|
||||
if isinstance(contents, bytes):
|
||||
|
||||
@@ -1,22 +1,17 @@
|
||||
import asyncio
|
||||
from concurrent import futures
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Callable, Coroutine, Iterable, List, TypeVar
|
||||
|
||||
T = TypeVar('T') # Return type of the async function
|
||||
R = TypeVar('R') # Return type of the sync function
|
||||
from typing import Callable, Coroutine, Iterable, List
|
||||
|
||||
GENERAL_TIMEOUT: int = 15
|
||||
EXECUTOR = ThreadPoolExecutor()
|
||||
|
||||
|
||||
async def call_sync_from_async(fn: Callable[..., R], *args: Any, **kwargs: Any) -> R:
|
||||
async def call_sync_from_async(fn: Callable, *args, **kwargs):
|
||||
"""
|
||||
Shorthand for running a function in the default background thread pool executor
|
||||
and awaiting the result. The nature of synchronous code is that the future
|
||||
returned by this function is not cancellable.
|
||||
|
||||
Preserves the return type of the original function.
|
||||
returned by this function is not cancellable
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
coro = loop.run_in_executor(None, lambda: fn(*args, **kwargs))
|
||||
@@ -25,28 +20,24 @@ async def call_sync_from_async(fn: Callable[..., R], *args: Any, **kwargs: Any)
|
||||
|
||||
|
||||
def call_async_from_sync(
|
||||
corofn: Callable[..., Coroutine[Any, Any, T]],
|
||||
timeout: float = GENERAL_TIMEOUT,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> T:
|
||||
corofn: Callable, timeout: float = GENERAL_TIMEOUT, *args, **kwargs
|
||||
):
|
||||
"""
|
||||
Shorthand for running a coroutine in the default background thread pool executor
|
||||
and awaiting the result.
|
||||
|
||||
Preserves the return type of the original coroutine function.
|
||||
and awaiting the result
|
||||
"""
|
||||
|
||||
if corofn is None:
|
||||
raise ValueError('corofn is None')
|
||||
if not asyncio.iscoroutinefunction(corofn):
|
||||
raise ValueError('corofn is not a coroutine function')
|
||||
|
||||
async def arun() -> T:
|
||||
async def arun():
|
||||
coro = corofn(*args, **kwargs)
|
||||
result = await coro
|
||||
return result
|
||||
|
||||
def run() -> T:
|
||||
def run():
|
||||
loop_for_thread = asyncio.new_event_loop()
|
||||
try:
|
||||
asyncio.set_event_loop(loop_for_thread)
|
||||
@@ -61,31 +52,20 @@ def call_async_from_sync(
|
||||
|
||||
|
||||
async def call_coro_in_bg_thread(
|
||||
corofn: Callable[..., Coroutine[Any, Any, T]],
|
||||
timeout: float = GENERAL_TIMEOUT,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> T:
|
||||
"""
|
||||
Function for running a coroutine in a background thread.
|
||||
|
||||
Preserves the return type of the original coroutine function.
|
||||
"""
|
||||
return await call_sync_from_async(
|
||||
call_async_from_sync, corofn, timeout, *args, **kwargs
|
||||
)
|
||||
corofn: Callable, timeout: float = GENERAL_TIMEOUT, *args, **kwargs
|
||||
):
|
||||
"""Function for running a coroutine in a background thread."""
|
||||
await call_sync_from_async(call_async_from_sync, corofn, timeout, *args, **kwargs)
|
||||
|
||||
|
||||
async def wait_all(
|
||||
iterable: Iterable[Coroutine[Any, Any, T]], timeout: int = GENERAL_TIMEOUT
|
||||
) -> List[T]:
|
||||
iterable: Iterable[Coroutine], timeout: int = GENERAL_TIMEOUT
|
||||
) -> List:
|
||||
"""
|
||||
Shorthand for waiting for all the coroutines in the iterable given in parallel. Creates
|
||||
a task for each coroutine.
|
||||
Returns a list of results in the original order. If any single task raised an exception, this is raised.
|
||||
If multiple tasks raised exceptions, an AsyncException is raised containing all exceptions.
|
||||
|
||||
Preserves the return type of the original coroutines.
|
||||
"""
|
||||
tasks = [asyncio.create_task(c) for c in iterable]
|
||||
if not tasks:
|
||||
@@ -110,8 +90,8 @@ async def wait_all(
|
||||
|
||||
|
||||
class AsyncException(Exception):
|
||||
def __init__(self, exceptions: list[Exception]) -> None:
|
||||
def __init__(self, exceptions):
|
||||
self.exceptions = exceptions
|
||||
|
||||
def __str__(self) -> str:
|
||||
def __str__(self):
|
||||
return '\n'.join(str(e) for e in self.exceptions)
|
||||
|
||||
@@ -25,7 +25,7 @@ class Chunk(BaseModel):
|
||||
return ret
|
||||
|
||||
|
||||
def _create_chunks_from_raw_string(content: str, size: int) -> list[Chunk]:
|
||||
def _create_chunks_from_raw_string(content: str, size: int):
|
||||
lines = content.split('\n')
|
||||
ret = []
|
||||
for i in range(0, len(lines), size):
|
||||
@@ -65,7 +65,7 @@ def normalized_lcs(chunk: str, query: str) -> float:
|
||||
"""
|
||||
if len(chunk) == 0:
|
||||
return 0.0
|
||||
_score = float(pylcs.lcs_sequence_length(chunk, query))
|
||||
_score = pylcs.lcs_sequence_length(chunk, query)
|
||||
return _score / len(chunk)
|
||||
|
||||
|
||||
|
||||
@@ -15,17 +15,15 @@ Hopefully, this will be fixed soon and we can remove this abomination.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
from typing import Any, Callable, Iterator, TypeVar
|
||||
from typing import Callable
|
||||
|
||||
import httpx
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def ensure_httpx_close() -> Iterator[None]:
|
||||
def ensure_httpx_close():
|
||||
wrapped_class = httpx.Client
|
||||
proxys: list[Any] = []
|
||||
proxys = []
|
||||
|
||||
class ClientProxy:
|
||||
"""
|
||||
@@ -34,24 +32,24 @@ def ensure_httpx_close() -> Iterator[None]:
|
||||
where a client is reused, we need to be able to reuse the client even after closing it.
|
||||
"""
|
||||
|
||||
client_constructor: Callable[..., Any]
|
||||
args: tuple[Any, ...]
|
||||
kwargs: dict[str, Any]
|
||||
client: httpx.Client | None
|
||||
client_constructor: Callable
|
||||
args: tuple
|
||||
kwargs: dict
|
||||
client: httpx.Client
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.client = wrapped_class(*self.args, **self.kwargs)
|
||||
proxys.append(self)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
def __getattr__(self, name):
|
||||
# Invoke a method on the proxied client - create one if required
|
||||
if self.client is None:
|
||||
self.client = wrapped_class(*self.args, **self.kwargs)
|
||||
return getattr(self.client, name)
|
||||
|
||||
def close(self) -> None:
|
||||
def close(self):
|
||||
# Close the client if it is open
|
||||
if self.client:
|
||||
self.client.close()
|
||||
@@ -64,21 +62,17 @@ def ensure_httpx_close() -> Iterator[None]:
|
||||
return object.__getattribute__(self, 'iter')(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
def is_closed(self):
|
||||
# Check if closed
|
||||
if self.client is None:
|
||||
return True
|
||||
# Convert to bool to ensure we return a bool
|
||||
return bool(self.client.is_closed)
|
||||
return self.client.is_closed
|
||||
|
||||
# We need to monkey patch the Client class to track instances
|
||||
# This is a hack until LiteLLM fixes their client lifecycle management
|
||||
original_client = httpx.Client
|
||||
httpx.Client = ClientProxy
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
httpx.Client = original_client
|
||||
httpx.Client = wrapped_class
|
||||
while proxys:
|
||||
proxy = proxys.pop()
|
||||
proxy.close()
|
||||
|
||||
@@ -5,15 +5,12 @@ from typing import Type, TypeVar
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
def import_from(qual_name: str) -> type:
|
||||
def import_from(qual_name: str):
|
||||
"""Import the value from the qualified name given"""
|
||||
parts = qual_name.split('.')
|
||||
module_name = '.'.join(parts[:-1])
|
||||
module = importlib.import_module(module_name)
|
||||
result = getattr(module, parts[-1])
|
||||
assert isinstance(
|
||||
result, type
|
||||
), f'Expected {qual_name} to be a type, got {type(result)}'
|
||||
return result
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import base64
|
||||
from typing import Any, AsyncIterator, Callable
|
||||
from typing import AsyncIterator, Callable
|
||||
|
||||
|
||||
def offset_to_page_id(offset: int, has_next: bool) -> str | None:
|
||||
@@ -16,7 +16,7 @@ def page_id_to_offset(page_id: str | None) -> int:
|
||||
return offset
|
||||
|
||||
|
||||
async def iterate(fn: Callable[..., Any], **kwargs: Any) -> AsyncIterator[Any]:
|
||||
async def iterate(fn: Callable, **kwargs) -> AsyncIterator:
|
||||
"""Iterate over paged result sets. Assumes that the results sets contain an array of result objects, and a next_page_id"""
|
||||
kwargs = {**kwargs}
|
||||
kwargs['page_id'] = None
|
||||
|
||||
@@ -22,7 +22,4 @@ def colorize(text: str, color: TermColor = TermColor.WARNING) -> str:
|
||||
Returns:
|
||||
str: Colored text
|
||||
"""
|
||||
# colored() returns a string with ANSI color codes
|
||||
result = colored(text, color.value)
|
||||
assert isinstance(result, str)
|
||||
return result
|
||||
return colored(text, color.value)
|
||||
|
||||
@@ -59,7 +59,6 @@ async def test_agent_session_start_with_no_state(mock_agent):
|
||||
|
||||
# Create a mock runtime and set it up
|
||||
mock_runtime = MagicMock(spec=Runtime)
|
||||
mock_runtime.get_microagents_from_selected_repo.return_value = []
|
||||
|
||||
# Mock the runtime creation to set up the runtime attribute
|
||||
async def mock_create_runtime(*args, **kwargs):
|
||||
@@ -143,7 +142,6 @@ async def test_agent_session_start_with_restored_state(mock_agent):
|
||||
|
||||
# Create a mock runtime and set it up
|
||||
mock_runtime = MagicMock(spec=Runtime)
|
||||
mock_runtime.get_microagents_from_selected_repo.return_value = []
|
||||
|
||||
# Mock the runtime creation to set up the runtime attribute
|
||||
async def mock_create_runtime(*args, **kwargs):
|
||||
|
||||
@@ -28,7 +28,7 @@ from openhands.storage.memory import InMemoryFileStore
|
||||
@pytest.fixture
|
||||
def file_store():
|
||||
"""Create a temporary file store for testing."""
|
||||
return InMemoryFileStore()
|
||||
return InMemoryFileStore({})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -190,10 +190,9 @@ async def test_memory_with_microagents():
|
||||
assert 'magic word' in observation.microagent_knowledge[0].content
|
||||
|
||||
|
||||
def test_memory_repository_info(prompt_dir):
|
||||
def test_memory_repository_info(prompt_dir, file_store):
|
||||
"""Test that Memory adds repository info to RecallObservations."""
|
||||
# Create an in-memory file store and real event stream
|
||||
file_store = InMemoryFileStore()
|
||||
# real event stream
|
||||
event_stream = EventStream(sid='test-session', file_store=file_store)
|
||||
|
||||
# Create a test repo microagent first
|
||||
@@ -321,10 +320,9 @@ async def test_memory_with_agent_microagents():
|
||||
assert 'magic word' in observation.microagent_knowledge[0].content
|
||||
|
||||
|
||||
def test_memory_multiple_repo_microagents(prompt_dir):
|
||||
def test_memory_multiple_repo_microagents(prompt_dir, file_store):
|
||||
"""Test that Memory loads and concatenates multiple repo microagents correctly."""
|
||||
# Create an in-memory file store and real event stream
|
||||
file_store = InMemoryFileStore()
|
||||
# Create real event stream
|
||||
event_stream = EventStream(sid='test-session', file_store=file_store)
|
||||
|
||||
# Create two test repo microagents
|
||||
|
||||
Reference in New Issue
Block a user