mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
14 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d67c4c5848 | |||
| 36b378c561 | |||
| ec763f8105 | |||
| 165c0cc42e | |||
| 1b4f15235e | |||
| 303b7ab180 | |||
| 78d185b102 | |||
| 300bfbdf2d | |||
| e2f414bf26 | |||
| 3b955dd9d5 | |||
| f1eb1f59c3 | |||
| e1f6929d98 | |||
| 2a7f926591 | |||
| b8daab721d |
@@ -10,6 +10,9 @@ We also support "remote" runtimes, which are typically managed by third-parties.
|
||||
They can make setup a bit simpler and more scalable, especially
|
||||
if you're running many OpenHands conversations in parallel (e.g. to do evaluation).
|
||||
|
||||
Additionally, we provide a "local" runtime that runs directly on your machine without Docker,
|
||||
which can be useful in controlled environments like CI pipelines.
|
||||
|
||||
## Docker Runtime
|
||||
This is the default Runtime that's used when you start OpenHands. You might notice
|
||||
some flags being passed to `docker run` that make this possible:
|
||||
@@ -56,11 +59,12 @@ any files that are mounted into its workspace.
|
||||
This setup can cause some issues with file permissions (hence the `SANDBOX_USER_ID` variable)
|
||||
but seems to work well on most systems.
|
||||
|
||||
## All Hands Runtime
|
||||
The All Hands Runtime is currently in beta. You can request access by joining
|
||||
the #remote-runtime-limited-beta channel on Slack ([see the README](https://github.com/All-Hands-AI/OpenHands?tab=readme-ov-file#-how-to-join-the-community) for an invite).
|
||||
## OpenHands Remote Runtime
|
||||
|
||||
To use the All Hands Runtime, set the following environment variables when
|
||||
OpenHands Remote Runtime is currently in beta (read [here](https://runtime.all-hands.dev/) for more details), it allows you to launch runtimes in parallel in the cloud.
|
||||
Fill out [this form](https://docs.google.com/forms/d/e/1FAIpQLSckVz_JFwg2_mOxNZjCtr7aoBFI2Mwdan3f75J_TrdMS1JV2g/viewform) to apply if you want to try this out!
|
||||
|
||||
To use the OpenHands Remote Runtime, set the following environment variables when
|
||||
starting OpenHands:
|
||||
|
||||
```bash
|
||||
@@ -117,3 +121,66 @@ bash -i <(curl -sL https://get.daytona.io/openhands)
|
||||
Once executed, OpenHands should be running locally and ready for use.
|
||||
|
||||
For more details and manual initialization, view the entire [README.md](https://github.com/All-Hands-AI/OpenHands/blob/main/openhands/runtime/impl/daytona/README.md)
|
||||
|
||||
## Local Runtime
|
||||
|
||||
The Local Runtime allows the OpenHands agent to execute actions directly on your local machine without using Docker. This runtime is primarily intended for controlled environments like CI pipelines or testing scenarios where Docker is not available.
|
||||
|
||||
:::caution
|
||||
**Security Warning**: The Local Runtime runs without any sandbox isolation. The agent can directly access and modify files on your machine. Only use this runtime in controlled environments or when you fully understand the security implications.
|
||||
:::
|
||||
|
||||
### Prerequisites
|
||||
|
||||
Before using the Local Runtime, ensure you have the following dependencies installed:
|
||||
|
||||
1. You have followed the [Development setup instructions](https://github.com/All-Hands-AI/OpenHands/blob/main/Development.md).
|
||||
2. tmux is available on your system.
|
||||
|
||||
### Configuration
|
||||
|
||||
To use the Local Runtime, besides required configurations like the model, API key, you'll need to set the following options via environment variables or the [config.toml file](https://github.com/All-Hands-AI/OpenHands/blob/main/config.template.toml) when starting OpenHands:
|
||||
|
||||
- Via environment variables:
|
||||
|
||||
```bash
|
||||
# Required
|
||||
export RUNTIME=local
|
||||
|
||||
# Optional but recommended
|
||||
export WORKSPACE_BASE=/path/to/your/workspace
|
||||
```
|
||||
|
||||
- Via `config.toml`:
|
||||
|
||||
```toml
|
||||
[core]
|
||||
runtime = "local"
|
||||
workspace_base = "/path/to/your/workspace"
|
||||
```
|
||||
|
||||
If `WORKSPACE_BASE` is not set, the runtime will create a temporary directory for the agent to work in.
|
||||
|
||||
### Example Usage
|
||||
|
||||
Here's an example of how to start OpenHands with the Local Runtime in Headless Mode:
|
||||
|
||||
```bash
|
||||
# Set the runtime type to local
|
||||
export RUNTIME=local
|
||||
|
||||
# Optionally set a workspace directory
|
||||
export WORKSPACE_BASE=/path/to/your/project
|
||||
|
||||
# Start OpenHands
|
||||
poetry run python -m openhands.core.main -t "write a bash script that prints hi"
|
||||
```
|
||||
|
||||
### Use Cases
|
||||
|
||||
The Local Runtime is particularly useful for:
|
||||
|
||||
- CI/CD pipelines where Docker is not available.
|
||||
- Testing and development of OpenHands itself.
|
||||
- Environments where container usage is restricted.
|
||||
- Scenarios where direct file system access is required.
|
||||
|
||||
@@ -3,7 +3,9 @@ import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Callable
|
||||
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
@@ -91,12 +93,22 @@ def get_config(metadata: EvalMetadata, instance: pd.Series) -> AppConfig:
|
||||
return config
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConditionalImports:
|
||||
"""We instantiate the values in this dataclass differently if we're evaluating SWE-bench or SWE-Gym."""
|
||||
|
||||
get_eval_report: Callable
|
||||
APPLY_PATCH_FAIL: str
|
||||
APPLY_PATCH_PASS: str
|
||||
|
||||
|
||||
def process_instance(
|
||||
instance: pd.Series,
|
||||
metadata: EvalMetadata,
|
||||
reset_logger: bool = True,
|
||||
log_dir: str | None = None,
|
||||
runtime_failure_count: int = 0,
|
||||
conditional_imports: ConditionalImports | None = None,
|
||||
) -> EvalOutput:
|
||||
"""
|
||||
Evaluate agent performance on a SWE-bench problem instance.
|
||||
@@ -108,9 +120,18 @@ def process_instance(
|
||||
log_dir (str | None, default=None): Path to directory where log files will be written. Must
|
||||
be provided if `reset_logger` is set.
|
||||
|
||||
conditional_imports: A dataclass containing values that are imported differently based on
|
||||
whether we're evaluating SWE-bench or SWE-Gym.
|
||||
|
||||
Raises:
|
||||
AssertionError: if the `reset_logger` flag is set without a provided log directory.
|
||||
|
||||
AssertionError: if `conditional_imports` is not provided.
|
||||
"""
|
||||
assert (
|
||||
conditional_imports is not None
|
||||
), 'conditional_imports must be provided to run process_instance using multiprocessing'
|
||||
|
||||
# Setup the logger properly, so you can run multi-processing to parallelize the evaluation
|
||||
if reset_logger:
|
||||
assert (
|
||||
@@ -124,7 +145,7 @@ def process_instance(
|
||||
config = get_config(metadata, instance)
|
||||
instance_id = instance.instance_id
|
||||
model_patch = instance['model_patch']
|
||||
test_spec: TestSpec = instance['test_spec']
|
||||
test_spec = instance['test_spec']
|
||||
logger.info(f'Starting evaluation for instance {instance_id}.')
|
||||
|
||||
if 'test_result' not in instance.keys():
|
||||
@@ -196,7 +217,9 @@ def process_instance(
|
||||
instance['test_result']['apply_patch_output'] = apply_patch_output
|
||||
|
||||
if 'APPLY_PATCH_FAIL' in apply_patch_output:
|
||||
logger.info(f'[{instance_id}] {APPLY_PATCH_FAIL}:\n{apply_patch_output}')
|
||||
logger.info(
|
||||
f'[{instance_id}] {conditional_imports.APPLY_PATCH_FAIL}:\n{apply_patch_output}'
|
||||
)
|
||||
instance['test_result']['report']['failed_apply_patch'] = True
|
||||
|
||||
return EvalOutput(
|
||||
@@ -205,7 +228,9 @@ def process_instance(
|
||||
metadata=metadata,
|
||||
)
|
||||
elif 'APPLY_PATCH_PASS' in apply_patch_output:
|
||||
logger.info(f'[{instance_id}] {APPLY_PATCH_PASS}:\n{apply_patch_output}')
|
||||
logger.info(
|
||||
f'[{instance_id}] {conditional_imports.APPLY_PATCH_PASS}:\n{apply_patch_output}'
|
||||
)
|
||||
|
||||
# Run eval script in background and save output to log file
|
||||
log_file = '/tmp/eval_output.log'
|
||||
@@ -271,7 +296,7 @@ def process_instance(
|
||||
with open(test_output_path, 'w') as f:
|
||||
f.write(test_output)
|
||||
try:
|
||||
_report = get_eval_report(
|
||||
_report = conditional_imports.get_eval_report(
|
||||
test_spec=test_spec,
|
||||
prediction={
|
||||
'model_patch': model_patch,
|
||||
@@ -345,7 +370,6 @@ if __name__ == '__main__':
|
||||
)
|
||||
from swegym.harness.test_spec import (
|
||||
SWEbenchInstance,
|
||||
TestSpec,
|
||||
make_test_spec,
|
||||
)
|
||||
from swegym.harness.utils import load_swebench_dataset
|
||||
@@ -357,7 +381,6 @@ if __name__ == '__main__':
|
||||
)
|
||||
from swebench.harness.test_spec.test_spec import (
|
||||
SWEbenchInstance,
|
||||
TestSpec,
|
||||
make_test_spec,
|
||||
)
|
||||
from swebench.harness.utils import load_swebench_dataset
|
||||
@@ -445,7 +468,15 @@ if __name__ == '__main__':
|
||||
# The evaluation harness constrains the signature of `process_instance_func` but we need to
|
||||
# pass extra information. Build a new function object to avoid issues with multiprocessing.
|
||||
process_instance_func = partial(
|
||||
process_instance, log_dir=output_file.replace('.jsonl', '.logs')
|
||||
process_instance,
|
||||
log_dir=output_file.replace('.jsonl', '.logs'),
|
||||
# We have to explicitly pass these imports to the process_instance function, otherwise
|
||||
# they won't be available in the multiprocessing context.
|
||||
conditional_imports=ConditionalImports(
|
||||
get_eval_report=get_eval_report,
|
||||
APPLY_PATCH_FAIL=APPLY_PATCH_FAIL,
|
||||
APPLY_PATCH_PASS=APPLY_PATCH_PASS,
|
||||
),
|
||||
)
|
||||
|
||||
run_evaluation(
|
||||
|
||||
@@ -4,8 +4,10 @@ import userEvent from "@testing-library/user-event";
|
||||
import { afterEach, beforeEach, describe, expect, it, test, vi } from "vitest";
|
||||
import OpenHands from "#/api/open-hands";
|
||||
import { PaymentForm } from "#/components/features/payment/payment-form";
|
||||
import * as featureFlags from "#/utils/feature-flags";
|
||||
|
||||
describe("PaymentForm", () => {
|
||||
const billingSettingsSpy = vi.spyOn(featureFlags, "BILLING_SETTINGS");
|
||||
const getBalanceSpy = vi.spyOn(OpenHands, "getBalance");
|
||||
const createCheckoutSessionSpy = vi.spyOn(OpenHands, "createCheckoutSession");
|
||||
const getConfigSpy = vi.spyOn(OpenHands, "getConfig");
|
||||
@@ -26,6 +28,7 @@ describe("PaymentForm", () => {
|
||||
GITHUB_CLIENT_ID: "123",
|
||||
POSTHOG_CLIENT_KEY: "456",
|
||||
});
|
||||
billingSettingsSpy.mockReturnValue(true);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
|
||||
@@ -72,8 +72,10 @@ describe("useTerminal", () => {
|
||||
wrapper: Wrapper,
|
||||
});
|
||||
|
||||
expect(mockTerminal.writeln).toHaveBeenNthCalledWith(1, "echo hello");
|
||||
expect(mockTerminal.writeln).toHaveBeenNthCalledWith(2, "hello");
|
||||
// Input commands should be displayed
|
||||
expect(mockTerminal.writeln).toHaveBeenCalledWith("echo hello");
|
||||
// Output commands should be displayed
|
||||
expect(mockTerminal.writeln).toHaveBeenCalledWith("hello");
|
||||
});
|
||||
|
||||
it("should hide secrets in the terminal", () => {
|
||||
@@ -97,13 +99,31 @@ describe("useTerminal", () => {
|
||||
},
|
||||
);
|
||||
|
||||
// BUG: `vi.clearAllMocks()` does not clear the number of calls
|
||||
// therefore, we need to assume the order of the calls based
|
||||
// on the test order
|
||||
expect(mockTerminal.writeln).toHaveBeenNthCalledWith(
|
||||
3,
|
||||
// Input command should be displayed with secrets masked
|
||||
expect(mockTerminal.writeln).toHaveBeenCalledWith(
|
||||
`export GITHUB_TOKEN=${"*".repeat(10)},${"*".repeat(10)},${"*".repeat(10)}`,
|
||||
);
|
||||
expect(mockTerminal.writeln).toHaveBeenNthCalledWith(4, "*".repeat(10));
|
||||
|
||||
// Output command should be displayed with secrets masked
|
||||
expect(mockTerminal.writeln).toHaveBeenCalledWith("*".repeat(10));
|
||||
});
|
||||
|
||||
it("should prevent duplicate command display", () => {
|
||||
const inputCommand = "ls -la";
|
||||
const commands: Command[] = [
|
||||
{ content: inputCommand, type: "input" },
|
||||
{ content: `${inputCommand}\nfile1.txt\nfile2.txt`, type: "output" },
|
||||
];
|
||||
|
||||
render(<TestTerminalComponent commands={commands} secrets={[]} />, {
|
||||
wrapper: Wrapper,
|
||||
});
|
||||
|
||||
// Input command should be displayed
|
||||
expect(mockTerminal.writeln).toHaveBeenCalledWith(inputCommand);
|
||||
|
||||
// Output should not be displayed since it starts with the input command
|
||||
// This prevents the duplicate display of the command
|
||||
expect(mockTerminal.writeln).not.toHaveBeenCalledWith(`${inputCommand}\nfile1.txt\nfile2.txt`);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -26,33 +26,32 @@ const createAxiosNotFoundErrorObject = () =>
|
||||
},
|
||||
);
|
||||
|
||||
const getSettingsSpy = vi.spyOn(OpenHands, "getSettings");
|
||||
|
||||
const RouterStub = createRoutesStub([
|
||||
{
|
||||
// layout route
|
||||
Component: MainApp,
|
||||
path: "/",
|
||||
children: [
|
||||
{
|
||||
// home route
|
||||
Component: Home,
|
||||
path: "/",
|
||||
},
|
||||
{
|
||||
Component: SettingsScreen,
|
||||
path: "/settings",
|
||||
},
|
||||
],
|
||||
},
|
||||
]);
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe("Home Screen", () => {
|
||||
const getSettingsSpy = vi.spyOn(OpenHands, "getSettings");
|
||||
const getConfigSpy = vi.spyOn(OpenHands, "getConfig");
|
||||
|
||||
const RouterStub = createRoutesStub([
|
||||
{
|
||||
// layout route
|
||||
Component: MainApp,
|
||||
path: "/",
|
||||
children: [
|
||||
{
|
||||
// home route
|
||||
Component: Home,
|
||||
path: "/",
|
||||
},
|
||||
{
|
||||
Component: SettingsScreen,
|
||||
path: "/settings",
|
||||
},
|
||||
],
|
||||
},
|
||||
]);
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it("should render the home screen", () => {
|
||||
renderWithProviders(<RouterStub initialEntries={["/"]} />);
|
||||
});
|
||||
@@ -79,57 +78,82 @@ describe("Home Screen", () => {
|
||||
const settingsScreen = await screen.findByTestId("settings-screen");
|
||||
expect(settingsScreen).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Settings 404", () => {
|
||||
it("should open the settings modal if GET /settings fails with a 404", async () => {
|
||||
const error = createAxiosNotFoundErrorObject();
|
||||
getSettingsSpy.mockRejectedValue(error);
|
||||
describe("Settings 404", () => {
|
||||
const getConfigSpy = vi.spyOn(OpenHands, "getConfig");
|
||||
|
||||
renderWithProviders(<RouterStub initialEntries={["/"]} />);
|
||||
it("should open the settings modal if GET /settings fails with a 404", async () => {
|
||||
const error = createAxiosNotFoundErrorObject();
|
||||
getSettingsSpy.mockRejectedValue(error);
|
||||
|
||||
const settingsModal = await screen.findByTestId("ai-config-modal");
|
||||
expect(settingsModal).toBeInTheDocument();
|
||||
});
|
||||
renderWithProviders(<RouterStub initialEntries={["/"]} />);
|
||||
|
||||
it("should navigate to the settings screen when clicking the advanced settings button", async () => {
|
||||
const error = createAxiosNotFoundErrorObject();
|
||||
getSettingsSpy.mockRejectedValue(error);
|
||||
const settingsModal = await screen.findByTestId("ai-config-modal");
|
||||
expect(settingsModal).toBeInTheDocument();
|
||||
});
|
||||
|
||||
const user = userEvent.setup();
|
||||
renderWithProviders(<RouterStub initialEntries={["/"]} />);
|
||||
it("should navigate to the settings screen when clicking the advanced settings button", async () => {
|
||||
const error = createAxiosNotFoundErrorObject();
|
||||
getSettingsSpy.mockRejectedValue(error);
|
||||
|
||||
const settingsScreen = screen.queryByTestId("settings-screen");
|
||||
expect(settingsScreen).not.toBeInTheDocument();
|
||||
const user = userEvent.setup();
|
||||
renderWithProviders(<RouterStub initialEntries={["/"]} />);
|
||||
|
||||
const settingsModal = await screen.findByTestId("ai-config-modal");
|
||||
expect(settingsModal).toBeInTheDocument();
|
||||
const settingsScreen = screen.queryByTestId("settings-screen");
|
||||
expect(settingsScreen).not.toBeInTheDocument();
|
||||
|
||||
const advancedSettingsButton = await screen.findByTestId(
|
||||
"advanced-settings-link",
|
||||
);
|
||||
await user.click(advancedSettingsButton);
|
||||
const settingsModal = await screen.findByTestId("ai-config-modal");
|
||||
expect(settingsModal).toBeInTheDocument();
|
||||
|
||||
const settingsScreenAfter = await screen.findByTestId("settings-screen");
|
||||
expect(settingsScreenAfter).toBeInTheDocument();
|
||||
const advancedSettingsButton = await screen.findByTestId(
|
||||
"advanced-settings-link",
|
||||
);
|
||||
await user.click(advancedSettingsButton);
|
||||
|
||||
const settingsModalAfter = screen.queryByTestId("ai-config-modal");
|
||||
expect(settingsModalAfter).not.toBeInTheDocument();
|
||||
});
|
||||
const settingsScreenAfter = await screen.findByTestId("settings-screen");
|
||||
expect(settingsScreenAfter).toBeInTheDocument();
|
||||
|
||||
it("should not open the settings modal if GET /settings fails but is SaaS mode", async () => {
|
||||
// TODO: Remove HIDE_LLM_SETTINGS check once released
|
||||
vi.spyOn(FeatureFlags, "HIDE_LLM_SETTINGS").mockReturnValue(true);
|
||||
// @ts-expect-error - we only need APP_MODE for this test
|
||||
getConfigSpy.mockResolvedValue({ APP_MODE: "saas" });
|
||||
const error = createAxiosNotFoundErrorObject();
|
||||
getSettingsSpy.mockRejectedValue(error);
|
||||
const settingsModalAfter = screen.queryByTestId("ai-config-modal");
|
||||
expect(settingsModalAfter).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
renderWithProviders(<RouterStub initialEntries={["/"]} />);
|
||||
it("should not open the settings modal if GET /settings fails but is SaaS mode", async () => {
|
||||
// TODO: Remove HIDE_LLM_SETTINGS check once released
|
||||
vi.spyOn(FeatureFlags, "HIDE_LLM_SETTINGS").mockReturnValue(true);
|
||||
// @ts-expect-error - we only need APP_MODE for this test
|
||||
getConfigSpy.mockResolvedValue({ APP_MODE: "saas" });
|
||||
const error = createAxiosNotFoundErrorObject();
|
||||
getSettingsSpy.mockRejectedValue(error);
|
||||
|
||||
// small hack to wait for the modal to not appear
|
||||
await expect(
|
||||
screen.findByTestId("ai-config-modal", {}, { timeout: 1000 }),
|
||||
).rejects.toThrow();
|
||||
});
|
||||
renderWithProviders(<RouterStub initialEntries={["/"]} />);
|
||||
|
||||
// small hack to wait for the modal to not appear
|
||||
await expect(
|
||||
screen.findByTestId("ai-config-modal", {}, { timeout: 1000 }),
|
||||
).rejects.toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Setup Payment modal", () => {
|
||||
const getConfigSpy = vi.spyOn(OpenHands, "getConfig");
|
||||
|
||||
afterEach(() => {
|
||||
vi.resetAllMocks();
|
||||
});
|
||||
|
||||
it("should only render if SaaS mode and is new user", async () => {
|
||||
// @ts-expect-error - we only need the APP_MODE for this test
|
||||
getConfigSpy.mockResolvedValue({
|
||||
APP_MODE: "saas",
|
||||
});
|
||||
vi.spyOn(FeatureFlags, "BILLING_SETTINGS").mockReturnValue(true);
|
||||
const error = createAxiosNotFoundErrorObject();
|
||||
getSettingsSpy.mockRejectedValue(error);
|
||||
|
||||
renderWithProviders(<RouterStub initialEntries={["/"]} />);
|
||||
|
||||
const setupPaymentModal = await screen.findByTestId("proceed-to-stripe-button");
|
||||
expect(setupPaymentModal).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -721,7 +721,7 @@ describe("Settings Screen", () => {
|
||||
expect(saveSettingsSpy).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
llm_api_key: "", // empty because it's not set previously
|
||||
github_token: undefined,
|
||||
provider_tokens: undefined,
|
||||
language: "no",
|
||||
}),
|
||||
);
|
||||
@@ -758,7 +758,7 @@ describe("Settings Screen", () => {
|
||||
|
||||
expect(saveSettingsSpy).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
github_token: undefined,
|
||||
provider_tokens: undefined,
|
||||
llm_api_key: "", // empty because it's not set previously
|
||||
llm_model: "openai/gpt-4o",
|
||||
}),
|
||||
@@ -801,7 +801,7 @@ describe("Settings Screen", () => {
|
||||
|
||||
expect(saveSettingsSpy).toHaveBeenCalledWith({
|
||||
...mockCopy,
|
||||
github_token: undefined, // not set
|
||||
provider_tokens: undefined, // not set
|
||||
llm_api_key: "", // reset as well
|
||||
});
|
||||
expect(screen.queryByTestId("reset-modal")).not.toBeInTheDocument();
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import { describe, it, expect, vi, beforeEach } from "vitest";
|
||||
import { handleStatusMessage } from "#/services/actions";
|
||||
import { handleStatusMessage, handleActionMessage } from "#/services/actions";
|
||||
import store from "#/store";
|
||||
import { trackError } from "#/utils/error-handler";
|
||||
import ActionType from "#/types/action-type";
|
||||
import { ActionMessage } from "#/types/message";
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock("#/utils/error-handler", () => ({
|
||||
@@ -56,4 +58,89 @@ describe("Actions Service", () => {
|
||||
}));
|
||||
});
|
||||
});
|
||||
|
||||
describe("handleActionMessage", () => {
|
||||
it("should use first-person perspective for task completion messages", () => {
|
||||
// Test partial completion
|
||||
const messagePartial: ActionMessage = {
|
||||
id: 1,
|
||||
action: ActionType.FINISH,
|
||||
source: "agent",
|
||||
message: "",
|
||||
timestamp: new Date().toISOString(),
|
||||
args: {
|
||||
final_thought: "",
|
||||
task_completed: "partial",
|
||||
outputs: "",
|
||||
thought: ""
|
||||
}
|
||||
};
|
||||
|
||||
// Mock implementation to capture the message
|
||||
let capturedPartialMessage = "";
|
||||
(store.dispatch as any).mockImplementation((action: any) => {
|
||||
if (action.type === "chat/addAssistantMessage" &&
|
||||
action.payload.includes("believe that the task was **completed partially**")) {
|
||||
capturedPartialMessage = action.payload;
|
||||
}
|
||||
});
|
||||
|
||||
handleActionMessage(messagePartial);
|
||||
expect(capturedPartialMessage).toContain("I believe that the task was **completed partially**");
|
||||
|
||||
// Test not completed
|
||||
const messageNotCompleted: ActionMessage = {
|
||||
id: 2,
|
||||
action: ActionType.FINISH,
|
||||
source: "agent",
|
||||
message: "",
|
||||
timestamp: new Date().toISOString(),
|
||||
args: {
|
||||
final_thought: "",
|
||||
task_completed: "false",
|
||||
outputs: "",
|
||||
thought: ""
|
||||
}
|
||||
};
|
||||
|
||||
// Mock implementation to capture the message
|
||||
let capturedNotCompletedMessage = "";
|
||||
(store.dispatch as any).mockImplementation((action: any) => {
|
||||
if (action.type === "chat/addAssistantMessage" &&
|
||||
action.payload.includes("believe that the task was **not completed**")) {
|
||||
capturedNotCompletedMessage = action.payload;
|
||||
}
|
||||
});
|
||||
|
||||
handleActionMessage(messageNotCompleted);
|
||||
expect(capturedNotCompletedMessage).toContain("I believe that the task was **not completed**");
|
||||
|
||||
// Test completed successfully
|
||||
const messageCompleted: ActionMessage = {
|
||||
id: 3,
|
||||
action: ActionType.FINISH,
|
||||
source: "agent",
|
||||
message: "",
|
||||
timestamp: new Date().toISOString(),
|
||||
args: {
|
||||
final_thought: "",
|
||||
task_completed: "true",
|
||||
outputs: "",
|
||||
thought: ""
|
||||
}
|
||||
};
|
||||
|
||||
// Mock implementation to capture the message
|
||||
let capturedCompletedMessage = "";
|
||||
(store.dispatch as any).mockImplementation((action: any) => {
|
||||
if (action.type === "chat/addAssistantMessage" &&
|
||||
action.payload.includes("believe that the task was **completed successfully**")) {
|
||||
capturedCompletedMessage = action.payload;
|
||||
}
|
||||
});
|
||||
|
||||
handleActionMessage(messageCompleted);
|
||||
expect(capturedCompletedMessage).toContain("I believe that the task was **completed successfully**");
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
* - Please do NOT serve this file on production.
|
||||
*/
|
||||
|
||||
const PACKAGE_VERSION = '2.7.0'
|
||||
const PACKAGE_VERSION = '2.7.3'
|
||||
const INTEGRITY_CHECKSUM = '00729d72e3b82faf54ca8b9621dbb96f'
|
||||
const IS_MOCKED_RESPONSE = Symbol('isMockedResponse')
|
||||
const activeClientIds = new Set()
|
||||
|
||||
@@ -281,6 +281,13 @@ class OpenHands {
|
||||
return data.redirect_url;
|
||||
}
|
||||
|
||||
static async createBillingSessionResponse(): Promise<string> {
|
||||
const { data } = await openHands.post(
|
||||
"/api/billing/create-customer-setup-session",
|
||||
);
|
||||
return data.redirect_url;
|
||||
}
|
||||
|
||||
static async getBalance(): Promise<string> {
|
||||
const { data } = await openHands.get<{ credits: string }>(
|
||||
"/api/billing/credits",
|
||||
|
||||
@@ -48,6 +48,7 @@ export interface GetConfigResponse {
|
||||
APP_SLUG?: string;
|
||||
GITHUB_CLIENT_ID: string;
|
||||
POSTHOG_CLIENT_KEY: string;
|
||||
STRIPE_PUBLISHABLE_KEY?: string;
|
||||
}
|
||||
|
||||
export interface GetVSCodeUrlResponse {
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
import { useMutation } from "@tanstack/react-query";
|
||||
import { Trans, useTranslation } from "react-i18next";
|
||||
import AllHandsLogo from "#/assets/branding/all-hands-logo.svg?react";
|
||||
import { ModalBackdrop } from "#/components/shared/modals/modal-backdrop";
|
||||
import { ModalBody } from "#/components/shared/modals/modal-body";
|
||||
import OpenHands from "#/api/open-hands";
|
||||
import { BrandButton } from "../settings/brand-button";
|
||||
import { displayErrorToast } from "#/utils/custom-toast-handlers";
|
||||
|
||||
export function SetupPaymentModal() {
|
||||
const { t } = useTranslation();
|
||||
const { mutate, isPending } = useMutation({
|
||||
mutationFn: OpenHands.createBillingSessionResponse,
|
||||
onSuccess: (data) => {
|
||||
window.location.href = data;
|
||||
},
|
||||
onError: () => {
|
||||
displayErrorToast(t("BILLING$ERROR_WHILE_CREATING_SESSION"));
|
||||
},
|
||||
});
|
||||
|
||||
return (
|
||||
<ModalBackdrop>
|
||||
<ModalBody className="border border-tertiary">
|
||||
<AllHandsLogo width={68} height={46} />
|
||||
<div className="flex flex-col gap-2 w-full items-center text-center">
|
||||
<h1 className="text-2xl font-bold">{t("BILLING$YOUVE_GOT_50")}</h1>
|
||||
<p>
|
||||
<Trans
|
||||
i18nKey="BILLING$CLAIM_YOUR_50"
|
||||
components={{ b: <strong /> }}
|
||||
/>
|
||||
</p>
|
||||
</div>
|
||||
<BrandButton
|
||||
testId="proceed-to-stripe-button"
|
||||
type="submit"
|
||||
variant="primary"
|
||||
className="w-full"
|
||||
isDisabled={isPending}
|
||||
onClick={mutate}
|
||||
>
|
||||
{t("BILLING$PROCEED_TO_STRIPE")}
|
||||
</BrandButton>
|
||||
</ModalBody>
|
||||
</ModalBackdrop>
|
||||
);
|
||||
}
|
||||
@@ -61,7 +61,7 @@ export function Sidebar() {
|
||||
displayErrorToast(
|
||||
"Something went wrong while fetching settings. Please reload the page.",
|
||||
);
|
||||
} else if (settingsError?.status === 404) {
|
||||
} else if (config?.APP_MODE === "oss" && settingsError?.status === 404) {
|
||||
setSettingsModalIsOpen(true);
|
||||
}
|
||||
}, [
|
||||
|
||||
@@ -17,7 +17,7 @@ const saveSettingsMutationFn = async (settings: Partial<PostSettings>) => {
|
||||
? ""
|
||||
: settings.LLM_API_KEY?.trim() || undefined,
|
||||
remote_runtime_resource_factor: settings.REMOTE_RUNTIME_RESOURCE_FACTOR,
|
||||
github_token: settings.github_token,
|
||||
provider_tokens: settings.provider_tokens,
|
||||
unset_github_token: settings.unset_github_token,
|
||||
enable_default_condenser: settings.ENABLE_DEFAULT_CONDENSER,
|
||||
enable_sound_notifications: settings.ENABLE_SOUND_NOTIFICATIONS,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { useQuery } from "@tanstack/react-query";
|
||||
import { useConfig } from "./use-config";
|
||||
import OpenHands from "#/api/open-hands";
|
||||
import { BILLING_SETTINGS } from "#/utils/feature-flags";
|
||||
|
||||
export const useBalance = () => {
|
||||
const { data: config } = useConfig();
|
||||
@@ -8,6 +9,6 @@ export const useBalance = () => {
|
||||
return useQuery({
|
||||
queryKey: ["user", "balance"],
|
||||
queryFn: OpenHands.getBalance,
|
||||
enabled: config?.APP_MODE === "saas",
|
||||
enabled: config?.APP_MODE === "saas" && BILLING_SETTINGS(),
|
||||
});
|
||||
};
|
||||
|
||||
@@ -21,6 +21,8 @@ const getSettingsQueryFn = async () => {
|
||||
ENABLE_DEFAULT_CONDENSER: apiSettings.enable_default_condenser,
|
||||
ENABLE_SOUND_NOTIFICATIONS: apiSettings.enable_sound_notifications,
|
||||
USER_CONSENTS_TO_ANALYTICS: apiSettings.user_consents_to_analytics,
|
||||
PROVIDER_TOKENS: apiSettings.provider_tokens,
|
||||
IS_NEW_USER: false,
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
@@ -86,6 +86,8 @@ export const useTerminal = ({
|
||||
|
||||
const handleEnter = (command: string) => {
|
||||
terminal.current?.write("\r\n");
|
||||
// Send the command to the backend but don't echo it back in the terminal
|
||||
// The backend will include the command in its response
|
||||
send(getTerminalCommand(command));
|
||||
};
|
||||
|
||||
@@ -131,10 +133,26 @@ export const useTerminal = ({
|
||||
content = content.replaceAll(secret, "*".repeat(10));
|
||||
});
|
||||
|
||||
terminal.current?.writeln(
|
||||
parseTerminalOutput(content.replaceAll("\n", "\r\n").trim()),
|
||||
);
|
||||
// Check if this is an output that starts with the previous input command
|
||||
// This happens when the backend echoes back the command in the output
|
||||
let shouldDisplayContent = true;
|
||||
if (type === "output" && i > 0 && commands[i - 1].type === "input") {
|
||||
const prevInputCommand = commands[i - 1].content.trim();
|
||||
// If the output starts with the input command, remove it to avoid duplication
|
||||
if (content.trim().startsWith(prevInputCommand)) {
|
||||
// Skip displaying this part as it's a duplicate of the user's input
|
||||
// that's already shown in the terminal
|
||||
shouldDisplayContent = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (shouldDisplayContent) {
|
||||
terminal.current?.writeln(
|
||||
parseTerminalOutput(content.replaceAll("\n", "\r\n").trim()),
|
||||
);
|
||||
}
|
||||
|
||||
// Add a new prompt after the output
|
||||
if (type === "output") {
|
||||
terminal.current.write(`\n$ `);
|
||||
}
|
||||
|
||||
@@ -312,4 +312,9 @@ export enum I18nKey {
|
||||
BUTTON$MARK_NOT_HELPFUL = "BUTTON$MARK_NOT_HELPFUL",
|
||||
BUTTON$EXPORT_CONVERSATION = "BUTTON$EXPORT_CONVERSATION",
|
||||
BILLING$CLICK_TO_TOP_UP = "BILLING$CLICK_TO_TOP_UP",
|
||||
BILLING$YOUVE_GOT_50 = "BILLING$YOUVE_GOT_50",
|
||||
BILLING$ERROR_WHILE_CREATING_SESSION = "BILLING$ERROR_WHILE_CREATING_SESSION",
|
||||
BILLING$CLAIM_YOUR_50 = "BILLING$CLAIM_YOUR_50",
|
||||
BILLING$PROCEED_TO_STRIPE = "BILLING$PROCEED_TO_STRIPE",
|
||||
BILLING$YOURE_IN = "BILLING$YOURE_IN",
|
||||
}
|
||||
|
||||
@@ -4647,5 +4647,80 @@
|
||||
"fr": "Ajouter des fonds à votre compte",
|
||||
"tr": "Hesabınıza bakiye ekleyin",
|
||||
"de": "Guthaben zu Ihrem Konto hinzufügen"
|
||||
},
|
||||
"BILLING$YOUVE_GOT_50": {
|
||||
"en": "You've got $50 in free OpenHands credits",
|
||||
"ja": "OpenHandsの無料クレジット$50を獲得しました",
|
||||
"zh-CN": "您获得了 $50 的 OpenHands 免费额度",
|
||||
"zh-TW": "您獲得了 $50 的 OpenHands 免費額度",
|
||||
"ko-KR": "OpenHands 무료 크레딧 $50를 받았습니다",
|
||||
"no": "Du har fått $50 i gratis OpenHands-kreditter",
|
||||
"it": "Hai ottenuto $50 in crediti gratuiti OpenHands",
|
||||
"pt": "Você ganhou $50 em créditos gratuitos OpenHands",
|
||||
"es": "Has recibido $50 en créditos gratuitos de OpenHands",
|
||||
"ar": "لديك 50$ من رصيد OpenHands المجاني",
|
||||
"fr": "Vous avez reçu $50 de crédits OpenHands gratuits",
|
||||
"tr": "OpenHands'de $50 ücretsiz kredi kazandınız",
|
||||
"de": "Sie haben $50 in kostenlosen OpenHands-Guthaben erhalten"
|
||||
},
|
||||
"BILLING$ERROR_WHILE_CREATING_SESSION": {
|
||||
"en": "Error occurred while setting up your payment session. Please try again later.",
|
||||
"ja": "お支払いセッションの設定中にエラーが発生しました。後ほど再度お試しください。",
|
||||
"zh-CN": "设置支付会话时发生错误。请稍后再试。",
|
||||
"zh-TW": "設置支付會話時發生錯誤。請稍後再試。",
|
||||
"ko-KR": "결제 세션 설정 중 오류가 발생했습니다. 나중에 다시 시도해 주세요.",
|
||||
"no": "Det oppstod en feil under oppsett av betalingsøkten. Vennligst prøv igjen senere.",
|
||||
"it": "Si è verificato un errore durante la configurazione della sessione di pagamento. Si prega di riprovare più tardi.",
|
||||
"pt": "Ocorreu um erro ao configurar sua sessão de pagamento. Por favor, tente novamente mais tarde.",
|
||||
"es": "Se produjo un error al configurar tu sesión de pago. Por favor, inténtalo de nuevo más tarde.",
|
||||
"ar": "حدث خطأ أثناء إعداد جلسة الدفع الخاصة بك. يرجى المحاولة مرة أخرى لاحقًا.",
|
||||
"fr": "Une erreur s'est produite lors de la configuration de votre session de paiement. Veuillez réessayer plus tard.",
|
||||
"tr": "Ödeme oturumunuz kurulurken bir hata oluştu. Lütfen daha sonra tekrar deneyin.",
|
||||
"de": "Beim Einrichten Ihrer Zahlungssitzung ist ein Fehler aufgetreten. Bitte versuchen Sie es später erneut."
|
||||
},
|
||||
"BILLING$CLAIM_YOUR_50": {
|
||||
"en": "Add a credit card with Stripe to claim your $50. <b>We won't charge you without asking first!</b>",
|
||||
"ja": "Stripeでクレジットカードを追加して$50を獲得。<b>事前の確認なしで請求することはありません!</b>",
|
||||
"zh-CN": "添加 Stripe 信用卡以领取 $50。<b>我们不会在未经您同意的情况下收费!</b>",
|
||||
"zh-TW": "添加 Stripe 信用卡以領取 $50。<b>我們不會在未經您同意的情況下收費!</b>",
|
||||
"ko-KR": "Stripe에 신용카드를 추가하여 $50를 받으세요. <b>사전 동의 없이 요금이 청구되지 않습니다!</b>",
|
||||
"no": "Legg til et kredittkort med Stripe for å få $50. <b>Vi belaster deg ikke uten å spørre først!</b>",
|
||||
"it": "Aggiungi una carta di credito con Stripe per ottenere $50. <b>Non ti addebiteremo nulla senza chiedere prima!</b>",
|
||||
"pt": "Adicione um cartão de crédito com Stripe para receber $50. <b>Não cobraremos sem perguntar primeiro!</b>",
|
||||
"es": "Añade una tarjeta de crédito con Stripe para reclamar tus $50. <b>¡No te cobraremos sin preguntarte primero!</b>",
|
||||
"ar": "أضف بطاقة ائتمان مع Stripe للحصول على 50$. <b>لن نقوم بالخصم دون إذن مسبق!</b>",
|
||||
"fr": "Ajoutez une carte de crédit avec Stripe pour obtenir 50$. <b>Nous ne vous facturerons pas sans vous demander d'abord !</b>",
|
||||
"tr": "50$ almak için Stripe ile kredi kartı ekleyin. <b>Önce sormadan ücret almayacağız!</b>",
|
||||
"de": "Fügen Sie eine Kreditkarte mit Stripe hinzu, um $50 zu erhalten. <b>Wir belasten Sie nicht ohne vorherige Zustimmung!</b>"
|
||||
},
|
||||
"BILLING$PROCEED_TO_STRIPE": {
|
||||
"en": "Add Billing Info",
|
||||
"ja": "請求情報を追加",
|
||||
"zh-CN": "添加账单信息",
|
||||
"zh-TW": "添加帳單資訊",
|
||||
"ko-KR": "결제 정보 추가",
|
||||
"no": "Legg til betalingsinformasjon",
|
||||
"it": "Aggiungi informazioni di fatturazione",
|
||||
"pt": "Adicionar informações de pagamento",
|
||||
"es": "Añadir información de facturación",
|
||||
"ar": "إضافة معلومات الفواتير",
|
||||
"fr": "Ajouter les informations de facturation",
|
||||
"tr": "Fatura Bilgisi Ekle",
|
||||
"de": "Zahlungsinformationen hinzufügen"
|
||||
},
|
||||
"BILLING$YOURE_IN": {
|
||||
"en": "You're in! You can start using your $50 in free credits now.",
|
||||
"ja": "登録完了!$50分の無料クレジットを今すぐご利用いただけます。",
|
||||
"zh-CN": "您已加入!现在可以开始使用$50的免费额度了。",
|
||||
"zh-TW": "您已加入!現在可以開始使用$50的免費額度了。",
|
||||
"ko-KR": "가입 완료! 지금 바로 $50 상당의 무료 크레딧을 사용하실 수 있습니다.",
|
||||
"no": "Du er med! Du kan begynne å bruke dine $50 i gratis kreditter nå.",
|
||||
"it": "Ci sei! Puoi iniziare a utilizzare i tuoi $50 in crediti gratuiti ora.",
|
||||
"pt": "Você está dentro! Você pode começar a usar seus $50 em créditos gratuitos agora.",
|
||||
"es": "¡Ya estás dentro! Puedes empezar a usar tus $50 en créditos gratuitos ahora.",
|
||||
"ar": "أنت معنا! يمكنك البدء في استخدام رصيدك المجاني البالغ 50 دولارًا الآن.",
|
||||
"fr": "C'est fait ! Vous pouvez commencer à utiliser vos 50 $ de crédits gratuits maintenant.",
|
||||
"tr": "Başardın! Şimdi $50 değerindeki ücretsiz kredilerini kullanmaya başlayabilirsin.",
|
||||
"de": "Du bist dabei! Du kannst jetzt deine $50 an kostenlosen Guthaben nutzen."
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,4 @@
|
||||
import { delay, http, HttpResponse } from "msw";
|
||||
import Stripe from "stripe";
|
||||
|
||||
const TEST_STRIPE_SECRET_KEY = "";
|
||||
const PRICE_ID = "";
|
||||
|
||||
export const STRIPE_BILLING_HANDLERS = [
|
||||
http.get("/api/billing/credits", async () => {
|
||||
@@ -10,27 +6,17 @@ export const STRIPE_BILLING_HANDLERS = [
|
||||
return HttpResponse.json({ credits: "100" });
|
||||
}),
|
||||
|
||||
http.post("/api/billing/create-checkout-session", async ({ request }) => {
|
||||
http.post("/api/billing/create-checkout-session", async () => {
|
||||
await delay();
|
||||
const body = await request.json();
|
||||
return HttpResponse.json({
|
||||
redirect_url: "https://stripe.com/some-checkout",
|
||||
});
|
||||
}),
|
||||
|
||||
if (body && typeof body === "object" && body.amount) {
|
||||
const stripe = new Stripe(TEST_STRIPE_SECRET_KEY);
|
||||
const session = await stripe.checkout.sessions.create({
|
||||
line_items: [
|
||||
{
|
||||
price: PRICE_ID,
|
||||
quantity: body.amount,
|
||||
},
|
||||
],
|
||||
mode: "payment",
|
||||
success_url: "http://localhost:3001/settings/billing/?checkout=success",
|
||||
cancel_url: "http://localhost:3001/settings/billing/?checkout=cancel",
|
||||
});
|
||||
|
||||
if (session.url) return HttpResponse.json({ redirect_url: session.url });
|
||||
}
|
||||
|
||||
return HttpResponse.json({ message: "Invalid request" }, { status: 400 });
|
||||
http.post("/api/billing/create-customer-setup-session", async () => {
|
||||
await delay();
|
||||
return HttpResponse.json({
|
||||
redirect_url: "https://stripe.com/some-customer-setup",
|
||||
});
|
||||
}),
|
||||
];
|
||||
|
||||
@@ -22,12 +22,13 @@ export const MOCK_DEFAULT_USER_SETTINGS: ApiSettings | PostApiSettings = {
|
||||
enable_default_condenser: DEFAULT_SETTINGS.ENABLE_DEFAULT_CONDENSER,
|
||||
enable_sound_notifications: DEFAULT_SETTINGS.ENABLE_SOUND_NOTIFICATIONS,
|
||||
user_consents_to_analytics: DEFAULT_SETTINGS.USER_CONSENTS_TO_ANALYTICS,
|
||||
provider_tokens: DEFAULT_SETTINGS.PROVIDER_TOKENS,
|
||||
};
|
||||
|
||||
const MOCK_USER_PREFERENCES: {
|
||||
settings: ApiSettings | PostApiSettings;
|
||||
settings: ApiSettings | PostApiSettings | null;
|
||||
} = {
|
||||
settings: MOCK_DEFAULT_USER_SETTINGS,
|
||||
settings: null,
|
||||
};
|
||||
|
||||
const conversations: Conversation[] = [
|
||||
@@ -174,22 +175,24 @@ export const handlers = [
|
||||
),
|
||||
http.get("/api/options/config", () => {
|
||||
const mockSaas = import.meta.env.VITE_MOCK_SAAS === "true";
|
||||
|
||||
const config: GetConfigResponse = {
|
||||
APP_MODE: mockSaas ? "saas" : "oss",
|
||||
GITHUB_CLIENT_ID: "fake-github-client-id",
|
||||
POSTHOG_CLIENT_KEY: "fake-posthog-client-key",
|
||||
STRIPE_PUBLISHABLE_KEY: "",
|
||||
};
|
||||
|
||||
return HttpResponse.json(config);
|
||||
}),
|
||||
http.get("/api/settings", async () => {
|
||||
await delay();
|
||||
const settings: ApiSettings = {
|
||||
...MOCK_USER_PREFERENCES.settings,
|
||||
language: "no",
|
||||
};
|
||||
// @ts-expect-error - mock types
|
||||
if (settings.github_token) settings.github_token_is_set = true;
|
||||
const { settings } = MOCK_USER_PREFERENCES;
|
||||
|
||||
if (!settings) return HttpResponse.json(null, { status: 404 });
|
||||
|
||||
if (Object.keys(settings.provider_tokens).length > 0)
|
||||
settings.github_token_is_set = true;
|
||||
|
||||
return HttpResponse.json(settings);
|
||||
}),
|
||||
@@ -201,17 +204,19 @@ export const handlers = [
|
||||
if (typeof body === "object") {
|
||||
newSettings = { ...body };
|
||||
if (newSettings.unset_github_token) {
|
||||
newSettings.github_token = undefined;
|
||||
newSettings.provider_tokens = { github: "", gitlab: "" };
|
||||
newSettings.github_token_is_set = false;
|
||||
delete newSettings.unset_github_token;
|
||||
}
|
||||
}
|
||||
|
||||
MOCK_USER_PREFERENCES.settings = {
|
||||
const fullSettings = {
|
||||
...MOCK_DEFAULT_USER_SETTINGS,
|
||||
...MOCK_USER_PREFERENCES.settings,
|
||||
...newSettings,
|
||||
};
|
||||
|
||||
MOCK_USER_PREFERENCES.settings = fullSettings;
|
||||
return HttpResponse.json(null, { status: 200 });
|
||||
}
|
||||
|
||||
|
||||
@@ -24,7 +24,10 @@ function Home() {
|
||||
});
|
||||
|
||||
return (
|
||||
<div className="bg-base-secondary h-full rounded-xl flex flex-col items-center justify-center relative overflow-y-auto px-2">
|
||||
<div
|
||||
data-testid="home-screen"
|
||||
className="bg-base-secondary h-full rounded-xl flex flex-col items-center justify-center relative overflow-y-auto px-2"
|
||||
>
|
||||
<HeroHeading />
|
||||
<div className="flex flex-col gap-8 w-full md:w-[600px] items-center">
|
||||
<div className="flex flex-col gap-2 w-full">
|
||||
|
||||
@@ -1,5 +1,13 @@
|
||||
import React from "react";
|
||||
import { useRouteError, isRouteErrorResponse, Outlet } from "react-router";
|
||||
import {
|
||||
useRouteError,
|
||||
isRouteErrorResponse,
|
||||
Outlet,
|
||||
useNavigate,
|
||||
useLocation,
|
||||
useSearchParams,
|
||||
} from "react-router";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import i18n from "#/i18n";
|
||||
import { useGitHubAuthUrl } from "#/hooks/use-github-auth-url";
|
||||
import { useIsAuthed } from "#/hooks/query/use-is-authed";
|
||||
@@ -10,6 +18,10 @@ import { AnalyticsConsentFormModal } from "#/components/features/analytics/analy
|
||||
import { useSettings } from "#/hooks/query/use-settings";
|
||||
import { useAuth } from "#/context/auth-context";
|
||||
import { useMigrateUserConsent } from "#/hooks/use-migrate-user-consent";
|
||||
import { useBalance } from "#/hooks/query/use-balance";
|
||||
import { SetupPaymentModal } from "#/components/features/payment/setup-payment-modal";
|
||||
import { BILLING_SETTINGS } from "#/utils/feature-flags";
|
||||
import { displaySuccessToast } from "#/utils/custom-toast-handlers";
|
||||
|
||||
export function ErrorBoundary() {
|
||||
const error = useRouteError();
|
||||
@@ -44,11 +56,14 @@ export function ErrorBoundary() {
|
||||
}
|
||||
|
||||
export default function MainApp() {
|
||||
const navigate = useNavigate();
|
||||
const { pathname } = useLocation();
|
||||
const [searchParams] = useSearchParams();
|
||||
const { githubTokenIsSet } = useAuth();
|
||||
const { data: settings } = useSettings();
|
||||
const { error, isFetching } = useBalance();
|
||||
const { migrateUserConsent } = useMigrateUserConsent();
|
||||
|
||||
const [consentFormIsOpen, setConsentFormIsOpen] = React.useState(false);
|
||||
const { t } = useTranslation();
|
||||
|
||||
const config = useConfig();
|
||||
const {
|
||||
@@ -62,6 +77,8 @@ export default function MainApp() {
|
||||
gitHubClientId: config.data?.GITHUB_CLIENT_ID || null,
|
||||
});
|
||||
|
||||
const [consentFormIsOpen, setConsentFormIsOpen] = React.useState(false);
|
||||
|
||||
React.useEffect(() => {
|
||||
if (settings?.LANGUAGE) {
|
||||
i18n.changeLanguage(settings.LANGUAGE);
|
||||
@@ -84,6 +101,17 @@ export default function MainApp() {
|
||||
});
|
||||
}, []);
|
||||
|
||||
React.useEffect(() => {
|
||||
// Don't allow users to use the app if it 402s
|
||||
if (error?.status === 402 && pathname !== "/") {
|
||||
navigate("/");
|
||||
} else if (!isFetching && searchParams.get("free_credits") === "success") {
|
||||
displaySuccessToast(t("BILLING$YOURE_IN"));
|
||||
searchParams.delete("free_credits");
|
||||
navigate("/");
|
||||
}
|
||||
}, [error?.status, pathname, isFetching]);
|
||||
|
||||
const userIsAuthed = !!isAuthed && !authError;
|
||||
const renderWaitlistModal =
|
||||
!isFetchingAuth && !userIsAuthed && config.data?.APP_MODE === "saas";
|
||||
@@ -116,6 +144,10 @@ export default function MainApp() {
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
|
||||
{BILLING_SETTINGS() &&
|
||||
config.data?.APP_MODE === "saas" &&
|
||||
settings?.IS_NEW_USER && <SetupPaymentModal />}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -61,7 +61,10 @@ function AccountSettings() {
|
||||
if (isSuccess) {
|
||||
return (
|
||||
isCustomModel(resources.models, settings.LLM_MODEL) ||
|
||||
hasAdvancedSettingsSet(settings)
|
||||
hasAdvancedSettingsSet({
|
||||
...settings,
|
||||
PROVIDER_TOKENS: settings.PROVIDER_TOKENS || {},
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
@@ -128,37 +131,42 @@ function AccountSettings() {
|
||||
: llmBaseUrl;
|
||||
const finalLlmApiKey = shouldHandleSpecialSaasCase ? undefined : llmApiKey;
|
||||
|
||||
saveSettings(
|
||||
{
|
||||
github_token:
|
||||
formData.get("github-token-input")?.toString() || undefined,
|
||||
LANGUAGE: languageValue,
|
||||
user_consents_to_analytics: userConsentsToAnalytics,
|
||||
ENABLE_DEFAULT_CONDENSER: enableMemoryCondenser,
|
||||
ENABLE_SOUND_NOTIFICATIONS: enableSoundNotifications,
|
||||
LLM_MODEL: finalLlmModel,
|
||||
LLM_BASE_URL: finalLlmBaseUrl,
|
||||
LLM_API_KEY: finalLlmApiKey,
|
||||
AGENT: formData.get("agent-input")?.toString(),
|
||||
SECURITY_ANALYZER:
|
||||
formData.get("security-analyzer-input")?.toString() || "",
|
||||
REMOTE_RUNTIME_RESOURCE_FACTOR:
|
||||
remoteRuntimeResourceFactor ||
|
||||
DEFAULT_SETTINGS.REMOTE_RUNTIME_RESOURCE_FACTOR,
|
||||
CONFIRMATION_MODE: confirmationModeIsEnabled,
|
||||
const githubToken = formData.get("github-token-input")?.toString();
|
||||
const newSettings = {
|
||||
github_token: githubToken,
|
||||
provider_tokens: githubToken
|
||||
? {
|
||||
github: githubToken,
|
||||
gitlab: "",
|
||||
}
|
||||
: undefined,
|
||||
LANGUAGE: languageValue,
|
||||
user_consents_to_analytics: userConsentsToAnalytics,
|
||||
ENABLE_DEFAULT_CONDENSER: enableMemoryCondenser,
|
||||
ENABLE_SOUND_NOTIFICATIONS: enableSoundNotifications,
|
||||
LLM_MODEL: finalLlmModel,
|
||||
LLM_BASE_URL: finalLlmBaseUrl,
|
||||
LLM_API_KEY: finalLlmApiKey,
|
||||
AGENT: formData.get("agent-input")?.toString(),
|
||||
SECURITY_ANALYZER:
|
||||
formData.get("security-analyzer-input")?.toString() || "",
|
||||
REMOTE_RUNTIME_RESOURCE_FACTOR:
|
||||
remoteRuntimeResourceFactor ||
|
||||
DEFAULT_SETTINGS.REMOTE_RUNTIME_RESOURCE_FACTOR,
|
||||
CONFIRMATION_MODE: confirmationModeIsEnabled,
|
||||
};
|
||||
|
||||
saveSettings(newSettings, {
|
||||
onSuccess: () => {
|
||||
handleCaptureConsent(userConsentsToAnalytics);
|
||||
displaySuccessToast("Settings saved");
|
||||
setLlmConfigMode(isAdvancedSettingsSet ? "advanced" : "basic");
|
||||
},
|
||||
{
|
||||
onSuccess: () => {
|
||||
handleCaptureConsent(userConsentsToAnalytics);
|
||||
displaySuccessToast("Settings saved");
|
||||
setLlmConfigMode(isAdvancedSettingsSet ? "advanced" : "basic");
|
||||
},
|
||||
onError: (error) => {
|
||||
const errorMessage = retrieveAxiosErrorMessage(error);
|
||||
displayErrorToast(errorMessage);
|
||||
},
|
||||
onError: (error) => {
|
||||
const errorMessage = retrieveAxiosErrorMessage(error);
|
||||
displayErrorToast(errorMessage);
|
||||
},
|
||||
);
|
||||
});
|
||||
};
|
||||
|
||||
const handleReset = () => {
|
||||
|
||||
@@ -62,13 +62,12 @@ const messageActions = {
|
||||
let successPrediction = "";
|
||||
if (message.args.task_completed === "partial") {
|
||||
successPrediction =
|
||||
"The agent thinks that the task was **completed partially**.";
|
||||
"I believe that the task was **completed partially**.";
|
||||
} else if (message.args.task_completed === "false") {
|
||||
successPrediction =
|
||||
"The agent thinks that the task was **not completed**.";
|
||||
successPrediction = "I believe that the task was **not completed**.";
|
||||
} else if (message.args.task_completed === "true") {
|
||||
successPrediction =
|
||||
"The agent thinks that the task was **completed successfully**.";
|
||||
"I believe that the task was **completed successfully**.";
|
||||
}
|
||||
if (successPrediction) {
|
||||
// if final_thought is not empty, add a new line before the success prediction
|
||||
|
||||
@@ -15,6 +15,11 @@ export const DEFAULT_SETTINGS: Settings = {
|
||||
ENABLE_DEFAULT_CONDENSER: true,
|
||||
ENABLE_SOUND_NOTIFICATIONS: false,
|
||||
USER_CONSENTS_TO_ANALYTICS: false,
|
||||
PROVIDER_TOKENS: {
|
||||
github: "",
|
||||
gitlab: "",
|
||||
},
|
||||
IS_NEW_USER: true,
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -21,7 +21,6 @@ export interface InitConfig {
|
||||
LLM_MODEL: string;
|
||||
};
|
||||
token?: string;
|
||||
github_token?: string;
|
||||
latest_event_id?: unknown; // Not sure what this is
|
||||
}
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
export type Provider = "github" | "gitlab";
|
||||
|
||||
export type Settings = {
|
||||
LLM_MODEL: string;
|
||||
LLM_BASE_URL: string;
|
||||
@@ -11,6 +13,8 @@ export type Settings = {
|
||||
ENABLE_DEFAULT_CONDENSER: boolean;
|
||||
ENABLE_SOUND_NOTIFICATIONS: boolean;
|
||||
USER_CONSENTS_TO_ANALYTICS: boolean | null;
|
||||
PROVIDER_TOKENS: Record<Provider, string>;
|
||||
IS_NEW_USER?: boolean;
|
||||
};
|
||||
|
||||
export type ApiSettings = {
|
||||
@@ -26,16 +30,17 @@ export type ApiSettings = {
|
||||
enable_default_condenser: boolean;
|
||||
enable_sound_notifications: boolean;
|
||||
user_consents_to_analytics: boolean | null;
|
||||
provider_tokens: Record<Provider, string>;
|
||||
};
|
||||
|
||||
export type PostSettings = Settings & {
|
||||
github_token: string;
|
||||
provider_tokens: Record<Provider, string>;
|
||||
unset_github_token: boolean;
|
||||
user_consents_to_analytics: boolean | null;
|
||||
};
|
||||
|
||||
export type PostApiSettings = ApiSettings & {
|
||||
github_token: string;
|
||||
provider_tokens: Record<Provider, string>;
|
||||
unset_github_token: boolean;
|
||||
user_consents_to_analytics: boolean | null;
|
||||
};
|
||||
|
||||
@@ -59,6 +59,18 @@ export const extractSettings = (formData: FormData): Partial<Settings> => {
|
||||
ENABLE_DEFAULT_CONDENSER,
|
||||
} = extractAdvancedFormData(formData);
|
||||
|
||||
// Extract provider tokens
|
||||
const githubToken = formData.get("github-token")?.toString();
|
||||
const gitlabToken = formData.get("gitlab-token")?.toString();
|
||||
const providerTokens: Record<string, string> = {};
|
||||
|
||||
if (githubToken) {
|
||||
providerTokens.github = githubToken;
|
||||
}
|
||||
if (gitlabToken) {
|
||||
providerTokens.gitlab = gitlabToken;
|
||||
}
|
||||
|
||||
return {
|
||||
LLM_MODEL: CUSTOM_LLM_MODEL || LLM_MODEL,
|
||||
LLM_API_KEY,
|
||||
@@ -68,5 +80,6 @@ export const extractSettings = (formData: FormData): Partial<Settings> => {
|
||||
CONFIRMATION_MODE,
|
||||
SECURITY_ANALYZER,
|
||||
ENABLE_DEFAULT_CONDENSER,
|
||||
PROVIDER_TOKENS: providerTokens,
|
||||
};
|
||||
};
|
||||
|
||||
@@ -17,45 +17,7 @@ export default {
|
||||
tertiary: "#454545", // gray, used for inputs
|
||||
"tertiary-light": "#B7BDC2", // lighter gray, used for borders and placeholder text
|
||||
content: "#ECEDEE", // light gray, used mostly for text
|
||||
},
|
||||
},
|
||||
animation: {
|
||||
enter: "toastIn 400ms cubic-bezier(0.21, 1.02, 0.73, 1)",
|
||||
leave: "toastOut 100ms ease-in forwards",
|
||||
},
|
||||
keyframes: {
|
||||
toastIn: {
|
||||
"0%": {
|
||||
opacity: "0",
|
||||
transform: "translateY(-100%) scale(0.8)",
|
||||
},
|
||||
"80%": {
|
||||
opacity: "1",
|
||||
transform: "translateY(0) scale(1.02)",
|
||||
},
|
||||
"100%": {
|
||||
opacity: "1",
|
||||
transform: "translateY(0) scale(1)",
|
||||
},
|
||||
},
|
||||
toastOut: {
|
||||
"0%": {
|
||||
opacity: "1",
|
||||
transform: "translateY(0) scale(1)",
|
||||
},
|
||||
"100%": {
|
||||
opacity: "0",
|
||||
transform: "translateY(-100%) scale(0.9)",
|
||||
},
|
||||
},
|
||||
colors: {
|
||||
primary: "#C9B974", // nice yellow
|
||||
base: "#171717", // dark background (neutral-900)
|
||||
"base-secondary": "#262626", // lighter background (neutral-800); also used for tooltips
|
||||
danger: "#E76A5E",
|
||||
success: "#A5E75E",
|
||||
tertiary: "#454545", // gray, used for inputs
|
||||
"tertiary-light": "#B7BDC2", // lighter gray, used for borders and placeholder text
|
||||
"content-2": "#F9FBFE",
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -103,8 +103,9 @@ class StuckDetector:
|
||||
return True
|
||||
|
||||
# scenario 5: context window error loop
|
||||
if self._is_stuck_context_window_error(filtered_history):
|
||||
return True
|
||||
if len(filtered_history) >= 10:
|
||||
if self._is_stuck_context_window_error(filtered_history):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@@ -333,12 +334,12 @@ class StuckDetector:
|
||||
if isinstance(event, AgentCondensationObservation)
|
||||
]
|
||||
|
||||
# Need at least 3 condensation events to detect a loop
|
||||
if len(condensation_events) < 3:
|
||||
# Need at least 10 condensation events to detect a loop
|
||||
if len(condensation_events) < 10:
|
||||
return False
|
||||
|
||||
# Get the last 3 condensation events
|
||||
last_condensation_events = condensation_events[-3:]
|
||||
# Get the last 10 condensation events
|
||||
last_condensation_events = condensation_events[-10:]
|
||||
|
||||
# Check if there are any non-condensation events between them
|
||||
for i in range(len(last_condensation_events) - 1):
|
||||
|
||||
@@ -5,43 +5,43 @@ from typing import Any
|
||||
import httpx
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.github.github_types import (
|
||||
GhAuthenticationError,
|
||||
GHUnknownException,
|
||||
GitHubRepository,
|
||||
GitHubUser,
|
||||
from openhands.integrations.service_types import (
|
||||
AuthenticationError,
|
||||
GitService,
|
||||
Repository,
|
||||
SuggestedTask,
|
||||
TaskType,
|
||||
UnknownException,
|
||||
User,
|
||||
)
|
||||
from openhands.utils.import_utils import get_impl
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
class GitHubService:
|
||||
class GitHubService(GitService):
|
||||
BASE_URL = 'https://api.github.com'
|
||||
github_token: SecretStr = SecretStr('')
|
||||
token: SecretStr = SecretStr('')
|
||||
refresh = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str | None = None,
|
||||
external_auth_token: SecretStr | None = None,
|
||||
github_token: SecretStr | None = None,
|
||||
token: SecretStr | None = None,
|
||||
external_token_manager: bool = False,
|
||||
):
|
||||
self.user_id = user_id
|
||||
self.external_token_manager = external_token_manager
|
||||
|
||||
if github_token:
|
||||
self.github_token = github_token
|
||||
if token:
|
||||
self.token = token
|
||||
|
||||
async def _get_github_headers(self) -> dict:
|
||||
"""Retrieve the GH Token from settings store to construct the headers."""
|
||||
if self.user_id and not self.github_token:
|
||||
self.github_token = await self.get_latest_token()
|
||||
if self.user_id and not self.token:
|
||||
self.token = await self.get_latest_token()
|
||||
|
||||
return {
|
||||
'Authorization': f'Bearer {self.github_token.get_secret_value() if self.github_token else ""}',
|
||||
'Authorization': f'Bearer {self.token.get_secret_value() if self.token else ""}',
|
||||
'Accept': 'application/vnd.github.v3+json',
|
||||
}
|
||||
|
||||
@@ -49,7 +49,7 @@ class GitHubService:
|
||||
return status_code == 401
|
||||
|
||||
async def get_latest_token(self) -> SecretStr | None:
|
||||
return self.github_token
|
||||
return self.token
|
||||
|
||||
async def _fetch_data(
|
||||
self, url: str, params: dict | None = None
|
||||
@@ -74,20 +74,20 @@ class GitHubService:
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise GhAuthenticationError('Invalid Github token')
|
||||
raise AuthenticationError('Invalid Github token')
|
||||
|
||||
logger.warning(f'Status error on GH API: {e}')
|
||||
raise GHUnknownException('Unknown error')
|
||||
raise UnknownException('Unknown error')
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
logger.warning(f'HTTP error on GH API: {e}')
|
||||
raise GHUnknownException('Unknown error')
|
||||
raise UnknownException('Unknown error')
|
||||
|
||||
async def get_user(self) -> GitHubUser:
|
||||
async def get_user(self) -> User:
|
||||
url = f'{self.BASE_URL}/user'
|
||||
response, _ = await self._fetch_data(url)
|
||||
|
||||
return GitHubUser(
|
||||
return User(
|
||||
id=response.get('id'),
|
||||
login=response.get('login'),
|
||||
avatar_url=response.get('avatar_url'),
|
||||
@@ -98,7 +98,7 @@ class GitHubService:
|
||||
|
||||
async def get_repositories(
|
||||
self, page: int, per_page: int, sort: str, installation_id: int | None
|
||||
) -> list[GitHubRepository]:
|
||||
) -> list[Repository]:
|
||||
params = {'page': str(page), 'per_page': str(per_page)}
|
||||
if installation_id:
|
||||
url = f'{self.BASE_URL}/user/installations/{installation_id}/repositories'
|
||||
@@ -111,7 +111,7 @@ class GitHubService:
|
||||
|
||||
next_link: str = headers.get('Link', '')
|
||||
repos = [
|
||||
GitHubRepository(
|
||||
Repository(
|
||||
id=repo.get('id'),
|
||||
full_name=repo.get('full_name'),
|
||||
stargazers_count=repo.get('stargazers_count'),
|
||||
@@ -129,7 +129,7 @@ class GitHubService:
|
||||
|
||||
async def search_repositories(
|
||||
self, query: str, per_page: int, sort: str, order: str
|
||||
) -> list[GitHubRepository]:
|
||||
) -> list[Repository]:
|
||||
url = f'{self.BASE_URL}/search/repositories'
|
||||
params = {'q': query, 'per_page': per_page, 'sort': sort, 'order': order}
|
||||
|
||||
@@ -137,7 +137,7 @@ class GitHubService:
|
||||
repos = response.get('items', [])
|
||||
|
||||
repos = [
|
||||
GitHubRepository(
|
||||
Repository(
|
||||
id=repo.get('id'),
|
||||
full_name=repo.get('full_name'),
|
||||
stargazers_count=repo.get('stargazers_count'),
|
||||
@@ -163,7 +163,7 @@ class GitHubService:
|
||||
|
||||
result = response.json()
|
||||
if 'errors' in result:
|
||||
raise GHUnknownException(
|
||||
raise UnknownException(
|
||||
f"GraphQL query error: {json.dumps(result['errors'])}"
|
||||
)
|
||||
|
||||
@@ -171,14 +171,14 @@ class GitHubService:
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise GhAuthenticationError('Invalid Github token')
|
||||
raise AuthenticationError('Invalid Github token')
|
||||
|
||||
logger.warning(f'Status error on GH API: {e}')
|
||||
raise GHUnknownException('Unknown error')
|
||||
raise UnknownException('Unknown error')
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
logger.warning(f'HTTP error on GH API: {e}')
|
||||
raise GHUnknownException('Unknown error')
|
||||
raise UnknownException('Unknown error')
|
||||
|
||||
async def get_suggested_tasks(self) -> list[SuggestedTask]:
|
||||
"""Get suggested tasks for the authenticated user across all repositories.
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class TaskType(str, Enum):
|
||||
MERGE_CONFLICTS = 'MERGE_CONFLICTS'
|
||||
FAILING_CHECKS = 'FAILING_CHECKS'
|
||||
UNRESOLVED_COMMENTS = 'UNRESOLVED_COMMENTS'
|
||||
OPEN_ISSUE = 'OPEN_ISSUE'
|
||||
OPEN_PR = 'OPEN_PR'
|
||||
|
||||
|
||||
class SuggestedTask(BaseModel):
|
||||
task_type: TaskType
|
||||
repo: str
|
||||
issue_number: int
|
||||
title: str
|
||||
|
||||
|
||||
class GitHubUser(BaseModel):
|
||||
id: int
|
||||
login: str
|
||||
avatar_url: str
|
||||
company: str | None = None
|
||||
name: str | None = None
|
||||
email: str | None = None
|
||||
|
||||
|
||||
class GitHubRepository(BaseModel):
|
||||
id: int
|
||||
full_name: str
|
||||
stargazers_count: int | None = None
|
||||
link_header: str | None = None
|
||||
|
||||
|
||||
class GhAuthenticationError(ValueError):
|
||||
"""Raised when there is an issue with GitHub authentication."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class GHUnknownException(ValueError):
|
||||
"""Raised when there is an issue with GitHub communcation."""
|
||||
|
||||
pass
|
||||
@@ -0,0 +1,119 @@
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.integrations.service_types import (
|
||||
AuthenticationError,
|
||||
GitService,
|
||||
Repository,
|
||||
UnknownException,
|
||||
User,
|
||||
)
|
||||
from openhands.utils.import_utils import get_impl
|
||||
|
||||
|
||||
class GitLabService(GitService):
|
||||
BASE_URL = 'https://gitlab.com/api/v4'
|
||||
token: SecretStr = SecretStr('')
|
||||
refresh = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str | None = None,
|
||||
external_auth_token: SecretStr | None = None,
|
||||
token: SecretStr | None = None,
|
||||
external_token_manager: bool = False,
|
||||
):
|
||||
self.user_id = user_id
|
||||
self.external_token_manager = external_token_manager
|
||||
|
||||
if token:
|
||||
self.token = token
|
||||
|
||||
async def _get_gitlab_headers(self) -> dict:
|
||||
"""
|
||||
Retrieve the GitLab Token to construct the headers
|
||||
"""
|
||||
if self.user_id and not self.token:
|
||||
self.token = await self.get_latest_token()
|
||||
|
||||
return {
|
||||
'Authorization': f'Bearer {self.token.get_secret_value()}',
|
||||
}
|
||||
|
||||
def _has_token_expired(self, status_code: int) -> bool:
|
||||
return status_code == 401
|
||||
|
||||
async def get_latest_token(self) -> SecretStr:
|
||||
return self.token
|
||||
|
||||
async def _fetch_data(
|
||||
self, url: str, params: dict | None = None
|
||||
) -> tuple[Any, dict]:
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
gitlab_headers = await self._get_gitlab_headers()
|
||||
response = await client.get(url, headers=gitlab_headers, params=params)
|
||||
if self.refresh and self._has_token_expired(response.status_code):
|
||||
await self.get_latest_token()
|
||||
gitlab_headers = await self._get_gitlab_headers()
|
||||
response = await client.get(
|
||||
url, headers=gitlab_headers, params=params
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
headers = {}
|
||||
if 'Link' in response.headers:
|
||||
headers['Link'] = response.headers['Link']
|
||||
|
||||
return response.json(), headers
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise AuthenticationError('Invalid GitLab token')
|
||||
raise UnknownException('Unknown error')
|
||||
|
||||
except httpx.HTTPError:
|
||||
raise UnknownException('Unknown error')
|
||||
|
||||
async def get_user(self) -> User:
|
||||
url = f'{self.BASE_URL}/user'
|
||||
response, _ = await self._fetch_data(url)
|
||||
|
||||
return User(
|
||||
id=response.get('id'),
|
||||
username=response.get('username'),
|
||||
avatar_url=response.get('avatar_url'),
|
||||
name=response.get('name'),
|
||||
email=response.get('email'),
|
||||
company=response.get('organization'),
|
||||
login=response.get('username'),
|
||||
)
|
||||
|
||||
async def search_repositories(
|
||||
self, query: str, per_page: int = 30, sort: str = 'updated', order: str = 'desc'
|
||||
):
|
||||
url = f'{self.BASE_URL}/search'
|
||||
params = {
|
||||
'scope': 'projects',
|
||||
'search': query,
|
||||
'per_page': per_page,
|
||||
'order_by': sort,
|
||||
'sort': order,
|
||||
}
|
||||
response, headers = await self._fetch_data(url, params)
|
||||
return response, headers
|
||||
|
||||
async def get_repositories(
|
||||
self, page: int, per_page: int, sort: str, installation_id: int | None
|
||||
) -> list[Repository]:
|
||||
return []
|
||||
|
||||
|
||||
gitlab_service_cls = os.environ.get(
|
||||
'OPENHANDS_GITLAB_SERVICE_CLS',
|
||||
'openhands.integrations.gitlab.gitlab_service.GitLabService',
|
||||
)
|
||||
GitLabServiceImpl = get_impl(GitLabService, gitlab_service_cls)
|
||||
@@ -0,0 +1,143 @@
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, SecretStr, SerializationInfo, field_serializer
|
||||
from pydantic.json import pydantic_encoder
|
||||
|
||||
from openhands.integrations.github.github_service import GithubServiceImpl
|
||||
from openhands.integrations.gitlab.gitlab_service import GitLabServiceImpl
|
||||
from openhands.integrations.service_types import (
|
||||
AuthenticationError,
|
||||
GitService,
|
||||
Repository,
|
||||
User,
|
||||
)
|
||||
|
||||
|
||||
class ProviderType(Enum):
|
||||
GITHUB = 'github'
|
||||
GITLAB = 'gitlab'
|
||||
|
||||
|
||||
class ProviderToken(BaseModel):
|
||||
token: SecretStr | None
|
||||
user_id: str | None
|
||||
|
||||
|
||||
PROVIDER_TOKEN_TYPE = dict[ProviderType, ProviderToken]
|
||||
CUSTOM_SECRETS_TYPE = dict[str, SecretStr]
|
||||
|
||||
|
||||
class SecretStore(BaseModel):
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE = {}
|
||||
|
||||
@classmethod
|
||||
def _convert_token(
|
||||
cls, token_value: str | ProviderToken | SecretStr
|
||||
) -> ProviderToken:
|
||||
if isinstance(token_value, ProviderToken):
|
||||
return token_value
|
||||
elif isinstance(token_value, str):
|
||||
return ProviderToken(token=SecretStr(token_value), user_id=None)
|
||||
elif isinstance(token_value, SecretStr):
|
||||
return ProviderToken(token=token_value, user_id=None)
|
||||
else:
|
||||
raise ValueError(f'Invalid token type: {type(token_value)}')
|
||||
|
||||
def model_post_init(self, __context) -> None:
|
||||
# Convert any string tokens to ProviderToken objects
|
||||
converted_tokens = {}
|
||||
for token_type, token_value in self.provider_tokens.items():
|
||||
if token_value: # Only convert non-empty tokens
|
||||
try:
|
||||
if isinstance(token_type, str):
|
||||
token_type = ProviderType(token_type)
|
||||
converted_tokens[token_type] = self._convert_token(token_value)
|
||||
except ValueError:
|
||||
# Skip invalid provider types or tokens
|
||||
continue
|
||||
self.provider_tokens = converted_tokens
|
||||
|
||||
@field_serializer('provider_tokens')
|
||||
def provider_tokens_serializer(
|
||||
self, provider_tokens: PROVIDER_TOKEN_TYPE, info: SerializationInfo
|
||||
):
|
||||
tokens = {}
|
||||
expose_secrets = info.context and info.context.get('expose_secrets', False)
|
||||
|
||||
for token_type, provider_token in provider_tokens.items():
|
||||
if not provider_token or not provider_token.token:
|
||||
continue
|
||||
|
||||
token_type_str = (
|
||||
token_type.value
|
||||
if isinstance(token_type, ProviderType)
|
||||
else str(token_type)
|
||||
)
|
||||
tokens[token_type_str] = {
|
||||
'token': provider_token.token.get_secret_value()
|
||||
if expose_secrets
|
||||
else pydantic_encoder(provider_token.token),
|
||||
'user_id': provider_token.user_id,
|
||||
}
|
||||
|
||||
return tokens
|
||||
|
||||
|
||||
class ProviderHandler:
|
||||
def __init__(
|
||||
self,
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE,
|
||||
external_auth_token: SecretStr | None = None,
|
||||
):
|
||||
self.service_class_map: dict[ProviderType, type[GitService]] = {
|
||||
ProviderType.GITHUB: GithubServiceImpl,
|
||||
ProviderType.GITLAB: GitLabServiceImpl,
|
||||
}
|
||||
|
||||
self.provider_tokens = provider_tokens
|
||||
self.external_auth_token = external_auth_token
|
||||
|
||||
def _get_service(self, provider: ProviderType) -> GitService:
|
||||
"""Helper method to instantiate a service for a given provider"""
|
||||
token = self.provider_tokens[provider]
|
||||
service_class = self.service_class_map[provider]
|
||||
return service_class(
|
||||
user_id=token.user_id,
|
||||
external_auth_token=self.external_auth_token,
|
||||
token=token.token,
|
||||
)
|
||||
|
||||
async def get_user(self) -> User:
|
||||
"""Get user information from the first available provider"""
|
||||
for provider in self.provider_tokens:
|
||||
try:
|
||||
service = self._get_service(provider)
|
||||
return await service.get_user()
|
||||
except Exception:
|
||||
continue
|
||||
raise AuthenticationError('Need valid provider token')
|
||||
|
||||
async def get_latest_provider_tokens(self) -> dict[ProviderType, SecretStr]:
|
||||
"""Get latest token from services"""
|
||||
tokens = {}
|
||||
for provider in self.provider_tokens:
|
||||
service = self._get_service(provider)
|
||||
tokens[provider] = await service.get_latest_token()
|
||||
|
||||
return tokens
|
||||
|
||||
async def get_repositories(
|
||||
self, page: int, per_page: int, sort: str, installation_id: int | None
|
||||
) -> list[Repository]:
|
||||
"""Get repositories from all available providers"""
|
||||
all_repos = []
|
||||
for provider in self.provider_tokens:
|
||||
try:
|
||||
service = self._get_service(provider)
|
||||
repos = await service.get_repositories(
|
||||
page, per_page, sort, installation_id
|
||||
)
|
||||
all_repos.extend(repos)
|
||||
except Exception:
|
||||
continue
|
||||
return all_repos
|
||||
@@ -0,0 +1,89 @@
|
||||
from enum import Enum
|
||||
from typing import Protocol
|
||||
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
|
||||
class TaskType(str, Enum):
|
||||
MERGE_CONFLICTS = 'MERGE_CONFLICTS'
|
||||
FAILING_CHECKS = 'FAILING_CHECKS'
|
||||
UNRESOLVED_COMMENTS = 'UNRESOLVED_COMMENTS'
|
||||
OPEN_ISSUE = 'OPEN_ISSUE'
|
||||
OPEN_PR = 'OPEN_PR'
|
||||
|
||||
|
||||
class SuggestedTask(BaseModel):
|
||||
task_type: TaskType
|
||||
repo: str
|
||||
issue_number: int
|
||||
title: str
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
id: int
|
||||
login: str
|
||||
avatar_url: str
|
||||
company: str | None = None
|
||||
name: str | None = None
|
||||
email: str | None = None
|
||||
|
||||
|
||||
class Repository(BaseModel):
|
||||
id: int
|
||||
full_name: str
|
||||
stargazers_count: int | None = None
|
||||
link_header: str | None = None
|
||||
|
||||
|
||||
class AuthenticationError(ValueError):
|
||||
"""Raised when there is an issue with GitHub authentication."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class UnknownException(ValueError):
|
||||
"""Raised when there is an issue with GitHub communcation."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class GitService(Protocol):
|
||||
"""Protocol defining the interface for Git service providers"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str | None,
|
||||
token: SecretStr | None,
|
||||
external_auth_token: SecretStr | None,
|
||||
external_token_manager: bool = False,
|
||||
) -> None:
|
||||
"""Initialize the service with authentication details"""
|
||||
...
|
||||
|
||||
async def get_latest_token(self) -> SecretStr:
|
||||
"""Get latest working token of the users"""
|
||||
...
|
||||
|
||||
async def get_user(self) -> User:
|
||||
"""Get the authenticated user's information"""
|
||||
...
|
||||
|
||||
async def search_repositories(
|
||||
self,
|
||||
query: str,
|
||||
per_page: int,
|
||||
sort: str,
|
||||
order: str,
|
||||
) -> list[Repository]:
|
||||
"""Search for repositories"""
|
||||
...
|
||||
|
||||
async def get_repositories(
|
||||
self,
|
||||
page: int,
|
||||
per_page: int,
|
||||
sort: str,
|
||||
installation_id: int | None,
|
||||
) -> list[Repository]:
|
||||
"""Get repositories for the authenticated user"""
|
||||
...
|
||||
@@ -0,0 +1,37 @@
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.integrations.github.github_service import GitHubService
|
||||
from openhands.integrations.gitlab.gitlab_service import GitLabService
|
||||
from openhands.integrations.provider import ProviderType
|
||||
|
||||
|
||||
async def validate_provider_token(token: SecretStr) -> ProviderType | None:
|
||||
"""
|
||||
Determine whether a token is for GitHub or GitLab by attempting to get user info
|
||||
from both services.
|
||||
|
||||
Args:
|
||||
token: The token to check
|
||||
|
||||
Returns:
|
||||
'github' if it's a GitHub token
|
||||
'gitlab' if it's a GitLab token
|
||||
None if the token is invalid for both services
|
||||
"""
|
||||
# Try GitHub first
|
||||
try:
|
||||
github_service = GitHubService(token=token)
|
||||
await github_service.get_user()
|
||||
return ProviderType.GITHUB
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Try GitLab next
|
||||
try:
|
||||
gitlab_service = GitLabService(token=token)
|
||||
await gitlab_service.get_user()
|
||||
return ProviderType.GITLAB
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from openhands.core.config.condenser_config import LLMSummarizingCondenserConfig
|
||||
from openhands.core.message import Message, TextContent
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation.agent import AgentCondensationObservation
|
||||
from openhands.llm import LLM
|
||||
@@ -90,13 +91,10 @@ INTENT: Fix precision while maintaining FITS compliance"""
|
||||
for forgotten_event in forgotten_events:
|
||||
prompt += str(forgotten_event) + '\n\n'
|
||||
|
||||
messages = [Message(role='user', content=[TextContent(text=prompt)])]
|
||||
|
||||
response = self.llm.completion(
|
||||
messages=[
|
||||
{
|
||||
'content': prompt,
|
||||
'role': 'user',
|
||||
},
|
||||
],
|
||||
messages=self.llm.format_messages_for_llm(messages),
|
||||
)
|
||||
summary = response.choices[0].message.content
|
||||
|
||||
|
||||
@@ -158,7 +158,7 @@ class ActionExecutor:
|
||||
self.bash_session: BashSession | None = None
|
||||
self.lock = asyncio.Lock()
|
||||
self.plugins: dict[str, Plugin] = {}
|
||||
self.file_editor = OHEditor()
|
||||
self.file_editor = OHEditor(workspace_root=self._initial_cwd)
|
||||
self.browser = BrowserEnv(browsergym_eval_env)
|
||||
self.start_time = time.time()
|
||||
self.last_execution_time = self.start_time
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
from fastapi import Request
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderType
|
||||
|
||||
|
||||
def get_provider_tokens(request: Request) -> PROVIDER_TOKEN_TYPE | None:
|
||||
"""Get GitHub token from request state. For backward compatibility."""
|
||||
return getattr(request.state, 'provider_tokens', None)
|
||||
|
||||
|
||||
def get_access_token(request: Request) -> SecretStr | None:
|
||||
return getattr(request.state, 'access_token', None)
|
||||
@@ -11,8 +18,18 @@ def get_user_id(request: Request) -> str | None:
|
||||
|
||||
|
||||
def get_github_token(request: Request) -> SecretStr | None:
|
||||
return getattr(request.state, 'github_token', None)
|
||||
provider_tokens = get_provider_tokens(request)
|
||||
|
||||
if provider_tokens and ProviderType.GITHUB in provider_tokens:
|
||||
return provider_tokens[ProviderType.GITHUB].token
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_github_user_id(request: Request) -> str | None:
|
||||
return getattr(request.state, 'github_user_id', None)
|
||||
provider_tokens = get_provider_tokens(request)
|
||||
|
||||
if provider_tokens and ProviderType.GITHUB in provider_tokens:
|
||||
return provider_tokens[ProviderType.GITHUB].user_id
|
||||
|
||||
return None
|
||||
|
||||
@@ -194,10 +194,14 @@ class GitHubTokenMiddleware(SessionMiddlewareInterface):
|
||||
settings = await settings_store.load()
|
||||
|
||||
# TODO: To avoid checks like this we should re-add the abilty to have completely different middleware in SAAS as in OSS
|
||||
if getattr(request.state, 'github_token', None) is None:
|
||||
if settings and settings.github_token:
|
||||
request.state.github_token = settings.github_token
|
||||
if getattr(request.state, 'provider_tokens', None) is None:
|
||||
if (
|
||||
settings
|
||||
and settings.secrets_store
|
||||
and settings.secrets_store.provider_tokens
|
||||
):
|
||||
request.state.provider_tokens = settings.secrets_store.provider_tokens
|
||||
else:
|
||||
request.state.github_token = None
|
||||
request.state.provider_tokens = None
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
+148
-120
@@ -3,147 +3,168 @@ from fastapi.responses import JSONResponse
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.integrations.github.github_service import GithubServiceImpl
|
||||
from openhands.integrations.github.github_types import (
|
||||
GhAuthenticationError,
|
||||
GHUnknownException,
|
||||
GitHubRepository,
|
||||
GitHubUser,
|
||||
SuggestedTask,
|
||||
from openhands.integrations.provider import (
|
||||
PROVIDER_TOKEN_TYPE,
|
||||
ProviderHandler,
|
||||
ProviderType,
|
||||
)
|
||||
from openhands.server.auth import get_access_token, get_github_token, get_github_user_id
|
||||
from openhands.integrations.service_types import (
|
||||
AuthenticationError,
|
||||
Repository,
|
||||
SuggestedTask,
|
||||
UnknownException,
|
||||
User,
|
||||
)
|
||||
from openhands.server.auth import get_access_token, get_provider_tokens
|
||||
|
||||
app = APIRouter(prefix='/api/github')
|
||||
|
||||
|
||||
@app.get('/repositories', response_model=list[GitHubRepository])
|
||||
@app.get('/repositories', response_model=list[Repository])
|
||||
async def get_github_repositories(
|
||||
page: int = 1,
|
||||
per_page: int = 10,
|
||||
sort: str = 'pushed',
|
||||
installation_id: int | None = None,
|
||||
github_user_id: str | None = Depends(get_github_user_id),
|
||||
github_user_token: SecretStr | None = Depends(get_github_token),
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens),
|
||||
access_token: SecretStr | None = Depends(get_access_token),
|
||||
):
|
||||
client = GithubServiceImpl(
|
||||
user_id=github_user_id,
|
||||
external_auth_token=access_token,
|
||||
github_token=github_user_token,
|
||||
if provider_tokens and ProviderType.GITHUB in provider_tokens:
|
||||
token = provider_tokens[ProviderType.GITHUB]
|
||||
client = GithubServiceImpl(
|
||||
user_id=token.user_id, external_auth_token=access_token, token=token.token
|
||||
)
|
||||
|
||||
try:
|
||||
repos: list[Repository] = await client.get_repositories(
|
||||
page, per_page, sort, installation_id
|
||||
)
|
||||
return repos
|
||||
|
||||
except AuthenticationError as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
except UnknownException as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
content='GitHub token required.',
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
try:
|
||||
repos: list[GitHubRepository] = await client.get_repositories(
|
||||
page, per_page, sort, installation_id
|
||||
)
|
||||
return repos
|
||||
|
||||
except GhAuthenticationError as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
except GHUnknownException as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
|
||||
@app.get('/user', response_model=GitHubUser)
|
||||
@app.get('/user', response_model=User)
|
||||
async def get_github_user(
|
||||
github_user_id: str | None = Depends(get_github_user_id),
|
||||
github_user_token: SecretStr | None = Depends(get_github_token),
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens),
|
||||
access_token: SecretStr | None = Depends(get_access_token),
|
||||
):
|
||||
client = GithubServiceImpl(
|
||||
user_id=github_user_id,
|
||||
external_auth_token=access_token,
|
||||
github_token=github_user_token,
|
||||
if provider_tokens:
|
||||
client = ProviderHandler(provider_tokens=provider_tokens, external_auth_token=access_token)
|
||||
|
||||
try:
|
||||
user: User = await client.get_user()
|
||||
return user
|
||||
|
||||
except AuthenticationError as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
except UnknownException as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
content='GitHub token required.',
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
try:
|
||||
user: GitHubUser = await client.get_user()
|
||||
return user
|
||||
|
||||
except GhAuthenticationError as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
except GHUnknownException as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
|
||||
@app.get('/installations', response_model=list[int])
|
||||
async def get_github_installation_ids(
|
||||
github_user_id: str | None = Depends(get_github_user_id),
|
||||
github_user_token: SecretStr | None = Depends(get_github_token),
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens),
|
||||
access_token: SecretStr | None = Depends(get_access_token),
|
||||
):
|
||||
client = GithubServiceImpl(
|
||||
user_id=github_user_id,
|
||||
external_auth_token=access_token,
|
||||
github_token=github_user_token,
|
||||
if provider_tokens and ProviderType.GITHUB in provider_tokens:
|
||||
token = provider_tokens[ProviderType.GITHUB]
|
||||
|
||||
client = GithubServiceImpl(
|
||||
user_id=token.user_id, external_auth_token=access_token, token=token.token
|
||||
)
|
||||
try:
|
||||
installations_ids: list[int] = await client.get_installation_ids()
|
||||
return installations_ids
|
||||
|
||||
except AuthenticationError as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
except UnknownException as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
content='GitHub token required.',
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
try:
|
||||
installations_ids: list[int] = await client.get_installation_ids()
|
||||
return installations_ids
|
||||
|
||||
except GhAuthenticationError as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
except GHUnknownException as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
|
||||
@app.get('/search/repositories', response_model=list[GitHubRepository])
|
||||
@app.get('/search/repositories', response_model=list[Repository])
|
||||
async def search_github_repositories(
|
||||
query: str,
|
||||
per_page: int = 5,
|
||||
sort: str = 'stars',
|
||||
order: str = 'desc',
|
||||
github_user_id: str | None = Depends(get_github_user_id),
|
||||
github_user_token: SecretStr | None = Depends(get_github_token),
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens),
|
||||
access_token: SecretStr | None = Depends(get_access_token),
|
||||
):
|
||||
client = GithubServiceImpl(
|
||||
user_id=github_user_id,
|
||||
external_auth_token=access_token,
|
||||
github_token=github_user_token,
|
||||
if provider_tokens and ProviderType.GITHUB in provider_tokens:
|
||||
token = provider_tokens[ProviderType.GITHUB]
|
||||
|
||||
client = GithubServiceImpl(
|
||||
user_id=token.user_id, external_auth_token=access_token, token=token.token
|
||||
)
|
||||
try:
|
||||
repos: list[Repository] = await client.search_repositories(
|
||||
query, per_page, sort, order
|
||||
)
|
||||
return repos
|
||||
|
||||
except AuthenticationError as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
except UnknownException as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
content='GitHub token required.',
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
try:
|
||||
repos: list[GitHubRepository] = await client.search_repositories(
|
||||
query, per_page, sort, order
|
||||
)
|
||||
return repos
|
||||
|
||||
except GhAuthenticationError as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
except GHUnknownException as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
|
||||
@app.get('/suggested-tasks', response_model=list[SuggestedTask])
|
||||
async def get_suggested_tasks(
|
||||
github_user_id: str | None = Depends(get_github_user_id),
|
||||
github_user_token: SecretStr | None = Depends(get_github_token),
|
||||
access_token: SecretStr | None = Depends(get_access_token),
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens),
|
||||
access_token: SecretStr | None = Depends(get_access_token)
|
||||
):
|
||||
"""Get suggested tasks for the authenticated user across their most recently pushed repositories.
|
||||
|
||||
@@ -151,23 +172,30 @@ async def get_suggested_tasks(
|
||||
- PRs owned by the user
|
||||
- Issues assigned to the user.
|
||||
"""
|
||||
client = GithubServiceImpl(
|
||||
user_id=github_user_id,
|
||||
external_auth_token=access_token,
|
||||
github_token=github_user_token,
|
||||
|
||||
if provider_tokens and ProviderType.GITHUB in provider_tokens:
|
||||
token = provider_tokens[ProviderType.GITHUB]
|
||||
|
||||
client = GithubServiceImpl(
|
||||
user_id=token.user_id, external_auth_token=access_token, token=token.token
|
||||
)
|
||||
try:
|
||||
tasks: list[SuggestedTask] = await client.get_suggested_tasks()
|
||||
return tasks
|
||||
|
||||
except AuthenticationError as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=401,
|
||||
)
|
||||
|
||||
except UnknownException as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
content='GitHub token required.',
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
try:
|
||||
tasks: list[SuggestedTask] = await client.get_suggested_tasks()
|
||||
return tasks
|
||||
|
||||
except GhAuthenticationError as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=401,
|
||||
)
|
||||
|
||||
except GHUnknownException as e:
|
||||
return JSONResponse(
|
||||
content=str(e),
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
@@ -8,8 +8,9 @@ from pydantic import BaseModel, SecretStr
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.integrations.github.github_service import GithubServiceImpl
|
||||
from openhands.integrations.provider import ProviderType
|
||||
from openhands.runtime import get_runtime_cls
|
||||
from openhands.server.auth import get_access_token, get_github_token, get_github_user_id
|
||||
from openhands.server.auth import get_provider_tokens, get_access_token, get_github_user_id
|
||||
from openhands.server.data_models.conversation_info import ConversationInfo
|
||||
from openhands.server.data_models.conversation_info_result_set import (
|
||||
ConversationInfoResultSet,
|
||||
@@ -136,13 +137,18 @@ async def new_conversation(request: Request, data: InitSessionRequest):
|
||||
using the returned conversation ID.
|
||||
"""
|
||||
logger.info('Initializing new conversation')
|
||||
user_id = get_github_user_id(request)
|
||||
gh_client = GithubServiceImpl(
|
||||
user_id=user_id,
|
||||
external_auth_token=get_access_token(request),
|
||||
github_token=get_github_token(request),
|
||||
)
|
||||
github_token = await gh_client.get_latest_token()
|
||||
user_id = None
|
||||
github_token = None
|
||||
provider_tokens = get_provider_tokens(request)
|
||||
if provider_tokens and ProviderType.GITHUB in provider_tokens:
|
||||
token = provider_tokens[ProviderType.GITHUB]
|
||||
user_id = token.user_id
|
||||
gh_client = GithubServiceImpl(
|
||||
user_id=user_id,
|
||||
external_auth_token=get_access_token(request),
|
||||
token=token.token,
|
||||
)
|
||||
github_token = await gh_client.get_latest_token()
|
||||
|
||||
selected_repository = data.selected_repository
|
||||
selected_branch = data.selected_branch
|
||||
|
||||
@@ -3,8 +3,9 @@ from fastapi.responses import JSONResponse
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.github.github_service import GithubServiceImpl
|
||||
from openhands.server.auth import get_github_token, get_user_id
|
||||
from openhands.integrations.provider import ProviderToken, ProviderType
|
||||
from openhands.integrations.utils import validate_provider_token
|
||||
from openhands.server.auth import get_provider_tokens, get_user_id
|
||||
from openhands.server.settings import GETSettingsModel, POSTSettingsModel, Settings
|
||||
from openhands.server.shared import SettingsStoreImpl, config
|
||||
|
||||
@@ -23,14 +24,14 @@ async def load_settings(request: Request) -> GETSettingsModel | JSONResponse:
|
||||
content={'error': 'Settings not found'},
|
||||
)
|
||||
|
||||
token_is_set = bool(user_id) or bool(get_github_token(request))
|
||||
github_token_is_set = bool(user_id) or bool(get_provider_tokens(request))
|
||||
settings_with_token_data = GETSettingsModel(
|
||||
**settings.model_dump(),
|
||||
github_token_is_set=token_is_set,
|
||||
github_token_is_set=github_token_is_set,
|
||||
)
|
||||
settings_with_token_data.llm_api_key = settings.llm_api_key
|
||||
|
||||
del settings_with_token_data.github_token
|
||||
del settings_with_token_data.secrets_store
|
||||
return settings_with_token_data
|
||||
except Exception as e:
|
||||
logger.warning(f'Invalid token: {e}')
|
||||
@@ -45,26 +46,27 @@ async def store_settings(
|
||||
request: Request,
|
||||
settings: POSTSettingsModel,
|
||||
) -> JSONResponse:
|
||||
# Check if token is valid
|
||||
if settings.github_token:
|
||||
try:
|
||||
# We check if the token is valid by getting the user
|
||||
# If the token is invalid, this will raise an exception
|
||||
github = GithubServiceImpl(
|
||||
user_id=None,
|
||||
external_auth_token=None,
|
||||
github_token=SecretStr(settings.github_token),
|
||||
)
|
||||
await github.get_user()
|
||||
# Check provider tokens are valid
|
||||
if settings.provider_tokens:
|
||||
# Remove extraneous token types
|
||||
provider_types = [provider.value for provider in ProviderType]
|
||||
settings.provider_tokens = {
|
||||
k: v for k, v in settings.provider_tokens.items() if k in provider_types
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f'Invalid GitHub token: {e}')
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={
|
||||
'error': 'Invalid GitHub token. Please make sure it is valid.'
|
||||
},
|
||||
)
|
||||
# Determine whether tokens are valid
|
||||
for token_type, token_value in settings.provider_tokens.items():
|
||||
if token_value:
|
||||
confirmed_token_type = await validate_provider_token(
|
||||
SecretStr(token_value)
|
||||
)
|
||||
if not confirmed_token_type or confirmed_token_type.value != token_type:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={
|
||||
'error': f'Invalid token. Please make sure it is a valid {token_type} token.'
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
settings_store = await SettingsStoreImpl.get_instance(
|
||||
@@ -72,32 +74,46 @@ async def store_settings(
|
||||
)
|
||||
existing_settings = await settings_store.load()
|
||||
|
||||
# Convert to Settings model and merge with existing settings
|
||||
if existing_settings:
|
||||
# LLM key isn't on the frontend, so we need to keep it if unset
|
||||
# Keep existing LLM settings if not provided
|
||||
if settings.llm_api_key is None:
|
||||
settings.llm_api_key = existing_settings.llm_api_key
|
||||
if settings.llm_model is None:
|
||||
settings.llm_model = existing_settings.llm_model
|
||||
if settings.llm_base_url is None:
|
||||
settings.llm_base_url = existing_settings.llm_base_url
|
||||
|
||||
if settings.github_token is None:
|
||||
settings.github_token = existing_settings.github_token
|
||||
|
||||
# Keep existing analytics consent if not provided
|
||||
if settings.user_consents_to_analytics is None:
|
||||
settings.user_consents_to_analytics = (
|
||||
existing_settings.user_consents_to_analytics
|
||||
)
|
||||
|
||||
if settings.llm_model is None:
|
||||
settings.llm_model = existing_settings.llm_model
|
||||
if existing_settings.secrets_store:
|
||||
existing_providers = [
|
||||
provider.value
|
||||
for provider in existing_settings.secrets_store.provider_tokens
|
||||
]
|
||||
|
||||
if settings.llm_base_url is None:
|
||||
settings.llm_base_url = existing_settings.llm_base_url
|
||||
# Merge incoming settings store with the existing one
|
||||
for provider, token_value in settings.provider_tokens.items():
|
||||
if provider in existing_providers and not token_value:
|
||||
provider_type = ProviderType(provider)
|
||||
existing_token = (
|
||||
existing_settings.secrets_store.provider_tokens.get(
|
||||
provider_type
|
||||
)
|
||||
)
|
||||
if existing_token and existing_token.token:
|
||||
settings.provider_tokens[provider] = (
|
||||
existing_token.token.get_secret_value()
|
||||
)
|
||||
|
||||
response = JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={'message': 'Settings stored'},
|
||||
)
|
||||
|
||||
if settings.unset_github_token:
|
||||
settings.github_token = None
|
||||
# Merge provider tokens with existing ones
|
||||
if settings.unset_github_token: # Only merge if not unsetting tokens
|
||||
settings.secrets_store.provider_tokens = {}
|
||||
settings.provider_tokens = {}
|
||||
|
||||
# Update sandbox config with new settings
|
||||
if settings.remote_runtime_resource_factor is not None:
|
||||
@@ -106,9 +122,11 @@ async def store_settings(
|
||||
)
|
||||
|
||||
settings = convert_to_settings(settings)
|
||||
|
||||
await settings_store.store(settings)
|
||||
return response
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={'message': 'Settings stored'},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f'Something went wrong storing settings: {e}')
|
||||
return JSONResponse(
|
||||
@@ -127,8 +145,19 @@ def convert_to_settings(settings_with_token_data: POSTSettingsModel) -> Settings
|
||||
if key in Settings.model_fields # Ensures only `Settings` fields are included
|
||||
}
|
||||
|
||||
# Convert the `llm_api_key` and `github_token` to a `SecretStr` instance
|
||||
# Convert the `llm_api_key` to a `SecretStr` instance
|
||||
filtered_settings_data['llm_api_key'] = settings_with_token_data.llm_api_key
|
||||
filtered_settings_data['github_token'] = settings_with_token_data.github_token
|
||||
|
||||
return Settings(**filtered_settings_data)
|
||||
# Create a new Settings instance without provider tokens
|
||||
settings = Settings(**filtered_settings_data)
|
||||
|
||||
# Update provider tokens if any are provided
|
||||
if settings_with_token_data.provider_tokens:
|
||||
for token_type, token_value in settings_with_token_data.provider_tokens.items():
|
||||
if token_value:
|
||||
provider = ProviderType(token_type)
|
||||
settings.secrets_store.provider_tokens[provider] = ProviderToken(
|
||||
token=SecretStr(token_value), user_id=None
|
||||
)
|
||||
|
||||
return settings
|
||||
|
||||
@@ -1,10 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, SecretStr, SerializationInfo, field_serializer
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
SecretStr,
|
||||
SerializationInfo,
|
||||
field_serializer,
|
||||
model_validator,
|
||||
)
|
||||
from pydantic.json import pydantic_encoder
|
||||
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.config.utils import load_app_config
|
||||
from openhands.integrations.provider import ProviderToken, ProviderType, SecretStore
|
||||
|
||||
|
||||
class Settings(BaseModel):
|
||||
@@ -21,7 +28,7 @@ class Settings(BaseModel):
|
||||
llm_api_key: SecretStr | None = None
|
||||
llm_base_url: str | None = None
|
||||
remote_runtime_resource_factor: int | None = None
|
||||
github_token: SecretStr | None = None
|
||||
secrets_store: SecretStore = SecretStore()
|
||||
enable_default_condenser: bool = False
|
||||
enable_sound_notifications: bool = False
|
||||
user_consents_to_analytics: bool | None = None
|
||||
@@ -38,22 +45,63 @@ class Settings(BaseModel):
|
||||
|
||||
return pydantic_encoder(llm_api_key)
|
||||
|
||||
@field_serializer('github_token')
|
||||
def github_token_serializer(
|
||||
self, github_token: SecretStr | None, info: SerializationInfo
|
||||
):
|
||||
"""Custom serializer for the GitHub token.
|
||||
@staticmethod
|
||||
def _convert_token_value(
|
||||
token_type: ProviderType, token_value: str | dict
|
||||
) -> ProviderToken | None:
|
||||
"""Convert a token value to a ProviderToken object."""
|
||||
if isinstance(token_value, dict):
|
||||
token_str = token_value.get('token')
|
||||
if not token_str:
|
||||
return None
|
||||
return ProviderToken(
|
||||
token=SecretStr(token_str),
|
||||
user_id=token_value.get('user_id'),
|
||||
)
|
||||
if isinstance(token_value, str) and token_value:
|
||||
return ProviderToken(token=SecretStr(token_value), user_id=None)
|
||||
return None
|
||||
|
||||
To serialize the token instead of ********, set expose_secrets to True in the serialization context.
|
||||
"""
|
||||
if github_token is None:
|
||||
return None
|
||||
@model_validator(mode='before')
|
||||
@classmethod
|
||||
def convert_provider_tokens(cls, data: dict | object) -> dict | object:
|
||||
"""Convert provider tokens from JSON format to SecretStore format."""
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
|
||||
context = info.context
|
||||
if context and context.get('expose_secrets', False):
|
||||
return github_token.get_secret_value()
|
||||
secrets_store = data.get('secrets_store')
|
||||
if not isinstance(secrets_store, dict):
|
||||
return data
|
||||
|
||||
return pydantic_encoder(github_token)
|
||||
tokens = secrets_store.get('provider_tokens')
|
||||
if not isinstance(tokens, dict):
|
||||
return data
|
||||
|
||||
converted_tokens = {}
|
||||
for token_type_str, token_value in tokens.items():
|
||||
if not token_value:
|
||||
continue
|
||||
|
||||
try:
|
||||
token_type = ProviderType(token_type_str)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
provider_token = cls._convert_token_value(token_type, token_value)
|
||||
if provider_token:
|
||||
converted_tokens[token_type] = provider_token
|
||||
|
||||
data['secrets_store'] = SecretStore(provider_tokens=converted_tokens)
|
||||
return data
|
||||
|
||||
@field_serializer('secrets_store')
|
||||
def secrets_store_serializer(self, secrets: SecretStore, info: SerializationInfo):
|
||||
"""Custom serializer for secrets store."""
|
||||
return {
|
||||
'provider_tokens': secrets.provider_tokens_serializer(
|
||||
secrets.provider_tokens, info
|
||||
)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def from_config() -> Settings | None:
|
||||
@@ -73,7 +121,7 @@ class Settings(BaseModel):
|
||||
llm_api_key=llm_config.api_key,
|
||||
llm_base_url=llm_config.base_url,
|
||||
remote_runtime_resource_factor=app_config.sandbox.remote_runtime_resource_factor,
|
||||
github_token=None,
|
||||
provider_tokens={},
|
||||
)
|
||||
return settings
|
||||
|
||||
@@ -84,14 +132,12 @@ class POSTSettingsModel(Settings):
|
||||
"""
|
||||
|
||||
unset_github_token: bool | None = None
|
||||
github_token: str | None = (
|
||||
None # This is a string because it's coming from the frontend
|
||||
)
|
||||
# Override provider_tokens to accept string tokens from frontend
|
||||
provider_tokens: dict[str, str] = {}
|
||||
|
||||
# Override the serializer for the GitHub token to handle the string input
|
||||
@field_serializer('github_token')
|
||||
def github_token_serializer(self, github_token: str | None):
|
||||
return github_token
|
||||
@field_serializer('provider_tokens')
|
||||
def provider_tokens_serializer(self, provider_tokens: dict[str, str]):
|
||||
return provider_tokens
|
||||
|
||||
|
||||
class GETSettingsModel(Settings):
|
||||
|
||||
Generated
+36
-30
@@ -706,14 +706,14 @@ virtualenv = ["virtualenv (>=20.0.35)"]
|
||||
|
||||
[[package]]
|
||||
name = "cachetools"
|
||||
version = "5.5.1"
|
||||
version = "5.5.2"
|
||||
description = "Extensible memoizing collections and decorators"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["main", "evaluation"]
|
||||
files = [
|
||||
{file = "cachetools-5.5.1-py3-none-any.whl", hash = "sha256:b76651fdc3b24ead3c648bbdeeb940c1b04d365b38b4af66788f9ec4a81d42bb"},
|
||||
{file = "cachetools-5.5.1.tar.gz", hash = "sha256:70f238fbba50383ef62e55c6aff6d9673175fe59f7c6782c7a0b9e38f4a9df95"},
|
||||
{file = "cachetools-5.5.2-py3-none-any.whl", hash = "sha256:d26a22bcc62eb95c3beabd9f1ee5e820d3d2704fe2967cbe350e20c8ffcd3f0a"},
|
||||
{file = "cachetools-5.5.2.tar.gz", hash = "sha256:1a661caa9175d26759571b2e19580f9d6393969e5dfca11fdb1f947a23e640d4"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2199,14 +2199,14 @@ grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"]
|
||||
|
||||
[[package]]
|
||||
name = "google-api-python-client"
|
||||
version = "2.163.0"
|
||||
version = "2.164.0"
|
||||
description = "Google API Client Library for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "google_api_python_client-2.163.0-py2.py3-none-any.whl", hash = "sha256:080e8bc0669cb4c1fb8efb8da2f5b91a2625d8f0e7796cfad978f33f7016c6c4"},
|
||||
{file = "google_api_python_client-2.163.0.tar.gz", hash = "sha256:88dee87553a2d82176e2224648bf89272d536c8f04dcdda37ef0a71473886dd7"},
|
||||
{file = "google_api_python_client-2.164.0-py2.py3-none-any.whl", hash = "sha256:b2037c3d280793c8d5180b04317b16be4acd5f77af5dfa7213ace32d140a9ffe"},
|
||||
{file = "google_api_python_client-2.164.0.tar.gz", hash = "sha256:116f5a05dfb95ed7f7ea0d0f561fc5464146709c583226cc814690f9bb221492"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -3808,14 +3808,14 @@ types-tqdm = "*"
|
||||
|
||||
[[package]]
|
||||
name = "litellm"
|
||||
version = "1.63.6"
|
||||
version = "1.63.8"
|
||||
description = "Library to easily interface with LLM API providers"
|
||||
optional = false
|
||||
python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,>=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "litellm-1.63.6-py3-none-any.whl", hash = "sha256:68d3d1f6c062851702e3f7ad99a204cf2b67f8b79d8608f22b72664e7de641c6"},
|
||||
{file = "litellm-1.63.6.tar.gz", hash = "sha256:b8fd6eca1f17f7d1101f38a90689b5f1a2f42a828c70299e0cb570297e5fb9ae"},
|
||||
{file = "litellm-1.63.8-py3-none-any.whl", hash = "sha256:12615acf16d34b444e13cb9faab89466f63a22330e72e30c7d35e12ebd526188"},
|
||||
{file = "litellm-1.63.8.tar.gz", hash = "sha256:ae7324fb93a0da2dfd05f8fa301c3ac20dfce05d4651bdb005aeb64c88a76672"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -3825,7 +3825,7 @@ httpx = ">=0.23.0"
|
||||
importlib-metadata = ">=6.8.0"
|
||||
jinja2 = ">=3.1.2,<4.0.0"
|
||||
jsonschema = ">=4.22.0,<5.0.0"
|
||||
openai = ">=1.61.0"
|
||||
openai = ">=1.66.1"
|
||||
pydantic = ">=2.0.0,<3.0.0"
|
||||
python-dotenv = ">=0.2.0"
|
||||
tiktoken = ">=0.7.0"
|
||||
@@ -4251,14 +4251,14 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "modal"
|
||||
version = "0.73.93"
|
||||
version = "0.73.98"
|
||||
description = "Python client library for Modal"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main", "evaluation"]
|
||||
files = [
|
||||
{file = "modal-0.73.93-py3-none-any.whl", hash = "sha256:ceec0b456c8332b956c607adf67c0aee7df53f5d7985acafede1616928422a69"},
|
||||
{file = "modal-0.73.93.tar.gz", hash = "sha256:4b8dc338172edbedb85b94c5588169f1a1d44ec7e6765b5843fea405e599dcdd"},
|
||||
{file = "modal-0.73.98-py3-none-any.whl", hash = "sha256:a49cd5f5b46d1a6c6a0d528618d3cbb73ac2908e199716590ec3a5275d79ed98"},
|
||||
{file = "modal-0.73.98.tar.gz", hash = "sha256:817f73c222fa39a16d6888a92eb7a6847ecae574e44ef04e2dce5e534bdd2df9"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -4794,14 +4794,14 @@ signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"]
|
||||
|
||||
[[package]]
|
||||
name = "openai"
|
||||
version = "1.66.2"
|
||||
version = "1.66.3"
|
||||
description = "The official Python library for the openai API"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main", "evaluation", "test"]
|
||||
files = [
|
||||
{file = "openai-1.66.2-py3-none-any.whl", hash = "sha256:75194057ee6bb8b732526387b6041327a05656d976fc21c064e21c8ac6b07999"},
|
||||
{file = "openai-1.66.2.tar.gz", hash = "sha256:9b3a843c25f81ee09b6469d483d9fba779d5c6ea41861180772f043481b0598d"},
|
||||
{file = "openai-1.66.3-py3-none-any.whl", hash = "sha256:a427c920f727711877ab17c11b95f1230b27767ba7a01e5b66102945141ceca9"},
|
||||
{file = "openai-1.66.3.tar.gz", hash = "sha256:8dde3aebe2d081258d4159c4cb27bdc13b5bb3f7ea2201d9bd940b9a89faf0c9"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -4820,18 +4820,18 @@ realtime = ["websockets (>=13,<15)"]
|
||||
|
||||
[[package]]
|
||||
name = "openhands-aci"
|
||||
version = "0.2.5"
|
||||
version = "0.2.6"
|
||||
description = "An Agent-Computer Interface (ACI) designed for software development agents OpenHands."
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.12"
|
||||
python-versions = "^3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_aci-0.2.5-py3-none-any.whl", hash = "sha256:775a3ea9eacf090ff6fa6819dcc449a359a770f2d25232890441a799b0bd3c2e"},
|
||||
{file = "openhands_aci-0.2.5.tar.gz", hash = "sha256:cfa51834771fb7f35cc754f04ee3b6d8d985df79a6fa4bdd0f57a8a20e9f0883"},
|
||||
]
|
||||
files = []
|
||||
develop = false
|
||||
|
||||
[package.dependencies]
|
||||
binaryornot = ">=0.4.4,<0.5.0"
|
||||
binaryornot = "^0.4.4"
|
||||
cachetools = "^5.5.2"
|
||||
chardet = "^5.0.0"
|
||||
flake8 = "*"
|
||||
gitpython = "*"
|
||||
grep-ast = "0.3.3"
|
||||
@@ -4840,12 +4840,18 @@ networkx = "*"
|
||||
numpy = "*"
|
||||
pandas = "*"
|
||||
scipy = "*"
|
||||
tree-sitter = ">=0.24.0,<0.25.0"
|
||||
tree-sitter-javascript = ">=0.23.1,<0.24.0"
|
||||
tree-sitter-python = ">=0.23.6,<0.24.0"
|
||||
tree-sitter-ruby = ">=0.23.1,<0.24.0"
|
||||
tree-sitter-typescript = ">=0.23.2,<0.24.0"
|
||||
whatthepatch = ">=1.0.6,<2.0.0"
|
||||
tree-sitter = "^0.24.0"
|
||||
tree-sitter-javascript = "^0.23.1"
|
||||
tree-sitter-python = "^0.23.6"
|
||||
tree-sitter-ruby = "^0.23.1"
|
||||
tree-sitter-typescript = "^0.23.2"
|
||||
whatthepatch = "^1.0.6"
|
||||
|
||||
[package.source]
|
||||
type = "git"
|
||||
url = "https://github.com/All-Hands-AI/openhands-aci.git"
|
||||
reference = "add-encoding-detection"
|
||||
resolved_reference = "040d9578d90894409f51ecca877b120fe696fe0b"
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-api"
|
||||
@@ -9050,4 +9056,4 @@ testing = ["coverage[toml]", "zope.event", "zope.testing"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = "^3.12"
|
||||
content-hash = "3a03748738e9eaf8e82b1a672cbf64a611d1ae7ecf57ebc71eb1a5113c3f7440"
|
||||
content-hash = "6a644bc65782a717a49718496bd279ecb888807ec625d992af4448cc5d9271c1"
|
||||
|
||||
+1
-1
@@ -67,7 +67,7 @@ runloop-api-client = "0.26.0"
|
||||
libtmux = ">=0.37,<0.40"
|
||||
pygithub = "^2.5.0"
|
||||
joblib = "*"
|
||||
openhands-aci = "^0.2.5"
|
||||
openhands-aci = "^0.2.6"
|
||||
python-socketio = "^5.11.4"
|
||||
redis = "^5.2.0"
|
||||
sse-starlette = "^2.1.3"
|
||||
|
||||
@@ -15,6 +15,7 @@ from openhands.core.config.condenser_config import (
|
||||
RecentEventsCondenserConfig,
|
||||
)
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.message import Message, TextContent
|
||||
from openhands.events.event import Event, EventSource
|
||||
from openhands.events.observation import BrowserOutputObservation
|
||||
from openhands.events.observation.agent import AgentCondensationObservation
|
||||
@@ -80,6 +81,10 @@ def mock_llm() -> LLM:
|
||||
# Attach helper methods to the mock object
|
||||
mock_llm.set_mock_response_content = set_mock_response_content
|
||||
|
||||
mock_llm.format_messages_for_llm = lambda events: [
|
||||
Message(role='user', content=[TextContent(text=str(event))]) for event in events
|
||||
]
|
||||
|
||||
return mock_llm
|
||||
|
||||
|
||||
|
||||
@@ -5,16 +5,16 @@ import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.integrations.github.github_service import GitHubService
|
||||
from openhands.integrations.github.github_types import GhAuthenticationError
|
||||
from openhands.integrations.service_types import AuthenticationError
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_github_service_token_handling():
|
||||
# Test initialization with SecretStr token
|
||||
token = SecretStr('test-token')
|
||||
service = GitHubService(user_id=None, github_token=token)
|
||||
assert service.github_token == token
|
||||
assert service.github_token.get_secret_value() == 'test-token'
|
||||
service = GitHubService(user_id=None, token=token)
|
||||
assert service.token == token
|
||||
assert service.token.get_secret_value() == 'test-token'
|
||||
|
||||
# Test headers contain the token correctly
|
||||
headers = await service._get_github_headers()
|
||||
@@ -23,14 +23,14 @@ async def test_github_service_token_handling():
|
||||
|
||||
# Test initialization without token
|
||||
service = GitHubService(user_id='test-user')
|
||||
assert service.github_token == SecretStr('')
|
||||
assert service.token == SecretStr('')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_github_service_token_refresh():
|
||||
# Test that token refresh is only attempted when refresh=True
|
||||
token = SecretStr('test-token')
|
||||
service = GitHubService(user_id=None, github_token=token)
|
||||
service = GitHubService(user_id=None, token=token)
|
||||
assert not service.refresh
|
||||
|
||||
# Test token expiry detection
|
||||
@@ -58,7 +58,7 @@ async def test_github_service_fetch_data():
|
||||
mock_client.__aexit__.return_value = None
|
||||
|
||||
with patch('httpx.AsyncClient', return_value=mock_client):
|
||||
service = GitHubService(user_id=None, github_token=SecretStr('test-token'))
|
||||
service = GitHubService(user_id=None, token=SecretStr('test-token'))
|
||||
_ = await service._fetch_data('https://api.github.com/user')
|
||||
|
||||
# Verify the request was made with correct headers
|
||||
@@ -77,5 +77,5 @@ async def test_github_service_fetch_data():
|
||||
mock_client.get.reset_mock()
|
||||
mock_client.get.return_value = mock_response
|
||||
|
||||
with pytest.raises(GhAuthenticationError):
|
||||
with pytest.raises(AuthenticationError):
|
||||
_ = await service._fetch_data('https://api.github.com/user')
|
||||
|
||||
+48
-63
@@ -614,8 +614,8 @@ class TestStuckDetector:
|
||||
message_observation = NullObservation(content='')
|
||||
state.history.append(message_observation)
|
||||
|
||||
# Add three consecutive condensation events (should detect as stuck)
|
||||
for _ in range(3):
|
||||
# Add ten consecutive condensation events (should detect as stuck)
|
||||
for _ in range(10):
|
||||
condensation = AgentCondensationObservation(
|
||||
content='Trimming prompt to meet context window limitations'
|
||||
)
|
||||
@@ -638,42 +638,39 @@ class TestStuckDetector:
|
||||
message_observation = NullObservation(content='')
|
||||
state.history.append(message_observation)
|
||||
|
||||
# Add condensation events with other events between them
|
||||
condensation1 = AgentCondensationObservation(
|
||||
content='Trimming prompt to meet context window limitations'
|
||||
)
|
||||
state.history.append(condensation1)
|
||||
# Add 10 condensation events with other events between them
|
||||
for i in range(10):
|
||||
# Add a condensation event
|
||||
condensation = AgentCondensationObservation(
|
||||
content='Trimming prompt to meet context window limitations'
|
||||
)
|
||||
state.history.append(condensation)
|
||||
|
||||
# Add some other events between condensation events
|
||||
cmd_action = CmdRunAction(command='ls')
|
||||
state.history.append(cmd_action)
|
||||
cmd_observation = CmdOutputObservation(
|
||||
command='ls', content='file1.txt\nfile2.txt'
|
||||
)
|
||||
state.history.append(cmd_observation)
|
||||
# Add some other events between condensation events (except after the last one)
|
||||
if i < 9:
|
||||
# Add a command action and observation
|
||||
cmd_action = CmdRunAction(command=f'ls {i}')
|
||||
state.history.append(cmd_action)
|
||||
cmd_observation = CmdOutputObservation(
|
||||
command=f'ls {i}', content='file1.txt\nfile2.txt'
|
||||
)
|
||||
state.history.append(cmd_observation)
|
||||
|
||||
condensation2 = AgentCondensationObservation(
|
||||
content='Trimming prompt to meet context window limitations'
|
||||
)
|
||||
state.history.append(condensation2)
|
||||
|
||||
# Add more other events
|
||||
read_action = FileReadAction(path='file1.txt')
|
||||
state.history.append(read_action)
|
||||
read_observation = FileReadObservation(content='File content', path='file1.txt')
|
||||
state.history.append(read_observation)
|
||||
|
||||
condensation3 = AgentCondensationObservation(
|
||||
content='Trimming prompt to meet context window limitations'
|
||||
)
|
||||
state.history.append(condensation3)
|
||||
# Add a file read action and observation for even iterations
|
||||
if i % 2 == 0:
|
||||
read_action = FileReadAction(path=f'file{i}.txt')
|
||||
state.history.append(read_action)
|
||||
read_observation = FileReadObservation(
|
||||
content=f'File content {i}', path=f'file{i}.txt'
|
||||
)
|
||||
state.history.append(read_observation)
|
||||
|
||||
with patch('logging.Logger.warning') as mock_warning:
|
||||
assert stuck_detector.is_stuck(headless_mode=True) is False
|
||||
mock_warning.assert_not_called()
|
||||
|
||||
def test_is_not_stuck_context_window_error_less_than_three(self, stuck_detector):
|
||||
"""Test that we don't detect a loop with less than three condensation events."""
|
||||
def test_is_not_stuck_context_window_error_less_than_ten(self, stuck_detector):
|
||||
"""Test that we don't detect a loop with less than ten condensation events."""
|
||||
state = stuck_detector.state
|
||||
|
||||
# Add some initial events
|
||||
@@ -683,8 +680,8 @@ class TestStuckDetector:
|
||||
message_observation = NullObservation(content='')
|
||||
state.history.append(message_observation)
|
||||
|
||||
# Add only two condensation events (should not detect as stuck)
|
||||
for _ in range(2):
|
||||
# Add only nine condensation events (should not detect as stuck)
|
||||
for _ in range(9):
|
||||
condensation = AgentCondensationObservation(
|
||||
content='Trimming prompt to meet context window limitations'
|
||||
)
|
||||
@@ -695,7 +692,7 @@ class TestStuckDetector:
|
||||
mock_warning.assert_not_called()
|
||||
|
||||
def test_is_stuck_context_window_error_with_user_messages(self, stuck_detector):
|
||||
"""Test that we still detect a loop even with user messages between condensation events.
|
||||
"""Test that we still detect a loop even with user messages between condensation events in headless mode.
|
||||
|
||||
User messages are filtered out in the stuck detection logic, so they shouldn't
|
||||
prevent us from detecting a loop of condensation events.
|
||||
@@ -709,35 +706,23 @@ class TestStuckDetector:
|
||||
message_observation = NullObservation(content='')
|
||||
state.history.append(message_observation)
|
||||
|
||||
# Add condensation events with user messages between them
|
||||
condensation1 = AgentCondensationObservation(
|
||||
content='Trimming prompt to meet context window limitations'
|
||||
)
|
||||
state.history.append(condensation1)
|
||||
# Add condensation events with user messages between them (total of 10)
|
||||
for i in range(10):
|
||||
# Add a condensation event
|
||||
condensation = AgentCondensationObservation(
|
||||
content='Trimming prompt to meet context window limitations'
|
||||
)
|
||||
state.history.append(condensation)
|
||||
|
||||
# Add user message between condensation events
|
||||
user_message = MessageAction(content='Please continue', wait_for_response=False)
|
||||
user_message._source = EventSource.USER
|
||||
state.history.append(user_message)
|
||||
user_observation = NullObservation(content='')
|
||||
state.history.append(user_observation)
|
||||
|
||||
condensation2 = AgentCondensationObservation(
|
||||
content='Trimming prompt to meet context window limitations'
|
||||
)
|
||||
state.history.append(condensation2)
|
||||
|
||||
# Add another user message
|
||||
user_message2 = MessageAction(content='Keep going', wait_for_response=False)
|
||||
user_message2._source = EventSource.USER
|
||||
state.history.append(user_message2)
|
||||
user_observation2 = NullObservation(content='')
|
||||
state.history.append(user_observation2)
|
||||
|
||||
condensation3 = AgentCondensationObservation(
|
||||
content='Trimming prompt to meet context window limitations'
|
||||
)
|
||||
state.history.append(condensation3)
|
||||
# Add user message between condensation events (except after the last one)
|
||||
if i < 9:
|
||||
user_message = MessageAction(
|
||||
content=f'Please continue {i}', wait_for_response=False
|
||||
)
|
||||
user_message._source = EventSource.USER
|
||||
state.history.append(user_message)
|
||||
user_observation = NullObservation(content='')
|
||||
state.history.append(user_observation)
|
||||
|
||||
with patch('logging.Logger.warning') as mock_warning:
|
||||
assert stuck_detector.is_stuck(headless_mode=True) is True
|
||||
@@ -754,7 +739,7 @@ class TestStuckDetector:
|
||||
state = stuck_detector.state
|
||||
|
||||
# Add condensation events first
|
||||
for _ in range(3):
|
||||
for _ in range(10):
|
||||
condensation = AgentCondensationObservation(
|
||||
content='Trimming prompt to meet context window limitations'
|
||||
)
|
||||
|
||||
@@ -1,159 +0,0 @@
|
||||
"""Test the retry mechanism in RemoteRuntime."""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch, call
|
||||
|
||||
import requests
|
||||
import tenacity
|
||||
from requests.exceptions import ConnectionError
|
||||
|
||||
from openhands.core.config.sandbox_config import SandboxConfig
|
||||
from openhands.runtime.impl.remote.remote_runtime import RemoteRuntime
|
||||
from openhands.runtime.utils.request import send_request
|
||||
from openhands.utils.tenacity_stop import stop_if_should_exit
|
||||
|
||||
|
||||
class TestRemoteRuntimeRetry(unittest.TestCase):
|
||||
"""Test the retry mechanism in RemoteRuntime."""
|
||||
|
||||
def test_retry_decorator_exists(self):
|
||||
"""Test that the retry decorator is used when remote_runtime_enable_retries=True."""
|
||||
# This test verifies that the code path for retry exists in the RemoteRuntime class
|
||||
|
||||
# Check if the methods exist
|
||||
self.assertTrue(hasattr(RemoteRuntime, '_send_action_server_request'))
|
||||
self.assertTrue(hasattr(RemoteRuntime, '_send_action_server_request_impl'))
|
||||
|
||||
# Check the source code of the class to verify it contains retry logic
|
||||
import inspect
|
||||
source = inspect.getsource(RemoteRuntime)
|
||||
self.assertIn('remote_runtime_enable_retries', source)
|
||||
self.assertIn('retry_decorator', source)
|
||||
self.assertIn('retry_if_exception_type', source)
|
||||
|
||||
@patch('tenacity.retry')
|
||||
def test_retry_decorator_called_with_correct_params(self, mock_retry):
|
||||
"""Test that the retry decorator is called with correct parameters."""
|
||||
# Setup
|
||||
mock_retry.return_value = lambda f: f # Make retry a pass-through decorator
|
||||
|
||||
# Create a runtime instance with remote_runtime_enable_retries=True
|
||||
runtime = MagicMock()
|
||||
runtime._runtime_closed = False
|
||||
runtime._stop_if_closed = lambda x: False
|
||||
|
||||
config = MagicMock()
|
||||
sandbox_config = SandboxConfig()
|
||||
sandbox_config.remote_runtime_enable_retries = True
|
||||
config.sandbox = sandbox_config
|
||||
runtime.config = config
|
||||
|
||||
# Mock super() to return a simple object
|
||||
with patch('openhands.runtime.impl.remote.remote_runtime.super') as mock_super:
|
||||
mock_super.return_value._send_action_server_request = lambda *args, **kwargs: "mocked response"
|
||||
|
||||
# Call the method
|
||||
RemoteRuntime._send_action_server_request(runtime, "GET", "http://example.com")
|
||||
|
||||
# Verify retry was called with ConnectionError
|
||||
mock_retry.assert_called()
|
||||
# Get the first positional argument of the first call
|
||||
retry_args = mock_retry.call_args[1]
|
||||
|
||||
# Check that retry is configured for ConnectionError
|
||||
self.assertIn('retry', retry_args)
|
||||
|
||||
# Check that stop conditions include stop_if_should_exit
|
||||
self.assertIn('stop', retry_args)
|
||||
|
||||
def test_connection_error_not_retried_when_disabled(self):
|
||||
"""Test that ConnectionError is not retried when remote_runtime_enable_retries=False."""
|
||||
# Create a runtime instance with remote_runtime_enable_retries=False
|
||||
runtime = MagicMock()
|
||||
|
||||
config = MagicMock()
|
||||
sandbox_config = SandboxConfig()
|
||||
sandbox_config.remote_runtime_enable_retries = False # Disable retries
|
||||
config.sandbox = sandbox_config
|
||||
runtime.config = config
|
||||
|
||||
# Mock _send_action_server_request_impl to raise ConnectionError
|
||||
runtime._send_action_server_request_impl = MagicMock(side_effect=ConnectionError())
|
||||
|
||||
# Call the method - should raise ConnectionError without retrying
|
||||
with self.assertRaises(ConnectionError):
|
||||
RemoteRuntime._send_action_server_request(runtime, "GET", "http://example.com")
|
||||
|
||||
# Verify _send_action_server_request_impl was called exactly once (no retry)
|
||||
self.assertEqual(runtime._send_action_server_request_impl.call_count, 1)
|
||||
|
||||
@patch('tenacity.retry')
|
||||
def test_connection_error_retried_when_enabled(self, mock_retry):
|
||||
"""Test that ConnectionError is retried when remote_runtime_enable_retries=True."""
|
||||
# Setup a mock retry decorator that will call the function with retries
|
||||
def mock_retry_decorator(retry_func):
|
||||
def wrapper(*args, **kwargs):
|
||||
# Simulate retry behavior by calling the function twice
|
||||
try:
|
||||
return retry_func(*args, **kwargs)
|
||||
except ConnectionError:
|
||||
# On first ConnectionError, try again and return success
|
||||
return "success after retry"
|
||||
return wrapper
|
||||
|
||||
mock_retry.return_value = mock_retry_decorator
|
||||
|
||||
# Create a runtime instance with remote_runtime_enable_retries=True
|
||||
runtime = MagicMock()
|
||||
runtime._runtime_closed = False
|
||||
runtime._stop_if_closed = lambda x: False
|
||||
|
||||
config = MagicMock()
|
||||
sandbox_config = SandboxConfig()
|
||||
sandbox_config.remote_runtime_enable_retries = True # Enable retries
|
||||
config.sandbox = sandbox_config
|
||||
runtime.config = config
|
||||
|
||||
# Mock _send_action_server_request_impl to raise ConnectionError on first call
|
||||
impl_mock = MagicMock()
|
||||
impl_mock.side_effect = [ConnectionError(), "success"]
|
||||
runtime._send_action_server_request_impl = impl_mock
|
||||
|
||||
# Call the method - should retry and succeed
|
||||
result = RemoteRuntime._send_action_server_request(runtime, "GET", "http://example.com")
|
||||
|
||||
# Verify retry was called
|
||||
mock_retry.assert_called()
|
||||
|
||||
# The result should be "success after retry" from our mock decorator
|
||||
self.assertEqual(result, "success after retry")
|
||||
|
||||
|
||||
def test_tenacity_retry_with_connection_error(self):
|
||||
"""Test that tenacity retry works with ConnectionError."""
|
||||
# Create a function that will raise ConnectionError on first call
|
||||
call_count = [0]
|
||||
|
||||
@tenacity.retry(
|
||||
retry=tenacity.retry_if_exception_type(ConnectionError),
|
||||
stop=tenacity.stop_after_attempt(3),
|
||||
wait=tenacity.wait_exponential(multiplier=0.1, min=0.1, max=1),
|
||||
)
|
||||
def function_with_retry():
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
raise ConnectionError("Connection refused")
|
||||
return "success"
|
||||
|
||||
# Call the function - should retry and succeed
|
||||
result = function_with_retry()
|
||||
|
||||
# Verify the function was called twice (retry happened)
|
||||
self.assertEqual(call_count[0], 2)
|
||||
|
||||
# Verify the result is correct
|
||||
self.assertEqual(result, "success")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -6,6 +6,7 @@ from openhands.core.config.app_config import AppConfig
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.config.sandbox_config import SandboxConfig
|
||||
from openhands.core.config.security_config import SecurityConfig
|
||||
from openhands.integrations.provider import ProviderToken, ProviderType
|
||||
from openhands.server.routes.settings import convert_to_settings
|
||||
from openhands.server.settings import POSTSettingsModel, Settings
|
||||
|
||||
@@ -43,7 +44,7 @@ def test_settings_from_config():
|
||||
assert settings.llm_api_key.get_secret_value() == 'test-key'
|
||||
assert settings.llm_base_url == 'https://test.example.com'
|
||||
assert settings.remote_runtime_resource_factor == 2
|
||||
assert settings.github_token is None
|
||||
assert not settings.secrets_store.provider_tokens
|
||||
|
||||
|
||||
def test_settings_from_config_no_api_key():
|
||||
@@ -80,23 +81,41 @@ def test_settings_handles_sensitive_data():
|
||||
llm_api_key='test-key',
|
||||
llm_base_url='https://test.example.com',
|
||||
remote_runtime_resource_factor=2,
|
||||
github_token='test-token',
|
||||
)
|
||||
settings.secrets_store.provider_tokens[ProviderType.GITHUB] = ProviderToken(
|
||||
token=SecretStr('test-token'),
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
assert str(settings.llm_api_key) == '**********'
|
||||
assert str(settings.github_token) == '**********'
|
||||
assert (
|
||||
str(settings.secrets_store.provider_tokens[ProviderType.GITHUB].token)
|
||||
== '**********'
|
||||
)
|
||||
|
||||
assert settings.llm_api_key.get_secret_value() == 'test-key'
|
||||
assert settings.github_token.get_secret_value() == 'test-token'
|
||||
assert (
|
||||
settings.secrets_store.provider_tokens[
|
||||
ProviderType.GITHUB
|
||||
].token.get_secret_value()
|
||||
== 'test-token'
|
||||
)
|
||||
|
||||
|
||||
def test_convert_to_settings():
|
||||
settings_with_token_data = POSTSettingsModel(
|
||||
llm_api_key='test-key',
|
||||
github_token='test-token',
|
||||
provider_tokens={
|
||||
'github': 'test-token',
|
||||
},
|
||||
)
|
||||
|
||||
settings = convert_to_settings(settings_with_token_data)
|
||||
|
||||
assert settings.llm_api_key.get_secret_value() == 'test-key'
|
||||
assert settings.github_token.get_secret_value() == 'test-token'
|
||||
assert (
|
||||
settings.secrets_store.provider_tokens[
|
||||
ProviderType.GITHUB
|
||||
].token.get_secret_value()
|
||||
== 'test-token'
|
||||
)
|
||||
|
||||
@@ -5,6 +5,7 @@ from fastapi.testclient import TestClient
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.core.config.sandbox_config import SandboxConfig
|
||||
from openhands.integrations.provider import ProviderType, SecretStore
|
||||
from openhands.server.app import app
|
||||
from openhands.server.settings import Settings
|
||||
|
||||
@@ -19,6 +20,24 @@ def mock_settings_store():
|
||||
yield store_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_user_id():
|
||||
with patch('openhands.server.routes.settings.get_user_id') as mock:
|
||||
mock.return_value = 'test-user'
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_validate_provider_token():
|
||||
with patch('openhands.server.routes.settings.validate_provider_token') as mock:
|
||||
|
||||
async def mock_determine(*args, **kwargs):
|
||||
return ProviderType.GITHUB
|
||||
|
||||
mock.side_effect = mock_determine
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_client(mock_settings_store):
|
||||
# Mock the middleware that adds github_token
|
||||
@@ -28,9 +47,15 @@ def test_client(mock_settings_store):
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
settings = mock_settings_store.load.return_value
|
||||
token = settings.github_token if settings else None
|
||||
token = None
|
||||
if settings and settings.secrets_store.provider_tokens.get(
|
||||
ProviderType.GITHUB
|
||||
):
|
||||
token = settings.secrets_store.provider_tokens[
|
||||
ProviderType.GITHUB
|
||||
].token
|
||||
if scope['type'] == 'http':
|
||||
scope['state'] = {'github_token': token}
|
||||
scope['state'] = {'token': token}
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
# Replace the middleware
|
||||
@@ -47,7 +72,9 @@ def mock_github_service():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_settings_api_runtime_factor(test_client, mock_settings_store):
|
||||
async def test_settings_api_runtime_factor(
|
||||
test_client, mock_settings_store, mock_get_user_id, mock_validate_provider_token
|
||||
):
|
||||
# Mock the settings store to return None initially (no existing settings)
|
||||
mock_settings_store.load.return_value = None
|
||||
|
||||
@@ -62,6 +89,7 @@ async def test_settings_api_runtime_factor(test_client, mock_settings_store):
|
||||
'llm_api_key': 'test-key',
|
||||
'llm_base_url': 'https://test.com',
|
||||
'remote_runtime_resource_factor': 2,
|
||||
'provider_tokens': {'github': 'test-token'},
|
||||
}
|
||||
|
||||
# The test_client fixture already handles authentication
|
||||
@@ -98,12 +126,17 @@ async def test_settings_api_runtime_factor(test_client, mock_settings_store):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_settings_llm_api_key(test_client, mock_settings_store):
|
||||
async def test_settings_llm_api_key(
|
||||
test_client, mock_settings_store, mock_get_user_id, mock_validate_provider_token
|
||||
):
|
||||
# Mock the settings store to return None initially (no existing settings)
|
||||
mock_settings_store.load.return_value = None
|
||||
|
||||
# Test data with remote_runtime_resource_factor
|
||||
settings_data = {'llm_api_key': 'test-key'}
|
||||
settings_data = {
|
||||
'llm_api_key': 'test-key',
|
||||
'provider_tokens': {'github': 'test-token'},
|
||||
}
|
||||
|
||||
# The test_client fixture already handles authentication
|
||||
|
||||
@@ -132,9 +165,13 @@ async def test_settings_llm_api_key(test_client, mock_settings_store):
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_settings_api_set_github_token(
|
||||
mock_github_service, test_client, mock_settings_store
|
||||
mock_github_service,
|
||||
test_client,
|
||||
mock_settings_store,
|
||||
mock_get_user_id,
|
||||
mock_validate_provider_token,
|
||||
):
|
||||
# Test data with github_token set
|
||||
# Test data with provider token set
|
||||
settings_data = {
|
||||
'language': 'en',
|
||||
'agent': 'test-agent',
|
||||
@@ -144,16 +181,21 @@ async def test_settings_api_set_github_token(
|
||||
'llm_model': 'test-model',
|
||||
'llm_api_key': 'test-key',
|
||||
'llm_base_url': 'https://test.com',
|
||||
'github_token': 'test-token',
|
||||
'provider_tokens': {'github': 'test-token'},
|
||||
}
|
||||
|
||||
# Make the POST request to store settings
|
||||
response = test_client.post('/api/settings', json=settings_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify the settings were stored with the github_token
|
||||
# Verify the settings were stored with the provider token
|
||||
stored_settings = mock_settings_store.store.call_args[0][0]
|
||||
assert stored_settings.github_token == 'test-token'
|
||||
assert (
|
||||
stored_settings.secrets_store.provider_tokens[
|
||||
ProviderType.GITHUB
|
||||
].token.get_secret_value()
|
||||
== 'test-token'
|
||||
)
|
||||
|
||||
# Mock settings store to return our settings for the GET request
|
||||
mock_settings_store.load.return_value = Settings(**settings_data)
|
||||
@@ -163,17 +205,21 @@ async def test_settings_api_set_github_token(
|
||||
data = response.json()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert data.get('github_token') is None
|
||||
assert data['github_token_is_set'] is True
|
||||
assert data.get('token') is None
|
||||
assert data['token_is_set'] is True
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason='Mock middleware does not seem to properly set the github_token'
|
||||
)
|
||||
async def test_settings_unset_github_token(
|
||||
mock_github_service, test_client, mock_settings_store
|
||||
mock_github_service,
|
||||
test_client,
|
||||
mock_settings_store,
|
||||
mock_get_user_id,
|
||||
mock_validate_provider_token,
|
||||
):
|
||||
# Test data with unset_github_token set to True
|
||||
# Test data with unset_token set to True
|
||||
settings_data = {
|
||||
'language': 'en',
|
||||
'agent': 'test-agent',
|
||||
@@ -183,7 +229,7 @@ async def test_settings_unset_github_token(
|
||||
'llm_model': 'test-model',
|
||||
'llm_api_key': 'test-key',
|
||||
'llm_base_url': 'https://test.com',
|
||||
'github_token': 'test-token',
|
||||
'provider_tokens': {'github': 'test-token'},
|
||||
}
|
||||
|
||||
# Mock settings store to return our settings for the GET request
|
||||
@@ -191,23 +237,23 @@ async def test_settings_unset_github_token(
|
||||
|
||||
response = test_client.get('/api/settings')
|
||||
assert response.status_code == 200
|
||||
assert response.json()['github_token_is_set'] is True
|
||||
assert response.json()['token_is_set'] is True
|
||||
|
||||
settings_data['unset_github_token'] = True
|
||||
settings_data['unset_token'] = True
|
||||
|
||||
# Make the POST request to store settings
|
||||
response = test_client.post('/api/settings', json=settings_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify the settings were stored with the github_token unset
|
||||
# Verify the settings were stored with the provider token unset
|
||||
stored_settings = mock_settings_store.store.call_args[0][0]
|
||||
assert stored_settings.github_token is None
|
||||
assert not stored_settings.secrets_store.provider_tokens
|
||||
mock_settings_store.load.return_value = Settings(**stored_settings.dict())
|
||||
|
||||
# Make a GET request to retrieve settings
|
||||
response = test_client.get('/api/settings')
|
||||
assert response.status_code == 200
|
||||
assert response.json()['github_token_is_set'] is False
|
||||
assert response.json()['token_is_set'] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -222,6 +268,7 @@ async def test_settings_preserve_llm_fields_when_none(test_client, mock_settings
|
||||
llm_model='existing-model',
|
||||
llm_api_key=SecretStr('existing-key'),
|
||||
llm_base_url='https://existing.com',
|
||||
secrets_store=SecretStore(),
|
||||
)
|
||||
|
||||
# Mock the settings store to return our initial settings
|
||||
|
||||
@@ -3,13 +3,13 @@ from unittest.mock import AsyncMock
|
||||
import pytest
|
||||
|
||||
from openhands.integrations.github.github_service import GitHubService
|
||||
from openhands.integrations.github.github_types import GitHubUser, TaskType
|
||||
from openhands.integrations.service_types import TaskType, User
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_suggested_tasks():
|
||||
# Mock responses
|
||||
mock_user = GitHubUser(
|
||||
mock_user = User(
|
||||
id=1,
|
||||
login='test-user',
|
||||
avatar_url='https://example.com/avatar.jpg',
|
||||
|
||||
Reference in New Issue
Block a user