mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
12 Commits
self-hoste
...
gitlab-doc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bc1305fc0e | ||
|
|
f985a46cfd | ||
|
|
9ca96afe29 | ||
|
|
7f3f44432e | ||
|
|
35b6b8ae2f | ||
|
|
52110305b3 | ||
|
|
877644be8c | ||
|
|
3bc85eb7ac | ||
|
|
5fa01ed278 | ||
|
|
1f747232cf | ||
|
|
4b1ed30e97 | ||
|
|
998de564cd |
1
docs/modules/usage/how-to/gitlab-runner.md
Normal file
1
docs/modules/usage/how-to/gitlab-runner.md
Normal file
@@ -0,0 +1 @@
|
||||
# Using GitLab CI Runners
|
||||
@@ -91,6 +91,13 @@ describe("HomeScreen", () => {
|
||||
screen.getByTestId("task-suggestions");
|
||||
});
|
||||
|
||||
it("should have responsive layout for mobile and desktop screens", async () => {
|
||||
renderHomeScreen();
|
||||
|
||||
const mainContainer = screen.getByTestId("home-screen").querySelector("main");
|
||||
expect(mainContainer).toHaveClass("flex", "flex-col", "md:flex-row");
|
||||
});
|
||||
|
||||
it("should filter the suggested tasks based on the selected repository", async () => {
|
||||
const retrieveUserGitRepositoriesSpy = vi.spyOn(
|
||||
GitService,
|
||||
|
||||
@@ -6,55 +6,38 @@ import { KeyStatusIcon } from "../key-status-icon";
|
||||
|
||||
interface GitHubTokenInputProps {
|
||||
onChange: (value: string) => void;
|
||||
onBaseDomainChange?: (value: string) => void;
|
||||
isGitHubTokenSet: boolean;
|
||||
name: string;
|
||||
baseDomainSet?: string | null;
|
||||
isSaas: boolean;
|
||||
}
|
||||
|
||||
export function GitHubTokenInput({
|
||||
onChange,
|
||||
onBaseDomainChange,
|
||||
isGitHubTokenSet,
|
||||
name,
|
||||
baseDomainSet,
|
||||
isSaas,
|
||||
}: GitHubTokenInputProps) {
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-6">
|
||||
{!isSaas && (
|
||||
<SettingsInput
|
||||
testId={name}
|
||||
name={name}
|
||||
onChange={onChange}
|
||||
label={t(I18nKey.GITHUB$TOKEN_LABEL)}
|
||||
type="password"
|
||||
className="w-[680px]"
|
||||
placeholder={isGitHubTokenSet ? "<hidden>" : ""}
|
||||
startContent={
|
||||
isGitHubTokenSet && (
|
||||
<KeyStatusIcon
|
||||
testId="gh-set-token-indicator"
|
||||
isSet={isGitHubTokenSet}
|
||||
/>
|
||||
)
|
||||
}
|
||||
/>
|
||||
)}
|
||||
|
||||
<SettingsInput
|
||||
onChange={onBaseDomainChange || (() => {})}
|
||||
label={t(I18nKey.GITHUB$BASE_DOMAIN_LABEL)}
|
||||
type="text"
|
||||
testId={name}
|
||||
name={name}
|
||||
onChange={onChange}
|
||||
label={t(I18nKey.GITHUB$TOKEN_LABEL)}
|
||||
type="password"
|
||||
className="w-[680px]"
|
||||
placeholder={"github.com"}
|
||||
defaultValue={baseDomainSet ? baseDomainSet : undefined}
|
||||
placeholder={isGitHubTokenSet ? "<hidden>" : ""}
|
||||
startContent={
|
||||
isGitHubTokenSet && (
|
||||
<KeyStatusIcon
|
||||
testId="gh-set-token-indicator"
|
||||
isSet={isGitHubTokenSet}
|
||||
/>
|
||||
)
|
||||
}
|
||||
/>
|
||||
|
||||
{!isSaas && <GitHubTokenHelpAnchor />}
|
||||
<GitHubTokenHelpAnchor />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -6,55 +6,38 @@ import { KeyStatusIcon } from "../key-status-icon";
|
||||
|
||||
interface GitLabTokenInputProps {
|
||||
onChange: (value: string) => void;
|
||||
onBaseDomainChange?: (value: string) => void;
|
||||
isGitLabTokenSet: boolean;
|
||||
name: string;
|
||||
baseDomainSet?: string | null;
|
||||
isSaas: boolean;
|
||||
}
|
||||
|
||||
export function GitLabTokenInput({
|
||||
onChange,
|
||||
onBaseDomainChange,
|
||||
isGitLabTokenSet,
|
||||
name,
|
||||
baseDomainSet,
|
||||
isSaas,
|
||||
}: GitLabTokenInputProps) {
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-6">
|
||||
{!isSaas && (
|
||||
<SettingsInput
|
||||
testId={name}
|
||||
name={name}
|
||||
onChange={onChange}
|
||||
label={t(I18nKey.GITLAB$TOKEN_LABEL)}
|
||||
type="password"
|
||||
className="w-[680px]"
|
||||
placeholder={isGitLabTokenSet ? "<hidden>" : ""}
|
||||
startContent={
|
||||
isGitLabTokenSet && (
|
||||
<KeyStatusIcon
|
||||
testId="gl-set-token-indicator"
|
||||
isSet={isGitLabTokenSet}
|
||||
/>
|
||||
)
|
||||
}
|
||||
/>
|
||||
)}
|
||||
|
||||
<SettingsInput
|
||||
onChange={onBaseDomainChange || (() => {})}
|
||||
label={t(I18nKey.GITLAB$BASE_DOMAIN_LABEL)}
|
||||
type="text"
|
||||
testId={name}
|
||||
name={name}
|
||||
onChange={onChange}
|
||||
label={t(I18nKey.GITLAB$TOKEN_LABEL)}
|
||||
type="password"
|
||||
className="w-[680px]"
|
||||
placeholder={"gitlab.com"}
|
||||
defaultValue={baseDomainSet ? baseDomainSet : undefined}
|
||||
placeholder={isGitLabTokenSet ? "<hidden>" : ""}
|
||||
startContent={
|
||||
isGitLabTokenSet && (
|
||||
<KeyStatusIcon
|
||||
testId="gl-set-token-indicator"
|
||||
isSet={isGitLabTokenSet}
|
||||
/>
|
||||
)
|
||||
}
|
||||
/>
|
||||
|
||||
{!isSaas && <GitLabTokenHelpAnchor />}
|
||||
<GitLabTokenHelpAnchor />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -54,11 +54,13 @@ export const useSettings = () => {
|
||||
React.useEffect(() => {
|
||||
if (query.data?.PROVIDER_TOKENS_SET) {
|
||||
const providers = query.data.PROVIDER_TOKENS_SET;
|
||||
const setProviders = Object.keys(providers) as Array<
|
||||
keyof typeof providers
|
||||
>;
|
||||
const setProviders = (
|
||||
Object.keys(providers) as Array<keyof typeof providers>
|
||||
).filter((key) => providers[key]);
|
||||
setProviderTokensSet(setProviders);
|
||||
const atLeastOneSet = setProviders.length > 0;
|
||||
const atLeastOneSet = Object.values(query.data.PROVIDER_TOKENS_SET).some(
|
||||
(value) => value,
|
||||
);
|
||||
setProvidersAreSet(atLeastOneSet);
|
||||
}
|
||||
}, [query.data?.PROVIDER_TOKENS_SET, query.isFetched]);
|
||||
|
||||
@@ -104,7 +104,6 @@ export enum I18nKey {
|
||||
EXIT_PROJECT$TITLE = "EXIT_PROJECT$TITLE",
|
||||
LANGUAGE$LABEL = "LANGUAGE$LABEL",
|
||||
GITHUB$TOKEN_LABEL = "GITHUB$TOKEN_LABEL",
|
||||
GITHUB$BASE_DOMAIN_LABEL = "GITHUB$BASE_DOMAIN_LABEL",
|
||||
GITHUB$TOKEN_OPTIONAL = "GITHUB$TOKEN_OPTIONAL",
|
||||
GITHUB$GET_TOKEN = "GITHUB$GET_TOKEN",
|
||||
GITHUB$TOKEN_HELP_TEXT = "GITHUB$TOKEN_HELP_TEXT",
|
||||
@@ -451,7 +450,6 @@ export enum I18nKey {
|
||||
MODEL_SELECTOR$VERIFIED = "MODEL_SELECTOR$VERIFIED",
|
||||
MODEL_SELECTOR$OTHERS = "MODEL_SELECTOR$OTHERS",
|
||||
GITLAB$TOKEN_LABEL = "GITLAB$TOKEN_LABEL",
|
||||
GITLAB$BASE_DOMAIN_LABEL = "GITLAB$BASE_DOMAIN_LABEL",
|
||||
GITLAB$GET_TOKEN = "GITLAB$GET_TOKEN",
|
||||
GITLAB$TOKEN_HELP_TEXT = "GITLAB$TOKEN_HELP_TEXT",
|
||||
GITLAB$TOKEN_LINK_TEXT = "GITLAB$TOKEN_LINK_TEXT",
|
||||
|
||||
@@ -1569,21 +1569,6 @@
|
||||
"tr": "GitHub Jetonu",
|
||||
"de": "GitHub-Token"
|
||||
},
|
||||
"GITHUB$BASE_DOMAIN_LABEL": {
|
||||
"en": "GitHub Base Domain",
|
||||
"ja": "GitHub ベースドメイン",
|
||||
"zh-CN": "GitHub 基础域名",
|
||||
"zh-TW": "GitHub 基礎網域",
|
||||
"ko-KR": "GitHub 기본 도메인",
|
||||
"no": "GitHub Base Domain",
|
||||
"it": "Dominio Base GitHub",
|
||||
"pt": "Domínio Base do GitHub",
|
||||
"es": "Dominio Base de GitHub",
|
||||
"ar": "نطاق GitHub الأساسي",
|
||||
"fr": "Domaine de Base GitHub",
|
||||
"tr": "GitHub Temel Alan Adı",
|
||||
"de": "GitHub Basis-Domain"
|
||||
},
|
||||
"GITHUB$TOKEN_OPTIONAL": {
|
||||
"en": "GitHub Token (Optional)",
|
||||
"ja": "GitHubトークン(任意)",
|
||||
@@ -6484,21 +6469,6 @@
|
||||
"tr": "GitLab Jetonu",
|
||||
"de": "GitLab-Token"
|
||||
},
|
||||
"GITLAB$BASE_DOMAIN_LABEL": {
|
||||
"en": "GitLab Base Domain",
|
||||
"ja": "GitLab ベースドメイン",
|
||||
"zh-CN": "GitLab 基础域名",
|
||||
"zh-TW": "GitLab 基礎網域",
|
||||
"ko-KR": "GitLab 기본 도메인",
|
||||
"no": "GitLab Base Domain",
|
||||
"it": "Dominio Base GitLab",
|
||||
"pt": "Domínio Base do GitLab",
|
||||
"es": "Dominio Base de GitLab",
|
||||
"ar": "نطاق GitLab الأساسي",
|
||||
"fr": "Domaine de Base GitLab",
|
||||
"tr": "GitLab Temel Alan Adı",
|
||||
"de": "GitLab Basis-Domain"
|
||||
},
|
||||
"GITLAB$GET_TOKEN": {
|
||||
"en": "Generate a token on",
|
||||
"ja": "トークンを生成する",
|
||||
|
||||
@@ -15,13 +15,11 @@ import {
|
||||
} from "#/utils/custom-toast-handlers";
|
||||
import { retrieveAxiosErrorMessage } from "#/utils/retrieve-axios-error-message";
|
||||
import { GitSettingInputsSkeleton } from "#/components/features/settings/git-settings/github-settings-inputs-skeleton";
|
||||
import { useAuth } from "#/context/auth-context";
|
||||
|
||||
function GitSettingsScreen() {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const { mutate: saveSettings, isPending } = useSaveSettings();
|
||||
const { providerTokensSet } = useAuth();
|
||||
const { mutate: disconnectGitTokens } = useLogout();
|
||||
|
||||
const { data: settings, isLoading } = useSettings();
|
||||
@@ -31,17 +29,10 @@ function GitSettingsScreen() {
|
||||
React.useState(false);
|
||||
const [gitlabTokenInputHasValue, setGitlabTokenInputHasValue] =
|
||||
React.useState(false);
|
||||
const [githubBaseDomainInputHasValue, setGithubBaseDomainInputHasValue] =
|
||||
React.useState(false);
|
||||
const [gitlabBaseDomainInputHasValue, setGitlabBaseDomainInputHasValue] =
|
||||
React.useState(false);
|
||||
|
||||
const isSaas = config?.APP_MODE === "saas";
|
||||
const isGitHubTokenSet = providerTokensSet.includes("github");
|
||||
const isGitLabTokenSet = providerTokensSet.includes("gitlab");
|
||||
|
||||
const existingGithubBaseDomain = settings?.PROVIDER_TOKENS_SET["github"];
|
||||
const existingGitlabBaseDomain = settings?.PROVIDER_TOKENS_SET["gitlab"];
|
||||
const isGitHubTokenSet = !!settings?.PROVIDER_TOKENS_SET.github;
|
||||
const isGitLabTokenSet = !!settings?.PROVIDER_TOKENS_SET.gitlab;
|
||||
|
||||
const formAction = async (formData: FormData) => {
|
||||
const disconnectButtonClicked =
|
||||
@@ -54,22 +45,12 @@ function GitSettingsScreen() {
|
||||
|
||||
const githubToken = formData.get("github-token-input")?.toString() || "";
|
||||
const gitlabToken = formData.get("gitlab-token-input")?.toString() || "";
|
||||
const githubBaseDomain =
|
||||
formData.get("github-base-domain-input")?.toString() || "";
|
||||
const gitlabBaseDomain =
|
||||
formData.get("gitlab-base-domain-input")?.toString() || "";
|
||||
|
||||
saveSettings(
|
||||
{
|
||||
provider_tokens: {
|
||||
github: {
|
||||
token: githubToken,
|
||||
base_domain: githubBaseDomain || null,
|
||||
},
|
||||
gitlab: {
|
||||
token: gitlabToken,
|
||||
base_domain: gitlabBaseDomain || null,
|
||||
},
|
||||
github: githubToken,
|
||||
gitlab: gitlabToken,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -83,19 +64,12 @@ function GitSettingsScreen() {
|
||||
onSettled: () => {
|
||||
setGithubTokenInputHasValue(false);
|
||||
setGitlabTokenInputHasValue(false);
|
||||
setGithubBaseDomainInputHasValue(false);
|
||||
setGitlabBaseDomainInputHasValue(false);
|
||||
},
|
||||
},
|
||||
);
|
||||
};
|
||||
|
||||
const formIsClean =
|
||||
!githubTokenInputHasValue &&
|
||||
!gitlabTokenInputHasValue &&
|
||||
!githubBaseDomainInputHasValue &&
|
||||
!gitlabBaseDomainInputHasValue;
|
||||
|
||||
const formIsClean = !githubTokenInputHasValue && !gitlabTokenInputHasValue;
|
||||
const shouldRenderExternalConfigureButtons = isSaas && config.APP_SLUG;
|
||||
|
||||
return (
|
||||
@@ -110,32 +84,22 @@ function GitSettingsScreen() {
|
||||
<ConfigureGitHubRepositoriesAnchor slug={config.APP_SLUG!} />
|
||||
)}
|
||||
|
||||
{!isLoading && (
|
||||
{!isSaas && !isLoading && (
|
||||
<div className="p-9 flex flex-col gap-12">
|
||||
<GitHubTokenInput
|
||||
name="github-token-input"
|
||||
baseDomainSet={existingGithubBaseDomain}
|
||||
isGitHubTokenSet={isGitHubTokenSet}
|
||||
onChange={(value) => {
|
||||
setGithubTokenInputHasValue(!!value);
|
||||
}}
|
||||
onBaseDomainChange={(value) => {
|
||||
setGithubBaseDomainInputHasValue(!!value);
|
||||
}}
|
||||
isSaas={isSaas}
|
||||
/>
|
||||
|
||||
<GitLabTokenInput
|
||||
name="gitlab-token-input"
|
||||
baseDomainSet={existingGitlabBaseDomain}
|
||||
isGitLabTokenSet={isGitLabTokenSet}
|
||||
onChange={(value) => {
|
||||
setGitlabTokenInputHasValue(!!value);
|
||||
}}
|
||||
onBaseDomainChange={(value) => {
|
||||
setGitlabBaseDomainInputHasValue(!!value);
|
||||
}}
|
||||
isSaas={isSaas}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -22,7 +22,7 @@ function HomeScreen() {
|
||||
|
||||
<hr className="border-[#717888]" />
|
||||
|
||||
<main className="flex justify-between gap-4">
|
||||
<main className="flex flex-col md:flex-row justify-between gap-4">
|
||||
<RepoConnector
|
||||
onRepoSelection={(title) => setSelectedRepoTitle(title)}
|
||||
/>
|
||||
|
||||
@@ -11,13 +11,13 @@ export const DEFAULT_SETTINGS: Settings = {
|
||||
CONFIRMATION_MODE: false,
|
||||
SECURITY_ANALYZER: "",
|
||||
REMOTE_RUNTIME_RESOURCE_FACTOR: 1,
|
||||
PROVIDER_TOKENS_SET: { github: null, gitlab: null },
|
||||
PROVIDER_TOKENS_SET: { github: false, gitlab: false },
|
||||
ENABLE_DEFAULT_CONDENSER: true,
|
||||
ENABLE_SOUND_NOTIFICATIONS: false,
|
||||
USER_CONSENTS_TO_ANALYTICS: false,
|
||||
PROVIDER_TOKENS: {
|
||||
github: { token: "", base_domain: null },
|
||||
gitlab: { token: "", base_domain: null },
|
||||
github: "",
|
||||
gitlab: "",
|
||||
},
|
||||
IS_NEW_USER: true,
|
||||
};
|
||||
|
||||
@@ -5,11 +5,6 @@ export const ProviderOptions = {
|
||||
|
||||
export type Provider = keyof typeof ProviderOptions;
|
||||
|
||||
export type ProviderToken = {
|
||||
token: string;
|
||||
base_domain: string | null;
|
||||
};
|
||||
|
||||
export type Settings = {
|
||||
LLM_MODEL: string;
|
||||
LLM_BASE_URL: string;
|
||||
@@ -19,11 +14,11 @@ export type Settings = {
|
||||
CONFIRMATION_MODE: boolean;
|
||||
SECURITY_ANALYZER: string;
|
||||
REMOTE_RUNTIME_RESOURCE_FACTOR: number | null;
|
||||
PROVIDER_TOKENS_SET: Record<Provider, string | null>;
|
||||
PROVIDER_TOKENS_SET: Record<Provider, boolean>;
|
||||
ENABLE_DEFAULT_CONDENSER: boolean;
|
||||
ENABLE_SOUND_NOTIFICATIONS: boolean;
|
||||
USER_CONSENTS_TO_ANALYTICS: boolean | null;
|
||||
PROVIDER_TOKENS: Record<Provider, ProviderToken>;
|
||||
PROVIDER_TOKENS: Record<Provider, string>;
|
||||
IS_NEW_USER?: boolean;
|
||||
};
|
||||
|
||||
@@ -40,17 +35,17 @@ export type ApiSettings = {
|
||||
enable_default_condenser: boolean;
|
||||
enable_sound_notifications: boolean;
|
||||
user_consents_to_analytics: boolean | null;
|
||||
provider_tokens: Record<Provider, ProviderToken>;
|
||||
provider_tokens_set: Record<Provider, string | null>;
|
||||
provider_tokens: Record<Provider, string>;
|
||||
provider_tokens_set: Record<Provider, boolean>;
|
||||
};
|
||||
|
||||
export type PostSettings = Settings & {
|
||||
provider_tokens: Record<Provider, ProviderToken>;
|
||||
provider_tokens: Record<Provider, string>;
|
||||
user_consents_to_analytics: boolean | null;
|
||||
llm_api_key?: string | null;
|
||||
};
|
||||
|
||||
export type PostApiSettings = ApiSettings & {
|
||||
provider_tokens: Record<Provider, ProviderToken>;
|
||||
provider_tokens: Record<Provider, string>;
|
||||
user_consents_to_analytics: boolean | null;
|
||||
};
|
||||
|
||||
@@ -1,64 +0,0 @@
|
||||
---
|
||||
name: add_openhands_repo_instruction
|
||||
version: 1.0.0
|
||||
author: openhands
|
||||
agent: CodeActAgent
|
||||
inputs:
|
||||
- name: REPO_FOLDER_NAME
|
||||
description: "Branch for the agent to work on"
|
||||
required: false
|
||||
---
|
||||
|
||||
Please browse the current repository under /workspace/{{ REPO_FOLDER_NAME }}, look at the documentation and relevant code, and understand the purpose of this repository.
|
||||
|
||||
Specifically, I want you to create a `.openhands/microagents/repo.md` file. This file should contain succinct information that summarizes (1) the purpose of this repository, (2) the general setup of this repo, and (3) a brief description of the structure of this repo.
|
||||
|
||||
Here's an example:
|
||||
```markdown
|
||||
---
|
||||
name: repo
|
||||
type: repo
|
||||
agent: CodeActAgent
|
||||
---
|
||||
|
||||
This repository contains the code for OpenHands, an automated AI software engineer. It has a Python backend
|
||||
(in the `openhands` directory) and React frontend (in the `frontend` directory).
|
||||
|
||||
## General Setup:
|
||||
To set up the entire repo, including frontend and backend, run `make build`.
|
||||
You don't need to do this unless the user asks you to, or if you're trying to run the entire application.
|
||||
|
||||
Before pushing any changes, you should ensure that any lint errors or simple test errors have been fixed.
|
||||
|
||||
* If you've made changes to the backend, you should run `pre-commit run --all-files --config ./dev_config/python/.pre-commit-config.yaml`
|
||||
* If you've made changes to the frontend, you should run `cd frontend && npm run lint:fix && npm run build ; cd ..`
|
||||
|
||||
If either command fails, it may have automatically fixed some issues. You should fix any issues that weren't automatically fixed,
|
||||
then re-run the command to ensure it passes.
|
||||
|
||||
## Repository Structure
|
||||
Backend:
|
||||
- Located in the `openhands` directory
|
||||
- Testing:
|
||||
- All tests are in `tests/unit/test_*.py`
|
||||
- To test new code, run `poetry run pytest tests/unit/test_xxx.py` where `xxx` is the appropriate file for the current functionality
|
||||
- Write all tests with pytest
|
||||
|
||||
Frontend:
|
||||
- Located in the `frontend` directory
|
||||
- Prerequisites: A recent version of NodeJS / NPM
|
||||
- Setup: Run `npm install` in the frontend directory
|
||||
- Testing:
|
||||
- Run tests: `npm run test`
|
||||
- To run specific tests: `npm run test -- -t "TestName"`
|
||||
- Building:
|
||||
- Build for production: `npm run build`
|
||||
- Environment Variables:
|
||||
- Set in `frontend/.env` or as environment variables
|
||||
- Available variables: VITE_BACKEND_HOST, VITE_USE_TLS, VITE_INSECURE_SKIP_VERIFY, VITE_FRONTEND_PORT
|
||||
- Internationalization:
|
||||
- Generate i18n declaration file: `npm run make-i18n`
|
||||
```
|
||||
|
||||
Now, please write a similar markdown for the current repository.
|
||||
Read all the GitHub workflows under .github/ of the repository (if this folder exists) to understand the CI checks (e.g., linter, pre-commit), and include those in the repo.md file.
|
||||
@@ -1,19 +0,0 @@
|
||||
---
|
||||
name: address_pr_comments
|
||||
version: 1.0.0
|
||||
author: openhands
|
||||
agent: CodeActAgent
|
||||
inputs:
|
||||
- name: PR_URL
|
||||
description: "URL of the pull request"
|
||||
required: true
|
||||
- name: BRANCH_NAME
|
||||
description: "Branch name corresponds to the pull request"
|
||||
required: true
|
||||
---
|
||||
|
||||
First, check the branch {{ BRANCH_NAME }} and read the diff against the main branch to understand the purpose.
|
||||
|
||||
This branch corresponds to this PR {{ PR_URL }}
|
||||
|
||||
Next, you should use the GitHub API to read the reviews and comments on this PR and address them.
|
||||
32
microagents/agent_memory.md
Normal file
32
microagents/agent_memory.md
Normal file
@@ -0,0 +1,32 @@
|
||||
---
|
||||
name: agent_memory
|
||||
type: knowledge
|
||||
version: 1.0.0
|
||||
agent: CodeActAgent
|
||||
triggers:
|
||||
- /remember
|
||||
---
|
||||
|
||||
* Repository memory: Use .openhands/microagents/repo.md under each repository root to store and access important information.
|
||||
- If this file exists, it will be added to your context automatically.
|
||||
- If missing, you should create it unless the user has explicitly asked you to not do so.
|
||||
|
||||
* Store and maintain **general knowledge** that will be helpful for most future tasks:
|
||||
1. Repository structure
|
||||
2. Common commands (build, lint, test, pre-commit, etc.)
|
||||
3. Code style preferences
|
||||
4. Workflows and best practices
|
||||
5. Any other repository-specific knowledge you learn
|
||||
|
||||
* IMPORTANT: ONLY LOG the information that would be helpful for different future tasks, for example, how to configure the settings, how to setup the repository. Do NOT add issue-specific information (e.g., what specific error you have ran into and how you fix it).
|
||||
|
||||
* When adding new information:
|
||||
- ALWAYS ask for user confirmation first by listing the exact items (numbered 1, 2, 3, etc.) you plan to save to repo.md
|
||||
- Only save the items the user approves (they may ask you to save a subset)
|
||||
- Ensure it integrates nicely with existing knowledge in repo.md
|
||||
- Reorganize the content if needed to maintain clarity and organization
|
||||
- Group related information together under appropriate sections or headings
|
||||
- If you've only explored a portion of the codebase, clearly note this limitation in the repository structure documentation
|
||||
- If you don't know the essential commands for working with the repository, such as lint or typecheck, ask the user and suggest adding them to repo.md for future reference (with permission)
|
||||
|
||||
When you receive this message, please review and summarize your recent actions and observations, then present a list of valuable information that should be saved in repo.md to the user.
|
||||
@@ -1,27 +0,0 @@
|
||||
---
|
||||
name: get_test_to_pass
|
||||
version: 1.0.0
|
||||
author: openhands
|
||||
agent: CodeActAgent
|
||||
inputs:
|
||||
- name: BRANCH_NAME
|
||||
description: "Branch for the agent to work on"
|
||||
required: true
|
||||
- name: TEST_COMMAND_TO_RUN
|
||||
description: "The test command you want the agent to work on. For example, `pytest tests/unit/test_bash_parsing.py`"
|
||||
required: true
|
||||
- name: FUNCTION_TO_FIX
|
||||
description: "The name of function to fix"
|
||||
required: false
|
||||
- name: FILE_FOR_FUNCTION
|
||||
description: "The path of the file that contains the function"
|
||||
required: false
|
||||
---
|
||||
|
||||
Can you check out branch "{{ BRANCH_NAME }}", and run {{ TEST_COMMAND_TO_RUN }}.
|
||||
|
||||
{%- if FUNCTION_TO_FIX and FILE_FOR_FUNCTION %}
|
||||
Help me fix these tests to pass by fixing the {{ FUNCTION_TO_FIX }} function in file {{ FILE_FOR_FUNCTION }}.
|
||||
{%- endif %}
|
||||
|
||||
PLEASE DO NOT modify the tests by yourselves -- Let me know if you think some of the tests are incorrect.
|
||||
@@ -1,21 +0,0 @@
|
||||
---
|
||||
name: update_pr_description
|
||||
version: 1.0.0
|
||||
author: openhands
|
||||
agent: CodeActAgent
|
||||
inputs:
|
||||
- name: PR_URL
|
||||
description: "URL of the pull request"
|
||||
type: string
|
||||
required: true
|
||||
validation:
|
||||
pattern: "^https://github.com/.+/.+/pull/[0-9]+$"
|
||||
- name: BRANCH_NAME
|
||||
description: "Branch name corresponds to the pull request"
|
||||
type: string
|
||||
required: true
|
||||
---
|
||||
|
||||
Please check the branch "{{ BRANCH_NAME }}" and look at the diff against the main branch. This branch belongs to this PR "{{ PR_URL }}".
|
||||
|
||||
Once you understand the purpose of the diff, please use Github API to read the existing PR description, and update it to be more reflective of the changes we've made when necessary.
|
||||
@@ -1,21 +0,0 @@
|
||||
---
|
||||
name: update_test_for_new_implementation
|
||||
version: 1.0.0
|
||||
author: openhands
|
||||
agent: CodeActAgent
|
||||
inputs:
|
||||
- name: BRANCH_NAME
|
||||
description: "Branch for the agent to work on"
|
||||
required: true
|
||||
- name: TEST_COMMAND_TO_RUN
|
||||
description: "The test command you want the agent to work on. For example, `pytest tests/unit/test_bash_parsing.py`"
|
||||
required: true
|
||||
---
|
||||
|
||||
Can you check out branch "{{ BRANCH_NAME }}", and run {{ TEST_COMMAND_TO_RUN }}.
|
||||
|
||||
{%- if FUNCTION_TO_FIX and FILE_FOR_FUNCTION %}
|
||||
Help me fix these tests to pass by fixing the {{ FUNCTION_TO_FIX }} function in file {{ FILE_FOR_FUNCTION }}.
|
||||
{%- endif %}
|
||||
|
||||
PLEASE DO NOT modify the tests by yourselves -- Let me know if you think some of the tests are incorrect.
|
||||
@@ -20,10 +20,7 @@ from openhands.controller.state.state import State
|
||||
from openhands.core.config import AgentConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.message import Message
|
||||
from openhands.events.action import (
|
||||
Action,
|
||||
AgentFinishAction,
|
||||
)
|
||||
from openhands.events.action import Action, AgentFinishAction, MessageAction
|
||||
from openhands.events.event import Event
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.memory.condenser import Condenser
|
||||
@@ -173,7 +170,8 @@ class CodeActAgent(Agent):
|
||||
f'Processing {len(condensed_history)} events from a total of {len(state.history)} events'
|
||||
)
|
||||
|
||||
messages = self._get_messages(condensed_history)
|
||||
initial_user_message = self._get_initial_user_message(state.history)
|
||||
messages = self._get_messages(condensed_history, initial_user_message)
|
||||
params: dict = {
|
||||
'messages': self.llm.format_messages_for_llm(messages),
|
||||
}
|
||||
@@ -216,7 +214,29 @@ class CodeActAgent(Agent):
|
||||
self.pending_actions.append(action)
|
||||
return self.pending_actions.popleft()
|
||||
|
||||
def _get_messages(self, events: list[Event]) -> list[Message]:
|
||||
def _get_initial_user_message(self, history: list[Event]) -> MessageAction:
|
||||
"""Finds the initial user message action from the full history."""
|
||||
initial_user_message: MessageAction | None = None
|
||||
for event in history:
|
||||
if isinstance(event, MessageAction) and event.source == 'user':
|
||||
initial_user_message = event
|
||||
break
|
||||
|
||||
if initial_user_message is None:
|
||||
# This should not happen in a valid conversation
|
||||
logger.error(
|
||||
f'CRITICAL: Could not find the initial user MessageAction in the full {len(history)} events history.'
|
||||
)
|
||||
# Depending on desired robustness, could raise error or create a dummy action
|
||||
# and log the error
|
||||
raise ValueError(
|
||||
'Initial user message not found in history. Please report this issue.'
|
||||
)
|
||||
return initial_user_message
|
||||
|
||||
def _get_messages(
|
||||
self, events: list[Event], initial_user_message: MessageAction
|
||||
) -> list[Message]:
|
||||
"""Constructs the message history for the LLM conversation.
|
||||
|
||||
This method builds a structured conversation history by processing events from the state
|
||||
@@ -253,6 +273,7 @@ class CodeActAgent(Agent):
|
||||
# Use ConversationMemory to process events (including SystemMessageAction)
|
||||
messages = self.conversation_memory.process_events(
|
||||
condensed_history=events,
|
||||
initial_user_action=initial_user_message,
|
||||
max_message_chars=self.llm.config.max_message_chars,
|
||||
vision_is_active=self.llm.vision_is_active(),
|
||||
)
|
||||
|
||||
@@ -190,7 +190,7 @@ class AgentController:
|
||||
logger.debug(f'System message got from agent: {system_message}')
|
||||
if system_message:
|
||||
self.event_stream.add_event(system_message, EventSource.AGENT)
|
||||
logger.debug(f'System message added to event stream: {system_message}')
|
||||
logger.info(f'System message added to event stream: {system_message}')
|
||||
|
||||
async def close(self, set_stop_state: bool = True) -> None:
|
||||
"""Closes the agent controller, canceling any ongoing tasks and unsubscribing from the event stream.
|
||||
@@ -1020,7 +1020,7 @@ class AgentController:
|
||||
self.state.start_id = 0
|
||||
|
||||
self.log(
|
||||
'debug',
|
||||
'info',
|
||||
f'AgentController {self.id} - created new state. start_id: {self.state.start_id}',
|
||||
)
|
||||
else:
|
||||
@@ -1030,7 +1030,7 @@ class AgentController:
|
||||
self.state.start_id = 0
|
||||
|
||||
self.log(
|
||||
'debug',
|
||||
'info',
|
||||
f'AgentController {self.id} initializing history from event {self.state.start_id}',
|
||||
)
|
||||
|
||||
@@ -1143,70 +1143,169 @@ class AgentController:
|
||||
|
||||
def _handle_long_context_error(self) -> None:
|
||||
# When context window is exceeded, keep roughly half of agent interactions
|
||||
kept_event_ids = {
|
||||
e.id for e in self._apply_conversation_window(self.state.history)
|
||||
}
|
||||
kept_events = self._apply_conversation_window()
|
||||
kept_event_ids = {e.id for e in kept_events}
|
||||
|
||||
self.log(
|
||||
'info',
|
||||
f'Context window exceeded. Keeping events with IDs: {kept_event_ids}',
|
||||
)
|
||||
|
||||
# The events to forget are those that are not in the kept set
|
||||
forgotten_event_ids = {e.id for e in self.state.history} - kept_event_ids
|
||||
|
||||
# Save the ID of the first event in our truncated history for future reloading
|
||||
if self.state.history:
|
||||
self.state.start_id = self.state.history[0].id
|
||||
if len(kept_event_ids) == 0:
|
||||
self.log(
|
||||
'warning',
|
||||
'No events kept after applying conversation window. This should not happen.',
|
||||
)
|
||||
|
||||
# verify that the first event id in kept_event_ids is the same as the start_id
|
||||
if len(kept_event_ids) > 0 and self.state.history[0].id not in kept_event_ids:
|
||||
self.log(
|
||||
'warning',
|
||||
f'First event after applying conversation window was not kept: {self.state.history[0].id} not in {kept_event_ids}',
|
||||
)
|
||||
|
||||
# Add an error event to trigger another step by the agent
|
||||
self.event_stream.add_event(
|
||||
CondensationAction(
|
||||
forgotten_events_start_id=min(forgotten_event_ids),
|
||||
forgotten_events_end_id=max(forgotten_event_ids),
|
||||
forgotten_events_start_id=min(forgotten_event_ids)
|
||||
if forgotten_event_ids
|
||||
else 0,
|
||||
forgotten_events_end_id=max(forgotten_event_ids)
|
||||
if forgotten_event_ids
|
||||
else 0,
|
||||
),
|
||||
EventSource.AGENT,
|
||||
)
|
||||
|
||||
def _apply_conversation_window(self, events: list[Event]) -> list[Event]:
|
||||
def _apply_conversation_window(self) -> list[Event]:
|
||||
"""Cuts history roughly in half when context window is exceeded.
|
||||
|
||||
It preserves action-observation pairs and ensures that the first user message is always included.
|
||||
It preserves action-observation pairs and ensures that the system message,
|
||||
the first user message, and its associated recall observation are always included
|
||||
at the beginning of the context window.
|
||||
|
||||
The algorithm:
|
||||
1. Cut history in half
|
||||
2. Check first event in new history:
|
||||
- If Observation: find and include its Action
|
||||
- If MessageAction: ensure its related Action-Observation pair isn't split
|
||||
3. Always include the first user message
|
||||
1. Identify essential initial events: System Message, First User Message, Recall Observation.
|
||||
2. Determine the slice of recent events to potentially keep.
|
||||
3. Validate the start of the recent slice for dangling observations.
|
||||
4. Combine essential events and validated recent events, ensuring essentials come first.
|
||||
|
||||
Args:
|
||||
events: List of events to filter
|
||||
|
||||
Returns:
|
||||
Filtered list of events keeping newest half while preserving pairs
|
||||
Filtered list of events keeping newest half while preserving pairs and essential initial events.
|
||||
"""
|
||||
if not events:
|
||||
return events
|
||||
if not self.state.history:
|
||||
return []
|
||||
|
||||
# Find first user message - we'll need to ensure it's included
|
||||
first_user_msg = next(
|
||||
(
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, MessageAction) and e.source == EventSource.USER
|
||||
),
|
||||
None,
|
||||
history = self.state.history
|
||||
|
||||
# 1. Identify essential initial events
|
||||
system_message: SystemMessageAction | None = None
|
||||
first_user_msg: MessageAction | None = None
|
||||
recall_action: RecallAction | None = None
|
||||
recall_observation: Observation | None = None
|
||||
|
||||
# Find System Message (should be the first event, if it exists)
|
||||
system_message = next(
|
||||
(e for e in history if isinstance(e, SystemMessageAction)), None
|
||||
)
|
||||
assert (
|
||||
system_message is None
|
||||
or isinstance(system_message, SystemMessageAction)
|
||||
and system_message.id == history[0].id
|
||||
)
|
||||
|
||||
# cut in half
|
||||
mid_point = max(1, len(events) // 2)
|
||||
kept_events = events[mid_point:]
|
||||
if len(kept_events) > 0 and isinstance(kept_events[0], Observation):
|
||||
kept_events = kept_events[1:]
|
||||
# Find First User Message, which MUST exist
|
||||
first_user_msg = self._first_user_message()
|
||||
if first_user_msg is None:
|
||||
raise RuntimeError('No first user message found in the event stream.')
|
||||
|
||||
# Ensure first user message is included
|
||||
if first_user_msg and first_user_msg not in kept_events:
|
||||
kept_events = [first_user_msg] + kept_events
|
||||
first_user_msg_index = -1
|
||||
for i, event in enumerate(history):
|
||||
if isinstance(event, MessageAction) and event.source == EventSource.USER:
|
||||
first_user_msg = event
|
||||
first_user_msg_index = i
|
||||
break
|
||||
|
||||
# start_id points to first user message
|
||||
# Find Recall Action and Observation related to the First User Message
|
||||
if first_user_msg is not None and first_user_msg_index != -1:
|
||||
# Look for RecallAction after the first user message
|
||||
for i in range(first_user_msg_index + 1, len(history)):
|
||||
event = history[i]
|
||||
if (
|
||||
isinstance(event, RecallAction)
|
||||
and event.query == first_user_msg.content
|
||||
):
|
||||
# Found RecallAction, now look for its Observation
|
||||
recall_action = event
|
||||
for j in range(i + 1, len(history)):
|
||||
obs_event = history[j]
|
||||
# Check for Observation caused by this RecallAction
|
||||
if (
|
||||
isinstance(obs_event, Observation)
|
||||
and obs_event.cause == recall_action.id
|
||||
):
|
||||
recall_observation = obs_event
|
||||
break # Found the observation, stop inner loop
|
||||
break # Found the recall action (and maybe obs), stop outer loop
|
||||
|
||||
essential_events: list[Event] = []
|
||||
if system_message:
|
||||
essential_events.append(system_message)
|
||||
if first_user_msg:
|
||||
self.state.start_id = first_user_msg.id
|
||||
essential_events.append(first_user_msg)
|
||||
# Also keep the RecallAction that triggered the essential RecallObservation
|
||||
if recall_action:
|
||||
essential_events.append(recall_action)
|
||||
if recall_observation:
|
||||
essential_events.append(recall_observation)
|
||||
|
||||
return kept_events
|
||||
# 2. Determine the slice of recent events to potentially keep
|
||||
num_non_essential_events = len(history) - len(essential_events)
|
||||
# Keep roughly half of the non-essential events, minimum 1
|
||||
num_recent_to_keep = max(1, num_non_essential_events // 2)
|
||||
|
||||
# Calculate the starting index for the recent slice
|
||||
slice_start_index = len(history) - num_recent_to_keep
|
||||
slice_start_index = max(0, slice_start_index) # Ensure index is not negative
|
||||
recent_events_slice = history[slice_start_index:]
|
||||
|
||||
# 3. Validate the start of the recent slice for dangling observations
|
||||
# IMPORTANT: Most observations in history are tool call results, which cannot be without their action, or we get an LLM API error
|
||||
first_valid_event_index = 0
|
||||
for i, event in enumerate(recent_events_slice):
|
||||
if isinstance(event, Observation):
|
||||
first_valid_event_index += 1
|
||||
else:
|
||||
break
|
||||
# If all events in the slice are dangling observations, we need to keep at least one
|
||||
if first_valid_event_index == len(recent_events_slice):
|
||||
self.log(
|
||||
'warning',
|
||||
'All recent events are dangling observations, which we truncate. This means the agent has only the essential first events. This should not happen.',
|
||||
)
|
||||
|
||||
# Adjust the recent_events_slice if dangling observations were found at the start
|
||||
if first_valid_event_index < len(recent_events_slice):
|
||||
validated_recent_events = recent_events_slice[first_valid_event_index:]
|
||||
if first_valid_event_index > 0:
|
||||
self.log(
|
||||
'debug',
|
||||
f'Removed {first_valid_event_index} dangling observation(s) from the start of recent event slice.',
|
||||
)
|
||||
else:
|
||||
validated_recent_events = []
|
||||
|
||||
# 4. Combine essential events and validated recent events
|
||||
events_to_keep: list[Event] = essential_events + validated_recent_events
|
||||
self.log('debug', f'History truncated. Kept {len(events_to_keep)} events.')
|
||||
|
||||
return events_to_keep
|
||||
|
||||
def _is_stuck(self) -> bool:
|
||||
"""Checks if the agent or its delegate is stuck in a loop.
|
||||
|
||||
@@ -14,12 +14,14 @@ from openhands.core.cli_commands import (
|
||||
)
|
||||
from openhands.core.cli_tui import (
|
||||
UsageMetrics,
|
||||
display_agent_running_message,
|
||||
display_banner,
|
||||
display_event,
|
||||
display_initial_user_prompt,
|
||||
display_initialization_animation,
|
||||
display_runtime_initialization_message,
|
||||
display_welcome_message,
|
||||
process_agent_pause,
|
||||
read_confirmation_input,
|
||||
read_prompt_input,
|
||||
)
|
||||
@@ -99,6 +101,7 @@ async def run_session(
|
||||
|
||||
sid = str(uuid4())
|
||||
is_loaded = asyncio.Event()
|
||||
is_paused = asyncio.Event()
|
||||
|
||||
# Show runtime initialization message
|
||||
display_runtime_initialization_message(config.runtime)
|
||||
@@ -124,10 +127,12 @@ async def run_session(
|
||||
|
||||
usage_metrics = UsageMetrics()
|
||||
|
||||
async def prompt_for_next_task():
|
||||
async def prompt_for_next_task(agent_state: str):
|
||||
nonlocal reload_microagents, new_session_requested
|
||||
while True:
|
||||
next_message = await read_prompt_input(config.cli_multiline_input)
|
||||
next_message = await read_prompt_input(
|
||||
agent_state, multiline=config.cli_multiline_input
|
||||
)
|
||||
|
||||
if not next_message.strip():
|
||||
continue
|
||||
@@ -150,14 +155,23 @@ async def run_session(
|
||||
return
|
||||
|
||||
async def on_event_async(event: Event) -> None:
|
||||
nonlocal reload_microagents
|
||||
nonlocal reload_microagents, is_paused
|
||||
display_event(event, config)
|
||||
update_usage_metrics(event, usage_metrics)
|
||||
|
||||
# Pause the agent if the pause event is set (if Ctrl-P is pressed)
|
||||
if is_paused.is_set():
|
||||
event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.PAUSED),
|
||||
EventSource.USER,
|
||||
)
|
||||
is_paused.clear()
|
||||
|
||||
if isinstance(event, AgentStateChangedObservation):
|
||||
if event.agent_state in [
|
||||
AgentState.AWAITING_USER_INPUT,
|
||||
AgentState.FINISHED,
|
||||
AgentState.PAUSED,
|
||||
]:
|
||||
# Reload microagents after initialization of repo.md
|
||||
if reload_microagents:
|
||||
@@ -166,20 +180,28 @@ async def run_session(
|
||||
)
|
||||
memory.load_user_workspace_microagents(microagents)
|
||||
reload_microagents = False
|
||||
await prompt_for_next_task()
|
||||
await prompt_for_next_task(event.agent_state)
|
||||
|
||||
if event.agent_state == AgentState.AWAITING_USER_CONFIRMATION:
|
||||
user_confirmed = await read_confirmation_input()
|
||||
if user_confirmed:
|
||||
event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.USER_CONFIRMED),
|
||||
EventSource.USER,
|
||||
)
|
||||
else:
|
||||
event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.USER_REJECTED),
|
||||
EventSource.USER,
|
||||
)
|
||||
# Only display the confirmation prompt if the agent is not paused
|
||||
if not is_paused.is_set():
|
||||
user_confirmed = await read_confirmation_input()
|
||||
if user_confirmed:
|
||||
event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.USER_CONFIRMED),
|
||||
EventSource.USER,
|
||||
)
|
||||
else:
|
||||
event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.USER_REJECTED),
|
||||
EventSource.USER,
|
||||
)
|
||||
|
||||
if event.agent_state == AgentState.RUNNING:
|
||||
# Enable pause/resume functionality only if the confirmation mode is disabled
|
||||
if not config.security.confirmation_mode:
|
||||
display_agent_running_message()
|
||||
loop.create_task(process_agent_pause(is_paused))
|
||||
|
||||
def on_event(event: Event) -> None:
|
||||
loop.create_task(on_event_async(event))
|
||||
@@ -212,7 +234,7 @@ async def run_session(
|
||||
clear()
|
||||
|
||||
# Show OpenHands banner and session ID
|
||||
display_banner(session_id=sid, is_loaded=is_loaded)
|
||||
display_banner(session_id=sid)
|
||||
|
||||
# Show OpenHands welcome
|
||||
display_welcome_message()
|
||||
@@ -225,7 +247,7 @@ async def run_session(
|
||||
)
|
||||
else:
|
||||
# Otherwise prompt for the user's first message right away
|
||||
asyncio.create_task(prompt_for_next_task())
|
||||
asyncio.create_task(prompt_for_next_task(''))
|
||||
|
||||
await run_agent_until_done(
|
||||
controller, runtime, memory, [AgentState.STOPPED, AgentState.ERROR]
|
||||
|
||||
@@ -70,6 +70,8 @@ async def handle_commands(
|
||||
)
|
||||
elif command == '/settings':
|
||||
await handle_settings_command(config, settings_store)
|
||||
elif command == '/resume':
|
||||
close_repl, new_session_requested = await handle_resume_command(event_stream)
|
||||
else:
|
||||
close_repl = True
|
||||
action = MessageAction(content=command)
|
||||
@@ -183,6 +185,28 @@ async def handle_settings_command(
|
||||
await modify_llm_settings_advanced(config, settings_store)
|
||||
|
||||
|
||||
# FIXME: Currently there's an issue with the actual 'resume' behavior.
|
||||
# Setting the agent state to RUNNING will currently freeze the agent without continuing with the rest of the task.
|
||||
# This is a workaround to handle the resume command for the time being. Replace user message with the state change event once the issue is fixed.
|
||||
async def handle_resume_command(
|
||||
event_stream: EventStream,
|
||||
) -> tuple[bool, bool]:
|
||||
close_repl = True
|
||||
new_session_requested = False
|
||||
|
||||
event_stream.add_event(
|
||||
MessageAction(content='continue'),
|
||||
EventSource.USER,
|
||||
)
|
||||
|
||||
# event_stream.add_event(
|
||||
# ChangeAgentStateAction(AgentState.RUNNING),
|
||||
# EventSource.ENVIRONMENT,
|
||||
# )
|
||||
|
||||
return close_repl, new_session_requested
|
||||
|
||||
|
||||
async def init_repository(current_dir: str) -> bool:
|
||||
repo_file_path = Path(current_dir) / '.openhands' / 'microagents' / 'repo.md'
|
||||
init_repo = False
|
||||
|
||||
@@ -10,7 +10,9 @@ from prompt_toolkit import PromptSession, print_formatted_text
|
||||
from prompt_toolkit.application import Application
|
||||
from prompt_toolkit.completion import Completer, Completion
|
||||
from prompt_toolkit.formatted_text import HTML, FormattedText, StyleAndTextTuples
|
||||
from prompt_toolkit.input import create_input
|
||||
from prompt_toolkit.key_binding import KeyBindings
|
||||
from prompt_toolkit.keys import Keys
|
||||
from prompt_toolkit.layout.containers import HSplit, Window
|
||||
from prompt_toolkit.layout.controls import FormattedTextControl
|
||||
from prompt_toolkit.layout.layout import Layout
|
||||
@@ -22,6 +24,7 @@ from prompt_toolkit.widgets import Frame, TextArea
|
||||
|
||||
from openhands import __version__
|
||||
from openhands.core.config import AppConfig
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.events import EventSource
|
||||
from openhands.events.action import (
|
||||
Action,
|
||||
@@ -32,6 +35,7 @@ from openhands.events.action import (
|
||||
)
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation import (
|
||||
AgentStateChangedObservation,
|
||||
CmdOutputObservation,
|
||||
FileEditObservation,
|
||||
FileReadObservation,
|
||||
@@ -56,6 +60,7 @@ COMMANDS = {
|
||||
'/status': 'Display session details and usage metrics',
|
||||
'/new': 'Create a new session',
|
||||
'/settings': 'Display and modify current settings',
|
||||
'/resume': 'Resume the agent',
|
||||
}
|
||||
|
||||
|
||||
@@ -114,7 +119,7 @@ def display_initialization_animation(text, is_loaded: asyncio.Event):
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def display_banner(session_id: str, is_loaded: asyncio.Event):
|
||||
def display_banner(session_id: str):
|
||||
print_formatted_text(
|
||||
HTML(r"""<gold>
|
||||
___ _ _ _
|
||||
@@ -129,11 +134,8 @@ def display_banner(session_id: str, is_loaded: asyncio.Event):
|
||||
|
||||
print_formatted_text(HTML(f'<grey>OpenHands CLI v{__version__}</grey>'))
|
||||
|
||||
banner_text = (
|
||||
'Initialized session' if is_loaded.is_set() else 'Initializing session'
|
||||
)
|
||||
print_formatted_text('')
|
||||
print_formatted_text(HTML(f'<grey>{banner_text} {session_id}</grey>'))
|
||||
print_formatted_text(HTML(f'<grey>Initialized session {session_id}</grey>'))
|
||||
print_formatted_text('')
|
||||
|
||||
|
||||
@@ -177,6 +179,8 @@ def display_event(event: Event, config: AppConfig) -> None:
|
||||
display_file_edit(event)
|
||||
if isinstance(event, FileReadObservation):
|
||||
display_file_read(event)
|
||||
if isinstance(event, AgentStateChangedObservation):
|
||||
display_agent_paused_message(event.agent_state)
|
||||
|
||||
|
||||
def display_message(message: str):
|
||||
@@ -389,77 +393,58 @@ def display_status(usage_metrics: UsageMetrics, session_id: str):
|
||||
display_usage_metrics(usage_metrics)
|
||||
|
||||
|
||||
def display_agent_running_message():
|
||||
print_formatted_text('')
|
||||
print_formatted_text(
|
||||
HTML('<gold>Agent running...</gold> <grey>(Ctrl-P to pause)</grey>')
|
||||
)
|
||||
|
||||
|
||||
def display_agent_paused_message(agent_state: str):
|
||||
if agent_state != AgentState.PAUSED:
|
||||
return
|
||||
print_formatted_text('')
|
||||
print_formatted_text(
|
||||
HTML('<gold>Agent paused</gold> <grey>(type /resume to resume)</grey>')
|
||||
)
|
||||
|
||||
|
||||
# Common input functions
|
||||
class CommandCompleter(Completer):
|
||||
"""Custom completer for commands."""
|
||||
|
||||
def __init__(self, agent_state: str):
|
||||
super().__init__()
|
||||
self.agent_state = agent_state
|
||||
|
||||
def get_completions(self, document, complete_event):
|
||||
text = document.text
|
||||
|
||||
# Only show completions if the user has typed '/'
|
||||
text = document.text_before_cursor.lstrip()
|
||||
if text.startswith('/'):
|
||||
# If just '/' is typed, show all commands
|
||||
if text == '/':
|
||||
for command, description in COMMANDS.items():
|
||||
available_commands = dict(COMMANDS)
|
||||
if self.agent_state != AgentState.PAUSED:
|
||||
available_commands.pop('/resume', None)
|
||||
|
||||
for command, description in available_commands.items():
|
||||
if command.startswith(text):
|
||||
yield Completion(
|
||||
command[1:], # Remove the leading '/' as it's already typed
|
||||
start_position=0,
|
||||
display=f'{command} - {description}',
|
||||
command,
|
||||
start_position=-len(text),
|
||||
display_meta=description,
|
||||
style='bg:ansidarkgray fg:ansiwhite',
|
||||
)
|
||||
# Otherwise show matching commands
|
||||
else:
|
||||
for command, description in COMMANDS.items():
|
||||
if command.startswith(text):
|
||||
yield Completion(
|
||||
command[len(text) :], # Complete the remaining part
|
||||
start_position=0,
|
||||
display=f'{command} - {description}',
|
||||
)
|
||||
|
||||
|
||||
prompt_session = PromptSession(style=DEFAULT_STYLE)
|
||||
|
||||
# RPrompt animation related variables
|
||||
SPINNER_FRAMES = [
|
||||
'[ ■□□□ ]',
|
||||
'[ □■□□ ]',
|
||||
'[ □□■□ ]',
|
||||
'[ □□□■ ]',
|
||||
'[ □□■□ ]',
|
||||
'[ □■□□ ]',
|
||||
]
|
||||
ANIMATION_INTERVAL = 0.2 # seconds
|
||||
|
||||
current_frame_index = 0
|
||||
last_update_time = time.monotonic()
|
||||
def create_prompt_session():
|
||||
return PromptSession(style=DEFAULT_STYLE)
|
||||
|
||||
|
||||
# RPrompt function for the user confirmation
|
||||
def get_rprompt() -> FormattedText:
|
||||
"""
|
||||
Returns the current animation frame for the rprompt.
|
||||
This function is called by prompt_toolkit during rendering.
|
||||
"""
|
||||
global current_frame_index, last_update_time
|
||||
|
||||
# Only update the frame if enough time has passed
|
||||
# This prevents excessive recalculation during rendering
|
||||
now = time.monotonic()
|
||||
if now - last_update_time > ANIMATION_INTERVAL:
|
||||
current_frame_index = (current_frame_index + 1) % len(SPINNER_FRAMES)
|
||||
last_update_time = now
|
||||
|
||||
# Return the frame wrapped in FormattedText
|
||||
return FormattedText(
|
||||
[
|
||||
('', ' '), # Add a space before the spinner
|
||||
(COLOR_GOLD, SPINNER_FRAMES[current_frame_index]),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
async def read_prompt_input(multiline=False):
|
||||
async def read_prompt_input(agent_state: str, multiline=False):
|
||||
try:
|
||||
prompt_session = create_prompt_session()
|
||||
prompt_session.completer = (
|
||||
CommandCompleter(agent_state) if not multiline else None
|
||||
)
|
||||
|
||||
if multiline:
|
||||
kb = KeyBindings()
|
||||
|
||||
@@ -470,38 +455,54 @@ async def read_prompt_input(multiline=False):
|
||||
with patch_stdout():
|
||||
print_formatted_text('')
|
||||
message = await prompt_session.prompt_async(
|
||||
'Enter your message and press Ctrl+D to finish:\n',
|
||||
HTML(
|
||||
'<gold>Enter your message and press Ctrl-D to finish:</gold>\n'
|
||||
),
|
||||
multiline=True,
|
||||
key_bindings=kb,
|
||||
)
|
||||
else:
|
||||
with patch_stdout():
|
||||
print_formatted_text('')
|
||||
prompt_session.completer = CommandCompleter()
|
||||
message = await prompt_session.prompt_async(
|
||||
'> ',
|
||||
HTML('<gold>> </gold>'),
|
||||
)
|
||||
return message
|
||||
return message if message is not None else ''
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
return '/exit'
|
||||
|
||||
|
||||
async def read_confirmation_input():
|
||||
async def read_confirmation_input() -> bool:
|
||||
try:
|
||||
prompt_session = create_prompt_session()
|
||||
|
||||
with patch_stdout():
|
||||
prompt_session.completer = None
|
||||
confirmation = await prompt_session.prompt_async(
|
||||
'Proceed with action? (y)es/(n)o > ',
|
||||
rprompt=get_rprompt,
|
||||
refresh_interval=ANIMATION_INTERVAL / 2,
|
||||
print_formatted_text('')
|
||||
confirmation: str = await prompt_session.prompt_async(
|
||||
HTML('<gold>Proceed with action? (y)es/(n)o > </gold>'),
|
||||
)
|
||||
prompt_session.rprompt = None
|
||||
confirmation = confirmation.strip().lower()
|
||||
|
||||
confirmation = '' if confirmation is None else confirmation.strip().lower()
|
||||
return confirmation in ['y', 'yes']
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
return False
|
||||
|
||||
|
||||
async def process_agent_pause(done: asyncio.Event) -> None:
|
||||
input = create_input()
|
||||
|
||||
def keys_ready():
|
||||
for key_press in input.read_keys():
|
||||
if key_press.key == Keys.ControlP:
|
||||
print_formatted_text('')
|
||||
print_formatted_text(HTML('<gold>Pausing the agent...</gold>'))
|
||||
done.set()
|
||||
|
||||
with input.raw_mode():
|
||||
with input.attach(keys_ready):
|
||||
await done.wait()
|
||||
|
||||
|
||||
def cli_confirm(
|
||||
question: str = 'Are you sure?', choices: list[str] | None = None
|
||||
) -> int:
|
||||
|
||||
@@ -41,7 +41,7 @@ class GitHubService(BaseGitService, GitService):
|
||||
if token:
|
||||
self.token = token
|
||||
|
||||
if base_domain and base_domain != "github.com":
|
||||
if base_domain:
|
||||
self.BASE_URL = f'https://{base_domain}/api/v3'
|
||||
|
||||
@property
|
||||
|
||||
@@ -34,7 +34,6 @@ from openhands.server.types import AppMode
|
||||
class ProviderToken(BaseModel):
|
||||
token: SecretStr | None = Field(default=None)
|
||||
user_id: str | None = Field(default=None)
|
||||
base_domain: str | None = Field(default=None)
|
||||
|
||||
model_config = {
|
||||
'frozen': True, # Makes the entire model immutable
|
||||
@@ -44,20 +43,15 @@ class ProviderToken(BaseModel):
|
||||
@classmethod
|
||||
def from_value(cls, token_value: ProviderToken | dict[str, str]) -> ProviderToken:
|
||||
"""Factory method to create a ProviderToken from various input types"""
|
||||
if isinstance(token_value, cls):
|
||||
if isinstance(token_value, ProviderToken):
|
||||
return token_value
|
||||
elif isinstance(token_value, dict):
|
||||
token_str = token_value.get('token')
|
||||
user_id = token_value.get('user_id')
|
||||
base_domain = token_value.get('base_domain')
|
||||
return cls(
|
||||
token=SecretStr(token_str) if token_str is not None else None,
|
||||
user_id=user_id,
|
||||
base_domain=base_domain,
|
||||
)
|
||||
return cls(token=SecretStr(token_str), user_id=user_id)
|
||||
|
||||
else:
|
||||
raise ValueError('Unsupported Provider token type')
|
||||
raise ValueError('Unsupport Provider token type')
|
||||
|
||||
|
||||
PROVIDER_TOKEN_TYPE = MappingProxyType[ProviderType, ProviderToken]
|
||||
@@ -104,7 +98,6 @@ class SecretStore(BaseModel):
|
||||
if expose_secrets
|
||||
else pydantic_encoder(provider_token.token),
|
||||
'user_id': provider_token.user_id,
|
||||
'base_domain': provider_token.base_domain,
|
||||
}
|
||||
|
||||
return tokens
|
||||
|
||||
@@ -54,6 +54,7 @@ class ConversationMemory:
|
||||
def process_events(
|
||||
self,
|
||||
condensed_history: list[Event],
|
||||
initial_user_action: MessageAction,
|
||||
max_message_chars: int | None = None,
|
||||
vision_is_active: bool = False,
|
||||
) -> list[Message]:
|
||||
@@ -66,12 +67,14 @@ class ConversationMemory:
|
||||
max_message_chars: The maximum number of characters in the content of an event included
|
||||
in the prompt to the LLM. Larger observations are truncated.
|
||||
vision_is_active: Whether vision is active in the LLM. If True, image URLs will be included.
|
||||
initial_user_action: The initial user message action, if available. Used to ensure the conversation starts correctly.
|
||||
"""
|
||||
|
||||
events = condensed_history
|
||||
|
||||
# Ensure the system message exists (handles legacy cases)
|
||||
# Ensure the event list starts with SystemMessageAction, then MessageAction(source='user')
|
||||
self._ensure_system_message(events)
|
||||
self._ensure_initial_user_message(events, initial_user_action)
|
||||
|
||||
# log visual browsing status
|
||||
logger.debug(f'Visual browsing: {self.agent_config.enable_som_visual_browsing}')
|
||||
@@ -699,6 +702,43 @@ class ConversationMemory:
|
||||
system_message = SystemMessageAction(content=system_prompt)
|
||||
# Insert the system message directly at the beginning of the events list
|
||||
events.insert(0, system_message)
|
||||
logger.debug(
|
||||
logger.info(
|
||||
'[ConversationMemory] Added SystemMessageAction for backward compatibility'
|
||||
)
|
||||
|
||||
def _ensure_initial_user_message(
|
||||
self, events: list[Event], initial_user_action: MessageAction
|
||||
) -> None:
|
||||
"""Checks if the second event is a user MessageAction and inserts the provided one if needed."""
|
||||
if (
|
||||
not events
|
||||
): # Should have system message from previous step, but safety check
|
||||
logger.error('Cannot ensure initial user message: event list is empty.')
|
||||
# Or raise? Let's log for now, _ensure_system_message should handle this.
|
||||
return
|
||||
|
||||
# We expect events[0] to be SystemMessageAction after _ensure_system_message
|
||||
if len(events) == 1:
|
||||
# Only system message exists
|
||||
logger.info(
|
||||
'Initial user message action was missing. Inserting the initial user message.'
|
||||
)
|
||||
events.insert(1, initial_user_action)
|
||||
elif not isinstance(events[1], MessageAction) or events[1].source != 'user':
|
||||
# The second event exists but is not the correct initial user message action.
|
||||
# We will insert the correct one provided.
|
||||
logger.info(
|
||||
'Second event was not the initial user message action. Inserting correct one at index 1.'
|
||||
)
|
||||
|
||||
# Insert the user message event at index 1. This will be the second message as LLM APIs expect
|
||||
# but something was wrong with the history, so log all we can.
|
||||
events.insert(1, initial_user_action)
|
||||
|
||||
# Else: events[1] is already a user MessageAction.
|
||||
# Check if it matches the one provided (if any discrepancy, log warning but proceed).
|
||||
elif events[1] != initial_user_action:
|
||||
logger.debug(
|
||||
'The user MessageAction at index 1 does not match the provided initial_user_action. '
|
||||
'Proceeding with the one found in condensed history.'
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import overload
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action.agent import CondensationAction
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation.agent import AgentCondensationObservation
|
||||
@@ -65,6 +66,8 @@ class View(BaseModel):
|
||||
break
|
||||
|
||||
if summary is not None and summary_offset is not None:
|
||||
logger.info(f'Inserting summary at offset {summary_offset}')
|
||||
|
||||
kept_events.insert(
|
||||
summary_offset, AgentCondensationObservation(content=summary)
|
||||
)
|
||||
|
||||
@@ -1,191 +1,141 @@
|
||||
# OpenHands Github & Gitlab Issue Resolver 🙌
|
||||
# OpenHands GitHub & GitLab Issue Resolver 🙌
|
||||
|
||||
Need help resolving a GitHub issue but don't have the time to do it yourself? Let an AI agent help you out!
|
||||
Need help resolving GitHub or GitLab issues? Let an AI agent help you out!
|
||||
|
||||
This tool allows you to use open-source AI agents based on [OpenHands](https://github.com/all-hands-ai/openhands)
|
||||
to attempt to resolve GitHub issues automatically. While it can handle multiple issues, it's primarily designed
|
||||
to help you resolve one issue at a time with high quality.
|
||||
This tool uses [OpenHands](https://github.com/all-hands-ai/openhands) AI agents to automatically resolve issues in your repositories. It's designed to handle one issue at a time with high quality.
|
||||
|
||||
Getting started is simple - just follow the instructions below.
|
||||
## 1. Setting Up for GitHub (Action Workflow)
|
||||
|
||||
## Using the GitHub Actions Workflow
|
||||
### Prerequisites
|
||||
|
||||
This repository includes a GitHub Actions workflow that can automatically attempt to fix individual issues labeled with 'fix-me'.
|
||||
Follow these steps to use this workflow in your own repository:
|
||||
- [Create a personal access token](https://github.com/settings/tokens?type=beta) with read/write scope for
|
||||
|
||||
1. [Create a personal access token](https://github.com/settings/tokens?type=beta) with read/write scope for "contents", "issues", "pull requests", and "workflows"
|
||||
- "contents"
|
||||
- "issues"
|
||||
- "pull requests"
|
||||
- "workflows"
|
||||
|
||||
Note: If you're working with an organizational repository, you may need to configure the organization's personal access token policy first. See [Setting a personal access token policy for your organization](https://docs.github.com/en/organizations/managing-programmatic-access-to-your-organization/setting-a-personal-access-token-policy-for-your-organization) for details.
|
||||
- Create an LLM API key (e,g [Claude API](https://www.anthropic.com/api))
|
||||
|
||||
2. Create an API key for the [Claude API](https://www.anthropic.com/api) (recommended) or another supported LLM service
|
||||
### Installation
|
||||
|
||||
3. Copy `examples/openhands-resolver.yml` to your repository's `.github/workflows/` directory
|
||||
1. Copy `examples/openhands-resolver.yml` to your repository's `.github/workflows/` directory
|
||||
|
||||
4. Configure repository permissions:
|
||||
- Go to `Settings -> Actions -> General -> Workflow permissions`
|
||||
- Select "Read and write permissions"
|
||||
- Enable "Allow Github Actions to create and approve pull requests"
|
||||
2. Configure repository permissions:
|
||||
|
||||
Note: If the "Read and write permissions" option is greyed out:
|
||||
- First check if permissions need to be set at the organization level
|
||||
- If still greyed out at the organization level, permissions need to be set in the [Enterprise policy settings](https://docs.github.com/en/enterprise-cloud@latest/admin/enforcing-policies/enforcing-policies-for-your-enterprise/enforcing-policies-for-github-actions-in-your-enterprise#enforcing-a-policy-for-workflow-permissions-in-your-enterprise)
|
||||
- Go to `Settings -> Actions -> General -> Workflow permissions`
|
||||
- Select **Read and write permissions**
|
||||
- Enable **Allow Github Actions to create and approve pull requests**
|
||||
|
||||
> If "Read and write permissions" is greyed out:
|
||||
>
|
||||
> - Check organization settings first
|
||||
> - Otherwise, permissions might need to be set in [Enterprise policy settings](https://docs.github.com/en/enterprise-cloud@latest/admin/enforcing-policies/enforcing-policies-for-your-enterprise/enforcing-policies-for-github-actions-in-your-enterprise#enforcing-a-policy-for-workflow-permissions-in-your-enterprise)
|
||||
|
||||
3. Set up [GitHub secrets](https://docs.github.com/en/actions/security-for-github-actions/security-guides/using-secrets-in-github-actions):
|
||||
|
||||
5. Set up [GitHub secrets](https://docs.github.com/en/actions/security-for-github-actions/security-guides/using-secrets-in-github-actions):
|
||||
- Required:
|
||||
- `LLM_API_KEY`: Your LLM API key
|
||||
- `LLM_API_KEY`: Your LLM API key
|
||||
- Optional:
|
||||
- `PAT_USERNAME`: GitHub username for the personal access token
|
||||
- `PAT_TOKEN`: The personal access token
|
||||
- `LLM_BASE_URL`: Base URL for LLM API (only if using a proxy)
|
||||
- [See how to customize more configurations](https://docs.all-hands.dev/modules/usage/how-to/github-action#custom-configurations)
|
||||
|
||||
Note: You can set these secrets at the organization level to use across multiple repositories.
|
||||
## 2. Setting up GitLab (CI Runner)
|
||||
|
||||
6. Set up any [custom configurations required](https://docs.all-hands.dev/modules/usage/how-to/github-action#custom-configurations)
|
||||
### Prerequisites
|
||||
|
||||
7. Usage:
|
||||
There are two ways to trigger the OpenHands agent:
|
||||
Create a GitLab Personal Access Token with API, read/write access
|
||||
|
||||
a. Using the 'fix-me' label:
|
||||
- Add the 'fix-me' label to any issue you want the AI to resolve
|
||||
- The agent will consider all comments in the issue thread when resolving
|
||||
- The workflow will:
|
||||
1. Attempt to resolve the issue using OpenHands
|
||||
2. Create a draft PR if successful, or push a branch if unsuccessful
|
||||
3. Comment on the issue with the results
|
||||
4. Remove the 'fix-me' label once processed
|
||||
### Installation
|
||||
|
||||
b. Using `@openhands-agent` mention:
|
||||
- Create a new comment containing `@openhands-agent` in any issue
|
||||
- The agent will only consider the comment where it's mentioned
|
||||
- The workflow will:
|
||||
1. Attempt to resolve the issue based on the specific comment
|
||||
2. Create a draft PR if successful, or push a branch if unsuccessful
|
||||
3. Comment on the issue with the results
|
||||
## 3. Triggering OpenHands Agent
|
||||
|
||||
Need help? Feel free to [open an issue](https://github.com/all-hands-ai/openhands/issues) or email us at [contact@all-hands.dev](mailto:contact@all-hands.dev).
|
||||
You can trigger OpenHands in two shared ways (works for both GitHub and GitLab):
|
||||
|
||||
## Manual Installation
|
||||
Using the 'fix-me' label:
|
||||
|
||||
If you prefer to run the resolver programmatically instead of using GitHub Actions, follow these steps:
|
||||
- Add the 'fix-me' label to any issue you want the AI to resolve
|
||||
- The agent will consider all comments in the issue thread when resolving
|
||||
|
||||
1. Install the package:
|
||||
Using `@openhands-agent` in an issue/pr comment:
|
||||
|
||||
- Create a new comment containing `@openhands-agent`
|
||||
- The agent will only consider the comment + comment thread where it's mentioned
|
||||
|
||||
## 4. Running Locally
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
pip install openhands-ai
|
||||
```
|
||||
|
||||
2. Create a GitHub or GitLab access token:
|
||||
- Create a GitHub acces token
|
||||
- Visit [GitHub's token settings](https://github.com/settings/personal-access-tokens/new)
|
||||
- Create a fine-grained token with these scopes:
|
||||
- "Content"
|
||||
- "Pull requests"
|
||||
- "Issues"
|
||||
- "Workflows"
|
||||
- If you don't have push access to the target repo, you can fork it first
|
||||
### Setup
|
||||
|
||||
- Create a GitLab acces token
|
||||
- Visit [GitLab's token settings](https://gitlab.com/-/user_settings/personal_access_tokens)
|
||||
- Create a fine-grained token with these scopes:
|
||||
- 'api'
|
||||
- 'read_api'
|
||||
- 'read_user'
|
||||
- 'read_repository'
|
||||
- 'write_repository'
|
||||
Create a GitHub or GitLab access token with appropriate permissions
|
||||
|
||||
3. Set up environment variables:
|
||||
Set up environment variables:
|
||||
|
||||
```bash
|
||||
|
||||
# GitHub credentials
|
||||
|
||||
export GITHUB_TOKEN="your-github-token"
|
||||
export GIT_USERNAME="your-github-username" # Optional, defaults to token owner
|
||||
|
||||
# GitLab credentials if you're using GitLab repo
|
||||
export GIT_USERNAME="your-github-username"
|
||||
|
||||
# GitLab credentials (if using GitLab)
|
||||
export GITLAB_TOKEN="your-gitlab-token"
|
||||
export GIT_USERNAME="your-gitlab-username" # Optional, defaults to token owner
|
||||
export GIT_USERNAME="your-gitlab-username"
|
||||
|
||||
# LLM configuration
|
||||
|
||||
export LLM_MODEL="anthropic/claude-3-5-sonnet-20241022" # Recommended
|
||||
export LLM_MODEL="anthropic/claude-3-5-sonnet-20241022"
|
||||
export LLM_API_KEY="your-llm-api-key"
|
||||
export LLM_BASE_URL="your-api-url" # Optional, for API proxies
|
||||
export LLM_BASE_URL="your-api-url" # Optional
|
||||
```
|
||||
|
||||
Note: OpenHands works best with powerful models like Anthropic's Claude or OpenAI's GPT-4. While other models are supported, they may not perform as well for complex issue resolution.
|
||||
### Resolving Issues
|
||||
|
||||
## Resolving Issues
|
||||
|
||||
The resolver can automatically attempt to fix a single issue in your repository using the following command:
|
||||
Resolve a single issue:
|
||||
|
||||
```bash
|
||||
python -m openhands.resolver.resolve_issue --selected-repo [OWNER]/[REPO] --issue-number [NUMBER]
|
||||
```
|
||||
|
||||
For instance, if you want to resolve issue #100 in this repo, you would run:
|
||||
### Responding to PR Comments
|
||||
|
||||
```bash
|
||||
python -m openhands.resolver.resolve_issue --selected-repo all-hands-ai/openhands --issue-number 100
|
||||
```
|
||||
|
||||
The output will be written to the `output/` directory.
|
||||
|
||||
If you've installed the package from source using poetry, you can use:
|
||||
|
||||
```bash
|
||||
poetry run python openhands/resolver/resolve_issue.py --selected-repo all-hands-ai/openhands --issue-number 100
|
||||
```
|
||||
|
||||
## Responding to PR Comments
|
||||
|
||||
The resolver can also respond to comments on pull requests using:
|
||||
Respond to comments on pull requests:
|
||||
|
||||
```bash
|
||||
python -m openhands.resolver.send_pull_request --issue-number PR_NUMBER --issue-type pr
|
||||
```
|
||||
|
||||
This functionality is available both through the GitHub Actions workflow and when running the resolver locally.
|
||||
### Visualizing Results
|
||||
|
||||
## Visualizing successful PRs
|
||||
|
||||
To find successful PRs, you can run the following command:
|
||||
View successful PRs:
|
||||
|
||||
```bash
|
||||
grep '"success":true' output/output.jsonl | sed 's/.*\("number":[0-9]*\).*/\1/g'
|
||||
```
|
||||
|
||||
Then you can go through and visualize the ones you'd like.
|
||||
Visualize specific PR:
|
||||
|
||||
```bash
|
||||
python -m openhands.resolver.visualize_resolver_output --issue-number ISSUE_NUMBER --vis-method json
|
||||
```
|
||||
|
||||
## Uploading PRs
|
||||
### Uploading PRs
|
||||
|
||||
If you find any PRs that were successful, you can upload them.
|
||||
There are three ways you can upload:
|
||||
|
||||
1. `branch` - upload a branch without creating a PR
|
||||
2. `draft` - create a draft PR
|
||||
3. `ready` - create a non-draft PR that's ready for review
|
||||
Upload your changes in one of three ways:
|
||||
|
||||
```bash
|
||||
python -m openhands.resolver.send_pull_request --issue-number ISSUE_NUMBER --username YOUR_GITHUB_OR_GITLAB_USERNAME --pr-type draft
|
||||
python -m openhands.resolver.send_pull_request --issue-number ISSUE_NUMBER --username YOUR_GITHUB_OR_GITLAB_USERNAME --pr-type [branch|draft|ready]
|
||||
```
|
||||
|
||||
If you want to upload to a fork, you can do so by specifying the `fork-owner`:
|
||||
## Custom Instructions
|
||||
|
||||
```bash
|
||||
python -m openhands.resolver.send_pull_request --issue-number ISSUE_NUMBER --username YOUR_GITHUB_OR_GITLAB_USERNAME --pr-type draft --fork-owner YOUR_GITHUB_OR_GITLAB_USERNAME
|
||||
```
|
||||
|
||||
## Providing Custom Instructions
|
||||
|
||||
You can customize how the AI agent approaches issue resolution by adding a repository microagent file at `.openhands/microagents/repo.md` in your repository. This file's contents will be automatically loaded in the prompt when working with your repository. For more information about repository microagents, see [Repository Instructions](https://github.com/All-Hands-AI/OpenHands/tree/main/microagents#2-repository-instructions-private).
|
||||
Add repository-specific instructions by creating a file at `.openhands/microagents/repo.md` in your repository. For more information about repository microagents, see [Repository Instructions](https://github.com/All-Hands-AI/OpenHands/tree/main/microagents#2-repository-instructions-private).
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
If you have any issues, please open an issue on this github or gitlab repo, we're happy to help!
|
||||
If you have any issues, please open an issue on this github repo, we're happy to help!
|
||||
Alternatively, you can [email us](mailto:contact@all-hands.dev) or join the OpenHands Slack workspace (see [the README](/README.md) for an invite link).
|
||||
|
||||
@@ -20,6 +20,7 @@ from openhands.server.shared import config
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
from openhands.server.user_auth import (
|
||||
get_provider_tokens,
|
||||
get_user_id,
|
||||
get_user_settings,
|
||||
get_user_settings_store,
|
||||
)
|
||||
@@ -30,6 +31,7 @@ app = APIRouter(prefix='/api')
|
||||
|
||||
@app.get('/settings', response_model=GETSettingsModel)
|
||||
async def load_settings(
|
||||
user_id: str | None = Depends(get_user_id),
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens),
|
||||
settings: Settings | None = Depends(get_user_settings),
|
||||
) -> GETSettingsModel | JSONResponse:
|
||||
@@ -41,10 +43,18 @@ async def load_settings(
|
||||
)
|
||||
|
||||
provider_tokens_set = {}
|
||||
|
||||
if bool(user_id):
|
||||
provider_tokens_set[ProviderType.GITHUB.value] = True
|
||||
|
||||
if provider_tokens:
|
||||
for provider_type, provider_token in provider_tokens.items():
|
||||
if provider_token.token or provider_token.user_id:
|
||||
provider_tokens_set[provider_type] = provider_token.base_domain
|
||||
all_provider_types = [provider.value for provider in ProviderType]
|
||||
provider_tokens_types = [provider.value for provider in provider_tokens]
|
||||
for provider_type in all_provider_types:
|
||||
if provider_type in provider_tokens_types:
|
||||
provider_tokens_set[provider_type] = True
|
||||
else:
|
||||
provider_tokens_set[provider_type] = False
|
||||
|
||||
settings_with_token_data = GETSettingsModel(
|
||||
**settings.model_dump(exclude='secrets_store'),
|
||||
@@ -208,80 +218,66 @@ async def reset_settings() -> JSONResponse:
|
||||
)
|
||||
|
||||
|
||||
async def check_provider_tokens(settings: POSTSettingsModel, existing_settings: Settings | None) -> str:
|
||||
async def check_provider_tokens(settings: POSTSettingsModel) -> str:
|
||||
if settings.provider_tokens:
|
||||
# Remove extraneous token types
|
||||
provider_types = [provider for provider in ProviderType]
|
||||
provider_types = [provider.value for provider in ProviderType]
|
||||
settings.provider_tokens = {
|
||||
k: v for k, v in settings.provider_tokens.items() if k in provider_types
|
||||
}
|
||||
|
||||
# Determine whether tokens are valid
|
||||
for provider_type, provider_token in settings.provider_tokens.items():
|
||||
token_value = provider_token
|
||||
existing_token = existing_settings.secrets_store.provider_tokens.get(provider_type, None) if existing_settings else None
|
||||
|
||||
# Use incoming value otherwise default to existing value
|
||||
token = SecretStr("")
|
||||
if token_value.token:
|
||||
token = token_value.token
|
||||
elif existing_token and existing_token.token:
|
||||
token = existing_token.token
|
||||
|
||||
if not token:
|
||||
continue
|
||||
|
||||
base_domain = provider_token.base_domain # FE should always send latest base_domain param
|
||||
confirmed_token_type = await validate_provider_token(
|
||||
token,
|
||||
base_domain
|
||||
)
|
||||
|
||||
|
||||
if not confirmed_token_type or confirmed_token_type != provider_type:
|
||||
return f'Invalid {provider_type.value} token or base domain.'
|
||||
for token_type, token_value in settings.provider_tokens.items():
|
||||
if token_value:
|
||||
confirmed_token_type = await validate_provider_token(
|
||||
SecretStr(token_value)
|
||||
)
|
||||
if not confirmed_token_type or confirmed_token_type.value != token_type:
|
||||
return f'Invalid token. Please make sure it is a valid {token_type} token.'
|
||||
|
||||
return ''
|
||||
|
||||
|
||||
async def store_provider_tokens(
|
||||
settings: POSTSettingsModel, existing_settings: Settings
|
||||
settings: POSTSettingsModel, settings_store: SettingsStore
|
||||
):
|
||||
existing_settings = await settings_store.load()
|
||||
if existing_settings:
|
||||
if settings.provider_tokens:
|
||||
existing_providers = [
|
||||
provider
|
||||
for provider in existing_settings.secrets_store.provider_tokens
|
||||
]
|
||||
if existing_settings.secrets_store:
|
||||
existing_providers = [
|
||||
provider.value
|
||||
for provider in existing_settings.secrets_store.provider_tokens
|
||||
]
|
||||
|
||||
# Merge incoming settings store with the existing one
|
||||
for provider, token_value in settings.provider_tokens.items():
|
||||
if provider in existing_providers and not token_value.token:
|
||||
provider_type = ProviderType(provider)
|
||||
existing_token = (
|
||||
existing_settings.secrets_store.provider_tokens.get(
|
||||
provider_type
|
||||
# Merge incoming settings store with the existing one
|
||||
for provider, token_value in list(settings.provider_tokens.items()):
|
||||
if provider in existing_providers and not token_value:
|
||||
provider_type = ProviderType(provider)
|
||||
existing_token = (
|
||||
existing_settings.secrets_store.provider_tokens.get(
|
||||
provider_type
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if existing_token:
|
||||
updated_token = ProviderToken(
|
||||
token=existing_token.token,
|
||||
user_id=existing_token.user_id,
|
||||
base_domain=token_value.base_domain
|
||||
)
|
||||
|
||||
settings.provider_tokens[provider] = updated_token
|
||||
|
||||
if existing_token and existing_token.token:
|
||||
settings.provider_tokens[provider] = (
|
||||
existing_token.token.get_secret_value()
|
||||
)
|
||||
else: # nothing passed in means keep current settings
|
||||
settings.provider_tokens = dict(existing_settings.secrets_store.provider_tokens)
|
||||
provider_tokens = existing_settings.secrets_store.provider_tokens
|
||||
settings.provider_tokens = {
|
||||
provider.value: data.token.get_secret_value() if data.token else None
|
||||
for provider, data in provider_tokens.items()
|
||||
}
|
||||
|
||||
return settings
|
||||
|
||||
|
||||
async def store_llm_settings(
|
||||
settings: POSTSettingsModel, existing_settings: Settings
|
||||
settings: POSTSettingsModel, settings_store: SettingsStore
|
||||
) -> POSTSettingsModel:
|
||||
existing_settings = await settings_store.load()
|
||||
|
||||
# Convert to Settings model and merge with existing settings
|
||||
if existing_settings:
|
||||
# Keep existing LLM settings if not provided
|
||||
@@ -299,10 +295,9 @@ async def store_llm_settings(
|
||||
async def store_settings(
|
||||
settings: POSTSettingsModel,
|
||||
settings_store: SettingsStore = Depends(get_user_settings_store),
|
||||
existing_settings: Settings | None = Depends(get_user_settings),
|
||||
) -> JSONResponse:
|
||||
# Check provider tokens are valid
|
||||
provider_err_msg = await check_provider_tokens(settings, existing_settings)
|
||||
provider_err_msg = await check_provider_tokens(settings)
|
||||
if provider_err_msg:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
@@ -310,9 +305,11 @@ async def store_settings(
|
||||
)
|
||||
|
||||
try:
|
||||
existing_settings = await settings_store.load()
|
||||
|
||||
# Convert to Settings model and merge with existing settings
|
||||
if existing_settings:
|
||||
settings = await store_llm_settings(settings, existing_settings)
|
||||
settings = await store_llm_settings(settings, settings_store)
|
||||
|
||||
# Keep existing analytics consent if not provided
|
||||
if settings.user_consents_to_analytics is None:
|
||||
@@ -320,7 +317,7 @@ async def store_settings(
|
||||
existing_settings.user_consents_to_analytics
|
||||
)
|
||||
|
||||
settings = await store_provider_tokens(settings, existing_settings)
|
||||
settings = await store_provider_tokens(settings, settings_store)
|
||||
|
||||
# Update sandbox config with new settings
|
||||
if settings.remote_runtime_resource_factor is not None:
|
||||
@@ -360,9 +357,17 @@ def convert_to_settings(settings_with_token_data: POSTSettingsModel) -> Settings
|
||||
|
||||
# Create new provider tokens immutably
|
||||
if settings_with_token_data.provider_tokens:
|
||||
tokens = {}
|
||||
for token_type, token_value in settings_with_token_data.provider_tokens.items():
|
||||
if token_value:
|
||||
provider = ProviderType(token_type)
|
||||
tokens[provider] = ProviderToken(
|
||||
token=SecretStr(token_value), user_id=None
|
||||
)
|
||||
|
||||
# Create new SecretStore with tokens
|
||||
settings = settings.model_copy(
|
||||
update={'secrets_store': SecretStore(provider_tokens=settings_with_token_data.provider_tokens)}
|
||||
update={'secrets_store': SecretStore(provider_tokens=tokens)}
|
||||
)
|
||||
|
||||
return settings
|
||||
|
||||
@@ -5,8 +5,6 @@ from pydantic import (
|
||||
SecretStr,
|
||||
)
|
||||
|
||||
from openhands.integrations.provider import ProviderToken
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
|
||||
@@ -15,7 +13,7 @@ class POSTSettingsModel(Settings):
|
||||
Settings for POST requests
|
||||
"""
|
||||
|
||||
provider_tokens: dict[ProviderType, ProviderToken] = {}
|
||||
provider_tokens: dict[str, str] = {}
|
||||
|
||||
|
||||
class POSTSettingsCustomSecrets(BaseModel):
|
||||
@@ -31,14 +29,9 @@ class GETSettingsModel(Settings):
|
||||
Settings with additional token data for the frontend
|
||||
"""
|
||||
|
||||
provider_tokens_set: dict[ProviderType, str | None] | None = (
|
||||
None # Provider Type and base domain key-value pair
|
||||
)
|
||||
provider_tokens_set: dict[str, bool] | None = None
|
||||
llm_api_key_set: bool
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
|
||||
|
||||
class GETSettingsCustomSecrets(BaseModel):
|
||||
"""
|
||||
|
||||
@@ -22,7 +22,6 @@ from openhands.events.observation import (
|
||||
ErrorObservation,
|
||||
)
|
||||
from openhands.events.observation.agent import RecallObservation
|
||||
from openhands.events.observation.commands import CmdOutputObservation
|
||||
from openhands.events.observation.empty import NullObservation
|
||||
from openhands.events.serialization import event_to_dict
|
||||
from openhands.llm import LLM
|
||||
@@ -765,7 +764,7 @@ async def test_context_window_exceeded_error_handling(
|
||||
|
||||
# We do that by playing the role of the recall module -- subscribe to the
|
||||
# event stream and respond to recall actions by inserting fake recall
|
||||
# obesrvations.
|
||||
# observations.
|
||||
def on_event_memory(event: Event):
|
||||
if isinstance(event, RecallAction):
|
||||
microagent_obs = RecallObservation(
|
||||
@@ -807,13 +806,19 @@ async def test_context_window_exceeded_error_handling(
|
||||
# size (because we return a message action, which triggers a recall, which
|
||||
# triggers a recall response). But if the pre/post-views are on the turn
|
||||
# when we throw the context window exceeded error, we should see the
|
||||
# post-step view compressed.
|
||||
# post-step view compressed (or rather, a CondensationAction added).
|
||||
for index, (first_view, second_view) in enumerate(
|
||||
zip(step_state.views[:-1], step_state.views[1:])
|
||||
):
|
||||
if index == error_after:
|
||||
assert len(first_view) > len(second_view)
|
||||
# Verify that the CondensationAction is present in the second view (after error)
|
||||
# but not in the first view (before error)
|
||||
assert not any(isinstance(e, CondensationAction) for e in first_view.events)
|
||||
assert any(isinstance(e, CondensationAction) for e in second_view.events)
|
||||
# The length might not strictly decrease due to CondensationAction being added
|
||||
assert len(first_view) == len(second_view)
|
||||
else:
|
||||
# Before the error, the view length should increase
|
||||
assert len(first_view) < len(second_view)
|
||||
|
||||
# The final state's history should contain:
|
||||
@@ -886,7 +891,7 @@ async def test_run_controller_with_context_window_exceeded_with_truncation(
|
||||
def step(self, state: State):
|
||||
# If the state has more than one message and we haven't errored yet,
|
||||
# throw the context window exceeded error
|
||||
if len(state.history) > 3 and not self.has_errored:
|
||||
if len(state.history) > 5 and not self.has_errored:
|
||||
error = ContextWindowExceededError(
|
||||
message='prompt is too long: 233885 tokens > 200000 maximum',
|
||||
model='',
|
||||
@@ -1467,126 +1472,6 @@ def test_agent_controller_should_step_with_null_observation_cause_zero(mock_agen
|
||||
), 'should_step should return False for NullObservation with cause = 0'
|
||||
|
||||
|
||||
def test_apply_conversation_window_basic(mock_event_stream, mock_agent):
|
||||
"""Test that the _apply_conversation_window method correctly prunes a list of events."""
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test_apply_conversation_window_basic',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
|
||||
# Create a sequence of events with IDs
|
||||
first_msg = MessageAction(content='Hello, start task', wait_for_response=False)
|
||||
first_msg._source = EventSource.USER
|
||||
first_msg._id = 1
|
||||
|
||||
# Add agent question
|
||||
agent_msg = MessageAction(
|
||||
content='What task would you like me to perform?', wait_for_response=True
|
||||
)
|
||||
agent_msg._source = EventSource.AGENT
|
||||
agent_msg._id = 2
|
||||
|
||||
# Add user response
|
||||
user_response = MessageAction(
|
||||
content='Please list all files and show me current directory',
|
||||
wait_for_response=False,
|
||||
)
|
||||
user_response._source = EventSource.USER
|
||||
user_response._id = 3
|
||||
|
||||
cmd1 = CmdRunAction(command='ls')
|
||||
cmd1._id = 4
|
||||
obs1 = CmdOutputObservation(command='ls', content='file1.txt', command_id=4)
|
||||
obs1._id = 5
|
||||
obs1._cause = 4
|
||||
|
||||
cmd2 = CmdRunAction(command='pwd')
|
||||
cmd2._id = 6
|
||||
obs2 = CmdOutputObservation(command='pwd', content='/home', command_id=6)
|
||||
obs2._id = 7
|
||||
obs2._cause = 6
|
||||
|
||||
events = [first_msg, agent_msg, user_response, cmd1, obs1, cmd2, obs2]
|
||||
|
||||
# Apply truncation
|
||||
truncated = controller._apply_conversation_window(events)
|
||||
|
||||
# Verify truncation occured
|
||||
# Should keep first user message and roughly half of other events
|
||||
assert (
|
||||
3 <= len(truncated) < len(events)
|
||||
) # First message + at least one action-observation pair
|
||||
assert truncated[0] == first_msg # First message always preserved
|
||||
assert controller.state.start_id == first_msg._id
|
||||
|
||||
# Verify pairs aren't split
|
||||
for i, event in enumerate(truncated[1:]):
|
||||
if isinstance(event, CmdOutputObservation):
|
||||
assert any(e._id == event._cause for e in truncated[: i + 1])
|
||||
|
||||
|
||||
def test_history_restoration_after_truncation(mock_event_stream, mock_agent):
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test_truncation',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
|
||||
# Create events with IDs
|
||||
first_msg = MessageAction(content='Start task', wait_for_response=False)
|
||||
first_msg._source = EventSource.USER
|
||||
first_msg._id = 1
|
||||
|
||||
events = [first_msg]
|
||||
for i in range(5):
|
||||
cmd = CmdRunAction(command=f'cmd{i}')
|
||||
cmd._id = i + 2
|
||||
obs = CmdOutputObservation(
|
||||
command=f'cmd{i}', content=f'output{i}', command_id=cmd._id
|
||||
)
|
||||
obs._cause = cmd._id
|
||||
events.extend([cmd, obs])
|
||||
|
||||
# Set up initial history
|
||||
controller.state.history = events.copy()
|
||||
|
||||
# Force truncation
|
||||
controller.state.history = controller._apply_conversation_window(
|
||||
controller.state.history
|
||||
)
|
||||
|
||||
# Save state
|
||||
saved_start_id = controller.state.start_id
|
||||
saved_history_len = len(controller.state.history)
|
||||
|
||||
# Set up mock event stream for new controller
|
||||
mock_event_stream.get_events.return_value = controller.state.history
|
||||
|
||||
# Create new controller with saved state
|
||||
new_controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test_truncation',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
new_controller.state.start_id = saved_start_id
|
||||
new_controller.state.history = mock_event_stream.get_events()
|
||||
|
||||
# Verify restoration
|
||||
assert len(new_controller.state.history) == saved_history_len
|
||||
assert new_controller.state.history[0] == first_msg
|
||||
assert new_controller.state.start_id == saved_start_id
|
||||
|
||||
|
||||
def test_system_message_in_event_stream(mock_agent, test_event_stream):
|
||||
"""Test that SystemMessageAction is added to event stream in AgentController."""
|
||||
_ = AgentController(
|
||||
|
||||
569
tests/unit/test_agent_history.py
Normal file
569
tests/unit/test_agent_history.py
Normal file
@@ -0,0 +1,569 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.controller.agent_controller import AgentController
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import AppConfig
|
||||
from openhands.events import EventSource
|
||||
from openhands.events.action import CmdRunAction, MessageAction, RecallAction
|
||||
from openhands.events.action.message import SystemMessageAction
|
||||
from openhands.events.event import RecallType
|
||||
from openhands.events.observation import (
|
||||
CmdOutputObservation,
|
||||
Observation,
|
||||
RecallObservation,
|
||||
)
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
|
||||
# Helper function to create events with sequential IDs and causes
|
||||
def create_events(event_data):
|
||||
events = []
|
||||
# Import necessary types here to avoid repeated imports inside the loop
|
||||
from openhands.events.action import CmdRunAction, RecallAction
|
||||
from openhands.events.observation import CmdOutputObservation, RecallObservation
|
||||
|
||||
for i, data in enumerate(event_data):
|
||||
event_type = data['type']
|
||||
source = data.get('source', EventSource.AGENT)
|
||||
kwargs = {} # Arguments for the event constructor
|
||||
|
||||
# Determine arguments based on event type
|
||||
if event_type == RecallAction:
|
||||
kwargs['query'] = data.get('query', '')
|
||||
kwargs['recall_type'] = data.get('recall_type', RecallType.KNOWLEDGE)
|
||||
elif event_type == RecallObservation:
|
||||
kwargs['content'] = data.get('content', '')
|
||||
kwargs['recall_type'] = data.get('recall_type', RecallType.KNOWLEDGE)
|
||||
elif event_type == CmdRunAction:
|
||||
kwargs['command'] = data.get('command', '')
|
||||
elif event_type == CmdOutputObservation:
|
||||
# Required args for CmdOutputObservation
|
||||
kwargs['content'] = data.get('content', '')
|
||||
kwargs['command'] = data.get('command', '')
|
||||
# Pass command_id via kwargs if present in data
|
||||
if 'command_id' in data:
|
||||
kwargs['command_id'] = data['command_id']
|
||||
# Pass metadata if present
|
||||
if 'metadata' in data:
|
||||
kwargs['metadata'] = data['metadata']
|
||||
else: # Default for MessageAction, SystemMessageAction, etc.
|
||||
kwargs['content'] = data.get('content', '')
|
||||
|
||||
# Instantiate the event
|
||||
event = event_type(**kwargs)
|
||||
|
||||
# Assign internal attributes AFTER instantiation
|
||||
event._id = i + 1 # Assign sequential IDs starting from 1
|
||||
event._source = source
|
||||
# Assign _cause using cause_id from data, AFTER event._id is set
|
||||
if 'cause_id' in data:
|
||||
event._cause = data['cause_id']
|
||||
# If command_id was NOT passed via kwargs but cause_id exists,
|
||||
# pass cause_id as command_id to __init__ via kwargs for legacy handling
|
||||
# This needs to happen *before* instantiation if we want __init__ to handle it
|
||||
# Let's adjust the logic slightly:
|
||||
if event_type == CmdOutputObservation:
|
||||
if 'command_id' not in kwargs and 'cause_id' in data:
|
||||
kwargs['command_id'] = data['cause_id'] # Let __init__ handle this
|
||||
# Re-instantiate if we added command_id
|
||||
if 'command_id' in kwargs and event.command_id != kwargs['command_id']:
|
||||
event = event_type(**kwargs)
|
||||
event._id = i + 1
|
||||
event._source = source
|
||||
|
||||
# Now assign _cause if it exists in data, after potential re-instantiation
|
||||
if 'cause_id' in data:
|
||||
event._cause = data['cause_id']
|
||||
|
||||
events.append(event)
|
||||
return events
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def controller_fixture():
|
||||
mock_agent = MagicMock(spec=Agent)
|
||||
mock_agent.llm = MagicMock(spec=LLM)
|
||||
mock_agent.llm.metrics = Metrics()
|
||||
mock_agent.llm.config = AppConfig().get_llm_config()
|
||||
mock_agent.config = AppConfig().get_agent_config('CodeActAgent')
|
||||
|
||||
mock_event_stream = MagicMock(spec=EventStream)
|
||||
mock_event_stream.sid = 'test_sid'
|
||||
mock_event_stream.file_store = InMemoryFileStore({})
|
||||
# Ensure get_latest_event_id returns an integer
|
||||
mock_event_stream.get_latest_event_id.return_value = -1
|
||||
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test_sid',
|
||||
)
|
||||
controller.state = State(session_id='test_sid')
|
||||
|
||||
# Mock _first_user_message directly on the instance
|
||||
mock_first_user_message = MagicMock(spec=MessageAction)
|
||||
controller._first_user_message = MagicMock(return_value=mock_first_user_message)
|
||||
|
||||
return controller, mock_first_user_message
|
||||
|
||||
|
||||
# =============================================
|
||||
# Test Cases for _apply_conversation_window
|
||||
# =============================================
|
||||
|
||||
|
||||
def test_basic_truncation(controller_fixture):
|
||||
controller, mock_first_user_message = controller_fixture
|
||||
|
||||
controller.state.history = create_events(
|
||||
[
|
||||
{'type': SystemMessageAction, 'content': 'System Prompt'}, # 1
|
||||
{
|
||||
'type': MessageAction,
|
||||
'content': 'User Task 1',
|
||||
'source': EventSource.USER,
|
||||
}, # 2
|
||||
{'type': RecallAction, 'query': 'User Task 1'}, # 3
|
||||
{'type': RecallObservation, 'content': 'Recall result', 'cause_id': 3}, # 4
|
||||
{'type': CmdRunAction, 'command': 'ls'}, # 5
|
||||
{
|
||||
'type': CmdOutputObservation,
|
||||
'content': 'file1',
|
||||
'command': 'ls',
|
||||
'cause_id': 5,
|
||||
}, # 6
|
||||
{'type': CmdRunAction, 'command': 'pwd'}, # 7
|
||||
{
|
||||
'type': CmdOutputObservation,
|
||||
'content': '/dir',
|
||||
'command': 'pwd',
|
||||
'cause_id': 7,
|
||||
}, # 8
|
||||
{'type': CmdRunAction, 'command': 'cat file1'}, # 9
|
||||
{
|
||||
'type': CmdOutputObservation,
|
||||
'content': 'content',
|
||||
'command': 'cat file1',
|
||||
'cause_id': 9,
|
||||
}, # 10
|
||||
]
|
||||
)
|
||||
mock_first_user_message.id = 2 # Set the ID of the mocked first user message
|
||||
|
||||
# Calculation (RecallAction now essential):
|
||||
# History len = 10
|
||||
# Essentials = [sys(1), user(2), recall_act(3), recall_obs(4)] (len=4)
|
||||
# Non-essential count = 10 - 4 = 6
|
||||
# num_recent_to_keep = max(1, 6 // 2) = 3
|
||||
# slice_start_index = 10 - 3 = 7
|
||||
# recent_events_slice = history[7:] = [obs2(8), cmd3(9), obs3(10)]
|
||||
# Validation: remove leading obs2(8). validated_slice = [cmd3(9), obs3(10)]
|
||||
# Final = essentials + validated_slice = [sys(1), user(2), recall_act(3), recall_obs(4), cmd3(9), obs3(10)]
|
||||
# Expected IDs: [1, 2, 3, 4, 9, 10]. Length 6.
|
||||
truncated_events = controller._apply_conversation_window()
|
||||
|
||||
assert len(truncated_events) == 6
|
||||
expected_ids = [1, 2, 3, 4, 9, 10]
|
||||
actual_ids = [e.id for e in truncated_events]
|
||||
assert actual_ids == expected_ids
|
||||
# Check no dangling observations at the start of the recent slice part
|
||||
# The first event of the validated slice is cmd3(9)
|
||||
assert not isinstance(truncated_events[4], Observation) # Index adjusted
|
||||
|
||||
|
||||
def test_no_system_message(controller_fixture):
|
||||
controller, mock_first_user_message = controller_fixture
|
||||
|
||||
controller.state.history = create_events(
|
||||
[
|
||||
{
|
||||
'type': MessageAction,
|
||||
'content': 'User Task 1',
|
||||
'source': EventSource.USER,
|
||||
}, # 1
|
||||
{'type': RecallAction, 'query': 'User Task 1'}, # 2
|
||||
{'type': RecallObservation, 'content': 'Recall result', 'cause_id': 2}, # 3
|
||||
{'type': CmdRunAction, 'command': 'ls'}, # 4
|
||||
{
|
||||
'type': CmdOutputObservation,
|
||||
'content': 'file1',
|
||||
'command': 'ls',
|
||||
'cause_id': 4,
|
||||
}, # 5
|
||||
{'type': CmdRunAction, 'command': 'pwd'}, # 6
|
||||
{
|
||||
'type': CmdOutputObservation,
|
||||
'content': '/dir',
|
||||
'command': 'pwd',
|
||||
'cause_id': 6,
|
||||
}, # 7
|
||||
{'type': CmdRunAction, 'command': 'cat file1'}, # 8
|
||||
{
|
||||
'type': CmdOutputObservation,
|
||||
'content': 'content',
|
||||
'command': 'cat file1',
|
||||
'cause_id': 8,
|
||||
}, # 9
|
||||
]
|
||||
)
|
||||
mock_first_user_message.id = 1
|
||||
|
||||
# Calculation (RecallAction now essential):
|
||||
# History len = 9
|
||||
# Essentials = [user(1), recall_act(2), recall_obs(3)] (len=3)
|
||||
# Non-essential count = 9 - 3 = 6
|
||||
# num_recent_to_keep = max(1, 6 // 2) = 3
|
||||
# slice_start_index = 9 - 3 = 6
|
||||
# recent_events_slice = history[6:] = [obs2(7), cmd3(8), obs3(9)]
|
||||
# Validation: remove leading obs2(7). validated_slice = [cmd3(8), obs3(9)]
|
||||
# Final = essentials + validated_slice = [user(1), recall_act(2), recall_obs(3), cmd3(8), obs3(9)]
|
||||
# Expected IDs: [1, 2, 3, 8, 9]. Length 5.
|
||||
truncated_events = controller._apply_conversation_window()
|
||||
|
||||
assert len(truncated_events) == 5
|
||||
expected_ids = [1, 2, 3, 8, 9]
|
||||
actual_ids = [e.id for e in truncated_events]
|
||||
assert actual_ids == expected_ids
|
||||
|
||||
|
||||
def test_no_recall_observation(controller_fixture):
|
||||
controller, mock_first_user_message = controller_fixture
|
||||
|
||||
controller.state.history = create_events(
|
||||
[
|
||||
{'type': SystemMessageAction, 'content': 'System Prompt'}, # 1
|
||||
{
|
||||
'type': MessageAction,
|
||||
'content': 'User Task 1',
|
||||
'source': EventSource.USER,
|
||||
}, # 2
|
||||
{'type': RecallAction, 'query': 'User Task 1'}, # 3 (Recall Action exists)
|
||||
# Recall Observation is missing
|
||||
{'type': CmdRunAction, 'command': 'ls'}, # 4
|
||||
{
|
||||
'type': CmdOutputObservation,
|
||||
'content': 'file1',
|
||||
'command': 'ls',
|
||||
'cause_id': 4,
|
||||
}, # 5
|
||||
{'type': CmdRunAction, 'command': 'pwd'}, # 6
|
||||
{
|
||||
'type': CmdOutputObservation,
|
||||
'content': '/dir',
|
||||
'command': 'pwd',
|
||||
'cause_id': 6,
|
||||
}, # 7
|
||||
{'type': CmdRunAction, 'command': 'cat file1'}, # 8
|
||||
{
|
||||
'type': CmdOutputObservation,
|
||||
'content': 'content',
|
||||
'command': 'cat file1',
|
||||
'cause_id': 8,
|
||||
}, # 9
|
||||
]
|
||||
)
|
||||
mock_first_user_message.id = 2
|
||||
|
||||
# Calculation (RecallAction essential only if RecallObs exists):
|
||||
# History len = 9
|
||||
# Essentials = [sys(1), user(2)] (len=2) - RecallObs missing, so RecallAction not essential here
|
||||
# Non-essential count = 9 - 2 = 7
|
||||
# num_recent_to_keep = max(1, 7 // 2) = 3
|
||||
# slice_start_index = 9 - 3 = 6
|
||||
# recent_events_slice = history[6:] = [obs2(7), cmd3(8), obs3(9)]
|
||||
# Validation: remove leading obs2(7). validated_slice = [cmd3(8), obs3(9)]
|
||||
# Final = essentials + validated_slice = [sys(1), user(2), recall_action(3), cmd_cat(8), obs_cat(9)]
|
||||
# Expected IDs: [1, 2, 3, 8, 9]. Length 5.
|
||||
truncated_events = controller._apply_conversation_window()
|
||||
|
||||
assert len(truncated_events) == 5
|
||||
expected_ids = [1, 2, 3, 8, 9]
|
||||
actual_ids = [e.id for e in truncated_events]
|
||||
assert actual_ids == expected_ids
|
||||
|
||||
|
||||
def test_short_history_no_truncation(controller_fixture):
|
||||
controller, mock_first_user_message = controller_fixture
|
||||
|
||||
history = create_events(
|
||||
[
|
||||
{'type': SystemMessageAction, 'content': 'System Prompt'}, # 1
|
||||
{
|
||||
'type': MessageAction,
|
||||
'content': 'User Task 1',
|
||||
'source': EventSource.USER,
|
||||
}, # 2
|
||||
{'type': RecallAction, 'query': 'User Task 1'}, # 3
|
||||
{'type': RecallObservation, 'content': 'Recall result', 'cause_id': 3}, # 4
|
||||
{'type': CmdRunAction, 'command': 'ls'}, # 5
|
||||
{
|
||||
'type': CmdOutputObservation,
|
||||
'content': 'file1',
|
||||
'command': 'ls',
|
||||
'cause_id': 5,
|
||||
}, # 6
|
||||
]
|
||||
)
|
||||
controller.state.history = history
|
||||
mock_first_user_message.id = 2
|
||||
|
||||
# Calculation (RecallAction now essential):
|
||||
# History len = 6
|
||||
# Essentials = [sys(1), user(2), recall_act(3), recall_obs(4)] (len=4)
|
||||
# Non-essential count = 6 - 4 = 2
|
||||
# num_recent_to_keep = max(1, 2 // 2) = 1
|
||||
# slice_start_index = 6 - 1 = 5
|
||||
# recent_events_slice = history[5:] = [obs1(6)]
|
||||
# Validation: remove leading obs1(6). validated_slice = []
|
||||
# Final = essentials + validated_slice = [sys(1), user(2), recall_act(3), recall_obs(4)]
|
||||
# Expected IDs: [1, 2, 3, 4]. Length 4.
|
||||
truncated_events = controller._apply_conversation_window()
|
||||
|
||||
assert len(truncated_events) == 4
|
||||
expected_ids = [1, 2, 3, 4]
|
||||
actual_ids = [e.id for e in truncated_events]
|
||||
assert actual_ids == expected_ids
|
||||
|
||||
|
||||
def test_only_essential_events(controller_fixture):
|
||||
controller, mock_first_user_message = controller_fixture
|
||||
|
||||
history = create_events(
|
||||
[
|
||||
{'type': SystemMessageAction, 'content': 'System Prompt'}, # 1
|
||||
{
|
||||
'type': MessageAction,
|
||||
'content': 'User Task 1',
|
||||
'source': EventSource.USER,
|
||||
}, # 2
|
||||
{'type': RecallAction, 'query': 'User Task 1'}, # 3
|
||||
{'type': RecallObservation, 'content': 'Recall result', 'cause_id': 3}, # 4
|
||||
]
|
||||
)
|
||||
controller.state.history = history
|
||||
mock_first_user_message.id = 2
|
||||
|
||||
# Calculation (RecallAction now essential):
|
||||
# History len = 4
|
||||
# Essentials = [sys(1), user(2), recall_act(3), recall_obs(4)] (len=4)
|
||||
# Non-essential count = 4 - 4 = 0
|
||||
# num_recent_to_keep = max(1, 0 // 2) = 1
|
||||
# slice_start_index = 4 - 1 = 3
|
||||
# recent_events_slice = history[3:] = [recall_obs(4)]
|
||||
# Validation: remove leading recall_obs(4). validated_slice = []
|
||||
# Final = essentials + validated_slice = [sys(1), user(2), recall_act(3), recall_obs(4)]
|
||||
# Expected IDs: [1, 2, 3, 4]. Length 4.
|
||||
truncated_events = controller._apply_conversation_window()
|
||||
|
||||
assert len(truncated_events) == 4
|
||||
expected_ids = [1, 2, 3, 4]
|
||||
actual_ids = [e.id for e in truncated_events]
|
||||
assert actual_ids == expected_ids
|
||||
|
||||
|
||||
def test_dangling_observations_at_cut_point(controller_fixture):
|
||||
controller, mock_first_user_message = controller_fixture
|
||||
|
||||
history_forced_dangle = create_events(
|
||||
[
|
||||
{'type': SystemMessageAction, 'content': 'System Prompt'}, # 1
|
||||
{
|
||||
'type': MessageAction,
|
||||
'content': 'User Task 1',
|
||||
'source': EventSource.USER,
|
||||
}, # 2
|
||||
{'type': RecallAction, 'query': 'User Task 1'}, # 3
|
||||
{'type': RecallObservation, 'content': 'Recall result', 'cause_id': 3}, # 4
|
||||
# --- Slice calculation should start here ---
|
||||
{
|
||||
'type': CmdOutputObservation,
|
||||
'content': 'dangle1',
|
||||
'command': 'cmd_unknown',
|
||||
}, # 5 (Dangling)
|
||||
{
|
||||
'type': CmdOutputObservation,
|
||||
'content': 'dangle2',
|
||||
'command': 'cmd_unknown',
|
||||
}, # 6 (Dangling)
|
||||
{'type': CmdRunAction, 'command': 'cmd1'}, # 7
|
||||
{
|
||||
'type': CmdOutputObservation,
|
||||
'content': 'obs1',
|
||||
'command': 'cmd1',
|
||||
'cause_id': 7,
|
||||
}, # 8
|
||||
{'type': CmdRunAction, 'command': 'cmd2'}, # 9
|
||||
{
|
||||
'type': CmdOutputObservation,
|
||||
'content': 'obs2',
|
||||
'command': 'cmd2',
|
||||
'cause_id': 9,
|
||||
}, # 10
|
||||
]
|
||||
) # 10 events total
|
||||
controller.state.history = history_forced_dangle
|
||||
mock_first_user_message.id = 2
|
||||
|
||||
# Calculation (RecallAction now essential):
|
||||
# History len = 10
|
||||
# Essentials = [sys(1), user(2), recall_act(3), recall_obs(4)] (len=4)
|
||||
# Non-essential count = 10 - 4 = 6
|
||||
# num_recent_to_keep = max(1, 6 // 2) = 3
|
||||
# slice_start_index = 10 - 3 = 7
|
||||
# recent_events_slice = history[7:] = [obs1(8), cmd2(9), obs2(10)]
|
||||
# Validation: remove leading obs1(8). validated_slice = [cmd2(9), obs2(10)]
|
||||
# Final = essentials + validated_slice = [sys(1), user(2), recall_act(3), recall_obs(4), cmd2(9), obs2(10)]
|
||||
# Expected IDs: [1, 2, 3, 4, 9, 10]. Length 6.
|
||||
truncated_events = controller._apply_conversation_window()
|
||||
|
||||
assert len(truncated_events) == 6
|
||||
expected_ids = [1, 2, 3, 4, 9, 10]
|
||||
actual_ids = [e.id for e in truncated_events]
|
||||
assert actual_ids == expected_ids
|
||||
# Verify dangling observations 5 and 6 were removed (implicitly by slice start and validation)
|
||||
|
||||
|
||||
def test_only_dangling_observations_in_recent_slice(controller_fixture):
|
||||
controller, mock_first_user_message = controller_fixture
|
||||
|
||||
history = create_events(
|
||||
[
|
||||
{'type': SystemMessageAction, 'content': 'System Prompt'}, # 1
|
||||
{
|
||||
'type': MessageAction,
|
||||
'content': 'User Task 1',
|
||||
'source': EventSource.USER,
|
||||
}, # 2
|
||||
{'type': RecallAction, 'query': 'User Task 1'}, # 3
|
||||
{'type': RecallObservation, 'content': 'Recall result', 'cause_id': 3}, # 4
|
||||
# --- Slice calculation should start here ---
|
||||
{
|
||||
'type': CmdOutputObservation,
|
||||
'content': 'dangle1',
|
||||
'command': 'cmd_unknown',
|
||||
}, # 5 (Dangling)
|
||||
{
|
||||
'type': CmdOutputObservation,
|
||||
'content': 'dangle2',
|
||||
'command': 'cmd_unknown',
|
||||
}, # 6 (Dangling)
|
||||
]
|
||||
) # 6 events total
|
||||
controller.state.history = history
|
||||
mock_first_user_message.id = 2
|
||||
|
||||
# Calculation (RecallAction now essential):
|
||||
# History len = 6
|
||||
# Essentials = [sys(1), user(2), recall_act(3), recall_obs(4)] (len=4)
|
||||
# Non-essential count = 6 - 4 = 2
|
||||
# num_recent_to_keep = max(1, 2 // 2) = 1
|
||||
# slice_start_index = 6 - 1 = 5
|
||||
# recent_events_slice = history[5:] = [dangle2(6)]
|
||||
# Validation: remove leading dangle2(6). validated_slice = [] (Corrected based on user feedback/bugfix)
|
||||
# Final = essentials + validated_slice = [sys(1), user(2), recall_act(3), recall_obs(4)]
|
||||
# Expected IDs: [1, 2, 3, 4]. Length 4.
|
||||
with patch(
|
||||
'openhands.controller.agent_controller.logger.warning'
|
||||
) as mock_log_warning:
|
||||
truncated_events = controller._apply_conversation_window()
|
||||
|
||||
assert len(truncated_events) == 4
|
||||
expected_ids = [1, 2, 3, 4]
|
||||
actual_ids = [e.id for e in truncated_events]
|
||||
assert actual_ids == expected_ids
|
||||
# Verify dangling observations 5 and 6 were removed
|
||||
|
||||
# Check that the specific warning was logged exactly once
|
||||
assert mock_log_warning.call_count == 1
|
||||
|
||||
# Check the essential parts of the arguments, allowing for variations like stacklevel
|
||||
call_args, call_kwargs = mock_log_warning.call_args
|
||||
expected_message_substring = 'All recent events are dangling observations, which we truncate. This means the agent has only the essential first events. This should not happen.'
|
||||
assert expected_message_substring in call_args[0]
|
||||
assert 'extra' in call_kwargs
|
||||
assert call_kwargs['extra'].get('session_id') == 'test_sid'
|
||||
|
||||
|
||||
def test_empty_history(controller_fixture):
|
||||
controller, _ = controller_fixture
|
||||
controller.state.history = []
|
||||
|
||||
truncated_events = controller._apply_conversation_window()
|
||||
assert truncated_events == []
|
||||
|
||||
|
||||
def test_multiple_user_messages(controller_fixture):
|
||||
controller, mock_first_user_message = controller_fixture
|
||||
|
||||
history = create_events(
|
||||
[
|
||||
{'type': SystemMessageAction, 'content': 'System Prompt'}, # 1
|
||||
{
|
||||
'type': MessageAction,
|
||||
'content': 'User Task 1',
|
||||
'source': EventSource.USER,
|
||||
}, # 2 (First)
|
||||
{'type': RecallAction, 'query': 'User Task 1'}, # 3
|
||||
{
|
||||
'type': RecallObservation,
|
||||
'content': 'Recall result 1',
|
||||
'cause_id': 3,
|
||||
}, # 4
|
||||
{'type': CmdRunAction, 'command': 'cmd1'}, # 5
|
||||
{
|
||||
'type': CmdOutputObservation,
|
||||
'content': 'obs1',
|
||||
'command': 'cmd1',
|
||||
'cause_id': 5,
|
||||
}, # 6
|
||||
{
|
||||
'type': MessageAction,
|
||||
'content': 'User Task 2',
|
||||
'source': EventSource.USER,
|
||||
}, # 7 (Second)
|
||||
{'type': RecallAction, 'query': 'User Task 2'}, # 8
|
||||
{
|
||||
'type': RecallObservation,
|
||||
'content': 'Recall result 2',
|
||||
'cause_id': 8,
|
||||
}, # 9
|
||||
{'type': CmdRunAction, 'command': 'cmd2'}, # 10
|
||||
{
|
||||
'type': CmdOutputObservation,
|
||||
'content': 'obs2',
|
||||
'command': 'cmd2',
|
||||
'cause_id': 10,
|
||||
}, # 11
|
||||
]
|
||||
) # 11 events total
|
||||
controller.state.history = history
|
||||
mock_first_user_message.id = 2 # Explicitly set the first user message ID
|
||||
|
||||
# Calculation (RecallAction now essential):
|
||||
# History len = 11
|
||||
# Essentials = [sys(1), user1(2), recall_act1(3), recall_obs1(4)] (len=4)
|
||||
# Non-essential count = 11 - 4 = 7
|
||||
# num_recent_to_keep = max(1, 7 // 2) = 3
|
||||
# slice_start_index = 11 - 3 = 8
|
||||
# recent_events_slice = history[8:] = [recall_obs2(9), cmd2(10), obs2(11)]
|
||||
# Validation: remove leading recall_obs2(9). validated_slice = [cmd2(10), obs2(11)]
|
||||
# Final = essentials + validated_slice = [sys(1), user1(2), recall_act1(3), recall_obs1(4)] + [cmd2(10), obs2(11)]
|
||||
# Expected IDs: [1, 2, 3, 4, 10, 11]. Length 6.
|
||||
truncated_events = controller._apply_conversation_window()
|
||||
|
||||
assert len(truncated_events) == 6
|
||||
expected_ids = [1, 2, 3, 4, 10, 11]
|
||||
actual_ids = [e.id for e in truncated_events]
|
||||
assert actual_ids == expected_ids
|
||||
|
||||
# Verify the second user message (ID 7) was NOT kept
|
||||
assert not any(event.id == 7 for event in truncated_events)
|
||||
# Verify the first user message (ID 2) is present
|
||||
assert any(event.id == 2 for event in truncated_events)
|
||||
@@ -44,6 +44,7 @@ from openhands.events.observation.commands import (
|
||||
)
|
||||
from openhands.events.tool import ToolCallMetadata
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.memory.condenser import View
|
||||
|
||||
|
||||
@pytest.fixture(params=['CodeActAgent', 'ReadOnlyAgent'])
|
||||
@@ -97,6 +98,12 @@ def test_reset(agent):
|
||||
action._source = EventSource.AGENT
|
||||
agent.pending_actions.append(action)
|
||||
|
||||
# Create a mock state with initial user message
|
||||
mock_state = Mock(spec=State)
|
||||
initial_user_message = MessageAction(content='Initial user message')
|
||||
initial_user_message._source = EventSource.USER
|
||||
mock_state.history = [initial_user_message]
|
||||
|
||||
# Reset
|
||||
agent.reset()
|
||||
|
||||
@@ -110,8 +117,14 @@ def test_step_with_pending_actions(agent):
|
||||
pending_action._source = EventSource.AGENT
|
||||
agent.pending_actions.append(pending_action)
|
||||
|
||||
# Create a mock state with initial user message
|
||||
mock_state = Mock(spec=State)
|
||||
initial_user_message = MessageAction(content='Initial user message')
|
||||
initial_user_message._source = EventSource.USER
|
||||
mock_state.history = [initial_user_message]
|
||||
|
||||
# Step should return the pending action
|
||||
result = agent.step(Mock())
|
||||
result = agent.step(mock_state)
|
||||
assert result == pending_action
|
||||
assert len(agent.pending_actions) == 0
|
||||
|
||||
@@ -260,6 +273,11 @@ def test_step_with_no_pending_actions(mock_state: State):
|
||||
mock_state.latest_user_message_llm_metrics = None
|
||||
mock_state.latest_user_message_tool_call_metadata = None
|
||||
|
||||
# Add initial user message to history
|
||||
initial_user_message = MessageAction(content='Initial user message')
|
||||
initial_user_message._source = EventSource.USER
|
||||
mock_state.history = [initial_user_message]
|
||||
|
||||
action = agent.step(mock_state)
|
||||
assert isinstance(action, MessageAction)
|
||||
assert action.content == 'Task completed'
|
||||
@@ -330,42 +348,56 @@ def test_mismatched_tool_call_events_and_auto_add_system_message(
|
||||
)
|
||||
|
||||
action = CmdRunAction('foo')
|
||||
action._source = 'agent'
|
||||
action._source = EventSource.AGENT
|
||||
action.tool_call_metadata = tool_call_metadata
|
||||
|
||||
observation = CmdOutputObservation(content='', command_id=0, command='foo')
|
||||
observation.tool_call_metadata = tool_call_metadata
|
||||
|
||||
# Add initial user message
|
||||
initial_user_message = MessageAction(content='Initial user message')
|
||||
initial_user_message._source = EventSource.USER
|
||||
|
||||
# When both events are provided, the agent should get three messages:
|
||||
# 1. The system message (added automatically for backward compatibility)
|
||||
# 2. The action message
|
||||
# 3. The observation message
|
||||
mock_state.history = [action, observation]
|
||||
messages = agent._get_messages(mock_state.history)
|
||||
assert len(messages) == 3
|
||||
mock_state.history = [initial_user_message, action, observation]
|
||||
messages = agent._get_messages(mock_state.history, initial_user_message)
|
||||
assert len(messages) == 4 # System + initial user + action + observation
|
||||
assert messages[0].role == 'system' # First message should be the system message
|
||||
assert messages[1].role == 'assistant' # Second message should be the action
|
||||
assert messages[2].role == 'tool' # Third message should be the observation
|
||||
assert (
|
||||
messages[1].role == 'user'
|
||||
) # Second message should be the initial user message
|
||||
assert messages[2].role == 'assistant' # Third message should be the action
|
||||
assert messages[3].role == 'tool' # Fourth message should be the observation
|
||||
|
||||
# The same should hold if the events are presented out-of-order
|
||||
mock_state.history = [observation, action]
|
||||
messages = agent._get_messages(mock_state.history)
|
||||
assert len(messages) == 3
|
||||
mock_state.history = [initial_user_message, observation, action]
|
||||
messages = agent._get_messages(mock_state.history, initial_user_message)
|
||||
assert len(messages) == 4
|
||||
assert messages[0].role == 'system' # First message should be the system message
|
||||
assert (
|
||||
messages[1].role == 'user'
|
||||
) # Second message should be the initial user message
|
||||
|
||||
# If only one of the two events is present, then we should just get the system message
|
||||
# plus any valid message from the event
|
||||
mock_state.history = [action]
|
||||
messages = agent._get_messages(mock_state.history)
|
||||
mock_state.history = [initial_user_message, action]
|
||||
messages = agent._get_messages(mock_state.history, initial_user_message)
|
||||
assert (
|
||||
len(messages) == 1
|
||||
) # Only system message, action is waiting for its observation
|
||||
len(messages) == 2
|
||||
) # System + initial user message, action is waiting for its observation
|
||||
assert messages[0].role == 'system'
|
||||
assert messages[1].role == 'user'
|
||||
|
||||
mock_state.history = [observation]
|
||||
messages = agent._get_messages(mock_state.history)
|
||||
assert len(messages) == 1 # Only system message, observation has no matching action
|
||||
mock_state.history = [initial_user_message, observation]
|
||||
messages = agent._get_messages(mock_state.history, initial_user_message)
|
||||
assert (
|
||||
len(messages) == 2
|
||||
) # System + initial user message, observation has no matching action
|
||||
assert messages[0].role == 'system'
|
||||
assert messages[1].role == 'user'
|
||||
|
||||
|
||||
def test_grep_tool():
|
||||
@@ -470,3 +502,19 @@ def test_get_system_message():
|
||||
assert len(result.tools) > 0
|
||||
assert any(tool['function']['name'] == 'execute_bash' for tool in result.tools)
|
||||
assert result._source == EventSource.AGENT
|
||||
|
||||
|
||||
def test_step_raises_error_if_no_initial_user_message(
|
||||
agent: CodeActAgent, mock_state: State
|
||||
):
|
||||
"""Tests that step raises ValueError if the initial user message is not found."""
|
||||
# Ensure history does NOT contain a user MessageAction
|
||||
assistant_message = MessageAction(content='Assistant message')
|
||||
assistant_message._source = EventSource.AGENT
|
||||
mock_state.history = [assistant_message]
|
||||
# Mock the condenser to return the history as is
|
||||
agent.condenser = Mock()
|
||||
agent.condenser.condensed_history.return_value = View(events=mock_state.history)
|
||||
|
||||
with pytest.raises(ValueError, match='Initial user message not found'):
|
||||
agent.step(mock_state)
|
||||
|
||||
@@ -8,6 +8,7 @@ from openhands.core.cli_commands import (
|
||||
handle_help_command,
|
||||
handle_init_command,
|
||||
handle_new_command,
|
||||
handle_resume_command,
|
||||
handle_settings_command,
|
||||
handle_status_command,
|
||||
)
|
||||
@@ -461,3 +462,27 @@ class TestHandleSettingsCommand:
|
||||
# Verify correct behavior
|
||||
mock_display_settings.assert_called_once_with(config)
|
||||
mock_cli_confirm.assert_called_once()
|
||||
|
||||
|
||||
class TestHandleResumeCommand:
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_resume_command(self):
|
||||
"""Test that handle_resume_command adds the 'continue' message to the event stream."""
|
||||
# Create a mock event stream
|
||||
event_stream = MagicMock(spec=EventStream)
|
||||
|
||||
# Call the function
|
||||
close_repl, new_session_requested = await handle_resume_command(event_stream)
|
||||
|
||||
# Check that the event stream add_event was called with the correct message action
|
||||
event_stream.add_event.assert_called_once()
|
||||
args, kwargs = event_stream.add_event.call_args
|
||||
message_action, source = args
|
||||
|
||||
assert isinstance(message_action, MessageAction)
|
||||
assert message_action.content == 'continue'
|
||||
assert source == EventSource.USER
|
||||
|
||||
# Check the return values
|
||||
assert close_repl is True
|
||||
assert new_session_requested is False
|
||||
|
||||
325
tests/unit/test_cli_pause_resume.py
Normal file
325
tests/unit/test_cli_pause_resume.py
Normal file
@@ -0,0 +1,325 @@
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
from prompt_toolkit.formatted_text import HTML
|
||||
from prompt_toolkit.keys import Keys
|
||||
|
||||
from openhands.core.cli_tui import process_agent_pause
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.events import EventSource
|
||||
from openhands.events.action import ChangeAgentStateAction
|
||||
from openhands.events.observation import AgentStateChangedObservation
|
||||
from openhands.events.stream import EventStream
|
||||
|
||||
|
||||
class TestProcessAgentPause:
|
||||
@pytest.mark.asyncio
|
||||
@patch('openhands.core.cli_tui.create_input')
|
||||
@patch('openhands.core.cli_tui.print_formatted_text')
|
||||
async def test_process_agent_pause_ctrl_p(self, mock_print, mock_create_input):
|
||||
"""Test that process_agent_pause sets the done event when Ctrl+P is pressed."""
|
||||
# Create the done event
|
||||
done = asyncio.Event()
|
||||
|
||||
# Set up the mock input
|
||||
mock_input = MagicMock()
|
||||
mock_create_input.return_value = mock_input
|
||||
|
||||
# Mock the context managers
|
||||
mock_raw_mode = MagicMock()
|
||||
mock_input.raw_mode.return_value = mock_raw_mode
|
||||
mock_raw_mode.__enter__ = MagicMock()
|
||||
mock_raw_mode.__exit__ = MagicMock()
|
||||
|
||||
mock_attach = MagicMock()
|
||||
mock_input.attach.return_value = mock_attach
|
||||
mock_attach.__enter__ = MagicMock()
|
||||
mock_attach.__exit__ = MagicMock()
|
||||
|
||||
# Capture the keys_ready function
|
||||
keys_ready_func = None
|
||||
|
||||
def fake_attach(callback):
|
||||
nonlocal keys_ready_func
|
||||
keys_ready_func = callback
|
||||
return mock_attach
|
||||
|
||||
mock_input.attach.side_effect = fake_attach
|
||||
|
||||
# Create a task to run process_agent_pause
|
||||
task = asyncio.create_task(process_agent_pause(done))
|
||||
|
||||
# Give it a moment to start and capture the callback
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Make sure we captured the callback
|
||||
assert keys_ready_func is not None
|
||||
|
||||
# Create a key press that simulates Ctrl+P
|
||||
key_press = MagicMock()
|
||||
key_press.key = Keys.ControlP
|
||||
mock_input.read_keys.return_value = [key_press]
|
||||
|
||||
# Manually call the callback to simulate key press
|
||||
keys_ready_func()
|
||||
|
||||
# Verify done was set
|
||||
assert done.is_set()
|
||||
|
||||
# Verify print was called with the pause message
|
||||
assert mock_print.call_count == 2
|
||||
assert mock_print.call_args_list[0] == call('')
|
||||
|
||||
# Check that the second call contains the pause message HTML
|
||||
second_call = mock_print.call_args_list[1][0][0]
|
||||
assert isinstance(second_call, HTML)
|
||||
assert 'Pausing the agent' in str(second_call)
|
||||
|
||||
# Cancel the task
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
class TestCliPauseResumeInRunSession:
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_event_async_pause_processing(self):
|
||||
"""Test that on_event_async processes the pause event when is_paused is set."""
|
||||
# Create a mock event
|
||||
event = MagicMock()
|
||||
|
||||
# Create mock dependencies
|
||||
event_stream = MagicMock()
|
||||
is_paused = asyncio.Event()
|
||||
reload_microagents = False
|
||||
config = MagicMock()
|
||||
|
||||
# Patch the display_event function
|
||||
with patch('openhands.core.cli.display_event') as mock_display_event, patch(
|
||||
'openhands.core.cli.update_usage_metrics'
|
||||
) as mock_update_metrics:
|
||||
# Create a closure to capture the current context
|
||||
async def test_func():
|
||||
# Set the pause event
|
||||
is_paused.set()
|
||||
|
||||
# Create a context similar to run_session to call on_event_async
|
||||
# We're creating a function that mimics the environment of on_event_async
|
||||
async def on_event_async_test(event):
|
||||
nonlocal reload_microagents, is_paused
|
||||
mock_display_event(event, config)
|
||||
mock_update_metrics(event, usage_metrics=MagicMock())
|
||||
|
||||
# Pause the agent if the pause event is set (through Ctrl-P)
|
||||
if is_paused.is_set():
|
||||
event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.PAUSED),
|
||||
EventSource.USER,
|
||||
)
|
||||
is_paused.clear()
|
||||
|
||||
# Call our test function
|
||||
await on_event_async_test(event)
|
||||
|
||||
# Check that the event_stream.add_event was called with the correct action
|
||||
event_stream.add_event.assert_called_once()
|
||||
args, kwargs = event_stream.add_event.call_args
|
||||
action, source = args
|
||||
|
||||
assert isinstance(action, ChangeAgentStateAction)
|
||||
assert action.agent_state == AgentState.PAUSED
|
||||
assert source == EventSource.USER
|
||||
|
||||
# Check that is_paused was cleared
|
||||
assert not is_paused.is_set()
|
||||
|
||||
# Run the test function
|
||||
await test_func()
|
||||
|
||||
|
||||
class TestCliCommandsPauseResume:
|
||||
@pytest.mark.asyncio
|
||||
@patch('openhands.core.cli_commands.handle_resume_command')
|
||||
async def test_handle_commands_resume(self, mock_handle_resume):
|
||||
"""Test that the handle_commands function properly calls handle_resume_command."""
|
||||
# Import the handle_commands function
|
||||
from openhands.core.cli_commands import handle_commands
|
||||
|
||||
# Set up mocks
|
||||
event_stream = MagicMock(spec=EventStream)
|
||||
usage_metrics = MagicMock()
|
||||
sid = 'test-session-id'
|
||||
config = MagicMock()
|
||||
current_dir = '/test/dir'
|
||||
settings_store = MagicMock()
|
||||
|
||||
# Set the return value for handle_resume_command
|
||||
mock_handle_resume.return_value = (False, False)
|
||||
|
||||
# Call handle_commands with the resume command
|
||||
close_repl, reload_microagents, new_session_requested = await handle_commands(
|
||||
'/resume',
|
||||
event_stream,
|
||||
usage_metrics,
|
||||
sid,
|
||||
config,
|
||||
current_dir,
|
||||
settings_store,
|
||||
)
|
||||
|
||||
# Check that handle_resume_command was called with the correct arguments
|
||||
mock_handle_resume.assert_called_once_with(event_stream)
|
||||
|
||||
# Check the return values
|
||||
assert close_repl is False
|
||||
assert reload_microagents is False
|
||||
assert new_session_requested is False
|
||||
|
||||
|
||||
class TestAgentStatePauseResume:
|
||||
@pytest.mark.asyncio
|
||||
@patch('openhands.core.cli.display_agent_running_message')
|
||||
@patch('openhands.core.cli.process_agent_pause')
|
||||
async def test_agent_running_enables_pause(
|
||||
self, mock_process_agent_pause, mock_display_message
|
||||
):
|
||||
"""Test that when the agent is running, pause functionality is enabled."""
|
||||
# Create mock dependencies
|
||||
event = MagicMock()
|
||||
# AgentStateChangedObservation requires a content parameter
|
||||
event.observation = AgentStateChangedObservation(
|
||||
agent_state=AgentState.RUNNING, content='Agent state changed to RUNNING'
|
||||
)
|
||||
|
||||
# Create a context similar to run_session to call on_event_async
|
||||
loop = MagicMock()
|
||||
is_paused = asyncio.Event()
|
||||
config = MagicMock()
|
||||
config.security.confirmation_mode = False
|
||||
|
||||
# Create a closure to capture the current context
|
||||
async def test_func():
|
||||
# Call our simplified on_event_async
|
||||
async def on_event_async_test(event):
|
||||
if isinstance(event.observation, AgentStateChangedObservation):
|
||||
if event.observation.agent_state == AgentState.RUNNING:
|
||||
# Enable pause/resume functionality only if the confirmation mode is disabled
|
||||
if not config.security.confirmation_mode:
|
||||
mock_display_message()
|
||||
loop.create_task(mock_process_agent_pause(is_paused))
|
||||
|
||||
# Call the function
|
||||
await on_event_async_test(event)
|
||||
|
||||
# Check that the message was displayed
|
||||
mock_display_message.assert_called_once()
|
||||
|
||||
# Check that process_agent_pause was called with the right arguments
|
||||
mock_process_agent_pause.assert_called_once_with(is_paused)
|
||||
|
||||
# Check that loop.create_task was called
|
||||
loop.create_task.assert_called_once()
|
||||
|
||||
# Run the test function
|
||||
await test_func()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('openhands.core.cli.display_event')
|
||||
@patch('openhands.core.cli.update_usage_metrics')
|
||||
async def test_pause_event_changes_agent_state(
|
||||
self, mock_update_metrics, mock_display_event
|
||||
):
|
||||
"""Test that when is_paused is set, a PAUSED state change event is added to the stream."""
|
||||
# Create mock dependencies
|
||||
event = MagicMock()
|
||||
event_stream = MagicMock()
|
||||
is_paused = asyncio.Event()
|
||||
config = MagicMock()
|
||||
reload_microagents = False
|
||||
|
||||
# Set the pause event
|
||||
is_paused.set()
|
||||
|
||||
# Create a closure to capture the current context
|
||||
async def test_func():
|
||||
# Create a context similar to run_session to call on_event_async
|
||||
async def on_event_async_test(event):
|
||||
nonlocal reload_microagents
|
||||
mock_display_event(event, config)
|
||||
mock_update_metrics(event, MagicMock())
|
||||
|
||||
# Pause the agent if the pause event is set (through Ctrl-P)
|
||||
if is_paused.is_set():
|
||||
event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.PAUSED),
|
||||
EventSource.USER,
|
||||
)
|
||||
is_paused.clear()
|
||||
|
||||
# Call the function
|
||||
await on_event_async_test(event)
|
||||
|
||||
# Check that the event_stream.add_event was called with the correct action
|
||||
event_stream.add_event.assert_called_once()
|
||||
args, kwargs = event_stream.add_event.call_args
|
||||
action, source = args
|
||||
|
||||
assert isinstance(action, ChangeAgentStateAction)
|
||||
assert action.agent_state == AgentState.PAUSED
|
||||
assert source == EventSource.USER
|
||||
|
||||
# Check that is_paused was cleared
|
||||
assert not is_paused.is_set()
|
||||
|
||||
# Run the test
|
||||
await test_func()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_paused_agent_awaits_input(self):
|
||||
"""Test that when the agent is paused, it awaits user input."""
|
||||
# Create mock dependencies
|
||||
event = MagicMock()
|
||||
# AgentStateChangedObservation requires a content parameter
|
||||
event.observation = AgentStateChangedObservation(
|
||||
agent_state=AgentState.PAUSED, content='Agent state changed to PAUSED'
|
||||
)
|
||||
reload_microagents = False
|
||||
memory = MagicMock()
|
||||
runtime = MagicMock()
|
||||
prompt_task = MagicMock()
|
||||
|
||||
# Create a closure to capture the current context
|
||||
async def test_func():
|
||||
# Create a simplified version of on_event_async
|
||||
async def on_event_async_test(event):
|
||||
nonlocal reload_microagents, prompt_task
|
||||
|
||||
if isinstance(event.observation, AgentStateChangedObservation):
|
||||
if event.observation.agent_state in [
|
||||
AgentState.AWAITING_USER_INPUT,
|
||||
AgentState.FINISHED,
|
||||
AgentState.PAUSED,
|
||||
]:
|
||||
# Reload microagents after initialization of repo.md
|
||||
if reload_microagents:
|
||||
microagents = runtime.get_microagents_from_selected_repo(
|
||||
None
|
||||
)
|
||||
memory.load_user_workspace_microagents(microagents)
|
||||
reload_microagents = False
|
||||
|
||||
# Since prompt_for_next_task is a nested function in cli.py,
|
||||
# we'll just check that we've reached this code path
|
||||
prompt_task = 'Prompt for next task would be called here'
|
||||
|
||||
# Call the function
|
||||
await on_event_async_test(event)
|
||||
|
||||
# Check that we reached the code path where prompt_for_next_task would be called
|
||||
assert prompt_task == 'Prompt for next task would be called here'
|
||||
|
||||
# Run the test
|
||||
await test_func()
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
from openhands.core.cli_tui import (
|
||||
@@ -52,12 +51,9 @@ class TestDisplayFunctions:
|
||||
|
||||
@patch('openhands.core.cli_tui.print_formatted_text')
|
||||
def test_display_banner(self, mock_print):
|
||||
# Create a mock loaded event
|
||||
is_loaded = asyncio.Event()
|
||||
is_loaded.set()
|
||||
session_id = 'test-session-id'
|
||||
|
||||
display_banner(session_id, is_loaded)
|
||||
display_banner(session_id)
|
||||
|
||||
# Verify banner calls
|
||||
assert mock_print.call_count >= 3
|
||||
|
||||
@@ -100,6 +100,7 @@ def test_process_events_with_message_action(conversation_memory):
|
||||
# Process events
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[system_message, user_message, assistant_message],
|
||||
initial_user_action=user_message,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
@@ -108,10 +109,178 @@ def test_process_events_with_message_action(conversation_memory):
|
||||
assert len(messages) == 3
|
||||
assert messages[0].role == 'system'
|
||||
assert messages[0].content[0].text == 'System message'
|
||||
|
||||
|
||||
# Test cases for _ensure_system_message
|
||||
def test_ensure_system_message_adds_if_missing(conversation_memory):
|
||||
"""Test that _ensure_system_message adds a system message if none exists."""
|
||||
user_message = MessageAction(content='User message')
|
||||
user_message._source = EventSource.USER
|
||||
events = [user_message]
|
||||
conversation_memory._ensure_system_message(events)
|
||||
assert len(events) == 2
|
||||
assert isinstance(events[0], SystemMessageAction)
|
||||
assert events[0].content == 'System message' # From fixture
|
||||
assert isinstance(events[1], MessageAction) # Original event is still there
|
||||
|
||||
|
||||
def test_ensure_system_message_does_nothing_if_present(conversation_memory):
|
||||
"""Test that _ensure_system_message does nothing if a system message is already present."""
|
||||
original_system_message = SystemMessageAction(content='Existing system message')
|
||||
user_message = MessageAction(content='User message')
|
||||
user_message._source = EventSource.USER
|
||||
events = [
|
||||
original_system_message,
|
||||
user_message,
|
||||
]
|
||||
original_events = list(events) # Copy before modification
|
||||
conversation_memory._ensure_system_message(events)
|
||||
assert events == original_events # List should be unchanged
|
||||
|
||||
|
||||
# Test cases for _ensure_initial_user_message
|
||||
@pytest.fixture
|
||||
def initial_user_action():
|
||||
msg = MessageAction(content='Initial User Message')
|
||||
msg._source = EventSource.USER
|
||||
return msg
|
||||
|
||||
|
||||
def test_ensure_initial_user_message_adds_if_only_system(
|
||||
conversation_memory, initial_user_action
|
||||
):
|
||||
"""Test adding the initial user message when only the system message exists."""
|
||||
system_message = SystemMessageAction(content='System')
|
||||
system_message._source = EventSource.AGENT
|
||||
events = [system_message]
|
||||
conversation_memory._ensure_initial_user_message(events, initial_user_action)
|
||||
assert len(events) == 2
|
||||
assert events[0] == system_message
|
||||
assert events[1] == initial_user_action
|
||||
|
||||
|
||||
def test_ensure_initial_user_message_correct_already_present(
|
||||
conversation_memory, initial_user_action
|
||||
):
|
||||
"""Test that nothing changes if the correct initial user message is at index 1."""
|
||||
system_message = SystemMessageAction(content='System')
|
||||
agent_message = MessageAction(content='Assistant')
|
||||
agent_message._source = EventSource.USER
|
||||
events = [
|
||||
system_message,
|
||||
initial_user_action,
|
||||
agent_message,
|
||||
]
|
||||
original_events = list(events)
|
||||
conversation_memory._ensure_initial_user_message(events, initial_user_action)
|
||||
assert events == original_events
|
||||
|
||||
|
||||
def test_ensure_initial_user_message_incorrect_at_index_1(
|
||||
conversation_memory, initial_user_action
|
||||
):
|
||||
"""Test inserting the correct initial user message when an incorrect message is at index 1."""
|
||||
system_message = SystemMessageAction(content='System')
|
||||
incorrect_second_message = MessageAction(content='Assistant')
|
||||
incorrect_second_message._source = EventSource.AGENT
|
||||
events = [system_message, incorrect_second_message]
|
||||
conversation_memory._ensure_initial_user_message(events, initial_user_action)
|
||||
assert len(events) == 3
|
||||
assert events[0] == system_message
|
||||
assert events[1] == initial_user_action # Correct one inserted
|
||||
assert events[2] == incorrect_second_message # Original second message shifted
|
||||
|
||||
|
||||
def test_ensure_initial_user_message_correct_present_later(
|
||||
conversation_memory, initial_user_action
|
||||
):
|
||||
"""Test inserting the correct initial user message at index 1 even if it exists later."""
|
||||
system_message = SystemMessageAction(content='System')
|
||||
incorrect_second_message = MessageAction(content='Assistant')
|
||||
incorrect_second_message._source = EventSource.AGENT
|
||||
# Correct initial message is present, but later in the list
|
||||
events = [system_message, incorrect_second_message]
|
||||
conversation_memory._ensure_system_message(events)
|
||||
conversation_memory._ensure_initial_user_message(events, initial_user_action)
|
||||
assert len(events) == 3 # Should still insert at index 1, not remove the later one
|
||||
assert events[0] == system_message
|
||||
assert events[1] == initial_user_action # Correct one inserted at index 1
|
||||
assert events[2] == incorrect_second_message # Original second message shifted
|
||||
# The duplicate initial_user_action originally at index 2 is now at index 3 (implicitly tested by length and content)
|
||||
|
||||
|
||||
def test_ensure_initial_user_message_different_user_msg_at_index_1(
|
||||
conversation_memory, initial_user_action
|
||||
):
|
||||
"""Test inserting the correct initial user message when a *different* user message is at index 1."""
|
||||
system_message = SystemMessageAction(content='System')
|
||||
different_user_message = MessageAction(content='Different User Message')
|
||||
different_user_message._source = EventSource.USER
|
||||
events = [system_message, different_user_message]
|
||||
conversation_memory._ensure_initial_user_message(events, initial_user_action)
|
||||
assert len(events) == 2
|
||||
assert events[0] == system_message
|
||||
assert events[1] == different_user_message # Original second message remains
|
||||
|
||||
|
||||
def test_ensure_initial_user_message_different_user_msg_at_index_1_and_orphaned_obs(
|
||||
conversation_memory, initial_user_action
|
||||
):
|
||||
"""
|
||||
Test process_events when an incorrect user message is at index 1 AND
|
||||
an orphaned observation (with tool_call_metadata but no matching action) exists.
|
||||
Expect: System msg, CORRECT initial user msg, the incorrect user msg (shifted).
|
||||
The orphaned observation should be filtered out.
|
||||
"""
|
||||
system_message = SystemMessageAction(content='System')
|
||||
different_user_message = MessageAction(content='Different User Message')
|
||||
different_user_message._source = EventSource.USER
|
||||
|
||||
# Create an orphaned observation (no matching action/tool call request will exist)
|
||||
# Use a dictionary that mimics ModelResponse structure to satisfy Pydantic
|
||||
mock_response = {
|
||||
'id': 'mock_response_id',
|
||||
'choices': [{'message': {'content': None, 'tool_calls': []}}],
|
||||
'created': 0,
|
||||
'model': '',
|
||||
'object': '',
|
||||
'usage': {'completion_tokens': 0, 'prompt_tokens': 0, 'total_tokens': 0},
|
||||
}
|
||||
orphaned_obs = CmdOutputObservation(
|
||||
command='orphan_cmd',
|
||||
content='Orphaned output',
|
||||
command_id=99,
|
||||
exit_code=0,
|
||||
)
|
||||
orphaned_obs.tool_call_metadata = ToolCallMetadata(
|
||||
tool_call_id='orphan_call_id',
|
||||
function_name='execute_bash',
|
||||
model_response=mock_response,
|
||||
total_calls_in_response=1,
|
||||
)
|
||||
|
||||
# Initial events list: system, wrong user message, orphaned observation
|
||||
events = [system_message, different_user_message, orphaned_obs]
|
||||
|
||||
# Call the main process_events method
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=events,
|
||||
initial_user_action=initial_user_action, # Provide the *correct* initial action
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# Assertions on the final messages list
|
||||
assert len(messages) == 2
|
||||
# 1. System message should be first
|
||||
assert messages[0].role == 'system'
|
||||
assert messages[0].content[0].text == 'System'
|
||||
|
||||
# 2. The different user message should be left at index 1
|
||||
assert messages[1].role == 'user'
|
||||
assert messages[1].content[0].text == 'Hello'
|
||||
assert messages[2].role == 'assistant'
|
||||
assert messages[2].content[0].text == 'Hi there'
|
||||
assert messages[1].content[0].text == different_user_message.content
|
||||
|
||||
# Implicitly assert that the orphaned_obs was filtered out by checking the length (2)
|
||||
|
||||
|
||||
def test_process_events_with_cmd_output_observation(conversation_memory):
|
||||
@@ -125,14 +294,17 @@ def test_process_events_with_cmd_output_observation(conversation_memory):
|
||||
),
|
||||
)
|
||||
|
||||
initial_user_message = MessageAction(content='Initial user message')
|
||||
initial_user_message._source = EventSource.USER
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_user_action=initial_user_message,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert len(messages) == 3 # System + initial user + result
|
||||
result = messages[2] # The actual result is now at index 2
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
@@ -148,14 +320,17 @@ def test_process_events_with_ipython_run_cell_observation(conversation_memory):
|
||||
content='IPython output\n',
|
||||
)
|
||||
|
||||
initial_user_message = MessageAction(content='Initial user message')
|
||||
initial_user_message._source = EventSource.USER
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_user_action=initial_user_message,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert len(messages) == 3 # System + initial user + result
|
||||
result = messages[2] # The actual result is now at index 2
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
@@ -172,14 +347,17 @@ def test_process_events_with_agent_delegate_observation(conversation_memory):
|
||||
content='Content', outputs={'content': 'Delegated agent output'}
|
||||
)
|
||||
|
||||
initial_user_message = MessageAction(content='Initial user message')
|
||||
initial_user_message._source = EventSource.USER
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_user_action=initial_user_message,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert len(messages) == 3 # System + initial user + result
|
||||
result = messages[2] # The actual result is now at index 2
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
@@ -189,14 +367,17 @@ def test_process_events_with_agent_delegate_observation(conversation_memory):
|
||||
def test_process_events_with_error_observation(conversation_memory):
|
||||
obs = ErrorObservation('Error message')
|
||||
|
||||
initial_user_message = MessageAction(content='Initial user message')
|
||||
initial_user_message._source = EventSource.USER
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_user_action=initial_user_message,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert len(messages) == 3 # System + initial user + result
|
||||
result = messages[2] # The actual result is now at index 2
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
@@ -207,10 +388,13 @@ def test_process_events_with_error_observation(conversation_memory):
|
||||
def test_process_events_with_unknown_observation(conversation_memory):
|
||||
# Create a mock that inherits from Event but not Action or Observation
|
||||
obs = Mock(spec=Event)
|
||||
initial_user_message = MessageAction(content='Initial user message')
|
||||
initial_user_message._source = EventSource.USER
|
||||
|
||||
with pytest.raises(ValueError, match='Unknown event type'):
|
||||
conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_user_action=initial_user_message,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
@@ -226,14 +410,17 @@ def test_process_events_with_file_edit_observation(conversation_memory):
|
||||
impl_source=FileEditSource.LLM_BASED_EDIT,
|
||||
)
|
||||
|
||||
initial_user_message = MessageAction(content='Initial user message')
|
||||
initial_user_message._source = EventSource.USER
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_user_action=initial_user_message,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert len(messages) == 3 # System + initial user + result
|
||||
result = messages[2] # The actual result is now at index 2
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
@@ -247,18 +434,21 @@ def test_process_events_with_file_read_observation(conversation_memory):
|
||||
impl_source=FileReadSource.DEFAULT,
|
||||
)
|
||||
|
||||
initial_user_message = MessageAction(content='Initial user message')
|
||||
initial_user_message._source = EventSource.USER
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_user_action=initial_user_message,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert len(messages) == 3 # System + initial user + result
|
||||
result = messages[2] # The actual result is now at index 2
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
assert result.content[0].text == 'File content'
|
||||
assert result.content[0].text == '\n\nFile content'
|
||||
|
||||
|
||||
def test_process_events_with_browser_output_observation(conversation_memory):
|
||||
@@ -270,14 +460,17 @@ def test_process_events_with_browser_output_observation(conversation_memory):
|
||||
error=False,
|
||||
)
|
||||
|
||||
initial_user_message = MessageAction(content='Initial user message')
|
||||
initial_user_message._source = EventSource.USER
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_user_action=initial_user_message,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert len(messages) == 3 # System + initial user + result
|
||||
result = messages[2] # The actual result is now at index 2
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
@@ -287,14 +480,17 @@ def test_process_events_with_browser_output_observation(conversation_memory):
|
||||
def test_process_events_with_user_reject_observation(conversation_memory):
|
||||
obs = UserRejectObservation('Action rejected')
|
||||
|
||||
initial_user_message = MessageAction(content='Initial user message')
|
||||
initial_user_message._source = EventSource.USER
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_user_action=initial_user_message,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert len(messages) == 3 # System + initial user + result
|
||||
result = messages[2] # The actual result is now at index 2
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
@@ -317,14 +513,17 @@ def test_process_events_with_empty_environment_info(conversation_memory):
|
||||
content='Retrieved environment info',
|
||||
)
|
||||
|
||||
initial_user_message = MessageAction(content='Initial user message')
|
||||
initial_user_message._source = EventSource.USER
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[empty_obs],
|
||||
initial_user_action=initial_user_message,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# Should only contain no messages except system message
|
||||
assert len(messages) == 1
|
||||
# Should only contain system message and initial user message
|
||||
assert len(messages) == 2
|
||||
|
||||
# Verify that build_workspace_context was NOT called since all input values were empty
|
||||
conversation_memory.prompt_manager.build_workspace_context.assert_not_called()
|
||||
@@ -348,14 +547,20 @@ def test_process_events_with_function_calling_observation(conversation_memory):
|
||||
model_response=mock_response,
|
||||
total_calls_in_response=1,
|
||||
)
|
||||
# Define initial user action
|
||||
initial_user_action = MessageAction(content='Initial user message')
|
||||
initial_user_action._source = EventSource.USER
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_user_action=initial_user_action,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# No direct message when using function calling
|
||||
assert len(messages) == 1 # should be no messages except system message
|
||||
assert (
|
||||
len(messages) == 2
|
||||
) # should be no messages except system message and initial user message
|
||||
|
||||
|
||||
def test_process_events_with_message_action_with_image(conversation_memory):
|
||||
@@ -365,14 +570,18 @@ def test_process_events_with_message_action_with_image(conversation_memory):
|
||||
)
|
||||
action._source = EventSource.AGENT
|
||||
|
||||
# Define initial user action
|
||||
initial_user_action = MessageAction(content='Initial user message')
|
||||
initial_user_action._source = EventSource.USER
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[action],
|
||||
initial_user_action=initial_user_action,
|
||||
max_message_chars=None,
|
||||
vision_is_active=True,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert len(messages) == 3
|
||||
result = messages[2]
|
||||
assert result.role == 'assistant'
|
||||
assert len(result.content) == 2
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
@@ -385,14 +594,18 @@ def test_process_events_with_user_cmd_action(conversation_memory):
|
||||
action = CmdRunAction(command='ls -l')
|
||||
action._source = EventSource.USER
|
||||
|
||||
# Define initial user action
|
||||
initial_user_action = MessageAction(content='Initial user message')
|
||||
initial_user_action._source = EventSource.USER
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[action],
|
||||
initial_user_action=initial_user_action,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert len(messages) == 3
|
||||
result = messages[2]
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
@@ -418,14 +631,18 @@ def test_process_events_with_agent_finish_action_with_tool_metadata(
|
||||
total_calls_in_response=1,
|
||||
)
|
||||
|
||||
# Define initial user action
|
||||
initial_user_action = MessageAction(content='Initial user message')
|
||||
initial_user_action._source = EventSource.USER
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[action],
|
||||
initial_user_action=initial_user_action,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert len(messages) == 3
|
||||
result = messages[2]
|
||||
assert result.role == 'assistant'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
@@ -461,18 +678,22 @@ def test_process_events_with_environment_microagent_observation(conversation_mem
|
||||
content='Retrieved environment info',
|
||||
)
|
||||
|
||||
# Define initial user action
|
||||
initial_user_action = MessageAction(content='Initial user message')
|
||||
initial_user_action._source = EventSource.USER
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_user_action=initial_user_action,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert len(messages) == 3
|
||||
result = messages[2]
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
assert result.content[0].text == 'Formatted repository and runtime info'
|
||||
assert result.content[0].text == '\n\nFormatted repository and runtime info'
|
||||
|
||||
# Verify the prompt_manager was called with the correct parameters
|
||||
conversation_memory.prompt_manager.build_workspace_context.assert_called_once()
|
||||
@@ -516,14 +737,18 @@ def test_process_events_with_knowledge_microagent_microagent_observation(
|
||||
content='Retrieved knowledge from microagents',
|
||||
)
|
||||
|
||||
# Define initial user action
|
||||
initial_user_action = MessageAction(content='Initial user message')
|
||||
initial_user_action._source = EventSource.USER
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_user_action=initial_user_action,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert len(messages) == 3 # System + Initial User + Result
|
||||
result = messages[2] # Result is now at index 2
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
@@ -559,14 +784,18 @@ def test_process_events_with_microagent_observation_extensions_disabled(
|
||||
content='Retrieved environment info',
|
||||
)
|
||||
|
||||
# Define initial user action
|
||||
initial_user_action = MessageAction(content='Initial user message')
|
||||
initial_user_action._source = EventSource.USER
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_user_action=initial_user_action,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# When prompt extensions are disabled, the RecallObservation should be ignored
|
||||
assert len(messages) == 1 # should be no messages except system message
|
||||
assert len(messages) == 2 # System + Initial User
|
||||
|
||||
# Verify the prompt_manager was not called
|
||||
conversation_memory.prompt_manager.build_workspace_context.assert_not_called()
|
||||
@@ -581,14 +810,18 @@ def test_process_events_with_empty_microagent_knowledge(conversation_memory):
|
||||
content='Retrieved knowledge from microagents',
|
||||
)
|
||||
|
||||
# Define initial user action
|
||||
initial_user_action = MessageAction(content='Initial user message')
|
||||
initial_user_action._source = EventSource.USER
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_user_action=initial_user_action,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# The implementation returns an empty string and it doesn't creates a message
|
||||
assert len(messages) == 1 # should be no messages except system message
|
||||
assert len(messages) == 2 # System + Initial User
|
||||
|
||||
# When there are no triggered agents, build_microagent_info is not called
|
||||
conversation_memory.prompt_manager.build_microagent_info.assert_not_called()
|
||||
@@ -793,19 +1026,23 @@ def test_process_events_with_microagent_observation_deduplication(conversation_m
|
||||
content='Third retrieval',
|
||||
)
|
||||
|
||||
# Define initial user action
|
||||
initial_user_action = MessageAction(content='Initial user message')
|
||||
initial_user_action._source = EventSource.USER
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs1, obs2, obs3],
|
||||
initial_user_action=initial_user_action,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# Verify that only the first occurrence of content for each agent is included
|
||||
assert len(messages) == 2 # with system message
|
||||
|
||||
assert len(messages) == 3 # System + Initial User + Result
|
||||
# Result is now at index 2
|
||||
# First microagent should include all agents since they appear here first
|
||||
assert 'Image best practices v1' in messages[1].content[0].text
|
||||
assert 'Git best practices v1' in messages[1].content[0].text
|
||||
assert 'Python best practices v1' in messages[1].content[0].text
|
||||
assert 'Image best practices v1' in messages[2].content[0].text
|
||||
assert 'Git best practices v1' in messages[2].content[0].text
|
||||
assert 'Python best practices v1' in messages[2].content[0].text
|
||||
|
||||
|
||||
def test_process_events_with_microagent_observation_deduplication_disabled_agents(
|
||||
@@ -842,18 +1079,22 @@ def test_process_events_with_microagent_observation_deduplication_disabled_agent
|
||||
content='Second retrieval',
|
||||
)
|
||||
|
||||
# Define initial user action
|
||||
initial_user_action = MessageAction(content='Initial user message')
|
||||
initial_user_action._source = EventSource.USER
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs1, obs2],
|
||||
initial_user_action=initial_user_action,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# Verify that disabled agents are filtered out and only the first occurrence of enabled agents is included
|
||||
assert len(messages) == 2
|
||||
|
||||
assert len(messages) == 3 # System + Initial User + Result
|
||||
# Result is now at index 2
|
||||
# First microagent should include enabled_agent but not disabled_agent
|
||||
assert 'Disabled agent content' not in messages[1].content[0].text
|
||||
assert 'Enabled agent content v1' in messages[1].content[0].text
|
||||
assert 'Disabled agent content' not in messages[2].content[0].text
|
||||
assert 'Enabled agent content v1' in messages[2].content[0].text
|
||||
|
||||
|
||||
def test_process_events_with_microagent_observation_deduplication_empty(
|
||||
@@ -866,17 +1107,22 @@ def test_process_events_with_microagent_observation_deduplication_empty(
|
||||
content='Empty retrieval',
|
||||
)
|
||||
|
||||
# Define initial user action
|
||||
initial_user_action = MessageAction(content='Initial user message')
|
||||
initial_user_action._source = EventSource.USER
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_user_action=initial_user_action,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# Verify that empty RecallObservations are handled gracefully
|
||||
assert (
|
||||
len(messages) == 1
|
||||
) # an empty microagent is not added to Messages, only system message is found
|
||||
len(messages) == 2 # System + Initial User
|
||||
) # an empty microagent is not added to Messages
|
||||
assert messages[0].role == 'system'
|
||||
assert messages[1].role == 'user' # Initial user message
|
||||
|
||||
|
||||
def test_has_agent_in_earlier_events(conversation_memory):
|
||||
@@ -1088,13 +1334,183 @@ def test_system_message_in_events(conversation_memory):
|
||||
system_message._source = EventSource.AGENT
|
||||
|
||||
# Process events with the system message in condensed_history
|
||||
# Define initial user action
|
||||
initial_user_action = MessageAction(content='Initial user message')
|
||||
initial_user_action._source = EventSource.USER
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[system_message],
|
||||
initial_user_action=initial_user_action,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# Check that the system message was processed correctly
|
||||
assert len(messages) == 1
|
||||
assert len(messages) == 2 # System + Initial User
|
||||
assert messages[0].role == 'system'
|
||||
assert messages[0].content[0].text == 'System message'
|
||||
assert messages[1].role == 'user' # Initial user message
|
||||
|
||||
|
||||
# Helper function to create mock tool call metadata
|
||||
def _create_mock_tool_call_metadata(
|
||||
tool_call_id: str, function_name: str, response_id: str = 'mock_response_id'
|
||||
) -> ToolCallMetadata:
|
||||
# Use a dictionary that mimics ModelResponse structure to satisfy Pydantic
|
||||
mock_response = {
|
||||
'id': response_id,
|
||||
'choices': [
|
||||
{
|
||||
'message': {
|
||||
'role': 'assistant',
|
||||
'content': None, # Content is None for tool calls
|
||||
'tool_calls': [
|
||||
{
|
||||
'id': tool_call_id,
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': function_name,
|
||||
'arguments': '{}',
|
||||
}, # Args don't matter for this test
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
],
|
||||
'created': 0,
|
||||
'model': 'mock_model',
|
||||
'object': 'chat.completion',
|
||||
'usage': {'completion_tokens': 0, 'prompt_tokens': 0, 'total_tokens': 0},
|
||||
}
|
||||
return ToolCallMetadata(
|
||||
tool_call_id=tool_call_id,
|
||||
function_name=function_name,
|
||||
model_response=mock_response,
|
||||
total_calls_in_response=1,
|
||||
)
|
||||
|
||||
|
||||
def test_process_events_partial_history(conversation_memory):
|
||||
"""
|
||||
Tests process_events with full and partial histories to verify
|
||||
_ensure_system_message, _ensure_initial_user_message, and tool call matching logic.
|
||||
"""
|
||||
# --- Define Common Events ---
|
||||
system_message = SystemMessageAction(content='System message')
|
||||
system_message._source = EventSource.AGENT
|
||||
|
||||
user_message = MessageAction(
|
||||
content='Initial user query'
|
||||
) # This is the crucial initial_user_action
|
||||
user_message._source = EventSource.USER
|
||||
|
||||
recall_obs = RecallObservation(
|
||||
recall_type=RecallType.WORKSPACE_CONTEXT,
|
||||
repo_name='test-repo',
|
||||
repo_directory='/path/to/repo',
|
||||
content='Retrieved environment info',
|
||||
)
|
||||
recall_obs._source = EventSource.AGENT
|
||||
|
||||
cmd_action = CmdRunAction(command='ls', thought='Running ls')
|
||||
cmd_action._source = EventSource.AGENT
|
||||
cmd_action.tool_call_metadata = _create_mock_tool_call_metadata(
|
||||
tool_call_id='call_ls_1', function_name='execute_bash', response_id='resp_ls_1'
|
||||
)
|
||||
|
||||
cmd_obs = CmdOutputObservation(
|
||||
command_id=1, command='ls', content='file1.txt\nfile2.py', exit_code=0
|
||||
)
|
||||
cmd_obs._source = EventSource.AGENT
|
||||
cmd_obs.tool_call_metadata = _create_mock_tool_call_metadata(
|
||||
tool_call_id='call_ls_1', function_name='execute_bash', response_id='resp_ls_1'
|
||||
)
|
||||
|
||||
# --- Scenario 1: Full History ---
|
||||
full_history: list[Event] = [
|
||||
system_message,
|
||||
user_message, # Correct initial user message at index 1
|
||||
recall_obs,
|
||||
cmd_action,
|
||||
cmd_obs,
|
||||
]
|
||||
messages_full = conversation_memory.process_events(
|
||||
condensed_history=list(full_history), # Pass a copy
|
||||
initial_user_action=user_message, # Provide the initial action
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# Expected: System, User, Recall (formatted), Assistant (tool call), Tool Response
|
||||
assert len(messages_full) == 5
|
||||
assert messages_full[0].role == 'system'
|
||||
assert messages_full[0].content[0].text == 'System message'
|
||||
assert messages_full[1].role == 'user'
|
||||
assert messages_full[1].content[0].text == 'Initial user query'
|
||||
assert messages_full[2].role == 'user' # Recall obs becomes user message
|
||||
assert (
|
||||
'Formatted repository and runtime info' in messages_full[2].content[0].text
|
||||
) # From fixture mock
|
||||
assert messages_full[3].role == 'assistant'
|
||||
assert messages_full[3].tool_calls is not None
|
||||
assert len(messages_full[3].tool_calls) == 1
|
||||
assert messages_full[3].tool_calls[0].id == 'call_ls_1'
|
||||
assert messages_full[4].role == 'tool'
|
||||
assert messages_full[4].tool_call_id == 'call_ls_1'
|
||||
assert 'file1.txt' in messages_full[4].content[0].text
|
||||
|
||||
# --- Scenario 2: Partial History (Action + Observation) ---
|
||||
# Simulates processing only the last action/observation pair
|
||||
partial_history_action_obs: list[Event] = [
|
||||
cmd_action,
|
||||
cmd_obs,
|
||||
]
|
||||
messages_partial_action_obs = conversation_memory.process_events(
|
||||
condensed_history=list(partial_history_action_obs), # Pass a copy
|
||||
initial_user_action=user_message, # Provide the initial action
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# Expected: System (added), Initial User (added), Assistant (tool call), Tool Response
|
||||
assert len(messages_partial_action_obs) == 4
|
||||
assert (
|
||||
messages_partial_action_obs[0].role == 'system'
|
||||
) # Added by _ensure_system_message
|
||||
assert messages_partial_action_obs[0].content[0].text == 'System message'
|
||||
assert (
|
||||
messages_partial_action_obs[1].role == 'user'
|
||||
) # Added by _ensure_initial_user_message
|
||||
assert messages_partial_action_obs[1].content[0].text == 'Initial user query'
|
||||
assert messages_partial_action_obs[2].role == 'assistant'
|
||||
assert messages_partial_action_obs[2].tool_calls is not None
|
||||
assert len(messages_partial_action_obs[2].tool_calls) == 1
|
||||
assert messages_partial_action_obs[2].tool_calls[0].id == 'call_ls_1'
|
||||
assert messages_partial_action_obs[3].role == 'tool'
|
||||
assert messages_partial_action_obs[3].tool_call_id == 'call_ls_1'
|
||||
assert 'file1.txt' in messages_partial_action_obs[3].content[0].text
|
||||
|
||||
# --- Scenario 3: Partial History (Observation Only) ---
|
||||
# Simulates processing only the last observation
|
||||
partial_history_obs_only: list[Event] = [
|
||||
cmd_obs,
|
||||
]
|
||||
messages_partial_obs_only = conversation_memory.process_events(
|
||||
condensed_history=list(partial_history_obs_only), # Pass a copy
|
||||
initial_user_action=user_message, # Provide the initial action
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# Expected: System (added), Initial User (added).
|
||||
# The CmdOutputObservation has tool_call_metadata, but there's no corresponding
|
||||
# assistant message (from CmdRunAction) with the matching tool_call.id in the input history.
|
||||
# Therefore, _filter_unmatched_tool_calls should remove the tool response message.
|
||||
assert len(messages_partial_obs_only) == 2
|
||||
assert (
|
||||
messages_partial_obs_only[0].role == 'system'
|
||||
) # Added by _ensure_system_message
|
||||
assert messages_partial_obs_only[0].content[0].text == 'System message'
|
||||
assert (
|
||||
messages_partial_obs_only[1].role == 'user'
|
||||
) # Added by _ensure_initial_user_message
|
||||
assert messages_partial_obs_only[1].content[0].text == 'Initial user query'
|
||||
|
||||
@@ -76,7 +76,7 @@ def test_get_messages(codeact_agent: CodeActAgent):
|
||||
history.append(message_action_5)
|
||||
|
||||
codeact_agent.reset()
|
||||
messages = codeact_agent._get_messages(history)
|
||||
messages = codeact_agent._get_messages(history, message_action_1)
|
||||
|
||||
assert (
|
||||
len(messages) == 6
|
||||
@@ -106,16 +106,19 @@ def test_get_messages_prompt_caching(codeact_agent: CodeActAgent):
|
||||
history.append(system_message_action)
|
||||
|
||||
# Add multiple user and agent messages
|
||||
initial_user_message = None # Keep track of the first user message
|
||||
for i in range(15):
|
||||
message_action_user = MessageAction(f'User message {i}')
|
||||
message_action_user._source = 'user'
|
||||
if initial_user_message is None:
|
||||
initial_user_message = message_action_user # Store the first one
|
||||
history.append(message_action_user)
|
||||
message_action_agent = MessageAction(f'Agent message {i}')
|
||||
message_action_agent._source = 'agent'
|
||||
history.append(message_action_agent)
|
||||
|
||||
codeact_agent.reset()
|
||||
messages = codeact_agent._get_messages(history)
|
||||
messages = codeact_agent._get_messages(history, initial_user_message)
|
||||
|
||||
# Check that only the last two user messages have cache_prompt=True
|
||||
cached_user_messages = [
|
||||
|
||||
Reference in New Issue
Block a user