diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/__tests__/ChatInput.test.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/__tests__/ChatInput.test.tsx
index 44bdb55af5..ddcf8489ce 100644
--- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/__tests__/ChatInput.test.tsx
+++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/__tests__/ChatInput.test.tsx
@@ -12,12 +12,17 @@ const mockSetCopilotChatMode = vi.fn((mode: string) => {
mockCopilotMode = mode;
});
+let mockCopilotLlmModel = "standard";
+const mockSetCopilotLlmModel = vi.fn((model: string) => {
+ mockCopilotLlmModel = model;
+});
+
vi.mock("@/app/(platform)/copilot/store", () => ({
useCopilotUIStore: () => ({
copilotChatMode: mockCopilotMode,
setCopilotChatMode: mockSetCopilotChatMode,
- copilotLlmModel: "standard",
- setCopilotLlmModel: vi.fn(),
+ copilotLlmModel: mockCopilotLlmModel,
+ setCopilotLlmModel: mockSetCopilotLlmModel,
initialPrompt: null,
setInitialPrompt: vi.fn(),
}),
@@ -109,6 +114,7 @@ afterEach(() => {
cleanup();
vi.clearAllMocks();
mockCopilotMode = "extended_thinking";
+ mockCopilotLlmModel = "standard";
});
describe("ChatInput mode toggle", () => {
@@ -189,3 +195,69 @@ describe("ChatInput mode toggle", () => {
);
});
});
+
+describe("ChatInput model toggle", () => {
+ it("renders model toggle button when flag is enabled", () => {
+ mockFlagValue = true;
+ render();
+ expect(screen.getByLabelText(/switch to advanced model/i)).toBeDefined();
+ });
+
+ it("does not render model toggle when flag is disabled", () => {
+ mockFlagValue = false;
+ render();
+ expect(
+ screen.queryByLabelText(/switch to (advanced|standard) model/i),
+ ).toBeNull();
+ });
+
+ it("toggles from standard to advanced on click", () => {
+ mockFlagValue = true;
+ mockCopilotLlmModel = "standard";
+ render();
+ fireEvent.click(screen.getByLabelText(/switch to advanced model/i));
+ expect(mockSetCopilotLlmModel).toHaveBeenCalledWith("advanced");
+ });
+
+ it("toggles from advanced to standard on click", () => {
+ mockFlagValue = true;
+ mockCopilotLlmModel = "advanced";
+ render();
+ fireEvent.click(screen.getByLabelText(/switch to standard model/i));
+ expect(mockSetCopilotLlmModel).toHaveBeenCalledWith("standard");
+ });
+
+ it("hides model toggle when streaming", () => {
+ mockFlagValue = true;
+ render();
+ expect(
+ screen.queryByLabelText(/switch to (advanced|standard) model/i),
+ ).toBeNull();
+ });
+
+ it("shows a toast when switching to advanced", async () => {
+ const { toast } = await import("@/components/molecules/Toast/use-toast");
+ mockFlagValue = true;
+ mockCopilotLlmModel = "standard";
+ render();
+ fireEvent.click(screen.getByLabelText(/switch to advanced model/i));
+ expect(toast).toHaveBeenCalledWith(
+ expect.objectContaining({
+ title: expect.stringMatching(/switched to advanced model/i),
+ }),
+ );
+ });
+
+ it("shows a toast when switching to standard", async () => {
+ const { toast } = await import("@/components/molecules/Toast/use-toast");
+ mockFlagValue = true;
+ mockCopilotLlmModel = "advanced";
+ render();
+ fireEvent.click(screen.getByLabelText(/switch to standard model/i));
+ expect(toast).toHaveBeenCalledWith(
+ expect.objectContaining({
+ title: expect.stringMatching(/switched to standard model/i),
+ }),
+ );
+ });
+});