Compare commits

..

2 Commits

Author SHA1 Message Date
openhands 29ce9f9822 Fix linting issues in use-update-conversation.ts 2025-04-24 20:23:52 +00:00
openhands b29252fbef Fix title tag synchronization issues 2025-04-24 20:19:24 +00:00
42 changed files with 1387 additions and 1154 deletions
+3 -3
View File
@@ -61,8 +61,8 @@ RUN add-apt-repository ppa:deadsnakes/ppa \
&& apt-get install -y python3.12 python3.12-venv python3.12-dev python3-pip \
&& ln -s /usr/bin/python3.12 /usr/bin/python
# NodeJS >= 22.x
RUN curl -fsSL https://deb.nodesource.com/setup_22.x | bash - \
# NodeJS >= 18.17.1
RUN curl -fsSL https://deb.nodesource.com/setup_18.x | bash - \
&& apt-get install -y nodejs
# Poetry >= 1.8
@@ -108,7 +108,7 @@ WORKDIR /app
# cache build dependencies
RUN \
--mount=type=bind,source=./,target=/app/,rw \
--mount=type=bind,source=./,target=/app/ \
<<EOF
#!/bin/bash
make -s clean
+40 -27
View File
@@ -1646,32 +1646,6 @@
}
}
},
"/api/reset-settings": {
"post": {
"summary": "Reset settings (Deprecated)",
"description": "This endpoint is deprecated and will return a 410 Gone error. Reset functionality has been removed.",
"operationId": "resetSettings",
"deprecated": true,
"responses": {
"410": {
"description": "Feature removed",
"content": {
"application/json": {
"schema": {
"type": "object",
"properties": {
"error": {
"type": "string",
"example": "Reset settings functionality has been removed."
}
}
}
}
}
}
}
}
},
"/api/unset-settings-tokens": {
"post": {
"summary": "Unset settings tokens",
@@ -1711,6 +1685,45 @@
}
}
},
"/api/reset-settings": {
"post": {
"summary": "Reset settings",
"description": "Reset user settings to defaults",
"operationId": "resetSettings",
"responses": {
"200": {
"description": "Settings reset successfully",
"content": {
"application/json": {
"schema": {
"type": "object",
"properties": {
"message": {
"type": "string"
}
}
}
}
}
},
"500": {
"description": "Error resetting settings",
"content": {
"application/json": {
"schema": {
"type": "object",
"properties": {
"error": {
"type": "string"
}
}
}
}
}
}
}
}
},
"/api/options/models": {
"get": {
"summary": "Get models",
@@ -2082,4 +2095,4 @@
"bearerAuth": []
}
]
}
}
+138
View File
@@ -25,6 +25,7 @@ const mock_provider_tokens_are_set: Record<Provider, boolean> = {
describe("Settings Screen", () => {
const getSettingsSpy = vi.spyOn(OpenHands, "getSettings");
const saveSettingsSpy = vi.spyOn(OpenHands, "saveSettings");
const resetSettingsSpy = vi.spyOn(OpenHands, "resetSettings");
const getConfigSpy = vi.spyOn(OpenHands, "getConfig");
const { handleLogoutMock } = vi.hoisted(() => ({
@@ -66,6 +67,7 @@ describe("Settings Screen", () => {
// Use queryAllByText to handle multiple elements with the same text
expect(screen.queryAllByText("SETTINGS$LLM_SETTINGS")).not.toHaveLength(0);
screen.getByText("ACCOUNT_SETTINGS$ADDITIONAL_SETTINGS");
screen.getByText("BUTTON$RESET_TO_DEFAULTS");
screen.getByText("BUTTON$SAVE");
});
});
@@ -540,6 +542,54 @@ describe("Settings Screen", () => {
});
});
test("resetting settings with no changes but having advanced enabled should hide the advanced items", async () => {
const user = userEvent.setup();
getSettingsSpy.mockResolvedValueOnce({
...MOCK_DEFAULT_USER_SETTINGS,
});
renderSettingsScreen();
await toggleAdvancedSettings(user);
const resetButton = screen.getByText("BUTTON$RESET_TO_DEFAULTS");
await user.click(resetButton);
// show modal
const modal = await screen.findByTestId("reset-modal");
expect(modal).toBeInTheDocument();
// Mock the settings that will be returned after reset
// This should be the default settings with no advanced settings enabled
getSettingsSpy.mockResolvedValueOnce({
...MOCK_DEFAULT_USER_SETTINGS,
llm_base_url: "",
confirmation_mode: false,
security_analyzer: "",
});
// confirm reset
const confirmButton = within(modal).getByText("Reset");
await user.click(confirmButton);
await waitFor(() => {
expect(
screen.queryByTestId("llm-custom-model-input"),
).not.toBeInTheDocument();
expect(
screen.queryByTestId("base-url-input"),
).not.toBeInTheDocument();
expect(screen.queryByTestId("agent-input")).not.toBeInTheDocument();
expect(
screen.queryByTestId("security-analyzer-input"),
).not.toBeInTheDocument();
expect(
screen.queryByTestId("enable-confirmation-mode-switch"),
).not.toBeInTheDocument();
});
});
it("should save if only confirmation mode is enabled", async () => {
const user = userEvent.setup();
renderSettingsScreen();
@@ -712,6 +762,81 @@ describe("Settings Screen", () => {
);
});
it("should reset the settings when the 'Reset to defaults' button is clicked", async () => {
const user = userEvent.setup();
getSettingsSpy.mockResolvedValue(MOCK_DEFAULT_USER_SETTINGS);
renderSettingsScreen();
const languageInput = await screen.findByTestId("language-input");
await user.click(languageInput);
const norskOption = await screen.findByText("Norsk");
await user.click(norskOption);
expect(languageInput).toHaveValue("Norsk");
const resetButton = screen.getByText("BUTTON$RESET_TO_DEFAULTS");
await user.click(resetButton);
expect(saveSettingsSpy).not.toHaveBeenCalled();
// show modal
const modal = await screen.findByTestId("reset-modal");
expect(modal).toBeInTheDocument();
// confirm reset
const confirmButton = within(modal).getByText("Reset");
await user.click(confirmButton);
await waitFor(() => {
expect(resetSettingsSpy).toHaveBeenCalled();
});
// Mock the settings response after reset
getSettingsSpy.mockResolvedValueOnce({
...MOCK_DEFAULT_USER_SETTINGS,
llm_base_url: "",
confirmation_mode: false,
security_analyzer: "",
});
// Wait for the mutation to complete and the modal to be removed
await waitFor(() => {
expect(screen.queryByTestId("reset-modal")).not.toBeInTheDocument();
expect(
screen.queryByTestId("llm-custom-model-input"),
).not.toBeInTheDocument();
expect(screen.queryByTestId("base-url-input")).not.toBeInTheDocument();
expect(screen.queryByTestId("agent-input")).not.toBeInTheDocument();
expect(
screen.queryByTestId("security-analyzer-input"),
).not.toBeInTheDocument();
expect(
screen.queryByTestId("enable-confirmation-mode-switch"),
).not.toBeInTheDocument();
});
});
it("should cancel the reset when the 'Cancel' button is clicked", async () => {
const user = userEvent.setup();
getSettingsSpy.mockResolvedValue(MOCK_DEFAULT_USER_SETTINGS);
renderSettingsScreen();
const resetButton = await screen.findByText("BUTTON$RESET_TO_DEFAULTS");
await user.click(resetButton);
const modal = await screen.findByTestId("reset-modal");
expect(modal).toBeInTheDocument();
const cancelButton = within(modal).getByText("Cancel");
await user.click(cancelButton);
expect(saveSettingsSpy).not.toHaveBeenCalled();
expect(screen.queryByTestId("reset-modal")).not.toBeInTheDocument();
});
it("should call handleCaptureConsent with true if the save is successful", async () => {
const user = userEvent.setup();
const handleCaptureConsentSpy = vi.spyOn(
@@ -919,5 +1044,18 @@ describe("Settings Screen", () => {
);
});
it("should not submit the unwanted fields when resetting", async () => {
const user = userEvent.setup();
renderSettingsScreen();
const resetButton = await screen.findByText("BUTTON$RESET_TO_DEFAULTS");
await user.click(resetButton);
const modal = await screen.findByTestId("reset-modal");
const confirmButton = within(modal).getByText("Reset");
await user.click(confirmButton);
expect(saveSettingsSpy).not.toHaveBeenCalled();
expect(resetSettingsSpy).toHaveBeenCalled();
});
});
});
+8
View File
@@ -199,6 +199,14 @@ class OpenHands {
return data.status === 200;
}
/**
* Reset user settings in server
*/
static async resetSettings(): Promise<boolean> {
const response = await openHands.post("/api/reset-settings");
return response.status === 200;
}
static async createCheckoutSession(amount: number): Promise<string> {
const { data } = await openHands.post(
"/api/billing/create-checkout-session",
@@ -4,7 +4,15 @@ import OpenHands from "#/api/open-hands";
import { PostSettings, PostApiSettings } from "#/types/settings";
import { useSettings } from "../query/use-settings";
const saveSettingsMutationFn = async (settings: Partial<PostSettings>) => {
const saveSettingsMutationFn = async (
settings: Partial<PostSettings> | null,
) => {
// If settings is null, we're resetting
if (settings === null) {
await OpenHands.resetSettings();
return;
}
const apiSettings: Partial<PostApiSettings> = {
llm_model: settings.LLM_MODEL,
llm_base_url: settings.LLM_BASE_URL,
@@ -31,7 +39,12 @@ export const useSaveSettings = () => {
const { data: currentSettings } = useSettings();
return useMutation({
mutationFn: async (settings: Partial<PostSettings>) => {
mutationFn: async (settings: Partial<PostSettings> | null) => {
if (settings === null) {
await saveSettingsMutationFn(null);
return;
}
const newSettings = { ...currentSettings, ...settings };
await saveSettingsMutationFn(newSettings);
},
@@ -11,8 +11,14 @@ export const useUpdateConversation = () => {
conversation: Partial<Omit<Conversation, "id">>;
}) =>
OpenHands.updateUserConversation(variables.id, variables.conversation),
onSuccess: () => {
onSuccess: (_, variables) => {
// Invalidate the conversations list
queryClient.invalidateQueries({ queryKey: ["user", "conversations"] });
// Also invalidate the specific conversation to ensure title updates are reflected
queryClient.invalidateQueries({
queryKey: ["user", "conversation", variables.id],
});
},
});
};
@@ -7,6 +7,6 @@ export const useUserConversation = (cid: string | null) =>
queryFn: () => OpenHands.getConversation(cid!),
enabled: !!cid,
retry: false,
staleTime: 1000 * 60 * 5, // 5 minutes
staleTime: 1000 * 30, // 30 seconds - reduced from 5 minutes for more responsive title updates
gcTime: 1000 * 60 * 15, // 15 minutes
});
+24 -6
View File
@@ -48,6 +48,7 @@ export function useAutoTitle() {
return;
}
// Request title generation from the backend
updateConversation(
{
id: conversationId,
@@ -56,13 +57,30 @@ export function useAutoTitle() {
{
onSuccess: async () => {
try {
const updatedConversation =
await OpenHands.getConversation(conversationId);
// Add a small delay to allow the backend to generate the title
// This helps avoid race conditions where we fetch before the title is ready
setTimeout(async () => {
try {
const updatedConversation =
await OpenHands.getConversation(conversationId);
queryClient.setQueryData(
["user", "conversation", conversationId],
updatedConversation,
);
if (updatedConversation && updatedConversation.title) {
queryClient.setQueryData(
["user", "conversation", conversationId],
updatedConversation,
);
} else {
// If we still don't have a title, invalidate the query to trigger a refetch
queryClient.invalidateQueries({
queryKey: ["user", "conversation", conversationId],
});
}
} catch (error) {
queryClient.invalidateQueries({
queryKey: ["user", "conversation", conversationId],
});
}
}, 1000); // 1 second delay
} catch (error) {
queryClient.invalidateQueries({
queryKey: ["user", "conversation", conversationId],
+2 -1
View File
@@ -82,6 +82,7 @@ export enum I18nKey {
API$DONT_KNOW_KEY = "API$DONT_KNOW_KEY",
BUTTON$SAVE = "BUTTON$SAVE",
BUTTON$CLOSE = "BUTTON$CLOSE",
BUTTON$RESET_TO_DEFAULTS = "BUTTON$RESET_TO_DEFAULTS",
MODAL$CONFIRM_RESET_TITLE = "MODAL$CONFIRM_RESET_TITLE",
MODAL$CONFIRM_RESET_MESSAGE = "MODAL$CONFIRM_RESET_MESSAGE",
MODAL$END_SESSION_TITLE = "MODAL$END_SESSION_TITLE",
@@ -320,6 +321,7 @@ export enum I18nKey {
SETTINGS_FORM$ENABLE_CONFIRMATION_MODE_LABEL = "SETTINGS_FORM$ENABLE_CONFIRMATION_MODE_LABEL",
SETTINGS_FORM$SAVE_LABEL = "SETTINGS_FORM$SAVE_LABEL",
SETTINGS_FORM$CLOSE_LABEL = "SETTINGS_FORM$CLOSE_LABEL",
SETTINGS_FORM$RESET_TO_DEFAULTS_LABEL = "SETTINGS_FORM$RESET_TO_DEFAULTS_LABEL",
SETTINGS_FORM$CANCEL_LABEL = "SETTINGS_FORM$CANCEL_LABEL",
SETTINGS_FORM$END_SESSION_LABEL = "SETTINGS_FORM$END_SESSION_LABEL",
SETTINGS_FORM$CHANGING_WORKSPACE_WARNING_MESSAGE = "SETTINGS_FORM$CHANGING_WORKSPACE_WARNING_MESSAGE",
@@ -338,7 +340,6 @@ export enum I18nKey {
STATUS$LLM_RETRY = "STATUS$LLM_RETRY",
AGENT_ERROR$BAD_ACTION = "AGENT_ERROR$BAD_ACTION",
AGENT_ERROR$ACTION_TIMEOUT = "AGENT_ERROR$ACTION_TIMEOUT",
AGENT_ERROR$TOO_MANY_CONVERSATIONS = "AGENT_ERROR$TOO_MANY_CONVERSATIONS",
PROJECT_MENU_CARD_CONTEXT_MENU$CONNECT_TO_GITHUB_LABEL = "PROJECT_MENU_CARD_CONTEXT_MENU$CONNECT_TO_GITHUB_LABEL",
PROJECT_MENU_CARD_CONTEXT_MENU$PUSH_TO_GITHUB_LABEL = "PROJECT_MENU_CARD_CONTEXT_MENU$PUSH_TO_GITHUB_LABEL",
PROJECT_MENU_CARD_CONTEXT_MENU$DOWNLOAD_FILES_LABEL = "PROJECT_MENU_CARD_CONTEXT_MENU$DOWNLOAD_FILES_LABEL",
+30 -3
View File
@@ -1231,6 +1231,21 @@
"tr": "Kapat",
"de": "Schließen"
},
"BUTTON$RESET_TO_DEFAULTS": {
"en": "Reset to defaults",
"ja": "デフォルトにリセット",
"zh-CN": "重置为默认值",
"zh-TW": "還原為預設值",
"ko-KR": "기본값으로 재설정",
"no": "Tilbakestill til standard",
"it": "Ripristina valori predefiniti",
"pt": "Restaurar padrões",
"es": "Restablecer valores predeterminados",
"ar": "إعادة التعيين إلى الإعدادات الافتراضية",
"fr": "Réinitialiser aux valeurs par défaut",
"tr": "Varsayılanlara sıfırla",
"de": "Auf Standardwerte zurücksetzen"
},
"MODAL$CONFIRM_RESET_TITLE": {
"en": "Are you sure?",
"ja": "本当によろしいですか?",
@@ -4526,6 +4541,21 @@
"pt": "Fechar",
"tr": "Kapat"
},
"SETTINGS_FORM$RESET_TO_DEFAULTS_LABEL": {
"en": "Reset to defaults",
"es": "Reiniciar valores por defect",
"zh-CN": "重置为默认值",
"zh-TW": "還原為預設值",
"ko-KR": "기본값으로 재설정",
"ja": "デフォルトに戻す",
"no": "Tilbakestill til standardverdier",
"ar": "إعادة التعيين إلى الإعدادات الافتراضية",
"de": "Auf Standardwerte zurücksetzen",
"fr": "Réinitialiser aux valeurs par défaut",
"it": "Ripristina valori predefiniti",
"pt": "Restaurar padrões",
"tr": "Varsayılanlara sıfırla"
},
"SETTINGS_FORM$CANCEL_LABEL": {
"en": "Cancel",
"es": "Cancelar",
@@ -4796,9 +4826,6 @@
"es": "La acción expiró",
"tr": "İşlem zaman aşımına uğradı"
},
"AGENT_ERROR$TOO_MANY_CONVERSATIONS": {
"en": "Too many conversations at once."
},
"PROJECT_MENU_CARD_CONTEXT_MENU$CONNECT_TO_GITHUB_LABEL": {
"en": "Connect to GitHub",
"es": "Conectar a GitHub",
+54
View File
@@ -9,6 +9,7 @@ import { SettingsDropdownInput } from "#/components/features/settings/settings-d
import { SettingsInput } from "#/components/features/settings/settings-input";
import { SettingsSwitch } from "#/components/features/settings/settings-switch";
import { LoadingSpinner } from "#/components/shared/loading-spinner";
import { ModalBackdrop } from "#/components/shared/modals/modal-backdrop";
import { ModelSelector } from "#/components/shared/modals/settings/model-selector";
import { useSaveSettings } from "#/hooks/mutation/use-save-settings";
import { useAIConfigOptions } from "#/hooks/query/use-ai-config-options";
@@ -94,6 +95,8 @@ function AccountSettings() {
>(isAdvancedSettingsSet ? "advanced" : "basic");
const [confirmationModeIsEnabled, setConfirmationModeIsEnabled] =
React.useState(!!settings?.SECURITY_ANALYZER);
const [resetSettingsModalIsOpen, setResetSettingsModalIsOpen] =
React.useState(false);
const formRef = React.useRef<HTMLFormElement>(null);
@@ -177,6 +180,16 @@ function AccountSettings() {
});
};
const handleReset = () => {
saveSettings(null, {
onSuccess: () => {
displaySuccessToast(t(I18nKey.SETTINGS$RESET));
setResetSettingsModalIsOpen(false);
setLlmConfigMode("basic");
},
});
};
React.useEffect(() => {
// If settings is still loading by the time the state is set, it will always
// default to basic settings. This is a workaround to ensure the correct
@@ -514,6 +527,13 @@ function AccountSettings() {
</form>
<footer className="flex gap-6 p-6 justify-end border-t border-t-tertiary">
<BrandButton
type="button"
variant="secondary"
onClick={() => setResetSettingsModalIsOpen(true)}
>
{t(I18nKey.BUTTON$RESET_TO_DEFAULTS)}
</BrandButton>
<BrandButton
type="button"
variant="primary"
@@ -524,6 +544,40 @@ function AccountSettings() {
{t(I18nKey.BUTTON$SAVE)}
</BrandButton>
</footer>
{resetSettingsModalIsOpen && (
<ModalBackdrop>
<div
data-testid="reset-modal"
className="bg-base-secondary p-4 rounded-xl flex flex-col gap-4 border border-tertiary"
>
<p>{t(I18nKey.SETTINGS$RESET_CONFIRMATION)}</p>
<div className="w-full flex gap-2">
<BrandButton
type="button"
variant="primary"
className="grow"
onClick={() => {
handleReset();
}}
>
Reset
</BrandButton>
<BrandButton
type="button"
variant="secondary"
className="grow"
onClick={() => {
setResetSettingsModalIsOpen(false);
}}
>
Cancel
</BrandButton>
</div>
</div>
</ModalBackdrop>
)}
</>
);
}
+1
View File
@@ -135,6 +135,7 @@ def event_to_dict(event: 'Event') -> dict:
k: (v.value if isinstance(v, Enum) else _convert_pydantic_to_dict(v))
for k, v in props.items()
}
logger.debug(f'extras data in event_to_dict: {d["extras"]}')
# Include success field for CmdOutputObservation
if hasattr(event, 'success'):
d['success'] = event.success
@@ -18,8 +18,6 @@ from openhands.integrations.service_types import (
)
from openhands.server.types import AppMode
from openhands.utils.import_utils import get_impl
from datetime import datetime
class GitHubService(BaseGitService, GitService):
@@ -159,13 +157,6 @@ class GitHubService(BaseGitService, GitService):
return repos[:max_repos] # Trim to max_repos if needed
def parse_pushed_at_date(self, repo):
ts = repo.get("pushed_at")
return datetime.strptime(ts, "%Y-%m-%dT%H:%M:%SZ") if ts else datetime.min
async def get_repositories(self, sort: str, app_mode: AppMode) -> list[Repository]:
MAX_REPOS = 1000
PER_PAGE = 100 # Maximum allowed by GitHub API
@@ -192,11 +183,6 @@ class GitHubService(BaseGitService, GitService):
# If we've already reached MAX_REPOS, no need to check other installations
if len(all_repos) >= MAX_REPOS:
break
if sort == "pushed":
all_repos.sort(
key=self.parse_pushed_at_date, reverse=True
)
else:
# Original behavior for non-SaaS mode
params = {'per_page': str(PER_PAGE), 'sort': sort}
@@ -205,7 +191,6 @@ class GitHubService(BaseGitService, GitService):
# Fetch user repositories
all_repos = await self._fetch_paginated_repos(url, params, MAX_REPOS)
# Convert to Repository objects
return [
Repository(
+1 -1
View File
@@ -111,7 +111,7 @@ class BaseGitService(ABC):
return UnknownException('Unknown error')
def handle_http_error(self, e: HTTPError) -> UnknownException:
logger.warning(f'HTTP error on {self.provider} API: {type(e).__name__} : {e}')
logger.warning(f'HTTP error on {self.provider} API: {e}')
return UnknownException('Unknown error')
+1 -1
View File
@@ -713,7 +713,7 @@ class LLM(RetryMixin, DebugMixin):
completion_response=response, **extra_kwargs
)
except Exception as e:
logger.debug(f'Error getting cost from litellm: {e}')
logger.error(f'Error getting cost from litellm: {e}')
if cost is None:
_model_name = '/'.join(self.config.model.split('/')[1:])
+1 -1
View File
@@ -183,7 +183,7 @@ python -m openhands.resolver.send_pull_request --issue-number ISSUE_NUMBER --use
## Providing Custom Instructions
You can customize how the AI agent approaches issue resolution by adding a repository microagent file at `.openhands/microagents/repo.md` in your repository. This file's contents will be automatically loaded in the prompt when working with your repository. For more information about repository microagents, see [Repository Instructions](https://github.com/All-Hands-AI/OpenHands/tree/main/microagents#2-repository-instructions-private).
You can customize how the AI agent approaches issue resolution by adding a `.openhands_instructions` file to the root of your repository. If present, this file's contents will be injected into the prompt for openhands edits.
## Troubleshooting
+8 -26
View File
@@ -98,7 +98,6 @@ class Runtime(FileEditRuntimeMixin):
initial_env_vars: dict[str, str]
attach_to_existing: bool
status_callback: Callable | None
git_dir: str | None
def __init__(
self,
@@ -113,11 +112,9 @@ class Runtime(FileEditRuntimeMixin):
user_id: str | None = None,
git_provider_tokens: PROVIDER_TOKEN_TYPE | None = None,
):
# GitHandler will be initialized with an async function
self.git_handler = GitHandler(
execute_shell_fn=self._execute_shell_fn_git_handler
)
self.git_dir = None
self.sid = sid
self.event_stream = event_stream
self.event_stream.subscribe(
@@ -319,9 +316,6 @@ class Runtime(FileEditRuntimeMixin):
selected_branch: str | None,
repository_provider: ProviderType = ProviderType.GITHUB,
) -> str:
# Set the git_dir to the workspace mount path by default
self.git_dir = self.config.workspace_mount_path_in_sandbox
if not selected_repository:
# In SaaS mode (indicated by user_id being set), always run git init
# In OSS mode, only run git init if workspace_base is not set
@@ -333,7 +327,6 @@ class Runtime(FileEditRuntimeMixin):
command='git init',
)
self.run_action(action)
# git_dir is already set to workspace mount path
else:
logger.info(
'In workspace mount mode, not initializing a new git repository.'
@@ -402,13 +395,6 @@ class Runtime(FileEditRuntimeMixin):
)
self.log('info', f'Cloning repo: {selected_repository}')
self.run_action(action)
# Update git_dir to point to the cloned repository directory
self.git_dir = os.path.join(
self.config.workspace_mount_path_in_sandbox, dir_name
)
self.git_handler.set_cwd(self.git_dir)
return dir_name
def maybe_run_setup_script(self):
@@ -626,15 +612,13 @@ class Runtime(FileEditRuntimeMixin):
# Git
# ====================================================================
async def _execute_shell_fn_git_handler(
def _execute_shell_fn_git_handler(
self, command: str, cwd: str | None
) -> CommandResult:
"""
This function is used by the GitHandler to execute shell commands.
"""
obs = await call_sync_from_async(
self.run, CmdRunAction(command=command, is_static=True, cwd=cwd)
)
obs = self.run(CmdRunAction(command=command, is_static=True, cwd=cwd))
exit_code = 0
content = ''
@@ -645,15 +629,13 @@ class Runtime(FileEditRuntimeMixin):
return CommandResult(content=content, exit_code=exit_code)
async def get_git_changes(self) -> list[dict[str, str]] | None:
if self.git_dir:
self.git_handler.set_cwd(self.git_dir)
return await call_sync_from_async(self.git_handler.get_git_changes)
def get_git_changes(self, cwd: str) -> list[dict[str, str]] | None:
self.git_handler.set_cwd(cwd)
return self.git_handler.get_git_changes()
async def get_git_diff(self, file_path: str) -> dict[str, str]:
if self.git_dir:
self.git_handler.set_cwd(self.git_dir)
return await call_sync_from_async(self.git_handler.get_git_diff, file_path)
def get_git_diff(self, file_path: str, cwd: str) -> dict[str, str]:
self.git_handler.set_cwd(cwd)
return self.git_handler.get_git_diff(file_path)
@property
def additional_agent_instructions(self) -> str:
+2 -2
View File
@@ -215,13 +215,13 @@ class BashSession:
self.session.set_option('history-limit', str(self.HISTORY_LIMIT), _global=True)
self.session.history_limit = self.HISTORY_LIMIT
# We need to create a new pane because the initial pane's history limit is (default) 2000
_initial_window = self.session.active_window
_initial_window = self.session.attached_window
self.window = self.session.new_window(
window_name='bash',
window_shell=window_command,
start_directory=self.work_dir, # This parameter is supported by libtmux
)
self.pane = self.window.active_pane
self.pane = self.window.attached_pane
logger.debug(f'pane: {self.pane}; history_limit: {self.session.history_limit}')
_initial_window.kill_window()
+31 -35
View File
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Awaitable, Callable
from typing import Callable
@dataclass
@@ -23,7 +23,7 @@ class GitHandler:
def __init__(
self,
execute_shell_fn: Callable[[str, str | None], Awaitable[CommandResult]],
execute_shell_fn: Callable[[str, str | None], CommandResult],
):
self.execute = execute_shell_fn
self.cwd: str | None = None
@@ -37,11 +37,7 @@ class GitHandler:
"""
self.cwd = cwd
async def _execute_async(self, cmd: str, cwd: str | None) -> CommandResult:
"""Execute the command asynchronously."""
return await self.execute(cmd, cwd)
async def _is_git_repo(self) -> bool:
def _is_git_repo(self) -> bool:
"""
Checks if the current directory is a Git repository.
@@ -49,10 +45,10 @@ class GitHandler:
bool: True if inside a Git repository, otherwise False.
"""
cmd = 'git rev-parse --is-inside-work-tree'
output = await self._execute_async(cmd, self.cwd)
output = self.execute(cmd, self.cwd)
return output.content.strip() == 'true'
async def _get_current_file_content(self, file_path: str) -> str:
def _get_current_file_content(self, file_path: str) -> str:
"""
Retrieves the current content of a given file.
@@ -62,10 +58,10 @@ class GitHandler:
Returns:
str: The file content.
"""
output = await self._execute_async(f'cat {file_path}', self.cwd)
output = self.execute(f'cat {file_path}', self.cwd)
return output.content
async def _verify_ref_exists(self, ref: str) -> bool:
def _verify_ref_exists(self, ref: str) -> bool:
"""
Verifies whether a specific Git reference exists.
@@ -76,18 +72,18 @@ class GitHandler:
bool: True if the reference exists, otherwise False.
"""
cmd = f'git rev-parse --verify {ref}'
output = await self._execute_async(cmd, self.cwd)
output = self.execute(cmd, self.cwd)
return output.exit_code == 0
async def _get_valid_ref(self) -> str | None:
def _get_valid_ref(self) -> str | None:
"""
Determines a valid Git reference for comparison.
Returns:
str | None: A valid Git reference or None if no valid reference is found.
"""
current_branch = await self._get_current_branch()
default_branch = await self._get_default_branch()
current_branch = self._get_current_branch()
default_branch = self._get_default_branch()
ref_current_branch = f'origin/{current_branch}'
ref_non_default_branch = f'$(git merge-base HEAD "$(git rev-parse --abbrev-ref origin/{default_branch})")'
@@ -101,12 +97,12 @@ class GitHandler:
ref_new_repo,
]
for ref in refs:
if await self._verify_ref_exists(ref):
if self._verify_ref_exists(ref):
return ref
return None
async def _get_ref_content(self, file_path: str) -> str:
def _get_ref_content(self, file_path: str) -> str:
"""
Retrieves the content of a file from a valid Git reference.
@@ -116,15 +112,15 @@ class GitHandler:
Returns:
str: The content of the file from the reference, or an empty string if unavailable.
"""
ref = await self._get_valid_ref()
ref = self._get_valid_ref()
if not ref:
return ''
cmd = f'git show {ref}:{file_path}'
output = await self._execute_async(cmd, self.cwd)
output = self.execute(cmd, self.cwd)
return output.content if output.exit_code == 0 else ''
async def _get_default_branch(self) -> str:
def _get_default_branch(self) -> str:
"""
Retrieves the primary Git branch name of the repository.
@@ -132,10 +128,10 @@ class GitHandler:
str: The name of the primary branch.
"""
cmd = 'git remote show origin | grep "HEAD branch"'
output = await self._execute_async(cmd, self.cwd)
output = self.execute(cmd, self.cwd)
return output.content.split()[-1].strip()
async def _get_current_branch(self) -> str:
def _get_current_branch(self) -> str:
"""
Retrieves the currently selected Git branch.
@@ -143,25 +139,25 @@ class GitHandler:
str: The name of the current branch.
"""
cmd = 'git rev-parse --abbrev-ref HEAD'
output = await self._execute_async(cmd, self.cwd)
output = self.execute(cmd, self.cwd)
return output.content.strip()
async def _get_changed_files(self) -> list[str]:
def _get_changed_files(self) -> list[str]:
"""
Retrieves a list of changed files compared to a valid Git reference.
Returns:
list[str]: A list of changed file paths.
"""
ref = await self._get_valid_ref()
ref = self._get_valid_ref()
if not ref:
return []
diff_cmd = f'git diff --name-status {ref}'
output = await self._execute_async(diff_cmd, self.cwd)
output = self.execute(diff_cmd, self.cwd)
return output.content.splitlines()
async def _get_untracked_files(self) -> list[dict[str, str]]:
def _get_untracked_files(self) -> list[dict[str, str]]:
"""
Retrieves a list of untracked files in the repository. This is useful for detecting new files.
@@ -169,7 +165,7 @@ class GitHandler:
list[dict[str, str]]: A list of dictionaries containing file paths and statuses.
"""
cmd = 'git ls-files --others --exclude-standard'
output = await self._execute_async(cmd, self.cwd)
output = self.execute(cmd, self.cwd)
obs_list = output.content.splitlines()
return (
[{'status': 'A', 'path': path} for path in obs_list]
@@ -177,24 +173,24 @@ class GitHandler:
else []
)
async def get_git_changes(self) -> list[dict[str, str]] | None:
def get_git_changes(self) -> list[dict[str, str]] | None:
"""
Retrieves the list of changed files in the Git repository.
Returns:
list[dict[str, str]] | None: A list of dictionaries containing file paths and statuses. None if not a git repository.
"""
if not await self._is_git_repo():
if not self._is_git_repo():
return None
changes_list = await self._get_changed_files()
changes_list = self._get_changed_files()
result = parse_git_changes(changes_list)
# join with any untracked files
result += await self._get_untracked_files()
result += self._get_untracked_files()
return result
async def get_git_diff(self, file_path: str) -> dict[str, str]:
def get_git_diff(self, file_path: str) -> dict[str, str]:
"""
Retrieves the original and modified content of a file in the repository.
@@ -204,8 +200,8 @@ class GitHandler:
Returns:
dict[str, str]: A dictionary containing the original and modified content.
"""
modified = await self._get_current_file_content(file_path)
original = await self._get_ref_content(file_path)
modified = self._get_current_file_content(file_path)
original = self._get_ref_content(file_path)
return {
'modified': modified,
+34
View File
@@ -0,0 +1,34 @@
from fastapi import Request
from pydantic import SecretStr
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderType
def get_provider_tokens(request: Request) -> PROVIDER_TOKEN_TYPE | None:
"""Get GitHub token from request state. For backward compatibility."""
return getattr(request.state, 'provider_tokens', None)
def get_access_token(request: Request) -> SecretStr | None:
return getattr(request.state, 'access_token', None)
def get_user_id(request: Request) -> str | None:
return getattr(request.state, 'user_id', None)
def get_github_token(request: Request) -> SecretStr | None:
provider_tokens = get_provider_tokens(request)
if provider_tokens and ProviderType.GITHUB in provider_tokens:
return provider_tokens[ProviderType.GITHUB].token
return None
def get_github_user_id(request: Request) -> str | None:
provider_tokens = get_provider_tokens(request)
if provider_tokens and ProviderType.GITHUB in provider_tokens:
return provider_tokens[ProviderType.GITHUB].user_id
return None
-1
View File
@@ -20,7 +20,6 @@ class ServerConfig(ServerConfigInterface):
)
conversation_manager_class: str = 'openhands.server.conversation_manager.standalone_conversation_manager.StandaloneConversationManager'
monitoring_listener_class: str = 'openhands.server.monitoring.MonitoringListener'
user_auth_class: str = 'openhands.server.user_auth.default_user_auth.DefaultUserAuth'
def verify_config(self):
if self.config_cls:
@@ -202,7 +202,9 @@ class StandaloneConversationManager(ConversationManager):
ConversationStore, # type: ignore
self.server_config.conversation_store_class,
)
store = await conversation_store_class.get_instance(self.config, user_id)
store = await conversation_store_class.get_instance(
self.config, user_id, github_user_id
)
return store
async def get_running_agent_loops(
@@ -283,7 +285,7 @@ class StandaloneConversationManager(ConversationManager):
response_ids = await self.get_running_agent_loops(user_id)
if len(response_ids) >= self.config.max_concurrent_conversations:
logger.info(
f'too_many_sessions_for:{user_id or ''}',
'too_many_sessions_for:{user_id}',
extra={'session_id': sid, 'user_id': user_id},
)
# Get the conversations sorted (oldest first)
@@ -295,22 +297,6 @@ class StandaloneConversationManager(ConversationManager):
while len(conversations) >= self.config.max_concurrent_conversations:
oldest_conversation_id = conversations.pop().conversation_id
logger.debug(
f'closing_from_too_many_sessions:{user_id or ''}:{oldest_conversation_id}',
extra={'session_id': oldest_conversation_id, 'user_id': user_id},
)
# Send status message to client and close session.
status_update_dict = {
'status_update': True,
'type': 'error',
'id': 'AGENT_ERROR$TOO_MANY_CONVERSATIONS',
'message': 'Too many conversations at once. If you are still using this one, try reactivating it by prompting the agent to continue',
}
await self.sio.emit(
'oh_event',
status_update_dict,
to=ROOM_KEY.format(sid=oldest_conversation_id),
)
await self.close_session(oldest_conversation_id)
session = Session(
@@ -395,8 +381,8 @@ class StandaloneConversationManager(ConversationManager):
f'removing connections: {connection_ids_to_remove}',
extra={'session_id': sid},
)
for connection_id in connection_ids_to_remove:
self._local_connection_id_to_session_id.pop(connection_id, None)
for connnnection_id in connection_ids_to_remove:
self._local_connection_id_to_session_id.pop(connnnection_id, None)
session = self._local_agent_loops_by_sid.pop(sid, None)
if not session:
+2
View File
@@ -9,6 +9,7 @@ from openhands.server.middleware import (
CacheControlMiddleware,
InMemoryRateLimiter,
LocalhostCORSMiddleware,
ProviderTokenMiddleware,
RateLimitMiddleware,
)
from openhands.server.static import SPAStaticFiles
@@ -31,5 +32,6 @@ base_app.add_middleware(
rate_limiter=InMemoryRateLimiter(requests=10, seconds=1),
)
base_app.middleware('http')(AttachConversationMiddleware(base_app))
base_app.middleware('http')(ProviderTokenMiddleware(base_app))
app = socketio.ASGIApp(sio, other_asgi_app=base_app)
+26 -3
View File
@@ -12,8 +12,8 @@ from starlette.requests import Request as StarletteRequest
from starlette.types import ASGIApp
from openhands.server import shared
from openhands.server.auth import get_user_id
from openhands.server.types import SessionMiddlewareInterface
from openhands.server.user_auth import get_user_id
class LocalhostCORSMiddleware(CORSMiddleware):
@@ -147,10 +147,9 @@ class AttachConversationMiddleware(SessionMiddlewareInterface):
"""
Attach the user's session based on the provided authentication token.
"""
user_id = await get_user_id(request)
request.state.conversation = (
await shared.conversation_manager.attach_to_conversation(
request.state.sid, user_id
request.state.sid, get_user_id(request)
)
)
if not request.state.conversation:
@@ -184,3 +183,27 @@ class AttachConversationMiddleware(SessionMiddlewareInterface):
await self._detach_session(request)
return response
class ProviderTokenMiddleware(SessionMiddlewareInterface):
def __init__(self, app):
self.app = app
async def __call__(self, request: Request, call_next: Callable):
settings_store = await shared.SettingsStoreImpl.get_instance(
shared.config, get_user_id(request)
)
settings = await settings_store.load()
# TODO: To avoid checks like this we should re-add the abilty to have completely different middleware in SAAS as in OSS
if getattr(request.state, 'provider_tokens', None) is None:
if (
settings
and settings.secrets_store
and settings.secrets_store.provider_tokens
):
request.state.provider_tokens = settings.secrets_store.provider_tokens
else:
request.state.provider_tokens = None
return await call_next(request)
+77 -17
View File
@@ -2,7 +2,6 @@ import os
from fastapi import (
APIRouter,
Depends,
HTTPException,
Request,
status,
@@ -22,10 +21,19 @@ from openhands.events.observation import (
FileReadObservation,
)
from openhands.runtime.base import Runtime
from openhands.server.auth import get_github_user_id, get_user_id
from openhands.server.data_models.conversation_info import ConversationInfo
from openhands.server.file_config import (
FILES_TO_IGNORE,
)
from openhands.server.user_auth import get_user_id
from openhands.server.shared import (
ConversationStoreImpl,
config,
conversation_manager,
)
from openhands.storage.conversation.conversation_store import ConversationStore
from openhands.storage.data_models.conversation_metadata import ConversationMetadata
from openhands.storage.data_models.conversation_status import ConversationStatus
from openhands.utils.async_utils import call_sync_from_async
app = APIRouter(prefix='/api/conversations/{conversation_id}')
@@ -179,19 +187,24 @@ def zip_current_workspace(request: Request):
@app.get('/git/changes')
async def git_changes(
request: Request,
conversation_id: str,
user_id: str = Depends(get_user_id),
):
async def git_changes(request: Request, conversation_id: str):
runtime: Runtime = request.state.conversation.runtime
logger.info(f'Getting git changes in {runtime.git_dir}')
conversation_store = await ConversationStoreImpl.get_instance(
config, get_user_id(request), get_github_user_id(request)
)
cwd = await get_cwd(
conversation_store,
conversation_id,
runtime.config.workspace_mount_path_in_sandbox,
)
logger.info(f'Getting git changes in {cwd}')
try:
changes = await call_sync_from_async(runtime.get_git_changes)
changes = await call_sync_from_async(runtime.get_git_changes, cwd)
if changes is None:
return JSONResponse(
status_code=404,
status_code=500,
content={'error': 'Not a git repository'},
)
return changes
@@ -210,16 +223,20 @@ async def git_changes(
@app.get('/git/diff')
async def git_diff(
request: Request,
path: str,
conversation_id: str,
):
async def git_diff(request: Request, path: str, conversation_id: str):
runtime: Runtime = request.state.conversation.runtime
logger.info(f'Getting git diff for {path} in {runtime.git_dir}')
conversation_store = await ConversationStoreImpl.get_instance(
config, get_user_id(request), get_github_user_id(request)
)
cwd = await get_cwd(
conversation_store,
conversation_id,
runtime.config.workspace_mount_path_in_sandbox,
)
try:
diff = await call_sync_from_async(runtime.get_git_diff, path)
diff = await call_sync_from_async(runtime.get_git_diff, path, cwd)
return diff
except AgentRuntimeUnavailableError as e:
logger.error(f'Error getting diff: {e}')
@@ -227,3 +244,46 @@ async def git_diff(
status_code=500,
content={'error': f'Error getting diff: {e}'},
)
async def get_cwd(
conversation_store: ConversationStore,
conversation_id: str,
workspace_mount_path_in_sandbox: str,
):
metadata = await conversation_store.get_metadata(conversation_id)
is_running = await conversation_manager.is_agent_loop_running(conversation_id)
conversation_info = await _get_conversation_info(metadata, is_running)
cwd = workspace_mount_path_in_sandbox
if conversation_info and conversation_info.selected_repository:
repo_dir = conversation_info.selected_repository.split('/')[-1]
cwd = os.path.join(cwd, repo_dir)
return cwd
async def _get_conversation_info(
conversation: ConversationMetadata,
is_running: bool,
) -> ConversationInfo | None:
try:
title = conversation.title
if not title:
title = f'Conversation {conversation.conversation_id[:5]}'
return ConversationInfo(
conversation_id=conversation.conversation_id,
title=title,
last_updated_at=conversation.last_updated_at,
created_at=conversation.created_at,
selected_repository=conversation.selected_repository,
status=ConversationStatus.RUNNING
if is_running
else ConversationStatus.STOPPED,
)
except Exception as e:
logger.error(
f'Error loading conversation {conversation.conversation_id}: {str(e)}',
extra={'session_id': conversation.conversation_id},
)
return None
+1 -5
View File
@@ -15,12 +15,8 @@ from openhands.integrations.service_types import (
UnknownException,
User,
)
from openhands.server.auth import get_access_token, get_provider_tokens, get_user_id
from openhands.server.shared import server_config
from openhands.server.user_auth import (
get_access_token,
get_provider_tokens,
get_user_id,
)
app = APIRouter(prefix='/api/user')
+24 -25
View File
@@ -1,7 +1,7 @@
import uuid
from datetime import datetime, timezone
from fastapi import APIRouter, Body, Depends, status
from fastapi import APIRouter, Body, Request, status
from fastapi.responses import JSONResponse
from pydantic import BaseModel
@@ -15,6 +15,11 @@ from openhands.integrations.provider import (
)
from openhands.integrations.service_types import Repository
from openhands.runtime import get_runtime_cls
from openhands.server.auth import (
get_github_user_id,
get_provider_tokens,
get_user_id,
)
from openhands.server.data_models.conversation_info import ConversationInfo
from openhands.server.data_models.conversation_info_result_set import (
ConversationInfoResultSet,
@@ -28,12 +33,6 @@ from openhands.server.shared import (
file_store,
)
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
from openhands.server.user_auth import (
get_provider_tokens,
get_user_id,
)
from openhands.server.utils import get_conversation_store
from openhands.storage.conversation.conversation_store import ConversationStore
from openhands.storage.data_models.conversation_metadata import (
ConversationMetadata,
ConversationTrigger,
@@ -96,7 +95,7 @@ async def _create_new_conversation(
session_init_args['selected_branch'] = selected_branch
conversation_init_data = ConversationInitData(**session_init_args)
logger.info('Loading conversation store')
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
conversation_store = await ConversationStoreImpl.get_instance(config, user_id, None)
logger.info('Conversation store loaded')
conversation_id = uuid.uuid4().hex
@@ -153,17 +152,14 @@ async def _create_new_conversation(
@app.post('/conversations')
async def new_conversation(
data: InitSessionRequest,
user_id: str = Depends(get_user_id),
provider_tokens: PROVIDER_TOKEN_TYPE = Depends(get_provider_tokens),
):
async def new_conversation(request: Request, data: InitSessionRequest):
"""Initialize a new session or join an existing one.
After successful initialization, the client should connect to the WebSocket
using the returned conversation ID.
"""
logger.info('Initializing new conversation')
provider_tokens = get_provider_tokens(request)
selected_repository = data.selected_repository
selected_branch = data.selected_branch
initial_user_msg = data.initial_user_msg
@@ -173,7 +169,7 @@ async def new_conversation(
try:
# Create conversation with initial message
conversation_id = await _create_new_conversation(
user_id,
get_user_id(request),
provider_tokens,
selected_repository,
selected_branch,
@@ -208,11 +204,13 @@ async def new_conversation(
@app.get('/conversations')
async def search_conversations(
request: Request,
page_id: str | None = None,
limit: int = 20,
user_id: str | None = Depends(get_user_id),
conversation_store: ConversationStore = Depends(get_conversation_store),
) -> ConversationInfoResultSet:
conversation_store = await ConversationStoreImpl.get_instance(
config, get_user_id(request), get_github_user_id(request)
)
conversation_metadata_result_set = await conversation_store.search(page_id, limit)
# Filter out conversations older than max_age
@@ -230,7 +228,7 @@ async def search_conversations(
conversation.conversation_id for conversation in filtered_results
)
running_conversations = await conversation_manager.get_running_agent_loops(
user_id, set(conversation_ids)
get_user_id(request), set(conversation_ids)
)
result = ConversationInfoResultSet(
results=await wait_all(
@@ -247,9 +245,11 @@ async def search_conversations(
@app.get('/conversations/{conversation_id}')
async def get_conversation(
conversation_id: str,
conversation_store: ConversationStore = Depends(get_conversation_store),
conversation_id: str, request: Request
) -> ConversationInfo | None:
conversation_store = await ConversationStoreImpl.get_instance(
config, get_user_id(request), get_github_user_id(request)
)
try:
metadata = await conversation_store.get_metadata(conversation_id)
is_running = await conversation_manager.is_agent_loop_running(conversation_id)
@@ -340,12 +340,11 @@ async def auto_generate_title(conversation_id: str, user_id: str | None) -> str:
@app.patch('/conversations/{conversation_id}')
async def update_conversation(
conversation_id: str,
title: str = Body(embed=True),
user_id: str | None = Depends(get_user_id),
request: Request, conversation_id: str, title: str = Body(embed=True)
) -> bool:
user_id = get_user_id(request)
conversation_store = await ConversationStoreImpl.get_instance(
config, user_id
config, user_id, get_github_user_id(request)
)
metadata = await conversation_store.get_metadata(conversation_id)
if not metadata:
@@ -367,10 +366,10 @@ async def update_conversation(
@app.delete('/conversations/{conversation_id}')
async def delete_conversation(
conversation_id: str,
user_id: str | None = Depends(get_user_id),
request: Request,
) -> bool:
conversation_store = await ConversationStoreImpl.get_instance(
config, user_id
config, get_user_id(request), get_github_user_id(request)
)
try:
await conversation_store.get_metadata(conversation_id)
+88 -49
View File
@@ -1,15 +1,11 @@
from fastapi import APIRouter, Depends, Request, status
from fastapi import APIRouter, Request, status
from fastapi.responses import JSONResponse
from pydantic import SecretStr
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.provider import (
PROVIDER_TOKEN_TYPE,
ProviderToken,
ProviderType,
SecretStore,
)
from openhands.integrations.provider import ProviderToken, ProviderType, SecretStore
from openhands.integrations.utils import validate_provider_token
from openhands.server.auth import get_provider_tokens, get_user_id
from openhands.server.settings import (
GETSettingsCustomSecrets,
GETSettingsModel,
@@ -19,24 +15,16 @@ from openhands.server.settings import (
)
from openhands.server.shared import SettingsStoreImpl, config, server_config
from openhands.server.types import AppMode
from openhands.server.user_auth import (
get_provider_tokens,
get_user_id,
get_user_settings,
get_user_settings_store,
)
from openhands.storage.settings.settings_store import SettingsStore
app = APIRouter(prefix='/api')
@app.get('/settings', response_model=GETSettingsModel)
async def load_settings(
user_id: str | None = Depends(get_user_id),
provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens),
settings: Settings | None = Depends(get_user_settings),
) -> GETSettingsModel | JSONResponse:
async def load_settings(request: Request) -> GETSettingsModel | JSONResponse:
try:
user_id = get_user_id(request)
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
settings = await settings_store.load()
if not settings:
return JSONResponse(
status_code=status.HTTP_404_NOT_FOUND,
@@ -48,6 +36,7 @@ async def load_settings(
if bool(user_id):
provider_tokens_set[ProviderType.GITHUB.value] = True
provider_tokens = get_provider_tokens(request)
if provider_tokens:
all_provider_types = [provider.value for provider in ProviderType]
provider_tokens_types = [provider.value for provider in provider_tokens]
@@ -59,7 +48,7 @@ async def load_settings(
settings_with_token_data = GETSettingsModel(
**settings.model_dump(exclude='secrets_store'),
llm_api_key_set=settings.llm_api_key is not None and bool(settings.llm_api_key),
llm_api_key_set=settings.llm_api_key is not None,
provider_tokens_set=provider_tokens_set,
)
settings_with_token_data.llm_api_key = None
@@ -74,9 +63,12 @@ async def load_settings(
@app.get('/secrets', response_model=GETSettingsCustomSecrets)
async def load_custom_secrets_names(
settings: Settings | None = Depends(get_user_settings),
request: Request,
) -> GETSettingsCustomSecrets | JSONResponse:
try:
user_id = get_user_id(request)
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
settings = await settings_store.load()
if not settings:
return JSONResponse(
status_code=status.HTTP_404_NOT_FOUND,
@@ -101,11 +93,13 @@ async def load_custom_secrets_names(
@app.post('/secrets', response_model=dict[str, str])
async def add_custom_secret(
incoming_secrets: POSTSettingsCustomSecrets,
settings_store: SettingsStore = Depends(get_user_settings_store),
request: Request, incoming_secrets: POSTSettingsCustomSecrets
) -> JSONResponse:
try:
existing_settings = await settings_store.load()
settings_store = await SettingsStoreImpl.get_instance(
config, get_user_id(request)
)
existing_settings: Settings = await settings_store.load()
if existing_settings:
for (
secret_name,
@@ -127,6 +121,7 @@ async def add_custom_secret(
update={'secrets_store': updated_secret_store}
)
updated_settings = convert_to_settings(updated_settings)
await settings_store.store(updated_settings)
return JSONResponse(
@@ -142,11 +137,11 @@ async def add_custom_secret(
@app.delete('/secrets/{secret_id}')
async def delete_custom_secret(
secret_id: str,
settings_store: SettingsStore = Depends(get_user_settings_store),
) -> JSONResponse:
async def delete_custom_secret(request: Request, secret_id: str) -> JSONResponse:
try:
settings_store = await SettingsStoreImpl.get_instance(
config, get_user_id(request)
)
existing_settings: Settings | None = await settings_store.load()
custom_secrets = {}
if existing_settings:
@@ -167,6 +162,7 @@ async def delete_custom_secret(
update={'secrets_store': updated_secret_store}
)
updated_settings = convert_to_settings(updated_settings)
await settings_store.store(updated_settings)
return JSONResponse(
@@ -182,10 +178,12 @@ async def delete_custom_secret(
@app.post('/unset-settings-tokens', response_model=dict[str, str])
async def unset_settings_tokens(
settings_store: SettingsStore = Depends(get_user_settings_store),
) -> JSONResponse:
async def unset_settings_tokens(request: Request) -> JSONResponse:
try:
settings_store = await SettingsStoreImpl.get_instance(
config, get_user_id(request)
)
existing_settings = await settings_store.load()
if existing_settings:
settings = existing_settings.model_copy(
@@ -207,20 +205,58 @@ async def unset_settings_tokens(
@app.post('/reset-settings', response_model=dict[str, str])
async def reset_settings() -> JSONResponse:
async def reset_settings(request: Request) -> JSONResponse:
"""
Resets user settings. (Deprecated)
Resets user settings.
"""
logger.warning(
f"Deprecated endpoint /api/reset-settings called by user"
)
return JSONResponse(
status_code=status.HTTP_410_GONE,
content={'error': 'Reset settings functionality has been removed.'},
)
try:
settings_store = await SettingsStoreImpl.get_instance(
config, get_user_id(request)
)
existing_settings = await settings_store.load()
settings = Settings(
language='en',
agent='CodeActAgent',
security_analyzer='',
confirmation_mode=False,
llm_model='anthropic/claude-3-5-sonnet-20241022',
llm_api_key='',
llm_base_url='',
remote_runtime_resource_factor=1,
enable_default_condenser=True,
enable_sound_notifications=False,
user_consents_to_analytics=existing_settings.user_consents_to_analytics
if existing_settings
else False,
)
server_config_values = server_config.get_config()
is_hide_llm_settings_enabled = server_config_values.get(
'FEATURE_FLAGS', {}
).get('HIDE_LLM_SETTINGS', False)
# We don't want the user to be able to modify these settings in SaaS
if server_config.app_mode == AppMode.SAAS and is_hide_llm_settings_enabled:
if existing_settings:
settings.llm_api_key = existing_settings.llm_api_key
settings.llm_base_url = existing_settings.llm_base_url
settings.llm_model = existing_settings.llm_model
await settings_store.store(settings)
return JSONResponse(
status_code=status.HTTP_200_OK,
content={'message': 'Settings stored'},
)
except Exception as e:
logger.warning(f'Something went wrong resetting settings: {e}')
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={'error': 'Something went wrong resetting settings'},
)
async def check_provider_tokens(settings: POSTSettingsModel) -> str:
async def check_provider_tokens(request: Request, settings: POSTSettingsModel) -> str:
if settings.provider_tokens:
# Remove extraneous token types
provider_types = [provider.value for provider in ProviderType]
@@ -240,9 +276,8 @@ async def check_provider_tokens(settings: POSTSettingsModel) -> str:
return ''
async def store_provider_tokens(
settings: POSTSettingsModel, settings_store: SettingsStore
):
async def store_provider_tokens(request: Request, settings: POSTSettingsModel):
settings_store = await SettingsStoreImpl.get_instance(config, get_user_id(request))
existing_settings = await settings_store.load()
if existing_settings:
if settings.provider_tokens:
@@ -276,8 +311,9 @@ async def store_provider_tokens(
async def store_llm_settings(
settings: POSTSettingsModel, settings_store: SettingsStore
request: Request, settings: POSTSettingsModel
) -> POSTSettingsModel:
settings_store = await SettingsStoreImpl.get_instance(config, get_user_id(request))
existing_settings = await settings_store.load()
# Convert to Settings model and merge with existing settings
@@ -295,11 +331,11 @@ async def store_llm_settings(
@app.post('/settings', response_model=dict[str, str])
async def store_settings(
request: Request,
settings: POSTSettingsModel,
settings_store: SettingsStore = Depends(get_user_settings_store),
) -> JSONResponse:
# Check provider tokens are valid
provider_err_msg = await check_provider_tokens(settings)
provider_err_msg = await check_provider_tokens(request, settings)
if provider_err_msg:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
@@ -307,11 +343,14 @@ async def store_settings(
)
try:
settings_store = await SettingsStoreImpl.get_instance(
config, get_user_id(request)
)
existing_settings = await settings_store.load()
# Convert to Settings model and merge with existing settings
if existing_settings:
settings = await store_llm_settings(settings, settings_store)
settings = await store_llm_settings(request, settings)
# Keep existing analytics consent if not provided
if settings.user_consents_to_analytics is None:
@@ -319,7 +358,7 @@ async def store_settings(
existing_settings.user_consents_to_analytics
)
settings = await store_provider_tokens(settings, settings_store)
settings = await store_provider_tokens(request, settings)
# Update sandbox config with new settings
if settings.remote_runtime_resource_factor is not None:
+1 -4
View File
@@ -94,10 +94,7 @@ class Settings(BaseModel):
return {
'provider_tokens': secrets.provider_tokens_serializer(
secrets.provider_tokens, info
),
'custom_secrets': secrets.custom_secrets_serializer(
secrets.custom_secrets, info
),
)
}
@staticmethod
-48
View File
@@ -1,48 +0,0 @@
from fastapi import Request
from pydantic import SecretStr
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
from openhands.integrations.service_types import ProviderType
from openhands.server.settings import Settings
from openhands.server.user_auth.user_auth import get_user_auth
from openhands.storage.settings.settings_store import SettingsStore
async def get_provider_tokens(request: Request) -> PROVIDER_TOKEN_TYPE | None:
user_auth = await get_user_auth(request)
provider_tokens = await user_auth.get_provider_tokens()
return provider_tokens
async def get_access_token(request: Request) -> SecretStr | None:
user_auth = await get_user_auth(request)
access_token = await user_auth.get_access_token()
return access_token
async def get_user_id(request: Request) -> str | None:
user_auth = await get_user_auth(request)
user_id = await user_auth.get_user_id()
return user_id
async def get_github_user_id(request: Request) -> str | None:
provider_tokens = await get_provider_tokens(request)
if not provider_tokens:
return None
github_provider = provider_tokens.get(ProviderType.GITHUB)
if github_provider:
return github_provider.user_id
return None
async def get_user_settings(request: Request) -> Settings | None:
user_auth = await get_user_auth(request)
user_settings = await user_auth.get_user_settings()
return user_settings
async def get_user_settings_store(request: Request) -> SettingsStore | None:
user_auth = await get_user_auth(request)
user_settings_store = await user_auth.get_user_settings_store()
return user_settings_store
@@ -1,57 +0,0 @@
from dataclasses import dataclass
from fastapi import Request
from pydantic import SecretStr
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
from openhands.server import shared
from openhands.server.settings import Settings
from openhands.server.user_auth.user_auth import UserAuth
from openhands.storage.settings.settings_store import SettingsStore
@dataclass
class DefaultUserAuth(UserAuth):
"""Default user authentication mechanism"""
_settings: Settings | None = None
_settings_store: SettingsStore | None = None
async def get_user_id(self) -> str | None:
"""The default implementation does not support multi tenancy, so user_id is always None"""
return None
async def get_access_token(self) -> SecretStr | None:
"""The default implementation does not support multi tenancy, so access_token is always None"""
return None
async def get_user_settings_store(self):
settings_store = self._settings_store
if settings_store:
return settings_store
user_id = await self.get_user_id()
settings_store = await shared.SettingsStoreImpl.get_instance(
shared.config, user_id
)
self._settings_store = settings_store
return settings_store
async def get_user_settings(self) -> Settings | None:
settings = self._settings
if settings:
return settings
settings_store = await self.get_user_settings_store()
settings = await settings_store.load()
self._settings = settings
return settings
async def get_provider_tokens(self) -> PROVIDER_TOKEN_TYPE | None:
settings = await self.get_user_settings()
secrets_store = getattr(settings, 'secrets_store', None)
provider_tokens = getattr(secrets_store, 'provider_tokens', None)
return provider_tokens
@classmethod
async def get_instance(cls, request: Request) -> UserAuth:
user_auth = DefaultUserAuth()
return user_auth
-63
View File
@@ -1,63 +0,0 @@
from __future__ import annotations
import os
from abc import ABC, abstractmethod
from fastapi import Request
from pydantic import SecretStr
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
from openhands.server.settings import Settings
from openhands.server.shared import server_config
from openhands.storage.settings.settings_store import SettingsStore
from openhands.utils.import_utils import get_impl
class UserAuth(ABC):
"""Extensible class encapsulating user Authentication"""
_settings: Settings | None
@abstractmethod
async def get_user_id(self) -> str | None:
"""Get the unique identifier for the current user"""
@abstractmethod
async def get_access_token(self) -> SecretStr | None:
"""Get the access token for the current user"""
@abstractmethod
async def get_provider_tokens(self) -> PROVIDER_TOKEN_TYPE | None:
"""Get the provider tokens for the current user."""
@abstractmethod
async def get_user_settings_store(self) -> SettingsStore | None:
"""Get the settings store for the current user."""
async def get_user_settings(self) -> Settings | None:
"""Get the user settings for the current user"""
settings = self._settings
if settings:
return settings
settings_store = await self.get_user_settings_store()
if settings_store is None:
return None
settings = await settings_store.load()
self._settings = settings
return settings
@classmethod
@abstractmethod
async def get_instance(cls, request: Request) -> UserAuth:
"""Get an instance of UserAuth from the request given"""
async def get_user_auth(request: Request) -> UserAuth:
user_auth = getattr(request.state, 'user_auth', None)
if user_auth:
return user_auth
impl_name = server_config.user_auth_class
impl = get_impl(UserAuth, impl_name)
user_auth = await impl.get_instance(request)
request.state.user_auth = user_auth
return user_auth
-16
View File
@@ -1,16 +0,0 @@
from fastapi import Request
from openhands.server.shared import ConversationStoreImpl, config
from openhands.server.user_auth import get_user_auth
from openhands.storage.conversation.conversation_store import ConversationStore
async def get_conversation_store(request: Request) -> ConversationStore | None:
conversation_store = getattr(request.state, 'conversation_store', None)
if conversation_store:
return conversation_store
user_auth = await get_user_auth(request)
user_id = await user_auth.get_user_id()
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
request.state.conversation_store = conversation_store
return conversation_store
@@ -60,6 +60,6 @@ class ConversationStore(ABC):
@classmethod
@abstractmethod
async def get_instance(
cls, config: AppConfig, user_id: str | None
cls, config: AppConfig, user_id: str | None, github_user_id: str | None
) -> ConversationStore:
"""Get a store for the user represented by the token given."""
@@ -101,7 +101,7 @@ class FileConversationStore(ConversationStore):
@classmethod
async def get_instance(
cls, config: AppConfig, user_id: str | None
cls, config: AppConfig, user_id: str | None, github_user_id: str | None
) -> FileConversationStore:
file_store = get_file_store(config.file_store, config.file_store_path)
return FileConversationStore(file_store)
Generated
+38 -43
View File
@@ -496,18 +496,18 @@ files = [
[[package]]
name = "boto3"
version = "1.38.2"
version = "1.38.1"
description = "The AWS SDK for Python"
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "boto3-1.38.2-py3-none-any.whl", hash = "sha256:ef3237b169cd906a44a32c03b3229833d923c9e9733355b329ded2151f91ec0b"},
{file = "boto3-1.38.2.tar.gz", hash = "sha256:53c8d44b231251fa9421dd13d968236d59fe2cf0421e077afedbf3821653fb3b"},
{file = "boto3-1.38.1-py3-none-any.whl", hash = "sha256:f192a4a34885a9e3e970b5ce5e6bec947be0f3fe6c4693b2a737c14407b12a5a"},
{file = "boto3-1.38.1.tar.gz", hash = "sha256:988e7fae7fd4d59798f84604d73a3a019c07b048f746c7c40258c0e656473887"},
]
[package.dependencies]
botocore = ">=1.38.2,<1.39.0"
botocore = ">=1.38.1,<1.39.0"
jmespath = ">=0.7.1,<2.0.0"
s3transfer = ">=0.12.0,<0.13.0"
@@ -516,14 +516,14 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"]
[[package]]
name = "boto3-stubs"
version = "1.38.2"
description = "Type annotations for boto3 1.38.2 generated with mypy-boto3-builder 8.10.1"
version = "1.38.1"
description = "Type annotations for boto3 1.38.1 generated with mypy-boto3-builder 8.10.1"
optional = false
python-versions = ">=3.8"
groups = ["evaluation"]
files = [
{file = "boto3_stubs-1.38.2-py3-none-any.whl", hash = "sha256:e18f2dc194c4b8a29f61275ba039689d063c4775a78560e35a5ce820ec257fb5"},
{file = "boto3_stubs-1.38.2.tar.gz", hash = "sha256:405cd777d41530cf8ed009d20b04daef1f7d4bd2fd9fd3636ac86eccdb55159c"},
{file = "boto3_stubs-1.38.1-py3-none-any.whl", hash = "sha256:3501f98c39b8c2d613b1138a4e8881ceef2ac9497ac030be47cf4336f1aa0573"},
{file = "boto3_stubs-1.38.1.tar.gz", hash = "sha256:25b03fdbda288c1576fbe002ecf40088e9f5d6cdf0518de8a84a7467aa898092"},
]
[package.dependencies]
@@ -579,7 +579,7 @@ bedrock-data-automation-runtime = ["mypy-boto3-bedrock-data-automation-runtime (
bedrock-runtime = ["mypy-boto3-bedrock-runtime (>=1.38.0,<1.39.0)"]
billing = ["mypy-boto3-billing (>=1.38.0,<1.39.0)"]
billingconductor = ["mypy-boto3-billingconductor (>=1.38.0,<1.39.0)"]
boto3 = ["boto3 (==1.38.2)"]
boto3 = ["boto3 (==1.38.1)"]
braket = ["mypy-boto3-braket (>=1.38.0,<1.39.0)"]
budgets = ["mypy-boto3-budgets (>=1.38.0,<1.39.0)"]
ce = ["mypy-boto3-ce (>=1.38.0,<1.39.0)"]
@@ -943,14 +943,14 @@ xray = ["mypy-boto3-xray (>=1.38.0,<1.39.0)"]
[[package]]
name = "botocore"
version = "1.38.2"
version = "1.38.1"
description = "Low-level, data-driven core of boto 3."
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "botocore-1.38.2-py3-none-any.whl", hash = "sha256:5d9cffedb1c759a058b43793d16647ed44ec87072f98a1bd6cd673ac0ae6b81d"},
{file = "botocore-1.38.2.tar.gz", hash = "sha256:b688a9bd17211a1eaae3a6c965ba9f3973e5435efaaa4fa201f499d3467830e1"},
{file = "botocore-1.38.1-py3-none-any.whl", hash = "sha256:b1673975e3c42d0e2d1804f9f73e88961e95eac371c8f8c0a0d7e661ec3c90c3"},
{file = "botocore-1.38.1.tar.gz", hash = "sha256:c2eb42eeaa502f236ba894a65ea7f7241711150cc450b9d59fbbad41e741adc0"},
]
[package.dependencies]
@@ -2523,7 +2523,6 @@ files = [
{file = "gevent-25.4.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c1d1a66a28372d505e0d8f6f1fdb62f7d5b3423e49431f41b99bd9133f006b7"},
{file = "gevent-25.4.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:fdf9aec76a7285b00fb64ec942cd9ff88f8765874a5abf99c4e8c5374b3133e9"},
{file = "gevent-25.4.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:7442b3ffac08f6239d6463ee2943fd9a619b64b2db11cec292acf8caccb70536"},
{file = "gevent-25.4.2-cp310-cp310-win_amd64.whl", hash = "sha256:d7999e4d4b3597b706a333f9a7bf2efbd8365cd244312405f33b4870fa3b411d"},
{file = "gevent-25.4.2-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:2270a8607661e609c44e4f72811b6380dcfede558041e4ee3134e66753865038"},
{file = "gevent-25.4.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eb89ed32e2b766fcb1afc52847e33d8c369d2b40f23d4c96977fd092b5a0ea86"},
{file = "gevent-25.4.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:43469ed40ea6cfb1c88e8d85a57aa5f52dd6b3b94a2e499752ab7e60a90c7dba"},
@@ -2531,7 +2530,6 @@ files = [
{file = "gevent-25.4.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ccbc835939416a7df7834b79c655409a2a9d2deb9bf119b28dedf72a168f7895"},
{file = "gevent-25.4.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:feb5f2f44dcdad1a6b80e7ce24e7557ce25d01ff13b7a74ca276d113adf9d4af"},
{file = "gevent-25.4.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:91408dd197c13ca0f1e0d5cdcc9870c674963bb87a7e370b2884d1426d73834f"},
{file = "gevent-25.4.2-cp311-cp311-win_amd64.whl", hash = "sha256:12b596c027cf546a235231d421473483fdf7fa586d38162d36b07c8efa9081ba"},
{file = "gevent-25.4.2-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:5940174c7d1ffc7bb4b0ea9f2908f4f361eb03ada9e145d3590b8df1e61c379b"},
{file = "gevent-25.4.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b7ae7ad4ff9c4492d4b633702e35153509b07dc6ffd20f1577076d7647c9caba"},
{file = "gevent-25.4.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d68fdf9bff0068367126983d7d85765124c292b4bc3d4d19ed8138335d8426a7"},
@@ -2539,7 +2537,6 @@ files = [
{file = "gevent-25.4.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a7c70ab6d33dfeb43bfe982c636609d8f90506dacaaa1f409a3c43c66d578fb1"},
{file = "gevent-25.4.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:8e740bc08ba4c34951f4bb6351dbe04209416e12d620691fb57e115b218a7818"},
{file = "gevent-25.4.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:c535d96ded6e26b37fadda9242a49fea6308754da5945173940614b7520c07b4"},
{file = "gevent-25.4.2-cp312-cp312-win_amd64.whl", hash = "sha256:c62bf14557d2cb54f5e3c1ba0a3b3f4b69bf0441081c32d63b205763b495b251"},
{file = "gevent-25.4.2-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:f735f57bc19d0f8bbc784093cfb7953a9ad66612b05c3ff876ec7951a96d7edd"},
{file = "gevent-25.4.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63aecf1e43b8d01086ea574ed05f7272ed40c48dd41fa3d061e3c5ca900abcdd"},
{file = "gevent-25.4.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f12e570777027f807dc7dc3ea1945ea040befaf1c9485deb6f24d7110009fc12"},
@@ -2551,8 +2548,6 @@ files = [
{file = "gevent-25.4.2-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:b0a656eccd9cb115d01c9bbe55bfe84cf20c8422c495503f41aef747b193c33d"},
{file = "gevent-25.4.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95790dd8aeb4ca8df9ac215ec353a29108647797e54daa652a4634ca316f70d4"},
{file = "gevent-25.4.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:76c440972ff57eb64e089f85210ccc0fa247ab71cdedff5414c6b86392f7f791"},
{file = "gevent-25.4.2-cp39-cp39-win32.whl", hash = "sha256:b91e862ab0ddecf37ee6e3bf33965ef4c3e38ba9cdc106eef552293caed512f9"},
{file = "gevent-25.4.2-cp39-cp39-win_amd64.whl", hash = "sha256:03587078c402aee27231ecaabd81aec1e8b3de2629830fbd4486e2d09e638ddc"},
{file = "gevent-25.4.2-pp310-pypy310_pp73-macosx_11_0_universal2.whl", hash = "sha256:498f548330c4724e3b0cee0d75551165fc9e4309ae3ddcba3d644aaa866ca9c3"},
{file = "gevent-25.4.2.tar.gz", hash = "sha256:7ffba461458ed28a85a01285ea0e0dc14f883204d17ce5ed82fa839a9d620028"},
]
@@ -2676,14 +2671,14 @@ grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"]
[[package]]
name = "google-api-python-client"
version = "2.168.0"
version = "2.167.0"
description = "Google API Client Library for Python"
optional = false
python-versions = ">=3.7"
groups = ["main"]
files = [
{file = "google_api_python_client-2.168.0-py3-none-any.whl", hash = "sha256:ebf27fc318a3cf682dc994cefc46b6794eafee91d91fc659d46e018155ace530"},
{file = "google_api_python_client-2.168.0.tar.gz", hash = "sha256:10759c3c8f5bbb17752b349ff336963ab215db150f34594a5581d5cd9b5add41"},
{file = "google_api_python_client-2.167.0-py2.py3-none-any.whl", hash = "sha256:ce25290cc229505d770ca5c8d03850e0ae87d8e998fc6dd743ecece018baa396"},
{file = "google_api_python_client-2.167.0.tar.gz", hash = "sha256:a458d402572e1c2caf9db090d8e7b270f43ff326bd9349c731a86b19910e3995"},
]
[package.dependencies]
@@ -4883,14 +4878,14 @@ files = [
[[package]]
name = "modal"
version = "0.74.23"
version = "0.74.20"
description = "Python client library for Modal"
optional = false
python-versions = ">=3.9"
groups = ["main", "evaluation"]
files = [
{file = "modal-0.74.23-py3-none-any.whl", hash = "sha256:96c397487ed5f499ad040b5edf5f378ada8e0676da17523a2d6fadb3f1d384e1"},
{file = "modal-0.74.23.tar.gz", hash = "sha256:3a042cdf482975b43341da0b33fa6a6adae06978ead69a086ca658a7dcb0cd6d"},
{file = "modal-0.74.20-py3-none-any.whl", hash = "sha256:d6aa369f83399a399b2b151f98124df703a31b8358d5ace3289a88cb61fb9ef0"},
{file = "modal-0.74.20.tar.gz", hash = "sha256:2d4e6e6592c309346448d8e278a0c392596c8a6971b5d789bc81af0d8180f2de"},
]
[package.dependencies]
@@ -7950,30 +7945,30 @@ pyasn1 = ">=0.1.3"
[[package]]
name = "ruff"
version = "0.11.7"
version = "0.11.6"
description = "An extremely fast Python linter and code formatter, written in Rust."
optional = false
python-versions = ">=3.7"
groups = ["dev", "evaluation"]
files = [
{file = "ruff-0.11.7-py3-none-linux_armv6l.whl", hash = "sha256:d29e909d9a8d02f928d72ab7837b5cbc450a5bdf578ab9ebee3263d0a525091c"},
{file = "ruff-0.11.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:dd1fb86b168ae349fb01dd497d83537b2c5541fe0626e70c786427dd8363aaee"},
{file = "ruff-0.11.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d3d7d2e140a6fbbc09033bce65bd7ea29d6a0adeb90b8430262fbacd58c38ada"},
{file = "ruff-0.11.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4809df77de390a1c2077d9b7945d82f44b95d19ceccf0c287c56e4dc9b91ca64"},
{file = "ruff-0.11.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f3a0c2e169e6b545f8e2dba185eabbd9db4f08880032e75aa0e285a6d3f48201"},
{file = "ruff-0.11.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:49b888200a320dd96a68e86736cf531d6afba03e4f6cf098401406a257fcf3d6"},
{file = "ruff-0.11.7-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:2b19cdb9cf7dae00d5ee2e7c013540cdc3b31c4f281f1dacb5a799d610e90db4"},
{file = "ruff-0.11.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:64e0ee994c9e326b43539d133a36a455dbaab477bc84fe7bfbd528abe2f05c1e"},
{file = "ruff-0.11.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bad82052311479a5865f52c76ecee5d468a58ba44fb23ee15079f17dd4c8fd63"},
{file = "ruff-0.11.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7940665e74e7b65d427b82bffc1e46710ec7f30d58b4b2d5016e3f0321436502"},
{file = "ruff-0.11.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:169027e31c52c0e36c44ae9a9c7db35e505fee0b39f8d9fca7274a6305295a92"},
{file = "ruff-0.11.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:305b93f9798aee582e91e34437810439acb28b5fc1fee6b8205c78c806845a94"},
{file = "ruff-0.11.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:a681db041ef55550c371f9cd52a3cf17a0da4c75d6bd691092dfc38170ebc4b6"},
{file = "ruff-0.11.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:07f1496ad00a4a139f4de220b0c97da6d4c85e0e4aa9b2624167b7d4d44fd6b6"},
{file = "ruff-0.11.7-py3-none-win32.whl", hash = "sha256:f25dfb853ad217e6e5f1924ae8a5b3f6709051a13e9dad18690de6c8ff299e26"},
{file = "ruff-0.11.7-py3-none-win_amd64.whl", hash = "sha256:0a931d85959ceb77e92aea4bbedfded0a31534ce191252721128f77e5ae1f98a"},
{file = "ruff-0.11.7-py3-none-win_arm64.whl", hash = "sha256:778c1e5d6f9e91034142dfd06110534ca13220bfaad5c3735f6cb844654f6177"},
{file = "ruff-0.11.7.tar.gz", hash = "sha256:655089ad3224070736dc32844fde783454f8558e71f501cb207485fe4eee23d4"},
{file = "ruff-0.11.6-py3-none-linux_armv6l.whl", hash = "sha256:d84dcbe74cf9356d1bdb4a78cf74fd47c740bf7bdeb7529068f69b08272239a1"},
{file = "ruff-0.11.6-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:9bc583628e1096148011a5d51ff3c836f51899e61112e03e5f2b1573a9b726de"},
{file = "ruff-0.11.6-py3-none-macosx_11_0_arm64.whl", hash = "sha256:f2959049faeb5ba5e3b378709e9d1bf0cab06528b306b9dd6ebd2a312127964a"},
{file = "ruff-0.11.6-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63c5d4e30d9d0de7fedbfb3e9e20d134b73a30c1e74b596f40f0629d5c28a193"},
{file = "ruff-0.11.6-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:26a4b9a4e1439f7d0a091c6763a100cef8fbdc10d68593df6f3cfa5abdd9246e"},
{file = "ruff-0.11.6-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b5edf270223dd622218256569636dc3e708c2cb989242262fe378609eccf1308"},
{file = "ruff-0.11.6-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:f55844e818206a9dd31ff27f91385afb538067e2dc0beb05f82c293ab84f7d55"},
{file = "ruff-0.11.6-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1d8f782286c5ff562e4e00344f954b9320026d8e3fae2ba9e6948443fafd9ffc"},
{file = "ruff-0.11.6-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:01c63ba219514271cee955cd0adc26a4083df1956d57847978383b0e50ffd7d2"},
{file = "ruff-0.11.6-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:15adac20ef2ca296dd3d8e2bedc6202ea6de81c091a74661c3666e5c4c223ff6"},
{file = "ruff-0.11.6-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:4dd6b09e98144ad7aec026f5588e493c65057d1b387dd937d7787baa531d9bc2"},
{file = "ruff-0.11.6-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:45b2e1d6c0eed89c248d024ea95074d0e09988d8e7b1dad8d3ab9a67017a5b03"},
{file = "ruff-0.11.6-py3-none-musllinux_1_2_i686.whl", hash = "sha256:bd40de4115b2ec4850302f1a1d8067f42e70b4990b68838ccb9ccd9f110c5e8b"},
{file = "ruff-0.11.6-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:77cda2dfbac1ab73aef5e514c4cbfc4ec1fbef4b84a44c736cc26f61b3814cd9"},
{file = "ruff-0.11.6-py3-none-win32.whl", hash = "sha256:5151a871554be3036cd6e51d0ec6eef56334d74dfe1702de717a995ee3d5b287"},
{file = "ruff-0.11.6-py3-none-win_amd64.whl", hash = "sha256:cce85721d09c51f3b782c331b0abd07e9d7d5f775840379c640606d3159cae0e"},
{file = "ruff-0.11.6-py3-none-win_arm64.whl", hash = "sha256:3567ba0d07fb170b1b48d944715e3294b77f5b7679e8ba258199a250383ccb79"},
{file = "ruff-0.11.6.tar.gz", hash = "sha256:bec8bcc3ac228a45ccc811e45f7eb61b950dbf4cf31a67fa89352574b01c7d79"},
]
[[package]]
@@ -10257,4 +10252,4 @@ testing = ["coverage[toml]", "zope.event", "zope.testing"]
[metadata]
lock-version = "2.1"
python-versions = "^3.12"
content-hash = "e0d99d8657168051347da0ebbeb0ff23b3c035149627253736cf9d2ec3930435"
content-hash = "ce7e49638e83acefc31930cb89f11f33a0c243552e0ed703e769440305c88aa9"
+1 -1
View File
@@ -78,7 +78,7 @@ playwright = "^1.51.0"
prompt-toolkit = "^3.0.50"
[tool.poetry.group.dev.dependencies]
ruff = "0.11.7"
ruff = "0.11.6"
mypy = "1.15.0"
pre-commit = "4.2.0"
build = "*"
+40 -160
View File
@@ -1,10 +1,11 @@
import json
from contextlib import contextmanager
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import MagicMock, patch
import pytest
from openhands.runtime.impl.docker.docker_runtime import DockerRuntime
from openhands.server.data_models.conversation_info import ConversationInfo
from openhands.server.data_models.conversation_info_result_set import (
ConversationInfoResultSet,
@@ -15,7 +16,6 @@ from openhands.server.routes.manage_conversations import (
search_conversations,
update_conversation,
)
from openhands.storage.data_models.conversation_metadata import ConversationMetadata
from openhands.storage.data_models.conversation_status import ConversationStatus
from openhands.storage.locations import get_conversation_metadata_filename
from openhands.storage.memory import InMemoryFileStore
@@ -72,36 +72,9 @@ async def test_search_conversations():
)
mock_datetime.fromisoformat = datetime.fromisoformat
mock_datetime.timezone = timezone
# Mock the conversation store
mock_store = MagicMock()
mock_store.search = AsyncMock(
return_value=ConversationInfoResultSet(
results=[
ConversationMetadata(
conversation_id='some_conversation_id',
title='Some Conversation',
created_at=datetime.fromisoformat(
'2025-01-01T00:00:00+00:00'
),
last_updated_at=datetime.fromisoformat(
'2025-01-01T00:01:00+00:00'
),
selected_repository='foobar',
github_user_id='12345',
user_id='12345',
)
]
)
)
result_set = await search_conversations(
page_id=None,
limit=20,
user_id='12345',
conversation_store=mock_store,
MagicMock(state=MagicMock(github_token=''))
)
expected = ConversationInfoResultSet(
results=[
ConversationInfo(
@@ -124,51 +97,26 @@ async def test_search_conversations():
@pytest.mark.asyncio
async def test_get_conversation():
with _patch_store():
# Mock the conversation store
mock_store = MagicMock()
mock_store.get_metadata = AsyncMock(
return_value=ConversationMetadata(
conversation_id='some_conversation_id',
title='Some Conversation',
created_at=datetime.fromisoformat('2025-01-01T00:00:00+00:00'),
last_updated_at=datetime.fromisoformat('2025-01-01T00:01:00+00:00'),
selected_repository='foobar',
github_user_id='12345',
user_id='12345',
)
conversation = await get_conversation(
'some_conversation_id', MagicMock(state=MagicMock(github_token=''))
)
# Mock the conversation manager
with patch(
'openhands.server.routes.manage_conversations.conversation_manager'
) as mock_manager:
mock_manager.is_agent_loop_running = AsyncMock(return_value=False)
conversation = await get_conversation(
'some_conversation_id', conversation_store=mock_store
)
expected = ConversationInfo(
conversation_id='some_conversation_id',
title='Some Conversation',
created_at=datetime.fromisoformat('2025-01-01T00:00:00+00:00'),
last_updated_at=datetime.fromisoformat('2025-01-01T00:01:00+00:00'),
status=ConversationStatus.STOPPED,
selected_repository='foobar',
)
assert conversation == expected
expected = ConversationInfo(
conversation_id='some_conversation_id',
title='Some Conversation',
created_at=datetime.fromisoformat('2025-01-01T00:00:00+00:00'),
last_updated_at=datetime.fromisoformat('2025-01-01T00:01:00+00:00'),
status=ConversationStatus.STOPPED,
selected_repository='foobar',
)
assert conversation == expected
@pytest.mark.asyncio
async def test_get_missing_conversation():
with _patch_store():
# Mock the conversation store
mock_store = MagicMock()
mock_store.get_metadata = AsyncMock(side_effect=FileNotFoundError)
assert (
await get_conversation(
'no_such_conversation', conversation_store=mock_store
'no_such_conversation', MagicMock(state=MagicMock(github_token=''))
)
is None
)
@@ -177,102 +125,34 @@ async def test_get_missing_conversation():
@pytest.mark.asyncio
async def test_update_conversation():
with _patch_store():
# Mock the ConversationStoreImpl.get_instance
with patch(
'openhands.server.routes.manage_conversations.ConversationStoreImpl.get_instance'
) as mock_get_instance:
# Create a mock conversation store
mock_store = MagicMock()
# Mock metadata
metadata = ConversationMetadata(
conversation_id='some_conversation_id',
title='Some Conversation',
created_at=datetime.fromisoformat('2025-01-01T00:00:00+00:00'),
last_updated_at=datetime.fromisoformat('2025-01-01T00:01:00+00:00'),
selected_repository='foobar',
github_user_id='12345',
user_id='12345',
)
# Set up the mock to return metadata and then save it
mock_store.get_metadata = AsyncMock(return_value=metadata)
mock_store.save_metadata = AsyncMock()
# Return the mock store from get_instance
mock_get_instance.return_value = mock_store
# Call update_conversation
result = await update_conversation(
'some_conversation_id',
'New Title',
user_id='12345',
)
# Verify the result
assert result is True
# Verify that save_metadata was called with updated metadata
mock_store.save_metadata.assert_called_once()
saved_metadata = mock_store.save_metadata.call_args[0][0]
assert saved_metadata.title == 'New Title'
await update_conversation(
MagicMock(state=MagicMock(github_token='')),
'some_conversation_id',
'New Title',
)
conversation = await get_conversation(
'some_conversation_id', MagicMock(state=MagicMock(github_token=''))
)
expected = ConversationInfo(
conversation_id='some_conversation_id',
title='New Title',
created_at=datetime.fromisoformat('2025-01-01T00:00:00+00:00'),
last_updated_at=datetime.fromisoformat('2025-01-01T00:01:00+00:00'),
status=ConversationStatus.STOPPED,
selected_repository='foobar',
)
assert conversation == expected
@pytest.mark.asyncio
async def test_delete_conversation():
with _patch_store():
# Mock the ConversationStoreImpl.get_instance
with patch(
'openhands.server.routes.manage_conversations.ConversationStoreImpl.get_instance'
) as mock_get_instance:
# Create a mock conversation store
mock_store = MagicMock()
# Set up the mock to return metadata and then delete it
mock_store.get_metadata = AsyncMock(
return_value=ConversationMetadata(
conversation_id='some_conversation_id',
title='Some Conversation',
created_at=datetime.fromisoformat('2025-01-01T00:00:00+00:00'),
last_updated_at=datetime.fromisoformat('2025-01-01T00:01:00+00:00'),
selected_repository='foobar',
github_user_id='12345',
user_id='12345',
)
with patch.object(DockerRuntime, 'delete', return_value=None):
await delete_conversation(
'some_conversation_id',
MagicMock(state=MagicMock(github_token='')),
)
mock_store.delete_metadata = AsyncMock()
# Return the mock store from get_instance
mock_get_instance.return_value = mock_store
# Mock the conversation manager
with patch(
'openhands.server.routes.manage_conversations.conversation_manager'
) as mock_manager:
mock_manager.is_agent_loop_running = AsyncMock(return_value=False)
# Mock the runtime class
with patch(
'openhands.server.routes.manage_conversations.get_runtime_cls'
) as mock_get_runtime_cls:
mock_runtime_cls = MagicMock()
mock_runtime_cls.delete = AsyncMock()
mock_get_runtime_cls.return_value = mock_runtime_cls
# Call delete_conversation
result = await delete_conversation(
'some_conversation_id', user_id='12345'
)
# Verify the result
assert result is True
# Verify that delete_metadata was called
mock_store.delete_metadata.assert_called_once_with(
'some_conversation_id'
)
# Verify that runtime.delete was called
mock_runtime_cls.delete.assert_called_once_with(
'some_conversation_id'
)
conversation = await get_conversation(
'some_conversation_id', MagicMock(state=MagicMock(github_token=''))
)
assert conversation is None
+40 -57
View File
@@ -1,17 +1,11 @@
import asyncio
import os
import shutil
import subprocess
import tempfile
import unittest
import pytest
from openhands.runtime.utils.git_handler import CommandResult, GitHandler
# Mark all test methods as asyncio tests
pytestmark = pytest.mark.asyncio
class TestGitHandler(unittest.TestCase):
def setUp(self):
@@ -110,9 +104,9 @@ class TestGitHandler(unittest.TestCase):
# Push the feature branch to origin
self._execute_command('git push -u origin feature-branch', self.local_dir)
async def test_is_git_repo(self):
def test_is_git_repo(self):
"""Test that _is_git_repo returns True for a git repository."""
self.assertTrue(await self.git_handler._is_git_repo())
self.assertTrue(self.git_handler._is_git_repo())
# Verify the command was executed
self.assertTrue(
@@ -122,9 +116,9 @@ class TestGitHandler(unittest.TestCase):
)
)
async def test_get_default_branch(self):
def test_get_default_branch(self):
"""Test that _get_default_branch returns the correct branch name."""
branch = await self.git_handler._get_default_branch()
branch = self.git_handler._get_default_branch()
self.assertEqual(branch, 'main')
# Verify the command was executed
@@ -135,9 +129,9 @@ class TestGitHandler(unittest.TestCase):
)
)
async def test_get_current_branch(self):
def test_get_current_branch(self):
"""Test that _get_current_branch returns the correct branch name."""
branch = await self.git_handler._get_current_branch()
branch = self.git_handler._get_current_branch()
self.assertEqual(branch, 'feature-branch')
# Verify the command was executed
@@ -148,10 +142,10 @@ class TestGitHandler(unittest.TestCase):
)
)
async def test_get_valid_ref_with_origin_current_branch(self):
def test_get_valid_ref_with_origin_current_branch(self):
"""Test that _get_valid_ref returns the current branch in origin when it exists."""
# This test uses the setup from setUp where the current branch exists in origin
ref = await self.git_handler._get_valid_ref()
ref = self.git_handler._get_valid_ref()
self.assertIsNotNone(ref)
# Check that the refs were checked in the correct order
@@ -171,7 +165,7 @@ class TestGitHandler(unittest.TestCase):
result = self._execute_command(f'git rev-parse --verify {ref}', self.local_dir)
self.assertEqual(result.exit_code, 0)
async def test_get_valid_ref_without_origin_current_branch(self):
def test_get_valid_ref_without_origin_current_branch(self):
"""Test that _get_valid_ref falls back to default branch when current branch doesn't exist in origin."""
# Create a new branch that doesn't exist in origin
self._execute_command('git checkout -b new-local-branch', self.local_dir)
@@ -179,7 +173,7 @@ class TestGitHandler(unittest.TestCase):
# Clear the executed commands to start fresh
self.executed_commands = []
ref = await self.git_handler._get_valid_ref()
ref = self.git_handler._get_valid_ref()
self.assertIsNotNone(ref)
# Check that the refs were checked in the correct order
@@ -202,7 +196,7 @@ class TestGitHandler(unittest.TestCase):
result = self._execute_command(f'git rev-parse --verify {ref}', self.local_dir)
self.assertEqual(result.exit_code, 0)
async def test_get_valid_ref_without_origin(self):
def test_get_valid_ref_without_origin(self):
"""Test that _get_valid_ref falls back to empty tree ref when there's no origin."""
# Create a new directory with a git repo but no origin
no_origin_dir = os.path.join(self.test_dir, 'no-origin')
@@ -213,20 +207,18 @@ class TestGitHandler(unittest.TestCase):
self._execute_command("git config user.email 'test@example.com'", no_origin_dir)
self._execute_command("git config user.name 'Test User'", no_origin_dir)
# Create a file and commit it using subprocess
file_path = os.path.join(no_origin_dir, 'file1.txt')
self._execute_command(
f'echo "Content in repo without origin" > {file_path}', no_origin_dir
)
# Create a file and commit it
with open(os.path.join(no_origin_dir, 'file1.txt'), 'w') as f:
f.write('Content in repo without origin')
self._execute_command('git add file1.txt', no_origin_dir)
self._execute_command("git commit -m 'Initial commit'", no_origin_dir)
# Create a custom GitHandler with a modified _get_default_branch method for this test
class TestGitHandler(GitHandler):
async def _get_default_branch(self) -> str:
def _get_default_branch(self) -> str:
# Override to handle repos without origin
try:
return await super()._get_default_branch()
return super()._get_default_branch()
except IndexError:
return 'main' # Default fallback
@@ -237,7 +229,7 @@ class TestGitHandler(unittest.TestCase):
# Clear the executed commands to start fresh
self.executed_commands = []
ref = await no_origin_handler._get_valid_ref()
ref = no_origin_handler._get_valid_ref()
# Verify that git commands were executed
self.assertTrue(
@@ -259,9 +251,9 @@ class TestGitHandler(unittest.TestCase):
)
self.assertEqual(result.exit_code, 0)
async def test_get_ref_content(self):
def test_get_ref_content(self):
"""Test that _get_ref_content returns the content from a valid ref."""
content = await self.git_handler._get_ref_content('file1.txt')
content = self.git_handler._get_ref_content('file1.txt')
self.assertEqual(content.strip(), 'Modified content')
# Should have called _get_valid_ref and then git show
@@ -270,9 +262,9 @@ class TestGitHandler(unittest.TestCase):
]
self.assertTrue(any('file1.txt' in cmd for cmd in show_commands))
async def test_get_current_file_content(self):
def test_get_current_file_content(self):
"""Test that _get_current_file_content returns the current content of a file."""
content = await self.git_handler._get_current_file_content('file1.txt')
content = self.git_handler._get_current_file_content('file1.txt')
self.assertEqual(content.strip(), 'Modified content again')
# Verify the command was executed
@@ -280,15 +272,14 @@ class TestGitHandler(unittest.TestCase):
any(cmd == 'cat file1.txt' for cmd, _ in self.executed_commands)
)
async def test_get_changed_files(self):
def test_get_changed_files(self):
"""Test that _get_changed_files returns the list of changed files."""
# Let's create a new file to ensure it shows up in the diff
# Use subprocess directly to create and add the file
file_path = os.path.join(self.local_dir, 'new_file.txt')
self._execute_command(f'echo "New file content" > {file_path}', self.local_dir)
with open(os.path.join(self.local_dir, 'new_file.txt'), 'w') as f:
f.write('New file content')
self._execute_command('git add new_file.txt', self.local_dir)
files = await self.git_handler._get_changed_files()
files = self.git_handler._get_changed_files()
self.assertTrue(files)
# Should include file1.txt (modified) and file3.txt (deleted)
@@ -304,15 +295,13 @@ class TestGitHandler(unittest.TestCase):
]
self.assertTrue(diff_commands)
async def test_get_untracked_files(self):
def test_get_untracked_files(self):
"""Test that _get_untracked_files returns the list of untracked files."""
# Create an untracked file using subprocess
file_path = os.path.join(self.local_dir, 'untracked.txt')
self._execute_command(
f'echo "Untracked file content" > {file_path}', self.local_dir
)
# Create an untracked file
with open(os.path.join(self.local_dir, 'untracked.txt'), 'w') as f:
f.write('Untracked file content')
files = await self.git_handler._get_untracked_files()
files = self.git_handler._get_untracked_files()
self.assertEqual(len(files), 1)
self.assertEqual(files[0]['path'], 'untracked.txt')
self.assertEqual(files[0]['status'], 'A')
@@ -325,22 +314,18 @@ class TestGitHandler(unittest.TestCase):
)
)
async def test_get_git_changes(self):
def test_get_git_changes(self):
"""Test that get_git_changes returns the combined list of changed and untracked files."""
# Create an untracked file using subprocess
file_path = os.path.join(self.local_dir, 'untracked.txt')
self._execute_command(
f'echo "Untracked file content" > {file_path}', self.local_dir
)
# Create an untracked file
with open(os.path.join(self.local_dir, 'untracked.txt'), 'w') as f:
f.write('Untracked file content')
# Create a new file and stage it
file_path2 = os.path.join(self.local_dir, 'new_file2.txt')
self._execute_command(
f'echo "New file 2 content" > {file_path2}', self.local_dir
)
with open(os.path.join(self.local_dir, 'new_file2.txt'), 'w') as f:
f.write('New file 2 content')
self._execute_command('git add new_file2.txt', self.local_dir)
changes = await self.git_handler.get_git_changes()
changes = self.git_handler.get_git_changes()
self.assertIsNotNone(changes)
# Should include file1.txt (modified), file3.txt (deleted), new_file2.txt (added), and untracked.txt (untracked)
@@ -356,9 +341,9 @@ class TestGitHandler(unittest.TestCase):
self.assertIn('A', statuses) # Added
self.assertIn('D', statuses) # Deleted
async def test_get_git_diff(self):
def test_get_git_diff(self):
"""Test that get_git_diff returns the original and modified content of a file."""
diff = await self.git_handler.get_git_diff('file1.txt')
diff = self.git_handler.get_git_diff('file1.txt')
self.assertEqual(diff['modified'].strip(), 'Modified content again')
self.assertEqual(diff['original'].strip(), 'Modified content')
@@ -375,6 +360,4 @@ class TestGitHandler(unittest.TestCase):
if __name__ == '__main__':
import asyncio
asyncio.run(unittest.main())
unittest.main()
+313 -334
View File
@@ -1,8 +1,6 @@
"""Tests for the custom secrets API endpoints."""
# flake8: noqa: E501
from contextlib import contextmanager
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import FastAPI
@@ -12,8 +10,6 @@ from pydantic import SecretStr
from openhands.integrations.provider import ProviderToken, ProviderType, SecretStore
from openhands.server.routes.settings import app as settings_app
from openhands.server.settings import Settings
from openhands.storage.memory import InMemoryFileStore
from openhands.storage.settings.file_settings_store import FileSettingsStore
@pytest.fixture
@@ -24,128 +20,127 @@ def test_client():
return TestClient(app)
@contextmanager
def patch_file_settings_store():
store = FileSettingsStore(InMemoryFileStore())
with patch(
'openhands.storage.settings.file_settings_store.FileSettingsStore.get_instance',
AsyncMock(return_value=store),
):
yield store
@pytest.fixture
def mock_settings_store():
with patch('openhands.server.routes.settings.SettingsStoreImpl') as mock:
store_instance = MagicMock()
mock.get_instance = AsyncMock(return_value=store_instance)
store_instance.load = AsyncMock()
store_instance.store = AsyncMock()
yield store_instance
@pytest.fixture
def mock_convert_to_settings():
with patch('openhands.server.routes.settings.convert_to_settings') as mock:
# Make the mock function pass through the input settings
mock.side_effect = lambda settings: settings
yield mock
@pytest.fixture
def mock_get_user_id():
with patch('openhands.server.routes.settings.get_user_id') as mock:
mock.return_value = 'test-user'
yield mock
@pytest.mark.asyncio
async def test_load_custom_secrets_names(test_client):
async def test_load_custom_secrets_names(test_client, mock_settings_store):
"""Test loading custom secrets names."""
with patch_file_settings_store() as file_settings_store:
# Create initial settings with custom secrets
custom_secrets = {
'API_KEY': SecretStr('api-key-value'),
'DB_PASSWORD': SecretStr('db-password-value'),
}
provider_tokens = {
ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token'))
}
secret_store = SecretStore(
custom_secrets=custom_secrets, provider_tokens=provider_tokens
)
initial_settings = Settings(
language='en',
agent='test-agent',
llm_api_key=SecretStr('test-llm-key'),
secrets_store=secret_store,
)
# Create initial settings with custom secrets
custom_secrets = {
'API_KEY': SecretStr('api-key-value'),
'DB_PASSWORD': SecretStr('db-password-value'),
}
provider_tokens = {
ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token'))
}
secret_store = SecretStore(
custom_secrets=custom_secrets, provider_tokens=provider_tokens
)
initial_settings = Settings(
language='en',
agent='test-agent',
llm_api_key=SecretStr('test-llm-key'),
secrets_store=secret_store,
)
# Store the initial settings
await file_settings_store.store(initial_settings)
# Mock the settings store to return our initial settings
mock_settings_store.load.return_value = initial_settings
# Make the GET request
response = test_client.get('/api/secrets')
assert response.status_code == 200
# Make the GET request
response = test_client.get('/api/secrets')
assert response.status_code == 200
# Check the response
data = response.json()
assert 'custom_secrets' in data
assert sorted(data['custom_secrets']) == ['API_KEY', 'DB_PASSWORD']
# Verify that the original settings were not modified
stored_settings = await file_settings_store.load()
assert (
stored_settings.secrets_store.custom_secrets['API_KEY'].get_secret_value()
== 'api-key-value'
)
assert (
stored_settings.secrets_store.custom_secrets[
'DB_PASSWORD'
].get_secret_value()
== 'db-password-value'
)
assert ProviderType.GITHUB in stored_settings.secrets_store.provider_tokens
# Check the response
data = response.json()
assert 'custom_secrets' in data
assert sorted(data['custom_secrets']) == ['API_KEY', 'DB_PASSWORD']
@pytest.mark.asyncio
async def test_load_custom_secrets_names_empty(test_client):
async def test_load_custom_secrets_names_empty(test_client, mock_settings_store):
"""Test loading custom secrets names when there are no custom secrets."""
with patch_file_settings_store() as file_settings_store:
# Create initial settings with no custom secrets
provider_tokens = {
ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token'))
}
secret_store = SecretStore(provider_tokens=provider_tokens)
initial_settings = Settings(
language='en',
agent='test-agent',
llm_api_key=SecretStr('test-llm-key'),
secrets_store=secret_store,
)
# Create initial settings with no custom secrets
provider_tokens = {
ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token'))
}
secret_store = SecretStore(provider_tokens=provider_tokens)
initial_settings = Settings(
language='en',
agent='test-agent',
llm_api_key=SecretStr('test-llm-key'),
secrets_store=secret_store,
)
# Store the initial settings
await file_settings_store.store(initial_settings)
# Mock the settings store to return our initial settings
mock_settings_store.load.return_value = initial_settings
# Make the GET request
response = test_client.get('/api/secrets')
assert response.status_code == 200
# Make the GET request
response = test_client.get('/api/secrets')
assert response.status_code == 200
# Check the response
data = response.json()
assert 'custom_secrets' in data
assert data['custom_secrets'] == []
# Check the response
data = response.json()
assert 'custom_secrets' in data
assert data['custom_secrets'] == []
@pytest.mark.asyncio
async def test_add_custom_secret(test_client):
async def test_add_custom_secret(
test_client, mock_settings_store, mock_convert_to_settings
):
"""Test adding a new custom secret."""
# Create initial settings with provider tokens but no custom secrets
provider_tokens = {
ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token'))
}
secret_store = SecretStore(provider_tokens=provider_tokens)
initial_settings = Settings(
language='en',
agent='test-agent',
llm_api_key=SecretStr('test-llm-key'),
secrets_store=secret_store,
)
with patch_file_settings_store() as file_settings_store:
# Create initial settings with provider tokens but no custom secrets
provider_tokens = {
ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token'))
}
secret_store = SecretStore(provider_tokens=provider_tokens)
initial_settings = Settings(
language='en',
agent='test-agent',
llm_api_key=SecretStr('test-llm-key'),
secrets_store=secret_store,
)
# Mock the settings store to return our initial settings
mock_settings_store.load.return_value = initial_settings
# Store the initial settings
await file_settings_store.store(initial_settings)
# Make the POST request to add a custom secret
add_secret_data = {'custom_secrets': {'API_KEY': 'api-key-value'}}
response = test_client.post('/api/secrets', json=add_secret_data)
assert response.status_code == 200
# Make the POST request to add a custom secret
add_secret_data = {'custom_secrets': {'API_KEY': 'api-key-value'}}
response = test_client.post('/api/secrets', json=add_secret_data)
assert response.status_code == 200
# Verify that the settings were stored with the new secret
stored_settings = mock_settings_store.store.call_args[0][0]
# Verify that the settings were stored with the new secret
stored_settings = await file_settings_store.load()
# Check that the secret was added
assert 'API_KEY' in stored_settings.secrets_store.custom_secrets
assert (
stored_settings.secrets_store.custom_secrets['API_KEY'].get_secret_value()
== 'api-key-value'
)
# Check that the secret was added
assert 'API_KEY' in stored_settings.secrets_store.custom_secrets
assert (
stored_settings.secrets_store.custom_secrets['API_KEY'].get_secret_value()
== 'api-key-value'
)
# Check that other settings were preserved
assert stored_settings.language == 'en'
@@ -154,274 +149,258 @@ async def test_add_custom_secret(test_client):
@pytest.mark.asyncio
async def test_update_existing_custom_secret(test_client):
async def test_update_existing_custom_secret(
test_client, mock_settings_store, mock_convert_to_settings
):
"""Test updating an existing custom secret."""
with patch_file_settings_store() as file_settings_store:
# Create initial settings with a custom secret
custom_secrets = {'API_KEY': SecretStr('old-api-key')}
provider_tokens = {
ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token'))
}
secret_store = SecretStore(
custom_secrets=custom_secrets, provider_tokens=provider_tokens
)
initial_settings = Settings(
language='en',
agent='test-agent',
llm_api_key=SecretStr('test-llm-key'),
secrets_store=secret_store,
)
# Create initial settings with a custom secret
custom_secrets = {'API_KEY': SecretStr('old-api-key')}
provider_tokens = {
ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token'))
}
secret_store = SecretStore(
custom_secrets=custom_secrets, provider_tokens=provider_tokens
)
initial_settings = Settings(
language='en',
agent='test-agent',
llm_api_key=SecretStr('test-llm-key'),
secrets_store=secret_store,
)
# Store the initial settings
await file_settings_store.store(initial_settings)
# Mock the settings store to return our initial settings
mock_settings_store.load.return_value = initial_settings
# Make the POST request to update the custom secret
update_secret_data = {'custom_secrets': {'API_KEY': 'new-api-key'}}
response = test_client.post('/api/secrets', json=update_secret_data)
assert response.status_code == 200
# Make the POST request to update the custom secret
update_secret_data = {'custom_secrets': {'API_KEY': 'new-api-key'}}
response = test_client.post('/api/secrets', json=update_secret_data)
assert response.status_code == 200
# Verify that the settings were stored with the updated secret
stored_settings = await file_settings_store.load()
# Verify that the settings were stored with the updated secret
stored_settings: Settings = mock_settings_store.store.call_args[0][0]
# Check that the secret was updated
assert 'API_KEY' in stored_settings.secrets_store.custom_secrets
assert (
stored_settings.secrets_store.custom_secrets['API_KEY'].get_secret_value()
== 'new-api-key'
)
# Check that the secret was updated
assert 'API_KEY' in stored_settings.secrets_store.custom_secrets
assert (
stored_settings.secrets_store.custom_secrets['API_KEY'].get_secret_value()
== 'new-api-key'
)
# Check that other settings were preserved
assert stored_settings.language == 'en'
assert stored_settings.agent == 'test-agent'
assert stored_settings.llm_api_key.get_secret_value() == 'test-llm-key'
assert ProviderType.GITHUB in stored_settings.secrets_store.provider_tokens
# Check that other settings were preserved
assert stored_settings.language == 'en'
assert stored_settings.agent == 'test-agent'
assert stored_settings.llm_api_key.get_secret_value() == 'test-llm-key'
assert ProviderType.GITHUB in stored_settings.secrets_store.provider_tokens
@pytest.mark.asyncio
async def test_add_multiple_custom_secrets(test_client):
async def test_add_multiple_custom_secrets(
test_client, mock_settings_store, mock_convert_to_settings
):
"""Test adding multiple custom secrets at once."""
with patch_file_settings_store() as file_settings_store:
# Create initial settings with one custom secret
custom_secrets = {'EXISTING_SECRET': SecretStr('existing-value')}
provider_tokens = {
ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token'))
# Create initial settings with one custom secret
custom_secrets = {'EXISTING_SECRET': SecretStr('existing-value')}
provider_tokens = {
ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token'))
}
secret_store = SecretStore(
custom_secrets=custom_secrets, provider_tokens=provider_tokens
)
initial_settings = Settings(
language='en',
agent='test-agent',
llm_api_key=SecretStr('test-llm-key'),
secrets_store=secret_store,
)
# Mock the settings store to return our initial settings
mock_settings_store.load.return_value = initial_settings
# Make the POST request to add multiple custom secrets
add_secrets_data = {
'custom_secrets': {
'API_KEY': 'api-key-value',
'DB_PASSWORD': 'db-password-value',
}
secret_store = SecretStore(
custom_secrets=custom_secrets, provider_tokens=provider_tokens
)
initial_settings = Settings(
language='en',
agent='test-agent',
llm_api_key=SecretStr('test-llm-key'),
secrets_store=secret_store,
)
}
response = test_client.post('/api/secrets', json=add_secrets_data)
assert response.status_code == 200
# Store the initial settings
await file_settings_store.store(initial_settings)
# Verify that the settings were stored with the new secrets
stored_settings = mock_settings_store.store.call_args[0][0]
# Make the POST request to add multiple custom secrets
add_secrets_data = {
'custom_secrets': {
'API_KEY': 'api-key-value',
'DB_PASSWORD': 'db-password-value',
}
}
response = test_client.post('/api/secrets', json=add_secrets_data)
assert response.status_code == 200
# Check that the new secrets were added
assert 'API_KEY' in stored_settings.secrets_store.custom_secrets
assert (
stored_settings.secrets_store.custom_secrets['API_KEY'].get_secret_value()
== 'api-key-value'
)
assert 'DB_PASSWORD' in stored_settings.secrets_store.custom_secrets
assert (
stored_settings.secrets_store.custom_secrets['DB_PASSWORD'].get_secret_value()
== 'db-password-value'
)
# Verify that the settings were stored with the new secrets
stored_settings = await file_settings_store.load()
# Check that existing secrets were preserved
assert 'EXISTING_SECRET' in stored_settings.secrets_store.custom_secrets
assert (
stored_settings.secrets_store.custom_secrets[
'EXISTING_SECRET'
].get_secret_value()
== 'existing-value'
)
# Check that the new secrets were added
assert 'API_KEY' in stored_settings.secrets_store.custom_secrets
assert (
stored_settings.secrets_store.custom_secrets['API_KEY'].get_secret_value()
== 'api-key-value'
)
assert 'DB_PASSWORD' in stored_settings.secrets_store.custom_secrets
assert (
stored_settings.secrets_store.custom_secrets[
'DB_PASSWORD'
].get_secret_value()
== 'db-password-value'
)
# Check that existing secrets were preserved
assert 'EXISTING_SECRET' in stored_settings.secrets_store.custom_secrets
assert (
stored_settings.secrets_store.custom_secrets[
'EXISTING_SECRET'
].get_secret_value()
== 'existing-value'
)
# Check that other settings were preserved
assert stored_settings.language == 'en'
assert stored_settings.agent == 'test-agent'
assert stored_settings.llm_api_key.get_secret_value() == 'test-llm-key'
assert ProviderType.GITHUB in stored_settings.secrets_store.provider_tokens
# Check that other settings were preserved
assert stored_settings.language == 'en'
assert stored_settings.agent == 'test-agent'
assert stored_settings.llm_api_key.get_secret_value() == 'test-llm-key'
@pytest.mark.asyncio
async def test_delete_custom_secret(test_client):
async def test_delete_custom_secret(
test_client, mock_settings_store, mock_convert_to_settings
):
"""Test deleting a custom secret."""
with patch_file_settings_store() as file_settings_store:
# Create initial settings with multiple custom secrets
custom_secrets = {
'API_KEY': SecretStr('api-key-value'),
'DB_PASSWORD': SecretStr('db-password-value'),
}
provider_tokens = {
ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token'))
}
secret_store = SecretStore(
custom_secrets=custom_secrets, provider_tokens=provider_tokens
)
initial_settings = Settings(
language='en',
agent='test-agent',
llm_api_key=SecretStr('test-llm-key'),
secrets_store=secret_store,
)
# Create initial settings with multiple custom secrets
custom_secrets = {
'API_KEY': SecretStr('api-key-value'),
'DB_PASSWORD': SecretStr('db-password-value'),
}
provider_tokens = {
ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token'))
}
secret_store = SecretStore(
custom_secrets=custom_secrets, provider_tokens=provider_tokens
)
initial_settings = Settings(
language='en',
agent='test-agent',
llm_api_key=SecretStr('test-llm-key'),
secrets_store=secret_store,
)
# Store the initial settings
await file_settings_store.store(initial_settings)
# Mock the settings store to return our initial settings
mock_settings_store.load.return_value = initial_settings
# Make the DELETE request to delete a custom secret
response = test_client.delete('/api/secrets/API_KEY')
assert response.status_code == 200
# Make the DELETE request to delete a custom secret
response = test_client.delete('/api/secrets/API_KEY')
assert response.status_code == 200
# Verify that the settings were stored without the deleted secret
stored_settings = await file_settings_store.load()
# Verify that the settings were stored without the deleted secret
stored_settings = mock_settings_store.store.call_args[0][0]
# Check that the specified secret was deleted
assert 'API_KEY' not in stored_settings.secrets_store.custom_secrets
# Check that the specified secret was deleted
assert 'API_KEY' not in stored_settings.secrets_store.custom_secrets
# Check that other secrets were preserved
assert 'DB_PASSWORD' in stored_settings.secrets_store.custom_secrets
assert (
stored_settings.secrets_store.custom_secrets[
'DB_PASSWORD'
].get_secret_value()
== 'db-password-value'
)
# Check that other secrets were preserved
assert 'DB_PASSWORD' in stored_settings.secrets_store.custom_secrets
assert (
stored_settings.secrets_store.custom_secrets['DB_PASSWORD'].get_secret_value()
== 'db-password-value'
)
# Check that other settings were preserved
assert stored_settings.language == 'en'
assert stored_settings.agent == 'test-agent'
assert stored_settings.llm_api_key.get_secret_value() == 'test-llm-key'
assert ProviderType.GITHUB in stored_settings.secrets_store.provider_tokens
# Check that other settings were preserved
assert stored_settings.language == 'en'
assert stored_settings.agent == 'test-agent'
assert stored_settings.llm_api_key.get_secret_value() == 'test-llm-key'
@pytest.mark.asyncio
async def test_delete_nonexistent_custom_secret(test_client):
async def test_delete_nonexistent_custom_secret(
test_client, mock_settings_store, mock_convert_to_settings
):
"""Test deleting a custom secret that doesn't exist."""
with patch_file_settings_store() as file_settings_store:
# Create initial settings with a custom secret
custom_secrets = {'API_KEY': SecretStr('api-key-value')}
provider_tokens = {
ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token'))
}
secret_store = SecretStore(
custom_secrets=custom_secrets, provider_tokens=provider_tokens
)
initial_settings = Settings(
language='en',
agent='test-agent',
llm_api_key=SecretStr('test-llm-key'),
secrets_store=secret_store,
)
# Create initial settings with a custom secret
custom_secrets = {'API_KEY': SecretStr('api-key-value')}
provider_tokens = {
ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token'))
}
secret_store = SecretStore(
custom_secrets=custom_secrets, provider_tokens=provider_tokens
)
initial_settings = Settings(
language='en',
agent='test-agent',
llm_api_key=SecretStr('test-llm-key'),
secrets_store=secret_store,
)
# Store the initial settings
await file_settings_store.store(initial_settings)
# Mock the settings store to return our initial settings
mock_settings_store.load.return_value = initial_settings
# Make the DELETE request to delete a nonexistent custom secret
response = test_client.delete('/api/secrets/NONEXISTENT_KEY')
assert response.status_code == 200
# Make the DELETE request to delete a nonexistent custom secret
response = test_client.delete('/api/secrets/NONEXISTENT_KEY')
assert response.status_code == 200
# Verify that the settings were stored without changes to existing secrets
stored_settings = await file_settings_store.load()
# Verify that the settings were stored without changes to existing secrets
stored_settings = mock_settings_store.store.call_args[0][0]
# Check that the existing secret was preserved
assert 'API_KEY' in stored_settings.secrets_store.custom_secrets
assert (
stored_settings.secrets_store.custom_secrets['API_KEY'].get_secret_value()
== 'api-key-value'
)
# Check that the existing secret was preserved
assert 'API_KEY' in stored_settings.secrets_store.custom_secrets
assert (
stored_settings.secrets_store.custom_secrets['API_KEY'].get_secret_value()
== 'api-key-value'
)
# Check that other settings were preserved
assert stored_settings.language == 'en'
assert stored_settings.agent == 'test-agent'
assert stored_settings.llm_api_key.get_secret_value() == 'test-llm-key'
assert ProviderType.GITHUB in stored_settings.secrets_store.provider_tokens
# Check that other settings were preserved
assert stored_settings.language == 'en'
assert stored_settings.agent == 'test-agent'
assert stored_settings.llm_api_key.get_secret_value() == 'test-llm-key'
@pytest.mark.asyncio
async def test_custom_secrets_operations_preserve_settings(test_client):
async def test_custom_secrets_operations_preserve_settings(
test_client, mock_settings_store, mock_convert_to_settings
):
"""Test that operations on custom secrets preserve all other settings."""
with patch_file_settings_store() as file_settings_store:
# Create initial settings with comprehensive data
custom_secrets = {'INITIAL_SECRET': SecretStr('initial-value')}
provider_tokens = {
ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token')),
ProviderType.GITLAB: ProviderToken(token=SecretStr('gitlab-token')),
}
secret_store = SecretStore(
custom_secrets=custom_secrets, provider_tokens=provider_tokens
)
initial_settings = Settings(
language='en',
agent='test-agent',
max_iterations=100,
security_analyzer='default',
confirmation_mode=True,
llm_model='test-model',
llm_api_key=SecretStr('test-llm-key'),
llm_base_url='https://test.com',
remote_runtime_resource_factor=2,
enable_default_condenser=True,
enable_sound_notifications=False,
user_consents_to_analytics=True,
secrets_store=secret_store,
)
# Create initial settings with comprehensive data
custom_secrets = {'INITIAL_SECRET': SecretStr('initial-value')}
provider_tokens = {
ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token')),
ProviderType.GITLAB: ProviderToken(token=SecretStr('gitlab-token')),
}
secret_store = SecretStore(
custom_secrets=custom_secrets, provider_tokens=provider_tokens
)
initial_settings = Settings(
language='en',
agent='test-agent',
max_iterations=100,
security_analyzer='default',
confirmation_mode=True,
llm_model='test-model',
llm_api_key=SecretStr('test-llm-key'),
llm_base_url='https://test.com',
remote_runtime_resource_factor=2,
enable_default_condenser=True,
enable_sound_notifications=False,
user_consents_to_analytics=True,
secrets_store=secret_store,
)
# Store the initial settings
await file_settings_store.store(initial_settings)
# Mock the settings store to return our initial settings
mock_settings_store.load.return_value = initial_settings
# 1. Test adding a new custom secret
add_secret_data = {'custom_secrets': {'NEW_SECRET': 'new-value'}}
response = test_client.post('/api/secrets', json=add_secret_data)
assert response.status_code == 200
# 1. Test adding a new custom secret
add_secret_data = {'custom_secrets': {'NEW_SECRET': 'new-value'}}
response = test_client.post('/api/secrets', json=add_secret_data)
assert response.status_code == 200
# Verify all settings are preserved
stored_settings = await file_settings_store.load()
assert stored_settings.language == 'en'
assert stored_settings.agent == 'test-agent'
assert stored_settings.max_iterations == 100
assert stored_settings.security_analyzer == 'default'
assert stored_settings.confirmation_mode is True
assert stored_settings.llm_model == 'test-model'
assert stored_settings.llm_api_key.get_secret_value() == 'test-llm-key'
assert stored_settings.llm_base_url == 'https://test.com'
assert stored_settings.remote_runtime_resource_factor == 2
assert stored_settings.enable_default_condenser is True
assert stored_settings.enable_sound_notifications is False
assert stored_settings.user_consents_to_analytics is True
assert len(stored_settings.secrets_store.provider_tokens) == 2
assert ProviderType.GITHUB in stored_settings.secrets_store.provider_tokens
assert ProviderType.GITLAB in stored_settings.secrets_store.provider_tokens
assert (
stored_settings.secrets_store.custom_secrets[
'INITIAL_SECRET'
].get_secret_value()
== 'initial-value'
)
assert (
stored_settings.secrets_store.custom_secrets[
'NEW_SECRET'
].get_secret_value()
== 'new-value'
)
# Verify all settings are preserved
stored_settings = mock_settings_store.store.call_args[0][0]
assert stored_settings.language == 'en'
assert stored_settings.agent == 'test-agent'
assert stored_settings.max_iterations == 100
assert stored_settings.security_analyzer == 'default'
assert stored_settings.confirmation_mode is True
assert stored_settings.llm_model == 'test-model'
assert stored_settings.llm_api_key.get_secret_value() == 'test-llm-key'
assert stored_settings.llm_base_url == 'https://test.com'
assert stored_settings.remote_runtime_resource_factor == 2
assert stored_settings.enable_default_condenser is True
assert stored_settings.enable_sound_notifications is False
assert stored_settings.user_consents_to_analytics is True
assert len(stored_settings.secrets_store.provider_tokens) == 2
# 2. Test updating an existing custom secret
update_secret_data = {'custom_secrets': {'INITIAL_SECRET': 'updated-value'}}
@@ -429,7 +408,7 @@ async def test_custom_secrets_operations_preserve_settings(test_client):
assert response.status_code == 200
# Verify all settings are still preserved
stored_settings = await file_settings_store.load()
stored_settings = mock_settings_store.store.call_args[0][0]
assert stored_settings.language == 'en'
assert stored_settings.agent == 'test-agent'
assert stored_settings.max_iterations == 100
@@ -449,7 +428,7 @@ async def test_custom_secrets_operations_preserve_settings(test_client):
assert response.status_code == 200
# Verify all settings are still preserved
stored_settings = await file_settings_store.load()
stored_settings = mock_settings_store.store.call_args[0][0]
assert stored_settings.language == 'en'
assert stored_settings.agent == 'test-agent'
assert stored_settings.max_iterations == 100
+226 -57
View File
@@ -1,60 +1,82 @@
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import Request
from fastapi.testclient import TestClient
from pydantic import SecretStr
from openhands.integrations.provider import ProviderToken, ProviderType
from openhands.core.config.sandbox_config import SandboxConfig
from openhands.integrations.provider import ProviderType, SecretStore
from openhands.server.app import app
from openhands.server.user_auth.user_auth import UserAuth
from openhands.storage.settings.settings_store import SettingsStore
class MockUserAuth(UserAuth):
"""Mock implementation of UserAuth for testing"""
def __init__(self):
self._settings = None
self._settings_store = MagicMock()
self._settings_store.load = AsyncMock(return_value=None)
self._settings_store.store = AsyncMock()
async def get_user_id(self) -> str | None:
return 'test-user'
async def get_access_token(self) -> SecretStr | None:
return SecretStr('test-token')
async def get_provider_tokens(self) -> dict[ProviderType, ProviderToken] | None: # noqa: E501
return None
async def get_user_settings_store(self) -> SettingsStore | None:
return self._settings_store
@classmethod
async def get_instance(cls, request: Request) -> UserAuth:
return MockUserAuth()
from openhands.server.settings import Settings
@pytest.fixture
def test_client():
# Create a test client
with patch(
'openhands.server.user_auth.user_auth.UserAuth.get_instance',
return_value=MockUserAuth(),
):
with patch(
'openhands.server.routes.settings.validate_provider_token',
return_value=ProviderType.GITHUB,
):
client = TestClient(app)
yield client
def mock_settings_store():
with patch('openhands.server.routes.settings.SettingsStoreImpl') as mock:
store_instance = MagicMock()
mock.get_instance = AsyncMock(return_value=store_instance)
store_instance.load = AsyncMock()
store_instance.store = AsyncMock()
yield store_instance
@pytest.fixture
def mock_get_user_id():
with patch('openhands.server.routes.settings.get_user_id') as mock:
mock.return_value = 'test-user'
yield mock
@pytest.fixture
def mock_validate_provider_token():
with patch('openhands.server.routes.settings.validate_provider_token') as mock:
async def mock_determine(*args, **kwargs):
return ProviderType.GITHUB
mock.side_effect = mock_determine
yield mock
@pytest.fixture
def test_client(mock_settings_store):
# Mock the middleware that adds github_token
class MockMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
settings = mock_settings_store.load.return_value
token = None
if settings and settings.secrets_store.provider_tokens.get(
ProviderType.GITHUB
):
token = settings.secrets_store.provider_tokens[
ProviderType.GITHUB
].token
if scope['type'] == 'http':
scope['state'] = {'token': token}
await self.app(scope, receive, send)
# Replace the middleware
app.middleware_stack = None # Clear existing middleware
app.add_middleware(MockMiddleware)
return TestClient(app)
@pytest.fixture
def mock_github_service():
with patch('openhands.server.routes.settings.GitHubService') as mock:
yield mock
@pytest.mark.asyncio
async def test_settings_api_endpoints(test_client):
"""Test that the settings API endpoints work with the new auth system"""
async def test_settings_api_runtime_factor(
test_client, mock_settings_store, mock_get_user_id, mock_validate_provider_token
):
# Mock the settings store to return None initially (no existing settings)
mock_settings_store.load.return_value = None
# Test data with remote_runtime_resource_factor
settings_data = {
@@ -70,29 +92,176 @@ async def test_settings_api_endpoints(test_client):
'provider_tokens': {'github': 'test-token'},
}
# The test_client fixture already handles authentication
# Make the POST request to store settings
response = test_client.post('/api/settings', json=settings_data)
# We're not checking the exact response, just that it doesn't error
assert response.status_code == 200
# Test the GET settings endpoint
# Verify the settings were stored with the correct runtime factor
stored_settings = mock_settings_store.store.call_args[0][0]
assert stored_settings.remote_runtime_resource_factor == 2
# Mock settings store to return our settings for the GET request
mock_settings_store.load.return_value = Settings(**settings_data)
# Make a GET request to retrieve settings
response = test_client.get('/api/settings')
assert response.status_code == 200
assert response.json()['remote_runtime_resource_factor'] == 2
# Verify that the sandbox config gets updated when settings are loaded
with patch('openhands.server.shared.config') as mock_config:
mock_config.sandbox = SandboxConfig()
response = test_client.get('/api/settings')
assert response.status_code == 200
# Verify that the sandbox config was updated with the new value
mock_settings_store.store.assert_called()
stored_settings = mock_settings_store.store.call_args[0][0]
assert stored_settings.remote_runtime_resource_factor == 2
assert isinstance(stored_settings.llm_api_key, SecretStr)
assert stored_settings.llm_api_key.get_secret_value() == 'test-key'
@pytest.mark.asyncio
async def test_settings_llm_api_key(
test_client, mock_settings_store, mock_get_user_id, mock_validate_provider_token
):
# Mock the settings store to return None initially (no existing settings)
mock_settings_store.load.return_value = None
# Test data with remote_runtime_resource_factor
settings_data = {
'llm_api_key': 'test-key',
'provider_tokens': {'github': 'test-token'},
}
# The test_client fixture already handles authentication
# Make the POST request to store settings
response = test_client.post('/api/settings', json=settings_data)
assert response.status_code == 200
# Verify the settings were stored with the correct secret API key
stored_settings = mock_settings_store.store.call_args[0][0]
assert isinstance(stored_settings.llm_api_key, SecretStr)
assert stored_settings.llm_api_key.get_secret_value() == 'test-key'
# Mock settings store to return our settings for the GET request
mock_settings_store.load.return_value = Settings(**settings_data)
# Make a GET request to retrieve settings
response = test_client.get('/api/settings')
assert response.status_code == 200
# Test updating with partial settings
partial_settings = {
'language': 'fr',
'llm_model': None, # Should preserve existing value
'llm_api_key': None, # Should preserve existing value
# We should never expose the API key in the response
assert 'test-key' not in response.json()
@pytest.mark.skip(
reason='Mock middleware does not seem to properly set the github_token'
)
@pytest.mark.asyncio
async def test_settings_api_set_github_token(
mock_github_service,
test_client,
mock_settings_store,
mock_get_user_id,
mock_validate_provider_token,
):
# Test data with provider token set
settings_data = {
'language': 'en',
'agent': 'test-agent',
'max_iterations': 100,
'security_analyzer': 'default',
'confirmation_mode': True,
'llm_model': 'test-model',
'llm_api_key': 'test-key',
'llm_base_url': 'https://test.com',
'provider_tokens': {'github': 'test-token'},
}
response = test_client.post('/api/settings', json=partial_settings)
# Make the POST request to store settings
response = test_client.post('/api/settings', json=settings_data)
assert response.status_code == 200
# Test the unset-settings-tokens endpoint
response = test_client.post('/api/unset-settings-tokens')
# Verify the settings were stored with the provider token
stored_settings = mock_settings_store.store.call_args[0][0]
assert (
stored_settings.secrets_store.provider_tokens[
ProviderType.GITHUB
].token.get_secret_value()
== 'test-token'
)
# Mock settings store to return our settings for the GET request
mock_settings_store.load.return_value = Settings(**settings_data)
# Make a GET request to retrieve settings
response = test_client.get('/api/settings')
data = response.json()
assert response.status_code == 200
assert data.get('token') is None
assert data['token_is_set'] is True
@pytest.mark.asyncio
async def test_settings_preserve_llm_fields_when_none(test_client, mock_settings_store):
# Setup initial settings with LLM fields populated
initial_settings = Settings(
language='en',
agent='test-agent',
max_iterations=100,
security_analyzer='default',
confirmation_mode=True,
llm_model='existing-model',
llm_api_key=SecretStr('existing-key'),
llm_base_url='https://existing.com',
secrets_store=SecretStore(),
)
# Mock the settings store to return our initial settings
mock_settings_store.load.return_value = initial_settings
# Test data with None values for LLM fields
settings_update = {
'language': 'fr', # Change something else to verify the update happens
'llm_model': None,
'llm_api_key': None,
'llm_base_url': None,
}
# Make the POST request to update settings
response = test_client.post('/api/settings', json=settings_update)
assert response.status_code == 200
# We'll skip the secrets endpoints for now as they require more complex mocking # noqa: E501
# and they're not directly related to the authentication refactoring
# Verify that the settings were stored with preserved LLM values
stored_settings = mock_settings_store.store.call_args[0][0]
# Check that language was updated
assert stored_settings.language == 'fr'
# Check that LLM fields were preserved and not cleared
assert stored_settings.llm_model == 'existing-model'
assert isinstance(stored_settings.llm_api_key, SecretStr)
assert stored_settings.llm_api_key.get_secret_value() == 'existing-key'
assert stored_settings.llm_base_url == 'https://existing.com'
# Update the mock to return our new settings for the GET request
mock_settings_store.load.return_value = stored_settings
# Make a GET request to verify the updated settings
response = test_client.get('/api/settings')
assert response.status_code == 200
data = response.json()
# Verify fields in the response
assert data['language'] == 'fr'
assert data['llm_model'] == 'existing-model'
assert data['llm_base_url'] == 'https://existing.com'
# We expect the API key not to be included in the response
assert 'test-key' not in str(response.content)
+101 -67
View File
@@ -23,6 +23,7 @@ async def get_settings_store(request):
@pytest.mark.asyncio
async def test_check_provider_tokens_valid():
"""Test check_provider_tokens with valid tokens."""
mock_request = MagicMock()
settings = POSTSettingsModel(provider_tokens={'github': 'valid-token'})
# Mock the validate_provider_token function to return GITHUB for valid tokens
@@ -31,7 +32,7 @@ async def test_check_provider_tokens_valid():
) as mock_validate:
mock_validate.return_value = ProviderType.GITHUB
result = await check_provider_tokens(settings)
result = await check_provider_tokens(mock_request, settings)
# Should return empty string for valid token
assert result == ''
@@ -41,6 +42,7 @@ async def test_check_provider_tokens_valid():
@pytest.mark.asyncio
async def test_check_provider_tokens_invalid():
"""Test check_provider_tokens with invalid tokens."""
mock_request = MagicMock()
settings = POSTSettingsModel(provider_tokens={'github': 'invalid-token'})
# Mock the validate_provider_token function to return None for invalid tokens
@@ -49,7 +51,7 @@ async def test_check_provider_tokens_invalid():
) as mock_validate:
mock_validate.return_value = None
result = await check_provider_tokens(settings)
result = await check_provider_tokens(mock_request, settings)
# Should return error message for invalid token
assert 'Invalid token' in result
@@ -59,9 +61,10 @@ async def test_check_provider_tokens_invalid():
@pytest.mark.asyncio
async def test_check_provider_tokens_wrong_type():
"""Test check_provider_tokens with unsupported provider type."""
mock_request = MagicMock()
settings = POSTSettingsModel(provider_tokens={'unsupported': 'some-token'})
result = await check_provider_tokens(settings)
result = await check_provider_tokens(mock_request, settings)
# Should return empty string for unsupported provider
assert result == ''
@@ -70,9 +73,10 @@ async def test_check_provider_tokens_wrong_type():
@pytest.mark.asyncio
async def test_check_provider_tokens_no_tokens():
"""Test check_provider_tokens with no tokens."""
mock_request = MagicMock()
settings = POSTSettingsModel(provider_tokens={})
result = await check_provider_tokens(settings)
result = await check_provider_tokens(mock_request, settings)
# Should return empty string when no tokens provided
assert result == ''
@@ -82,6 +86,7 @@ async def test_check_provider_tokens_no_tokens():
@pytest.mark.asyncio
async def test_store_llm_settings_new_settings():
"""Test store_llm_settings with new settings."""
mock_request = MagicMock()
settings = POSTSettingsModel(
llm_model='gpt-4',
llm_api_key='test-api-key',
@@ -89,20 +94,25 @@ async def test_store_llm_settings_new_settings():
)
# Mock the settings store
mock_store = MagicMock()
mock_store.load = AsyncMock(return_value=None) # No existing settings
with patch(
'openhands.server.routes.settings.SettingsStoreImpl.get_instance'
) as mock_get_store:
mock_store = MagicMock()
mock_store.load = AsyncMock(return_value=None) # No existing settings
mock_get_store.return_value = mock_store
result = await store_llm_settings(settings, mock_store)
result = await store_llm_settings(mock_request, settings)
# Should return settings with the provided values
assert result.llm_model == 'gpt-4'
assert result.llm_api_key.get_secret_value() == 'test-api-key'
assert result.llm_base_url == 'https://api.example.com'
# Should return settings with the provided values
assert result.llm_model == 'gpt-4'
assert result.llm_api_key.get_secret_value() == 'test-api-key'
assert result.llm_base_url == 'https://api.example.com'
@pytest.mark.asyncio
async def test_store_llm_settings_update_existing():
"""Test store_llm_settings updates existing settings."""
mock_request = MagicMock()
settings = POSTSettingsModel(
llm_model='gpt-4',
llm_api_key='new-api-key',
@@ -110,118 +120,142 @@ async def test_store_llm_settings_update_existing():
)
# Mock the settings store
mock_store = MagicMock()
with patch(
'openhands.server.routes.settings.SettingsStoreImpl.get_instance'
) as mock_get_store:
mock_store = MagicMock()
# Create existing settings
existing_settings = Settings(
llm_model='gpt-3.5',
llm_api_key=SecretStr('old-api-key'),
llm_base_url='https://old.example.com',
)
# Create existing settings
existing_settings = Settings(
llm_model='gpt-3.5',
llm_api_key=SecretStr('old-api-key'),
llm_base_url='https://old.example.com',
)
mock_store.load = AsyncMock(return_value=existing_settings)
mock_store.load = AsyncMock(return_value=existing_settings)
mock_get_store.return_value = mock_store
result = await store_llm_settings(settings, mock_store)
result = await store_llm_settings(mock_request, settings)
# Should return settings with the updated values
assert result.llm_model == 'gpt-4'
assert result.llm_api_key.get_secret_value() == 'new-api-key'
assert result.llm_base_url == 'https://new.example.com'
# Should return settings with the updated values
assert result.llm_model == 'gpt-4'
assert result.llm_api_key.get_secret_value() == 'new-api-key'
assert result.llm_base_url == 'https://new.example.com'
@pytest.mark.asyncio
async def test_store_llm_settings_partial_update():
"""Test store_llm_settings with partial update."""
mock_request = MagicMock()
settings = POSTSettingsModel(
llm_model='gpt-4' # Only updating model
)
# Mock the settings store
mock_store = MagicMock()
with patch(
'openhands.server.routes.settings.SettingsStoreImpl.get_instance'
) as mock_get_store:
mock_store = MagicMock()
# Create existing settings
existing_settings = Settings(
llm_model='gpt-3.5',
llm_api_key=SecretStr('existing-api-key'),
llm_base_url='https://existing.example.com',
)
# Create existing settings
existing_settings = Settings(
llm_model='gpt-3.5',
llm_api_key=SecretStr('existing-api-key'),
llm_base_url='https://existing.example.com',
)
mock_store.load = AsyncMock(return_value=existing_settings)
mock_store.load = AsyncMock(return_value=existing_settings)
mock_get_store.return_value = mock_store
result = await store_llm_settings(settings, mock_store)
result = await store_llm_settings(mock_request, settings)
# Should return settings with updated model but keep other values
assert result.llm_model == 'gpt-4'
# For SecretStr objects, we need to compare the secret value
assert result.llm_api_key.get_secret_value() == 'existing-api-key'
assert result.llm_base_url == 'https://existing.example.com'
# Should return settings with updated model but keep other values
assert result.llm_model == 'gpt-4'
# For SecretStr objects, we need to compare the secret value
assert result.llm_api_key.get_secret_value() == 'existing-api-key'
assert result.llm_base_url == 'https://existing.example.com'
# Tests for store_provider_tokens
@pytest.mark.asyncio
async def test_store_provider_tokens_new_tokens():
"""Test store_provider_tokens with new tokens."""
mock_request = MagicMock()
settings = POSTSettingsModel(provider_tokens={'github': 'new-token'})
# Mock the settings store
mock_store = MagicMock()
mock_store.load = AsyncMock(return_value=None) # No existing settings
with patch(
'openhands.server.routes.settings.SettingsStoreImpl.get_instance'
) as mock_get_store:
mock_store = MagicMock()
mock_store.load = AsyncMock(return_value=None) # No existing settings
mock_get_store.return_value = mock_store
result = await store_provider_tokens(settings, mock_store)
result = await store_provider_tokens(mock_request, settings)
# Should return settings with the provided tokens
assert result.provider_tokens == {'github': 'new-token'}
# Should return settings with the provided tokens
assert result.provider_tokens == {'github': 'new-token'}
@pytest.mark.asyncio
async def test_store_provider_tokens_update_existing():
"""Test store_provider_tokens updates existing tokens."""
mock_request = MagicMock()
settings = POSTSettingsModel(provider_tokens={'github': 'updated-token'})
# Mock the settings store
mock_store = MagicMock()
with patch(
'openhands.server.routes.settings.SettingsStoreImpl.get_instance'
) as mock_get_store:
mock_store = MagicMock()
# Create existing settings with a GitHub token
github_token = ProviderToken(token=SecretStr('old-token'))
provider_tokens = {ProviderType.GITHUB: github_token}
# Create existing settings with a GitHub token
github_token = ProviderToken(token=SecretStr('old-token'))
provider_tokens = {ProviderType.GITHUB: github_token}
# Create a SecretStore with the provider tokens
secrets_store = SecretStore(provider_tokens=provider_tokens)
# Create a SecretStore with the provider tokens
secrets_store = SecretStore(provider_tokens=provider_tokens)
# Create existing settings with the secrets store
existing_settings = Settings(secrets_store=secrets_store)
# Create existing settings with the secrets store
existing_settings = Settings(secrets_store=secrets_store)
mock_store.load = AsyncMock(return_value=existing_settings)
mock_store.load = AsyncMock(return_value=existing_settings)
mock_get_store.return_value = mock_store
result = await store_provider_tokens(settings, mock_store)
result = await store_provider_tokens(mock_request, settings)
# Should return settings with the updated tokens
assert result.provider_tokens == {'github': 'updated-token'}
# Should return settings with the updated tokens
assert result.provider_tokens == {'github': 'updated-token'}
@pytest.mark.asyncio
async def test_store_provider_tokens_keep_existing():
"""Test store_provider_tokens keeps existing tokens when empty string provided."""
mock_request = MagicMock()
settings = POSTSettingsModel(
provider_tokens={'github': ''} # Empty string should keep existing token
)
# Mock the settings store
mock_store = MagicMock()
with patch(
'openhands.server.routes.settings.SettingsStoreImpl.get_instance'
) as mock_get_store:
mock_store = MagicMock()
# Create existing settings with a GitHub token
github_token = ProviderToken(token=SecretStr('existing-token'))
provider_tokens = {ProviderType.GITHUB: github_token}
# Create existing settings with a GitHub token
github_token = ProviderToken(token=SecretStr('existing-token'))
provider_tokens = {ProviderType.GITHUB: github_token}
# Create a SecretStore with the provider tokens
secrets_store = SecretStore(provider_tokens=provider_tokens)
# Create a SecretStore with the provider tokens
secrets_store = SecretStore(provider_tokens=provider_tokens)
# Create existing settings with the secrets store
existing_settings = Settings(secrets_store=secrets_store)
# Create existing settings with the secrets store
existing_settings = Settings(secrets_store=secrets_store)
mock_store.load = AsyncMock(return_value=existing_settings)
mock_store.load = AsyncMock(return_value=existing_settings)
mock_get_store.return_value = mock_store
result = await store_provider_tokens(settings, mock_store)
result = await store_provider_tokens(mock_request, settings)
# Should return settings with the existing token preserved
assert result.provider_tokens == {'github': 'existing-token'}
# Should return settings with the existing token preserved
assert result.provider_tokens == {'github': 'existing-token'}