mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 29ce9f9822 | |||
| b29252fbef |
@@ -61,8 +61,8 @@ RUN add-apt-repository ppa:deadsnakes/ppa \
|
||||
&& apt-get install -y python3.12 python3.12-venv python3.12-dev python3-pip \
|
||||
&& ln -s /usr/bin/python3.12 /usr/bin/python
|
||||
|
||||
# NodeJS >= 22.x
|
||||
RUN curl -fsSL https://deb.nodesource.com/setup_22.x | bash - \
|
||||
# NodeJS >= 18.17.1
|
||||
RUN curl -fsSL https://deb.nodesource.com/setup_18.x | bash - \
|
||||
&& apt-get install -y nodejs
|
||||
|
||||
# Poetry >= 1.8
|
||||
@@ -108,7 +108,7 @@ WORKDIR /app
|
||||
|
||||
# cache build dependencies
|
||||
RUN \
|
||||
--mount=type=bind,source=./,target=/app/,rw \
|
||||
--mount=type=bind,source=./,target=/app/ \
|
||||
<<EOF
|
||||
#!/bin/bash
|
||||
make -s clean
|
||||
|
||||
Vendored
+40
-27
@@ -1646,32 +1646,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/reset-settings": {
|
||||
"post": {
|
||||
"summary": "Reset settings (Deprecated)",
|
||||
"description": "This endpoint is deprecated and will return a 410 Gone error. Reset functionality has been removed.",
|
||||
"operationId": "resetSettings",
|
||||
"deprecated": true,
|
||||
"responses": {
|
||||
"410": {
|
||||
"description": "Feature removed",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"error": {
|
||||
"type": "string",
|
||||
"example": "Reset settings functionality has been removed."
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/unset-settings-tokens": {
|
||||
"post": {
|
||||
"summary": "Unset settings tokens",
|
||||
@@ -1711,6 +1685,45 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/reset-settings": {
|
||||
"post": {
|
||||
"summary": "Reset settings",
|
||||
"description": "Reset user settings to defaults",
|
||||
"operationId": "resetSettings",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Settings reset successfully",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"500": {
|
||||
"description": "Error resetting settings",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"error": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/options/models": {
|
||||
"get": {
|
||||
"summary": "Get models",
|
||||
@@ -2082,4 +2095,4 @@
|
||||
"bearerAuth": []
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -25,6 +25,7 @@ const mock_provider_tokens_are_set: Record<Provider, boolean> = {
|
||||
describe("Settings Screen", () => {
|
||||
const getSettingsSpy = vi.spyOn(OpenHands, "getSettings");
|
||||
const saveSettingsSpy = vi.spyOn(OpenHands, "saveSettings");
|
||||
const resetSettingsSpy = vi.spyOn(OpenHands, "resetSettings");
|
||||
const getConfigSpy = vi.spyOn(OpenHands, "getConfig");
|
||||
|
||||
const { handleLogoutMock } = vi.hoisted(() => ({
|
||||
@@ -66,6 +67,7 @@ describe("Settings Screen", () => {
|
||||
// Use queryAllByText to handle multiple elements with the same text
|
||||
expect(screen.queryAllByText("SETTINGS$LLM_SETTINGS")).not.toHaveLength(0);
|
||||
screen.getByText("ACCOUNT_SETTINGS$ADDITIONAL_SETTINGS");
|
||||
screen.getByText("BUTTON$RESET_TO_DEFAULTS");
|
||||
screen.getByText("BUTTON$SAVE");
|
||||
});
|
||||
});
|
||||
@@ -540,6 +542,54 @@ describe("Settings Screen", () => {
|
||||
});
|
||||
});
|
||||
|
||||
test("resetting settings with no changes but having advanced enabled should hide the advanced items", async () => {
|
||||
const user = userEvent.setup();
|
||||
|
||||
getSettingsSpy.mockResolvedValueOnce({
|
||||
...MOCK_DEFAULT_USER_SETTINGS,
|
||||
});
|
||||
|
||||
renderSettingsScreen();
|
||||
|
||||
await toggleAdvancedSettings(user);
|
||||
|
||||
const resetButton = screen.getByText("BUTTON$RESET_TO_DEFAULTS");
|
||||
await user.click(resetButton);
|
||||
|
||||
// show modal
|
||||
const modal = await screen.findByTestId("reset-modal");
|
||||
expect(modal).toBeInTheDocument();
|
||||
|
||||
// Mock the settings that will be returned after reset
|
||||
// This should be the default settings with no advanced settings enabled
|
||||
getSettingsSpy.mockResolvedValueOnce({
|
||||
...MOCK_DEFAULT_USER_SETTINGS,
|
||||
llm_base_url: "",
|
||||
confirmation_mode: false,
|
||||
security_analyzer: "",
|
||||
});
|
||||
|
||||
// confirm reset
|
||||
const confirmButton = within(modal).getByText("Reset");
|
||||
await user.click(confirmButton);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.queryByTestId("llm-custom-model-input"),
|
||||
).not.toBeInTheDocument();
|
||||
expect(
|
||||
screen.queryByTestId("base-url-input"),
|
||||
).not.toBeInTheDocument();
|
||||
expect(screen.queryByTestId("agent-input")).not.toBeInTheDocument();
|
||||
expect(
|
||||
screen.queryByTestId("security-analyzer-input"),
|
||||
).not.toBeInTheDocument();
|
||||
expect(
|
||||
screen.queryByTestId("enable-confirmation-mode-switch"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should save if only confirmation mode is enabled", async () => {
|
||||
const user = userEvent.setup();
|
||||
renderSettingsScreen();
|
||||
@@ -712,6 +762,81 @@ describe("Settings Screen", () => {
|
||||
);
|
||||
});
|
||||
|
||||
it("should reset the settings when the 'Reset to defaults' button is clicked", async () => {
|
||||
const user = userEvent.setup();
|
||||
getSettingsSpy.mockResolvedValue(MOCK_DEFAULT_USER_SETTINGS);
|
||||
|
||||
renderSettingsScreen();
|
||||
|
||||
const languageInput = await screen.findByTestId("language-input");
|
||||
await user.click(languageInput);
|
||||
|
||||
const norskOption = await screen.findByText("Norsk");
|
||||
await user.click(norskOption);
|
||||
|
||||
expect(languageInput).toHaveValue("Norsk");
|
||||
|
||||
const resetButton = screen.getByText("BUTTON$RESET_TO_DEFAULTS");
|
||||
await user.click(resetButton);
|
||||
|
||||
expect(saveSettingsSpy).not.toHaveBeenCalled();
|
||||
|
||||
// show modal
|
||||
const modal = await screen.findByTestId("reset-modal");
|
||||
expect(modal).toBeInTheDocument();
|
||||
|
||||
// confirm reset
|
||||
const confirmButton = within(modal).getByText("Reset");
|
||||
await user.click(confirmButton);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(resetSettingsSpy).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
// Mock the settings response after reset
|
||||
getSettingsSpy.mockResolvedValueOnce({
|
||||
...MOCK_DEFAULT_USER_SETTINGS,
|
||||
llm_base_url: "",
|
||||
confirmation_mode: false,
|
||||
security_analyzer: "",
|
||||
});
|
||||
|
||||
// Wait for the mutation to complete and the modal to be removed
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByTestId("reset-modal")).not.toBeInTheDocument();
|
||||
expect(
|
||||
screen.queryByTestId("llm-custom-model-input"),
|
||||
).not.toBeInTheDocument();
|
||||
expect(screen.queryByTestId("base-url-input")).not.toBeInTheDocument();
|
||||
expect(screen.queryByTestId("agent-input")).not.toBeInTheDocument();
|
||||
expect(
|
||||
screen.queryByTestId("security-analyzer-input"),
|
||||
).not.toBeInTheDocument();
|
||||
expect(
|
||||
screen.queryByTestId("enable-confirmation-mode-switch"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should cancel the reset when the 'Cancel' button is clicked", async () => {
|
||||
const user = userEvent.setup();
|
||||
getSettingsSpy.mockResolvedValue(MOCK_DEFAULT_USER_SETTINGS);
|
||||
|
||||
renderSettingsScreen();
|
||||
|
||||
const resetButton = await screen.findByText("BUTTON$RESET_TO_DEFAULTS");
|
||||
await user.click(resetButton);
|
||||
|
||||
const modal = await screen.findByTestId("reset-modal");
|
||||
expect(modal).toBeInTheDocument();
|
||||
|
||||
const cancelButton = within(modal).getByText("Cancel");
|
||||
await user.click(cancelButton);
|
||||
|
||||
expect(saveSettingsSpy).not.toHaveBeenCalled();
|
||||
expect(screen.queryByTestId("reset-modal")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should call handleCaptureConsent with true if the save is successful", async () => {
|
||||
const user = userEvent.setup();
|
||||
const handleCaptureConsentSpy = vi.spyOn(
|
||||
@@ -919,5 +1044,18 @@ describe("Settings Screen", () => {
|
||||
);
|
||||
});
|
||||
|
||||
it("should not submit the unwanted fields when resetting", async () => {
|
||||
const user = userEvent.setup();
|
||||
renderSettingsScreen();
|
||||
|
||||
const resetButton = await screen.findByText("BUTTON$RESET_TO_DEFAULTS");
|
||||
await user.click(resetButton);
|
||||
|
||||
const modal = await screen.findByTestId("reset-modal");
|
||||
const confirmButton = within(modal).getByText("Reset");
|
||||
await user.click(confirmButton);
|
||||
expect(saveSettingsSpy).not.toHaveBeenCalled();
|
||||
expect(resetSettingsSpy).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -199,6 +199,14 @@ class OpenHands {
|
||||
return data.status === 200;
|
||||
}
|
||||
|
||||
/**
|
||||
* Reset user settings in server
|
||||
*/
|
||||
static async resetSettings(): Promise<boolean> {
|
||||
const response = await openHands.post("/api/reset-settings");
|
||||
return response.status === 200;
|
||||
}
|
||||
|
||||
static async createCheckoutSession(amount: number): Promise<string> {
|
||||
const { data } = await openHands.post(
|
||||
"/api/billing/create-checkout-session",
|
||||
|
||||
@@ -4,7 +4,15 @@ import OpenHands from "#/api/open-hands";
|
||||
import { PostSettings, PostApiSettings } from "#/types/settings";
|
||||
import { useSettings } from "../query/use-settings";
|
||||
|
||||
const saveSettingsMutationFn = async (settings: Partial<PostSettings>) => {
|
||||
const saveSettingsMutationFn = async (
|
||||
settings: Partial<PostSettings> | null,
|
||||
) => {
|
||||
// If settings is null, we're resetting
|
||||
if (settings === null) {
|
||||
await OpenHands.resetSettings();
|
||||
return;
|
||||
}
|
||||
|
||||
const apiSettings: Partial<PostApiSettings> = {
|
||||
llm_model: settings.LLM_MODEL,
|
||||
llm_base_url: settings.LLM_BASE_URL,
|
||||
@@ -31,7 +39,12 @@ export const useSaveSettings = () => {
|
||||
const { data: currentSettings } = useSettings();
|
||||
|
||||
return useMutation({
|
||||
mutationFn: async (settings: Partial<PostSettings>) => {
|
||||
mutationFn: async (settings: Partial<PostSettings> | null) => {
|
||||
if (settings === null) {
|
||||
await saveSettingsMutationFn(null);
|
||||
return;
|
||||
}
|
||||
|
||||
const newSettings = { ...currentSettings, ...settings };
|
||||
await saveSettingsMutationFn(newSettings);
|
||||
},
|
||||
|
||||
@@ -11,8 +11,14 @@ export const useUpdateConversation = () => {
|
||||
conversation: Partial<Omit<Conversation, "id">>;
|
||||
}) =>
|
||||
OpenHands.updateUserConversation(variables.id, variables.conversation),
|
||||
onSuccess: () => {
|
||||
onSuccess: (_, variables) => {
|
||||
// Invalidate the conversations list
|
||||
queryClient.invalidateQueries({ queryKey: ["user", "conversations"] });
|
||||
|
||||
// Also invalidate the specific conversation to ensure title updates are reflected
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: ["user", "conversation", variables.id],
|
||||
});
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
@@ -7,6 +7,6 @@ export const useUserConversation = (cid: string | null) =>
|
||||
queryFn: () => OpenHands.getConversation(cid!),
|
||||
enabled: !!cid,
|
||||
retry: false,
|
||||
staleTime: 1000 * 60 * 5, // 5 minutes
|
||||
staleTime: 1000 * 30, // 30 seconds - reduced from 5 minutes for more responsive title updates
|
||||
gcTime: 1000 * 60 * 15, // 15 minutes
|
||||
});
|
||||
|
||||
@@ -48,6 +48,7 @@ export function useAutoTitle() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Request title generation from the backend
|
||||
updateConversation(
|
||||
{
|
||||
id: conversationId,
|
||||
@@ -56,13 +57,30 @@ export function useAutoTitle() {
|
||||
{
|
||||
onSuccess: async () => {
|
||||
try {
|
||||
const updatedConversation =
|
||||
await OpenHands.getConversation(conversationId);
|
||||
// Add a small delay to allow the backend to generate the title
|
||||
// This helps avoid race conditions where we fetch before the title is ready
|
||||
setTimeout(async () => {
|
||||
try {
|
||||
const updatedConversation =
|
||||
await OpenHands.getConversation(conversationId);
|
||||
|
||||
queryClient.setQueryData(
|
||||
["user", "conversation", conversationId],
|
||||
updatedConversation,
|
||||
);
|
||||
if (updatedConversation && updatedConversation.title) {
|
||||
queryClient.setQueryData(
|
||||
["user", "conversation", conversationId],
|
||||
updatedConversation,
|
||||
);
|
||||
} else {
|
||||
// If we still don't have a title, invalidate the query to trigger a refetch
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: ["user", "conversation", conversationId],
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: ["user", "conversation", conversationId],
|
||||
});
|
||||
}
|
||||
}, 1000); // 1 second delay
|
||||
} catch (error) {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: ["user", "conversation", conversationId],
|
||||
|
||||
@@ -82,6 +82,7 @@ export enum I18nKey {
|
||||
API$DONT_KNOW_KEY = "API$DONT_KNOW_KEY",
|
||||
BUTTON$SAVE = "BUTTON$SAVE",
|
||||
BUTTON$CLOSE = "BUTTON$CLOSE",
|
||||
BUTTON$RESET_TO_DEFAULTS = "BUTTON$RESET_TO_DEFAULTS",
|
||||
MODAL$CONFIRM_RESET_TITLE = "MODAL$CONFIRM_RESET_TITLE",
|
||||
MODAL$CONFIRM_RESET_MESSAGE = "MODAL$CONFIRM_RESET_MESSAGE",
|
||||
MODAL$END_SESSION_TITLE = "MODAL$END_SESSION_TITLE",
|
||||
@@ -320,6 +321,7 @@ export enum I18nKey {
|
||||
SETTINGS_FORM$ENABLE_CONFIRMATION_MODE_LABEL = "SETTINGS_FORM$ENABLE_CONFIRMATION_MODE_LABEL",
|
||||
SETTINGS_FORM$SAVE_LABEL = "SETTINGS_FORM$SAVE_LABEL",
|
||||
SETTINGS_FORM$CLOSE_LABEL = "SETTINGS_FORM$CLOSE_LABEL",
|
||||
SETTINGS_FORM$RESET_TO_DEFAULTS_LABEL = "SETTINGS_FORM$RESET_TO_DEFAULTS_LABEL",
|
||||
SETTINGS_FORM$CANCEL_LABEL = "SETTINGS_FORM$CANCEL_LABEL",
|
||||
SETTINGS_FORM$END_SESSION_LABEL = "SETTINGS_FORM$END_SESSION_LABEL",
|
||||
SETTINGS_FORM$CHANGING_WORKSPACE_WARNING_MESSAGE = "SETTINGS_FORM$CHANGING_WORKSPACE_WARNING_MESSAGE",
|
||||
@@ -338,7 +340,6 @@ export enum I18nKey {
|
||||
STATUS$LLM_RETRY = "STATUS$LLM_RETRY",
|
||||
AGENT_ERROR$BAD_ACTION = "AGENT_ERROR$BAD_ACTION",
|
||||
AGENT_ERROR$ACTION_TIMEOUT = "AGENT_ERROR$ACTION_TIMEOUT",
|
||||
AGENT_ERROR$TOO_MANY_CONVERSATIONS = "AGENT_ERROR$TOO_MANY_CONVERSATIONS",
|
||||
PROJECT_MENU_CARD_CONTEXT_MENU$CONNECT_TO_GITHUB_LABEL = "PROJECT_MENU_CARD_CONTEXT_MENU$CONNECT_TO_GITHUB_LABEL",
|
||||
PROJECT_MENU_CARD_CONTEXT_MENU$PUSH_TO_GITHUB_LABEL = "PROJECT_MENU_CARD_CONTEXT_MENU$PUSH_TO_GITHUB_LABEL",
|
||||
PROJECT_MENU_CARD_CONTEXT_MENU$DOWNLOAD_FILES_LABEL = "PROJECT_MENU_CARD_CONTEXT_MENU$DOWNLOAD_FILES_LABEL",
|
||||
|
||||
@@ -1231,6 +1231,21 @@
|
||||
"tr": "Kapat",
|
||||
"de": "Schließen"
|
||||
},
|
||||
"BUTTON$RESET_TO_DEFAULTS": {
|
||||
"en": "Reset to defaults",
|
||||
"ja": "デフォルトにリセット",
|
||||
"zh-CN": "重置为默认值",
|
||||
"zh-TW": "還原為預設值",
|
||||
"ko-KR": "기본값으로 재설정",
|
||||
"no": "Tilbakestill til standard",
|
||||
"it": "Ripristina valori predefiniti",
|
||||
"pt": "Restaurar padrões",
|
||||
"es": "Restablecer valores predeterminados",
|
||||
"ar": "إعادة التعيين إلى الإعدادات الافتراضية",
|
||||
"fr": "Réinitialiser aux valeurs par défaut",
|
||||
"tr": "Varsayılanlara sıfırla",
|
||||
"de": "Auf Standardwerte zurücksetzen"
|
||||
},
|
||||
"MODAL$CONFIRM_RESET_TITLE": {
|
||||
"en": "Are you sure?",
|
||||
"ja": "本当によろしいですか?",
|
||||
@@ -4526,6 +4541,21 @@
|
||||
"pt": "Fechar",
|
||||
"tr": "Kapat"
|
||||
},
|
||||
"SETTINGS_FORM$RESET_TO_DEFAULTS_LABEL": {
|
||||
"en": "Reset to defaults",
|
||||
"es": "Reiniciar valores por defect",
|
||||
"zh-CN": "重置为默认值",
|
||||
"zh-TW": "還原為預設值",
|
||||
"ko-KR": "기본값으로 재설정",
|
||||
"ja": "デフォルトに戻す",
|
||||
"no": "Tilbakestill til standardverdier",
|
||||
"ar": "إعادة التعيين إلى الإعدادات الافتراضية",
|
||||
"de": "Auf Standardwerte zurücksetzen",
|
||||
"fr": "Réinitialiser aux valeurs par défaut",
|
||||
"it": "Ripristina valori predefiniti",
|
||||
"pt": "Restaurar padrões",
|
||||
"tr": "Varsayılanlara sıfırla"
|
||||
},
|
||||
"SETTINGS_FORM$CANCEL_LABEL": {
|
||||
"en": "Cancel",
|
||||
"es": "Cancelar",
|
||||
@@ -4796,9 +4826,6 @@
|
||||
"es": "La acción expiró",
|
||||
"tr": "İşlem zaman aşımına uğradı"
|
||||
},
|
||||
"AGENT_ERROR$TOO_MANY_CONVERSATIONS": {
|
||||
"en": "Too many conversations at once."
|
||||
},
|
||||
"PROJECT_MENU_CARD_CONTEXT_MENU$CONNECT_TO_GITHUB_LABEL": {
|
||||
"en": "Connect to GitHub",
|
||||
"es": "Conectar a GitHub",
|
||||
|
||||
@@ -9,6 +9,7 @@ import { SettingsDropdownInput } from "#/components/features/settings/settings-d
|
||||
import { SettingsInput } from "#/components/features/settings/settings-input";
|
||||
import { SettingsSwitch } from "#/components/features/settings/settings-switch";
|
||||
import { LoadingSpinner } from "#/components/shared/loading-spinner";
|
||||
import { ModalBackdrop } from "#/components/shared/modals/modal-backdrop";
|
||||
import { ModelSelector } from "#/components/shared/modals/settings/model-selector";
|
||||
import { useSaveSettings } from "#/hooks/mutation/use-save-settings";
|
||||
import { useAIConfigOptions } from "#/hooks/query/use-ai-config-options";
|
||||
@@ -94,6 +95,8 @@ function AccountSettings() {
|
||||
>(isAdvancedSettingsSet ? "advanced" : "basic");
|
||||
const [confirmationModeIsEnabled, setConfirmationModeIsEnabled] =
|
||||
React.useState(!!settings?.SECURITY_ANALYZER);
|
||||
const [resetSettingsModalIsOpen, setResetSettingsModalIsOpen] =
|
||||
React.useState(false);
|
||||
|
||||
const formRef = React.useRef<HTMLFormElement>(null);
|
||||
|
||||
@@ -177,6 +180,16 @@ function AccountSettings() {
|
||||
});
|
||||
};
|
||||
|
||||
const handleReset = () => {
|
||||
saveSettings(null, {
|
||||
onSuccess: () => {
|
||||
displaySuccessToast(t(I18nKey.SETTINGS$RESET));
|
||||
setResetSettingsModalIsOpen(false);
|
||||
setLlmConfigMode("basic");
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
React.useEffect(() => {
|
||||
// If settings is still loading by the time the state is set, it will always
|
||||
// default to basic settings. This is a workaround to ensure the correct
|
||||
@@ -514,6 +527,13 @@ function AccountSettings() {
|
||||
</form>
|
||||
|
||||
<footer className="flex gap-6 p-6 justify-end border-t border-t-tertiary">
|
||||
<BrandButton
|
||||
type="button"
|
||||
variant="secondary"
|
||||
onClick={() => setResetSettingsModalIsOpen(true)}
|
||||
>
|
||||
{t(I18nKey.BUTTON$RESET_TO_DEFAULTS)}
|
||||
</BrandButton>
|
||||
<BrandButton
|
||||
type="button"
|
||||
variant="primary"
|
||||
@@ -524,6 +544,40 @@ function AccountSettings() {
|
||||
{t(I18nKey.BUTTON$SAVE)}
|
||||
</BrandButton>
|
||||
</footer>
|
||||
|
||||
{resetSettingsModalIsOpen && (
|
||||
<ModalBackdrop>
|
||||
<div
|
||||
data-testid="reset-modal"
|
||||
className="bg-base-secondary p-4 rounded-xl flex flex-col gap-4 border border-tertiary"
|
||||
>
|
||||
<p>{t(I18nKey.SETTINGS$RESET_CONFIRMATION)}</p>
|
||||
<div className="w-full flex gap-2">
|
||||
<BrandButton
|
||||
type="button"
|
||||
variant="primary"
|
||||
className="grow"
|
||||
onClick={() => {
|
||||
handleReset();
|
||||
}}
|
||||
>
|
||||
Reset
|
||||
</BrandButton>
|
||||
|
||||
<BrandButton
|
||||
type="button"
|
||||
variant="secondary"
|
||||
className="grow"
|
||||
onClick={() => {
|
||||
setResetSettingsModalIsOpen(false);
|
||||
}}
|
||||
>
|
||||
Cancel
|
||||
</BrandButton>
|
||||
</div>
|
||||
</div>
|
||||
</ModalBackdrop>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -135,6 +135,7 @@ def event_to_dict(event: 'Event') -> dict:
|
||||
k: (v.value if isinstance(v, Enum) else _convert_pydantic_to_dict(v))
|
||||
for k, v in props.items()
|
||||
}
|
||||
logger.debug(f'extras data in event_to_dict: {d["extras"]}')
|
||||
# Include success field for CmdOutputObservation
|
||||
if hasattr(event, 'success'):
|
||||
d['success'] = event.success
|
||||
|
||||
@@ -18,8 +18,6 @@ from openhands.integrations.service_types import (
|
||||
)
|
||||
from openhands.server.types import AppMode
|
||||
from openhands.utils.import_utils import get_impl
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
|
||||
class GitHubService(BaseGitService, GitService):
|
||||
@@ -159,13 +157,6 @@ class GitHubService(BaseGitService, GitService):
|
||||
|
||||
return repos[:max_repos] # Trim to max_repos if needed
|
||||
|
||||
|
||||
def parse_pushed_at_date(self, repo):
|
||||
ts = repo.get("pushed_at")
|
||||
return datetime.strptime(ts, "%Y-%m-%dT%H:%M:%SZ") if ts else datetime.min
|
||||
|
||||
|
||||
|
||||
async def get_repositories(self, sort: str, app_mode: AppMode) -> list[Repository]:
|
||||
MAX_REPOS = 1000
|
||||
PER_PAGE = 100 # Maximum allowed by GitHub API
|
||||
@@ -192,11 +183,6 @@ class GitHubService(BaseGitService, GitService):
|
||||
# If we've already reached MAX_REPOS, no need to check other installations
|
||||
if len(all_repos) >= MAX_REPOS:
|
||||
break
|
||||
|
||||
if sort == "pushed":
|
||||
all_repos.sort(
|
||||
key=self.parse_pushed_at_date, reverse=True
|
||||
)
|
||||
else:
|
||||
# Original behavior for non-SaaS mode
|
||||
params = {'per_page': str(PER_PAGE), 'sort': sort}
|
||||
@@ -205,7 +191,6 @@ class GitHubService(BaseGitService, GitService):
|
||||
# Fetch user repositories
|
||||
all_repos = await self._fetch_paginated_repos(url, params, MAX_REPOS)
|
||||
|
||||
|
||||
# Convert to Repository objects
|
||||
return [
|
||||
Repository(
|
||||
|
||||
@@ -111,7 +111,7 @@ class BaseGitService(ABC):
|
||||
return UnknownException('Unknown error')
|
||||
|
||||
def handle_http_error(self, e: HTTPError) -> UnknownException:
|
||||
logger.warning(f'HTTP error on {self.provider} API: {type(e).__name__} : {e}')
|
||||
logger.warning(f'HTTP error on {self.provider} API: {e}')
|
||||
return UnknownException('Unknown error')
|
||||
|
||||
|
||||
|
||||
@@ -713,7 +713,7 @@ class LLM(RetryMixin, DebugMixin):
|
||||
completion_response=response, **extra_kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f'Error getting cost from litellm: {e}')
|
||||
logger.error(f'Error getting cost from litellm: {e}')
|
||||
|
||||
if cost is None:
|
||||
_model_name = '/'.join(self.config.model.split('/')[1:])
|
||||
|
||||
@@ -183,7 +183,7 @@ python -m openhands.resolver.send_pull_request --issue-number ISSUE_NUMBER --use
|
||||
|
||||
## Providing Custom Instructions
|
||||
|
||||
You can customize how the AI agent approaches issue resolution by adding a repository microagent file at `.openhands/microagents/repo.md` in your repository. This file's contents will be automatically loaded in the prompt when working with your repository. For more information about repository microagents, see [Repository Instructions](https://github.com/All-Hands-AI/OpenHands/tree/main/microagents#2-repository-instructions-private).
|
||||
You can customize how the AI agent approaches issue resolution by adding a `.openhands_instructions` file to the root of your repository. If present, this file's contents will be injected into the prompt for openhands edits.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
|
||||
@@ -98,7 +98,6 @@ class Runtime(FileEditRuntimeMixin):
|
||||
initial_env_vars: dict[str, str]
|
||||
attach_to_existing: bool
|
||||
status_callback: Callable | None
|
||||
git_dir: str | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -113,11 +112,9 @@ class Runtime(FileEditRuntimeMixin):
|
||||
user_id: str | None = None,
|
||||
git_provider_tokens: PROVIDER_TOKEN_TYPE | None = None,
|
||||
):
|
||||
# GitHandler will be initialized with an async function
|
||||
self.git_handler = GitHandler(
|
||||
execute_shell_fn=self._execute_shell_fn_git_handler
|
||||
)
|
||||
self.git_dir = None
|
||||
self.sid = sid
|
||||
self.event_stream = event_stream
|
||||
self.event_stream.subscribe(
|
||||
@@ -319,9 +316,6 @@ class Runtime(FileEditRuntimeMixin):
|
||||
selected_branch: str | None,
|
||||
repository_provider: ProviderType = ProviderType.GITHUB,
|
||||
) -> str:
|
||||
# Set the git_dir to the workspace mount path by default
|
||||
self.git_dir = self.config.workspace_mount_path_in_sandbox
|
||||
|
||||
if not selected_repository:
|
||||
# In SaaS mode (indicated by user_id being set), always run git init
|
||||
# In OSS mode, only run git init if workspace_base is not set
|
||||
@@ -333,7 +327,6 @@ class Runtime(FileEditRuntimeMixin):
|
||||
command='git init',
|
||||
)
|
||||
self.run_action(action)
|
||||
# git_dir is already set to workspace mount path
|
||||
else:
|
||||
logger.info(
|
||||
'In workspace mount mode, not initializing a new git repository.'
|
||||
@@ -402,13 +395,6 @@ class Runtime(FileEditRuntimeMixin):
|
||||
)
|
||||
self.log('info', f'Cloning repo: {selected_repository}')
|
||||
self.run_action(action)
|
||||
|
||||
# Update git_dir to point to the cloned repository directory
|
||||
self.git_dir = os.path.join(
|
||||
self.config.workspace_mount_path_in_sandbox, dir_name
|
||||
)
|
||||
self.git_handler.set_cwd(self.git_dir)
|
||||
|
||||
return dir_name
|
||||
|
||||
def maybe_run_setup_script(self):
|
||||
@@ -626,15 +612,13 @@ class Runtime(FileEditRuntimeMixin):
|
||||
# Git
|
||||
# ====================================================================
|
||||
|
||||
async def _execute_shell_fn_git_handler(
|
||||
def _execute_shell_fn_git_handler(
|
||||
self, command: str, cwd: str | None
|
||||
) -> CommandResult:
|
||||
"""
|
||||
This function is used by the GitHandler to execute shell commands.
|
||||
"""
|
||||
obs = await call_sync_from_async(
|
||||
self.run, CmdRunAction(command=command, is_static=True, cwd=cwd)
|
||||
)
|
||||
obs = self.run(CmdRunAction(command=command, is_static=True, cwd=cwd))
|
||||
exit_code = 0
|
||||
content = ''
|
||||
|
||||
@@ -645,15 +629,13 @@ class Runtime(FileEditRuntimeMixin):
|
||||
|
||||
return CommandResult(content=content, exit_code=exit_code)
|
||||
|
||||
async def get_git_changes(self) -> list[dict[str, str]] | None:
|
||||
if self.git_dir:
|
||||
self.git_handler.set_cwd(self.git_dir)
|
||||
return await call_sync_from_async(self.git_handler.get_git_changes)
|
||||
def get_git_changes(self, cwd: str) -> list[dict[str, str]] | None:
|
||||
self.git_handler.set_cwd(cwd)
|
||||
return self.git_handler.get_git_changes()
|
||||
|
||||
async def get_git_diff(self, file_path: str) -> dict[str, str]:
|
||||
if self.git_dir:
|
||||
self.git_handler.set_cwd(self.git_dir)
|
||||
return await call_sync_from_async(self.git_handler.get_git_diff, file_path)
|
||||
def get_git_diff(self, file_path: str, cwd: str) -> dict[str, str]:
|
||||
self.git_handler.set_cwd(cwd)
|
||||
return self.git_handler.get_git_diff(file_path)
|
||||
|
||||
@property
|
||||
def additional_agent_instructions(self) -> str:
|
||||
|
||||
@@ -215,13 +215,13 @@ class BashSession:
|
||||
self.session.set_option('history-limit', str(self.HISTORY_LIMIT), _global=True)
|
||||
self.session.history_limit = self.HISTORY_LIMIT
|
||||
# We need to create a new pane because the initial pane's history limit is (default) 2000
|
||||
_initial_window = self.session.active_window
|
||||
_initial_window = self.session.attached_window
|
||||
self.window = self.session.new_window(
|
||||
window_name='bash',
|
||||
window_shell=window_command,
|
||||
start_directory=self.work_dir, # This parameter is supported by libtmux
|
||||
)
|
||||
self.pane = self.window.active_pane
|
||||
self.pane = self.window.attached_pane
|
||||
logger.debug(f'pane: {self.pane}; history_limit: {self.session.history_limit}')
|
||||
_initial_window.kill_window()
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Awaitable, Callable
|
||||
from typing import Callable
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -23,7 +23,7 @@ class GitHandler:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
execute_shell_fn: Callable[[str, str | None], Awaitable[CommandResult]],
|
||||
execute_shell_fn: Callable[[str, str | None], CommandResult],
|
||||
):
|
||||
self.execute = execute_shell_fn
|
||||
self.cwd: str | None = None
|
||||
@@ -37,11 +37,7 @@ class GitHandler:
|
||||
"""
|
||||
self.cwd = cwd
|
||||
|
||||
async def _execute_async(self, cmd: str, cwd: str | None) -> CommandResult:
|
||||
"""Execute the command asynchronously."""
|
||||
return await self.execute(cmd, cwd)
|
||||
|
||||
async def _is_git_repo(self) -> bool:
|
||||
def _is_git_repo(self) -> bool:
|
||||
"""
|
||||
Checks if the current directory is a Git repository.
|
||||
|
||||
@@ -49,10 +45,10 @@ class GitHandler:
|
||||
bool: True if inside a Git repository, otherwise False.
|
||||
"""
|
||||
cmd = 'git rev-parse --is-inside-work-tree'
|
||||
output = await self._execute_async(cmd, self.cwd)
|
||||
output = self.execute(cmd, self.cwd)
|
||||
return output.content.strip() == 'true'
|
||||
|
||||
async def _get_current_file_content(self, file_path: str) -> str:
|
||||
def _get_current_file_content(self, file_path: str) -> str:
|
||||
"""
|
||||
Retrieves the current content of a given file.
|
||||
|
||||
@@ -62,10 +58,10 @@ class GitHandler:
|
||||
Returns:
|
||||
str: The file content.
|
||||
"""
|
||||
output = await self._execute_async(f'cat {file_path}', self.cwd)
|
||||
output = self.execute(f'cat {file_path}', self.cwd)
|
||||
return output.content
|
||||
|
||||
async def _verify_ref_exists(self, ref: str) -> bool:
|
||||
def _verify_ref_exists(self, ref: str) -> bool:
|
||||
"""
|
||||
Verifies whether a specific Git reference exists.
|
||||
|
||||
@@ -76,18 +72,18 @@ class GitHandler:
|
||||
bool: True if the reference exists, otherwise False.
|
||||
"""
|
||||
cmd = f'git rev-parse --verify {ref}'
|
||||
output = await self._execute_async(cmd, self.cwd)
|
||||
output = self.execute(cmd, self.cwd)
|
||||
return output.exit_code == 0
|
||||
|
||||
async def _get_valid_ref(self) -> str | None:
|
||||
def _get_valid_ref(self) -> str | None:
|
||||
"""
|
||||
Determines a valid Git reference for comparison.
|
||||
|
||||
Returns:
|
||||
str | None: A valid Git reference or None if no valid reference is found.
|
||||
"""
|
||||
current_branch = await self._get_current_branch()
|
||||
default_branch = await self._get_default_branch()
|
||||
current_branch = self._get_current_branch()
|
||||
default_branch = self._get_default_branch()
|
||||
|
||||
ref_current_branch = f'origin/{current_branch}'
|
||||
ref_non_default_branch = f'$(git merge-base HEAD "$(git rev-parse --abbrev-ref origin/{default_branch})")'
|
||||
@@ -101,12 +97,12 @@ class GitHandler:
|
||||
ref_new_repo,
|
||||
]
|
||||
for ref in refs:
|
||||
if await self._verify_ref_exists(ref):
|
||||
if self._verify_ref_exists(ref):
|
||||
return ref
|
||||
|
||||
return None
|
||||
|
||||
async def _get_ref_content(self, file_path: str) -> str:
|
||||
def _get_ref_content(self, file_path: str) -> str:
|
||||
"""
|
||||
Retrieves the content of a file from a valid Git reference.
|
||||
|
||||
@@ -116,15 +112,15 @@ class GitHandler:
|
||||
Returns:
|
||||
str: The content of the file from the reference, or an empty string if unavailable.
|
||||
"""
|
||||
ref = await self._get_valid_ref()
|
||||
ref = self._get_valid_ref()
|
||||
if not ref:
|
||||
return ''
|
||||
|
||||
cmd = f'git show {ref}:{file_path}'
|
||||
output = await self._execute_async(cmd, self.cwd)
|
||||
output = self.execute(cmd, self.cwd)
|
||||
return output.content if output.exit_code == 0 else ''
|
||||
|
||||
async def _get_default_branch(self) -> str:
|
||||
def _get_default_branch(self) -> str:
|
||||
"""
|
||||
Retrieves the primary Git branch name of the repository.
|
||||
|
||||
@@ -132,10 +128,10 @@ class GitHandler:
|
||||
str: The name of the primary branch.
|
||||
"""
|
||||
cmd = 'git remote show origin | grep "HEAD branch"'
|
||||
output = await self._execute_async(cmd, self.cwd)
|
||||
output = self.execute(cmd, self.cwd)
|
||||
return output.content.split()[-1].strip()
|
||||
|
||||
async def _get_current_branch(self) -> str:
|
||||
def _get_current_branch(self) -> str:
|
||||
"""
|
||||
Retrieves the currently selected Git branch.
|
||||
|
||||
@@ -143,25 +139,25 @@ class GitHandler:
|
||||
str: The name of the current branch.
|
||||
"""
|
||||
cmd = 'git rev-parse --abbrev-ref HEAD'
|
||||
output = await self._execute_async(cmd, self.cwd)
|
||||
output = self.execute(cmd, self.cwd)
|
||||
return output.content.strip()
|
||||
|
||||
async def _get_changed_files(self) -> list[str]:
|
||||
def _get_changed_files(self) -> list[str]:
|
||||
"""
|
||||
Retrieves a list of changed files compared to a valid Git reference.
|
||||
|
||||
Returns:
|
||||
list[str]: A list of changed file paths.
|
||||
"""
|
||||
ref = await self._get_valid_ref()
|
||||
ref = self._get_valid_ref()
|
||||
if not ref:
|
||||
return []
|
||||
|
||||
diff_cmd = f'git diff --name-status {ref}'
|
||||
output = await self._execute_async(diff_cmd, self.cwd)
|
||||
output = self.execute(diff_cmd, self.cwd)
|
||||
return output.content.splitlines()
|
||||
|
||||
async def _get_untracked_files(self) -> list[dict[str, str]]:
|
||||
def _get_untracked_files(self) -> list[dict[str, str]]:
|
||||
"""
|
||||
Retrieves a list of untracked files in the repository. This is useful for detecting new files.
|
||||
|
||||
@@ -169,7 +165,7 @@ class GitHandler:
|
||||
list[dict[str, str]]: A list of dictionaries containing file paths and statuses.
|
||||
"""
|
||||
cmd = 'git ls-files --others --exclude-standard'
|
||||
output = await self._execute_async(cmd, self.cwd)
|
||||
output = self.execute(cmd, self.cwd)
|
||||
obs_list = output.content.splitlines()
|
||||
return (
|
||||
[{'status': 'A', 'path': path} for path in obs_list]
|
||||
@@ -177,24 +173,24 @@ class GitHandler:
|
||||
else []
|
||||
)
|
||||
|
||||
async def get_git_changes(self) -> list[dict[str, str]] | None:
|
||||
def get_git_changes(self) -> list[dict[str, str]] | None:
|
||||
"""
|
||||
Retrieves the list of changed files in the Git repository.
|
||||
|
||||
Returns:
|
||||
list[dict[str, str]] | None: A list of dictionaries containing file paths and statuses. None if not a git repository.
|
||||
"""
|
||||
if not await self._is_git_repo():
|
||||
if not self._is_git_repo():
|
||||
return None
|
||||
|
||||
changes_list = await self._get_changed_files()
|
||||
changes_list = self._get_changed_files()
|
||||
result = parse_git_changes(changes_list)
|
||||
|
||||
# join with any untracked files
|
||||
result += await self._get_untracked_files()
|
||||
result += self._get_untracked_files()
|
||||
return result
|
||||
|
||||
async def get_git_diff(self, file_path: str) -> dict[str, str]:
|
||||
def get_git_diff(self, file_path: str) -> dict[str, str]:
|
||||
"""
|
||||
Retrieves the original and modified content of a file in the repository.
|
||||
|
||||
@@ -204,8 +200,8 @@ class GitHandler:
|
||||
Returns:
|
||||
dict[str, str]: A dictionary containing the original and modified content.
|
||||
"""
|
||||
modified = await self._get_current_file_content(file_path)
|
||||
original = await self._get_ref_content(file_path)
|
||||
modified = self._get_current_file_content(file_path)
|
||||
original = self._get_ref_content(file_path)
|
||||
|
||||
return {
|
||||
'modified': modified,
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
from fastapi import Request
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderType
|
||||
|
||||
|
||||
def get_provider_tokens(request: Request) -> PROVIDER_TOKEN_TYPE | None:
|
||||
"""Get GitHub token from request state. For backward compatibility."""
|
||||
return getattr(request.state, 'provider_tokens', None)
|
||||
|
||||
|
||||
def get_access_token(request: Request) -> SecretStr | None:
|
||||
return getattr(request.state, 'access_token', None)
|
||||
|
||||
|
||||
def get_user_id(request: Request) -> str | None:
|
||||
return getattr(request.state, 'user_id', None)
|
||||
|
||||
|
||||
def get_github_token(request: Request) -> SecretStr | None:
|
||||
provider_tokens = get_provider_tokens(request)
|
||||
|
||||
if provider_tokens and ProviderType.GITHUB in provider_tokens:
|
||||
return provider_tokens[ProviderType.GITHUB].token
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_github_user_id(request: Request) -> str | None:
|
||||
provider_tokens = get_provider_tokens(request)
|
||||
if provider_tokens and ProviderType.GITHUB in provider_tokens:
|
||||
return provider_tokens[ProviderType.GITHUB].user_id
|
||||
|
||||
return None
|
||||
@@ -20,7 +20,6 @@ class ServerConfig(ServerConfigInterface):
|
||||
)
|
||||
conversation_manager_class: str = 'openhands.server.conversation_manager.standalone_conversation_manager.StandaloneConversationManager'
|
||||
monitoring_listener_class: str = 'openhands.server.monitoring.MonitoringListener'
|
||||
user_auth_class: str = 'openhands.server.user_auth.default_user_auth.DefaultUserAuth'
|
||||
|
||||
def verify_config(self):
|
||||
if self.config_cls:
|
||||
|
||||
@@ -202,7 +202,9 @@ class StandaloneConversationManager(ConversationManager):
|
||||
ConversationStore, # type: ignore
|
||||
self.server_config.conversation_store_class,
|
||||
)
|
||||
store = await conversation_store_class.get_instance(self.config, user_id)
|
||||
store = await conversation_store_class.get_instance(
|
||||
self.config, user_id, github_user_id
|
||||
)
|
||||
return store
|
||||
|
||||
async def get_running_agent_loops(
|
||||
@@ -283,7 +285,7 @@ class StandaloneConversationManager(ConversationManager):
|
||||
response_ids = await self.get_running_agent_loops(user_id)
|
||||
if len(response_ids) >= self.config.max_concurrent_conversations:
|
||||
logger.info(
|
||||
f'too_many_sessions_for:{user_id or ''}',
|
||||
'too_many_sessions_for:{user_id}',
|
||||
extra={'session_id': sid, 'user_id': user_id},
|
||||
)
|
||||
# Get the conversations sorted (oldest first)
|
||||
@@ -295,22 +297,6 @@ class StandaloneConversationManager(ConversationManager):
|
||||
|
||||
while len(conversations) >= self.config.max_concurrent_conversations:
|
||||
oldest_conversation_id = conversations.pop().conversation_id
|
||||
logger.debug(
|
||||
f'closing_from_too_many_sessions:{user_id or ''}:{oldest_conversation_id}',
|
||||
extra={'session_id': oldest_conversation_id, 'user_id': user_id},
|
||||
)
|
||||
# Send status message to client and close session.
|
||||
status_update_dict = {
|
||||
'status_update': True,
|
||||
'type': 'error',
|
||||
'id': 'AGENT_ERROR$TOO_MANY_CONVERSATIONS',
|
||||
'message': 'Too many conversations at once. If you are still using this one, try reactivating it by prompting the agent to continue',
|
||||
}
|
||||
await self.sio.emit(
|
||||
'oh_event',
|
||||
status_update_dict,
|
||||
to=ROOM_KEY.format(sid=oldest_conversation_id),
|
||||
)
|
||||
await self.close_session(oldest_conversation_id)
|
||||
|
||||
session = Session(
|
||||
@@ -395,8 +381,8 @@ class StandaloneConversationManager(ConversationManager):
|
||||
f'removing connections: {connection_ids_to_remove}',
|
||||
extra={'session_id': sid},
|
||||
)
|
||||
for connection_id in connection_ids_to_remove:
|
||||
self._local_connection_id_to_session_id.pop(connection_id, None)
|
||||
for connnnection_id in connection_ids_to_remove:
|
||||
self._local_connection_id_to_session_id.pop(connnnection_id, None)
|
||||
|
||||
session = self._local_agent_loops_by_sid.pop(sid, None)
|
||||
if not session:
|
||||
|
||||
@@ -9,6 +9,7 @@ from openhands.server.middleware import (
|
||||
CacheControlMiddleware,
|
||||
InMemoryRateLimiter,
|
||||
LocalhostCORSMiddleware,
|
||||
ProviderTokenMiddleware,
|
||||
RateLimitMiddleware,
|
||||
)
|
||||
from openhands.server.static import SPAStaticFiles
|
||||
@@ -31,5 +32,6 @@ base_app.add_middleware(
|
||||
rate_limiter=InMemoryRateLimiter(requests=10, seconds=1),
|
||||
)
|
||||
base_app.middleware('http')(AttachConversationMiddleware(base_app))
|
||||
base_app.middleware('http')(ProviderTokenMiddleware(base_app))
|
||||
|
||||
app = socketio.ASGIApp(sio, other_asgi_app=base_app)
|
||||
|
||||
@@ -12,8 +12,8 @@ from starlette.requests import Request as StarletteRequest
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
from openhands.server import shared
|
||||
from openhands.server.auth import get_user_id
|
||||
from openhands.server.types import SessionMiddlewareInterface
|
||||
from openhands.server.user_auth import get_user_id
|
||||
|
||||
|
||||
class LocalhostCORSMiddleware(CORSMiddleware):
|
||||
@@ -147,10 +147,9 @@ class AttachConversationMiddleware(SessionMiddlewareInterface):
|
||||
"""
|
||||
Attach the user's session based on the provided authentication token.
|
||||
"""
|
||||
user_id = await get_user_id(request)
|
||||
request.state.conversation = (
|
||||
await shared.conversation_manager.attach_to_conversation(
|
||||
request.state.sid, user_id
|
||||
request.state.sid, get_user_id(request)
|
||||
)
|
||||
)
|
||||
if not request.state.conversation:
|
||||
@@ -184,3 +183,27 @@ class AttachConversationMiddleware(SessionMiddlewareInterface):
|
||||
await self._detach_session(request)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class ProviderTokenMiddleware(SessionMiddlewareInterface):
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, request: Request, call_next: Callable):
|
||||
settings_store = await shared.SettingsStoreImpl.get_instance(
|
||||
shared.config, get_user_id(request)
|
||||
)
|
||||
settings = await settings_store.load()
|
||||
|
||||
# TODO: To avoid checks like this we should re-add the abilty to have completely different middleware in SAAS as in OSS
|
||||
if getattr(request.state, 'provider_tokens', None) is None:
|
||||
if (
|
||||
settings
|
||||
and settings.secrets_store
|
||||
and settings.secrets_store.provider_tokens
|
||||
):
|
||||
request.state.provider_tokens = settings.secrets_store.provider_tokens
|
||||
else:
|
||||
request.state.provider_tokens = None
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
@@ -2,7 +2,6 @@ import os
|
||||
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
HTTPException,
|
||||
Request,
|
||||
status,
|
||||
@@ -22,10 +21,19 @@ from openhands.events.observation import (
|
||||
FileReadObservation,
|
||||
)
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.server.auth import get_github_user_id, get_user_id
|
||||
from openhands.server.data_models.conversation_info import ConversationInfo
|
||||
from openhands.server.file_config import (
|
||||
FILES_TO_IGNORE,
|
||||
)
|
||||
from openhands.server.user_auth import get_user_id
|
||||
from openhands.server.shared import (
|
||||
ConversationStoreImpl,
|
||||
config,
|
||||
conversation_manager,
|
||||
)
|
||||
from openhands.storage.conversation.conversation_store import ConversationStore
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationMetadata
|
||||
from openhands.storage.data_models.conversation_status import ConversationStatus
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
app = APIRouter(prefix='/api/conversations/{conversation_id}')
|
||||
@@ -179,19 +187,24 @@ def zip_current_workspace(request: Request):
|
||||
|
||||
|
||||
@app.get('/git/changes')
|
||||
async def git_changes(
|
||||
request: Request,
|
||||
conversation_id: str,
|
||||
user_id: str = Depends(get_user_id),
|
||||
):
|
||||
async def git_changes(request: Request, conversation_id: str):
|
||||
runtime: Runtime = request.state.conversation.runtime
|
||||
logger.info(f'Getting git changes in {runtime.git_dir}')
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, get_user_id(request), get_github_user_id(request)
|
||||
)
|
||||
|
||||
cwd = await get_cwd(
|
||||
conversation_store,
|
||||
conversation_id,
|
||||
runtime.config.workspace_mount_path_in_sandbox,
|
||||
)
|
||||
logger.info(f'Getting git changes in {cwd}')
|
||||
|
||||
try:
|
||||
changes = await call_sync_from_async(runtime.get_git_changes)
|
||||
changes = await call_sync_from_async(runtime.get_git_changes, cwd)
|
||||
if changes is None:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
status_code=500,
|
||||
content={'error': 'Not a git repository'},
|
||||
)
|
||||
return changes
|
||||
@@ -210,16 +223,20 @@ async def git_changes(
|
||||
|
||||
|
||||
@app.get('/git/diff')
|
||||
async def git_diff(
|
||||
request: Request,
|
||||
path: str,
|
||||
conversation_id: str,
|
||||
):
|
||||
async def git_diff(request: Request, path: str, conversation_id: str):
|
||||
runtime: Runtime = request.state.conversation.runtime
|
||||
logger.info(f'Getting git diff for {path} in {runtime.git_dir}')
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, get_user_id(request), get_github_user_id(request)
|
||||
)
|
||||
|
||||
cwd = await get_cwd(
|
||||
conversation_store,
|
||||
conversation_id,
|
||||
runtime.config.workspace_mount_path_in_sandbox,
|
||||
)
|
||||
|
||||
try:
|
||||
diff = await call_sync_from_async(runtime.get_git_diff, path)
|
||||
diff = await call_sync_from_async(runtime.get_git_diff, path, cwd)
|
||||
return diff
|
||||
except AgentRuntimeUnavailableError as e:
|
||||
logger.error(f'Error getting diff: {e}')
|
||||
@@ -227,3 +244,46 @@ async def git_diff(
|
||||
status_code=500,
|
||||
content={'error': f'Error getting diff: {e}'},
|
||||
)
|
||||
|
||||
|
||||
async def get_cwd(
|
||||
conversation_store: ConversationStore,
|
||||
conversation_id: str,
|
||||
workspace_mount_path_in_sandbox: str,
|
||||
):
|
||||
metadata = await conversation_store.get_metadata(conversation_id)
|
||||
is_running = await conversation_manager.is_agent_loop_running(conversation_id)
|
||||
conversation_info = await _get_conversation_info(metadata, is_running)
|
||||
|
||||
cwd = workspace_mount_path_in_sandbox
|
||||
if conversation_info and conversation_info.selected_repository:
|
||||
repo_dir = conversation_info.selected_repository.split('/')[-1]
|
||||
cwd = os.path.join(cwd, repo_dir)
|
||||
|
||||
return cwd
|
||||
|
||||
|
||||
async def _get_conversation_info(
|
||||
conversation: ConversationMetadata,
|
||||
is_running: bool,
|
||||
) -> ConversationInfo | None:
|
||||
try:
|
||||
title = conversation.title
|
||||
if not title:
|
||||
title = f'Conversation {conversation.conversation_id[:5]}'
|
||||
return ConversationInfo(
|
||||
conversation_id=conversation.conversation_id,
|
||||
title=title,
|
||||
last_updated_at=conversation.last_updated_at,
|
||||
created_at=conversation.created_at,
|
||||
selected_repository=conversation.selected_repository,
|
||||
status=ConversationStatus.RUNNING
|
||||
if is_running
|
||||
else ConversationStatus.STOPPED,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'Error loading conversation {conversation.conversation_id}: {str(e)}',
|
||||
extra={'session_id': conversation.conversation_id},
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -15,12 +15,8 @@ from openhands.integrations.service_types import (
|
||||
UnknownException,
|
||||
User,
|
||||
)
|
||||
from openhands.server.auth import get_access_token, get_provider_tokens, get_user_id
|
||||
from openhands.server.shared import server_config
|
||||
from openhands.server.user_auth import (
|
||||
get_access_token,
|
||||
get_provider_tokens,
|
||||
get_user_id,
|
||||
)
|
||||
|
||||
app = APIRouter(prefix='/api/user')
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, status
|
||||
from fastapi import APIRouter, Body, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -15,6 +15,11 @@ from openhands.integrations.provider import (
|
||||
)
|
||||
from openhands.integrations.service_types import Repository
|
||||
from openhands.runtime import get_runtime_cls
|
||||
from openhands.server.auth import (
|
||||
get_github_user_id,
|
||||
get_provider_tokens,
|
||||
get_user_id,
|
||||
)
|
||||
from openhands.server.data_models.conversation_info import ConversationInfo
|
||||
from openhands.server.data_models.conversation_info_result_set import (
|
||||
ConversationInfoResultSet,
|
||||
@@ -28,12 +33,6 @@ from openhands.server.shared import (
|
||||
file_store,
|
||||
)
|
||||
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
|
||||
from openhands.server.user_auth import (
|
||||
get_provider_tokens,
|
||||
get_user_id,
|
||||
)
|
||||
from openhands.server.utils import get_conversation_store
|
||||
from openhands.storage.conversation.conversation_store import ConversationStore
|
||||
from openhands.storage.data_models.conversation_metadata import (
|
||||
ConversationMetadata,
|
||||
ConversationTrigger,
|
||||
@@ -96,7 +95,7 @@ async def _create_new_conversation(
|
||||
session_init_args['selected_branch'] = selected_branch
|
||||
conversation_init_data = ConversationInitData(**session_init_args)
|
||||
logger.info('Loading conversation store')
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, user_id, None)
|
||||
logger.info('Conversation store loaded')
|
||||
|
||||
conversation_id = uuid.uuid4().hex
|
||||
@@ -153,17 +152,14 @@ async def _create_new_conversation(
|
||||
|
||||
|
||||
@app.post('/conversations')
|
||||
async def new_conversation(
|
||||
data: InitSessionRequest,
|
||||
user_id: str = Depends(get_user_id),
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE = Depends(get_provider_tokens),
|
||||
):
|
||||
async def new_conversation(request: Request, data: InitSessionRequest):
|
||||
"""Initialize a new session or join an existing one.
|
||||
|
||||
After successful initialization, the client should connect to the WebSocket
|
||||
using the returned conversation ID.
|
||||
"""
|
||||
logger.info('Initializing new conversation')
|
||||
provider_tokens = get_provider_tokens(request)
|
||||
selected_repository = data.selected_repository
|
||||
selected_branch = data.selected_branch
|
||||
initial_user_msg = data.initial_user_msg
|
||||
@@ -173,7 +169,7 @@ async def new_conversation(
|
||||
try:
|
||||
# Create conversation with initial message
|
||||
conversation_id = await _create_new_conversation(
|
||||
user_id,
|
||||
get_user_id(request),
|
||||
provider_tokens,
|
||||
selected_repository,
|
||||
selected_branch,
|
||||
@@ -208,11 +204,13 @@ async def new_conversation(
|
||||
|
||||
@app.get('/conversations')
|
||||
async def search_conversations(
|
||||
request: Request,
|
||||
page_id: str | None = None,
|
||||
limit: int = 20,
|
||||
user_id: str | None = Depends(get_user_id),
|
||||
conversation_store: ConversationStore = Depends(get_conversation_store),
|
||||
) -> ConversationInfoResultSet:
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, get_user_id(request), get_github_user_id(request)
|
||||
)
|
||||
conversation_metadata_result_set = await conversation_store.search(page_id, limit)
|
||||
|
||||
# Filter out conversations older than max_age
|
||||
@@ -230,7 +228,7 @@ async def search_conversations(
|
||||
conversation.conversation_id for conversation in filtered_results
|
||||
)
|
||||
running_conversations = await conversation_manager.get_running_agent_loops(
|
||||
user_id, set(conversation_ids)
|
||||
get_user_id(request), set(conversation_ids)
|
||||
)
|
||||
result = ConversationInfoResultSet(
|
||||
results=await wait_all(
|
||||
@@ -247,9 +245,11 @@ async def search_conversations(
|
||||
|
||||
@app.get('/conversations/{conversation_id}')
|
||||
async def get_conversation(
|
||||
conversation_id: str,
|
||||
conversation_store: ConversationStore = Depends(get_conversation_store),
|
||||
conversation_id: str, request: Request
|
||||
) -> ConversationInfo | None:
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, get_user_id(request), get_github_user_id(request)
|
||||
)
|
||||
try:
|
||||
metadata = await conversation_store.get_metadata(conversation_id)
|
||||
is_running = await conversation_manager.is_agent_loop_running(conversation_id)
|
||||
@@ -340,12 +340,11 @@ async def auto_generate_title(conversation_id: str, user_id: str | None) -> str:
|
||||
|
||||
@app.patch('/conversations/{conversation_id}')
|
||||
async def update_conversation(
|
||||
conversation_id: str,
|
||||
title: str = Body(embed=True),
|
||||
user_id: str | None = Depends(get_user_id),
|
||||
request: Request, conversation_id: str, title: str = Body(embed=True)
|
||||
) -> bool:
|
||||
user_id = get_user_id(request)
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, user_id
|
||||
config, user_id, get_github_user_id(request)
|
||||
)
|
||||
metadata = await conversation_store.get_metadata(conversation_id)
|
||||
if not metadata:
|
||||
@@ -367,10 +366,10 @@ async def update_conversation(
|
||||
@app.delete('/conversations/{conversation_id}')
|
||||
async def delete_conversation(
|
||||
conversation_id: str,
|
||||
user_id: str | None = Depends(get_user_id),
|
||||
request: Request,
|
||||
) -> bool:
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, user_id
|
||||
config, get_user_id(request), get_github_user_id(request)
|
||||
)
|
||||
try:
|
||||
await conversation_store.get_metadata(conversation_id)
|
||||
|
||||
@@ -1,15 +1,11 @@
|
||||
from fastapi import APIRouter, Depends, Request, status
|
||||
from fastapi import APIRouter, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import (
|
||||
PROVIDER_TOKEN_TYPE,
|
||||
ProviderToken,
|
||||
ProviderType,
|
||||
SecretStore,
|
||||
)
|
||||
from openhands.integrations.provider import ProviderToken, ProviderType, SecretStore
|
||||
from openhands.integrations.utils import validate_provider_token
|
||||
from openhands.server.auth import get_provider_tokens, get_user_id
|
||||
from openhands.server.settings import (
|
||||
GETSettingsCustomSecrets,
|
||||
GETSettingsModel,
|
||||
@@ -19,24 +15,16 @@ from openhands.server.settings import (
|
||||
)
|
||||
from openhands.server.shared import SettingsStoreImpl, config, server_config
|
||||
from openhands.server.types import AppMode
|
||||
from openhands.server.user_auth import (
|
||||
get_provider_tokens,
|
||||
get_user_id,
|
||||
get_user_settings,
|
||||
get_user_settings_store,
|
||||
)
|
||||
from openhands.storage.settings.settings_store import SettingsStore
|
||||
|
||||
app = APIRouter(prefix='/api')
|
||||
|
||||
|
||||
@app.get('/settings', response_model=GETSettingsModel)
|
||||
async def load_settings(
|
||||
user_id: str | None = Depends(get_user_id),
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens),
|
||||
settings: Settings | None = Depends(get_user_settings),
|
||||
) -> GETSettingsModel | JSONResponse:
|
||||
async def load_settings(request: Request) -> GETSettingsModel | JSONResponse:
|
||||
try:
|
||||
user_id = get_user_id(request)
|
||||
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
|
||||
settings = await settings_store.load()
|
||||
if not settings:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
@@ -48,6 +36,7 @@ async def load_settings(
|
||||
if bool(user_id):
|
||||
provider_tokens_set[ProviderType.GITHUB.value] = True
|
||||
|
||||
provider_tokens = get_provider_tokens(request)
|
||||
if provider_tokens:
|
||||
all_provider_types = [provider.value for provider in ProviderType]
|
||||
provider_tokens_types = [provider.value for provider in provider_tokens]
|
||||
@@ -59,7 +48,7 @@ async def load_settings(
|
||||
|
||||
settings_with_token_data = GETSettingsModel(
|
||||
**settings.model_dump(exclude='secrets_store'),
|
||||
llm_api_key_set=settings.llm_api_key is not None and bool(settings.llm_api_key),
|
||||
llm_api_key_set=settings.llm_api_key is not None,
|
||||
provider_tokens_set=provider_tokens_set,
|
||||
)
|
||||
settings_with_token_data.llm_api_key = None
|
||||
@@ -74,9 +63,12 @@ async def load_settings(
|
||||
|
||||
@app.get('/secrets', response_model=GETSettingsCustomSecrets)
|
||||
async def load_custom_secrets_names(
|
||||
settings: Settings | None = Depends(get_user_settings),
|
||||
request: Request,
|
||||
) -> GETSettingsCustomSecrets | JSONResponse:
|
||||
try:
|
||||
user_id = get_user_id(request)
|
||||
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
|
||||
settings = await settings_store.load()
|
||||
if not settings:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
@@ -101,11 +93,13 @@ async def load_custom_secrets_names(
|
||||
|
||||
@app.post('/secrets', response_model=dict[str, str])
|
||||
async def add_custom_secret(
|
||||
incoming_secrets: POSTSettingsCustomSecrets,
|
||||
settings_store: SettingsStore = Depends(get_user_settings_store),
|
||||
request: Request, incoming_secrets: POSTSettingsCustomSecrets
|
||||
) -> JSONResponse:
|
||||
try:
|
||||
existing_settings = await settings_store.load()
|
||||
settings_store = await SettingsStoreImpl.get_instance(
|
||||
config, get_user_id(request)
|
||||
)
|
||||
existing_settings: Settings = await settings_store.load()
|
||||
if existing_settings:
|
||||
for (
|
||||
secret_name,
|
||||
@@ -127,6 +121,7 @@ async def add_custom_secret(
|
||||
update={'secrets_store': updated_secret_store}
|
||||
)
|
||||
|
||||
updated_settings = convert_to_settings(updated_settings)
|
||||
await settings_store.store(updated_settings)
|
||||
|
||||
return JSONResponse(
|
||||
@@ -142,11 +137,11 @@ async def add_custom_secret(
|
||||
|
||||
|
||||
@app.delete('/secrets/{secret_id}')
|
||||
async def delete_custom_secret(
|
||||
secret_id: str,
|
||||
settings_store: SettingsStore = Depends(get_user_settings_store),
|
||||
) -> JSONResponse:
|
||||
async def delete_custom_secret(request: Request, secret_id: str) -> JSONResponse:
|
||||
try:
|
||||
settings_store = await SettingsStoreImpl.get_instance(
|
||||
config, get_user_id(request)
|
||||
)
|
||||
existing_settings: Settings | None = await settings_store.load()
|
||||
custom_secrets = {}
|
||||
if existing_settings:
|
||||
@@ -167,6 +162,7 @@ async def delete_custom_secret(
|
||||
update={'secrets_store': updated_secret_store}
|
||||
)
|
||||
|
||||
updated_settings = convert_to_settings(updated_settings)
|
||||
await settings_store.store(updated_settings)
|
||||
|
||||
return JSONResponse(
|
||||
@@ -182,10 +178,12 @@ async def delete_custom_secret(
|
||||
|
||||
|
||||
@app.post('/unset-settings-tokens', response_model=dict[str, str])
|
||||
async def unset_settings_tokens(
|
||||
settings_store: SettingsStore = Depends(get_user_settings_store),
|
||||
) -> JSONResponse:
|
||||
async def unset_settings_tokens(request: Request) -> JSONResponse:
|
||||
try:
|
||||
settings_store = await SettingsStoreImpl.get_instance(
|
||||
config, get_user_id(request)
|
||||
)
|
||||
|
||||
existing_settings = await settings_store.load()
|
||||
if existing_settings:
|
||||
settings = existing_settings.model_copy(
|
||||
@@ -207,20 +205,58 @@ async def unset_settings_tokens(
|
||||
|
||||
|
||||
@app.post('/reset-settings', response_model=dict[str, str])
|
||||
async def reset_settings() -> JSONResponse:
|
||||
async def reset_settings(request: Request) -> JSONResponse:
|
||||
"""
|
||||
Resets user settings. (Deprecated)
|
||||
Resets user settings.
|
||||
"""
|
||||
logger.warning(
|
||||
f"Deprecated endpoint /api/reset-settings called by user"
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_410_GONE,
|
||||
content={'error': 'Reset settings functionality has been removed.'},
|
||||
)
|
||||
try:
|
||||
settings_store = await SettingsStoreImpl.get_instance(
|
||||
config, get_user_id(request)
|
||||
)
|
||||
|
||||
existing_settings = await settings_store.load()
|
||||
settings = Settings(
|
||||
language='en',
|
||||
agent='CodeActAgent',
|
||||
security_analyzer='',
|
||||
confirmation_mode=False,
|
||||
llm_model='anthropic/claude-3-5-sonnet-20241022',
|
||||
llm_api_key='',
|
||||
llm_base_url='',
|
||||
remote_runtime_resource_factor=1,
|
||||
enable_default_condenser=True,
|
||||
enable_sound_notifications=False,
|
||||
user_consents_to_analytics=existing_settings.user_consents_to_analytics
|
||||
if existing_settings
|
||||
else False,
|
||||
)
|
||||
|
||||
server_config_values = server_config.get_config()
|
||||
is_hide_llm_settings_enabled = server_config_values.get(
|
||||
'FEATURE_FLAGS', {}
|
||||
).get('HIDE_LLM_SETTINGS', False)
|
||||
# We don't want the user to be able to modify these settings in SaaS
|
||||
if server_config.app_mode == AppMode.SAAS and is_hide_llm_settings_enabled:
|
||||
if existing_settings:
|
||||
settings.llm_api_key = existing_settings.llm_api_key
|
||||
settings.llm_base_url = existing_settings.llm_base_url
|
||||
settings.llm_model = existing_settings.llm_model
|
||||
|
||||
await settings_store.store(settings)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={'message': 'Settings stored'},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f'Something went wrong resetting settings: {e}')
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content={'error': 'Something went wrong resetting settings'},
|
||||
)
|
||||
|
||||
|
||||
async def check_provider_tokens(settings: POSTSettingsModel) -> str:
|
||||
async def check_provider_tokens(request: Request, settings: POSTSettingsModel) -> str:
|
||||
if settings.provider_tokens:
|
||||
# Remove extraneous token types
|
||||
provider_types = [provider.value for provider in ProviderType]
|
||||
@@ -240,9 +276,8 @@ async def check_provider_tokens(settings: POSTSettingsModel) -> str:
|
||||
return ''
|
||||
|
||||
|
||||
async def store_provider_tokens(
|
||||
settings: POSTSettingsModel, settings_store: SettingsStore
|
||||
):
|
||||
async def store_provider_tokens(request: Request, settings: POSTSettingsModel):
|
||||
settings_store = await SettingsStoreImpl.get_instance(config, get_user_id(request))
|
||||
existing_settings = await settings_store.load()
|
||||
if existing_settings:
|
||||
if settings.provider_tokens:
|
||||
@@ -276,8 +311,9 @@ async def store_provider_tokens(
|
||||
|
||||
|
||||
async def store_llm_settings(
|
||||
settings: POSTSettingsModel, settings_store: SettingsStore
|
||||
request: Request, settings: POSTSettingsModel
|
||||
) -> POSTSettingsModel:
|
||||
settings_store = await SettingsStoreImpl.get_instance(config, get_user_id(request))
|
||||
existing_settings = await settings_store.load()
|
||||
|
||||
# Convert to Settings model and merge with existing settings
|
||||
@@ -295,11 +331,11 @@ async def store_llm_settings(
|
||||
|
||||
@app.post('/settings', response_model=dict[str, str])
|
||||
async def store_settings(
|
||||
request: Request,
|
||||
settings: POSTSettingsModel,
|
||||
settings_store: SettingsStore = Depends(get_user_settings_store),
|
||||
) -> JSONResponse:
|
||||
# Check provider tokens are valid
|
||||
provider_err_msg = await check_provider_tokens(settings)
|
||||
provider_err_msg = await check_provider_tokens(request, settings)
|
||||
if provider_err_msg:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
@@ -307,11 +343,14 @@ async def store_settings(
|
||||
)
|
||||
|
||||
try:
|
||||
settings_store = await SettingsStoreImpl.get_instance(
|
||||
config, get_user_id(request)
|
||||
)
|
||||
existing_settings = await settings_store.load()
|
||||
|
||||
# Convert to Settings model and merge with existing settings
|
||||
if existing_settings:
|
||||
settings = await store_llm_settings(settings, settings_store)
|
||||
settings = await store_llm_settings(request, settings)
|
||||
|
||||
# Keep existing analytics consent if not provided
|
||||
if settings.user_consents_to_analytics is None:
|
||||
@@ -319,7 +358,7 @@ async def store_settings(
|
||||
existing_settings.user_consents_to_analytics
|
||||
)
|
||||
|
||||
settings = await store_provider_tokens(settings, settings_store)
|
||||
settings = await store_provider_tokens(request, settings)
|
||||
|
||||
# Update sandbox config with new settings
|
||||
if settings.remote_runtime_resource_factor is not None:
|
||||
|
||||
@@ -94,10 +94,7 @@ class Settings(BaseModel):
|
||||
return {
|
||||
'provider_tokens': secrets.provider_tokens_serializer(
|
||||
secrets.provider_tokens, info
|
||||
),
|
||||
'custom_secrets': secrets.custom_secrets_serializer(
|
||||
secrets.custom_secrets, info
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
from fastapi import Request
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.server.settings import Settings
|
||||
from openhands.server.user_auth.user_auth import get_user_auth
|
||||
from openhands.storage.settings.settings_store import SettingsStore
|
||||
|
||||
|
||||
async def get_provider_tokens(request: Request) -> PROVIDER_TOKEN_TYPE | None:
|
||||
user_auth = await get_user_auth(request)
|
||||
provider_tokens = await user_auth.get_provider_tokens()
|
||||
return provider_tokens
|
||||
|
||||
|
||||
async def get_access_token(request: Request) -> SecretStr | None:
|
||||
user_auth = await get_user_auth(request)
|
||||
access_token = await user_auth.get_access_token()
|
||||
return access_token
|
||||
|
||||
|
||||
async def get_user_id(request: Request) -> str | None:
|
||||
user_auth = await get_user_auth(request)
|
||||
user_id = await user_auth.get_user_id()
|
||||
return user_id
|
||||
|
||||
|
||||
async def get_github_user_id(request: Request) -> str | None:
|
||||
provider_tokens = await get_provider_tokens(request)
|
||||
if not provider_tokens:
|
||||
return None
|
||||
github_provider = provider_tokens.get(ProviderType.GITHUB)
|
||||
if github_provider:
|
||||
return github_provider.user_id
|
||||
return None
|
||||
|
||||
|
||||
async def get_user_settings(request: Request) -> Settings | None:
|
||||
user_auth = await get_user_auth(request)
|
||||
user_settings = await user_auth.get_user_settings()
|
||||
return user_settings
|
||||
|
||||
|
||||
async def get_user_settings_store(request: Request) -> SettingsStore | None:
|
||||
user_auth = await get_user_auth(request)
|
||||
user_settings_store = await user_auth.get_user_settings_store()
|
||||
return user_settings_store
|
||||
@@ -1,57 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from fastapi import Request
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||
from openhands.server import shared
|
||||
from openhands.server.settings import Settings
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
from openhands.storage.settings.settings_store import SettingsStore
|
||||
|
||||
|
||||
@dataclass
|
||||
class DefaultUserAuth(UserAuth):
|
||||
"""Default user authentication mechanism"""
|
||||
|
||||
_settings: Settings | None = None
|
||||
_settings_store: SettingsStore | None = None
|
||||
|
||||
async def get_user_id(self) -> str | None:
|
||||
"""The default implementation does not support multi tenancy, so user_id is always None"""
|
||||
return None
|
||||
|
||||
async def get_access_token(self) -> SecretStr | None:
|
||||
"""The default implementation does not support multi tenancy, so access_token is always None"""
|
||||
return None
|
||||
|
||||
async def get_user_settings_store(self):
|
||||
settings_store = self._settings_store
|
||||
if settings_store:
|
||||
return settings_store
|
||||
user_id = await self.get_user_id()
|
||||
settings_store = await shared.SettingsStoreImpl.get_instance(
|
||||
shared.config, user_id
|
||||
)
|
||||
self._settings_store = settings_store
|
||||
return settings_store
|
||||
|
||||
async def get_user_settings(self) -> Settings | None:
|
||||
settings = self._settings
|
||||
if settings:
|
||||
return settings
|
||||
settings_store = await self.get_user_settings_store()
|
||||
settings = await settings_store.load()
|
||||
self._settings = settings
|
||||
return settings
|
||||
|
||||
async def get_provider_tokens(self) -> PROVIDER_TOKEN_TYPE | None:
|
||||
settings = await self.get_user_settings()
|
||||
secrets_store = getattr(settings, 'secrets_store', None)
|
||||
provider_tokens = getattr(secrets_store, 'provider_tokens', None)
|
||||
return provider_tokens
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls, request: Request) -> UserAuth:
|
||||
user_auth = DefaultUserAuth()
|
||||
return user_auth
|
||||
@@ -1,63 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from fastapi import Request
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||
from openhands.server.settings import Settings
|
||||
from openhands.server.shared import server_config
|
||||
from openhands.storage.settings.settings_store import SettingsStore
|
||||
from openhands.utils.import_utils import get_impl
|
||||
|
||||
|
||||
class UserAuth(ABC):
|
||||
"""Extensible class encapsulating user Authentication"""
|
||||
|
||||
_settings: Settings | None
|
||||
|
||||
@abstractmethod
|
||||
async def get_user_id(self) -> str | None:
|
||||
"""Get the unique identifier for the current user"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_access_token(self) -> SecretStr | None:
|
||||
"""Get the access token for the current user"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_provider_tokens(self) -> PROVIDER_TOKEN_TYPE | None:
|
||||
"""Get the provider tokens for the current user."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_user_settings_store(self) -> SettingsStore | None:
|
||||
"""Get the settings store for the current user."""
|
||||
|
||||
async def get_user_settings(self) -> Settings | None:
|
||||
"""Get the user settings for the current user"""
|
||||
settings = self._settings
|
||||
if settings:
|
||||
return settings
|
||||
settings_store = await self.get_user_settings_store()
|
||||
if settings_store is None:
|
||||
return None
|
||||
settings = await settings_store.load()
|
||||
self._settings = settings
|
||||
return settings
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def get_instance(cls, request: Request) -> UserAuth:
|
||||
"""Get an instance of UserAuth from the request given"""
|
||||
|
||||
|
||||
async def get_user_auth(request: Request) -> UserAuth:
|
||||
user_auth = getattr(request.state, 'user_auth', None)
|
||||
if user_auth:
|
||||
return user_auth
|
||||
impl_name = server_config.user_auth_class
|
||||
impl = get_impl(UserAuth, impl_name)
|
||||
user_auth = await impl.get_instance(request)
|
||||
request.state.user_auth = user_auth
|
||||
return user_auth
|
||||
@@ -1,16 +0,0 @@
|
||||
from fastapi import Request
|
||||
|
||||
from openhands.server.shared import ConversationStoreImpl, config
|
||||
from openhands.server.user_auth import get_user_auth
|
||||
from openhands.storage.conversation.conversation_store import ConversationStore
|
||||
|
||||
|
||||
async def get_conversation_store(request: Request) -> ConversationStore | None:
|
||||
conversation_store = getattr(request.state, 'conversation_store', None)
|
||||
if conversation_store:
|
||||
return conversation_store
|
||||
user_auth = await get_user_auth(request)
|
||||
user_id = await user_auth.get_user_id()
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
|
||||
request.state.conversation_store = conversation_store
|
||||
return conversation_store
|
||||
@@ -60,6 +60,6 @@ class ConversationStore(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def get_instance(
|
||||
cls, config: AppConfig, user_id: str | None
|
||||
cls, config: AppConfig, user_id: str | None, github_user_id: str | None
|
||||
) -> ConversationStore:
|
||||
"""Get a store for the user represented by the token given."""
|
||||
|
||||
@@ -101,7 +101,7 @@ class FileConversationStore(ConversationStore):
|
||||
|
||||
@classmethod
|
||||
async def get_instance(
|
||||
cls, config: AppConfig, user_id: str | None
|
||||
cls, config: AppConfig, user_id: str | None, github_user_id: str | None
|
||||
) -> FileConversationStore:
|
||||
file_store = get_file_store(config.file_store, config.file_store_path)
|
||||
return FileConversationStore(file_store)
|
||||
|
||||
Generated
+38
-43
@@ -496,18 +496,18 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "boto3"
|
||||
version = "1.38.2"
|
||||
version = "1.38.1"
|
||||
description = "The AWS SDK for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "boto3-1.38.2-py3-none-any.whl", hash = "sha256:ef3237b169cd906a44a32c03b3229833d923c9e9733355b329ded2151f91ec0b"},
|
||||
{file = "boto3-1.38.2.tar.gz", hash = "sha256:53c8d44b231251fa9421dd13d968236d59fe2cf0421e077afedbf3821653fb3b"},
|
||||
{file = "boto3-1.38.1-py3-none-any.whl", hash = "sha256:f192a4a34885a9e3e970b5ce5e6bec947be0f3fe6c4693b2a737c14407b12a5a"},
|
||||
{file = "boto3-1.38.1.tar.gz", hash = "sha256:988e7fae7fd4d59798f84604d73a3a019c07b048f746c7c40258c0e656473887"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
botocore = ">=1.38.2,<1.39.0"
|
||||
botocore = ">=1.38.1,<1.39.0"
|
||||
jmespath = ">=0.7.1,<2.0.0"
|
||||
s3transfer = ">=0.12.0,<0.13.0"
|
||||
|
||||
@@ -516,14 +516,14 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"]
|
||||
|
||||
[[package]]
|
||||
name = "boto3-stubs"
|
||||
version = "1.38.2"
|
||||
description = "Type annotations for boto3 1.38.2 generated with mypy-boto3-builder 8.10.1"
|
||||
version = "1.38.1"
|
||||
description = "Type annotations for boto3 1.38.1 generated with mypy-boto3-builder 8.10.1"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["evaluation"]
|
||||
files = [
|
||||
{file = "boto3_stubs-1.38.2-py3-none-any.whl", hash = "sha256:e18f2dc194c4b8a29f61275ba039689d063c4775a78560e35a5ce820ec257fb5"},
|
||||
{file = "boto3_stubs-1.38.2.tar.gz", hash = "sha256:405cd777d41530cf8ed009d20b04daef1f7d4bd2fd9fd3636ac86eccdb55159c"},
|
||||
{file = "boto3_stubs-1.38.1-py3-none-any.whl", hash = "sha256:3501f98c39b8c2d613b1138a4e8881ceef2ac9497ac030be47cf4336f1aa0573"},
|
||||
{file = "boto3_stubs-1.38.1.tar.gz", hash = "sha256:25b03fdbda288c1576fbe002ecf40088e9f5d6cdf0518de8a84a7467aa898092"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -579,7 +579,7 @@ bedrock-data-automation-runtime = ["mypy-boto3-bedrock-data-automation-runtime (
|
||||
bedrock-runtime = ["mypy-boto3-bedrock-runtime (>=1.38.0,<1.39.0)"]
|
||||
billing = ["mypy-boto3-billing (>=1.38.0,<1.39.0)"]
|
||||
billingconductor = ["mypy-boto3-billingconductor (>=1.38.0,<1.39.0)"]
|
||||
boto3 = ["boto3 (==1.38.2)"]
|
||||
boto3 = ["boto3 (==1.38.1)"]
|
||||
braket = ["mypy-boto3-braket (>=1.38.0,<1.39.0)"]
|
||||
budgets = ["mypy-boto3-budgets (>=1.38.0,<1.39.0)"]
|
||||
ce = ["mypy-boto3-ce (>=1.38.0,<1.39.0)"]
|
||||
@@ -943,14 +943,14 @@ xray = ["mypy-boto3-xray (>=1.38.0,<1.39.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "botocore"
|
||||
version = "1.38.2"
|
||||
version = "1.38.1"
|
||||
description = "Low-level, data-driven core of boto 3."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "botocore-1.38.2-py3-none-any.whl", hash = "sha256:5d9cffedb1c759a058b43793d16647ed44ec87072f98a1bd6cd673ac0ae6b81d"},
|
||||
{file = "botocore-1.38.2.tar.gz", hash = "sha256:b688a9bd17211a1eaae3a6c965ba9f3973e5435efaaa4fa201f499d3467830e1"},
|
||||
{file = "botocore-1.38.1-py3-none-any.whl", hash = "sha256:b1673975e3c42d0e2d1804f9f73e88961e95eac371c8f8c0a0d7e661ec3c90c3"},
|
||||
{file = "botocore-1.38.1.tar.gz", hash = "sha256:c2eb42eeaa502f236ba894a65ea7f7241711150cc450b9d59fbbad41e741adc0"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -2523,7 +2523,6 @@ files = [
|
||||
{file = "gevent-25.4.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c1d1a66a28372d505e0d8f6f1fdb62f7d5b3423e49431f41b99bd9133f006b7"},
|
||||
{file = "gevent-25.4.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:fdf9aec76a7285b00fb64ec942cd9ff88f8765874a5abf99c4e8c5374b3133e9"},
|
||||
{file = "gevent-25.4.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:7442b3ffac08f6239d6463ee2943fd9a619b64b2db11cec292acf8caccb70536"},
|
||||
{file = "gevent-25.4.2-cp310-cp310-win_amd64.whl", hash = "sha256:d7999e4d4b3597b706a333f9a7bf2efbd8365cd244312405f33b4870fa3b411d"},
|
||||
{file = "gevent-25.4.2-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:2270a8607661e609c44e4f72811b6380dcfede558041e4ee3134e66753865038"},
|
||||
{file = "gevent-25.4.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eb89ed32e2b766fcb1afc52847e33d8c369d2b40f23d4c96977fd092b5a0ea86"},
|
||||
{file = "gevent-25.4.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:43469ed40ea6cfb1c88e8d85a57aa5f52dd6b3b94a2e499752ab7e60a90c7dba"},
|
||||
@@ -2531,7 +2530,6 @@ files = [
|
||||
{file = "gevent-25.4.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ccbc835939416a7df7834b79c655409a2a9d2deb9bf119b28dedf72a168f7895"},
|
||||
{file = "gevent-25.4.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:feb5f2f44dcdad1a6b80e7ce24e7557ce25d01ff13b7a74ca276d113adf9d4af"},
|
||||
{file = "gevent-25.4.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:91408dd197c13ca0f1e0d5cdcc9870c674963bb87a7e370b2884d1426d73834f"},
|
||||
{file = "gevent-25.4.2-cp311-cp311-win_amd64.whl", hash = "sha256:12b596c027cf546a235231d421473483fdf7fa586d38162d36b07c8efa9081ba"},
|
||||
{file = "gevent-25.4.2-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:5940174c7d1ffc7bb4b0ea9f2908f4f361eb03ada9e145d3590b8df1e61c379b"},
|
||||
{file = "gevent-25.4.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b7ae7ad4ff9c4492d4b633702e35153509b07dc6ffd20f1577076d7647c9caba"},
|
||||
{file = "gevent-25.4.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d68fdf9bff0068367126983d7d85765124c292b4bc3d4d19ed8138335d8426a7"},
|
||||
@@ -2539,7 +2537,6 @@ files = [
|
||||
{file = "gevent-25.4.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a7c70ab6d33dfeb43bfe982c636609d8f90506dacaaa1f409a3c43c66d578fb1"},
|
||||
{file = "gevent-25.4.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:8e740bc08ba4c34951f4bb6351dbe04209416e12d620691fb57e115b218a7818"},
|
||||
{file = "gevent-25.4.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:c535d96ded6e26b37fadda9242a49fea6308754da5945173940614b7520c07b4"},
|
||||
{file = "gevent-25.4.2-cp312-cp312-win_amd64.whl", hash = "sha256:c62bf14557d2cb54f5e3c1ba0a3b3f4b69bf0441081c32d63b205763b495b251"},
|
||||
{file = "gevent-25.4.2-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:f735f57bc19d0f8bbc784093cfb7953a9ad66612b05c3ff876ec7951a96d7edd"},
|
||||
{file = "gevent-25.4.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63aecf1e43b8d01086ea574ed05f7272ed40c48dd41fa3d061e3c5ca900abcdd"},
|
||||
{file = "gevent-25.4.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f12e570777027f807dc7dc3ea1945ea040befaf1c9485deb6f24d7110009fc12"},
|
||||
@@ -2551,8 +2548,6 @@ files = [
|
||||
{file = "gevent-25.4.2-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:b0a656eccd9cb115d01c9bbe55bfe84cf20c8422c495503f41aef747b193c33d"},
|
||||
{file = "gevent-25.4.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95790dd8aeb4ca8df9ac215ec353a29108647797e54daa652a4634ca316f70d4"},
|
||||
{file = "gevent-25.4.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:76c440972ff57eb64e089f85210ccc0fa247ab71cdedff5414c6b86392f7f791"},
|
||||
{file = "gevent-25.4.2-cp39-cp39-win32.whl", hash = "sha256:b91e862ab0ddecf37ee6e3bf33965ef4c3e38ba9cdc106eef552293caed512f9"},
|
||||
{file = "gevent-25.4.2-cp39-cp39-win_amd64.whl", hash = "sha256:03587078c402aee27231ecaabd81aec1e8b3de2629830fbd4486e2d09e638ddc"},
|
||||
{file = "gevent-25.4.2-pp310-pypy310_pp73-macosx_11_0_universal2.whl", hash = "sha256:498f548330c4724e3b0cee0d75551165fc9e4309ae3ddcba3d644aaa866ca9c3"},
|
||||
{file = "gevent-25.4.2.tar.gz", hash = "sha256:7ffba461458ed28a85a01285ea0e0dc14f883204d17ce5ed82fa839a9d620028"},
|
||||
]
|
||||
@@ -2676,14 +2671,14 @@ grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"]
|
||||
|
||||
[[package]]
|
||||
name = "google-api-python-client"
|
||||
version = "2.168.0"
|
||||
version = "2.167.0"
|
||||
description = "Google API Client Library for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "google_api_python_client-2.168.0-py3-none-any.whl", hash = "sha256:ebf27fc318a3cf682dc994cefc46b6794eafee91d91fc659d46e018155ace530"},
|
||||
{file = "google_api_python_client-2.168.0.tar.gz", hash = "sha256:10759c3c8f5bbb17752b349ff336963ab215db150f34594a5581d5cd9b5add41"},
|
||||
{file = "google_api_python_client-2.167.0-py2.py3-none-any.whl", hash = "sha256:ce25290cc229505d770ca5c8d03850e0ae87d8e998fc6dd743ecece018baa396"},
|
||||
{file = "google_api_python_client-2.167.0.tar.gz", hash = "sha256:a458d402572e1c2caf9db090d8e7b270f43ff326bd9349c731a86b19910e3995"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -4883,14 +4878,14 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "modal"
|
||||
version = "0.74.23"
|
||||
version = "0.74.20"
|
||||
description = "Python client library for Modal"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main", "evaluation"]
|
||||
files = [
|
||||
{file = "modal-0.74.23-py3-none-any.whl", hash = "sha256:96c397487ed5f499ad040b5edf5f378ada8e0676da17523a2d6fadb3f1d384e1"},
|
||||
{file = "modal-0.74.23.tar.gz", hash = "sha256:3a042cdf482975b43341da0b33fa6a6adae06978ead69a086ca658a7dcb0cd6d"},
|
||||
{file = "modal-0.74.20-py3-none-any.whl", hash = "sha256:d6aa369f83399a399b2b151f98124df703a31b8358d5ace3289a88cb61fb9ef0"},
|
||||
{file = "modal-0.74.20.tar.gz", hash = "sha256:2d4e6e6592c309346448d8e278a0c392596c8a6971b5d789bc81af0d8180f2de"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -7950,30 +7945,30 @@ pyasn1 = ">=0.1.3"
|
||||
|
||||
[[package]]
|
||||
name = "ruff"
|
||||
version = "0.11.7"
|
||||
version = "0.11.6"
|
||||
description = "An extremely fast Python linter and code formatter, written in Rust."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["dev", "evaluation"]
|
||||
files = [
|
||||
{file = "ruff-0.11.7-py3-none-linux_armv6l.whl", hash = "sha256:d29e909d9a8d02f928d72ab7837b5cbc450a5bdf578ab9ebee3263d0a525091c"},
|
||||
{file = "ruff-0.11.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:dd1fb86b168ae349fb01dd497d83537b2c5541fe0626e70c786427dd8363aaee"},
|
||||
{file = "ruff-0.11.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d3d7d2e140a6fbbc09033bce65bd7ea29d6a0adeb90b8430262fbacd58c38ada"},
|
||||
{file = "ruff-0.11.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4809df77de390a1c2077d9b7945d82f44b95d19ceccf0c287c56e4dc9b91ca64"},
|
||||
{file = "ruff-0.11.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f3a0c2e169e6b545f8e2dba185eabbd9db4f08880032e75aa0e285a6d3f48201"},
|
||||
{file = "ruff-0.11.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:49b888200a320dd96a68e86736cf531d6afba03e4f6cf098401406a257fcf3d6"},
|
||||
{file = "ruff-0.11.7-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:2b19cdb9cf7dae00d5ee2e7c013540cdc3b31c4f281f1dacb5a799d610e90db4"},
|
||||
{file = "ruff-0.11.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:64e0ee994c9e326b43539d133a36a455dbaab477bc84fe7bfbd528abe2f05c1e"},
|
||||
{file = "ruff-0.11.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bad82052311479a5865f52c76ecee5d468a58ba44fb23ee15079f17dd4c8fd63"},
|
||||
{file = "ruff-0.11.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7940665e74e7b65d427b82bffc1e46710ec7f30d58b4b2d5016e3f0321436502"},
|
||||
{file = "ruff-0.11.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:169027e31c52c0e36c44ae9a9c7db35e505fee0b39f8d9fca7274a6305295a92"},
|
||||
{file = "ruff-0.11.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:305b93f9798aee582e91e34437810439acb28b5fc1fee6b8205c78c806845a94"},
|
||||
{file = "ruff-0.11.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:a681db041ef55550c371f9cd52a3cf17a0da4c75d6bd691092dfc38170ebc4b6"},
|
||||
{file = "ruff-0.11.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:07f1496ad00a4a139f4de220b0c97da6d4c85e0e4aa9b2624167b7d4d44fd6b6"},
|
||||
{file = "ruff-0.11.7-py3-none-win32.whl", hash = "sha256:f25dfb853ad217e6e5f1924ae8a5b3f6709051a13e9dad18690de6c8ff299e26"},
|
||||
{file = "ruff-0.11.7-py3-none-win_amd64.whl", hash = "sha256:0a931d85959ceb77e92aea4bbedfded0a31534ce191252721128f77e5ae1f98a"},
|
||||
{file = "ruff-0.11.7-py3-none-win_arm64.whl", hash = "sha256:778c1e5d6f9e91034142dfd06110534ca13220bfaad5c3735f6cb844654f6177"},
|
||||
{file = "ruff-0.11.7.tar.gz", hash = "sha256:655089ad3224070736dc32844fde783454f8558e71f501cb207485fe4eee23d4"},
|
||||
{file = "ruff-0.11.6-py3-none-linux_armv6l.whl", hash = "sha256:d84dcbe74cf9356d1bdb4a78cf74fd47c740bf7bdeb7529068f69b08272239a1"},
|
||||
{file = "ruff-0.11.6-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:9bc583628e1096148011a5d51ff3c836f51899e61112e03e5f2b1573a9b726de"},
|
||||
{file = "ruff-0.11.6-py3-none-macosx_11_0_arm64.whl", hash = "sha256:f2959049faeb5ba5e3b378709e9d1bf0cab06528b306b9dd6ebd2a312127964a"},
|
||||
{file = "ruff-0.11.6-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63c5d4e30d9d0de7fedbfb3e9e20d134b73a30c1e74b596f40f0629d5c28a193"},
|
||||
{file = "ruff-0.11.6-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:26a4b9a4e1439f7d0a091c6763a100cef8fbdc10d68593df6f3cfa5abdd9246e"},
|
||||
{file = "ruff-0.11.6-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b5edf270223dd622218256569636dc3e708c2cb989242262fe378609eccf1308"},
|
||||
{file = "ruff-0.11.6-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:f55844e818206a9dd31ff27f91385afb538067e2dc0beb05f82c293ab84f7d55"},
|
||||
{file = "ruff-0.11.6-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1d8f782286c5ff562e4e00344f954b9320026d8e3fae2ba9e6948443fafd9ffc"},
|
||||
{file = "ruff-0.11.6-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:01c63ba219514271cee955cd0adc26a4083df1956d57847978383b0e50ffd7d2"},
|
||||
{file = "ruff-0.11.6-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:15adac20ef2ca296dd3d8e2bedc6202ea6de81c091a74661c3666e5c4c223ff6"},
|
||||
{file = "ruff-0.11.6-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:4dd6b09e98144ad7aec026f5588e493c65057d1b387dd937d7787baa531d9bc2"},
|
||||
{file = "ruff-0.11.6-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:45b2e1d6c0eed89c248d024ea95074d0e09988d8e7b1dad8d3ab9a67017a5b03"},
|
||||
{file = "ruff-0.11.6-py3-none-musllinux_1_2_i686.whl", hash = "sha256:bd40de4115b2ec4850302f1a1d8067f42e70b4990b68838ccb9ccd9f110c5e8b"},
|
||||
{file = "ruff-0.11.6-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:77cda2dfbac1ab73aef5e514c4cbfc4ec1fbef4b84a44c736cc26f61b3814cd9"},
|
||||
{file = "ruff-0.11.6-py3-none-win32.whl", hash = "sha256:5151a871554be3036cd6e51d0ec6eef56334d74dfe1702de717a995ee3d5b287"},
|
||||
{file = "ruff-0.11.6-py3-none-win_amd64.whl", hash = "sha256:cce85721d09c51f3b782c331b0abd07e9d7d5f775840379c640606d3159cae0e"},
|
||||
{file = "ruff-0.11.6-py3-none-win_arm64.whl", hash = "sha256:3567ba0d07fb170b1b48d944715e3294b77f5b7679e8ba258199a250383ccb79"},
|
||||
{file = "ruff-0.11.6.tar.gz", hash = "sha256:bec8bcc3ac228a45ccc811e45f7eb61b950dbf4cf31a67fa89352574b01c7d79"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -10257,4 +10252,4 @@ testing = ["coverage[toml]", "zope.event", "zope.testing"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = "^3.12"
|
||||
content-hash = "e0d99d8657168051347da0ebbeb0ff23b3c035149627253736cf9d2ec3930435"
|
||||
content-hash = "ce7e49638e83acefc31930cb89f11f33a0c243552e0ed703e769440305c88aa9"
|
||||
|
||||
+1
-1
@@ -78,7 +78,7 @@ playwright = "^1.51.0"
|
||||
prompt-toolkit = "^3.0.50"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
ruff = "0.11.7"
|
||||
ruff = "0.11.6"
|
||||
mypy = "1.15.0"
|
||||
pre-commit = "4.2.0"
|
||||
build = "*"
|
||||
|
||||
+40
-160
@@ -1,10 +1,11 @@
|
||||
import json
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.runtime.impl.docker.docker_runtime import DockerRuntime
|
||||
from openhands.server.data_models.conversation_info import ConversationInfo
|
||||
from openhands.server.data_models.conversation_info_result_set import (
|
||||
ConversationInfoResultSet,
|
||||
@@ -15,7 +16,6 @@ from openhands.server.routes.manage_conversations import (
|
||||
search_conversations,
|
||||
update_conversation,
|
||||
)
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationMetadata
|
||||
from openhands.storage.data_models.conversation_status import ConversationStatus
|
||||
from openhands.storage.locations import get_conversation_metadata_filename
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
@@ -72,36 +72,9 @@ async def test_search_conversations():
|
||||
)
|
||||
mock_datetime.fromisoformat = datetime.fromisoformat
|
||||
mock_datetime.timezone = timezone
|
||||
|
||||
# Mock the conversation store
|
||||
mock_store = MagicMock()
|
||||
mock_store.search = AsyncMock(
|
||||
return_value=ConversationInfoResultSet(
|
||||
results=[
|
||||
ConversationMetadata(
|
||||
conversation_id='some_conversation_id',
|
||||
title='Some Conversation',
|
||||
created_at=datetime.fromisoformat(
|
||||
'2025-01-01T00:00:00+00:00'
|
||||
),
|
||||
last_updated_at=datetime.fromisoformat(
|
||||
'2025-01-01T00:01:00+00:00'
|
||||
),
|
||||
selected_repository='foobar',
|
||||
github_user_id='12345',
|
||||
user_id='12345',
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
result_set = await search_conversations(
|
||||
page_id=None,
|
||||
limit=20,
|
||||
user_id='12345',
|
||||
conversation_store=mock_store,
|
||||
MagicMock(state=MagicMock(github_token=''))
|
||||
)
|
||||
|
||||
expected = ConversationInfoResultSet(
|
||||
results=[
|
||||
ConversationInfo(
|
||||
@@ -124,51 +97,26 @@ async def test_search_conversations():
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation():
|
||||
with _patch_store():
|
||||
# Mock the conversation store
|
||||
mock_store = MagicMock()
|
||||
mock_store.get_metadata = AsyncMock(
|
||||
return_value=ConversationMetadata(
|
||||
conversation_id='some_conversation_id',
|
||||
title='Some Conversation',
|
||||
created_at=datetime.fromisoformat('2025-01-01T00:00:00+00:00'),
|
||||
last_updated_at=datetime.fromisoformat('2025-01-01T00:01:00+00:00'),
|
||||
selected_repository='foobar',
|
||||
github_user_id='12345',
|
||||
user_id='12345',
|
||||
)
|
||||
conversation = await get_conversation(
|
||||
'some_conversation_id', MagicMock(state=MagicMock(github_token=''))
|
||||
)
|
||||
|
||||
# Mock the conversation manager
|
||||
with patch(
|
||||
'openhands.server.routes.manage_conversations.conversation_manager'
|
||||
) as mock_manager:
|
||||
mock_manager.is_agent_loop_running = AsyncMock(return_value=False)
|
||||
|
||||
conversation = await get_conversation(
|
||||
'some_conversation_id', conversation_store=mock_store
|
||||
)
|
||||
|
||||
expected = ConversationInfo(
|
||||
conversation_id='some_conversation_id',
|
||||
title='Some Conversation',
|
||||
created_at=datetime.fromisoformat('2025-01-01T00:00:00+00:00'),
|
||||
last_updated_at=datetime.fromisoformat('2025-01-01T00:01:00+00:00'),
|
||||
status=ConversationStatus.STOPPED,
|
||||
selected_repository='foobar',
|
||||
)
|
||||
assert conversation == expected
|
||||
expected = ConversationInfo(
|
||||
conversation_id='some_conversation_id',
|
||||
title='Some Conversation',
|
||||
created_at=datetime.fromisoformat('2025-01-01T00:00:00+00:00'),
|
||||
last_updated_at=datetime.fromisoformat('2025-01-01T00:01:00+00:00'),
|
||||
status=ConversationStatus.STOPPED,
|
||||
selected_repository='foobar',
|
||||
)
|
||||
assert conversation == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_missing_conversation():
|
||||
with _patch_store():
|
||||
# Mock the conversation store
|
||||
mock_store = MagicMock()
|
||||
mock_store.get_metadata = AsyncMock(side_effect=FileNotFoundError)
|
||||
|
||||
assert (
|
||||
await get_conversation(
|
||||
'no_such_conversation', conversation_store=mock_store
|
||||
'no_such_conversation', MagicMock(state=MagicMock(github_token=''))
|
||||
)
|
||||
is None
|
||||
)
|
||||
@@ -177,102 +125,34 @@ async def test_get_missing_conversation():
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_conversation():
|
||||
with _patch_store():
|
||||
# Mock the ConversationStoreImpl.get_instance
|
||||
with patch(
|
||||
'openhands.server.routes.manage_conversations.ConversationStoreImpl.get_instance'
|
||||
) as mock_get_instance:
|
||||
# Create a mock conversation store
|
||||
mock_store = MagicMock()
|
||||
|
||||
# Mock metadata
|
||||
metadata = ConversationMetadata(
|
||||
conversation_id='some_conversation_id',
|
||||
title='Some Conversation',
|
||||
created_at=datetime.fromisoformat('2025-01-01T00:00:00+00:00'),
|
||||
last_updated_at=datetime.fromisoformat('2025-01-01T00:01:00+00:00'),
|
||||
selected_repository='foobar',
|
||||
github_user_id='12345',
|
||||
user_id='12345',
|
||||
)
|
||||
|
||||
# Set up the mock to return metadata and then save it
|
||||
mock_store.get_metadata = AsyncMock(return_value=metadata)
|
||||
mock_store.save_metadata = AsyncMock()
|
||||
|
||||
# Return the mock store from get_instance
|
||||
mock_get_instance.return_value = mock_store
|
||||
|
||||
# Call update_conversation
|
||||
result = await update_conversation(
|
||||
'some_conversation_id',
|
||||
'New Title',
|
||||
user_id='12345',
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result is True
|
||||
|
||||
# Verify that save_metadata was called with updated metadata
|
||||
mock_store.save_metadata.assert_called_once()
|
||||
saved_metadata = mock_store.save_metadata.call_args[0][0]
|
||||
assert saved_metadata.title == 'New Title'
|
||||
await update_conversation(
|
||||
MagicMock(state=MagicMock(github_token='')),
|
||||
'some_conversation_id',
|
||||
'New Title',
|
||||
)
|
||||
conversation = await get_conversation(
|
||||
'some_conversation_id', MagicMock(state=MagicMock(github_token=''))
|
||||
)
|
||||
expected = ConversationInfo(
|
||||
conversation_id='some_conversation_id',
|
||||
title='New Title',
|
||||
created_at=datetime.fromisoformat('2025-01-01T00:00:00+00:00'),
|
||||
last_updated_at=datetime.fromisoformat('2025-01-01T00:01:00+00:00'),
|
||||
status=ConversationStatus.STOPPED,
|
||||
selected_repository='foobar',
|
||||
)
|
||||
assert conversation == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_conversation():
|
||||
with _patch_store():
|
||||
# Mock the ConversationStoreImpl.get_instance
|
||||
with patch(
|
||||
'openhands.server.routes.manage_conversations.ConversationStoreImpl.get_instance'
|
||||
) as mock_get_instance:
|
||||
# Create a mock conversation store
|
||||
mock_store = MagicMock()
|
||||
|
||||
# Set up the mock to return metadata and then delete it
|
||||
mock_store.get_metadata = AsyncMock(
|
||||
return_value=ConversationMetadata(
|
||||
conversation_id='some_conversation_id',
|
||||
title='Some Conversation',
|
||||
created_at=datetime.fromisoformat('2025-01-01T00:00:00+00:00'),
|
||||
last_updated_at=datetime.fromisoformat('2025-01-01T00:01:00+00:00'),
|
||||
selected_repository='foobar',
|
||||
github_user_id='12345',
|
||||
user_id='12345',
|
||||
)
|
||||
with patch.object(DockerRuntime, 'delete', return_value=None):
|
||||
await delete_conversation(
|
||||
'some_conversation_id',
|
||||
MagicMock(state=MagicMock(github_token='')),
|
||||
)
|
||||
mock_store.delete_metadata = AsyncMock()
|
||||
|
||||
# Return the mock store from get_instance
|
||||
mock_get_instance.return_value = mock_store
|
||||
|
||||
# Mock the conversation manager
|
||||
with patch(
|
||||
'openhands.server.routes.manage_conversations.conversation_manager'
|
||||
) as mock_manager:
|
||||
mock_manager.is_agent_loop_running = AsyncMock(return_value=False)
|
||||
|
||||
# Mock the runtime class
|
||||
with patch(
|
||||
'openhands.server.routes.manage_conversations.get_runtime_cls'
|
||||
) as mock_get_runtime_cls:
|
||||
mock_runtime_cls = MagicMock()
|
||||
mock_runtime_cls.delete = AsyncMock()
|
||||
mock_get_runtime_cls.return_value = mock_runtime_cls
|
||||
|
||||
# Call delete_conversation
|
||||
result = await delete_conversation(
|
||||
'some_conversation_id', user_id='12345'
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result is True
|
||||
|
||||
# Verify that delete_metadata was called
|
||||
mock_store.delete_metadata.assert_called_once_with(
|
||||
'some_conversation_id'
|
||||
)
|
||||
|
||||
# Verify that runtime.delete was called
|
||||
mock_runtime_cls.delete.assert_called_once_with(
|
||||
'some_conversation_id'
|
||||
)
|
||||
conversation = await get_conversation(
|
||||
'some_conversation_id', MagicMock(state=MagicMock(github_token=''))
|
||||
)
|
||||
assert conversation is None
|
||||
|
||||
@@ -1,17 +1,11 @@
|
||||
import asyncio
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.runtime.utils.git_handler import CommandResult, GitHandler
|
||||
|
||||
# Mark all test methods as asyncio tests
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
class TestGitHandler(unittest.TestCase):
|
||||
def setUp(self):
|
||||
@@ -110,9 +104,9 @@ class TestGitHandler(unittest.TestCase):
|
||||
# Push the feature branch to origin
|
||||
self._execute_command('git push -u origin feature-branch', self.local_dir)
|
||||
|
||||
async def test_is_git_repo(self):
|
||||
def test_is_git_repo(self):
|
||||
"""Test that _is_git_repo returns True for a git repository."""
|
||||
self.assertTrue(await self.git_handler._is_git_repo())
|
||||
self.assertTrue(self.git_handler._is_git_repo())
|
||||
|
||||
# Verify the command was executed
|
||||
self.assertTrue(
|
||||
@@ -122,9 +116,9 @@ class TestGitHandler(unittest.TestCase):
|
||||
)
|
||||
)
|
||||
|
||||
async def test_get_default_branch(self):
|
||||
def test_get_default_branch(self):
|
||||
"""Test that _get_default_branch returns the correct branch name."""
|
||||
branch = await self.git_handler._get_default_branch()
|
||||
branch = self.git_handler._get_default_branch()
|
||||
self.assertEqual(branch, 'main')
|
||||
|
||||
# Verify the command was executed
|
||||
@@ -135,9 +129,9 @@ class TestGitHandler(unittest.TestCase):
|
||||
)
|
||||
)
|
||||
|
||||
async def test_get_current_branch(self):
|
||||
def test_get_current_branch(self):
|
||||
"""Test that _get_current_branch returns the correct branch name."""
|
||||
branch = await self.git_handler._get_current_branch()
|
||||
branch = self.git_handler._get_current_branch()
|
||||
self.assertEqual(branch, 'feature-branch')
|
||||
|
||||
# Verify the command was executed
|
||||
@@ -148,10 +142,10 @@ class TestGitHandler(unittest.TestCase):
|
||||
)
|
||||
)
|
||||
|
||||
async def test_get_valid_ref_with_origin_current_branch(self):
|
||||
def test_get_valid_ref_with_origin_current_branch(self):
|
||||
"""Test that _get_valid_ref returns the current branch in origin when it exists."""
|
||||
# This test uses the setup from setUp where the current branch exists in origin
|
||||
ref = await self.git_handler._get_valid_ref()
|
||||
ref = self.git_handler._get_valid_ref()
|
||||
self.assertIsNotNone(ref)
|
||||
|
||||
# Check that the refs were checked in the correct order
|
||||
@@ -171,7 +165,7 @@ class TestGitHandler(unittest.TestCase):
|
||||
result = self._execute_command(f'git rev-parse --verify {ref}', self.local_dir)
|
||||
self.assertEqual(result.exit_code, 0)
|
||||
|
||||
async def test_get_valid_ref_without_origin_current_branch(self):
|
||||
def test_get_valid_ref_without_origin_current_branch(self):
|
||||
"""Test that _get_valid_ref falls back to default branch when current branch doesn't exist in origin."""
|
||||
# Create a new branch that doesn't exist in origin
|
||||
self._execute_command('git checkout -b new-local-branch', self.local_dir)
|
||||
@@ -179,7 +173,7 @@ class TestGitHandler(unittest.TestCase):
|
||||
# Clear the executed commands to start fresh
|
||||
self.executed_commands = []
|
||||
|
||||
ref = await self.git_handler._get_valid_ref()
|
||||
ref = self.git_handler._get_valid_ref()
|
||||
self.assertIsNotNone(ref)
|
||||
|
||||
# Check that the refs were checked in the correct order
|
||||
@@ -202,7 +196,7 @@ class TestGitHandler(unittest.TestCase):
|
||||
result = self._execute_command(f'git rev-parse --verify {ref}', self.local_dir)
|
||||
self.assertEqual(result.exit_code, 0)
|
||||
|
||||
async def test_get_valid_ref_without_origin(self):
|
||||
def test_get_valid_ref_without_origin(self):
|
||||
"""Test that _get_valid_ref falls back to empty tree ref when there's no origin."""
|
||||
# Create a new directory with a git repo but no origin
|
||||
no_origin_dir = os.path.join(self.test_dir, 'no-origin')
|
||||
@@ -213,20 +207,18 @@ class TestGitHandler(unittest.TestCase):
|
||||
self._execute_command("git config user.email 'test@example.com'", no_origin_dir)
|
||||
self._execute_command("git config user.name 'Test User'", no_origin_dir)
|
||||
|
||||
# Create a file and commit it using subprocess
|
||||
file_path = os.path.join(no_origin_dir, 'file1.txt')
|
||||
self._execute_command(
|
||||
f'echo "Content in repo without origin" > {file_path}', no_origin_dir
|
||||
)
|
||||
# Create a file and commit it
|
||||
with open(os.path.join(no_origin_dir, 'file1.txt'), 'w') as f:
|
||||
f.write('Content in repo without origin')
|
||||
self._execute_command('git add file1.txt', no_origin_dir)
|
||||
self._execute_command("git commit -m 'Initial commit'", no_origin_dir)
|
||||
|
||||
# Create a custom GitHandler with a modified _get_default_branch method for this test
|
||||
class TestGitHandler(GitHandler):
|
||||
async def _get_default_branch(self) -> str:
|
||||
def _get_default_branch(self) -> str:
|
||||
# Override to handle repos without origin
|
||||
try:
|
||||
return await super()._get_default_branch()
|
||||
return super()._get_default_branch()
|
||||
except IndexError:
|
||||
return 'main' # Default fallback
|
||||
|
||||
@@ -237,7 +229,7 @@ class TestGitHandler(unittest.TestCase):
|
||||
# Clear the executed commands to start fresh
|
||||
self.executed_commands = []
|
||||
|
||||
ref = await no_origin_handler._get_valid_ref()
|
||||
ref = no_origin_handler._get_valid_ref()
|
||||
|
||||
# Verify that git commands were executed
|
||||
self.assertTrue(
|
||||
@@ -259,9 +251,9 @@ class TestGitHandler(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(result.exit_code, 0)
|
||||
|
||||
async def test_get_ref_content(self):
|
||||
def test_get_ref_content(self):
|
||||
"""Test that _get_ref_content returns the content from a valid ref."""
|
||||
content = await self.git_handler._get_ref_content('file1.txt')
|
||||
content = self.git_handler._get_ref_content('file1.txt')
|
||||
self.assertEqual(content.strip(), 'Modified content')
|
||||
|
||||
# Should have called _get_valid_ref and then git show
|
||||
@@ -270,9 +262,9 @@ class TestGitHandler(unittest.TestCase):
|
||||
]
|
||||
self.assertTrue(any('file1.txt' in cmd for cmd in show_commands))
|
||||
|
||||
async def test_get_current_file_content(self):
|
||||
def test_get_current_file_content(self):
|
||||
"""Test that _get_current_file_content returns the current content of a file."""
|
||||
content = await self.git_handler._get_current_file_content('file1.txt')
|
||||
content = self.git_handler._get_current_file_content('file1.txt')
|
||||
self.assertEqual(content.strip(), 'Modified content again')
|
||||
|
||||
# Verify the command was executed
|
||||
@@ -280,15 +272,14 @@ class TestGitHandler(unittest.TestCase):
|
||||
any(cmd == 'cat file1.txt' for cmd, _ in self.executed_commands)
|
||||
)
|
||||
|
||||
async def test_get_changed_files(self):
|
||||
def test_get_changed_files(self):
|
||||
"""Test that _get_changed_files returns the list of changed files."""
|
||||
# Let's create a new file to ensure it shows up in the diff
|
||||
# Use subprocess directly to create and add the file
|
||||
file_path = os.path.join(self.local_dir, 'new_file.txt')
|
||||
self._execute_command(f'echo "New file content" > {file_path}', self.local_dir)
|
||||
with open(os.path.join(self.local_dir, 'new_file.txt'), 'w') as f:
|
||||
f.write('New file content')
|
||||
self._execute_command('git add new_file.txt', self.local_dir)
|
||||
|
||||
files = await self.git_handler._get_changed_files()
|
||||
files = self.git_handler._get_changed_files()
|
||||
self.assertTrue(files)
|
||||
|
||||
# Should include file1.txt (modified) and file3.txt (deleted)
|
||||
@@ -304,15 +295,13 @@ class TestGitHandler(unittest.TestCase):
|
||||
]
|
||||
self.assertTrue(diff_commands)
|
||||
|
||||
async def test_get_untracked_files(self):
|
||||
def test_get_untracked_files(self):
|
||||
"""Test that _get_untracked_files returns the list of untracked files."""
|
||||
# Create an untracked file using subprocess
|
||||
file_path = os.path.join(self.local_dir, 'untracked.txt')
|
||||
self._execute_command(
|
||||
f'echo "Untracked file content" > {file_path}', self.local_dir
|
||||
)
|
||||
# Create an untracked file
|
||||
with open(os.path.join(self.local_dir, 'untracked.txt'), 'w') as f:
|
||||
f.write('Untracked file content')
|
||||
|
||||
files = await self.git_handler._get_untracked_files()
|
||||
files = self.git_handler._get_untracked_files()
|
||||
self.assertEqual(len(files), 1)
|
||||
self.assertEqual(files[0]['path'], 'untracked.txt')
|
||||
self.assertEqual(files[0]['status'], 'A')
|
||||
@@ -325,22 +314,18 @@ class TestGitHandler(unittest.TestCase):
|
||||
)
|
||||
)
|
||||
|
||||
async def test_get_git_changes(self):
|
||||
def test_get_git_changes(self):
|
||||
"""Test that get_git_changes returns the combined list of changed and untracked files."""
|
||||
# Create an untracked file using subprocess
|
||||
file_path = os.path.join(self.local_dir, 'untracked.txt')
|
||||
self._execute_command(
|
||||
f'echo "Untracked file content" > {file_path}', self.local_dir
|
||||
)
|
||||
# Create an untracked file
|
||||
with open(os.path.join(self.local_dir, 'untracked.txt'), 'w') as f:
|
||||
f.write('Untracked file content')
|
||||
|
||||
# Create a new file and stage it
|
||||
file_path2 = os.path.join(self.local_dir, 'new_file2.txt')
|
||||
self._execute_command(
|
||||
f'echo "New file 2 content" > {file_path2}', self.local_dir
|
||||
)
|
||||
with open(os.path.join(self.local_dir, 'new_file2.txt'), 'w') as f:
|
||||
f.write('New file 2 content')
|
||||
self._execute_command('git add new_file2.txt', self.local_dir)
|
||||
|
||||
changes = await self.git_handler.get_git_changes()
|
||||
changes = self.git_handler.get_git_changes()
|
||||
self.assertIsNotNone(changes)
|
||||
|
||||
# Should include file1.txt (modified), file3.txt (deleted), new_file2.txt (added), and untracked.txt (untracked)
|
||||
@@ -356,9 +341,9 @@ class TestGitHandler(unittest.TestCase):
|
||||
self.assertIn('A', statuses) # Added
|
||||
self.assertIn('D', statuses) # Deleted
|
||||
|
||||
async def test_get_git_diff(self):
|
||||
def test_get_git_diff(self):
|
||||
"""Test that get_git_diff returns the original and modified content of a file."""
|
||||
diff = await self.git_handler.get_git_diff('file1.txt')
|
||||
diff = self.git_handler.get_git_diff('file1.txt')
|
||||
self.assertEqual(diff['modified'].strip(), 'Modified content again')
|
||||
self.assertEqual(diff['original'].strip(), 'Modified content')
|
||||
|
||||
@@ -375,6 +360,4 @@ class TestGitHandler(unittest.TestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import asyncio
|
||||
|
||||
asyncio.run(unittest.main())
|
||||
unittest.main()
|
||||
|
||||
+313
-334
@@ -1,8 +1,6 @@
|
||||
"""Tests for the custom secrets API endpoints."""
|
||||
# flake8: noqa: E501
|
||||
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
@@ -12,8 +10,6 @@ from pydantic import SecretStr
|
||||
from openhands.integrations.provider import ProviderToken, ProviderType, SecretStore
|
||||
from openhands.server.routes.settings import app as settings_app
|
||||
from openhands.server.settings import Settings
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
from openhands.storage.settings.file_settings_store import FileSettingsStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -24,128 +20,127 @@ def test_client():
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def patch_file_settings_store():
|
||||
store = FileSettingsStore(InMemoryFileStore())
|
||||
with patch(
|
||||
'openhands.storage.settings.file_settings_store.FileSettingsStore.get_instance',
|
||||
AsyncMock(return_value=store),
|
||||
):
|
||||
yield store
|
||||
@pytest.fixture
|
||||
def mock_settings_store():
|
||||
with patch('openhands.server.routes.settings.SettingsStoreImpl') as mock:
|
||||
store_instance = MagicMock()
|
||||
mock.get_instance = AsyncMock(return_value=store_instance)
|
||||
store_instance.load = AsyncMock()
|
||||
store_instance.store = AsyncMock()
|
||||
yield store_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_convert_to_settings():
|
||||
with patch('openhands.server.routes.settings.convert_to_settings') as mock:
|
||||
# Make the mock function pass through the input settings
|
||||
mock.side_effect = lambda settings: settings
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_user_id():
|
||||
with patch('openhands.server.routes.settings.get_user_id') as mock:
|
||||
mock.return_value = 'test-user'
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_custom_secrets_names(test_client):
|
||||
async def test_load_custom_secrets_names(test_client, mock_settings_store):
|
||||
"""Test loading custom secrets names."""
|
||||
with patch_file_settings_store() as file_settings_store:
|
||||
# Create initial settings with custom secrets
|
||||
custom_secrets = {
|
||||
'API_KEY': SecretStr('api-key-value'),
|
||||
'DB_PASSWORD': SecretStr('db-password-value'),
|
||||
}
|
||||
provider_tokens = {
|
||||
ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token'))
|
||||
}
|
||||
secret_store = SecretStore(
|
||||
custom_secrets=custom_secrets, provider_tokens=provider_tokens
|
||||
)
|
||||
initial_settings = Settings(
|
||||
language='en',
|
||||
agent='test-agent',
|
||||
llm_api_key=SecretStr('test-llm-key'),
|
||||
secrets_store=secret_store,
|
||||
)
|
||||
# Create initial settings with custom secrets
|
||||
custom_secrets = {
|
||||
'API_KEY': SecretStr('api-key-value'),
|
||||
'DB_PASSWORD': SecretStr('db-password-value'),
|
||||
}
|
||||
provider_tokens = {
|
||||
ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token'))
|
||||
}
|
||||
secret_store = SecretStore(
|
||||
custom_secrets=custom_secrets, provider_tokens=provider_tokens
|
||||
)
|
||||
initial_settings = Settings(
|
||||
language='en',
|
||||
agent='test-agent',
|
||||
llm_api_key=SecretStr('test-llm-key'),
|
||||
secrets_store=secret_store,
|
||||
)
|
||||
|
||||
# Store the initial settings
|
||||
await file_settings_store.store(initial_settings)
|
||||
# Mock the settings store to return our initial settings
|
||||
mock_settings_store.load.return_value = initial_settings
|
||||
|
||||
# Make the GET request
|
||||
response = test_client.get('/api/secrets')
|
||||
assert response.status_code == 200
|
||||
# Make the GET request
|
||||
response = test_client.get('/api/secrets')
|
||||
assert response.status_code == 200
|
||||
|
||||
# Check the response
|
||||
data = response.json()
|
||||
assert 'custom_secrets' in data
|
||||
assert sorted(data['custom_secrets']) == ['API_KEY', 'DB_PASSWORD']
|
||||
|
||||
# Verify that the original settings were not modified
|
||||
stored_settings = await file_settings_store.load()
|
||||
assert (
|
||||
stored_settings.secrets_store.custom_secrets['API_KEY'].get_secret_value()
|
||||
== 'api-key-value'
|
||||
)
|
||||
assert (
|
||||
stored_settings.secrets_store.custom_secrets[
|
||||
'DB_PASSWORD'
|
||||
].get_secret_value()
|
||||
== 'db-password-value'
|
||||
)
|
||||
assert ProviderType.GITHUB in stored_settings.secrets_store.provider_tokens
|
||||
# Check the response
|
||||
data = response.json()
|
||||
assert 'custom_secrets' in data
|
||||
assert sorted(data['custom_secrets']) == ['API_KEY', 'DB_PASSWORD']
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_custom_secrets_names_empty(test_client):
|
||||
async def test_load_custom_secrets_names_empty(test_client, mock_settings_store):
|
||||
"""Test loading custom secrets names when there are no custom secrets."""
|
||||
with patch_file_settings_store() as file_settings_store:
|
||||
# Create initial settings with no custom secrets
|
||||
provider_tokens = {
|
||||
ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token'))
|
||||
}
|
||||
secret_store = SecretStore(provider_tokens=provider_tokens)
|
||||
initial_settings = Settings(
|
||||
language='en',
|
||||
agent='test-agent',
|
||||
llm_api_key=SecretStr('test-llm-key'),
|
||||
secrets_store=secret_store,
|
||||
)
|
||||
# Create initial settings with no custom secrets
|
||||
provider_tokens = {
|
||||
ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token'))
|
||||
}
|
||||
secret_store = SecretStore(provider_tokens=provider_tokens)
|
||||
initial_settings = Settings(
|
||||
language='en',
|
||||
agent='test-agent',
|
||||
llm_api_key=SecretStr('test-llm-key'),
|
||||
secrets_store=secret_store,
|
||||
)
|
||||
|
||||
# Store the initial settings
|
||||
await file_settings_store.store(initial_settings)
|
||||
# Mock the settings store to return our initial settings
|
||||
mock_settings_store.load.return_value = initial_settings
|
||||
|
||||
# Make the GET request
|
||||
response = test_client.get('/api/secrets')
|
||||
assert response.status_code == 200
|
||||
# Make the GET request
|
||||
response = test_client.get('/api/secrets')
|
||||
assert response.status_code == 200
|
||||
|
||||
# Check the response
|
||||
data = response.json()
|
||||
assert 'custom_secrets' in data
|
||||
assert data['custom_secrets'] == []
|
||||
# Check the response
|
||||
data = response.json()
|
||||
assert 'custom_secrets' in data
|
||||
assert data['custom_secrets'] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_custom_secret(test_client):
|
||||
async def test_add_custom_secret(
|
||||
test_client, mock_settings_store, mock_convert_to_settings
|
||||
):
|
||||
"""Test adding a new custom secret."""
|
||||
# Create initial settings with provider tokens but no custom secrets
|
||||
provider_tokens = {
|
||||
ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token'))
|
||||
}
|
||||
secret_store = SecretStore(provider_tokens=provider_tokens)
|
||||
initial_settings = Settings(
|
||||
language='en',
|
||||
agent='test-agent',
|
||||
llm_api_key=SecretStr('test-llm-key'),
|
||||
secrets_store=secret_store,
|
||||
)
|
||||
|
||||
with patch_file_settings_store() as file_settings_store:
|
||||
# Create initial settings with provider tokens but no custom secrets
|
||||
provider_tokens = {
|
||||
ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token'))
|
||||
}
|
||||
secret_store = SecretStore(provider_tokens=provider_tokens)
|
||||
initial_settings = Settings(
|
||||
language='en',
|
||||
agent='test-agent',
|
||||
llm_api_key=SecretStr('test-llm-key'),
|
||||
secrets_store=secret_store,
|
||||
)
|
||||
# Mock the settings store to return our initial settings
|
||||
mock_settings_store.load.return_value = initial_settings
|
||||
|
||||
# Store the initial settings
|
||||
await file_settings_store.store(initial_settings)
|
||||
# Make the POST request to add a custom secret
|
||||
add_secret_data = {'custom_secrets': {'API_KEY': 'api-key-value'}}
|
||||
response = test_client.post('/api/secrets', json=add_secret_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Make the POST request to add a custom secret
|
||||
add_secret_data = {'custom_secrets': {'API_KEY': 'api-key-value'}}
|
||||
response = test_client.post('/api/secrets', json=add_secret_data)
|
||||
assert response.status_code == 200
|
||||
# Verify that the settings were stored with the new secret
|
||||
stored_settings = mock_settings_store.store.call_args[0][0]
|
||||
|
||||
# Verify that the settings were stored with the new secret
|
||||
stored_settings = await file_settings_store.load()
|
||||
|
||||
# Check that the secret was added
|
||||
assert 'API_KEY' in stored_settings.secrets_store.custom_secrets
|
||||
assert (
|
||||
stored_settings.secrets_store.custom_secrets['API_KEY'].get_secret_value()
|
||||
== 'api-key-value'
|
||||
)
|
||||
# Check that the secret was added
|
||||
assert 'API_KEY' in stored_settings.secrets_store.custom_secrets
|
||||
assert (
|
||||
stored_settings.secrets_store.custom_secrets['API_KEY'].get_secret_value()
|
||||
== 'api-key-value'
|
||||
)
|
||||
|
||||
# Check that other settings were preserved
|
||||
assert stored_settings.language == 'en'
|
||||
@@ -154,274 +149,258 @@ async def test_add_custom_secret(test_client):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_existing_custom_secret(test_client):
|
||||
async def test_update_existing_custom_secret(
|
||||
test_client, mock_settings_store, mock_convert_to_settings
|
||||
):
|
||||
"""Test updating an existing custom secret."""
|
||||
with patch_file_settings_store() as file_settings_store:
|
||||
# Create initial settings with a custom secret
|
||||
custom_secrets = {'API_KEY': SecretStr('old-api-key')}
|
||||
provider_tokens = {
|
||||
ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token'))
|
||||
}
|
||||
secret_store = SecretStore(
|
||||
custom_secrets=custom_secrets, provider_tokens=provider_tokens
|
||||
)
|
||||
initial_settings = Settings(
|
||||
language='en',
|
||||
agent='test-agent',
|
||||
llm_api_key=SecretStr('test-llm-key'),
|
||||
secrets_store=secret_store,
|
||||
)
|
||||
# Create initial settings with a custom secret
|
||||
custom_secrets = {'API_KEY': SecretStr('old-api-key')}
|
||||
provider_tokens = {
|
||||
ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token'))
|
||||
}
|
||||
secret_store = SecretStore(
|
||||
custom_secrets=custom_secrets, provider_tokens=provider_tokens
|
||||
)
|
||||
initial_settings = Settings(
|
||||
language='en',
|
||||
agent='test-agent',
|
||||
llm_api_key=SecretStr('test-llm-key'),
|
||||
secrets_store=secret_store,
|
||||
)
|
||||
|
||||
# Store the initial settings
|
||||
await file_settings_store.store(initial_settings)
|
||||
# Mock the settings store to return our initial settings
|
||||
mock_settings_store.load.return_value = initial_settings
|
||||
|
||||
# Make the POST request to update the custom secret
|
||||
update_secret_data = {'custom_secrets': {'API_KEY': 'new-api-key'}}
|
||||
response = test_client.post('/api/secrets', json=update_secret_data)
|
||||
assert response.status_code == 200
|
||||
# Make the POST request to update the custom secret
|
||||
update_secret_data = {'custom_secrets': {'API_KEY': 'new-api-key'}}
|
||||
response = test_client.post('/api/secrets', json=update_secret_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify that the settings were stored with the updated secret
|
||||
stored_settings = await file_settings_store.load()
|
||||
# Verify that the settings were stored with the updated secret
|
||||
stored_settings: Settings = mock_settings_store.store.call_args[0][0]
|
||||
|
||||
# Check that the secret was updated
|
||||
assert 'API_KEY' in stored_settings.secrets_store.custom_secrets
|
||||
assert (
|
||||
stored_settings.secrets_store.custom_secrets['API_KEY'].get_secret_value()
|
||||
== 'new-api-key'
|
||||
)
|
||||
# Check that the secret was updated
|
||||
assert 'API_KEY' in stored_settings.secrets_store.custom_secrets
|
||||
assert (
|
||||
stored_settings.secrets_store.custom_secrets['API_KEY'].get_secret_value()
|
||||
== 'new-api-key'
|
||||
)
|
||||
|
||||
# Check that other settings were preserved
|
||||
assert stored_settings.language == 'en'
|
||||
assert stored_settings.agent == 'test-agent'
|
||||
assert stored_settings.llm_api_key.get_secret_value() == 'test-llm-key'
|
||||
assert ProviderType.GITHUB in stored_settings.secrets_store.provider_tokens
|
||||
# Check that other settings were preserved
|
||||
assert stored_settings.language == 'en'
|
||||
assert stored_settings.agent == 'test-agent'
|
||||
assert stored_settings.llm_api_key.get_secret_value() == 'test-llm-key'
|
||||
assert ProviderType.GITHUB in stored_settings.secrets_store.provider_tokens
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_multiple_custom_secrets(test_client):
|
||||
async def test_add_multiple_custom_secrets(
|
||||
test_client, mock_settings_store, mock_convert_to_settings
|
||||
):
|
||||
"""Test adding multiple custom secrets at once."""
|
||||
with patch_file_settings_store() as file_settings_store:
|
||||
# Create initial settings with one custom secret
|
||||
custom_secrets = {'EXISTING_SECRET': SecretStr('existing-value')}
|
||||
provider_tokens = {
|
||||
ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token'))
|
||||
# Create initial settings with one custom secret
|
||||
custom_secrets = {'EXISTING_SECRET': SecretStr('existing-value')}
|
||||
provider_tokens = {
|
||||
ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token'))
|
||||
}
|
||||
secret_store = SecretStore(
|
||||
custom_secrets=custom_secrets, provider_tokens=provider_tokens
|
||||
)
|
||||
initial_settings = Settings(
|
||||
language='en',
|
||||
agent='test-agent',
|
||||
llm_api_key=SecretStr('test-llm-key'),
|
||||
secrets_store=secret_store,
|
||||
)
|
||||
|
||||
# Mock the settings store to return our initial settings
|
||||
mock_settings_store.load.return_value = initial_settings
|
||||
|
||||
# Make the POST request to add multiple custom secrets
|
||||
add_secrets_data = {
|
||||
'custom_secrets': {
|
||||
'API_KEY': 'api-key-value',
|
||||
'DB_PASSWORD': 'db-password-value',
|
||||
}
|
||||
secret_store = SecretStore(
|
||||
custom_secrets=custom_secrets, provider_tokens=provider_tokens
|
||||
)
|
||||
initial_settings = Settings(
|
||||
language='en',
|
||||
agent='test-agent',
|
||||
llm_api_key=SecretStr('test-llm-key'),
|
||||
secrets_store=secret_store,
|
||||
)
|
||||
}
|
||||
response = test_client.post('/api/secrets', json=add_secrets_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Store the initial settings
|
||||
await file_settings_store.store(initial_settings)
|
||||
# Verify that the settings were stored with the new secrets
|
||||
stored_settings = mock_settings_store.store.call_args[0][0]
|
||||
|
||||
# Make the POST request to add multiple custom secrets
|
||||
add_secrets_data = {
|
||||
'custom_secrets': {
|
||||
'API_KEY': 'api-key-value',
|
||||
'DB_PASSWORD': 'db-password-value',
|
||||
}
|
||||
}
|
||||
response = test_client.post('/api/secrets', json=add_secrets_data)
|
||||
assert response.status_code == 200
|
||||
# Check that the new secrets were added
|
||||
assert 'API_KEY' in stored_settings.secrets_store.custom_secrets
|
||||
assert (
|
||||
stored_settings.secrets_store.custom_secrets['API_KEY'].get_secret_value()
|
||||
== 'api-key-value'
|
||||
)
|
||||
assert 'DB_PASSWORD' in stored_settings.secrets_store.custom_secrets
|
||||
assert (
|
||||
stored_settings.secrets_store.custom_secrets['DB_PASSWORD'].get_secret_value()
|
||||
== 'db-password-value'
|
||||
)
|
||||
|
||||
# Verify that the settings were stored with the new secrets
|
||||
stored_settings = await file_settings_store.load()
|
||||
# Check that existing secrets were preserved
|
||||
assert 'EXISTING_SECRET' in stored_settings.secrets_store.custom_secrets
|
||||
assert (
|
||||
stored_settings.secrets_store.custom_secrets[
|
||||
'EXISTING_SECRET'
|
||||
].get_secret_value()
|
||||
== 'existing-value'
|
||||
)
|
||||
|
||||
# Check that the new secrets were added
|
||||
assert 'API_KEY' in stored_settings.secrets_store.custom_secrets
|
||||
assert (
|
||||
stored_settings.secrets_store.custom_secrets['API_KEY'].get_secret_value()
|
||||
== 'api-key-value'
|
||||
)
|
||||
assert 'DB_PASSWORD' in stored_settings.secrets_store.custom_secrets
|
||||
assert (
|
||||
stored_settings.secrets_store.custom_secrets[
|
||||
'DB_PASSWORD'
|
||||
].get_secret_value()
|
||||
== 'db-password-value'
|
||||
)
|
||||
|
||||
# Check that existing secrets were preserved
|
||||
assert 'EXISTING_SECRET' in stored_settings.secrets_store.custom_secrets
|
||||
assert (
|
||||
stored_settings.secrets_store.custom_secrets[
|
||||
'EXISTING_SECRET'
|
||||
].get_secret_value()
|
||||
== 'existing-value'
|
||||
)
|
||||
|
||||
# Check that other settings were preserved
|
||||
assert stored_settings.language == 'en'
|
||||
assert stored_settings.agent == 'test-agent'
|
||||
assert stored_settings.llm_api_key.get_secret_value() == 'test-llm-key'
|
||||
assert ProviderType.GITHUB in stored_settings.secrets_store.provider_tokens
|
||||
# Check that other settings were preserved
|
||||
assert stored_settings.language == 'en'
|
||||
assert stored_settings.agent == 'test-agent'
|
||||
assert stored_settings.llm_api_key.get_secret_value() == 'test-llm-key'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_custom_secret(test_client):
|
||||
async def test_delete_custom_secret(
|
||||
test_client, mock_settings_store, mock_convert_to_settings
|
||||
):
|
||||
"""Test deleting a custom secret."""
|
||||
with patch_file_settings_store() as file_settings_store:
|
||||
# Create initial settings with multiple custom secrets
|
||||
custom_secrets = {
|
||||
'API_KEY': SecretStr('api-key-value'),
|
||||
'DB_PASSWORD': SecretStr('db-password-value'),
|
||||
}
|
||||
provider_tokens = {
|
||||
ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token'))
|
||||
}
|
||||
secret_store = SecretStore(
|
||||
custom_secrets=custom_secrets, provider_tokens=provider_tokens
|
||||
)
|
||||
initial_settings = Settings(
|
||||
language='en',
|
||||
agent='test-agent',
|
||||
llm_api_key=SecretStr('test-llm-key'),
|
||||
secrets_store=secret_store,
|
||||
)
|
||||
# Create initial settings with multiple custom secrets
|
||||
custom_secrets = {
|
||||
'API_KEY': SecretStr('api-key-value'),
|
||||
'DB_PASSWORD': SecretStr('db-password-value'),
|
||||
}
|
||||
provider_tokens = {
|
||||
ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token'))
|
||||
}
|
||||
secret_store = SecretStore(
|
||||
custom_secrets=custom_secrets, provider_tokens=provider_tokens
|
||||
)
|
||||
initial_settings = Settings(
|
||||
language='en',
|
||||
agent='test-agent',
|
||||
llm_api_key=SecretStr('test-llm-key'),
|
||||
secrets_store=secret_store,
|
||||
)
|
||||
|
||||
# Store the initial settings
|
||||
await file_settings_store.store(initial_settings)
|
||||
# Mock the settings store to return our initial settings
|
||||
mock_settings_store.load.return_value = initial_settings
|
||||
|
||||
# Make the DELETE request to delete a custom secret
|
||||
response = test_client.delete('/api/secrets/API_KEY')
|
||||
assert response.status_code == 200
|
||||
# Make the DELETE request to delete a custom secret
|
||||
response = test_client.delete('/api/secrets/API_KEY')
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify that the settings were stored without the deleted secret
|
||||
stored_settings = await file_settings_store.load()
|
||||
# Verify that the settings were stored without the deleted secret
|
||||
stored_settings = mock_settings_store.store.call_args[0][0]
|
||||
|
||||
# Check that the specified secret was deleted
|
||||
assert 'API_KEY' not in stored_settings.secrets_store.custom_secrets
|
||||
# Check that the specified secret was deleted
|
||||
assert 'API_KEY' not in stored_settings.secrets_store.custom_secrets
|
||||
|
||||
# Check that other secrets were preserved
|
||||
assert 'DB_PASSWORD' in stored_settings.secrets_store.custom_secrets
|
||||
assert (
|
||||
stored_settings.secrets_store.custom_secrets[
|
||||
'DB_PASSWORD'
|
||||
].get_secret_value()
|
||||
== 'db-password-value'
|
||||
)
|
||||
# Check that other secrets were preserved
|
||||
assert 'DB_PASSWORD' in stored_settings.secrets_store.custom_secrets
|
||||
assert (
|
||||
stored_settings.secrets_store.custom_secrets['DB_PASSWORD'].get_secret_value()
|
||||
== 'db-password-value'
|
||||
)
|
||||
|
||||
# Check that other settings were preserved
|
||||
assert stored_settings.language == 'en'
|
||||
assert stored_settings.agent == 'test-agent'
|
||||
assert stored_settings.llm_api_key.get_secret_value() == 'test-llm-key'
|
||||
assert ProviderType.GITHUB in stored_settings.secrets_store.provider_tokens
|
||||
# Check that other settings were preserved
|
||||
assert stored_settings.language == 'en'
|
||||
assert stored_settings.agent == 'test-agent'
|
||||
assert stored_settings.llm_api_key.get_secret_value() == 'test-llm-key'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_nonexistent_custom_secret(test_client):
|
||||
async def test_delete_nonexistent_custom_secret(
|
||||
test_client, mock_settings_store, mock_convert_to_settings
|
||||
):
|
||||
"""Test deleting a custom secret that doesn't exist."""
|
||||
with patch_file_settings_store() as file_settings_store:
|
||||
# Create initial settings with a custom secret
|
||||
custom_secrets = {'API_KEY': SecretStr('api-key-value')}
|
||||
provider_tokens = {
|
||||
ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token'))
|
||||
}
|
||||
secret_store = SecretStore(
|
||||
custom_secrets=custom_secrets, provider_tokens=provider_tokens
|
||||
)
|
||||
initial_settings = Settings(
|
||||
language='en',
|
||||
agent='test-agent',
|
||||
llm_api_key=SecretStr('test-llm-key'),
|
||||
secrets_store=secret_store,
|
||||
)
|
||||
# Create initial settings with a custom secret
|
||||
custom_secrets = {'API_KEY': SecretStr('api-key-value')}
|
||||
provider_tokens = {
|
||||
ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token'))
|
||||
}
|
||||
secret_store = SecretStore(
|
||||
custom_secrets=custom_secrets, provider_tokens=provider_tokens
|
||||
)
|
||||
initial_settings = Settings(
|
||||
language='en',
|
||||
agent='test-agent',
|
||||
llm_api_key=SecretStr('test-llm-key'),
|
||||
secrets_store=secret_store,
|
||||
)
|
||||
|
||||
# Store the initial settings
|
||||
await file_settings_store.store(initial_settings)
|
||||
# Mock the settings store to return our initial settings
|
||||
mock_settings_store.load.return_value = initial_settings
|
||||
|
||||
# Make the DELETE request to delete a nonexistent custom secret
|
||||
response = test_client.delete('/api/secrets/NONEXISTENT_KEY')
|
||||
assert response.status_code == 200
|
||||
# Make the DELETE request to delete a nonexistent custom secret
|
||||
response = test_client.delete('/api/secrets/NONEXISTENT_KEY')
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify that the settings were stored without changes to existing secrets
|
||||
stored_settings = await file_settings_store.load()
|
||||
# Verify that the settings were stored without changes to existing secrets
|
||||
stored_settings = mock_settings_store.store.call_args[0][0]
|
||||
|
||||
# Check that the existing secret was preserved
|
||||
assert 'API_KEY' in stored_settings.secrets_store.custom_secrets
|
||||
assert (
|
||||
stored_settings.secrets_store.custom_secrets['API_KEY'].get_secret_value()
|
||||
== 'api-key-value'
|
||||
)
|
||||
# Check that the existing secret was preserved
|
||||
assert 'API_KEY' in stored_settings.secrets_store.custom_secrets
|
||||
assert (
|
||||
stored_settings.secrets_store.custom_secrets['API_KEY'].get_secret_value()
|
||||
== 'api-key-value'
|
||||
)
|
||||
|
||||
# Check that other settings were preserved
|
||||
assert stored_settings.language == 'en'
|
||||
assert stored_settings.agent == 'test-agent'
|
||||
assert stored_settings.llm_api_key.get_secret_value() == 'test-llm-key'
|
||||
assert ProviderType.GITHUB in stored_settings.secrets_store.provider_tokens
|
||||
# Check that other settings were preserved
|
||||
assert stored_settings.language == 'en'
|
||||
assert stored_settings.agent == 'test-agent'
|
||||
assert stored_settings.llm_api_key.get_secret_value() == 'test-llm-key'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_secrets_operations_preserve_settings(test_client):
|
||||
async def test_custom_secrets_operations_preserve_settings(
|
||||
test_client, mock_settings_store, mock_convert_to_settings
|
||||
):
|
||||
"""Test that operations on custom secrets preserve all other settings."""
|
||||
with patch_file_settings_store() as file_settings_store:
|
||||
# Create initial settings with comprehensive data
|
||||
custom_secrets = {'INITIAL_SECRET': SecretStr('initial-value')}
|
||||
provider_tokens = {
|
||||
ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token')),
|
||||
ProviderType.GITLAB: ProviderToken(token=SecretStr('gitlab-token')),
|
||||
}
|
||||
secret_store = SecretStore(
|
||||
custom_secrets=custom_secrets, provider_tokens=provider_tokens
|
||||
)
|
||||
initial_settings = Settings(
|
||||
language='en',
|
||||
agent='test-agent',
|
||||
max_iterations=100,
|
||||
security_analyzer='default',
|
||||
confirmation_mode=True,
|
||||
llm_model='test-model',
|
||||
llm_api_key=SecretStr('test-llm-key'),
|
||||
llm_base_url='https://test.com',
|
||||
remote_runtime_resource_factor=2,
|
||||
enable_default_condenser=True,
|
||||
enable_sound_notifications=False,
|
||||
user_consents_to_analytics=True,
|
||||
secrets_store=secret_store,
|
||||
)
|
||||
# Create initial settings with comprehensive data
|
||||
custom_secrets = {'INITIAL_SECRET': SecretStr('initial-value')}
|
||||
provider_tokens = {
|
||||
ProviderType.GITHUB: ProviderToken(token=SecretStr('github-token')),
|
||||
ProviderType.GITLAB: ProviderToken(token=SecretStr('gitlab-token')),
|
||||
}
|
||||
secret_store = SecretStore(
|
||||
custom_secrets=custom_secrets, provider_tokens=provider_tokens
|
||||
)
|
||||
initial_settings = Settings(
|
||||
language='en',
|
||||
agent='test-agent',
|
||||
max_iterations=100,
|
||||
security_analyzer='default',
|
||||
confirmation_mode=True,
|
||||
llm_model='test-model',
|
||||
llm_api_key=SecretStr('test-llm-key'),
|
||||
llm_base_url='https://test.com',
|
||||
remote_runtime_resource_factor=2,
|
||||
enable_default_condenser=True,
|
||||
enable_sound_notifications=False,
|
||||
user_consents_to_analytics=True,
|
||||
secrets_store=secret_store,
|
||||
)
|
||||
|
||||
# Store the initial settings
|
||||
await file_settings_store.store(initial_settings)
|
||||
# Mock the settings store to return our initial settings
|
||||
mock_settings_store.load.return_value = initial_settings
|
||||
|
||||
# 1. Test adding a new custom secret
|
||||
add_secret_data = {'custom_secrets': {'NEW_SECRET': 'new-value'}}
|
||||
response = test_client.post('/api/secrets', json=add_secret_data)
|
||||
assert response.status_code == 200
|
||||
# 1. Test adding a new custom secret
|
||||
add_secret_data = {'custom_secrets': {'NEW_SECRET': 'new-value'}}
|
||||
response = test_client.post('/api/secrets', json=add_secret_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify all settings are preserved
|
||||
stored_settings = await file_settings_store.load()
|
||||
assert stored_settings.language == 'en'
|
||||
assert stored_settings.agent == 'test-agent'
|
||||
assert stored_settings.max_iterations == 100
|
||||
assert stored_settings.security_analyzer == 'default'
|
||||
assert stored_settings.confirmation_mode is True
|
||||
assert stored_settings.llm_model == 'test-model'
|
||||
assert stored_settings.llm_api_key.get_secret_value() == 'test-llm-key'
|
||||
assert stored_settings.llm_base_url == 'https://test.com'
|
||||
assert stored_settings.remote_runtime_resource_factor == 2
|
||||
assert stored_settings.enable_default_condenser is True
|
||||
assert stored_settings.enable_sound_notifications is False
|
||||
assert stored_settings.user_consents_to_analytics is True
|
||||
assert len(stored_settings.secrets_store.provider_tokens) == 2
|
||||
assert ProviderType.GITHUB in stored_settings.secrets_store.provider_tokens
|
||||
assert ProviderType.GITLAB in stored_settings.secrets_store.provider_tokens
|
||||
assert (
|
||||
stored_settings.secrets_store.custom_secrets[
|
||||
'INITIAL_SECRET'
|
||||
].get_secret_value()
|
||||
== 'initial-value'
|
||||
)
|
||||
assert (
|
||||
stored_settings.secrets_store.custom_secrets[
|
||||
'NEW_SECRET'
|
||||
].get_secret_value()
|
||||
== 'new-value'
|
||||
)
|
||||
# Verify all settings are preserved
|
||||
stored_settings = mock_settings_store.store.call_args[0][0]
|
||||
assert stored_settings.language == 'en'
|
||||
assert stored_settings.agent == 'test-agent'
|
||||
assert stored_settings.max_iterations == 100
|
||||
assert stored_settings.security_analyzer == 'default'
|
||||
assert stored_settings.confirmation_mode is True
|
||||
assert stored_settings.llm_model == 'test-model'
|
||||
assert stored_settings.llm_api_key.get_secret_value() == 'test-llm-key'
|
||||
assert stored_settings.llm_base_url == 'https://test.com'
|
||||
assert stored_settings.remote_runtime_resource_factor == 2
|
||||
assert stored_settings.enable_default_condenser is True
|
||||
assert stored_settings.enable_sound_notifications is False
|
||||
assert stored_settings.user_consents_to_analytics is True
|
||||
assert len(stored_settings.secrets_store.provider_tokens) == 2
|
||||
|
||||
# 2. Test updating an existing custom secret
|
||||
update_secret_data = {'custom_secrets': {'INITIAL_SECRET': 'updated-value'}}
|
||||
@@ -429,7 +408,7 @@ async def test_custom_secrets_operations_preserve_settings(test_client):
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify all settings are still preserved
|
||||
stored_settings = await file_settings_store.load()
|
||||
stored_settings = mock_settings_store.store.call_args[0][0]
|
||||
assert stored_settings.language == 'en'
|
||||
assert stored_settings.agent == 'test-agent'
|
||||
assert stored_settings.max_iterations == 100
|
||||
@@ -449,7 +428,7 @@ async def test_custom_secrets_operations_preserve_settings(test_client):
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify all settings are still preserved
|
||||
stored_settings = await file_settings_store.load()
|
||||
stored_settings = mock_settings_store.store.call_args[0][0]
|
||||
assert stored_settings.language == 'en'
|
||||
assert stored_settings.agent == 'test-agent'
|
||||
assert stored_settings.max_iterations == 100
|
||||
|
||||
+226
-57
@@ -1,60 +1,82 @@
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import Request
|
||||
from fastapi.testclient import TestClient
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.integrations.provider import ProviderToken, ProviderType
|
||||
from openhands.core.config.sandbox_config import SandboxConfig
|
||||
from openhands.integrations.provider import ProviderType, SecretStore
|
||||
from openhands.server.app import app
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
from openhands.storage.settings.settings_store import SettingsStore
|
||||
|
||||
|
||||
class MockUserAuth(UserAuth):
|
||||
"""Mock implementation of UserAuth for testing"""
|
||||
|
||||
def __init__(self):
|
||||
self._settings = None
|
||||
self._settings_store = MagicMock()
|
||||
self._settings_store.load = AsyncMock(return_value=None)
|
||||
self._settings_store.store = AsyncMock()
|
||||
|
||||
async def get_user_id(self) -> str | None:
|
||||
return 'test-user'
|
||||
|
||||
async def get_access_token(self) -> SecretStr | None:
|
||||
return SecretStr('test-token')
|
||||
|
||||
async def get_provider_tokens(self) -> dict[ProviderType, ProviderToken] | None: # noqa: E501
|
||||
return None
|
||||
|
||||
async def get_user_settings_store(self) -> SettingsStore | None:
|
||||
return self._settings_store
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls, request: Request) -> UserAuth:
|
||||
return MockUserAuth()
|
||||
from openhands.server.settings import Settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_client():
|
||||
# Create a test client
|
||||
with patch(
|
||||
'openhands.server.user_auth.user_auth.UserAuth.get_instance',
|
||||
return_value=MockUserAuth(),
|
||||
):
|
||||
with patch(
|
||||
'openhands.server.routes.settings.validate_provider_token',
|
||||
return_value=ProviderType.GITHUB,
|
||||
):
|
||||
client = TestClient(app)
|
||||
yield client
|
||||
def mock_settings_store():
|
||||
with patch('openhands.server.routes.settings.SettingsStoreImpl') as mock:
|
||||
store_instance = MagicMock()
|
||||
mock.get_instance = AsyncMock(return_value=store_instance)
|
||||
store_instance.load = AsyncMock()
|
||||
store_instance.store = AsyncMock()
|
||||
yield store_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_user_id():
|
||||
with patch('openhands.server.routes.settings.get_user_id') as mock:
|
||||
mock.return_value = 'test-user'
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_validate_provider_token():
|
||||
with patch('openhands.server.routes.settings.validate_provider_token') as mock:
|
||||
|
||||
async def mock_determine(*args, **kwargs):
|
||||
return ProviderType.GITHUB
|
||||
|
||||
mock.side_effect = mock_determine
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_client(mock_settings_store):
|
||||
# Mock the middleware that adds github_token
|
||||
class MockMiddleware:
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
settings = mock_settings_store.load.return_value
|
||||
token = None
|
||||
if settings and settings.secrets_store.provider_tokens.get(
|
||||
ProviderType.GITHUB
|
||||
):
|
||||
token = settings.secrets_store.provider_tokens[
|
||||
ProviderType.GITHUB
|
||||
].token
|
||||
if scope['type'] == 'http':
|
||||
scope['state'] = {'token': token}
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
# Replace the middleware
|
||||
app.middleware_stack = None # Clear existing middleware
|
||||
app.add_middleware(MockMiddleware)
|
||||
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_github_service():
|
||||
with patch('openhands.server.routes.settings.GitHubService') as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_settings_api_endpoints(test_client):
|
||||
"""Test that the settings API endpoints work with the new auth system"""
|
||||
async def test_settings_api_runtime_factor(
|
||||
test_client, mock_settings_store, mock_get_user_id, mock_validate_provider_token
|
||||
):
|
||||
# Mock the settings store to return None initially (no existing settings)
|
||||
mock_settings_store.load.return_value = None
|
||||
|
||||
# Test data with remote_runtime_resource_factor
|
||||
settings_data = {
|
||||
@@ -70,29 +92,176 @@ async def test_settings_api_endpoints(test_client):
|
||||
'provider_tokens': {'github': 'test-token'},
|
||||
}
|
||||
|
||||
# The test_client fixture already handles authentication
|
||||
|
||||
# Make the POST request to store settings
|
||||
response = test_client.post('/api/settings', json=settings_data)
|
||||
|
||||
# We're not checking the exact response, just that it doesn't error
|
||||
assert response.status_code == 200
|
||||
|
||||
# Test the GET settings endpoint
|
||||
# Verify the settings were stored with the correct runtime factor
|
||||
stored_settings = mock_settings_store.store.call_args[0][0]
|
||||
assert stored_settings.remote_runtime_resource_factor == 2
|
||||
|
||||
# Mock settings store to return our settings for the GET request
|
||||
mock_settings_store.load.return_value = Settings(**settings_data)
|
||||
|
||||
# Make a GET request to retrieve settings
|
||||
response = test_client.get('/api/settings')
|
||||
assert response.status_code == 200
|
||||
assert response.json()['remote_runtime_resource_factor'] == 2
|
||||
|
||||
# Verify that the sandbox config gets updated when settings are loaded
|
||||
with patch('openhands.server.shared.config') as mock_config:
|
||||
mock_config.sandbox = SandboxConfig()
|
||||
response = test_client.get('/api/settings')
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify that the sandbox config was updated with the new value
|
||||
mock_settings_store.store.assert_called()
|
||||
stored_settings = mock_settings_store.store.call_args[0][0]
|
||||
assert stored_settings.remote_runtime_resource_factor == 2
|
||||
|
||||
assert isinstance(stored_settings.llm_api_key, SecretStr)
|
||||
assert stored_settings.llm_api_key.get_secret_value() == 'test-key'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_settings_llm_api_key(
|
||||
test_client, mock_settings_store, mock_get_user_id, mock_validate_provider_token
|
||||
):
|
||||
# Mock the settings store to return None initially (no existing settings)
|
||||
mock_settings_store.load.return_value = None
|
||||
|
||||
# Test data with remote_runtime_resource_factor
|
||||
settings_data = {
|
||||
'llm_api_key': 'test-key',
|
||||
'provider_tokens': {'github': 'test-token'},
|
||||
}
|
||||
|
||||
# The test_client fixture already handles authentication
|
||||
|
||||
# Make the POST request to store settings
|
||||
response = test_client.post('/api/settings', json=settings_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify the settings were stored with the correct secret API key
|
||||
stored_settings = mock_settings_store.store.call_args[0][0]
|
||||
assert isinstance(stored_settings.llm_api_key, SecretStr)
|
||||
assert stored_settings.llm_api_key.get_secret_value() == 'test-key'
|
||||
|
||||
# Mock settings store to return our settings for the GET request
|
||||
mock_settings_store.load.return_value = Settings(**settings_data)
|
||||
|
||||
# Make a GET request to retrieve settings
|
||||
response = test_client.get('/api/settings')
|
||||
assert response.status_code == 200
|
||||
|
||||
# Test updating with partial settings
|
||||
partial_settings = {
|
||||
'language': 'fr',
|
||||
'llm_model': None, # Should preserve existing value
|
||||
'llm_api_key': None, # Should preserve existing value
|
||||
# We should never expose the API key in the response
|
||||
assert 'test-key' not in response.json()
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason='Mock middleware does not seem to properly set the github_token'
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_settings_api_set_github_token(
|
||||
mock_github_service,
|
||||
test_client,
|
||||
mock_settings_store,
|
||||
mock_get_user_id,
|
||||
mock_validate_provider_token,
|
||||
):
|
||||
# Test data with provider token set
|
||||
settings_data = {
|
||||
'language': 'en',
|
||||
'agent': 'test-agent',
|
||||
'max_iterations': 100,
|
||||
'security_analyzer': 'default',
|
||||
'confirmation_mode': True,
|
||||
'llm_model': 'test-model',
|
||||
'llm_api_key': 'test-key',
|
||||
'llm_base_url': 'https://test.com',
|
||||
'provider_tokens': {'github': 'test-token'},
|
||||
}
|
||||
|
||||
response = test_client.post('/api/settings', json=partial_settings)
|
||||
# Make the POST request to store settings
|
||||
response = test_client.post('/api/settings', json=settings_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Test the unset-settings-tokens endpoint
|
||||
response = test_client.post('/api/unset-settings-tokens')
|
||||
# Verify the settings were stored with the provider token
|
||||
stored_settings = mock_settings_store.store.call_args[0][0]
|
||||
assert (
|
||||
stored_settings.secrets_store.provider_tokens[
|
||||
ProviderType.GITHUB
|
||||
].token.get_secret_value()
|
||||
== 'test-token'
|
||||
)
|
||||
|
||||
# Mock settings store to return our settings for the GET request
|
||||
mock_settings_store.load.return_value = Settings(**settings_data)
|
||||
|
||||
# Make a GET request to retrieve settings
|
||||
response = test_client.get('/api/settings')
|
||||
data = response.json()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert data.get('token') is None
|
||||
assert data['token_is_set'] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_settings_preserve_llm_fields_when_none(test_client, mock_settings_store):
|
||||
# Setup initial settings with LLM fields populated
|
||||
initial_settings = Settings(
|
||||
language='en',
|
||||
agent='test-agent',
|
||||
max_iterations=100,
|
||||
security_analyzer='default',
|
||||
confirmation_mode=True,
|
||||
llm_model='existing-model',
|
||||
llm_api_key=SecretStr('existing-key'),
|
||||
llm_base_url='https://existing.com',
|
||||
secrets_store=SecretStore(),
|
||||
)
|
||||
|
||||
# Mock the settings store to return our initial settings
|
||||
mock_settings_store.load.return_value = initial_settings
|
||||
|
||||
# Test data with None values for LLM fields
|
||||
settings_update = {
|
||||
'language': 'fr', # Change something else to verify the update happens
|
||||
'llm_model': None,
|
||||
'llm_api_key': None,
|
||||
'llm_base_url': None,
|
||||
}
|
||||
|
||||
# Make the POST request to update settings
|
||||
response = test_client.post('/api/settings', json=settings_update)
|
||||
assert response.status_code == 200
|
||||
|
||||
# We'll skip the secrets endpoints for now as they require more complex mocking # noqa: E501
|
||||
# and they're not directly related to the authentication refactoring
|
||||
# Verify that the settings were stored with preserved LLM values
|
||||
stored_settings = mock_settings_store.store.call_args[0][0]
|
||||
|
||||
# Check that language was updated
|
||||
assert stored_settings.language == 'fr'
|
||||
|
||||
# Check that LLM fields were preserved and not cleared
|
||||
assert stored_settings.llm_model == 'existing-model'
|
||||
assert isinstance(stored_settings.llm_api_key, SecretStr)
|
||||
assert stored_settings.llm_api_key.get_secret_value() == 'existing-key'
|
||||
assert stored_settings.llm_base_url == 'https://existing.com'
|
||||
|
||||
# Update the mock to return our new settings for the GET request
|
||||
mock_settings_store.load.return_value = stored_settings
|
||||
|
||||
# Make a GET request to verify the updated settings
|
||||
response = test_client.get('/api/settings')
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
# Verify fields in the response
|
||||
assert data['language'] == 'fr'
|
||||
assert data['llm_model'] == 'existing-model'
|
||||
assert data['llm_base_url'] == 'https://existing.com'
|
||||
# We expect the API key not to be included in the response
|
||||
assert 'test-key' not in str(response.content)
|
||||
|
||||
@@ -23,6 +23,7 @@ async def get_settings_store(request):
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_provider_tokens_valid():
|
||||
"""Test check_provider_tokens with valid tokens."""
|
||||
mock_request = MagicMock()
|
||||
settings = POSTSettingsModel(provider_tokens={'github': 'valid-token'})
|
||||
|
||||
# Mock the validate_provider_token function to return GITHUB for valid tokens
|
||||
@@ -31,7 +32,7 @@ async def test_check_provider_tokens_valid():
|
||||
) as mock_validate:
|
||||
mock_validate.return_value = ProviderType.GITHUB
|
||||
|
||||
result = await check_provider_tokens(settings)
|
||||
result = await check_provider_tokens(mock_request, settings)
|
||||
|
||||
# Should return empty string for valid token
|
||||
assert result == ''
|
||||
@@ -41,6 +42,7 @@ async def test_check_provider_tokens_valid():
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_provider_tokens_invalid():
|
||||
"""Test check_provider_tokens with invalid tokens."""
|
||||
mock_request = MagicMock()
|
||||
settings = POSTSettingsModel(provider_tokens={'github': 'invalid-token'})
|
||||
|
||||
# Mock the validate_provider_token function to return None for invalid tokens
|
||||
@@ -49,7 +51,7 @@ async def test_check_provider_tokens_invalid():
|
||||
) as mock_validate:
|
||||
mock_validate.return_value = None
|
||||
|
||||
result = await check_provider_tokens(settings)
|
||||
result = await check_provider_tokens(mock_request, settings)
|
||||
|
||||
# Should return error message for invalid token
|
||||
assert 'Invalid token' in result
|
||||
@@ -59,9 +61,10 @@ async def test_check_provider_tokens_invalid():
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_provider_tokens_wrong_type():
|
||||
"""Test check_provider_tokens with unsupported provider type."""
|
||||
mock_request = MagicMock()
|
||||
settings = POSTSettingsModel(provider_tokens={'unsupported': 'some-token'})
|
||||
|
||||
result = await check_provider_tokens(settings)
|
||||
result = await check_provider_tokens(mock_request, settings)
|
||||
|
||||
# Should return empty string for unsupported provider
|
||||
assert result == ''
|
||||
@@ -70,9 +73,10 @@ async def test_check_provider_tokens_wrong_type():
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_provider_tokens_no_tokens():
|
||||
"""Test check_provider_tokens with no tokens."""
|
||||
mock_request = MagicMock()
|
||||
settings = POSTSettingsModel(provider_tokens={})
|
||||
|
||||
result = await check_provider_tokens(settings)
|
||||
result = await check_provider_tokens(mock_request, settings)
|
||||
|
||||
# Should return empty string when no tokens provided
|
||||
assert result == ''
|
||||
@@ -82,6 +86,7 @@ async def test_check_provider_tokens_no_tokens():
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_llm_settings_new_settings():
|
||||
"""Test store_llm_settings with new settings."""
|
||||
mock_request = MagicMock()
|
||||
settings = POSTSettingsModel(
|
||||
llm_model='gpt-4',
|
||||
llm_api_key='test-api-key',
|
||||
@@ -89,20 +94,25 @@ async def test_store_llm_settings_new_settings():
|
||||
)
|
||||
|
||||
# Mock the settings store
|
||||
mock_store = MagicMock()
|
||||
mock_store.load = AsyncMock(return_value=None) # No existing settings
|
||||
with patch(
|
||||
'openhands.server.routes.settings.SettingsStoreImpl.get_instance'
|
||||
) as mock_get_store:
|
||||
mock_store = MagicMock()
|
||||
mock_store.load = AsyncMock(return_value=None) # No existing settings
|
||||
mock_get_store.return_value = mock_store
|
||||
|
||||
result = await store_llm_settings(settings, mock_store)
|
||||
result = await store_llm_settings(mock_request, settings)
|
||||
|
||||
# Should return settings with the provided values
|
||||
assert result.llm_model == 'gpt-4'
|
||||
assert result.llm_api_key.get_secret_value() == 'test-api-key'
|
||||
assert result.llm_base_url == 'https://api.example.com'
|
||||
# Should return settings with the provided values
|
||||
assert result.llm_model == 'gpt-4'
|
||||
assert result.llm_api_key.get_secret_value() == 'test-api-key'
|
||||
assert result.llm_base_url == 'https://api.example.com'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_llm_settings_update_existing():
|
||||
"""Test store_llm_settings updates existing settings."""
|
||||
mock_request = MagicMock()
|
||||
settings = POSTSettingsModel(
|
||||
llm_model='gpt-4',
|
||||
llm_api_key='new-api-key',
|
||||
@@ -110,118 +120,142 @@ async def test_store_llm_settings_update_existing():
|
||||
)
|
||||
|
||||
# Mock the settings store
|
||||
mock_store = MagicMock()
|
||||
with patch(
|
||||
'openhands.server.routes.settings.SettingsStoreImpl.get_instance'
|
||||
) as mock_get_store:
|
||||
mock_store = MagicMock()
|
||||
|
||||
# Create existing settings
|
||||
existing_settings = Settings(
|
||||
llm_model='gpt-3.5',
|
||||
llm_api_key=SecretStr('old-api-key'),
|
||||
llm_base_url='https://old.example.com',
|
||||
)
|
||||
# Create existing settings
|
||||
existing_settings = Settings(
|
||||
llm_model='gpt-3.5',
|
||||
llm_api_key=SecretStr('old-api-key'),
|
||||
llm_base_url='https://old.example.com',
|
||||
)
|
||||
|
||||
mock_store.load = AsyncMock(return_value=existing_settings)
|
||||
mock_store.load = AsyncMock(return_value=existing_settings)
|
||||
mock_get_store.return_value = mock_store
|
||||
|
||||
result = await store_llm_settings(settings, mock_store)
|
||||
result = await store_llm_settings(mock_request, settings)
|
||||
|
||||
# Should return settings with the updated values
|
||||
assert result.llm_model == 'gpt-4'
|
||||
assert result.llm_api_key.get_secret_value() == 'new-api-key'
|
||||
assert result.llm_base_url == 'https://new.example.com'
|
||||
# Should return settings with the updated values
|
||||
assert result.llm_model == 'gpt-4'
|
||||
assert result.llm_api_key.get_secret_value() == 'new-api-key'
|
||||
assert result.llm_base_url == 'https://new.example.com'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_llm_settings_partial_update():
|
||||
"""Test store_llm_settings with partial update."""
|
||||
mock_request = MagicMock()
|
||||
settings = POSTSettingsModel(
|
||||
llm_model='gpt-4' # Only updating model
|
||||
)
|
||||
|
||||
# Mock the settings store
|
||||
mock_store = MagicMock()
|
||||
with patch(
|
||||
'openhands.server.routes.settings.SettingsStoreImpl.get_instance'
|
||||
) as mock_get_store:
|
||||
mock_store = MagicMock()
|
||||
|
||||
# Create existing settings
|
||||
existing_settings = Settings(
|
||||
llm_model='gpt-3.5',
|
||||
llm_api_key=SecretStr('existing-api-key'),
|
||||
llm_base_url='https://existing.example.com',
|
||||
)
|
||||
# Create existing settings
|
||||
existing_settings = Settings(
|
||||
llm_model='gpt-3.5',
|
||||
llm_api_key=SecretStr('existing-api-key'),
|
||||
llm_base_url='https://existing.example.com',
|
||||
)
|
||||
|
||||
mock_store.load = AsyncMock(return_value=existing_settings)
|
||||
mock_store.load = AsyncMock(return_value=existing_settings)
|
||||
mock_get_store.return_value = mock_store
|
||||
|
||||
result = await store_llm_settings(settings, mock_store)
|
||||
result = await store_llm_settings(mock_request, settings)
|
||||
|
||||
# Should return settings with updated model but keep other values
|
||||
assert result.llm_model == 'gpt-4'
|
||||
# For SecretStr objects, we need to compare the secret value
|
||||
assert result.llm_api_key.get_secret_value() == 'existing-api-key'
|
||||
assert result.llm_base_url == 'https://existing.example.com'
|
||||
# Should return settings with updated model but keep other values
|
||||
assert result.llm_model == 'gpt-4'
|
||||
# For SecretStr objects, we need to compare the secret value
|
||||
assert result.llm_api_key.get_secret_value() == 'existing-api-key'
|
||||
assert result.llm_base_url == 'https://existing.example.com'
|
||||
|
||||
|
||||
# Tests for store_provider_tokens
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_provider_tokens_new_tokens():
|
||||
"""Test store_provider_tokens with new tokens."""
|
||||
mock_request = MagicMock()
|
||||
settings = POSTSettingsModel(provider_tokens={'github': 'new-token'})
|
||||
|
||||
# Mock the settings store
|
||||
mock_store = MagicMock()
|
||||
mock_store.load = AsyncMock(return_value=None) # No existing settings
|
||||
with patch(
|
||||
'openhands.server.routes.settings.SettingsStoreImpl.get_instance'
|
||||
) as mock_get_store:
|
||||
mock_store = MagicMock()
|
||||
mock_store.load = AsyncMock(return_value=None) # No existing settings
|
||||
mock_get_store.return_value = mock_store
|
||||
|
||||
result = await store_provider_tokens(settings, mock_store)
|
||||
result = await store_provider_tokens(mock_request, settings)
|
||||
|
||||
# Should return settings with the provided tokens
|
||||
assert result.provider_tokens == {'github': 'new-token'}
|
||||
# Should return settings with the provided tokens
|
||||
assert result.provider_tokens == {'github': 'new-token'}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_provider_tokens_update_existing():
|
||||
"""Test store_provider_tokens updates existing tokens."""
|
||||
mock_request = MagicMock()
|
||||
settings = POSTSettingsModel(provider_tokens={'github': 'updated-token'})
|
||||
|
||||
# Mock the settings store
|
||||
mock_store = MagicMock()
|
||||
with patch(
|
||||
'openhands.server.routes.settings.SettingsStoreImpl.get_instance'
|
||||
) as mock_get_store:
|
||||
mock_store = MagicMock()
|
||||
|
||||
# Create existing settings with a GitHub token
|
||||
github_token = ProviderToken(token=SecretStr('old-token'))
|
||||
provider_tokens = {ProviderType.GITHUB: github_token}
|
||||
# Create existing settings with a GitHub token
|
||||
github_token = ProviderToken(token=SecretStr('old-token'))
|
||||
provider_tokens = {ProviderType.GITHUB: github_token}
|
||||
|
||||
# Create a SecretStore with the provider tokens
|
||||
secrets_store = SecretStore(provider_tokens=provider_tokens)
|
||||
# Create a SecretStore with the provider tokens
|
||||
secrets_store = SecretStore(provider_tokens=provider_tokens)
|
||||
|
||||
# Create existing settings with the secrets store
|
||||
existing_settings = Settings(secrets_store=secrets_store)
|
||||
# Create existing settings with the secrets store
|
||||
existing_settings = Settings(secrets_store=secrets_store)
|
||||
|
||||
mock_store.load = AsyncMock(return_value=existing_settings)
|
||||
mock_store.load = AsyncMock(return_value=existing_settings)
|
||||
mock_get_store.return_value = mock_store
|
||||
|
||||
result = await store_provider_tokens(settings, mock_store)
|
||||
result = await store_provider_tokens(mock_request, settings)
|
||||
|
||||
# Should return settings with the updated tokens
|
||||
assert result.provider_tokens == {'github': 'updated-token'}
|
||||
# Should return settings with the updated tokens
|
||||
assert result.provider_tokens == {'github': 'updated-token'}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_provider_tokens_keep_existing():
|
||||
"""Test store_provider_tokens keeps existing tokens when empty string provided."""
|
||||
mock_request = MagicMock()
|
||||
settings = POSTSettingsModel(
|
||||
provider_tokens={'github': ''} # Empty string should keep existing token
|
||||
)
|
||||
|
||||
# Mock the settings store
|
||||
mock_store = MagicMock()
|
||||
with patch(
|
||||
'openhands.server.routes.settings.SettingsStoreImpl.get_instance'
|
||||
) as mock_get_store:
|
||||
mock_store = MagicMock()
|
||||
|
||||
# Create existing settings with a GitHub token
|
||||
github_token = ProviderToken(token=SecretStr('existing-token'))
|
||||
provider_tokens = {ProviderType.GITHUB: github_token}
|
||||
# Create existing settings with a GitHub token
|
||||
github_token = ProviderToken(token=SecretStr('existing-token'))
|
||||
provider_tokens = {ProviderType.GITHUB: github_token}
|
||||
|
||||
# Create a SecretStore with the provider tokens
|
||||
secrets_store = SecretStore(provider_tokens=provider_tokens)
|
||||
# Create a SecretStore with the provider tokens
|
||||
secrets_store = SecretStore(provider_tokens=provider_tokens)
|
||||
|
||||
# Create existing settings with the secrets store
|
||||
existing_settings = Settings(secrets_store=secrets_store)
|
||||
# Create existing settings with the secrets store
|
||||
existing_settings = Settings(secrets_store=secrets_store)
|
||||
|
||||
mock_store.load = AsyncMock(return_value=existing_settings)
|
||||
mock_store.load = AsyncMock(return_value=existing_settings)
|
||||
mock_get_store.return_value = mock_store
|
||||
|
||||
result = await store_provider_tokens(settings, mock_store)
|
||||
result = await store_provider_tokens(mock_request, settings)
|
||||
|
||||
# Should return settings with the existing token preserved
|
||||
assert result.provider_tokens == {'github': 'existing-token'}
|
||||
# Should return settings with the existing token preserved
|
||||
assert result.provider_tokens == {'github': 'existing-token'}
|
||||
|
||||
Reference in New Issue
Block a user