mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
1 Commits
migrate-to
...
openhands-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
48cad4117a |
@@ -29,6 +29,7 @@ import OpenHands from "#/api/open-hands";
|
||||
import { downloadWorkspace } from "#/utils/download-workspace";
|
||||
import { SuggestionItem } from "./suggestion-item";
|
||||
import { useAuth } from "#/context/auth-context";
|
||||
import { CostDisplay } from "./cost-display";
|
||||
|
||||
const isErrorMessage = (
|
||||
message: Message | ErrorMessage,
|
||||
@@ -238,6 +239,7 @@ export function ChatInterface() {
|
||||
onClose={() => setFeedbackModalIsOpen(false)}
|
||||
polarity={feedbackPolarity}
|
||||
/>
|
||||
<CostDisplay />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
31
frontend/src/components/cost-display.tsx
Normal file
31
frontend/src/components/cost-display.tsx
Normal file
@@ -0,0 +1,31 @@
|
||||
import React from "react";
|
||||
import { useSelector } from "react-redux";
|
||||
import { RootState } from "#/store";
|
||||
|
||||
export function CostDisplay() {
|
||||
const { totalCost, lastStepCosts } = useSelector((state: RootState) => state.cost);
|
||||
|
||||
return (
|
||||
<div className="fixed bottom-24 right-4 bg-neutral-700 border border-neutral-600 rounded-lg p-3 text-sm">
|
||||
<div className="mb-2">
|
||||
<span className="text-neutral-400">Total Cost:</span>{" "}
|
||||
<span className="font-semibold">${totalCost.toFixed(4)}</span>
|
||||
</div>
|
||||
{lastStepCosts.length > 0 && (
|
||||
<div>
|
||||
<span className="text-neutral-400">Last Steps:</span>
|
||||
<div className="mt-1 space-y-1">
|
||||
{lastStepCosts.map((step, i) => (
|
||||
<div key={i} className="flex justify-between">
|
||||
<span className="text-neutral-300 truncate mr-4" title={step.description}>
|
||||
{step.description}
|
||||
</span>
|
||||
<span className="text-neutral-300">${step.cost.toFixed(4)}</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import {
|
||||
appendSecurityAnalyzerInput,
|
||||
} from "#/state/security-analyzer-slice";
|
||||
import { setCurStatusMessage } from "#/state/status-slice";
|
||||
import { addStepCost } from "#/state/cost-slice";
|
||||
import store from "#/store";
|
||||
import ActionType from "#/types/action-type";
|
||||
import {
|
||||
@@ -148,6 +149,12 @@ export function handleAssistantMessage(message: Record<string, unknown>) {
|
||||
handleObservationMessage(message as unknown as ObservationMessage);
|
||||
} else if (message.status_update) {
|
||||
handleStatusMessage(message as unknown as StatusMessage);
|
||||
} else if (message.event === "cost") {
|
||||
store.dispatch(addStepCost({
|
||||
stepCost: message.step_cost as number,
|
||||
totalCost: message.total_cost as number,
|
||||
description: message.description as string,
|
||||
}));
|
||||
} else {
|
||||
console.error("Unknown message type", message);
|
||||
}
|
||||
|
||||
39
frontend/src/state/cost-slice.ts
Normal file
39
frontend/src/state/cost-slice.ts
Normal file
@@ -0,0 +1,39 @@
|
||||
import { createSlice, PayloadAction } from "@reduxjs/toolkit";
|
||||
|
||||
type CostState = {
|
||||
totalCost: number;
|
||||
lastStepCosts: { cost: number; description: string }[];
|
||||
};
|
||||
|
||||
const initialState: CostState = {
|
||||
totalCost: 0,
|
||||
lastStepCosts: [],
|
||||
};
|
||||
|
||||
export const costSlice = createSlice({
|
||||
name: "cost",
|
||||
initialState,
|
||||
reducers: {
|
||||
addStepCost(
|
||||
state,
|
||||
action: PayloadAction<{ stepCost: number; totalCost: number; description: string }>,
|
||||
) {
|
||||
state.totalCost = action.payload.totalCost;
|
||||
state.lastStepCosts.push({
|
||||
cost: action.payload.stepCost,
|
||||
description: action.payload.description,
|
||||
});
|
||||
// Keep only last 3 step costs
|
||||
if (state.lastStepCosts.length > 3) {
|
||||
state.lastStepCosts.shift();
|
||||
}
|
||||
},
|
||||
clearCosts(state) {
|
||||
state.totalCost = 0;
|
||||
state.lastStepCosts = [];
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
export const { addStepCost, clearCosts } = costSlice.actions;
|
||||
export default costSlice.reducer;
|
||||
@@ -3,6 +3,7 @@ import agentReducer from "./state/agent-slice";
|
||||
import browserReducer from "./state/browser-slice";
|
||||
import chatReducer from "./state/chat-slice";
|
||||
import codeReducer from "./state/code-slice";
|
||||
import costReducer from "./state/cost-slice";
|
||||
import fileStateReducer from "./state/file-state-slice";
|
||||
import initialQueryReducer from "./state/initial-query-slice";
|
||||
import commandReducer from "./state/command-slice";
|
||||
@@ -21,6 +22,7 @@ export const rootReducer = combineReducers({
|
||||
jupyter: jupyterReducer,
|
||||
securityAnalyzer: securityAnalyzerReducer,
|
||||
status: statusReducer,
|
||||
cost: costReducer,
|
||||
});
|
||||
|
||||
const store = configureStore({
|
||||
|
||||
@@ -36,6 +36,7 @@ from openhands.events.action import (
|
||||
ModifyTaskAction,
|
||||
NullAction,
|
||||
)
|
||||
from openhands.events.cost import CostEvent
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation import (
|
||||
AgentDelegateObservation,
|
||||
@@ -181,6 +182,16 @@ class AgentController:
|
||||
async def update_state_after_step(self):
|
||||
# update metrics especially for cost. Use deepcopy to avoid it being modified by agent.reset()
|
||||
self.state.local_metrics = copy.deepcopy(self.agent.llm.metrics)
|
||||
# Emit cost event
|
||||
if self.state.local_metrics.accumulated_cost is not None:
|
||||
self.event_stream.add_event(
|
||||
CostEvent(
|
||||
step_cost=self.state.local_metrics.accumulated_cost - (self.state.metrics.accumulated_cost or 0),
|
||||
total_cost=self.state.local_metrics.accumulated_cost,
|
||||
description=f"Step {self.state.iteration}"
|
||||
),
|
||||
EventSource.ENVIRONMENT
|
||||
)
|
||||
|
||||
async def _react_to_exception(
|
||||
self,
|
||||
|
||||
21
openhands/events/cost.py
Normal file
21
openhands/events/cost.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from dataclasses import dataclass
|
||||
from openhands.events.observation import Observation
|
||||
|
||||
|
||||
@dataclass
|
||||
class CostEvent(Observation):
|
||||
"""Event emitted when a cost is incurred by the LLM."""
|
||||
step_cost: float
|
||||
total_cost: float
|
||||
description: str
|
||||
|
||||
def __init__(self, step_cost: float, total_cost: float, description: str):
|
||||
super().__init__(content="") # Content will be set in post_init
|
||||
self.step_cost = step_cost
|
||||
self.total_cost = total_cost
|
||||
self.description = description
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.observation = "cost"
|
||||
self.content = f"Cost: ${self.step_cost:.4f} (Total: ${self.total_cost:.4f}) - {self.description}"
|
||||
4
poetry.lock
generated
4
poetry.lock
generated
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aenum"
|
||||
@@ -10252,4 +10252,4 @@ testing = ["coverage[toml]", "zope.event", "zope.testing"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.12"
|
||||
content-hash = "f55423a96fb0640333c8f66552ba8443b0d61764558c763e8aae104359fd02c1"
|
||||
content-hash = "a9e84d284a9b711ed9844966fdc44d7455f7dac1112325e3d9990d7b527a71f7"
|
||||
|
||||
@@ -64,6 +64,7 @@ modal = "^0.66.26"
|
||||
runloop-api-client = "0.10.0"
|
||||
pygithub = "^2.5.0"
|
||||
openhands-aci = "^0.1.1"
|
||||
pytest = "^8.3.3"
|
||||
|
||||
[tool.poetry.group.llama-index.dependencies]
|
||||
llama-index = "*"
|
||||
@@ -96,6 +97,7 @@ reportlab = "*"
|
||||
concurrency = ["gevent"]
|
||||
|
||||
|
||||
|
||||
[tool.poetry.group.runtime.dependencies]
|
||||
jupyterlab = "*"
|
||||
notebook = "*"
|
||||
@@ -127,6 +129,7 @@ ignore = ["D1"]
|
||||
convention = "google"
|
||||
|
||||
|
||||
|
||||
[tool.poetry.group.evaluation.dependencies]
|
||||
streamlit = "*"
|
||||
whatthepatch = "*"
|
||||
|
||||
47
tests/unit/test_cost_event.py
Normal file
47
tests/unit/test_cost_event.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import pytest
|
||||
from openhands.events.cost import CostEvent
|
||||
from openhands.events.event import EventSource
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.storage.files import FileStore
|
||||
|
||||
|
||||
class MockFileStore(FileStore):
|
||||
def __init__(self):
|
||||
self.files = {}
|
||||
|
||||
def write(self, path: str, contents: str) -> None:
|
||||
self.files[path] = contents
|
||||
|
||||
def read(self, path: str) -> str:
|
||||
if path not in self.files:
|
||||
raise FileNotFoundError(path)
|
||||
return self.files[path]
|
||||
|
||||
def list(self, path: str) -> list[str]:
|
||||
return [k for k in self.files.keys() if k.startswith(path)]
|
||||
|
||||
def delete(self, path: str) -> None:
|
||||
self.files = {k: v for k, v in self.files.items() if not k.startswith(path)}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def event_stream():
|
||||
file_store = MockFileStore()
|
||||
return EventStream("test", file_store)
|
||||
|
||||
|
||||
def test_cost_event(event_stream):
|
||||
# Create a cost event
|
||||
cost_event = CostEvent(step_cost=0.1, total_cost=0.5, description="Test step")
|
||||
|
||||
# Add the event to the stream
|
||||
event_stream.add_event(cost_event, EventSource.ENVIRONMENT)
|
||||
|
||||
# Get the latest event
|
||||
latest_event = event_stream.get_latest_event()
|
||||
|
||||
# Verify the event was added correctly
|
||||
assert isinstance(latest_event, CostEvent)
|
||||
assert latest_event.step_cost == 0.1
|
||||
assert latest_event.total_cost == 0.5
|
||||
assert latest_event.description == "Test step"
|
||||
Reference in New Issue
Block a user