[Fix]: Move suggest task prompts to BE (#8109)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Rohit Malhotra
2025-04-27 16:32:15 -04:00
committed by GitHub
parent 70f469b0c1
commit 391ba1d988
23 changed files with 347 additions and 225 deletions

View File

@@ -46,10 +46,12 @@ describe("HomeHeader", () => {
await userEvent.click(launchButton);
expect(createConversationSpy).toHaveBeenCalledExactlyOnceWith(
"gui",
undefined,
undefined,
[],
undefined,
undefined,
);
// expect to be redirected to /conversations/:conversationId

View File

@@ -155,7 +155,7 @@ describe("RepoConnector", () => {
// select a repository from the dropdown
const dropdown = await waitFor(() =>
within(repoConnector).getByTestId("repo-dropdown")
within(repoConnector).getByTestId("repo-dropdown"),
);
await userEvent.click(dropdown);
@@ -164,6 +164,7 @@ describe("RepoConnector", () => {
await userEvent.click(launchButton);
expect(createConversationSpy).toHaveBeenCalledExactlyOnceWith(
"gui",
{
full_name: "rbren/polaris",
git_provider: "github",
@@ -173,6 +174,7 @@ describe("RepoConnector", () => {
undefined,
[],
undefined,
undefined,
);
});

View File

@@ -11,12 +11,6 @@ import { AuthProvider } from "#/context/auth-context";
import { TaskCard } from "#/components/features/home/tasks/task-card";
import * as GitService from "#/api/git";
import { GitRepository } from "#/types/git";
import {
getFailingChecksPrompt,
getMergeConflictPrompt,
getOpenIssuePrompt,
getUnresolvedCommentsPrompt,
} from "#/components/features/home/tasks/get-prompt-for-query";
const MOCK_TASK_1: SuggestedTask = {
issue_number: 123,
@@ -101,7 +95,7 @@ describe("TaskCard", () => {
expect(createConversationSpy).toHaveBeenCalled();
});
describe("creating conversation prompts", () => {
describe("creating suggested task conversation", () => {
beforeEach(() => {
const retrieveUserGitRepositoriesSpy = vi.spyOn(
GitService,
@@ -113,7 +107,7 @@ describe("TaskCard", () => {
});
});
it("should call create conversation with the merge conflict prompt", async () => {
it("should call create conversation with suggest task trigger and selected suggested task", async () => {
const createConversationSpy = vi.spyOn(OpenHands, "createConversation");
renderTaskCard(MOCK_TASK_1);
@@ -122,74 +116,12 @@ describe("TaskCard", () => {
await userEvent.click(launchButton);
expect(createConversationSpy).toHaveBeenCalledWith(
"suggested_task",
MOCK_RESPOSITORIES[0],
getMergeConflictPrompt(
MOCK_TASK_1.git_provider,
MOCK_TASK_1.issue_number,
MOCK_TASK_1.repo,
),
[],
undefined,
);
});
it("should call create conversation with the failing checks prompt", async () => {
const createConversationSpy = vi.spyOn(OpenHands, "createConversation");
renderTaskCard(MOCK_TASK_2);
const launchButton = screen.getByTestId("task-launch-button");
await userEvent.click(launchButton);
expect(createConversationSpy).toHaveBeenCalledWith(
MOCK_RESPOSITORIES[1],
getFailingChecksPrompt(
MOCK_TASK_2.git_provider,
MOCK_TASK_2.issue_number,
MOCK_TASK_2.repo,
),
[],
undefined,
);
});
it("should call create conversation with the unresolved comments prompt", async () => {
const createConversationSpy = vi.spyOn(OpenHands, "createConversation");
renderTaskCard(MOCK_TASK_3);
const launchButton = screen.getByTestId("task-launch-button");
await userEvent.click(launchButton);
expect(createConversationSpy).toHaveBeenCalledWith(
MOCK_RESPOSITORIES[2],
getUnresolvedCommentsPrompt(
MOCK_TASK_3.git_provider,
MOCK_TASK_3.issue_number,
MOCK_TASK_3.repo,
),
[],
undefined,
);
});
it("should call create conversation with the open issue prompt", async () => {
const createConversationSpy = vi.spyOn(OpenHands, "createConversation");
renderTaskCard(MOCK_TASK_4);
const launchButton = screen.getByTestId("task-launch-button");
await userEvent.click(launchButton);
expect(createConversationSpy).toHaveBeenCalledWith(
MOCK_RESPOSITORIES[3],
getOpenIssuePrompt(
MOCK_TASK_4.git_provider,
MOCK_TASK_4.issue_number,
MOCK_TASK_4.repo,
),
undefined,
[],
undefined,
MOCK_TASK_1,
);
});
});

View File

@@ -10,10 +10,12 @@ import {
GetTrajectoryResponse,
GitChangeDiff,
GitChange,
ConversationTrigger,
} from "./open-hands.types";
import { openHands } from "./open-hands-axios";
import { ApiSettings, PostApiSettings } from "#/types/settings";
import { GitUser, GitRepository } from "#/types/git";
import { SuggestedTask } from "#/components/features/home/tasks/task.types";
class OpenHands {
/**
@@ -149,17 +151,21 @@ class OpenHands {
}
static async createConversation(
conversation_trigger: ConversationTrigger = "gui",
selectedRepository?: GitRepository,
initialUserMsg?: string,
imageUrls?: string[],
replayJson?: string,
suggested_task?: SuggestedTask,
): Promise<Conversation> {
const body = {
conversation_trigger,
selected_repository: selectedRepository,
selected_branch: undefined,
initial_user_msg: initialUserMsg,
image_urls: imageUrls,
replay_json: replayJson,
suggested_task,
};
const { data } = await openHands.post<Conversation>(

View File

@@ -70,7 +70,7 @@ export interface AuthenticateResponse {
error?: string;
}
export type ConversationTrigger = "resolver" | "gui";
export type ConversationTrigger = "resolver" | "gui" | "suggested_task";
export interface Conversation {
conversation_id: string;

View File

@@ -1,33 +0,0 @@
import React from "react";
import { useDispatch } from "react-redux";
import { useTranslation } from "react-i18next";
import { I18nKey } from "#/i18n/declaration";
import { useCreateConversation } from "#/hooks/mutation/use-create-conversation";
import { setInitialPrompt } from "#/state/initial-query-slice";
const INITIAL_PROMPT = "";
export function CodeNotInGitLink() {
const dispatch = useDispatch();
const { t } = useTranslation();
const { mutate: createConversation } = useCreateConversation();
const handleStartFromScratch = () => {
// Set the initial prompt and create a new conversation
dispatch(setInitialPrompt(INITIAL_PROMPT));
createConversation({ q: INITIAL_PROMPT });
};
return (
<div className="text-xs text-neutral-400">
{t(I18nKey.GITHUB$CODE_NOT_IN_GITHUB)}{" "}
<span
onClick={handleStartFromScratch}
className="underline cursor-pointer"
>
{t(I18nKey.GITHUB$START_FROM_SCRATCH)}
</span>{" "}
{t(I18nKey.GITHUB$VSCODE_LINK_DESCRIPTION)}
</div>
);
}

View File

@@ -28,7 +28,7 @@ export function HomeHeader() {
testId="header-launch-button"
variant="primary"
type="button"
onClick={() => createConversation({})}
onClick={() => createConversation({ conversation_trigger: "gui" })}
isDisabled={isCreatingConversation}
>
{!isCreatingConversation && "Launch from Scratch"}

View File

@@ -142,7 +142,12 @@ export function RepositorySelectionForm({
isLoadingRepositories ||
isRepositoriesError
}
onClick={() => createConversation({ selectedRepository })}
onClick={() =>
createConversation({
selectedRepository,
conversation_trigger: "gui",
})
}
>
{!isCreatingConversation && "Launch"}
{isCreatingConversation && t("HOME$LOADING")}

View File

@@ -1,95 +0,0 @@
import { Provider } from "#/types/settings";
import { SuggestedTaskType } from "./task.types";
// Helper function to get provider-specific terminology
const getProviderTerms = (git_provider: Provider) => {
if (git_provider === "gitlab") {
return {
requestType: "Merge Request",
requestTypeShort: "MR",
apiName: "GitLab API",
tokenEnvVar: "GITLAB_TOKEN",
ciSystem: "CI pipelines",
ciProvider: "GitLab",
requestVerb: "merge request",
};
}
return {
requestType: "Pull Request",
requestTypeShort: "PR",
apiName: "GitHub API",
tokenEnvVar: "GITHUB_TOKEN",
ciSystem: "GitHub Actions",
ciProvider: "GitHub",
requestVerb: "pull request",
};
};
export const getMergeConflictPrompt = (
git_provider: Provider,
issueNumber: number,
repo: string,
) => {
const terms = getProviderTerms(git_provider);
return `You are working on ${terms.requestType} #${issueNumber} in repository ${repo}. You need to fix the merge conflicts.
Use the ${terms.apiName} with the ${terms.tokenEnvVar} environment variable to retrieve the ${terms.requestTypeShort} details. Check out the branch from that ${terms.requestVerb} and look at the diff versus the base branch of the ${terms.requestTypeShort} to understand the ${terms.requestTypeShort}'s intention.
Then resolve the merge conflicts. If you aren't sure what the right solution is, look back through the commit history at the commits that introduced the conflict and resolve them accordingly.`;
};
export const getFailingChecksPrompt = (
git_provider: Provider,
issueNumber: number,
repo: string,
) => {
const terms = getProviderTerms(git_provider);
return `You are working on ${terms.requestType} #${issueNumber} in repository ${repo}. You need to fix the failing CI checks.
Use the ${terms.apiName} with the ${terms.tokenEnvVar} environment variable to retrieve the ${terms.requestTypeShort} details. Check out the branch from that ${terms.requestVerb} and look at the diff versus the base branch of the ${terms.requestTypeShort} to understand the ${terms.requestTypeShort}'s intention.
Then use the ${terms.apiName} to look at the ${terms.ciSystem} that are failing on the most recent commit. Try and reproduce the failure locally.
Get things working locally, then push your changes. Sleep for 30 seconds at a time until the ${terms.ciProvider} ${terms.ciSystem.toLowerCase()} have run again. If they are still failing, repeat the process.`;
};
export const getUnresolvedCommentsPrompt = (
git_provider: Provider,
issueNumber: number,
repo: string,
) => {
const terms = getProviderTerms(git_provider);
return `You are working on ${terms.requestType} #${issueNumber} in repository ${repo}. You need to resolve the remaining comments from reviewers.
Use the ${terms.apiName} with the ${terms.tokenEnvVar} environment variable to retrieve the ${terms.requestTypeShort} details. Check out the branch from that ${terms.requestVerb} and look at the diff versus the base branch of the ${terms.requestTypeShort} to understand the ${terms.requestTypeShort}'s intention.
Then use the ${terms.apiName} to retrieve all the feedback on the ${terms.requestTypeShort} so far. If anything hasn't been addressed, address it and commit your changes back to the same branch.`;
};
export const getOpenIssuePrompt = (
git_provider: Provider,
issueNumber: number,
repo: string,
) => {
const terms = getProviderTerms(git_provider);
return `You are working on Issue #${issueNumber} in repository ${repo}. Your goal is to fix the issue.
Use the ${terms.apiName} with the ${terms.tokenEnvVar} environment variable to retrieve the issue details and any comments on the issue. Then check out a new branch and investigate what changes will need to be made.
Finally, make the required changes and open up a ${terms.requestVerb}. Be sure to reference the issue in the ${terms.requestTypeShort} description.`;
};
export const getPromptForQuery = (
git_provider: Provider,
type: SuggestedTaskType,
issueNumber: number,
repo: string,
) => {
switch (type) {
case "MERGE_CONFLICTS":
return getMergeConflictPrompt(git_provider, issueNumber, repo);
case "FAILING_CHECKS":
return getFailingChecksPrompt(git_provider, issueNumber, repo);
case "UNRESOLVED_COMMENTS":
return getUnresolvedCommentsPrompt(git_provider, issueNumber, repo);
case "OPEN_ISSUE":
return getOpenIssuePrompt(git_provider, issueNumber, repo);
default:
return "";
}
};

View File

@@ -4,7 +4,6 @@ import { useIsCreatingConversation } from "#/hooks/use-is-creating-conversation"
import { useCreateConversation } from "#/hooks/mutation/use-create-conversation";
import { cn } from "#/utils/utils";
import { useUserRepositories } from "#/hooks/query/use-user-repositories";
import { getPromptForQuery } from "./get-prompt-for-query";
import { TaskIssueNumber } from "./task-issue-number";
import { Provider } from "#/types/settings";
@@ -40,16 +39,11 @@ export function TaskCard({ task }: TaskCardProps) {
const handleLaunchConversation = () => {
const repo = getRepo(task.repo, task.git_provider);
const query = getPromptForQuery(
task.git_provider,
task.task_type,
task.issue_number,
task.repo,
);
return createConversation({
conversation_trigger: "suggested_task",
selectedRepository: repo,
q: query,
suggested_task: task,
});
};

View File

@@ -52,7 +52,7 @@ export function TaskForm({ ref }: TaskFormProps) {
const formData = new FormData(event.currentTarget);
const q = formData.get("q")?.toString();
createConversation({ q });
createConversation({ q, conversation_trigger: "gui" });
};
return (

View File

@@ -6,6 +6,8 @@ import OpenHands from "#/api/open-hands";
import { setInitialPrompt } from "#/state/initial-query-slice";
import { RootState } from "#/store";
import { GitRepository } from "#/types/git";
import { ConversationTrigger } from "#/api/open-hands.types";
import { SuggestedTask } from "#/components/features/home/tasks/task.types";
export const useCreateConversation = () => {
const navigate = useNavigate();
@@ -19,16 +21,20 @@ export const useCreateConversation = () => {
return useMutation({
mutationKey: ["create-conversation"],
mutationFn: async (variables: {
conversation_trigger: ConversationTrigger;
q?: string;
selectedRepository?: GitRepository | null;
suggested_task?: SuggestedTask;
}) => {
if (variables.q) dispatch(setInitialPrompt(variables.q));
return OpenHands.createConversation(
variables.conversation_trigger,
variables.selectedRepository || undefined,
variables.q,
files,
replayJson || undefined,
variables.suggested_task || undefined,
);
},
onSuccess: async ({ conversation_id: conversationId }, { q }) => {

View File

@@ -3,6 +3,7 @@ from enum import Enum
from typing import Any, Protocol
from httpx import AsyncClient, HTTPError, HTTPStatusError
from jinja2 import Environment, FileSystemLoader
from pydantic import BaseModel, SecretStr
from openhands.core.logger import openhands_logger as logger
@@ -29,6 +30,57 @@ class SuggestedTask(BaseModel):
issue_number: int
title: str
def get_provider_terms(self) -> dict:
if self.git_provider == ProviderType.GITLAB:
return {
'requestType': 'Merge Request',
'requestTypeShort': 'MR',
'apiName': 'GitLab API',
'tokenEnvVar': 'GITLAB_TOKEN',
'ciSystem': 'CI pipelines',
'ciProvider': 'GitLab',
'requestVerb': 'merge request',
}
elif self.git_provider == ProviderType.GITHUB:
return {
'requestType': 'Pull Request',
'requestTypeShort': 'PR',
'apiName': 'GitHub API',
'tokenEnvVar': 'GITHUB_TOKEN',
'ciSystem': 'GitHub Actions',
'ciProvider': 'GitHub',
'requestVerb': 'pull request',
}
raise ValueError(f'Provider {self.git_provider} for suggested task prompts')
def get_prompt_for_task(
self,
) -> str:
task_type = self.task_type
issue_number = self.issue_number
repo = self.repo
env = Environment(
loader=FileSystemLoader('openhands/integrations/templates/suggested_task')
)
template = None
if task_type == TaskType.MERGE_CONFLICTS:
template = env.get_template('merge_conflict_prompt.j2')
elif task_type == TaskType.FAILING_CHECKS:
template = env.get_template('failing_checks_prompt.j2')
elif task_type == TaskType.UNRESOLVED_COMMENTS:
template = env.get_template('unresolved_comments_prompt.j2')
elif task_type == TaskType.OPEN_ISSUE:
template = env.get_template('open_issue_prompt.j2')
else:
raise ValueError(f'Unsupported task type: {task_type}')
terms = self.get_provider_terms()
return template.render(issue_number=issue_number, repo=repo, **terms)
class User(BaseModel):
id: int

View File

@@ -0,0 +1,6 @@
You are working on {{ requestType }} #{{ issue_number }} in repository {{ repo }}. You need to fix the failing CI checks.
Use the {{ apiName }} with the {{ tokenEnvVar }} environment variable to retrieve the {{ requestTypeShort }} details.
Check out the branch from that {{ requestVerb }} and look at the diff versus the base branch of the {{ requestTypeShort }} to understand the {{ requestTypeShort }}'s intention.
Then use the {{ apiName }} to look at the {{ ciSystem }} that are failing on the most recent commit. Try and reproduce the failure locally.
Get things working locally, then push your changes. Sleep for 30 seconds at a time until the {{ ciProvider }} {{ ciSystem.lower() }} have run again.
If they are still failing, repeat the process.

View File

@@ -0,0 +1,4 @@
You are working on {{ requestType }} #{{ issue_number }} in repository {{ repo }}. You need to fix the merge conflicts.
Use the {{ apiName }} with the {{ tokenEnvVar }} environment variable to retrieve the {{ requestTypeShort }} details.
Check out the branch from that {{ requestVerb }} and look at the diff versus the base branch of the {{ requestTypeShort }} to understand the {{ requestTypeShort }}'s intention.
Then resolve the merge conflicts. If you aren't sure what the right solution is, look back through the commit history at the commits that introduced the conflict and resolve them accordingly.

View File

@@ -0,0 +1,4 @@
You are working on Issue #{{ issue_number }} in repository {{ repo }}. Your goal is to fix the issue.
Use the {{ apiName }} with the {{ tokenEnvVar }} environment variable to retrieve the issue details and any comments on the issue.
Then check out a new branch and investigate what changes will need to be made.
Finally, make the required changes and open up a {{ requestVerb }}. Be sure to reference the issue in the {{ requestTypeShort }} description.

View File

@@ -0,0 +1,5 @@
You are working on {{ requestType }} #{{ issue_number }} in repository {{ repo }}. You need to resolve the remaining comments from reviewers.
Use the {{ apiName }} with the {{ tokenEnvVar }} environment variable to retrieve the {{ requestTypeShort }} details.
Check out the branch from that {{ requestVerb }} and look at the diff versus the base branch of the {{ requestTypeShort }} to understand the {{ requestTypeShort }}'s intention.
Then use the {{ apiName }} to retrieve all the feedback on the {{ requestTypeShort }} so far.
If anything hasn't been addressed, address it and commit your changes back to the same branch.

View File

@@ -13,7 +13,7 @@ from openhands.events.stream import EventStream
from openhands.integrations.provider import (
PROVIDER_TOKEN_TYPE,
)
from openhands.integrations.service_types import Repository
from openhands.integrations.service_types import Repository, SuggestedTask
from openhands.runtime import get_runtime_cls
from openhands.server.data_models.conversation_info import ConversationInfo
from openhands.server.data_models.conversation_info_result_set import (
@@ -42,16 +42,19 @@ from openhands.storage.data_models.conversation_status import ConversationStatus
from openhands.utils.async_utils import wait_all
from openhands.utils.conversation_summary import generate_conversation_title
app = APIRouter(prefix='/api')
class InitSessionRequest(BaseModel):
conversation_trigger: ConversationTrigger = ConversationTrigger.GUI
selected_repository: Repository | None = None
selected_branch: str | None = None
initial_user_msg: str | None = None
image_urls: list[str] | None = None
replay_json: str | None = None
suggested_task: SuggestedTask | None = None
async def _create_new_conversation(
user_id: str | None,
@@ -64,9 +67,10 @@ async def _create_new_conversation(
conversation_trigger: ConversationTrigger = ConversationTrigger.GUI,
attach_convo_id: bool = False,
):
print("trigger", conversation_trigger)
logger.info(
'Creating conversation',
extra={'signal': 'create_conversation', 'user_id': user_id},
extra={'signal': 'create_conversation', 'user_id': user_id, 'trigger': conversation_trigger.value},
)
logger.info('Loading settings')
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
@@ -169,17 +173,24 @@ async def new_conversation(
initial_user_msg = data.initial_user_msg
image_urls = data.image_urls or []
replay_json = data.replay_json
suggested_task = data.suggested_task
conversation_trigger = data.conversation_trigger
if suggested_task:
initial_user_msg = suggested_task.get_prompt_for_task()
conversation_trigger = ConversationTrigger.SUGGESTED_TASK
try:
# Create conversation with initial message
conversation_id = await _create_new_conversation(
user_id,
provider_tokens,
selected_repository,
selected_branch,
initial_user_msg,
image_urls,
replay_json,
user_id=user_id,
git_provider_tokens=provider_tokens,
selected_repository=selected_repository,
selected_branch=selected_branch,
initial_user_msg=initial_user_msg,
image_urls=image_urls,
replay_json=replay_json,
conversation_trigger=conversation_trigger
)
return JSONResponse(

View File

@@ -6,6 +6,7 @@ from enum import Enum
class ConversationTrigger(Enum):
RESOLVER = 'resolver'
GUI = 'gui'
SUGGESTED_TASK = 'suggested_task'
@dataclass

View File

@@ -4,18 +4,31 @@ from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi.responses import JSONResponse
from openhands.integrations.service_types import (
ProviderType,
Repository,
SuggestedTask,
TaskType,
)
from openhands.server.data_models.conversation_info import ConversationInfo
from openhands.server.data_models.conversation_info_result_set import (
ConversationInfoResultSet,
)
from openhands.server.routes.manage_conversations import (
InitSessionRequest,
delete_conversation,
get_conversation,
new_conversation,
search_conversations,
update_conversation,
)
from openhands.storage.data_models.conversation_metadata import ConversationMetadata
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
from openhands.storage.data_models.conversation_metadata import (
ConversationMetadata,
ConversationTrigger,
)
from openhands.storage.data_models.conversation_status import ConversationStatus
from openhands.storage.locations import get_conversation_metadata_filename
from openhands.storage.memory import InMemoryFileStore
@@ -218,6 +231,213 @@ async def test_update_conversation():
assert saved_metadata.title == 'New Title'
@pytest.mark.asyncio
async def test_new_conversation_success():
"""Test successful creation of a new conversation."""
with _patch_store():
# Mock the _create_new_conversation function directly
with patch(
'openhands.server.routes.manage_conversations._create_new_conversation'
) as mock_create_conversation:
# Set up the mock to return a conversation ID
mock_create_conversation.return_value = 'test_conversation_id'
# Create test data
test_repo = Repository(
id=12345,
full_name='test/repo',
git_provider=ProviderType.GITHUB,
is_public=True,
)
test_request = InitSessionRequest(
conversation_trigger=ConversationTrigger.GUI,
selected_repository=test_repo,
selected_branch='main',
initial_user_msg='Hello, agent!',
image_urls=['https://example.com/image.jpg'],
)
# Call new_conversation
response = await new_conversation(
data=test_request, user_id='test_user', provider_tokens={}
)
# Verify the response
assert isinstance(response, JSONResponse)
assert response.status_code == 200
assert (
response.body.decode('utf-8')
== '{"status":"ok","conversation_id":"test_conversation_id"}'
)
# Verify that _create_new_conversation was called with the correct arguments
mock_create_conversation.assert_called_once()
call_args = mock_create_conversation.call_args[1]
assert call_args['user_id'] == 'test_user'
assert call_args['selected_repository'] == test_repo
assert call_args['selected_branch'] == 'main'
assert call_args['initial_user_msg'] == 'Hello, agent!'
assert call_args['image_urls'] == ['https://example.com/image.jpg']
assert call_args['conversation_trigger'] == ConversationTrigger.GUI
@pytest.mark.asyncio
async def test_new_conversation_with_suggested_task():
"""Test creating a new conversation with a suggested task."""
with _patch_store():
# Mock the _create_new_conversation function directly
with patch(
'openhands.server.routes.manage_conversations._create_new_conversation'
) as mock_create_conversation:
# Set up the mock to return a conversation ID
mock_create_conversation.return_value = 'test_conversation_id'
# Mock SuggestedTask.get_prompt_for_task
with patch(
'openhands.integrations.service_types.SuggestedTask.get_prompt_for_task'
) as mock_get_prompt:
mock_get_prompt.return_value = (
'Please fix the failing checks in PR #123'
)
# Create test data
test_repo = Repository(
id=12345,
full_name='test/repo',
git_provider=ProviderType.GITHUB,
is_public=True,
)
test_task = SuggestedTask(
git_provider=ProviderType.GITHUB,
task_type=TaskType.FAILING_CHECKS,
repo='test/repo',
issue_number=123,
title='Fix failing checks',
)
test_request = InitSessionRequest(
conversation_trigger=ConversationTrigger.SUGGESTED_TASK,
selected_repository=test_repo,
selected_branch='main',
suggested_task=test_task,
)
# Call new_conversation
response = await new_conversation(
data=test_request, user_id='test_user', provider_tokens={}
)
# Verify the response
assert isinstance(response, JSONResponse)
assert response.status_code == 200
assert (
response.body.decode('utf-8')
== '{"status":"ok","conversation_id":"test_conversation_id"}'
)
# Verify that _create_new_conversation was called with the correct arguments
mock_create_conversation.assert_called_once()
call_args = mock_create_conversation.call_args[1]
assert call_args['user_id'] == 'test_user'
assert call_args['selected_repository'] == test_repo
assert call_args['selected_branch'] == 'main'
assert (
call_args['initial_user_msg']
== 'Please fix the failing checks in PR #123'
)
assert (
call_args['conversation_trigger']
== ConversationTrigger.SUGGESTED_TASK
)
# Verify that get_prompt_for_task was called
mock_get_prompt.assert_called_once()
@pytest.mark.asyncio
async def test_new_conversation_missing_settings():
"""Test creating a new conversation when settings are missing."""
with _patch_store():
# Mock the _create_new_conversation function to raise MissingSettingsError
with patch(
'openhands.server.routes.manage_conversations._create_new_conversation'
) as mock_create_conversation:
# Set up the mock to raise MissingSettingsError
mock_create_conversation.side_effect = MissingSettingsError(
'Settings not found'
)
# Create test data
test_repo = Repository(
id=12345,
full_name='test/repo',
git_provider=ProviderType.GITHUB,
is_public=True,
)
test_request = InitSessionRequest(
conversation_trigger=ConversationTrigger.GUI,
selected_repository=test_repo,
selected_branch='main',
initial_user_msg='Hello, agent!',
)
# Call new_conversation
response = await new_conversation(
data=test_request, user_id='test_user', provider_tokens={}
)
# Verify the response
assert isinstance(response, JSONResponse)
assert response.status_code == 400
assert 'Settings not found' in response.body.decode('utf-8')
assert 'CONFIGURATION$SETTINGS_NOT_FOUND' in response.body.decode('utf-8')
@pytest.mark.asyncio
async def test_new_conversation_invalid_api_key():
"""Test creating a new conversation with an invalid API key."""
with _patch_store():
# Mock the _create_new_conversation function to raise LLMAuthenticationError
with patch(
'openhands.server.routes.manage_conversations._create_new_conversation'
) as mock_create_conversation:
# Set up the mock to raise LLMAuthenticationError
mock_create_conversation.side_effect = LLMAuthenticationError(
'Error authenticating with the LLM provider. Please check your API key'
)
# Create test data
test_repo = Repository(
id=12345,
full_name='test/repo',
git_provider=ProviderType.GITHUB,
is_public=True,
)
test_request = InitSessionRequest(
conversation_trigger=ConversationTrigger.GUI,
selected_repository=test_repo,
selected_branch='main',
initial_user_msg='Hello, agent!',
)
# Call new_conversation
response = await new_conversation(
data=test_request, user_id='test_user', provider_tokens={}
)
# Verify the response
assert isinstance(response, JSONResponse)
assert response.status_code == 400
assert 'Error authenticating with the LLM provider' in response.body.decode(
'utf-8'
)
assert 'STATUS$ERROR_LLM_AUTHENTICATION' in response.body.decode('utf-8')
@pytest.mark.asyncio
async def test_delete_conversation():
with _patch_store():