mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-08 22:38:05 -05:00
[Fix]: Move suggest task prompts to BE (#8109)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
@@ -46,10 +46,12 @@ describe("HomeHeader", () => {
|
||||
await userEvent.click(launchButton);
|
||||
|
||||
expect(createConversationSpy).toHaveBeenCalledExactlyOnceWith(
|
||||
"gui",
|
||||
undefined,
|
||||
undefined,
|
||||
[],
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
|
||||
// expect to be redirected to /conversations/:conversationId
|
||||
|
||||
@@ -155,7 +155,7 @@ describe("RepoConnector", () => {
|
||||
|
||||
// select a repository from the dropdown
|
||||
const dropdown = await waitFor(() =>
|
||||
within(repoConnector).getByTestId("repo-dropdown")
|
||||
within(repoConnector).getByTestId("repo-dropdown"),
|
||||
);
|
||||
await userEvent.click(dropdown);
|
||||
|
||||
@@ -164,6 +164,7 @@ describe("RepoConnector", () => {
|
||||
await userEvent.click(launchButton);
|
||||
|
||||
expect(createConversationSpy).toHaveBeenCalledExactlyOnceWith(
|
||||
"gui",
|
||||
{
|
||||
full_name: "rbren/polaris",
|
||||
git_provider: "github",
|
||||
@@ -173,6 +174,7 @@ describe("RepoConnector", () => {
|
||||
undefined,
|
||||
[],
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
@@ -11,12 +11,6 @@ import { AuthProvider } from "#/context/auth-context";
|
||||
import { TaskCard } from "#/components/features/home/tasks/task-card";
|
||||
import * as GitService from "#/api/git";
|
||||
import { GitRepository } from "#/types/git";
|
||||
import {
|
||||
getFailingChecksPrompt,
|
||||
getMergeConflictPrompt,
|
||||
getOpenIssuePrompt,
|
||||
getUnresolvedCommentsPrompt,
|
||||
} from "#/components/features/home/tasks/get-prompt-for-query";
|
||||
|
||||
const MOCK_TASK_1: SuggestedTask = {
|
||||
issue_number: 123,
|
||||
@@ -101,7 +95,7 @@ describe("TaskCard", () => {
|
||||
expect(createConversationSpy).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
describe("creating conversation prompts", () => {
|
||||
describe("creating suggested task conversation", () => {
|
||||
beforeEach(() => {
|
||||
const retrieveUserGitRepositoriesSpy = vi.spyOn(
|
||||
GitService,
|
||||
@@ -113,7 +107,7 @@ describe("TaskCard", () => {
|
||||
});
|
||||
});
|
||||
|
||||
it("should call create conversation with the merge conflict prompt", async () => {
|
||||
it("should call create conversation with suggest task trigger and selected suggested task", async () => {
|
||||
const createConversationSpy = vi.spyOn(OpenHands, "createConversation");
|
||||
|
||||
renderTaskCard(MOCK_TASK_1);
|
||||
@@ -122,74 +116,12 @@ describe("TaskCard", () => {
|
||||
await userEvent.click(launchButton);
|
||||
|
||||
expect(createConversationSpy).toHaveBeenCalledWith(
|
||||
"suggested_task",
|
||||
MOCK_RESPOSITORIES[0],
|
||||
getMergeConflictPrompt(
|
||||
MOCK_TASK_1.git_provider,
|
||||
MOCK_TASK_1.issue_number,
|
||||
MOCK_TASK_1.repo,
|
||||
),
|
||||
[],
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
|
||||
it("should call create conversation with the failing checks prompt", async () => {
|
||||
const createConversationSpy = vi.spyOn(OpenHands, "createConversation");
|
||||
|
||||
renderTaskCard(MOCK_TASK_2);
|
||||
|
||||
const launchButton = screen.getByTestId("task-launch-button");
|
||||
await userEvent.click(launchButton);
|
||||
|
||||
expect(createConversationSpy).toHaveBeenCalledWith(
|
||||
MOCK_RESPOSITORIES[1],
|
||||
getFailingChecksPrompt(
|
||||
MOCK_TASK_2.git_provider,
|
||||
MOCK_TASK_2.issue_number,
|
||||
MOCK_TASK_2.repo,
|
||||
),
|
||||
[],
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
|
||||
it("should call create conversation with the unresolved comments prompt", async () => {
|
||||
const createConversationSpy = vi.spyOn(OpenHands, "createConversation");
|
||||
|
||||
renderTaskCard(MOCK_TASK_3);
|
||||
|
||||
const launchButton = screen.getByTestId("task-launch-button");
|
||||
await userEvent.click(launchButton);
|
||||
|
||||
expect(createConversationSpy).toHaveBeenCalledWith(
|
||||
MOCK_RESPOSITORIES[2],
|
||||
getUnresolvedCommentsPrompt(
|
||||
MOCK_TASK_3.git_provider,
|
||||
MOCK_TASK_3.issue_number,
|
||||
MOCK_TASK_3.repo,
|
||||
),
|
||||
[],
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
|
||||
it("should call create conversation with the open issue prompt", async () => {
|
||||
const createConversationSpy = vi.spyOn(OpenHands, "createConversation");
|
||||
|
||||
renderTaskCard(MOCK_TASK_4);
|
||||
|
||||
const launchButton = screen.getByTestId("task-launch-button");
|
||||
await userEvent.click(launchButton);
|
||||
|
||||
expect(createConversationSpy).toHaveBeenCalledWith(
|
||||
MOCK_RESPOSITORIES[3],
|
||||
getOpenIssuePrompt(
|
||||
MOCK_TASK_4.git_provider,
|
||||
MOCK_TASK_4.issue_number,
|
||||
MOCK_TASK_4.repo,
|
||||
),
|
||||
undefined,
|
||||
[],
|
||||
undefined,
|
||||
MOCK_TASK_1,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -10,10 +10,12 @@ import {
|
||||
GetTrajectoryResponse,
|
||||
GitChangeDiff,
|
||||
GitChange,
|
||||
ConversationTrigger,
|
||||
} from "./open-hands.types";
|
||||
import { openHands } from "./open-hands-axios";
|
||||
import { ApiSettings, PostApiSettings } from "#/types/settings";
|
||||
import { GitUser, GitRepository } from "#/types/git";
|
||||
import { SuggestedTask } from "#/components/features/home/tasks/task.types";
|
||||
|
||||
class OpenHands {
|
||||
/**
|
||||
@@ -149,17 +151,21 @@ class OpenHands {
|
||||
}
|
||||
|
||||
static async createConversation(
|
||||
conversation_trigger: ConversationTrigger = "gui",
|
||||
selectedRepository?: GitRepository,
|
||||
initialUserMsg?: string,
|
||||
imageUrls?: string[],
|
||||
replayJson?: string,
|
||||
suggested_task?: SuggestedTask,
|
||||
): Promise<Conversation> {
|
||||
const body = {
|
||||
conversation_trigger,
|
||||
selected_repository: selectedRepository,
|
||||
selected_branch: undefined,
|
||||
initial_user_msg: initialUserMsg,
|
||||
image_urls: imageUrls,
|
||||
replay_json: replayJson,
|
||||
suggested_task,
|
||||
};
|
||||
|
||||
const { data } = await openHands.post<Conversation>(
|
||||
|
||||
@@ -70,7 +70,7 @@ export interface AuthenticateResponse {
|
||||
error?: string;
|
||||
}
|
||||
|
||||
export type ConversationTrigger = "resolver" | "gui";
|
||||
export type ConversationTrigger = "resolver" | "gui" | "suggested_task";
|
||||
|
||||
export interface Conversation {
|
||||
conversation_id: string;
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
import React from "react";
|
||||
import { useDispatch } from "react-redux";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { useCreateConversation } from "#/hooks/mutation/use-create-conversation";
|
||||
import { setInitialPrompt } from "#/state/initial-query-slice";
|
||||
|
||||
const INITIAL_PROMPT = "";
|
||||
|
||||
export function CodeNotInGitLink() {
|
||||
const dispatch = useDispatch();
|
||||
const { t } = useTranslation();
|
||||
const { mutate: createConversation } = useCreateConversation();
|
||||
|
||||
const handleStartFromScratch = () => {
|
||||
// Set the initial prompt and create a new conversation
|
||||
dispatch(setInitialPrompt(INITIAL_PROMPT));
|
||||
createConversation({ q: INITIAL_PROMPT });
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="text-xs text-neutral-400">
|
||||
{t(I18nKey.GITHUB$CODE_NOT_IN_GITHUB)}{" "}
|
||||
<span
|
||||
onClick={handleStartFromScratch}
|
||||
className="underline cursor-pointer"
|
||||
>
|
||||
{t(I18nKey.GITHUB$START_FROM_SCRATCH)}
|
||||
</span>{" "}
|
||||
{t(I18nKey.GITHUB$VSCODE_LINK_DESCRIPTION)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -28,7 +28,7 @@ export function HomeHeader() {
|
||||
testId="header-launch-button"
|
||||
variant="primary"
|
||||
type="button"
|
||||
onClick={() => createConversation({})}
|
||||
onClick={() => createConversation({ conversation_trigger: "gui" })}
|
||||
isDisabled={isCreatingConversation}
|
||||
>
|
||||
{!isCreatingConversation && "Launch from Scratch"}
|
||||
|
||||
@@ -142,7 +142,12 @@ export function RepositorySelectionForm({
|
||||
isLoadingRepositories ||
|
||||
isRepositoriesError
|
||||
}
|
||||
onClick={() => createConversation({ selectedRepository })}
|
||||
onClick={() =>
|
||||
createConversation({
|
||||
selectedRepository,
|
||||
conversation_trigger: "gui",
|
||||
})
|
||||
}
|
||||
>
|
||||
{!isCreatingConversation && "Launch"}
|
||||
{isCreatingConversation && t("HOME$LOADING")}
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
import { Provider } from "#/types/settings";
|
||||
import { SuggestedTaskType } from "./task.types";
|
||||
|
||||
// Helper function to get provider-specific terminology
|
||||
const getProviderTerms = (git_provider: Provider) => {
|
||||
if (git_provider === "gitlab") {
|
||||
return {
|
||||
requestType: "Merge Request",
|
||||
requestTypeShort: "MR",
|
||||
apiName: "GitLab API",
|
||||
tokenEnvVar: "GITLAB_TOKEN",
|
||||
ciSystem: "CI pipelines",
|
||||
ciProvider: "GitLab",
|
||||
requestVerb: "merge request",
|
||||
};
|
||||
}
|
||||
return {
|
||||
requestType: "Pull Request",
|
||||
requestTypeShort: "PR",
|
||||
apiName: "GitHub API",
|
||||
tokenEnvVar: "GITHUB_TOKEN",
|
||||
ciSystem: "GitHub Actions",
|
||||
ciProvider: "GitHub",
|
||||
requestVerb: "pull request",
|
||||
};
|
||||
};
|
||||
|
||||
export const getMergeConflictPrompt = (
|
||||
git_provider: Provider,
|
||||
issueNumber: number,
|
||||
repo: string,
|
||||
) => {
|
||||
const terms = getProviderTerms(git_provider);
|
||||
|
||||
return `You are working on ${terms.requestType} #${issueNumber} in repository ${repo}. You need to fix the merge conflicts.
|
||||
Use the ${terms.apiName} with the ${terms.tokenEnvVar} environment variable to retrieve the ${terms.requestTypeShort} details. Check out the branch from that ${terms.requestVerb} and look at the diff versus the base branch of the ${terms.requestTypeShort} to understand the ${terms.requestTypeShort}'s intention.
|
||||
Then resolve the merge conflicts. If you aren't sure what the right solution is, look back through the commit history at the commits that introduced the conflict and resolve them accordingly.`;
|
||||
};
|
||||
|
||||
export const getFailingChecksPrompt = (
|
||||
git_provider: Provider,
|
||||
issueNumber: number,
|
||||
repo: string,
|
||||
) => {
|
||||
const terms = getProviderTerms(git_provider);
|
||||
|
||||
return `You are working on ${terms.requestType} #${issueNumber} in repository ${repo}. You need to fix the failing CI checks.
|
||||
Use the ${terms.apiName} with the ${terms.tokenEnvVar} environment variable to retrieve the ${terms.requestTypeShort} details. Check out the branch from that ${terms.requestVerb} and look at the diff versus the base branch of the ${terms.requestTypeShort} to understand the ${terms.requestTypeShort}'s intention.
|
||||
Then use the ${terms.apiName} to look at the ${terms.ciSystem} that are failing on the most recent commit. Try and reproduce the failure locally.
|
||||
Get things working locally, then push your changes. Sleep for 30 seconds at a time until the ${terms.ciProvider} ${terms.ciSystem.toLowerCase()} have run again. If they are still failing, repeat the process.`;
|
||||
};
|
||||
|
||||
export const getUnresolvedCommentsPrompt = (
|
||||
git_provider: Provider,
|
||||
issueNumber: number,
|
||||
repo: string,
|
||||
) => {
|
||||
const terms = getProviderTerms(git_provider);
|
||||
|
||||
return `You are working on ${terms.requestType} #${issueNumber} in repository ${repo}. You need to resolve the remaining comments from reviewers.
|
||||
Use the ${terms.apiName} with the ${terms.tokenEnvVar} environment variable to retrieve the ${terms.requestTypeShort} details. Check out the branch from that ${terms.requestVerb} and look at the diff versus the base branch of the ${terms.requestTypeShort} to understand the ${terms.requestTypeShort}'s intention.
|
||||
Then use the ${terms.apiName} to retrieve all the feedback on the ${terms.requestTypeShort} so far. If anything hasn't been addressed, address it and commit your changes back to the same branch.`;
|
||||
};
|
||||
|
||||
export const getOpenIssuePrompt = (
|
||||
git_provider: Provider,
|
||||
issueNumber: number,
|
||||
repo: string,
|
||||
) => {
|
||||
const terms = getProviderTerms(git_provider);
|
||||
|
||||
return `You are working on Issue #${issueNumber} in repository ${repo}. Your goal is to fix the issue.
|
||||
Use the ${terms.apiName} with the ${terms.tokenEnvVar} environment variable to retrieve the issue details and any comments on the issue. Then check out a new branch and investigate what changes will need to be made.
|
||||
Finally, make the required changes and open up a ${terms.requestVerb}. Be sure to reference the issue in the ${terms.requestTypeShort} description.`;
|
||||
};
|
||||
|
||||
export const getPromptForQuery = (
|
||||
git_provider: Provider,
|
||||
type: SuggestedTaskType,
|
||||
issueNumber: number,
|
||||
repo: string,
|
||||
) => {
|
||||
switch (type) {
|
||||
case "MERGE_CONFLICTS":
|
||||
return getMergeConflictPrompt(git_provider, issueNumber, repo);
|
||||
case "FAILING_CHECKS":
|
||||
return getFailingChecksPrompt(git_provider, issueNumber, repo);
|
||||
case "UNRESOLVED_COMMENTS":
|
||||
return getUnresolvedCommentsPrompt(git_provider, issueNumber, repo);
|
||||
case "OPEN_ISSUE":
|
||||
return getOpenIssuePrompt(git_provider, issueNumber, repo);
|
||||
default:
|
||||
return "";
|
||||
}
|
||||
};
|
||||
@@ -4,7 +4,6 @@ import { useIsCreatingConversation } from "#/hooks/use-is-creating-conversation"
|
||||
import { useCreateConversation } from "#/hooks/mutation/use-create-conversation";
|
||||
import { cn } from "#/utils/utils";
|
||||
import { useUserRepositories } from "#/hooks/query/use-user-repositories";
|
||||
import { getPromptForQuery } from "./get-prompt-for-query";
|
||||
import { TaskIssueNumber } from "./task-issue-number";
|
||||
import { Provider } from "#/types/settings";
|
||||
|
||||
@@ -40,16 +39,11 @@ export function TaskCard({ task }: TaskCardProps) {
|
||||
|
||||
const handleLaunchConversation = () => {
|
||||
const repo = getRepo(task.repo, task.git_provider);
|
||||
const query = getPromptForQuery(
|
||||
task.git_provider,
|
||||
task.task_type,
|
||||
task.issue_number,
|
||||
task.repo,
|
||||
);
|
||||
|
||||
return createConversation({
|
||||
conversation_trigger: "suggested_task",
|
||||
selectedRepository: repo,
|
||||
q: query,
|
||||
suggested_task: task,
|
||||
});
|
||||
};
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ export function TaskForm({ ref }: TaskFormProps) {
|
||||
const formData = new FormData(event.currentTarget);
|
||||
|
||||
const q = formData.get("q")?.toString();
|
||||
createConversation({ q });
|
||||
createConversation({ q, conversation_trigger: "gui" });
|
||||
};
|
||||
|
||||
return (
|
||||
|
||||
@@ -6,6 +6,8 @@ import OpenHands from "#/api/open-hands";
|
||||
import { setInitialPrompt } from "#/state/initial-query-slice";
|
||||
import { RootState } from "#/store";
|
||||
import { GitRepository } from "#/types/git";
|
||||
import { ConversationTrigger } from "#/api/open-hands.types";
|
||||
import { SuggestedTask } from "#/components/features/home/tasks/task.types";
|
||||
|
||||
export const useCreateConversation = () => {
|
||||
const navigate = useNavigate();
|
||||
@@ -19,16 +21,20 @@ export const useCreateConversation = () => {
|
||||
return useMutation({
|
||||
mutationKey: ["create-conversation"],
|
||||
mutationFn: async (variables: {
|
||||
conversation_trigger: ConversationTrigger;
|
||||
q?: string;
|
||||
selectedRepository?: GitRepository | null;
|
||||
suggested_task?: SuggestedTask;
|
||||
}) => {
|
||||
if (variables.q) dispatch(setInitialPrompt(variables.q));
|
||||
|
||||
return OpenHands.createConversation(
|
||||
variables.conversation_trigger,
|
||||
variables.selectedRepository || undefined,
|
||||
variables.q,
|
||||
files,
|
||||
replayJson || undefined,
|
||||
variables.suggested_task || undefined,
|
||||
);
|
||||
},
|
||||
onSuccess: async ({ conversation_id: conversationId }, { q }) => {
|
||||
|
||||
@@ -3,6 +3,7 @@ from enum import Enum
|
||||
from typing import Any, Protocol
|
||||
|
||||
from httpx import AsyncClient, HTTPError, HTTPStatusError
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@@ -29,6 +30,57 @@ class SuggestedTask(BaseModel):
|
||||
issue_number: int
|
||||
title: str
|
||||
|
||||
def get_provider_terms(self) -> dict:
|
||||
if self.git_provider == ProviderType.GITLAB:
|
||||
return {
|
||||
'requestType': 'Merge Request',
|
||||
'requestTypeShort': 'MR',
|
||||
'apiName': 'GitLab API',
|
||||
'tokenEnvVar': 'GITLAB_TOKEN',
|
||||
'ciSystem': 'CI pipelines',
|
||||
'ciProvider': 'GitLab',
|
||||
'requestVerb': 'merge request',
|
||||
}
|
||||
elif self.git_provider == ProviderType.GITHUB:
|
||||
return {
|
||||
'requestType': 'Pull Request',
|
||||
'requestTypeShort': 'PR',
|
||||
'apiName': 'GitHub API',
|
||||
'tokenEnvVar': 'GITHUB_TOKEN',
|
||||
'ciSystem': 'GitHub Actions',
|
||||
'ciProvider': 'GitHub',
|
||||
'requestVerb': 'pull request',
|
||||
}
|
||||
|
||||
raise ValueError(f'Provider {self.git_provider} for suggested task prompts')
|
||||
|
||||
def get_prompt_for_task(
|
||||
self,
|
||||
) -> str:
|
||||
task_type = self.task_type
|
||||
issue_number = self.issue_number
|
||||
repo = self.repo
|
||||
|
||||
env = Environment(
|
||||
loader=FileSystemLoader('openhands/integrations/templates/suggested_task')
|
||||
)
|
||||
|
||||
template = None
|
||||
if task_type == TaskType.MERGE_CONFLICTS:
|
||||
template = env.get_template('merge_conflict_prompt.j2')
|
||||
elif task_type == TaskType.FAILING_CHECKS:
|
||||
template = env.get_template('failing_checks_prompt.j2')
|
||||
elif task_type == TaskType.UNRESOLVED_COMMENTS:
|
||||
template = env.get_template('unresolved_comments_prompt.j2')
|
||||
elif task_type == TaskType.OPEN_ISSUE:
|
||||
template = env.get_template('open_issue_prompt.j2')
|
||||
else:
|
||||
raise ValueError(f'Unsupported task type: {task_type}')
|
||||
|
||||
terms = self.get_provider_terms()
|
||||
|
||||
return template.render(issue_number=issue_number, repo=repo, **terms)
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
id: int
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
You are working on {{ requestType }} #{{ issue_number }} in repository {{ repo }}. You need to fix the failing CI checks.
|
||||
Use the {{ apiName }} with the {{ tokenEnvVar }} environment variable to retrieve the {{ requestTypeShort }} details.
|
||||
Check out the branch from that {{ requestVerb }} and look at the diff versus the base branch of the {{ requestTypeShort }} to understand the {{ requestTypeShort }}'s intention.
|
||||
Then use the {{ apiName }} to look at the {{ ciSystem }} that are failing on the most recent commit. Try and reproduce the failure locally.
|
||||
Get things working locally, then push your changes. Sleep for 30 seconds at a time until the {{ ciProvider }} {{ ciSystem.lower() }} have run again.
|
||||
If they are still failing, repeat the process.
|
||||
@@ -0,0 +1,4 @@
|
||||
You are working on {{ requestType }} #{{ issue_number }} in repository {{ repo }}. You need to fix the merge conflicts.
|
||||
Use the {{ apiName }} with the {{ tokenEnvVar }} environment variable to retrieve the {{ requestTypeShort }} details.
|
||||
Check out the branch from that {{ requestVerb }} and look at the diff versus the base branch of the {{ requestTypeShort }} to understand the {{ requestTypeShort }}'s intention.
|
||||
Then resolve the merge conflicts. If you aren't sure what the right solution is, look back through the commit history at the commits that introduced the conflict and resolve them accordingly.
|
||||
@@ -0,0 +1,4 @@
|
||||
You are working on Issue #{{ issue_number }} in repository {{ repo }}. Your goal is to fix the issue.
|
||||
Use the {{ apiName }} with the {{ tokenEnvVar }} environment variable to retrieve the issue details and any comments on the issue.
|
||||
Then check out a new branch and investigate what changes will need to be made.
|
||||
Finally, make the required changes and open up a {{ requestVerb }}. Be sure to reference the issue in the {{ requestTypeShort }} description.
|
||||
@@ -0,0 +1,5 @@
|
||||
You are working on {{ requestType }} #{{ issue_number }} in repository {{ repo }}. You need to resolve the remaining comments from reviewers.
|
||||
Use the {{ apiName }} with the {{ tokenEnvVar }} environment variable to retrieve the {{ requestTypeShort }} details.
|
||||
Check out the branch from that {{ requestVerb }} and look at the diff versus the base branch of the {{ requestTypeShort }} to understand the {{ requestTypeShort }}'s intention.
|
||||
Then use the {{ apiName }} to retrieve all the feedback on the {{ requestTypeShort }} so far.
|
||||
If anything hasn't been addressed, address it and commit your changes back to the same branch.
|
||||
@@ -13,7 +13,7 @@ from openhands.events.stream import EventStream
|
||||
from openhands.integrations.provider import (
|
||||
PROVIDER_TOKEN_TYPE,
|
||||
)
|
||||
from openhands.integrations.service_types import Repository
|
||||
from openhands.integrations.service_types import Repository, SuggestedTask
|
||||
from openhands.runtime import get_runtime_cls
|
||||
from openhands.server.data_models.conversation_info import ConversationInfo
|
||||
from openhands.server.data_models.conversation_info_result_set import (
|
||||
@@ -42,16 +42,19 @@ from openhands.storage.data_models.conversation_status import ConversationStatus
|
||||
from openhands.utils.async_utils import wait_all
|
||||
from openhands.utils.conversation_summary import generate_conversation_title
|
||||
|
||||
|
||||
app = APIRouter(prefix='/api')
|
||||
|
||||
|
||||
class InitSessionRequest(BaseModel):
|
||||
conversation_trigger: ConversationTrigger = ConversationTrigger.GUI
|
||||
selected_repository: Repository | None = None
|
||||
selected_branch: str | None = None
|
||||
initial_user_msg: str | None = None
|
||||
image_urls: list[str] | None = None
|
||||
replay_json: str | None = None
|
||||
|
||||
suggested_task: SuggestedTask | None = None
|
||||
|
||||
|
||||
async def _create_new_conversation(
|
||||
user_id: str | None,
|
||||
@@ -64,9 +67,10 @@ async def _create_new_conversation(
|
||||
conversation_trigger: ConversationTrigger = ConversationTrigger.GUI,
|
||||
attach_convo_id: bool = False,
|
||||
):
|
||||
print("trigger", conversation_trigger)
|
||||
logger.info(
|
||||
'Creating conversation',
|
||||
extra={'signal': 'create_conversation', 'user_id': user_id},
|
||||
extra={'signal': 'create_conversation', 'user_id': user_id, 'trigger': conversation_trigger.value},
|
||||
)
|
||||
logger.info('Loading settings')
|
||||
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
|
||||
@@ -169,17 +173,24 @@ async def new_conversation(
|
||||
initial_user_msg = data.initial_user_msg
|
||||
image_urls = data.image_urls or []
|
||||
replay_json = data.replay_json
|
||||
suggested_task = data.suggested_task
|
||||
conversation_trigger = data.conversation_trigger
|
||||
|
||||
if suggested_task:
|
||||
initial_user_msg = suggested_task.get_prompt_for_task()
|
||||
conversation_trigger = ConversationTrigger.SUGGESTED_TASK
|
||||
|
||||
try:
|
||||
# Create conversation with initial message
|
||||
conversation_id = await _create_new_conversation(
|
||||
user_id,
|
||||
provider_tokens,
|
||||
selected_repository,
|
||||
selected_branch,
|
||||
initial_user_msg,
|
||||
image_urls,
|
||||
replay_json,
|
||||
user_id=user_id,
|
||||
git_provider_tokens=provider_tokens,
|
||||
selected_repository=selected_repository,
|
||||
selected_branch=selected_branch,
|
||||
initial_user_msg=initial_user_msg,
|
||||
image_urls=image_urls,
|
||||
replay_json=replay_json,
|
||||
conversation_trigger=conversation_trigger
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
|
||||
@@ -6,6 +6,7 @@ from enum import Enum
|
||||
class ConversationTrigger(Enum):
|
||||
RESOLVER = 'resolver'
|
||||
GUI = 'gui'
|
||||
SUGGESTED_TASK = 'suggested_task'
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -4,18 +4,31 @@ from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from openhands.integrations.service_types import (
|
||||
ProviderType,
|
||||
Repository,
|
||||
SuggestedTask,
|
||||
TaskType,
|
||||
)
|
||||
from openhands.server.data_models.conversation_info import ConversationInfo
|
||||
from openhands.server.data_models.conversation_info_result_set import (
|
||||
ConversationInfoResultSet,
|
||||
)
|
||||
from openhands.server.routes.manage_conversations import (
|
||||
InitSessionRequest,
|
||||
delete_conversation,
|
||||
get_conversation,
|
||||
new_conversation,
|
||||
search_conversations,
|
||||
update_conversation,
|
||||
)
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationMetadata
|
||||
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
|
||||
from openhands.storage.data_models.conversation_metadata import (
|
||||
ConversationMetadata,
|
||||
ConversationTrigger,
|
||||
)
|
||||
from openhands.storage.data_models.conversation_status import ConversationStatus
|
||||
from openhands.storage.locations import get_conversation_metadata_filename
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
@@ -218,6 +231,213 @@ async def test_update_conversation():
|
||||
assert saved_metadata.title == 'New Title'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_conversation_success():
|
||||
"""Test successful creation of a new conversation."""
|
||||
with _patch_store():
|
||||
# Mock the _create_new_conversation function directly
|
||||
with patch(
|
||||
'openhands.server.routes.manage_conversations._create_new_conversation'
|
||||
) as mock_create_conversation:
|
||||
# Set up the mock to return a conversation ID
|
||||
mock_create_conversation.return_value = 'test_conversation_id'
|
||||
|
||||
# Create test data
|
||||
test_repo = Repository(
|
||||
id=12345,
|
||||
full_name='test/repo',
|
||||
git_provider=ProviderType.GITHUB,
|
||||
is_public=True,
|
||||
)
|
||||
|
||||
test_request = InitSessionRequest(
|
||||
conversation_trigger=ConversationTrigger.GUI,
|
||||
selected_repository=test_repo,
|
||||
selected_branch='main',
|
||||
initial_user_msg='Hello, agent!',
|
||||
image_urls=['https://example.com/image.jpg'],
|
||||
)
|
||||
|
||||
# Call new_conversation
|
||||
response = await new_conversation(
|
||||
data=test_request, user_id='test_user', provider_tokens={}
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert isinstance(response, JSONResponse)
|
||||
assert response.status_code == 200
|
||||
assert (
|
||||
response.body.decode('utf-8')
|
||||
== '{"status":"ok","conversation_id":"test_conversation_id"}'
|
||||
)
|
||||
|
||||
# Verify that _create_new_conversation was called with the correct arguments
|
||||
mock_create_conversation.assert_called_once()
|
||||
call_args = mock_create_conversation.call_args[1]
|
||||
assert call_args['user_id'] == 'test_user'
|
||||
assert call_args['selected_repository'] == test_repo
|
||||
assert call_args['selected_branch'] == 'main'
|
||||
assert call_args['initial_user_msg'] == 'Hello, agent!'
|
||||
assert call_args['image_urls'] == ['https://example.com/image.jpg']
|
||||
assert call_args['conversation_trigger'] == ConversationTrigger.GUI
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_conversation_with_suggested_task():
|
||||
"""Test creating a new conversation with a suggested task."""
|
||||
with _patch_store():
|
||||
# Mock the _create_new_conversation function directly
|
||||
with patch(
|
||||
'openhands.server.routes.manage_conversations._create_new_conversation'
|
||||
) as mock_create_conversation:
|
||||
# Set up the mock to return a conversation ID
|
||||
mock_create_conversation.return_value = 'test_conversation_id'
|
||||
|
||||
# Mock SuggestedTask.get_prompt_for_task
|
||||
with patch(
|
||||
'openhands.integrations.service_types.SuggestedTask.get_prompt_for_task'
|
||||
) as mock_get_prompt:
|
||||
mock_get_prompt.return_value = (
|
||||
'Please fix the failing checks in PR #123'
|
||||
)
|
||||
|
||||
# Create test data
|
||||
test_repo = Repository(
|
||||
id=12345,
|
||||
full_name='test/repo',
|
||||
git_provider=ProviderType.GITHUB,
|
||||
is_public=True,
|
||||
)
|
||||
|
||||
test_task = SuggestedTask(
|
||||
git_provider=ProviderType.GITHUB,
|
||||
task_type=TaskType.FAILING_CHECKS,
|
||||
repo='test/repo',
|
||||
issue_number=123,
|
||||
title='Fix failing checks',
|
||||
)
|
||||
|
||||
test_request = InitSessionRequest(
|
||||
conversation_trigger=ConversationTrigger.SUGGESTED_TASK,
|
||||
selected_repository=test_repo,
|
||||
selected_branch='main',
|
||||
suggested_task=test_task,
|
||||
)
|
||||
|
||||
# Call new_conversation
|
||||
response = await new_conversation(
|
||||
data=test_request, user_id='test_user', provider_tokens={}
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert isinstance(response, JSONResponse)
|
||||
assert response.status_code == 200
|
||||
assert (
|
||||
response.body.decode('utf-8')
|
||||
== '{"status":"ok","conversation_id":"test_conversation_id"}'
|
||||
)
|
||||
|
||||
# Verify that _create_new_conversation was called with the correct arguments
|
||||
mock_create_conversation.assert_called_once()
|
||||
call_args = mock_create_conversation.call_args[1]
|
||||
assert call_args['user_id'] == 'test_user'
|
||||
assert call_args['selected_repository'] == test_repo
|
||||
assert call_args['selected_branch'] == 'main'
|
||||
assert (
|
||||
call_args['initial_user_msg']
|
||||
== 'Please fix the failing checks in PR #123'
|
||||
)
|
||||
assert (
|
||||
call_args['conversation_trigger']
|
||||
== ConversationTrigger.SUGGESTED_TASK
|
||||
)
|
||||
|
||||
# Verify that get_prompt_for_task was called
|
||||
mock_get_prompt.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_conversation_missing_settings():
|
||||
"""Test creating a new conversation when settings are missing."""
|
||||
with _patch_store():
|
||||
# Mock the _create_new_conversation function to raise MissingSettingsError
|
||||
with patch(
|
||||
'openhands.server.routes.manage_conversations._create_new_conversation'
|
||||
) as mock_create_conversation:
|
||||
# Set up the mock to raise MissingSettingsError
|
||||
mock_create_conversation.side_effect = MissingSettingsError(
|
||||
'Settings not found'
|
||||
)
|
||||
|
||||
# Create test data
|
||||
test_repo = Repository(
|
||||
id=12345,
|
||||
full_name='test/repo',
|
||||
git_provider=ProviderType.GITHUB,
|
||||
is_public=True,
|
||||
)
|
||||
|
||||
test_request = InitSessionRequest(
|
||||
conversation_trigger=ConversationTrigger.GUI,
|
||||
selected_repository=test_repo,
|
||||
selected_branch='main',
|
||||
initial_user_msg='Hello, agent!',
|
||||
)
|
||||
|
||||
# Call new_conversation
|
||||
response = await new_conversation(
|
||||
data=test_request, user_id='test_user', provider_tokens={}
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert isinstance(response, JSONResponse)
|
||||
assert response.status_code == 400
|
||||
assert 'Settings not found' in response.body.decode('utf-8')
|
||||
assert 'CONFIGURATION$SETTINGS_NOT_FOUND' in response.body.decode('utf-8')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_conversation_invalid_api_key():
|
||||
"""Test creating a new conversation with an invalid API key."""
|
||||
with _patch_store():
|
||||
# Mock the _create_new_conversation function to raise LLMAuthenticationError
|
||||
with patch(
|
||||
'openhands.server.routes.manage_conversations._create_new_conversation'
|
||||
) as mock_create_conversation:
|
||||
# Set up the mock to raise LLMAuthenticationError
|
||||
mock_create_conversation.side_effect = LLMAuthenticationError(
|
||||
'Error authenticating with the LLM provider. Please check your API key'
|
||||
)
|
||||
|
||||
# Create test data
|
||||
test_repo = Repository(
|
||||
id=12345,
|
||||
full_name='test/repo',
|
||||
git_provider=ProviderType.GITHUB,
|
||||
is_public=True,
|
||||
)
|
||||
|
||||
test_request = InitSessionRequest(
|
||||
conversation_trigger=ConversationTrigger.GUI,
|
||||
selected_repository=test_repo,
|
||||
selected_branch='main',
|
||||
initial_user_msg='Hello, agent!',
|
||||
)
|
||||
|
||||
# Call new_conversation
|
||||
response = await new_conversation(
|
||||
data=test_request, user_id='test_user', provider_tokens={}
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert isinstance(response, JSONResponse)
|
||||
assert response.status_code == 400
|
||||
assert 'Error authenticating with the LLM provider' in response.body.decode(
|
||||
'utf-8'
|
||||
)
|
||||
assert 'STATUS$ERROR_LLM_AUTHENTICATION' in response.body.decode('utf-8')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_conversation():
|
||||
with _patch_store():
|
||||
|
||||
Reference in New Issue
Block a user