diff --git a/frontend/__tests__/context/ws-client-provider.test.tsx b/frontend/__tests__/context/ws-client-provider.test.tsx new file mode 100644 index 0000000000..e67a0ea56d --- /dev/null +++ b/frontend/__tests__/context/ws-client-provider.test.tsx @@ -0,0 +1,30 @@ +import { describe, it, expect, vi } from "vitest"; +import { render, screen } from "@testing-library/react"; +import * as ChatSlice from "#/state/chat-slice"; +import { + updateStatusWhenErrorMessagePresent, +} from "#/context/ws-client-provider"; + +describe("Propagate error message", () => { + it("should do nothing when no message was passed from server", () => { + const addErrorMessageSpy = vi.spyOn(ChatSlice, "addErrorMessage"); + updateStatusWhenErrorMessagePresent(null) + updateStatusWhenErrorMessagePresent(undefined) + updateStatusWhenErrorMessagePresent({}) + updateStatusWhenErrorMessagePresent({message: null}) + + expect(addErrorMessageSpy).not.toHaveBeenCalled(); + }); + + it("should display error to user when present", () => { + const message = "We have a problem!" + const addErrorMessageSpy = vi.spyOn(ChatSlice, "addErrorMessage") + updateStatusWhenErrorMessagePresent({message}) + + expect(addErrorMessageSpy).toHaveBeenCalledWith({ + message, + status_update: true, + type: 'error' + }); + }); +}); diff --git a/frontend/src/context/ws-client-provider.tsx b/frontend/src/context/ws-client-provider.tsx index 177a11e590..e061835074 100644 --- a/frontend/src/context/ws-client-provider.tsx +++ b/frontend/src/context/ws-client-provider.tsx @@ -2,7 +2,10 @@ import posthog from "posthog-js"; import React from "react"; import { io, Socket } from "socket.io-client"; import EventLogger from "#/utils/event-logger"; -import { handleAssistantMessage } from "#/services/actions"; +import { + handleAssistantMessage, + handleStatusMessage, +} from "#/services/actions"; import { useRate } from "#/hooks/use-rate"; import { OpenHandsParsedEvent } from "#/types/core"; import { @@ -64,6 +67,21 @@ interface WsClientProviderProps { conversationId: string; } +export function updateStatusWhenErrorMessagePresent(data: unknown) { + if ( + data && + typeof data === "object" && + "message" in data && + typeof data.message === "string" + ) { + handleStatusMessage({ + type: "error", + message: data.message, + status_update: true, + }); + } +} + export function WsClientProvider({ conversationId, children, @@ -101,7 +119,7 @@ export function WsClientProvider({ handleAssistantMessage(event); } - function handleDisconnect() { + function handleDisconnect(data: unknown) { setStatus(WsClientProviderStatus.DISCONNECTED); const sio = sioRef.current; if (!sio) { @@ -109,11 +127,13 @@ export function WsClientProvider({ } sio.io.opts.query = sio.io.opts.query || {}; sio.io.opts.query.latest_event_id = lastEventRef.current?.id; + updateStatusWhenErrorMessagePresent(data); } - function handleError() { - posthog.capture("socket_error"); + function handleError(data: unknown) { setStatus(WsClientProviderStatus.DISCONNECTED); + updateStatusWhenErrorMessagePresent(data); + posthog.capture("socket_error"); } React.useEffect(() => { diff --git a/frontend/src/services/actions.ts b/frontend/src/services/actions.ts index b19fc92383..8061c7242b 100644 --- a/frontend/src/services/actions.ts +++ b/frontend/src/services/actions.ts @@ -75,6 +75,7 @@ export function handleActionMessage(message: ActionMessage) { if (message.args && message.args.thought) { store.dispatch(addAssistantMessage(message.args.thought)); } + // Need to convert ActionMessage to RejectAction // @ts-expect-error TODO: fix store.dispatch(addAssistantAction(message)); } diff --git a/frontend/src/state/chat-slice.ts b/frontend/src/state/chat-slice.ts index 87ce8e52de..5bfffb62d4 100644 --- a/frontend/src/state/chat-slice.ts +++ b/frontend/src/state/chat-slice.ts @@ -73,7 +73,7 @@ export const chatSlice = createSlice({ state.messages.push(message); }, - addAssistantMessage(state, action: PayloadAction) { + addAssistantMessage(state: SliceState, action: PayloadAction) { const message: Message = { type: "thought", sender: "assistant", @@ -85,7 +85,10 @@ export const chatSlice = createSlice({ state.messages.push(message); }, - addAssistantAction(state, action: PayloadAction) { + addAssistantAction( + state: SliceState, + action: PayloadAction, + ) { const actionID = action.payload.action; if (!HANDLED_ACTIONS.includes(actionID)) { return; @@ -125,7 +128,7 @@ export const chatSlice = createSlice({ }, addAssistantObservation( - state, + state: SliceState, observation: PayloadAction, ) { const observationID = observation.payload.observation; @@ -179,7 +182,7 @@ export const chatSlice = createSlice({ }, addErrorMessage( - state, + state: SliceState, action: PayloadAction<{ id?: string; message: string }>, ) { const { id, message } = action.payload; @@ -192,7 +195,7 @@ export const chatSlice = createSlice({ }); }, - clearMessages(state) { + clearMessages(state: SliceState) { state.messages = []; }, }, diff --git a/frontend/src/types/message.tsx b/frontend/src/types/message.tsx index eb1e1bfcf9..95fca56e23 100644 --- a/frontend/src/types/message.tsx +++ b/frontend/src/types/message.tsx @@ -43,6 +43,6 @@ export interface ObservationMessage { export interface StatusMessage { status_update: true; type: string; - id: string; + id?: string; message: string; }