mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-09 14:57:59 -05:00
Refactor session management (#1810)
* refactor session mgmt * defer file handling to runtime * add todo * refactor sessions a bit more * remove messages logic from FE * fix up socket handshake * refactor frontend auth a bit * first pass at redoing file explorer * implement directory suffix * fix up file tree * close agent on websocket close * remove session saving * move file refresh * remove getWorkspace * plumb path/code differently * fix build issues * fix the tests * fix npm build * add session rehydration * fix event serialization * logspam * fix user message rehydration * add get_event fn * agent state restoration * change history tracking for codeact * fix responsiveness of init * fix lint * lint * delint * fix prop * update tests * logspam * lint * fix test * revert codeact * change fileService to use API * fix up session loading * delint * delint * fix integration tests * revert test * fix up access to options endpoints * fix initial files load * delint * fix file initialization * fix mock server * fixl int * fix auth for html * Update frontend/src/i18n/translation.json Co-authored-by: Xingyao Wang <xingyao6@illinois.edu> * refactor sessions and sockets * avoid reinitializing the same session * fix reconnect issue * change up intro message * more guards on reinit * rename agent_session * delint * fix a bunch of tests * delint * fix last test * remove code editor context * fix build * fix any * fix dot notation * Update frontend/src/services/api.ts Co-authored-by: Boxuan Li <liboxuan@connect.hku.hk> * fix up error handling * Update opendevin/server/session/agent.py Co-authored-by: Boxuan Li <liboxuan@connect.hku.hk> * Update opendevin/server/session/agent.py Co-authored-by: Boxuan Li <liboxuan@connect.hku.hk> * Update frontend/src/services/session.ts Co-authored-by: Boxuan Li <liboxuan@connect.hku.hk> * fix build errs * fix else * add closed state * delint * Update opendevin/server/session/session.py Co-authored-by: Engel Nyst <enyst@users.noreply.github.com> --------- Co-authored-by: Xingyao Wang <xingyao6@illinois.edu> Co-authored-by: Graham Neubig <neubig@gmail.com> Co-authored-by: Boxuan Li <liboxuan@connect.hku.hk> Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import { useDisclosure } from "@nextui-org/react";
|
||||
import React, { useEffect, useState } from "react";
|
||||
import React, { useEffect } from "react";
|
||||
import { Toaster } from "react-hot-toast";
|
||||
import CogTooth from "#/assets/cog-tooth";
|
||||
import ChatInterface from "#/components/chat/ChatInterface";
|
||||
@@ -8,15 +8,13 @@ import { Container, Orientation } from "#/components/Resizable";
|
||||
import Workspace from "#/components/Workspace";
|
||||
import LoadPreviousSessionModal from "#/components/modals/load-previous-session/LoadPreviousSessionModal";
|
||||
import SettingsModal from "#/components/modals/settings/SettingsModal";
|
||||
import { fetchMsgTotal } from "#/services/session";
|
||||
import Socket from "#/services/socket";
|
||||
import { ResFetchMsgTotal } from "#/types/ResponseType";
|
||||
import "./App.css";
|
||||
import AgentControlBar from "./components/AgentControlBar";
|
||||
import AgentStatusBar from "./components/AgentStatusBar";
|
||||
import Terminal from "./components/terminal/Terminal";
|
||||
import { initializeAgent } from "./services/agent";
|
||||
import { settingsAreUpToDate } from "./services/settings";
|
||||
import Session from "#/services/session";
|
||||
import { getToken } from "#/services/auth";
|
||||
import { settingsAreUpToDate } from "#/services/settings";
|
||||
|
||||
interface Props {
|
||||
setSettingOpen: (isOpen: boolean) => void;
|
||||
@@ -43,8 +41,6 @@ function Controls({ setSettingOpen }: Props): JSX.Element {
|
||||
let initOnce = false;
|
||||
|
||||
function App(): JSX.Element {
|
||||
const [isWarned, setIsWarned] = useState(false);
|
||||
|
||||
const {
|
||||
isOpen: settingsModalIsOpen,
|
||||
onOpen: onSettingsModalOpen,
|
||||
@@ -57,31 +53,18 @@ function App(): JSX.Element {
|
||||
onOpenChange: onLoadPreviousSessionModalOpenChange,
|
||||
} = useDisclosure();
|
||||
|
||||
const getMsgTotal = () => {
|
||||
if (isWarned) return;
|
||||
fetchMsgTotal()
|
||||
.then((data: ResFetchMsgTotal) => {
|
||||
if (data.msg_total > 0) {
|
||||
onLoadPreviousSessionModalOpen();
|
||||
setIsWarned(true);
|
||||
}
|
||||
})
|
||||
.catch();
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (initOnce) return;
|
||||
initOnce = true;
|
||||
|
||||
if (!settingsAreUpToDate()) {
|
||||
onSettingsModalOpen();
|
||||
} else if (getToken()) {
|
||||
onLoadPreviousSessionModalOpen();
|
||||
} else {
|
||||
initializeAgent();
|
||||
Session.startNewSession();
|
||||
}
|
||||
|
||||
Socket.registerCallback("open", [getMsgTotal]);
|
||||
|
||||
getMsgTotal();
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, []);
|
||||
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
export async function fetchModels() {
|
||||
const response = await fetch(`/api/litellm-models`);
|
||||
return response.json();
|
||||
}
|
||||
|
||||
export async function fetchAgents() {
|
||||
const response = await fetch(`/api/agents`);
|
||||
return response.json();
|
||||
}
|
||||
@@ -5,7 +5,6 @@ import ArrowIcon from "#/assets/arrow";
|
||||
import PauseIcon from "#/assets/pause";
|
||||
import PlayIcon from "#/assets/play";
|
||||
import { changeAgentState } from "#/services/agentStateService";
|
||||
import { clearMsgs } from "#/services/session";
|
||||
import store, { RootState } from "#/store";
|
||||
import AgentState from "#/types/AgentState";
|
||||
import { clearMessages } from "#/state/chatSlice";
|
||||
@@ -73,7 +72,6 @@ function AgentControlBar() {
|
||||
}
|
||||
|
||||
if (action === AgentState.STOPPED) {
|
||||
clearMsgs().then().catch();
|
||||
store.dispatch(clearMessages());
|
||||
} else {
|
||||
setIsLoading(true);
|
||||
@@ -86,7 +84,6 @@ function AgentControlBar() {
|
||||
useEffect(() => {
|
||||
if (curAgentState === desiredState) {
|
||||
if (curAgentState === AgentState.STOPPED) {
|
||||
clearMsgs().then().catch();
|
||||
store.dispatch(clearMessages());
|
||||
}
|
||||
setIsLoading(false);
|
||||
|
||||
@@ -1,33 +1,24 @@
|
||||
import Editor, { Monaco } from "@monaco-editor/react";
|
||||
import { Tab, Tabs } from "@nextui-org/react";
|
||||
import type { editor } from "monaco-editor";
|
||||
import React, { useMemo, useState } from "react";
|
||||
import React, { useMemo } from "react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { VscCode } from "react-icons/vsc";
|
||||
import { useDispatch, useSelector } from "react-redux";
|
||||
import { useSelector } from "react-redux";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { selectFile } from "#/services/fileService";
|
||||
import { setCode } from "#/state/codeSlice";
|
||||
import { RootState } from "#/store";
|
||||
import FileExplorer from "./file-explorer/FileExplorer";
|
||||
import { CodeEditorContext } from "./CodeEditorContext";
|
||||
|
||||
function CodeEditor(): JSX.Element {
|
||||
const { t } = useTranslation();
|
||||
const [selectedFileAbsolutePath, setSelectedFileAbsolutePath] = useState("");
|
||||
const selectedFileName = useMemo(() => {
|
||||
const paths = selectedFileAbsolutePath.split("/");
|
||||
return paths[paths.length - 1];
|
||||
}, [selectedFileAbsolutePath]);
|
||||
const codeEditorContext = useMemo(
|
||||
() => ({ selectedFileAbsolutePath }),
|
||||
[selectedFileAbsolutePath],
|
||||
);
|
||||
|
||||
const dispatch = useDispatch();
|
||||
const code = useSelector((state: RootState) => state.code.code);
|
||||
const activeFilepath = useSelector((state: RootState) => state.code.path);
|
||||
|
||||
const selectedFileName = useMemo(() => {
|
||||
const paths = activeFilepath.split("/");
|
||||
return paths[paths.length - 1];
|
||||
}, [activeFilepath]);
|
||||
|
||||
const handleEditorDidMount = (
|
||||
editor: editor.IStandaloneCodeEditor,
|
||||
monaco: Monaco,
|
||||
@@ -46,57 +37,44 @@ function CodeEditor(): JSX.Element {
|
||||
monaco.editor.setTheme("my-theme");
|
||||
};
|
||||
|
||||
const updateCode = async () => {
|
||||
const newCode = await selectFile(activeFilepath);
|
||||
setSelectedFileAbsolutePath(activeFilepath);
|
||||
dispatch(setCode(newCode));
|
||||
};
|
||||
|
||||
React.useEffect(() => {
|
||||
// FIXME: we can probably move this out of the component and into state/service
|
||||
if (activeFilepath) updateCode();
|
||||
}, [activeFilepath]);
|
||||
|
||||
return (
|
||||
<div className="flex h-full w-full bg-neutral-900 transition-all duration-500 ease-in-out">
|
||||
<CodeEditorContext.Provider value={codeEditorContext}>
|
||||
<FileExplorer />
|
||||
<div className="flex flex-col min-h-0 w-full">
|
||||
<Tabs
|
||||
disableCursorAnimation
|
||||
classNames={{
|
||||
base: "border-b border-divider border-neutral-600 mb-4",
|
||||
tabList:
|
||||
"w-full relative rounded-none bg-neutral-900 p-0 border-divider",
|
||||
cursor: "w-full bg-neutral-600 rounded-none",
|
||||
tab: "max-w-fit px-4 h-[36px]",
|
||||
tabContent: "group-data-[selected=true]:text-white",
|
||||
}}
|
||||
aria-label="Options"
|
||||
>
|
||||
<Tab
|
||||
key={selectedFileName.toLocaleLowerCase()}
|
||||
title={selectedFileName}
|
||||
<FileExplorer />
|
||||
<div className="flex flex-col min-h-0 w-full">
|
||||
<Tabs
|
||||
disableCursorAnimation
|
||||
classNames={{
|
||||
base: "border-b border-divider border-neutral-600 mb-4",
|
||||
tabList:
|
||||
"w-full relative rounded-none bg-neutral-900 p-0 border-divider",
|
||||
cursor: "w-full bg-neutral-600 rounded-none",
|
||||
tab: "max-w-fit px-4 h-[36px]",
|
||||
tabContent: "group-data-[selected=true]:text-white",
|
||||
}}
|
||||
aria-label="Options"
|
||||
>
|
||||
<Tab
|
||||
key={selectedFileName.toLocaleLowerCase()}
|
||||
title={selectedFileName}
|
||||
/>
|
||||
</Tabs>
|
||||
<div className="flex grow items-center justify-center">
|
||||
{selectedFileName === "" ? (
|
||||
<div className="flex flex-col items-center text-neutral-400">
|
||||
<VscCode size={100} />
|
||||
{t(I18nKey.CODE_EDITOR$EMPTY_MESSAGE)}
|
||||
</div>
|
||||
) : (
|
||||
<Editor
|
||||
height="100%"
|
||||
path={selectedFileName.toLocaleLowerCase()}
|
||||
defaultValue=""
|
||||
value={code}
|
||||
onMount={handleEditorDidMount}
|
||||
/>
|
||||
</Tabs>
|
||||
<div className="flex grow items-center justify-center">
|
||||
{selectedFileName === "" ? (
|
||||
<div className="flex flex-col items-center text-neutral-400">
|
||||
<VscCode size={100} />
|
||||
{t(I18nKey.CODE_EDITOR$EMPTY_MESSAGE)}
|
||||
</div>
|
||||
) : (
|
||||
<Editor
|
||||
height="100%"
|
||||
path={selectedFileName.toLocaleLowerCase()}
|
||||
defaultValue=""
|
||||
value={code}
|
||||
onMount={handleEditorDidMount}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</CodeEditorContext.Provider>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
import { createContext } from "react";
|
||||
|
||||
export const CodeEditorContext = createContext({
|
||||
selectedFileAbsolutePath: "",
|
||||
});
|
||||
@@ -5,7 +5,7 @@ import { act } from "react-dom/test-utils";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { renderWithProviders } from "test-utils";
|
||||
import ChatInterface from "./ChatInterface";
|
||||
import Socket from "#/services/socket";
|
||||
import Session from "#/services/session";
|
||||
import ActionType from "#/types/ActionType";
|
||||
import { addAssistantMessage } from "#/state/chatSlice";
|
||||
import AgentState from "#/types/AgentState";
|
||||
@@ -15,16 +15,17 @@ vi.mock("#/hooks/useTyping", () => ({
|
||||
useTyping: vi.fn((text: string) => text),
|
||||
}));
|
||||
|
||||
const socketSpy = vi.spyOn(Socket, "send");
|
||||
const sessionSpy = vi.spyOn(Session, "send");
|
||||
vi.spyOn(Session, "isConnected").mockImplementation(() => true);
|
||||
|
||||
// This is for the scrollview ref in Chat.tsx
|
||||
// TODO: Move this into test setup
|
||||
HTMLElement.prototype.scrollTo = vi.fn(() => {});
|
||||
|
||||
describe("ChatInterface", () => {
|
||||
it("should render the messages and input", () => {
|
||||
it("should render empty message list and input", () => {
|
||||
renderWithProviders(<ChatInterface />);
|
||||
expect(screen.queryAllByTestId("message")).toHaveLength(1); // initial welcome message only
|
||||
expect(screen.queryAllByTestId("message")).toHaveLength(0);
|
||||
});
|
||||
|
||||
it("should render the new message the user has typed", async () => {
|
||||
@@ -65,7 +66,7 @@ describe("ChatInterface", () => {
|
||||
expect(screen.getByText("Hello to you!")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should send the a start event to the Socket", () => {
|
||||
it("should send the a start event to the Session", () => {
|
||||
renderWithProviders(<ChatInterface />, {
|
||||
preloadedState: {
|
||||
agent: {
|
||||
@@ -83,10 +84,10 @@ describe("ChatInterface", () => {
|
||||
action: ActionType.MESSAGE,
|
||||
args: { content: "my message" },
|
||||
};
|
||||
expect(socketSpy).toHaveBeenCalledWith(JSON.stringify(event));
|
||||
expect(sessionSpy).toHaveBeenCalledWith(JSON.stringify(event));
|
||||
});
|
||||
|
||||
it("should send the a user message event to the Socket", () => {
|
||||
it("should send the a user message event to the Session", () => {
|
||||
renderWithProviders(<ChatInterface />, {
|
||||
preloadedState: {
|
||||
agent: {
|
||||
@@ -104,7 +105,7 @@ describe("ChatInterface", () => {
|
||||
action: ActionType.MESSAGE,
|
||||
args: { content: "my message" },
|
||||
};
|
||||
expect(socketSpy).toHaveBeenCalledWith(JSON.stringify(event));
|
||||
expect(sessionSpy).toHaveBeenCalledWith(JSON.stringify(event));
|
||||
});
|
||||
|
||||
it("should disable the user input if agent is not initialized", () => {
|
||||
|
||||
@@ -10,7 +10,7 @@ import Chat from "./Chat";
|
||||
import { RootState } from "#/store";
|
||||
import AgentState from "#/types/AgentState";
|
||||
import { sendChatMessage } from "#/services/chatService";
|
||||
import { addUserMessage } from "#/state/chatSlice";
|
||||
import { addUserMessage, addAssistantMessage } from "#/state/chatSlice";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { useScrollToBottom } from "#/hooks/useScrollToBottom";
|
||||
|
||||
@@ -58,6 +58,12 @@ function ChatInterface() {
|
||||
const { scrollDomToBottom, onChatBodyScroll, hitBottom } =
|
||||
useScrollToBottom(scrollRef);
|
||||
|
||||
React.useEffect(() => {
|
||||
if (curAgentState === AgentState.INIT && messages.length === 0) {
|
||||
dispatch(addAssistantMessage(t(I18nKey.CHAT_INTERFACE$INITIAL_MESSAGE)));
|
||||
}
|
||||
}, [curAgentState]);
|
||||
|
||||
return (
|
||||
<div className="flex flex-col h-full bg-neutral-800">
|
||||
<div className="flex items-center gap-2 border-b border-neutral-600 text-sm px-4 py-2">
|
||||
|
||||
@@ -7,8 +7,9 @@ import { describe, it, expect, vi, Mock } from "vitest";
|
||||
import FileExplorer from "./FileExplorer";
|
||||
import { uploadFiles, listFiles } from "#/services/fileService";
|
||||
import toast from "#/utils/toast";
|
||||
import AgentState from "#/types/AgentState";
|
||||
|
||||
const toastSpy = vi.spyOn(toast, "stickyError");
|
||||
const toastSpy = vi.spyOn(toast, "error");
|
||||
|
||||
vi.mock("../../services/fileService", async () => ({
|
||||
listFiles: vi.fn(async (path: string = "/") => {
|
||||
@@ -42,7 +43,13 @@ describe("FileExplorer", () => {
|
||||
it.todo("should render an empty workspace");
|
||||
|
||||
it.only("should refetch the workspace when clicking the refresh button", async () => {
|
||||
const { getByText } = renderWithProviders(<FileExplorer />);
|
||||
const { getByText } = renderWithProviders(<FileExplorer />, {
|
||||
preloadedState: {
|
||||
agent: {
|
||||
curAgentState: AgentState.RUNNING,
|
||||
},
|
||||
},
|
||||
});
|
||||
await waitFor(() => {
|
||||
expect(getByText("folder1")).toBeInTheDocument();
|
||||
expect(getByText("file2.ts")).toBeInTheDocument();
|
||||
|
||||
@@ -5,14 +5,16 @@ import {
|
||||
IoIosRefresh,
|
||||
IoIosCloudUpload,
|
||||
} from "react-icons/io";
|
||||
import { useDispatch } from "react-redux";
|
||||
import { useDispatch, useSelector } from "react-redux";
|
||||
import { IoFileTray } from "react-icons/io5";
|
||||
import { twMerge } from "tailwind-merge";
|
||||
import AgentState from "#/types/AgentState";
|
||||
import { setRefreshID } from "#/state/codeSlice";
|
||||
import { listFiles, uploadFiles } from "#/services/fileService";
|
||||
import IconButton from "../IconButton";
|
||||
import ExplorerTree from "./ExplorerTree";
|
||||
import toast from "#/utils/toast";
|
||||
import { RootState } from "#/store";
|
||||
|
||||
interface ExplorerActionsProps {
|
||||
onRefresh: () => void;
|
||||
@@ -87,6 +89,7 @@ function FileExplorer() {
|
||||
const [isHidden, setIsHidden] = React.useState(false);
|
||||
const [isDragging, setIsDragging] = React.useState(false);
|
||||
const [files, setFiles] = React.useState<string[]>([]);
|
||||
const { curAgentState } = useSelector((state: RootState) => state.agent);
|
||||
const fileInputRef = React.useRef<HTMLInputElement | null>(null);
|
||||
const dispatch = useDispatch();
|
||||
|
||||
@@ -95,6 +98,12 @@ function FileExplorer() {
|
||||
};
|
||||
|
||||
const refreshWorkspace = async () => {
|
||||
if (
|
||||
curAgentState === AgentState.LOADING ||
|
||||
curAgentState === AgentState.STOPPED
|
||||
) {
|
||||
return;
|
||||
}
|
||||
dispatch(setRefreshID(Math.random()));
|
||||
setFiles(await listFiles("/"));
|
||||
};
|
||||
@@ -104,7 +113,7 @@ function FileExplorer() {
|
||||
await uploadFiles(toAdd);
|
||||
await refreshWorkspace();
|
||||
} catch (error) {
|
||||
toast.stickyError("ws", "Error uploading file");
|
||||
toast.error("ws", "Error uploading file");
|
||||
}
|
||||
};
|
||||
|
||||
@@ -112,7 +121,9 @@ function FileExplorer() {
|
||||
(async () => {
|
||||
await refreshWorkspace();
|
||||
})();
|
||||
}, [curAgentState]);
|
||||
|
||||
React.useEffect(() => {
|
||||
const enableDragging = () => {
|
||||
setIsDragging(true);
|
||||
};
|
||||
@@ -130,6 +141,10 @@ function FileExplorer() {
|
||||
};
|
||||
}, []);
|
||||
|
||||
if (!files.length) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="relative">
|
||||
{isDragging && (
|
||||
|
||||
@@ -4,9 +4,8 @@ import { twMerge } from "tailwind-merge";
|
||||
import { RootState } from "#/store";
|
||||
import FolderIcon from "../FolderIcon";
|
||||
import FileIcon from "../FileIcons";
|
||||
import { listFiles } from "#/services/fileService";
|
||||
import { setActiveFilepath } from "#/state/codeSlice";
|
||||
import { CodeEditorContext } from "../CodeEditorContext";
|
||||
import { listFiles, selectFile } from "#/services/fileService";
|
||||
import { setCode, setActiveFilepath } from "#/state/codeSlice";
|
||||
|
||||
interface TitleProps {
|
||||
name: string;
|
||||
@@ -36,8 +35,8 @@ interface TreeNodeProps {
|
||||
function TreeNode({ path, defaultOpen = false }: TreeNodeProps) {
|
||||
const [isOpen, setIsOpen] = React.useState(defaultOpen);
|
||||
const [children, setChildren] = React.useState<string[] | null>(null);
|
||||
const { selectedFileAbsolutePath } = React.useContext(CodeEditorContext);
|
||||
const refreshID = useSelector((state: RootState) => state.code.refreshID);
|
||||
const activeFilepath = useSelector((state: RootState) => state.code.path);
|
||||
|
||||
const dispatch = useDispatch();
|
||||
|
||||
@@ -60,10 +59,12 @@ function TreeNode({ path, defaultOpen = false }: TreeNodeProps) {
|
||||
refreshChildren();
|
||||
}, [refreshID, isOpen]);
|
||||
|
||||
const handleClick = () => {
|
||||
const handleClick = async () => {
|
||||
if (isDirectory) {
|
||||
setIsOpen((prev) => !prev);
|
||||
} else {
|
||||
const newCode = await selectFile(path);
|
||||
dispatch(setCode(newCode));
|
||||
dispatch(setActiveFilepath(path));
|
||||
}
|
||||
};
|
||||
@@ -72,7 +73,7 @@ function TreeNode({ path, defaultOpen = false }: TreeNodeProps) {
|
||||
<div
|
||||
className={twMerge(
|
||||
"text-sm text-neutral-400",
|
||||
path === selectedFileAbsolutePath ? "bg-gray-700" : "",
|
||||
path === activeFilepath ? "bg-gray-700" : "",
|
||||
)}
|
||||
>
|
||||
<Title
|
||||
|
||||
@@ -2,54 +2,23 @@ import React from "react";
|
||||
import { act, render, screen, waitFor } from "@testing-library/react";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import LoadPreviousSessionModal from "./LoadPreviousSessionModal";
|
||||
import { clearMsgs, fetchMsgs } from "../../../services/session";
|
||||
import { addChatMessageFromEvent } from "../../../services/chatService";
|
||||
import { handleAssistantMessage } from "../../../services/actions";
|
||||
import toast from "../../../utils/toast";
|
||||
import Session from "../../../services/session";
|
||||
|
||||
const RESUME_SESSION_BUTTON_LABEL_KEY =
|
||||
"LOAD_SESSION$RESUME_SESSION_MODAL_ACTION_LABEL";
|
||||
const START_NEW_SESSION_BUTTON_LABEL_KEY =
|
||||
"LOAD_SESSION$START_NEW_SESSION_MODAL_ACTION_LABEL";
|
||||
|
||||
const mocks = vi.hoisted(() => ({
|
||||
fetchMsgsMock: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("../../../services/session", async (importOriginal) => ({
|
||||
...(await importOriginal<typeof import("../../../services/session")>()),
|
||||
clearMsgs: vi.fn(),
|
||||
fetchMsgs: mocks.fetchMsgsMock.mockResolvedValue({
|
||||
messages: [
|
||||
{
|
||||
id: "1",
|
||||
role: "user",
|
||||
payload: { type: "action" },
|
||||
},
|
||||
{
|
||||
id: "2",
|
||||
role: "assistant",
|
||||
payload: { type: "observation" },
|
||||
},
|
||||
],
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock("../../../services/chatService", async (importOriginal) => ({
|
||||
...(await importOriginal<typeof import("../../../services/chatService")>()),
|
||||
addChatMessageFromEvent: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("../../../services/actions", async (importOriginal) => ({
|
||||
...(await importOriginal<typeof import("../../../services/actions")>()),
|
||||
handleAssistantMessage: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("../../../utils/toast", () => ({
|
||||
default: {
|
||||
stickyError: vi.fn(),
|
||||
},
|
||||
}));
|
||||
vi.spyOn(Session, "isConnected").mockImplementation(() => true);
|
||||
const restoreOrStartNewSessionSpy = vi.spyOn(
|
||||
Session,
|
||||
"restoreOrStartNewSession",
|
||||
);
|
||||
|
||||
describe("LoadPreviousSession", () => {
|
||||
afterEach(() => {
|
||||
@@ -75,7 +44,6 @@ describe("LoadPreviousSession", () => {
|
||||
userEvent.click(startNewSessionButton);
|
||||
});
|
||||
|
||||
expect(clearMsgs).toHaveBeenCalledTimes(1);
|
||||
// modal should close right after clearing messages
|
||||
expect(onOpenChangeMock).toHaveBeenCalledWith(false);
|
||||
});
|
||||
@@ -93,36 +61,9 @@ describe("LoadPreviousSession", () => {
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(fetchMsgs).toHaveBeenCalledTimes(1);
|
||||
expect(addChatMessageFromEvent).toHaveBeenCalledTimes(1);
|
||||
expect(handleAssistantMessage).toHaveBeenCalledTimes(1);
|
||||
expect(restoreOrStartNewSessionSpy).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
// modal should close right after fetching messages
|
||||
expect(onOpenChangeMock).toHaveBeenCalledWith(false);
|
||||
});
|
||||
|
||||
it("should show an error toast if there is an error fetching the session", async () => {
|
||||
mocks.fetchMsgsMock.mockRejectedValue(new Error("Get messages failed."));
|
||||
|
||||
render(<LoadPreviousSessionModal isOpen onOpenChange={vi.fn} />);
|
||||
|
||||
const resumeSessionButton = screen.getByRole("button", {
|
||||
name: RESUME_SESSION_BUTTON_LABEL_KEY,
|
||||
});
|
||||
|
||||
act(() => {
|
||||
userEvent.click(resumeSessionButton);
|
||||
});
|
||||
|
||||
await waitFor(async () => {
|
||||
await expect(() => fetchMsgs()).rejects.toThrow();
|
||||
expect(handleAssistantMessage).not.toHaveBeenCalled();
|
||||
expect(addChatMessageFromEvent).not.toHaveBeenCalled();
|
||||
// error toast should be shown
|
||||
expect(toast.stickyError).toHaveBeenCalledWith(
|
||||
"ws",
|
||||
"Error fetching the session",
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
import React from "react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { handleAssistantMessage } from "#/services/actions";
|
||||
import { addChatMessageFromEvent } from "#/services/chatService";
|
||||
import { clearMsgs, fetchMsgs } from "#/services/session";
|
||||
import toast from "#/utils/toast";
|
||||
import BaseModal from "../base-modal/BaseModal";
|
||||
import Session from "#/services/session";
|
||||
|
||||
interface LoadPreviousSessionModalProps {
|
||||
isOpen: boolean;
|
||||
@@ -18,28 +15,6 @@ function LoadPreviousSessionModal({
|
||||
}: LoadPreviousSessionModalProps) {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const onStartNewSession = async () => {
|
||||
await clearMsgs();
|
||||
};
|
||||
|
||||
const onResumeSession = async () => {
|
||||
try {
|
||||
const { messages } = await fetchMsgs();
|
||||
|
||||
messages.forEach((message) => {
|
||||
if (message.role === "user") {
|
||||
addChatMessageFromEvent(message.payload);
|
||||
}
|
||||
|
||||
if (message.role === "assistant") {
|
||||
handleAssistantMessage(message.payload);
|
||||
}
|
||||
});
|
||||
} catch (error) {
|
||||
toast.stickyError("ws", "Error fetching the session");
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<BaseModal
|
||||
isOpen={isOpen}
|
||||
@@ -50,13 +25,13 @@ function LoadPreviousSessionModal({
|
||||
{
|
||||
label: t(I18nKey.LOAD_SESSION$RESUME_SESSION_MODAL_ACTION_LABEL),
|
||||
className: "bg-primary rounded-lg",
|
||||
action: onResumeSession,
|
||||
action: Session.restoreOrStartNewSession,
|
||||
closeAfterAction: true,
|
||||
},
|
||||
{
|
||||
label: t(I18nKey.LOAD_SESSION$START_NEW_SESSION_MODAL_ACTION_LABEL),
|
||||
className: "bg-neutral-500 rounded-lg",
|
||||
action: onStartNewSession,
|
||||
action: Session.startNewSession,
|
||||
closeAfterAction: true,
|
||||
},
|
||||
]}
|
||||
|
||||
@@ -11,12 +11,14 @@ import {
|
||||
saveSettings,
|
||||
getDefaultSettings,
|
||||
} from "#/services/settings";
|
||||
import { initializeAgent } from "#/services/agent";
|
||||
import { fetchAgents, fetchModels } from "#/api";
|
||||
import Session from "#/services/session";
|
||||
import { fetchAgents, fetchModels } from "#/services/options";
|
||||
import SettingsModal from "./SettingsModal";
|
||||
|
||||
const toastSpy = vi.spyOn(toast, "settingsChanged");
|
||||
const i18nSpy = vi.spyOn(i18next, "changeLanguage");
|
||||
const startNewSessionSpy = vi.spyOn(Session, "startNewSession");
|
||||
vi.spyOn(Session, "isConnected").mockImplementation(() => true);
|
||||
|
||||
vi.mock("#/services/settings", async (importOriginal) => ({
|
||||
...(await importOriginal<typeof import("#/services/settings")>()),
|
||||
@@ -35,12 +37,8 @@ vi.mock("#/services/settings", async (importOriginal) => ({
|
||||
saveSettings: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("#/services/agent", async () => ({
|
||||
initializeAgent: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("#/api", async (importOriginal) => ({
|
||||
...(await importOriginal<typeof import("#/api")>()),
|
||||
vi.mock("#/services/options", async (importOriginal) => ({
|
||||
...(await importOriginal<typeof import("#/services/options")>()),
|
||||
fetchModels: vi
|
||||
.fn()
|
||||
.mockResolvedValue(Promise.resolve(["model1", "model2", "model3"])),
|
||||
@@ -162,7 +160,7 @@ describe("SettingsModal", () => {
|
||||
userEvent.click(saveButton);
|
||||
});
|
||||
|
||||
expect(initializeAgent).toHaveBeenCalled();
|
||||
expect(startNewSessionSpy).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("should display a toast for every change", async () => {
|
||||
|
||||
@@ -3,10 +3,10 @@ import i18next from "i18next";
|
||||
import React, { useEffect } from "react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { useSelector } from "react-redux";
|
||||
import { fetchAgents, fetchModels } from "#/api";
|
||||
import { fetchAgents, fetchModels } from "#/services/options";
|
||||
import { AvailableLanguages } from "#/i18n";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { initializeAgent } from "#/services/agent";
|
||||
import Session from "#/services/session";
|
||||
import { RootState } from "../../../store";
|
||||
import AgentState from "../../../types/AgentState";
|
||||
import {
|
||||
@@ -100,7 +100,7 @@ function SettingsModal({ isOpen, onOpenChange }: SettingsProps) {
|
||||
const updatedSettings = getSettingsDifference(settings);
|
||||
saveSettings(settings);
|
||||
i18next.changeLanguage(settings.LANGUAGE);
|
||||
initializeAgent(); // reinitialize the agent with the new settings
|
||||
Session.startNewSession();
|
||||
|
||||
const sensitiveKeys = ["LLM_API_KEY"];
|
||||
|
||||
|
||||
@@ -266,12 +266,12 @@
|
||||
"en": "Please stop the agent before editing these settings."
|
||||
},
|
||||
"LOAD_SESSION$MODAL_TITLE": {
|
||||
"en": "Unfinished Session Detected",
|
||||
"zh-CN": "检测到有未完成的会话",
|
||||
"zh-TW": "偵測到未完成的會話"
|
||||
"en": "Return to existing session?",
|
||||
"zh-CN": "是否继续未完成的会话?",
|
||||
"zh-TW": "是否繼續未完成的會話?"
|
||||
},
|
||||
"LOAD_SESSION$MODAL_CONTENT": {
|
||||
"en": "You seem to have an unfinished task. Would you like to pick up where you left off or start fresh?",
|
||||
"en": "You seem to have an ongoing session. Would you like to pick up where you left off, or start fresh?",
|
||||
"zh-CN": "您似乎有一个未完成的任务。您想继续之前的工作还是重新开始?",
|
||||
"zh-TW": "您似乎有一個未完成的任務。您想從上次離開的地方繼續還是重新開始?"
|
||||
},
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
|
||||
import ActionType from "#/types/ActionType";
|
||||
import { initializeAgent } from "./agent";
|
||||
import { Settings, saveSettings } from "./settings";
|
||||
import Socket from "./socket";
|
||||
|
||||
const sendSpy = vi.spyOn(Socket, "send");
|
||||
|
||||
describe("initializeAgent", () => {
|
||||
it("Should initialize the agent with the current settings", () => {
|
||||
const settings: Settings = {
|
||||
LLM_MODEL: "llm_value",
|
||||
AGENT: "agent_value",
|
||||
LANGUAGE: "language_value",
|
||||
LLM_API_KEY: "sk-...",
|
||||
};
|
||||
|
||||
const event = {
|
||||
action: ActionType.INIT,
|
||||
args: settings,
|
||||
};
|
||||
|
||||
saveSettings(settings);
|
||||
initializeAgent();
|
||||
|
||||
expect(sendSpy).toHaveBeenCalledWith(JSON.stringify(event));
|
||||
});
|
||||
});
|
||||
@@ -1,14 +0,0 @@
|
||||
import ActionType from "#/types/ActionType";
|
||||
import { getSettings } from "./settings";
|
||||
import Socket from "./socket";
|
||||
|
||||
/**
|
||||
* Initialize the agent with the current settings.
|
||||
* @param settings - The new settings.
|
||||
*/
|
||||
export const initializeAgent = () => {
|
||||
const settings = getSettings();
|
||||
const event = { action: ActionType.INIT, args: settings };
|
||||
const eventString = JSON.stringify(event);
|
||||
Socket.send(eventString);
|
||||
};
|
||||
@@ -1,7 +1,6 @@
|
||||
import ActionType from "#/types/ActionType";
|
||||
import AgentState from "#/types/AgentState";
|
||||
import Socket from "./socket";
|
||||
import { initializeAgent } from "./agent";
|
||||
import Session from "./session";
|
||||
|
||||
const INIT_DELAY = 1000;
|
||||
|
||||
@@ -10,10 +9,10 @@ export function changeAgentState(state: AgentState): void {
|
||||
action: ActionType.CHANGE_AGENT_STATE,
|
||||
args: { agent_state: state },
|
||||
});
|
||||
Socket.send(eventString);
|
||||
Session.send(eventString);
|
||||
if (state === AgentState.STOPPED) {
|
||||
setTimeout(() => {
|
||||
initializeAgent();
|
||||
Session.startNewSession();
|
||||
}, INIT_DELAY);
|
||||
}
|
||||
}
|
||||
|
||||
58
frontend/src/services/api.ts
Normal file
58
frontend/src/services/api.ts
Normal file
@@ -0,0 +1,58 @@
|
||||
import { getToken } from "./auth";
|
||||
import toast from "#/utils/toast";
|
||||
|
||||
const WAIT_FOR_AUTH_DELAY_MS = 500;
|
||||
|
||||
export async function request(
|
||||
url: string,
|
||||
optionsIn: RequestInit = {},
|
||||
disableToast: boolean = false,
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
): Promise<any> {
|
||||
const options = JSON.parse(JSON.stringify(optionsIn));
|
||||
|
||||
const onFail = (msg: string) => {
|
||||
if (!disableToast) {
|
||||
toast.error("api", msg);
|
||||
}
|
||||
throw new Error(msg);
|
||||
};
|
||||
|
||||
const needsAuth = !url.startsWith("/api/options/");
|
||||
const token = getToken();
|
||||
if (!token && needsAuth) {
|
||||
return new Promise((resolve) => {
|
||||
setTimeout(() => {
|
||||
resolve(request(url, optionsIn, disableToast));
|
||||
}, WAIT_FOR_AUTH_DELAY_MS);
|
||||
});
|
||||
}
|
||||
if (token) {
|
||||
options.headers = {
|
||||
...(options.headers || {}),
|
||||
Authorization: `Bearer ${token}`,
|
||||
};
|
||||
}
|
||||
|
||||
let response = null;
|
||||
try {
|
||||
response = await fetch(url, options);
|
||||
} catch (e) {
|
||||
onFail(`Error fetching ${url}`);
|
||||
}
|
||||
if (response?.status && response?.status >= 400) {
|
||||
onFail(
|
||||
`${response.status} error while fetching ${url}: ${response?.statusText}`,
|
||||
);
|
||||
}
|
||||
if (!response?.ok) {
|
||||
onFail(`Error fetching ${url}: ${response?.statusText}`);
|
||||
}
|
||||
|
||||
try {
|
||||
return await (response && response.json());
|
||||
} catch (e) {
|
||||
onFail(`Error parsing JSON from ${url}`);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
@@ -1,18 +1,5 @@
|
||||
import * as jose from "jose";
|
||||
import type { Mock } from "vitest";
|
||||
import { fetchToken, validateToken, getToken } from "./auth";
|
||||
|
||||
vi.mock("jose", () => ({
|
||||
decodeJwt: vi.fn(),
|
||||
}));
|
||||
|
||||
// SUGGESTION: Prefer using msw for mocking requests (see https://mswjs.io/)
|
||||
global.fetch = vi.fn(() =>
|
||||
Promise.resolve({
|
||||
status: 200,
|
||||
json: () => Promise.resolve({ token: "newToken" }),
|
||||
}),
|
||||
) as Mock;
|
||||
import { getToken } from "./auth";
|
||||
|
||||
Storage.prototype.getItem = vi.fn();
|
||||
Storage.prototype.setItem = vi.fn();
|
||||
@@ -22,66 +9,12 @@ describe("Auth Service", () => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe("fetchToken", () => {
|
||||
it("should fetch and return a token", async () => {
|
||||
const data = await fetchToken();
|
||||
|
||||
expect(localStorage.getItem).toHaveBeenCalledWith("token"); // Used to set Authorization header
|
||||
expect(data).toEqual({ token: "newToken" });
|
||||
expect(fetch).toHaveBeenCalledWith(`/api/auth`, {
|
||||
headers: expect.any(Headers),
|
||||
});
|
||||
});
|
||||
|
||||
it("throws an error if response status is not 200", async () => {
|
||||
(fetch as Mock).mockImplementationOnce(() =>
|
||||
Promise.resolve({ status: 401 }),
|
||||
);
|
||||
await expect(fetchToken()).rejects.toThrow("Get token failed.");
|
||||
});
|
||||
});
|
||||
|
||||
describe("validateToken", () => {
|
||||
it("returns true for a valid token", () => {
|
||||
(jose.decodeJwt as Mock).mockReturnValue({ sid: "123" });
|
||||
expect(validateToken("validToken")).toBe(true);
|
||||
});
|
||||
|
||||
it("returns false for an invalid token", () => {
|
||||
(jose.decodeJwt as Mock).mockReturnValue({});
|
||||
expect(validateToken("invalidToken")).toBe(false);
|
||||
});
|
||||
|
||||
it("returns false when decodeJwt throws", () => {
|
||||
(jose.decodeJwt as Mock).mockImplementation(() => {
|
||||
throw new Error("Invalid token");
|
||||
});
|
||||
expect(validateToken("badToken")).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("getToken", () => {
|
||||
it("returns existing valid token from localStorage", async () => {
|
||||
(jose.decodeJwt as Mock).mockReturnValue({ sid: "123" });
|
||||
(Storage.prototype.getItem as Mock).mockReturnValue("existingToken");
|
||||
|
||||
const token = await getToken();
|
||||
expect(token).toBe("existingToken");
|
||||
});
|
||||
|
||||
it("fetches, validates, and stores a new token when existing token is invalid", async () => {
|
||||
(jose.decodeJwt as Mock)
|
||||
.mockReturnValueOnce({})
|
||||
.mockReturnValueOnce({ sid: "123" });
|
||||
|
||||
const token = await getToken();
|
||||
expect(token).toBe("newToken");
|
||||
expect(localStorage.setItem).toHaveBeenCalledWith("token", "newToken");
|
||||
});
|
||||
|
||||
it("throws an error when fetched token is invalid", async () => {
|
||||
(jose.decodeJwt as Mock).mockReturnValue({});
|
||||
await expect(getToken()).rejects.toThrow("Token validation failed.");
|
||||
it("should fetch and return a token", async () => {
|
||||
(Storage.prototype.getItem as Mock).mockReturnValue("newToken");
|
||||
const data = await getToken();
|
||||
expect(localStorage.getItem).toHaveBeenCalledWith("token"); // Used to set Authorization header
|
||||
expect(data).toEqual("newToken");
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,44 +1,13 @@
|
||||
import * as jose from "jose";
|
||||
import { ResFetchToken } from "#/types/ResponseType";
|
||||
const TOKEN_KEY = "token";
|
||||
|
||||
const fetchToken = async (): Promise<ResFetchToken> => {
|
||||
const headers = new Headers({
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Bearer ${localStorage.getItem("token")}`,
|
||||
});
|
||||
const response = await fetch(`/api/auth`, { headers });
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Get token failed.");
|
||||
}
|
||||
const data: ResFetchToken = await response.json();
|
||||
return data;
|
||||
const getToken = (): string => localStorage.getItem(TOKEN_KEY) ?? "";
|
||||
|
||||
const clearToken = (): void => {
|
||||
localStorage.removeItem(TOKEN_KEY);
|
||||
};
|
||||
|
||||
export const validateToken = (token: string): boolean => {
|
||||
try {
|
||||
const claims = jose.decodeJwt(token);
|
||||
return !(claims.sid === undefined || claims.sid === "");
|
||||
} catch (error) {
|
||||
return false;
|
||||
}
|
||||
const setToken = (token: string): void => {
|
||||
localStorage.setItem(TOKEN_KEY, token);
|
||||
};
|
||||
|
||||
const getToken = async (): Promise<string> => {
|
||||
const token = localStorage.getItem("token") ?? "";
|
||||
if (validateToken(token)) {
|
||||
return token;
|
||||
}
|
||||
|
||||
const data = await fetchToken();
|
||||
if (data.token === undefined || data.token === "") {
|
||||
throw new Error("Get token failed.");
|
||||
}
|
||||
const newToken = data.token;
|
||||
if (validateToken(newToken)) {
|
||||
localStorage.setItem("token", newToken);
|
||||
return newToken;
|
||||
}
|
||||
throw new Error("Token validation failed.");
|
||||
};
|
||||
|
||||
export { getToken, fetchToken };
|
||||
export { getToken, setToken, clearToken };
|
||||
|
||||
@@ -1,28 +1,8 @@
|
||||
import store from "#/store";
|
||||
import ActionType from "#/types/ActionType";
|
||||
import { SocketMessage } from "#/types/ResponseType";
|
||||
import { ActionMessage } from "#/types/Message";
|
||||
import Socket from "./socket";
|
||||
import { addUserMessage } from "#/state/chatSlice";
|
||||
import Session from "./session";
|
||||
|
||||
export function sendChatMessage(message: string): void {
|
||||
const event = { action: ActionType.MESSAGE, args: { content: message } };
|
||||
const eventString = JSON.stringify(event);
|
||||
Socket.send(eventString);
|
||||
}
|
||||
|
||||
export function addChatMessageFromEvent(event: string | SocketMessage): void {
|
||||
try {
|
||||
let data: ActionMessage;
|
||||
if (typeof event === "string") {
|
||||
data = JSON.parse(event);
|
||||
} else {
|
||||
data = event as ActionMessage;
|
||||
}
|
||||
if (data && data.args && data.args.task) {
|
||||
store.dispatch(addUserMessage(data.args.task));
|
||||
}
|
||||
} catch (error) {
|
||||
//
|
||||
}
|
||||
Session.send(eventString);
|
||||
}
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import { request } from "./api";
|
||||
|
||||
export async function selectFile(file: string): Promise<string> {
|
||||
const res = await fetch(`/api/select-file?file=${file}`);
|
||||
const data = await res.json();
|
||||
if (res.status !== 200) {
|
||||
throw new Error(data.error);
|
||||
}
|
||||
const data = await request(`/api/select-file?file=${file}`);
|
||||
return data.code as string;
|
||||
}
|
||||
|
||||
@@ -13,20 +11,13 @@ export async function uploadFiles(files: FileList) {
|
||||
formData.append("files", files[i]);
|
||||
}
|
||||
|
||||
const res = await fetch("/api/upload-files", {
|
||||
await request("/api/upload-files", {
|
||||
method: "POST",
|
||||
body: formData,
|
||||
});
|
||||
|
||||
const data = await res.json();
|
||||
|
||||
if (res.status !== 200) {
|
||||
throw new Error(data.error || "Failed to upload files.");
|
||||
}
|
||||
}
|
||||
|
||||
export async function listFiles(path: string = "/"): Promise<string[]> {
|
||||
const res = await fetch(`/api/list-files?path=${path}`);
|
||||
const data = await res.json();
|
||||
const data = await request(`/api/list-files?path=${path}`);
|
||||
return data as string[];
|
||||
}
|
||||
|
||||
9
frontend/src/services/options.ts
Normal file
9
frontend/src/services/options.ts
Normal file
@@ -0,0 +1,9 @@
|
||||
import { request } from "./api";
|
||||
|
||||
export async function fetchModels() {
|
||||
return request(`/api/options/models`);
|
||||
}
|
||||
|
||||
export async function fetchAgents() {
|
||||
return request(`/api/options/agents`);
|
||||
}
|
||||
@@ -1,127 +1,36 @@
|
||||
import type { Mock } from "vitest";
|
||||
import {
|
||||
ResDelMsg,
|
||||
ResFetchMsg,
|
||||
ResFetchMsgTotal,
|
||||
ResFetchMsgs,
|
||||
} from "../types/ResponseType";
|
||||
import { clearMsgs, fetchMsgTotal, fetchMsgs } from "./session";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
|
||||
// SUGGESTION: Prefer using msw for mocking requests (see https://mswjs.io/)
|
||||
global.fetch = vi.fn();
|
||||
Storage.prototype.getItem = vi.fn();
|
||||
import ActionType from "#/types/ActionType";
|
||||
import { Settings, saveSettings } from "./settings";
|
||||
import Session from "./session";
|
||||
|
||||
describe("Session Service", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
const sendSpy = vi.spyOn(Session, "send");
|
||||
const setupSpy = vi
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
.spyOn(Session as any, "_setupSocket")
|
||||
.mockImplementation(() => {
|
||||
/* eslint-disable-next-line @typescript-eslint/dot-notation */
|
||||
Session["_initializeAgent"](); // use key syntax to fix complaint about private fn
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
// Used to set Authorization header
|
||||
expect(localStorage.getItem).toHaveBeenCalledWith("token");
|
||||
});
|
||||
describe("startNewSession", () => {
|
||||
it("Should start a new session with the current settings", () => {
|
||||
const settings: Settings = {
|
||||
LLM_MODEL: "llm_value",
|
||||
AGENT: "agent_value",
|
||||
LANGUAGE: "language_value",
|
||||
LLM_API_KEY: "sk-...",
|
||||
};
|
||||
|
||||
describe("fetchMsgTotal", () => {
|
||||
it("should fetch and return message total", async () => {
|
||||
const expectedResult: ResFetchMsgTotal = {
|
||||
msg_total: 10,
|
||||
};
|
||||
const event = {
|
||||
action: ActionType.INIT,
|
||||
args: settings,
|
||||
};
|
||||
|
||||
(fetch as Mock).mockImplementationOnce(() =>
|
||||
Promise.resolve({
|
||||
status: 200,
|
||||
json: () => Promise.resolve(expectedResult),
|
||||
}),
|
||||
);
|
||||
saveSettings(settings);
|
||||
Session.startNewSession();
|
||||
|
||||
const data = await fetchMsgTotal();
|
||||
|
||||
expect(fetch).toHaveBeenCalledWith(`/api/messages/total`, {
|
||||
headers: expect.any(Headers),
|
||||
});
|
||||
|
||||
expect(data).toEqual(expectedResult);
|
||||
});
|
||||
|
||||
it("throws an error if response status is not 200", async () => {
|
||||
// NOTE: The current implementation ONLY handles 200 status;
|
||||
// this means throwing even with a status of 201, 204, etc.
|
||||
(fetch as Mock).mockImplementationOnce(() =>
|
||||
Promise.resolve({ status: 401 }),
|
||||
);
|
||||
|
||||
await expect(fetchMsgTotal()).rejects.toThrow(
|
||||
"Get message total failed.",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("fetchMsgs", () => {
|
||||
it("should fetch and return messages", async () => {
|
||||
const expectedResult: ResFetchMsgs = {
|
||||
messages: [
|
||||
{
|
||||
id: "1",
|
||||
role: "user",
|
||||
payload: {} as ResFetchMsg["payload"],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
(fetch as Mock).mockImplementationOnce(() =>
|
||||
Promise.resolve({
|
||||
status: 200,
|
||||
json: () => Promise.resolve(expectedResult),
|
||||
}),
|
||||
);
|
||||
|
||||
const data = await fetchMsgs();
|
||||
|
||||
expect(fetch).toHaveBeenCalledWith(`/api/messages`, {
|
||||
headers: expect.any(Headers),
|
||||
});
|
||||
|
||||
expect(data).toEqual(expectedResult);
|
||||
});
|
||||
|
||||
it("throws an error if response status is not 200", async () => {
|
||||
(fetch as Mock).mockImplementationOnce(() =>
|
||||
Promise.resolve({ status: 401 }),
|
||||
);
|
||||
|
||||
await expect(fetchMsgs()).rejects.toThrow("Get messages failed.");
|
||||
});
|
||||
});
|
||||
|
||||
describe("clearMsgs", () => {
|
||||
it("should clear messages", async () => {
|
||||
const expectedResult: ResDelMsg = {
|
||||
ok: "true",
|
||||
};
|
||||
|
||||
(fetch as Mock).mockImplementationOnce(() =>
|
||||
Promise.resolve({
|
||||
status: 200,
|
||||
json: () => Promise.resolve(expectedResult),
|
||||
}),
|
||||
);
|
||||
|
||||
const data = await clearMsgs();
|
||||
|
||||
expect(fetch).toHaveBeenCalledWith(`/api/messages`, {
|
||||
method: "DELETE",
|
||||
headers: expect.any(Headers),
|
||||
});
|
||||
|
||||
expect(data).toEqual(expectedResult);
|
||||
});
|
||||
|
||||
it("throws an error if response status is not 200", async () => {
|
||||
(fetch as Mock).mockImplementationOnce(() =>
|
||||
Promise.resolve({ status: 401 }),
|
||||
);
|
||||
|
||||
await expect(clearMsgs()).rejects.toThrow("Delete messages failed.");
|
||||
});
|
||||
expect(setupSpy).toHaveBeenCalledTimes(1);
|
||||
expect(sendSpy).toHaveBeenCalledWith(JSON.stringify(event));
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,49 +1,165 @@
|
||||
import {
|
||||
ResDelMsg,
|
||||
ResFetchMsgs,
|
||||
ResFetchMsgTotal,
|
||||
} from "../types/ResponseType";
|
||||
import toast from "#/utils/toast";
|
||||
import { handleAssistantMessage } from "./actions";
|
||||
import { getToken, setToken, clearToken } from "./auth";
|
||||
import ActionType from "#/types/ActionType";
|
||||
import { getSettings } from "./settings";
|
||||
|
||||
const fetchMsgTotal = async (): Promise<ResFetchMsgTotal> => {
|
||||
const headers = new Headers({
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Bearer ${localStorage.getItem("token")}`,
|
||||
});
|
||||
const response = await fetch(`/api/messages/total`, { headers });
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Get message total failed.");
|
||||
class Session {
|
||||
private static _socket: WebSocket | null = null;
|
||||
|
||||
// callbacks contain a list of callable functions
|
||||
// event: function, like:
|
||||
// open: [function1, function2]
|
||||
// message: [function1, function2]
|
||||
private static callbacks: {
|
||||
[K in keyof WebSocketEventMap]: ((data: WebSocketEventMap[K]) => void)[];
|
||||
} = {
|
||||
open: [],
|
||||
message: [],
|
||||
error: [],
|
||||
close: [],
|
||||
};
|
||||
|
||||
private static _connecting = false;
|
||||
|
||||
private static _disconnecting = false;
|
||||
|
||||
public static restoreOrStartNewSession() {
|
||||
const token = getToken();
|
||||
if (Session.isConnected()) {
|
||||
Session.disconnect();
|
||||
}
|
||||
Session._connect(token);
|
||||
}
|
||||
const data: ResFetchMsgTotal = await response.json();
|
||||
return data;
|
||||
};
|
||||
|
||||
const fetchMsgs = async (): Promise<ResFetchMsgs> => {
|
||||
const headers = new Headers({
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Bearer ${localStorage.getItem("token")}`,
|
||||
});
|
||||
const response = await fetch(`/api/messages`, { headers });
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Get messages failed.");
|
||||
public static startNewSession() {
|
||||
clearToken();
|
||||
Session.restoreOrStartNewSession();
|
||||
}
|
||||
const data: ResFetchMsgs = await response.json();
|
||||
return data;
|
||||
};
|
||||
|
||||
const clearMsgs = async (): Promise<ResDelMsg> => {
|
||||
const headers = new Headers({
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Bearer ${localStorage.getItem("token")}`,
|
||||
});
|
||||
const response = await fetch(`/api/messages`, {
|
||||
method: "DELETE",
|
||||
headers,
|
||||
});
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Delete messages failed.");
|
||||
private static _initializeAgent = () => {
|
||||
const settings = getSettings();
|
||||
const event = { action: ActionType.INIT, args: settings };
|
||||
const eventString = JSON.stringify(event);
|
||||
Session.send(eventString);
|
||||
};
|
||||
|
||||
private static _connect(token: string = ""): void {
|
||||
if (Session.isConnected()) return;
|
||||
Session._connecting = true;
|
||||
|
||||
const protocol = window.location.protocol === "https:" ? "wss:" : "ws:";
|
||||
const WS_URL = `${protocol}//${window.location.host}/ws?token=${token}`;
|
||||
Session._socket = new WebSocket(WS_URL);
|
||||
Session._setupSocket();
|
||||
}
|
||||
const data: ResDelMsg = await response.json();
|
||||
return data;
|
||||
};
|
||||
|
||||
export { fetchMsgTotal, fetchMsgs, clearMsgs };
|
||||
private static _setupSocket(): void {
|
||||
if (!Session._socket) {
|
||||
throw new Error("Socket is not initialized.");
|
||||
}
|
||||
Session._socket.onopen = (e) => {
|
||||
toast.success("ws", "Connected to server.");
|
||||
Session._connecting = false;
|
||||
Session._initializeAgent();
|
||||
Session.callbacks.open?.forEach((callback) => {
|
||||
callback(e);
|
||||
});
|
||||
};
|
||||
|
||||
Session._socket.onmessage = (e) => {
|
||||
let data = null;
|
||||
try {
|
||||
data = JSON.parse(e.data);
|
||||
} catch (err) {
|
||||
// TODO: report the error
|
||||
console.error("Error parsing JSON data", err);
|
||||
return;
|
||||
}
|
||||
if (data.error && data.error_code === 401) {
|
||||
clearToken();
|
||||
} else if (data.token) {
|
||||
setToken(data.token);
|
||||
} else {
|
||||
handleAssistantMessage(data);
|
||||
}
|
||||
};
|
||||
|
||||
Session._socket.onerror = () => {
|
||||
const msg = "Connection failed. Retry...";
|
||||
toast.error("ws", msg);
|
||||
};
|
||||
|
||||
Session._socket.onclose = () => {
|
||||
if (!Session._disconnecting) {
|
||||
setTimeout(() => {
|
||||
Session.restoreOrStartNewSession();
|
||||
}, 3000); // Reconnect after 3 seconds
|
||||
}
|
||||
Session._disconnecting = false;
|
||||
};
|
||||
}
|
||||
|
||||
static isConnected(): boolean {
|
||||
return (
|
||||
Session._socket !== null && Session._socket.readyState === WebSocket.OPEN
|
||||
);
|
||||
}
|
||||
|
||||
static disconnect(): void {
|
||||
Session._disconnecting = true;
|
||||
if (Session._socket) {
|
||||
Session._socket.close();
|
||||
}
|
||||
Session._socket = null;
|
||||
}
|
||||
|
||||
static send(message: string): void {
|
||||
if (Session._connecting) {
|
||||
setTimeout(() => Session.send(message), 1000);
|
||||
return;
|
||||
}
|
||||
if (!Session.isConnected()) {
|
||||
throw new Error("Not connected to server.");
|
||||
}
|
||||
|
||||
if (Session.isConnected()) {
|
||||
Session._socket?.send(message);
|
||||
} else {
|
||||
const msg = "Connection failed. Retry...";
|
||||
toast.error("ws", msg);
|
||||
}
|
||||
}
|
||||
|
||||
static addEventListener(
|
||||
event: string,
|
||||
callback: (e: MessageEvent) => void,
|
||||
): void {
|
||||
Session._socket?.addEventListener(
|
||||
event as keyof WebSocketEventMap,
|
||||
callback as (
|
||||
this: WebSocket,
|
||||
ev: WebSocketEventMap[keyof WebSocketEventMap],
|
||||
) => never,
|
||||
);
|
||||
}
|
||||
|
||||
static removeEventListener(
|
||||
event: string,
|
||||
listener: (e: Event) => void,
|
||||
): void {
|
||||
Session._socket?.removeEventListener(event, listener);
|
||||
}
|
||||
|
||||
static registerCallback<K extends keyof WebSocketEventMap>(
|
||||
event: K,
|
||||
callbacks: ((data: WebSocketEventMap[K]) => void)[],
|
||||
): void {
|
||||
if (Session.callbacks[event] === undefined) {
|
||||
return;
|
||||
}
|
||||
Session.callbacks[event].push(...callbacks);
|
||||
}
|
||||
}
|
||||
|
||||
export default Session;
|
||||
|
||||
@@ -1,129 +0,0 @@
|
||||
// import { toast } from "sonner";
|
||||
import toast from "#/utils/toast";
|
||||
import { handleAssistantMessage } from "./actions";
|
||||
import { getToken } from "./auth";
|
||||
|
||||
class Socket {
|
||||
private static _socket: WebSocket | null = null;
|
||||
|
||||
// callbacks contain a list of callable functions
|
||||
// event: function, like:
|
||||
// open: [function1, function2]
|
||||
// message: [function1, function2]
|
||||
private static callbacks: {
|
||||
[K in keyof WebSocketEventMap]: ((data: WebSocketEventMap[K]) => void)[];
|
||||
} = {
|
||||
open: [],
|
||||
message: [],
|
||||
error: [],
|
||||
close: [],
|
||||
};
|
||||
|
||||
private static initializing = false;
|
||||
|
||||
public static tryInitialize(): void {
|
||||
if (Socket.initializing) return;
|
||||
Socket.initializing = true;
|
||||
getToken()
|
||||
.then((token) => {
|
||||
Socket._initialize(token);
|
||||
})
|
||||
.catch(() => {
|
||||
const msg = `Connection failed. Retry...`;
|
||||
toast.stickyError("ws", msg);
|
||||
|
||||
setTimeout(() => {
|
||||
this.tryInitialize();
|
||||
}, 1500);
|
||||
});
|
||||
}
|
||||
|
||||
private static _initialize(token: string): void {
|
||||
if (Socket.isConnected()) return;
|
||||
|
||||
const protocol = window.location.protocol === "https:" ? "wss:" : "ws:";
|
||||
const WS_URL = `${protocol}//${window.location.host}/ws?token=${token}`;
|
||||
Socket._socket = new WebSocket(WS_URL);
|
||||
|
||||
Socket._socket.onopen = (e) => {
|
||||
toast.stickySuccess("ws", "Connected to server.");
|
||||
Socket.initializing = false;
|
||||
Socket.callbacks.open?.forEach((callback) => {
|
||||
callback(e);
|
||||
});
|
||||
};
|
||||
|
||||
Socket._socket.onmessage = (e) => {
|
||||
handleAssistantMessage(e.data);
|
||||
};
|
||||
|
||||
Socket._socket.onerror = () => {
|
||||
const msg = "Connection failed. Retry...";
|
||||
toast.stickyError("ws", msg);
|
||||
};
|
||||
|
||||
Socket._socket.onclose = () => {
|
||||
// Reconnect after a delay
|
||||
setTimeout(() => {
|
||||
Socket.tryInitialize();
|
||||
}, 3000); // Reconnect after 3 seconds
|
||||
};
|
||||
}
|
||||
|
||||
static isConnected(): boolean {
|
||||
return (
|
||||
Socket._socket !== null && Socket._socket.readyState === WebSocket.OPEN
|
||||
);
|
||||
}
|
||||
|
||||
static send(message: string): void {
|
||||
if (!Socket.isConnected()) {
|
||||
Socket.tryInitialize();
|
||||
}
|
||||
if (Socket.initializing) {
|
||||
setTimeout(() => Socket.send(message), 1000);
|
||||
return;
|
||||
}
|
||||
|
||||
if (Socket.isConnected()) {
|
||||
Socket._socket?.send(message);
|
||||
} else {
|
||||
const msg = "Connection failed. Retry...";
|
||||
toast.stickyError("ws", msg);
|
||||
}
|
||||
}
|
||||
|
||||
static addEventListener(
|
||||
event: string,
|
||||
callback: (e: MessageEvent) => void,
|
||||
): void {
|
||||
Socket._socket?.addEventListener(
|
||||
event as keyof WebSocketEventMap,
|
||||
callback as (
|
||||
this: WebSocket,
|
||||
ev: WebSocketEventMap[keyof WebSocketEventMap],
|
||||
) => never,
|
||||
);
|
||||
}
|
||||
|
||||
static removeEventListener(
|
||||
event: string,
|
||||
listener: (e: Event) => void,
|
||||
): void {
|
||||
Socket._socket?.removeEventListener(event, listener);
|
||||
}
|
||||
|
||||
static registerCallback<K extends keyof WebSocketEventMap>(
|
||||
event: K,
|
||||
callbacks: ((data: WebSocketEventMap[K]) => void)[],
|
||||
): void {
|
||||
if (Socket.callbacks[event] === undefined) {
|
||||
return;
|
||||
}
|
||||
Socket.callbacks[event].push(...callbacks);
|
||||
}
|
||||
}
|
||||
|
||||
Socket.tryInitialize();
|
||||
|
||||
export default Socket;
|
||||
@@ -1,3 +1,5 @@
|
||||
import { request } from "./api";
|
||||
|
||||
export type Task = {
|
||||
id: string;
|
||||
goal: string;
|
||||
@@ -14,14 +16,6 @@ export enum TaskState {
|
||||
}
|
||||
|
||||
export async function getRootTask(): Promise<Task | undefined> {
|
||||
const headers = new Headers({
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Bearer ${localStorage.getItem("token")}`,
|
||||
});
|
||||
const res = await fetch("/api/root_task", { headers });
|
||||
if (res.status !== 200 && res.status !== 204) {
|
||||
return undefined;
|
||||
}
|
||||
const data = (await res.json()) as Task;
|
||||
return data;
|
||||
const res = await request("/api/root_task");
|
||||
return res as Task;
|
||||
}
|
||||
|
||||
@@ -3,13 +3,7 @@ import { createSlice, PayloadAction } from "@reduxjs/toolkit";
|
||||
type SliceState = { messages: Message[] };
|
||||
|
||||
const initialState: SliceState = {
|
||||
messages: [
|
||||
{
|
||||
content:
|
||||
"Hi! I'm OpenDevin, an AI Software Engineer. What would you like to build with me today?",
|
||||
sender: "assistant",
|
||||
},
|
||||
],
|
||||
messages: [],
|
||||
};
|
||||
|
||||
export const chatSlice = createSlice({
|
||||
|
||||
@@ -1,41 +1,5 @@
|
||||
import { ActionMessage, ObservationMessage } from "./Message";
|
||||
|
||||
type Role = "user" | "assistant";
|
||||
|
||||
interface ResConfigurations {
|
||||
[key: string]: string | boolean | number;
|
||||
}
|
||||
|
||||
interface ResFetchToken {
|
||||
token: string;
|
||||
}
|
||||
|
||||
interface ResFetchMsgTotal {
|
||||
msg_total: number;
|
||||
}
|
||||
|
||||
interface ResFetchMsg {
|
||||
id: string;
|
||||
role: Role;
|
||||
payload: SocketMessage;
|
||||
}
|
||||
|
||||
interface ResFetchMsgs {
|
||||
messages: ResFetchMsg[];
|
||||
}
|
||||
|
||||
interface ResDelMsg {
|
||||
ok: string;
|
||||
}
|
||||
|
||||
type SocketMessage = ActionMessage | ObservationMessage;
|
||||
|
||||
export {
|
||||
type ResConfigurations,
|
||||
type ResFetchToken,
|
||||
type ResFetchMsgTotal,
|
||||
type ResFetchMsg,
|
||||
type ResFetchMsgs,
|
||||
type ResDelMsg,
|
||||
type SocketMessage,
|
||||
};
|
||||
export { type SocketMessage };
|
||||
|
||||
@@ -3,15 +3,10 @@ import toast from "react-hot-toast";
|
||||
const idMap = new Map<string, string>();
|
||||
|
||||
export default {
|
||||
stickyError: (id: string, msg: string) => {
|
||||
error: (id: string, msg: string) => {
|
||||
if (idMap.has(id)) return; // prevent duplicate toast
|
||||
const toastId = toast.loading(msg, {
|
||||
// icon: "👏",
|
||||
// style: {
|
||||
// borderRadius: "10px",
|
||||
// background: "#333",
|
||||
// color: "#fff",
|
||||
// },
|
||||
const toastId = toast(msg, {
|
||||
duration: 4000,
|
||||
style: {
|
||||
background: "#ef4444",
|
||||
color: "#fff",
|
||||
@@ -24,12 +19,13 @@ export default {
|
||||
});
|
||||
idMap.set(id, toastId);
|
||||
},
|
||||
stickySuccess: (id: string, msg: string) => {
|
||||
success: (id: string, msg: string) => {
|
||||
const toastId = idMap.get(id);
|
||||
if (toastId === undefined) return;
|
||||
if (toastId) {
|
||||
toast.success(msg, {
|
||||
id: toastId,
|
||||
duration: 4000,
|
||||
style: {
|
||||
background: "#333",
|
||||
color: "#fff",
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
TROUBLESHOOTING_URL = 'https://opendevin.github.io/OpenDevin/modules/usage/troubleshooting'
|
||||
@@ -243,6 +243,9 @@ class AgentController:
|
||||
def get_state(self):
|
||||
return self.state
|
||||
|
||||
def set_state(self, state: State):
|
||||
self.state = state
|
||||
|
||||
def _is_stuck(self):
|
||||
# check if delegate stuck
|
||||
if self.delegate and self.delegate._is_stuck():
|
||||
|
||||
@@ -3,6 +3,7 @@ import logging
|
||||
import os
|
||||
import pathlib
|
||||
import platform
|
||||
import uuid
|
||||
from dataclasses import dataclass, field, fields, is_dataclass
|
||||
from types import UnionType
|
||||
from typing import Any, ClassVar, get_args, get_origin
|
||||
@@ -173,6 +174,7 @@ class AppConfig(metaclass=Singleton):
|
||||
sandbox_user_id: int = os.getuid() if hasattr(os, 'getuid') else 1000
|
||||
sandbox_timeout: int = 120
|
||||
github_token: str | None = None
|
||||
jwt_secret: str = uuid.uuid4().hex
|
||||
debug: bool = False
|
||||
enable_auto_lint: bool = (
|
||||
False # once enabled, OpenDevin would lint files after editing
|
||||
|
||||
3
opendevin/core/const/guide_url.py
Normal file
3
opendevin/core/const/guide_url.py
Normal file
@@ -0,0 +1,3 @@
|
||||
TROUBLESHOOTING_URL = (
|
||||
'https://opendevin.github.io/OpenDevin/modules/usage/troubleshooting'
|
||||
)
|
||||
@@ -10,8 +10,8 @@ from glob import glob
|
||||
|
||||
import docker
|
||||
|
||||
from opendevin.const.guide_url import TROUBLESHOOTING_URL
|
||||
from opendevin.core.config import config
|
||||
from opendevin.core.const.guide_url import TROUBLESHOOTING_URL
|
||||
from opendevin.core.exceptions import SandboxInvalidBackgroundCommandError
|
||||
from opendevin.core.logger import opendevin_logger as logger
|
||||
from opendevin.core.schema import CancellableStream
|
||||
|
||||
@@ -12,8 +12,8 @@ from glob import glob
|
||||
import docker
|
||||
from pexpect import exceptions, pxssh
|
||||
|
||||
from opendevin.const.guide_url import TROUBLESHOOTING_URL
|
||||
from opendevin.core.config import config
|
||||
from opendevin.core.const.guide_url import TROUBLESHOOTING_URL
|
||||
from opendevin.core.exceptions import SandboxInvalidBackgroundCommandError
|
||||
from opendevin.core.logger import opendevin_logger as logger
|
||||
from opendevin.core.schema import CancellableStream
|
||||
|
||||
18
opendevin/runtime/e2b/filestore.py
Normal file
18
opendevin/runtime/e2b/filestore.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from opendevin.storage.files import FileStore
|
||||
|
||||
|
||||
class E2BFileStore(FileStore):
|
||||
def __init__(self, filesystem):
|
||||
self.filesystem = filesystem
|
||||
|
||||
def write(self, path: str, contents: str) -> None:
|
||||
self.filesystem.write(path, contents)
|
||||
|
||||
def read(self, path: str) -> str:
|
||||
return self.filesystem.read(path)
|
||||
|
||||
def list(self, path: str) -> list[str]:
|
||||
return self.filesystem.list(path)
|
||||
|
||||
def delete(self, path: str) -> None:
|
||||
self.filesystem.delete(path)
|
||||
@@ -13,6 +13,7 @@ from opendevin.runtime import Sandbox
|
||||
from opendevin.runtime.server.files import insert_lines, read_lines
|
||||
from opendevin.runtime.server.runtime import ServerRuntime
|
||||
|
||||
from .filestore import E2BFileStore
|
||||
from .sandbox import E2BSandbox
|
||||
|
||||
|
||||
@@ -26,25 +27,25 @@ class E2BRuntime(ServerRuntime):
|
||||
super().__init__(event_stream, sid, sandbox)
|
||||
if not isinstance(self.sandbox, E2BSandbox):
|
||||
raise ValueError('E2BRuntime requires an E2BSandbox')
|
||||
self.filesystem = self.sandbox.filesystem
|
||||
self.file_store = E2BFileStore(self.sandbox.filesystem)
|
||||
|
||||
async def read(self, action: FileReadAction) -> Observation:
|
||||
content = self.filesystem.read(action.path)
|
||||
content = self.file_store.read(action.path)
|
||||
lines = read_lines(content.split('\n'), action.start, action.end)
|
||||
code_view = ''.join(lines)
|
||||
return FileReadObservation(code_view, path=action.path)
|
||||
|
||||
async def write(self, action: FileWriteAction) -> Observation:
|
||||
if action.start == 0 and action.end == -1:
|
||||
self.filesystem.write(action.path, action.content)
|
||||
self.file_store.write(action.path, action.content)
|
||||
return FileWriteObservation(content='', path=action.path)
|
||||
files = self.filesystem.list(action.path)
|
||||
files = self.file_store.list(action.path)
|
||||
if action.path in files:
|
||||
all_lines = self.filesystem.read(action.path)
|
||||
all_lines = self.file_store.read(action.path).split('\n')
|
||||
new_file = insert_lines(
|
||||
action.content.split('\n'), all_lines, action.start, action.end
|
||||
)
|
||||
self.filesystem.write(action.path, ''.join(new_file))
|
||||
self.file_store.write(action.path, ''.join(new_file))
|
||||
return FileWriteObservation('', path=action.path)
|
||||
else:
|
||||
# FIXME: we should create a new file here
|
||||
|
||||
@@ -31,6 +31,7 @@ from opendevin.runtime import (
|
||||
)
|
||||
from opendevin.runtime.browser.browser_env import BrowserEnv
|
||||
from opendevin.runtime.plugins import PluginRequirement
|
||||
from opendevin.storage import FileStore, InMemoryFileStore
|
||||
|
||||
|
||||
def create_sandbox(sid: str = 'default', sandbox_type: str = 'exec') -> Sandbox:
|
||||
@@ -55,6 +56,7 @@ class Runtime:
|
||||
"""
|
||||
|
||||
sid: str
|
||||
file_store: FileStore
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -70,6 +72,7 @@ class Runtime:
|
||||
self.sandbox = sandbox
|
||||
self._is_external_sandbox = True
|
||||
self.browser = BrowserEnv()
|
||||
self.file_store = InMemoryFileStore()
|
||||
self.event_stream = event_stream
|
||||
self.event_stream.subscribe(EventStreamSubscriber.RUNTIME, self.on_event)
|
||||
self._bg_task = asyncio.create_task(self._start_background_observation_loop())
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from opendevin.core.config import config
|
||||
from opendevin.events.action import (
|
||||
AgentRecallAction,
|
||||
BrowseInteractiveAction,
|
||||
@@ -15,13 +16,25 @@ from opendevin.events.observation import (
|
||||
NullObservation,
|
||||
Observation,
|
||||
)
|
||||
from opendevin.events.stream import EventStream
|
||||
from opendevin.runtime import Sandbox
|
||||
from opendevin.runtime.runtime import Runtime
|
||||
from opendevin.storage.local import LocalFileStore
|
||||
|
||||
from .browse import browse
|
||||
from .files import read_file, write_file
|
||||
|
||||
|
||||
class ServerRuntime(Runtime):
|
||||
def __init__(
|
||||
self,
|
||||
event_stream: EventStream,
|
||||
sid: str = 'default',
|
||||
sandbox: Sandbox | None = None,
|
||||
):
|
||||
super().__init__(event_stream, sid, sandbox)
|
||||
self.file_store = LocalFileStore(config.workspace_base)
|
||||
|
||||
async def run(self, action: CmdRunAction) -> Observation:
|
||||
return self._run_command(action.command, background=action.background)
|
||||
|
||||
@@ -71,10 +84,12 @@ class ServerRuntime(Runtime):
|
||||
return IPythonRunCellObservation(content=output, code=action.code)
|
||||
|
||||
async def read(self, action: FileReadAction) -> Observation:
|
||||
# TODO: use self.file_store
|
||||
working_dir = self.sandbox.get_working_directory()
|
||||
return await read_file(action.path, working_dir, action.start, action.end)
|
||||
|
||||
async def write(self, action: FileWriteAction) -> Observation:
|
||||
# TODO: use self.file_store
|
||||
working_dir = self.sandbox.get_working_directory()
|
||||
return await write_file(
|
||||
action.path, working_dir, action.content, action.start, action.end
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
from .manager import AgentManager
|
||||
|
||||
agent_manager = AgentManager()
|
||||
|
||||
__all__ = ['AgentManager', 'agent_manager']
|
||||
@@ -1,161 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from agenthub.codeact_agent.codeact_agent import CodeActAgent
|
||||
from opendevin.const.guide_url import TROUBLESHOOTING_URL
|
||||
from opendevin.controller import AgentController
|
||||
from opendevin.controller.agent import Agent
|
||||
from opendevin.core.config import config
|
||||
from opendevin.core.logger import opendevin_logger as logger
|
||||
from opendevin.core.schema import ActionType, AgentState, ConfigType
|
||||
from opendevin.events.action import (
|
||||
ChangeAgentStateAction,
|
||||
NullAction,
|
||||
)
|
||||
from opendevin.events.event import Event
|
||||
from opendevin.events.observation import (
|
||||
NullObservation,
|
||||
)
|
||||
from opendevin.events.serialization.action import action_from_dict
|
||||
from opendevin.events.serialization.event import event_to_dict
|
||||
from opendevin.events.stream import EventSource, EventStream, EventStreamSubscriber
|
||||
from opendevin.llm.llm import LLM
|
||||
from opendevin.runtime import DockerSSHBox
|
||||
from opendevin.runtime.e2b.runtime import E2BRuntime
|
||||
from opendevin.runtime.runtime import Runtime
|
||||
from opendevin.runtime.server.runtime import ServerRuntime
|
||||
from opendevin.server.session import session_manager
|
||||
|
||||
|
||||
class AgentUnit:
|
||||
"""Represents a session with an agent.
|
||||
|
||||
Attributes:
|
||||
controller: The AgentController instance for controlling the agent.
|
||||
"""
|
||||
|
||||
sid: str
|
||||
event_stream: EventStream
|
||||
controller: Optional[AgentController] = None
|
||||
runtime: Optional[Runtime] = None
|
||||
|
||||
def __init__(self, sid):
|
||||
"""Initializes a new instance of the Session class."""
|
||||
self.sid = sid
|
||||
self.event_stream = EventStream(sid)
|
||||
self.event_stream.subscribe(EventStreamSubscriber.SERVER, self.on_event)
|
||||
if config.runtime == 'server':
|
||||
logger.info('Using server runtime')
|
||||
self.runtime = ServerRuntime(self.event_stream, sid)
|
||||
elif config.runtime == 'e2b':
|
||||
logger.info('Using E2B runtime')
|
||||
self.runtime = E2BRuntime(self.event_stream, sid)
|
||||
|
||||
async def send_error(self, message):
|
||||
"""Sends an error message to the client.
|
||||
|
||||
Args:
|
||||
message: The error message to send.
|
||||
"""
|
||||
await session_manager.send_error(self.sid, message)
|
||||
|
||||
async def send_message(self, message):
|
||||
"""Sends a message to the client.
|
||||
|
||||
Args:
|
||||
message: The message to send.
|
||||
"""
|
||||
await session_manager.send_message(self.sid, message)
|
||||
|
||||
async def send(self, data):
|
||||
"""Sends data to the client.
|
||||
|
||||
Args:
|
||||
data: The data to send.
|
||||
"""
|
||||
await session_manager.send(self.sid, data)
|
||||
|
||||
async def dispatch(self, action: str | None, data: dict):
|
||||
"""Dispatches actions to the agent from the client."""
|
||||
if action is None:
|
||||
await self.send_error('Invalid action')
|
||||
return
|
||||
|
||||
if action == ActionType.INIT:
|
||||
await self.create_controller(data)
|
||||
await self.event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.INIT), EventSource.USER
|
||||
)
|
||||
return
|
||||
|
||||
action_dict = data.copy()
|
||||
action_dict['action'] = action
|
||||
action_obj = action_from_dict(action_dict)
|
||||
await self.event_stream.add_event(action_obj, EventSource.USER)
|
||||
|
||||
async def create_controller(self, start_event: dict):
|
||||
"""Creates an AgentController instance.
|
||||
|
||||
Args:
|
||||
start_event: The start event data (optional).
|
||||
"""
|
||||
args = {
|
||||
key: value
|
||||
for key, value in start_event.get('args', {}).items()
|
||||
if value != ''
|
||||
} # remove empty values, prevent FE from sending empty strings
|
||||
agent_cls = args.get(ConfigType.AGENT, config.agent.name)
|
||||
model = args.get(ConfigType.LLM_MODEL, config.llm.model)
|
||||
api_key = args.get(ConfigType.LLM_API_KEY, config.llm.api_key)
|
||||
api_base = config.llm.base_url
|
||||
max_iterations = args.get(ConfigType.MAX_ITERATIONS, config.max_iterations)
|
||||
max_chars = args.get(ConfigType.MAX_CHARS, config.llm.max_chars)
|
||||
|
||||
logger.info(f'Creating agent {agent_cls} using LLM {model}')
|
||||
llm = LLM(model=model, api_key=api_key, base_url=api_base)
|
||||
agent = Agent.get_cls(agent_cls)(llm)
|
||||
if isinstance(agent, CodeActAgent):
|
||||
if not self.runtime or not isinstance(self.runtime.sandbox, DockerSSHBox):
|
||||
logger.warning(
|
||||
'CodeActAgent requires DockerSSHBox as sandbox! Using other sandbox that are not stateful (LocalBox, DockerExecBox) will not work properly.'
|
||||
)
|
||||
# Initializing plugins into the runtime
|
||||
assert self.runtime is not None, 'Runtime is not initialized'
|
||||
self.runtime.init_sandbox_plugins(agent.sandbox_plugins)
|
||||
|
||||
if self.controller is not None:
|
||||
await self.controller.close()
|
||||
try:
|
||||
self.controller = AgentController(
|
||||
sid=self.sid,
|
||||
event_stream=self.event_stream,
|
||||
agent=agent,
|
||||
max_iterations=int(max_iterations),
|
||||
max_chars=int(max_chars),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f'Error creating controller: {e}')
|
||||
await self.send_error(
|
||||
f'Error creating controller. Please check Docker is running and visit `{TROUBLESHOOTING_URL}` for more debugging information..'
|
||||
)
|
||||
return
|
||||
|
||||
async def on_event(self, event: Event):
|
||||
"""Callback function for agent events.
|
||||
|
||||
Args:
|
||||
event: The agent event (Observation or Action).
|
||||
"""
|
||||
if isinstance(event, NullAction):
|
||||
return
|
||||
if isinstance(event, NullObservation):
|
||||
return
|
||||
if event.source == 'agent' and not isinstance(
|
||||
event, (NullAction, NullObservation)
|
||||
):
|
||||
await self.send(event_to_dict(event))
|
||||
|
||||
async def close(self):
|
||||
if self.controller is not None:
|
||||
await self.controller.close()
|
||||
if self.runtime is not None:
|
||||
self.runtime.close()
|
||||
@@ -1,48 +0,0 @@
|
||||
import asyncio, atexit
|
||||
|
||||
from opendevin.core.logger import opendevin_logger as logger
|
||||
from opendevin.server.session import session_manager
|
||||
|
||||
from .agent import AgentUnit
|
||||
|
||||
|
||||
class AgentManager:
|
||||
sid_to_agent: dict[str, 'AgentUnit'] = {}
|
||||
|
||||
def __init__(self):
|
||||
atexit.register(self.close)
|
||||
|
||||
def register_agent(self, sid: str):
|
||||
"""Registers a new agent.
|
||||
|
||||
Args:
|
||||
sid: The session ID of the agent.
|
||||
"""
|
||||
if sid not in self.sid_to_agent:
|
||||
self.sid_to_agent[sid] = AgentUnit(sid)
|
||||
return
|
||||
|
||||
# TODO: confirm whether the agent is alive
|
||||
|
||||
async def dispatch(self, sid: str, action: str | None, data: dict):
|
||||
"""Dispatches actions to the agent from the client."""
|
||||
if sid not in self.sid_to_agent:
|
||||
# self.register_agent(sid) # auto-register agent, may be opened later
|
||||
logger.error(f'Agent not registered: {sid}')
|
||||
await session_manager.send_error(sid, 'Agent not registered')
|
||||
return
|
||||
|
||||
await self.sid_to_agent[sid].dispatch(action, data)
|
||||
|
||||
def close(self):
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(self._close())
|
||||
|
||||
async def _close(self):
|
||||
logger.info(f'Closing {len(self.sid_to_agent)} agent(s)...')
|
||||
for sid, agent in self.sid_to_agent.items():
|
||||
await agent.close()
|
||||
@@ -1,12 +1,9 @@
|
||||
import os
|
||||
|
||||
import jwt
|
||||
from jwt.exceptions import InvalidTokenError
|
||||
|
||||
from opendevin.core.config import config
|
||||
from opendevin.core.logger import opendevin_logger as logger
|
||||
|
||||
JWT_SECRET = os.getenv('JWT_SECRET', '5ecRe7')
|
||||
|
||||
|
||||
def get_sid_from_token(token: str) -> str:
|
||||
"""
|
||||
@@ -20,7 +17,7 @@ def get_sid_from_token(token: str) -> str:
|
||||
"""
|
||||
try:
|
||||
# Decode the JWT using the specified secret and algorithm
|
||||
payload = jwt.decode(token, JWT_SECRET, algorithms=['HS256'])
|
||||
payload = jwt.decode(token, config.jwt_secret, algorithms=['HS256'])
|
||||
|
||||
# Ensure the payload contains 'sid'
|
||||
if 'sid' in payload:
|
||||
@@ -41,4 +38,4 @@ def sign_token(payload: dict[str, object]) -> str:
|
||||
# "sid": sid,
|
||||
# # "exp": datetime.now(timezone.utc) + timedelta(minutes=15),
|
||||
# }
|
||||
return jwt.encode(payload, JWT_SECRET, algorithm='HS256')
|
||||
return jwt.encode(payload, config.jwt_secret, algorithm='HS256')
|
||||
|
||||
@@ -1,26 +1,25 @@
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore')
|
||||
import litellm
|
||||
from fastapi import Depends, FastAPI, Request, Response, UploadFile, WebSocket, status
|
||||
from fastapi import FastAPI, Request, Response, UploadFile, WebSocket, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from fastapi.security import HTTPBearer
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
import agenthub # noqa F401 (we import this to get the agents registered)
|
||||
from opendevin.controller.agent import Agent
|
||||
from opendevin.core.config import config
|
||||
from opendevin.core.logger import opendevin_logger as logger
|
||||
from opendevin.events.action import ChangeAgentStateAction, NullAction
|
||||
from opendevin.events.observation import AgentStateChangedObservation, NullObservation
|
||||
from opendevin.events.serialization import event_to_dict
|
||||
from opendevin.llm import bedrock
|
||||
from opendevin.server.agent import agent_manager
|
||||
from opendevin.server.auth import get_sid_from_token, sign_token
|
||||
from opendevin.server.session import message_stack, session_manager
|
||||
from opendevin.server.session import session_manager
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
@@ -34,6 +33,45 @@ app.add_middleware(
|
||||
security_scheme = HTTPBearer()
|
||||
|
||||
|
||||
@app.middleware('http')
|
||||
async def attach_session(request: Request, call_next):
|
||||
if request.url.path.startswith('/api/options/') or not request.url.path.startswith(
|
||||
'/api/'
|
||||
):
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
if not request.headers.get('Authorization'):
|
||||
response = JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={'error': 'Missing Authorization header'},
|
||||
)
|
||||
return response
|
||||
|
||||
auth_token = request.headers.get('Authorization')
|
||||
if 'Bearer' in auth_token:
|
||||
auth_token = auth_token.split('Bearer')[1].strip()
|
||||
|
||||
request.state.sid = get_sid_from_token(auth_token)
|
||||
if request.state.sid == '':
|
||||
response = JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={'error': 'Invalid token'},
|
||||
)
|
||||
return response
|
||||
|
||||
request.state.session = session_manager.get_session(request.state.sid)
|
||||
if request.state.session is None:
|
||||
response = JSONResponse(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
content={'error': 'Session not found'},
|
||||
)
|
||||
return response
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
|
||||
# This endpoint receives events from the client (i.e. the browser)
|
||||
@app.websocket('/ws')
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
@@ -99,16 +137,41 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
```
|
||||
"""
|
||||
await websocket.accept()
|
||||
sid = get_sid_from_token(websocket.query_params.get('token') or '')
|
||||
if sid == '':
|
||||
logger.error('Failed to decode token')
|
||||
return
|
||||
session_manager.add_session(sid, websocket)
|
||||
agent_manager.register_agent(sid)
|
||||
await session_manager.loop_recv(sid, agent_manager.dispatch)
|
||||
|
||||
session = None
|
||||
if websocket.query_params.get('token'):
|
||||
token = websocket.query_params.get('token')
|
||||
sid = get_sid_from_token(token)
|
||||
|
||||
if sid == '':
|
||||
await websocket.send_json({'error': 'Invalid token', 'error_code': 401})
|
||||
await websocket.close()
|
||||
return
|
||||
else:
|
||||
sid = str(uuid.uuid4())
|
||||
token = sign_token({'sid': sid})
|
||||
|
||||
session = session_manager.add_or_restart_session(sid, websocket)
|
||||
await websocket.send_json({'token': token, 'status': 'ok'})
|
||||
|
||||
last_event_id = -1
|
||||
if websocket.query_params.get('last_event_id'):
|
||||
last_event_id = int(websocket.query_params.get('last_event_id'))
|
||||
for event in session.agent_session.event_stream.get_events(
|
||||
start_id=last_event_id + 1
|
||||
):
|
||||
if isinstance(event, NullAction) or isinstance(event, NullObservation):
|
||||
continue
|
||||
if isinstance(event, ChangeAgentStateAction) or isinstance(
|
||||
event, AgentStateChangedObservation
|
||||
):
|
||||
continue
|
||||
await websocket.send_json(event_to_dict(event))
|
||||
|
||||
await session.loop_recv()
|
||||
|
||||
|
||||
@app.get('/api/litellm-models')
|
||||
@app.get('/api/options/models')
|
||||
async def get_litellm_models():
|
||||
"""
|
||||
Get all models supported by LiteLLM.
|
||||
@@ -128,7 +191,7 @@ async def get_litellm_models():
|
||||
return list(set(model_list))
|
||||
|
||||
|
||||
@app.get('/api/agents')
|
||||
@app.get('/api/options/agents')
|
||||
async def get_agents():
|
||||
"""
|
||||
Get all agents supported by LiteLLM.
|
||||
@@ -142,89 +205,6 @@ async def get_agents():
|
||||
return agents
|
||||
|
||||
|
||||
@app.get('/api/auth')
|
||||
async def get_token(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security_scheme),
|
||||
):
|
||||
"""
|
||||
Generate a JWT for authentication when starting a WebSocket connection. This endpoint checks if valid credentials
|
||||
are provided and uses them to get a session ID. If no valid credentials are provided, it generates a new session ID.
|
||||
|
||||
To obtain an authentication token:
|
||||
```sh
|
||||
curl -H "Authorization: Bearer 5ecRe7" http://localhost:3000/api/auth
|
||||
```
|
||||
**Note:** If `JWT_SECRET` is set, use its value instead of `5ecRe7`.
|
||||
"""
|
||||
if credentials and credentials.credentials:
|
||||
sid = get_sid_from_token(credentials.credentials)
|
||||
if not sid:
|
||||
sid = str(uuid.uuid4())
|
||||
logger.info(
|
||||
f'Invalid or missing credentials, generating new session ID: {sid}'
|
||||
)
|
||||
else:
|
||||
sid = str(uuid.uuid4())
|
||||
logger.info(f'No credentials provided, generating new session ID: {sid}')
|
||||
|
||||
token = sign_token({'sid': sid})
|
||||
return {'token': token, 'status': 'ok'}
|
||||
|
||||
|
||||
@app.get('/api/messages')
|
||||
async def get_messages(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security_scheme),
|
||||
):
|
||||
"""
|
||||
Get messages.
|
||||
|
||||
To get messages:
|
||||
```sh
|
||||
curl -H "Authorization: Bearer <TOKEN>" http://localhost:3000/api/messages
|
||||
```
|
||||
"""
|
||||
data = []
|
||||
sid = get_sid_from_token(credentials.credentials)
|
||||
if sid != '':
|
||||
data = message_stack.get_messages(sid)
|
||||
|
||||
return {'messages': data}
|
||||
|
||||
|
||||
@app.get('/api/messages/total')
|
||||
async def get_message_total(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security_scheme),
|
||||
):
|
||||
"""
|
||||
Get total message count.
|
||||
|
||||
To get the total message count:
|
||||
```sh
|
||||
curl -H "Authorization: Bearer <TOKEN>" http://localhost:3000/api/messages/total
|
||||
```
|
||||
"""
|
||||
sid = get_sid_from_token(credentials.credentials)
|
||||
return {'msg_total': message_stack.get_message_total(sid)}
|
||||
|
||||
|
||||
@app.delete('/api/messages')
|
||||
async def del_messages(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security_scheme),
|
||||
):
|
||||
"""
|
||||
Delete messages.
|
||||
|
||||
To delete messages:
|
||||
```sh
|
||||
|
||||
curl -X DELETE -H "Authorization: Bearer <TOKEN>" http://localhost:3000/api/messages
|
||||
```
|
||||
"""
|
||||
sid = get_sid_from_token(credentials.credentials)
|
||||
message_stack.del_messages(sid)
|
||||
return {'ok': True}
|
||||
|
||||
|
||||
@app.get('/api/list-files')
|
||||
def list_files(request: Request, path: str = '/'):
|
||||
"""
|
||||
@@ -235,27 +215,25 @@ def list_files(request: Request, path: str = '/'):
|
||||
curl http://localhost:3000/api/list-files
|
||||
```
|
||||
"""
|
||||
if path.startswith('/'):
|
||||
path = path[1:]
|
||||
abs_path = os.path.join(config.workspace_base, path)
|
||||
try:
|
||||
files = os.listdir(abs_path)
|
||||
except Exception as e:
|
||||
logger.error(f'Error listing files: {e}', exc_info=False)
|
||||
if not request.state.session.agent_session.runtime:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
content={'error': 'Path not found'},
|
||||
content={'error': 'Runtime not yet initialized'},
|
||||
)
|
||||
|
||||
try:
|
||||
return request.state.session.agent_session.runtime.file_store.list(path)
|
||||
except Exception as e:
|
||||
logger.error(f'Error refreshing files: {e}', exc_info=False)
|
||||
error_msg = f'Error refreshing files: {e}'
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content={'error': error_msg},
|
||||
)
|
||||
files = [os.path.join(path, f) for f in files]
|
||||
files = [
|
||||
f + '/' if os.path.isdir(os.path.join(config.workspace_base, f)) else f
|
||||
for f in files
|
||||
]
|
||||
return files
|
||||
|
||||
|
||||
@app.get('/api/select-file')
|
||||
def select_file(file: str):
|
||||
def select_file(file: str, request: Request):
|
||||
"""
|
||||
Select a file.
|
||||
|
||||
@@ -265,12 +243,7 @@ def select_file(file: str):
|
||||
```
|
||||
"""
|
||||
try:
|
||||
workspace_base = config.workspace_base
|
||||
file_path = Path(workspace_base, file)
|
||||
# The following will check if the file is within the workspace base and throw an exception if not
|
||||
file_path.resolve().relative_to(Path(workspace_base).resolve())
|
||||
with open(file_path, 'r') as selected_file:
|
||||
content = selected_file.read()
|
||||
content = request.state.session.agent_session.runtime.file_store.read(file)
|
||||
except Exception as e:
|
||||
logger.error(f'Error opening file {file}: {e}', exc_info=False)
|
||||
error_msg = f'Error opening file: {e}'
|
||||
@@ -282,7 +255,7 @@ def select_file(file: str):
|
||||
|
||||
|
||||
@app.post('/api/upload-files')
|
||||
async def upload_files(files: list[UploadFile]):
|
||||
async def upload_file(request: Request, files: list[UploadFile]):
|
||||
"""
|
||||
Upload files to the workspace.
|
||||
|
||||
@@ -292,13 +265,11 @@ async def upload_files(files: list[UploadFile]):
|
||||
```
|
||||
"""
|
||||
try:
|
||||
workspace_base = config.workspace_base
|
||||
for file in files:
|
||||
file_path = Path(workspace_base, file.filename)
|
||||
# The following will check if the file is within the workspace base and throw an exception if not
|
||||
file_path.resolve().relative_to(Path(workspace_base).resolve())
|
||||
with open(file_path, 'wb') as buffer:
|
||||
shutil.copyfileobj(file.file, buffer)
|
||||
file_contents = await file.read()
|
||||
request.state.session.agent_session.runtime.file_store.write(
|
||||
file.filename, file_contents
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'Error saving files: {e}', exc_info=True)
|
||||
return JSONResponse(
|
||||
@@ -309,9 +280,7 @@ async def upload_files(files: list[UploadFile]):
|
||||
|
||||
|
||||
@app.get('/api/root_task')
|
||||
def get_root_task(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security_scheme),
|
||||
):
|
||||
def get_root_task(request: Request):
|
||||
"""
|
||||
Get root_task.
|
||||
|
||||
@@ -320,9 +289,7 @@ def get_root_task(
|
||||
curl -H "Authorization: Bearer <TOKEN>" http://localhost:3000/api/root_task
|
||||
```
|
||||
"""
|
||||
sid = get_sid_from_token(credentials.credentials)
|
||||
agent = agent_manager.sid_to_agent[sid]
|
||||
controller = agent.controller
|
||||
controller = request.state.session.agent_session.controller
|
||||
if controller is not None:
|
||||
state = controller.get_state()
|
||||
if state:
|
||||
|
||||
@@ -33,7 +33,7 @@ def read_root():
|
||||
return {'message': 'This is a mock server'}
|
||||
|
||||
|
||||
@app.get('/api/litellm-models')
|
||||
@app.get('/api/options/models')
|
||||
def read_llm_models():
|
||||
return [
|
||||
'gpt-4',
|
||||
@@ -43,7 +43,7 @@ def read_llm_models():
|
||||
]
|
||||
|
||||
|
||||
@app.get('/api/agents')
|
||||
@app.get('/api/options/agents')
|
||||
def read_llm_agents():
|
||||
return [
|
||||
'MonologueAgent',
|
||||
@@ -52,16 +52,6 @@ def read_llm_agents():
|
||||
]
|
||||
|
||||
|
||||
@app.get('/api/messages')
|
||||
async def get_messages():
|
||||
return {'messages': []}
|
||||
|
||||
|
||||
@app.get('/api/messages/total')
|
||||
async def get_message_total():
|
||||
return {'msg_total': 0}
|
||||
|
||||
|
||||
@app.get('/api/list-files')
|
||||
def refresh_files():
|
||||
return ['hello_world.py']
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from .manager import SessionManager
|
||||
from .msg_stack import message_stack
|
||||
from .session import Session
|
||||
|
||||
session_manager = SessionManager()
|
||||
|
||||
117
opendevin/server/session/agent.py
Normal file
117
opendevin/server/session/agent.py
Normal file
@@ -0,0 +1,117 @@
|
||||
from typing import Optional
|
||||
|
||||
from agenthub.codeact_agent.codeact_agent import CodeActAgent
|
||||
from opendevin.controller import AgentController
|
||||
from opendevin.controller.agent import Agent
|
||||
from opendevin.controller.state.state import State
|
||||
from opendevin.core.config import config
|
||||
from opendevin.core.logger import opendevin_logger as logger
|
||||
from opendevin.core.schema import ConfigType
|
||||
from opendevin.events.stream import EventStream
|
||||
from opendevin.llm.llm import LLM
|
||||
from opendevin.runtime import DockerSSHBox
|
||||
from opendevin.runtime.e2b.runtime import E2BRuntime
|
||||
from opendevin.runtime.runtime import Runtime
|
||||
from opendevin.runtime.server.runtime import ServerRuntime
|
||||
|
||||
|
||||
class AgentSession:
|
||||
"""Represents a session with an agent.
|
||||
|
||||
Attributes:
|
||||
controller: The AgentController instance for controlling the agent.
|
||||
"""
|
||||
|
||||
sid: str
|
||||
event_stream: EventStream
|
||||
controller: Optional[AgentController] = None
|
||||
runtime: Optional[Runtime] = None
|
||||
_closed: bool = False
|
||||
|
||||
def __init__(self, sid):
|
||||
"""Initializes a new instance of the Session class."""
|
||||
self.sid = sid
|
||||
self.event_stream = EventStream(sid)
|
||||
|
||||
async def start(self, start_event: dict):
|
||||
"""Starts the agent session.
|
||||
|
||||
Args:
|
||||
start_event: The start event data (optional).
|
||||
"""
|
||||
if self.controller or self.runtime:
|
||||
raise Exception(
|
||||
'Session already started. You need to close this session and start a new one.'
|
||||
)
|
||||
await self._create_runtime()
|
||||
await self._create_controller(start_event)
|
||||
|
||||
async def close(self):
|
||||
if self._closed:
|
||||
return
|
||||
if self.controller is not None:
|
||||
end_state = self.controller.get_state()
|
||||
end_state.save_to_session(self.sid)
|
||||
await self.controller.close()
|
||||
if self.runtime is not None:
|
||||
self.runtime.close()
|
||||
self._closed = True
|
||||
|
||||
async def _create_runtime(self):
|
||||
if self.runtime is not None:
|
||||
raise Exception('Runtime already created')
|
||||
if config.runtime == 'server':
|
||||
logger.info('Using server runtime')
|
||||
self.runtime = ServerRuntime(self.event_stream, self.sid)
|
||||
elif config.runtime == 'e2b':
|
||||
logger.info('Using E2B runtime')
|
||||
self.runtime = E2BRuntime(self.event_stream, self.sid)
|
||||
else:
|
||||
raise Exception(
|
||||
f'Runtime not defined in config, or is invalid: {config.runtime}'
|
||||
)
|
||||
|
||||
async def _create_controller(self, start_event: dict):
|
||||
"""Creates an AgentController instance.
|
||||
|
||||
Args:
|
||||
start_event: The start event data (optional).
|
||||
"""
|
||||
if self.controller is not None:
|
||||
raise Exception('Controller already created')
|
||||
if self.runtime is None:
|
||||
raise Exception('Runtime must be initialized before the agent controller')
|
||||
args = {
|
||||
key: value
|
||||
for key, value in start_event.get('args', {}).items()
|
||||
if value != ''
|
||||
} # remove empty values, prevent FE from sending empty strings
|
||||
agent_cls = args.get(ConfigType.AGENT, config.agent.name)
|
||||
model = args.get(ConfigType.LLM_MODEL, config.llm.model)
|
||||
api_key = args.get(ConfigType.LLM_API_KEY, config.llm.api_key)
|
||||
api_base = config.llm.base_url
|
||||
max_iterations = args.get(ConfigType.MAX_ITERATIONS, config.max_iterations)
|
||||
max_chars = args.get(ConfigType.MAX_CHARS, config.llm.max_chars)
|
||||
|
||||
logger.info(f'Creating agent {agent_cls} using LLM {model}')
|
||||
llm = LLM(model=model, api_key=api_key, base_url=api_base)
|
||||
agent = Agent.get_cls(agent_cls)(llm)
|
||||
if isinstance(agent, CodeActAgent):
|
||||
if not self.runtime or not isinstance(self.runtime.sandbox, DockerSSHBox):
|
||||
logger.warning(
|
||||
'CodeActAgent requires DockerSSHBox as sandbox! Using other sandbox that are not stateful (LocalBox, DockerExecBox) will not work properly.'
|
||||
)
|
||||
self.runtime.init_sandbox_plugins(agent.sandbox_plugins)
|
||||
|
||||
self.controller = AgentController(
|
||||
sid=self.sid,
|
||||
event_stream=self.event_stream,
|
||||
agent=agent,
|
||||
max_iterations=int(max_iterations),
|
||||
max_chars=int(max_chars),
|
||||
)
|
||||
try:
|
||||
agent_state = State.restore_from_session(self.sid)
|
||||
self.controller.set_state(agent_state)
|
||||
except Exception as e:
|
||||
print('Error restoring state', e)
|
||||
@@ -1,20 +1,12 @@
|
||||
import asyncio
|
||||
import atexit
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Callable
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
from opendevin.core.logger import opendevin_logger as logger
|
||||
|
||||
from .msg_stack import message_stack
|
||||
from .session import Session
|
||||
|
||||
CACHE_DIR = os.getenv('CACHE_DIR', 'cache')
|
||||
SESSION_CACHE_FILE = os.path.join(CACHE_DIR, 'sessions.json')
|
||||
|
||||
|
||||
class SessionManager:
|
||||
_sessions: dict[str, Session] = {}
|
||||
@@ -22,30 +14,21 @@ class SessionManager:
|
||||
session_timeout: int = 600
|
||||
|
||||
def __init__(self):
|
||||
self._load_sessions()
|
||||
atexit.register(self.close)
|
||||
asyncio.create_task(self._cleanup_sessions())
|
||||
|
||||
def add_session(self, sid: str, ws_conn: WebSocket):
|
||||
if sid not in self._sessions:
|
||||
self._sessions[sid] = Session(sid=sid, ws=ws_conn)
|
||||
return
|
||||
self._sessions[sid].update_connection(ws_conn)
|
||||
def add_or_restart_session(self, sid: str, ws_conn: WebSocket) -> Session:
|
||||
if sid in self._sessions:
|
||||
asyncio.create_task(self._sessions[sid].close())
|
||||
self._sessions[sid] = Session(sid=sid, ws=ws_conn)
|
||||
return self._sessions[sid]
|
||||
|
||||
async def loop_recv(self, sid: str, dispatch: Callable):
|
||||
print(f'Starting loop_recv for sid: {sid}')
|
||||
"""Starts listening for messages from the client."""
|
||||
def get_session(self, sid: str) -> Session | None:
|
||||
if sid not in self._sessions:
|
||||
return
|
||||
await self._sessions[sid].loop_recv(dispatch)
|
||||
|
||||
def close(self):
|
||||
logger.info('Saving sessions...')
|
||||
self._save_sessions()
|
||||
return None
|
||||
return self._sessions.get(sid)
|
||||
|
||||
async def send(self, sid: str, data: dict[str, object]) -> bool:
|
||||
"""Sends data to the client."""
|
||||
message_stack.add_message(sid, 'assistant', data)
|
||||
if sid not in self._sessions:
|
||||
return False
|
||||
return await self._sessions[sid].send(data)
|
||||
@@ -58,33 +41,6 @@ class SessionManager:
|
||||
"""Sends a message to the client."""
|
||||
return await self.send(sid, {'message': message})
|
||||
|
||||
def _save_sessions(self):
|
||||
data = {}
|
||||
for sid, conn in self._sessions.items():
|
||||
data[sid] = {
|
||||
'sid': conn.sid,
|
||||
'last_active_ts': conn.last_active_ts,
|
||||
'is_alive': conn.is_alive,
|
||||
}
|
||||
if not os.path.exists(CACHE_DIR):
|
||||
os.makedirs(CACHE_DIR)
|
||||
with open(SESSION_CACHE_FILE, 'w+') as file:
|
||||
json.dump(data, file)
|
||||
|
||||
def _load_sessions(self):
|
||||
try:
|
||||
with open(SESSION_CACHE_FILE, 'r') as file:
|
||||
data = json.load(file)
|
||||
for sid, sdata in data.items():
|
||||
conn = Session(sid, None)
|
||||
ok = conn.load_from_data(sdata)
|
||||
if ok:
|
||||
self._sessions[sid] = conn
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except json.decoder.JSONDecodeError:
|
||||
pass
|
||||
|
||||
async def _cleanup_sessions(self):
|
||||
while True:
|
||||
current_time = time.time()
|
||||
|
||||
@@ -1,114 +0,0 @@
|
||||
import asyncio
|
||||
import atexit
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from opendevin.core.logger import opendevin_logger as logger
|
||||
from opendevin.core.schema.action import ActionType
|
||||
|
||||
CACHE_DIR = os.getenv('CACHE_DIR', 'cache')
|
||||
MSG_CACHE_FILE = os.path.join(CACHE_DIR, 'messages.json')
|
||||
|
||||
|
||||
class Message:
|
||||
id: str = str(uuid.uuid4())
|
||||
role: str # "user"| "assistant"
|
||||
payload: dict[str, object]
|
||||
|
||||
def __init__(self, role: str, payload: dict[str, object]):
|
||||
self.role = role
|
||||
self.payload = payload
|
||||
|
||||
def to_dict(self):
|
||||
return {'id': self.id, 'role': self.role, 'payload': self.payload}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict):
|
||||
m = cls(data['role'], data['payload'])
|
||||
m.id = data['id']
|
||||
return m
|
||||
|
||||
|
||||
class MessageStack:
|
||||
_messages: dict[str, list[Message]] = {}
|
||||
|
||||
def __init__(self):
|
||||
self._load_messages()
|
||||
atexit.register(self.close)
|
||||
|
||||
def close(self):
|
||||
logger.info('Saving messages...')
|
||||
self._save_messages()
|
||||
|
||||
def add_message(self, sid: str, role: str, message: dict[str, object]):
|
||||
if sid not in self._messages:
|
||||
self._messages[sid] = []
|
||||
self._messages[sid].append(Message(role, message))
|
||||
|
||||
def del_messages(self, sid: str):
|
||||
if sid not in self._messages:
|
||||
return
|
||||
del self._messages[sid]
|
||||
asyncio.create_task(self._del_messages(sid))
|
||||
|
||||
def get_messages(self, sid: str) -> list[dict[str, object]]:
|
||||
if sid not in self._messages:
|
||||
return []
|
||||
return [msg.to_dict() for msg in self._messages[sid]]
|
||||
|
||||
def get_message_total(self, sid: str) -> int:
|
||||
if sid not in self._messages:
|
||||
return 0
|
||||
cnt = 0
|
||||
for msg in self._messages[sid]:
|
||||
# Ignore assistant init message for now.
|
||||
if 'action' in msg.payload and msg.payload['action'] in [
|
||||
ActionType.INIT,
|
||||
ActionType.CHANGE_AGENT_STATE,
|
||||
]:
|
||||
continue
|
||||
cnt += 1
|
||||
return cnt
|
||||
|
||||
def _save_messages(self):
|
||||
if not os.path.exists(CACHE_DIR):
|
||||
os.makedirs(CACHE_DIR)
|
||||
data = {}
|
||||
for sid, msgs in self._messages.items():
|
||||
data[sid] = [msg.to_dict() for msg in msgs]
|
||||
with open(MSG_CACHE_FILE, 'w+') as file:
|
||||
json.dump(data, file)
|
||||
|
||||
def _load_messages(self):
|
||||
try:
|
||||
with open(MSG_CACHE_FILE, 'r') as file:
|
||||
data = json.load(file)
|
||||
for sid, msgs in data.items():
|
||||
self._messages[sid] = [Message.from_dict(msg) for msg in msgs]
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except json.decoder.JSONDecodeError:
|
||||
pass
|
||||
|
||||
async def _del_messages(self, del_sid: str):
|
||||
logger.info('Deleting messages...')
|
||||
try:
|
||||
with open(MSG_CACHE_FILE, 'r+') as file:
|
||||
data = json.load(file)
|
||||
new_data = {}
|
||||
for sid, msgs in data.items():
|
||||
if sid != del_sid:
|
||||
new_data[sid] = msgs
|
||||
# Move the file pointer to the beginning of the file to overwrite the original contents
|
||||
file.seek(0)
|
||||
# clean previous content
|
||||
file.truncate()
|
||||
json.dump(new_data, file)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except json.decoder.JSONDecodeError:
|
||||
pass
|
||||
|
||||
|
||||
message_stack = MessageStack()
|
||||
@@ -1,11 +1,19 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Callable
|
||||
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
|
||||
from opendevin.core.const.guide_url import TROUBLESHOOTING_URL
|
||||
from opendevin.core.logger import opendevin_logger as logger
|
||||
from opendevin.core.schema import AgentState
|
||||
from opendevin.core.schema.action import ActionType
|
||||
from opendevin.events.action import ChangeAgentStateAction, NullAction
|
||||
from opendevin.events.event import Event
|
||||
from opendevin.events.observation import AgentStateChangedObservation, NullObservation
|
||||
from opendevin.events.serialization import EventSource, event_from_dict, event_to_dict
|
||||
from opendevin.events.stream import EventStreamSubscriber
|
||||
|
||||
from .msg_stack import message_stack
|
||||
from .agent import AgentSession
|
||||
|
||||
DEL_DELT_SEC = 60 * 60 * 5
|
||||
|
||||
@@ -15,13 +23,22 @@ class Session:
|
||||
websocket: WebSocket | None
|
||||
last_active_ts: int = 0
|
||||
is_alive: bool = True
|
||||
agent_session: AgentSession
|
||||
|
||||
def __init__(self, sid: str, ws: WebSocket | None):
|
||||
self.sid = sid
|
||||
self.websocket = ws
|
||||
self.last_active_ts = int(time.time())
|
||||
self.agent_session = AgentSession(sid)
|
||||
self.agent_session.event_stream.subscribe(
|
||||
EventStreamSubscriber.SERVER, self.on_event
|
||||
)
|
||||
|
||||
async def loop_recv(self, dispatch: Callable):
|
||||
async def close(self):
|
||||
self.is_alive = False
|
||||
await self.agent_session.close()
|
||||
|
||||
async def loop_recv(self):
|
||||
try:
|
||||
if self.websocket is None:
|
||||
return
|
||||
@@ -31,24 +48,62 @@ class Session:
|
||||
except ValueError:
|
||||
await self.send_error('Invalid JSON')
|
||||
continue
|
||||
|
||||
message_stack.add_message(self.sid, 'user', data)
|
||||
action = data.get('action', None)
|
||||
await dispatch(self.sid, action, data)
|
||||
await self.dispatch(data)
|
||||
except WebSocketDisconnect:
|
||||
self.is_alive = False
|
||||
await self.close()
|
||||
logger.info('WebSocket disconnected, sid: %s', self.sid)
|
||||
except RuntimeError as e:
|
||||
# WebSocket is not connected
|
||||
if 'WebSocket is not connected' in str(e):
|
||||
self.is_alive = False
|
||||
await self.close()
|
||||
logger.exception('Error in loop_recv: %s', e)
|
||||
|
||||
async def _initialize_agent(self, data: dict):
|
||||
await self.agent_session.event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.LOADING), EventSource.USER
|
||||
)
|
||||
await self.agent_session.event_stream.add_event(
|
||||
AgentStateChangedObservation('', AgentState.LOADING), EventSource.AGENT
|
||||
)
|
||||
try:
|
||||
await self.agent_session.start(data)
|
||||
except Exception as e:
|
||||
logger.exception(f'Error creating controller: {e}')
|
||||
await self.send_error(
|
||||
f'Error creating controller. Please check Docker is running and visit `{TROUBLESHOOTING_URL}` for more debugging information..'
|
||||
)
|
||||
return
|
||||
await self.agent_session.event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.INIT), EventSource.USER
|
||||
)
|
||||
|
||||
async def on_event(self, event: Event):
|
||||
"""Callback function for agent events.
|
||||
|
||||
Args:
|
||||
event: The agent event (Observation or Action).
|
||||
"""
|
||||
if isinstance(event, NullAction):
|
||||
return
|
||||
if isinstance(event, NullObservation):
|
||||
return
|
||||
if event.source == EventSource.AGENT and not isinstance(
|
||||
event, (NullAction, NullObservation)
|
||||
):
|
||||
await self.send(event_to_dict(event))
|
||||
|
||||
async def dispatch(self, data: dict):
|
||||
action = data.get('action', '')
|
||||
if action == ActionType.INIT:
|
||||
await self._initialize_agent(data)
|
||||
return
|
||||
event = event_from_dict(data.copy())
|
||||
await self.agent_session.event_stream.add_event(event, EventSource.USER)
|
||||
|
||||
async def send(self, data: dict[str, object]) -> bool:
|
||||
try:
|
||||
if self.websocket is None or not self.is_alive:
|
||||
return False
|
||||
await self.websocket.send_json(data)
|
||||
await asyncio.sleep(0.001) # This flushes the data to the client
|
||||
self.last_active_ts = int(time.time())
|
||||
return True
|
||||
except WebSocketDisconnect:
|
||||
|
||||
@@ -28,7 +28,9 @@ class LocalFileStore(FileStore):
|
||||
|
||||
def list(self, path: str) -> list[str]:
|
||||
full_path = self.get_full_path(path)
|
||||
return [os.path.join(path, f) for f in os.listdir(full_path)]
|
||||
files = [os.path.join(path, f) for f in os.listdir(full_path)]
|
||||
files = [f + '/' if os.path.isdir(self.get_full_path(f)) else f for f in files]
|
||||
return files
|
||||
|
||||
def delete(self, path: str) -> None:
|
||||
full_path = self.get_full_path(path)
|
||||
|
||||
@@ -30,6 +30,8 @@ class InMemoryFileStore(FileStore):
|
||||
files.append(file)
|
||||
else:
|
||||
dir_path = os.path.join(path, parts[0])
|
||||
if not dir_path.endswith('/'):
|
||||
dir_path += '/'
|
||||
if dir_path not in files:
|
||||
files.append(dir_path)
|
||||
return files
|
||||
|
||||
@@ -54,10 +54,8 @@ def test_deep_list(setup_env):
|
||||
store.write('foo/bar/baz.txt', 'Hello, world!')
|
||||
store.write('foo/bar/qux.txt', 'Hello, world!')
|
||||
store.write('foo/bar/quux.txt', 'Hello, world!')
|
||||
assert store.list('') == ['foo'], 'Expected foo, got {} for class {}'.format(
|
||||
store.list(''), store.__class__
|
||||
)
|
||||
assert store.list('foo') == ['foo/bar']
|
||||
assert store.list('') == ['foo/'], f'for class {store.__class__}'
|
||||
assert store.list('foo') == ['foo/bar/']
|
||||
assert (
|
||||
store.list('foo/bar').sort()
|
||||
== ['foo/bar/baz.txt', 'foo/bar/qux.txt', 'foo/bar/quux.txt'].sort()
|
||||
|
||||
Reference in New Issue
Block a user