Refactor session management (#1810)

* refactor session mgmt

* defer file handling to runtime

* add todo

* refactor sessions a bit more

* remove messages logic from FE

* fix up socket handshake

* refactor frontend auth a bit

* first pass at redoing file explorer

* implement directory suffix

* fix up file tree

* close agent on websocket close

* remove session saving

* move file refresh

* remove getWorkspace

* plumb path/code differently

* fix build issues

* fix the tests

* fix npm build

* add session rehydration

* fix event serialization

* logspam

* fix user message rehydration

* add get_event fn

* agent state restoration

* change history tracking for codeact

* fix responsiveness of init

* fix lint

* lint

* delint

* fix prop

* update tests

* logspam

* lint

* fix test

* revert codeact

* change fileService to use API

* fix up session loading

* delint

* delint

* fix integration tests

* revert test

* fix up access to options endpoints

* fix initial files load

* delint

* fix file initialization

* fix mock server

* fixl int

* fix auth for html

* Update frontend/src/i18n/translation.json

Co-authored-by: Xingyao Wang <xingyao6@illinois.edu>

* refactor sessions and sockets

* avoid reinitializing the same session

* fix reconnect issue

* change up intro message

* more guards on reinit

* rename agent_session

* delint

* fix a bunch of tests

* delint

* fix last test

* remove code editor context

* fix build

* fix any

* fix dot notation

* Update frontend/src/services/api.ts

Co-authored-by: Boxuan Li <liboxuan@connect.hku.hk>

* fix up error handling

* Update opendevin/server/session/agent.py

Co-authored-by: Boxuan Li <liboxuan@connect.hku.hk>

* Update opendevin/server/session/agent.py

Co-authored-by: Boxuan Li <liboxuan@connect.hku.hk>

* Update frontend/src/services/session.ts

Co-authored-by: Boxuan Li <liboxuan@connect.hku.hk>

* fix build errs

* fix else

* add closed state

* delint

* Update opendevin/server/session/session.py

Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>

---------

Co-authored-by: Xingyao Wang <xingyao6@illinois.edu>
Co-authored-by: Graham Neubig <neubig@gmail.com>
Co-authored-by: Boxuan Li <liboxuan@connect.hku.hk>
Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
This commit is contained in:
Robert Brennan
2024-05-22 14:33:16 -04:00
committed by GitHub
parent 37354dbc83
commit 5bdacf738d
55 changed files with 764 additions and 1337 deletions

View File

@@ -1,5 +1,5 @@
import { useDisclosure } from "@nextui-org/react";
import React, { useEffect, useState } from "react";
import React, { useEffect } from "react";
import { Toaster } from "react-hot-toast";
import CogTooth from "#/assets/cog-tooth";
import ChatInterface from "#/components/chat/ChatInterface";
@@ -8,15 +8,13 @@ import { Container, Orientation } from "#/components/Resizable";
import Workspace from "#/components/Workspace";
import LoadPreviousSessionModal from "#/components/modals/load-previous-session/LoadPreviousSessionModal";
import SettingsModal from "#/components/modals/settings/SettingsModal";
import { fetchMsgTotal } from "#/services/session";
import Socket from "#/services/socket";
import { ResFetchMsgTotal } from "#/types/ResponseType";
import "./App.css";
import AgentControlBar from "./components/AgentControlBar";
import AgentStatusBar from "./components/AgentStatusBar";
import Terminal from "./components/terminal/Terminal";
import { initializeAgent } from "./services/agent";
import { settingsAreUpToDate } from "./services/settings";
import Session from "#/services/session";
import { getToken } from "#/services/auth";
import { settingsAreUpToDate } from "#/services/settings";
interface Props {
setSettingOpen: (isOpen: boolean) => void;
@@ -43,8 +41,6 @@ function Controls({ setSettingOpen }: Props): JSX.Element {
let initOnce = false;
function App(): JSX.Element {
const [isWarned, setIsWarned] = useState(false);
const {
isOpen: settingsModalIsOpen,
onOpen: onSettingsModalOpen,
@@ -57,31 +53,18 @@ function App(): JSX.Element {
onOpenChange: onLoadPreviousSessionModalOpenChange,
} = useDisclosure();
const getMsgTotal = () => {
if (isWarned) return;
fetchMsgTotal()
.then((data: ResFetchMsgTotal) => {
if (data.msg_total > 0) {
onLoadPreviousSessionModalOpen();
setIsWarned(true);
}
})
.catch();
};
useEffect(() => {
if (initOnce) return;
initOnce = true;
if (!settingsAreUpToDate()) {
onSettingsModalOpen();
} else if (getToken()) {
onLoadPreviousSessionModalOpen();
} else {
initializeAgent();
Session.startNewSession();
}
Socket.registerCallback("open", [getMsgTotal]);
getMsgTotal();
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);

View File

@@ -1,9 +0,0 @@
export async function fetchModels() {
const response = await fetch(`/api/litellm-models`);
return response.json();
}
export async function fetchAgents() {
const response = await fetch(`/api/agents`);
return response.json();
}

View File

@@ -5,7 +5,6 @@ import ArrowIcon from "#/assets/arrow";
import PauseIcon from "#/assets/pause";
import PlayIcon from "#/assets/play";
import { changeAgentState } from "#/services/agentStateService";
import { clearMsgs } from "#/services/session";
import store, { RootState } from "#/store";
import AgentState from "#/types/AgentState";
import { clearMessages } from "#/state/chatSlice";
@@ -73,7 +72,6 @@ function AgentControlBar() {
}
if (action === AgentState.STOPPED) {
clearMsgs().then().catch();
store.dispatch(clearMessages());
} else {
setIsLoading(true);
@@ -86,7 +84,6 @@ function AgentControlBar() {
useEffect(() => {
if (curAgentState === desiredState) {
if (curAgentState === AgentState.STOPPED) {
clearMsgs().then().catch();
store.dispatch(clearMessages());
}
setIsLoading(false);

View File

@@ -1,33 +1,24 @@
import Editor, { Monaco } from "@monaco-editor/react";
import { Tab, Tabs } from "@nextui-org/react";
import type { editor } from "monaco-editor";
import React, { useMemo, useState } from "react";
import React, { useMemo } from "react";
import { useTranslation } from "react-i18next";
import { VscCode } from "react-icons/vsc";
import { useDispatch, useSelector } from "react-redux";
import { useSelector } from "react-redux";
import { I18nKey } from "#/i18n/declaration";
import { selectFile } from "#/services/fileService";
import { setCode } from "#/state/codeSlice";
import { RootState } from "#/store";
import FileExplorer from "./file-explorer/FileExplorer";
import { CodeEditorContext } from "./CodeEditorContext";
function CodeEditor(): JSX.Element {
const { t } = useTranslation();
const [selectedFileAbsolutePath, setSelectedFileAbsolutePath] = useState("");
const selectedFileName = useMemo(() => {
const paths = selectedFileAbsolutePath.split("/");
return paths[paths.length - 1];
}, [selectedFileAbsolutePath]);
const codeEditorContext = useMemo(
() => ({ selectedFileAbsolutePath }),
[selectedFileAbsolutePath],
);
const dispatch = useDispatch();
const code = useSelector((state: RootState) => state.code.code);
const activeFilepath = useSelector((state: RootState) => state.code.path);
const selectedFileName = useMemo(() => {
const paths = activeFilepath.split("/");
return paths[paths.length - 1];
}, [activeFilepath]);
const handleEditorDidMount = (
editor: editor.IStandaloneCodeEditor,
monaco: Monaco,
@@ -46,57 +37,44 @@ function CodeEditor(): JSX.Element {
monaco.editor.setTheme("my-theme");
};
const updateCode = async () => {
const newCode = await selectFile(activeFilepath);
setSelectedFileAbsolutePath(activeFilepath);
dispatch(setCode(newCode));
};
React.useEffect(() => {
// FIXME: we can probably move this out of the component and into state/service
if (activeFilepath) updateCode();
}, [activeFilepath]);
return (
<div className="flex h-full w-full bg-neutral-900 transition-all duration-500 ease-in-out">
<CodeEditorContext.Provider value={codeEditorContext}>
<FileExplorer />
<div className="flex flex-col min-h-0 w-full">
<Tabs
disableCursorAnimation
classNames={{
base: "border-b border-divider border-neutral-600 mb-4",
tabList:
"w-full relative rounded-none bg-neutral-900 p-0 border-divider",
cursor: "w-full bg-neutral-600 rounded-none",
tab: "max-w-fit px-4 h-[36px]",
tabContent: "group-data-[selected=true]:text-white",
}}
aria-label="Options"
>
<Tab
key={selectedFileName.toLocaleLowerCase()}
title={selectedFileName}
<FileExplorer />
<div className="flex flex-col min-h-0 w-full">
<Tabs
disableCursorAnimation
classNames={{
base: "border-b border-divider border-neutral-600 mb-4",
tabList:
"w-full relative rounded-none bg-neutral-900 p-0 border-divider",
cursor: "w-full bg-neutral-600 rounded-none",
tab: "max-w-fit px-4 h-[36px]",
tabContent: "group-data-[selected=true]:text-white",
}}
aria-label="Options"
>
<Tab
key={selectedFileName.toLocaleLowerCase()}
title={selectedFileName}
/>
</Tabs>
<div className="flex grow items-center justify-center">
{selectedFileName === "" ? (
<div className="flex flex-col items-center text-neutral-400">
<VscCode size={100} />
{t(I18nKey.CODE_EDITOR$EMPTY_MESSAGE)}
</div>
) : (
<Editor
height="100%"
path={selectedFileName.toLocaleLowerCase()}
defaultValue=""
value={code}
onMount={handleEditorDidMount}
/>
</Tabs>
<div className="flex grow items-center justify-center">
{selectedFileName === "" ? (
<div className="flex flex-col items-center text-neutral-400">
<VscCode size={100} />
{t(I18nKey.CODE_EDITOR$EMPTY_MESSAGE)}
</div>
) : (
<Editor
height="100%"
path={selectedFileName.toLocaleLowerCase()}
defaultValue=""
value={code}
onMount={handleEditorDidMount}
/>
)}
</div>
)}
</div>
</CodeEditorContext.Provider>
</div>
</div>
);
}

View File

@@ -1,5 +0,0 @@
import { createContext } from "react";
export const CodeEditorContext = createContext({
selectedFileAbsolutePath: "",
});

View File

@@ -5,7 +5,7 @@ import { act } from "react-dom/test-utils";
import userEvent from "@testing-library/user-event";
import { renderWithProviders } from "test-utils";
import ChatInterface from "./ChatInterface";
import Socket from "#/services/socket";
import Session from "#/services/session";
import ActionType from "#/types/ActionType";
import { addAssistantMessage } from "#/state/chatSlice";
import AgentState from "#/types/AgentState";
@@ -15,16 +15,17 @@ vi.mock("#/hooks/useTyping", () => ({
useTyping: vi.fn((text: string) => text),
}));
const socketSpy = vi.spyOn(Socket, "send");
const sessionSpy = vi.spyOn(Session, "send");
vi.spyOn(Session, "isConnected").mockImplementation(() => true);
// This is for the scrollview ref in Chat.tsx
// TODO: Move this into test setup
HTMLElement.prototype.scrollTo = vi.fn(() => {});
describe("ChatInterface", () => {
it("should render the messages and input", () => {
it("should render empty message list and input", () => {
renderWithProviders(<ChatInterface />);
expect(screen.queryAllByTestId("message")).toHaveLength(1); // initial welcome message only
expect(screen.queryAllByTestId("message")).toHaveLength(0);
});
it("should render the new message the user has typed", async () => {
@@ -65,7 +66,7 @@ describe("ChatInterface", () => {
expect(screen.getByText("Hello to you!")).toBeInTheDocument();
});
it("should send the a start event to the Socket", () => {
it("should send the a start event to the Session", () => {
renderWithProviders(<ChatInterface />, {
preloadedState: {
agent: {
@@ -83,10 +84,10 @@ describe("ChatInterface", () => {
action: ActionType.MESSAGE,
args: { content: "my message" },
};
expect(socketSpy).toHaveBeenCalledWith(JSON.stringify(event));
expect(sessionSpy).toHaveBeenCalledWith(JSON.stringify(event));
});
it("should send the a user message event to the Socket", () => {
it("should send the a user message event to the Session", () => {
renderWithProviders(<ChatInterface />, {
preloadedState: {
agent: {
@@ -104,7 +105,7 @@ describe("ChatInterface", () => {
action: ActionType.MESSAGE,
args: { content: "my message" },
};
expect(socketSpy).toHaveBeenCalledWith(JSON.stringify(event));
expect(sessionSpy).toHaveBeenCalledWith(JSON.stringify(event));
});
it("should disable the user input if agent is not initialized", () => {

View File

@@ -10,7 +10,7 @@ import Chat from "./Chat";
import { RootState } from "#/store";
import AgentState from "#/types/AgentState";
import { sendChatMessage } from "#/services/chatService";
import { addUserMessage } from "#/state/chatSlice";
import { addUserMessage, addAssistantMessage } from "#/state/chatSlice";
import { I18nKey } from "#/i18n/declaration";
import { useScrollToBottom } from "#/hooks/useScrollToBottom";
@@ -58,6 +58,12 @@ function ChatInterface() {
const { scrollDomToBottom, onChatBodyScroll, hitBottom } =
useScrollToBottom(scrollRef);
React.useEffect(() => {
if (curAgentState === AgentState.INIT && messages.length === 0) {
dispatch(addAssistantMessage(t(I18nKey.CHAT_INTERFACE$INITIAL_MESSAGE)));
}
}, [curAgentState]);
return (
<div className="flex flex-col h-full bg-neutral-800">
<div className="flex items-center gap-2 border-b border-neutral-600 text-sm px-4 py-2">

View File

@@ -7,8 +7,9 @@ import { describe, it, expect, vi, Mock } from "vitest";
import FileExplorer from "./FileExplorer";
import { uploadFiles, listFiles } from "#/services/fileService";
import toast from "#/utils/toast";
import AgentState from "#/types/AgentState";
const toastSpy = vi.spyOn(toast, "stickyError");
const toastSpy = vi.spyOn(toast, "error");
vi.mock("../../services/fileService", async () => ({
listFiles: vi.fn(async (path: string = "/") => {
@@ -42,7 +43,13 @@ describe("FileExplorer", () => {
it.todo("should render an empty workspace");
it.only("should refetch the workspace when clicking the refresh button", async () => {
const { getByText } = renderWithProviders(<FileExplorer />);
const { getByText } = renderWithProviders(<FileExplorer />, {
preloadedState: {
agent: {
curAgentState: AgentState.RUNNING,
},
},
});
await waitFor(() => {
expect(getByText("folder1")).toBeInTheDocument();
expect(getByText("file2.ts")).toBeInTheDocument();

View File

@@ -5,14 +5,16 @@ import {
IoIosRefresh,
IoIosCloudUpload,
} from "react-icons/io";
import { useDispatch } from "react-redux";
import { useDispatch, useSelector } from "react-redux";
import { IoFileTray } from "react-icons/io5";
import { twMerge } from "tailwind-merge";
import AgentState from "#/types/AgentState";
import { setRefreshID } from "#/state/codeSlice";
import { listFiles, uploadFiles } from "#/services/fileService";
import IconButton from "../IconButton";
import ExplorerTree from "./ExplorerTree";
import toast from "#/utils/toast";
import { RootState } from "#/store";
interface ExplorerActionsProps {
onRefresh: () => void;
@@ -87,6 +89,7 @@ function FileExplorer() {
const [isHidden, setIsHidden] = React.useState(false);
const [isDragging, setIsDragging] = React.useState(false);
const [files, setFiles] = React.useState<string[]>([]);
const { curAgentState } = useSelector((state: RootState) => state.agent);
const fileInputRef = React.useRef<HTMLInputElement | null>(null);
const dispatch = useDispatch();
@@ -95,6 +98,12 @@ function FileExplorer() {
};
const refreshWorkspace = async () => {
if (
curAgentState === AgentState.LOADING ||
curAgentState === AgentState.STOPPED
) {
return;
}
dispatch(setRefreshID(Math.random()));
setFiles(await listFiles("/"));
};
@@ -104,7 +113,7 @@ function FileExplorer() {
await uploadFiles(toAdd);
await refreshWorkspace();
} catch (error) {
toast.stickyError("ws", "Error uploading file");
toast.error("ws", "Error uploading file");
}
};
@@ -112,7 +121,9 @@ function FileExplorer() {
(async () => {
await refreshWorkspace();
})();
}, [curAgentState]);
React.useEffect(() => {
const enableDragging = () => {
setIsDragging(true);
};
@@ -130,6 +141,10 @@ function FileExplorer() {
};
}, []);
if (!files.length) {
return null;
}
return (
<div className="relative">
{isDragging && (

View File

@@ -4,9 +4,8 @@ import { twMerge } from "tailwind-merge";
import { RootState } from "#/store";
import FolderIcon from "../FolderIcon";
import FileIcon from "../FileIcons";
import { listFiles } from "#/services/fileService";
import { setActiveFilepath } from "#/state/codeSlice";
import { CodeEditorContext } from "../CodeEditorContext";
import { listFiles, selectFile } from "#/services/fileService";
import { setCode, setActiveFilepath } from "#/state/codeSlice";
interface TitleProps {
name: string;
@@ -36,8 +35,8 @@ interface TreeNodeProps {
function TreeNode({ path, defaultOpen = false }: TreeNodeProps) {
const [isOpen, setIsOpen] = React.useState(defaultOpen);
const [children, setChildren] = React.useState<string[] | null>(null);
const { selectedFileAbsolutePath } = React.useContext(CodeEditorContext);
const refreshID = useSelector((state: RootState) => state.code.refreshID);
const activeFilepath = useSelector((state: RootState) => state.code.path);
const dispatch = useDispatch();
@@ -60,10 +59,12 @@ function TreeNode({ path, defaultOpen = false }: TreeNodeProps) {
refreshChildren();
}, [refreshID, isOpen]);
const handleClick = () => {
const handleClick = async () => {
if (isDirectory) {
setIsOpen((prev) => !prev);
} else {
const newCode = await selectFile(path);
dispatch(setCode(newCode));
dispatch(setActiveFilepath(path));
}
};
@@ -72,7 +73,7 @@ function TreeNode({ path, defaultOpen = false }: TreeNodeProps) {
<div
className={twMerge(
"text-sm text-neutral-400",
path === selectedFileAbsolutePath ? "bg-gray-700" : "",
path === activeFilepath ? "bg-gray-700" : "",
)}
>
<Title

View File

@@ -2,54 +2,23 @@ import React from "react";
import { act, render, screen, waitFor } from "@testing-library/react";
import userEvent from "@testing-library/user-event";
import LoadPreviousSessionModal from "./LoadPreviousSessionModal";
import { clearMsgs, fetchMsgs } from "../../../services/session";
import { addChatMessageFromEvent } from "../../../services/chatService";
import { handleAssistantMessage } from "../../../services/actions";
import toast from "../../../utils/toast";
import Session from "../../../services/session";
const RESUME_SESSION_BUTTON_LABEL_KEY =
"LOAD_SESSION$RESUME_SESSION_MODAL_ACTION_LABEL";
const START_NEW_SESSION_BUTTON_LABEL_KEY =
"LOAD_SESSION$START_NEW_SESSION_MODAL_ACTION_LABEL";
const mocks = vi.hoisted(() => ({
fetchMsgsMock: vi.fn(),
}));
vi.mock("../../../services/session", async (importOriginal) => ({
...(await importOriginal<typeof import("../../../services/session")>()),
clearMsgs: vi.fn(),
fetchMsgs: mocks.fetchMsgsMock.mockResolvedValue({
messages: [
{
id: "1",
role: "user",
payload: { type: "action" },
},
{
id: "2",
role: "assistant",
payload: { type: "observation" },
},
],
}),
}));
vi.mock("../../../services/chatService", async (importOriginal) => ({
...(await importOriginal<typeof import("../../../services/chatService")>()),
addChatMessageFromEvent: vi.fn(),
}));
vi.mock("../../../services/actions", async (importOriginal) => ({
...(await importOriginal<typeof import("../../../services/actions")>()),
handleAssistantMessage: vi.fn(),
}));
vi.mock("../../../utils/toast", () => ({
default: {
stickyError: vi.fn(),
},
}));
vi.spyOn(Session, "isConnected").mockImplementation(() => true);
const restoreOrStartNewSessionSpy = vi.spyOn(
Session,
"restoreOrStartNewSession",
);
describe("LoadPreviousSession", () => {
afterEach(() => {
@@ -75,7 +44,6 @@ describe("LoadPreviousSession", () => {
userEvent.click(startNewSessionButton);
});
expect(clearMsgs).toHaveBeenCalledTimes(1);
// modal should close right after clearing messages
expect(onOpenChangeMock).toHaveBeenCalledWith(false);
});
@@ -93,36 +61,9 @@ describe("LoadPreviousSession", () => {
});
await waitFor(() => {
expect(fetchMsgs).toHaveBeenCalledTimes(1);
expect(addChatMessageFromEvent).toHaveBeenCalledTimes(1);
expect(handleAssistantMessage).toHaveBeenCalledTimes(1);
expect(restoreOrStartNewSessionSpy).toHaveBeenCalledTimes(1);
});
// modal should close right after fetching messages
expect(onOpenChangeMock).toHaveBeenCalledWith(false);
});
it("should show an error toast if there is an error fetching the session", async () => {
mocks.fetchMsgsMock.mockRejectedValue(new Error("Get messages failed."));
render(<LoadPreviousSessionModal isOpen onOpenChange={vi.fn} />);
const resumeSessionButton = screen.getByRole("button", {
name: RESUME_SESSION_BUTTON_LABEL_KEY,
});
act(() => {
userEvent.click(resumeSessionButton);
});
await waitFor(async () => {
await expect(() => fetchMsgs()).rejects.toThrow();
expect(handleAssistantMessage).not.toHaveBeenCalled();
expect(addChatMessageFromEvent).not.toHaveBeenCalled();
// error toast should be shown
expect(toast.stickyError).toHaveBeenCalledWith(
"ws",
"Error fetching the session",
);
});
});
});

View File

@@ -1,11 +1,8 @@
import React from "react";
import { useTranslation } from "react-i18next";
import { I18nKey } from "#/i18n/declaration";
import { handleAssistantMessage } from "#/services/actions";
import { addChatMessageFromEvent } from "#/services/chatService";
import { clearMsgs, fetchMsgs } from "#/services/session";
import toast from "#/utils/toast";
import BaseModal from "../base-modal/BaseModal";
import Session from "#/services/session";
interface LoadPreviousSessionModalProps {
isOpen: boolean;
@@ -18,28 +15,6 @@ function LoadPreviousSessionModal({
}: LoadPreviousSessionModalProps) {
const { t } = useTranslation();
const onStartNewSession = async () => {
await clearMsgs();
};
const onResumeSession = async () => {
try {
const { messages } = await fetchMsgs();
messages.forEach((message) => {
if (message.role === "user") {
addChatMessageFromEvent(message.payload);
}
if (message.role === "assistant") {
handleAssistantMessage(message.payload);
}
});
} catch (error) {
toast.stickyError("ws", "Error fetching the session");
}
};
return (
<BaseModal
isOpen={isOpen}
@@ -50,13 +25,13 @@ function LoadPreviousSessionModal({
{
label: t(I18nKey.LOAD_SESSION$RESUME_SESSION_MODAL_ACTION_LABEL),
className: "bg-primary rounded-lg",
action: onResumeSession,
action: Session.restoreOrStartNewSession,
closeAfterAction: true,
},
{
label: t(I18nKey.LOAD_SESSION$START_NEW_SESSION_MODAL_ACTION_LABEL),
className: "bg-neutral-500 rounded-lg",
action: onStartNewSession,
action: Session.startNewSession,
closeAfterAction: true,
},
]}

View File

@@ -11,12 +11,14 @@ import {
saveSettings,
getDefaultSettings,
} from "#/services/settings";
import { initializeAgent } from "#/services/agent";
import { fetchAgents, fetchModels } from "#/api";
import Session from "#/services/session";
import { fetchAgents, fetchModels } from "#/services/options";
import SettingsModal from "./SettingsModal";
const toastSpy = vi.spyOn(toast, "settingsChanged");
const i18nSpy = vi.spyOn(i18next, "changeLanguage");
const startNewSessionSpy = vi.spyOn(Session, "startNewSession");
vi.spyOn(Session, "isConnected").mockImplementation(() => true);
vi.mock("#/services/settings", async (importOriginal) => ({
...(await importOriginal<typeof import("#/services/settings")>()),
@@ -35,12 +37,8 @@ vi.mock("#/services/settings", async (importOriginal) => ({
saveSettings: vi.fn(),
}));
vi.mock("#/services/agent", async () => ({
initializeAgent: vi.fn(),
}));
vi.mock("#/api", async (importOriginal) => ({
...(await importOriginal<typeof import("#/api")>()),
vi.mock("#/services/options", async (importOriginal) => ({
...(await importOriginal<typeof import("#/services/options")>()),
fetchModels: vi
.fn()
.mockResolvedValue(Promise.resolve(["model1", "model2", "model3"])),
@@ -162,7 +160,7 @@ describe("SettingsModal", () => {
userEvent.click(saveButton);
});
expect(initializeAgent).toHaveBeenCalled();
expect(startNewSessionSpy).toHaveBeenCalled();
});
it("should display a toast for every change", async () => {

View File

@@ -3,10 +3,10 @@ import i18next from "i18next";
import React, { useEffect } from "react";
import { useTranslation } from "react-i18next";
import { useSelector } from "react-redux";
import { fetchAgents, fetchModels } from "#/api";
import { fetchAgents, fetchModels } from "#/services/options";
import { AvailableLanguages } from "#/i18n";
import { I18nKey } from "#/i18n/declaration";
import { initializeAgent } from "#/services/agent";
import Session from "#/services/session";
import { RootState } from "../../../store";
import AgentState from "../../../types/AgentState";
import {
@@ -100,7 +100,7 @@ function SettingsModal({ isOpen, onOpenChange }: SettingsProps) {
const updatedSettings = getSettingsDifference(settings);
saveSettings(settings);
i18next.changeLanguage(settings.LANGUAGE);
initializeAgent(); // reinitialize the agent with the new settings
Session.startNewSession();
const sensitiveKeys = ["LLM_API_KEY"];

View File

@@ -266,12 +266,12 @@
"en": "Please stop the agent before editing these settings."
},
"LOAD_SESSION$MODAL_TITLE": {
"en": "Unfinished Session Detected",
"zh-CN": "检测到有未完成的会话",
"zh-TW": "偵測到未完成的會話"
"en": "Return to existing session?",
"zh-CN": "是否继续未完成的会话?",
"zh-TW": "是否繼續未完成的會話?"
},
"LOAD_SESSION$MODAL_CONTENT": {
"en": "You seem to have an unfinished task. Would you like to pick up where you left off or start fresh?",
"en": "You seem to have an ongoing session. Would you like to pick up where you left off, or start fresh?",
"zh-CN": "您似乎有一个未完成的任务。您想继续之前的工作还是重新开始?",
"zh-TW": "您似乎有一個未完成的任務。您想從上次離開的地方繼續還是重新開始?"
},

View File

@@ -1,29 +0,0 @@
import { describe, expect, it, vi } from "vitest";
import ActionType from "#/types/ActionType";
import { initializeAgent } from "./agent";
import { Settings, saveSettings } from "./settings";
import Socket from "./socket";
const sendSpy = vi.spyOn(Socket, "send");
describe("initializeAgent", () => {
it("Should initialize the agent with the current settings", () => {
const settings: Settings = {
LLM_MODEL: "llm_value",
AGENT: "agent_value",
LANGUAGE: "language_value",
LLM_API_KEY: "sk-...",
};
const event = {
action: ActionType.INIT,
args: settings,
};
saveSettings(settings);
initializeAgent();
expect(sendSpy).toHaveBeenCalledWith(JSON.stringify(event));
});
});

View File

@@ -1,14 +0,0 @@
import ActionType from "#/types/ActionType";
import { getSettings } from "./settings";
import Socket from "./socket";
/**
* Initialize the agent with the current settings.
* @param settings - The new settings.
*/
export const initializeAgent = () => {
const settings = getSettings();
const event = { action: ActionType.INIT, args: settings };
const eventString = JSON.stringify(event);
Socket.send(eventString);
};

View File

@@ -1,7 +1,6 @@
import ActionType from "#/types/ActionType";
import AgentState from "#/types/AgentState";
import Socket from "./socket";
import { initializeAgent } from "./agent";
import Session from "./session";
const INIT_DELAY = 1000;
@@ -10,10 +9,10 @@ export function changeAgentState(state: AgentState): void {
action: ActionType.CHANGE_AGENT_STATE,
args: { agent_state: state },
});
Socket.send(eventString);
Session.send(eventString);
if (state === AgentState.STOPPED) {
setTimeout(() => {
initializeAgent();
Session.startNewSession();
}, INIT_DELAY);
}
}

View File

@@ -0,0 +1,58 @@
import { getToken } from "./auth";
import toast from "#/utils/toast";
const WAIT_FOR_AUTH_DELAY_MS = 500;
export async function request(
url: string,
optionsIn: RequestInit = {},
disableToast: boolean = false,
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
): Promise<any> {
const options = JSON.parse(JSON.stringify(optionsIn));
const onFail = (msg: string) => {
if (!disableToast) {
toast.error("api", msg);
}
throw new Error(msg);
};
const needsAuth = !url.startsWith("/api/options/");
const token = getToken();
if (!token && needsAuth) {
return new Promise((resolve) => {
setTimeout(() => {
resolve(request(url, optionsIn, disableToast));
}, WAIT_FOR_AUTH_DELAY_MS);
});
}
if (token) {
options.headers = {
...(options.headers || {}),
Authorization: `Bearer ${token}`,
};
}
let response = null;
try {
response = await fetch(url, options);
} catch (e) {
onFail(`Error fetching ${url}`);
}
if (response?.status && response?.status >= 400) {
onFail(
`${response.status} error while fetching ${url}: ${response?.statusText}`,
);
}
if (!response?.ok) {
onFail(`Error fetching ${url}: ${response?.statusText}`);
}
try {
return await (response && response.json());
} catch (e) {
onFail(`Error parsing JSON from ${url}`);
}
return null;
}

View File

@@ -1,18 +1,5 @@
import * as jose from "jose";
import type { Mock } from "vitest";
import { fetchToken, validateToken, getToken } from "./auth";
vi.mock("jose", () => ({
decodeJwt: vi.fn(),
}));
// SUGGESTION: Prefer using msw for mocking requests (see https://mswjs.io/)
global.fetch = vi.fn(() =>
Promise.resolve({
status: 200,
json: () => Promise.resolve({ token: "newToken" }),
}),
) as Mock;
import { getToken } from "./auth";
Storage.prototype.getItem = vi.fn();
Storage.prototype.setItem = vi.fn();
@@ -22,66 +9,12 @@ describe("Auth Service", () => {
vi.clearAllMocks();
});
describe("fetchToken", () => {
it("should fetch and return a token", async () => {
const data = await fetchToken();
expect(localStorage.getItem).toHaveBeenCalledWith("token"); // Used to set Authorization header
expect(data).toEqual({ token: "newToken" });
expect(fetch).toHaveBeenCalledWith(`/api/auth`, {
headers: expect.any(Headers),
});
});
it("throws an error if response status is not 200", async () => {
(fetch as Mock).mockImplementationOnce(() =>
Promise.resolve({ status: 401 }),
);
await expect(fetchToken()).rejects.toThrow("Get token failed.");
});
});
describe("validateToken", () => {
it("returns true for a valid token", () => {
(jose.decodeJwt as Mock).mockReturnValue({ sid: "123" });
expect(validateToken("validToken")).toBe(true);
});
it("returns false for an invalid token", () => {
(jose.decodeJwt as Mock).mockReturnValue({});
expect(validateToken("invalidToken")).toBe(false);
});
it("returns false when decodeJwt throws", () => {
(jose.decodeJwt as Mock).mockImplementation(() => {
throw new Error("Invalid token");
});
expect(validateToken("badToken")).toBe(false);
});
});
describe("getToken", () => {
it("returns existing valid token from localStorage", async () => {
(jose.decodeJwt as Mock).mockReturnValue({ sid: "123" });
(Storage.prototype.getItem as Mock).mockReturnValue("existingToken");
const token = await getToken();
expect(token).toBe("existingToken");
});
it("fetches, validates, and stores a new token when existing token is invalid", async () => {
(jose.decodeJwt as Mock)
.mockReturnValueOnce({})
.mockReturnValueOnce({ sid: "123" });
const token = await getToken();
expect(token).toBe("newToken");
expect(localStorage.setItem).toHaveBeenCalledWith("token", "newToken");
});
it("throws an error when fetched token is invalid", async () => {
(jose.decodeJwt as Mock).mockReturnValue({});
await expect(getToken()).rejects.toThrow("Token validation failed.");
it("should fetch and return a token", async () => {
(Storage.prototype.getItem as Mock).mockReturnValue("newToken");
const data = await getToken();
expect(localStorage.getItem).toHaveBeenCalledWith("token"); // Used to set Authorization header
expect(data).toEqual("newToken");
});
});
});

View File

@@ -1,44 +1,13 @@
import * as jose from "jose";
import { ResFetchToken } from "#/types/ResponseType";
const TOKEN_KEY = "token";
const fetchToken = async (): Promise<ResFetchToken> => {
const headers = new Headers({
"Content-Type": "application/json",
Authorization: `Bearer ${localStorage.getItem("token")}`,
});
const response = await fetch(`/api/auth`, { headers });
if (response.status !== 200) {
throw new Error("Get token failed.");
}
const data: ResFetchToken = await response.json();
return data;
const getToken = (): string => localStorage.getItem(TOKEN_KEY) ?? "";
const clearToken = (): void => {
localStorage.removeItem(TOKEN_KEY);
};
export const validateToken = (token: string): boolean => {
try {
const claims = jose.decodeJwt(token);
return !(claims.sid === undefined || claims.sid === "");
} catch (error) {
return false;
}
const setToken = (token: string): void => {
localStorage.setItem(TOKEN_KEY, token);
};
const getToken = async (): Promise<string> => {
const token = localStorage.getItem("token") ?? "";
if (validateToken(token)) {
return token;
}
const data = await fetchToken();
if (data.token === undefined || data.token === "") {
throw new Error("Get token failed.");
}
const newToken = data.token;
if (validateToken(newToken)) {
localStorage.setItem("token", newToken);
return newToken;
}
throw new Error("Token validation failed.");
};
export { getToken, fetchToken };
export { getToken, setToken, clearToken };

View File

@@ -1,28 +1,8 @@
import store from "#/store";
import ActionType from "#/types/ActionType";
import { SocketMessage } from "#/types/ResponseType";
import { ActionMessage } from "#/types/Message";
import Socket from "./socket";
import { addUserMessage } from "#/state/chatSlice";
import Session from "./session";
export function sendChatMessage(message: string): void {
const event = { action: ActionType.MESSAGE, args: { content: message } };
const eventString = JSON.stringify(event);
Socket.send(eventString);
}
export function addChatMessageFromEvent(event: string | SocketMessage): void {
try {
let data: ActionMessage;
if (typeof event === "string") {
data = JSON.parse(event);
} else {
data = event as ActionMessage;
}
if (data && data.args && data.args.task) {
store.dispatch(addUserMessage(data.args.task));
}
} catch (error) {
//
}
Session.send(eventString);
}

View File

@@ -1,9 +1,7 @@
import { request } from "./api";
export async function selectFile(file: string): Promise<string> {
const res = await fetch(`/api/select-file?file=${file}`);
const data = await res.json();
if (res.status !== 200) {
throw new Error(data.error);
}
const data = await request(`/api/select-file?file=${file}`);
return data.code as string;
}
@@ -13,20 +11,13 @@ export async function uploadFiles(files: FileList) {
formData.append("files", files[i]);
}
const res = await fetch("/api/upload-files", {
await request("/api/upload-files", {
method: "POST",
body: formData,
});
const data = await res.json();
if (res.status !== 200) {
throw new Error(data.error || "Failed to upload files.");
}
}
export async function listFiles(path: string = "/"): Promise<string[]> {
const res = await fetch(`/api/list-files?path=${path}`);
const data = await res.json();
const data = await request(`/api/list-files?path=${path}`);
return data as string[];
}

View File

@@ -0,0 +1,9 @@
import { request } from "./api";
export async function fetchModels() {
return request(`/api/options/models`);
}
export async function fetchAgents() {
return request(`/api/options/agents`);
}

View File

@@ -1,127 +1,36 @@
import type { Mock } from "vitest";
import {
ResDelMsg,
ResFetchMsg,
ResFetchMsgTotal,
ResFetchMsgs,
} from "../types/ResponseType";
import { clearMsgs, fetchMsgTotal, fetchMsgs } from "./session";
import { describe, expect, it, vi } from "vitest";
// SUGGESTION: Prefer using msw for mocking requests (see https://mswjs.io/)
global.fetch = vi.fn();
Storage.prototype.getItem = vi.fn();
import ActionType from "#/types/ActionType";
import { Settings, saveSettings } from "./settings";
import Session from "./session";
describe("Session Service", () => {
beforeEach(() => {
vi.clearAllMocks();
const sendSpy = vi.spyOn(Session, "send");
const setupSpy = vi
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
.spyOn(Session as any, "_setupSocket")
.mockImplementation(() => {
/* eslint-disable-next-line @typescript-eslint/dot-notation */
Session["_initializeAgent"](); // use key syntax to fix complaint about private fn
});
afterEach(() => {
// Used to set Authorization header
expect(localStorage.getItem).toHaveBeenCalledWith("token");
});
describe("startNewSession", () => {
it("Should start a new session with the current settings", () => {
const settings: Settings = {
LLM_MODEL: "llm_value",
AGENT: "agent_value",
LANGUAGE: "language_value",
LLM_API_KEY: "sk-...",
};
describe("fetchMsgTotal", () => {
it("should fetch and return message total", async () => {
const expectedResult: ResFetchMsgTotal = {
msg_total: 10,
};
const event = {
action: ActionType.INIT,
args: settings,
};
(fetch as Mock).mockImplementationOnce(() =>
Promise.resolve({
status: 200,
json: () => Promise.resolve(expectedResult),
}),
);
saveSettings(settings);
Session.startNewSession();
const data = await fetchMsgTotal();
expect(fetch).toHaveBeenCalledWith(`/api/messages/total`, {
headers: expect.any(Headers),
});
expect(data).toEqual(expectedResult);
});
it("throws an error if response status is not 200", async () => {
// NOTE: The current implementation ONLY handles 200 status;
// this means throwing even with a status of 201, 204, etc.
(fetch as Mock).mockImplementationOnce(() =>
Promise.resolve({ status: 401 }),
);
await expect(fetchMsgTotal()).rejects.toThrow(
"Get message total failed.",
);
});
});
describe("fetchMsgs", () => {
it("should fetch and return messages", async () => {
const expectedResult: ResFetchMsgs = {
messages: [
{
id: "1",
role: "user",
payload: {} as ResFetchMsg["payload"],
},
],
};
(fetch as Mock).mockImplementationOnce(() =>
Promise.resolve({
status: 200,
json: () => Promise.resolve(expectedResult),
}),
);
const data = await fetchMsgs();
expect(fetch).toHaveBeenCalledWith(`/api/messages`, {
headers: expect.any(Headers),
});
expect(data).toEqual(expectedResult);
});
it("throws an error if response status is not 200", async () => {
(fetch as Mock).mockImplementationOnce(() =>
Promise.resolve({ status: 401 }),
);
await expect(fetchMsgs()).rejects.toThrow("Get messages failed.");
});
});
describe("clearMsgs", () => {
it("should clear messages", async () => {
const expectedResult: ResDelMsg = {
ok: "true",
};
(fetch as Mock).mockImplementationOnce(() =>
Promise.resolve({
status: 200,
json: () => Promise.resolve(expectedResult),
}),
);
const data = await clearMsgs();
expect(fetch).toHaveBeenCalledWith(`/api/messages`, {
method: "DELETE",
headers: expect.any(Headers),
});
expect(data).toEqual(expectedResult);
});
it("throws an error if response status is not 200", async () => {
(fetch as Mock).mockImplementationOnce(() =>
Promise.resolve({ status: 401 }),
);
await expect(clearMsgs()).rejects.toThrow("Delete messages failed.");
});
expect(setupSpy).toHaveBeenCalledTimes(1);
expect(sendSpy).toHaveBeenCalledWith(JSON.stringify(event));
});
});

View File

@@ -1,49 +1,165 @@
import {
ResDelMsg,
ResFetchMsgs,
ResFetchMsgTotal,
} from "../types/ResponseType";
import toast from "#/utils/toast";
import { handleAssistantMessage } from "./actions";
import { getToken, setToken, clearToken } from "./auth";
import ActionType from "#/types/ActionType";
import { getSettings } from "./settings";
const fetchMsgTotal = async (): Promise<ResFetchMsgTotal> => {
const headers = new Headers({
"Content-Type": "application/json",
Authorization: `Bearer ${localStorage.getItem("token")}`,
});
const response = await fetch(`/api/messages/total`, { headers });
if (response.status !== 200) {
throw new Error("Get message total failed.");
class Session {
private static _socket: WebSocket | null = null;
// callbacks contain a list of callable functions
// event: function, like:
// open: [function1, function2]
// message: [function1, function2]
private static callbacks: {
[K in keyof WebSocketEventMap]: ((data: WebSocketEventMap[K]) => void)[];
} = {
open: [],
message: [],
error: [],
close: [],
};
private static _connecting = false;
private static _disconnecting = false;
public static restoreOrStartNewSession() {
const token = getToken();
if (Session.isConnected()) {
Session.disconnect();
}
Session._connect(token);
}
const data: ResFetchMsgTotal = await response.json();
return data;
};
const fetchMsgs = async (): Promise<ResFetchMsgs> => {
const headers = new Headers({
"Content-Type": "application/json",
Authorization: `Bearer ${localStorage.getItem("token")}`,
});
const response = await fetch(`/api/messages`, { headers });
if (response.status !== 200) {
throw new Error("Get messages failed.");
public static startNewSession() {
clearToken();
Session.restoreOrStartNewSession();
}
const data: ResFetchMsgs = await response.json();
return data;
};
const clearMsgs = async (): Promise<ResDelMsg> => {
const headers = new Headers({
"Content-Type": "application/json",
Authorization: `Bearer ${localStorage.getItem("token")}`,
});
const response = await fetch(`/api/messages`, {
method: "DELETE",
headers,
});
if (response.status !== 200) {
throw new Error("Delete messages failed.");
private static _initializeAgent = () => {
const settings = getSettings();
const event = { action: ActionType.INIT, args: settings };
const eventString = JSON.stringify(event);
Session.send(eventString);
};
private static _connect(token: string = ""): void {
if (Session.isConnected()) return;
Session._connecting = true;
const protocol = window.location.protocol === "https:" ? "wss:" : "ws:";
const WS_URL = `${protocol}//${window.location.host}/ws?token=${token}`;
Session._socket = new WebSocket(WS_URL);
Session._setupSocket();
}
const data: ResDelMsg = await response.json();
return data;
};
export { fetchMsgTotal, fetchMsgs, clearMsgs };
private static _setupSocket(): void {
if (!Session._socket) {
throw new Error("Socket is not initialized.");
}
Session._socket.onopen = (e) => {
toast.success("ws", "Connected to server.");
Session._connecting = false;
Session._initializeAgent();
Session.callbacks.open?.forEach((callback) => {
callback(e);
});
};
Session._socket.onmessage = (e) => {
let data = null;
try {
data = JSON.parse(e.data);
} catch (err) {
// TODO: report the error
console.error("Error parsing JSON data", err);
return;
}
if (data.error && data.error_code === 401) {
clearToken();
} else if (data.token) {
setToken(data.token);
} else {
handleAssistantMessage(data);
}
};
Session._socket.onerror = () => {
const msg = "Connection failed. Retry...";
toast.error("ws", msg);
};
Session._socket.onclose = () => {
if (!Session._disconnecting) {
setTimeout(() => {
Session.restoreOrStartNewSession();
}, 3000); // Reconnect after 3 seconds
}
Session._disconnecting = false;
};
}
static isConnected(): boolean {
return (
Session._socket !== null && Session._socket.readyState === WebSocket.OPEN
);
}
static disconnect(): void {
Session._disconnecting = true;
if (Session._socket) {
Session._socket.close();
}
Session._socket = null;
}
static send(message: string): void {
if (Session._connecting) {
setTimeout(() => Session.send(message), 1000);
return;
}
if (!Session.isConnected()) {
throw new Error("Not connected to server.");
}
if (Session.isConnected()) {
Session._socket?.send(message);
} else {
const msg = "Connection failed. Retry...";
toast.error("ws", msg);
}
}
static addEventListener(
event: string,
callback: (e: MessageEvent) => void,
): void {
Session._socket?.addEventListener(
event as keyof WebSocketEventMap,
callback as (
this: WebSocket,
ev: WebSocketEventMap[keyof WebSocketEventMap],
) => never,
);
}
static removeEventListener(
event: string,
listener: (e: Event) => void,
): void {
Session._socket?.removeEventListener(event, listener);
}
static registerCallback<K extends keyof WebSocketEventMap>(
event: K,
callbacks: ((data: WebSocketEventMap[K]) => void)[],
): void {
if (Session.callbacks[event] === undefined) {
return;
}
Session.callbacks[event].push(...callbacks);
}
}
export default Session;

View File

@@ -1,129 +0,0 @@
// import { toast } from "sonner";
import toast from "#/utils/toast";
import { handleAssistantMessage } from "./actions";
import { getToken } from "./auth";
class Socket {
private static _socket: WebSocket | null = null;
// callbacks contain a list of callable functions
// event: function, like:
// open: [function1, function2]
// message: [function1, function2]
private static callbacks: {
[K in keyof WebSocketEventMap]: ((data: WebSocketEventMap[K]) => void)[];
} = {
open: [],
message: [],
error: [],
close: [],
};
private static initializing = false;
public static tryInitialize(): void {
if (Socket.initializing) return;
Socket.initializing = true;
getToken()
.then((token) => {
Socket._initialize(token);
})
.catch(() => {
const msg = `Connection failed. Retry...`;
toast.stickyError("ws", msg);
setTimeout(() => {
this.tryInitialize();
}, 1500);
});
}
private static _initialize(token: string): void {
if (Socket.isConnected()) return;
const protocol = window.location.protocol === "https:" ? "wss:" : "ws:";
const WS_URL = `${protocol}//${window.location.host}/ws?token=${token}`;
Socket._socket = new WebSocket(WS_URL);
Socket._socket.onopen = (e) => {
toast.stickySuccess("ws", "Connected to server.");
Socket.initializing = false;
Socket.callbacks.open?.forEach((callback) => {
callback(e);
});
};
Socket._socket.onmessage = (e) => {
handleAssistantMessage(e.data);
};
Socket._socket.onerror = () => {
const msg = "Connection failed. Retry...";
toast.stickyError("ws", msg);
};
Socket._socket.onclose = () => {
// Reconnect after a delay
setTimeout(() => {
Socket.tryInitialize();
}, 3000); // Reconnect after 3 seconds
};
}
static isConnected(): boolean {
return (
Socket._socket !== null && Socket._socket.readyState === WebSocket.OPEN
);
}
static send(message: string): void {
if (!Socket.isConnected()) {
Socket.tryInitialize();
}
if (Socket.initializing) {
setTimeout(() => Socket.send(message), 1000);
return;
}
if (Socket.isConnected()) {
Socket._socket?.send(message);
} else {
const msg = "Connection failed. Retry...";
toast.stickyError("ws", msg);
}
}
static addEventListener(
event: string,
callback: (e: MessageEvent) => void,
): void {
Socket._socket?.addEventListener(
event as keyof WebSocketEventMap,
callback as (
this: WebSocket,
ev: WebSocketEventMap[keyof WebSocketEventMap],
) => never,
);
}
static removeEventListener(
event: string,
listener: (e: Event) => void,
): void {
Socket._socket?.removeEventListener(event, listener);
}
static registerCallback<K extends keyof WebSocketEventMap>(
event: K,
callbacks: ((data: WebSocketEventMap[K]) => void)[],
): void {
if (Socket.callbacks[event] === undefined) {
return;
}
Socket.callbacks[event].push(...callbacks);
}
}
Socket.tryInitialize();
export default Socket;

View File

@@ -1,3 +1,5 @@
import { request } from "./api";
export type Task = {
id: string;
goal: string;
@@ -14,14 +16,6 @@ export enum TaskState {
}
export async function getRootTask(): Promise<Task | undefined> {
const headers = new Headers({
"Content-Type": "application/json",
Authorization: `Bearer ${localStorage.getItem("token")}`,
});
const res = await fetch("/api/root_task", { headers });
if (res.status !== 200 && res.status !== 204) {
return undefined;
}
const data = (await res.json()) as Task;
return data;
const res = await request("/api/root_task");
return res as Task;
}

View File

@@ -3,13 +3,7 @@ import { createSlice, PayloadAction } from "@reduxjs/toolkit";
type SliceState = { messages: Message[] };
const initialState: SliceState = {
messages: [
{
content:
"Hi! I'm OpenDevin, an AI Software Engineer. What would you like to build with me today?",
sender: "assistant",
},
],
messages: [],
};
export const chatSlice = createSlice({

View File

@@ -1,41 +1,5 @@
import { ActionMessage, ObservationMessage } from "./Message";
type Role = "user" | "assistant";
interface ResConfigurations {
[key: string]: string | boolean | number;
}
interface ResFetchToken {
token: string;
}
interface ResFetchMsgTotal {
msg_total: number;
}
interface ResFetchMsg {
id: string;
role: Role;
payload: SocketMessage;
}
interface ResFetchMsgs {
messages: ResFetchMsg[];
}
interface ResDelMsg {
ok: string;
}
type SocketMessage = ActionMessage | ObservationMessage;
export {
type ResConfigurations,
type ResFetchToken,
type ResFetchMsgTotal,
type ResFetchMsg,
type ResFetchMsgs,
type ResDelMsg,
type SocketMessage,
};
export { type SocketMessage };

View File

@@ -3,15 +3,10 @@ import toast from "react-hot-toast";
const idMap = new Map<string, string>();
export default {
stickyError: (id: string, msg: string) => {
error: (id: string, msg: string) => {
if (idMap.has(id)) return; // prevent duplicate toast
const toastId = toast.loading(msg, {
// icon: "👏",
// style: {
// borderRadius: "10px",
// background: "#333",
// color: "#fff",
// },
const toastId = toast(msg, {
duration: 4000,
style: {
background: "#ef4444",
color: "#fff",
@@ -24,12 +19,13 @@ export default {
});
idMap.set(id, toastId);
},
stickySuccess: (id: string, msg: string) => {
success: (id: string, msg: string) => {
const toastId = idMap.get(id);
if (toastId === undefined) return;
if (toastId) {
toast.success(msg, {
id: toastId,
duration: 4000,
style: {
background: "#333",
color: "#fff",

View File

@@ -1 +0,0 @@
TROUBLESHOOTING_URL = 'https://opendevin.github.io/OpenDevin/modules/usage/troubleshooting'

View File

@@ -243,6 +243,9 @@ class AgentController:
def get_state(self):
return self.state
def set_state(self, state: State):
self.state = state
def _is_stuck(self):
# check if delegate stuck
if self.delegate and self.delegate._is_stuck():

View File

@@ -3,6 +3,7 @@ import logging
import os
import pathlib
import platform
import uuid
from dataclasses import dataclass, field, fields, is_dataclass
from types import UnionType
from typing import Any, ClassVar, get_args, get_origin
@@ -173,6 +174,7 @@ class AppConfig(metaclass=Singleton):
sandbox_user_id: int = os.getuid() if hasattr(os, 'getuid') else 1000
sandbox_timeout: int = 120
github_token: str | None = None
jwt_secret: str = uuid.uuid4().hex
debug: bool = False
enable_auto_lint: bool = (
False # once enabled, OpenDevin would lint files after editing

View File

@@ -0,0 +1,3 @@
TROUBLESHOOTING_URL = (
'https://opendevin.github.io/OpenDevin/modules/usage/troubleshooting'
)

View File

@@ -10,8 +10,8 @@ from glob import glob
import docker
from opendevin.const.guide_url import TROUBLESHOOTING_URL
from opendevin.core.config import config
from opendevin.core.const.guide_url import TROUBLESHOOTING_URL
from opendevin.core.exceptions import SandboxInvalidBackgroundCommandError
from opendevin.core.logger import opendevin_logger as logger
from opendevin.core.schema import CancellableStream

View File

@@ -12,8 +12,8 @@ from glob import glob
import docker
from pexpect import exceptions, pxssh
from opendevin.const.guide_url import TROUBLESHOOTING_URL
from opendevin.core.config import config
from opendevin.core.const.guide_url import TROUBLESHOOTING_URL
from opendevin.core.exceptions import SandboxInvalidBackgroundCommandError
from opendevin.core.logger import opendevin_logger as logger
from opendevin.core.schema import CancellableStream

View File

@@ -0,0 +1,18 @@
from opendevin.storage.files import FileStore
class E2BFileStore(FileStore):
def __init__(self, filesystem):
self.filesystem = filesystem
def write(self, path: str, contents: str) -> None:
self.filesystem.write(path, contents)
def read(self, path: str) -> str:
return self.filesystem.read(path)
def list(self, path: str) -> list[str]:
return self.filesystem.list(path)
def delete(self, path: str) -> None:
self.filesystem.delete(path)

View File

@@ -13,6 +13,7 @@ from opendevin.runtime import Sandbox
from opendevin.runtime.server.files import insert_lines, read_lines
from opendevin.runtime.server.runtime import ServerRuntime
from .filestore import E2BFileStore
from .sandbox import E2BSandbox
@@ -26,25 +27,25 @@ class E2BRuntime(ServerRuntime):
super().__init__(event_stream, sid, sandbox)
if not isinstance(self.sandbox, E2BSandbox):
raise ValueError('E2BRuntime requires an E2BSandbox')
self.filesystem = self.sandbox.filesystem
self.file_store = E2BFileStore(self.sandbox.filesystem)
async def read(self, action: FileReadAction) -> Observation:
content = self.filesystem.read(action.path)
content = self.file_store.read(action.path)
lines = read_lines(content.split('\n'), action.start, action.end)
code_view = ''.join(lines)
return FileReadObservation(code_view, path=action.path)
async def write(self, action: FileWriteAction) -> Observation:
if action.start == 0 and action.end == -1:
self.filesystem.write(action.path, action.content)
self.file_store.write(action.path, action.content)
return FileWriteObservation(content='', path=action.path)
files = self.filesystem.list(action.path)
files = self.file_store.list(action.path)
if action.path in files:
all_lines = self.filesystem.read(action.path)
all_lines = self.file_store.read(action.path).split('\n')
new_file = insert_lines(
action.content.split('\n'), all_lines, action.start, action.end
)
self.filesystem.write(action.path, ''.join(new_file))
self.file_store.write(action.path, ''.join(new_file))
return FileWriteObservation('', path=action.path)
else:
# FIXME: we should create a new file here

View File

@@ -31,6 +31,7 @@ from opendevin.runtime import (
)
from opendevin.runtime.browser.browser_env import BrowserEnv
from opendevin.runtime.plugins import PluginRequirement
from opendevin.storage import FileStore, InMemoryFileStore
def create_sandbox(sid: str = 'default', sandbox_type: str = 'exec') -> Sandbox:
@@ -55,6 +56,7 @@ class Runtime:
"""
sid: str
file_store: FileStore
def __init__(
self,
@@ -70,6 +72,7 @@ class Runtime:
self.sandbox = sandbox
self._is_external_sandbox = True
self.browser = BrowserEnv()
self.file_store = InMemoryFileStore()
self.event_stream = event_stream
self.event_stream.subscribe(EventStreamSubscriber.RUNTIME, self.on_event)
self._bg_task = asyncio.create_task(self._start_background_observation_loop())

View File

@@ -1,3 +1,4 @@
from opendevin.core.config import config
from opendevin.events.action import (
AgentRecallAction,
BrowseInteractiveAction,
@@ -15,13 +16,25 @@ from opendevin.events.observation import (
NullObservation,
Observation,
)
from opendevin.events.stream import EventStream
from opendevin.runtime import Sandbox
from opendevin.runtime.runtime import Runtime
from opendevin.storage.local import LocalFileStore
from .browse import browse
from .files import read_file, write_file
class ServerRuntime(Runtime):
def __init__(
self,
event_stream: EventStream,
sid: str = 'default',
sandbox: Sandbox | None = None,
):
super().__init__(event_stream, sid, sandbox)
self.file_store = LocalFileStore(config.workspace_base)
async def run(self, action: CmdRunAction) -> Observation:
return self._run_command(action.command, background=action.background)
@@ -71,10 +84,12 @@ class ServerRuntime(Runtime):
return IPythonRunCellObservation(content=output, code=action.code)
async def read(self, action: FileReadAction) -> Observation:
# TODO: use self.file_store
working_dir = self.sandbox.get_working_directory()
return await read_file(action.path, working_dir, action.start, action.end)
async def write(self, action: FileWriteAction) -> Observation:
# TODO: use self.file_store
working_dir = self.sandbox.get_working_directory()
return await write_file(
action.path, working_dir, action.content, action.start, action.end

View File

@@ -1,5 +0,0 @@
from .manager import AgentManager
agent_manager = AgentManager()
__all__ = ['AgentManager', 'agent_manager']

View File

@@ -1,161 +0,0 @@
from typing import Optional
from agenthub.codeact_agent.codeact_agent import CodeActAgent
from opendevin.const.guide_url import TROUBLESHOOTING_URL
from opendevin.controller import AgentController
from opendevin.controller.agent import Agent
from opendevin.core.config import config
from opendevin.core.logger import opendevin_logger as logger
from opendevin.core.schema import ActionType, AgentState, ConfigType
from opendevin.events.action import (
ChangeAgentStateAction,
NullAction,
)
from opendevin.events.event import Event
from opendevin.events.observation import (
NullObservation,
)
from opendevin.events.serialization.action import action_from_dict
from opendevin.events.serialization.event import event_to_dict
from opendevin.events.stream import EventSource, EventStream, EventStreamSubscriber
from opendevin.llm.llm import LLM
from opendevin.runtime import DockerSSHBox
from opendevin.runtime.e2b.runtime import E2BRuntime
from opendevin.runtime.runtime import Runtime
from opendevin.runtime.server.runtime import ServerRuntime
from opendevin.server.session import session_manager
class AgentUnit:
"""Represents a session with an agent.
Attributes:
controller: The AgentController instance for controlling the agent.
"""
sid: str
event_stream: EventStream
controller: Optional[AgentController] = None
runtime: Optional[Runtime] = None
def __init__(self, sid):
"""Initializes a new instance of the Session class."""
self.sid = sid
self.event_stream = EventStream(sid)
self.event_stream.subscribe(EventStreamSubscriber.SERVER, self.on_event)
if config.runtime == 'server':
logger.info('Using server runtime')
self.runtime = ServerRuntime(self.event_stream, sid)
elif config.runtime == 'e2b':
logger.info('Using E2B runtime')
self.runtime = E2BRuntime(self.event_stream, sid)
async def send_error(self, message):
"""Sends an error message to the client.
Args:
message: The error message to send.
"""
await session_manager.send_error(self.sid, message)
async def send_message(self, message):
"""Sends a message to the client.
Args:
message: The message to send.
"""
await session_manager.send_message(self.sid, message)
async def send(self, data):
"""Sends data to the client.
Args:
data: The data to send.
"""
await session_manager.send(self.sid, data)
async def dispatch(self, action: str | None, data: dict):
"""Dispatches actions to the agent from the client."""
if action is None:
await self.send_error('Invalid action')
return
if action == ActionType.INIT:
await self.create_controller(data)
await self.event_stream.add_event(
ChangeAgentStateAction(AgentState.INIT), EventSource.USER
)
return
action_dict = data.copy()
action_dict['action'] = action
action_obj = action_from_dict(action_dict)
await self.event_stream.add_event(action_obj, EventSource.USER)
async def create_controller(self, start_event: dict):
"""Creates an AgentController instance.
Args:
start_event: The start event data (optional).
"""
args = {
key: value
for key, value in start_event.get('args', {}).items()
if value != ''
} # remove empty values, prevent FE from sending empty strings
agent_cls = args.get(ConfigType.AGENT, config.agent.name)
model = args.get(ConfigType.LLM_MODEL, config.llm.model)
api_key = args.get(ConfigType.LLM_API_KEY, config.llm.api_key)
api_base = config.llm.base_url
max_iterations = args.get(ConfigType.MAX_ITERATIONS, config.max_iterations)
max_chars = args.get(ConfigType.MAX_CHARS, config.llm.max_chars)
logger.info(f'Creating agent {agent_cls} using LLM {model}')
llm = LLM(model=model, api_key=api_key, base_url=api_base)
agent = Agent.get_cls(agent_cls)(llm)
if isinstance(agent, CodeActAgent):
if not self.runtime or not isinstance(self.runtime.sandbox, DockerSSHBox):
logger.warning(
'CodeActAgent requires DockerSSHBox as sandbox! Using other sandbox that are not stateful (LocalBox, DockerExecBox) will not work properly.'
)
# Initializing plugins into the runtime
assert self.runtime is not None, 'Runtime is not initialized'
self.runtime.init_sandbox_plugins(agent.sandbox_plugins)
if self.controller is not None:
await self.controller.close()
try:
self.controller = AgentController(
sid=self.sid,
event_stream=self.event_stream,
agent=agent,
max_iterations=int(max_iterations),
max_chars=int(max_chars),
)
except Exception as e:
logger.exception(f'Error creating controller: {e}')
await self.send_error(
f'Error creating controller. Please check Docker is running and visit `{TROUBLESHOOTING_URL}` for more debugging information..'
)
return
async def on_event(self, event: Event):
"""Callback function for agent events.
Args:
event: The agent event (Observation or Action).
"""
if isinstance(event, NullAction):
return
if isinstance(event, NullObservation):
return
if event.source == 'agent' and not isinstance(
event, (NullAction, NullObservation)
):
await self.send(event_to_dict(event))
async def close(self):
if self.controller is not None:
await self.controller.close()
if self.runtime is not None:
self.runtime.close()

View File

@@ -1,48 +0,0 @@
import asyncio, atexit
from opendevin.core.logger import opendevin_logger as logger
from opendevin.server.session import session_manager
from .agent import AgentUnit
class AgentManager:
sid_to_agent: dict[str, 'AgentUnit'] = {}
def __init__(self):
atexit.register(self.close)
def register_agent(self, sid: str):
"""Registers a new agent.
Args:
sid: The session ID of the agent.
"""
if sid not in self.sid_to_agent:
self.sid_to_agent[sid] = AgentUnit(sid)
return
# TODO: confirm whether the agent is alive
async def dispatch(self, sid: str, action: str | None, data: dict):
"""Dispatches actions to the agent from the client."""
if sid not in self.sid_to_agent:
# self.register_agent(sid) # auto-register agent, may be opened later
logger.error(f'Agent not registered: {sid}')
await session_manager.send_error(sid, 'Agent not registered')
return
await self.sid_to_agent[sid].dispatch(action, data)
def close(self):
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(self._close())
async def _close(self):
logger.info(f'Closing {len(self.sid_to_agent)} agent(s)...')
for sid, agent in self.sid_to_agent.items():
await agent.close()

View File

@@ -1,12 +1,9 @@
import os
import jwt
from jwt.exceptions import InvalidTokenError
from opendevin.core.config import config
from opendevin.core.logger import opendevin_logger as logger
JWT_SECRET = os.getenv('JWT_SECRET', '5ecRe7')
def get_sid_from_token(token: str) -> str:
"""
@@ -20,7 +17,7 @@ def get_sid_from_token(token: str) -> str:
"""
try:
# Decode the JWT using the specified secret and algorithm
payload = jwt.decode(token, JWT_SECRET, algorithms=['HS256'])
payload = jwt.decode(token, config.jwt_secret, algorithms=['HS256'])
# Ensure the payload contains 'sid'
if 'sid' in payload:
@@ -41,4 +38,4 @@ def sign_token(payload: dict[str, object]) -> str:
# "sid": sid,
# # "exp": datetime.now(timezone.utc) + timedelta(minutes=15),
# }
return jwt.encode(payload, JWT_SECRET, algorithm='HS256')
return jwt.encode(payload, config.jwt_secret, algorithm='HS256')

View File

@@ -1,26 +1,25 @@
import os
import shutil
import uuid
import warnings
from pathlib import Path
with warnings.catch_warnings():
warnings.simplefilter('ignore')
import litellm
from fastapi import Depends, FastAPI, Request, Response, UploadFile, WebSocket, status
from fastapi import FastAPI, Request, Response, UploadFile, WebSocket, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, RedirectResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi.security import HTTPBearer
from fastapi.staticfiles import StaticFiles
import agenthub # noqa F401 (we import this to get the agents registered)
from opendevin.controller.agent import Agent
from opendevin.core.config import config
from opendevin.core.logger import opendevin_logger as logger
from opendevin.events.action import ChangeAgentStateAction, NullAction
from opendevin.events.observation import AgentStateChangedObservation, NullObservation
from opendevin.events.serialization import event_to_dict
from opendevin.llm import bedrock
from opendevin.server.agent import agent_manager
from opendevin.server.auth import get_sid_from_token, sign_token
from opendevin.server.session import message_stack, session_manager
from opendevin.server.session import session_manager
app = FastAPI()
app.add_middleware(
@@ -34,6 +33,45 @@ app.add_middleware(
security_scheme = HTTPBearer()
@app.middleware('http')
async def attach_session(request: Request, call_next):
if request.url.path.startswith('/api/options/') or not request.url.path.startswith(
'/api/'
):
response = await call_next(request)
return response
if not request.headers.get('Authorization'):
response = JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={'error': 'Missing Authorization header'},
)
return response
auth_token = request.headers.get('Authorization')
if 'Bearer' in auth_token:
auth_token = auth_token.split('Bearer')[1].strip()
request.state.sid = get_sid_from_token(auth_token)
if request.state.sid == '':
response = JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={'error': 'Invalid token'},
)
return response
request.state.session = session_manager.get_session(request.state.sid)
if request.state.session is None:
response = JSONResponse(
status_code=status.HTTP_404_NOT_FOUND,
content={'error': 'Session not found'},
)
return response
response = await call_next(request)
return response
# This endpoint receives events from the client (i.e. the browser)
@app.websocket('/ws')
async def websocket_endpoint(websocket: WebSocket):
@@ -99,16 +137,41 @@ async def websocket_endpoint(websocket: WebSocket):
```
"""
await websocket.accept()
sid = get_sid_from_token(websocket.query_params.get('token') or '')
if sid == '':
logger.error('Failed to decode token')
return
session_manager.add_session(sid, websocket)
agent_manager.register_agent(sid)
await session_manager.loop_recv(sid, agent_manager.dispatch)
session = None
if websocket.query_params.get('token'):
token = websocket.query_params.get('token')
sid = get_sid_from_token(token)
if sid == '':
await websocket.send_json({'error': 'Invalid token', 'error_code': 401})
await websocket.close()
return
else:
sid = str(uuid.uuid4())
token = sign_token({'sid': sid})
session = session_manager.add_or_restart_session(sid, websocket)
await websocket.send_json({'token': token, 'status': 'ok'})
last_event_id = -1
if websocket.query_params.get('last_event_id'):
last_event_id = int(websocket.query_params.get('last_event_id'))
for event in session.agent_session.event_stream.get_events(
start_id=last_event_id + 1
):
if isinstance(event, NullAction) or isinstance(event, NullObservation):
continue
if isinstance(event, ChangeAgentStateAction) or isinstance(
event, AgentStateChangedObservation
):
continue
await websocket.send_json(event_to_dict(event))
await session.loop_recv()
@app.get('/api/litellm-models')
@app.get('/api/options/models')
async def get_litellm_models():
"""
Get all models supported by LiteLLM.
@@ -128,7 +191,7 @@ async def get_litellm_models():
return list(set(model_list))
@app.get('/api/agents')
@app.get('/api/options/agents')
async def get_agents():
"""
Get all agents supported by LiteLLM.
@@ -142,89 +205,6 @@ async def get_agents():
return agents
@app.get('/api/auth')
async def get_token(
credentials: HTTPAuthorizationCredentials = Depends(security_scheme),
):
"""
Generate a JWT for authentication when starting a WebSocket connection. This endpoint checks if valid credentials
are provided and uses them to get a session ID. If no valid credentials are provided, it generates a new session ID.
To obtain an authentication token:
```sh
curl -H "Authorization: Bearer 5ecRe7" http://localhost:3000/api/auth
```
**Note:** If `JWT_SECRET` is set, use its value instead of `5ecRe7`.
"""
if credentials and credentials.credentials:
sid = get_sid_from_token(credentials.credentials)
if not sid:
sid = str(uuid.uuid4())
logger.info(
f'Invalid or missing credentials, generating new session ID: {sid}'
)
else:
sid = str(uuid.uuid4())
logger.info(f'No credentials provided, generating new session ID: {sid}')
token = sign_token({'sid': sid})
return {'token': token, 'status': 'ok'}
@app.get('/api/messages')
async def get_messages(
credentials: HTTPAuthorizationCredentials = Depends(security_scheme),
):
"""
Get messages.
To get messages:
```sh
curl -H "Authorization: Bearer <TOKEN>" http://localhost:3000/api/messages
```
"""
data = []
sid = get_sid_from_token(credentials.credentials)
if sid != '':
data = message_stack.get_messages(sid)
return {'messages': data}
@app.get('/api/messages/total')
async def get_message_total(
credentials: HTTPAuthorizationCredentials = Depends(security_scheme),
):
"""
Get total message count.
To get the total message count:
```sh
curl -H "Authorization: Bearer <TOKEN>" http://localhost:3000/api/messages/total
```
"""
sid = get_sid_from_token(credentials.credentials)
return {'msg_total': message_stack.get_message_total(sid)}
@app.delete('/api/messages')
async def del_messages(
credentials: HTTPAuthorizationCredentials = Depends(security_scheme),
):
"""
Delete messages.
To delete messages:
```sh
curl -X DELETE -H "Authorization: Bearer <TOKEN>" http://localhost:3000/api/messages
```
"""
sid = get_sid_from_token(credentials.credentials)
message_stack.del_messages(sid)
return {'ok': True}
@app.get('/api/list-files')
def list_files(request: Request, path: str = '/'):
"""
@@ -235,27 +215,25 @@ def list_files(request: Request, path: str = '/'):
curl http://localhost:3000/api/list-files
```
"""
if path.startswith('/'):
path = path[1:]
abs_path = os.path.join(config.workspace_base, path)
try:
files = os.listdir(abs_path)
except Exception as e:
logger.error(f'Error listing files: {e}', exc_info=False)
if not request.state.session.agent_session.runtime:
return JSONResponse(
status_code=status.HTTP_404_NOT_FOUND,
content={'error': 'Path not found'},
content={'error': 'Runtime not yet initialized'},
)
try:
return request.state.session.agent_session.runtime.file_store.list(path)
except Exception as e:
logger.error(f'Error refreshing files: {e}', exc_info=False)
error_msg = f'Error refreshing files: {e}'
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={'error': error_msg},
)
files = [os.path.join(path, f) for f in files]
files = [
f + '/' if os.path.isdir(os.path.join(config.workspace_base, f)) else f
for f in files
]
return files
@app.get('/api/select-file')
def select_file(file: str):
def select_file(file: str, request: Request):
"""
Select a file.
@@ -265,12 +243,7 @@ def select_file(file: str):
```
"""
try:
workspace_base = config.workspace_base
file_path = Path(workspace_base, file)
# The following will check if the file is within the workspace base and throw an exception if not
file_path.resolve().relative_to(Path(workspace_base).resolve())
with open(file_path, 'r') as selected_file:
content = selected_file.read()
content = request.state.session.agent_session.runtime.file_store.read(file)
except Exception as e:
logger.error(f'Error opening file {file}: {e}', exc_info=False)
error_msg = f'Error opening file: {e}'
@@ -282,7 +255,7 @@ def select_file(file: str):
@app.post('/api/upload-files')
async def upload_files(files: list[UploadFile]):
async def upload_file(request: Request, files: list[UploadFile]):
"""
Upload files to the workspace.
@@ -292,13 +265,11 @@ async def upload_files(files: list[UploadFile]):
```
"""
try:
workspace_base = config.workspace_base
for file in files:
file_path = Path(workspace_base, file.filename)
# The following will check if the file is within the workspace base and throw an exception if not
file_path.resolve().relative_to(Path(workspace_base).resolve())
with open(file_path, 'wb') as buffer:
shutil.copyfileobj(file.file, buffer)
file_contents = await file.read()
request.state.session.agent_session.runtime.file_store.write(
file.filename, file_contents
)
except Exception as e:
logger.error(f'Error saving files: {e}', exc_info=True)
return JSONResponse(
@@ -309,9 +280,7 @@ async def upload_files(files: list[UploadFile]):
@app.get('/api/root_task')
def get_root_task(
credentials: HTTPAuthorizationCredentials = Depends(security_scheme),
):
def get_root_task(request: Request):
"""
Get root_task.
@@ -320,9 +289,7 @@ def get_root_task(
curl -H "Authorization: Bearer <TOKEN>" http://localhost:3000/api/root_task
```
"""
sid = get_sid_from_token(credentials.credentials)
agent = agent_manager.sid_to_agent[sid]
controller = agent.controller
controller = request.state.session.agent_session.controller
if controller is not None:
state = controller.get_state()
if state:

View File

@@ -33,7 +33,7 @@ def read_root():
return {'message': 'This is a mock server'}
@app.get('/api/litellm-models')
@app.get('/api/options/models')
def read_llm_models():
return [
'gpt-4',
@@ -43,7 +43,7 @@ def read_llm_models():
]
@app.get('/api/agents')
@app.get('/api/options/agents')
def read_llm_agents():
return [
'MonologueAgent',
@@ -52,16 +52,6 @@ def read_llm_agents():
]
@app.get('/api/messages')
async def get_messages():
return {'messages': []}
@app.get('/api/messages/total')
async def get_message_total():
return {'msg_total': 0}
@app.get('/api/list-files')
def refresh_files():
return ['hello_world.py']

View File

@@ -1,5 +1,4 @@
from .manager import SessionManager
from .msg_stack import message_stack
from .session import Session
session_manager = SessionManager()

View File

@@ -0,0 +1,117 @@
from typing import Optional
from agenthub.codeact_agent.codeact_agent import CodeActAgent
from opendevin.controller import AgentController
from opendevin.controller.agent import Agent
from opendevin.controller.state.state import State
from opendevin.core.config import config
from opendevin.core.logger import opendevin_logger as logger
from opendevin.core.schema import ConfigType
from opendevin.events.stream import EventStream
from opendevin.llm.llm import LLM
from opendevin.runtime import DockerSSHBox
from opendevin.runtime.e2b.runtime import E2BRuntime
from opendevin.runtime.runtime import Runtime
from opendevin.runtime.server.runtime import ServerRuntime
class AgentSession:
"""Represents a session with an agent.
Attributes:
controller: The AgentController instance for controlling the agent.
"""
sid: str
event_stream: EventStream
controller: Optional[AgentController] = None
runtime: Optional[Runtime] = None
_closed: bool = False
def __init__(self, sid):
"""Initializes a new instance of the Session class."""
self.sid = sid
self.event_stream = EventStream(sid)
async def start(self, start_event: dict):
"""Starts the agent session.
Args:
start_event: The start event data (optional).
"""
if self.controller or self.runtime:
raise Exception(
'Session already started. You need to close this session and start a new one.'
)
await self._create_runtime()
await self._create_controller(start_event)
async def close(self):
if self._closed:
return
if self.controller is not None:
end_state = self.controller.get_state()
end_state.save_to_session(self.sid)
await self.controller.close()
if self.runtime is not None:
self.runtime.close()
self._closed = True
async def _create_runtime(self):
if self.runtime is not None:
raise Exception('Runtime already created')
if config.runtime == 'server':
logger.info('Using server runtime')
self.runtime = ServerRuntime(self.event_stream, self.sid)
elif config.runtime == 'e2b':
logger.info('Using E2B runtime')
self.runtime = E2BRuntime(self.event_stream, self.sid)
else:
raise Exception(
f'Runtime not defined in config, or is invalid: {config.runtime}'
)
async def _create_controller(self, start_event: dict):
"""Creates an AgentController instance.
Args:
start_event: The start event data (optional).
"""
if self.controller is not None:
raise Exception('Controller already created')
if self.runtime is None:
raise Exception('Runtime must be initialized before the agent controller')
args = {
key: value
for key, value in start_event.get('args', {}).items()
if value != ''
} # remove empty values, prevent FE from sending empty strings
agent_cls = args.get(ConfigType.AGENT, config.agent.name)
model = args.get(ConfigType.LLM_MODEL, config.llm.model)
api_key = args.get(ConfigType.LLM_API_KEY, config.llm.api_key)
api_base = config.llm.base_url
max_iterations = args.get(ConfigType.MAX_ITERATIONS, config.max_iterations)
max_chars = args.get(ConfigType.MAX_CHARS, config.llm.max_chars)
logger.info(f'Creating agent {agent_cls} using LLM {model}')
llm = LLM(model=model, api_key=api_key, base_url=api_base)
agent = Agent.get_cls(agent_cls)(llm)
if isinstance(agent, CodeActAgent):
if not self.runtime or not isinstance(self.runtime.sandbox, DockerSSHBox):
logger.warning(
'CodeActAgent requires DockerSSHBox as sandbox! Using other sandbox that are not stateful (LocalBox, DockerExecBox) will not work properly.'
)
self.runtime.init_sandbox_plugins(agent.sandbox_plugins)
self.controller = AgentController(
sid=self.sid,
event_stream=self.event_stream,
agent=agent,
max_iterations=int(max_iterations),
max_chars=int(max_chars),
)
try:
agent_state = State.restore_from_session(self.sid)
self.controller.set_state(agent_state)
except Exception as e:
print('Error restoring state', e)

View File

@@ -1,20 +1,12 @@
import asyncio
import atexit
import json
import os
import time
from typing import Callable
from fastapi import WebSocket
from opendevin.core.logger import opendevin_logger as logger
from .msg_stack import message_stack
from .session import Session
CACHE_DIR = os.getenv('CACHE_DIR', 'cache')
SESSION_CACHE_FILE = os.path.join(CACHE_DIR, 'sessions.json')
class SessionManager:
_sessions: dict[str, Session] = {}
@@ -22,30 +14,21 @@ class SessionManager:
session_timeout: int = 600
def __init__(self):
self._load_sessions()
atexit.register(self.close)
asyncio.create_task(self._cleanup_sessions())
def add_session(self, sid: str, ws_conn: WebSocket):
if sid not in self._sessions:
self._sessions[sid] = Session(sid=sid, ws=ws_conn)
return
self._sessions[sid].update_connection(ws_conn)
def add_or_restart_session(self, sid: str, ws_conn: WebSocket) -> Session:
if sid in self._sessions:
asyncio.create_task(self._sessions[sid].close())
self._sessions[sid] = Session(sid=sid, ws=ws_conn)
return self._sessions[sid]
async def loop_recv(self, sid: str, dispatch: Callable):
print(f'Starting loop_recv for sid: {sid}')
"""Starts listening for messages from the client."""
def get_session(self, sid: str) -> Session | None:
if sid not in self._sessions:
return
await self._sessions[sid].loop_recv(dispatch)
def close(self):
logger.info('Saving sessions...')
self._save_sessions()
return None
return self._sessions.get(sid)
async def send(self, sid: str, data: dict[str, object]) -> bool:
"""Sends data to the client."""
message_stack.add_message(sid, 'assistant', data)
if sid not in self._sessions:
return False
return await self._sessions[sid].send(data)
@@ -58,33 +41,6 @@ class SessionManager:
"""Sends a message to the client."""
return await self.send(sid, {'message': message})
def _save_sessions(self):
data = {}
for sid, conn in self._sessions.items():
data[sid] = {
'sid': conn.sid,
'last_active_ts': conn.last_active_ts,
'is_alive': conn.is_alive,
}
if not os.path.exists(CACHE_DIR):
os.makedirs(CACHE_DIR)
with open(SESSION_CACHE_FILE, 'w+') as file:
json.dump(data, file)
def _load_sessions(self):
try:
with open(SESSION_CACHE_FILE, 'r') as file:
data = json.load(file)
for sid, sdata in data.items():
conn = Session(sid, None)
ok = conn.load_from_data(sdata)
if ok:
self._sessions[sid] = conn
except FileNotFoundError:
pass
except json.decoder.JSONDecodeError:
pass
async def _cleanup_sessions(self):
while True:
current_time = time.time()

View File

@@ -1,114 +0,0 @@
import asyncio
import atexit
import json
import os
import uuid
from opendevin.core.logger import opendevin_logger as logger
from opendevin.core.schema.action import ActionType
CACHE_DIR = os.getenv('CACHE_DIR', 'cache')
MSG_CACHE_FILE = os.path.join(CACHE_DIR, 'messages.json')
class Message:
id: str = str(uuid.uuid4())
role: str # "user"| "assistant"
payload: dict[str, object]
def __init__(self, role: str, payload: dict[str, object]):
self.role = role
self.payload = payload
def to_dict(self):
return {'id': self.id, 'role': self.role, 'payload': self.payload}
@classmethod
def from_dict(cls, data: dict):
m = cls(data['role'], data['payload'])
m.id = data['id']
return m
class MessageStack:
_messages: dict[str, list[Message]] = {}
def __init__(self):
self._load_messages()
atexit.register(self.close)
def close(self):
logger.info('Saving messages...')
self._save_messages()
def add_message(self, sid: str, role: str, message: dict[str, object]):
if sid not in self._messages:
self._messages[sid] = []
self._messages[sid].append(Message(role, message))
def del_messages(self, sid: str):
if sid not in self._messages:
return
del self._messages[sid]
asyncio.create_task(self._del_messages(sid))
def get_messages(self, sid: str) -> list[dict[str, object]]:
if sid not in self._messages:
return []
return [msg.to_dict() for msg in self._messages[sid]]
def get_message_total(self, sid: str) -> int:
if sid not in self._messages:
return 0
cnt = 0
for msg in self._messages[sid]:
# Ignore assistant init message for now.
if 'action' in msg.payload and msg.payload['action'] in [
ActionType.INIT,
ActionType.CHANGE_AGENT_STATE,
]:
continue
cnt += 1
return cnt
def _save_messages(self):
if not os.path.exists(CACHE_DIR):
os.makedirs(CACHE_DIR)
data = {}
for sid, msgs in self._messages.items():
data[sid] = [msg.to_dict() for msg in msgs]
with open(MSG_CACHE_FILE, 'w+') as file:
json.dump(data, file)
def _load_messages(self):
try:
with open(MSG_CACHE_FILE, 'r') as file:
data = json.load(file)
for sid, msgs in data.items():
self._messages[sid] = [Message.from_dict(msg) for msg in msgs]
except FileNotFoundError:
pass
except json.decoder.JSONDecodeError:
pass
async def _del_messages(self, del_sid: str):
logger.info('Deleting messages...')
try:
with open(MSG_CACHE_FILE, 'r+') as file:
data = json.load(file)
new_data = {}
for sid, msgs in data.items():
if sid != del_sid:
new_data[sid] = msgs
# Move the file pointer to the beginning of the file to overwrite the original contents
file.seek(0)
# clean previous content
file.truncate()
json.dump(new_data, file)
except FileNotFoundError:
pass
except json.decoder.JSONDecodeError:
pass
message_stack = MessageStack()

View File

@@ -1,11 +1,19 @@
import asyncio
import time
from typing import Callable
from fastapi import WebSocket, WebSocketDisconnect
from opendevin.core.const.guide_url import TROUBLESHOOTING_URL
from opendevin.core.logger import opendevin_logger as logger
from opendevin.core.schema import AgentState
from opendevin.core.schema.action import ActionType
from opendevin.events.action import ChangeAgentStateAction, NullAction
from opendevin.events.event import Event
from opendevin.events.observation import AgentStateChangedObservation, NullObservation
from opendevin.events.serialization import EventSource, event_from_dict, event_to_dict
from opendevin.events.stream import EventStreamSubscriber
from .msg_stack import message_stack
from .agent import AgentSession
DEL_DELT_SEC = 60 * 60 * 5
@@ -15,13 +23,22 @@ class Session:
websocket: WebSocket | None
last_active_ts: int = 0
is_alive: bool = True
agent_session: AgentSession
def __init__(self, sid: str, ws: WebSocket | None):
self.sid = sid
self.websocket = ws
self.last_active_ts = int(time.time())
self.agent_session = AgentSession(sid)
self.agent_session.event_stream.subscribe(
EventStreamSubscriber.SERVER, self.on_event
)
async def loop_recv(self, dispatch: Callable):
async def close(self):
self.is_alive = False
await self.agent_session.close()
async def loop_recv(self):
try:
if self.websocket is None:
return
@@ -31,24 +48,62 @@ class Session:
except ValueError:
await self.send_error('Invalid JSON')
continue
message_stack.add_message(self.sid, 'user', data)
action = data.get('action', None)
await dispatch(self.sid, action, data)
await self.dispatch(data)
except WebSocketDisconnect:
self.is_alive = False
await self.close()
logger.info('WebSocket disconnected, sid: %s', self.sid)
except RuntimeError as e:
# WebSocket is not connected
if 'WebSocket is not connected' in str(e):
self.is_alive = False
await self.close()
logger.exception('Error in loop_recv: %s', e)
async def _initialize_agent(self, data: dict):
await self.agent_session.event_stream.add_event(
ChangeAgentStateAction(AgentState.LOADING), EventSource.USER
)
await self.agent_session.event_stream.add_event(
AgentStateChangedObservation('', AgentState.LOADING), EventSource.AGENT
)
try:
await self.agent_session.start(data)
except Exception as e:
logger.exception(f'Error creating controller: {e}')
await self.send_error(
f'Error creating controller. Please check Docker is running and visit `{TROUBLESHOOTING_URL}` for more debugging information..'
)
return
await self.agent_session.event_stream.add_event(
ChangeAgentStateAction(AgentState.INIT), EventSource.USER
)
async def on_event(self, event: Event):
"""Callback function for agent events.
Args:
event: The agent event (Observation or Action).
"""
if isinstance(event, NullAction):
return
if isinstance(event, NullObservation):
return
if event.source == EventSource.AGENT and not isinstance(
event, (NullAction, NullObservation)
):
await self.send(event_to_dict(event))
async def dispatch(self, data: dict):
action = data.get('action', '')
if action == ActionType.INIT:
await self._initialize_agent(data)
return
event = event_from_dict(data.copy())
await self.agent_session.event_stream.add_event(event, EventSource.USER)
async def send(self, data: dict[str, object]) -> bool:
try:
if self.websocket is None or not self.is_alive:
return False
await self.websocket.send_json(data)
await asyncio.sleep(0.001) # This flushes the data to the client
self.last_active_ts = int(time.time())
return True
except WebSocketDisconnect:

View File

@@ -28,7 +28,9 @@ class LocalFileStore(FileStore):
def list(self, path: str) -> list[str]:
full_path = self.get_full_path(path)
return [os.path.join(path, f) for f in os.listdir(full_path)]
files = [os.path.join(path, f) for f in os.listdir(full_path)]
files = [f + '/' if os.path.isdir(self.get_full_path(f)) else f for f in files]
return files
def delete(self, path: str) -> None:
full_path = self.get_full_path(path)

View File

@@ -30,6 +30,8 @@ class InMemoryFileStore(FileStore):
files.append(file)
else:
dir_path = os.path.join(path, parts[0])
if not dir_path.endswith('/'):
dir_path += '/'
if dir_path not in files:
files.append(dir_path)
return files

View File

@@ -54,10 +54,8 @@ def test_deep_list(setup_env):
store.write('foo/bar/baz.txt', 'Hello, world!')
store.write('foo/bar/qux.txt', 'Hello, world!')
store.write('foo/bar/quux.txt', 'Hello, world!')
assert store.list('') == ['foo'], 'Expected foo, got {} for class {}'.format(
store.list(''), store.__class__
)
assert store.list('foo') == ['foo/bar']
assert store.list('') == ['foo/'], f'for class {store.__class__}'
assert store.list('foo') == ['foo/bar/']
assert (
store.list('foo/bar').sort()
== ['foo/bar/baz.txt', 'foo/bar/qux.txt', 'foo/bar/quux.txt'].sort()