diff --git a/enterprise/migrations/versions/081_add_parent_conversation_id.py b/enterprise/migrations/versions/081_add_parent_conversation_id.py new file mode 100644 index 0000000000..b27c444632 --- /dev/null +++ b/enterprise/migrations/versions/081_add_parent_conversation_id.py @@ -0,0 +1,41 @@ +"""add parent_conversation_id to conversation_metadata + +Revision ID: 081 +Revises: 080 +Create Date: 2025-11-06 00:00:00.000000 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = '081' +down_revision: Union[str, None] = '080' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + op.add_column( + 'conversation_metadata', + sa.Column('parent_conversation_id', sa.String(), nullable=True), + ) + op.create_index( + op.f('ix_conversation_metadata_parent_conversation_id'), + 'conversation_metadata', + ['parent_conversation_id'], + unique=False, + ) + + +def downgrade() -> None: + """Downgrade schema.""" + op.drop_index( + op.f('ix_conversation_metadata_parent_conversation_id'), + table_name='conversation_metadata', + ) + op.drop_column('conversation_metadata', 'parent_conversation_id') diff --git a/enterprise/storage/saas_conversation_store.py b/enterprise/storage/saas_conversation_store.py index 916c4009ad..d1dfc2a2d6 100644 --- a/enterprise/storage/saas_conversation_store.py +++ b/enterprise/storage/saas_conversation_store.py @@ -82,6 +82,7 @@ class SaasConversationStore(ConversationStore): kwargs.pop('reasoning_tokens', None) kwargs.pop('context_window', None) kwargs.pop('per_turn_token', None) + kwargs.pop('parent_conversation_id', None) return ConversationMetadata(**kwargs) diff --git a/frontend/__tests__/posthog-tracking.test.tsx b/frontend/__tests__/posthog-tracking.test.tsx new file mode 100644 index 0000000000..5d76649013 --- /dev/null +++ b/frontend/__tests__/posthog-tracking.test.tsx @@ -0,0 +1,233 @@ +import { + describe, + it, + expect, + beforeAll, + afterAll, + afterEach, + vi, +} from "vitest"; +import { screen, waitFor, render, cleanup } from "@testing-library/react"; +import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; +import { createMockAgentErrorEvent } from "#/mocks/mock-ws-helpers"; +import { ConversationWebSocketProvider } from "#/contexts/conversation-websocket-context"; +import { conversationWebSocketTestSetup } from "./helpers/msw-websocket-setup"; +import { ConnectionStatusComponent } from "./helpers/websocket-test-components"; + +// Mock the tracking function +const mockTrackCreditLimitReached = vi.fn(); + +// Mock useTracking hook +vi.mock("#/hooks/use-tracking", () => ({ + useTracking: () => ({ + trackCreditLimitReached: mockTrackCreditLimitReached, + trackLoginButtonClick: vi.fn(), + trackConversationCreated: vi.fn(), + trackPushButtonClick: vi.fn(), + trackPullButtonClick: vi.fn(), + trackCreatePrButtonClick: vi.fn(), + trackGitProviderConnected: vi.fn(), + trackUserSignupCompleted: vi.fn(), + trackCreditsPurchased: vi.fn(), + }), +})); + +// Mock useActiveConversation hook +vi.mock("#/hooks/query/use-active-conversation", () => ({ + useActiveConversation: () => ({ + data: null, + isLoading: false, + error: null, + }), +})); + +// MSW WebSocket mock setup +const { wsLink, server: mswServer } = conversationWebSocketTestSetup(); + +beforeAll(() => { + // The global MSW server from vitest.setup.ts is already running + // We just need to start our WebSocket-specific server + mswServer.listen({ onUnhandledRequest: "bypass" }); +}); + +afterEach(() => { + // Clear all mocks before each test + mockTrackCreditLimitReached.mockClear(); + mswServer.resetHandlers(); + // Clean up any React components + cleanup(); +}); + +afterAll(async () => { + // Close the WebSocket MSW server + mswServer.close(); + + // Give time for any pending WebSocket connections to close. This is very important to prevent serious memory leaks + await new Promise((resolve) => { + setTimeout(resolve, 500); + }); +}); + +// Helper function to render components with all necessary providers +function renderWithProviders( + children: React.ReactNode, + conversationId = "test-conversation-123", + conversationUrl = "http://localhost:3000/api/conversations/test-conversation-123", +) { + const queryClient = new QueryClient({ + defaultOptions: { + queries: { retry: false }, + mutations: { retry: false }, + }, + }); + + return render( + + + {children} + + , + ); +} + +describe("PostHog Analytics Tracking", () => { + describe("Credit Limit Tracking", () => { + it("should track credit_limit_reached when AgentErrorEvent contains budget error", async () => { + // Create a mock AgentErrorEvent with budget-related error message + const mockBudgetErrorEvent = createMockAgentErrorEvent({ + error: "ExceededBudget: Task exceeded maximum budget of $10.00", + }); + + // Set up MSW to send the budget error event when connection is established + mswServer.use( + wsLink.addEventListener("connection", ({ client, server }) => { + server.connect(); + // Send the mock budget error event after connection + client.send(JSON.stringify(mockBudgetErrorEvent)); + }), + ); + + // Render with all providers + renderWithProviders(); + + // Wait for connection to be established + await waitFor(() => { + expect(screen.getByTestId("connection-state")).toHaveTextContent( + "OPEN", + ); + }); + + // Wait for the tracking event to be captured + await waitFor(() => { + expect(mockTrackCreditLimitReached).toHaveBeenCalledWith( + expect.objectContaining({ + conversationId: "test-conversation-123", + }), + ); + }); + }); + + it("should track credit_limit_reached when AgentErrorEvent contains 'credit' keyword", async () => { + // Create error with "credit" keyword (case-insensitive) + const mockCreditErrorEvent = createMockAgentErrorEvent({ + error: "Insufficient CREDIT to complete this operation", + }); + + mswServer.use( + wsLink.addEventListener("connection", ({ client, server }) => { + server.connect(); + client.send(JSON.stringify(mockCreditErrorEvent)); + }), + ); + + renderWithProviders(); + + await waitFor(() => { + expect(screen.getByTestId("connection-state")).toHaveTextContent( + "OPEN", + ); + }); + + await waitFor(() => { + expect(mockTrackCreditLimitReached).toHaveBeenCalledWith( + expect.objectContaining({ + conversationId: "test-conversation-123", + }), + ); + }); + }); + + it("should NOT track credit_limit_reached for non-budget errors", async () => { + // Create a regular error without budget/credit keywords + const mockRegularErrorEvent = createMockAgentErrorEvent({ + error: "Failed to execute command: Permission denied", + }); + + mswServer.use( + wsLink.addEventListener("connection", ({ client, server }) => { + server.connect(); + client.send(JSON.stringify(mockRegularErrorEvent)); + }), + ); + + renderWithProviders(); + + // Wait for connection and error to be processed + await waitFor(() => { + expect(screen.getByTestId("connection-state")).toHaveTextContent( + "OPEN", + ); + }); + + // Verify that credit_limit_reached was NOT tracked + expect(mockTrackCreditLimitReached).not.toHaveBeenCalled(); + }); + + it("should only track credit_limit_reached once per error event", async () => { + const mockBudgetErrorEvent = createMockAgentErrorEvent({ + error: "Budget exceeded: $10.00 limit reached", + }); + + mswServer.use( + wsLink.addEventListener("connection", ({ client, server }) => { + server.connect(); + // Send the same error event twice + client.send(JSON.stringify(mockBudgetErrorEvent)); + client.send( + JSON.stringify({ ...mockBudgetErrorEvent, id: "different-id" }), + ); + }), + ); + + renderWithProviders(); + + await waitFor(() => { + expect(screen.getByTestId("connection-state")).toHaveTextContent( + "OPEN", + ); + }); + + await waitFor(() => { + expect(mockTrackCreditLimitReached).toHaveBeenCalledTimes(2); + }); + + // Both calls should be for credit_limit_reached (once per event) + expect(mockTrackCreditLimitReached).toHaveBeenNthCalledWith( + 1, + expect.objectContaining({ + conversationId: "test-conversation-123", + }), + ); + expect(mockTrackCreditLimitReached).toHaveBeenNthCalledWith( + 2, + expect.objectContaining({ + conversationId: "test-conversation-123", + }), + ); + }); + }); +}); diff --git a/frontend/public/android-chrome-192x192.png b/frontend/public/android-chrome-192x192.png index 23f4f4fd24..31d5801adb 100644 Binary files a/frontend/public/android-chrome-192x192.png and b/frontend/public/android-chrome-192x192.png differ diff --git a/frontend/public/android-chrome-512x512.png b/frontend/public/android-chrome-512x512.png index 1fe76e4196..57e1544c5f 100644 Binary files a/frontend/public/android-chrome-512x512.png and b/frontend/public/android-chrome-512x512.png differ diff --git a/frontend/public/apple-touch-icon.png b/frontend/public/apple-touch-icon.png index d6146fed32..31d5801adb 100644 Binary files a/frontend/public/apple-touch-icon.png and b/frontend/public/apple-touch-icon.png differ diff --git a/frontend/public/favicon-16x16.png b/frontend/public/favicon-16x16.png index 5db772fa15..4f230c5981 100644 Binary files a/frontend/public/favicon-16x16.png and b/frontend/public/favicon-16x16.png differ diff --git a/frontend/public/favicon-32x32.png b/frontend/public/favicon-32x32.png index bb75b8b65f..1f874a817d 100644 Binary files a/frontend/public/favicon-32x32.png and b/frontend/public/favicon-32x32.png differ diff --git a/frontend/public/favicon.ico b/frontend/public/favicon.ico index 680e72b56f..502a6b37cd 100644 Binary files a/frontend/public/favicon.ico and b/frontend/public/favicon.ico differ diff --git a/frontend/public/safari-pinned-tab.svg b/frontend/public/safari-pinned-tab.svg index fb271c3449..daa0090f0f 100644 --- a/frontend/public/safari-pinned-tab.svg +++ b/frontend/public/safari-pinned-tab.svg @@ -1,32 +1,7 @@ - - safari-pinned-tab-svg - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + diff --git a/frontend/src/api/conversation-service/v1-conversation-service.api.ts b/frontend/src/api/conversation-service/v1-conversation-service.api.ts index 717228c79f..5ca7daf09a 100644 --- a/frontend/src/api/conversation-service/v1-conversation-service.api.ts +++ b/frontend/src/api/conversation-service/v1-conversation-service.api.ts @@ -60,6 +60,8 @@ class V1ConversationService { selected_branch?: string, conversationInstructions?: string, trigger?: ConversationTrigger, + parent_conversation_id?: string, + agent_type?: "default" | "plan", ): Promise { const body: V1AppConversationStartRequest = { selected_repository: selectedRepository, @@ -67,6 +69,8 @@ class V1ConversationService { selected_branch, title: conversationInstructions, trigger, + parent_conversation_id: parent_conversation_id || null, + agent_type, }; // Add initial message if provided @@ -111,11 +115,11 @@ class V1ConversationService { * Search for start tasks (ongoing tasks that haven't completed yet) * Use this to find tasks that were started but the user navigated away * - * Note: Backend only supports filtering by limit. To filter by repository/trigger, + * Note: Backend supports filtering by limit and created_at__gte. To filter by repository/trigger, * filter the results client-side after fetching. * * @param limit Maximum number of tasks to return (max 100) - * @returns Array of start tasks + * @returns Array of start tasks from the last 20 minutes */ static async searchStartTasks( limit: number = 100, @@ -123,6 +127,10 @@ class V1ConversationService { const params = new URLSearchParams(); params.append("limit", limit.toString()); + // Only get tasks from the last 20 minutes + const twentyMinutesAgo = new Date(Date.now() - 20 * 60 * 1000); + params.append("created_at__gte", twentyMinutesAgo.toISOString()); + const { data } = await openHands.get( `/api/v1/app-conversations/start-tasks/search?${params.toString()}`, ); diff --git a/frontend/src/api/conversation-service/v1-conversation-service.types.ts b/frontend/src/api/conversation-service/v1-conversation-service.types.ts index b48ce5bd6b..3441448472 100644 --- a/frontend/src/api/conversation-service/v1-conversation-service.types.ts +++ b/frontend/src/api/conversation-service/v1-conversation-service.types.ts @@ -30,6 +30,8 @@ export interface V1AppConversationStartRequest { title?: string | null; trigger?: ConversationTrigger | null; pr_number?: number[]; + parent_conversation_id?: string | null; + agent_type?: "default" | "plan"; } export type V1AppConversationStartTaskStatus = diff --git a/frontend/src/api/open-hands.types.ts b/frontend/src/api/open-hands.types.ts index 9a30e46027..47d34fe567 100644 --- a/frontend/src/api/open-hands.types.ts +++ b/frontend/src/api/open-hands.types.ts @@ -77,6 +77,7 @@ export interface Conversation { session_api_key: string | null; pr_number?: number[] | null; conversation_version?: "V0" | "V1"; + sub_conversation_ids?: string[]; } export interface ResultSet { diff --git a/frontend/src/assets/branding/openhands-logo.svg b/frontend/src/assets/branding/openhands-logo.svg index a079e0aa51..3aa40d4404 100644 --- a/frontend/src/assets/branding/openhands-logo.svg +++ b/frontend/src/assets/branding/openhands-logo.svg @@ -1,16 +1,9 @@ - - - - - - - - - - - - - - - + + + + + + + + diff --git a/frontend/src/components/features/chat/change-agent-button.tsx b/frontend/src/components/features/chat/change-agent-button.tsx index 706f582b59..6d41f5cfc1 100644 --- a/frontend/src/components/features/chat/change-agent-button.tsx +++ b/frontend/src/components/features/chat/change-agent-button.tsx @@ -1,4 +1,4 @@ -import React, { useMemo, useEffect } from "react"; +import React, { useMemo, useEffect, useState } from "react"; import { useTranslation } from "react-i18next"; import { Typography } from "#/ui/typography"; import { I18nKey } from "#/i18n/declaration"; @@ -11,10 +11,12 @@ import { cn } from "#/utils/utils"; import { USE_PLANNING_AGENT } from "#/utils/feature-flags"; import { useAgentState } from "#/hooks/use-agent-state"; import { AgentState } from "#/types/agent-state"; +import { useActiveConversation } from "#/hooks/query/use-active-conversation"; +import { useCreateConversation } from "#/hooks/mutation/use-create-conversation"; +import { displaySuccessToast } from "#/utils/custom-toast-handlers"; export function ChangeAgentButton() { - const { t } = useTranslation(); - const [contextMenuOpen, setContextMenuOpen] = React.useState(false); + const [contextMenuOpen, setContextMenuOpen] = useState(false); const conversationMode = useConversationStore( (state) => state.conversationMode, @@ -28,8 +30,14 @@ export function ChangeAgentButton() { const { curAgentState } = useAgentState(); + const { t } = useTranslation(); + const isAgentRunning = curAgentState === AgentState.RUNNING; + const { data: conversation } = useActiveConversation(); + const { mutate: createConversation, isPending: isCreatingConversation } = + useCreateConversation(); + // Close context menu when agent starts running useEffect(() => { if (isAgentRunning && contextMenuOpen) { @@ -37,6 +45,75 @@ export function ChangeAgentButton() { } }, [isAgentRunning, contextMenuOpen]); + const handlePlanClick = ( + event: React.MouseEvent | KeyboardEvent, + ) => { + event.preventDefault(); + event.stopPropagation(); + + // Set conversation mode to "plan" immediately + setConversationMode("plan"); + + // Check if sub_conversation_ids is not empty + if ( + (conversation?.sub_conversation_ids && + conversation.sub_conversation_ids.length > 0) || + !conversation?.conversation_id + ) { + // Do nothing if both conditions are true + return; + } + + // Create a new sub-conversation if we have a current conversation ID + createConversation( + { + parentConversationId: conversation.conversation_id, + agentType: "plan", + }, + { + onSuccess: () => + displaySuccessToast( + t(I18nKey.PLANNING_AGENTT$PLANNING_AGENT_INITIALIZED), + ), + }, + ); + }; + + // Handle Shift + Tab keyboard shortcut to cycle through modes + useEffect(() => { + if (!shouldUsePlanningAgent || isAgentRunning) { + return undefined; + } + + const handleKeyDown = (event: KeyboardEvent) => { + // Check for Shift + Tab combination + if (event.shiftKey && event.key === "Tab") { + // Prevent default tab navigation behavior + event.preventDefault(); + event.stopPropagation(); + + // Cycle between modes: code -> plan -> code + const nextMode = conversationMode === "code" ? "plan" : "code"; + if (nextMode === "plan") { + handlePlanClick(event); + } else { + setConversationMode(nextMode); + } + } + }; + + document.addEventListener("keydown", handleKeyDown); + + return () => { + document.removeEventListener("keydown", handleKeyDown); + }; + }, [ + shouldUsePlanningAgent, + isAgentRunning, + conversationMode, + setConversationMode, + ]); + const handleButtonClick = (event: React.MouseEvent) => { event.preventDefault(); event.stopPropagation(); @@ -49,12 +126,6 @@ export function ChangeAgentButton() { setConversationMode("code"); }; - const handlePlanClick = (event: React.MouseEvent) => { - event.preventDefault(); - event.stopPropagation(); - setConversationMode("plan"); - }; - const isExecutionAgent = conversationMode === "code"; const buttonLabel = useMemo(() => { @@ -71,6 +142,8 @@ export function ChangeAgentButton() { return ; }, [isExecutionAgent]); + const isButtonDisabled = isAgentRunning || isCreatingConversation; + if (!shouldUsePlanningAgent) { return null; } @@ -80,11 +153,11 @@ export function ChangeAgentButton() {
+ + + {description} + +
+ ); +} diff --git a/frontend/src/components/features/conversation-panel/conversation-card/conversation-version-badge.tsx b/frontend/src/components/features/conversation-panel/conversation-card/conversation-version-badge.tsx index 371b5f98aa..9626368fb5 100644 --- a/frontend/src/components/features/conversation-panel/conversation-card/conversation-version-badge.tsx +++ b/frontend/src/components/features/conversation-panel/conversation-card/conversation-version-badge.tsx @@ -25,10 +25,7 @@ export function ConversationVersionBadge({ diff --git a/frontend/src/contexts/conversation-websocket-context.tsx b/frontend/src/contexts/conversation-websocket-context.tsx index a9f29fb426..bb56e1497e 100644 --- a/frontend/src/contexts/conversation-websocket-context.tsx +++ b/frontend/src/contexts/conversation-websocket-context.tsx @@ -27,8 +27,10 @@ import { } from "#/types/v1/type-guards"; import { handleActionEventCacheInvalidation } from "#/utils/cache-utils"; import { buildWebSocketUrl } from "#/utils/websocket-url"; +import { isBudgetOrCreditError } from "#/utils/error-handler"; import type { V1SendMessageRequest } from "#/api/conversation-service/v1-conversation-service.types"; import EventService from "#/api/event-service/event-service.api"; +import { useTracking } from "#/hooks/use-tracking"; // eslint-disable-next-line @typescript-eslint/naming-convention export type V1_WebSocketConnectionState = @@ -69,6 +71,7 @@ export function ConversationWebSocketProvider({ const { removeOptimisticUserMessage } = useOptimisticUserMessageStore(); const { setExecutionStatus } = useV1ConversationStateStore(); const { appendInput, appendOutput } = useCommandStore(); + const { trackCreditLimitReached } = useTracking(); // History loading state const [isLoadingHistory, setIsLoadingHistory] = useState(true); @@ -132,6 +135,13 @@ export function ConversationWebSocketProvider({ // Handle AgentErrorEvent specifically if (isAgentErrorEvent(event)) { setErrorMessage(event.error); + + // Track credit limit reached if the error is budget-related + if (isBudgetOrCreditError(event.error)) { + trackCreditLimitReached({ + conversationId: conversationId || "unknown", + }); + } } // Clear optimistic user message when a user message is confirmed diff --git a/frontend/src/hooks/mutation/use-create-conversation.ts b/frontend/src/hooks/mutation/use-create-conversation.ts index 4baba32802..a44d921e44 100644 --- a/frontend/src/hooks/mutation/use-create-conversation.ts +++ b/frontend/src/hooks/mutation/use-create-conversation.ts @@ -17,6 +17,8 @@ interface CreateConversationVariables { suggestedTask?: SuggestedTask; conversationInstructions?: string; createMicroagent?: CreateMicroagent; + parentConversationId?: string; + agentType?: "default" | "plan"; } // Response type that combines both V1 and legacy responses @@ -44,6 +46,8 @@ export const useCreateConversation = () => { suggestedTask, conversationInstructions, createMicroagent, + parentConversationId, + agentType, } = variables; const useV1 = USE_V1_CONVERSATION_API() && !createMicroagent; @@ -57,6 +61,8 @@ export const useCreateConversation = () => { repository?.branch, conversationInstructions, undefined, // trigger - will be set by backend + parentConversationId, + agentType, ); // Return a special task ID that the frontend will recognize diff --git a/frontend/src/hooks/query/use-sub-conversations.ts b/frontend/src/hooks/query/use-sub-conversations.ts new file mode 100644 index 0000000000..53e6c84d45 --- /dev/null +++ b/frontend/src/hooks/query/use-sub-conversations.ts @@ -0,0 +1,39 @@ +import { useQuery } from "@tanstack/react-query"; +import V1ConversationService from "#/api/conversation-service/v1-conversation-service.api"; +import { V1AppConversation } from "#/api/conversation-service/v1-conversation-service.types"; + +const FIVE_MINUTES = 1000 * 60 * 5; +const FIFTEEN_MINUTES = 1000 * 60 * 15; + +/** + * React hook to fetch sub-conversations by their IDs + * + * @param subConversationIds Array of sub-conversation IDs to fetch + * @returns React Query result with sub-conversation data, loading, and error states + * + * @example + * ```tsx + * const { data: subConversations, isLoading, isError } = useSubConversations( + * conversation.sub_conversation_ids || [] + * ); + * ``` + */ +export const useSubConversations = ( + subConversationIds: string[] | null | undefined, +) => { + const ids = subConversationIds || []; + + return useQuery<(V1AppConversation | null)[]>({ + queryKey: ["v1", "sub-conversations", ids], + queryFn: async () => { + if (ids.length === 0) { + return []; + } + return V1ConversationService.batchGetAppConversations(ids); + }, + enabled: ids.length > 0, + staleTime: FIVE_MINUTES, + gcTime: FIFTEEN_MINUTES, + retry: false, + }); +}; diff --git a/frontend/src/hooks/use-sync-posthog-consent.ts b/frontend/src/hooks/use-sync-posthog-consent.ts new file mode 100644 index 0000000000..615aa9a1bf --- /dev/null +++ b/frontend/src/hooks/use-sync-posthog-consent.ts @@ -0,0 +1,41 @@ +import React from "react"; +import { usePostHog } from "posthog-js/react"; +import { handleCaptureConsent } from "#/utils/handle-capture-consent"; +import { useSettings } from "./query/use-settings"; + +/** + * Hook to sync PostHog opt-in/out state with backend setting on mount. + * This ensures that if the backend setting changes (e.g., via API or different client), + * the PostHog instance reflects the current user preference. + */ +export const useSyncPostHogConsent = () => { + const posthog = usePostHog(); + const { data: settings } = useSettings(); + const hasSyncedRef = React.useRef(false); + + React.useEffect(() => { + // Only run once when both PostHog and settings are available + if (!posthog || settings === undefined || hasSyncedRef.current) { + return; + } + + const backendConsent = settings.USER_CONSENTS_TO_ANALYTICS; + + // Only sync if there's a backend preference set + if (backendConsent !== null) { + const posthogHasOptedIn = posthog.has_opted_in_capturing(); + const posthogHasOptedOut = posthog.has_opted_out_capturing(); + + // Check if PostHog state is out of sync with backend + const needsSync = + (backendConsent === true && !posthogHasOptedIn) || + (backendConsent === false && !posthogHasOptedOut); + + if (needsSync) { + handleCaptureConsent(posthog, backendConsent); + } + + hasSyncedRef.current = true; + } + }, [posthog, settings]); +}; diff --git a/frontend/src/hooks/use-tracking.ts b/frontend/src/hooks/use-tracking.ts index 4b7959c1dd..0dfc0f0705 100644 --- a/frontend/src/hooks/use-tracking.ts +++ b/frontend/src/hooks/use-tracking.ts @@ -67,6 +67,38 @@ export const useTracking = () => { }); }; + const trackUserSignupCompleted = () => { + posthog.capture("user_signup_completed", { + signup_timestamp: new Date().toISOString(), + ...commonProperties, + }); + }; + + const trackCreditsPurchased = ({ + amountUsd, + stripeSessionId, + }: { + amountUsd: number; + stripeSessionId: string; + }) => { + posthog.capture("credits_purchased", { + amount_usd: amountUsd, + stripe_session_id: stripeSessionId, + ...commonProperties, + }); + }; + + const trackCreditLimitReached = ({ + conversationId, + }: { + conversationId: string; + }) => { + posthog.capture("credit_limit_reached", { + conversation_id: conversationId, + ...commonProperties, + }); + }; + return { trackLoginButtonClick, trackConversationCreated, @@ -74,5 +106,8 @@ export const useTracking = () => { trackPullButtonClick, trackCreatePrButtonClick, trackGitProviderConnected, + trackUserSignupCompleted, + trackCreditsPurchased, + trackCreditLimitReached, }; }; diff --git a/frontend/src/i18n/declaration.ts b/frontend/src/i18n/declaration.ts index 7fabded8df..8eff691a29 100644 --- a/frontend/src/i18n/declaration.ts +++ b/frontend/src/i18n/declaration.ts @@ -944,4 +944,7 @@ export enum I18nKey { COMMON$ASK = "COMMON$ASK", COMMON$PLAN = "COMMON$PLAN", COMMON$LET_S_WORK_ON_A_PLAN = "COMMON$LET_S_WORK_ON_A_PLAN", + COMMON$CODE_AGENT_DESCRIPTION = "COMMON$CODE_AGENT_DESCRIPTION", + COMMON$PLAN_AGENT_DESCRIPTION = "COMMON$PLAN_AGENT_DESCRIPTION", + PLANNING_AGENTT$PLANNING_AGENT_INITIALIZED = "PLANNING_AGENTT$PLANNING_AGENT_INITIALIZED", } diff --git a/frontend/src/i18n/translation.json b/frontend/src/i18n/translation.json index 992bc69ca7..3dbec85c2f 100644 --- a/frontend/src/i18n/translation.json +++ b/frontend/src/i18n/translation.json @@ -15102,5 +15102,53 @@ "tr": "Bir plan üzerinde çalışalım", "de": "Lassen Sie uns an einem Plan arbeiten", "uk": "Давайте розробимо план" + }, + "COMMON$CODE_AGENT_DESCRIPTION": { + "en": "Write, edit, and debug with AI assistance in real time.", + "ja": "AIの支援をリアルタイムで受けながら、コードの作成、編集、デバッグを行いましょう。", + "zh-CN": "实时在 AI 协助下编写、编辑和调试。", + "zh-TW": "即時在 AI 協助下編寫、編輯和除錯。", + "ko-KR": "AI의 지원을 받아 실시간으로 작성, 편집 및 디버깅하세요.", + "no": "Skriv, rediger og feilsøk med AI-assistanse i sanntid.", + "it": "Scrivi, modifica e esegui il debug con assistenza AI in tempo reale.", + "pt": "Escreva, edite e depure com assistência de IA em tempo real.", + "es": "Escribe, edita y depura con ayuda de IA en tiempo real.", + "ar": "اكتب وعدّل وصحّح الأخطاء بمساعدة الذكاء الاصطناعي في الوقت الفعلي.", + "fr": "Rédigez, modifiez et déboguez avec l’aide de l’IA en temps réel.", + "tr": "AI desteğiyle gerçek zamanlı olarak yazın, düzenleyin ve hata ayıklayın.", + "de": "Schreiben, bearbeiten und debuggen Sie mit KI-Unterstützung in Echtzeit.", + "uk": "Пишіть, редагуйте та налагоджуйте з підтримкою ШІ у реальному часі." + }, + "COMMON$PLAN_AGENT_DESCRIPTION": { + "en": "Outline goals, structure tasks, and map your next steps.", + "ja": "目標を明確にし、タスクを構造化し、次のステップを計画しましょう。", + "zh-CN": "概述目标、结构化任务,并规划下一步。", + "zh-TW": "概述目標、結構化任務,並規劃下一步。", + "ko-KR": "목표를 개요하고, 작업을 구조화하며, 다음 단계를 구상하세요.", + "no": "Skisser mål, strukturer oppgaver og planlegg dine neste steg.", + "it": "Definisci gli obiettivi, struttura le attività e pianifica i prossimi passi.", + "pt": "Esboce objetivos, estruture tarefas e trace seus próximos passos.", + "es": "Define objetivos, estructura tareas y planifica tus próximos pasos.", + "ar": "حدد الأهداف، نظم المهام، وارسم خطواتك التالية.", + "fr": "Dressez des objectifs, structurez vos tâches et planifiez vos prochaines étapes.", + "tr": "Hedefleri belirtin, görevleri yapılandırın ve sonraki adımlarınızı belirleyin.", + "de": "Umreißen Sie Ziele, strukturieren Sie Aufgaben und planen Sie Ihre nächsten Schritte.", + "uk": "Окресліть цілі, структуруйте завдання та сплануйте наступні кроки." + }, + "PLANNING_AGENTT$PLANNING_AGENT_INITIALIZED": { + "en": "Planning agent initialized", + "ja": "プランニングエージェントが初期化されました", + "zh-CN": "规划代理已初始化", + "zh-TW": "規劃代理已初始化", + "ko-KR": "계획 에이전트가 초기화되었습니다", + "no": "Planleggingsagent er initialisert", + "it": "Agente di pianificazione inizializzato", + "pt": "Agente de planejamento inicializado", + "es": "Agente de planificación inicializado", + "ar": "تم تهيئة وكيل التخطيط", + "fr": "Agent de planification initialisé", + "tr": "Planlama ajanı başlatıldı", + "de": "Planungsagent wurde initialisiert", + "uk": "Агент планування ініціалізовано" } } diff --git a/frontend/src/routes/accept-tos.tsx b/frontend/src/routes/accept-tos.tsx index 773f7ba2ee..f723f2a5f6 100644 --- a/frontend/src/routes/accept-tos.tsx +++ b/frontend/src/routes/accept-tos.tsx @@ -10,6 +10,7 @@ import { BrandButton } from "#/components/features/settings/brand-button"; import { handleCaptureConsent } from "#/utils/handle-capture-consent"; import { openHands } from "#/api/open-hands-axios"; import { ModalBackdrop } from "#/components/shared/modals/modal-backdrop"; +import { useTracking } from "#/hooks/use-tracking"; export default function AcceptTOS() { const posthog = usePostHog(); @@ -17,6 +18,7 @@ export default function AcceptTOS() { const navigate = useNavigate(); const [searchParams] = useSearchParams(); const [isTosAccepted, setIsTosAccepted] = React.useState(false); + const { trackUserSignupCompleted } = useTracking(); // Get the redirect URL from the query parameters const redirectUrl = searchParams.get("redirect_url") || "/"; @@ -33,6 +35,9 @@ export default function AcceptTOS() { }); }, onSuccess: (response) => { + // Track user signup completion + trackUserSignupCompleted(); + // Get the redirect URL from the response const finalRedirectUrl = response.data.redirect_url || redirectUrl; diff --git a/frontend/src/routes/billing.tsx b/frontend/src/routes/billing.tsx index fdd410f6c4..c004d93dee 100644 --- a/frontend/src/routes/billing.tsx +++ b/frontend/src/routes/billing.tsx @@ -7,21 +7,35 @@ import { displaySuccessToast, } from "#/utils/custom-toast-handlers"; import { I18nKey } from "#/i18n/declaration"; +import { useTracking } from "#/hooks/use-tracking"; function BillingSettingsScreen() { const { t } = useTranslation(); const [searchParams, setSearchParams] = useSearchParams(); + const { trackCreditsPurchased } = useTracking(); const checkoutStatus = searchParams.get("checkout"); React.useEffect(() => { if (checkoutStatus === "success") { + // Get purchase details from URL params + const amount = searchParams.get("amount"); + const sessionId = searchParams.get("session_id"); + + // Track credits purchased if we have the necessary data + if (amount && sessionId) { + trackCreditsPurchased({ + amountUsd: parseFloat(amount), + stripeSessionId: sessionId, + }); + } + displaySuccessToast(t(I18nKey.PAYMENT$SUCCESS)); } else if (checkoutStatus === "cancel") { displayErrorToast(t(I18nKey.PAYMENT$CANCELLED)); } setSearchParams({}); - }, [checkoutStatus]); + }, [checkoutStatus, searchParams, setSearchParams, t, trackCreditsPurchased]); return ; } diff --git a/frontend/src/routes/root-layout.tsx b/frontend/src/routes/root-layout.tsx index 930451dae9..264ae541c8 100644 --- a/frontend/src/routes/root-layout.tsx +++ b/frontend/src/routes/root-layout.tsx @@ -25,6 +25,7 @@ import { useIsOnTosPage } from "#/hooks/use-is-on-tos-page"; import { useAutoLogin } from "#/hooks/use-auto-login"; import { useAuthCallback } from "#/hooks/use-auth-callback"; import { useReoTracking } from "#/hooks/use-reo-tracking"; +import { useSyncPostHogConsent } from "#/hooks/use-sync-posthog-consent"; import { LOCAL_STORAGE_KEYS } from "#/utils/local-storage"; import { EmailVerificationGuard } from "#/components/features/guards/email-verification-guard"; import { MaintenanceBanner } from "#/components/features/maintenance/maintenance-banner"; @@ -100,6 +101,9 @@ export default function MainApp() { // Initialize Reo.dev tracking in SaaS mode useReoTracking(); + // Sync PostHog opt-in/out state with backend setting on mount + useSyncPostHogConsent(); + React.useEffect(() => { // Don't change language when on TOS page if (!isOnTosPage && settings?.LANGUAGE) { diff --git a/frontend/src/utils/error-handler.ts b/frontend/src/utils/error-handler.ts index 385881e0ce..d479853b6d 100644 --- a/frontend/src/utils/error-handler.ts +++ b/frontend/src/utils/error-handler.ts @@ -50,3 +50,11 @@ export function showChatError({ status_update: true, }); } + +/** + * Checks if an error message indicates a budget or credit limit issue + */ +export function isBudgetOrCreditError(errorMessage: string): boolean { + const lowerCaseError = errorMessage.toLowerCase(); + return lowerCaseError.includes("budget") || lowerCaseError.includes("credit"); +} diff --git a/microagents/agent-builder.md b/microagents/agent-builder.md new file mode 100644 index 0000000000..a1b7931e9e --- /dev/null +++ b/microagents/agent-builder.md @@ -0,0 +1,39 @@ +--- +name: agent_sdk_builder +version: 1.0.0 +author: openhands +agent: CodeActAgent +triggers: + - /agent-builder +inputs: + - name: INITIAL_PROMPT + description: "Initial SDK requirements" +--- + +# Agent Builder and Interviewer Role + +You are an expert requirements gatherer and agent builder. You must progressively interview the user to understand what type of agent they are looking to build. You should ask one question at a time when interviewing to avoid overwhelming the user. + +Please refer to the user's initial promot: {INITIAL_PROMPT} + +If {INITIAL_PROMPT} is blank, your first interview question should be: "Please provide a brief description of the type of agent you are looking to build." + +# Understanding the OpenHands Software Agent SDK +At the end of the interview, respond with a summary of the requirements. Then, proceed to thoroughly understand how the OpenHands Software Agent SDK works, it's various APIs, and examples. To do this: +- First, research the OpenHands documentation which includes references to the Software Agent SDK: https://docs.openhands.dev/llms.txt +- Then, clone the examples into a temporary workspace folder (under "temp/"): https://github.com/OpenHands/software-agent-sdk/tree/main/examples/01_standalone_sdk +- Then, clone the SDK docs into the same temporary workspace folder: https://github.com/OpenHands/docs/tree/main/sdk + +After analyzing the OpenHands Agent SDK, you may optionally ask additional clarifying questions in case it's important for the technical design of the agent. + +# Generating the SDK Plan +You can then proceed to build a technical implementation plan based on the user requirements and your understanding of how the OpenHands Agent SDK works. +- The plan should be stored in "plan/SDK_PLAN.md" from the root of the workspace. +- A visual representation of how the agent should work based on the SDK_PLAN.md. This should look like a flow diagram with nodes and edges. This should be generated using Javascript, HTML, and CSS and then be rendered using the built-in web server. Store this in the plan/ directory. + +# Implementing the Plan +After the plan is generated, please ask the user if they are ready to generate the SDK implementation. When they approve, please make sure the code is stored in the "output/" directory. Make sure the code provides logging that a user can see in the terminal. Ideally, the SDK is a single python file. + +Additional guidelines: +- Users can configure their LLM API Key using an environment variable named "LLM_API_KEY" +- Unless otherwise specified, default to this model: openhands/claude-sonnet-4-20250514. This is configurable through the LLM_BASE_MODEL environment variable. diff --git a/openhands/app_server/app_conversation/app_conversation_info_service.py b/openhands/app_server/app_conversation/app_conversation_info_service.py index 1bbd06531b..56c4d77fae 100644 --- a/openhands/app_server/app_conversation/app_conversation_info_service.py +++ b/openhands/app_server/app_conversation/app_conversation_info_service.py @@ -26,6 +26,7 @@ class AppConversationInfoService(ABC): sort_order: AppConversationSortOrder = AppConversationSortOrder.CREATED_AT_DESC, page_id: str | None = None, limit: int = 100, + include_sub_conversations: bool = False, ) -> AppConversationInfoPage: """Search for sandboxed conversations.""" diff --git a/openhands/app_server/app_conversation/app_conversation_models.py b/openhands/app_server/app_conversation/app_conversation_models.py index 1b2f201dcd..d918a2d9b1 100644 --- a/openhands/app_server/app_conversation/app_conversation_models.py +++ b/openhands/app_server/app_conversation/app_conversation_models.py @@ -16,6 +16,13 @@ from openhands.sdk.llm import MetricsSnapshot from openhands.storage.data_models.conversation_metadata import ConversationTrigger +class AgentType(Enum): + """Agent type for conversation.""" + + DEFAULT = 'default' + PLAN = 'plan' + + class AppConversationInfo(BaseModel): """Conversation info which does not contain status.""" @@ -34,6 +41,9 @@ class AppConversationInfo(BaseModel): metrics: MetricsSnapshot | None = None + parent_conversation_id: OpenHandsUUID | None = None + sub_conversation_ids: list[OpenHandsUUID] = Field(default_factory=list) + created_at: datetime = Field(default_factory=utc_now) updated_at: datetime = Field(default_factory=utc_now) @@ -98,6 +108,8 @@ class AppConversationStartRequest(BaseModel): title: str | None = None trigger: ConversationTrigger | None = None pr_number: list[int] = Field(default_factory=list) + parent_conversation_id: OpenHandsUUID | None = None + agent_type: AgentType = Field(default=AgentType.DEFAULT) class AppConversationStartTaskStatus(Enum): diff --git a/openhands/app_server/app_conversation/app_conversation_router.py b/openhands/app_server/app_conversation/app_conversation_router.py index 83596b64a5..b66d998362 100644 --- a/openhands/app_server/app_conversation/app_conversation_router.py +++ b/openhands/app_server/app_conversation/app_conversation_router.py @@ -99,6 +99,12 @@ async def search_app_conversations( lte=100, ), ] = 100, + include_sub_conversations: Annotated[ + bool, + Query( + title='If True, include sub-conversations in the results. If False (default), exclude all sub-conversations.' + ), + ] = False, app_conversation_service: AppConversationService = ( app_conversation_service_dependency ), @@ -114,6 +120,7 @@ async def search_app_conversations( updated_at__lt=updated_at__lt, page_id=page_id, limit=limit, + include_sub_conversations=include_sub_conversations, ) @@ -193,7 +200,8 @@ async def stream_app_conversation_start( user_context: UserContext = user_context_dependency, ) -> list[AppConversationStartTask]: """Start an app conversation start task and stream updates from it. - Leaves the connection open until either the conversation starts or there was an error""" + Leaves the connection open until either the conversation starts or there was an error + """ response = StreamingResponse( _stream_app_conversation_start(request, user_context), media_type='application/json', @@ -207,6 +215,10 @@ async def search_app_conversation_start_tasks( UUID | None, Query(title='Filter by conversation ID equal to this value'), ] = None, + created_at__gte: Annotated[ + datetime | None, + Query(title='Filter by created_at greater than or equal to this datetime'), + ] = None, sort_order: Annotated[ AppConversationStartTaskSortOrder, Query(title='Sort order for the results'), @@ -233,6 +245,7 @@ async def search_app_conversation_start_tasks( return ( await app_conversation_start_task_service.search_app_conversation_start_tasks( conversation_id__eq=conversation_id__eq, + created_at__gte=created_at__gte, sort_order=sort_order, page_id=page_id, limit=limit, @@ -246,6 +259,10 @@ async def count_app_conversation_start_tasks( UUID | None, Query(title='Filter by conversation ID equal to this value'), ] = None, + created_at__gte: Annotated[ + datetime | None, + Query(title='Filter by created_at greater than or equal to this datetime'), + ] = None, app_conversation_start_task_service: AppConversationStartTaskService = ( app_conversation_start_task_service_dependency ), @@ -253,6 +270,7 @@ async def count_app_conversation_start_tasks( """Count conversation start tasks matching the given filters.""" return await app_conversation_start_task_service.count_app_conversation_start_tasks( conversation_id__eq=conversation_id__eq, + created_at__gte=created_at__gte, ) diff --git a/openhands/app_server/app_conversation/app_conversation_service.py b/openhands/app_server/app_conversation/app_conversation_service.py index d910856c76..8c39a66ae5 100644 --- a/openhands/app_server/app_conversation/app_conversation_service.py +++ b/openhands/app_server/app_conversation/app_conversation_service.py @@ -30,6 +30,7 @@ class AppConversationService(ABC): sort_order: AppConversationSortOrder = AppConversationSortOrder.CREATED_AT_DESC, page_id: str | None = None, limit: int = 100, + include_sub_conversations: bool = False, ) -> AppConversationPage: """Search for sandboxed conversations.""" diff --git a/openhands/app_server/app_conversation/app_conversation_start_task_service.py b/openhands/app_server/app_conversation/app_conversation_start_task_service.py index 05229411f5..230b26cd8f 100644 --- a/openhands/app_server/app_conversation/app_conversation_start_task_service.py +++ b/openhands/app_server/app_conversation/app_conversation_start_task_service.py @@ -1,5 +1,6 @@ import asyncio from abc import ABC, abstractmethod +from datetime import datetime from uuid import UUID from openhands.app_server.app_conversation.app_conversation_models import ( @@ -18,6 +19,7 @@ class AppConversationStartTaskService(ABC): async def search_app_conversation_start_tasks( self, conversation_id__eq: UUID | None = None, + created_at__gte: datetime | None = None, sort_order: AppConversationStartTaskSortOrder = AppConversationStartTaskSortOrder.CREATED_AT_DESC, page_id: str | None = None, limit: int = 100, @@ -28,6 +30,7 @@ class AppConversationStartTaskService(ABC): async def count_app_conversation_start_tasks( self, conversation_id__eq: UUID | None = None, + created_at__gte: datetime | None = None, ) -> int: """Count conversation start tasks.""" diff --git a/openhands/app_server/app_conversation/live_status_app_conversation_service.py b/openhands/app_server/app_conversation/live_status_app_conversation_service.py index cc10d254e7..b57cf3013e 100644 --- a/openhands/app_server/app_conversation/live_status_app_conversation_service.py +++ b/openhands/app_server/app_conversation/live_status_app_conversation_service.py @@ -21,6 +21,7 @@ from openhands.app_server.app_conversation.app_conversation_info_service import AppConversationInfoService, ) from openhands.app_server.app_conversation.app_conversation_models import ( + AgentType, AppConversation, AppConversationInfo, AppConversationPage, @@ -62,6 +63,9 @@ from openhands.app_server.sandbox.sandbox_spec_service import SandboxSpecService from openhands.app_server.services.injector import InjectorState from openhands.app_server.services.jwt_service import JwtService from openhands.app_server.user.user_context import UserContext +from openhands.app_server.utils.docker_utils import ( + replace_localhost_hostname_for_docker, +) from openhands.experiments.experiment_manager import ExperimentManagerImpl from openhands.integrations.provider import ProviderType from openhands.sdk import LocalWorkspace @@ -70,6 +74,7 @@ from openhands.sdk.llm import LLM from openhands.sdk.security.confirmation_policy import AlwaysConfirm from openhands.sdk.workspace.remote.async_remote_workspace import AsyncRemoteWorkspace from openhands.tools.preset.default import get_default_agent +from openhands.tools.preset.planning import get_planning_agent _conversation_info_type_adapter = TypeAdapter(list[ConversationInfo | None]) _logger = logging.getLogger(__name__) @@ -103,6 +108,7 @@ class LiveStatusAppConversationService(GitAppConversationService): sort_order: AppConversationSortOrder = AppConversationSortOrder.CREATED_AT_DESC, page_id: str | None = None, limit: int = 20, + include_sub_conversations: bool = False, ) -> AppConversationPage: """Search for sandboxed conversations.""" page = await self.app_conversation_info_service.search_app_conversation_info( @@ -114,6 +120,7 @@ class LiveStatusAppConversationService(GitAppConversationService): sort_order=sort_order, page_id=page_id, limit=limit, + include_sub_conversations=include_sub_conversations, ) conversations: list[AppConversation] = await self._build_app_conversations( page.items @@ -168,6 +175,20 @@ class LiveStatusAppConversationService(GitAppConversationService): ) -> AsyncGenerator[AppConversationStartTask, None]: # Create and yield the start task user_id = await self.user_context.get_user_id() + + # Validate and inherit from parent conversation if provided + if request.parent_conversation_id: + parent_info = ( + await self.app_conversation_info_service.get_app_conversation_info( + request.parent_conversation_id + ) + ) + if parent_info is None: + raise ValueError( + f'Parent conversation not found: {request.parent_conversation_id}' + ) + self._inherit_configuration_from_parent(request, parent_info) + task = AppConversationStartTask( created_by_user_id=user_id, request=request, @@ -206,6 +227,8 @@ class LiveStatusAppConversationService(GitAppConversationService): request.initial_message, request.git_provider, sandbox_spec.working_dir, + request.agent_type, + request.llm_model, ) ) @@ -224,6 +247,7 @@ class LiveStatusAppConversationService(GitAppConversationService): headers={'X-Session-API-Key': sandbox.session_api_key}, timeout=self.sandbox_startup_timeout, ) + response.raise_for_status() info = ConversationInfo.model_validate(response.json()) @@ -241,6 +265,7 @@ class LiveStatusAppConversationService(GitAppConversationService): git_provider=request.git_provider, trigger=request.trigger, pr_number=request.pr_number, + parent_conversation_id=request.parent_conversation_id, ) await self.app_conversation_info_service.save_app_conversation_info( app_conversation_info @@ -450,13 +475,46 @@ class LiveStatusAppConversationService(GitAppConversationService): for exposed_url in exposed_urls if exposed_url.name == AGENT_SERVER ) + agent_server_url = replace_localhost_hostname_for_docker(agent_server_url) return agent_server_url + def _inherit_configuration_from_parent( + self, request: AppConversationStartRequest, parent_info: AppConversationInfo + ) -> None: + """Inherit configuration from parent conversation if not explicitly provided. + + This ensures sub-conversations automatically inherit: + - Sandbox ID (to share the same workspace/environment) + - Git parameters (repository, branch, provider) + - LLM model + + Args: + request: The conversation start request to modify + parent_info: The parent conversation info to inherit from + """ + # Inherit sandbox_id from parent to share the same workspace/environment + if not request.sandbox_id: + request.sandbox_id = parent_info.sandbox_id + + # Inherit git parameters from parent if not provided + if not request.selected_repository: + request.selected_repository = parent_info.selected_repository + if not request.selected_branch: + request.selected_branch = parent_info.selected_branch + if not request.git_provider: + request.git_provider = parent_info.git_provider + + # Inherit LLM model from parent if not provided + if not request.llm_model and parent_info.llm_model: + request.llm_model = parent_info.llm_model + async def _build_start_conversation_request_for_user( self, initial_message: SendMessageRequest | None, git_provider: ProviderType | None, working_dir: str, + agent_type: AgentType = AgentType.DEFAULT, + llm_model: str | None = None, ) -> StartConversationRequest: user = await self.user_context.get_user_info() @@ -488,13 +546,19 @@ class LiveStatusAppConversationService(GitAppConversationService): workspace = LocalWorkspace(working_dir=working_dir) + # Use provided llm_model if available, otherwise fall back to user's default + model = llm_model or user.llm_model llm = LLM( - model=user.llm_model, + model=model, base_url=user.llm_base_url, api_key=user.llm_api_key, usage_id='agent', ) - agent = get_default_agent(llm=llm) + # Select agent based on agent_type + if agent_type == AgentType.PLAN: + agent = get_planning_agent(llm=llm) + else: + agent = get_default_agent(llm=llm) conversation_id = uuid4() agent = ExperimentManagerImpl.run_agent_variant_tests__v1( diff --git a/openhands/app_server/app_conversation/sql_app_conversation_info_service.py b/openhands/app_server/app_conversation/sql_app_conversation_info_service.py index 903d578412..a7e5f7e497 100644 --- a/openhands/app_server/app_conversation/sql_app_conversation_info_service.py +++ b/openhands/app_server/app_conversation/sql_app_conversation_info_service.py @@ -95,6 +95,7 @@ class StoredConversationMetadata(Base): # type: ignore conversation_version = Column(String, nullable=False, default='V0', index=True) sandbox_id = Column(String, nullable=True, index=True) + parent_conversation_id = Column(String, nullable=True, index=True) @dataclass @@ -117,10 +118,18 @@ class SQLAppConversationInfoService(AppConversationInfoService): sort_order: AppConversationSortOrder = AppConversationSortOrder.CREATED_AT_DESC, page_id: str | None = None, limit: int = 100, + include_sub_conversations: bool = False, ) -> AppConversationInfoPage: """Search for sandboxed conversations without permission checks.""" query = await self._secure_select() + # Conditionally exclude sub-conversations based on the parameter + if not include_sub_conversations: + # Exclude sub-conversations (only include top-level conversations) + query = query.where( + StoredConversationMetadata.parent_conversation_id.is_(None) + ) + query = self._apply_filters( query=query, title__contains=title__contains, @@ -237,6 +246,26 @@ class SQLAppConversationInfoService(AppConversationInfoService): query = query.where(*conditions) return query + async def _get_sub_conversation_ids( + self, parent_conversation_id: UUID + ) -> list[UUID]: + """Get all sub-conversation IDs for a given parent conversation. + + Args: + parent_conversation_id: The ID of the parent conversation + + Returns: + List of sub-conversation IDs + """ + query = await self._secure_select() + query = query.where( + StoredConversationMetadata.parent_conversation_id + == str(parent_conversation_id) + ) + result_set = await self.db_session.execute(query) + rows = result_set.scalars().all() + return [UUID(row.conversation_id) for row in rows] + async def get_app_conversation_info( self, conversation_id: UUID ) -> AppConversationInfo | None: @@ -247,7 +276,9 @@ class SQLAppConversationInfoService(AppConversationInfoService): result_set = await self.db_session.execute(query) result = result_set.scalar_one_or_none() if result: - return self._to_info(result) + # Fetch sub-conversation IDs + sub_conversation_ids = await self._get_sub_conversation_ids(conversation_id) + return self._to_info(result, sub_conversation_ids=sub_conversation_ids) return None async def batch_get_app_conversation_info( @@ -266,8 +297,13 @@ class SQLAppConversationInfoService(AppConversationInfoService): results: list[AppConversationInfo | None] = [] for conversation_id in conversation_id_strs: info = info_by_id.get(conversation_id) + sub_conversation_ids = await self._get_sub_conversation_ids( + UUID(conversation_id) + ) if info: - results.append(self._to_info(info)) + results.append( + self._to_info(info, sub_conversation_ids=sub_conversation_ids) + ) else: results.append(None) @@ -302,6 +338,11 @@ class SQLAppConversationInfoService(AppConversationInfoService): llm_model=info.llm_model, conversation_version='V1', sandbox_id=info.sandbox_id, + parent_conversation_id=( + str(info.parent_conversation_id) + if info.parent_conversation_id + else None + ), ) await self.db_session.merge(stored) @@ -314,7 +355,11 @@ class SQLAppConversationInfoService(AppConversationInfoService): ) return query - def _to_info(self, stored: StoredConversationMetadata) -> AppConversationInfo: + def _to_info( + self, + stored: StoredConversationMetadata, + sub_conversation_ids: list[UUID] | None = None, + ) -> AppConversationInfo: # V1 conversations should always have a sandbox_id sandbox_id = stored.sandbox_id assert sandbox_id is not None @@ -354,6 +399,12 @@ class SQLAppConversationInfoService(AppConversationInfoService): pr_number=stored.pr_number, llm_model=stored.llm_model, metrics=metrics, + parent_conversation_id=( + UUID(stored.parent_conversation_id) + if stored.parent_conversation_id + else None + ), + sub_conversation_ids=sub_conversation_ids or [], created_at=created_at, updated_at=updated_at, ) diff --git a/openhands/app_server/app_conversation/sql_app_conversation_start_task_service.py b/openhands/app_server/app_conversation/sql_app_conversation_start_task_service.py index 91b48ab781..4913e795bb 100644 --- a/openhands/app_server/app_conversation/sql_app_conversation_start_task_service.py +++ b/openhands/app_server/app_conversation/sql_app_conversation_start_task_service.py @@ -18,6 +18,7 @@ from __future__ import annotations import logging from dataclasses import dataclass +from datetime import datetime from typing import AsyncGenerator from uuid import UUID @@ -75,6 +76,7 @@ class SQLAppConversationStartTaskService(AppConversationStartTaskService): async def search_app_conversation_start_tasks( self, conversation_id__eq: UUID | None = None, + created_at__gte: datetime | None = None, sort_order: AppConversationStartTaskSortOrder = AppConversationStartTaskSortOrder.CREATED_AT_DESC, page_id: str | None = None, limit: int = 100, @@ -95,6 +97,12 @@ class SQLAppConversationStartTaskService(AppConversationStartTaskService): == conversation_id__eq ) + # Apply created_at__gte filter + if created_at__gte is not None: + query = query.where( + StoredAppConversationStartTask.created_at >= created_at__gte + ) + # Add sort order if sort_order == AppConversationStartTaskSortOrder.CREATED_AT: query = query.order_by(StoredAppConversationStartTask.created_at) @@ -139,6 +147,7 @@ class SQLAppConversationStartTaskService(AppConversationStartTaskService): async def count_app_conversation_start_tasks( self, conversation_id__eq: UUID | None = None, + created_at__gte: datetime | None = None, ) -> int: """Count conversation start tasks.""" query = select(func.count(StoredAppConversationStartTask.id)) @@ -156,6 +165,12 @@ class SQLAppConversationStartTaskService(AppConversationStartTaskService): == conversation_id__eq ) + # Apply created_at__gte filter + if created_at__gte is not None: + query = query.where( + StoredAppConversationStartTask.created_at >= created_at__gte + ) + result = await self.session.execute(query) count = result.scalar() return count or 0 diff --git a/openhands/app_server/app_lifespan/alembic/versions/003.py b/openhands/app_server/app_lifespan/alembic/versions/003.py index e959907939..6879b4358f 100644 --- a/openhands/app_server/app_lifespan/alembic/versions/003.py +++ b/openhands/app_server/app_lifespan/alembic/versions/003.py @@ -1,8 +1,8 @@ -"""Update conversation_metadata table to match StoredConversationMetadata dataclass +"""add parent_conversation_id to conversation_metadata Revision ID: 003 Revises: 002 -Create Date: 2025-11-11 00:00:00.000000 +Create Date: 2025-11-06 00:00:00.000000 """ @@ -13,32 +13,29 @@ from alembic import op # revision identifiers, used by Alembic. revision: str = '003' -down_revision: Union[str, Sequence[str], None] = '002' +down_revision: Union[str, None] = '002' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: """Upgrade schema.""" - # Drop columns that are not in the StoredConversationMetadata dataclass - op.drop_column('conversation_metadata', 'github_user_id') - op.alter_column( + op.add_column( 'conversation_metadata', - 'user_id', - existing_type=sa.String(), - nullable=True, + sa.Column('parent_conversation_id', sa.String(), nullable=True), + ) + op.create_index( + op.f('ix_conversation_metadata_parent_conversation_id'), + 'conversation_metadata', + ['parent_conversation_id'], + unique=False, ) def downgrade() -> None: """Downgrade schema.""" - # Add back the dropped columns - op.add_column( - 'conversation_metadata', sa.Column('github_user_id', sa.String(), nullable=True) - ) - op.alter_column( - 'conversation_metadata', - 'user_id', - existing_type=sa.String(), - nullable=False, + op.drop_index( + op.f('ix_conversation_metadata_parent_conversation_id'), + table_name='conversation_metadata', ) + op.drop_column('conversation_metadata', 'parent_conversation_id') diff --git a/openhands/app_server/sandbox/docker_sandbox_service.py b/openhands/app_server/sandbox/docker_sandbox_service.py index 0f6ea51916..2d7ff39c44 100644 --- a/openhands/app_server/sandbox/docker_sandbox_service.py +++ b/openhands/app_server/sandbox/docker_sandbox_service.py @@ -32,6 +32,9 @@ from openhands.app_server.sandbox.sandbox_service import ( ) from openhands.app_server.sandbox.sandbox_spec_service import SandboxSpecService from openhands.app_server.services.injector import InjectorState +from openhands.app_server.utils.docker_utils import ( + replace_localhost_hostname_for_docker, +) _logger = logging.getLogger(__name__) SESSION_API_KEY_VARIABLE = 'OH_SESSION_API_KEYS_0' @@ -185,6 +188,9 @@ class DockerSandboxService(SandboxService): if exposed_url.name == AGENT_SERVER ) try: + # When running in Docker, replace localhost hostname with host.docker.internal for internal requests + app_server_url = replace_localhost_hostname_for_docker(app_server_url) + response = await self.httpx_client.get( f'{app_server_url}{self.health_check_path}' ) @@ -192,7 +198,7 @@ class DockerSandboxService(SandboxService): except asyncio.CancelledError: raise except Exception as exc: - _logger.info(f'Sandbox server not running: {exc}') + _logger.info(f'Sandbox server not running: {app_server_url} : {exc}') sandbox_info.status = SandboxStatus.ERROR sandbox_info.exposed_urls = None sandbox_info.session_api_key = None diff --git a/openhands/app_server/sandbox/docker_sandbox_spec_service.py b/openhands/app_server/sandbox/docker_sandbox_spec_service.py index cd42cbfe70..7504cec537 100644 --- a/openhands/app_server/sandbox/docker_sandbox_spec_service.py +++ b/openhands/app_server/sandbox/docker_sandbox_spec_service.py @@ -81,10 +81,49 @@ class DockerSandboxSpecServiceInjector(SandboxSpecServiceInjector): try: docker_client.images.get(spec.id) except docker.errors.ImageNotFound: - _logger.info(f'⬇️ Pulling Docker Image: {spec.id}') - # Pull in a background thread to prevent locking up the main runloop - loop = asyncio.get_running_loop() - await loop.run_in_executor(None, docker_client.images.pull, spec.id) - _logger.info(f'⬇️ Finished Pulling Docker Image: {spec.id}') + _logger.info(f'⬇️ Pulling Docker Image: {spec.id}') + await self._pull_with_progress_logging(docker_client, spec.id) + _logger.info(f'⬇️ Finished Pulling Docker Image: {spec.id}') except docker.errors.APIError as exc: raise SandboxError(f'Error Getting Docker Image: {spec.id}') from exc + + async def _pull_with_progress_logging( + self, docker_client: docker.DockerClient, image_id: str + ): + """Pull Docker image with periodic progress logging every 5 seconds.""" + # Event to signal when pull is complete + pull_complete = asyncio.Event() + + async def periodic_logger(): + """Log progress message every 5 seconds until pull is complete.""" + while not pull_complete.is_set(): + try: + await asyncio.wait_for(pull_complete.wait(), timeout=5.0) + break # Pull completed + except asyncio.TimeoutError: + # 5 seconds elapsed, log progress message + _logger.info(f'🔄 Downloading Docker Image: {image_id}...') + + async def pull_image(): + """Perform the actual Docker image pull.""" + try: + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, docker_client.images.pull, image_id) + finally: + pull_complete.set() + + # Run both tasks concurrently + logger_task = asyncio.create_task(periodic_logger()) + pull_task = asyncio.create_task(pull_image()) + + try: + # Wait for pull to complete + await pull_task + finally: + # Ensure logger task is cancelled if still running + if not logger_task.done(): + logger_task.cancel() + try: + await logger_task + except asyncio.CancelledError: + pass diff --git a/openhands/app_server/sandbox/remote_sandbox_service.py b/openhands/app_server/sandbox/remote_sandbox_service.py index c7d444c4ec..c96b7362c0 100644 --- a/openhands/app_server/sandbox/remote_sandbox_service.py +++ b/openhands/app_server/sandbox/remote_sandbox_service.py @@ -318,7 +318,6 @@ class RemoteSandboxService(SandboxService): created_at=utc_now(), ) self.db_session.add(stored_sandbox) - await self.db_session.commit() # Prepare environment variables environment = await self._init_environment(sandbox_spec, sandbox_id) @@ -407,7 +406,6 @@ class RemoteSandboxService(SandboxService): if not stored_sandbox: return False await self.db_session.delete(stored_sandbox) - await self.db_session.commit() runtime_data = await self._get_runtime(sandbox_id) response = await self._send_runtime_api_request( 'POST', diff --git a/openhands/app_server/sandbox/sandbox_spec_service.py b/openhands/app_server/sandbox/sandbox_spec_service.py index fd091ca130..dad14297c5 100644 --- a/openhands/app_server/sandbox/sandbox_spec_service.py +++ b/openhands/app_server/sandbox/sandbox_spec_service.py @@ -11,7 +11,7 @@ from openhands.sdk.utils.models import DiscriminatedUnionMixin # The version of the agent server to use for deployments. # Typically this will be the same as the values from the pyproject.toml -AGENT_SERVER_IMAGE = 'ghcr.io/openhands/agent-server:f3c0c19-python' +AGENT_SERVER_IMAGE = 'ghcr.io/openhands/agent-server:4e2ecd8-python' class SandboxSpecService(ABC): diff --git a/openhands/app_server/utils/docker_utils.py b/openhands/app_server/utils/docker_utils.py new file mode 100644 index 0000000000..03821c3974 --- /dev/null +++ b/openhands/app_server/utils/docker_utils.py @@ -0,0 +1,32 @@ +from urllib.parse import urlparse, urlunparse + +from openhands.utils.environment import is_running_in_docker + + +def replace_localhost_hostname_for_docker( + url: str, replacement: str = 'host.docker.internal' +) -> str: + """Replace localhost hostname in URL with the specified replacement when running in Docker. + + This function only performs the replacement when the code is running inside a Docker + container. When not running in Docker, it returns the original URL unchanged. + + Only replaces the hostname if it's exactly 'localhost', preserving all other + parts of the URL including port, path, query parameters, etc. + + Args: + url: The URL to process + replacement: The hostname to replace localhost with (default: 'host.docker.internal') + + Returns: + URL with localhost hostname replaced if running in Docker and hostname is localhost, + otherwise returns the original URL unchanged + """ + if not is_running_in_docker(): + return url + parsed = urlparse(url) + if parsed.hostname == 'localhost': + # Replace only the hostname part, preserving port and everything else + netloc = parsed.netloc.replace('localhost', replacement, 1) + return urlunparse(parsed._replace(netloc=netloc)) + return url diff --git a/openhands/llm/model_features.py b/openhands/llm/model_features.py index 954cee00fd..4673d6b73f 100644 --- a/openhands/llm/model_features.py +++ b/openhands/llm/model_features.py @@ -80,6 +80,8 @@ FUNCTION_CALLING_PATTERNS: list[str] = [ 'o4-mini*', # Google Gemini 'gemini-2.5-pro*', + # Groq models (via groq/ provider prefix) + 'groq/*', # Others 'kimi-k2-0711-preview', 'kimi-k2-instruct', diff --git a/openhands/server/data_models/conversation_info.py b/openhands/server/data_models/conversation_info.py index f4c4a77809..78af0e3dc1 100644 --- a/openhands/server/data_models/conversation_info.py +++ b/openhands/server/data_models/conversation_info.py @@ -28,3 +28,4 @@ class ConversationInfo: created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) pr_number: list[int] = field(default_factory=list) conversation_version: str = 'V0' + sub_conversation_ids: list[str] = field(default_factory=list) diff --git a/openhands/server/routes/manage_conversations.py b/openhands/server/routes/manage_conversations.py index 8984f79e8c..56f6b95f6c 100644 --- a/openhands/server/routes/manage_conversations.py +++ b/openhands/server/routes/manage_conversations.py @@ -1,3 +1,4 @@ +import asyncio import base64 import itertools import json @@ -5,12 +6,14 @@ import os import re import uuid from datetime import datetime, timedelta, timezone +from typing import Annotated import base62 -from fastapi import APIRouter, Depends, status +from fastapi import APIRouter, Depends, Query, Request, status from fastapi.responses import JSONResponse from jinja2 import Environment, FileSystemLoader from pydantic import BaseModel, ConfigDict, Field +from sqlalchemy.ext.asyncio import AsyncSession from openhands.app_server.app_conversation.app_conversation_info_service import ( AppConversationInfoService, @@ -24,9 +27,11 @@ from openhands.app_server.app_conversation.app_conversation_service import ( from openhands.app_server.config import ( depends_app_conversation_info_service, depends_app_conversation_service, + depends_db_session, depends_sandbox_service, ) from openhands.app_server.sandbox.sandbox_service import SandboxService +from openhands.app_server.services.db_session_injector import set_db_session_keep_open from openhands.core.config.llm_config import LLMConfig from openhands.core.config.mcp_config import MCPConfig from openhands.core.logger import openhands_logger as logger @@ -99,6 +104,7 @@ app = APIRouter(prefix='/api', dependencies=get_dependencies()) app_conversation_service_dependency = depends_app_conversation_service() app_conversation_info_service_dependency = depends_app_conversation_info_service() sandbox_service_dependency = depends_sandbox_service() +db_session_dependency = depends_db_session() def _filter_conversations_by_age( @@ -304,6 +310,12 @@ async def search_conversations( limit: int = 20, selected_repository: str | None = None, conversation_trigger: ConversationTrigger | None = None, + include_sub_conversations: Annotated[ + bool, + Query( + title='If True, include sub-conversations in the results. If False (default), exclude all sub-conversations.' + ), + ] = False, conversation_store: ConversationStore = Depends(get_conversation_store), app_conversation_service: AppConversationService = app_conversation_service_dependency, ) -> ConversationInfoResultSet: @@ -338,6 +350,7 @@ async def search_conversations( limit=limit, # Apply age filter at the service level if possible created_at__gte=age_filter_date, + include_sub_conversations=include_sub_conversations, ) # Convert V1 conversations to ConversationInfo format @@ -467,16 +480,22 @@ async def get_conversation( @app.delete('/conversations/{conversation_id}') async def delete_conversation( + request: Request, conversation_id: str = Depends(validate_conversation_id), user_id: str | None = Depends(get_user_id), app_conversation_service: AppConversationService = app_conversation_service_dependency, + app_conversation_info_service: AppConversationInfoService = app_conversation_info_service_dependency, sandbox_service: SandboxService = sandbox_service_dependency, + db_session: AsyncSession = db_session_dependency, ) -> bool: + set_db_session_keep_open(request.state, True) # Try V1 conversation first v1_result = await _try_delete_v1_conversation( conversation_id, app_conversation_service, + app_conversation_info_service, sandbox_service, + db_session, ) if v1_result is not None: return v1_result @@ -488,23 +507,32 @@ async def delete_conversation( async def _try_delete_v1_conversation( conversation_id: str, app_conversation_service: AppConversationService, + app_conversation_info_service: AppConversationInfoService, sandbox_service: SandboxService, + db_session: AsyncSession, ) -> bool | None: """Try to delete a V1 conversation. Returns None if not a V1 conversation.""" result = None try: conversation_uuid = uuid.UUID(conversation_id) # Check if it's a V1 conversation by trying to get it - app_conversation = await app_conversation_service.get_app_conversation( - conversation_uuid + app_conversation_info = ( + await app_conversation_info_service.get_app_conversation_info( + conversation_uuid + ) ) - if app_conversation: + if app_conversation_info: # This is a V1 conversation, delete it using the app conversation service # Pass the conversation ID for secure deletion result = await app_conversation_service.delete_app_conversation( - app_conversation.id + app_conversation_info.id + ) + # Delete the sandbox in the background + asyncio.create_task( + _delete_sandbox_and_close_connection( + sandbox_service, app_conversation_info.sandbox_id, db_session + ) ) - await sandbox_service.delete_sandbox(app_conversation.sandbox_id) except (ValueError, TypeError): # Not a valid UUID, continue with V0 logic pass @@ -515,6 +543,16 @@ async def _try_delete_v1_conversation( return result +async def _delete_sandbox_and_close_connection( + sandbox_service: SandboxService, sandbox_id: str, db_session: AsyncSession +): + try: + await sandbox_service.delete_sandbox(sandbox_id) + await db_session.commit() + finally: + await db_session.aclose() + + async def _delete_v0_conversation(conversation_id: str, user_id: str | None) -> bool: """Delete a V0 conversation using the legacy logic.""" conversation_store = await ConversationStoreImpl.get_instance(config, user_id) @@ -1157,6 +1195,7 @@ async def _fetch_v1_conversations_safe( app_conversation_service: App conversation service for V1 v1_page_id: Page ID for V1 pagination limit: Maximum number of results + include_sub_conversations: If True, include sub-conversations in results Returns: Tuple of (v1_conversations, v1_next_page_id) @@ -1432,4 +1471,7 @@ def _to_conversation_info(app_conversation: AppConversation) -> ConversationInfo created_at=app_conversation.created_at, pr_number=app_conversation.pr_number, conversation_version='V1', + sub_conversation_ids=[ + sub_id.hex for sub_id in app_conversation.sub_conversation_ids + ], ) diff --git a/tests/unit/app_server/test_docker_sandbox_service.py b/tests/unit/app_server/test_docker_sandbox_service.py index f79988773a..d428b3b664 100644 --- a/tests/unit/app_server/test_docker_sandbox_service.py +++ b/tests/unit/app_server/test_docker_sandbox_service.py @@ -697,10 +697,14 @@ class TestDockerSandboxService: assert result is not None assert isinstance(result.created_at, datetime) + @patch( + 'openhands.app_server.utils.docker_utils.is_running_in_docker', + return_value=True, + ) async def test_container_to_checked_sandbox_info_health_check_success( - self, service, mock_running_container + self, mock_is_docker, service, mock_running_container ): - """Test health check success.""" + """Test health check success when running in Docker.""" # Setup service.httpx_client.get.return_value.raise_for_status.return_value = None @@ -715,7 +719,34 @@ class TestDockerSandboxService: assert result.exposed_urls is not None assert result.session_api_key == 'session_key_123' - # Verify health check was called + # Verify health check was called with Docker-internal URL + service.httpx_client.get.assert_called_once_with( + 'http://host.docker.internal:12345/health' + ) + + @patch( + 'openhands.app_server.utils.docker_utils.is_running_in_docker', + return_value=False, + ) + async def test_container_to_checked_sandbox_info_health_check_success_not_in_docker( + self, mock_is_docker, service, mock_running_container + ): + """Test health check success when not running in Docker.""" + # Setup + service.httpx_client.get.return_value.raise_for_status.return_value = None + + # Execute + result = await service._container_to_checked_sandbox_info( + mock_running_container + ) + + # Verify + assert result is not None + assert result.status == SandboxStatus.RUNNING + assert result.exposed_urls is not None + assert result.session_api_key == 'session_key_123' + + # Verify health check was called with original localhost URL service.httpx_client.get.assert_called_once_with( 'http://localhost:12345/health' ) diff --git a/tests/unit/app_server/test_docker_sandbox_spec_service_injector.py b/tests/unit/app_server/test_docker_sandbox_spec_service_injector.py index 1df987c56f..059cc27e8a 100644 --- a/tests/unit/app_server/test_docker_sandbox_spec_service_injector.py +++ b/tests/unit/app_server/test_docker_sandbox_spec_service_injector.py @@ -447,3 +447,85 @@ class TestDockerSandboxSpecServiceInjector: # Verify no Docker operations were performed mock_get_docker_client.assert_not_called() mock_docker_client.images.get.assert_not_called() + + @patch('openhands.app_server.sandbox.docker_sandbox_spec_service.get_docker_client') + @patch('openhands.app_server.sandbox.docker_sandbox_spec_service._logger') + async def test_pull_with_progress_logging( + self, mock_logger, mock_get_docker_client, sample_spec + ): + """Test that periodic progress logging occurs during image pull.""" + # Setup + mock_docker_client = MagicMock() + mock_get_docker_client.return_value = mock_docker_client + mock_docker_client.images.get.side_effect = ImageNotFound('Image not found') + + # Create a future that will be resolved after some delay to simulate slow pull + pull_future = asyncio.Future() + + async def delayed_pull_completion(): + # Wait for multiple logging intervals to pass + await asyncio.sleep(12) # 12 seconds = 2 logging intervals (5s each) + pull_future.set_result(MagicMock()) + + # Start the delayed completion task + asyncio.create_task(delayed_pull_completion()) + + # Mock the executor to return our delayed future + with patch('asyncio.get_running_loop') as mock_get_loop: + mock_loop = MagicMock() + mock_get_loop.return_value = mock_loop + mock_loop.run_in_executor.return_value = pull_future + + injector = DockerSandboxSpecServiceInjector() + + # Execute + await injector.pull_spec_if_missing(sample_spec) + + # Verify that progress logging occurred + # Should have initial pull message, progress messages, and completion message + progress_calls = [ + call + for call in mock_logger.info.call_args_list + if '🔄 Downloading Docker Image:' in str(call) + ] + + # Should have at least 2 progress log messages (every 5 seconds for 12 seconds) + assert len(progress_calls) >= 2 + + # Verify the progress message format + for call in progress_calls: + assert '🔄 Downloading Docker Image: test-image:latest...' in str(call) + + @patch('openhands.app_server.sandbox.docker_sandbox_spec_service.get_docker_client') + @patch('openhands.app_server.sandbox.docker_sandbox_spec_service._logger') + async def test_pull_with_progress_logging_fast_pull( + self, mock_logger, mock_get_docker_client, sample_spec + ): + """Test that no progress logging occurs for fast pulls (< 5 seconds).""" + # Setup + mock_docker_client = MagicMock() + mock_get_docker_client.return_value = mock_docker_client + mock_docker_client.images.get.side_effect = ImageNotFound('Image not found') + + # Mock fast pull (completes immediately) + with patch('asyncio.get_running_loop') as mock_get_loop: + mock_loop = MagicMock() + mock_get_loop.return_value = mock_loop + fast_future = asyncio.Future() + fast_future.set_result(MagicMock()) + mock_loop.run_in_executor.return_value = fast_future + + injector = DockerSandboxSpecServiceInjector() + + # Execute + await injector.pull_spec_if_missing(sample_spec) + + # Verify that no progress logging occurred (only start/end messages) + progress_calls = [ + call + for call in mock_logger.info.call_args_list + if '🔄 Downloading Docker Image:' in str(call) + ] + + # Should have no progress log messages for fast pulls + assert len(progress_calls) == 0 diff --git a/tests/unit/app_server/test_docker_utils.py b/tests/unit/app_server/test_docker_utils.py new file mode 100644 index 0000000000..127c6dfc6e --- /dev/null +++ b/tests/unit/app_server/test_docker_utils.py @@ -0,0 +1,297 @@ +from unittest.mock import patch + +from openhands.app_server.utils.docker_utils import ( + replace_localhost_hostname_for_docker, +) + + +class TestReplaceLocalhostHostnameForDocker: + """Test cases for replace_localhost_hostname_for_docker function.""" + + @patch( + 'openhands.app_server.utils.docker_utils.is_running_in_docker', + return_value=True, + ) + def test_replace_localhost_basic_in_docker(self, mock_is_docker): + """Test basic localhost replacement when running in Docker.""" + # Basic HTTP URL + result = replace_localhost_hostname_for_docker('http://localhost:8080') + assert result == 'http://host.docker.internal:8080' + + # HTTPS URL + result = replace_localhost_hostname_for_docker('https://localhost:443') + assert result == 'https://host.docker.internal:443' + + # No port specified + result = replace_localhost_hostname_for_docker('http://localhost') + assert result == 'http://host.docker.internal' + + @patch( + 'openhands.app_server.utils.docker_utils.is_running_in_docker', + return_value=False, + ) + def test_replace_localhost_basic_not_in_docker(self, mock_is_docker): + """Test that localhost is NOT replaced when not running in Docker.""" + # Basic HTTP URL + result = replace_localhost_hostname_for_docker('http://localhost:8080') + assert result == 'http://localhost:8080' + + # HTTPS URL + result = replace_localhost_hostname_for_docker('https://localhost:443') + assert result == 'https://localhost:443' + + # No port specified + result = replace_localhost_hostname_for_docker('http://localhost') + assert result == 'http://localhost' + + @patch( + 'openhands.app_server.utils.docker_utils.is_running_in_docker', + return_value=True, + ) + def test_replace_localhost_with_path_and_query(self, mock_is_docker): + """Test localhost replacement preserving path and query parameters.""" + # With path + result = replace_localhost_hostname_for_docker( + 'http://localhost:3000/api/health' + ) + assert result == 'http://host.docker.internal:3000/api/health' + + # With query parameters containing localhost + result = replace_localhost_hostname_for_docker( + 'http://localhost:8080/path?param=localhost&other=value' + ) + assert ( + result + == 'http://host.docker.internal:8080/path?param=localhost&other=value' + ) + + # With path containing localhost + result = replace_localhost_hostname_for_docker( + 'http://localhost:9000/localhost/endpoint' + ) + assert result == 'http://host.docker.internal:9000/localhost/endpoint' + + # With fragment + result = replace_localhost_hostname_for_docker( + 'http://localhost:8080/path#localhost' + ) + assert result == 'http://host.docker.internal:8080/path#localhost' + + @patch( + 'openhands.app_server.utils.docker_utils.is_running_in_docker', + return_value=True, + ) + def test_replace_localhost_with_authentication(self, mock_is_docker): + """Test localhost replacement with authentication in URL.""" + result = replace_localhost_hostname_for_docker( + 'http://user:pass@localhost:8080/path' + ) + assert result == 'http://user:pass@host.docker.internal:8080/path' + + result = replace_localhost_hostname_for_docker( + 'https://admin:secret@localhost:443/admin' + ) + assert result == 'https://admin:secret@host.docker.internal:443/admin' + + @patch( + 'openhands.app_server.utils.docker_utils.is_running_in_docker', + return_value=True, + ) + def test_replace_localhost_different_protocols(self, mock_is_docker): + """Test localhost replacement with different protocols.""" + # FTP + result = replace_localhost_hostname_for_docker('ftp://localhost:21/files') + assert result == 'ftp://host.docker.internal:21/files' + + # WebSocket + result = replace_localhost_hostname_for_docker('ws://localhost:8080/socket') + assert result == 'ws://host.docker.internal:8080/socket' + + # WebSocket Secure + result = replace_localhost_hostname_for_docker( + 'wss://localhost:443/secure-socket' + ) + assert result == 'wss://host.docker.internal:443/secure-socket' + + @patch( + 'openhands.app_server.utils.docker_utils.is_running_in_docker', + return_value=True, + ) + def test_no_replacement_for_non_localhost(self, mock_is_docker): + """Test that non-localhost hostnames are not replaced even when in Docker.""" + # IP address + result = replace_localhost_hostname_for_docker('http://127.0.0.1:8080') + assert result == 'http://127.0.0.1:8080' + + # Different hostname + result = replace_localhost_hostname_for_docker('http://example.com:8080') + assert result == 'http://example.com:8080' + + # Hostname containing localhost but not exact match + result = replace_localhost_hostname_for_docker('http://mylocalhost:8080') + assert result == 'http://mylocalhost:8080' + + # Subdomain of localhost + result = replace_localhost_hostname_for_docker('http://api.localhost:8080') + assert result == 'http://api.localhost:8080' + + # localhost as subdomain + result = replace_localhost_hostname_for_docker( + 'http://localhost.example.com:8080' + ) + assert result == 'http://localhost.example.com:8080' + + @patch( + 'openhands.app_server.utils.docker_utils.is_running_in_docker', + return_value=True, + ) + def test_custom_replacement_hostname(self, mock_is_docker): + """Test using custom replacement hostname.""" + result = replace_localhost_hostname_for_docker( + 'http://localhost:8080', 'custom.host' + ) + assert result == 'http://custom.host:8080' + + result = replace_localhost_hostname_for_docker( + 'https://localhost:443/path', 'internal.docker' + ) + assert result == 'https://internal.docker:443/path' + + @patch( + 'openhands.app_server.utils.docker_utils.is_running_in_docker', + return_value=True, + ) + def test_edge_cases_in_docker(self, mock_is_docker): + """Test edge cases and malformed URLs when in Docker.""" + # Empty string + result = replace_localhost_hostname_for_docker('') + assert result == '' + + # Malformed URL (no protocol) + result = replace_localhost_hostname_for_docker('localhost:8080') + assert result == 'localhost:8080' + + # Just hostname + result = replace_localhost_hostname_for_docker('localhost') + assert result == 'localhost' + + # URL with no hostname + result = replace_localhost_hostname_for_docker('http://:8080/path') + assert result == 'http://:8080/path' + + # Invalid URL structure + result = replace_localhost_hostname_for_docker('not-a-url') + assert result == 'not-a-url' + + @patch( + 'openhands.app_server.utils.docker_utils.is_running_in_docker', + return_value=False, + ) + def test_edge_cases_not_in_docker(self, mock_is_docker): + """Test edge cases and malformed URLs when not in Docker.""" + # Empty string + result = replace_localhost_hostname_for_docker('') + assert result == '' + + # Malformed URL (no protocol) + result = replace_localhost_hostname_for_docker('localhost:8080') + assert result == 'localhost:8080' + + # Just hostname + result = replace_localhost_hostname_for_docker('localhost') + assert result == 'localhost' + + # URL with no hostname + result = replace_localhost_hostname_for_docker('http://:8080/path') + assert result == 'http://:8080/path' + + # Invalid URL structure + result = replace_localhost_hostname_for_docker('not-a-url') + assert result == 'not-a-url' + + @patch( + 'openhands.app_server.utils.docker_utils.is_running_in_docker', + return_value=True, + ) + def test_complex_urls(self, mock_is_docker): + """Test complex URL scenarios.""" + # Multiple query parameters and fragments + complex_url = 'http://localhost:8080/api/v1/health?timeout=30&retry=3&host=localhost#section' + result = replace_localhost_hostname_for_docker(complex_url) + expected = 'http://host.docker.internal:8080/api/v1/health?timeout=30&retry=3&host=localhost#section' + assert result == expected + + # URL with encoded characters + encoded_url = ( + 'http://localhost:8080/path%20with%20spaces?param=value%20with%20spaces' + ) + result = replace_localhost_hostname_for_docker(encoded_url) + expected = 'http://host.docker.internal:8080/path%20with%20spaces?param=value%20with%20spaces' + assert result == expected + + @patch( + 'openhands.app_server.utils.docker_utils.is_running_in_docker', + return_value=True, + ) + def test_integration_with_docker_detection_in_docker(self, mock_is_docker): + """Test integration scenario similar to actual usage when in Docker.""" + # Simulate the actual usage pattern in the code + app_server_url = 'http://localhost:35375' + + # This is how it's used in the actual code + internal_url = replace_localhost_hostname_for_docker(app_server_url) + + assert internal_url == 'http://host.docker.internal:35375' + + # Test with health check path appended + health_check_url = f'{internal_url}/health' + assert health_check_url == 'http://host.docker.internal:35375/health' + + @patch( + 'openhands.app_server.utils.docker_utils.is_running_in_docker', + return_value=False, + ) + def test_integration_with_docker_detection_not_in_docker(self, mock_is_docker): + """Test integration scenario similar to actual usage when not in Docker.""" + # Simulate the actual usage pattern in the code + app_server_url = 'http://localhost:35375' + + # This is how it's used in the actual code + internal_url = replace_localhost_hostname_for_docker(app_server_url) + + # Should return original URL when not in Docker + assert internal_url == 'http://localhost:35375' + + # Test with health check path appended + health_check_url = f'{internal_url}/health' + assert health_check_url == 'http://localhost:35375/health' + + @patch( + 'openhands.app_server.utils.docker_utils.is_running_in_docker', + return_value=True, + ) + def test_preserves_original_url_structure(self, mock_is_docker): + """Test that all URL components are preserved correctly.""" + original_url = 'https://user:pass@localhost:8443/api/v1/endpoint?param1=value1¶m2=value2#fragment' + result = replace_localhost_hostname_for_docker(original_url) + expected = 'https://user:pass@host.docker.internal:8443/api/v1/endpoint?param1=value1¶m2=value2#fragment' + + assert result == expected + + # Verify each component is preserved + from urllib.parse import urlparse + + original_parsed = urlparse(original_url) + result_parsed = urlparse(result) + + assert original_parsed.scheme == result_parsed.scheme + assert original_parsed.username == result_parsed.username + assert original_parsed.password == result_parsed.password + assert original_parsed.port == result_parsed.port + assert original_parsed.path == result_parsed.path + assert original_parsed.query == result_parsed.query + assert original_parsed.fragment == result_parsed.fragment + + # Only hostname should be different + assert original_parsed.hostname == 'localhost' + assert result_parsed.hostname == 'host.docker.internal' diff --git a/tests/unit/app_server/test_remote_sandbox_service.py b/tests/unit/app_server/test_remote_sandbox_service.py index 1d917cc760..567ecad2e3 100644 --- a/tests/unit/app_server/test_remote_sandbox_service.py +++ b/tests/unit/app_server/test_remote_sandbox_service.py @@ -435,7 +435,7 @@ class TestSandboxLifecycle: 9 ) # max_num_sandboxes - 1 remote_sandbox_service.db_session.add.assert_called_once() - remote_sandbox_service.db_session.commit.assert_called_once() + remote_sandbox_service.db_session.commit.assert_not_called() @pytest.mark.asyncio async def test_start_sandbox_with_specific_spec( @@ -627,7 +627,7 @@ class TestSandboxLifecycle: # Verify assert result is True remote_sandbox_service.db_session.delete.assert_called_once_with(stored_sandbox) - remote_sandbox_service.db_session.commit.assert_called_once() + remote_sandbox_service.db_session.commit.assert_not_called() remote_sandbox_service.httpx_client.request.assert_called_once_with( 'POST', 'https://api.example.com/stop', diff --git a/tests/unit/app_server/test_sql_app_conversation_info_service.py b/tests/unit/app_server/test_sql_app_conversation_info_service.py index 0315ff2604..2b741d984f 100644 --- a/tests/unit/app_server/test_sql_app_conversation_info_service.py +++ b/tests/unit/app_server/test_sql_app_conversation_info_service.py @@ -563,3 +563,383 @@ class TestSQLAppConversationInfoService: created_at__gte=start_time, created_at__lt=end_time ) assert count == 2 + + @pytest.mark.asyncio + async def test_search_excludes_sub_conversations_by_default( + self, + service: SQLAppConversationInfoService, + ): + """Test that search excludes sub-conversations by default.""" + # Create a parent conversation + parent_id = uuid4() + parent_info = AppConversationInfo( + id=parent_id, + created_by_user_id='test_user_123', + sandbox_id='sandbox_parent', + title='Parent Conversation', + created_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc), + updated_at=datetime(2024, 1, 1, 12, 30, 0, tzinfo=timezone.utc), + ) + + # Create sub-conversations + sub_info_1 = AppConversationInfo( + id=uuid4(), + created_by_user_id='test_user_123', + sandbox_id='sandbox_sub1', + title='Sub Conversation 1', + parent_conversation_id=parent_id, + created_at=datetime(2024, 1, 1, 13, 0, 0, tzinfo=timezone.utc), + updated_at=datetime(2024, 1, 1, 13, 30, 0, tzinfo=timezone.utc), + ) + + sub_info_2 = AppConversationInfo( + id=uuid4(), + created_by_user_id='test_user_123', + sandbox_id='sandbox_sub2', + title='Sub Conversation 2', + parent_conversation_id=parent_id, + created_at=datetime(2024, 1, 1, 14, 0, 0, tzinfo=timezone.utc), + updated_at=datetime(2024, 1, 1, 14, 30, 0, tzinfo=timezone.utc), + ) + + # Save all conversations + await service.save_app_conversation_info(parent_info) + await service.save_app_conversation_info(sub_info_1) + await service.save_app_conversation_info(sub_info_2) + + # Search without include_sub_conversations (default False) + page = await service.search_app_conversation_info() + + # Should only return the parent conversation + assert len(page.items) == 1 + assert page.items[0].id == parent_id + assert page.items[0].title == 'Parent Conversation' + assert page.items[0].parent_conversation_id is None + + @pytest.mark.asyncio + async def test_search_includes_sub_conversations_when_flag_true( + self, + service: SQLAppConversationInfoService, + ): + """Test that search includes sub-conversations when include_sub_conversations=True.""" + # Create a parent conversation + parent_id = uuid4() + parent_info = AppConversationInfo( + id=parent_id, + created_by_user_id='test_user_123', + sandbox_id='sandbox_parent', + title='Parent Conversation', + created_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc), + updated_at=datetime(2024, 1, 1, 12, 30, 0, tzinfo=timezone.utc), + ) + + # Create sub-conversations + sub_info_1 = AppConversationInfo( + id=uuid4(), + created_by_user_id='test_user_123', + sandbox_id='sandbox_sub1', + title='Sub Conversation 1', + parent_conversation_id=parent_id, + created_at=datetime(2024, 1, 1, 13, 0, 0, tzinfo=timezone.utc), + updated_at=datetime(2024, 1, 1, 13, 30, 0, tzinfo=timezone.utc), + ) + + sub_info_2 = AppConversationInfo( + id=uuid4(), + created_by_user_id='test_user_123', + sandbox_id='sandbox_sub2', + title='Sub Conversation 2', + parent_conversation_id=parent_id, + created_at=datetime(2024, 1, 1, 14, 0, 0, tzinfo=timezone.utc), + updated_at=datetime(2024, 1, 1, 14, 30, 0, tzinfo=timezone.utc), + ) + + # Save all conversations + await service.save_app_conversation_info(parent_info) + await service.save_app_conversation_info(sub_info_1) + await service.save_app_conversation_info(sub_info_2) + + # Search with include_sub_conversations=True + page = await service.search_app_conversation_info( + include_sub_conversations=True + ) + + # Should return all conversations (1 parent + 2 sub-conversations) + assert len(page.items) == 3 + + # Verify all conversations are present + conversation_ids = {item.id for item in page.items} + assert parent_id in conversation_ids + assert sub_info_1.id in conversation_ids + assert sub_info_2.id in conversation_ids + + # Verify parent conversation has no parent_conversation_id + parent_item = next(item for item in page.items if item.id == parent_id) + assert parent_item.parent_conversation_id is None + + # Verify sub-conversations have parent_conversation_id set + sub_item_1 = next(item for item in page.items if item.id == sub_info_1.id) + assert sub_item_1.parent_conversation_id == parent_id + + sub_item_2 = next(item for item in page.items if item.id == sub_info_2.id) + assert sub_item_2.parent_conversation_id == parent_id + + @pytest.mark.asyncio + async def test_search_sub_conversations_with_filters( + self, + service: SQLAppConversationInfoService, + ): + """Test that include_sub_conversations works correctly with other filters.""" + # Create a parent conversation + parent_id = uuid4() + parent_info = AppConversationInfo( + id=parent_id, + created_by_user_id='test_user_123', + sandbox_id='sandbox_parent', + title='Parent Conversation', + created_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc), + updated_at=datetime(2024, 1, 1, 12, 30, 0, tzinfo=timezone.utc), + ) + + # Create sub-conversations with different titles + sub_info_1 = AppConversationInfo( + id=uuid4(), + created_by_user_id='test_user_123', + sandbox_id='sandbox_sub1', + title='Sub Conversation Alpha', + parent_conversation_id=parent_id, + created_at=datetime(2024, 1, 1, 13, 0, 0, tzinfo=timezone.utc), + updated_at=datetime(2024, 1, 1, 13, 30, 0, tzinfo=timezone.utc), + ) + + sub_info_2 = AppConversationInfo( + id=uuid4(), + created_by_user_id='test_user_123', + sandbox_id='sandbox_sub2', + title='Sub Conversation Beta', + parent_conversation_id=parent_id, + created_at=datetime(2024, 1, 1, 14, 0, 0, tzinfo=timezone.utc), + updated_at=datetime(2024, 1, 1, 14, 30, 0, tzinfo=timezone.utc), + ) + + # Save all conversations + await service.save_app_conversation_info(parent_info) + await service.save_app_conversation_info(sub_info_1) + await service.save_app_conversation_info(sub_info_2) + + # Search with title filter and include_sub_conversations=False (default) + page = await service.search_app_conversation_info(title__contains='Alpha') + # Should only find parent if it matches, but parent doesn't have "Alpha" + # So should find nothing or only sub if we include them + assert len(page.items) == 0 + + # Search with title filter and include_sub_conversations=True + page = await service.search_app_conversation_info( + title__contains='Alpha', include_sub_conversations=True + ) + # Should find the sub-conversation with "Alpha" in title + assert len(page.items) == 1 + assert page.items[0].title == 'Sub Conversation Alpha' + assert page.items[0].parent_conversation_id == parent_id + + # Search with title filter for "Parent" and include_sub_conversations=True + page = await service.search_app_conversation_info( + title__contains='Parent', include_sub_conversations=True + ) + # Should find the parent conversation + assert len(page.items) == 1 + assert page.items[0].title == 'Parent Conversation' + assert page.items[0].parent_conversation_id is None + + @pytest.mark.asyncio + async def test_search_sub_conversations_with_date_filters( + self, + service: SQLAppConversationInfoService, + ): + """Test that include_sub_conversations works correctly with date filters.""" + # Create a parent conversation + parent_id = uuid4() + parent_info = AppConversationInfo( + id=parent_id, + created_by_user_id='test_user_123', + sandbox_id='sandbox_parent', + title='Parent Conversation', + created_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc), + updated_at=datetime(2024, 1, 1, 12, 30, 0, tzinfo=timezone.utc), + ) + + # Create sub-conversations at different times + sub_info_1 = AppConversationInfo( + id=uuid4(), + created_by_user_id='test_user_123', + sandbox_id='sandbox_sub1', + title='Sub Conversation 1', + parent_conversation_id=parent_id, + created_at=datetime(2024, 1, 1, 13, 0, 0, tzinfo=timezone.utc), + updated_at=datetime(2024, 1, 1, 13, 30, 0, tzinfo=timezone.utc), + ) + + sub_info_2 = AppConversationInfo( + id=uuid4(), + created_by_user_id='test_user_123', + sandbox_id='sandbox_sub2', + title='Sub Conversation 2', + parent_conversation_id=parent_id, + created_at=datetime(2024, 1, 1, 14, 0, 0, tzinfo=timezone.utc), + updated_at=datetime(2024, 1, 1, 14, 30, 0, tzinfo=timezone.utc), + ) + + # Save all conversations + await service.save_app_conversation_info(parent_info) + await service.save_app_conversation_info(sub_info_1) + await service.save_app_conversation_info(sub_info_2) + + # Search with date filter and include_sub_conversations=False (default) + cutoff_time = datetime(2024, 1, 1, 13, 30, 0, tzinfo=timezone.utc) + page = await service.search_app_conversation_info(created_at__gte=cutoff_time) + # Should only return parent if it matches the filter, but parent is at 12:00 + assert len(page.items) == 0 + + # Search with date filter and include_sub_conversations=True + page = await service.search_app_conversation_info( + created_at__gte=cutoff_time, include_sub_conversations=True + ) + # Should find sub-conversations created after cutoff (sub_info_2 at 14:00) + assert len(page.items) == 1 + assert page.items[0].id == sub_info_2.id + assert page.items[0].parent_conversation_id == parent_id + + @pytest.mark.asyncio + async def test_search_multiple_parents_with_sub_conversations( + self, + service: SQLAppConversationInfoService, + ): + """Test search with multiple parent conversations and their sub-conversations.""" + # Create first parent conversation + parent1_id = uuid4() + parent1_info = AppConversationInfo( + id=parent1_id, + created_by_user_id='test_user_123', + sandbox_id='sandbox_parent1', + title='Parent 1', + created_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc), + updated_at=datetime(2024, 1, 1, 12, 30, 0, tzinfo=timezone.utc), + ) + + # Create second parent conversation + parent2_id = uuid4() + parent2_info = AppConversationInfo( + id=parent2_id, + created_by_user_id='test_user_123', + sandbox_id='sandbox_parent2', + title='Parent 2', + created_at=datetime(2024, 1, 1, 13, 0, 0, tzinfo=timezone.utc), + updated_at=datetime(2024, 1, 1, 13, 30, 0, tzinfo=timezone.utc), + ) + + # Create sub-conversations for parent1 + sub1_1 = AppConversationInfo( + id=uuid4(), + created_by_user_id='test_user_123', + sandbox_id='sandbox_sub1_1', + title='Sub 1-1', + parent_conversation_id=parent1_id, + created_at=datetime(2024, 1, 1, 14, 0, 0, tzinfo=timezone.utc), + updated_at=datetime(2024, 1, 1, 14, 30, 0, tzinfo=timezone.utc), + ) + + # Create sub-conversations for parent2 + sub2_1 = AppConversationInfo( + id=uuid4(), + created_by_user_id='test_user_123', + sandbox_id='sandbox_sub2_1', + title='Sub 2-1', + parent_conversation_id=parent2_id, + created_at=datetime(2024, 1, 1, 15, 0, 0, tzinfo=timezone.utc), + updated_at=datetime(2024, 1, 1, 15, 30, 0, tzinfo=timezone.utc), + ) + + # Save all conversations + await service.save_app_conversation_info(parent1_info) + await service.save_app_conversation_info(parent2_info) + await service.save_app_conversation_info(sub1_1) + await service.save_app_conversation_info(sub2_1) + + # Search without include_sub_conversations (default False) + page = await service.search_app_conversation_info() + # Should only return the 2 parent conversations + assert len(page.items) == 2 + conversation_ids = {item.id for item in page.items} + assert parent1_id in conversation_ids + assert parent2_id in conversation_ids + assert sub1_1.id not in conversation_ids + assert sub2_1.id not in conversation_ids + + # Search with include_sub_conversations=True + page = await service.search_app_conversation_info( + include_sub_conversations=True + ) + # Should return all 4 conversations (2 parents + 2 sub-conversations) + assert len(page.items) == 4 + conversation_ids = {item.id for item in page.items} + assert parent1_id in conversation_ids + assert parent2_id in conversation_ids + assert sub1_1.id in conversation_ids + assert sub2_1.id in conversation_ids + + @pytest.mark.asyncio + async def test_search_sub_conversations_with_pagination( + self, + service: SQLAppConversationInfoService, + ): + """Test that include_sub_conversations works correctly with pagination.""" + # Create a parent conversation + parent_id = uuid4() + parent_info = AppConversationInfo( + id=parent_id, + created_by_user_id='test_user_123', + sandbox_id='sandbox_parent', + title='Parent Conversation', + created_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc), + updated_at=datetime(2024, 1, 1, 12, 30, 0, tzinfo=timezone.utc), + ) + + # Create multiple sub-conversations + sub_conversations = [] + for i in range(5): + sub_info = AppConversationInfo( + id=uuid4(), + created_by_user_id='test_user_123', + sandbox_id=f'sandbox_sub{i}', + title=f'Sub Conversation {i}', + parent_conversation_id=parent_id, + created_at=datetime(2024, 1, 1, 13 + i, 0, 0, tzinfo=timezone.utc), + updated_at=datetime(2024, 1, 1, 13 + i, 30, 0, tzinfo=timezone.utc), + ) + sub_conversations.append(sub_info) + await service.save_app_conversation_info(sub_info) + + # Save parent + await service.save_app_conversation_info(parent_info) + + # Search with include_sub_conversations=True and pagination + page1 = await service.search_app_conversation_info( + include_sub_conversations=True, limit=3 + ) + # Should return 3 items (1 parent + 2 sub-conversations) + assert len(page1.items) == 3 + assert page1.next_page_id is not None + + # Get next page + page2 = await service.search_app_conversation_info( + include_sub_conversations=True, limit=3, page_id=page1.next_page_id + ) + # Should return remaining items + assert len(page2.items) == 3 + assert page2.next_page_id is None + + # Verify all conversations are present across pages + all_ids = {item.id for item in page1.items} | {item.id for item in page2.items} + assert parent_id in all_ids + for sub_info in sub_conversations: + assert sub_info.id in all_ids diff --git a/tests/unit/app_server/test_sql_app_conversation_start_task_service.py b/tests/unit/app_server/test_sql_app_conversation_start_task_service.py index 017f4f1fc8..943595e141 100644 --- a/tests/unit/app_server/test_sql_app_conversation_start_task_service.py +++ b/tests/unit/app_server/test_sql_app_conversation_start_task_service.py @@ -639,3 +639,145 @@ class TestSQLAppConversationStartTaskService: user2_count = await user2_service.count_app_conversation_start_tasks() assert user2_count == 1 + + async def test_search_app_conversation_start_tasks_with_created_at_gte_filter( + self, + service: SQLAppConversationStartTaskService, + sample_request: AppConversationStartRequest, + ): + """Test search with created_at__gte filter.""" + from datetime import timedelta + + from openhands.agent_server.models import utc_now + + # Create tasks with different creation times + base_time = utc_now() + + # Task 1: created 2 hours ago + task1 = AppConversationStartTask( + id=uuid4(), + created_by_user_id='user1', + status=AppConversationStartTaskStatus.WORKING, + request=sample_request, + ) + task1.created_at = base_time - timedelta(hours=2) + await service.save_app_conversation_start_task(task1) + + # Task 2: created 1 hour ago + task2 = AppConversationStartTask( + id=uuid4(), + created_by_user_id='user1', + status=AppConversationStartTaskStatus.READY, + request=sample_request, + ) + task2.created_at = base_time - timedelta(hours=1) + await service.save_app_conversation_start_task(task2) + + # Task 3: created 30 minutes ago + task3 = AppConversationStartTask( + id=uuid4(), + created_by_user_id='user1', + status=AppConversationStartTaskStatus.WORKING, + request=sample_request, + ) + task3.created_at = base_time - timedelta(minutes=30) + await service.save_app_conversation_start_task(task3) + + # Search for tasks created in the last 90 minutes + filter_time = base_time - timedelta(minutes=90) + result = await service.search_app_conversation_start_tasks( + created_at__gte=filter_time + ) + + # Should return task2 and task3 (created within last 90 minutes) + assert len(result.items) == 2 + task_ids = [task.id for task in result.items] + assert task2.id in task_ids + assert task3.id in task_ids + assert task1.id not in task_ids + + # Test count with the same filter + count = await service.count_app_conversation_start_tasks( + created_at__gte=filter_time + ) + assert count == 2 + + # Search for tasks created in the last 45 minutes + filter_time_recent = base_time - timedelta(minutes=45) + result_recent = await service.search_app_conversation_start_tasks( + created_at__gte=filter_time_recent + ) + + # Should return only task3 + assert len(result_recent.items) == 1 + assert result_recent.items[0].id == task3.id + + # Test count with recent filter + count_recent = await service.count_app_conversation_start_tasks( + created_at__gte=filter_time_recent + ) + assert count_recent == 1 + + async def test_search_app_conversation_start_tasks_combined_filters( + self, + service: SQLAppConversationStartTaskService, + sample_request: AppConversationStartRequest, + ): + """Test search with both conversation_id and created_at__gte filters.""" + from datetime import timedelta + + from openhands.agent_server.models import utc_now + + conversation_id1 = uuid4() + conversation_id2 = uuid4() + base_time = utc_now() + + # Task 1: conversation_id1, created 2 hours ago + task1 = AppConversationStartTask( + id=uuid4(), + created_by_user_id='user1', + status=AppConversationStartTaskStatus.WORKING, + app_conversation_id=conversation_id1, + request=sample_request, + ) + task1.created_at = base_time - timedelta(hours=2) + await service.save_app_conversation_start_task(task1) + + # Task 2: conversation_id1, created 30 minutes ago + task2 = AppConversationStartTask( + id=uuid4(), + created_by_user_id='user1', + status=AppConversationStartTaskStatus.READY, + app_conversation_id=conversation_id1, + request=sample_request, + ) + task2.created_at = base_time - timedelta(minutes=30) + await service.save_app_conversation_start_task(task2) + + # Task 3: conversation_id2, created 30 minutes ago + task3 = AppConversationStartTask( + id=uuid4(), + created_by_user_id='user1', + status=AppConversationStartTaskStatus.WORKING, + app_conversation_id=conversation_id2, + request=sample_request, + ) + task3.created_at = base_time - timedelta(minutes=30) + await service.save_app_conversation_start_task(task3) + + # Search for tasks with conversation_id1 created in the last hour + filter_time = base_time - timedelta(hours=1) + result = await service.search_app_conversation_start_tasks( + conversation_id__eq=conversation_id1, created_at__gte=filter_time + ) + + # Should return only task2 (conversation_id1 and created within last hour) + assert len(result.items) == 1 + assert result.items[0].id == task2.id + assert result.items[0].app_conversation_id == conversation_id1 + + # Test count with combined filters + count = await service.count_app_conversation_start_tasks( + conversation_id__eq=conversation_id1, created_at__gte=filter_time + ) + assert count == 1 diff --git a/tests/unit/server/data_models/test_conversation.py b/tests/unit/server/data_models/test_conversation.py index 0917dc1fac..c3cf34dac3 100644 --- a/tests/unit/server/data_models/test_conversation.py +++ b/tests/unit/server/data_models/test_conversation.py @@ -911,10 +911,16 @@ async def test_delete_conversation(): # Create a mock app conversation service mock_app_conversation_service = MagicMock() - mock_app_conversation_service.get_app_conversation = AsyncMock( + + # Create a mock app conversation info service + mock_app_conversation_info_service = MagicMock() + mock_app_conversation_info_service.get_app_conversation_info = AsyncMock( return_value=None ) + # Create a mock sandbox service + mock_sandbox_service = MagicMock() + # Mock the conversation manager with patch( 'openhands.server.routes.manage_conversations.conversation_manager' @@ -932,9 +938,12 @@ async def test_delete_conversation(): # Call delete_conversation result = await delete_conversation( + request=MagicMock(), conversation_id='some_conversation_id', user_id='12345', app_conversation_service=mock_app_conversation_service, + app_conversation_info_service=mock_app_conversation_info_service, + sandbox_service=mock_sandbox_service, ) # Verify the result @@ -972,42 +981,63 @@ async def test_delete_v1_conversation_success(): mock_service = MagicMock() mock_service_dep.return_value = mock_service - # Mock the conversation exists - mock_app_conversation = AppConversation( - id=conversation_uuid, - created_by_user_id='test_user', - sandbox_id='test-sandbox-id', - title='Test V1 Conversation', - sandbox_status=SandboxStatus.RUNNING, - execution_status=ConversationExecutionStatus.RUNNING, - session_api_key='test-api-key', - selected_repository='test/repo', - selected_branch='main', - git_provider=ProviderType.GITHUB, - trigger=ConversationTrigger.GUI, - created_at=datetime.now(timezone.utc), - updated_at=datetime.now(timezone.utc), - ) - mock_service.get_app_conversation = AsyncMock( - return_value=mock_app_conversation - ) - mock_service.delete_app_conversation = AsyncMock(return_value=True) + # Mock the app conversation info service + with patch( + 'openhands.server.routes.manage_conversations.app_conversation_info_service_dependency' + ) as mock_info_service_dep: + mock_info_service = MagicMock() + mock_info_service_dep.return_value = mock_info_service - # Call delete_conversation with V1 conversation ID - result = await delete_conversation( - conversation_id=conversation_id, - user_id='test_user', - app_conversation_service=mock_service, - ) + # Mock the sandbox service + with patch( + 'openhands.server.routes.manage_conversations.sandbox_service_dependency' + ) as mock_sandbox_service_dep: + mock_sandbox_service = MagicMock() + mock_sandbox_service_dep.return_value = mock_sandbox_service - # Verify the result - assert result is True + # Mock the conversation info exists + mock_app_conversation_info = AppConversation( + id=conversation_uuid, + created_by_user_id='test_user', + sandbox_id='test-sandbox-id', + title='Test V1 Conversation', + sandbox_status=SandboxStatus.RUNNING, + execution_status=ConversationExecutionStatus.RUNNING, + session_api_key='test-api-key', + selected_repository='test/repo', + selected_branch='main', + git_provider=ProviderType.GITHUB, + trigger=ConversationTrigger.GUI, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + mock_info_service.get_app_conversation_info = AsyncMock( + return_value=mock_app_conversation_info + ) + mock_service.delete_app_conversation = AsyncMock(return_value=True) - # Verify that get_app_conversation was called - mock_service.get_app_conversation.assert_called_once_with(conversation_uuid) + # Call delete_conversation with V1 conversation ID + result = await delete_conversation( + request=MagicMock(), + conversation_id=conversation_id, + user_id='test_user', + app_conversation_service=mock_service, + app_conversation_info_service=mock_info_service, + sandbox_service=mock_sandbox_service, + ) - # Verify that delete_app_conversation was called with the conversation ID - mock_service.delete_app_conversation.assert_called_once_with(conversation_uuid) + # Verify the result + assert result is True + + # Verify that get_app_conversation_info was called + mock_info_service.get_app_conversation_info.assert_called_once_with( + conversation_uuid + ) + + # Verify that delete_app_conversation was called with the conversation ID + mock_service.delete_app_conversation.assert_called_once_with( + conversation_uuid + ) @pytest.mark.asyncio @@ -1025,25 +1055,46 @@ async def test_delete_v1_conversation_not_found(): mock_service = MagicMock() mock_service_dep.return_value = mock_service - # Mock the conversation doesn't exist - mock_service.get_app_conversation = AsyncMock(return_value=None) - mock_service.delete_app_conversation = AsyncMock(return_value=False) + # Mock the app conversation info service + with patch( + 'openhands.server.routes.manage_conversations.app_conversation_info_service_dependency' + ) as mock_info_service_dep: + mock_info_service = MagicMock() + mock_info_service_dep.return_value = mock_info_service - # Call delete_conversation with V1 conversation ID - result = await delete_conversation( - conversation_id=conversation_id, - user_id='test_user', - app_conversation_service=mock_service, - ) + # Mock the sandbox service + with patch( + 'openhands.server.routes.manage_conversations.sandbox_service_dependency' + ) as mock_sandbox_service_dep: + mock_sandbox_service = MagicMock() + mock_sandbox_service_dep.return_value = mock_sandbox_service - # Verify the result - assert result is False + # Mock the conversation doesn't exist + mock_info_service.get_app_conversation_info = AsyncMock( + return_value=None + ) + mock_service.delete_app_conversation = AsyncMock(return_value=False) - # Verify that get_app_conversation was called - mock_service.get_app_conversation.assert_called_once_with(conversation_uuid) + # Call delete_conversation with V1 conversation ID + result = await delete_conversation( + request=MagicMock(), + conversation_id=conversation_id, + user_id='test_user', + app_conversation_service=mock_service, + app_conversation_info_service=mock_info_service, + sandbox_service=mock_sandbox_service, + ) - # Verify that delete_app_conversation was NOT called - mock_service.delete_app_conversation.assert_not_called() + # Verify the result + assert result is False + + # Verify that get_app_conversation_info was called + mock_info_service.get_app_conversation_info.assert_called_once_with( + conversation_uuid + ) + + # Verify that delete_app_conversation was NOT called + mock_service.delete_app_conversation.assert_not_called() @pytest.mark.asyncio @@ -1091,19 +1142,40 @@ async def test_delete_v1_conversation_invalid_uuid(): mock_runtime_cls.delete = AsyncMock() mock_get_runtime_cls.return_value = mock_runtime_cls - # Call delete_conversation - result = await delete_conversation( - conversation_id=conversation_id, - user_id='test_user', - app_conversation_service=mock_service, - ) + # Mock the app conversation info service + with patch( + 'openhands.server.routes.manage_conversations.app_conversation_info_service_dependency' + ) as mock_info_service_dep: + mock_info_service = MagicMock() + mock_info_service_dep.return_value = mock_info_service - # Verify the result - assert result is True + # Mock the sandbox service + with patch( + 'openhands.server.routes.manage_conversations.sandbox_service_dependency' + ) as mock_sandbox_service_dep: + mock_sandbox_service = MagicMock() + mock_sandbox_service_dep.return_value = mock_sandbox_service - # Verify V0 logic was used - mock_store.delete_metadata.assert_called_once_with(conversation_id) - mock_runtime_cls.delete.assert_called_once_with(conversation_id) + # Call delete_conversation + result = await delete_conversation( + request=MagicMock(), + conversation_id=conversation_id, + user_id='test_user', + app_conversation_service=mock_service, + app_conversation_info_service=mock_info_service, + sandbox_service=mock_sandbox_service, + ) + + # Verify the result + assert result is True + + # Verify V0 logic was used + mock_store.delete_metadata.assert_called_once_with( + conversation_id + ) + mock_runtime_cls.delete.assert_called_once_with( + conversation_id + ) @pytest.mark.asyncio @@ -1121,57 +1193,84 @@ async def test_delete_v1_conversation_service_error(): mock_service = MagicMock() mock_service_dep.return_value = mock_service - # Mock service error - mock_service.get_app_conversation = AsyncMock( - side_effect=Exception('Service error') - ) - - # Mock V0 conversation logic as fallback + # Mock the app conversation info service with patch( - 'openhands.server.routes.manage_conversations.ConversationStoreImpl.get_instance' - ) as mock_get_instance: - mock_store = MagicMock() - mock_store.get_metadata = AsyncMock( - return_value=ConversationMetadata( - conversation_id=conversation_id, - title='Test V0 Conversation', - created_at=datetime.fromisoformat('2025-01-01T00:00:00+00:00'), - last_updated_at=datetime.fromisoformat('2025-01-01T00:01:00+00:00'), - selected_repository='test/repo', - user_id='test_user', - ) - ) - mock_store.delete_metadata = AsyncMock() - mock_get_instance.return_value = mock_store + 'openhands.server.routes.manage_conversations.app_conversation_info_service_dependency' + ) as mock_info_service_dep: + mock_info_service = MagicMock() + mock_info_service_dep.return_value = mock_info_service - # Mock conversation manager + # Mock the sandbox service with patch( - 'openhands.server.routes.manage_conversations.conversation_manager' - ) as mock_manager: - mock_manager.is_agent_loop_running = AsyncMock(return_value=False) - mock_manager.get_connections = AsyncMock(return_value={}) + 'openhands.server.routes.manage_conversations.sandbox_service_dependency' + ) as mock_sandbox_service_dep: + mock_sandbox_service = MagicMock() + mock_sandbox_service_dep.return_value = mock_sandbox_service - # Mock runtime + # Mock service error + mock_info_service.get_app_conversation_info = AsyncMock( + side_effect=Exception('Service error') + ) + + # Mock V0 conversation logic as fallback with patch( - 'openhands.server.routes.manage_conversations.get_runtime_cls' - ) as mock_get_runtime_cls: - mock_runtime_cls = MagicMock() - mock_runtime_cls.delete = AsyncMock() - mock_get_runtime_cls.return_value = mock_runtime_cls - - # Call delete_conversation - result = await delete_conversation( - conversation_id=conversation_id, - user_id='test_user', - app_conversation_service=mock_service, + 'openhands.server.routes.manage_conversations.ConversationStoreImpl.get_instance' + ) as mock_get_instance: + mock_store = MagicMock() + mock_store.get_metadata = AsyncMock( + return_value=ConversationMetadata( + conversation_id=conversation_id, + title='Test V0 Conversation', + created_at=datetime.fromisoformat( + '2025-01-01T00:00:00+00:00' + ), + last_updated_at=datetime.fromisoformat( + '2025-01-01T00:01:00+00:00' + ), + selected_repository='test/repo', + user_id='test_user', + ) ) + mock_store.delete_metadata = AsyncMock() + mock_get_instance.return_value = mock_store - # Verify the result (should fallback to V0) - assert result is True + # Mock conversation manager + with patch( + 'openhands.server.routes.manage_conversations.conversation_manager' + ) as mock_manager: + mock_manager.is_agent_loop_running = AsyncMock( + return_value=False + ) + mock_manager.get_connections = AsyncMock(return_value={}) - # Verify V0 logic was used - mock_store.delete_metadata.assert_called_once_with(conversation_id) - mock_runtime_cls.delete.assert_called_once_with(conversation_id) + # Mock runtime + with patch( + 'openhands.server.routes.manage_conversations.get_runtime_cls' + ) as mock_get_runtime_cls: + mock_runtime_cls = MagicMock() + mock_runtime_cls.delete = AsyncMock() + mock_get_runtime_cls.return_value = mock_runtime_cls + + # Call delete_conversation + result = await delete_conversation( + request=MagicMock(), + conversation_id=conversation_id, + user_id='test_user', + app_conversation_service=mock_service, + app_conversation_info_service=mock_info_service, + sandbox_service=mock_sandbox_service, + ) + + # Verify the result (should fallback to V0) + assert result is True + + # Verify V0 logic was used + mock_store.delete_metadata.assert_called_once_with( + conversation_id + ) + mock_runtime_cls.delete.assert_called_once_with( + conversation_id + ) @pytest.mark.asyncio @@ -1195,42 +1294,63 @@ async def test_delete_v1_conversation_with_agent_server(): mock_service = MagicMock() mock_service_dep.return_value = mock_service - # Mock the conversation exists with running sandbox - mock_app_conversation = AppConversation( - id=conversation_uuid, - created_by_user_id='test_user', - sandbox_id='test-sandbox-id', - title='Test V1 Conversation', - sandbox_status=SandboxStatus.RUNNING, - execution_status=ConversationExecutionStatus.RUNNING, - session_api_key='test-api-key', - selected_repository='test/repo', - selected_branch='main', - git_provider=ProviderType.GITHUB, - trigger=ConversationTrigger.GUI, - created_at=datetime.now(timezone.utc), - updated_at=datetime.now(timezone.utc), - ) - mock_service.get_app_conversation = AsyncMock( - return_value=mock_app_conversation - ) - mock_service.delete_app_conversation = AsyncMock(return_value=True) + # Mock the app conversation info service + with patch( + 'openhands.server.routes.manage_conversations.app_conversation_info_service_dependency' + ) as mock_info_service_dep: + mock_info_service = MagicMock() + mock_info_service_dep.return_value = mock_info_service - # Call delete_conversation with V1 conversation ID - result = await delete_conversation( - conversation_id=conversation_id, - user_id='test_user', - app_conversation_service=mock_service, - ) + # Mock the sandbox service + with patch( + 'openhands.server.routes.manage_conversations.sandbox_service_dependency' + ) as mock_sandbox_service_dep: + mock_sandbox_service = MagicMock() + mock_sandbox_service_dep.return_value = mock_sandbox_service - # Verify the result - assert result is True + # Mock the conversation exists with running sandbox + mock_app_conversation_info = AppConversation( + id=conversation_uuid, + created_by_user_id='test_user', + sandbox_id='test-sandbox-id', + title='Test V1 Conversation', + sandbox_status=SandboxStatus.RUNNING, + execution_status=ConversationExecutionStatus.RUNNING, + session_api_key='test-api-key', + selected_repository='test/repo', + selected_branch='main', + git_provider=ProviderType.GITHUB, + trigger=ConversationTrigger.GUI, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + mock_info_service.get_app_conversation_info = AsyncMock( + return_value=mock_app_conversation_info + ) + mock_service.delete_app_conversation = AsyncMock(return_value=True) - # Verify that get_app_conversation was called - mock_service.get_app_conversation.assert_called_once_with(conversation_uuid) + # Call delete_conversation with V1 conversation ID + result = await delete_conversation( + request=MagicMock(), + conversation_id=conversation_id, + user_id='test_user', + app_conversation_service=mock_service, + app_conversation_info_service=mock_info_service, + sandbox_service=mock_sandbox_service, + ) - # Verify that delete_app_conversation was called with the conversation ID - mock_service.delete_app_conversation.assert_called_once_with(conversation_uuid) + # Verify the result + assert result is True + + # Verify that get_app_conversation_info was called + mock_info_service.get_app_conversation_info.assert_called_once_with( + conversation_uuid + ) + + # Verify that delete_app_conversation was called with the conversation ID + mock_service.delete_app_conversation.assert_called_once_with( + conversation_uuid + ) @pytest.mark.asyncio diff --git a/tests/unit/server/routes/test_conversation_routes.py b/tests/unit/server/routes/test_conversation_routes.py index f909e44cc8..343894cefa 100644 --- a/tests/unit/server/routes/test_conversation_routes.py +++ b/tests/unit/server/routes/test_conversation_routes.py @@ -11,10 +11,21 @@ from openhands.app_server.app_conversation.app_conversation_info_service import AppConversationInfoService, ) from openhands.app_server.app_conversation.app_conversation_models import ( + AgentType, AppConversationInfo, + AppConversationPage, + AppConversationStartRequest, + AppConversationStartTask, + AppConversationStartTaskStatus, +) +from openhands.app_server.app_conversation.app_conversation_service import ( + AppConversationService, ) from openhands.microagent.microagent import KnowledgeMicroagent, RepoMicroagent from openhands.microagent.types import MicroagentMetadata, MicroagentType +from openhands.server.data_models.conversation_info_result_set import ( + ConversationInfoResultSet, +) from openhands.server.routes.conversation import ( AddMessageRequest, add_message, @@ -22,11 +33,15 @@ from openhands.server.routes.conversation import ( ) from openhands.server.routes.manage_conversations import ( UpdateConversationRequest, + search_conversations, update_conversation, ) from openhands.server.session.conversation import ServerConversation from openhands.storage.conversation.conversation_store import ConversationStore -from openhands.storage.data_models.conversation_metadata import ConversationMetadata +from openhands.storage.data_models.conversation_metadata import ( + ConversationMetadata, + ConversationTrigger, +) @pytest.mark.asyncio @@ -1125,3 +1140,322 @@ async def test_add_message_empty_message(): call_args = mock_manager.send_event_to_conversation.call_args message_data = call_args[0][1] assert message_data['args']['content'] == '' + + +@pytest.mark.sub_conversation +@pytest.mark.asyncio +async def test_create_sub_conversation_with_planning_agent(): + """Test creating a sub-conversation from a parent conversation with planning agent.""" + from uuid import uuid4 + + parent_conversation_id = uuid4() + user_id = 'test_user_456' + sandbox_id = 'test_sandbox_123' + + # Create mock parent conversation info + parent_info = AppConversationInfo( + id=parent_conversation_id, + created_by_user_id=user_id, + sandbox_id=sandbox_id, + selected_repository='test/repo', + selected_branch='main', + git_provider=None, + title='Parent Conversation', + llm_model='anthropic/claude-3-5-sonnet-20241022', + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + + # Create sub-conversation request with planning agent + sub_conversation_request = AppConversationStartRequest( + parent_conversation_id=parent_conversation_id, + agent_type=AgentType.PLAN, + initial_message=None, + ) + + # Create mock app conversation service + mock_app_conversation_service = MagicMock(spec=AppConversationService) + mock_app_conversation_info_service = MagicMock(spec=AppConversationInfoService) + + # Mock the service to return parent info + mock_app_conversation_info_service.get_app_conversation_info = AsyncMock( + return_value=parent_info + ) + + # Mock the start_app_conversation method to return a task + async def mock_start_generator(request): + task = AppConversationStartTask( + id=uuid4(), + created_by_user_id=user_id, + status=AppConversationStartTaskStatus.READY, + app_conversation_id=uuid4(), + sandbox_id=sandbox_id, + agent_server_url='http://agent-server:8000', + request=request, + ) + yield task + + mock_app_conversation_service.start_app_conversation = mock_start_generator + + # Test the service method directly + async for task in mock_app_conversation_service.start_app_conversation( + sub_conversation_request + ): + # Verify the task was created with planning agent + assert task is not None + assert task.status == AppConversationStartTaskStatus.READY + assert task.request.agent_type == AgentType.PLAN + assert task.request.parent_conversation_id == parent_conversation_id + assert task.sandbox_id == sandbox_id + break + + +@pytest.mark.asyncio +async def test_search_conversations_include_sub_conversations_default_false(): + """Test that include_sub_conversations defaults to False when not provided.""" + with patch('openhands.server.routes.manage_conversations.config') as mock_config: + mock_config.conversation_max_age_seconds = 864000 # 10 days + with patch( + 'openhands.server.routes.manage_conversations.conversation_manager' + ) as mock_manager: + + async def mock_get_running_agent_loops(*args, **kwargs): + return set() + + async def mock_get_connections(*args, **kwargs): + return {} + + async def get_agent_loop_info(*args, **kwargs): + return [] + + mock_manager.get_running_agent_loops = mock_get_running_agent_loops + mock_manager.get_connections = mock_get_connections + mock_manager.get_agent_loop_info = get_agent_loop_info + with patch( + 'openhands.server.routes.manage_conversations.datetime' + ) as mock_datetime: + mock_datetime.now.return_value = datetime.fromisoformat( + '2025-01-01T00:00:00+00:00' + ) + mock_datetime.fromisoformat = datetime.fromisoformat + mock_datetime.timezone = timezone + + # Mock the conversation store + mock_store = MagicMock() + mock_store.search = AsyncMock( + return_value=ConversationInfoResultSet(results=[]) + ) + + # Create a mock app conversation service + mock_app_conversation_service = AsyncMock() + mock_app_conversation_service.search_app_conversations.return_value = ( + AppConversationPage(items=[]) + ) + + # Call search_conversations without include_sub_conversations parameter + await search_conversations( + page_id=None, + limit=20, + selected_repository=None, + conversation_trigger=None, + conversation_store=mock_store, + app_conversation_service=mock_app_conversation_service, + ) + + # Verify that search_app_conversations was called with include_sub_conversations=False (default) + mock_app_conversation_service.search_app_conversations.assert_called_once() + call_kwargs = ( + mock_app_conversation_service.search_app_conversations.call_args[1] + ) + assert call_kwargs.get('include_sub_conversations') is False + + +@pytest.mark.asyncio +async def test_search_conversations_include_sub_conversations_explicit_false(): + """Test that include_sub_conversations=False is properly passed through.""" + with patch('openhands.server.routes.manage_conversations.config') as mock_config: + mock_config.conversation_max_age_seconds = 864000 # 10 days + with patch( + 'openhands.server.routes.manage_conversations.conversation_manager' + ) as mock_manager: + + async def mock_get_running_agent_loops(*args, **kwargs): + return set() + + async def mock_get_connections(*args, **kwargs): + return {} + + async def get_agent_loop_info(*args, **kwargs): + return [] + + mock_manager.get_running_agent_loops = mock_get_running_agent_loops + mock_manager.get_connections = mock_get_connections + mock_manager.get_agent_loop_info = get_agent_loop_info + with patch( + 'openhands.server.routes.manage_conversations.datetime' + ) as mock_datetime: + mock_datetime.now.return_value = datetime.fromisoformat( + '2025-01-01T00:00:00+00:00' + ) + mock_datetime.fromisoformat = datetime.fromisoformat + mock_datetime.timezone = timezone + + # Mock the conversation store + mock_store = MagicMock() + mock_store.search = AsyncMock( + return_value=ConversationInfoResultSet(results=[]) + ) + + # Create a mock app conversation service + mock_app_conversation_service = AsyncMock() + mock_app_conversation_service.search_app_conversations.return_value = ( + AppConversationPage(items=[]) + ) + + # Call search_conversations with include_sub_conversations=False + await search_conversations( + page_id=None, + limit=20, + selected_repository=None, + conversation_trigger=None, + include_sub_conversations=False, + conversation_store=mock_store, + app_conversation_service=mock_app_conversation_service, + ) + + # Verify that search_app_conversations was called with include_sub_conversations=False + mock_app_conversation_service.search_app_conversations.assert_called_once() + call_kwargs = ( + mock_app_conversation_service.search_app_conversations.call_args[1] + ) + assert call_kwargs.get('include_sub_conversations') is False + + +@pytest.mark.asyncio +async def test_search_conversations_include_sub_conversations_explicit_true(): + """Test that include_sub_conversations=True is properly passed through.""" + with patch('openhands.server.routes.manage_conversations.config') as mock_config: + mock_config.conversation_max_age_seconds = 864000 # 10 days + with patch( + 'openhands.server.routes.manage_conversations.conversation_manager' + ) as mock_manager: + + async def mock_get_running_agent_loops(*args, **kwargs): + return set() + + async def mock_get_connections(*args, **kwargs): + return {} + + async def get_agent_loop_info(*args, **kwargs): + return [] + + mock_manager.get_running_agent_loops = mock_get_running_agent_loops + mock_manager.get_connections = mock_get_connections + mock_manager.get_agent_loop_info = get_agent_loop_info + with patch( + 'openhands.server.routes.manage_conversations.datetime' + ) as mock_datetime: + mock_datetime.now.return_value = datetime.fromisoformat( + '2025-01-01T00:00:00+00:00' + ) + mock_datetime.fromisoformat = datetime.fromisoformat + mock_datetime.timezone = timezone + + # Mock the conversation store + mock_store = MagicMock() + mock_store.search = AsyncMock( + return_value=ConversationInfoResultSet(results=[]) + ) + + # Create a mock app conversation service + mock_app_conversation_service = AsyncMock() + mock_app_conversation_service.search_app_conversations.return_value = ( + AppConversationPage(items=[]) + ) + + # Call search_conversations with include_sub_conversations=True + await search_conversations( + page_id=None, + limit=20, + selected_repository=None, + conversation_trigger=None, + include_sub_conversations=True, + conversation_store=mock_store, + app_conversation_service=mock_app_conversation_service, + ) + + # Verify that search_app_conversations was called with include_sub_conversations=True + mock_app_conversation_service.search_app_conversations.assert_called_once() + call_kwargs = ( + mock_app_conversation_service.search_app_conversations.call_args[1] + ) + assert call_kwargs.get('include_sub_conversations') is True + + +@pytest.mark.asyncio +async def test_search_conversations_include_sub_conversations_with_other_filters(): + """Test that include_sub_conversations works correctly with other filters.""" + with patch('openhands.server.routes.manage_conversations.config') as mock_config: + mock_config.conversation_max_age_seconds = 864000 # 10 days + with patch( + 'openhands.server.routes.manage_conversations.conversation_manager' + ) as mock_manager: + + async def mock_get_running_agent_loops(*args, **kwargs): + return set() + + async def mock_get_connections(*args, **kwargs): + return {} + + async def get_agent_loop_info(*args, **kwargs): + return [] + + mock_manager.get_running_agent_loops = mock_get_running_agent_loops + mock_manager.get_connections = mock_get_connections + mock_manager.get_agent_loop_info = get_agent_loop_info + with patch( + 'openhands.server.routes.manage_conversations.datetime' + ) as mock_datetime: + mock_datetime.now.return_value = datetime.fromisoformat( + '2025-01-01T00:00:00+00:00' + ) + mock_datetime.fromisoformat = datetime.fromisoformat + mock_datetime.timezone = timezone + + # Mock the conversation store + mock_store = MagicMock() + mock_store.search = AsyncMock( + return_value=ConversationInfoResultSet(results=[]) + ) + + # Create a mock app conversation service + mock_app_conversation_service = AsyncMock() + mock_app_conversation_service.search_app_conversations.return_value = ( + AppConversationPage(items=[]) + ) + + # Create a valid base64-encoded page_id for testing + import base64 + + page_id_data = json.dumps({'v0': None, 'v1': 'test_v1_page_id'}) + encoded_page_id = base64.b64encode(page_id_data.encode()).decode() + + # Call search_conversations with include_sub_conversations and other filters + await search_conversations( + page_id=encoded_page_id, + limit=50, + selected_repository='test/repo', + conversation_trigger=ConversationTrigger.GUI, + include_sub_conversations=True, + conversation_store=mock_store, + app_conversation_service=mock_app_conversation_service, + ) + + # Verify that search_app_conversations was called with all parameters including include_sub_conversations=True + mock_app_conversation_service.search_app_conversations.assert_called_once() + call_kwargs = ( + mock_app_conversation_service.search_app_conversations.call_args[1] + ) + assert call_kwargs.get('include_sub_conversations') is True + assert call_kwargs.get('page_id') == 'test_v1_page_id' + assert call_kwargs.get('limit') == 50