mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
2 Commits
0.44.0
...
add-cli-ll
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4cbe46b56c | ||
|
|
f777029546 |
@@ -136,7 +136,7 @@ poetry run pytest ./tests/unit/test_*.py
|
||||
To reduce build time (e.g., if no changes were made to the client-runtime component), you can use an existing Docker
|
||||
container image by setting the SANDBOX_RUNTIME_CONTAINER_IMAGE environment variable to the desired Docker image.
|
||||
|
||||
Example: `export SANDBOX_RUNTIME_CONTAINER_IMAGE=ghcr.io/all-hands-ai/runtime:0.44-nikolaik`
|
||||
Example: `export SANDBOX_RUNTIME_CONTAINER_IMAGE=ghcr.io/all-hands-ai/runtime:0.43-nikolaik`
|
||||
|
||||
## Develop inside Docker container
|
||||
|
||||
|
||||
@@ -62,17 +62,17 @@ system requirements and more information.
|
||||
|
||||
|
||||
```bash
|
||||
docker pull docker.all-hands.dev/all-hands-ai/runtime:0.44-nikolaik
|
||||
docker pull docker.all-hands.dev/all-hands-ai/runtime:0.43-nikolaik
|
||||
|
||||
docker run -it --rm --pull=always \
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:0.44-nikolaik \
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:0.43-nikolaik \
|
||||
-e LOG_ALL_EVENTS=true \
|
||||
-v /var/run/docker.sock:/var/run/docker.sock \
|
||||
-v ~/.openhands-state:/.openhands-state \
|
||||
-p 3000:3000 \
|
||||
--add-host host.docker.internal:host-gateway \
|
||||
--name openhands-app \
|
||||
docker.all-hands.dev/all-hands-ai/openhands:0.44
|
||||
docker.all-hands.dev/all-hands-ai/openhands:0.43
|
||||
```
|
||||
|
||||
You'll find OpenHands running at [http://localhost:3000](http://localhost:3000)!
|
||||
|
||||
@@ -51,17 +51,17 @@ OpenHands也可以使用Docker在本地系统上运行。
|
||||
|
||||
|
||||
```bash
|
||||
docker pull docker.all-hands.dev/all-hands-ai/runtime:0.44-nikolaik
|
||||
docker pull docker.all-hands.dev/all-hands-ai/runtime:0.43-nikolaik
|
||||
|
||||
docker run -it --rm --pull=always \
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:0.44-nikolaik \
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:0.43-nikolaik \
|
||||
-e LOG_ALL_EVENTS=true \
|
||||
-v /var/run/docker.sock:/var/run/docker.sock \
|
||||
-v ~/.openhands-state:/.openhands-state \
|
||||
-p 3000:3000 \
|
||||
--add-host host.docker.internal:host-gateway \
|
||||
--name openhands-app \
|
||||
docker.all-hands.dev/all-hands-ai/openhands:0.44
|
||||
docker.all-hands.dev/all-hands-ai/openhands:0.43
|
||||
```
|
||||
|
||||
您将在[http://localhost:3000](http://localhost:3000)找到运行中的OpenHands!
|
||||
|
||||
@@ -12,7 +12,7 @@ services:
|
||||
- SANDBOX_API_HOSTNAME=host.docker.internal
|
||||
- DOCKER_HOST_ADDR=host.docker.internal
|
||||
#
|
||||
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-ghcr.io/all-hands-ai/runtime:0.44-nikolaik}
|
||||
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-ghcr.io/all-hands-ai/runtime:0.43-nikolaik}
|
||||
- SANDBOX_USER_ID=${SANDBOX_USER_ID:-1234}
|
||||
- WORKSPACE_MOUNT_PATH=${WORKSPACE_BASE:-$PWD/workspace}
|
||||
ports:
|
||||
|
||||
@@ -7,7 +7,7 @@ services:
|
||||
image: openhands:latest
|
||||
container_name: openhands-app-${DATE:-}
|
||||
environment:
|
||||
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-docker.all-hands.dev/all-hands-ai/runtime:0.44-nikolaik}
|
||||
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-docker.all-hands.dev/all-hands-ai/runtime:0.43-nikolaik}
|
||||
#- SANDBOX_USER_ID=${SANDBOX_USER_ID:-1234} # enable this only if you want a specific non-root sandbox user but you will have to manually adjust permissions of openhands-state for this user
|
||||
- WORKSPACE_MOUNT_PATH=${WORKSPACE_BASE:-$PWD/workspace}
|
||||
ports:
|
||||
|
||||
@@ -5,38 +5,15 @@ description: This guide walks you through installing the OpenHands Slack app.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- You are a slack workspace admin
|
||||
- Access to OpenHands Cloud
|
||||
|
||||
## Installation Steps
|
||||
|
||||
<AccordionGroup>
|
||||
<Accordion title="Install Slack App (only for Slack admins/owners)">
|
||||
|
||||
**This step is for Slack admins/owners**
|
||||
|
||||
1. Make sure you have permissions to install Apps to your workspace.
|
||||
2. Click the button below to install OpenHands Slack App <a target="_blank" href="https://slack.com/oauth/v2/authorize?client_id=7477886716822.8729519890534&scope=app_mentions:read,chat:write,users:read,channels:history,groups:history,mpim:history,im:history&user_scope=channels:history,groups:history,im:history,mpim:history"><img alt="Add to Slack" height="40" width="139" src="https://platform.slack-edge.com/img/add_to_slack.png" srcSet="https://platform.slack-edge.com/img/add_to_slack.png 1x, https://platform.slack-edge.com/img/add_to_slack@2x.png 2x" /></a>
|
||||
3. In the top right corner, select the workspace to install the OpenHands Slack app.
|
||||
4. Review permissions and click allow.
|
||||
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Authorize Slack App (for all Slack workspace members)">
|
||||
|
||||
**Make sure your Slack workspace admin/owner has installed OpenHands Slack App first**
|
||||
|
||||
Every user in the slack workspace (including admins/owners) must link their Cloud OpenHands account to the OpenHands Slack App. To do this
|
||||
1. Visit [integrations settings](https://app.all-hands.dev/settings/integrations) in OpenHands Cloud.
|
||||
2. Click the button "Install Slack App".
|
||||
3. In the top right corner, select the workspace to install the OpenHands Slack app.
|
||||
4. Review permissions and click allow.
|
||||
|
||||
Depending on the workspace settings, you may need approval from your slack admin to authorize the Slack App.
|
||||
|
||||
</Accordion>
|
||||
|
||||
</AccordionGroup>
|
||||
|
||||
1. Log in to [OpenHands Cloud](https://app.all-hands.dev)
|
||||
2. Click the button below to OpenHands Slack App <a target="_blank" href="https://slack.com/oauth/v2/authorize?client_id=7477886716822.8729519890534&scope=app_mentions:read,chat:write,users:read,channels:history,groups:history,mpim:history,im:history&user_scope=channels:history,groups:history,im:history,mpim:history"><img alt="Add to Slack" height="40" width="139" src="https://platform.slack-edge.com/img/add_to_slack.png" srcSet="https://platform.slack-edge.com/img/add_to_slack.png 1x, https://platform.slack-edge.com/img/add_to_slack@2x.png 2x" /></a>
|
||||
3. In the top right corner, select the workspace to install the OpenHands Slack app.
|
||||
4. Review permissions and click allow
|
||||
|
||||
## Working With the Slack App
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ poetry run python -m openhands.cli.main
|
||||
```bash
|
||||
docker run -it \
|
||||
--pull=always \
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:0.44-nikolaik \
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:0.43-nikolaik \
|
||||
-e SANDBOX_USER_ID=$(id -u) \
|
||||
-e SANDBOX_VOLUMES=$SANDBOX_VOLUMES \
|
||||
-e LLM_API_KEY=$LLM_API_KEY \
|
||||
@@ -56,7 +56,7 @@ docker run -it \
|
||||
-v ~/.openhands-state:/.openhands-state \
|
||||
--add-host host.docker.internal:host-gateway \
|
||||
--name openhands-app-$(date +%Y%m%d%H%M%S) \
|
||||
docker.all-hands.dev/all-hands-ai/openhands:0.44 \
|
||||
docker.all-hands.dev/all-hands-ai/openhands:0.43 \
|
||||
python -m openhands.cli.main --override-cli-mode true
|
||||
```
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ To run OpenHands in Headless mode with Docker:
|
||||
```bash
|
||||
docker run -it \
|
||||
--pull=always \
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:0.44-nikolaik \
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:0.43-nikolaik \
|
||||
-e SANDBOX_USER_ID=$(id -u) \
|
||||
-e SANDBOX_VOLUMES=$SANDBOX_VOLUMES \
|
||||
-e LLM_API_KEY=$LLM_API_KEY \
|
||||
@@ -42,7 +42,7 @@ docker run -it \
|
||||
-v ~/.openhands-state:/.openhands-state \
|
||||
--add-host host.docker.internal:host-gateway \
|
||||
--name openhands-app-$(date +%Y%m%d%H%M%S) \
|
||||
docker.all-hands.dev/all-hands-ai/openhands:0.44 \
|
||||
docker.all-hands.dev/all-hands-ai/openhands:0.43 \
|
||||
python -m openhands.core.main -t "write a bash script that prints hi"
|
||||
```
|
||||
|
||||
|
||||
@@ -54,25 +54,25 @@ Check [the installation guide](/usage/local-setup) to make sure you have all the
|
||||
export LMSTUDIO_MODEL_NAME="imported-models/uncategorized/devstralq4_k_m.gguf" # <- Replace this with the model name you copied from LMStudio
|
||||
export LMSTUDIO_URL="http://host.docker.internal:1234" # <- Replace this with the port from LMStudio
|
||||
|
||||
docker pull docker.all-hands.dev/all-hands-ai/runtime:0.44-nikolaik
|
||||
docker pull docker.all-hands.dev/all-hands-ai/runtime:0.43-nikolaik
|
||||
|
||||
mkdir -p ~/.openhands-state && echo '{"language":"en","agent":"CodeActAgent","max_iterations":null,"security_analyzer":null,"confirmation_mode":false,"llm_model":"lm_studio/'$LMSTUDIO_MODEL_NAME'","llm_api_key":"dummy","llm_base_url":"'$LMSTUDIO_URL/v1'","remote_runtime_resource_factor":null,"github_token":null,"enable_default_condenser":true,"user_consents_to_analytics":true}' > ~/.openhands-state/settings.json
|
||||
|
||||
docker run -it --rm --pull=always \
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:0.44-nikolaik \
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:0.43-nikolaik \
|
||||
-e LOG_ALL_EVENTS=true \
|
||||
-v /var/run/docker.sock:/var/run/docker.sock \
|
||||
-v ~/.openhands-state:/.openhands-state \
|
||||
-p 3000:3000 \
|
||||
--add-host host.docker.internal:host-gateway \
|
||||
--name openhands-app \
|
||||
docker.all-hands.dev/all-hands-ai/openhands:0.44
|
||||
docker.all-hands.dev/all-hands-ai/openhands:0.43
|
||||
```
|
||||
|
||||
Once your server is running -- you can visit `http://localhost:3000` in your browser to use OpenHands with local Devstral model:
|
||||
```
|
||||
Digest: sha256:e72f9baecb458aedb9afc2cd5bc935118d1868719e55d50da73190d3a85c674f
|
||||
Status: Image is up to date for docker.all-hands.dev/all-hands-ai/openhands:0.44
|
||||
Status: Image is up to date for docker.all-hands.dev/all-hands-ai/openhands:0.43
|
||||
Starting OpenHands...
|
||||
Running OpenHands as root
|
||||
14:22:13 - openhands:INFO: server_config.py:50 - Using config class None
|
||||
|
||||
@@ -67,17 +67,17 @@ A system with a modern processor and a minimum of **4GB RAM** is recommended to
|
||||
### Start the App
|
||||
|
||||
```bash
|
||||
docker pull docker.all-hands.dev/all-hands-ai/runtime:0.44-nikolaik
|
||||
docker pull docker.all-hands.dev/all-hands-ai/runtime:0.43-nikolaik
|
||||
|
||||
docker run -it --rm --pull=always \
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:0.44-nikolaik \
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:0.43-nikolaik \
|
||||
-e LOG_ALL_EVENTS=true \
|
||||
-v /var/run/docker.sock:/var/run/docker.sock \
|
||||
-v ~/.openhands-state:/.openhands-state \
|
||||
-p 3000:3000 \
|
||||
--add-host host.docker.internal:host-gateway \
|
||||
--name openhands-app \
|
||||
docker.all-hands.dev/all-hands-ai/openhands:0.44
|
||||
docker.all-hands.dev/all-hands-ai/openhands:0.43
|
||||
```
|
||||
|
||||
You'll find OpenHands running at http://localhost:3000!
|
||||
|
||||
@@ -31,7 +31,7 @@ const renderRepoConnector = () => {
|
||||
},
|
||||
{
|
||||
Component: () => <div data-testid="git-settings-screen" />,
|
||||
path: "/settings/integrations",
|
||||
path: "/settings/git",
|
||||
},
|
||||
],
|
||||
},
|
||||
|
||||
@@ -35,13 +35,13 @@ const queryClient = new QueryClient();
|
||||
const GitSettingsRouterStub = createRoutesStub([
|
||||
{
|
||||
Component: GitSettingsScreen,
|
||||
path: "/settings/integrations",
|
||||
path: "/settings/github",
|
||||
},
|
||||
]);
|
||||
|
||||
const renderGitSettingsScreen = () => {
|
||||
const { rerender, ...rest } = render(
|
||||
<GitSettingsRouterStub initialEntries={["/settings/integrations"]} />,
|
||||
<GitSettingsRouterStub initialEntries={["/settings/github"]} />,
|
||||
{
|
||||
wrapper: ({ children }) => (
|
||||
<QueryClientProvider client={queryClient}>
|
||||
@@ -54,7 +54,7 @@ const renderGitSettingsScreen = () => {
|
||||
const rerenderGitSettingsScreen = () =>
|
||||
rerender(
|
||||
<QueryClientProvider client={queryClient}>
|
||||
<GitSettingsRouterStub initialEntries={["/settings/integrations"]} />
|
||||
<GitSettingsRouterStub initialEntries={["/settings/github"]} />
|
||||
</QueryClientProvider>,
|
||||
);
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ const RouterStub = createRoutesStub([
|
||||
},
|
||||
{
|
||||
Component: () => <div data-testid="git-settings-screen" />,
|
||||
path: "/settings/integrations",
|
||||
path: "/settings/git",
|
||||
},
|
||||
],
|
||||
},
|
||||
|
||||
@@ -30,7 +30,7 @@ vi.mock("react-i18next", async () => {
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => {
|
||||
const translations: Record<string, string> = {
|
||||
"SETTINGS$NAV_INTEGRATIONS": "Integrations",
|
||||
"SETTINGS$NAV_GIT": "Git",
|
||||
"SETTINGS$NAV_APPLICATION": "Application",
|
||||
"SETTINGS$NAV_CREDITS": "Credits",
|
||||
"SETTINGS$NAV_API_KEYS": "API Keys",
|
||||
@@ -61,7 +61,7 @@ describe("Settings Billing", () => {
|
||||
},
|
||||
{
|
||||
Component: () => <div data-testid="git-settings-screen" />,
|
||||
path: "/settings/integrations",
|
||||
path: "/settings/git",
|
||||
},
|
||||
{
|
||||
Component: () => <div data-testid="user-settings-screen" />,
|
||||
|
||||
@@ -14,7 +14,7 @@ vi.mock("react-i18next", async () => {
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => {
|
||||
const translations: Record<string, string> = {
|
||||
SETTINGS$NAV_INTEGRATIONS: "Integrations",
|
||||
SETTINGS$NAV_GIT: "Git",
|
||||
SETTINGS$NAV_APPLICATION: "Application",
|
||||
SETTINGS$NAV_CREDITS: "Credits",
|
||||
SETTINGS$NAV_API_KEYS: "API Keys",
|
||||
@@ -49,7 +49,7 @@ describe("Settings Screen", () => {
|
||||
},
|
||||
{
|
||||
Component: () => <div data-testid="git-settings-screen" />,
|
||||
path: "/settings/integrations",
|
||||
path: "/settings/git",
|
||||
},
|
||||
{
|
||||
Component: () => <div data-testid="application-settings-screen" />,
|
||||
@@ -79,7 +79,7 @@ describe("Settings Screen", () => {
|
||||
};
|
||||
|
||||
it("should render the navbar", async () => {
|
||||
const sectionsToInclude = ["llm", "integrations", "application", "secrets"];
|
||||
const sectionsToInclude = ["llm", "git", "application", "secrets"];
|
||||
const sectionsToExclude = ["api keys", "credits"];
|
||||
const getConfigSpy = vi.spyOn(OpenHands, "getConfig");
|
||||
// @ts-expect-error - only return app mode
|
||||
@@ -111,7 +111,7 @@ describe("Settings Screen", () => {
|
||||
APP_MODE: "saas",
|
||||
});
|
||||
const sectionsToInclude = [
|
||||
"integrations",
|
||||
"git",
|
||||
"application",
|
||||
"credits",
|
||||
"secrets",
|
||||
|
||||
4
frontend/package-lock.json
generated
4
frontend/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "openhands-frontend",
|
||||
"version": "0.44.0",
|
||||
"version": "0.43.0",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "openhands-frontend",
|
||||
"version": "0.44.0",
|
||||
"version": "0.43.0",
|
||||
"dependencies": {
|
||||
"@heroui/react": "^2.8.0-beta.7",
|
||||
"@microlink/react-json-view": "^1.26.2",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "openhands-frontend",
|
||||
"version": "0.44.0",
|
||||
"version": "0.43.0",
|
||||
"private": true,
|
||||
"type": "module",
|
||||
"engines": {
|
||||
|
||||
@@ -111,59 +111,6 @@ class OpenHands {
|
||||
return data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Submit conversation feedback with rating
|
||||
* @param conversationId The conversation ID
|
||||
* @param rating The rating (1-5)
|
||||
* @param eventId Optional event ID this feedback corresponds to
|
||||
* @param reason Optional reason for the rating
|
||||
* @returns Response from the feedback endpoint
|
||||
*/
|
||||
static async submitConversationFeedback(
|
||||
conversationId: string,
|
||||
rating: number,
|
||||
eventId?: number,
|
||||
reason?: string,
|
||||
): Promise<{ status: string; message: string }> {
|
||||
const url = `/feedback/conversation`;
|
||||
const payload = {
|
||||
conversation_id: conversationId,
|
||||
event_id: eventId,
|
||||
rating,
|
||||
reason,
|
||||
metadata: { source: "likert-scale" },
|
||||
};
|
||||
const { data } = await openHands.post<{ status: string; message: string }>(
|
||||
url,
|
||||
payload,
|
||||
);
|
||||
return data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if feedback exists for a specific conversation and event
|
||||
* @param conversationId The conversation ID
|
||||
* @param eventId The event ID to check
|
||||
* @returns Feedback data including existence, rating, and reason
|
||||
*/
|
||||
static async checkFeedbackExists(
|
||||
conversationId: string,
|
||||
eventId: number,
|
||||
): Promise<{ exists: boolean; rating?: number; reason?: string }> {
|
||||
try {
|
||||
const url = `/feedback/conversation/${conversationId}/${eventId}`;
|
||||
const { data } = await openHands.get<{
|
||||
exists: boolean;
|
||||
rating?: number;
|
||||
reason?: string;
|
||||
}>(url);
|
||||
return data;
|
||||
} catch (error) {
|
||||
// Error checking if feedback exists
|
||||
return { exists: false };
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Authenticate with GitHub token
|
||||
* @returns Response with authentication status and user info if successful
|
||||
|
||||
@@ -18,7 +18,6 @@ import { useWsClient } from "#/context/ws-client-provider";
|
||||
import { Messages } from "./messages";
|
||||
import { ChatSuggestions } from "./chat-suggestions";
|
||||
import { ActionSuggestions } from "./action-suggestions";
|
||||
import { ScrollProvider } from "#/context/scroll-context";
|
||||
|
||||
import { ScrollToBottomButton } from "#/components/shared/buttons/scroll-to-bottom-button";
|
||||
import { LoadingSpinner } from "#/components/shared/loading-spinner";
|
||||
@@ -29,7 +28,6 @@ import { useOptimisticUserMessage } from "#/hooks/use-optimistic-user-message";
|
||||
import { useWSErrorMessage } from "#/hooks/use-ws-error-message";
|
||||
import { ErrorMessageBanner } from "./error-message-banner";
|
||||
import { shouldRenderEvent } from "./event-content-helpers/should-render-event";
|
||||
import { useConfig } from "#/hooks/query/use-config";
|
||||
|
||||
function getEntryPoint(
|
||||
hasRepository: boolean | null,
|
||||
@@ -47,15 +45,8 @@ export function ChatInterface() {
|
||||
useOptimisticUserMessage();
|
||||
const { t } = useTranslation();
|
||||
const scrollRef = React.useRef<HTMLDivElement>(null);
|
||||
const {
|
||||
scrollDomToBottom,
|
||||
onChatBodyScroll,
|
||||
hitBottom,
|
||||
autoScroll,
|
||||
setAutoScroll,
|
||||
setHitBottom,
|
||||
} = useScrollToBottom(scrollRef);
|
||||
const { data: config } = useConfig();
|
||||
const { scrollDomToBottom, onChatBodyScroll, hitBottom } =
|
||||
useScrollToBottom(scrollRef);
|
||||
|
||||
const { curAgentState } = useSelector((state: RootState) => state.agent);
|
||||
|
||||
@@ -135,97 +126,80 @@ export function ChatInterface() {
|
||||
curAgentState === AgentState.AWAITING_USER_INPUT ||
|
||||
curAgentState === AgentState.FINISHED;
|
||||
|
||||
// Create a ScrollProvider with the scroll hook values
|
||||
const scrollProviderValue = {
|
||||
scrollRef,
|
||||
autoScroll,
|
||||
setAutoScroll,
|
||||
scrollDomToBottom,
|
||||
hitBottom,
|
||||
setHitBottom,
|
||||
onChatBodyScroll,
|
||||
};
|
||||
|
||||
return (
|
||||
<ScrollProvider value={scrollProviderValue}>
|
||||
<div className="h-full flex flex-col justify-between">
|
||||
{events.length === 0 && !optimisticUserMessage && (
|
||||
<ChatSuggestions onSuggestionsClick={setMessageToSend} />
|
||||
<div className="h-full flex flex-col justify-between">
|
||||
{events.length === 0 && !optimisticUserMessage && (
|
||||
<ChatSuggestions onSuggestionsClick={setMessageToSend} />
|
||||
)}
|
||||
|
||||
<div
|
||||
ref={scrollRef}
|
||||
onScroll={(e) => onChatBodyScroll(e.currentTarget)}
|
||||
className="scrollbar scrollbar-thin scrollbar-thumb-gray-400 scrollbar-thumb-rounded-full scrollbar-track-gray-800 hover:scrollbar-thumb-gray-300 flex flex-col grow overflow-y-auto overflow-x-hidden px-4 pt-4 gap-2 fast-smooth-scroll"
|
||||
>
|
||||
{isLoadingMessages && (
|
||||
<div className="flex justify-center">
|
||||
<LoadingSpinner size="small" />
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div
|
||||
ref={scrollRef}
|
||||
onScroll={(e) => onChatBodyScroll(e.currentTarget)}
|
||||
className="scrollbar scrollbar-thin scrollbar-thumb-gray-400 scrollbar-thumb-rounded-full scrollbar-track-gray-800 hover:scrollbar-thumb-gray-300 flex flex-col grow overflow-y-auto overflow-x-hidden px-4 pt-4 gap-2 fast-smooth-scroll"
|
||||
>
|
||||
{isLoadingMessages && (
|
||||
<div className="flex justify-center">
|
||||
<LoadingSpinner size="small" />
|
||||
</div>
|
||||
)}
|
||||
|
||||
{!isLoadingMessages && (
|
||||
<Messages
|
||||
messages={events}
|
||||
isAwaitingUserConfirmation={
|
||||
curAgentState === AgentState.AWAITING_USER_CONFIRMATION
|
||||
}
|
||||
/>
|
||||
)}
|
||||
|
||||
{isWaitingForUserInput &&
|
||||
events.length > 0 &&
|
||||
!optimisticUserMessage && (
|
||||
<ActionSuggestions
|
||||
onSuggestionsClick={(value) => handleSendMessage(value, [])}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="flex flex-col gap-[6px] px-4 pb-4">
|
||||
<div className="flex justify-between relative">
|
||||
{config?.APP_MODE !== "saas" && (
|
||||
<TrajectoryActions
|
||||
onPositiveFeedback={() =>
|
||||
onClickShareFeedbackActionButton("positive")
|
||||
}
|
||||
onNegativeFeedback={() =>
|
||||
onClickShareFeedbackActionButton("negative")
|
||||
}
|
||||
onExportTrajectory={() => onClickExportTrajectoryButton()}
|
||||
/>
|
||||
)}
|
||||
|
||||
<div className="absolute left-1/2 transform -translate-x-1/2 bottom-0">
|
||||
{curAgentState === AgentState.RUNNING && <TypingIndicator />}
|
||||
</div>
|
||||
|
||||
{!hitBottom && <ScrollToBottomButton onClick={scrollDomToBottom} />}
|
||||
</div>
|
||||
|
||||
{errorMessage && <ErrorMessageBanner message={errorMessage} />}
|
||||
|
||||
<InteractiveChatBox
|
||||
onSubmit={handleSendMessage}
|
||||
onStop={handleStop}
|
||||
isDisabled={
|
||||
curAgentState === AgentState.LOADING ||
|
||||
{!isLoadingMessages && (
|
||||
<Messages
|
||||
messages={events}
|
||||
isAwaitingUserConfirmation={
|
||||
curAgentState === AgentState.AWAITING_USER_CONFIRMATION
|
||||
}
|
||||
mode={curAgentState === AgentState.RUNNING ? "stop" : "submit"}
|
||||
value={messageToSend ?? undefined}
|
||||
onChange={setMessageToSend}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{config?.APP_MODE !== "saas" && (
|
||||
<FeedbackModal
|
||||
isOpen={feedbackModalIsOpen}
|
||||
onClose={() => setFeedbackModalIsOpen(false)}
|
||||
polarity={feedbackPolarity}
|
||||
/>
|
||||
)}
|
||||
|
||||
{isWaitingForUserInput &&
|
||||
events.length > 0 &&
|
||||
!optimisticUserMessage && (
|
||||
<ActionSuggestions
|
||||
onSuggestionsClick={(value) => handleSendMessage(value, [])}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</ScrollProvider>
|
||||
|
||||
<div className="flex flex-col gap-[6px] px-4 pb-4">
|
||||
<div className="flex justify-between relative">
|
||||
<TrajectoryActions
|
||||
onPositiveFeedback={() =>
|
||||
onClickShareFeedbackActionButton("positive")
|
||||
}
|
||||
onNegativeFeedback={() =>
|
||||
onClickShareFeedbackActionButton("negative")
|
||||
}
|
||||
onExportTrajectory={() => onClickExportTrajectoryButton()}
|
||||
/>
|
||||
|
||||
<div className="absolute left-1/2 transform -translate-x-1/2 bottom-0">
|
||||
{curAgentState === AgentState.RUNNING && <TypingIndicator />}
|
||||
</div>
|
||||
|
||||
{!hitBottom && <ScrollToBottomButton onClick={scrollDomToBottom} />}
|
||||
</div>
|
||||
|
||||
{errorMessage && <ErrorMessageBanner message={errorMessage} />}
|
||||
|
||||
<InteractiveChatBox
|
||||
onSubmit={handleSendMessage}
|
||||
onStop={handleStop}
|
||||
isDisabled={
|
||||
curAgentState === AgentState.LOADING ||
|
||||
curAgentState === AgentState.AWAITING_USER_CONFIRMATION
|
||||
}
|
||||
mode={curAgentState === AgentState.RUNNING ? "stop" : "submit"}
|
||||
value={messageToSend ?? undefined}
|
||||
onChange={setMessageToSend}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<FeedbackModal
|
||||
isOpen={feedbackModalIsOpen}
|
||||
onClose={() => setFeedbackModalIsOpen(false)}
|
||||
polarity={feedbackPolarity}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import React from "react";
|
||||
import { ConfirmationButtons } from "#/components/shared/buttons/confirmation-buttons";
|
||||
import { OpenHandsAction } from "#/types/core/actions";
|
||||
import {
|
||||
@@ -19,10 +18,6 @@ import { MCPObservationContent } from "./mcp-observation-content";
|
||||
import { getObservationResult } from "./event-content-helpers/get-observation-result";
|
||||
import { getEventContent } from "./event-content-helpers/get-event-content";
|
||||
import { GenericEventMessage } from "./generic-event-message";
|
||||
import { LikertScale } from "../feedback/likert-scale";
|
||||
|
||||
import { useConfig } from "#/hooks/query/use-config";
|
||||
import { useFeedbackExists } from "#/hooks/query/use-feedback-exists";
|
||||
|
||||
const hasThoughtProperty = (
|
||||
obj: Record<string, unknown>,
|
||||
@@ -44,14 +39,6 @@ export function EventMessage({
|
||||
const shouldShowConfirmationButtons =
|
||||
isLastMessage && event.source === "agent" && isAwaitingUserConfirmation;
|
||||
|
||||
const { data: config } = useConfig();
|
||||
|
||||
// Use our query hook to check if feedback exists and get rating/reason
|
||||
const {
|
||||
data: feedbackData = { exists: false },
|
||||
isLoading: isCheckingFeedback,
|
||||
} = useFeedbackExists(isFinishAction(event) ? event.id : undefined);
|
||||
|
||||
if (isErrorObservation(event)) {
|
||||
return (
|
||||
<ErrorMessage
|
||||
@@ -68,25 +55,9 @@ export function EventMessage({
|
||||
return null;
|
||||
}
|
||||
|
||||
const showLikertScale =
|
||||
config?.APP_MODE === "saas" &&
|
||||
isFinishAction(event) &&
|
||||
isLastMessage &&
|
||||
!isCheckingFeedback;
|
||||
|
||||
if (isFinishAction(event)) {
|
||||
return (
|
||||
<>
|
||||
<ChatMessage type="agent" message={getEventContent(event).details} />
|
||||
{showLikertScale && (
|
||||
<LikertScale
|
||||
eventId={event.id}
|
||||
initiallySubmitted={feedbackData.exists}
|
||||
initialRating={feedbackData.rating}
|
||||
initialReason={feedbackData.reason}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
<ChatMessage type="agent" message={getEventContent(event).details} />
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,248 +0,0 @@
|
||||
import React, { useState, useEffect, useContext } from "react";
|
||||
import { cn } from "#/utils/utils";
|
||||
import i18n from "#/i18n";
|
||||
import { useSubmitConversationFeedback } from "#/hooks/mutation/use-submit-conversation-feedback";
|
||||
import { ScrollContext } from "#/context/scroll-context";
|
||||
|
||||
// Global timeout duration in milliseconds
|
||||
const AUTO_SUBMIT_TIMEOUT = 10000;
|
||||
|
||||
interface LikertScaleProps {
|
||||
eventId?: number;
|
||||
initiallySubmitted?: boolean;
|
||||
initialRating?: number;
|
||||
initialReason?: string;
|
||||
}
|
||||
|
||||
const FEEDBACK_REASONS = [
|
||||
i18n.t("FEEDBACK$REASON_MISUNDERSTOOD_INSTRUCTION"),
|
||||
i18n.t("FEEDBACK$REASON_FORGOT_CONTEXT"),
|
||||
i18n.t("FEEDBACK$REASON_UNNECESSARY_CHANGES"),
|
||||
i18n.t("FEEDBACK$REASON_OTHER"),
|
||||
];
|
||||
|
||||
export function LikertScale({
|
||||
eventId,
|
||||
initiallySubmitted = false,
|
||||
initialRating,
|
||||
initialReason,
|
||||
}: LikertScaleProps) {
|
||||
const [selectedRating, setSelectedRating] = useState<number | null>(
|
||||
initialRating || null,
|
||||
);
|
||||
const [selectedReason, setSelectedReason] = useState<string | null>(
|
||||
initialReason || null,
|
||||
);
|
||||
const [showReasons, setShowReasons] = useState(false);
|
||||
const [reasonTimeout, setReasonTimeout] = useState<NodeJS.Timeout | null>(
|
||||
null,
|
||||
);
|
||||
const [isSubmitted, setIsSubmitted] = useState(initiallySubmitted);
|
||||
const [countdown, setCountdown] = useState<number>(0);
|
||||
|
||||
// Get scroll context
|
||||
const scrollContext = useContext(ScrollContext);
|
||||
|
||||
// If scrollContext is undefined, we're not inside a ScrollProvider
|
||||
const scrollToBottom = scrollContext?.scrollDomToBottom;
|
||||
const autoScroll = scrollContext?.autoScroll;
|
||||
|
||||
// Use our mutation hook
|
||||
const { mutate: submitConversationFeedback } =
|
||||
useSubmitConversationFeedback();
|
||||
|
||||
// Update isSubmitted if initiallySubmitted changes
|
||||
useEffect(() => {
|
||||
setIsSubmitted(initiallySubmitted);
|
||||
}, [initiallySubmitted]);
|
||||
|
||||
// Update selectedRating if initialRating changes
|
||||
useEffect(() => {
|
||||
if (initialRating) {
|
||||
setSelectedRating(initialRating);
|
||||
}
|
||||
}, [initialRating]);
|
||||
|
||||
// Update selectedReason if initialReason changes
|
||||
useEffect(() => {
|
||||
if (initialReason) {
|
||||
setSelectedReason(initialReason);
|
||||
}
|
||||
}, [initialReason]);
|
||||
|
||||
// Submit feedback and disable the component
|
||||
const submitFeedback = (rating: number, reason?: string) => {
|
||||
submitConversationFeedback(
|
||||
{
|
||||
rating,
|
||||
eventId,
|
||||
reason,
|
||||
},
|
||||
{
|
||||
onSuccess: () => {
|
||||
setSelectedReason(reason || null);
|
||||
setShowReasons(false);
|
||||
setIsSubmitted(true);
|
||||
},
|
||||
},
|
||||
);
|
||||
};
|
||||
|
||||
// Handle star rating selection
|
||||
const handleRatingClick = (rating: number) => {
|
||||
if (isSubmitted) return; // Prevent changes after submission
|
||||
|
||||
setSelectedRating(rating);
|
||||
|
||||
// Only show reasons if rating is 3 or less (1, 2, or 3 stars)
|
||||
// For ratings > 3 (4 or 5 stars), submit immediately without showing reasons
|
||||
if (rating <= 3) {
|
||||
setShowReasons(true);
|
||||
setCountdown(Math.ceil(AUTO_SUBMIT_TIMEOUT / 1000));
|
||||
|
||||
// Set a timeout to auto-submit if no reason is selected
|
||||
const timeout = setTimeout(() => {
|
||||
submitFeedback(rating);
|
||||
}, AUTO_SUBMIT_TIMEOUT);
|
||||
|
||||
setReasonTimeout(timeout);
|
||||
|
||||
// Only scroll to bottom if the user is already at the bottom (autoScroll is true)
|
||||
if (scrollToBottom && autoScroll) {
|
||||
// Small delay to ensure the reasons are fully rendered
|
||||
setTimeout(() => {
|
||||
scrollToBottom();
|
||||
}, 100);
|
||||
}
|
||||
} else {
|
||||
// For ratings > 3 (4 or 5 stars), submit immediately without showing reasons
|
||||
setShowReasons(false);
|
||||
submitFeedback(rating);
|
||||
}
|
||||
};
|
||||
|
||||
// Handle reason selection
|
||||
const handleReasonClick = (reason: string) => {
|
||||
if (selectedRating && reasonTimeout && !isSubmitted) {
|
||||
clearTimeout(reasonTimeout);
|
||||
setCountdown(0);
|
||||
submitFeedback(selectedRating, reason);
|
||||
}
|
||||
};
|
||||
|
||||
// Countdown effect
|
||||
useEffect(() => {
|
||||
if (countdown > 0 && showReasons && !isSubmitted) {
|
||||
const timer = setTimeout(() => {
|
||||
setCountdown(countdown - 1);
|
||||
}, 1000);
|
||||
return () => clearTimeout(timer);
|
||||
}
|
||||
return () => {};
|
||||
}, [countdown, showReasons, isSubmitted]);
|
||||
|
||||
// Clean up timeout on unmount
|
||||
useEffect(
|
||||
() => () => {
|
||||
if (reasonTimeout) {
|
||||
clearTimeout(reasonTimeout);
|
||||
}
|
||||
},
|
||||
[reasonTimeout],
|
||||
);
|
||||
|
||||
// Scroll to bottom when component mounts, but only if user is already at the bottom
|
||||
useEffect(() => {
|
||||
if (scrollToBottom && autoScroll && !isSubmitted) {
|
||||
// Small delay to ensure the component is fully rendered
|
||||
setTimeout(() => {
|
||||
scrollToBottom();
|
||||
}, 100);
|
||||
}
|
||||
}, [scrollToBottom, autoScroll, isSubmitted]);
|
||||
|
||||
// Scroll to bottom when reasons are shown, but only if user is already at the bottom
|
||||
useEffect(() => {
|
||||
if (scrollToBottom && autoScroll && showReasons) {
|
||||
// Small delay to ensure the reasons are fully rendered
|
||||
setTimeout(() => {
|
||||
scrollToBottom();
|
||||
}, 100);
|
||||
}
|
||||
}, [scrollToBottom, autoScroll, showReasons]);
|
||||
|
||||
// Helper function to get button class based on state
|
||||
const getButtonClass = (rating: number) => {
|
||||
if (isSubmitted) {
|
||||
return selectedRating && selectedRating >= rating
|
||||
? "text-yellow-400 cursor-not-allowed"
|
||||
: "text-gray-300 opacity-50 cursor-not-allowed";
|
||||
}
|
||||
|
||||
return selectedRating && selectedRating >= rating
|
||||
? "text-yellow-400"
|
||||
: "text-gray-300 hover:text-yellow-200";
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="mt-3 flex flex-col gap-1">
|
||||
<div className="text-sm text-gray-500 mb-1">
|
||||
{isSubmitted
|
||||
? i18n.t("FEEDBACK$THANK_YOU_FOR_FEEDBACK")
|
||||
: i18n.t("FEEDBACK$RATE_AGENT_PERFORMANCE")}
|
||||
</div>
|
||||
<div className="flex flex-col gap-1">
|
||||
<span className="flex gap-2 items-center flex-wrap">
|
||||
{[1, 2, 3, 4, 5].map((rating) => (
|
||||
<button
|
||||
type="button"
|
||||
key={rating}
|
||||
onClick={() => handleRatingClick(rating)}
|
||||
disabled={isSubmitted}
|
||||
className={cn("text-xl transition-all", getButtonClass(rating))}
|
||||
aria-label={`Rate ${rating} stars`}
|
||||
>
|
||||
★
|
||||
</button>
|
||||
))}
|
||||
{/* Show selected reason inline with stars when submitted (only for ratings <= 3) */}
|
||||
{isSubmitted &&
|
||||
selectedReason &&
|
||||
selectedRating &&
|
||||
selectedRating <= 3 && (
|
||||
<span className="text-sm text-gray-500 italic">
|
||||
{selectedReason}
|
||||
</span>
|
||||
)}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{showReasons && !isSubmitted && (
|
||||
<div className="mt-1 flex flex-col gap-1">
|
||||
<div className="text-xs text-gray-500 mb-1">
|
||||
{i18n.t("FEEDBACK$SELECT_REASON")}
|
||||
</div>
|
||||
{countdown > 0 && (
|
||||
<div className="text-xs text-gray-400 mb-1 italic">
|
||||
{i18n.t("FEEDBACK$SELECT_REASON_COUNTDOWN", {
|
||||
countdown,
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
<div className="flex flex-col gap-0.5">
|
||||
{FEEDBACK_REASONS.map((reason) => (
|
||||
<button
|
||||
type="button"
|
||||
key={reason}
|
||||
onClick={() => handleReasonClick(reason)}
|
||||
className="text-sm text-left py-1 px-2 rounded hover:bg-gray-700 transition-colors"
|
||||
>
|
||||
{reason}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -10,10 +10,7 @@ export function ConnectToProviderMessage() {
|
||||
return (
|
||||
<div className="flex flex-col gap-4">
|
||||
<p>{t("HOME$CONNECT_PROVIDER_MESSAGE")}</p>
|
||||
<Link
|
||||
data-testid="navigate-to-settings-button"
|
||||
to="/settings/integrations"
|
||||
>
|
||||
<Link data-testid="navigate-to-settings-button" to="/settings/git">
|
||||
<BrandButton type="button" variant="primary" isDisabled={isLoading}>
|
||||
{!isLoading && t("SETTINGS$TITLE")}
|
||||
{isLoading && t("HOME$LOADING")}
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { BrandButton } from "../brand-button";
|
||||
|
||||
export function InstallSlackAppAnchor() {
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<a
|
||||
data-testid="install-slack-app-button"
|
||||
href="https://slack.com/oauth/v2/authorize?client_id=7477886716822.8729519890534&scope=app_mentions:read,chat:write,users:read,channels:history,groups:history,mpim:history,im:history&user_scope=channels:history,groups:history,im:history,mpim:history"
|
||||
target="_blank"
|
||||
rel="noreferrer noopener"
|
||||
className="py-9"
|
||||
>
|
||||
<BrandButton type="button" variant="secondary">
|
||||
{t(I18nKey.SLACK$INSTALL_APP)}
|
||||
</BrandButton>
|
||||
</a>
|
||||
);
|
||||
}
|
||||
@@ -1,42 +0,0 @@
|
||||
import React, { createContext, useContext, ReactNode, RefObject } from "react";
|
||||
import { useScrollToBottom } from "#/hooks/use-scroll-to-bottom";
|
||||
|
||||
interface ScrollContextType {
|
||||
scrollRef: RefObject<HTMLDivElement | null>;
|
||||
autoScroll: boolean;
|
||||
setAutoScroll: (value: boolean) => void;
|
||||
scrollDomToBottom: () => void;
|
||||
hitBottom: boolean;
|
||||
setHitBottom: (value: boolean) => void;
|
||||
onChatBodyScroll: (e: HTMLElement) => void;
|
||||
}
|
||||
|
||||
export const ScrollContext = createContext<ScrollContextType | undefined>(
|
||||
undefined,
|
||||
);
|
||||
|
||||
interface ScrollProviderProps {
|
||||
children: ReactNode;
|
||||
value?: ScrollContextType;
|
||||
}
|
||||
|
||||
export function ScrollProvider({ children, value }: ScrollProviderProps) {
|
||||
const scrollHook = useScrollToBottom(React.useRef<HTMLDivElement>(null));
|
||||
|
||||
// Use provided value or default to the hook
|
||||
const contextValue = value || scrollHook;
|
||||
|
||||
return (
|
||||
<ScrollContext.Provider value={contextValue}>
|
||||
{children}
|
||||
</ScrollContext.Provider>
|
||||
);
|
||||
}
|
||||
|
||||
export function useScrollContext() {
|
||||
const context = useContext(ScrollContext);
|
||||
if (context === undefined) {
|
||||
throw new Error("useScrollContext must be used within a ScrollProvider");
|
||||
}
|
||||
return context;
|
||||
}
|
||||
@@ -1,39 +0,0 @@
|
||||
import { useMutation, useQueryClient } from "@tanstack/react-query";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import OpenHands from "#/api/open-hands";
|
||||
import { useConversationId } from "#/hooks/use-conversation-id";
|
||||
|
||||
type SubmitConversationFeedbackArgs = {
|
||||
rating: number;
|
||||
eventId?: number;
|
||||
reason?: string;
|
||||
};
|
||||
|
||||
export const useSubmitConversationFeedback = () => {
|
||||
const { conversationId } = useConversationId();
|
||||
const queryClient = useQueryClient();
|
||||
const { t } = useTranslation();
|
||||
|
||||
return useMutation({
|
||||
mutationFn: ({ rating, eventId, reason }: SubmitConversationFeedbackArgs) =>
|
||||
OpenHands.submitConversationFeedback(
|
||||
conversationId,
|
||||
rating,
|
||||
eventId,
|
||||
reason,
|
||||
),
|
||||
onSuccess: (_, { eventId }) => {
|
||||
// Invalidate the feedback existence query to trigger a refetch
|
||||
if (eventId) {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: ["feedback", "exists", conversationId, eventId],
|
||||
});
|
||||
}
|
||||
},
|
||||
onError: (error) => {
|
||||
// Log error but don't show toast - user will just see the UI stay in unsubmitted state
|
||||
// eslint-disable-next-line no-console
|
||||
console.error(t("FEEDBACK$FAILED_TO_SUBMIT"), error);
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -1,24 +0,0 @@
|
||||
import { useQuery } from "@tanstack/react-query";
|
||||
import OpenHands from "#/api/open-hands";
|
||||
import { useConversationId } from "#/hooks/use-conversation-id";
|
||||
|
||||
export interface FeedbackData {
|
||||
exists: boolean;
|
||||
rating?: number;
|
||||
reason?: string;
|
||||
}
|
||||
|
||||
export const useFeedbackExists = (eventId?: number) => {
|
||||
const { conversationId } = useConversationId();
|
||||
|
||||
return useQuery<FeedbackData>({
|
||||
queryKey: ["feedback", "exists", conversationId, eventId],
|
||||
queryFn: () => {
|
||||
if (!eventId) return { exists: false };
|
||||
return OpenHands.checkFeedbackExists(conversationId, eventId);
|
||||
},
|
||||
enabled: !!eventId,
|
||||
staleTime: 1000 * 60 * 5, // 5 minutes
|
||||
gcTime: 1000 * 60 * 15, // 15 minutes
|
||||
});
|
||||
};
|
||||
@@ -80,7 +80,7 @@ export enum I18nKey {
|
||||
ANALYTICS$CONFIRM_PREFERENCES = "ANALYTICS$CONFIRM_PREFERENCES",
|
||||
SETTINGS$SAVING = "SETTINGS$SAVING",
|
||||
SETTINGS$SAVE_CHANGES = "SETTINGS$SAVE_CHANGES",
|
||||
SETTINGS$NAV_INTEGRATIONS = "SETTINGS$NAV_INTEGRATIONS",
|
||||
SETTINGS$NAV_GIT = "SETTINGS$NAV_GIT",
|
||||
SETTINGS$NAV_APPLICATION = "SETTINGS$NAV_APPLICATION",
|
||||
SETTINGS$NAV_CREDITS = "SETTINGS$NAV_CREDITS",
|
||||
SETTINGS$NAV_SECRETS = "SETTINGS$NAV_SECRETS",
|
||||
@@ -174,7 +174,6 @@ export enum I18nKey {
|
||||
GITHUB$TOKEN_INVALID = "GITHUB$TOKEN_INVALID",
|
||||
BUTTON$DISCONNECT = "BUTTON$DISCONNECT",
|
||||
GITHUB$CONFIGURE_REPOS = "GITHUB$CONFIGURE_REPOS",
|
||||
SLACK$INSTALL_APP = "SLACK$INSTALL_APP",
|
||||
COMMON$CLICK_FOR_INSTRUCTIONS = "COMMON$CLICK_FOR_INSTRUCTIONS",
|
||||
LLM$SELECT_MODEL_PLACEHOLDER = "LLM$SELECT_MODEL_PLACEHOLDER",
|
||||
LLM$MODEL = "LLM$MODEL",
|
||||
@@ -584,13 +583,4 @@ export enum I18nKey {
|
||||
SETTINGS$EMAIL_VERIFICATION_RESTRICTION_MESSAGE = "SETTINGS$EMAIL_VERIFICATION_RESTRICTION_MESSAGE",
|
||||
SETTINGS$RESEND_VERIFICATION = "SETTINGS$RESEND_VERIFICATION",
|
||||
SETTINGS$FAILED_TO_RESEND_VERIFICATION = "SETTINGS$FAILED_TO_RESEND_VERIFICATION",
|
||||
FEEDBACK$RATE_AGENT_PERFORMANCE = "FEEDBACK$RATE_AGENT_PERFORMANCE",
|
||||
FEEDBACK$SELECT_REASON = "FEEDBACK$SELECT_REASON",
|
||||
FEEDBACK$SELECT_REASON_COUNTDOWN = "FEEDBACK$SELECT_REASON_COUNTDOWN",
|
||||
FEEDBACK$REASON_MISUNDERSTOOD_INSTRUCTION = "FEEDBACK$REASON_MISUNDERSTOOD_INSTRUCTION",
|
||||
FEEDBACK$REASON_FORGOT_CONTEXT = "FEEDBACK$REASON_FORGOT_CONTEXT",
|
||||
FEEDBACK$REASON_UNNECESSARY_CHANGES = "FEEDBACK$REASON_UNNECESSARY_CHANGES",
|
||||
FEEDBACK$REASON_OTHER = "FEEDBACK$REASON_OTHER",
|
||||
FEEDBACK$THANK_YOU_FOR_FEEDBACK = "FEEDBACK$THANK_YOU_FOR_FEEDBACK",
|
||||
FEEDBACK$FAILED_TO_SUBMIT = "FEEDBACK$FAILED_TO_SUBMIT",
|
||||
}
|
||||
|
||||
@@ -1279,21 +1279,21 @@
|
||||
"de": "Änderungen speichern",
|
||||
"uk": "Зберегти зміни"
|
||||
},
|
||||
"SETTINGS$NAV_INTEGRATIONS": {
|
||||
"en": "Integrations",
|
||||
"ja": "統合",
|
||||
"zh-CN": "集成",
|
||||
"zh-TW": "整合",
|
||||
"ko-KR": "통합",
|
||||
"no": "Integrasjoner",
|
||||
"it": "Integrazioni",
|
||||
"pt": "Integrações",
|
||||
"es": "Integraciones",
|
||||
"ar": "التكامل",
|
||||
"fr": "Intégrations",
|
||||
"tr": "Entegrasyonlar",
|
||||
"de": "Integrationen",
|
||||
"uk": "Інтеграції"
|
||||
"SETTINGS$NAV_GIT": {
|
||||
"en": "Git",
|
||||
"ja": "Git",
|
||||
"zh-CN": "Git",
|
||||
"zh-TW": "Git",
|
||||
"ko-KR": "Git",
|
||||
"no": "Git",
|
||||
"it": "Git",
|
||||
"pt": "Git",
|
||||
"es": "Git",
|
||||
"ar": "Git",
|
||||
"fr": "Git",
|
||||
"tr": "Git",
|
||||
"de": "Git",
|
||||
"uk": "Git"
|
||||
},
|
||||
"SETTINGS$NAV_APPLICATION": {
|
||||
"en": "Application",
|
||||
@@ -2783,22 +2783,6 @@
|
||||
"de": "GitHub-Repositories konfigurieren",
|
||||
"uk": "Налаштування репозиторіїв Github"
|
||||
},
|
||||
"SLACK$INSTALL_APP": {
|
||||
"en": "Install OpenHands Slack App",
|
||||
"ja": "OpenHands Slackアプリをインストール",
|
||||
"zh-CN": "安装 OpenHands Slack 应用",
|
||||
"zh-TW": "安裝 OpenHands Slack 應用程式",
|
||||
"ko-KR": "OpenHands Slack 앱 설치",
|
||||
"no": "Installer OpenHands Slack-app",
|
||||
"it": "Installa l'app Slack di OpenHands",
|
||||
"pt": "Instalar aplicativo Slack do OpenHands",
|
||||
"es": "Instalar aplicación Slack de OpenHands",
|
||||
"ar": "تثبيت تطبيق OpenHands Slack",
|
||||
"fr": "Installer l'application Slack OpenHands",
|
||||
"tr": "OpenHands Slack uygulamasını yükle",
|
||||
"de": "OpenHands Slack-App installieren",
|
||||
"uk": "Встановити додаток OpenHands Slack"
|
||||
},
|
||||
"COMMON$CLICK_FOR_INSTRUCTIONS": {
|
||||
"en": "Click here for instructions",
|
||||
"ja": "手順はこちらをクリック",
|
||||
@@ -9342,149 +9326,5 @@
|
||||
"tr": "Doğrulama e-postası yeniden gönderilemedi",
|
||||
"de": "Bestätigungs-E-Mail konnte nicht erneut gesendet werden",
|
||||
"uk": "Не вдалося повторно надіслати лист підтвердження"
|
||||
},
|
||||
"FEEDBACK$RATE_AGENT_PERFORMANCE": {
|
||||
"en": "Rate the agent's performance:",
|
||||
"ja": "エージェントのパフォーマンスを評価してください:",
|
||||
"zh-CN": "评价代理的表现:",
|
||||
"zh-TW": "評價代理的表現:",
|
||||
"ko-KR": "에이전트의 성능을 평가하세요:",
|
||||
"no": "Vurder agentens ytelse:",
|
||||
"it": "Valuta le prestazioni dell'agente:",
|
||||
"pt": "Avalie o desempenho do agente:",
|
||||
"es": "Evalúe el rendimiento del agente:",
|
||||
"ar": "قيم أداء الوكيل:",
|
||||
"fr": "Évaluez la performance de l'agent :",
|
||||
"tr": "Ajanın performansını değerlendirin:",
|
||||
"de": "Bewerten Sie die Leistung des Agenten:",
|
||||
"uk": "Оцініть продуктивність агента:"
|
||||
},
|
||||
"FEEDBACK$SELECT_REASON": {
|
||||
"en": "Select a reason (optional):",
|
||||
"ja": "理由を選択してください(任意):",
|
||||
"zh-CN": "选择原因(可选):",
|
||||
"zh-TW": "選擇原因(可選):",
|
||||
"ko-KR": "이유 선택 (선택 사항):",
|
||||
"no": "Velg en grunn (valgfritt):",
|
||||
"it": "Seleziona un motivo (opzionale):",
|
||||
"pt": "Selecione um motivo (opcional):",
|
||||
"es": "Seleccione un motivo (opcional):",
|
||||
"ar": "حدد سببًا (اختياري):",
|
||||
"fr": "Sélectionnez une raison (facultatif) :",
|
||||
"tr": "Bir neden seçin (isteğe bağlı):",
|
||||
"de": "Wählen Sie einen Grund (optional):",
|
||||
"uk": "Виберіть причину (необов'язково):"
|
||||
},
|
||||
"FEEDBACK$SELECT_REASON_COUNTDOWN": {
|
||||
"en": "Auto-submitting in {{countdown}} seconds...",
|
||||
"ja": "{{countdown}}秒後に自動送信されます...",
|
||||
"zh-CN": "{{countdown}}秒后自动提交...",
|
||||
"zh-TW": "{{countdown}}秒後自動提交...",
|
||||
"ko-KR": "{{countdown}}초 후 자동 제출...",
|
||||
"no": "Sender automatisk om {{countdown}} sekunder...",
|
||||
"it": "Invio automatico tra {{countdown}} secondi...",
|
||||
"pt": "Enviando automaticamente em {{countdown}} segundos...",
|
||||
"es": "Enviando automáticamente en {{countdown}} segundos...",
|
||||
"ar": "الإرسال التلقائي خلال {{countdown}} ثانية...",
|
||||
"fr": "Envoi automatique dans {{countdown}} secondes...",
|
||||
"tr": "{{countdown}} saniye içinde otomatik gönderilecek...",
|
||||
"de": "Automatische Übermittlung in {{countdown}} Sekunden...",
|
||||
"uk": "Автоматична відправка через {{countdown}} секунд..."
|
||||
},
|
||||
"FEEDBACK$REASON_MISUNDERSTOOD_INSTRUCTION": {
|
||||
"en": "The agent misunderstood my instruction",
|
||||
"ja": "エージェントは私の指示を誤解しました",
|
||||
"zh-CN": "代理误解了我的指示",
|
||||
"zh-TW": "代理誤解了我的指示",
|
||||
"ko-KR": "에이전트가 내 지시를 잘못 이해했습니다",
|
||||
"no": "Agenten misforsto instruksjonene mine",
|
||||
"it": "L'agente ha frainteso le mie istruzioni",
|
||||
"pt": "O agente não entendeu minhas instruções",
|
||||
"es": "El agente malinterpretó mis instrucciones",
|
||||
"ar": "أساء الوكيل فهم تعليماتي",
|
||||
"fr": "L'agent a mal compris mes instructions",
|
||||
"tr": "Ajan talimatlarımı yanlış anladı",
|
||||
"de": "Der Agent hat meine Anweisungen missverstanden",
|
||||
"uk": "Агент неправильно зрозумів мої інструкції"
|
||||
},
|
||||
"FEEDBACK$REASON_FORGOT_CONTEXT": {
|
||||
"en": "The agent forgot about the earlier context",
|
||||
"ja": "エージェントは以前のコンテキストを忘れました",
|
||||
"zh-CN": "代理忘记了之前的上下文",
|
||||
"zh-TW": "代理忘記了之前的上下文",
|
||||
"ko-KR": "에이전트가 이전 컨텍스트를 잊었습니다",
|
||||
"no": "Agenten glemte den tidligere konteksten",
|
||||
"it": "L'agente ha dimenticato il contesto precedente",
|
||||
"pt": "O agente esqueceu o contexto anterior",
|
||||
"es": "El agente olvidó el contexto anterior",
|
||||
"ar": "نسي الوكيل السياق السابق",
|
||||
"fr": "L'agent a oublié le contexte précédent",
|
||||
"tr": "Ajan önceki bağlamı unuttu",
|
||||
"de": "Der Agent hat den früheren Kontext vergessen",
|
||||
"uk": "Агент забув про попередній контекст"
|
||||
},
|
||||
"FEEDBACK$REASON_UNNECESSARY_CHANGES": {
|
||||
"en": "The agent made unnecessary changes",
|
||||
"ja": "エージェントは不要な変更を行いました",
|
||||
"zh-CN": "代理进行了不必要的更改",
|
||||
"zh-TW": "代理進行了不必要的更改",
|
||||
"ko-KR": "에이전트가 불필요한 변경을 했습니다",
|
||||
"no": "Agenten gjorde unødvendige endringer",
|
||||
"it": "L'agente ha apportato modifiche non necessarie",
|
||||
"pt": "O agente fez alterações desnecessárias",
|
||||
"es": "El agente hizo cambios innecesarios",
|
||||
"ar": "قام الوكيل بتغييرات غير ضرورية",
|
||||
"fr": "L'agent a apporté des modifications inutiles",
|
||||
"tr": "Ajan gereksiz değişiklikler yaptı",
|
||||
"de": "Der Agent hat unnötige Änderungen vorgenommen",
|
||||
"uk": "Агент зробив непотрібні зміни"
|
||||
},
|
||||
"FEEDBACK$REASON_OTHER": {
|
||||
"en": "Other",
|
||||
"ja": "その他",
|
||||
"zh-CN": "其他",
|
||||
"zh-TW": "其他",
|
||||
"ko-KR": "기타",
|
||||
"no": "Annet",
|
||||
"it": "Altro",
|
||||
"pt": "Outro",
|
||||
"es": "Otro",
|
||||
"ar": "أخرى",
|
||||
"fr": "Autre",
|
||||
"tr": "Diğer",
|
||||
"de": "Andere",
|
||||
"uk": "Інше"
|
||||
},
|
||||
"FEEDBACK$THANK_YOU_FOR_FEEDBACK": {
|
||||
"en": "Thank you for your feedback! This will help us improve OpenHands going forward.",
|
||||
"ja": "フィードバックをありがとうございます!これにより、今後OpenHandsを改善していくことができます。",
|
||||
"zh-CN": "感谢您的反馈!这将帮助我们改进OpenHands。",
|
||||
"zh-TW": "感謝您的反饋!這將幫助我們改進OpenHands。",
|
||||
"ko-KR": "피드백 감사합니다! 이를 통해 OpenHands를 개선해 나가겠습니다.",
|
||||
"no": "Takk for tilbakemeldingen! Dette vil hjelpe oss med å forbedre OpenHands fremover.",
|
||||
"it": "Grazie per il tuo feedback! Questo ci aiuterà a migliorare OpenHands in futuro.",
|
||||
"pt": "Obrigado pelo seu feedback! Isso nos ajudará a melhorar o OpenHands no futuro.",
|
||||
"es": "¡Gracias por su comentario! Esto nos ayudará a mejorar OpenHands en el futuro.",
|
||||
"ar": "شكرا على ملاحظاتك! سيساعدنا هذا في تحسين OpenHands في المستقبل.",
|
||||
"fr": "Merci pour votre retour ! Cela nous aidera à améliorer OpenHands à l'avenir.",
|
||||
"tr": "Geri bildiriminiz için teşekkürler! Bu, OpenHands'i ileride geliştirmemize yardımcı olacak.",
|
||||
"de": "Vielen Dank für Ihr Feedback! Das hilft uns, OpenHands in Zukunft zu verbessern.",
|
||||
"uk": "Дякуємо за ваш відгук! Це допоможе нам покращити OpenHands у майбутньому."
|
||||
},
|
||||
"FEEDBACK$FAILED_TO_SUBMIT": {
|
||||
"en": "Failed to submit feedback",
|
||||
"ja": "フィードバックの送信に失敗しました",
|
||||
"zh-CN": "提交反馈失败",
|
||||
"zh-TW": "提交反饋失敗",
|
||||
"ko-KR": "피드백 제출 실패",
|
||||
"no": "Kunne ikke sende tilbakemelding",
|
||||
"it": "Impossibile inviare feedback",
|
||||
"pt": "Falha ao enviar feedback",
|
||||
"es": "Error al enviar comentarios",
|
||||
"ar": "فشل في تقديم التعليقات",
|
||||
"fr": "Échec de l'envoi des commentaires",
|
||||
"tr": "Geri bildirim gönderilemedi",
|
||||
"de": "Feedback konnte nicht gesendet werden",
|
||||
"uk": "Не вдалося надіслати відгук"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ export default [
|
||||
index("routes/llm-settings.tsx"),
|
||||
route("mcp", "routes/mcp-settings.tsx"),
|
||||
route("user", "routes/user-settings.tsx"),
|
||||
route("integrations", "routes/git-settings.tsx"),
|
||||
route("git", "routes/git-settings.tsx"),
|
||||
route("app", "routes/app-settings.tsx"),
|
||||
route("billing", "routes/billing.tsx"),
|
||||
route("secrets", "routes/secrets-settings.tsx"),
|
||||
|
||||
@@ -7,7 +7,6 @@ import { useLogout } from "#/hooks/mutation/use-logout";
|
||||
import { GitHubTokenInput } from "#/components/features/settings/git-settings/github-token-input";
|
||||
import { GitLabTokenInput } from "#/components/features/settings/git-settings/gitlab-token-input";
|
||||
import { ConfigureGitHubRepositoriesAnchor } from "#/components/features/settings/git-settings/configure-github-repositories-anchor";
|
||||
import { InstallSlackAppAnchor } from "#/components/features/settings/git-settings/install-slack-app-anchor";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import {
|
||||
displayErrorToast,
|
||||
@@ -104,10 +103,6 @@ function GitSettingsScreen() {
|
||||
<ConfigureGitHubRepositoriesAnchor slug={config.APP_SLUG!} />
|
||||
)}
|
||||
|
||||
{shouldRenderExternalConfigureButtons && !isLoading && (
|
||||
<InstallSlackAppAnchor />
|
||||
)}
|
||||
|
||||
{!isSaas && (
|
||||
<GitHubTokenInput
|
||||
name="github-token-input"
|
||||
|
||||
@@ -84,11 +84,7 @@ function SecretsSettingsScreen() {
|
||||
)}
|
||||
|
||||
{shouldRenderConnectToGitButton && (
|
||||
<Link
|
||||
to="/settings/integrations"
|
||||
data-testid="connect-git-button"
|
||||
type="button"
|
||||
>
|
||||
<Link to="/settings/git" data-testid="connect-git-button" type="button">
|
||||
<BrandButton type="button" variant="secondary">
|
||||
Connect a Git provider to manage secrets
|
||||
</BrandButton>
|
||||
|
||||
@@ -16,7 +16,7 @@ function SettingsScreen() {
|
||||
|
||||
const saasNavItems = [
|
||||
{ to: "/settings/user", text: t("SETTINGS$NAV_USER") },
|
||||
{ to: "/settings/integrations", text: t("SETTINGS$NAV_INTEGRATIONS") },
|
||||
{ to: "/settings/git", text: t("SETTINGS$NAV_GIT") },
|
||||
{ to: "/settings/app", text: t("SETTINGS$NAV_APPLICATION") },
|
||||
{ to: "/settings/billing", text: t("SETTINGS$NAV_CREDITS") },
|
||||
{ to: "/settings/secrets", text: t("SETTINGS$NAV_SECRETS") },
|
||||
@@ -26,7 +26,7 @@ function SettingsScreen() {
|
||||
const ossNavItems = [
|
||||
{ to: "/settings", text: t("SETTINGS$NAV_LLM") },
|
||||
{ to: "/settings/mcp", text: t("SETTINGS$NAV_MCP") },
|
||||
{ to: "/settings/integrations", text: t("SETTINGS$NAV_INTEGRATIONS") },
|
||||
{ to: "/settings/git", text: t("SETTINGS$NAV_GIT") },
|
||||
{ to: "/settings/app", text: t("SETTINGS$NAV_APPLICATION") },
|
||||
{ to: "/settings/secrets", text: t("SETTINGS$NAV_SECRETS") },
|
||||
];
|
||||
|
||||
@@ -125,9 +125,9 @@ class BrowsingAgent(Agent):
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Resets the Browsing Agent's internal state."""
|
||||
"""Resets the Browsing Agent."""
|
||||
super().reset()
|
||||
# Reset agent-specific counters but not LLM metrics
|
||||
self.cost_accumulator = 0
|
||||
self.error_accumulator = 0
|
||||
|
||||
def step(self, state: State) -> Action:
|
||||
|
||||
@@ -136,9 +136,8 @@ class CodeActAgent(Agent):
|
||||
return tools
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Resets the CodeAct Agent's internal state."""
|
||||
"""Resets the CodeAct Agent."""
|
||||
super().reset()
|
||||
# Only clear pending actions, not LLM metrics
|
||||
self.pending_actions.clear()
|
||||
|
||||
def step(self, state: State) -> 'Action':
|
||||
|
||||
@@ -119,14 +119,14 @@ class DummyAgent(Agent):
|
||||
]
|
||||
|
||||
def step(self, state: State) -> Action:
|
||||
if state.iteration_flag.current_value >= len(self.steps):
|
||||
if state.iteration >= len(self.steps):
|
||||
return AgentFinishAction()
|
||||
|
||||
current_step = self.steps[state.iteration_flag.current_value]
|
||||
current_step = self.steps[state.iteration]
|
||||
action = current_step['action']
|
||||
|
||||
if state.iteration_flag.current_value > 0:
|
||||
prev_step = self.steps[state.iteration_flag.current_value - 1]
|
||||
if state.iteration > 0:
|
||||
prev_step = self.steps[state.iteration - 1]
|
||||
|
||||
if 'observations' in prev_step and prev_step['observations']:
|
||||
expected_observations = prev_step['observations']
|
||||
|
||||
@@ -176,9 +176,9 @@ Note:
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Resets the VisualBrowsingAgent's internal state."""
|
||||
"""Resets the VisualBrowsingAgent."""
|
||||
super().reset()
|
||||
# Reset agent-specific counters but not LLM metrics
|
||||
self.cost_accumulator = 0
|
||||
self.error_accumulator = 0
|
||||
|
||||
def step(self, state: State) -> Action:
|
||||
|
||||
@@ -103,10 +103,16 @@ class Agent(ABC):
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Resets the agent's execution status."""
|
||||
# Only reset the completion status, not the LLM metrics
|
||||
"""Resets the agent's execution status and clears the history. This method can be used
|
||||
to prepare the agent for restarting the instruction or cleaning up before destruction.
|
||||
|
||||
"""
|
||||
# TODO clear history
|
||||
self._complete = False
|
||||
|
||||
if self.llm:
|
||||
self.llm.reset()
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
||||
@@ -7,6 +7,7 @@ import time
|
||||
import traceback
|
||||
from typing import Callable
|
||||
|
||||
import litellm # noqa
|
||||
from litellm.exceptions import ( # noqa
|
||||
APIConnectionError,
|
||||
APIError,
|
||||
@@ -24,8 +25,7 @@ from litellm.exceptions import ( # noqa
|
||||
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.controller.replay import ReplayManager
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.controller.state.state_tracker import StateTracker
|
||||
from openhands.controller.state.state import State, TrafficControlState
|
||||
from openhands.controller.stuck import StuckDetector
|
||||
from openhands.core.config import AgentConfig, LLMConfig
|
||||
from openhands.core.exceptions import (
|
||||
@@ -61,6 +61,7 @@ from openhands.events.action import (
|
||||
)
|
||||
from openhands.events.action.agent import CondensationAction, RecallAction
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.event_filter import EventFilter
|
||||
from openhands.events.observation import (
|
||||
AgentDelegateObservation,
|
||||
AgentStateChangedObservation,
|
||||
@@ -68,11 +69,10 @@ from openhands.events.observation import (
|
||||
NullObservation,
|
||||
Observation,
|
||||
)
|
||||
from openhands.events.serialization.event import truncate_content
|
||||
from openhands.events.serialization.event import event_to_trajectory, truncate_content
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.llm.metrics import Metrics, TokenUsage
|
||||
from openhands.memory.view import View
|
||||
from openhands.storage.files import FileStore
|
||||
|
||||
# note: RESUME is only available on web GUI
|
||||
TRAFFIC_CONTROL_REMINDER = (
|
||||
@@ -101,13 +101,11 @@ class AgentController:
|
||||
self,
|
||||
agent: Agent,
|
||||
event_stream: EventStream,
|
||||
iteration_delta: int,
|
||||
budget_per_task_delta: float | None = None,
|
||||
max_iterations: int,
|
||||
max_budget_per_task: float | None = None,
|
||||
agent_to_llm_config: dict[str, LLMConfig] | None = None,
|
||||
agent_configs: dict[str, AgentConfig] | None = None,
|
||||
sid: str | None = None,
|
||||
file_store: FileStore | None = None,
|
||||
user_id: str | None = None,
|
||||
confirmation_mode: bool = False,
|
||||
initial_state: State | None = None,
|
||||
is_delegate: bool = False,
|
||||
@@ -134,10 +132,7 @@ class AgentController:
|
||||
status_callback: Optional callback function to handle status updates.
|
||||
replay_events: A list of logs to replay.
|
||||
"""
|
||||
|
||||
self.id = sid or event_stream.sid
|
||||
self.user_id = user_id
|
||||
self.file_store = file_store
|
||||
self.agent = agent
|
||||
self.headless_mode = headless_mode
|
||||
self.is_delegate = is_delegate
|
||||
@@ -151,22 +146,29 @@ class AgentController:
|
||||
EventStreamSubscriber.AGENT_CONTROLLER, self.on_event, self.id
|
||||
)
|
||||
|
||||
self.state_tracker = StateTracker(sid, file_store, user_id)
|
||||
# filter out events that are not relevant to the agent
|
||||
# so they will not be included in the agent history
|
||||
self.agent_history_filter = EventFilter(
|
||||
exclude_types=(
|
||||
NullAction,
|
||||
NullObservation,
|
||||
ChangeAgentStateAction,
|
||||
AgentStateChangedObservation,
|
||||
),
|
||||
exclude_hidden=True,
|
||||
)
|
||||
|
||||
# state from the previous session, state from a parent agent, or a fresh state
|
||||
self.set_initial_state(
|
||||
state=initial_state,
|
||||
max_iterations=iteration_delta,
|
||||
max_budget_per_task=budget_per_task_delta,
|
||||
max_iterations=max_iterations,
|
||||
confirmation_mode=confirmation_mode,
|
||||
)
|
||||
|
||||
self.state = self.state_tracker.state # TODO: share between manager and controller for backward compatability; we should ideally move all state related logic to the state manager
|
||||
|
||||
self.max_budget_per_task = max_budget_per_task
|
||||
self.agent_to_llm_config = agent_to_llm_config if agent_to_llm_config else {}
|
||||
self.agent_configs = agent_configs if agent_configs else {}
|
||||
self._initial_max_iterations = iteration_delta
|
||||
self._initial_max_budget_per_task = budget_per_task_delta
|
||||
self._initial_max_iterations = max_iterations
|
||||
self._initial_max_budget_per_task = max_budget_per_task
|
||||
|
||||
# stuck helper
|
||||
self._stuck_detector = StuckDetector(self.state)
|
||||
@@ -212,7 +214,26 @@ class AgentController:
|
||||
if set_stop_state:
|
||||
await self.set_agent_state_to(AgentState.STOPPED)
|
||||
|
||||
self.state_tracker.close(self.event_stream)
|
||||
# we made history, now is the time to rewrite it!
|
||||
# the final state.history will be used by external scripts like evals, tests, etc.
|
||||
# history will need to be complete WITH delegates events
|
||||
# like the regular agent history, it does not include:
|
||||
# - 'hidden' events, events with hidden=True
|
||||
# - backend events (the default 'filtered out' types, types in self.filter_out)
|
||||
start_id = self.state.start_id if self.state.start_id >= 0 else 0
|
||||
end_id = (
|
||||
self.state.end_id
|
||||
if self.state.end_id >= 0
|
||||
else self.event_stream.get_latest_event_id()
|
||||
)
|
||||
self.state.history = list(
|
||||
self.event_stream.search_events(
|
||||
start_id=start_id,
|
||||
end_id=end_id,
|
||||
reverse=False,
|
||||
filter=self.agent_history_filter,
|
||||
)
|
||||
)
|
||||
|
||||
# unsubscribe from the event stream
|
||||
# only the root parent controller subscribes to the event stream
|
||||
@@ -236,6 +257,14 @@ class AgentController:
|
||||
extra_merged = {'session_id': self.id, **extra}
|
||||
getattr(logger, level)(message, extra=extra_merged, stacklevel=2)
|
||||
|
||||
def update_state_before_step(self) -> None:
|
||||
self.state.iteration += 1
|
||||
self.state.local_iteration += 1
|
||||
|
||||
async def update_state_after_step(self) -> None:
|
||||
# update metrics especially for cost. Use deepcopy to avoid it being modified by agent._reset()
|
||||
self.state.local_metrics = copy.deepcopy(self.agent.llm.metrics)
|
||||
|
||||
async def _react_to_exception(
|
||||
self,
|
||||
e: Exception,
|
||||
@@ -361,17 +390,10 @@ class AgentController:
|
||||
# If we have a delegate that is not finished or errored, forward events to it
|
||||
if self.delegate is not None:
|
||||
delegate_state = self.delegate.get_agent_state()
|
||||
if (
|
||||
delegate_state
|
||||
not in (
|
||||
AgentState.FINISHED,
|
||||
AgentState.ERROR,
|
||||
AgentState.REJECTED,
|
||||
)
|
||||
or 'RuntimeError: Agent reached maximum iteration.'
|
||||
in self.delegate.state.last_error
|
||||
or 'RuntimeError:Agent reached maximum budget for conversation'
|
||||
in self.delegate.state.last_error
|
||||
if delegate_state not in (
|
||||
AgentState.FINISHED,
|
||||
AgentState.ERROR,
|
||||
AgentState.REJECTED,
|
||||
):
|
||||
# Forward the event to delegate and skip parent processing
|
||||
asyncio.get_event_loop().run_until_complete(
|
||||
@@ -390,7 +412,9 @@ class AgentController:
|
||||
if hasattr(event, 'hidden') and event.hidden:
|
||||
return
|
||||
|
||||
self.state_tracker.add_history(event)
|
||||
# if the event is not filtered out, add it to the history
|
||||
if self.agent_history_filter.include(event):
|
||||
self.state.history.append(event)
|
||||
|
||||
if isinstance(event, Action):
|
||||
await self._handle_action(event)
|
||||
@@ -433,9 +457,11 @@ class AgentController:
|
||||
|
||||
elif isinstance(action, AgentFinishAction):
|
||||
self.state.outputs = action.outputs
|
||||
self.state.metrics.merge(self.state.local_metrics)
|
||||
await self.set_agent_state_to(AgentState.FINISHED)
|
||||
elif isinstance(action, AgentRejectAction):
|
||||
self.state.outputs = action.outputs
|
||||
self.state.metrics.merge(self.state.local_metrics)
|
||||
await self.set_agent_state_to(AgentState.REJECTED)
|
||||
|
||||
async def _handle_observation(self, observation: Observation) -> None:
|
||||
@@ -455,10 +481,8 @@ class AgentController:
|
||||
log_level, str(observation_to_print), extra={'msg_type': 'OBSERVATION'}
|
||||
)
|
||||
|
||||
# TODO: these metrics come from the draft editor, and they get accumulated into controller's state metrics and the agent's llm metrics
|
||||
# In the future, we should have a more principled way to sharing metrics across all LLM instances for a given conversation
|
||||
if observation.llm_metrics is not None:
|
||||
self.state_tracker.merge_metrics(observation.llm_metrics)
|
||||
self.agent.llm.metrics.merge(observation.llm_metrics)
|
||||
|
||||
# this happens for runnable actions and microagent actions
|
||||
if self._pending_action and self._pending_action.id == observation.cause:
|
||||
@@ -472,6 +496,9 @@ class AgentController:
|
||||
if self.state.agent_state == AgentState.USER_REJECTED:
|
||||
await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT)
|
||||
return
|
||||
elif isinstance(observation, ErrorObservation):
|
||||
if self.state.agent_state == AgentState.ERROR:
|
||||
self.state.metrics.merge(self.state.local_metrics)
|
||||
|
||||
async def _handle_message_action(self, action: MessageAction) -> None:
|
||||
"""Handles message actions from the event stream.
|
||||
@@ -489,6 +516,22 @@ class AgentController:
|
||||
str(action),
|
||||
extra={'msg_type': 'ACTION', 'event_source': EventSource.USER},
|
||||
)
|
||||
# Extend max iterations when the user sends a message (only in non-headless mode)
|
||||
if self._initial_max_iterations is not None and not self.headless_mode:
|
||||
self.state.max_iterations = (
|
||||
self.state.iteration + self._initial_max_iterations
|
||||
)
|
||||
if (
|
||||
self.state.traffic_control_state == TrafficControlState.THROTTLING
|
||||
or self.state.traffic_control_state == TrafficControlState.PAUSED
|
||||
):
|
||||
self.state.traffic_control_state = TrafficControlState.NORMAL
|
||||
self.log(
|
||||
'debug',
|
||||
f'Extended max iterations to {self.state.max_iterations} after user message',
|
||||
)
|
||||
# try to retrieve microagents relevant to the user message
|
||||
# set pending_action while we search for information
|
||||
|
||||
# if this is the first user message for this agent, matters for the microagent info type
|
||||
first_user_message = self._first_user_message()
|
||||
@@ -562,16 +605,36 @@ class AgentController:
|
||||
return
|
||||
|
||||
if new_state in (AgentState.STOPPED, AgentState.ERROR):
|
||||
# sync existing metrics BEFORE resetting the agent
|
||||
await self.update_state_after_step()
|
||||
self.state.metrics.merge(self.state.local_metrics)
|
||||
self._reset()
|
||||
|
||||
# User is allowing to check control limits and expand them if applicable
|
||||
if (
|
||||
self.state.agent_state == AgentState.ERROR
|
||||
and new_state == AgentState.RUNNING
|
||||
elif (
|
||||
new_state == AgentState.RUNNING
|
||||
and self.state.agent_state == AgentState.PAUSED
|
||||
# TODO: do we really need both THROTTLING and PAUSED states, or can we clean up one of them completely?
|
||||
and self.state.traffic_control_state == TrafficControlState.THROTTLING
|
||||
):
|
||||
self.state_tracker.maybe_increase_control_flags_limits(self.headless_mode)
|
||||
# user intends to interrupt traffic control and let the task resume temporarily
|
||||
self.state.traffic_control_state = TrafficControlState.PAUSED
|
||||
# User has chosen to deliberately continue - lets double the max iterations
|
||||
if (
|
||||
self.state.iteration is not None
|
||||
and self.state.max_iterations is not None
|
||||
and self._initial_max_iterations is not None
|
||||
and not self.headless_mode
|
||||
):
|
||||
if self.state.iteration >= self.state.max_iterations:
|
||||
self.state.max_iterations += self._initial_max_iterations
|
||||
|
||||
if self._pending_action is not None and (
|
||||
if (
|
||||
self.state.metrics.accumulated_cost is not None
|
||||
and self.max_budget_per_task is not None
|
||||
and self._initial_max_budget_per_task is not None
|
||||
):
|
||||
if self.state.metrics.accumulated_cost >= self.max_budget_per_task:
|
||||
self.max_budget_per_task += self._initial_max_budget_per_task
|
||||
elif self._pending_action is not None and (
|
||||
new_state in (AgentState.USER_CONFIRMED, AgentState.USER_REJECTED)
|
||||
):
|
||||
if hasattr(self._pending_action, 'thought'):
|
||||
@@ -596,10 +659,6 @@ class AgentController:
|
||||
EventSource.ENVIRONMENT,
|
||||
)
|
||||
|
||||
# Save state whenever agent state changes to ensure we don't lose state
|
||||
# in case of crashes or unexpected circumstances
|
||||
self.save_state()
|
||||
|
||||
def get_agent_state(self) -> AgentState:
|
||||
"""Returns the current state of the agent.
|
||||
|
||||
@@ -627,27 +686,19 @@ class AgentController:
|
||||
agent_cls: type[Agent] = Agent.get_cls(action.agent)
|
||||
agent_config = self.agent_configs.get(action.agent, self.agent.config)
|
||||
llm_config = self.agent_to_llm_config.get(action.agent, self.agent.llm.config)
|
||||
# Make sure metrics are shared between parent and child for global accumulation
|
||||
llm = LLM(
|
||||
config=llm_config,
|
||||
retry_listener=self.agent.llm.retry_listener,
|
||||
metrics=self.state.metrics,
|
||||
)
|
||||
llm = LLM(config=llm_config, retry_listener=self._notify_on_llm_retry)
|
||||
delegate_agent = agent_cls(llm=llm, config=agent_config)
|
||||
|
||||
# Take a snapshot of the current metrics before starting the delegate
|
||||
state = State(
|
||||
session_id=self.id.removesuffix('-delegate'),
|
||||
inputs=action.inputs or {},
|
||||
iteration_flag=self.state.iteration_flag,
|
||||
budget_flag=self.state.budget_flag,
|
||||
local_iteration=0,
|
||||
iteration=self.state.iteration,
|
||||
max_iterations=self.state.max_iterations,
|
||||
delegate_level=self.state.delegate_level + 1,
|
||||
# global metrics should be shared between parent and child
|
||||
metrics=self.state.metrics,
|
||||
# start on top of the stream
|
||||
start_id=self.event_stream.get_latest_event_id() + 1,
|
||||
parent_metrics_snapshot=self.state_tracker.get_metrics_snapshot(),
|
||||
parent_iteration=self.state.iteration_flag.current_value,
|
||||
)
|
||||
self.log(
|
||||
'debug',
|
||||
@@ -657,12 +708,10 @@ class AgentController:
|
||||
# Create the delegate with is_delegate=True so it does NOT subscribe directly
|
||||
self.delegate = AgentController(
|
||||
sid=self.id + '-delegate',
|
||||
file_store=self.file_store,
|
||||
user_id=self.user_id,
|
||||
agent=delegate_agent,
|
||||
event_stream=self.event_stream,
|
||||
iteration_delta=self._initial_max_iterations,
|
||||
budget_per_task_delta=self._initial_max_budget_per_task,
|
||||
max_iterations=self.state.max_iterations,
|
||||
max_budget_per_task=self.max_budget_per_task,
|
||||
agent_to_llm_config=self.agent_to_llm_config,
|
||||
agent_configs=self.agent_configs,
|
||||
initial_state=state,
|
||||
@@ -681,13 +730,7 @@ class AgentController:
|
||||
delegate_state = self.delegate.get_agent_state()
|
||||
|
||||
# update iteration that is shared across agents
|
||||
self.state.iteration_flag.current_value = (
|
||||
self.delegate.state.iteration_flag.current_value
|
||||
)
|
||||
|
||||
# Calculate delegate-specific metrics before closing the delegate
|
||||
delegate_metrics = self.state.get_local_metrics()
|
||||
logger.info(f'Local metrics for delegate: {delegate_metrics}')
|
||||
self.state.iteration = self.delegate.state.iteration
|
||||
|
||||
# close the delegate controller before adding new events
|
||||
asyncio.get_event_loop().run_until_complete(self.delegate.close())
|
||||
@@ -700,12 +743,8 @@ class AgentController:
|
||||
|
||||
# prepare delegate result observation
|
||||
# TODO: replace this with AI-generated summary (#2395)
|
||||
# Filter out metrics from the formatted output to avoid clutter
|
||||
display_outputs = {
|
||||
k: v for k, v in delegate_outputs.items() if k != 'metrics'
|
||||
}
|
||||
formatted_output = ', '.join(
|
||||
f'{key}: {value}' for key, value in display_outputs.items()
|
||||
f'{key}: {value}' for key, value in delegate_outputs.items()
|
||||
)
|
||||
content = (
|
||||
f'{self.delegate.agent.name} finishes task with {formatted_output}'
|
||||
@@ -759,16 +798,24 @@ class AgentController:
|
||||
|
||||
self.log(
|
||||
'debug',
|
||||
f'LEVEL {self.state.delegate_level} LOCAL STEP {self.state.get_local_step()} GLOBAL STEP {self.state.iteration_flag.current_value}',
|
||||
f'LEVEL {self.state.delegate_level} LOCAL STEP {self.state.local_iteration} GLOBAL STEP {self.state.iteration}',
|
||||
extra={'msg_type': 'STEP'},
|
||||
)
|
||||
|
||||
# Ensure budget control flag is synchronized with the latest metrics.
|
||||
# In the future, we should centralized the use of one LLM object per conversation.
|
||||
# This will help us unify the cost for auto generating titles, running the condensor, etc.
|
||||
# Before many microservices will touh the same llm cost field, we should sync with the budget flag for the controller
|
||||
# and check that we haven't exceeded budget BEFORE executing an agent step.
|
||||
self.state_tracker.sync_budget_flag_with_metrics()
|
||||
stop_step = False
|
||||
if self.state.iteration >= self.state.max_iterations:
|
||||
stop_step = await self._handle_traffic_control(
|
||||
'iteration', self.state.iteration, self.state.max_iterations
|
||||
)
|
||||
if self.max_budget_per_task is not None:
|
||||
current_cost = self.state.metrics.accumulated_cost
|
||||
if current_cost > self.max_budget_per_task:
|
||||
stop_step = await self._handle_traffic_control(
|
||||
'budget', current_cost, self.max_budget_per_task
|
||||
)
|
||||
if stop_step:
|
||||
logger.warning('Stopping agent due to traffic control')
|
||||
return
|
||||
|
||||
if self._is_stuck():
|
||||
await self._react_to_exception(
|
||||
@@ -776,13 +823,7 @@ class AgentController:
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
self.state_tracker.run_control_flags()
|
||||
except Exception as e:
|
||||
logger.warning('Control flag limits hit')
|
||||
await self._react_to_exception(e)
|
||||
return
|
||||
|
||||
self.update_state_before_step()
|
||||
action: Action = NullAction()
|
||||
|
||||
if self._replay_manager.should_replay():
|
||||
@@ -853,9 +894,60 @@ class AgentController:
|
||||
|
||||
self.event_stream.add_event(action, action._source) # type: ignore [attr-defined]
|
||||
|
||||
await self.update_state_after_step()
|
||||
|
||||
log_level = 'info' if LOG_ALL_EVENTS else 'debug'
|
||||
self.log(log_level, str(action), extra={'msg_type': 'ACTION'})
|
||||
|
||||
def _notify_on_llm_retry(self, retries: int, max: int) -> None:
|
||||
if self.status_callback is not None:
|
||||
msg_id = 'STATUS$LLM_RETRY'
|
||||
self.status_callback(
|
||||
'info', msg_id, f'Retrying LLM request, {retries} / {max}'
|
||||
)
|
||||
|
||||
async def _handle_traffic_control(
|
||||
self, limit_type: str, current_value: float, max_value: float
|
||||
) -> bool:
|
||||
"""Handles agent state after hitting the traffic control limit.
|
||||
|
||||
Args:
|
||||
limit_type (str): The type of limit that was hit.
|
||||
current_value (float): The current value of the limit.
|
||||
max_value (float): The maximum value of the limit.
|
||||
"""
|
||||
stop_step = False
|
||||
if self.state.traffic_control_state == TrafficControlState.PAUSED:
|
||||
self.log(
|
||||
'debug', 'Hitting traffic control, temporarily resume upon user request'
|
||||
)
|
||||
self.state.traffic_control_state = TrafficControlState.NORMAL
|
||||
else:
|
||||
self.state.traffic_control_state = TrafficControlState.THROTTLING
|
||||
# Format values as integers for iterations, keep decimals for budget
|
||||
if limit_type == 'iteration':
|
||||
current_str = str(int(current_value))
|
||||
max_str = str(int(max_value))
|
||||
else:
|
||||
current_str = f'{current_value:.2f}'
|
||||
max_str = f'{max_value:.2f}'
|
||||
|
||||
if self.headless_mode:
|
||||
e = RuntimeError(
|
||||
f'Agent reached maximum {limit_type} in headless mode. '
|
||||
f'Current {limit_type}: {current_str}, max {limit_type}: {max_str}'
|
||||
)
|
||||
await self._react_to_exception(e)
|
||||
else:
|
||||
e = RuntimeError(
|
||||
f'Agent reached maximum {limit_type}. '
|
||||
f'Current {limit_type}: {current_str}, max {limit_type}: {max_str}. '
|
||||
)
|
||||
# FIXME: this isn't really an exception--we should have a different path
|
||||
await self._react_to_exception(e)
|
||||
stop_step = True
|
||||
return stop_step
|
||||
|
||||
@property
|
||||
def _pending_action(self) -> Action | None:
|
||||
"""Get the current pending action with time tracking.
|
||||
@@ -923,26 +1015,150 @@ class AgentController:
|
||||
self,
|
||||
state: State | None,
|
||||
max_iterations: int,
|
||||
max_budget_per_task: float | None,
|
||||
confirmation_mode: bool = False,
|
||||
):
|
||||
self.state_tracker.set_initial_state(
|
||||
self.id,
|
||||
self.agent,
|
||||
state,
|
||||
max_iterations,
|
||||
max_budget_per_task,
|
||||
confirmation_mode,
|
||||
)
|
||||
) -> None:
|
||||
"""Sets the initial state for the agent, either from the previous session, or from a parent agent, or by creating a new one.
|
||||
|
||||
Args:
|
||||
state: The state to initialize with, or None to create a new state.
|
||||
max_iterations: The maximum number of iterations allowed for the task.
|
||||
confirmation_mode: Whether to enable confirmation mode.
|
||||
"""
|
||||
# state can come from:
|
||||
# - the previous session, in which case it has history
|
||||
# - from a parent agent, in which case it has no history
|
||||
# - None / a new state
|
||||
|
||||
# If state is None, we create a brand new state and still load the event stream so we can restore the history
|
||||
if state is None:
|
||||
self.state = State(
|
||||
session_id=self.id.removesuffix('-delegate'),
|
||||
inputs={},
|
||||
max_iterations=max_iterations,
|
||||
confirmation_mode=confirmation_mode,
|
||||
)
|
||||
self.state.start_id = 0
|
||||
|
||||
self.log(
|
||||
'info',
|
||||
f'AgentController {self.id} - created new state. start_id: {self.state.start_id}',
|
||||
)
|
||||
else:
|
||||
self.state = state
|
||||
|
||||
if self.state.start_id <= -1:
|
||||
self.state.start_id = 0
|
||||
|
||||
self.log(
|
||||
'info',
|
||||
f'AgentController {self.id} initializing history from event {self.state.start_id}',
|
||||
)
|
||||
|
||||
# Always load from the event stream to avoid losing history
|
||||
self.state_tracker._init_history(
|
||||
self.event_stream,
|
||||
)
|
||||
self._init_history()
|
||||
|
||||
def get_trajectory(self, include_screenshots: bool = False) -> list[dict]:
|
||||
# state history could be partially hidden/truncated before controller is closed
|
||||
assert self._closed
|
||||
return self.state_tracker.get_trajectory(include_screenshots)
|
||||
return [
|
||||
event_to_trajectory(event, include_screenshots)
|
||||
for event in self.state.history
|
||||
]
|
||||
|
||||
def _init_history(self) -> None:
|
||||
"""Initializes the agent's history from the event stream.
|
||||
|
||||
The history is a list of events that:
|
||||
- Excludes events of types listed in self.filter_out
|
||||
- Excludes events with hidden=True attribute
|
||||
- For delegate events (between AgentDelegateAction and AgentDelegateObservation):
|
||||
- Excludes all events between the action and observation
|
||||
- Includes the delegate action and observation themselves
|
||||
"""
|
||||
# define range of events to fetch
|
||||
# delegates start with a start_id and initially won't find any events
|
||||
# otherwise we're restoring a previous session
|
||||
start_id = self.state.start_id if self.state.start_id >= 0 else 0
|
||||
end_id = (
|
||||
self.state.end_id
|
||||
if self.state.end_id >= 0
|
||||
else self.event_stream.get_latest_event_id()
|
||||
)
|
||||
|
||||
# sanity check
|
||||
if start_id > end_id + 1:
|
||||
self.log(
|
||||
'warning',
|
||||
f'start_id {start_id} is greater than end_id + 1 ({end_id + 1}). History will be empty.',
|
||||
)
|
||||
self.state.history = []
|
||||
return
|
||||
|
||||
events: list[Event] = []
|
||||
|
||||
# Get rest of history
|
||||
events_to_add = list(
|
||||
self.event_stream.search_events(
|
||||
start_id=start_id,
|
||||
end_id=end_id,
|
||||
reverse=False,
|
||||
filter=self.agent_history_filter,
|
||||
)
|
||||
)
|
||||
events.extend(events_to_add)
|
||||
|
||||
# Find all delegate action/observation pairs
|
||||
delegate_ranges: list[tuple[int, int]] = []
|
||||
delegate_action_ids: list[int] = [] # stack of unmatched delegate action IDs
|
||||
|
||||
for event in events:
|
||||
if isinstance(event, AgentDelegateAction):
|
||||
delegate_action_ids.append(event.id)
|
||||
# Note: we can get agent=event.agent and task=event.inputs.get('task','')
|
||||
# if we need to track these in the future
|
||||
|
||||
elif isinstance(event, AgentDelegateObservation):
|
||||
# Match with most recent unmatched delegate action
|
||||
if not delegate_action_ids:
|
||||
self.log(
|
||||
'warning',
|
||||
f'Found AgentDelegateObservation without matching action at id={event.id}',
|
||||
)
|
||||
continue
|
||||
|
||||
action_id = delegate_action_ids.pop()
|
||||
delegate_ranges.append((action_id, event.id))
|
||||
|
||||
# Filter out events between delegate action/observation pairs
|
||||
if delegate_ranges:
|
||||
filtered_events: list[Event] = []
|
||||
current_idx = 0
|
||||
|
||||
for start_id, end_id in sorted(delegate_ranges):
|
||||
# Add events before delegate range
|
||||
filtered_events.extend(
|
||||
event for event in events[current_idx:] if event.id < start_id
|
||||
)
|
||||
|
||||
# Add delegate action and observation
|
||||
filtered_events.extend(
|
||||
event for event in events if event.id in (start_id, end_id)
|
||||
)
|
||||
|
||||
# Update index to after delegate range
|
||||
current_idx = next(
|
||||
(i for i, e in enumerate(events) if e.id > end_id), len(events)
|
||||
)
|
||||
|
||||
# Add any remaining events after last delegate range
|
||||
filtered_events.extend(events[current_idx:])
|
||||
|
||||
self.state.history = filtered_events
|
||||
else:
|
||||
self.state.history = events
|
||||
|
||||
# make sure history is in sync
|
||||
self.state.start_id = start_id
|
||||
|
||||
def _handle_long_context_error(self) -> None:
|
||||
# When context window is exceeded, keep roughly half of agent interactions
|
||||
@@ -1143,7 +1359,7 @@ class AgentController:
|
||||
action: The action to attach metrics to
|
||||
"""
|
||||
# Get metrics from agent LLM
|
||||
agent_metrics = self.state.metrics
|
||||
agent_metrics = self.agent.llm.metrics
|
||||
|
||||
# Get metrics from condenser LLM if it exists
|
||||
condenser_metrics: TokenUsage | None = None
|
||||
@@ -1174,10 +1390,10 @@ class AgentController:
|
||||
# Log the metrics information for debugging
|
||||
# Get the latest usage directly from the agent's metrics
|
||||
latest_usage = None
|
||||
if self.state.metrics.token_usages:
|
||||
latest_usage = self.state.metrics.token_usages[-1]
|
||||
if self.agent.llm.metrics.token_usages:
|
||||
latest_usage = self.agent.llm.metrics.token_usages[-1]
|
||||
|
||||
accumulated_usage = self.state.metrics.accumulated_token_usage
|
||||
accumulated_usage = self.agent.llm.metrics.accumulated_token_usage
|
||||
self.log(
|
||||
'debug',
|
||||
f'Action metrics - accumulated_cost: {metrics.accumulated_cost}, '
|
||||
@@ -1265,6 +1481,3 @@ class AgentController:
|
||||
None,
|
||||
)
|
||||
return self._cached_first_user_message
|
||||
|
||||
def save_state(self):
|
||||
self.state_tracker.save_state()
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
T = TypeVar(
|
||||
'T', int, float
|
||||
) # Type for the value (int for iterations, float for budget)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ControlFlag(Generic[T]):
|
||||
"""Base class for control flags that manage limits and state transitions."""
|
||||
|
||||
limit_increase_amount: T
|
||||
current_value: T
|
||||
max_value: T
|
||||
headless_mode: bool = False
|
||||
_hit_limit: bool = False
|
||||
|
||||
def reached_limit(self) -> bool:
|
||||
"""Check if the limit has been reached.
|
||||
|
||||
Returns:
|
||||
bool: True if the limit has been reached, False otherwise.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def increase_limit(self, headless_mode: bool) -> None:
|
||||
"""Expand the limit when needed."""
|
||||
raise NotImplementedError
|
||||
|
||||
def step(self):
|
||||
"""Determine the next state based on the current state and mode.
|
||||
|
||||
Returns:
|
||||
ControlFlagState: The next state.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class IterationControlFlag(ControlFlag[int]):
|
||||
"""Control flag for managing iteration limits."""
|
||||
|
||||
def reached_limit(self) -> bool:
|
||||
"""Check if the iteration limit has been reached."""
|
||||
self._hit_limit = self.current_value >= self.max_value
|
||||
return self._hit_limit
|
||||
|
||||
def increase_limit(self, headless_mode: bool) -> None:
|
||||
"""Expand the iteration limit by adding the initial value."""
|
||||
if not headless_mode and self._hit_limit:
|
||||
self.max_value += self.limit_increase_amount
|
||||
self._hit_limit = False
|
||||
|
||||
def step(self):
|
||||
if self.reached_limit():
|
||||
raise RuntimeError(
|
||||
f'Agent reached maximum iteration. '
|
||||
f'Current iteration: {self.current_value}, max iteration: {self.max_value}'
|
||||
)
|
||||
|
||||
# Increment the current value
|
||||
self.current_value += 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class BudgetControlFlag(ControlFlag[float]):
|
||||
"""Control flag for managing budget limits."""
|
||||
|
||||
def reached_limit(self) -> bool:
|
||||
"""Check if the budget limit has been reached."""
|
||||
self._hit_limit = self.current_value >= self.max_value
|
||||
return self._hit_limit
|
||||
|
||||
def increase_limit(self, headless_mode) -> None:
|
||||
"""Expand the budget limit by adding the initial value to the current value."""
|
||||
if self._hit_limit:
|
||||
self.max_value = self.current_value + self.limit_increase_amount
|
||||
self._hit_limit = False
|
||||
|
||||
def step(self):
|
||||
"""Check if we've reached the limit and update state accordingly.
|
||||
|
||||
Note: Unlike IterationControlFlag, this doesn't increment the value
|
||||
as the budget is updated externally.
|
||||
"""
|
||||
if self.reached_limit():
|
||||
current_str = f'{self.current_value:.2f}'
|
||||
max_str = f'{self.max_value:.2f}'
|
||||
raise RuntimeError(
|
||||
f'Agent reached maximum budget for conversation.'
|
||||
f'Current budget: {current_str}, max budget: {max_str}'
|
||||
)
|
||||
@@ -8,10 +8,6 @@ from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import openhands
|
||||
from openhands.controller.state.control_flags import (
|
||||
BudgetControlFlag,
|
||||
IterationControlFlag,
|
||||
)
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.events.action import (
|
||||
@@ -24,15 +20,7 @@ from openhands.memory.view import View
|
||||
from openhands.storage.files import FileStore
|
||||
from openhands.storage.locations import get_conversation_agent_state_filename
|
||||
|
||||
RESUMABLE_STATES = [
|
||||
AgentState.RUNNING,
|
||||
AgentState.PAUSED,
|
||||
AgentState.AWAITING_USER_INPUT,
|
||||
AgentState.FINISHED,
|
||||
]
|
||||
|
||||
|
||||
# NOTE: this is deprecated
|
||||
class TrafficControlState(str, Enum):
|
||||
# default state, no rate limiting
|
||||
NORMAL = 'normal'
|
||||
@@ -44,6 +32,14 @@ class TrafficControlState(str, Enum):
|
||||
PAUSED = 'paused'
|
||||
|
||||
|
||||
RESUMABLE_STATES = [
|
||||
AgentState.RUNNING,
|
||||
AgentState.PAUSED,
|
||||
AgentState.AWAITING_USER_INPUT,
|
||||
AgentState.FINISHED,
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class State:
|
||||
"""
|
||||
@@ -79,43 +75,35 @@ class State:
|
||||
"""
|
||||
|
||||
session_id: str = ''
|
||||
iteration_flag: IterationControlFlag = field(
|
||||
default_factory=lambda: IterationControlFlag(
|
||||
limit_increase_amount=100, current_value=0, max_value=100
|
||||
)
|
||||
)
|
||||
budget_flag: BudgetControlFlag | None = None
|
||||
# global iteration for the current task
|
||||
iteration: int = 0
|
||||
# local iteration for the current subtask
|
||||
local_iteration: int = 0
|
||||
# max number of iterations for the current task
|
||||
max_iterations: int = 100
|
||||
confirmation_mode: bool = False
|
||||
history: list[Event] = field(default_factory=list)
|
||||
inputs: dict = field(default_factory=dict)
|
||||
outputs: dict = field(default_factory=dict)
|
||||
agent_state: AgentState = AgentState.LOADING
|
||||
resume_state: AgentState | None = None
|
||||
traffic_control_state: TrafficControlState = TrafficControlState.NORMAL
|
||||
# global metrics for the current task
|
||||
metrics: Metrics = field(default_factory=Metrics)
|
||||
# local metrics for the current subtask
|
||||
local_metrics: Metrics = field(default_factory=Metrics)
|
||||
# root agent has level 0, and every delegate increases the level by one
|
||||
delegate_level: int = 0
|
||||
# start_id and end_id track the range of events in history
|
||||
start_id: int = -1
|
||||
end_id: int = -1
|
||||
|
||||
parent_metrics_snapshot: Metrics | None = None
|
||||
parent_iteration: int = 100
|
||||
|
||||
# NOTE: this is used by the controller to track parent's metrics snapshot before delegation
|
||||
delegates: dict[tuple[int, int], tuple[str, str]] = field(default_factory=dict)
|
||||
# NOTE: This will never be used by the controller, but it can be used by different
|
||||
# evaluation tasks to store extra data needed to track the progress/state of the task.
|
||||
extra_data: dict[str, Any] = field(default_factory=dict)
|
||||
last_error: str = ''
|
||||
|
||||
# NOTE: deprecated args, kept here temporarily for backwards compatability
|
||||
# Will be remove in 30 days
|
||||
iteration: int | None = None
|
||||
local_iteration: int | None = None
|
||||
max_iterations: int | None = None
|
||||
traffic_control_state: TrafficControlState | None = None
|
||||
local_metrics: Metrics | None = None
|
||||
delegates: dict[tuple[int, int], tuple[str, str]] | None = None
|
||||
|
||||
def save_to_session(
|
||||
self, sid: str, file_store: FileStore, user_id: str | None
|
||||
) -> None:
|
||||
@@ -177,10 +165,6 @@ class State:
|
||||
|
||||
# first state after restore
|
||||
state.agent_state = AgentState.LOADING
|
||||
|
||||
# We don't need to clean up deprecated fields here
|
||||
# They will be handled by __getstate__ when the state is saved again
|
||||
|
||||
return state
|
||||
|
||||
def __getstate__(self) -> dict:
|
||||
@@ -193,52 +177,15 @@ class State:
|
||||
state.pop('_history_checksum', None)
|
||||
state.pop('_view', None)
|
||||
|
||||
# Remove deprecated fields before pickling
|
||||
state.pop('iteration', None)
|
||||
state.pop('local_iteration', None)
|
||||
state.pop('max_iterations', None)
|
||||
state.pop('traffic_control_state', None)
|
||||
state.pop('local_metrics', None)
|
||||
state.pop('delegates', None)
|
||||
|
||||
return state
|
||||
|
||||
def __setstate__(self, state: dict) -> None:
|
||||
# Check if we're restoring from an older version (before control flags)
|
||||
is_old_version = 'iteration' in state
|
||||
|
||||
# Convert old iteration tracking to new iteration_flag if needed
|
||||
if is_old_version:
|
||||
# Create iteration_flag from old values
|
||||
max_iterations = state.get('max_iterations', 100)
|
||||
current_iteration = state.get('iteration', 0)
|
||||
|
||||
# Add the iteration_flag to the state
|
||||
state['iteration_flag'] = IterationControlFlag(
|
||||
limit_increase_amount=max_iterations,
|
||||
current_value=current_iteration,
|
||||
max_value=max_iterations,
|
||||
)
|
||||
|
||||
# Update the state
|
||||
self.__dict__.update(state)
|
||||
|
||||
# We keep the deprecated fields for backward compatibility
|
||||
# They will be removed by __getstate__ when the state is saved again
|
||||
|
||||
# make sure we always have the attribute history
|
||||
if not hasattr(self, 'history'):
|
||||
self.history = []
|
||||
|
||||
# Ensure we have default values for new fields if they're missing
|
||||
if not hasattr(self, 'iteration_flag'):
|
||||
self.iteration_flag = IterationControlFlag(
|
||||
limit_increase_amount=100, current_value=0, max_value=100
|
||||
)
|
||||
|
||||
if not hasattr(self, 'budget_flag'):
|
||||
self.budget_flag = None
|
||||
|
||||
def get_current_user_intent(self) -> tuple[str | None, list[str] | None]:
|
||||
"""Returns the latest user message and image(if provided) that appears after a FinishAction, or the first (the task) if nothing was finished yet."""
|
||||
last_user_message = None
|
||||
@@ -276,17 +223,6 @@ class State:
|
||||
],
|
||||
}
|
||||
|
||||
def get_local_step(self):
|
||||
if not self.parent_iteration:
|
||||
return self.iteration_flag.current_value
|
||||
|
||||
return self.iteration_flag.current_value - self.parent_iteration
|
||||
|
||||
def get_local_metrics(self):
|
||||
if not self.parent_metrics_snapshot:
|
||||
return self.metrics
|
||||
return self.metrics.diff(self.parent_metrics_snapshot)
|
||||
|
||||
@property
|
||||
def view(self) -> View:
|
||||
# Compute a simple checksum from the history to see if we can re-use any
|
||||
|
||||
@@ -1,290 +0,0 @@
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.controller.state.control_flags import (
|
||||
BudgetControlFlag,
|
||||
IterationControlFlag,
|
||||
)
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action.agent import AgentDelegateAction, ChangeAgentStateAction
|
||||
from openhands.events.action.empty import NullAction
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.event_filter import EventFilter
|
||||
from openhands.events.observation.agent import AgentStateChangedObservation
|
||||
from openhands.events.observation.delegate import AgentDelegateObservation
|
||||
from openhands.events.observation.empty import NullObservation
|
||||
from openhands.events.serialization.event import event_to_trajectory
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.storage.files import FileStore
|
||||
|
||||
|
||||
class StateTracker:
|
||||
"""Manages and synchronizes the state of an agent throughout its lifecycle.
|
||||
|
||||
It is responsible for:
|
||||
1. Maintaining agent state persistence across sessions
|
||||
2. Managing agent history by filtering and tracking relevant events (previously done in the agent controller)
|
||||
3. Synchronizing metrics between the controller and LLM components
|
||||
4. Updating control flags for budget and iteration limits
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, sid: str | None, file_store: FileStore | None, user_id: str | None
|
||||
):
|
||||
self.sid = sid
|
||||
self.file_store = file_store
|
||||
self.user_id = user_id
|
||||
|
||||
# filter out events that are not relevant to the agent
|
||||
# so they will not be included in the agent history
|
||||
self.agent_history_filter = EventFilter(
|
||||
exclude_types=(
|
||||
NullAction,
|
||||
NullObservation,
|
||||
ChangeAgentStateAction,
|
||||
AgentStateChangedObservation,
|
||||
),
|
||||
exclude_hidden=True,
|
||||
)
|
||||
|
||||
def set_initial_state(
|
||||
self,
|
||||
id: str,
|
||||
agent: Agent,
|
||||
state: State | None,
|
||||
max_iterations: int,
|
||||
max_budget_per_task: float | None,
|
||||
confirmation_mode: bool = False,
|
||||
) -> None:
|
||||
"""Sets the initial state for the agent, either from the previous session, or from a parent agent, or by creating a new one.
|
||||
|
||||
Args:
|
||||
state: The state to initialize with, or None to create a new state.
|
||||
max_iterations: The maximum number of iterations allowed for the task.
|
||||
confirmation_mode: Whether to enable confirmation mode.
|
||||
"""
|
||||
# state can come from:
|
||||
# - the previous session, in which case it has history
|
||||
# - from a parent agent, in which case it has no history
|
||||
# - None / a new state
|
||||
|
||||
# If state is None, we create a brand new state and still load the event stream so we can restore the history
|
||||
if state is None:
|
||||
self.state = State(
|
||||
session_id=id.removesuffix('-delegate'),
|
||||
inputs={},
|
||||
iteration_flag=IterationControlFlag(
|
||||
limit_increase_amount=max_iterations,
|
||||
current_value=0,
|
||||
max_value=max_iterations,
|
||||
),
|
||||
budget_flag=None
|
||||
if not max_budget_per_task
|
||||
else BudgetControlFlag(
|
||||
limit_increase_amount=max_budget_per_task,
|
||||
current_value=0,
|
||||
max_value=max_budget_per_task,
|
||||
),
|
||||
confirmation_mode=confirmation_mode,
|
||||
)
|
||||
self.state.start_id = 0
|
||||
|
||||
logger.info(
|
||||
f'AgentController {id} - created new state. start_id: {self.state.start_id}'
|
||||
)
|
||||
else:
|
||||
self.state = state
|
||||
if self.state.start_id <= -1:
|
||||
self.state.start_id = 0
|
||||
|
||||
logger.info(
|
||||
f'AgentController {id} initializing history from event {self.state.start_id}',
|
||||
)
|
||||
|
||||
# Share the state metrics with the agent's LLM metrics
|
||||
# This ensures that all accumulated metrics are always in sync between controller and llm
|
||||
agent.llm.metrics = self.state.metrics
|
||||
|
||||
def _init_history(self, event_stream: EventStream) -> None:
|
||||
"""Initializes the agent's history from the event stream.
|
||||
|
||||
The history is a list of events that:
|
||||
- Excludes events of types listed in self.filter_out
|
||||
- Excludes events with hidden=True attribute
|
||||
- For delegate events (between AgentDelegateAction and AgentDelegateObservation):
|
||||
- Excludes all events between the action and observation
|
||||
- Includes the delegate action and observation themselves
|
||||
"""
|
||||
# define range of events to fetch
|
||||
# delegates start with a start_id and initially won't find any events
|
||||
# otherwise we're restoring a previous session
|
||||
start_id = self.state.start_id if self.state.start_id >= 0 else 0
|
||||
end_id = (
|
||||
self.state.end_id
|
||||
if self.state.end_id >= 0
|
||||
else event_stream.get_latest_event_id()
|
||||
)
|
||||
|
||||
# sanity check
|
||||
if start_id > end_id + 1:
|
||||
logger.warning(
|
||||
f'start_id {start_id} is greater than end_id + 1 ({end_id + 1}). History will be empty.',
|
||||
)
|
||||
self.state.history = []
|
||||
return
|
||||
|
||||
events: list[Event] = []
|
||||
|
||||
# Get rest of history
|
||||
events_to_add = list(
|
||||
event_stream.search_events(
|
||||
start_id=start_id,
|
||||
end_id=end_id,
|
||||
reverse=False,
|
||||
filter=self.agent_history_filter,
|
||||
)
|
||||
)
|
||||
events.extend(events_to_add)
|
||||
|
||||
# Find all delegate action/observation pairs
|
||||
delegate_ranges: list[tuple[int, int]] = []
|
||||
delegate_action_ids: list[int] = [] # stack of unmatched delegate action IDs
|
||||
|
||||
for event in events:
|
||||
if isinstance(event, AgentDelegateAction):
|
||||
delegate_action_ids.append(event.id)
|
||||
# Note: we can get agent=event.agent and task=event.inputs.get('task','')
|
||||
# if we need to track these in the future
|
||||
|
||||
elif isinstance(event, AgentDelegateObservation):
|
||||
# Match with most recent unmatched delegate action
|
||||
if not delegate_action_ids:
|
||||
logger.warning(
|
||||
f'Found AgentDelegateObservation without matching action at id={event.id}',
|
||||
)
|
||||
continue
|
||||
|
||||
action_id = delegate_action_ids.pop()
|
||||
delegate_ranges.append((action_id, event.id))
|
||||
|
||||
# Filter out events between delegate action/observation pairs
|
||||
if delegate_ranges:
|
||||
filtered_events: list[Event] = []
|
||||
current_idx = 0
|
||||
|
||||
for start_id, end_id in sorted(delegate_ranges):
|
||||
# Add events before delegate range
|
||||
filtered_events.extend(
|
||||
event for event in events[current_idx:] if event.id < start_id
|
||||
)
|
||||
|
||||
# Add delegate action and observation
|
||||
filtered_events.extend(
|
||||
event for event in events if event.id in (start_id, end_id)
|
||||
)
|
||||
|
||||
# Update index to after delegate range
|
||||
current_idx = next(
|
||||
(i for i, e in enumerate(events) if e.id > end_id), len(events)
|
||||
)
|
||||
|
||||
# Add any remaining events after last delegate range
|
||||
filtered_events.extend(events[current_idx:])
|
||||
|
||||
self.state.history = filtered_events
|
||||
else:
|
||||
self.state.history = events
|
||||
|
||||
# make sure history is in sync
|
||||
self.state.start_id = start_id
|
||||
|
||||
def close(self, event_stream: EventStream):
|
||||
# we made history, now is the time to rewrite it!
|
||||
# the final state.history will be used by external scripts like evals, tests, etc.
|
||||
# history will need to be complete WITH delegates events
|
||||
# like the regular agent history, it does not include:
|
||||
# - 'hidden' events, events with hidden=True
|
||||
# - backend events (the default 'filtered out' types, types in self.filter_out)
|
||||
start_id = self.state.start_id if self.state.start_id >= 0 else 0
|
||||
end_id = (
|
||||
self.state.end_id
|
||||
if self.state.end_id >= 0
|
||||
else event_stream.get_latest_event_id()
|
||||
)
|
||||
|
||||
self.state.history = list(
|
||||
event_stream.search_events(
|
||||
start_id=start_id,
|
||||
end_id=end_id,
|
||||
reverse=False,
|
||||
filter=self.agent_history_filter,
|
||||
)
|
||||
)
|
||||
|
||||
def add_history(self, event: Event):
|
||||
# if the event is not filtered out, add it to the history
|
||||
if self.agent_history_filter.include(event):
|
||||
self.state.history.append(event)
|
||||
|
||||
def get_trajectory(self, include_screenshots: bool = False) -> list[dict]:
|
||||
return [
|
||||
event_to_trajectory(event, include_screenshots)
|
||||
for event in self.state.history
|
||||
]
|
||||
|
||||
def maybe_increase_control_flags_limits(self, headless_mode: bool):
|
||||
# Iteration and budget extensions are independent of each other
|
||||
# An error will be thrown if any one of the control flags have reached or exceeded its limit
|
||||
self.state.iteration_flag.increase_limit(headless_mode)
|
||||
if self.state.budget_flag:
|
||||
self.state.budget_flag.increase_limit(headless_mode)
|
||||
|
||||
def get_metrics_snapshot(self):
|
||||
"""
|
||||
Deep copy of metrics
|
||||
This serves as a snapshot for the parent's metrics at the time a delegate is created
|
||||
It will be stored and used to compute local metrics for the delegate
|
||||
(since delegates now accumulate metrics from where its parent left off)
|
||||
"""
|
||||
|
||||
return self.state.metrics.copy()
|
||||
|
||||
def save_state(self):
|
||||
"""
|
||||
Save's current state to persistent store
|
||||
"""
|
||||
if self.sid and self.file_store:
|
||||
self.state.save_to_session(self.sid, self.file_store, self.user_id)
|
||||
|
||||
def run_control_flags(self):
|
||||
"""
|
||||
Performs one step of the control flags
|
||||
"""
|
||||
self.state.iteration_flag.step()
|
||||
if self.state.budget_flag:
|
||||
self.state.budget_flag.step()
|
||||
|
||||
def sync_budget_flag_with_metrics(self):
|
||||
"""
|
||||
Ensures that budget flag is up to date with accumulated costs from llm completions
|
||||
Budget flag will monitor for when budget is exceeded
|
||||
"""
|
||||
if self.state.budget_flag:
|
||||
self.state.budget_flag.current_value = self.state.metrics.accumulated_cost
|
||||
|
||||
def merge_metrics(self, metrics: Metrics):
|
||||
"""
|
||||
Merges metrics with the state metrics
|
||||
|
||||
NOTE: this should be refactored in the future. We should have services (draft llm, title autocomplete, condenser, etc)
|
||||
use their own LLMs, but the metrics object should be shared. This way we have one source of truth for accumulated costs from
|
||||
all services
|
||||
|
||||
This would prevent having fragmented stores for metrics, and we don't have the burden of deciding where and how to store them
|
||||
if we decide introduce more specialized services that require llm completions
|
||||
|
||||
"""
|
||||
self.state.metrics.merge(metrics)
|
||||
if self.state.budget_flag:
|
||||
self.state.budget_flag.current_value = self.state.metrics.accumulated_cost
|
||||
@@ -744,6 +744,27 @@ def get_parser() -> argparse.ArgumentParser:
|
||||
type=bool,
|
||||
default=False,
|
||||
)
|
||||
|
||||
# LLM configuration arguments for local models
|
||||
parser.add_argument(
|
||||
'--llm-model',
|
||||
help='LLM model to use (e.g., "lm_studio/devstral", "openai/gpt-4")',
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
'--llm-base-url',
|
||||
help='Base URL for LLM API (required for local models, e.g., "http://localhost:1234/v1")',
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
'--llm-api-key',
|
||||
help='API key for LLM (use "dummy" for local models)',
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@@ -821,6 +842,21 @@ def setup_config_from_args(args: argparse.Namespace) -> OpenHandsConfig:
|
||||
raise ValueError(f'Invalid toml file, cannot read {args.llm_config}')
|
||||
config.set_llm_config(llm_config)
|
||||
|
||||
# Override LLM settings with direct CLI arguments
|
||||
if args.llm_model or args.llm_base_url or args.llm_api_key:
|
||||
from pydantic import SecretStr
|
||||
|
||||
llm_config = config.get_llm_config()
|
||||
|
||||
if args.llm_model:
|
||||
llm_config.model = args.llm_model
|
||||
if args.llm_base_url:
|
||||
llm_config.base_url = args.llm_base_url
|
||||
if args.llm_api_key:
|
||||
llm_config.api_key = SecretStr(args.llm_api_key)
|
||||
|
||||
config.set_llm_config(llm_config)
|
||||
|
||||
# Override default agent if provided
|
||||
if args.agent_cls:
|
||||
config.default_agent = args.agent_cls
|
||||
|
||||
@@ -206,8 +206,8 @@ def create_controller(
|
||||
|
||||
controller = AgentController(
|
||||
agent=agent,
|
||||
iteration_delta=config.max_iterations,
|
||||
budget_per_task_delta=config.max_budget_per_task,
|
||||
max_iterations=config.max_iterations,
|
||||
max_budget_per_task=config.max_budget_per_task,
|
||||
agent_to_llm_config=config.get_agent_to_llm_config_map(),
|
||||
event_stream=event_stream,
|
||||
initial_state=initial_state,
|
||||
|
||||
@@ -773,6 +773,9 @@ class LLM(RetryMixin, DebugMixin):
|
||||
def __repr__(self) -> str:
|
||||
return str(self)
|
||||
|
||||
def reset(self) -> None:
|
||||
self.metrics.reset()
|
||||
|
||||
def format_messages_for_llm(self, messages: Message | list[Message]) -> list[dict]:
|
||||
if isinstance(messages, Message):
|
||||
messages = [messages]
|
||||
|
||||
@@ -193,6 +193,22 @@ class Metrics:
|
||||
'token_usages': [usage.model_dump() for usage in self._token_usages],
|
||||
}
|
||||
|
||||
def reset(self) -> None:
|
||||
self._accumulated_cost = 0.0
|
||||
self._costs = []
|
||||
self._response_latencies = []
|
||||
self._token_usages = []
|
||||
# Reset accumulated token usage with a new instance
|
||||
self._accumulated_token_usage = TokenUsage(
|
||||
model=self.model_name,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
context_window=0,
|
||||
response_id='',
|
||||
)
|
||||
|
||||
def log(self) -> str:
|
||||
"""Log the metrics."""
|
||||
metrics = self.get()
|
||||
@@ -205,58 +221,5 @@ class Metrics:
|
||||
"""Create a deep copy of the Metrics object."""
|
||||
return copy.deepcopy(self)
|
||||
|
||||
def diff(self, baseline: 'Metrics') -> 'Metrics':
|
||||
"""Calculate the difference between current metrics and a baseline.
|
||||
|
||||
This is useful for tracking metrics for specific operations like delegates.
|
||||
|
||||
Args:
|
||||
baseline: A metrics object representing the baseline state
|
||||
|
||||
Returns:
|
||||
A new Metrics object containing only the differences since the baseline
|
||||
"""
|
||||
result = Metrics(self.model_name)
|
||||
|
||||
# Calculate cost difference
|
||||
result._accumulated_cost = self._accumulated_cost - baseline._accumulated_cost
|
||||
|
||||
# Include only costs that were added after the baseline
|
||||
if baseline._costs:
|
||||
last_baseline_timestamp = baseline._costs[-1].timestamp
|
||||
result._costs = [
|
||||
cost for cost in self._costs if cost.timestamp > last_baseline_timestamp
|
||||
]
|
||||
else:
|
||||
result._costs = self._costs.copy()
|
||||
|
||||
# Include only response latencies that were added after the baseline
|
||||
result._response_latencies = self._response_latencies[
|
||||
len(baseline._response_latencies) :
|
||||
]
|
||||
|
||||
# Include only token usages that were added after the baseline
|
||||
result._token_usages = self._token_usages[len(baseline._token_usages) :]
|
||||
|
||||
# Calculate accumulated token usage difference
|
||||
base_usage = baseline.accumulated_token_usage
|
||||
current_usage = self.accumulated_token_usage
|
||||
|
||||
result._accumulated_token_usage = TokenUsage(
|
||||
model=self.model_name,
|
||||
prompt_tokens=current_usage.prompt_tokens - base_usage.prompt_tokens,
|
||||
completion_tokens=current_usage.completion_tokens
|
||||
- base_usage.completion_tokens,
|
||||
cache_read_tokens=current_usage.cache_read_tokens
|
||||
- base_usage.cache_read_tokens,
|
||||
cache_write_tokens=current_usage.cache_write_tokens
|
||||
- base_usage.cache_write_tokens,
|
||||
context_window=current_usage.context_window,
|
||||
per_turn_token=0,
|
||||
response_id='',
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'Metrics({self.get()}'
|
||||
|
||||
@@ -706,9 +706,7 @@ fi
|
||||
# Get authenticated URL and do a shallow clone (--depth 1) for efficiency
|
||||
remote_url = self._get_authenticated_git_url(org_openhands_repo)
|
||||
|
||||
clone_cmd = (
|
||||
f'GIT_TERMINAL_PROMPT=0 git clone --depth 1 {remote_url} {org_repo_dir}'
|
||||
)
|
||||
clone_cmd = f'git clone --depth 1 {remote_url} {org_repo_dir}'
|
||||
|
||||
action = CmdRunAction(command=clone_cmd)
|
||||
obs = self.run_action(action)
|
||||
|
||||
@@ -13,7 +13,6 @@ from daytona_sdk import (
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||
from openhands.runtime.impl.action_execution.action_execution_client import (
|
||||
ActionExecutionClient,
|
||||
)
|
||||
@@ -43,8 +42,6 @@ class DaytonaRuntime(ActionExecutionClient):
|
||||
status_callback: Callable | None = None,
|
||||
attach_to_existing: bool = False,
|
||||
headless_mode: bool = True,
|
||||
user_id: str | None = None,
|
||||
git_provider_tokens: PROVIDER_TOKEN_TYPE | None = None,
|
||||
):
|
||||
assert config.daytona_api_key, 'Daytona API key is required'
|
||||
|
||||
@@ -77,8 +74,6 @@ class DaytonaRuntime(ActionExecutionClient):
|
||||
status_callback,
|
||||
attach_to_existing,
|
||||
headless_mode,
|
||||
user_id,
|
||||
git_provider_tokens,
|
||||
)
|
||||
|
||||
def _get_workspace(self) -> Workspace | None:
|
||||
|
||||
@@ -17,7 +17,6 @@ from openhands.core.exceptions import (
|
||||
from openhands.core.logger import DEBUG, DEBUG_RUNTIME
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events import EventStream
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||
from openhands.runtime.builder import DockerRuntimeBuilder
|
||||
from openhands.runtime.impl.action_execution.action_execution_client import (
|
||||
ActionExecutionClient,
|
||||
@@ -87,8 +86,6 @@ class DockerRuntime(ActionExecutionClient):
|
||||
status_callback: Callable | None = None,
|
||||
attach_to_existing: bool = False,
|
||||
headless_mode: bool = True,
|
||||
user_id: str | None = None,
|
||||
git_provider_tokens: PROVIDER_TOKEN_TYPE | None = None,
|
||||
main_module: str = DEFAULT_MAIN_MODULE,
|
||||
):
|
||||
if not DockerRuntime._shutdown_listener_id:
|
||||
@@ -135,8 +132,6 @@ class DockerRuntime(ActionExecutionClient):
|
||||
status_callback,
|
||||
attach_to_existing,
|
||||
headless_mode,
|
||||
user_id,
|
||||
git_provider_tokens,
|
||||
)
|
||||
|
||||
# Log runtime_extra_deps after base class initialization so self.sid is available
|
||||
|
||||
@@ -12,42 +12,29 @@ from openhands.events.observation import (
|
||||
Observation,
|
||||
)
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||
from openhands.runtime.impl.action_execution.action_execution_client import (
|
||||
ActionExecutionClient,
|
||||
)
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.runtime.impl.e2b.filestore import E2BFileStore
|
||||
from openhands.runtime.impl.e2b.sandbox import E2BSandbox
|
||||
from openhands.runtime.plugins import PluginRequirement
|
||||
from openhands.runtime.utils.files import insert_lines, read_lines
|
||||
|
||||
|
||||
class E2BRuntime(ActionExecutionClient):
|
||||
class E2BRuntime(Runtime):
|
||||
def __init__(
|
||||
self,
|
||||
config: OpenHandsConfig,
|
||||
event_stream: EventStream,
|
||||
sid: str = 'default',
|
||||
plugins: list[PluginRequirement] | None = None,
|
||||
env_vars: dict[str, str] | None = None,
|
||||
status_callback: Callable | None = None,
|
||||
attach_to_existing: bool = False,
|
||||
headless_mode: bool = True,
|
||||
user_id: str | None = None,
|
||||
git_provider_tokens: PROVIDER_TOKEN_TYPE | None = None,
|
||||
sandbox: E2BSandbox | None = None,
|
||||
status_callback: Callable | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
config,
|
||||
event_stream,
|
||||
sid,
|
||||
plugins,
|
||||
env_vars,
|
||||
status_callback,
|
||||
attach_to_existing,
|
||||
headless_mode,
|
||||
user_id,
|
||||
git_provider_tokens,
|
||||
status_callback=status_callback,
|
||||
)
|
||||
if sandbox is None:
|
||||
self.sandbox = E2BSandbox()
|
||||
|
||||
@@ -25,7 +25,6 @@ from openhands.events.observation import (
|
||||
Observation,
|
||||
)
|
||||
from openhands.events.serialization import event_to_dict, observation_from_dict
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||
from openhands.runtime.impl.action_execution.action_execution_client import (
|
||||
ActionExecutionClient,
|
||||
)
|
||||
@@ -146,8 +145,6 @@ class LocalRuntime(ActionExecutionClient):
|
||||
status_callback: Callable[[str, str, str], None] | None = None,
|
||||
attach_to_existing: bool = False,
|
||||
headless_mode: bool = True,
|
||||
user_id: str | None = None,
|
||||
git_provider_tokens: PROVIDER_TOKEN_TYPE | None = None,
|
||||
) -> None:
|
||||
self.is_windows = sys.platform == 'win32'
|
||||
if self.is_windows:
|
||||
@@ -197,8 +194,6 @@ class LocalRuntime(ActionExecutionClient):
|
||||
status_callback,
|
||||
attach_to_existing,
|
||||
headless_mode,
|
||||
user_id,
|
||||
git_provider_tokens,
|
||||
)
|
||||
|
||||
# If there is an API key in the environment we use this in requests to the runtime
|
||||
|
||||
@@ -9,7 +9,6 @@ import tenacity
|
||||
|
||||
from openhands.core.config import OpenHandsConfig
|
||||
from openhands.events import EventStream
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||
from openhands.runtime.impl.action_execution.action_execution_client import (
|
||||
ActionExecutionClient,
|
||||
)
|
||||
@@ -54,8 +53,6 @@ class ModalRuntime(ActionExecutionClient):
|
||||
status_callback: Callable | None = None,
|
||||
attach_to_existing: bool = False,
|
||||
headless_mode: bool = True,
|
||||
user_id: str | None = None,
|
||||
git_provider_tokens: PROVIDER_TOKEN_TYPE | None = None,
|
||||
):
|
||||
assert config.modal_api_token_id, 'Modal API token id is required'
|
||||
assert config.modal_api_token_secret, 'Modal API token secret is required'
|
||||
@@ -103,8 +100,6 @@ class ModalRuntime(ActionExecutionClient):
|
||||
status_callback,
|
||||
attach_to_existing,
|
||||
headless_mode,
|
||||
user_id,
|
||||
git_provider_tokens,
|
||||
)
|
||||
|
||||
async def connect(self):
|
||||
|
||||
@@ -9,7 +9,6 @@ from runloop_api_client.types.shared_params import LaunchParameters
|
||||
from openhands.core.config import OpenHandsConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events import EventStream
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||
from openhands.runtime.impl.action_execution.action_execution_client import (
|
||||
ActionExecutionClient,
|
||||
)
|
||||
@@ -37,8 +36,6 @@ class RunloopRuntime(ActionExecutionClient):
|
||||
status_callback: Callable | None = None,
|
||||
attach_to_existing: bool = False,
|
||||
headless_mode: bool = True,
|
||||
user_id: str | None = None,
|
||||
git_provider_tokens: PROVIDER_TOKEN_TYPE | None = None,
|
||||
):
|
||||
assert config.runloop_api_key is not None, 'Runloop API key is required'
|
||||
self.devbox: DevboxView | None = None
|
||||
@@ -56,8 +53,6 @@ class RunloopRuntime(ActionExecutionClient):
|
||||
status_callback,
|
||||
attach_to_existing,
|
||||
headless_mode,
|
||||
user_id,
|
||||
git_provider_tokens,
|
||||
)
|
||||
# Buffer for container logs
|
||||
self._vscode_url: str | None = None
|
||||
|
||||
@@ -305,6 +305,7 @@ class FileEditRuntimeMixin(FileEditRuntimeInterface):
|
||||
return ErrorObservation(error_msg)
|
||||
|
||||
content_to_edit = '\n'.join(old_file_lines[start_idx:end_idx])
|
||||
self.draft_editor_llm.reset()
|
||||
_edited_content = get_new_file_contents(
|
||||
self.draft_editor_llm, content_to_edit, action.content
|
||||
)
|
||||
|
||||
@@ -232,7 +232,8 @@ class AgentSession:
|
||||
if self.event_stream is not None:
|
||||
self.event_stream.close()
|
||||
if self.controller is not None:
|
||||
self.controller.save_state()
|
||||
end_state = self.controller.get_state()
|
||||
end_state.save_to_session(self.sid, self.file_store, self.user_id)
|
||||
await self.controller.close()
|
||||
if self.runtime is not None:
|
||||
EXECUTOR.submit(self.runtime.close)
|
||||
@@ -365,7 +366,6 @@ class AgentSession:
|
||||
headless_mode=False,
|
||||
attach_to_existing=False,
|
||||
env_vars=env_vars,
|
||||
git_provider_tokens=git_provider_tokens,
|
||||
)
|
||||
|
||||
# FIXME: this sleep is a terrible hack.
|
||||
@@ -438,12 +438,10 @@ class AgentSession:
|
||||
initial_state = self._maybe_restore_state()
|
||||
controller = AgentController(
|
||||
sid=self.sid,
|
||||
user_id=self.user_id,
|
||||
file_store=self.file_store,
|
||||
event_stream=self.event_stream,
|
||||
agent=agent,
|
||||
iteration_delta=int(max_iterations),
|
||||
budget_per_task_delta=max_budget_per_task,
|
||||
max_iterations=int(max_iterations),
|
||||
max_budget_per_task=max_budget_per_task,
|
||||
agent_to_llm_config=agent_to_llm_config,
|
||||
agent_configs=agent_configs,
|
||||
confirmation_mode=confirmation_mode,
|
||||
|
||||
@@ -127,5 +127,5 @@ class PromptManager:
|
||||
None,
|
||||
)
|
||||
if latest_user_message:
|
||||
reminder_text = f'\n\nENVIRONMENT REMINDER: You have {state.iteration_flag.max_value - state.iteration_flag.current_value} turns left to complete the task. When finished reply with <finish></finish>.'
|
||||
reminder_text = f'\n\nENVIRONMENT REMINDER: You have {state.max_iterations - state.iteration} turns left to complete the task. When finished reply with <finish></finish>.'
|
||||
latest_user_message.content.append(TextContent(text=reminder_text))
|
||||
|
||||
@@ -6,7 +6,7 @@ requires = [
|
||||
|
||||
[tool.poetry]
|
||||
name = "openhands-ai"
|
||||
version = "0.44.0"
|
||||
version = "0.43.0"
|
||||
description = "OpenHands: Code Less, Make More"
|
||||
authors = [ "OpenHands" ]
|
||||
license = "MIT"
|
||||
|
||||
@@ -11,10 +11,7 @@ from litellm import (
|
||||
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.controller.agent_controller import AgentController
|
||||
from openhands.controller.state.control_flags import (
|
||||
BudgetControlFlag,
|
||||
)
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.controller.state.state import State, TrafficControlState
|
||||
from openhands.core.config import OpenHandsConfig
|
||||
from openhands.core.config.agent_config import AgentConfig
|
||||
from openhands.core.main import run_controller
|
||||
@@ -131,7 +128,7 @@ async def test_set_agent_state(mock_agent, mock_event_stream):
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
iteration_delta=10,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
@@ -149,7 +146,7 @@ async def test_on_event_message_action(mock_agent, mock_event_stream):
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
iteration_delta=10,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
@@ -166,7 +163,7 @@ async def test_on_event_change_agent_state_action(mock_agent, mock_event_stream)
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
iteration_delta=10,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
@@ -184,7 +181,7 @@ async def test_react_to_exception(mock_agent, mock_event_stream, mock_status_cal
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
status_callback=mock_status_callback,
|
||||
iteration_delta=10,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
@@ -204,7 +201,7 @@ async def test_react_to_content_policy_violation(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
status_callback=mock_status_callback,
|
||||
iteration_delta=10,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
@@ -290,7 +287,7 @@ async def test_run_controller_with_fatal_error(
|
||||
)
|
||||
assert len(error_observations) == 1
|
||||
error_observation = error_observations[0]
|
||||
assert state.iteration_flag.current_value == 3
|
||||
assert state.iteration == 3
|
||||
assert state.agent_state == AgentState.ERROR
|
||||
assert state.last_error == 'AgentStuckInLoopError: Agent got stuck in a loop'
|
||||
assert (
|
||||
@@ -354,7 +351,7 @@ async def test_run_controller_stop_with_stuck(
|
||||
for i, event in enumerate(events):
|
||||
print(f'event {i}: {event_to_dict(event)}')
|
||||
|
||||
assert state.iteration_flag.current_value == 3
|
||||
assert state.iteration == 3
|
||||
assert len(events) == 12
|
||||
# check the eventstream have 4 pairs of repeated actions and observations
|
||||
# With the refactored system message handling, we need to adjust the range
|
||||
@@ -381,19 +378,24 @@ async def test_run_controller_stop_with_stuck(
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_iterations_extension(mock_agent, mock_event_stream):
|
||||
# Test with headless_mode=False - should extend max_iterations
|
||||
initial_state = State(max_iterations=10)
|
||||
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
iteration_delta=10,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=False,
|
||||
initial_state=initial_state,
|
||||
)
|
||||
controller.state.agent_state = AgentState.RUNNING
|
||||
controller.state.iteration_flag.current_value = 10
|
||||
controller.state.iteration = 10
|
||||
assert controller.state.traffic_control_state == TrafficControlState.NORMAL
|
||||
|
||||
# Trigger throttling by calling _step() when we hit max_iterations
|
||||
await controller._step()
|
||||
assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
|
||||
assert controller.state.agent_state == AgentState.ERROR
|
||||
|
||||
# Simulate a new user message
|
||||
@@ -403,24 +405,28 @@ async def test_max_iterations_extension(mock_agent, mock_event_stream):
|
||||
|
||||
# Max iterations should be extended to current iteration + initial max_iterations
|
||||
assert (
|
||||
controller.state.iteration_flag.max_value == 20
|
||||
controller.state.max_iterations == 20
|
||||
) # Current iteration (10 initial because _step() should not have been executed) + initial max_iterations (10)
|
||||
assert controller.state.traffic_control_state == TrafficControlState.NORMAL
|
||||
assert controller.state.agent_state == AgentState.RUNNING
|
||||
|
||||
# Close the controller to clean up
|
||||
await controller.close()
|
||||
|
||||
# Test with headless_mode=True - should NOT extend max_iterations
|
||||
initial_state = State(max_iterations=10)
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
iteration_delta=10,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
initial_state=initial_state,
|
||||
)
|
||||
controller.state.agent_state = AgentState.RUNNING
|
||||
controller.state.iteration_flag.current_value = 10
|
||||
controller.state.iteration = 10
|
||||
assert controller.state.traffic_control_state == TrafficControlState.NORMAL
|
||||
|
||||
# Simulate a new user message
|
||||
message_action = MessageAction(content='Test message')
|
||||
@@ -428,143 +434,64 @@ async def test_max_iterations_extension(mock_agent, mock_event_stream):
|
||||
await send_event_to_controller(controller, message_action)
|
||||
|
||||
# Max iterations should NOT be extended in headless mode
|
||||
assert controller.state.iteration_flag.max_value == 10 # Original value unchanged
|
||||
assert controller.state.max_iterations == 10 # Original value unchanged
|
||||
|
||||
# Trigger throttling by calling _step() when we hit max_iterations
|
||||
await controller._step()
|
||||
|
||||
assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
|
||||
assert controller.state.agent_state == AgentState.ERROR
|
||||
await controller.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_step_max_budget(mock_agent, mock_event_stream):
|
||||
# Metrics are always synced with budget flag before
|
||||
metrics = Metrics()
|
||||
metrics.accumulated_cost = 10.1
|
||||
budget_flag = BudgetControlFlag(
|
||||
limit_increase_amount=10, current_value=10.1, max_value=10
|
||||
)
|
||||
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
iteration_delta=10,
|
||||
budget_per_task_delta=10,
|
||||
max_iterations=10,
|
||||
max_budget_per_task=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=False,
|
||||
initial_state=State(budget_flag=budget_flag, metrics=metrics),
|
||||
)
|
||||
controller.state.agent_state = AgentState.RUNNING
|
||||
controller.state.metrics.accumulated_cost = 10.1
|
||||
assert controller.state.traffic_control_state == TrafficControlState.NORMAL
|
||||
await controller._step()
|
||||
assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
|
||||
assert controller.state.agent_state == AgentState.ERROR
|
||||
await controller.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_step_max_budget_headless(mock_agent, mock_event_stream):
|
||||
# Metrics are always synced with budget flag before
|
||||
metrics = Metrics()
|
||||
metrics.accumulated_cost = 10.1
|
||||
budget_flag = BudgetControlFlag(
|
||||
limit_increase_amount=10, current_value=10.1, max_value=10
|
||||
)
|
||||
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
iteration_delta=10,
|
||||
budget_per_task_delta=10,
|
||||
max_iterations=10,
|
||||
max_budget_per_task=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
initial_state=State(budget_flag=budget_flag, metrics=metrics),
|
||||
)
|
||||
controller.state.agent_state = AgentState.RUNNING
|
||||
controller.state.metrics.accumulated_cost = 10.1
|
||||
assert controller.state.traffic_control_state == TrafficControlState.NORMAL
|
||||
await controller._step()
|
||||
assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
|
||||
# In headless mode, throttling results in an error
|
||||
assert controller.state.agent_state == AgentState.ERROR
|
||||
await controller.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_budget_reset_on_continue(mock_agent, mock_event_stream):
|
||||
"""Test that when a user continues after hitting the budget limit:
|
||||
1. Error is thrown when budget cap is exceeded
|
||||
2. LLM budget does not reset when user continues
|
||||
3. Budget is extended by adding the initial budget cap to the current accumulated cost
|
||||
"""
|
||||
|
||||
# Create a real Metrics instance shared between controller state and llm
|
||||
metrics = Metrics()
|
||||
metrics.accumulated_cost = 6.0
|
||||
|
||||
initial_budget = 5.0
|
||||
|
||||
initial_state = State(
|
||||
metrics=metrics,
|
||||
budget_flag=BudgetControlFlag(
|
||||
limit_increase_amount=initial_budget,
|
||||
current_value=6.0,
|
||||
max_value=initial_budget,
|
||||
),
|
||||
)
|
||||
|
||||
# Create controller with budget cap
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
iteration_delta=10,
|
||||
budget_per_task_delta=initial_budget,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=False,
|
||||
initial_state=initial_state,
|
||||
)
|
||||
|
||||
# Set up initial state
|
||||
controller.state.agent_state = AgentState.RUNNING
|
||||
|
||||
# Set up metrics to simulate having spent more than the budget
|
||||
assert controller.state.budget_flag.current_value == 6.0
|
||||
assert controller.agent.llm.metrics.accumulated_cost == 6.0
|
||||
|
||||
# Trigger budget limit
|
||||
await controller._step()
|
||||
|
||||
# Verify budget limit was hit and error was thrown
|
||||
assert controller.state.agent_state == AgentState.ERROR
|
||||
assert 'budget' in controller.state.last_error.lower()
|
||||
|
||||
# Now set the agent state to RUNNING (simulating user clicking "continue")
|
||||
await controller.set_agent_state_to(AgentState.RUNNING)
|
||||
|
||||
# Now simulate user sending a message
|
||||
message_action = MessageAction(content='Please continue')
|
||||
message_action._source = EventSource.USER
|
||||
await controller._on_event(message_action)
|
||||
|
||||
# Verify budget cap was extended by adding initial budget to current accumulated cost
|
||||
# accumulated cost (6.0) + initial budget (5.0) = 11.0
|
||||
assert controller.state.budget_flag.max_value == 11.0
|
||||
|
||||
# Verify LLM metrics were NOT reset - they should still be 6.0
|
||||
assert controller.agent.llm.metrics.accumulated_cost == 6.0
|
||||
|
||||
# The controller state metrics are same as llm metrics
|
||||
assert controller.state.metrics.accumulated_cost == 6.0
|
||||
|
||||
# Verify traffic control state was reset
|
||||
await controller.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_with_pending_action_no_observation(mock_agent, mock_event_stream):
|
||||
"""Test reset() when there's a pending action with tool call metadata but no observation."""
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
iteration_delta=10,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
@@ -613,7 +540,7 @@ async def test_reset_with_pending_action_existing_observation(
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
iteration_delta=10,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
@@ -655,7 +582,7 @@ async def test_reset_without_pending_action(mock_agent, mock_event_stream):
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
iteration_delta=10,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
@@ -686,7 +613,7 @@ async def test_reset_with_pending_action_no_metadata(
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
iteration_delta=10,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
@@ -735,8 +662,6 @@ async def test_run_controller_max_iterations_has_metrics(
|
||||
mock_agent.llm.metrics = Metrics()
|
||||
mock_agent.llm.config = config.get_llm_config()
|
||||
|
||||
step_count = 0
|
||||
|
||||
def agent_step_fn(state):
|
||||
print(f'agent_step_fn received state: {state}')
|
||||
# Mock the cost of the LLM
|
||||
@@ -744,9 +669,7 @@ async def test_run_controller_max_iterations_has_metrics(
|
||||
print(
|
||||
f'mock_agent.llm.metrics.accumulated_cost: {mock_agent.llm.metrics.accumulated_cost}'
|
||||
)
|
||||
nonlocal step_count
|
||||
step_count += 1
|
||||
return CmdRunAction(command=f'ls {step_count}')
|
||||
return CmdRunAction(command='ls')
|
||||
|
||||
mock_agent.step = agent_step_fn
|
||||
|
||||
@@ -783,13 +706,11 @@ async def test_run_controller_max_iterations_has_metrics(
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
memory=mock_memory,
|
||||
)
|
||||
|
||||
state.metrics = mock_agent.llm.metrics
|
||||
assert state.iteration_flag.current_value == 3
|
||||
assert state.iteration == 3
|
||||
assert state.agent_state == AgentState.ERROR
|
||||
assert (
|
||||
state.last_error
|
||||
== 'RuntimeError: Agent reached maximum iteration. Current iteration: 3, max iteration: 3'
|
||||
== 'RuntimeError: Agent reached maximum iteration in headless mode. Current iteration: 3, max iteration: 3'
|
||||
)
|
||||
error_observations = test_event_stream.get_matching_events(
|
||||
reverse=True, limit=1, event_types=(AgentStateChangedObservation)
|
||||
@@ -799,7 +720,7 @@ async def test_run_controller_max_iterations_has_metrics(
|
||||
|
||||
assert (
|
||||
error_observation.reason
|
||||
== 'RuntimeError: Agent reached maximum iteration. Current iteration: 3, max iteration: 3'
|
||||
== 'RuntimeError: Agent reached maximum iteration in headless mode. Current iteration: 3, max iteration: 3'
|
||||
)
|
||||
|
||||
assert state.metrics.accumulated_cost == 10.0 * 3, (
|
||||
@@ -813,19 +734,12 @@ async def test_notify_on_llm_retry(mock_agent, mock_event_stream, mock_status_ca
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
status_callback=mock_status_callback,
|
||||
iteration_delta=10,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
|
||||
def notify_on_llm_retry(attempt, max_attempts):
|
||||
controller.status_callback('info', 'STATUS$LLM_RETRY', ANY)
|
||||
|
||||
# Attach the retry listener to the agent's LLM
|
||||
controller.agent.llm.retry_listener = notify_on_llm_retry
|
||||
|
||||
controller.agent.llm.retry_listener(1, 2)
|
||||
controller._notify_on_llm_retry(1, 2)
|
||||
controller.status_callback.assert_called_once_with('info', 'STATUS$LLM_RETRY', ANY)
|
||||
await controller.close()
|
||||
|
||||
@@ -1051,11 +965,11 @@ async def test_run_controller_with_context_window_exceeded_with_truncation(
|
||||
|
||||
# Hitting the iteration limit indicates the controller is failing for the
|
||||
# expected reason
|
||||
assert state.iteration_flag.current_value == 5
|
||||
assert state.iteration == 5
|
||||
assert state.agent_state == AgentState.ERROR
|
||||
assert (
|
||||
state.last_error
|
||||
== 'RuntimeError: Agent reached maximum iteration. Current iteration: 5, max iteration: 5'
|
||||
== 'RuntimeError: Agent reached maximum iteration in headless mode. Current iteration: 5, max iteration: 5'
|
||||
)
|
||||
|
||||
# Check that the context window exceeded error was raised during the run
|
||||
@@ -1128,7 +1042,7 @@ async def test_run_controller_with_context_window_exceeded_without_truncation(
|
||||
# Hitting the iteration limit indicates the controller is failing for the
|
||||
# expected reason
|
||||
# With the refactored system message handling, the iteration count is different
|
||||
assert state.iteration_flag.current_value == 1
|
||||
assert state.iteration == 1
|
||||
assert state.agent_state == AgentState.ERROR
|
||||
assert (
|
||||
state.last_error
|
||||
@@ -1188,7 +1102,7 @@ async def test_run_controller_with_memory_error(test_event_stream, mock_agent):
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
assert state.iteration_flag.current_value == 0
|
||||
assert state.iteration == 0
|
||||
assert state.agent_state == AgentState.ERROR
|
||||
assert state.last_error == 'Error: RuntimeError'
|
||||
|
||||
@@ -1199,13 +1113,10 @@ async def test_action_metrics_copy(mock_agent):
|
||||
file_store = InMemoryFileStore({})
|
||||
event_stream = EventStream(sid='test', file_store=file_store)
|
||||
|
||||
metrics = Metrics(model_name='test-model')
|
||||
metrics.accumulated_cost = 0.05
|
||||
|
||||
initial_state = State(metrics=metrics, budget_flag=None)
|
||||
|
||||
# Create agent with metrics
|
||||
mock_agent.llm = MagicMock(spec=LLM)
|
||||
metrics = Metrics(model_name='test-model')
|
||||
metrics.accumulated_cost = 0.05
|
||||
|
||||
# Add multiple token usages - we should get the last one in the action
|
||||
usage1 = TokenUsage(
|
||||
@@ -1259,11 +1170,10 @@ async def test_action_metrics_copy(mock_agent):
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=event_stream,
|
||||
iteration_delta=10,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
initial_state=initial_state,
|
||||
)
|
||||
|
||||
# Execute one step
|
||||
@@ -1330,7 +1240,7 @@ async def test_condenser_metrics_included(mock_agent, test_event_stream):
|
||||
cache_write_tokens=10,
|
||||
response_id='agent-accumulated',
|
||||
)
|
||||
# mock_agent.llm.metrics = agent_metrics
|
||||
mock_agent.llm.metrics = agent_metrics
|
||||
mock_agent.name = 'TestAgent'
|
||||
|
||||
# Create condenser with its own metrics
|
||||
@@ -1369,11 +1279,10 @@ async def test_condenser_metrics_included(mock_agent, test_event_stream):
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=test_event_stream,
|
||||
iteration_delta=10,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
initial_state=State(metrics=agent_metrics, budget_flag=None),
|
||||
)
|
||||
|
||||
# Execute one step
|
||||
@@ -1428,7 +1337,7 @@ async def test_first_user_message_with_identical_content(test_event_stream, mock
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=test_event_stream,
|
||||
iteration_delta=10,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
@@ -1500,7 +1409,7 @@ async def test_agent_controller_processes_null_observation_with_cause():
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=event_stream,
|
||||
iteration_delta=10,
|
||||
max_iterations=10,
|
||||
sid='test-session',
|
||||
)
|
||||
|
||||
@@ -1571,7 +1480,7 @@ def test_agent_controller_should_step_with_null_observation_cause_zero(mock_agen
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=event_stream,
|
||||
iteration_delta=10,
|
||||
max_iterations=10,
|
||||
sid='test-session',
|
||||
)
|
||||
|
||||
@@ -1592,7 +1501,7 @@ def test_agent_controller_should_step_with_null_observation_cause_zero(mock_agen
|
||||
def test_system_message_in_event_stream(mock_agent, test_event_stream):
|
||||
"""Test that SystemMessageAction is added to event stream in AgentController."""
|
||||
_ = AgentController(
|
||||
agent=mock_agent, event_stream=test_event_stream, iteration_delta=10
|
||||
agent=mock_agent, event_stream=test_event_stream, max_iterations=10
|
||||
)
|
||||
|
||||
# Get events from the event stream
|
||||
@@ -1644,7 +1553,7 @@ async def test_openrouter_context_window_exceeded_error(
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=test_event_stream,
|
||||
iteration_delta=max_iterations,
|
||||
max_iterations=max_iterations,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
|
||||
@@ -7,10 +7,6 @@ import pytest
|
||||
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.controller.agent_controller import AgentController
|
||||
from openhands.controller.state.control_flags import (
|
||||
BudgetControlFlag,
|
||||
IterationControlFlag,
|
||||
)
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import LLMConfig
|
||||
from openhands.core.config.agent_config import AgentConfig
|
||||
@@ -22,8 +18,6 @@ from openhands.events.action import (
|
||||
MessageAction,
|
||||
)
|
||||
from openhands.events.action.agent import RecallAction
|
||||
from openhands.events.action.commands import CmdRunAction
|
||||
from openhands.events.action.message import SystemMessageAction
|
||||
from openhands.events.event import Event, RecallType
|
||||
from openhands.events.observation.agent import RecallObservation
|
||||
from openhands.events.stream import EventStreamSubscriber
|
||||
@@ -49,14 +43,16 @@ def mock_parent_agent():
|
||||
agent.llm = MagicMock(spec=LLM)
|
||||
agent.llm.metrics = Metrics()
|
||||
agent.llm.config = LLMConfig()
|
||||
agent.llm.retry_listener = None # Add retry_listener attribute
|
||||
agent.config = AgentConfig()
|
||||
|
||||
# Add a proper system message mock
|
||||
from openhands.events.action.message import SystemMessageAction
|
||||
|
||||
system_message = SystemMessageAction(content='Test system message')
|
||||
system_message._source = EventSource.AGENT
|
||||
system_message._id = -1 # Set invalid ID to avoid the ID check
|
||||
agent.get_system_message.return_value = system_message
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
@@ -68,54 +64,34 @@ def mock_child_agent():
|
||||
agent.llm = MagicMock(spec=LLM)
|
||||
agent.llm.metrics = Metrics()
|
||||
agent.llm.config = LLMConfig()
|
||||
agent.llm.retry_listener = None # Add retry_listener attribute
|
||||
agent.config = AgentConfig()
|
||||
|
||||
# Add a proper system message mock
|
||||
from openhands.events.action.message import SystemMessageAction
|
||||
|
||||
system_message = SystemMessageAction(content='Test system message')
|
||||
system_message._source = EventSource.AGENT
|
||||
system_message._id = -1 # Set invalid ID to avoid the ID check
|
||||
agent.get_system_message.return_value = system_message
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_stream):
|
||||
"""
|
||||
Test that when the parent agent delegates to a child
|
||||
1. the parent's delegate is set, and once the child finishes, the parent is cleaned up properly.
|
||||
2. metrics are accumulated globally (delegate is adding to the parents metrics)
|
||||
3. local metrics for the delegate are still accessible
|
||||
Test that when the parent agent delegates to a child, the parent's delegate
|
||||
is set, and once the child finishes, the parent is cleaned up properly.
|
||||
"""
|
||||
# Mock the agent class resolution so that AgentController can instantiate mock_child_agent
|
||||
Agent.get_cls = Mock(return_value=lambda llm, config: mock_child_agent)
|
||||
|
||||
step_count = 0
|
||||
|
||||
def agent_step_fn(state):
|
||||
nonlocal step_count
|
||||
step_count += 1
|
||||
return CmdRunAction(command=f'ls {step_count}')
|
||||
|
||||
mock_child_agent.step = agent_step_fn
|
||||
|
||||
parent_metrics = Metrics()
|
||||
parent_metrics.accumulated_cost = 2
|
||||
# Create parent controller
|
||||
parent_state = State(
|
||||
inputs={},
|
||||
metrics=parent_metrics,
|
||||
budget_flag=BudgetControlFlag(
|
||||
current_value=2, limit_increase_amount=10, max_value=10
|
||||
),
|
||||
iteration_flag=IterationControlFlag(
|
||||
current_value=1, limit_increase_amount=10, max_value=10
|
||||
),
|
||||
)
|
||||
|
||||
parent_state = State(max_iterations=10)
|
||||
parent_controller = AgentController(
|
||||
agent=mock_parent_agent,
|
||||
event_stream=mock_event_stream,
|
||||
iteration_delta=1, # Add the required iteration_delta parameter
|
||||
max_iterations=10,
|
||||
sid='parent',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
@@ -156,9 +132,8 @@ async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_s
|
||||
# Verify that a RecallObservation was added to the event stream
|
||||
events = list(mock_event_stream.get_events())
|
||||
|
||||
# The exact number of events might vary depending on implementation details
|
||||
# Just verify that we have at least a few events
|
||||
assert mock_event_stream.get_latest_event_id() >= 3
|
||||
# SystemMessageAction, RecallAction, AgentChangeState, AgentDelegateAction, SystemMessageAction (for child)
|
||||
assert mock_event_stream.get_latest_event_id() == 5
|
||||
|
||||
# a RecallObservation and an AgentDelegateAction should be in the list
|
||||
assert any(isinstance(event, RecallObservation) for event in events)
|
||||
@@ -170,33 +145,13 @@ async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_s
|
||||
)
|
||||
|
||||
# The parent's iteration should have incremented
|
||||
assert parent_controller.state.iteration_flag.current_value == 2, (
|
||||
assert parent_controller.state.iteration == 1, (
|
||||
'Parent iteration should be incremented after step.'
|
||||
)
|
||||
|
||||
# Now simulate that the child increments local iteration and finishes its subtask
|
||||
delegate_controller = parent_controller.delegate
|
||||
|
||||
# Take four delegate steps; mock cost per step
|
||||
for i in range(4):
|
||||
delegate_controller.state.iteration_flag.step()
|
||||
delegate_controller.agent.step(delegate_controller.state)
|
||||
delegate_controller.agent.llm.metrics.add_cost(1.0)
|
||||
|
||||
assert (
|
||||
delegate_controller.state.get_local_step() == 4
|
||||
) # verify local metrics are accessible via snapshot
|
||||
|
||||
assert (
|
||||
delegate_controller.state.metrics.accumulated_cost
|
||||
== 6 # Make sure delegate tracks global cost
|
||||
)
|
||||
|
||||
assert (
|
||||
delegate_controller.state.get_local_metrics().accumulated_cost
|
||||
== 4 # Delegate spent one dollar per step
|
||||
)
|
||||
|
||||
delegate_controller.state.iteration = 5 # child had some steps
|
||||
delegate_controller.state.outputs = {'delegate_result': 'done'}
|
||||
|
||||
# The child is done, so we simulate it finishing:
|
||||
@@ -210,7 +165,7 @@ async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_s
|
||||
)
|
||||
|
||||
# Parent's global iteration is updated from the child
|
||||
assert parent_controller.state.iteration_flag.current_value == 7, (
|
||||
assert parent_controller.state.iteration == 6, (
|
||||
"Parent iteration should be the child's iteration + 1 after child is done."
|
||||
)
|
||||
|
||||
@@ -232,24 +187,19 @@ async def test_delegate_step_different_states(
|
||||
mock_parent_agent, mock_event_stream, delegate_state
|
||||
):
|
||||
"""Ensure that delegate is closed or remains open based on the delegate's state."""
|
||||
# Create a state with iteration_flag.max_value set to 10
|
||||
state = State(inputs={})
|
||||
state.iteration_flag.max_value = 10
|
||||
controller = AgentController(
|
||||
agent=mock_parent_agent,
|
||||
event_stream=mock_event_stream,
|
||||
iteration_delta=1, # Add the required iteration_delta parameter
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
initial_state=state,
|
||||
)
|
||||
|
||||
mock_delegate = AsyncMock()
|
||||
controller.delegate = mock_delegate
|
||||
|
||||
mock_delegate.state.iteration_flag = MagicMock()
|
||||
mock_delegate.state.iteration_flag.current_value = 5
|
||||
mock_delegate.state.iteration = 5
|
||||
mock_delegate.state.outputs = {'result': 'test'}
|
||||
mock_delegate.agent.name = 'TestDelegate'
|
||||
|
||||
@@ -257,7 +207,7 @@ async def test_delegate_step_different_states(
|
||||
mock_delegate._step = AsyncMock()
|
||||
mock_delegate.close = AsyncMock()
|
||||
|
||||
async def call_on_event_with_new_loop():
|
||||
def call_on_event_with_new_loop():
|
||||
"""
|
||||
In this thread, create and set a fresh event loop, so that the run_until_complete()
|
||||
calls inside controller.on_event(...) find a valid loop.
|
||||
@@ -276,135 +226,14 @@ async def test_delegate_step_different_states(
|
||||
future = loop.run_in_executor(executor, call_on_event_with_new_loop)
|
||||
await future
|
||||
|
||||
# Give time for the event loop to process events
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
if delegate_state == AgentState.RUNNING:
|
||||
assert controller.delegate is not None
|
||||
assert controller.state.iteration_flag.current_value == 0
|
||||
assert controller.state.iteration == 0
|
||||
mock_delegate.close.assert_not_called()
|
||||
else:
|
||||
assert controller.delegate is None
|
||||
assert controller.state.iteration_flag.current_value == 5
|
||||
assert controller.state.iteration == 5
|
||||
# The close method is called once in end_delegate
|
||||
assert mock_delegate.close.call_count == 1
|
||||
|
||||
await controller.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delegate_hits_global_limits(
|
||||
mock_child_agent, mock_event_stream, mock_parent_agent
|
||||
):
|
||||
"""
|
||||
Global limits from control flags should apply to delegates
|
||||
"""
|
||||
# Mock the agent class resolution so that AgentController can instantiate mock_child_agent
|
||||
Agent.get_cls = Mock(return_value=lambda llm, config: mock_child_agent)
|
||||
|
||||
parent_metrics = Metrics()
|
||||
parent_metrics.accumulated_cost = 2
|
||||
# Create parent controller
|
||||
parent_state = State(
|
||||
inputs={},
|
||||
metrics=parent_metrics,
|
||||
budget_flag=BudgetControlFlag(
|
||||
current_value=2, limit_increase_amount=10, max_value=10
|
||||
),
|
||||
iteration_flag=IterationControlFlag(
|
||||
current_value=2, limit_increase_amount=3, max_value=3
|
||||
),
|
||||
)
|
||||
|
||||
parent_controller = AgentController(
|
||||
agent=mock_parent_agent,
|
||||
event_stream=mock_event_stream,
|
||||
iteration_delta=1, # Add the required iteration_delta parameter
|
||||
sid='parent',
|
||||
confirmation_mode=False,
|
||||
headless_mode=False,
|
||||
initial_state=parent_state,
|
||||
)
|
||||
|
||||
# Setup Memory to catch RecallActions
|
||||
mock_memory = MagicMock(spec=Memory)
|
||||
mock_memory.event_stream = mock_event_stream
|
||||
|
||||
def on_event(event: Event):
|
||||
if isinstance(event, RecallAction):
|
||||
# create a RecallObservation
|
||||
microagent_observation = RecallObservation(
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
content='Found info',
|
||||
)
|
||||
microagent_observation._cause = event.id # ignore attr-defined warning
|
||||
mock_event_stream.add_event(microagent_observation, EventSource.ENVIRONMENT)
|
||||
|
||||
mock_memory.on_event = on_event
|
||||
mock_event_stream.subscribe(
|
||||
EventStreamSubscriber.MEMORY, mock_memory.on_event, mock_memory
|
||||
)
|
||||
|
||||
# Setup a delegate action from the parent
|
||||
delegate_action = AgentDelegateAction(agent='ChildAgent', inputs={'test': True})
|
||||
mock_parent_agent.step.return_value = delegate_action
|
||||
|
||||
# Simulate a user message event to cause parent.step() to run
|
||||
message_action = MessageAction(content='please delegate now')
|
||||
message_action._source = EventSource.USER
|
||||
await parent_controller._on_event(message_action)
|
||||
|
||||
# Give time for the async step() to execute
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Verify that a RecallObservation was added to the event stream
|
||||
events = list(mock_event_stream.get_events())
|
||||
|
||||
# The exact number of events might vary depending on implementation details
|
||||
# Just verify that we have at least a few events
|
||||
assert mock_event_stream.get_latest_event_id() >= 3
|
||||
|
||||
# a RecallObservation and an AgentDelegateAction should be in the list
|
||||
assert any(isinstance(event, RecallObservation) for event in events)
|
||||
assert any(isinstance(event, AgentDelegateAction) for event in events)
|
||||
|
||||
# Verify that a delegate agent controller is created
|
||||
assert parent_controller.delegate is not None, (
|
||||
"Parent's delegate controller was not set."
|
||||
)
|
||||
|
||||
delegate_controller = parent_controller.delegate
|
||||
await delegate_controller.set_agent_state_to(AgentState.RUNNING)
|
||||
|
||||
# Step should hit max budget
|
||||
message_action = MessageAction(content='Test message')
|
||||
message_action._source = EventSource.USER
|
||||
|
||||
await delegate_controller._on_event(message_action)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert delegate_controller.state.agent_state == AgentState.ERROR
|
||||
assert (
|
||||
delegate_controller.state.last_error
|
||||
== 'RuntimeError: Agent reached maximum iteration. Current iteration: 3, max iteration: 3'
|
||||
)
|
||||
|
||||
await delegate_controller.set_agent_state_to(AgentState.RUNNING)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert delegate_controller.state.iteration_flag.max_value == 6
|
||||
assert (
|
||||
delegate_controller.state.iteration_flag.max_value
|
||||
== parent_controller.state.iteration_flag.max_value
|
||||
)
|
||||
|
||||
message_action = MessageAction(content='Test message 2')
|
||||
message_action._source = EventSource.USER
|
||||
await delegate_controller._on_event(message_action)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert delegate_controller.state.iteration_flag.current_value == 4
|
||||
assert (
|
||||
delegate_controller.state.iteration_flag.current_value
|
||||
== parent_controller.state.iteration_flag.current_value
|
||||
)
|
||||
|
||||
@@ -99,17 +99,13 @@ def controller_fixture():
|
||||
# Ensure get_latest_event_id returns an integer
|
||||
mock_event_stream.get_latest_event_id.return_value = -1
|
||||
|
||||
# Create a state with iteration_flag.max_value set to 10
|
||||
state = State(inputs={}, session_id='test_sid')
|
||||
state.iteration_flag.max_value = 10
|
||||
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
iteration_delta=1, # Add the required iteration_delta parameter
|
||||
max_iterations=10,
|
||||
sid='test_sid',
|
||||
initial_state=state,
|
||||
)
|
||||
controller.state = State(session_id='test_sid')
|
||||
|
||||
# Don't mock _first_user_message anymore since we need it to work with history
|
||||
return controller
|
||||
|
||||
@@ -17,8 +17,6 @@ from openhands.runtime.impl.action_execution.action_execution_client import (
|
||||
from openhands.server.session.agent_session import AgentSession
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
# We'll use the DeprecatedState class from the main codebase
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent():
|
||||
@@ -133,7 +131,7 @@ async def test_agent_session_start_with_no_state(mock_agent):
|
||||
# Verify set_initial_state was called once with None as state
|
||||
assert session.controller.set_initial_state_call_count == 1
|
||||
assert session.controller.test_initial_state is None
|
||||
assert session.controller.state.iteration_flag.max_value == 10
|
||||
assert session.controller.state.max_iterations == 10
|
||||
assert session.controller.agent.name == 'test-agent'
|
||||
assert session.controller.state.start_id == 0
|
||||
assert session.controller.state.end_id == -1
|
||||
@@ -173,11 +171,7 @@ async def test_agent_session_start_with_restored_state(mock_agent):
|
||||
mock_restored_state = MagicMock(spec=State)
|
||||
mock_restored_state.start_id = -1
|
||||
mock_restored_state.end_id = -1
|
||||
# Use iteration_flag instead of max_iterations
|
||||
mock_restored_state.iteration_flag = MagicMock()
|
||||
mock_restored_state.iteration_flag.max_value = 5
|
||||
# Add metrics attribute
|
||||
mock_restored_state.metrics = MagicMock(spec=Metrics)
|
||||
mock_restored_state.max_iterations = 5
|
||||
|
||||
# Create a spy on set_initial_state by subclassing AgentController
|
||||
class SpyAgentController(AgentController):
|
||||
@@ -225,180 +219,6 @@ async def test_agent_session_start_with_restored_state(mock_agent):
|
||||
)
|
||||
assert session.controller.test_initial_state is mock_restored_state
|
||||
assert session.controller.state is mock_restored_state
|
||||
assert session.controller.state.iteration_flag.max_value == 5
|
||||
assert session.controller.state.max_iterations == 5
|
||||
assert session.controller.state.start_id == 0
|
||||
assert session.controller.state.end_id == -1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metrics_centralization_and_sharing(mock_agent):
|
||||
"""Test that metrics are centralized and shared between controller and agent."""
|
||||
|
||||
# Setup
|
||||
file_store = InMemoryFileStore({})
|
||||
session = AgentSession(
|
||||
sid='test-session',
|
||||
file_store=file_store,
|
||||
)
|
||||
|
||||
# Create a mock runtime and set it up
|
||||
mock_runtime = MagicMock(spec=ActionExecutionClient)
|
||||
|
||||
# Mock the runtime creation to set up the runtime attribute
|
||||
async def mock_create_runtime(*args, **kwargs):
|
||||
session.runtime = mock_runtime
|
||||
return True
|
||||
|
||||
session._create_runtime = AsyncMock(side_effect=mock_create_runtime)
|
||||
|
||||
# Create a mock EventStream with no events
|
||||
mock_event_stream = MagicMock(spec=EventStream)
|
||||
mock_event_stream.get_events.return_value = []
|
||||
mock_event_stream.subscribe = MagicMock()
|
||||
mock_event_stream.get_latest_event_id.return_value = 0
|
||||
|
||||
# Inject the mock event stream into the session
|
||||
session.event_stream = mock_event_stream
|
||||
|
||||
# Create a real Memory instance with the mock event stream
|
||||
memory = Memory(event_stream=mock_event_stream, sid='test-session')
|
||||
memory.microagents_dir = 'test-dir'
|
||||
|
||||
# Patch necessary components
|
||||
with (
|
||||
patch(
|
||||
'openhands.server.session.agent_session.EventStream',
|
||||
return_value=mock_event_stream,
|
||||
),
|
||||
patch(
|
||||
'openhands.controller.state.state.State.restore_from_session',
|
||||
side_effect=Exception('No state found'),
|
||||
),
|
||||
patch('openhands.server.session.agent_session.Memory', return_value=memory),
|
||||
):
|
||||
await session.start(
|
||||
runtime_name='test-runtime',
|
||||
config=OpenHandsConfig(),
|
||||
agent=mock_agent,
|
||||
max_iterations=10,
|
||||
)
|
||||
|
||||
# Verify that the agent's LLM metrics and controller's state metrics are the same object
|
||||
assert session.controller.agent.llm.metrics is session.controller.state.metrics
|
||||
|
||||
# Add some metrics to the agent's LLM
|
||||
test_cost = 0.05
|
||||
session.controller.agent.llm.metrics.add_cost(test_cost)
|
||||
|
||||
# Verify that the cost is reflected in the controller's state metrics
|
||||
assert session.controller.state.metrics.accumulated_cost == test_cost
|
||||
|
||||
# Create a test metrics object to simulate an observation with metrics
|
||||
test_observation_metrics = Metrics()
|
||||
test_observation_metrics.add_cost(0.1)
|
||||
|
||||
# Get the current accumulated cost before merging
|
||||
current_cost = session.controller.state.metrics.accumulated_cost
|
||||
|
||||
# Simulate merging metrics from an observation
|
||||
session.controller.state_tracker.merge_metrics(test_observation_metrics)
|
||||
|
||||
# Verify that the merged metrics are reflected in both agent and controller
|
||||
assert session.controller.state.metrics.accumulated_cost == current_cost + 0.1
|
||||
assert (
|
||||
session.controller.agent.llm.metrics.accumulated_cost == current_cost + 0.1
|
||||
)
|
||||
|
||||
# Reset the agent and verify that metrics are not reset
|
||||
session.controller.agent.reset()
|
||||
|
||||
# Metrics should still be the same after reset
|
||||
assert session.controller.state.metrics.accumulated_cost == test_cost + 0.1
|
||||
assert session.controller.agent.llm.metrics.accumulated_cost == test_cost + 0.1
|
||||
assert session.controller.agent.llm.metrics is session.controller.state.metrics
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_budget_control_flag_syncs_with_metrics(mock_agent):
|
||||
"""Test that BudgetControlFlag's current value matches the accumulated costs."""
|
||||
|
||||
# Setup
|
||||
file_store = InMemoryFileStore({})
|
||||
session = AgentSession(
|
||||
sid='test-session',
|
||||
file_store=file_store,
|
||||
)
|
||||
|
||||
# Create a mock runtime and set it up
|
||||
mock_runtime = MagicMock(spec=ActionExecutionClient)
|
||||
|
||||
# Mock the runtime creation to set up the runtime attribute
|
||||
async def mock_create_runtime(*args, **kwargs):
|
||||
session.runtime = mock_runtime
|
||||
return True
|
||||
|
||||
session._create_runtime = AsyncMock(side_effect=mock_create_runtime)
|
||||
|
||||
# Create a mock EventStream with no events
|
||||
mock_event_stream = MagicMock(spec=EventStream)
|
||||
mock_event_stream.get_events.return_value = []
|
||||
mock_event_stream.subscribe = MagicMock()
|
||||
mock_event_stream.get_latest_event_id.return_value = 0
|
||||
|
||||
# Inject the mock event stream into the session
|
||||
session.event_stream = mock_event_stream
|
||||
|
||||
# Create a real Memory instance with the mock event stream
|
||||
memory = Memory(event_stream=mock_event_stream, sid='test-session')
|
||||
memory.microagents_dir = 'test-dir'
|
||||
|
||||
# Patch necessary components
|
||||
with (
|
||||
patch(
|
||||
'openhands.server.session.agent_session.EventStream',
|
||||
return_value=mock_event_stream,
|
||||
),
|
||||
patch(
|
||||
'openhands.controller.state.state.State.restore_from_session',
|
||||
side_effect=Exception('No state found'),
|
||||
),
|
||||
patch('openhands.server.session.agent_session.Memory', return_value=memory),
|
||||
):
|
||||
# Start the session with a budget limit
|
||||
await session.start(
|
||||
runtime_name='test-runtime',
|
||||
config=OpenHandsConfig(),
|
||||
agent=mock_agent,
|
||||
max_iterations=10,
|
||||
max_budget_per_task=1.0, # Set a budget limit
|
||||
)
|
||||
|
||||
# Verify that the budget control flag was created
|
||||
assert session.controller.state.budget_flag is not None
|
||||
assert session.controller.state.budget_flag.max_value == 1.0
|
||||
assert session.controller.state.budget_flag.current_value == 0.0
|
||||
|
||||
# Add some metrics to the agent's LLM
|
||||
test_cost = 0.05
|
||||
session.controller.agent.llm.metrics.add_cost(test_cost)
|
||||
|
||||
# Verify that the budget control flag's current value is updated
|
||||
# This happens through the state_tracker.sync_budget_flag_with_metrics method
|
||||
session.controller.state_tracker.sync_budget_flag_with_metrics()
|
||||
assert session.controller.state.budget_flag.current_value == test_cost
|
||||
|
||||
# Create a test metrics object to simulate an observation with metrics
|
||||
test_observation_metrics = Metrics()
|
||||
test_observation_metrics.add_cost(0.1)
|
||||
|
||||
# Simulate merging metrics from an observation
|
||||
session.controller.state_tracker.merge_metrics(test_observation_metrics)
|
||||
|
||||
# Verify that the budget control flag's current value is updated to match the new accumulated cost
|
||||
assert session.controller.state.budget_flag.current_value == test_cost + 0.1
|
||||
|
||||
# Reset the agent and verify that metrics and budget flag are not reset
|
||||
session.controller.agent.reset()
|
||||
|
||||
# Budget control flag should still reflect the accumulated cost after reset
|
||||
assert session.controller.state.budget_flag.current_value == test_cost + 0.1
|
||||
|
||||
@@ -21,6 +21,9 @@ def test_parser_default_values():
|
||||
assert args.name == ''
|
||||
assert not args.no_auto_continue
|
||||
assert args.selected_repo is None
|
||||
assert args.llm_model is None
|
||||
assert args.llm_base_url is None
|
||||
assert args.llm_api_key is None
|
||||
|
||||
|
||||
def test_parser_custom_values():
|
||||
@@ -55,6 +58,12 @@ def test_parser_custom_values():
|
||||
'--no-auto-continue',
|
||||
'--selected-repo',
|
||||
'owner/repo',
|
||||
'--llm-model',
|
||||
'openai/gpt-4',
|
||||
'--llm-base-url',
|
||||
'http://localhost:1234/v1',
|
||||
'--llm-api-key',
|
||||
'test-api-key',
|
||||
]
|
||||
)
|
||||
|
||||
@@ -73,6 +82,9 @@ def test_parser_custom_values():
|
||||
assert args.no_auto_continue
|
||||
assert args.version
|
||||
assert args.selected_repo == 'owner/repo'
|
||||
assert args.llm_model == 'openai/gpt-4'
|
||||
assert args.llm_base_url == 'http://localhost:1234/v1'
|
||||
assert args.llm_api_key == 'test-api-key'
|
||||
|
||||
|
||||
def test_parser_file_overrides_task():
|
||||
@@ -138,13 +150,16 @@ def test_help_message(capsys):
|
||||
'--no-auto-continue',
|
||||
'--selected-repo SELECTED_REPO',
|
||||
'--override-cli-mode OVERRIDE_CLI_MODE',
|
||||
'--llm-model LLM_MODEL',
|
||||
'--llm-base-url LLM_BASE_URL',
|
||||
'--llm-api-key LLM_API_KEY',
|
||||
]
|
||||
|
||||
for element in expected_elements:
|
||||
assert element in help_output, f"Expected '{element}' to be in the help message"
|
||||
|
||||
option_count = help_output.count(' -')
|
||||
assert option_count == 20, f'Expected 20 options, found {option_count}'
|
||||
assert option_count == 23, f'Expected 23 options, found {option_count}'
|
||||
|
||||
|
||||
def test_selected_repo_format():
|
||||
|
||||
@@ -1,139 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from openhands.controller.state.control_flags import (
|
||||
BudgetControlFlag,
|
||||
IterationControlFlag,
|
||||
)
|
||||
|
||||
|
||||
def test_iteration_control_flag_reaches_limit_and_increases():
|
||||
flag = IterationControlFlag(limit_increase_amount=5, current_value=5, max_value=5)
|
||||
|
||||
# Should be at limit
|
||||
assert flag.reached_limit() is True
|
||||
assert flag._hit_limit is True
|
||||
|
||||
# Increase limit in non-headless mode
|
||||
flag.increase_limit(headless_mode=False)
|
||||
assert flag.max_value == 10 # increased by limit_increase_amount
|
||||
|
||||
# After increase, we should no longer be at limit
|
||||
flag._hit_limit = False # simulate reset
|
||||
assert flag.reached_limit() is False
|
||||
|
||||
|
||||
def test_iteration_control_flag_does_not_increase_in_headless():
|
||||
flag = IterationControlFlag(limit_increase_amount=5, current_value=5, max_value=5)
|
||||
|
||||
assert flag.reached_limit() is True
|
||||
assert flag._hit_limit is True
|
||||
|
||||
# Should NOT increase max_value in headless mode
|
||||
flag.increase_limit(headless_mode=True)
|
||||
assert flag.max_value == 5
|
||||
|
||||
|
||||
def test_iteration_control_flag_step_behavior():
|
||||
flag = IterationControlFlag(limit_increase_amount=2, current_value=0, max_value=2)
|
||||
|
||||
# First step
|
||||
flag.step()
|
||||
assert flag.current_value == 1
|
||||
assert not flag.reached_limit()
|
||||
|
||||
# Second step
|
||||
flag.step()
|
||||
assert flag.current_value == 2
|
||||
assert flag.reached_limit()
|
||||
|
||||
# Stepping again should raise error
|
||||
with pytest.raises(RuntimeError, match='Agent reached maximum iteration'):
|
||||
flag.step()
|
||||
|
||||
|
||||
# ----- BudgetControlFlag Tests -----
|
||||
|
||||
|
||||
def test_budget_control_flag_reaches_limit_and_increases():
|
||||
flag = BudgetControlFlag(
|
||||
limit_increase_amount=10.0, current_value=50.0, max_value=50.0
|
||||
)
|
||||
|
||||
# Should be at limit
|
||||
assert flag.reached_limit() is True
|
||||
assert flag._hit_limit is True
|
||||
|
||||
# Increase budget — allowed only if _hit_limit == True
|
||||
flag.increase_limit(headless_mode=False)
|
||||
assert flag.max_value == 60.0 # current_value + limit_increase_amount
|
||||
|
||||
# After increasing, _hit_limit should be reset manually in your logic
|
||||
flag._hit_limit = False
|
||||
flag.current_value = 55.0
|
||||
assert flag.reached_limit() is False
|
||||
|
||||
|
||||
def test_budget_control_flag_does_not_increase_if_not_hit_limit():
|
||||
flag = BudgetControlFlag(
|
||||
limit_increase_amount=10.0, current_value=40.0, max_value=50.0
|
||||
)
|
||||
|
||||
# Not at limit yet
|
||||
assert flag.reached_limit() is False
|
||||
assert flag._hit_limit is False
|
||||
|
||||
# Try to increase — should do nothing
|
||||
old_max_value = flag.max_value
|
||||
flag.increase_limit(headless_mode=False)
|
||||
assert flag.max_value == old_max_value
|
||||
|
||||
|
||||
def test_budget_control_flag_does_not_increase_in_headless():
|
||||
flag = BudgetControlFlag(
|
||||
limit_increase_amount=10.0, current_value=50.0, max_value=50.0
|
||||
)
|
||||
|
||||
assert flag.reached_limit() is True
|
||||
assert flag._hit_limit is True
|
||||
|
||||
# Increase limit in headless mode — should still increase since BudgetControlFlag ignores headless param
|
||||
flag.increase_limit(headless_mode=True)
|
||||
assert flag.max_value == 60.0
|
||||
|
||||
|
||||
def test_budget_control_flag_step_raises_on_limit():
|
||||
flag = BudgetControlFlag(
|
||||
limit_increase_amount=5.0, current_value=55.0, max_value=50.0
|
||||
)
|
||||
|
||||
# Should raise RuntimeError
|
||||
with pytest.raises(RuntimeError, match='Agent reached maximum budget'):
|
||||
flag.step()
|
||||
|
||||
# After increasing limit, step should not raise
|
||||
flag.max_value = 60.0
|
||||
flag._hit_limit = False
|
||||
flag.step() # Should not raise
|
||||
|
||||
|
||||
def test_budget_control_flag_hit_limit_resets_after_increase():
|
||||
flag = BudgetControlFlag(
|
||||
limit_increase_amount=10.0, current_value=50.0, max_value=50.0
|
||||
)
|
||||
|
||||
# Initially should hit limit
|
||||
assert flag.reached_limit() is True
|
||||
assert flag._hit_limit is True
|
||||
|
||||
# Increase limit
|
||||
flag.increase_limit(headless_mode=False)
|
||||
|
||||
# After increasing, _hit_limit should be reset
|
||||
assert flag._hit_limit is False
|
||||
|
||||
# Should no longer report reaching limit unless value exceeds new max
|
||||
assert flag.reached_limit() is False
|
||||
|
||||
# If we push current_value over new max_value:
|
||||
flag.current_value = flag.max_value + 1.0
|
||||
assert flag.reached_limit() is True
|
||||
@@ -55,9 +55,7 @@ def event_stream(temp_dir):
|
||||
class TestStuckDetector:
|
||||
@pytest.fixture
|
||||
def stuck_detector(self):
|
||||
state = State(inputs={})
|
||||
# Set the iteration flag's max_value to 50 (equivalent to the old max_iterations)
|
||||
state.iteration_flag.max_value = 50
|
||||
state = State(inputs={}, max_iterations=50)
|
||||
state.history = [] # Initialize history as an empty list
|
||||
return StuckDetector(state)
|
||||
|
||||
|
||||
76
tests/unit/test_iteration_limit.py
Normal file
76
tests/unit/test_iteration_limit.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.controller.agent_controller import AgentController
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.events import EventStream
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.events.event import EventSource
|
||||
from openhands.llm.metrics import Metrics
|
||||
|
||||
|
||||
class DummyAgent:
|
||||
def __init__(self):
|
||||
self.name = 'dummy'
|
||||
self.llm = type(
|
||||
'DummyLLM',
|
||||
(),
|
||||
{
|
||||
'metrics': Metrics(),
|
||||
'config': type('DummyConfig', (), {'max_message_chars': 10000})(),
|
||||
},
|
||||
)()
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def get_system_message(self):
|
||||
# Return a proper SystemMessageAction for the refactored system message handling
|
||||
from openhands.events.action.message import SystemMessageAction
|
||||
from openhands.events.event import EventSource
|
||||
|
||||
system_message = SystemMessageAction(content='This is a dummy system message')
|
||||
system_message._source = EventSource.AGENT
|
||||
system_message._id = -1 # Set invalid ID to avoid the ID check
|
||||
return system_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_iteration_limit_extends_on_user_message():
|
||||
# Initialize test components
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
file_store = InMemoryFileStore()
|
||||
event_stream = EventStream(sid='test', file_store=file_store)
|
||||
agent = DummyAgent()
|
||||
initial_max_iterations = 100
|
||||
controller = AgentController(
|
||||
agent=agent,
|
||||
event_stream=event_stream,
|
||||
max_iterations=initial_max_iterations,
|
||||
sid='test',
|
||||
headless_mode=False,
|
||||
)
|
||||
|
||||
# Set initial state
|
||||
await controller.set_agent_state_to(AgentState.RUNNING)
|
||||
controller.state.iteration = 90 # Close to the limit
|
||||
assert controller.state.max_iterations == initial_max_iterations
|
||||
|
||||
# Simulate user message
|
||||
user_message = MessageAction('test message', EventSource.USER)
|
||||
event_stream.add_event(user_message, EventSource.USER)
|
||||
await asyncio.sleep(0.1) # Give time for event to be processed
|
||||
|
||||
# Verify max_iterations was extended
|
||||
assert controller.state.max_iterations == 90 + initial_max_iterations
|
||||
|
||||
# Simulate more iterations and another user message
|
||||
controller.state.iteration = 180 # Close to new limit
|
||||
user_message2 = MessageAction('another message', EventSource.USER)
|
||||
event_stream.add_event(user_message2, EventSource.USER)
|
||||
await asyncio.sleep(0.1) # Give time for event to be processed
|
||||
|
||||
# Verify max_iterations was extended again
|
||||
assert controller.state.max_iterations == 180 + initial_max_iterations
|
||||
@@ -250,6 +250,28 @@ def test_response_latency_tracking(mock_time, mock_litellm_completion):
|
||||
assert latency_record.latency == 0.0 # Should be lifted to 0 instead of being -1!
|
||||
|
||||
|
||||
def test_llm_reset():
|
||||
llm = LLM(LLMConfig(model='gpt-4o-mini', api_key='test_key'))
|
||||
initial_metrics = copy.deepcopy(llm.metrics)
|
||||
initial_metrics.add_cost(1.0)
|
||||
initial_metrics.add_response_latency(0.5, 'test-id')
|
||||
initial_metrics.add_token_usage(10, 5, 3, 2, 1000, 'test-id')
|
||||
llm.reset()
|
||||
assert llm.metrics.accumulated_cost != initial_metrics.accumulated_cost
|
||||
assert llm.metrics.costs != initial_metrics.costs
|
||||
assert llm.metrics.response_latencies != initial_metrics.response_latencies
|
||||
assert llm.metrics.token_usages != initial_metrics.token_usages
|
||||
assert isinstance(llm.metrics, Metrics)
|
||||
|
||||
# Check that accumulated token usage is reset
|
||||
metrics_data = llm.metrics.get()
|
||||
accumulated_usage = metrics_data['accumulated_token_usage']
|
||||
assert accumulated_usage['prompt_tokens'] == 0
|
||||
assert accumulated_usage['completion_tokens'] == 0
|
||||
assert accumulated_usage['cache_read_tokens'] == 0
|
||||
assert accumulated_usage['cache_write_tokens'] == 0
|
||||
|
||||
|
||||
@patch('openhands.llm.llm.litellm.get_model_info')
|
||||
def test_llm_init_with_openrouter_model(mock_get_model_info, default_config):
|
||||
default_config.model = 'openrouter:gpt-4o-mini'
|
||||
|
||||
@@ -111,7 +111,7 @@ async def test_memory_on_event_exception_handling(memory, event_stream, mock_age
|
||||
)
|
||||
|
||||
# Verify that the controller's last error was set
|
||||
assert state.iteration_flag.current_value == 0
|
||||
assert state.iteration == 0
|
||||
assert state.agent_state == AgentState.ERROR
|
||||
assert state.last_error == 'Error: Exception'
|
||||
|
||||
@@ -142,7 +142,7 @@ async def test_memory_on_workspace_context_recall_exception_handling(
|
||||
)
|
||||
|
||||
# Verify that the controller's last error was set
|
||||
assert state.iteration_flag.current_value == 0
|
||||
assert state.iteration == 0
|
||||
assert state.agent_state == AgentState.ERROR
|
||||
assert state.last_error == 'Error: Exception'
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ import shutil
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.controller.state.control_flags import IterationControlFlag
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.message import Message, TextContent
|
||||
from openhands.events.observation.agent import MicroagentKnowledge
|
||||
@@ -162,11 +161,9 @@ def test_add_turns_left_reminder(prompt_dir):
|
||||
manager = PromptManager(prompt_dir=prompt_dir)
|
||||
|
||||
# Create a State object with specific iteration values
|
||||
state = State(
|
||||
iteration_flag=IterationControlFlag(
|
||||
current_value=3, max_value=10, limit_increase_amount=10
|
||||
)
|
||||
)
|
||||
state = State()
|
||||
state.iteration = 3
|
||||
state.max_iterations = 10
|
||||
|
||||
# Create a list of messages with a user message
|
||||
user_message = Message(role='user', content=[TextContent(text='User content')])
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
from openhands.controller.state.state import State, TrafficControlState
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.events.event import Event
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
|
||||
@@ -60,66 +56,3 @@ def test_state_view_cache_not_serialized():
|
||||
# be structurally identical but _not_ the same object.
|
||||
assert id(restored_view) != id(view)
|
||||
assert restored_view.events == view.events
|
||||
|
||||
|
||||
def test_restore_older_state_version():
|
||||
"""Test that we can restore from an older state version (before control flags)."""
|
||||
# Create a dictionary that mimics the old state format (before control flags)
|
||||
state = State(
|
||||
session_id='test_old_session',
|
||||
iteration=42,
|
||||
local_iteration=42,
|
||||
max_iterations=100,
|
||||
agent_state=AgentState.RUNNING,
|
||||
traffic_control_state=TrafficControlState.NORMAL,
|
||||
metrics=Metrics(),
|
||||
confirmation_mode=False,
|
||||
)
|
||||
|
||||
def no_op_getstate(self):
|
||||
return self.__dict__
|
||||
|
||||
store = InMemoryFileStore()
|
||||
|
||||
with patch.object(State, '__getstate__', no_op_getstate):
|
||||
state.save_to_session('test_old_session', store, None)
|
||||
|
||||
# Now restore it
|
||||
restored_state = State.restore_from_session('test_old_session', store, None)
|
||||
|
||||
# Verify that when we store the active fields are populated with the values from the deprecated fields
|
||||
assert restored_state.session_id == 'test_old_session'
|
||||
assert restored_state.agent_state == AgentState.LOADING
|
||||
assert restored_state.resume_state == AgentState.RUNNING
|
||||
assert restored_state.iteration_flag.current_value == 42
|
||||
assert restored_state.iteration_flag.max_value == 100
|
||||
|
||||
|
||||
def test_save_without_deprecated_fields():
|
||||
"""Test that we can save state without deprecated fields"""
|
||||
# Create a dictionary that mimics the old state format (before control flags)
|
||||
state = State(
|
||||
session_id='test_old_session',
|
||||
iteration=42,
|
||||
local_iteration=42,
|
||||
max_iterations=100,
|
||||
agent_state=AgentState.RUNNING,
|
||||
traffic_control_state=TrafficControlState.NORMAL,
|
||||
metrics=Metrics(),
|
||||
confirmation_mode=False,
|
||||
)
|
||||
|
||||
store = InMemoryFileStore()
|
||||
|
||||
state.save_to_session('test_state', store, None)
|
||||
restored_state = State.restore_from_session('test_state', store, None)
|
||||
|
||||
# Verify that when we save and restore, the deprecated fields are removed
|
||||
# but the new fields maintain the correct values
|
||||
assert restored_state.session_id == 'test_old_session'
|
||||
assert restored_state.agent_state == AgentState.LOADING
|
||||
assert restored_state.resume_state == AgentState.RUNNING
|
||||
assert (
|
||||
restored_state.iteration_flag.current_value == 0
|
||||
) # The depreciated attrib was not stored, so it did not override existing values on restore
|
||||
assert restored_state.iteration_flag.max_value == 100
|
||||
|
||||
91
tests/unit/test_traffic_control.py
Normal file
91
tests/unit/test_traffic_control.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.controller.agent_controller import AgentController
|
||||
from openhands.core.config import AgentConfig, LLMConfig
|
||||
from openhands.events import EventStream
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.storage import InMemoryFileStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent_controller():
|
||||
llm = LLM(config=LLMConfig())
|
||||
agent = MagicMock()
|
||||
agent.name = 'test_agent'
|
||||
agent.llm = llm
|
||||
agent.config = AgentConfig()
|
||||
|
||||
# Add a proper system message mock
|
||||
from openhands.events import EventSource
|
||||
from openhands.events.action.message import SystemMessageAction
|
||||
|
||||
system_message = SystemMessageAction(content='Test system message')
|
||||
system_message._source = EventSource.AGENT
|
||||
system_message._id = -1 # Set invalid ID to avoid the ID check
|
||||
agent.get_system_message.return_value = system_message
|
||||
|
||||
event_stream = EventStream(sid='test', file_store=InMemoryFileStore())
|
||||
controller = AgentController(
|
||||
agent=agent,
|
||||
event_stream=event_stream,
|
||||
max_iterations=100,
|
||||
max_budget_per_task=10.0,
|
||||
sid='test',
|
||||
headless_mode=False,
|
||||
)
|
||||
return controller
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_traffic_control_iteration_message(agent_controller):
|
||||
"""Test that iteration messages are formatted as integers."""
|
||||
# Mock _react_to_exception to capture the error
|
||||
error = None
|
||||
|
||||
async def mock_react_to_exception(e):
|
||||
nonlocal error
|
||||
error = e
|
||||
|
||||
agent_controller._react_to_exception = mock_react_to_exception
|
||||
|
||||
await agent_controller._handle_traffic_control('iteration', 200.0, 100.0)
|
||||
assert error is not None
|
||||
assert 'Current iteration: 200, max iteration: 100' in str(error)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_traffic_control_budget_message(agent_controller):
|
||||
"""Test that budget messages keep decimal points."""
|
||||
# Mock _react_to_exception to capture the error
|
||||
error = None
|
||||
|
||||
async def mock_react_to_exception(e):
|
||||
nonlocal error
|
||||
error = e
|
||||
|
||||
agent_controller._react_to_exception = mock_react_to_exception
|
||||
|
||||
await agent_controller._handle_traffic_control('budget', 15.75, 10.0)
|
||||
assert error is not None
|
||||
assert 'Current budget: 15.75, max budget: 10.00' in str(error)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_traffic_control_headless_mode(agent_controller):
|
||||
"""Test that headless mode messages are formatted correctly."""
|
||||
# Mock _react_to_exception to capture the error
|
||||
error = None
|
||||
|
||||
async def mock_react_to_exception(e):
|
||||
nonlocal error
|
||||
error = e
|
||||
|
||||
agent_controller._react_to_exception = mock_react_to_exception
|
||||
|
||||
agent_controller.headless_mode = True
|
||||
await agent_controller._handle_traffic_control('iteration', 200.0, 100.0)
|
||||
assert error is not None
|
||||
assert 'in headless mode' in str(error)
|
||||
assert 'Current iteration: 200, max iteration: 100' in str(error)
|
||||
Reference in New Issue
Block a user