mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 497fd4a02c | |||
| 7fac7d6dd0 | |||
| 4155b8f801 | |||
| e712b013f9 | |||
| 29b137e9b1 | |||
| 1292f0c2ea | |||
| da8c946078 | |||
| 43cef1f969 | |||
| ab661b485b | |||
| 9632914bf0 | |||
| 3db780ef93 | |||
| 43c16516e8 |
@@ -316,10 +316,6 @@ llm_config = 'gpt3'
|
||||
# Additional Docker runtime kwargs
|
||||
#docker_runtime_kwargs = {}
|
||||
|
||||
# Specific port to use for VSCode. If not set, a random port will be chosen.
|
||||
# Useful when deploying OpenHands in a remote machine where you need to expose a specific port.
|
||||
#vscode_port = 41234
|
||||
|
||||
#################################### Security ###################################
|
||||
# Configuration for security features
|
||||
##############################################################################
|
||||
|
||||
@@ -4,38 +4,6 @@
|
||||
OpenHands only supports Windows via WSL. Please be sure to run all commands inside your WSL terminal.
|
||||
:::
|
||||
|
||||
### Unable to access VS Code tab via local IP
|
||||
|
||||
**Description**
|
||||
|
||||
When accessing OpenHands through a non-localhost URL (such as a LAN IP address), the VS Code tab shows a "Forbidden" error, while other parts of the UI work fine.
|
||||
|
||||
**Resolution**
|
||||
|
||||
This happens because VS Code runs on a random high port that may not be exposed or accessible from other machines. To fix this:
|
||||
|
||||
1. Set a specific port for VS Code using the `SANDBOX_VSCODE_PORT` environment variable:
|
||||
```bash
|
||||
docker run -it --rm \
|
||||
-e SANDBOX_VSCODE_PORT=41234 \
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:latest \
|
||||
-v /var/run/docker.sock:/var/run/docker.sock \
|
||||
-v ~/.openhands-state:/.openhands-state \
|
||||
-p 3000:3000 \
|
||||
-p 41234:41234 \
|
||||
--add-host host.docker.internal:host-gateway \
|
||||
--name openhands-app \
|
||||
docker.all-hands.dev/all-hands-ai/openhands:latest
|
||||
```
|
||||
|
||||
2. Make sure to expose the same port with `-p 41234:41234` in your Docker command.
|
||||
|
||||
3. Alternatively, you can set this in your `config.toml` file:
|
||||
```toml
|
||||
[sandbox]
|
||||
vscode_port = 41234
|
||||
```
|
||||
|
||||
### Launch docker client failed
|
||||
|
||||
**Description**
|
||||
|
||||
@@ -89,19 +89,8 @@ describe("Settings Billing", () => {
|
||||
|
||||
renderSettingsScreen();
|
||||
|
||||
// Instead of looking for exact text, we'll check if any element contains "Credits"
|
||||
const navbar = await screen.findByTestId("settings-navbar");
|
||||
|
||||
// Wait for the component to render fully
|
||||
await new Promise(resolve => setTimeout(resolve, 100));
|
||||
|
||||
// Get all text elements and check if any contain "Credits"
|
||||
const allElements = within(navbar).queryAllByText(/./i);
|
||||
const hasCreditsTab = allElements.some(el =>
|
||||
el.textContent && el.textContent.toLowerCase().includes("credits")
|
||||
);
|
||||
|
||||
expect(hasCreditsTab).toBe(true);
|
||||
within(navbar).getByText("Credits");
|
||||
});
|
||||
|
||||
it("should render the billing settings if clicking the credits item", async () => {
|
||||
@@ -119,28 +108,10 @@ describe("Settings Billing", () => {
|
||||
renderSettingsScreen();
|
||||
|
||||
const navbar = await screen.findByTestId("settings-navbar");
|
||||
|
||||
// Wait for the component to render fully
|
||||
await new Promise(resolve => setTimeout(resolve, 100));
|
||||
|
||||
// Find all links in the navbar
|
||||
const navLinks = navbar.querySelectorAll('a');
|
||||
|
||||
// Find the credits link by checking the href
|
||||
const creditsLink = Array.from(navLinks).find(link =>
|
||||
link.getAttribute('href')?.includes('/settings/credits') ||
|
||||
link.textContent?.toLowerCase().includes('credits')
|
||||
);
|
||||
|
||||
// Make sure we found the credits link
|
||||
expect(creditsLink).toBeTruthy();
|
||||
|
||||
// Click the credits link if found
|
||||
if (creditsLink) {
|
||||
await user.click(creditsLink);
|
||||
|
||||
const billingSection = await screen.findByTestId("billing-settings");
|
||||
expect(billingSection).toBeInTheDocument();
|
||||
}
|
||||
const credits = within(navbar).getByText("Credits");
|
||||
await user.click(credits);
|
||||
|
||||
const billingSection = await screen.findByTestId("billing-settings");
|
||||
expect(billingSection).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -118,30 +118,17 @@ describe("Settings Screen", () => {
|
||||
renderSettingsScreen();
|
||||
|
||||
const navbar = await screen.findByTestId("settings-navbar");
|
||||
|
||||
// Wait for the component to render fully
|
||||
await new Promise(resolve => setTimeout(resolve, 100));
|
||||
|
||||
// Get all text elements in the navbar
|
||||
const allElements = navbar.querySelectorAll('a span');
|
||||
const allText = Array.from(allElements).map(el => el.textContent?.toLowerCase() || '');
|
||||
|
||||
// Check that each section to include has a matching element
|
||||
sectionsToInclude.forEach((section) => {
|
||||
const hasSection = allText.some(text =>
|
||||
text.includes(section.toLowerCase())
|
||||
) || Array.from(navbar.querySelectorAll('a')).some(link =>
|
||||
link.getAttribute('href')?.toLowerCase().includes(section.toLowerCase())
|
||||
);
|
||||
expect(hasSection).toBe(true);
|
||||
const sectionElement = within(navbar).getByText(section, {
|
||||
exact: false, // case insensitive
|
||||
});
|
||||
expect(sectionElement).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Check that each section to exclude does not have a matching element
|
||||
sectionsToExclude.forEach((section) => {
|
||||
const hasSection = allText.some(text =>
|
||||
text.includes(section.toLowerCase())
|
||||
);
|
||||
expect(hasSection).toBe(false);
|
||||
const sectionElement = within(navbar).queryByText(section, {
|
||||
exact: false, // case insensitive
|
||||
});
|
||||
expect(sectionElement).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
import { useEffect } from "react";
|
||||
import { openHands } from "#/api/open-hands-axios";
|
||||
import { useLogoutHandler } from "#/hooks/useLogoutHandler";
|
||||
|
||||
interface AxiosInterceptorSetupProps {
|
||||
appMode?: string;
|
||||
}
|
||||
|
||||
export function AxiosInterceptorSetup({ appMode }: AxiosInterceptorSetupProps) {
|
||||
const handleLogoutAndRefresh = useLogoutHandler(appMode);
|
||||
|
||||
useEffect(() => {
|
||||
const interceptor = openHands.interceptors.response.use(
|
||||
(response) => response,
|
||||
async (error) => {
|
||||
if (
|
||||
error.response &&
|
||||
error.response.status === 401 &&
|
||||
localStorage.getItem("providersAreSet") === "true"
|
||||
) {
|
||||
await handleLogoutAndRefresh();
|
||||
}
|
||||
|
||||
return Promise.reject(error);
|
||||
},
|
||||
);
|
||||
|
||||
return () => {
|
||||
openHands.interceptors.response.eject(interceptor);
|
||||
};
|
||||
}, [handleLogoutAndRefresh]);
|
||||
|
||||
return null; // It's a logical component
|
||||
}
|
||||
@@ -28,11 +28,6 @@ function AuthProvider({
|
||||
initialProvidersAreSet,
|
||||
);
|
||||
|
||||
// Update localStorage when providersAreSet changes
|
||||
React.useEffect(() => {
|
||||
localStorage.setItem("providersAreSet", providersAreSet.toString());
|
||||
}, [providersAreSet]);
|
||||
|
||||
const value = React.useMemo(
|
||||
() => ({
|
||||
providerTokensSet,
|
||||
@@ -40,10 +35,10 @@ function AuthProvider({
|
||||
providersAreSet,
|
||||
setProvidersAreSet,
|
||||
}),
|
||||
[providerTokensSet, providersAreSet],
|
||||
[providerTokensSet],
|
||||
);
|
||||
|
||||
return <AuthContext.Provider value={value}>{children}</AuthContext.Provider>;
|
||||
return <AuthContext value={value}>{children}</AuthContext>;
|
||||
}
|
||||
|
||||
function useAuth() {
|
||||
|
||||
@@ -17,22 +17,19 @@ import { AuthProvider } from "./context/auth-context";
|
||||
import { queryClientConfig } from "./query-client-config";
|
||||
import OpenHands from "./api/open-hands";
|
||||
import { displayErrorToast } from "./utils/custom-toast-handlers";
|
||||
import { AxiosInterceptorSetup } from "./components/AxiosInterceptorSetup";
|
||||
|
||||
function AppInitializers() {
|
||||
function PosthogInit() {
|
||||
const [posthogClientKey, setPosthogClientKey] = React.useState<string | null>(
|
||||
null,
|
||||
);
|
||||
const [appMode, setAppMode] = React.useState<string | undefined>(undefined);
|
||||
|
||||
React.useEffect(() => {
|
||||
(async () => {
|
||||
try {
|
||||
const config = await OpenHands.getConfig();
|
||||
setPosthogClientKey(config.POSTHOG_CLIENT_KEY);
|
||||
setAppMode(config.APP_MODE);
|
||||
} catch (error) {
|
||||
displayErrorToast("Error fetching app configuration");
|
||||
displayErrorToast("Error fetching PostHog client key");
|
||||
}
|
||||
})();
|
||||
}, []);
|
||||
@@ -46,7 +43,7 @@ function AppInitializers() {
|
||||
}
|
||||
}, [posthogClientKey]);
|
||||
|
||||
return appMode ? <AxiosInterceptorSetup appMode={appMode} /> : null;
|
||||
return null;
|
||||
}
|
||||
|
||||
async function prepareApp() {
|
||||
@@ -73,7 +70,7 @@ prepareApp().then(() =>
|
||||
<AuthProvider>
|
||||
<QueryClientProvider client={queryClient}>
|
||||
<HydratedRouter />
|
||||
<AppInitializers />
|
||||
<PosthogInit />
|
||||
</QueryClientProvider>
|
||||
</AuthProvider>
|
||||
</Provider>
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
import React from "react";
|
||||
import { createLogoutHandler } from "#/utils/auth-utils";
|
||||
|
||||
export const useLogoutHandler = (appMode?: string) =>
|
||||
React.useMemo(() => createLogoutHandler(appMode), [appMode]);
|
||||
@@ -1,27 +0,0 @@
|
||||
/**
|
||||
* Utility functions for authentication
|
||||
*/
|
||||
|
||||
/**
|
||||
* Creates a logout handler function
|
||||
* @param appMode The current app mode
|
||||
* @returns A function that handles logout and browser refresh
|
||||
*/
|
||||
export const createLogoutHandler =
|
||||
(appMode: string | undefined) => async (): Promise<void> => {
|
||||
if (appMode === "saas") {
|
||||
try {
|
||||
const baseURL = `${window.location.protocol}//${
|
||||
import.meta.env.VITE_BACKEND_BASE_URL || window?.location.host
|
||||
}`;
|
||||
await fetch(`${baseURL}/api/logout`, {
|
||||
method: "POST",
|
||||
credentials: "include",
|
||||
});
|
||||
} catch (error) {
|
||||
// Error during logout is not critical as we'll refresh anyway
|
||||
} finally {
|
||||
window.location.reload();
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -162,7 +162,7 @@ class BrowsingAgent(Agent):
|
||||
last_action = event
|
||||
elif isinstance(event, MessageAction) and event.source == EventSource.AGENT:
|
||||
# agent has responded, task finished.
|
||||
return AgentFinishAction(outputs={'content': event.content})
|
||||
return AgentFinishAction(final_thought=event.content)
|
||||
elif isinstance(event, Observation):
|
||||
last_obs = event
|
||||
|
||||
@@ -201,10 +201,8 @@ class BrowsingAgent(Agent):
|
||||
)
|
||||
return MessageAction('Error encountered when browsing.')
|
||||
|
||||
goal, _ = state.get_current_user_intent()
|
||||
|
||||
if goal is None:
|
||||
goal = state.inputs['task']
|
||||
user_message_action = state.get_current_user_intent()
|
||||
goal = user_message_action.content
|
||||
|
||||
system_msg = get_system_message(
|
||||
goal,
|
||||
|
||||
@@ -1,12 +1,8 @@
|
||||
import copy
|
||||
import os
|
||||
from collections import deque
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm import ChatCompletionToolParam
|
||||
from openhands.events.action import Action
|
||||
from openhands.llm.llm import ModelResponse
|
||||
from litellm import ChatCompletionToolParam
|
||||
|
||||
import openhands.agenthub.codeact_agent.function_calling as codeact_function_calling
|
||||
from openhands.agenthub.codeact_agent.tools.bash import create_cmd_run_tool
|
||||
@@ -24,7 +20,7 @@ from openhands.controller.state.state import State
|
||||
from openhands.core.config import AgentConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.message import Message
|
||||
from openhands.events.action import AgentFinishAction, MessageAction
|
||||
from openhands.events.action import Action, AgentFinishAction, MessageAction
|
||||
from openhands.events.event import Event
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.memory.condenser import Condenser
|
||||
@@ -79,26 +75,23 @@ class CodeActAgent(Agent):
|
||||
- config (AgentConfig): The configuration for this agent
|
||||
"""
|
||||
super().__init__(llm, config)
|
||||
self.pending_actions: deque['Action'] = deque()
|
||||
self.pending_actions: deque[Action] = deque()
|
||||
self.reset()
|
||||
self.tools = self._get_tools()
|
||||
|
||||
self.prompt_manager = PromptManager(
|
||||
prompt_dir=os.path.join(os.path.dirname(__file__), 'prompts'),
|
||||
)
|
||||
|
||||
# Create a ConversationMemory instance
|
||||
self.conversation_memory = ConversationMemory(self.config, self.prompt_manager)
|
||||
|
||||
self.condenser = Condenser.from_config(self.config.condenser)
|
||||
logger.debug(f'Using condenser: {type(self.condenser)}')
|
||||
|
||||
@property
|
||||
def prompt_manager(self) -> PromptManager:
|
||||
if self._prompt_manager is None:
|
||||
self._prompt_manager = PromptManager(
|
||||
prompt_dir=os.path.join(os.path.dirname(__file__), 'prompts'),
|
||||
)
|
||||
self.response_to_actions_fn = codeact_function_calling.response_to_actions
|
||||
|
||||
return self._prompt_manager
|
||||
|
||||
def _get_tools(self) -> list['ChatCompletionToolParam']:
|
||||
def _get_tools(self) -> list[ChatCompletionToolParam]:
|
||||
# For these models, we use short tool descriptions ( < 1024 tokens)
|
||||
# to avoid hitting the OpenAI token limit for tool descriptions.
|
||||
SHORT_TOOL_DESCRIPTION_LLM_SUBSTRS = ['gpt-', 'o3', 'o1', 'o4']
|
||||
@@ -137,7 +130,7 @@ class CodeActAgent(Agent):
|
||||
super().reset()
|
||||
self.pending_actions.clear()
|
||||
|
||||
def step(self, state: State) -> 'Action':
|
||||
def step(self, state: State) -> Action:
|
||||
"""Performs one step using the CodeAct Agent.
|
||||
|
||||
This includes gathering info on previous steps and prompting the model to make a command to execute.
|
||||
@@ -205,7 +198,9 @@ class CodeActAgent(Agent):
|
||||
params['extra_body'] = {'metadata': state.to_llm_metadata(agent_name=self.name)}
|
||||
response = self.llm.completion(**params)
|
||||
logger.debug(f'Response from LLM: {response}')
|
||||
actions = self.response_to_actions(response)
|
||||
actions = self.response_to_actions_fn(
|
||||
response, mcp_tool_names=list(self.mcp_tools.keys())
|
||||
)
|
||||
logger.debug(f'Actions after response_to_actions: {actions}')
|
||||
for action in actions:
|
||||
self.pending_actions.append(action)
|
||||
@@ -279,8 +274,3 @@ class CodeActAgent(Agent):
|
||||
self.conversation_memory.apply_prompt_caching(messages)
|
||||
|
||||
return messages
|
||||
|
||||
def response_to_actions(self, response: 'ModelResponse') -> list['Action']:
|
||||
return codeact_function_calling.response_to_actions(
|
||||
response, mcp_tool_names=list(self.mcp_tools.keys())
|
||||
)
|
||||
|
||||
@@ -105,7 +105,8 @@ def response_to_actions(
|
||||
elif tool_call.function.name == 'delegate_to_browsing_agent':
|
||||
action = AgentDelegateAction(
|
||||
agent='BrowsingAgent',
|
||||
inputs=arguments,
|
||||
prompt=arguments.get('prompt', ''),
|
||||
inputs={},
|
||||
)
|
||||
|
||||
# ================================================
|
||||
@@ -113,8 +114,10 @@ def response_to_actions(
|
||||
# ================================================
|
||||
elif tool_call.function.name == FinishTool['function']['name']:
|
||||
action = AgentFinishAction(
|
||||
final_thought=arguments.get('message', ''),
|
||||
outputs=arguments.get('outputs', {}),
|
||||
thought=arguments.get('thought', ''),
|
||||
task_completed=arguments.get('task_completed', None),
|
||||
final_thought=arguments.get('final_thought', ''),
|
||||
)
|
||||
|
||||
# ================================================
|
||||
|
||||
@@ -4,13 +4,6 @@ ReadOnlyAgent - A specialized version of CodeActAgent that only uses read-only t
|
||||
|
||||
import os
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm import ChatCompletionToolParam
|
||||
from openhands.events.action import Action
|
||||
from openhands.llm.llm import ModelResponse
|
||||
|
||||
from openhands.agenthub.codeact_agent.codeact_agent import CodeActAgent
|
||||
from openhands.agenthub.readonly_agent import (
|
||||
function_calling as readonly_function_calling,
|
||||
@@ -48,27 +41,24 @@ class ReadOnlyAgent(CodeActAgent):
|
||||
- llm (LLM): The llm to be used by this agent
|
||||
- config (AgentConfig): The configuration for this agent
|
||||
"""
|
||||
# Initialize the CodeActAgent class; some of it is overridden with class methods
|
||||
# Initialize the CodeActAgent class but we'll override some of its behavior
|
||||
super().__init__(llm, config)
|
||||
|
||||
# Override the tools to only include read-only tools
|
||||
# Get the read-only tools from our own function_calling module
|
||||
self.tools = readonly_function_calling.get_tools()
|
||||
|
||||
# Set up our own prompt manager
|
||||
self.prompt_manager = PromptManager(
|
||||
prompt_dir=os.path.join(os.path.dirname(__file__), 'prompts'),
|
||||
)
|
||||
|
||||
self.response_to_actions_fn = readonly_function_calling.response_to_actions
|
||||
|
||||
logger.debug(
|
||||
f"TOOLS loaded for ReadOnlyAgent: {', '.join([tool.get('function').get('name') for tool in self.tools])}"
|
||||
)
|
||||
|
||||
@property
|
||||
def prompt_manager(self) -> PromptManager:
|
||||
# Set up our own prompt manager
|
||||
if self._prompt_manager is None:
|
||||
self._prompt_manager = PromptManager(
|
||||
prompt_dir=os.path.join(os.path.dirname(__file__), 'prompts'),
|
||||
)
|
||||
return self._prompt_manager
|
||||
|
||||
def _get_tools(self) -> list['ChatCompletionToolParam']:
|
||||
# Override the tools to only include read-only tools
|
||||
# Get the read-only tools from our own function_calling module
|
||||
return readonly_function_calling.get_tools()
|
||||
|
||||
def set_mcp_tools(self, mcp_tools: list[dict]) -> None:
|
||||
"""Sets the list of MCP tools for the agent.
|
||||
|
||||
@@ -78,8 +68,3 @@ class ReadOnlyAgent(CodeActAgent):
|
||||
logger.warning(
|
||||
'ReadOnlyAgent does not support MCP tools. MCP tools will be ignored by the agent.'
|
||||
)
|
||||
|
||||
def response_to_actions(self, response: 'ModelResponse') -> list['Action']:
|
||||
return readonly_function_calling.response_to_actions(
|
||||
response, mcp_tool_names=list(self.mcp_tools.keys())
|
||||
)
|
||||
|
||||
@@ -216,7 +216,7 @@ Note:
|
||||
last_action = event
|
||||
elif isinstance(event, MessageAction) and event.source == EventSource.AGENT:
|
||||
# agent has responded, task finished.
|
||||
return AgentFinishAction(outputs={'content': event.content})
|
||||
return AgentFinishAction(final_thought=event.content)
|
||||
elif isinstance(event, Observation):
|
||||
# Only process BrowserOutputObservation and skip other observation types
|
||||
if not isinstance(event, BrowserOutputObservation):
|
||||
@@ -271,10 +271,10 @@ Note:
|
||||
)
|
||||
return MessageAction('Error encountered when browsing.')
|
||||
set_of_marks = last_obs.set_of_marks
|
||||
goal, image_urls = state.get_current_user_intent()
|
||||
user_message_action = state.get_current_user_intent()
|
||||
goal = user_message_action.content
|
||||
image_urls = user_message_action.image_urls
|
||||
|
||||
if goal is None:
|
||||
goal = state.inputs['task']
|
||||
goal_txt, goal_images = create_goal_prompt(goal, image_urls)
|
||||
observation_txt, som_screenshot = create_observation_prompt(
|
||||
cur_axtree_txt, tabs, focused_element, error_prefix, set_of_marks
|
||||
|
||||
@@ -8,7 +8,6 @@ if TYPE_CHECKING:
|
||||
from openhands.core.config import AgentConfig
|
||||
from openhands.events.action import Action
|
||||
from openhands.events.action.message import SystemMessageAction
|
||||
from openhands.utils.prompt import PromptManager
|
||||
from litellm import ChatCompletionToolParam
|
||||
|
||||
from openhands.core.exceptions import (
|
||||
@@ -20,6 +19,9 @@ from openhands.events.event import EventSource
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.runtime.plugins import PluginRequirement
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openhands.utils.prompt import PromptManager
|
||||
|
||||
|
||||
class Agent(ABC):
|
||||
DEPRECATED = False
|
||||
@@ -41,16 +43,10 @@ class Agent(ABC):
|
||||
self.llm = llm
|
||||
self.config = config
|
||||
self._complete = False
|
||||
self._prompt_manager: 'PromptManager' | None = None
|
||||
self.prompt_manager: 'PromptManager' | None = None
|
||||
self.mcp_tools: dict[str, ChatCompletionToolParam] = {}
|
||||
self.tools: list = []
|
||||
|
||||
@property
|
||||
def prompt_manager(self) -> 'PromptManager':
|
||||
if self._prompt_manager is None:
|
||||
raise ValueError(f'Prompt manager not initialized for agent {self.name}')
|
||||
return self._prompt_manager
|
||||
|
||||
def get_system_message(self) -> 'SystemMessageAction | None':
|
||||
"""
|
||||
Returns a SystemMessageAction containing the system message and tools.
|
||||
|
||||
@@ -438,12 +438,13 @@ class AgentController:
|
||||
elif isinstance(action, AgentDelegateAction):
|
||||
await self.start_delegate(action)
|
||||
assert self.delegate is not None
|
||||
# Post a MessageAction with the task for the delegate
|
||||
if 'task' in action.inputs:
|
||||
# Post a MessageAction with the prompt for the delegate
|
||||
if action.prompt:
|
||||
self.event_stream.add_event(
|
||||
MessageAction(content='TASK: ' + action.inputs['task']),
|
||||
EventSource.USER,
|
||||
MessageAction(content=action.prompt),
|
||||
EventSource.USER, # Source is USER, as it represents the task prompt for the delegate
|
||||
)
|
||||
# Delegate starts in RUNNING state as it receives the prompt immediately
|
||||
await self.delegate.set_agent_state_to(AgentState.RUNNING)
|
||||
return
|
||||
|
||||
@@ -727,34 +728,22 @@ class AgentController:
|
||||
# close the delegate controller before adding new events
|
||||
asyncio.get_event_loop().run_until_complete(self.delegate.close())
|
||||
|
||||
if delegate_state in (AgentState.FINISHED, AgentState.REJECTED):
|
||||
# retrieve delegate result
|
||||
delegate_outputs = (
|
||||
self.delegate.state.outputs if self.delegate.state else {}
|
||||
)
|
||||
# prepare delegate result observation
|
||||
delegate_outputs = self.delegate.state.outputs if self.delegate.state else {}
|
||||
formatted_output = ', '.join(
|
||||
f'{key}: {value}' for key, value in delegate_outputs.items()
|
||||
)
|
||||
|
||||
# prepare delegate result observation
|
||||
# TODO: replace this with AI-generated summary (#2395)
|
||||
formatted_output = ', '.join(
|
||||
f'{key}: {value}' for key, value in delegate_outputs.items()
|
||||
)
|
||||
if delegate_state in (AgentState.FINISHED, AgentState.REJECTED):
|
||||
content = (
|
||||
f'{self.delegate.agent.name} finishes task with {formatted_output}'
|
||||
)
|
||||
else:
|
||||
# delegate state is ERROR
|
||||
# emit AgentDelegateObservation with error content
|
||||
delegate_outputs = (
|
||||
self.delegate.state.outputs if self.delegate.state else {}
|
||||
)
|
||||
content = (
|
||||
f'{self.delegate.agent.name} encountered an error during execution.'
|
||||
)
|
||||
|
||||
content = f'Delegated agent finished with result:\n\n{content}'
|
||||
content = f'{self.delegate.agent.name} encountered an error during execution. Known results: {delegate_outputs}'
|
||||
|
||||
# emit the delegate result observation
|
||||
obs = AgentDelegateObservation(outputs=delegate_outputs, content=content)
|
||||
obs = AgentDelegateObservation(content=content, outputs={})
|
||||
|
||||
# associate the delegate action with the initiating tool call
|
||||
for event in reversed(self.state.history):
|
||||
|
||||
@@ -188,19 +188,39 @@ class State:
|
||||
if not hasattr(self, 'history'):
|
||||
self.history = []
|
||||
|
||||
def get_current_user_intent(self) -> tuple[str | None, list[str] | None]:
|
||||
"""Returns the latest user message and image(if provided) that appears after a FinishAction, or the first (the task) if nothing was finished yet."""
|
||||
last_user_message = None
|
||||
last_user_message_image_urls: list[str] | None = []
|
||||
for event in reversed(self.view):
|
||||
if isinstance(event, MessageAction) and event.source == 'user':
|
||||
last_user_message = event.content
|
||||
last_user_message_image_urls = event.image_urls
|
||||
elif isinstance(event, AgentFinishAction):
|
||||
if last_user_message is not None:
|
||||
return last_user_message, None
|
||||
def get_current_user_intent(self) -> MessageAction:
|
||||
"""Returns the latest user MessageAction that appears after a FinishAction, or the first (the task) if nothing was finished yet."""
|
||||
likely_task: MessageAction | None = None
|
||||
|
||||
return last_user_message, last_user_message_image_urls
|
||||
# Search in the view for the latest user message after the last finish action
|
||||
for event in reversed(self.view):
|
||||
if isinstance(event, MessageAction) and event.source == EventSource.USER:
|
||||
likely_task = event
|
||||
elif isinstance(event, AgentFinishAction):
|
||||
# If a FinishAction is found, the user message after it is the one we just found (if any)
|
||||
break
|
||||
|
||||
# If a user message was found in the view after the last finish action, return it
|
||||
if likely_task is not None:
|
||||
return likely_task
|
||||
|
||||
# If no user message was found in the view after the last finish action,
|
||||
# it means either there were no user messages in the view, or the last event in the view was a FinishAction
|
||||
# In this case, we fall back to finding the very first user message in the full history.
|
||||
logger.warning(
|
||||
'No user message found in the view after the last FinishAction. Returning the first message in history.'
|
||||
)
|
||||
if self.history:
|
||||
# Look for the very first user message in the full history
|
||||
for event in self.history:
|
||||
if (
|
||||
isinstance(event, MessageAction)
|
||||
and event.source == EventSource.USER
|
||||
):
|
||||
return event
|
||||
|
||||
# If no user message is found in the entire history, raise an error
|
||||
raise ValueError('No user message found in history. This should not happen.')
|
||||
|
||||
def get_last_agent_message(self) -> MessageAction | None:
|
||||
for event in reversed(self.view):
|
||||
|
||||
@@ -6,11 +6,13 @@ from uuid import uuid4
|
||||
from prompt_toolkit.shortcuts import clear
|
||||
|
||||
import openhands.agenthub # noqa F401 (we import this to get the agents registered)
|
||||
from openhands.cli.commands import (
|
||||
from openhands.controller import AgentController
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.core.cli_commands import (
|
||||
check_folder_security_agreement,
|
||||
handle_commands,
|
||||
)
|
||||
from openhands.cli.tui import (
|
||||
from openhands.core.cli_tui import (
|
||||
UsageMetrics,
|
||||
display_agent_running_message,
|
||||
display_banner,
|
||||
@@ -23,11 +25,9 @@ from openhands.cli.tui import (
|
||||
read_confirmation_input,
|
||||
read_prompt_input,
|
||||
)
|
||||
from openhands.cli.utils import (
|
||||
from openhands.core.cli_utils import (
|
||||
update_usage_metrics,
|
||||
)
|
||||
from openhands.controller import AgentController
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.core.config import (
|
||||
AppConfig,
|
||||
parse_arguments,
|
||||
@@ -5,12 +5,12 @@ from prompt_toolkit import print_formatted_text
|
||||
from prompt_toolkit.shortcuts import clear, print_container
|
||||
from prompt_toolkit.widgets import Frame, TextArea
|
||||
|
||||
from openhands.cli.settings import (
|
||||
from openhands.core.cli_settings import (
|
||||
display_settings,
|
||||
modify_llm_settings_advanced,
|
||||
modify_llm_settings_basic,
|
||||
)
|
||||
from openhands.cli.tui import (
|
||||
from openhands.core.cli_tui import (
|
||||
COLOR_GREY,
|
||||
UsageMetrics,
|
||||
cli_confirm,
|
||||
@@ -18,7 +18,7 @@ from openhands.cli.tui import (
|
||||
display_shutdown_message,
|
||||
display_status,
|
||||
)
|
||||
from openhands.cli.utils import (
|
||||
from openhands.core.cli_utils import (
|
||||
add_local_config_trusted_dir,
|
||||
get_local_config_trusted_dirs,
|
||||
read_file,
|
||||
@@ -5,19 +5,19 @@ from prompt_toolkit.shortcuts import print_container
|
||||
from prompt_toolkit.widgets import Frame, TextArea
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.cli.tui import (
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.core.cli_tui import (
|
||||
COLOR_GREY,
|
||||
UserCancelledError,
|
||||
cli_confirm,
|
||||
kb_cancel,
|
||||
)
|
||||
from openhands.cli.utils import (
|
||||
from openhands.core.cli_utils import (
|
||||
VERIFIED_ANTHROPIC_MODELS,
|
||||
VERIFIED_OPENAI_MODELS,
|
||||
VERIFIED_PROVIDERS,
|
||||
organize_models_and_providers,
|
||||
)
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.core.config import AppConfig
|
||||
from openhands.core.config.condenser_config import NoOpCondenserConfig
|
||||
from openhands.core.config.utils import OH_DEFAULT_AGENT
|
||||
@@ -3,7 +3,7 @@ from typing import Dict, List
|
||||
|
||||
import toml
|
||||
|
||||
from openhands.cli.tui import (
|
||||
from openhands.core.cli_tui import (
|
||||
UsageMetrics,
|
||||
)
|
||||
from openhands.events.event import Event
|
||||
@@ -39,8 +39,6 @@ class SandboxConfig(BaseModel):
|
||||
docker_runtime_kwargs: Additional keyword arguments to pass to the Docker runtime when running containers.
|
||||
This should be a JSON string that will be parsed into a dictionary.
|
||||
trusted_dirs: List of directories that can be trusted to run the OpenHands CLI.
|
||||
vscode_port: The port to use for VSCode. If None, a random port will be chosen.
|
||||
This is useful when deploying OpenHands in a remote machine where you need to expose a specific port.
|
||||
"""
|
||||
|
||||
remote_runtime_api_url: str | None = Field(default='http://localhost:8000')
|
||||
@@ -79,7 +77,6 @@ class SandboxConfig(BaseModel):
|
||||
docker_runtime_kwargs: dict | None = Field(default=None)
|
||||
selected_repo: str | None = Field(default=None)
|
||||
trusted_dirs: list[str] = Field(default_factory=list)
|
||||
vscode_port: int | None = Field(default=None)
|
||||
|
||||
model_config = {'extra': 'forbid'}
|
||||
|
||||
|
||||
@@ -86,6 +86,13 @@ class AgentRejectAction(Action):
|
||||
class AgentDelegateAction(Action):
|
||||
agent: str
|
||||
inputs: dict
|
||||
"""Deprecated.
|
||||
Delegate agents run similarly to the main agent:
|
||||
- start from a prompt (passed in the 'prompt' field)
|
||||
- end with an AgentFinishAction.
|
||||
"""
|
||||
prompt: str
|
||||
"""The prompt/task for the delegate agent"""
|
||||
thought: str = ''
|
||||
action: str = ActionType.DELEGATE
|
||||
|
||||
|
||||
@@ -10,13 +10,18 @@ class AgentDelegateObservation(Observation):
|
||||
|
||||
Attributes:
|
||||
content (str): The content of the observation.
|
||||
outputs (dict): The outputs of the delegated agent.
|
||||
outputs (dict): The outputs of the delegated agent. (deprecated)
|
||||
observation (str): The type of observation.
|
||||
"""
|
||||
|
||||
outputs: dict
|
||||
"""Deprecated.
|
||||
Delegate agents run similarly to the main agent:
|
||||
- start from a prompt (passed in the 'prompt' field)
|
||||
- end with an AgentFinishAction.
|
||||
"""
|
||||
observation: str = ObservationType.DELEGATE
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return ''
|
||||
return self.content
|
||||
|
||||
@@ -19,9 +19,6 @@ from openhands.integrations.service_types import (
|
||||
)
|
||||
from openhands.server.types import AppMode
|
||||
from openhands.utils.import_utils import get_impl
|
||||
from openhands.integrations.github.queries import suggested_task_pr_graphql_query, suggested_task_issue_graphql_query
|
||||
from datetime import datetime
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
class GitHubService(BaseGitService, GitService):
|
||||
@@ -47,9 +44,6 @@ class GitHubService(BaseGitService, GitService):
|
||||
if base_domain:
|
||||
self.BASE_URL = f'https://{base_domain}/api/v3'
|
||||
|
||||
self.external_auth_id = external_auth_id
|
||||
self.external_auth_token = external_auth_token
|
||||
|
||||
@property
|
||||
def provider(self) -> str:
|
||||
return ProviderType.GITHUB.value
|
||||
@@ -290,21 +284,60 @@ class GitHubService(BaseGitService, GitService):
|
||||
Returns:
|
||||
- PRs authored by the user.
|
||||
- Issues assigned to the user.
|
||||
|
||||
Note: Queries are split to avoid timeout issues.
|
||||
"""
|
||||
# Get user info to use in queries
|
||||
user = await self.get_user()
|
||||
login = user.login
|
||||
tasks: list[SuggestedTask] = []
|
||||
|
||||
query = """
|
||||
query GetUserTasks($login: String!) {
|
||||
user(login: $login) {
|
||||
pullRequests(first: 100, states: [OPEN], orderBy: {field: UPDATED_AT, direction: DESC}) {
|
||||
nodes {
|
||||
number
|
||||
title
|
||||
repository {
|
||||
nameWithOwner
|
||||
}
|
||||
mergeable
|
||||
commits(last: 1) {
|
||||
nodes {
|
||||
commit {
|
||||
statusCheckRollup {
|
||||
state
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
reviews(first: 100, states: [CHANGES_REQUESTED, COMMENTED]) {
|
||||
nodes {
|
||||
state
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
issues(first: 100, states: [OPEN], filterBy: {assignee: $login}, orderBy: {field: UPDATED_AT, direction: DESC}) {
|
||||
nodes {
|
||||
number
|
||||
title
|
||||
repository {
|
||||
nameWithOwner
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
variables = {'login': login}
|
||||
|
||||
try:
|
||||
pr_response = await self.execute_graphql_query(suggested_task_pr_graphql_query, variables)
|
||||
pr_data = pr_response['data']['user']
|
||||
|
||||
response = await self.execute_graphql_query(query, variables)
|
||||
data = response['data']['user']
|
||||
tasks: list[SuggestedTask] = []
|
||||
|
||||
# Process pull requests
|
||||
for pr in pr_data['pullRequests']['nodes']:
|
||||
for pr in data['pullRequests']['nodes']:
|
||||
repo_name = pr['repository']['nameWithOwner']
|
||||
|
||||
# Start with default task type
|
||||
@@ -340,18 +373,8 @@ class GitHubService(BaseGitService, GitService):
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.info(f"Error fetching suggested task for PRs: {e}",
|
||||
extra={'signal': 'github_suggested_tasks', 'user_id': self.external_auth_id})
|
||||
|
||||
try:
|
||||
# Execute issue query
|
||||
issue_response = await self.execute_graphql_query(suggested_task_issue_graphql_query, variables)
|
||||
issue_data = issue_response['data']['user']
|
||||
|
||||
# Process issues
|
||||
for issue in issue_data['issues']['nodes']:
|
||||
for issue in data['issues']['nodes']:
|
||||
repo_name = issue['repository']['nameWithOwner']
|
||||
tasks.append(
|
||||
SuggestedTask(
|
||||
@@ -364,12 +387,8 @@ class GitHubService(BaseGitService, GitService):
|
||||
)
|
||||
|
||||
return tasks
|
||||
|
||||
except Exception as e:
|
||||
logger.info(f"Error fetching suggested task for issues: {e}",
|
||||
extra={'signal': 'github_suggested_tasks', 'user_id': self.external_auth_id})
|
||||
|
||||
return tasks
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
async def get_repository_details_from_repo_name(
|
||||
self, repository: str
|
||||
|
||||
@@ -29,7 +29,7 @@ suggested_task_pr_graphql_query = """
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
|
||||
suggested_task_issue_graphql_query = """
|
||||
query GetUserIssues($login: String!) {
|
||||
user(login: $login) {
|
||||
|
||||
@@ -6,6 +6,7 @@ from pydantic import SecretStr
|
||||
|
||||
from openhands.integrations.service_types import (
|
||||
BaseGitService,
|
||||
Branch,
|
||||
GitService,
|
||||
ProviderType,
|
||||
Repository,
|
||||
@@ -131,7 +132,7 @@ class GitLabService(BaseGitService, GitService):
|
||||
|
||||
payload = {
|
||||
'query': query,
|
||||
'variables': variables,
|
||||
'variables': variables if variables is not None else {},
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
@@ -195,6 +196,7 @@ class GitLabService(BaseGitService, GitService):
|
||||
full_name=repo.get('path_with_namespace'),
|
||||
stargazers_count=repo.get('star_count'),
|
||||
git_provider=ProviderType.GITLAB,
|
||||
is_public=True,
|
||||
)
|
||||
for repo in response
|
||||
]
|
||||
@@ -398,6 +400,44 @@ class GitLabService(BaseGitService, GitService):
|
||||
is_public=repo.get('visibility') == 'public',
|
||||
)
|
||||
|
||||
async def get_branches(self, repository: str) -> list[Branch]:
|
||||
"""Get branches for a repository"""
|
||||
encoded_name = repository.replace('/', '%2F')
|
||||
url = f'{self.BASE_URL}/projects/{encoded_name}/repository/branches'
|
||||
|
||||
# Set maximum branches to fetch (10 pages with 100 per page)
|
||||
MAX_BRANCHES = 1000
|
||||
PER_PAGE = 100
|
||||
|
||||
all_branches: list[Branch] = []
|
||||
page = 1
|
||||
|
||||
# Fetch up to 10 pages of branches
|
||||
while page <= 10 and len(all_branches) < MAX_BRANCHES:
|
||||
params = {'per_page': str(PER_PAGE), 'page': str(page)}
|
||||
response, headers = await self._make_request(url, params)
|
||||
|
||||
if not response: # No more branches
|
||||
break
|
||||
|
||||
for branch_data in response:
|
||||
branch = Branch(
|
||||
name=branch_data.get('name'),
|
||||
commit_sha=branch_data.get('commit', {}).get('id', ''),
|
||||
protected=branch_data.get('protected', False),
|
||||
last_push_date=branch_data.get('commit', {}).get('committed_date'),
|
||||
)
|
||||
all_branches.append(branch)
|
||||
|
||||
page += 1
|
||||
|
||||
# Check if we've reached the last page
|
||||
link_header = headers.get('Link', '')
|
||||
if 'rel="next"' not in link_header:
|
||||
break
|
||||
|
||||
return all_branches
|
||||
|
||||
|
||||
gitlab_service_cls = os.environ.get(
|
||||
'OPENHANDS_GITLAB_SERVICE_CLS',
|
||||
|
||||
@@ -18,6 +18,7 @@ from openhands.integrations.github.github_service import GithubServiceImpl
|
||||
from openhands.integrations.gitlab.gitlab_service import GitLabServiceImpl
|
||||
from openhands.integrations.service_types import (
|
||||
AuthenticationError,
|
||||
Branch,
|
||||
GitService,
|
||||
ProviderType,
|
||||
Repository,
|
||||
@@ -30,6 +31,7 @@ from openhands.server.types import AppMode
|
||||
class ProviderToken(BaseModel):
|
||||
token: SecretStr | None = Field(default=None)
|
||||
user_id: str | None = Field(default=None)
|
||||
host: str | None = Field(default=None)
|
||||
|
||||
model_config = {
|
||||
'frozen': True, # Makes the entire model immutable
|
||||
@@ -39,15 +41,20 @@ class ProviderToken(BaseModel):
|
||||
@classmethod
|
||||
def from_value(cls, token_value: ProviderToken | dict[str, str]) -> ProviderToken:
|
||||
"""Factory method to create a ProviderToken from various input types"""
|
||||
if isinstance(token_value, ProviderToken):
|
||||
if isinstance(token_value, cls):
|
||||
return token_value
|
||||
elif isinstance(token_value, dict):
|
||||
token_str = token_value.get('token')
|
||||
token_str = token_value.get('token', '')
|
||||
# Override with emtpy string if it was set to None
|
||||
# Cannot pass None to SecretStr
|
||||
if token_str is None:
|
||||
token_str = ''
|
||||
user_id = token_value.get('user_id')
|
||||
return cls(token=SecretStr(token_str), user_id=user_id)
|
||||
host = token_value.get('host')
|
||||
return cls(token=SecretStr(token_str), user_id=user_id, host=host)
|
||||
|
||||
else:
|
||||
raise ValueError('Unsupport Provider token type')
|
||||
raise ValueError('Unsupported Provider token type')
|
||||
|
||||
|
||||
PROVIDER_TOKEN_TYPE = MappingProxyType[ProviderType, ProviderToken]
|
||||
@@ -165,7 +172,8 @@ class ProviderHandler:
|
||||
query, per_page, sort, order
|
||||
)
|
||||
all_repos.extend(service_repos)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.warning(f'Error searching repos from {provider}: {e}')
|
||||
continue
|
||||
|
||||
return all_repos
|
||||
@@ -305,3 +313,56 @@ class ProviderHandler:
|
||||
pass
|
||||
|
||||
raise AuthenticationError(f'Unable to access repo {repository}')
|
||||
|
||||
async def get_branches(
|
||||
self, repository: str, specified_provider: ProviderType | None = None
|
||||
) -> list[Branch]:
|
||||
"""
|
||||
Get branches for a repository
|
||||
|
||||
Args:
|
||||
repository: The repository name
|
||||
specified_provider: Optional provider type to use
|
||||
|
||||
Returns:
|
||||
A list of branches for the repository
|
||||
"""
|
||||
all_branches: list[Branch] = []
|
||||
|
||||
if specified_provider:
|
||||
try:
|
||||
service = self._get_service(specified_provider)
|
||||
branches = await service.get_branches(repository)
|
||||
return branches
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f'Error fetching branches from {specified_provider}: {e}'
|
||||
)
|
||||
|
||||
for provider in self.provider_tokens:
|
||||
try:
|
||||
service = self._get_service(provider)
|
||||
branches = await service.get_branches(repository)
|
||||
all_branches.extend(branches)
|
||||
# If we found branches, no need to check other providers
|
||||
if all_branches:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f'Error fetching branches from {provider}: {e}')
|
||||
|
||||
# Sort branches by last push date (newest first)
|
||||
all_branches.sort(
|
||||
key=lambda b: b.last_push_date if b.last_push_date else '', reverse=True
|
||||
)
|
||||
|
||||
# Move main/master branch to the top if it exists
|
||||
main_branches = []
|
||||
other_branches = []
|
||||
|
||||
for branch in all_branches:
|
||||
if branch.name.lower() in ['main', 'master']:
|
||||
main_branches.append(branch)
|
||||
else:
|
||||
other_branches.append(branch)
|
||||
|
||||
return main_branches + other_branches
|
||||
|
||||
@@ -91,6 +91,13 @@ class User(BaseModel):
|
||||
email: str | None = None
|
||||
|
||||
|
||||
class Branch(BaseModel):
|
||||
name: str
|
||||
commit_sha: str
|
||||
protected: bool
|
||||
last_push_date: str | None = None # ISO 8601 format date string
|
||||
|
||||
|
||||
class Repository(BaseModel):
|
||||
id: int
|
||||
full_name: str
|
||||
@@ -211,3 +218,6 @@ class GitService(Protocol):
|
||||
self, repository: str
|
||||
) -> Repository:
|
||||
"""Gets all repository details from repository name"""
|
||||
|
||||
async def get_branches(self, repository: str) -> list[Branch]:
|
||||
"""Get branches for a repository"""
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
Please summarize your work.
|
||||
|
||||
If you answered a question, please re-state the answer to the question
|
||||
If you made changes, please create a concise overview on whether the request has been addressed successfully or if there are were issues with the attempt.
|
||||
If successful, make sure your changes are pushed to the remote branch.
|
||||
@@ -9,6 +9,7 @@ We follow format from: https://docs.litellm.ai/docs/completion/function_call
|
||||
import copy
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from typing import Iterable
|
||||
|
||||
from litellm import ChatCompletionToolParam
|
||||
@@ -47,8 +48,15 @@ Reminder:
|
||||
|
||||
STOP_WORDS = ['</function']
|
||||
|
||||
|
||||
def refine_prompt(prompt: str) -> str:
|
||||
if sys.platform == 'win32':
|
||||
return prompt.replace('bash', 'powershell')
|
||||
return prompt
|
||||
|
||||
|
||||
# NOTE: we need to make sure this example is always in-sync with the tool interface designed in openhands/agenthub/codeact_agent/function_calling.py
|
||||
IN_CONTEXT_LEARNING_EXAMPLE_PREFIX = """
|
||||
IN_CONTEXT_LEARNING_EXAMPLE_PREFIX = refine_prompt("""
|
||||
Here's a running example of how to perform a task with the provided tools.
|
||||
|
||||
--------------------- START OF EXAMPLE ---------------------
|
||||
@@ -75,7 +83,7 @@ from flask import Flask
|
||||
app = Flask(__name__)
|
||||
|
||||
@app.route('/')
|
||||
def index():
|
||||
def index() -> str:
|
||||
numbers = list(range(1, 11))
|
||||
return str(numbers)
|
||||
|
||||
@@ -218,7 +226,7 @@ The server is running on port 5000 with PID 126. You can access the list of numb
|
||||
Do NOT assume the environment is the same as in the example above.
|
||||
|
||||
--------------------- NEW TASK DESCRIPTION ---------------------
|
||||
""".lstrip()
|
||||
""").lstrip()
|
||||
|
||||
IN_CONTEXT_LEARNING_EXAMPLE_SUFFIX = """
|
||||
--------------------- END OF NEW TASK DESCRIPTION ---------------------
|
||||
@@ -245,12 +253,12 @@ def convert_tool_call_to_string(tool_call: dict) -> str:
|
||||
if tool_call['type'] != 'function':
|
||||
raise FunctionCallConversionError("Tool call type must be 'function'.")
|
||||
|
||||
ret = f"<function={tool_call['function']['name']}>\n"
|
||||
ret = f'<function={tool_call["function"]["name"]}>\n'
|
||||
try:
|
||||
args = json.loads(tool_call['function']['arguments'])
|
||||
except json.JSONDecodeError as e:
|
||||
raise FunctionCallConversionError(
|
||||
f"Failed to parse arguments as JSON. Arguments: {tool_call['function']['arguments']}"
|
||||
f'Failed to parse arguments as JSON. Arguments: {tool_call["function"]["arguments"]}'
|
||||
) from e
|
||||
for param_name, param_value in args.items():
|
||||
is_multiline = isinstance(param_value, str) and '\n' in param_value
|
||||
@@ -272,8 +280,8 @@ def convert_tools_to_description(tools: list[dict]) -> str:
|
||||
fn = tool['function']
|
||||
if i > 0:
|
||||
ret += '\n'
|
||||
ret += f"---- BEGIN FUNCTION #{i+1}: {fn['name']} ----\n"
|
||||
ret += f"Description: {fn['description']}\n"
|
||||
ret += f'---- BEGIN FUNCTION #{i + 1}: {fn["name"]} ----\n'
|
||||
ret += f'Description: {fn["description"]}\n'
|
||||
|
||||
if 'parameters' in fn:
|
||||
ret += 'Parameters:\n'
|
||||
@@ -295,12 +303,12 @@ def convert_tools_to_description(tools: list[dict]) -> str:
|
||||
desc += f'\nAllowed values: [{enum_values}]'
|
||||
|
||||
ret += (
|
||||
f' ({j+1}) {param_name} ({param_type}, {param_status}): {desc}\n'
|
||||
f' ({j + 1}) {param_name} ({param_type}, {param_status}): {desc}\n'
|
||||
)
|
||||
else:
|
||||
ret += 'No parameters are required for this function.\n'
|
||||
|
||||
ret += f'---- END FUNCTION #{i+1} ----\n'
|
||||
ret += f'---- END FUNCTION #{i + 1} ----\n'
|
||||
return ret
|
||||
|
||||
|
||||
@@ -351,7 +359,8 @@ def convert_fncall_messages_to_non_fncall_messages(
|
||||
and any(
|
||||
(
|
||||
tool['type'] == 'function'
|
||||
and tool['function']['name'] == 'execute_bash'
|
||||
and tool['function']['name']
|
||||
== refine_prompt('execute_bash')
|
||||
and 'command'
|
||||
in tool['function']['parameters']['properties']
|
||||
)
|
||||
@@ -658,7 +667,7 @@ def convert_non_fncall_messages_to_fncall_messages(
|
||||
'content': [{'type': 'text', 'text': tool_result}]
|
||||
if isinstance(content, list)
|
||||
else tool_result,
|
||||
'tool_call_id': f'toolu_{tool_call_counter-1:02d}', # Use last generated ID
|
||||
'tool_call_id': f'toolu_{tool_call_counter - 1:02d}', # Use last generated ID
|
||||
}
|
||||
)
|
||||
else:
|
||||
@@ -781,14 +790,14 @@ def convert_from_multiple_tool_calls_to_single_tool_call_messages(
|
||||
# add the tool result
|
||||
converted_messages.append(message)
|
||||
else:
|
||||
assert (
|
||||
len(pending_tool_calls) == 0
|
||||
), f'Found pending tool calls but not found in pending list: {pending_tool_calls=}'
|
||||
assert len(pending_tool_calls) == 0, (
|
||||
f'Found pending tool calls but not found in pending list: {pending_tool_calls=}'
|
||||
)
|
||||
converted_messages.append(message)
|
||||
else:
|
||||
assert (
|
||||
len(pending_tool_calls) == 0
|
||||
), f'Found pending tool calls but not expect to handle it with role {role}: {pending_tool_calls=}, {message=}'
|
||||
assert len(pending_tool_calls) == 0, (
|
||||
f'Found pending tool calls but not expect to handle it with role {role}: {pending_tool_calls=}, {message=}'
|
||||
)
|
||||
converted_messages.append(message)
|
||||
|
||||
if not ignore_final_tool_result and len(pending_tool_calls) > 0:
|
||||
|
||||
+34
-9
@@ -49,6 +49,8 @@ LLM_RETRY_EXCEPTIONS: tuple[type[Exception], ...] = (
|
||||
# remove this when we gemini and deepseek are supported
|
||||
CACHE_PROMPT_SUPPORTED_MODELS = [
|
||||
'claude-3-7-sonnet-20250219',
|
||||
'claude-sonnet-3-7-latest',
|
||||
'claude-3.7-sonnet',
|
||||
'claude-3-5-sonnet-20241022',
|
||||
'claude-3-5-sonnet-20240620',
|
||||
'claude-3-5-haiku-20241022',
|
||||
@@ -59,6 +61,7 @@ CACHE_PROMPT_SUPPORTED_MODELS = [
|
||||
# function calling supporting models
|
||||
FUNCTION_CALLING_SUPPORTED_MODELS = [
|
||||
'claude-3-7-sonnet-20250219',
|
||||
'claude-sonnet-3-7-latest',
|
||||
'claude-3-5-sonnet',
|
||||
'claude-3-5-sonnet-20240620',
|
||||
'claude-3-5-sonnet-20241022',
|
||||
@@ -108,7 +111,7 @@ class LLM(RetryMixin, DebugMixin):
|
||||
config: LLMConfig,
|
||||
metrics: Metrics | None = None,
|
||||
retry_listener: Callable[[int, int], None] | None = None,
|
||||
):
|
||||
) -> None:
|
||||
"""Initializes the LLM. If LLMConfig is passed, its values will be the fallback.
|
||||
|
||||
Passing simple parameters always overrides config.
|
||||
@@ -199,7 +202,7 @@ class LLM(RetryMixin, DebugMixin):
|
||||
"""Wrapper for the litellm completion function. Logs the input and output of the completion function."""
|
||||
from openhands.io import json
|
||||
|
||||
messages: list[dict[str, Any]] | dict[str, Any] = []
|
||||
messages_kwarg: list[dict[str, Any]] | dict[str, Any] = []
|
||||
mock_function_calling = not self.is_function_calling_active()
|
||||
|
||||
# some callers might send the model and messages directly
|
||||
@@ -209,16 +212,18 @@ class LLM(RetryMixin, DebugMixin):
|
||||
# design wise: we don't allow overriding the configured values
|
||||
# implementation wise: the partial function set the model as a kwarg already
|
||||
# as well as other kwargs
|
||||
messages = args[1] if len(args) > 1 else args[0]
|
||||
kwargs['messages'] = messages
|
||||
messages_kwarg = args[1] if len(args) > 1 else args[0]
|
||||
kwargs['messages'] = messages_kwarg
|
||||
|
||||
# remove the first args, they're sent in kwargs
|
||||
args = args[2:]
|
||||
elif 'messages' in kwargs:
|
||||
messages = kwargs['messages']
|
||||
messages_kwarg = kwargs['messages']
|
||||
|
||||
# ensure we work with a list of messages
|
||||
messages = messages if isinstance(messages, list) else [messages]
|
||||
messages: list[dict[str, Any]] = (
|
||||
messages_kwarg if isinstance(messages_kwarg, list) else [messages_kwarg]
|
||||
)
|
||||
|
||||
# handle conversion of to non-function calling messages if needed
|
||||
original_fncall_messages = copy.deepcopy(messages)
|
||||
@@ -290,6 +295,7 @@ class LLM(RetryMixin, DebugMixin):
|
||||
)
|
||||
|
||||
non_fncall_response_message = resp.choices[0].message
|
||||
# messages is already a list with proper typing from line 223
|
||||
fn_call_messages_with_response = (
|
||||
convert_non_fncall_messages_to_fncall_messages(
|
||||
messages + [non_fncall_response_message], mock_fncall_tools
|
||||
@@ -412,6 +418,7 @@ class LLM(RetryMixin, DebugMixin):
|
||||
)
|
||||
if current_model_info:
|
||||
self.model_info = current_model_info['model_info']
|
||||
logger.debug(f'Got model info from litellm proxy: {self.model_info}')
|
||||
|
||||
# Last two attempts to get model info from NAME
|
||||
if not self.model_info:
|
||||
@@ -467,7 +474,10 @@ class LLM(RetryMixin, DebugMixin):
|
||||
self.model_info['max_tokens'], int
|
||||
):
|
||||
self.config.max_output_tokens = self.model_info['max_tokens']
|
||||
if 'claude-3-7-sonnet' in self.config.model:
|
||||
if any(
|
||||
model in self.config.model
|
||||
for model in ['claude-3-7-sonnet', 'claude-3.7-sonnet']
|
||||
):
|
||||
self.config.max_output_tokens = 64000 # litellm set max to 128k, but that requires a header to be set
|
||||
|
||||
# Initialize function calling capability
|
||||
@@ -598,6 +608,12 @@ class LLM(RetryMixin, DebugMixin):
|
||||
if cache_write_tokens:
|
||||
stats += 'Input tokens (cache write): ' + str(cache_write_tokens) + '\n'
|
||||
|
||||
# Get context window from model info
|
||||
context_window = 0
|
||||
if self.model_info and 'max_input_tokens' in self.model_info:
|
||||
context_window = self.model_info['max_input_tokens']
|
||||
logger.debug(f'Using context window: {context_window}')
|
||||
|
||||
# Record in metrics
|
||||
# We'll treat cache_hit_tokens as "cache read" and cache_write_tokens as "cache write"
|
||||
self.metrics.add_token_usage(
|
||||
@@ -605,6 +621,7 @@ class LLM(RetryMixin, DebugMixin):
|
||||
completion_tokens=completion_tokens,
|
||||
cache_read_tokens=cache_hit_tokens,
|
||||
cache_write_tokens=cache_write_tokens,
|
||||
context_window=context_window,
|
||||
response_id=response_id,
|
||||
)
|
||||
|
||||
@@ -631,7 +648,15 @@ class LLM(RetryMixin, DebugMixin):
|
||||
logger.info(
|
||||
'Message objects now include serialized tool calls in token counting'
|
||||
)
|
||||
messages = self.format_messages_for_llm(messages) # type: ignore
|
||||
# Assert the expected type for format_messages_for_llm
|
||||
assert isinstance(messages, list) and all(
|
||||
isinstance(m, Message) for m in messages
|
||||
), 'Expected list of Message objects'
|
||||
|
||||
# We've already asserted that messages is a list of Message objects
|
||||
# Use explicit typing to satisfy mypy
|
||||
messages_typed: list[Message] = messages # type: ignore
|
||||
messages = self.format_messages_for_llm(messages_typed)
|
||||
|
||||
# try to get the token count with the default litellm tokenizers
|
||||
# or the custom tokenizer if set for this LLM configuration
|
||||
@@ -662,7 +687,7 @@ class LLM(RetryMixin, DebugMixin):
|
||||
boolean: True if executing a local model.
|
||||
"""
|
||||
if self.config.base_url is not None:
|
||||
for substring in ['localhost', '127.0.0.1' '0.0.0.0']:
|
||||
for substring in ['localhost', '127.0.0.1', '0.0.0.0']:
|
||||
if substring in self.config.base_url:
|
||||
return True
|
||||
elif self.config.model is not None:
|
||||
|
||||
@@ -26,6 +26,8 @@ class TokenUsage(BaseModel):
|
||||
completion_tokens: int = Field(default=0)
|
||||
cache_read_tokens: int = Field(default=0)
|
||||
cache_write_tokens: int = Field(default=0)
|
||||
context_window: int = Field(default=0)
|
||||
per_turn_token: int = Field(default=0)
|
||||
response_id: str = Field(default='')
|
||||
|
||||
def __add__(self, other: 'TokenUsage') -> 'TokenUsage':
|
||||
@@ -36,6 +38,8 @@ class TokenUsage(BaseModel):
|
||||
completion_tokens=self.completion_tokens + other.completion_tokens,
|
||||
cache_read_tokens=self.cache_read_tokens + other.cache_read_tokens,
|
||||
cache_write_tokens=self.cache_write_tokens + other.cache_write_tokens,
|
||||
context_window=max(self.context_window, other.context_window),
|
||||
per_turn_token=other.per_turn_token,
|
||||
response_id=self.response_id,
|
||||
)
|
||||
|
||||
@@ -60,6 +64,7 @@ class Metrics:
|
||||
completion_tokens=0,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
context_window=0,
|
||||
response_id='',
|
||||
)
|
||||
|
||||
@@ -107,6 +112,7 @@ class Metrics:
|
||||
completion_tokens=0,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
context_window=0,
|
||||
response_id='',
|
||||
)
|
||||
return self._accumulated_token_usage
|
||||
@@ -130,15 +136,22 @@ class Metrics:
|
||||
completion_tokens: int,
|
||||
cache_read_tokens: int,
|
||||
cache_write_tokens: int,
|
||||
context_window: int,
|
||||
response_id: str,
|
||||
) -> None:
|
||||
"""Add a single usage record."""
|
||||
|
||||
# Token each turn for calculating context usage.
|
||||
per_turn_token = prompt_tokens + completion_tokens
|
||||
|
||||
usage = TokenUsage(
|
||||
model=self.model_name,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
cache_read_tokens=cache_read_tokens,
|
||||
cache_write_tokens=cache_write_tokens,
|
||||
context_window=context_window,
|
||||
per_turn_token=per_turn_token,
|
||||
response_id=response_id,
|
||||
)
|
||||
self._token_usages.append(usage)
|
||||
@@ -150,6 +163,8 @@ class Metrics:
|
||||
completion_tokens=completion_tokens,
|
||||
cache_read_tokens=cache_read_tokens,
|
||||
cache_write_tokens=cache_write_tokens,
|
||||
context_window=context_window,
|
||||
per_turn_token=per_turn_token,
|
||||
response_id='',
|
||||
)
|
||||
|
||||
@@ -190,6 +205,7 @@ class Metrics:
|
||||
completion_tokens=0,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
context_window=0,
|
||||
response_id='',
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
@@ -18,8 +18,8 @@ class MCPClient(BaseModel):
|
||||
session: Optional[ClientSession] = None
|
||||
exit_stack: AsyncExitStack = AsyncExitStack()
|
||||
description: str = 'MCP client tools for server interaction'
|
||||
tools: List[MCPClientTool] = Field(default_factory=list)
|
||||
tool_map: Dict[str, MCPClientTool] = Field(default_factory=dict)
|
||||
tools: list[MCPClientTool] = Field(default_factory=list)
|
||||
tool_map: dict[str, MCPClientTool] = Field(default_factory=dict)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
@@ -91,7 +91,7 @@ class MCPClient(BaseModel):
|
||||
f'Connected to server with tools: {[tool.name for tool in response.tools]}'
|
||||
)
|
||||
|
||||
async def call_tool(self, tool_name: str, args: Dict):
|
||||
async def call_tool(self, tool_name: str, args: dict):
|
||||
"""Call a tool on the MCP server."""
|
||||
if tool_name not in self.tool_map:
|
||||
raise ValueError(f'Tool {tool_name} not found.')
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Dict
|
||||
|
||||
from mcp.types import Tool
|
||||
|
||||
|
||||
@@ -14,7 +12,7 @@ class MCPClientTool(Tool):
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def to_param(self) -> Dict:
|
||||
def to_param(self) -> dict:
|
||||
"""Convert tool to function call format."""
|
||||
return {
|
||||
'type': 'function',
|
||||
|
||||
@@ -158,12 +158,12 @@ async def add_mcp_tools_to_agent(
|
||||
ActionExecutionClient, # inline import to avoid circular import
|
||||
)
|
||||
|
||||
assert isinstance(
|
||||
runtime, ActionExecutionClient
|
||||
), 'Runtime must be an instance of ActionExecutionClient'
|
||||
assert (
|
||||
runtime.runtime_initialized
|
||||
), 'Runtime must be initialized before adding MCP tools'
|
||||
assert isinstance(runtime, ActionExecutionClient), (
|
||||
'Runtime must be an instance of ActionExecutionClient'
|
||||
)
|
||||
assert runtime.runtime_initialized, (
|
||||
'Runtime must be initialized before adding MCP tools'
|
||||
)
|
||||
|
||||
# Add the runtime as another MCP server
|
||||
updated_mcp_config = runtime.get_updated_mcp_config()
|
||||
@@ -171,7 +171,7 @@ async def add_mcp_tools_to_agent(
|
||||
mcp_tools = await fetch_mcp_tools_from_config(updated_mcp_config)
|
||||
|
||||
logger.info(
|
||||
f"Loaded {len(mcp_tools)} MCP tools: {[tool['function']['name'] for tool in mcp_tools]}"
|
||||
f'Loaded {len(mcp_tools)} MCP tools: {[tool["function"]["name"] for tool in mcp_tools]}'
|
||||
)
|
||||
|
||||
# Set the MCP tools on the agent
|
||||
|
||||
@@ -28,7 +28,7 @@ class BrowserOutputCondenser(Condenser):
|
||||
):
|
||||
results.append(
|
||||
AgentCondensationObservation(
|
||||
f'Current URL: {event.url}\nContent Omitted'
|
||||
f'Visited URL {event.url}\nContent omitted'
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -26,6 +26,7 @@ jobs:
|
||||
base_container_image: ${{ vars.OPENHANDS_BASE_CONTAINER_IMAGE || '' }}
|
||||
LLM_MODEL: ${{ vars.LLM_MODEL || 'anthropic/claude-3-5-sonnet-20241022' }}
|
||||
target_branch: ${{ vars.TARGET_BRANCH || 'main' }}
|
||||
runner: ${{ vars.TARGET_RUNNER }}
|
||||
secrets:
|
||||
PAT_TOKEN: ${{ secrets.PAT_TOKEN }}
|
||||
PAT_USERNAME: ${{ secrets.PAT_USERNAME }}
|
||||
|
||||
@@ -214,7 +214,7 @@ class GitlabIssueHandler(IssueHandlerInterface):
|
||||
|
||||
def reply_to_comment(self, pr_number: int, comment_id: str, reply: str) -> None:
|
||||
response = httpx.get(
|
||||
f'{self.base_url}/merge_requests/{pr_number}/discussions/{comment_id.split('/')[-1]}',
|
||||
f'{self.base_url}/merge_requests/{pr_number}/discussions/{comment_id.split("/")[-1]}',
|
||||
headers=self.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
@@ -225,7 +225,7 @@ class GitlabIssueHandler(IssueHandlerInterface):
|
||||
'note_id': discussions.get('notes', [])[-1]['id'],
|
||||
}
|
||||
response = httpx.post(
|
||||
f'{self.base_url}/merge_requests/{pr_number}/discussions/{comment_id.split('/')[-1]}/notes',
|
||||
f'{self.base_url}/merge_requests/{pr_number}/discussions/{comment_id.split("/")[-1]}/notes',
|
||||
headers=self.headers,
|
||||
json=data,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
from openhands.core.config import LLMConfig
|
||||
from openhands.integrations.provider import ProviderType
|
||||
from openhands.resolver.interfaces.github import GithubIssueHandler, GithubPRHandler
|
||||
from openhands.resolver.interfaces.gitlab import GitlabIssueHandler, GitlabPRHandler
|
||||
from openhands.resolver.interfaces.issue_definitions import (
|
||||
ServiceContextIssue,
|
||||
ServiceContextPR,
|
||||
)
|
||||
|
||||
|
||||
class IssueHandlerFactory:
|
||||
def __init__(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
token: str,
|
||||
username: str,
|
||||
platform: ProviderType,
|
||||
base_domain: str,
|
||||
issue_type: str,
|
||||
llm_config: LLMConfig,
|
||||
) -> None:
|
||||
self.owner = owner
|
||||
self.repo = repo
|
||||
self.token = token
|
||||
self.username = username
|
||||
self.platform = platform
|
||||
self.base_domain = base_domain
|
||||
self.issue_type = issue_type
|
||||
self.llm_config = llm_config
|
||||
|
||||
def create(self) -> ServiceContextIssue | ServiceContextPR:
|
||||
if self.issue_type == 'issue':
|
||||
if self.platform == ProviderType.GITHUB:
|
||||
return ServiceContextIssue(
|
||||
GithubIssueHandler(
|
||||
self.owner,
|
||||
self.repo,
|
||||
self.token,
|
||||
self.username,
|
||||
self.base_domain,
|
||||
),
|
||||
self.llm_config,
|
||||
)
|
||||
else: # platform == Platform.GITLAB
|
||||
return ServiceContextIssue(
|
||||
GitlabIssueHandler(
|
||||
self.owner,
|
||||
self.repo,
|
||||
self.token,
|
||||
self.username,
|
||||
self.base_domain,
|
||||
),
|
||||
self.llm_config,
|
||||
)
|
||||
elif self.issue_type == 'pr':
|
||||
if self.platform == ProviderType.GITHUB:
|
||||
return ServiceContextPR(
|
||||
GithubPRHandler(
|
||||
self.owner,
|
||||
self.repo,
|
||||
self.token,
|
||||
self.username,
|
||||
self.base_domain,
|
||||
),
|
||||
self.llm_config,
|
||||
)
|
||||
else: # platform == Platform.GITLAB
|
||||
return ServiceContextPR(
|
||||
GitlabPRHandler(
|
||||
self.owner,
|
||||
self.repo,
|
||||
self.token,
|
||||
self.username,
|
||||
self.base_domain,
|
||||
),
|
||||
self.llm_config,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Invalid issue type: {self.issue_type}')
|
||||
@@ -6,11 +6,9 @@ class HunkException(PatchingException):
|
||||
def __init__(self, msg: str, hunk: int | None = None) -> None:
|
||||
self.hunk = hunk
|
||||
if hunk is not None:
|
||||
super(HunkException, self).__init__(
|
||||
'{msg}, in hunk #{n}'.format(msg=msg, n=hunk)
|
||||
)
|
||||
super().__init__('{msg}, in hunk #{n}'.format(msg=msg, n=hunk))
|
||||
else:
|
||||
super(HunkException, self).__init__(msg)
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
class ApplyException(PatchingException):
|
||||
@@ -19,7 +17,7 @@ class ApplyException(PatchingException):
|
||||
|
||||
class SubprocessException(ApplyException):
|
||||
def __init__(self, msg: str, code: int) -> None:
|
||||
super(SubprocessException, self).__init__(msg)
|
||||
super().__init__(msg)
|
||||
self.code = code
|
||||
|
||||
|
||||
|
||||
@@ -28,13 +28,12 @@ from openhands.events.observation import (
|
||||
)
|
||||
from openhands.events.stream import EventStreamSubscriber
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.resolver.interfaces.github import GithubIssueHandler, GithubPRHandler
|
||||
from openhands.resolver.interfaces.gitlab import GitlabIssueHandler, GitlabPRHandler
|
||||
from openhands.resolver.interfaces.issue import Issue
|
||||
from openhands.resolver.interfaces.issue_definitions import (
|
||||
ServiceContextIssue,
|
||||
ServiceContextPR,
|
||||
)
|
||||
from openhands.resolver.issue_handler_factory import IssueHandlerFactory
|
||||
from openhands.resolver.resolver_output import ResolverOutput
|
||||
from openhands.resolver.utils import (
|
||||
codeact_user_response,
|
||||
@@ -111,12 +110,22 @@ class IssueResolver:
|
||||
model = args.llm_model or os.environ['LLM_MODEL']
|
||||
base_url = args.llm_base_url or os.environ.get('LLM_BASE_URL', None)
|
||||
api_version = os.environ.get('LLM_API_VERSION', None)
|
||||
llm_num_retries = int(os.environ.get('LLM_NUM_RETRIES', '4'))
|
||||
llm_retry_min_wait = int(os.environ.get('LLM_RETRY_MIN_WAIT', '5'))
|
||||
llm_retry_max_wait = int(os.environ.get('LLM_RETRY_MAX_WAIT', '30'))
|
||||
llm_retry_multiplier = int(os.environ.get('LLM_RETRY_MULTIPLIER', 2))
|
||||
llm_timeout = int(os.environ.get('LLM_TIMEOUT', 0))
|
||||
|
||||
# Create LLMConfig instance
|
||||
llm_config = LLMConfig(
|
||||
model=model,
|
||||
api_key=SecretStr(api_key) if api_key else None,
|
||||
base_url=base_url,
|
||||
num_retries=llm_num_retries,
|
||||
retry_min_wait=llm_retry_min_wait,
|
||||
retry_max_wait=llm_retry_max_wait,
|
||||
retry_multiplier=llm_retry_multiplier,
|
||||
timeout=llm_timeout,
|
||||
)
|
||||
|
||||
# Only set api_version if it was explicitly provided, otherwise let LLMConfig handle it
|
||||
@@ -152,8 +161,6 @@ class IssueResolver:
|
||||
|
||||
self.owner = owner
|
||||
self.repo = repo
|
||||
self.token = token
|
||||
self.username = username
|
||||
self.platform = platform
|
||||
self.runtime_container_image = runtime_container_image
|
||||
self.base_container_image = base_container_image
|
||||
@@ -165,9 +172,20 @@ class IssueResolver:
|
||||
self.repo_instruction = repo_instruction
|
||||
self.issue_number = args.issue_number
|
||||
self.comment_id = args.comment_id
|
||||
self.base_domain = base_domain
|
||||
self.platform = platform
|
||||
|
||||
factory = IssueHandlerFactory(
|
||||
owner=self.owner,
|
||||
repo=self.repo,
|
||||
token=token,
|
||||
username=username,
|
||||
platform=self.platform,
|
||||
base_domain=base_domain,
|
||||
issue_type=self.issue_type,
|
||||
llm_config=self.llm_config,
|
||||
)
|
||||
self.issue_handler = factory.create()
|
||||
|
||||
def initialize_runtime(
|
||||
self,
|
||||
runtime: Runtime,
|
||||
@@ -435,58 +453,6 @@ class IssueResolver:
|
||||
)
|
||||
return output
|
||||
|
||||
def issue_handler_factory(self) -> ServiceContextIssue | ServiceContextPR:
|
||||
# Determine default base_domain based on platform
|
||||
|
||||
if self.issue_type == 'issue':
|
||||
if self.platform == ProviderType.GITHUB:
|
||||
return ServiceContextIssue(
|
||||
GithubIssueHandler(
|
||||
self.owner,
|
||||
self.repo,
|
||||
self.token,
|
||||
self.username,
|
||||
self.base_domain,
|
||||
),
|
||||
self.llm_config,
|
||||
)
|
||||
else: # platform == Platform.GITLAB
|
||||
return ServiceContextIssue(
|
||||
GitlabIssueHandler(
|
||||
self.owner,
|
||||
self.repo,
|
||||
self.token,
|
||||
self.username,
|
||||
self.base_domain,
|
||||
),
|
||||
self.llm_config,
|
||||
)
|
||||
elif self.issue_type == 'pr':
|
||||
if self.platform == ProviderType.GITHUB:
|
||||
return ServiceContextPR(
|
||||
GithubPRHandler(
|
||||
self.owner,
|
||||
self.repo,
|
||||
self.token,
|
||||
self.username,
|
||||
self.base_domain,
|
||||
),
|
||||
self.llm_config,
|
||||
)
|
||||
else: # platform == Platform.GITLAB
|
||||
return ServiceContextPR(
|
||||
GitlabPRHandler(
|
||||
self.owner,
|
||||
self.repo,
|
||||
self.token,
|
||||
self.username,
|
||||
self.base_domain,
|
||||
),
|
||||
self.llm_config,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Invalid issue type: {self.issue_type}')
|
||||
|
||||
async def resolve_issue(
|
||||
self,
|
||||
reset_logger: bool = False,
|
||||
@@ -497,10 +463,8 @@ class IssueResolver:
|
||||
reset_logger: Whether to reset the logger for multiprocessing.
|
||||
"""
|
||||
|
||||
issue_handler = self.issue_handler_factory()
|
||||
|
||||
# Load dataset
|
||||
issues: list[Issue] = issue_handler.get_converted_issues(
|
||||
issues: list[Issue] = self.issue_handler.get_converted_issues(
|
||||
issue_numbers=[self.issue_number], comment_id=self.comment_id
|
||||
)
|
||||
|
||||
@@ -546,7 +510,7 @@ class IssueResolver:
|
||||
[
|
||||
'git',
|
||||
'clone',
|
||||
issue_handler.get_clone_url(),
|
||||
self.issue_handler.get_clone_url(),
|
||||
f'{self.output_dir}/repo',
|
||||
]
|
||||
).decode('utf-8')
|
||||
@@ -625,7 +589,7 @@ class IssueResolver:
|
||||
output = await self.process_issue(
|
||||
issue,
|
||||
base_commit,
|
||||
issue_handler,
|
||||
self.issue_handler,
|
||||
reset_logger,
|
||||
)
|
||||
output_fp.write(output.model_dump_json() + '\n')
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Type
|
||||
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.runtime.impl.daytona.daytona_runtime import DaytonaRuntime
|
||||
from openhands.runtime.impl.docker.docker_runtime import (
|
||||
@@ -13,7 +11,7 @@ from openhands.runtime.impl.runloop.runloop_runtime import RunloopRuntime
|
||||
from openhands.utils.import_utils import get_impl
|
||||
|
||||
# mypy: disable-error-code="type-abstract"
|
||||
_DEFAULT_RUNTIME_CLASSES: dict[str, Type[Runtime]] = {
|
||||
_DEFAULT_RUNTIME_CLASSES: dict[str, type[Runtime]] = {
|
||||
'eventstream': DockerRuntime,
|
||||
'docker': DockerRuntime,
|
||||
'e2b': E2BRuntime,
|
||||
@@ -25,7 +23,7 @@ _DEFAULT_RUNTIME_CLASSES: dict[str, Type[Runtime]] = {
|
||||
}
|
||||
|
||||
|
||||
def get_runtime_cls(name: str) -> Type[Runtime]:
|
||||
def get_runtime_cls(name: str) -> type[Runtime]:
|
||||
"""
|
||||
If name is one of the predefined runtime names (e.g. 'docker'), return its class.
|
||||
Otherwise attempt to resolve name as subclass of Runtime and return it.
|
||||
|
||||
@@ -13,6 +13,7 @@ import logging
|
||||
import mimetypes
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import traceback
|
||||
@@ -76,6 +77,10 @@ from openhands.utils.async_utils import call_sync_from_async, wait_all
|
||||
mcp_router_logger.setLevel(logger.getEffectiveLevel())
|
||||
|
||||
|
||||
if sys.platform == 'win32':
|
||||
from openhands.runtime.utils.windows_bash import WindowsPowershellSession
|
||||
|
||||
|
||||
class ActionRequest(BaseModel):
|
||||
action: dict
|
||||
|
||||
@@ -100,7 +105,7 @@ def _execute_file_editor(
|
||||
view_range: list[int] | None = None,
|
||||
old_str: str | None = None,
|
||||
new_str: str | None = None,
|
||||
insert_line: int | None = None,
|
||||
insert_line: int | str | None = None,
|
||||
enable_linting: bool = False,
|
||||
) -> tuple[str, tuple[str | None, str | None]]:
|
||||
"""Execute file editor command and handle exceptions.
|
||||
@@ -113,13 +118,24 @@ def _execute_file_editor(
|
||||
view_range: Optional view range tuple (start, end)
|
||||
old_str: Optional string to replace
|
||||
new_str: Optional replacement string
|
||||
insert_line: Optional line number for insertion
|
||||
insert_line: Optional line number for insertion (can be int or str)
|
||||
enable_linting: Whether to enable linting
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing the output string and a tuple of old and new file content
|
||||
"""
|
||||
result: ToolResult | None = None
|
||||
|
||||
# Convert insert_line from string to int if needed
|
||||
if insert_line is not None and isinstance(insert_line, str):
|
||||
try:
|
||||
insert_line = int(insert_line)
|
||||
except ValueError:
|
||||
return (
|
||||
f"ERROR:\nInvalid insert_line value: '{insert_line}'. Expected an integer.",
|
||||
(None, None),
|
||||
)
|
||||
|
||||
try:
|
||||
result = editor(
|
||||
command=command,
|
||||
@@ -133,6 +149,9 @@ def _execute_file_editor(
|
||||
)
|
||||
except ToolError as e:
|
||||
result = ToolResult(error=e.message)
|
||||
except TypeError as e:
|
||||
# Handle unexpected arguments or type errors
|
||||
return f'ERROR:\n{str(e)}', (None, None)
|
||||
|
||||
if result.error:
|
||||
return f'ERROR:\n{result.error}', (None, None)
|
||||
@@ -167,13 +186,14 @@ class ActionExecutor:
|
||||
if _updated_user_id is not None:
|
||||
self.user_id = _updated_user_id
|
||||
|
||||
self.bash_session: BashSession | None = None
|
||||
self.bash_session: BashSession | 'WindowsPowershellSession' | None = None # type: ignore[name-defined]
|
||||
self.lock = asyncio.Lock()
|
||||
self.plugins: dict[str, Plugin] = {}
|
||||
self.file_editor = OHEditor(workspace_root=self._initial_cwd)
|
||||
self.browser: BrowserEnv | None = None
|
||||
self.browser_init_task: asyncio.Task | None = None
|
||||
self.browsergym_eval_env = browsergym_eval_env
|
||||
|
||||
self.start_time = time.time()
|
||||
self.last_execution_time = self.start_time
|
||||
self._initialized = False
|
||||
@@ -199,6 +219,10 @@ class ActionExecutor:
|
||||
|
||||
async def _init_browser_async(self):
|
||||
"""Initialize the browser asynchronously."""
|
||||
if sys.platform == 'win32':
|
||||
logger.warning('Browser environment not supported on windows')
|
||||
return
|
||||
|
||||
logger.debug('Initializing browser asynchronously')
|
||||
try:
|
||||
self.browser = BrowserEnv(self.browsergym_eval_env)
|
||||
@@ -232,15 +256,25 @@ class ActionExecutor:
|
||||
async def ainit(self):
|
||||
# bash needs to be initialized first
|
||||
logger.debug('Initializing bash session')
|
||||
self.bash_session = BashSession(
|
||||
work_dir=self._initial_cwd,
|
||||
username=self.username,
|
||||
no_change_timeout_seconds=int(
|
||||
os.environ.get('NO_CHANGE_TIMEOUT_SECONDS', 10)
|
||||
),
|
||||
max_memory_mb=self.max_memory_gb * 1024 if self.max_memory_gb else None,
|
||||
)
|
||||
self.bash_session.initialize()
|
||||
if sys.platform == 'win32':
|
||||
self.bash_session = WindowsPowershellSession( # type: ignore[name-defined]
|
||||
work_dir=self._initial_cwd,
|
||||
username=self.username,
|
||||
no_change_timeout_seconds=int(
|
||||
os.environ.get('NO_CHANGE_TIMEOUT_SECONDS', 10)
|
||||
),
|
||||
max_memory_mb=self.max_memory_gb * 1024 if self.max_memory_gb else None,
|
||||
)
|
||||
else:
|
||||
self.bash_session = BashSession(
|
||||
work_dir=self._initial_cwd,
|
||||
username=self.username,
|
||||
no_change_timeout_seconds=int(
|
||||
os.environ.get('NO_CHANGE_TIMEOUT_SECONDS', 10)
|
||||
),
|
||||
max_memory_mb=self.max_memory_gb * 1024 if self.max_memory_gb else None,
|
||||
)
|
||||
self.bash_session.initialize()
|
||||
logger.debug('Bash session initialized')
|
||||
|
||||
# Start browser initialization in the background
|
||||
@@ -282,19 +316,55 @@ class ActionExecutor:
|
||||
logger.debug(f'Initializing plugin: {plugin.name}')
|
||||
|
||||
if isinstance(plugin, JupyterPlugin):
|
||||
# Escape backslashes in Windows path
|
||||
cwd = self.bash_session.cwd.replace('\\', '/')
|
||||
await self.run_ipython(
|
||||
IPythonRunCellAction(
|
||||
code=f'import os; os.chdir("{self.bash_session.cwd}")'
|
||||
)
|
||||
IPythonRunCellAction(code=f'import os; os.chdir(r"{cwd}")')
|
||||
)
|
||||
|
||||
async def _init_bash_commands(self):
|
||||
INIT_COMMANDS = [
|
||||
'git config --file ./.git_config user.name "openhands" && git config --file ./.git_config user.email "openhands@all-hands.dev" && alias git="git --no-pager" && export GIT_CONFIG=$(pwd)/.git_config'
|
||||
if os.environ.get('LOCAL_RUNTIME_MODE') == '1'
|
||||
else 'git config --global user.name "openhands" && git config --global user.email "openhands@all-hands.dev" && alias git="git --no-pager"'
|
||||
]
|
||||
logger.debug(f'Initializing by running {len(INIT_COMMANDS)} bash commands...')
|
||||
INIT_COMMANDS = []
|
||||
is_local_runtime = os.environ.get('LOCAL_RUNTIME_MODE') == '1'
|
||||
is_windows = sys.platform == 'win32'
|
||||
|
||||
# Determine git config commands based on platform and runtime mode
|
||||
if is_local_runtime:
|
||||
if is_windows:
|
||||
# Windows, local - split into separate commands
|
||||
INIT_COMMANDS.append(
|
||||
'git config --file ./.git_config user.name "openhands"'
|
||||
)
|
||||
INIT_COMMANDS.append(
|
||||
'git config --file ./.git_config user.email "openhands@all-hands.dev"'
|
||||
)
|
||||
INIT_COMMANDS.append(
|
||||
'$env:GIT_CONFIG = (Join-Path (Get-Location) ".git_config")'
|
||||
)
|
||||
else:
|
||||
# Linux/macOS, local
|
||||
base_git_config = (
|
||||
'git config --file ./.git_config user.name "openhands" && '
|
||||
'git config --file ./.git_config user.email "openhands@all-hands.dev" && '
|
||||
'export GIT_CONFIG=$(pwd)/.git_config'
|
||||
)
|
||||
INIT_COMMANDS.append(base_git_config)
|
||||
else:
|
||||
# Non-local (implies Linux/macOS)
|
||||
base_git_config = (
|
||||
'git config --global user.name "openhands" && '
|
||||
'git config --global user.email "openhands@all-hands.dev"'
|
||||
)
|
||||
INIT_COMMANDS.append(base_git_config)
|
||||
|
||||
# Determine no-pager command
|
||||
if is_windows:
|
||||
no_pager_cmd = 'function git { git.exe --no-pager $args }'
|
||||
else:
|
||||
no_pager_cmd = 'alias git="git --no-pager"'
|
||||
|
||||
INIT_COMMANDS.append(no_pager_cmd)
|
||||
|
||||
logger.info(f'Initializing by running {len(INIT_COMMANDS)} bash commands...')
|
||||
for command in INIT_COMMANDS:
|
||||
action = CmdRunAction(command=command)
|
||||
action.set_hard_timeout(300)
|
||||
@@ -345,9 +415,9 @@ class ActionExecutor:
|
||||
logger.debug(
|
||||
f'{self.bash_session.cwd} != {jupyter_cwd} -> reset Jupyter PWD'
|
||||
)
|
||||
reset_jupyter_cwd_code = (
|
||||
f'import os; os.chdir("{self.bash_session.cwd}")'
|
||||
)
|
||||
# escape windows paths
|
||||
cwd = self.bash_session.cwd.replace('\\', '/')
|
||||
reset_jupyter_cwd_code = f'import os; os.chdir("{cwd}")'
|
||||
_aux_action = IPythonRunCellAction(code=reset_jupyter_cwd_code)
|
||||
_reset_obs: IPythonRunCellObservation = await _jupyter_plugin.run(
|
||||
_aux_action
|
||||
@@ -527,12 +597,20 @@ class ActionExecutor:
|
||||
)
|
||||
|
||||
async def browse(self, action: BrowseURLAction) -> Observation:
|
||||
if self.browser is None:
|
||||
return ErrorObservation(
|
||||
'Browser functionality is not supported on Windows.'
|
||||
)
|
||||
await self._ensure_browser_ready()
|
||||
return await browse(action, self.browser)
|
||||
return await browse(action, self.browser, self.initial_cwd)
|
||||
|
||||
async def browse_interactive(self, action: BrowseInteractiveAction) -> Observation:
|
||||
if self.browser is None:
|
||||
return ErrorObservation(
|
||||
'Browser functionality is not supported on Windows.'
|
||||
)
|
||||
await self._ensure_browser_ready()
|
||||
return await browse(action, self.browser)
|
||||
return await browse(action, self.browser, self.initial_cwd)
|
||||
|
||||
def close(self):
|
||||
self.memory_monitor.stop_monitoring()
|
||||
@@ -726,7 +804,6 @@ if __name__ == '__main__':
|
||||
if not isinstance(action, Action):
|
||||
raise HTTPException(status_code=400, detail='Invalid action type')
|
||||
client.last_execution_time = time.time()
|
||||
|
||||
observation = await client.run_action(action)
|
||||
return event_to_dict(observation)
|
||||
except Exception as e:
|
||||
@@ -897,7 +974,7 @@ if __name__ == '__main__':
|
||||
|
||||
To list files:
|
||||
```sh
|
||||
curl http://localhost:3000/api/list-files
|
||||
curl -X POST -d '{"path": "/"}' http://localhost:3000/list_files
|
||||
```
|
||||
|
||||
Args:
|
||||
|
||||
+262
-29
@@ -72,6 +72,7 @@ STATUS_MESSAGES = {
|
||||
'STATUS$CONTAINER_STARTED': 'Container started.',
|
||||
'STATUS$WAITING_FOR_CLIENT': 'Waiting for client...',
|
||||
'STATUS$SETTING_UP_WORKSPACE': 'Setting up workspace...',
|
||||
'STATUS$SETTING_UP_GIT_HOOKS': 'Setting up git hooks...',
|
||||
}
|
||||
|
||||
|
||||
@@ -424,21 +425,278 @@ class Runtime(FileEditRuntimeMixin):
|
||||
if isinstance(obs, CmdOutputObservation) and obs.exit_code != 0:
|
||||
self.log('error', f'Setup script failed: {obs.content}')
|
||||
|
||||
def maybe_setup_git_hooks(self):
|
||||
"""Set up git hooks if .openhands/pre-commit.sh exists in the workspace or repository."""
|
||||
pre_commit_script = '.openhands/pre-commit.sh'
|
||||
read_obs = self.read(FileReadAction(path=pre_commit_script))
|
||||
if isinstance(read_obs, ErrorObservation):
|
||||
return
|
||||
|
||||
if self.status_callback:
|
||||
self.status_callback(
|
||||
'info', 'STATUS$SETTING_UP_GIT_HOOKS', 'Setting up git hooks...'
|
||||
)
|
||||
|
||||
# Ensure the git hooks directory exists
|
||||
action = CmdRunAction('mkdir -p .git/hooks')
|
||||
obs = self.run_action(action)
|
||||
if isinstance(obs, CmdOutputObservation) and obs.exit_code != 0:
|
||||
self.log('error', f'Failed to create git hooks directory: {obs.content}')
|
||||
return
|
||||
|
||||
# Make the pre-commit script executable
|
||||
action = CmdRunAction(f'chmod +x {pre_commit_script}')
|
||||
obs = self.run_action(action)
|
||||
if isinstance(obs, CmdOutputObservation) and obs.exit_code != 0:
|
||||
self.log(
|
||||
'error', f'Failed to make pre-commit script executable: {obs.content}'
|
||||
)
|
||||
return
|
||||
|
||||
# Check if there's an existing pre-commit hook
|
||||
pre_commit_hook = '.git/hooks/pre-commit'
|
||||
pre_commit_local = '.git/hooks/pre-commit.local'
|
||||
|
||||
# Read the existing pre-commit hook if it exists
|
||||
read_obs = self.read(FileReadAction(path=pre_commit_hook))
|
||||
if not isinstance(read_obs, ErrorObservation):
|
||||
# If the existing hook wasn't created by OpenHands, preserve it
|
||||
if 'This hook was installed by OpenHands' not in read_obs.content:
|
||||
self.log('info', 'Preserving existing pre-commit hook')
|
||||
# Move the existing hook to pre-commit.local
|
||||
action = CmdRunAction(f'mv {pre_commit_hook} {pre_commit_local}')
|
||||
obs = self.run_action(action)
|
||||
if isinstance(obs, CmdOutputObservation) and obs.exit_code != 0:
|
||||
self.log(
|
||||
'error',
|
||||
f'Failed to preserve existing pre-commit hook: {obs.content}',
|
||||
)
|
||||
return
|
||||
|
||||
# Make it executable
|
||||
action = CmdRunAction(f'chmod +x {pre_commit_local}')
|
||||
obs = self.run_action(action)
|
||||
if isinstance(obs, CmdOutputObservation) and obs.exit_code != 0:
|
||||
self.log(
|
||||
'error',
|
||||
f'Failed to make preserved hook executable: {obs.content}',
|
||||
)
|
||||
return
|
||||
|
||||
# Create the pre-commit hook that calls our script
|
||||
pre_commit_hook_content = f"""#!/bin/bash
|
||||
# This hook was installed by OpenHands
|
||||
# It calls the pre-commit script in the .openhands directory
|
||||
|
||||
if [ -x "{pre_commit_script}" ]; then
|
||||
source "{pre_commit_script}"
|
||||
exit $?
|
||||
else
|
||||
echo "Warning: {pre_commit_script} not found or not executable"
|
||||
exit 0
|
||||
fi
|
||||
"""
|
||||
|
||||
# Write the pre-commit hook
|
||||
write_obs = self.write(
|
||||
FileWriteAction(path=pre_commit_hook, content=pre_commit_hook_content)
|
||||
)
|
||||
if isinstance(write_obs, ErrorObservation):
|
||||
self.log('error', f'Failed to write pre-commit hook: {write_obs.content}')
|
||||
return
|
||||
|
||||
# Make the pre-commit hook executable
|
||||
action = CmdRunAction(f'chmod +x {pre_commit_hook}')
|
||||
obs = self.run_action(action)
|
||||
if isinstance(obs, CmdOutputObservation) and obs.exit_code != 0:
|
||||
self.log(
|
||||
'error', f'Failed to make pre-commit hook executable: {obs.content}'
|
||||
)
|
||||
return
|
||||
|
||||
self.log('info', 'Git pre-commit hook installed successfully')
|
||||
|
||||
def _load_microagents_from_directory(
|
||||
self, microagents_dir: Path, source_description: str
|
||||
) -> list[BaseMicroagent]:
|
||||
"""Load microagents from a directory.
|
||||
|
||||
Args:
|
||||
microagents_dir: Path to the directory containing microagents
|
||||
source_description: Description of the source for logging purposes
|
||||
|
||||
Returns:
|
||||
A list of loaded microagents
|
||||
"""
|
||||
loaded_microagents: list[BaseMicroagent] = []
|
||||
files = self.list_files(str(microagents_dir))
|
||||
|
||||
if not files:
|
||||
return loaded_microagents
|
||||
|
||||
self.log(
|
||||
'info',
|
||||
f'Found {len(files)} files in {source_description} microagents directory',
|
||||
)
|
||||
zip_path = self.copy_from(str(microagents_dir))
|
||||
microagent_folder = tempfile.mkdtemp()
|
||||
|
||||
try:
|
||||
with ZipFile(zip_path, 'r') as zip_file:
|
||||
zip_file.extractall(microagent_folder)
|
||||
|
||||
zip_path.unlink()
|
||||
repo_agents, knowledge_agents = load_microagents_from_dir(microagent_folder)
|
||||
|
||||
self.log(
|
||||
'info',
|
||||
f'Loaded {len(repo_agents)} repo agents and {len(knowledge_agents)} knowledge agents from {source_description}',
|
||||
)
|
||||
|
||||
loaded_microagents.extend(repo_agents.values())
|
||||
loaded_microagents.extend(knowledge_agents.values())
|
||||
finally:
|
||||
shutil.rmtree(microagent_folder)
|
||||
|
||||
return loaded_microagents
|
||||
|
||||
def _get_authenticated_git_url(self, repo_path: str) -> str:
|
||||
"""Get an authenticated git URL for a repository.
|
||||
|
||||
Args:
|
||||
repo_path: Repository path (e.g., "github.com/acme-co/api")
|
||||
|
||||
Returns:
|
||||
Authenticated git URL if credentials are available, otherwise regular HTTPS URL
|
||||
"""
|
||||
remote_url = f'https://{repo_path}.git'
|
||||
|
||||
# Determine provider from repo path
|
||||
provider = None
|
||||
if 'github.com' in repo_path:
|
||||
provider = ProviderType.GITHUB
|
||||
elif 'gitlab.com' in repo_path:
|
||||
provider = ProviderType.GITLAB
|
||||
|
||||
# Add authentication if available
|
||||
if (
|
||||
provider
|
||||
and self.git_provider_tokens
|
||||
and provider in self.git_provider_tokens
|
||||
):
|
||||
git_token = self.git_provider_tokens[provider].token
|
||||
if git_token:
|
||||
if provider == ProviderType.GITLAB:
|
||||
remote_url = f'https://oauth2:{git_token.get_secret_value()}@{repo_path.replace("gitlab.com/", "")}.git'
|
||||
else:
|
||||
remote_url = f'https://{git_token.get_secret_value()}@{repo_path.replace("github.com/", "")}.git'
|
||||
|
||||
return remote_url
|
||||
|
||||
def get_microagents_from_org_or_user(
|
||||
self, selected_repository: str
|
||||
) -> list[BaseMicroagent]:
|
||||
"""Load microagents from the organization or user level .openhands repository.
|
||||
|
||||
For example, if the repository is github.com/acme-co/api, this will check if
|
||||
github.com/acme-co/.openhands exists. If it does, it will clone it and load
|
||||
the microagents from the ./microagents/ folder.
|
||||
|
||||
Args:
|
||||
selected_repository: The repository path (e.g., "github.com/acme-co/api")
|
||||
|
||||
Returns:
|
||||
A list of loaded microagents from the org/user level repository
|
||||
"""
|
||||
loaded_microagents: list[BaseMicroagent] = []
|
||||
workspace_root = Path(self.config.workspace_mount_path_in_sandbox)
|
||||
|
||||
repo_parts = selected_repository.split('/')
|
||||
if len(repo_parts) < 2:
|
||||
return loaded_microagents
|
||||
|
||||
# Extract the domain and org/user name
|
||||
domain = repo_parts[0] if len(repo_parts) > 2 else 'github.com'
|
||||
org_name = repo_parts[-2]
|
||||
|
||||
# Construct the org-level .openhands repo path
|
||||
org_openhands_repo = f'{domain}/{org_name}/.openhands'
|
||||
if domain not in org_openhands_repo:
|
||||
org_openhands_repo = f'github.com/{org_openhands_repo}'
|
||||
|
||||
self.log(
|
||||
'info',
|
||||
f'Checking for org-level microagents at {org_openhands_repo}',
|
||||
)
|
||||
|
||||
# Try to clone the org-level .openhands repo
|
||||
try:
|
||||
# Create a temporary directory for the org-level repo
|
||||
org_repo_dir = workspace_root / f'org_openhands_{org_name}'
|
||||
|
||||
# Get authenticated URL and do a shallow clone (--depth 1) for efficiency
|
||||
remote_url = self._get_authenticated_git_url(org_openhands_repo)
|
||||
clone_cmd = f"git clone --depth 1 {remote_url} {org_repo_dir} 2>/dev/null || echo 'Org repo not found'"
|
||||
|
||||
action = CmdRunAction(command=clone_cmd)
|
||||
obs = self.run_action(action)
|
||||
|
||||
if (
|
||||
isinstance(obs, CmdOutputObservation)
|
||||
and obs.exit_code == 0
|
||||
and 'Org repo not found' not in obs.content
|
||||
):
|
||||
self.log(
|
||||
'info',
|
||||
f'Successfully cloned org-level microagents from {org_openhands_repo}',
|
||||
)
|
||||
|
||||
# Load microagents from the org-level repo
|
||||
org_microagents_dir = org_repo_dir / 'microagents'
|
||||
loaded_microagents = self._load_microagents_from_directory(
|
||||
org_microagents_dir, 'org-level'
|
||||
)
|
||||
|
||||
# Clean up the org repo directory
|
||||
shutil.rmtree(org_repo_dir)
|
||||
else:
|
||||
self.log(
|
||||
'info',
|
||||
f'No org-level microagents found at {org_openhands_repo}',
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.log('error', f'Error loading org-level microagents: {str(e)}')
|
||||
|
||||
return loaded_microagents
|
||||
|
||||
def get_microagents_from_selected_repo(
|
||||
self, selected_repository: str | None
|
||||
) -> list[BaseMicroagent]:
|
||||
"""Load microagents from the selected repository.
|
||||
If selected_repository is None, load microagents from the current workspace.
|
||||
This is the main entry point for loading microagents.
|
||||
|
||||
This method also checks for user/org level microagents stored in a .openhands repository.
|
||||
For example, if the repository is github.com/acme-co/api, it will also check for
|
||||
github.com/acme-co/.openhands and load microagents from there if it exists.
|
||||
"""
|
||||
|
||||
loaded_microagents: list[BaseMicroagent] = []
|
||||
workspace_root = Path(self.config.workspace_mount_path_in_sandbox)
|
||||
microagents_dir = workspace_root / '.openhands' / 'microagents'
|
||||
repo_root = None
|
||||
|
||||
# Check for user/org level microagents if a repository is selected
|
||||
if selected_repository:
|
||||
# Load microagents from the org/user level repository
|
||||
org_microagents = self.get_microagents_from_org_or_user(selected_repository)
|
||||
loaded_microagents.extend(org_microagents)
|
||||
|
||||
# Continue with repository-specific microagents
|
||||
repo_root = workspace_root / selected_repository.split('/')[-1]
|
||||
microagents_dir = repo_root / '.openhands' / 'microagents'
|
||||
|
||||
self.log(
|
||||
'info',
|
||||
f'Selected repo: {selected_repository}, loading microagents from {microagents_dir} (inside runtime)',
|
||||
@@ -470,35 +728,10 @@ class Runtime(FileEditRuntimeMixin):
|
||||
)
|
||||
|
||||
# Load microagents from directory
|
||||
files = self.list_files(str(microagents_dir))
|
||||
if files:
|
||||
self.log('info', f'Found {len(files)} files in microagents directory.')
|
||||
zip_path = self.copy_from(str(microagents_dir))
|
||||
microagent_folder = tempfile.mkdtemp()
|
||||
|
||||
# Properly handle the zip file
|
||||
with ZipFile(zip_path, 'r') as zip_file:
|
||||
zip_file.extractall(microagent_folder)
|
||||
|
||||
# Add debug print of directory structure
|
||||
self.log('debug', 'Microagent folder structure:')
|
||||
for root, _, files in os.walk(microagent_folder):
|
||||
relative_path = os.path.relpath(root, microagent_folder)
|
||||
self.log('debug', f'Directory: {relative_path}/')
|
||||
for file in files:
|
||||
self.log('debug', f' File: {os.path.join(relative_path, file)}')
|
||||
|
||||
# Clean up the temporary zip file
|
||||
zip_path.unlink()
|
||||
# Load all microagents using the existing function
|
||||
repo_agents, knowledge_agents = load_microagents_from_dir(microagent_folder)
|
||||
self.log(
|
||||
'info',
|
||||
f'Loaded {len(repo_agents)} repo agents and {len(knowledge_agents)} knowledge agents',
|
||||
)
|
||||
loaded_microagents.extend(repo_agents.values())
|
||||
loaded_microagents.extend(knowledge_agents.values())
|
||||
shutil.rmtree(microagent_folder)
|
||||
repo_microagents = self._load_microagents_from_directory(
|
||||
microagents_dir, 'repository'
|
||||
)
|
||||
loaded_microagents.extend(repo_microagents)
|
||||
|
||||
return loaded_microagents
|
||||
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
import base64
|
||||
import io
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def image_to_png_base64_url(
|
||||
image: np.ndarray | Image.Image, add_data_prefix: bool = False
|
||||
) -> str:
|
||||
"""Convert a numpy array to a base64 encoded png image url."""
|
||||
if isinstance(image, np.ndarray):
|
||||
image = Image.fromarray(image)
|
||||
if image.mode in ('RGBA', 'LA'):
|
||||
image = image.convert('RGB')
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format='PNG')
|
||||
|
||||
image_base64 = base64.b64encode(buffered.getvalue()).decode()
|
||||
return (
|
||||
f'data:image/png;base64,{image_base64}'
|
||||
if add_data_prefix
|
||||
else f'{image_base64}'
|
||||
)
|
||||
|
||||
|
||||
def png_base64_url_to_image(png_base64_url: str) -> Image.Image:
|
||||
"""Convert a base64 encoded png image url to a PIL Image."""
|
||||
splited = png_base64_url.split(',')
|
||||
if len(splited) == 2:
|
||||
base64_data = splited[1]
|
||||
else:
|
||||
base64_data = png_base64_url
|
||||
return Image.open(io.BytesIO(base64.b64decode(base64_data)))
|
||||
@@ -1,6 +1,4 @@
|
||||
import atexit
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import multiprocessing
|
||||
import time
|
||||
@@ -9,13 +7,12 @@ import uuid
|
||||
import browsergym.core # noqa F401 (we register the openended task as a gym environment)
|
||||
import gymnasium as gym
|
||||
import html2text
|
||||
import numpy as np
|
||||
import tenacity
|
||||
from browsergym.utils.obs import flatten_dom_to_str, overlay_som
|
||||
from PIL import Image
|
||||
|
||||
from openhands.core.exceptions import BrowserInitException
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.runtime.browser.base64 import image_to_png_base64_url
|
||||
from openhands.utils.shutdown_listener import should_continue, should_exit
|
||||
from openhands.utils.tenacity_stop import stop_if_should_exit
|
||||
|
||||
@@ -40,7 +37,7 @@ class BrowserEnv:
|
||||
self.init_browser()
|
||||
atexit.register(self.close)
|
||||
|
||||
def get_html_text_converter(self):
|
||||
def get_html_text_converter(self) -> html2text.HTML2Text:
|
||||
html_text_converter = html2text.HTML2Text()
|
||||
# ignore links and images
|
||||
html_text_converter.ignore_links = False
|
||||
@@ -56,7 +53,7 @@ class BrowserEnv:
|
||||
stop=tenacity.stop_after_attempt(5) | stop_if_should_exit(),
|
||||
retry=tenacity.retry_if_exception_type(BrowserInitException),
|
||||
)
|
||||
def init_browser(self):
|
||||
def init_browser(self) -> None:
|
||||
logger.debug('Starting browser env...')
|
||||
try:
|
||||
self.process = multiprocessing.Process(target=self.browser_process)
|
||||
@@ -69,7 +66,7 @@ class BrowserEnv:
|
||||
self.close()
|
||||
raise BrowserInitException('Failed to start browser environment.')
|
||||
|
||||
def browser_process(self):
|
||||
def browser_process(self) -> None:
|
||||
if self.eval_mode:
|
||||
assert self.browsergym_eval_env is not None
|
||||
logger.info('Initializing browser env for web browsing evaluation.')
|
||||
@@ -165,13 +162,13 @@ class BrowserEnv:
|
||||
html_str = flatten_dom_to_str(obs['dom_object'])
|
||||
obs['text_content'] = self.html_text_converter.handle(html_str)
|
||||
# make observation serializable
|
||||
obs['set_of_marks'] = self.image_to_png_base64_url(
|
||||
obs['set_of_marks'] = image_to_png_base64_url(
|
||||
overlay_som(
|
||||
obs['screenshot'], obs.get('extra_element_properties', {})
|
||||
),
|
||||
add_data_prefix=True,
|
||||
)
|
||||
obs['screenshot'] = self.image_to_png_base64_url(
|
||||
obs['screenshot'] = image_to_png_base64_url(
|
||||
obs['screenshot'], add_data_prefix=True
|
||||
)
|
||||
obs['active_page_index'] = obs['active_page_index'].item()
|
||||
@@ -196,17 +193,18 @@ class BrowserEnv:
|
||||
if self.agent_side.poll(timeout=0.01):
|
||||
response_id, obs = self.agent_side.recv()
|
||||
if response_id == unique_request_id:
|
||||
return obs
|
||||
return dict(obs)
|
||||
|
||||
def check_alive(self, timeout: float = 60):
|
||||
def check_alive(self, timeout: float = 60) -> bool:
|
||||
self.agent_side.send(('IS_ALIVE', None))
|
||||
if self.agent_side.poll(timeout=timeout):
|
||||
response_id, _ = self.agent_side.recv()
|
||||
if response_id == 'ALIVE':
|
||||
return True
|
||||
logger.debug(f'Browser env is not alive. Response ID: {response_id}')
|
||||
return False
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
if not self.process.is_alive():
|
||||
return
|
||||
try:
|
||||
@@ -225,41 +223,3 @@ class BrowserEnv:
|
||||
self.browser_side.close()
|
||||
except Exception as e:
|
||||
logger.error(f'Encountered an error when closing browser env: {e}')
|
||||
|
||||
@staticmethod
|
||||
def image_to_png_base64_url(
|
||||
image: np.ndarray | Image.Image, add_data_prefix: bool = False
|
||||
):
|
||||
"""Convert a numpy array to a base64 encoded png image url."""
|
||||
if isinstance(image, np.ndarray):
|
||||
image = Image.fromarray(image)
|
||||
if image.mode in ('RGBA', 'LA'):
|
||||
image = image.convert('RGB')
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format='PNG')
|
||||
|
||||
image_base64 = base64.b64encode(buffered.getvalue()).decode()
|
||||
return (
|
||||
f'data:image/png;base64,{image_base64}'
|
||||
if add_data_prefix
|
||||
else f'{image_base64}'
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def image_to_jpg_base64_url(
|
||||
image: np.ndarray | Image.Image, add_data_prefix: bool = False
|
||||
):
|
||||
"""Convert a numpy array to a base64 encoded jpeg image url."""
|
||||
if isinstance(image, np.ndarray):
|
||||
image = Image.fromarray(image)
|
||||
if image.mode in ('RGBA', 'LA'):
|
||||
image = image.convert('RGB')
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format='JPEG')
|
||||
|
||||
image_base64 = base64.b64encode(buffered.getvalue()).decode()
|
||||
return (
|
||||
f'data:image/jpeg;base64,{image_base64}'
|
||||
if add_data_prefix
|
||||
else f'{image_base64}'
|
||||
)
|
||||
|
||||
@@ -1,15 +1,23 @@
|
||||
import base64
|
||||
import datetime
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from openhands.core.exceptions import BrowserUnavailableException
|
||||
from openhands.core.schema import ActionType
|
||||
from openhands.events.action import BrowseInteractiveAction, BrowseURLAction
|
||||
from openhands.events.observation import BrowserOutputObservation
|
||||
from openhands.runtime.browser.base64 import png_base64_url_to_image
|
||||
from openhands.runtime.browser.browser_env import BrowserEnv
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
|
||||
async def browse(
|
||||
action: BrowseURLAction | BrowseInteractiveAction, browser: BrowserEnv | None
|
||||
action: BrowseURLAction | BrowseInteractiveAction,
|
||||
browser: BrowserEnv | None,
|
||||
workspace_dir: str | None = None,
|
||||
) -> BrowserOutputObservation:
|
||||
if browser is None:
|
||||
raise BrowserUnavailableException()
|
||||
@@ -31,10 +39,50 @@ async def browse(
|
||||
try:
|
||||
# obs provided by BrowserGym: see https://github.com/ServiceNow/BrowserGym/blob/main/core/src/browsergym/core/env.py#L396
|
||||
obs = await call_sync_from_async(browser.step, action_str)
|
||||
|
||||
# Save screenshot if workspace_dir is provided
|
||||
screenshot_path = None
|
||||
if workspace_dir is not None and obs.get('screenshot'):
|
||||
# Create screenshots directory if it doesn't exist
|
||||
screenshots_dir = Path(workspace_dir) / '.browser_screenshots'
|
||||
screenshots_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Generate a filename based on timestamp
|
||||
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S_%f')
|
||||
screenshot_filename = f'screenshot_{timestamp}.png'
|
||||
screenshot_path = str(screenshots_dir / screenshot_filename)
|
||||
|
||||
# Direct image saving from base64 data without using PIL's Image.open
|
||||
# This approach bypasses potential encoding issues that might occur when
|
||||
# converting between different image representations, ensuring the raw PNG
|
||||
# data from the browser is saved directly to disk.
|
||||
|
||||
# Extract the base64 data
|
||||
base64_data = obs.get('screenshot', '')
|
||||
if ',' in base64_data:
|
||||
base64_data = base64_data.split(',')[1]
|
||||
|
||||
try:
|
||||
# Decode base64 directly to binary
|
||||
image_data = base64.b64decode(base64_data)
|
||||
|
||||
# Write binary data directly to file
|
||||
with open(screenshot_path, 'wb') as f:
|
||||
f.write(image_data)
|
||||
|
||||
# Verify the image was saved correctly by opening it
|
||||
# This is just a verification step and can be removed in production
|
||||
Image.open(screenshot_path).verify()
|
||||
except Exception:
|
||||
# If direct saving fails, fall back to the original method
|
||||
image = png_base64_url_to_image(obs.get('screenshot'))
|
||||
image.save(screenshot_path, format='PNG', optimize=True)
|
||||
|
||||
return BrowserOutputObservation(
|
||||
content=obs['text_content'], # text content of the page
|
||||
url=obs.get('url', ''), # URL of the page
|
||||
screenshot=obs.get('screenshot', None), # base64-encoded screenshot, png
|
||||
screenshot_path=screenshot_path, # path to saved screenshot file
|
||||
set_of_marks=obs.get(
|
||||
'set_of_marks', None
|
||||
), # base64-encoded Set-of-Marks annotated screenshot, png,
|
||||
@@ -60,6 +108,7 @@ async def browse(
|
||||
return BrowserOutputObservation(
|
||||
content=str(e),
|
||||
screenshot='',
|
||||
screenshot_path=None,
|
||||
error=True,
|
||||
last_browser_action_error=str(e),
|
||||
url=asked_url if action.action == ActionType.BROWSE else '',
|
||||
|
||||
@@ -36,7 +36,7 @@ class DockerRuntimeBuilder(RuntimeBuilder):
|
||||
self.rolling_logger = RollingLogger(max_lines=10)
|
||||
|
||||
@staticmethod
|
||||
def check_buildx(is_podman: bool = False):
|
||||
def check_buildx(is_podman: bool = False) -> bool:
|
||||
"""Check if Docker Buildx is available"""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
|
||||
@@ -99,8 +99,8 @@ class RemoteRuntimeBuilder(RuntimeBuilder):
|
||||
logger.info(f'Build status: {status}')
|
||||
|
||||
if status == 'SUCCESS':
|
||||
logger.debug(f"Successfully built {status_data['image']}")
|
||||
return status_data['image']
|
||||
logger.debug(f'Successfully built {status_data["image"]}')
|
||||
return str(status_data['image'])
|
||||
elif status in [
|
||||
'FAILURE',
|
||||
'INTERNAL_ERROR',
|
||||
@@ -139,11 +139,11 @@ class RemoteRuntimeBuilder(RuntimeBuilder):
|
||||
|
||||
if result['exists']:
|
||||
logger.debug(
|
||||
f"Image {image_name} exists. "
|
||||
f"Uploaded at: {result['image']['upload_time']}, "
|
||||
f"Size: {result['image']['image_size_bytes'] / 1024 / 1024:.2f} MB"
|
||||
f'Image {image_name} exists. '
|
||||
f'Uploaded at: {result["image"]["upload_time"]}, '
|
||||
f'Size: {result["image"]["image_size_bytes"] / 1024 / 1024:.2f} MB'
|
||||
)
|
||||
else:
|
||||
logger.debug(f'Image {image_name} does not exist.')
|
||||
|
||||
return result['exists']
|
||||
return bool(result['exists'])
|
||||
|
||||
@@ -5,7 +5,6 @@ This server has no authentication and only listens to localhost traffic.
|
||||
|
||||
import os
|
||||
import threading
|
||||
from typing import Tuple
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
@@ -22,12 +21,12 @@ def create_app() -> FastAPI:
|
||||
)
|
||||
|
||||
@app.get('/')
|
||||
async def root():
|
||||
async def root() -> dict[str, str]:
|
||||
"""Root endpoint to check if the server is running."""
|
||||
return {'status': 'File viewer server is running'}
|
||||
|
||||
@app.get('/view')
|
||||
async def view_file(path: str, request: Request):
|
||||
async def view_file(path: str, request: Request) -> HTMLResponse:
|
||||
"""View a file using an embedded viewer.
|
||||
|
||||
Args:
|
||||
@@ -75,7 +74,7 @@ def create_app() -> FastAPI:
|
||||
return app
|
||||
|
||||
|
||||
def start_file_viewer_server(port: int) -> Tuple[str, threading.Thread]:
|
||||
def start_file_viewer_server(port: int) -> tuple[str, threading.Thread]:
|
||||
"""Start the file viewer server on the specified port or find an available one.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -158,7 +158,6 @@ class ActionExecutionClient(Runtime):
|
||||
|
||||
def copy_from(self, path: str) -> Path:
|
||||
"""Zip all files in the sandbox and return as a stream of bytes."""
|
||||
|
||||
try:
|
||||
params = {'path': path}
|
||||
with self.session.stream(
|
||||
@@ -183,25 +182,44 @@ class ActionExecutionClient(Runtime):
|
||||
if not os.path.exists(host_src):
|
||||
raise FileNotFoundError(f'Source file {host_src} does not exist')
|
||||
|
||||
temp_zip_path: str | None = None # Define temp_zip_path outside the try block
|
||||
|
||||
try:
|
||||
params = {'destination': sandbox_dest, 'recursive': str(recursive).lower()}
|
||||
file_to_upload = None
|
||||
upload_data = {}
|
||||
|
||||
if recursive:
|
||||
# Create and write the zip file inside the try block
|
||||
with tempfile.NamedTemporaryFile(
|
||||
suffix='.zip', delete=False
|
||||
) as temp_zip:
|
||||
temp_zip_path = temp_zip.name
|
||||
|
||||
with ZipFile(temp_zip_path, 'w') as zipf:
|
||||
for root, _, files in os.walk(host_src):
|
||||
for file in files:
|
||||
file_path = os.path.join(root, file)
|
||||
arcname = os.path.relpath(
|
||||
file_path, os.path.dirname(host_src)
|
||||
)
|
||||
zipf.write(file_path, arcname)
|
||||
try:
|
||||
with ZipFile(temp_zip_path, 'w') as zipf:
|
||||
for root, _, files in os.walk(host_src):
|
||||
for file in files:
|
||||
file_path = os.path.join(root, file)
|
||||
arcname = os.path.relpath(
|
||||
file_path, os.path.dirname(host_src)
|
||||
)
|
||||
zipf.write(file_path, arcname)
|
||||
|
||||
upload_data = {'file': open(temp_zip_path, 'rb')}
|
||||
self.log(
|
||||
'debug',
|
||||
f'Opening temporary zip file for upload: {temp_zip_path}',
|
||||
)
|
||||
file_to_upload = open(temp_zip_path, 'rb')
|
||||
upload_data = {'file': file_to_upload}
|
||||
except Exception as e:
|
||||
# Ensure temp file is cleaned up if zipping fails
|
||||
if temp_zip_path and os.path.exists(temp_zip_path):
|
||||
os.unlink(temp_zip_path)
|
||||
raise e # Re-raise the exception after cleanup attempt
|
||||
else:
|
||||
upload_data = {'file': open(host_src, 'rb')}
|
||||
file_to_upload = open(host_src, 'rb')
|
||||
upload_data = {'file': file_to_upload}
|
||||
|
||||
params = {'destination': sandbox_dest, 'recursive': str(recursive).lower()}
|
||||
|
||||
@@ -217,11 +235,18 @@ class ActionExecutionClient(Runtime):
|
||||
f'Copy completed: host:{host_src} -> runtime:{sandbox_dest}. Response: {response.text}',
|
||||
)
|
||||
finally:
|
||||
if recursive:
|
||||
os.unlink(temp_zip_path)
|
||||
self.log(
|
||||
'debug', f'Copy completed: host:{host_src} -> runtime:{sandbox_dest}'
|
||||
)
|
||||
if file_to_upload:
|
||||
file_to_upload.close()
|
||||
|
||||
# Cleanup the temporary zip file if it was created
|
||||
if temp_zip_path and os.path.exists(temp_zip_path):
|
||||
try:
|
||||
os.unlink(temp_zip_path)
|
||||
except Exception as e:
|
||||
self.log(
|
||||
'error',
|
||||
f'Failed to delete temporary zip file {temp_zip_path}: {e}',
|
||||
)
|
||||
|
||||
def get_vscode_token(self) -> str:
|
||||
if self.vscode_enabled and self.runtime_initialized:
|
||||
@@ -334,26 +359,34 @@ class ActionExecutionClient(Runtime):
|
||||
server.model_dump(mode='json')
|
||||
for server in updated_mcp_config.stdio_servers
|
||||
]
|
||||
self.log('debug', f'Updating MCP server to: {stdio_tools}')
|
||||
response = self._send_action_server_request(
|
||||
'POST',
|
||||
f'{self.action_execution_server_url}/update_mcp_server',
|
||||
json=stdio_tools,
|
||||
timeout=10,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(f'Failed to update MCP server: {response.text}')
|
||||
|
||||
# No API key by default. Child runtime can override this when appropriate
|
||||
updated_mcp_config.sse_servers.append(
|
||||
MCPSSEServerConfig(
|
||||
url=self.action_execution_server_url.rstrip('/') + '/sse', api_key=None
|
||||
if len(stdio_tools) > 0:
|
||||
self.log('debug', f'Updating MCP server to: {stdio_tools}')
|
||||
response = self._send_action_server_request(
|
||||
'POST',
|
||||
f'{self.action_execution_server_url}/update_mcp_server',
|
||||
json=stdio_tools,
|
||||
timeout=10,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
self.log('warning', f'Failed to update MCP server: {response.text}')
|
||||
|
||||
# No API key by default. Child runtime can override this when appropriate
|
||||
updated_mcp_config.sse_servers.append(
|
||||
MCPSSEServerConfig(
|
||||
url=self.action_execution_server_url.rstrip('/') + '/sse',
|
||||
api_key=None,
|
||||
)
|
||||
)
|
||||
self.log(
|
||||
'info',
|
||||
f'Updated MCP config: {updated_mcp_config.sse_servers}',
|
||||
)
|
||||
else:
|
||||
self.log(
|
||||
'debug',
|
||||
'MCP servers inside runtime is not updated since no stdio servers are provided',
|
||||
)
|
||||
)
|
||||
self.log(
|
||||
'debug',
|
||||
f'Updated MCP config by adding runtime as another server: {updated_mcp_config}',
|
||||
)
|
||||
return updated_mcp_config
|
||||
|
||||
async def call_tool_mcp(self, action: MCPAction) -> Observation:
|
||||
|
||||
@@ -12,18 +12,32 @@
|
||||
|
||||
### Step 2: Set Your API Key as an Environment Variable
|
||||
Run the following command in your terminal, replacing `<your-api-key>` with the actual key you copied:
|
||||
|
||||
Mac/Linux:
|
||||
```bash
|
||||
export DAYTONA_API_KEY="<your-api-key>"
|
||||
```
|
||||
|
||||
Windows PowerShell:
|
||||
```powershell
|
||||
$env:DAYTONA_API_KEY="<your-api-key>"
|
||||
```
|
||||
|
||||
This step ensures that OpenHands can authenticate with the Daytona platform when it runs.
|
||||
|
||||
### Step 3: Run OpenHands Locally Using Docker
|
||||
To start the latest version of OpenHands on your machine, execute the following command in your terminal:
|
||||
|
||||
Mac/Linux:
|
||||
```bash
|
||||
bash -i <(curl -sL https://get.daytona.io/openhands)
|
||||
```
|
||||
|
||||
Windows:
|
||||
```powershell
|
||||
powershell -Command "irm https://get.daytona.io/openhands-windows | iex"
|
||||
```
|
||||
|
||||
#### What This Command Does:
|
||||
- Downloads the latest OpenHands release script.
|
||||
- Runs the script in an interactive Bash session.
|
||||
@@ -36,10 +50,16 @@ Once executed, OpenHands should be running locally and ready for use.
|
||||
### Step 1: Set the `OPENHANDS_VERSION` Environment Variable
|
||||
Run the following command in your terminal, replacing `<openhands-release>` with the latest release's version seen in the [main README.md file](https://github.com/All-Hands-AI/OpenHands?tab=readme-ov-file#-quick-start):
|
||||
|
||||
#### Mac/Linux:
|
||||
```bash
|
||||
export OPENHANDS_VERSION="<openhands-release>" # e.g. 0.27
|
||||
```
|
||||
|
||||
#### Windows PowerShell:
|
||||
```powershell
|
||||
$env:OPENHANDS_VERSION="<openhands-release>" # e.g. 0.27
|
||||
```
|
||||
|
||||
### Step 2: Retrieve Your Daytona API Key
|
||||
1. Visit the [Daytona Dashboard](https://app.daytona.io/dashboard/keys).
|
||||
2. Click **"Create Key"**.
|
||||
@@ -48,13 +68,21 @@ export OPENHANDS_VERSION="<openhands-release>" # e.g. 0.27
|
||||
|
||||
### Step 3: Set Your API Key as an Environment Variable:
|
||||
Run the following command in your terminal, replacing `<your-api-key>` with the actual key you copied:
|
||||
|
||||
#### Mac/Linux:
|
||||
```bash
|
||||
export DAYTONA_API_KEY="<your-api-key>"
|
||||
```
|
||||
|
||||
#### Windows PowerShell:
|
||||
```powershell
|
||||
$env:DAYTONA_API_KEY="<your-api-key>"
|
||||
```
|
||||
|
||||
### Step 4: Run the following `docker` command:
|
||||
This command pulls and runs the OpenHands container using Docker. Once executed, OpenHands should be running locally and ready for use.
|
||||
|
||||
#### Mac/Linux:
|
||||
```bash
|
||||
docker run -it --rm --pull=always \
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:${OPENHANDS_VERSION}-nikolaik \
|
||||
@@ -67,16 +95,36 @@ docker run -it --rm --pull=always \
|
||||
docker.all-hands.dev/all-hands-ai/openhands:${OPENHANDS_VERSION}
|
||||
```
|
||||
|
||||
#### Windows:
|
||||
```powershell
|
||||
docker run -it --rm --pull=always `
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:${env:OPENHANDS_VERSION}-nikolaik `
|
||||
-e LOG_ALL_EVENTS=true `
|
||||
-e RUNTIME=daytona `
|
||||
-e DAYTONA_API_KEY=${env:DAYTONA_API_KEY} `
|
||||
-v ~/.openhands-state:/.openhands-state `
|
||||
-p 3000:3000 `
|
||||
--name openhands-app `
|
||||
docker.all-hands.dev/all-hands-ai/openhands:${env:OPENHANDS_VERSION}
|
||||
```
|
||||
|
||||
> **Tip:** If you don't want your sandboxes to default to the EU region, you can set the `DAYTONA_TARGET` environment variable to `us`
|
||||
|
||||
### Running OpenHands Locally Without Docker
|
||||
|
||||
Alternatively, if you want to run the OpenHands app on your local machine using `make run` without Docker, make sure to set the following environment variables first:
|
||||
|
||||
#### Mac/Linux:
|
||||
```bash
|
||||
export RUNTIME="daytona"
|
||||
export DAYTONA_API_KEY="<your-api-key>"
|
||||
```
|
||||
|
||||
#### Windows PowerShell:
|
||||
```powershell
|
||||
$env:RUNTIME="daytona"
|
||||
$env:DAYTONA_API_KEY="<your-api-key>"
|
||||
```
|
||||
|
||||
## Documentation
|
||||
Read more by visiting our [documentation](https://www.daytona.io/docs/) page.
|
||||
|
||||
@@ -115,12 +115,12 @@ class DaytonaRuntime(ActionExecutionClient):
|
||||
|
||||
def _construct_api_url(self, port: int) -> str:
|
||||
assert self.workspace is not None, 'Workspace is not initialized'
|
||||
assert (
|
||||
self.workspace.instance.info is not None
|
||||
), 'Workspace info is not available'
|
||||
assert (
|
||||
self.workspace.instance.info.provider_metadata is not None
|
||||
), 'Provider metadata is not available'
|
||||
assert self.workspace.instance.info is not None, (
|
||||
'Workspace info is not available'
|
||||
)
|
||||
assert self.workspace.instance.info.provider_metadata is not None, (
|
||||
'Provider metadata is not available'
|
||||
)
|
||||
|
||||
node_domain = json.loads(self.workspace.instance.info.provider_metadata)[
|
||||
'nodeDomain'
|
||||
|
||||
@@ -47,6 +47,7 @@ def _is_retryable_wait_until_alive_error(exception):
|
||||
exception,
|
||||
(
|
||||
ConnectionError,
|
||||
httpx.ConnectTimeout,
|
||||
httpx.NetworkError,
|
||||
httpx.RemoteProtocolError,
|
||||
httpx.HTTPStatusError,
|
||||
@@ -207,6 +208,54 @@ class DockerRuntime(ActionExecutionClient):
|
||||
)
|
||||
raise ex
|
||||
|
||||
def _process_volumes(self) -> dict[str, dict[str, str]]:
|
||||
"""Process volume mounts based on configuration.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping host paths to container bind mounts with their modes.
|
||||
"""
|
||||
# Initialize volumes dictionary
|
||||
volumes: dict[str, dict[str, str]] = {}
|
||||
|
||||
# Process volumes (comma-delimited)
|
||||
if self.config.sandbox.volumes is not None:
|
||||
# Handle multiple mounts with comma delimiter
|
||||
mounts = self.config.sandbox.volumes.split(',')
|
||||
|
||||
for mount in mounts:
|
||||
parts = mount.split(':')
|
||||
if len(parts) >= 2:
|
||||
host_path = os.path.abspath(parts[0])
|
||||
container_path = parts[1]
|
||||
# Default mode is 'rw' if not specified
|
||||
mount_mode = parts[2] if len(parts) > 2 else 'rw'
|
||||
|
||||
volumes[host_path] = {
|
||||
'bind': container_path,
|
||||
'mode': mount_mode,
|
||||
}
|
||||
logger.debug(
|
||||
f'Mount dir (sandbox.volumes): {host_path} to {container_path} with mode: {mount_mode}'
|
||||
)
|
||||
|
||||
# Legacy mounting with workspace_* parameters
|
||||
elif (
|
||||
self.config.workspace_mount_path is not None
|
||||
and self.config.workspace_mount_path_in_sandbox is not None
|
||||
):
|
||||
mount_mode = 'rw' # Default mode
|
||||
|
||||
# e.g. result would be: {"/home/user/openhands/workspace": {'bind': "/workspace", 'mode': 'rw'}}
|
||||
volumes[self.config.workspace_mount_path] = {
|
||||
'bind': self.config.workspace_mount_path_in_sandbox,
|
||||
'mode': mount_mode,
|
||||
}
|
||||
logger.debug(
|
||||
f'Mount dir (legacy): {self.config.workspace_mount_path} with mode: {mount_mode}'
|
||||
)
|
||||
|
||||
return volumes
|
||||
|
||||
def _init_container(self):
|
||||
self.log('debug', 'Preparing to start container...')
|
||||
self.send_status_message('STATUS$PREPARING_CONTAINER')
|
||||
@@ -272,23 +321,16 @@ class DockerRuntime(ActionExecutionClient):
|
||||
environment.update(self.config.sandbox.runtime_startup_env_vars)
|
||||
|
||||
self.log('debug', f'Workspace Base: {self.config.workspace_base}')
|
||||
if (
|
||||
self.config.workspace_mount_path is not None
|
||||
and self.config.workspace_mount_path_in_sandbox is not None
|
||||
):
|
||||
# e.g. result would be: {"/home/user/openhands/workspace": {'bind': "/workspace", 'mode': 'rw'}}
|
||||
volumes = {
|
||||
self.config.workspace_mount_path: {
|
||||
'bind': self.config.workspace_mount_path_in_sandbox,
|
||||
'mode': 'rw',
|
||||
}
|
||||
}
|
||||
logger.debug(f'Mount dir: {self.config.workspace_mount_path}')
|
||||
else:
|
||||
|
||||
# Process volumes for mounting
|
||||
volumes = self._process_volumes()
|
||||
|
||||
# If no volumes were configured, set to None
|
||||
if not volumes:
|
||||
logger.debug(
|
||||
'Mount dir is not set, will not mount the workspace directory to the container'
|
||||
)
|
||||
volumes = None
|
||||
volumes = {} # Empty dict instead of None to satisfy mypy
|
||||
self.log(
|
||||
'debug',
|
||||
f'Sandbox workspace: {self.config.workspace_mount_path_in_sandbox}',
|
||||
@@ -447,8 +489,9 @@ class DockerRuntime(ActionExecutionClient):
|
||||
def web_hosts(self):
|
||||
hosts: dict[str, int] = {}
|
||||
|
||||
host_addr = os.environ.get('DOCKER_HOST_ADDR', 'localhost')
|
||||
for port in self._app_ports:
|
||||
hosts[f'http://localhost:{port}'] = port
|
||||
hosts[f'http://{host_addr}:{port}'] = port
|
||||
|
||||
return hosts
|
||||
|
||||
|
||||
@@ -40,9 +40,9 @@ class E2BBox:
|
||||
|
||||
def _archive(self, host_src: str, recursive: bool = False):
|
||||
if recursive:
|
||||
assert os.path.isdir(
|
||||
host_src
|
||||
), 'Source must be a directory when recursive is True'
|
||||
assert os.path.isdir(host_src), (
|
||||
'Source must be a directory when recursive is True'
|
||||
)
|
||||
files = glob(host_src + '/**/*', recursive=True)
|
||||
srcname = os.path.basename(host_src)
|
||||
tar_filename = os.path.join(os.path.dirname(host_src), srcname + '.tar')
|
||||
@@ -52,9 +52,9 @@ class E2BBox:
|
||||
file, arcname=os.path.relpath(file, os.path.dirname(host_src))
|
||||
)
|
||||
else:
|
||||
assert os.path.isfile(
|
||||
host_src
|
||||
), 'Source must be a file when recursive is False'
|
||||
assert os.path.isfile(host_src), (
|
||||
'Source must be a file when recursive is False'
|
||||
)
|
||||
srcname = os.path.basename(host_src)
|
||||
tar_filename = os.path.join(os.path.dirname(host_src), srcname + '.tar')
|
||||
with tarfile.open(tar_filename, mode='w') as tar:
|
||||
|
||||
@@ -41,6 +41,18 @@ from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.tenacity_stop import stop_if_should_exit
|
||||
|
||||
|
||||
def get_user_info():
|
||||
"""Get user ID and username in a cross-platform way."""
|
||||
username = os.getenv('USER')
|
||||
if sys.platform == 'win32':
|
||||
# On Windows, we don't use user IDs the same way
|
||||
# Return a default value that won't cause issues
|
||||
return 1000, username
|
||||
else:
|
||||
# On Unix systems, use os.getuid()
|
||||
return os.getuid(), username
|
||||
|
||||
|
||||
def check_dependencies(code_repo_path: str, poetry_venvs_path: str):
|
||||
ERROR_MESSAGE = 'Please follow the instructions in https://github.com/All-Hands-AI/OpenHands/blob/main/Development.md to install OpenHands.'
|
||||
if not os.path.exists(code_repo_path):
|
||||
@@ -63,28 +75,33 @@ def check_dependencies(code_repo_path: str, poetry_venvs_path: str):
|
||||
if 'jupyter' not in output.lower():
|
||||
raise ValueError('Jupyter is not properly installed. ' + ERROR_MESSAGE)
|
||||
|
||||
# Check libtmux is installed
|
||||
logger.debug('Checking dependencies: libtmux')
|
||||
import libtmux
|
||||
# Check libtmux is installed (skip on Windows)
|
||||
|
||||
server = libtmux.Server()
|
||||
try:
|
||||
session = server.new_session(session_name='test-session')
|
||||
except Exception:
|
||||
raise ValueError('tmux is not properly installed or available on the path.')
|
||||
pane = session.attached_pane
|
||||
pane.send_keys('echo "test"')
|
||||
pane_output = '\n'.join(pane.cmd('capture-pane', '-p').stdout)
|
||||
session.kill_session()
|
||||
if 'test' not in pane_output:
|
||||
raise ValueError('libtmux is not properly installed. ' + ERROR_MESSAGE)
|
||||
if sys.platform != 'win32':
|
||||
logger.debug('Checking dependencies: libtmux')
|
||||
import libtmux
|
||||
|
||||
# Check browser works
|
||||
logger.debug('Checking dependencies: browser')
|
||||
from openhands.runtime.browser.browser_env import BrowserEnv
|
||||
server = libtmux.Server()
|
||||
try:
|
||||
session = server.new_session(session_name='test-session')
|
||||
except Exception:
|
||||
raise ValueError('tmux is not properly installed or available on the path.')
|
||||
pane = session.attached_pane
|
||||
pane.send_keys('echo "test"')
|
||||
pane_output = '\n'.join(pane.cmd('capture-pane', '-p').stdout)
|
||||
session.kill_session()
|
||||
if 'test' not in pane_output:
|
||||
raise ValueError('libtmux is not properly installed. ' + ERROR_MESSAGE)
|
||||
|
||||
browser = BrowserEnv()
|
||||
browser.close()
|
||||
# Skip browser environment check on Windows
|
||||
if sys.platform != 'win32':
|
||||
logger.debug('Checking dependencies: browser')
|
||||
from openhands.runtime.browser.browser_env import BrowserEnv
|
||||
|
||||
browser = BrowserEnv()
|
||||
browser.close()
|
||||
else:
|
||||
logger.warning('Running on Windows - browser environment check skipped.')
|
||||
|
||||
|
||||
class LocalRuntime(ActionExecutionClient):
|
||||
@@ -110,9 +127,15 @@ class LocalRuntime(ActionExecutionClient):
|
||||
attach_to_existing: bool = False,
|
||||
headless_mode: bool = True,
|
||||
):
|
||||
self.is_windows = sys.platform == 'win32'
|
||||
if self.is_windows:
|
||||
logger.warning(
|
||||
'Running on Windows - some features that require tmux will be limited. '
|
||||
'For full functionality, please consider using WSL or Docker runtime.'
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self._user_id = os.getuid()
|
||||
self._username = os.getenv('USER')
|
||||
self._user_id, self._username = get_user_info()
|
||||
|
||||
if self.config.workspace_base is not None:
|
||||
logger.warning(
|
||||
@@ -161,6 +184,7 @@ class LocalRuntime(ActionExecutionClient):
|
||||
self.status_callback = status_callback
|
||||
self.server_process: subprocess.Popen[str] | None = None
|
||||
self.action_semaphore = threading.Semaphore(1) # Ensure one action at a time
|
||||
self._log_thread_exit_event = threading.Event() # Add exit event
|
||||
|
||||
# Update env vars
|
||||
if self.config.sandbox.runtime_startup_env_vars:
|
||||
@@ -199,7 +223,7 @@ class LocalRuntime(ActionExecutionClient):
|
||||
server_port=self._host_port,
|
||||
plugins=self.plugins,
|
||||
app_config=self.config,
|
||||
python_prefix=[],
|
||||
python_prefix=['poetry', 'run'],
|
||||
override_user_id=self._user_id,
|
||||
override_username=self._username,
|
||||
)
|
||||
@@ -208,7 +232,7 @@ class LocalRuntime(ActionExecutionClient):
|
||||
env = os.environ.copy()
|
||||
# Get the code repo path
|
||||
code_repo_path = os.path.dirname(os.path.dirname(openhands.__file__))
|
||||
env['PYTHONPATH'] = f'{code_repo_path}{os.pathsep}{env.get("PYTHONPATH", "")}'
|
||||
env['PYTHONPATH'] = os.pathsep.join([code_repo_path, env.get('PYTHONPATH', '')])
|
||||
env['OPENHANDS_REPO_PATH'] = code_repo_path
|
||||
env['LOCAL_RUNTIME_MODE'] = '1'
|
||||
|
||||
@@ -230,19 +254,50 @@ class LocalRuntime(ActionExecutionClient):
|
||||
universal_newlines=True,
|
||||
bufsize=1,
|
||||
env=env,
|
||||
cwd=code_repo_path, # Explicitly set the working directory
|
||||
)
|
||||
|
||||
# Start a thread to read and log server output
|
||||
def log_output():
|
||||
while (
|
||||
self.server_process
|
||||
and self.server_process.poll()
|
||||
and self.server_process.stdout
|
||||
):
|
||||
line = self.server_process.stdout.readline()
|
||||
if not line:
|
||||
break
|
||||
self.log('debug', f'Server: {line.strip()}')
|
||||
if not self.server_process or not self.server_process.stdout:
|
||||
self.log('error', 'Server process or stdout not available for logging.')
|
||||
return
|
||||
|
||||
try:
|
||||
# Read lines while the process is running and stdout is available
|
||||
while self.server_process.poll() is None:
|
||||
if self._log_thread_exit_event.is_set(): # Check exit event
|
||||
self.log('info', 'Log thread received exit signal.')
|
||||
break # Exit loop if signaled
|
||||
line = self.server_process.stdout.readline()
|
||||
if not line:
|
||||
# Process might have exited between poll() and readline()
|
||||
break
|
||||
self.log('info', f'Server: {line.strip()}')
|
||||
|
||||
# Capture any remaining output after the process exits OR if signaled
|
||||
if (
|
||||
not self._log_thread_exit_event.is_set()
|
||||
): # Check again before reading remaining
|
||||
self.log('info', 'Server process exited, reading remaining output.')
|
||||
for line in self.server_process.stdout:
|
||||
if (
|
||||
self._log_thread_exit_event.is_set()
|
||||
): # Check inside loop too
|
||||
self.log(
|
||||
'info',
|
||||
'Log thread received exit signal while reading remaining output.',
|
||||
)
|
||||
break
|
||||
self.log('info', f'Server (remaining): {line.strip()}')
|
||||
|
||||
except Exception as e:
|
||||
# Log the error, but don't prevent the thread from potentially exiting
|
||||
self.log('error', f'Error reading server output: {e}')
|
||||
finally:
|
||||
self.log(
|
||||
'info', 'Log output thread finished.'
|
||||
) # Add log for thread exit
|
||||
|
||||
self._log_thread = threading.Thread(target=log_output, daemon=True)
|
||||
self._log_thread.start()
|
||||
@@ -312,6 +367,8 @@ class LocalRuntime(ActionExecutionClient):
|
||||
|
||||
def close(self):
|
||||
"""Stop the server process."""
|
||||
self._log_thread_exit_event.set() # Signal the log thread to exit
|
||||
|
||||
if self.server_process:
|
||||
self.server_process.terminate()
|
||||
try:
|
||||
@@ -319,7 +376,7 @@ class LocalRuntime(ActionExecutionClient):
|
||||
except subprocess.TimeoutExpired:
|
||||
self.server_process.kill()
|
||||
self.server_process = None
|
||||
self._log_thread.join()
|
||||
self._log_thread.join(timeout=5) # Add timeout to join
|
||||
|
||||
if self._temp_workspace:
|
||||
shutil.rmtree(self._temp_workspace)
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Callable
|
||||
from typing import Any, Callable
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
import tenacity
|
||||
from tenacity import RetryCallState
|
||||
|
||||
from openhands.core.config import AppConfig
|
||||
from openhands.core.exceptions import (
|
||||
@@ -37,6 +38,9 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
runtime_id: str | None = None
|
||||
runtime_url: str | None = None
|
||||
_runtime_initialized: bool = False
|
||||
runtime_builder: RemoteRuntimeBuilder
|
||||
container_image: str
|
||||
available_hosts: dict[str, int]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -45,12 +49,12 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
sid: str = 'default',
|
||||
plugins: list[PluginRequirement] | None = None,
|
||||
env_vars: dict[str, str] | None = None,
|
||||
status_callback: Callable | None = None,
|
||||
status_callback: Callable[..., None] | None = None,
|
||||
attach_to_existing: bool = False,
|
||||
headless_mode: bool = True,
|
||||
user_id: str | None = None,
|
||||
git_provider_tokens: PROVIDER_TOKEN_TYPE | None = None,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
config,
|
||||
event_stream,
|
||||
@@ -94,10 +98,12 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
getattr(logger, level)(message, stacklevel=2)
|
||||
|
||||
@property
|
||||
def action_execution_server_url(self):
|
||||
def action_execution_server_url(self) -> str:
|
||||
if self.runtime_url is None:
|
||||
raise NotImplementedError('Runtime URL is not initialized')
|
||||
return self.runtime_url
|
||||
|
||||
async def connect(self):
|
||||
async def connect(self) -> None:
|
||||
try:
|
||||
await call_sync_from_async(self._start_or_attach_to_runtime)
|
||||
except Exception:
|
||||
@@ -107,7 +113,7 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
await call_sync_from_async(self.setup_initial_env)
|
||||
self._runtime_initialized = True
|
||||
|
||||
def _start_or_attach_to_runtime(self):
|
||||
def _start_or_attach_to_runtime(self) -> None:
|
||||
existing_runtime = self._check_existing_runtime()
|
||||
if existing_runtime:
|
||||
self.log('debug', f'Using existing runtime with ID: {self.runtime_id}')
|
||||
@@ -130,12 +136,12 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
)
|
||||
self.container_image = self.config.sandbox.runtime_container_image
|
||||
self._start_runtime()
|
||||
assert (
|
||||
self.runtime_id is not None
|
||||
), 'Runtime ID is not set. This should never happen.'
|
||||
assert (
|
||||
self.runtime_url is not None
|
||||
), 'Runtime URL is not set. This should never happen.'
|
||||
assert self.runtime_id is not None, (
|
||||
'Runtime ID is not set. This should never happen.'
|
||||
)
|
||||
assert self.runtime_url is not None, (
|
||||
'Runtime URL is not set. This should never happen.'
|
||||
)
|
||||
self.send_status_message('STATUS$WAITING_FOR_CLIENT')
|
||||
if not self.attach_to_existing:
|
||||
self.log('info', 'Waiting for runtime to be alive...')
|
||||
@@ -179,7 +185,7 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
self.log('error', f'Invalid response from runtime API: {data}')
|
||||
return False
|
||||
|
||||
def _build_runtime(self):
|
||||
def _build_runtime(self) -> None:
|
||||
self.log('debug', f'Building RemoteRuntime config:\n{self.config}')
|
||||
response = self._send_runtime_api_request(
|
||||
'GET',
|
||||
@@ -223,18 +229,18 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
f'Container image {self.container_image} does not exist'
|
||||
)
|
||||
|
||||
def _start_runtime(self):
|
||||
def _start_runtime(self) -> None:
|
||||
# Prepare the request body for the /start endpoint
|
||||
command = get_action_execution_server_startup_command(
|
||||
server_port=self.port,
|
||||
plugins=self.plugins,
|
||||
app_config=self.config,
|
||||
)
|
||||
environment = {}
|
||||
environment: dict[str, str] = {}
|
||||
if self.config.debug or os.environ.get('DEBUG', 'false').lower() == 'true':
|
||||
environment['DEBUG'] = 'true'
|
||||
environment.update(self.config.sandbox.runtime_startup_env_vars)
|
||||
start_request = {
|
||||
start_request: dict[str, Any] = {
|
||||
'image': self.container_image,
|
||||
'command': command,
|
||||
'working_dir': '/openhands/code/',
|
||||
@@ -262,8 +268,10 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
self.log('error', f'Unable to start runtime: {str(e)}')
|
||||
raise AgentRuntimeUnavailableError() from e
|
||||
|
||||
def _resume_runtime(self):
|
||||
"""
|
||||
def _resume_runtime(self) -> None:
|
||||
"""Resume a stopped runtime.
|
||||
|
||||
Steps:
|
||||
1. Show status update that runtime is being started.
|
||||
2. Send the runtime API a /resume request
|
||||
3. Poll for the runtime to be ready
|
||||
@@ -279,7 +287,7 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
self.setup_initial_env()
|
||||
self.log('debug', 'Runtime resumed.')
|
||||
|
||||
def _parse_runtime_response(self, response: httpx.Response):
|
||||
def _parse_runtime_response(self, response: httpx.Response) -> None:
|
||||
start_response = response.json()
|
||||
self.runtime_id = start_response['runtime_id']
|
||||
self.runtime_url = start_response['url']
|
||||
@@ -310,7 +318,7 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
def web_hosts(self) -> dict[str, int]:
|
||||
return self.available_hosts
|
||||
|
||||
def _wait_until_alive(self):
|
||||
def _wait_until_alive(self) -> None:
|
||||
retry_decorator = tenacity.retry(
|
||||
stop=tenacity.stop_after_delay(
|
||||
self.config.sandbox.remote_runtime_init_timeout
|
||||
@@ -321,9 +329,9 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
retry=tenacity.retry_if_exception_type(AgentRuntimeNotReadyError),
|
||||
wait=tenacity.wait_fixed(2),
|
||||
)
|
||||
return retry_decorator(self._wait_until_alive_impl)()
|
||||
retry_decorator(self._wait_until_alive_impl)()
|
||||
|
||||
def _wait_until_alive_impl(self):
|
||||
def _wait_until_alive_impl(self) -> None:
|
||||
self.log('debug', f'Waiting for runtime to be alive at url: {self.runtime_url}')
|
||||
runtime_info_response = self._send_runtime_api_request(
|
||||
'GET',
|
||||
@@ -384,7 +392,7 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
)
|
||||
raise AgentRuntimeNotReadyError()
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
if self.attach_to_existing:
|
||||
super().close()
|
||||
return
|
||||
@@ -417,7 +425,9 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
finally:
|
||||
super().close()
|
||||
|
||||
def _send_runtime_api_request(self, method, url, **kwargs):
|
||||
def _send_runtime_api_request(
|
||||
self, method: str, url: str, **kwargs: Any
|
||||
) -> httpx.Response:
|
||||
try:
|
||||
kwargs['timeout'] = self.config.sandbox.remote_runtime_api_timeout
|
||||
return send_request(self.session, method, url, **kwargs)
|
||||
@@ -428,7 +438,9 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
)
|
||||
raise
|
||||
|
||||
def _send_action_server_request(self, method, url, **kwargs):
|
||||
def _send_action_server_request(
|
||||
self, method: str, url: str, **kwargs: Any
|
||||
) -> httpx.Response:
|
||||
if not self.config.sandbox.remote_runtime_enable_retries:
|
||||
return self._send_action_server_request_impl(method, url, **kwargs)
|
||||
|
||||
@@ -444,7 +456,9 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
method, url, **kwargs
|
||||
)
|
||||
|
||||
def _send_action_server_request_impl(self, method, url, **kwargs):
|
||||
def _send_action_server_request_impl(
|
||||
self, method: str, url: str, **kwargs: Any
|
||||
) -> httpx.Response:
|
||||
try:
|
||||
return super()._send_action_server_request(method, url, **kwargs)
|
||||
except httpx.TimeoutException:
|
||||
@@ -455,7 +469,7 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
raise
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
if e.response.status_code in (404, 502, 504):
|
||||
if hasattr(e, 'response') and e.response.status_code in (404, 502, 504):
|
||||
if e.response.status_code == 404:
|
||||
raise AgentRuntimeDisconnectedError(
|
||||
f'Runtime is not responding. This may be temporary, please try again. Original error: {e}'
|
||||
@@ -464,7 +478,7 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
raise AgentRuntimeDisconnectedError(
|
||||
f'Runtime is temporarily unavailable. This may be due to a restart or network issue, please try again. Original error: {e}'
|
||||
) from e
|
||||
elif e.response.status_code == 503:
|
||||
elif hasattr(e, 'response') and e.response.status_code == 503:
|
||||
if self.config.sandbox.keep_runtime_alive:
|
||||
self.log('warning', 'Runtime appears to be paused. Resuming...')
|
||||
self._resume_runtime()
|
||||
@@ -476,5 +490,5 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
else:
|
||||
raise e
|
||||
|
||||
def _stop_if_closed(self, retry_state: tenacity.RetryCallState) -> bool:
|
||||
def _stop_if_closed(self, retry_state: RetryCallState) -> bool:
|
||||
return self._runtime_closed
|
||||
|
||||
@@ -157,7 +157,7 @@ def _print_window(
|
||||
else:
|
||||
output += '(this is the beginning of the file)\n'
|
||||
for i in range(start, end + 1):
|
||||
_new_line = f'{i}|{lines[i-1]}'
|
||||
_new_line = f'{i}|{lines[i - 1]}'
|
||||
if not _new_line.endswith('\n'):
|
||||
_new_line += '\n'
|
||||
output += _new_line
|
||||
|
||||
@@ -2,7 +2,7 @@ from types import ModuleType
|
||||
|
||||
|
||||
def import_functions(
|
||||
module: ModuleType, function_names: list[str], target_globals: dict
|
||||
module: ModuleType, function_names: list[str], target_globals: dict[str, object]
|
||||
) -> None:
|
||||
for name in function_names:
|
||||
if hasattr(module, name):
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import asyncio
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@@ -20,7 +23,7 @@ class JupyterPlugin(Plugin):
|
||||
name: str = 'jupyter'
|
||||
kernel_gateway_port: int
|
||||
kernel_id: str
|
||||
gateway_process: asyncio.subprocess.Process
|
||||
gateway_process: asyncio.subprocess.Process | subprocess.Popen
|
||||
python_interpreter_path: str
|
||||
|
||||
async def initialize(
|
||||
@@ -28,7 +31,10 @@ class JupyterPlugin(Plugin):
|
||||
) -> None:
|
||||
self.kernel_gateway_port = find_available_tcp_port(40000, 49999)
|
||||
self.kernel_id = kernel_id
|
||||
if username in ['root', 'openhands']:
|
||||
is_local_runtime = os.environ.get('LOCAL_RUNTIME_MODE') == '1'
|
||||
is_windows = sys.platform == 'win32'
|
||||
|
||||
if not is_local_runtime:
|
||||
# Non-LocalRuntime
|
||||
prefix = f'su - {username} -s '
|
||||
# cd to code repo, setup all env vars and run micromamba
|
||||
@@ -50,37 +56,84 @@ class JupyterPlugin(Plugin):
|
||||
)
|
||||
# The correct environment is ensured by the PATH in LocalRuntime.
|
||||
poetry_prefix = f'cd {code_repo_path}\n'
|
||||
jupyter_launch_command = (
|
||||
f"{prefix}/bin/bash << 'EOF'\n"
|
||||
f'{poetry_prefix}'
|
||||
'poetry run jupyter kernelgateway '
|
||||
'--KernelGatewayApp.ip=0.0.0.0 '
|
||||
f'--KernelGatewayApp.port={self.kernel_gateway_port}\n'
|
||||
'EOF'
|
||||
)
|
||||
logger.debug(f'Jupyter launch command: {jupyter_launch_command}')
|
||||
|
||||
# Using asyncio.create_subprocess_shell instead of subprocess.Popen
|
||||
# to avoid ASYNC101 linting error
|
||||
self.gateway_process = await asyncio.create_subprocess_shell(
|
||||
jupyter_launch_command,
|
||||
stderr=asyncio.subprocess.STDOUT,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
)
|
||||
# read stdout until the kernel gateway is ready
|
||||
output = ''
|
||||
while should_continue() and self.gateway_process.stdout is not None:
|
||||
line_bytes = await self.gateway_process.stdout.readline()
|
||||
line = line_bytes.decode('utf-8')
|
||||
output += line
|
||||
if 'at' in line:
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
logger.debug('Waiting for jupyter kernel gateway to start...')
|
||||
if is_windows:
|
||||
# Windows-specific command format
|
||||
jupyter_launch_command = (
|
||||
f'cd /d "{code_repo_path}" && '
|
||||
'poetry run jupyter kernelgateway '
|
||||
'--KernelGatewayApp.ip=0.0.0.0 '
|
||||
f'--KernelGatewayApp.port={self.kernel_gateway_port}'
|
||||
)
|
||||
logger.debug(f'Jupyter launch command (Windows): {jupyter_launch_command}')
|
||||
|
||||
# Using synchronous subprocess.Popen for Windows as asyncio.create_subprocess_shell
|
||||
# has limitations on Windows platforms
|
||||
self.gateway_process = subprocess.Popen( # type: ignore[ASYNC101] # noqa: ASYNC101
|
||||
jupyter_launch_command,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
shell=True,
|
||||
text=True,
|
||||
)
|
||||
|
||||
# Windows-specific stdout handling with synchronous time.sleep
|
||||
# as asyncio has limitations on Windows for subprocess operations
|
||||
output = ''
|
||||
while should_continue():
|
||||
if self.gateway_process.stdout is None:
|
||||
time.sleep(1) # type: ignore[ASYNC101] # noqa: ASYNC101
|
||||
continue
|
||||
|
||||
line = self.gateway_process.stdout.readline()
|
||||
if not line:
|
||||
time.sleep(1) # type: ignore[ASYNC101] # noqa: ASYNC101
|
||||
continue
|
||||
|
||||
output += line
|
||||
if 'at' in line:
|
||||
break
|
||||
|
||||
time.sleep(1) # type: ignore[ASYNC101] # noqa: ASYNC101
|
||||
logger.debug('Waiting for jupyter kernel gateway to start...')
|
||||
|
||||
logger.debug(
|
||||
f'Jupyter kernel gateway started at port {self.kernel_gateway_port}. Output: {output}'
|
||||
)
|
||||
else:
|
||||
# Unix systems (Linux/macOS)
|
||||
jupyter_launch_command = (
|
||||
f"{prefix}/bin/bash << 'EOF'\n"
|
||||
f'{poetry_prefix}'
|
||||
'poetry run jupyter kernelgateway '
|
||||
'--KernelGatewayApp.ip=0.0.0.0 '
|
||||
f'--KernelGatewayApp.port={self.kernel_gateway_port}\n'
|
||||
'EOF'
|
||||
)
|
||||
logger.debug(f'Jupyter launch command: {jupyter_launch_command}')
|
||||
|
||||
# Using asyncio.create_subprocess_shell instead of subprocess.Popen
|
||||
# to avoid ASYNC101 linting error
|
||||
self.gateway_process = await asyncio.create_subprocess_shell(
|
||||
jupyter_launch_command,
|
||||
stderr=asyncio.subprocess.STDOUT,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
)
|
||||
# read stdout until the kernel gateway is ready
|
||||
output = ''
|
||||
while should_continue() and self.gateway_process.stdout is not None:
|
||||
line_bytes = await self.gateway_process.stdout.readline()
|
||||
line = line_bytes.decode('utf-8')
|
||||
output += line
|
||||
if 'at' in line:
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
logger.debug('Waiting for jupyter kernel gateway to start...')
|
||||
|
||||
logger.debug(
|
||||
f'Jupyter kernel gateway started at port {self.kernel_gateway_port}. Output: {output}'
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f'Jupyter kernel gateway started at port {self.kernel_gateway_port}. Output: {output}'
|
||||
)
|
||||
_obs = await self.run(
|
||||
IPythonRunCellAction(code='import sys; print(sys.executable)')
|
||||
)
|
||||
|
||||
@@ -138,7 +138,7 @@ class JupyterKernel:
|
||||
retry=retry_if_exception_type(ConnectionRefusedError),
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_fixed(2),
|
||||
)
|
||||
) # type: ignore
|
||||
async def execute(self, code: str, timeout: int = 120) -> str:
|
||||
if not self.ws or self.ws.stream.closed():
|
||||
await self._connect()
|
||||
@@ -189,7 +189,7 @@ class JupyterKernel:
|
||||
|
||||
if os.environ.get('DEBUG'):
|
||||
logging.info(
|
||||
f"MSG TYPE: {msg_type.upper()} DONE:{execution_done}\nCONTENT: {msg_dict['content']}"
|
||||
f'MSG TYPE: {msg_type.upper()} DONE:{execution_done}\nCONTENT: {msg_dict["content"]}'
|
||||
)
|
||||
|
||||
if msg_type == 'error':
|
||||
@@ -203,7 +203,7 @@ class JupyterKernel:
|
||||
if 'image/png' in msg_dict['content']['data']:
|
||||
# use markdone to display image (in case of large image)
|
||||
outputs.append(
|
||||
f"\n\n"
|
||||
f'\n\n'
|
||||
)
|
||||
|
||||
elif msg_type == 'execute_reply':
|
||||
@@ -272,7 +272,7 @@ class ExecuteHandler(tornado.web.RequestHandler):
|
||||
|
||||
def make_app() -> tornado.web.Application:
|
||||
jupyter_kernel = JupyterKernel(
|
||||
f"localhost:{os.environ.get('JUPYTER_GATEWAY_PORT', '8888')}",
|
||||
f'localhost:{os.environ.get("JUPYTER_GATEWAY_PORT", "8888")}',
|
||||
os.environ.get('JUPYTER_GATEWAY_KERNEL_ID', 'default'),
|
||||
)
|
||||
asyncio.get_event_loop().run_until_complete(jupyter_kernel.initialize())
|
||||
|
||||
@@ -6,7 +6,7 @@ import uuid
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import bashlex # type: ignore
|
||||
import bashlex
|
||||
import libtmux
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@@ -25,7 +25,13 @@ def split_bash_commands(commands: str) -> list[str]:
|
||||
return ['']
|
||||
try:
|
||||
parsed = bashlex.parse(commands)
|
||||
except (bashlex.errors.ParsingError, NotImplementedError, TypeError):
|
||||
except (
|
||||
bashlex.errors.ParsingError,
|
||||
NotImplementedError,
|
||||
TypeError,
|
||||
AttributeError,
|
||||
):
|
||||
# Added AttributeError to catch 'str' object has no attribute 'kind' error (issue #8369)
|
||||
logger.debug(
|
||||
f'Failed to parse bash commands\n'
|
||||
f'[input]: {commands}\n'
|
||||
@@ -501,9 +507,9 @@ class BashSession:
|
||||
if len(splited_commands) > 1:
|
||||
return ErrorObservation(
|
||||
content=(
|
||||
f"ERROR: Cannot execute multiple commands at once.\n"
|
||||
f"Please run each command separately OR chain them into a single command via && or ;\n"
|
||||
f"Provided commands:\n{'\n'.join(f'({i + 1}) {cmd}' for i, cmd in enumerate(splited_commands))}"
|
||||
f'ERROR: Cannot execute multiple commands at once.\n'
|
||||
f'Please run each command separately OR chain them into a single command via && or ;\n'
|
||||
f'Provided commands:\n{"\n".join(f"({i + 1}) {cmd}" for i, cmd in enumerate(splited_commands))}'
|
||||
)
|
||||
)
|
||||
|
||||
@@ -591,8 +597,8 @@ class BashSession:
|
||||
logger.debug(
|
||||
f'PANE CONTENT GOT after {time.time() - _start_time:.2f} seconds'
|
||||
)
|
||||
logger.debug(f"BEGIN OF PANE CONTENT: {cur_pane_output.split('\n')[:10]}")
|
||||
logger.debug(f"END OF PANE CONTENT: {cur_pane_output.split('\n')[-10:]}")
|
||||
logger.debug(f'BEGIN OF PANE CONTENT: {cur_pane_output.split("\n")[:10]}')
|
||||
logger.debug(f'END OF PANE CONTENT: {cur_pane_output.split("\n")[-10:]}')
|
||||
ps1_matches = CmdOutputMetadata.matches_ps1_metadata(cur_pane_output)
|
||||
current_ps1_count = len(ps1_matches)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import tempfile
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from openhands_aci.utils.diff import get_diff # type: ignore
|
||||
from openhands_aci.utils.diff import get_diff
|
||||
|
||||
from openhands.core.config import AppConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
@@ -35,8 +35,8 @@ def generate_file_viewer_html(file_path: str) -> str:
|
||||
# Check if the file extension is supported
|
||||
if file_extension not in supported_extensions:
|
||||
raise ValueError(
|
||||
f"Unsupported file extension: {file_extension}. "
|
||||
f"Supported extensions are: {', '.join(supported_extensions)}"
|
||||
f'Unsupported file extension: {file_extension}. '
|
||||
f'Supported extensions are: {", ".join(supported_extensions)}'
|
||||
)
|
||||
|
||||
# Check if the file exists
|
||||
|
||||
@@ -28,7 +28,7 @@ class GitHandler:
|
||||
self.execute = execute_shell_fn
|
||||
self.cwd: str | None = None
|
||||
|
||||
def set_cwd(self, cwd: str):
|
||||
def set_cwd(self, cwd: str) -> None:
|
||||
"""
|
||||
Sets the current working directory for Git operations.
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ class LogStreamer:
|
||||
def __init__(
|
||||
self,
|
||||
container: docker.models.containers.Container,
|
||||
logFn: Callable,
|
||||
logFn: Callable[[str, str], None],
|
||||
):
|
||||
self.log = logFn
|
||||
# Initialize all attributes before starting the thread on this instance
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import threading
|
||||
|
||||
from memory_profiler import memory_usage # type: ignore
|
||||
from memory_profiler import memory_usage
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ class RequestHTTPError(httpx.HTTPStatusError):
|
||||
s = super().__str__()
|
||||
if self.detail is not None:
|
||||
s += f'\nDetails: {self.detail}'
|
||||
return s
|
||||
return str(s)
|
||||
|
||||
|
||||
def is_retryable_error(exception: Any) -> bool:
|
||||
@@ -57,4 +57,4 @@ def send_request(
|
||||
response=e.response,
|
||||
detail=_json.get('detail') if _json is not None else None,
|
||||
) from e
|
||||
return response # type: ignore
|
||||
return response
|
||||
|
||||
@@ -6,10 +6,9 @@ import string
|
||||
import tempfile
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import docker
|
||||
from dirhash import dirhash # type: ignore
|
||||
from dirhash import dirhash
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
|
||||
import openhands
|
||||
@@ -111,7 +110,7 @@ def build_runtime_image(
|
||||
build_folder: str | None = None,
|
||||
dry_run: bool = False,
|
||||
force_rebuild: bool = False,
|
||||
extra_build_args: List[str] | None = None,
|
||||
extra_build_args: list[str] | None = None,
|
||||
) -> str:
|
||||
"""Prepares the final docker build folder.
|
||||
|
||||
@@ -167,7 +166,7 @@ def build_runtime_image_in_folder(
|
||||
dry_run: bool,
|
||||
force_rebuild: bool,
|
||||
platform: str | None = None,
|
||||
extra_build_args: List[str] | None = None,
|
||||
extra_build_args: list[str] | None = None,
|
||||
) -> str:
|
||||
runtime_image_repo, _ = get_runtime_image_repo_and_tag(base_image)
|
||||
lock_tag = f'oh_v{oh_version}_{get_hash_for_lock_files(base_image)}'
|
||||
@@ -284,8 +283,9 @@ def prep_build_folder(
|
||||
build_from=build_from,
|
||||
extra_deps=extra_deps,
|
||||
)
|
||||
with open(Path(build_folder, 'Dockerfile'), 'w') as file: # type: ignore
|
||||
file.write(dockerfile_content) # type: ignore
|
||||
dockerfile_path = Path(build_folder, 'Dockerfile')
|
||||
with open(str(dockerfile_path), 'w') as f:
|
||||
f.write(dockerfile_content)
|
||||
|
||||
|
||||
_ALPHABET = string.digits + string.ascii_lowercase
|
||||
@@ -294,7 +294,7 @@ _ALPHABET = string.digits + string.ascii_lowercase
|
||||
def truncate_hash(hash: str) -> str:
|
||||
"""Convert the base16 hash to base36 and truncate at 16 characters."""
|
||||
value = int(hash, 16)
|
||||
result: List[str] = []
|
||||
result: list[str] = []
|
||||
while value > 0 and len(result) < 16:
|
||||
value, remainder = divmod(value, len(_ALPHABET))
|
||||
result.append(_ALPHABET[remainder])
|
||||
@@ -347,7 +347,7 @@ def _build_sandbox_image(
|
||||
lock_tag: str,
|
||||
versioned_tag: str | None,
|
||||
platform: str | None = None,
|
||||
extra_build_args: List[str] | None = None,
|
||||
extra_build_args: list[str] | None = None,
|
||||
) -> str:
|
||||
"""Build and tag the sandbox image. The image will be tagged with all tags that do not yet exist."""
|
||||
names = [
|
||||
@@ -385,9 +385,9 @@ if __name__ == '__main__':
|
||||
# and create a Dockerfile dynamically and place it in the build_folder only. This allows the Docker image to
|
||||
# then be created using the Dockerfile (most likely using the containers/build.sh script)
|
||||
build_folder = args.build_folder
|
||||
assert os.path.exists(
|
||||
build_folder
|
||||
), f'Build folder {build_folder} does not exist'
|
||||
assert os.path.exists(build_folder), (
|
||||
f'Build folder {build_folder} does not exist'
|
||||
)
|
||||
logger.debug(
|
||||
f'Copying the source code and generating the Dockerfile in the build folder: {build_folder}'
|
||||
)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
@@ -32,6 +33,17 @@ def init_user_and_working_directory(
|
||||
Returns:
|
||||
int | None: The user ID if it was updated, None otherwise.
|
||||
"""
|
||||
# If running on Windows, just create the directory and return
|
||||
if sys.platform == 'win32':
|
||||
logger.debug('Running on Windows, skipping Unix-specific user setup')
|
||||
logger.debug(f'Client working directory: {initial_cwd}')
|
||||
|
||||
# Create the working directory if it doesn't exist
|
||||
os.makedirs(initial_cwd, exist_ok=True)
|
||||
logger.debug(f'Created working directory: {initial_cwd}')
|
||||
|
||||
return None
|
||||
|
||||
# if username is CURRENT_USER, then we don't need to do anything
|
||||
# This is specific to the local runtime
|
||||
if username == os.getenv('USER') and username not in ['root', 'openhands']:
|
||||
|
||||
@@ -38,7 +38,9 @@ RUN ln -s "$(dirname $(which node))/corepack" /usr/local/bin/corepack && \
|
||||
{% endif %}
|
||||
|
||||
# Install uv (required by MCP)
|
||||
RUN curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
RUN curl -LsSf https://astral.sh/uv/install.sh | env UV_INSTALL_DIR="/openhands/bin" sh
|
||||
# Add /openhands/bin to PATH
|
||||
ENV PATH="/openhands/bin:${PATH}"
|
||||
|
||||
# Remove UID 1000 named pn or ubuntu, so the 'openhands' user can be created from ubuntu hosts
|
||||
RUN (if getent passwd 1000 | grep -q pn; then userdel pn; fi) && \
|
||||
|
||||
@@ -5,7 +5,7 @@ import time
|
||||
import psutil
|
||||
|
||||
|
||||
def get_system_stats() -> dict:
|
||||
def get_system_stats() -> dict[str, object]:
|
||||
"""Get current system resource statistics.
|
||||
|
||||
Returns:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -176,9 +176,9 @@ class InvariantAnalyzer(SecurityAnalyzer):
|
||||
],
|
||||
)
|
||||
)
|
||||
assert (
|
||||
self.guardrail_llm is not None
|
||||
), 'InvariantAnalyzer.guardrail_llm should be initialized before calling check_usertask'
|
||||
assert self.guardrail_llm is not None, (
|
||||
'InvariantAnalyzer.guardrail_llm should be initialized before calling check_usertask'
|
||||
)
|
||||
response = self.guardrail_llm.completion(
|
||||
messages=self.guardrail_llm.format_messages_for_llm(messages),
|
||||
stop=['.'],
|
||||
@@ -261,9 +261,9 @@ class InvariantAnalyzer(SecurityAnalyzer):
|
||||
],
|
||||
)
|
||||
)
|
||||
assert (
|
||||
self.guardrail_llm is not None
|
||||
), 'InvariantAnalyzer.guardrail_llm should be initialized before calling check_fillaction'
|
||||
assert self.guardrail_llm is not None, (
|
||||
'InvariantAnalyzer.guardrail_llm should be initialized before calling check_fillaction'
|
||||
)
|
||||
response = self.guardrail_llm.completion(
|
||||
messages=self.guardrail_llm.format_messages_for_llm(messages),
|
||||
stop=['.'],
|
||||
|
||||
@@ -20,7 +20,7 @@ TraceElement = Message | ToolCall | ToolOutput | Function
|
||||
|
||||
|
||||
def get_next_id(trace: list[TraceElement]) -> str:
|
||||
used_ids = [el.id for el in trace if type(el) == ToolCall]
|
||||
used_ids = [el.id for el in trace if isinstance(el, ToolCall)]
|
||||
for i in range(1, len(used_ids) + 2):
|
||||
if str(i) not in used_ids:
|
||||
return str(i)
|
||||
@@ -31,7 +31,7 @@ def get_last_id(
|
||||
trace: list[TraceElement],
|
||||
) -> str | None:
|
||||
for el in reversed(trace):
|
||||
if type(el) == ToolCall:
|
||||
if isinstance(el, ToolCall):
|
||||
return el.id
|
||||
return None
|
||||
|
||||
@@ -39,12 +39,12 @@ def get_last_id(
|
||||
def parse_action(trace: list[TraceElement], action: Action) -> list[TraceElement]:
|
||||
next_id = get_next_id(trace)
|
||||
inv_trace: list[TraceElement] = []
|
||||
if type(action) == MessageAction:
|
||||
if isinstance(action, MessageAction):
|
||||
if action.source == EventSource.USER:
|
||||
inv_trace.append(Message(role='user', content=action.content))
|
||||
else:
|
||||
inv_trace.append(Message(role='assistant', content=action.content))
|
||||
elif type(action) in [NullAction, ChangeAgentStateAction]:
|
||||
elif isinstance(action, (NullAction, ChangeAgentStateAction)):
|
||||
pass
|
||||
elif hasattr(action, 'action') and action.action is not None:
|
||||
event_dict = event_to_dict(action)
|
||||
@@ -63,7 +63,7 @@ def parse_observation(
|
||||
trace: list[TraceElement], obs: Observation
|
||||
) -> list[TraceElement]:
|
||||
last_id = get_last_id(trace)
|
||||
if type(obs) in [NullObservation, AgentStateChangedObservation]:
|
||||
if isinstance(obs, (NullObservation, AgentStateChangedObservation)):
|
||||
return []
|
||||
elif hasattr(obs, 'content') and obs.content is not None:
|
||||
return [ToolOutput(role='tool', content=obs.content, tool_call_id=last_id)]
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from typing import Type
|
||||
|
||||
from openhands.security.analyzer import SecurityAnalyzer
|
||||
from openhands.security.invariant.analyzer import InvariantAnalyzer
|
||||
|
||||
SecurityAnalyzers: dict[str, Type[SecurityAnalyzer]] = {
|
||||
SecurityAnalyzers: dict[str, type[SecurityAnalyzer]] = {
|
||||
'invariant': InvariantAnalyzer,
|
||||
}
|
||||
|
||||
@@ -42,7 +42,7 @@ websocat ws://127.0.0.1:3000/ws
|
||||
```sh
|
||||
LLM_API_KEY=sk-... # Your Anthropic API Key
|
||||
LLM_MODEL=claude-3-5-sonnet-20241022 # Default model for the agent to use
|
||||
WORKSPACE_BASE=/path/to/your/workspace # Default absolute path to workspace
|
||||
SANDBOX_VOLUMES=/path/to/your/workspace:/workspace:rw # Mount paths in format host_path:container_path:mode
|
||||
```
|
||||
|
||||
## API Schema
|
||||
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Callable, Iterable, Type
|
||||
from typing import Callable, Iterable
|
||||
|
||||
import socketio
|
||||
|
||||
@@ -23,6 +23,10 @@ from openhands.storage.data_models.conversation_metadata import ConversationMeta
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
from openhands.storage.files import FileStore
|
||||
from openhands.utils.async_utils import GENERAL_TIMEOUT, call_async_from_sync, wait_all
|
||||
from openhands.utils.conversation_summary import (
|
||||
auto_generate_title,
|
||||
get_default_conversation_title,
|
||||
)
|
||||
from openhands.utils.import_utils import get_impl
|
||||
from openhands.utils.shutdown_listener import should_continue
|
||||
|
||||
@@ -52,7 +56,7 @@ class StandaloneConversationManager(ConversationManager):
|
||||
)
|
||||
_conversations_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
||||
_cleanup_task: asyncio.Task | None = None
|
||||
_conversation_store_class: Type | None = None
|
||||
_conversation_store_class: type[ConversationStore] | None = None
|
||||
|
||||
async def __aenter__(self):
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_stale())
|
||||
@@ -283,7 +287,7 @@ class StandaloneConversationManager(ConversationManager):
|
||||
response_ids = await self.get_running_agent_loops(user_id)
|
||||
if len(response_ids) >= self.config.max_concurrent_conversations:
|
||||
logger.info(
|
||||
f'too_many_sessions_for:{user_id or ''}',
|
||||
f'too_many_sessions_for:{user_id or ""}',
|
||||
extra={'session_id': sid, 'user_id': user_id},
|
||||
)
|
||||
# Get the conversations sorted (oldest first)
|
||||
@@ -296,7 +300,7 @@ class StandaloneConversationManager(ConversationManager):
|
||||
while len(conversations) >= self.config.max_concurrent_conversations:
|
||||
oldest_conversation_id = conversations.pop().conversation_id
|
||||
logger.debug(
|
||||
f'closing_from_too_many_sessions:{user_id or ''}:{oldest_conversation_id}',
|
||||
f'closing_from_too_many_sessions:{user_id or ""}:{oldest_conversation_id}',
|
||||
extra={'session_id': oldest_conversation_id, 'user_id': user_id},
|
||||
)
|
||||
# Send status message to client and close session.
|
||||
@@ -328,7 +332,9 @@ class StandaloneConversationManager(ConversationManager):
|
||||
try:
|
||||
session.agent_session.event_stream.subscribe(
|
||||
EventStreamSubscriber.SERVER,
|
||||
self._create_conversation_update_callback(user_id, github_user_id, sid),
|
||||
self._create_conversation_update_callback(
|
||||
user_id, github_user_id, sid, settings
|
||||
),
|
||||
UPDATED_AT_CALLBACK_ID,
|
||||
)
|
||||
except ValueError:
|
||||
@@ -425,7 +431,11 @@ class StandaloneConversationManager(ConversationManager):
|
||||
)
|
||||
|
||||
def _create_conversation_update_callback(
|
||||
self, user_id: str | None, github_user_id: str | None, conversation_id: str
|
||||
self,
|
||||
user_id: str | None,
|
||||
github_user_id: str | None,
|
||||
conversation_id: str,
|
||||
settings: Settings,
|
||||
) -> Callable:
|
||||
def callback(event, *args, **kwargs):
|
||||
call_async_from_sync(
|
||||
@@ -434,13 +444,19 @@ class StandaloneConversationManager(ConversationManager):
|
||||
user_id,
|
||||
github_user_id,
|
||||
conversation_id,
|
||||
settings,
|
||||
event,
|
||||
)
|
||||
|
||||
return callback
|
||||
|
||||
async def _update_conversation_for_event(
|
||||
self, user_id: str, github_user_id: str, conversation_id: str, event=None
|
||||
self,
|
||||
user_id: str,
|
||||
github_user_id: str,
|
||||
conversation_id: str,
|
||||
settings: Settings,
|
||||
event=None,
|
||||
):
|
||||
conversation_store = await self._get_conversation_store(user_id, github_user_id)
|
||||
conversation = await conversation_store.get_metadata(conversation_id)
|
||||
@@ -462,6 +478,32 @@ class StandaloneConversationManager(ConversationManager):
|
||||
conversation.total_tokens = (
|
||||
token_usage.prompt_tokens + token_usage.completion_tokens
|
||||
)
|
||||
default_title = get_default_conversation_title(conversation_id)
|
||||
if (
|
||||
conversation.title == default_title
|
||||
): # attempt to autogenerate if default title is in use
|
||||
title = await auto_generate_title(
|
||||
conversation_id, user_id, self.file_store, settings
|
||||
)
|
||||
if title and not title.isspace():
|
||||
conversation.title = title
|
||||
try:
|
||||
# Emit a status update to the client with the new title
|
||||
status_update_dict = {
|
||||
'status_update': True,
|
||||
'type': 'info',
|
||||
'message': conversation_id,
|
||||
'conversation_title': conversation.title,
|
||||
}
|
||||
await self.sio.emit(
|
||||
'oh_event',
|
||||
status_update_dict,
|
||||
to=ROOM_KEY.format(sid=conversation_id),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'Error emitting title update event: {e}')
|
||||
else:
|
||||
conversation.title = default_title
|
||||
|
||||
await conversation_store.save_metadata(conversation)
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ def store_feedback(feedback: FeedbackDataModel) -> dict[str, str]:
|
||||
display_feedback = feedback.model_dump()
|
||||
if 'trajectory' in display_feedback:
|
||||
display_feedback['trajectory'] = (
|
||||
f"elided [length: {len(display_feedback['trajectory'])}"
|
||||
f'elided [length: {len(display_feedback["trajectory"])}'
|
||||
)
|
||||
if 'token' in display_feedback:
|
||||
display_feedback['token'] = 'elided'
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
from types import MappingProxyType
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
@@ -20,14 +21,18 @@ from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderToken
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.server.session.conversation_init_data import ConversationInitData
|
||||
from openhands.server.shared import (
|
||||
SecretsStoreImpl,
|
||||
SettingsStoreImpl,
|
||||
config,
|
||||
conversation_manager,
|
||||
server_config,
|
||||
sio,
|
||||
)
|
||||
from openhands.server.types import AppMode
|
||||
from openhands.storage.conversation.conversation_validator import (
|
||||
create_conversation_validator,
|
||||
)
|
||||
from openhands.storage.data_models.user_secrets import UserSecrets
|
||||
|
||||
|
||||
def create_provider_tokens_object(
|
||||
@@ -43,74 +48,96 @@ def create_provider_tokens_object(
|
||||
|
||||
@sio.event
|
||||
async def connect(connection_id: str, environ):
|
||||
logger.info(f'sio:connect: {connection_id}')
|
||||
query_params = parse_qs(environ.get('QUERY_STRING', ''))
|
||||
latest_event_id_str = query_params.get('latest_event_id', [-1])[0]
|
||||
try:
|
||||
latest_event_id = int(latest_event_id_str)
|
||||
except ValueError:
|
||||
logger.debug(
|
||||
f'Invalid latest_event_id value: {latest_event_id_str}, defaulting to -1'
|
||||
logger.info(f'sio:connect: {connection_id}')
|
||||
query_params = parse_qs(environ.get('QUERY_STRING', ''))
|
||||
latest_event_id_str = query_params.get('latest_event_id', [-1])[0]
|
||||
try:
|
||||
latest_event_id = int(latest_event_id_str)
|
||||
except ValueError:
|
||||
logger.debug(
|
||||
f'Invalid latest_event_id value: {latest_event_id_str}, defaulting to -1'
|
||||
)
|
||||
latest_event_id = -1
|
||||
conversation_id = query_params.get('conversation_id', [None])[0]
|
||||
raw_list = query_params.get('providers_set', [])
|
||||
providers_list = []
|
||||
for item in raw_list:
|
||||
providers_list.extend(item.split(',') if isinstance(item, str) else [])
|
||||
providers_list = [p for p in providers_list if p]
|
||||
providers_set = [ProviderType(p) for p in providers_list]
|
||||
|
||||
if not conversation_id:
|
||||
logger.error('No conversation_id in query params')
|
||||
raise ConnectionRefusedError('No conversation_id in query params')
|
||||
|
||||
cookies_str = environ.get('HTTP_COOKIE', '')
|
||||
# Get Authorization header from the environment
|
||||
# Headers in WSGI/ASGI are prefixed with 'HTTP_' and have dashes replaced with underscores
|
||||
authorization_header = environ.get('HTTP_AUTHORIZATION', None)
|
||||
conversation_validator = create_conversation_validator()
|
||||
user_id, github_user_id = await conversation_validator.validate(
|
||||
conversation_id, cookies_str, authorization_header
|
||||
)
|
||||
latest_event_id = -1
|
||||
conversation_id = query_params.get('conversation_id', [None])[0]
|
||||
raw_list = query_params.get('providers_set', [])
|
||||
providers_list = []
|
||||
for item in raw_list:
|
||||
providers_list.extend(item.split(',') if isinstance(item, str) else [])
|
||||
providers_list = [p for p in providers_list if p]
|
||||
providers_set = [ProviderType(p) for p in providers_list]
|
||||
|
||||
if not conversation_id:
|
||||
logger.error('No conversation_id in query params')
|
||||
raise ConnectionRefusedError('No conversation_id in query params')
|
||||
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
|
||||
settings = await settings_store.load()
|
||||
|
||||
cookies_str = environ.get('HTTP_COOKIE', '')
|
||||
conversation_validator = create_conversation_validator()
|
||||
user_id, github_user_id = await conversation_validator.validate(
|
||||
conversation_id, cookies_str
|
||||
)
|
||||
secrets_store = await SecretsStoreImpl.get_instance(config, user_id)
|
||||
user_secrets: UserSecrets | None = await secrets_store.load()
|
||||
|
||||
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
|
||||
settings = await settings_store.load()
|
||||
if not settings:
|
||||
raise ConnectionRefusedError(
|
||||
'Settings not found', {'msg_id': 'CONFIGURATION$SETTINGS_NOT_FOUND'}
|
||||
)
|
||||
session_init_args: dict = {}
|
||||
if settings:
|
||||
session_init_args = {**settings.__dict__, **session_init_args}
|
||||
|
||||
if not settings:
|
||||
raise ConnectionRefusedError(
|
||||
'Settings not found', {'msg_id': 'CONFIGURATION$SETTINGS_NOT_FOUND'}
|
||||
git_provider_tokens = create_provider_tokens_object(providers_set)
|
||||
if server_config.app_mode != AppMode.SAAS and user_secrets:
|
||||
git_provider_tokens = user_secrets.provider_tokens
|
||||
|
||||
session_init_args['git_provider_tokens'] = git_provider_tokens
|
||||
|
||||
conversation_init_data = ConversationInitData(**session_init_args)
|
||||
|
||||
event_stream = await conversation_manager.join_conversation(
|
||||
conversation_id,
|
||||
connection_id,
|
||||
conversation_init_data,
|
||||
user_id,
|
||||
github_user_id,
|
||||
)
|
||||
session_init_args: dict = {}
|
||||
if settings:
|
||||
session_init_args = {**settings.__dict__, **session_init_args}
|
||||
|
||||
session_init_args['git_provider_tokens'] = create_provider_tokens_object(
|
||||
providers_set
|
||||
)
|
||||
conversation_init_data = ConversationInitData(**session_init_args)
|
||||
|
||||
event_stream = await conversation_manager.join_conversation(
|
||||
conversation_id, connection_id, conversation_init_data, user_id, github_user_id
|
||||
)
|
||||
logger.info(
|
||||
f'Connected to conversation {conversation_id} with connection_id {connection_id}. Replaying event stream...'
|
||||
)
|
||||
agent_state_changed = None
|
||||
if event_stream is None:
|
||||
raise ConnectionRefusedError('Failed to join conversation')
|
||||
async_store = AsyncEventStoreWrapper(event_stream, latest_event_id + 1)
|
||||
async for event in async_store:
|
||||
logger.debug(f'oh_event: {event.__class__.__name__}')
|
||||
if isinstance(
|
||||
event,
|
||||
(NullAction, NullObservation, RecallAction),
|
||||
):
|
||||
continue
|
||||
elif isinstance(event, AgentStateChangedObservation):
|
||||
agent_state_changed = event
|
||||
else:
|
||||
await sio.emit('oh_event', event_to_dict(event), to=connection_id)
|
||||
if agent_state_changed:
|
||||
await sio.emit('oh_event', event_to_dict(agent_state_changed), to=connection_id)
|
||||
logger.info(f'Finished replaying event stream for conversation {conversation_id}')
|
||||
logger.info(
|
||||
f'Connected to conversation {conversation_id} with connection_id {connection_id}. Replaying event stream...'
|
||||
)
|
||||
agent_state_changed = None
|
||||
if event_stream is None:
|
||||
raise ConnectionRefusedError('Failed to join conversation')
|
||||
async_store = AsyncEventStoreWrapper(event_stream, latest_event_id + 1)
|
||||
async for event in async_store:
|
||||
logger.debug(f'oh_event: {event.__class__.__name__}')
|
||||
if isinstance(
|
||||
event,
|
||||
(NullAction, NullObservation, RecallAction),
|
||||
):
|
||||
continue
|
||||
elif isinstance(event, AgentStateChangedObservation):
|
||||
agent_state_changed = event
|
||||
else:
|
||||
await sio.emit('oh_event', event_to_dict(event), to=connection_id)
|
||||
if agent_state_changed:
|
||||
await sio.emit(
|
||||
'oh_event', event_to_dict(agent_state_changed), to=connection_id
|
||||
)
|
||||
logger.info(
|
||||
f'Finished replaying event stream for conversation {conversation_id}'
|
||||
)
|
||||
except ConnectionRefusedError:
|
||||
# Close the broken connection after sending an error message
|
||||
asyncio.create_task(sio.disconnect(connection_id))
|
||||
raise
|
||||
|
||||
|
||||
@sio.event
|
||||
|
||||
@@ -8,7 +8,7 @@ app = APIRouter(prefix='/api/conversations/{conversation_id}')
|
||||
|
||||
|
||||
@app.get('/config')
|
||||
async def get_remote_runtime_config(request: Request):
|
||||
async def get_remote_runtime_config(request: Request) -> JSONResponse:
|
||||
"""Retrieve the runtime configuration.
|
||||
|
||||
Currently, this is the session ID and runtime ID (if available).
|
||||
@@ -25,7 +25,7 @@ async def get_remote_runtime_config(request: Request):
|
||||
|
||||
|
||||
@app.get('/vscode-url')
|
||||
async def get_vscode_url(request: Request):
|
||||
async def get_vscode_url(request: Request) -> JSONResponse:
|
||||
"""Get the VSCode URL.
|
||||
|
||||
This endpoint allows getting the VSCode URL.
|
||||
@@ -55,7 +55,7 @@ async def get_vscode_url(request: Request):
|
||||
|
||||
|
||||
@app.get('/web-hosts')
|
||||
async def get_hosts(request: Request):
|
||||
async def get_hosts(request: Request) -> JSONResponse:
|
||||
"""Get the hosts used by the runtime.
|
||||
|
||||
This endpoint allows getting the hosts used by the runtime.
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
@@ -41,8 +42,17 @@ from openhands.utils.async_utils import call_sync_from_async
|
||||
app = APIRouter(prefix='/api/conversations/{conversation_id}')
|
||||
|
||||
|
||||
@app.get('/list-files')
|
||||
async def list_files(request: Request, path: str | None = None):
|
||||
@app.get(
|
||||
'/list-files',
|
||||
response_model=list[str],
|
||||
responses={
|
||||
404: {'description': 'Runtime not initialized', 'model': dict},
|
||||
500: {'description': 'Error listing or filtering files', 'model': dict},
|
||||
},
|
||||
)
|
||||
async def list_files(
|
||||
request: Request, path: str | None = None
|
||||
) -> list[str] | JSONResponse:
|
||||
"""List files in the specified path.
|
||||
|
||||
This function retrieves a list of files from the agent's runtime file store,
|
||||
@@ -83,7 +93,7 @@ async def list_files(request: Request, path: str | None = None):
|
||||
|
||||
file_list = [f for f in file_list if f not in FILES_TO_IGNORE]
|
||||
|
||||
async def filter_for_gitignore(file_list, base_path):
|
||||
async def filter_for_gitignore(file_list: list[str], base_path: str) -> list[str]:
|
||||
gitignore_path = os.path.join(base_path, '.gitignore')
|
||||
try:
|
||||
read_action = FileReadAction(gitignore_path)
|
||||
@@ -109,8 +119,21 @@ async def list_files(request: Request, path: str | None = None):
|
||||
return file_list
|
||||
|
||||
|
||||
@app.get('/select-file')
|
||||
async def select_file(file: str, request: Request):
|
||||
# NOTE: We use response_model=None for endpoints that can return multiple response types
|
||||
# (like FileResponse | JSONResponse). This is because FastAPI's response_model expects a
|
||||
# Pydantic model, but Starlette response classes like FileResponse are not Pydantic models.
|
||||
# Instead, we document the possible responses using the 'responses' parameter and maintain
|
||||
# proper type annotations for mypy.
|
||||
@app.get(
|
||||
'/select-file',
|
||||
response_model=None,
|
||||
responses={
|
||||
200: {'description': 'File content returned as JSON', 'model': dict[str, str]},
|
||||
500: {'description': 'Error opening file', 'model': dict},
|
||||
415: {'description': 'Unsupported media type', 'model': dict},
|
||||
},
|
||||
)
|
||||
async def select_file(file: str, request: Request) -> FileResponse | JSONResponse:
|
||||
"""Retrieve the content of a specified file.
|
||||
|
||||
To select a file:
|
||||
@@ -144,7 +167,7 @@ async def select_file(file: str, request: Request):
|
||||
|
||||
if isinstance(observation, FileReadObservation):
|
||||
content = observation.content
|
||||
return {'code': content}
|
||||
return JSONResponse(content={'code': content})
|
||||
elif isinstance(observation, ErrorObservation):
|
||||
logger.error(f'Error opening file {file}: {observation}')
|
||||
|
||||
@@ -158,10 +181,23 @@ async def select_file(file: str, request: Request):
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content={'error': f'Error opening file: {observation}'},
|
||||
)
|
||||
else:
|
||||
# Handle unexpected observation types
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content={'error': f'Unexpected observation type: {type(observation)}'},
|
||||
)
|
||||
|
||||
|
||||
@app.get('/zip-directory')
|
||||
def zip_current_workspace(request: Request):
|
||||
@app.get(
|
||||
'/zip-directory',
|
||||
response_model=None,
|
||||
responses={
|
||||
200: {'description': 'Zipped workspace returned as FileResponse'},
|
||||
500: {'description': 'Error zipping workspace', 'model': dict},
|
||||
},
|
||||
)
|
||||
def zip_current_workspace(request: Request) -> FileResponse | JSONResponse:
|
||||
try:
|
||||
logger.debug('Zipping workspace')
|
||||
runtime: Runtime = request.state.conversation.runtime
|
||||
@@ -188,12 +224,19 @@ def zip_current_workspace(request: Request):
|
||||
)
|
||||
|
||||
|
||||
@app.get('/git/changes')
|
||||
@app.get(
|
||||
'/git/changes',
|
||||
response_model=dict[str, Any],
|
||||
responses={
|
||||
404: {'description': 'Not a git repository', 'model': dict},
|
||||
500: {'description': 'Error getting changes', 'model': dict},
|
||||
},
|
||||
)
|
||||
async def git_changes(
|
||||
request: Request,
|
||||
conversation_id: str,
|
||||
user_id: str = Depends(get_user_id),
|
||||
):
|
||||
) -> dict[str, Any] | JSONResponse:
|
||||
runtime: Runtime = request.state.conversation.runtime
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config,
|
||||
@@ -229,13 +272,17 @@ async def git_changes(
|
||||
)
|
||||
|
||||
|
||||
@app.get('/git/diff')
|
||||
@app.get(
|
||||
'/git/diff',
|
||||
response_model=dict[str, Any],
|
||||
responses={500: {'description': 'Error getting diff', 'model': dict}},
|
||||
)
|
||||
async def git_diff(
|
||||
request: Request,
|
||||
path: str,
|
||||
conversation_id: str,
|
||||
conversation_store=Depends(get_conversation_store),
|
||||
):
|
||||
conversation_store: Any = Depends(get_conversation_store),
|
||||
) -> dict[str, Any] | JSONResponse:
|
||||
runtime: Runtime = request.state.conversation.runtime
|
||||
|
||||
cwd = await get_cwd(
|
||||
@@ -259,7 +306,7 @@ async def get_cwd(
|
||||
conversation_store: ConversationStore,
|
||||
conversation_id: str,
|
||||
workspace_mount_path_in_sandbox: str,
|
||||
):
|
||||
) -> str:
|
||||
metadata = await conversation_store.get_metadata(conversation_id)
|
||||
is_running = await conversation_manager.is_agent_loop_running(conversation_id)
|
||||
conversation_info = await _get_conversation_info(metadata, is_running)
|
||||
|
||||
@@ -8,6 +8,7 @@ from openhands.integrations.provider import (
|
||||
)
|
||||
from openhands.integrations.service_types import (
|
||||
AuthenticationError,
|
||||
Branch,
|
||||
Repository,
|
||||
SuggestedTask,
|
||||
UnknownException,
|
||||
@@ -29,7 +30,7 @@ async def get_user_repositories(
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens),
|
||||
access_token: SecretStr | None = Depends(get_access_token),
|
||||
user_id: str | None = Depends(get_user_id),
|
||||
):
|
||||
) -> list[Repository] | JSONResponse:
|
||||
if provider_tokens:
|
||||
client = ProviderHandler(
|
||||
provider_tokens=provider_tokens,
|
||||
@@ -65,7 +66,7 @@ async def get_user_repositories(
|
||||
async def get_user(
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens),
|
||||
access_token: SecretStr | None = Depends(get_access_token),
|
||||
):
|
||||
) -> User | JSONResponse:
|
||||
if provider_tokens:
|
||||
client = ProviderHandler(
|
||||
provider_tokens=provider_tokens, external_auth_token=access_token
|
||||
@@ -101,7 +102,7 @@ async def search_repositories(
|
||||
order: str = 'desc',
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens),
|
||||
access_token: SecretStr | None = Depends(get_access_token),
|
||||
):
|
||||
) -> list[Repository] | JSONResponse:
|
||||
if provider_tokens:
|
||||
client = ProviderHandler(
|
||||
provider_tokens=provider_tokens, external_auth_token=access_token
|
||||
@@ -134,7 +135,7 @@ async def search_repositories(
|
||||
async def get_suggested_tasks(
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens),
|
||||
access_token: SecretStr | None = Depends(get_access_token),
|
||||
):
|
||||
) -> list[SuggestedTask] | JSONResponse:
|
||||
"""Get suggested tasks for the authenticated user across their most recently pushed repositories.
|
||||
|
||||
Returns:
|
||||
@@ -165,3 +166,43 @@ async def get_suggested_tasks(
|
||||
content='No providers set.',
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
|
||||
@app.get('/repository/branches', response_model=list[Branch])
|
||||
async def get_repository_branches(
|
||||
repository: str,
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens),
|
||||
access_token: SecretStr | None = Depends(get_access_token),
|
||||
) -> list[Branch] | JSONResponse:
|
||||
"""Get branches for a repository.
|
||||
|
||||
Args:
|
||||
repository: The repository name in the format 'owner/repo'
|
||||
|
||||
Returns:
|
||||
A list of branches for the repository
|
||||
"""
|
||||
if provider_tokens:
|
||||
client = ProviderHandler(
|
||||
provider_tokens=provider_tokens, external_auth_token=access_token
|
||||
)
|
||||
try:
|
||||
branches: list[Branch] = await client.get_branches(repository)
|
||||
return branches
|
||||
|
||||
except AuthenticationError as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
except UnknownException as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
content='Git provider token required. (such as GitHub).',
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, status
|
||||
from fastapi import APIRouter, Depends, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.events.event import EventSource
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.integrations.provider import (
|
||||
PROVIDER_TOKEN_TYPE,
|
||||
ProviderHandler,
|
||||
@@ -30,7 +28,6 @@ from openhands.server.shared import (
|
||||
SettingsStoreImpl,
|
||||
config,
|
||||
conversation_manager,
|
||||
file_store,
|
||||
)
|
||||
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
|
||||
from openhands.server.user_auth import (
|
||||
@@ -47,7 +44,7 @@ from openhands.storage.data_models.conversation_metadata import (
|
||||
)
|
||||
from openhands.storage.data_models.conversation_status import ConversationStatus
|
||||
from openhands.utils.async_utils import wait_all
|
||||
from openhands.utils.conversation_summary import generate_conversation_title
|
||||
from openhands.utils.conversation_summary import get_default_conversation_title
|
||||
|
||||
app = APIRouter(prefix='/api')
|
||||
|
||||
@@ -75,7 +72,7 @@ async def _create_new_conversation(
|
||||
replay_json: str | None,
|
||||
conversation_trigger: ConversationTrigger = ConversationTrigger.GUI,
|
||||
attach_convo_id: bool = False,
|
||||
):
|
||||
) -> str:
|
||||
logger.info(
|
||||
'Creating conversation',
|
||||
extra={
|
||||
@@ -89,7 +86,7 @@ async def _create_new_conversation(
|
||||
settings = await settings_store.load()
|
||||
logger.info('Settings loaded')
|
||||
|
||||
session_init_args: dict = {}
|
||||
session_init_args: dict[str, Any] = {}
|
||||
if settings:
|
||||
session_init_args = {**settings.__dict__, **session_init_args}
|
||||
# We could use litellm.check_valid_key for a more accurate check,
|
||||
@@ -98,13 +95,13 @@ async def _create_new_conversation(
|
||||
not settings.llm_api_key
|
||||
or settings.llm_api_key.get_secret_value().isspace()
|
||||
):
|
||||
logger.warn(f'Missing api key for model {settings.llm_model}')
|
||||
logger.warning(f'Missing api key for model {settings.llm_model}')
|
||||
raise LLMAuthenticationError(
|
||||
'Error authenticating with the LLM provider. Please check your API key'
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warn('Settings not present, not starting conversation')
|
||||
logger.warning('Settings not present, not starting conversation')
|
||||
raise MissingSettingsError('Settings not found')
|
||||
|
||||
session_init_args['git_provider_tokens'] = git_provider_tokens
|
||||
@@ -162,7 +159,6 @@ async def _create_new_conversation(
|
||||
replay_json=replay_json,
|
||||
)
|
||||
logger.info(f'Finished initializing conversation {conversation_id}')
|
||||
|
||||
return conversation_id
|
||||
|
||||
|
||||
@@ -172,7 +168,7 @@ async def new_conversation(
|
||||
user_id: str = Depends(get_user_id),
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE = Depends(get_provider_tokens),
|
||||
auth_type: AuthType | None = Depends(get_auth_type),
|
||||
):
|
||||
) -> JSONResponse:
|
||||
"""Initialize a new session or join an existing one.
|
||||
|
||||
After successful initialization, the client should connect to the WebSocket
|
||||
@@ -300,109 +296,6 @@ async def get_conversation(
|
||||
return None
|
||||
|
||||
|
||||
def get_default_conversation_title(conversation_id: str) -> str:
|
||||
"""
|
||||
Generate a default title for a conversation based on its ID.
|
||||
|
||||
Args:
|
||||
conversation_id: The ID of the conversation
|
||||
|
||||
Returns:
|
||||
A default title string
|
||||
"""
|
||||
return f'Conversation {conversation_id[:5]}'
|
||||
|
||||
|
||||
async def auto_generate_title(conversation_id: str, user_id: str | None) -> str:
|
||||
"""
|
||||
Auto-generate a title for a conversation based on the first user message.
|
||||
Uses LLM-based title generation if available, otherwise falls back to a simple truncation.
|
||||
|
||||
Args:
|
||||
conversation_id: The ID of the conversation
|
||||
user_id: The ID of the user
|
||||
|
||||
Returns:
|
||||
A generated title string
|
||||
"""
|
||||
logger.info(f'Auto-generating title for conversation {conversation_id}')
|
||||
|
||||
try:
|
||||
# Create an event stream for the conversation
|
||||
event_stream = EventStream(conversation_id, file_store, user_id)
|
||||
|
||||
# Find the first user message
|
||||
first_user_message = None
|
||||
for event in event_stream.get_events():
|
||||
if (
|
||||
event.source == EventSource.USER
|
||||
and isinstance(event, MessageAction)
|
||||
and event.content
|
||||
and event.content.strip()
|
||||
):
|
||||
first_user_message = event.content
|
||||
break
|
||||
|
||||
if first_user_message:
|
||||
# Get LLM config from user settings
|
||||
try:
|
||||
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
|
||||
settings = await settings_store.load()
|
||||
|
||||
if settings and settings.llm_model:
|
||||
# Create LLM config from settings
|
||||
llm_config = LLMConfig(
|
||||
model=settings.llm_model,
|
||||
api_key=settings.llm_api_key,
|
||||
base_url=settings.llm_base_url,
|
||||
)
|
||||
|
||||
# Try to generate title using LLM
|
||||
llm_title = await generate_conversation_title(
|
||||
first_user_message, llm_config
|
||||
)
|
||||
if llm_title:
|
||||
logger.info(f'Generated title using LLM: {llm_title}')
|
||||
return llm_title
|
||||
except Exception as e:
|
||||
logger.error(f'Error using LLM for title generation: {e}')
|
||||
|
||||
# Fall back to simple truncation if LLM generation fails or is unavailable
|
||||
first_user_message = first_user_message.strip()
|
||||
title = first_user_message[:30]
|
||||
if len(first_user_message) > 30:
|
||||
title += '...'
|
||||
logger.info(f'Generated title using truncation: {title}')
|
||||
return title
|
||||
except Exception as e:
|
||||
logger.error(f'Error generating title: {str(e)}')
|
||||
return ''
|
||||
|
||||
|
||||
@app.patch('/conversations/{conversation_id}')
|
||||
async def update_conversation(
|
||||
conversation_id: str,
|
||||
title: str = Body(embed=True),
|
||||
user_id: str | None = Depends(get_user_id),
|
||||
) -> bool:
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
|
||||
metadata = await conversation_store.get_metadata(conversation_id)
|
||||
if not metadata:
|
||||
return False
|
||||
|
||||
# If title is empty or unspecified, auto-generate it
|
||||
if not title or title.isspace():
|
||||
title = await auto_generate_title(conversation_id, user_id)
|
||||
|
||||
# If we still don't have a title, use the default
|
||||
if not title or title.isspace():
|
||||
title = get_default_conversation_title(conversation_id)
|
||||
|
||||
metadata.title = title
|
||||
await conversation_store.save_metadata(metadata)
|
||||
return True
|
||||
|
||||
|
||||
@app.delete('/conversations/{conversation_id}')
|
||||
async def delete_conversation(
|
||||
conversation_id: str,
|
||||
|
||||
@@ -2,6 +2,8 @@ from fastapi import APIRouter, Depends, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.integrations.utils import validate_provider_token
|
||||
from openhands.server.settings import (
|
||||
GETCustomSecrets,
|
||||
@@ -9,6 +11,7 @@ from openhands.server.settings import (
|
||||
POSTProviderModel,
|
||||
)
|
||||
from openhands.server.user_auth import (
|
||||
get_provider_tokens,
|
||||
get_secrets_store,
|
||||
get_user_secrets,
|
||||
)
|
||||
@@ -51,25 +54,59 @@ async def invalidate_legacy_secrets_store(
|
||||
return None
|
||||
|
||||
|
||||
async def check_provider_tokens(provider_info: POSTProviderModel) -> str:
|
||||
print(provider_info)
|
||||
if provider_info.provider_tokens:
|
||||
# Determine whether tokens are valid
|
||||
for token_type, token_value in provider_info.provider_tokens.items():
|
||||
if token_value.token:
|
||||
confirmed_token_type = await validate_provider_token(token_value.token)
|
||||
if not confirmed_token_type or confirmed_token_type != token_type:
|
||||
return f'Invalid token. Please make sure it is a valid {token_type.value} token.'
|
||||
def process_token_validation_result(
|
||||
confirmed_token_type: ProviderType | None, token_type: ProviderType
|
||||
):
|
||||
if not confirmed_token_type or confirmed_token_type != token_type:
|
||||
return (
|
||||
f'Invalid token. Please make sure it is a valid {token_type.value} token.'
|
||||
)
|
||||
|
||||
return ''
|
||||
|
||||
|
||||
async def check_provider_tokens(
|
||||
incoming_provider_tokens: POSTProviderModel,
|
||||
existing_provider_tokens: PROVIDER_TOKEN_TYPE | None,
|
||||
) -> str:
|
||||
msg = ''
|
||||
if incoming_provider_tokens.provider_tokens:
|
||||
# Determine whether tokens are valid
|
||||
for token_type, token_value in incoming_provider_tokens.provider_tokens.items():
|
||||
if token_value.token:
|
||||
confirmed_token_type = await validate_provider_token(
|
||||
token_value.token, token_value.host
|
||||
) # FE always sends latest host
|
||||
msg = process_token_validation_result(confirmed_token_type, token_type)
|
||||
|
||||
existing_token = (
|
||||
existing_provider_tokens.get(token_type, None)
|
||||
if existing_provider_tokens
|
||||
else None
|
||||
)
|
||||
if (
|
||||
existing_token
|
||||
and (existing_token.host != token_value.host)
|
||||
and existing_token.token
|
||||
):
|
||||
confirmed_token_type = await validate_provider_token(
|
||||
existing_token.token, token_value.host
|
||||
) # Host has changed, check it against existing token
|
||||
if not confirmed_token_type or confirmed_token_type != token_type:
|
||||
msg = process_token_validation_result(
|
||||
confirmed_token_type, token_type
|
||||
)
|
||||
|
||||
return msg
|
||||
|
||||
|
||||
@app.post('/add-git-providers')
|
||||
async def store_provider_tokens(
|
||||
provider_info: POSTProviderModel,
|
||||
secrets_store: SecretsStore = Depends(get_secrets_store),
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens),
|
||||
) -> JSONResponse:
|
||||
provider_err_msg = await check_provider_tokens(provider_info)
|
||||
provider_err_msg = await check_provider_tokens(provider_info, provider_tokens)
|
||||
if provider_err_msg:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
@@ -91,8 +128,9 @@ async def store_provider_tokens(
|
||||
if existing_token and existing_token.token:
|
||||
provider_info.provider_tokens[provider] = existing_token
|
||||
|
||||
else: # nothing passed in means keep current settings
|
||||
provider_info.provider_tokens = dict(user_secrets.provider_tokens)
|
||||
provider_info.provider_tokens[provider] = provider_info.provider_tokens[
|
||||
provider
|
||||
].model_copy(update={'host': token_value.host})
|
||||
|
||||
updated_secrets = user_secrets.model_copy(
|
||||
update={'provider_tokens': provider_info.provider_tokens}
|
||||
@@ -154,10 +192,10 @@ async def load_custom_secrets_names(
|
||||
return GETCustomSecrets(custom_secrets=custom_secrets)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f'Invalid token: {e}')
|
||||
logger.warning(f'Failed to load secret names: {e}')
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={'error': 'Invalid token'},
|
||||
content={'error': 'Failed to get secret names'},
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -23,7 +23,14 @@ from openhands.storage.settings.settings_store import SettingsStore
|
||||
app = APIRouter(prefix='/api')
|
||||
|
||||
|
||||
@app.get('/settings', response_model=GETSettingsModel)
|
||||
@app.get(
|
||||
'/settings',
|
||||
response_model=GETSettingsModel,
|
||||
responses={
|
||||
404: {'description': 'Settings not found', 'model': dict},
|
||||
401: {'description': 'Invalid token', 'model': dict},
|
||||
},
|
||||
)
|
||||
async def load_settings(
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens),
|
||||
settings_store: SettingsStore = Depends(get_user_settings_store),
|
||||
@@ -42,6 +49,7 @@ async def load_settings(
|
||||
user_secrets = await invalidate_legacy_secrets_store(
|
||||
settings, settings_store, secrets_store
|
||||
)
|
||||
|
||||
# If invalidation is successful, then the returned user secrets holds the most recent values
|
||||
git_providers = (
|
||||
user_secrets.provider_tokens if user_secrets else provider_tokens
|
||||
@@ -51,7 +59,7 @@ async def load_settings(
|
||||
if git_providers:
|
||||
for provider_type, provider_token in git_providers.items():
|
||||
if provider_token.token or provider_token.user_id:
|
||||
provider_tokens_set[provider_type] = None
|
||||
provider_tokens_set[provider_type] = provider_token.host
|
||||
|
||||
settings_with_token_data = GETSettingsModel(
|
||||
**settings.model_dump(exclude='secrets_store'),
|
||||
@@ -69,7 +77,15 @@ async def load_settings(
|
||||
)
|
||||
|
||||
|
||||
@app.post('/reset-settings', response_model=dict[str, str])
|
||||
@app.post(
|
||||
'/reset-settings',
|
||||
responses={
|
||||
410: {
|
||||
'description': 'Reset settings functionality has been removed',
|
||||
'model': dict,
|
||||
}
|
||||
},
|
||||
)
|
||||
async def reset_settings() -> JSONResponse:
|
||||
"""
|
||||
Resets user settings. (Deprecated)
|
||||
@@ -99,7 +115,18 @@ async def store_llm_settings(
|
||||
return settings
|
||||
|
||||
|
||||
@app.post('/settings', response_model=dict[str, str])
|
||||
# NOTE: We use response_model=None for endpoints that return JSONResponse directly.
|
||||
# This is because FastAPI's response_model expects a Pydantic model, but we're returning
|
||||
# a response object directly. We document the possible responses using the 'responses'
|
||||
# parameter and maintain proper type annotations for mypy.
|
||||
@app.post(
|
||||
'/settings',
|
||||
response_model=None,
|
||||
responses={
|
||||
200: {'description': 'Settings stored successfully', 'model': dict},
|
||||
500: {'description': 'Error storing settings', 'model': dict},
|
||||
},
|
||||
)
|
||||
async def store_settings(
|
||||
settings: Settings,
|
||||
settings_store: SettingsStore = Depends(get_user_settings_store),
|
||||
|
||||
@@ -58,7 +58,7 @@ class AgentSession:
|
||||
file_store: FileStore,
|
||||
status_callback: Callable | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
) -> None:
|
||||
"""Initializes a new instance of the Session class
|
||||
|
||||
Parameters:
|
||||
@@ -89,7 +89,7 @@ class AgentSession:
|
||||
selected_branch: str | None = None,
|
||||
initial_message: MessageAction | None = None,
|
||||
replay_json: str | None = None,
|
||||
):
|
||||
) -> None:
|
||||
"""Starts the Agent session
|
||||
Parameters:
|
||||
- runtime_name: The name of the runtime associated with the session
|
||||
@@ -188,7 +188,7 @@ class AgentSession:
|
||||
},
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
async def close(self) -> None:
|
||||
"""Closes the Agent session"""
|
||||
if self._closed:
|
||||
return
|
||||
@@ -245,7 +245,7 @@ class AgentSession:
|
||||
assert isinstance(replay_events[0], MessageAction)
|
||||
return replay_events[0]
|
||||
|
||||
def _create_security_analyzer(self, security_analyzer: str | None):
|
||||
def _create_security_analyzer(self, security_analyzer: str | None) -> None:
|
||||
"""Creates a SecurityAnalyzer instance that will be used to analyze the agent actions
|
||||
|
||||
Parameters:
|
||||
@@ -333,6 +333,7 @@ class AgentSession:
|
||||
git_provider_tokens, selected_repository, selected_branch
|
||||
)
|
||||
await call_sync_from_async(self.runtime.maybe_run_setup_script)
|
||||
await call_sync_from_async(self.runtime.maybe_setup_git_hooks)
|
||||
|
||||
self.logger.debug(
|
||||
f'Runtime initialized with plugins: {[plugin.name for plugin in self.runtime.plugins]}'
|
||||
|
||||
@@ -38,10 +38,10 @@ class Conversation:
|
||||
headless_mode=False,
|
||||
)
|
||||
|
||||
async def connect(self):
|
||||
async def connect(self) -> None:
|
||||
await self.runtime.connect()
|
||||
|
||||
async def disconnect(self):
|
||||
async def disconnect(self) -> None:
|
||||
if self.event_stream:
|
||||
self.event_stream.close()
|
||||
asyncio.create_task(call_sync_from_async(self.runtime.close))
|
||||
|
||||
@@ -6,7 +6,7 @@ from logging import LoggerAdapter
|
||||
import socketio
|
||||
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.core.config import AppConfig
|
||||
from openhands.core.config import AppConfig, MCPConfig
|
||||
from openhands.core.config.condenser_config import (
|
||||
BrowserOutputCondenserConfig,
|
||||
CondenserPipelineConfig,
|
||||
@@ -73,7 +73,7 @@ class Session:
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self.user_id = user_id
|
||||
|
||||
async def close(self):
|
||||
async def close(self) -> None:
|
||||
if self.sio:
|
||||
await self.sio.emit(
|
||||
'oh_event',
|
||||
@@ -90,7 +90,7 @@ class Session:
|
||||
settings: Settings,
|
||||
initial_message: MessageAction | None,
|
||||
replay_json: str | None,
|
||||
):
|
||||
) -> None:
|
||||
self.agent_session.event_stream.add_event(
|
||||
AgentStateChangedObservation('', AgentState.LOADING),
|
||||
EventSource.ENVIRONMENT,
|
||||
@@ -114,6 +114,7 @@ class Session:
|
||||
or settings.sandbox_runtime_container_image
|
||||
else self.config.sandbox.runtime_container_image
|
||||
)
|
||||
self.config.mcp = settings.mcp_config or MCPConfig()
|
||||
max_iterations = settings.max_iterations or self.config.max_iterations
|
||||
|
||||
# This is a shallow copy of the default LLM config, so changes here will
|
||||
@@ -137,7 +138,7 @@ class Session:
|
||||
# output, which should keep the summarization cost down.
|
||||
default_condenser_config = CondenserPipelineConfig(
|
||||
condensers=[
|
||||
BrowserOutputCondenserConfig(),
|
||||
BrowserOutputCondenserConfig(attention_window=2),
|
||||
LLMSummarizingCondenserConfig(
|
||||
llm_config=llm.config, keep_first=4, max_size=80
|
||||
),
|
||||
@@ -194,10 +195,10 @@ class Session:
|
||||
'info', msg_id, f'Retrying LLM request, {retries} / {max}'
|
||||
)
|
||||
|
||||
def on_event(self, event: Event):
|
||||
def on_event(self, event: Event) -> None:
|
||||
asyncio.get_event_loop().run_until_complete(self._on_event(event))
|
||||
|
||||
async def _on_event(self, event: Event):
|
||||
async def _on_event(self, event: Event) -> None:
|
||||
"""Callback function for events that mainly come from the agent.
|
||||
Event is the base class for any agent action and observation.
|
||||
|
||||
@@ -235,7 +236,7 @@ class Session:
|
||||
event_dict['source'] = EventSource.AGENT
|
||||
await self.send(event_dict)
|
||||
|
||||
async def dispatch(self, data: dict):
|
||||
async def dispatch(self, data: dict) -> None:
|
||||
event = event_from_dict(data.copy())
|
||||
# This checks if the model supports images
|
||||
if isinstance(event, MessageAction) and event.image_urls:
|
||||
@@ -253,7 +254,7 @@ class Session:
|
||||
return
|
||||
self.agent_session.event_stream.add_event(event, EventSource.USER)
|
||||
|
||||
async def send(self, data: dict[str, object]):
|
||||
async def send(self, data: dict[str, object]) -> None:
|
||||
if asyncio.get_running_loop() != self.loop:
|
||||
self.loop.create_task(self._send(data))
|
||||
return
|
||||
@@ -273,11 +274,11 @@ class Session:
|
||||
self.is_alive = False
|
||||
return False
|
||||
|
||||
async def send_error(self, message: str):
|
||||
async def send_error(self, message: str) -> None:
|
||||
"""Sends an error message to the client."""
|
||||
await self.send({'error': True, 'message': message})
|
||||
|
||||
async def _send_status_message(self, msg_type: str, id: str, message: str):
|
||||
async def _send_status_message(self, msg_type: str, id: str, message: str) -> None:
|
||||
"""Sends a status message to the client."""
|
||||
if msg_type == 'error':
|
||||
agent_session = self.agent_session
|
||||
@@ -292,7 +293,7 @@ class Session:
|
||||
{'status_update': True, 'type': msg_type, 'id': id, 'message': message}
|
||||
)
|
||||
|
||||
def queue_status_message(self, msg_type: str, id: str, message: str):
|
||||
def queue_status_message(self, msg_type: str, id: str, message: str) -> None:
|
||||
"""Queues a status message to be sent asynchronously."""
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._send_status_message(msg_type, id, message), self.loop
|
||||
|
||||
@@ -5,6 +5,7 @@ from pydantic import (
|
||||
SecretStr,
|
||||
)
|
||||
|
||||
from openhands.core.config.mcp_config import MCPConfig
|
||||
from openhands.integrations.provider import ProviderToken
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
@@ -15,6 +16,7 @@ class POSTProviderModel(BaseModel):
|
||||
Settings for POST requests
|
||||
"""
|
||||
|
||||
mcp_config: MCPConfig | None = None
|
||||
provider_tokens: dict[ProviderType, ProviderToken] = {}
|
||||
|
||||
|
||||
@@ -36,6 +38,8 @@ class GETSettingsModel(Settings):
|
||||
)
|
||||
llm_api_key_set: bool
|
||||
|
||||
model_config = {'use_enum_values': True}
|
||||
|
||||
|
||||
class GETCustomSecrets(BaseModel):
|
||||
"""
|
||||
|
||||
@@ -7,7 +7,10 @@ class ConversationValidator:
|
||||
"""Storage for conversation metadata. May or may not support multiple users depending on the environment."""
|
||||
|
||||
async def validate(
|
||||
self, conversation_id: str, cookies_str: str
|
||||
self,
|
||||
conversation_id: str,
|
||||
cookies_str: str,
|
||||
authorization_header: str | None = None,
|
||||
) -> tuple[None, None]:
|
||||
return None, None
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from pydantic import (
|
||||
from pydantic.json import pydantic_encoder
|
||||
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.config.mcp_config import MCPConfig
|
||||
from openhands.core.config.utils import load_app_config
|
||||
from openhands.storage.data_models.user_secrets import UserSecrets
|
||||
|
||||
@@ -33,9 +34,11 @@ class Settings(BaseModel):
|
||||
secrets_store: UserSecrets = Field(default_factory=UserSecrets, frozen=True)
|
||||
enable_default_condenser: bool = True
|
||||
enable_sound_notifications: bool = False
|
||||
enable_proactive_conversation_starters: bool = True
|
||||
user_consents_to_analytics: bool | None = None
|
||||
sandbox_base_container_image: str | None = None
|
||||
sandbox_runtime_container_image: str | None = None
|
||||
mcp_config: MCPConfig | None = None
|
||||
|
||||
model_config = {
|
||||
'validate_assignment': True,
|
||||
@@ -104,6 +107,12 @@ class Settings(BaseModel):
|
||||
# If no api key has been set, we take this to mean that there is no reasonable default
|
||||
return None
|
||||
security = app_config.security
|
||||
|
||||
# Get MCP config if available
|
||||
mcp_config = None
|
||||
if hasattr(app_config, 'mcp'):
|
||||
mcp_config = app_config.mcp
|
||||
|
||||
settings = Settings(
|
||||
language='en',
|
||||
agent=app_config.default_agent,
|
||||
@@ -114,5 +123,6 @@ class Settings(BaseModel):
|
||||
llm_api_key=llm_config.api_key,
|
||||
llm_base_url=llm_config.base_url,
|
||||
remote_runtime_resource_factor=app_config.sandbox.remote_runtime_resource_factor,
|
||||
mcp_config=mcp_config,
|
||||
)
|
||||
return settings
|
||||
|
||||
@@ -14,6 +14,7 @@ from pydantic.json import pydantic_encoder
|
||||
|
||||
from openhands.integrations.provider import (
|
||||
CUSTOM_SECRETS_TYPE,
|
||||
CUSTOM_SECRETS_TYPE_WITH_JSON_SCHEMA,
|
||||
PROVIDER_TOKEN_TYPE,
|
||||
PROVIDER_TOKEN_TYPE_WITH_JSON_SCHEMA,
|
||||
ProviderToken,
|
||||
@@ -26,7 +27,7 @@ class UserSecrets(BaseModel):
|
||||
default_factory=lambda: MappingProxyType({})
|
||||
)
|
||||
|
||||
custom_secrets: CUSTOM_SECRETS_TYPE = Field(
|
||||
custom_secrets: CUSTOM_SECRETS_TYPE_WITH_JSON_SCHEMA = Field(
|
||||
default_factory=lambda: MappingProxyType({})
|
||||
)
|
||||
|
||||
@@ -44,7 +45,7 @@ class UserSecrets(BaseModel):
|
||||
expose_secrets = info.context and info.context.get('expose_secrets', False)
|
||||
|
||||
for token_type, provider_token in provider_tokens.items():
|
||||
if not provider_token or not provider_token.token:
|
||||
if not provider_token:
|
||||
continue
|
||||
|
||||
token_type_str = (
|
||||
@@ -52,10 +53,18 @@ class UserSecrets(BaseModel):
|
||||
if isinstance(token_type, ProviderType)
|
||||
else str(token_type)
|
||||
)
|
||||
|
||||
token = None
|
||||
if provider_token.token:
|
||||
token = (
|
||||
provider_token.token.get_secret_value()
|
||||
if expose_secrets
|
||||
else pydantic_encoder(provider_token.token)
|
||||
)
|
||||
|
||||
tokens[token_type_str] = {
|
||||
'token': provider_token.token.get_secret_value()
|
||||
if expose_secrets
|
||||
else pydantic_encoder(provider_token.token),
|
||||
'token': token,
|
||||
'host': provider_token.host,
|
||||
'user_id': provider_token.user_id,
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, List, TypedDict
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import boto3
|
||||
import botocore
|
||||
@@ -16,7 +16,7 @@ class GetObjectOutputDict(TypedDict):
|
||||
|
||||
|
||||
class ListObjectsV2OutputDict(TypedDict):
|
||||
Contents: List[S3ObjectDict] | None
|
||||
Contents: list[S3ObjectDict] | None
|
||||
|
||||
|
||||
class S3FileStore(FileStore):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
from concurrent import futures
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Callable, Coroutine, Iterable, List
|
||||
from typing import Callable, Coroutine, Iterable
|
||||
|
||||
GENERAL_TIMEOUT: int = 15
|
||||
EXECUTOR = ThreadPoolExecutor()
|
||||
@@ -64,7 +64,7 @@ async def call_coro_in_bg_thread(
|
||||
|
||||
async def wait_all(
|
||||
iterable: Iterable[Coroutine], timeout: int = GENERAL_TIMEOUT
|
||||
) -> List:
|
||||
) -> list:
|
||||
"""
|
||||
Shorthand for waiting for all the coroutines in the iterable given in parallel. Creates
|
||||
a task for each coroutine.
|
||||
|
||||
@@ -4,7 +4,12 @@ from typing import Optional
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.events.event import EventSource
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
from openhands.storage.files import FileStore
|
||||
|
||||
|
||||
async def generate_conversation_title(
|
||||
@@ -55,3 +60,81 @@ async def generate_conversation_title(
|
||||
except Exception as e:
|
||||
logger.error(f'Error generating conversation title: {e}')
|
||||
return None
|
||||
|
||||
|
||||
def get_default_conversation_title(conversation_id: str) -> str:
|
||||
"""
|
||||
Generate a default title for a conversation based on its ID.
|
||||
|
||||
Args:
|
||||
conversation_id: The ID of the conversation
|
||||
|
||||
Returns:
|
||||
A default title string
|
||||
"""
|
||||
return f'Conversation {conversation_id[:5]}'
|
||||
|
||||
|
||||
async def auto_generate_title(
|
||||
conversation_id: str, user_id: str | None, file_store: FileStore, settings: Settings
|
||||
) -> str:
|
||||
"""
|
||||
Auto-generate a title for a conversation based on the first user message.
|
||||
Uses LLM-based title generation if available, otherwise falls back to a simple truncation.
|
||||
|
||||
Args:
|
||||
conversation_id: The ID of the conversation
|
||||
user_id: The ID of the user
|
||||
|
||||
Returns:
|
||||
A generated title string
|
||||
"""
|
||||
logger.info(f'Auto-generating title for conversation {conversation_id}')
|
||||
|
||||
try:
|
||||
# Create an event stream for the conversation
|
||||
event_stream = EventStream(conversation_id, file_store, user_id)
|
||||
|
||||
# Find the first user message
|
||||
first_user_message = None
|
||||
for event in event_stream.get_events():
|
||||
if (
|
||||
event.source == EventSource.USER
|
||||
and isinstance(event, MessageAction)
|
||||
and event.content
|
||||
and event.content.strip()
|
||||
):
|
||||
first_user_message = event.content
|
||||
break
|
||||
|
||||
if first_user_message:
|
||||
# Get LLM config from user settings
|
||||
try:
|
||||
if settings and settings.llm_model:
|
||||
# Create LLM config from settings
|
||||
llm_config = LLMConfig(
|
||||
model=settings.llm_model,
|
||||
api_key=settings.llm_api_key,
|
||||
base_url=settings.llm_base_url,
|
||||
)
|
||||
|
||||
# Try to generate title using LLM
|
||||
llm_title = await generate_conversation_title(
|
||||
first_user_message, llm_config
|
||||
)
|
||||
if llm_title:
|
||||
logger.info(f'Generated title using LLM: {llm_title}')
|
||||
return llm_title
|
||||
except Exception as e:
|
||||
logger.error(f'Error using LLM for title generation: {e}')
|
||||
|
||||
# Fall back to simple truncation if LLM generation fails or is unavailable
|
||||
first_user_message = first_user_message.strip()
|
||||
title = first_user_message[:30]
|
||||
if len(first_user_message) > 30:
|
||||
title += '...'
|
||||
logger.info(f'Generated title using truncation: {title}')
|
||||
return title
|
||||
except Exception as e:
|
||||
logger.error(f'Error generating title: {str(e)}')
|
||||
return ''
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import importlib
|
||||
from functools import lru_cache
|
||||
from typing import Type, TypeVar
|
||||
from typing import TypeVar
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
@@ -15,7 +15,7 @@ def import_from(qual_name: str):
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_impl(cls: Type[T], impl_name: str | None) -> Type[T]:
|
||||
def get_impl(cls: type[T], impl_name: str | None) -> type[T]:
|
||||
"""Import a named implementation of the specified class"""
|
||||
if impl_name is None:
|
||||
return cls
|
||||
|
||||
Generated
+208
-78
File diff suppressed because one or more lines are too long
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user