mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
17 Commits
external-l
...
allow-mess
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7a2e7d6270 | ||
|
|
f5c58adaf7 | ||
|
|
c6cb025afe | ||
|
|
cb8214676e | ||
|
|
fdfb7308b8 | ||
|
|
4785de91b0 | ||
|
|
effd2b7d06 | ||
|
|
608dd8f2c2 | ||
|
|
6d0c03509e | ||
|
|
3e1070bbe9 | ||
|
|
2045350720 | ||
|
|
5f83d4cf9a | ||
|
|
d5a996a9e1 | ||
|
|
b0d38bbeb8 | ||
|
|
ed50b3ee8f | ||
|
|
4e5ed36213 | ||
|
|
9060452af6 |
@@ -4,6 +4,7 @@ import { render, screen, waitFor } from "@testing-library/react";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { AnalyticsConsentFormModal } from "#/components/features/analytics/analytics-consent-form-modal";
|
||||
import OpenHands from "#/api/open-hands";
|
||||
import { SettingsProvider } from "#/context/settings-context";
|
||||
import { AuthProvider } from "#/context/auth-context";
|
||||
|
||||
describe("AnalyticsConsentFormModal", () => {
|
||||
@@ -16,7 +17,7 @@ describe("AnalyticsConsentFormModal", () => {
|
||||
wrapper: ({ children }) => (
|
||||
<AuthProvider>
|
||||
<QueryClientProvider client={new QueryClient()}>
|
||||
{children}
|
||||
<SettingsProvider>{children}</SettingsProvider>
|
||||
</QueryClientProvider>
|
||||
</AuthProvider>
|
||||
),
|
||||
|
||||
@@ -18,6 +18,7 @@ describe("ConversationCard", () => {
|
||||
const onClick = vi.fn();
|
||||
const onDelete = vi.fn();
|
||||
const onChangeTitle = vi.fn();
|
||||
const onDownloadWorkspace = vi.fn();
|
||||
|
||||
beforeAll(() => {
|
||||
vi.stubGlobal("window", { open: vi.fn() });
|
||||
@@ -268,7 +269,30 @@ describe("ConversationCard", () => {
|
||||
expect(onClick).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("should call onDownloadWorkspace when the download button is clicked", async () => {
|
||||
const user = userEvent.setup();
|
||||
render(
|
||||
<ConversationCard
|
||||
onClick={onClick}
|
||||
onDelete={onDelete}
|
||||
onChangeTitle={onChangeTitle}
|
||||
onDownloadWorkspace={onDownloadWorkspace}
|
||||
title="Conversation 1"
|
||||
selectedRepository={null}
|
||||
lastUpdatedAt="2021-10-01T12:00:00Z"
|
||||
/>,
|
||||
);
|
||||
|
||||
const ellipsisButton = screen.getByTestId("ellipsis-button");
|
||||
await user.click(ellipsisButton);
|
||||
|
||||
const menu = screen.getByTestId("context-menu");
|
||||
const downloadButton = within(menu).getByTestId("download-button");
|
||||
|
||||
await user.click(downloadButton);
|
||||
|
||||
expect(onDownloadWorkspace).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("should not display the edit or delete options if the handler is not provided", async () => {
|
||||
const user = userEvent.setup();
|
||||
@@ -313,6 +337,7 @@ describe("ConversationCard", () => {
|
||||
onClick={onClick}
|
||||
onDelete={onDelete}
|
||||
onChangeTitle={onChangeTitle}
|
||||
onDownloadWorkspace={onDownloadWorkspace}
|
||||
title="Conversation 1"
|
||||
selectedRepository={null}
|
||||
lastUpdatedAt="2021-10-01T12:00:00Z"
|
||||
@@ -325,6 +350,7 @@ describe("ConversationCard", () => {
|
||||
<ConversationCard
|
||||
onClick={onClick}
|
||||
onDelete={onDelete}
|
||||
onDownloadWorkspace={onDownloadWorkspace}
|
||||
title="Conversation 1"
|
||||
selectedRepository={null}
|
||||
lastUpdatedAt="2021-10-01T12:00:00Z"
|
||||
@@ -333,6 +359,18 @@ describe("ConversationCard", () => {
|
||||
|
||||
expect(screen.getByTestId("ellipsis-button")).toBeInTheDocument();
|
||||
|
||||
rerender(
|
||||
<ConversationCard
|
||||
onClick={onClick}
|
||||
onDownloadWorkspace={onDownloadWorkspace}
|
||||
title="Conversation 1"
|
||||
selectedRepository={null}
|
||||
lastUpdatedAt="2021-10-01T12:00:00Z"
|
||||
/>,
|
||||
);
|
||||
|
||||
expect(screen.queryByTestId("ellipsis-button")).toBeInTheDocument();
|
||||
|
||||
rerender(
|
||||
<ConversationCard
|
||||
onClick={onClick}
|
||||
|
||||
@@ -59,7 +59,7 @@ describe("TrajectoryActions", () => {
|
||||
expect(onNegativeFeedback).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("should call onExportTrajectory when export button is clicked", async () => {
|
||||
it("should call onExportTrajectory when negative feedback is clicked", async () => {
|
||||
renderWithProviders(
|
||||
<TrajectoryActions
|
||||
onPositiveFeedback={onPositiveFeedback}
|
||||
|
||||
@@ -3,18 +3,15 @@ import { describe, expect, it, vi } from "vitest";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import OpenHands from "#/api/open-hands";
|
||||
import { useSaveSettings } from "#/hooks/mutation/use-save-settings";
|
||||
import { AuthProvider } from "#/context/auth-context";
|
||||
|
||||
describe("useSaveSettings", () => {
|
||||
it("should send an empty string for llm_api_key if an empty string is passed, otherwise undefined", async () => {
|
||||
const saveSettingsSpy = vi.spyOn(OpenHands, "saveSettings");
|
||||
const { result } = renderHook(() => useSaveSettings(), {
|
||||
wrapper: ({ children }) => (
|
||||
<AuthProvider>
|
||||
<QueryClientProvider client={new QueryClient()}>
|
||||
{children}
|
||||
</QueryClientProvider>
|
||||
</AuthProvider>
|
||||
<QueryClientProvider client={new QueryClient()}>
|
||||
{children}
|
||||
</QueryClientProvider>
|
||||
),
|
||||
});
|
||||
|
||||
|
||||
@@ -55,8 +55,7 @@ describe("frontend/routes/_oh", () => {
|
||||
});
|
||||
});
|
||||
|
||||
// FIXME: This test fails when it shouldn't be, please investigate
|
||||
it.skip("should render and capture the user's consent if oss mode", async () => {
|
||||
it("should render and capture the user's consent if oss mode", async () => {
|
||||
const user = userEvent.setup();
|
||||
const getConfigSpy = vi.spyOn(OpenHands, "getConfig");
|
||||
const getSettingsSpy = vi.spyOn(OpenHands, "getSettings");
|
||||
|
||||
@@ -4,7 +4,7 @@ import {
|
||||
} from "#/components/shared/modals/confirmation-modals/base-modal";
|
||||
import { ModalBackdrop } from "#/components/shared/modals/modal-backdrop";
|
||||
import { ModalBody } from "#/components/shared/modals/modal-body";
|
||||
import { useSaveSettings } from "#/hooks/mutation/use-save-settings";
|
||||
import { useCurrentSettings } from "#/context/settings-context";
|
||||
import { handleCaptureConsent } from "#/utils/handle-capture-consent";
|
||||
import { BrandButton } from "../settings/brand-button";
|
||||
|
||||
@@ -15,14 +15,14 @@ interface AnalyticsConsentFormModalProps {
|
||||
export function AnalyticsConsentFormModal({
|
||||
onClose,
|
||||
}: AnalyticsConsentFormModalProps) {
|
||||
const { mutate: saveUserSettings } = useSaveSettings();
|
||||
const { saveUserSettings } = useCurrentSettings();
|
||||
|
||||
const handleSubmit = async (e: React.FormEvent<HTMLFormElement>) => {
|
||||
e.preventDefault();
|
||||
const formData = new FormData(e.currentTarget);
|
||||
const analytics = formData.get("analytics") === "on";
|
||||
|
||||
saveUserSettings(
|
||||
await saveUserSettings(
|
||||
{ user_consents_to_analytics: analytics },
|
||||
{
|
||||
onSuccess: () => {
|
||||
|
||||
@@ -2,6 +2,7 @@ import posthog from "posthog-js";
|
||||
import React from "react";
|
||||
import { useSelector } from "react-redux";
|
||||
import { SuggestionItem } from "#/components/features/suggestions/suggestion-item";
|
||||
import { DownloadModal } from "#/components/shared/download-modal";
|
||||
import type { RootState } from "#/store";
|
||||
import { useAuth } from "#/context/auth-context";
|
||||
|
||||
@@ -17,11 +18,21 @@ export function ActionSuggestions({
|
||||
(state: RootState) => state.initialQuery,
|
||||
);
|
||||
|
||||
const [isDownloading, setIsDownloading] = React.useState(false);
|
||||
const [hasPullRequest, setHasPullRequest] = React.useState(false);
|
||||
|
||||
const handleDownloadClose = () => {
|
||||
setIsDownloading(false);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-2 mb-2">
|
||||
{githubTokenIsSet && selectedRepository && (
|
||||
<DownloadModal
|
||||
initialPath=""
|
||||
onClose={handleDownloadClose}
|
||||
isOpen={isDownloading}
|
||||
/>
|
||||
{githubTokenIsSet && selectedRepository ? (
|
||||
<div className="flex flex-row gap-2 justify-center w-full">
|
||||
{!hasPullRequest ? (
|
||||
<>
|
||||
@@ -63,6 +74,21 @@ export function ActionSuggestions({
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
) : (
|
||||
<SuggestionItem
|
||||
suggestion={{
|
||||
label: !isDownloading
|
||||
? "Download files"
|
||||
: "Downloading, please wait...",
|
||||
value: "Download files",
|
||||
}}
|
||||
onClick={() => {
|
||||
posthog.capture("download_workspace_button_clicked");
|
||||
if (!isDownloading) {
|
||||
setIsDownloading(true);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -13,7 +13,10 @@ import { generateAgentStateChangeEvent } from "#/services/agent-state-service";
|
||||
import { FeedbackModal } from "../feedback/feedback-modal";
|
||||
import { useScrollToBottom } from "#/hooks/use-scroll-to-bottom";
|
||||
import { TypingIndicator } from "./typing-indicator";
|
||||
import { useWsClient } from "#/context/ws-client-provider";
|
||||
import {
|
||||
useWsClient,
|
||||
WsClientProviderStatus,
|
||||
} from "#/context/ws-client-provider";
|
||||
import { Messages } from "./messages";
|
||||
import { ChatSuggestions } from "./chat-suggestions";
|
||||
import { ActionSuggestions } from "./action-suggestions";
|
||||
@@ -21,7 +24,7 @@ import { ContinueButton } from "#/components/shared/buttons/continue-button";
|
||||
import { ScrollToBottomButton } from "#/components/shared/buttons/scroll-to-bottom-button";
|
||||
import { LoadingSpinner } from "#/components/shared/loading-spinner";
|
||||
import { useGetTrajectory } from "#/hooks/mutation/use-get-trajectory";
|
||||
import { downloadTrajectory } from "#/utils/download-trajectory";
|
||||
import { downloadTrajectory } from "#/utils/download-files";
|
||||
import { displayErrorToast } from "#/utils/custom-toast-handlers";
|
||||
|
||||
function getEntryPoint(
|
||||
@@ -34,7 +37,7 @@ function getEntryPoint(
|
||||
}
|
||||
|
||||
export function ChatInterface() {
|
||||
const { send, isLoadingMessages } = useWsClient();
|
||||
const { send, isLoadingMessages, status, pendingMessages } = useWsClient();
|
||||
const dispatch = useDispatch();
|
||||
const scrollRef = React.useRef<HTMLDivElement>(null);
|
||||
const { scrollDomToBottom, onChatBodyScroll, hitBottom } =
|
||||
@@ -54,6 +57,9 @@ export function ChatInterface() {
|
||||
const params = useParams();
|
||||
const { mutate: getTrajectory } = useGetTrajectory();
|
||||
|
||||
const isClientDisconnected = status === WsClientProviderStatus.DISCONNECTED;
|
||||
const hasPendingMessages = pendingMessages.length > 0;
|
||||
|
||||
const handleSendMessage = async (content: string, files: File[]) => {
|
||||
if (messages.length === 0) {
|
||||
posthog.capture("initial_query_submitted", {
|
||||
@@ -76,7 +82,15 @@ export function ChatInterface() {
|
||||
const timestamp = new Date().toISOString();
|
||||
const pending = true;
|
||||
dispatch(addUserMessage({ content, imageUrls, timestamp, pending }));
|
||||
send(createChatMessage(content, imageUrls, timestamp));
|
||||
|
||||
// Create and send the chat message
|
||||
const chatMessage = createChatMessage(content, imageUrls, timestamp);
|
||||
send(chatMessage);
|
||||
|
||||
// Send the agent state change event immediately
|
||||
// The backend will handle the ordering and queueing
|
||||
send(generateAgentStateChangeEvent(AgentState.RUNNING));
|
||||
|
||||
setMessageToSend(null);
|
||||
};
|
||||
|
||||
@@ -131,8 +145,20 @@ export function ChatInterface() {
|
||||
className="flex flex-col grow overflow-y-auto overflow-x-hidden px-4 pt-4 gap-2"
|
||||
>
|
||||
{isLoadingMessages && (
|
||||
<div className="flex justify-center">
|
||||
<div className="flex flex-col items-center gap-2">
|
||||
<LoadingSpinner size="small" />
|
||||
{isClientDisconnected && (
|
||||
<div className="text-sm text-neutral-400">
|
||||
Waiting for client to become ready...
|
||||
{hasPendingMessages && (
|
||||
<div className="text-xs text-neutral-500 mt-1">
|
||||
{pendingMessages.length} message
|
||||
{pendingMessages.length !== 1 ? "s" : ""} will be sent when
|
||||
connected
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -179,7 +205,7 @@ export function ChatInterface() {
|
||||
onSubmit={handleSendMessage}
|
||||
onStop={handleStop}
|
||||
isDisabled={
|
||||
curAgentState === AgentState.LOADING ||
|
||||
// Allow input even when loading, but not during confirmation
|
||||
curAgentState === AgentState.AWAITING_USER_CONFIRMATION
|
||||
}
|
||||
mode={curAgentState === AgentState.RUNNING ? "stop" : "submit"}
|
||||
|
||||
@@ -22,10 +22,11 @@ export function AgentStatusBar() {
|
||||
const { t, i18n } = useTranslation();
|
||||
const { curAgentState } = useSelector((state: RootState) => state.agent);
|
||||
const { curStatusMessage } = useSelector((state: RootState) => state.status);
|
||||
const { status } = useWsClient();
|
||||
const { status, pendingMessages } = useWsClient();
|
||||
const { notify } = useNotification();
|
||||
|
||||
const [statusMessage, setStatusMessage] = React.useState<string>("");
|
||||
const hasPendingMessages = pendingMessages.length > 0;
|
||||
|
||||
const updateStatusMessage = () => {
|
||||
let message = curStatusMessage.message || "";
|
||||
@@ -71,7 +72,13 @@ export function AgentStatusBar() {
|
||||
|
||||
React.useEffect(() => {
|
||||
if (status === WsClientProviderStatus.DISCONNECTED) {
|
||||
setStatusMessage("Connecting...");
|
||||
if (hasPendingMessages) {
|
||||
setStatusMessage(
|
||||
`Connecting... (${pendingMessages.length} pending message${pendingMessages.length !== 1 ? "s" : ""})`,
|
||||
);
|
||||
} else {
|
||||
setStatusMessage("Connecting...");
|
||||
}
|
||||
} else {
|
||||
setStatusMessage(AGENT_STATUS_MAP[curAgentState].message);
|
||||
if (notificationStates.includes(curAgentState)) {
|
||||
@@ -87,7 +94,7 @@ export function AgentStatusBar() {
|
||||
}
|
||||
}
|
||||
}
|
||||
}, [curAgentState, notify, t]);
|
||||
}, [curAgentState, status, pendingMessages.length, notify, t]);
|
||||
|
||||
return (
|
||||
<div className="flex flex-col items-center">
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import { useParams } from "react-router";
|
||||
import React from "react";
|
||||
import posthog from "posthog-js";
|
||||
import { AgentControlBar } from "./agent-control-bar";
|
||||
import { AgentStatusBar } from "./agent-status-bar";
|
||||
import { SecurityLock } from "./security-lock";
|
||||
import { useUserConversation } from "#/hooks/query/use-user-conversation";
|
||||
import { ConversationCard } from "../conversation-panel/conversation-card";
|
||||
import { DownloadModal } from "#/components/shared/download-modal";
|
||||
|
||||
interface ControlsProps {
|
||||
setSecurityOpen: (isOpen: boolean) => void;
|
||||
@@ -17,6 +19,13 @@ export function Controls({ setSecurityOpen, showSecurityLock }: ControlsProps) {
|
||||
params.conversationId ?? null,
|
||||
);
|
||||
|
||||
const [downloading, setDownloading] = React.useState(false);
|
||||
|
||||
const handleDownloadWorkspace = () => {
|
||||
posthog.capture("download_workspace_button_clicked");
|
||||
setDownloading(true);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center gap-2">
|
||||
@@ -30,11 +39,17 @@ export function Controls({ setSecurityOpen, showSecurityLock }: ControlsProps) {
|
||||
|
||||
<ConversationCard
|
||||
variant="compact"
|
||||
onDownloadWorkspace={handleDownloadWorkspace}
|
||||
title={conversation?.title ?? ""}
|
||||
lastUpdatedAt={conversation?.created_at ?? ""}
|
||||
selectedRepository={conversation?.selected_repository ?? null}
|
||||
status={conversation?.status}
|
||||
conversationId={conversation?.conversation_id}
|
||||
/>
|
||||
|
||||
<DownloadModal
|
||||
initialPath=""
|
||||
onClose={() => setDownloading(false)}
|
||||
isOpen={downloading}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -7,7 +7,7 @@ interface ConversationCardContextMenuProps {
|
||||
onClose: () => void;
|
||||
onDelete?: (event: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
onEdit?: (event: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
onDownloadViaVSCode?: (event: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
onDownload?: (event: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
position?: "top" | "bottom";
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ export function ConversationCardContextMenu({
|
||||
onClose,
|
||||
onDelete,
|
||||
onEdit,
|
||||
onDownloadViaVSCode,
|
||||
onDownload,
|
||||
position = "bottom",
|
||||
}: ConversationCardContextMenuProps) {
|
||||
const ref = useClickOutsideElement<HTMLUListElement>(onClose);
|
||||
@@ -40,12 +40,9 @@ export function ConversationCardContextMenu({
|
||||
Edit Title
|
||||
</ContextMenuListItem>
|
||||
)}
|
||||
{onDownloadViaVSCode && (
|
||||
<ContextMenuListItem
|
||||
testId="download-vscode-button"
|
||||
onClick={onDownloadViaVSCode}
|
||||
>
|
||||
Download via VS Code
|
||||
{onDownload && (
|
||||
<ContextMenuListItem testId="download-button" onClick={onDownload}>
|
||||
Download Workspace
|
||||
</ContextMenuListItem>
|
||||
)}
|
||||
</ContextMenu>
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import React from "react";
|
||||
import posthog from "posthog-js";
|
||||
import { formatTimeDelta } from "#/utils/format-time-delta";
|
||||
import { ConversationRepoLink } from "./conversation-repo-link";
|
||||
import {
|
||||
@@ -14,34 +13,31 @@ interface ConversationCardProps {
|
||||
onClick?: () => void;
|
||||
onDelete?: () => void;
|
||||
onChangeTitle?: (title: string) => void;
|
||||
onDownloadWorkspace?: () => void;
|
||||
isActive?: boolean;
|
||||
title: string;
|
||||
selectedRepository: string | null;
|
||||
lastUpdatedAt: string; // ISO 8601
|
||||
status?: ProjectStatus;
|
||||
variant?: "compact" | "default";
|
||||
conversationId?: string; // Optional conversation ID for VS Code URL
|
||||
}
|
||||
|
||||
export function ConversationCard({
|
||||
onClick,
|
||||
onDelete,
|
||||
onChangeTitle,
|
||||
onDownloadWorkspace,
|
||||
isActive,
|
||||
title,
|
||||
selectedRepository,
|
||||
lastUpdatedAt,
|
||||
status = "STOPPED",
|
||||
variant = "default",
|
||||
conversationId,
|
||||
}: ConversationCardProps) {
|
||||
const [contextMenuVisible, setContextMenuVisible] = React.useState(false);
|
||||
const [titleMode, setTitleMode] = React.useState<"view" | "edit">("view");
|
||||
const inputRef = React.useRef<HTMLInputElement>(null);
|
||||
|
||||
// We don't use the VS Code URL hook directly here to avoid test failures
|
||||
// Instead, we'll add the download button conditionally
|
||||
|
||||
const handleBlur = () => {
|
||||
if (inputRef.current?.value) {
|
||||
const trimmed = inputRef.current.value.trim();
|
||||
@@ -82,32 +78,9 @@ export function ConversationCard({
|
||||
setContextMenuVisible(false);
|
||||
};
|
||||
|
||||
const handleDownloadViaVSCode = async (
|
||||
event: React.MouseEvent<HTMLButtonElement>,
|
||||
) => {
|
||||
event.preventDefault();
|
||||
const handleDownload = (event: React.MouseEvent<HTMLButtonElement>) => {
|
||||
event.stopPropagation();
|
||||
posthog.capture("download_via_vscode_button_clicked");
|
||||
|
||||
// Fetch the VS Code URL from the API
|
||||
if (conversationId) {
|
||||
try {
|
||||
const response = await fetch(
|
||||
`/api/conversations/${conversationId}/vscode-url`,
|
||||
);
|
||||
const data = await response.json();
|
||||
|
||||
if (data.vscode_url) {
|
||||
window.open(data.vscode_url, "_blank");
|
||||
} else {
|
||||
console.error("VS Code URL not available", data.error);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to fetch VS Code URL", error);
|
||||
}
|
||||
}
|
||||
|
||||
setContextMenuVisible(false);
|
||||
onDownloadWorkspace?.();
|
||||
};
|
||||
|
||||
React.useEffect(() => {
|
||||
@@ -116,11 +89,7 @@ export function ConversationCard({
|
||||
}
|
||||
}, [titleMode]);
|
||||
|
||||
const hasContextMenu = !!(
|
||||
onDelete ||
|
||||
onChangeTitle ||
|
||||
conversationId // If we have a conversation ID, we can show the download button
|
||||
);
|
||||
const hasContextMenu = !!(onDelete || onChangeTitle || onDownloadWorkspace);
|
||||
|
||||
return (
|
||||
<div
|
||||
@@ -176,9 +145,7 @@ export function ConversationCard({
|
||||
onClose={() => setContextMenuVisible(false)}
|
||||
onDelete={onDelete && handleDelete}
|
||||
onEdit={onChangeTitle && handleEdit}
|
||||
onDownloadViaVSCode={
|
||||
conversationId ? handleDownloadViaVSCode : undefined
|
||||
}
|
||||
onDownload={onDownloadWorkspace && handleDownload}
|
||||
position={variant === "compact" ? "top" : "bottom"}
|
||||
/>
|
||||
)}
|
||||
|
||||
@@ -10,6 +10,7 @@ import { DocsButton } from "#/components/shared/buttons/docs-button";
|
||||
import { ExitProjectButton } from "#/components/shared/buttons/exit-project-button";
|
||||
import { SettingsButton } from "#/components/shared/buttons/settings-button";
|
||||
import { SettingsModal } from "#/components/shared/modals/settings/settings-modal";
|
||||
import { useCurrentSettings } from "#/context/settings-context";
|
||||
import { useSettings } from "#/hooks/query/use-settings";
|
||||
import { ConversationPanel } from "../conversation-panel/conversation-panel";
|
||||
import { useEndSession } from "#/hooks/use-end-session";
|
||||
@@ -22,7 +23,6 @@ import { useConfig } from "#/hooks/query/use-config";
|
||||
import { cn } from "#/utils/utils";
|
||||
import { displayErrorToast } from "#/utils/custom-toast-handlers";
|
||||
import { HIDE_LLM_SETTINGS } from "#/utils/feature-flags";
|
||||
import { useSaveSettings } from "#/hooks/mutation/use-save-settings";
|
||||
|
||||
export function Sidebar() {
|
||||
const location = useLocation();
|
||||
@@ -31,13 +31,12 @@ export function Sidebar() {
|
||||
const user = useGitHubUser();
|
||||
const { data: config } = useConfig();
|
||||
const {
|
||||
data: settings,
|
||||
error: settingsError,
|
||||
isError: settingsIsError,
|
||||
isFetching: isFetchingSettings,
|
||||
} = useSettings();
|
||||
const { mutateAsync: logout } = useLogout();
|
||||
const { mutate: saveUserSettings } = useSaveSettings();
|
||||
const { settings, saveUserSettings } = useCurrentSettings();
|
||||
|
||||
const [settingsModalIsOpen, setSettingsModalIsOpen] = React.useState(false);
|
||||
|
||||
@@ -79,7 +78,7 @@ export function Sidebar() {
|
||||
|
||||
const handleLogout = async () => {
|
||||
if (config?.APP_MODE === "saas") await logout();
|
||||
else saveUserSettings({ unset_github_token: true });
|
||||
else await saveUserSettings({ unset_github_token: true });
|
||||
posthog.reset();
|
||||
};
|
||||
|
||||
|
||||
33
frontend/src/components/shared/download-modal.tsx
Normal file
33
frontend/src/components/shared/download-modal.tsx
Normal file
@@ -0,0 +1,33 @@
|
||||
import { useDownloadProgress } from "#/hooks/use-download-progress";
|
||||
import { DownloadProgress } from "./download-progress";
|
||||
|
||||
interface DownloadModalProps {
|
||||
initialPath: string;
|
||||
onClose: () => void;
|
||||
isOpen: boolean;
|
||||
}
|
||||
|
||||
function ActiveDownload({
|
||||
initialPath,
|
||||
onClose,
|
||||
}: {
|
||||
initialPath: string;
|
||||
onClose: () => void;
|
||||
}) {
|
||||
const { progress, cancelDownload } = useDownloadProgress(
|
||||
initialPath,
|
||||
onClose,
|
||||
);
|
||||
|
||||
return <DownloadProgress progress={progress} onCancel={cancelDownload} />;
|
||||
}
|
||||
|
||||
export function DownloadModal({
|
||||
initialPath,
|
||||
onClose,
|
||||
isOpen,
|
||||
}: DownloadModalProps) {
|
||||
if (!isOpen) return null;
|
||||
|
||||
return <ActiveDownload initialPath={initialPath} onClose={onClose} />;
|
||||
}
|
||||
94
frontend/src/components/shared/download-progress.tsx
Normal file
94
frontend/src/components/shared/download-progress.tsx
Normal file
@@ -0,0 +1,94 @@
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
|
||||
export interface DownloadProgressState {
|
||||
filesTotal: number;
|
||||
filesDownloaded: number;
|
||||
currentFile: string;
|
||||
totalBytesDownloaded: number;
|
||||
bytesDownloadedPerSecond: number;
|
||||
isDiscoveringFiles: boolean;
|
||||
}
|
||||
|
||||
interface DownloadProgressProps {
|
||||
progress: DownloadProgressState;
|
||||
onCancel: () => void;
|
||||
}
|
||||
|
||||
export function DownloadProgress({
|
||||
progress,
|
||||
onCancel,
|
||||
}: DownloadProgressProps) {
|
||||
const { t } = useTranslation();
|
||||
const formatBytes = (bytes: number) => {
|
||||
const units = ["B", "KB", "MB", "GB"];
|
||||
let size = bytes;
|
||||
let unitIndex = 0;
|
||||
while (size >= 1024 && unitIndex < units.length - 1) {
|
||||
size /= 1024;
|
||||
unitIndex += 1;
|
||||
}
|
||||
return `${size.toFixed(1)} ${units[unitIndex]}`;
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="fixed inset-0 bg-black/50 flex items-center justify-center z-20">
|
||||
<div className="bg-[#1C1C1C] rounded-lg p-6 max-w-md w-full mx-4 border border-[#525252]">
|
||||
<div className="mb-4">
|
||||
<h3 className="text-lg font-semibold mb-2 text-white">
|
||||
{progress.isDiscoveringFiles
|
||||
? t(I18nKey.DOWNLOAD$PREPARING)
|
||||
: t(I18nKey.DOWNLOAD$DOWNLOADING)}
|
||||
</h3>
|
||||
<p className="text-sm text-gray-400 truncate">
|
||||
{progress.isDiscoveringFiles
|
||||
? t(I18nKey.DOWNLOAD$FOUND_FILES, { count: progress.filesTotal })
|
||||
: progress.currentFile}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="mb-4">
|
||||
<div className="h-2 bg-[#2C2C2C] rounded-full overflow-hidden">
|
||||
{progress.isDiscoveringFiles ? (
|
||||
<div
|
||||
className="h-full bg-blue-500 animate-pulse"
|
||||
style={{ width: "100%" }}
|
||||
/>
|
||||
) : (
|
||||
<div
|
||||
className="h-full bg-blue-500 transition-all duration-300"
|
||||
style={{
|
||||
width: `${(progress.filesDownloaded / progress.filesTotal) * 100}%`,
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex justify-between text-sm text-gray-400">
|
||||
<span>
|
||||
{progress.isDiscoveringFiles
|
||||
? t(I18nKey.DOWNLOAD$SCANNING)
|
||||
: t(I18nKey.DOWNLOAD$FILES_PROGRESS, {
|
||||
downloaded: progress.filesDownloaded,
|
||||
total: progress.filesTotal,
|
||||
})}
|
||||
</span>
|
||||
{!progress.isDiscoveringFiles && (
|
||||
<span>{formatBytes(progress.bytesDownloadedPerSecond)}/s</span>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="mt-4 flex justify-end">
|
||||
<button
|
||||
type="button"
|
||||
onClick={onCancel}
|
||||
className="px-4 py-2 text-sm text-gray-400 hover:text-white transition-colors"
|
||||
>
|
||||
{t(I18nKey.DOWNLOAD$CANCEL)}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -9,12 +9,12 @@ import { extractSettings } from "#/utils/settings-utils";
|
||||
import { useEndSession } from "#/hooks/use-end-session";
|
||||
import { ModalBackdrop } from "../modal-backdrop";
|
||||
import { ModelSelector } from "./model-selector";
|
||||
import { useCurrentSettings } from "#/context/settings-context";
|
||||
import { Settings } from "#/types/settings";
|
||||
import { BrandButton } from "#/components/features/settings/brand-button";
|
||||
import { KeyStatusIcon } from "#/components/features/settings/key-status-icon";
|
||||
import { SettingsInput } from "#/components/features/settings/settings-input";
|
||||
import { HelpLink } from "#/components/features/settings/help-link";
|
||||
import { useSaveSettings } from "#/hooks/mutation/use-save-settings";
|
||||
|
||||
interface SettingsFormProps {
|
||||
settings: Settings;
|
||||
@@ -23,7 +23,7 @@ interface SettingsFormProps {
|
||||
}
|
||||
|
||||
export function SettingsForm({ settings, models, onClose }: SettingsFormProps) {
|
||||
const { mutate: saveUserSettings } = useSaveSettings();
|
||||
const { saveUserSettings } = useCurrentSettings();
|
||||
const endSession = useEndSession();
|
||||
|
||||
const location = useLocation();
|
||||
|
||||
74
frontend/src/context/settings-context.tsx
Normal file
74
frontend/src/context/settings-context.tsx
Normal file
@@ -0,0 +1,74 @@
|
||||
import React from "react";
|
||||
import { MutateOptions } from "@tanstack/react-query";
|
||||
import { useSettings } from "#/hooks/query/use-settings";
|
||||
import { useSaveSettings } from "#/hooks/mutation/use-save-settings";
|
||||
import { PostSettings, Settings } from "#/types/settings";
|
||||
import { retrieveAxiosErrorMessage } from "#/utils/retrieve-axios-error-message";
|
||||
import { displayErrorToast } from "#/utils/custom-toast-handlers";
|
||||
|
||||
type SaveUserSettingsConfig = {
|
||||
onSuccess: MutateOptions<void, Error, Partial<PostSettings>>["onSuccess"];
|
||||
};
|
||||
|
||||
interface SettingsContextType {
|
||||
saveUserSettings: (
|
||||
newSettings: Partial<PostSettings>,
|
||||
config?: SaveUserSettingsConfig,
|
||||
) => Promise<void>;
|
||||
settings: Settings | undefined;
|
||||
}
|
||||
|
||||
const SettingsContext = React.createContext<SettingsContextType | undefined>(
|
||||
undefined,
|
||||
);
|
||||
|
||||
interface SettingsProviderProps {
|
||||
children: React.ReactNode;
|
||||
}
|
||||
|
||||
export function SettingsProvider({ children }: SettingsProviderProps) {
|
||||
const { data: userSettings } = useSettings();
|
||||
const { mutateAsync: saveSettings } = useSaveSettings();
|
||||
|
||||
const saveUserSettings = async (
|
||||
newSettings: Partial<PostSettings>,
|
||||
config?: SaveUserSettingsConfig,
|
||||
) => {
|
||||
const updatedSettings: Partial<PostSettings> = {
|
||||
...userSettings,
|
||||
...newSettings,
|
||||
};
|
||||
|
||||
if (updatedSettings.LLM_API_KEY === "**********") {
|
||||
delete updatedSettings.LLM_API_KEY;
|
||||
}
|
||||
|
||||
await saveSettings(updatedSettings, {
|
||||
onSuccess: config?.onSuccess,
|
||||
onError: (error) => {
|
||||
const errorMessage = retrieveAxiosErrorMessage(error);
|
||||
displayErrorToast(errorMessage);
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
const value = React.useMemo(
|
||||
() => ({
|
||||
saveUserSettings,
|
||||
settings: userSettings,
|
||||
}),
|
||||
[saveUserSettings, userSettings],
|
||||
);
|
||||
|
||||
return <SettingsContext value={value}>{children}</SettingsContext>;
|
||||
}
|
||||
|
||||
export function useCurrentSettings() {
|
||||
const context = React.useContext(SettingsContext);
|
||||
if (context === undefined) {
|
||||
throw new Error(
|
||||
"useCurrentSettings must be used within a SettingsProvider",
|
||||
);
|
||||
}
|
||||
return context;
|
||||
}
|
||||
@@ -49,12 +49,14 @@ interface UseWsClient {
|
||||
isLoadingMessages: boolean;
|
||||
events: Record<string, unknown>[];
|
||||
send: (event: Record<string, unknown>) => void;
|
||||
pendingMessages: Record<string, unknown>[];
|
||||
}
|
||||
|
||||
const WsClientContext = React.createContext<UseWsClient>({
|
||||
status: WsClientProviderStatus.DISCONNECTED,
|
||||
isLoadingMessages: true,
|
||||
events: [],
|
||||
pendingMessages: [],
|
||||
send: () => {
|
||||
throw new Error("not connected");
|
||||
},
|
||||
@@ -109,26 +111,43 @@ export function WsClientProvider({
|
||||
WsClientProviderStatus.DISCONNECTED,
|
||||
);
|
||||
const [events, setEvents] = React.useState<Record<string, unknown>[]>([]);
|
||||
const [pendingMessages, setPendingMessages] = React.useState<
|
||||
Record<string, unknown>[]
|
||||
>([]);
|
||||
const lastEventRef = React.useRef<Record<string, unknown> | null>(null);
|
||||
|
||||
const messageRateHandler = useRate({ threshold: 250 });
|
||||
|
||||
// Private function to queue messages for later sending
|
||||
const queueMessage = (event: Record<string, unknown>) => {
|
||||
EventLogger.info(`Queueing message: ${JSON.stringify(event)}`);
|
||||
setPendingMessages((prev) => [...prev, event]);
|
||||
};
|
||||
|
||||
function send(event: Record<string, unknown>) {
|
||||
if (!sioRef.current) {
|
||||
EventLogger.error("WebSocket is not connected.");
|
||||
EventLogger.info("WebSocket is not connected, queueing message");
|
||||
queueMessage(event);
|
||||
return;
|
||||
}
|
||||
|
||||
// Send the message to the backend
|
||||
EventLogger.info(`Sending message: ${JSON.stringify(event)}`);
|
||||
sioRef.current.emit("oh_action", event);
|
||||
}
|
||||
|
||||
function handleConnect() {
|
||||
setStatus(WsClientProviderStatus.CONNECTED);
|
||||
EventLogger.info(
|
||||
`WebSocket connected. Pending messages: ${pendingMessages.length}`,
|
||||
);
|
||||
}
|
||||
|
||||
function handleMessage(event: Record<string, unknown>) {
|
||||
if (isOpenHandsEvent(event) && isMessageAction(event)) {
|
||||
messageRateHandler.record(new Date().getTime());
|
||||
}
|
||||
|
||||
setEvents((prevEvents) => [...prevEvents, event]);
|
||||
if (!Number.isNaN(parseInt(event.id as string, 10))) {
|
||||
lastEventRef.current = event;
|
||||
@@ -145,14 +164,39 @@ export function WsClientProvider({
|
||||
}
|
||||
sio.io.opts.query = sio.io.opts.query || {};
|
||||
sio.io.opts.query.latest_event_id = lastEventRef.current?.id;
|
||||
EventLogger.info(
|
||||
`WebSocket disconnected. Latest event ID: ${lastEventRef.current?.id}`,
|
||||
);
|
||||
updateStatusWhenErrorMessagePresent(data);
|
||||
}
|
||||
|
||||
function handleError(data: unknown) {
|
||||
setStatus(WsClientProviderStatus.DISCONNECTED);
|
||||
EventLogger.error(`WebSocket connection error: ${JSON.stringify(data)}`);
|
||||
updateStatusWhenErrorMessagePresent(data);
|
||||
}
|
||||
|
||||
// Process any pending messages when the WebSocket connects
|
||||
React.useEffect(() => {
|
||||
if (
|
||||
status === WsClientProviderStatus.CONNECTED &&
|
||||
pendingMessages.length > 0 &&
|
||||
sioRef.current
|
||||
) {
|
||||
// We're connected and have pending messages
|
||||
EventLogger.info(
|
||||
`Connected! Sending ${pendingMessages.length} queued messages`,
|
||||
);
|
||||
|
||||
pendingMessages.forEach((event) => {
|
||||
sioRef.current?.emit("oh_action", event);
|
||||
});
|
||||
|
||||
setPendingMessages([]);
|
||||
EventLogger.info("All queued messages sent, queue cleared");
|
||||
}
|
||||
}, [status, pendingMessages.length]);
|
||||
|
||||
React.useEffect(() => {
|
||||
lastEventRef.current = null;
|
||||
}, [conversationId]);
|
||||
@@ -210,9 +254,10 @@ export function WsClientProvider({
|
||||
status,
|
||||
isLoadingMessages: messageRateHandler.isUnderThreshold,
|
||||
events,
|
||||
pendingMessages,
|
||||
send,
|
||||
}),
|
||||
[status, messageRateHandler.isUnderThreshold, events],
|
||||
[status, messageRateHandler.isUnderThreshold, events, pendingMessages],
|
||||
);
|
||||
|
||||
return <WsClientContext value={value}>{children}</WsClientContext>;
|
||||
|
||||
@@ -16,6 +16,7 @@ import store from "./store";
|
||||
import { useConfig } from "./hooks/query/use-config";
|
||||
import { AuthProvider } from "./context/auth-context";
|
||||
import { queryClientConfig } from "./query-client-config";
|
||||
import { SettingsProvider } from "./context/settings-context";
|
||||
|
||||
function PosthogInit() {
|
||||
const { data: config } = useConfig();
|
||||
@@ -55,8 +56,10 @@ prepareApp().then(() =>
|
||||
<Provider store={store}>
|
||||
<AuthProvider>
|
||||
<QueryClientProvider client={queryClient}>
|
||||
<HydratedRouter />
|
||||
<PosthogInit />
|
||||
<SettingsProvider>
|
||||
<HydratedRouter />
|
||||
<PosthogInit />
|
||||
</SettingsProvider>
|
||||
</QueryClientProvider>
|
||||
</AuthProvider>
|
||||
</Provider>
|
||||
|
||||
@@ -2,7 +2,6 @@ import { useMutation, useQueryClient } from "@tanstack/react-query";
|
||||
import { DEFAULT_SETTINGS } from "#/services/settings";
|
||||
import OpenHands from "#/api/open-hands";
|
||||
import { PostSettings, PostApiSettings } from "#/types/settings";
|
||||
import { useSettings } from "../query/use-settings";
|
||||
|
||||
const saveSettingsMutationFn = async (settings: Partial<PostSettings>) => {
|
||||
const resetLlmApiKey = settings.LLM_API_KEY === "";
|
||||
@@ -30,25 +29,9 @@ const saveSettingsMutationFn = async (settings: Partial<PostSettings>) => {
|
||||
|
||||
export const useSaveSettings = () => {
|
||||
const queryClient = useQueryClient();
|
||||
const { data: currentSettings } = useSettings();
|
||||
|
||||
return useMutation({
|
||||
mutationFn: async (settings: Partial<PostSettings>) => {
|
||||
const newSettings = { ...currentSettings, ...settings };
|
||||
|
||||
// Temp hack for reset logic
|
||||
if (
|
||||
settings.LLM_API_KEY === undefined &&
|
||||
settings.LLM_BASE_URL === undefined &&
|
||||
settings.LLM_MODEL === undefined
|
||||
) {
|
||||
delete newSettings.LLM_API_KEY;
|
||||
delete newSettings.LLM_BASE_URL;
|
||||
delete newSettings.LLM_MODEL;
|
||||
}
|
||||
|
||||
await saveSettingsMutationFn(newSettings);
|
||||
},
|
||||
mutationFn: saveSettingsMutationFn,
|
||||
onSuccess: async () => {
|
||||
await queryClient.invalidateQueries({ queryKey: ["settings"] });
|
||||
},
|
||||
|
||||
@@ -5,13 +5,13 @@ import { useConfig } from "./use-config";
|
||||
import OpenHands from "#/api/open-hands";
|
||||
import { useAuth } from "#/context/auth-context";
|
||||
import { useLogout } from "../mutation/use-logout";
|
||||
import { useSaveSettings } from "../mutation/use-save-settings";
|
||||
import { useCurrentSettings } from "#/context/settings-context";
|
||||
|
||||
export const useGitHubUser = () => {
|
||||
const { githubTokenIsSet } = useAuth();
|
||||
const { setGitHubTokenIsSet } = useAuth();
|
||||
const { mutateAsync: logout } = useLogout();
|
||||
const { mutate: saveUserSettings } = useSaveSettings();
|
||||
const { saveUserSettings } = useCurrentSettings();
|
||||
const { data: config } = useConfig();
|
||||
|
||||
const user = useQuery({
|
||||
@@ -38,7 +38,7 @@ export const useGitHubUser = () => {
|
||||
const handleLogout = async () => {
|
||||
if (config?.APP_MODE === "saas") await logout();
|
||||
else {
|
||||
saveUserSettings({ unset_github_token: true });
|
||||
await saveUserSettings({ unset_github_token: true });
|
||||
setGitHubTokenIsSet(false);
|
||||
}
|
||||
posthog.reset();
|
||||
|
||||
@@ -44,13 +44,13 @@ export const useSettings = () => {
|
||||
});
|
||||
|
||||
React.useEffect(() => {
|
||||
if (query.isFetched && query.data?.LLM_API_KEY) {
|
||||
if (query.data?.LLM_API_KEY) {
|
||||
posthog.capture("user_activated");
|
||||
}
|
||||
}, [query.data?.LLM_API_KEY, query.isFetched]);
|
||||
}, [query.data?.LLM_API_KEY]);
|
||||
|
||||
React.useEffect(() => {
|
||||
if (query.isFetched) setGitHubTokenIsSet(!!query.data?.GITHUB_TOKEN_IS_SET);
|
||||
setGitHubTokenIsSet(!!query.data?.GITHUB_TOKEN_IS_SET);
|
||||
}, [query.data?.GITHUB_TOKEN_IS_SET, query.isFetched]);
|
||||
|
||||
// We want to return the defaults if the settings aren't found so the user can still see the
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
import { useCurrentSettings } from "#/context/settings-context";
|
||||
import { useLogout } from "./mutation/use-logout";
|
||||
import { useSaveSettings } from "./mutation/use-save-settings";
|
||||
import { useConfig } from "./query/use-config";
|
||||
|
||||
export const useAppLogout = () => {
|
||||
const { data: config } = useConfig();
|
||||
const { mutateAsync: logout } = useLogout();
|
||||
const { mutate: saveUserSettings } = useSaveSettings();
|
||||
const { saveUserSettings } = useCurrentSettings();
|
||||
|
||||
const handleLogout = async () => {
|
||||
if (config?.APP_MODE === "saas") await logout();
|
||||
else saveUserSettings({ unset_github_token: true });
|
||||
else await saveUserSettings({ unset_github_token: true });
|
||||
};
|
||||
|
||||
return { handleLogout };
|
||||
|
||||
80
frontend/src/hooks/use-download-progress.ts
Normal file
80
frontend/src/hooks/use-download-progress.ts
Normal file
@@ -0,0 +1,80 @@
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import { downloadFiles } from "#/utils/download-files";
|
||||
import { DownloadProgressState } from "#/components/shared/download-progress";
|
||||
import { useConversation } from "#/context/conversation-context";
|
||||
|
||||
export const INITIAL_PROGRESS: DownloadProgressState = {
|
||||
filesTotal: 0,
|
||||
filesDownloaded: 0,
|
||||
currentFile: "",
|
||||
totalBytesDownloaded: 0,
|
||||
bytesDownloadedPerSecond: 0,
|
||||
isDiscoveringFiles: true,
|
||||
};
|
||||
|
||||
export function useDownloadProgress(
|
||||
initialPath: string | undefined,
|
||||
onClose: () => void,
|
||||
) {
|
||||
const [isStarted, setIsStarted] = useState(false);
|
||||
const [progress, setProgress] =
|
||||
useState<DownloadProgressState>(INITIAL_PROGRESS);
|
||||
const progressRef = useRef<DownloadProgressState>(INITIAL_PROGRESS);
|
||||
const abortController = useRef<AbortController>(null);
|
||||
const { conversationId } = useConversation();
|
||||
|
||||
// Create AbortController on mount
|
||||
useEffect(() => {
|
||||
const controller = new AbortController();
|
||||
abortController.current = controller;
|
||||
// Initialize progress ref with initial state
|
||||
progressRef.current = INITIAL_PROGRESS;
|
||||
return () => {
|
||||
controller.abort();
|
||||
abortController.current = null;
|
||||
};
|
||||
}, []); // Empty deps array - only run on mount/unmount
|
||||
|
||||
// Start download when isStarted becomes true
|
||||
useEffect(() => {
|
||||
if (!isStarted) {
|
||||
setIsStarted(true);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!abortController.current) return;
|
||||
|
||||
// Start download
|
||||
const download = async () => {
|
||||
try {
|
||||
await downloadFiles(conversationId, initialPath, {
|
||||
onProgress: (p) => {
|
||||
// Update both the ref and state
|
||||
progressRef.current = { ...p };
|
||||
setProgress((prev: DownloadProgressState) => ({ ...prev, ...p }));
|
||||
},
|
||||
signal: abortController.current!.signal,
|
||||
});
|
||||
onClose();
|
||||
} catch (error) {
|
||||
if (error instanceof Error && error.message === "Download cancelled") {
|
||||
onClose();
|
||||
} else {
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
};
|
||||
download();
|
||||
}, [initialPath, onClose, isStarted]);
|
||||
|
||||
// No longer need startDownload as it's handled in useEffect
|
||||
|
||||
const cancelDownload = useCallback(() => {
|
||||
abortController.current?.abort();
|
||||
}, []);
|
||||
|
||||
return {
|
||||
progress,
|
||||
cancelDownload,
|
||||
};
|
||||
}
|
||||
@@ -1,9 +1,9 @@
|
||||
import React from "react";
|
||||
import { useCurrentSettings } from "#/context/settings-context";
|
||||
import { handleCaptureConsent } from "#/utils/handle-capture-consent";
|
||||
import { useSaveSettings } from "./mutation/use-save-settings";
|
||||
|
||||
export const useMigrateUserConsent = () => {
|
||||
const { mutate: saveUserSettings } = useSaveSettings();
|
||||
const { saveUserSettings } = useCurrentSettings();
|
||||
|
||||
/**
|
||||
* Migrate user consent to the settings store on the server.
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import { useCallback, useRef } from "react";
|
||||
import notificationSound from "../assets/notification.mp3";
|
||||
import { useSettings } from "./query/use-settings";
|
||||
import { useCurrentSettings } from "../context/settings-context";
|
||||
|
||||
export const useNotification = () => {
|
||||
const { data: settings } = useSettings();
|
||||
const { settings } = useCurrentSettings();
|
||||
const audioRef = useRef<HTMLAudioElement | undefined>(undefined);
|
||||
|
||||
// Initialize audio only in browser environment
|
||||
|
||||
356
frontend/src/utils/download-files.ts
Normal file
356
frontend/src/utils/download-files.ts
Normal file
@@ -0,0 +1,356 @@
|
||||
import OpenHands from "#/api/open-hands";
|
||||
import { downloadWorkspace } from "./download-workspace";
|
||||
|
||||
interface DownloadProgress {
|
||||
filesTotal: number;
|
||||
filesDownloaded: number;
|
||||
currentFile: string;
|
||||
totalBytesDownloaded: number;
|
||||
bytesDownloadedPerSecond: number;
|
||||
isDiscoveringFiles: boolean;
|
||||
}
|
||||
|
||||
interface DownloadOptions {
|
||||
onProgress?: (progress: DownloadProgress) => void;
|
||||
signal?: AbortSignal;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if the File System Access API is supported
|
||||
*/
|
||||
function isFileSystemAccessSupported(): boolean {
|
||||
return "showDirectoryPicker" in window;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if the Save File Picker API is supported
|
||||
*/
|
||||
function isSaveFilePickerSupported(): boolean {
|
||||
return "showSaveFilePicker" in window;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates subdirectories and returns the final directory handle
|
||||
*/
|
||||
async function createSubdirectories(
|
||||
baseHandle: FileSystemDirectoryHandle,
|
||||
pathParts: string[],
|
||||
): Promise<FileSystemDirectoryHandle> {
|
||||
return pathParts.reduce(async (promise, part) => {
|
||||
const handle = await promise;
|
||||
return handle.getDirectoryHandle(part, { create: true });
|
||||
}, Promise.resolve(baseHandle));
|
||||
}
|
||||
|
||||
/**
|
||||
* Recursively gets all files in a directory
|
||||
*/
|
||||
async function getAllFiles(
|
||||
conversationID: string,
|
||||
path: string,
|
||||
progress: DownloadProgress,
|
||||
options?: DownloadOptions,
|
||||
): Promise<string[]> {
|
||||
const entries = await OpenHands.getFiles(conversationID, path);
|
||||
|
||||
const processEntry = async (entry: string): Promise<string[]> => {
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Download cancelled");
|
||||
}
|
||||
|
||||
const fullPath = path + entry;
|
||||
if (entry.endsWith("/")) {
|
||||
const subEntries = await OpenHands.getFiles(conversationID, fullPath);
|
||||
const subFilesPromises = subEntries.map((subEntry) =>
|
||||
processEntry(subEntry),
|
||||
);
|
||||
const subFilesArrays = await Promise.all(subFilesPromises);
|
||||
return subFilesArrays.flat();
|
||||
}
|
||||
const updatedProgress = {
|
||||
...progress,
|
||||
filesTotal: progress.filesTotal + 1,
|
||||
currentFile: fullPath,
|
||||
};
|
||||
options?.onProgress?.(updatedProgress);
|
||||
return [fullPath];
|
||||
};
|
||||
|
||||
const filePromises = entries.map((entry) => processEntry(entry));
|
||||
const fileArrays = await Promise.all(filePromises);
|
||||
|
||||
const updatedProgress = {
|
||||
...progress,
|
||||
isDiscoveringFiles: false,
|
||||
};
|
||||
options?.onProgress?.(updatedProgress);
|
||||
|
||||
return fileArrays.flat();
|
||||
}
|
||||
|
||||
/**
|
||||
* Process a batch of files
|
||||
*/
|
||||
async function processBatch(
|
||||
conversationID: string,
|
||||
batch: string[],
|
||||
directoryHandle: FileSystemDirectoryHandle,
|
||||
progress: DownloadProgress,
|
||||
startTime: number,
|
||||
completedFiles: number,
|
||||
totalBytes: number,
|
||||
options?: DownloadOptions,
|
||||
): Promise<{ newCompleted: number; newBytes: number }> {
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Download cancelled");
|
||||
}
|
||||
|
||||
// Process files in the batch in parallel
|
||||
const results = await Promise.all(
|
||||
batch.map(async (path) => {
|
||||
try {
|
||||
const newProgress = {
|
||||
...progress,
|
||||
currentFile: path,
|
||||
isDiscoveringFiles: false,
|
||||
filesDownloaded: completedFiles,
|
||||
totalBytesDownloaded: totalBytes,
|
||||
bytesDownloadedPerSecond:
|
||||
totalBytes / ((Date.now() - startTime) / 1000),
|
||||
};
|
||||
options?.onProgress?.(newProgress);
|
||||
|
||||
const content = await OpenHands.getFile(conversationID, path);
|
||||
|
||||
// Save to the selected directory preserving structure
|
||||
const pathParts = path.split("/").filter(Boolean);
|
||||
const fileName = pathParts.pop() || "file";
|
||||
const dirHandle =
|
||||
pathParts.length > 0
|
||||
? await createSubdirectories(directoryHandle, pathParts)
|
||||
: directoryHandle;
|
||||
|
||||
// Create and write the file
|
||||
const fileHandle = await dirHandle.getFileHandle(fileName, {
|
||||
create: true,
|
||||
});
|
||||
const writable = await fileHandle.createWritable();
|
||||
await writable.write(content);
|
||||
await writable.close();
|
||||
|
||||
// Return the size of this file
|
||||
return new Blob([content]).size;
|
||||
} catch (error) {
|
||||
// Silently handle file processing errors and return 0 bytes
|
||||
return 0;
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
// Calculate batch totals
|
||||
const batchBytes = results.reduce((sum, size) => sum + size, 0);
|
||||
const newTotalBytes = totalBytes + batchBytes;
|
||||
const newCompleted =
|
||||
completedFiles + results.filter((size) => size > 0).length;
|
||||
|
||||
// Update progress with batch results
|
||||
const updatedProgress = {
|
||||
...progress,
|
||||
filesDownloaded: newCompleted,
|
||||
totalBytesDownloaded: newTotalBytes,
|
||||
bytesDownloadedPerSecond: newTotalBytes / ((Date.now() - startTime) / 1000),
|
||||
isDiscoveringFiles: false,
|
||||
};
|
||||
options?.onProgress?.(updatedProgress);
|
||||
|
||||
return {
|
||||
newCompleted,
|
||||
newBytes: newTotalBytes,
|
||||
};
|
||||
}
|
||||
|
||||
export async function downloadTrajectory(
|
||||
conversationId: string,
|
||||
data: unknown[] | null,
|
||||
): Promise<void> {
|
||||
try {
|
||||
if (!isSaveFilePickerSupported()) {
|
||||
throw new Error(
|
||||
"Your browser doesn't support downloading folders. Please use Chrome, Edge, or another browser that supports the File System Access API.",
|
||||
);
|
||||
}
|
||||
const options = {
|
||||
suggestedName: `trajectory-${conversationId}.json`,
|
||||
types: [
|
||||
{
|
||||
description: "JSON File",
|
||||
accept: {
|
||||
"application/json": [".json"],
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const handle = await window.showSaveFilePicker(options);
|
||||
const writable = await handle.createWritable();
|
||||
await writable.write(JSON.stringify(data, null, 2));
|
||||
await writable.close();
|
||||
} catch (error) {
|
||||
throw new Error(
|
||||
`Failed to download file: ${error instanceof Error ? error.message : String(error)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Downloads files from the workspace one by one
|
||||
* @param initialPath Initial path to start downloading from. If not provided, downloads from root
|
||||
* @param options Download options including progress callback and abort signal
|
||||
*/
|
||||
export async function downloadFiles(
|
||||
conversationID: string,
|
||||
initialPath?: string,
|
||||
options?: DownloadOptions,
|
||||
): Promise<void> {
|
||||
const startTime = Date.now();
|
||||
const progress: DownloadProgress = {
|
||||
filesTotal: 0, // Will be updated during file discovery
|
||||
filesDownloaded: 0,
|
||||
currentFile: "",
|
||||
totalBytesDownloaded: 0,
|
||||
bytesDownloadedPerSecond: 0,
|
||||
isDiscoveringFiles: true,
|
||||
};
|
||||
|
||||
try {
|
||||
// Check if File System Access API is supported
|
||||
if (!isFileSystemAccessSupported()) {
|
||||
throw new Error(
|
||||
"Your browser doesn't support downloading folders. Please use Chrome, Edge, or another browser that supports the File System Access API.",
|
||||
);
|
||||
}
|
||||
|
||||
// Show directory picker first
|
||||
let directoryHandle: FileSystemDirectoryHandle;
|
||||
try {
|
||||
directoryHandle = await window.showDirectoryPicker();
|
||||
} catch (error) {
|
||||
if (error instanceof Error && error.name === "AbortError") {
|
||||
throw new Error("Download cancelled");
|
||||
}
|
||||
if (error instanceof Error && error.name === "SecurityError") {
|
||||
throw new Error(
|
||||
"Permission denied. Please allow access to the download location when prompted.",
|
||||
);
|
||||
}
|
||||
throw new Error("Failed to select download location. Please try again.");
|
||||
}
|
||||
|
||||
// Then recursively get all files
|
||||
const files = await getAllFiles(
|
||||
conversationID,
|
||||
initialPath || "",
|
||||
progress,
|
||||
options,
|
||||
);
|
||||
|
||||
// Set isDiscoveringFiles to false now that we have the full list and preserve filesTotal
|
||||
const finalTotal = progress.filesTotal;
|
||||
options?.onProgress?.({
|
||||
...progress,
|
||||
filesTotal: finalTotal,
|
||||
isDiscoveringFiles: false,
|
||||
});
|
||||
|
||||
// Verify we still have permission after the potentially long file scan
|
||||
try {
|
||||
// Try to create and write to a test file to verify permissions
|
||||
const testHandle = await directoryHandle.getFileHandle(
|
||||
".openhands-test",
|
||||
{ create: true },
|
||||
);
|
||||
const writable = await testHandle.createWritable();
|
||||
await writable.close();
|
||||
} catch (error) {
|
||||
if (
|
||||
error instanceof Error &&
|
||||
error.message.includes("User activation is required")
|
||||
) {
|
||||
// Ask for permission again
|
||||
try {
|
||||
directoryHandle = await window.showDirectoryPicker();
|
||||
} catch (permissionError) {
|
||||
if (
|
||||
permissionError instanceof Error &&
|
||||
permissionError.name === "AbortError"
|
||||
) {
|
||||
throw new Error("Download cancelled");
|
||||
}
|
||||
if (
|
||||
permissionError instanceof Error &&
|
||||
permissionError.name === "SecurityError"
|
||||
) {
|
||||
throw new Error(
|
||||
"Permission denied. Please allow access to the download location when prompted.",
|
||||
);
|
||||
}
|
||||
throw new Error(
|
||||
"Failed to select download location. Please try again.",
|
||||
);
|
||||
}
|
||||
} else {
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
// Process files in parallel batches to avoid overwhelming the browser
|
||||
const BATCH_SIZE = 5;
|
||||
const batches = Array.from(
|
||||
{ length: Math.ceil(files.length / BATCH_SIZE) },
|
||||
(_, i) => files.slice(i * BATCH_SIZE, (i + 1) * BATCH_SIZE),
|
||||
);
|
||||
|
||||
// Keep track of completed files across all batches
|
||||
let completedFiles = 0;
|
||||
let totalBytesDownloaded = 0;
|
||||
|
||||
// Process batches sequentially to maintain order and avoid overwhelming the browser
|
||||
await batches.reduce(
|
||||
(promise, batch) =>
|
||||
promise.then(async () => {
|
||||
const { newCompleted, newBytes } = await processBatch(
|
||||
conversationID,
|
||||
batch,
|
||||
directoryHandle,
|
||||
progress,
|
||||
startTime,
|
||||
completedFiles,
|
||||
totalBytesDownloaded,
|
||||
options,
|
||||
);
|
||||
completedFiles = newCompleted;
|
||||
totalBytesDownloaded = newBytes;
|
||||
}),
|
||||
Promise.resolve(),
|
||||
);
|
||||
} catch (error) {
|
||||
if (error instanceof Error && error.message === "Download cancelled") {
|
||||
throw error;
|
||||
}
|
||||
// Fallback to old style download
|
||||
if (
|
||||
error instanceof Error &&
|
||||
(error.message.includes("browser doesn't support") ||
|
||||
error.message.includes("Failed to select") ||
|
||||
error.message.includes("Permission denied"))
|
||||
) {
|
||||
await downloadWorkspace(conversationID);
|
||||
return;
|
||||
}
|
||||
|
||||
// Otherwise, wrap it with a generic message
|
||||
throw new Error(
|
||||
`Failed to download files: ${error instanceof Error ? error.message : String(error)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
function isSaveFilePickerSupported(): boolean {
|
||||
return typeof window !== "undefined" && "showSaveFilePicker" in window;
|
||||
}
|
||||
|
||||
export async function downloadTrajectory(
|
||||
conversationId: string,
|
||||
data: unknown[] | null,
|
||||
): Promise<void> {
|
||||
if (!isSaveFilePickerSupported()) {
|
||||
throw new Error(
|
||||
"Your browser doesn't support downloading files. Please use Chrome, Edge, or another browser that supports the File System Access API.",
|
||||
);
|
||||
}
|
||||
const options: SaveFilePickerOptions = {
|
||||
suggestedName: `trajectory-${conversationId}.json`,
|
||||
types: [
|
||||
{
|
||||
description: "JSON File",
|
||||
accept: {
|
||||
"application/json": [".json"],
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const fileHandle = await window.showSaveFilePicker(options);
|
||||
const writable = await fileHandle.createWritable();
|
||||
await writable.write(JSON.stringify(data, null, 2));
|
||||
await writable.close();
|
||||
}
|
||||
16
frontend/src/utils/download-workspace.ts
Normal file
16
frontend/src/utils/download-workspace.ts
Normal file
@@ -0,0 +1,16 @@
|
||||
import OpenHands from "#/api/open-hands";
|
||||
|
||||
/**
|
||||
* Downloads the current workspace as a .zip file.
|
||||
*/
|
||||
export const downloadWorkspace = async (conversationId: string) => {
|
||||
const blob = await OpenHands.getWorkspaceZip(conversationId);
|
||||
|
||||
const url = URL.createObjectURL(blob);
|
||||
const link = document.createElement("a");
|
||||
link.href = url;
|
||||
link.setAttribute("download", "workspace.zip");
|
||||
document.body.appendChild(link);
|
||||
link.click();
|
||||
link.parentNode?.removeChild(link);
|
||||
};
|
||||
@@ -1,18 +1,28 @@
|
||||
/* eslint-disable no-console */
|
||||
|
||||
/**
|
||||
* A utility class for logging events. This class will only log events in development mode.
|
||||
* A utility class for logging events. This class will log events in development mode
|
||||
* and can be forced to log in any environment by setting FORCE_LOGGING to true.
|
||||
*/
|
||||
class EventLogger {
|
||||
static isDevMode = process.env.NODE_ENV === "development";
|
||||
|
||||
static FORCE_LOGGING = false; // Set to false for production, true only for debugging
|
||||
|
||||
static shouldLog() {
|
||||
return this.isDevMode || this.FORCE_LOGGING;
|
||||
}
|
||||
|
||||
/**
|
||||
* Format and log a message event
|
||||
* @param event The raw event object
|
||||
*/
|
||||
static message(event: MessageEvent) {
|
||||
if (this.isDevMode) {
|
||||
console.warn(JSON.stringify(JSON.parse(event.data.toString()), null, 2));
|
||||
if (this.shouldLog()) {
|
||||
console.warn(
|
||||
"[OpenHands]",
|
||||
JSON.stringify(JSON.parse(event.data.toString()), null, 2),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,8 +32,8 @@ class EventLogger {
|
||||
* @param name The name of the event
|
||||
*/
|
||||
static event(event: Event, name?: string) {
|
||||
if (this.isDevMode) {
|
||||
console.warn(name || "EVENT", event);
|
||||
if (this.shouldLog()) {
|
||||
console.warn("[OpenHands]", name || "EVENT", event);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,8 +42,18 @@ class EventLogger {
|
||||
* @param warning The warning message
|
||||
*/
|
||||
static warning(warning: string) {
|
||||
if (this.isDevMode) {
|
||||
console.warn(warning);
|
||||
if (this.shouldLog()) {
|
||||
console.warn("[OpenHands]", warning);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Log an info message
|
||||
* @param info The info message
|
||||
*/
|
||||
static info(info: string) {
|
||||
if (this.shouldLog()) {
|
||||
console.info("[OpenHands]", info);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,8 +62,8 @@ class EventLogger {
|
||||
* @param error The error message
|
||||
*/
|
||||
static error(error: string) {
|
||||
if (this.isDevMode) {
|
||||
console.error(error);
|
||||
if (this.shouldLog()) {
|
||||
console.error("[OpenHands]", error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import { vi } from "vitest";
|
||||
import { AppStore, RootState, rootReducer } from "./src/store";
|
||||
import { AuthProvider } from "#/context/auth-context";
|
||||
import { ConversationProvider } from "#/context/conversation-context";
|
||||
import { SettingsProvider } from "#/context/settings-context";
|
||||
|
||||
// Mock useParams before importing components
|
||||
vi.mock("react-router", async () => {
|
||||
@@ -65,7 +66,7 @@ export function renderWithProviders(
|
||||
function Wrapper({ children }: PropsWithChildren) {
|
||||
return (
|
||||
<Provider store={store}>
|
||||
<AuthProvider initialGithubTokenIsSet>
|
||||
<AuthProvider initialGithubTokenIsSet={true}>
|
||||
<QueryClientProvider
|
||||
client={
|
||||
new QueryClient({
|
||||
@@ -73,9 +74,11 @@ export function renderWithProviders(
|
||||
})
|
||||
}
|
||||
>
|
||||
<ConversationProvider>
|
||||
<I18nextProvider i18n={i18n}>{children}</I18nextProvider>
|
||||
</ConversationProvider>
|
||||
<SettingsProvider>
|
||||
<ConversationProvider>
|
||||
<I18nextProvider i18n={i18n}>{children}</I18nextProvider>
|
||||
</ConversationProvider>
|
||||
</SettingsProvider>
|
||||
</QueryClientProvider>
|
||||
</AuthProvider>
|
||||
</Provider>
|
||||
|
||||
@@ -93,7 +93,6 @@ class AgentController:
|
||||
ChangeAgentStateAction,
|
||||
AgentStateChangedObservation,
|
||||
)
|
||||
_cached_first_user_message: MessageAction | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -114,7 +113,7 @@ class AgentController:
|
||||
"""Initializes a new instance of the AgentController class.
|
||||
|
||||
Args:
|
||||
agent: The agent instance to control. The agent should already have an initialized LLM.
|
||||
agent: The agent instance to control.
|
||||
event_stream: The event stream to publish events to.
|
||||
max_iterations: The maximum number of iterations the agent can run.
|
||||
max_budget_per_task: The maximum budget (in USD) allowed per task, beyond which the agent will stop.
|
||||
@@ -602,7 +601,6 @@ class AgentController:
|
||||
agent_cls: Type[Agent] = Agent.get_cls(action.agent)
|
||||
agent_config = self.agent_configs.get(action.agent, self.agent.config)
|
||||
llm_config = self.agent_to_llm_config.get(action.agent, self.agent.llm.config)
|
||||
# Create a new LLM instance for the delegate with its own config
|
||||
llm = LLM(config=llm_config, retry_listener=self._notify_on_llm_retry)
|
||||
delegate_agent = agent_cls(llm=llm, config=agent_config)
|
||||
state = State(
|
||||
@@ -1208,19 +1206,15 @@ class AgentController:
|
||||
Returns:
|
||||
MessageAction | None: The first user message, or None if no user message found
|
||||
"""
|
||||
# Return cached message if any
|
||||
if self._cached_first_user_message is not None:
|
||||
return self._cached_first_user_message
|
||||
# Find the first user message from the appropriate starting point
|
||||
user_messages = list(self.event_stream.get_events(start_id=self.state.start_id))
|
||||
|
||||
# Find the first user message
|
||||
self._cached_first_user_message = next(
|
||||
# Get and return the first user message
|
||||
return next(
|
||||
(
|
||||
e
|
||||
for e in self.event_stream.get_events(
|
||||
start_id=self.state.start_id,
|
||||
)
|
||||
for e in user_messages
|
||||
if isinstance(e, MessageAction) and e.source == EventSource.USER
|
||||
),
|
||||
None,
|
||||
)
|
||||
return self._cached_first_user_message
|
||||
|
||||
@@ -15,7 +15,6 @@ from openhands.core.config import (
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events import EventStream
|
||||
from openhands.events.event import Event
|
||||
from openhands.integrations.provider import ProviderToken, ProviderType, SecretStore
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.memory.memory import Memory
|
||||
from openhands.microagent.microagent import BaseMicroAgent
|
||||
@@ -102,23 +101,11 @@ def initialize_repository_for_runtime(
|
||||
github_token = (
|
||||
SecretStr(os.environ.get('GITHUB_TOKEN')) if not github_token else github_token
|
||||
)
|
||||
|
||||
secret_store = (
|
||||
SecretStore(
|
||||
provider_tokens={
|
||||
ProviderType.GITHUB: ProviderToken(token=SecretStr(github_token))
|
||||
}
|
||||
)
|
||||
if github_token
|
||||
else None
|
||||
)
|
||||
provider_tokens = secret_store.provider_tokens if secret_store else None
|
||||
|
||||
repo_directory = None
|
||||
if selected_repository and provider_tokens:
|
||||
if selected_repository and github_token:
|
||||
logger.debug(f'Selected repository {selected_repository}.')
|
||||
repo_directory = runtime.clone_repo(
|
||||
provider_tokens,
|
||||
github_token,
|
||||
selected_repository,
|
||||
None,
|
||||
)
|
||||
@@ -166,17 +153,12 @@ def create_memory(
|
||||
return memory
|
||||
|
||||
|
||||
def create_agent(config: AppConfig, llm: LLM | None = None) -> Agent:
|
||||
def create_agent(config: AppConfig) -> Agent:
|
||||
agent_cls: Type[Agent] = Agent.get_cls(config.default_agent)
|
||||
agent_config = config.get_agent_config(config.default_agent)
|
||||
|
||||
# Create LLM if not provided
|
||||
if llm is None:
|
||||
llm_config = config.get_llm_config_from_agent(config.default_agent)
|
||||
llm = LLM(config=llm_config)
|
||||
|
||||
llm_config = config.get_llm_config_from_agent(config.default_agent)
|
||||
agent = agent_cls(
|
||||
llm=llm,
|
||||
llm=LLM(config=llm_config),
|
||||
config=agent_config,
|
||||
)
|
||||
|
||||
@@ -202,7 +184,6 @@ def create_controller(
|
||||
except Exception as e:
|
||||
logger.debug(f'Cannot restore agent state: {e}')
|
||||
|
||||
# The agent already has an initialized LLM
|
||||
controller = AgentController(
|
||||
agent=agent,
|
||||
max_iterations=config.max_iterations,
|
||||
|
||||
@@ -22,7 +22,6 @@ class GitLabService(GitService):
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str | None = None,
|
||||
external_auth_id: str | None = None,
|
||||
external_auth_token: SecretStr | None = None,
|
||||
token: SecretStr | None = None,
|
||||
external_token_manager: bool = False,
|
||||
@@ -47,7 +46,7 @@ class GitLabService(GitService):
|
||||
def _has_token_expired(self, status_code: int) -> bool:
|
||||
return status_code == 401
|
||||
|
||||
async def get_latest_token(self) -> SecretStr | None:
|
||||
async def get_latest_token(self) -> SecretStr:
|
||||
return self.token
|
||||
|
||||
async def _fetch_data(
|
||||
|
||||
@@ -2,7 +2,6 @@ from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from types import MappingProxyType
|
||||
from typing import Any, Coroutine, Literal, overload
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
@@ -14,9 +13,6 @@ from pydantic import (
|
||||
)
|
||||
from pydantic.json import pydantic_encoder
|
||||
|
||||
from openhands.events.action.action import Action
|
||||
from openhands.events.action.commands import CmdRunAction
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.integrations.github.github_service import GithubServiceImpl
|
||||
from openhands.integrations.gitlab.gitlab_service import GitLabServiceImpl
|
||||
from openhands.integrations.service_types import (
|
||||
@@ -134,23 +130,15 @@ class ProviderHandler:
|
||||
def __init__(
|
||||
self,
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE,
|
||||
external_auth_id: str | None = None,
|
||||
external_auth_token: SecretStr | None = None,
|
||||
external_token_manager: bool = False,
|
||||
):
|
||||
if not isinstance(provider_tokens, MappingProxyType):
|
||||
raise TypeError(
|
||||
f'provider_tokens must be a MappingProxyType, got {type(provider_tokens).__name__}'
|
||||
)
|
||||
|
||||
self.service_class_map: dict[ProviderType, type[GitService]] = {
|
||||
ProviderType.GITHUB: GithubServiceImpl,
|
||||
ProviderType.GITLAB: GitLabServiceImpl,
|
||||
}
|
||||
|
||||
self.external_auth_id = external_auth_id
|
||||
# Create immutable copy through SecretStore
|
||||
self.external_auth_token = external_auth_token
|
||||
self.external_token_manager = external_token_manager
|
||||
self._provider_tokens = provider_tokens
|
||||
|
||||
@property
|
||||
@@ -164,10 +152,8 @@ class ProviderHandler:
|
||||
service_class = self.service_class_map[provider]
|
||||
return service_class(
|
||||
user_id=token.user_id,
|
||||
external_auth_id=self.external_auth_id,
|
||||
external_auth_token=self.external_auth_token,
|
||||
token=token.token,
|
||||
external_token_manager=self.external_token_manager,
|
||||
)
|
||||
|
||||
async def get_user(self) -> User:
|
||||
@@ -180,12 +166,14 @@ class ProviderHandler:
|
||||
continue
|
||||
raise AuthenticationError('Need valid provider token')
|
||||
|
||||
async def _get_latest_provider_token(
|
||||
self, provider: ProviderType
|
||||
) -> SecretStr | None:
|
||||
"""Get latest token from service"""
|
||||
service = self._get_service(provider)
|
||||
return await service.get_latest_token()
|
||||
async def get_latest_provider_tokens(self) -> dict[ProviderType, SecretStr]:
|
||||
"""Get latest token from services"""
|
||||
tokens = {}
|
||||
for provider in self.provider_tokens:
|
||||
service = self._get_service(provider)
|
||||
tokens[provider] = await service.get_latest_token()
|
||||
|
||||
return tokens
|
||||
|
||||
async def get_repositories(
|
||||
self, page: int, per_page: int, sort: str, installation_id: int | None
|
||||
@@ -202,120 +190,3 @@ class ProviderHandler:
|
||||
except Exception:
|
||||
continue
|
||||
return all_repos
|
||||
|
||||
async def set_event_stream_secrets(
|
||||
self,
|
||||
event_stream: EventStream,
|
||||
env_vars: dict[ProviderType, SecretStr] | None = None,
|
||||
):
|
||||
"""
|
||||
This ensures that the latest provider tokens are masked from the event stream
|
||||
It is called when the provider tokens are first initialized in the runtime or when tokens are re-exported with the latest working ones
|
||||
|
||||
Args:
|
||||
event_stream: Agent session's event stream
|
||||
env_vars: Dict of providers and their tokens that require updating
|
||||
"""
|
||||
if env_vars:
|
||||
exposed_env_vars = self.expose_env_vars(env_vars)
|
||||
else:
|
||||
exposed_env_vars = await self.get_env_vars(expose_secrets=True)
|
||||
event_stream.set_secrets(exposed_env_vars)
|
||||
|
||||
def expose_env_vars(
|
||||
self, env_secrets: dict[ProviderType, SecretStr]
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Return string values instead of typed values for environment secrets
|
||||
Called just before exporting secrets to runtime, or setting secrets in the event stream
|
||||
"""
|
||||
exposed_envs = {}
|
||||
for provider, token in env_secrets.items():
|
||||
env_key = ProviderHandler.get_provider_env_key(provider)
|
||||
exposed_envs[env_key] = token.get_secret_value()
|
||||
|
||||
return exposed_envs
|
||||
|
||||
@overload
|
||||
def get_env_vars(
|
||||
self,
|
||||
expose_secrets: Literal[True],
|
||||
providers: list[ProviderType] | None = ...,
|
||||
get_latest: bool = False,
|
||||
) -> Coroutine[Any, Any, dict[str, str]]: ...
|
||||
|
||||
@overload
|
||||
def get_env_vars(
|
||||
self,
|
||||
expose_secrets: Literal[False],
|
||||
providers: list[ProviderType] | None = ...,
|
||||
get_latest: bool = False,
|
||||
) -> Coroutine[Any, Any, dict[ProviderType, SecretStr]]: ...
|
||||
|
||||
async def get_env_vars(
|
||||
self,
|
||||
expose_secrets: bool = False,
|
||||
providers: list[ProviderType] | None = None,
|
||||
get_latest: bool = False,
|
||||
) -> dict[ProviderType, SecretStr] | dict[str, str]:
|
||||
"""
|
||||
Retrieves the provider tokens from ProviderHandler object
|
||||
This is used when initializing/exporting new provider tokens in the runtime
|
||||
|
||||
Args:
|
||||
expose_secrets: Flag which returns strings instead of secrets
|
||||
providers: Return provider tokens for the list passed in, otherwise return all available providers
|
||||
get_latest: Get the latest working token for the providers if True, otherwise get the existing ones
|
||||
"""
|
||||
|
||||
if not self.provider_tokens:
|
||||
return {}
|
||||
|
||||
env_vars: dict[ProviderType, SecretStr] = {}
|
||||
all_providers = [provider for provider in ProviderType]
|
||||
provider_list = providers if providers else all_providers
|
||||
|
||||
for provider in provider_list:
|
||||
if provider in self.provider_tokens:
|
||||
token = (
|
||||
self.provider_tokens[provider].token
|
||||
if self.provider_tokens
|
||||
else SecretStr('')
|
||||
)
|
||||
|
||||
if get_latest:
|
||||
token = await self._get_latest_provider_token(provider)
|
||||
|
||||
if token:
|
||||
env_vars[provider] = token
|
||||
|
||||
if not expose_secrets:
|
||||
return env_vars
|
||||
|
||||
return self.expose_env_vars(env_vars)
|
||||
|
||||
@classmethod
|
||||
def check_cmd_action_for_provider_token_ref(
|
||||
cls, event: Action
|
||||
) -> list[ProviderType]:
|
||||
"""
|
||||
Detect if agent run action is using a provider token (e.g $GITHUB_TOKEN)
|
||||
Returns a list of providers which are called by the agent
|
||||
"""
|
||||
|
||||
if not isinstance(event, CmdRunAction):
|
||||
return []
|
||||
|
||||
called_providers = []
|
||||
for provider in ProviderType:
|
||||
if ProviderHandler.get_provider_env_key(provider) in event.command.lower():
|
||||
called_providers.append(provider)
|
||||
|
||||
return called_providers
|
||||
|
||||
@classmethod
|
||||
def get_provider_env_key(cls, provider: ProviderType) -> str:
|
||||
"""
|
||||
Map ProviderType value to the environment variable name in the runtime
|
||||
"""
|
||||
return f'{provider.value}_token'.lower()
|
||||
|
||||
@@ -52,17 +52,16 @@ class GitService(Protocol):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str | None = None,
|
||||
token: SecretStr | None = None,
|
||||
external_auth_id: str | None = None,
|
||||
external_auth_token: SecretStr | None = None,
|
||||
user_id: str | None,
|
||||
token: SecretStr | None,
|
||||
external_auth_token: SecretStr | None,
|
||||
external_token_manager: bool = False,
|
||||
) -> None:
|
||||
"""Initialize the service with authentication details"""
|
||||
...
|
||||
|
||||
async def get_latest_token(self) -> SecretStr | None:
|
||||
"""Get latest working token of the user"""
|
||||
async def get_latest_token(self) -> SecretStr:
|
||||
"""Get latest working token of the users"""
|
||||
...
|
||||
|
||||
async def get_user(self) -> User:
|
||||
|
||||
@@ -9,8 +9,7 @@ import string
|
||||
import tempfile
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from types import MappingProxyType
|
||||
from typing import Callable, cast
|
||||
from typing import Callable
|
||||
from zipfile import ZipFile
|
||||
|
||||
from pydantic import SecretStr
|
||||
@@ -42,11 +41,7 @@ from openhands.events.observation import (
|
||||
UserRejectObservation,
|
||||
)
|
||||
from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
|
||||
from openhands.integrations.provider import (
|
||||
PROVIDER_TOKEN_TYPE,
|
||||
ProviderHandler,
|
||||
ProviderType,
|
||||
)
|
||||
from openhands.integrations.github.github_service import GithubServiceImpl
|
||||
from openhands.microagent import (
|
||||
BaseMicroAgent,
|
||||
load_microagents_from_dir,
|
||||
@@ -57,11 +52,7 @@ from openhands.runtime.plugins import (
|
||||
VSCodeRequirement,
|
||||
)
|
||||
from openhands.runtime.utils.edit import FileEditRuntimeMixin
|
||||
from openhands.utils.async_utils import (
|
||||
GENERAL_TIMEOUT,
|
||||
call_async_from_sync,
|
||||
call_sync_from_async,
|
||||
)
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
STATUS_MESSAGES = {
|
||||
'STATUS$STARTING_RUNTIME': 'Starting runtime...',
|
||||
@@ -107,7 +98,6 @@ class Runtime(FileEditRuntimeMixin):
|
||||
attach_to_existing: bool = False,
|
||||
headless_mode: bool = False,
|
||||
user_id: str | None = None,
|
||||
git_provider_tokens: PROVIDER_TOKEN_TYPE | None = None,
|
||||
):
|
||||
self.sid = sid
|
||||
self.event_stream = event_stream
|
||||
@@ -131,17 +121,6 @@ class Runtime(FileEditRuntimeMixin):
|
||||
if env_vars is not None:
|
||||
self.initial_env_vars.update(env_vars)
|
||||
|
||||
self.provider_handler = ProviderHandler(
|
||||
provider_tokens=git_provider_tokens
|
||||
or cast(PROVIDER_TOKEN_TYPE, MappingProxyType({})),
|
||||
external_auth_id=user_id,
|
||||
external_token_manager=True,
|
||||
)
|
||||
raw_env_vars: dict[str, str] = call_async_from_sync(
|
||||
self.provider_handler.get_env_vars, GENERAL_TIMEOUT, True, None, False
|
||||
)
|
||||
self.initial_env_vars.update(raw_env_vars)
|
||||
|
||||
self._vscode_enabled = any(
|
||||
isinstance(plugin, VSCodeRequirement) for plugin in self.plugins
|
||||
)
|
||||
@@ -192,8 +171,6 @@ class Runtime(FileEditRuntimeMixin):
|
||||
# ====================================================================
|
||||
|
||||
def add_env_vars(self, env_vars: dict[str, str]) -> None:
|
||||
env_vars = {key.upper(): value for key, value in env_vars.items()}
|
||||
|
||||
# Add env vars to the IPython shell (if Jupyter is used)
|
||||
if any(isinstance(plugin, JupyterRequirement) for plugin in self.plugins):
|
||||
code = 'import os\n'
|
||||
@@ -239,62 +216,56 @@ class Runtime(FileEditRuntimeMixin):
|
||||
if isinstance(event, Action):
|
||||
asyncio.get_event_loop().run_until_complete(self._handle_action(event))
|
||||
|
||||
async def _export_latest_git_provider_tokens(self, event: Action) -> None:
|
||||
"""
|
||||
Refresh runtime provider tokens when agent attemps to run action with provider token
|
||||
"""
|
||||
if not self.user_id:
|
||||
return
|
||||
|
||||
providers_called = ProviderHandler.check_cmd_action_for_provider_token_ref(
|
||||
event
|
||||
)
|
||||
|
||||
if not providers_called:
|
||||
return
|
||||
|
||||
logger.info(f'Fetching latest github token for runtime: {self.sid}')
|
||||
env_vars = await self.provider_handler.get_env_vars(
|
||||
providers=providers_called, expose_secrets=False, get_latest=True
|
||||
)
|
||||
|
||||
# This statement is to debug expired github tokens, and will be removed later
|
||||
if ProviderType.GITHUB not in env_vars:
|
||||
logger.error(f'Failed to refresh github token for runtime: {self.sid}')
|
||||
return
|
||||
|
||||
if len(env_vars) == 0:
|
||||
return
|
||||
|
||||
raw_token = env_vars[ProviderType.GITHUB].get_secret_value()
|
||||
if not self.prev_token:
|
||||
logger.info(
|
||||
f'Setting github token in runtime: {self.sid}\nToken value: {raw_token[0:5]}; length: {len(raw_token)}'
|
||||
)
|
||||
elif self.prev_token.get_secret_value() != raw_token:
|
||||
logger.info(
|
||||
f'Setting new github token in runtime {self.sid}\nToken value: {raw_token[0:5]}; length: {len(raw_token)}'
|
||||
)
|
||||
|
||||
self.prev_token = SecretStr(raw_token)
|
||||
|
||||
try:
|
||||
await self.provider_handler.set_event_stream_secrets(
|
||||
self.event_stream, env_vars=env_vars
|
||||
)
|
||||
self.add_env_vars(self.provider_handler.expose_env_vars(env_vars))
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f'Failed export latest github token to runtime: {self.sid}, {e}'
|
||||
)
|
||||
|
||||
async def _handle_action(self, event: Action) -> None:
|
||||
if event.timeout is None:
|
||||
# We don't block the command if this is a default timeout action
|
||||
event.set_hard_timeout(self.config.sandbox.timeout, blocking=False)
|
||||
assert event.timeout is not None
|
||||
try:
|
||||
await self._export_latest_git_provider_tokens(event)
|
||||
if isinstance(event, CmdRunAction):
|
||||
if self.user_id and 'GITHUB_TOKEN' in event.command:
|
||||
gh_client = GithubServiceImpl(
|
||||
external_auth_id=self.user_id, external_token_manager=True
|
||||
)
|
||||
logger.info(f'Fetching latest github token for runtime: {self.sid}')
|
||||
token = await gh_client.get_latest_token()
|
||||
if not token:
|
||||
logger.error(
|
||||
f'Failed to refresh github token for runtime: {self.sid}'
|
||||
)
|
||||
|
||||
if token:
|
||||
raw_token = token.get_secret_value()
|
||||
|
||||
if not self.prev_token:
|
||||
logger.info(
|
||||
f'Setting github token in runtime: {self.sid}\nToken value: {raw_token[0:5]}; length: {len(raw_token)}'
|
||||
)
|
||||
|
||||
elif self.prev_token.get_secret_value() != raw_token:
|
||||
logger.info(
|
||||
f'Setting new github token in runtime {self.sid}\nToken value: {raw_token[0:5]}; length: {len(raw_token)}'
|
||||
)
|
||||
|
||||
self.prev_token = token
|
||||
|
||||
env_vars = {
|
||||
'GITHUB_TOKEN': raw_token,
|
||||
}
|
||||
|
||||
try:
|
||||
self.add_env_vars(env_vars)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f'Failed export latest github token to runtime: {self.sid}, {e}'
|
||||
)
|
||||
|
||||
self.event_stream.update_secrets(
|
||||
{
|
||||
'github_token': raw_token,
|
||||
}
|
||||
)
|
||||
|
||||
observation: Observation = await call_sync_from_async(
|
||||
self.run_action, event
|
||||
)
|
||||
@@ -322,20 +293,14 @@ class Runtime(FileEditRuntimeMixin):
|
||||
|
||||
def clone_repo(
|
||||
self,
|
||||
git_provider_tokens: PROVIDER_TOKEN_TYPE,
|
||||
github_token: SecretStr,
|
||||
selected_repository: str,
|
||||
selected_branch: str | None,
|
||||
) -> str:
|
||||
if (
|
||||
ProviderType.GITHUB not in git_provider_tokens
|
||||
or not git_provider_tokens[ProviderType.GITHUB].token
|
||||
or not selected_repository
|
||||
):
|
||||
if not github_token or not selected_repository:
|
||||
raise ValueError(
|
||||
'github_token and selected_repository must be provided to clone a repository'
|
||||
)
|
||||
|
||||
github_token: SecretStr = git_provider_tokens[ProviderType.GITHUB].token
|
||||
url = f'https://{github_token.get_secret_value()}@github.com/{selected_repository}.git'
|
||||
dir_name = selected_repository.split('/')[-1]
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import threading
|
||||
from abc import abstractmethod
|
||||
@@ -35,7 +36,6 @@ from openhands.events.observation import (
|
||||
)
|
||||
from openhands.events.serialization import event_to_dict, observation_from_dict
|
||||
from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.runtime.plugins import PluginRequirement
|
||||
from openhands.runtime.utils.request import send_request
|
||||
@@ -60,7 +60,6 @@ class ActionExecutionClient(Runtime):
|
||||
attach_to_existing: bool = False,
|
||||
headless_mode: bool = True,
|
||||
user_id: str | None = None,
|
||||
git_provider_tokens: PROVIDER_TOKEN_TYPE | None = None
|
||||
):
|
||||
self.session = HttpSession()
|
||||
self.action_semaphore = threading.Semaphore(1) # Ensure one action at a time
|
||||
@@ -77,7 +76,6 @@ class ActionExecutionClient(Runtime):
|
||||
attach_to_existing,
|
||||
headless_mode,
|
||||
user_id,
|
||||
git_provider_tokens,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
@@ -151,10 +149,7 @@ class ActionExecutionClient(Runtime):
|
||||
with tempfile.NamedTemporaryFile(
|
||||
suffix='.zip', delete=False
|
||||
) as temp_file:
|
||||
for chunk in response.iter_content(chunk_size=16 * 1024):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
temp_file.write(chunk)
|
||||
temp_file.flush()
|
||||
shutil.copyfileobj(response.raw, temp_file, length=16 * 1024)
|
||||
return Path(temp_file.name)
|
||||
except requests.Timeout:
|
||||
raise TimeoutError('Copy operation timed out')
|
||||
|
||||
@@ -16,7 +16,6 @@ from openhands.core.exceptions import (
|
||||
)
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events import EventStream
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||
from openhands.runtime.builder.remote import RemoteRuntimeBuilder
|
||||
from openhands.runtime.impl.action_execution.action_execution_client import (
|
||||
ActionExecutionClient,
|
||||
@@ -48,7 +47,6 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
attach_to_existing: bool = False,
|
||||
headless_mode: bool = True,
|
||||
user_id: str | None = None,
|
||||
git_provider_tokens: PROVIDER_TOKEN_TYPE | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
config,
|
||||
@@ -60,7 +58,6 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
attach_to_existing,
|
||||
headless_mode,
|
||||
user_id,
|
||||
git_provider_tokens,
|
||||
)
|
||||
if self.config.sandbox.api_key is None:
|
||||
raise ValueError(
|
||||
|
||||
@@ -28,6 +28,7 @@ def get_github_token(request: Request) -> SecretStr | None:
|
||||
|
||||
def get_github_user_id(request: Request) -> str | None:
|
||||
provider_tokens = get_provider_tokens(request)
|
||||
|
||||
if provider_tokens and ProviderType.GITHUB in provider_tokens:
|
||||
return provider_tokens[ProviderType.GITHUB].user_id
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from openhands.core.exceptions import AgentRuntimeUnavailableError
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.events.observation.agent import AgentStateChangedObservation
|
||||
from openhands.events.stream import EventStream, EventStreamSubscriber, session_exists
|
||||
from openhands.server.config.server_config import ServerConfig
|
||||
from openhands.server.monitoring import MonitoringListener
|
||||
@@ -115,22 +116,27 @@ class StandaloneConversationManager(ConversationManager):
|
||||
settings: Settings,
|
||||
user_id: str | None,
|
||||
github_user_id: str | None,
|
||||
) -> EventStream:
|
||||
):
|
||||
logger.info(
|
||||
f'join_conversation:{sid}:{connection_id}',
|
||||
extra={'session_id': sid, 'user_id': user_id},
|
||||
)
|
||||
await self.sio.enter_room(connection_id, ROOM_KEY.format(sid=sid))
|
||||
self._local_connection_id_to_session_id[connection_id] = sid
|
||||
event_stream = await self.maybe_start_agent_loop(
|
||||
sid, settings, user_id, github_user_id=github_user_id
|
||||
)
|
||||
event_stream = await self._get_event_stream(sid, user_id)
|
||||
if not event_stream:
|
||||
logger.error(
|
||||
f'No event stream after joining conversation: {sid}',
|
||||
extra={'session_id': sid},
|
||||
return await self.maybe_start_agent_loop(
|
||||
sid, settings, user_id, github_user_id=github_user_id
|
||||
)
|
||||
raise RuntimeError(f'no_event_stream:{sid}')
|
||||
for event in event_stream.get_events(reverse=True):
|
||||
if isinstance(event, AgentStateChangedObservation):
|
||||
if event.agent_state in (
|
||||
AgentState.STOPPED.value,
|
||||
AgentState.ERROR.value,
|
||||
):
|
||||
await self.close_session(sid)
|
||||
return await self.maybe_start_agent_loop(sid, settings, user_id)
|
||||
break
|
||||
return event_stream
|
||||
|
||||
async def detach_from_conversation(self, conversation: Conversation):
|
||||
@@ -324,6 +330,15 @@ class StandaloneConversationManager(ConversationManager):
|
||||
|
||||
session = self._local_agent_loops_by_sid.get(sid)
|
||||
if session:
|
||||
# Check if the session is ready to process actions
|
||||
if not session.is_ready():
|
||||
logger.info(
|
||||
f'Session not ready, queueing action: {data}',
|
||||
extra={'session_id': sid},
|
||||
)
|
||||
session.queue_action(data)
|
||||
return
|
||||
|
||||
await session.dispatch(data)
|
||||
return
|
||||
|
||||
|
||||
@@ -54,13 +54,10 @@ async def connect(connection_id: str, environ):
|
||||
event_stream = await conversation_manager.join_conversation(
|
||||
conversation_id, connection_id, settings, user_id, github_user_id
|
||||
)
|
||||
logger.info(
|
||||
f'Connected to conversation {conversation_id} with connection_id {connection_id}. Replaying event stream...'
|
||||
)
|
||||
|
||||
agent_state_changed = None
|
||||
async_stream = AsyncEventStreamWrapper(event_stream, latest_event_id + 1)
|
||||
async for event in async_stream:
|
||||
logger.info(f'oh_event: {event.__class__.__name__}')
|
||||
if isinstance(
|
||||
event,
|
||||
(NullAction, NullObservation, RecallAction, RecallObservation),
|
||||
@@ -72,7 +69,6 @@ async def connect(connection_id: str, environ):
|
||||
await sio.emit('oh_event', event_to_dict(event), to=connection_id)
|
||||
if agent_state_changed:
|
||||
await sio.emit('oh_event', event_to_dict(agent_state_changed), to=connection_id)
|
||||
logger.info(f'Finished replaying event stream for conversation {conversation_id}')
|
||||
|
||||
|
||||
@sio.event
|
||||
|
||||
@@ -3,15 +3,15 @@ from datetime import datetime, timezone
|
||||
|
||||
from fastapi import APIRouter, Body, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.integrations.provider import (
|
||||
PROVIDER_TOKEN_TYPE,
|
||||
)
|
||||
from openhands.integrations.github.github_service import GithubServiceImpl
|
||||
from openhands.integrations.provider import ProviderType
|
||||
from openhands.runtime import get_runtime_cls
|
||||
from openhands.server.auth import (
|
||||
get_access_token,
|
||||
get_github_user_id,
|
||||
get_provider_tokens,
|
||||
get_user_id,
|
||||
@@ -44,7 +44,7 @@ class InitSessionRequest(BaseModel):
|
||||
|
||||
async def _create_new_conversation(
|
||||
user_id: str | None,
|
||||
git_provider_tokens: PROVIDER_TOKEN_TYPE | None,
|
||||
token: SecretStr | None,
|
||||
selected_repository: str | None,
|
||||
selected_branch: str | None,
|
||||
initial_user_msg: str | None,
|
||||
@@ -78,7 +78,7 @@ async def _create_new_conversation(
|
||||
logger.warn('Settings not present, not starting conversation')
|
||||
raise MissingSettingsError('Settings not found')
|
||||
|
||||
session_init_args['git_provider_tokens'] = git_provider_tokens
|
||||
session_init_args['provider_token'] = token
|
||||
session_init_args['selected_repository'] = selected_repository
|
||||
session_init_args['selected_branch'] = selected_branch
|
||||
conversation_init_data = ConversationInitData(**session_init_args)
|
||||
@@ -146,7 +146,19 @@ async def new_conversation(request: Request, data: InitSessionRequest):
|
||||
using the returned conversation ID.
|
||||
"""
|
||||
logger.info('Initializing new conversation')
|
||||
user_id = None
|
||||
github_token = None
|
||||
provider_tokens = get_provider_tokens(request)
|
||||
if provider_tokens and ProviderType.GITHUB in provider_tokens:
|
||||
token = provider_tokens[ProviderType.GITHUB]
|
||||
user_id = token.user_id
|
||||
gh_client = GithubServiceImpl(
|
||||
user_id=user_id,
|
||||
external_auth_token=get_access_token(request),
|
||||
token=token.token,
|
||||
)
|
||||
github_token = await gh_client.get_latest_token()
|
||||
|
||||
selected_repository = data.selected_repository
|
||||
selected_branch = data.selected_branch
|
||||
initial_user_msg = data.initial_user_msg
|
||||
@@ -156,7 +168,7 @@ async def new_conversation(request: Request, data: InitSessionRequest):
|
||||
# Create conversation with initial message
|
||||
conversation_id = await _create_new_conversation(
|
||||
get_user_id(request),
|
||||
provider_tokens,
|
||||
github_token,
|
||||
selected_repository,
|
||||
selected_branch,
|
||||
initial_user_msg,
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import asyncio
|
||||
import time
|
||||
from logging import LoggerAdapter
|
||||
from types import MappingProxyType
|
||||
from typing import Callable, cast
|
||||
from typing import Callable
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.controller import AgentController
|
||||
from openhands.controller.agent import Agent
|
||||
@@ -14,7 +15,6 @@ from openhands.core.schema.agent import AgentState
|
||||
from openhands.events.action import ChangeAgentStateAction, MessageAction
|
||||
from openhands.events.event import EventSource
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderHandler
|
||||
from openhands.memory.memory import Memory
|
||||
from openhands.microagent.microagent import BaseMicroAgent
|
||||
from openhands.runtime import get_runtime_cls
|
||||
@@ -78,10 +78,10 @@ class AgentSession:
|
||||
config: AppConfig,
|
||||
agent: Agent,
|
||||
max_iterations: int,
|
||||
git_provider_tokens: PROVIDER_TOKEN_TYPE | None = None,
|
||||
max_budget_per_task: float | None = None,
|
||||
agent_to_llm_config: dict[str, LLMConfig] | None = None,
|
||||
agent_configs: dict[str, AgentConfig] | None = None,
|
||||
github_token: SecretStr | None = None,
|
||||
selected_repository: str | None = None,
|
||||
selected_branch: str | None = None,
|
||||
initial_message: MessageAction | None = None,
|
||||
@@ -115,7 +115,7 @@ class AgentSession:
|
||||
runtime_name=runtime_name,
|
||||
config=config,
|
||||
agent=agent,
|
||||
git_provider_tokens=git_provider_tokens,
|
||||
github_token=github_token,
|
||||
selected_repository=selected_repository,
|
||||
selected_branch=selected_branch,
|
||||
)
|
||||
@@ -137,10 +137,12 @@ class AgentSession:
|
||||
repo_directory=repo_directory,
|
||||
)
|
||||
|
||||
if git_provider_tokens:
|
||||
provider_handler = ProviderHandler(provider_tokens=git_provider_tokens)
|
||||
await provider_handler.set_event_stream_secrets(self.event_stream)
|
||||
|
||||
if github_token:
|
||||
self.event_stream.set_secrets(
|
||||
{
|
||||
'github_token': github_token.get_secret_value(),
|
||||
}
|
||||
)
|
||||
if not self._closed:
|
||||
if initial_message:
|
||||
self.event_stream.add_event(initial_message, EventSource.USER)
|
||||
@@ -210,7 +212,7 @@ class AgentSession:
|
||||
runtime_name: str,
|
||||
config: AppConfig,
|
||||
agent: Agent,
|
||||
git_provider_tokens: PROVIDER_TOKEN_TYPE | None = None,
|
||||
github_token: SecretStr | None = None,
|
||||
selected_repository: str | None = None,
|
||||
selected_branch: str | None = None,
|
||||
) -> bool:
|
||||
@@ -230,36 +232,29 @@ class AgentSession:
|
||||
|
||||
self.logger.debug(f'Initializing runtime `{runtime_name}` now...')
|
||||
runtime_cls = get_runtime_cls(runtime_name)
|
||||
env_vars = (
|
||||
{
|
||||
'GITHUB_TOKEN': github_token.get_secret_value(),
|
||||
}
|
||||
if github_token
|
||||
else None
|
||||
)
|
||||
|
||||
kwargs = {}
|
||||
if runtime_cls == RemoteRuntime:
|
||||
self.runtime = runtime_cls(
|
||||
config=config,
|
||||
event_stream=self.event_stream,
|
||||
sid=self.sid,
|
||||
plugins=agent.sandbox_plugins,
|
||||
status_callback=self._status_callback,
|
||||
headless_mode=False,
|
||||
attach_to_existing=False,
|
||||
git_provider_tokens=git_provider_tokens,
|
||||
user_id=self.user_id,
|
||||
)
|
||||
else:
|
||||
provider_handler = ProviderHandler(
|
||||
provider_tokens=git_provider_tokens
|
||||
or cast(PROVIDER_TOKEN_TYPE, MappingProxyType({}))
|
||||
)
|
||||
env_vars = await provider_handler.get_env_vars(expose_secrets=True)
|
||||
kwargs['user_id'] = self.user_id
|
||||
|
||||
self.runtime = runtime_cls(
|
||||
config=config,
|
||||
event_stream=self.event_stream,
|
||||
sid=self.sid,
|
||||
plugins=agent.sandbox_plugins,
|
||||
status_callback=self._status_callback,
|
||||
headless_mode=False,
|
||||
attach_to_existing=False,
|
||||
env_vars=env_vars,
|
||||
)
|
||||
self.runtime = runtime_cls(
|
||||
config=config,
|
||||
event_stream=self.event_stream,
|
||||
sid=self.sid,
|
||||
plugins=agent.sandbox_plugins,
|
||||
status_callback=self._status_callback,
|
||||
headless_mode=False,
|
||||
attach_to_existing=False,
|
||||
env_vars=env_vars,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# FIXME: this sleep is a terrible hack.
|
||||
# This is to give the websocket a second to connect, so that
|
||||
@@ -276,10 +271,10 @@ class AgentSession:
|
||||
)
|
||||
return False
|
||||
|
||||
if selected_repository and git_provider_tokens:
|
||||
if selected_repository:
|
||||
await call_sync_from_async(
|
||||
self.runtime.clone_repo,
|
||||
git_provider_tokens,
|
||||
github_token,
|
||||
selected_repository,
|
||||
selected_branch,
|
||||
)
|
||||
@@ -330,7 +325,6 @@ class AgentSession:
|
||||
)
|
||||
self.logger.debug(msg)
|
||||
|
||||
# The agent already has an initialized LLM
|
||||
controller = AgentController(
|
||||
sid=self.sid,
|
||||
event_stream=self.event_stream,
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from types import MappingProxyType
|
||||
from pydantic import Field
|
||||
from pydantic import Field, SecretStr
|
||||
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||
from openhands.server.settings import Settings
|
||||
|
||||
|
||||
@@ -10,10 +8,6 @@ class ConversationInitData(Settings):
|
||||
Session initialization data for the web environment - a deep copy of the global config is made and then overridden with this data.
|
||||
"""
|
||||
|
||||
git_provider_tokens: PROVIDER_TOKEN_TYPE | None = Field(default=None, frozen=True)
|
||||
provider_token: SecretStr | None = Field(default=None)
|
||||
selected_repository: str | None = Field(default=None)
|
||||
selected_branch: str | None = Field(default=None)
|
||||
|
||||
model_config = {
|
||||
'arbitrary_types_allowed': True,
|
||||
}
|
||||
@@ -1,7 +1,9 @@
|
||||
import asyncio
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from dataclasses import field
|
||||
from logging import LoggerAdapter
|
||||
from typing import List
|
||||
|
||||
import socketio
|
||||
|
||||
@@ -12,6 +14,7 @@ from openhands.core.config.condenser_config import (
|
||||
)
|
||||
from openhands.core.const.guide_url import TROUBLESHOOTING_URL
|
||||
from openhands.core.logger import OpenHandsLoggerAdapter
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.events.action import MessageAction, NullAction
|
||||
from openhands.events.event import Event, EventSource
|
||||
@@ -23,7 +26,6 @@ from openhands.events.observation import (
|
||||
from openhands.events.observation.error import ErrorObservation
|
||||
from openhands.events.serialization import event_from_dict, event_to_dict
|
||||
from openhands.events.stream import EventStreamSubscriber
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.server.session.agent_session import AgentSession
|
||||
from openhands.server.session.conversation_init_data import ConversationInitData
|
||||
@@ -44,6 +46,9 @@ class Session:
|
||||
file_store: FileStore
|
||||
user_id: str | None
|
||||
logger: LoggerAdapter
|
||||
_pending_actions: List[dict] = []
|
||||
_is_ready: bool = False
|
||||
_ready_event: asyncio.Event = field(default_factory=asyncio.Event)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -71,6 +76,16 @@ class Session:
|
||||
self.config = deepcopy(config)
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self.user_id = user_id
|
||||
self._pending_actions = []
|
||||
self._is_ready = False
|
||||
self._ready_event = asyncio.Event()
|
||||
|
||||
# Subscribe to agent state changes to detect when the agent is ready
|
||||
self.agent_session.event_stream.subscribe(
|
||||
EventStreamSubscriber.SERVER,
|
||||
self._on_agent_state_change,
|
||||
f'{self.sid}_state_change',
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
if self.sio:
|
||||
@@ -87,6 +102,11 @@ class Session:
|
||||
async def initialize_agent(
|
||||
self, settings: Settings, initial_message: MessageAction | None
|
||||
):
|
||||
# Reset the ready state when initializing a new agent
|
||||
self._is_ready = False
|
||||
self._ready_event.clear()
|
||||
|
||||
# Set the agent state to LOADING
|
||||
self.agent_session.event_stream.add_event(
|
||||
AgentStateChangedObservation('', AgentState.LOADING),
|
||||
EventSource.ENVIRONMENT,
|
||||
@@ -134,11 +154,11 @@ class Session:
|
||||
|
||||
agent = Agent.get_cls(agent_cls)(llm, agent_config)
|
||||
|
||||
git_provider_tokens = None
|
||||
provider_token = None
|
||||
selected_repository = None
|
||||
selected_branch = None
|
||||
if isinstance(settings, ConversationInitData):
|
||||
git_provider_tokens = settings.git_provider_tokens
|
||||
provider_token = settings.provider_token
|
||||
selected_repository = settings.selected_repository
|
||||
selected_branch = settings.selected_branch
|
||||
|
||||
@@ -151,7 +171,7 @@ class Session:
|
||||
max_budget_per_task=self.config.max_budget_per_task,
|
||||
agent_to_llm_config=self.config.get_agent_to_llm_config_map(),
|
||||
agent_configs=self.config.get_agent_configs(),
|
||||
git_provider_tokens=git_provider_tokens,
|
||||
github_token=provider_token,
|
||||
selected_repository=selected_repository,
|
||||
selected_branch=selected_branch,
|
||||
initial_message=initial_message,
|
||||
@@ -280,3 +300,55 @@ class Session:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._send_status_message(msg_type, id, message), self.loop
|
||||
)
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
"""Check if the session is ready to process actions."""
|
||||
return self._is_ready
|
||||
|
||||
def queue_action(self, action_data: dict):
|
||||
"""Queue an action to be processed when the session is ready."""
|
||||
logger.info(f'Queueing action for session {self.sid}: {action_data}')
|
||||
self._pending_actions.append(action_data)
|
||||
|
||||
# Start a task to process the queue when the session becomes ready
|
||||
asyncio.run_coroutine_threadsafe(self._process_queue_when_ready(), self.loop)
|
||||
|
||||
async def _process_queue_when_ready(self):
|
||||
"""Process the queue of actions when the session becomes ready."""
|
||||
if not self._ready_event.is_set():
|
||||
try:
|
||||
# Wait for the session to become ready
|
||||
await asyncio.wait_for(self._ready_event.wait(), timeout=60)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
f'Timeout waiting for session {self.sid} to become ready'
|
||||
)
|
||||
return
|
||||
|
||||
# Process all pending actions
|
||||
if self._pending_actions:
|
||||
logger.info(
|
||||
f'Processing {len(self._pending_actions)} queued actions for session {self.sid}'
|
||||
)
|
||||
|
||||
# Process all pending actions
|
||||
for action_data in self._pending_actions:
|
||||
logger.info(f'Processing queued action: {action_data}')
|
||||
await self.dispatch(action_data)
|
||||
|
||||
# Clear the queue
|
||||
self._pending_actions = []
|
||||
|
||||
def _on_agent_state_change(self, event: Event):
|
||||
"""Handle agent state change events to detect when the agent is ready."""
|
||||
if isinstance(event, AgentStateChangedObservation):
|
||||
# Check if the agent state indicates it's ready
|
||||
if event.agent_state in ['idle', 'ready']:
|
||||
logger.info(f'Agent for session {self.sid} is now ready')
|
||||
self._is_ready = True
|
||||
self._ready_event.set()
|
||||
|
||||
# Process any pending actions
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._process_queue_when_ready(), self.loop
|
||||
)
|
||||
|
||||
142
poetry.lock
generated
142
poetry.lock
generated
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 2.0.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aiohappyeyeballs"
|
||||
@@ -113,7 +113,7 @@ propcache = ">=0.2.0"
|
||||
yarl = ">=1.17.0,<2.0"
|
||||
|
||||
[package.extras]
|
||||
speedups = ["Brotli ; platform_python_implementation == \"CPython\"", "aiodns (>=3.2.0) ; sys_platform == \"linux\" or sys_platform == \"darwin\"", "brotlicffi ; platform_python_implementation != \"CPython\""]
|
||||
speedups = ["Brotli", "aiodns (>=3.2.0)", "brotlicffi"]
|
||||
|
||||
[[package]]
|
||||
name = "aiolimiter"
|
||||
@@ -224,7 +224,7 @@ typing_extensions = {version = ">=4.5", markers = "python_version < \"3.13\""}
|
||||
|
||||
[package.extras]
|
||||
doc = ["Sphinx (>=7.4,<8.0)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx_rtd_theme"]
|
||||
test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "trustme", "truststore (>=0.9.1) ; python_version >= \"3.10\"", "uvloop (>=0.21) ; platform_python_implementation == \"CPython\" and platform_system != \"Windows\" and python_version < \"3.14\""]
|
||||
test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "trustme", "truststore (>=0.9.1)", "uvloop (>=0.21)"]
|
||||
trio = ["trio (>=0.26.1)"]
|
||||
|
||||
[[package]]
|
||||
@@ -360,12 +360,12 @@ files = [
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
benchmark = ["cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"]
|
||||
cov = ["cloudpickle ; platform_python_implementation == \"CPython\"", "coverage[toml] (>=5.3)", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"]
|
||||
dev = ["cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pre-commit-uv", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"]
|
||||
benchmark = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
|
||||
cov = ["cloudpickle", "coverage[toml] (>=5.3)", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
|
||||
dev = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pre-commit-uv", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
|
||||
docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier (<24.7)"]
|
||||
tests = ["cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"]
|
||||
tests-mypy = ["mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\""]
|
||||
tests = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
|
||||
tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"]
|
||||
|
||||
[[package]]
|
||||
name = "babel"
|
||||
@@ -380,7 +380,7 @@ files = [
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
dev = ["backports.zoneinfo ; python_version < \"3.9\"", "freezegun (>=1.0,<2.0)", "jinja2 (>=3.0)", "pytest (>=6.0)", "pytest-cov", "pytz", "setuptools", "tzdata ; sys_platform == \"win32\""]
|
||||
dev = ["backports.zoneinfo", "freezegun (>=1.0,<2.0)", "jinja2 (>=3.0)", "pytest (>=6.0)", "pytest-cov", "pytz", "setuptools", "tzdata"]
|
||||
|
||||
[[package]]
|
||||
name = "bashlex"
|
||||
@@ -408,9 +408,9 @@ files = [
|
||||
|
||||
[package.extras]
|
||||
all = ["typing-extensions (>=3.10.0.0)"]
|
||||
dev = ["autoapi (>=0.9.0)", "coverage (>=5.5)", "mypy (>=0.800) ; platform_python_implementation != \"PyPy\"", "numpy ; sys_platform != \"darwin\" and platform_python_implementation != \"PyPy\"", "pytest (>=4.0.0)", "sphinx", "sphinx (>=4.1.0)", "tox (>=3.20.1)", "typing-extensions ; python_version < \"3.9.0\""]
|
||||
dev = ["autoapi (>=0.9.0)", "coverage (>=5.5)", "mypy (>=0.800)", "numpy", "pytest (>=4.0.0)", "sphinx", "sphinx (>=4.1.0)", "tox (>=3.20.1)", "typing-extensions"]
|
||||
doc-rtd = ["furo (==2022.6.21)", "sphinx (==4.1.0)"]
|
||||
test-tox = ["mypy (>=0.800) ; platform_python_implementation != \"PyPy\"", "numpy ; sys_platform != \"darwin\" and platform_python_implementation != \"PyPy\"", "pytest (>=4.0.0)", "sphinx", "typing-extensions ; python_version < \"3.9.0\""]
|
||||
test-tox = ["mypy (>=0.800)", "numpy", "pytest (>=4.0.0)", "sphinx", "typing-extensions"]
|
||||
test-tox-coverage = ["coverage (>=5.5)"]
|
||||
|
||||
[[package]]
|
||||
@@ -699,7 +699,7 @@ pyproject_hooks = "*"
|
||||
|
||||
[package.extras]
|
||||
docs = ["furo (>=2023.08.17)", "sphinx (>=7.0,<8.0)", "sphinx-argparse-cli (>=1.5)", "sphinx-autodoc-typehints (>=1.10)", "sphinx-issues (>=3.0.0)"]
|
||||
test = ["build[uv,virtualenv]", "filelock (>=3)", "pytest (>=6.2.4)", "pytest-cov (>=2.12)", "pytest-mock (>=2)", "pytest-rerunfailures (>=9.1)", "pytest-xdist (>=1.34)", "setuptools (>=42.0.0) ; python_version < \"3.10\"", "setuptools (>=56.0.0) ; python_version == \"3.10\"", "setuptools (>=56.0.0) ; python_version == \"3.11\"", "setuptools (>=67.8.0) ; python_version >= \"3.12\"", "wheel (>=0.36.0)"]
|
||||
test = ["build[uv,virtualenv]", "filelock (>=3)", "pytest (>=6.2.4)", "pytest-cov (>=2.12)", "pytest-mock (>=2)", "pytest-rerunfailures (>=9.1)", "pytest-xdist (>=1.34)", "setuptools (>=42.0.0)", "setuptools (>=56.0.0)", "setuptools (>=56.0.0)", "setuptools (>=67.8.0)", "wheel (>=0.36.0)"]
|
||||
typing = ["build[uv]", "importlib-metadata (>=5.1)", "mypy (>=1.9.0,<1.10.0)", "tomli", "typing-extensions (>=3.7.4.3)"]
|
||||
uv = ["uv (>=0.1.18)"]
|
||||
virtualenv = ["virtualenv (>=20.0.35)"]
|
||||
@@ -1168,7 +1168,7 @@ files = [
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
toml = ["tomli ; python_full_version <= \"3.11.0a6\""]
|
||||
toml = ["tomli"]
|
||||
|
||||
[[package]]
|
||||
name = "cryptography"
|
||||
@@ -1215,10 +1215,10 @@ files = [
|
||||
cffi = {version = ">=1.12", markers = "platform_python_implementation != \"PyPy\""}
|
||||
|
||||
[package.extras]
|
||||
docs = ["sphinx (>=5.3.0)", "sphinx-rtd-theme (>=3.0.0) ; python_version >= \"3.8\""]
|
||||
docs = ["sphinx (>=5.3.0)", "sphinx-rtd-theme (>=3.0.0)"]
|
||||
docstest = ["pyenchant (>=3)", "readme-renderer (>=30.0)", "sphinxcontrib-spelling (>=7.3.1)"]
|
||||
nox = ["nox (>=2024.4.15)", "nox[uv] (>=2024.3.2) ; python_version >= \"3.8\""]
|
||||
pep8test = ["check-sdist ; python_version >= \"3.8\"", "click (>=8.0.1)", "mypy (>=1.4)", "ruff (>=0.3.6)"]
|
||||
nox = ["nox (>=2024.4.15)", "nox[uv] (>=2024.3.2)"]
|
||||
pep8test = ["check-sdist", "click (>=8.0.1)", "mypy (>=1.4)", "ruff (>=0.3.6)"]
|
||||
sdist = ["build (>=1.0.0)"]
|
||||
ssh = ["bcrypt (>=3.1.5)"]
|
||||
test = ["certifi (>=2024)", "cryptography-vectors (==44.0.1)", "pretend (>=0.7)", "pytest (>=7.4.0)", "pytest-benchmark (>=4.0)", "pytest-cov (>=2.10.1)", "pytest-xdist (>=3.5.0)"]
|
||||
@@ -1285,17 +1285,17 @@ tqdm = ">=4.66.3"
|
||||
xxhash = "*"
|
||||
|
||||
[package.extras]
|
||||
audio = ["librosa", "soundfile (>=0.12.1)", "soxr (>=0.4.0) ; python_version >= \"3.9\""]
|
||||
audio = ["librosa", "soundfile (>=0.12.1)", "soxr (>=0.4.0)"]
|
||||
benchmarks = ["tensorflow (==2.12.0)", "torch (==2.0.1)", "transformers (==4.30.1)"]
|
||||
dev = ["Pillow (>=9.4.0)", "absl-py", "decorator", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.8.0.post1)", "jax (>=0.3.14) ; sys_platform != \"win32\"", "jaxlib (>=0.3.14) ; sys_platform != \"win32\"", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "moto[server]", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "ruff (>=0.3.0)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "soxr (>=0.4.0) ; python_version >= \"3.9\"", "sqlalchemy", "tensorflow (>=2.16.0) ; python_version >= \"3.10\"", "tensorflow (>=2.6.0)", "tensorflow (>=2.6.0) ; python_version < \"3.10\"", "tiktoken", "torch", "torch (>=2.0.0)", "torchdata", "transformers", "transformers (>=4.42.0)", "zstandard"]
|
||||
dev = ["Pillow (>=9.4.0)", "absl-py", "decorator", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.8.0.post1)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "moto[server]", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "ruff (>=0.3.0)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "soxr (>=0.4.0)", "sqlalchemy", "tensorflow (>=2.16.0)", "tensorflow (>=2.6.0)", "tensorflow (>=2.6.0)", "tiktoken", "torch", "torch (>=2.0.0)", "torchdata", "transformers", "transformers (>=4.42.0)", "zstandard"]
|
||||
docs = ["s3fs", "tensorflow (>=2.6.0)", "torch", "transformers"]
|
||||
jax = ["jax (>=0.3.14)", "jaxlib (>=0.3.14)"]
|
||||
quality = ["ruff (>=0.3.0)"]
|
||||
s3 = ["s3fs"]
|
||||
tensorflow = ["tensorflow (>=2.6.0)"]
|
||||
tensorflow-gpu = ["tensorflow (>=2.6.0)"]
|
||||
tests = ["Pillow (>=9.4.0)", "absl-py", "decorator", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.8.0.post1)", "jax (>=0.3.14) ; sys_platform != \"win32\"", "jaxlib (>=0.3.14) ; sys_platform != \"win32\"", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "moto[server]", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "soxr (>=0.4.0) ; python_version >= \"3.9\"", "sqlalchemy", "tensorflow (>=2.16.0) ; python_version >= \"3.10\"", "tensorflow (>=2.6.0) ; python_version < \"3.10\"", "tiktoken", "torch (>=2.0.0)", "torchdata", "transformers (>=4.42.0)", "zstandard"]
|
||||
tests-numpy2 = ["Pillow (>=9.4.0)", "absl-py", "decorator", "elasticsearch (<8.0.0)", "jax (>=0.3.14) ; sys_platform != \"win32\"", "jaxlib (>=0.3.14) ; sys_platform != \"win32\"", "joblib (<1.3.0)", "joblibspark", "lz4", "moto[server]", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "soxr (>=0.4.0) ; python_version >= \"3.9\"", "sqlalchemy", "tiktoken", "torch (>=2.0.0)", "torchdata", "transformers (>=4.42.0)", "zstandard"]
|
||||
tests = ["Pillow (>=9.4.0)", "absl-py", "decorator", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.8.0.post1)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "moto[server]", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "soxr (>=0.4.0)", "sqlalchemy", "tensorflow (>=2.16.0)", "tensorflow (>=2.6.0)", "tiktoken", "torch (>=2.0.0)", "torchdata", "transformers (>=4.42.0)", "zstandard"]
|
||||
tests-numpy2 = ["Pillow (>=9.4.0)", "absl-py", "decorator", "elasticsearch (<8.0.0)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "lz4", "moto[server]", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "soxr (>=0.4.0)", "sqlalchemy", "tiktoken", "torch (>=2.0.0)", "torchdata", "transformers (>=4.42.0)", "zstandard"]
|
||||
torch = ["torch"]
|
||||
vision = ["Pillow (>=9.4.0)"]
|
||||
|
||||
@@ -1417,7 +1417,7 @@ files = [
|
||||
wrapt = ">=1.10,<2"
|
||||
|
||||
[package.extras]
|
||||
dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "setuptools ; python_version >= \"3.12\"", "tox"]
|
||||
dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "setuptools", "tox"]
|
||||
|
||||
[[package]]
|
||||
name = "dill"
|
||||
@@ -1627,7 +1627,7 @@ files = [
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipython", "littleutils", "pytest", "rich ; python_version >= \"3.11\""]
|
||||
tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipython", "littleutils", "pytest", "rich"]
|
||||
|
||||
[[package]]
|
||||
name = "faker"
|
||||
@@ -1725,7 +1725,7 @@ files = [
|
||||
[package.extras]
|
||||
docs = ["furo (>=2024.8.6)", "sphinx (>=8.1.3)", "sphinx-autodoc-typehints (>=3)"]
|
||||
testing = ["covdefaults (>=2.3)", "coverage (>=7.6.10)", "diff-cover (>=9.2.1)", "pytest (>=8.3.4)", "pytest-asyncio (>=0.25.2)", "pytest-cov (>=6)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.28.1)"]
|
||||
typing = ["typing-extensions (>=4.12.2) ; python_version < \"3.11\""]
|
||||
typing = ["typing-extensions (>=4.12.2)"]
|
||||
|
||||
[[package]]
|
||||
name = "flake8"
|
||||
@@ -1828,18 +1828,18 @@ files = [
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
all = ["brotli (>=1.0.1) ; platform_python_implementation == \"CPython\"", "brotlicffi (>=0.8.0) ; platform_python_implementation != \"CPython\"", "fs (>=2.2.0,<3)", "lxml (>=4.0)", "lz4 (>=1.7.4.2)", "matplotlib", "munkres ; platform_python_implementation == \"PyPy\"", "pycairo", "scipy ; platform_python_implementation != \"PyPy\"", "skia-pathops (>=0.5.0)", "sympy", "uharfbuzz (>=0.23.0)", "unicodedata2 (>=15.1.0) ; python_version <= \"3.12\"", "xattr ; sys_platform == \"darwin\"", "zopfli (>=0.1.4)"]
|
||||
all = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "fs (>=2.2.0,<3)", "lxml (>=4.0)", "lz4 (>=1.7.4.2)", "matplotlib", "munkres", "pycairo", "scipy", "skia-pathops (>=0.5.0)", "sympy", "uharfbuzz (>=0.23.0)", "unicodedata2 (>=15.1.0)", "xattr", "zopfli (>=0.1.4)"]
|
||||
graphite = ["lz4 (>=1.7.4.2)"]
|
||||
interpolatable = ["munkres ; platform_python_implementation == \"PyPy\"", "pycairo", "scipy ; platform_python_implementation != \"PyPy\""]
|
||||
interpolatable = ["munkres", "pycairo", "scipy"]
|
||||
lxml = ["lxml (>=4.0)"]
|
||||
pathops = ["skia-pathops (>=0.5.0)"]
|
||||
plot = ["matplotlib"]
|
||||
repacker = ["uharfbuzz (>=0.23.0)"]
|
||||
symfont = ["sympy"]
|
||||
type1 = ["xattr ; sys_platform == \"darwin\""]
|
||||
type1 = ["xattr"]
|
||||
ufo = ["fs (>=2.2.0,<3)"]
|
||||
unicode = ["unicodedata2 (>=15.1.0) ; python_version <= \"3.12\""]
|
||||
woff = ["brotli (>=1.0.1) ; platform_python_implementation == \"CPython\"", "brotlicffi (>=0.8.0) ; platform_python_implementation != \"CPython\"", "zopfli (>=0.1.4)"]
|
||||
unicode = ["unicodedata2 (>=15.1.0)"]
|
||||
woff = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "zopfli (>=0.1.4)"]
|
||||
|
||||
[[package]]
|
||||
name = "fqdn"
|
||||
@@ -2103,11 +2103,11 @@ greenlet = {version = ">=3.0rc3", markers = "platform_python_implementation == \
|
||||
"zope.interface" = "*"
|
||||
|
||||
[package.extras]
|
||||
dnspython = ["dnspython (>=1.16.0,<2.0) ; python_version < \"3.10\"", "idna ; python_version < \"3.10\""]
|
||||
dnspython = ["dnspython (>=1.16.0,<2.0)", "idna"]
|
||||
docs = ["furo", "repoze.sphinx.autointerface", "sphinx", "sphinxcontrib-programoutput", "zope.schema"]
|
||||
monitor = ["psutil (>=5.7.0) ; sys_platform != \"win32\" or platform_python_implementation == \"CPython\""]
|
||||
recommended = ["cffi (>=1.12.2) ; platform_python_implementation == \"CPython\"", "dnspython (>=1.16.0,<2.0) ; python_version < \"3.10\"", "idna ; python_version < \"3.10\"", "psutil (>=5.7.0) ; sys_platform != \"win32\" or platform_python_implementation == \"CPython\""]
|
||||
test = ["cffi (>=1.12.2) ; platform_python_implementation == \"CPython\"", "coverage (>=5.0) ; sys_platform != \"win32\"", "dnspython (>=1.16.0,<2.0) ; python_version < \"3.10\"", "idna ; python_version < \"3.10\"", "objgraph", "psutil (>=5.7.0) ; sys_platform != \"win32\" or platform_python_implementation == \"CPython\"", "requests"]
|
||||
monitor = ["psutil (>=5.7.0)"]
|
||||
recommended = ["cffi (>=1.12.2)", "dnspython (>=1.16.0,<2.0)", "idna", "psutil (>=5.7.0)"]
|
||||
test = ["cffi (>=1.12.2)", "coverage (>=5.0)", "dnspython (>=1.16.0,<2.0)", "idna", "objgraph", "psutil (>=5.7.0)", "requests"]
|
||||
|
||||
[[package]]
|
||||
name = "ghapi"
|
||||
@@ -2160,7 +2160,7 @@ gitdb = ">=4.0.1,<5"
|
||||
|
||||
[package.extras]
|
||||
doc = ["sphinx (>=7.1.2,<7.2)", "sphinx-autodoc-typehints", "sphinx_rtd_theme"]
|
||||
test = ["coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock ; python_version < \"3.8\"", "mypy", "pre-commit", "pytest (>=7.3.1)", "pytest-cov", "pytest-instafail", "pytest-mock", "pytest-sugar", "typing-extensions ; python_version < \"3.11\""]
|
||||
test = ["coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre-commit", "pytest (>=7.3.1)", "pytest-cov", "pytest-instafail", "pytest-mock", "pytest-sugar", "typing-extensions"]
|
||||
|
||||
[[package]]
|
||||
name = "google-ai-generativelanguage"
|
||||
@@ -2179,7 +2179,7 @@ google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extr
|
||||
google-auth = ">=2.14.1,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0dev"
|
||||
proto-plus = [
|
||||
{version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""},
|
||||
{version = ">=1.22.3,<2.0.0dev"},
|
||||
{version = ">=1.22.3,<2.0.0dev", markers = "python_version < \"3.13\""},
|
||||
]
|
||||
protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0dev"
|
||||
|
||||
@@ -2202,14 +2202,14 @@ grpcio = {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_versi
|
||||
grpcio-status = {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}
|
||||
proto-plus = [
|
||||
{version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""},
|
||||
{version = ">=1.22.3,<2.0.0dev"},
|
||||
{version = ">=1.22.3,<2.0.0dev", markers = "python_version < \"3.13\""},
|
||||
]
|
||||
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0"
|
||||
requests = ">=2.18.0,<3.0.0.dev0"
|
||||
|
||||
[package.extras]
|
||||
async-rest = ["google-auth[aiohttp] (>=2.35.0,<3.0.dev0)"]
|
||||
grpc = ["grpcio (>=1.33.2,<2.0dev)", "grpcio (>=1.49.1,<2.0dev) ; python_version >= \"3.11\"", "grpcio-status (>=1.33.2,<2.0.dev0)", "grpcio-status (>=1.49.1,<2.0.dev0) ; python_version >= \"3.11\""]
|
||||
grpc = ["grpcio (>=1.33.2,<2.0dev)", "grpcio (>=1.49.1,<2.0dev)", "grpcio-status (>=1.33.2,<2.0.dev0)", "grpcio-status (>=1.49.1,<2.0.dev0)"]
|
||||
grpcgcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"]
|
||||
grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"]
|
||||
|
||||
@@ -2324,10 +2324,10 @@ ag2-testing = ["absl-py", "ag2[gemini]", "cloudpickle (>=3.0,<4.0)", "google-clo
|
||||
agent-engines = ["cloudpickle (>=3.0,<4.0)", "google-cloud-logging (<4)", "google-cloud-trace (<2)", "opentelemetry-exporter-gcp-trace (<2)", "opentelemetry-sdk (<2)", "packaging (>=24.0)", "pydantic (>=2.10,<3)", "typing-extensions"]
|
||||
autologging = ["mlflow (>=1.27.0,<=2.16.0)"]
|
||||
cloud-profiler = ["tensorboard-plugin-profile (>=2.4.0,<2.18.0)", "tensorflow (>=2.4.0,<3.0.0dev)", "werkzeug (>=2.0.0,<2.1.0dev)"]
|
||||
datasets = ["pyarrow (>=10.0.1) ; python_version == \"3.11\"", "pyarrow (>=14.0.0) ; python_version >= \"3.12\"", "pyarrow (>=3.0.0,<8.0dev) ; python_version < \"3.11\""]
|
||||
datasets = ["pyarrow (>=10.0.1)", "pyarrow (>=14.0.0)", "pyarrow (>=3.0.0,<8.0dev)"]
|
||||
endpoint = ["requests (>=2.28.1)"]
|
||||
evaluation = ["pandas (>=1.0.0)", "scikit-learn (<1.6.0) ; python_version <= \"3.10\"", "scikit-learn ; python_version > \"3.10\"", "tqdm (>=4.23.0)"]
|
||||
full = ["docker (>=5.0.3)", "explainable-ai-sdk (>=1.0.0)", "fastapi (>=0.71.0,<=0.114.0)", "google-cloud-bigquery", "google-cloud-bigquery-storage", "google-vizier (>=0.1.6)", "httpx (>=0.23.0,<0.25.0)", "immutabledict", "lit-nlp (==0.4.0)", "mlflow (>=1.27.0,<=2.16.0)", "numpy (>=1.15.0)", "pandas (>=1.0.0)", "pyarrow (>=10.0.1) ; python_version == \"3.11\"", "pyarrow (>=14.0.0) ; python_version >= \"3.12\"", "pyarrow (>=3.0.0,<8.0dev) ; python_version < \"3.11\"", "pyarrow (>=6.0.1)", "pyyaml (>=5.3.1,<7)", "ray[default] (>=2.4,<2.5.dev0 || >2.9.0,!=2.9.1,!=2.9.2,<2.10.dev0 || >=2.33.dev0,<=2.33.0) ; python_version < \"3.11\"", "ray[default] (>=2.5,<=2.33.0) ; python_version == \"3.11\"", "requests (>=2.28.1)", "scikit-learn (<1.6.0) ; python_version <= \"3.10\"", "scikit-learn ; python_version > \"3.10\"", "setuptools (<70.0.0)", "starlette (>=0.17.1)", "tensorboard-plugin-profile (>=2.4.0,<2.18.0)", "tensorflow (>=2.3.0,<3.0.0dev)", "tensorflow (>=2.3.0,<3.0.0dev) ; python_version <= \"3.11\"", "tensorflow (>=2.4.0,<3.0.0dev)", "tqdm (>=4.23.0)", "urllib3 (>=1.21.1,<1.27)", "uvicorn[standard] (>=0.16.0)", "werkzeug (>=2.0.0,<2.1.0dev)"]
|
||||
evaluation = ["pandas (>=1.0.0)", "scikit-learn", "scikit-learn (<1.6.0)", "tqdm (>=4.23.0)"]
|
||||
full = ["docker (>=5.0.3)", "explainable-ai-sdk (>=1.0.0)", "fastapi (>=0.71.0,<=0.114.0)", "google-cloud-bigquery", "google-cloud-bigquery-storage", "google-vizier (>=0.1.6)", "httpx (>=0.23.0,<0.25.0)", "immutabledict", "lit-nlp (==0.4.0)", "mlflow (>=1.27.0,<=2.16.0)", "numpy (>=1.15.0)", "pandas (>=1.0.0)", "pyarrow (>=10.0.1)", "pyarrow (>=14.0.0)", "pyarrow (>=3.0.0,<8.0dev)", "pyarrow (>=6.0.1)", "pyyaml (>=5.3.1,<7)", "ray[default] (>=2.4,<2.5.dev0 || >2.9.0,!=2.9.1,!=2.9.2,<2.10.dev0 || >=2.33.dev0,<=2.33.0)", "ray[default] (>=2.5,<=2.33.0)", "requests (>=2.28.1)", "scikit-learn", "scikit-learn (<1.6.0)", "setuptools (<70.0.0)", "starlette (>=0.17.1)", "tensorboard-plugin-profile (>=2.4.0,<2.18.0)", "tensorflow (>=2.3.0,<3.0.0dev)", "tensorflow (>=2.3.0,<3.0.0dev)", "tensorflow (>=2.4.0,<3.0.0dev)", "tqdm (>=4.23.0)", "urllib3 (>=1.21.1,<1.27)", "uvicorn[standard] (>=0.16.0)", "werkzeug (>=2.0.0,<2.1.0dev)"]
|
||||
langchain = ["langchain (>=0.3,<0.4)", "langchain-core (>=0.3,<0.4)", "langchain-google-vertexai (>=2,<3)", "langgraph (>=0.2.45,<0.3)", "openinference-instrumentation-langchain (>=0.1.19,<0.2)"]
|
||||
langchain-testing = ["absl-py", "cloudpickle (>=3.0,<4.0)", "google-cloud-trace (<2)", "langchain (>=0.3,<0.4)", "langchain-core (>=0.3,<0.4)", "langchain-google-vertexai (>=2,<3)", "langgraph (>=0.2.45,<0.3)", "openinference-instrumentation-langchain (>=0.1.19,<0.2)", "opentelemetry-exporter-gcp-trace (<2)", "opentelemetry-sdk (<2)", "pydantic (>=2.6.3,<3)", "pytest-xdist", "typing-extensions"]
|
||||
lit = ["explainable-ai-sdk (>=1.0.0)", "lit-nlp (==0.4.0)", "pandas (>=1.0.0)", "tensorflow (>=2.3.0,<3.0.0dev)"]
|
||||
@@ -2335,11 +2335,11 @@ metadata = ["numpy (>=1.15.0)", "pandas (>=1.0.0)"]
|
||||
pipelines = ["pyyaml (>=5.3.1,<7)"]
|
||||
prediction = ["docker (>=5.0.3)", "fastapi (>=0.71.0,<=0.114.0)", "httpx (>=0.23.0,<0.25.0)", "starlette (>=0.17.1)", "uvicorn[standard] (>=0.16.0)"]
|
||||
private-endpoints = ["requests (>=2.28.1)", "urllib3 (>=1.21.1,<1.27)"]
|
||||
ray = ["google-cloud-bigquery", "google-cloud-bigquery-storage", "immutabledict", "pandas (>=1.0.0)", "pyarrow (>=6.0.1)", "ray[default] (>=2.4,<2.5.dev0 || >2.9.0,!=2.9.1,!=2.9.2,<2.10.dev0 || >=2.33.dev0,<=2.33.0) ; python_version < \"3.11\"", "ray[default] (>=2.5,<=2.33.0) ; python_version == \"3.11\"", "setuptools (<70.0.0)"]
|
||||
ray-testing = ["google-cloud-bigquery", "google-cloud-bigquery-storage", "immutabledict", "pandas (>=1.0.0)", "pyarrow (>=6.0.1)", "pytest-xdist", "ray[default] (>=2.4,<2.5.dev0 || >2.9.0,!=2.9.1,!=2.9.2,<2.10.dev0 || >=2.33.dev0,<=2.33.0) ; python_version < \"3.11\"", "ray[default] (>=2.5,<=2.33.0) ; python_version == \"3.11\"", "ray[train]", "scikit-learn (<1.6.0)", "setuptools (<70.0.0)", "tensorflow", "torch (>=2.0.0,<2.1.0)", "xgboost", "xgboost-ray"]
|
||||
ray = ["google-cloud-bigquery", "google-cloud-bigquery-storage", "immutabledict", "pandas (>=1.0.0)", "pyarrow (>=6.0.1)", "ray[default] (>=2.4,<2.5.dev0 || >2.9.0,!=2.9.1,!=2.9.2,<2.10.dev0 || >=2.33.dev0,<=2.33.0)", "ray[default] (>=2.5,<=2.33.0)", "setuptools (<70.0.0)"]
|
||||
ray-testing = ["google-cloud-bigquery", "google-cloud-bigquery-storage", "immutabledict", "pandas (>=1.0.0)", "pyarrow (>=6.0.1)", "pytest-xdist", "ray[default] (>=2.4,<2.5.dev0 || >2.9.0,!=2.9.1,!=2.9.2,<2.10.dev0 || >=2.33.dev0,<=2.33.0)", "ray[default] (>=2.5,<=2.33.0)", "ray[train]", "scikit-learn (<1.6.0)", "setuptools (<70.0.0)", "tensorflow", "torch (>=2.0.0,<2.1.0)", "xgboost", "xgboost-ray"]
|
||||
reasoningengine = ["cloudpickle (>=3.0,<4.0)", "google-cloud-trace (<2)", "opentelemetry-exporter-gcp-trace (<2)", "opentelemetry-sdk (<2)", "pydantic (>=2.6.3,<3)", "typing-extensions"]
|
||||
tensorboard = ["tensorboard-plugin-profile (>=2.4.0,<2.18.0)", "tensorflow (>=2.3.0,<3.0.0dev) ; python_version <= \"3.11\"", "tensorflow (>=2.4.0,<3.0.0dev)", "werkzeug (>=2.0.0,<2.1.0dev)"]
|
||||
testing = ["aiohttp", "bigframes ; python_version >= \"3.10\"", "docker (>=5.0.3)", "explainable-ai-sdk (>=1.0.0)", "fastapi (>=0.71.0,<=0.114.0)", "google-api-core (>=2.11,<3.0.0)", "google-cloud-bigquery", "google-cloud-bigquery-storage", "google-vizier (>=0.1.6)", "grpcio-testing", "httpx (>=0.23.0,<0.25.0)", "immutabledict", "ipython", "kfp (>=2.6.0,<3.0.0)", "lit-nlp (==0.4.0)", "mlflow (>=1.27.0,<=2.16.0)", "nltk", "numpy (>=1.15.0)", "pandas (>=1.0.0)", "pyarrow (>=10.0.1) ; python_version == \"3.11\"", "pyarrow (>=14.0.0) ; python_version >= \"3.12\"", "pyarrow (>=3.0.0,<8.0dev) ; python_version < \"3.11\"", "pyarrow (>=6.0.1)", "pytest-asyncio", "pytest-xdist", "pyyaml (>=5.3.1,<7)", "ray[default] (>=2.4,<2.5.dev0 || >2.9.0,!=2.9.1,!=2.9.2,<2.10.dev0 || >=2.33.dev0,<=2.33.0) ; python_version < \"3.11\"", "ray[default] (>=2.5,<=2.33.0) ; python_version == \"3.11\"", "requests (>=2.28.1)", "requests-toolbelt (<1.0.0)", "scikit-learn (<1.6.0) ; python_version <= \"3.10\"", "scikit-learn ; python_version > \"3.10\"", "sentencepiece (>=0.2.0)", "setuptools (<70.0.0)", "starlette (>=0.17.1)", "tensorboard-plugin-profile (>=2.4.0,<2.18.0)", "tensorflow (==2.13.0) ; python_version <= \"3.11\"", "tensorflow (==2.16.1) ; python_version > \"3.11\"", "tensorflow (>=2.3.0,<3.0.0dev)", "tensorflow (>=2.3.0,<3.0.0dev) ; python_version <= \"3.11\"", "tensorflow (>=2.4.0,<3.0.0dev)", "torch (>=2.0.0,<2.1.0) ; python_version <= \"3.11\"", "torch (>=2.2.0) ; python_version > \"3.11\"", "tqdm (>=4.23.0)", "urllib3 (>=1.21.1,<1.27)", "uvicorn[standard] (>=0.16.0)", "werkzeug (>=2.0.0,<2.1.0dev)", "xgboost"]
|
||||
tensorboard = ["tensorboard-plugin-profile (>=2.4.0,<2.18.0)", "tensorflow (>=2.3.0,<3.0.0dev)", "tensorflow (>=2.4.0,<3.0.0dev)", "werkzeug (>=2.0.0,<2.1.0dev)"]
|
||||
testing = ["aiohttp", "bigframes", "docker (>=5.0.3)", "explainable-ai-sdk (>=1.0.0)", "fastapi (>=0.71.0,<=0.114.0)", "google-api-core (>=2.11,<3.0.0)", "google-cloud-bigquery", "google-cloud-bigquery-storage", "google-vizier (>=0.1.6)", "grpcio-testing", "httpx (>=0.23.0,<0.25.0)", "immutabledict", "ipython", "kfp (>=2.6.0,<3.0.0)", "lit-nlp (==0.4.0)", "mlflow (>=1.27.0,<=2.16.0)", "nltk", "numpy (>=1.15.0)", "pandas (>=1.0.0)", "pyarrow (>=10.0.1)", "pyarrow (>=14.0.0)", "pyarrow (>=3.0.0,<8.0dev)", "pyarrow (>=6.0.1)", "pytest-asyncio", "pytest-xdist", "pyyaml (>=5.3.1,<7)", "ray[default] (>=2.4,<2.5.dev0 || >2.9.0,!=2.9.1,!=2.9.2,<2.10.dev0 || >=2.33.dev0,<=2.33.0)", "ray[default] (>=2.5,<=2.33.0)", "requests (>=2.28.1)", "requests-toolbelt (<1.0.0)", "scikit-learn", "scikit-learn (<1.6.0)", "sentencepiece (>=0.2.0)", "setuptools (<70.0.0)", "starlette (>=0.17.1)", "tensorboard-plugin-profile (>=2.4.0,<2.18.0)", "tensorflow (==2.13.0)", "tensorflow (==2.16.1)", "tensorflow (>=2.3.0,<3.0.0dev)", "tensorflow (>=2.3.0,<3.0.0dev)", "tensorflow (>=2.4.0,<3.0.0dev)", "torch (>=2.0.0,<2.1.0)", "torch (>=2.2.0)", "tqdm (>=4.23.0)", "urllib3 (>=1.21.1,<1.27)", "uvicorn[standard] (>=0.16.0)", "werkzeug (>=2.0.0,<2.1.0dev)", "xgboost"]
|
||||
tokenization = ["sentencepiece (>=0.2.0)"]
|
||||
vizier = ["google-vizier (>=0.1.6)"]
|
||||
xai = ["tensorflow (>=2.3.0,<3.0.0dev)"]
|
||||
@@ -2368,12 +2368,12 @@ requests = ">=2.21.0,<3.0.0dev"
|
||||
[package.extras]
|
||||
all = ["google-cloud-bigquery[bigquery-v2,bqstorage,geopandas,ipython,ipywidgets,opentelemetry,pandas,tqdm]"]
|
||||
bigquery-v2 = ["proto-plus (>=1.22.3,<2.0.0dev)", "protobuf (>=3.20.2,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev)"]
|
||||
bqstorage = ["google-cloud-bigquery-storage (>=2.6.0,<3.0.0dev)", "grpcio (>=1.47.0,<2.0dev)", "grpcio (>=1.49.1,<2.0dev) ; python_version >= \"3.11\"", "pyarrow (>=3.0.0)"]
|
||||
bqstorage = ["google-cloud-bigquery-storage (>=2.6.0,<3.0.0dev)", "grpcio (>=1.47.0,<2.0dev)", "grpcio (>=1.49.1,<2.0dev)", "pyarrow (>=3.0.0)"]
|
||||
geopandas = ["Shapely (>=1.8.4,<3.0.0dev)", "geopandas (>=0.9.0,<2.0dev)"]
|
||||
ipython = ["bigquery-magics (>=0.1.0)"]
|
||||
ipywidgets = ["ipykernel (>=6.0.0)", "ipywidgets (>=7.7.0)"]
|
||||
opentelemetry = ["opentelemetry-api (>=1.1.0)", "opentelemetry-instrumentation (>=0.20b0)", "opentelemetry-sdk (>=1.1.0)"]
|
||||
pandas = ["db-dtypes (>=0.3.0,<2.0.0dev)", "importlib-metadata (>=1.0.0) ; python_version < \"3.8\"", "pandas (>=1.1.0)", "pyarrow (>=3.0.0)"]
|
||||
pandas = ["db-dtypes (>=0.3.0,<2.0.0dev)", "importlib-metadata (>=1.0.0)", "pandas (>=1.1.0)", "pyarrow (>=3.0.0)"]
|
||||
tqdm = ["tqdm (>=4.7.4,<5.0.0dev)"]
|
||||
|
||||
[[package]]
|
||||
@@ -2893,7 +2893,7 @@ httpcore = "==1.*"
|
||||
idna = "*"
|
||||
|
||||
[package.extras]
|
||||
brotli = ["brotli ; platform_python_implementation == \"CPython\"", "brotlicffi ; platform_python_implementation != \"CPython\""]
|
||||
brotli = ["brotli", "brotlicffi"]
|
||||
cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"]
|
||||
http2 = ["h2 (>=3,<5)"]
|
||||
socks = ["socksio (==1.*)"]
|
||||
@@ -3028,7 +3028,7 @@ zipp = ">=0.5"
|
||||
[package.extras]
|
||||
docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
|
||||
perf = ["ipython"]
|
||||
testing = ["flufl.flake8", "importlib-resources (>=1.3) ; python_version < \"3.9\"", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy ; platform_python_implementation != \"PyPy\"", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"]
|
||||
testing = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "iniconfig"
|
||||
@@ -3102,7 +3102,7 @@ traitlets = ">=5.13.0"
|
||||
[package.extras]
|
||||
all = ["ipython[black,doc,kernel,matplotlib,nbconvert,nbformat,notebook,parallel,qtconsole]", "ipython[test,test-extra]"]
|
||||
black = ["black"]
|
||||
doc = ["docrepr", "exceptiongroup", "intersphinx_registry", "ipykernel", "ipython[test]", "matplotlib", "setuptools (>=18.5)", "sphinx (>=1.3)", "sphinx-rtd-theme", "sphinxcontrib-jquery", "tomli ; python_version < \"3.11\"", "typing_extensions"]
|
||||
doc = ["docrepr", "exceptiongroup", "intersphinx_registry", "ipykernel", "ipython[test]", "matplotlib", "setuptools (>=18.5)", "sphinx (>=1.3)", "sphinx-rtd-theme", "sphinxcontrib-jquery", "tomli", "typing_extensions"]
|
||||
kernel = ["ipykernel"]
|
||||
matplotlib = ["matplotlib"]
|
||||
nbconvert = ["nbconvert"]
|
||||
@@ -3415,7 +3415,7 @@ traitlets = ">=5.3"
|
||||
|
||||
[package.extras]
|
||||
docs = ["ipykernel", "myst-parser", "pydata-sphinx-theme", "sphinx (>=4)", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-spelling"]
|
||||
test = ["coverage", "ipykernel (>=6.14)", "mypy", "paramiko ; sys_platform == \"win32\"", "pre-commit", "pytest (<8.2.0)", "pytest-cov", "pytest-jupyter[client] (>=0.4.1)", "pytest-timeout"]
|
||||
test = ["coverage", "ipykernel (>=6.14)", "mypy", "paramiko", "pre-commit", "pytest (<8.2.0)", "pytest-cov", "pytest-jupyter[client] (>=0.4.1)", "pytest-timeout"]
|
||||
|
||||
[[package]]
|
||||
name = "jupyter-core"
|
||||
@@ -4415,7 +4415,7 @@ files = [
|
||||
[package.extras]
|
||||
develop = ["codecov", "pycodestyle", "pytest (>=4.6)", "pytest-cov", "wheel"]
|
||||
docs = ["sphinx"]
|
||||
gmpy = ["gmpy2 (>=2.1.0a4) ; platform_python_implementation != \"PyPy\""]
|
||||
gmpy = ["gmpy2 (>=2.1.0a4)"]
|
||||
tests = ["pytest (>=4.6)"]
|
||||
|
||||
[[package]]
|
||||
@@ -4813,7 +4813,7 @@ tornado = ">=6.2.0"
|
||||
[package.extras]
|
||||
dev = ["hatch", "pre-commit"]
|
||||
docs = ["myst-parser", "nbsphinx", "pydata-sphinx-theme", "sphinx (>=1.3.6)", "sphinxcontrib-github-alt", "sphinxcontrib-spelling"]
|
||||
test = ["importlib-resources (>=5.0) ; python_version < \"3.10\"", "ipykernel", "jupyter-server[test] (>=2.4.0,<3)", "jupyterlab-server[test] (>=2.27.1,<3)", "nbval", "pytest (>=7.0)", "pytest-console-scripts", "pytest-timeout", "pytest-tornasync", "requests"]
|
||||
test = ["importlib-resources (>=5.0)", "ipykernel", "jupyter-server[test] (>=2.4.0,<3)", "jupyterlab-server[test] (>=2.27.1,<3)", "nbval", "pytest (>=7.0)", "pytest-console-scripts", "pytest-timeout", "pytest-tornasync", "requests"]
|
||||
|
||||
[[package]]
|
||||
name = "notebook-shim"
|
||||
@@ -5323,7 +5323,7 @@ docs = ["furo", "olefile", "sphinx (>=8.1)", "sphinx-copybutton", "sphinx-inline
|
||||
fpx = ["olefile"]
|
||||
mic = ["olefile"]
|
||||
tests = ["check-manifest", "coverage (>=7.4.2)", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout", "trove-classifiers (>=2024.10.12)"]
|
||||
typing = ["typing-extensions ; python_version < \"3.10\""]
|
||||
typing = ["typing-extensions"]
|
||||
xmp = ["defusedxml"]
|
||||
|
||||
[[package]]
|
||||
@@ -5805,7 +5805,7 @@ typing-extensions = ">=4.12.2"
|
||||
|
||||
[package.extras]
|
||||
email = ["email-validator (>=2.0.0)"]
|
||||
timezone = ["tzdata ; python_version >= \"3.9\" and platform_system == \"Windows\""]
|
||||
timezone = ["tzdata"]
|
||||
|
||||
[[package]]
|
||||
name = "pydantic-core"
|
||||
@@ -5938,7 +5938,7 @@ numpy = ">=1.16.4"
|
||||
|
||||
[package.extras]
|
||||
carto = ["pydeck-carto"]
|
||||
jupyter = ["ipykernel (>=5.1.2) ; python_version >= \"3.4\"", "ipython (>=5.8.0) ; python_version < \"3.4\"", "ipywidgets (>=7,<8)", "traitlets (>=4.3.2)"]
|
||||
jupyter = ["ipykernel (>=5.1.2)", "ipython (>=5.8.0)", "ipywidgets (>=7,<8)", "traitlets (>=4.3.2)"]
|
||||
|
||||
[[package]]
|
||||
name = "pyee"
|
||||
@@ -5956,7 +5956,7 @@ files = [
|
||||
typing-extensions = "*"
|
||||
|
||||
[package.extras]
|
||||
dev = ["black", "flake8", "flake8-black", "isort", "jupyter-console", "mkdocs", "mkdocs-include-markdown-plugin", "mkdocstrings[python]", "pytest", "pytest-asyncio ; python_version >= \"3.4\"", "pytest-trio ; python_version >= \"3.7\"", "toml", "tox", "trio", "trio ; python_version > \"3.6\"", "trio-typing ; python_version > \"3.6\"", "twine", "twisted", "validate-pyproject[all]"]
|
||||
dev = ["black", "flake8", "flake8-black", "isort", "jupyter-console", "mkdocs", "mkdocs-include-markdown-plugin", "mkdocstrings[python]", "pytest", "pytest-asyncio", "pytest-trio", "toml", "tox", "trio", "trio", "trio-typing", "twine", "twisted", "validate-pyproject[all]"]
|
||||
|
||||
[[package]]
|
||||
name = "pyflakes"
|
||||
@@ -6338,7 +6338,7 @@ files = [
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
dev = ["backports.zoneinfo ; python_version < \"3.9\"", "black", "build", "freezegun", "mdx_truly_sane_lists", "mike", "mkdocs", "mkdocs-awesome-pages-plugin", "mkdocs-gen-files", "mkdocs-literate-nav", "mkdocs-material (>=8.5)", "mkdocstrings[python]", "msgspec ; implementation_name != \"pypy\"", "mypy", "orjson ; implementation_name != \"pypy\"", "pylint", "pytest", "tzdata", "validate-pyproject[all]"]
|
||||
dev = ["backports.zoneinfo", "black", "build", "freezegun", "mdx_truly_sane_lists", "mike", "mkdocs", "mkdocs-awesome-pages-plugin", "mkdocs-gen-files", "mkdocs-literate-nav", "mkdocs-material (>=8.5)", "mkdocstrings[python]", "msgspec", "mypy", "orjson", "pylint", "pytest", "tzdata", "validate-pyproject[all]"]
|
||||
|
||||
[[package]]
|
||||
name = "python-levenshtein"
|
||||
@@ -7370,7 +7370,7 @@ tifffile = ">=2022.8.12"
|
||||
[package.extras]
|
||||
build = ["Cython (>=3.0.8)", "build (>=1.2.1)", "meson-python (>=0.16)", "ninja (>=1.11.1.1)", "numpy (>=2.0)", "pythran (>=0.16)", "spin (==0.13)"]
|
||||
data = ["pooch (>=1.6.0)"]
|
||||
developer = ["ipython", "pre-commit", "tomli ; python_version < \"3.11\""]
|
||||
developer = ["ipython", "pre-commit", "tomli"]
|
||||
docs = ["PyWavelets (>=1.6)", "dask[array] (>=2023.2.0)", "intersphinx-registry (>=0.2411.14)", "ipykernel", "ipywidgets", "kaleido (==0.2.1)", "matplotlib (>=3.7)", "myst-parser", "numpydoc (>=1.7)", "pandas (>=2.0)", "plotly (>=5.20)", "pooch (>=1.6)", "pydata-sphinx-theme (>=0.16)", "pytest-doctestplus", "scikit-learn (>=1.2)", "seaborn (>=0.11)", "sphinx (>=8.0)", "sphinx-copybutton", "sphinx-gallery[parallel] (>=0.18)", "sphinx_design (>=0.5)", "tifffile (>=2022.8.12)"]
|
||||
optional = ["PyWavelets (>=1.6)", "SimpleITK", "astropy (>=5.0)", "cloudpickle (>=1.1.1)", "dask[array] (>=2023.2.0)", "matplotlib (>=3.7)", "pooch (>=1.6.0)", "pyamg (>=5.2)", "scikit-learn (>=1.2)"]
|
||||
test = ["asv", "numpydoc (>=1.7)", "pooch (>=1.6.0)", "pytest (>=8)", "pytest-cov (>=2.11.0)", "pytest-doctestplus", "pytest-faulthandler", "pytest-localserver"]
|
||||
@@ -7437,7 +7437,7 @@ numpy = ">=1.23.5,<2.5"
|
||||
[package.extras]
|
||||
dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy (==1.10.0)", "pycodestyle", "pydevtool", "rich-click", "ruff (>=0.0.292)", "types-psutil", "typing_extensions"]
|
||||
doc = ["intersphinx_registry", "jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.16.5)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0,<8.0.0)", "sphinx-copybutton", "sphinx-design (>=0.4.0)"]
|
||||
test = ["Cython", "array-api-strict (>=2.0,<2.1.1)", "asv", "gmpy2", "hypothesis (>=6.30)", "meson", "mpmath", "ninja ; sys_platform != \"emscripten\"", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"]
|
||||
test = ["Cython", "array-api-strict (>=2.0,<2.1.1)", "asv", "gmpy2", "hypothesis (>=6.30)", "meson", "mpmath", "ninja", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"]
|
||||
|
||||
[[package]]
|
||||
name = "seaborn"
|
||||
@@ -7474,9 +7474,9 @@ files = [
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
nativelib = ["pyobjc-framework-Cocoa ; sys_platform == \"darwin\"", "pywin32 ; sys_platform == \"win32\""]
|
||||
objc = ["pyobjc-framework-Cocoa ; sys_platform == \"darwin\""]
|
||||
win32 = ["pywin32 ; sys_platform == \"win32\""]
|
||||
nativelib = ["pyobjc-framework-Cocoa", "pywin32"]
|
||||
objc = ["pyobjc-framework-Cocoa"]
|
||||
win32 = ["pywin32"]
|
||||
|
||||
[[package]]
|
||||
name = "setuptools"
|
||||
@@ -7491,13 +7491,13 @@ files = [
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\"", "ruff (>=0.8.0) ; sys_platform != \"cygwin\""]
|
||||
core = ["importlib_metadata (>=6) ; python_version < \"3.10\"", "jaraco.collections", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1) ; python_version < \"3.11\"", "wheel (>=0.43.0)"]
|
||||
check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.8.0)"]
|
||||
core = ["importlib_metadata (>=6)", "jaraco.collections", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"]
|
||||
cover = ["pytest-cov"]
|
||||
doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"]
|
||||
enabler = ["pytest-enabler (>=2.2)"]
|
||||
test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21) ; python_version >= \"3.9\" and sys_platform != \"cygwin\"", "jaraco.envs (>=2.2)", "jaraco.path (>=3.7.2)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf ; sys_platform != \"cygwin\"", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"]
|
||||
type = ["importlib_metadata (>=7.0.2) ; python_version < \"3.10\"", "jaraco.develop (>=7.21) ; sys_platform != \"cygwin\"", "mypy (==1.14.*)", "pytest-mypy"]
|
||||
test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.7.2)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"]
|
||||
type = ["importlib_metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (==1.14.*)", "pytest-mypy"]
|
||||
|
||||
[[package]]
|
||||
name = "shapely"
|
||||
@@ -7747,7 +7747,7 @@ typing-extensions = ">=4.4.0,<5"
|
||||
watchdog = {version = ">=2.1.5,<7", markers = "platform_system != \"Darwin\""}
|
||||
|
||||
[package.extras]
|
||||
snowflake = ["snowflake-connector-python (>=3.3.0) ; python_version < \"3.12\"", "snowflake-snowpark-python[modin] (>=1.17.0) ; python_version < \"3.12\""]
|
||||
snowflake = ["snowflake-connector-python (>=3.3.0)", "snowflake-snowpark-python[modin] (>=1.17.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "strenum"
|
||||
@@ -8605,7 +8605,7 @@ files = [
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
brotli = ["brotli (>=1.0.9) ; platform_python_implementation == \"CPython\"", "brotlicffi (>=0.8.0) ; platform_python_implementation != \"CPython\""]
|
||||
brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"]
|
||||
h2 = ["h2 (>=4,<5)"]
|
||||
socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
|
||||
zstd = ["zstandard (>=0.18.0)"]
|
||||
@@ -8627,7 +8627,7 @@ click = ">=7.0"
|
||||
h11 = ">=0.8"
|
||||
|
||||
[package.extras]
|
||||
standard = ["colorama (>=0.4) ; sys_platform == \"win32\"", "httptools (>=0.6.3)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.14.0,!=0.15.0,!=0.15.1) ; sys_platform != \"win32\" and sys_platform != \"cygwin\" and platform_python_implementation != \"PyPy\"", "watchfiles (>=0.13)", "websockets (>=10.4)"]
|
||||
standard = ["colorama (>=0.4)", "httptools (>=0.6.3)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.14.0,!=0.15.0,!=0.15.1)", "watchfiles (>=0.13)", "websockets (>=10.4)"]
|
||||
|
||||
[[package]]
|
||||
name = "virtualenv"
|
||||
@@ -8648,7 +8648,7 @@ platformdirs = ">=3.9.1,<5"
|
||||
|
||||
[package.extras]
|
||||
docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2,!=7.3)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"]
|
||||
test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8) ; platform_python_implementation == \"PyPy\" or platform_python_implementation == \"CPython\" and sys_platform == \"win32\" and python_version >= \"3.13\"", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10) ; platform_python_implementation == \"CPython\""]
|
||||
test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8)", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10)"]
|
||||
|
||||
[[package]]
|
||||
name = "watchdog"
|
||||
@@ -9232,11 +9232,11 @@ files = [
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\""]
|
||||
check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)"]
|
||||
cover = ["pytest-cov"]
|
||||
doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
|
||||
enabler = ["pytest-enabler (>=2.2)"]
|
||||
test = ["big-O", "importlib-resources ; python_version < \"3.9\"", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-ignore-flaky"]
|
||||
test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-ignore-flaky"]
|
||||
type = ["pytest-mypy"]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -100,7 +100,6 @@ reportlab = "*"
|
||||
concurrency = ["gevent"]
|
||||
|
||||
|
||||
|
||||
[tool.poetry.group.runtime.dependencies]
|
||||
jupyterlab = "*"
|
||||
notebook = "*"
|
||||
@@ -130,7 +129,6 @@ ignore = ["D1"]
|
||||
convention = "google"
|
||||
|
||||
|
||||
|
||||
[tool.poetry.group.evaluation.dependencies]
|
||||
streamlit = "*"
|
||||
whatthepatch = "*"
|
||||
|
||||
@@ -993,7 +993,6 @@ async def test_first_user_message_with_identical_content():
|
||||
"""
|
||||
Test that _first_user_message correctly identifies the first user message
|
||||
even when multiple messages have identical content but different IDs.
|
||||
Also verifies that the result is properly cached.
|
||||
|
||||
The issue we're checking is that the comparison (action == self._first_user_message())
|
||||
should correctly differentiate between messages with the same content but different IDs.
|
||||
@@ -1039,20 +1038,6 @@ async def test_first_user_message_with_identical_content():
|
||||
second_message.id != first_user_message.id
|
||||
) # This should be False, but may be True if there's a bug
|
||||
|
||||
# Verify caching behavior
|
||||
assert (
|
||||
controller._cached_first_user_message is not None
|
||||
) # Cache should be populated
|
||||
assert (
|
||||
controller._cached_first_user_message is first_user_message
|
||||
) # Cache should store the same object
|
||||
|
||||
# Mock get_events to verify it's not called again
|
||||
with patch.object(event_stream, 'get_events') as mock_get_events:
|
||||
cached_message = controller._first_user_message()
|
||||
assert cached_message is first_user_message # Should return cached object
|
||||
mock_get_events.assert_not_called() # Should not call get_events again
|
||||
|
||||
await controller.close()
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ from types import MappingProxyType
|
||||
import pytest
|
||||
from pydantic import SecretStr, ValidationError
|
||||
|
||||
from openhands.events.action.commands import CmdRunAction
|
||||
from openhands.integrations.provider import (
|
||||
ProviderHandler,
|
||||
ProviderToken,
|
||||
@@ -228,146 +227,3 @@ def test_token_conversion():
|
||||
)
|
||||
|
||||
assert len(store6.provider_tokens.keys()) == 0
|
||||
|
||||
|
||||
def test_provider_handler_type_enforcement():
|
||||
with pytest.raises((TypeError)):
|
||||
ProviderHandler(provider_tokens={'a': 'b'})
|
||||
|
||||
|
||||
def test_expose_env_vars():
|
||||
"""Test that expose_env_vars correctly exposes secrets as strings"""
|
||||
tokens = MappingProxyType(
|
||||
{
|
||||
ProviderType.GITHUB: ProviderToken(token=SecretStr('test_token')),
|
||||
ProviderType.GITLAB: ProviderToken(token=SecretStr('gitlab_token')),
|
||||
}
|
||||
)
|
||||
handler = ProviderHandler(provider_tokens=tokens)
|
||||
|
||||
# Test with specific provider tokens
|
||||
env_secrets = {
|
||||
ProviderType.GITHUB: SecretStr('gh_token'),
|
||||
ProviderType.GITLAB: SecretStr('gl_token'),
|
||||
}
|
||||
exposed = handler.expose_env_vars(env_secrets)
|
||||
|
||||
assert exposed['github_token'] == 'gh_token'
|
||||
assert exposed['gitlab_token'] == 'gl_token'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_env_vars():
|
||||
"""Test get_env_vars with different configurations"""
|
||||
tokens = MappingProxyType(
|
||||
{
|
||||
ProviderType.GITHUB: ProviderToken(token=SecretStr('test_token')),
|
||||
ProviderType.GITLAB: ProviderToken(token=SecretStr('gitlab_token')),
|
||||
}
|
||||
)
|
||||
handler = ProviderHandler(provider_tokens=tokens)
|
||||
|
||||
# Test getting all tokens unexposed
|
||||
env_vars = await handler.get_env_vars(expose_secrets=False)
|
||||
assert isinstance(env_vars, dict)
|
||||
assert isinstance(env_vars[ProviderType.GITHUB], SecretStr)
|
||||
assert env_vars[ProviderType.GITHUB].get_secret_value() == 'test_token'
|
||||
assert env_vars[ProviderType.GITLAB].get_secret_value() == 'gitlab_token'
|
||||
|
||||
# Test getting specific providers
|
||||
env_vars = await handler.get_env_vars(
|
||||
expose_secrets=False, providers=[ProviderType.GITHUB]
|
||||
)
|
||||
assert len(env_vars) == 1
|
||||
assert ProviderType.GITHUB in env_vars
|
||||
assert ProviderType.GITLAB not in env_vars
|
||||
|
||||
# Test exposed secrets
|
||||
exposed_vars = await handler.get_env_vars(expose_secrets=True)
|
||||
assert isinstance(exposed_vars, dict)
|
||||
assert exposed_vars['github_token'] == 'test_token'
|
||||
assert exposed_vars['gitlab_token'] == 'gitlab_token'
|
||||
|
||||
# Test empty tokens
|
||||
empty_handler = ProviderHandler(provider_tokens=MappingProxyType({}))
|
||||
empty_vars = await empty_handler.get_env_vars()
|
||||
assert empty_vars == {}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def event_stream():
|
||||
"""Fixture for event stream testing"""
|
||||
|
||||
class TestEventStream:
|
||||
def __init__(self):
|
||||
self.secrets = {}
|
||||
|
||||
def set_secrets(self, secrets):
|
||||
self.secrets = secrets
|
||||
|
||||
return TestEventStream()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_event_stream_secrets(event_stream):
|
||||
"""Test setting secrets in event stream"""
|
||||
tokens = MappingProxyType(
|
||||
{
|
||||
ProviderType.GITHUB: ProviderToken(token=SecretStr('test_token')),
|
||||
ProviderType.GITLAB: ProviderToken(token=SecretStr('gitlab_token')),
|
||||
}
|
||||
)
|
||||
handler = ProviderHandler(provider_tokens=tokens)
|
||||
|
||||
# Test with provided env_vars
|
||||
env_vars = {
|
||||
ProviderType.GITHUB: SecretStr('new_token'),
|
||||
ProviderType.GITLAB: SecretStr('new_gitlab_token'),
|
||||
}
|
||||
await handler.set_event_stream_secrets(event_stream, env_vars)
|
||||
assert event_stream.secrets == {
|
||||
'github_token': 'new_token',
|
||||
'gitlab_token': 'new_gitlab_token',
|
||||
}
|
||||
|
||||
# Test without env_vars (using existing tokens)
|
||||
await handler.set_event_stream_secrets(event_stream)
|
||||
assert event_stream.secrets == {
|
||||
'github_token': 'test_token',
|
||||
'gitlab_token': 'gitlab_token',
|
||||
}
|
||||
|
||||
|
||||
def test_check_cmd_action_for_provider_token_ref():
|
||||
"""Test detection of provider tokens in command actions"""
|
||||
|
||||
# Test command with GitHub token
|
||||
cmd = CmdRunAction(command='echo $GITHUB_TOKEN')
|
||||
providers = ProviderHandler.check_cmd_action_for_provider_token_ref(cmd)
|
||||
assert ProviderType.GITHUB in providers
|
||||
assert len(providers) == 1
|
||||
|
||||
# Test command with multiple tokens
|
||||
cmd = CmdRunAction(command='echo $GITHUB_TOKEN && echo $GITLAB_TOKEN')
|
||||
providers = ProviderHandler.check_cmd_action_for_provider_token_ref(cmd)
|
||||
assert ProviderType.GITHUB in providers
|
||||
assert ProviderType.GITLAB in providers
|
||||
assert len(providers) == 2
|
||||
|
||||
# Test command without tokens
|
||||
cmd = CmdRunAction(command='echo "Hello"')
|
||||
providers = ProviderHandler.check_cmd_action_for_provider_token_ref(cmd)
|
||||
assert len(providers) == 0
|
||||
|
||||
# Test non-command action
|
||||
from openhands.events.action import MessageAction
|
||||
|
||||
msg = MessageAction(content='test')
|
||||
providers = ProviderHandler.check_cmd_action_for_provider_token_ref(msg)
|
||||
assert len(providers) == 0
|
||||
|
||||
|
||||
def test_get_provider_env_key():
|
||||
"""Test provider environment key generation"""
|
||||
assert ProviderHandler.get_provider_env_key(ProviderType.GITHUB) == 'github_token'
|
||||
assert ProviderHandler.get_provider_env_key(ProviderType.GITLAB) == 'gitlab_token'
|
||||
|
||||
@@ -1,182 +0,0 @@
|
||||
from types import MappingProxyType
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.core.config import AppConfig
|
||||
from openhands.events.action import Action
|
||||
from openhands.events.action.commands import CmdRunAction
|
||||
from openhands.events.observation import NullObservation, Observation
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.integrations.provider import ProviderToken, ProviderType
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.storage import get_file_store
|
||||
|
||||
|
||||
class TestRuntime(Runtime):
|
||||
"""A concrete implementation of Runtime for testing"""
|
||||
|
||||
async def connect(self):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def browse(self, action):
|
||||
return NullObservation()
|
||||
|
||||
def browse_interactive(self, action):
|
||||
return NullObservation()
|
||||
|
||||
def run(self, action):
|
||||
return NullObservation()
|
||||
|
||||
def run_ipython(self, action):
|
||||
return NullObservation()
|
||||
|
||||
def read(self, action):
|
||||
return NullObservation()
|
||||
|
||||
def write(self, action):
|
||||
return NullObservation()
|
||||
|
||||
def copy_from(self, path):
|
||||
return ''
|
||||
|
||||
def copy_to(self, path, content):
|
||||
pass
|
||||
|
||||
def list_files(self, path):
|
||||
return []
|
||||
|
||||
def run_action(self, action: Action) -> Observation:
|
||||
return NullObservation()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir(tmp_path_factory: pytest.TempPathFactory) -> str:
|
||||
return str(tmp_path_factory.mktemp('test_event_stream'))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runtime(temp_dir):
|
||||
"""Fixture for runtime testing"""
|
||||
config = AppConfig()
|
||||
git_provider_tokens = MappingProxyType(
|
||||
{ProviderType.GITHUB: ProviderToken(token=SecretStr('test_token'))}
|
||||
)
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('abc', file_store)
|
||||
runtime = TestRuntime(
|
||||
config=config,
|
||||
event_stream=event_stream,
|
||||
sid='test',
|
||||
user_id='test_user',
|
||||
git_provider_tokens=git_provider_tokens,
|
||||
)
|
||||
return runtime
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_export_latest_git_provider_tokens_no_user_id(temp_dir):
|
||||
"""Test that no token export happens when user_id is not set"""
|
||||
config = AppConfig()
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('abc', file_store)
|
||||
runtime = TestRuntime(config=config, event_stream=event_stream, sid='test')
|
||||
|
||||
# Create a command that would normally trigger token export
|
||||
cmd = CmdRunAction(command='echo $GITHUB_TOKEN')
|
||||
|
||||
# This should not raise any errors and should return None
|
||||
await runtime._export_latest_git_provider_tokens(cmd)
|
||||
|
||||
# Verify no secrets were set
|
||||
assert not event_stream.secrets
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_export_latest_git_provider_tokens_no_token_ref(temp_dir):
|
||||
"""Test that no token export happens when command doesn't reference tokens"""
|
||||
config = AppConfig()
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('abc', file_store)
|
||||
runtime = TestRuntime(
|
||||
config=config, event_stream=event_stream, sid='test', user_id='test_user'
|
||||
)
|
||||
|
||||
# Create a command that doesn't reference any tokens
|
||||
cmd = CmdRunAction(command='echo "hello"')
|
||||
|
||||
# This should not raise any errors and should return None
|
||||
await runtime._export_latest_git_provider_tokens(cmd)
|
||||
|
||||
# Verify no secrets were set
|
||||
assert not event_stream.secrets
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_export_latest_git_provider_tokens_success(runtime):
|
||||
"""Test successful token export when command references tokens"""
|
||||
# Create a command that references the GitHub token
|
||||
cmd = CmdRunAction(command='echo $GITHUB_TOKEN')
|
||||
|
||||
# Export the tokens
|
||||
await runtime._export_latest_git_provider_tokens(cmd)
|
||||
|
||||
# Verify that the token was exported to the event stream
|
||||
assert runtime.event_stream.secrets == {'github_token': 'test_token'}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_export_latest_git_provider_tokens_multiple_refs(temp_dir):
|
||||
"""Test token export with multiple token references"""
|
||||
config = AppConfig()
|
||||
# Initialize with both GitHub and GitLab tokens
|
||||
git_provider_tokens = MappingProxyType(
|
||||
{
|
||||
ProviderType.GITHUB: ProviderToken(token=SecretStr('github_token')),
|
||||
ProviderType.GITLAB: ProviderToken(token=SecretStr('gitlab_token')),
|
||||
}
|
||||
)
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('abc', file_store)
|
||||
runtime = TestRuntime(
|
||||
config=config,
|
||||
event_stream=event_stream,
|
||||
sid='test',
|
||||
user_id='test_user',
|
||||
git_provider_tokens=git_provider_tokens,
|
||||
)
|
||||
|
||||
# Create a command that references multiple tokens
|
||||
cmd = CmdRunAction(command='echo $GITHUB_TOKEN && echo $GITLAB_TOKEN')
|
||||
|
||||
# Export the tokens
|
||||
await runtime._export_latest_git_provider_tokens(cmd)
|
||||
|
||||
# Verify that both tokens were exported
|
||||
assert event_stream.secrets == {
|
||||
'github_token': 'github_token',
|
||||
'gitlab_token': 'gitlab_token',
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_export_latest_git_provider_tokens_token_update(runtime):
|
||||
"""Test that token updates are handled correctly"""
|
||||
# First export with initial token
|
||||
cmd = CmdRunAction(command='echo $GITHUB_TOKEN')
|
||||
await runtime._export_latest_git_provider_tokens(cmd)
|
||||
|
||||
# Update the token
|
||||
new_token = 'new_test_token'
|
||||
runtime.provider_handler._provider_tokens = MappingProxyType(
|
||||
{ProviderType.GITHUB: ProviderToken(token=SecretStr(new_token))}
|
||||
)
|
||||
|
||||
# Export again with updated token
|
||||
await runtime._export_latest_git_provider_tokens(cmd)
|
||||
|
||||
# Verify that the new token was exported
|
||||
assert runtime.event_stream.secrets == {'github_token': new_token}
|
||||
@@ -44,8 +44,6 @@ async def test_init_new_local_session():
|
||||
sio = get_mock_sio()
|
||||
get_running_agent_loops_mock = AsyncMock()
|
||||
get_running_agent_loops_mock.return_value = set()
|
||||
is_agent_loop_running_mock = AsyncMock()
|
||||
is_agent_loop_running_mock.return_value = True
|
||||
with (
|
||||
patch(
|
||||
'openhands.server.conversation_manager.standalone_conversation_manager.Session',
|
||||
@@ -62,19 +60,9 @@ async def test_init_new_local_session():
|
||||
await conversation_manager.maybe_start_agent_loop(
|
||||
'new-session-id', ConversationInitData(), 1
|
||||
)
|
||||
with (
|
||||
patch(
|
||||
'openhands.server.conversation_manager.standalone_conversation_manager.StandaloneConversationManager.is_agent_loop_running',
|
||||
is_agent_loop_running_mock,
|
||||
),
|
||||
):
|
||||
await conversation_manager.join_conversation(
|
||||
'new-session-id',
|
||||
'new-session-id',
|
||||
ConversationInitData(),
|
||||
1,
|
||||
'12345',
|
||||
)
|
||||
await conversation_manager.join_conversation(
|
||||
'new-session-id', 'new-session-id', ConversationInitData(), 1, '12345'
|
||||
)
|
||||
assert session_instance.initialize_agent.call_count == 1
|
||||
assert sio.enter_room.await_count == 1
|
||||
|
||||
@@ -88,8 +76,6 @@ async def test_join_local_session():
|
||||
sio = get_mock_sio()
|
||||
get_running_agent_loops_mock = AsyncMock()
|
||||
get_running_agent_loops_mock.return_value = set()
|
||||
is_agent_loop_running_mock = AsyncMock()
|
||||
is_agent_loop_running_mock.return_value = True
|
||||
with (
|
||||
patch(
|
||||
'openhands.server.conversation_manager.standalone_conversation_manager.Session',
|
||||
@@ -106,26 +92,20 @@ async def test_join_local_session():
|
||||
await conversation_manager.maybe_start_agent_loop(
|
||||
'new-session-id', ConversationInitData(), None
|
||||
)
|
||||
with (
|
||||
patch(
|
||||
'openhands.server.conversation_manager.standalone_conversation_manager.StandaloneConversationManager.is_agent_loop_running',
|
||||
is_agent_loop_running_mock,
|
||||
),
|
||||
):
|
||||
await conversation_manager.join_conversation(
|
||||
'new-session-id',
|
||||
'new-session-id',
|
||||
ConversationInitData(),
|
||||
None,
|
||||
'12345',
|
||||
)
|
||||
await conversation_manager.join_conversation(
|
||||
'new-session-id',
|
||||
'new-session-id',
|
||||
ConversationInitData(),
|
||||
None,
|
||||
'12345',
|
||||
)
|
||||
await conversation_manager.join_conversation(
|
||||
'new-session-id',
|
||||
'new-session-id',
|
||||
ConversationInitData(),
|
||||
None,
|
||||
'12345',
|
||||
)
|
||||
await conversation_manager.join_conversation(
|
||||
'new-session-id',
|
||||
'new-session-id',
|
||||
ConversationInitData(),
|
||||
None,
|
||||
'12345',
|
||||
)
|
||||
assert session_instance.initialize_agent.call_count == 1
|
||||
assert sio.enter_room.await_count == 2
|
||||
|
||||
|
||||
Reference in New Issue
Block a user