From 07e750f03804de645bd92fcdd08b3b10ccf86db3 Mon Sep 17 00:00:00 2001 From: "sp.wack" <83104063+amanape@users.noreply.github.com> Date: Fri, 23 Aug 2024 20:06:15 +0300 Subject: [PATCH] feat(frontend): Improve models input UI/UX in settings (#3530) * Create helper functions * Add map according to litellm docs * Create ModelSelector * Extend model selector * use autocomplete from nextui * Improve keys without providers * Handle models without a provider * Add verified section and some empty handling * Add support for default or previously set models * Update tests * Lint * Remove modifier * Fix typescript error * Functionality for switching to custom model * Add verified models * Respond to resetting to default * Comment --- .../modals/settings/ModelSelector.test.tsx | 193 ++++++++++++++++++ .../modals/settings/ModelSelector.tsx | 133 ++++++++++++ .../modals/settings/SettingsForm.test.tsx | 127 +++++++++--- .../modals/settings/SettingsForm.tsx | 54 ++++- .../modals/settings/SettingsModal.test.tsx | 61 ++++-- .../modals/settings/SettingsModal.tsx | 22 +- frontend/src/services/session.test.ts | 36 ++++ frontend/src/services/session.ts | 10 +- frontend/src/services/settings.test.ts | 23 +++ frontend/src/services/settings.ts | 11 +- .../src/utils/extractModelAndProvider.test.ts | 62 ++++++ frontend/src/utils/extractModelAndProvider.ts | 49 +++++ frontend/src/utils/isNumber.test.ts | 9 + frontend/src/utils/isNumber.ts | 2 + frontend/src/utils/mapProvider.test.ts | 27 +++ frontend/src/utils/mapProvider.ts | 30 +++ .../utils/organizeModelsAndProviders.test.ts | 51 +++++ .../src/utils/organizeModelsAndProviders.ts | 42 ++++ frontend/src/utils/verified-models.ts | 14 ++ 19 files changed, 901 insertions(+), 55 deletions(-) create mode 100644 frontend/src/components/modals/settings/ModelSelector.test.tsx create mode 100644 frontend/src/components/modals/settings/ModelSelector.tsx create mode 100644 frontend/src/utils/extractModelAndProvider.test.ts create mode 100644 frontend/src/utils/extractModelAndProvider.ts create mode 100644 frontend/src/utils/isNumber.test.ts create mode 100644 frontend/src/utils/isNumber.ts create mode 100644 frontend/src/utils/mapProvider.test.ts create mode 100644 frontend/src/utils/mapProvider.ts create mode 100644 frontend/src/utils/organizeModelsAndProviders.test.ts create mode 100644 frontend/src/utils/organizeModelsAndProviders.ts create mode 100644 frontend/src/utils/verified-models.ts diff --git a/frontend/src/components/modals/settings/ModelSelector.test.tsx b/frontend/src/components/modals/settings/ModelSelector.test.tsx new file mode 100644 index 0000000000..2b481c7fc6 --- /dev/null +++ b/frontend/src/components/modals/settings/ModelSelector.test.tsx @@ -0,0 +1,193 @@ +import React from "react"; +import { describe, it, expect, vi } from "vitest"; +import { render, screen } from "@testing-library/react"; +import userEvent from "@testing-library/user-event"; +import { ModelSelector } from "./ModelSelector"; + +describe("ModelSelector", () => { + const models = { + openai: { + separator: "/", + models: ["gpt-4o", "gpt-3.5-turbo"], + }, + azure: { + separator: "/", + models: ["ada", "gpt-35-turbo"], + }, + vertex_ai: { + separator: "/", + models: ["chat-bison", "chat-bison-32k"], + }, + cohere: { + separator: ".", + models: ["command-r-v1:0"], + }, + }; + + it("should display the provider selector", async () => { + const user = userEvent.setup(); + const onModelChange = vi.fn(); + render(); + + const selector = screen.getByLabelText("Provider"); + expect(selector).toBeInTheDocument(); + + await user.click(selector); + + expect(screen.getByText("OpenAI")).toBeInTheDocument(); + expect(screen.getByText("Azure")).toBeInTheDocument(); + expect(screen.getByText("VertexAI")).toBeInTheDocument(); + expect(screen.getByText("cohere")).toBeInTheDocument(); + }); + + it("should disable the model selector if the provider is not selected", async () => { + const user = userEvent.setup(); + const onModelChange = vi.fn(); + render(); + + const modelSelector = screen.getByLabelText("Model"); + expect(modelSelector).toBeDisabled(); + + const providerSelector = screen.getByLabelText("Provider"); + await user.click(providerSelector); + + const vertexAI = screen.getByText("VertexAI"); + await user.click(vertexAI); + + expect(modelSelector).not.toBeDisabled(); + }); + + it("should display the model selector", async () => { + const user = userEvent.setup(); + const onModelChange = vi.fn(); + render(); + + const providerSelector = screen.getByLabelText("Provider"); + await user.click(providerSelector); + + const azureProvider = screen.getByText("Azure"); + await user.click(azureProvider); + + const modelSelector = screen.getByLabelText("Model"); + await user.click(modelSelector); + + expect(screen.getByText("ada")).toBeInTheDocument(); + expect(screen.getByText("gpt-35-turbo")).toBeInTheDocument(); + + await user.click(providerSelector); + const vertexProvider = screen.getByText("VertexAI"); + await user.click(vertexProvider); + + await user.click(modelSelector); + + expect(screen.getByText("chat-bison")).toBeInTheDocument(); + expect(screen.getByText("chat-bison-32k")).toBeInTheDocument(); + }); + + it("should display the actual litellm model ID as the user is making the selections", async () => { + const user = userEvent.setup(); + const onModelChange = vi.fn(); + render(); + + const id = screen.getByTestId("model-id"); + const providerSelector = screen.getByLabelText("Provider"); + const modelSelector = screen.getByLabelText("Model"); + + expect(id).toHaveTextContent("No model selected"); + + await user.click(providerSelector); + await user.click(screen.getByText("Azure")); + + expect(id).toHaveTextContent("azure/"); + + await user.click(modelSelector); + await user.click(screen.getByText("ada")); + expect(id).toHaveTextContent("azure/ada"); + + await user.click(providerSelector); + await user.click(screen.getByText("cohere")); + expect(id).toHaveTextContent("cohere."); + + await user.click(modelSelector); + await user.click(screen.getByText("command-r-v1:0")); + expect(id).toHaveTextContent("cohere.command-r-v1:0"); + }); + + it("should call onModelChange when the model is changed", async () => { + const user = userEvent.setup(); + const onModelChange = vi.fn(); + render(); + + const providerSelector = screen.getByLabelText("Provider"); + const modelSelector = screen.getByLabelText("Model"); + + await user.click(providerSelector); + await user.click(screen.getByText("Azure")); + + await user.click(modelSelector); + await user.click(screen.getByText("ada")); + + expect(onModelChange).toHaveBeenCalledTimes(1); + expect(onModelChange).toHaveBeenCalledWith("azure/ada"); + + await user.click(modelSelector); + await user.click(screen.getByText("gpt-35-turbo")); + + expect(onModelChange).toHaveBeenCalledTimes(2); + expect(onModelChange).toHaveBeenCalledWith("azure/gpt-35-turbo"); + + await user.click(providerSelector); + await user.click(screen.getByText("cohere")); + + await user.click(modelSelector); + await user.click(screen.getByText("command-r-v1:0")); + + expect(onModelChange).toHaveBeenCalledTimes(3); + expect(onModelChange).toHaveBeenCalledWith("cohere.command-r-v1:0"); + }); + + it("should clear the model ID when the provider is cleared", async () => { + const user = userEvent.setup(); + const onModelChange = vi.fn(); + render(); + + const providerSelector = screen.getByLabelText("Provider"); + const modelSelector = screen.getByLabelText("Model"); + + await user.click(providerSelector); + await user.click(screen.getByText("Azure")); + + await user.click(modelSelector); + await user.click(screen.getByText("ada")); + + expect(screen.getByTestId("model-id")).toHaveTextContent("azure/ada"); + + await user.clear(providerSelector); + + expect(screen.getByTestId("model-id")).toHaveTextContent( + "No model selected", + ); + }); + + it("should have a default value if passed", async () => { + const onModelChange = vi.fn(); + render( + , + ); + + expect(screen.getByTestId("model-id")).toHaveTextContent("azure/ada"); + expect(screen.getByLabelText("Provider")).toHaveValue("Azure"); + expect(screen.getByLabelText("Model")).toHaveValue("ada"); + }); + + it.todo("should disable provider if isDisabled is true"); + + it.todo( + "should display the verified models in the correct order", + async () => {}, + ); +}); diff --git a/frontend/src/components/modals/settings/ModelSelector.tsx b/frontend/src/components/modals/settings/ModelSelector.tsx new file mode 100644 index 0000000000..f5f7c221d4 --- /dev/null +++ b/frontend/src/components/modals/settings/ModelSelector.tsx @@ -0,0 +1,133 @@ +import { + Autocomplete, + AutocompleteItem, + AutocompleteSection, +} from "@nextui-org/react"; +import React from "react"; +import { mapProvider } from "#/utils/mapProvider"; +import { VERIFIED_MODELS, VERIFIED_PROVIDERS } from "#/utils/verified-models"; +import { extractModelAndProvider } from "#/utils/extractModelAndProvider"; + +interface ModelSelectorProps { + isDisabled?: boolean; + models: Record; + onModelChange: (model: string) => void; + defaultModel?: string; +} + +export function ModelSelector({ + isDisabled, + models, + onModelChange, + defaultModel, +}: ModelSelectorProps) { + const [litellmId, setLitellmId] = React.useState(null); + const [selectedProvider, setSelectedProvider] = React.useState( + null, + ); + const [selectedModel, setSelectedModel] = React.useState(null); + + React.useEffect(() => { + if (defaultModel) { + // runs when resetting to defaults + const { provider, model } = extractModelAndProvider(defaultModel); + + setLitellmId(defaultModel); + setSelectedProvider(provider); + setSelectedModel(model); + } + }, [defaultModel]); + + const handleChangeProvider = (provider: string) => { + setSelectedProvider(provider); + setSelectedModel(null); + + const separator = models[provider]?.separator || ""; + setLitellmId(provider + separator); + }; + + const handleChangeModel = (model: string) => { + const separator = models[selectedProvider || ""]?.separator || ""; + const fullModel = selectedProvider + separator + model; + setLitellmId(fullModel); + onModelChange(fullModel); + setSelectedModel(model); + }; + + const clear = () => { + setSelectedProvider(null); + setLitellmId(null); + }; + + return ( +
+ + {litellmId?.replace("other", "") || "No model selected"} + + +
+ { + if (e?.toString()) handleChangeProvider(e.toString()); + }} + onInputChange={(value) => !value && clear()} + defaultSelectedKey={selectedProvider ?? undefined} + selectedKey={selectedProvider} + > + + {Object.keys(models) + .filter((provider) => VERIFIED_PROVIDERS.includes(provider)) + .map((provider) => ( + + {mapProvider(provider)} + + ))} + + + {Object.keys(models) + .filter((provider) => !VERIFIED_PROVIDERS.includes(provider)) + .map((provider) => ( + + {mapProvider(provider)} + + ))} + + + + { + if (e?.toString()) handleChangeModel(e.toString()); + }} + isDisabled={isDisabled || !selectedProvider} + selectedKey={selectedModel} + defaultSelectedKey={selectedModel ?? undefined} + > + + {models[selectedProvider || ""]?.models + .filter((model) => VERIFIED_MODELS.includes(model)) + .map((model) => ( + + {model} + + ))} + + + {models[selectedProvider || ""]?.models + .filter((model) => !VERIFIED_MODELS.includes(model)) + .map((model) => ( + + {model} + + ))} + + +
+
+ ); +} diff --git a/frontend/src/components/modals/settings/SettingsForm.test.tsx b/frontend/src/components/modals/settings/SettingsForm.test.tsx index b6847ff3ba..a6ac059e06 100644 --- a/frontend/src/components/modals/settings/SettingsForm.test.tsx +++ b/frontend/src/components/modals/settings/SettingsForm.test.tsx @@ -6,6 +6,8 @@ import { Settings } from "#/services/settings"; import SettingsForm from "./SettingsForm"; const onModelChangeMock = vi.fn(); +const onCustomModelChangeMock = vi.fn(); +const onModelTypeChangeMock = vi.fn(); const onAgentChangeMock = vi.fn(); const onLanguageChangeMock = vi.fn(); const onAPIKeyChangeMock = vi.fn(); @@ -18,7 +20,9 @@ const renderSettingsForm = (settings?: Settings) => { disabled={false} settings={ settings || { - LLM_MODEL: "model1", + LLM_MODEL: "gpt-4o", + CUSTOM_LLM_MODEL: "", + USING_CUSTOM_MODEL: false, AGENT: "agent1", LANGUAGE: "en", LLM_API_KEY: "sk-...", @@ -26,10 +30,12 @@ const renderSettingsForm = (settings?: Settings) => { SECURITY_ANALYZER: "analyzer1", } } - models={["model1", "model2", "model3"]} + models={["gpt-4o", "gpt-3.5-turbo", "azure/ada"]} agents={["agent1", "agent2", "agent3"]} securityAnalyzers={["analyzer1", "analyzer2", "analyzer3"]} onModelChange={onModelChangeMock} + onCustomModelChange={onCustomModelChangeMock} + onModelTypeChange={onModelTypeChangeMock} onAgentChange={onAgentChangeMock} onLanguageChange={onLanguageChangeMock} onAPIKeyChange={onAPIKeyChangeMock} @@ -43,7 +49,8 @@ describe("SettingsForm", () => { it("should display the first values in the array by default", () => { renderSettingsForm(); - const modelInput = screen.getByRole("combobox", { name: "model" }); + const providerInput = screen.getByRole("combobox", { name: "Provider" }); + const modelInput = screen.getByRole("combobox", { name: "Model" }); const agentInput = screen.getByRole("combobox", { name: "agent" }); const languageInput = screen.getByRole("combobox", { name: "language" }); const apiKeyInput = screen.getByTestId("apikey"); @@ -52,7 +59,8 @@ describe("SettingsForm", () => { name: "securityanalyzer", }); - expect(modelInput).toHaveValue("model1"); + expect(providerInput).toHaveValue("OpenAI"); + expect(modelInput).toHaveValue("gpt-4o"); expect(agentInput).toHaveValue("agent1"); expect(languageInput).toHaveValue("English"); expect(apiKeyInput).toHaveValue("sk-..."); @@ -62,7 +70,9 @@ describe("SettingsForm", () => { it("should display the existing values if they are present", () => { renderSettingsForm({ - LLM_MODEL: "model2", + LLM_MODEL: "gpt-3.5-turbo", + CUSTOM_LLM_MODEL: "", + USING_CUSTOM_MODEL: false, AGENT: "agent2", LANGUAGE: "es", LLM_API_KEY: "sk-...", @@ -70,14 +80,16 @@ describe("SettingsForm", () => { SECURITY_ANALYZER: "analyzer2", }); - const modelInput = screen.getByRole("combobox", { name: "model" }); + const providerInput = screen.getByRole("combobox", { name: "Provider" }); + const modelInput = screen.getByRole("combobox", { name: "Model" }); const agentInput = screen.getByRole("combobox", { name: "agent" }); const languageInput = screen.getByRole("combobox", { name: "language" }); const securityAnalyzerInput = screen.getByRole("combobox", { name: "securityanalyzer", }); - expect(modelInput).toHaveValue("model2"); + expect(providerInput).toHaveValue("OpenAI"); + expect(modelInput).toHaveValue("gpt-3.5-turbo"); expect(agentInput).toHaveValue("agent2"); expect(languageInput).toHaveValue("EspaƱol"); expect(securityAnalyzerInput).toHaveValue("analyzer2"); @@ -87,18 +99,22 @@ describe("SettingsForm", () => { renderWithProviders( { onSecurityAnalyzerChange={onSecurityAnalyzerChangeMock} />, ); - const modelInput = screen.getByRole("combobox", { name: "model" }); + + const providerInput = screen.getByRole("combobox", { name: "Provider" }); + const modelInput = screen.getByRole("combobox", { name: "Model" }); const agentInput = screen.getByRole("combobox", { name: "agent" }); const languageInput = screen.getByRole("combobox", { name: "language" }); const confirmationModeInput = screen.getByTestId("confirmationmode"); @@ -114,6 +132,7 @@ describe("SettingsForm", () => { name: "securityanalyzer", }); + expect(providerInput).toBeDisabled(); expect(modelInput).toBeDisabled(); expect(agentInput).toBeDisabled(); expect(languageInput).toBeDisabled(); @@ -122,22 +141,6 @@ describe("SettingsForm", () => { }); describe("onChange handlers", () => { - it("should call the onModelChange handler when the model changes", async () => { - renderSettingsForm(); - - const modelInput = screen.getByRole("combobox", { name: "model" }); - await act(async () => { - await userEvent.click(modelInput); - }); - - const model3 = screen.getByText("model3"); - await act(async () => { - await userEvent.click(model3); - }); - - expect(onModelChangeMock).toHaveBeenCalledWith("model3"); - }); - it("should call the onAgentChange handler when the agent changes", async () => { const user = userEvent.setup(); renderSettingsForm(); @@ -182,4 +185,76 @@ describe("SettingsForm", () => { expect(onAPIKeyChangeMock).toHaveBeenCalledWith("sk-...x"); }); }); + + describe("Setting a custom LLM model", () => { + it("should display the fetched models by default", () => { + renderSettingsForm(); + + const modelSelector = screen.getByTestId("model-selector"); + expect(modelSelector).toBeInTheDocument(); + + const customModelInput = screen.queryByTestId("custom-model-input"); + expect(customModelInput).not.toBeInTheDocument(); + }); + + it("should switch to the custom model input when the custom model toggle is clicked", async () => { + const user = userEvent.setup(); + renderSettingsForm(); + + const customModelToggle = screen.getByTestId("custom-model-toggle"); + await user.click(customModelToggle); + + const modelSelector = screen.queryByTestId("model-selector"); + expect(modelSelector).not.toBeInTheDocument(); + + const customModelInput = screen.getByTestId("custom-model-input"); + expect(customModelInput).toBeInTheDocument(); + }); + + it("should call the onCustomModelChange handler when the custom model input changes", async () => { + const user = userEvent.setup(); + renderSettingsForm(); + + const customModelToggle = screen.getByTestId("custom-model-toggle"); + await user.click(customModelToggle); + + const customModelInput = screen.getByTestId("custom-model-input"); + await userEvent.type(customModelInput, "my/custom-model"); + + expect(onCustomModelChangeMock).toHaveBeenCalledWith("my/custom-model"); + expect(onModelTypeChangeMock).toHaveBeenCalledWith("custom"); + }); + + it("should have custom model switched if using custom model", () => { + renderWithProviders( + , + ); + + const customModelToggle = screen.getByTestId("custom-model-toggle"); + expect(customModelToggle).toHaveAttribute("aria-checked", "true"); + }); + }); }); diff --git a/frontend/src/components/modals/settings/SettingsForm.tsx b/frontend/src/components/modals/settings/SettingsForm.tsx index f865f659b3..d607d61731 100644 --- a/frontend/src/components/modals/settings/SettingsForm.tsx +++ b/frontend/src/components/modals/settings/SettingsForm.tsx @@ -6,6 +6,8 @@ import { AvailableLanguages } from "../../../i18n"; import { I18nKey } from "../../../i18n/declaration"; import { AutocompleteCombobox } from "./AutocompleteCombobox"; import { Settings } from "#/services/settings"; +import { organizeModelsAndProviders } from "#/utils/organizeModelsAndProviders"; +import { ModelSelector } from "./ModelSelector"; interface SettingsFormProps { settings: Settings; @@ -15,6 +17,8 @@ interface SettingsFormProps { disabled: boolean; onModelChange: (model: string) => void; + onCustomModelChange: (model: string) => void; + onModelTypeChange: (type: "custom" | "default") => void; onAPIKeyChange: (apiKey: string) => void; onAgentChange: (agent: string) => void; onLanguageChange: (language: string) => void; @@ -29,6 +33,8 @@ function SettingsForm({ securityAnalyzers, disabled, onModelChange, + onCustomModelChange, + onModelTypeChange, onAPIKeyChange, onAgentChange, onLanguageChange, @@ -38,20 +44,46 @@ function SettingsForm({ const { t } = useTranslation(); const { isOpen: isVisible, onOpenChange: onVisibleChange } = useDisclosure(); const [isAgentSelectEnabled, setIsAgentSelectEnabled] = React.useState(false); + const [usingCustomModel, setUsingCustomModel] = React.useState( + settings.USING_CUSTOM_MODEL, + ); + + const changeModelType = (type: "custom" | "default") => { + if (type === "custom") { + setUsingCustomModel(true); + onModelTypeChange("custom"); + } else { + setUsingCustomModel(false); + onModelTypeChange("default"); + } + }; return ( <> - ({ value: model, label: model }))} - defaultKey={settings.LLM_MODEL} - onChange={(e) => { - onModelChange(e); - }} - tooltip={t(I18nKey.SETTINGS$MODEL_TOOLTIP)} - allowCustomValue // user can type in a custom LLM model that is not in the list - disabled={disabled} - /> + changeModelType(value ? "custom" : "default")} + > + Use custom model + + {usingCustomModel && ( + + )} + {!usingCustomModel && ( + + )} ({ ...(await importOriginal()), getSettings: vi.fn().mockReturnValue({ LLM_MODEL: "gpt-4o", + CUSTOM_LLM_MODEL: "", + USING_CUSTOM_MODEL: false, AGENT: "CodeActAgent", LANGUAGE: "en", LLM_API_KEY: "sk-...", @@ -32,6 +34,8 @@ vi.mock("#/services/settings", async (importOriginal) => ({ }), getDefaultSettings: vi.fn().mockReturnValue({ LLM_MODEL: "gpt-4o", + CUSTOM_LLM_MODEL: "", + USING_CUSTOM_MODEL: false, AGENT: "CodeActAgent", LANGUAGE: "en", LLM_API_KEY: "", @@ -46,7 +50,14 @@ vi.mock("#/services/options", async (importOriginal) => ({ ...(await importOriginal()), fetchModels: vi .fn() - .mockResolvedValue(Promise.resolve(["model1", "model2", "model3"])), + .mockResolvedValue( + Promise.resolve([ + "gpt-4o", + "gpt-3.5-turbo", + "azure/ada", + "cohere.command-r-v1:0", + ]), + ), fetchAgents: vi .fn() .mockResolvedValue(Promise.resolve(["agent1", "agent2", "agent3"])), @@ -104,6 +115,8 @@ describe("SettingsModal", () => { describe("onHandleSave", () => { const initialSettings: Settings = { LLM_MODEL: "gpt-4o", + CUSTOM_LLM_MODEL: "", + USING_CUSTOM_MODEL: false, AGENT: "CodeActAgent", LANGUAGE: "en", LLM_API_KEY: "sk-...", @@ -122,17 +135,22 @@ describe("SettingsModal", () => { await assertModelsAndAgentsFetched(); const saveButton = screen.getByRole("button", { name: /save/i }); - const modelInput = screen.getByRole("combobox", { name: "model" }); + const providerInput = screen.getByRole("combobox", { name: "Provider" }); + const modelInput = screen.getByRole("combobox", { name: "Model" }); + + await user.click(providerInput); + const azure = screen.getByText("Azure"); + await user.click(azure); await user.click(modelInput); - const model3 = screen.getByText("model3"); - + const model3 = screen.getByText("ada"); await user.click(model3); + await user.click(saveButton); expect(saveSettings).toHaveBeenCalledWith({ ...initialSettings, - LLM_MODEL: "model3", + LLM_MODEL: "azure/ada", }); }); @@ -146,12 +164,17 @@ describe("SettingsModal", () => { ); const saveButton = screen.getByRole("button", { name: /save/i }); - const modelInput = screen.getByRole("combobox", { name: "model" }); + const providerInput = screen.getByRole("combobox", { name: "Provider" }); + const modelInput = screen.getByRole("combobox", { name: "Model" }); + + await user.click(providerInput); + const openai = screen.getByText("OpenAI"); + await user.click(openai); await user.click(modelInput); - const model3 = screen.getByText("model3"); - + const model3 = screen.getByText("gpt-3.5-turbo"); await user.click(model3); + await user.click(saveButton); expect(startNewSessionSpy).toHaveBeenCalled(); @@ -167,12 +190,17 @@ describe("SettingsModal", () => { ); const saveButton = screen.getByRole("button", { name: /save/i }); - const modelInput = screen.getByRole("combobox", { name: "model" }); + const providerInput = screen.getByRole("combobox", { name: "Provider" }); + const modelInput = screen.getByRole("combobox", { name: "Model" }); + + await user.click(providerInput); + const cohere = screen.getByText("cohere"); + await user.click(cohere); await user.click(modelInput); - const model3 = screen.getByText("model3"); - + const model3 = screen.getByText("command-r-v1:0"); await user.click(model3); + await user.click(saveButton); expect(toastSpy).toHaveBeenCalledTimes(4); @@ -213,12 +241,17 @@ describe("SettingsModal", () => { }); const saveButton = screen.getByRole("button", { name: /save/i }); - const modelInput = screen.getByRole("combobox", { name: "model" }); + const providerInput = screen.getByRole("combobox", { name: "Provider" }); + const modelInput = screen.getByRole("combobox", { name: "Model" }); + + await user.click(providerInput); + const cohere = screen.getByText("cohere"); + await user.click(cohere); await user.click(modelInput); - const model3 = screen.getByText("model3"); - + const model3 = screen.getByText("command-r-v1:0"); await user.click(model3); + await user.click(saveButton); expect(onOpenChangeMock).toHaveBeenCalledWith(false); diff --git a/frontend/src/components/modals/settings/SettingsModal.tsx b/frontend/src/components/modals/settings/SettingsModal.tsx index 50f90192d5..5abc63a5c2 100644 --- a/frontend/src/components/modals/settings/SettingsModal.tsx +++ b/frontend/src/components/modals/settings/SettingsModal.tsx @@ -63,8 +63,10 @@ function SettingsModal({ isOpen, onOpenChange }: SettingsProps) { React.useEffect(() => { (async () => { try { - setModels(await fetchModels()); - setAgents(await fetchAgents()); + const fetchedModels = await fetchModels(); + const fetchedAgents = await fetchAgents(); + setModels(fetchedModels); + setAgents(fetchedAgents); setSecurityAnalyzers(await fetchSecurityAnalyzers()); } catch (error) { toast.error("settings", t(I18nKey.CONFIGURATION$ERROR_FETCH_MODELS)); @@ -81,6 +83,20 @@ function SettingsModal({ isOpen, onOpenChange }: SettingsProps) { })); }; + const handleCustomModelChange = (model: string) => { + setSettings((prev) => ({ + ...prev, + CUSTOM_LLM_MODEL: model, + })); + }; + + const handleModelTypeChange = (type: "custom" | "default") => { + setSettings((prev) => ({ + ...prev, + USING_CUSTOM_MODEL: type === "custom", + })); + }; + const handleAgentChange = (agent: string) => { setSettings((prev) => ({ ...prev, AGENT: agent })); }; @@ -189,6 +205,8 @@ function SettingsModal({ isOpen, onOpenChange }: SettingsProps) { agents={agents} securityAnalyzers={securityAnalyzers} onModelChange={handleModelChange} + onCustomModelChange={handleCustomModelChange} + onModelTypeChange={handleModelTypeChange} onAgentChange={handleAgentChange} onLanguageChange={handleLanguageChange} onAPIKeyChange={handleAPIKeyChange} diff --git a/frontend/src/services/session.test.ts b/frontend/src/services/session.test.ts index 492fef608c..e028a3ebf7 100644 --- a/frontend/src/services/session.test.ts +++ b/frontend/src/services/session.test.ts @@ -11,9 +11,16 @@ const setupSpy = vi.spyOn(Session, "_setupSocket").mockImplementation(() => { }); describe("startNewSession", () => { + afterEach(() => { + sendSpy.mockClear(); + setupSpy.mockClear(); + }); + it("Should start a new session with the current settings", () => { const settings: Settings = { LLM_MODEL: "llm_value", + CUSTOM_LLM_MODEL: "", + USING_CUSTOM_MODEL: false, AGENT: "agent_value", LANGUAGE: "language_value", LLM_API_KEY: "sk-...", @@ -32,4 +39,33 @@ describe("startNewSession", () => { expect(setupSpy).toHaveBeenCalledTimes(1); expect(sendSpy).toHaveBeenCalledWith(JSON.stringify(event)); }); + + it("should start with the custom llm if set", () => { + const settings: Settings = { + LLM_MODEL: "llm_value", + CUSTOM_LLM_MODEL: "custom_llm_value", + USING_CUSTOM_MODEL: true, + AGENT: "agent_value", + LANGUAGE: "language_value", + LLM_API_KEY: "sk-...", + CONFIRMATION_MODE: true, + SECURITY_ANALYZER: "analyzer", + }; + + const event = { + action: ActionType.INIT, + args: settings, + }; + + saveSettings(settings); + Session.startNewSession(); + + expect(setupSpy).toHaveBeenCalledTimes(1); + expect(sendSpy).toHaveBeenCalledWith( + JSON.stringify({ + ...event, + args: { ...settings, LLM_MODEL: "custom_llm_value" }, + }), + ); + }); }); diff --git a/frontend/src/services/session.ts b/frontend/src/services/session.ts index 09bcdc7b0d..ab009928e8 100644 --- a/frontend/src/services/session.ts +++ b/frontend/src/services/session.ts @@ -46,7 +46,15 @@ class Session { private static _initializeAgent = () => { const settings = getSettings(); - const event = { action: ActionType.INIT, args: settings }; + const event = { + action: ActionType.INIT, + args: { + ...settings, + LLM_MODEL: settings.USING_CUSTOM_MODEL + ? settings.CUSTOM_LLM_MODEL + : settings.LLM_MODEL, + }, + }; const eventString = JSON.stringify(event); Session.send(eventString); }; diff --git a/frontend/src/services/settings.test.ts b/frontend/src/services/settings.test.ts index 869ed6c2be..2a337d4d19 100644 --- a/frontend/src/services/settings.test.ts +++ b/frontend/src/services/settings.test.ts @@ -18,6 +18,8 @@ describe("getSettings", () => { it("should get the stored settings", () => { (localStorage.getItem as Mock) .mockReturnValueOnce("llm_value") + .mockReturnValueOnce("custom_llm_value") + .mockReturnValueOnce("true") .mockReturnValueOnce("agent_value") .mockReturnValueOnce("language_value") .mockReturnValueOnce("api_key") @@ -28,6 +30,8 @@ describe("getSettings", () => { expect(settings).toEqual({ LLM_MODEL: "llm_value", + CUSTOM_LLM_MODEL: "custom_llm_value", + USING_CUSTOM_MODEL: true, AGENT: "agent_value", LANGUAGE: "language_value", LLM_API_KEY: "api_key", @@ -43,12 +47,16 @@ describe("getSettings", () => { .mockReturnValueOnce(null) .mockReturnValueOnce(null) .mockReturnValueOnce(null) + .mockReturnValueOnce(null) + .mockReturnValueOnce(null) .mockReturnValueOnce(null); const settings = getSettings(); expect(settings).toEqual({ LLM_MODEL: DEFAULT_SETTINGS.LLM_MODEL, + CUSTOM_LLM_MODEL: "", + USING_CUSTOM_MODEL: DEFAULT_SETTINGS.USING_CUSTOM_MODEL, AGENT: DEFAULT_SETTINGS.AGENT, LANGUAGE: DEFAULT_SETTINGS.LANGUAGE, LLM_API_KEY: "", @@ -62,6 +70,8 @@ describe("saveSettings", () => { it("should save the settings", () => { const settings: Settings = { LLM_MODEL: "llm_value", + CUSTOM_LLM_MODEL: "custom_llm_value", + USING_CUSTOM_MODEL: true, AGENT: "agent_value", LANGUAGE: "language_value", LLM_API_KEY: "some_key", @@ -72,6 +82,14 @@ describe("saveSettings", () => { saveSettings(settings); expect(localStorage.setItem).toHaveBeenCalledWith("LLM_MODEL", "llm_value"); + expect(localStorage.setItem).toHaveBeenCalledWith( + "CUSTOM_LLM_MODEL", + "custom_llm_value", + ); + expect(localStorage.setItem).toHaveBeenCalledWith( + "USING_CUSTOM_MODEL", + "true", + ); expect(localStorage.setItem).toHaveBeenCalledWith("AGENT", "agent_value"); expect(localStorage.setItem).toHaveBeenCalledWith( "LANGUAGE", @@ -122,6 +140,8 @@ describe("getSettingsDifference", () => { beforeEach(() => { (localStorage.getItem as Mock) .mockReturnValueOnce("llm_value") + .mockReturnValueOnce("custom_llm_value") + .mockReturnValueOnce("false") .mockReturnValueOnce("agent_value") .mockReturnValueOnce("language_value"); }); @@ -129,6 +149,8 @@ describe("getSettingsDifference", () => { it("should return updated settings", () => { const settings = { LLM_MODEL: "new_llm_value", + CUSTOM_LLM_MODEL: "custom_llm_value", + USING_CUSTOM_MODEL: true, AGENT: "new_agent_value", LANGUAGE: "language_value", }; @@ -136,6 +158,7 @@ describe("getSettingsDifference", () => { const updatedSettings = getSettingsDifference(settings); expect(updatedSettings).toEqual({ + USING_CUSTOM_MODEL: true, LLM_MODEL: "new_llm_value", AGENT: "new_agent_value", }); diff --git a/frontend/src/services/settings.ts b/frontend/src/services/settings.ts index ec9dcc67ec..d554d44681 100644 --- a/frontend/src/services/settings.ts +++ b/frontend/src/services/settings.ts @@ -2,6 +2,8 @@ const LATEST_SETTINGS_VERSION = 1; export type Settings = { LLM_MODEL: string; + CUSTOM_LLM_MODEL: string; + USING_CUSTOM_MODEL: boolean; AGENT: string; LANGUAGE: string; LLM_API_KEY: string; @@ -12,7 +14,9 @@ export type Settings = { type SettingsInput = Settings[keyof Settings]; export const DEFAULT_SETTINGS: Settings = { - LLM_MODEL: "gpt-4o", + LLM_MODEL: "openai/gpt-4o", + CUSTOM_LLM_MODEL: "", + USING_CUSTOM_MODEL: false, AGENT: "CodeActAgent", LANGUAGE: "en", LLM_API_KEY: "", @@ -54,6 +58,9 @@ export const getDefaultSettings = (): Settings => DEFAULT_SETTINGS; */ export const getSettings = (): Settings => { const model = localStorage.getItem("LLM_MODEL"); + const customModel = localStorage.getItem("CUSTOM_LLM_MODEL"); + const usingCustomModel = + localStorage.getItem("USING_CUSTOM_MODEL") === "true"; const agent = localStorage.getItem("AGENT"); const language = localStorage.getItem("LANGUAGE"); const apiKey = localStorage.getItem("LLM_API_KEY"); @@ -62,6 +69,8 @@ export const getSettings = (): Settings => { return { LLM_MODEL: model || DEFAULT_SETTINGS.LLM_MODEL, + CUSTOM_LLM_MODEL: customModel || DEFAULT_SETTINGS.CUSTOM_LLM_MODEL, + USING_CUSTOM_MODEL: usingCustomModel || DEFAULT_SETTINGS.USING_CUSTOM_MODEL, AGENT: agent || DEFAULT_SETTINGS.AGENT, LANGUAGE: language || DEFAULT_SETTINGS.LANGUAGE, LLM_API_KEY: apiKey || DEFAULT_SETTINGS.LLM_API_KEY, diff --git a/frontend/src/utils/extractModelAndProvider.test.ts b/frontend/src/utils/extractModelAndProvider.test.ts new file mode 100644 index 0000000000..dd43d68caa --- /dev/null +++ b/frontend/src/utils/extractModelAndProvider.test.ts @@ -0,0 +1,62 @@ +import { describe, it, expect } from "vitest"; +import { extractModelAndProvider } from "./extractModelAndProvider"; + +describe("extractModelAndProvider", () => { + it("should work", () => { + expect(extractModelAndProvider("azure/ada")).toEqual({ + provider: "azure", + model: "ada", + separator: "/", + }); + + expect( + extractModelAndProvider("azure/standard/1024-x-1024/dall-e-2"), + ).toEqual({ + provider: "azure", + model: "standard/1024-x-1024/dall-e-2", + separator: "/", + }); + + expect(extractModelAndProvider("vertex_ai_beta/chat-bison")).toEqual({ + provider: "vertex_ai_beta", + model: "chat-bison", + separator: "/", + }); + + expect(extractModelAndProvider("cohere.command-r-v1:0")).toEqual({ + provider: "cohere", + model: "command-r-v1:0", + separator: ".", + }); + + expect( + extractModelAndProvider( + "cloudflare/@cf/mistral/mistral-7b-instruct-v0.1", + ), + ).toEqual({ + provider: "cloudflare", + model: "@cf/mistral/mistral-7b-instruct-v0.1", + separator: "/", + }); + + expect(extractModelAndProvider("together-ai-21.1b-41b")).toEqual({ + provider: "", + model: "together-ai-21.1b-41b", + separator: "", + }); + }); + + it("should add provider for popular models", () => { + expect(extractModelAndProvider("gpt-3.5-turbo")).toEqual({ + provider: "openai", + model: "gpt-3.5-turbo", + separator: "/", + }); + + expect(extractModelAndProvider("gpt-4o")).toEqual({ + provider: "openai", + model: "gpt-4o", + separator: "/", + }); + }); +}); diff --git a/frontend/src/utils/extractModelAndProvider.ts b/frontend/src/utils/extractModelAndProvider.ts new file mode 100644 index 0000000000..cd8dcecf65 --- /dev/null +++ b/frontend/src/utils/extractModelAndProvider.ts @@ -0,0 +1,49 @@ +import { isNumber } from "./isNumber"; +import { VERIFIED_OPENAI_MODELS } from "./verified-models"; + +/** + * Checks if the split array is actually a version number. + * @param split The split array of the model string + * @returns Boolean indicating if the split is actually a version number + * + * @example + * const split = ["gpt-3", "5-turbo"] // incorrectly split from "gpt-3.5-turbo" + * splitIsActuallyVersion(split) // returns true + */ +const splitIsActuallyVersion = (split: string[]) => + split[1] && split[1][0] && isNumber(split[1][0]); + +/** + * Given a model string, extract the provider and model name. Currently the supported separators are "/" and "." + * @param model The model string + * @returns An object containing the provider, model name, and separator + * + * @example + * extractModelAndProvider("azure/ada") + * // returns { provider: "azure", model: "ada", separator: "/" } + * + * extractModelAndProvider("cohere.command-r-v1:0") + * // returns { provider: "cohere", model: "command-r-v1:0", separator: "." } + */ +export const extractModelAndProvider = (model: string) => { + let separator = "/"; + let split = model.split(separator); + if (split.length === 1) { + // no "/" separator found, try with "." + separator = "."; + split = model.split(separator); + if (splitIsActuallyVersion(split)) { + split = [split.join(separator)]; // undo the split + } + } + if (split.length === 1) { + // no "/" or "." separator found + if (VERIFIED_OPENAI_MODELS.includes(split[0])) { + return { provider: "openai", model: split[0], separator: "/" }; + } + // return as model only + return { provider: "", model, separator: "" }; + } + const [provider, ...modelId] = split; + return { provider, model: modelId.join(separator), separator }; +}; diff --git a/frontend/src/utils/isNumber.test.ts b/frontend/src/utils/isNumber.test.ts new file mode 100644 index 0000000000..6b3640a88d --- /dev/null +++ b/frontend/src/utils/isNumber.test.ts @@ -0,0 +1,9 @@ +import { test, expect } from "vitest"; +import { isNumber } from "./isNumber"; + +test("isNumber", () => { + expect(isNumber(1)).toBe(true); + expect(isNumber(0)).toBe(true); + expect(isNumber("3")).toBe(true); + expect(isNumber("0")).toBe(true); +}); diff --git a/frontend/src/utils/isNumber.ts b/frontend/src/utils/isNumber.ts new file mode 100644 index 0000000000..8ac961d9c0 --- /dev/null +++ b/frontend/src/utils/isNumber.ts @@ -0,0 +1,2 @@ +export const isNumber = (value: string | number): boolean => + !Number.isNaN(Number(value)); diff --git a/frontend/src/utils/mapProvider.test.ts b/frontend/src/utils/mapProvider.test.ts new file mode 100644 index 0000000000..10d3a52a2b --- /dev/null +++ b/frontend/src/utils/mapProvider.test.ts @@ -0,0 +1,27 @@ +import { test, expect } from "vitest"; +import { mapProvider } from "./mapProvider"; + +test("mapProvider", () => { + expect(mapProvider("azure")).toBe("Azure"); + expect(mapProvider("azure_ai")).toBe("Azure AI Studio"); + expect(mapProvider("vertex_ai")).toBe("VertexAI"); + expect(mapProvider("palm")).toBe("PaLM"); + expect(mapProvider("gemini")).toBe("Gemini"); + expect(mapProvider("anthropic")).toBe("Anthropic"); + expect(mapProvider("sagemaker")).toBe("AWS SageMaker"); + expect(mapProvider("bedrock")).toBe("AWS Bedrock"); + expect(mapProvider("mistral")).toBe("Mistral AI"); + expect(mapProvider("anyscale")).toBe("Anyscale"); + expect(mapProvider("databricks")).toBe("Databricks"); + expect(mapProvider("ollama")).toBe("Ollama"); + expect(mapProvider("perlexity")).toBe("Perplexity AI"); + expect(mapProvider("friendliai")).toBe("FriendliAI"); + expect(mapProvider("groq")).toBe("Groq"); + expect(mapProvider("fireworks_ai")).toBe("Fireworks AI"); + expect(mapProvider("cloudflare")).toBe("Cloudflare Workers AI"); + expect(mapProvider("deepinfra")).toBe("DeepInfra"); + expect(mapProvider("ai21")).toBe("AI21"); + expect(mapProvider("replicate")).toBe("Replicate"); + expect(mapProvider("voyage")).toBe("Voyage AI"); + expect(mapProvider("openrouter")).toBe("OpenRouter"); +}); diff --git a/frontend/src/utils/mapProvider.ts b/frontend/src/utils/mapProvider.ts new file mode 100644 index 0000000000..28a50dee71 --- /dev/null +++ b/frontend/src/utils/mapProvider.ts @@ -0,0 +1,30 @@ +export const MAP_PROVIDER = { + openai: "OpenAI", + azure: "Azure", + azure_ai: "Azure AI Studio", + vertex_ai: "VertexAI", + palm: "PaLM", + gemini: "Gemini", + anthropic: "Anthropic", + sagemaker: "AWS SageMaker", + bedrock: "AWS Bedrock", + mistral: "Mistral AI", + anyscale: "Anyscale", + databricks: "Databricks", + ollama: "Ollama", + perlexity: "Perplexity AI", + friendliai: "FriendliAI", + groq: "Groq", + fireworks_ai: "Fireworks AI", + cloudflare: "Cloudflare Workers AI", + deepinfra: "DeepInfra", + ai21: "AI21", + replicate: "Replicate", + voyage: "Voyage AI", + openrouter: "OpenRouter", +}; + +export const mapProvider = (provider: string) => + Object.keys(MAP_PROVIDER).includes(provider) + ? MAP_PROVIDER[provider as keyof typeof MAP_PROVIDER] + : provider; diff --git a/frontend/src/utils/organizeModelsAndProviders.test.ts b/frontend/src/utils/organizeModelsAndProviders.test.ts new file mode 100644 index 0000000000..53dfcede7d --- /dev/null +++ b/frontend/src/utils/organizeModelsAndProviders.test.ts @@ -0,0 +1,51 @@ +import { test } from "vitest"; +import { organizeModelsAndProviders } from "./organizeModelsAndProviders"; + +test("organizeModelsAndProviders", () => { + const models = [ + "azure/ada", + "azure/gpt-35-turbo", + "azure/gpt-3-turbo", + "azure/standard/1024-x-1024/dall-e-2", + "vertex_ai_beta/chat-bison", + "vertex_ai_beta/chat-bison-32k", + "sagemaker/meta-textgeneration-llama-2-13b", + "cohere.command-r-v1:0", + "cloudflare/@cf/mistral/mistral-7b-instruct-v0.1", + "gpt-4o", + "together-ai-21.1b-41b", + "gpt-3.5-turbo", + ]; + + const object = organizeModelsAndProviders(models); + + expect(object).toEqual({ + azure: { + separator: "/", + models: [ + "ada", + "gpt-35-turbo", + "gpt-3-turbo", + "standard/1024-x-1024/dall-e-2", + ], + }, + vertex_ai_beta: { + separator: "/", + models: ["chat-bison", "chat-bison-32k"], + }, + sagemaker: { separator: "/", models: ["meta-textgeneration-llama-2-13b"] }, + cohere: { separator: ".", models: ["command-r-v1:0"] }, + cloudflare: { + separator: "/", + models: ["@cf/mistral/mistral-7b-instruct-v0.1"], + }, + openai: { + separator: "/", + models: ["gpt-4o", "gpt-3.5-turbo"], + }, + other: { + separator: "", + models: ["together-ai-21.1b-41b"], + }, + }); +}); diff --git a/frontend/src/utils/organizeModelsAndProviders.ts b/frontend/src/utils/organizeModelsAndProviders.ts new file mode 100644 index 0000000000..61958a5783 --- /dev/null +++ b/frontend/src/utils/organizeModelsAndProviders.ts @@ -0,0 +1,42 @@ +import { extractModelAndProvider } from "./extractModelAndProvider"; + +/** + * Given a list of models, organize them by provider + * @param models The list of models + * @returns An object containing the provider and models + * + * @example + * const models = [ + * "azure/ada", + * "azure/gpt-35-turbo", + * "cohere.command-r-v1:0", + * ]; + * + * organizeModelsAndProviders(models); + * // returns { + * // azure: { + * // separator: "/", + * // models: ["ada", "gpt-35-turbo"], + * // }, + * // cohere: { + * // separator: ".", + * // models: ["command-r-v1:0"], + * // }, + * // } + */ +export const organizeModelsAndProviders = (models: string[]) => { + const object: Record = {}; + models.forEach((model) => { + const { + separator, + provider, + model: modelId, + } = extractModelAndProvider(model); + const key = provider || "other"; + if (!object[key]) { + object[key] = { separator, models: [] }; + } + object[key].models.push(modelId); + }); + return object; +}; diff --git a/frontend/src/utils/verified-models.ts b/frontend/src/utils/verified-models.ts new file mode 100644 index 0000000000..2efffaa314 --- /dev/null +++ b/frontend/src/utils/verified-models.ts @@ -0,0 +1,14 @@ +// Here are the list of verified models and providers that we know work well with OpenHands. +export const VERIFIED_PROVIDERS = ["openai", "azure", "anthropic"]; +export const VERIFIED_MODELS = ["gpt-4o", "claude-3-5-sonnet-20240620-v1:0"]; + +// LiteLLM does not return OpenAI models with the provider, so we list them here to set them ourselves for consistency +// (e.g., they return `gpt-4o` instead of `openai/gpt-4o`) +export const VERIFIED_OPENAI_MODELS = [ + "gpt-4o", + "gpt-4o-mini", + "gpt-4-turbo", + "gpt-4", + "gpt-4-32k", + "gpt-3.5-turbo", +];