From 07e750f03804de645bd92fcdd08b3b10ccf86db3 Mon Sep 17 00:00:00 2001
From: "sp.wack" <83104063+amanape@users.noreply.github.com>
Date: Fri, 23 Aug 2024 20:06:15 +0300
Subject: [PATCH] feat(frontend): Improve models input UI/UX in settings
(#3530)
* Create helper functions
* Add map according to litellm docs
* Create ModelSelector
* Extend model selector
* use autocomplete from nextui
* Improve keys without providers
* Handle models without a provider
* Add verified section and some empty handling
* Add support for default or previously set models
* Update tests
* Lint
* Remove modifier
* Fix typescript error
* Functionality for switching to custom model
* Add verified models
* Respond to resetting to default
* Comment
---
.../modals/settings/ModelSelector.test.tsx | 193 ++++++++++++++++++
.../modals/settings/ModelSelector.tsx | 133 ++++++++++++
.../modals/settings/SettingsForm.test.tsx | 127 +++++++++---
.../modals/settings/SettingsForm.tsx | 54 ++++-
.../modals/settings/SettingsModal.test.tsx | 61 ++++--
.../modals/settings/SettingsModal.tsx | 22 +-
frontend/src/services/session.test.ts | 36 ++++
frontend/src/services/session.ts | 10 +-
frontend/src/services/settings.test.ts | 23 +++
frontend/src/services/settings.ts | 11 +-
.../src/utils/extractModelAndProvider.test.ts | 62 ++++++
frontend/src/utils/extractModelAndProvider.ts | 49 +++++
frontend/src/utils/isNumber.test.ts | 9 +
frontend/src/utils/isNumber.ts | 2 +
frontend/src/utils/mapProvider.test.ts | 27 +++
frontend/src/utils/mapProvider.ts | 30 +++
.../utils/organizeModelsAndProviders.test.ts | 51 +++++
.../src/utils/organizeModelsAndProviders.ts | 42 ++++
frontend/src/utils/verified-models.ts | 14 ++
19 files changed, 901 insertions(+), 55 deletions(-)
create mode 100644 frontend/src/components/modals/settings/ModelSelector.test.tsx
create mode 100644 frontend/src/components/modals/settings/ModelSelector.tsx
create mode 100644 frontend/src/utils/extractModelAndProvider.test.ts
create mode 100644 frontend/src/utils/extractModelAndProvider.ts
create mode 100644 frontend/src/utils/isNumber.test.ts
create mode 100644 frontend/src/utils/isNumber.ts
create mode 100644 frontend/src/utils/mapProvider.test.ts
create mode 100644 frontend/src/utils/mapProvider.ts
create mode 100644 frontend/src/utils/organizeModelsAndProviders.test.ts
create mode 100644 frontend/src/utils/organizeModelsAndProviders.ts
create mode 100644 frontend/src/utils/verified-models.ts
diff --git a/frontend/src/components/modals/settings/ModelSelector.test.tsx b/frontend/src/components/modals/settings/ModelSelector.test.tsx
new file mode 100644
index 0000000000..2b481c7fc6
--- /dev/null
+++ b/frontend/src/components/modals/settings/ModelSelector.test.tsx
@@ -0,0 +1,193 @@
+import React from "react";
+import { describe, it, expect, vi } from "vitest";
+import { render, screen } from "@testing-library/react";
+import userEvent from "@testing-library/user-event";
+import { ModelSelector } from "./ModelSelector";
+
+describe("ModelSelector", () => {
+ const models = {
+ openai: {
+ separator: "/",
+ models: ["gpt-4o", "gpt-3.5-turbo"],
+ },
+ azure: {
+ separator: "/",
+ models: ["ada", "gpt-35-turbo"],
+ },
+ vertex_ai: {
+ separator: "/",
+ models: ["chat-bison", "chat-bison-32k"],
+ },
+ cohere: {
+ separator: ".",
+ models: ["command-r-v1:0"],
+ },
+ };
+
+ it("should display the provider selector", async () => {
+ const user = userEvent.setup();
+ const onModelChange = vi.fn();
+ render();
+
+ const selector = screen.getByLabelText("Provider");
+ expect(selector).toBeInTheDocument();
+
+ await user.click(selector);
+
+ expect(screen.getByText("OpenAI")).toBeInTheDocument();
+ expect(screen.getByText("Azure")).toBeInTheDocument();
+ expect(screen.getByText("VertexAI")).toBeInTheDocument();
+ expect(screen.getByText("cohere")).toBeInTheDocument();
+ });
+
+ it("should disable the model selector if the provider is not selected", async () => {
+ const user = userEvent.setup();
+ const onModelChange = vi.fn();
+ render();
+
+ const modelSelector = screen.getByLabelText("Model");
+ expect(modelSelector).toBeDisabled();
+
+ const providerSelector = screen.getByLabelText("Provider");
+ await user.click(providerSelector);
+
+ const vertexAI = screen.getByText("VertexAI");
+ await user.click(vertexAI);
+
+ expect(modelSelector).not.toBeDisabled();
+ });
+
+ it("should display the model selector", async () => {
+ const user = userEvent.setup();
+ const onModelChange = vi.fn();
+ render();
+
+ const providerSelector = screen.getByLabelText("Provider");
+ await user.click(providerSelector);
+
+ const azureProvider = screen.getByText("Azure");
+ await user.click(azureProvider);
+
+ const modelSelector = screen.getByLabelText("Model");
+ await user.click(modelSelector);
+
+ expect(screen.getByText("ada")).toBeInTheDocument();
+ expect(screen.getByText("gpt-35-turbo")).toBeInTheDocument();
+
+ await user.click(providerSelector);
+ const vertexProvider = screen.getByText("VertexAI");
+ await user.click(vertexProvider);
+
+ await user.click(modelSelector);
+
+ expect(screen.getByText("chat-bison")).toBeInTheDocument();
+ expect(screen.getByText("chat-bison-32k")).toBeInTheDocument();
+ });
+
+ it("should display the actual litellm model ID as the user is making the selections", async () => {
+ const user = userEvent.setup();
+ const onModelChange = vi.fn();
+ render();
+
+ const id = screen.getByTestId("model-id");
+ const providerSelector = screen.getByLabelText("Provider");
+ const modelSelector = screen.getByLabelText("Model");
+
+ expect(id).toHaveTextContent("No model selected");
+
+ await user.click(providerSelector);
+ await user.click(screen.getByText("Azure"));
+
+ expect(id).toHaveTextContent("azure/");
+
+ await user.click(modelSelector);
+ await user.click(screen.getByText("ada"));
+ expect(id).toHaveTextContent("azure/ada");
+
+ await user.click(providerSelector);
+ await user.click(screen.getByText("cohere"));
+ expect(id).toHaveTextContent("cohere.");
+
+ await user.click(modelSelector);
+ await user.click(screen.getByText("command-r-v1:0"));
+ expect(id).toHaveTextContent("cohere.command-r-v1:0");
+ });
+
+ it("should call onModelChange when the model is changed", async () => {
+ const user = userEvent.setup();
+ const onModelChange = vi.fn();
+ render();
+
+ const providerSelector = screen.getByLabelText("Provider");
+ const modelSelector = screen.getByLabelText("Model");
+
+ await user.click(providerSelector);
+ await user.click(screen.getByText("Azure"));
+
+ await user.click(modelSelector);
+ await user.click(screen.getByText("ada"));
+
+ expect(onModelChange).toHaveBeenCalledTimes(1);
+ expect(onModelChange).toHaveBeenCalledWith("azure/ada");
+
+ await user.click(modelSelector);
+ await user.click(screen.getByText("gpt-35-turbo"));
+
+ expect(onModelChange).toHaveBeenCalledTimes(2);
+ expect(onModelChange).toHaveBeenCalledWith("azure/gpt-35-turbo");
+
+ await user.click(providerSelector);
+ await user.click(screen.getByText("cohere"));
+
+ await user.click(modelSelector);
+ await user.click(screen.getByText("command-r-v1:0"));
+
+ expect(onModelChange).toHaveBeenCalledTimes(3);
+ expect(onModelChange).toHaveBeenCalledWith("cohere.command-r-v1:0");
+ });
+
+ it("should clear the model ID when the provider is cleared", async () => {
+ const user = userEvent.setup();
+ const onModelChange = vi.fn();
+ render();
+
+ const providerSelector = screen.getByLabelText("Provider");
+ const modelSelector = screen.getByLabelText("Model");
+
+ await user.click(providerSelector);
+ await user.click(screen.getByText("Azure"));
+
+ await user.click(modelSelector);
+ await user.click(screen.getByText("ada"));
+
+ expect(screen.getByTestId("model-id")).toHaveTextContent("azure/ada");
+
+ await user.clear(providerSelector);
+
+ expect(screen.getByTestId("model-id")).toHaveTextContent(
+ "No model selected",
+ );
+ });
+
+ it("should have a default value if passed", async () => {
+ const onModelChange = vi.fn();
+ render(
+ ,
+ );
+
+ expect(screen.getByTestId("model-id")).toHaveTextContent("azure/ada");
+ expect(screen.getByLabelText("Provider")).toHaveValue("Azure");
+ expect(screen.getByLabelText("Model")).toHaveValue("ada");
+ });
+
+ it.todo("should disable provider if isDisabled is true");
+
+ it.todo(
+ "should display the verified models in the correct order",
+ async () => {},
+ );
+});
diff --git a/frontend/src/components/modals/settings/ModelSelector.tsx b/frontend/src/components/modals/settings/ModelSelector.tsx
new file mode 100644
index 0000000000..f5f7c221d4
--- /dev/null
+++ b/frontend/src/components/modals/settings/ModelSelector.tsx
@@ -0,0 +1,133 @@
+import {
+ Autocomplete,
+ AutocompleteItem,
+ AutocompleteSection,
+} from "@nextui-org/react";
+import React from "react";
+import { mapProvider } from "#/utils/mapProvider";
+import { VERIFIED_MODELS, VERIFIED_PROVIDERS } from "#/utils/verified-models";
+import { extractModelAndProvider } from "#/utils/extractModelAndProvider";
+
+interface ModelSelectorProps {
+ isDisabled?: boolean;
+ models: Record;
+ onModelChange: (model: string) => void;
+ defaultModel?: string;
+}
+
+export function ModelSelector({
+ isDisabled,
+ models,
+ onModelChange,
+ defaultModel,
+}: ModelSelectorProps) {
+ const [litellmId, setLitellmId] = React.useState(null);
+ const [selectedProvider, setSelectedProvider] = React.useState(
+ null,
+ );
+ const [selectedModel, setSelectedModel] = React.useState(null);
+
+ React.useEffect(() => {
+ if (defaultModel) {
+ // runs when resetting to defaults
+ const { provider, model } = extractModelAndProvider(defaultModel);
+
+ setLitellmId(defaultModel);
+ setSelectedProvider(provider);
+ setSelectedModel(model);
+ }
+ }, [defaultModel]);
+
+ const handleChangeProvider = (provider: string) => {
+ setSelectedProvider(provider);
+ setSelectedModel(null);
+
+ const separator = models[provider]?.separator || "";
+ setLitellmId(provider + separator);
+ };
+
+ const handleChangeModel = (model: string) => {
+ const separator = models[selectedProvider || ""]?.separator || "";
+ const fullModel = selectedProvider + separator + model;
+ setLitellmId(fullModel);
+ onModelChange(fullModel);
+ setSelectedModel(model);
+ };
+
+ const clear = () => {
+ setSelectedProvider(null);
+ setLitellmId(null);
+ };
+
+ return (
+
+
+ {litellmId?.replace("other", "") || "No model selected"}
+
+
+
+
{
+ if (e?.toString()) handleChangeProvider(e.toString());
+ }}
+ onInputChange={(value) => !value && clear()}
+ defaultSelectedKey={selectedProvider ?? undefined}
+ selectedKey={selectedProvider}
+ >
+
+ {Object.keys(models)
+ .filter((provider) => VERIFIED_PROVIDERS.includes(provider))
+ .map((provider) => (
+
+ {mapProvider(provider)}
+
+ ))}
+
+
+ {Object.keys(models)
+ .filter((provider) => !VERIFIED_PROVIDERS.includes(provider))
+ .map((provider) => (
+
+ {mapProvider(provider)}
+
+ ))}
+
+
+
+
{
+ if (e?.toString()) handleChangeModel(e.toString());
+ }}
+ isDisabled={isDisabled || !selectedProvider}
+ selectedKey={selectedModel}
+ defaultSelectedKey={selectedModel ?? undefined}
+ >
+
+ {models[selectedProvider || ""]?.models
+ .filter((model) => VERIFIED_MODELS.includes(model))
+ .map((model) => (
+
+ {model}
+
+ ))}
+
+
+ {models[selectedProvider || ""]?.models
+ .filter((model) => !VERIFIED_MODELS.includes(model))
+ .map((model) => (
+
+ {model}
+
+ ))}
+
+
+
+
+ );
+}
diff --git a/frontend/src/components/modals/settings/SettingsForm.test.tsx b/frontend/src/components/modals/settings/SettingsForm.test.tsx
index b6847ff3ba..a6ac059e06 100644
--- a/frontend/src/components/modals/settings/SettingsForm.test.tsx
+++ b/frontend/src/components/modals/settings/SettingsForm.test.tsx
@@ -6,6 +6,8 @@ import { Settings } from "#/services/settings";
import SettingsForm from "./SettingsForm";
const onModelChangeMock = vi.fn();
+const onCustomModelChangeMock = vi.fn();
+const onModelTypeChangeMock = vi.fn();
const onAgentChangeMock = vi.fn();
const onLanguageChangeMock = vi.fn();
const onAPIKeyChangeMock = vi.fn();
@@ -18,7 +20,9 @@ const renderSettingsForm = (settings?: Settings) => {
disabled={false}
settings={
settings || {
- LLM_MODEL: "model1",
+ LLM_MODEL: "gpt-4o",
+ CUSTOM_LLM_MODEL: "",
+ USING_CUSTOM_MODEL: false,
AGENT: "agent1",
LANGUAGE: "en",
LLM_API_KEY: "sk-...",
@@ -26,10 +30,12 @@ const renderSettingsForm = (settings?: Settings) => {
SECURITY_ANALYZER: "analyzer1",
}
}
- models={["model1", "model2", "model3"]}
+ models={["gpt-4o", "gpt-3.5-turbo", "azure/ada"]}
agents={["agent1", "agent2", "agent3"]}
securityAnalyzers={["analyzer1", "analyzer2", "analyzer3"]}
onModelChange={onModelChangeMock}
+ onCustomModelChange={onCustomModelChangeMock}
+ onModelTypeChange={onModelTypeChangeMock}
onAgentChange={onAgentChangeMock}
onLanguageChange={onLanguageChangeMock}
onAPIKeyChange={onAPIKeyChangeMock}
@@ -43,7 +49,8 @@ describe("SettingsForm", () => {
it("should display the first values in the array by default", () => {
renderSettingsForm();
- const modelInput = screen.getByRole("combobox", { name: "model" });
+ const providerInput = screen.getByRole("combobox", { name: "Provider" });
+ const modelInput = screen.getByRole("combobox", { name: "Model" });
const agentInput = screen.getByRole("combobox", { name: "agent" });
const languageInput = screen.getByRole("combobox", { name: "language" });
const apiKeyInput = screen.getByTestId("apikey");
@@ -52,7 +59,8 @@ describe("SettingsForm", () => {
name: "securityanalyzer",
});
- expect(modelInput).toHaveValue("model1");
+ expect(providerInput).toHaveValue("OpenAI");
+ expect(modelInput).toHaveValue("gpt-4o");
expect(agentInput).toHaveValue("agent1");
expect(languageInput).toHaveValue("English");
expect(apiKeyInput).toHaveValue("sk-...");
@@ -62,7 +70,9 @@ describe("SettingsForm", () => {
it("should display the existing values if they are present", () => {
renderSettingsForm({
- LLM_MODEL: "model2",
+ LLM_MODEL: "gpt-3.5-turbo",
+ CUSTOM_LLM_MODEL: "",
+ USING_CUSTOM_MODEL: false,
AGENT: "agent2",
LANGUAGE: "es",
LLM_API_KEY: "sk-...",
@@ -70,14 +80,16 @@ describe("SettingsForm", () => {
SECURITY_ANALYZER: "analyzer2",
});
- const modelInput = screen.getByRole("combobox", { name: "model" });
+ const providerInput = screen.getByRole("combobox", { name: "Provider" });
+ const modelInput = screen.getByRole("combobox", { name: "Model" });
const agentInput = screen.getByRole("combobox", { name: "agent" });
const languageInput = screen.getByRole("combobox", { name: "language" });
const securityAnalyzerInput = screen.getByRole("combobox", {
name: "securityanalyzer",
});
- expect(modelInput).toHaveValue("model2");
+ expect(providerInput).toHaveValue("OpenAI");
+ expect(modelInput).toHaveValue("gpt-3.5-turbo");
expect(agentInput).toHaveValue("agent2");
expect(languageInput).toHaveValue("EspaƱol");
expect(securityAnalyzerInput).toHaveValue("analyzer2");
@@ -87,18 +99,22 @@ describe("SettingsForm", () => {
renderWithProviders(
{
onSecurityAnalyzerChange={onSecurityAnalyzerChangeMock}
/>,
);
- const modelInput = screen.getByRole("combobox", { name: "model" });
+
+ const providerInput = screen.getByRole("combobox", { name: "Provider" });
+ const modelInput = screen.getByRole("combobox", { name: "Model" });
const agentInput = screen.getByRole("combobox", { name: "agent" });
const languageInput = screen.getByRole("combobox", { name: "language" });
const confirmationModeInput = screen.getByTestId("confirmationmode");
@@ -114,6 +132,7 @@ describe("SettingsForm", () => {
name: "securityanalyzer",
});
+ expect(providerInput).toBeDisabled();
expect(modelInput).toBeDisabled();
expect(agentInput).toBeDisabled();
expect(languageInput).toBeDisabled();
@@ -122,22 +141,6 @@ describe("SettingsForm", () => {
});
describe("onChange handlers", () => {
- it("should call the onModelChange handler when the model changes", async () => {
- renderSettingsForm();
-
- const modelInput = screen.getByRole("combobox", { name: "model" });
- await act(async () => {
- await userEvent.click(modelInput);
- });
-
- const model3 = screen.getByText("model3");
- await act(async () => {
- await userEvent.click(model3);
- });
-
- expect(onModelChangeMock).toHaveBeenCalledWith("model3");
- });
-
it("should call the onAgentChange handler when the agent changes", async () => {
const user = userEvent.setup();
renderSettingsForm();
@@ -182,4 +185,76 @@ describe("SettingsForm", () => {
expect(onAPIKeyChangeMock).toHaveBeenCalledWith("sk-...x");
});
});
+
+ describe("Setting a custom LLM model", () => {
+ it("should display the fetched models by default", () => {
+ renderSettingsForm();
+
+ const modelSelector = screen.getByTestId("model-selector");
+ expect(modelSelector).toBeInTheDocument();
+
+ const customModelInput = screen.queryByTestId("custom-model-input");
+ expect(customModelInput).not.toBeInTheDocument();
+ });
+
+ it("should switch to the custom model input when the custom model toggle is clicked", async () => {
+ const user = userEvent.setup();
+ renderSettingsForm();
+
+ const customModelToggle = screen.getByTestId("custom-model-toggle");
+ await user.click(customModelToggle);
+
+ const modelSelector = screen.queryByTestId("model-selector");
+ expect(modelSelector).not.toBeInTheDocument();
+
+ const customModelInput = screen.getByTestId("custom-model-input");
+ expect(customModelInput).toBeInTheDocument();
+ });
+
+ it("should call the onCustomModelChange handler when the custom model input changes", async () => {
+ const user = userEvent.setup();
+ renderSettingsForm();
+
+ const customModelToggle = screen.getByTestId("custom-model-toggle");
+ await user.click(customModelToggle);
+
+ const customModelInput = screen.getByTestId("custom-model-input");
+ await userEvent.type(customModelInput, "my/custom-model");
+
+ expect(onCustomModelChangeMock).toHaveBeenCalledWith("my/custom-model");
+ expect(onModelTypeChangeMock).toHaveBeenCalledWith("custom");
+ });
+
+ it("should have custom model switched if using custom model", () => {
+ renderWithProviders(
+ ,
+ );
+
+ const customModelToggle = screen.getByTestId("custom-model-toggle");
+ expect(customModelToggle).toHaveAttribute("aria-checked", "true");
+ });
+ });
});
diff --git a/frontend/src/components/modals/settings/SettingsForm.tsx b/frontend/src/components/modals/settings/SettingsForm.tsx
index f865f659b3..d607d61731 100644
--- a/frontend/src/components/modals/settings/SettingsForm.tsx
+++ b/frontend/src/components/modals/settings/SettingsForm.tsx
@@ -6,6 +6,8 @@ import { AvailableLanguages } from "../../../i18n";
import { I18nKey } from "../../../i18n/declaration";
import { AutocompleteCombobox } from "./AutocompleteCombobox";
import { Settings } from "#/services/settings";
+import { organizeModelsAndProviders } from "#/utils/organizeModelsAndProviders";
+import { ModelSelector } from "./ModelSelector";
interface SettingsFormProps {
settings: Settings;
@@ -15,6 +17,8 @@ interface SettingsFormProps {
disabled: boolean;
onModelChange: (model: string) => void;
+ onCustomModelChange: (model: string) => void;
+ onModelTypeChange: (type: "custom" | "default") => void;
onAPIKeyChange: (apiKey: string) => void;
onAgentChange: (agent: string) => void;
onLanguageChange: (language: string) => void;
@@ -29,6 +33,8 @@ function SettingsForm({
securityAnalyzers,
disabled,
onModelChange,
+ onCustomModelChange,
+ onModelTypeChange,
onAPIKeyChange,
onAgentChange,
onLanguageChange,
@@ -38,20 +44,46 @@ function SettingsForm({
const { t } = useTranslation();
const { isOpen: isVisible, onOpenChange: onVisibleChange } = useDisclosure();
const [isAgentSelectEnabled, setIsAgentSelectEnabled] = React.useState(false);
+ const [usingCustomModel, setUsingCustomModel] = React.useState(
+ settings.USING_CUSTOM_MODEL,
+ );
+
+ const changeModelType = (type: "custom" | "default") => {
+ if (type === "custom") {
+ setUsingCustomModel(true);
+ onModelTypeChange("custom");
+ } else {
+ setUsingCustomModel(false);
+ onModelTypeChange("default");
+ }
+ };
return (
<>
- ({ value: model, label: model }))}
- defaultKey={settings.LLM_MODEL}
- onChange={(e) => {
- onModelChange(e);
- }}
- tooltip={t(I18nKey.SETTINGS$MODEL_TOOLTIP)}
- allowCustomValue // user can type in a custom LLM model that is not in the list
- disabled={disabled}
- />
+ changeModelType(value ? "custom" : "default")}
+ >
+ Use custom model
+
+ {usingCustomModel && (
+
+ )}
+ {!usingCustomModel && (
+
+ )}
({
...(await importOriginal()),
getSettings: vi.fn().mockReturnValue({
LLM_MODEL: "gpt-4o",
+ CUSTOM_LLM_MODEL: "",
+ USING_CUSTOM_MODEL: false,
AGENT: "CodeActAgent",
LANGUAGE: "en",
LLM_API_KEY: "sk-...",
@@ -32,6 +34,8 @@ vi.mock("#/services/settings", async (importOriginal) => ({
}),
getDefaultSettings: vi.fn().mockReturnValue({
LLM_MODEL: "gpt-4o",
+ CUSTOM_LLM_MODEL: "",
+ USING_CUSTOM_MODEL: false,
AGENT: "CodeActAgent",
LANGUAGE: "en",
LLM_API_KEY: "",
@@ -46,7 +50,14 @@ vi.mock("#/services/options", async (importOriginal) => ({
...(await importOriginal()),
fetchModels: vi
.fn()
- .mockResolvedValue(Promise.resolve(["model1", "model2", "model3"])),
+ .mockResolvedValue(
+ Promise.resolve([
+ "gpt-4o",
+ "gpt-3.5-turbo",
+ "azure/ada",
+ "cohere.command-r-v1:0",
+ ]),
+ ),
fetchAgents: vi
.fn()
.mockResolvedValue(Promise.resolve(["agent1", "agent2", "agent3"])),
@@ -104,6 +115,8 @@ describe("SettingsModal", () => {
describe("onHandleSave", () => {
const initialSettings: Settings = {
LLM_MODEL: "gpt-4o",
+ CUSTOM_LLM_MODEL: "",
+ USING_CUSTOM_MODEL: false,
AGENT: "CodeActAgent",
LANGUAGE: "en",
LLM_API_KEY: "sk-...",
@@ -122,17 +135,22 @@ describe("SettingsModal", () => {
await assertModelsAndAgentsFetched();
const saveButton = screen.getByRole("button", { name: /save/i });
- const modelInput = screen.getByRole("combobox", { name: "model" });
+ const providerInput = screen.getByRole("combobox", { name: "Provider" });
+ const modelInput = screen.getByRole("combobox", { name: "Model" });
+
+ await user.click(providerInput);
+ const azure = screen.getByText("Azure");
+ await user.click(azure);
await user.click(modelInput);
- const model3 = screen.getByText("model3");
-
+ const model3 = screen.getByText("ada");
await user.click(model3);
+
await user.click(saveButton);
expect(saveSettings).toHaveBeenCalledWith({
...initialSettings,
- LLM_MODEL: "model3",
+ LLM_MODEL: "azure/ada",
});
});
@@ -146,12 +164,17 @@ describe("SettingsModal", () => {
);
const saveButton = screen.getByRole("button", { name: /save/i });
- const modelInput = screen.getByRole("combobox", { name: "model" });
+ const providerInput = screen.getByRole("combobox", { name: "Provider" });
+ const modelInput = screen.getByRole("combobox", { name: "Model" });
+
+ await user.click(providerInput);
+ const openai = screen.getByText("OpenAI");
+ await user.click(openai);
await user.click(modelInput);
- const model3 = screen.getByText("model3");
-
+ const model3 = screen.getByText("gpt-3.5-turbo");
await user.click(model3);
+
await user.click(saveButton);
expect(startNewSessionSpy).toHaveBeenCalled();
@@ -167,12 +190,17 @@ describe("SettingsModal", () => {
);
const saveButton = screen.getByRole("button", { name: /save/i });
- const modelInput = screen.getByRole("combobox", { name: "model" });
+ const providerInput = screen.getByRole("combobox", { name: "Provider" });
+ const modelInput = screen.getByRole("combobox", { name: "Model" });
+
+ await user.click(providerInput);
+ const cohere = screen.getByText("cohere");
+ await user.click(cohere);
await user.click(modelInput);
- const model3 = screen.getByText("model3");
-
+ const model3 = screen.getByText("command-r-v1:0");
await user.click(model3);
+
await user.click(saveButton);
expect(toastSpy).toHaveBeenCalledTimes(4);
@@ -213,12 +241,17 @@ describe("SettingsModal", () => {
});
const saveButton = screen.getByRole("button", { name: /save/i });
- const modelInput = screen.getByRole("combobox", { name: "model" });
+ const providerInput = screen.getByRole("combobox", { name: "Provider" });
+ const modelInput = screen.getByRole("combobox", { name: "Model" });
+
+ await user.click(providerInput);
+ const cohere = screen.getByText("cohere");
+ await user.click(cohere);
await user.click(modelInput);
- const model3 = screen.getByText("model3");
-
+ const model3 = screen.getByText("command-r-v1:0");
await user.click(model3);
+
await user.click(saveButton);
expect(onOpenChangeMock).toHaveBeenCalledWith(false);
diff --git a/frontend/src/components/modals/settings/SettingsModal.tsx b/frontend/src/components/modals/settings/SettingsModal.tsx
index 50f90192d5..5abc63a5c2 100644
--- a/frontend/src/components/modals/settings/SettingsModal.tsx
+++ b/frontend/src/components/modals/settings/SettingsModal.tsx
@@ -63,8 +63,10 @@ function SettingsModal({ isOpen, onOpenChange }: SettingsProps) {
React.useEffect(() => {
(async () => {
try {
- setModels(await fetchModels());
- setAgents(await fetchAgents());
+ const fetchedModels = await fetchModels();
+ const fetchedAgents = await fetchAgents();
+ setModels(fetchedModels);
+ setAgents(fetchedAgents);
setSecurityAnalyzers(await fetchSecurityAnalyzers());
} catch (error) {
toast.error("settings", t(I18nKey.CONFIGURATION$ERROR_FETCH_MODELS));
@@ -81,6 +83,20 @@ function SettingsModal({ isOpen, onOpenChange }: SettingsProps) {
}));
};
+ const handleCustomModelChange = (model: string) => {
+ setSettings((prev) => ({
+ ...prev,
+ CUSTOM_LLM_MODEL: model,
+ }));
+ };
+
+ const handleModelTypeChange = (type: "custom" | "default") => {
+ setSettings((prev) => ({
+ ...prev,
+ USING_CUSTOM_MODEL: type === "custom",
+ }));
+ };
+
const handleAgentChange = (agent: string) => {
setSettings((prev) => ({ ...prev, AGENT: agent }));
};
@@ -189,6 +205,8 @@ function SettingsModal({ isOpen, onOpenChange }: SettingsProps) {
agents={agents}
securityAnalyzers={securityAnalyzers}
onModelChange={handleModelChange}
+ onCustomModelChange={handleCustomModelChange}
+ onModelTypeChange={handleModelTypeChange}
onAgentChange={handleAgentChange}
onLanguageChange={handleLanguageChange}
onAPIKeyChange={handleAPIKeyChange}
diff --git a/frontend/src/services/session.test.ts b/frontend/src/services/session.test.ts
index 492fef608c..e028a3ebf7 100644
--- a/frontend/src/services/session.test.ts
+++ b/frontend/src/services/session.test.ts
@@ -11,9 +11,16 @@ const setupSpy = vi.spyOn(Session, "_setupSocket").mockImplementation(() => {
});
describe("startNewSession", () => {
+ afterEach(() => {
+ sendSpy.mockClear();
+ setupSpy.mockClear();
+ });
+
it("Should start a new session with the current settings", () => {
const settings: Settings = {
LLM_MODEL: "llm_value",
+ CUSTOM_LLM_MODEL: "",
+ USING_CUSTOM_MODEL: false,
AGENT: "agent_value",
LANGUAGE: "language_value",
LLM_API_KEY: "sk-...",
@@ -32,4 +39,33 @@ describe("startNewSession", () => {
expect(setupSpy).toHaveBeenCalledTimes(1);
expect(sendSpy).toHaveBeenCalledWith(JSON.stringify(event));
});
+
+ it("should start with the custom llm if set", () => {
+ const settings: Settings = {
+ LLM_MODEL: "llm_value",
+ CUSTOM_LLM_MODEL: "custom_llm_value",
+ USING_CUSTOM_MODEL: true,
+ AGENT: "agent_value",
+ LANGUAGE: "language_value",
+ LLM_API_KEY: "sk-...",
+ CONFIRMATION_MODE: true,
+ SECURITY_ANALYZER: "analyzer",
+ };
+
+ const event = {
+ action: ActionType.INIT,
+ args: settings,
+ };
+
+ saveSettings(settings);
+ Session.startNewSession();
+
+ expect(setupSpy).toHaveBeenCalledTimes(1);
+ expect(sendSpy).toHaveBeenCalledWith(
+ JSON.stringify({
+ ...event,
+ args: { ...settings, LLM_MODEL: "custom_llm_value" },
+ }),
+ );
+ });
});
diff --git a/frontend/src/services/session.ts b/frontend/src/services/session.ts
index 09bcdc7b0d..ab009928e8 100644
--- a/frontend/src/services/session.ts
+++ b/frontend/src/services/session.ts
@@ -46,7 +46,15 @@ class Session {
private static _initializeAgent = () => {
const settings = getSettings();
- const event = { action: ActionType.INIT, args: settings };
+ const event = {
+ action: ActionType.INIT,
+ args: {
+ ...settings,
+ LLM_MODEL: settings.USING_CUSTOM_MODEL
+ ? settings.CUSTOM_LLM_MODEL
+ : settings.LLM_MODEL,
+ },
+ };
const eventString = JSON.stringify(event);
Session.send(eventString);
};
diff --git a/frontend/src/services/settings.test.ts b/frontend/src/services/settings.test.ts
index 869ed6c2be..2a337d4d19 100644
--- a/frontend/src/services/settings.test.ts
+++ b/frontend/src/services/settings.test.ts
@@ -18,6 +18,8 @@ describe("getSettings", () => {
it("should get the stored settings", () => {
(localStorage.getItem as Mock)
.mockReturnValueOnce("llm_value")
+ .mockReturnValueOnce("custom_llm_value")
+ .mockReturnValueOnce("true")
.mockReturnValueOnce("agent_value")
.mockReturnValueOnce("language_value")
.mockReturnValueOnce("api_key")
@@ -28,6 +30,8 @@ describe("getSettings", () => {
expect(settings).toEqual({
LLM_MODEL: "llm_value",
+ CUSTOM_LLM_MODEL: "custom_llm_value",
+ USING_CUSTOM_MODEL: true,
AGENT: "agent_value",
LANGUAGE: "language_value",
LLM_API_KEY: "api_key",
@@ -43,12 +47,16 @@ describe("getSettings", () => {
.mockReturnValueOnce(null)
.mockReturnValueOnce(null)
.mockReturnValueOnce(null)
+ .mockReturnValueOnce(null)
+ .mockReturnValueOnce(null)
.mockReturnValueOnce(null);
const settings = getSettings();
expect(settings).toEqual({
LLM_MODEL: DEFAULT_SETTINGS.LLM_MODEL,
+ CUSTOM_LLM_MODEL: "",
+ USING_CUSTOM_MODEL: DEFAULT_SETTINGS.USING_CUSTOM_MODEL,
AGENT: DEFAULT_SETTINGS.AGENT,
LANGUAGE: DEFAULT_SETTINGS.LANGUAGE,
LLM_API_KEY: "",
@@ -62,6 +70,8 @@ describe("saveSettings", () => {
it("should save the settings", () => {
const settings: Settings = {
LLM_MODEL: "llm_value",
+ CUSTOM_LLM_MODEL: "custom_llm_value",
+ USING_CUSTOM_MODEL: true,
AGENT: "agent_value",
LANGUAGE: "language_value",
LLM_API_KEY: "some_key",
@@ -72,6 +82,14 @@ describe("saveSettings", () => {
saveSettings(settings);
expect(localStorage.setItem).toHaveBeenCalledWith("LLM_MODEL", "llm_value");
+ expect(localStorage.setItem).toHaveBeenCalledWith(
+ "CUSTOM_LLM_MODEL",
+ "custom_llm_value",
+ );
+ expect(localStorage.setItem).toHaveBeenCalledWith(
+ "USING_CUSTOM_MODEL",
+ "true",
+ );
expect(localStorage.setItem).toHaveBeenCalledWith("AGENT", "agent_value");
expect(localStorage.setItem).toHaveBeenCalledWith(
"LANGUAGE",
@@ -122,6 +140,8 @@ describe("getSettingsDifference", () => {
beforeEach(() => {
(localStorage.getItem as Mock)
.mockReturnValueOnce("llm_value")
+ .mockReturnValueOnce("custom_llm_value")
+ .mockReturnValueOnce("false")
.mockReturnValueOnce("agent_value")
.mockReturnValueOnce("language_value");
});
@@ -129,6 +149,8 @@ describe("getSettingsDifference", () => {
it("should return updated settings", () => {
const settings = {
LLM_MODEL: "new_llm_value",
+ CUSTOM_LLM_MODEL: "custom_llm_value",
+ USING_CUSTOM_MODEL: true,
AGENT: "new_agent_value",
LANGUAGE: "language_value",
};
@@ -136,6 +158,7 @@ describe("getSettingsDifference", () => {
const updatedSettings = getSettingsDifference(settings);
expect(updatedSettings).toEqual({
+ USING_CUSTOM_MODEL: true,
LLM_MODEL: "new_llm_value",
AGENT: "new_agent_value",
});
diff --git a/frontend/src/services/settings.ts b/frontend/src/services/settings.ts
index ec9dcc67ec..d554d44681 100644
--- a/frontend/src/services/settings.ts
+++ b/frontend/src/services/settings.ts
@@ -2,6 +2,8 @@ const LATEST_SETTINGS_VERSION = 1;
export type Settings = {
LLM_MODEL: string;
+ CUSTOM_LLM_MODEL: string;
+ USING_CUSTOM_MODEL: boolean;
AGENT: string;
LANGUAGE: string;
LLM_API_KEY: string;
@@ -12,7 +14,9 @@ export type Settings = {
type SettingsInput = Settings[keyof Settings];
export const DEFAULT_SETTINGS: Settings = {
- LLM_MODEL: "gpt-4o",
+ LLM_MODEL: "openai/gpt-4o",
+ CUSTOM_LLM_MODEL: "",
+ USING_CUSTOM_MODEL: false,
AGENT: "CodeActAgent",
LANGUAGE: "en",
LLM_API_KEY: "",
@@ -54,6 +58,9 @@ export const getDefaultSettings = (): Settings => DEFAULT_SETTINGS;
*/
export const getSettings = (): Settings => {
const model = localStorage.getItem("LLM_MODEL");
+ const customModel = localStorage.getItem("CUSTOM_LLM_MODEL");
+ const usingCustomModel =
+ localStorage.getItem("USING_CUSTOM_MODEL") === "true";
const agent = localStorage.getItem("AGENT");
const language = localStorage.getItem("LANGUAGE");
const apiKey = localStorage.getItem("LLM_API_KEY");
@@ -62,6 +69,8 @@ export const getSettings = (): Settings => {
return {
LLM_MODEL: model || DEFAULT_SETTINGS.LLM_MODEL,
+ CUSTOM_LLM_MODEL: customModel || DEFAULT_SETTINGS.CUSTOM_LLM_MODEL,
+ USING_CUSTOM_MODEL: usingCustomModel || DEFAULT_SETTINGS.USING_CUSTOM_MODEL,
AGENT: agent || DEFAULT_SETTINGS.AGENT,
LANGUAGE: language || DEFAULT_SETTINGS.LANGUAGE,
LLM_API_KEY: apiKey || DEFAULT_SETTINGS.LLM_API_KEY,
diff --git a/frontend/src/utils/extractModelAndProvider.test.ts b/frontend/src/utils/extractModelAndProvider.test.ts
new file mode 100644
index 0000000000..dd43d68caa
--- /dev/null
+++ b/frontend/src/utils/extractModelAndProvider.test.ts
@@ -0,0 +1,62 @@
+import { describe, it, expect } from "vitest";
+import { extractModelAndProvider } from "./extractModelAndProvider";
+
+describe("extractModelAndProvider", () => {
+ it("should work", () => {
+ expect(extractModelAndProvider("azure/ada")).toEqual({
+ provider: "azure",
+ model: "ada",
+ separator: "/",
+ });
+
+ expect(
+ extractModelAndProvider("azure/standard/1024-x-1024/dall-e-2"),
+ ).toEqual({
+ provider: "azure",
+ model: "standard/1024-x-1024/dall-e-2",
+ separator: "/",
+ });
+
+ expect(extractModelAndProvider("vertex_ai_beta/chat-bison")).toEqual({
+ provider: "vertex_ai_beta",
+ model: "chat-bison",
+ separator: "/",
+ });
+
+ expect(extractModelAndProvider("cohere.command-r-v1:0")).toEqual({
+ provider: "cohere",
+ model: "command-r-v1:0",
+ separator: ".",
+ });
+
+ expect(
+ extractModelAndProvider(
+ "cloudflare/@cf/mistral/mistral-7b-instruct-v0.1",
+ ),
+ ).toEqual({
+ provider: "cloudflare",
+ model: "@cf/mistral/mistral-7b-instruct-v0.1",
+ separator: "/",
+ });
+
+ expect(extractModelAndProvider("together-ai-21.1b-41b")).toEqual({
+ provider: "",
+ model: "together-ai-21.1b-41b",
+ separator: "",
+ });
+ });
+
+ it("should add provider for popular models", () => {
+ expect(extractModelAndProvider("gpt-3.5-turbo")).toEqual({
+ provider: "openai",
+ model: "gpt-3.5-turbo",
+ separator: "/",
+ });
+
+ expect(extractModelAndProvider("gpt-4o")).toEqual({
+ provider: "openai",
+ model: "gpt-4o",
+ separator: "/",
+ });
+ });
+});
diff --git a/frontend/src/utils/extractModelAndProvider.ts b/frontend/src/utils/extractModelAndProvider.ts
new file mode 100644
index 0000000000..cd8dcecf65
--- /dev/null
+++ b/frontend/src/utils/extractModelAndProvider.ts
@@ -0,0 +1,49 @@
+import { isNumber } from "./isNumber";
+import { VERIFIED_OPENAI_MODELS } from "./verified-models";
+
+/**
+ * Checks if the split array is actually a version number.
+ * @param split The split array of the model string
+ * @returns Boolean indicating if the split is actually a version number
+ *
+ * @example
+ * const split = ["gpt-3", "5-turbo"] // incorrectly split from "gpt-3.5-turbo"
+ * splitIsActuallyVersion(split) // returns true
+ */
+const splitIsActuallyVersion = (split: string[]) =>
+ split[1] && split[1][0] && isNumber(split[1][0]);
+
+/**
+ * Given a model string, extract the provider and model name. Currently the supported separators are "/" and "."
+ * @param model The model string
+ * @returns An object containing the provider, model name, and separator
+ *
+ * @example
+ * extractModelAndProvider("azure/ada")
+ * // returns { provider: "azure", model: "ada", separator: "/" }
+ *
+ * extractModelAndProvider("cohere.command-r-v1:0")
+ * // returns { provider: "cohere", model: "command-r-v1:0", separator: "." }
+ */
+export const extractModelAndProvider = (model: string) => {
+ let separator = "/";
+ let split = model.split(separator);
+ if (split.length === 1) {
+ // no "/" separator found, try with "."
+ separator = ".";
+ split = model.split(separator);
+ if (splitIsActuallyVersion(split)) {
+ split = [split.join(separator)]; // undo the split
+ }
+ }
+ if (split.length === 1) {
+ // no "/" or "." separator found
+ if (VERIFIED_OPENAI_MODELS.includes(split[0])) {
+ return { provider: "openai", model: split[0], separator: "/" };
+ }
+ // return as model only
+ return { provider: "", model, separator: "" };
+ }
+ const [provider, ...modelId] = split;
+ return { provider, model: modelId.join(separator), separator };
+};
diff --git a/frontend/src/utils/isNumber.test.ts b/frontend/src/utils/isNumber.test.ts
new file mode 100644
index 0000000000..6b3640a88d
--- /dev/null
+++ b/frontend/src/utils/isNumber.test.ts
@@ -0,0 +1,9 @@
+import { test, expect } from "vitest";
+import { isNumber } from "./isNumber";
+
+test("isNumber", () => {
+ expect(isNumber(1)).toBe(true);
+ expect(isNumber(0)).toBe(true);
+ expect(isNumber("3")).toBe(true);
+ expect(isNumber("0")).toBe(true);
+});
diff --git a/frontend/src/utils/isNumber.ts b/frontend/src/utils/isNumber.ts
new file mode 100644
index 0000000000..8ac961d9c0
--- /dev/null
+++ b/frontend/src/utils/isNumber.ts
@@ -0,0 +1,2 @@
+export const isNumber = (value: string | number): boolean =>
+ !Number.isNaN(Number(value));
diff --git a/frontend/src/utils/mapProvider.test.ts b/frontend/src/utils/mapProvider.test.ts
new file mode 100644
index 0000000000..10d3a52a2b
--- /dev/null
+++ b/frontend/src/utils/mapProvider.test.ts
@@ -0,0 +1,27 @@
+import { test, expect } from "vitest";
+import { mapProvider } from "./mapProvider";
+
+test("mapProvider", () => {
+ expect(mapProvider("azure")).toBe("Azure");
+ expect(mapProvider("azure_ai")).toBe("Azure AI Studio");
+ expect(mapProvider("vertex_ai")).toBe("VertexAI");
+ expect(mapProvider("palm")).toBe("PaLM");
+ expect(mapProvider("gemini")).toBe("Gemini");
+ expect(mapProvider("anthropic")).toBe("Anthropic");
+ expect(mapProvider("sagemaker")).toBe("AWS SageMaker");
+ expect(mapProvider("bedrock")).toBe("AWS Bedrock");
+ expect(mapProvider("mistral")).toBe("Mistral AI");
+ expect(mapProvider("anyscale")).toBe("Anyscale");
+ expect(mapProvider("databricks")).toBe("Databricks");
+ expect(mapProvider("ollama")).toBe("Ollama");
+ expect(mapProvider("perlexity")).toBe("Perplexity AI");
+ expect(mapProvider("friendliai")).toBe("FriendliAI");
+ expect(mapProvider("groq")).toBe("Groq");
+ expect(mapProvider("fireworks_ai")).toBe("Fireworks AI");
+ expect(mapProvider("cloudflare")).toBe("Cloudflare Workers AI");
+ expect(mapProvider("deepinfra")).toBe("DeepInfra");
+ expect(mapProvider("ai21")).toBe("AI21");
+ expect(mapProvider("replicate")).toBe("Replicate");
+ expect(mapProvider("voyage")).toBe("Voyage AI");
+ expect(mapProvider("openrouter")).toBe("OpenRouter");
+});
diff --git a/frontend/src/utils/mapProvider.ts b/frontend/src/utils/mapProvider.ts
new file mode 100644
index 0000000000..28a50dee71
--- /dev/null
+++ b/frontend/src/utils/mapProvider.ts
@@ -0,0 +1,30 @@
+export const MAP_PROVIDER = {
+ openai: "OpenAI",
+ azure: "Azure",
+ azure_ai: "Azure AI Studio",
+ vertex_ai: "VertexAI",
+ palm: "PaLM",
+ gemini: "Gemini",
+ anthropic: "Anthropic",
+ sagemaker: "AWS SageMaker",
+ bedrock: "AWS Bedrock",
+ mistral: "Mistral AI",
+ anyscale: "Anyscale",
+ databricks: "Databricks",
+ ollama: "Ollama",
+ perlexity: "Perplexity AI",
+ friendliai: "FriendliAI",
+ groq: "Groq",
+ fireworks_ai: "Fireworks AI",
+ cloudflare: "Cloudflare Workers AI",
+ deepinfra: "DeepInfra",
+ ai21: "AI21",
+ replicate: "Replicate",
+ voyage: "Voyage AI",
+ openrouter: "OpenRouter",
+};
+
+export const mapProvider = (provider: string) =>
+ Object.keys(MAP_PROVIDER).includes(provider)
+ ? MAP_PROVIDER[provider as keyof typeof MAP_PROVIDER]
+ : provider;
diff --git a/frontend/src/utils/organizeModelsAndProviders.test.ts b/frontend/src/utils/organizeModelsAndProviders.test.ts
new file mode 100644
index 0000000000..53dfcede7d
--- /dev/null
+++ b/frontend/src/utils/organizeModelsAndProviders.test.ts
@@ -0,0 +1,51 @@
+import { test } from "vitest";
+import { organizeModelsAndProviders } from "./organizeModelsAndProviders";
+
+test("organizeModelsAndProviders", () => {
+ const models = [
+ "azure/ada",
+ "azure/gpt-35-turbo",
+ "azure/gpt-3-turbo",
+ "azure/standard/1024-x-1024/dall-e-2",
+ "vertex_ai_beta/chat-bison",
+ "vertex_ai_beta/chat-bison-32k",
+ "sagemaker/meta-textgeneration-llama-2-13b",
+ "cohere.command-r-v1:0",
+ "cloudflare/@cf/mistral/mistral-7b-instruct-v0.1",
+ "gpt-4o",
+ "together-ai-21.1b-41b",
+ "gpt-3.5-turbo",
+ ];
+
+ const object = organizeModelsAndProviders(models);
+
+ expect(object).toEqual({
+ azure: {
+ separator: "/",
+ models: [
+ "ada",
+ "gpt-35-turbo",
+ "gpt-3-turbo",
+ "standard/1024-x-1024/dall-e-2",
+ ],
+ },
+ vertex_ai_beta: {
+ separator: "/",
+ models: ["chat-bison", "chat-bison-32k"],
+ },
+ sagemaker: { separator: "/", models: ["meta-textgeneration-llama-2-13b"] },
+ cohere: { separator: ".", models: ["command-r-v1:0"] },
+ cloudflare: {
+ separator: "/",
+ models: ["@cf/mistral/mistral-7b-instruct-v0.1"],
+ },
+ openai: {
+ separator: "/",
+ models: ["gpt-4o", "gpt-3.5-turbo"],
+ },
+ other: {
+ separator: "",
+ models: ["together-ai-21.1b-41b"],
+ },
+ });
+});
diff --git a/frontend/src/utils/organizeModelsAndProviders.ts b/frontend/src/utils/organizeModelsAndProviders.ts
new file mode 100644
index 0000000000..61958a5783
--- /dev/null
+++ b/frontend/src/utils/organizeModelsAndProviders.ts
@@ -0,0 +1,42 @@
+import { extractModelAndProvider } from "./extractModelAndProvider";
+
+/**
+ * Given a list of models, organize them by provider
+ * @param models The list of models
+ * @returns An object containing the provider and models
+ *
+ * @example
+ * const models = [
+ * "azure/ada",
+ * "azure/gpt-35-turbo",
+ * "cohere.command-r-v1:0",
+ * ];
+ *
+ * organizeModelsAndProviders(models);
+ * // returns {
+ * // azure: {
+ * // separator: "/",
+ * // models: ["ada", "gpt-35-turbo"],
+ * // },
+ * // cohere: {
+ * // separator: ".",
+ * // models: ["command-r-v1:0"],
+ * // },
+ * // }
+ */
+export const organizeModelsAndProviders = (models: string[]) => {
+ const object: Record = {};
+ models.forEach((model) => {
+ const {
+ separator,
+ provider,
+ model: modelId,
+ } = extractModelAndProvider(model);
+ const key = provider || "other";
+ if (!object[key]) {
+ object[key] = { separator, models: [] };
+ }
+ object[key].models.push(modelId);
+ });
+ return object;
+};
diff --git a/frontend/src/utils/verified-models.ts b/frontend/src/utils/verified-models.ts
new file mode 100644
index 0000000000..2efffaa314
--- /dev/null
+++ b/frontend/src/utils/verified-models.ts
@@ -0,0 +1,14 @@
+// Here are the list of verified models and providers that we know work well with OpenHands.
+export const VERIFIED_PROVIDERS = ["openai", "azure", "anthropic"];
+export const VERIFIED_MODELS = ["gpt-4o", "claude-3-5-sonnet-20240620-v1:0"];
+
+// LiteLLM does not return OpenAI models with the provider, so we list them here to set them ourselves for consistency
+// (e.g., they return `gpt-4o` instead of `openai/gpt-4o`)
+export const VERIFIED_OPENAI_MODELS = [
+ "gpt-4o",
+ "gpt-4o-mini",
+ "gpt-4-turbo",
+ "gpt-4",
+ "gpt-4-32k",
+ "gpt-3.5-turbo",
+];