Compare commits

...

1 Commits

Author SHA1 Message Date
openhands
48cad4117a Fix issue #5257: Display API costs in frontend 2024-11-25 14:25:10 +00:00
10 changed files with 165 additions and 2 deletions

View File

@@ -29,6 +29,7 @@ import OpenHands from "#/api/open-hands";
import { downloadWorkspace } from "#/utils/download-workspace";
import { SuggestionItem } from "./suggestion-item";
import { useAuth } from "#/context/auth-context";
import { CostDisplay } from "./cost-display";
const isErrorMessage = (
message: Message | ErrorMessage,
@@ -238,6 +239,7 @@ export function ChatInterface() {
onClose={() => setFeedbackModalIsOpen(false)}
polarity={feedbackPolarity}
/>
<CostDisplay />
</div>
);
}

View File

@@ -0,0 +1,31 @@
import React from "react";
import { useSelector } from "react-redux";
import { RootState } from "#/store";
export function CostDisplay() {
const { totalCost, lastStepCosts } = useSelector((state: RootState) => state.cost);
return (
<div className="fixed bottom-24 right-4 bg-neutral-700 border border-neutral-600 rounded-lg p-3 text-sm">
<div className="mb-2">
<span className="text-neutral-400">Total Cost:</span>{" "}
<span className="font-semibold">${totalCost.toFixed(4)}</span>
</div>
{lastStepCosts.length > 0 && (
<div>
<span className="text-neutral-400">Last Steps:</span>
<div className="mt-1 space-y-1">
{lastStepCosts.map((step, i) => (
<div key={i} className="flex justify-between">
<span className="text-neutral-300 truncate mr-4" title={step.description}>
{step.description}
</span>
<span className="text-neutral-300">${step.cost.toFixed(4)}</span>
</div>
))}
</div>
</div>
)}
</div>
);
}

View File

