Compare commits

..

11 Commits

Author SHA1 Message Date
rohitvinodmalhotra@gmail.com 2399174e89 fix settings load 2025-04-28 19:09:27 -04:00
rohitvinodmalhotra@gmail.com 7c3f4891f8 render for saas 2025-04-28 18:02:56 -04:00
rohitvinodmalhotra@gmail.com 49bb7bbaba add base_domain to fe 2025-04-28 17:50:30 -04:00
openhands 10c1252cfe Add base domain fields for GitHub and GitLab in git settings 2025-04-28 20:14:55 +00:00
rohitvinodmalhotra@gmail.com 911867492c fix providertype comp 2025-04-28 16:07:03 -04:00
rohitvinodmalhotra@gmail.com 85a1b47c8d modify new git page 2025-04-28 16:03:58 -04:00
rohitvinodmalhotra@gmail.com d6011829a3 merge main 2025-04-28 16:02:54 -04:00
rohitvinodmalhotra@gmail.com 9200e1dbd8 fix providertype comparisions 2025-04-28 15:21:03 -04:00
rohitvinodmalhotra@gmail.com d1343539ba update save for provider token 2025-04-28 14:30:44 -04:00
rohitvinodmalhotra@gmail.com 8bc206833a restructure fe type 2025-04-28 14:08:18 -04:00
rohitvinodmalhotra@gmail.com 7cf61d8c0e add base domain param to provider token 2025-04-28 13:29:06 -04:00
46 changed files with 846 additions and 2120 deletions
@@ -91,13 +91,6 @@ describe("HomeScreen", () => {
screen.getByTestId("task-suggestions");
});
it("should have responsive layout for mobile and desktop screens", async () => {
renderHomeScreen();
const mainContainer = screen.getByTestId("home-screen").querySelector("main");
expect(mainContainer).toHaveClass("flex", "flex-col", "md:flex-row");
});
it("should filter the suggested tasks based on the selected repository", async () => {
const retrieveUserGitRepositoriesSpy = vi.spyOn(
GitService,
@@ -6,38 +6,55 @@ import { KeyStatusIcon } from "../key-status-icon";
interface GitHubTokenInputProps {
onChange: (value: string) => void;
onBaseDomainChange?: (value: string) => void;
isGitHubTokenSet: boolean;
name: string;
baseDomainSet?: string | null;
isSaas: boolean;
}
export function GitHubTokenInput({
onChange,
onBaseDomainChange,
isGitHubTokenSet,
name,
baseDomainSet,
isSaas,
}: GitHubTokenInputProps) {
const { t } = useTranslation();
return (
<div className="flex flex-col gap-6">
{!isSaas && (
<SettingsInput
testId={name}
name={name}
onChange={onChange}
label={t(I18nKey.GITHUB$TOKEN_LABEL)}
type="password"
className="w-[680px]"
placeholder={isGitHubTokenSet ? "<hidden>" : ""}
startContent={
isGitHubTokenSet && (
<KeyStatusIcon
testId="gh-set-token-indicator"
isSet={isGitHubTokenSet}
/>
)
}
/>
)}
<SettingsInput
testId={name}
name={name}
onChange={onChange}
label={t(I18nKey.GITHUB$TOKEN_LABEL)}
type="password"
onChange={onBaseDomainChange || (() => {})}
label={t(I18nKey.GITHUB$BASE_DOMAIN_LABEL)}
type="text"
className="w-[680px]"
placeholder={isGitHubTokenSet ? "<hidden>" : ""}
startContent={
isGitHubTokenSet && (
<KeyStatusIcon
testId="gh-set-token-indicator"
isSet={isGitHubTokenSet}
/>
)
}
placeholder={"github.com"}
defaultValue={baseDomainSet ? baseDomainSet : undefined}
/>
<GitHubTokenHelpAnchor />
{!isSaas && <GitHubTokenHelpAnchor />}
</div>
);
}
@@ -6,38 +6,55 @@ import { KeyStatusIcon } from "../key-status-icon";
interface GitLabTokenInputProps {
onChange: (value: string) => void;
onBaseDomainChange?: (value: string) => void;
isGitLabTokenSet: boolean;
name: string;
baseDomainSet?: string | null;
isSaas: boolean;
}
export function GitLabTokenInput({
onChange,
onBaseDomainChange,
isGitLabTokenSet,
name,
baseDomainSet,
isSaas,
}: GitLabTokenInputProps) {
const { t } = useTranslation();
return (
<div className="flex flex-col gap-6">
{!isSaas && (
<SettingsInput
testId={name}
name={name}
onChange={onChange}
label={t(I18nKey.GITLAB$TOKEN_LABEL)}
type="password"
className="w-[680px]"
placeholder={isGitLabTokenSet ? "<hidden>" : ""}
startContent={
isGitLabTokenSet && (
<KeyStatusIcon
testId="gl-set-token-indicator"
isSet={isGitLabTokenSet}
/>
)
}
/>
)}
<SettingsInput
testId={name}
name={name}
onChange={onChange}
label={t(I18nKey.GITLAB$TOKEN_LABEL)}
type="password"
onChange={onBaseDomainChange || (() => {})}
label={t(I18nKey.GITLAB$BASE_DOMAIN_LABEL)}
type="text"
className="w-[680px]"
placeholder={isGitLabTokenSet ? "<hidden>" : ""}
startContent={
isGitLabTokenSet && (
<KeyStatusIcon
testId="gl-set-token-indicator"
isSet={isGitLabTokenSet}
/>
)
}
placeholder={"gitlab.com"}
defaultValue={baseDomainSet ? baseDomainSet : undefined}
/>
<GitLabTokenHelpAnchor />
{!isSaas && <GitLabTokenHelpAnchor />}
</div>
);
}
+4 -6
View File
@@ -54,13 +54,11 @@ export const useSettings = () => {
React.useEffect(() => {
if (query.data?.PROVIDER_TOKENS_SET) {
const providers = query.data.PROVIDER_TOKENS_SET;
const setProviders = (
Object.keys(providers) as Array<keyof typeof providers>
).filter((key) => providers[key]);
const setProviders = Object.keys(providers) as Array<
keyof typeof providers
>;
setProviderTokensSet(setProviders);
const atLeastOneSet = Object.values(query.data.PROVIDER_TOKENS_SET).some(
(value) => value,
);
const atLeastOneSet = setProviders.length > 0;
setProvidersAreSet(atLeastOneSet);
}
}, [query.data?.PROVIDER_TOKENS_SET, query.isFetched]);
+2
View File
@@ -104,6 +104,7 @@ export enum I18nKey {
EXIT_PROJECT$TITLE = "EXIT_PROJECT$TITLE",
LANGUAGE$LABEL = "LANGUAGE$LABEL",
GITHUB$TOKEN_LABEL = "GITHUB$TOKEN_LABEL",
GITHUB$BASE_DOMAIN_LABEL = "GITHUB$BASE_DOMAIN_LABEL",
GITHUB$TOKEN_OPTIONAL = "GITHUB$TOKEN_OPTIONAL",
GITHUB$GET_TOKEN = "GITHUB$GET_TOKEN",
GITHUB$TOKEN_HELP_TEXT = "GITHUB$TOKEN_HELP_TEXT",
@@ -450,6 +451,7 @@ export enum I18nKey {
MODEL_SELECTOR$VERIFIED = "MODEL_SELECTOR$VERIFIED",
MODEL_SELECTOR$OTHERS = "MODEL_SELECTOR$OTHERS",
GITLAB$TOKEN_LABEL = "GITLAB$TOKEN_LABEL",
GITLAB$BASE_DOMAIN_LABEL = "GITLAB$BASE_DOMAIN_LABEL",
GITLAB$GET_TOKEN = "GITLAB$GET_TOKEN",
GITLAB$TOKEN_HELP_TEXT = "GITLAB$TOKEN_HELP_TEXT",
GITLAB$TOKEN_LINK_TEXT = "GITLAB$TOKEN_LINK_TEXT",
+30
View File
@@ -1569,6 +1569,21 @@
"tr": "GitHub Jetonu",
"de": "GitHub-Token"
},
"GITHUB$BASE_DOMAIN_LABEL": {
"en": "GitHub Base Domain",
"ja": "GitHub ベースドメイン",
"zh-CN": "GitHub 基础域名",
"zh-TW": "GitHub 基礎網域",
"ko-KR": "GitHub 기본 도메인",
"no": "GitHub Base Domain",
"it": "Dominio Base GitHub",
"pt": "Domínio Base do GitHub",
"es": "Dominio Base de GitHub",
"ar": "نطاق GitHub الأساسي",
"fr": "Domaine de Base GitHub",
"tr": "GitHub Temel Alan Adı",
"de": "GitHub Basis-Domain"
},
"GITHUB$TOKEN_OPTIONAL": {
"en": "GitHub Token (Optional)",
"ja": "GitHubトークン(任意)",
@@ -6469,6 +6484,21 @@
"tr": "GitLab Jetonu",
"de": "GitLab-Token"
},
"GITLAB$BASE_DOMAIN_LABEL": {
"en": "GitLab Base Domain",
"ja": "GitLab ベースドメイン",
"zh-CN": "GitLab 基础域名",
"zh-TW": "GitLab 基礎網域",
"ko-KR": "GitLab 기본 도메인",
"no": "GitLab Base Domain",
"it": "Dominio Base GitLab",
"pt": "Domínio Base do GitLab",
"es": "Dominio Base de GitLab",
"ar": "نطاق GitLab الأساسي",
"fr": "Domaine de Base GitLab",
"tr": "GitLab Temel Alan Adı",
"de": "GitLab Basis-Domain"
},
"GITLAB$GET_TOKEN": {
"en": "Generate a token on",
"ja": "トークンを生成する",
+42 -6
View File
@@ -15,11 +15,13 @@ import {
} from "#/utils/custom-toast-handlers";
import { retrieveAxiosErrorMessage } from "#/utils/retrieve-axios-error-message";
import { GitSettingInputsSkeleton } from "#/components/features/settings/git-settings/github-settings-inputs-skeleton";
import { useAuth } from "#/context/auth-context";
function GitSettingsScreen() {
const { t } = useTranslation();
const { mutate: saveSettings, isPending } = useSaveSettings();
const { providerTokensSet } = useAuth();
const { mutate: disconnectGitTokens } = useLogout();
const { data: settings, isLoading } = useSettings();
@@ -29,10 +31,17 @@ function GitSettingsScreen() {
React.useState(false);
const [gitlabTokenInputHasValue, setGitlabTokenInputHasValue] =
React.useState(false);
const [githubBaseDomainInputHasValue, setGithubBaseDomainInputHasValue] =
React.useState(false);
const [gitlabBaseDomainInputHasValue, setGitlabBaseDomainInputHasValue] =
React.useState(false);
const isSaas = config?.APP_MODE === "saas";
const isGitHubTokenSet = !!settings?.PROVIDER_TOKENS_SET.github;
const isGitLabTokenSet = !!settings?.PROVIDER_TOKENS_SET.gitlab;
const isGitHubTokenSet = providerTokensSet.includes("github");
const isGitLabTokenSet = providerTokensSet.includes("gitlab");
const existingGithubBaseDomain = settings?.PROVIDER_TOKENS_SET["github"];
const existingGitlabBaseDomain = settings?.PROVIDER_TOKENS_SET["gitlab"];
const formAction = async (formData: FormData) => {
const disconnectButtonClicked =
@@ -45,12 +54,22 @@ function GitSettingsScreen() {
const githubToken = formData.get("github-token-input")?.toString() || "";
const gitlabToken = formData.get("gitlab-token-input")?.toString() || "";
const githubBaseDomain =
formData.get("github-base-domain-input")?.toString() || "";
const gitlabBaseDomain =
formData.get("gitlab-base-domain-input")?.toString() || "";
saveSettings(
{
provider_tokens: {
github: githubToken,
gitlab: gitlabToken,
github: {
token: githubToken,
base_domain: githubBaseDomain || null,
},
gitlab: {
token: gitlabToken,
base_domain: gitlabBaseDomain || null,
},
},
},
{
@@ -64,12 +83,19 @@ function GitSettingsScreen() {
onSettled: () => {
setGithubTokenInputHasValue(false);
setGitlabTokenInputHasValue(false);
setGithubBaseDomainInputHasValue(false);
setGitlabBaseDomainInputHasValue(false);
},
},
);
};
const formIsClean = !githubTokenInputHasValue && !gitlabTokenInputHasValue;
const formIsClean =
!githubTokenInputHasValue &&
!gitlabTokenInputHasValue &&
!githubBaseDomainInputHasValue &&
!gitlabBaseDomainInputHasValue;
const shouldRenderExternalConfigureButtons = isSaas && config.APP_SLUG;
return (
@@ -84,22 +110,32 @@ function GitSettingsScreen() {
<ConfigureGitHubRepositoriesAnchor slug={config.APP_SLUG!} />
)}
{!isSaas && !isLoading && (
{!isLoading && (
<div className="p-9 flex flex-col gap-12">
<GitHubTokenInput
name="github-token-input"
baseDomainSet={existingGithubBaseDomain}
isGitHubTokenSet={isGitHubTokenSet}
onChange={(value) => {
setGithubTokenInputHasValue(!!value);
}}
onBaseDomainChange={(value) => {
setGithubBaseDomainInputHasValue(!!value);
}}
isSaas={isSaas}
/>
<GitLabTokenInput
name="gitlab-token-input"
baseDomainSet={existingGitlabBaseDomain}
isGitLabTokenSet={isGitLabTokenSet}
onChange={(value) => {
setGitlabTokenInputHasValue(!!value);
}}
onBaseDomainChange={(value) => {
setGitlabBaseDomainInputHasValue(!!value);
}}
isSaas={isSaas}
/>
</div>
)}
+1 -1
View File
@@ -22,7 +22,7 @@ function HomeScreen() {
<hr className="border-[#717888]" />
<main className="flex flex-col md:flex-row justify-between gap-4">
<main className="flex justify-between gap-4">
<RepoConnector
onRepoSelection={(title) => setSelectedRepoTitle(title)}
/>
+3 -3
View File
@@ -11,13 +11,13 @@ export const DEFAULT_SETTINGS: Settings = {
CONFIRMATION_MODE: false,
SECURITY_ANALYZER: "",
REMOTE_RUNTIME_RESOURCE_FACTOR: 1,
PROVIDER_TOKENS_SET: { github: false, gitlab: false },
PROVIDER_TOKENS_SET: { github: null, gitlab: null },
ENABLE_DEFAULT_CONDENSER: true,
ENABLE_SOUND_NOTIFICATIONS: false,
USER_CONSENTS_TO_ANALYTICS: false,
PROVIDER_TOKENS: {
github: "",
gitlab: "",
github: { token: "", base_domain: null },
gitlab: { token: "", base_domain: null },
},
IS_NEW_USER: true,
};
+11 -6
View File
@@ -5,6 +5,11 @@ export const ProviderOptions = {
export type Provider = keyof typeof ProviderOptions;
export type ProviderToken = {
token: string;
base_domain: string | null;
};
export type Settings = {
LLM_MODEL: string;
LLM_BASE_URL: string;
@@ -14,11 +19,11 @@ export type Settings = {
CONFIRMATION_MODE: boolean;
SECURITY_ANALYZER: string;
REMOTE_RUNTIME_RESOURCE_FACTOR: number | null;
PROVIDER_TOKENS_SET: Record<Provider, boolean>;
PROVIDER_TOKENS_SET: Record<Provider, string | null>;
ENABLE_DEFAULT_CONDENSER: boolean;
ENABLE_SOUND_NOTIFICATIONS: boolean;
USER_CONSENTS_TO_ANALYTICS: boolean | null;
PROVIDER_TOKENS: Record<Provider, string>;
PROVIDER_TOKENS: Record<Provider, ProviderToken>;
IS_NEW_USER?: boolean;
};
@@ -35,17 +40,17 @@ export type ApiSettings = {
enable_default_condenser: boolean;
enable_sound_notifications: boolean;
user_consents_to_analytics: boolean | null;
provider_tokens: Record<Provider, string>;
provider_tokens_set: Record<Provider, boolean>;
provider_tokens: Record<Provider, ProviderToken>;
provider_tokens_set: Record<Provider, string | null>;
};
export type PostSettings = Settings & {
provider_tokens: Record<Provider, string>;
provider_tokens: Record<Provider, ProviderToken>;
user_consents_to_analytics: boolean | null;
llm_api_key?: string | null;
};
export type PostApiSettings = ApiSettings & {
provider_tokens: Record<Provider, string>;
provider_tokens: Record<Provider, ProviderToken>;
user_consents_to_analytics: boolean | null;
};
@@ -0,0 +1,64 @@
---
name: add_openhands_repo_instruction
version: 1.0.0
author: openhands
agent: CodeActAgent
inputs:
- name: REPO_FOLDER_NAME
description: "Branch for the agent to work on"
required: false
---
Please browse the current repository under /workspace/{{ REPO_FOLDER_NAME }}, look at the documentation and relevant code, and understand the purpose of this repository.
Specifically, I want you to create a `.openhands/microagents/repo.md` file. This file should contain succinct information that summarizes (1) the purpose of this repository, (2) the general setup of this repo, and (3) a brief description of the structure of this repo.
Here's an example:
```markdown
---
name: repo
type: repo
agent: CodeActAgent
---
This repository contains the code for OpenHands, an automated AI software engineer. It has a Python backend
(in the `openhands` directory) and React frontend (in the `frontend` directory).
## General Setup:
To set up the entire repo, including frontend and backend, run `make build`.
You don't need to do this unless the user asks you to, or if you're trying to run the entire application.
Before pushing any changes, you should ensure that any lint errors or simple test errors have been fixed.
* If you've made changes to the backend, you should run `pre-commit run --all-files --config ./dev_config/python/.pre-commit-config.yaml`
* If you've made changes to the frontend, you should run `cd frontend && npm run lint:fix && npm run build ; cd ..`
If either command fails, it may have automatically fixed some issues. You should fix any issues that weren't automatically fixed,
then re-run the command to ensure it passes.
## Repository Structure
Backend:
- Located in the `openhands` directory
- Testing:
- All tests are in `tests/unit/test_*.py`
- To test new code, run `poetry run pytest tests/unit/test_xxx.py` where `xxx` is the appropriate file for the current functionality
- Write all tests with pytest
Frontend:
- Located in the `frontend` directory
- Prerequisites: A recent version of NodeJS / NPM
- Setup: Run `npm install` in the frontend directory
- Testing:
- Run tests: `npm run test`
- To run specific tests: `npm run test -- -t "TestName"`
- Building:
- Build for production: `npm run build`
- Environment Variables:
- Set in `frontend/.env` or as environment variables
- Available variables: VITE_BACKEND_HOST, VITE_USE_TLS, VITE_INSECURE_SKIP_VERIFY, VITE_FRONTEND_PORT
- Internationalization:
- Generate i18n declaration file: `npm run make-i18n`
```
Now, please write a similar markdown for the current repository.
Read all the GitHub workflows under .github/ of the repository (if this folder exists) to understand the CI checks (e.g., linter, pre-commit), and include those in the repo.md file.
+19
View File
@@ -0,0 +1,19 @@
---
name: address_pr_comments
version: 1.0.0
author: openhands
agent: CodeActAgent
inputs:
- name: PR_URL
description: "URL of the pull request"
required: true
- name: BRANCH_NAME
description: "Branch name corresponds to the pull request"
required: true
---
First, check the branch {{ BRANCH_NAME }} and read the diff against the main branch to understand the purpose.
This branch corresponds to this PR {{ PR_URL }}
Next, you should use the GitHub API to read the reviews and comments on this PR and address them.
-32
View File
@@ -1,32 +0,0 @@
---
name: agent_memory
type: knowledge
version: 1.0.0
agent: CodeActAgent
triggers:
- /remember
---
* Repository memory: Use .openhands/microagents/repo.md under each repository root to store and access important information.
- If this file exists, it will be added to your context automatically.
- If missing, you should create it unless the user has explicitly asked you to not do so.
* Store and maintain **general knowledge** that will be helpful for most future tasks:
1. Repository structure
2. Common commands (build, lint, test, pre-commit, etc.)
3. Code style preferences
4. Workflows and best practices
5. Any other repository-specific knowledge you learn
* IMPORTANT: ONLY LOG the information that would be helpful for different future tasks, for example, how to configure the settings, how to setup the repository. Do NOT add issue-specific information (e.g., what specific error you have ran into and how you fix it).
* When adding new information:
- ALWAYS ask for user confirmation first by listing the exact items (numbered 1, 2, 3, etc.) you plan to save to repo.md
- Only save the items the user approves (they may ask you to save a subset)
- Ensure it integrates nicely with existing knowledge in repo.md
- Reorganize the content if needed to maintain clarity and organization
- Group related information together under appropriate sections or headings
- If you've only explored a portion of the codebase, clearly note this limitation in the repository structure documentation
- If you don't know the essential commands for working with the repository, such as lint or typecheck, ask the user and suggest adding them to repo.md for future reference (with permission)
When you receive this message, please review and summarize your recent actions and observations, then present a list of valuable information that should be saved in repo.md to the user.
+27
View File
@@ -0,0 +1,27 @@
---
name: get_test_to_pass
version: 1.0.0
author: openhands
agent: CodeActAgent
inputs:
- name: BRANCH_NAME
description: "Branch for the agent to work on"
required: true
- name: TEST_COMMAND_TO_RUN
description: "The test command you want the agent to work on. For example, `pytest tests/unit/test_bash_parsing.py`"
required: true
- name: FUNCTION_TO_FIX
description: "The name of function to fix"
required: false
- name: FILE_FOR_FUNCTION
description: "The path of the file that contains the function"
required: false
---
Can you check out branch "{{ BRANCH_NAME }}", and run {{ TEST_COMMAND_TO_RUN }}.
{%- if FUNCTION_TO_FIX and FILE_FOR_FUNCTION %}
Help me fix these tests to pass by fixing the {{ FUNCTION_TO_FIX }} function in file {{ FILE_FOR_FUNCTION }}.
{%- endif %}
PLEASE DO NOT modify the tests by yourselves -- Let me know if you think some of the tests are incorrect.
+21
View File
@@ -0,0 +1,21 @@
---
name: update_pr_description
version: 1.0.0
author: openhands
agent: CodeActAgent
inputs:
- name: PR_URL
description: "URL of the pull request"
type: string
required: true
validation:
pattern: "^https://github.com/.+/.+/pull/[0-9]+$"
- name: BRANCH_NAME
description: "Branch name corresponds to the pull request"
type: string
required: true
---
Please check the branch "{{ BRANCH_NAME }}" and look at the diff against the main branch. This branch belongs to this PR "{{ PR_URL }}".
Once you understand the purpose of the diff, please use Github API to read the existing PR description, and update it to be more reflective of the changes we've made when necessary.
@@ -0,0 +1,21 @@
---
name: update_test_for_new_implementation
version: 1.0.0
author: openhands
agent: CodeActAgent
inputs:
- name: BRANCH_NAME
description: "Branch for the agent to work on"
required: true
- name: TEST_COMMAND_TO_RUN
description: "The test command you want the agent to work on. For example, `pytest tests/unit/test_bash_parsing.py`"
required: true
---
Can you check out branch "{{ BRANCH_NAME }}", and run {{ TEST_COMMAND_TO_RUN }}.
{%- if FUNCTION_TO_FIX and FILE_FOR_FUNCTION %}
Help me fix these tests to pass by fixing the {{ FUNCTION_TO_FIX }} function in file {{ FILE_FOR_FUNCTION }}.
{%- endif %}
PLEASE DO NOT modify the tests by yourselves -- Let me know if you think some of the tests are incorrect.
@@ -20,7 +20,10 @@ from openhands.controller.state.state import State
from openhands.core.config import AgentConfig
from openhands.core.logger import openhands_logger as logger
from openhands.core.message import Message
from openhands.events.action import Action, AgentFinishAction, MessageAction
from openhands.events.action import (
Action,
AgentFinishAction,
)
from openhands.events.event import Event
from openhands.llm.llm import LLM
from openhands.memory.condenser import Condenser
@@ -170,8 +173,7 @@ class CodeActAgent(Agent):
f'Processing {len(condensed_history)} events from a total of {len(state.history)} events'
)
initial_user_message = self._get_initial_user_message(state.history)
messages = self._get_messages(condensed_history, initial_user_message)
messages = self._get_messages(condensed_history)
params: dict = {
'messages': self.llm.format_messages_for_llm(messages),
}
@@ -214,29 +216,7 @@ class CodeActAgent(Agent):
self.pending_actions.append(action)
return self.pending_actions.popleft()
def _get_initial_user_message(self, history: list[Event]) -> MessageAction:
"""Finds the initial user message action from the full history."""
initial_user_message: MessageAction | None = None
for event in history:
if isinstance(event, MessageAction) and event.source == 'user':
initial_user_message = event
break
if initial_user_message is None:
# This should not happen in a valid conversation
logger.error(
f'CRITICAL: Could not find the initial user MessageAction in the full {len(history)} events history.'
)
# Depending on desired robustness, could raise error or create a dummy action
# and log the error
raise ValueError(
'Initial user message not found in history. Please report this issue.'
)
return initial_user_message
def _get_messages(
self, events: list[Event], initial_user_message: MessageAction
) -> list[Message]:
def _get_messages(self, events: list[Event]) -> list[Message]:
"""Constructs the message history for the LLM conversation.
This method builds a structured conversation history by processing events from the state
@@ -273,7 +253,6 @@ class CodeActAgent(Agent):
# Use ConversationMemory to process events (including SystemMessageAction)
messages = self.conversation_memory.process_events(
condensed_history=events,
initial_user_action=initial_user_message,
max_message_chars=self.llm.config.max_message_chars,
vision_is_active=self.llm.vision_is_active(),
)
+71 -174
View File
@@ -190,7 +190,7 @@ class AgentController:
logger.debug(f'System message got from agent: {system_message}')
if system_message:
self.event_stream.add_event(system_message, EventSource.AGENT)
logger.info(f'System message added to event stream: {system_message}')
logger.debug(f'System message added to event stream: {system_message}')
async def close(self, set_stop_state: bool = True) -> None:
"""Closes the agent controller, canceling any ongoing tasks and unsubscribing from the event stream.
@@ -260,43 +260,39 @@ class AgentController:
# Store the error reason before setting the agent state
self.state.last_error = f'{type(e).__name__}: {str(e)}'
if isinstance(e, RateLimitError):
await self.set_agent_state_to(AgentState.RATE_LIMITED)
return
err_id = ''
err_details = type(e).__name__
if isinstance(e, AuthenticationError):
err_id = 'STATUS$ERROR_LLM_AUTHENTICATION'
elif isinstance(
e,
(
ServiceUnavailableError,
APIConnectionError,
APIError,
),
):
err_id = 'STATUS$ERROR_LLM_SERVICE_UNAVAILABLE'
elif isinstance(e, InternalServerError):
err_id = 'STATUS$ERROR_LLM_INTERNAL_SERVER_ERROR'
elif isinstance(e, BadRequestError) and 'ExceededBudget' in str(e):
err_id = 'STATUS$ERROR_LLM_OUT_OF_CREDITS'
elif isinstance(e, ContentPolicyViolationError) or (
isinstance(e, BadRequestError) and 'ContentPolicyViolationError' in str(e)
):
err_id = 'STATUS$ERROR_LLM_CONTENT_POLICY_VIOLATION'
if err_id:
# These err_details will end up on the frontend. We only plumb through known errors
# listed above to avoid exposing sensitive information
err_details = type(e).__name__ + ': ' + str(e)
self.state.last_error = err_details
else:
self.state.last_error = f'{type(e).__name__}: {str(e)}'
if self.status_callback is not None:
self.status_callback('error', err_id, err_details)
err_id = ''
if isinstance(e, AuthenticationError):
err_id = 'STATUS$ERROR_LLM_AUTHENTICATION'
self.state.last_error = err_id
elif isinstance(
e,
(
ServiceUnavailableError,
APIConnectionError,
APIError,
),
):
err_id = 'STATUS$ERROR_LLM_SERVICE_UNAVAILABLE'
self.state.last_error = err_id
elif isinstance(e, InternalServerError):
err_id = 'STATUS$ERROR_LLM_INTERNAL_SERVER_ERROR'
self.state.last_error = err_id
elif isinstance(e, BadRequestError) and 'ExceededBudget' in str(e):
err_id = 'STATUS$ERROR_LLM_OUT_OF_CREDITS'
self.state.last_error = err_id
elif isinstance(e, ContentPolicyViolationError) or (
isinstance(e, BadRequestError)
and 'ContentPolicyViolationError' in str(e)
):
err_id = 'STATUS$ERROR_LLM_CONTENT_POLICY_VIOLATION'
self.state.last_error = err_id
elif isinstance(e, RateLimitError):
await self.set_agent_state_to(AgentState.RATE_LIMITED)
return
self.status_callback('error', err_id, self.state.last_error)
# Set the agent state to ERROR after storing the reason
await self.set_agent_state_to(AgentState.ERROR)
def step(self) -> None:
@@ -1024,7 +1020,7 @@ class AgentController:
self.state.start_id = 0
self.log(
'info',
'debug',
f'AgentController {self.id} - created new state. start_id: {self.state.start_id}',
)
else:
@@ -1034,7 +1030,7 @@ class AgentController:
self.state.start_id = 0
self.log(
'info',
'debug',
f'AgentController {self.id} initializing history from event {self.state.start_id}',
)
@@ -1147,169 +1143,70 @@ class AgentController:
def _handle_long_context_error(self) -> None:
# When context window is exceeded, keep roughly half of agent interactions
kept_events = self._apply_conversation_window()
kept_event_ids = {e.id for e in kept_events}
self.log(
'info',
f'Context window exceeded. Keeping events with IDs: {kept_event_ids}',
)
# The events to forget are those that are not in the kept set
kept_event_ids = {
e.id for e in self._apply_conversation_window(self.state.history)
}
forgotten_event_ids = {e.id for e in self.state.history} - kept_event_ids
if len(kept_event_ids) == 0:
self.log(
'warning',
'No events kept after applying conversation window. This should not happen.',
)
# verify that the first event id in kept_event_ids is the same as the start_id
if len(kept_event_ids) > 0 and self.state.history[0].id not in kept_event_ids:
self.log(
'warning',
f'First event after applying conversation window was not kept: {self.state.history[0].id} not in {kept_event_ids}',
)
# Save the ID of the first event in our truncated history for future reloading
if self.state.history:
self.state.start_id = self.state.history[0].id
# Add an error event to trigger another step by the agent
self.event_stream.add_event(
CondensationAction(
forgotten_events_start_id=min(forgotten_event_ids)
if forgotten_event_ids
else 0,
forgotten_events_end_id=max(forgotten_event_ids)
if forgotten_event_ids
else 0,
forgotten_events_start_id=min(forgotten_event_ids),
forgotten_events_end_id=max(forgotten_event_ids),
),
EventSource.AGENT,
)
def _apply_conversation_window(self) -> list[Event]:
def _apply_conversation_window(self, events: list[Event]) -> list[Event]:
"""Cuts history roughly in half when context window is exceeded.
It preserves action-observation pairs and ensures that the system message,
the first user message, and its associated recall observation are always included
at the beginning of the context window.
It preserves action-observation pairs and ensures that the first user message is always included.
The algorithm:
1. Identify essential initial events: System Message, First User Message, Recall Observation.
2. Determine the slice of recent events to potentially keep.
3. Validate the start of the recent slice for dangling observations.
4. Combine essential events and validated recent events, ensuring essentials come first.
1. Cut history in half
2. Check first event in new history:
- If Observation: find and include its Action
- If MessageAction: ensure its related Action-Observation pair isn't split
3. Always include the first user message
Args:
events: List of events to filter
Returns:
Filtered list of events keeping newest half while preserving pairs and essential initial events.
Filtered list of events keeping newest half while preserving pairs
"""
if not self.state.history:
return []
if not events:
return events
history = self.state.history
# 1. Identify essential initial events
system_message: SystemMessageAction | None = None
first_user_msg: MessageAction | None = None
recall_action: RecallAction | None = None
recall_observation: Observation | None = None
# Find System Message (should be the first event, if it exists)
system_message = next(
(e for e in history if isinstance(e, SystemMessageAction)), None
)
assert (
system_message is None
or isinstance(system_message, SystemMessageAction)
and system_message.id == history[0].id
# Find first user message - we'll need to ensure it's included
first_user_msg = next(
(
e
for e in events
if isinstance(e, MessageAction) and e.source == EventSource.USER
),
None,
)
# Find First User Message, which MUST exist
first_user_msg = self._first_user_message()
if first_user_msg is None:
raise RuntimeError('No first user message found in the event stream.')
# cut in half
mid_point = max(1, len(events) // 2)
kept_events = events[mid_point:]
if len(kept_events) > 0 and isinstance(kept_events[0], Observation):
kept_events = kept_events[1:]
first_user_msg_index = -1
for i, event in enumerate(history):
if isinstance(event, MessageAction) and event.source == EventSource.USER:
first_user_msg = event
first_user_msg_index = i
break
# Ensure first user message is included
if first_user_msg and first_user_msg not in kept_events:
kept_events = [first_user_msg] + kept_events
# Find Recall Action and Observation related to the First User Message
if first_user_msg is not None and first_user_msg_index != -1:
# Look for RecallAction after the first user message
for i in range(first_user_msg_index + 1, len(history)):
event = history[i]
if (
isinstance(event, RecallAction)
and event.query == first_user_msg.content
):
# Found RecallAction, now look for its Observation
recall_action = event
for j in range(i + 1, len(history)):
obs_event = history[j]
# Check for Observation caused by this RecallAction
if (
isinstance(obs_event, Observation)
and obs_event.cause == recall_action.id
):
recall_observation = obs_event
break # Found the observation, stop inner loop
break # Found the recall action (and maybe obs), stop outer loop
essential_events: list[Event] = []
if system_message:
essential_events.append(system_message)
# start_id points to first user message
if first_user_msg:
essential_events.append(first_user_msg)
# Also keep the RecallAction that triggered the essential RecallObservation
if recall_action:
essential_events.append(recall_action)
if recall_observation:
essential_events.append(recall_observation)
self.state.start_id = first_user_msg.id
# 2. Determine the slice of recent events to potentially keep
num_non_essential_events = len(history) - len(essential_events)
# Keep roughly half of the non-essential events, minimum 1
num_recent_to_keep = max(1, num_non_essential_events // 2)
# Calculate the starting index for the recent slice
slice_start_index = len(history) - num_recent_to_keep
slice_start_index = max(0, slice_start_index) # Ensure index is not negative
recent_events_slice = history[slice_start_index:]
# 3. Validate the start of the recent slice for dangling observations
# IMPORTANT: Most observations in history are tool call results, which cannot be without their action, or we get an LLM API error
first_valid_event_index = 0
for i, event in enumerate(recent_events_slice):
if isinstance(event, Observation):
first_valid_event_index += 1
else:
break
# If all events in the slice are dangling observations, we need to keep at least one
if first_valid_event_index == len(recent_events_slice):
self.log(
'warning',
'All recent events are dangling observations, which we truncate. This means the agent has only the essential first events. This should not happen.',
)
# Adjust the recent_events_slice if dangling observations were found at the start
if first_valid_event_index < len(recent_events_slice):
validated_recent_events = recent_events_slice[first_valid_event_index:]
if first_valid_event_index > 0:
self.log(
'debug',
f'Removed {first_valid_event_index} dangling observation(s) from the start of recent event slice.',
)
else:
validated_recent_events = []
# 4. Combine essential events and validated recent events
events_to_keep: list[Event] = essential_events + validated_recent_events
self.log('debug', f'History truncated. Kept {len(events_to_keep)} events.')
return events_to_keep
return kept_events
def _is_stuck(self) -> bool:
"""Checks if the agent or its delegate is stuck in a loop.
+17 -39
View File
@@ -14,14 +14,12 @@ from openhands.core.cli_commands import (
)
from openhands.core.cli_tui import (
UsageMetrics,
display_agent_running_message,
display_banner,
display_event,
display_initial_user_prompt,
display_initialization_animation,
display_runtime_initialization_message,
display_welcome_message,
process_agent_pause,
read_confirmation_input,
read_prompt_input,
)
@@ -101,7 +99,6 @@ async def run_session(
sid = str(uuid4())
is_loaded = asyncio.Event()
is_paused = asyncio.Event()
# Show runtime initialization message
display_runtime_initialization_message(config.runtime)
@@ -127,12 +124,10 @@ async def run_session(
usage_metrics = UsageMetrics()
async def prompt_for_next_task(agent_state: str):
async def prompt_for_next_task():
nonlocal reload_microagents, new_session_requested
while True:
next_message = await read_prompt_input(
agent_state, multiline=config.cli_multiline_input
)
next_message = await read_prompt_input(config.cli_multiline_input)
if not next_message.strip():
continue
@@ -155,23 +150,14 @@ async def run_session(
return
async def on_event_async(event: Event) -> None:
nonlocal reload_microagents, is_paused
nonlocal reload_microagents
display_event(event, config)
update_usage_metrics(event, usage_metrics)
# Pause the agent if the pause event is set (if Ctrl-P is pressed)
if is_paused.is_set():
event_stream.add_event(
ChangeAgentStateAction(AgentState.PAUSED),
EventSource.USER,
)
is_paused.clear()
if isinstance(event, AgentStateChangedObservation):
if event.agent_state in [
AgentState.AWAITING_USER_INPUT,
AgentState.FINISHED,
AgentState.PAUSED,
]:
# Reload microagents after initialization of repo.md
if reload_microagents:
@@ -180,28 +166,20 @@ async def run_session(
)
memory.load_user_workspace_microagents(microagents)
reload_microagents = False
await prompt_for_next_task(event.agent_state)
await prompt_for_next_task()
if event.agent_state == AgentState.AWAITING_USER_CONFIRMATION:
# Only display the confirmation prompt if the agent is not paused
if not is_paused.is_set():
user_confirmed = await read_confirmation_input()
if user_confirmed:
event_stream.add_event(
ChangeAgentStateAction(AgentState.USER_CONFIRMED),
EventSource.USER,
)
else:
event_stream.add_event(
ChangeAgentStateAction(AgentState.USER_REJECTED),
EventSource.USER,
)
if event.agent_state == AgentState.RUNNING:
# Enable pause/resume functionality only if the confirmation mode is disabled
if not config.security.confirmation_mode:
display_agent_running_message()
loop.create_task(process_agent_pause(is_paused))
user_confirmed = await read_confirmation_input()
if user_confirmed:
event_stream.add_event(
ChangeAgentStateAction(AgentState.USER_CONFIRMED),
EventSource.USER,
)
else:
event_stream.add_event(
ChangeAgentStateAction(AgentState.USER_REJECTED),
EventSource.USER,
)
def on_event(event: Event) -> None:
loop.create_task(on_event_async(event))
@@ -234,7 +212,7 @@ async def run_session(
clear()
# Show OpenHands banner and session ID
display_banner(session_id=sid)
display_banner(session_id=sid, is_loaded=is_loaded)
# Show OpenHands welcome
display_welcome_message()
@@ -247,7 +225,7 @@ async def run_session(
)
else:
# Otherwise prompt for the user's first message right away
asyncio.create_task(prompt_for_next_task(''))
asyncio.create_task(prompt_for_next_task())
await run_agent_until_done(
controller, runtime, memory, [AgentState.STOPPED, AgentState.ERROR]
-24
View File
@@ -70,8 +70,6 @@ async def handle_commands(
)
elif command == '/settings':
await handle_settings_command(config, settings_store)
elif command == '/resume':
close_repl, new_session_requested = await handle_resume_command(event_stream)
else:
close_repl = True
action = MessageAction(content=command)
@@ -185,28 +183,6 @@ async def handle_settings_command(
await modify_llm_settings_advanced(config, settings_store)
# FIXME: Currently there's an issue with the actual 'resume' behavior.
# Setting the agent state to RUNNING will currently freeze the agent without continuing with the rest of the task.
# This is a workaround to handle the resume command for the time being. Replace user message with the state change event once the issue is fixed.
async def handle_resume_command(
event_stream: EventStream,
) -> tuple[bool, bool]:
close_repl = True
new_session_requested = False
event_stream.add_event(
MessageAction(content='continue'),
EventSource.USER,
)
# event_stream.add_event(
# ChangeAgentStateAction(AgentState.RUNNING),
# EventSource.ENVIRONMENT,
# )
return close_repl, new_session_requested
async def init_repository(current_dir: str) -> bool:
repo_file_path = Path(current_dir) / '.openhands' / 'microagents' / 'repo.md'
init_repo = False
+75 -76
View File
@@ -10,9 +10,7 @@ from prompt_toolkit import PromptSession, print_formatted_text
from prompt_toolkit.application import Application
from prompt_toolkit.completion import Completer, Completion
from prompt_toolkit.formatted_text import HTML, FormattedText, StyleAndTextTuples
from prompt_toolkit.input import create_input
from prompt_toolkit.key_binding import KeyBindings
from prompt_toolkit.keys import Keys
from prompt_toolkit.layout.containers import HSplit, Window
from prompt_toolkit.layout.controls import FormattedTextControl
from prompt_toolkit.layout.layout import Layout
@@ -24,7 +22,6 @@ from prompt_toolkit.widgets import Frame, TextArea
from openhands import __version__
from openhands.core.config import AppConfig
from openhands.core.schema import AgentState
from openhands.events import EventSource
from openhands.events.action import (
Action,
@@ -35,7 +32,6 @@ from openhands.events.action import (
)
from openhands.events.event import Event
from openhands.events.observation import (
AgentStateChangedObservation,
CmdOutputObservation,
FileEditObservation,
FileReadObservation,
@@ -60,7 +56,6 @@ COMMANDS = {
'/status': 'Display session details and usage metrics',
'/new': 'Create a new session',
'/settings': 'Display and modify current settings',
'/resume': 'Resume the agent',
}
@@ -119,7 +114,7 @@ def display_initialization_animation(text, is_loaded: asyncio.Event):
sys.stdout.flush()
def display_banner(session_id: str):
def display_banner(session_id: str, is_loaded: asyncio.Event):
print_formatted_text(
HTML(r"""<gold>
___ _ _ _
@@ -134,8 +129,11 @@ def display_banner(session_id: str):
print_formatted_text(HTML(f'<grey>OpenHands CLI v{__version__}</grey>'))
banner_text = (
'Initialized session' if is_loaded.is_set() else 'Initializing session'
)
print_formatted_text('')
print_formatted_text(HTML(f'<grey>Initialized session {session_id}</grey>'))
print_formatted_text(HTML(f'<grey>{banner_text} {session_id}</grey>'))
print_formatted_text('')
@@ -179,8 +177,6 @@ def display_event(event: Event, config: AppConfig) -> None:
display_file_edit(event)
if isinstance(event, FileReadObservation):
display_file_read(event)
if isinstance(event, AgentStateChangedObservation):
display_agent_paused_message(event.agent_state)
def display_message(message: str):
@@ -393,58 +389,77 @@ def display_status(usage_metrics: UsageMetrics, session_id: str):
display_usage_metrics(usage_metrics)
def display_agent_running_message():
print_formatted_text('')
print_formatted_text(
HTML('<gold>Agent running...</gold> <grey>(Ctrl-P to pause)</grey>')
)
def display_agent_paused_message(agent_state: str):
if agent_state != AgentState.PAUSED:
return
print_formatted_text('')
print_formatted_text(
HTML('<gold>Agent paused</gold> <grey>(type /resume to resume)</grey>')
)
# Common input functions
class CommandCompleter(Completer):
"""Custom completer for commands."""
def __init__(self, agent_state: str):
super().__init__()
self.agent_state = agent_state
def get_completions(self, document, complete_event):
text = document.text_before_cursor.lstrip()
text = document.text
# Only show completions if the user has typed '/'
if text.startswith('/'):
available_commands = dict(COMMANDS)
if self.agent_state != AgentState.PAUSED:
available_commands.pop('/resume', None)
for command, description in available_commands.items():
if command.startswith(text):
# If just '/' is typed, show all commands
if text == '/':
for command, description in COMMANDS.items():
yield Completion(
command,
start_position=-len(text),
display_meta=description,
style='bg:ansidarkgray fg:ansiwhite',
command[1:], # Remove the leading '/' as it's already typed
start_position=0,
display=f'{command} - {description}',
)
# Otherwise show matching commands
else:
for command, description in COMMANDS.items():
if command.startswith(text):
yield Completion(
command[len(text) :], # Complete the remaining part
start_position=0,
display=f'{command} - {description}',
)
def create_prompt_session():
return PromptSession(style=DEFAULT_STYLE)
prompt_session = PromptSession(style=DEFAULT_STYLE)
# RPrompt animation related variables
SPINNER_FRAMES = [
'[ ■□□□ ]',
'[ □■□□ ]',
'[ □□■□ ]',
'[ □□□■ ]',
'[ □□■□ ]',
'[ □■□□ ]',
]
ANIMATION_INTERVAL = 0.2 # seconds
current_frame_index = 0
last_update_time = time.monotonic()
async def read_prompt_input(agent_state: str, multiline=False):
# RPrompt function for the user confirmation
def get_rprompt() -> FormattedText:
"""
Returns the current animation frame for the rprompt.
This function is called by prompt_toolkit during rendering.
"""
global current_frame_index, last_update_time
# Only update the frame if enough time has passed
# This prevents excessive recalculation during rendering
now = time.monotonic()
if now - last_update_time > ANIMATION_INTERVAL:
current_frame_index = (current_frame_index + 1) % len(SPINNER_FRAMES)
last_update_time = now
# Return the frame wrapped in FormattedText
return FormattedText(
[
('', ' '), # Add a space before the spinner
(COLOR_GOLD, SPINNER_FRAMES[current_frame_index]),
]
)
async def read_prompt_input(multiline=False):
try:
prompt_session = create_prompt_session()
prompt_session.completer = (
CommandCompleter(agent_state) if not multiline else None
)
if multiline:
kb = KeyBindings()
@@ -455,54 +470,38 @@ async def read_prompt_input(agent_state: str, multiline=False):
with patch_stdout():
print_formatted_text('')
message = await prompt_session.prompt_async(
HTML(
'<gold>Enter your message and press Ctrl-D to finish:</gold>\n'
),
'Enter your message and press Ctrl+D to finish:\n',
multiline=True,
key_bindings=kb,
)
else:
with patch_stdout():
print_formatted_text('')
prompt_session.completer = CommandCompleter()
message = await prompt_session.prompt_async(
HTML('<gold>> </gold>'),
'> ',
)
return message if message is not None else ''
return message
except (KeyboardInterrupt, EOFError):
return '/exit'
async def read_confirmation_input() -> bool:
async def read_confirmation_input():
try:
prompt_session = create_prompt_session()
with patch_stdout():
print_formatted_text('')
confirmation: str = await prompt_session.prompt_async(
HTML('<gold>Proceed with action? (y)es/(n)o > </gold>'),
prompt_session.completer = None
confirmation = await prompt_session.prompt_async(
'Proceed with action? (y)es/(n)o > ',
rprompt=get_rprompt,
refresh_interval=ANIMATION_INTERVAL / 2,
)
confirmation = '' if confirmation is None else confirmation.strip().lower()
prompt_session.rprompt = None
confirmation = confirmation.strip().lower()
return confirmation in ['y', 'yes']
except (KeyboardInterrupt, EOFError):
return False
async def process_agent_pause(done: asyncio.Event) -> None:
input = create_input()
def keys_ready():
for key_press in input.read_keys():
if key_press.key == Keys.ControlP:
print_formatted_text('')
print_formatted_text(HTML('<gold>Pausing the agent...</gold>'))
done.set()
with input.raw_mode():
with input.attach(keys_ready):
await done.wait()
def cli_confirm(
question: str = 'Are you sure?', choices: list[str] | None = None
) -> int:
@@ -41,7 +41,7 @@ class GitHubService(BaseGitService, GitService):
if token:
self.token = token
if base_domain:
if base_domain and base_domain != "github.com":
self.BASE_URL = f'https://{base_domain}/api/v3'
@property
+10 -3
View File
@@ -34,6 +34,7 @@ from openhands.server.types import AppMode
class ProviderToken(BaseModel):
token: SecretStr | None = Field(default=None)
user_id: str | None = Field(default=None)
base_domain: str | None = Field(default=None)
model_config = {
'frozen': True, # Makes the entire model immutable
@@ -43,15 +44,20 @@ class ProviderToken(BaseModel):
@classmethod
def from_value(cls, token_value: ProviderToken | dict[str, str]) -> ProviderToken:
"""Factory method to create a ProviderToken from various input types"""
if isinstance(token_value, ProviderToken):
if isinstance(token_value, cls):
return token_value
elif isinstance(token_value, dict):
token_str = token_value.get('token')
user_id = token_value.get('user_id')
return cls(token=SecretStr(token_str), user_id=user_id)
base_domain = token_value.get('base_domain')
return cls(
token=SecretStr(token_str) if token_str is not None else None,
user_id=user_id,
base_domain=base_domain,
)
else:
raise ValueError('Unsupport Provider token type')
raise ValueError('Unsupported Provider token type')
PROVIDER_TOKEN_TYPE = MappingProxyType[ProviderType, ProviderToken]
@@ -98,6 +104,7 @@ class SecretStore(BaseModel):
if expose_secrets
else pydantic_encoder(provider_token.token),
'user_id': provider_token.user_id,
'base_domain': provider_token.base_domain,
}
return tokens
@@ -3,4 +3,4 @@ Use the {{ apiName }} with the {{ tokenEnvVar }} environment variable to retriev
Check out the branch from that {{ requestVerb }} and look at the diff versus the base branch of the {{ requestTypeShort }} to understand the {{ requestTypeShort }}'s intention.
Then use the {{ apiName }} to look at the {{ ciSystem }} that are failing on the most recent commit. Try and reproduce the failure locally.
Get things working locally, then push your changes. Sleep for 30 seconds at a time until the {{ ciProvider }} {{ ciSystem.lower() }} have run again.
If they are still failing, repeat the process.
If they are still failing, repeat the process.
@@ -1,4 +1,4 @@
You are working on {{ requestType }} #{{ issue_number }} in repository {{ repo }}. You need to fix the merge conflicts.
Use the {{ apiName }} with the {{ tokenEnvVar }} environment variable to retrieve the {{ requestTypeShort }} details.
Check out the branch from that {{ requestVerb }} and look at the diff versus the base branch of the {{ requestTypeShort }} to understand the {{ requestTypeShort }}'s intention.
Then resolve the merge conflicts. If you aren't sure what the right solution is, look back through the commit history at the commits that introduced the conflict and resolve them accordingly.
Then resolve the merge conflicts. If you aren't sure what the right solution is, look back through the commit history at the commits that introduced the conflict and resolve them accordingly.
@@ -1,4 +1,4 @@
You are working on Issue #{{ issue_number }} in repository {{ repo }}. Your goal is to fix the issue.
Use the {{ apiName }} with the {{ tokenEnvVar }} environment variable to retrieve the issue details and any comments on the issue.
Then check out a new branch and investigate what changes will need to be made.
Finally, make the required changes and open up a {{ requestVerb }}. Be sure to reference the issue in the {{ requestTypeShort }} description.
Finally, make the required changes and open up a {{ requestVerb }}. Be sure to reference the issue in the {{ requestTypeShort }} description.
@@ -2,4 +2,4 @@ You are working on {{ requestType }} #{{ issue_number }} in repository {{ repo }
Use the {{ apiName }} with the {{ tokenEnvVar }} environment variable to retrieve the {{ requestTypeShort }} details.
Check out the branch from that {{ requestVerb }} and look at the diff versus the base branch of the {{ requestTypeShort }} to understand the {{ requestTypeShort }}'s intention.
Then use the {{ apiName }} to retrieve all the feedback on the {{ requestTypeShort }} so far.
If anything hasn't been addressed, address it and commit your changes back to the same branch.
If anything hasn't been addressed, address it and commit your changes back to the same branch.
+2 -42
View File
@@ -54,7 +54,6 @@ class ConversationMemory:
def process_events(
self,
condensed_history: list[Event],
initial_user_action: MessageAction,
max_message_chars: int | None = None,
vision_is_active: bool = False,
) -> list[Message]:
@@ -67,14 +66,12 @@ class ConversationMemory:
max_message_chars: The maximum number of characters in the content of an event included
in the prompt to the LLM. Larger observations are truncated.
vision_is_active: Whether vision is active in the LLM. If True, image URLs will be included.
initial_user_action: The initial user message action, if available. Used to ensure the conversation starts correctly.
"""
events = condensed_history
# Ensure the event list starts with SystemMessageAction, then MessageAction(source='user')
# Ensure the system message exists (handles legacy cases)
self._ensure_system_message(events)
self._ensure_initial_user_message(events, initial_user_action)
# log visual browsing status
logger.debug(f'Visual browsing: {self.agent_config.enable_som_visual_browsing}')
@@ -702,43 +699,6 @@ class ConversationMemory:
system_message = SystemMessageAction(content=system_prompt)
# Insert the system message directly at the beginning of the events list
events.insert(0, system_message)
logger.info(
logger.debug(
'[ConversationMemory] Added SystemMessageAction for backward compatibility'
)
def _ensure_initial_user_message(
self, events: list[Event], initial_user_action: MessageAction
) -> None:
"""Checks if the second event is a user MessageAction and inserts the provided one if needed."""
if (
not events
): # Should have system message from previous step, but safety check
logger.error('Cannot ensure initial user message: event list is empty.')
# Or raise? Let's log for now, _ensure_system_message should handle this.
return
# We expect events[0] to be SystemMessageAction after _ensure_system_message
if len(events) == 1:
# Only system message exists
logger.info(
'Initial user message action was missing. Inserting the initial user message.'
)
events.insert(1, initial_user_action)
elif not isinstance(events[1], MessageAction) or events[1].source != 'user':
# The second event exists but is not the correct initial user message action.
# We will insert the correct one provided.
logger.info(
'Second event was not the initial user message action. Inserting correct one at index 1.'
)
# Insert the user message event at index 1. This will be the second message as LLM APIs expect
# but something was wrong with the history, so log all we can.
events.insert(1, initial_user_action)
# Else: events[1] is already a user MessageAction.
# Check if it matches the one provided (if any discrepancy, log warning but proceed).
elif events[1] != initial_user_action:
logger.debug(
'The user MessageAction at index 1 does not match the provided initial_user_action. '
'Proceeding with the one found in condensed history.'
)
-3
View File
@@ -4,7 +4,6 @@ from typing import overload
from pydantic import BaseModel
from openhands.core.logger import openhands_logger as logger
from openhands.events.action.agent import CondensationAction
from openhands.events.event import Event
from openhands.events.observation.agent import AgentCondensationObservation
@@ -66,8 +65,6 @@ class View(BaseModel):
break
if summary is not None and summary_offset is not None:
logger.info(f'Inserting summary at offset {summary_offset}')
kept_events.insert(
summary_offset, AgentCondensationObservation(content=summary)
)
@@ -10,8 +10,8 @@ from openhands.events.event_store import EventStore
from openhands.server.config.server_config import ServerConfig
from openhands.server.monitoring import MonitoringListener
from openhands.server.session.conversation import Conversation
from openhands.storage.conversation.conversation_store import ConversationStore
from openhands.storage.data_models.settings import Settings
from openhands.storage.conversation.conversation_store import ConversationStore
from openhands.storage.files import FileStore
@@ -18,9 +18,9 @@ from openhands.server.monitoring import MonitoringListener
from openhands.server.session.agent_session import WAIT_TIME_BEFORE_CLOSE
from openhands.server.session.conversation import Conversation
from openhands.server.session.session import ROOM_KEY, Session
from openhands.storage.data_models.settings import Settings
from openhands.storage.conversation.conversation_store import ConversationStore
from openhands.storage.data_models.conversation_metadata import ConversationMetadata
from openhands.storage.data_models.settings import Settings
from openhands.storage.files import FileStore
from openhands.utils.async_utils import GENERAL_TIMEOUT, call_async_from_sync, wait_all
from openhands.utils.import_utils import get_impl
@@ -42,6 +42,7 @@ from openhands.storage.data_models.conversation_status import ConversationStatus
from openhands.utils.async_utils import wait_all
from openhands.utils.conversation_summary import generate_conversation_title
app = APIRouter(prefix='/api')
@@ -53,7 +54,7 @@ class InitSessionRequest(BaseModel):
image_urls: list[str] | None = None
replay_json: str | None = None
suggested_task: SuggestedTask | None = None
async def _create_new_conversation(
user_id: str | None,
@@ -66,14 +67,10 @@ async def _create_new_conversation(
conversation_trigger: ConversationTrigger = ConversationTrigger.GUI,
attach_convo_id: bool = False,
):
print('trigger', conversation_trigger)
print("trigger", conversation_trigger)
logger.info(
'Creating conversation',
extra={
'signal': 'create_conversation',
'user_id': user_id,
'trigger': conversation_trigger.value,
},
extra={'signal': 'create_conversation', 'user_id': user_id, 'trigger': conversation_trigger.value},
)
logger.info('Loading settings')
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
@@ -193,7 +190,7 @@ async def new_conversation(
initial_user_msg=initial_user_msg,
image_urls=image_urls,
replay_json=replay_json,
conversation_trigger=conversation_trigger,
conversation_trigger=conversation_trigger
)
return JSONResponse(
+2 -1
View File
@@ -2,8 +2,9 @@ from typing import Any
from fastapi import APIRouter
from openhands.controller.agent import Agent
from openhands.security.options import SecurityAnalyzers
from openhands.controller.agent import Agent
from openhands.server.shared import config, server_config
from openhands.utils.llm import get_supported_llm_models
+59 -64
View File
@@ -17,13 +17,12 @@ from openhands.server.settings import (
POSTSettingsModel,
)
from openhands.server.shared import config
from openhands.storage.data_models.settings import Settings
from openhands.server.user_auth import (
get_provider_tokens,
get_user_id,
get_user_settings,
get_user_settings_store,
)
from openhands.storage.data_models.settings import Settings
from openhands.storage.settings.settings_store import SettingsStore
app = APIRouter(prefix='/api')
@@ -31,7 +30,6 @@ app = APIRouter(prefix='/api')
@app.get('/settings', response_model=GETSettingsModel)
async def load_settings(
user_id: str | None = Depends(get_user_id),
provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens),
settings: Settings | None = Depends(get_user_settings),
) -> GETSettingsModel | JSONResponse:
@@ -43,18 +41,10 @@ async def load_settings(
)
provider_tokens_set = {}
if bool(user_id):
provider_tokens_set[ProviderType.GITHUB.value] = True
if provider_tokens:
all_provider_types = [provider.value for provider in ProviderType]
provider_tokens_types = [provider.value for provider in provider_tokens]
for provider_type in all_provider_types:
if provider_type in provider_tokens_types:
provider_tokens_set[provider_type] = True
else:
provider_tokens_set[provider_type] = False
for provider_type, provider_token in provider_tokens.items():
if provider_token.token or provider_token.user_id:
provider_tokens_set[provider_type] = provider_token.base_domain
settings_with_token_data = GETSettingsModel(
**settings.model_dump(exclude='secrets_store'),
@@ -218,66 +208,80 @@ async def reset_settings() -> JSONResponse:
)
async def check_provider_tokens(settings: POSTSettingsModel) -> str:
async def check_provider_tokens(settings: POSTSettingsModel, existing_settings: Settings | None) -> str:
if settings.provider_tokens:
# Remove extraneous token types
provider_types = [provider.value for provider in ProviderType]
provider_types = [provider for provider in ProviderType]
settings.provider_tokens = {
k: v for k, v in settings.provider_tokens.items() if k in provider_types
}
# Determine whether tokens are valid
for token_type, token_value in settings.provider_tokens.items():
if token_value:
confirmed_token_type = await validate_provider_token(
SecretStr(token_value)
)
if not confirmed_token_type or confirmed_token_type.value != token_type:
return f'Invalid token. Please make sure it is a valid {token_type} token.'
for provider_type, provider_token in settings.provider_tokens.items():
token_value = provider_token
existing_token = existing_settings.secrets_store.provider_tokens.get(provider_type, None) if existing_settings else None
# Use incoming value otherwise default to existing value
token = SecretStr("")
if token_value.token:
token = token_value.token
elif existing_token and existing_token.token:
token = existing_token.token
if not token:
continue
base_domain = provider_token.base_domain # FE should always send latest base_domain param
confirmed_token_type = await validate_provider_token(
token,
base_domain
)
if not confirmed_token_type or confirmed_token_type != provider_type:
return f'Invalid {provider_type.value} token or base domain.'
return ''
async def store_provider_tokens(
settings: POSTSettingsModel, settings_store: SettingsStore
settings: POSTSettingsModel, existing_settings: Settings
):
existing_settings = await settings_store.load()
if existing_settings:
if settings.provider_tokens:
if existing_settings.secrets_store:
existing_providers = [
provider.value
for provider in existing_settings.secrets_store.provider_tokens
]
existing_providers = [
provider
for provider in existing_settings.secrets_store.provider_tokens
]
# Merge incoming settings store with the existing one
for provider, token_value in list(settings.provider_tokens.items()):
if provider in existing_providers and not token_value:
provider_type = ProviderType(provider)
existing_token = (
existing_settings.secrets_store.provider_tokens.get(
provider_type
)
# Merge incoming settings store with the existing one
for provider, token_value in settings.provider_tokens.items():
if provider in existing_providers and not token_value.token:
provider_type = ProviderType(provider)
existing_token = (
existing_settings.secrets_store.provider_tokens.get(
provider_type
)
if existing_token and existing_token.token:
settings.provider_tokens[provider] = (
existing_token.token.get_secret_value()
)
)
if existing_token:
updated_token = ProviderToken(
token=existing_token.token,
user_id=existing_token.user_id,
base_domain=token_value.base_domain
)
settings.provider_tokens[provider] = updated_token
else: # nothing passed in means keep current settings
provider_tokens = existing_settings.secrets_store.provider_tokens
settings.provider_tokens = {
provider.value: data.token.get_secret_value() if data.token else None
for provider, data in provider_tokens.items()
}
settings.provider_tokens = dict(existing_settings.secrets_store.provider_tokens)
return settings
async def store_llm_settings(
settings: POSTSettingsModel, settings_store: SettingsStore
settings: POSTSettingsModel, existing_settings: Settings
) -> POSTSettingsModel:
existing_settings = await settings_store.load()
# Convert to Settings model and merge with existing settings
if existing_settings:
# Keep existing LLM settings if not provided
@@ -295,9 +299,10 @@ async def store_llm_settings(
async def store_settings(
settings: POSTSettingsModel,
settings_store: SettingsStore = Depends(get_user_settings_store),
existing_settings: Settings | None = Depends(get_user_settings),
) -> JSONResponse:
# Check provider tokens are valid
provider_err_msg = await check_provider_tokens(settings)
provider_err_msg = await check_provider_tokens(settings, existing_settings)
if provider_err_msg:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
@@ -305,11 +310,9 @@ async def store_settings(
)
try:
existing_settings = await settings_store.load()
# Convert to Settings model and merge with existing settings
if existing_settings:
settings = await store_llm_settings(settings, settings_store)
settings = await store_llm_settings(settings, existing_settings)
# Keep existing analytics consent if not provided
if settings.user_consents_to_analytics is None:
@@ -317,7 +320,7 @@ async def store_settings(
existing_settings.user_consents_to_analytics
)
settings = await store_provider_tokens(settings, settings_store)
settings = await store_provider_tokens(settings, existing_settings)
# Update sandbox config with new settings
if settings.remote_runtime_resource_factor is not None:
@@ -357,17 +360,9 @@ def convert_to_settings(settings_with_token_data: POSTSettingsModel) -> Settings
# Create new provider tokens immutably
if settings_with_token_data.provider_tokens:
tokens = {}
for token_type, token_value in settings_with_token_data.provider_tokens.items():
if token_value:
provider = ProviderType(token_type)
tokens[provider] = ProviderToken(
token=SecretStr(token_value), user_id=None
)
# Create new SecretStore with tokens
settings = settings.model_copy(
update={'secrets_store': SecretStore(provider_tokens=tokens)}
update={'secrets_store': SecretStore(provider_tokens=settings_with_token_data.provider_tokens)}
)
return settings
+9 -2
View File
@@ -5,6 +5,8 @@ from pydantic import (
SecretStr,
)
from openhands.integrations.provider import ProviderToken
from openhands.integrations.service_types import ProviderType
from openhands.storage.data_models.settings import Settings
@@ -13,7 +15,7 @@ class POSTSettingsModel(Settings):
Settings for POST requests
"""
provider_tokens: dict[str, str] = {}
provider_tokens: dict[ProviderType, ProviderToken] = {}
class POSTSettingsCustomSecrets(BaseModel):
@@ -29,9 +31,14 @@ class GETSettingsModel(Settings):
Settings with additional token data for the frontend
"""
provider_tokens_set: dict[str, bool] | None = None
provider_tokens_set: dict[ProviderType, str | None] | None = (
None # Provider Type and base domain key-value pair
)
llm_api_key_set: bool
class Config:
use_enum_values = True
class GETSettingsCustomSecrets(BaseModel):
"""
@@ -4,8 +4,8 @@ import json
from dataclasses import dataclass
from openhands.core.config.app_config import AppConfig
from openhands.storage import get_file_store
from openhands.storage.data_models.settings import Settings
from openhands.storage import get_file_store
from openhands.storage.files import FileStore
from openhands.storage.settings.settings_store import SettingsStore
from openhands.utils.async_utils import call_sync_from_async
@@ -26,7 +26,7 @@ from openhands.resolver.resolver_output import ResolverOutput
@pytest.fixture
def default_mock_args():
"""Fixture that provides a default mock args object with common values.
Tests can override specific attributes as needed.
"""
mock_args = MagicMock()
@@ -53,13 +53,10 @@ def default_mock_args():
@pytest.fixture
def mock_github_token():
"""Fixture that patches the identify_token function to return GitHub provider type.
This eliminates the need for repeated patching in each test function.
"""
with patch(
'openhands.resolver.resolve_issue.identify_token',
return_value=ProviderType.GITHUB,
) as patched:
with patch('openhands.resolver.resolve_issue.identify_token', return_value=ProviderType.GITHUB) as patched:
yield patched
@@ -155,9 +152,7 @@ async def test_resolve_issue_no_issues_found(default_mock_args, mock_github_toke
# Verify that the handler was correctly configured and called
resolver.issue_handler_factory.assert_called_once()
mock_handler.get_converted_issues.assert_called_once_with(
issue_numbers=[5432], comment_id=None
)
mock_handler.get_converted_issues.assert_called_once_with(issue_numbers=[5432], comment_id=None)
def test_download_issues_from_github():
@@ -353,7 +348,9 @@ async def test_complete_runtime(default_mock_args, mock_github_token):
# Create resolver with mocked token identification
resolver = IssueResolver(default_mock_args)
result = await resolver.complete_runtime(mock_runtime, 'base_commit_hash')
result = await resolver.complete_runtime(
mock_runtime, 'base_commit_hash'
)
assert result == {'git_patch': 'git diff content'}
assert mock_runtime.run_action.call_count == 5
@@ -361,7 +358,7 @@ async def test_complete_runtime(default_mock_args, mock_github_token):
@pytest.mark.asyncio
@pytest.mark.parametrize(
'test_case',
"test_case",
[
{
'name': 'successful_run',
@@ -413,20 +410,11 @@ async def test_complete_runtime(default_mock_args, mock_github_token):
'expected_error': None,
'expected_explanation': 'Non-JSON explanation',
'is_pr': True,
'comment_success': [
True,
False,
], # To trigger the PR success logging code path
'comment_success': [True, False], # To trigger the PR success logging code path
},
],
)
async def test_process_issue(
default_mock_args,
mock_github_token,
mock_output_dir,
mock_prompt_template,
test_case,
):
async def test_process_issue(default_mock_args, mock_github_token, mock_output_dir, mock_prompt_template, test_case):
"""Test the process_issue method with different scenarios."""
# Set up test data
@@ -438,7 +426,7 @@ async def test_process_issue(
body='This is a test issue',
)
base_commit = 'abcdef1234567890'
# Customize the mock args for this test
default_mock_args.output_dir = mock_output_dir
default_mock_args.issue_type = 'pr' if test_case.get('is_pr', False) else 'issue'
@@ -469,7 +457,7 @@ async def test_process_issue(
# Mock the create_runtime function
mock_create_runtime = MagicMock(return_value=mock_runtime)
# Mock the run_controller function
mock_run_controller = AsyncMock()
if test_case['run_controller_raises']:
@@ -478,15 +466,14 @@ async def test_process_issue(
mock_run_controller.return_value = test_case['run_controller_return']
# Patch the necessary functions and methods
with patch(
'openhands.resolver.resolve_issue.create_runtime', mock_create_runtime
), patch(
'openhands.resolver.resolve_issue.run_controller', mock_run_controller
), patch.object(
resolver, 'complete_runtime', return_value={'git_patch': 'test patch'}
), patch.object(resolver, 'initialize_runtime') as mock_initialize_runtime:
with patch('openhands.resolver.resolve_issue.create_runtime', mock_create_runtime), \
patch('openhands.resolver.resolve_issue.run_controller', mock_run_controller), \
patch.object(resolver, 'complete_runtime', return_value={'git_patch': 'test patch'}), \
patch.object(resolver, 'initialize_runtime') as mock_initialize_runtime:
# Call the process_issue method
result = await resolver.process_issue(issue, base_commit, handler_instance)
# Assert the result matches our expectations
assert isinstance(result, ResolverOutput)
@@ -503,17 +490,16 @@ async def test_process_issue(
mock_initialize_runtime.assert_called_once()
mock_run_controller.assert_called_once()
resolver.complete_runtime.assert_awaited_once_with(mock_runtime, base_commit)
# Assert run_controller was called with the right parameters
if not test_case['run_controller_raises']:
# Check that the first positional argument is a config
assert 'config' in mock_run_controller.call_args[1]
# Check that initial_user_action is a MessageAction with the right content
assert isinstance(
mock_run_controller.call_args[1]['initial_user_action'], MessageAction
)
assert isinstance(mock_run_controller.call_args[1]['initial_user_action'], MessageAction)
assert mock_run_controller.call_args[1]['runtime'] == mock_runtime
# Assert that guess_success was called only for successful runs
if test_case['expected_success']:
handler_instance.guess_success.assert_called_once()
@@ -19,16 +19,14 @@ from openhands.resolver.interfaces.issue_definitions import (
ServiceContextIssue,
ServiceContextPR,
)
from openhands.resolver.resolve_issue import (
IssueResolver,
)
from openhands.resolver.resolve_issue import IssueResolver, SandboxConfig, AppConfig, AgentConfig
from openhands.resolver.resolver_output import ResolverOutput
@pytest.fixture
def default_mock_args():
"""Fixture that provides a default mock args object with common values.
Tests can override specific attributes as needed.
"""
mock_args = MagicMock()
@@ -54,13 +52,10 @@ def default_mock_args():
@pytest.fixture
def mock_gitlab_token():
"""Fixture that patches the identify_token function to return GitLab provider type.
This eliminates the need for repeated patching in each test function.
"""
with patch(
'openhands.resolver.resolve_issue.identify_token',
return_value=ProviderType.GITLAB,
) as patched:
with patch('openhands.resolver.resolve_issue.identify_token', return_value=ProviderType.GITLAB) as patched:
yield patched
@@ -129,10 +124,10 @@ def test_initialize_runtime(default_mock_args, mock_gitlab_token):
exit_code=0, content='', command='git config --global core.pager ""'
),
]
# Create resolver with mocked token identification
resolver = IssueResolver(default_mock_args)
resolver.initialize_runtime(mock_runtime)
if os.getenv('GITLAB_CI') == 'true':
@@ -159,26 +154,24 @@ async def test_resolve_issue_no_issues_found(default_mock_args, mock_gitlab_toke
# Customize the mock args for this test
default_mock_args.issue_number = 5432
# Create a resolver instance with mocked token identification
resolver = IssueResolver(default_mock_args)
# Mock the issue_handler_factory method
resolver.issue_handler_factory = MagicMock(return_value=mock_handler)
# Test that the correct exception is raised
with pytest.raises(ValueError) as exc_info:
await resolver.resolve_issue()
# Verify the error message
assert 'No issues found for issue number 5432' in str(exc_info.value)
assert 'test-owner/test-repo' in str(exc_info.value)
# Verify that the handler was correctly configured and called
resolver.issue_handler_factory.assert_called_once()
mock_handler.get_converted_issues.assert_called_once_with(
issue_numbers=[5432], comment_id=None
)
mock_handler.get_converted_issues.assert_called_once_with(issue_numbers=[5432], comment_id=None)
def test_download_issues_from_gitlab():
@@ -384,14 +377,12 @@ async def test_complete_runtime(default_mock_args, mock_gitlab_token):
content='',
command='git config --global --add safe.directory /workspace',
),
create_cmd_output(exit_code=0, content='', command='git add -A'),
create_cmd_output(
exit_code=0,
content='git diff content',
command='git diff --no-color --cached base_commit_hash',
exit_code=0, content='', command='git add -A'
),
create_cmd_output(exit_code=0, content='git diff content', command='git diff --no-color --cached base_commit_hash'),
]
# Create a resolver instance with mocked token identification
resolver = IssueResolver(default_mock_args)
@@ -403,7 +394,7 @@ async def test_complete_runtime(default_mock_args, mock_gitlab_token):
@pytest.mark.asyncio
@pytest.mark.parametrize(
'test_case',
"test_case",
[
{
'name': 'successful_run',
@@ -457,13 +448,7 @@ async def test_complete_runtime(default_mock_args, mock_gitlab_token):
},
],
)
async def test_process_issue(
default_mock_args,
mock_gitlab_token,
mock_output_dir,
mock_prompt_template,
test_case,
):
async def test_process_issue(default_mock_args, mock_gitlab_token, mock_output_dir, mock_prompt_template, test_case):
"""Test the process_issue method with different scenarios."""
# Set up test data
issue = Issue(
@@ -497,7 +482,7 @@ async def test_process_issue(
mock_runtime = MagicMock()
mock_runtime.connect = AsyncMock()
mock_create_runtime = MagicMock(return_value=mock_runtime)
# Configure run_controller mock based on test case
mock_run_controller = AsyncMock()
if test_case.get('run_controller_raises'):
@@ -506,18 +491,16 @@ async def test_process_issue(
mock_run_controller.return_value = test_case['run_controller_return']
# Patch the necessary functions and methods
with patch(
'openhands.resolver.resolve_issue.create_runtime', mock_create_runtime
), patch(
'openhands.resolver.resolve_issue.run_controller', mock_run_controller
), patch.object(
resolver, 'complete_runtime', return_value={'git_patch': 'test patch'}
), patch.object(resolver, 'initialize_runtime') as mock_initialize_runtime, patch(
'openhands.resolver.resolve_issue.SandboxConfig', return_value=MagicMock()
), patch('openhands.resolver.resolve_issue.AppConfig', return_value=MagicMock()):
with patch('openhands.resolver.resolve_issue.create_runtime', mock_create_runtime), \
patch('openhands.resolver.resolve_issue.run_controller', mock_run_controller), \
patch.object(resolver, 'complete_runtime', return_value={'git_patch': 'test patch'}), \
patch.object(resolver, 'initialize_runtime') as mock_initialize_runtime, \
patch('openhands.resolver.resolve_issue.SandboxConfig', return_value=MagicMock()), \
patch('openhands.resolver.resolve_issue.AppConfig', return_value=MagicMock()):
# Call the process_issue method
result = await resolver.process_issue(issue, base_commit, handler_instance)
mock_create_runtime.assert_called_once()
mock_runtime.connect.assert_called_once()
mock_initialize_runtime.assert_called_once()
@@ -538,7 +521,6 @@ async def test_process_issue(
else:
handler_instance.guess_success.assert_not_called()
def test_get_instruction(mock_prompt_template, mock_followup_prompt_template):
issue = Issue(
owner='test_owner',
@@ -941,4 +923,4 @@ def test_download_issue_with_specific_comment():
if __name__ == '__main__':
pytest.main()
pytest.main()
+152 -23
View File
@@ -22,6 +22,7 @@ from openhands.events.observation import (
ErrorObservation,
)
from openhands.events.observation.agent import RecallObservation
from openhands.events.observation.commands import CmdOutputObservation
from openhands.events.observation.empty import NullObservation
from openhands.events.serialization import event_to_dict
from openhands.llm import LLM
@@ -204,14 +205,11 @@ async def test_react_to_content_policy_violation(
mock_status_callback.assert_called_once_with(
'error',
'STATUS$ERROR_LLM_CONTENT_POLICY_VIOLATION',
'ContentPolicyViolationError: litellm.BadRequestError: litellm.ContentPolicyViolationError: Output blocked by content filtering policy',
'STATUS$ERROR_LLM_CONTENT_POLICY_VIOLATION',
)
# Verify the state was updated correctly
assert (
controller.state.last_error
== 'ContentPolicyViolationError: litellm.BadRequestError: litellm.ContentPolicyViolationError: Output blocked by content filtering policy'
)
assert controller.state.last_error == 'STATUS$ERROR_LLM_CONTENT_POLICY_VIOLATION'
assert controller.state.agent_state == AgentState.ERROR
await controller.close()
@@ -275,8 +273,10 @@ async def test_run_controller_with_fatal_error(
error_observation = error_observations[0]
assert state.iteration == 3
assert state.agent_state == AgentState.ERROR
assert state.last_error == 'AgentStuckInLoopError'
assert error_observation.reason == 'AgentStuckInLoopError'
assert state.last_error == 'AgentStuckInLoopError: Agent got stuck in a loop'
assert (
error_observation.reason == 'AgentStuckInLoopError: Agent got stuck in a loop'
)
assert len(events) == 12
@@ -356,7 +356,7 @@ async def test_run_controller_stop_with_stuck(
assert last_event['observation'] == 'agent_state_changed'
assert state.agent_state == AgentState.ERROR
assert state.last_error == 'AgentStuckInLoopError'
assert state.last_error == 'AgentStuckInLoopError: Agent got stuck in a loop'
@pytest.mark.asyncio
@@ -689,14 +689,20 @@ async def test_run_controller_max_iterations_has_metrics(
)
assert state.iteration == 3
assert state.agent_state == AgentState.ERROR
assert 'RuntimeError' in state.last_error
assert (
state.last_error
== 'RuntimeError: Agent reached maximum iteration in headless mode. Current iteration: 3, max iteration: 3'
)
error_observations = test_event_stream.get_matching_events(
reverse=True, limit=1, event_types=(AgentStateChangedObservation)
)
assert len(error_observations) == 1
error_observation = error_observations[0]
assert 'RuntimeError' in error_observation.reason
assert (
error_observation.reason
== 'RuntimeError: Agent reached maximum iteration in headless mode. Current iteration: 3, max iteration: 3'
)
assert (
state.metrics.accumulated_cost == 10.0 * 3
@@ -759,7 +765,7 @@ async def test_context_window_exceeded_error_handling(
# We do that by playing the role of the recall module -- subscribe to the
# event stream and respond to recall actions by inserting fake recall
# observations.
# obesrvations.
def on_event_memory(event: Event):
if isinstance(event, RecallAction):
microagent_obs = RecallObservation(
@@ -801,19 +807,13 @@ async def test_context_window_exceeded_error_handling(
# size (because we return a message action, which triggers a recall, which
# triggers a recall response). But if the pre/post-views are on the turn
# when we throw the context window exceeded error, we should see the
# post-step view compressed (or rather, a CondensationAction added).
# post-step view compressed.
for index, (first_view, second_view) in enumerate(
zip(step_state.views[:-1], step_state.views[1:])
):
if index == error_after:
# Verify that the CondensationAction is present in the second view (after error)
# but not in the first view (before error)
assert not any(isinstance(e, CondensationAction) for e in first_view.events)
assert any(isinstance(e, CondensationAction) for e in second_view.events)
# The length might not strictly decrease due to CondensationAction being added
assert len(first_view) == len(second_view)
assert len(first_view) > len(second_view)
else:
# Before the error, the view length should increase
assert len(first_view) < len(second_view)
# The final state's history should contain:
@@ -886,7 +886,7 @@ async def test_run_controller_with_context_window_exceeded_with_truncation(
def step(self, state: State):
# If the state has more than one message and we haven't errored yet,
# throw the context window exceeded error
if len(state.history) > 5 and not self.has_errored:
if len(state.history) > 3 and not self.has_errored:
error = ContextWindowExceededError(
message='prompt is too long: 233885 tokens > 200000 maximum',
model='',
@@ -940,7 +940,10 @@ async def test_run_controller_with_context_window_exceeded_with_truncation(
# expected reason
assert state.iteration == 5
assert state.agent_state == AgentState.ERROR
assert state.last_error == 'RuntimeError'
assert (
state.last_error
== 'RuntimeError: Agent reached maximum iteration in headless mode. Current iteration: 5, max iteration: 5'
)
# Check that the context window exceeded error was raised during the run
assert step_state.has_errored
@@ -1014,14 +1017,20 @@ async def test_run_controller_with_context_window_exceeded_without_truncation(
# With the refactored system message handling, the iteration count is different
assert state.iteration == 1
assert state.agent_state == AgentState.ERROR
assert state.last_error == 'LLMContextWindowExceedError'
assert (
state.last_error
== 'LLMContextWindowExceedError: Conversation history longer than LLM context window limit. Consider turning on enable_history_truncation config to avoid this error'
)
error_observations = test_event_stream.get_matching_events(
reverse=True, limit=1, event_types=(AgentStateChangedObservation)
)
assert len(error_observations) == 1
error_observation = error_observations[0]
assert 'LLMContextWindowExceedError' in error_observation.reason
assert (
error_observation.reason
== 'LLMContextWindowExceedError: Conversation history longer than LLM context window limit. Consider turning on enable_history_truncation config to avoid this error'
)
# Check that the context window exceeded error was raised during the run
assert step_state.has_errored
@@ -1458,6 +1467,126 @@ def test_agent_controller_should_step_with_null_observation_cause_zero(mock_agen
), 'should_step should return False for NullObservation with cause = 0'
def test_apply_conversation_window_basic(mock_event_stream, mock_agent):
"""Test that the _apply_conversation_window method correctly prunes a list of events."""
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test_apply_conversation_window_basic',
confirmation_mode=False,
headless_mode=True,
)
# Create a sequence of events with IDs
first_msg = MessageAction(content='Hello, start task', wait_for_response=False)
first_msg._source = EventSource.USER
first_msg._id = 1
# Add agent question
agent_msg = MessageAction(
content='What task would you like me to perform?', wait_for_response=True
)
agent_msg._source = EventSource.AGENT
agent_msg._id = 2
# Add user response
user_response = MessageAction(
content='Please list all files and show me current directory',
wait_for_response=False,
)
user_response._source = EventSource.USER
user_response._id = 3
cmd1 = CmdRunAction(command='ls')
cmd1._id = 4
obs1 = CmdOutputObservation(command='ls', content='file1.txt', command_id=4)
obs1._id = 5
obs1._cause = 4
cmd2 = CmdRunAction(command='pwd')
cmd2._id = 6
obs2 = CmdOutputObservation(command='pwd', content='/home', command_id=6)
obs2._id = 7
obs2._cause = 6
events = [first_msg, agent_msg, user_response, cmd1, obs1, cmd2, obs2]
# Apply truncation
truncated = controller._apply_conversation_window(events)
# Verify truncation occured
# Should keep first user message and roughly half of other events
assert (
3 <= len(truncated) < len(events)
) # First message + at least one action-observation pair
assert truncated[0] == first_msg # First message always preserved
assert controller.state.start_id == first_msg._id
# Verify pairs aren't split
for i, event in enumerate(truncated[1:]):
if isinstance(event, CmdOutputObservation):
assert any(e._id == event._cause for e in truncated[: i + 1])
def test_history_restoration_after_truncation(mock_event_stream, mock_agent):
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test_truncation',
confirmation_mode=False,
headless_mode=True,
)
# Create events with IDs
first_msg = MessageAction(content='Start task', wait_for_response=False)
first_msg._source = EventSource.USER
first_msg._id = 1
events = [first_msg]
for i in range(5):
cmd = CmdRunAction(command=f'cmd{i}')
cmd._id = i + 2
obs = CmdOutputObservation(
command=f'cmd{i}', content=f'output{i}', command_id=cmd._id
)
obs._cause = cmd._id
events.extend([cmd, obs])
# Set up initial history
controller.state.history = events.copy()
# Force truncation
controller.state.history = controller._apply_conversation_window(
controller.state.history
)
# Save state
saved_start_id = controller.state.start_id
saved_history_len = len(controller.state.history)
# Set up mock event stream for new controller
mock_event_stream.get_events.return_value = controller.state.history
# Create new controller with saved state
new_controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test_truncation',
confirmation_mode=False,
headless_mode=True,
)
new_controller.state.start_id = saved_start_id
new_controller.state.history = mock_event_stream.get_events()
# Verify restoration
assert len(new_controller.state.history) == saved_history_len
assert new_controller.state.history[0] == first_msg
assert new_controller.state.start_id == saved_start_id
def test_system_message_in_event_stream(mock_agent, test_event_stream):
"""Test that SystemMessageAction is added to event stream in AgentController."""
_ = AgentController(
-569
View File
@@ -1,569 +0,0 @@
from unittest.mock import MagicMock, patch
import pytest
from openhands.controller.agent import Agent
from openhands.controller.agent_controller import AgentController
from openhands.controller.state.state import State
from openhands.core.config import AppConfig
from openhands.events import EventSource
from openhands.events.action import CmdRunAction, MessageAction, RecallAction
from openhands.events.action.message import SystemMessageAction
from openhands.events.event import RecallType
from openhands.events.observation import (
CmdOutputObservation,
Observation,
RecallObservation,
)
from openhands.events.stream import EventStream
from openhands.llm.llm import LLM
from openhands.llm.metrics import Metrics
from openhands.storage.memory import InMemoryFileStore
# Helper function to create events with sequential IDs and causes
def create_events(event_data):
events = []
# Import necessary types here to avoid repeated imports inside the loop
from openhands.events.action import CmdRunAction, RecallAction
from openhands.events.observation import CmdOutputObservation, RecallObservation
for i, data in enumerate(event_data):
event_type = data['type']
source = data.get('source', EventSource.AGENT)
kwargs = {} # Arguments for the event constructor
# Determine arguments based on event type
if event_type == RecallAction:
kwargs['query'] = data.get('query', '')
kwargs['recall_type'] = data.get('recall_type', RecallType.KNOWLEDGE)
elif event_type == RecallObservation:
kwargs['content'] = data.get('content', '')
kwargs['recall_type'] = data.get('recall_type', RecallType.KNOWLEDGE)
elif event_type == CmdRunAction:
kwargs['command'] = data.get('command', '')
elif event_type == CmdOutputObservation:
# Required args for CmdOutputObservation
kwargs['content'] = data.get('content', '')
kwargs['command'] = data.get('command', '')
# Pass command_id via kwargs if present in data
if 'command_id' in data:
kwargs['command_id'] = data['command_id']
# Pass metadata if present
if 'metadata' in data:
kwargs['metadata'] = data['metadata']
else: # Default for MessageAction, SystemMessageAction, etc.
kwargs['content'] = data.get('content', '')
# Instantiate the event
event = event_type(**kwargs)
# Assign internal attributes AFTER instantiation
event._id = i + 1 # Assign sequential IDs starting from 1
event._source = source
# Assign _cause using cause_id from data, AFTER event._id is set
if 'cause_id' in data:
event._cause = data['cause_id']
# If command_id was NOT passed via kwargs but cause_id exists,
# pass cause_id as command_id to __init__ via kwargs for legacy handling
# This needs to happen *before* instantiation if we want __init__ to handle it
# Let's adjust the logic slightly:
if event_type == CmdOutputObservation:
if 'command_id' not in kwargs and 'cause_id' in data:
kwargs['command_id'] = data['cause_id'] # Let __init__ handle this
# Re-instantiate if we added command_id
if 'command_id' in kwargs and event.command_id != kwargs['command_id']:
event = event_type(**kwargs)
event._id = i + 1
event._source = source
# Now assign _cause if it exists in data, after potential re-instantiation
if 'cause_id' in data:
event._cause = data['cause_id']
events.append(event)
return events
@pytest.fixture
def controller_fixture():
mock_agent = MagicMock(spec=Agent)
mock_agent.llm = MagicMock(spec=LLM)
mock_agent.llm.metrics = Metrics()
mock_agent.llm.config = AppConfig().get_llm_config()
mock_agent.config = AppConfig().get_agent_config('CodeActAgent')
mock_event_stream = MagicMock(spec=EventStream)
mock_event_stream.sid = 'test_sid'
mock_event_stream.file_store = InMemoryFileStore({})
# Ensure get_latest_event_id returns an integer
mock_event_stream.get_latest_event_id.return_value = -1
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test_sid',
)
controller.state = State(session_id='test_sid')
# Mock _first_user_message directly on the instance
mock_first_user_message = MagicMock(spec=MessageAction)
controller._first_user_message = MagicMock(return_value=mock_first_user_message)
return controller, mock_first_user_message
# =============================================
# Test Cases for _apply_conversation_window
# =============================================
def test_basic_truncation(controller_fixture):
controller, mock_first_user_message = controller_fixture
controller.state.history = create_events(
[
{'type': SystemMessageAction, 'content': 'System Prompt'}, # 1
{
'type': MessageAction,
'content': 'User Task 1',
'source': EventSource.USER,
}, # 2
{'type': RecallAction, 'query': 'User Task 1'}, # 3
{'type': RecallObservation, 'content': 'Recall result', 'cause_id': 3}, # 4
{'type': CmdRunAction, 'command': 'ls'}, # 5
{
'type': CmdOutputObservation,
'content': 'file1',
'command': 'ls',
'cause_id': 5,
}, # 6
{'type': CmdRunAction, 'command': 'pwd'}, # 7
{
'type': CmdOutputObservation,
'content': '/dir',
'command': 'pwd',
'cause_id': 7,
}, # 8
{'type': CmdRunAction, 'command': 'cat file1'}, # 9
{
'type': CmdOutputObservation,
'content': 'content',
'command': 'cat file1',
'cause_id': 9,
}, # 10
]
)
mock_first_user_message.id = 2 # Set the ID of the mocked first user message
# Calculation (RecallAction now essential):
# History len = 10
# Essentials = [sys(1), user(2), recall_act(3), recall_obs(4)] (len=4)
# Non-essential count = 10 - 4 = 6
# num_recent_to_keep = max(1, 6 // 2) = 3
# slice_start_index = 10 - 3 = 7
# recent_events_slice = history[7:] = [obs2(8), cmd3(9), obs3(10)]
# Validation: remove leading obs2(8). validated_slice = [cmd3(9), obs3(10)]
# Final = essentials + validated_slice = [sys(1), user(2), recall_act(3), recall_obs(4), cmd3(9), obs3(10)]
# Expected IDs: [1, 2, 3, 4, 9, 10]. Length 6.
truncated_events = controller._apply_conversation_window()
assert len(truncated_events) == 6
expected_ids = [1, 2, 3, 4, 9, 10]
actual_ids = [e.id for e in truncated_events]
assert actual_ids == expected_ids
# Check no dangling observations at the start of the recent slice part
# The first event of the validated slice is cmd3(9)
assert not isinstance(truncated_events[4], Observation) # Index adjusted
def test_no_system_message(controller_fixture):
controller, mock_first_user_message = controller_fixture
controller.state.history = create_events(
[
{
'type': MessageAction,
'content': 'User Task 1',
'source': EventSource.USER,
}, # 1
{'type': RecallAction, 'query': 'User Task 1'}, # 2
{'type': RecallObservation, 'content': 'Recall result', 'cause_id': 2}, # 3
{'type': CmdRunAction, 'command': 'ls'}, # 4
{
'type': CmdOutputObservation,
'content': 'file1',
'command': 'ls',
'cause_id': 4,
}, # 5
{'type': CmdRunAction, 'command': 'pwd'}, # 6
{
'type': CmdOutputObservation,
'content': '/dir',
'command': 'pwd',
'cause_id': 6,
}, # 7
{'type': CmdRunAction, 'command': 'cat file1'}, # 8
{
'type': CmdOutputObservation,
'content': 'content',
'command': 'cat file1',
'cause_id': 8,
}, # 9
]
)
mock_first_user_message.id = 1
# Calculation (RecallAction now essential):
# History len = 9
# Essentials = [user(1), recall_act(2), recall_obs(3)] (len=3)
# Non-essential count = 9 - 3 = 6
# num_recent_to_keep = max(1, 6 // 2) = 3
# slice_start_index = 9 - 3 = 6
# recent_events_slice = history[6:] = [obs2(7), cmd3(8), obs3(9)]
# Validation: remove leading obs2(7). validated_slice = [cmd3(8), obs3(9)]
# Final = essentials + validated_slice = [user(1), recall_act(2), recall_obs(3), cmd3(8), obs3(9)]
# Expected IDs: [1, 2, 3, 8, 9]. Length 5.
truncated_events = controller._apply_conversation_window()
assert len(truncated_events) == 5
expected_ids = [1, 2, 3, 8, 9]
actual_ids = [e.id for e in truncated_events]
assert actual_ids == expected_ids
def test_no_recall_observation(controller_fixture):
controller, mock_first_user_message = controller_fixture
controller.state.history = create_events(
[
{'type': SystemMessageAction, 'content': 'System Prompt'}, # 1
{
'type': MessageAction,
'content': 'User Task 1',
'source': EventSource.USER,
}, # 2
{'type': RecallAction, 'query': 'User Task 1'}, # 3 (Recall Action exists)
# Recall Observation is missing
{'type': CmdRunAction, 'command': 'ls'}, # 4
{
'type': CmdOutputObservation,
'content': 'file1',
'command': 'ls',
'cause_id': 4,
}, # 5
{'type': CmdRunAction, 'command': 'pwd'}, # 6
{
'type': CmdOutputObservation,
'content': '/dir',
'command': 'pwd',
'cause_id': 6,
}, # 7
{'type': CmdRunAction, 'command': 'cat file1'}, # 8
{
'type': CmdOutputObservation,
'content': 'content',
'command': 'cat file1',
'cause_id': 8,
}, # 9
]
)
mock_first_user_message.id = 2
# Calculation (RecallAction essential only if RecallObs exists):
# History len = 9
# Essentials = [sys(1), user(2)] (len=2) - RecallObs missing, so RecallAction not essential here
# Non-essential count = 9 - 2 = 7
# num_recent_to_keep = max(1, 7 // 2) = 3
# slice_start_index = 9 - 3 = 6
# recent_events_slice = history[6:] = [obs2(7), cmd3(8), obs3(9)]
# Validation: remove leading obs2(7). validated_slice = [cmd3(8), obs3(9)]
# Final = essentials + validated_slice = [sys(1), user(2), recall_action(3), cmd_cat(8), obs_cat(9)]
# Expected IDs: [1, 2, 3, 8, 9]. Length 5.
truncated_events = controller._apply_conversation_window()
assert len(truncated_events) == 5
expected_ids = [1, 2, 3, 8, 9]
actual_ids = [e.id for e in truncated_events]
assert actual_ids == expected_ids
def test_short_history_no_truncation(controller_fixture):
controller, mock_first_user_message = controller_fixture
history = create_events(
[
{'type': SystemMessageAction, 'content': 'System Prompt'}, # 1
{
'type': MessageAction,
'content': 'User Task 1',
'source': EventSource.USER,
}, # 2
{'type': RecallAction, 'query': 'User Task 1'}, # 3
{'type': RecallObservation, 'content': 'Recall result', 'cause_id': 3}, # 4
{'type': CmdRunAction, 'command': 'ls'}, # 5
{
'type': CmdOutputObservation,
'content': 'file1',
'command': 'ls',
'cause_id': 5,
}, # 6
]
)
controller.state.history = history
mock_first_user_message.id = 2
# Calculation (RecallAction now essential):
# History len = 6
# Essentials = [sys(1), user(2), recall_act(3), recall_obs(4)] (len=4)
# Non-essential count = 6 - 4 = 2
# num_recent_to_keep = max(1, 2 // 2) = 1
# slice_start_index = 6 - 1 = 5
# recent_events_slice = history[5:] = [obs1(6)]
# Validation: remove leading obs1(6). validated_slice = []
# Final = essentials + validated_slice = [sys(1), user(2), recall_act(3), recall_obs(4)]
# Expected IDs: [1, 2, 3, 4]. Length 4.
truncated_events = controller._apply_conversation_window()
assert len(truncated_events) == 4
expected_ids = [1, 2, 3, 4]
actual_ids = [e.id for e in truncated_events]
assert actual_ids == expected_ids
def test_only_essential_events(controller_fixture):
controller, mock_first_user_message = controller_fixture
history = create_events(
[
{'type': SystemMessageAction, 'content': 'System Prompt'}, # 1
{
'type': MessageAction,
'content': 'User Task 1',
'source': EventSource.USER,
}, # 2
{'type': RecallAction, 'query': 'User Task 1'}, # 3
{'type': RecallObservation, 'content': 'Recall result', 'cause_id': 3}, # 4
]
)
controller.state.history = history
mock_first_user_message.id = 2
# Calculation (RecallAction now essential):
# History len = 4
# Essentials = [sys(1), user(2), recall_act(3), recall_obs(4)] (len=4)
# Non-essential count = 4 - 4 = 0
# num_recent_to_keep = max(1, 0 // 2) = 1
# slice_start_index = 4 - 1 = 3
# recent_events_slice = history[3:] = [recall_obs(4)]
# Validation: remove leading recall_obs(4). validated_slice = []
# Final = essentials + validated_slice = [sys(1), user(2), recall_act(3), recall_obs(4)]
# Expected IDs: [1, 2, 3, 4]. Length 4.
truncated_events = controller._apply_conversation_window()
assert len(truncated_events) == 4
expected_ids = [1, 2, 3, 4]
actual_ids = [e.id for e in truncated_events]
assert actual_ids == expected_ids
def test_dangling_observations_at_cut_point(controller_fixture):
controller, mock_first_user_message = controller_fixture
history_forced_dangle = create_events(
[
{'type': SystemMessageAction, 'content': 'System Prompt'}, # 1
{
'type': MessageAction,
'content': 'User Task 1',
'source': EventSource.USER,
}, # 2
{'type': RecallAction, 'query': 'User Task 1'}, # 3
{'type': RecallObservation, 'content': 'Recall result', 'cause_id': 3}, # 4
# --- Slice calculation should start here ---
{
'type': CmdOutputObservation,
'content': 'dangle1',
'command': 'cmd_unknown',
}, # 5 (Dangling)
{
'type': CmdOutputObservation,
'content': 'dangle2',
'command': 'cmd_unknown',
}, # 6 (Dangling)
{'type': CmdRunAction, 'command': 'cmd1'}, # 7
{
'type': CmdOutputObservation,
'content': 'obs1',
'command': 'cmd1',
'cause_id': 7,
}, # 8
{'type': CmdRunAction, 'command': 'cmd2'}, # 9
{
'type': CmdOutputObservation,
'content': 'obs2',
'command': 'cmd2',
'cause_id': 9,
}, # 10
]
) # 10 events total
controller.state.history = history_forced_dangle
mock_first_user_message.id = 2
# Calculation (RecallAction now essential):
# History len = 10
# Essentials = [sys(1), user(2), recall_act(3), recall_obs(4)] (len=4)
# Non-essential count = 10 - 4 = 6
# num_recent_to_keep = max(1, 6 // 2) = 3
# slice_start_index = 10 - 3 = 7
# recent_events_slice = history[7:] = [obs1(8), cmd2(9), obs2(10)]
# Validation: remove leading obs1(8). validated_slice = [cmd2(9), obs2(10)]
# Final = essentials + validated_slice = [sys(1), user(2), recall_act(3), recall_obs(4), cmd2(9), obs2(10)]
# Expected IDs: [1, 2, 3, 4, 9, 10]. Length 6.
truncated_events = controller._apply_conversation_window()
assert len(truncated_events) == 6
expected_ids = [1, 2, 3, 4, 9, 10]
actual_ids = [e.id for e in truncated_events]
assert actual_ids == expected_ids
# Verify dangling observations 5 and 6 were removed (implicitly by slice start and validation)
def test_only_dangling_observations_in_recent_slice(controller_fixture):
controller, mock_first_user_message = controller_fixture
history = create_events(
[
{'type': SystemMessageAction, 'content': 'System Prompt'}, # 1
{
'type': MessageAction,
'content': 'User Task 1',
'source': EventSource.USER,
}, # 2
{'type': RecallAction, 'query': 'User Task 1'}, # 3
{'type': RecallObservation, 'content': 'Recall result', 'cause_id': 3}, # 4
# --- Slice calculation should start here ---
{
'type': CmdOutputObservation,
'content': 'dangle1',
'command': 'cmd_unknown',
}, # 5 (Dangling)
{
'type': CmdOutputObservation,
'content': 'dangle2',
'command': 'cmd_unknown',
}, # 6 (Dangling)
]
) # 6 events total
controller.state.history = history
mock_first_user_message.id = 2
# Calculation (RecallAction now essential):
# History len = 6
# Essentials = [sys(1), user(2), recall_act(3), recall_obs(4)] (len=4)
# Non-essential count = 6 - 4 = 2
# num_recent_to_keep = max(1, 2 // 2) = 1
# slice_start_index = 6 - 1 = 5
# recent_events_slice = history[5:] = [dangle2(6)]
# Validation: remove leading dangle2(6). validated_slice = [] (Corrected based on user feedback/bugfix)
# Final = essentials + validated_slice = [sys(1), user(2), recall_act(3), recall_obs(4)]
# Expected IDs: [1, 2, 3, 4]. Length 4.
with patch(
'openhands.controller.agent_controller.logger.warning'
) as mock_log_warning:
truncated_events = controller._apply_conversation_window()
assert len(truncated_events) == 4
expected_ids = [1, 2, 3, 4]
actual_ids = [e.id for e in truncated_events]
assert actual_ids == expected_ids
# Verify dangling observations 5 and 6 were removed
# Check that the specific warning was logged exactly once
assert mock_log_warning.call_count == 1
# Check the essential parts of the arguments, allowing for variations like stacklevel
call_args, call_kwargs = mock_log_warning.call_args
expected_message_substring = 'All recent events are dangling observations, which we truncate. This means the agent has only the essential first events. This should not happen.'
assert expected_message_substring in call_args[0]
assert 'extra' in call_kwargs
assert call_kwargs['extra'].get('session_id') == 'test_sid'
def test_empty_history(controller_fixture):
controller, _ = controller_fixture
controller.state.history = []
truncated_events = controller._apply_conversation_window()
assert truncated_events == []
def test_multiple_user_messages(controller_fixture):
controller, mock_first_user_message = controller_fixture
history = create_events(
[
{'type': SystemMessageAction, 'content': 'System Prompt'}, # 1
{
'type': MessageAction,
'content': 'User Task 1',
'source': EventSource.USER,
}, # 2 (First)
{'type': RecallAction, 'query': 'User Task 1'}, # 3
{
'type': RecallObservation,
'content': 'Recall result 1',
'cause_id': 3,
}, # 4
{'type': CmdRunAction, 'command': 'cmd1'}, # 5
{
'type': CmdOutputObservation,
'content': 'obs1',
'command': 'cmd1',
'cause_id': 5,
}, # 6
{
'type': MessageAction,
'content': 'User Task 2',
'source': EventSource.USER,
}, # 7 (Second)
{'type': RecallAction, 'query': 'User Task 2'}, # 8
{
'type': RecallObservation,
'content': 'Recall result 2',
'cause_id': 8,
}, # 9
{'type': CmdRunAction, 'command': 'cmd2'}, # 10
{
'type': CmdOutputObservation,
'content': 'obs2',
'command': 'cmd2',
'cause_id': 10,
}, # 11
]
) # 11 events total
controller.state.history = history
mock_first_user_message.id = 2 # Explicitly set the first user message ID
# Calculation (RecallAction now essential):
# History len = 11
# Essentials = [sys(1), user1(2), recall_act1(3), recall_obs1(4)] (len=4)
# Non-essential count = 11 - 4 = 7
# num_recent_to_keep = max(1, 7 // 2) = 3
# slice_start_index = 11 - 3 = 8
# recent_events_slice = history[8:] = [recall_obs2(9), cmd2(10), obs2(11)]
# Validation: remove leading recall_obs2(9). validated_slice = [cmd2(10), obs2(11)]
# Final = essentials + validated_slice = [sys(1), user1(2), recall_act1(3), recall_obs1(4)] + [cmd2(10), obs2(11)]
# Expected IDs: [1, 2, 3, 4, 10, 11]. Length 6.
truncated_events = controller._apply_conversation_window()
assert len(truncated_events) == 6
expected_ids = [1, 2, 3, 4, 10, 11]
actual_ids = [e.id for e in truncated_events]
assert actual_ids == expected_ids
# Verify the second user message (ID 7) was NOT kept
assert not any(event.id == 7 for event in truncated_events)
# Verify the first user message (ID 2) is present
assert any(event.id == 2 for event in truncated_events)
+17 -65
View File
@@ -44,7 +44,6 @@ from openhands.events.observation.commands import (
)
from openhands.events.tool import ToolCallMetadata
from openhands.llm.llm import LLM
from openhands.memory.condenser import View
@pytest.fixture(params=['CodeActAgent', 'ReadOnlyAgent'])
@@ -98,12 +97,6 @@ def test_reset(agent):
action._source = EventSource.AGENT
agent.pending_actions.append(action)
# Create a mock state with initial user message
mock_state = Mock(spec=State)
initial_user_message = MessageAction(content='Initial user message')
initial_user_message._source = EventSource.USER
mock_state.history = [initial_user_message]
# Reset
agent.reset()
@@ -117,14 +110,8 @@ def test_step_with_pending_actions(agent):
pending_action._source = EventSource.AGENT
agent.pending_actions.append(pending_action)
# Create a mock state with initial user message
mock_state = Mock(spec=State)
initial_user_message = MessageAction(content='Initial user message')
initial_user_message._source = EventSource.USER
mock_state.history = [initial_user_message]
# Step should return the pending action
result = agent.step(mock_state)
result = agent.step(Mock())
assert result == pending_action
assert len(agent.pending_actions) == 0
@@ -273,11 +260,6 @@ def test_step_with_no_pending_actions(mock_state: State):
mock_state.latest_user_message_llm_metrics = None
mock_state.latest_user_message_tool_call_metadata = None
# Add initial user message to history
initial_user_message = MessageAction(content='Initial user message')
initial_user_message._source = EventSource.USER
mock_state.history = [initial_user_message]
action = agent.step(mock_state)
assert isinstance(action, MessageAction)
assert action.content == 'Task completed'
@@ -348,56 +330,42 @@ def test_mismatched_tool_call_events_and_auto_add_system_message(
)
action = CmdRunAction('foo')
action._source = EventSource.AGENT
action._source = 'agent'
action.tool_call_metadata = tool_call_metadata
observation = CmdOutputObservation(content='', command_id=0, command='foo')
observation.tool_call_metadata = tool_call_metadata
# Add initial user message
initial_user_message = MessageAction(content='Initial user message')
initial_user_message._source = EventSource.USER
# When both events are provided, the agent should get three messages:
# 1. The system message (added automatically for backward compatibility)
# 2. The action message
# 3. The observation message
mock_state.history = [initial_user_message, action, observation]
messages = agent._get_messages(mock_state.history, initial_user_message)
assert len(messages) == 4 # System + initial user + action + observation
mock_state.history = [action, observation]
messages = agent._get_messages(mock_state.history)
assert len(messages) == 3
assert messages[0].role == 'system' # First message should be the system message
assert (
messages[1].role == 'user'
) # Second message should be the initial user message
assert messages[2].role == 'assistant' # Third message should be the action
assert messages[3].role == 'tool' # Fourth message should be the observation
assert messages[1].role == 'assistant' # Second message should be the action
assert messages[2].role == 'tool' # Third message should be the observation
# The same should hold if the events are presented out-of-order
mock_state.history = [initial_user_message, observation, action]
messages = agent._get_messages(mock_state.history, initial_user_message)
assert len(messages) == 4
mock_state.history = [observation, action]
messages = agent._get_messages(mock_state.history)
assert len(messages) == 3
assert messages[0].role == 'system' # First message should be the system message
assert (
messages[1].role == 'user'
) # Second message should be the initial user message
# If only one of the two events is present, then we should just get the system message
# plus any valid message from the event
mock_state.history = [initial_user_message, action]
messages = agent._get_messages(mock_state.history, initial_user_message)
mock_state.history = [action]
messages = agent._get_messages(mock_state.history)
assert (
len(messages) == 2
) # System + initial user message, action is waiting for its observation
len(messages) == 1
) # Only system message, action is waiting for its observation
assert messages[0].role == 'system'
assert messages[1].role == 'user'
mock_state.history = [initial_user_message, observation]
messages = agent._get_messages(mock_state.history, initial_user_message)
assert (
len(messages) == 2
) # System + initial user message, observation has no matching action
mock_state.history = [observation]
messages = agent._get_messages(mock_state.history)
assert len(messages) == 1 # Only system message, observation has no matching action
assert messages[0].role == 'system'
assert messages[1].role == 'user'
def test_grep_tool():
@@ -502,19 +470,3 @@ def test_get_system_message():
assert len(result.tools) > 0
assert any(tool['function']['name'] == 'execute_bash' for tool in result.tools)
assert result._source == EventSource.AGENT
def test_step_raises_error_if_no_initial_user_message(
agent: CodeActAgent, mock_state: State
):
"""Tests that step raises ValueError if the initial user message is not found."""
# Ensure history does NOT contain a user MessageAction
assistant_message = MessageAction(content='Assistant message')
assistant_message._source = EventSource.AGENT
mock_state.history = [assistant_message]
# Mock the condenser to return the history as is
agent.condenser = Mock()
agent.condenser.condensed_history.return_value = View(events=mock_state.history)
with pytest.raises(ValueError, match='Initial user message not found'):
agent.step(mock_state)
-25
View File
@@ -8,7 +8,6 @@ from openhands.core.cli_commands import (
handle_help_command,
handle_init_command,
handle_new_command,
handle_resume_command,
handle_settings_command,
handle_status_command,
)
@@ -462,27 +461,3 @@ class TestHandleSettingsCommand:
# Verify correct behavior
mock_display_settings.assert_called_once_with(config)
mock_cli_confirm.assert_called_once()
class TestHandleResumeCommand:
@pytest.mark.asyncio
async def test_handle_resume_command(self):
"""Test that handle_resume_command adds the 'continue' message to the event stream."""
# Create a mock event stream
event_stream = MagicMock(spec=EventStream)
# Call the function
close_repl, new_session_requested = await handle_resume_command(event_stream)
# Check that the event stream add_event was called with the correct message action
event_stream.add_event.assert_called_once()
args, kwargs = event_stream.add_event.call_args
message_action, source = args
assert isinstance(message_action, MessageAction)
assert message_action.content == 'continue'
assert source == EventSource.USER
# Check the return values
assert close_repl is True
assert new_session_requested is False
-325
View File
@@ -1,325 +0,0 @@
import asyncio
from unittest.mock import MagicMock, call, patch
import pytest
from prompt_toolkit.formatted_text import HTML
from prompt_toolkit.keys import Keys
from openhands.core.cli_tui import process_agent_pause
from openhands.core.schema import AgentState
from openhands.events import EventSource
from openhands.events.action import ChangeAgentStateAction
from openhands.events.observation import AgentStateChangedObservation
from openhands.events.stream import EventStream
class TestProcessAgentPause:
@pytest.mark.asyncio
@patch('openhands.core.cli_tui.create_input')
@patch('openhands.core.cli_tui.print_formatted_text')
async def test_process_agent_pause_ctrl_p(self, mock_print, mock_create_input):
"""Test that process_agent_pause sets the done event when Ctrl+P is pressed."""
# Create the done event
done = asyncio.Event()
# Set up the mock input
mock_input = MagicMock()
mock_create_input.return_value = mock_input
# Mock the context managers
mock_raw_mode = MagicMock()
mock_input.raw_mode.return_value = mock_raw_mode
mock_raw_mode.__enter__ = MagicMock()
mock_raw_mode.__exit__ = MagicMock()
mock_attach = MagicMock()
mock_input.attach.return_value = mock_attach
mock_attach.__enter__ = MagicMock()
mock_attach.__exit__ = MagicMock()
# Capture the keys_ready function
keys_ready_func = None
def fake_attach(callback):
nonlocal keys_ready_func
keys_ready_func = callback
return mock_attach
mock_input.attach.side_effect = fake_attach
# Create a task to run process_agent_pause
task = asyncio.create_task(process_agent_pause(done))
# Give it a moment to start and capture the callback
await asyncio.sleep(0.1)
# Make sure we captured the callback
assert keys_ready_func is not None
# Create a key press that simulates Ctrl+P
key_press = MagicMock()
key_press.key = Keys.ControlP
mock_input.read_keys.return_value = [key_press]
# Manually call the callback to simulate key press
keys_ready_func()
# Verify done was set
assert done.is_set()
# Verify print was called with the pause message
assert mock_print.call_count == 2
assert mock_print.call_args_list[0] == call('')
# Check that the second call contains the pause message HTML
second_call = mock_print.call_args_list[1][0][0]
assert isinstance(second_call, HTML)
assert 'Pausing the agent' in str(second_call)
# Cancel the task
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
class TestCliPauseResumeInRunSession:
@pytest.mark.asyncio
async def test_on_event_async_pause_processing(self):
"""Test that on_event_async processes the pause event when is_paused is set."""
# Create a mock event
event = MagicMock()
# Create mock dependencies
event_stream = MagicMock()
is_paused = asyncio.Event()
reload_microagents = False
config = MagicMock()
# Patch the display_event function
with patch('openhands.core.cli.display_event') as mock_display_event, patch(
'openhands.core.cli.update_usage_metrics'
) as mock_update_metrics:
# Create a closure to capture the current context
async def test_func():
# Set the pause event
is_paused.set()
# Create a context similar to run_session to call on_event_async
# We're creating a function that mimics the environment of on_event_async
async def on_event_async_test(event):
nonlocal reload_microagents, is_paused
mock_display_event(event, config)
mock_update_metrics(event, usage_metrics=MagicMock())
# Pause the agent if the pause event is set (through Ctrl-P)
if is_paused.is_set():
event_stream.add_event(
ChangeAgentStateAction(AgentState.PAUSED),
EventSource.USER,
)
is_paused.clear()
# Call our test function
await on_event_async_test(event)
# Check that the event_stream.add_event was called with the correct action
event_stream.add_event.assert_called_once()
args, kwargs = event_stream.add_event.call_args
action, source = args
assert isinstance(action, ChangeAgentStateAction)
assert action.agent_state == AgentState.PAUSED
assert source == EventSource.USER
# Check that is_paused was cleared
assert not is_paused.is_set()
# Run the test function
await test_func()
class TestCliCommandsPauseResume:
@pytest.mark.asyncio
@patch('openhands.core.cli_commands.handle_resume_command')
async def test_handle_commands_resume(self, mock_handle_resume):
"""Test that the handle_commands function properly calls handle_resume_command."""
# Import the handle_commands function
from openhands.core.cli_commands import handle_commands
# Set up mocks
event_stream = MagicMock(spec=EventStream)
usage_metrics = MagicMock()
sid = 'test-session-id'
config = MagicMock()
current_dir = '/test/dir'
settings_store = MagicMock()
# Set the return value for handle_resume_command
mock_handle_resume.return_value = (False, False)
# Call handle_commands with the resume command
close_repl, reload_microagents, new_session_requested = await handle_commands(
'/resume',
event_stream,
usage_metrics,
sid,
config,
current_dir,
settings_store,
)
# Check that handle_resume_command was called with the correct arguments
mock_handle_resume.assert_called_once_with(event_stream)
# Check the return values
assert close_repl is False
assert reload_microagents is False
assert new_session_requested is False
class TestAgentStatePauseResume:
@pytest.mark.asyncio
@patch('openhands.core.cli.display_agent_running_message')
@patch('openhands.core.cli.process_agent_pause')
async def test_agent_running_enables_pause(
self, mock_process_agent_pause, mock_display_message
):
"""Test that when the agent is running, pause functionality is enabled."""
# Create mock dependencies
event = MagicMock()
# AgentStateChangedObservation requires a content parameter
event.observation = AgentStateChangedObservation(
agent_state=AgentState.RUNNING, content='Agent state changed to RUNNING'
)
# Create a context similar to run_session to call on_event_async
loop = MagicMock()
is_paused = asyncio.Event()
config = MagicMock()
config.security.confirmation_mode = False
# Create a closure to capture the current context
async def test_func():
# Call our simplified on_event_async
async def on_event_async_test(event):
if isinstance(event.observation, AgentStateChangedObservation):
if event.observation.agent_state == AgentState.RUNNING:
# Enable pause/resume functionality only if the confirmation mode is disabled
if not config.security.confirmation_mode:
mock_display_message()
loop.create_task(mock_process_agent_pause(is_paused))
# Call the function
await on_event_async_test(event)
# Check that the message was displayed
mock_display_message.assert_called_once()
# Check that process_agent_pause was called with the right arguments
mock_process_agent_pause.assert_called_once_with(is_paused)
# Check that loop.create_task was called
loop.create_task.assert_called_once()
# Run the test function
await test_func()
@pytest.mark.asyncio
@patch('openhands.core.cli.display_event')
@patch('openhands.core.cli.update_usage_metrics')
async def test_pause_event_changes_agent_state(
self, mock_update_metrics, mock_display_event
):
"""Test that when is_paused is set, a PAUSED state change event is added to the stream."""
# Create mock dependencies
event = MagicMock()
event_stream = MagicMock()
is_paused = asyncio.Event()
config = MagicMock()
reload_microagents = False
# Set the pause event
is_paused.set()
# Create a closure to capture the current context
async def test_func():
# Create a context similar to run_session to call on_event_async
async def on_event_async_test(event):
nonlocal reload_microagents
mock_display_event(event, config)
mock_update_metrics(event, MagicMock())
# Pause the agent if the pause event is set (through Ctrl-P)
if is_paused.is_set():
event_stream.add_event(
ChangeAgentStateAction(AgentState.PAUSED),
EventSource.USER,
)
is_paused.clear()
# Call the function
await on_event_async_test(event)
# Check that the event_stream.add_event was called with the correct action
event_stream.add_event.assert_called_once()
args, kwargs = event_stream.add_event.call_args
action, source = args
assert isinstance(action, ChangeAgentStateAction)
assert action.agent_state == AgentState.PAUSED
assert source == EventSource.USER
# Check that is_paused was cleared
assert not is_paused.is_set()
# Run the test
await test_func()
@pytest.mark.asyncio
async def test_paused_agent_awaits_input(self):
"""Test that when the agent is paused, it awaits user input."""
# Create mock dependencies
event = MagicMock()
# AgentStateChangedObservation requires a content parameter
event.observation = AgentStateChangedObservation(
agent_state=AgentState.PAUSED, content='Agent state changed to PAUSED'
)
reload_microagents = False
memory = MagicMock()
runtime = MagicMock()
prompt_task = MagicMock()
# Create a closure to capture the current context
async def test_func():
# Create a simplified version of on_event_async
async def on_event_async_test(event):
nonlocal reload_microagents, prompt_task
if isinstance(event.observation, AgentStateChangedObservation):
if event.observation.agent_state in [
AgentState.AWAITING_USER_INPUT,
AgentState.FINISHED,
AgentState.PAUSED,
]:
# Reload microagents after initialization of repo.md
if reload_microagents:
microagents = runtime.get_microagents_from_selected_repo(
None
)
memory.load_user_workspace_microagents(microagents)
reload_microagents = False
# Since prompt_for_next_task is a nested function in cli.py,
# we'll just check that we've reached this code path
prompt_task = 'Prompt for next task would be called here'
# Call the function
await on_event_async_test(event)
# Check that we reached the code path where prompt_for_next_task would be called
assert prompt_task == 'Prompt for next task would be called here'
# Run the test
await test_func()
+5 -1
View File
@@ -1,3 +1,4 @@
import asyncio
from unittest.mock import MagicMock, Mock, patch
from openhands.core.cli_tui import (
@@ -51,9 +52,12 @@ class TestDisplayFunctions:
@patch('openhands.core.cli_tui.print_formatted_text')
def test_display_banner(self, mock_print):
# Create a mock loaded event
is_loaded = asyncio.Event()
is_loaded.set()
session_id = 'test-session-id'
display_banner(session_id)
display_banner(session_id, is_loaded)
# Verify banner calls
assert mock_print.call_count >= 3
+48 -464
View File
@@ -100,7 +100,6 @@ def test_process_events_with_message_action(conversation_memory):
# Process events
messages = conversation_memory.process_events(
condensed_history=[system_message, user_message, assistant_message],
initial_user_action=user_message,
max_message_chars=None,
vision_is_active=False,
)
@@ -109,178 +108,10 @@ def test_process_events_with_message_action(conversation_memory):
assert len(messages) == 3
assert messages[0].role == 'system'
assert messages[0].content[0].text == 'System message'
# Test cases for _ensure_system_message
def test_ensure_system_message_adds_if_missing(conversation_memory):
"""Test that _ensure_system_message adds a system message if none exists."""
user_message = MessageAction(content='User message')
user_message._source = EventSource.USER
events = [user_message]
conversation_memory._ensure_system_message(events)
assert len(events) == 2
assert isinstance(events[0], SystemMessageAction)
assert events[0].content == 'System message' # From fixture
assert isinstance(events[1], MessageAction) # Original event is still there
def test_ensure_system_message_does_nothing_if_present(conversation_memory):
"""Test that _ensure_system_message does nothing if a system message is already present."""
original_system_message = SystemMessageAction(content='Existing system message')
user_message = MessageAction(content='User message')
user_message._source = EventSource.USER
events = [
original_system_message,
user_message,
]
original_events = list(events) # Copy before modification
conversation_memory._ensure_system_message(events)
assert events == original_events # List should be unchanged
# Test cases for _ensure_initial_user_message
@pytest.fixture
def initial_user_action():
msg = MessageAction(content='Initial User Message')
msg._source = EventSource.USER
return msg
def test_ensure_initial_user_message_adds_if_only_system(
conversation_memory, initial_user_action
):
"""Test adding the initial user message when only the system message exists."""
system_message = SystemMessageAction(content='System')
system_message._source = EventSource.AGENT
events = [system_message]
conversation_memory._ensure_initial_user_message(events, initial_user_action)
assert len(events) == 2
assert events[0] == system_message
assert events[1] == initial_user_action
def test_ensure_initial_user_message_correct_already_present(
conversation_memory, initial_user_action
):
"""Test that nothing changes if the correct initial user message is at index 1."""
system_message = SystemMessageAction(content='System')
agent_message = MessageAction(content='Assistant')
agent_message._source = EventSource.USER
events = [
system_message,
initial_user_action,
agent_message,
]
original_events = list(events)
conversation_memory._ensure_initial_user_message(events, initial_user_action)
assert events == original_events
def test_ensure_initial_user_message_incorrect_at_index_1(
conversation_memory, initial_user_action
):
"""Test inserting the correct initial user message when an incorrect message is at index 1."""
system_message = SystemMessageAction(content='System')
incorrect_second_message = MessageAction(content='Assistant')
incorrect_second_message._source = EventSource.AGENT
events = [system_message, incorrect_second_message]
conversation_memory._ensure_initial_user_message(events, initial_user_action)
assert len(events) == 3
assert events[0] == system_message
assert events[1] == initial_user_action # Correct one inserted
assert events[2] == incorrect_second_message # Original second message shifted
def test_ensure_initial_user_message_correct_present_later(
conversation_memory, initial_user_action
):
"""Test inserting the correct initial user message at index 1 even if it exists later."""
system_message = SystemMessageAction(content='System')
incorrect_second_message = MessageAction(content='Assistant')
incorrect_second_message._source = EventSource.AGENT
# Correct initial message is present, but later in the list
events = [system_message, incorrect_second_message]
conversation_memory._ensure_system_message(events)
conversation_memory._ensure_initial_user_message(events, initial_user_action)
assert len(events) == 3 # Should still insert at index 1, not remove the later one
assert events[0] == system_message
assert events[1] == initial_user_action # Correct one inserted at index 1
assert events[2] == incorrect_second_message # Original second message shifted
# The duplicate initial_user_action originally at index 2 is now at index 3 (implicitly tested by length and content)
def test_ensure_initial_user_message_different_user_msg_at_index_1(
conversation_memory, initial_user_action
):
"""Test inserting the correct initial user message when a *different* user message is at index 1."""
system_message = SystemMessageAction(content='System')
different_user_message = MessageAction(content='Different User Message')
different_user_message._source = EventSource.USER
events = [system_message, different_user_message]
conversation_memory._ensure_initial_user_message(events, initial_user_action)
assert len(events) == 2
assert events[0] == system_message
assert events[1] == different_user_message # Original second message remains
def test_ensure_initial_user_message_different_user_msg_at_index_1_and_orphaned_obs(
conversation_memory, initial_user_action
):
"""
Test process_events when an incorrect user message is at index 1 AND
an orphaned observation (with tool_call_metadata but no matching action) exists.
Expect: System msg, CORRECT initial user msg, the incorrect user msg (shifted).
The orphaned observation should be filtered out.
"""
system_message = SystemMessageAction(content='System')
different_user_message = MessageAction(content='Different User Message')
different_user_message._source = EventSource.USER
# Create an orphaned observation (no matching action/tool call request will exist)
# Use a dictionary that mimics ModelResponse structure to satisfy Pydantic
mock_response = {
'id': 'mock_response_id',
'choices': [{'message': {'content': None, 'tool_calls': []}}],
'created': 0,
'model': '',
'object': '',
'usage': {'completion_tokens': 0, 'prompt_tokens': 0, 'total_tokens': 0},
}
orphaned_obs = CmdOutputObservation(
command='orphan_cmd',
content='Orphaned output',
command_id=99,
exit_code=0,
)
orphaned_obs.tool_call_metadata = ToolCallMetadata(
tool_call_id='orphan_call_id',
function_name='execute_bash',
model_response=mock_response,
total_calls_in_response=1,
)
# Initial events list: system, wrong user message, orphaned observation
events = [system_message, different_user_message, orphaned_obs]
# Call the main process_events method
messages = conversation_memory.process_events(
condensed_history=events,
initial_user_action=initial_user_action, # Provide the *correct* initial action
max_message_chars=None,
vision_is_active=False,
)
# Assertions on the final messages list
assert len(messages) == 2
# 1. System message should be first
assert messages[0].role == 'system'
assert messages[0].content[0].text == 'System'
# 2. The different user message should be left at index 1
assert messages[1].role == 'user'
assert messages[1].content[0].text == different_user_message.content
# Implicitly assert that the orphaned_obs was filtered out by checking the length (2)
assert messages[1].content[0].text == 'Hello'
assert messages[2].role == 'assistant'
assert messages[2].content[0].text == 'Hi there'
def test_process_events_with_cmd_output_observation(conversation_memory):
@@ -294,17 +125,14 @@ def test_process_events_with_cmd_output_observation(conversation_memory):
),
)
initial_user_message = MessageAction(content='Initial user message')
initial_user_message._source = EventSource.USER
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_user_action=initial_user_message,
max_message_chars=None,
vision_is_active=False,
)
assert len(messages) == 3 # System + initial user + result
result = messages[2] # The actual result is now at index 2
assert len(messages) == 2
result = messages[1]
assert result.role == 'user'
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
@@ -320,17 +148,14 @@ def test_process_events_with_ipython_run_cell_observation(conversation_memory):
content='IPython output\n![image](data:image/png;base64,ABC123)',
)
initial_user_message = MessageAction(content='Initial user message')
initial_user_message._source = EventSource.USER
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_user_action=initial_user_message,
max_message_chars=None,
vision_is_active=False,
)
assert len(messages) == 3 # System + initial user + result
result = messages[2] # The actual result is now at index 2
assert len(messages) == 2
result = messages[1]
assert result.role == 'user'
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
@@ -347,17 +172,14 @@ def test_process_events_with_agent_delegate_observation(conversation_memory):
content='Content', outputs={'content': 'Delegated agent output'}
)
initial_user_message = MessageAction(content='Initial user message')
initial_user_message._source = EventSource.USER
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_user_action=initial_user_message,
max_message_chars=None,
vision_is_active=False,
)
assert len(messages) == 3 # System + initial user + result
result = messages[2] # The actual result is now at index 2
assert len(messages) == 2
result = messages[1]
assert result.role == 'user'
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
@@ -367,17 +189,14 @@ def test_process_events_with_agent_delegate_observation(conversation_memory):
def test_process_events_with_error_observation(conversation_memory):
obs = ErrorObservation('Error message')
initial_user_message = MessageAction(content='Initial user message')
initial_user_message._source = EventSource.USER
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_user_action=initial_user_message,
max_message_chars=None,
vision_is_active=False,
)
assert len(messages) == 3 # System + initial user + result
result = messages[2] # The actual result is now at index 2
assert len(messages) == 2
result = messages[1]
assert result.role == 'user'
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
@@ -388,13 +207,10 @@ def test_process_events_with_error_observation(conversation_memory):
def test_process_events_with_unknown_observation(conversation_memory):
# Create a mock that inherits from Event but not Action or Observation
obs = Mock(spec=Event)
initial_user_message = MessageAction(content='Initial user message')
initial_user_message._source = EventSource.USER
with pytest.raises(ValueError, match='Unknown event type'):
conversation_memory.process_events(
condensed_history=[obs],
initial_user_action=initial_user_message,
max_message_chars=None,
vision_is_active=False,
)
@@ -410,17 +226,14 @@ def test_process_events_with_file_edit_observation(conversation_memory):
impl_source=FileEditSource.LLM_BASED_EDIT,
)
initial_user_message = MessageAction(content='Initial user message')
initial_user_message._source = EventSource.USER
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_user_action=initial_user_message,
max_message_chars=None,
vision_is_active=False,
)
assert len(messages) == 3 # System + initial user + result
result = messages[2] # The actual result is now at index 2
assert len(messages) == 2
result = messages[1]
assert result.role == 'user'
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
@@ -434,21 +247,18 @@ def test_process_events_with_file_read_observation(conversation_memory):
impl_source=FileReadSource.DEFAULT,
)
initial_user_message = MessageAction(content='Initial user message')
initial_user_message._source = EventSource.USER
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_user_action=initial_user_message,
max_message_chars=None,
vision_is_active=False,
)
assert len(messages) == 3 # System + initial user + result
result = messages[2] # The actual result is now at index 2
assert len(messages) == 2
result = messages[1]
assert result.role == 'user'
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
assert result.content[0].text == '\n\nFile content'
assert result.content[0].text == 'File content'
def test_process_events_with_browser_output_observation(conversation_memory):
@@ -460,17 +270,14 @@ def test_process_events_with_browser_output_observation(conversation_memory):
error=False,
)
initial_user_message = MessageAction(content='Initial user message')
initial_user_message._source = EventSource.USER
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_user_action=initial_user_message,
max_message_chars=None,
vision_is_active=False,
)
assert len(messages) == 3 # System + initial user + result
result = messages[2] # The actual result is now at index 2
assert len(messages) == 2
result = messages[1]
assert result.role == 'user'
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
@@ -480,17 +287,14 @@ def test_process_events_with_browser_output_observation(conversation_memory):
def test_process_events_with_user_reject_observation(conversation_memory):
obs = UserRejectObservation('Action rejected')
initial_user_message = MessageAction(content='Initial user message')
initial_user_message._source = EventSource.USER
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_user_action=initial_user_message,
max_message_chars=None,
vision_is_active=False,
)
assert len(messages) == 3 # System + initial user + result
result = messages[2] # The actual result is now at index 2
assert len(messages) == 2
result = messages[1]
assert result.role == 'user'
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
@@ -513,17 +317,14 @@ def test_process_events_with_empty_environment_info(conversation_memory):
content='Retrieved environment info',
)
initial_user_message = MessageAction(content='Initial user message')
initial_user_message._source = EventSource.USER
messages = conversation_memory.process_events(
condensed_history=[empty_obs],
initial_user_action=initial_user_message,
max_message_chars=None,
vision_is_active=False,
)
# Should only contain system message and initial user message
assert len(messages) == 2
# Should only contain no messages except system message
assert len(messages) == 1
# Verify that build_workspace_context was NOT called since all input values were empty
conversation_memory.prompt_manager.build_workspace_context.assert_not_called()
@@ -547,20 +348,14 @@ def test_process_events_with_function_calling_observation(conversation_memory):
model_response=mock_response,
total_calls_in_response=1,
)
# Define initial user action
initial_user_action = MessageAction(content='Initial user message')
initial_user_action._source = EventSource.USER
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_user_action=initial_user_action,
max_message_chars=None,
vision_is_active=False,
)
# No direct message when using function calling
assert (
len(messages) == 2
) # should be no messages except system message and initial user message
assert len(messages) == 1 # should be no messages except system message
def test_process_events_with_message_action_with_image(conversation_memory):
@@ -570,18 +365,14 @@ def test_process_events_with_message_action_with_image(conversation_memory):
)
action._source = EventSource.AGENT
# Define initial user action
initial_user_action = MessageAction(content='Initial user message')
initial_user_action._source = EventSource.USER
messages = conversation_memory.process_events(
condensed_history=[action],
initial_user_action=initial_user_action,
max_message_chars=None,
vision_is_active=True,
)
assert len(messages) == 3
result = messages[2]
assert len(messages) == 2
result = messages[1]
assert result.role == 'assistant'
assert len(result.content) == 2
assert isinstance(result.content[0], TextContent)
@@ -594,18 +385,14 @@ def test_process_events_with_user_cmd_action(conversation_memory):
action = CmdRunAction(command='ls -l')
action._source = EventSource.USER
# Define initial user action
initial_user_action = MessageAction(content='Initial user message')
initial_user_action._source = EventSource.USER
messages = conversation_memory.process_events(
condensed_history=[action],
initial_user_action=initial_user_action,
max_message_chars=None,
vision_is_active=False,
)
assert len(messages) == 3
result = messages[2]
assert len(messages) == 2
result = messages[1]
assert result.role == 'user'
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
@@ -631,18 +418,14 @@ def test_process_events_with_agent_finish_action_with_tool_metadata(
total_calls_in_response=1,
)
# Define initial user action
initial_user_action = MessageAction(content='Initial user message')
initial_user_action._source = EventSource.USER
messages = conversation_memory.process_events(
condensed_history=[action],
initial_user_action=initial_user_action,
max_message_chars=None,
vision_is_active=False,
)
assert len(messages) == 3
result = messages[2]
assert len(messages) == 2
result = messages[1]
assert result.role == 'assistant'
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
@@ -678,22 +461,18 @@ def test_process_events_with_environment_microagent_observation(conversation_mem
content='Retrieved environment info',
)
# Define initial user action
initial_user_action = MessageAction(content='Initial user message')
initial_user_action._source = EventSource.USER
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_user_action=initial_user_action,
max_message_chars=None,
vision_is_active=False,
)
assert len(messages) == 3
result = messages[2]
assert len(messages) == 2
result = messages[1]
assert result.role == 'user'
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
assert result.content[0].text == '\n\nFormatted repository and runtime info'
assert result.content[0].text == 'Formatted repository and runtime info'
# Verify the prompt_manager was called with the correct parameters
conversation_memory.prompt_manager.build_workspace_context.assert_called_once()
@@ -737,18 +516,14 @@ def test_process_events_with_knowledge_microagent_microagent_observation(
content='Retrieved knowledge from microagents',
)
# Define initial user action
initial_user_action = MessageAction(content='Initial user message')
initial_user_action._source = EventSource.USER
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_user_action=initial_user_action,
max_message_chars=None,
vision_is_active=False,
)
assert len(messages) == 3 # System + Initial User + Result
result = messages[2] # Result is now at index 2
assert len(messages) == 2
result = messages[1]
assert result.role == 'user'
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
@@ -784,18 +559,14 @@ def test_process_events_with_microagent_observation_extensions_disabled(
content='Retrieved environment info',
)
# Define initial user action
initial_user_action = MessageAction(content='Initial user message')
initial_user_action._source = EventSource.USER
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_user_action=initial_user_action,
max_message_chars=None,
vision_is_active=False,
)
# When prompt extensions are disabled, the RecallObservation should be ignored
assert len(messages) == 2 # System + Initial User
assert len(messages) == 1 # should be no messages except system message
# Verify the prompt_manager was not called
conversation_memory.prompt_manager.build_workspace_context.assert_not_called()
@@ -810,18 +581,14 @@ def test_process_events_with_empty_microagent_knowledge(conversation_memory):
content='Retrieved knowledge from microagents',
)
# Define initial user action
initial_user_action = MessageAction(content='Initial user message')
initial_user_action._source = EventSource.USER
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_user_action=initial_user_action,
max_message_chars=None,
vision_is_active=False,
)
# The implementation returns an empty string and it doesn't creates a message
assert len(messages) == 2 # System + Initial User
assert len(messages) == 1 # should be no messages except system message
# When there are no triggered agents, build_microagent_info is not called
conversation_memory.prompt_manager.build_microagent_info.assert_not_called()
@@ -1026,23 +793,19 @@ def test_process_events_with_microagent_observation_deduplication(conversation_m
content='Third retrieval',
)
# Define initial user action
initial_user_action = MessageAction(content='Initial user message')
initial_user_action._source = EventSource.USER
messages = conversation_memory.process_events(
condensed_history=[obs1, obs2, obs3],
initial_user_action=initial_user_action,
max_message_chars=None,
vision_is_active=False,
)
# Verify that only the first occurrence of content for each agent is included
assert len(messages) == 3 # System + Initial User + Result
# Result is now at index 2
assert len(messages) == 2 # with system message
# First microagent should include all agents since they appear here first
assert 'Image best practices v1' in messages[2].content[0].text
assert 'Git best practices v1' in messages[2].content[0].text
assert 'Python best practices v1' in messages[2].content[0].text
assert 'Image best practices v1' in messages[1].content[0].text
assert 'Git best practices v1' in messages[1].content[0].text
assert 'Python best practices v1' in messages[1].content[0].text
def test_process_events_with_microagent_observation_deduplication_disabled_agents(
@@ -1079,22 +842,18 @@ def test_process_events_with_microagent_observation_deduplication_disabled_agent
content='Second retrieval',
)
# Define initial user action
initial_user_action = MessageAction(content='Initial user message')
initial_user_action._source = EventSource.USER
messages = conversation_memory.process_events(
condensed_history=[obs1, obs2],
initial_user_action=initial_user_action,
max_message_chars=None,
vision_is_active=False,
)
# Verify that disabled agents are filtered out and only the first occurrence of enabled agents is included
assert len(messages) == 3 # System + Initial User + Result
# Result is now at index 2
assert len(messages) == 2
# First microagent should include enabled_agent but not disabled_agent
assert 'Disabled agent content' not in messages[2].content[0].text
assert 'Enabled agent content v1' in messages[2].content[0].text
assert 'Disabled agent content' not in messages[1].content[0].text
assert 'Enabled agent content v1' in messages[1].content[0].text
def test_process_events_with_microagent_observation_deduplication_empty(
@@ -1107,22 +866,17 @@ def test_process_events_with_microagent_observation_deduplication_empty(
content='Empty retrieval',
)
# Define initial user action
initial_user_action = MessageAction(content='Initial user message')
initial_user_action._source = EventSource.USER
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_user_action=initial_user_action,
max_message_chars=None,
vision_is_active=False,
)
# Verify that empty RecallObservations are handled gracefully
assert (
len(messages) == 2 # System + Initial User
) # an empty microagent is not added to Messages
len(messages) == 1
) # an empty microagent is not added to Messages, only system message is found
assert messages[0].role == 'system'
assert messages[1].role == 'user' # Initial user message
def test_has_agent_in_earlier_events(conversation_memory):
@@ -1334,183 +1088,13 @@ def test_system_message_in_events(conversation_memory):
system_message._source = EventSource.AGENT
# Process events with the system message in condensed_history
# Define initial user action
initial_user_action = MessageAction(content='Initial user message')
initial_user_action._source = EventSource.USER
messages = conversation_memory.process_events(
condensed_history=[system_message],
initial_user_action=initial_user_action,
max_message_chars=None,
vision_is_active=False,
)
# Check that the system message was processed correctly
assert len(messages) == 2 # System + Initial User
assert len(messages) == 1
assert messages[0].role == 'system'
assert messages[0].content[0].text == 'System message'
assert messages[1].role == 'user' # Initial user message
# Helper function to create mock tool call metadata
def _create_mock_tool_call_metadata(
tool_call_id: str, function_name: str, response_id: str = 'mock_response_id'
) -> ToolCallMetadata:
# Use a dictionary that mimics ModelResponse structure to satisfy Pydantic
mock_response = {
'id': response_id,
'choices': [
{
'message': {
'role': 'assistant',
'content': None, # Content is None for tool calls
'tool_calls': [
{
'id': tool_call_id,
'type': 'function',
'function': {
'name': function_name,
'arguments': '{}',
}, # Args don't matter for this test
}
],
}
}
],
'created': 0,
'model': 'mock_model',
'object': 'chat.completion',
'usage': {'completion_tokens': 0, 'prompt_tokens': 0, 'total_tokens': 0},
}
return ToolCallMetadata(
tool_call_id=tool_call_id,
function_name=function_name,
model_response=mock_response,
total_calls_in_response=1,
)
def test_process_events_partial_history(conversation_memory):
"""
Tests process_events with full and partial histories to verify
_ensure_system_message, _ensure_initial_user_message, and tool call matching logic.
"""
# --- Define Common Events ---
system_message = SystemMessageAction(content='System message')
system_message._source = EventSource.AGENT
user_message = MessageAction(
content='Initial user query'
) # This is the crucial initial_user_action
user_message._source = EventSource.USER
recall_obs = RecallObservation(
recall_type=RecallType.WORKSPACE_CONTEXT,
repo_name='test-repo',
repo_directory='/path/to/repo',
content='Retrieved environment info',
)
recall_obs._source = EventSource.AGENT
cmd_action = CmdRunAction(command='ls', thought='Running ls')
cmd_action._source = EventSource.AGENT
cmd_action.tool_call_metadata = _create_mock_tool_call_metadata(
tool_call_id='call_ls_1', function_name='execute_bash', response_id='resp_ls_1'
)
cmd_obs = CmdOutputObservation(
command_id=1, command='ls', content='file1.txt\nfile2.py', exit_code=0
)
cmd_obs._source = EventSource.AGENT
cmd_obs.tool_call_metadata = _create_mock_tool_call_metadata(
tool_call_id='call_ls_1', function_name='execute_bash', response_id='resp_ls_1'
)
# --- Scenario 1: Full History ---
full_history: list[Event] = [
system_message,
user_message, # Correct initial user message at index 1
recall_obs,
cmd_action,
cmd_obs,
]
messages_full = conversation_memory.process_events(
condensed_history=list(full_history), # Pass a copy
initial_user_action=user_message, # Provide the initial action
max_message_chars=None,
vision_is_active=False,
)
# Expected: System, User, Recall (formatted), Assistant (tool call), Tool Response
assert len(messages_full) == 5
assert messages_full[0].role == 'system'
assert messages_full[0].content[0].text == 'System message'
assert messages_full[1].role == 'user'
assert messages_full[1].content[0].text == 'Initial user query'
assert messages_full[2].role == 'user' # Recall obs becomes user message
assert (
'Formatted repository and runtime info' in messages_full[2].content[0].text
) # From fixture mock
assert messages_full[3].role == 'assistant'
assert messages_full[3].tool_calls is not None
assert len(messages_full[3].tool_calls) == 1
assert messages_full[3].tool_calls[0].id == 'call_ls_1'
assert messages_full[4].role == 'tool'
assert messages_full[4].tool_call_id == 'call_ls_1'
assert 'file1.txt' in messages_full[4].content[0].text
# --- Scenario 2: Partial History (Action + Observation) ---
# Simulates processing only the last action/observation pair
partial_history_action_obs: list[Event] = [
cmd_action,
cmd_obs,
]
messages_partial_action_obs = conversation_memory.process_events(
condensed_history=list(partial_history_action_obs), # Pass a copy
initial_user_action=user_message, # Provide the initial action
max_message_chars=None,
vision_is_active=False,
)
# Expected: System (added), Initial User (added), Assistant (tool call), Tool Response
assert len(messages_partial_action_obs) == 4
assert (
messages_partial_action_obs[0].role == 'system'
) # Added by _ensure_system_message
assert messages_partial_action_obs[0].content[0].text == 'System message'
assert (
messages_partial_action_obs[1].role == 'user'
) # Added by _ensure_initial_user_message
assert messages_partial_action_obs[1].content[0].text == 'Initial user query'
assert messages_partial_action_obs[2].role == 'assistant'
assert messages_partial_action_obs[2].tool_calls is not None
assert len(messages_partial_action_obs[2].tool_calls) == 1
assert messages_partial_action_obs[2].tool_calls[0].id == 'call_ls_1'
assert messages_partial_action_obs[3].role == 'tool'
assert messages_partial_action_obs[3].tool_call_id == 'call_ls_1'
assert 'file1.txt' in messages_partial_action_obs[3].content[0].text
# --- Scenario 3: Partial History (Observation Only) ---
# Simulates processing only the last observation
partial_history_obs_only: list[Event] = [
cmd_obs,
]
messages_partial_obs_only = conversation_memory.process_events(
condensed_history=list(partial_history_obs_only), # Pass a copy
initial_user_action=user_message, # Provide the initial action
max_message_chars=None,
vision_is_active=False,
)
# Expected: System (added), Initial User (added).
# The CmdOutputObservation has tool_call_metadata, but there's no corresponding
# assistant message (from CmdRunAction) with the matching tool_call.id in the input history.
# Therefore, _filter_unmatched_tool_calls should remove the tool response message.
assert len(messages_partial_obs_only) == 2
assert (
messages_partial_obs_only[0].role == 'system'
) # Added by _ensure_system_message
assert messages_partial_obs_only[0].content[0].text == 'System message'
assert (
messages_partial_obs_only[1].role == 'user'
) # Added by _ensure_initial_user_message
assert messages_partial_obs_only[1].content[0].text == 'Initial user query'
+2 -5
View File
@@ -76,7 +76,7 @@ def test_get_messages(codeact_agent: CodeActAgent):
history.append(message_action_5)
codeact_agent.reset()
messages = codeact_agent._get_messages(history, message_action_1)
messages = codeact_agent._get_messages(history)
assert (
len(messages) == 6
@@ -106,19 +106,16 @@ def test_get_messages_prompt_caching(codeact_agent: CodeActAgent):
history.append(system_message_action)
# Add multiple user and agent messages
initial_user_message = None # Keep track of the first user message
for i in range(15):
message_action_user = MessageAction(f'User message {i}')
message_action_user._source = 'user'
if initial_user_message is None:
initial_user_message = message_action_user # Store the first one
history.append(message_action_user)
message_action_agent = MessageAction(f'Agent message {i}')
message_action_agent._source = 'agent'
history.append(message_action_agent)
codeact_agent.reset()
messages = codeact_agent._get_messages(history, initial_user_message)
messages = codeact_agent._get_messages(history)
# Check that only the last two user messages have cache_prompt=True
cached_user_messages = [