mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
102 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| cf877b5628 | |||
| fb9958aff8 | |||
| c1f5861eaf | |||
| fa7f58b7c5 | |||
| a691bec7fc | |||
| 7eb77c131d | |||
| 858870a095 | |||
| d65e5b5e46 | |||
| 2b0816f53a | |||
| ab9536dc6b | |||
| 9f5888315a | |||
| fcefb872b6 | |||
| b91cd0570e | |||
| 94b45c6c36 | |||
| cbc380fe49 | |||
| fb776ef650 | |||
| a75b576f1c | |||
| 63956c3292 | |||
| f75141af3e | |||
| e18775b391 | |||
| 495de35139 | |||
| be29d89b3c | |||
| e4515b21eb | |||
| a8f6a35341 | |||
| f706a217d0 | |||
| 2890b8c6ff | |||
| 39a846ccc3 | |||
| 0137201903 | |||
| 49a98885ab | |||
| 38648bddb3 | |||
| b44774d2be | |||
| 04330898b6 | |||
| 120fd7516a | |||
| 2224127ac3 | |||
| 2d1e9fa35b | |||
| 0ec962e96b | |||
| 3a9f00aa37 | |||
| e02dbb8974 | |||
| 8039807c3f | |||
| cae7b6e72f | |||
| 7ca41486be | |||
| 81c02623a1 | |||
| a96760eea7 | |||
| dcb2e21b87 | |||
| 7edebcbc0c | |||
| abd1f9948f | |||
| 2879e58781 | |||
| 1d1ffc2be0 | |||
| db41148396 | |||
| 39a4ca422f | |||
| 6d86803f41 | |||
| 8e0386c416 | |||
| 48cd85e47e | |||
| c62b47dcb1 | |||
| eb9a822d4c | |||
| fb7333aa62 | |||
| fb23418803 | |||
| 991585c05d | |||
| 35a40ddee8 | |||
| 5d1f9f815a | |||
| d3bf989e77 | |||
| 6589e592e3 | |||
| fe4c0569f7 | |||
| 38dcf959bc | |||
| ef3acf726c | |||
| 017d758a76 | |||
| 3ed37e18ac | |||
| 28ecf06404 | |||
| 26fa1185a4 | |||
| d3a8b037f2 | |||
| af1fa8961a | |||
| 3b215c4ad1 | |||
| 7516b53f5a | |||
| 855ef7ba5f | |||
| 09ca1b882f | |||
| 1322f944be | |||
| 5925483f6b | |||
| 0144424c8e | |||
| f07ce85b45 | |||
| bc5a46dcee | |||
| 9990870060 | |||
| bab9a45590 | |||
| b4107ff9dc | |||
| 3e04713097 | |||
| 77f868081c | |||
| 3a12924bc8 | |||
| cfa7def554 | |||
| 33d6a11abf | |||
| 71d5aa5aa8 | |||
| 90d2681e34 | |||
| 565a5702c3 | |||
| 4b9097068d | |||
| c9a5834164 | |||
| 19a089aa4b | |||
| 918c44d164 | |||
| e06e20a5ba | |||
| 430ee1c9fd | |||
| a03377698c | |||
| 9dab5b1bbf | |||
| 135d5fbd38 | |||
| ad615ebc8b | |||
| 424f6b30d1 |
@@ -0,0 +1,202 @@
|
||||
---
|
||||
name: cross-repo-testing
|
||||
description: This skill should be used when the user asks to "test a cross-repo feature", "deploy a feature branch to staging", "test SDK against OH Cloud", "e2e test a cloud workspace feature", "test provider tokens", "test secrets inheritance", or when changes span the SDK and OpenHands server repos and need end-to-end validation against a staging deployment.
|
||||
triggers:
|
||||
- cross-repo
|
||||
- staging deployment
|
||||
- feature branch deploy
|
||||
- test against cloud
|
||||
- e2e cloud
|
||||
---
|
||||
|
||||
# Cross-Repo Testing: SDK ↔ OpenHands Cloud
|
||||
|
||||
How to end-to-end test features that span `OpenHands/software-agent-sdk` and `OpenHands/OpenHands` (the Cloud backend).
|
||||
|
||||
## Repository Map
|
||||
|
||||
| Repo | Role | What lives here |
|
||||
|------|------|-----------------|
|
||||
| [`software-agent-sdk`](https://github.com/OpenHands/software-agent-sdk) | Agent core | `openhands-sdk`, `openhands-workspace`, `openhands-tools` packages. `OpenHandsCloudWorkspace` lives here. |
|
||||
| [`OpenHands`](https://github.com/OpenHands/OpenHands) | Cloud backend | FastAPI server (`openhands/app_server/`), sandbox management, auth, enterprise integrations. Deployed as OH Cloud. |
|
||||
| [`deploy`](https://github.com/OpenHands/deploy) | Infrastructure | Helm charts + GitHub Actions that build the enterprise Docker image and deploy to staging/production. |
|
||||
|
||||
**Data flow:** SDK client → OH Cloud API (`/api/v1/...`) → sandbox agent-server (inside runtime container)
|
||||
|
||||
## When You Need This
|
||||
|
||||
There are **two flows** depending on which direction the dependency goes:
|
||||
|
||||
| Flow | When | Example |
|
||||
|------|------|---------|
|
||||
| **A — SDK client → new Cloud API** | The SDK calls an API that doesn't exist yet on production | `workspace.get_llm()` calling `GET /api/v1/users/me?expose_secrets=true` |
|
||||
| **B — OH server → new SDK code** | The Cloud server needs unreleased SDK packages or a new agent-server image | Server consumes a new tool, agent behavior, or workspace method from the SDK |
|
||||
|
||||
Flow A only requires deploying the server PR. Flow B requires pinning the SDK to an unreleased commit in the server PR **and** using the SDK PR's agent-server image. Both flows may apply simultaneously.
|
||||
|
||||
---
|
||||
|
||||
## Flow A: SDK Client Tests Against New Cloud API
|
||||
|
||||
Use this when the SDK calls an endpoint that only exists on the server PR branch.
|
||||
|
||||
### A1. Write and test the server-side changes
|
||||
|
||||
In the `OpenHands` repo, implement the new API endpoint(s). Run unit tests:
|
||||
|
||||
```bash
|
||||
cd OpenHands
|
||||
poetry run pytest tests/unit/app_server/test_<relevant>.py -v
|
||||
```
|
||||
|
||||
Push a PR. Wait for the **"Push Enterprise Image" (Docker) CI job** to succeed — this builds `ghcr.io/openhands/enterprise-server:sha-<COMMIT>`.
|
||||
|
||||
### A2. Write the SDK-side changes
|
||||
|
||||
In `software-agent-sdk`, implement the client code (e.g., new methods on `OpenHandsCloudWorkspace`). Run SDK unit tests:
|
||||
|
||||
```bash
|
||||
cd software-agent-sdk
|
||||
pip install -e openhands-sdk -e openhands-workspace
|
||||
pytest tests/ -v
|
||||
```
|
||||
|
||||
Push a PR. SDK CI is independent — it doesn't need the server changes to pass unit tests.
|
||||
|
||||
### A3. Deploy the server PR to staging
|
||||
|
||||
See [Deploying to a Staging Feature Environment](#deploying-to-a-staging-feature-environment) below.
|
||||
|
||||
### A4. Run the SDK e2e test against staging
|
||||
|
||||
See [Running E2E Tests Against Staging](#running-e2e-tests-against-staging) below.
|
||||
|
||||
---
|
||||
|
||||
## Flow B: OH Server Needs Unreleased SDK Code
|
||||
|
||||
Use this when the Cloud server depends on SDK changes that haven't been released to PyPI yet. The server's runtime containers run the `agent-server` image built from the SDK repo, so the server PR must be configured to use the SDK PR's image and packages.
|
||||
|
||||
### B1. Get the SDK PR merged (or identify the commit)
|
||||
|
||||
The SDK PR must have CI pass so its agent-server Docker image is built. The image is tagged with the **merge-commit SHA** from GitHub Actions — NOT the head-commit SHA shown in the PR.
|
||||
|
||||
Find the correct image tag:
|
||||
- Check the SDK PR description for an `AGENT_SERVER_IMAGES` section
|
||||
- Or check the "Consolidate Build Information" CI job for `"short_sha": "<tag>"`
|
||||
|
||||
### B2. Pin SDK packages to the commit in the OpenHands PR
|
||||
|
||||
In the `OpenHands` repo PR, pin all 3 SDK packages (`openhands-sdk`, `openhands-agent-server`, `openhands-tools`) to the unreleased commit and update the agent-server image tag. This involves editing 3 files and regenerating 3 lock files.
|
||||
|
||||
Follow the **`update-sdk` skill** → "Development: Pin SDK to an Unreleased Commit" section for the full procedure and file-by-file instructions.
|
||||
|
||||
### B3. Wait for the OpenHands enterprise image to build
|
||||
|
||||
Push the pinned changes. The OpenHands CI will build a new enterprise Docker image (`ghcr.io/openhands/enterprise-server:sha-<OH_COMMIT>`) that bundles the unreleased SDK. Wait for the "Push Enterprise Image" job to succeed.
|
||||
|
||||
### B4. Deploy and test
|
||||
|
||||
Follow [Deploying to a Staging Feature Environment](#deploying-to-a-staging-feature-environment) using the new OpenHands commit SHA.
|
||||
|
||||
### B5. Before merging: remove the pin
|
||||
|
||||
**CI guard:** `check-package-versions.yml` blocks merge to `main` if `[tool.poetry.dependencies]` contains `rev` fields. Before the OpenHands PR can merge, the SDK PR must be merged and released to PyPI, then the pin must be replaced with the released version number.
|
||||
|
||||
---
|
||||
|
||||
## Deploying to a Staging Feature Environment
|
||||
|
||||
The `deploy` repo creates preview environments from OpenHands PRs.
|
||||
|
||||
**Option A — GitHub Actions UI (preferred):**
|
||||
Go to `OpenHands/deploy` → Actions → "Create OpenHands preview PR" → enter the OpenHands PR number. This creates a branch `ohpr-<PR>-<random>` and opens a deploy PR.
|
||||
|
||||
**Option B — Update an existing feature branch:**
|
||||
```bash
|
||||
cd deploy
|
||||
git checkout ohpr-<PR>-<random>
|
||||
# In .github/workflows/deploy.yaml, update BOTH:
|
||||
# OPENHANDS_SHA: "<full-40-char-commit>"
|
||||
# OPENHANDS_RUNTIME_IMAGE_TAG: "<same-commit>-nikolaik"
|
||||
git commit -am "Update OPENHANDS_SHA to <commit>" && git push
|
||||
```
|
||||
|
||||
**Before updating the SHA**, verify the enterprise Docker image exists:
|
||||
```bash
|
||||
gh api repos/OpenHands/OpenHands/actions/runs \
|
||||
--jq '.workflow_runs[] | select(.head_sha=="<COMMIT>") | "\(.name): \(.conclusion)"' \
|
||||
| grep Docker
|
||||
# Must show: "Docker: success"
|
||||
```
|
||||
|
||||
The deploy CI auto-triggers and creates the environment at:
|
||||
```
|
||||
https://ohpr-<PR>-<random>.staging.all-hands.dev
|
||||
```
|
||||
|
||||
**Wait for it to be live:**
|
||||
```bash
|
||||
curl -s -o /dev/null -w "%{http_code}" https://ohpr-<PR>-<random>.staging.all-hands.dev/api/v1/health
|
||||
# 401 = server is up (auth required). DNS may take 1-2 min on first deploy.
|
||||
```
|
||||
|
||||
## Running E2E Tests Against Staging
|
||||
|
||||
**Critical: Feature deployments have their own Keycloak instance.** API keys from `app.all-hands.dev` or `$OPENHANDS_API_KEY` will NOT work. You need a test API key issued by the specific feature deployment's Keycloak.
|
||||
|
||||
**You (the agent) cannot obtain this key yourself** — the feature environment requires interactive browser login with credentials you do not have. You must **ask the user** to:
|
||||
1. Log in to the feature deployment at `https://ohpr-<PR>-<random>.staging.all-hands.dev` in their browser
|
||||
2. Generate a test API key from the UI
|
||||
3. Provide the key to you so you can proceed with e2e testing
|
||||
|
||||
Do **not** attempt to log in via the browser or guess credentials. Wait for the user to supply the key before running any e2e tests.
|
||||
|
||||
```python
|
||||
from openhands.workspace import OpenHandsCloudWorkspace
|
||||
|
||||
STAGING = "https://ohpr-<PR>-<random>.staging.all-hands.dev"
|
||||
|
||||
with OpenHandsCloudWorkspace(
|
||||
cloud_api_url=STAGING,
|
||||
cloud_api_key="<test-api-key-for-this-deployment>",
|
||||
) as workspace:
|
||||
# Test the new feature
|
||||
llm = workspace.get_llm()
|
||||
secrets = workspace.get_secrets()
|
||||
print(f"LLM: {llm.model}, secrets: {list(secrets.keys())}")
|
||||
```
|
||||
|
||||
Or run an example script:
|
||||
```bash
|
||||
OPENHANDS_CLOUD_API_KEY="<key>" \
|
||||
OPENHANDS_CLOUD_API_URL="https://ohpr-<PR>-<random>.staging.all-hands.dev" \
|
||||
python examples/02_remote_agent_server/10_cloud_workspace_saas_credentials.py
|
||||
```
|
||||
|
||||
### Recording results
|
||||
|
||||
Both repos support a `.pr/` directory for temporary PR artifacts (design docs, test logs, scripts). These files are automatically removed when the PR is approved — see `.github/workflows/pr-artifacts.yml` and the "PR-Specific Artifacts" section in each repo's `AGENTS.md`.
|
||||
|
||||
Push test output to the `.pr/logs/` directory of whichever repo you're working in:
|
||||
```bash
|
||||
mkdir -p .pr/logs
|
||||
python test_script.py 2>&1 | tee .pr/logs/<test_name>.log
|
||||
git add -f .pr/logs/
|
||||
git commit -m "docs: add e2e test results" && git push
|
||||
```
|
||||
|
||||
Comment on **both PRs** with pass/fail summary and link to logs.
|
||||
|
||||
## Key Gotchas
|
||||
|
||||
| Gotcha | Details |
|
||||
|--------|---------|
|
||||
| **Feature env auth is isolated** | Each `ohpr-*` deployment has its own Keycloak. Production API keys don't work. Agents cannot log in — you must ask the user to provide a test API key from the feature deployment's UI. |
|
||||
| **Two SHAs in deploy.yaml** | `OPENHANDS_SHA` and `OPENHANDS_RUNTIME_IMAGE_TAG` must both be updated. The runtime tag is `<sha>-nikolaik`. |
|
||||
| **Enterprise image must exist** | The Docker CI job on the OpenHands PR must succeed before you can deploy. If it hasn't run, push an empty commit to trigger it. |
|
||||
| **DNS propagation** | First deployment of a new branch takes 1-2 min for DNS. Subsequent deploys are instant. |
|
||||
| **Merge-commit SHA ≠ head SHA** | SDK CI tags Docker images with GitHub Actions' merge-commit SHA, not the PR head SHA. Check the SDK PR description or CI logs for the correct tag. |
|
||||
| **SDK pin blocks merge** | `check-package-versions.yml` prevents merging an OpenHands PR that has `rev` fields in `[tool.poetry.dependencies]`. The SDK must be released to PyPI first. |
|
||||
| **Flow A: stock agent-server is fine** | When only the Cloud API changes, `OpenHandsCloudWorkspace` talks to the Cloud server, not the agent-server. No custom image needed. |
|
||||
| **Flow B: agent-server image is required** | When the server needs new SDK code inside runtime containers, you must pin to the SDK PR's agent-server image. |
|
||||
@@ -219,11 +219,9 @@ jobs:
|
||||
- name: Determine app image tag
|
||||
shell: bash
|
||||
run: |
|
||||
# Duplicated with build.sh
|
||||
sanitized_ref_name=$(echo "$GITHUB_REF_NAME" | sed 's/[^a-zA-Z0-9.-]\+/-/g')
|
||||
OPENHANDS_BUILD_VERSION=$sanitized_ref_name
|
||||
sanitized_ref_name=$(echo "$sanitized_ref_name" | tr '[:upper:]' '[:lower:]') # lower case is required in tagging
|
||||
echo "OPENHANDS_DOCKER_TAG=${sanitized_ref_name}" >> $GITHUB_ENV
|
||||
# Use the commit SHA to pin the exact app image built by ghcr_build_app,
|
||||
# rather than a mutable branch tag like "main" which can serve stale cached layers.
|
||||
echo "OPENHANDS_DOCKER_TAG=${RELEVANT_SHA}" >> $GITHUB_ENV
|
||||
- name: Build and push Docker image
|
||||
uses: useblacksmith/build-push-action@v1
|
||||
with:
|
||||
|
||||
@@ -0,0 +1,136 @@
|
||||
---
|
||||
name: PR Artifacts
|
||||
|
||||
on:
|
||||
workflow_dispatch: # Manual trigger for testing
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
branches: [main]
|
||||
pull_request_review:
|
||||
types: [submitted]
|
||||
|
||||
jobs:
|
||||
# Auto-remove .pr/ directory when a reviewer approves
|
||||
cleanup-on-approval:
|
||||
concurrency:
|
||||
group: cleanup-pr-artifacts-${{ github.event.pull_request.number }}
|
||||
cancel-in-progress: false
|
||||
if: github.event_name == 'pull_request_review' && github.event.review.state == 'approved'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Check if fork PR
|
||||
id: check-fork
|
||||
run: |
|
||||
if [ "${{ github.event.pull_request.head.repo.full_name }}" != "${{ github.event.pull_request.base.repo.full_name }}" ]; then
|
||||
echo "is_fork=true" >> $GITHUB_OUTPUT
|
||||
echo "::notice::Fork PR detected - skipping auto-cleanup (manual removal required)"
|
||||
else
|
||||
echo "is_fork=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- uses: actions/checkout@v5
|
||||
if: steps.check-fork.outputs.is_fork == 'false'
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.ref }}
|
||||
token: ${{ secrets.ALLHANDS_BOT_GITHUB_PAT }}
|
||||
|
||||
- name: Remove .pr/ directory
|
||||
id: remove
|
||||
if: steps.check-fork.outputs.is_fork == 'false'
|
||||
run: |
|
||||
if [ -d ".pr" ]; then
|
||||
git config user.name "allhands-bot"
|
||||
git config user.email "allhands-bot@users.noreply.github.com"
|
||||
git rm -rf .pr/
|
||||
git commit -m "chore: Remove PR-only artifacts [automated]"
|
||||
git push || {
|
||||
echo "::error::Failed to push cleanup commit. Check branch protection rules."
|
||||
exit 1
|
||||
}
|
||||
echo "removed=true" >> $GITHUB_OUTPUT
|
||||
echo "::notice::Removed .pr/ directory"
|
||||
else
|
||||
echo "removed=false" >> $GITHUB_OUTPUT
|
||||
echo "::notice::No .pr/ directory to remove"
|
||||
fi
|
||||
|
||||
- name: Update PR comment after cleanup
|
||||
if: steps.check-fork.outputs.is_fork == 'false' && steps.remove.outputs.removed == 'true'
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const marker = '<!-- pr-artifacts-notice -->';
|
||||
const body = `${marker}
|
||||
✅ **PR Artifacts Cleaned Up**
|
||||
|
||||
The \`.pr/\` directory has been automatically removed.
|
||||
`;
|
||||
|
||||
const { data: comments } = await github.rest.issues.listComments({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number,
|
||||
});
|
||||
|
||||
const existing = comments.find(c => c.body.includes(marker));
|
||||
if (existing) {
|
||||
await github.rest.issues.updateComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
comment_id: existing.id,
|
||||
body: body,
|
||||
});
|
||||
}
|
||||
|
||||
# Warn if .pr/ directory exists (will be auto-removed on approval)
|
||||
check-pr-artifacts:
|
||||
if: github.event_name == 'pull_request'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
|
||||
- name: Check for .pr/ directory
|
||||
id: check
|
||||
run: |
|
||||
if [ -d ".pr" ]; then
|
||||
echo "exists=true" >> $GITHUB_OUTPUT
|
||||
echo "::warning::.pr/ directory exists and will be automatically removed when the PR is approved. For fork PRs, manual removal is required before merging."
|
||||
else
|
||||
echo "exists=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Post or update PR comment
|
||||
if: steps.check.outputs.exists == 'true'
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const marker = '<!-- pr-artifacts-notice -->';
|
||||
const body = `${marker}
|
||||
📁 **PR Artifacts Notice**
|
||||
|
||||
This PR contains a \`.pr/\` directory with PR-specific documents. This directory will be **automatically removed** when the PR is approved.
|
||||
|
||||
> For fork PRs: Manual removal is required before merging.
|
||||
`;
|
||||
|
||||
const { data: comments } = await github.rest.issues.listComments({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number,
|
||||
});
|
||||
|
||||
const existing = comments.find(c => c.body.includes(marker));
|
||||
if (!existing) {
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number,
|
||||
body: body,
|
||||
});
|
||||
}
|
||||
@@ -1,6 +1,21 @@
|
||||
This repository contains the code for OpenHands, an automated AI software engineer. It has a Python backend
|
||||
(in the `openhands` directory) and React frontend (in the `frontend` directory).
|
||||
|
||||
|
||||
## Repository Memory
|
||||
- Legacy `/api/settings` responses can bridge to the SDK by returning `sdk_settings_schema` from `openhands.sdk.settings` when that package is available. Use this as the compatibility handoff while V1 settings work moves into the SDK and newer clients.
|
||||
- The legacy LLM settings screen now renders SDK-backed sections from `sdk_settings_schema` and reads/writes values through the generic settings blob. The canonical backend field is `agent_settings`; `sdk_settings_values` is a compatibility alias for older callers.
|
||||
- In enterprise mode, persist the generic SDK settings blob in `agent_settings` on `enterprise/storage/org_member.py` and `enterprise/storage/user_settings.py`. Keep a `sdk_settings_values` alias only for compatibility with older tests/callers.
|
||||
- Persisted SaaS `agent_settings` should carry a `schema_version` and canonical dotted keys, but should not duplicate secret SDK values like `llm.api_key` in plaintext JSON. Reconstruct those from encrypted legacy columns on load, and backfill/migrate rows on read/write.
|
||||
- The frontend settings query still normalizes canonical backend fields (`agent_settings`, `agent_settings_schema`) back into legacy `sdk_settings_values` / `sdk_settings_schema` for existing settings screens. Strip both canonical and legacy schema/value blobs from save payloads so redacted GET metadata is never POSTed back.
|
||||
|
||||
- The SDK settings schema now uses neutral metadata (`value_type`, `prominence`, `choices`, `depends_on`) instead of legacy UI-only fields like `widget`, `advanced`, or `placeholder`. Frontend helpers should derive control types from `value_type`/`choices`, and dotted `sdk_settings_values` may include structured JSON objects/arrays.
|
||||
- When constructing runtime `LLM`s for `openhands/*` models, keep explicit user-provided `llm.base_url` overrides, but prefer the app's `openhands_provider_base_url` when the user did not set one. Newer SDK defaults may populate an OpenHands proxy URL automatically, so check persisted user settings rather than `AgentSettings.llm.base_url` alone.
|
||||
- SDK `AgentSettings` sections are: `llm`, `condenser`, `verification`. The `verification` section merges former `critic` + `security` settings into one `VerificationSettings` model. Backward-compat property accessors (`.critic`, `.security`, `.enabled`, `.mode`, `.threshold`) and type aliases (`CriticSettings`, `SecuritySettings`) are preserved. Do NOT subclass `AgentSettings` in OpenHands — use it directly.
|
||||
|
||||
|
||||
|
||||
|
||||
## General Setup:
|
||||
To set up the entire repo, including frontend and backend, run `make build`.
|
||||
You don't need to do this unless the user asks you to, or if you're trying to run the entire application.
|
||||
@@ -36,6 +51,40 @@ then re-run the command to ensure it passes. Common issues include:
|
||||
- Be especially careful with `git reset --hard` after staging files, as it will remove accidentally staged files
|
||||
- When remote has new changes, use `git fetch upstream && git rebase upstream/<branch>` on the same branch
|
||||
|
||||
## PR-Specific Artifacts (`.pr/` directory)
|
||||
|
||||
When working on a PR that requires design documents, scripts meant for development-only, or other temporary artifacts that should NOT be merged to main, store them in a `.pr/` directory at the repository root.
|
||||
|
||||
### Usage
|
||||
|
||||
```
|
||||
.pr/
|
||||
├── design.md # Design decisions and architecture notes
|
||||
├── analysis.md # Investigation or debugging notes
|
||||
├── logs/ # Test output or CI logs for reviewer reference
|
||||
└── notes.md # Any other PR-specific content
|
||||
```
|
||||
|
||||
### How It Works
|
||||
|
||||
1. **Notification**: When `.pr/` exists, a comment is posted to the PR conversation alerting reviewers
|
||||
2. **Auto-cleanup**: When the PR is approved, the `.pr/` directory is automatically removed via `.github/workflows/pr-artifacts.yml`
|
||||
3. **Fork PRs**: Auto-cleanup cannot push to forks, so manual removal is required before merging
|
||||
|
||||
### Important Notes
|
||||
|
||||
- Do NOT put anything in `.pr/` that needs to be preserved after merge
|
||||
- The `.pr/` check passes (green ✅) during development — it only posts a notification, not a blocking error
|
||||
- For fork PRs: You must manually remove `.pr/` before the PR can be merged
|
||||
|
||||
### When to Use
|
||||
|
||||
- Complex refactoring that benefits from written design rationale
|
||||
- Debugging sessions where you want to document your investigation
|
||||
- E2E test results or logs that demonstrate a cross-repo feature works
|
||||
- Feature implementations that need temporary planning docs
|
||||
- Any analysis that helps reviewers understand the PR but isn't needed long-term
|
||||
|
||||
## Repository Structure
|
||||
Backend:
|
||||
- Located in the `openhands` directory
|
||||
|
||||
@@ -125,6 +125,17 @@ For example, a PR title could be:
|
||||
- If your changes are user-facing (e.g. a new feature in the UI, a change in behavior, or a bugfix),
|
||||
please include a short message that we can add to our changelog
|
||||
|
||||
## Becoming a Maintainer
|
||||
|
||||
For contributors who have made significant and sustained contributions to the project, there is a possibility of joining the maintainer team.
|
||||
The process for this is as follows:
|
||||
|
||||
1. Any contributor who has made sustained and high-quality contributions to the codebase can be nominated by any maintainer. If you feel that you may qualify you can reach out to any of the maintainers that have reviewed your PRs and ask if you can be nominated.
|
||||
2. Once a maintainer nominates a new maintainer, there will be a discussion period among the maintainers for at least 3 days.
|
||||
3. If no concerns are raised the nomination will be accepted by acclamation, and if concerns are raised there will be a discussion and possible vote.
|
||||
|
||||
Note that just making many PRs does not immediately imply that you will become a maintainer. We will be looking at sustained high-quality contributions over a period of time, as well as good teamwork and adherence to our [Code of Conduct](./CODE_OF_CONDUCT.md).
|
||||
|
||||
## Need Help?
|
||||
|
||||
- **Slack**: [Join our community](https://openhands.dev/joinslack)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
ARG OPENHANDS_BUILD_VERSION=dev
|
||||
FROM node:25.2-trixie-slim AS frontend-builder
|
||||
FROM node:25.8-trixie-slim AS frontend-builder
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -22,7 +22,7 @@ ENV POETRY_NO_INTERACTION=1 \
|
||||
|
||||
RUN apt-get update -y \
|
||||
&& apt-get install -y curl make git build-essential jq gettext \
|
||||
&& python3 -m pip install poetry --break-system-packages
|
||||
&& python3 -m pip install "poetry>=2.3.0" --break-system-packages
|
||||
|
||||
COPY pyproject.toml poetry.lock ./
|
||||
RUN touch README.md
|
||||
|
||||
@@ -10,7 +10,7 @@ LABEL com.datadoghq.tags.env="${DD_ENV}"
|
||||
# Apply security updates to fix CVEs
|
||||
RUN apt-get update && \
|
||||
apt-get install -y curl && \
|
||||
curl -fsSL https://deb.nodesource.com/setup_20.x | bash - && \
|
||||
curl -fsSL https://deb.nodesource.com/setup_24.x | bash - && \
|
||||
apt-get install -y nodejs && \
|
||||
apt-get install -y jq gettext && \
|
||||
# Apply security updates for packages with available fixes
|
||||
@@ -33,7 +33,8 @@ RUN cd /tmp/enterprise && \
|
||||
# Export only main dependencies with hashes for supply chain security
|
||||
/app/.venv/bin/poetry export --only main -o requirements.txt && \
|
||||
# Remove the local path dependency (openhands-ai is already in base image)
|
||||
sed -i '/^-e /d; /openhands-ai/d' requirements.txt && \
|
||||
# and git-based SDK dependencies (already installed via the base app image)
|
||||
sed -i '/^-e /d; /openhands-ai/d; /^openhands-.*@ git+/d' requirements.txt && \
|
||||
# Install pinned dependencies from lock file
|
||||
/app/.venv/bin/pip install -r requirements.txt && \
|
||||
# Cleanup - return to /app before removing /tmp/enterprise
|
||||
|
||||
@@ -0,0 +1,142 @@
|
||||
"""Add agent_settings columns to enterprise settings tables.
|
||||
|
||||
Revision ID: 102
|
||||
Revises: 101
|
||||
Create Date: 2026-03-22 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '102'
|
||||
down_revision: Union[str, None] = '101'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
_EMPTY_JSON = sa.text("'{}'::json")
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
'user_settings',
|
||||
sa.Column(
|
||||
'agent_settings', sa.JSON(), nullable=False, server_default=_EMPTY_JSON
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
'org_member',
|
||||
sa.Column(
|
||||
'agent_settings', sa.JSON(), nullable=False, server_default=_EMPTY_JSON
|
||||
),
|
||||
)
|
||||
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE user_settings
|
||||
SET agent_settings = jsonb_strip_nulls(
|
||||
jsonb_build_object(
|
||||
'schema_version', 1,
|
||||
'agent', agent,
|
||||
'llm.model', llm_model,
|
||||
'llm.base_url', llm_base_url,
|
||||
'verification.confirmation_mode', confirmation_mode,
|
||||
'verification.security_analyzer', security_analyzer,
|
||||
'condenser.enabled', enable_default_condenser,
|
||||
'condenser.max_size', condenser_max_size,
|
||||
'max_iterations', max_iterations
|
||||
) || COALESCE(agent_settings::jsonb, '{}'::jsonb)
|
||||
)::json
|
||||
"""
|
||||
)
|
||||
)
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE org_member
|
||||
SET agent_settings = jsonb_strip_nulls(
|
||||
jsonb_build_object(
|
||||
'schema_version', 1,
|
||||
'llm.model', llm_model,
|
||||
'llm.base_url', llm_base_url,
|
||||
'max_iterations', max_iterations
|
||||
) || COALESCE(agent_settings::jsonb, '{}'::jsonb)
|
||||
)::json
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
op.alter_column('user_settings', 'agent_settings', server_default=None)
|
||||
op.alter_column('org_member', 'agent_settings', server_default=None)
|
||||
op.drop_column('user_settings', 'agent')
|
||||
op.drop_column('user_settings', 'max_iterations')
|
||||
op.drop_column('user_settings', 'security_analyzer')
|
||||
op.drop_column('user_settings', 'confirmation_mode')
|
||||
op.drop_column('user_settings', 'llm_model')
|
||||
op.drop_column('user_settings', 'llm_base_url')
|
||||
op.drop_column('user_settings', 'enable_default_condenser')
|
||||
op.drop_column('user_settings', 'condenser_max_size')
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column('user_settings', sa.Column('agent', sa.String(), nullable=True))
|
||||
op.add_column(
|
||||
'user_settings', sa.Column('max_iterations', sa.Integer(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
'user_settings', sa.Column('security_analyzer', sa.String(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
'user_settings', sa.Column('confirmation_mode', sa.Boolean(), nullable=True)
|
||||
)
|
||||
op.add_column('user_settings', sa.Column('llm_model', sa.String(), nullable=True))
|
||||
op.add_column(
|
||||
'user_settings', sa.Column('llm_base_url', sa.String(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
'user_settings',
|
||||
sa.Column(
|
||||
'enable_default_condenser',
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.true(),
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
'user_settings', sa.Column('condenser_max_size', sa.Integer(), nullable=True)
|
||||
)
|
||||
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE user_settings
|
||||
SET
|
||||
agent = agent_settings ->> 'agent',
|
||||
max_iterations = NULLIF(agent_settings ->> 'max_iterations', '')::integer,
|
||||
security_analyzer =
|
||||
agent_settings ->> 'verification.security_analyzer',
|
||||
confirmation_mode = CASE
|
||||
WHEN agent_settings::jsonb ? 'verification.confirmation_mode'
|
||||
THEN (agent_settings ->> 'verification.confirmation_mode')::boolean
|
||||
ELSE NULL
|
||||
END,
|
||||
llm_model = agent_settings ->> 'llm.model',
|
||||
llm_base_url = agent_settings ->> 'llm.base_url',
|
||||
enable_default_condenser = CASE
|
||||
WHEN agent_settings::jsonb ? 'condenser.enabled'
|
||||
THEN (agent_settings ->> 'condenser.enabled')::boolean
|
||||
ELSE TRUE
|
||||
END,
|
||||
condenser_max_size =
|
||||
NULLIF(agent_settings ->> 'condenser.max_size', '')::integer
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
op.drop_column('org_member', 'agent_settings')
|
||||
op.drop_column('user_settings', 'agent_settings')
|
||||
Generated
+1933
-1820
File diff suppressed because it is too large
Load Diff
@@ -46,6 +46,7 @@ from server.routes.org_invitations import ( # noqa: E402
|
||||
)
|
||||
from server.routes.orgs import org_router # noqa: E402
|
||||
from server.routes.readiness import readiness_router # noqa: E402
|
||||
from server.routes.service import service_router # noqa: E402
|
||||
from server.routes.user import saas_user_router # noqa: E402
|
||||
from server.routes.user_app_settings import user_app_settings_router # noqa: E402
|
||||
from server.sharing.shared_conversation_router import ( # noqa: E402
|
||||
@@ -112,6 +113,7 @@ if GITLAB_APP_CLIENT_ID:
|
||||
base_app.include_router(gitlab_integration_router)
|
||||
|
||||
base_app.include_router(api_keys_router) # Add routes for API key management
|
||||
base_app.include_router(service_router) # Add routes for internal service API
|
||||
base_app.include_router(org_router) # Add routes for organization management
|
||||
base_app.include_router(
|
||||
verified_models_router
|
||||
|
||||
@@ -35,7 +35,7 @@ Usage:
|
||||
from enum import Enum
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi import Depends, HTTPException, Request, status
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
from storage.role import Role
|
||||
from storage.role_store import RoleStore
|
||||
@@ -214,6 +214,19 @@ def has_permission(user_role: Role, permission: Permission) -> bool:
|
||||
return permission in permissions
|
||||
|
||||
|
||||
async def get_api_key_org_id_from_request(request: Request) -> UUID | None:
|
||||
"""Get the org_id bound to the API key used for authentication.
|
||||
|
||||
Returns None if:
|
||||
- Not authenticated via API key (cookie auth)
|
||||
- API key is a legacy key without org binding
|
||||
"""
|
||||
user_auth = getattr(request.state, 'user_auth', None)
|
||||
if user_auth and hasattr(user_auth, 'get_api_key_org_id'):
|
||||
return user_auth.get_api_key_org_id()
|
||||
return None
|
||||
|
||||
|
||||
def require_permission(permission: Permission):
|
||||
"""
|
||||
Factory function that creates a dependency to require a specific permission.
|
||||
@@ -221,8 +234,9 @@ def require_permission(permission: Permission):
|
||||
This creates a FastAPI dependency that:
|
||||
1. Extracts org_id from the path parameter
|
||||
2. Gets the authenticated user_id
|
||||
3. Checks if the user has the required permission in the organization
|
||||
4. Returns the user_id if authorized, raises HTTPException otherwise
|
||||
3. Validates API key org binding (if using API key auth)
|
||||
4. Checks if the user has the required permission in the organization
|
||||
5. Returns the user_id if authorized, raises HTTPException otherwise
|
||||
|
||||
Usage:
|
||||
@router.get('/{org_id}/settings')
|
||||
@@ -240,6 +254,7 @@ def require_permission(permission: Permission):
|
||||
"""
|
||||
|
||||
async def permission_checker(
|
||||
request: Request,
|
||||
org_id: UUID | None = None,
|
||||
user_id: str | None = Depends(get_user_id),
|
||||
) -> str:
|
||||
@@ -249,6 +264,23 @@ def require_permission(permission: Permission):
|
||||
detail='User not authenticated',
|
||||
)
|
||||
|
||||
# Validate API key organization binding
|
||||
api_key_org_id = await get_api_key_org_id_from_request(request)
|
||||
if api_key_org_id is not None and org_id is not None:
|
||||
if api_key_org_id != org_id:
|
||||
logger.warning(
|
||||
'API key organization mismatch',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'api_key_org_id': str(api_key_org_id),
|
||||
'target_org_id': str(org_id),
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail='API key is not authorized for this organization',
|
||||
)
|
||||
|
||||
user_role = await get_user_org_role(user_id, org_id)
|
||||
|
||||
if not user_role:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from types import MappingProxyType
|
||||
from uuid import UUID
|
||||
|
||||
import jwt
|
||||
from fastapi import Request
|
||||
@@ -59,6 +60,19 @@ class SaasUserAuth(UserAuth):
|
||||
_secrets: Secrets | None = None
|
||||
accepted_tos: bool | None = None
|
||||
auth_type: AuthType = AuthType.COOKIE
|
||||
# API key context fields - populated when authenticated via API key
|
||||
api_key_org_id: UUID | None = None # Org bound to the API key used for auth
|
||||
api_key_id: int | None = None
|
||||
api_key_name: str | None = None
|
||||
|
||||
def get_api_key_org_id(self) -> UUID | None:
|
||||
"""Get the organization ID bound to the API key used for authentication.
|
||||
|
||||
Returns:
|
||||
The org_id if authenticated via API key with org binding, None otherwise
|
||||
(cookie auth or legacy API keys without org binding).
|
||||
"""
|
||||
return self.api_key_org_id
|
||||
|
||||
async def get_user_id(self) -> str | None:
|
||||
return self.user_id
|
||||
@@ -283,14 +297,19 @@ async def saas_user_auth_from_bearer(request: Request) -> SaasUserAuth | None:
|
||||
return None
|
||||
|
||||
api_key_store = ApiKeyStore.get_instance()
|
||||
user_id = await api_key_store.validate_api_key(api_key)
|
||||
if not user_id:
|
||||
validation_result = await api_key_store.validate_api_key(api_key)
|
||||
if not validation_result:
|
||||
return None
|
||||
offline_token = await token_manager.load_offline_token(user_id)
|
||||
offline_token = await token_manager.load_offline_token(
|
||||
validation_result.user_id
|
||||
)
|
||||
saas_user_auth = SaasUserAuth(
|
||||
user_id=user_id,
|
||||
user_id=validation_result.user_id,
|
||||
refresh_token=SecretStr(offline_token),
|
||||
auth_type=AuthType.BEARER,
|
||||
api_key_org_id=validation_result.org_id,
|
||||
api_key_id=validation_result.key_id,
|
||||
api_key_name=validation_result.key_name,
|
||||
)
|
||||
await saas_user_auth.refresh()
|
||||
return saas_user_auth
|
||||
|
||||
@@ -182,6 +182,10 @@ class SetAuthCookieMiddleware:
|
||||
if path.startswith('/api/v1/webhooks/'):
|
||||
return False
|
||||
|
||||
# Service API uses its own authentication (X-Service-API-Key header)
|
||||
if path.startswith('/api/service/'):
|
||||
return False
|
||||
|
||||
is_mcp = path.startswith('/mcp')
|
||||
is_api_route = path.startswith('/api')
|
||||
return is_api_route or is_mcp
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from datetime import UTC, datetime
|
||||
from typing import cast
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from pydantic import BaseModel, field_validator
|
||||
from server.auth.saas_user_auth import SaasUserAuth
|
||||
from storage.api_key import ApiKey
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
@@ -11,7 +13,8 @@ from storage.org_service import OrgService
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_id
|
||||
from openhands.server.user_auth import get_user_auth, get_user_id
|
||||
from openhands.server.user_auth.user_auth import AuthType
|
||||
|
||||
|
||||
# Helper functions for BYOR API key management
|
||||
@@ -150,6 +153,16 @@ class MessageResponse(BaseModel):
|
||||
message: str
|
||||
|
||||
|
||||
class CurrentApiKeyResponse(BaseModel):
|
||||
"""Response model for the current API key endpoint."""
|
||||
|
||||
id: int
|
||||
name: str | None
|
||||
org_id: str
|
||||
user_id: str
|
||||
auth_type: str
|
||||
|
||||
|
||||
def api_key_to_response(key: ApiKey) -> ApiKeyResponse:
|
||||
"""Convert an ApiKey model to an ApiKeyResponse."""
|
||||
return ApiKeyResponse(
|
||||
@@ -262,6 +275,46 @@ async def delete_api_key(
|
||||
)
|
||||
|
||||
|
||||
@api_router.get('/current', tags=['Keys'])
|
||||
async def get_current_api_key(
|
||||
request: Request,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> CurrentApiKeyResponse:
|
||||
"""Get information about the currently authenticated API key.
|
||||
|
||||
This endpoint returns metadata about the API key used for the current request,
|
||||
including the org_id associated with the key. This is useful for API key
|
||||
callers who need to know which organization context their key operates in.
|
||||
|
||||
Returns 400 if not authenticated via API key (e.g., using cookie auth).
|
||||
"""
|
||||
user_auth = await get_user_auth(request)
|
||||
|
||||
# Check if authenticated via API key
|
||||
if user_auth.get_auth_type() != AuthType.BEARER:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='This endpoint requires API key authentication. Not available for cookie-based auth.',
|
||||
)
|
||||
|
||||
# In SaaS context, bearer auth always produces SaasUserAuth
|
||||
saas_user_auth = cast(SaasUserAuth, user_auth)
|
||||
|
||||
if saas_user_auth.api_key_org_id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='This API key was created before organization support. Please regenerate your API key to use this endpoint.',
|
||||
)
|
||||
|
||||
return CurrentApiKeyResponse(
|
||||
id=saas_user_auth.api_key_id,
|
||||
name=saas_user_auth.api_key_name,
|
||||
org_id=str(saas_user_auth.api_key_org_id),
|
||||
user_id=user_id,
|
||||
auth_type=saas_user_auth.auth_type.value,
|
||||
)
|
||||
|
||||
|
||||
@api_router.get('/llm/byor', tags=['Keys'])
|
||||
async def get_llm_api_key_for_byor(
|
||||
user_id: str = Depends(get_user_id),
|
||||
|
||||
@@ -68,7 +68,7 @@ async def list_user_orgs(
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(title='The max number of results in the page', gt=0, lte=100),
|
||||
Query(title='The max number of results in the page', gt=0, le=100),
|
||||
] = 100,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> OrgPage:
|
||||
@@ -734,7 +734,7 @@ async def get_org_members(
|
||||
Query(
|
||||
title='The max number of results in the page',
|
||||
gt=0,
|
||||
lte=100,
|
||||
le=100,
|
||||
),
|
||||
] = 10,
|
||||
email: Annotated[
|
||||
|
||||
@@ -0,0 +1,270 @@
|
||||
"""
|
||||
Service API routes for internal service-to-service communication.
|
||||
|
||||
This module provides endpoints for trusted internal services (e.g., automations service)
|
||||
to perform privileged operations like creating API keys on behalf of users.
|
||||
|
||||
Authentication is via a shared secret (X-Service-API-Key header) configured
|
||||
through the AUTOMATIONS_SERVICE_API_KEY environment variable.
|
||||
"""
|
||||
|
||||
import os
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Header, HTTPException, status
|
||||
from pydantic import BaseModel, field_validator
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
# Environment variable for the service API key
|
||||
AUTOMATIONS_SERVICE_API_KEY = os.getenv('AUTOMATIONS_SERVICE_API_KEY', '').strip()
|
||||
|
||||
service_router = APIRouter(prefix='/api/service', tags=['Service'])
|
||||
|
||||
|
||||
class CreateUserApiKeyRequest(BaseModel):
|
||||
"""Request model for creating an API key on behalf of a user."""
|
||||
|
||||
name: str # Required - used to identify the key
|
||||
|
||||
@field_validator('name')
|
||||
@classmethod
|
||||
def validate_name(cls, v: str) -> str:
|
||||
if not v or not v.strip():
|
||||
raise ValueError('name is required and cannot be empty')
|
||||
return v.strip()
|
||||
|
||||
|
||||
class CreateUserApiKeyResponse(BaseModel):
|
||||
"""Response model for created API key."""
|
||||
|
||||
key: str
|
||||
user_id: str
|
||||
org_id: str
|
||||
name: str
|
||||
|
||||
|
||||
class ServiceInfoResponse(BaseModel):
|
||||
"""Response model for service info endpoint."""
|
||||
|
||||
service: str
|
||||
authenticated: bool
|
||||
|
||||
|
||||
async def validate_service_api_key(
|
||||
x_service_api_key: str | None = Header(default=None, alias='X-Service-API-Key'),
|
||||
) -> str:
|
||||
"""
|
||||
Validate the service API key from the request header.
|
||||
|
||||
Args:
|
||||
x_service_api_key: The service API key from the X-Service-API-Key header
|
||||
|
||||
Returns:
|
||||
str: Service identifier for audit logging
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if key is missing or invalid
|
||||
HTTPException: 503 if service auth is not configured
|
||||
"""
|
||||
if not AUTOMATIONS_SERVICE_API_KEY:
|
||||
logger.warning(
|
||||
'Service authentication not configured (AUTOMATIONS_SERVICE_API_KEY not set)'
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail='Service authentication not configured',
|
||||
)
|
||||
|
||||
if not x_service_api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail='X-Service-API-Key header is required',
|
||||
)
|
||||
|
||||
if x_service_api_key != AUTOMATIONS_SERVICE_API_KEY:
|
||||
logger.warning('Invalid service API key attempted')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail='Invalid service API key',
|
||||
)
|
||||
|
||||
return 'automations-service'
|
||||
|
||||
|
||||
@service_router.get('/health')
|
||||
async def service_health() -> dict:
|
||||
"""Health check endpoint for the service API.
|
||||
|
||||
This endpoint does not require authentication and can be used
|
||||
to verify the service routes are accessible.
|
||||
"""
|
||||
return {
|
||||
'status': 'ok',
|
||||
'service_auth_configured': bool(AUTOMATIONS_SERVICE_API_KEY),
|
||||
}
|
||||
|
||||
|
||||
@service_router.post('/users/{user_id}/orgs/{org_id}/api-keys')
|
||||
async def get_or_create_api_key_for_user(
|
||||
user_id: str,
|
||||
org_id: UUID,
|
||||
request: CreateUserApiKeyRequest,
|
||||
x_service_api_key: str | None = Header(default=None, alias='X-Service-API-Key'),
|
||||
) -> CreateUserApiKeyResponse:
|
||||
"""
|
||||
Get or create an API key for a user on behalf of the automations service.
|
||||
|
||||
If a key with the given name already exists for the user/org and is not expired,
|
||||
returns the existing key. Otherwise, creates a new key.
|
||||
|
||||
The created/returned keys are system keys and are:
|
||||
- Not visible to the user in their API keys list
|
||||
- Not deletable by the user
|
||||
- Never expire
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
org_id: The organization ID
|
||||
request: Request body containing name (required)
|
||||
x_service_api_key: Service API key header for authentication
|
||||
|
||||
Returns:
|
||||
CreateUserApiKeyResponse: The API key and metadata
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if service key is invalid
|
||||
HTTPException: 404 if user not found
|
||||
HTTPException: 403 if user is not a member of the specified org
|
||||
"""
|
||||
# Validate service API key
|
||||
service_id = await validate_service_api_key(x_service_api_key)
|
||||
|
||||
# Verify user exists
|
||||
user = await UserStore.get_user_by_id(user_id)
|
||||
if not user:
|
||||
logger.warning(
|
||||
'Service attempted to create key for non-existent user',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f'User {user_id} not found',
|
||||
)
|
||||
|
||||
# Verify user is a member of the specified org
|
||||
org_member = await OrgMemberStore.get_org_member(org_id, UUID(user_id))
|
||||
if not org_member:
|
||||
logger.warning(
|
||||
'Service attempted to create key for user not in org',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f'User {user_id} is not a member of org {org_id}',
|
||||
)
|
||||
|
||||
# Get or create the system API key
|
||||
api_key_store = ApiKeyStore.get_instance()
|
||||
|
||||
try:
|
||||
api_key = await api_key_store.get_or_create_system_api_key(
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name=request.name,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Failed to get or create system API key',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to get or create API key',
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'Service created API key for user',
|
||||
extra={
|
||||
'service_id': service_id,
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'key_name': request.name,
|
||||
},
|
||||
)
|
||||
|
||||
return CreateUserApiKeyResponse(
|
||||
key=api_key,
|
||||
user_id=user_id,
|
||||
org_id=str(org_id),
|
||||
name=request.name,
|
||||
)
|
||||
|
||||
|
||||
@service_router.delete('/users/{user_id}/orgs/{org_id}/api-keys/{key_name}')
|
||||
async def delete_user_api_key(
|
||||
user_id: str,
|
||||
org_id: UUID,
|
||||
key_name: str,
|
||||
x_service_api_key: str | None = Header(default=None, alias='X-Service-API-Key'),
|
||||
) -> dict:
|
||||
"""
|
||||
Delete a system API key created by the service.
|
||||
|
||||
This endpoint allows the automations service to clean up API keys
|
||||
it previously created for users.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
org_id: The organization ID
|
||||
key_name: The name of the key to delete (without __SYSTEM__: prefix)
|
||||
x_service_api_key: Service API key header for authentication
|
||||
|
||||
Returns:
|
||||
dict: Success message
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if service key is invalid
|
||||
HTTPException: 404 if key not found
|
||||
"""
|
||||
# Validate service API key
|
||||
service_id = await validate_service_api_key(x_service_api_key)
|
||||
|
||||
api_key_store = ApiKeyStore.get_instance()
|
||||
|
||||
# Delete the key by name (wrap with system key prefix since service creates system keys)
|
||||
system_key_name = api_key_store.make_system_key_name(key_name)
|
||||
success = await api_key_store.delete_api_key_by_name(
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name=system_key_name,
|
||||
allow_system=True,
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f'API key with name "{key_name}" not found for user {user_id} in org {org_id}',
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'Service deleted API key for user',
|
||||
extra={
|
||||
'service_id': service_id,
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'key_name': key_name,
|
||||
},
|
||||
)
|
||||
|
||||
return {'message': 'API key deleted successfully'}
|
||||
@@ -4,7 +4,7 @@ from datetime import datetime
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from server.sharing.shared_conversation_info_service import (
|
||||
SharedConversationInfoService,
|
||||
)
|
||||
@@ -60,7 +60,7 @@ async def search_shared_conversations(
|
||||
Query(
|
||||
title='The max number of results in the page',
|
||||
gt=0,
|
||||
lte=100,
|
||||
le=100,
|
||||
),
|
||||
] = 100,
|
||||
include_sub_conversations: Annotated[
|
||||
@@ -72,8 +72,6 @@ async def search_shared_conversations(
|
||||
shared_conversation_service: SharedConversationInfoService = shared_conversation_info_service_dependency,
|
||||
) -> SharedConversationPage:
|
||||
"""Search / List shared conversations."""
|
||||
assert limit > 0
|
||||
assert limit <= 100
|
||||
return await shared_conversation_service.search_shared_conversation_info(
|
||||
title__contains=title__contains,
|
||||
created_at__gte=created_at__gte,
|
||||
@@ -127,7 +125,11 @@ async def batch_get_shared_conversations(
|
||||
shared_conversation_service: SharedConversationInfoService = shared_conversation_info_service_dependency,
|
||||
) -> list[SharedConversation | None]:
|
||||
"""Get a batch of shared conversations given their ids. Return None for any missing or non-shared."""
|
||||
assert len(ids) <= 100
|
||||
if len(ids) > 100:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f'Cannot request more than 100 conversations at once, got {len(ids)}',
|
||||
)
|
||||
uuids = [UUID(id_) for id_ in ids]
|
||||
shared_conversation_info = (
|
||||
await shared_conversation_service.batch_get_shared_conversation_info(uuids)
|
||||
|
||||
@@ -4,7 +4,7 @@ from datetime import datetime
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from server.sharing.shared_event_service import (
|
||||
SharedEventService,
|
||||
SharedEventServiceInjector,
|
||||
@@ -77,13 +77,11 @@ async def search_shared_events(
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(title='The max number of results in the page', gt=0, lte=100),
|
||||
Query(title='The max number of results in the page', gt=0, le=100),
|
||||
] = 100,
|
||||
shared_event_service: SharedEventService = shared_event_service_dependency,
|
||||
) -> EventPage:
|
||||
"""Search / List events for a shared conversation."""
|
||||
assert limit > 0
|
||||
assert limit <= 100
|
||||
return await shared_event_service.search_shared_events(
|
||||
conversation_id=UUID(conversation_id),
|
||||
kind__eq=kind__eq,
|
||||
@@ -134,7 +132,11 @@ async def batch_get_shared_events(
|
||||
shared_event_service: SharedEventService = shared_event_service_dependency,
|
||||
) -> list[Event | None]:
|
||||
"""Get a batch of events for a shared conversation given their ids, returning null for any missing event."""
|
||||
assert len(id) <= 100
|
||||
if len(id) > 100:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f'Cannot request more than 100 events at once, got {len(id)}',
|
||||
)
|
||||
event_ids = [UUID(id_) for id_ in id]
|
||||
events = await shared_event_service.batch_get_shared_events(
|
||||
UUID(conversation_id), event_ids
|
||||
|
||||
@@ -4,6 +4,7 @@ import secrets
|
||||
import string
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from storage.api_key import ApiKey
|
||||
@@ -13,9 +14,22 @@ from storage.user_store import UserStore
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class ApiKeyValidationResult:
|
||||
"""Result of API key validation containing user and organization info."""
|
||||
|
||||
user_id: str
|
||||
org_id: UUID | None # None for legacy API keys without org binding
|
||||
key_id: int
|
||||
key_name: str | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ApiKeyStore:
|
||||
API_KEY_PREFIX = 'sk-oh-'
|
||||
# Prefix for system keys created by internal services (e.g., automations)
|
||||
# Keys with this prefix are hidden from users and cannot be deleted by users
|
||||
SYSTEM_KEY_NAME_PREFIX = '__SYSTEM__:'
|
||||
|
||||
def generate_api_key(self, length: int = 32) -> str:
|
||||
"""Generate a random API key with the sk-oh- prefix."""
|
||||
@@ -23,6 +37,19 @@ class ApiKeyStore:
|
||||
random_part = ''.join(secrets.choice(alphabet) for _ in range(length))
|
||||
return f'{self.API_KEY_PREFIX}{random_part}'
|
||||
|
||||
@classmethod
|
||||
def is_system_key_name(cls, name: str | None) -> bool:
|
||||
"""Check if a key name indicates a system key."""
|
||||
return name is not None and name.startswith(cls.SYSTEM_KEY_NAME_PREFIX)
|
||||
|
||||
@classmethod
|
||||
def make_system_key_name(cls, name: str) -> str:
|
||||
"""Create a system key name with the appropriate prefix.
|
||||
|
||||
Format: __SYSTEM__:<name>
|
||||
"""
|
||||
return f'{cls.SYSTEM_KEY_NAME_PREFIX}{name}'
|
||||
|
||||
async def create_api_key(
|
||||
self, user_id: str, name: str | None = None, expires_at: datetime | None = None
|
||||
) -> str:
|
||||
@@ -60,8 +87,120 @@ class ApiKeyStore:
|
||||
|
||||
return api_key
|
||||
|
||||
async def validate_api_key(self, api_key: str) -> str | None:
|
||||
"""Validate an API key and return the associated user_id if valid."""
|
||||
async def get_or_create_system_api_key(
|
||||
self,
|
||||
user_id: str,
|
||||
org_id: UUID,
|
||||
name: str,
|
||||
) -> str:
|
||||
"""Get or create a system API key for a user on behalf of an internal service.
|
||||
|
||||
If a key with the given name already exists for this user/org and is not expired,
|
||||
returns the existing key. Otherwise, creates a new key (and deletes any expired one).
|
||||
|
||||
System keys are:
|
||||
- Not visible to users in their API keys list (filtered by name prefix)
|
||||
- Not deletable by users (protected by name prefix check)
|
||||
- Associated with a specific org (not the user's current org)
|
||||
- Never expire (no expiration date)
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user to create the key for
|
||||
org_id: The organization ID to associate the key with
|
||||
name: Required name for the key (will be prefixed with __SYSTEM__:)
|
||||
|
||||
Returns:
|
||||
The API key (existing or newly created)
|
||||
"""
|
||||
# Create system key name with prefix
|
||||
system_key_name = self.make_system_key_name(name)
|
||||
|
||||
async with a_session_maker() as session:
|
||||
# Check if key already exists for this user/org/name
|
||||
result = await session.execute(
|
||||
select(ApiKey).filter(
|
||||
ApiKey.user_id == user_id,
|
||||
ApiKey.org_id == org_id,
|
||||
ApiKey.name == system_key_name,
|
||||
)
|
||||
)
|
||||
existing_key = result.scalars().first()
|
||||
|
||||
if existing_key:
|
||||
# Check if expired
|
||||
if existing_key.expires_at:
|
||||
now = datetime.now(UTC)
|
||||
expires_at = existing_key.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=UTC)
|
||||
|
||||
if expires_at < now:
|
||||
# Key is expired, delete it and create new one
|
||||
logger.info(
|
||||
'System API key expired, re-issuing',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'key_name': system_key_name,
|
||||
},
|
||||
)
|
||||
await session.delete(existing_key)
|
||||
await session.commit()
|
||||
else:
|
||||
# Key exists and is not expired, return it
|
||||
logger.debug(
|
||||
'Returning existing system API key',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'key_name': system_key_name,
|
||||
},
|
||||
)
|
||||
return existing_key.key
|
||||
else:
|
||||
# Key exists and has no expiration, return it
|
||||
logger.debug(
|
||||
'Returning existing system API key',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'key_name': system_key_name,
|
||||
},
|
||||
)
|
||||
return existing_key.key
|
||||
|
||||
# Create new key (no expiration)
|
||||
api_key = self.generate_api_key()
|
||||
|
||||
async with a_session_maker() as session:
|
||||
key_record = ApiKey(
|
||||
key=api_key,
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name=system_key_name,
|
||||
expires_at=None, # System keys never expire
|
||||
)
|
||||
session.add(key_record)
|
||||
await session.commit()
|
||||
|
||||
logger.info(
|
||||
'Created system API key',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'key_name': system_key_name,
|
||||
},
|
||||
)
|
||||
|
||||
return api_key
|
||||
|
||||
async def validate_api_key(self, api_key: str) -> ApiKeyValidationResult | None:
|
||||
"""Validate an API key and return the associated user_id and org_id if valid.
|
||||
|
||||
Returns:
|
||||
ApiKeyValidationResult if the key is valid, None otherwise.
|
||||
The org_id may be None for legacy API keys that weren't bound to an organization.
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
|
||||
async with a_session_maker() as session:
|
||||
@@ -89,7 +228,12 @@ class ApiKeyStore:
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
return key_record.user_id
|
||||
return ApiKeyValidationResult(
|
||||
user_id=key_record.user_id,
|
||||
org_id=key_record.org_id,
|
||||
key_id=key_record.id,
|
||||
key_name=key_record.name,
|
||||
)
|
||||
|
||||
async def delete_api_key(self, api_key: str) -> bool:
|
||||
"""Delete an API key by the key value."""
|
||||
@@ -105,8 +249,18 @@ class ApiKeyStore:
|
||||
|
||||
return True
|
||||
|
||||
async def delete_api_key_by_id(self, key_id: int) -> bool:
|
||||
"""Delete an API key by its ID."""
|
||||
async def delete_api_key_by_id(
|
||||
self, key_id: int, allow_system: bool = False
|
||||
) -> bool:
|
||||
"""Delete an API key by its ID.
|
||||
|
||||
Args:
|
||||
key_id: The ID of the key to delete
|
||||
allow_system: If False (default), system keys cannot be deleted
|
||||
|
||||
Returns:
|
||||
True if the key was deleted, False if not found or is a protected system key
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(select(ApiKey).filter(ApiKey.id == key_id))
|
||||
key_record = result.scalars().first()
|
||||
@@ -114,13 +268,26 @@ class ApiKeyStore:
|
||||
if not key_record:
|
||||
return False
|
||||
|
||||
# Protect system keys from deletion unless explicitly allowed
|
||||
if self.is_system_key_name(key_record.name) and not allow_system:
|
||||
logger.warning(
|
||||
'Attempted to delete system API key',
|
||||
extra={'key_id': key_id, 'user_id': key_record.user_id},
|
||||
)
|
||||
return False
|
||||
|
||||
await session.delete(key_record)
|
||||
await session.commit()
|
||||
|
||||
return True
|
||||
|
||||
async def list_api_keys(self, user_id: str) -> list[ApiKey]:
|
||||
"""List all API keys for a user."""
|
||||
"""List all user-visible API keys for a user.
|
||||
|
||||
This excludes:
|
||||
- System keys (name starts with __SYSTEM__:) - created by internal services
|
||||
- MCP_API_KEY - internal MCP key
|
||||
"""
|
||||
user = await UserStore.get_user_by_id(user_id)
|
||||
if user is None:
|
||||
raise ValueError(f'User not found: {user_id}')
|
||||
@@ -129,11 +296,17 @@ class ApiKeyStore:
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(ApiKey).filter(
|
||||
ApiKey.user_id == user_id, ApiKey.org_id == org_id
|
||||
ApiKey.user_id == user_id,
|
||||
ApiKey.org_id == org_id,
|
||||
)
|
||||
)
|
||||
keys = result.scalars().all()
|
||||
return [key for key in keys if key.name != 'MCP_API_KEY']
|
||||
# Filter out system keys and MCP_API_KEY
|
||||
return [
|
||||
key
|
||||
for key in keys
|
||||
if key.name != 'MCP_API_KEY' and not self.is_system_key_name(key.name)
|
||||
]
|
||||
|
||||
async def retrieve_mcp_api_key(self, user_id: str) -> str | None:
|
||||
user = await UserStore.get_user_by_id(user_id)
|
||||
@@ -163,17 +336,44 @@ class ApiKeyStore:
|
||||
key_record = result.scalars().first()
|
||||
return key_record.key if key_record else None
|
||||
|
||||
async def delete_api_key_by_name(self, user_id: str, name: str) -> bool:
|
||||
"""Delete an API key by name for a specific user."""
|
||||
async def delete_api_key_by_name(
|
||||
self,
|
||||
user_id: str,
|
||||
name: str,
|
||||
org_id: UUID | None = None,
|
||||
allow_system: bool = False,
|
||||
) -> bool:
|
||||
"""Delete an API key by name for a specific user.
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user whose key to delete
|
||||
name: The name of the key to delete
|
||||
org_id: Optional organization ID to filter by (required for system keys)
|
||||
allow_system: If False (default), system keys cannot be deleted
|
||||
|
||||
Returns:
|
||||
True if the key was deleted, False if not found or is a protected system key
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(ApiKey).filter(ApiKey.user_id == user_id, ApiKey.name == name)
|
||||
)
|
||||
# Build the query filters
|
||||
filters = [ApiKey.user_id == user_id, ApiKey.name == name]
|
||||
if org_id is not None:
|
||||
filters.append(ApiKey.org_id == org_id)
|
||||
|
||||
result = await session.execute(select(ApiKey).filter(*filters))
|
||||
key_record = result.scalars().first()
|
||||
|
||||
if not key_record:
|
||||
return False
|
||||
|
||||
# Protect system keys from deletion unless explicitly allowed
|
||||
if self.is_system_key_name(key_record.name) and not allow_system:
|
||||
logger.warning(
|
||||
'Attempted to delete system API key',
|
||||
extra={'user_id': user_id, 'key_name': name},
|
||||
)
|
||||
return False
|
||||
|
||||
await session.delete(key_record)
|
||||
await session.commit()
|
||||
|
||||
|
||||
@@ -29,14 +29,37 @@ KEY_VERIFICATION_TIMEOUT = 5.0
|
||||
# A very large number to represent "unlimited" until LiteLLM fixes their unlimited update bug.
|
||||
UNLIMITED_BUDGET_SETTING = 1000000000.0
|
||||
|
||||
try:
|
||||
DEFAULT_INITIAL_BUDGET = float(os.environ.get('DEFAULT_INITIAL_BUDGET', 0.0))
|
||||
if DEFAULT_INITIAL_BUDGET < 0:
|
||||
# Check if billing is enabled (defaults to false for enterprise deployments)
|
||||
ENABLE_BILLING = os.environ.get('ENABLE_BILLING', 'false').lower() == 'true'
|
||||
|
||||
|
||||
def _get_default_initial_budget() -> float | None:
|
||||
"""Get the default initial budget for new teams.
|
||||
|
||||
When billing is disabled (ENABLE_BILLING=false), returns None to disable
|
||||
budget enforcement in LiteLLM. When billing is enabled, returns the
|
||||
DEFAULT_INITIAL_BUDGET environment variable value (default 0.0).
|
||||
|
||||
Returns:
|
||||
float | None: The default budget, or None to disable budget enforcement.
|
||||
"""
|
||||
if not ENABLE_BILLING:
|
||||
return None
|
||||
|
||||
try:
|
||||
budget = float(os.environ.get('DEFAULT_INITIAL_BUDGET', 0.0))
|
||||
if budget < 0:
|
||||
raise ValueError(
|
||||
f'DEFAULT_INITIAL_BUDGET must be non-negative, got {budget}'
|
||||
)
|
||||
return budget
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f'DEFAULT_INITIAL_BUDGET must be non-negative, got {DEFAULT_INITIAL_BUDGET}'
|
||||
)
|
||||
except ValueError as e:
|
||||
raise ValueError(f'Invalid DEFAULT_INITIAL_BUDGET environment variable: {e}') from e
|
||||
f'Invalid DEFAULT_INITIAL_BUDGET environment variable: {e}'
|
||||
) from e
|
||||
|
||||
|
||||
DEFAULT_INITIAL_BUDGET: float | None = _get_default_initial_budget()
|
||||
|
||||
|
||||
def get_openhands_cloud_key_alias(keycloak_user_id: str, org_id: str) -> str:
|
||||
@@ -110,12 +133,15 @@ class LiteLlmManager:
|
||||
) as client:
|
||||
# Check if team already exists and get its budget
|
||||
# New users joining existing orgs should inherit the team's budget
|
||||
team_budget: float = DEFAULT_INITIAL_BUDGET
|
||||
# When billing is disabled, DEFAULT_INITIAL_BUDGET is None
|
||||
team_budget: float | None = DEFAULT_INITIAL_BUDGET
|
||||
try:
|
||||
existing_team = await LiteLlmManager._get_team(client, org_id)
|
||||
if existing_team:
|
||||
team_info = existing_team.get('team_info', {})
|
||||
team_budget = team_info.get('max_budget', 0.0) or 0.0
|
||||
# Preserve None from existing team (no budget enforcement)
|
||||
existing_budget = team_info.get('max_budget')
|
||||
team_budget = existing_budget
|
||||
logger.info(
|
||||
'LiteLlmManager:create_entries:existing_team_budget',
|
||||
extra={
|
||||
@@ -138,9 +164,33 @@ class LiteLlmManager:
|
||||
)
|
||||
|
||||
if create_user:
|
||||
await LiteLlmManager._create_user(
|
||||
user_created = await LiteLlmManager._create_user(
|
||||
client, keycloak_user_info.get('email'), keycloak_user_id
|
||||
)
|
||||
if not user_created:
|
||||
logger.error(
|
||||
'create_entries_failed_user_creation',
|
||||
extra={
|
||||
'org_id': org_id,
|
||||
'user_id': keycloak_user_id,
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
# Verify user exists before proceeding with key generation
|
||||
user_exists = await LiteLlmManager._user_exists(
|
||||
client, keycloak_user_id
|
||||
)
|
||||
if not user_exists:
|
||||
logger.error(
|
||||
'create_entries_user_not_found_before_key_generation',
|
||||
extra={
|
||||
'org_id': org_id,
|
||||
'user_id': keycloak_user_id,
|
||||
'create_user_flag': create_user,
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
await LiteLlmManager._add_user_to_team(
|
||||
client, keycloak_user_id, org_id, team_budget
|
||||
@@ -304,10 +354,12 @@ class LiteLlmManager:
|
||||
# Check if the database key exists in LiteLLM
|
||||
# If not, generate a new key to prevent verification failures later
|
||||
db_key = None
|
||||
legacy_settings = user_settings.to_settings() if user_settings else None
|
||||
if (
|
||||
user_settings
|
||||
and user_settings.llm_api_key
|
||||
and user_settings.llm_base_url == LITE_LLM_API_URL
|
||||
and legacy_settings
|
||||
and legacy_settings.llm_base_url == LITE_LLM_API_URL
|
||||
):
|
||||
db_key = user_settings.llm_api_key
|
||||
if hasattr(db_key, 'get_secret_value'):
|
||||
@@ -525,25 +577,40 @@ class LiteLlmManager:
|
||||
client: httpx.AsyncClient,
|
||||
team_alias: str,
|
||||
team_id: str,
|
||||
max_budget: float,
|
||||
max_budget: float | None,
|
||||
):
|
||||
"""Create a new team in LiteLLM.
|
||||
|
||||
Args:
|
||||
client: The HTTP client to use.
|
||||
team_alias: The alias for the team.
|
||||
team_id: The ID for the team.
|
||||
max_budget: The maximum budget for the team. When None, budget
|
||||
enforcement is disabled (unlimited usage).
|
||||
"""
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
|
||||
json_data: dict[str, Any] = {
|
||||
'team_id': team_id,
|
||||
'team_alias': team_alias,
|
||||
'models': [],
|
||||
'spend': 0,
|
||||
'metadata': {
|
||||
'version': ORG_SETTINGS_VERSION,
|
||||
'model': get_default_litellm_model(),
|
||||
},
|
||||
}
|
||||
|
||||
if max_budget is not None:
|
||||
json_data['max_budget'] = max_budget
|
||||
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/team/new',
|
||||
json={
|
||||
'team_id': team_id,
|
||||
'team_alias': team_alias,
|
||||
'models': [],
|
||||
'max_budget': max_budget,
|
||||
'spend': 0,
|
||||
'metadata': {
|
||||
'version': ORG_SETTINGS_VERSION,
|
||||
'model': get_default_litellm_model(),
|
||||
},
|
||||
},
|
||||
json=json_data,
|
||||
)
|
||||
|
||||
# Team failed to create in litellm - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
if (
|
||||
@@ -620,15 +687,48 @@ class LiteLlmManager:
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _user_exists(
|
||||
client: httpx.AsyncClient,
|
||||
user_id: str,
|
||||
) -> bool:
|
||||
"""Check if a user exists in LiteLLM.
|
||||
|
||||
Returns True if the user exists, False otherwise.
|
||||
"""
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
return False
|
||||
try:
|
||||
response = await client.get(
|
||||
f'{LITE_LLM_API_URL}/user/info?user_id={user_id}',
|
||||
)
|
||||
if response.is_success:
|
||||
user_data = response.json()
|
||||
# Check that user_info exists and has the user_id
|
||||
user_info = user_data.get('user_info', {})
|
||||
return user_info.get('user_id') == user_id
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
'litellm_user_exists_check_failed',
|
||||
extra={'user_id': user_id, 'error': str(e)},
|
||||
)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def _create_user(
|
||||
client: httpx.AsyncClient,
|
||||
email: str | None,
|
||||
keycloak_user_id: str,
|
||||
):
|
||||
) -> bool:
|
||||
"""Create a user in LiteLLM.
|
||||
|
||||
Returns True if the user was created or already exists and is verified,
|
||||
False if creation failed and user does not exist.
|
||||
"""
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
return False
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/new',
|
||||
json={
|
||||
@@ -681,17 +781,33 @@ class LiteLlmManager:
|
||||
'user_id': keycloak_user_id,
|
||||
},
|
||||
)
|
||||
return
|
||||
# Verify the user actually exists before returning success
|
||||
user_exists = await LiteLlmManager._user_exists(
|
||||
client, keycloak_user_id
|
||||
)
|
||||
if not user_exists:
|
||||
logger.error(
|
||||
'litellm_user_claimed_exists_but_not_found',
|
||||
extra={
|
||||
'user_id': keycloak_user_id,
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
},
|
||||
)
|
||||
return False
|
||||
return True
|
||||
logger.error(
|
||||
'error_creating_litellm_user',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': [keycloak_user_id],
|
||||
'user_id': keycloak_user_id,
|
||||
'email': None,
|
||||
},
|
||||
)
|
||||
return False
|
||||
response.raise_for_status()
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def _get_user(client: httpx.AsyncClient, user_id: str) -> dict | None:
|
||||
@@ -918,19 +1034,34 @@ class LiteLlmManager:
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
team_id: str,
|
||||
max_budget: float,
|
||||
max_budget: float | None,
|
||||
):
|
||||
"""Add a user to a team in LiteLLM.
|
||||
|
||||
Args:
|
||||
client: The HTTP client to use.
|
||||
keycloak_user_id: The user's Keycloak ID.
|
||||
team_id: The team ID.
|
||||
max_budget: The maximum budget for the user in the team. When None,
|
||||
budget enforcement is disabled (unlimited usage).
|
||||
"""
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
|
||||
json_data: dict[str, Any] = {
|
||||
'team_id': team_id,
|
||||
'member': {'user_id': keycloak_user_id, 'role': 'user'},
|
||||
}
|
||||
|
||||
if max_budget is not None:
|
||||
json_data['max_budget_in_team'] = max_budget
|
||||
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/team/member_add',
|
||||
json={
|
||||
'team_id': team_id,
|
||||
'member': {'user_id': keycloak_user_id, 'role': 'user'},
|
||||
'max_budget_in_team': max_budget,
|
||||
},
|
||||
json=json_data,
|
||||
)
|
||||
|
||||
# Failed to add user to team - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
if (
|
||||
@@ -998,19 +1129,34 @@ class LiteLlmManager:
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
team_id: str,
|
||||
max_budget: float,
|
||||
max_budget: float | None,
|
||||
):
|
||||
"""Update a user's budget in a team.
|
||||
|
||||
Args:
|
||||
client: The HTTP client to use.
|
||||
keycloak_user_id: The user's Keycloak ID.
|
||||
team_id: The team ID.
|
||||
max_budget: The maximum budget for the user in the team. When None,
|
||||
budget enforcement is disabled (unlimited usage).
|
||||
"""
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
|
||||
json_data: dict[str, Any] = {
|
||||
'team_id': team_id,
|
||||
'user_id': keycloak_user_id,
|
||||
}
|
||||
|
||||
if max_budget is not None:
|
||||
json_data['max_budget_in_team'] = max_budget
|
||||
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/team/member_update',
|
||||
json={
|
||||
'team_id': team_id,
|
||||
'user_id': keycloak_user_id,
|
||||
'max_budget_in_team': max_budget,
|
||||
},
|
||||
json=json_data,
|
||||
)
|
||||
|
||||
# Failed to update user in team - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
logger.error(
|
||||
@@ -1397,6 +1543,7 @@ class LiteLlmManager:
|
||||
create_team = staticmethod(with_http_client(_create_team))
|
||||
get_team = staticmethod(with_http_client(_get_team))
|
||||
update_team = staticmethod(with_http_client(_update_team))
|
||||
user_exists = staticmethod(with_http_client(_user_exists))
|
||||
create_user = staticmethod(with_http_client(_create_user))
|
||||
get_user = staticmethod(with_http_client(_get_user))
|
||||
update_user = staticmethod(with_http_client(_update_user))
|
||||
|
||||
@@ -3,7 +3,7 @@ SQLAlchemy model for Organization-Member relationship.
|
||||
"""
|
||||
|
||||
from pydantic import SecretStr
|
||||
from sqlalchemy import UUID, Column, ForeignKey, Integer, String
|
||||
from sqlalchemy import JSON, UUID, Column, ForeignKey, Integer, String
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
from storage.encrypt_utils import decrypt_value, encrypt_value
|
||||
@@ -22,6 +22,8 @@ class OrgMember(Base): # type: ignore
|
||||
llm_model = Column(String, nullable=True)
|
||||
_llm_api_key_for_byor = Column(String, nullable=True)
|
||||
llm_base_url = Column(String, nullable=True)
|
||||
agent_settings = Column(JSON, nullable=False, default=dict)
|
||||
|
||||
status = Column(String, nullable=True)
|
||||
|
||||
# Relationships
|
||||
|
||||
@@ -17,6 +17,14 @@ from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
# Only these agent_settings keys are stored per member; org-wide settings live on Org.
|
||||
_MEMBER_SCOPED_AGENT_SETTINGS_KEYS = {
|
||||
'schema_version',
|
||||
'llm.model',
|
||||
'llm.base_url',
|
||||
'max_iterations',
|
||||
}
|
||||
|
||||
|
||||
class OrgMemberStore:
|
||||
"""Store for managing organization-member relationships."""
|
||||
@@ -159,12 +167,21 @@ class OrgMemberStore:
|
||||
|
||||
@staticmethod
|
||||
def get_kwargs_from_user_settings(user_settings: UserSettings):
|
||||
kwargs = {
|
||||
normalized: getattr(user_settings, normalized)
|
||||
for c in OrgMember.__table__.columns
|
||||
if (normalized := c.name.lstrip('_')) and hasattr(user_settings, normalized)
|
||||
settings = user_settings.to_settings()
|
||||
return {
|
||||
'llm_api_key': user_settings.llm_api_key,
|
||||
'llm_model': settings.llm_model,
|
||||
'llm_api_key_for_byor': user_settings.llm_api_key_for_byor,
|
||||
'llm_base_url': settings.llm_base_url,
|
||||
'max_iterations': settings.max_iterations,
|
||||
'agent_settings': {
|
||||
key: value
|
||||
for key, value in settings.normalized_agent_settings(
|
||||
strip_secret_values=True
|
||||
).items()
|
||||
if key in _MEMBER_SCOPED_AGENT_SETTINGS_KEYS
|
||||
},
|
||||
}
|
||||
return kwargs
|
||||
|
||||
@staticmethod
|
||||
async def get_org_members_count(
|
||||
|
||||
@@ -212,26 +212,30 @@ class OrgStore:
|
||||
|
||||
@staticmethod
|
||||
def get_kwargs_from_user_settings(user_settings: UserSettings):
|
||||
kwargs = {}
|
||||
|
||||
for c in Org.__table__.columns:
|
||||
# Normalize for lookup
|
||||
normalized = (
|
||||
c.name.removeprefix('_default_').removeprefix('default_').lstrip('_')
|
||||
)
|
||||
|
||||
if not hasattr(user_settings, normalized):
|
||||
continue
|
||||
|
||||
# ---- FIX: Output key should drop *only* leading "_" but preserve "default" ----
|
||||
key = c.name
|
||||
if key.startswith('_'):
|
||||
key = key[1:] # remove only the very first leading underscore
|
||||
|
||||
kwargs[key] = getattr(user_settings, normalized)
|
||||
|
||||
kwargs['org_version'] = user_settings.user_version
|
||||
return kwargs
|
||||
settings = user_settings.to_settings()
|
||||
return {
|
||||
'agent': settings.agent,
|
||||
'default_max_iterations': settings.max_iterations,
|
||||
'security_analyzer': settings.security_analyzer,
|
||||
'confirmation_mode': settings.confirmation_mode,
|
||||
'default_llm_model': settings.llm_model,
|
||||
'default_llm_base_url': settings.llm_base_url,
|
||||
'remote_runtime_resource_factor': user_settings.remote_runtime_resource_factor,
|
||||
'enable_default_condenser': settings.enable_default_condenser,
|
||||
'billing_margin': user_settings.billing_margin,
|
||||
'enable_proactive_conversation_starters': user_settings.enable_proactive_conversation_starters,
|
||||
'sandbox_base_container_image': user_settings.sandbox_base_container_image,
|
||||
'sandbox_runtime_container_image': user_settings.sandbox_runtime_container_image,
|
||||
'org_version': user_settings.user_version,
|
||||
'mcp_config': user_settings.mcp_config,
|
||||
'search_api_key': user_settings.search_api_key,
|
||||
'sandbox_api_key': user_settings.sandbox_api_key,
|
||||
'max_budget_per_task': user_settings.max_budget_per_task,
|
||||
'enable_solvability_analysis': user_settings.enable_solvability_analysis,
|
||||
'v1_enabled': user_settings.v1_enabled,
|
||||
'condenser_max_size': settings.condenser_max_size,
|
||||
'sandbox_grouping_strategy': user_settings.sandbox_grouping_strategy,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def persist_org_with_owner(
|
||||
|
||||
@@ -15,25 +15,27 @@ class SaasConversationValidator(ConversationValidator):
|
||||
|
||||
async def _validate_api_key(self, api_key: str) -> str | None:
|
||||
"""
|
||||
Validate an API key and return the user_id and github_user_id if valid.
|
||||
Validate an API key and return the user_id if valid.
|
||||
|
||||
Args:
|
||||
api_key: The API key to validate
|
||||
|
||||
Returns:
|
||||
A tuple of (user_id, github_user_id) if the API key is valid, None otherwise
|
||||
The user_id if the API key is valid, None otherwise
|
||||
"""
|
||||
try:
|
||||
token_manager = TokenManager()
|
||||
|
||||
# Validate the API key and get the user_id
|
||||
api_key_store = ApiKeyStore.get_instance()
|
||||
user_id = await api_key_store.validate_api_key(api_key)
|
||||
validation_result = await api_key_store.validate_api_key(api_key)
|
||||
|
||||
if not user_id:
|
||||
if not validation_result:
|
||||
logger.warning('Invalid API key')
|
||||
return None
|
||||
|
||||
user_id = validation_result.user_id
|
||||
|
||||
# Get the offline token for the user
|
||||
offline_token = await token_manager.load_offline_token(user_id)
|
||||
if not offline_token:
|
||||
|
||||
@@ -59,12 +59,15 @@ class SaasSecretsStore(SecretsStore):
|
||||
|
||||
async with a_session_maker() as session:
|
||||
# Incoming secrets are always the most updated ones
|
||||
# Delete all existing records and override with incoming ones
|
||||
await session.execute(
|
||||
delete(StoredCustomSecrets).filter(
|
||||
StoredCustomSecrets.keycloak_user_id == self.user_id
|
||||
)
|
||||
# Delete existing records for this user AND organization only
|
||||
delete_query = delete(StoredCustomSecrets).filter(
|
||||
StoredCustomSecrets.keycloak_user_id == self.user_id
|
||||
)
|
||||
if org_id is not None:
|
||||
delete_query = delete_query.filter(StoredCustomSecrets.org_id == org_id)
|
||||
else:
|
||||
delete_query = delete_query.filter(StoredCustomSecrets.org_id.is_(None))
|
||||
await session.execute(delete_query)
|
||||
|
||||
# Prepare the new secrets data
|
||||
kwargs = item.model_dump(context={'expose_secrets': True})
|
||||
|
||||
@@ -28,6 +28,14 @@ from openhands.server.settings import Settings
|
||||
from openhands.storage.settings.settings_store import SettingsStore
|
||||
from openhands.utils.llm import is_openhands_model
|
||||
|
||||
# Only these agent_settings keys are persisted on org_member; org-wide values live on Org.
|
||||
_MEMBER_SCOPED_AGENT_SETTINGS_KEYS = {
|
||||
'schema_version',
|
||||
'llm.model',
|
||||
'llm.base_url',
|
||||
'max_iterations',
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SaasSettingsStore(SettingsStore):
|
||||
@@ -69,6 +77,29 @@ class SaasSettingsStore(SettingsStore):
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@staticmethod
|
||||
def _member_scoped_agent_settings(agent_settings: dict) -> dict:
|
||||
return {
|
||||
key: value
|
||||
for key, value in agent_settings.items()
|
||||
if key in _MEMBER_SCOPED_AGENT_SETTINGS_KEYS
|
||||
}
|
||||
|
||||
async def _persist_agent_settings_async(
|
||||
self, org_id: uuid.UUID, agent_settings: dict
|
||||
) -> None:
|
||||
async with a_session_maker() as session:
|
||||
stmt = (
|
||||
update(OrgMember)
|
||||
.where(
|
||||
OrgMember.org_id == org_id,
|
||||
OrgMember.user_id == uuid.UUID(self.user_id),
|
||||
)
|
||||
.values(agent_settings=agent_settings)
|
||||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
async def load(self) -> Settings | None:
|
||||
user = await UserStore.get_user_by_id(self.user_id)
|
||||
if not user:
|
||||
@@ -115,6 +146,7 @@ class SaasSettingsStore(SettingsStore):
|
||||
kwargs['llm_api_key_for_byor'] = org_member.llm_api_key_for_byor
|
||||
if org_member.llm_base_url:
|
||||
kwargs['llm_base_url'] = org_member.llm_base_url
|
||||
kwargs['agent_settings'] = org_member.agent_settings or {}
|
||||
if org.v1_enabled is None:
|
||||
kwargs['v1_enabled'] = True
|
||||
# Apply default if sandbox_grouping_strategy is None in the database
|
||||
@@ -122,6 +154,11 @@ class SaasSettingsStore(SettingsStore):
|
||||
kwargs.pop('sandbox_grouping_strategy', None)
|
||||
|
||||
settings = Settings(**kwargs)
|
||||
persisted_agent_settings = self._member_scoped_agent_settings(
|
||||
settings.normalized_agent_settings(strip_secret_values=True)
|
||||
)
|
||||
if persisted_agent_settings != (org_member.agent_settings or {}):
|
||||
await self._persist_agent_settings_async(org_id, persisted_agent_settings)
|
||||
return settings
|
||||
|
||||
async def store(self, item: Settings):
|
||||
@@ -185,22 +222,23 @@ class SaasSettingsStore(SettingsStore):
|
||||
)
|
||||
|
||||
kwargs = item.model_dump(context={'expose_secrets': True})
|
||||
kwargs['agent_settings'] = self._member_scoped_agent_settings(
|
||||
item.normalized_agent_settings(strip_secret_values=True)
|
||||
)
|
||||
for model in (user, org, org_member):
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(model, key):
|
||||
setattr(model, key, value)
|
||||
|
||||
# Map Settings fields to Org fields with 'default_' prefix
|
||||
# The generic loop above doesn't update these because Org uses
|
||||
# 'default_llm_model' not 'llm_model', etc.
|
||||
# Use exclude_unset to only update explicitly-set fields (allows clearing with null)
|
||||
settings_data = item.model_dump(exclude_unset=True)
|
||||
if 'llm_model' in settings_data:
|
||||
org.default_llm_model = settings_data['llm_model']
|
||||
if 'llm_base_url' in settings_data:
|
||||
org.default_llm_base_url = settings_data['llm_base_url']
|
||||
if 'max_iterations' in settings_data:
|
||||
org.default_max_iterations = settings_data['max_iterations']
|
||||
# Map explicitly provided SDK-managed settings onto Org defaults.
|
||||
# These values now live in item.agent_settings, so inspect the
|
||||
# dotted keys directly instead of relying on model_dump().
|
||||
if 'llm.model' in item.agent_settings:
|
||||
org.default_llm_model = item.llm_model
|
||||
if 'llm.base_url' in item.agent_settings:
|
||||
org.default_llm_base_url = item.llm_base_url
|
||||
if 'max_iterations' in item.agent_settings:
|
||||
org.default_max_iterations = item.max_iterations
|
||||
|
||||
# Propagate LLM settings to all org members
|
||||
# This ensures all members see the same LLM configuration when an admin saves
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from server.constants import DEFAULT_BILLING_MARGIN
|
||||
from sqlalchemy import JSON, Boolean, Column, DateTime, Float, Identity, Integer, String
|
||||
from storage.base import Base
|
||||
@@ -8,17 +10,9 @@ class UserSettings(Base): # type: ignore
|
||||
id = Column(Integer, Identity(), primary_key=True)
|
||||
keycloak_user_id = Column(String, nullable=True, index=True)
|
||||
language = Column(String, nullable=True)
|
||||
agent = Column(String, nullable=True)
|
||||
max_iterations = Column(Integer, nullable=True)
|
||||
security_analyzer = Column(String, nullable=True)
|
||||
confirmation_mode = Column(Boolean, nullable=True, default=False)
|
||||
llm_model = Column(String, nullable=True)
|
||||
llm_api_key = Column(String, nullable=True)
|
||||
llm_api_key_for_byor = Column(String, nullable=True)
|
||||
llm_base_url = Column(String, nullable=True)
|
||||
remote_runtime_resource_factor = Column(Integer, nullable=True)
|
||||
enable_default_condenser = Column(Boolean, nullable=False, default=True)
|
||||
condenser_max_size = Column(Integer, nullable=True)
|
||||
user_consents_to_analytics = Column(Boolean, nullable=True)
|
||||
billing_margin = Column(Float, nullable=True, default=DEFAULT_BILLING_MARGIN)
|
||||
enable_sound_notifications = Column(Boolean, nullable=True, default=False)
|
||||
@@ -40,6 +34,16 @@ class UserSettings(Base): # type: ignore
|
||||
git_user_name = Column(String, nullable=True)
|
||||
git_user_email = Column(String, nullable=True)
|
||||
v1_enabled = Column(Boolean, nullable=True)
|
||||
agent_settings = Column(JSON, nullable=False, default=dict)
|
||||
|
||||
already_migrated = Column(
|
||||
Boolean, nullable=True, default=False
|
||||
) # False = not migrated, True = migrated
|
||||
|
||||
def to_settings(self):
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
return Settings(
|
||||
agent_settings=dict(self.agent_settings or {}),
|
||||
llm_api_key=self.llm_api_key,
|
||||
)
|
||||
|
||||
@@ -235,7 +235,7 @@ class UserStore:
|
||||
# if user has custom settings, set org defaults to current version
|
||||
if custom_settings:
|
||||
org_kwargs['default_llm_model'] = get_default_litellm_model()
|
||||
org_kwargs['llm_base_url'] = LITE_LLM_API_URL
|
||||
org_kwargs['default_llm_base_url'] = LITE_LLM_API_URL
|
||||
org_kwargs['org_version'] = ORG_SETTINGS_VERSION
|
||||
|
||||
for key, value in org_kwargs.items():
|
||||
@@ -975,19 +975,31 @@ class UserStore:
|
||||
'max_iterations', org_member.max_iterations
|
||||
)
|
||||
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
agent_settings = Settings(
|
||||
agent=org.agent,
|
||||
llm_model=llm_model,
|
||||
llm_api_key=org_member.llm_api_key.get_secret_value()
|
||||
if org_member.llm_api_key
|
||||
else None,
|
||||
llm_base_url=llm_base_url,
|
||||
max_iterations=max_iterations,
|
||||
confirmation_mode=org.confirmation_mode,
|
||||
security_analyzer=org.security_analyzer,
|
||||
enable_default_condenser=org.enable_default_condenser,
|
||||
condenser_max_size=org.condenser_max_size,
|
||||
agent_settings=org_member.agent_settings or {},
|
||||
).normalized_agent_settings(strip_secret_values=True)
|
||||
|
||||
return UserSettings(
|
||||
keycloak_user_id=user_id,
|
||||
# OrgMember fields
|
||||
llm_api_key=org_member.llm_api_key.get_secret_value()
|
||||
if org_member.llm_api_key
|
||||
else None,
|
||||
llm_api_key_for_byor=org_member.llm_api_key_for_byor.get_secret_value()
|
||||
if org_member.llm_api_key_for_byor
|
||||
else None,
|
||||
llm_model=llm_model,
|
||||
llm_base_url=llm_base_url,
|
||||
max_iterations=max_iterations,
|
||||
# User fields
|
||||
accepted_tos=user.accepted_tos,
|
||||
enable_sound_notifications=user.enable_sound_notifications,
|
||||
language=user.language,
|
||||
@@ -996,12 +1008,7 @@ class UserStore:
|
||||
email_verified=user.email_verified,
|
||||
git_user_name=user.git_user_name,
|
||||
git_user_email=user.git_user_email,
|
||||
# Org fields
|
||||
agent=org.agent,
|
||||
security_analyzer=org.security_analyzer,
|
||||
confirmation_mode=org.confirmation_mode,
|
||||
remote_runtime_resource_factor=org.remote_runtime_resource_factor,
|
||||
enable_default_condenser=org.enable_default_condenser,
|
||||
billing_margin=org.billing_margin,
|
||||
enable_proactive_conversation_starters=org.enable_proactive_conversation_starters,
|
||||
sandbox_base_container_image=org.sandbox_base_container_image,
|
||||
@@ -1017,7 +1024,8 @@ class UserStore:
|
||||
max_budget_per_task=org.max_budget_per_task,
|
||||
enable_solvability_analysis=org.enable_solvability_analysis,
|
||||
v1_enabled=org.v1_enabled,
|
||||
condenser_max_size=org.condenser_max_size,
|
||||
sandbox_grouping_strategy=org.sandbox_grouping_strategy,
|
||||
agent_settings=agent_settings,
|
||||
already_migrated=False,
|
||||
)
|
||||
|
||||
@@ -1035,15 +1043,12 @@ class UserStore:
|
||||
Returns:
|
||||
True if user has custom settings, False if using old defaults
|
||||
"""
|
||||
# Normalize values
|
||||
user_model = (
|
||||
user_settings.llm_model.strip() or None if user_settings.llm_model else None
|
||||
)
|
||||
settings = user_settings.to_settings()
|
||||
|
||||
user_model = settings.llm_model.strip() or None if settings.llm_model else None
|
||||
user_base_url = (
|
||||
user_settings.llm_base_url.strip() or None
|
||||
if user_settings.llm_base_url
|
||||
else None
|
||||
)
|
||||
settings.llm_base_url.strip() if settings.llm_base_url else None
|
||||
) or None
|
||||
|
||||
# Custom base_url = definitely custom settings (BYOK)
|
||||
if user_base_url and user_base_url != LITE_LLM_API_URL:
|
||||
|
||||
@@ -0,0 +1,331 @@
|
||||
"""Unit tests for service API routes."""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from server.routes.service import (
|
||||
CreateUserApiKeyRequest,
|
||||
delete_user_api_key,
|
||||
get_or_create_api_key_for_user,
|
||||
validate_service_api_key,
|
||||
)
|
||||
|
||||
|
||||
class TestValidateServiceApiKey:
|
||||
"""Test cases for validate_service_api_key."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_service_key(self):
|
||||
"""Test validation with valid service API key."""
|
||||
with patch(
|
||||
'server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-service-key'
|
||||
):
|
||||
result = await validate_service_api_key('test-service-key')
|
||||
assert result == 'automations-service'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_service_key(self):
|
||||
"""Test validation with missing service API key header."""
|
||||
with patch(
|
||||
'server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-service-key'
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await validate_service_api_key(None)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert 'X-Service-API-Key header is required' in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_service_key(self):
|
||||
"""Test validation with invalid service API key."""
|
||||
with patch(
|
||||
'server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-service-key'
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await validate_service_api_key('wrong-key')
|
||||
assert exc_info.value.status_code == 401
|
||||
assert 'Invalid service API key' in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_auth_not_configured(self):
|
||||
"""Test validation when service auth is not configured."""
|
||||
with patch('server.routes.service.AUTOMATIONS_SERVICE_API_KEY', ''):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await validate_service_api_key('any-key')
|
||||
assert exc_info.value.status_code == 503
|
||||
assert 'Service authentication not configured' in exc_info.value.detail
|
||||
|
||||
|
||||
class TestCreateUserApiKeyRequest:
|
||||
"""Test cases for CreateUserApiKeyRequest validation."""
|
||||
|
||||
def test_valid_request(self):
|
||||
"""Test valid request with all fields."""
|
||||
request = CreateUserApiKeyRequest(
|
||||
name='automation',
|
||||
)
|
||||
assert request.name == 'automation'
|
||||
|
||||
def test_name_is_required(self):
|
||||
"""Test that name field is required."""
|
||||
with pytest.raises(ValueError):
|
||||
CreateUserApiKeyRequest(
|
||||
name='', # Empty name should fail
|
||||
)
|
||||
|
||||
def test_name_is_stripped(self):
|
||||
"""Test that name field is stripped of whitespace."""
|
||||
request = CreateUserApiKeyRequest(
|
||||
name=' automation ',
|
||||
)
|
||||
assert request.name == 'automation'
|
||||
|
||||
def test_whitespace_only_name_fails(self):
|
||||
"""Test that whitespace-only name fails validation."""
|
||||
with pytest.raises(ValueError):
|
||||
CreateUserApiKeyRequest(
|
||||
name=' ',
|
||||
)
|
||||
|
||||
|
||||
class TestGetOrCreateApiKeyForUser:
|
||||
"""Test cases for get_or_create_api_key_for_user endpoint."""
|
||||
|
||||
@pytest.fixture
|
||||
def valid_user_id(self):
|
||||
"""Return a valid user ID."""
|
||||
return '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
|
||||
@pytest.fixture
|
||||
def valid_org_id(self):
|
||||
"""Return a valid org ID."""
|
||||
return uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081')
|
||||
|
||||
@pytest.fixture
|
||||
def valid_request(self):
|
||||
"""Create a valid request object."""
|
||||
return CreateUserApiKeyRequest(
|
||||
name='automation',
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_not_found(self, valid_user_id, valid_org_id, valid_request):
|
||||
"""Test error when user doesn't exist."""
|
||||
with patch('server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
'server.routes.service.UserStore.get_user_by_id', new_callable=AsyncMock
|
||||
) as mock_get_user:
|
||||
mock_get_user.return_value = None
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_or_create_api_key_for_user(
|
||||
user_id=valid_user_id,
|
||||
org_id=valid_org_id,
|
||||
request=valid_request,
|
||||
x_service_api_key='test-key',
|
||||
)
|
||||
assert exc_info.value.status_code == 404
|
||||
assert 'not found' in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_not_in_org(self, valid_user_id, valid_org_id, valid_request):
|
||||
"""Test error when user is not a member of the org."""
|
||||
mock_user = MagicMock()
|
||||
|
||||
with patch('server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
'server.routes.service.UserStore.get_user_by_id', new_callable=AsyncMock
|
||||
) as mock_get_user:
|
||||
with patch(
|
||||
'server.routes.service.OrgMemberStore.get_org_member',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_member:
|
||||
mock_get_user.return_value = mock_user
|
||||
mock_get_member.return_value = None
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_or_create_api_key_for_user(
|
||||
user_id=valid_user_id,
|
||||
org_id=valid_org_id,
|
||||
request=valid_request,
|
||||
x_service_api_key='test-key',
|
||||
)
|
||||
assert exc_info.value.status_code == 403
|
||||
assert 'not a member of org' in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_key_creation(
|
||||
self, valid_user_id, valid_org_id, valid_request
|
||||
):
|
||||
"""Test successful API key creation."""
|
||||
mock_user = MagicMock()
|
||||
mock_org_member = MagicMock()
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_store.get_or_create_system_api_key = AsyncMock(
|
||||
return_value='sk-oh-test-key-12345678901234567890'
|
||||
)
|
||||
|
||||
with patch('server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
'server.routes.service.UserStore.get_user_by_id', new_callable=AsyncMock
|
||||
) as mock_get_user:
|
||||
with patch(
|
||||
'server.routes.service.OrgMemberStore.get_org_member',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_member:
|
||||
with patch(
|
||||
'server.routes.service.ApiKeyStore.get_instance'
|
||||
) as mock_get_store:
|
||||
mock_get_user.return_value = mock_user
|
||||
mock_get_member.return_value = mock_org_member
|
||||
mock_get_store.return_value = mock_api_key_store
|
||||
|
||||
response = await get_or_create_api_key_for_user(
|
||||
user_id=valid_user_id,
|
||||
org_id=valid_org_id,
|
||||
request=valid_request,
|
||||
x_service_api_key='test-key',
|
||||
)
|
||||
|
||||
assert response.key == 'sk-oh-test-key-12345678901234567890'
|
||||
assert response.user_id == valid_user_id
|
||||
assert response.org_id == str(valid_org_id)
|
||||
assert response.name == 'automation'
|
||||
|
||||
# Verify the store was called with correct arguments
|
||||
mock_api_key_store.get_or_create_system_api_key.assert_called_once_with(
|
||||
user_id=valid_user_id,
|
||||
org_id=valid_org_id,
|
||||
name='automation',
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_exception_handling(
|
||||
self, valid_user_id, valid_org_id, valid_request
|
||||
):
|
||||
"""Test error handling when store raises exception."""
|
||||
mock_user = MagicMock()
|
||||
mock_org_member = MagicMock()
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_store.get_or_create_system_api_key = AsyncMock(
|
||||
side_effect=Exception('Database error')
|
||||
)
|
||||
|
||||
with patch('server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
'server.routes.service.UserStore.get_user_by_id', new_callable=AsyncMock
|
||||
) as mock_get_user:
|
||||
with patch(
|
||||
'server.routes.service.OrgMemberStore.get_org_member',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_member:
|
||||
with patch(
|
||||
'server.routes.service.ApiKeyStore.get_instance'
|
||||
) as mock_get_store:
|
||||
mock_get_user.return_value = mock_user
|
||||
mock_get_member.return_value = mock_org_member
|
||||
mock_get_store.return_value = mock_api_key_store
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_or_create_api_key_for_user(
|
||||
user_id=valid_user_id,
|
||||
org_id=valid_org_id,
|
||||
request=valid_request,
|
||||
x_service_api_key='test-key',
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert 'Failed to get or create API key' in exc_info.value.detail
|
||||
|
||||
|
||||
class TestDeleteUserApiKey:
|
||||
"""Test cases for delete_user_api_key endpoint."""
|
||||
|
||||
@pytest.fixture
|
||||
def valid_org_id(self):
|
||||
"""Return a valid org ID."""
|
||||
return uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_delete(self, valid_org_id):
|
||||
"""Test successful deletion of a system API key."""
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_store.make_system_key_name.return_value = '__SYSTEM__:automation'
|
||||
mock_api_key_store.delete_api_key_by_name = AsyncMock(return_value=True)
|
||||
|
||||
with patch('server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
'server.routes.service.ApiKeyStore.get_instance'
|
||||
) as mock_get_store:
|
||||
mock_get_store.return_value = mock_api_key_store
|
||||
|
||||
response = await delete_user_api_key(
|
||||
user_id='user-123',
|
||||
org_id=valid_org_id,
|
||||
key_name='automation',
|
||||
x_service_api_key='test-key',
|
||||
)
|
||||
|
||||
assert response == {'message': 'API key deleted successfully'}
|
||||
|
||||
# Verify the store was called with correct arguments
|
||||
mock_api_key_store.make_system_key_name.assert_called_once_with('automation')
|
||||
mock_api_key_store.delete_api_key_by_name.assert_called_once_with(
|
||||
user_id='user-123',
|
||||
org_id=valid_org_id,
|
||||
name='__SYSTEM__:automation',
|
||||
allow_system=True,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_key_not_found(self, valid_org_id):
|
||||
"""Test error when key to delete is not found."""
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_store.make_system_key_name.return_value = '__SYSTEM__:nonexistent'
|
||||
mock_api_key_store.delete_api_key_by_name = AsyncMock(return_value=False)
|
||||
|
||||
with patch('server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
'server.routes.service.ApiKeyStore.get_instance'
|
||||
) as mock_get_store:
|
||||
mock_get_store.return_value = mock_api_key_store
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await delete_user_api_key(
|
||||
user_id='user-123',
|
||||
org_id=valid_org_id,
|
||||
key_name='nonexistent',
|
||||
x_service_api_key='test-key',
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
assert 'not found' in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_invalid_service_key(self, valid_org_id):
|
||||
"""Test error when service API key is invalid."""
|
||||
with patch('server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-key'):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await delete_user_api_key(
|
||||
user_id='user-123',
|
||||
org_id=valid_org_id,
|
||||
key_name='automation',
|
||||
x_service_api_key='wrong-key',
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert 'Invalid service API key' in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_missing_service_key(self, valid_org_id):
|
||||
"""Test error when service API key header is missing."""
|
||||
with patch('server.routes.service.AUTOMATIONS_SERVICE_API_KEY', 'test-key'):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await delete_user_api_key(
|
||||
user_id='user-123',
|
||||
org_id=valid_org_id,
|
||||
key_name='automation',
|
||||
x_service_api_key=None,
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert 'X-Service-API-Key header is required' in exc_info.value.detail
|
||||
@@ -1,19 +1,26 @@
|
||||
"""Unit tests for API keys routes, focusing on BYOR key validation and retrieval."""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from pydantic import SecretStr
|
||||
from server.auth.saas_user_auth import SaasUserAuth
|
||||
from server.routes.api_keys import (
|
||||
ByorPermittedResponse,
|
||||
CurrentApiKeyResponse,
|
||||
LlmApiKeyResponse,
|
||||
check_byor_permitted,
|
||||
delete_byor_key_from_litellm,
|
||||
get_current_api_key,
|
||||
get_llm_api_key_for_byor,
|
||||
)
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
|
||||
from openhands.server.user_auth.user_auth import AuthType
|
||||
|
||||
|
||||
class TestVerifyByorKeyInLitellm:
|
||||
"""Test the verify_byor_key_in_litellm function."""
|
||||
@@ -512,3 +519,81 @@ class TestCheckByorPermitted:
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert 'Failed to check BYOR export permission' in exc_info.value.detail
|
||||
|
||||
|
||||
class TestGetCurrentApiKey:
|
||||
"""Test the get_current_api_key endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.api_keys.get_user_auth')
|
||||
async def test_returns_api_key_info_for_bearer_auth(self, mock_get_user_auth):
|
||||
"""Test that API key metadata including org_id is returned for bearer token auth."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
org_id = uuid.uuid4()
|
||||
mock_request = MagicMock()
|
||||
|
||||
user_auth = SaasUserAuth(
|
||||
refresh_token=SecretStr('mock-token'),
|
||||
user_id=user_id,
|
||||
auth_type=AuthType.BEARER,
|
||||
api_key_org_id=org_id,
|
||||
api_key_id=42,
|
||||
api_key_name='My Production Key',
|
||||
)
|
||||
mock_get_user_auth.return_value = user_auth
|
||||
|
||||
# Act
|
||||
result = await get_current_api_key(request=mock_request, user_id=user_id)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, CurrentApiKeyResponse)
|
||||
assert result.org_id == str(org_id)
|
||||
assert result.id == 42
|
||||
assert result.name == 'My Production Key'
|
||||
assert result.user_id == user_id
|
||||
assert result.auth_type == 'bearer'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.api_keys.get_user_auth')
|
||||
async def test_returns_400_for_cookie_auth(self, mock_get_user_auth):
|
||||
"""Test that 400 Bad Request is returned when using cookie authentication."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
mock_request = MagicMock()
|
||||
|
||||
mock_user_auth = MagicMock()
|
||||
mock_user_auth.get_auth_type.return_value = AuthType.COOKIE
|
||||
mock_get_user_auth.return_value = mock_user_auth
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_api_key(request=mock_request, user_id=user_id)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert 'API key authentication' in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.api_keys.get_user_auth')
|
||||
async def test_returns_400_when_api_key_org_id_is_none(self, mock_get_user_auth):
|
||||
"""Test that 400 is returned when API key has no org_id (legacy key)."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
mock_request = MagicMock()
|
||||
|
||||
user_auth = SaasUserAuth(
|
||||
refresh_token=SecretStr('mock-token'),
|
||||
user_id=user_id,
|
||||
auth_type=AuthType.BEARER,
|
||||
api_key_org_id=None, # No org_id - legacy key
|
||||
api_key_id=42,
|
||||
api_key_name='Legacy Key',
|
||||
)
|
||||
mock_get_user_auth.return_value = user_auth
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_api_key(request=mock_request, user_id=user_id)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert 'created before organization support' in exc_info.value.detail
|
||||
|
||||
@@ -0,0 +1,314 @@
|
||||
"""Unit tests for ApiKeyStore system key functionality."""
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from storage.api_key import ApiKey
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_key_store():
|
||||
"""Create ApiKeyStore instance."""
|
||||
return ApiKeyStore()
|
||||
|
||||
|
||||
class TestApiKeyStoreSystemKeys:
|
||||
"""Test cases for system API key functionality."""
|
||||
|
||||
def test_is_system_key_name_with_prefix(self, api_key_store):
|
||||
"""Test that names with __SYSTEM__: prefix are identified as system keys."""
|
||||
assert api_key_store.is_system_key_name('__SYSTEM__:automation') is True
|
||||
assert api_key_store.is_system_key_name('__SYSTEM__:test-key') is True
|
||||
assert api_key_store.is_system_key_name('__SYSTEM__:') is True
|
||||
|
||||
def test_is_system_key_name_without_prefix(self, api_key_store):
|
||||
"""Test that names without __SYSTEM__: prefix are not system keys."""
|
||||
assert api_key_store.is_system_key_name('my-key') is False
|
||||
assert api_key_store.is_system_key_name('automation') is False
|
||||
assert api_key_store.is_system_key_name('MCP_API_KEY') is False
|
||||
assert api_key_store.is_system_key_name('') is False
|
||||
|
||||
def test_is_system_key_name_none(self, api_key_store):
|
||||
"""Test that None is not a system key."""
|
||||
assert api_key_store.is_system_key_name(None) is False
|
||||
|
||||
def test_make_system_key_name(self, api_key_store):
|
||||
"""Test system key name generation."""
|
||||
assert (
|
||||
api_key_store.make_system_key_name('automation') == '__SYSTEM__:automation'
|
||||
)
|
||||
assert api_key_store.make_system_key_name('test-key') == '__SYSTEM__:test-key'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_create_system_api_key_creates_new(
|
||||
self, api_key_store, async_session_maker
|
||||
):
|
||||
"""Test creating a new system API key when none exists."""
|
||||
user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
org_id = uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081')
|
||||
key_name = 'automation'
|
||||
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
api_key = await api_key_store.get_or_create_system_api_key(
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name=key_name,
|
||||
)
|
||||
|
||||
assert api_key.startswith('sk-oh-')
|
||||
assert len(api_key) == len('sk-oh-') + 32
|
||||
|
||||
# Verify the key was created in the database
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(select(ApiKey).filter(ApiKey.key == api_key))
|
||||
key_record = result.scalars().first()
|
||||
assert key_record is not None
|
||||
assert key_record.user_id == user_id
|
||||
assert key_record.org_id == org_id
|
||||
assert key_record.name == '__SYSTEM__:automation'
|
||||
assert key_record.expires_at is None # System keys never expire
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_create_system_api_key_returns_existing(
|
||||
self, api_key_store, async_session_maker
|
||||
):
|
||||
"""Test that existing valid system key is returned."""
|
||||
user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
org_id = uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081')
|
||||
key_name = 'automation'
|
||||
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
# Create the first key
|
||||
first_key = await api_key_store.get_or_create_system_api_key(
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name=key_name,
|
||||
)
|
||||
|
||||
# Request again - should return the same key
|
||||
second_key = await api_key_store.get_or_create_system_api_key(
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name=key_name,
|
||||
)
|
||||
|
||||
assert first_key == second_key
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_create_system_api_key_different_names(
|
||||
self, api_key_store, async_session_maker
|
||||
):
|
||||
"""Test that different names create different keys."""
|
||||
user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
org_id = uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081')
|
||||
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
key1 = await api_key_store.get_or_create_system_api_key(
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name='automation-1',
|
||||
)
|
||||
|
||||
key2 = await api_key_store.get_or_create_system_api_key(
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name='automation-2',
|
||||
)
|
||||
|
||||
assert key1 != key2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_create_system_api_key_reissues_expired(
|
||||
self, api_key_store, async_session_maker
|
||||
):
|
||||
"""Test that expired system key is replaced with a new one."""
|
||||
user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
org_id = uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081')
|
||||
key_name = 'automation'
|
||||
system_key_name = '__SYSTEM__:automation'
|
||||
|
||||
# First, manually create an expired key
|
||||
expired_time = datetime.now(UTC) - timedelta(hours=1)
|
||||
async with async_session_maker() as session:
|
||||
expired_key = ApiKey(
|
||||
key='sk-oh-expired-key-12345678901234567890',
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name=system_key_name,
|
||||
expires_at=expired_time.replace(tzinfo=None),
|
||||
)
|
||||
session.add(expired_key)
|
||||
await session.commit()
|
||||
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
# Request the key - should create a new one
|
||||
new_key = await api_key_store.get_or_create_system_api_key(
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name=key_name,
|
||||
)
|
||||
|
||||
assert new_key != 'sk-oh-expired-key-12345678901234567890'
|
||||
assert new_key.startswith('sk-oh-')
|
||||
|
||||
# Verify old key was deleted and new key exists
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(ApiKey).filter(ApiKey.name == system_key_name)
|
||||
)
|
||||
keys = result.scalars().all()
|
||||
assert len(keys) == 1
|
||||
assert keys[0].key == new_key
|
||||
assert keys[0].expires_at is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_api_keys_excludes_system_keys(
|
||||
self, api_key_store, async_session_maker
|
||||
):
|
||||
"""Test that list_api_keys excludes system keys."""
|
||||
user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
org_id = uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081')
|
||||
|
||||
# Create a user key and a system key
|
||||
async with async_session_maker() as session:
|
||||
user_key = ApiKey(
|
||||
key='sk-oh-user-key-123456789012345678901',
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name='my-user-key',
|
||||
)
|
||||
system_key = ApiKey(
|
||||
key='sk-oh-system-key-12345678901234567890',
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name='__SYSTEM__:automation',
|
||||
)
|
||||
mcp_key = ApiKey(
|
||||
key='sk-oh-mcp-key-1234567890123456789012',
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name='MCP_API_KEY',
|
||||
)
|
||||
session.add(user_key)
|
||||
session.add(system_key)
|
||||
session.add(mcp_key)
|
||||
await session.commit()
|
||||
|
||||
# Mock UserStore.get_user_by_id to return a user with the correct org
|
||||
mock_user = MagicMock()
|
||||
mock_user.current_org_id = org_id
|
||||
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
with patch(
|
||||
'storage.api_key_store.UserStore.get_user_by_id', new_callable=AsyncMock
|
||||
) as mock_get_user:
|
||||
mock_get_user.return_value = mock_user
|
||||
keys = await api_key_store.list_api_keys(user_id)
|
||||
|
||||
# Should only return the user key
|
||||
assert len(keys) == 1
|
||||
assert keys[0].name == 'my-user-key'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_api_key_by_id_protects_system_keys(
|
||||
self, api_key_store, async_session_maker
|
||||
):
|
||||
"""Test that system keys cannot be deleted by users."""
|
||||
user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
org_id = uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081')
|
||||
|
||||
# Create a system key
|
||||
async with async_session_maker() as session:
|
||||
system_key = ApiKey(
|
||||
key='sk-oh-system-key-12345678901234567890',
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name='__SYSTEM__:automation',
|
||||
)
|
||||
session.add(system_key)
|
||||
await session.commit()
|
||||
key_id = system_key.id
|
||||
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
# Attempt to delete without allow_system flag
|
||||
result = await api_key_store.delete_api_key_by_id(
|
||||
key_id, allow_system=False
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
# Verify the key still exists
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(select(ApiKey).filter(ApiKey.id == key_id))
|
||||
key_record = result.scalars().first()
|
||||
assert key_record is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_api_key_by_id_allows_system_with_flag(
|
||||
self, api_key_store, async_session_maker
|
||||
):
|
||||
"""Test that system keys can be deleted with allow_system=True."""
|
||||
user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
org_id = uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081')
|
||||
|
||||
# Create a system key
|
||||
async with async_session_maker() as session:
|
||||
system_key = ApiKey(
|
||||
key='sk-oh-system-key-12345678901234567890',
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name='__SYSTEM__:automation',
|
||||
)
|
||||
session.add(system_key)
|
||||
await session.commit()
|
||||
key_id = system_key.id
|
||||
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
# Delete with allow_system=True
|
||||
result = await api_key_store.delete_api_key_by_id(key_id, allow_system=True)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify the key was deleted
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(select(ApiKey).filter(ApiKey.id == key_id))
|
||||
key_record = result.scalars().first()
|
||||
assert key_record is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_api_key_by_id_allows_regular_keys(
|
||||
self, api_key_store, async_session_maker
|
||||
):
|
||||
"""Test that regular keys can be deleted normally."""
|
||||
user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
org_id = uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081')
|
||||
|
||||
# Create a regular key
|
||||
async with async_session_maker() as session:
|
||||
regular_key = ApiKey(
|
||||
key='sk-oh-regular-key-1234567890123456789',
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name='my-regular-key',
|
||||
)
|
||||
session.add(regular_key)
|
||||
await session.commit()
|
||||
key_id = regular_key.id
|
||||
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
# Delete without allow_system flag - should work for regular keys
|
||||
result = await api_key_store.delete_api_key_by_id(
|
||||
key_id, allow_system=False
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify the key was deleted
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(select(ApiKey).filter(ApiKey.id == key_id))
|
||||
key_record = result.scalars().first()
|
||||
assert key_record is None
|
||||
@@ -5,7 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from storage.api_key import ApiKey
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
from storage.api_key_store import ApiKeyStore, ApiKeyValidationResult
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -110,8 +110,8 @@ async def test_create_api_key(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_api_key_valid(api_key_store, async_session_maker):
|
||||
"""Test validating a valid API key."""
|
||||
# Setup - create an API key in the database
|
||||
"""Test validating a valid API key returns user_id and org_id."""
|
||||
# Arrange
|
||||
user_id = str(uuid.uuid4())
|
||||
org_id = uuid.uuid4()
|
||||
api_key_value = 'test-api-key'
|
||||
@@ -126,13 +126,19 @@ async def test_validate_api_key_valid(api_key_store, async_session_maker):
|
||||
)
|
||||
session.add(key_record)
|
||||
await session.commit()
|
||||
key_id = key_record.id
|
||||
|
||||
# Execute - patch a_session_maker to use test's async session maker
|
||||
# Act
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
result = await api_key_store.validate_api_key(api_key_value)
|
||||
|
||||
# Verify
|
||||
assert result == user_id
|
||||
# Assert
|
||||
assert isinstance(result, ApiKeyValidationResult)
|
||||
assert result is not None
|
||||
assert result.user_id == user_id
|
||||
assert result.org_id == org_id
|
||||
assert result.key_id == key_id
|
||||
assert result.key_name == 'Test Key'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -197,7 +203,7 @@ async def test_validate_api_key_valid_timezone_naive(
|
||||
api_key_store, async_session_maker
|
||||
):
|
||||
"""Test validating a valid API key with timezone-naive datetime from database."""
|
||||
# Setup - create a valid API key with timezone-naive datetime (future date)
|
||||
# Arrange
|
||||
user_id = str(uuid.uuid4())
|
||||
org_id = uuid.uuid4()
|
||||
api_key_value = 'test-valid-naive-key'
|
||||
@@ -214,12 +220,44 @@ async def test_validate_api_key_valid_timezone_naive(
|
||||
session.add(key_record)
|
||||
await session.commit()
|
||||
|
||||
# Execute - patch a_session_maker to use test's async session maker
|
||||
# Act
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
result = await api_key_store.validate_api_key(api_key_value)
|
||||
|
||||
# Verify
|
||||
assert result == user_id
|
||||
# Assert
|
||||
assert isinstance(result, ApiKeyValidationResult)
|
||||
assert result.user_id == user_id
|
||||
assert result.org_id == org_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_api_key_legacy_without_org_id(
|
||||
api_key_store, async_session_maker
|
||||
):
|
||||
"""Test validating a legacy API key without org_id returns None for org_id."""
|
||||
# Arrange
|
||||
user_id = str(uuid.uuid4())
|
||||
api_key_value = 'test-legacy-key-no-org'
|
||||
|
||||
async with async_session_maker() as session:
|
||||
key_record = ApiKey(
|
||||
key=api_key_value,
|
||||
user_id=user_id,
|
||||
org_id=None, # Legacy key without org binding
|
||||
name='Legacy Key',
|
||||
)
|
||||
session.add(key_record)
|
||||
await session.commit()
|
||||
|
||||
# Act
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
result = await api_key_store.validate_api_key(api_key_value)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, ApiKeyValidationResult)
|
||||
assert result is not None
|
||||
assert result.user_id == user_id
|
||||
assert result.org_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -13,6 +13,7 @@ from server.auth.authorization import (
|
||||
ROLE_PERMISSIONS,
|
||||
Permission,
|
||||
RoleName,
|
||||
get_api_key_org_id_from_request,
|
||||
get_role_permissions,
|
||||
get_user_org_role,
|
||||
has_permission,
|
||||
@@ -444,6 +445,15 @@ class TestGetUserOrgRole:
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _create_mock_request(api_key_org_id=None):
|
||||
"""Helper to create a mock request with optional api_key_org_id."""
|
||||
mock_request = MagicMock()
|
||||
mock_user_auth = MagicMock()
|
||||
mock_user_auth.get_api_key_org_id.return_value = api_key_org_id
|
||||
mock_request.state.user_auth = mock_user_auth
|
||||
return mock_request
|
||||
|
||||
|
||||
class TestRequirePermission:
|
||||
"""Tests for require_permission dependency factory."""
|
||||
|
||||
@@ -456,6 +466,7 @@ class TestRequirePermission:
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
mock_request = _create_mock_request()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
@@ -465,7 +476,9 @@ class TestRequirePermission:
|
||||
AsyncMock(return_value=mock_role),
|
||||
):
|
||||
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||
result = await permission_checker(org_id=org_id, user_id=user_id)
|
||||
result = await permission_checker(
|
||||
request=mock_request, org_id=org_id, user_id=user_id
|
||||
)
|
||||
assert result == user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -476,10 +489,11 @@ class TestRequirePermission:
|
||||
THEN: 401 Unauthorized is raised
|
||||
"""
|
||||
org_id = uuid4()
|
||||
mock_request = _create_mock_request()
|
||||
|
||||
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await permission_checker(org_id=org_id, user_id=None)
|
||||
await permission_checker(request=mock_request, org_id=org_id, user_id=None)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert 'not authenticated' in exc_info.value.detail.lower()
|
||||
@@ -493,6 +507,7 @@ class TestRequirePermission:
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
mock_request = _create_mock_request()
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role',
|
||||
@@ -500,7 +515,9 @@ class TestRequirePermission:
|
||||
):
|
||||
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await permission_checker(org_id=org_id, user_id=user_id)
|
||||
await permission_checker(
|
||||
request=mock_request, org_id=org_id, user_id=user_id
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert 'not a member' in exc_info.value.detail.lower()
|
||||
@@ -514,6 +531,7 @@ class TestRequirePermission:
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
mock_request = _create_mock_request()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'member'
|
||||
@@ -524,7 +542,9 @@ class TestRequirePermission:
|
||||
):
|
||||
permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await permission_checker(org_id=org_id, user_id=user_id)
|
||||
await permission_checker(
|
||||
request=mock_request, org_id=org_id, user_id=user_id
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert 'delete_organization' in exc_info.value.detail.lower()
|
||||
@@ -538,6 +558,7 @@ class TestRequirePermission:
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
mock_request = _create_mock_request()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'owner'
|
||||
@@ -547,7 +568,9 @@ class TestRequirePermission:
|
||||
AsyncMock(return_value=mock_role),
|
||||
):
|
||||
permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
|
||||
result = await permission_checker(org_id=org_id, user_id=user_id)
|
||||
result = await permission_checker(
|
||||
request=mock_request, org_id=org_id, user_id=user_id
|
||||
)
|
||||
assert result == user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -559,6 +582,7 @@ class TestRequirePermission:
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
mock_request = _create_mock_request()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
@@ -569,7 +593,9 @@ class TestRequirePermission:
|
||||
):
|
||||
permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await permission_checker(org_id=org_id, user_id=user_id)
|
||||
await permission_checker(
|
||||
request=mock_request, org_id=org_id, user_id=user_id
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
@@ -582,6 +608,7 @@ class TestRequirePermission:
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
mock_request = _create_mock_request()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'member'
|
||||
@@ -595,7 +622,9 @@ class TestRequirePermission:
|
||||
):
|
||||
permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
|
||||
with pytest.raises(HTTPException):
|
||||
await permission_checker(org_id=org_id, user_id=user_id)
|
||||
await permission_checker(
|
||||
request=mock_request, org_id=org_id, user_id=user_id
|
||||
)
|
||||
|
||||
mock_logger.warning.assert_called()
|
||||
call_args = mock_logger.warning.call_args
|
||||
@@ -611,6 +640,7 @@ class TestRequirePermission:
|
||||
THEN: User ID is returned
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
mock_request = _create_mock_request()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
@@ -620,7 +650,9 @@ class TestRequirePermission:
|
||||
AsyncMock(return_value=mock_role),
|
||||
) as mock_get_role:
|
||||
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||
result = await permission_checker(org_id=None, user_id=user_id)
|
||||
result = await permission_checker(
|
||||
request=mock_request, org_id=None, user_id=user_id
|
||||
)
|
||||
assert result == user_id
|
||||
mock_get_role.assert_called_once_with(user_id, None)
|
||||
|
||||
@@ -632,6 +664,7 @@ class TestRequirePermission:
|
||||
THEN: HTTPException with 403 status is raised
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
mock_request = _create_mock_request()
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role',
|
||||
@@ -639,7 +672,9 @@ class TestRequirePermission:
|
||||
):
|
||||
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await permission_checker(org_id=None, user_id=user_id)
|
||||
await permission_checker(
|
||||
request=mock_request, org_id=None, user_id=user_id
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert 'not a member' in exc_info.value.detail
|
||||
@@ -662,6 +697,7 @@ class TestPermissionScenarios:
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
mock_request = _create_mock_request()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'member'
|
||||
@@ -671,7 +707,9 @@ class TestPermissionScenarios:
|
||||
AsyncMock(return_value=mock_role),
|
||||
):
|
||||
permission_checker = require_permission(Permission.MANAGE_SECRETS)
|
||||
result = await permission_checker(org_id=org_id, user_id=user_id)
|
||||
result = await permission_checker(
|
||||
request=mock_request, org_id=org_id, user_id=user_id
|
||||
)
|
||||
assert result == user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -683,6 +721,7 @@ class TestPermissionScenarios:
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
mock_request = _create_mock_request()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'member'
|
||||
@@ -695,7 +734,9 @@ class TestPermissionScenarios:
|
||||
Permission.INVITE_USER_TO_ORGANIZATION
|
||||
)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await permission_checker(org_id=org_id, user_id=user_id)
|
||||
await permission_checker(
|
||||
request=mock_request, org_id=org_id, user_id=user_id
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
@@ -708,6 +749,7 @@ class TestPermissionScenarios:
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
mock_request = _create_mock_request()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
@@ -719,7 +761,9 @@ class TestPermissionScenarios:
|
||||
permission_checker = require_permission(
|
||||
Permission.INVITE_USER_TO_ORGANIZATION
|
||||
)
|
||||
result = await permission_checker(org_id=org_id, user_id=user_id)
|
||||
result = await permission_checker(
|
||||
request=mock_request, org_id=org_id, user_id=user_id
|
||||
)
|
||||
assert result == user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -731,6 +775,7 @@ class TestPermissionScenarios:
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
mock_request = _create_mock_request()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
@@ -741,7 +786,9 @@ class TestPermissionScenarios:
|
||||
):
|
||||
permission_checker = require_permission(Permission.CHANGE_USER_ROLE_OWNER)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await permission_checker(org_id=org_id, user_id=user_id)
|
||||
await permission_checker(
|
||||
request=mock_request, org_id=org_id, user_id=user_id
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
@@ -754,6 +801,7 @@ class TestPermissionScenarios:
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
mock_request = _create_mock_request()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'owner'
|
||||
@@ -763,5 +811,200 @@ class TestPermissionScenarios:
|
||||
AsyncMock(return_value=mock_role),
|
||||
):
|
||||
permission_checker = require_permission(Permission.CHANGE_USER_ROLE_OWNER)
|
||||
result = await permission_checker(org_id=org_id, user_id=user_id)
|
||||
result = await permission_checker(
|
||||
request=mock_request, org_id=org_id, user_id=user_id
|
||||
)
|
||||
assert result == user_id
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests for API key organization validation
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestApiKeyOrgValidation:
|
||||
"""Tests for API key organization binding validation in require_permission."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allows_access_when_api_key_org_matches_target_org(self):
|
||||
"""
|
||||
GIVEN: API key with org_id that matches the target org_id in the request
|
||||
WHEN: Permission checker is called
|
||||
THEN: User ID is returned (access allowed)
|
||||
"""
|
||||
# Arrange
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
mock_request = _create_mock_request(api_key_org_id=org_id)
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
|
||||
# Act & Assert
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role',
|
||||
AsyncMock(return_value=mock_role),
|
||||
):
|
||||
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||
result = await permission_checker(
|
||||
request=mock_request, org_id=org_id, user_id=user_id
|
||||
)
|
||||
assert result == user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_denies_access_when_api_key_org_mismatches_target_org(self):
|
||||
"""
|
||||
GIVEN: API key created for Org A, but user tries to access Org B
|
||||
WHEN: Permission checker is called
|
||||
THEN: 403 Forbidden is raised with org mismatch message
|
||||
"""
|
||||
# Arrange
|
||||
user_id = str(uuid4())
|
||||
api_key_org_id = uuid4() # Org A - where API key was created
|
||||
target_org_id = uuid4() # Org B - where user is trying to access
|
||||
mock_request = _create_mock_request(api_key_org_id=api_key_org_id)
|
||||
|
||||
# Act & Assert
|
||||
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await permission_checker(
|
||||
request=mock_request, org_id=target_org_id, user_id=user_id
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert (
|
||||
'API key is not authorized for this organization' in exc_info.value.detail
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allows_access_for_legacy_api_key_without_org_binding(self):
|
||||
"""
|
||||
GIVEN: Legacy API key without org_id binding (org_id is None)
|
||||
WHEN: Permission checker is called
|
||||
THEN: Falls through to normal permission check (backward compatible)
|
||||
"""
|
||||
# Arrange
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
mock_request = _create_mock_request(api_key_org_id=None)
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
|
||||
# Act & Assert
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role',
|
||||
AsyncMock(return_value=mock_role),
|
||||
):
|
||||
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||
result = await permission_checker(
|
||||
request=mock_request, org_id=org_id, user_id=user_id
|
||||
)
|
||||
assert result == user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allows_access_for_cookie_auth_without_api_key_org_id(self):
|
||||
"""
|
||||
GIVEN: Cookie-based authentication (no api_key_org_id in user_auth)
|
||||
WHEN: Permission checker is called
|
||||
THEN: Falls through to normal permission check
|
||||
"""
|
||||
# Arrange
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
mock_request = _create_mock_request(api_key_org_id=None)
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
|
||||
# Act & Assert
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role',
|
||||
AsyncMock(return_value=mock_role),
|
||||
):
|
||||
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||
result = await permission_checker(
|
||||
request=mock_request, org_id=org_id, user_id=user_id
|
||||
)
|
||||
assert result == user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logs_warning_on_api_key_org_mismatch(self):
|
||||
"""
|
||||
GIVEN: API key org_id doesn't match target org_id
|
||||
WHEN: Permission checker is called
|
||||
THEN: Warning is logged with org mismatch details
|
||||
"""
|
||||
# Arrange
|
||||
user_id = str(uuid4())
|
||||
api_key_org_id = uuid4()
|
||||
target_org_id = uuid4()
|
||||
mock_request = _create_mock_request(api_key_org_id=api_key_org_id)
|
||||
|
||||
# Act & Assert
|
||||
with patch('server.auth.authorization.logger') as mock_logger:
|
||||
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||
with pytest.raises(HTTPException):
|
||||
await permission_checker(
|
||||
request=mock_request, org_id=target_org_id, user_id=user_id
|
||||
)
|
||||
|
||||
mock_logger.warning.assert_called()
|
||||
call_args = mock_logger.warning.call_args
|
||||
assert call_args[1]['extra']['user_id'] == user_id
|
||||
assert call_args[1]['extra']['api_key_org_id'] == str(api_key_org_id)
|
||||
assert call_args[1]['extra']['target_org_id'] == str(target_org_id)
|
||||
|
||||
|
||||
class TestGetApiKeyOrgIdFromRequest:
|
||||
"""Tests for get_api_key_org_id_from_request helper function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_org_id_when_user_auth_has_api_key_org_id(self):
|
||||
"""
|
||||
GIVEN: Request with user_auth that has api_key_org_id
|
||||
WHEN: get_api_key_org_id_from_request is called
|
||||
THEN: Returns the api_key_org_id
|
||||
"""
|
||||
# Arrange
|
||||
org_id = uuid4()
|
||||
mock_request = _create_mock_request(api_key_org_id=org_id)
|
||||
|
||||
# Act
|
||||
result = await get_api_key_org_id_from_request(mock_request)
|
||||
|
||||
# Assert
|
||||
assert result == org_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_user_auth_has_no_api_key_org_id(self):
|
||||
"""
|
||||
GIVEN: Request with user_auth that has no api_key_org_id (cookie auth)
|
||||
WHEN: get_api_key_org_id_from_request is called
|
||||
THEN: Returns None
|
||||
"""
|
||||
# Arrange
|
||||
mock_request = _create_mock_request(api_key_org_id=None)
|
||||
|
||||
# Act
|
||||
result = await get_api_key_org_id_from_request(mock_request)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_no_user_auth_in_request(self):
|
||||
"""
|
||||
GIVEN: Request without user_auth in state
|
||||
WHEN: get_api_key_org_id_from_request is called
|
||||
THEN: Returns None
|
||||
"""
|
||||
# Arrange
|
||||
mock_request = MagicMock()
|
||||
mock_request.state.user_auth = None
|
||||
|
||||
# Act
|
||||
result = await get_api_key_org_id_from_request(mock_request)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@@ -38,8 +38,9 @@ class TestDefaultInitialBudget:
|
||||
if 'storage.lite_llm_manager' in sys.modules:
|
||||
del sys.modules['storage.lite_llm_manager']
|
||||
|
||||
# Clear the env var
|
||||
# Clear the env vars
|
||||
os.environ.pop('DEFAULT_INITIAL_BUDGET', None)
|
||||
os.environ.pop('ENABLE_BILLING', None)
|
||||
|
||||
# Restore original module or reimport fresh
|
||||
if original_module is not None:
|
||||
@@ -47,31 +48,56 @@ class TestDefaultInitialBudget:
|
||||
else:
|
||||
importlib.import_module('storage.lite_llm_manager')
|
||||
|
||||
def test_default_initial_budget_defaults_to_zero(self):
|
||||
"""Test that DEFAULT_INITIAL_BUDGET defaults to 0.0 when env var not set."""
|
||||
def test_default_initial_budget_none_when_billing_disabled(self):
|
||||
"""Test that DEFAULT_INITIAL_BUDGET is None when billing is disabled."""
|
||||
# Temporarily remove the module so we can reimport with different env vars
|
||||
if 'storage.lite_llm_manager' in sys.modules:
|
||||
del sys.modules['storage.lite_llm_manager']
|
||||
|
||||
# Clear the env var and reimport
|
||||
# Ensure billing is disabled (default) and reimport
|
||||
os.environ.pop('ENABLE_BILLING', None)
|
||||
os.environ.pop('DEFAULT_INITIAL_BUDGET', None)
|
||||
module = importlib.import_module('storage.lite_llm_manager')
|
||||
assert module.DEFAULT_INITIAL_BUDGET is None
|
||||
|
||||
def test_default_initial_budget_defaults_to_zero_when_billing_enabled(self):
|
||||
"""Test that DEFAULT_INITIAL_BUDGET defaults to 0.0 when billing is enabled."""
|
||||
# Temporarily remove the module so we can reimport with different env vars
|
||||
if 'storage.lite_llm_manager' in sys.modules:
|
||||
del sys.modules['storage.lite_llm_manager']
|
||||
|
||||
# Enable billing and reimport
|
||||
os.environ['ENABLE_BILLING'] = 'true'
|
||||
os.environ.pop('DEFAULT_INITIAL_BUDGET', None)
|
||||
module = importlib.import_module('storage.lite_llm_manager')
|
||||
assert module.DEFAULT_INITIAL_BUDGET == 0.0
|
||||
|
||||
def test_default_initial_budget_uses_env_var(self):
|
||||
"""Test that DEFAULT_INITIAL_BUDGET uses value from environment variable."""
|
||||
def test_default_initial_budget_uses_env_var_when_billing_enabled(self):
|
||||
"""Test that DEFAULT_INITIAL_BUDGET uses value from environment variable when billing enabled."""
|
||||
if 'storage.lite_llm_manager' in sys.modules:
|
||||
del sys.modules['storage.lite_llm_manager']
|
||||
|
||||
os.environ['ENABLE_BILLING'] = 'true'
|
||||
os.environ['DEFAULT_INITIAL_BUDGET'] = '100.0'
|
||||
module = importlib.import_module('storage.lite_llm_manager')
|
||||
assert module.DEFAULT_INITIAL_BUDGET == 100.0
|
||||
|
||||
def test_default_initial_budget_ignores_env_var_when_billing_disabled(self):
|
||||
"""Test that DEFAULT_INITIAL_BUDGET returns None when billing disabled, ignoring env var."""
|
||||
if 'storage.lite_llm_manager' in sys.modules:
|
||||
del sys.modules['storage.lite_llm_manager']
|
||||
|
||||
os.environ.pop('ENABLE_BILLING', None) # billing disabled by default
|
||||
os.environ['DEFAULT_INITIAL_BUDGET'] = '100.0'
|
||||
module = importlib.import_module('storage.lite_llm_manager')
|
||||
assert module.DEFAULT_INITIAL_BUDGET is None
|
||||
|
||||
def test_default_initial_budget_rejects_invalid_value(self):
|
||||
"""Test that DEFAULT_INITIAL_BUDGET raises ValueError for invalid values."""
|
||||
if 'storage.lite_llm_manager' in sys.modules:
|
||||
del sys.modules['storage.lite_llm_manager']
|
||||
|
||||
os.environ['ENABLE_BILLING'] = 'true'
|
||||
os.environ['DEFAULT_INITIAL_BUDGET'] = 'abc'
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
importlib.import_module('storage.lite_llm_manager')
|
||||
@@ -82,6 +108,7 @@ class TestDefaultInitialBudget:
|
||||
if 'storage.lite_llm_manager' in sys.modules:
|
||||
del sys.modules['storage.lite_llm_manager']
|
||||
|
||||
os.environ['ENABLE_BILLING'] = 'true'
|
||||
os.environ['DEFAULT_INITIAL_BUDGET'] = '-10.0'
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
importlib.import_module('storage.lite_llm_manager')
|
||||
@@ -105,10 +132,12 @@ class TestLiteLlmManager:
|
||||
def mock_user_settings(self):
|
||||
"""Create a mock UserSettings object."""
|
||||
user_settings = UserSettings()
|
||||
user_settings.agent = 'TestAgent'
|
||||
user_settings.llm_model = 'test-model'
|
||||
user_settings.agent_settings = {
|
||||
'agent': 'TestAgent',
|
||||
'llm.model': 'test-model',
|
||||
'llm.base_url': 'http://test.com',
|
||||
}
|
||||
user_settings.llm_api_key = SecretStr('test-key')
|
||||
user_settings.llm_base_url = 'http://test.com'
|
||||
user_settings.user_version = 4 # Set version to avoid None comparison
|
||||
return user_settings
|
||||
|
||||
@@ -212,6 +241,16 @@ class TestLiteLlmManager:
|
||||
mock_404_response = MagicMock()
|
||||
mock_404_response.status_code = 404
|
||||
mock_404_response.is_success = False
|
||||
mock_404_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
message='Not Found', request=MagicMock(), response=mock_404_response
|
||||
)
|
||||
|
||||
# Mock user exists check response
|
||||
mock_user_exists_response = MagicMock()
|
||||
mock_user_exists_response.is_success = True
|
||||
mock_user_exists_response.json.return_value = {
|
||||
'user_info': {'user_id': 'test-user-id'}
|
||||
}
|
||||
|
||||
mock_token_manager = MagicMock()
|
||||
mock_token_manager.return_value.get_user_info_from_user_id = AsyncMock(
|
||||
@@ -219,12 +258,8 @@ class TestLiteLlmManager:
|
||||
)
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_404_response
|
||||
mock_client.get.return_value.raise_for_status.side_effect = (
|
||||
httpx.HTTPStatusError(
|
||||
message='Not Found', request=MagicMock(), response=mock_404_response
|
||||
)
|
||||
)
|
||||
# First GET is for _get_team (404), second GET is for _user_exists (success)
|
||||
mock_client.get.side_effect = [mock_404_response, mock_user_exists_response]
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
mock_client_class = MagicMock()
|
||||
@@ -247,8 +282,8 @@ class TestLiteLlmManager:
|
||||
assert result.llm_api_key.get_secret_value() == 'test-api-key'
|
||||
assert result.llm_base_url == 'http://test.com'
|
||||
|
||||
# Verify API calls were made (get_team + 4 posts)
|
||||
assert mock_client.get.call_count == 1 # get_team
|
||||
# Verify API calls were made (get_team + user_exists + 4 posts)
|
||||
assert mock_client.get.call_count == 2 # get_team + user_exists
|
||||
assert (
|
||||
mock_client.post.call_count == 4
|
||||
) # create_team, add_user_to_team, delete_key_by_alias, generate_key
|
||||
@@ -267,13 +302,21 @@ class TestLiteLlmManager:
|
||||
}
|
||||
mock_team_response.raise_for_status = MagicMock()
|
||||
|
||||
# Mock user exists check response
|
||||
mock_user_exists_response = MagicMock()
|
||||
mock_user_exists_response.is_success = True
|
||||
mock_user_exists_response.json.return_value = {
|
||||
'user_info': {'user_id': 'test-user-id'}
|
||||
}
|
||||
|
||||
mock_token_manager = MagicMock()
|
||||
mock_token_manager.return_value.get_user_info_from_user_id = AsyncMock(
|
||||
return_value={'email': 'test@example.com'}
|
||||
)
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_team_response
|
||||
# First GET is for _get_team (success), second GET is for _user_exists (success)
|
||||
mock_client.get.side_effect = [mock_team_response, mock_user_exists_response]
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
mock_client_class = MagicMock()
|
||||
@@ -293,8 +336,8 @@ class TestLiteLlmManager:
|
||||
assert result is not None
|
||||
|
||||
# Verify _get_team was called first
|
||||
mock_client.get.assert_called_once()
|
||||
get_call_url = mock_client.get.call_args[0][0]
|
||||
assert mock_client.get.call_count == 2 # get_team + user_exists
|
||||
get_call_url = mock_client.get.call_args_list[0][0][0]
|
||||
assert 'team/info' in get_call_url
|
||||
assert 'test-org-id' in get_call_url
|
||||
|
||||
@@ -316,19 +359,25 @@ class TestLiteLlmManager:
|
||||
mock_404_response = MagicMock()
|
||||
mock_404_response.status_code = 404
|
||||
mock_404_response.is_success = False
|
||||
mock_404_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
message='Not Found', request=MagicMock(), response=mock_404_response
|
||||
)
|
||||
|
||||
mock_token_manager = MagicMock()
|
||||
mock_token_manager.return_value.get_user_info_from_user_id = AsyncMock(
|
||||
return_value={'email': 'test@example.com'}
|
||||
)
|
||||
|
||||
# Mock user exists check response
|
||||
mock_user_exists_response = MagicMock()
|
||||
mock_user_exists_response.is_success = True
|
||||
mock_user_exists_response.json.return_value = {
|
||||
'user_info': {'user_id': 'test-user-id'}
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_404_response
|
||||
mock_client.get.return_value.raise_for_status.side_effect = (
|
||||
httpx.HTTPStatusError(
|
||||
message='Not Found', request=MagicMock(), response=mock_404_response
|
||||
)
|
||||
)
|
||||
# First GET is for _get_team (404), second GET is for _user_exists (success)
|
||||
mock_client.get.side_effect = [mock_404_response, mock_user_exists_response]
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
mock_client_class = MagicMock()
|
||||
@@ -366,6 +415,16 @@ class TestLiteLlmManager:
|
||||
mock_404_response = MagicMock()
|
||||
mock_404_response.status_code = 404
|
||||
mock_404_response.is_success = False
|
||||
mock_404_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
message='Not Found', request=MagicMock(), response=mock_404_response
|
||||
)
|
||||
|
||||
# Mock user exists check response
|
||||
mock_user_exists_response = MagicMock()
|
||||
mock_user_exists_response.is_success = True
|
||||
mock_user_exists_response.json.return_value = {
|
||||
'user_info': {'user_id': 'test-user-id'}
|
||||
}
|
||||
|
||||
mock_token_manager = MagicMock()
|
||||
mock_token_manager.return_value.get_user_info_from_user_id = AsyncMock(
|
||||
@@ -373,12 +432,8 @@ class TestLiteLlmManager:
|
||||
)
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_404_response
|
||||
mock_client.get.return_value.raise_for_status.side_effect = (
|
||||
httpx.HTTPStatusError(
|
||||
message='Not Found', request=MagicMock(), response=mock_404_response
|
||||
)
|
||||
)
|
||||
# First GET is for _get_team (404), second GET is for _user_exists (success)
|
||||
mock_client.get.side_effect = [mock_404_response, mock_user_exists_response]
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
mock_client_class = MagicMock()
|
||||
@@ -477,10 +532,11 @@ class TestLiteLlmManager:
|
||||
|
||||
# migrate_entries returns the user_settings unchanged
|
||||
assert result is not None
|
||||
assert result.agent == 'TestAgent'
|
||||
assert result.llm_model == 'test-model'
|
||||
effective_settings = result.to_settings()
|
||||
assert effective_settings.agent == 'TestAgent'
|
||||
assert effective_settings.llm_model == 'test-model'
|
||||
assert result.llm_api_key.get_secret_value() == 'test-key'
|
||||
assert result.llm_base_url == 'http://test.com'
|
||||
assert effective_settings.llm_base_url == 'http://test.com'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migrate_entries_no_user_found(self, mock_user_settings):
|
||||
@@ -601,10 +657,11 @@ class TestLiteLlmManager:
|
||||
|
||||
# migrate_entries returns the user_settings unchanged
|
||||
assert result is not None
|
||||
assert result.agent == 'TestAgent'
|
||||
assert result.llm_model == 'test-model'
|
||||
effective_settings = result.to_settings()
|
||||
assert effective_settings.agent == 'TestAgent'
|
||||
assert effective_settings.llm_model == 'test-model'
|
||||
assert result.llm_api_key.get_secret_value() == 'test-key'
|
||||
assert result.llm_base_url == 'http://test.com'
|
||||
assert effective_settings.llm_base_url == 'http://test.com'
|
||||
|
||||
# Verify migration steps were called:
|
||||
# - 2 GET requests: _get_user, _get_user_keys
|
||||
@@ -806,15 +863,16 @@ class TestLiteLlmManager:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_success(self, mock_http_client, mock_response):
|
||||
"""Test successful _create_user operation."""
|
||||
"""Test successful _create_user operation returns True."""
|
||||
mock_http_client.post.return_value = mock_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
await LiteLlmManager._create_user(
|
||||
result = await LiteLlmManager._create_user(
|
||||
mock_http_client, 'test@example.com', 'test-user-id'
|
||||
)
|
||||
|
||||
assert result is True
|
||||
mock_http_client.post.assert_called_once()
|
||||
call_args = mock_http_client.post.call_args
|
||||
assert 'http://test.com/user/new' in call_args[0]
|
||||
@@ -823,7 +881,7 @@ class TestLiteLlmManager:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_duplicate_email(self, mock_http_client, mock_response):
|
||||
"""Test _create_user with duplicate email handling."""
|
||||
"""Test _create_user with duplicate email handling returns True."""
|
||||
# First call fails with duplicate email
|
||||
error_response = MagicMock()
|
||||
error_response.is_success = False
|
||||
@@ -835,23 +893,81 @@ class TestLiteLlmManager:
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
await LiteLlmManager._create_user(
|
||||
result = await LiteLlmManager._create_user(
|
||||
mock_http_client, 'test@example.com', 'test-user-id'
|
||||
)
|
||||
|
||||
assert result is True
|
||||
assert mock_http_client.post.call_count == 2
|
||||
# Second call should have None email
|
||||
second_call_args = mock_http_client.post.call_args_list[1]
|
||||
assert second_call_args[1]['json']['user_email'] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com')
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key')
|
||||
async def test_user_exists_returns_true(self, mock_http_client):
|
||||
"""Test _user_exists returns True when user exists in LiteLLM."""
|
||||
# Arrange
|
||||
user_response = MagicMock()
|
||||
user_response.is_success = True
|
||||
user_response.json.return_value = {
|
||||
'user_info': {'user_id': 'test-user-id', 'email': 'test@example.com'}
|
||||
}
|
||||
mock_http_client.get.return_value = user_response
|
||||
|
||||
# Act
|
||||
result = await LiteLlmManager._user_exists(mock_http_client, 'test-user-id')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_http_client.get.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com')
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key')
|
||||
async def test_user_exists_returns_false_when_not_found(self, mock_http_client):
|
||||
"""Test _user_exists returns False when user not found."""
|
||||
# Arrange
|
||||
user_response = MagicMock()
|
||||
user_response.is_success = False
|
||||
mock_http_client.get.return_value = user_response
|
||||
|
||||
# Act
|
||||
result = await LiteLlmManager._user_exists(mock_http_client, 'test-user-id')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com')
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key')
|
||||
async def test_user_exists_returns_false_on_mismatched_user_id(
|
||||
self, mock_http_client
|
||||
):
|
||||
"""Test _user_exists returns False when returned user_id doesn't match."""
|
||||
# Arrange
|
||||
user_response = MagicMock()
|
||||
user_response.is_success = True
|
||||
user_response.json.return_value = {
|
||||
'user_info': {'user_id': 'different-user-id'}
|
||||
}
|
||||
mock_http_client.get.return_value = user_response
|
||||
|
||||
# Act
|
||||
result = await LiteLlmManager._user_exists(mock_http_client, 'test-user-id')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.lite_llm_manager.logger')
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com')
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key')
|
||||
async def test_create_user_already_exists_with_409_status_code(
|
||||
async def test_create_user_already_exists_and_verified(
|
||||
self, mock_logger, mock_http_client
|
||||
):
|
||||
"""Test _create_user handles 409 Conflict when user already exists."""
|
||||
"""Test _create_user returns True when user already exists and is verified."""
|
||||
# Arrange
|
||||
first_response = MagicMock()
|
||||
first_response.is_success = False
|
||||
@@ -863,14 +979,141 @@ class TestLiteLlmManager:
|
||||
second_response.status_code = 409
|
||||
second_response.text = 'User with id test-user-id already exists'
|
||||
|
||||
user_exists_response = MagicMock()
|
||||
user_exists_response.is_success = True
|
||||
user_exists_response.json.return_value = {
|
||||
'user_info': {'user_id': 'test-user-id'}
|
||||
}
|
||||
|
||||
mock_http_client.post.side_effect = [first_response, second_response]
|
||||
mock_http_client.get.return_value = user_exists_response
|
||||
|
||||
# Act
|
||||
await LiteLlmManager._create_user(
|
||||
result = await LiteLlmManager._create_user(
|
||||
mock_http_client, 'test@example.com', 'test-user-id'
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_logger.warning.assert_any_call(
|
||||
'litellm_user_already_exists',
|
||||
extra={'user_id': 'test-user-id'},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.lite_llm_manager.logger')
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com')
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key')
|
||||
async def test_create_user_already_exists_but_not_found_returns_false(
|
||||
self, mock_logger, mock_http_client
|
||||
):
|
||||
"""Test _create_user returns False when LiteLLM claims user exists but verification fails."""
|
||||
# Arrange
|
||||
first_response = MagicMock()
|
||||
first_response.is_success = False
|
||||
first_response.status_code = 400
|
||||
first_response.text = 'duplicate email'
|
||||
|
||||
second_response = MagicMock()
|
||||
second_response.is_success = False
|
||||
second_response.status_code = 409
|
||||
second_response.text = 'User with id test-user-id already exists'
|
||||
|
||||
user_not_exists_response = MagicMock()
|
||||
user_not_exists_response.is_success = False
|
||||
|
||||
mock_http_client.post.side_effect = [first_response, second_response]
|
||||
mock_http_client.get.return_value = user_not_exists_response
|
||||
|
||||
# Act
|
||||
result = await LiteLlmManager._create_user(
|
||||
mock_http_client, 'test@example.com', 'test-user-id'
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
mock_logger.error.assert_any_call(
|
||||
'litellm_user_claimed_exists_but_not_found',
|
||||
extra={
|
||||
'user_id': 'test-user-id',
|
||||
'status_code': 409,
|
||||
'text': 'User with id test-user-id already exists',
|
||||
},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.lite_llm_manager.logger')
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com')
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key')
|
||||
async def test_create_user_failure_returns_false(
|
||||
self, mock_logger, mock_http_client
|
||||
):
|
||||
"""Test _create_user returns False when creation fails with non-'already exists' error."""
|
||||
# Arrange
|
||||
first_response = MagicMock()
|
||||
first_response.is_success = False
|
||||
first_response.status_code = 400
|
||||
first_response.text = 'duplicate email'
|
||||
|
||||
second_response = MagicMock()
|
||||
second_response.is_success = False
|
||||
second_response.status_code = 500
|
||||
second_response.text = 'Internal server error'
|
||||
|
||||
mock_http_client.post.side_effect = [first_response, second_response]
|
||||
|
||||
# Act
|
||||
result = await LiteLlmManager._create_user(
|
||||
mock_http_client, 'test@example.com', 'test-user-id'
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
mock_logger.error.assert_any_call(
|
||||
'error_creating_litellm_user',
|
||||
extra={
|
||||
'status_code': 500,
|
||||
'text': 'Internal server error',
|
||||
'user_id': 'test-user-id',
|
||||
'email': None,
|
||||
},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.lite_llm_manager.logger')
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com')
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key')
|
||||
async def test_create_user_already_exists_with_409_status_code(
|
||||
self, mock_logger, mock_http_client
|
||||
):
|
||||
"""Test _create_user handles 409 Conflict when user already exists and verifies."""
|
||||
# Arrange
|
||||
first_response = MagicMock()
|
||||
first_response.is_success = False
|
||||
first_response.status_code = 400
|
||||
first_response.text = 'duplicate email'
|
||||
|
||||
second_response = MagicMock()
|
||||
second_response.is_success = False
|
||||
second_response.status_code = 409
|
||||
second_response.text = 'User with id test-user-id already exists'
|
||||
|
||||
user_exists_response = MagicMock()
|
||||
user_exists_response.is_success = True
|
||||
user_exists_response.json.return_value = {
|
||||
'user_info': {'user_id': 'test-user-id'}
|
||||
}
|
||||
|
||||
mock_http_client.post.side_effect = [first_response, second_response]
|
||||
mock_http_client.get.return_value = user_exists_response
|
||||
|
||||
# Act
|
||||
result = await LiteLlmManager._create_user(
|
||||
mock_http_client, 'test@example.com', 'test-user-id'
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_logger.warning.assert_any_call(
|
||||
'litellm_user_already_exists',
|
||||
extra={'user_id': 'test-user-id'},
|
||||
@@ -883,7 +1126,7 @@ class TestLiteLlmManager:
|
||||
async def test_create_user_already_exists_with_400_status_code(
|
||||
self, mock_logger, mock_http_client
|
||||
):
|
||||
"""Test _create_user handles 400 Bad Request when user already exists."""
|
||||
"""Test _create_user handles 400 Bad Request when user already exists and verifies."""
|
||||
# Arrange
|
||||
first_response = MagicMock()
|
||||
first_response.is_success = False
|
||||
@@ -895,14 +1138,22 @@ class TestLiteLlmManager:
|
||||
second_response.status_code = 400
|
||||
second_response.text = 'User already exists'
|
||||
|
||||
user_exists_response = MagicMock()
|
||||
user_exists_response.is_success = True
|
||||
user_exists_response.json.return_value = {
|
||||
'user_info': {'user_id': 'test-user-id'}
|
||||
}
|
||||
|
||||
mock_http_client.post.side_effect = [first_response, second_response]
|
||||
mock_http_client.get.return_value = user_exists_response
|
||||
|
||||
# Act
|
||||
await LiteLlmManager._create_user(
|
||||
result = await LiteLlmManager._create_user(
|
||||
mock_http_client, 'test@example.com', 'test-user-id'
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_logger.warning.assert_any_call(
|
||||
'litellm_user_already_exists',
|
||||
extra={'user_id': 'test-user-id'},
|
||||
@@ -1784,7 +2035,7 @@ class TestLiteLlmManager:
|
||||
|
||||
# downgrade_entries returns the user_settings
|
||||
assert result is not None
|
||||
assert result.agent == 'TestAgent'
|
||||
assert result.to_settings().agent == 'TestAgent'
|
||||
|
||||
# Verify downgrade steps were called:
|
||||
# GET requests:
|
||||
@@ -1818,7 +2069,7 @@ class TestLiteLlmManager:
|
||||
# In local deployment, should return user_settings without
|
||||
# making any LiteLLM calls
|
||||
assert result is not None
|
||||
assert result.agent == 'TestAgent'
|
||||
assert result.to_settings().agent == 'TestAgent'
|
||||
|
||||
|
||||
class TestGetAllKeysForUser:
|
||||
@@ -2137,3 +2388,195 @@ class TestVerifyExistingKey:
|
||||
openhands_type=True,
|
||||
)
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestBudgetPayloadHandling:
|
||||
"""Test cases for budget field handling in API payloads.
|
||||
|
||||
These tests verify that when max_budget is None, the budget field is NOT
|
||||
included in the JSON payload (which tells LiteLLM to disable budget
|
||||
enforcement), and when max_budget has a value, it IS included.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_team_excludes_max_budget_when_none(self):
|
||||
"""Test that _create_team does NOT include max_budget when it is None."""
|
||||
mock_client = AsyncMock(spec=httpx.AsyncClient)
|
||||
mock_response = MagicMock()
|
||||
mock_response.is_success = True
|
||||
mock_response.status_code = 200
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-api-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
await LiteLlmManager._create_team(
|
||||
mock_client,
|
||||
team_alias='test-team',
|
||||
team_id='test-team-id',
|
||||
max_budget=None, # None = no budget limit
|
||||
)
|
||||
|
||||
# Verify the call was made
|
||||
mock_client.post.assert_called_once()
|
||||
call_args = mock_client.post.call_args
|
||||
|
||||
# Verify URL
|
||||
assert call_args[0][0] == 'http://test.com/team/new'
|
||||
|
||||
# Verify that max_budget is NOT in the JSON payload
|
||||
json_payload = call_args[1]['json']
|
||||
assert 'max_budget' not in json_payload, (
|
||||
'max_budget should NOT be in payload when None '
|
||||
'(omitting it tells LiteLLM to disable budget enforcement)'
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_team_includes_max_budget_when_set(self):
|
||||
"""Test that _create_team includes max_budget when it has a value."""
|
||||
mock_client = AsyncMock(spec=httpx.AsyncClient)
|
||||
mock_response = MagicMock()
|
||||
mock_response.is_success = True
|
||||
mock_response.status_code = 200
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-api-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
await LiteLlmManager._create_team(
|
||||
mock_client,
|
||||
team_alias='test-team',
|
||||
team_id='test-team-id',
|
||||
max_budget=100.0, # Explicit budget limit
|
||||
)
|
||||
|
||||
# Verify the call was made
|
||||
mock_client.post.assert_called_once()
|
||||
call_args = mock_client.post.call_args
|
||||
|
||||
# Verify that max_budget IS in the JSON payload with the correct value
|
||||
json_payload = call_args[1]['json']
|
||||
assert (
|
||||
'max_budget' in json_payload
|
||||
), 'max_budget should be in payload when set to a value'
|
||||
assert json_payload['max_budget'] == 100.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_user_to_team_excludes_max_budget_when_none(self):
|
||||
"""Test that _add_user_to_team does NOT include max_budget_in_team when None."""
|
||||
mock_client = AsyncMock(spec=httpx.AsyncClient)
|
||||
mock_response = MagicMock()
|
||||
mock_response.is_success = True
|
||||
mock_response.status_code = 200
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-api-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
await LiteLlmManager._add_user_to_team(
|
||||
mock_client,
|
||||
keycloak_user_id='test-user-id',
|
||||
team_id='test-team-id',
|
||||
max_budget=None, # None = no budget limit
|
||||
)
|
||||
|
||||
# Verify the call was made
|
||||
mock_client.post.assert_called_once()
|
||||
call_args = mock_client.post.call_args
|
||||
|
||||
# Verify URL
|
||||
assert call_args[0][0] == 'http://test.com/team/member_add'
|
||||
|
||||
# Verify that max_budget_in_team is NOT in the JSON payload
|
||||
json_payload = call_args[1]['json']
|
||||
assert 'max_budget_in_team' not in json_payload, (
|
||||
'max_budget_in_team should NOT be in payload when None '
|
||||
'(omitting it tells LiteLLM to disable budget enforcement)'
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_user_to_team_includes_max_budget_when_set(self):
|
||||
"""Test that _add_user_to_team includes max_budget_in_team when set."""
|
||||
mock_client = AsyncMock(spec=httpx.AsyncClient)
|
||||
mock_response = MagicMock()
|
||||
mock_response.is_success = True
|
||||
mock_response.status_code = 200
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-api-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
await LiteLlmManager._add_user_to_team(
|
||||
mock_client,
|
||||
keycloak_user_id='test-user-id',
|
||||
team_id='test-team-id',
|
||||
max_budget=50.0, # Explicit budget limit
|
||||
)
|
||||
|
||||
# Verify the call was made
|
||||
mock_client.post.assert_called_once()
|
||||
call_args = mock_client.post.call_args
|
||||
|
||||
# Verify that max_budget_in_team IS in the JSON payload
|
||||
json_payload = call_args[1]['json']
|
||||
assert (
|
||||
'max_budget_in_team' in json_payload
|
||||
), 'max_budget_in_team should be in payload when set to a value'
|
||||
assert json_payload['max_budget_in_team'] == 50.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_in_team_excludes_max_budget_when_none(self):
|
||||
"""Test that _update_user_in_team does NOT include max_budget_in_team when None."""
|
||||
mock_client = AsyncMock(spec=httpx.AsyncClient)
|
||||
mock_response = MagicMock()
|
||||
mock_response.is_success = True
|
||||
mock_response.status_code = 200
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-api-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
await LiteLlmManager._update_user_in_team(
|
||||
mock_client,
|
||||
keycloak_user_id='test-user-id',
|
||||
team_id='test-team-id',
|
||||
max_budget=None, # None = no budget limit
|
||||
)
|
||||
|
||||
# Verify the call was made
|
||||
mock_client.post.assert_called_once()
|
||||
call_args = mock_client.post.call_args
|
||||
|
||||
# Verify URL
|
||||
assert call_args[0][0] == 'http://test.com/team/member_update'
|
||||
|
||||
# Verify that max_budget_in_team is NOT in the JSON payload
|
||||
json_payload = call_args[1]['json']
|
||||
assert 'max_budget_in_team' not in json_payload, (
|
||||
'max_budget_in_team should NOT be in payload when None '
|
||||
'(omitting it tells LiteLLM to disable budget enforcement)'
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_in_team_includes_max_budget_when_set(self):
|
||||
"""Test that _update_user_in_team includes max_budget_in_team when set."""
|
||||
mock_client = AsyncMock(spec=httpx.AsyncClient)
|
||||
mock_response = MagicMock()
|
||||
mock_response.is_success = True
|
||||
mock_response.status_code = 200
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-api-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
await LiteLlmManager._update_user_in_team(
|
||||
mock_client,
|
||||
keycloak_user_id='test-user-id',
|
||||
team_id='test-team-id',
|
||||
max_budget=75.0, # Explicit budget limit
|
||||
)
|
||||
|
||||
# Verify the call was made
|
||||
mock_client.post.assert_called_once()
|
||||
call_args = mock_client.post.call_args
|
||||
|
||||
# Verify that max_budget_in_team IS in the JSON payload
|
||||
json_payload = call_args[1]['json']
|
||||
assert (
|
||||
'max_budget_in_team' in json_payload
|
||||
), 'max_budget_in_team should be in payload when set to a value'
|
||||
assert json_payload['max_budget_in_team'] == 75.0
|
||||
|
||||
@@ -10,6 +10,37 @@ from storage.org_member import OrgMember
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
from storage.role import Role
|
||||
from storage.user import User
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
|
||||
def test_get_kwargs_from_user_settings_uses_agent_settings_as_source_of_truth():
|
||||
user_settings = UserSettings(
|
||||
llm_api_key='legacy-secret',
|
||||
agent_settings={
|
||||
'schema_version': 1,
|
||||
'agent': 'CodeActAgent',
|
||||
'verification.confirmation_mode': True,
|
||||
'verification.security_analyzer': 'llm',
|
||||
'condenser.enabled': False,
|
||||
'condenser.max_size': 128,
|
||||
'llm.model': 'anthropic/claude-sonnet-4-5-20250929',
|
||||
'llm.base_url': 'https://api.example.com',
|
||||
'max_iterations': 42,
|
||||
},
|
||||
)
|
||||
|
||||
kwargs = OrgMemberStore.get_kwargs_from_user_settings(user_settings)
|
||||
|
||||
assert kwargs['llm_api_key'] == 'legacy-secret'
|
||||
assert kwargs['llm_model'] == 'anthropic/claude-sonnet-4-5-20250929'
|
||||
assert kwargs['llm_base_url'] == 'https://api.example.com'
|
||||
assert kwargs['max_iterations'] == 42
|
||||
assert kwargs['agent_settings'] == {
|
||||
'schema_version': 1,
|
||||
'llm.model': 'anthropic/claude-sonnet-4-5-20250929',
|
||||
'llm.base_url': 'https://api.example.com',
|
||||
'max_iterations': 42,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -246,3 +246,82 @@ class TestSaasSecretsStore:
|
||||
assert isinstance(store, SaasSecretsStore)
|
||||
assert store.user_id == 'test-user-id'
|
||||
assert store.config == mock_config
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
'storage.saas_secrets_store.UserStore.get_user_by_id',
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
async def test_secrets_isolation_between_organizations(
|
||||
self, mock_get_user, secrets_store, mock_user
|
||||
):
|
||||
"""Test that secrets from one organization are not deleted when storing
|
||||
secrets in another organization. This reproduces a bug where switching
|
||||
organizations and creating a secret would delete all secrets from the
|
||||
user's personal workspace."""
|
||||
org1_id = UUID('a1111111-1111-1111-1111-111111111111')
|
||||
org2_id = UUID('b2222222-2222-2222-2222-222222222222')
|
||||
|
||||
# Store secrets in org1 (personal workspace)
|
||||
mock_user.current_org_id = org1_id
|
||||
mock_get_user.return_value = mock_user
|
||||
org1_secrets = Secrets(
|
||||
custom_secrets=MappingProxyType(
|
||||
{
|
||||
'personal_secret': CustomSecret.from_value(
|
||||
{
|
||||
'secret': 'personal_secret_value',
|
||||
'description': 'My personal secret',
|
||||
}
|
||||
),
|
||||
}
|
||||
)
|
||||
)
|
||||
await secrets_store.store(org1_secrets)
|
||||
|
||||
# Verify org1 secrets are stored
|
||||
loaded_org1 = await secrets_store.load()
|
||||
assert loaded_org1 is not None
|
||||
assert 'personal_secret' in loaded_org1.custom_secrets
|
||||
assert (
|
||||
loaded_org1.custom_secrets['personal_secret'].secret.get_secret_value()
|
||||
== 'personal_secret_value'
|
||||
)
|
||||
|
||||
# Switch to org2 and store secrets there
|
||||
mock_user.current_org_id = org2_id
|
||||
mock_get_user.return_value = mock_user
|
||||
org2_secrets = Secrets(
|
||||
custom_secrets=MappingProxyType(
|
||||
{
|
||||
'org2_secret': CustomSecret.from_value(
|
||||
{'secret': 'org2_secret_value', 'description': 'Org2 secret'}
|
||||
),
|
||||
}
|
||||
)
|
||||
)
|
||||
await secrets_store.store(org2_secrets)
|
||||
|
||||
# Verify org2 secrets are stored
|
||||
loaded_org2 = await secrets_store.load()
|
||||
assert loaded_org2 is not None
|
||||
assert 'org2_secret' in loaded_org2.custom_secrets
|
||||
assert (
|
||||
loaded_org2.custom_secrets['org2_secret'].secret.get_secret_value()
|
||||
== 'org2_secret_value'
|
||||
)
|
||||
|
||||
# Switch back to org1 and verify secrets are still there
|
||||
mock_user.current_org_id = org1_id
|
||||
mock_get_user.return_value = mock_user
|
||||
loaded_org1_again = await secrets_store.load()
|
||||
assert loaded_org1_again is not None
|
||||
assert 'personal_secret' in loaded_org1_again.custom_secrets
|
||||
assert (
|
||||
loaded_org1_again.custom_secrets[
|
||||
'personal_secret'
|
||||
].secret.get_secret_value()
|
||||
== 'personal_secret_value'
|
||||
)
|
||||
# Verify org2 secrets are NOT visible in org1
|
||||
assert 'org2_secret' not in loaded_org1_again.custom_secrets
|
||||
|
||||
@@ -26,6 +26,29 @@ def mock_config():
|
||||
return config
|
||||
|
||||
|
||||
def test_member_scoped_agent_settings_filters_effective_settings(mock_config):
|
||||
store = SaasSettingsStore('test-user-id', mock_config)
|
||||
effective_settings = Settings(
|
||||
agent='CodeActAgent',
|
||||
llm_model='anthropic/claude-sonnet-4-5-20250929',
|
||||
llm_base_url='https://api.example.com',
|
||||
max_iterations=42,
|
||||
confirmation_mode=True,
|
||||
security_analyzer='llm',
|
||||
enable_default_condenser=False,
|
||||
condenser_max_size=128,
|
||||
)
|
||||
|
||||
assert store._member_scoped_agent_settings(
|
||||
effective_settings.normalized_agent_settings(strip_secret_values=True)
|
||||
) == {
|
||||
'schema_version': 1,
|
||||
'llm.model': 'anthropic/claude-sonnet-4-5-20250929',
|
||||
'llm.base_url': 'https://api.example.com',
|
||||
'max_iterations': 42,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def settings_store(async_session_maker, mock_config):
|
||||
store = SaasSettingsStore('5594c7b6-f959-4b81-92e9-b09c206f5081', mock_config)
|
||||
@@ -79,6 +102,7 @@ def settings_store(async_session_maker, mock_config):
|
||||
|
||||
# Encrypt the data before storing
|
||||
store._encrypt_kwargs(item_dict)
|
||||
item_dict['agent_settings'] = item.agent_settings
|
||||
|
||||
# Continue with the original implementation
|
||||
from sqlalchemy import select
|
||||
@@ -119,6 +143,10 @@ async def test_store_and_load_keycloak_user(settings_store):
|
||||
agent='smith',
|
||||
email='test@example.com',
|
||||
email_verified=True,
|
||||
agent_settings={
|
||||
'critic_mode': 'all_actions',
|
||||
'enable_critic': True,
|
||||
},
|
||||
)
|
||||
|
||||
await settings_store.store(settings)
|
||||
@@ -126,6 +154,8 @@ async def test_store_and_load_keycloak_user(settings_store):
|
||||
# Load and verify settings
|
||||
loaded_settings = await settings_store.load()
|
||||
assert loaded_settings is not None
|
||||
assert loaded_settings.agent_settings['critic_mode'] == 'all_actions'
|
||||
assert loaded_settings.agent_settings['enable_critic'] is True
|
||||
assert loaded_settings.llm_api_key.get_secret_value() == 'secret_key'
|
||||
assert loaded_settings.agent == 'smith'
|
||||
|
||||
@@ -140,7 +170,7 @@ async def test_store_and_load_keycloak_user(settings_store):
|
||||
)
|
||||
stored = result.scalars().first()
|
||||
assert stored is not None
|
||||
assert stored.agent == 'smith'
|
||||
assert stored.agent_settings['agent'] == 'smith'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import time
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import jwt
|
||||
@@ -18,6 +19,7 @@ from server.auth.saas_user_auth import (
|
||||
saas_user_auth_from_cookie,
|
||||
saas_user_auth_from_signed_token,
|
||||
)
|
||||
from storage.api_key_store import ApiKeyValidationResult
|
||||
from storage.user_authorization import UserAuthorizationType
|
||||
|
||||
from openhands.integrations.provider import ProviderToken, ProviderType
|
||||
@@ -457,7 +459,8 @@ async def test_get_instance_no_auth(mock_request):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_saas_user_auth_from_bearer_success():
|
||||
"""Test successful authentication from bearer token."""
|
||||
"""Test successful authentication from bearer token sets user_id and api_key_org_id."""
|
||||
# Arrange
|
||||
mock_request = MagicMock()
|
||||
mock_request.headers = {'Authorization': 'Bearer test_api_key'}
|
||||
|
||||
@@ -468,12 +471,22 @@ async def test_saas_user_auth_from_bearer_success():
|
||||
algorithm='HS256',
|
||||
)
|
||||
|
||||
mock_org_id = uuid.uuid4()
|
||||
mock_validation_result = ApiKeyValidationResult(
|
||||
user_id='test_user_id',
|
||||
org_id=mock_org_id,
|
||||
key_id=42,
|
||||
key_name='Test Key',
|
||||
)
|
||||
|
||||
with (
|
||||
patch('server.auth.saas_user_auth.ApiKeyStore') as mock_api_key_store_cls,
|
||||
patch('server.auth.saas_user_auth.token_manager') as mock_token_manager,
|
||||
):
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_store.validate_api_key = AsyncMock(return_value='test_user_id')
|
||||
mock_api_key_store.validate_api_key = AsyncMock(
|
||||
return_value=mock_validation_result
|
||||
)
|
||||
mock_api_key_store_cls.get_instance.return_value = mock_api_key_store
|
||||
|
||||
mock_token_manager.load_offline_token = AsyncMock(return_value=offline_token)
|
||||
@@ -485,6 +498,9 @@ async def test_saas_user_auth_from_bearer_success():
|
||||
|
||||
assert isinstance(result, SaasUserAuth)
|
||||
assert result.user_id == 'test_user_id'
|
||||
assert result.api_key_org_id == mock_org_id
|
||||
assert result.api_key_id == 42
|
||||
assert result.api_key_name == 'Test Key'
|
||||
mock_api_key_store.validate_api_key.assert_called_once_with('test_api_key')
|
||||
mock_token_manager.load_offline_token.assert_called_once_with('test_user_id')
|
||||
mock_token_manager.refresh.assert_called_once_with(offline_token)
|
||||
|
||||
@@ -688,8 +688,10 @@ def test_has_custom_settings_custom_base_url():
|
||||
|
||||
user_settings = UserSettings(
|
||||
keycloak_user_id='test',
|
||||
llm_base_url='https://custom.api.example.com',
|
||||
llm_model='some-model',
|
||||
agent_settings={
|
||||
'llm.base_url': 'https://custom.api.example.com',
|
||||
'llm.model': 'some-model',
|
||||
},
|
||||
)
|
||||
|
||||
result = UserStore._has_custom_settings(user_settings, old_user_version=1)
|
||||
@@ -701,11 +703,7 @@ def test_has_custom_settings_no_model():
|
||||
"""Test that no model set means using defaults."""
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
user_settings = UserSettings(
|
||||
keycloak_user_id='test',
|
||||
llm_base_url=None,
|
||||
llm_model=None,
|
||||
)
|
||||
user_settings = UserSettings(keycloak_user_id='test', agent_settings={})
|
||||
|
||||
result = UserStore._has_custom_settings(user_settings, old_user_version=1)
|
||||
|
||||
@@ -718,8 +716,7 @@ def test_has_custom_settings_empty_model():
|
||||
|
||||
user_settings = UserSettings(
|
||||
keycloak_user_id='test',
|
||||
llm_base_url=None,
|
||||
llm_model=' ', # whitespace only
|
||||
agent_settings={'llm.model': ' '},
|
||||
)
|
||||
|
||||
result = UserStore._has_custom_settings(user_settings, old_user_version=1)
|
||||
@@ -780,7 +777,10 @@ def test_create_user_settings_from_entities():
|
||||
|
||||
assert result.keycloak_user_id == user_id
|
||||
assert result.llm_api_key == 'test-api-key'
|
||||
assert result.llm_model == 'claude-3-5-sonnet'
|
||||
assert result.agent_settings['llm.model'] == 'claude-3-5-sonnet'
|
||||
assert result.agent_settings['llm.base_url'] == 'https://api.example.com'
|
||||
assert result.agent_settings['max_iterations'] == 50
|
||||
assert result.agent_settings['agent'] == 'CodeActAgent'
|
||||
assert result.language == 'en'
|
||||
assert result.email == 'test@example.com'
|
||||
|
||||
@@ -835,9 +835,10 @@ def test_create_user_settings_from_entities_with_org_fallback():
|
||||
)
|
||||
|
||||
# Should have fallen back to org defaults
|
||||
assert result.llm_model == 'default-model'
|
||||
assert result.llm_base_url == 'https://default.api.com'
|
||||
assert result.max_iterations == 100
|
||||
assert result.agent_settings['llm.model'] == 'default-model'
|
||||
assert result.agent_settings['llm.base_url'] == 'https://default.api.com'
|
||||
assert result.agent_settings['max_iterations'] == 100
|
||||
assert result.agent_settings['agent'] == 'CodeActAgent'
|
||||
assert result.language == 'es'
|
||||
assert result.search_api_key == 'search-key'
|
||||
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
import { render, screen, waitFor } from "@testing-library/react";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { describe, it, expect } from "vitest";
|
||||
import { CopyableContentWrapper } from "#/components/shared/buttons/copyable-content-wrapper";
|
||||
|
||||
describe("CopyableContentWrapper", () => {
|
||||
it("should hide the copy button by default", () => {
|
||||
render(
|
||||
<CopyableContentWrapper text="hello">
|
||||
<p>content</p>
|
||||
</CopyableContentWrapper>,
|
||||
);
|
||||
|
||||
expect(screen.getByTestId("copy-to-clipboard")).not.toBeVisible();
|
||||
});
|
||||
|
||||
it("should show the copy button on hover", async () => {
|
||||
const user = userEvent.setup();
|
||||
render(
|
||||
<CopyableContentWrapper text="hello">
|
||||
<p>content</p>
|
||||
</CopyableContentWrapper>,
|
||||
);
|
||||
|
||||
await user.hover(screen.getByText("content"));
|
||||
|
||||
expect(screen.getByTestId("copy-to-clipboard")).toBeVisible();
|
||||
});
|
||||
|
||||
it("should copy text to clipboard on click", async () => {
|
||||
const user = userEvent.setup();
|
||||
render(
|
||||
<CopyableContentWrapper text="copy me">
|
||||
<p>content</p>
|
||||
</CopyableContentWrapper>,
|
||||
);
|
||||
|
||||
await user.click(screen.getByTestId("copy-to-clipboard"));
|
||||
|
||||
await waitFor(() =>
|
||||
expect(navigator.clipboard.readText()).resolves.toBe("copy me"),
|
||||
);
|
||||
});
|
||||
|
||||
it("should show copied state after clicking", async () => {
|
||||
const user = userEvent.setup();
|
||||
render(
|
||||
<CopyableContentWrapper text="hello">
|
||||
<p>content</p>
|
||||
</CopyableContentWrapper>,
|
||||
);
|
||||
|
||||
await user.click(screen.getByTestId("copy-to-clipboard"));
|
||||
|
||||
expect(screen.getByTestId("copy-to-clipboard")).toHaveAttribute(
|
||||
"aria-label",
|
||||
"BUTTON$COPIED",
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,78 @@
|
||||
import { render, screen } from "@testing-library/react";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import { ContextMenuContainer } from "#/components/features/context-menu/context-menu-container";
|
||||
|
||||
describe("ContextMenuContainer", () => {
|
||||
const user = userEvent.setup();
|
||||
const onCloseMock = vi.fn();
|
||||
|
||||
it("should render children", () => {
|
||||
render(
|
||||
<ContextMenuContainer onClose={onCloseMock}>
|
||||
<div data-testid="child-1">Child 1</div>
|
||||
<div data-testid="child-2">Child 2</div>
|
||||
</ContextMenuContainer>,
|
||||
);
|
||||
|
||||
expect(screen.getByTestId("child-1")).toBeInTheDocument();
|
||||
expect(screen.getByTestId("child-2")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should apply consistent base styling", () => {
|
||||
render(
|
||||
<ContextMenuContainer onClose={onCloseMock} testId="test-container">
|
||||
<div>Content</div>
|
||||
</ContextMenuContainer>,
|
||||
);
|
||||
|
||||
const container = screen.getByTestId("test-container");
|
||||
expect(container).toHaveClass("bg-[#050505]");
|
||||
expect(container).toHaveClass("border");
|
||||
expect(container).toHaveClass("border-[#242424]");
|
||||
expect(container).toHaveClass("rounded-[12px]");
|
||||
expect(container).toHaveClass("p-[25px]");
|
||||
expect(container).toHaveClass("context-menu-box-shadow");
|
||||
});
|
||||
|
||||
it("should call onClose when clicking outside", async () => {
|
||||
render(
|
||||
<ContextMenuContainer onClose={onCloseMock} testId="test-container">
|
||||
<div>Content</div>
|
||||
</ContextMenuContainer>,
|
||||
);
|
||||
|
||||
await user.click(document.body);
|
||||
expect(onCloseMock).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
it("should render children in a flex row layout", () => {
|
||||
render(
|
||||
<ContextMenuContainer onClose={onCloseMock} testId="test-container">
|
||||
<div data-testid="child-1">Child 1</div>
|
||||
<div data-testid="child-2">Child 2</div>
|
||||
</ContextMenuContainer>,
|
||||
);
|
||||
|
||||
const container = screen.getByTestId("test-container");
|
||||
const innerDiv = container.firstChild as HTMLElement;
|
||||
expect(innerDiv).toHaveClass("flex");
|
||||
expect(innerDiv).toHaveClass("flex-row");
|
||||
expect(innerDiv).toHaveClass("gap-4");
|
||||
});
|
||||
|
||||
it("should apply additional className when provided", () => {
|
||||
render(
|
||||
<ContextMenuContainer
|
||||
onClose={onCloseMock}
|
||||
testId="test-container"
|
||||
className="custom-class"
|
||||
>
|
||||
<div>Content</div>
|
||||
</ContextMenuContainer>,
|
||||
);
|
||||
|
||||
const container = screen.getByTestId("test-container");
|
||||
expect(container).toHaveClass("custom-class");
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,60 @@
|
||||
import { render, screen } from "@testing-library/react";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { ContextMenuCTA } from "#/components/features/context-menu/context-menu-cta";
|
||||
|
||||
// Mock useTracking hook
|
||||
const mockTrackSaasSelfhostedInquiry = vi.fn();
|
||||
vi.mock("#/hooks/use-tracking", () => ({
|
||||
useTracking: () => ({
|
||||
trackSaasSelfhostedInquiry: mockTrackSaasSelfhostedInquiry,
|
||||
}),
|
||||
}));
|
||||
|
||||
describe("ContextMenuCTA", () => {
|
||||
it("should render the CTA component", () => {
|
||||
render(<ContextMenuCTA />);
|
||||
|
||||
expect(screen.getByText("CTA$ENTERPRISE_TITLE")).toBeInTheDocument();
|
||||
expect(screen.getByText("CTA$ENTERPRISE_DESCRIPTION")).toBeInTheDocument();
|
||||
expect(screen.getByText("CTA$LEARN_MORE")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should call trackSaasSelfhostedInquiry with location 'context_menu' when Learn More is clicked", async () => {
|
||||
const user = userEvent.setup();
|
||||
render(<ContextMenuCTA />);
|
||||
|
||||
const learnMoreLink = screen.getByRole("link", {
|
||||
name: "CTA$LEARN_MORE",
|
||||
});
|
||||
await user.click(learnMoreLink);
|
||||
|
||||
expect(mockTrackSaasSelfhostedInquiry).toHaveBeenCalledWith({
|
||||
location: "context_menu",
|
||||
});
|
||||
});
|
||||
|
||||
it("should render Learn More as a link with correct href and target", () => {
|
||||
render(<ContextMenuCTA />);
|
||||
|
||||
const learnMoreLink = screen.getByRole("link", {
|
||||
name: "CTA$LEARN_MORE",
|
||||
});
|
||||
expect(learnMoreLink).toHaveAttribute(
|
||||
"href",
|
||||
"https://openhands.dev/enterprise/",
|
||||
);
|
||||
expect(learnMoreLink).toHaveAttribute("target", "_blank");
|
||||
expect(learnMoreLink).toHaveAttribute("rel", "noopener noreferrer");
|
||||
});
|
||||
|
||||
it("should render the stacked icon", () => {
|
||||
render(<ContextMenuCTA />);
|
||||
|
||||
const contentContainer = screen.getByTestId("context-menu-cta-content");
|
||||
const icon = contentContainer.querySelector("svg");
|
||||
expect(icon).toBeInTheDocument();
|
||||
expect(icon).toHaveAttribute("width", "40");
|
||||
expect(icon).toHaveAttribute("height", "40");
|
||||
});
|
||||
});
|
||||
+6
-1
@@ -1,11 +1,16 @@
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import { describe, expect, it, vi, beforeEach } from "vitest";
|
||||
import { render, screen, waitFor } from "@testing-library/react";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { AnalyticsConsentFormModal } from "#/components/features/analytics/analytics-consent-form-modal";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
import { useSelectedOrganizationStore } from "#/stores/selected-organization-store";
|
||||
|
||||
describe("AnalyticsConsentFormModal", () => {
|
||||
beforeEach(() => {
|
||||
useSelectedOrganizationStore.setState({ organizationId: "test-org-id" });
|
||||
});
|
||||
|
||||
it("should call saveUserSettings with consent", async () => {
|
||||
const user = userEvent.setup();
|
||||
const onCloseMock = vi.fn();
|
||||
|
||||
@@ -49,9 +49,17 @@ vi.mock("#/utils/custom-toast-handlers", () => ({
|
||||
displayErrorToast: vi.fn(),
|
||||
}));
|
||||
|
||||
// Mock feature flags - we'll control the return value in each test
|
||||
const mockEnableProjUserJourney = vi.fn(() => true);
|
||||
vi.mock("#/utils/feature-flags", () => ({
|
||||
ENABLE_PROJ_USER_JOURNEY: () => mockEnableProjUserJourney(),
|
||||
}));
|
||||
|
||||
describe("LoginContent", () => {
|
||||
beforeEach(() => {
|
||||
vi.stubGlobal("location", { href: "" });
|
||||
// Reset mock to return true by default
|
||||
mockEnableProjUserJourney.mockReturnValue(true);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -274,6 +282,65 @@ describe("LoginContent", () => {
|
||||
expect(screen.getByTestId("terms-and-privacy-notice")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should display the enterprise LoginCTA component when appMode is saas and feature flag enabled", () => {
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<LoginContent
|
||||
githubAuthUrl="https://github.com/oauth/authorize"
|
||||
appMode="saas"
|
||||
providersConfigured={["github"]}
|
||||
/>
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
expect(screen.getByTestId("login-cta")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not display the enterprise LoginCTA component when appMode is oss even with feature flag enabled", () => {
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<LoginContent
|
||||
githubAuthUrl="https://github.com/oauth/authorize"
|
||||
appMode="oss"
|
||||
providersConfigured={["github"]}
|
||||
/>
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
expect(screen.queryByTestId("login-cta")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not display the enterprise LoginCTA component when appMode is null", () => {
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<LoginContent
|
||||
githubAuthUrl="https://github.com/oauth/authorize"
|
||||
appMode={null}
|
||||
providersConfigured={["github"]}
|
||||
/>
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
expect(screen.queryByTestId("login-cta")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not display the enterprise LoginCTA component when feature flag is disabled", () => {
|
||||
// Disable the feature flag
|
||||
mockEnableProjUserJourney.mockReturnValue(false);
|
||||
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<LoginContent
|
||||
githubAuthUrl="https://github.com/oauth/authorize"
|
||||
appMode="saas"
|
||||
providersConfigured={["github"]}
|
||||
/>
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
expect(screen.queryByTestId("login-cta")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should display invitation pending message when hasInvitation is true", () => {
|
||||
render(
|
||||
<MemoryRouter>
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
import { render, screen } from "@testing-library/react";
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { LoginCTA } from "#/components/features/auth/login-cta";
|
||||
|
||||
// Mock useTracking hook
|
||||
const mockTrackSaasSelfhostedInquiry = vi.fn();
|
||||
vi.mock("#/hooks/use-tracking", () => ({
|
||||
useTracking: () => ({
|
||||
trackSaasSelfhostedInquiry: mockTrackSaasSelfhostedInquiry,
|
||||
}),
|
||||
}));
|
||||
|
||||
describe("LoginCTA", () => {
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it("should render enterprise CTA with title and description", () => {
|
||||
render(<LoginCTA />);
|
||||
|
||||
expect(screen.getByTestId("login-cta")).toBeInTheDocument();
|
||||
expect(screen.getByText("CTA$ENTERPRISE")).toBeInTheDocument();
|
||||
expect(screen.getByText("CTA$ENTERPRISE_DEPLOY")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should render all enterprise feature list items", () => {
|
||||
render(<LoginCTA />);
|
||||
|
||||
expect(screen.getByText("CTA$FEATURE_ON_PREMISES")).toBeInTheDocument();
|
||||
expect(screen.getByText("CTA$FEATURE_DATA_CONTROL")).toBeInTheDocument();
|
||||
expect(screen.getByText("CTA$FEATURE_COMPLIANCE")).toBeInTheDocument();
|
||||
expect(screen.getByText("CTA$FEATURE_SUPPORT")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should render Learn More as a link with correct href and target", () => {
|
||||
render(<LoginCTA />);
|
||||
|
||||
const learnMoreLink = screen.getByRole("link", {
|
||||
name: "CTA$LEARN_MORE",
|
||||
});
|
||||
expect(learnMoreLink).toHaveAttribute(
|
||||
"href",
|
||||
"https://openhands.dev/enterprise/",
|
||||
);
|
||||
expect(learnMoreLink).toHaveAttribute("target", "_blank");
|
||||
expect(learnMoreLink).toHaveAttribute("rel", "noopener noreferrer");
|
||||
});
|
||||
|
||||
it("should call trackSaasSelfhostedInquiry with location 'login_page' when Learn More is clicked", async () => {
|
||||
const user = userEvent.setup();
|
||||
render(<LoginCTA />);
|
||||
|
||||
const learnMoreLink = screen.getByRole("link", {
|
||||
name: "CTA$LEARN_MORE",
|
||||
});
|
||||
await user.click(learnMoreLink);
|
||||
|
||||
expect(mockTrackSaasSelfhostedInquiry).toHaveBeenCalledWith({
|
||||
location: "login_page",
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -10,9 +10,12 @@ import {
|
||||
import { OpenHandsObservation } from "#/types/core/observations";
|
||||
import ConversationService from "#/api/conversation-service/conversation-service.api";
|
||||
import { Conversation } from "#/api/open-hands.types";
|
||||
import { useSelectedOrganizationStore } from "#/stores/selected-organization-store";
|
||||
|
||||
vi.mock("react-router", () => ({
|
||||
vi.mock("react-router", async (importOriginal) => ({
|
||||
...(await importOriginal<typeof import("react-router")>()),
|
||||
useParams: () => ({ conversationId: "123" }),
|
||||
useRevalidator: () => ({ revalidate: vi.fn() }),
|
||||
}));
|
||||
|
||||
let queryClient: QueryClient;
|
||||
@@ -47,6 +50,7 @@ const renderMessages = ({
|
||||
describe("Messages", () => {
|
||||
beforeEach(() => {
|
||||
queryClient = new QueryClient();
|
||||
useSelectedOrganizationStore.setState({ organizationId: "test-org-id" });
|
||||
});
|
||||
|
||||
const assistantMessage: AssistantMessageAction = {
|
||||
|
||||
@@ -11,23 +11,23 @@ vi.mock("posthog-js/react", () => ({
|
||||
}),
|
||||
}));
|
||||
|
||||
const { PROJ_USER_JOURNEY_MOCK } = vi.hoisted(() => ({
|
||||
PROJ_USER_JOURNEY_MOCK: vi.fn(() => true),
|
||||
const { ENABLE_PROJ_USER_JOURNEY_MOCK } = vi.hoisted(() => ({
|
||||
ENABLE_PROJ_USER_JOURNEY_MOCK: vi.fn(() => true),
|
||||
}));
|
||||
|
||||
vi.mock("#/utils/feature-flags", () => ({
|
||||
PROJ_USER_JOURNEY: () => PROJ_USER_JOURNEY_MOCK(),
|
||||
ENABLE_PROJ_USER_JOURNEY: () => ENABLE_PROJ_USER_JOURNEY_MOCK(),
|
||||
}));
|
||||
|
||||
describe("EnterpriseBanner", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
PROJ_USER_JOURNEY_MOCK.mockReturnValue(true);
|
||||
ENABLE_PROJ_USER_JOURNEY_MOCK.mockReturnValue(true);
|
||||
});
|
||||
|
||||
describe("Feature Flag", () => {
|
||||
it("should not render when proj_user_journey feature flag is disabled", () => {
|
||||
PROJ_USER_JOURNEY_MOCK.mockReturnValue(false);
|
||||
ENABLE_PROJ_USER_JOURNEY_MOCK.mockReturnValue(false);
|
||||
|
||||
const { container } = renderWithProviders(<EnterpriseBanner />);
|
||||
|
||||
@@ -36,7 +36,7 @@ describe("EnterpriseBanner", () => {
|
||||
});
|
||||
|
||||
it("should render when proj_user_journey feature flag is enabled", () => {
|
||||
PROJ_USER_JOURNEY_MOCK.mockReturnValue(true);
|
||||
ENABLE_PROJ_USER_JOURNEY_MOCK.mockReturnValue(true);
|
||||
|
||||
renderWithProviders(<EnterpriseBanner />);
|
||||
|
||||
|
||||
@@ -0,0 +1,159 @@
|
||||
import { render, screen } from "@testing-library/react";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { describe, expect, it, vi, beforeEach } from "vitest";
|
||||
import { HomepageCTA } from "#/components/features/home/homepage-cta";
|
||||
|
||||
// Mock the translation function
|
||||
vi.mock("react-i18next", async () => {
|
||||
const actual = await vi.importActual("react-i18next");
|
||||
return {
|
||||
...actual,
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => {
|
||||
const translations: Record<string, string> = {
|
||||
"CTA$ENTERPRISE_TITLE": "Get OpenHands for Enterprise",
|
||||
"CTA$ENTERPRISE_DESCRIPTION":
|
||||
"Cloud allows you to access OpenHands anywhere and coordinate with your team like never before",
|
||||
"CTA$LEARN_MORE": "Learn More",
|
||||
};
|
||||
return translations[key] || key;
|
||||
},
|
||||
i18n: { language: "en" },
|
||||
}),
|
||||
};
|
||||
});
|
||||
|
||||
// Mock local storage
|
||||
vi.mock("#/utils/local-storage", () => ({
|
||||
setCTADismissed: vi.fn(),
|
||||
}));
|
||||
|
||||
// Mock useTracking hook
|
||||
const mockTrackSaasSelfhostedInquiry = vi.fn();
|
||||
vi.mock("#/hooks/use-tracking", () => ({
|
||||
useTracking: () => ({
|
||||
trackSaasSelfhostedInquiry: mockTrackSaasSelfhostedInquiry,
|
||||
}),
|
||||
}));
|
||||
|
||||
import { setCTADismissed } from "#/utils/local-storage";
|
||||
|
||||
describe("HomepageCTA", () => {
|
||||
const mockSetShouldShowCTA = vi.fn();
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
const renderHomepageCTA = () => {
|
||||
return render(<HomepageCTA setShouldShowCTA={mockSetShouldShowCTA} />);
|
||||
};
|
||||
|
||||
describe("rendering", () => {
|
||||
it("renders the enterprise title", () => {
|
||||
renderHomepageCTA();
|
||||
expect(
|
||||
screen.getByText("Get OpenHands for Enterprise"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("renders the enterprise description", () => {
|
||||
renderHomepageCTA();
|
||||
expect(
|
||||
screen.getByText(/Cloud allows you to access OpenHands anywhere/),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("renders the Learn More link", () => {
|
||||
renderHomepageCTA();
|
||||
const link = screen.getByRole("link", { name: "Learn More" });
|
||||
expect(link).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("renders the close button with correct aria-label", () => {
|
||||
renderHomepageCTA();
|
||||
expect(screen.getByRole("button", { name: "Close" })).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("close button behavior", () => {
|
||||
it("calls setCTADismissed with 'homepage' when close button is clicked", async () => {
|
||||
const user = userEvent.setup();
|
||||
renderHomepageCTA();
|
||||
|
||||
const closeButton = screen.getByRole("button", { name: "Close" });
|
||||
await user.click(closeButton);
|
||||
|
||||
expect(setCTADismissed).toHaveBeenCalledWith("homepage");
|
||||
});
|
||||
|
||||
it("calls setShouldShowCTA with false when close button is clicked", async () => {
|
||||
const user = userEvent.setup();
|
||||
renderHomepageCTA();
|
||||
|
||||
const closeButton = screen.getByRole("button", { name: "Close" });
|
||||
await user.click(closeButton);
|
||||
|
||||
expect(mockSetShouldShowCTA).toHaveBeenCalledWith(false);
|
||||
});
|
||||
|
||||
it("calls both setCTADismissed and setShouldShowCTA in order", async () => {
|
||||
const user = userEvent.setup();
|
||||
const callOrder: string[] = [];
|
||||
|
||||
vi.mocked(setCTADismissed).mockImplementation(() => {
|
||||
callOrder.push("setCTADismissed");
|
||||
});
|
||||
mockSetShouldShowCTA.mockImplementation(() => {
|
||||
callOrder.push("setShouldShowCTA");
|
||||
});
|
||||
|
||||
renderHomepageCTA();
|
||||
|
||||
const closeButton = screen.getByRole("button", { name: "Close" });
|
||||
await user.click(closeButton);
|
||||
|
||||
expect(callOrder).toEqual(["setCTADismissed", "setShouldShowCTA"]);
|
||||
});
|
||||
});
|
||||
|
||||
describe("Learn More link behavior", () => {
|
||||
it("calls trackSaasSelfhostedInquiry with location 'home_page' when clicked", async () => {
|
||||
const user = userEvent.setup();
|
||||
renderHomepageCTA();
|
||||
|
||||
const learnMoreLink = screen.getByRole("link", { name: "Learn More" });
|
||||
await user.click(learnMoreLink);
|
||||
|
||||
expect(mockTrackSaasSelfhostedInquiry).toHaveBeenCalledWith({
|
||||
location: "home_page",
|
||||
});
|
||||
});
|
||||
|
||||
it("has correct href and target attributes", () => {
|
||||
renderHomepageCTA();
|
||||
|
||||
const learnMoreLink = screen.getByRole("link", { name: "Learn More" });
|
||||
expect(learnMoreLink).toHaveAttribute(
|
||||
"href",
|
||||
"https://openhands.dev/enterprise/",
|
||||
);
|
||||
expect(learnMoreLink).toHaveAttribute("target", "_blank");
|
||||
expect(learnMoreLink).toHaveAttribute("rel", "noopener noreferrer");
|
||||
});
|
||||
});
|
||||
|
||||
describe("accessibility", () => {
|
||||
it("close button is focusable", () => {
|
||||
renderHomepageCTA();
|
||||
const closeButton = screen.getByRole("button", { name: "Close" });
|
||||
expect(closeButton).not.toHaveAttribute("tabindex", "-1");
|
||||
});
|
||||
|
||||
it("Learn More link is focusable", () => {
|
||||
renderHomepageCTA();
|
||||
const learnMoreLink = screen.getByRole("link", { name: "Learn More" });
|
||||
expect(learnMoreLink).not.toHaveAttribute("tabindex", "-1");
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -10,6 +10,7 @@ import OptionService from "#/api/option-service/option-service.api";
|
||||
import { GitRepository } from "#/types/git";
|
||||
import { RepoConnector } from "#/components/features/home/repo-connector";
|
||||
import { MOCK_DEFAULT_USER_SETTINGS } from "#/mocks/handlers";
|
||||
import { useSelectedOrganizationStore } from "#/stores/selected-organization-store";
|
||||
|
||||
const renderRepoConnector = () => {
|
||||
const mockRepoSelection = vi.fn();
|
||||
@@ -65,6 +66,7 @@ const MOCK_RESPOSITORIES: GitRepository[] = [
|
||||
];
|
||||
|
||||
beforeEach(() => {
|
||||
useSelectedOrganizationStore.setState({ organizationId: "test-org-id" });
|
||||
const getSettingsSpy = vi.spyOn(SettingsService, "getSettings");
|
||||
getSettingsSpy.mockResolvedValue({
|
||||
...MOCK_DEFAULT_USER_SETTINGS,
|
||||
@@ -493,6 +495,6 @@ describe("RepoConnector", () => {
|
||||
expect(goToSettingsButton).toBeInTheDocument();
|
||||
|
||||
await userEvent.click(goToSettingsButton);
|
||||
await screen.findByTestId("git-settings-screen");
|
||||
await screen.findByTestId("settings-screen");
|
||||
});
|
||||
});
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
import { render, screen, waitFor } from "@testing-library/react";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { describe, it, expect } from "vitest";
|
||||
import { code as Code } from "#/components/features/markdown/code";
|
||||
|
||||
describe("code (markdown)", () => {
|
||||
it("should render inline code without a copy button", () => {
|
||||
render(<Code>inline snippet</Code>);
|
||||
|
||||
expect(screen.getByText("inline snippet")).toBeInTheDocument();
|
||||
expect(screen.queryByTestId("copy-to-clipboard")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should render a multiline code block with a copy button", () => {
|
||||
render(<Code>{"line1\nline2"}</Code>);
|
||||
|
||||
expect(screen.getByText("line1 line2")).toBeInTheDocument();
|
||||
expect(screen.getByTestId("copy-to-clipboard")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should render a syntax-highlighted block with a copy button", () => {
|
||||
render(<Code className="language-js">{"console.log('hi')"}</Code>);
|
||||
|
||||
expect(screen.getByTestId("copy-to-clipboard")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should copy code block content to clipboard", async () => {
|
||||
const user = userEvent.setup();
|
||||
render(<Code>{"line1\nline2"}</Code>);
|
||||
|
||||
await user.click(screen.getByTestId("copy-to-clipboard"));
|
||||
|
||||
await waitFor(() =>
|
||||
expect(navigator.clipboard.readText()).resolves.toBe("line1\nline2"),
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -7,6 +7,8 @@ import OnboardingForm from "#/routes/onboarding-form";
|
||||
|
||||
const mockMutate = vi.fn();
|
||||
const mockNavigate = vi.fn();
|
||||
const mockUseConfig = vi.fn();
|
||||
const mockTrackOnboardingCompleted = vi.fn();
|
||||
|
||||
vi.mock("react-router", async (importOriginal) => {
|
||||
const original = await importOriginal<typeof import("react-router")>();
|
||||
@@ -22,6 +24,16 @@ vi.mock("#/hooks/mutation/use-submit-onboarding", () => ({
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/query/use-config", () => ({
|
||||
useConfig: () => mockUseConfig(),
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/use-tracking", () => ({
|
||||
useTracking: () => ({
|
||||
trackOnboardingCompleted: mockTrackOnboardingCompleted,
|
||||
}),
|
||||
}));
|
||||
|
||||
const renderOnboardingForm = () => {
|
||||
return renderWithProviders(
|
||||
<MemoryRouter>
|
||||
@@ -30,10 +42,15 @@ const renderOnboardingForm = () => {
|
||||
);
|
||||
};
|
||||
|
||||
describe("OnboardingForm", () => {
|
||||
describe("OnboardingForm - SaaS Mode", () => {
|
||||
beforeEach(() => {
|
||||
mockMutate.mockClear();
|
||||
mockNavigate.mockClear();
|
||||
mockTrackOnboardingCompleted.mockClear();
|
||||
mockUseConfig.mockReturnValue({
|
||||
data: { app_mode: "saas" },
|
||||
isLoading: false,
|
||||
});
|
||||
});
|
||||
|
||||
it("should render with the correct test id", () => {
|
||||
@@ -50,7 +67,7 @@ describe("OnboardingForm", () => {
|
||||
expect(screen.getByTestId("step-actions")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should display step progress indicator with 3 bars", () => {
|
||||
it("should display step progress indicator with 3 bars for saas mode", () => {
|
||||
renderOnboardingForm();
|
||||
|
||||
const stepHeader = screen.getByTestId("step-header");
|
||||
@@ -69,7 +86,7 @@ describe("OnboardingForm", () => {
|
||||
const user = userEvent.setup();
|
||||
renderOnboardingForm();
|
||||
|
||||
await user.click(screen.getByTestId("step-option-software_engineer"));
|
||||
await user.click(screen.getByTestId("step-option-solo"));
|
||||
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
expect(nextButton).not.toBeDisabled();
|
||||
@@ -84,7 +101,7 @@ describe("OnboardingForm", () => {
|
||||
let progressBars = stepHeader.querySelectorAll(".bg-white");
|
||||
expect(progressBars).toHaveLength(1);
|
||||
|
||||
await user.click(screen.getByTestId("step-option-software_engineer"));
|
||||
await user.click(screen.getByTestId("step-option-solo"));
|
||||
await user.click(screen.getByRole("button", { name: /next/i }));
|
||||
|
||||
// On step 2, first two progress bars should be filled
|
||||
@@ -96,7 +113,7 @@ describe("OnboardingForm", () => {
|
||||
const user = userEvent.setup();
|
||||
renderOnboardingForm();
|
||||
|
||||
await user.click(screen.getByTestId("step-option-software_engineer"));
|
||||
await user.click(screen.getByTestId("step-option-solo"));
|
||||
await user.click(screen.getByRole("button", { name: /next/i }));
|
||||
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
@@ -107,29 +124,51 @@ describe("OnboardingForm", () => {
|
||||
const user = userEvent.setup();
|
||||
renderOnboardingForm();
|
||||
|
||||
// Step 1 - select role
|
||||
await user.click(screen.getByTestId("step-option-software_engineer"));
|
||||
await user.click(screen.getByRole("button", { name: /next/i }));
|
||||
|
||||
// Step 2 - select org size
|
||||
// Step 1 - select org size (first step in saas mode - single select)
|
||||
await user.click(screen.getByTestId("step-option-org_2_10"));
|
||||
await user.click(screen.getByRole("button", { name: /next/i }));
|
||||
|
||||
// Step 3 - select use case
|
||||
// Step 2 - select use case (multi-select)
|
||||
await user.click(screen.getByTestId("step-option-new_features"));
|
||||
await user.click(screen.getByRole("button", { name: /next/i }));
|
||||
|
||||
// Step 3 - select role (last step in saas mode - single select)
|
||||
await user.click(screen.getByTestId("step-option-software_engineer"));
|
||||
await user.click(screen.getByRole("button", { name: /finish/i }));
|
||||
|
||||
expect(mockMutate).toHaveBeenCalledTimes(1);
|
||||
expect(mockMutate).toHaveBeenCalledWith({
|
||||
selections: {
|
||||
step1: "software_engineer",
|
||||
step2: "org_2_10",
|
||||
step3: "new_features",
|
||||
org_size: "org_2_10",
|
||||
use_case: ["new_features"],
|
||||
role: "software_engineer",
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it("should render 6 options on step 1", () => {
|
||||
it("should track onboarding completion to PostHog in SaaS mode", async () => {
|
||||
const user = userEvent.setup();
|
||||
renderOnboardingForm();
|
||||
|
||||
// Complete the full SaaS onboarding flow
|
||||
await user.click(screen.getByTestId("step-option-org_2_10"));
|
||||
await user.click(screen.getByRole("button", { name: /next/i }));
|
||||
|
||||
await user.click(screen.getByTestId("step-option-new_features"));
|
||||
await user.click(screen.getByRole("button", { name: /next/i }));
|
||||
|
||||
await user.click(screen.getByTestId("step-option-software_engineer"));
|
||||
await user.click(screen.getByRole("button", { name: /finish/i }));
|
||||
|
||||
expect(mockTrackOnboardingCompleted).toHaveBeenCalledTimes(1);
|
||||
expect(mockTrackOnboardingCompleted).toHaveBeenCalledWith({
|
||||
role: "software_engineer",
|
||||
orgSize: "org_2_10",
|
||||
useCase: ["new_features"],
|
||||
});
|
||||
});
|
||||
|
||||
it("should render 5 options on step 1 (org size question)", () => {
|
||||
renderOnboardingForm();
|
||||
|
||||
const options = screen
|
||||
@@ -137,31 +176,86 @@ describe("OnboardingForm", () => {
|
||||
.filter((btn) =>
|
||||
btn.getAttribute("data-testid")?.startsWith("step-option-"),
|
||||
);
|
||||
expect(options).toHaveLength(6);
|
||||
expect(options).toHaveLength(5);
|
||||
});
|
||||
|
||||
it("should preserve selections when navigating through steps", async () => {
|
||||
const user = userEvent.setup();
|
||||
renderOnboardingForm();
|
||||
|
||||
// Select role on step 1
|
||||
await user.click(screen.getByTestId("step-option-cto_founder"));
|
||||
await user.click(screen.getByRole("button", { name: /next/i }));
|
||||
|
||||
// Select org size on step 2
|
||||
// Select org size on step 1 (single select)
|
||||
await user.click(screen.getByTestId("step-option-solo"));
|
||||
await user.click(screen.getByRole("button", { name: /next/i }));
|
||||
|
||||
// Select use case on step 3
|
||||
// Select use case on step 2 (multi-select)
|
||||
await user.click(screen.getByTestId("step-option-fixing_bugs"));
|
||||
await user.click(screen.getByRole("button", { name: /next/i }));
|
||||
|
||||
// Select role on step 3 (single select)
|
||||
await user.click(screen.getByTestId("step-option-cto_founder"));
|
||||
await user.click(screen.getByRole("button", { name: /finish/i }));
|
||||
|
||||
// Verify all selections were preserved
|
||||
expect(mockMutate).toHaveBeenCalledWith({
|
||||
selections: {
|
||||
step1: "cto_founder",
|
||||
step2: "solo",
|
||||
step3: "fixing_bugs",
|
||||
org_size: "solo",
|
||||
use_case: ["fixing_bugs"],
|
||||
role: "cto_founder",
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it("should allow selecting multiple options on multi-select steps", async () => {
|
||||
const user = userEvent.setup();
|
||||
renderOnboardingForm();
|
||||
|
||||
// Step 1 - select org size (single select)
|
||||
await user.click(screen.getByTestId("step-option-solo"));
|
||||
await user.click(screen.getByRole("button", { name: /next/i }));
|
||||
|
||||
// Step 2 - select multiple use cases (multi-select)
|
||||
await user.click(screen.getByTestId("step-option-new_features"));
|
||||
await user.click(screen.getByTestId("step-option-fixing_bugs"));
|
||||
await user.click(screen.getByTestId("step-option-refactoring"));
|
||||
await user.click(screen.getByRole("button", { name: /next/i }));
|
||||
|
||||
// Step 3 - select role (single select)
|
||||
await user.click(screen.getByTestId("step-option-software_engineer"));
|
||||
await user.click(screen.getByRole("button", { name: /finish/i }));
|
||||
|
||||
expect(mockMutate).toHaveBeenCalledWith({
|
||||
selections: {
|
||||
org_size: "solo",
|
||||
use_case: ["new_features", "fixing_bugs", "refactoring"],
|
||||
role: "software_engineer",
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it("should allow deselecting options on multi-select steps", async () => {
|
||||
const user = userEvent.setup();
|
||||
renderOnboardingForm();
|
||||
|
||||
// Step 1 - select org size
|
||||
await user.click(screen.getByTestId("step-option-solo"));
|
||||
await user.click(screen.getByRole("button", { name: /next/i }));
|
||||
|
||||
// Step 2 - select and deselect use cases
|
||||
await user.click(screen.getByTestId("step-option-new_features"));
|
||||
await user.click(screen.getByTestId("step-option-fixing_bugs"));
|
||||
await user.click(screen.getByTestId("step-option-new_features")); // Deselect
|
||||
|
||||
await user.click(screen.getByRole("button", { name: /next/i }));
|
||||
|
||||
// Step 3 - select role
|
||||
await user.click(screen.getByTestId("step-option-software_engineer"));
|
||||
await user.click(screen.getByRole("button", { name: /finish/i }));
|
||||
|
||||
expect(mockMutate).toHaveBeenCalledWith({
|
||||
selections: {
|
||||
org_size: "solo",
|
||||
use_case: ["fixing_bugs"],
|
||||
role: "software_engineer",
|
||||
},
|
||||
});
|
||||
});
|
||||
@@ -171,10 +265,10 @@ describe("OnboardingForm", () => {
|
||||
renderOnboardingForm();
|
||||
|
||||
// Navigate to step 3
|
||||
await user.click(screen.getByTestId("step-option-software_engineer"));
|
||||
await user.click(screen.getByTestId("step-option-solo"));
|
||||
await user.click(screen.getByRole("button", { name: /next/i }));
|
||||
|
||||
await user.click(screen.getByTestId("step-option-solo"));
|
||||
await user.click(screen.getByTestId("step-option-new_features"));
|
||||
await user.click(screen.getByRole("button", { name: /next/i }));
|
||||
|
||||
// On step 3, all three progress bars should be filled
|
||||
@@ -194,7 +288,7 @@ describe("OnboardingForm", () => {
|
||||
const user = userEvent.setup();
|
||||
renderOnboardingForm();
|
||||
|
||||
await user.click(screen.getByTestId("step-option-software_engineer"));
|
||||
await user.click(screen.getByTestId("step-option-solo"));
|
||||
await user.click(screen.getByRole("button", { name: /next/i }));
|
||||
|
||||
const backButton = screen.getByRole("button", { name: /back/i });
|
||||
@@ -206,7 +300,7 @@ describe("OnboardingForm", () => {
|
||||
renderOnboardingForm();
|
||||
|
||||
// Navigate to step 2
|
||||
await user.click(screen.getByTestId("step-option-software_engineer"));
|
||||
await user.click(screen.getByTestId("step-option-solo"));
|
||||
await user.click(screen.getByRole("button", { name: /next/i }));
|
||||
|
||||
// Verify we're on step 2 (2 progress bars filled)
|
||||
|
||||
@@ -12,7 +12,7 @@ describe("StepContent", () => {
|
||||
|
||||
const defaultProps = {
|
||||
options: mockOptions,
|
||||
selectedOptionId: null,
|
||||
selectedOptionIds: [],
|
||||
onSelectOption: vi.fn(),
|
||||
};
|
||||
|
||||
@@ -44,7 +44,7 @@ describe("StepContent", () => {
|
||||
});
|
||||
|
||||
it("should mark the selected option as selected", () => {
|
||||
render(<StepContent {...defaultProps} selectedOptionId="option1" />);
|
||||
render(<StepContent {...defaultProps} selectedOptionIds={["option1"]} />);
|
||||
|
||||
const selectedOption = screen.getByTestId("step-option-option1");
|
||||
const unselectedOption = screen.getByTestId("step-option-option2");
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
import { render, screen } from "@testing-library/react";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import { StepInput } from "#/components/features/onboarding/step-input";
|
||||
|
||||
describe("StepInput", () => {
|
||||
const defaultProps = {
|
||||
id: "test-input",
|
||||
label: "Test Label",
|
||||
value: "",
|
||||
onChange: vi.fn(),
|
||||
};
|
||||
|
||||
it("should render with correct test id", () => {
|
||||
render(<StepInput {...defaultProps} />);
|
||||
|
||||
expect(screen.getByTestId("step-input-test-input")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should render the label", () => {
|
||||
render(<StepInput {...defaultProps} />);
|
||||
|
||||
expect(screen.getByText("Test Label")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should display the provided value", () => {
|
||||
render(<StepInput {...defaultProps} value="Hello World" />);
|
||||
|
||||
const input = screen.getByTestId("step-input-test-input");
|
||||
expect(input).toHaveValue("Hello World");
|
||||
});
|
||||
|
||||
it("should call onChange when user types", async () => {
|
||||
const mockOnChange = vi.fn();
|
||||
const user = userEvent.setup();
|
||||
|
||||
render(<StepInput {...defaultProps} onChange={mockOnChange} />);
|
||||
|
||||
const input = screen.getByTestId("step-input-test-input");
|
||||
await user.type(input, "a");
|
||||
|
||||
expect(mockOnChange).toHaveBeenCalledWith("a");
|
||||
});
|
||||
|
||||
it("should call onChange with the full input value on each keystroke", async () => {
|
||||
const mockOnChange = vi.fn();
|
||||
const user = userEvent.setup();
|
||||
|
||||
render(<StepInput {...defaultProps} onChange={mockOnChange} />);
|
||||
|
||||
const input = screen.getByTestId("step-input-test-input");
|
||||
await user.type(input, "abc");
|
||||
|
||||
expect(mockOnChange).toHaveBeenCalledTimes(3);
|
||||
expect(mockOnChange).toHaveBeenNthCalledWith(1, "a");
|
||||
expect(mockOnChange).toHaveBeenNthCalledWith(2, "b");
|
||||
expect(mockOnChange).toHaveBeenNthCalledWith(3, "c");
|
||||
});
|
||||
|
||||
it("should use the id prop for data-testid", () => {
|
||||
render(<StepInput {...defaultProps} id="org_name" />);
|
||||
|
||||
expect(screen.getByTestId("step-input-org_name")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should render as a text input", () => {
|
||||
render(<StepInput {...defaultProps} />);
|
||||
|
||||
const input = screen.getByTestId("step-input-test-input");
|
||||
expect(input).toHaveAttribute("type", "text");
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,351 @@
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
|
||||
import { screen, waitFor } from "@testing-library/react";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { renderWithProviders } from "test-utils";
|
||||
import { AddCreditsModal } from "#/components/features/org/add-credits-modal";
|
||||
import BillingService from "#/api/billing-service/billing-service.api";
|
||||
|
||||
vi.mock("react-i18next", async (importOriginal) => ({
|
||||
...(await importOriginal<typeof import("react-i18next")>()),
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key,
|
||||
i18n: {
|
||||
changeLanguage: vi.fn(),
|
||||
},
|
||||
}),
|
||||
}));
|
||||
|
||||
describe("AddCreditsModal", () => {
|
||||
const onCloseMock = vi.fn();
|
||||
|
||||
const renderModal = () => {
|
||||
const user = userEvent.setup();
|
||||
renderWithProviders(<AddCreditsModal onClose={onCloseMock} />);
|
||||
return { user };
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe("Rendering", () => {
|
||||
it("should render the form with correct elements", () => {
|
||||
renderModal();
|
||||
|
||||
expect(screen.getByTestId("add-credits-form")).toBeInTheDocument();
|
||||
expect(screen.getByTestId("amount-input")).toBeInTheDocument();
|
||||
expect(screen.getByRole("button", { name: /ORG\$NEXT/i })).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should display the title", () => {
|
||||
renderModal();
|
||||
|
||||
expect(screen.getByText("ORG$ADD_CREDITS")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Button State Management", () => {
|
||||
it("should enable submit button initially when modal opens", () => {
|
||||
renderModal();
|
||||
|
||||
const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i });
|
||||
expect(nextButton).not.toBeDisabled();
|
||||
});
|
||||
|
||||
it("should enable submit button when input contains invalid value", async () => {
|
||||
const { user } = renderModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i });
|
||||
|
||||
await user.type(amountInput, "-50");
|
||||
|
||||
expect(nextButton).not.toBeDisabled();
|
||||
});
|
||||
|
||||
it("should enable submit button when input contains valid value", async () => {
|
||||
const { user } = renderModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i });
|
||||
|
||||
await user.type(amountInput, "100");
|
||||
|
||||
expect(nextButton).not.toBeDisabled();
|
||||
});
|
||||
|
||||
it("should enable submit button after validation error is shown", async () => {
|
||||
const { user } = renderModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i });
|
||||
|
||||
await user.type(amountInput, "9");
|
||||
await user.click(nextButton);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("amount-error")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
expect(nextButton).not.toBeDisabled();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Input Attributes & Placeholder", () => {
|
||||
it("should have min attribute set to 10", () => {
|
||||
renderModal();
|
||||
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
expect(amountInput).toHaveAttribute("min", "10");
|
||||
});
|
||||
|
||||
it("should have max attribute set to 25000", () => {
|
||||
renderModal();
|
||||
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
expect(amountInput).toHaveAttribute("max", "25000");
|
||||
});
|
||||
|
||||
it("should have step attribute set to 1", () => {
|
||||
renderModal();
|
||||
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
expect(amountInput).toHaveAttribute("step", "1");
|
||||
});
|
||||
});
|
||||
|
||||
describe("Error Message Display", () => {
|
||||
it("should not display error message initially when modal opens", () => {
|
||||
renderModal();
|
||||
|
||||
const errorMessage = screen.queryByTestId("amount-error");
|
||||
expect(errorMessage).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should display error message after submitting amount above maximum", async () => {
|
||||
const { user } = renderModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i });
|
||||
|
||||
await user.type(amountInput, "25001");
|
||||
await user.click(nextButton);
|
||||
|
||||
await waitFor(() => {
|
||||
const errorMessage = screen.getByTestId("amount-error");
|
||||
expect(errorMessage).toHaveTextContent("PAYMENT$ERROR_MAXIMUM_AMOUNT");
|
||||
});
|
||||
});
|
||||
|
||||
it("should display error message after submitting decimal value", async () => {
|
||||
const { user } = renderModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i });
|
||||
|
||||
await user.type(amountInput, "50.5");
|
||||
await user.click(nextButton);
|
||||
|
||||
await waitFor(() => {
|
||||
const errorMessage = screen.getByTestId("amount-error");
|
||||
expect(errorMessage).toHaveTextContent("PAYMENT$ERROR_MUST_BE_WHOLE_NUMBER");
|
||||
});
|
||||
});
|
||||
|
||||
it("should display error message after submitting amount below minimum", async () => {
|
||||
const { user } = renderModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i });
|
||||
|
||||
await user.type(amountInput, "9");
|
||||
await user.click(nextButton);
|
||||
|
||||
await waitFor(() => {
|
||||
const errorMessage = screen.getByTestId("amount-error");
|
||||
expect(errorMessage).toHaveTextContent("PAYMENT$ERROR_MINIMUM_AMOUNT");
|
||||
});
|
||||
});
|
||||
|
||||
it("should display error message after submitting negative amount", async () => {
|
||||
const { user } = renderModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i });
|
||||
|
||||
await user.type(amountInput, "-50");
|
||||
await user.click(nextButton);
|
||||
|
||||
await waitFor(() => {
|
||||
const errorMessage = screen.getByTestId("amount-error");
|
||||
expect(errorMessage).toHaveTextContent("PAYMENT$ERROR_NEGATIVE_AMOUNT");
|
||||
});
|
||||
});
|
||||
|
||||
it("should replace error message when submitting different invalid value", async () => {
|
||||
const { user } = renderModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i });
|
||||
|
||||
await user.type(amountInput, "9");
|
||||
await user.click(nextButton);
|
||||
|
||||
await waitFor(() => {
|
||||
const errorMessage = screen.getByTestId("amount-error");
|
||||
expect(errorMessage).toHaveTextContent("PAYMENT$ERROR_MINIMUM_AMOUNT");
|
||||
});
|
||||
|
||||
await user.clear(amountInput);
|
||||
await user.type(amountInput, "25001");
|
||||
await user.click(nextButton);
|
||||
|
||||
await waitFor(() => {
|
||||
const errorMessage = screen.getByTestId("amount-error");
|
||||
expect(errorMessage).toHaveTextContent("PAYMENT$ERROR_MAXIMUM_AMOUNT");
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("Form Submission Behavior", () => {
|
||||
it("should prevent submission when amount is invalid", async () => {
|
||||
const createCheckoutSessionSpy = vi.spyOn(
|
||||
BillingService,
|
||||
"createCheckoutSession",
|
||||
);
|
||||
const { user } = renderModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i });
|
||||
|
||||
await user.type(amountInput, "9");
|
||||
await user.click(nextButton);
|
||||
|
||||
expect(createCheckoutSessionSpy).not.toHaveBeenCalled();
|
||||
await waitFor(() => {
|
||||
const errorMessage = screen.getByTestId("amount-error");
|
||||
expect(errorMessage).toHaveTextContent("PAYMENT$ERROR_MINIMUM_AMOUNT");
|
||||
});
|
||||
});
|
||||
|
||||
it("should call createCheckoutSession with correct amount when valid", async () => {
|
||||
const createCheckoutSessionSpy = vi.spyOn(
|
||||
BillingService,
|
||||
"createCheckoutSession",
|
||||
);
|
||||
const { user } = renderModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i });
|
||||
|
||||
await user.type(amountInput, "1000");
|
||||
await user.click(nextButton);
|
||||
|
||||
expect(createCheckoutSessionSpy).toHaveBeenCalledWith(1000);
|
||||
const errorMessage = screen.queryByTestId("amount-error");
|
||||
expect(errorMessage).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not call createCheckoutSession when validation fails", async () => {
|
||||
const createCheckoutSessionSpy = vi.spyOn(
|
||||
BillingService,
|
||||
"createCheckoutSession",
|
||||
);
|
||||
const { user } = renderModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i });
|
||||
|
||||
await user.type(amountInput, "-50");
|
||||
await user.click(nextButton);
|
||||
|
||||
expect(createCheckoutSessionSpy).not.toHaveBeenCalled();
|
||||
await waitFor(() => {
|
||||
const errorMessage = screen.getByTestId("amount-error");
|
||||
expect(errorMessage).toHaveTextContent("PAYMENT$ERROR_NEGATIVE_AMOUNT");
|
||||
});
|
||||
});
|
||||
|
||||
it("should close modal on successful submission", async () => {
|
||||
vi.spyOn(BillingService, "createCheckoutSession").mockResolvedValue(
|
||||
"https://checkout.stripe.com/test-session",
|
||||
);
|
||||
|
||||
const { user } = renderModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i });
|
||||
|
||||
await user.type(amountInput, "1000");
|
||||
await user.click(nextButton);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(onCloseMock).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
it("should allow API call when validation passes and clear any previous errors", async () => {
|
||||
const createCheckoutSessionSpy = vi.spyOn(
|
||||
BillingService,
|
||||
"createCheckoutSession",
|
||||
);
|
||||
|
||||
const { user } = renderModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i });
|
||||
|
||||
// First submit invalid value
|
||||
await user.type(amountInput, "9");
|
||||
await user.click(nextButton);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("amount-error")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Then submit valid value
|
||||
await user.clear(amountInput);
|
||||
await user.type(amountInput, "100");
|
||||
await user.click(nextButton);
|
||||
|
||||
expect(createCheckoutSessionSpy).toHaveBeenCalledWith(100);
|
||||
const errorMessage = screen.queryByTestId("amount-error");
|
||||
expect(errorMessage).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Edge Cases", () => {
|
||||
it("should handle zero value correctly", async () => {
|
||||
const { user } = renderModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i });
|
||||
|
||||
await user.type(amountInput, "0");
|
||||
await user.click(nextButton);
|
||||
|
||||
await waitFor(() => {
|
||||
const errorMessage = screen.getByTestId("amount-error");
|
||||
expect(errorMessage).toHaveTextContent("PAYMENT$ERROR_MINIMUM_AMOUNT");
|
||||
});
|
||||
});
|
||||
|
||||
it("should handle whitespace-only input correctly", async () => {
|
||||
const createCheckoutSessionSpy = vi.spyOn(
|
||||
BillingService,
|
||||
"createCheckoutSession",
|
||||
);
|
||||
const { user } = renderModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /ORG\$NEXT/i });
|
||||
|
||||
// Number inputs typically don't accept spaces, but test the behavior
|
||||
await user.type(amountInput, " ");
|
||||
await user.click(nextButton);
|
||||
|
||||
// Should not call API (empty/invalid input)
|
||||
expect(createCheckoutSessionSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Modal Interaction", () => {
|
||||
it("should call onClose when cancel button is clicked", async () => {
|
||||
const { user } = renderModal();
|
||||
|
||||
const cancelButton = screen.getByRole("button", { name: /close/i });
|
||||
await user.click(cancelButton);
|
||||
|
||||
expect(onCloseMock).toHaveBeenCalledOnce();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,7 +1,8 @@
|
||||
import { render, screen } from "@testing-library/react";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import { describe, expect, it, vi, beforeEach } from "vitest";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { ApiKeysManager } from "#/components/features/settings/api-keys-manager";
|
||||
import { useSelectedOrganizationStore } from "#/stores/selected-organization-store";
|
||||
|
||||
// Mock the react-i18next
|
||||
vi.mock("react-i18next", async () => {
|
||||
@@ -37,6 +38,10 @@ vi.mock("#/hooks/query/use-api-keys", () => ({
|
||||
}));
|
||||
|
||||
describe("ApiKeysManager", () => {
|
||||
beforeEach(() => {
|
||||
useSelectedOrganizationStore.setState({ organizationId: "test-org-id" });
|
||||
});
|
||||
|
||||
const renderComponent = () => {
|
||||
const queryClient = new QueryClient();
|
||||
return render(
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import {
|
||||
renderWithProviders,
|
||||
createAxiosNotFoundErrorObject,
|
||||
@@ -10,6 +10,7 @@ import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
import OptionService from "#/api/option-service/option-service.api";
|
||||
import { MOCK_DEFAULT_USER_SETTINGS } from "#/mocks/handlers";
|
||||
import { WebClientConfig } from "#/api/option-service/option.types";
|
||||
import { useSelectedOrganizationStore } from "#/stores/selected-organization-store";
|
||||
|
||||
// Helper to create mock config with sensible defaults
|
||||
const createMockConfig = (
|
||||
@@ -76,6 +77,10 @@ describe("Sidebar", () => {
|
||||
const getSettingsSpy = vi.spyOn(SettingsService, "getSettings");
|
||||
const getConfigSpy = vi.spyOn(OptionService, "getConfig");
|
||||
|
||||
beforeEach(() => {
|
||||
useSelectedOrganizationStore.setState({ organizationId: "test-org-id" });
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
@@ -18,6 +18,27 @@ import { OrganizationMember } from "#/types/org";
|
||||
import { useSelectedOrganizationStore } from "#/stores/selected-organization-store";
|
||||
import { createMockWebClientConfig } from "#/mocks/settings-handlers";
|
||||
|
||||
// Mock useBreakpoint hook
|
||||
vi.mock("#/hooks/use-breakpoint", () => ({
|
||||
useBreakpoint: vi.fn(() => false), // Default to desktop (not mobile)
|
||||
}));
|
||||
|
||||
// Mock feature flags
|
||||
const mockEnableProjUserJourney = vi.fn(() => true);
|
||||
vi.mock("#/utils/feature-flags", () => ({
|
||||
ENABLE_PROJ_USER_JOURNEY: () => mockEnableProjUserJourney(),
|
||||
}));
|
||||
|
||||
// Mock useTracking hook for CTA
|
||||
vi.mock("#/hooks/use-tracking", () => ({
|
||||
useTracking: () => ({
|
||||
trackSaasSelfhostedInquiry: vi.fn(),
|
||||
}),
|
||||
}));
|
||||
|
||||
// Import the mocked modules
|
||||
import * as breakpoint from "#/hooks/use-breakpoint";
|
||||
|
||||
type UserContextMenuProps = GetComponentPropTypes<typeof UserContextMenu>;
|
||||
|
||||
function UserContextMenuWithRootOutlet({
|
||||
@@ -123,6 +144,9 @@ describe("UserContextMenu", () => {
|
||||
// Ensure clean state at the start of each test
|
||||
vi.restoreAllMocks();
|
||||
useSelectedOrganizationStore.setState({ organizationId: null });
|
||||
// Reset feature flag and breakpoint mocks to defaults
|
||||
mockEnableProjUserJourney.mockReturnValue(true);
|
||||
vi.mocked(breakpoint.useBreakpoint).mockReturnValue(false); // Desktop by default
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -132,11 +156,19 @@ describe("UserContextMenu", () => {
|
||||
useSelectedOrganizationStore.setState({ organizationId: null });
|
||||
});
|
||||
|
||||
it("should render the default context items for a user", () => {
|
||||
it("should render the default context items for a user", async () => {
|
||||
vi.spyOn(OptionService, "getConfig").mockResolvedValue(
|
||||
createMockWebClientConfig({ app_mode: "saas" }),
|
||||
);
|
||||
|
||||
renderUserContextMenu({ type: "member", onClose: vi.fn, onOpenInviteModal: vi.fn });
|
||||
|
||||
screen.getByTestId("org-selector");
|
||||
screen.getByText("ACCOUNT_SETTINGS$LOGOUT");
|
||||
|
||||
// Wait for config to load so logout button appears
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("ACCOUNT_SETTINGS$LOGOUT")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
expect(
|
||||
screen.queryByText("ORG$INVITE_ORG_MEMBERS"),
|
||||
@@ -280,6 +312,20 @@ describe("UserContextMenu", () => {
|
||||
screen.queryByText("Organization Members"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not display logout button in OSS mode", async () => {
|
||||
renderUserContextMenu({ type: "member", onClose: vi.fn, onOpenInviteModal: vi.fn });
|
||||
|
||||
// Wait for the config to load
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("SETTINGS$NAV_LLM")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Verify logout button is NOT rendered in OSS mode
|
||||
expect(
|
||||
screen.queryByText("ACCOUNT_SETTINGS$LOGOUT"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("HIDE_LLM_SETTINGS feature flag", () => {
|
||||
@@ -358,10 +404,15 @@ describe("UserContextMenu", () => {
|
||||
});
|
||||
|
||||
it("should call the logout handler when Logout is clicked", async () => {
|
||||
vi.spyOn(OptionService, "getConfig").mockResolvedValue(
|
||||
createMockWebClientConfig({ app_mode: "saas" }),
|
||||
);
|
||||
|
||||
const logoutSpy = vi.spyOn(AuthService, "logout");
|
||||
renderUserContextMenu({ type: "member", onClose: vi.fn, onOpenInviteModal: vi.fn });
|
||||
|
||||
const logoutButton = screen.getByText("ACCOUNT_SETTINGS$LOGOUT");
|
||||
// Wait for config to load so logout button appears
|
||||
const logoutButton = await screen.findByText("ACCOUNT_SETTINGS$LOGOUT");
|
||||
await userEvent.click(logoutButton);
|
||||
|
||||
expect(logoutSpy).toHaveBeenCalledOnce();
|
||||
@@ -464,6 +515,10 @@ describe("UserContextMenu", () => {
|
||||
});
|
||||
|
||||
it("should call the onClose handler after each action", async () => {
|
||||
vi.spyOn(OptionService, "getConfig").mockResolvedValue(
|
||||
createMockWebClientConfig({ app_mode: "saas" }),
|
||||
);
|
||||
|
||||
// Mock a team org so org management buttons are visible
|
||||
vi.spyOn(organizationService, "getOrganizations").mockResolvedValue({
|
||||
items: [MOCK_TEAM_ORG_ACME],
|
||||
@@ -473,7 +528,8 @@ describe("UserContextMenu", () => {
|
||||
const onCloseMock = vi.fn();
|
||||
renderUserContextMenu({ type: "owner", onClose: onCloseMock, onOpenInviteModal: vi.fn });
|
||||
|
||||
const logoutButton = screen.getByText("ACCOUNT_SETTINGS$LOGOUT");
|
||||
// Wait for config to load so logout button appears
|
||||
const logoutButton = await screen.findByText("ACCOUNT_SETTINGS$LOGOUT");
|
||||
await userEvent.click(logoutButton);
|
||||
expect(onCloseMock).toHaveBeenCalledTimes(1);
|
||||
|
||||
@@ -630,4 +686,77 @@ describe("UserContextMenu", () => {
|
||||
// Verify that the dropdown shows the selected organization
|
||||
expect(screen.getByRole("combobox")).toHaveValue(INITIAL_MOCK_ORGS[1].name);
|
||||
});
|
||||
|
||||
describe("Context Menu CTA", () => {
|
||||
it("should render the CTA component in SaaS mode on desktop with feature flag enabled", async () => {
|
||||
// Set SaaS mode
|
||||
vi.spyOn(OptionService, "getConfig").mockResolvedValue(
|
||||
createMockWebClientConfig({ app_mode: "saas" }),
|
||||
);
|
||||
|
||||
renderUserContextMenu({ type: "member", onClose: vi.fn, onOpenInviteModal: vi.fn });
|
||||
|
||||
// Wait for config to load
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("context-menu-cta")).toBeInTheDocument();
|
||||
});
|
||||
expect(screen.getByText("CTA$ENTERPRISE_TITLE")).toBeInTheDocument();
|
||||
expect(screen.getByText("CTA$LEARN_MORE")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not render the CTA component in OSS mode even with feature flag enabled", async () => {
|
||||
// Set OSS mode
|
||||
vi.spyOn(OptionService, "getConfig").mockResolvedValue(
|
||||
createMockWebClientConfig({ app_mode: "oss" }),
|
||||
);
|
||||
|
||||
renderUserContextMenu({ type: "member", onClose: vi.fn, onOpenInviteModal: vi.fn });
|
||||
|
||||
// Wait for config to load
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("user-context-menu")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
expect(screen.queryByTestId("context-menu-cta")).not.toBeInTheDocument();
|
||||
expect(screen.queryByText("CTA$ENTERPRISE_TITLE")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not render the CTA component on mobile even in SaaS mode with feature flag enabled", async () => {
|
||||
// Set SaaS mode
|
||||
vi.spyOn(OptionService, "getConfig").mockResolvedValue(
|
||||
createMockWebClientConfig({ app_mode: "saas" }),
|
||||
);
|
||||
// Set mobile mode
|
||||
vi.mocked(breakpoint.useBreakpoint).mockReturnValue(true);
|
||||
|
||||
renderUserContextMenu({ type: "member", onClose: vi.fn, onOpenInviteModal: vi.fn });
|
||||
|
||||
// Wait for config to load
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("user-context-menu")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
expect(screen.queryByTestId("context-menu-cta")).not.toBeInTheDocument();
|
||||
expect(screen.queryByText("CTA$ENTERPRISE_TITLE")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not render the CTA component when feature flag is disabled in SaaS mode", async () => {
|
||||
// Set SaaS mode
|
||||
vi.spyOn(OptionService, "getConfig").mockResolvedValue(
|
||||
createMockWebClientConfig({ app_mode: "saas" }),
|
||||
);
|
||||
// Disable the feature flag
|
||||
mockEnableProjUserJourney.mockReturnValue(false);
|
||||
|
||||
renderUserContextMenu({ type: "member", onClose: vi.fn, onOpenInviteModal: vi.fn });
|
||||
|
||||
// Wait for config to load
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("user-context-menu")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
expect(screen.queryByTestId("context-menu-cta")).not.toBeInTheDocument();
|
||||
expect(screen.queryByText("CTA$ENTERPRISE_TITLE")).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,26 +1,25 @@
|
||||
import { screen } from "@testing-library/react";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { afterEach, beforeAll, describe, expect, it, vi } from "vitest";
|
||||
import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { MemoryRouter } from "react-router";
|
||||
import { InteractiveChatBox } from "#/components/features/chat/interactive-chat-box";
|
||||
import { renderWithProviders } from "../../test-utils";
|
||||
import { AgentState } from "#/types/agent-state";
|
||||
import { useAgentState } from "#/hooks/use-agent-state";
|
||||
import { useConversationStore } from "#/stores/conversation-store";
|
||||
import { useSelectedOrganizationStore } from "#/stores/selected-organization-store";
|
||||
|
||||
vi.mock("#/hooks/use-agent-state", () => ({
|
||||
useAgentState: vi.fn(),
|
||||
}));
|
||||
|
||||
// Mock React Router hooks
|
||||
vi.mock("react-router", async () => {
|
||||
const actual = await vi.importActual("react-router");
|
||||
return {
|
||||
...actual,
|
||||
useNavigate: () => vi.fn(),
|
||||
useParams: () => ({ conversationId: "test-conversation-id" }),
|
||||
};
|
||||
});
|
||||
vi.mock("react-router", async (importOriginal) => ({
|
||||
...(await importOriginal<typeof import("react-router")>()),
|
||||
useNavigate: () => vi.fn(),
|
||||
useParams: () => ({ conversationId: "test-conversation-id" }),
|
||||
useRevalidator: () => ({ revalidate: vi.fn() }),
|
||||
}));
|
||||
|
||||
// Mock the useActiveConversation hook
|
||||
vi.mock("#/hooks/query/use-active-conversation", () => ({
|
||||
@@ -52,6 +51,10 @@ vi.mock("#/hooks/use-conversation-name-context-menu", () => ({
|
||||
describe("InteractiveChatBox", () => {
|
||||
const onSubmitMock = vi.fn();
|
||||
|
||||
beforeEach(() => {
|
||||
useSelectedOrganizationStore.setState({ organizationId: "test-org-id" });
|
||||
});
|
||||
|
||||
const mockStores = (agentState: AgentState = AgentState.INIT) => {
|
||||
vi.mocked(useAgentState).mockReturnValue({
|
||||
curAgentState: agentState,
|
||||
@@ -213,6 +216,36 @@ describe("InteractiveChatBox", () => {
|
||||
expect(onSubmitMock).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("should lock the text input field when disabled prop is true (isNewConversationPending)", () => {
|
||||
mockStores(AgentState.INIT);
|
||||
|
||||
renderInteractiveChatBox({
|
||||
onSubmit: onSubmitMock,
|
||||
disabled: true,
|
||||
});
|
||||
|
||||
const chatInput = screen.getByTestId("chat-input");
|
||||
// When disabled=true, the text field should not be editable
|
||||
expect(chatInput).toHaveAttribute("contenteditable", "false");
|
||||
// Should show visual disabled state
|
||||
expect(chatInput.className).toContain("cursor-not-allowed");
|
||||
expect(chatInput.className).toContain("opacity-50");
|
||||
});
|
||||
|
||||
it("should keep the text input field editable when disabled prop is false", () => {
|
||||
mockStores(AgentState.INIT);
|
||||
|
||||
renderInteractiveChatBox({
|
||||
onSubmit: onSubmitMock,
|
||||
disabled: false,
|
||||
});
|
||||
|
||||
const chatInput = screen.getByTestId("chat-input");
|
||||
expect(chatInput).toHaveAttribute("contenteditable", "true");
|
||||
expect(chatInput.className).not.toContain("cursor-not-allowed");
|
||||
expect(chatInput.className).not.toContain("opacity-50");
|
||||
});
|
||||
|
||||
it("should handle image upload and message submission correctly", async () => {
|
||||
const user = userEvent.setup();
|
||||
const onSubmit = vi.fn();
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
import { render, screen, waitFor } from "@testing-library/react";
|
||||
import { render, screen, waitFor, fireEvent, act } from "@testing-library/react";
|
||||
import { describe, expect, it, vi, afterEach, beforeEach, test } from "vitest";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { QueryClientProvider, QueryClient } from "@tanstack/react-query";
|
||||
import { MemoryRouter } from "react-router";
|
||||
import { MemoryRouter, createRoutesStub } from "react-router";
|
||||
import { ReactElement } from "react";
|
||||
import { http, HttpResponse } from "msw";
|
||||
import { UserActions } from "#/components/features/sidebar/user-actions";
|
||||
import { organizationService } from "#/api/organization-service/organization-service.api";
|
||||
import { MOCK_PERSONAL_ORG, MOCK_TEAM_ORG_ACME } from "#/mocks/org-handlers";
|
||||
import { useSelectedOrganizationStore } from "#/stores/selected-organization-store";
|
||||
import { server } from "#/mocks/node";
|
||||
import { createMockWebClientConfig } from "#/mocks/settings-handlers";
|
||||
import { renderWithProviders } from "../../test-utils";
|
||||
|
||||
vi.mock("react-router", async (importActual) => ({
|
||||
@@ -59,6 +62,20 @@ const renderUserActions = (props = { hasAvatar: true }) => {
|
||||
);
|
||||
};
|
||||
|
||||
// RouterStub and render helper for menu close delay tests
|
||||
const RouterStubForMenuCloseDelay = createRoutesStub([
|
||||
{
|
||||
path: "/",
|
||||
Component: () => (
|
||||
<UserActions user={{ avatar_url: "https://example.com/avatar.png" }} />
|
||||
),
|
||||
},
|
||||
]);
|
||||
|
||||
const renderUserActionsForMenuCloseDelay = () => {
|
||||
return renderWithProviders(<RouterStubForMenuCloseDelay initialEntries={["/"]} />);
|
||||
};
|
||||
|
||||
// Create mocks for all the hooks we need
|
||||
const useIsAuthedMock = vi
|
||||
.fn()
|
||||
@@ -347,7 +364,7 @@ describe("UserActions", () => {
|
||||
expect(contextMenu).toBeVisible();
|
||||
});
|
||||
|
||||
it("should have pointer-events-none on hover bridge pseudo-element to allow menu item clicks", async () => {
|
||||
it("should use state-based visibility for hover behavior instead of CSS pseudo-element", async () => {
|
||||
renderUserActions();
|
||||
|
||||
const userActions = screen.getByTestId("user-actions");
|
||||
@@ -356,19 +373,17 @@ describe("UserActions", () => {
|
||||
const contextMenu = screen.getByTestId("user-context-menu");
|
||||
const hoverBridgeContainer = contextMenu.parentElement;
|
||||
|
||||
// The hover bridge uses a ::before pseudo-element for diagonal mouse movement
|
||||
// This pseudo-element MUST have pointer-events-none to allow clicks through to menu items
|
||||
// The class should include "before:pointer-events-none" to prevent the hover bridge from blocking clicks
|
||||
expect(hoverBridgeContainer?.className).toContain(
|
||||
"before:pointer-events-none",
|
||||
);
|
||||
// The component uses state-based visibility with a 500ms delay for diagonal mouse movement
|
||||
// When visible, the container should have opacity-100 and pointer-events-auto
|
||||
expect(hoverBridgeContainer?.className).toContain("opacity-100");
|
||||
expect(hoverBridgeContainer?.className).toContain("pointer-events-auto");
|
||||
});
|
||||
|
||||
describe("Org selector dropdown state reset when context menu hides", () => {
|
||||
// These tests verify that the org selector dropdown resets its internal
|
||||
// state (search text, open/closed) when the context menu hides and
|
||||
// reappears. Without this, stale state persists because the context
|
||||
// menu is hidden via CSS (opacity/pointer-events) rather than unmounted.
|
||||
// reappears. The component uses a 500ms delay before hiding (to support
|
||||
// diagonal mouse movement).
|
||||
|
||||
beforeEach(() => {
|
||||
vi.spyOn(organizationService, "getOrganizations").mockResolvedValue({
|
||||
@@ -400,8 +415,22 @@ describe("UserActions", () => {
|
||||
await user.type(input, "search text");
|
||||
expect(input).toHaveValue("search text");
|
||||
|
||||
// Unhover to hide context menu, then hover again
|
||||
// Unhover to trigger hide timeout, then wait for the 500ms delay to complete
|
||||
await user.unhover(userActions);
|
||||
|
||||
// Wait for the 500ms hide delay to complete and menu to actually hide
|
||||
await waitFor(
|
||||
() => {
|
||||
// The menu resets when it actually hides (after 500ms delay)
|
||||
// After hiding, hovering again should show a fresh menu
|
||||
},
|
||||
{ timeout: 600 },
|
||||
);
|
||||
|
||||
// Wait a bit more for the timeout to fire
|
||||
await new Promise((resolve) => setTimeout(resolve, 550));
|
||||
|
||||
// Now hover again to show the menu
|
||||
await user.hover(userActions);
|
||||
|
||||
// Org selector should be reset — showing selected org name, not search text
|
||||
@@ -434,8 +463,13 @@ describe("UserActions", () => {
|
||||
await user.type(input, "Acme");
|
||||
expect(input).toHaveValue("Acme");
|
||||
|
||||
// Unhover to hide context menu, then hover again
|
||||
// Unhover to trigger hide timeout
|
||||
await user.unhover(userActions);
|
||||
|
||||
// Wait for the 500ms hide delay to complete
|
||||
await new Promise((resolve) => setTimeout(resolve, 550));
|
||||
|
||||
// Now hover again to show the menu
|
||||
await user.hover(userActions);
|
||||
|
||||
// Wait for fresh component with org data
|
||||
@@ -454,4 +488,83 @@ describe("UserActions", () => {
|
||||
expect(screen.queryAllByRole("option")).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe("menu close delay", () => {
|
||||
beforeEach(() => {
|
||||
vi.useFakeTimers();
|
||||
useSelectedOrganizationStore.setState({ organizationId: "1" });
|
||||
|
||||
// Mock config to return SaaS mode so useShouldShowUserFeatures returns true
|
||||
server.use(
|
||||
http.get("/api/v1/web-client/config", () =>
|
||||
HttpResponse.json(createMockWebClientConfig({ app_mode: "saas" })),
|
||||
),
|
||||
);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers();
|
||||
server.resetHandlers();
|
||||
});
|
||||
|
||||
it("should keep menu visible when mouse leaves and re-enters within 500ms", async () => {
|
||||
// Arrange - render and wait for queries to settle
|
||||
renderUserActionsForMenuCloseDelay();
|
||||
await act(async () => {
|
||||
await vi.runAllTimersAsync();
|
||||
});
|
||||
|
||||
const userActions = screen.getByTestId("user-actions");
|
||||
|
||||
// Act - open menu
|
||||
await act(async () => {
|
||||
fireEvent.mouseEnter(userActions);
|
||||
});
|
||||
|
||||
// Assert - menu is visible
|
||||
expect(screen.getByTestId("user-context-menu")).toBeInTheDocument();
|
||||
|
||||
// Act - leave and re-enter within 500ms
|
||||
await act(async () => {
|
||||
fireEvent.mouseLeave(userActions);
|
||||
await vi.advanceTimersByTimeAsync(200);
|
||||
fireEvent.mouseEnter(userActions);
|
||||
});
|
||||
|
||||
// Assert - menu should still be visible after waiting (pending close was cancelled)
|
||||
await act(async () => {
|
||||
await vi.advanceTimersByTimeAsync(500);
|
||||
});
|
||||
expect(screen.getByTestId("user-context-menu")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not close menu before 500ms delay when mouse leaves", async () => {
|
||||
// Arrange - render and wait for queries to settle
|
||||
renderUserActionsForMenuCloseDelay();
|
||||
await act(async () => {
|
||||
await vi.runAllTimersAsync();
|
||||
});
|
||||
|
||||
const userActions = screen.getByTestId("user-actions");
|
||||
|
||||
// Act - open menu
|
||||
await act(async () => {
|
||||
fireEvent.mouseEnter(userActions);
|
||||
});
|
||||
|
||||
// Assert - menu is visible
|
||||
expect(screen.getByTestId("user-context-menu")).toBeInTheDocument();
|
||||
|
||||
// Act - leave without re-entering, but check before timeout expires
|
||||
await act(async () => {
|
||||
fireEvent.mouseLeave(userActions);
|
||||
await vi.advanceTimersByTimeAsync(400); // Before the 500ms delay
|
||||
});
|
||||
|
||||
// Assert - menu should still be visible (delay hasn't expired yet)
|
||||
// Note: The menu is always in DOM but with opacity-0 when closed.
|
||||
// This test verifies the state hasn't changed yet (delay is working).
|
||||
expect(screen.getByTestId("user-context-menu")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -0,0 +1,95 @@
|
||||
import { render, screen } from "@testing-library/react";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { getEventContent } from "#/components/v1/chat";
|
||||
import { ActionEvent, ObservationEvent, SecurityRisk } from "#/types/v1/core";
|
||||
|
||||
const terminalActionEvent: ActionEvent = {
|
||||
id: "action-1",
|
||||
timestamp: new Date().toISOString(),
|
||||
source: "agent",
|
||||
thought: [{ type: "text", text: "Checking repository status." }],
|
||||
thinking_blocks: [],
|
||||
action: {
|
||||
kind: "TerminalAction",
|
||||
command: "git status",
|
||||
is_input: false,
|
||||
timeout: null,
|
||||
reset: false,
|
||||
},
|
||||
tool_name: "terminal",
|
||||
tool_call_id: "tool-1",
|
||||
tool_call: {
|
||||
id: "tool-1",
|
||||
type: "function",
|
||||
function: {
|
||||
name: "terminal",
|
||||
arguments: '{"command":"git status"}',
|
||||
},
|
||||
},
|
||||
llm_response_id: "response-1",
|
||||
security_risk: SecurityRisk.LOW,
|
||||
summary: "Check repository status",
|
||||
};
|
||||
|
||||
const terminalObservationEvent: ObservationEvent = {
|
||||
id: "obs-1",
|
||||
timestamp: new Date().toISOString(),
|
||||
source: "environment",
|
||||
tool_name: "terminal",
|
||||
tool_call_id: "tool-1",
|
||||
action_id: "action-1",
|
||||
observation: {
|
||||
kind: "TerminalObservation",
|
||||
content: [{ type: "text", text: "On branch main" }],
|
||||
command: "git status",
|
||||
exit_code: 0,
|
||||
is_error: false,
|
||||
timeout: false,
|
||||
metadata: {
|
||||
exit_code: 0,
|
||||
pid: 1,
|
||||
username: "openhands",
|
||||
hostname: "sandbox",
|
||||
prefix: "",
|
||||
suffix: "",
|
||||
working_dir: "/workspace/project/OpenHands",
|
||||
py_interpreter_path: null,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
describe("getEventContent", () => {
|
||||
it("uses the action summary as the full action title", () => {
|
||||
const { title } = getEventContent(terminalActionEvent);
|
||||
|
||||
render(<>{title}</>);
|
||||
|
||||
expect(screen.getByText("Check repository status")).toBeInTheDocument();
|
||||
expect(screen.queryByText("$ git status")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("falls back to command-based title when summary is missing", () => {
|
||||
const actionWithoutSummary = { ...terminalActionEvent, summary: undefined };
|
||||
const { title } = getEventContent(actionWithoutSummary);
|
||||
|
||||
render(<>{title}</>);
|
||||
|
||||
// Without i18n loaded, the translation key renders as the raw key
|
||||
expect(screen.getByText("ACTION_MESSAGE$RUN")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.queryByText("Check repository status"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("reuses the action summary as the full paired observation title", () => {
|
||||
const { title } = getEventContent(
|
||||
terminalObservationEvent,
|
||||
terminalActionEvent,
|
||||
);
|
||||
|
||||
render(<>{title}</>);
|
||||
|
||||
expect(screen.getByText("Check repository status")).toBeInTheDocument();
|
||||
expect(screen.queryByText("$ git status")).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
@@ -7,6 +7,7 @@ import {
|
||||
WsClientProvider,
|
||||
useWsClient,
|
||||
} from "#/context/ws-client-provider";
|
||||
import { useSelectedOrganizationStore } from "#/stores/selected-organization-store";
|
||||
|
||||
describe("Propagate error message", () => {
|
||||
it("should do nothing when no message was passed from server", () => {
|
||||
@@ -56,6 +57,7 @@ function TestComponent() {
|
||||
describe("WsClientProvider", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
useSelectedOrganizationStore.setState({ organizationId: "test-org-id" });
|
||||
vi.mock("#/hooks/query/use-active-conversation", () => ({
|
||||
useActiveConversation: () => {
|
||||
return { data: {
|
||||
|
||||
@@ -40,6 +40,7 @@ import {
|
||||
import { conversationWebSocketTestSetup } from "./helpers/msw-websocket-setup";
|
||||
import { useEventStore } from "#/stores/use-event-store";
|
||||
import { isV1Event } from "#/types/v1/type-guards";
|
||||
import { useSelectedOrganizationStore } from "#/stores/selected-organization-store";
|
||||
|
||||
// Mock useUserConversation to return V1 conversation data
|
||||
vi.mock("#/hooks/query/use-user-conversation", () => ({
|
||||
@@ -62,6 +63,10 @@ beforeAll(() => {
|
||||
mswServer.listen({ onUnhandledRequest: "bypass" });
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
useSelectedOrganizationStore.setState({ organizationId: "test-org-id" });
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
mswServer.resetHandlers();
|
||||
// Clean up any React components
|
||||
|
||||
@@ -0,0 +1,299 @@
|
||||
import { renderHook, waitFor } from "@testing-library/react";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { describe, expect, it, vi, beforeEach } from "vitest";
|
||||
import V1ConversationService from "#/api/conversation-service/v1-conversation-service.api";
|
||||
import { useNewConversationCommand } from "#/hooks/mutation/use-new-conversation-command";
|
||||
|
||||
const mockNavigate = vi.fn();
|
||||
|
||||
vi.mock("react-router", () => ({
|
||||
useNavigate: () => mockNavigate,
|
||||
useParams: () => ({ conversationId: "conv-123" }),
|
||||
}));
|
||||
|
||||
vi.mock("react-i18next", () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key,
|
||||
}),
|
||||
}));
|
||||
|
||||
const { mockToast } = vi.hoisted(() => {
|
||||
const mockToast = Object.assign(vi.fn(), {
|
||||
loading: vi.fn(),
|
||||
dismiss: vi.fn(),
|
||||
});
|
||||
return { mockToast };
|
||||
});
|
||||
|
||||
vi.mock("react-hot-toast", () => ({
|
||||
default: mockToast,
|
||||
}));
|
||||
|
||||
vi.mock("#/utils/custom-toast-handlers", () => ({
|
||||
displaySuccessToast: vi.fn(),
|
||||
displayErrorToast: vi.fn(),
|
||||
TOAST_OPTIONS: { position: "top-right" },
|
||||
}));
|
||||
|
||||
const mockConversation = {
|
||||
conversation_id: "conv-123",
|
||||
sandbox_id: "sandbox-456",
|
||||
title: "Test Conversation",
|
||||
selected_repository: null,
|
||||
selected_branch: null,
|
||||
git_provider: null,
|
||||
last_updated_at: new Date().toISOString(),
|
||||
created_at: new Date().toISOString(),
|
||||
status: "RUNNING" as const,
|
||||
runtime_status: null,
|
||||
url: null,
|
||||
session_api_key: null,
|
||||
conversation_version: "V1" as const,
|
||||
};
|
||||
|
||||
vi.mock("#/hooks/query/use-active-conversation", () => ({
|
||||
useActiveConversation: () => ({
|
||||
data: mockConversation,
|
||||
}),
|
||||
}));
|
||||
|
||||
function makeStartTask(overrides: Record<string, unknown> = {}) {
|
||||
return {
|
||||
id: "task-789",
|
||||
created_by_user_id: null,
|
||||
status: "READY",
|
||||
detail: null,
|
||||
app_conversation_id: "new-conv-999",
|
||||
sandbox_id: "sandbox-456",
|
||||
agent_server_url: "http://agent-server.local",
|
||||
request: {
|
||||
sandbox_id: null,
|
||||
initial_message: null,
|
||||
processors: [],
|
||||
llm_model: null,
|
||||
selected_repository: null,
|
||||
selected_branch: null,
|
||||
git_provider: null,
|
||||
suggested_task: null,
|
||||
title: null,
|
||||
trigger: null,
|
||||
pr_number: [],
|
||||
parent_conversation_id: null,
|
||||
agent_type: "default",
|
||||
},
|
||||
created_at: new Date().toISOString(),
|
||||
updated_at: new Date().toISOString(),
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
describe("useNewConversationCommand", () => {
|
||||
let queryClient: QueryClient;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
queryClient = new QueryClient({
|
||||
defaultOptions: { mutations: { retry: false } },
|
||||
});
|
||||
// Mock batchGetAppConversations to return V1 data with llm_model
|
||||
vi.spyOn(
|
||||
V1ConversationService,
|
||||
"batchGetAppConversations",
|
||||
).mockResolvedValue([
|
||||
{
|
||||
id: "conv-123",
|
||||
title: "Test Conversation",
|
||||
sandbox_id: "sandbox-456",
|
||||
sandbox_status: "RUNNING",
|
||||
execution_status: "IDLE",
|
||||
conversation_url: null,
|
||||
session_api_key: null,
|
||||
selected_repository: null,
|
||||
selected_branch: null,
|
||||
git_provider: null,
|
||||
trigger: null,
|
||||
pr_number: [],
|
||||
llm_model: "gpt-4o",
|
||||
metrics: null,
|
||||
created_at: new Date().toISOString(),
|
||||
updated_at: new Date().toISOString(),
|
||||
sub_conversation_ids: [],
|
||||
public: false,
|
||||
} as never,
|
||||
]);
|
||||
});
|
||||
|
||||
const wrapper = ({ children }: { children: React.ReactNode }) => (
|
||||
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
|
||||
);
|
||||
|
||||
it("calls createConversation with sandbox_id and navigates on success", async () => {
|
||||
const readyTask = makeStartTask();
|
||||
const createSpy = vi
|
||||
.spyOn(V1ConversationService, "createConversation")
|
||||
.mockResolvedValue(readyTask as never);
|
||||
const getStartTaskSpy = vi
|
||||
.spyOn(V1ConversationService, "getStartTask")
|
||||
.mockResolvedValue(readyTask as never);
|
||||
|
||||
const { result } = renderHook(() => useNewConversationCommand(), { wrapper });
|
||||
|
||||
await result.current.mutateAsync();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(createSpy).toHaveBeenCalledWith(
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
"sandbox-456",
|
||||
"gpt-4o",
|
||||
);
|
||||
expect(getStartTaskSpy).toHaveBeenCalledWith("task-789");
|
||||
expect(mockNavigate).toHaveBeenCalledWith(
|
||||
"/conversations/new-conv-999",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it("polls getStartTask until status is READY", async () => {
|
||||
vi.useFakeTimers({ shouldAdvanceTime: true });
|
||||
|
||||
const workingTask = makeStartTask({
|
||||
status: "WORKING",
|
||||
app_conversation_id: null,
|
||||
});
|
||||
const readyTask = makeStartTask({ status: "READY" });
|
||||
|
||||
vi.spyOn(V1ConversationService, "createConversation").mockResolvedValue(
|
||||
workingTask as never,
|
||||
);
|
||||
const getStartTaskSpy = vi
|
||||
.spyOn(V1ConversationService, "getStartTask")
|
||||
.mockResolvedValueOnce(workingTask as never)
|
||||
.mockResolvedValueOnce(readyTask as never);
|
||||
|
||||
const { result } = renderHook(() => useNewConversationCommand(), { wrapper });
|
||||
|
||||
const mutatePromise = result.current.mutateAsync();
|
||||
|
||||
await vi.advanceTimersByTimeAsync(2000);
|
||||
await mutatePromise;
|
||||
|
||||
await waitFor(() => {
|
||||
expect(getStartTaskSpy).toHaveBeenCalledTimes(2);
|
||||
expect(mockNavigate).toHaveBeenCalledWith(
|
||||
"/conversations/new-conv-999",
|
||||
);
|
||||
});
|
||||
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
it("throws when task status is ERROR", async () => {
|
||||
const errorTask = makeStartTask({
|
||||
status: "ERROR",
|
||||
detail: "Sandbox crashed",
|
||||
app_conversation_id: null,
|
||||
});
|
||||
|
||||
vi.spyOn(V1ConversationService, "createConversation").mockResolvedValue(
|
||||
errorTask as never,
|
||||
);
|
||||
vi.spyOn(V1ConversationService, "getStartTask").mockResolvedValue(
|
||||
errorTask as never,
|
||||
);
|
||||
|
||||
const { result } = renderHook(() => useNewConversationCommand(), { wrapper });
|
||||
|
||||
await expect(result.current.mutateAsync()).rejects.toThrow(
|
||||
"Sandbox crashed",
|
||||
);
|
||||
});
|
||||
|
||||
it("invalidates conversation list queries on success", async () => {
|
||||
const readyTask = makeStartTask();
|
||||
|
||||
vi.spyOn(V1ConversationService, "createConversation").mockResolvedValue(
|
||||
readyTask as never,
|
||||
);
|
||||
vi.spyOn(V1ConversationService, "getStartTask").mockResolvedValue(
|
||||
readyTask as never,
|
||||
);
|
||||
|
||||
const invalidateSpy = vi.spyOn(queryClient, "invalidateQueries");
|
||||
|
||||
const { result } = renderHook(() => useNewConversationCommand(), { wrapper });
|
||||
|
||||
await result.current.mutateAsync();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(invalidateSpy).toHaveBeenCalledWith({
|
||||
queryKey: ["user", "conversations"],
|
||||
});
|
||||
expect(invalidateSpy).toHaveBeenCalledWith({
|
||||
queryKey: ["v1-batch-get-app-conversations"],
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
it("creates a standalone conversation (not a sub-conversation) so it appears in the list", async () => {
|
||||
const readyTask = makeStartTask();
|
||||
const createSpy = vi
|
||||
.spyOn(V1ConversationService, "createConversation")
|
||||
.mockResolvedValue(readyTask as never);
|
||||
vi.spyOn(V1ConversationService, "getStartTask").mockResolvedValue(
|
||||
readyTask as never,
|
||||
);
|
||||
|
||||
const { result } = renderHook(() => useNewConversationCommand(), { wrapper });
|
||||
|
||||
await result.current.mutateAsync();
|
||||
|
||||
await waitFor(() => {
|
||||
// parent_conversation_id should be undefined so the new conversation
|
||||
// is NOT a sub-conversation and will appear in the conversation list.
|
||||
expect(createSpy).toHaveBeenCalledWith(
|
||||
undefined, // selectedRepository (null from mock)
|
||||
undefined, // git_provider (null from mock)
|
||||
undefined, // initialUserMsg
|
||||
undefined, // selected_branch (null from mock)
|
||||
undefined, // conversationInstructions
|
||||
undefined, // suggestedTask
|
||||
undefined, // trigger
|
||||
undefined, // parent_conversation_id is NOT set
|
||||
undefined, // agent_type
|
||||
"sandbox-456", // sandbox_id IS set to reuse the sandbox
|
||||
"gpt-4o", // llm_model IS inherited from the original conversation
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it("shows a loading toast immediately and dismisses it on success", async () => {
|
||||
const readyTask = makeStartTask();
|
||||
|
||||
vi.spyOn(V1ConversationService, "createConversation").mockResolvedValue(
|
||||
readyTask as never,
|
||||
);
|
||||
vi.spyOn(V1ConversationService, "getStartTask").mockResolvedValue(
|
||||
readyTask as never,
|
||||
);
|
||||
|
||||
const { result } = renderHook(() => useNewConversationCommand(), { wrapper });
|
||||
|
||||
await result.current.mutateAsync();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockToast.loading).toHaveBeenCalledWith(
|
||||
"CONVERSATION$CLEARING",
|
||||
expect.objectContaining({ id: "clear-conversation" }),
|
||||
);
|
||||
expect(mockToast.dismiss).toHaveBeenCalledWith("clear-conversation");
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,10 +1,15 @@
|
||||
import { renderHook, waitFor } from "@testing-library/react";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import { describe, expect, it, vi, beforeEach } from "vitest";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
import { useSaveSettings } from "#/hooks/mutation/use-save-settings";
|
||||
import { useSelectedOrganizationStore } from "#/stores/selected-organization-store";
|
||||
|
||||
describe("useSaveSettings", () => {
|
||||
beforeEach(() => {
|
||||
useSelectedOrganizationStore.setState({ organizationId: "test-org-id" });
|
||||
});
|
||||
|
||||
it("should send an empty string for llm_api_key if an empty string is passed, otherwise undefined", async () => {
|
||||
const saveSettingsSpy = vi.spyOn(SettingsService, "saveSettings");
|
||||
const { result } = renderHook(() => useSaveSettings(), {
|
||||
|
||||
@@ -0,0 +1,225 @@
|
||||
import { renderHook, waitFor } from "@testing-library/react";
|
||||
import { describe, expect, it, vi, beforeEach } from "vitest";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import React from "react";
|
||||
import { useSettings } from "#/hooks/query/use-settings";
|
||||
import { useGetSecrets } from "#/hooks/query/use-get-secrets";
|
||||
import { useApiKeys } from "#/hooks/query/use-api-keys";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
import { SecretsService } from "#/api/secrets-service";
|
||||
import ApiKeysClient from "#/api/api-keys";
|
||||
import { MOCK_DEFAULT_USER_SETTINGS } from "#/mocks/handlers";
|
||||
import { useSelectedOrganizationStore } from "#/stores/selected-organization-store";
|
||||
|
||||
vi.mock("#/hooks/query/use-config", () => ({
|
||||
useConfig: () => ({
|
||||
data: { app_mode: "saas" },
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/query/use-is-authed", () => ({
|
||||
useIsAuthed: () => ({
|
||||
data: true,
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/use-is-on-intermediate-page", () => ({
|
||||
useIsOnIntermediatePage: () => false,
|
||||
}));
|
||||
|
||||
describe("Organization-scoped query hooks", () => {
|
||||
let queryClient: QueryClient;
|
||||
|
||||
const createWrapper = () => {
|
||||
return ({ children }: { children: React.ReactNode }) => (
|
||||
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
|
||||
);
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: {
|
||||
retry: false,
|
||||
},
|
||||
},
|
||||
});
|
||||
useSelectedOrganizationStore.setState({ organizationId: "org-1" });
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe("useSettings", () => {
|
||||
it("should include organizationId in query key for proper cache isolation", async () => {
|
||||
const getSettingsSpy = vi.spyOn(SettingsService, "getSettings");
|
||||
getSettingsSpy.mockResolvedValue(MOCK_DEFAULT_USER_SETTINGS);
|
||||
|
||||
const { result } = renderHook(() => useSettings(), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
await waitFor(() => expect(result.current.isFetched).toBe(true));
|
||||
|
||||
// Verify the query was cached with the org-specific key
|
||||
const cachedData = queryClient.getQueryData(["settings", "org-1"]);
|
||||
expect(cachedData).toBeDefined();
|
||||
|
||||
// Verify no data is cached under the old key without org ID
|
||||
const oldKeyData = queryClient.getQueryData(["settings"]);
|
||||
expect(oldKeyData).toBeUndefined();
|
||||
});
|
||||
|
||||
it("should refetch when organization changes", async () => {
|
||||
const getSettingsSpy = vi.spyOn(SettingsService, "getSettings");
|
||||
getSettingsSpy.mockResolvedValue({
|
||||
...MOCK_DEFAULT_USER_SETTINGS,
|
||||
language: "en",
|
||||
});
|
||||
|
||||
// First render with org-1
|
||||
const { result, rerender } = renderHook(() => useSettings(), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
await waitFor(() => expect(result.current.isFetched).toBe(true));
|
||||
expect(getSettingsSpy).toHaveBeenCalledTimes(1);
|
||||
|
||||
// Change organization
|
||||
useSelectedOrganizationStore.setState({ organizationId: "org-2" });
|
||||
getSettingsSpy.mockResolvedValue({
|
||||
...MOCK_DEFAULT_USER_SETTINGS,
|
||||
language: "es",
|
||||
});
|
||||
|
||||
// Rerender to pick up the new org ID
|
||||
rerender();
|
||||
|
||||
await waitFor(() => {
|
||||
// Should have fetched again for the new org
|
||||
expect(getSettingsSpy).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
// Verify both org caches exist independently
|
||||
const org1Data = queryClient.getQueryData(["settings", "org-1"]);
|
||||
const org2Data = queryClient.getQueryData(["settings", "org-2"]);
|
||||
expect(org1Data).toBeDefined();
|
||||
expect(org2Data).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe("useGetSecrets", () => {
|
||||
it("should include organizationId in query key for proper cache isolation", async () => {
|
||||
const getSecretsSpy = vi.spyOn(SecretsService, "getSecrets");
|
||||
getSecretsSpy.mockResolvedValue([]);
|
||||
|
||||
const { result } = renderHook(() => useGetSecrets(), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
await waitFor(() => expect(result.current.isFetched).toBe(true));
|
||||
|
||||
// Verify the query was cached with the org-specific key
|
||||
const cachedData = queryClient.getQueryData(["secrets", "org-1"]);
|
||||
expect(cachedData).toBeDefined();
|
||||
|
||||
// Verify no data is cached under the old key without org ID
|
||||
const oldKeyData = queryClient.getQueryData(["secrets"]);
|
||||
expect(oldKeyData).toBeUndefined();
|
||||
});
|
||||
|
||||
it("should fetch different data when organization changes", async () => {
|
||||
const getSecretsSpy = vi.spyOn(SecretsService, "getSecrets");
|
||||
|
||||
// Mock different secrets for different orgs
|
||||
getSecretsSpy.mockResolvedValueOnce([
|
||||
{ name: "SECRET_ORG_1", description: "Org 1 secret" },
|
||||
]);
|
||||
|
||||
const { result, rerender } = renderHook(() => useGetSecrets(), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
await waitFor(() => expect(result.current.isFetched).toBe(true));
|
||||
expect(result.current.data).toHaveLength(1);
|
||||
expect(result.current.data?.[0].name).toBe("SECRET_ORG_1");
|
||||
|
||||
// Change organization
|
||||
useSelectedOrganizationStore.setState({ organizationId: "org-2" });
|
||||
getSecretsSpy.mockResolvedValueOnce([
|
||||
{ name: "SECRET_ORG_2", description: "Org 2 secret" },
|
||||
]);
|
||||
|
||||
rerender();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.data?.[0]?.name).toBe("SECRET_ORG_2");
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("useApiKeys", () => {
|
||||
it("should include organizationId in query key for proper cache isolation", async () => {
|
||||
const getApiKeysSpy = vi.spyOn(ApiKeysClient, "getApiKeys");
|
||||
getApiKeysSpy.mockResolvedValue([]);
|
||||
|
||||
const { result } = renderHook(() => useApiKeys(), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
await waitFor(() => expect(result.current.isFetched).toBe(true));
|
||||
|
||||
// Verify the query was cached with the org-specific key
|
||||
const cachedData = queryClient.getQueryData(["api-keys", "org-1"]);
|
||||
expect(cachedData).toBeDefined();
|
||||
|
||||
// Verify no data is cached under the old key without org ID
|
||||
const oldKeyData = queryClient.getQueryData(["api-keys"]);
|
||||
expect(oldKeyData).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Cache isolation between organizations", () => {
|
||||
it("should maintain separate caches for each organization", async () => {
|
||||
const getSettingsSpy = vi.spyOn(SettingsService, "getSettings");
|
||||
|
||||
// Simulate fetching for org-1
|
||||
getSettingsSpy.mockResolvedValueOnce({
|
||||
...MOCK_DEFAULT_USER_SETTINGS,
|
||||
language: "en",
|
||||
});
|
||||
|
||||
useSelectedOrganizationStore.setState({ organizationId: "org-1" });
|
||||
const { rerender } = renderHook(() => useSettings(), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(queryClient.getQueryData(["settings", "org-1"])).toBeDefined();
|
||||
});
|
||||
|
||||
// Switch to org-2
|
||||
getSettingsSpy.mockResolvedValueOnce({
|
||||
...MOCK_DEFAULT_USER_SETTINGS,
|
||||
language: "fr",
|
||||
});
|
||||
|
||||
useSelectedOrganizationStore.setState({ organizationId: "org-2" });
|
||||
rerender();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(queryClient.getQueryData(["settings", "org-2"])).toBeDefined();
|
||||
});
|
||||
|
||||
// Switch back to org-1 - should use cached data, not refetch
|
||||
useSelectedOrganizationStore.setState({ organizationId: "org-1" });
|
||||
rerender();
|
||||
|
||||
// org-1 data should still be in cache
|
||||
const org1Cache = queryClient.getQueryData(["settings", "org-1"]) as any;
|
||||
expect(org1Cache?.language).toBe("en");
|
||||
|
||||
// org-2 data should also still be in cache
|
||||
const org2Cache = queryClient.getQueryData(["settings", "org-2"]) as any;
|
||||
expect(org2Cache?.language).toBe("fr");
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,64 @@
|
||||
import { renderHook } from "@testing-library/react";
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import type { Conversation } from "#/api/open-hands.types";
|
||||
import { useRuntimeIsReady } from "#/hooks/use-runtime-is-ready";
|
||||
import { useAgentState } from "#/hooks/use-agent-state";
|
||||
import { useActiveConversation } from "#/hooks/query/use-active-conversation";
|
||||
import { AgentState } from "#/types/agent-state";
|
||||
|
||||
vi.mock("#/hooks/use-agent-state");
|
||||
vi.mock("#/hooks/query/use-active-conversation");
|
||||
|
||||
function asMockReturnValue<T>(value: Partial<T>): T {
|
||||
return value as T;
|
||||
}
|
||||
|
||||
function makeConversation(): Conversation {
|
||||
return {
|
||||
conversation_id: "conv-123",
|
||||
title: "Test Conversation",
|
||||
selected_repository: null,
|
||||
selected_branch: null,
|
||||
git_provider: null,
|
||||
last_updated_at: new Date().toISOString(),
|
||||
created_at: new Date().toISOString(),
|
||||
status: "RUNNING",
|
||||
runtime_status: null,
|
||||
url: null,
|
||||
session_api_key: null,
|
||||
};
|
||||
}
|
||||
|
||||
describe("useRuntimeIsReady", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
|
||||
vi.mocked(useActiveConversation).mockReturnValue(
|
||||
asMockReturnValue<ReturnType<typeof useActiveConversation>>({
|
||||
data: makeConversation(),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it("treats agent errors as not ready by default", () => {
|
||||
vi.mocked(useAgentState).mockReturnValue({
|
||||
curAgentState: AgentState.ERROR,
|
||||
});
|
||||
|
||||
const { result } = renderHook(() => useRuntimeIsReady());
|
||||
|
||||
expect(result.current).toBe(false);
|
||||
});
|
||||
|
||||
it("allows runtime-backed tabs to stay ready when the agent errors", () => {
|
||||
vi.mocked(useAgentState).mockReturnValue({
|
||||
curAgentState: AgentState.ERROR,
|
||||
});
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useRuntimeIsReady({ allowAgentError: true }),
|
||||
);
|
||||
|
||||
expect(result.current).toBe(true);
|
||||
});
|
||||
});
|
||||
@@ -205,7 +205,9 @@ describe("useWebSocket", () => {
|
||||
expect(result.current.isConnected).toBe(true);
|
||||
});
|
||||
|
||||
expect(onCloseSpy).not.toHaveBeenCalled();
|
||||
// Reset spy after connection is established to ignore any spurious
|
||||
// close events fired by the MSW mock during the handshake.
|
||||
onCloseSpy.mockClear();
|
||||
|
||||
// Unmount to trigger close
|
||||
unmount();
|
||||
|
||||
@@ -5,9 +5,11 @@ import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import AcceptTOS from "#/routes/accept-tos";
|
||||
import * as CaptureConsent from "#/utils/handle-capture-consent";
|
||||
import { openHands } from "#/api/open-hands-axios";
|
||||
import { useSelectedOrganizationStore } from "#/stores/selected-organization-store";
|
||||
|
||||
// Mock the react-router hooks
|
||||
vi.mock("react-router", () => ({
|
||||
vi.mock("react-router", async (importOriginal) => ({
|
||||
...(await importOriginal<typeof import("react-router")>()),
|
||||
useNavigate: () => vi.fn(),
|
||||
useSearchParams: () => [
|
||||
{
|
||||
@@ -19,6 +21,7 @@ vi.mock("react-router", () => ({
|
||||
},
|
||||
},
|
||||
],
|
||||
useRevalidator: () => ({ revalidate: vi.fn() }),
|
||||
}));
|
||||
|
||||
// Mock the axios instance
|
||||
@@ -54,6 +57,7 @@ const createWrapper = () => {
|
||||
|
||||
describe("AcceptTOS", () => {
|
||||
beforeEach(() => {
|
||||
useSelectedOrganizationStore.setState({ organizationId: "test-org-id" });
|
||||
vi.stubGlobal("location", { href: "" });
|
||||
});
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { render, screen, waitFor } from "@testing-library/react";
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import AppSettingsScreen, { clientLoader } from "#/routes/app-settings";
|
||||
@@ -8,6 +8,11 @@ import { MOCK_DEFAULT_USER_SETTINGS } from "#/mocks/handlers";
|
||||
import { AvailableLanguages } from "#/i18n";
|
||||
import * as CaptureConsent from "#/utils/handle-capture-consent";
|
||||
import * as ToastHandlers from "#/utils/custom-toast-handlers";
|
||||
import { useSelectedOrganizationStore } from "#/stores/selected-organization-store";
|
||||
|
||||
beforeEach(() => {
|
||||
useSelectedOrganizationStore.setState({ organizationId: "test-org-id" });
|
||||
});
|
||||
|
||||
const renderAppSettingsScreen = () =>
|
||||
render(<AppSettingsScreen />, {
|
||||
|
||||
@@ -32,6 +32,7 @@ describe("Changes Tab", () => {
|
||||
vi.mocked(useUnifiedGetGitChanges).mockReturnValue({
|
||||
data: [],
|
||||
isLoading: false,
|
||||
isFetching: false,
|
||||
isSuccess: true,
|
||||
isError: false,
|
||||
error: null,
|
||||
@@ -50,6 +51,7 @@ describe("Changes Tab", () => {
|
||||
vi.mocked(useUnifiedGetGitChanges).mockReturnValue({
|
||||
data: [{ path: "src/file.ts", status: "M" }],
|
||||
isLoading: false,
|
||||
isFetching: false,
|
||||
isSuccess: true,
|
||||
isError: false,
|
||||
error: null,
|
||||
|
||||
@@ -5,12 +5,12 @@ import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { createRoutesStub } from "react-router";
|
||||
import DeviceVerify from "#/routes/device-verify";
|
||||
|
||||
const { useIsAuthedMock, PROJ_USER_JOURNEY_MOCK } = vi.hoisted(() => ({
|
||||
const { useIsAuthedMock, ENABLE_PROJ_USER_JOURNEY_MOCK } = vi.hoisted(() => ({
|
||||
useIsAuthedMock: vi.fn(() => ({
|
||||
data: false as boolean | undefined,
|
||||
isLoading: false,
|
||||
})),
|
||||
PROJ_USER_JOURNEY_MOCK: vi.fn(() => true),
|
||||
ENABLE_PROJ_USER_JOURNEY_MOCK: vi.fn(() => true),
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/query/use-is-authed", () => ({
|
||||
@@ -24,7 +24,7 @@ vi.mock("posthog-js/react", () => ({
|
||||
}));
|
||||
|
||||
vi.mock("#/utils/feature-flags", () => ({
|
||||
PROJ_USER_JOURNEY: () => PROJ_USER_JOURNEY_MOCK(),
|
||||
ENABLE_PROJ_USER_JOURNEY: () => ENABLE_PROJ_USER_JOURNEY_MOCK(),
|
||||
}));
|
||||
|
||||
const RouterStub = createRoutesStub([
|
||||
@@ -67,7 +67,7 @@ describe("DeviceVerify", () => {
|
||||
),
|
||||
);
|
||||
// Enable feature flag by default
|
||||
PROJ_USER_JOURNEY_MOCK.mockReturnValue(true);
|
||||
ENABLE_PROJ_USER_JOURNEY_MOCK.mockReturnValue(true);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -254,7 +254,7 @@ describe("DeviceVerify", () => {
|
||||
});
|
||||
|
||||
it("should not include the EnterpriseBanner and be center-aligned when feature flag is disabled", async () => {
|
||||
PROJ_USER_JOURNEY_MOCK.mockReturnValue(false);
|
||||
ENABLE_PROJ_USER_JOURNEY_MOCK.mockReturnValue(false);
|
||||
useIsAuthedMock.mockReturnValue({
|
||||
data: true,
|
||||
isLoading: false,
|
||||
|
||||
@@ -609,3 +609,193 @@ describe("New user welcome toast", () => {
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("HomepageCTA visibility", () => {
|
||||
const getConfigSpy = vi.spyOn(OptionService, "getConfig");
|
||||
const getSettingsSpy = vi.spyOn(SettingsService, "getSettings");
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
|
||||
vi.spyOn(AuthService, "authenticate").mockResolvedValue(true);
|
||||
|
||||
getSettingsSpy.mockResolvedValue(MOCK_DEFAULT_USER_SETTINGS);
|
||||
|
||||
// Mock localStorage to enable the PROJ_USER_JOURNEY feature flag (CTA dismissal also uses localStorage)
|
||||
vi.stubGlobal("localStorage", {
|
||||
getItem: vi.fn((key: string) => {
|
||||
if (key === "FEATURE_PROJ_USER_JOURNEY") {
|
||||
return "true";
|
||||
}
|
||||
return null;
|
||||
}),
|
||||
setItem: vi.fn(),
|
||||
removeItem: vi.fn(),
|
||||
clear: vi.fn(),
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
it("should show HomepageCTA in SaaS mode when not dismissed and feature flag enabled", async () => {
|
||||
useIsAuthedMock.mockReturnValue({
|
||||
data: true,
|
||||
isLoading: false,
|
||||
isFetching: false,
|
||||
isError: false,
|
||||
});
|
||||
useConfigMock.mockReturnValue({
|
||||
data: { app_mode: "saas", feature_flags: DEFAULT_FEATURE_FLAGS },
|
||||
isLoading: false,
|
||||
});
|
||||
|
||||
getConfigSpy.mockResolvedValue({
|
||||
app_mode: "saas",
|
||||
posthog_client_key: "test-posthog-key",
|
||||
providers_configured: ["github"],
|
||||
auth_url: "https://auth.example.com",
|
||||
feature_flags: DEFAULT_FEATURE_FLAGS,
|
||||
maintenance_start_time: null,
|
||||
recaptcha_site_key: null,
|
||||
faulty_models: [],
|
||||
error_message: null,
|
||||
updated_at: "2024-01-14T10:00:00Z",
|
||||
github_app_slug: null,
|
||||
});
|
||||
|
||||
renderHomeScreen();
|
||||
|
||||
await screen.findByTestId("home-screen");
|
||||
|
||||
const ctaLink = await screen.findByRole("link", {
|
||||
name: "CTA$LEARN_MORE",
|
||||
});
|
||||
expect(ctaLink).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not show HomepageCTA in OSS mode even with feature flag enabled", async () => {
|
||||
useIsAuthedMock.mockReturnValue({
|
||||
data: true,
|
||||
isLoading: false,
|
||||
isFetching: false,
|
||||
isError: false,
|
||||
});
|
||||
useConfigMock.mockReturnValue({
|
||||
data: { app_mode: "oss", feature_flags: DEFAULT_FEATURE_FLAGS },
|
||||
isLoading: false,
|
||||
});
|
||||
|
||||
getConfigSpy.mockResolvedValue({
|
||||
app_mode: "oss",
|
||||
posthog_client_key: "test-posthog-key",
|
||||
providers_configured: ["github"],
|
||||
auth_url: "https://auth.example.com",
|
||||
feature_flags: DEFAULT_FEATURE_FLAGS,
|
||||
maintenance_start_time: null,
|
||||
recaptcha_site_key: null,
|
||||
faulty_models: [],
|
||||
error_message: null,
|
||||
updated_at: "2024-01-14T10:00:00Z",
|
||||
github_app_slug: null,
|
||||
});
|
||||
|
||||
renderHomeScreen();
|
||||
|
||||
await screen.findByTestId("home-screen");
|
||||
|
||||
expect(screen.queryByText("CTA$ENTERPRISE_TITLE")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not show HomepageCTA when feature flag is disabled", async () => {
|
||||
// Override localStorage to disable the feature flag
|
||||
vi.stubGlobal("localStorage", {
|
||||
getItem: vi.fn(() => null), // No feature flags set
|
||||
setItem: vi.fn(),
|
||||
removeItem: vi.fn(),
|
||||
clear: vi.fn(),
|
||||
});
|
||||
|
||||
useIsAuthedMock.mockReturnValue({
|
||||
data: true,
|
||||
isLoading: false,
|
||||
isFetching: false,
|
||||
isError: false,
|
||||
});
|
||||
useConfigMock.mockReturnValue({
|
||||
data: { app_mode: "saas", feature_flags: DEFAULT_FEATURE_FLAGS },
|
||||
isLoading: false,
|
||||
});
|
||||
|
||||
getConfigSpy.mockResolvedValue({
|
||||
app_mode: "saas",
|
||||
posthog_client_key: "test-posthog-key",
|
||||
providers_configured: ["github"],
|
||||
auth_url: "https://auth.example.com",
|
||||
feature_flags: DEFAULT_FEATURE_FLAGS,
|
||||
maintenance_start_time: null,
|
||||
recaptcha_site_key: null,
|
||||
faulty_models: [],
|
||||
error_message: null,
|
||||
updated_at: "2024-01-14T10:00:00Z",
|
||||
github_app_slug: null,
|
||||
});
|
||||
|
||||
renderHomeScreen();
|
||||
|
||||
await screen.findByTestId("home-screen");
|
||||
|
||||
expect(screen.queryByText("CTA$ENTERPRISE_TITLE")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not show HomepageCTA when dismissed in local storage", async () => {
|
||||
// Override localStorage to mark CTA as dismissed while keeping the feature flag enabled
|
||||
vi.stubGlobal("localStorage", {
|
||||
getItem: vi.fn((key: string) => {
|
||||
if (key === "FEATURE_PROJ_USER_JOURNEY") {
|
||||
return "true";
|
||||
}
|
||||
if (key === "homepage-cta-dismissed") {
|
||||
return "true";
|
||||
}
|
||||
return null;
|
||||
}),
|
||||
setItem: vi.fn(),
|
||||
removeItem: vi.fn(),
|
||||
clear: vi.fn(),
|
||||
});
|
||||
|
||||
useIsAuthedMock.mockReturnValue({
|
||||
data: true,
|
||||
isLoading: false,
|
||||
isFetching: false,
|
||||
isError: false,
|
||||
});
|
||||
useConfigMock.mockReturnValue({
|
||||
data: { app_mode: "saas", feature_flags: DEFAULT_FEATURE_FLAGS },
|
||||
isLoading: false,
|
||||
});
|
||||
|
||||
getConfigSpy.mockResolvedValue({
|
||||
app_mode: "saas",
|
||||
posthog_client_key: "test-posthog-key",
|
||||
providers_configured: ["github"],
|
||||
auth_url: "https://auth.example.com",
|
||||
feature_flags: DEFAULT_FEATURE_FLAGS,
|
||||
maintenance_start_time: null,
|
||||
recaptcha_site_key: null,
|
||||
faulty_models: [],
|
||||
error_message: null,
|
||||
updated_at: "2024-01-14T10:00:00Z",
|
||||
github_app_slug: null,
|
||||
});
|
||||
|
||||
renderHomeScreen();
|
||||
|
||||
await screen.findByTestId("home-screen");
|
||||
|
||||
expect(screen.queryByText("CTA$ENTERPRISE_TITLE")).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -73,6 +73,11 @@ vi.mock("#/hooks/use-invitation", () => ({
|
||||
useInvitation: () => useInvitationMock(),
|
||||
}));
|
||||
|
||||
// Mock feature flags - enable by default for tests
|
||||
vi.mock("#/utils/feature-flags", () => ({
|
||||
ENABLE_PROJ_USER_JOURNEY: () => true,
|
||||
}));
|
||||
|
||||
const RouterStub = createRoutesStub([
|
||||
{
|
||||
Component: LoginPage,
|
||||
|
||||
@@ -283,305 +283,6 @@ describe("Manage Org Route", () => {
|
||||
expect(createCheckoutSessionSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
describe("AddCreditsModal", () => {
|
||||
const openAddCreditsModal = async () => {
|
||||
const user = userEvent.setup();
|
||||
renderManageOrg();
|
||||
await screen.findByTestId("manage-org-screen");
|
||||
|
||||
await selectOrganization({ orgIndex: 0 }); // user is owner in org 1
|
||||
|
||||
const addCreditsButton = await waitFor(() => screen.getByText(/add/i));
|
||||
await user.click(addCreditsButton);
|
||||
|
||||
const addCreditsForm = screen.getByTestId("add-credits-form");
|
||||
expect(addCreditsForm).toBeInTheDocument();
|
||||
|
||||
return { user, addCreditsForm };
|
||||
};
|
||||
|
||||
describe("Button State Management", () => {
|
||||
it("should enable submit button initially when modal opens", async () => {
|
||||
await openAddCreditsModal();
|
||||
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
expect(nextButton).not.toBeDisabled();
|
||||
});
|
||||
|
||||
it("should enable submit button when input contains invalid value", async () => {
|
||||
const { user } = await openAddCreditsModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
|
||||
await user.type(amountInput, "-50");
|
||||
|
||||
expect(nextButton).not.toBeDisabled();
|
||||
});
|
||||
|
||||
it("should enable submit button when input contains valid value", async () => {
|
||||
const { user } = await openAddCreditsModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
|
||||
await user.type(amountInput, "100");
|
||||
|
||||
expect(nextButton).not.toBeDisabled();
|
||||
});
|
||||
|
||||
it("should enable submit button after validation error is shown", async () => {
|
||||
const { user } = await openAddCreditsModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
|
||||
await user.type(amountInput, "9");
|
||||
await user.click(nextButton);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("amount-error")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
expect(nextButton).not.toBeDisabled();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Input Attributes & Placeholder", () => {
|
||||
it("should have min attribute set to 10", async () => {
|
||||
await openAddCreditsModal();
|
||||
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
expect(amountInput).toHaveAttribute("min", "10");
|
||||
});
|
||||
|
||||
it("should have max attribute set to 25000", async () => {
|
||||
await openAddCreditsModal();
|
||||
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
expect(amountInput).toHaveAttribute("max", "25000");
|
||||
});
|
||||
|
||||
it("should have step attribute set to 1", async () => {
|
||||
await openAddCreditsModal();
|
||||
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
expect(amountInput).toHaveAttribute("step", "1");
|
||||
});
|
||||
});
|
||||
|
||||
describe("Error Message Display", () => {
|
||||
it("should not display error message initially when modal opens", async () => {
|
||||
await openAddCreditsModal();
|
||||
|
||||
const errorMessage = screen.queryByTestId("amount-error");
|
||||
expect(errorMessage).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should display error message after submitting amount above maximum", async () => {
|
||||
const { user } = await openAddCreditsModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
|
||||
await user.type(amountInput, "25001");
|
||||
await user.click(nextButton);
|
||||
|
||||
await waitFor(() => {
|
||||
const errorMessage = screen.getByTestId("amount-error");
|
||||
expect(errorMessage).toHaveTextContent(
|
||||
"PAYMENT$ERROR_MAXIMUM_AMOUNT",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it("should display error message after submitting decimal value", async () => {
|
||||
const { user } = await openAddCreditsModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
|
||||
await user.type(amountInput, "50.5");
|
||||
await user.click(nextButton);
|
||||
|
||||
await waitFor(() => {
|
||||
const errorMessage = screen.getByTestId("amount-error");
|
||||
expect(errorMessage).toHaveTextContent(
|
||||
"PAYMENT$ERROR_MUST_BE_WHOLE_NUMBER",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it("should replace error message when submitting different invalid value", async () => {
|
||||
const { user } = await openAddCreditsModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
|
||||
await user.type(amountInput, "9");
|
||||
await user.click(nextButton);
|
||||
|
||||
await waitFor(() => {
|
||||
const errorMessage = screen.getByTestId("amount-error");
|
||||
expect(errorMessage).toHaveTextContent(
|
||||
"PAYMENT$ERROR_MINIMUM_AMOUNT",
|
||||
);
|
||||
});
|
||||
|
||||
await user.clear(amountInput);
|
||||
await user.type(amountInput, "25001");
|
||||
await user.click(nextButton);
|
||||
|
||||
await waitFor(() => {
|
||||
const errorMessage = screen.getByTestId("amount-error");
|
||||
expect(errorMessage).toHaveTextContent(
|
||||
"PAYMENT$ERROR_MAXIMUM_AMOUNT",
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("Form Submission Behavior", () => {
|
||||
it("should prevent submission when amount is invalid", async () => {
|
||||
const createCheckoutSessionSpy = vi.spyOn(
|
||||
BillingService,
|
||||
"createCheckoutSession",
|
||||
);
|
||||
const { user } = await openAddCreditsModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
|
||||
await user.type(amountInput, "9");
|
||||
await user.click(nextButton);
|
||||
|
||||
expect(createCheckoutSessionSpy).not.toHaveBeenCalled();
|
||||
await waitFor(() => {
|
||||
const errorMessage = screen.getByTestId("amount-error");
|
||||
expect(errorMessage).toHaveTextContent(
|
||||
"PAYMENT$ERROR_MINIMUM_AMOUNT",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it("should call createCheckoutSession with correct amount when valid", async () => {
|
||||
const createCheckoutSessionSpy = vi.spyOn(
|
||||
BillingService,
|
||||
"createCheckoutSession",
|
||||
);
|
||||
const { user } = await openAddCreditsModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
|
||||
await user.type(amountInput, "1000");
|
||||
await user.click(nextButton);
|
||||
|
||||
expect(createCheckoutSessionSpy).toHaveBeenCalledWith(1000);
|
||||
const errorMessage = screen.queryByTestId("amount-error");
|
||||
expect(errorMessage).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not call createCheckoutSession when validation fails", async () => {
|
||||
const createCheckoutSessionSpy = vi.spyOn(
|
||||
BillingService,
|
||||
"createCheckoutSession",
|
||||
);
|
||||
const { user } = await openAddCreditsModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
|
||||
await user.type(amountInput, "-50");
|
||||
await user.click(nextButton);
|
||||
|
||||
// Verify mutation was not called
|
||||
expect(createCheckoutSessionSpy).not.toHaveBeenCalled();
|
||||
await waitFor(() => {
|
||||
const errorMessage = screen.getByTestId("amount-error");
|
||||
expect(errorMessage).toHaveTextContent(
|
||||
"PAYMENT$ERROR_NEGATIVE_AMOUNT",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it("should close modal on successful submission", async () => {
|
||||
const createCheckoutSessionSpy = vi
|
||||
.spyOn(BillingService, "createCheckoutSession")
|
||||
.mockResolvedValue("https://checkout.stripe.com/test-session");
|
||||
|
||||
const { user } = await openAddCreditsModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
|
||||
await user.type(amountInput, "1000");
|
||||
await user.click(nextButton);
|
||||
|
||||
expect(createCheckoutSessionSpy).toHaveBeenCalledWith(1000);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.queryByTestId("add-credits-form"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should allow API call when validation passes and clear any previous errors", async () => {
|
||||
const createCheckoutSessionSpy = vi.spyOn(
|
||||
BillingService,
|
||||
"createCheckoutSession",
|
||||
);
|
||||
|
||||
const { user } = await openAddCreditsModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
|
||||
// First submit invalid value
|
||||
await user.type(amountInput, "9");
|
||||
await user.click(nextButton);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("amount-error")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Then submit valid value
|
||||
await user.clear(amountInput);
|
||||
await user.type(amountInput, "100");
|
||||
await user.click(nextButton);
|
||||
|
||||
expect(createCheckoutSessionSpy).toHaveBeenCalledWith(100);
|
||||
const errorMessage = screen.queryByTestId("amount-error");
|
||||
expect(errorMessage).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Edge Cases", () => {
|
||||
it("should handle zero value correctly", async () => {
|
||||
const { user } = await openAddCreditsModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
|
||||
await user.type(amountInput, "0");
|
||||
await user.click(nextButton);
|
||||
|
||||
await waitFor(() => {
|
||||
const errorMessage = screen.getByTestId("amount-error");
|
||||
expect(errorMessage).toHaveTextContent(
|
||||
"PAYMENT$ERROR_MINIMUM_AMOUNT",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it("should handle whitespace-only input correctly", async () => {
|
||||
const createCheckoutSessionSpy = vi.spyOn(
|
||||
BillingService,
|
||||
"createCheckoutSession",
|
||||
);
|
||||
const { user } = await openAddCreditsModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
|
||||
// Number inputs typically don't accept spaces, but test the behavior
|
||||
await user.type(amountInput, " ");
|
||||
await user.click(nextButton);
|
||||
|
||||
// Should not call API (empty/invalid input)
|
||||
expect(createCheckoutSessionSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
it("should show add credits option for ADMIN role", async () => {
|
||||
renderManageOrg();
|
||||
await screen.findByTestId("manage-org-screen");
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
import { screen } from "@testing-library/react";
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { renderWithProviders } from "test-utils";
|
||||
import VSCodeTab from "#/routes/vscode-tab";
|
||||
import { useUnifiedVSCodeUrl } from "#/hooks/query/use-unified-vscode-url";
|
||||
import { useAgentState } from "#/hooks/use-agent-state";
|
||||
import { AgentState } from "#/types/agent-state";
|
||||
|
||||
vi.mock("#/hooks/query/use-unified-vscode-url");
|
||||
vi.mock("#/hooks/use-agent-state");
|
||||
vi.mock("#/utils/feature-flags", () => ({
|
||||
VSCODE_IN_NEW_TAB: () => false,
|
||||
}));
|
||||
|
||||
function mockVSCodeUrlHook(
|
||||
value: Partial<ReturnType<typeof useUnifiedVSCodeUrl>>,
|
||||
) {
|
||||
vi.mocked(useUnifiedVSCodeUrl).mockReturnValue({
|
||||
data: { url: "http://localhost:3000/vscode", error: null },
|
||||
error: null,
|
||||
isLoading: false,
|
||||
isError: false,
|
||||
isSuccess: true,
|
||||
status: "success",
|
||||
refetch: vi.fn(),
|
||||
...value,
|
||||
} as ReturnType<typeof useUnifiedVSCodeUrl>);
|
||||
}
|
||||
|
||||
describe("VSCodeTab", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it("keeps VSCode accessible when the agent is in an error state", () => {
|
||||
vi.mocked(useAgentState).mockReturnValue({
|
||||
curAgentState: AgentState.ERROR,
|
||||
});
|
||||
mockVSCodeUrlHook({});
|
||||
|
||||
renderWithProviders(<VSCodeTab />);
|
||||
|
||||
expect(
|
||||
screen.queryByText("DIFF_VIEWER$WAITING_FOR_RUNTIME"),
|
||||
).not.toBeInTheDocument();
|
||||
expect(screen.getByTitle("VSCODE$TITLE")).toHaveAttribute(
|
||||
"src",
|
||||
"http://localhost:3000/vscode",
|
||||
);
|
||||
});
|
||||
|
||||
it("still waits while the runtime is starting", () => {
|
||||
vi.mocked(useAgentState).mockReturnValue({
|
||||
curAgentState: AgentState.LOADING,
|
||||
});
|
||||
mockVSCodeUrlHook({});
|
||||
|
||||
renderWithProviders(<VSCodeTab />);
|
||||
|
||||
expect(
|
||||
screen.getByText("DIFF_VIEWER$WAITING_FOR_RUNTIME"),
|
||||
).toBeInTheDocument();
|
||||
expect(screen.queryByTitle("VSCODE$TITLE")).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
@@ -1,8 +1,13 @@
|
||||
import { render, screen } from "@testing-library/react";
|
||||
import { test, expect, describe, vi } from "vitest";
|
||||
import { test, expect, describe, vi, beforeEach } from "vitest";
|
||||
import { MemoryRouter } from "react-router";
|
||||
import { InteractiveChatBox } from "#/components/features/chat/interactive-chat-box";
|
||||
import { renderWithProviders } from "../../test-utils";
|
||||
import { useSelectedOrganizationStore } from "#/stores/selected-organization-store";
|
||||
|
||||
beforeEach(() => {
|
||||
useSelectedOrganizationStore.setState({ organizationId: "test-org-id" });
|
||||
});
|
||||
|
||||
// Mock the translation function
|
||||
vi.mock("react-i18next", async () => {
|
||||
@@ -29,14 +34,12 @@ vi.mock("#/hooks/query/use-active-conversation", () => ({
|
||||
}));
|
||||
|
||||
// Mock React Router hooks
|
||||
vi.mock("react-router", async () => {
|
||||
const actual = await vi.importActual("react-router");
|
||||
return {
|
||||
...actual,
|
||||
useNavigate: () => vi.fn(),
|
||||
useParams: () => ({ conversationId: "test-conversation-id" }),
|
||||
};
|
||||
});
|
||||
vi.mock("react-router", async (importOriginal) => ({
|
||||
...(await importOriginal<typeof import("react-router")>()),
|
||||
useNavigate: () => vi.fn(),
|
||||
useParams: () => ({ conversationId: "test-conversation-id" }),
|
||||
useRevalidator: () => ({ revalidate: vi.fn() }),
|
||||
}));
|
||||
|
||||
// Mock other hooks that might be used by the component
|
||||
vi.mock("#/hooks/use-user-providers", () => ({
|
||||
|
||||
@@ -4,27 +4,96 @@ import { getGitPath } from "#/utils/get-git-path";
|
||||
describe("getGitPath", () => {
|
||||
const conversationId = "abc123";
|
||||
|
||||
it("should return /workspace/project/{conversationId} when no repository is selected", () => {
|
||||
expect(getGitPath(conversationId, null)).toBe(`/workspace/project/${conversationId}`);
|
||||
expect(getGitPath(conversationId, undefined)).toBe(`/workspace/project/${conversationId}`);
|
||||
describe("without sandbox grouping (NO_GROUPING)", () => {
|
||||
it("should return /workspace/project when no repository is selected", () => {
|
||||
expect(getGitPath(conversationId, null, false)).toBe("/workspace/project");
|
||||
expect(getGitPath(conversationId, undefined, false)).toBe(
|
||||
"/workspace/project",
|
||||
);
|
||||
});
|
||||
|
||||
it("should handle standard owner/repo format (GitHub)", () => {
|
||||
expect(getGitPath(conversationId, "OpenHands/OpenHands", false)).toBe(
|
||||
"/workspace/project/OpenHands",
|
||||
);
|
||||
expect(getGitPath(conversationId, "facebook/react", false)).toBe(
|
||||
"/workspace/project/react",
|
||||
);
|
||||
});
|
||||
|
||||
it("should handle nested group paths (GitLab)", () => {
|
||||
expect(
|
||||
getGitPath(conversationId, "modernhealth/frontend-guild/pan", false),
|
||||
).toBe("/workspace/project/pan");
|
||||
expect(getGitPath(conversationId, "group/subgroup/repo", false)).toBe(
|
||||
"/workspace/project/repo",
|
||||
);
|
||||
expect(getGitPath(conversationId, "a/b/c/d/repo", false)).toBe(
|
||||
"/workspace/project/repo",
|
||||
);
|
||||
});
|
||||
|
||||
it("should handle single segment paths", () => {
|
||||
expect(getGitPath(conversationId, "repo", false)).toBe(
|
||||
"/workspace/project/repo",
|
||||
);
|
||||
});
|
||||
|
||||
it("should handle empty string", () => {
|
||||
expect(getGitPath(conversationId, "", false)).toBe("/workspace/project");
|
||||
});
|
||||
});
|
||||
|
||||
it("should handle standard owner/repo format (GitHub)", () => {
|
||||
expect(getGitPath(conversationId, "OpenHands/OpenHands")).toBe(`/workspace/project/${conversationId}/OpenHands`);
|
||||
expect(getGitPath(conversationId, "facebook/react")).toBe(`/workspace/project/${conversationId}/react`);
|
||||
describe("with sandbox grouping enabled", () => {
|
||||
it("should return /workspace/project/{conversationId} when no repository is selected", () => {
|
||||
expect(getGitPath(conversationId, null, true)).toBe(
|
||||
`/workspace/project/${conversationId}`,
|
||||
);
|
||||
expect(getGitPath(conversationId, undefined, true)).toBe(
|
||||
`/workspace/project/${conversationId}`,
|
||||
);
|
||||
});
|
||||
|
||||
it("should handle standard owner/repo format (GitHub)", () => {
|
||||
expect(getGitPath(conversationId, "OpenHands/OpenHands", true)).toBe(
|
||||
`/workspace/project/${conversationId}/OpenHands`,
|
||||
);
|
||||
expect(getGitPath(conversationId, "facebook/react", true)).toBe(
|
||||
`/workspace/project/${conversationId}/react`,
|
||||
);
|
||||
});
|
||||
|
||||
it("should handle nested group paths (GitLab)", () => {
|
||||
expect(
|
||||
getGitPath(conversationId, "modernhealth/frontend-guild/pan", true),
|
||||
).toBe(`/workspace/project/${conversationId}/pan`);
|
||||
expect(getGitPath(conversationId, "group/subgroup/repo", true)).toBe(
|
||||
`/workspace/project/${conversationId}/repo`,
|
||||
);
|
||||
expect(getGitPath(conversationId, "a/b/c/d/repo", true)).toBe(
|
||||
`/workspace/project/${conversationId}/repo`,
|
||||
);
|
||||
});
|
||||
|
||||
it("should handle single segment paths", () => {
|
||||
expect(getGitPath(conversationId, "repo", true)).toBe(
|
||||
`/workspace/project/${conversationId}/repo`,
|
||||
);
|
||||
});
|
||||
|
||||
it("should handle empty string", () => {
|
||||
expect(getGitPath(conversationId, "", true)).toBe(
|
||||
`/workspace/project/${conversationId}`,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it("should handle nested group paths (GitLab)", () => {
|
||||
expect(getGitPath(conversationId, "modernhealth/frontend-guild/pan")).toBe(`/workspace/project/${conversationId}/pan`);
|
||||
expect(getGitPath(conversationId, "group/subgroup/repo")).toBe(`/workspace/project/${conversationId}/repo`);
|
||||
expect(getGitPath(conversationId, "a/b/c/d/repo")).toBe(`/workspace/project/${conversationId}/repo`);
|
||||
});
|
||||
|
||||
it("should handle single segment paths", () => {
|
||||
expect(getGitPath(conversationId, "repo")).toBe(`/workspace/project/${conversationId}/repo`);
|
||||
});
|
||||
|
||||
it("should handle empty string", () => {
|
||||
expect(getGitPath(conversationId, "")).toBe(`/workspace/project/${conversationId}`);
|
||||
describe("default behavior (useSandboxGrouping defaults to false)", () => {
|
||||
it("should default to no sandbox grouping", () => {
|
||||
expect(getGitPath(conversationId, null)).toBe("/workspace/project");
|
||||
expect(getGitPath(conversationId, "owner/repo")).toBe(
|
||||
"/workspace/project/repo",
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
import { describe, it, expect, beforeEach } from "vitest";
|
||||
import {
|
||||
LOCAL_STORAGE_KEYS,
|
||||
LoginMethod,
|
||||
setLoginMethod,
|
||||
getLoginMethod,
|
||||
clearLoginData,
|
||||
setCTADismissed,
|
||||
isCTADismissed,
|
||||
} from "#/utils/local-storage";
|
||||
|
||||
describe("local-storage utilities", () => {
|
||||
beforeEach(() => {
|
||||
localStorage.clear();
|
||||
});
|
||||
|
||||
describe("Login method utilities", () => {
|
||||
describe("setLoginMethod", () => {
|
||||
it("stores the login method in local storage", () => {
|
||||
setLoginMethod(LoginMethod.GITHUB);
|
||||
expect(localStorage.getItem(LOCAL_STORAGE_KEYS.LOGIN_METHOD)).toBe("github");
|
||||
});
|
||||
|
||||
it("stores different login methods correctly", () => {
|
||||
setLoginMethod(LoginMethod.GITLAB);
|
||||
expect(localStorage.getItem(LOCAL_STORAGE_KEYS.LOGIN_METHOD)).toBe("gitlab");
|
||||
|
||||
setLoginMethod(LoginMethod.BITBUCKET);
|
||||
expect(localStorage.getItem(LOCAL_STORAGE_KEYS.LOGIN_METHOD)).toBe("bitbucket");
|
||||
|
||||
setLoginMethod(LoginMethod.AZURE_DEVOPS);
|
||||
expect(localStorage.getItem(LOCAL_STORAGE_KEYS.LOGIN_METHOD)).toBe("azure_devops");
|
||||
|
||||
setLoginMethod(LoginMethod.ENTERPRISE_SSO);
|
||||
expect(localStorage.getItem(LOCAL_STORAGE_KEYS.LOGIN_METHOD)).toBe("enterprise_sso");
|
||||
|
||||
setLoginMethod(LoginMethod.BITBUCKET_DATA_CENTER);
|
||||
expect(localStorage.getItem(LOCAL_STORAGE_KEYS.LOGIN_METHOD)).toBe("bitbucket_data_center");
|
||||
});
|
||||
|
||||
it("overwrites previous login method", () => {
|
||||
setLoginMethod(LoginMethod.GITHUB);
|
||||
setLoginMethod(LoginMethod.GITLAB);
|
||||
expect(localStorage.getItem(LOCAL_STORAGE_KEYS.LOGIN_METHOD)).toBe("gitlab");
|
||||
});
|
||||
});
|
||||
|
||||
describe("getLoginMethod", () => {
|
||||
it("returns null when no login method is set", () => {
|
||||
expect(getLoginMethod()).toBeNull();
|
||||
});
|
||||
|
||||
it("returns the stored login method", () => {
|
||||
localStorage.setItem(LOCAL_STORAGE_KEYS.LOGIN_METHOD, "github");
|
||||
expect(getLoginMethod()).toBe(LoginMethod.GITHUB);
|
||||
});
|
||||
|
||||
it("returns correct login method for all types", () => {
|
||||
localStorage.setItem(LOCAL_STORAGE_KEYS.LOGIN_METHOD, "gitlab");
|
||||
expect(getLoginMethod()).toBe(LoginMethod.GITLAB);
|
||||
|
||||
localStorage.setItem(LOCAL_STORAGE_KEYS.LOGIN_METHOD, "bitbucket");
|
||||
expect(getLoginMethod()).toBe(LoginMethod.BITBUCKET);
|
||||
|
||||
localStorage.setItem(LOCAL_STORAGE_KEYS.LOGIN_METHOD, "azure_devops");
|
||||
expect(getLoginMethod()).toBe(LoginMethod.AZURE_DEVOPS);
|
||||
});
|
||||
});
|
||||
|
||||
describe("clearLoginData", () => {
|
||||
it("removes the login method from local storage", () => {
|
||||
setLoginMethod(LoginMethod.GITHUB);
|
||||
expect(getLoginMethod()).toBe(LoginMethod.GITHUB);
|
||||
|
||||
clearLoginData();
|
||||
expect(getLoginMethod()).toBeNull();
|
||||
});
|
||||
|
||||
it("does not throw when no login method is set", () => {
|
||||
expect(() => clearLoginData()).not.toThrow();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("CTA utilities", () => {
|
||||
describe("isCTADismissed", () => {
|
||||
it("returns false when CTA has not been dismissed", () => {
|
||||
expect(isCTADismissed("homepage")).toBe(false);
|
||||
});
|
||||
|
||||
it("returns true when CTA has been dismissed", () => {
|
||||
localStorage.setItem("homepage-cta-dismissed", "true");
|
||||
expect(isCTADismissed("homepage")).toBe(true);
|
||||
});
|
||||
|
||||
it("returns false when storage value is not 'true'", () => {
|
||||
localStorage.setItem("homepage-cta-dismissed", "false");
|
||||
expect(isCTADismissed("homepage")).toBe(false);
|
||||
|
||||
localStorage.setItem("homepage-cta-dismissed", "invalid");
|
||||
expect(isCTADismissed("homepage")).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("setCTADismissed", () => {
|
||||
it("sets the CTA as dismissed in local storage", () => {
|
||||
setCTADismissed("homepage");
|
||||
expect(localStorage.getItem("homepage-cta-dismissed")).toBe("true");
|
||||
});
|
||||
|
||||
it("generates correct key for homepage location", () => {
|
||||
setCTADismissed("homepage");
|
||||
expect(localStorage.getItem("homepage-cta-dismissed")).toBe("true");
|
||||
});
|
||||
});
|
||||
|
||||
describe("storage key format", () => {
|
||||
it("uses the correct key format: {location}-cta-dismissed", () => {
|
||||
setCTADismissed("homepage");
|
||||
|
||||
// Verify key exists with correct format
|
||||
expect(localStorage.getItem("homepage-cta-dismissed")).toBe("true");
|
||||
|
||||
// Verify other keys don't exist
|
||||
expect(localStorage.getItem("cta-dismissed")).toBeNull();
|
||||
expect(localStorage.getItem("homepage")).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe("persistence", () => {
|
||||
it("dismissed state persists across multiple reads", () => {
|
||||
setCTADismissed("homepage");
|
||||
|
||||
expect(isCTADismissed("homepage")).toBe(true);
|
||||
expect(isCTADismissed("homepage")).toBe(true);
|
||||
expect(isCTADismissed("homepage")).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,10 +1,18 @@
|
||||
import { describe, expect, it, vi, beforeEach } from "vitest";
|
||||
import { PermissionKey } from "#/utils/org/permissions";
|
||||
import { OrganizationMember, OrganizationsQueryData } from "#/types/org";
|
||||
import {
|
||||
getAvailableRolesAUserCanAssign,
|
||||
getActiveOrganizationUser,
|
||||
} from "#/utils/org/permission-checks";
|
||||
import { getSelectedOrganizationIdFromStore } from "#/stores/selected-organization-store";
|
||||
import { queryClient } from "#/query-client-config";
|
||||
|
||||
// Mock dependencies for getActiveOrganizationUser tests
|
||||
// Mock dependencies
|
||||
vi.mock("#/api/organization-service/organization-service.api", () => ({
|
||||
organizationService: {
|
||||
getMe: vi.fn(),
|
||||
getOrganizations: vi.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
@@ -12,49 +20,60 @@ vi.mock("#/stores/selected-organization-store", () => ({
|
||||
getSelectedOrganizationIdFromStore: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("#/utils/query-client-getters", () => ({
|
||||
getMeFromQueryClient: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("#/query-client-config", () => ({
|
||||
queryClient: {
|
||||
getQueryData: vi.fn(),
|
||||
fetchQuery: vi.fn(),
|
||||
setQueryData: vi.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
// Import after mocks are set up
|
||||
import {
|
||||
getAvailableRolesAUserCanAssign,
|
||||
getActiveOrganizationUser,
|
||||
} from "#/utils/org/permission-checks";
|
||||
import { organizationService } from "#/api/organization-service/organization-service.api";
|
||||
import { getSelectedOrganizationIdFromStore } from "#/stores/selected-organization-store";
|
||||
import { getMeFromQueryClient } from "#/utils/query-client-getters";
|
||||
// Test fixtures
|
||||
const mockUser: OrganizationMember = {
|
||||
org_id: "org-1",
|
||||
user_id: "user-1",
|
||||
email: "test@example.com",
|
||||
role: "admin",
|
||||
llm_api_key: "",
|
||||
max_iterations: 100,
|
||||
llm_model: "gpt-4",
|
||||
llm_api_key_for_byor: null,
|
||||
llm_base_url: "",
|
||||
status: "active",
|
||||
};
|
||||
|
||||
const mockOrganizationsData: OrganizationsQueryData = {
|
||||
items: [
|
||||
{ id: "org-1", name: "Org 1" },
|
||||
{ id: "org-2", name: "Org 2" },
|
||||
] as OrganizationsQueryData["items"],
|
||||
currentOrgId: "org-1",
|
||||
};
|
||||
|
||||
describe("getAvailableRolesAUserCanAssign", () => {
|
||||
it("returns empty array if user has no permissions", () => {
|
||||
const result = getAvailableRolesAUserCanAssign([]);
|
||||
expect(result).toEqual([]);
|
||||
});
|
||||
it("returns empty array if user has no permissions", () => {
|
||||
const result = getAvailableRolesAUserCanAssign([]);
|
||||
expect(result).toEqual([]);
|
||||
});
|
||||
|
||||
it("returns only roles the user has permission for", () => {
|
||||
const userPermissions: PermissionKey[] = [
|
||||
"change_user_role:member",
|
||||
"change_user_role:admin",
|
||||
];
|
||||
const result = getAvailableRolesAUserCanAssign(userPermissions);
|
||||
expect(result.sort()).toEqual(["admin", "member"].sort());
|
||||
});
|
||||
it("returns only roles the user has permission for", () => {
|
||||
const userPermissions: PermissionKey[] = [
|
||||
"change_user_role:member",
|
||||
"change_user_role:admin",
|
||||
];
|
||||
const result = getAvailableRolesAUserCanAssign(userPermissions);
|
||||
expect(result.sort()).toEqual(["admin", "member"].sort());
|
||||
});
|
||||
|
||||
it("returns all roles if user has all permissions", () => {
|
||||
const allPermissions: PermissionKey[] = [
|
||||
"change_user_role:member",
|
||||
"change_user_role:admin",
|
||||
"change_user_role:owner",
|
||||
];
|
||||
const result = getAvailableRolesAUserCanAssign(allPermissions);
|
||||
expect(result.sort()).toEqual(["member", "admin", "owner"].sort());
|
||||
});
|
||||
it("returns all roles if user has all permissions", () => {
|
||||
const allPermissions: PermissionKey[] = [
|
||||
"change_user_role:member",
|
||||
"change_user_role:admin",
|
||||
"change_user_role:owner",
|
||||
];
|
||||
const result = getAvailableRolesAUserCanAssign(allPermissions);
|
||||
expect(result.sort()).toEqual(["member", "admin", "owner"].sort());
|
||||
});
|
||||
});
|
||||
|
||||
describe("getActiveOrganizationUser", () => {
|
||||
@@ -62,18 +81,147 @@ describe("getActiveOrganizationUser", () => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it("should return undefined when API call throws an error", async () => {
|
||||
// Arrange: orgId exists, cache is empty, API call fails
|
||||
vi.mocked(getSelectedOrganizationIdFromStore).mockReturnValue("org-1");
|
||||
vi.mocked(getMeFromQueryClient).mockReturnValue(undefined);
|
||||
vi.mocked(organizationService.getMe).mockRejectedValue(
|
||||
new Error("Network error"),
|
||||
);
|
||||
describe("when orgId exists in store", () => {
|
||||
it("should fetch user directly using stored orgId", async () => {
|
||||
// Arrange
|
||||
vi.mocked(getSelectedOrganizationIdFromStore).mockReturnValue("org-1");
|
||||
vi.mocked(queryClient.fetchQuery).mockResolvedValue(mockUser);
|
||||
|
||||
// Act
|
||||
const result = await getActiveOrganizationUser();
|
||||
// Act
|
||||
const result = await getActiveOrganizationUser();
|
||||
|
||||
// Assert: should return undefined instead of propagating the error
|
||||
expect(result).toBeUndefined();
|
||||
// Assert
|
||||
expect(result).toEqual(mockUser);
|
||||
expect(queryClient.getQueryData).not.toHaveBeenCalled();
|
||||
expect(queryClient.fetchQuery).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
queryKey: ["organizations", "org-1", "me"],
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it("should return undefined when user fetch fails", async () => {
|
||||
// Arrange
|
||||
vi.mocked(getSelectedOrganizationIdFromStore).mockReturnValue("org-1");
|
||||
vi.mocked(queryClient.fetchQuery).mockRejectedValue(
|
||||
new Error("Network error"),
|
||||
);
|
||||
|
||||
// Act
|
||||
const result = await getActiveOrganizationUser();
|
||||
|
||||
// Assert
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe("when orgId is null in store (page refresh scenario)", () => {
|
||||
beforeEach(() => {
|
||||
vi.mocked(getSelectedOrganizationIdFromStore).mockReturnValue(null);
|
||||
});
|
||||
|
||||
it("should use currentOrgId from cached organizations data", async () => {
|
||||
// Arrange
|
||||
vi.mocked(queryClient.getQueryData).mockReturnValue(
|
||||
mockOrganizationsData,
|
||||
);
|
||||
vi.mocked(queryClient.fetchQuery).mockResolvedValue(mockUser);
|
||||
|
||||
// Act
|
||||
const result = await getActiveOrganizationUser();
|
||||
|
||||
// Assert
|
||||
expect(result).toEqual(mockUser);
|
||||
expect(queryClient.getQueryData).toHaveBeenCalledWith(["organizations"]);
|
||||
expect(queryClient.fetchQuery).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
queryKey: ["organizations", "org-1", "me"],
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it("should fallback to first org when currentOrgId is null", async () => {
|
||||
// Arrange
|
||||
const dataWithoutCurrentOrg: OrganizationsQueryData = {
|
||||
items: [
|
||||
{ id: "first-org" },
|
||||
{ id: "second-org" },
|
||||
] as OrganizationsQueryData["items"],
|
||||
currentOrgId: null,
|
||||
};
|
||||
vi.mocked(queryClient.getQueryData).mockReturnValue(
|
||||
dataWithoutCurrentOrg,
|
||||
);
|
||||
vi.mocked(queryClient.fetchQuery).mockResolvedValue(mockUser);
|
||||
|
||||
// Act
|
||||
const result = await getActiveOrganizationUser();
|
||||
|
||||
// Assert
|
||||
expect(result).toEqual(mockUser);
|
||||
expect(queryClient.fetchQuery).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
queryKey: ["organizations", "first-org", "me"],
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it("should fetch organizations when not in cache", async () => {
|
||||
// Arrange
|
||||
vi.mocked(queryClient.getQueryData).mockReturnValue(undefined);
|
||||
vi.mocked(queryClient.fetchQuery)
|
||||
.mockResolvedValueOnce(mockOrganizationsData) // First call: fetch organizations
|
||||
.mockResolvedValueOnce(mockUser); // Second call: fetch user
|
||||
|
||||
// Act
|
||||
const result = await getActiveOrganizationUser();
|
||||
|
||||
// Assert
|
||||
expect(result).toEqual(mockUser);
|
||||
expect(queryClient.fetchQuery).toHaveBeenCalledTimes(2);
|
||||
expect(queryClient.fetchQuery).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
expect.objectContaining({
|
||||
queryKey: ["organizations"],
|
||||
}),
|
||||
);
|
||||
expect(queryClient.fetchQuery).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
expect.objectContaining({
|
||||
queryKey: ["organizations", "org-1", "me"],
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it("should return undefined when fetching organizations fails", async () => {
|
||||
// Arrange
|
||||
vi.mocked(queryClient.getQueryData).mockReturnValue(undefined);
|
||||
vi.mocked(queryClient.fetchQuery).mockRejectedValue(
|
||||
new Error("Failed to fetch organizations"),
|
||||
);
|
||||
|
||||
// Act
|
||||
const result = await getActiveOrganizationUser();
|
||||
|
||||
// Assert
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
|
||||
it("should return undefined when no organizations exist", async () => {
|
||||
// Arrange
|
||||
const emptyData: OrganizationsQueryData = {
|
||||
items: [],
|
||||
currentOrgId: null,
|
||||
};
|
||||
vi.mocked(queryClient.getQueryData).mockReturnValue(emptyData);
|
||||
|
||||
// Act
|
||||
const result = await getActiveOrganizationUser();
|
||||
|
||||
// Assert
|
||||
expect(result).toBeUndefined();
|
||||
// Should not attempt to fetch user since there's no orgId
|
||||
expect(queryClient.fetchQuery).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -11,7 +11,7 @@ import { I18nKey } from "#/i18n/declaration";
|
||||
// Mock translations
|
||||
const t = (key: string) => {
|
||||
const translations: { [key: string]: string } = {
|
||||
COMMON$WAITING_FOR_SANDBOX: "Waiting For Sandbox",
|
||||
COMMON$WAITING_FOR_SANDBOX: "Waiting for sandbox",
|
||||
COMMON$STOPPING: "Stopping",
|
||||
COMMON$STARTING: "Starting",
|
||||
COMMON$SERVER_STOPPED: "Server stopped",
|
||||
@@ -69,7 +69,7 @@ describe("getStatusText", () => {
|
||||
t,
|
||||
});
|
||||
|
||||
expect(result).toBe(t(I18nKey.COMMON$WAITING_FOR_SANDBOX));
|
||||
expect(result).toBe("Waiting for sandbox");
|
||||
});
|
||||
|
||||
it("returns task detail when task status is ERROR and detail exists", () => {
|
||||
|
||||
Generated
+3
-4
@@ -15325,10 +15325,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/socket.io-parser": {
|
||||
"version": "4.2.5",
|
||||
"resolved": "https://registry.npmjs.org/socket.io-parser/-/socket.io-parser-4.2.5.tgz",
|
||||
"integrity": "sha512-bPMmpy/5WWKHea5Y/jYAP6k74A+hvmRCQaJuJB6I/ML5JZq/KfNieUVo/3Mh7SAqn7TyFdIo6wqYHInG1MU1bQ==",
|
||||
"license": "MIT",
|
||||
"version": "4.2.6",
|
||||
"resolved": "https://registry.npmjs.org/socket.io-parser/-/socket.io-parser-4.2.6.tgz",
|
||||
"integrity": "sha512-asJqbVBDsBCJx0pTqw3WfesSY0iRX+2xzWEWzrpcH7L6fLzrhyF8WPI8UaeM4YCuDfpwA/cgsdugMsmtz8EJeg==",
|
||||
"dependencies": {
|
||||
"@socket.io/component-emitter": "~3.1.0",
|
||||
"debug": "~4.4.1"
|
||||
|
||||
@@ -68,6 +68,8 @@ class V1ConversationService {
|
||||
trigger?: ConversationTrigger,
|
||||
parent_conversation_id?: string,
|
||||
agent_type?: "default" | "plan",
|
||||
sandbox_id?: string,
|
||||
llm_model?: string,
|
||||
): Promise<V1AppConversationStartTask> {
|
||||
const body: V1AppConversationStartRequest = {
|
||||
selected_repository: selectedRepository,
|
||||
@@ -78,6 +80,8 @@ class V1ConversationService {
|
||||
trigger,
|
||||
parent_conversation_id: parent_conversation_id || null,
|
||||
agent_type,
|
||||
sandbox_id: sandbox_id || null,
|
||||
llm_model: llm_model || null,
|
||||
};
|
||||
|
||||
// suggested_task implies the backend will construct the initial_message
|
||||
|
||||
@@ -17,7 +17,9 @@ class SettingsService {
|
||||
* Save the settings to the server. Only valid settings are saved.
|
||||
* @param settings - the settings to save
|
||||
*/
|
||||
static async saveSettings(settings: Partial<Settings>): Promise<boolean> {
|
||||
static async saveSettings(
|
||||
settings: Partial<Settings> & Record<string, unknown>,
|
||||
): Promise<boolean> {
|
||||
const data = await openHands.post("/api/settings", settings);
|
||||
return data.status === 200;
|
||||
}
|
||||
|
||||
@@ -13,6 +13,9 @@ import { TermsAndPrivacyNotice } from "#/components/shared/terms-and-privacy-not
|
||||
import { useRecaptcha } from "#/hooks/use-recaptcha";
|
||||
import { useConfig } from "#/hooks/query/use-config";
|
||||
import { displayErrorToast } from "#/utils/custom-toast-handlers";
|
||||
import { cn } from "#/utils/utils";
|
||||
import { ENABLE_PROJ_USER_JOURNEY } from "#/utils/feature-flags";
|
||||
import { LoginCTA } from "./login-cta";
|
||||
|
||||
export interface LoginContentProps {
|
||||
githubAuthUrl: string | null;
|
||||
@@ -177,125 +180,133 @@ export function LoginContent({
|
||||
|
||||
return (
|
||||
<div
|
||||
className="flex flex-col items-center w-full gap-12.5"
|
||||
data-testid="login-content"
|
||||
className={cn(
|
||||
"flex flex-col md:flex-row items-center md:items-stretch gap-6 h-full",
|
||||
)}
|
||||
>
|
||||
<div>
|
||||
<OpenHandsLogoWhite width={106} height={72} />
|
||||
</div>
|
||||
<div
|
||||
className={cn("flex flex-col items-center w-full gap-12.5")}
|
||||
data-testid="login-content"
|
||||
>
|
||||
<div>
|
||||
<OpenHandsLogoWhite width={106} height={72} />
|
||||
</div>
|
||||
|
||||
<h1 className="text-[39px] leading-5 font-medium text-white text-center">
|
||||
{t(I18nKey.AUTH$LETS_GET_STARTED)}
|
||||
</h1>
|
||||
<h1 className="text-[39px] leading-5 font-medium text-white text-center">
|
||||
{t(I18nKey.AUTH$LETS_GET_STARTED)}
|
||||
</h1>
|
||||
|
||||
{shouldShownHelperText && (
|
||||
<div className="flex flex-col items-center gap-3">
|
||||
{emailVerified && (
|
||||
<p className="text-sm text-muted-foreground text-center">
|
||||
{t(I18nKey.AUTH$EMAIL_VERIFIED_PLEASE_LOGIN)}
|
||||
</p>
|
||||
)}
|
||||
{hasDuplicatedEmail && (
|
||||
<p className="text-sm text-danger text-center">
|
||||
{t(I18nKey.AUTH$DUPLICATE_EMAIL_ERROR)}
|
||||
</p>
|
||||
)}
|
||||
{recaptchaBlocked && (
|
||||
<p className="text-sm text-danger text-center max-w-125">
|
||||
{t(I18nKey.AUTH$RECAPTCHA_BLOCKED)}
|
||||
</p>
|
||||
)}
|
||||
{hasInvitation && (
|
||||
<p className="text-sm text-muted-foreground text-center">
|
||||
{t(I18nKey.AUTH$INVITATION_PENDING)}
|
||||
</p>
|
||||
)}
|
||||
{showBitbucket && (
|
||||
<p className="text-sm text-white text-center max-w-125">
|
||||
{t(I18nKey.AUTH$BITBUCKET_SIGNUP_DISABLED)}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{shouldShownHelperText && (
|
||||
<div className="flex flex-col items-center gap-3">
|
||||
{emailVerified && (
|
||||
<p className="text-sm text-muted-foreground text-center">
|
||||
{t(I18nKey.AUTH$EMAIL_VERIFIED_PLEASE_LOGIN)}
|
||||
</p>
|
||||
)}
|
||||
{hasDuplicatedEmail && (
|
||||
<p className="text-sm text-danger text-center">
|
||||
{t(I18nKey.AUTH$DUPLICATE_EMAIL_ERROR)}
|
||||
</p>
|
||||
)}
|
||||
{recaptchaBlocked && (
|
||||
<p className="text-sm text-danger text-center max-w-125">
|
||||
{t(I18nKey.AUTH$RECAPTCHA_BLOCKED)}
|
||||
</p>
|
||||
)}
|
||||
{hasInvitation && (
|
||||
<p className="text-sm text-muted-foreground text-center">
|
||||
{t(I18nKey.AUTH$INVITATION_PENDING)}
|
||||
</p>
|
||||
)}
|
||||
{showBitbucket && (
|
||||
<p className="text-sm text-white text-center max-w-125">
|
||||
{t(I18nKey.AUTH$BITBUCKET_SIGNUP_DISABLED)}
|
||||
</p>
|
||||
{noProvidersConfigured ? (
|
||||
<div className="text-center p-4 text-muted-foreground">
|
||||
{t(I18nKey.AUTH$NO_PROVIDERS_CONFIGURED)}
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
{showGithub && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={handleGitHubAuth}
|
||||
className={`${buttonBaseClasses} bg-[#9E28B0] text-white`}
|
||||
>
|
||||
<GitHubLogo width={14} height={14} className="shrink-0" />
|
||||
<span className={buttonLabelClasses}>
|
||||
{t(I18nKey.GITHUB$CONNECT_TO_GITHUB)}
|
||||
</span>
|
||||
</button>
|
||||
)}
|
||||
|
||||
{showGitlab && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={handleGitLabAuth}
|
||||
className={`${buttonBaseClasses} bg-[#FC6B0E] text-white`}
|
||||
>
|
||||
<GitLabLogo width={14} height={14} className="shrink-0" />
|
||||
<span className={buttonLabelClasses}>
|
||||
{t(I18nKey.GITLAB$CONNECT_TO_GITLAB)}
|
||||
</span>
|
||||
</button>
|
||||
)}
|
||||
|
||||
{showBitbucket && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={handleBitbucketAuth}
|
||||
className={`${buttonBaseClasses} bg-[#2684FF] text-white`}
|
||||
>
|
||||
<BitbucketLogo width={14} height={14} className="shrink-0" />
|
||||
<span className={buttonLabelClasses}>
|
||||
{t(I18nKey.BITBUCKET$CONNECT_TO_BITBUCKET)}
|
||||
</span>
|
||||
</button>
|
||||
)}
|
||||
|
||||
{showBitbucketDataCenter && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={handleBitbucketDataCenterAuth}
|
||||
className={`${buttonBaseClasses} bg-[#2684FF] text-white`}
|
||||
>
|
||||
<BitbucketLogo width={14} height={14} className="shrink-0" />
|
||||
<span className={buttonLabelClasses}>
|
||||
{t(
|
||||
I18nKey.BITBUCKET_DATA_CENTER$CONNECT_TO_BITBUCKET_DATA_CENTER,
|
||||
)}
|
||||
</span>
|
||||
</button>
|
||||
)}
|
||||
|
||||
{showEnterpriseSso && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={handleEnterpriseSsoAuth}
|
||||
className={`${buttonBaseClasses} bg-[#374151] text-white`}
|
||||
>
|
||||
<FaUserShield size={14} className="shrink-0" />
|
||||
<span className={buttonLabelClasses}>
|
||||
{t(I18nKey.ENTERPRISE_SSO$CONNECT_TO_ENTERPRISE_SSO)}
|
||||
</span>
|
||||
</button>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="flex flex-col items-center gap-3">
|
||||
{noProvidersConfigured ? (
|
||||
<div className="text-center p-4 text-muted-foreground">
|
||||
{t(I18nKey.AUTH$NO_PROVIDERS_CONFIGURED)}
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
{showGithub && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={handleGitHubAuth}
|
||||
className={`${buttonBaseClasses} bg-[#9E28B0] text-white`}
|
||||
>
|
||||
<GitHubLogo width={14} height={14} className="shrink-0" />
|
||||
<span className={buttonLabelClasses}>
|
||||
{t(I18nKey.GITHUB$CONNECT_TO_GITHUB)}
|
||||
</span>
|
||||
</button>
|
||||
)}
|
||||
|
||||
{showGitlab && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={handleGitLabAuth}
|
||||
className={`${buttonBaseClasses} bg-[#FC6B0E] text-white`}
|
||||
>
|
||||
<GitLabLogo width={14} height={14} className="shrink-0" />
|
||||
<span className={buttonLabelClasses}>
|
||||
{t(I18nKey.GITLAB$CONNECT_TO_GITLAB)}
|
||||
</span>
|
||||
</button>
|
||||
)}
|
||||
|
||||
{showBitbucket && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={handleBitbucketAuth}
|
||||
className={`${buttonBaseClasses} bg-[#2684FF] text-white`}
|
||||
>
|
||||
<BitbucketLogo width={14} height={14} className="shrink-0" />
|
||||
<span className={buttonLabelClasses}>
|
||||
{t(I18nKey.BITBUCKET$CONNECT_TO_BITBUCKET)}
|
||||
</span>
|
||||
</button>
|
||||
)}
|
||||
|
||||
{showBitbucketDataCenter && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={handleBitbucketDataCenterAuth}
|
||||
className={`${buttonBaseClasses} bg-[#2684FF] text-white`}
|
||||
>
|
||||
<BitbucketLogo width={14} height={14} className="shrink-0" />
|
||||
<span className={buttonLabelClasses}>
|
||||
{t(
|
||||
I18nKey.BITBUCKET_DATA_CENTER$CONNECT_TO_BITBUCKET_DATA_CENTER,
|
||||
)}
|
||||
</span>
|
||||
</button>
|
||||
)}
|
||||
|
||||
{showEnterpriseSso && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={handleEnterpriseSsoAuth}
|
||||
className={`${buttonBaseClasses} bg-[#374151] text-white`}
|
||||
>
|
||||
<FaUserShield size={14} className="shrink-0" />
|
||||
<span className={buttonLabelClasses}>
|
||||
{t(I18nKey.ENTERPRISE_SSO$CONNECT_TO_ENTERPRISE_SSO)}
|
||||
</span>
|
||||
</button>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
<TermsAndPrivacyNotice className="max-w-[320px] text-[#A3A3A3]" />
|
||||
</div>
|
||||
|
||||
<TermsAndPrivacyNotice className="max-w-[320px] text-[#A3A3A3]" />
|
||||
{appMode === "saas" && ENABLE_PROJ_USER_JOURNEY() && <LoginCTA />}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { Card } from "#/ui/card";
|
||||
import { CardTitle } from "#/ui/card-title";
|
||||
import { Typography } from "#/ui/typography";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { cn } from "#/utils/utils";
|
||||
import StackedIcon from "#/icons/stacked.svg?react";
|
||||
import { useTracking } from "#/hooks/use-tracking";
|
||||
|
||||
export function LoginCTA() {
|
||||
const { t } = useTranslation();
|
||||
const { trackSaasSelfhostedInquiry } = useTracking();
|
||||
|
||||
const handleLearnMoreClick = () => {
|
||||
trackSaasSelfhostedInquiry({ location: "login_page" });
|
||||
};
|
||||
|
||||
return (
|
||||
<Card
|
||||
testId="login-cta"
|
||||
theme="dark"
|
||||
className={cn("w-full max-w-80 h-auto flex-col", "cta-card-gradient")}
|
||||
>
|
||||
<div className={cn("flex flex-col gap-[11px] p-6")}>
|
||||
<div className={cn("size-10")}>
|
||||
<StackedIcon width={40} height={40} />
|
||||
</div>
|
||||
|
||||
<CardTitle>{t(I18nKey.CTA$ENTERPRISE)}</CardTitle>
|
||||
|
||||
<Typography.Text className="text-[#8C8C8C] font-inter font-normal text-sm leading-5">
|
||||
{t(I18nKey.CTA$ENTERPRISE_DEPLOY)}
|
||||
</Typography.Text>
|
||||
|
||||
<ul
|
||||
className={cn(
|
||||
"text-[#8C8C8C] font-inter font-normal text-sm leading-5 list-disc list-inside flex flex-col gap-1",
|
||||
)}
|
||||
>
|
||||
<li>{t(I18nKey.CTA$FEATURE_ON_PREMISES)}</li>
|
||||
<li>{t(I18nKey.CTA$FEATURE_DATA_CONTROL)}</li>
|
||||
<li>{t(I18nKey.CTA$FEATURE_COMPLIANCE)}</li>
|
||||
<li>{t(I18nKey.CTA$FEATURE_SUPPORT)}</li>
|
||||
</ul>
|
||||
|
||||
<div className={cn("h-10 flex justify-start")}>
|
||||
<a
|
||||
href="https://openhands.dev/enterprise/"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
onClick={handleLearnMoreClick}
|
||||
className={cn(
|
||||
"inline-flex items-center justify-center",
|
||||
"h-10 px-4 rounded",
|
||||
"bg-[#050505] border border-[#242424]",
|
||||
"text-white hover:bg-[#0a0a0a]",
|
||||
"font-semibold text-sm",
|
||||
)}
|
||||
>
|
||||
{t(I18nKey.CTA$LEARN_MORE)}
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
@@ -38,6 +38,8 @@ import { useTaskPolling } from "#/hooks/query/use-task-polling";
|
||||
import { useConversationWebSocket } from "#/contexts/conversation-websocket-context";
|
||||
import ChatStatusIndicator from "./chat-status-indicator";
|
||||
import { getStatusColor, getStatusText } from "#/utils/utils";
|
||||
import { useNewConversationCommand } from "#/hooks/mutation/use-new-conversation-command";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
|
||||
function getEntryPoint(
|
||||
hasRepository: boolean | null,
|
||||
@@ -80,6 +82,10 @@ export function ChatInterface() {
|
||||
setHitBottom,
|
||||
} = useScrollToBottom(scrollRef);
|
||||
const { data: config } = useConfig();
|
||||
const {
|
||||
mutate: newConversationCommand,
|
||||
isPending: isNewConversationPending,
|
||||
} = useNewConversationCommand();
|
||||
|
||||
const { curAgentState } = useAgentState();
|
||||
const { handleBuildPlanClick } = useHandleBuildPlanClick();
|
||||
@@ -146,6 +152,27 @@ export function ChatInterface() {
|
||||
originalImages: File[],
|
||||
originalFiles: File[],
|
||||
) => {
|
||||
// Handle /new command for V1 conversations
|
||||
if (content.trim() === "/new") {
|
||||
if (!isV1Conversation) {
|
||||
displayErrorToast(t(I18nKey.CONVERSATION$CLEAR_V1_ONLY));
|
||||
return;
|
||||
}
|
||||
if (!params.conversationId) {
|
||||
displayErrorToast(t(I18nKey.CONVERSATION$CLEAR_NO_ID));
|
||||
return;
|
||||
}
|
||||
if (totalEvents === 0) {
|
||||
displayErrorToast(t(I18nKey.CONVERSATION$CLEAR_EMPTY));
|
||||
return;
|
||||
}
|
||||
if (isNewConversationPending) {
|
||||
return;
|
||||
}
|
||||
newConversationCommand();
|
||||
return;
|
||||
}
|
||||
|
||||
// Create mutable copies of the arrays
|
||||
const images = [...originalImages];
|
||||
const files = [...originalFiles];
|
||||
@@ -338,7 +365,10 @@ export function ChatInterface() {
|
||||
/>
|
||||
)}
|
||||
|
||||
<InteractiveChatBox onSubmit={handleSendMessage} />
|
||||
<InteractiveChatBox
|
||||
onSubmit={handleSendMessage}
|
||||
disabled={isNewConversationPending}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{config?.app_mode !== "saas" && !isV1Conversation && (
|
||||
|
||||
@@ -12,6 +12,7 @@ interface ChatInputContainerProps {
|
||||
chatContainerRef: React.RefObject<HTMLDivElement | null>;
|
||||
isDragOver: boolean;
|
||||
disabled: boolean;
|
||||
isNewConversationPending?: boolean;
|
||||
showButton: boolean;
|
||||
buttonClassName: string;
|
||||
chatInputRef: React.RefObject<HTMLDivElement | null>;
|
||||
@@ -36,6 +37,7 @@ export function ChatInputContainer({
|
||||
chatContainerRef,
|
||||
isDragOver,
|
||||
disabled,
|
||||
isNewConversationPending = false,
|
||||
showButton,
|
||||
buttonClassName,
|
||||
chatInputRef,
|
||||
@@ -89,6 +91,7 @@ export function ChatInputContainer({
|
||||
<ChatInputRow
|
||||
chatInputRef={chatInputRef}
|
||||
disabled={disabled}
|
||||
isNewConversationPending={isNewConversationPending}
|
||||
showButton={showButton}
|
||||
buttonClassName={buttonClassName}
|
||||
handleFileIconClick={handleFileIconClick}
|
||||
|
||||
@@ -2,9 +2,11 @@ import React from "react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { useConversationStore } from "#/stores/conversation-store";
|
||||
import { cn } from "#/utils/utils";
|
||||
|
||||
interface ChatInputFieldProps {
|
||||
chatInputRef: React.RefObject<HTMLDivElement | null>;
|
||||
disabled?: boolean;
|
||||
onInput: () => void;
|
||||
onPaste: (e: React.ClipboardEvent) => void;
|
||||
onKeyDown: (e: React.KeyboardEvent) => void;
|
||||
@@ -14,6 +16,7 @@ interface ChatInputFieldProps {
|
||||
|
||||
export function ChatInputField({
|
||||
chatInputRef,
|
||||
disabled = false,
|
||||
onInput,
|
||||
onPaste,
|
||||
onKeyDown,
|
||||
@@ -36,8 +39,11 @@ export function ChatInputField({
|
||||
<div className="basis-0 flex flex-col font-normal grow justify-center leading-[0] min-h-px min-w-px overflow-ellipsis overflow-hidden relative shrink-0 text-[#d0d9fa] text-[16px] text-left">
|
||||
<div
|
||||
ref={chatInputRef}
|
||||
className="chat-input bg-transparent text-white text-[16px] font-normal leading-[20px] outline-none resize-none custom-scrollbar min-h-[20px] max-h-[400px] [text-overflow:inherit] [text-wrap-mode:inherit] [white-space-collapse:inherit] block whitespace-pre-wrap"
|
||||
contentEditable
|
||||
className={cn(
|
||||
"chat-input bg-transparent text-white text-[16px] font-normal leading-[20px] outline-none resize-none custom-scrollbar min-h-[20px] max-h-[400px] [text-overflow:inherit] [text-wrap-mode:inherit] [white-space-collapse:inherit] block whitespace-pre-wrap",
|
||||
disabled && "cursor-not-allowed opacity-50",
|
||||
)}
|
||||
contentEditable={!disabled}
|
||||
data-placeholder={
|
||||
isPlanMode
|
||||
? t(I18nKey.COMMON$LET_S_WORK_ON_A_PLAN)
|
||||
|
||||
@@ -7,6 +7,7 @@ import { ChatInputField } from "./chat-input-field";
|
||||
interface ChatInputRowProps {
|
||||
chatInputRef: React.RefObject<HTMLDivElement | null>;
|
||||
disabled: boolean;
|
||||
isNewConversationPending?: boolean;
|
||||
showButton: boolean;
|
||||
buttonClassName: string;
|
||||
handleFileIconClick: (isDisabled: boolean) => void;
|
||||
@@ -21,6 +22,7 @@ interface ChatInputRowProps {
|
||||
export function ChatInputRow({
|
||||
chatInputRef,
|
||||
disabled,
|
||||
isNewConversationPending = false,
|
||||
showButton,
|
||||
buttonClassName,
|
||||
handleFileIconClick,
|
||||
@@ -41,6 +43,7 @@ export function ChatInputRow({
|
||||
|
||||
<ChatInputField
|
||||
chatInputRef={chatInputRef}
|
||||
disabled={isNewConversationPending}
|
||||
onInput={onInput}
|
||||
onPaste={onPaste}
|
||||
onKeyDown={onKeyDown}
|
||||
|
||||
@@ -13,6 +13,7 @@ import { useConversationStore } from "#/stores/conversation-store";
|
||||
|
||||
export interface CustomChatInputProps {
|
||||
disabled?: boolean;
|
||||
isNewConversationPending?: boolean;
|
||||
showButton?: boolean;
|
||||
conversationStatus?: ConversationStatus | null;
|
||||
onSubmit: (message: string) => void;
|
||||
@@ -25,6 +26,7 @@ export interface CustomChatInputProps {
|
||||
|
||||
export function CustomChatInput({
|
||||
disabled = false,
|
||||
isNewConversationPending = false,
|
||||
showButton = true,
|
||||
conversationStatus = null,
|
||||
onSubmit,
|
||||
@@ -147,6 +149,7 @@ export function CustomChatInput({
|
||||
chatContainerRef={chatContainerRef}
|
||||
isDragOver={isDragOver}
|
||||
disabled={isDisabled}
|
||||
isNewConversationPending={isNewConversationPending}
|
||||
showButton={showButton}
|
||||
buttonClassName={buttonClassName}
|
||||
chatInputRef={chatInputRef}
|
||||
|
||||
@@ -13,9 +13,13 @@ import { isTaskPolling } from "#/utils/utils";
|
||||
|
||||
interface InteractiveChatBoxProps {
|
||||
onSubmit: (message: string, images: File[], files: File[]) => void;
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
export function InteractiveChatBox({ onSubmit }: InteractiveChatBoxProps) {
|
||||
export function InteractiveChatBox({
|
||||
onSubmit,
|
||||
disabled = false,
|
||||
}: InteractiveChatBoxProps) {
|
||||
const {
|
||||
images,
|
||||
files,
|
||||
@@ -145,6 +149,7 @@ export function InteractiveChatBox({ onSubmit }: InteractiveChatBoxProps) {
|
||||
// Allow users to submit messages during LOADING state - they will be
|
||||
// queued server-side and delivered when the conversation becomes ready
|
||||
const isDisabled =
|
||||
disabled ||
|
||||
curAgentState === AgentState.AWAITING_USER_CONFIRMATION ||
|
||||
isTaskPolling(subConversationTaskStatus);
|
||||
|
||||
@@ -152,6 +157,7 @@ export function InteractiveChatBox({ onSubmit }: InteractiveChatBoxProps) {
|
||||
<div data-testid="interactive-chat-box">
|
||||
<CustomChatInput
|
||||
disabled={isDisabled}
|
||||
isNewConversationPending={disabled}
|
||||
onSubmit={handleSubmit}
|
||||
onFilesPaste={handleUpload}
|
||||
conversationStatus={conversation?.status || null}
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
import React from "react";
|
||||
import { cn } from "#/utils/utils";
|
||||
import { useClickOutsideElement } from "#/hooks/use-click-outside-element";
|
||||
|
||||
interface ContextMenuContainerProps {
|
||||
children: React.ReactNode;
|
||||
onClose: () => void;
|
||||
testId?: string;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function ContextMenuContainer({
|
||||
children,
|
||||
onClose,
|
||||
testId,
|
||||
className,
|
||||
}: ContextMenuContainerProps) {
|
||||
const ref = useClickOutsideElement<HTMLDivElement>(onClose);
|
||||
|
||||
return (
|
||||
<div
|
||||
ref={ref}
|
||||
data-testid={testId}
|
||||
className={cn(
|
||||
// Base styling - same for ALL modes (SaaS, OSS, mobile, desktop)
|
||||
"absolute rounded-[12px] p-[25px]",
|
||||
"bg-[#050505] border border-[#242424]",
|
||||
"text-white overflow-hidden z-[9999]",
|
||||
"context-menu-box-shadow",
|
||||
// Positioning
|
||||
"right-0 md:right-auto md:left-full md:bottom-0",
|
||||
"w-fit",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<div className="flex flex-row gap-4 items-stretch">{children}</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { cn } from "#/utils/utils";
|
||||
import { Card } from "#/ui/card";
|
||||
import { CardTitle } from "#/ui/card-title";
|
||||
import { Typography } from "#/ui/typography";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import StackedIcon from "#/icons/stacked.svg?react";
|
||||
import { useTracking } from "#/hooks/use-tracking";
|
||||
|
||||
export function ContextMenuCTA() {
|
||||
const { t } = useTranslation();
|
||||
const { trackSaasSelfhostedInquiry } = useTracking();
|
||||
|
||||
const handleLearnMoreClick = () => {
|
||||
trackSaasSelfhostedInquiry({ location: "context_menu" });
|
||||
};
|
||||
|
||||
return (
|
||||
<Card
|
||||
testId="context-menu-cta"
|
||||
theme="dark"
|
||||
className={cn(
|
||||
"w-[286px] min-h-[200px] rounded-[6px]",
|
||||
"flex-col justify-end",
|
||||
"cta-card-gradient",
|
||||
)}
|
||||
>
|
||||
<div
|
||||
data-testid="context-menu-cta-content"
|
||||
className={cn("flex flex-col gap-[11px] p-[25px]")}
|
||||
>
|
||||
<StackedIcon width={40} height={40} />
|
||||
|
||||
<CardTitle>{t(I18nKey.CTA$ENTERPRISE_TITLE)}</CardTitle>
|
||||
|
||||
<Typography.Text
|
||||
className={cn(
|
||||
"text-[#8C8C8C] font-inter font-normal",
|
||||
"text-[14px] leading-[20px]",
|
||||
)}
|
||||
>
|
||||
{t(I18nKey.CTA$ENTERPRISE_DESCRIPTION)}
|
||||
</Typography.Text>
|
||||
|
||||
<div className="flex mt-auto">
|
||||
<a
|
||||
href="https://openhands.dev/enterprise/"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
onClick={handleLearnMoreClick}
|
||||
className={cn(
|
||||
"inline-flex items-center justify-center",
|
||||
"h-[40px] px-4 rounded-[4px]",
|
||||
"bg-[#050505] border border-[#242424]",
|
||||
"text-white hover:bg-[#0a0a0a]",
|
||||
"font-semibold text-sm",
|
||||
)}
|
||||
>
|
||||
{t(I18nKey.CTA$LEARN_MORE)}
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
+9
-3
@@ -20,7 +20,7 @@ export function ConversationTabTitle({
|
||||
conversationKey,
|
||||
}: ConversationTabTitleProps) {
|
||||
const { t } = useTranslation();
|
||||
const { refetch } = useUnifiedGetGitChanges();
|
||||
const { refetch, isFetching } = useUnifiedGetGitChanges();
|
||||
const { handleBuildPlanClick } = useHandleBuildPlanClick();
|
||||
const { curAgentState } = useAgentState();
|
||||
const { planContent } = useConversationStore();
|
||||
@@ -41,10 +41,16 @@ export function ConversationTabTitle({
|
||||
{conversationKey === "editor" && (
|
||||
<button
|
||||
type="button"
|
||||
className="flex w-[26px] py-1 justify-center items-center gap-[10px] rounded-[7px] hover:bg-[#474A54] cursor-pointer"
|
||||
className="flex w-[26px] py-1 justify-center items-center gap-[10px] rounded-[7px] hover:enabled:bg-[#474A54] cursor-pointer disabled:opacity-50 disabled:cursor-not-allowed"
|
||||
onClick={handleRefresh}
|
||||
disabled={isFetching}
|
||||
>
|
||||
<RefreshIcon width={12.75} height={15} color="#ffffff" />
|
||||
<RefreshIcon
|
||||
width={12.75}
|
||||
height={15}
|
||||
color="#ffffff"
|
||||
className={isFetching ? "animate-spin" : ""}
|
||||
/>
|
||||
</button>
|
||||
)}
|
||||
{conversationKey === "planner" && (
|
||||
|
||||
+3
-2
@@ -1,14 +1,15 @@
|
||||
import { FaExternalLinkAlt } from "react-icons/fa";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { RUNTIME_INACTIVE_STATES } from "#/types/agent-state";
|
||||
import { useAgentState } from "#/hooks/use-agent-state";
|
||||
import { useUnifiedVSCodeUrl } from "#/hooks/query/use-unified-vscode-url";
|
||||
import { RUNTIME_STARTING_STATES } from "#/types/agent-state";
|
||||
|
||||
export function VSCodeTooltipContent() {
|
||||
const { curAgentState } = useAgentState();
|
||||
const { t } = useTranslation();
|
||||
const { data, refetch } = useUnifiedVSCodeUrl();
|
||||
const isRuntimeStarting = RUNTIME_STARTING_STATES.includes(curAgentState);
|
||||
|
||||
const handleVSCodeClick = async (e: React.MouseEvent) => {
|
||||
e.preventDefault();
|
||||
@@ -29,7 +30,7 @@ export function VSCodeTooltipContent() {
|
||||
return (
|
||||
<div className="flex items-center gap-2">
|
||||
<span>{t(I18nKey.COMMON$CODE)}</span>
|
||||
{!RUNTIME_INACTIVE_STATES.includes(curAgentState) ? (
|
||||
{!isRuntimeStarting ? (
|
||||
<FaExternalLinkAlt
|
||||
className="w-3 h-3 text-inherit cursor-pointer"
|
||||
onClick={handleVSCodeClick}
|
||||
|
||||
@@ -3,7 +3,7 @@ import { usePostHog } from "posthog-js/react";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { H2, Text } from "#/ui/typography";
|
||||
import CheckCircleFillIcon from "#/icons/check-circle-fill.svg?react";
|
||||
import { PROJ_USER_JOURNEY } from "#/utils/feature-flags";
|
||||
import { ENABLE_PROJ_USER_JOURNEY } from "#/utils/feature-flags";
|
||||
|
||||
const ENTERPRISE_FEATURE_KEYS: I18nKey[] = [
|
||||
I18nKey.ENTERPRISE$FEATURE_DATA_PRIVACY,
|
||||
@@ -16,7 +16,7 @@ export function EnterpriseBanner() {
|
||||
const { t } = useTranslation();
|
||||
const posthog = usePostHog();
|
||||
|
||||
if (!PROJ_USER_JOURNEY()) {
|
||||
if (!ENABLE_PROJ_USER_JOURNEY()) {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ export function ConnectToProviderMessage() {
|
||||
</div>
|
||||
<Link
|
||||
data-testid="navigate-to-settings-button"
|
||||
to="/settings/integrations"
|
||||
to="/settings"
|
||||
className="self-start w-full"
|
||||
>
|
||||
<BrandButton
|
||||
|
||||
@@ -0,0 +1,75 @@
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { Dispatch, SetStateAction } from "react";
|
||||
import { Card } from "#/ui/card";
|
||||
import { CardTitle } from "#/ui/card-title";
|
||||
import { Typography } from "#/ui/typography";
|
||||
import { cn } from "#/utils/utils";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { setCTADismissed } from "#/utils/local-storage";
|
||||
import { useTracking } from "#/hooks/use-tracking";
|
||||
import CloseIcon from "#/icons/close.svg?react";
|
||||
|
||||
interface HomepageCTAProps {
|
||||
setShouldShowCTA: Dispatch<SetStateAction<boolean>>;
|
||||
}
|
||||
|
||||
export function HomepageCTA({ setShouldShowCTA }: HomepageCTAProps) {
|
||||
const { t } = useTranslation();
|
||||
const { trackSaasSelfhostedInquiry } = useTracking();
|
||||
|
||||
const handleClose = () => {
|
||||
setCTADismissed("homepage");
|
||||
setShouldShowCTA(false);
|
||||
};
|
||||
|
||||
const handleLearnMoreClick = () => {
|
||||
trackSaasSelfhostedInquiry({ location: "home_page" });
|
||||
};
|
||||
|
||||
return (
|
||||
<Card theme="dark" className={cn("w-[320px] cta-card-gradient")}>
|
||||
<button
|
||||
type="button"
|
||||
onClick={handleClose}
|
||||
className={cn(
|
||||
"absolute top-3 right-3 size-7 rounded-full",
|
||||
"border border-[#242424] bg-[#0A0A0A]",
|
||||
"flex items-center justify-center",
|
||||
"text-white/60 hover:text-white cursor-pointer",
|
||||
"shadow-[0px_1px_2px_-1px_#0000001A,0px_1px_3px_0px_#0000001A]",
|
||||
)}
|
||||
aria-label="Close"
|
||||
>
|
||||
<CloseIcon width={16} height={16} />
|
||||
</button>
|
||||
|
||||
<div className="p-6 flex flex-col gap-4">
|
||||
<div className="flex flex-col gap-2">
|
||||
<CardTitle className="font-inter font-semibold text-xl leading-7 tracking-normal text-[#FAFAFA]">
|
||||
{t(I18nKey.CTA$ENTERPRISE_TITLE)}
|
||||
</CardTitle>
|
||||
|
||||
<Typography.Text className="font-inter font-normal text-sm leading-5 tracking-normal text-[#8C8C8C]">
|
||||
{t(I18nKey.CTA$ENTERPRISE_DESCRIPTION)}
|
||||
</Typography.Text>
|
||||
</div>
|
||||
|
||||
<a
|
||||
href="https://openhands.dev/enterprise/"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
onClick={handleLearnMoreClick}
|
||||
className={cn(
|
||||
"inline-flex items-center justify-center",
|
||||
"w-fit h-10 px-4 rounded",
|
||||
"bg-[#050505] border border-[#242424]",
|
||||
"text-white hover:bg-[#1a1a1a]",
|
||||
"font-semibold text-sm",
|
||||
)}
|
||||
>
|
||||
{t(I18nKey.CTA$LEARN_MORE)}
|
||||
</a>
|
||||
</div>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user