@@ -10,6 +10,7 @@ import {
appendSecurityAnalyzerInput,
} from "#/state/security-analyzer-slice";
import { setCurStatusMessage } from "#/state/status-slice";
import { addStepCost } from "#/state/cost-slice";
import store from "#/store";
import ActionType from "#/types/action-type";
import {
@@ -148,6 +149,12 @@ export function handleAssistantMessage(message: Record<string, unknown>) {
handleObservationMessage(message as unknown as ObservationMessage);
} else if (message.status_update) {
handleStatusMessage(message as unknown as StatusMessage);
} else if (message.event === "cost") {
store.dispatch(addStepCost({
stepCost: message.step_cost as number,
totalCost: message.total_cost as number,
description: message.description as string,
}));
} else {
console.error("Unknown message type", message);
}

View File

@@ -0,0 +1,39 @@
import { createSlice, PayloadAction } from "@reduxjs/toolkit";
type CostState = {
totalCost: number;
lastStepCosts: { cost: number; description: string }[];
};
const initialState: CostState = {
totalCost: 0,
lastStepCosts: [],
};
export const costSlice = createSlice({
name: "cost",
initialState,
reducers: {
addStepCost(
state,
action: PayloadAction<{ stepCost: number; totalCost: number; description: string }>,
) {
state.totalCost = action.payload.totalCost;
state.lastStepCosts.push({
cost: action.payload.stepCost,
description: action.payload.description,
});
// Keep only last 3 step costs
if (state.lastStepCosts.length > 3) {
state.lastStepCosts.shift();
}
},
clearCosts(state) {
state.totalCost = 0;
state.lastStepCosts = [];
},
},
});
export const { addStepCost, clearCosts } = costSlice.actions;
export default costSlice.reducer;

View File

@@ -3,6 +3,7 @@ import agentReducer from "./state/agent-slice";
import browserReducer from "./state/browser-slice";
import chatReducer from "./state/chat-slice";
import codeReducer from "./state/code-slice";
import costReducer from "./state/cost-slice";
import fileStateReducer from "./state/file-state-slice";
import initialQueryReducer from "./state/initial-query-slice";
import commandReducer from "./state/command-slice";
@@ -21,6 +22,7 @@ export const rootReducer = combineReducers({
jupyter: jupyterReducer,
securityAnalyzer: securityAnalyzerReducer,
status: statusReducer,
cost: costReducer,
});
const store = configureStore({

View File

@@ -36,6 +36,7 @@ from openhands.events.action import (
ModifyTaskAction,
NullAction,
)
from openhands.events.cost import CostEvent
from openhands.events.event import Event
from openhands.events.observation import (
AgentDelegateObservation,
@@ -181,6 +182,16 @@ class AgentController:
async def update_state_after_step(self):
# update metrics especially for cost. Use deepcopy to avoid it being modified by agent.reset()
self.state.local_metrics = copy.deepcopy(self.agent.llm.metrics)
# Emit cost event
if self.state.local_metrics.accumulated_cost is not None:
self.event_stream.add_event(
CostEvent(
step_cost=self.state.local_metrics.accumulated_cost - (self.state.metrics.accumulated_cost or 0),
total_cost=self.state.local_metrics.accumulated_cost,
description=f"Step {self.state.iteration}"
),
EventSource.ENVIRONMENT
)
async def _react_to_exception(
self,

21
openhands/events/cost.py Normal file
View File

@@ -0,0 +1,21 @@
from dataclasses import dataclass
from openhands.events.observation import Observation
@dataclass
class CostEvent(Observation):
"""Event emitted when a cost is incurred by the LLM."""
step_cost: float
total_cost: float
description: str
def __init__(self, step_cost: float, total_cost: float, description: str):
super().__init__(content="") # Content will be set in post_init
self.step_cost = step_cost
self.total_cost = total_cost
self.description = description
def __post_init__(self):
super().__post_init__()
self.observation = "cost"
self.content = f"Cost: ${self.step_cost:.4f} (Total: ${self.total_cost:.4f}) - {self.description}"

4
poetry.lock generated
View File

@@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand.
[[package]]
name = "aenum"
@@ -10252,4 +10252,4 @@ testing = ["coverage[toml]", "zope.event", "zope.testing"]
[metadata]
lock-version = "2.0"
python-versions = "^3.12"
content-hash = "f55423a96fb0640333c8f66552ba8443b0d61764558c763e8aae104359fd02c1"
content-hash = "a9e84d284a9b711ed9844966fdc44d7455f7dac1112325e3d9990d7b527a71f7"

View File

@@ -64,6 +64,7 @@ modal = "^0.66.26"
runloop-api-client = "0.10.0"
pygithub = "^2.5.0"
openhands-aci = "^0.1.1"
pytest = "^8.3.3"
[tool.poetry.group.llama-index.dependencies]
llama-index = "*"
@@ -96,6 +97,7 @@ reportlab = "*"
concurrency = ["gevent"]
[tool.poetry.group.runtime.dependencies]
jupyterlab = "*"
notebook = "*"
@@ -127,6 +129,7 @@ ignore = ["D1"]
convention = "google"
[tool.poetry.group.evaluation.dependencies]
streamlit = "*"
whatthepatch = "*"

View File

@@ -0,0 +1,47 @@
import pytest
from openhands.events.cost import CostEvent
from openhands.events.event import EventSource
from openhands.events.stream import EventStream
from openhands.storage.files import FileStore
class MockFileStore(FileStore):
def __init__(self):
self.files = {}
def write(self, path: str, contents: str) -> None:
self.files[path] = contents
def read(self, path: str) -> str:
if path not in self.files:
raise FileNotFoundError(path)
return self.files[path]
def list(self, path: str) -> list[str]:
return [k for k in self.files.keys() if k.startswith(path)]
def delete(self, path: str) -> None:
self.files = {k: v for k, v in self.files.items() if not k.startswith(path)}
@pytest.fixture
def event_stream():
file_store = MockFileStore()
return EventStream("test", file_store)
def test_cost_event(event_stream):
# Create a cost event
cost_event = CostEvent(step_cost=0.1, total_cost=0.5, description="Test step")
# Add the event to the stream
event_stream.add_event(cost_event, EventSource.ENVIRONMENT)
# Get the latest event
latest_event = event_stream.get_latest_event()
# Verify the event was added correctly
assert isinstance(latest_event, CostEvent)
assert latest_event.step_cost == 0.1
assert latest_event.total_cost == 0.5
assert latest_event.description == "Test step"