mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
12 Commits
fix-cli-co
...
delegates-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
497fd4a02c | ||
|
|
7fac7d6dd0 | ||
|
|
4155b8f801 | ||
|
|
e712b013f9 | ||
|
|
29b137e9b1 | ||
|
|
1292f0c2ea | ||
|
|
da8c946078 | ||
|
|
43cef1f969 | ||
|
|
ab661b485b | ||
|
|
9632914bf0 | ||
|
|
3db780ef93 | ||
|
|
43c16516e8 |
@@ -162,7 +162,7 @@ class BrowsingAgent(Agent):
|
||||
last_action = event
|
||||
elif isinstance(event, MessageAction) and event.source == EventSource.AGENT:
|
||||
# agent has responded, task finished.
|
||||
return AgentFinishAction(outputs={'content': event.content})
|
||||
return AgentFinishAction(final_thought=event.content)
|
||||
elif isinstance(event, Observation):
|
||||
last_obs = event
|
||||
|
||||
@@ -201,10 +201,8 @@ class BrowsingAgent(Agent):
|
||||
)
|
||||
return MessageAction('Error encountered when browsing.')
|
||||
|
||||
goal, _ = state.get_current_user_intent()
|
||||
|
||||
if goal is None:
|
||||
goal = state.inputs['task']
|
||||
user_message_action = state.get_current_user_intent()
|
||||
goal = user_message_action.content
|
||||
|
||||
system_msg = get_system_message(
|
||||
goal,
|
||||
|
||||
@@ -105,7 +105,8 @@ def response_to_actions(
|
||||
elif tool_call.function.name == 'delegate_to_browsing_agent':
|
||||
action = AgentDelegateAction(
|
||||
agent='BrowsingAgent',
|
||||
inputs=arguments,
|
||||
prompt=arguments.get('prompt', ''),
|
||||
inputs={},
|
||||
)
|
||||
|
||||
# ================================================
|
||||
@@ -113,8 +114,10 @@ def response_to_actions(
|
||||
# ================================================
|
||||
elif tool_call.function.name == FinishTool['function']['name']:
|
||||
action = AgentFinishAction(
|
||||
final_thought=arguments.get('message', ''),
|
||||
outputs=arguments.get('outputs', {}),
|
||||
thought=arguments.get('thought', ''),
|
||||
task_completed=arguments.get('task_completed', None),
|
||||
final_thought=arguments.get('final_thought', ''),
|
||||
)
|
||||
|
||||
# ================================================
|
||||
|
||||
@@ -216,7 +216,7 @@ Note:
|
||||
last_action = event
|
||||
elif isinstance(event, MessageAction) and event.source == EventSource.AGENT:
|
||||
# agent has responded, task finished.
|
||||
return AgentFinishAction(outputs={'content': event.content})
|
||||
return AgentFinishAction(final_thought=event.content)
|
||||
elif isinstance(event, Observation):
|
||||
# Only process BrowserOutputObservation and skip other observation types
|
||||
if not isinstance(event, BrowserOutputObservation):
|
||||
@@ -271,10 +271,10 @@ Note:
|
||||
)
|
||||
return MessageAction('Error encountered when browsing.')
|
||||
set_of_marks = last_obs.set_of_marks
|
||||
goal, image_urls = state.get_current_user_intent()
|
||||
user_message_action = state.get_current_user_intent()
|
||||
goal = user_message_action.content
|
||||
image_urls = user_message_action.image_urls
|
||||
|
||||
if goal is None:
|
||||
goal = state.inputs['task']
|
||||
goal_txt, goal_images = create_goal_prompt(goal, image_urls)
|
||||
observation_txt, som_screenshot = create_observation_prompt(
|
||||
cur_axtree_txt, tabs, focused_element, error_prefix, set_of_marks
|
||||
|
||||
@@ -438,12 +438,13 @@ class AgentController:
|
||||
elif isinstance(action, AgentDelegateAction):
|
||||
await self.start_delegate(action)
|
||||
assert self.delegate is not None
|
||||
# Post a MessageAction with the task for the delegate
|
||||
if 'task' in action.inputs:
|
||||
# Post a MessageAction with the prompt for the delegate
|
||||
if action.prompt:
|
||||
self.event_stream.add_event(
|
||||
MessageAction(content='TASK: ' + action.inputs['task']),
|
||||
EventSource.USER,
|
||||
MessageAction(content=action.prompt),
|
||||
EventSource.USER, # Source is USER, as it represents the task prompt for the delegate
|
||||
)
|
||||
# Delegate starts in RUNNING state as it receives the prompt immediately
|
||||
await self.delegate.set_agent_state_to(AgentState.RUNNING)
|
||||
return
|
||||
|
||||
@@ -727,34 +728,22 @@ class AgentController:
|
||||
# close the delegate controller before adding new events
|
||||
asyncio.get_event_loop().run_until_complete(self.delegate.close())
|
||||
|
||||
if delegate_state in (AgentState.FINISHED, AgentState.REJECTED):
|
||||
# retrieve delegate result
|
||||
delegate_outputs = (
|
||||
self.delegate.state.outputs if self.delegate.state else {}
|
||||
)
|
||||
# prepare delegate result observation
|
||||
delegate_outputs = self.delegate.state.outputs if self.delegate.state else {}
|
||||
formatted_output = ', '.join(
|
||||
f'{key}: {value}' for key, value in delegate_outputs.items()
|
||||
)
|
||||
|
||||
# prepare delegate result observation
|
||||
# TODO: replace this with AI-generated summary (#2395)
|
||||
formatted_output = ', '.join(
|
||||
f'{key}: {value}' for key, value in delegate_outputs.items()
|
||||
)
|
||||
if delegate_state in (AgentState.FINISHED, AgentState.REJECTED):
|
||||
content = (
|
||||
f'{self.delegate.agent.name} finishes task with {formatted_output}'
|
||||
)
|
||||
else:
|
||||
# delegate state is ERROR
|
||||
# emit AgentDelegateObservation with error content
|
||||
delegate_outputs = (
|
||||
self.delegate.state.outputs if self.delegate.state else {}
|
||||
)
|
||||
content = (
|
||||
f'{self.delegate.agent.name} encountered an error during execution.'
|
||||
)
|
||||
|
||||
content = f'Delegated agent finished with result:\n\n{content}'
|
||||
content = f'{self.delegate.agent.name} encountered an error during execution. Known results: {delegate_outputs}'
|
||||
|
||||
# emit the delegate result observation
|
||||
obs = AgentDelegateObservation(outputs=delegate_outputs, content=content)
|
||||
obs = AgentDelegateObservation(content=content, outputs={})
|
||||
|
||||
# associate the delegate action with the initiating tool call
|
||||
for event in reversed(self.state.history):
|
||||
|
||||
@@ -188,19 +188,39 @@ class State:
|
||||
if not hasattr(self, 'history'):
|
||||
self.history = []
|
||||
|
||||
def get_current_user_intent(self) -> tuple[str | None, list[str] | None]:
|
||||
"""Returns the latest user message and image(if provided) that appears after a FinishAction, or the first (the task) if nothing was finished yet."""
|
||||
last_user_message = None
|
||||
last_user_message_image_urls: list[str] | None = []
|
||||
for event in reversed(self.view):
|
||||
if isinstance(event, MessageAction) and event.source == 'user':
|
||||
last_user_message = event.content
|
||||
last_user_message_image_urls = event.image_urls
|
||||
elif isinstance(event, AgentFinishAction):
|
||||
if last_user_message is not None:
|
||||
return last_user_message, None
|
||||
def get_current_user_intent(self) -> MessageAction:
|
||||
"""Returns the latest user MessageAction that appears after a FinishAction, or the first (the task) if nothing was finished yet."""
|
||||
likely_task: MessageAction | None = None
|
||||
|
||||
return last_user_message, last_user_message_image_urls
|
||||
# Search in the view for the latest user message after the last finish action
|
||||
for event in reversed(self.view):
|
||||
if isinstance(event, MessageAction) and event.source == EventSource.USER:
|
||||
likely_task = event
|
||||
elif isinstance(event, AgentFinishAction):
|
||||
# If a FinishAction is found, the user message after it is the one we just found (if any)
|
||||
break
|
||||
|
||||
# If a user message was found in the view after the last finish action, return it
|
||||
if likely_task is not None:
|
||||
return likely_task
|
||||
|
||||
# If no user message was found in the view after the last finish action,
|
||||
# it means either there were no user messages in the view, or the last event in the view was a FinishAction
|
||||
# In this case, we fall back to finding the very first user message in the full history.
|
||||
logger.warning(
|
||||
'No user message found in the view after the last FinishAction. Returning the first message in history.'
|
||||
)
|
||||
if self.history:
|
||||
# Look for the very first user message in the full history
|
||||
for event in self.history:
|
||||
if (
|
||||
isinstance(event, MessageAction)
|
||||
and event.source == EventSource.USER
|
||||
):
|
||||
return event
|
||||
|
||||
# If no user message is found in the entire history, raise an error
|
||||
raise ValueError('No user message found in history. This should not happen.')
|
||||
|
||||
def get_last_agent_message(self) -> MessageAction | None:
|
||||
for event in reversed(self.view):
|
||||
|
||||
@@ -86,6 +86,13 @@ class AgentRejectAction(Action):
|
||||
class AgentDelegateAction(Action):
|
||||
agent: str
|
||||
inputs: dict
|
||||
"""Deprecated.
|
||||
Delegate agents run similarly to the main agent:
|
||||
- start from a prompt (passed in the 'prompt' field)
|
||||
- end with an AgentFinishAction.
|
||||
"""
|
||||
prompt: str
|
||||
"""The prompt/task for the delegate agent"""
|
||||
thought: str = ''
|
||||
action: str = ActionType.DELEGATE
|
||||
|
||||
|
||||
@@ -10,13 +10,18 @@ class AgentDelegateObservation(Observation):
|
||||
|
||||
Attributes:
|
||||
content (str): The content of the observation.
|
||||
outputs (dict): The outputs of the delegated agent.
|
||||
outputs (dict): The outputs of the delegated agent. (deprecated)
|
||||
observation (str): The type of observation.
|
||||
"""
|
||||
|
||||
outputs: dict
|
||||
"""Deprecated.
|
||||
Delegate agents run similarly to the main agent:
|
||||
- start from a prompt (passed in the 'prompt' field)
|
||||
- end with an AgentFinishAction.
|
||||
"""
|
||||
observation: str = ObservationType.DELEGATE
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return ''
|
||||
return self.content
|
||||
|
||||
47
openhands/integrations/github/queries.py
Normal file
47
openhands/integrations/github/queries.py
Normal file
@@ -0,0 +1,47 @@
|
||||
suggested_task_pr_graphql_query = """
|
||||
query GetUserPRs($login: String!) {
|
||||
user(login: $login) {
|
||||
pullRequests(first: 50, states: [OPEN], orderBy: {field: UPDATED_AT, direction: DESC}) {
|
||||
nodes {
|
||||
number
|
||||
title
|
||||
repository {
|
||||
nameWithOwner
|
||||
}
|
||||
mergeable
|
||||
commits(last: 1) {
|
||||
nodes {
|
||||
commit {
|
||||
statusCheckRollup {
|
||||
state
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
reviews(first: 50, states: [CHANGES_REQUESTED, COMMENTED]) {
|
||||
nodes {
|
||||
state
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
suggested_task_issue_graphql_query = """
|
||||
query GetUserIssues($login: String!) {
|
||||
user(login: $login) {
|
||||
issues(first: 50, states: [OPEN], filterBy: {assignee: $login}, orderBy: {field: UPDATED_AT, direction: DESC}) {
|
||||
nodes {
|
||||
number
|
||||
title
|
||||
repository {
|
||||
nameWithOwner
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
@@ -6,6 +6,7 @@ from pydantic import SecretStr
|
||||
|
||||
from openhands.integrations.service_types import (
|
||||
BaseGitService,
|
||||
Branch,
|
||||
GitService,
|
||||
ProviderType,
|
||||
Repository,
|
||||
@@ -131,7 +132,7 @@ class GitLabService(BaseGitService, GitService):
|
||||
|
||||
payload = {
|
||||
'query': query,
|
||||
'variables': variables,
|
||||
'variables': variables if variables is not None else {},
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
@@ -195,6 +196,7 @@ class GitLabService(BaseGitService, GitService):
|
||||
full_name=repo.get('path_with_namespace'),
|
||||
stargazers_count=repo.get('star_count'),
|
||||
git_provider=ProviderType.GITLAB,
|
||||
is_public=True,
|
||||
)
|
||||
for repo in response
|
||||
]
|
||||
@@ -398,6 +400,44 @@ class GitLabService(BaseGitService, GitService):
|
||||
is_public=repo.get('visibility') == 'public',
|
||||
)
|
||||
|
||||
async def get_branches(self, repository: str) -> list[Branch]:
|
||||
"""Get branches for a repository"""
|
||||
encoded_name = repository.replace('/', '%2F')
|
||||
url = f'{self.BASE_URL}/projects/{encoded_name}/repository/branches'
|
||||
|
||||
# Set maximum branches to fetch (10 pages with 100 per page)
|
||||
MAX_BRANCHES = 1000
|
||||
PER_PAGE = 100
|
||||
|
||||
all_branches: list[Branch] = []
|
||||
page = 1
|
||||
|
||||
# Fetch up to 10 pages of branches
|
||||
while page <= 10 and len(all_branches) < MAX_BRANCHES:
|
||||
params = {'per_page': str(PER_PAGE), 'page': str(page)}
|
||||
response, headers = await self._make_request(url, params)
|
||||
|
||||
if not response: # No more branches
|
||||
break
|
||||
|
||||
for branch_data in response:
|
||||
branch = Branch(
|
||||
name=branch_data.get('name'),
|
||||
commit_sha=branch_data.get('commit', {}).get('id', ''),
|
||||
protected=branch_data.get('protected', False),
|
||||
last_push_date=branch_data.get('commit', {}).get('committed_date'),
|
||||
)
|
||||
all_branches.append(branch)
|
||||
|
||||
page += 1
|
||||
|
||||
# Check if we've reached the last page
|
||||
link_header = headers.get('Link', '')
|
||||
if 'rel="next"' not in link_header:
|
||||
break
|
||||
|
||||
return all_branches
|
||||
|
||||
|
||||
gitlab_service_cls = os.environ.get(
|
||||
'OPENHANDS_GITLAB_SERVICE_CLS',
|
||||
|
||||
@@ -18,6 +18,7 @@ from openhands.integrations.github.github_service import GithubServiceImpl
|
||||
from openhands.integrations.gitlab.gitlab_service import GitLabServiceImpl
|
||||
from openhands.integrations.service_types import (
|
||||
AuthenticationError,
|
||||
Branch,
|
||||
GitService,
|
||||
ProviderType,
|
||||
Repository,
|
||||
@@ -30,6 +31,7 @@ from openhands.server.types import AppMode
|
||||
class ProviderToken(BaseModel):
|
||||
token: SecretStr | None = Field(default=None)
|
||||
user_id: str | None = Field(default=None)
|
||||
host: str | None = Field(default=None)
|
||||
|
||||
model_config = {
|
||||
'frozen': True, # Makes the entire model immutable
|
||||
@@ -39,15 +41,20 @@ class ProviderToken(BaseModel):
|
||||
@classmethod
|
||||
def from_value(cls, token_value: ProviderToken | dict[str, str]) -> ProviderToken:
|
||||
"""Factory method to create a ProviderToken from various input types"""
|
||||
if isinstance(token_value, ProviderToken):
|
||||
if isinstance(token_value, cls):
|
||||
return token_value
|
||||
elif isinstance(token_value, dict):
|
||||
token_str = token_value.get('token')
|
||||
token_str = token_value.get('token', '')
|
||||
# Override with emtpy string if it was set to None
|
||||
# Cannot pass None to SecretStr
|
||||
if token_str is None:
|
||||
token_str = ''
|
||||
user_id = token_value.get('user_id')
|
||||
return cls(token=SecretStr(token_str), user_id=user_id)
|
||||
host = token_value.get('host')
|
||||
return cls(token=SecretStr(token_str), user_id=user_id, host=host)
|
||||
|
||||
else:
|
||||
raise ValueError('Unsupport Provider token type')
|
||||
raise ValueError('Unsupported Provider token type')
|
||||
|
||||
|
||||
PROVIDER_TOKEN_TYPE = MappingProxyType[ProviderType, ProviderToken]
|
||||
@@ -165,7 +172,8 @@ class ProviderHandler:
|
||||
query, per_page, sort, order
|
||||
)
|
||||
all_repos.extend(service_repos)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.warning(f'Error searching repos from {provider}: {e}')
|
||||
continue
|
||||
|
||||
return all_repos
|
||||
@@ -305,3 +313,56 @@ class ProviderHandler:
|
||||
pass
|
||||
|
||||
raise AuthenticationError(f'Unable to access repo {repository}')
|
||||
|
||||
async def get_branches(
|
||||
self, repository: str, specified_provider: ProviderType | None = None
|
||||
) -> list[Branch]:
|
||||
"""
|
||||
Get branches for a repository
|
||||
|
||||
Args:
|
||||
repository: The repository name
|
||||
specified_provider: Optional provider type to use
|
||||
|
||||
Returns:
|
||||
A list of branches for the repository
|
||||
"""
|
||||
all_branches: list[Branch] = []
|
||||
|
||||
if specified_provider:
|
||||
try:
|
||||
service = self._get_service(specified_provider)
|
||||
branches = await service.get_branches(repository)
|
||||
return branches
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f'Error fetching branches from {specified_provider}: {e}'
|
||||
)
|
||||
|
||||
for provider in self.provider_tokens:
|
||||
try:
|
||||
service = self._get_service(provider)
|
||||
branches = await service.get_branches(repository)
|
||||
all_branches.extend(branches)
|
||||
# If we found branches, no need to check other providers
|
||||
if all_branches:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f'Error fetching branches from {provider}: {e}')
|
||||
|
||||
# Sort branches by last push date (newest first)
|
||||
all_branches.sort(
|
||||
key=lambda b: b.last_push_date if b.last_push_date else '', reverse=True
|
||||
)
|
||||
|
||||
# Move main/master branch to the top if it exists
|
||||
main_branches = []
|
||||
other_branches = []
|
||||
|
||||
for branch in all_branches:
|
||||
if branch.name.lower() in ['main', 'master']:
|
||||
main_branches.append(branch)
|
||||
else:
|
||||
other_branches.append(branch)
|
||||
|
||||
return main_branches + other_branches
|
||||
|
||||
@@ -91,6 +91,13 @@ class User(BaseModel):
|
||||
email: str | None = None
|
||||
|
||||
|
||||
class Branch(BaseModel):
|
||||
name: str
|
||||
commit_sha: str
|
||||
protected: bool
|
||||
last_push_date: str | None = None # ISO 8601 format date string
|
||||
|
||||
|
||||
class Repository(BaseModel):
|
||||
id: int
|
||||
full_name: str
|
||||
@@ -164,7 +171,7 @@ class BaseGitService(ABC):
|
||||
|
||||
def handle_http_error(self, e: HTTPError) -> UnknownException:
|
||||
logger.warning(f'HTTP error on {self.provider} API: {type(e).__name__} : {e}')
|
||||
return UnknownException('Unknown error')
|
||||
return UnknownException(f'HTTP error {type(e).__name__}')
|
||||
|
||||
|
||||
class GitService(Protocol):
|
||||
@@ -211,3 +218,6 @@ class GitService(Protocol):
|
||||
self, repository: str
|
||||
) -> Repository:
|
||||
"""Gets all repository details from repository name"""
|
||||
|
||||
async def get_branches(self, repository: str) -> list[Branch]:
|
||||
"""Get branches for a repository"""
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
Please summarize your work.
|
||||
|
||||
If you answered a question, please re-state the answer to the question
|
||||
If you made changes, please create a concise overview on whether the request has been addressed successfully or if there are were issues with the attempt.
|
||||
If successful, make sure your changes are pushed to the remote branch.
|
||||
@@ -9,6 +9,7 @@ We follow format from: https://docs.litellm.ai/docs/completion/function_call
|
||||
import copy
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from typing import Iterable
|
||||
|
||||
from litellm import ChatCompletionToolParam
|
||||
@@ -47,8 +48,15 @@ Reminder:
|
||||
|
||||
STOP_WORDS = ['</function']
|
||||
|
||||
|
||||
def refine_prompt(prompt: str) -> str:
|
||||
if sys.platform == 'win32':
|
||||
return prompt.replace('bash', 'powershell')
|
||||
return prompt
|
||||
|
||||
|
||||
# NOTE: we need to make sure this example is always in-sync with the tool interface designed in openhands/agenthub/codeact_agent/function_calling.py
|
||||
IN_CONTEXT_LEARNING_EXAMPLE_PREFIX = """
|
||||
IN_CONTEXT_LEARNING_EXAMPLE_PREFIX = refine_prompt("""
|
||||
Here's a running example of how to perform a task with the provided tools.
|
||||
|
||||
--------------------- START OF EXAMPLE ---------------------
|
||||
@@ -75,7 +83,7 @@ from flask import Flask
|
||||
app = Flask(__name__)
|
||||
|
||||
@app.route('/')
|
||||
def index():
|
||||
def index() -> str:
|
||||
numbers = list(range(1, 11))
|
||||
return str(numbers)
|
||||
|
||||
@@ -218,7 +226,7 @@ The server is running on port 5000 with PID 126. You can access the list of numb
|
||||
Do NOT assume the environment is the same as in the example above.
|
||||
|
||||
--------------------- NEW TASK DESCRIPTION ---------------------
|
||||
""".lstrip()
|
||||
""").lstrip()
|
||||
|
||||
IN_CONTEXT_LEARNING_EXAMPLE_SUFFIX = """
|
||||
--------------------- END OF NEW TASK DESCRIPTION ---------------------
|
||||
@@ -245,12 +253,12 @@ def convert_tool_call_to_string(tool_call: dict) -> str:
|
||||
if tool_call['type'] != 'function':
|
||||
raise FunctionCallConversionError("Tool call type must be 'function'.")
|
||||
|
||||
ret = f"<function={tool_call['function']['name']}>\n"
|
||||
ret = f'<function={tool_call["function"]["name"]}>\n'
|
||||
try:
|
||||
args = json.loads(tool_call['function']['arguments'])
|
||||
except json.JSONDecodeError as e:
|
||||
raise FunctionCallConversionError(
|
||||
f"Failed to parse arguments as JSON. Arguments: {tool_call['function']['arguments']}"
|
||||
f'Failed to parse arguments as JSON. Arguments: {tool_call["function"]["arguments"]}'
|
||||
) from e
|
||||
for param_name, param_value in args.items():
|
||||
is_multiline = isinstance(param_value, str) and '\n' in param_value
|
||||
@@ -272,8 +280,8 @@ def convert_tools_to_description(tools: list[dict]) -> str:
|
||||
fn = tool['function']
|
||||
if i > 0:
|
||||
ret += '\n'
|
||||
ret += f"---- BEGIN FUNCTION #{i+1}: {fn['name']} ----\n"
|
||||
ret += f"Description: {fn['description']}\n"
|
||||
ret += f'---- BEGIN FUNCTION #{i + 1}: {fn["name"]} ----\n'
|
||||
ret += f'Description: {fn["description"]}\n'
|
||||
|
||||
if 'parameters' in fn:
|
||||
ret += 'Parameters:\n'
|
||||
@@ -295,12 +303,12 @@ def convert_tools_to_description(tools: list[dict]) -> str:
|
||||
desc += f'\nAllowed values: [{enum_values}]'
|
||||
|
||||
ret += (
|
||||
f' ({j+1}) {param_name} ({param_type}, {param_status}): {desc}\n'
|
||||
f' ({j + 1}) {param_name} ({param_type}, {param_status}): {desc}\n'
|
||||
)
|
||||
else:
|
||||
ret += 'No parameters are required for this function.\n'
|
||||
|
||||
ret += f'---- END FUNCTION #{i+1} ----\n'
|
||||
ret += f'---- END FUNCTION #{i + 1} ----\n'
|
||||
return ret
|
||||
|
||||
|
||||
@@ -351,7 +359,8 @@ def convert_fncall_messages_to_non_fncall_messages(
|
||||
and any(
|
||||
(
|
||||
tool['type'] == 'function'
|
||||
and tool['function']['name'] == 'execute_bash'
|
||||
and tool['function']['name']
|
||||
== refine_prompt('execute_bash')
|
||||
and 'command'
|
||||
in tool['function']['parameters']['properties']
|
||||
)
|
||||
@@ -658,7 +667,7 @@ def convert_non_fncall_messages_to_fncall_messages(
|
||||
'content': [{'type': 'text', 'text': tool_result}]
|
||||
if isinstance(content, list)
|
||||
else tool_result,
|
||||
'tool_call_id': f'toolu_{tool_call_counter-1:02d}', # Use last generated ID
|
||||
'tool_call_id': f'toolu_{tool_call_counter - 1:02d}', # Use last generated ID
|
||||
}
|
||||
)
|
||||
else:
|
||||
@@ -781,14 +790,14 @@ def convert_from_multiple_tool_calls_to_single_tool_call_messages(
|
||||
# add the tool result
|
||||
converted_messages.append(message)
|
||||
else:
|
||||
assert (
|
||||
len(pending_tool_calls) == 0
|
||||
), f'Found pending tool calls but not found in pending list: {pending_tool_calls=}'
|
||||
assert len(pending_tool_calls) == 0, (
|
||||
f'Found pending tool calls but not found in pending list: {pending_tool_calls=}'
|
||||
)
|
||||
converted_messages.append(message)
|
||||
else:
|
||||
assert (
|
||||
len(pending_tool_calls) == 0
|
||||
), f'Found pending tool calls but not expect to handle it with role {role}: {pending_tool_calls=}, {message=}'
|
||||
assert len(pending_tool_calls) == 0, (
|
||||
f'Found pending tool calls but not expect to handle it with role {role}: {pending_tool_calls=}, {message=}'
|
||||
)
|
||||
converted_messages.append(message)
|
||||
|
||||
if not ignore_final_tool_result and len(pending_tool_calls) > 0:
|
||||
|
||||
@@ -49,6 +49,8 @@ LLM_RETRY_EXCEPTIONS: tuple[type[Exception], ...] = (
|
||||
# remove this when we gemini and deepseek are supported
|
||||
CACHE_PROMPT_SUPPORTED_MODELS = [
|
||||
'claude-3-7-sonnet-20250219',
|
||||
'claude-sonnet-3-7-latest',
|
||||
'claude-3.7-sonnet',
|
||||
'claude-3-5-sonnet-20241022',
|
||||
'claude-3-5-sonnet-20240620',
|
||||
'claude-3-5-haiku-20241022',
|
||||
@@ -59,6 +61,7 @@ CACHE_PROMPT_SUPPORTED_MODELS = [
|
||||
# function calling supporting models
|
||||
FUNCTION_CALLING_SUPPORTED_MODELS = [
|
||||
'claude-3-7-sonnet-20250219',
|
||||
'claude-sonnet-3-7-latest',
|
||||
'claude-3-5-sonnet',
|
||||
'claude-3-5-sonnet-20240620',
|
||||
'claude-3-5-sonnet-20241022',
|
||||
@@ -108,7 +111,7 @@ class LLM(RetryMixin, DebugMixin):
|
||||
config: LLMConfig,
|
||||
metrics: Metrics | None = None,
|
||||
retry_listener: Callable[[int, int], None] | None = None,
|
||||
):
|
||||
) -> None:
|
||||
"""Initializes the LLM. If LLMConfig is passed, its values will be the fallback.
|
||||
|
||||
Passing simple parameters always overrides config.
|
||||
@@ -199,7 +202,7 @@ class LLM(RetryMixin, DebugMixin):
|
||||
"""Wrapper for the litellm completion function. Logs the input and output of the completion function."""
|
||||
from openhands.io import json
|
||||
|
||||
messages: list[dict[str, Any]] | dict[str, Any] = []
|
||||
messages_kwarg: list[dict[str, Any]] | dict[str, Any] = []
|
||||
mock_function_calling = not self.is_function_calling_active()
|
||||
|
||||
# some callers might send the model and messages directly
|
||||
@@ -209,16 +212,18 @@ class LLM(RetryMixin, DebugMixin):
|
||||
# design wise: we don't allow overriding the configured values
|
||||
# implementation wise: the partial function set the model as a kwarg already
|
||||
# as well as other kwargs
|
||||
messages = args[1] if len(args) > 1 else args[0]
|
||||
kwargs['messages'] = messages
|
||||
messages_kwarg = args[1] if len(args) > 1 else args[0]
|
||||
kwargs['messages'] = messages_kwarg
|
||||
|
||||
# remove the first args, they're sent in kwargs
|
||||
args = args[2:]
|
||||
elif 'messages' in kwargs:
|
||||
messages = kwargs['messages']
|
||||
messages_kwarg = kwargs['messages']
|
||||
|
||||
# ensure we work with a list of messages
|
||||
messages = messages if isinstance(messages, list) else [messages]
|
||||
messages: list[dict[str, Any]] = (
|
||||
messages_kwarg if isinstance(messages_kwarg, list) else [messages_kwarg]
|
||||
)
|
||||
|
||||
# handle conversion of to non-function calling messages if needed
|
||||
original_fncall_messages = copy.deepcopy(messages)
|
||||
@@ -290,6 +295,7 @@ class LLM(RetryMixin, DebugMixin):
|
||||
)
|
||||
|
||||
non_fncall_response_message = resp.choices[0].message
|
||||
# messages is already a list with proper typing from line 223
|
||||
fn_call_messages_with_response = (
|
||||
convert_non_fncall_messages_to_fncall_messages(
|
||||
messages + [non_fncall_response_message], mock_fncall_tools
|
||||
@@ -412,6 +418,7 @@ class LLM(RetryMixin, DebugMixin):
|
||||
)
|
||||
if current_model_info:
|
||||
self.model_info = current_model_info['model_info']
|
||||
logger.debug(f'Got model info from litellm proxy: {self.model_info}')
|
||||
|
||||
# Last two attempts to get model info from NAME
|
||||
if not self.model_info:
|
||||
@@ -467,7 +474,10 @@ class LLM(RetryMixin, DebugMixin):
|
||||
self.model_info['max_tokens'], int
|
||||
):
|
||||
self.config.max_output_tokens = self.model_info['max_tokens']
|
||||
if 'claude-3-7-sonnet' in self.config.model:
|
||||
if any(
|
||||
model in self.config.model
|
||||
for model in ['claude-3-7-sonnet', 'claude-3.7-sonnet']
|
||||
):
|
||||
self.config.max_output_tokens = 64000 # litellm set max to 128k, but that requires a header to be set
|
||||
|
||||
# Initialize function calling capability
|
||||
@@ -598,6 +608,12 @@ class LLM(RetryMixin, DebugMixin):
|
||||
if cache_write_tokens:
|
||||
stats += 'Input tokens (cache write): ' + str(cache_write_tokens) + '\n'
|
||||
|
||||
# Get context window from model info
|
||||
context_window = 0
|
||||
if self.model_info and 'max_input_tokens' in self.model_info:
|
||||
context_window = self.model_info['max_input_tokens']
|
||||
logger.debug(f'Using context window: {context_window}')
|
||||
|
||||
# Record in metrics
|
||||
# We'll treat cache_hit_tokens as "cache read" and cache_write_tokens as "cache write"
|
||||
self.metrics.add_token_usage(
|
||||
@@ -605,6 +621,7 @@ class LLM(RetryMixin, DebugMixin):
|
||||
completion_tokens=completion_tokens,
|
||||
cache_read_tokens=cache_hit_tokens,
|
||||
cache_write_tokens=cache_write_tokens,
|
||||
context_window=context_window,
|
||||
response_id=response_id,
|
||||
)
|
||||
|
||||
@@ -631,7 +648,15 @@ class LLM(RetryMixin, DebugMixin):
|
||||
logger.info(
|
||||
'Message objects now include serialized tool calls in token counting'
|
||||
)
|
||||
messages = self.format_messages_for_llm(messages) # type: ignore
|
||||
# Assert the expected type for format_messages_for_llm
|
||||
assert isinstance(messages, list) and all(
|
||||
isinstance(m, Message) for m in messages
|
||||
), 'Expected list of Message objects'
|
||||
|
||||
# We've already asserted that messages is a list of Message objects
|
||||
# Use explicit typing to satisfy mypy
|
||||
messages_typed: list[Message] = messages # type: ignore
|
||||
messages = self.format_messages_for_llm(messages_typed)
|
||||
|
||||
# try to get the token count with the default litellm tokenizers
|
||||
# or the custom tokenizer if set for this LLM configuration
|
||||
@@ -662,7 +687,7 @@ class LLM(RetryMixin, DebugMixin):
|
||||
boolean: True if executing a local model.
|
||||
"""
|
||||
if self.config.base_url is not None:
|
||||
for substring in ['localhost', '127.0.0.1' '0.0.0.0']:
|
||||
for substring in ['localhost', '127.0.0.1', '0.0.0.0']:
|
||||
if substring in self.config.base_url:
|
||||
return True
|
||||
elif self.config.model is not None:
|
||||
|
||||
@@ -26,6 +26,8 @@ class TokenUsage(BaseModel):
|
||||
completion_tokens: int = Field(default=0)
|
||||
cache_read_tokens: int = Field(default=0)
|
||||
cache_write_tokens: int = Field(default=0)
|
||||
context_window: int = Field(default=0)
|
||||
per_turn_token: int = Field(default=0)
|
||||
response_id: str = Field(default='')
|
||||
|
||||
def __add__(self, other: 'TokenUsage') -> 'TokenUsage':
|
||||
@@ -36,6 +38,8 @@ class TokenUsage(BaseModel):
|
||||
completion_tokens=self.completion_tokens + other.completion_tokens,
|
||||
cache_read_tokens=self.cache_read_tokens + other.cache_read_tokens,
|
||||
cache_write_tokens=self.cache_write_tokens + other.cache_write_tokens,
|
||||
context_window=max(self.context_window, other.context_window),
|
||||
per_turn_token=other.per_turn_token,
|
||||
response_id=self.response_id,
|
||||
)
|
||||
|
||||
@@ -60,6 +64,7 @@ class Metrics:
|
||||
completion_tokens=0,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
context_window=0,
|
||||
response_id='',
|
||||
)
|
||||
|
||||
@@ -107,6 +112,7 @@ class Metrics:
|
||||
completion_tokens=0,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
context_window=0,
|
||||
response_id='',
|
||||
)
|
||||
return self._accumulated_token_usage
|
||||
@@ -130,15 +136,22 @@ class Metrics:
|
||||
completion_tokens: int,
|
||||
cache_read_tokens: int,
|
||||
cache_write_tokens: int,
|
||||
context_window: int,
|
||||
response_id: str,
|
||||
) -> None:
|
||||
"""Add a single usage record."""
|
||||
|
||||
# Token each turn for calculating context usage.
|
||||
per_turn_token = prompt_tokens + completion_tokens
|
||||
|
||||
usage = TokenUsage(
|
||||
model=self.model_name,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
cache_read_tokens=cache_read_tokens,
|
||||
cache_write_tokens=cache_write_tokens,
|
||||
context_window=context_window,
|
||||
per_turn_token=per_turn_token,
|
||||
response_id=response_id,
|
||||
)
|
||||
self._token_usages.append(usage)
|
||||
@@ -150,6 +163,8 @@ class Metrics:
|
||||
completion_tokens=completion_tokens,
|
||||
cache_read_tokens=cache_read_tokens,
|
||||
cache_write_tokens=cache_write_tokens,
|
||||
context_window=context_window,
|
||||
per_turn_token=per_turn_token,
|
||||
response_id='',
|
||||
)
|
||||
|
||||
@@ -190,6 +205,7 @@ class Metrics:
|
||||
completion_tokens=0,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
context_window=0,
|
||||
response_id='',
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
@@ -18,8 +18,8 @@ class MCPClient(BaseModel):
|
||||
session: Optional[ClientSession] = None
|
||||
exit_stack: AsyncExitStack = AsyncExitStack()
|
||||
description: str = 'MCP client tools for server interaction'
|
||||
tools: List[MCPClientTool] = Field(default_factory=list)
|
||||
tool_map: Dict[str, MCPClientTool] = Field(default_factory=dict)
|
||||
tools: list[MCPClientTool] = Field(default_factory=list)
|
||||
tool_map: dict[str, MCPClientTool] = Field(default_factory=dict)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
@@ -91,7 +91,7 @@ class MCPClient(BaseModel):
|
||||
f'Connected to server with tools: {[tool.name for tool in response.tools]}'
|
||||
)
|
||||
|
||||
async def call_tool(self, tool_name: str, args: Dict):
|
||||
async def call_tool(self, tool_name: str, args: dict):
|
||||
"""Call a tool on the MCP server."""
|
||||
if tool_name not in self.tool_map:
|
||||
raise ValueError(f'Tool {tool_name} not found.')
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Dict
|
||||
|
||||
from mcp.types import Tool
|
||||
|
||||
|
||||
@@ -14,7 +12,7 @@ class MCPClientTool(Tool):
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def to_param(self) -> Dict:
|
||||
def to_param(self) -> dict:
|
||||
"""Convert tool to function call format."""
|
||||
return {
|
||||
'type': 'function',
|
||||
|
||||
@@ -158,12 +158,12 @@ async def add_mcp_tools_to_agent(
|
||||
ActionExecutionClient, # inline import to avoid circular import
|
||||
)
|
||||
|
||||
assert isinstance(
|
||||
runtime, ActionExecutionClient
|
||||
), 'Runtime must be an instance of ActionExecutionClient'
|
||||
assert (
|
||||
runtime.runtime_initialized
|
||||
), 'Runtime must be initialized before adding MCP tools'
|
||||
assert isinstance(runtime, ActionExecutionClient), (
|
||||
'Runtime must be an instance of ActionExecutionClient'
|
||||
)
|
||||
assert runtime.runtime_initialized, (
|
||||
'Runtime must be initialized before adding MCP tools'
|
||||
)
|
||||
|
||||
# Add the runtime as another MCP server
|
||||
updated_mcp_config = runtime.get_updated_mcp_config()
|
||||
@@ -171,7 +171,7 @@ async def add_mcp_tools_to_agent(
|
||||
mcp_tools = await fetch_mcp_tools_from_config(updated_mcp_config)
|
||||
|
||||
logger.info(
|
||||
f"Loaded {len(mcp_tools)} MCP tools: {[tool['function']['name'] for tool in mcp_tools]}"
|
||||
f'Loaded {len(mcp_tools)} MCP tools: {[tool["function"]["name"] for tool in mcp_tools]}'
|
||||
)
|
||||
|
||||
# Set the MCP tools on the agent
|
||||
|
||||
@@ -28,7 +28,7 @@ class BrowserOutputCondenser(Condenser):
|
||||
):
|
||||
results.append(
|
||||
AgentCondensationObservation(
|
||||
f'Current URL: {event.url}\nContent Omitted'
|
||||
f'Visited URL {event.url}\nContent omitted'
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -26,6 +26,7 @@ jobs:
|
||||
base_container_image: ${{ vars.OPENHANDS_BASE_CONTAINER_IMAGE || '' }}
|
||||
LLM_MODEL: ${{ vars.LLM_MODEL || 'anthropic/claude-3-5-sonnet-20241022' }}
|
||||
target_branch: ${{ vars.TARGET_BRANCH || 'main' }}
|
||||
runner: ${{ vars.TARGET_RUNNER }}
|
||||
secrets:
|
||||
PAT_TOKEN: ${{ secrets.PAT_TOKEN }}
|
||||
PAT_USERNAME: ${{ secrets.PAT_USERNAME }}
|
||||
|
||||
@@ -214,7 +214,7 @@ class GitlabIssueHandler(IssueHandlerInterface):
|
||||
|
||||
def reply_to_comment(self, pr_number: int, comment_id: str, reply: str) -> None:
|
||||
response = httpx.get(
|
||||
f'{self.base_url}/merge_requests/{pr_number}/discussions/{comment_id.split('/')[-1]}',
|
||||
f'{self.base_url}/merge_requests/{pr_number}/discussions/{comment_id.split("/")[-1]}',
|
||||
headers=self.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
@@ -225,7 +225,7 @@ class GitlabIssueHandler(IssueHandlerInterface):
|
||||
'note_id': discussions.get('notes', [])[-1]['id'],
|
||||
}
|
||||
response = httpx.post(
|
||||
f'{self.base_url}/merge_requests/{pr_number}/discussions/{comment_id.split('/')[-1]}/notes',
|
||||
f'{self.base_url}/merge_requests/{pr_number}/discussions/{comment_id.split("/")[-1]}/notes',
|
||||
headers=self.headers,
|
||||
json=data,
|
||||
)
|
||||
|
||||
80
openhands/resolver/issue_handler_factory.py
Normal file
80
openhands/resolver/issue_handler_factory.py
Normal file
@@ -0,0 +1,80 @@
|
||||
from openhands.core.config import LLMConfig
|
||||
from openhands.integrations.provider import ProviderType
|
||||
from openhands.resolver.interfaces.github import GithubIssueHandler, GithubPRHandler
|
||||
from openhands.resolver.interfaces.gitlab import GitlabIssueHandler, GitlabPRHandler
|
||||
from openhands.resolver.interfaces.issue_definitions import (
|
||||
ServiceContextIssue,
|
||||
ServiceContextPR,
|
||||
)
|
||||
|
||||
|
||||
class IssueHandlerFactory:
|
||||
def __init__(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
token: str,
|
||||
username: str,
|
||||
platform: ProviderType,
|
||||
base_domain: str,
|
||||
issue_type: str,
|
||||
llm_config: LLMConfig,
|
||||
) -> None:
|
||||
self.owner = owner
|
||||
self.repo = repo
|
||||
self.token = token
|
||||
self.username = username
|
||||
self.platform = platform
|
||||
self.base_domain = base_domain
|
||||
self.issue_type = issue_type
|
||||
self.llm_config = llm_config
|
||||
|
||||
def create(self) -> ServiceContextIssue | ServiceContextPR:
|
||||
if self.issue_type == 'issue':
|
||||
if self.platform == ProviderType.GITHUB:
|
||||
return ServiceContextIssue(
|
||||
GithubIssueHandler(
|
||||
self.owner,
|
||||
self.repo,
|
||||
self.token,
|
||||
self.username,
|
||||
self.base_domain,
|
||||
),
|
||||
self.llm_config,
|
||||
)
|
||||
else: # platform == Platform.GITLAB
|
||||
return ServiceContextIssue(
|
||||
GitlabIssueHandler(
|
||||
self.owner,
|
||||
self.repo,
|
||||
self.token,
|
||||
self.username,
|
||||
self.base_domain,
|
||||
),
|
||||
self.llm_config,
|
||||
)
|
||||
elif self.issue_type == 'pr':
|
||||
if self.platform == ProviderType.GITHUB:
|
||||
return ServiceContextPR(
|
||||
GithubPRHandler(
|
||||
self.owner,
|
||||
self.repo,
|
||||
self.token,
|
||||
self.username,
|
||||
self.base_domain,
|
||||
),
|
||||
self.llm_config,
|
||||
)
|
||||
else: # platform == Platform.GITLAB
|
||||
return ServiceContextPR(
|
||||
GitlabPRHandler(
|
||||
self.owner,
|
||||
self.repo,
|
||||
self.token,
|
||||
self.username,
|
||||
self.base_domain,
|
||||
),
|
||||
self.llm_config,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Invalid issue type: {self.issue_type}')
|
||||
@@ -6,11 +6,9 @@ class HunkException(PatchingException):
|
||||
def __init__(self, msg: str, hunk: int | None = None) -> None:
|
||||
self.hunk = hunk
|
||||
if hunk is not None:
|
||||
super(HunkException, self).__init__(
|
||||
'{msg}, in hunk #{n}'.format(msg=msg, n=hunk)
|
||||
)
|
||||
super().__init__('{msg}, in hunk #{n}'.format(msg=msg, n=hunk))
|
||||
else:
|
||||
super(HunkException, self).__init__(msg)
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
class ApplyException(PatchingException):
|
||||
@@ -19,7 +17,7 @@ class ApplyException(PatchingException):
|
||||
|
||||
class SubprocessException(ApplyException):
|
||||
def __init__(self, msg: str, code: int) -> None:
|
||||
super(SubprocessException, self).__init__(msg)
|
||||
super().__init__(msg)
|
||||
self.code = code
|
||||
|
||||
|
||||
|
||||
@@ -28,13 +28,12 @@ from openhands.events.observation import (
|
||||
)
|
||||
from openhands.events.stream import EventStreamSubscriber
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.resolver.interfaces.github import GithubIssueHandler, GithubPRHandler
|
||||
from openhands.resolver.interfaces.gitlab import GitlabIssueHandler, GitlabPRHandler
|
||||
from openhands.resolver.interfaces.issue import Issue
|
||||
from openhands.resolver.interfaces.issue_definitions import (
|
||||
ServiceContextIssue,
|
||||
ServiceContextPR,
|
||||
)
|
||||
from openhands.resolver.issue_handler_factory import IssueHandlerFactory
|
||||
from openhands.resolver.resolver_output import ResolverOutput
|
||||
from openhands.resolver.utils import (
|
||||
codeact_user_response,
|
||||
@@ -111,12 +110,22 @@ class IssueResolver:
|
||||
model = args.llm_model or os.environ['LLM_MODEL']
|
||||
base_url = args.llm_base_url or os.environ.get('LLM_BASE_URL', None)
|
||||
api_version = os.environ.get('LLM_API_VERSION', None)
|
||||
llm_num_retries = int(os.environ.get('LLM_NUM_RETRIES', '4'))
|
||||
llm_retry_min_wait = int(os.environ.get('LLM_RETRY_MIN_WAIT', '5'))
|
||||
llm_retry_max_wait = int(os.environ.get('LLM_RETRY_MAX_WAIT', '30'))
|
||||
llm_retry_multiplier = int(os.environ.get('LLM_RETRY_MULTIPLIER', 2))
|
||||
llm_timeout = int(os.environ.get('LLM_TIMEOUT', 0))
|
||||
|
||||
# Create LLMConfig instance
|
||||
llm_config = LLMConfig(
|
||||
model=model,
|
||||
api_key=SecretStr(api_key) if api_key else None,
|
||||
base_url=base_url,
|
||||
num_retries=llm_num_retries,
|
||||
retry_min_wait=llm_retry_min_wait,
|
||||
retry_max_wait=llm_retry_max_wait,
|
||||
retry_multiplier=llm_retry_multiplier,
|
||||
timeout=llm_timeout,
|
||||
)
|
||||
|
||||
# Only set api_version if it was explicitly provided, otherwise let LLMConfig handle it
|
||||
@@ -152,8 +161,6 @@ class IssueResolver:
|
||||
|
||||
self.owner = owner
|
||||
self.repo = repo
|
||||
self.token = token
|
||||
self.username = username
|
||||
self.platform = platform
|
||||
self.runtime_container_image = runtime_container_image
|
||||
self.base_container_image = base_container_image
|
||||
@@ -165,9 +172,20 @@ class IssueResolver:
|
||||
self.repo_instruction = repo_instruction
|
||||
self.issue_number = args.issue_number
|
||||
self.comment_id = args.comment_id
|
||||
self.base_domain = base_domain
|
||||
self.platform = platform
|
||||
|
||||
factory = IssueHandlerFactory(
|
||||
owner=self.owner,
|
||||
repo=self.repo,
|
||||
token=token,
|
||||
username=username,
|
||||
platform=self.platform,
|
||||
base_domain=base_domain,
|
||||
issue_type=self.issue_type,
|
||||
llm_config=self.llm_config,
|
||||
)
|
||||
self.issue_handler = factory.create()
|
||||
|
||||
def initialize_runtime(
|
||||
self,
|
||||
runtime: Runtime,
|
||||
@@ -435,58 +453,6 @@ class IssueResolver:
|
||||
)
|
||||
return output
|
||||
|
||||
def issue_handler_factory(self) -> ServiceContextIssue | ServiceContextPR:
|
||||
# Determine default base_domain based on platform
|
||||
|
||||
if self.issue_type == 'issue':
|
||||
if self.platform == ProviderType.GITHUB:
|
||||
return ServiceContextIssue(
|
||||
GithubIssueHandler(
|
||||
self.owner,
|
||||
self.repo,
|
||||
self.token,
|
||||
self.username,
|
||||
self.base_domain,
|
||||
),
|
||||
self.llm_config,
|
||||
)
|
||||
else: # platform == Platform.GITLAB
|
||||
return ServiceContextIssue(
|
||||
GitlabIssueHandler(
|
||||
self.owner,
|
||||
self.repo,
|
||||
self.token,
|
||||
self.username,
|
||||
self.base_domain,
|
||||
),
|
||||
self.llm_config,
|
||||
)
|
||||
elif self.issue_type == 'pr':
|
||||
if self.platform == ProviderType.GITHUB:
|
||||
return ServiceContextPR(
|
||||
GithubPRHandler(
|
||||
self.owner,
|
||||
self.repo,
|
||||
self.token,
|
||||
self.username,
|
||||
self.base_domain,
|
||||
),
|
||||
self.llm_config,
|
||||
)
|
||||
else: # platform == Platform.GITLAB
|
||||
return ServiceContextPR(
|
||||
GitlabPRHandler(
|
||||
self.owner,
|
||||
self.repo,
|
||||
self.token,
|
||||
self.username,
|
||||
self.base_domain,
|
||||
),
|
||||
self.llm_config,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Invalid issue type: {self.issue_type}')
|
||||
|
||||
async def resolve_issue(
|
||||
self,
|
||||
reset_logger: bool = False,
|
||||
@@ -497,10 +463,8 @@ class IssueResolver:
|
||||
reset_logger: Whether to reset the logger for multiprocessing.
|
||||
"""
|
||||
|
||||
issue_handler = self.issue_handler_factory()
|
||||
|
||||
# Load dataset
|
||||
issues: list[Issue] = issue_handler.get_converted_issues(
|
||||
issues: list[Issue] = self.issue_handler.get_converted_issues(
|
||||
issue_numbers=[self.issue_number], comment_id=self.comment_id
|
||||
)
|
||||
|
||||
@@ -546,7 +510,7 @@ class IssueResolver:
|
||||
[
|
||||
'git',
|
||||
'clone',
|
||||
issue_handler.get_clone_url(),
|
||||
self.issue_handler.get_clone_url(),
|
||||
f'{self.output_dir}/repo',
|
||||
]
|
||||
).decode('utf-8')
|
||||
@@ -625,7 +589,7 @@ class IssueResolver:
|
||||
output = await self.process_issue(
|
||||
issue,
|
||||
base_commit,
|
||||
issue_handler,
|
||||
self.issue_handler,
|
||||
reset_logger,
|
||||
)
|
||||
output_fp.write(output.model_dump_json() + '\n')
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Type
|
||||
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.runtime.impl.daytona.daytona_runtime import DaytonaRuntime
|
||||
from openhands.runtime.impl.docker.docker_runtime import (
|
||||
@@ -13,7 +11,7 @@ from openhands.runtime.impl.runloop.runloop_runtime import RunloopRuntime
|
||||
from openhands.utils.import_utils import get_impl
|
||||
|
||||
# mypy: disable-error-code="type-abstract"
|
||||
_DEFAULT_RUNTIME_CLASSES: dict[str, Type[Runtime]] = {
|
||||
_DEFAULT_RUNTIME_CLASSES: dict[str, type[Runtime]] = {
|
||||
'eventstream': DockerRuntime,
|
||||
'docker': DockerRuntime,
|
||||
'e2b': E2BRuntime,
|
||||
@@ -25,7 +23,7 @@ _DEFAULT_RUNTIME_CLASSES: dict[str, Type[Runtime]] = {
|
||||
}
|
||||
|
||||
|
||||
def get_runtime_cls(name: str) -> Type[Runtime]:
|
||||
def get_runtime_cls(name: str) -> type[Runtime]:
|
||||
"""
|
||||
If name is one of the predefined runtime names (e.g. 'docker'), return its class.
|
||||
Otherwise attempt to resolve name as subclass of Runtime and return it.
|
||||
|
||||
@@ -13,6 +13,7 @@ import logging
|
||||
import mimetypes
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import traceback
|
||||
@@ -76,6 +77,10 @@ from openhands.utils.async_utils import call_sync_from_async, wait_all
|
||||
mcp_router_logger.setLevel(logger.getEffectiveLevel())
|
||||
|
||||
|
||||
if sys.platform == 'win32':
|
||||
from openhands.runtime.utils.windows_bash import WindowsPowershellSession
|
||||
|
||||
|
||||
class ActionRequest(BaseModel):
|
||||
action: dict
|
||||
|
||||
@@ -100,7 +105,7 @@ def _execute_file_editor(
|
||||
view_range: list[int] | None = None,
|
||||
old_str: str | None = None,
|
||||
new_str: str | None = None,
|
||||
insert_line: int | None = None,
|
||||
insert_line: int | str | None = None,
|
||||
enable_linting: bool = False,
|
||||
) -> tuple[str, tuple[str | None, str | None]]:
|
||||
"""Execute file editor command and handle exceptions.
|
||||
@@ -113,13 +118,24 @@ def _execute_file_editor(
|
||||
view_range: Optional view range tuple (start, end)
|
||||
old_str: Optional string to replace
|
||||
new_str: Optional replacement string
|
||||
insert_line: Optional line number for insertion
|
||||
insert_line: Optional line number for insertion (can be int or str)
|
||||
enable_linting: Whether to enable linting
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing the output string and a tuple of old and new file content
|
||||
"""
|
||||
result: ToolResult | None = None
|
||||
|
||||
# Convert insert_line from string to int if needed
|
||||
if insert_line is not None and isinstance(insert_line, str):
|
||||
try:
|
||||
insert_line = int(insert_line)
|
||||
except ValueError:
|
||||
return (
|
||||
f"ERROR:\nInvalid insert_line value: '{insert_line}'. Expected an integer.",
|
||||
(None, None),
|
||||
)
|
||||
|
||||
try:
|
||||
result = editor(
|
||||
command=command,
|
||||
@@ -133,6 +149,9 @@ def _execute_file_editor(
|
||||
)
|
||||
except ToolError as e:
|
||||
result = ToolResult(error=e.message)
|
||||
except TypeError as e:
|
||||
# Handle unexpected arguments or type errors
|
||||
return f'ERROR:\n{str(e)}', (None, None)
|
||||
|
||||
if result.error:
|
||||
return f'ERROR:\n{result.error}', (None, None)
|
||||
@@ -167,13 +186,14 @@ class ActionExecutor:
|
||||
if _updated_user_id is not None:
|
||||
self.user_id = _updated_user_id
|
||||
|
||||
self.bash_session: BashSession | None = None
|
||||
self.bash_session: BashSession | 'WindowsPowershellSession' | None = None # type: ignore[name-defined]
|
||||
self.lock = asyncio.Lock()
|
||||
self.plugins: dict[str, Plugin] = {}
|
||||
self.file_editor = OHEditor(workspace_root=self._initial_cwd)
|
||||
self.browser: BrowserEnv | None = None
|
||||
self.browser_init_task: asyncio.Task | None = None
|
||||
self.browsergym_eval_env = browsergym_eval_env
|
||||
|
||||
self.start_time = time.time()
|
||||
self.last_execution_time = self.start_time
|
||||
self._initialized = False
|
||||
@@ -199,6 +219,10 @@ class ActionExecutor:
|
||||
|
||||
async def _init_browser_async(self):
|
||||
"""Initialize the browser asynchronously."""
|
||||
if sys.platform == 'win32':
|
||||
logger.warning('Browser environment not supported on windows')
|
||||
return
|
||||
|
||||
logger.debug('Initializing browser asynchronously')
|
||||
try:
|
||||
self.browser = BrowserEnv(self.browsergym_eval_env)
|
||||
@@ -232,15 +256,25 @@ class ActionExecutor:
|
||||
async def ainit(self):
|
||||
# bash needs to be initialized first
|
||||
logger.debug('Initializing bash session')
|
||||
self.bash_session = BashSession(
|
||||
work_dir=self._initial_cwd,
|
||||
username=self.username,
|
||||
no_change_timeout_seconds=int(
|
||||
os.environ.get('NO_CHANGE_TIMEOUT_SECONDS', 10)
|
||||
),
|
||||
max_memory_mb=self.max_memory_gb * 1024 if self.max_memory_gb else None,
|
||||
)
|
||||
self.bash_session.initialize()
|
||||
if sys.platform == 'win32':
|
||||
self.bash_session = WindowsPowershellSession( # type: ignore[name-defined]
|
||||
work_dir=self._initial_cwd,
|
||||
username=self.username,
|
||||
no_change_timeout_seconds=int(
|
||||
os.environ.get('NO_CHANGE_TIMEOUT_SECONDS', 10)
|
||||
),
|
||||
max_memory_mb=self.max_memory_gb * 1024 if self.max_memory_gb else None,
|
||||
)
|
||||
else:
|
||||
self.bash_session = BashSession(
|
||||
work_dir=self._initial_cwd,
|
||||
username=self.username,
|
||||
no_change_timeout_seconds=int(
|
||||
os.environ.get('NO_CHANGE_TIMEOUT_SECONDS', 10)
|
||||
),
|
||||
max_memory_mb=self.max_memory_gb * 1024 if self.max_memory_gb else None,
|
||||
)
|
||||
self.bash_session.initialize()
|
||||
logger.debug('Bash session initialized')
|
||||
|
||||
# Start browser initialization in the background
|
||||
@@ -282,19 +316,55 @@ class ActionExecutor:
|
||||
logger.debug(f'Initializing plugin: {plugin.name}')
|
||||
|
||||
if isinstance(plugin, JupyterPlugin):
|
||||
# Escape backslashes in Windows path
|
||||
cwd = self.bash_session.cwd.replace('\\', '/')
|
||||
await self.run_ipython(
|
||||
IPythonRunCellAction(
|
||||
code=f'import os; os.chdir("{self.bash_session.cwd}")'
|
||||
)
|
||||
IPythonRunCellAction(code=f'import os; os.chdir(r"{cwd}")')
|
||||
)
|
||||
|
||||
async def _init_bash_commands(self):
|
||||
INIT_COMMANDS = [
|
||||
'git config --file ./.git_config user.name "openhands" && git config --file ./.git_config user.email "openhands@all-hands.dev" && alias git="git --no-pager" && export GIT_CONFIG=$(pwd)/.git_config'
|
||||
if os.environ.get('LOCAL_RUNTIME_MODE') == '1'
|
||||
else 'git config --global user.name "openhands" && git config --global user.email "openhands@all-hands.dev" && alias git="git --no-pager"'
|
||||
]
|
||||
logger.debug(f'Initializing by running {len(INIT_COMMANDS)} bash commands...')
|
||||
INIT_COMMANDS = []
|
||||
is_local_runtime = os.environ.get('LOCAL_RUNTIME_MODE') == '1'
|
||||
is_windows = sys.platform == 'win32'
|
||||
|
||||
# Determine git config commands based on platform and runtime mode
|
||||
if is_local_runtime:
|
||||
if is_windows:
|
||||
# Windows, local - split into separate commands
|
||||
INIT_COMMANDS.append(
|
||||
'git config --file ./.git_config user.name "openhands"'
|
||||
)
|
||||
INIT_COMMANDS.append(
|
||||
'git config --file ./.git_config user.email "openhands@all-hands.dev"'
|
||||
)
|
||||
INIT_COMMANDS.append(
|
||||
'$env:GIT_CONFIG = (Join-Path (Get-Location) ".git_config")'
|
||||
)
|
||||
else:
|
||||
# Linux/macOS, local
|
||||
base_git_config = (
|
||||
'git config --file ./.git_config user.name "openhands" && '
|
||||
'git config --file ./.git_config user.email "openhands@all-hands.dev" && '
|
||||
'export GIT_CONFIG=$(pwd)/.git_config'
|
||||
)
|
||||
INIT_COMMANDS.append(base_git_config)
|
||||
else:
|
||||
# Non-local (implies Linux/macOS)
|
||||
base_git_config = (
|
||||
'git config --global user.name "openhands" && '
|
||||
'git config --global user.email "openhands@all-hands.dev"'
|
||||
)
|
||||
INIT_COMMANDS.append(base_git_config)
|
||||
|
||||
# Determine no-pager command
|
||||
if is_windows:
|
||||
no_pager_cmd = 'function git { git.exe --no-pager $args }'
|
||||
else:
|
||||
no_pager_cmd = 'alias git="git --no-pager"'
|
||||
|
||||
INIT_COMMANDS.append(no_pager_cmd)
|
||||
|
||||
logger.info(f'Initializing by running {len(INIT_COMMANDS)} bash commands...')
|
||||
for command in INIT_COMMANDS:
|
||||
action = CmdRunAction(command=command)
|
||||
action.set_hard_timeout(300)
|
||||
@@ -345,9 +415,9 @@ class ActionExecutor:
|
||||
logger.debug(
|
||||
f'{self.bash_session.cwd} != {jupyter_cwd} -> reset Jupyter PWD'
|
||||
)
|
||||
reset_jupyter_cwd_code = (
|
||||
f'import os; os.chdir("{self.bash_session.cwd}")'
|
||||
)
|
||||
# escape windows paths
|
||||
cwd = self.bash_session.cwd.replace('\\', '/')
|
||||
reset_jupyter_cwd_code = f'import os; os.chdir("{cwd}")'
|
||||
_aux_action = IPythonRunCellAction(code=reset_jupyter_cwd_code)
|
||||
_reset_obs: IPythonRunCellObservation = await _jupyter_plugin.run(
|
||||
_aux_action
|
||||
@@ -527,12 +597,20 @@ class ActionExecutor:
|
||||
)
|
||||
|
||||
async def browse(self, action: BrowseURLAction) -> Observation:
|
||||
if self.browser is None:
|
||||
return ErrorObservation(
|
||||
'Browser functionality is not supported on Windows.'
|
||||
)
|
||||
await self._ensure_browser_ready()
|
||||
return await browse(action, self.browser)
|
||||
return await browse(action, self.browser, self.initial_cwd)
|
||||
|
||||
async def browse_interactive(self, action: BrowseInteractiveAction) -> Observation:
|
||||
if self.browser is None:
|
||||
return ErrorObservation(
|
||||
'Browser functionality is not supported on Windows.'
|
||||
)
|
||||
await self._ensure_browser_ready()
|
||||
return await browse(action, self.browser)
|
||||
return await browse(action, self.browser, self.initial_cwd)
|
||||
|
||||
def close(self):
|
||||
self.memory_monitor.stop_monitoring()
|
||||
@@ -726,7 +804,6 @@ if __name__ == '__main__':
|
||||
if not isinstance(action, Action):
|
||||
raise HTTPException(status_code=400, detail='Invalid action type')
|
||||
client.last_execution_time = time.time()
|
||||
|
||||
observation = await client.run_action(action)
|
||||
return event_to_dict(observation)
|
||||
except Exception as e:
|
||||
@@ -897,7 +974,7 @@ if __name__ == '__main__':
|
||||
|
||||
To list files:
|
||||
```sh
|
||||
curl http://localhost:3000/api/list-files
|
||||
curl -X POST -d '{"path": "/"}' http://localhost:3000/list_files
|
||||
```
|
||||
|
||||
Args:
|
||||
|
||||
@@ -72,6 +72,7 @@ STATUS_MESSAGES = {
|
||||
'STATUS$CONTAINER_STARTED': 'Container started.',
|
||||
'STATUS$WAITING_FOR_CLIENT': 'Waiting for client...',
|
||||
'STATUS$SETTING_UP_WORKSPACE': 'Setting up workspace...',
|
||||
'STATUS$SETTING_UP_GIT_HOOKS': 'Setting up git hooks...',
|
||||
}
|
||||
|
||||
|
||||
@@ -424,21 +425,278 @@ class Runtime(FileEditRuntimeMixin):
|
||||
if isinstance(obs, CmdOutputObservation) and obs.exit_code != 0:
|
||||
self.log('error', f'Setup script failed: {obs.content}')
|
||||
|
||||
def maybe_setup_git_hooks(self):
|
||||
"""Set up git hooks if .openhands/pre-commit.sh exists in the workspace or repository."""
|
||||
pre_commit_script = '.openhands/pre-commit.sh'
|
||||
read_obs = self.read(FileReadAction(path=pre_commit_script))
|
||||
if isinstance(read_obs, ErrorObservation):
|
||||
return
|
||||
|
||||
if self.status_callback:
|
||||
self.status_callback(
|
||||
'info', 'STATUS$SETTING_UP_GIT_HOOKS', 'Setting up git hooks...'
|
||||
)
|
||||
|
||||
# Ensure the git hooks directory exists
|
||||
action = CmdRunAction('mkdir -p .git/hooks')
|
||||
obs = self.run_action(action)
|
||||
if isinstance(obs, CmdOutputObservation) and obs.exit_code != 0:
|
||||
self.log('error', f'Failed to create git hooks directory: {obs.content}')
|
||||
return
|
||||
|
||||
# Make the pre-commit script executable
|
||||
action = CmdRunAction(f'chmod +x {pre_commit_script}')
|
||||
obs = self.run_action(action)
|
||||
if isinstance(obs, CmdOutputObservation) and obs.exit_code != 0:
|
||||
self.log(
|
||||
'error', f'Failed to make pre-commit script executable: {obs.content}'
|
||||
)
|
||||
return
|
||||
|
||||
# Check if there's an existing pre-commit hook
|
||||
pre_commit_hook = '.git/hooks/pre-commit'
|
||||
pre_commit_local = '.git/hooks/pre-commit.local'
|
||||
|
||||
# Read the existing pre-commit hook if it exists
|
||||
read_obs = self.read(FileReadAction(path=pre_commit_hook))
|
||||
if not isinstance(read_obs, ErrorObservation):
|
||||
# If the existing hook wasn't created by OpenHands, preserve it
|
||||
if 'This hook was installed by OpenHands' not in read_obs.content:
|
||||
self.log('info', 'Preserving existing pre-commit hook')
|
||||
# Move the existing hook to pre-commit.local
|
||||
action = CmdRunAction(f'mv {pre_commit_hook} {pre_commit_local}')
|
||||
obs = self.run_action(action)
|
||||
if isinstance(obs, CmdOutputObservation) and obs.exit_code != 0:
|
||||
self.log(
|
||||
'error',
|
||||
f'Failed to preserve existing pre-commit hook: {obs.content}',
|
||||
)
|
||||
return
|
||||
|
||||
# Make it executable
|
||||
action = CmdRunAction(f'chmod +x {pre_commit_local}')
|
||||
obs = self.run_action(action)
|
||||
if isinstance(obs, CmdOutputObservation) and obs.exit_code != 0:
|
||||
self.log(
|
||||
'error',
|
||||
f'Failed to make preserved hook executable: {obs.content}',
|
||||
)
|
||||
return
|
||||
|
||||
# Create the pre-commit hook that calls our script
|
||||
pre_commit_hook_content = f"""#!/bin/bash
|
||||
# This hook was installed by OpenHands
|
||||
# It calls the pre-commit script in the .openhands directory
|
||||
|
||||
if [ -x "{pre_commit_script}" ]; then
|
||||
source "{pre_commit_script}"
|
||||
exit $?
|
||||
else
|
||||
echo "Warning: {pre_commit_script} not found or not executable"
|
||||
exit 0
|
||||
fi
|
||||
"""
|
||||
|
||||
# Write the pre-commit hook
|
||||
write_obs = self.write(
|
||||
FileWriteAction(path=pre_commit_hook, content=pre_commit_hook_content)
|
||||
)
|
||||
if isinstance(write_obs, ErrorObservation):
|
||||
self.log('error', f'Failed to write pre-commit hook: {write_obs.content}')
|
||||
return
|
||||
|
||||
# Make the pre-commit hook executable
|
||||
action = CmdRunAction(f'chmod +x {pre_commit_hook}')
|
||||
obs = self.run_action(action)
|
||||
if isinstance(obs, CmdOutputObservation) and obs.exit_code != 0:
|
||||
self.log(
|
||||
'error', f'Failed to make pre-commit hook executable: {obs.content}'
|
||||
)
|
||||
return
|
||||
|
||||
self.log('info', 'Git pre-commit hook installed successfully')
|
||||
|
||||
def _load_microagents_from_directory(
|
||||
self, microagents_dir: Path, source_description: str
|
||||
) -> list[BaseMicroagent]:
|
||||
"""Load microagents from a directory.
|
||||
|
||||
Args:
|
||||
microagents_dir: Path to the directory containing microagents
|
||||
source_description: Description of the source for logging purposes
|
||||
|
||||
Returns:
|
||||
A list of loaded microagents
|
||||
"""
|
||||
loaded_microagents: list[BaseMicroagent] = []
|
||||
files = self.list_files(str(microagents_dir))
|
||||
|
||||
if not files:
|
||||
return loaded_microagents
|
||||
|
||||
self.log(
|
||||
'info',
|
||||
f'Found {len(files)} files in {source_description} microagents directory',
|
||||
)
|
||||
zip_path = self.copy_from(str(microagents_dir))
|
||||
microagent_folder = tempfile.mkdtemp()
|
||||
|
||||
try:
|
||||
with ZipFile(zip_path, 'r') as zip_file:
|
||||
zip_file.extractall(microagent_folder)
|
||||
|
||||
zip_path.unlink()
|
||||
repo_agents, knowledge_agents = load_microagents_from_dir(microagent_folder)
|
||||
|
||||
self.log(
|
||||
'info',
|
||||
f'Loaded {len(repo_agents)} repo agents and {len(knowledge_agents)} knowledge agents from {source_description}',
|
||||
)
|
||||
|
||||
loaded_microagents.extend(repo_agents.values())
|
||||
loaded_microagents.extend(knowledge_agents.values())
|
||||
finally:
|
||||
shutil.rmtree(microagent_folder)
|
||||
|
||||
return loaded_microagents
|
||||
|
||||
def _get_authenticated_git_url(self, repo_path: str) -> str:
|
||||
"""Get an authenticated git URL for a repository.
|
||||
|
||||
Args:
|
||||
repo_path: Repository path (e.g., "github.com/acme-co/api")
|
||||
|
||||
Returns:
|
||||
Authenticated git URL if credentials are available, otherwise regular HTTPS URL
|
||||
"""
|
||||
remote_url = f'https://{repo_path}.git'
|
||||
|
||||
# Determine provider from repo path
|
||||
provider = None
|
||||
if 'github.com' in repo_path:
|
||||
provider = ProviderType.GITHUB
|
||||
elif 'gitlab.com' in repo_path:
|
||||
provider = ProviderType.GITLAB
|
||||
|
||||
# Add authentication if available
|
||||
if (
|
||||
provider
|
||||
and self.git_provider_tokens
|
||||
and provider in self.git_provider_tokens
|
||||
):
|
||||
git_token = self.git_provider_tokens[provider].token
|
||||
if git_token:
|
||||
if provider == ProviderType.GITLAB:
|
||||
remote_url = f'https://oauth2:{git_token.get_secret_value()}@{repo_path.replace("gitlab.com/", "")}.git'
|
||||
else:
|
||||
remote_url = f'https://{git_token.get_secret_value()}@{repo_path.replace("github.com/", "")}.git'
|
||||
|
||||
return remote_url
|
||||
|
||||
def get_microagents_from_org_or_user(
|
||||
self, selected_repository: str
|
||||
) -> list[BaseMicroagent]:
|
||||
"""Load microagents from the organization or user level .openhands repository.
|
||||
|
||||
For example, if the repository is github.com/acme-co/api, this will check if
|
||||
github.com/acme-co/.openhands exists. If it does, it will clone it and load
|
||||
the microagents from the ./microagents/ folder.
|
||||
|
||||
Args:
|
||||
selected_repository: The repository path (e.g., "github.com/acme-co/api")
|
||||
|
||||
Returns:
|
||||
A list of loaded microagents from the org/user level repository
|
||||
"""
|
||||
loaded_microagents: list[BaseMicroagent] = []
|
||||
workspace_root = Path(self.config.workspace_mount_path_in_sandbox)
|
||||
|
||||
repo_parts = selected_repository.split('/')
|
||||
if len(repo_parts) < 2:
|
||||
return loaded_microagents
|
||||
|
||||
# Extract the domain and org/user name
|
||||
domain = repo_parts[0] if len(repo_parts) > 2 else 'github.com'
|
||||
org_name = repo_parts[-2]
|
||||
|
||||
# Construct the org-level .openhands repo path
|
||||
org_openhands_repo = f'{domain}/{org_name}/.openhands'
|
||||
if domain not in org_openhands_repo:
|
||||
org_openhands_repo = f'github.com/{org_openhands_repo}'
|
||||
|
||||
self.log(
|
||||
'info',
|
||||
f'Checking for org-level microagents at {org_openhands_repo}',
|
||||
)
|
||||
|
||||
# Try to clone the org-level .openhands repo
|
||||
try:
|
||||
# Create a temporary directory for the org-level repo
|
||||
org_repo_dir = workspace_root / f'org_openhands_{org_name}'
|
||||
|
||||
# Get authenticated URL and do a shallow clone (--depth 1) for efficiency
|
||||
remote_url = self._get_authenticated_git_url(org_openhands_repo)
|
||||
clone_cmd = f"git clone --depth 1 {remote_url} {org_repo_dir} 2>/dev/null || echo 'Org repo not found'"
|
||||
|
||||
action = CmdRunAction(command=clone_cmd)
|
||||
obs = self.run_action(action)
|
||||
|
||||
if (
|
||||
isinstance(obs, CmdOutputObservation)
|
||||
and obs.exit_code == 0
|
||||
and 'Org repo not found' not in obs.content
|
||||
):
|
||||
self.log(
|
||||
'info',
|
||||
f'Successfully cloned org-level microagents from {org_openhands_repo}',
|
||||
)
|
||||
|
||||
# Load microagents from the org-level repo
|
||||
org_microagents_dir = org_repo_dir / 'microagents'
|
||||
loaded_microagents = self._load_microagents_from_directory(
|
||||
org_microagents_dir, 'org-level'
|
||||
)
|
||||
|
||||
# Clean up the org repo directory
|
||||
shutil.rmtree(org_repo_dir)
|
||||
else:
|
||||
self.log(
|
||||
'info',
|
||||
f'No org-level microagents found at {org_openhands_repo}',
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.log('error', f'Error loading org-level microagents: {str(e)}')
|
||||
|
||||
return loaded_microagents
|
||||
|
||||
def get_microagents_from_selected_repo(
|
||||
self, selected_repository: str | None
|
||||
) -> list[BaseMicroagent]:
|
||||
"""Load microagents from the selected repository.
|
||||
If selected_repository is None, load microagents from the current workspace.
|
||||
This is the main entry point for loading microagents.
|
||||
|
||||
This method also checks for user/org level microagents stored in a .openhands repository.
|
||||
For example, if the repository is github.com/acme-co/api, it will also check for
|
||||
github.com/acme-co/.openhands and load microagents from there if it exists.
|
||||
"""
|
||||
|
||||
loaded_microagents: list[BaseMicroagent] = []
|
||||
workspace_root = Path(self.config.workspace_mount_path_in_sandbox)
|
||||
microagents_dir = workspace_root / '.openhands' / 'microagents'
|
||||
repo_root = None
|
||||
|
||||
# Check for user/org level microagents if a repository is selected
|
||||
if selected_repository:
|
||||
# Load microagents from the org/user level repository
|
||||
org_microagents = self.get_microagents_from_org_or_user(selected_repository)
|
||||
loaded_microagents.extend(org_microagents)
|
||||
|
||||
# Continue with repository-specific microagents
|
||||
repo_root = workspace_root / selected_repository.split('/')[-1]
|
||||
microagents_dir = repo_root / '.openhands' / 'microagents'
|
||||
|
||||
self.log(
|
||||
'info',
|
||||
f'Selected repo: {selected_repository}, loading microagents from {microagents_dir} (inside runtime)',
|
||||
@@ -470,35 +728,10 @@ class Runtime(FileEditRuntimeMixin):
|
||||
)
|
||||
|
||||
# Load microagents from directory
|
||||
files = self.list_files(str(microagents_dir))
|
||||
if files:
|
||||
self.log('info', f'Found {len(files)} files in microagents directory.')
|
||||
zip_path = self.copy_from(str(microagents_dir))
|
||||
microagent_folder = tempfile.mkdtemp()
|
||||
|
||||
# Properly handle the zip file
|
||||
with ZipFile(zip_path, 'r') as zip_file:
|
||||
zip_file.extractall(microagent_folder)
|
||||
|
||||
# Add debug print of directory structure
|
||||
self.log('debug', 'Microagent folder structure:')
|
||||
for root, _, files in os.walk(microagent_folder):
|
||||
relative_path = os.path.relpath(root, microagent_folder)
|
||||
self.log('debug', f'Directory: {relative_path}/')
|
||||
for file in files:
|
||||
self.log('debug', f' File: {os.path.join(relative_path, file)}')
|
||||
|
||||
# Clean up the temporary zip file
|
||||
zip_path.unlink()
|
||||
# Load all microagents using the existing function
|
||||
repo_agents, knowledge_agents = load_microagents_from_dir(microagent_folder)
|
||||
self.log(
|
||||
'info',
|
||||
f'Loaded {len(repo_agents)} repo agents and {len(knowledge_agents)} knowledge agents',
|
||||
)
|
||||
loaded_microagents.extend(repo_agents.values())
|
||||
loaded_microagents.extend(knowledge_agents.values())
|
||||
shutil.rmtree(microagent_folder)
|
||||
repo_microagents = self._load_microagents_from_directory(
|
||||
microagents_dir, 'repository'
|
||||
)
|
||||
loaded_microagents.extend(repo_microagents)
|
||||
|
||||
return loaded_microagents
|
||||
|
||||
|
||||
34
openhands/runtime/browser/base64.py
Normal file
34
openhands/runtime/browser/base64.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import base64
|
||||
import io
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def image_to_png_base64_url(
|
||||
image: np.ndarray | Image.Image, add_data_prefix: bool = False
|
||||
) -> str:
|
||||
"""Convert a numpy array to a base64 encoded png image url."""
|
||||
if isinstance(image, np.ndarray):
|
||||
image = Image.fromarray(image)
|
||||
if image.mode in ('RGBA', 'LA'):
|
||||
image = image.convert('RGB')
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format='PNG')
|
||||
|
||||
image_base64 = base64.b64encode(buffered.getvalue()).decode()
|
||||
return (
|
||||
f'data:image/png;base64,{image_base64}'
|
||||
if add_data_prefix
|
||||
else f'{image_base64}'
|
||||
)
|
||||
|
||||
|
||||
def png_base64_url_to_image(png_base64_url: str) -> Image.Image:
|
||||
"""Convert a base64 encoded png image url to a PIL Image."""
|
||||
splited = png_base64_url.split(',')
|
||||
if len(splited) == 2:
|
||||
base64_data = splited[1]
|
||||
else:
|
||||
base64_data = png_base64_url
|
||||
return Image.open(io.BytesIO(base64.b64decode(base64_data)))
|
||||
@@ -1,6 +1,4 @@
|
||||
import atexit
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import multiprocessing
|
||||
import time
|
||||
@@ -9,13 +7,12 @@ import uuid
|
||||
import browsergym.core # noqa F401 (we register the openended task as a gym environment)
|
||||
import gymnasium as gym
|
||||
import html2text
|
||||
import numpy as np
|
||||
import tenacity
|
||||
from browsergym.utils.obs import flatten_dom_to_str, overlay_som
|
||||
from PIL import Image
|
||||
|
||||
from openhands.core.exceptions import BrowserInitException
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.runtime.browser.base64 import image_to_png_base64_url
|
||||
from openhands.utils.shutdown_listener import should_continue, should_exit
|
||||
from openhands.utils.tenacity_stop import stop_if_should_exit
|
||||
|
||||
@@ -40,7 +37,7 @@ class BrowserEnv:
|
||||
self.init_browser()
|
||||
atexit.register(self.close)
|
||||
|
||||
def get_html_text_converter(self):
|
||||
def get_html_text_converter(self) -> html2text.HTML2Text:
|
||||
html_text_converter = html2text.HTML2Text()
|
||||
# ignore links and images
|
||||
html_text_converter.ignore_links = False
|
||||
@@ -56,7 +53,7 @@ class BrowserEnv:
|
||||
stop=tenacity.stop_after_attempt(5) | stop_if_should_exit(),
|
||||
retry=tenacity.retry_if_exception_type(BrowserInitException),
|
||||
)
|
||||
def init_browser(self):
|
||||
def init_browser(self) -> None:
|
||||
logger.debug('Starting browser env...')
|
||||
try:
|
||||
self.process = multiprocessing.Process(target=self.browser_process)
|
||||
@@ -69,7 +66,7 @@ class BrowserEnv:
|
||||
self.close()
|
||||
raise BrowserInitException('Failed to start browser environment.')
|
||||
|
||||
def browser_process(self):
|
||||
def browser_process(self) -> None:
|
||||
if self.eval_mode:
|
||||
assert self.browsergym_eval_env is not None
|
||||
logger.info('Initializing browser env for web browsing evaluation.')
|
||||
@@ -165,13 +162,13 @@ class BrowserEnv:
|
||||
html_str = flatten_dom_to_str(obs['dom_object'])
|
||||
obs['text_content'] = self.html_text_converter.handle(html_str)
|
||||
# make observation serializable
|
||||
obs['set_of_marks'] = self.image_to_png_base64_url(
|
||||
obs['set_of_marks'] = image_to_png_base64_url(
|
||||
overlay_som(
|
||||
obs['screenshot'], obs.get('extra_element_properties', {})
|
||||
),
|
||||
add_data_prefix=True,
|
||||
)
|
||||
obs['screenshot'] = self.image_to_png_base64_url(
|
||||
obs['screenshot'] = image_to_png_base64_url(
|
||||
obs['screenshot'], add_data_prefix=True
|
||||
)
|
||||
obs['active_page_index'] = obs['active_page_index'].item()
|
||||
@@ -196,17 +193,18 @@ class BrowserEnv:
|
||||
if self.agent_side.poll(timeout=0.01):
|
||||
response_id, obs = self.agent_side.recv()
|
||||
if response_id == unique_request_id:
|
||||
return obs
|
||||
return dict(obs)
|
||||
|
||||
def check_alive(self, timeout: float = 60):
|
||||
def check_alive(self, timeout: float = 60) -> bool:
|
||||
self.agent_side.send(('IS_ALIVE', None))
|
||||
if self.agent_side.poll(timeout=timeout):
|
||||
response_id, _ = self.agent_side.recv()
|
||||
if response_id == 'ALIVE':
|
||||
return True
|
||||
logger.debug(f'Browser env is not alive. Response ID: {response_id}')
|
||||
return False
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
if not self.process.is_alive():
|
||||
return
|
||||
try:
|
||||
@@ -225,41 +223,3 @@ class BrowserEnv:
|
||||
self.browser_side.close()
|
||||
except Exception as e:
|
||||
logger.error(f'Encountered an error when closing browser env: {e}')
|
||||
|
||||
@staticmethod
|
||||
def image_to_png_base64_url(
|
||||
image: np.ndarray | Image.Image, add_data_prefix: bool = False
|
||||
):
|
||||
"""Convert a numpy array to a base64 encoded png image url."""
|
||||
if isinstance(image, np.ndarray):
|
||||
image = Image.fromarray(image)
|
||||
if image.mode in ('RGBA', 'LA'):
|
||||
image = image.convert('RGB')
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format='PNG')
|
||||
|
||||
image_base64 = base64.b64encode(buffered.getvalue()).decode()
|
||||
return (
|
||||
f'data:image/png;base64,{image_base64}'
|
||||
if add_data_prefix
|
||||
else f'{image_base64}'
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def image_to_jpg_base64_url(
|
||||
image: np.ndarray | Image.Image, add_data_prefix: bool = False
|
||||
):
|
||||
"""Convert a numpy array to a base64 encoded jpeg image url."""
|
||||
if isinstance(image, np.ndarray):
|
||||
image = Image.fromarray(image)
|
||||
if image.mode in ('RGBA', 'LA'):
|
||||
image = image.convert('RGB')
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format='JPEG')
|
||||
|
||||
image_base64 = base64.b64encode(buffered.getvalue()).decode()
|
||||
return (
|
||||
f'data:image/jpeg;base64,{image_base64}'
|
||||
if add_data_prefix
|
||||
else f'{image_base64}'
|
||||
)
|
||||
|
||||
@@ -1,15 +1,23 @@
|
||||
import base64
|
||||
import datetime
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from openhands.core.exceptions import BrowserUnavailableException
|
||||
from openhands.core.schema import ActionType
|
||||
from openhands.events.action import BrowseInteractiveAction, BrowseURLAction
|
||||
from openhands.events.observation import BrowserOutputObservation
|
||||
from openhands.runtime.browser.base64 import png_base64_url_to_image
|
||||
from openhands.runtime.browser.browser_env import BrowserEnv
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
|
||||
async def browse(
|
||||
action: BrowseURLAction | BrowseInteractiveAction, browser: BrowserEnv | None
|
||||
action: BrowseURLAction | BrowseInteractiveAction,
|
||||
browser: BrowserEnv | None,
|
||||
workspace_dir: str | None = None,
|
||||
) -> BrowserOutputObservation:
|
||||
if browser is None:
|
||||
raise BrowserUnavailableException()
|
||||
@@ -31,10 +39,50 @@ async def browse(
|
||||
try:
|
||||
# obs provided by BrowserGym: see https://github.com/ServiceNow/BrowserGym/blob/main/core/src/browsergym/core/env.py#L396
|
||||
obs = await call_sync_from_async(browser.step, action_str)
|
||||
|
||||
# Save screenshot if workspace_dir is provided
|
||||
screenshot_path = None
|
||||
if workspace_dir is not None and obs.get('screenshot'):
|
||||
# Create screenshots directory if it doesn't exist
|
||||
screenshots_dir = Path(workspace_dir) / '.browser_screenshots'
|
||||
screenshots_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Generate a filename based on timestamp
|
||||
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S_%f')
|
||||
screenshot_filename = f'screenshot_{timestamp}.png'
|
||||
screenshot_path = str(screenshots_dir / screenshot_filename)
|
||||
|
||||
# Direct image saving from base64 data without using PIL's Image.open
|
||||
# This approach bypasses potential encoding issues that might occur when
|
||||
# converting between different image representations, ensuring the raw PNG
|
||||
# data from the browser is saved directly to disk.
|
||||
|
||||
# Extract the base64 data
|
||||
base64_data = obs.get('screenshot', '')
|
||||
if ',' in base64_data:
|
||||
base64_data = base64_data.split(',')[1]
|
||||
|
||||
try:
|
||||
# Decode base64 directly to binary
|
||||
image_data = base64.b64decode(base64_data)
|
||||
|
||||
# Write binary data directly to file
|
||||
with open(screenshot_path, 'wb') as f:
|
||||
f.write(image_data)
|
||||
|
||||
# Verify the image was saved correctly by opening it
|
||||
# This is just a verification step and can be removed in production
|
||||
Image.open(screenshot_path).verify()
|
||||
except Exception:
|
||||
# If direct saving fails, fall back to the original method
|
||||
image = png_base64_url_to_image(obs.get('screenshot'))
|
||||
image.save(screenshot_path, format='PNG', optimize=True)
|
||||
|
||||
return BrowserOutputObservation(
|
||||
content=obs['text_content'], # text content of the page
|
||||
url=obs.get('url', ''), # URL of the page
|
||||
screenshot=obs.get('screenshot', None), # base64-encoded screenshot, png
|
||||
screenshot_path=screenshot_path, # path to saved screenshot file
|
||||
set_of_marks=obs.get(
|
||||
'set_of_marks', None
|
||||
), # base64-encoded Set-of-Marks annotated screenshot, png,
|
||||
@@ -60,6 +108,7 @@ async def browse(
|
||||
return BrowserOutputObservation(
|
||||
content=str(e),
|
||||
screenshot='',
|
||||
screenshot_path=None,
|
||||
error=True,
|
||||
last_browser_action_error=str(e),
|
||||
url=asked_url if action.action == ActionType.BROWSE else '',
|
||||
|
||||
@@ -36,7 +36,7 @@ class DockerRuntimeBuilder(RuntimeBuilder):
|
||||
self.rolling_logger = RollingLogger(max_lines=10)
|
||||
|
||||
@staticmethod
|
||||
def check_buildx(is_podman: bool = False):
|
||||
def check_buildx(is_podman: bool = False) -> bool:
|
||||
"""Check if Docker Buildx is available"""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
|
||||
@@ -99,8 +99,8 @@ class RemoteRuntimeBuilder(RuntimeBuilder):
|
||||
logger.info(f'Build status: {status}')
|
||||
|
||||
if status == 'SUCCESS':
|
||||
logger.debug(f"Successfully built {status_data['image']}")
|
||||
return status_data['image']
|
||||
logger.debug(f'Successfully built {status_data["image"]}')
|
||||
return str(status_data['image'])
|
||||
elif status in [
|
||||
'FAILURE',
|
||||
'INTERNAL_ERROR',
|
||||
@@ -139,11 +139,11 @@ class RemoteRuntimeBuilder(RuntimeBuilder):
|
||||
|
||||
if result['exists']:
|
||||
logger.debug(
|
||||
f"Image {image_name} exists. "
|
||||
f"Uploaded at: {result['image']['upload_time']}, "
|
||||
f"Size: {result['image']['image_size_bytes'] / 1024 / 1024:.2f} MB"
|
||||
f'Image {image_name} exists. '
|
||||
f'Uploaded at: {result["image"]["upload_time"]}, '
|
||||
f'Size: {result["image"]["image_size_bytes"] / 1024 / 1024:.2f} MB'
|
||||
)
|
||||
else:
|
||||
logger.debug(f'Image {image_name} does not exist.')
|
||||
|
||||
return result['exists']
|
||||
return bool(result['exists'])
|
||||
|
||||
@@ -5,7 +5,6 @@ This server has no authentication and only listens to localhost traffic.
|
||||
|
||||
import os
|
||||
import threading
|
||||
from typing import Tuple
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
@@ -22,12 +21,12 @@ def create_app() -> FastAPI:
|
||||
)
|
||||
|
||||
@app.get('/')
|
||||
async def root():
|
||||
async def root() -> dict[str, str]:
|
||||
"""Root endpoint to check if the server is running."""
|
||||
return {'status': 'File viewer server is running'}
|
||||
|
||||
@app.get('/view')
|
||||
async def view_file(path: str, request: Request):
|
||||
async def view_file(path: str, request: Request) -> HTMLResponse:
|
||||
"""View a file using an embedded viewer.
|
||||
|
||||
Args:
|
||||
@@ -75,7 +74,7 @@ def create_app() -> FastAPI:
|
||||
return app
|
||||
|
||||
|
||||
def start_file_viewer_server(port: int) -> Tuple[str, threading.Thread]:
|
||||
def start_file_viewer_server(port: int) -> tuple[str, threading.Thread]:
|
||||
"""Start the file viewer server on the specified port or find an available one.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -158,7 +158,6 @@ class ActionExecutionClient(Runtime):
|
||||
|
||||
def copy_from(self, path: str) -> Path:
|
||||
"""Zip all files in the sandbox and return as a stream of bytes."""
|
||||
|
||||
try:
|
||||
params = {'path': path}
|
||||
with self.session.stream(
|
||||
@@ -183,25 +182,44 @@ class ActionExecutionClient(Runtime):
|
||||
if not os.path.exists(host_src):
|
||||
raise FileNotFoundError(f'Source file {host_src} does not exist')
|
||||
|
||||
temp_zip_path: str | None = None # Define temp_zip_path outside the try block
|
||||
|
||||
try:
|
||||
params = {'destination': sandbox_dest, 'recursive': str(recursive).lower()}
|
||||
file_to_upload = None
|
||||
upload_data = {}
|
||||
|
||||
if recursive:
|
||||
# Create and write the zip file inside the try block
|
||||
with tempfile.NamedTemporaryFile(
|
||||
suffix='.zip', delete=False
|
||||
) as temp_zip:
|
||||
temp_zip_path = temp_zip.name
|
||||
|
||||
with ZipFile(temp_zip_path, 'w') as zipf:
|
||||
for root, _, files in os.walk(host_src):
|
||||
for file in files:
|
||||
file_path = os.path.join(root, file)
|
||||
arcname = os.path.relpath(
|
||||
file_path, os.path.dirname(host_src)
|
||||
)
|
||||
zipf.write(file_path, arcname)
|
||||
try:
|
||||
with ZipFile(temp_zip_path, 'w') as zipf:
|
||||
for root, _, files in os.walk(host_src):
|
||||
for file in files:
|
||||
file_path = os.path.join(root, file)
|
||||
arcname = os.path.relpath(
|
||||
file_path, os.path.dirname(host_src)
|
||||
)
|
||||
zipf.write(file_path, arcname)
|
||||
|
||||
upload_data = {'file': open(temp_zip_path, 'rb')}
|
||||
self.log(
|
||||
'debug',
|
||||
f'Opening temporary zip file for upload: {temp_zip_path}',
|
||||
)
|
||||
file_to_upload = open(temp_zip_path, 'rb')
|
||||
upload_data = {'file': file_to_upload}
|
||||
except Exception as e:
|
||||
# Ensure temp file is cleaned up if zipping fails
|
||||
if temp_zip_path and os.path.exists(temp_zip_path):
|
||||
os.unlink(temp_zip_path)
|
||||
raise e # Re-raise the exception after cleanup attempt
|
||||
else:
|
||||
upload_data = {'file': open(host_src, 'rb')}
|
||||
file_to_upload = open(host_src, 'rb')
|
||||
upload_data = {'file': file_to_upload}
|
||||
|
||||
params = {'destination': sandbox_dest, 'recursive': str(recursive).lower()}
|
||||
|
||||
@@ -217,11 +235,18 @@ class ActionExecutionClient(Runtime):
|
||||
f'Copy completed: host:{host_src} -> runtime:{sandbox_dest}. Response: {response.text}',
|
||||
)
|
||||
finally:
|
||||
if recursive:
|
||||
os.unlink(temp_zip_path)
|
||||
self.log(
|
||||
'debug', f'Copy completed: host:{host_src} -> runtime:{sandbox_dest}'
|
||||
)
|
||||
if file_to_upload:
|
||||
file_to_upload.close()
|
||||
|
||||
# Cleanup the temporary zip file if it was created
|
||||
if temp_zip_path and os.path.exists(temp_zip_path):
|
||||
try:
|
||||
os.unlink(temp_zip_path)
|
||||
except Exception as e:
|
||||
self.log(
|
||||
'error',
|
||||
f'Failed to delete temporary zip file {temp_zip_path}: {e}',
|
||||
)
|
||||
|
||||
def get_vscode_token(self) -> str:
|
||||
if self.vscode_enabled and self.runtime_initialized:
|
||||
@@ -334,26 +359,34 @@ class ActionExecutionClient(Runtime):
|
||||
server.model_dump(mode='json')
|
||||
for server in updated_mcp_config.stdio_servers
|
||||
]
|
||||
self.log('debug', f'Updating MCP server to: {stdio_tools}')
|
||||
response = self._send_action_server_request(
|
||||
'POST',
|
||||
f'{self.action_execution_server_url}/update_mcp_server',
|
||||
json=stdio_tools,
|
||||
timeout=10,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(f'Failed to update MCP server: {response.text}')
|
||||
|
||||
# No API key by default. Child runtime can override this when appropriate
|
||||
updated_mcp_config.sse_servers.append(
|
||||
MCPSSEServerConfig(
|
||||
url=self.action_execution_server_url.rstrip('/') + '/sse', api_key=None
|
||||
if len(stdio_tools) > 0:
|
||||
self.log('debug', f'Updating MCP server to: {stdio_tools}')
|
||||
response = self._send_action_server_request(
|
||||
'POST',
|
||||
f'{self.action_execution_server_url}/update_mcp_server',
|
||||
json=stdio_tools,
|
||||
timeout=10,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
self.log('warning', f'Failed to update MCP server: {response.text}')
|
||||
|
||||
# No API key by default. Child runtime can override this when appropriate
|
||||
updated_mcp_config.sse_servers.append(
|
||||
MCPSSEServerConfig(
|
||||
url=self.action_execution_server_url.rstrip('/') + '/sse',
|
||||
api_key=None,
|
||||
)
|
||||
)
|
||||
self.log(
|
||||
'info',
|
||||
f'Updated MCP config: {updated_mcp_config.sse_servers}',
|
||||
)
|
||||
else:
|
||||
self.log(
|
||||
'debug',
|
||||
'MCP servers inside runtime is not updated since no stdio servers are provided',
|
||||
)
|
||||
)
|
||||
self.log(
|
||||
'debug',
|
||||
f'Updated MCP config by adding runtime as another server: {updated_mcp_config}',
|
||||
)
|
||||
return updated_mcp_config
|
||||
|
||||
async def call_tool_mcp(self, action: MCPAction) -> Observation:
|
||||
|
||||
@@ -12,18 +12,32 @@
|
||||
|
||||
### Step 2: Set Your API Key as an Environment Variable
|
||||
Run the following command in your terminal, replacing `<your-api-key>` with the actual key you copied:
|
||||
|
||||
Mac/Linux:
|
||||
```bash
|
||||
export DAYTONA_API_KEY="<your-api-key>"
|
||||
```
|
||||
|
||||
Windows PowerShell:
|
||||
```powershell
|
||||
$env:DAYTONA_API_KEY="<your-api-key>"
|
||||
```
|
||||
|
||||
This step ensures that OpenHands can authenticate with the Daytona platform when it runs.
|
||||
|
||||
### Step 3: Run OpenHands Locally Using Docker
|
||||
To start the latest version of OpenHands on your machine, execute the following command in your terminal:
|
||||
|
||||
Mac/Linux:
|
||||
```bash
|
||||
bash -i <(curl -sL https://get.daytona.io/openhands)
|
||||
```
|
||||
|
||||
Windows:
|
||||
```powershell
|
||||
powershell -Command "irm https://get.daytona.io/openhands-windows | iex"
|
||||
```
|
||||
|
||||
#### What This Command Does:
|
||||
- Downloads the latest OpenHands release script.
|
||||
- Runs the script in an interactive Bash session.
|
||||
@@ -36,10 +50,16 @@ Once executed, OpenHands should be running locally and ready for use.
|
||||
### Step 1: Set the `OPENHANDS_VERSION` Environment Variable
|
||||
Run the following command in your terminal, replacing `<openhands-release>` with the latest release's version seen in the [main README.md file](https://github.com/All-Hands-AI/OpenHands?tab=readme-ov-file#-quick-start):
|
||||
|
||||
#### Mac/Linux:
|
||||
```bash
|
||||
export OPENHANDS_VERSION="<openhands-release>" # e.g. 0.27
|
||||
```
|
||||
|
||||
#### Windows PowerShell:
|
||||
```powershell
|
||||
$env:OPENHANDS_VERSION="<openhands-release>" # e.g. 0.27
|
||||
```
|
||||
|
||||
### Step 2: Retrieve Your Daytona API Key
|
||||
1. Visit the [Daytona Dashboard](https://app.daytona.io/dashboard/keys).
|
||||
2. Click **"Create Key"**.
|
||||
@@ -48,13 +68,21 @@ export OPENHANDS_VERSION="<openhands-release>" # e.g. 0.27
|
||||
|
||||
### Step 3: Set Your API Key as an Environment Variable:
|
||||
Run the following command in your terminal, replacing `<your-api-key>` with the actual key you copied:
|
||||
|
||||
#### Mac/Linux:
|
||||
```bash
|
||||
export DAYTONA_API_KEY="<your-api-key>"
|
||||
```
|
||||
|
||||
#### Windows PowerShell:
|
||||
```powershell
|
||||
$env:DAYTONA_API_KEY="<your-api-key>"
|
||||
```
|
||||
|
||||
### Step 4: Run the following `docker` command:
|
||||
This command pulls and runs the OpenHands container using Docker. Once executed, OpenHands should be running locally and ready for use.
|
||||
|
||||
#### Mac/Linux:
|
||||
```bash
|
||||
docker run -it --rm --pull=always \
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:${OPENHANDS_VERSION}-nikolaik \
|
||||
@@ -67,16 +95,36 @@ docker run -it --rm --pull=always \
|
||||
docker.all-hands.dev/all-hands-ai/openhands:${OPENHANDS_VERSION}
|
||||
```
|
||||
|
||||
#### Windows:
|
||||
```powershell
|
||||
docker run -it --rm --pull=always `
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:${env:OPENHANDS_VERSION}-nikolaik `
|
||||
-e LOG_ALL_EVENTS=true `
|
||||
-e RUNTIME=daytona `
|
||||
-e DAYTONA_API_KEY=${env:DAYTONA_API_KEY} `
|
||||
-v ~/.openhands-state:/.openhands-state `
|
||||
-p 3000:3000 `
|
||||
--name openhands-app `
|
||||
docker.all-hands.dev/all-hands-ai/openhands:${env:OPENHANDS_VERSION}
|
||||
```
|
||||
|
||||
> **Tip:** If you don't want your sandboxes to default to the EU region, you can set the `DAYTONA_TARGET` environment variable to `us`
|
||||
|
||||
### Running OpenHands Locally Without Docker
|
||||
|
||||
Alternatively, if you want to run the OpenHands app on your local machine using `make run` without Docker, make sure to set the following environment variables first:
|
||||
|
||||
#### Mac/Linux:
|
||||
```bash
|
||||
export RUNTIME="daytona"
|
||||
export DAYTONA_API_KEY="<your-api-key>"
|
||||
```
|
||||
|
||||
#### Windows PowerShell:
|
||||
```powershell
|
||||
$env:RUNTIME="daytona"
|
||||
$env:DAYTONA_API_KEY="<your-api-key>"
|
||||
```
|
||||
|
||||
## Documentation
|
||||
Read more by visiting our [documentation](https://www.daytona.io/docs/) page.
|
||||
|
||||
@@ -115,12 +115,12 @@ class DaytonaRuntime(ActionExecutionClient):
|
||||
|
||||
def _construct_api_url(self, port: int) -> str:
|
||||
assert self.workspace is not None, 'Workspace is not initialized'
|
||||
assert (
|
||||
self.workspace.instance.info is not None
|
||||
), 'Workspace info is not available'
|
||||
assert (
|
||||
self.workspace.instance.info.provider_metadata is not None
|
||||
), 'Provider metadata is not available'
|
||||
assert self.workspace.instance.info is not None, (
|
||||
'Workspace info is not available'
|
||||
)
|
||||
assert self.workspace.instance.info.provider_metadata is not None, (
|
||||
'Provider metadata is not available'
|
||||
)
|
||||
|
||||
node_domain = json.loads(self.workspace.instance.info.provider_metadata)[
|
||||
'nodeDomain'
|
||||
|
||||
@@ -47,6 +47,7 @@ def _is_retryable_wait_until_alive_error(exception):
|
||||
exception,
|
||||
(
|
||||
ConnectionError,
|
||||
httpx.ConnectTimeout,
|
||||
httpx.NetworkError,
|
||||
httpx.RemoteProtocolError,
|
||||
httpx.HTTPStatusError,
|
||||
@@ -207,12 +208,64 @@ class DockerRuntime(ActionExecutionClient):
|
||||
)
|
||||
raise ex
|
||||
|
||||
def _process_volumes(self) -> dict[str, dict[str, str]]:
|
||||
"""Process volume mounts based on configuration.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping host paths to container bind mounts with their modes.
|
||||
"""
|
||||
# Initialize volumes dictionary
|
||||
volumes: dict[str, dict[str, str]] = {}
|
||||
|
||||
# Process volumes (comma-delimited)
|
||||
if self.config.sandbox.volumes is not None:
|
||||
# Handle multiple mounts with comma delimiter
|
||||
mounts = self.config.sandbox.volumes.split(',')
|
||||
|
||||
for mount in mounts:
|
||||
parts = mount.split(':')
|
||||
if len(parts) >= 2:
|
||||
host_path = os.path.abspath(parts[0])
|
||||
container_path = parts[1]
|
||||
# Default mode is 'rw' if not specified
|
||||
mount_mode = parts[2] if len(parts) > 2 else 'rw'
|
||||
|
||||
volumes[host_path] = {
|
||||
'bind': container_path,
|
||||
'mode': mount_mode,
|
||||
}
|
||||
logger.debug(
|
||||
f'Mount dir (sandbox.volumes): {host_path} to {container_path} with mode: {mount_mode}'
|
||||
)
|
||||
|
||||
# Legacy mounting with workspace_* parameters
|
||||
elif (
|
||||
self.config.workspace_mount_path is not None
|
||||
and self.config.workspace_mount_path_in_sandbox is not None
|
||||
):
|
||||
mount_mode = 'rw' # Default mode
|
||||
|
||||
# e.g. result would be: {"/home/user/openhands/workspace": {'bind': "/workspace", 'mode': 'rw'}}
|
||||
volumes[self.config.workspace_mount_path] = {
|
||||
'bind': self.config.workspace_mount_path_in_sandbox,
|
||||
'mode': mount_mode,
|
||||
}
|
||||
logger.debug(
|
||||
f'Mount dir (legacy): {self.config.workspace_mount_path} with mode: {mount_mode}'
|
||||
)
|
||||
|
||||
return volumes
|
||||
|
||||
def _init_container(self):
|
||||
self.log('debug', 'Preparing to start container...')
|
||||
self.send_status_message('STATUS$PREPARING_CONTAINER')
|
||||
self._host_port = self._find_available_port(EXECUTION_SERVER_PORT_RANGE)
|
||||
self._container_port = self._host_port
|
||||
self._vscode_port = self._find_available_port(VSCODE_PORT_RANGE)
|
||||
# Use the configured vscode_port if provided, otherwise find an available port
|
||||
self._vscode_port = (
|
||||
self.config.sandbox.vscode_port
|
||||
or self._find_available_port(VSCODE_PORT_RANGE)
|
||||
)
|
||||
self._app_ports = [
|
||||
self._find_available_port(APP_PORT_RANGE_1),
|
||||
self._find_available_port(APP_PORT_RANGE_2),
|
||||
@@ -268,23 +321,16 @@ class DockerRuntime(ActionExecutionClient):
|
||||
environment.update(self.config.sandbox.runtime_startup_env_vars)
|
||||
|
||||
self.log('debug', f'Workspace Base: {self.config.workspace_base}')
|
||||
if (
|
||||
self.config.workspace_mount_path is not None
|
||||
and self.config.workspace_mount_path_in_sandbox is not None
|
||||
):
|
||||
# e.g. result would be: {"/home/user/openhands/workspace": {'bind': "/workspace", 'mode': 'rw'}}
|
||||
volumes = {
|
||||
self.config.workspace_mount_path: {
|
||||
'bind': self.config.workspace_mount_path_in_sandbox,
|
||||
'mode': 'rw',
|
||||
}
|
||||
}
|
||||
logger.debug(f'Mount dir: {self.config.workspace_mount_path}')
|
||||
else:
|
||||
|
||||
# Process volumes for mounting
|
||||
volumes = self._process_volumes()
|
||||
|
||||
# If no volumes were configured, set to None
|
||||
if not volumes:
|
||||
logger.debug(
|
||||
'Mount dir is not set, will not mount the workspace directory to the container'
|
||||
)
|
||||
volumes = None
|
||||
volumes = {} # Empty dict instead of None to satisfy mypy
|
||||
self.log(
|
||||
'debug',
|
||||
f'Sandbox workspace: {self.config.workspace_mount_path_in_sandbox}',
|
||||
@@ -443,8 +489,9 @@ class DockerRuntime(ActionExecutionClient):
|
||||
def web_hosts(self):
|
||||
hosts: dict[str, int] = {}
|
||||
|
||||
host_addr = os.environ.get('DOCKER_HOST_ADDR', 'localhost')
|
||||
for port in self._app_ports:
|
||||
hosts[f'http://localhost:{port}'] = port
|
||||
hosts[f'http://{host_addr}:{port}'] = port
|
||||
|
||||
return hosts
|
||||
|
||||
|
||||
@@ -40,9 +40,9 @@ class E2BBox:
|
||||
|
||||
def _archive(self, host_src: str, recursive: bool = False):
|
||||
if recursive:
|
||||
assert os.path.isdir(
|
||||
host_src
|
||||
), 'Source must be a directory when recursive is True'
|
||||
assert os.path.isdir(host_src), (
|
||||
'Source must be a directory when recursive is True'
|
||||
)
|
||||
files = glob(host_src + '/**/*', recursive=True)
|
||||
srcname = os.path.basename(host_src)
|
||||
tar_filename = os.path.join(os.path.dirname(host_src), srcname + '.tar')
|
||||
@@ -52,9 +52,9 @@ class E2BBox:
|
||||
file, arcname=os.path.relpath(file, os.path.dirname(host_src))
|
||||
)
|
||||
else:
|
||||
assert os.path.isfile(
|
||||
host_src
|
||||
), 'Source must be a file when recursive is False'
|
||||
assert os.path.isfile(host_src), (
|
||||
'Source must be a file when recursive is False'
|
||||
)
|
||||
srcname = os.path.basename(host_src)
|
||||
tar_filename = os.path.join(os.path.dirname(host_src), srcname + '.tar')
|
||||
with tarfile.open(tar_filename, mode='w') as tar:
|
||||
|
||||
@@ -41,6 +41,18 @@ from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.tenacity_stop import stop_if_should_exit
|
||||
|
||||
|
||||
def get_user_info():
|
||||
"""Get user ID and username in a cross-platform way."""
|
||||
username = os.getenv('USER')
|
||||
if sys.platform == 'win32':
|
||||
# On Windows, we don't use user IDs the same way
|
||||
# Return a default value that won't cause issues
|
||||
return 1000, username
|
||||
else:
|
||||
# On Unix systems, use os.getuid()
|
||||
return os.getuid(), username
|
||||
|
||||
|
||||
def check_dependencies(code_repo_path: str, poetry_venvs_path: str):
|
||||
ERROR_MESSAGE = 'Please follow the instructions in https://github.com/All-Hands-AI/OpenHands/blob/main/Development.md to install OpenHands.'
|
||||
if not os.path.exists(code_repo_path):
|
||||
@@ -63,28 +75,33 @@ def check_dependencies(code_repo_path: str, poetry_venvs_path: str):
|
||||
if 'jupyter' not in output.lower():
|
||||
raise ValueError('Jupyter is not properly installed. ' + ERROR_MESSAGE)
|
||||
|
||||
# Check libtmux is installed
|
||||
logger.debug('Checking dependencies: libtmux')
|
||||
import libtmux
|
||||
# Check libtmux is installed (skip on Windows)
|
||||
|
||||
server = libtmux.Server()
|
||||
try:
|
||||
session = server.new_session(session_name='test-session')
|
||||
except Exception:
|
||||
raise ValueError('tmux is not properly installed or available on the path.')
|
||||
pane = session.attached_pane
|
||||
pane.send_keys('echo "test"')
|
||||
pane_output = '\n'.join(pane.cmd('capture-pane', '-p').stdout)
|
||||
session.kill_session()
|
||||
if 'test' not in pane_output:
|
||||
raise ValueError('libtmux is not properly installed. ' + ERROR_MESSAGE)
|
||||
if sys.platform != 'win32':
|
||||
logger.debug('Checking dependencies: libtmux')
|
||||
import libtmux
|
||||
|
||||
# Check browser works
|
||||
logger.debug('Checking dependencies: browser')
|
||||
from openhands.runtime.browser.browser_env import BrowserEnv
|
||||
server = libtmux.Server()
|
||||
try:
|
||||
session = server.new_session(session_name='test-session')
|
||||
except Exception:
|
||||
raise ValueError('tmux is not properly installed or available on the path.')
|
||||
pane = session.attached_pane
|
||||
pane.send_keys('echo "test"')
|
||||
pane_output = '\n'.join(pane.cmd('capture-pane', '-p').stdout)
|
||||
session.kill_session()
|
||||
if 'test' not in pane_output:
|
||||
raise ValueError('libtmux is not properly installed. ' + ERROR_MESSAGE)
|
||||
|
||||
browser = BrowserEnv()
|
||||
browser.close()
|
||||
# Skip browser environment check on Windows
|
||||
if sys.platform != 'win32':
|
||||
logger.debug('Checking dependencies: browser')
|
||||
from openhands.runtime.browser.browser_env import BrowserEnv
|
||||
|
||||
browser = BrowserEnv()
|
||||
browser.close()
|
||||
else:
|
||||
logger.warning('Running on Windows - browser environment check skipped.')
|
||||
|
||||
|
||||
class LocalRuntime(ActionExecutionClient):
|
||||
@@ -110,9 +127,15 @@ class LocalRuntime(ActionExecutionClient):
|
||||
attach_to_existing: bool = False,
|
||||
headless_mode: bool = True,
|
||||
):
|
||||
self.is_windows = sys.platform == 'win32'
|
||||
if self.is_windows:
|
||||
logger.warning(
|
||||
'Running on Windows - some features that require tmux will be limited. '
|
||||
'For full functionality, please consider using WSL or Docker runtime.'
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self._user_id = os.getuid()
|
||||
self._username = os.getenv('USER')
|
||||
self._user_id, self._username = get_user_info()
|
||||
|
||||
if self.config.workspace_base is not None:
|
||||
logger.warning(
|
||||
@@ -161,6 +184,7 @@ class LocalRuntime(ActionExecutionClient):
|
||||
self.status_callback = status_callback
|
||||
self.server_process: subprocess.Popen[str] | None = None
|
||||
self.action_semaphore = threading.Semaphore(1) # Ensure one action at a time
|
||||
self._log_thread_exit_event = threading.Event() # Add exit event
|
||||
|
||||
# Update env vars
|
||||
if self.config.sandbox.runtime_startup_env_vars:
|
||||
@@ -199,7 +223,7 @@ class LocalRuntime(ActionExecutionClient):
|
||||
server_port=self._host_port,
|
||||
plugins=self.plugins,
|
||||
app_config=self.config,
|
||||
python_prefix=[],
|
||||
python_prefix=['poetry', 'run'],
|
||||
override_user_id=self._user_id,
|
||||
override_username=self._username,
|
||||
)
|
||||
@@ -208,7 +232,7 @@ class LocalRuntime(ActionExecutionClient):
|
||||
env = os.environ.copy()
|
||||
# Get the code repo path
|
||||
code_repo_path = os.path.dirname(os.path.dirname(openhands.__file__))
|
||||
env['PYTHONPATH'] = f'{code_repo_path}{os.pathsep}{env.get("PYTHONPATH", "")}'
|
||||
env['PYTHONPATH'] = os.pathsep.join([code_repo_path, env.get('PYTHONPATH', '')])
|
||||
env['OPENHANDS_REPO_PATH'] = code_repo_path
|
||||
env['LOCAL_RUNTIME_MODE'] = '1'
|
||||
|
||||
@@ -230,19 +254,50 @@ class LocalRuntime(ActionExecutionClient):
|
||||
universal_newlines=True,
|
||||
bufsize=1,
|
||||
env=env,
|
||||
cwd=code_repo_path, # Explicitly set the working directory
|
||||
)
|
||||
|
||||
# Start a thread to read and log server output
|
||||
def log_output():
|
||||
while (
|
||||
self.server_process
|
||||
and self.server_process.poll()
|
||||
and self.server_process.stdout
|
||||
):
|
||||
line = self.server_process.stdout.readline()
|
||||
if not line:
|
||||
break
|
||||
self.log('debug', f'Server: {line.strip()}')
|
||||
if not self.server_process or not self.server_process.stdout:
|
||||
self.log('error', 'Server process or stdout not available for logging.')
|
||||
return
|
||||
|
||||
try:
|
||||
# Read lines while the process is running and stdout is available
|
||||
while self.server_process.poll() is None:
|
||||
if self._log_thread_exit_event.is_set(): # Check exit event
|
||||
self.log('info', 'Log thread received exit signal.')
|
||||
break # Exit loop if signaled
|
||||
line = self.server_process.stdout.readline()
|
||||
if not line:
|
||||
# Process might have exited between poll() and readline()
|
||||
break
|
||||
self.log('info', f'Server: {line.strip()}')
|
||||
|
||||
# Capture any remaining output after the process exits OR if signaled
|
||||
if (
|
||||
not self._log_thread_exit_event.is_set()
|
||||
): # Check again before reading remaining
|
||||
self.log('info', 'Server process exited, reading remaining output.')
|
||||
for line in self.server_process.stdout:
|
||||
if (
|
||||
self._log_thread_exit_event.is_set()
|
||||
): # Check inside loop too
|
||||
self.log(
|
||||
'info',
|
||||
'Log thread received exit signal while reading remaining output.',
|
||||
)
|
||||
break
|
||||
self.log('info', f'Server (remaining): {line.strip()}')
|
||||
|
||||
except Exception as e:
|
||||
# Log the error, but don't prevent the thread from potentially exiting
|
||||
self.log('error', f'Error reading server output: {e}')
|
||||
finally:
|
||||
self.log(
|
||||
'info', 'Log output thread finished.'
|
||||
) # Add log for thread exit
|
||||
|
||||
self._log_thread = threading.Thread(target=log_output, daemon=True)
|
||||
self._log_thread.start()
|
||||
@@ -312,6 +367,8 @@ class LocalRuntime(ActionExecutionClient):
|
||||
|
||||
def close(self):
|
||||
"""Stop the server process."""
|
||||
self._log_thread_exit_event.set() # Signal the log thread to exit
|
||||
|
||||
if self.server_process:
|
||||
self.server_process.terminate()
|
||||
try:
|
||||
@@ -319,7 +376,7 @@ class LocalRuntime(ActionExecutionClient):
|
||||
except subprocess.TimeoutExpired:
|
||||
self.server_process.kill()
|
||||
self.server_process = None
|
||||
self._log_thread.join()
|
||||
self._log_thread.join(timeout=5) # Add timeout to join
|
||||
|
||||
if self._temp_workspace:
|
||||
shutil.rmtree(self._temp_workspace)
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Callable
|
||||
from typing import Any, Callable
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
import tenacity
|
||||
from tenacity import RetryCallState
|
||||
|
||||
from openhands.core.config import AppConfig
|
||||
from openhands.core.exceptions import (
|
||||
@@ -37,6 +38,9 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
runtime_id: str | None = None
|
||||
runtime_url: str | None = None
|
||||
_runtime_initialized: bool = False
|
||||
runtime_builder: RemoteRuntimeBuilder
|
||||
container_image: str
|
||||
available_hosts: dict[str, int]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -45,12 +49,12 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
sid: str = 'default',
|
||||
plugins: list[PluginRequirement] | None = None,
|
||||
env_vars: dict[str, str] | None = None,
|
||||
status_callback: Callable | None = None,
|
||||
status_callback: Callable[..., None] | None = None,
|
||||
attach_to_existing: bool = False,
|
||||
headless_mode: bool = True,
|
||||
user_id: str | None = None,
|
||||
git_provider_tokens: PROVIDER_TOKEN_TYPE | None = None,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
config,
|
||||
event_stream,
|
||||
@@ -94,10 +98,12 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
getattr(logger, level)(message, stacklevel=2)
|
||||
|
||||
@property
|
||||
def action_execution_server_url(self):
|
||||
def action_execution_server_url(self) -> str:
|
||||
if self.runtime_url is None:
|
||||
raise NotImplementedError('Runtime URL is not initialized')
|
||||
return self.runtime_url
|
||||
|
||||
async def connect(self):
|
||||
async def connect(self) -> None:
|
||||
try:
|
||||
await call_sync_from_async(self._start_or_attach_to_runtime)
|
||||
except Exception:
|
||||
@@ -107,7 +113,7 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
await call_sync_from_async(self.setup_initial_env)
|
||||
self._runtime_initialized = True
|
||||
|
||||
def _start_or_attach_to_runtime(self):
|
||||
def _start_or_attach_to_runtime(self) -> None:
|
||||
existing_runtime = self._check_existing_runtime()
|
||||
if existing_runtime:
|
||||
self.log('debug', f'Using existing runtime with ID: {self.runtime_id}')
|
||||
@@ -130,12 +136,12 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
)
|
||||
self.container_image = self.config.sandbox.runtime_container_image
|
||||
self._start_runtime()
|
||||
assert (
|
||||
self.runtime_id is not None
|
||||
), 'Runtime ID is not set. This should never happen.'
|
||||
assert (
|
||||
self.runtime_url is not None
|
||||
), 'Runtime URL is not set. This should never happen.'
|
||||
assert self.runtime_id is not None, (
|
||||
'Runtime ID is not set. This should never happen.'
|
||||
)
|
||||
assert self.runtime_url is not None, (
|
||||
'Runtime URL is not set. This should never happen.'
|
||||
)
|
||||
self.send_status_message('STATUS$WAITING_FOR_CLIENT')
|
||||
if not self.attach_to_existing:
|
||||
self.log('info', 'Waiting for runtime to be alive...')
|
||||
@@ -179,7 +185,7 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
self.log('error', f'Invalid response from runtime API: {data}')
|
||||
return False
|
||||
|
||||
def _build_runtime(self):
|
||||
def _build_runtime(self) -> None:
|
||||
self.log('debug', f'Building RemoteRuntime config:\n{self.config}')
|
||||
response = self._send_runtime_api_request(
|
||||
'GET',
|
||||
@@ -223,18 +229,18 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
f'Container image {self.container_image} does not exist'
|
||||
)
|
||||
|
||||
def _start_runtime(self):
|
||||
def _start_runtime(self) -> None:
|
||||
# Prepare the request body for the /start endpoint
|
||||
command = get_action_execution_server_startup_command(
|
||||
server_port=self.port,
|
||||
plugins=self.plugins,
|
||||
app_config=self.config,
|
||||
)
|
||||
environment = {}
|
||||
environment: dict[str, str] = {}
|
||||
if self.config.debug or os.environ.get('DEBUG', 'false').lower() == 'true':
|
||||
environment['DEBUG'] = 'true'
|
||||
environment.update(self.config.sandbox.runtime_startup_env_vars)
|
||||
start_request = {
|
||||
start_request: dict[str, Any] = {
|
||||
'image': self.container_image,
|
||||
'command': command,
|
||||
'working_dir': '/openhands/code/',
|
||||
@@ -262,8 +268,10 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
self.log('error', f'Unable to start runtime: {str(e)}')
|
||||
raise AgentRuntimeUnavailableError() from e
|
||||
|
||||
def _resume_runtime(self):
|
||||
"""
|
||||
def _resume_runtime(self) -> None:
|
||||
"""Resume a stopped runtime.
|
||||
|
||||
Steps:
|
||||
1. Show status update that runtime is being started.
|
||||
2. Send the runtime API a /resume request
|
||||
3. Poll for the runtime to be ready
|
||||
@@ -279,7 +287,7 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
self.setup_initial_env()
|
||||
self.log('debug', 'Runtime resumed.')
|
||||
|
||||
def _parse_runtime_response(self, response: httpx.Response):
|
||||
def _parse_runtime_response(self, response: httpx.Response) -> None:
|
||||
start_response = response.json()
|
||||
self.runtime_id = start_response['runtime_id']
|
||||
self.runtime_url = start_response['url']
|
||||
@@ -310,7 +318,7 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
def web_hosts(self) -> dict[str, int]:
|
||||
return self.available_hosts
|
||||
|
||||
def _wait_until_alive(self):
|
||||
def _wait_until_alive(self) -> None:
|
||||
retry_decorator = tenacity.retry(
|
||||
stop=tenacity.stop_after_delay(
|
||||
self.config.sandbox.remote_runtime_init_timeout
|
||||
@@ -321,9 +329,9 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
retry=tenacity.retry_if_exception_type(AgentRuntimeNotReadyError),
|
||||
wait=tenacity.wait_fixed(2),
|
||||
)
|
||||
return retry_decorator(self._wait_until_alive_impl)()
|
||||
retry_decorator(self._wait_until_alive_impl)()
|
||||
|
||||
def _wait_until_alive_impl(self):
|
||||
def _wait_until_alive_impl(self) -> None:
|
||||
self.log('debug', f'Waiting for runtime to be alive at url: {self.runtime_url}')
|
||||
runtime_info_response = self._send_runtime_api_request(
|
||||
'GET',
|
||||
@@ -384,7 +392,7 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
)
|
||||
raise AgentRuntimeNotReadyError()
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
if self.attach_to_existing:
|
||||
super().close()
|
||||
return
|
||||
@@ -417,7 +425,9 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
finally:
|
||||
super().close()
|
||||
|
||||
def _send_runtime_api_request(self, method, url, **kwargs):
|
||||
def _send_runtime_api_request(
|
||||
self, method: str, url: str, **kwargs: Any
|
||||
) -> httpx.Response:
|
||||
try:
|
||||
kwargs['timeout'] = self.config.sandbox.remote_runtime_api_timeout
|
||||
return send_request(self.session, method, url, **kwargs)
|
||||
@@ -428,7 +438,9 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
)
|
||||
raise
|
||||
|
||||
def _send_action_server_request(self, method, url, **kwargs):
|
||||
def _send_action_server_request(
|
||||
self, method: str, url: str, **kwargs: Any
|
||||
) -> httpx.Response:
|
||||
if not self.config.sandbox.remote_runtime_enable_retries:
|
||||
return self._send_action_server_request_impl(method, url, **kwargs)
|
||||
|
||||
@@ -444,7 +456,9 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
method, url, **kwargs
|
||||
)
|
||||
|
||||
def _send_action_server_request_impl(self, method, url, **kwargs):
|
||||
def _send_action_server_request_impl(
|
||||
self, method: str, url: str, **kwargs: Any
|
||||
) -> httpx.Response:
|
||||
try:
|
||||
return super()._send_action_server_request(method, url, **kwargs)
|
||||
except httpx.TimeoutException:
|
||||
@@ -455,7 +469,7 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
raise
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
if e.response.status_code in (404, 502, 504):
|
||||
if hasattr(e, 'response') and e.response.status_code in (404, 502, 504):
|
||||
if e.response.status_code == 404:
|
||||
raise AgentRuntimeDisconnectedError(
|
||||
f'Runtime is not responding. This may be temporary, please try again. Original error: {e}'
|
||||
@@ -464,7 +478,7 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
raise AgentRuntimeDisconnectedError(
|
||||
f'Runtime is temporarily unavailable. This may be due to a restart or network issue, please try again. Original error: {e}'
|
||||
) from e
|
||||
elif e.response.status_code == 503:
|
||||
elif hasattr(e, 'response') and e.response.status_code == 503:
|
||||
if self.config.sandbox.keep_runtime_alive:
|
||||
self.log('warning', 'Runtime appears to be paused. Resuming...')
|
||||
self._resume_runtime()
|
||||
@@ -476,5 +490,5 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
else:
|
||||
raise e
|
||||
|
||||
def _stop_if_closed(self, retry_state: tenacity.RetryCallState) -> bool:
|
||||
def _stop_if_closed(self, retry_state: RetryCallState) -> bool:
|
||||
return self._runtime_closed
|
||||
|
||||
@@ -157,7 +157,7 @@ def _print_window(
|
||||
else:
|
||||
output += '(this is the beginning of the file)\n'
|
||||
for i in range(start, end + 1):
|
||||
_new_line = f'{i}|{lines[i-1]}'
|
||||
_new_line = f'{i}|{lines[i - 1]}'
|
||||
if not _new_line.endswith('\n'):
|
||||
_new_line += '\n'
|
||||
output += _new_line
|
||||
|
||||
@@ -2,7 +2,7 @@ from types import ModuleType
|
||||
|
||||
|
||||
def import_functions(
|
||||
module: ModuleType, function_names: list[str], target_globals: dict
|
||||
module: ModuleType, function_names: list[str], target_globals: dict[str, object]
|
||||
) -> None:
|
||||
for name in function_names:
|
||||
if hasattr(module, name):
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import asyncio
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@@ -20,7 +23,7 @@ class JupyterPlugin(Plugin):
|
||||
name: str = 'jupyter'
|
||||
kernel_gateway_port: int
|
||||
kernel_id: str
|
||||
gateway_process: asyncio.subprocess.Process
|
||||
gateway_process: asyncio.subprocess.Process | subprocess.Popen
|
||||
python_interpreter_path: str
|
||||
|
||||
async def initialize(
|
||||
@@ -28,7 +31,10 @@ class JupyterPlugin(Plugin):
|
||||
) -> None:
|
||||
self.kernel_gateway_port = find_available_tcp_port(40000, 49999)
|
||||
self.kernel_id = kernel_id
|
||||
if username in ['root', 'openhands']:
|
||||
is_local_runtime = os.environ.get('LOCAL_RUNTIME_MODE') == '1'
|
||||
is_windows = sys.platform == 'win32'
|
||||
|
||||
if not is_local_runtime:
|
||||
# Non-LocalRuntime
|
||||
prefix = f'su - {username} -s '
|
||||
# cd to code repo, setup all env vars and run micromamba
|
||||
@@ -50,37 +56,84 @@ class JupyterPlugin(Plugin):
|
||||
)
|
||||
# The correct environment is ensured by the PATH in LocalRuntime.
|
||||
poetry_prefix = f'cd {code_repo_path}\n'
|
||||
jupyter_launch_command = (
|
||||
f"{prefix}/bin/bash << 'EOF'\n"
|
||||
f'{poetry_prefix}'
|
||||
'poetry run jupyter kernelgateway '
|
||||
'--KernelGatewayApp.ip=0.0.0.0 '
|
||||
f'--KernelGatewayApp.port={self.kernel_gateway_port}\n'
|
||||
'EOF'
|
||||
)
|
||||
logger.debug(f'Jupyter launch command: {jupyter_launch_command}')
|
||||
|
||||
# Using asyncio.create_subprocess_shell instead of subprocess.Popen
|
||||
# to avoid ASYNC101 linting error
|
||||
self.gateway_process = await asyncio.create_subprocess_shell(
|
||||
jupyter_launch_command,
|
||||
stderr=asyncio.subprocess.STDOUT,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
)
|
||||
# read stdout until the kernel gateway is ready
|
||||
output = ''
|
||||
while should_continue() and self.gateway_process.stdout is not None:
|
||||
line_bytes = await self.gateway_process.stdout.readline()
|
||||
line = line_bytes.decode('utf-8')
|
||||
output += line
|
||||
if 'at' in line:
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
logger.debug('Waiting for jupyter kernel gateway to start...')
|
||||
if is_windows:
|
||||
# Windows-specific command format
|
||||
jupyter_launch_command = (
|
||||
f'cd /d "{code_repo_path}" && '
|
||||
'poetry run jupyter kernelgateway '
|
||||
'--KernelGatewayApp.ip=0.0.0.0 '
|
||||
f'--KernelGatewayApp.port={self.kernel_gateway_port}'
|
||||
)
|
||||
logger.debug(f'Jupyter launch command (Windows): {jupyter_launch_command}')
|
||||
|
||||
# Using synchronous subprocess.Popen for Windows as asyncio.create_subprocess_shell
|
||||
# has limitations on Windows platforms
|
||||
self.gateway_process = subprocess.Popen( # type: ignore[ASYNC101] # noqa: ASYNC101
|
||||
jupyter_launch_command,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
shell=True,
|
||||
text=True,
|
||||
)
|
||||
|
||||
# Windows-specific stdout handling with synchronous time.sleep
|
||||
# as asyncio has limitations on Windows for subprocess operations
|
||||
output = ''
|
||||
while should_continue():
|
||||
if self.gateway_process.stdout is None:
|
||||
time.sleep(1) # type: ignore[ASYNC101] # noqa: ASYNC101
|
||||
continue
|
||||
|
||||
line = self.gateway_process.stdout.readline()
|
||||
if not line:
|
||||
time.sleep(1) # type: ignore[ASYNC101] # noqa: ASYNC101
|
||||
continue
|
||||
|
||||
output += line
|
||||
if 'at' in line:
|
||||
break
|
||||
|
||||
time.sleep(1) # type: ignore[ASYNC101] # noqa: ASYNC101
|
||||
logger.debug('Waiting for jupyter kernel gateway to start...')
|
||||
|
||||
logger.debug(
|
||||
f'Jupyter kernel gateway started at port {self.kernel_gateway_port}. Output: {output}'
|
||||
)
|
||||
else:
|
||||
# Unix systems (Linux/macOS)
|
||||
jupyter_launch_command = (
|
||||
f"{prefix}/bin/bash << 'EOF'\n"
|
||||
f'{poetry_prefix}'
|
||||
'poetry run jupyter kernelgateway '
|
||||
'--KernelGatewayApp.ip=0.0.0.0 '
|
||||
f'--KernelGatewayApp.port={self.kernel_gateway_port}\n'
|
||||
'EOF'
|
||||
)
|
||||
logger.debug(f'Jupyter launch command: {jupyter_launch_command}')
|
||||
|
||||
# Using asyncio.create_subprocess_shell instead of subprocess.Popen
|
||||
# to avoid ASYNC101 linting error
|
||||
self.gateway_process = await asyncio.create_subprocess_shell(
|
||||
jupyter_launch_command,
|
||||
stderr=asyncio.subprocess.STDOUT,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
)
|
||||
# read stdout until the kernel gateway is ready
|
||||
output = ''
|
||||
while should_continue() and self.gateway_process.stdout is not None:
|
||||
line_bytes = await self.gateway_process.stdout.readline()
|
||||
line = line_bytes.decode('utf-8')
|
||||
output += line
|
||||
if 'at' in line:
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
logger.debug('Waiting for jupyter kernel gateway to start...')
|
||||
|
||||
logger.debug(
|
||||
f'Jupyter kernel gateway started at port {self.kernel_gateway_port}. Output: {output}'
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f'Jupyter kernel gateway started at port {self.kernel_gateway_port}. Output: {output}'
|
||||
)
|
||||
_obs = await self.run(
|
||||
IPythonRunCellAction(code='import sys; print(sys.executable)')
|
||||
)
|
||||
|
||||
@@ -138,7 +138,7 @@ class JupyterKernel:
|
||||
retry=retry_if_exception_type(ConnectionRefusedError),
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_fixed(2),
|
||||
)
|
||||
) # type: ignore
|
||||
async def execute(self, code: str, timeout: int = 120) -> str:
|
||||
if not self.ws or self.ws.stream.closed():
|
||||
await self._connect()
|
||||
@@ -189,7 +189,7 @@ class JupyterKernel:
|
||||
|
||||
if os.environ.get('DEBUG'):
|
||||
logging.info(
|
||||
f"MSG TYPE: {msg_type.upper()} DONE:{execution_done}\nCONTENT: {msg_dict['content']}"
|
||||
f'MSG TYPE: {msg_type.upper()} DONE:{execution_done}\nCONTENT: {msg_dict["content"]}'
|
||||
)
|
||||
|
||||
if msg_type == 'error':
|
||||
@@ -203,7 +203,7 @@ class JupyterKernel:
|
||||
if 'image/png' in msg_dict['content']['data']:
|
||||
# use markdone to display image (in case of large image)
|
||||
outputs.append(
|
||||
f"\n\n"
|
||||
f'\n\n'
|
||||
)
|
||||
|
||||
elif msg_type == 'execute_reply':
|
||||
@@ -272,7 +272,7 @@ class ExecuteHandler(tornado.web.RequestHandler):
|
||||
|
||||
def make_app() -> tornado.web.Application:
|
||||
jupyter_kernel = JupyterKernel(
|
||||
f"localhost:{os.environ.get('JUPYTER_GATEWAY_PORT', '8888')}",
|
||||
f'localhost:{os.environ.get("JUPYTER_GATEWAY_PORT", "8888")}',
|
||||
os.environ.get('JUPYTER_GATEWAY_KERNEL_ID', 'default'),
|
||||
)
|
||||
asyncio.get_event_loop().run_until_complete(jupyter_kernel.initialize())
|
||||
|
||||
@@ -6,7 +6,7 @@ import uuid
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import bashlex # type: ignore
|
||||
import bashlex
|
||||
import libtmux
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@@ -25,7 +25,13 @@ def split_bash_commands(commands: str) -> list[str]:
|
||||
return ['']
|
||||
try:
|
||||
parsed = bashlex.parse(commands)
|
||||
except (bashlex.errors.ParsingError, NotImplementedError, TypeError):
|
||||
except (
|
||||
bashlex.errors.ParsingError,
|
||||
NotImplementedError,
|
||||
TypeError,
|
||||
AttributeError,
|
||||
):
|
||||
# Added AttributeError to catch 'str' object has no attribute 'kind' error (issue #8369)
|
||||
logger.debug(
|
||||
f'Failed to parse bash commands\n'
|
||||
f'[input]: {commands}\n'
|
||||
@@ -501,9 +507,9 @@ class BashSession:
|
||||
if len(splited_commands) > 1:
|
||||
return ErrorObservation(
|
||||
content=(
|
||||
f"ERROR: Cannot execute multiple commands at once.\n"
|
||||
f"Please run each command separately OR chain them into a single command via && or ;\n"
|
||||
f"Provided commands:\n{'\n'.join(f'({i + 1}) {cmd}' for i, cmd in enumerate(splited_commands))}"
|
||||
f'ERROR: Cannot execute multiple commands at once.\n'
|
||||
f'Please run each command separately OR chain them into a single command via && or ;\n'
|
||||
f'Provided commands:\n{"\n".join(f"({i + 1}) {cmd}" for i, cmd in enumerate(splited_commands))}'
|
||||
)
|
||||
)
|
||||
|
||||
@@ -591,8 +597,8 @@ class BashSession:
|
||||
logger.debug(
|
||||
f'PANE CONTENT GOT after {time.time() - _start_time:.2f} seconds'
|
||||
)
|
||||
logger.debug(f"BEGIN OF PANE CONTENT: {cur_pane_output.split('\n')[:10]}")
|
||||
logger.debug(f"END OF PANE CONTENT: {cur_pane_output.split('\n')[-10:]}")
|
||||
logger.debug(f'BEGIN OF PANE CONTENT: {cur_pane_output.split("\n")[:10]}')
|
||||
logger.debug(f'END OF PANE CONTENT: {cur_pane_output.split("\n")[-10:]}')
|
||||
ps1_matches = CmdOutputMetadata.matches_ps1_metadata(cur_pane_output)
|
||||
current_ps1_count = len(ps1_matches)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import tempfile
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from openhands_aci.utils.diff import get_diff # type: ignore
|
||||
from openhands_aci.utils.diff import get_diff
|
||||
|
||||
from openhands.core.config import AppConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
@@ -35,8 +35,8 @@ def generate_file_viewer_html(file_path: str) -> str:
|
||||
# Check if the file extension is supported
|
||||
if file_extension not in supported_extensions:
|
||||
raise ValueError(
|
||||
f"Unsupported file extension: {file_extension}. "
|
||||
f"Supported extensions are: {', '.join(supported_extensions)}"
|
||||
f'Unsupported file extension: {file_extension}. '
|
||||
f'Supported extensions are: {", ".join(supported_extensions)}'
|
||||
)
|
||||
|
||||
# Check if the file exists
|
||||
|
||||
@@ -28,7 +28,7 @@ class GitHandler:
|
||||
self.execute = execute_shell_fn
|
||||
self.cwd: str | None = None
|
||||
|
||||
def set_cwd(self, cwd: str):
|
||||
def set_cwd(self, cwd: str) -> None:
|
||||
"""
|
||||
Sets the current working directory for Git operations.
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ class LogStreamer:
|
||||
def __init__(
|
||||
self,
|
||||
container: docker.models.containers.Container,
|
||||
logFn: Callable,
|
||||
logFn: Callable[[str, str], None],
|
||||
):
|
||||
self.log = logFn
|
||||
# Initialize all attributes before starting the thread on this instance
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import threading
|
||||
|
||||
from memory_profiler import memory_usage # type: ignore
|
||||
from memory_profiler import memory_usage
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ class RequestHTTPError(httpx.HTTPStatusError):
|
||||
s = super().__str__()
|
||||
if self.detail is not None:
|
||||
s += f'\nDetails: {self.detail}'
|
||||
return s
|
||||
return str(s)
|
||||
|
||||
|
||||
def is_retryable_error(exception: Any) -> bool:
|
||||
@@ -57,4 +57,4 @@ def send_request(
|
||||
response=e.response,
|
||||
detail=_json.get('detail') if _json is not None else None,
|
||||
) from e
|
||||
return response # type: ignore
|
||||
return response
|
||||
|
||||
@@ -6,10 +6,9 @@ import string
|
||||
import tempfile
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import docker
|
||||
from dirhash import dirhash # type: ignore
|
||||
from dirhash import dirhash
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
|
||||
import openhands
|
||||
@@ -111,7 +110,7 @@ def build_runtime_image(
|
||||
build_folder: str | None = None,
|
||||
dry_run: bool = False,
|
||||
force_rebuild: bool = False,
|
||||
extra_build_args: List[str] | None = None,
|
||||
extra_build_args: list[str] | None = None,
|
||||
) -> str:
|
||||
"""Prepares the final docker build folder.
|
||||
|
||||
@@ -167,7 +166,7 @@ def build_runtime_image_in_folder(
|
||||
dry_run: bool,
|
||||
force_rebuild: bool,
|
||||
platform: str | None = None,
|
||||
extra_build_args: List[str] | None = None,
|
||||
extra_build_args: list[str] | None = None,
|
||||
) -> str:
|
||||
runtime_image_repo, _ = get_runtime_image_repo_and_tag(base_image)
|
||||
lock_tag = f'oh_v{oh_version}_{get_hash_for_lock_files(base_image)}'
|
||||
@@ -284,8 +283,9 @@ def prep_build_folder(
|
||||
build_from=build_from,
|
||||
extra_deps=extra_deps,
|
||||
)
|
||||
with open(Path(build_folder, 'Dockerfile'), 'w') as file: # type: ignore
|
||||
file.write(dockerfile_content) # type: ignore
|
||||
dockerfile_path = Path(build_folder, 'Dockerfile')
|
||||
with open(str(dockerfile_path), 'w') as f:
|
||||
f.write(dockerfile_content)
|
||||
|
||||
|
||||
_ALPHABET = string.digits + string.ascii_lowercase
|
||||
@@ -294,7 +294,7 @@ _ALPHABET = string.digits + string.ascii_lowercase
|
||||
def truncate_hash(hash: str) -> str:
|
||||
"""Convert the base16 hash to base36 and truncate at 16 characters."""
|
||||
value = int(hash, 16)
|
||||
result: List[str] = []
|
||||
result: list[str] = []
|
||||
while value > 0 and len(result) < 16:
|
||||
value, remainder = divmod(value, len(_ALPHABET))
|
||||
result.append(_ALPHABET[remainder])
|
||||
@@ -347,7 +347,7 @@ def _build_sandbox_image(
|
||||
lock_tag: str,
|
||||
versioned_tag: str | None,
|
||||
platform: str | None = None,
|
||||
extra_build_args: List[str] | None = None,
|
||||
extra_build_args: list[str] | None = None,
|
||||
) -> str:
|
||||
"""Build and tag the sandbox image. The image will be tagged with all tags that do not yet exist."""
|
||||
names = [
|
||||
@@ -385,9 +385,9 @@ if __name__ == '__main__':
|
||||
# and create a Dockerfile dynamically and place it in the build_folder only. This allows the Docker image to
|
||||
# then be created using the Dockerfile (most likely using the containers/build.sh script)
|
||||
build_folder = args.build_folder
|
||||
assert os.path.exists(
|
||||
build_folder
|
||||
), f'Build folder {build_folder} does not exist'
|
||||
assert os.path.exists(build_folder), (
|
||||
f'Build folder {build_folder} does not exist'
|
||||
)
|
||||
logger.debug(
|
||||
f'Copying the source code and generating the Dockerfile in the build folder: {build_folder}'
|
||||
)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
@@ -32,6 +33,17 @@ def init_user_and_working_directory(
|
||||
Returns:
|
||||
int | None: The user ID if it was updated, None otherwise.
|
||||
"""
|
||||
# If running on Windows, just create the directory and return
|
||||
if sys.platform == 'win32':
|
||||
logger.debug('Running on Windows, skipping Unix-specific user setup')
|
||||
logger.debug(f'Client working directory: {initial_cwd}')
|
||||
|
||||
# Create the working directory if it doesn't exist
|
||||
os.makedirs(initial_cwd, exist_ok=True)
|
||||
logger.debug(f'Created working directory: {initial_cwd}')
|
||||
|
||||
return None
|
||||
|
||||
# if username is CURRENT_USER, then we don't need to do anything
|
||||
# This is specific to the local runtime
|
||||
if username == os.getenv('USER') and username not in ['root', 'openhands']:
|
||||
|
||||
@@ -38,7 +38,9 @@ RUN ln -s "$(dirname $(which node))/corepack" /usr/local/bin/corepack && \
|
||||
{% endif %}
|
||||
|
||||
# Install uv (required by MCP)
|
||||
RUN curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
RUN curl -LsSf https://astral.sh/uv/install.sh | env UV_INSTALL_DIR="/openhands/bin" sh
|
||||
# Add /openhands/bin to PATH
|
||||
ENV PATH="/openhands/bin:${PATH}"
|
||||
|
||||
# Remove UID 1000 named pn or ubuntu, so the 'openhands' user can be created from ubuntu hosts
|
||||
RUN (if getent passwd 1000 | grep -q pn; then userdel pn; fi) && \
|
||||
|
||||
@@ -5,7 +5,7 @@ import time
|
||||
import psutil
|
||||
|
||||
|
||||
def get_system_stats() -> dict:
|
||||
def get_system_stats() -> dict[str, object]:
|
||||
"""Get current system resource statistics.
|
||||
|
||||
Returns:
|
||||
|
||||
1413
openhands/runtime/utils/windows_bash.py
Normal file
1413
openhands/runtime/utils/windows_bash.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -176,9 +176,9 @@ class InvariantAnalyzer(SecurityAnalyzer):
|
||||
],
|
||||
)
|
||||
)
|
||||
assert (
|
||||
self.guardrail_llm is not None
|
||||
), 'InvariantAnalyzer.guardrail_llm should be initialized before calling check_usertask'
|
||||
assert self.guardrail_llm is not None, (
|
||||
'InvariantAnalyzer.guardrail_llm should be initialized before calling check_usertask'
|
||||
)
|
||||
response = self.guardrail_llm.completion(
|
||||
messages=self.guardrail_llm.format_messages_for_llm(messages),
|
||||
stop=['.'],
|
||||
@@ -261,9 +261,9 @@ class InvariantAnalyzer(SecurityAnalyzer):
|
||||
],
|
||||
)
|
||||
)
|
||||
assert (
|
||||
self.guardrail_llm is not None
|
||||
), 'InvariantAnalyzer.guardrail_llm should be initialized before calling check_fillaction'
|
||||
assert self.guardrail_llm is not None, (
|
||||
'InvariantAnalyzer.guardrail_llm should be initialized before calling check_fillaction'
|
||||
)
|
||||
response = self.guardrail_llm.completion(
|
||||
messages=self.guardrail_llm.format_messages_for_llm(messages),
|
||||
stop=['.'],
|
||||
|
||||
@@ -20,7 +20,7 @@ TraceElement = Message | ToolCall | ToolOutput | Function
|
||||
|
||||
|
||||
def get_next_id(trace: list[TraceElement]) -> str:
|
||||
used_ids = [el.id for el in trace if type(el) == ToolCall]
|
||||
used_ids = [el.id for el in trace if isinstance(el, ToolCall)]
|
||||
for i in range(1, len(used_ids) + 2):
|
||||
if str(i) not in used_ids:
|
||||
return str(i)
|
||||
@@ -31,7 +31,7 @@ def get_last_id(
|
||||
trace: list[TraceElement],
|
||||
) -> str | None:
|
||||
for el in reversed(trace):
|
||||
if type(el) == ToolCall:
|
||||
if isinstance(el, ToolCall):
|
||||
return el.id
|
||||
return None
|
||||
|
||||
@@ -39,12 +39,12 @@ def get_last_id(
|
||||
def parse_action(trace: list[TraceElement], action: Action) -> list[TraceElement]:
|
||||
next_id = get_next_id(trace)
|
||||
inv_trace: list[TraceElement] = []
|
||||
if type(action) == MessageAction:
|
||||
if isinstance(action, MessageAction):
|
||||
if action.source == EventSource.USER:
|
||||
inv_trace.append(Message(role='user', content=action.content))
|
||||
else:
|
||||
inv_trace.append(Message(role='assistant', content=action.content))
|
||||
elif type(action) in [NullAction, ChangeAgentStateAction]:
|
||||
elif isinstance(action, (NullAction, ChangeAgentStateAction)):
|
||||
pass
|
||||
elif hasattr(action, 'action') and action.action is not None:
|
||||
event_dict = event_to_dict(action)
|
||||
@@ -63,7 +63,7 @@ def parse_observation(
|
||||
trace: list[TraceElement], obs: Observation
|
||||
) -> list[TraceElement]:
|
||||
last_id = get_last_id(trace)
|
||||
if type(obs) in [NullObservation, AgentStateChangedObservation]:
|
||||
if isinstance(obs, (NullObservation, AgentStateChangedObservation)):
|
||||
return []
|
||||
elif hasattr(obs, 'content') and obs.content is not None:
|
||||
return [ToolOutput(role='tool', content=obs.content, tool_call_id=last_id)]
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from typing import Type
|
||||
|
||||
from openhands.security.analyzer import SecurityAnalyzer
|
||||
from openhands.security.invariant.analyzer import InvariantAnalyzer
|
||||
|
||||
SecurityAnalyzers: dict[str, Type[SecurityAnalyzer]] = {
|
||||
SecurityAnalyzers: dict[str, type[SecurityAnalyzer]] = {
|
||||
'invariant': InvariantAnalyzer,
|
||||
}
|
||||
|
||||
@@ -42,7 +42,7 @@ websocat ws://127.0.0.1:3000/ws
|
||||
```sh
|
||||
LLM_API_KEY=sk-... # Your Anthropic API Key
|
||||
LLM_MODEL=claude-3-5-sonnet-20241022 # Default model for the agent to use
|
||||
WORKSPACE_BASE=/path/to/your/workspace # Default absolute path to workspace
|
||||
SANDBOX_VOLUMES=/path/to/your/workspace:/workspace:rw # Mount paths in format host_path:container_path:mode
|
||||
```
|
||||
|
||||
## API Schema
|
||||
|
||||
@@ -16,7 +16,7 @@ class ServerConfig(ServerConfigInterface):
|
||||
'openhands.storage.settings.file_settings_store.FileSettingsStore'
|
||||
)
|
||||
secret_store_class: str = (
|
||||
'openhands.storage.settings.file_secrets_store.FileSecretsStore'
|
||||
'openhands.storage.secrets.file_secrets_store.FileSecretsStore'
|
||||
)
|
||||
conversation_store_class: str = (
|
||||
'openhands.storage.conversation.file_conversation_store.FileConversationStore'
|
||||
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Callable, Iterable, Type
|
||||
from typing import Callable, Iterable
|
||||
|
||||
import socketio
|
||||
|
||||
@@ -23,6 +23,10 @@ from openhands.storage.data_models.conversation_metadata import ConversationMeta
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
from openhands.storage.files import FileStore
|
||||
from openhands.utils.async_utils import GENERAL_TIMEOUT, call_async_from_sync, wait_all
|
||||
from openhands.utils.conversation_summary import (
|
||||
auto_generate_title,
|
||||
get_default_conversation_title,
|
||||
)
|
||||
from openhands.utils.import_utils import get_impl
|
||||
from openhands.utils.shutdown_listener import should_continue
|
||||
|
||||
@@ -52,7 +56,7 @@ class StandaloneConversationManager(ConversationManager):
|
||||
)
|
||||
_conversations_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
||||
_cleanup_task: asyncio.Task | None = None
|
||||
_conversation_store_class: Type | None = None
|
||||
_conversation_store_class: type[ConversationStore] | None = None
|
||||
|
||||
async def __aenter__(self):
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_stale())
|
||||
@@ -283,7 +287,7 @@ class StandaloneConversationManager(ConversationManager):
|
||||
response_ids = await self.get_running_agent_loops(user_id)
|
||||
if len(response_ids) >= self.config.max_concurrent_conversations:
|
||||
logger.info(
|
||||
f'too_many_sessions_for:{user_id or ''}',
|
||||
f'too_many_sessions_for:{user_id or ""}',
|
||||
extra={'session_id': sid, 'user_id': user_id},
|
||||
)
|
||||
# Get the conversations sorted (oldest first)
|
||||
@@ -296,7 +300,7 @@ class StandaloneConversationManager(ConversationManager):
|
||||
while len(conversations) >= self.config.max_concurrent_conversations:
|
||||
oldest_conversation_id = conversations.pop().conversation_id
|
||||
logger.debug(
|
||||
f'closing_from_too_many_sessions:{user_id or ''}:{oldest_conversation_id}',
|
||||
f'closing_from_too_many_sessions:{user_id or ""}:{oldest_conversation_id}',
|
||||
extra={'session_id': oldest_conversation_id, 'user_id': user_id},
|
||||
)
|
||||
# Send status message to client and close session.
|
||||
@@ -328,7 +332,9 @@ class StandaloneConversationManager(ConversationManager):
|
||||
try:
|
||||
session.agent_session.event_stream.subscribe(
|
||||
EventStreamSubscriber.SERVER,
|
||||
self._create_conversation_update_callback(user_id, github_user_id, sid),
|
||||
self._create_conversation_update_callback(
|
||||
user_id, github_user_id, sid, settings
|
||||
),
|
||||
UPDATED_AT_CALLBACK_ID,
|
||||
)
|
||||
except ValueError:
|
||||
@@ -425,7 +431,11 @@ class StandaloneConversationManager(ConversationManager):
|
||||
)
|
||||
|
||||
def _create_conversation_update_callback(
|
||||
self, user_id: str | None, github_user_id: str | None, conversation_id: str
|
||||
self,
|
||||
user_id: str | None,
|
||||
github_user_id: str | None,
|
||||
conversation_id: str,
|
||||
settings: Settings,
|
||||
) -> Callable:
|
||||
def callback(event, *args, **kwargs):
|
||||
call_async_from_sync(
|
||||
@@ -434,13 +444,19 @@ class StandaloneConversationManager(ConversationManager):
|
||||
user_id,
|
||||
github_user_id,
|
||||
conversation_id,
|
||||
settings,
|
||||
event,
|
||||
)
|
||||
|
||||
return callback
|
||||
|
||||
async def _update_conversation_for_event(
|
||||
self, user_id: str, github_user_id: str, conversation_id: str, event=None
|
||||
self,
|
||||
user_id: str,
|
||||
github_user_id: str,
|
||||
conversation_id: str,
|
||||
settings: Settings,
|
||||
event=None,
|
||||
):
|
||||
conversation_store = await self._get_conversation_store(user_id, github_user_id)
|
||||
conversation = await conversation_store.get_metadata(conversation_id)
|
||||
@@ -462,6 +478,32 @@ class StandaloneConversationManager(ConversationManager):
|
||||
conversation.total_tokens = (
|
||||
token_usage.prompt_tokens + token_usage.completion_tokens
|
||||
)
|
||||
default_title = get_default_conversation_title(conversation_id)
|
||||
if (
|
||||
conversation.title == default_title
|
||||
): # attempt to autogenerate if default title is in use
|
||||
title = await auto_generate_title(
|
||||
conversation_id, user_id, self.file_store, settings
|
||||
)
|
||||
if title and not title.isspace():
|
||||
conversation.title = title
|
||||
try:
|
||||
# Emit a status update to the client with the new title
|
||||
status_update_dict = {
|
||||
'status_update': True,
|
||||
'type': 'info',
|
||||
'message': conversation_id,
|
||||
'conversation_title': conversation.title,
|
||||
}
|
||||
await self.sio.emit(
|
||||
'oh_event',
|
||||
status_update_dict,
|
||||
to=ROOM_KEY.format(sid=conversation_id),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'Error emitting title update event: {e}')
|
||||
else:
|
||||
conversation.title = default_title
|
||||
|
||||
await conversation_store.save_metadata(conversation)
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ def store_feedback(feedback: FeedbackDataModel) -> dict[str, str]:
|
||||
display_feedback = feedback.model_dump()
|
||||
if 'trajectory' in display_feedback:
|
||||
display_feedback['trajectory'] = (
|
||||
f"elided [length: {len(display_feedback['trajectory'])}"
|
||||
f'elided [length: {len(display_feedback["trajectory"])}'
|
||||
)
|
||||
if 'token' in display_feedback:
|
||||
display_feedback['token'] = 'elided'
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
from types import MappingProxyType
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
@@ -20,14 +21,18 @@ from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderToken
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.server.session.conversation_init_data import ConversationInitData
|
||||
from openhands.server.shared import (
|
||||
SecretsStoreImpl,
|
||||
SettingsStoreImpl,
|
||||
config,
|
||||
conversation_manager,
|
||||
server_config,
|
||||
sio,
|
||||
)
|
||||
from openhands.server.types import AppMode
|
||||
from openhands.storage.conversation.conversation_validator import (
|
||||
create_conversation_validator,
|
||||
)
|
||||
from openhands.storage.data_models.user_secrets import UserSecrets
|
||||
|
||||
|
||||
def create_provider_tokens_object(
|
||||
@@ -43,74 +48,96 @@ def create_provider_tokens_object(
|
||||
|
||||
@sio.event
|
||||
async def connect(connection_id: str, environ):
|
||||
logger.info(f'sio:connect: {connection_id}')
|
||||
query_params = parse_qs(environ.get('QUERY_STRING', ''))
|
||||
latest_event_id_str = query_params.get('latest_event_id', [-1])[0]
|
||||
try:
|
||||
latest_event_id = int(latest_event_id_str)
|
||||
except ValueError:
|
||||
logger.debug(
|
||||
f'Invalid latest_event_id value: {latest_event_id_str}, defaulting to -1'
|
||||
logger.info(f'sio:connect: {connection_id}')
|
||||
query_params = parse_qs(environ.get('QUERY_STRING', ''))
|
||||
latest_event_id_str = query_params.get('latest_event_id', [-1])[0]
|
||||
try:
|
||||
latest_event_id = int(latest_event_id_str)
|
||||
except ValueError:
|
||||
logger.debug(
|
||||
f'Invalid latest_event_id value: {latest_event_id_str}, defaulting to -1'
|
||||
)
|
||||
latest_event_id = -1
|
||||
conversation_id = query_params.get('conversation_id', [None])[0]
|
||||
raw_list = query_params.get('providers_set', [])
|
||||
providers_list = []
|
||||
for item in raw_list:
|
||||
providers_list.extend(item.split(',') if isinstance(item, str) else [])
|
||||
providers_list = [p for p in providers_list if p]
|
||||
providers_set = [ProviderType(p) for p in providers_list]
|
||||
|
||||
if not conversation_id:
|
||||
logger.error('No conversation_id in query params')
|
||||
raise ConnectionRefusedError('No conversation_id in query params')
|
||||
|
||||
cookies_str = environ.get('HTTP_COOKIE', '')
|
||||
# Get Authorization header from the environment
|
||||
# Headers in WSGI/ASGI are prefixed with 'HTTP_' and have dashes replaced with underscores
|
||||
authorization_header = environ.get('HTTP_AUTHORIZATION', None)
|
||||
conversation_validator = create_conversation_validator()
|
||||
user_id, github_user_id = await conversation_validator.validate(
|
||||
conversation_id, cookies_str, authorization_header
|
||||
)
|
||||
latest_event_id = -1
|
||||
conversation_id = query_params.get('conversation_id', [None])[0]
|
||||
raw_list = query_params.get('providers_set', [])
|
||||
providers_list = []
|
||||
for item in raw_list:
|
||||
providers_list.extend(item.split(',') if isinstance(item, str) else [])
|
||||
providers_list = [p for p in providers_list if p]
|
||||
providers_set = [ProviderType(p) for p in providers_list]
|
||||
|
||||
if not conversation_id:
|
||||
logger.error('No conversation_id in query params')
|
||||
raise ConnectionRefusedError('No conversation_id in query params')
|
||||
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
|
||||
settings = await settings_store.load()
|
||||
|
||||
cookies_str = environ.get('HTTP_COOKIE', '')
|
||||
conversation_validator = create_conversation_validator()
|
||||
user_id, github_user_id = await conversation_validator.validate(
|
||||
conversation_id, cookies_str
|
||||
)
|
||||
secrets_store = await SecretsStoreImpl.get_instance(config, user_id)
|
||||
user_secrets: UserSecrets | None = await secrets_store.load()
|
||||
|
||||
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
|
||||
settings = await settings_store.load()
|
||||
if not settings:
|
||||
raise ConnectionRefusedError(
|
||||
'Settings not found', {'msg_id': 'CONFIGURATION$SETTINGS_NOT_FOUND'}
|
||||
)
|
||||
session_init_args: dict = {}
|
||||
if settings:
|
||||
session_init_args = {**settings.__dict__, **session_init_args}
|
||||
|
||||
if not settings:
|
||||
raise ConnectionRefusedError(
|
||||
'Settings not found', {'msg_id': 'CONFIGURATION$SETTINGS_NOT_FOUND'}
|
||||
git_provider_tokens = create_provider_tokens_object(providers_set)
|
||||
if server_config.app_mode != AppMode.SAAS and user_secrets:
|
||||
git_provider_tokens = user_secrets.provider_tokens
|
||||
|
||||
session_init_args['git_provider_tokens'] = git_provider_tokens
|
||||
|
||||
conversation_init_data = ConversationInitData(**session_init_args)
|
||||
|
||||
event_stream = await conversation_manager.join_conversation(
|
||||
conversation_id,
|
||||
connection_id,
|
||||
conversation_init_data,
|
||||
user_id,
|
||||
github_user_id,
|
||||
)
|
||||
session_init_args: dict = {}
|
||||
if settings:
|
||||
session_init_args = {**settings.__dict__, **session_init_args}
|
||||
|
||||
session_init_args['git_provider_tokens'] = create_provider_tokens_object(
|
||||
providers_set
|
||||
)
|
||||
conversation_init_data = ConversationInitData(**session_init_args)
|
||||
|
||||
event_stream = await conversation_manager.join_conversation(
|
||||
conversation_id, connection_id, conversation_init_data, user_id, github_user_id
|
||||
)
|
||||
logger.info(
|
||||
f'Connected to conversation {conversation_id} with connection_id {connection_id}. Replaying event stream...'
|
||||
)
|
||||
agent_state_changed = None
|
||||
if event_stream is None:
|
||||
raise ConnectionRefusedError('Failed to join conversation')
|
||||
async_store = AsyncEventStoreWrapper(event_stream, latest_event_id + 1)
|
||||
async for event in async_store:
|
||||
logger.debug(f'oh_event: {event.__class__.__name__}')
|
||||
if isinstance(
|
||||
event,
|
||||
(NullAction, NullObservation, RecallAction),
|
||||
):
|
||||
continue
|
||||
elif isinstance(event, AgentStateChangedObservation):
|
||||
agent_state_changed = event
|
||||
else:
|
||||
await sio.emit('oh_event', event_to_dict(event), to=connection_id)
|
||||
if agent_state_changed:
|
||||
await sio.emit('oh_event', event_to_dict(agent_state_changed), to=connection_id)
|
||||
logger.info(f'Finished replaying event stream for conversation {conversation_id}')
|
||||
logger.info(
|
||||
f'Connected to conversation {conversation_id} with connection_id {connection_id}. Replaying event stream...'
|
||||
)
|
||||
agent_state_changed = None
|
||||
if event_stream is None:
|
||||
raise ConnectionRefusedError('Failed to join conversation')
|
||||
async_store = AsyncEventStoreWrapper(event_stream, latest_event_id + 1)
|
||||
async for event in async_store:
|
||||
logger.debug(f'oh_event: {event.__class__.__name__}')
|
||||
if isinstance(
|
||||
event,
|
||||
(NullAction, NullObservation, RecallAction),
|
||||
):
|
||||
continue
|
||||
elif isinstance(event, AgentStateChangedObservation):
|
||||
agent_state_changed = event
|
||||
else:
|
||||
await sio.emit('oh_event', event_to_dict(event), to=connection_id)
|
||||
if agent_state_changed:
|
||||
await sio.emit(
|
||||
'oh_event', event_to_dict(agent_state_changed), to=connection_id
|
||||
)
|
||||
logger.info(
|
||||
f'Finished replaying event stream for conversation {conversation_id}'
|
||||
)
|
||||
except ConnectionRefusedError:
|
||||
# Close the broken connection after sending an error message
|
||||
asyncio.create_task(sio.disconnect(connection_id))
|
||||
raise
|
||||
|
||||
|
||||
@sio.event
|
||||
|
||||
@@ -8,7 +8,7 @@ app = APIRouter(prefix='/api/conversations/{conversation_id}')
|
||||
|
||||
|
||||
@app.get('/config')
|
||||
async def get_remote_runtime_config(request: Request):
|
||||
async def get_remote_runtime_config(request: Request) -> JSONResponse:
|
||||
"""Retrieve the runtime configuration.
|
||||
|
||||
Currently, this is the session ID and runtime ID (if available).
|
||||
@@ -25,7 +25,7 @@ async def get_remote_runtime_config(request: Request):
|
||||
|
||||
|
||||
@app.get('/vscode-url')
|
||||
async def get_vscode_url(request: Request):
|
||||
async def get_vscode_url(request: Request) -> JSONResponse:
|
||||
"""Get the VSCode URL.
|
||||
|
||||
This endpoint allows getting the VSCode URL.
|
||||
@@ -55,7 +55,7 @@ async def get_vscode_url(request: Request):
|
||||
|
||||
|
||||
@app.get('/web-hosts')
|
||||
async def get_hosts(request: Request):
|
||||
async def get_hosts(request: Request) -> JSONResponse:
|
||||
"""Get the hosts used by the runtime.
|
||||
|
||||
This endpoint allows getting the hosts used by the runtime.
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
@@ -41,8 +42,17 @@ from openhands.utils.async_utils import call_sync_from_async
|
||||
app = APIRouter(prefix='/api/conversations/{conversation_id}')
|
||||
|
||||
|
||||
@app.get('/list-files')
|
||||
async def list_files(request: Request, path: str | None = None):
|
||||
@app.get(
|
||||
'/list-files',
|
||||
response_model=list[str],
|
||||
responses={
|
||||
404: {'description': 'Runtime not initialized', 'model': dict},
|
||||
500: {'description': 'Error listing or filtering files', 'model': dict},
|
||||
},
|
||||
)
|
||||
async def list_files(
|
||||
request: Request, path: str | None = None
|
||||
) -> list[str] | JSONResponse:
|
||||
"""List files in the specified path.
|
||||
|
||||
This function retrieves a list of files from the agent's runtime file store,
|
||||
@@ -83,7 +93,7 @@ async def list_files(request: Request, path: str | None = None):
|
||||
|
||||
file_list = [f for f in file_list if f not in FILES_TO_IGNORE]
|
||||
|
||||
async def filter_for_gitignore(file_list, base_path):
|
||||
async def filter_for_gitignore(file_list: list[str], base_path: str) -> list[str]:
|
||||
gitignore_path = os.path.join(base_path, '.gitignore')
|
||||
try:
|
||||
read_action = FileReadAction(gitignore_path)
|
||||
@@ -109,8 +119,21 @@ async def list_files(request: Request, path: str | None = None):
|
||||
return file_list
|
||||
|
||||
|
||||
@app.get('/select-file')
|
||||
async def select_file(file: str, request: Request):
|
||||
# NOTE: We use response_model=None for endpoints that can return multiple response types
|
||||
# (like FileResponse | JSONResponse). This is because FastAPI's response_model expects a
|
||||
# Pydantic model, but Starlette response classes like FileResponse are not Pydantic models.
|
||||
# Instead, we document the possible responses using the 'responses' parameter and maintain
|
||||
# proper type annotations for mypy.
|
||||
@app.get(
|
||||
'/select-file',
|
||||
response_model=None,
|
||||
responses={
|
||||
200: {'description': 'File content returned as JSON', 'model': dict[str, str]},
|
||||
500: {'description': 'Error opening file', 'model': dict},
|
||||
415: {'description': 'Unsupported media type', 'model': dict},
|
||||
},
|
||||
)
|
||||
async def select_file(file: str, request: Request) -> FileResponse | JSONResponse:
|
||||
"""Retrieve the content of a specified file.
|
||||
|
||||
To select a file:
|
||||
@@ -144,7 +167,7 @@ async def select_file(file: str, request: Request):
|
||||
|
||||
if isinstance(observation, FileReadObservation):
|
||||
content = observation.content
|
||||
return {'code': content}
|
||||
return JSONResponse(content={'code': content})
|
||||
elif isinstance(observation, ErrorObservation):
|
||||
logger.error(f'Error opening file {file}: {observation}')
|
||||
|
||||
@@ -158,10 +181,23 @@ async def select_file(file: str, request: Request):
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content={'error': f'Error opening file: {observation}'},
|
||||
)
|
||||
else:
|
||||
# Handle unexpected observation types
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content={'error': f'Unexpected observation type: {type(observation)}'},
|
||||
)
|
||||
|
||||
|
||||
@app.get('/zip-directory')
|
||||
def zip_current_workspace(request: Request):
|
||||
@app.get(
|
||||
'/zip-directory',
|
||||
response_model=None,
|
||||
responses={
|
||||
200: {'description': 'Zipped workspace returned as FileResponse'},
|
||||
500: {'description': 'Error zipping workspace', 'model': dict},
|
||||
},
|
||||
)
|
||||
def zip_current_workspace(request: Request) -> FileResponse | JSONResponse:
|
||||
try:
|
||||
logger.debug('Zipping workspace')
|
||||
runtime: Runtime = request.state.conversation.runtime
|
||||
@@ -188,12 +224,19 @@ def zip_current_workspace(request: Request):
|
||||
)
|
||||
|
||||
|
||||
@app.get('/git/changes')
|
||||
@app.get(
|
||||
'/git/changes',
|
||||
response_model=dict[str, Any],
|
||||
responses={
|
||||
404: {'description': 'Not a git repository', 'model': dict},
|
||||
500: {'description': 'Error getting changes', 'model': dict},
|
||||
},
|
||||
)
|
||||
async def git_changes(
|
||||
request: Request,
|
||||
conversation_id: str,
|
||||
user_id: str = Depends(get_user_id),
|
||||
):
|
||||
) -> dict[str, Any] | JSONResponse:
|
||||
runtime: Runtime = request.state.conversation.runtime
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config,
|
||||
@@ -229,13 +272,17 @@ async def git_changes(
|
||||
)
|
||||
|
||||
|
||||
@app.get('/git/diff')
|
||||
@app.get(
|
||||
'/git/diff',
|
||||
response_model=dict[str, Any],
|
||||
responses={500: {'description': 'Error getting diff', 'model': dict}},
|
||||
)
|
||||
async def git_diff(
|
||||
request: Request,
|
||||
path: str,
|
||||
conversation_id: str,
|
||||
conversation_store=Depends(get_conversation_store),
|
||||
):
|
||||
conversation_store: Any = Depends(get_conversation_store),
|
||||
) -> dict[str, Any] | JSONResponse:
|
||||
runtime: Runtime = request.state.conversation.runtime
|
||||
|
||||
cwd = await get_cwd(
|
||||
@@ -259,7 +306,7 @@ async def get_cwd(
|
||||
conversation_store: ConversationStore,
|
||||
conversation_id: str,
|
||||
workspace_mount_path_in_sandbox: str,
|
||||
):
|
||||
) -> str:
|
||||
metadata = await conversation_store.get_metadata(conversation_id)
|
||||
is_running = await conversation_manager.is_agent_loop_running(conversation_id)
|
||||
conversation_info = await _get_conversation_info(metadata, is_running)
|
||||
|
||||
@@ -8,6 +8,7 @@ from openhands.integrations.provider import (
|
||||
)
|
||||
from openhands.integrations.service_types import (
|
||||
AuthenticationError,
|
||||
Branch,
|
||||
Repository,
|
||||
SuggestedTask,
|
||||
UnknownException,
|
||||
@@ -29,7 +30,7 @@ async def get_user_repositories(
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens),
|
||||
access_token: SecretStr | None = Depends(get_access_token),
|
||||
user_id: str | None = Depends(get_user_id),
|
||||
):
|
||||
) -> list[Repository] | JSONResponse:
|
||||
if provider_tokens:
|
||||
client = ProviderHandler(
|
||||
provider_tokens=provider_tokens,
|
||||
@@ -65,7 +66,7 @@ async def get_user_repositories(
|
||||
async def get_user(
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens),
|
||||
access_token: SecretStr | None = Depends(get_access_token),
|
||||
):
|
||||
) -> User | JSONResponse:
|
||||
if provider_tokens:
|
||||
client = ProviderHandler(
|
||||
provider_tokens=provider_tokens, external_auth_token=access_token
|
||||
@@ -101,7 +102,7 @@ async def search_repositories(
|
||||
order: str = 'desc',
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens),
|
||||
access_token: SecretStr | None = Depends(get_access_token),
|
||||
):
|
||||
) -> list[Repository] | JSONResponse:
|
||||
if provider_tokens:
|
||||
client = ProviderHandler(
|
||||
provider_tokens=provider_tokens, external_auth_token=access_token
|
||||
@@ -134,7 +135,7 @@ async def search_repositories(
|
||||
async def get_suggested_tasks(
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens),
|
||||
access_token: SecretStr | None = Depends(get_access_token),
|
||||
):
|
||||
) -> list[SuggestedTask] | JSONResponse:
|
||||
"""Get suggested tasks for the authenticated user across their most recently pushed repositories.
|
||||
|
||||
Returns:
|
||||
@@ -165,3 +166,43 @@ async def get_suggested_tasks(
|
||||
content='No providers set.',
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
|
||||
@app.get('/repository/branches', response_model=list[Branch])
|
||||
async def get_repository_branches(
|
||||
repository: str,
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens),
|
||||
access_token: SecretStr | None = Depends(get_access_token),
|
||||
) -> list[Branch] | JSONResponse:
|
||||
"""Get branches for a repository.
|
||||
|
||||
Args:
|
||||
repository: The repository name in the format 'owner/repo'
|
||||
|
||||
Returns:
|
||||
A list of branches for the repository
|
||||
"""
|
||||
if provider_tokens:
|
||||
client = ProviderHandler(
|
||||
provider_tokens=provider_tokens, external_auth_token=access_token
|
||||
)
|
||||
try:
|
||||
branches: list[Branch] = await client.get_branches(repository)
|
||||
return branches
|
||||
|
||||
except AuthenticationError as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
except UnknownException as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
content='Git provider token required. (such as GitHub).',
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, status
|
||||
from fastapi import APIRouter, Depends, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from pydantic import BaseModel
|
||||
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.events.event import EventSource
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.integrations.provider import (
|
||||
PROVIDER_TOKEN_TYPE,
|
||||
ProviderHandler,
|
||||
@@ -30,7 +28,6 @@ from openhands.server.shared import (
|
||||
SettingsStoreImpl,
|
||||
config,
|
||||
conversation_manager,
|
||||
file_store,
|
||||
)
|
||||
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
|
||||
from openhands.server.user_auth import (
|
||||
@@ -47,7 +44,7 @@ from openhands.storage.data_models.conversation_metadata import (
|
||||
)
|
||||
from openhands.storage.data_models.conversation_status import ConversationStatus
|
||||
from openhands.utils.async_utils import wait_all
|
||||
from openhands.utils.conversation_summary import generate_conversation_title
|
||||
from openhands.utils.conversation_summary import get_default_conversation_title
|
||||
|
||||
app = APIRouter(prefix='/api')
|
||||
|
||||
@@ -61,10 +58,9 @@ class InitSessionRequest(BaseModel):
|
||||
image_urls: list[str] | None = None
|
||||
replay_json: str | None = None
|
||||
suggested_task: SuggestedTask | None = None
|
||||
|
||||
model_config = {
|
||||
"extra": "forbid"
|
||||
}
|
||||
|
||||
model_config = {'extra': 'forbid'}
|
||||
|
||||
|
||||
async def _create_new_conversation(
|
||||
user_id: str | None,
|
||||
@@ -76,7 +72,7 @@ async def _create_new_conversation(
|
||||
replay_json: str | None,
|
||||
conversation_trigger: ConversationTrigger = ConversationTrigger.GUI,
|
||||
attach_convo_id: bool = False,
|
||||
):
|
||||
) -> str:
|
||||
logger.info(
|
||||
'Creating conversation',
|
||||
extra={
|
||||
@@ -90,7 +86,7 @@ async def _create_new_conversation(
|
||||
settings = await settings_store.load()
|
||||
logger.info('Settings loaded')
|
||||
|
||||
session_init_args: dict = {}
|
||||
session_init_args: dict[str, Any] = {}
|
||||
if settings:
|
||||
session_init_args = {**settings.__dict__, **session_init_args}
|
||||
# We could use litellm.check_valid_key for a more accurate check,
|
||||
@@ -99,13 +95,13 @@ async def _create_new_conversation(
|
||||
not settings.llm_api_key
|
||||
or settings.llm_api_key.get_secret_value().isspace()
|
||||
):
|
||||
logger.warn(f'Missing api key for model {settings.llm_model}')
|
||||
logger.warning(f'Missing api key for model {settings.llm_model}')
|
||||
raise LLMAuthenticationError(
|
||||
'Error authenticating with the LLM provider. Please check your API key'
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warn('Settings not present, not starting conversation')
|
||||
logger.warning('Settings not present, not starting conversation')
|
||||
raise MissingSettingsError('Settings not found')
|
||||
|
||||
session_init_args['git_provider_tokens'] = git_provider_tokens
|
||||
@@ -163,7 +159,6 @@ async def _create_new_conversation(
|
||||
replay_json=replay_json,
|
||||
)
|
||||
logger.info(f'Finished initializing conversation {conversation_id}')
|
||||
|
||||
return conversation_id
|
||||
|
||||
|
||||
@@ -173,7 +168,7 @@ async def new_conversation(
|
||||
user_id: str = Depends(get_user_id),
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE = Depends(get_provider_tokens),
|
||||
auth_type: AuthType | None = Depends(get_auth_type),
|
||||
):
|
||||
) -> JSONResponse:
|
||||
"""Initialize a new session or join an existing one.
|
||||
|
||||
After successful initialization, the client should connect to the WebSocket
|
||||
@@ -246,7 +241,7 @@ async def new_conversation(
|
||||
},
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@app.get('/conversations')
|
||||
async def search_conversations(
|
||||
@@ -301,109 +296,6 @@ async def get_conversation(
|
||||
return None
|
||||
|
||||
|
||||
def get_default_conversation_title(conversation_id: str) -> str:
|
||||
"""
|
||||
Generate a default title for a conversation based on its ID.
|
||||
|
||||
Args:
|
||||
conversation_id: The ID of the conversation
|
||||
|
||||
Returns:
|
||||
A default title string
|
||||
"""
|
||||
return f'Conversation {conversation_id[:5]}'
|
||||
|
||||
|
||||
async def auto_generate_title(conversation_id: str, user_id: str | None) -> str:
|
||||
"""
|
||||
Auto-generate a title for a conversation based on the first user message.
|
||||
Uses LLM-based title generation if available, otherwise falls back to a simple truncation.
|
||||
|
||||
Args:
|
||||
conversation_id: The ID of the conversation
|
||||
user_id: The ID of the user
|
||||
|
||||
Returns:
|
||||
A generated title string
|
||||
"""
|
||||
logger.info(f'Auto-generating title for conversation {conversation_id}')
|
||||
|
||||
try:
|
||||
# Create an event stream for the conversation
|
||||
event_stream = EventStream(conversation_id, file_store, user_id)
|
||||
|
||||
# Find the first user message
|
||||
first_user_message = None
|
||||
for event in event_stream.get_events():
|
||||
if (
|
||||
event.source == EventSource.USER
|
||||
and isinstance(event, MessageAction)
|
||||
and event.content
|
||||
and event.content.strip()
|
||||
):
|
||||
first_user_message = event.content
|
||||
break
|
||||
|
||||
if first_user_message:
|
||||
# Get LLM config from user settings
|
||||
try:
|
||||
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
|
||||
settings = await settings_store.load()
|
||||
|
||||
if settings and settings.llm_model:
|
||||
# Create LLM config from settings
|
||||
llm_config = LLMConfig(
|
||||
model=settings.llm_model,
|
||||
api_key=settings.llm_api_key,
|
||||
base_url=settings.llm_base_url,
|
||||
)
|
||||
|
||||
# Try to generate title using LLM
|
||||
llm_title = await generate_conversation_title(
|
||||
first_user_message, llm_config
|
||||
)
|
||||
if llm_title:
|
||||
logger.info(f'Generated title using LLM: {llm_title}')
|
||||
return llm_title
|
||||
except Exception as e:
|
||||
logger.error(f'Error using LLM for title generation: {e}')
|
||||
|
||||
# Fall back to simple truncation if LLM generation fails or is unavailable
|
||||
first_user_message = first_user_message.strip()
|
||||
title = first_user_message[:30]
|
||||
if len(first_user_message) > 30:
|
||||
title += '...'
|
||||
logger.info(f'Generated title using truncation: {title}')
|
||||
return title
|
||||
except Exception as e:
|
||||
logger.error(f'Error generating title: {str(e)}')
|
||||
return ''
|
||||
|
||||
|
||||
@app.patch('/conversations/{conversation_id}')
|
||||
async def update_conversation(
|
||||
conversation_id: str,
|
||||
title: str = Body(embed=True),
|
||||
user_id: str | None = Depends(get_user_id),
|
||||
) -> bool:
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
|
||||
metadata = await conversation_store.get_metadata(conversation_id)
|
||||
if not metadata:
|
||||
return False
|
||||
|
||||
# If title is empty or unspecified, auto-generate it
|
||||
if not title or title.isspace():
|
||||
title = await auto_generate_title(conversation_id, user_id)
|
||||
|
||||
# If we still don't have a title, use the default
|
||||
if not title or title.isspace():
|
||||
title = get_default_conversation_title(conversation_id)
|
||||
|
||||
metadata.title = title
|
||||
await conversation_store.save_metadata(metadata)
|
||||
return True
|
||||
|
||||
|
||||
@app.delete('/conversations/{conversation_id}')
|
||||
async def delete_conversation(
|
||||
conversation_id: str,
|
||||
|
||||
@@ -1,32 +1,36 @@
|
||||
from fastapi import APIRouter, Depends, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.integrations.utils import validate_provider_token
|
||||
from openhands.server.settings import GETCustomSecrets, POSTCustomSecrets, POSTProviderModel
|
||||
from openhands.server.user_auth import get_secrets_store, get_user_secrets, get_user_settings_store
|
||||
from openhands.server.settings import (
|
||||
GETCustomSecrets,
|
||||
POSTCustomSecrets,
|
||||
POSTProviderModel,
|
||||
)
|
||||
from openhands.server.user_auth import (
|
||||
get_provider_tokens,
|
||||
get_secrets_store,
|
||||
get_user_secrets,
|
||||
)
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
from openhands.storage.data_models.user_secrets import UserSecrets
|
||||
from openhands.storage.settings.secret_store import SecretsStore
|
||||
from openhands.storage.secrets.secrets_store import SecretsStore
|
||||
from openhands.storage.settings.settings_store import SettingsStore
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
app = APIRouter(prefix='/api')
|
||||
|
||||
|
||||
|
||||
|
||||
# =================================================
|
||||
# SECTION: Handle git provider tokens
|
||||
# =================================================
|
||||
|
||||
|
||||
async def invalidate_legacy_secrets_store(
|
||||
settings: Settings,
|
||||
settings_store: SettingsStore,
|
||||
secrets_store: SecretsStore) -> UserSecrets | None:
|
||||
|
||||
settings: Settings, settings_store: SettingsStore, secrets_store: SecretsStore
|
||||
) -> UserSecrets | None:
|
||||
"""
|
||||
We are moving `secrets_store` (a field from `Settings` object) to its own dedicated store
|
||||
This function moves the values from Settings to UserSecrets, and deletes the values in Settings
|
||||
@@ -34,7 +38,9 @@ async def invalidate_legacy_secrets_store(
|
||||
"""
|
||||
|
||||
if len(settings.secrets_store.provider_tokens.items()) > 0:
|
||||
user_secrets = UserSecrets(provider_tokens=settings.secrets_store.provider_tokens)
|
||||
user_secrets = UserSecrets(
|
||||
provider_tokens=settings.secrets_store.provider_tokens
|
||||
)
|
||||
await secrets_store.store(user_secrets)
|
||||
|
||||
# Invalidate old tokens via settings store serializer
|
||||
@@ -44,69 +50,97 @@ async def invalidate_legacy_secrets_store(
|
||||
await settings_store.store(invalidated_secrets_settings)
|
||||
|
||||
return user_secrets
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
|
||||
async def check_provider_tokens(provider_info: POSTProviderModel) -> str:
|
||||
print(provider_info)
|
||||
if provider_info.provider_tokens:
|
||||
# Determine whether tokens are valid
|
||||
for token_type, token_value in provider_info.provider_tokens.items():
|
||||
if token_value.token:
|
||||
confirmed_token_type = await validate_provider_token(
|
||||
token_value.token
|
||||
)
|
||||
if not confirmed_token_type or confirmed_token_type != token_type:
|
||||
return f'Invalid token. Please make sure it is a valid {token_type.value} token.'
|
||||
def process_token_validation_result(
|
||||
confirmed_token_type: ProviderType | None, token_type: ProviderType
|
||||
):
|
||||
if not confirmed_token_type or confirmed_token_type != token_type:
|
||||
return (
|
||||
f'Invalid token. Please make sure it is a valid {token_type.value} token.'
|
||||
)
|
||||
|
||||
return ''
|
||||
|
||||
|
||||
async def check_provider_tokens(
|
||||
incoming_provider_tokens: POSTProviderModel,
|
||||
existing_provider_tokens: PROVIDER_TOKEN_TYPE | None,
|
||||
) -> str:
|
||||
msg = ''
|
||||
if incoming_provider_tokens.provider_tokens:
|
||||
# Determine whether tokens are valid
|
||||
for token_type, token_value in incoming_provider_tokens.provider_tokens.items():
|
||||
if token_value.token:
|
||||
confirmed_token_type = await validate_provider_token(
|
||||
token_value.token, token_value.host
|
||||
) # FE always sends latest host
|
||||
msg = process_token_validation_result(confirmed_token_type, token_type)
|
||||
|
||||
existing_token = (
|
||||
existing_provider_tokens.get(token_type, None)
|
||||
if existing_provider_tokens
|
||||
else None
|
||||
)
|
||||
if (
|
||||
existing_token
|
||||
and (existing_token.host != token_value.host)
|
||||
and existing_token.token
|
||||
):
|
||||
confirmed_token_type = await validate_provider_token(
|
||||
existing_token.token, token_value.host
|
||||
) # Host has changed, check it against existing token
|
||||
if not confirmed_token_type or confirmed_token_type != token_type:
|
||||
msg = process_token_validation_result(
|
||||
confirmed_token_type, token_type
|
||||
)
|
||||
|
||||
return msg
|
||||
|
||||
|
||||
@app.post('/add-git-providers')
|
||||
async def store_provider_tokens(
|
||||
provider_info: POSTProviderModel,
|
||||
secrets_store: SecretsStore = Depends(get_secrets_store)
|
||||
provider_info: POSTProviderModel,
|
||||
secrets_store: SecretsStore = Depends(get_secrets_store),
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens),
|
||||
) -> JSONResponse:
|
||||
provider_err_msg = await check_provider_tokens(provider_info)
|
||||
provider_err_msg = await check_provider_tokens(provider_info, provider_tokens)
|
||||
if provider_err_msg:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={'error': provider_err_msg},
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
user_secrets = await secrets_store.load()
|
||||
if not user_secrets:
|
||||
user_secrets = UserSecrets()
|
||||
|
||||
if provider_info.provider_tokens:
|
||||
existing_providers = [provider for provider in user_secrets.provider_tokens]
|
||||
|
||||
if user_secrets:
|
||||
if provider_info.provider_tokens:
|
||||
existing_providers = [
|
||||
# Merge incoming settings store with the existing one
|
||||
for provider, token_value in list(provider_info.provider_tokens.items()):
|
||||
if provider in existing_providers and not token_value.token:
|
||||
existing_token = user_secrets.provider_tokens.get(provider)
|
||||
if existing_token and existing_token.token:
|
||||
provider_info.provider_tokens[provider] = existing_token
|
||||
|
||||
provider_info.provider_tokens[provider] = provider_info.provider_tokens[
|
||||
provider
|
||||
for provider in user_secrets.provider_tokens
|
||||
]
|
||||
].model_copy(update={'host': token_value.host})
|
||||
|
||||
# Merge incoming settings store with the existing one
|
||||
for provider, token_value in list(provider_info.provider_tokens.items()):
|
||||
if provider in existing_providers and not token_value.token:
|
||||
existing_token = (
|
||||
user_secrets.provider_tokens.get(provider)
|
||||
)
|
||||
if existing_token and existing_token.token:
|
||||
provider_info.provider_tokens[provider] = existing_token
|
||||
updated_secrets = user_secrets.model_copy(
|
||||
update={'provider_tokens': provider_info.provider_tokens}
|
||||
)
|
||||
await secrets_store.store(updated_secrets)
|
||||
|
||||
else: # nothing passed in means keep current settings
|
||||
provider_info.provider_tokens = dict(user_secrets.provider_tokens)
|
||||
|
||||
|
||||
updated_secrets = user_secrets.model_copy(update={"provider_tokens":provider_info.provider_tokens})
|
||||
await secrets_store.store(updated_secrets)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={'message': 'Git providers stored'},
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={'message': 'Git providers stored'},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f'Something went wrong storing git providers: {e}')
|
||||
return JSONResponse(
|
||||
@@ -117,14 +151,12 @@ async def store_provider_tokens(
|
||||
|
||||
@app.post('/unset-provider-tokens', response_model=dict[str, str])
|
||||
async def unset_provider_tokens(
|
||||
secrets_store: SecretsStore = Depends(get_secrets_store)
|
||||
secrets_store: SecretsStore = Depends(get_secrets_store),
|
||||
) -> JSONResponse:
|
||||
try:
|
||||
user_secrets = await secrets_store.load()
|
||||
if user_secrets:
|
||||
updated_secrets = user_secrets.model_copy(
|
||||
update={'provider_tokens': {}}
|
||||
)
|
||||
updated_secrets = user_secrets.model_copy(update={'provider_tokens': {}})
|
||||
await secrets_store.store(updated_secrets)
|
||||
|
||||
return JSONResponse(
|
||||
@@ -140,14 +172,11 @@ async def unset_provider_tokens(
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
# =================================================
|
||||
# SECTION: Handle custom secrets
|
||||
# =================================================
|
||||
|
||||
|
||||
|
||||
@app.get('/secrets', response_model=GETCustomSecrets)
|
||||
async def load_custom_secrets_names(
|
||||
user_secrets: UserSecrets | None = Depends(get_user_secrets),
|
||||
@@ -158,15 +187,15 @@ async def load_custom_secrets_names(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
content={'error': 'User secrets not found'},
|
||||
)
|
||||
|
||||
|
||||
custom_secrets = list(user_secrets.custom_secrets.keys())
|
||||
return GETCustomSecrets(custom_secrets=custom_secrets)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f'Invalid token: {e}')
|
||||
logger.warning(f'Failed to load secret names: {e}')
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={'error': 'Invalid token'},
|
||||
content={'error': 'Failed to get secret names'},
|
||||
)
|
||||
|
||||
|
||||
@@ -186,9 +215,9 @@ async def create_custom_secret(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={'message': f'Secret {secret_name} already exists'},
|
||||
)
|
||||
|
||||
|
||||
custom_secrets[secret_name] = secret_value
|
||||
|
||||
|
||||
# Create a new UserSecrets that preserves provider tokens
|
||||
updated_user_secrets = UserSecrets(
|
||||
custom_secrets=custom_secrets,
|
||||
@@ -208,10 +237,11 @@ async def create_custom_secret(
|
||||
content={'error': 'Something went wrong creating secret'},
|
||||
)
|
||||
|
||||
|
||||
@app.put('/secrets/{secret_id}', response_model=dict[str, str])
|
||||
async def update_custom_secret(
|
||||
secret_id: str,
|
||||
incoming_secret: POSTCustomSecrets,
|
||||
secret_id: str,
|
||||
incoming_secret: POSTCustomSecrets,
|
||||
secrets_store: SecretsStore = Depends(get_secrets_store),
|
||||
) -> JSONResponse:
|
||||
try:
|
||||
@@ -289,4 +319,3 @@ async def delete_custom_secret(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content={'error': 'Something went wrong deleting secret'},
|
||||
)
|
||||
|
||||
|
||||
@@ -6,8 +6,6 @@ from openhands.integrations.provider import (
|
||||
PROVIDER_TOKEN_TYPE,
|
||||
ProviderType,
|
||||
)
|
||||
|
||||
|
||||
from openhands.server.routes.secrets import invalidate_legacy_secrets_store
|
||||
from openhands.server.settings import (
|
||||
GETSettingsModel,
|
||||
@@ -18,39 +16,50 @@ from openhands.server.user_auth import (
|
||||
get_secrets_store,
|
||||
get_user_settings_store,
|
||||
)
|
||||
from openhands.storage.settings.secret_store import SecretsStore
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
from openhands.storage.secrets.secrets_store import SecretsStore
|
||||
from openhands.storage.settings.settings_store import SettingsStore
|
||||
|
||||
app = APIRouter(prefix='/api')
|
||||
|
||||
|
||||
@app.get('/settings', response_model=GETSettingsModel)
|
||||
@app.get(
|
||||
'/settings',
|
||||
response_model=GETSettingsModel,
|
||||
responses={
|
||||
404: {'description': 'Settings not found', 'model': dict},
|
||||
401: {'description': 'Invalid token', 'model': dict},
|
||||
},
|
||||
)
|
||||
async def load_settings(
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens),
|
||||
settings_store: SettingsStore = Depends(get_user_settings_store),
|
||||
secrets_store: SecretsStore = Depends(get_secrets_store)
|
||||
secrets_store: SecretsStore = Depends(get_secrets_store),
|
||||
) -> GETSettingsModel | JSONResponse:
|
||||
|
||||
settings = await settings_store.load()
|
||||
|
||||
|
||||
try:
|
||||
if not settings:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
content={'error': 'Settings not found'},
|
||||
)
|
||||
|
||||
# On initial load, user secrets may not be populated with values migrated from settings store
|
||||
user_secrets = await invalidate_legacy_secrets_store(settings, settings_store, secrets_store)
|
||||
# If invalidation is successful, then the returned user secrets holds the most recent values
|
||||
git_providers = user_secrets.provider_tokens if user_secrets else provider_tokens
|
||||
|
||||
provider_tokens_set: dict[ProviderType, str | None] = {}
|
||||
# On initial load, user secrets may not be populated with values migrated from settings store
|
||||
user_secrets = await invalidate_legacy_secrets_store(
|
||||
settings, settings_store, secrets_store
|
||||
)
|
||||
|
||||
# If invalidation is successful, then the returned user secrets holds the most recent values
|
||||
git_providers = (
|
||||
user_secrets.provider_tokens if user_secrets else provider_tokens
|
||||
)
|
||||
|
||||
provider_tokens_set: dict[ProviderType, str | None] = {}
|
||||
if git_providers:
|
||||
for provider_type, provider_token in git_providers.items():
|
||||
if provider_token.token or provider_token.user_id:
|
||||
provider_tokens_set[provider_type] = None
|
||||
provider_tokens_set[provider_type] = provider_token.host
|
||||
|
||||
settings_with_token_data = GETSettingsModel(
|
||||
**settings.model_dump(exclude='secrets_store'),
|
||||
@@ -68,7 +77,15 @@ async def load_settings(
|
||||
)
|
||||
|
||||
|
||||
@app.post('/reset-settings', response_model=dict[str, str])
|
||||
@app.post(
|
||||
'/reset-settings',
|
||||
responses={
|
||||
410: {
|
||||
'description': 'Reset settings functionality has been removed',
|
||||
'model': dict,
|
||||
}
|
||||
},
|
||||
)
|
||||
async def reset_settings() -> JSONResponse:
|
||||
"""
|
||||
Resets user settings. (Deprecated)
|
||||
@@ -98,7 +115,18 @@ async def store_llm_settings(
|
||||
return settings
|
||||
|
||||
|
||||
@app.post('/settings', response_model=dict[str, str])
|
||||
# NOTE: We use response_model=None for endpoints that return JSONResponse directly.
|
||||
# This is because FastAPI's response_model expects a Pydantic model, but we're returning
|
||||
# a response object directly. We document the possible responses using the 'responses'
|
||||
# parameter and maintain proper type annotations for mypy.
|
||||
@app.post(
|
||||
'/settings',
|
||||
response_model=None,
|
||||
responses={
|
||||
200: {'description': 'Settings stored successfully', 'model': dict},
|
||||
500: {'description': 'Error storing settings', 'model': dict},
|
||||
},
|
||||
)
|
||||
async def store_settings(
|
||||
settings: Settings,
|
||||
settings_store: SettingsStore = Depends(get_user_settings_store),
|
||||
|
||||
@@ -58,7 +58,7 @@ class AgentSession:
|
||||
file_store: FileStore,
|
||||
status_callback: Callable | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
) -> None:
|
||||
"""Initializes a new instance of the Session class
|
||||
|
||||
Parameters:
|
||||
@@ -89,7 +89,7 @@ class AgentSession:
|
||||
selected_branch: str | None = None,
|
||||
initial_message: MessageAction | None = None,
|
||||
replay_json: str | None = None,
|
||||
):
|
||||
) -> None:
|
||||
"""Starts the Agent session
|
||||
Parameters:
|
||||
- runtime_name: The name of the runtime associated with the session
|
||||
@@ -188,7 +188,7 @@ class AgentSession:
|
||||
},
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
async def close(self) -> None:
|
||||
"""Closes the Agent session"""
|
||||
if self._closed:
|
||||
return
|
||||
@@ -245,7 +245,7 @@ class AgentSession:
|
||||
assert isinstance(replay_events[0], MessageAction)
|
||||
return replay_events[0]
|
||||
|
||||
def _create_security_analyzer(self, security_analyzer: str | None):
|
||||
def _create_security_analyzer(self, security_analyzer: str | None) -> None:
|
||||
"""Creates a SecurityAnalyzer instance that will be used to analyze the agent actions
|
||||
|
||||
Parameters:
|
||||
@@ -333,6 +333,7 @@ class AgentSession:
|
||||
git_provider_tokens, selected_repository, selected_branch
|
||||
)
|
||||
await call_sync_from_async(self.runtime.maybe_run_setup_script)
|
||||
await call_sync_from_async(self.runtime.maybe_setup_git_hooks)
|
||||
|
||||
self.logger.debug(
|
||||
f'Runtime initialized with plugins: {[plugin.name for plugin in self.runtime.plugins]}'
|
||||
|
||||
@@ -38,10 +38,10 @@ class Conversation:
|
||||
headless_mode=False,
|
||||
)
|
||||
|
||||
async def connect(self):
|
||||
async def connect(self) -> None:
|
||||
await self.runtime.connect()
|
||||
|
||||
async def disconnect(self):
|
||||
async def disconnect(self) -> None:
|
||||
if self.event_stream:
|
||||
self.event_stream.close()
|
||||
asyncio.create_task(call_sync_from_async(self.runtime.close))
|
||||
|
||||
@@ -6,7 +6,7 @@ from logging import LoggerAdapter
|
||||
import socketio
|
||||
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.core.config import AppConfig
|
||||
from openhands.core.config import AppConfig, MCPConfig
|
||||
from openhands.core.config.condenser_config import (
|
||||
BrowserOutputCondenserConfig,
|
||||
CondenserPipelineConfig,
|
||||
@@ -73,7 +73,7 @@ class Session:
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self.user_id = user_id
|
||||
|
||||
async def close(self):
|
||||
async def close(self) -> None:
|
||||
if self.sio:
|
||||
await self.sio.emit(
|
||||
'oh_event',
|
||||
@@ -90,7 +90,7 @@ class Session:
|
||||
settings: Settings,
|
||||
initial_message: MessageAction | None,
|
||||
replay_json: str | None,
|
||||
):
|
||||
) -> None:
|
||||
self.agent_session.event_stream.add_event(
|
||||
AgentStateChangedObservation('', AgentState.LOADING),
|
||||
EventSource.ENVIRONMENT,
|
||||
@@ -114,6 +114,7 @@ class Session:
|
||||
or settings.sandbox_runtime_container_image
|
||||
else self.config.sandbox.runtime_container_image
|
||||
)
|
||||
self.config.mcp = settings.mcp_config or MCPConfig()
|
||||
max_iterations = settings.max_iterations or self.config.max_iterations
|
||||
|
||||
# This is a shallow copy of the default LLM config, so changes here will
|
||||
@@ -137,7 +138,7 @@ class Session:
|
||||
# output, which should keep the summarization cost down.
|
||||
default_condenser_config = CondenserPipelineConfig(
|
||||
condensers=[
|
||||
BrowserOutputCondenserConfig(),
|
||||
BrowserOutputCondenserConfig(attention_window=2),
|
||||
LLMSummarizingCondenserConfig(
|
||||
llm_config=llm.config, keep_first=4, max_size=80
|
||||
),
|
||||
@@ -194,10 +195,10 @@ class Session:
|
||||
'info', msg_id, f'Retrying LLM request, {retries} / {max}'
|
||||
)
|
||||
|
||||
def on_event(self, event: Event):
|
||||
def on_event(self, event: Event) -> None:
|
||||
asyncio.get_event_loop().run_until_complete(self._on_event(event))
|
||||
|
||||
async def _on_event(self, event: Event):
|
||||
async def _on_event(self, event: Event) -> None:
|
||||
"""Callback function for events that mainly come from the agent.
|
||||
Event is the base class for any agent action and observation.
|
||||
|
||||
@@ -235,7 +236,7 @@ class Session:
|
||||
event_dict['source'] = EventSource.AGENT
|
||||
await self.send(event_dict)
|
||||
|
||||
async def dispatch(self, data: dict):
|
||||
async def dispatch(self, data: dict) -> None:
|
||||
event = event_from_dict(data.copy())
|
||||
# This checks if the model supports images
|
||||
if isinstance(event, MessageAction) and event.image_urls:
|
||||
@@ -253,7 +254,7 @@ class Session:
|
||||
return
|
||||
self.agent_session.event_stream.add_event(event, EventSource.USER)
|
||||
|
||||
async def send(self, data: dict[str, object]):
|
||||
async def send(self, data: dict[str, object]) -> None:
|
||||
if asyncio.get_running_loop() != self.loop:
|
||||
self.loop.create_task(self._send(data))
|
||||
return
|
||||
@@ -273,11 +274,11 @@ class Session:
|
||||
self.is_alive = False
|
||||
return False
|
||||
|
||||
async def send_error(self, message: str):
|
||||
async def send_error(self, message: str) -> None:
|
||||
"""Sends an error message to the client."""
|
||||
await self.send({'error': True, 'message': message})
|
||||
|
||||
async def _send_status_message(self, msg_type: str, id: str, message: str):
|
||||
async def _send_status_message(self, msg_type: str, id: str, message: str) -> None:
|
||||
"""Sends a status message to the client."""
|
||||
if msg_type == 'error':
|
||||
agent_session = self.agent_session
|
||||
@@ -292,7 +293,7 @@ class Session:
|
||||
{'status_update': True, 'type': msg_type, 'id': id, 'message': message}
|
||||
)
|
||||
|
||||
def queue_status_message(self, msg_type: str, id: str, message: str):
|
||||
def queue_status_message(self, msg_type: str, id: str, message: str) -> None:
|
||||
"""Queues a status message to be sent asynchronously."""
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._send_status_message(msg_type, id, message), self.loop
|
||||
|
||||
@@ -5,6 +5,7 @@ from pydantic import (
|
||||
SecretStr,
|
||||
)
|
||||
|
||||
from openhands.core.config.mcp_config import MCPConfig
|
||||
from openhands.integrations.provider import ProviderToken
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
@@ -15,6 +16,7 @@ class POSTProviderModel(BaseModel):
|
||||
Settings for POST requests
|
||||
"""
|
||||
|
||||
mcp_config: MCPConfig | None = None
|
||||
provider_tokens: dict[ProviderType, ProviderToken] = {}
|
||||
|
||||
|
||||
@@ -36,6 +38,8 @@ class GETSettingsModel(Settings):
|
||||
)
|
||||
llm_api_key_set: bool
|
||||
|
||||
model_config = {'use_enum_values': True}
|
||||
|
||||
|
||||
class GETCustomSecrets(BaseModel):
|
||||
"""
|
||||
|
||||
@@ -11,7 +11,7 @@ from openhands.server.conversation_manager.conversation_manager import (
|
||||
from openhands.server.monitoring import MonitoringListener
|
||||
from openhands.storage import get_file_store
|
||||
from openhands.storage.conversation.conversation_store import ConversationStore
|
||||
from openhands.storage.settings.secret_store import SecretsStore
|
||||
from openhands.storage.secrets.secrets_store import SecretsStore
|
||||
from openhands.storage.settings.settings_store import SettingsStore
|
||||
from openhands.utils.import_utils import get_impl
|
||||
|
||||
|
||||
@@ -4,9 +4,9 @@ from pydantic import SecretStr
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.server.settings import Settings
|
||||
from openhands.storage.data_models.user_secrets import UserSecrets
|
||||
from openhands.storage.settings.secret_store import SecretsStore
|
||||
from openhands.server.user_auth.user_auth import AuthType, get_user_auth
|
||||
from openhands.storage.data_models.user_secrets import UserSecrets
|
||||
from openhands.storage.secrets.secrets_store import SecretsStore
|
||||
from openhands.storage.settings.settings_store import SettingsStore
|
||||
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from openhands.server import shared
|
||||
from openhands.server.settings import Settings
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
from openhands.storage.data_models.user_secrets import UserSecrets
|
||||
from openhands.storage.settings.secret_store import SecretsStore
|
||||
from openhands.storage.secrets.secrets_store import SecretsStore
|
||||
from openhands.storage.settings.settings_store import SettingsStore
|
||||
|
||||
|
||||
@@ -69,7 +69,6 @@ class DefaultUserAuth(UserAuth):
|
||||
self._user_secrets = user_secrets
|
||||
return user_secrets
|
||||
|
||||
|
||||
async def get_provider_tokens(self) -> PROVIDER_TOKEN_TYPE | None:
|
||||
secrets_store = await self.get_user_secrets()
|
||||
provider_tokens = getattr(secrets_store, 'provider_tokens', None)
|
||||
|
||||
@@ -10,7 +10,7 @@ from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||
from openhands.server.settings import Settings
|
||||
from openhands.server.shared import server_config
|
||||
from openhands.storage.data_models.user_secrets import UserSecrets
|
||||
from openhands.storage.settings.secret_store import SecretsStore
|
||||
from openhands.storage.secrets.secrets_store import SecretsStore
|
||||
from openhands.storage.settings.settings_store import SettingsStore
|
||||
from openhands.utils.import_utils import get_impl
|
||||
|
||||
@@ -60,7 +60,7 @@ class UserAuth(ABC):
|
||||
@abstractmethod
|
||||
async def get_user_secrets(self) -> UserSecrets | None:
|
||||
"""Get the user's secrets"""
|
||||
|
||||
|
||||
def get_auth_type(self) -> AuthType | None:
|
||||
return None
|
||||
|
||||
|
||||
@@ -7,7 +7,10 @@ class ConversationValidator:
|
||||
"""Storage for conversation metadata. May or may not support multiple users depending on the environment."""
|
||||
|
||||
async def validate(
|
||||
self, conversation_id: str, cookies_str: str
|
||||
self,
|
||||
conversation_id: str,
|
||||
cookies_str: str,
|
||||
authorization_header: str | None = None,
|
||||
) -> tuple[None, None]:
|
||||
return None, None
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from pydantic import (
|
||||
from pydantic.json import pydantic_encoder
|
||||
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.config.mcp_config import MCPConfig
|
||||
from openhands.core.config.utils import load_app_config
|
||||
from openhands.storage.data_models.user_secrets import UserSecrets
|
||||
|
||||
@@ -33,9 +34,11 @@ class Settings(BaseModel):
|
||||
secrets_store: UserSecrets = Field(default_factory=UserSecrets, frozen=True)
|
||||
enable_default_condenser: bool = True
|
||||
enable_sound_notifications: bool = False
|
||||
enable_proactive_conversation_starters: bool = True
|
||||
user_consents_to_analytics: bool | None = None
|
||||
sandbox_base_container_image: str | None = None
|
||||
sandbox_runtime_container_image: str | None = None
|
||||
mcp_config: MCPConfig | None = None
|
||||
|
||||
model_config = {
|
||||
'validate_assignment': True,
|
||||
@@ -94,9 +97,7 @@ class Settings(BaseModel):
|
||||
"""Custom serializer for secrets store."""
|
||||
|
||||
"""Force invalidate secret store"""
|
||||
return {
|
||||
'provider_tokens': {}
|
||||
}
|
||||
return {'provider_tokens': {}}
|
||||
|
||||
@staticmethod
|
||||
def from_config() -> Settings | None:
|
||||
@@ -106,6 +107,12 @@ class Settings(BaseModel):
|
||||
# If no api key has been set, we take this to mean that there is no reasonable default
|
||||
return None
|
||||
security = app_config.security
|
||||
|
||||
# Get MCP config if available
|
||||
mcp_config = None
|
||||
if hasattr(app_config, 'mcp'):
|
||||
mcp_config = app_config.mcp
|
||||
|
||||
settings = Settings(
|
||||
language='en',
|
||||
agent=app_config.default_agent,
|
||||
@@ -116,5 +123,6 @@ class Settings(BaseModel):
|
||||
llm_api_key=llm_config.api_key,
|
||||
llm_base_url=llm_config.base_url,
|
||||
remote_runtime_resource_factor=app_config.sandbox.remote_runtime_resource_factor,
|
||||
mcp_config=mcp_config,
|
||||
)
|
||||
return settings
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from types import MappingProxyType
|
||||
from typing import Any
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
@@ -10,7 +11,14 @@ from pydantic import (
|
||||
model_validator,
|
||||
)
|
||||
from pydantic.json import pydantic_encoder
|
||||
from openhands.integrations.provider import CUSTOM_SECRETS_TYPE, PROVIDER_TOKEN_TYPE, PROVIDER_TOKEN_TYPE_WITH_JSON_SCHEMA, ProviderToken
|
||||
|
||||
from openhands.integrations.provider import (
|
||||
CUSTOM_SECRETS_TYPE,
|
||||
CUSTOM_SECRETS_TYPE_WITH_JSON_SCHEMA,
|
||||
PROVIDER_TOKEN_TYPE,
|
||||
PROVIDER_TOKEN_TYPE_WITH_JSON_SCHEMA,
|
||||
ProviderToken,
|
||||
)
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
|
||||
|
||||
@@ -19,7 +27,7 @@ class UserSecrets(BaseModel):
|
||||
default_factory=lambda: MappingProxyType({})
|
||||
)
|
||||
|
||||
custom_secrets: CUSTOM_SECRETS_TYPE = Field(
|
||||
custom_secrets: CUSTOM_SECRETS_TYPE_WITH_JSON_SCHEMA = Field(
|
||||
default_factory=lambda: MappingProxyType({})
|
||||
)
|
||||
|
||||
@@ -29,7 +37,6 @@ class UserSecrets(BaseModel):
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
|
||||
@field_serializer('provider_tokens')
|
||||
def provider_tokens_serializer(
|
||||
self, provider_tokens: PROVIDER_TOKEN_TYPE, info: SerializationInfo
|
||||
@@ -38,7 +45,7 @@ class UserSecrets(BaseModel):
|
||||
expose_secrets = info.context and info.context.get('expose_secrets', False)
|
||||
|
||||
for token_type, provider_token in provider_tokens.items():
|
||||
if not provider_token or not provider_token.token:
|
||||
if not provider_token:
|
||||
continue
|
||||
|
||||
token_type_str = (
|
||||
@@ -46,10 +53,18 @@ class UserSecrets(BaseModel):
|
||||
if isinstance(token_type, ProviderType)
|
||||
else str(token_type)
|
||||
)
|
||||
|
||||
token = None
|
||||
if provider_token.token:
|
||||
token = (
|
||||
provider_token.token.get_secret_value()
|
||||
if expose_secrets
|
||||
else pydantic_encoder(provider_token.token)
|
||||
)
|
||||
|
||||
tokens[token_type_str] = {
|
||||
'token': provider_token.token.get_secret_value()
|
||||
if expose_secrets
|
||||
else pydantic_encoder(provider_token.token),
|
||||
'token': token,
|
||||
'host': provider_token.host,
|
||||
'user_id': provider_token.user_id,
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, List, TypedDict
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import boto3
|
||||
import botocore
|
||||
@@ -16,7 +16,7 @@ class GetObjectOutputDict(TypedDict):
|
||||
|
||||
|
||||
class ListObjectsV2OutputDict(TypedDict):
|
||||
Contents: List[S3ObjectDict] | None
|
||||
Contents: list[S3ObjectDict] | None
|
||||
|
||||
|
||||
class S3FileStore(FileStore):
|
||||
|
||||
@@ -7,7 +7,7 @@ from openhands.core.config.app_config import AppConfig
|
||||
from openhands.storage import get_file_store
|
||||
from openhands.storage.data_models.user_secrets import UserSecrets
|
||||
from openhands.storage.files import FileStore
|
||||
from openhands.storage.settings.secret_store import SecretsStore
|
||||
from openhands.storage.secrets.secrets_store import SecretsStore
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
|
||||
@@ -34,4 +34,4 @@ class FileSecretsStore(SecretsStore):
|
||||
cls, config: AppConfig, user_id: str | None
|
||||
) -> FileSecretsStore:
|
||||
file_store = get_file_store(config.file_store, config.file_store_path)
|
||||
return FileSecretsStore(file_store)
|
||||
return FileSecretsStore(file_store)
|
||||
@@ -1,12 +1,11 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from openhands.core.config.app_config import AppConfig
|
||||
from openhands.storage.data_models.user_secrets import UserSecrets
|
||||
|
||||
|
||||
|
||||
class SecretsStore(ABC):
|
||||
"""Storage for secrets. May or may not support multiple users depending on the environment."""
|
||||
|
||||
@@ -20,7 +19,5 @@ class SecretsStore(ABC):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def get_instance(
|
||||
cls, config: AppConfig, user_id: str | None
|
||||
) -> SecretsStore:
|
||||
"""Get a store for the user represented by the token given."""
|
||||
async def get_instance(cls, config: AppConfig, user_id: str | None) -> SecretsStore:
|
||||
"""Get a store for the user represented by the token given."""
|
||||
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
from concurrent import futures
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Callable, Coroutine, Iterable, List
|
||||
from typing import Callable, Coroutine, Iterable
|
||||
|
||||
GENERAL_TIMEOUT: int = 15
|
||||
EXECUTOR = ThreadPoolExecutor()
|
||||
@@ -64,7 +64,7 @@ async def call_coro_in_bg_thread(
|
||||
|
||||
async def wait_all(
|
||||
iterable: Iterable[Coroutine], timeout: int = GENERAL_TIMEOUT
|
||||
) -> List:
|
||||
) -> list:
|
||||
"""
|
||||
Shorthand for waiting for all the coroutines in the iterable given in parallel. Creates
|
||||
a task for each coroutine.
|
||||
|
||||
@@ -4,7 +4,12 @@ from typing import Optional
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.events.event import EventSource
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
from openhands.storage.files import FileStore
|
||||
|
||||
|
||||
async def generate_conversation_title(
|
||||
@@ -55,3 +60,81 @@ async def generate_conversation_title(
|
||||
except Exception as e:
|
||||
logger.error(f'Error generating conversation title: {e}')
|
||||
return None
|
||||
|
||||
|
||||
def get_default_conversation_title(conversation_id: str) -> str:
|
||||
"""
|
||||
Generate a default title for a conversation based on its ID.
|
||||
|
||||
Args:
|
||||
conversation_id: The ID of the conversation
|
||||
|
||||
Returns:
|
||||
A default title string
|
||||
"""
|
||||
return f'Conversation {conversation_id[:5]}'
|
||||
|
||||
|
||||
async def auto_generate_title(
|
||||
conversation_id: str, user_id: str | None, file_store: FileStore, settings: Settings
|
||||
) -> str:
|
||||
"""
|
||||
Auto-generate a title for a conversation based on the first user message.
|
||||
Uses LLM-based title generation if available, otherwise falls back to a simple truncation.
|
||||
|
||||
Args:
|
||||
conversation_id: The ID of the conversation
|
||||
user_id: The ID of the user
|
||||
|
||||
Returns:
|
||||
A generated title string
|
||||
"""
|
||||
logger.info(f'Auto-generating title for conversation {conversation_id}')
|
||||
|
||||
try:
|
||||
# Create an event stream for the conversation
|
||||
event_stream = EventStream(conversation_id, file_store, user_id)
|
||||
|
||||
# Find the first user message
|
||||
first_user_message = None
|
||||
for event in event_stream.get_events():
|
||||
if (
|
||||
event.source == EventSource.USER
|
||||
and isinstance(event, MessageAction)
|
||||
and event.content
|
||||
and event.content.strip()
|
||||
):
|
||||
first_user_message = event.content
|
||||
break
|
||||
|
||||
if first_user_message:
|
||||
# Get LLM config from user settings
|
||||
try:
|
||||
if settings and settings.llm_model:
|
||||
# Create LLM config from settings
|
||||
llm_config = LLMConfig(
|
||||
model=settings.llm_model,
|
||||
api_key=settings.llm_api_key,
|
||||
base_url=settings.llm_base_url,
|
||||
)
|
||||
|
||||
# Try to generate title using LLM
|
||||
llm_title = await generate_conversation_title(
|
||||
first_user_message, llm_config
|
||||
)
|
||||
if llm_title:
|
||||
logger.info(f'Generated title using LLM: {llm_title}')
|
||||
return llm_title
|
||||
except Exception as e:
|
||||
logger.error(f'Error using LLM for title generation: {e}')
|
||||
|
||||
# Fall back to simple truncation if LLM generation fails or is unavailable
|
||||
first_user_message = first_user_message.strip()
|
||||
title = first_user_message[:30]
|
||||
if len(first_user_message) > 30:
|
||||
title += '...'
|
||||
logger.info(f'Generated title using truncation: {title}')
|
||||
return title
|
||||
except Exception as e:
|
||||
logger.error(f'Error generating title: {str(e)}')
|
||||
return ''
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import importlib
|
||||
from functools import lru_cache
|
||||
from typing import Type, TypeVar
|
||||
from typing import TypeVar
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
@@ -15,7 +15,7 @@ def import_from(qual_name: str):
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_impl(cls: Type[T], impl_name: str | None) -> Type[T]:
|
||||
def get_impl(cls: type[T], impl_name: str | None) -> type[T]:
|
||||
"""Import a named implementation of the specified class"""
|
||||
if impl_name is None:
|
||||
return cls
|
||||
|
||||
286
poetry.lock
generated
286
poetry.lock
generated
File diff suppressed because one or more lines are too long
@@ -1,25 +1,31 @@
|
||||
[build-system]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
requires = [
|
||||
"poetry-core",
|
||||
]
|
||||
|
||||
[tool.poetry]
|
||||
name = "openhands-ai"
|
||||
version = "0.36.0"
|
||||
version = "0.37.0"
|
||||
description = "OpenHands: Code Less, Make More"
|
||||
authors = ["OpenHands"]
|
||||
authors = [ "OpenHands" ]
|
||||
license = "MIT"
|
||||
readme = "README.md"
|
||||
repository = "https://github.com/All-Hands-AI/OpenHands"
|
||||
packages = [
|
||||
{ include = "openhands/**/*" },
|
||||
{ include = "pyproject.toml", to = "openhands" },
|
||||
{ include = "poetry.lock", to = "openhands" }
|
||||
{ include = "poetry.lock", to = "openhands" },
|
||||
]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.12"
|
||||
litellm = "^1.60.0, !=1.64.4" # avoid 1.64.4 (known bug)
|
||||
aiohttp = ">=3.9.0,!=3.11.13" # Pin to avoid yanked version 3.11.13
|
||||
google-generativeai = "*" # To use litellm with Gemini Pro API
|
||||
google-api-python-client = "^2.164.0" # For Google Sheets API
|
||||
google-auth-httplib2 = "*" # For Google Sheets authentication
|
||||
google-auth-oauthlib = "*" # For Google Sheets OAuth
|
||||
python = "^3.12,<3.14"
|
||||
litellm = "^1.60.0, !=1.64.4, !=1.67.*" # avoid 1.64.4 (known bug) & 1.67.* (known bug #10272)
|
||||
aiohttp = ">=3.9.0,!=3.11.13" # Pin to avoid yanked version 3.11.13
|
||||
google-generativeai = "*" # To use litellm with Gemini Pro API
|
||||
google-api-python-client = "^2.164.0" # For Google Sheets API
|
||||
google-auth-httplib2 = "*" # For Google Sheets authentication
|
||||
google-auth-oauthlib = "*" # For Google Sheets OAuth
|
||||
termcolor = "*"
|
||||
docker = "*"
|
||||
fastapi = "*"
|
||||
@@ -28,7 +34,7 @@ uvicorn = "*"
|
||||
types-toml = "*"
|
||||
numpy = "*"
|
||||
json-repair = "*"
|
||||
browsergym-core = "0.13.3" # integrate browsergym-core as the browsing interface
|
||||
browsergym-core = "0.13.3" # integrate browsergym-core as the browsing interface
|
||||
html2text = "*"
|
||||
e2b = ">=1.0.5,<1.4.0"
|
||||
pexpect = "*"
|
||||
@@ -40,7 +46,7 @@ tenacity = ">=8.5,<10.0"
|
||||
zope-interface = "7.2"
|
||||
pathspec = "^0.12.1"
|
||||
google-cloud-aiplatform = "*"
|
||||
anthropic = {extras = ["vertex"], version = "*"}
|
||||
anthropic = { extras = [ "vertex" ], version = "*" }
|
||||
tree-sitter = "^0.24.0"
|
||||
bashlex = "^0.18"
|
||||
pyjwt = "^2.9.0"
|
||||
@@ -54,7 +60,7 @@ tornado = "*"
|
||||
python-dotenv = "*"
|
||||
pylcs = "^0.1.1"
|
||||
whatthepatch = "^1.0.6"
|
||||
protobuf = "^4.21.6,<5.0.0" # chromadb currently fails on 5.0+
|
||||
protobuf = "^4.21.6,<5.0.0" # chromadb currently fails on 5.0+
|
||||
opentelemetry-api = "1.25.0"
|
||||
opentelemetry-exporter-otlp-proto-grpc = "1.25.0"
|
||||
modal = ">=0.66.26,<0.75.0"
|
||||
@@ -71,14 +77,14 @@ stripe = ">=11.5,<13.0"
|
||||
ipywidgets = "^8.1.5"
|
||||
qtconsole = "^5.6.1"
|
||||
memory-profiler = "^0.61.0"
|
||||
daytona-sdk = "0.15.0"
|
||||
mcp = "1.7.0"
|
||||
daytona-sdk = "0.16.1"
|
||||
mcp = "1.7.1"
|
||||
python-json-logger = "^3.2.1"
|
||||
playwright = "^1.51.0"
|
||||
prompt-toolkit = "^3.0.50"
|
||||
mcpm = "1.8.0"
|
||||
mcpm = "1.9.0"
|
||||
poetry = "^2.1.2"
|
||||
anyio = "4.9.0"
|
||||
pythonnet = "*"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
ruff = "0.11.8"
|
||||
@@ -97,39 +103,12 @@ pandas = "*"
|
||||
reportlab = "*"
|
||||
gevent = ">=24.2.1,<26.0.0"
|
||||
|
||||
[tool.coverage.run]
|
||||
concurrency = ["gevent"]
|
||||
|
||||
|
||||
[tool.poetry.group.runtime.dependencies]
|
||||
jupyterlab = "*"
|
||||
notebook = "*"
|
||||
jupyter_kernel_gateway = "*"
|
||||
flake8 = "*"
|
||||
|
||||
[build-system]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
requires = [
|
||||
"poetry-core",
|
||||
]
|
||||
|
||||
[tool.autopep8]
|
||||
# autopep8 fights with mypy on line length issue
|
||||
ignore = [ "E501" ]
|
||||
|
||||
[tool.black]
|
||||
# prevent black (if installed) from changing single quotes to double quotes
|
||||
skip-string-normalization = true
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["D"]
|
||||
# ignore warnings for missing docstrings
|
||||
ignore = ["D1"]
|
||||
|
||||
[tool.ruff.lint.pydocstyle]
|
||||
convention = "google"
|
||||
|
||||
|
||||
[tool.poetry.group.evaluation.dependencies]
|
||||
streamlit = "*"
|
||||
whatthepatch = "*"
|
||||
@@ -148,14 +127,10 @@ browsergym = "0.13.3"
|
||||
browsergym-webarena = "0.13.3"
|
||||
browsergym-miniwob = "0.13.3"
|
||||
browsergym-visualwebarena = "0.13.3"
|
||||
boto3-stubs = {extras = ["s3"], version = "^1.37.19"}
|
||||
pyarrow = "20.0.0" # transitive dependency, pinned here to avoid conflicts
|
||||
boto3-stubs = { extras = [ "s3" ], version = "^1.37.19" }
|
||||
pyarrow = "20.0.0" # transitive dependency, pinned here to avoid conflicts
|
||||
datasets = "*"
|
||||
|
||||
[tool.poetry-dynamic-versioning]
|
||||
enable = true
|
||||
style = "semver"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
openhands = "openhands.core.cli:main"
|
||||
|
||||
@@ -164,3 +139,24 @@ fuzzywuzzy = "^0.18.0"
|
||||
rouge = "^1.0.1"
|
||||
python-levenshtein = ">=0.26.1,<0.28.0"
|
||||
tree-sitter-python = "^0.23.6"
|
||||
|
||||
[tool.poetry-dynamic-versioning]
|
||||
enable = true
|
||||
style = "semver"
|
||||
|
||||
[tool.autopep8]
|
||||
# autopep8 fights with mypy on line length issue
|
||||
ignore = [ "E501" ]
|
||||
|
||||
[tool.black]
|
||||
# prevent black (if installed) from changing single quotes to double quotes
|
||||
skip-string-normalization = true
|
||||
|
||||
[tool.ruff]
|
||||
lint.select = [ "D" ]
|
||||
# ignore warnings for missing docstrings
|
||||
lint.ignore = [ "D1" ]
|
||||
lint.pydocstyle.convention = "google"
|
||||
|
||||
[tool.coverage.run]
|
||||
concurrency = [ "gevent" ]
|
||||
|
||||
60
tests/runtime/README.md
Normal file
60
tests/runtime/README.md
Normal file
@@ -0,0 +1,60 @@
|
||||
## Runtime Tests
|
||||
|
||||
This folder contains integration tests that verify the functionality of OpenHands' runtime environments and their interactions with various tools and features.
|
||||
|
||||
### What are Runtime Tests?
|
||||
|
||||
Runtime tests focus on testing:
|
||||
- Tool interactions within a runtime environment (bash commands, browsing, file operations)
|
||||
- Environment setup and configuration
|
||||
- Resource management and cleanup
|
||||
- Browser-based operations and file viewing capabilities
|
||||
- IPython/Jupyter integration
|
||||
- Environment variables and configuration handling
|
||||
|
||||
The tests can be run against different runtime environments (Docker, Local, Remote, Runloop, or Daytona) by setting the TEST_RUNTIME environment variable. By default, tests run using the Docker runtime.
|
||||
|
||||
### How are they different from Unit Tests?
|
||||
|
||||
While unit tests in `tests/unit/` focus on testing individual components in isolation, runtime tests verify:
|
||||
1. Integration between components
|
||||
2. Actual execution of commands in different runtime environments
|
||||
3. System-level interactions (file system, network, browser)
|
||||
4. Environment setup and teardown
|
||||
5. Tool functionality in real runtime contexts
|
||||
|
||||
### Running the Tests
|
||||
|
||||
Run all runtime tests:
|
||||
|
||||
```bash
|
||||
poetry run pytest ./tests/runtime
|
||||
```
|
||||
|
||||
Run specific test file:
|
||||
|
||||
```bash
|
||||
poetry run pytest ./tests/runtime/test_bash.py
|
||||
```
|
||||
|
||||
Run specific test:
|
||||
|
||||
```bash
|
||||
poetry run pytest ./tests/runtime/test_bash.py::test_bash_command_env
|
||||
```
|
||||
|
||||
For verbose output, add the `-v` flag (more verbose: `-vv` and `-vvv`):
|
||||
|
||||
```bash
|
||||
poetry run pytest -v ./tests/runtime/test_bash.py
|
||||
```
|
||||
|
||||
### Environment Variables
|
||||
|
||||
The runtime tests can be configured using environment variables:
|
||||
- `TEST_IN_CI`: Set to 'True' when running in CI environment
|
||||
- `TEST_RUNTIME`: Specify the runtime to test ('docker', 'local', 'remote', 'runloop', 'daytona')
|
||||
- `RUN_AS_OPENHANDS`: Set to 'True' to run tests as openhands user (default), 'False' for root
|
||||
- `SANDBOX_BASE_CONTAINER_IMAGE`: Specify a custom base container image for Docker runtime
|
||||
|
||||
For more details on pytest usage, see the [pytest documentation](https://docs.pytest.org/en/latest/contents.html).
|
||||
@@ -1,11 +1,13 @@
|
||||
"""Editor-related tests for the DockerRuntime."""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from conftest import _close_test_runtime, _load_runtime
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action import FileEditAction, FileWriteAction
|
||||
from openhands.runtime.action_execution_server import _execute_file_editor
|
||||
|
||||
|
||||
def test_view_file(temp_dir, runtime_cls, run_as_openhands):
|
||||
@@ -690,3 +692,32 @@ def test_view_large_file_with_truncation(temp_dir, runtime_cls, run_as_openhands
|
||||
)
|
||||
finally:
|
||||
_close_test_runtime(runtime)
|
||||
|
||||
|
||||
def test_insert_line_string_conversion():
|
||||
"""Test that insert_line is properly converted from string to int.
|
||||
|
||||
This test reproduces issue #8369 Example 2 where a string value for insert_line
|
||||
causes a TypeError in the editor.
|
||||
"""
|
||||
# Mock the OHEditor
|
||||
mock_editor = MagicMock()
|
||||
mock_editor.return_value = MagicMock(
|
||||
error=None, output='Success', old_content=None, new_content=None
|
||||
)
|
||||
|
||||
# Test with string insert_line
|
||||
result, _ = _execute_file_editor(
|
||||
editor=mock_editor,
|
||||
command='insert',
|
||||
path='/test/path.py',
|
||||
insert_line='185', # String instead of int
|
||||
new_str='test content',
|
||||
)
|
||||
|
||||
# Verify the editor was called with the correct parameters (insert_line converted to int)
|
||||
mock_editor.assert_called_once()
|
||||
args, kwargs = mock_editor.call_args
|
||||
assert isinstance(kwargs['insert_line'], int)
|
||||
assert kwargs['insert_line'] == 185
|
||||
assert result == 'Success'
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -117,7 +117,20 @@ def test_read_pdf_browse(temp_dir, runtime_cls, run_as_openhands):
|
||||
observation_text = str(obs)
|
||||
assert '[Action executed successfully.]' in observation_text
|
||||
assert 'Canvas' in observation_text
|
||||
assert (
|
||||
'Screenshot saved to: /workspace/.browser_screenshots/screenshot_'
|
||||
in observation_text
|
||||
)
|
||||
|
||||
# Check the /workspace/.browser_screenshots folder
|
||||
action_cmd = CmdRunAction(command='ls /workspace/.browser_screenshots')
|
||||
logger.info(action_cmd, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action_cmd)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
assert obs.exit_code == 0
|
||||
assert 'screenshot_' in obs.content
|
||||
assert '.png' in obs.content
|
||||
finally:
|
||||
_close_test_runtime(runtime)
|
||||
|
||||
@@ -169,6 +182,19 @@ def test_read_png_browse(temp_dir, runtime_cls, run_as_openhands):
|
||||
observation_text = str(obs)
|
||||
assert '[Action executed successfully.]' in observation_text
|
||||
assert 'File Viewer - test_image.png' in observation_text
|
||||
assert (
|
||||
'Screenshot saved to: /workspace/.browser_screenshots/screenshot_'
|
||||
in observation_text
|
||||
)
|
||||
|
||||
# Check the /workspace/.browser_screenshots folder
|
||||
action_cmd = CmdRunAction(command='ls /workspace/.browser_screenshots')
|
||||
logger.info(action_cmd, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action_cmd)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
assert obs.exit_code == 0
|
||||
assert 'screenshot_' in obs.content
|
||||
assert '.png' in obs.content
|
||||
finally:
|
||||
_close_test_runtime(runtime)
|
||||
|
||||
@@ -25,9 +25,9 @@ def test_env_vars_os_environ(temp_dir, runtime_cls, run_as_openhands):
|
||||
)
|
||||
print(obs)
|
||||
assert obs.exit_code == 0, 'The exit code should be 0.'
|
||||
assert (
|
||||
obs.content.strip().split('\n\r')[0].strip() == 'BAZ'
|
||||
), f'Output: [{obs.content}] for {runtime_cls}'
|
||||
assert obs.content.strip().split('\n\r')[0].strip() == 'BAZ', (
|
||||
f'Output: [{obs.content}] for {runtime_cls}'
|
||||
)
|
||||
|
||||
_close_test_runtime(runtime)
|
||||
|
||||
|
||||
@@ -168,9 +168,9 @@ def test_grep_to_cmdrun_paths_with_spaces(runtime_cls, run_as_openhands, temp_di
|
||||
|
||||
obs = _run_cmd_action(runtime, cmd)
|
||||
assert obs.exit_code == 0, f'Grep command failed for path: {path}'
|
||||
assert (
|
||||
'function' in obs.content
|
||||
), f'Expected pattern not found in output for path: {path}'
|
||||
assert 'function' in obs.content, (
|
||||
f'Expected pattern not found in output for path: {path}'
|
||||
)
|
||||
|
||||
# Verify the actual file was found
|
||||
if path == 'src/my project':
|
||||
|
||||
@@ -77,9 +77,9 @@ def test_simple_cmd_ipython_and_fileop(temp_dir, runtime_cls, run_as_openhands):
|
||||
action_read = FileReadAction(path='hello.sh')
|
||||
logger.info(action_read, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action_read)
|
||||
assert isinstance(
|
||||
obs, FileReadObservation
|
||||
), 'The observation should be a FileReadObservation.'
|
||||
assert isinstance(obs, FileReadObservation), (
|
||||
'The observation should be a FileReadObservation.'
|
||||
)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
|
||||
assert obs.content == 'echo "Hello, World!"\n'
|
||||
@@ -194,6 +194,52 @@ def test_ipython_simple(temp_dir, runtime_cls):
|
||||
_close_test_runtime(runtime)
|
||||
|
||||
|
||||
def test_ipython_chdir(temp_dir, runtime_cls):
|
||||
"""Test that os.chdir correctly handles paths with slashes."""
|
||||
runtime, config = _load_runtime(temp_dir, runtime_cls)
|
||||
|
||||
# Create a test directory and get its absolute path
|
||||
test_code = """
|
||||
import os
|
||||
os.makedirs('test_dir', exist_ok=True)
|
||||
abs_path = os.path.abspath('test_dir')
|
||||
print(abs_path)
|
||||
"""
|
||||
action_ipython = IPythonRunCellAction(code=test_code)
|
||||
logger.info(action_ipython, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action_ipython)
|
||||
assert isinstance(obs, IPythonRunCellObservation)
|
||||
test_dir_path = obs.content.split('\n')[0].strip()
|
||||
logger.info(f'test_dir_path: {test_dir_path}')
|
||||
assert test_dir_path # Verify we got a valid path
|
||||
|
||||
# Change to the test directory using its absolute path
|
||||
test_code = f"""
|
||||
import os
|
||||
os.chdir(r'{test_dir_path}')
|
||||
print(os.getcwd())
|
||||
"""
|
||||
action_ipython = IPythonRunCellAction(code=test_code)
|
||||
logger.info(action_ipython, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action_ipython)
|
||||
assert isinstance(obs, IPythonRunCellObservation)
|
||||
current_dir = obs.content.split('\n')[0].strip()
|
||||
assert current_dir == test_dir_path # Verify we changed to the correct directory
|
||||
|
||||
# Clean up
|
||||
test_code = """
|
||||
import os
|
||||
import shutil
|
||||
shutil.rmtree('test_dir', ignore_errors=True)
|
||||
"""
|
||||
action_ipython = IPythonRunCellAction(code=test_code)
|
||||
logger.info(action_ipython, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action_ipython)
|
||||
assert isinstance(obs, IPythonRunCellObservation)
|
||||
|
||||
_close_test_runtime(runtime)
|
||||
|
||||
|
||||
def test_ipython_package_install(temp_dir, runtime_cls, run_as_openhands):
|
||||
"""Make sure that cd in bash also update the current working directory in ipython."""
|
||||
runtime, config = _load_runtime(temp_dir, runtime_cls, run_as_openhands)
|
||||
|
||||
@@ -39,9 +39,9 @@ def test_edit_from_scratch(temp_dir, runtime_cls, run_as_openhands):
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
|
||||
assert isinstance(
|
||||
obs, FileEditObservation
|
||||
), 'The observation should be a FileEditObservation.'
|
||||
assert isinstance(obs, FileEditObservation), (
|
||||
'The observation should be a FileEditObservation.'
|
||||
)
|
||||
|
||||
action = FileReadAction(
|
||||
path=os.path.join('/workspace', 'app.py'),
|
||||
@@ -78,9 +78,9 @@ def test_edit(temp_dir, runtime_cls, run_as_openhands):
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
|
||||
assert isinstance(
|
||||
obs, FileEditObservation
|
||||
), 'The observation should be a FileEditObservation.'
|
||||
assert isinstance(obs, FileEditObservation), (
|
||||
'The observation should be a FileEditObservation.'
|
||||
)
|
||||
|
||||
action = FileReadAction(
|
||||
path=os.path.join('/workspace', 'app.py'),
|
||||
@@ -138,9 +138,9 @@ def test_edit_long_file(temp_dir, runtime_cls, run_as_openhands):
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
|
||||
assert isinstance(
|
||||
obs, FileEditObservation
|
||||
), 'The observation should be a FileEditObservation.'
|
||||
assert isinstance(obs, FileEditObservation), (
|
||||
'The observation should be a FileEditObservation.'
|
||||
)
|
||||
|
||||
action = FileReadAction(
|
||||
path=os.path.join('/workspace', 'app.py'),
|
||||
|
||||
@@ -2,7 +2,10 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
import socket
|
||||
import time
|
||||
|
||||
import docker
|
||||
import pytest
|
||||
from conftest import (
|
||||
_load_runtime,
|
||||
@@ -10,7 +13,7 @@ from conftest import (
|
||||
|
||||
import openhands
|
||||
from openhands.core.config import MCPConfig
|
||||
from openhands.core.config.mcp_config import MCPStdioServerConfig
|
||||
from openhands.core.config.mcp_config import MCPSSEServerConfig, MCPStdioServerConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action import CmdRunAction, MCPAction
|
||||
from openhands.events.observation import CmdOutputObservation, MCPObservation
|
||||
@@ -20,12 +23,90 @@ from openhands.events.observation import CmdOutputObservation, MCPObservation
|
||||
# ============================================================================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sse_mcp_docker_server():
|
||||
"""Manages the lifecycle of the SSE MCP Docker container for tests, using a random available port."""
|
||||
image_name = 'supercorp/supergateway'
|
||||
|
||||
# Find a free port
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(('', 0))
|
||||
host_port = s.getsockname()[1]
|
||||
|
||||
container_internal_port = (
|
||||
8000 # The port the MCP server listens on *inside* the container
|
||||
)
|
||||
|
||||
container_command_args = [
|
||||
'--stdio',
|
||||
'npx -y @modelcontextprotocol/server-filesystem /',
|
||||
'--port',
|
||||
str(container_internal_port), # MCP server inside container listens on this
|
||||
'--baseUrl',
|
||||
f'http://localhost:{host_port}', # The URL used to access the server from the host
|
||||
]
|
||||
client = docker.from_env()
|
||||
container = None
|
||||
log_streamer = None
|
||||
|
||||
# Import LogStreamer here as it's specific to this fixture's needs
|
||||
from openhands.runtime.utils.log_streamer import LogStreamer
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
f'Starting Docker container {image_name} with command: {" ".join(container_command_args)} '
|
||||
f'and mapping internal port {container_internal_port} to host port {host_port}',
|
||||
extra={'msg_type': 'ACTION'},
|
||||
)
|
||||
container = client.containers.run(
|
||||
image_name,
|
||||
command=container_command_args,
|
||||
ports={
|
||||
f'{container_internal_port}/tcp': host_port
|
||||
}, # Map container's internal port to the random host port
|
||||
detach=True,
|
||||
auto_remove=True,
|
||||
stdin_open=True,
|
||||
)
|
||||
logger.info(
|
||||
f'Container {container.short_id} started, listening on host port {host_port}.'
|
||||
)
|
||||
|
||||
log_streamer = LogStreamer(
|
||||
container,
|
||||
lambda level, msg: getattr(logger, level.lower())(
|
||||
f'[MCP server {container.short_id}] {msg}'
|
||||
),
|
||||
)
|
||||
# Wait for the server to initialize, as in the original tests
|
||||
time.sleep(10)
|
||||
|
||||
yield {'url': f'http://localhost:{host_port}/sse'}
|
||||
|
||||
finally:
|
||||
if container:
|
||||
logger.info(f'Stopping container {container.short_id}...')
|
||||
try:
|
||||
container.stop(timeout=5)
|
||||
logger.info(
|
||||
f'Container {container.short_id} stopped (and should be auto-removed).'
|
||||
)
|
||||
except docker.errors.NotFound:
|
||||
logger.info(
|
||||
f'Container {container.short_id} not found, likely already stopped and removed.'
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'Error stopping container {container.short_id}: {e}')
|
||||
if log_streamer:
|
||||
log_streamer.close()
|
||||
|
||||
|
||||
def test_default_activated_tools():
|
||||
project_root = os.path.dirname(openhands.__file__)
|
||||
mcp_config_path = os.path.join(project_root, 'runtime', 'mcp', 'config.json')
|
||||
assert os.path.exists(
|
||||
mcp_config_path
|
||||
), f'MCP config file not found at {mcp_config_path}'
|
||||
assert os.path.exists(mcp_config_path), (
|
||||
f'MCP config file not found at {mcp_config_path}'
|
||||
)
|
||||
with open(mcp_config_path, 'r') as f:
|
||||
mcp_config = json.load(f)
|
||||
assert 'default' in mcp_config
|
||||
@@ -62,9 +143,9 @@ async def test_fetch_mcp_via_stdio(temp_dir, runtime_cls, run_as_openhands):
|
||||
mcp_action = MCPAction(name='fetch', arguments={'url': 'http://localhost:8000'})
|
||||
obs = await runtime.call_tool_mcp(mcp_action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(
|
||||
obs, MCPObservation
|
||||
), 'The observation should be a MCPObservation.'
|
||||
assert isinstance(obs, MCPObservation), (
|
||||
'The observation should be a MCPObservation.'
|
||||
)
|
||||
|
||||
result_json = json.loads(obs.content)
|
||||
assert not result_json['isError']
|
||||
@@ -76,3 +157,110 @@ async def test_fetch_mcp_via_stdio(temp_dir, runtime_cls, run_as_openhands):
|
||||
)
|
||||
|
||||
runtime.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filesystem_mcp_via_sse(
|
||||
temp_dir, runtime_cls, run_as_openhands, sse_mcp_docker_server
|
||||
):
|
||||
sse_server_info = sse_mcp_docker_server
|
||||
sse_url = sse_server_info['url']
|
||||
runtime = None
|
||||
try:
|
||||
mcp_sse_server_config = MCPSSEServerConfig(url=sse_url)
|
||||
override_mcp_config = MCPConfig(sse_servers=[mcp_sse_server_config])
|
||||
runtime, config = _load_runtime(
|
||||
temp_dir,
|
||||
runtime_cls,
|
||||
run_as_openhands,
|
||||
override_mcp_config=override_mcp_config,
|
||||
)
|
||||
|
||||
mcp_action = MCPAction(name='list_directory', arguments={'path': '.'})
|
||||
obs = await runtime.call_tool_mcp(mcp_action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs, MCPObservation), (
|
||||
'The observation should be a MCPObservation.'
|
||||
)
|
||||
assert '[FILE] .dockerenv' in obs.content
|
||||
|
||||
finally:
|
||||
if runtime:
|
||||
runtime.close()
|
||||
# Container and log_streamer cleanup is handled by the sse_mcp_docker_server fixture
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_both_stdio_and_sse_mcp(
|
||||
temp_dir, runtime_cls, run_as_openhands, sse_mcp_docker_server
|
||||
):
|
||||
sse_server_info = sse_mcp_docker_server
|
||||
sse_url = sse_server_info['url']
|
||||
runtime = None
|
||||
try:
|
||||
mcp_sse_server_config = MCPSSEServerConfig(url=sse_url)
|
||||
|
||||
# Also add stdio server
|
||||
mcp_stdio_server_config = MCPStdioServerConfig(
|
||||
name='fetch', command='uvx', args=['mcp-server-fetch']
|
||||
)
|
||||
|
||||
override_mcp_config = MCPConfig(
|
||||
sse_servers=[mcp_sse_server_config], stdio_servers=[mcp_stdio_server_config]
|
||||
)
|
||||
runtime, config = _load_runtime(
|
||||
temp_dir,
|
||||
runtime_cls,
|
||||
run_as_openhands,
|
||||
override_mcp_config=override_mcp_config,
|
||||
)
|
||||
|
||||
# ======= Test SSE server =======
|
||||
mcp_action_sse = MCPAction(name='list_directory', arguments={'path': '.'})
|
||||
obs_sse = await runtime.call_tool_mcp(mcp_action_sse)
|
||||
logger.info(obs_sse, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs_sse, MCPObservation), (
|
||||
'The observation should be a MCPObservation.'
|
||||
)
|
||||
assert '[FILE] .dockerenv' in obs_sse.content
|
||||
|
||||
# ======= Test stdio server =======
|
||||
# Test browser server
|
||||
action_cmd_http = CmdRunAction(
|
||||
command='python3 -m http.server 8000 > server.log 2>&1 &'
|
||||
)
|
||||
logger.info(action_cmd_http, extra={'msg_type': 'ACTION'})
|
||||
obs_http = runtime.run_action(action_cmd_http)
|
||||
logger.info(obs_http, extra={'msg_type': 'OBSERVATION'})
|
||||
|
||||
assert isinstance(obs_http, CmdOutputObservation)
|
||||
assert obs_http.exit_code == 0
|
||||
assert '[1]' in obs_http.content
|
||||
|
||||
action_cmd_cat = CmdRunAction(command='sleep 3 && cat server.log')
|
||||
logger.info(action_cmd_cat, extra={'msg_type': 'ACTION'})
|
||||
obs_cat = runtime.run_action(action_cmd_cat)
|
||||
logger.info(obs_cat, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs_cat.exit_code == 0
|
||||
|
||||
mcp_action_fetch = MCPAction(
|
||||
name='fetch', arguments={'url': 'http://localhost:8000'}
|
||||
)
|
||||
obs_fetch = await runtime.call_tool_mcp(mcp_action_fetch)
|
||||
logger.info(obs_fetch, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(obs_fetch, MCPObservation), (
|
||||
'The observation should be a MCPObservation.'
|
||||
)
|
||||
|
||||
result_json = json.loads(obs_fetch.content)
|
||||
assert not result_json['isError']
|
||||
assert len(result_json['content']) == 1
|
||||
assert result_json['content'][0]['type'] == 'text'
|
||||
assert (
|
||||
result_json['content'][0]['text']
|
||||
== 'Contents of http://localhost:8000/:\n---\n\n* <server.log>\n\n---'
|
||||
)
|
||||
finally:
|
||||
if runtime:
|
||||
runtime.close()
|
||||
# SSE Docker container cleanup is handled by the sse_mcp_docker_server fixture
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Stress tests for the DockerRuntime, which connects to the ActionExecutor running in the sandbox."""
|
||||
|
||||
import pytest
|
||||
from conftest import _close_test_runtime, _load_runtime
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@@ -7,6 +8,7 @@ from openhands.events.action import CmdRunAction
|
||||
|
||||
|
||||
def test_stress_docker_runtime(temp_dir, runtime_cls, repeat=1):
|
||||
pytest.skip('This test is flaky')
|
||||
runtime, config = _load_runtime(
|
||||
temp_dir,
|
||||
runtime_cls,
|
||||
|
||||
@@ -468,9 +468,9 @@ def test_stress_runtime_memory_limits_with_repeated_file_edit():
|
||||
new_str=f'-content_{i:03d}',
|
||||
)
|
||||
obs = runtime.run_action(edit_action)
|
||||
assert (
|
||||
f'The file {test_file} has been edited' in obs.content
|
||||
), f'Edit failed at iteration {i}'
|
||||
assert f'The file {test_file} has been edited' in obs.content, (
|
||||
f'Edit failed at iteration {i}'
|
||||
)
|
||||
logger.info(f'finished iteration {i}')
|
||||
|
||||
# Verify final file state using FileEditAction view command
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user