Compare commits

..

1 Commits

Author SHA1 Message Date
Chuck Butkus 60c8480dbd Try 2025-06-05 23:14:45 -04:00
53 changed files with 730 additions and 734 deletions
Executable → Regular
View File
-4
View File
@@ -313,8 +313,6 @@ jobs:
TEST_IN_CI=true \
RUN_AS_OPENHANDS=false \
poetry run pytest -n 7 -raRs --reruns 2 --reruns-delay 5 -s ./tests/runtime --ignore=tests/runtime/test_browsergym_envs.py --durations=10
env:
DEBUG: "1"
# Run unit tests with the Docker runtime Docker images as openhands user
test_runtime_oh:
@@ -380,8 +378,6 @@ jobs:
TEST_IN_CI=true \
RUN_AS_OPENHANDS=true \
poetry run pytest -n 7 -raRs --reruns 2 --reruns-delay 5 -s ./tests/runtime --ignore=tests/runtime/test_browsergym_envs.py --durations=10
env:
DEBUG: "1"
# The two following jobs (named identically) are to check whether all the runtime tests have passed as the
# "All Runtime Tests Passed" is a required job for PRs to merge
-6
View File
@@ -74,11 +74,5 @@ jobs:
run: poetry install --with dev,test,runtime
- name: Run Windows unit tests
run: poetry run pytest -svv tests/unit/test_windows_bash.py
env:
DEBUG: "1"
- name: Run Windows runtime tests with LocalRuntime
run: $env:TEST_RUNTIME="local"; poetry run pytest -svv tests/runtime/test_bash.py
env:
TEST_RUNTIME: local
DEBUG: "1"
-8
View File
@@ -109,14 +109,6 @@ OpenHands requires an API key to access most language models. Here's how to get
</Accordion>
<Accordion title="Google (Gemini)">
1. Create a Google account if you don't already have one.
2. [Generate an API key](https://aistudio.google.com/apikey).
3. [Set up billing](https://aistudio.google.com/usage?tab=billing).
</Accordion>
</AccordionGroup>
Consider setting usage limits to control costs.
+1 -56
View File
@@ -1,60 +1,5 @@
import axios, { AxiosError, AxiosResponse } from "axios";
import axios from "axios";
export const openHands = axios.create({
baseURL: `${window.location.protocol}//${import.meta.env.VITE_BACKEND_BASE_URL || window?.location.host}`,
});
// Helper function to check if a response contains an email verification error
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const checkForEmailVerificationError = (data: any): boolean => {
const EMAIL_NOT_VERIFIED = "EmailNotVerifiedError";
if (typeof data === "string") {
return data.includes(EMAIL_NOT_VERIFIED);
}
if (typeof data === "object" && data !== null) {
if ("message" in data) {
const { message } = data;
if (typeof message === "string") {
return message.includes(EMAIL_NOT_VERIFIED);
}
if (Array.isArray(message)) {
return message.some(
(msg) => typeof msg === "string" && msg.includes(EMAIL_NOT_VERIFIED),
);
}
}
// Search any values in object in case message key is different
return Object.values(data).some(
(value) =>
(typeof value === "string" && value.includes(EMAIL_NOT_VERIFIED)) ||
(Array.isArray(value) &&
value.some(
(v) => typeof v === "string" && v.includes(EMAIL_NOT_VERIFIED),
)),
);
}
return false;
};
// Set up the global interceptor
openHands.interceptors.response.use(
(response: AxiosResponse) => response,
(error: AxiosError) => {
// Check if it's a 403 error with the email verification message
if (
error.response?.status === 403 &&
checkForEmailVerificationError(error.response?.data)
) {
if (window.location.pathname !== "/settings/user") {
window.location.reload();
}
}
// Continue with the error for other error handlers
return Promise.reject(error);
},
);
-20
View File
@@ -236,26 +236,6 @@ class OpenHands {
return data;
}
static async startConversation(
conversationId: string,
): Promise<Conversation | null> {
const { data } = await openHands.post<Conversation | null>(
`/api/conversations/${conversationId}/start`,
);
return data;
}
static async stopConversation(
conversationId: string,
): Promise<Conversation | null> {
const { data } = await openHands.post<Conversation | null>(
`/api/conversations/${conversationId}/stop`,
);
return data;
}
/**
* Get the settings from the server or use the default settings if not found
*/
@@ -84,7 +84,7 @@ export function AgentStatusBar() {
setStatusMessage(t(I18nKey.STATUS$STARTING_RUNTIME));
setIndicatorColor(IndicatorColor.RED);
} else if (status === WsClientProviderStatus.DISCONNECTED) {
setStatusMessage(t(I18nKey.STATUS$WEBSOCKET_CLOSED));
setStatusMessage(t(I18nKey.STATUS$CONNECTED)); // Using STATUS$CONNECTED instead of STATUS$CONNECTING
setIndicatorColor(IndicatorColor.RED);
} else {
setStatusMessage(AGENT_STATUS_MAP[curAgentState].message);
@@ -122,7 +122,7 @@ export function FileDiffViewer({ path, type }: FileDiffViewerProps) {
modifiedEditor.onDidContentSizeChange(updateEditorHeight);
};
const status = (type === "U" ? STATUS_MAP.A : STATUS_MAP[type]) || "?";
const status = type === "U" ? STATUS_MAP.A : STATUS_MAP[type];
let statusIcon: React.ReactNode;
if (typeof status === "string") {
+7 -15
View File
@@ -150,8 +150,7 @@ export function WsClientProvider({
const { providers } = useUserProviders();
const messageRateHandler = useRate({ threshold: 250 });
const { data: conversation, refetch: refetchConversation } =
useActiveConversation();
const { data: conversation } = useActiveConversation();
function send(event: Record<string, unknown>) {
if (!sioRef.current) {
@@ -270,11 +269,14 @@ export function WsClientProvider({
sio.io.opts.query.latest_event_id = lastEventRef.current?.id;
updateStatusWhenErrorMessagePresent(data);
setErrorMessage(hasValidMessageProperty(data) ? data.message : "");
setErrorMessage(
hasValidMessageProperty(data)
? data.message
: "The WebSocket connection was closed.",
);
}
function handleError(data: unknown) {
// set status
setStatus(WsClientProviderStatus.DISCONNECTED);
updateStatusWhenErrorMessagePresent(data);
@@ -283,9 +285,6 @@ export function WsClientProvider({
? data.message
: "An unknown error occurred on the WebSocket connection.",
);
// check if something went wrong with the conversation.
refetchConversation();
}
React.useEffect(() => {
@@ -301,19 +300,12 @@ export function WsClientProvider({
if (!conversationId) {
throw new Error("No conversation ID provided");
}
if (
!conversation ||
["STOPPED", "STARTING"].includes(conversation.status)
) {
if (!conversation || conversation.status === "STARTING") {
return () => undefined; // conversation not yet loaded
}
let sio = sioRef.current;
if (sio?.connected) {
sio.disconnect();
}
const lastEvent = lastEventRef.current;
const query = {
latest_event_id: lastEvent?.id ?? -1,
@@ -9,7 +9,7 @@ export const useActiveConversation = () => {
const { conversationId } = useConversationId();
const userConversation = useUserConversation(conversationId, (query) => {
if (query.state.data?.status === "STARTING") {
return 3000; // 3 seconds
return 2000; // 2 seconds
}
return FIVE_MINUTES;
});
@@ -17,10 +17,6 @@ export const useActiveConversation = () => {
useEffect(() => {
const conversation = userConversation.data;
OpenHands.setCurrentConversation(conversation || null);
}, [
conversationId,
userConversation.isFetched,
userConversation?.data?.status,
]);
}, [conversationId, userConversation.isFetched]);
return userConversation;
};
@@ -0,0 +1,116 @@
import { useEffect } from "react";
import { useQueryClient } from "@tanstack/react-query";
import { useNavigate } from "react-router";
import { AxiosError } from "axios";
import { openHands } from "#/api/open-hands-axios";
import { Settings } from "#/types/settings";
import { useConfig } from "#/hooks/query/use-config";
/**
* Hook to handle email verification errors (403 with "Email has not been verified" message)
* This hook sets up an axios interceptor that will reload settings and navigate to the user settings page
* when a 403 error with the specific message is encountered.
*/
export const useHandleEmailVerification = () => {
const queryClient = useQueryClient();
const navigate = useNavigate();
const { data: config } = useConfig();
const appMode = config?.APP_MODE;
console.log(`config: ${config}`);
console.log(`AppMode: ${appMode}`);
useEffect(() => {
// Add response interceptor
const interceptorId = openHands.interceptors.response.use(
(response) => response,
(error: AxiosError) => {
console.log(
`Received error ${error.response?.status} with message ${error.response?.data}`,
);
const EMAIL_NOT_VERIFIED = "EmailNotVerifiedError";
// check for email verification error message no matter how it is returned.
const isEmailNotVerified = (() => {
const data = error.response?.data;
if (typeof data === "string") {
return data.includes(EMAIL_NOT_VERIFIED);
}
if (typeof data === "object" && data !== null) {
if ("message" in data) {
const { message } = data;
if (typeof message === "string") {
return message.includes(EMAIL_NOT_VERIFIED);
}
if (Array.isArray(message)) {
return message.some(
(msg) =>
typeof msg === "string" && msg.includes(EMAIL_NOT_VERIFIED),
);
}
}
// Search any values in object in case message key is different
return Object.values(data).some(
(value) =>
(typeof value === "string" &&
value.includes(EMAIL_NOT_VERIFIED)) ||
(Array.isArray(value) &&
value.some(
(v) =>
typeof v === "string" && v.includes(EMAIL_NOT_VERIFIED),
)),
);
}
return false;
})();
// Check if it's a 403 error with the specific message
if (error.response?.status === 403 && isEmailNotVerified) {
console.log("EMAIL VERIFICATION ERROR");
// Only handle this in SAAS mode
console.log(`config1: ${config}`);
console.log(`AppMode1: ${appMode}`);
if (appMode === "saas") {
// Update settings to mark email as unverified
queryClient.setQueryData(
["settings"],
(oldData: Settings | undefined) => {
if (oldData) {
console.log("ADDING EMAIL_VERIFIED is FALSE");
return {
...oldData,
EMAIL_VERIFIED: false,
};
}
console.log("NO CHANGES TO SETTINGS");
return oldData;
},
);
// Invalidate settings to reload them
queryClient.invalidateQueries({ queryKey: ["settings"] });
// Navigate to settings/user page
// The EmailVerificationGuard will handle the redirect
console.log("NAVIGATING to /settings/user");
navigate("/settings/user");
}
} else {
console.log("NOT EMAIL VERIFICATION ERROR");
console.log(typeof error.response?.data);
}
// Continue with the error for other error handlers
return Promise.reject(error);
},
);
// Clean up interceptor when component unmounts
return () => {
openHands.interceptors.response.eject(interceptorId);
};
}, [queryClient, navigate]);
};
-1
View File
@@ -1,6 +1,5 @@
// this file generate by script, don't modify it manually!!!
export enum I18nKey {
STATUS$WEBSOCKET_CLOSED = "STATUS$WEBSOCKET_CLOSED",
HOME$LAUNCH_FROM_SCRATCH = "HOME$LAUNCH_FROM_SCRATCH",
HOME$READ_THIS = "HOME$READ_THIS",
AUTH$LOGGING_BACK_IN = "AUTH$LOGGING_BACK_IN",
-32
View File
@@ -1,20 +1,4 @@
{
"STATUS$WEBSOCKET_CLOSED": {
"en": "The WebSocket connection was closed.",
"ja": "WebSocket接続が閉じられました。",
"zh-CN": "WebSocket连接已关闭。",
"zh-TW": "WebSocket連接已關閉。",
"ko-KR": "WebSocket 연결이 닫혔습니다.",
"no": "WebSocket-tilkoblingen ble lukket.",
"it": "La connessione WebSocket è stata chiusa.",
"pt": "A conexão WebSocket foi fechada.",
"es": "La conexión WebSocket se ha cerrado.",
"ar": "تم إغلاق اتصال WebSocket.",
"fr": "La connexion WebSocket a été fermée.",
"tr": "WebSocket bağlantısı kapatıldı.",
"de": "Die WebSocket-Verbindung wurde geschlossen.",
"uk": "З'єднання WebSocket було закрито."
},
"HOME$LAUNCH_FROM_SCRATCH": {
"en": "Launch from Scratch",
"ja": "ゼロから始める",
@@ -9071,22 +9055,6 @@
"de": "Sie müssen Ihre E-Mail-Adresse bestätigen, bevor Sie All Hands verwenden können",
"uk": "Ви повинні підтвердити свою електронну адресу перед використанням All Hands"
},
"SETTINGS$INVALID_EMAIL_FORMAT": {
"en": "Please enter a valid email address",
"ja": "有効なメールアドレスを入力してください",
"zh-CN": "请输入有效的电子邮件地址",
"zh-TW": "請輸入有效的電子郵件地址",
"ko-KR": "유효한 이메일 주소를 입력하세요",
"no": "Vennligst skriv inn en gyldig e-postadresse",
"it": "Inserisci un indirizzo email valido",
"pt": "Por favor, insira um endereço de e-mail válido",
"es": "Por favor, introduzca una dirección de correo electrónico válida",
"ar": "الرجاء إدخال عنوان بريد إلكتروني صالح",
"fr": "Veuillez entrer une adresse e-mail valide",
"tr": "Lütfen geçerli bir e-posta adresi girin",
"de": "Bitte geben Sie eine gültige E-Mail-Adresse ein",
"uk": "Будь ласка, введіть дійсну електронну адресу"
},
"SETTINGS$EMAIL_VERIFICATION_RESTRICTION_MESSAGE": {
"en": "Your access is limited until your email is verified. You can only access this settings page.",
"ja": "メールが確認されるまでアクセスが制限されています。この設定ページにのみアクセスできます。",
+2 -7
View File
@@ -43,7 +43,7 @@ function AppContent() {
const { t } = useTranslation();
const { data: settings } = useSettings();
const { conversationId } = useConversationId();
const { data: conversation, isFetched, refetch } = useActiveConversation();
const { data: conversation, isFetched } = useActiveConversation();
const { data: isAuthed } = useIsAuthed();
const { curAgentState } = useSelector((state: RootState) => state.agent);
@@ -61,13 +61,8 @@ function AppContent() {
"This conversation does not exist, or you do not have permission to access it.",
);
navigate("/");
} else if (conversation?.status === "STOPPED") {
// start the conversation if the state is stopped on initial load
OpenHands.startConversation(conversation.conversation_id).then(() =>
refetch(),
);
}
}, [conversation?.conversation_id, isFetched, isAuthed]);
}, [conversation, isFetched, isAuthed]);
React.useEffect(() => {
dispatch(clearTerminal());
+4
View File
@@ -24,6 +24,7 @@ import { displaySuccessToast } from "#/utils/custom-toast-handlers";
import { useIsOnTosPage } from "#/hooks/use-is-on-tos-page";
import { useAutoLogin } from "#/hooks/use-auto-login";
import { useAuthCallback } from "#/hooks/use-auth-callback";
import { useHandleEmailVerification } from "#/hooks/use-handle-email-verification";
import { LOCAL_STORAGE_KEYS } from "#/utils/local-storage";
import { EmailVerificationGuard } from "#/components/features/guards/email-verification-guard";
@@ -93,6 +94,9 @@ export default function MainApp() {
// Handle authentication callback and set login method after successful authentication
useAuthCallback();
// Set up interceptor for email verification errors
useHandleEmailVerification();
React.useEffect(() => {
// Don't change language when on TOS page
if (!isOnTosPage && settings?.LANGUAGE) {
+4 -24
View File
@@ -5,9 +5,6 @@ import { useSettings } from "#/hooks/query/use-settings";
import { openHands } from "#/api/open-hands-axios";
import { displaySuccessToast } from "#/utils/custom-toast-handlers";
// Email validation regex pattern
const EMAIL_REGEX = /^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$/;
function EmailInputSection({
email,
onEmailChange,
@@ -17,7 +14,6 @@ function EmailInputSection({
isResendingVerification,
isEmailChanged,
emailVerified,
isEmailValid,
children,
}: {
email: string;
@@ -28,7 +24,6 @@ function EmailInputSection({
isResendingVerification: boolean;
isEmailChanged: boolean;
emailVerified?: boolean;
isEmailValid: boolean;
children: React.ReactNode;
}) {
const { t } = useTranslation();
@@ -41,27 +36,17 @@ function EmailInputSection({
type="email"
value={email}
onChange={onEmailChange}
className={`text-base text-white p-2 bg-base-tertiary rounded border ${
isEmailChanged && !isEmailValid
? "border-red-500"
: "border-tertiary"
} flex-grow focus:outline-none focus:border-transparent focus:ring-0`}
className="text-base text-white p-2 bg-base-tertiary rounded border border-tertiary flex-grow focus:outline-none focus:border-transparent focus:ring-0"
placeholder={t("SETTINGS$USER_EMAIL_LOADING")}
data-testid="email-input"
/>
</div>
{isEmailChanged && !isEmailValid && (
<div className="text-red-500 text-sm mt-1" data-testid="email-validation-error">
{t("SETTINGS$INVALID_EMAIL_FORMAT")}
</div>
)}
<div className="flex items-center gap-3 mt-2">
<button
type="button"
onClick={onSaveEmail}
disabled={!isEmailChanged || isSaving || !isEmailValid}
disabled={!isEmailChanged || isSaving}
className="px-4 py-2 rounded bg-primary text-white hover:opacity-80 disabled:opacity-30 disabled:cursor-not-allowed disabled:text-[#0D0F11]"
data-testid="save-email-button"
>
@@ -113,7 +98,6 @@ function UserSettingsScreen() {
const [originalEmail, setOriginalEmail] = useState("");
const [isSaving, setIsSaving] = useState(false);
const [isResendingVerification, setIsResendingVerification] = useState(false);
const [isEmailValid, setIsEmailValid] = useState(true);
const queryClient = useQueryClient();
const pollingIntervalRef = useRef<number | null>(null);
const prevVerificationStatusRef = useRef<boolean | undefined>(undefined);
@@ -122,7 +106,6 @@ function UserSettingsScreen() {
if (settings?.EMAIL) {
setEmail(settings.EMAIL);
setOriginalEmail(settings.EMAIL);
setIsEmailValid(EMAIL_REGEX.test(settings.EMAIL));
}
}, [settings?.EMAIL]);
@@ -160,13 +143,11 @@ function UserSettingsScreen() {
}, [settings?.EMAIL_VERIFIED, refetch, queryClient, t]);
const handleEmailChange = (e: React.ChangeEvent<HTMLInputElement>) => {
const newEmail = e.target.value;
setEmail(newEmail);
setIsEmailValid(EMAIL_REGEX.test(newEmail));
setEmail(e.target.value);
};
const handleSaveEmail = async () => {
if (email === originalEmail || !isEmailValid) return;
if (email === originalEmail) return;
try {
setIsSaving(true);
await openHands.post("/api/email", { email }, { withCredentials: true });
@@ -213,7 +194,6 @@ function UserSettingsScreen() {
isResendingVerification={isResendingVerification}
isEmailChanged={isEmailChanged}
emailVerified={settings?.EMAIL_VERIFIED}
isEmailValid={isEmailValid}
>
{settings?.EMAIL_VERIFIED === false && <VerificationAlert />}
</EmailInputSection>
@@ -505,7 +505,10 @@ class GitHubService(BaseGitService, GitService):
)
# Return the HTML URL of the created PR
return response['html_url']
if 'html_url' in response:
return response['html_url']
else:
return f'PR created but URL not found in response: {response}'
@@ -500,8 +500,12 @@ class GitLabService(BaseGitService, GitService):
url=url, params=payload, method=RequestMethod.POST
)
# Return the web URL of the created MR
if 'web_url' in response:
return response['web_url']
else:
return f'MR created but URL not found in response: {response}'
return response['web_url']
@@ -15,3 +15,4 @@ When you're done, make sure to
2. Use the `create_pr` tool to open a new PR
3. Name the branch using `openhands/` as a prefix (e.g `openhands/update-readme`)
4. The PR description should mention that it "fixes" or "closes" the issue number
5. Make sure to leave the following sentence at the end of the PR description: `@{{ username }} can click here to [continue refining the PR]({{ conversation_url }})`
@@ -9,3 +9,4 @@ When you're done, make sure to
1. Use the `create_pr` tool to open a new PR
2. The PR description should mention that it "fixes" or "closes" the issue number
3. Make sure to leave the following sentence at the end of the PR description: `@{{ username }} can click here to [continue refining the PR]({{ conversation_url }})`
@@ -15,3 +15,4 @@ When you're done, make sure to
2. Use the `create_mr` tool to open a new MR
3. Name the branch using `openhands/` as a prefix (e.g `openhands/update-readme`)
4. The MR description should mention that it "fixes" or "closes" the issue number
5. Make sure to leave the following sentence at the end of the MR description: `@{{ username }} can click here to [continue refining the MR]({{ conversation_url }})`
@@ -9,3 +9,4 @@ When you're done, make sure to
1. Use the `create_mr` tool to open a new MR
2. The MR description should mention that it "fixes" or "closes" the issue number
3. Make sure to leave the following sentence at the end of the MR description: `@{{ username }} can click here to [continue refining the MR]({{ conversation_url }})`
@@ -5,3 +5,7 @@ These are a list of text messages attached in order of most recent.
{{ message }}
{% if not loop.last %}\n\n{% endif %}
{% endfor %}
If you opened a pull request, please leave the following comment at the end your summary and pull request description
`{{ username }} can click here to [continue refining the PR]({{ conversation_url }})`
+170 -66
View File
@@ -1,12 +1,13 @@
import asyncio
import datetime
from contextlib import AsyncExitStack
from typing import Optional
from fastmcp import Client
from fastmcp.client.transports import SSETransport, StreamableHttpTransport
from mcp import McpError
from mcp.types import CallToolResult
from mcp import ClientSession
from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamablehttp_client
from pydantic import BaseModel, Field
from openhands.core.config.mcp_config import MCPSHTTPServerConfig, MCPSSEServerConfig
from openhands.core.logger import openhands_logger as logger
from openhands.mcp.tool import MCPClientTool
@@ -16,7 +17,8 @@ class MCPClient(BaseModel):
A collection of tools that connects to an MCP server and manages available tools through the Model Context Protocol.
"""
client: Optional[Client] = None
session: Optional[ClientSession] = None
exit_stack: AsyncExitStack = AsyncExitStack()
description: str = 'MCP client tools for server interaction'
tools: list[MCPClientTool] = Field(default_factory=list)
tool_map: dict[str, MCPClientTool] = Field(default_factory=dict)
@@ -24,87 +26,189 @@ class MCPClient(BaseModel):
class Config:
arbitrary_types_allowed = True
async def connect_sse(
self,
server_url: str,
api_key: str | None = None,
conversation_id: str | None = None,
timeout: float = 30.0,
) -> None:
"""Connect to an MCP server using SSE transport.
Args:
server_url: The URL of the SSE server to connect to.
timeout: Connection timeout in seconds. Default is 30 seconds.
"""
if not server_url:
raise ValueError('Server URL is required.')
if self.session:
await self.disconnect()
try:
# Use asyncio.wait_for to enforce the timeout
async def connect_with_timeout():
headers = (
{
'Authorization': f'Bearer {api_key}',
's': api_key, # We need this for action execution server's MCP Router
'X-Session-API-Key': api_key, # We need this for Remote Runtime
}
if api_key
else {}
)
if conversation_id:
headers['X-OpenHands-Conversation-ID'] = conversation_id
# Convert float timeout to datetime.timedelta for consistency
timeout_delta = datetime.timedelta(seconds=timeout)
streams_context = sse_client(
url=server_url,
headers=headers if headers else None,
timeout=timeout,
)
streams = await self.exit_stack.enter_async_context(streams_context)
# For SSE client, we only get read_stream and write_stream (2 values)
read_stream, write_stream = streams
self.session = await self.exit_stack.enter_async_context(
ClientSession(
read_stream, write_stream, read_timeout_seconds=timeout_delta
)
)
await self._initialize_and_list_tools()
# Apply timeout to the entire connection process
await asyncio.wait_for(connect_with_timeout(), timeout=timeout)
except asyncio.TimeoutError:
logger.error(
f'Connection to {server_url} timed out after {timeout} seconds'
)
await self.disconnect() # Clean up resources
raise # Re-raise the TimeoutError
except Exception as e:
logger.error(f'Error connecting to {server_url}: {str(e)}')
await self.disconnect() # Clean up resources
raise
async def _initialize_and_list_tools(self) -> None:
"""Initialize session and populate tool map."""
if not self.client:
if not self.session:
raise RuntimeError('Session not initialized.')
async with self.client:
tools = await self.client.list_tools()
await self.session.initialize()
response = await self.session.list_tools()
# Clear existing tools
self.tools = []
# Create proper tool objects for each server tool
for tool in tools:
for tool in response.tools:
server_tool = MCPClientTool(
name=tool.name,
description=tool.description,
inputSchema=tool.inputSchema,
session=self.client,
session=self.session,
)
self.tool_map[tool.name] = server_tool
self.tools.append(server_tool)
logger.info(f'Connected to server with tools: {[tool.name for tool in tools]}')
logger.info(
f'Connected to server with tools: {[tool.name for tool in response.tools]}'
)
async def connect_http(
self,
server: MCPSSEServerConfig | MCPSHTTPServerConfig,
conversation_id: str | None = None,
timeout: float = 30.0,
):
"""Connect to MCP server using SHTTP or SSE transport"""
server_url = server.url
api_key = server.api_key
if not server_url:
raise ValueError('Server URL is required.')
try:
headers = (
{
'Authorization': f'Bearer {api_key}',
's': api_key, # We need this for action execution server's MCP Router
'X-Session-API-Key': api_key, # We need this for Remote Runtime
}
if api_key
else {}
)
if conversation_id:
headers['X-OpenHands-Conversation-ID'] = conversation_id
# Instantiate custom transports due to custom headers
if isinstance(server, MCPSHTTPServerConfig):
transport = StreamableHttpTransport(
url=server_url,
headers=headers if headers else None,
)
else:
transport = SSETransport(
url=server_url,
headers=headers if headers else None,
)
self.client = Client(transport, timeout=timeout)
await self._initialize_and_list_tools()
except McpError as e:
logger.error(f'McpError connecting to {server_url}: {e}')
raise # Re-raise the error
except Exception as e:
logger.error(f'Error connecting to {server_url}: {e}')
raise
async def call_tool(self, tool_name: str, args: dict) -> CallToolResult:
async def call_tool(self, tool_name: str, args: dict):
"""Call a tool on the MCP server."""
if tool_name not in self.tool_map:
raise ValueError(f'Tool {tool_name} not found.')
# The MCPClientTool is primarily for metadata; use the session to call the actual tool.
if not self.client:
if not self.session:
raise RuntimeError('Client session is not available.')
return await self.session.call_tool(name=tool_name, arguments=args)
async with self.client:
return await self.client.call_tool_mcp(name=tool_name, arguments=args)
async def connect_shttp(
self,
server_url: str,
api_key: str | None = None,
conversation_id: str | None = None,
timeout: float = 30.0,
) -> None:
"""Connect to an MCP server using StreamableHTTP transport.
Args:
server_url: The URL of the StreamableHTTP server to connect to.
api_key: Optional API key for authentication.
conversation_id: Optional conversation ID for session tracking.
timeout: Connection timeout in seconds. Default is 30 seconds.
"""
if not server_url:
raise ValueError('Server URL is required.')
if self.session:
await self.disconnect()
try:
# Use asyncio.wait_for to enforce the timeout
async def connect_with_timeout():
headers = (
{
'Authorization': f'Bearer {api_key}',
's': api_key, # We need this for action execution server's MCP Router
'X-Session-API-Key': api_key, # We need this for Remote Runtime
}
if api_key
else {}
)
if conversation_id:
headers['X-OpenHands-Conversation-ID'] = conversation_id
# Convert float timeout to datetime.timedelta
timeout_delta = datetime.timedelta(seconds=timeout)
sse_read_timeout_delta = datetime.timedelta(
seconds=timeout * 10
) # 10x longer for read timeout
streams_context = streamablehttp_client(
url=server_url,
headers=headers if headers else None,
timeout=timeout_delta,
sse_read_timeout=sse_read_timeout_delta,
)
streams = await self.exit_stack.enter_async_context(streams_context)
# For StreamableHTTP client, we get read_stream, write_stream, and get_session_id (3 values)
read_stream, write_stream, _ = streams
self.session = await self.exit_stack.enter_async_context(
ClientSession(
read_stream, write_stream, read_timeout_seconds=timeout_delta
)
)
await self._initialize_and_list_tools()
# Apply timeout to the entire connection process
await asyncio.wait_for(connect_with_timeout(), timeout=timeout)
except asyncio.TimeoutError:
logger.error(
f'Connection to {server_url} timed out after {timeout} seconds'
)
await self.disconnect() # Clean up resources
raise # Re-raise the TimeoutError
except Exception as e:
logger.error(f'Error connecting to {server_url}: {str(e)}')
await self.disconnect() # Clean up resources
raise
async def disconnect(self) -> None:
"""Disconnect from the MCP server and clean up resources."""
if self.session:
try:
# Close the session first
if hasattr(self.session, 'close'):
await self.session.close()
# Then close the exit stack
await self.exit_stack.aclose()
except Exception as e:
logger.error(f'Error during disconnect: {str(e)}')
finally:
self.session = None
self.tools = []
logger.info('Disconnected from MCP server')
+27 -4
View File
@@ -72,22 +72,38 @@ async def create_mcp_clients(
mcp_clients = []
for server in servers:
is_shttp = isinstance(server, MCPSHTTPServerConfig)
connection_type = 'SHTTP' if is_shttp else 'SSE'
is_sse = isinstance(server, MCPSSEServerConfig)
connection_type = 'SSE' if is_sse else 'SHTTP'
logger.info(
f'Initializing MCP agent for {server} with {connection_type} connection...'
)
client = MCPClient()
try:
await client.connect_http(server, conversation_id=conversation_id)
if is_sse:
await client.connect_sse(
server.url,
api_key=server.api_key,
conversation_id=conversation_id,
)
else:
await client.connect_shttp(
server.url,
api_key=server.api_key,
conversation_id=conversation_id,
)
# Only add the client to the list after a successful connection
mcp_clients.append(client)
except Exception as e:
logger.error(f'Failed to connect to {server}: {str(e)}', exc_info=True)
try:
await client.disconnect()
except Exception as disconnect_error:
logger.error(
f'Error during disconnect after failed connection: {str(disconnect_error)}'
)
return mcp_clients
@@ -127,6 +143,13 @@ async def fetch_mcp_tools_from_config(
# Convert tools to the format expected by the agent
mcp_tools = convert_mcp_clients_to_tools(mcp_clients)
# Always disconnect clients to clean up resources
for mcp_client in mcp_clients:
try:
await mcp_client.disconnect()
except Exception as disconnect_error:
logger.error(f'Error disconnecting MCP client: {str(disconnect_error)}')
except Exception as e:
logger.error(f'Error fetching MCP tools: {str(e)}')
return []
+20 -16
View File
@@ -65,6 +65,7 @@ from openhands.runtime.browser.browser_env import BrowserEnv
from openhands.runtime.file_viewer_server import start_file_viewer_server
from openhands.runtime.plugins import ALL_PLUGINS, JupyterPlugin, Plugin, VSCodePlugin
from openhands.runtime.utils import find_available_tcp_port
from openhands.runtime.utils.async_bash import AsyncBashSession
from openhands.runtime.utils.bash import BashSession
from openhands.runtime.utils.files import insert_lines, read_lines
from openhands.runtime.utils.log_capture import capture_logs
@@ -253,10 +254,12 @@ class ActionExecutor:
# If we get here, the browser is ready
logger.debug('Browser is ready')
def _create_bash_session(self, cwd: str | None = None):
async def ainit(self):
# bash needs to be initialized first
logger.debug('Initializing bash session')
if sys.platform == 'win32':
return WindowsPowershellSession( # type: ignore[name-defined]
work_dir=cwd or self._initial_cwd,
self.bash_session = WindowsPowershellSession( # type: ignore[name-defined]
work_dir=self._initial_cwd,
username=self.username,
no_change_timeout_seconds=int(
os.environ.get('NO_CHANGE_TIMEOUT_SECONDS', 10)
@@ -264,21 +267,15 @@ class ActionExecutor:
max_memory_mb=self.max_memory_gb * 1024 if self.max_memory_gb else None,
)
else:
bash_session = BashSession(
work_dir=cwd or self._initial_cwd,
self.bash_session = BashSession(
work_dir=self._initial_cwd,
username=self.username,
no_change_timeout_seconds=int(
os.environ.get('NO_CHANGE_TIMEOUT_SECONDS', 10)
),
max_memory_mb=self.max_memory_gb * 1024 if self.max_memory_gb else None,
)
bash_session.initialize()
return bash_session
async def ainit(self):
# bash needs to be initialized first
logger.debug('Initializing bash session')
self.bash_session = self._create_bash_session()
self.bash_session.initialize()
logger.debug('Bash session initialized')
# Start browser initialization in the background
@@ -391,11 +388,18 @@ class ActionExecutor:
self, action: CmdRunAction
) -> CmdOutputObservation | ErrorObservation:
try:
bash_session = self.bash_session
if action.is_static:
bash_session = self._create_bash_session(action.cwd)
assert bash_session is not None
obs = await call_sync_from_async(bash_session.execute, action)
path = action.cwd or self._initial_cwd
result = await AsyncBashSession.execute(action.command, path)
obs = CmdOutputObservation(
content=result.content,
exit_code=result.exit_code,
command=action.command,
)
return obs
assert self.bash_session is not None
obs = await call_sync_from_async(self.bash_session.execute, action)
return obs
except Exception as e:
logger.error(f'Error running command: {e}')
+1 -4
View File
@@ -400,7 +400,7 @@ class Runtime(FileEditRuntimeMixin):
'No repository selected. Initializing a new git repository in the workspace.'
)
action = CmdRunAction(
command=f'git init && git config --global --add safe.directory {self.workspace_root}'
command='git init',
)
self.run_action(action)
else:
@@ -952,9 +952,6 @@ fi
exit_code = 0
content = ''
if isinstance(obs, ErrorObservation):
exit_code = -1
if hasattr(obs, 'exit_code'):
exit_code = obs.exit_code
if hasattr(obs, 'content'):
@@ -406,7 +406,7 @@ class ActionExecutionClient(Runtime):
'POST',
f'{self.action_execution_server_url}/update_mcp_server',
json=stdio_tools,
timeout=60,
timeout=10,
)
result = response.json()
if response.status_code != 200:
@@ -464,13 +464,16 @@ class ActionExecutionClient(Runtime):
)
# Create clients for this specific operation
mcp_clients = await create_mcp_clients(
updated_mcp_config.sse_servers, updated_mcp_config.shttp_servers, self.sid
)
mcp_clients = await create_mcp_clients(updated_mcp_config.sse_servers, updated_mcp_config.shttp_servers, self.sid)
# Call the tool and return the result
# No need for try/finally since disconnect() is now just resetting state
result = await call_tool_mcp_handler(mcp_clients, action)
# Reset client state (no active connections to worry about)
for client in mcp_clients:
await client.disconnect()
return result
def close(self) -> None:
+54
View File
@@ -0,0 +1,54 @@
import asyncio
import os
from openhands.runtime.base import CommandResult
class AsyncBashSession:
@staticmethod
async def execute(command: str, work_dir: str) -> CommandResult:
"""Execute a command in the bash session asynchronously."""
work_dir = os.path.abspath(work_dir)
if not os.path.exists(work_dir):
raise ValueError(f'Work directory {work_dir} does not exist.')
command = command.strip()
if not command:
return CommandResult(content='', exit_code=0)
try:
process = await asyncio.subprocess.create_subprocess_shell(
command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=work_dir,
)
try:
stdout, stderr = await asyncio.wait_for(
process.communicate(), timeout=30
)
output = stdout.decode('utf-8')
if stderr:
output = stderr.decode('utf-8')
print(f'!##! Error running command: {stderr.decode("utf-8")}')
return CommandResult(content=output, exit_code=process.returncode or 0)
except asyncio.TimeoutError:
process.terminate()
# Allow a brief moment for cleanup
try:
await asyncio.wait_for(process.wait(), timeout=1.0)
except asyncio.TimeoutError:
process.kill() # Force kill if it doesn't terminate cleanly
return CommandResult(content='Command timed out.', exit_code=-1)
except Exception as e:
return CommandResult(
content=f'Error running command: {str(e)}', exit_code=-1
)
+6 -3
View File
@@ -17,7 +17,6 @@ from openhands.events.observation.commands import (
CmdOutputMetadata,
CmdOutputObservation,
)
from openhands.runtime.utils.bash_constants import TIMEOUT_MESSAGE_TEMPLATE
from openhands.utils.shutdown_listener import should_continue
@@ -380,7 +379,9 @@ class BashSession:
metadata = CmdOutputMetadata() # No metadata available
metadata.suffix = (
f'\n[The command has no new output after {self.NO_CHANGE_TIMEOUT_SECONDS} seconds. '
f'{TIMEOUT_MESSAGE_TEMPLATE}]'
"You may wait longer to see additional output by sending empty command '', "
'send other commands to interact with the current process, '
'or send keys to interrupt/kill the command.]'
)
command_output = self._get_command_output(
command,
@@ -413,7 +414,9 @@ class BashSession:
metadata = CmdOutputMetadata() # No metadata available
metadata.suffix = (
f'\n[The command timed out after {timeout} seconds. '
f'{TIMEOUT_MESSAGE_TEMPLATE}]'
"You may wait longer to see additional output by sending empty command '', "
'send other commands to interact with the current process, '
'or send keys to interrupt/kill the command.]'
)
command_output = self._get_command_output(
command,
@@ -1,7 +0,0 @@
# Common timeout message that can be used across different timeout scenarios
TIMEOUT_MESSAGE_TEMPLATE = (
"You may wait longer to see additional output by sending empty command '', "
'send other commands to interact with the current process, '
'send keys to interrupt/kill the command, '
'or use the timeout parameter in execute_bash for future commands.'
)
+9 -13
View File
@@ -44,7 +44,7 @@ class GitHandler:
Returns:
bool: True if inside a Git repository, otherwise False.
"""
cmd = 'git --no-pager rev-parse --is-inside-work-tree'
cmd = 'git rev-parse --is-inside-work-tree'
output = self.execute(cmd, self.cwd)
return output.content.strip() == 'true'
@@ -71,7 +71,7 @@ class GitHandler:
Returns:
bool: True if the reference exists, otherwise False.
"""
cmd = f'git --no-pager rev-parse --verify {ref}'
cmd = f'git rev-parse --verify {ref}'
output = self.execute(cmd, self.cwd)
return output.exit_code == 0
@@ -86,9 +86,9 @@ class GitHandler:
default_branch = self._get_default_branch()
ref_current_branch = f'origin/{current_branch}'
ref_non_default_branch = f'$(git --no-pager merge-base HEAD "$(git --no-pager rev-parse --abbrev-ref origin/{default_branch})")'
ref_non_default_branch = f'$(git merge-base HEAD "$(git rev-parse --abbrev-ref origin/{default_branch})")'
ref_default_branch = 'origin/' + default_branch
ref_new_repo = '$(git --no-pager rev-parse --verify 4b825dc642cb6eb9a060e54bf8d69288fbee4904)' # compares with empty tree
ref_new_repo = '$(git rev-parse --verify 4b825dc642cb6eb9a060e54bf8d69288fbee4904)' # compares with empty tree
refs = [
ref_current_branch,
@@ -116,7 +116,7 @@ class GitHandler:
if not ref:
return ''
cmd = f'git --no-pager show {ref}:{file_path}'
cmd = f'git show {ref}:{file_path}'
output = self.execute(cmd, self.cwd)
return output.content if output.exit_code == 0 else ''
@@ -127,7 +127,7 @@ class GitHandler:
Returns:
str: The name of the primary branch.
"""
cmd = 'git --no-pager remote show origin | grep "HEAD branch"'
cmd = 'git remote show origin | grep "HEAD branch"'
output = self.execute(cmd, self.cwd)
return output.content.split()[-1].strip()
@@ -138,7 +138,7 @@ class GitHandler:
Returns:
str: The name of the current branch.
"""
cmd = 'git --no-pager rev-parse --abbrev-ref HEAD'
cmd = 'git rev-parse --abbrev-ref HEAD'
output = self.execute(cmd, self.cwd)
return output.content.strip()
@@ -153,12 +153,8 @@ class GitHandler:
if not ref:
return []
diff_cmd = f'git --no-pager diff --name-status {ref}'
diff_cmd = f'git diff --name-status {ref}'
output = self.execute(diff_cmd, self.cwd)
if output.exit_code != 0:
raise RuntimeError(
f'Failed to get diff for ref {ref} in {self.cwd}. Command output: {output.content}'
)
return output.content.splitlines()
def _get_untracked_files(self) -> list[dict[str, str]]:
@@ -168,7 +164,7 @@ class GitHandler:
Returns:
list[dict[str, str]]: A list of dictionaries containing file paths and statuses.
"""
cmd = 'git --no-pager ls-files --others --exclude-standard'
cmd = 'git ls-files --others --exclude-standard'
output = self.execute(cmd, self.cwd)
obs_list = output.content.splitlines()
return (
+6 -3
View File
@@ -20,7 +20,6 @@ from openhands.events.observation.commands import (
CmdOutputMetadata,
CmdOutputObservation,
)
from openhands.runtime.utils.bash_constants import TIMEOUT_MESSAGE_TEMPLATE
from openhands.utils.shutdown_listener import should_continue
pythonnet.load('coreclr')
@@ -560,7 +559,9 @@ class WindowsPowershellSession:
else:
metadata.suffix = (
f'\n[The command timed out after {timeout_seconds} seconds. '
f'{TIMEOUT_MESSAGE_TEMPLATE}]'
"You may wait longer to see additional output by sending empty command '', "
'send other commands to interact with the current process, '
'or send keys to interrupt/kill the command.]'
)
return CmdOutputObservation(
@@ -1330,7 +1331,9 @@ class WindowsPowershellSession:
# Align suffix with bash.py timeout message
suffix = (
f'\n[The command timed out after {timeout_seconds} seconds. '
f'{TIMEOUT_MESSAGE_TEMPLATE}]'
"You may wait longer to see additional output by sending empty command '', "
'send other commands to interact with the current process, '
'or send keys to interrupt/kill the command.]'
)
elif shutdown_requested:
# Align suffix with bash.py equivalent (though bash.py might not have specific shutdown message)
@@ -281,23 +281,7 @@ class DockerNestedConversationManager(ConversationManager):
raise ValueError('unsupported_operation')
async def close_session(self, sid: str):
# First try to graceful stop server.
try:
container = self.docker_client.containers.get(f'openhands-runtime-{sid}')
except docker.errors.NotFound as e:
return
try:
nested_url = self.get_nested_url_for_container(container)
async with httpx.AsyncClient(
headers={
'X-Session-API-Key': self._get_session_api_key_for_conversation(sid)
}
) as client:
response = await client.post(f'{nested_url}/api/conversations/{sid}/stop')
response.raise_for_status()
except Exception:
logger.exception("error_stopping_container")
container.stop()
stop_all_containers(f'openhands-runtime-{sid}')
async def get_agent_loop_info(self, user_id: str | None = None, filter_to_sids: set[str] | None = None) -> list[AgentLoopInfo]:
results = []
@@ -369,9 +369,7 @@ class StandaloneConversationManager(ConversationManager):
f'removing connections: {connection_ids_to_remove}',
extra={'session_id': sid},
)
# Perform a graceful shutdown of each connection
for connection_id in connection_ids_to_remove:
await self.sio.disconnect(connection_id)
self._local_connection_id_to_session_id.pop(connection_id, None)
session = self._local_agent_loops_by_sid.pop(sid, None)
+23 -46
View File
@@ -12,7 +12,6 @@ from openhands.events.action import (
)
from openhands.events.action.agent import RecallAction
from openhands.events.async_event_store_wrapper import AsyncEventStoreWrapper
from openhands.events.event_store import EventStore
from openhands.events.observation import (
NullObservation,
)
@@ -125,48 +124,6 @@ async def connect(connection_id: str, environ: dict) -> None:
f'User {user_id} is allowed to connect to conversation {conversation_id}'
)
try:
event_store = EventStore(
conversation_id, conversation_manager.file_store, user_id
)
except FileNotFoundError as e:
logger.error(
f'Failed to create EventStore for conversation {conversation_id}: {e}'
)
raise ConnectionRefusedError(f'Failed to access conversation events: {e}')
logger.info(
f'Replaying event stream for conversation {conversation_id} with connection_id {connection_id}...'
)
agent_state_changed = None
# Create an async store to replay events
async_store = AsyncEventStoreWrapper(event_store, latest_event_id + 1)
# Process all available events
async for event in async_store:
logger.debug(f'oh_event: {event.__class__.__name__}')
if isinstance(
event,
(NullAction, NullObservation, RecallAction),
):
continue
elif isinstance(event, AgentStateChangedObservation):
agent_state_changed = event
else:
await sio.emit('oh_event', event_to_dict(event), to=connection_id)
# Send the agent state changed event last if we have one
if agent_state_changed:
await sio.emit(
'oh_event', event_to_dict(agent_state_changed), to=connection_id
)
logger.info(
f'Finished replaying event stream for conversation {conversation_id}'
)
conversation_init_data = await setup_init_convo_settings(
user_id, conversation_id, providers_set
)
@@ -176,12 +133,32 @@ async def connect(connection_id: str, environ: dict) -> None:
conversation_init_data,
user_id,
)
logger.info(
f'Connected to conversation {conversation_id} with connection_id {connection_id}. Replaying event stream...'
)
agent_state_changed = None
if agent_loop_info is None:
raise ConnectionRefusedError('Failed to join conversation')
async_store = AsyncEventStoreWrapper(
agent_loop_info.event_store, latest_event_id + 1
)
async for event in async_store:
logger.debug(f'oh_event: {event.__class__.__name__}')
if isinstance(
event,
(NullAction, NullObservation, RecallAction),
):
continue
elif isinstance(event, AgentStateChangedObservation):
agent_state_changed = event
else:
await sio.emit('oh_event', event_to_dict(event), to=connection_id)
if agent_state_changed:
await sio.emit(
'oh_event', event_to_dict(agent_state_changed), to=connection_id
)
logger.info(
f'Successfully joined conversation {conversation_id} with connection_id {connection_id}'
f'Finished replaying event stream for conversation {conversation_id}'
)
except ConnectionRefusedError:
# Close the broken connection after sending an error message
+7 -45
View File
@@ -1,4 +1,3 @@
import os
import re
from typing import Annotated
@@ -11,10 +10,9 @@ from openhands.core.logger import openhands_logger as logger
from openhands.integrations.github.github_service import GithubServiceImpl
from openhands.integrations.gitlab.gitlab_service import GitLabServiceImpl
from openhands.integrations.provider import ProviderToken
from openhands.integrations.service_types import GitService, ProviderType
from openhands.integrations.service_types import ProviderType
from openhands.server.dependencies import get_dependencies
from openhands.server.shared import ConversationStoreImpl, config, server_config
from openhands.server.types import AppMode
from openhands.server.shared import ConversationStoreImpl, config
from openhands.server.user_auth import (
get_access_token,
get_provider_tokens,
@@ -22,31 +20,7 @@ from openhands.server.user_auth import (
)
from openhands.storage.data_models.conversation_metadata import ConversationMetadata
mcp_server = FastMCP(
'mcp', stateless_http=True, dependencies=get_dependencies(), mask_error_details=True
)
HOST = f'https://{os.getenv("WEB_HOST", "app.all-hands.dev").strip()}'
CONVO_URL = HOST + '/{}'
async def get_convo_link(service: GitService, conversation_id: str, body: str) -> str:
"""
Appends a followup link, in the PR body, to the OpenHands conversation that opened the PR
"""
if server_config.app_mode != AppMode.SAAS:
return body
user = await service.get_user()
username = user.login
convo_url = CONVO_URL.format(conversation_id)
convo_link = (
f'@{username} can click here to [continue refining the PR]({convo_url})'
)
body += f'\n\n{convo_link}'
return body
mcp_server = FastMCP('mcp', stateless_http=True, dependencies=get_dependencies(), mask_error_details=True)
async def save_pr_metadata(
user_id: str, conversation_id: str, tool_result: str
@@ -110,11 +84,6 @@ async def create_pr(
base_domain=github_token.host,
)
try:
body = await get_convo_link(github_service, conversation_id, body or '')
except Exception as e:
logger.warning(f'Failed to append convo link: {e}')
try:
response = await github_service.create_pr(
repo_name=repo_name,
@@ -128,7 +97,7 @@ async def create_pr(
await save_pr_metadata(user_id, conversation_id, response)
except Exception as e:
error = f'Error creating pull request: {e}'
error = f"Error creating pull request: {e}"
raise ToolError(str(error))
return response
@@ -163,7 +132,7 @@ async def create_mr(
else ProviderToken()
)
gitlab_service = GitLabServiceImpl(
github_service = GitLabServiceImpl(
user_id=github_token.user_id,
external_auth_id=user_id,
external_auth_token=access_token,
@@ -172,14 +141,7 @@ async def create_mr(
)
try:
description = await get_convo_link(
gitlab_service, conversation_id, description or ''
)
except Exception as e:
logger.warning(f'Failed to append convo link: {e}')
try:
response = await gitlab_service.create_mr(
response = await github_service.create_mr(
id=id,
source_branch=source_branch,
target_branch=target_branch,
@@ -191,7 +153,7 @@ async def create_mr(
await save_pr_metadata(user_id, conversation_id, response)
except Exception as e:
error = f'Error creating merge request: {e}'
error = f"Error creating merge request: {e}"
raise ToolError(str(error))
return response
@@ -124,8 +124,8 @@ async def create_new_conversation(
image_urls=image_urls or [],
)
if attach_convo_id:
logger.warning('Attaching convo ID is deprecated, skipping process')
if attach_convo_id and conversation_instructions:
conversation_instructions = conversation_instructions.format(conversation_id)
agent_loop_info = await conversation_manager.maybe_start_agent_loop(
conversation_id,
+1 -1
View File
@@ -12,7 +12,7 @@ async def get_conversation_store(request: Request) -> ConversationStore | None:
)
if conversation_store:
return conversation_store
user_id = await get_user_id(request)
user_id = get_user_id(request)
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
request.state.conversation_store = conversation_store
return conversation_store
Generated
+58 -61
View File
@@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand.
# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand.
[[package]]
name = "aiohappyeyeballs"
@@ -373,7 +373,7 @@ description = "LTS Port of Python audioop"
optional = false
python-versions = ">=3.13"
groups = ["main"]
markers = "python_version == \"3.13\""
markers = "python_version >= \"3.13\""
files = [
{file = "audioop_lts-0.2.1-cp313-abi3-macosx_10_13_universal2.whl", hash = "sha256:fd1345ae99e17e6910f47ce7d52673c6a1a70820d78b67de1b7abb3af29c426a"},
{file = "audioop_lts-0.2.1-cp313-abi3-macosx_10_13_x86_64.whl", hash = "sha256:e175350da05d2087e12cea8e72a70a1a8b14a17e92ed2022952a4419689ede5e"},
@@ -2969,8 +2969,8 @@ files = [
google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]}
google-auth = ">=2.14.1,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0dev"
proto-plus = [
{version = ">=1.22.3,<2.0.0dev"},
{version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""},
{version = ">=1.22.3,<2.0.0dev"},
]
protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0dev"
@@ -2992,8 +2992,8 @@ googleapis-common-protos = ">=1.56.2,<2.0.0"
grpcio = {version = ">=1.49.1,<2.0.0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}
grpcio-status = {version = ">=1.49.1,<2.0.0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}
proto-plus = [
{version = ">=1.22.3,<2.0.0"},
{version = ">=1.25.0,<2.0.0", markers = "python_version >= \"3.13\""},
{version = ">=1.22.3,<2.0.0"},
]
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<7.0.0"
requests = ">=2.18.0,<3.0.0"
@@ -3211,8 +3211,8 @@ google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0", extras
google-auth = ">=2.14.1,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0"
grpc-google-iam-v1 = ">=0.14.0,<1.0.0"
proto-plus = [
{version = ">=1.22.3,<2.0.0"},
{version = ">=1.25.0,<2.0.0", markers = "python_version >= \"3.13\""},
{version = ">=1.22.3,<2.0.0", markers = "python_version < \"3.13\""},
]
protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<7.0.0"
@@ -6456,106 +6456,102 @@ et-xmlfile = "*"
[[package]]
name = "opentelemetry-api"
version = "1.34.0"
version = "1.25.0"
description = "OpenTelemetry Python API"
optional = false
python-versions = ">=3.9"
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "opentelemetry_api-1.34.0-py3-none-any.whl", hash = "sha256:390b81984affe4453180820ca518de55e3be051111e70cc241bb3b0071ca3a2c"},
{file = "opentelemetry_api-1.34.0.tar.gz", hash = "sha256:48d167589134799093005b7f7f347c69cc67859c693b17787f334fbe8871279f"},
{file = "opentelemetry_api-1.25.0-py3-none-any.whl", hash = "sha256:757fa1aa020a0f8fa139f8959e53dec2051cc26b832e76fa839a6d76ecefd737"},
{file = "opentelemetry_api-1.25.0.tar.gz", hash = "sha256:77c4985f62f2614e42ce77ee4c9da5fa5f0bc1e1821085e9a47533a9323ae869"},
]
[package.dependencies]
importlib-metadata = ">=6.0,<8.8.0"
typing-extensions = ">=4.5.0"
deprecated = ">=1.2.6"
importlib-metadata = ">=6.0,<=7.1"
[[package]]
name = "opentelemetry-exporter-otlp-proto-common"
version = "1.34.0"
version = "1.25.0"
description = "OpenTelemetry Protobuf encoding"
optional = false
python-versions = ">=3.9"
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "opentelemetry_exporter_otlp_proto_common-1.34.0-py3-none-any.whl", hash = "sha256:a5ab7a9b7c3c7ba957c8ddcb08c0c93b1d732e066f544682a250ecf4d7a9ceef"},
{file = "opentelemetry_exporter_otlp_proto_common-1.34.0.tar.gz", hash = "sha256:5916d9ceda8c733adbec5e9cecf654fbf359e9f619ff43214277076fba888557"},
{file = "opentelemetry_exporter_otlp_proto_common-1.25.0-py3-none-any.whl", hash = "sha256:15637b7d580c2675f70246563363775b4e6de947871e01d0f4e3881d1848d693"},
{file = "opentelemetry_exporter_otlp_proto_common-1.25.0.tar.gz", hash = "sha256:c93f4e30da4eee02bacd1e004eb82ce4da143a2f8e15b987a9f603e0a85407d3"},
]
[package.dependencies]
opentelemetry-proto = "1.34.0"
opentelemetry-proto = "1.25.0"
[[package]]
name = "opentelemetry-exporter-otlp-proto-grpc"
version = "1.34.0"
version = "1.25.0"
description = "OpenTelemetry Collector Protobuf over gRPC Exporter"
optional = false
python-versions = ">=3.9"
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "opentelemetry_exporter_otlp_proto_grpc-1.34.0-py3-none-any.whl", hash = "sha256:31c41017af85833242d49beb07bde7341b0a145f0b898ee383f3e3019037afb1"},
{file = "opentelemetry_exporter_otlp_proto_grpc-1.34.0.tar.gz", hash = "sha256:a634425340f506d5ebf641c92d88eb873754d4c5259b5b816afb234c6f87b37d"},
{file = "opentelemetry_exporter_otlp_proto_grpc-1.25.0-py3-none-any.whl", hash = "sha256:3131028f0c0a155a64c430ca600fd658e8e37043cb13209f0109db5c1a3e4eb4"},
{file = "opentelemetry_exporter_otlp_proto_grpc-1.25.0.tar.gz", hash = "sha256:c0b1661415acec5af87625587efa1ccab68b873745ca0ee96b69bb1042087eac"},
]
[package.dependencies]
deprecated = ">=1.2.6"
googleapis-common-protos = ">=1.52,<2.0"
grpcio = [
{version = ">=1.63.2,<2.0.0", markers = "python_version < \"3.13\""},
{version = ">=1.66.2,<2.0.0", markers = "python_version >= \"3.13\""},
]
grpcio = ">=1.0.0,<2.0.0"
opentelemetry-api = ">=1.15,<2.0"
opentelemetry-exporter-otlp-proto-common = "1.34.0"
opentelemetry-proto = "1.34.0"
opentelemetry-sdk = ">=1.34.0,<1.35.0"
typing-extensions = ">=4.5.0"
opentelemetry-exporter-otlp-proto-common = "1.25.0"
opentelemetry-proto = "1.25.0"
opentelemetry-sdk = ">=1.25.0,<1.26.0"
[[package]]
name = "opentelemetry-proto"
version = "1.34.0"
version = "1.25.0"
description = "OpenTelemetry Python Proto"
optional = false
python-versions = ">=3.9"
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "opentelemetry_proto-1.34.0-py3-none-any.whl", hash = "sha256:ffb1f1b27552fda5a1cd581e34243cc0b6f134fb14c1c2a33cc3b4b208c9bf97"},
{file = "opentelemetry_proto-1.34.0.tar.gz", hash = "sha256:73e40509b692630a47192888424f7e0b8fb19d9ecf2f04e6f708170cd3346dfe"},
{file = "opentelemetry_proto-1.25.0-py3-none-any.whl", hash = "sha256:f07e3341c78d835d9b86665903b199893befa5e98866f63d22b00d0b7ca4972f"},
{file = "opentelemetry_proto-1.25.0.tar.gz", hash = "sha256:35b6ef9dc4a9f7853ecc5006738ad40443701e52c26099e197895cbda8b815a3"},
]
[package.dependencies]
protobuf = ">=5.0,<6.0"
protobuf = ">=3.19,<5.0"
[[package]]
name = "opentelemetry-sdk"
version = "1.34.0"
version = "1.25.0"
description = "OpenTelemetry Python SDK"
optional = false
python-versions = ">=3.9"
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "opentelemetry_sdk-1.34.0-py3-none-any.whl", hash = "sha256:7850bcd5b5c95f9aae48603d6592bdad5c7bdef50c03e06393f8f457d891fe32"},
{file = "opentelemetry_sdk-1.34.0.tar.gz", hash = "sha256:719559622afcd515c2aec462ccb749ba2e70075a01df45837623643814d33716"},
{file = "opentelemetry_sdk-1.25.0-py3-none-any.whl", hash = "sha256:d97ff7ec4b351692e9d5a15af570c693b8715ad78b8aafbec5c7100fe966b4c9"},
{file = "opentelemetry_sdk-1.25.0.tar.gz", hash = "sha256:ce7fc319c57707ef5bf8b74fb9f8ebdb8bfafbe11898410e0d2a761d08a98ec7"},
]
[package.dependencies]
opentelemetry-api = "1.34.0"
opentelemetry-semantic-conventions = "0.55b0"
typing-extensions = ">=4.5.0"
opentelemetry-api = "1.25.0"
opentelemetry-semantic-conventions = "0.46b0"
typing-extensions = ">=3.7.4"
[[package]]
name = "opentelemetry-semantic-conventions"
version = "0.55b0"
version = "0.46b0"
description = "OpenTelemetry Semantic Conventions"
optional = false
python-versions = ">=3.9"
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "opentelemetry_semantic_conventions-0.55b0-py3-none-any.whl", hash = "sha256:63bb15b67377700e51c422d0d24092ca6ce9f3a4cb6f032375aa8af1fc2aab65"},
{file = "opentelemetry_semantic_conventions-0.55b0.tar.gz", hash = "sha256:933d2e20c2dbc0f9b2f4f52138282875b4b14c66c491f5273bcdef1781368e9c"},
{file = "opentelemetry_semantic_conventions-0.46b0-py3-none-any.whl", hash = "sha256:6daef4ef9fa51d51855d9f8e0ccd3a1bd59e0e545abe99ac6203804e36ab3e07"},
{file = "opentelemetry_semantic_conventions-0.46b0.tar.gz", hash = "sha256:fbc982ecbb6a6e90869b15c1673be90bd18c8a56ff1cffc0864e38e2edffaefa"},
]
[package.dependencies]
opentelemetry-api = "1.34.0"
typing-extensions = ">=4.5.0"
opentelemetry-api = "1.25.0"
[[package]]
name = "overrides"
@@ -7175,23 +7171,23 @@ testing = ["google-api-core (>=1.31.5)"]
[[package]]
name = "protobuf"
version = "5.29.5"
version = "4.25.8"
description = ""
optional = false
python-versions = ">=3.8"
groups = ["main", "evaluation"]
files = [
{file = "protobuf-5.29.5-cp310-abi3-win32.whl", hash = "sha256:3f1c6468a2cfd102ff4703976138844f78ebd1fb45f49011afc5139e9e283079"},
{file = "protobuf-5.29.5-cp310-abi3-win_amd64.whl", hash = "sha256:3f76e3a3675b4a4d867b52e4a5f5b78a2ef9565549d4037e06cf7b0942b1d3fc"},
{file = "protobuf-5.29.5-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:e38c5add5a311f2a6eb0340716ef9b039c1dfa428b28f25a7838ac329204a671"},
{file = "protobuf-5.29.5-cp38-abi3-manylinux2014_aarch64.whl", hash = "sha256:fa18533a299d7ab6c55a238bf8629311439995f2e7eca5caaff08663606e9015"},
{file = "protobuf-5.29.5-cp38-abi3-manylinux2014_x86_64.whl", hash = "sha256:63848923da3325e1bf7e9003d680ce6e14b07e55d0473253a690c3a8b8fd6e61"},
{file = "protobuf-5.29.5-cp38-cp38-win32.whl", hash = "sha256:ef91363ad4faba7b25d844ef1ada59ff1604184c0bcd8b39b8a6bef15e1af238"},
{file = "protobuf-5.29.5-cp38-cp38-win_amd64.whl", hash = "sha256:7318608d56b6402d2ea7704ff1e1e4597bee46d760e7e4dd42a3d45e24b87f2e"},
{file = "protobuf-5.29.5-cp39-cp39-win32.whl", hash = "sha256:6f642dc9a61782fa72b90878af134c5afe1917c89a568cd3476d758d3c3a0736"},
{file = "protobuf-5.29.5-cp39-cp39-win_amd64.whl", hash = "sha256:470f3af547ef17847a28e1f47200a1cbf0ba3ff57b7de50d22776607cd2ea353"},
{file = "protobuf-5.29.5-py3-none-any.whl", hash = "sha256:6cf42630262c59b2d8de33954443d94b746c952b01434fc58a417fdbd2e84bd5"},
{file = "protobuf-5.29.5.tar.gz", hash = "sha256:bc1463bafd4b0929216c35f437a8e28731a2b7fe3d98bb77a600efced5a15c84"},
{file = "protobuf-4.25.8-cp310-abi3-win32.whl", hash = "sha256:504435d831565f7cfac9f0714440028907f1975e4bed228e58e72ecfff58a1e0"},
{file = "protobuf-4.25.8-cp310-abi3-win_amd64.whl", hash = "sha256:bd551eb1fe1d7e92c1af1d75bdfa572eff1ab0e5bf1736716814cdccdb2360f9"},
{file = "protobuf-4.25.8-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:ca809b42f4444f144f2115c4c1a747b9a404d590f18f37e9402422033e464e0f"},
{file = "protobuf-4.25.8-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:9ad7ef62d92baf5a8654fbb88dac7fa5594cfa70fd3440488a5ca3bfc6d795a7"},
{file = "protobuf-4.25.8-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:83e6e54e93d2b696a92cad6e6efc924f3850f82b52e1563778dfab8b355101b0"},
{file = "protobuf-4.25.8-cp38-cp38-win32.whl", hash = "sha256:27d498ffd1f21fb81d987a041c32d07857d1d107909f5134ba3350e1ce80a4af"},
{file = "protobuf-4.25.8-cp38-cp38-win_amd64.whl", hash = "sha256:d552c53d0415449c8d17ced5c341caba0d89dbf433698e1436c8fa0aae7808a3"},
{file = "protobuf-4.25.8-cp39-cp39-win32.whl", hash = "sha256:077ff8badf2acf8bc474406706ad890466274191a48d0abd3bd6987107c9cde5"},
{file = "protobuf-4.25.8-cp39-cp39-win_amd64.whl", hash = "sha256:f4510b93a3bec6eba8fd8f1093e9d7fb0d4a24d1a81377c10c0e5bbfe9e4ed24"},
{file = "protobuf-4.25.8-py3-none-any.whl", hash = "sha256:15a0af558aa3b13efef102ae6e4f3efac06f1eea11afb3a57db2901447d9fb59"},
{file = "protobuf-4.25.8.tar.gz", hash = "sha256:6135cf8affe1fc6f76cced2641e4ea8d3e59518d1f24ae41ba97bcad82d397cd"},
]
[[package]]
@@ -9339,6 +9335,7 @@ files = [
{file = "setuptools-80.9.0-py3-none-any.whl", hash = "sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922"},
{file = "setuptools-80.9.0.tar.gz", hash = "sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c"},
]
markers = {evaluation = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
[package.extras]
check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\"", "ruff (>=0.8.0) ; sys_platform != \"cygwin\""]
@@ -9581,7 +9578,7 @@ description = "Standard library aifc redistribution. \"dead battery\"."
optional = false
python-versions = "*"
groups = ["main"]
markers = "python_version == \"3.13\""
markers = "python_version >= \"3.13\""
files = [
{file = "standard_aifc-3.13.0-py3-none-any.whl", hash = "sha256:f7ae09cc57de1224a0dd8e3eb8f73830be7c3d0bc485de4c1f82b4a7f645ac66"},
{file = "standard_aifc-3.13.0.tar.gz", hash = "sha256:64e249c7cb4b3daf2fdba4e95721f811bde8bdfc43ad9f936589b7bb2fae2e43"},
@@ -9598,7 +9595,7 @@ description = "Standard library chunk redistribution. \"dead battery\"."
optional = false
python-versions = "*"
groups = ["main"]
markers = "python_version == \"3.13\""
markers = "python_version >= \"3.13\""
files = [
{file = "standard_chunk-3.13.0-py3-none-any.whl", hash = "sha256:17880a26c285189c644bd5bd8f8ed2bdb795d216e3293e6dbe55bbd848e2982c"},
{file = "standard_chunk-3.13.0.tar.gz", hash = "sha256:4ac345d37d7e686d2755e01836b8d98eda0d1a3ee90375e597ae43aaf064d654"},
@@ -11760,4 +11757,4 @@ cffi = ["cffi (>=1.11)"]
[metadata]
lock-version = "2.1"
python-versions = "^3.12,<3.14"
content-hash = "d9f6c24fa80dd191f180af0c802ea11ecf514d86aaa421cb19a9bb497362c101"
content-hash = "eaa84e30dbafb061a75b4b173a8ba16542c4a03ab74583c55ab282cd6119e430"
+10 -10
View File
@@ -20,12 +20,12 @@ packages = [
[tool.poetry.dependencies]
python = "^3.12,<3.14"
litellm = "^1.60.0, !=1.64.4, !=1.67.*" # avoid 1.64.4 (known bug) & 1.67.* (known bug #10272)
aiohttp = ">=3.9.0,!=3.11.13" # Pin to avoid yanked version 3.11.13
google-generativeai = "*" # To use litellm with Gemini Pro API
google-api-python-client = "^2.164.0" # For Google Sheets API
google-auth-httplib2 = "*" # For Google Sheets authentication
google-auth-oauthlib = "*" # For Google Sheets OAuth
litellm = "^1.60.0, !=1.64.4, !=1.67.*" # avoid 1.64.4 (known bug) & 1.67.* (known bug #10272)
aiohttp = ">=3.9.0,!=3.11.13" # Pin to avoid yanked version 3.11.13
google-generativeai = "*" # To use litellm with Gemini Pro API
google-api-python-client = "^2.164.0" # For Google Sheets API
google-auth-httplib2 = "*" # For Google Sheets authentication
google-auth-oauthlib = "*" # For Google Sheets OAuth
termcolor = "*"
docker = "*"
fastapi = "*"
@@ -34,7 +34,7 @@ types-toml = "*"
uvicorn = "*"
numpy = "*"
json-repair = "*"
browsergym-core = "0.13.3" # integrate browsergym-core as the browsing interface
browsergym-core = "0.13.3" # integrate browsergym-core as the browsing interface
html2text = "*"
e2b = ">=1.0.5,<1.4.0"
pexpect = "*"
@@ -49,9 +49,9 @@ tornado = "*"
python-dotenv = "*"
rapidfuzz = "^3.9.0"
whatthepatch = "^1.0.6"
protobuf = "^5.0.0,<6.0.0" # Updated to support newer opentelemetry
opentelemetry-api = "^1.33.1"
opentelemetry-exporter-otlp-proto-grpc = "^1.33.1"
protobuf = "^4.21.6,<5.0.0" # chromadb currently fails on 5.0+
opentelemetry-api = "1.25.0"
opentelemetry-exporter-otlp-proto-grpc = "1.25.0"
modal = ">=0.66.26,<0.78.0"
runloop-api-client = "0.33.0"
libtmux = ">=0.37,<0.40"
+4 -11
View File
@@ -16,16 +16,6 @@ from openhands.events.action import CmdRunAction
from openhands.events.observation import CmdOutputObservation, ErrorObservation
from openhands.runtime.impl.cli.cli_runtime import CLIRuntime
from openhands.runtime.impl.local.local_runtime import LocalRuntime
from openhands.runtime.utils.bash_constants import TIMEOUT_MESSAGE_TEMPLATE
def get_timeout_suffix(timeout_seconds):
"""Helper function to generate the expected timeout suffix."""
return (
f'[The command timed out after {timeout_seconds} seconds. '
f'{TIMEOUT_MESSAGE_TEMPLATE}]'
)
# ============================================================================================================================
# Bash-specific tests
@@ -66,7 +56,10 @@ def test_bash_server(temp_dir, runtime_cls, run_as_openhands):
if runtime_cls == CLIRuntime:
assert '[The command timed out after 1.0 seconds.]' in obs.metadata.suffix
else:
assert get_timeout_suffix(1.0) in obs.metadata.suffix
assert (
"[The command timed out after 1.0 seconds. You may wait longer to see additional output by sending empty command '', send other commands to interact with the current process, or send keys to interrupt/kill the command.]"
in obs.metadata.suffix
)
action = CmdRunAction(command='C-c', is_input=True)
action.set_hard_timeout(30)
+1 -1
View File
@@ -589,7 +589,7 @@
"working_dir": null,
"py_interpreter_path": null,
"prefix": "",
"suffix": "\n[The command has no new output after 30 seconds. You may wait longer to see additional output by sending empty command '', send other commands to interact with the current process, send keys to interrupt/kill the command, or use the timeout parameter in execute_bash for future commands.]"
"suffix": "\n[The command has no new output after 30 seconds. You may wait longer to see additional output by sending empty command '', send other commands to interact with the current process, or send keys to interrupt/kill the command.]"
},
"hidden": false
},
+42 -16
View File
@@ -5,15 +5,6 @@ import time
from openhands.core.logger import openhands_logger as logger
from openhands.events.action import CmdRunAction
from openhands.runtime.utils.bash import BashCommandStatus, BashSession
from openhands.runtime.utils.bash_constants import TIMEOUT_MESSAGE_TEMPLATE
def get_no_change_timeout_suffix(timeout_seconds):
"""Helper function to generate the expected no-change timeout suffix."""
return (
f'\n[The command has no new output after {timeout_seconds} seconds. '
f'{TIMEOUT_MESSAGE_TEMPLATE}]'
)
def test_session_initialization():
@@ -92,7 +83,12 @@ def test_long_running_command_follow_by_execute():
assert '1' in obs.content # First number should appear before timeout
assert obs.metadata.exit_code == -1 # -1 indicates command is still running
assert session.prev_status == BashCommandStatus.NO_CHANGE_TIMEOUT
assert obs.metadata.suffix == get_no_change_timeout_suffix(2)
assert obs.metadata.suffix == (
'\n[The command has no new output after 2 seconds. '
"You may wait longer to see additional output by sending empty command '', "
'send other commands to interact with the current process, '
'or send keys to interrupt/kill the command.]'
)
assert obs.metadata.prefix == ''
# Continue watching output
@@ -100,7 +96,12 @@ def test_long_running_command_follow_by_execute():
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert '2' in obs.content
assert obs.metadata.prefix == '[Below is the output of the previous command.]\n'
assert obs.metadata.suffix == get_no_change_timeout_suffix(2)
assert obs.metadata.suffix == (
'\n[The command has no new output after 2 seconds. '
"You may wait longer to see additional output by sending empty command '', "
'send other commands to interact with the current process, '
'or send keys to interrupt/kill the command.]'
)
assert obs.metadata.exit_code == -1 # -1 indicates command is still running
assert session.prev_status == BashCommandStatus.NO_CHANGE_TIMEOUT
@@ -141,7 +142,12 @@ def test_interactive_command():
assert 'Enter name:' in obs.content
assert obs.metadata.exit_code == -1 # -1 indicates command is still running
assert session.prev_status == BashCommandStatus.NO_CHANGE_TIMEOUT
assert obs.metadata.suffix == get_no_change_timeout_suffix(3)
assert obs.metadata.suffix == (
'\n[The command has no new output after 3 seconds. '
"You may wait longer to see additional output by sending empty command '', "
'send other commands to interact with the current process, '
'or send keys to interrupt/kill the command.]'
)
assert obs.metadata.prefix == ''
# Send input
@@ -158,21 +164,36 @@ def test_interactive_command():
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert obs.metadata.exit_code == -1
assert session.prev_status == BashCommandStatus.NO_CHANGE_TIMEOUT
assert obs.metadata.suffix == get_no_change_timeout_suffix(3)
assert obs.metadata.suffix == (
'\n[The command has no new output after 3 seconds. '
"You may wait longer to see additional output by sending empty command '', "
'send other commands to interact with the current process, '
'or send keys to interrupt/kill the command.]'
)
assert obs.metadata.prefix == ''
obs = session.execute(CmdRunAction('line 1', is_input=True))
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert obs.metadata.exit_code == -1
assert session.prev_status == BashCommandStatus.NO_CHANGE_TIMEOUT
assert obs.metadata.suffix == get_no_change_timeout_suffix(3)
assert obs.metadata.suffix == (
'\n[The command has no new output after 3 seconds. '
"You may wait longer to see additional output by sending empty command '', "
'send other commands to interact with the current process, '
'or send keys to interrupt/kill the command.]'
)
assert obs.metadata.prefix == '[Below is the output of the previous command.]\n'
obs = session.execute(CmdRunAction('line 2', is_input=True))
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert obs.metadata.exit_code == -1
assert session.prev_status == BashCommandStatus.NO_CHANGE_TIMEOUT
assert obs.metadata.suffix == get_no_change_timeout_suffix(3)
assert obs.metadata.suffix == (
'\n[The command has no new output after 3 seconds. '
"You may wait longer to see additional output by sending empty command '', "
'send other commands to interact with the current process, '
'or send keys to interrupt/kill the command.]'
)
assert obs.metadata.prefix == '[Below is the output of the previous command.]\n'
obs = session.execute(CmdRunAction('EOF', is_input=True))
@@ -195,7 +216,12 @@ def test_ctrl_c():
)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert 'looping' in obs.content
assert obs.metadata.suffix == get_no_change_timeout_suffix(2)
assert obs.metadata.suffix == (
'\n[The command has no new output after 2 seconds. '
"You may wait longer to see additional output by sending empty command '', "
'send other commands to interact with the current process, '
'or send keys to interrupt/kill the command.]'
)
assert obs.metadata.prefix == ''
assert obs.metadata.exit_code == -1 # -1 indicates command is still running
assert session.prev_status == BashCommandStatus.NO_CHANGE_TIMEOUT
+40 -78
View File
@@ -46,40 +46,28 @@ class TestGitHandler(unittest.TestCase):
def _setup_git_repos(self):
"""Set up real git repositories for testing."""
# Set up origin repository
self._execute_command('git init --initial-branch=main', self.origin_dir)
self._execute_command(
'git --no-pager init --initial-branch=main', self.origin_dir
)
self._execute_command(
"git --no-pager config user.email 'test@example.com'", self.origin_dir
)
self._execute_command(
"git --no-pager config user.name 'Test User'", self.origin_dir
"git config user.email 'test@example.com'", self.origin_dir
)
self._execute_command("git config user.name 'Test User'", self.origin_dir)
# Create a file and commit it
with open(os.path.join(self.origin_dir, 'file1.txt'), 'w') as f:
f.write('Original content')
self._execute_command('git --no-pager add file1.txt', self.origin_dir)
self._execute_command(
"git --no-pager commit -m 'Initial commit'", self.origin_dir
)
self._execute_command('git add file1.txt', self.origin_dir)
self._execute_command("git commit -m 'Initial commit'", self.origin_dir)
# Clone the origin repository to local
self._execute_command(f'git clone {self.origin_dir} {self.local_dir}')
self._execute_command(
f'git --no-pager clone {self.origin_dir} {self.local_dir}'
)
self._execute_command(
"git --no-pager config user.email 'test@example.com'", self.local_dir
)
self._execute_command(
"git --no-pager config user.name 'Test User'", self.local_dir
"git config user.email 'test@example.com'", self.local_dir
)
self._execute_command("git config user.name 'Test User'", self.local_dir)
# Create a feature branch in the local repository
self._execute_command(
'git --no-pager checkout -b feature-branch', self.local_dir
)
self._execute_command('git checkout -b feature-branch', self.local_dir)
# Modify a file and create a new file
with open(os.path.join(self.local_dir, 'file1.txt'), 'w') as f:
@@ -89,40 +77,32 @@ class TestGitHandler(unittest.TestCase):
f.write('New file content')
# Add and commit file1.txt changes to create a baseline
self._execute_command('git --no-pager add file1.txt', self.local_dir)
self._execute_command(
"git --no-pager commit -m 'Update file1.txt'", self.local_dir
)
self._execute_command('git add file1.txt', self.local_dir)
self._execute_command("git commit -m 'Update file1.txt'", self.local_dir)
# Add and commit file2.txt, then modify it
self._execute_command('git --no-pager add file2.txt', self.local_dir)
self._execute_command(
"git --no-pager commit -m 'Add file2.txt'", self.local_dir
)
self._execute_command('git add file2.txt', self.local_dir)
self._execute_command("git commit -m 'Add file2.txt'", self.local_dir)
# Modify file2.txt and stage it
with open(os.path.join(self.local_dir, 'file2.txt'), 'w') as f:
f.write('Modified new file content')
self._execute_command('git --no-pager add file2.txt', self.local_dir)
self._execute_command('git add file2.txt', self.local_dir)
# Create a file that will be deleted
with open(os.path.join(self.local_dir, 'file3.txt'), 'w') as f:
f.write('File to be deleted')
self._execute_command('git --no-pager add file3.txt', self.local_dir)
self._execute_command(
"git --no-pager commit -m 'Add file3.txt'", self.local_dir
)
self._execute_command('git --no-pager rm file3.txt', self.local_dir)
self._execute_command('git add file3.txt', self.local_dir)
self._execute_command("git commit -m 'Add file3.txt'", self.local_dir)
self._execute_command('git rm file3.txt', self.local_dir)
# Modify file1.txt again but don't stage it (unstaged change)
with open(os.path.join(self.local_dir, 'file1.txt'), 'w') as f:
f.write('Modified content again')
# Push the feature branch to origin
self._execute_command(
'git --no-pager push -u origin feature-branch', self.local_dir
)
self._execute_command('git push -u origin feature-branch', self.local_dir)
def test_is_git_repo(self):
"""Test that _is_git_repo returns True for a git repository."""
@@ -131,7 +111,7 @@ class TestGitHandler(unittest.TestCase):
# Verify the command was executed
self.assertTrue(
any(
cmd == 'git --no-pager rev-parse --is-inside-work-tree'
cmd == 'git rev-parse --is-inside-work-tree'
for cmd, _ in self.executed_commands
)
)
@@ -144,7 +124,7 @@ class TestGitHandler(unittest.TestCase):
# Verify the command was executed
self.assertTrue(
any(
cmd == 'git --no-pager remote show origin | grep "HEAD branch"'
cmd == 'git remote show origin | grep "HEAD branch"'
for cmd, _ in self.executed_commands
)
)
@@ -153,12 +133,11 @@ class TestGitHandler(unittest.TestCase):
"""Test that _get_current_branch returns the correct branch name."""
branch = self.git_handler._get_current_branch()
self.assertEqual(branch, 'feature-branch')
print('executed commands:', self.executed_commands)
# Verify the command was executed
self.assertTrue(
any(
cmd == 'git --no-pager rev-parse --abbrev-ref HEAD'
cmd == 'git rev-parse --abbrev-ref HEAD'
for cmd, _ in self.executed_commands
)
)
@@ -173,7 +152,7 @@ class TestGitHandler(unittest.TestCase):
verify_commands = [
cmd
for cmd, _ in self.executed_commands
if cmd.startswith('git --no-pager rev-parse --verify')
if cmd.startswith('git rev-parse --verify')
]
# First should check origin/feature-branch (current branch)
@@ -183,17 +162,13 @@ class TestGitHandler(unittest.TestCase):
self.assertEqual(ref, 'origin/feature-branch')
# Verify the ref exists
result = self._execute_command(
f'git --no-pager rev-parse --verify {ref}', self.local_dir
)
result = self._execute_command(f'git rev-parse --verify {ref}', self.local_dir)
self.assertEqual(result.exit_code, 0)
def test_get_valid_ref_without_origin_current_branch(self):
"""Test that _get_valid_ref falls back to default branch when current branch doesn't exist in origin."""
# Create a new branch that doesn't exist in origin
self._execute_command(
'git --no-pager checkout -b new-local-branch', self.local_dir
)
self._execute_command('git checkout -b new-local-branch', self.local_dir)
# Clear the executed commands to start fresh
self.executed_commands = []
@@ -205,7 +180,7 @@ class TestGitHandler(unittest.TestCase):
verify_commands = [
cmd
for cmd, _ in self.executed_commands
if cmd.startswith('git --no-pager rev-parse --verify')
if cmd.startswith('git rev-parse --verify')
]
# Should have tried origin/new-local-branch first (which doesn't exist)
@@ -218,9 +193,7 @@ class TestGitHandler(unittest.TestCase):
self.assertTrue(ref == 'origin/main' or 'merge-base' in ref)
# Verify the ref exists
result = self._execute_command(
f'git --no-pager rev-parse --verify {ref}', self.local_dir
)
result = self._execute_command(f'git rev-parse --verify {ref}', self.local_dir)
self.assertEqual(result.exit_code, 0)
def test_get_valid_ref_without_origin(self):
@@ -230,21 +203,15 @@ class TestGitHandler(unittest.TestCase):
os.makedirs(no_origin_dir, exist_ok=True)
# Initialize git repo without origin
self._execute_command('git --no-pager init', no_origin_dir)
self._execute_command(
"git --no-pager config user.email 'test@example.com'", no_origin_dir
)
self._execute_command(
"git --no-pager config user.name 'Test User'", no_origin_dir
)
self._execute_command('git init', no_origin_dir)
self._execute_command("git config user.email 'test@example.com'", no_origin_dir)
self._execute_command("git config user.name 'Test User'", no_origin_dir)
# Create a file and commit it
with open(os.path.join(no_origin_dir, 'file1.txt'), 'w') as f:
f.write('Content in repo without origin')
self._execute_command('git --no-pager add file1.txt', no_origin_dir)
self._execute_command(
"git --no-pager commit -m 'Initial commit'", no_origin_dir
)
self._execute_command('git add file1.txt', no_origin_dir)
self._execute_command("git commit -m 'Initial commit'", no_origin_dir)
# Create a custom GitHandler with a modified _get_default_branch method for this test
class TestGitHandler(GitHandler):
@@ -267,20 +234,19 @@ class TestGitHandler(unittest.TestCase):
# Verify that git commands were executed
self.assertTrue(
any(
cmd.startswith('git --no-pager rev-parse --verify')
cmd.startswith('git rev-parse --verify')
for cmd, _ in self.executed_commands
)
)
# Should have fallen back to the empty tree ref
self.assertEqual(
ref,
'$(git --no-pager rev-parse --verify 4b825dc642cb6eb9a060e54bf8d69288fbee4904)',
ref, '$(git rev-parse --verify 4b825dc642cb6eb9a060e54bf8d69288fbee4904)'
)
# Verify the ref exists (the empty tree ref always exists)
result = self._execute_command(
'git --no-pager rev-parse --verify 4b825dc642cb6eb9a060e54bf8d69288fbee4904',
'git rev-parse --verify 4b825dc642cb6eb9a060e54bf8d69288fbee4904',
no_origin_dir,
)
self.assertEqual(result.exit_code, 0)
@@ -292,9 +258,7 @@ class TestGitHandler(unittest.TestCase):
# Should have called _get_valid_ref and then git show
show_commands = [
cmd
for cmd, _ in self.executed_commands
if cmd.startswith('git --no-pager show')
cmd for cmd, _ in self.executed_commands if cmd.startswith('git show')
]
self.assertTrue(any('file1.txt' in cmd for cmd in show_commands))
@@ -313,7 +277,7 @@ class TestGitHandler(unittest.TestCase):
# Let's create a new file to ensure it shows up in the diff
with open(os.path.join(self.local_dir, 'new_file.txt'), 'w') as f:
f.write('New file content')
self._execute_command('git --no-pager add new_file.txt', self.local_dir)
self._execute_command('git add new_file.txt', self.local_dir)
files = self.git_handler._get_changed_files()
self.assertTrue(files)
@@ -327,9 +291,7 @@ class TestGitHandler(unittest.TestCase):
# Should have called _get_valid_ref and then git diff
diff_commands = [
cmd
for cmd, _ in self.executed_commands
if cmd.startswith('git --no-pager diff')
cmd for cmd, _ in self.executed_commands if cmd.startswith('git diff')
]
self.assertTrue(diff_commands)
@@ -347,7 +309,7 @@ class TestGitHandler(unittest.TestCase):
# Verify the command was executed
self.assertTrue(
any(
cmd == 'git --no-pager ls-files --others --exclude-standard'
cmd == 'git ls-files --others --exclude-standard'
for cmd, _ in self.executed_commands
)
)
@@ -361,7 +323,7 @@ class TestGitHandler(unittest.TestCase):
# Create a new file and stage it
with open(os.path.join(self.local_dir, 'new_file2.txt'), 'w') as f:
f.write('New file 2 content')
self._execute_command('git --no-pager add new_file2.txt', self.local_dir)
self._execute_command('git add new_file2.txt', self.local_dir)
changes = self.git_handler.get_git_changes()
self.assertIsNotNone(changes)
@@ -391,7 +353,7 @@ class TestGitHandler(unittest.TestCase):
)
self.assertTrue(
any(
'git --no-pager show' in cmd and 'file1.txt' in cmd
'git show' in cmd and 'file1.txt' in cmd
for cmd, _ in self.executed_commands
)
)
+49
View File
@@ -0,0 +1,49 @@
import asyncio
from contextlib import asynccontextmanager
from unittest import mock
import pytest
from openhands.mcp.client import MCPClient
@pytest.mark.asyncio
async def test_connect_sse_timeout():
"""Test that connect_sse properly times out when server_url is invalid."""
client = MCPClient()
# Create a mock async context manager that simulates a timeout
@asynccontextmanager
async def mock_slow_context(*args, **kwargs):
# This will hang for longer than our timeout
await asyncio.sleep(10.0)
yield (mock.AsyncMock(), mock.AsyncMock())
# Patch the sse_client function to return our slow context manager
with mock.patch(
'openhands.mcp.client.sse_client', return_value=mock_slow_context()
):
# Test with a very short timeout
with pytest.raises(asyncio.TimeoutError):
await client.connect_sse('http://example.com', timeout=0.1)
@pytest.mark.asyncio
async def test_connect_streamable_http_timeout():
"""Test that connect_streamable_http properly times out when server_url is invalid."""
client = MCPClient()
# Create a mock async context manager that simulates a timeout
@asynccontextmanager
async def mock_slow_context(*args, **kwargs):
# This will hang for longer than our timeout
await asyncio.sleep(10.0)
yield (mock.AsyncMock(), mock.AsyncMock(), mock.AsyncMock())
# Patch the streamablehttp_client function to return our slow context manager
with mock.patch(
'openhands.mcp.client.streamablehttp_client', return_value=mock_slow_context()
):
# Test with a very short timeout
with pytest.raises(asyncio.TimeoutError):
await client.connect_shttp('http://example.com', timeout=0.1)
+14 -17
View File
@@ -2,7 +2,6 @@ import asyncio
import pytest
from openhands.core.config.mcp_config import MCPSSEServerConfig
from openhands.mcp.client import MCPClient
from openhands.mcp.utils import create_mcp_clients
@@ -11,24 +10,22 @@ from openhands.mcp.utils import create_mcp_clients
async def test_create_mcp_clients_timeout_with_invalid_url():
"""Test that create_mcp_clients properly times out when given an invalid URL."""
# Use a non-existent domain that should cause a connection timeout
server = MCPSSEServerConfig(
url='http://non-existent-domain-that-will-timeout.invalid'
)
invalid_url = 'http://non-existent-domain-that-will-timeout.invalid'
# Temporarily modify the default timeout for the MCPClient.connect_http method
original_connect_connect_http = MCPClient.connect_http
# Temporarily modify the default timeout for the MCPClient.connect_sse method
original_connect_sse = MCPClient.connect_sse
# Create a wrapper that calls the original method but with a shorter timeout
async def connect_http_with_short_timeout(self, server_url, timeout=30.0):
return await original_connect_connect_http(self, server_url, timeout=0.5)
async def connect_sse_with_short_timeout(self, server_url, timeout=30.0):
return await original_connect_sse(self, server_url, timeout=0.5)
try:
# Replace the method with our wrapper
MCPClient.connect_http = connect_http_with_short_timeout
MCPClient.connect_sse = connect_sse_with_short_timeout
# Call create_mcp_clients with the invalid URL
start_time = asyncio.get_event_loop().time()
clients = await create_mcp_clients([server], [])
clients = await create_mcp_clients([invalid_url], [])
end_time = asyncio.get_event_loop().time()
# Verify that no clients were successfully connected
@@ -41,7 +38,7 @@ async def test_create_mcp_clients_timeout_with_invalid_url():
)
finally:
# Restore the original method
MCPClient.connect_http = original_connect_connect_http
MCPClient.connect_sse = original_connect_sse
@pytest.mark.asyncio
@@ -51,16 +48,16 @@ async def test_create_mcp_clients_with_unreachable_host():
# This IP is in the TEST-NET-1 range (192.0.2.0/24) reserved for documentation and examples
unreachable_url = 'http://192.0.2.1:8080'
# Temporarily modify the default timeout for the MCPClient.connect_http method
original_connect_http = MCPClient.connect_http
# Temporarily modify the default timeout for the MCPClient.connect_sse method
original_connect_sse = MCPClient.connect_sse
# Create a wrapper that calls the original method but with a shorter timeout
async def connect_http_with_short_timeout(self, server_url, timeout=30.0):
return await original_connect_http(self, server_url, timeout=1.0)
async def connect_sse_with_short_timeout(self, server_url, timeout=30.0):
return await original_connect_sse(self, server_url, timeout=1.0)
try:
# Replace the method with our wrapper
MCPClient.connect_http = connect_http_with_short_timeout
MCPClient.connect_sse = connect_sse_with_short_timeout
# Call create_mcp_clients with the unreachable URL
start_time = asyncio.get_event_loop().time()
@@ -76,4 +73,4 @@ async def test_create_mcp_clients_with_unreachable_host():
)
finally:
# Restore the original method
MCPClient.connect_http = original_connect_http
MCPClient.connect_sse = original_connect_sse
-86
View File
@@ -1,86 +0,0 @@
from unittest.mock import AsyncMock, patch
import pytest
from openhands.integrations.service_types import GitService
from openhands.server.routes.mcp import get_convo_link
from openhands.server.types import AppMode
@pytest.mark.asyncio
async def test_get_convo_link_non_saas_mode():
"""Test get_convo_link in non-SAAS mode."""
# Mock GitService
mock_service = AsyncMock(spec=GitService)
# Test with non-SAAS mode
with patch('openhands.server.routes.mcp.server_config') as mock_config:
mock_config.app_mode = AppMode.OSS
# Call the function
result = await get_convo_link(
service=mock_service, conversation_id='test-convo-id', body='Original body'
)
# Verify the result
assert result == 'Original body'
# Verify that get_user was not called
mock_service.get_user.assert_not_called()
@pytest.mark.asyncio
async def test_get_convo_link_saas_mode():
"""Test get_convo_link in SAAS mode."""
# Mock GitService and user
mock_service = AsyncMock(spec=GitService)
mock_user = AsyncMock()
mock_user.login = 'testuser'
mock_service.get_user.return_value = mock_user
# Test with SAAS mode
with (
patch('openhands.server.routes.mcp.server_config') as mock_config,
patch('openhands.server.routes.mcp.CONVO_URL', 'https://test.example.com/{}'),
):
mock_config.app_mode = AppMode.SAAS
# Call the function
result = await get_convo_link(
service=mock_service, conversation_id='test-convo-id', body='Original body'
)
# Verify the result
expected_link = '@testuser can click here to [continue refining the PR](https://test.example.com/test-convo-id)'
assert result == f'Original body\n\n{expected_link}'
# Verify that get_user was called
mock_service.get_user.assert_called_once()
@pytest.mark.asyncio
async def test_get_convo_link_empty_body():
"""Test get_convo_link with an empty body."""
# Mock GitService and user
mock_service = AsyncMock(spec=GitService)
mock_user = AsyncMock()
mock_user.login = 'testuser'
mock_service.get_user.return_value = mock_user
# Test with SAAS mode and empty body
with (
patch('openhands.server.routes.mcp.server_config') as mock_config,
patch('openhands.server.routes.mcp.CONVO_URL', 'https://test.example.com/{}'),
):
mock_config.app_mode = AppMode.SAAS
# Call the function
result = await get_convo_link(
service=mock_service, conversation_id='test-convo-id', body=''
)
# Verify the result
expected_link = '@testuser can click here to [continue refining the PR](https://test.example.com/test-convo-id)'
assert result == f'\n\n{expected_link}'
# Verify that get_user was called
mock_service.get_user.assert_called_once()
+8 -5
View File
@@ -13,12 +13,12 @@ async def test_sse_connection_timeout():
# Create a mock MCPClient
mock_client = mock.MagicMock(spec=MCPClient)
# Configure the mock to raise a TimeoutError when connect_http is called
async def mock_connect_http(*args, **kwargs):
# Configure the mock to raise a TimeoutError when connect_sse is called
async def mock_connect_sse(*args, **kwargs):
await asyncio.sleep(0.1) # Simulate some delay
raise asyncio.TimeoutError('Connection timed out')
mock_client.connect_http.side_effect = mock_connect_http
mock_client.connect_sse.side_effect = mock_connect_sse
mock_client.disconnect = mock.AsyncMock()
# Mock the MCPClient constructor to return our mock
@@ -35,8 +35,11 @@ async def test_sse_connection_timeout():
# Verify that no clients were successfully connected
assert len(clients) == 0
# Verify that connect_http was called for each server
assert mock_client.connect_http.call_count == 2
# Verify that connect_sse was called for each server
assert mock_client.connect_sse.call_count == 2
# Verify that disconnect was called for each failed connection
assert mock_client.disconnect.call_count == 2
@pytest.mark.asyncio
+9 -7
View File
@@ -24,7 +24,7 @@ async def test_create_mcp_clients_success(mock_mcp_client):
# Setup mock
mock_client_instance = AsyncMock()
mock_mcp_client.return_value = mock_client_instance
mock_client_instance.connect_http = AsyncMock()
mock_client_instance.connect_sse = AsyncMock()
# Test with two servers
server_configs = [
@@ -38,12 +38,12 @@ async def test_create_mcp_clients_success(mock_mcp_client):
assert len(clients) == 2
assert mock_mcp_client.call_count == 2
# Check that connect_http was called with correct parameters
mock_client_instance.connect_http.assert_any_call(
server_configs[0], conversation_id=None
# Check that connect_sse was called with correct parameters
mock_client_instance.connect_sse.assert_any_call(
'http://server1:8080', api_key=None, conversation_id=None
)
mock_client_instance.connect_http.assert_any_call(
server_configs[1], conversation_id=None
mock_client_instance.connect_sse.assert_any_call(
'http://server2:8080', api_key='test-key', conversation_id=None
)
@@ -56,10 +56,11 @@ async def test_create_mcp_clients_connection_failure(mock_mcp_client):
mock_mcp_client.return_value = mock_client_instance
# First connection succeeds, second fails
mock_client_instance.connect_http.side_effect = [
mock_client_instance.connect_sse.side_effect = [
None, # Success
Exception('Connection failed'), # Failure
]
mock_client_instance.disconnect = AsyncMock()
server_configs = [
MCPSSEServerConfig(url='http://server1:8080'),
@@ -70,6 +71,7 @@ async def test_create_mcp_clients_connection_failure(mock_mcp_client):
# Verify only one client was successfully created
assert len(clients) == 1
assert mock_client_instance.disconnect.call_count == 1
def test_convert_mcp_clients_to_tools_empty():
+2 -8
View File
@@ -234,10 +234,7 @@ async def test_clone_or_init_repo_no_repo_with_user_id(temp_dir):
# Verify that git init was called
assert len(runtime.run_action_calls) == 1
assert isinstance(runtime.run_action_calls[0], CmdRunAction)
assert (
runtime.run_action_calls[0].command
== f'git init && git config --global --add safe.directory {runtime.workspace_root}'
)
assert runtime.run_action_calls[0].command == 'git init'
assert result == ''
@@ -258,10 +255,7 @@ async def test_clone_or_init_repo_no_repo_no_user_id_no_workspace_base(temp_dir)
# Verify that git init was called
assert len(runtime.run_action_calls) == 1
assert isinstance(runtime.run_action_calls[0], CmdRunAction)
assert (
runtime.run_action_calls[0].command
== f'git init && git config --global --add safe.directory {runtime.workspace_root}'
)
assert runtime.run_action_calls[0].command == 'git init'
assert result == ''
@@ -167,7 +167,6 @@ async def test_add_to_local_event_stream():
@pytest.mark.asyncio
async def test_cleanup_session_connections():
sio = get_mock_sio()
sio.disconnect = AsyncMock() # Mock the disconnect method
async with StandaloneConversationManager(
sio, OpenHandsConfig(), InMemoryFileStore(), MonitoringListener()
) as conversation_manager:
@@ -182,7 +181,6 @@ async def test_cleanup_session_connections():
await conversation_manager._close_session('session1')
# Check that connections were removed from the dictionary
remaining_connections = conversation_manager._local_connection_id_to_session_id
assert 'conn1' not in remaining_connections
assert 'conn2' not in remaining_connections
@@ -190,8 +188,3 @@ async def test_cleanup_session_connections():
assert 'conn4' in remaining_connections
assert remaining_connections['conn3'] == 'session2'
assert remaining_connections['conn4'] == 'session2'
# Check that disconnect was called for each connection
assert sio.disconnect.await_count == 2
sio.disconnect.assert_any_call('conn1')
sio.disconnect.assert_any_call('conn2')
+4 -11
View File
@@ -12,16 +12,6 @@ from openhands.events.observation import ErrorObservation
from openhands.events.observation.commands import (
CmdOutputObservation,
)
from openhands.runtime.utils.bash_constants import TIMEOUT_MESSAGE_TEMPLATE
def get_timeout_suffix(timeout_seconds):
"""Helper function to generate the expected timeout suffix."""
return (
f'[The command timed out after {timeout_seconds} seconds. '
f'{TIMEOUT_MESSAGE_TEMPLATE}]'
)
# Skip all tests in this module if not running on Windows
pytestmark = pytest.mark.skipif(
@@ -178,7 +168,10 @@ def test_long_running_command(windows_bash_session):
# Verify the initial output was captured
assert 'Serving HTTP on' in result.content
# Check for timeout specific metadata
assert get_timeout_suffix(1.0) in result.metadata.suffix
assert (
"[The command timed out after 1.0 seconds. You may wait longer to see additional output by sending empty command '', send other commands to interact with the current process, or send keys to interrupt/kill the command.]"
in result.metadata.suffix
)
assert result.exit_code == -1
# The action timed out, but the command should be still running