Compare commits

..

50 Commits

Author SHA1 Message Date
mamoodi
9885ddea33 Release 1.1.0 2025-12-30 09:28:57 -05:00
sp.wack
103e3ead0a hotfix(frontend): validate git changes response is array before mapping (#12208) 2025-12-30 12:33:09 +00:00
dependabot[bot]
d5e83d0f06 chore(deps): bump peter-evans/create-or-update-comment from 4 to 5 (#12192)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Engel Nyst <engel.nyst@gmail.com>
2025-12-29 23:50:40 +00:00
dependabot[bot]
443918af3c chore(deps): bump docker/setup-qemu-action from 3.6.0 to 3.7.0 (#12193)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-12-30 00:25:56 +01:00
dependabot[bot]
910646d11f chore(deps): bump actions/cache from 4 to 5 (#12191)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-12-30 00:25:17 +01:00
Engel Nyst
d9d19043f1 chore: Mark V0 legacy files with clear headers and V1 pointers (#12165)
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Rohit Malhotra <rohitvinodmalhotra@gmail.com>
2025-12-30 00:21:29 +01:00
Graham Neubig
4dec38c7ce fix(event-webhook): Improve error logging with exception type and stack trace (#12202)
Co-authored-by: openhands <openhands@all-hands.dev>
2025-12-29 18:09:20 -05:00
Graham Neubig
c3f51d9dbe fix(billing): Add error handling for LiteLLM API failures in get_credits (#12201)
Co-authored-by: openhands <openhands@all-hands.dev>
2025-12-29 23:01:55 +00:00
chuckbutkus
ecbd3ae749 Fix local dev deployments (#12198) 2025-12-29 16:18:02 -05:00
Hiep Le
8ee1394e8c feat: add button to authentication modal to resend verification email (#12179) 2025-12-30 02:12:14 +07:00
Tim O'Farrell
d628e1f20a feat: Add frontend support for public conversation sharing (#12047)
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: sp.wack <83104063+amanape@users.noreply.github.com>
2025-12-29 12:04:06 -07:00
sp.wack
1480d4acb0 fix(frontend): deduplicate events on WebSocket reconnect (#12197) 2025-12-29 19:03:48 +00:00
Hiep Le
58a70e8b0d fix(backend): preserve users custom llm settings during settings migrations (#12134)
Co-authored-by: Xingyao Wang <xingyao@all-hands.dev>
2025-12-29 23:28:20 +07:00
Hiep Le
49e46a5fa1 refactor(backend): remove <sub> in slack response (#12135) 2025-12-29 23:27:48 +07:00
Hiep Le
2cf6494773 fix(backend): install_gitlab_webhooks.py is not functioning as expected (#12185) 2025-12-29 23:27:31 +07:00
Hiep Le
d3afbfa447 refactor(backend): add description field support for secrets (v1 conversations) (#12080) 2025-12-29 22:43:07 +07:00
Hiep Le
8d69b4066f fix(backend): exception occurs when running the latest code from the main branch (v1 conversations) (#12183) 2025-12-29 09:57:14 -05:00
dependabot[bot]
2261281656 chore(deps): bump @tanstack/react-query from 5.90.12 to 5.90.14 in /frontend in the version-all group (#12189)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-12-29 14:33:52 +00:00
sp.wack
d68b2cdd1a hotfix(frontend): fix provider type import (#12187) 2025-12-29 18:01:22 +04:00
dependabot[bot]
c70ecc8fe3 chore(deps): bump the version-all group across 1 directory with 6 updates (#12161)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: amanape <83104063+amanape@users.noreply.github.com>
2025-12-29 13:54:58 +00:00
Pedro Henrique
a3e85e2c2d test: Add MC/DC tests for loop pattern detector (stuck_detector) (#11600)
Co-authored-by: Engel Nyst <engel.nyst@gmail.com>
Co-authored-by: openhands <openhands@all-hands.dev>
2025-12-29 14:15:15 +01:00
Hiep Le
3bef4e6c2d refactor(frontend): update the error message for email addresses containing + during signup (#12178) 2025-12-29 19:36:28 +07:00
Engel Nyst
97654e6a5e Configurable conda/mamba channel_alias for runtime builds (#11516)
Co-authored-by: openhands <openhands@all-hands.dev>
2025-12-29 00:40:57 +01:00
Tim O'Farrell
30114666ad Bump the SDK to 1.7.1 (#12182)
Co-authored-by: openhands <openhands@all-hands.dev>
2025-12-28 18:57:08 +00:00
dependabot[bot]
ee50f333ba chore(deps): bump actions/upload-artifact from 4 to 5 (#11805)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Graham Neubig <neubig@gmail.com>
2025-12-28 09:51:34 -05:00
dependabot[bot]
09d1748a14 build(deps): bump actions/setup-python from 5 to 6 (#11755)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Graham Neubig <neubig@gmail.com>
2025-12-28 09:49:17 -05:00
dependabot[bot]
81519343c4 chore(deps): bump actions/download-artifact from 4 to 6 (#11524)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Graham Neubig <neubig@gmail.com>
2025-12-28 09:49:02 -05:00
dependabot[bot]
f742811e81 chore(deps): bump actions/setup-node from 4 to 6 (#11442)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Graham Neubig <neubig@gmail.com>
2025-12-28 08:58:26 -05:00
johba
f8e4b5562e Forgejo integration (#11111)
Co-authored-by: johba <admin@noreply.localhost>
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: johba <johba@harb.eth>
Co-authored-by: enyst <engel.nyst@gmail.com>
Co-authored-by: Graham Neubig <neubig@gmail.com>
Co-authored-by: MrGeorgen <65063405+MrGeorgen@users.noreply.github.com>
Co-authored-by: MrGeorgen <moinl6162@gmail.com>
2025-12-27 15:57:31 -05:00
Tim O'Farrell
cb1d1f8a0d Fix install-hooks CronJob failing when gitlab_webhook table doesn't exist (#12167)
Co-authored-by: openhands <openhands@all-hands.dev>
2025-12-26 10:53:21 -07:00
Tim O'Farrell
a829d10213 ALL-4634: implement public conversation sharing feature (#12044)
Co-authored-by: openhands <openhands@all-hands.dev>
2025-12-26 10:02:01 -07:00
Tim O'Farrell
cb8c1fa263 ALL-4627 Database Fixes (#12156)
Co-authored-by: openhands <openhands@all-hands.dev>
2025-12-26 09:19:51 -07:00
lif
c80f70392f fix(frontend): clean up console warnings in test suite (#12004)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: amanape <83104063+amanape@users.noreply.github.com>
2025-12-25 22:26:12 +04:00
Guy Elsmore-Paddock
94e6490a79 Use tini as Docker Runtime Init to Ensure Zombie Processes Get Reaped (#12133)
Co-authored-by: Tim O'Farrell <tofarr@gmail.com>
2025-12-25 06:16:52 +00:00
Tim O'Farrell
09af93a02a Agent server env override (#12068)
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Engel Nyst <engel.nyst@gmail.com>
2025-12-25 03:55:06 +00:00
shanemort1982
5407ea55aa Fix WebSocket localhost bug by passing DOCKER_HOST_ADDR to runtime containers (#12113)
Co-authored-by: openhands <openhands@all-hands.dev>
2025-12-24 14:26:45 -07:00
Tim O'Farrell
fe1026ee8a Fix for re-creating deleted conversation (#12152) 2025-12-24 12:13:29 -07:00
Tim O'Farrell
6d14ce420e Implement Export feature for V1 conversations with comprehensive unit tests (#12030)
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: hieptl <hieptl.developer@gmail.com>
2025-12-24 17:50:57 +00:00
lif
36fe23aea3 fix(llm): retry LiteLLM bad gateway errors (#12117) 2025-12-24 06:37:12 -05:00
sp.wack
9049b95792 docs(frontend): React Router testing guide (#12145) 2025-12-24 14:21:55 +04:00
Hiep Le
e2b2aa52cd feat: require email verification for new signups (#12123) 2025-12-24 14:56:02 +07:00
Tim O'Farrell
dc99c7b62e Fix SQLAlchemy result handling in get_sandbox_by_session_api_key (#12148)
Co-authored-by: openhands <openhands@all-hands.dev>
2025-12-24 00:11:16 +00:00
Tim O'Farrell
8bc1a47a78 Fix for error in get_sandbox_by_session_api_key (#12147) 2025-12-23 22:18:36 +00:00
Tim O'Farrell
8d0e7a92b8 ALL-4636 Resolution for connection leaks (#12144)
Co-authored-by: openhands <openhands@all-hands.dev>
2025-12-23 19:02:56 +00:00
Hiep Le
f6e7628bff feat: prevent signups using email addresses with a plus sign and enforce the existing email pattern (#12124) 2025-12-24 01:48:05 +07:00
sp.wack
fae83230ee docs(frontend): Add API services guide for frontend development (#12132) 2025-12-23 12:57:55 +00:00
sp.wack
a9d2f72d72 docs(frontend): Add MSW testing guide for frontend development (#12131) 2025-12-23 16:32:27 +04:00
Tim O'Farrell
2b8f779b65 fix: Runtime pods fail to start due to missing Playwright browser path (#12130) 2025-12-22 17:04:10 +00:00
Hiep Le
10edb28729 fix(frontend): llm settings view resets to basic after saving (#12097) 2025-12-22 23:00:57 +07:00
Hiep Le
5553d3ca2e feat: support blocking specific email domains (#12115) 2025-12-21 19:49:11 +07:00
311 changed files with 14474 additions and 3410 deletions

View File

@@ -15,7 +15,7 @@ jobs:
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: "3.12"

View File

@@ -27,7 +27,7 @@ jobs:
poetry-version: 2.1.3
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: '3.12'
cache: 'poetry'
@@ -38,7 +38,7 @@ jobs:
sudo apt-get install -y libgtk-3-0 libnotify4 libnss3 libxss1 libxtst6 xauth xvfb libgbm1 libasound2t64 netcat-openbsd
- name: Setup Node.js
uses: actions/setup-node@v4
uses: actions/setup-node@v6
with:
node-version: '22'
cache: 'npm'
@@ -192,7 +192,7 @@ jobs:
- name: Upload test results
if: always()
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: playwright-report
path: tests/e2e/test-results/
@@ -200,7 +200,7 @@ jobs:
- name: Upload OpenHands logs
if: always()
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: openhands-logs
path: |

View File

@@ -43,7 +43,7 @@ jobs:
⚠️ This PR contains **migrations**
- name: Comment warning on PR
uses: peter-evans/create-or-update-comment@v4
uses: peter-evans/create-or-update-comment@v5
with:
issue-number: ${{ github.event.pull_request.number }}
comment-id: ${{ steps.find-comment.outputs.comment-id }}

View File

@@ -39,7 +39,7 @@ jobs:
working-directory: ./frontend
run: npx playwright test --project=chromium
- name: Upload Playwright report
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
if: always()
with:
name: playwright-report

View File

@@ -64,7 +64,7 @@ jobs:
with:
ref: ${{ github.event.pull_request.head.sha }}
- name: Set up QEMU
uses: docker/setup-qemu-action@v3.6.0
uses: docker/setup-qemu-action@v3.7.0
with:
image: tonistiigi/binfmt:latest
- name: Login to GHCR
@@ -102,7 +102,7 @@ jobs:
with:
ref: ${{ github.event.pull_request.head.sha }}
- name: Set up QEMU
uses: docker/setup-qemu-action@v3.6.0
uses: docker/setup-qemu-action@v3.7.0
with:
image: tonistiigi/binfmt:latest
- name: Login to GHCR
@@ -161,7 +161,7 @@ jobs:
context: containers/runtime
- name: Upload runtime source for fork
if: github.event.pull_request.head.repo.fork
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: runtime-src-${{ matrix.base_image.tag }}
path: containers/runtime
@@ -268,7 +268,7 @@ jobs:
uses: docker/setup-buildx-action@v3
- name: Download runtime source for fork
if: github.event.pull_request.head.repo.fork
uses: actions/download-artifact@v4
uses: actions/download-artifact@v6
with:
name: runtime-src-${{ matrix.base_image.tag }}
path: containers/runtime
@@ -330,7 +330,7 @@ jobs:
uses: docker/setup-buildx-action@v3
- name: Download runtime source for fork
if: github.event.pull_request.head.repo.fork
uses: actions/download-artifact@v4
uses: actions/download-artifact@v6
with:
name: runtime-src-${{ matrix.base_image.tag }}
path: containers/runtime

View File

@@ -89,7 +89,7 @@ jobs:
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: "3.12"
- name: Upgrade pip
@@ -118,7 +118,7 @@ jobs:
contains(github.event.review.body, '@openhands-agent-exp')
)
)
uses: actions/cache@v4
uses: actions/cache@v5
with:
path: ${{ env.pythonLocation }}/lib/python3.12/site-packages/*
key: ${{ runner.os }}-pip-openhands-resolver-${{ hashFiles('/tmp/requirements.txt') }}
@@ -269,7 +269,7 @@ jobs:
fi
- name: Upload output.jsonl as artifact
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
if: always() # Upload even if the previous steps fail
with:
name: resolver-output

View File

@@ -63,7 +63,7 @@ jobs:
env:
COVERAGE_FILE: ".coverage.runtime.${{ matrix.python_version }}"
- name: Store coverage file
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: coverage-openhands
path: |
@@ -95,7 +95,7 @@ jobs:
env:
COVERAGE_FILE: ".coverage.enterprise.${{ matrix.python_version }}"
- name: Store coverage file
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: coverage-enterprise
path: ".coverage.enterprise.${{ matrix.python_version }}"
@@ -113,7 +113,7 @@ jobs:
steps:
- uses: actions/checkout@v4
- uses: actions/download-artifact@v5
- uses: actions/download-artifact@v6
id: download
with:
pattern: coverage-*

View File

@@ -37,7 +37,7 @@ jobs:
node-version: '22'
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: '3.12'
@@ -70,7 +70,7 @@ jobs:
fi
- name: Upload VSCode extension artifact
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: vscode-extension
path: openhands/integrations/vscode/openhands-vscode-0.0.1.vsix
@@ -142,7 +142,7 @@ jobs:
steps:
- name: Download .vsix artifact
uses: actions/download-artifact@v4
uses: actions/download-artifact@v6
with:
name: vscode-extension
path: ./

View File

@@ -161,7 +161,7 @@ poetry run pytest ./tests/unit/test_*.py
To reduce build time (e.g., if no changes were made to the client-runtime component), you can use an existing Docker
container image by setting the SANDBOX_RUNTIME_CONTAINER_IMAGE environment variable to the desired Docker image.
Example: `export SANDBOX_RUNTIME_CONTAINER_IMAGE=ghcr.io/openhands/runtime:1.0-nikolaik`
Example: `export SANDBOX_RUNTIME_CONTAINER_IMAGE=ghcr.io/openhands/runtime:1.1-nikolaik`
## Develop inside Docker container

View File

@@ -12,7 +12,7 @@ services:
- SANDBOX_API_HOSTNAME=host.docker.internal
- DOCKER_HOST_ADDR=host.docker.internal
#
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-ghcr.io/openhands/runtime:1.0-nikolaik}
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-ghcr.io/openhands/runtime:1.1-nikolaik}
- SANDBOX_USER_ID=${SANDBOX_USER_ID:-1234}
- WORKSPACE_MOUNT_PATH=${WORKSPACE_BASE:-$PWD/workspace}
ports:

View File

@@ -7,7 +7,7 @@ services:
image: openhands:latest
container_name: openhands-app-${DATE:-}
environment:
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-docker.openhands.dev/openhands/runtime:1.0-nikolaik}
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-docker.openhands.dev/openhands/runtime:1.1-nikolaik}
#- SANDBOX_USER_ID=${SANDBOX_USER_ID:-1234} # enable this only if you want a specific non-root sandbox user but you will have to manually adjust permissions of ~/.openhands for this user
- WORKSPACE_MOUNT_PATH=${WORKSPACE_BASE:-$PWD/workspace}
ports:

View File

@@ -50,7 +50,7 @@ First run this to retrieve Github App secrets
```
gcloud auth application-default login
gcloud config set project global-432717
local/decrypt_env.sh
enterprise_local/decrypt_env.sh /path/to/root/of/deploy/repo
```
Now run this to generate a `.env` file, which will used to run SAAS locally

4
enterprise/enterprise_local/decrypt_env.sh Normal file → Executable file
View File

@@ -4,12 +4,12 @@ set -euo pipefail
# Check if DEPLOY_DIR argument was provided
if [ $# -lt 1 ]; then
echo "Usage: $0 <DEPLOY_DIR>"
echo "Example: $0 /path/to/deploy"
echo "Example: $0 /path/to/root/of/deploy/repo"
exit 1
fi
# Normalize path (remove trailing slash)
DEPLOY_DIR="${DEPLOY_DIR%/}"
DEPLOY_DIR="${1%/}"
# Function to decrypt and rename
decrypt_and_move() {

View File

@@ -321,7 +321,7 @@ def append_conversation_footer(message: str, conversation_id: str) -> str:
The message with the conversation footer appended
"""
conversation_link = CONVERSATION_URL.format(conversation_id)
footer = f'\n\n<sub>[View full conversation]({conversation_link})</sub>'
footer = f'\n\n[View full conversation]({conversation_link})'
return message + footer

View File

@@ -0,0 +1,41 @@
"""add public column to conversation_metadata
Revision ID: 085
Revises: 084
Create Date: 2025-01-27 00:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '085'
down_revision: Union[str, None] = '084'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
op.add_column(
'conversation_metadata',
sa.Column('public', sa.Boolean(), nullable=True),
)
op.create_index(
op.f('ix_conversation_metadata_public'),
'conversation_metadata',
['public'],
unique=False,
)
def downgrade() -> None:
"""Downgrade schema."""
op.drop_index(
op.f('ix_conversation_metadata_public'),
table_name='conversation_metadata',
)
op.drop_column('conversation_metadata', 'public')

42
enterprise/poetry.lock generated
View File

@@ -4558,25 +4558,25 @@ valkey = ["valkey (>=6)"]
[[package]]
name = "litellm"
version = "1.80.7"
version = "1.80.11"
description = "Library to easily interface with LLM API providers"
optional = false
python-versions = "<4.0,>=3.9"
groups = ["main"]
files = [
{file = "litellm-1.80.7-py3-none-any.whl", hash = "sha256:f7d993f78c1e0e4e1202b2a925cc6540b55b6e5fb055dd342d88b145ab3102ed"},
{file = "litellm-1.80.7.tar.gz", hash = "sha256:3977a8d195aef842d01c18bf9e22984829363c6a4b54daf9a43c9dd9f190b42c"},
{file = "litellm-1.80.11-py3-none-any.whl", hash = "sha256:406283d66ead77dc7ff0e0b2559c80e9e497d8e7c2257efb1cb9210a20d09d54"},
{file = "litellm-1.80.11.tar.gz", hash = "sha256:c9fc63e7acb6360363238fe291bcff1488c59ff66020416d8376c0ee56414a19"},
]
[package.dependencies]
aiohttp = ">=3.10"
click = "*"
fastuuid = ">=0.13.0"
grpcio = ">=1.62.3,<1.68.0"
grpcio = {version = ">=1.62.3,<1.68.0", markers = "python_version < \"3.14\""}
httpx = ">=0.23.0"
importlib-metadata = ">=6.8.0"
jinja2 = ">=3.1.2,<4.0.0"
jsonschema = ">=4.22.0,<5.0.0"
jsonschema = ">=4.23.0,<5.0.0"
openai = ">=2.8.0"
pydantic = ">=2.5.0,<3.0.0"
python-dotenv = ">=0.2.0"
@@ -4587,7 +4587,7 @@ tokenizers = "*"
caching = ["diskcache (>=5.6.1,<6.0.0)"]
extra-proxy = ["azure-identity (>=1.15.0,<2.0.0) ; python_version >= \"3.9\"", "azure-keyvault-secrets (>=4.8.0,<5.0.0)", "google-cloud-iam (>=2.19.1,<3.0.0)", "google-cloud-kms (>=2.21.3,<3.0.0)", "prisma (==0.11.0)", "redisvl (>=0.4.1,<0.5.0) ; python_version >= \"3.9\" and python_version < \"3.14\"", "resend (>=0.8.0)"]
mlflow = ["mlflow (>3.1.4) ; python_version >= \"3.10\""]
proxy = ["PyJWT (>=2.10.1,<3.0.0) ; python_version >= \"3.9\"", "apscheduler (>=3.10.4,<4.0.0)", "azure-identity (>=1.15.0,<2.0.0) ; python_version >= \"3.9\"", "azure-storage-blob (>=12.25.1,<13.0.0)", "backoff", "boto3 (==1.36.0)", "cryptography", "fastapi (>=0.120.1)", "fastapi-sso (>=0.16.0,<0.17.0)", "gunicorn (>=23.0.0,<24.0.0)", "litellm-enterprise (==0.1.22)", "litellm-proxy-extras (==0.4.9)", "mcp (>=1.21.2,<2.0.0) ; python_version >= \"3.10\"", "orjson (>=3.9.7,<4.0.0)", "polars (>=1.31.0,<2.0.0) ; python_version >= \"3.10\"", "pynacl (>=1.5.0,<2.0.0)", "python-multipart (>=0.0.18,<0.0.19)", "pyyaml (>=6.0.1,<7.0.0)", "rich (==13.7.1)", "rq", "soundfile (>=0.12.1,<0.13.0)", "uvicorn (>=0.31.1,<0.32.0)", "uvloop (>=0.21.0,<0.22.0) ; sys_platform != \"win32\"", "websockets (>=15.0.1,<16.0.0)"]
proxy = ["PyJWT (>=2.10.1,<3.0.0) ; python_version >= \"3.9\"", "apscheduler (>=3.10.4,<4.0.0)", "azure-identity (>=1.15.0,<2.0.0) ; python_version >= \"3.9\"", "azure-storage-blob (>=12.25.1,<13.0.0)", "backoff", "boto3 (==1.36.0)", "cryptography", "fastapi (>=0.120.1)", "fastapi-sso (>=0.16.0,<0.17.0)", "gunicorn (>=23.0.0,<24.0.0)", "litellm-enterprise (==0.1.27)", "litellm-proxy-extras (==0.4.16)", "mcp (>=1.21.2,<2.0.0) ; python_version >= \"3.10\"", "orjson (>=3.9.7,<4.0.0)", "polars (>=1.31.0,<2.0.0) ; python_version >= \"3.10\"", "pynacl (>=1.5.0,<2.0.0)", "python-multipart (>=0.0.18,<0.0.19)", "pyyaml (>=6.0.1,<7.0.0)", "rich (==13.7.1)", "rq", "soundfile (>=0.12.1,<0.13.0)", "uvicorn (>=0.31.1,<0.32.0)", "uvloop (>=0.21.0,<0.22.0) ; sys_platform != \"win32\"", "websockets (>=15.0.1,<16.0.0)"]
semantic-router = ["semantic-router (>=0.1.12) ; python_version >= \"3.9\" and python_version < \"3.14\""]
utils = ["numpydoc"]
@@ -5836,14 +5836,14 @@ llama = ["llama-index (>=0.12.29,<0.13.0)", "llama-index-core (>=0.12.29,<0.13.0
[[package]]
name = "openhands-agent-server"
version = "1.6.0"
version = "1.7.1"
description = "OpenHands Agent Server - REST/WebSocket interface for OpenHands AI Agent"
optional = false
python-versions = ">=3.12"
groups = ["main"]
files = [
{file = "openhands_agent_server-1.6.0-py3-none-any.whl", hash = "sha256:e6ae865ac3e7a96b234e10a0faad23f6210e025bbf7721cb66bc7a71d160848c"},
{file = "openhands_agent_server-1.6.0.tar.gz", hash = "sha256:44ce7694ae2d4bb0666d318ef13e6618bd4dc73022c60354839fe6130e67d02a"},
{file = "openhands_agent_server-1.7.1-py3-none-any.whl", hash = "sha256:e5c57f1b73293d00a68b77f9d290f59d9e2217d9df844fb01c7d2f929c3417f4"},
{file = "openhands_agent_server-1.7.1.tar.gz", hash = "sha256:c82e1e6748ea3b4278ef2ee72f091dc37da6667c854b3aa3c0bc616086a82310"},
]
[package.dependencies]
@@ -5860,7 +5860,7 @@ wsproto = ">=1.2.0"
[[package]]
name = "openhands-ai"
version = "0.0.0-post.5687+7853b41ad"
version = "0.0.0-post.5742+ee50f333b"
description = "OpenHands: Code Less, Make More"
optional = false
python-versions = "^3.12,<3.14"
@@ -5896,15 +5896,15 @@ json-repair = "*"
jupyter_kernel_gateway = "*"
kubernetes = "^33.1.0"
libtmux = ">=0.46.2"
litellm = ">=1.74.3, <=1.80.7, !=1.64.4, !=1.67.*"
litellm = ">=1.74.3, !=1.64.4, !=1.67.*"
lmnr = "^0.7.20"
memory-profiler = "^0.61.0"
numpy = "*"
openai = "2.8.0"
openhands-aci = "0.3.2"
openhands-agent-server = "1.6.0"
openhands-sdk = "1.6.0"
openhands-tools = "1.6.0"
openhands-agent-server = "1.7.1"
openhands-sdk = "1.7.1"
openhands-tools = "1.7.1"
opentelemetry-api = "^1.33.1"
opentelemetry-exporter-otlp-proto-grpc = "^1.33.1"
pathspec = "^0.12.1"
@@ -5960,21 +5960,21 @@ url = ".."
[[package]]
name = "openhands-sdk"
version = "1.6.0"
version = "1.7.1"
description = "OpenHands SDK - Core functionality for building AI agents"
optional = false
python-versions = ">=3.12"
groups = ["main"]
files = [
{file = "openhands_sdk-1.6.0-py3-none-any.whl", hash = "sha256:94d2f87fb35406373da6728ae2d88584137f9e9b67fa0e940444c72f2e44e7d3"},
{file = "openhands_sdk-1.6.0.tar.gz", hash = "sha256:f45742350e3874a7f5b08befc4a9d5adc7e4454f7ab5f8391c519eee3116090f"},
{file = "openhands_sdk-1.7.1-py3-none-any.whl", hash = "sha256:e097e34dfbd45f38225ae2ff4830702424bcf742bc197b5a811540a75265b135"},
{file = "openhands_sdk-1.7.1.tar.gz", hash = "sha256:e13d1fe8bf14dffd91e9080608072a989132c981cf9bfcd124fa4f7a68a13691"},
]
[package.dependencies]
deprecation = ">=2.1.0"
fastmcp = ">=2.11.3"
httpx = ">=0.27.0"
litellm = ">=1.80.7"
litellm = ">=1.80.10"
lmnr = ">=0.7.24"
pydantic = ">=2.11.7"
python-frontmatter = ">=1.1.0"
@@ -5987,14 +5987,14 @@ boto3 = ["boto3 (>=1.35.0)"]
[[package]]
name = "openhands-tools"
version = "1.6.0"
version = "1.7.1"
description = "OpenHands Tools - Runtime tools for AI agents"
optional = false
python-versions = ">=3.12"
groups = ["main"]
files = [
{file = "openhands_tools-1.6.0-py3-none-any.whl", hash = "sha256:176556d44186536751b23fe052d3505492cc2afb8d52db20fb7a2cc0169cd57a"},
{file = "openhands_tools-1.6.0.tar.gz", hash = "sha256:d07ba31050fd4a7891a4c48388aa53ce9f703e17064ddbd59146d6c77e5980b3"},
{file = "openhands_tools-1.7.1-py3-none-any.whl", hash = "sha256:e25815f24925e94fbd4d8c3fd9b2147a0556fde595bf4f80a7dbba1014ea3c86"},
{file = "openhands_tools-1.7.1.tar.gz", hash = "sha256:f3823f7bd302c78969c454730cf793eb63109ce2d986e78585989c53986cc966"},
]
[package.dependencies]

View File

@@ -37,6 +37,12 @@ from server.routes.mcp_patch import patch_mcp_server # noqa: E402
from server.routes.oauth_device import oauth_device_router # noqa: E402
from server.routes.readiness import readiness_router # noqa: E402
from server.routes.user import saas_user_router # noqa: E402
from server.sharing.shared_conversation_router import ( # noqa: E402
router as shared_conversation_router,
)
from server.sharing.shared_event_router import ( # noqa: E402
router as shared_event_router,
)
from openhands.server.app import app as base_app # noqa: E402
from openhands.server.listen_socket import sio # noqa: E402
@@ -66,6 +72,8 @@ base_app.include_router(saas_user_router) # Add additional route SAAS user call
base_app.include_router(
billing_router
) # Add routes for credit management and Stripe payment integration
base_app.include_router(shared_conversation_router)
base_app.include_router(shared_event_router)
# Add GitHub integration router only if GITHUB_APP_CLIENT_ID is set
if GITHUB_APP_CLIENT_ID:
@@ -99,6 +107,7 @@ base_app.include_router(
event_webhook_router
) # Add routes for Events in nested runtimes
base_app.add_middleware(
CORSMiddleware,
allow_origins=PERMITTED_CORS_ORIGINS,

View File

@@ -38,3 +38,8 @@ ROLE_CHECK_ENABLED = os.getenv('ROLE_CHECK_ENABLED', 'false').lower() in (
'y',
'on',
)
BLOCKED_EMAIL_DOMAINS = [
domain.strip().lower()
for domain in os.getenv('BLOCKED_EMAIL_DOMAINS', '').split(',')
if domain.strip()
]

View File

@@ -0,0 +1,56 @@
from server.auth.constants import BLOCKED_EMAIL_DOMAINS
from openhands.core.logger import openhands_logger as logger
class DomainBlocker:
def __init__(self) -> None:
logger.debug('Initializing DomainBlocker')
self.blocked_domains: list[str] = BLOCKED_EMAIL_DOMAINS
if self.blocked_domains:
logger.info(
f'Successfully loaded {len(self.blocked_domains)} blocked email domains: {self.blocked_domains}'
)
def is_active(self) -> bool:
"""Check if domain blocking is enabled"""
return bool(self.blocked_domains)
def _extract_domain(self, email: str) -> str | None:
"""Extract and normalize email domain from email address"""
if not email:
return None
try:
# Extract domain part after @
if '@' not in email:
return None
domain = email.split('@')[1].strip().lower()
return domain if domain else None
except Exception:
logger.debug(f'Error extracting domain from email: {email}', exc_info=True)
return None
def is_domain_blocked(self, email: str) -> bool:
"""Check if email domain is blocked"""
if not self.is_active():
return False
if not email:
logger.debug('No email provided for domain check')
return False
domain = self._extract_domain(email)
if not domain:
logger.debug(f'Could not extract domain from email: {email}')
return False
is_blocked = domain in self.blocked_domains
if is_blocked:
logger.warning(f'Email domain {domain} is blocked for email: {email}')
else:
logger.debug(f'Email domain {domain} is not blocked')
return is_blocked
domain_blocker = DomainBlocker()

View File

@@ -0,0 +1,109 @@
"""Email validation utilities for preventing duplicate signups with + modifier."""
import re
def extract_base_email(email: str) -> str | None:
"""Extract base email from an email address.
For emails with + modifier, extracts the base email (local part before + and @, plus domain).
For emails without + modifier, returns the email as-is.
Examples:
extract_base_email("joe+test@example.com") -> "joe@example.com"
extract_base_email("joe@example.com") -> "joe@example.com"
extract_base_email("joe+openhands+test@example.com") -> "joe@example.com"
Args:
email: The email address to process
Returns:
The base email address, or None if email format is invalid
"""
if not email or '@' not in email:
return None
try:
local_part, domain = email.rsplit('@', 1)
# Extract the part before + if it exists
base_local = local_part.split('+', 1)[0]
return f'{base_local}@{domain}'
except (ValueError, AttributeError):
return None
def has_plus_modifier(email: str) -> bool:
"""Check if an email address contains a + modifier.
Args:
email: The email address to check
Returns:
True if email contains + before @, False otherwise
"""
if not email or '@' not in email:
return False
try:
local_part, _ = email.rsplit('@', 1)
return '+' in local_part
except (ValueError, AttributeError):
return False
def matches_base_email(email: str, base_email: str) -> bool:
"""Check if an email matches a base email pattern.
An email matches if:
- It is exactly the base email (e.g., joe@example.com)
- It has the same base local part and domain, with or without + modifier
(e.g., joe+test@example.com matches base joe@example.com)
Args:
email: The email address to check
base_email: The base email to match against
Returns:
True if email matches the base pattern, False otherwise
"""
if not email or not base_email:
return False
# Extract base from both emails for comparison
email_base = extract_base_email(email)
base_email_normalized = extract_base_email(base_email)
if not email_base or not base_email_normalized:
return False
# Emails match if they have the same base
return email_base.lower() == base_email_normalized.lower()
def get_base_email_regex_pattern(base_email: str) -> re.Pattern | None:
"""Generate a regex pattern to match emails with the same base.
For base_email "joe@example.com", the pattern will match:
- joe@example.com
- joe+anything@example.com
Args:
base_email: The base email address
Returns:
A compiled regex pattern, or None if base_email is invalid
"""
base = extract_base_email(base_email)
if not base:
return None
try:
local_part, domain = base.rsplit('@', 1)
# Escape special regex characters in local part and domain
escaped_local = re.escape(local_part)
escaped_domain = re.escape(domain)
# Pattern: joe@example.com OR joe+anything@example.com
pattern = rf'^{escaped_local}(\+[^@\s]+)?@{escaped_domain}$'
return re.compile(pattern, re.IGNORECASE)
except (ValueError, AttributeError):
return None

View File

@@ -13,6 +13,7 @@ from server.auth.auth_error import (
ExpiredError,
NoCredentialsError,
)
from server.auth.domain_blocker import domain_blocker
from server.auth.token_manager import TokenManager
from server.config import get_config
from server.logger import logger
@@ -153,8 +154,10 @@ class SaasUserAuth(UserAuth):
try:
# TODO: I think we can do this in a single request if we refactor
with session_maker() as session:
tokens = session.query(AuthTokens).where(
AuthTokens.keycloak_user_id == self.user_id
tokens = (
session.query(AuthTokens)
.where(AuthTokens.keycloak_user_id == self.user_id)
.all()
)
for token in tokens:
@@ -312,6 +315,16 @@ async def saas_user_auth_from_signed_token(signed_token: str) -> SaasUserAuth:
user_id = access_token_payload['sub']
email = access_token_payload['email']
email_verified = access_token_payload['email_verified']
# Check if email domain is blocked
if email and domain_blocker.is_active() and domain_blocker.is_domain_blocked(email):
logger.warning(
f'Blocked authentication attempt for existing user with email: {email}'
)
raise AuthError(
'Access denied: Your email domain is not allowed to access this service'
)
logger.debug('saas_user_auth_from_signed_token:return')
return SaasUserAuth(

View File

@@ -1,3 +1,4 @@
import asyncio
import base64
import hashlib
import json
@@ -25,6 +26,11 @@ from server.auth.constants import (
KEYCLOAK_SERVER_URL,
KEYCLOAK_SERVER_URL_EXT,
)
from server.auth.email_validation import (
extract_base_email,
get_base_email_regex_pattern,
matches_base_email,
)
from server.auth.keycloak_manager import get_keycloak_admin, get_keycloak_openid
from server.config import get_config
from server.logger import logger
@@ -509,6 +515,183 @@ class TokenManager:
logger.info(f'Got user ID {keycloak_user_id} from email: {email}')
return keycloak_user_id
async def _query_users_by_wildcard_pattern(
self, local_part: str, domain: str
) -> dict[str, dict]:
"""Query Keycloak for users matching a wildcard email pattern.
Tries multiple query methods to find users with emails matching
the pattern {local_part}*@{domain}. This catches the base email
and all + modifier variants.
Args:
local_part: The local part of the email (before @)
domain: The domain part of the email (after @)
Returns:
Dictionary mapping user IDs to user objects
"""
keycloak_admin = get_keycloak_admin(self.external)
all_users = {}
# Query for users with emails matching the base pattern using wildcard
# Pattern: {local_part}*@{domain} - catches base email and all + variants
# This may also catch unintended matches (e.g., joesmith@example.com), but
# they will be filtered out by the regex pattern check later
# Use 'search' parameter for Keycloak 26+ (better wildcard support)
wildcard_queries = [
{'search': f'{local_part}*@{domain}'}, # Try 'search' parameter first
{'q': f'email:{local_part}*@{domain}'}, # Fallback to 'q' parameter
]
for query_params in wildcard_queries:
try:
users = await keycloak_admin.a_get_users(query_params)
for user in users:
all_users[user.get('id')] = user
break # Success, no need to try fallback
except Exception as e:
logger.debug(
f'Wildcard query failed with {list(query_params.keys())[0]}: {e}'
)
continue # Try next query method
return all_users
def _find_duplicate_in_users(
self, users: dict[str, dict], base_email: str, current_user_id: str
) -> bool:
"""Check if any user in the provided list matches the base email pattern.
Filters users to find duplicates that match the base email pattern,
excluding the current user.
Args:
users: Dictionary mapping user IDs to user objects
base_email: The base email to match against
current_user_id: The user ID to exclude from the check
Returns:
True if a duplicate is found, False otherwise
"""
regex_pattern = get_base_email_regex_pattern(base_email)
if not regex_pattern:
logger.warning(
f'Could not generate regex pattern for base email: {base_email}'
)
# Fallback to simple matching
for user in users.values():
user_email = user.get('email', '').lower()
if (
user_email
and user.get('id') != current_user_id
and matches_base_email(user_email, base_email)
):
logger.info(
f'Found duplicate email: {user_email} matches base {base_email}'
)
return True
else:
for user in users.values():
user_email = user.get('email', '')
if (
user_email
and user.get('id') != current_user_id
and regex_pattern.match(user_email)
):
logger.info(
f'Found duplicate email: {user_email} matches base {base_email}'
)
return True
return False
@retry(
stop=stop_after_attempt(2),
retry=retry_if_exception_type(KeycloakConnectionError),
before_sleep=_before_sleep_callback,
)
async def check_duplicate_base_email(
self, email: str, current_user_id: str
) -> bool:
"""Check if a user with the same base email already exists.
This method checks for duplicate signups using email + modifier.
It checks if any user exists with the same base email, regardless of whether
the provided email has a + modifier or not.
Examples:
- If email is "joe+test@example.com", it checks for existing users with
base email "joe@example.com" (e.g., "joe@example.com", "joe+1@example.com")
- If email is "joe@example.com", it checks for existing users with
base email "joe@example.com" (e.g., "joe+1@example.com", "joe+test@example.com")
Args:
email: The email address to check (may or may not contain + modifier)
current_user_id: The user ID of the current user (to exclude from check)
Returns:
True if a duplicate is found (excluding current user), False otherwise
"""
if not email:
return False
base_email = extract_base_email(email)
if not base_email:
logger.warning(f'Could not extract base email from: {email}')
return False
try:
local_part, domain = base_email.rsplit('@', 1)
users = await self._query_users_by_wildcard_pattern(local_part, domain)
return self._find_duplicate_in_users(users, base_email, current_user_id)
except KeycloakConnectionError:
logger.exception('KeycloakConnectionError when checking duplicate email')
raise
except Exception as e:
logger.exception(f'Unexpected error checking duplicate email: {e}')
# On any error, allow signup to proceed (fail open)
return False
@retry(
stop=stop_after_attempt(2),
retry=retry_if_exception_type(KeycloakConnectionError),
before_sleep=_before_sleep_callback,
)
async def delete_keycloak_user(self, user_id: str) -> bool:
"""Delete a user from Keycloak.
This method is used to clean up user accounts that were created
but should not exist (e.g., duplicate email signups).
Args:
user_id: The Keycloak user ID to delete
Returns:
True if deletion was successful, False otherwise
"""
try:
keycloak_admin = get_keycloak_admin(self.external)
# Use the sync method (python-keycloak doesn't have async delete_user)
# Run it in a thread executor to avoid blocking the event loop
await asyncio.to_thread(keycloak_admin.delete_user, user_id)
logger.info(f'Successfully deleted Keycloak user {user_id}')
return True
except KeycloakConnectionError:
logger.exception(f'KeycloakConnectionError when deleting user {user_id}')
raise
except KeycloakError as e:
# User might not exist or already deleted
logger.warning(
f'KeycloakError when deleting user {user_id}: {e}',
extra={'user_id': user_id, 'error': str(e)},
)
return False
except Exception as e:
logger.exception(f'Unexpected error deleting Keycloak user {user_id}: {e}')
return False
async def get_user_info_from_user_id(self, user_id: str) -> dict | None:
keycloak_admin = get_keycloak_admin(self.external)
user = await keycloak_admin.a_get_user(user_id)
@@ -527,6 +710,49 @@ class TokenManager:
github_id = github_ids[0]
return github_id
async def disable_keycloak_user(
self, user_id: str, email: str | None = None
) -> None:
"""Disable a Keycloak user account.
Args:
user_id: The Keycloak user ID to disable
email: Optional email address for logging purposes
This method attempts to disable the user account but will not raise exceptions.
Errors are logged but do not prevent the operation from completing.
"""
try:
keycloak_admin = get_keycloak_admin(self.external)
# Get current user to preserve other fields
user = await keycloak_admin.a_get_user(user_id)
if user:
# Update user with enabled=False to disable the account
await keycloak_admin.a_update_user(
user_id=user_id,
payload={
'enabled': False,
'username': user.get('username', ''),
'email': user.get('email', ''),
'emailVerified': user.get('emailVerified', False),
},
)
email_str = f', email: {email}' if email else ''
logger.info(
f'Disabled Keycloak account for user_id: {user_id}{email_str}'
)
else:
logger.warning(
f'User not found in Keycloak when attempting to disable: {user_id}'
)
except Exception as e:
# Log error but don't raise - the caller should handle the blocking regardless
email_str = f', email: {email}' if email else ''
logger.error(
f'Failed to disable Keycloak account for user_id: {user_id}{email_str}: {str(e)}',
exc_info=True,
)
def store_org_token(self, installation_id: int, installation_token: str):
"""Store a GitHub App installation token.

View File

@@ -159,6 +159,7 @@ class SetAuthCookieMiddleware:
'/api/billing/cancel',
'/api/billing/customer-setup-success',
'/api/billing/stripe-webhook',
'/api/email/resend',
'/oauth/device/authorize',
'/oauth/device/token',
)

View File

@@ -14,6 +14,7 @@ from server.auth.constants import (
KEYCLOAK_SERVER_URL_EXT,
ROLE_CHECK_ENABLED,
)
from server.auth.domain_blocker import domain_blocker
from server.auth.gitlab_sync import schedule_gitlab_repo_sync
from server.auth.saas_user_auth import SaasUserAuth
from server.auth.token_manager import TokenManager
@@ -145,7 +146,76 @@ async def keycloak_callback(
content={'error': 'Missing user ID or username in response'},
)
email = user_info.get('email')
user_id = user_info['sub']
# Check if email domain is blocked
email = user_info.get('email')
if email and domain_blocker.is_active() and domain_blocker.is_domain_blocked(email):
logger.warning(
f'Blocked authentication attempt for email: {email}, user_id: {user_id}'
)
# Disable the Keycloak account
await token_manager.disable_keycloak_user(user_id, email)
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={
'error': 'Access denied: Your email domain is not allowed to access this service'
},
)
# Check for duplicate email with + modifier
if email:
try:
has_duplicate = await token_manager.check_duplicate_base_email(
email, user_id
)
if has_duplicate:
logger.warning(
f'Blocked signup attempt for email {email} - duplicate base email found',
extra={'user_id': user_id, 'email': email},
)
# Delete the Keycloak user that was automatically created during OAuth
# This prevents orphaned accounts in Keycloak
# The delete_keycloak_user method already handles all errors internally
deletion_success = await token_manager.delete_keycloak_user(user_id)
if deletion_success:
logger.info(
f'Deleted Keycloak user {user_id} after detecting duplicate email {email}'
)
else:
logger.warning(
f'Failed to delete Keycloak user {user_id} after detecting duplicate email {email}. '
f'User may need to be manually cleaned up.'
)
# Redirect to home page with query parameter indicating the issue
home_url = f'{request.base_url}?duplicated_email=true'
return RedirectResponse(home_url, status_code=302)
except Exception as e:
# Log error but allow signup to proceed (fail open)
logger.error(
f'Error checking duplicate email for {email}: {e}',
extra={'user_id': user_id, 'email': email},
)
# Check email verification status
email_verified = user_info.get('email_verified', False)
if not email_verified:
# Send verification email
# Import locally to avoid circular import with email.py
from server.routes.email import verify_email
await verify_email(request=request, user_id=user_id, is_auth_flow=True)
redirect_url = (
f'{request.base_url}?email_verification_required=true&user_id={user_id}'
)
response = RedirectResponse(redirect_url, status_code=302)
return response
# default to github IDP for now.
# TODO: remove default once Keycloak is updated universally with the new attribute.
idp: str = user_info.get('identity_provider', ProviderType.GITHUB.value)

View File

@@ -111,10 +111,24 @@ def calculate_credits(user_info: LiteLlmUserInfo) -> float:
async def get_credits(user_id: str = Depends(get_user_id)) -> GetCreditsResponse:
if not stripe_service.STRIPE_API_KEY:
return GetCreditsResponse()
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
user_json = await _get_litellm_user(client, user_id)
credits = calculate_credits(user_json['user_info'])
return GetCreditsResponse(credits=Decimal('{:.2f}'.format(credits)))
try:
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
user_json = await _get_litellm_user(client, user_id)
credits = calculate_credits(user_json['user_info'])
return GetCreditsResponse(credits=Decimal('{:.2f}'.format(credits)))
except httpx.HTTPStatusError as e:
logger.error(
f'litellm_get_user_failed: {type(e).__name__}: {e}',
extra={
'user_id': user_id,
'status_code': e.response.status_code,
},
exc_info=True,
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to retrieve credit balance from billing service',
)
# Endpoint to retrieve user's current subscription access

View File

@@ -7,6 +7,7 @@ from server.auth.constants import KEYCLOAK_CLIENT_ID
from server.auth.keycloak_manager import get_keycloak_admin
from server.auth.saas_user_auth import SaasUserAuth
from server.routes.auth import set_response_cookie
from server.utils.rate_limit_utils import check_rate_limit_by_user_id
from openhands.core.logger import openhands_logger as logger
from openhands.server.user_auth import get_user_id
@@ -28,6 +29,11 @@ class EmailUpdate(BaseModel):
return v
class ResendEmailVerificationRequest(BaseModel):
user_id: str | None = None
is_auth_flow: bool = False
@api_router.post('')
async def update_email(
email_data: EmailUpdate, request: Request, user_id: str = Depends(get_user_id)
@@ -74,7 +80,7 @@ async def update_email(
accepted_tos=user_auth.accepted_tos,
)
await _verify_email(request=request, user_id=user_id)
await verify_email(request=request, user_id=user_id)
logger.info(f'Updating email address for {user_id} to {email}')
return response
@@ -90,9 +96,41 @@ async def update_email(
)
@api_router.put('/verify')
async def verify_email(request: Request, user_id: str = Depends(get_user_id)):
await _verify_email(request=request, user_id=user_id)
@api_router.put('/resend')
async def resend_email_verification(
request: Request,
body: ResendEmailVerificationRequest | None = None,
):
# Get user_id from body if provided, otherwise from auth
user_id: str | None = None
if body and body.user_id:
user_id = body.user_id
else:
try:
user_id = await get_user_id(request)
except Exception:
pass
if not user_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail='user_id is required in request body or user must be authenticated',
)
# Check rate limit (uses user_id if available, otherwise falls back to IP)
# Use 30 seconds for user-based rate limiting to match frontend cooldown
await check_rate_limit_by_user_id(
request=request,
key_prefix='email_resend',
user_id=user_id,
user_rate_limit_seconds=30,
ip_rate_limit_seconds=60, # 1 minute for IP-based limiting (more lenient)
)
# Get is_auth_flow from body if provided, default to False
is_auth_flow = body.is_auth_flow if body else False
await verify_email(request=request, user_id=user_id, is_auth_flow=is_auth_flow)
logger.info(f'Resending verification email for {user_id}')
return JSONResponse(
@@ -124,10 +162,13 @@ async def verified_email(request: Request):
return response
async def _verify_email(request: Request, user_id: str):
async def verify_email(request: Request, user_id: str, is_auth_flow: bool = False):
keycloak_admin = get_keycloak_admin()
scheme = 'http' if request.url.hostname == 'localhost' else 'https'
redirect_uri = f'{scheme}://{request.url.netloc}/api/email/verified'
if is_auth_flow:
redirect_uri = f'{scheme}://{request.url.netloc}?email_verified=true'
else:
redirect_uri = f'{scheme}://{request.url.netloc}/api/email/verified'
logger.info(f'Redirect URI: {redirect_uri}')
await keycloak_admin.a_send_verify_email(
user_id=user_id,

View File

@@ -134,12 +134,12 @@ async def _process_batch_operations_background(
)
except Exception as e:
logger.error(
'error_processing_batch_operation',
f'error_processing_batch_operation: {type(e).__name__}: {e}',
extra={
'path': batch_op.path,
'method': str(batch_op.method),
'error': str(e),
},
exc_info=True,
)

View File

@@ -804,6 +804,8 @@ class SaasNestedConversationManager(ConversationManager):
env_vars['ENABLE_V1'] = '0'
env_vars['SU_TO_USER'] = SU_TO_USER
env_vars['DISABLE_VSCODE_PLUGIN'] = str(DISABLE_VSCODE_PLUGIN).lower()
env_vars['BROWSERGYM_DOWNLOAD_DIR'] = '/workspace/.downloads/'
env_vars['PLAYWRIGHT_BROWSERS_PATH'] = '/opt/playwright-browsers'
# We need this for LLM traces tracking to identify the source of the LLM calls
env_vars['WEB_HOST'] = WEB_HOST

View File

@@ -0,0 +1,20 @@
# Sharing Package
This package contains functionality for sharing conversations.
## Components
- **shared.py**: Data models for shared conversations
- **shared_conversation_info_service.py**: Service interface for accessing shared conversation info
- **sql_shared_conversation_info_service.py**: SQL implementation of the shared conversation info service
- **shared_event_service.py**: Service interface for accessing shared events
- **shared_event_service_impl.py**: Implementation of the shared event service
- **shared_conversation_router.py**: REST API endpoints for shared conversations
- **shared_event_router.py**: REST API endpoints for shared events
## Features
- Read-only access to shared conversations
- Event access for shared conversations
- Search and filtering capabilities
- Pagination support

View File

@@ -0,0 +1,142 @@
"""Implementation of SharedEventService.
This implementation provides read-only access to events from shared conversations:
- Validates that the conversation is shared before returning events
- Uses existing EventService for actual event retrieval
- Uses SharedConversationInfoService for shared conversation validation
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from datetime import datetime
from typing import AsyncGenerator
from uuid import UUID
from fastapi import Request
from server.sharing.shared_conversation_info_service import (
SharedConversationInfoService,
)
from server.sharing.shared_event_service import (
SharedEventService,
SharedEventServiceInjector,
)
from server.sharing.sql_shared_conversation_info_service import (
SQLSharedConversationInfoService,
)
from openhands.agent_server.models import EventPage, EventSortOrder
from openhands.app_server.event.event_service import EventService
from openhands.app_server.event_callback.event_callback_models import EventKind
from openhands.app_server.services.injector import InjectorState
from openhands.sdk import Event
logger = logging.getLogger(__name__)
@dataclass
class SharedEventServiceImpl(SharedEventService):
"""Implementation of SharedEventService that validates shared access."""
shared_conversation_info_service: SharedConversationInfoService
event_service: EventService
async def get_shared_event(
self, conversation_id: UUID, event_id: str
) -> Event | None:
"""Given a conversation_id and event_id, retrieve an event if the conversation is shared."""
# First check if the conversation is shared
shared_conversation_info = (
await self.shared_conversation_info_service.get_shared_conversation_info(
conversation_id
)
)
if shared_conversation_info is None:
return None
# If conversation is shared, get the event
return await self.event_service.get_event(event_id)
async def search_shared_events(
self,
conversation_id: UUID,
kind__eq: EventKind | None = None,
timestamp__gte: datetime | None = None,
timestamp__lt: datetime | None = None,
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
page_id: str | None = None,
limit: int = 100,
) -> EventPage:
"""Search events for a specific shared conversation."""
# First check if the conversation is shared
shared_conversation_info = (
await self.shared_conversation_info_service.get_shared_conversation_info(
conversation_id
)
)
if shared_conversation_info is None:
# Return empty page if conversation is not shared
return EventPage(items=[], next_page_id=None)
# If conversation is shared, search events for this conversation
return await self.event_service.search_events(
conversation_id__eq=conversation_id,
kind__eq=kind__eq,
timestamp__gte=timestamp__gte,
timestamp__lt=timestamp__lt,
sort_order=sort_order,
page_id=page_id,
limit=limit,
)
async def count_shared_events(
self,
conversation_id: UUID,
kind__eq: EventKind | None = None,
timestamp__gte: datetime | None = None,
timestamp__lt: datetime | None = None,
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
) -> int:
"""Count events for a specific shared conversation."""
# First check if the conversation is shared
shared_conversation_info = (
await self.shared_conversation_info_service.get_shared_conversation_info(
conversation_id
)
)
if shared_conversation_info is None:
return 0
# If conversation is shared, count events for this conversation
return await self.event_service.count_events(
conversation_id__eq=conversation_id,
kind__eq=kind__eq,
timestamp__gte=timestamp__gte,
timestamp__lt=timestamp__lt,
sort_order=sort_order,
)
class SharedEventServiceImplInjector(SharedEventServiceInjector):
async def inject(
self, state: InjectorState, request: Request | None = None
) -> AsyncGenerator[SharedEventService, None]:
# Define inline to prevent circular lookup
from openhands.app_server.config import (
get_db_session,
get_event_service,
)
async with (
get_db_session(state, request) as db_session,
get_event_service(state, request) as event_service,
):
shared_conversation_info_service = SQLSharedConversationInfoService(
db_session=db_session
)
service = SharedEventServiceImpl(
shared_conversation_info_service=shared_conversation_info_service,
event_service=event_service,
)
yield service

View File

@@ -0,0 +1,66 @@
import asyncio
from abc import ABC, abstractmethod
from datetime import datetime
from uuid import UUID
from server.sharing.shared_conversation_models import (
SharedConversation,
SharedConversationPage,
SharedConversationSortOrder,
)
from openhands.app_server.services.injector import Injector
from openhands.sdk.utils.models import DiscriminatedUnionMixin
class SharedConversationInfoService(ABC):
"""Service for accessing shared conversation info without user restrictions."""
@abstractmethod
async def search_shared_conversation_info(
self,
title__contains: str | None = None,
created_at__gte: datetime | None = None,
created_at__lt: datetime | None = None,
updated_at__gte: datetime | None = None,
updated_at__lt: datetime | None = None,
sort_order: SharedConversationSortOrder = SharedConversationSortOrder.CREATED_AT_DESC,
page_id: str | None = None,
limit: int = 100,
include_sub_conversations: bool = False,
) -> SharedConversationPage:
"""Search for shared conversations."""
@abstractmethod
async def count_shared_conversation_info(
self,
title__contains: str | None = None,
created_at__gte: datetime | None = None,
created_at__lt: datetime | None = None,
updated_at__gte: datetime | None = None,
updated_at__lt: datetime | None = None,
) -> int:
"""Count shared conversations."""
@abstractmethod
async def get_shared_conversation_info(
self, conversation_id: UUID
) -> SharedConversation | None:
"""Get a single shared conversation info, returning None if missing or not shared."""
async def batch_get_shared_conversation_info(
self, conversation_ids: list[UUID]
) -> list[SharedConversation | None]:
"""Get a batch of shared conversation info, return None for any missing or non-shared."""
return await asyncio.gather(
*[
self.get_shared_conversation_info(conversation_id)
for conversation_id in conversation_ids
]
)
class SharedConversationInfoServiceInjector(
DiscriminatedUnionMixin, Injector[SharedConversationInfoService], ABC
):
pass

View File

@@ -0,0 +1,56 @@
from datetime import datetime
from enum import Enum
# Simplified imports to avoid dependency chain issues
# from openhands.integrations.service_types import ProviderType
# from openhands.sdk.llm import MetricsSnapshot
# from openhands.storage.data_models.conversation_metadata import ConversationTrigger
# For now, use Any to avoid import issues
from typing import Any
from uuid import uuid4
from pydantic import BaseModel, Field
from openhands.agent_server.utils import OpenHandsUUID, utc_now
ProviderType = Any
MetricsSnapshot = Any
ConversationTrigger = Any
class SharedConversation(BaseModel):
"""Shared conversation info model with all fields from AppConversationInfo."""
id: OpenHandsUUID = Field(default_factory=uuid4)
created_by_user_id: str | None
sandbox_id: str
selected_repository: str | None = None
selected_branch: str | None = None
git_provider: ProviderType | None = None
title: str | None = None
pr_number: list[int] = Field(default_factory=list)
llm_model: str | None = None
metrics: MetricsSnapshot | None = None
parent_conversation_id: OpenHandsUUID | None = None
sub_conversation_ids: list[OpenHandsUUID] = Field(default_factory=list)
created_at: datetime = Field(default_factory=utc_now)
updated_at: datetime = Field(default_factory=utc_now)
class SharedConversationSortOrder(Enum):
CREATED_AT = 'CREATED_AT'
CREATED_AT_DESC = 'CREATED_AT_DESC'
UPDATED_AT = 'UPDATED_AT'
UPDATED_AT_DESC = 'UPDATED_AT_DESC'
TITLE = 'TITLE'
TITLE_DESC = 'TITLE_DESC'
class SharedConversationPage(BaseModel):
items: list[SharedConversation]
next_page_id: str | None = None

View File

@@ -0,0 +1,135 @@
"""Shared Conversation router for OpenHands Server."""
from datetime import datetime
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, Query
from server.sharing.shared_conversation_info_service import (
SharedConversationInfoService,
)
from server.sharing.shared_conversation_models import (
SharedConversation,
SharedConversationPage,
SharedConversationSortOrder,
)
from server.sharing.sql_shared_conversation_info_service import (
SQLSharedConversationInfoServiceInjector,
)
router = APIRouter(prefix='/api/shared-conversations', tags=['Sharing'])
shared_conversation_info_service_dependency = Depends(
SQLSharedConversationInfoServiceInjector().depends
)
# Read methods
@router.get('/search')
async def search_shared_conversations(
title__contains: Annotated[
str | None,
Query(title='Filter by title containing this string'),
] = None,
created_at__gte: Annotated[
datetime | None,
Query(title='Filter by created_at greater than or equal to this datetime'),
] = None,
created_at__lt: Annotated[
datetime | None,
Query(title='Filter by created_at less than this datetime'),
] = None,
updated_at__gte: Annotated[
datetime | None,
Query(title='Filter by updated_at greater than or equal to this datetime'),
] = None,
updated_at__lt: Annotated[
datetime | None,
Query(title='Filter by updated_at less than this datetime'),
] = None,
sort_order: Annotated[
SharedConversationSortOrder,
Query(title='Sort order for results'),
] = SharedConversationSortOrder.CREATED_AT_DESC,
page_id: Annotated[
str | None,
Query(title='Optional next_page_id from the previously returned page'),
] = None,
limit: Annotated[
int,
Query(
title='The max number of results in the page',
gt=0,
lte=100,
),
] = 100,
include_sub_conversations: Annotated[
bool,
Query(
title='If True, include sub-conversations in the results. If False (default), exclude all sub-conversations.'
),
] = False,
shared_conversation_service: SharedConversationInfoService = shared_conversation_info_service_dependency,
) -> SharedConversationPage:
"""Search / List shared conversations."""
assert limit > 0
assert limit <= 100
return await shared_conversation_service.search_shared_conversation_info(
title__contains=title__contains,
created_at__gte=created_at__gte,
created_at__lt=created_at__lt,
updated_at__gte=updated_at__gte,
updated_at__lt=updated_at__lt,
sort_order=sort_order,
page_id=page_id,
limit=limit,
include_sub_conversations=include_sub_conversations,
)
@router.get('/count')
async def count_shared_conversations(
title__contains: Annotated[
str | None,
Query(title='Filter by title containing this string'),
] = None,
created_at__gte: Annotated[
datetime | None,
Query(title='Filter by created_at greater than or equal to this datetime'),
] = None,
created_at__lt: Annotated[
datetime | None,
Query(title='Filter by created_at less than this datetime'),
] = None,
updated_at__gte: Annotated[
datetime | None,
Query(title='Filter by updated_at greater than or equal to this datetime'),
] = None,
updated_at__lt: Annotated[
datetime | None,
Query(title='Filter by updated_at less than this datetime'),
] = None,
shared_conversation_service: SharedConversationInfoService = shared_conversation_info_service_dependency,
) -> int:
"""Count shared conversations matching the given filters."""
return await shared_conversation_service.count_shared_conversation_info(
title__contains=title__contains,
created_at__gte=created_at__gte,
created_at__lt=created_at__lt,
updated_at__gte=updated_at__gte,
updated_at__lt=updated_at__lt,
)
@router.get('')
async def batch_get_shared_conversations(
ids: Annotated[list[str], Query()],
shared_conversation_service: SharedConversationInfoService = shared_conversation_info_service_dependency,
) -> list[SharedConversation | None]:
"""Get a batch of shared conversations given their ids. Return None for any missing or non-shared."""
assert len(ids) <= 100
uuids = [UUID(id_) for id_ in ids]
shared_conversation_info = (
await shared_conversation_service.batch_get_shared_conversation_info(uuids)
)
return shared_conversation_info

View File

@@ -0,0 +1,126 @@
"""Shared Event router for OpenHands Server."""
from datetime import datetime
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, Query
from server.sharing.filesystem_shared_event_service import (
SharedEventServiceImplInjector,
)
from server.sharing.shared_event_service import SharedEventService
from openhands.agent_server.models import EventPage, EventSortOrder
from openhands.app_server.event_callback.event_callback_models import EventKind
from openhands.sdk import Event
router = APIRouter(prefix='/api/shared-events', tags=['Sharing'])
shared_event_service_dependency = Depends(SharedEventServiceImplInjector().depends)
# Read methods
@router.get('/search')
async def search_shared_events(
conversation_id: Annotated[
str,
Query(title='Conversation ID to search events for'),
],
kind__eq: Annotated[
EventKind | None,
Query(title='Optional filter by event kind'),
] = None,
timestamp__gte: Annotated[
datetime | None,
Query(title='Optional filter by timestamp greater than or equal to'),
] = None,
timestamp__lt: Annotated[
datetime | None,
Query(title='Optional filter by timestamp less than'),
] = None,
sort_order: Annotated[
EventSortOrder,
Query(title='Sort order for results'),
] = EventSortOrder.TIMESTAMP,
page_id: Annotated[
str | None,
Query(title='Optional next_page_id from the previously returned page'),
] = None,
limit: Annotated[
int,
Query(title='The max number of results in the page', gt=0, lte=100),
] = 100,
shared_event_service: SharedEventService = shared_event_service_dependency,
) -> EventPage:
"""Search / List events for a shared conversation."""
assert limit > 0
assert limit <= 100
return await shared_event_service.search_shared_events(
conversation_id=UUID(conversation_id),
kind__eq=kind__eq,
timestamp__gte=timestamp__gte,
timestamp__lt=timestamp__lt,
sort_order=sort_order,
page_id=page_id,
limit=limit,
)
@router.get('/count')
async def count_shared_events(
conversation_id: Annotated[
str,
Query(title='Conversation ID to count events for'),
],
kind__eq: Annotated[
EventKind | None,
Query(title='Optional filter by event kind'),
] = None,
timestamp__gte: Annotated[
datetime | None,
Query(title='Optional filter by timestamp greater than or equal to'),
] = None,
timestamp__lt: Annotated[
datetime | None,
Query(title='Optional filter by timestamp less than'),
] = None,
sort_order: Annotated[
EventSortOrder,
Query(title='Sort order for results'),
] = EventSortOrder.TIMESTAMP,
shared_event_service: SharedEventService = shared_event_service_dependency,
) -> int:
"""Count events for a shared conversation matching the given filters."""
return await shared_event_service.count_shared_events(
conversation_id=UUID(conversation_id),
kind__eq=kind__eq,
timestamp__gte=timestamp__gte,
timestamp__lt=timestamp__lt,
sort_order=sort_order,
)
@router.get('')
async def batch_get_shared_events(
conversation_id: Annotated[
UUID,
Query(title='Conversation ID to get events for'),
],
id: Annotated[list[str], Query()],
shared_event_service: SharedEventService = shared_event_service_dependency,
) -> list[Event | None]:
"""Get a batch of events for a shared conversation given their ids, returning null for any missing event."""
assert len(id) <= 100
events = await shared_event_service.batch_get_shared_events(conversation_id, id)
return events
@router.get('/{conversation_id}/{event_id}')
async def get_shared_event(
conversation_id: UUID,
event_id: str,
shared_event_service: SharedEventService = shared_event_service_dependency,
) -> Event | None:
"""Get a single event from a shared conversation by conversation_id and event_id."""
return await shared_event_service.get_shared_event(conversation_id, event_id)

View File

@@ -0,0 +1,64 @@
import asyncio
import logging
from abc import ABC, abstractmethod
from datetime import datetime
from uuid import UUID
from openhands.agent_server.models import EventPage, EventSortOrder
from openhands.app_server.event_callback.event_callback_models import EventKind
from openhands.app_server.services.injector import Injector
from openhands.sdk import Event
from openhands.sdk.utils.models import DiscriminatedUnionMixin
_logger = logging.getLogger(__name__)
class SharedEventService(ABC):
"""Event Service for getting events from shared conversations only."""
@abstractmethod
async def get_shared_event(
self, conversation_id: UUID, event_id: str
) -> Event | None:
"""Given a conversation_id and event_id, retrieve an event if the conversation is shared."""
@abstractmethod
async def search_shared_events(
self,
conversation_id: UUID,
kind__eq: EventKind | None = None,
timestamp__gte: datetime | None = None,
timestamp__lt: datetime | None = None,
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
page_id: str | None = None,
limit: int = 100,
) -> EventPage:
"""Search events for a specific shared conversation."""
@abstractmethod
async def count_shared_events(
self,
conversation_id: UUID,
kind__eq: EventKind | None = None,
timestamp__gte: datetime | None = None,
timestamp__lt: datetime | None = None,
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
) -> int:
"""Count events for a specific shared conversation."""
async def batch_get_shared_events(
self, conversation_id: UUID, event_ids: list[str]
) -> list[Event | None]:
"""Given a conversation_id and list of event_ids, get events if the conversation is shared."""
return await asyncio.gather(
*[
self.get_shared_event(conversation_id, event_id)
for event_id in event_ids
]
)
class SharedEventServiceInjector(
DiscriminatedUnionMixin, Injector[SharedEventService], ABC
):
pass

View File

@@ -0,0 +1,282 @@
"""SQL implementation of SharedConversationInfoService.
This implementation provides read-only access to shared conversations:
- Direct database access without user permission checks
- Filters only conversations marked as shared (currently public)
- Full async/await support using SQL async db_sessions
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from datetime import UTC, datetime
from typing import AsyncGenerator
from uuid import UUID
from fastapi import Request
from server.sharing.shared_conversation_info_service import (
SharedConversationInfoService,
SharedConversationInfoServiceInjector,
)
from server.sharing.shared_conversation_models import (
SharedConversation,
SharedConversationPage,
SharedConversationSortOrder,
)
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
StoredConversationMetadata,
)
from openhands.app_server.services.injector import InjectorState
from openhands.integrations.provider import ProviderType
from openhands.sdk.llm import MetricsSnapshot
from openhands.sdk.llm.utils.metrics import TokenUsage
logger = logging.getLogger(__name__)
@dataclass
class SQLSharedConversationInfoService(SharedConversationInfoService):
"""SQL implementation of SharedConversationInfoService for shared conversations only."""
db_session: AsyncSession
async def search_shared_conversation_info(
self,
title__contains: str | None = None,
created_at__gte: datetime | None = None,
created_at__lt: datetime | None = None,
updated_at__gte: datetime | None = None,
updated_at__lt: datetime | None = None,
sort_order: SharedConversationSortOrder = SharedConversationSortOrder.CREATED_AT_DESC,
page_id: str | None = None,
limit: int = 100,
include_sub_conversations: bool = False,
) -> SharedConversationPage:
"""Search for shared conversations."""
query = self._public_select()
# Conditionally exclude sub-conversations based on the parameter
if not include_sub_conversations:
# Exclude sub-conversations (only include top-level conversations)
query = query.where(
StoredConversationMetadata.parent_conversation_id.is_(None)
)
query = self._apply_filters(
query=query,
title__contains=title__contains,
created_at__gte=created_at__gte,
created_at__lt=created_at__lt,
updated_at__gte=updated_at__gte,
updated_at__lt=updated_at__lt,
)
# Add sort order
if sort_order == SharedConversationSortOrder.CREATED_AT:
query = query.order_by(StoredConversationMetadata.created_at)
elif sort_order == SharedConversationSortOrder.CREATED_AT_DESC:
query = query.order_by(StoredConversationMetadata.created_at.desc())
elif sort_order == SharedConversationSortOrder.UPDATED_AT:
query = query.order_by(StoredConversationMetadata.last_updated_at)
elif sort_order == SharedConversationSortOrder.UPDATED_AT_DESC:
query = query.order_by(StoredConversationMetadata.last_updated_at.desc())
elif sort_order == SharedConversationSortOrder.TITLE:
query = query.order_by(StoredConversationMetadata.title)
elif sort_order == SharedConversationSortOrder.TITLE_DESC:
query = query.order_by(StoredConversationMetadata.title.desc())
# Apply pagination
if page_id is not None:
try:
offset = int(page_id)
query = query.offset(offset)
except ValueError:
# If page_id is not a valid integer, start from beginning
offset = 0
else:
offset = 0
# Apply limit and get one extra to check if there are more results
query = query.limit(limit + 1)
result = await self.db_session.execute(query)
rows = result.scalars().all()
# Check if there are more results
has_more = len(rows) > limit
if has_more:
rows = rows[:limit]
items = [self._to_shared_conversation(row) for row in rows]
# Calculate next page ID
next_page_id = None
if has_more:
next_page_id = str(offset + limit)
return SharedConversationPage(items=items, next_page_id=next_page_id)
async def count_shared_conversation_info(
self,
title__contains: str | None = None,
created_at__gte: datetime | None = None,
created_at__lt: datetime | None = None,
updated_at__gte: datetime | None = None,
updated_at__lt: datetime | None = None,
) -> int:
"""Count shared conversations matching the given filters."""
from sqlalchemy import func
query = select(func.count(StoredConversationMetadata.conversation_id))
# Only include shared conversations
query = query.where(StoredConversationMetadata.public == True) # noqa: E712
query = query.where(StoredConversationMetadata.conversation_version == 'V1')
query = self._apply_filters(
query=query,
title__contains=title__contains,
created_at__gte=created_at__gte,
created_at__lt=created_at__lt,
updated_at__gte=updated_at__gte,
updated_at__lt=updated_at__lt,
)
result = await self.db_session.execute(query)
return result.scalar() or 0
async def get_shared_conversation_info(
self, conversation_id: UUID
) -> SharedConversation | None:
"""Get a single public conversation info, returning None if missing or not shared."""
query = self._public_select().where(
StoredConversationMetadata.conversation_id == str(conversation_id)
)
result = await self.db_session.execute(query)
stored = result.scalar_one_or_none()
if stored is None:
return None
return self._to_shared_conversation(stored)
def _public_select(self):
"""Create a select query that only returns public conversations."""
query = select(StoredConversationMetadata).where(
StoredConversationMetadata.conversation_version == 'V1'
)
# Only include conversations marked as public
query = query.where(StoredConversationMetadata.public == True) # noqa: E712
return query
def _apply_filters(
self,
query,
title__contains: str | None = None,
created_at__gte: datetime | None = None,
created_at__lt: datetime | None = None,
updated_at__gte: datetime | None = None,
updated_at__lt: datetime | None = None,
):
"""Apply common filters to a query."""
if title__contains is not None:
query = query.where(
StoredConversationMetadata.title.contains(title__contains)
)
if created_at__gte is not None:
query = query.where(
StoredConversationMetadata.created_at >= created_at__gte
)
if created_at__lt is not None:
query = query.where(StoredConversationMetadata.created_at < created_at__lt)
if updated_at__gte is not None:
query = query.where(
StoredConversationMetadata.last_updated_at >= updated_at__gte
)
if updated_at__lt is not None:
query = query.where(
StoredConversationMetadata.last_updated_at < updated_at__lt
)
return query
def _to_shared_conversation(
self,
stored: StoredConversationMetadata,
sub_conversation_ids: list[UUID] | None = None,
) -> SharedConversation:
"""Convert StoredConversationMetadata to SharedConversation."""
# V1 conversations should always have a sandbox_id
sandbox_id = stored.sandbox_id
assert sandbox_id is not None
# Rebuild token usage
token_usage = TokenUsage(
prompt_tokens=stored.prompt_tokens,
completion_tokens=stored.completion_tokens,
cache_read_tokens=stored.cache_read_tokens,
cache_write_tokens=stored.cache_write_tokens,
context_window=stored.context_window,
per_turn_token=stored.per_turn_token,
)
# Rebuild metrics object
metrics = MetricsSnapshot(
accumulated_cost=stored.accumulated_cost,
max_budget_per_task=stored.max_budget_per_task,
accumulated_token_usage=token_usage,
)
# Get timestamps
created_at = self._fix_timezone(stored.created_at)
updated_at = self._fix_timezone(stored.last_updated_at)
return SharedConversation(
id=UUID(stored.conversation_id),
created_by_user_id=stored.user_id if stored.user_id else None,
sandbox_id=stored.sandbox_id,
selected_repository=stored.selected_repository,
selected_branch=stored.selected_branch,
git_provider=(
ProviderType(stored.git_provider) if stored.git_provider else None
),
title=stored.title,
pr_number=stored.pr_number,
llm_model=stored.llm_model,
metrics=metrics,
parent_conversation_id=(
UUID(stored.parent_conversation_id)
if stored.parent_conversation_id
else None
),
sub_conversation_ids=sub_conversation_ids or [],
created_at=created_at,
updated_at=updated_at,
)
def _fix_timezone(self, value: datetime) -> datetime:
"""Sqlite does not store timezones - and since we can't update the existing models
we assume UTC if the timezone is missing."""
if not value.tzinfo:
value = value.replace(tzinfo=UTC)
return value
class SQLSharedConversationInfoServiceInjector(SharedConversationInfoServiceInjector):
async def inject(
self, state: InjectorState, request: Request | None = None
) -> AsyncGenerator[SharedConversationInfoService, None]:
# Define inline to prevent circular lookup
from openhands.app_server.config import get_db_session
async with get_db_session(state, request) as db_session:
service = SQLSharedConversationInfoService(db_session=db_session)
yield service

View File

@@ -0,0 +1,83 @@
from fastapi import HTTPException, Request, status
from openhands.core.logger import openhands_logger as logger
from openhands.server.shared import sio
# Rate limiting constants
RATE_LIMIT_USER_SECONDS = 120 # 2 minutes per user_id
RATE_LIMIT_IP_SECONDS = 300 # 5 minutes per IP address
async def check_rate_limit_by_user_id(
request: Request,
key_prefix: str,
user_id: str | None,
user_rate_limit_seconds: int = RATE_LIMIT_USER_SECONDS,
ip_rate_limit_seconds: int = RATE_LIMIT_IP_SECONDS,
) -> None:
"""
Check rate limit for requests, using user_id when available, falling back to IP address.
Uses Redis to store rate limit keys with expiration. If a key already exists,
it means the rate limit is active and the request will be rejected.
Args:
request: FastAPI Request object
key_prefix: Prefix for the Redis key (e.g., "email_resend")
user_id: User ID if available, None otherwise
user_rate_limit_seconds: Rate limit window in seconds for user_id-based limiting (default: 120)
ip_rate_limit_seconds: Rate limit window in seconds for IP-based limiting (default: 300)
Raises:
HTTPException: If rate limit is exceeded (429 status code)
"""
try:
redis = sio.manager.redis
if not redis:
# If Redis is unavailable, log warning and allow request (fail open)
logger.warning('Redis unavailable for rate limiting, allowing request')
return
if user_id:
# Rate limit by user_id (primary method)
rate_limit_key = f'{key_prefix}:{user_id}'
rate_limit_seconds = user_rate_limit_seconds
else:
# Fallback to IP address rate limiting
client_ip = request.client.host if request.client else 'unknown'
rate_limit_key = f'{key_prefix}:ip:{client_ip}'
rate_limit_seconds = ip_rate_limit_seconds
# Try to set the key with expiration. If it already exists (nx=True fails),
# it means the rate limit is active
created = await redis.set(rate_limit_key, 1, nx=True, ex=rate_limit_seconds)
if not created:
logger.info(
f'Rate limit exceeded for {rate_limit_key}',
extra={
'user_id': user_id,
'ip': request.client.host if request.client else 'unknown',
},
)
# Format error message based on duration
if rate_limit_seconds < 60:
wait_message = f'{rate_limit_seconds} seconds'
elif rate_limit_seconds % 60 == 0:
wait_message = f'{rate_limit_seconds // 60} minute{"s" if rate_limit_seconds // 60 != 1 else ""}'
else:
minutes = rate_limit_seconds // 60
seconds = rate_limit_seconds % 60
wait_message = f'{minutes} minute{"s" if minutes != 1 else ""} and {seconds} second{"s" if seconds != 1 else ""}'
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail=f'Too many requests. Please wait {wait_message} before trying again.',
)
except HTTPException:
# Re-raise HTTPException (rate limit exceeded)
raise
except Exception as e:
# Log error but allow request (fail open) to avoid blocking legitimate users
logger.warning(f'Error checking rate limit: {e}', exc_info=True)
return

View File

@@ -19,17 +19,23 @@ GCP_REGION = os.environ.get('GCP_REGION')
POOL_SIZE = int(os.environ.get('DB_POOL_SIZE', '25'))
MAX_OVERFLOW = int(os.environ.get('DB_MAX_OVERFLOW', '10'))
POOL_RECYCLE = int(os.environ.get('DB_POOL_RECYCLE', '1800'))
# Initialize Cloud SQL Connector once at module level for GCP environments.
_connector = None
def _get_db_engine():
if GCP_DB_INSTANCE: # GCP environments
def get_db_connection():
global _connector
from google.cloud.sql.connector import Connector
connector = Connector()
if not _connector:
_connector = Connector()
instance_string = f'{GCP_PROJECT}:{GCP_REGION}:{GCP_DB_INSTANCE}'
return connector.connect(
return _connector.connect(
instance_string, 'pg8000', user=DB_USER, password=DB_PASS, db=DB_NAME
)
@@ -38,6 +44,7 @@ def _get_db_engine():
creator=get_db_connection,
pool_size=POOL_SIZE,
max_overflow=MAX_OVERFLOW,
pool_recycle=POOL_RECYCLE,
pool_pre_ping=True,
)
else:
@@ -48,6 +55,7 @@ def _get_db_engine():
host_string,
pool_size=POOL_SIZE,
max_overflow=MAX_OVERFLOW,
pool_recycle=POOL_RECYCLE,
pool_pre_ping=True,
)

View File

@@ -61,6 +61,7 @@ class SaasConversationStore(ConversationStore):
kwargs.pop('context_window', None)
kwargs.pop('per_turn_token', None)
kwargs.pop('parent_conversation_id', None)
kwargs.pop('public')
return ConversationMetadata(**kwargs)

View File

@@ -19,6 +19,7 @@ from server.constants import (
LITE_LLM_API_URL,
LITE_LLM_TEAM_ID,
REQUIRE_PAYMENT,
USER_SETTINGS_VERSION_TO_MODEL,
get_default_litellm_model,
)
from server.logger import logger
@@ -202,6 +203,53 @@ class SaasSettingsStore(SettingsStore):
)
return None
def _has_custom_settings(
self, settings: Settings, old_user_version: int | None
) -> bool:
"""
Check if user has custom LLM settings that should be preserved.
Returns True if user customized either model or base_url.
Args:
settings: The user's current settings
old_user_version: The user's old settings version, if any
Returns:
True if user has custom settings, False if using old defaults
"""
# Normalize values
user_model = (
settings.llm_model.strip()
if settings.llm_model and settings.llm_model.strip()
else None
)
user_base_url = (
settings.llm_base_url.strip()
if settings.llm_base_url and settings.llm_base_url.strip()
else None
)
# Custom base_url = definitely custom settings (BYOK)
if user_base_url and user_base_url != LITE_LLM_API_URL:
return True
# No model set = using defaults
if not user_model:
return False
# Check if model matches old version's default
if (
old_user_version
and old_user_version < CURRENT_USER_SETTINGS_VERSION
and old_user_version in USER_SETTINGS_VERSION_TO_MODEL
):
old_default_base = USER_SETTINGS_VERSION_TO_MODEL[old_user_version]
user_model_base = user_model.split('/')[-1]
if user_model_base == old_default_base:
return False # Matches old default
return True # Custom model
async def update_settings_with_litellm_default(
self, settings: Settings
) -> Settings | None:
@@ -213,6 +261,17 @@ class SaasSettingsStore(SettingsStore):
return None
local_deploy = os.environ.get('LOCAL_DEPLOYMENT', None)
key = LITE_LLM_API_KEY
# Check if user has custom settings
has_custom = self._has_custom_settings(settings, settings.user_version)
# Determine model to use (needed before LiteLLM user creation)
llm_model_to_use = (
settings.llm_model
if has_custom and settings.llm_model
else get_default_litellm_model()
)
if not local_deploy:
# Get user info to add to litellm
token_manager = TokenManager()
@@ -276,7 +335,7 @@ class SaasSettingsStore(SettingsStore):
# Create the new litellm user
response = await self._create_user_in_lite_llm(
client, email, max_budget, spend
client, email, max_budget, spend, llm_model_to_use
)
if not response.is_success:
logger.warning(
@@ -285,7 +344,7 @@ class SaasSettingsStore(SettingsStore):
)
# Litellm insists on unique email addresses - it is possible the email address was registered with a different user.
response = await self._create_user_in_lite_llm(
client, None, max_budget, spend
client, None, max_budget, spend, llm_model_to_use
)
# User failed to create in litellm - this is an unforseen error state...
@@ -311,11 +370,17 @@ class SaasSettingsStore(SettingsStore):
extra={'user_id': self.user_id},
)
if has_custom:
settings.llm_model = settings.llm_model or get_default_litellm_model()
settings.llm_base_url = settings.llm_base_url or LITE_LLM_API_URL
settings.llm_api_key = settings.llm_api_key or SecretStr(key)
else:
settings.llm_model = get_default_litellm_model()
settings.llm_base_url = LITE_LLM_API_URL
settings.llm_api_key = SecretStr(key)
settings.agent = 'CodeActAgent'
# Use the model corresponding to the current user settings version
settings.llm_model = get_default_litellm_model()
settings.llm_api_key = SecretStr(key)
settings.llm_base_url = LITE_LLM_API_URL
return settings
@classmethod
@@ -398,7 +463,12 @@ class SaasSettingsStore(SettingsStore):
)
async def _create_user_in_lite_llm(
self, client: httpx.AsyncClient, email: str | None, max_budget: int, spend: int
self,
client: httpx.AsyncClient,
email: str | None,
max_budget: int,
spend: int,
llm_model: str,
):
response = await client.post(
f'{LITE_LLM_API_URL}/user/new',
@@ -413,7 +483,7 @@ class SaasSettingsStore(SettingsStore):
'send_invite_email': False,
'metadata': {
'version': CURRENT_USER_SETTINGS_VERSION,
'model': get_default_litellm_model(),
'model': llm_model,
},
'key_alias': f'OpenHands Cloud - user {self.user_id}',
},

View File

@@ -4,6 +4,8 @@ from uuid import uuid4
from integrations.types import GitLabResourceType
from integrations.utils import GITLAB_WEBHOOK_URL
from sqlalchemy import text
from storage.database import a_session_maker
from storage.gitlab_webhook import GitlabWebhook, WebhookStatus
from storage.gitlab_webhook_store import GitlabWebhookStore
@@ -258,6 +260,25 @@ class VerifyWebhookStatus:
from integrations.gitlab.gitlab_service import SaaSGitLabService
# Check if the table exists before proceeding
# This handles cases where the CronJob runs before database migrations complete
async with a_session_maker() as session:
query = text("""
SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_name = 'gitlab_webhook'
)
""")
result = await session.execute(query)
table_exists = result.scalar() or False
if not table_exists:
logger.info(
'gitlab_webhook table does not exist yet, '
'waiting for database migrations to complete'
)
return
# Get an instance of the webhook store
webhook_store = await GitlabWebhookStore.get_instance()

View File

@@ -1,7 +1,12 @@
"""Tests for enterprise integrations utils module."""
from unittest.mock import patch
import pytest
from integrations.utils import get_summary_for_agent_state
from integrations.utils import (
append_conversation_footer,
get_summary_for_agent_state,
)
from openhands.core.schema.agent import AgentState
from openhands.events.observation.agent import AgentStateChangedObservation
@@ -157,3 +162,138 @@ class TestGetSummaryForAgentState:
assert 'try again later' in result.lower()
# RATE_LIMITED doesn't include conversation link in response
assert self.conversation_link not in result
class TestAppendConversationFooter:
"""Test cases for append_conversation_footer function."""
@patch(
'integrations.utils.CONVERSATION_URL', 'https://example.com/conversations/{}'
)
def test_appends_footer_with_markdown_link(self):
"""Test that footer is appended with correct markdown link format."""
# Arrange
message = 'This is a test message'
conversation_id = 'test-conv-123'
# Act
result = append_conversation_footer(message, conversation_id)
# Assert
assert result.startswith(message)
assert (
'[View full conversation](https://example.com/conversations/test-conv-123)'
in result
)
assert result.endswith(
'[View full conversation](https://example.com/conversations/test-conv-123)'
)
@patch(
'integrations.utils.CONVERSATION_URL', 'https://example.com/conversations/{}'
)
def test_footer_does_not_contain_html_tags(self):
"""Test that footer does not contain HTML tags like <sub>."""
# Arrange
message = 'Test message'
conversation_id = 'test-conv-456'
# Act
result = append_conversation_footer(message, conversation_id)
# Assert
assert '<sub>' not in result
assert '</sub>' not in result
@patch(
'integrations.utils.CONVERSATION_URL', 'https://example.com/conversations/{}'
)
def test_footer_format_with_newlines(self):
"""Test that footer is properly separated with newlines."""
# Arrange
message = 'Original message content'
conversation_id = 'test-conv-789'
# Act
result = append_conversation_footer(message, conversation_id)
# Assert
assert (
result
== 'Original message content\n\n[View full conversation](https://example.com/conversations/test-conv-789)'
)
@patch(
'integrations.utils.CONVERSATION_URL', 'https://example.com/conversations/{}'
)
def test_empty_message_still_appends_footer(self):
"""Test that footer is appended even when message is empty."""
# Arrange
message = ''
conversation_id = 'empty-msg-conv'
# Act
result = append_conversation_footer(message, conversation_id)
# Assert
assert result.startswith('\n\n')
assert (
'[View full conversation](https://example.com/conversations/empty-msg-conv)'
in result
)
@patch(
'integrations.utils.CONVERSATION_URL', 'https://example.com/conversations/{}'
)
def test_conversation_id_with_special_characters(self):
"""Test that footer handles conversation IDs with special characters."""
# Arrange
message = 'Test message'
conversation_id = 'conv-123_abc-456'
# Act
result = append_conversation_footer(message, conversation_id)
# Assert
expected_url = 'https://example.com/conversations/conv-123_abc-456'
assert expected_url in result
assert '[View full conversation]' in result
@patch(
'integrations.utils.CONVERSATION_URL', 'https://example.com/conversations/{}'
)
def test_multiline_message_preserves_content(self):
"""Test that multiline messages are preserved correctly."""
# Arrange
message = 'Line 1\nLine 2\nLine 3'
conversation_id = 'multiline-conv'
# Act
result = append_conversation_footer(message, conversation_id)
# Assert
assert result.startswith('Line 1\nLine 2\nLine 3')
assert '\n\n[View full conversation]' in result
assert message in result
@patch(
'integrations.utils.CONVERSATION_URL', 'https://example.com/conversations/{}'
)
def test_footer_contains_only_markdown_syntax(self):
"""Test that footer uses only markdown syntax, not HTML."""
# Arrange
message = 'Test message'
conversation_id = 'markdown-test'
# Act
result = append_conversation_footer(message, conversation_id)
# Assert
footer_part = result[len(message) :]
# Should only contain markdown link syntax: [text](url)
assert footer_part.startswith('\n\n[')
assert '](' in footer_part
assert footer_part.endswith(')')
# Should not contain any HTML tags (specifically <sub> tags that were removed)
assert '<sub>' not in footer_part
assert '</sub>' not in footer_part

View File

@@ -0,0 +1,361 @@
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import HTTPException, Request, status
from fastapi.responses import JSONResponse, RedirectResponse
from pydantic import SecretStr
from server.auth.saas_user_auth import SaasUserAuth
from server.routes.email import (
ResendEmailVerificationRequest,
resend_email_verification,
verified_email,
verify_email,
)
@pytest.fixture
def mock_request():
"""Create a mock request object."""
request = MagicMock(spec=Request)
request.url = MagicMock()
request.url.hostname = 'localhost'
request.url.netloc = 'localhost:8000'
request.url.path = '/api/email/verified'
request.base_url = 'http://localhost:8000/'
request.headers = {}
request.cookies = {}
request.query_params = MagicMock()
return request
@pytest.fixture
def mock_user_auth():
"""Create a mock SaasUserAuth object."""
auth = MagicMock(spec=SaasUserAuth)
auth.access_token = SecretStr('test_access_token')
auth.refresh_token = SecretStr('test_refresh_token')
auth.email = 'test@example.com'
auth.email_verified = False
auth.accepted_tos = True
auth.refresh = AsyncMock()
return auth
@pytest.mark.asyncio
async def test_verify_email_default_behavior(mock_request):
"""Test verify_email with default is_auth_flow=False."""
# Arrange
user_id = 'test_user_id'
mock_keycloak_admin = AsyncMock()
mock_keycloak_admin.a_send_verify_email = AsyncMock()
# Act
with patch(
'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin
):
await verify_email(request=mock_request, user_id=user_id)
# Assert
mock_keycloak_admin.a_send_verify_email.assert_called_once()
call_args = mock_keycloak_admin.a_send_verify_email.call_args
assert call_args.kwargs['user_id'] == user_id
assert (
call_args.kwargs['redirect_uri'] == 'http://localhost:8000/api/email/verified'
)
assert 'client_id' in call_args.kwargs
@pytest.mark.asyncio
async def test_verify_email_with_auth_flow(mock_request):
"""Test verify_email with is_auth_flow=True."""
# Arrange
user_id = 'test_user_id'
mock_keycloak_admin = AsyncMock()
mock_keycloak_admin.a_send_verify_email = AsyncMock()
# Act
with patch(
'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin
):
await verify_email(request=mock_request, user_id=user_id, is_auth_flow=True)
# Assert
mock_keycloak_admin.a_send_verify_email.assert_called_once()
call_args = mock_keycloak_admin.a_send_verify_email.call_args
assert call_args.kwargs['user_id'] == user_id
assert (
call_args.kwargs['redirect_uri'] == 'http://localhost:8000?email_verified=true'
)
assert 'client_id' in call_args.kwargs
@pytest.mark.asyncio
async def test_verify_email_https_scheme(mock_request):
"""Test verify_email uses https scheme for non-localhost hosts."""
# Arrange
user_id = 'test_user_id'
mock_request.url.hostname = 'example.com'
mock_request.url.netloc = 'example.com'
mock_keycloak_admin = AsyncMock()
mock_keycloak_admin.a_send_verify_email = AsyncMock()
# Act
with patch(
'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin
):
await verify_email(request=mock_request, user_id=user_id, is_auth_flow=True)
# Assert
call_args = mock_keycloak_admin.a_send_verify_email.call_args
assert call_args.kwargs['redirect_uri'].startswith('https://')
@pytest.mark.asyncio
async def test_verified_email_default_redirect(mock_request, mock_user_auth):
"""Test verified_email redirects to /settings/user by default."""
# Arrange
mock_request.query_params.get.return_value = None
# Act
with (
patch('server.routes.email.get_user_auth', return_value=mock_user_auth),
patch('server.routes.email.set_response_cookie') as mock_set_cookie,
):
result = await verified_email(mock_request)
# Assert
assert isinstance(result, RedirectResponse)
assert result.status_code == 302
assert result.headers['location'] == 'http://localhost:8000/settings/user'
mock_user_auth.refresh.assert_called_once()
mock_set_cookie.assert_called_once()
assert mock_user_auth.email_verified is True
@pytest.mark.asyncio
async def test_verified_email_https_scheme(mock_request, mock_user_auth):
"""Test verified_email uses https scheme for non-localhost hosts."""
# Arrange
mock_request.url.hostname = 'example.com'
mock_request.url.netloc = 'example.com'
mock_request.query_params.get.return_value = None
# Act
with (
patch('server.routes.email.get_user_auth', return_value=mock_user_auth),
patch('server.routes.email.set_response_cookie') as mock_set_cookie,
):
result = await verified_email(mock_request)
# Assert
assert isinstance(result, RedirectResponse)
assert result.headers['location'].startswith('https://')
mock_set_cookie.assert_called_once()
# Verify secure flag is True for https
call_kwargs = mock_set_cookie.call_args.kwargs
assert call_kwargs['secure'] is True
@pytest.mark.asyncio
async def test_resend_email_verification_with_user_id_from_body_succeeds(mock_request):
"""Test resend_email_verification succeeds when user_id is provided in body."""
# Arrange
user_id = 'test_user_id'
body = ResendEmailVerificationRequest(user_id=user_id, is_auth_flow=False)
mock_keycloak_admin = AsyncMock()
mock_keycloak_admin.a_send_verify_email = AsyncMock()
with (
patch('server.routes.email.check_rate_limit_by_user_id') as mock_rate_limit,
patch(
'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin
),
patch('server.routes.email.logger') as mock_logger,
):
mock_rate_limit.return_value = None # Rate limit check passes
# Act
result = await resend_email_verification(request=mock_request, body=body)
# Assert
assert isinstance(result, JSONResponse)
assert result.status_code == status.HTTP_200_OK
assert 'message' in result.body.decode()
mock_rate_limit.assert_called_once_with(
request=mock_request,
key_prefix='email_resend',
user_id=user_id,
user_rate_limit_seconds=30,
ip_rate_limit_seconds=60,
)
mock_keycloak_admin.a_send_verify_email.assert_called_once()
# Logger is called multiple times (verify_email and resend_email_verification)
# Check that the resend message was logged
assert any(
'Resending verification email for' in str(call)
for call in mock_logger.info.call_args_list
)
@pytest.mark.asyncio
async def test_resend_email_verification_with_user_id_from_auth_succeeds(mock_request):
"""Test resend_email_verification succeeds when user_id comes from authentication."""
# Arrange
user_id = 'test_user_id'
mock_keycloak_admin = AsyncMock()
mock_keycloak_admin.a_send_verify_email = AsyncMock()
with (
patch(
'server.routes.email.get_user_id', return_value=user_id
) as mock_get_user_id,
patch('server.routes.email.check_rate_limit_by_user_id') as mock_rate_limit,
patch(
'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin
),
):
mock_rate_limit.return_value = None # Rate limit check passes
# Act
result = await resend_email_verification(request=mock_request, body=None)
# Assert
assert isinstance(result, JSONResponse)
assert result.status_code == status.HTTP_200_OK
mock_get_user_id.assert_called_once_with(mock_request)
mock_rate_limit.assert_called_once_with(
request=mock_request,
key_prefix='email_resend',
user_id=user_id,
user_rate_limit_seconds=30,
ip_rate_limit_seconds=60,
)
@pytest.mark.asyncio
async def test_resend_email_verification_without_user_id_returns_400(mock_request):
"""Test resend_email_verification returns 400 when user_id is not available."""
# Arrange
with patch(
'server.routes.email.get_user_id', side_effect=Exception('Not authenticated')
):
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
await resend_email_verification(request=mock_request, body=None)
assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST
assert 'user_id is required' in exc_info.value.detail
@pytest.mark.asyncio
async def test_resend_email_verification_rate_limit_exceeded_returns_429(mock_request):
"""Test resend_email_verification returns 429 when rate limit is exceeded."""
# Arrange
user_id = 'test_user_id'
body = ResendEmailVerificationRequest(user_id=user_id)
with (
patch('server.routes.email.check_rate_limit_by_user_id') as mock_rate_limit,
):
mock_rate_limit.side_effect = HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail='Too many requests. Please wait 2 minutes before trying again.',
)
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
await resend_email_verification(request=mock_request, body=body)
assert exc_info.value.status_code == status.HTTP_429_TOO_MANY_REQUESTS
assert 'Too many requests' in exc_info.value.detail
mock_rate_limit.assert_called_once()
@pytest.mark.asyncio
async def test_resend_email_verification_with_is_auth_flow_true(mock_request):
"""Test resend_email_verification passes is_auth_flow to verify_email."""
# Arrange
user_id = 'test_user_id'
body = ResendEmailVerificationRequest(user_id=user_id, is_auth_flow=True)
mock_keycloak_admin = AsyncMock()
mock_keycloak_admin.a_send_verify_email = AsyncMock()
with (
patch('server.routes.email.check_rate_limit_by_user_id') as mock_rate_limit,
patch(
'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin
),
):
mock_rate_limit.return_value = None
# Act
await resend_email_verification(request=mock_request, body=body)
# Assert
mock_keycloak_admin.a_send_verify_email.assert_called_once()
call_args = mock_keycloak_admin.a_send_verify_email.call_args
# Verify that verify_email was called with is_auth_flow=True
# We check this indirectly by verifying the redirect_uri
assert 'email_verified=true' in call_args.kwargs['redirect_uri']
@pytest.mark.asyncio
async def test_resend_email_verification_with_is_auth_flow_false(mock_request):
"""Test resend_email_verification uses default is_auth_flow=False when not specified."""
# Arrange
user_id = 'test_user_id'
body = ResendEmailVerificationRequest(user_id=user_id, is_auth_flow=False)
mock_keycloak_admin = AsyncMock()
mock_keycloak_admin.a_send_verify_email = AsyncMock()
with (
patch('server.routes.email.check_rate_limit_by_user_id') as mock_rate_limit,
patch(
'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin
),
):
mock_rate_limit.return_value = None
# Act
await resend_email_verification(request=mock_request, body=body)
# Assert
mock_keycloak_admin.a_send_verify_email.assert_called_once()
call_args = mock_keycloak_admin.a_send_verify_email.call_args
# Verify that verify_email was called with is_auth_flow=False
assert '/api/email/verified' in call_args.kwargs['redirect_uri']
@pytest.mark.asyncio
async def test_resend_email_verification_body_none_uses_auth(mock_request):
"""Test resend_email_verification uses auth when body is None."""
# Arrange
user_id = 'test_user_id'
mock_keycloak_admin = AsyncMock()
mock_keycloak_admin.a_send_verify_email = AsyncMock()
with (
patch(
'server.routes.email.get_user_id', return_value=user_id
) as mock_get_user_id,
patch('server.routes.email.check_rate_limit_by_user_id') as mock_rate_limit,
patch(
'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin
),
):
mock_rate_limit.return_value = None
# Act
result = await resend_email_verification(request=mock_request, body=None)
# Assert
assert isinstance(result, JSONResponse)
assert result.status_code == status.HTTP_200_OK
mock_get_user_id.assert_called_once()
mock_rate_limit.assert_called_once_with(
request=mock_request,
key_prefix='email_resend',
user_id=user_id,
user_rate_limit_seconds=30,
ip_rate_limit_seconds=60,
)

View File

@@ -699,12 +699,11 @@ class TestProcessBatchOperationsBackground:
# Should not raise exceptions
await _process_batch_operations_background(batch_ops, 'test-api-key')
# Should log the error
mock_logger.error.assert_called_once_with(
'error_processing_batch_operation',
extra={
'path': 'invalid-path',
'method': 'BatchMethod.POST',
'error': mock_logger.error.call_args[1]['extra']['error'],
},
)
# Should log the error with exception type and message in the log message
mock_logger.error.assert_called_once()
call_args = mock_logger.error.call_args
log_message = call_args[0][0]
assert log_message.startswith('error_processing_batch_operation:')
assert call_args[1]['extra']['path'] == 'invalid-path'
assert call_args[1]['extra']['method'] == 'BatchMethod.POST'
assert call_args[1]['exc_info'] is True

View File

@@ -0,0 +1,290 @@
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import HTTPException, Request, status
from server.utils.rate_limit_utils import (
RATE_LIMIT_IP_SECONDS,
RATE_LIMIT_USER_SECONDS,
check_rate_limit_by_user_id,
)
@pytest.fixture
def mock_request():
"""Create a mock request object."""
request = MagicMock(spec=Request)
request.client = MagicMock()
request.client.host = '192.168.1.1'
return request
@pytest.fixture
def mock_redis():
"""Create a mock Redis client."""
redis = AsyncMock()
redis.set = AsyncMock(return_value=True) # First call succeeds (key doesn't exist)
return redis
@pytest.mark.asyncio
async def test_rate_limit_by_user_id_first_request_succeeds(mock_request, mock_redis):
"""Test that first request with user_id succeeds and sets rate limit key."""
# Arrange
user_id = 'test_user_id'
key_prefix = 'email_resend'
with (
patch('server.utils.rate_limit_utils.sio') as mock_sio,
patch('server.utils.rate_limit_utils.logger') as mock_logger,
):
mock_sio.manager.redis = mock_redis
# Act
await check_rate_limit_by_user_id(
request=mock_request, key_prefix=key_prefix, user_id=user_id
)
# Assert
mock_redis.set.assert_called_once_with(
f'{key_prefix}:{user_id}', 1, nx=True, ex=RATE_LIMIT_USER_SECONDS
)
mock_logger.warning.assert_not_called()
mock_logger.info.assert_not_called()
@pytest.mark.asyncio
async def test_rate_limit_by_user_id_second_request_within_window_fails(
mock_request, mock_redis
):
"""Test that second request with same user_id within rate limit window fails."""
# Arrange
user_id = 'test_user_id'
key_prefix = 'email_resend'
mock_redis.set = AsyncMock(return_value=False) # Key already exists
with (
patch('server.utils.rate_limit_utils.sio') as mock_sio,
patch('server.utils.rate_limit_utils.logger') as mock_logger,
):
mock_sio.manager.redis = mock_redis
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
await check_rate_limit_by_user_id(
request=mock_request, key_prefix=key_prefix, user_id=user_id
)
assert exc_info.value.status_code == status.HTTP_429_TOO_MANY_REQUESTS
assert 'Too many requests' in exc_info.value.detail
assert f'{RATE_LIMIT_USER_SECONDS // 60} minutes' in exc_info.value.detail
mock_logger.info.assert_called_once()
@pytest.mark.asyncio
async def test_rate_limit_by_ip_when_user_id_is_none(mock_request, mock_redis):
"""Test that rate limiting falls back to IP address when user_id is None."""
# Arrange
key_prefix = 'email_resend'
with (
patch('server.utils.rate_limit_utils.sio') as mock_sio,
patch('server.utils.rate_limit_utils.logger') as mock_logger,
):
mock_sio.manager.redis = mock_redis
# Act
await check_rate_limit_by_user_id(
request=mock_request, key_prefix=key_prefix, user_id=None
)
# Assert
mock_redis.set.assert_called_once_with(
f'{key_prefix}:ip:{mock_request.client.host}',
1,
nx=True,
ex=RATE_LIMIT_IP_SECONDS,
)
mock_logger.warning.assert_not_called()
@pytest.mark.asyncio
async def test_rate_limit_by_ip_second_request_within_window_fails(
mock_request, mock_redis
):
"""Test that second request from same IP within rate limit window fails."""
# Arrange
key_prefix = 'email_resend'
mock_redis.set = AsyncMock(return_value=False) # Key already exists
with (
patch('server.utils.rate_limit_utils.sio') as mock_sio,
):
mock_sio.manager.redis = mock_redis
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
await check_rate_limit_by_user_id(
request=mock_request, key_prefix=key_prefix, user_id=None
)
assert exc_info.value.status_code == status.HTTP_429_TOO_MANY_REQUESTS
assert f'{RATE_LIMIT_IP_SECONDS // 60} minutes' in exc_info.value.detail
@pytest.mark.asyncio
async def test_rate_limit_redis_unavailable_fails_open(mock_request):
"""Test that rate limiting fails open when Redis is unavailable."""
# Arrange
key_prefix = 'email_resend'
user_id = 'test_user_id'
with (
patch('server.utils.rate_limit_utils.sio') as mock_sio,
patch('server.utils.rate_limit_utils.logger') as mock_logger,
):
mock_sio.manager.redis = None # Redis unavailable
# Act
await check_rate_limit_by_user_id(
request=mock_request, key_prefix=key_prefix, user_id=user_id
)
# Assert
mock_logger.warning.assert_called_once_with(
'Redis unavailable for rate limiting, allowing request'
)
@pytest.mark.asyncio
async def test_rate_limit_redis_exception_fails_open(mock_request, mock_redis):
"""Test that rate limiting fails open when Redis raises an exception."""
# Arrange
key_prefix = 'email_resend'
user_id = 'test_user_id'
mock_redis.set = AsyncMock(side_effect=Exception('Redis connection error'))
with (
patch('server.utils.rate_limit_utils.sio') as mock_sio,
patch('server.utils.rate_limit_utils.logger') as mock_logger,
):
mock_sio.manager.redis = mock_redis
# Act
await check_rate_limit_by_user_id(
request=mock_request, key_prefix=key_prefix, user_id=user_id
)
# Assert
mock_logger.warning.assert_called_once()
assert 'Error checking rate limit' in str(mock_logger.warning.call_args[0][0])
@pytest.mark.asyncio
async def test_rate_limit_custom_key_prefix(mock_request, mock_redis):
"""Test that different key prefixes create different rate limit keys."""
# Arrange
user_id = 'test_user_id'
key_prefix = 'password_reset'
with patch('server.utils.rate_limit_utils.sio') as mock_sio:
mock_sio.manager.redis = mock_redis
# Act
await check_rate_limit_by_user_id(
request=mock_request, key_prefix=key_prefix, user_id=user_id
)
# Assert
mock_redis.set.assert_called_once_with(
f'{key_prefix}:{user_id}', 1, nx=True, ex=RATE_LIMIT_USER_SECONDS
)
@pytest.mark.asyncio
async def test_rate_limit_custom_rate_limit_seconds(mock_request, mock_redis):
"""Test that custom rate limit seconds are used correctly."""
# Arrange
user_id = 'test_user_id'
key_prefix = 'email_resend'
custom_user_seconds = 60
custom_ip_seconds = 180
with patch('server.utils.rate_limit_utils.sio') as mock_sio:
mock_sio.manager.redis = mock_redis
# Act
await check_rate_limit_by_user_id(
request=mock_request,
key_prefix=key_prefix,
user_id=user_id,
user_rate_limit_seconds=custom_user_seconds,
ip_rate_limit_seconds=custom_ip_seconds,
)
# Assert
mock_redis.set.assert_called_once_with(
f'{key_prefix}:{user_id}', 1, nx=True, ex=custom_user_seconds
)
@pytest.mark.asyncio
async def test_rate_limit_ip_with_unknown_client(mock_request, mock_redis):
"""Test that rate limiting handles missing client host gracefully."""
# Arrange
key_prefix = 'email_resend'
mock_request.client = None # No client information
with patch('server.utils.rate_limit_utils.sio') as mock_sio:
mock_sio.manager.redis = mock_redis
# Act
await check_rate_limit_by_user_id(
request=mock_request, key_prefix=key_prefix, user_id=None
)
# Assert
mock_redis.set.assert_called_once_with(
f'{key_prefix}:ip:unknown', 1, nx=True, ex=RATE_LIMIT_IP_SECONDS
)
@pytest.mark.asyncio
async def test_rate_limit_different_users_have_separate_limits(
mock_request, mock_redis
):
"""Test that different user_ids have separate rate limit keys."""
# Arrange
key_prefix = 'email_resend'
user_id_1 = 'user_1'
user_id_2 = 'user_2'
with patch('server.utils.rate_limit_utils.sio') as mock_sio:
mock_sio.manager.redis = mock_redis
# Act
await check_rate_limit_by_user_id(
request=mock_request, key_prefix=key_prefix, user_id=user_id_1
)
await check_rate_limit_by_user_id(
request=mock_request, key_prefix=key_prefix, user_id=user_id_2
)
# Assert
assert mock_redis.set.call_count == 2
# Extract call arguments properly
call_args_list = [
(call[0][0], call[0][1], call[1]['nx'], call[1]['ex'])
for call in mock_redis.set.call_args_list
]
assert (
f'{key_prefix}:{user_id_1}',
1,
True,
RATE_LIMIT_USER_SECONDS,
) in call_args_list
assert (
f'{key_prefix}:{user_id_2}',
1,
True,
RATE_LIMIT_USER_SECONDS,
) in call_args_list

View File

@@ -234,3 +234,53 @@ async def test_middleware_with_other_auth_error(middleware, mock_request):
assert 'set-cookie' in result.headers
# Logger should be called for non-NoCredentialsError
mock_logger.warning.assert_called_once()
@pytest.mark.asyncio
async def test_middleware_ignores_email_resend_path(
middleware, mock_request, mock_response
):
"""Test middleware ignores /api/email/resend path and doesn't require authentication."""
# Arrange
mock_request.cookies = {}
mock_request.url = MagicMock()
mock_request.url.hostname = 'localhost'
mock_request.url.path = '/api/email/resend'
mock_call_next = AsyncMock(return_value=mock_response)
# Act
result = await middleware(mock_request, mock_call_next)
# Assert
assert result == mock_response
mock_call_next.assert_called_once_with(mock_request)
# Should not raise NoCredentialsError even without auth cookie
@pytest.mark.asyncio
async def test_middleware_ignores_email_resend_path_no_tos_check(
middleware, mock_request, mock_response
):
"""Test middleware doesn't check TOS for /api/email/resend path."""
# Arrange
mock_request.cookies = {'keycloak_auth': 'test_cookie'}
mock_request.url = MagicMock()
mock_request.url.hostname = 'localhost'
mock_request.url.path = '/api/email/resend'
mock_call_next = AsyncMock(return_value=mock_response)
with (
patch('server.middleware.jwt.decode') as mock_decode,
patch('server.middleware.config') as mock_config,
):
# Even with accepted_tos=False, should not raise TosNotAcceptedError
mock_decode.return_value = {'accepted_tos': False}
mock_config.jwt_secret.get_secret_value.return_value = 'test_secret'
# Act
result = await middleware(mock_request, mock_call_next)
# Assert
assert result == mock_response
mock_call_next.assert_called_once_with(mock_request)
# Should not raise TosNotAcceptedError for this path

View File

@@ -136,6 +136,7 @@ async def test_keycloak_callback_user_not_allowed(mock_request):
'sub': 'test_user_id',
'preferred_username': 'test_user',
'identity_provider': 'github',
'email_verified': True,
}
)
mock_token_manager.store_idp_tokens = AsyncMock()
@@ -184,6 +185,7 @@ async def test_keycloak_callback_success_with_valid_offline_token(mock_request):
'sub': 'test_user_id',
'preferred_username': 'test_user',
'identity_provider': 'github',
'email_verified': True,
}
)
mock_token_manager.store_idp_tokens = AsyncMock()
@@ -214,6 +216,84 @@ async def test_keycloak_callback_success_with_valid_offline_token(mock_request):
mock_posthog.set.assert_called_once()
@pytest.mark.asyncio
async def test_keycloak_callback_email_not_verified(mock_request):
"""Test keycloak_callback when email is not verified."""
# Arrange
mock_verify_email = AsyncMock()
with (
patch('server.routes.auth.token_manager') as mock_token_manager,
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.email.verify_email', mock_verify_email),
):
mock_token_manager.get_keycloak_tokens = AsyncMock(
return_value=('test_access_token', 'test_refresh_token')
)
mock_token_manager.get_user_info = AsyncMock(
return_value={
'sub': 'test_user_id',
'preferred_username': 'test_user',
'identity_provider': 'github',
'email_verified': False,
}
)
mock_token_manager.store_idp_tokens = AsyncMock()
mock_verifier.is_active.return_value = False
# Act
result = await keycloak_callback(
code='test_code', state='test_state', request=mock_request
)
# Assert
assert isinstance(result, RedirectResponse)
assert result.status_code == 302
assert 'email_verification_required=true' in result.headers['location']
assert 'user_id=test_user_id' in result.headers['location']
mock_verify_email.assert_called_once_with(
request=mock_request, user_id='test_user_id', is_auth_flow=True
)
@pytest.mark.asyncio
async def test_keycloak_callback_email_not_verified_missing_field(mock_request):
"""Test keycloak_callback when email_verified field is missing (defaults to False)."""
# Arrange
mock_verify_email = AsyncMock()
with (
patch('server.routes.auth.token_manager') as mock_token_manager,
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.email.verify_email', mock_verify_email),
):
mock_token_manager.get_keycloak_tokens = AsyncMock(
return_value=('test_access_token', 'test_refresh_token')
)
mock_token_manager.get_user_info = AsyncMock(
return_value={
'sub': 'test_user_id',
'preferred_username': 'test_user',
'identity_provider': 'github',
# email_verified field is missing
}
)
mock_token_manager.store_idp_tokens = AsyncMock()
mock_verifier.is_active.return_value = False
# Act
result = await keycloak_callback(
code='test_code', state='test_state', request=mock_request
)
# Assert
assert isinstance(result, RedirectResponse)
assert result.status_code == 302
assert 'email_verification_required=true' in result.headers['location']
assert 'user_id=test_user_id' in result.headers['location']
mock_verify_email.assert_called_once_with(
request=mock_request, user_id='test_user_id', is_auth_flow=True
)
@pytest.mark.asyncio
async def test_keycloak_callback_success_without_offline_token(mock_request):
"""Test successful keycloak_callback without valid offline token."""
@@ -248,6 +328,7 @@ async def test_keycloak_callback_success_without_offline_token(mock_request):
'sub': 'test_user_id',
'preferred_username': 'test_user',
'identity_provider': 'github',
'email_verified': True,
}
)
mock_token_manager.store_idp_tokens = AsyncMock()
@@ -442,3 +523,418 @@ async def test_logout_without_refresh_token():
mock_token_manager.logout.assert_not_called()
assert 'set-cookie' in result.headers
@pytest.mark.asyncio
async def test_keycloak_callback_blocked_email_domain(mock_request):
"""Test keycloak_callback when email domain is blocked."""
# Arrange
with (
patch('server.routes.auth.token_manager') as mock_token_manager,
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
):
mock_token_manager.get_keycloak_tokens = AsyncMock(
return_value=('test_access_token', 'test_refresh_token')
)
mock_token_manager.get_user_info = AsyncMock(
return_value={
'sub': 'test_user_id',
'preferred_username': 'test_user',
'email': 'user@colsch.us',
'identity_provider': 'github',
}
)
mock_token_manager.disable_keycloak_user = AsyncMock()
mock_domain_blocker.is_active.return_value = True
mock_domain_blocker.is_domain_blocked.return_value = True
# Act
result = await keycloak_callback(
code='test_code', state='test_state', request=mock_request
)
# Assert
assert isinstance(result, JSONResponse)
assert result.status_code == status.HTTP_401_UNAUTHORIZED
assert 'error' in result.body.decode()
assert 'email domain is not allowed' in result.body.decode()
mock_domain_blocker.is_domain_blocked.assert_called_once_with('user@colsch.us')
mock_token_manager.disable_keycloak_user.assert_called_once_with(
'test_user_id', 'user@colsch.us'
)
@pytest.mark.asyncio
async def test_keycloak_callback_allowed_email_domain(mock_request):
"""Test keycloak_callback when email domain is not blocked."""
# Arrange
with (
patch('server.routes.auth.token_manager') as mock_token_manager,
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.auth.session_maker') as mock_session_maker,
):
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_user_settings = MagicMock()
mock_user_settings.accepted_tos = '2025-01-01'
mock_query.first.return_value = mock_user_settings
mock_token_manager.get_keycloak_tokens = AsyncMock(
return_value=('test_access_token', 'test_refresh_token')
)
mock_token_manager.get_user_info = AsyncMock(
return_value={
'sub': 'test_user_id',
'preferred_username': 'test_user',
'email': 'user@example.com',
'identity_provider': 'github',
'email_verified': True,
}
)
mock_token_manager.store_idp_tokens = AsyncMock()
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
mock_domain_blocker.is_active.return_value = True
mock_domain_blocker.is_domain_blocked.return_value = False
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
# Act
result = await keycloak_callback(
code='test_code', state='test_state', request=mock_request
)
# Assert
assert isinstance(result, RedirectResponse)
mock_domain_blocker.is_domain_blocked.assert_called_once_with(
'user@example.com'
)
mock_token_manager.disable_keycloak_user.assert_not_called()
@pytest.mark.asyncio
async def test_keycloak_callback_domain_blocking_inactive(mock_request):
"""Test keycloak_callback when domain blocking is not active."""
# Arrange
with (
patch('server.routes.auth.token_manager') as mock_token_manager,
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.auth.session_maker') as mock_session_maker,
):
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_user_settings = MagicMock()
mock_user_settings.accepted_tos = '2025-01-01'
mock_query.first.return_value = mock_user_settings
mock_token_manager.get_keycloak_tokens = AsyncMock(
return_value=('test_access_token', 'test_refresh_token')
)
mock_token_manager.get_user_info = AsyncMock(
return_value={
'sub': 'test_user_id',
'preferred_username': 'test_user',
'email': 'user@colsch.us',
'identity_provider': 'github',
'email_verified': True,
}
)
mock_token_manager.store_idp_tokens = AsyncMock()
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
mock_domain_blocker.is_active.return_value = False
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
# Act
result = await keycloak_callback(
code='test_code', state='test_state', request=mock_request
)
# Assert
assert isinstance(result, RedirectResponse)
mock_domain_blocker.is_domain_blocked.assert_not_called()
mock_token_manager.disable_keycloak_user.assert_not_called()
@pytest.mark.asyncio
async def test_keycloak_callback_missing_email(mock_request):
"""Test keycloak_callback when user info does not contain email."""
# Arrange
with (
patch('server.routes.auth.token_manager') as mock_token_manager,
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.auth.session_maker') as mock_session_maker,
):
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_user_settings = MagicMock()
mock_user_settings.accepted_tos = '2025-01-01'
mock_query.first.return_value = mock_user_settings
mock_token_manager.get_keycloak_tokens = AsyncMock(
return_value=('test_access_token', 'test_refresh_token')
)
mock_token_manager.get_user_info = AsyncMock(
return_value={
'sub': 'test_user_id',
'preferred_username': 'test_user',
'identity_provider': 'github',
'email_verified': True,
# No email field
}
)
mock_token_manager.store_idp_tokens = AsyncMock()
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
mock_domain_blocker.is_active.return_value = True
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
# Act
result = await keycloak_callback(
code='test_code', state='test_state', request=mock_request
)
# Assert
assert isinstance(result, RedirectResponse)
mock_domain_blocker.is_domain_blocked.assert_not_called()
mock_token_manager.disable_keycloak_user.assert_not_called()
@pytest.mark.asyncio
async def test_keycloak_callback_duplicate_email_detected(mock_request):
"""Test keycloak_callback when duplicate email is detected."""
with (
patch('server.routes.auth.token_manager') as mock_token_manager,
):
# Arrange
mock_token_manager.get_keycloak_tokens = AsyncMock(
return_value=('test_access_token', 'test_refresh_token')
)
mock_token_manager.get_user_info = AsyncMock(
return_value={
'sub': 'test_user_id',
'preferred_username': 'test_user',
'email': 'joe+test@example.com',
'identity_provider': 'github',
}
)
mock_token_manager.check_duplicate_base_email = AsyncMock(return_value=True)
mock_token_manager.delete_keycloak_user = AsyncMock(return_value=True)
# Act
result = await keycloak_callback(
code='test_code', state='test_state', request=mock_request
)
# Assert
assert isinstance(result, RedirectResponse)
assert result.status_code == 302
assert 'duplicated_email=true' in result.headers['location']
mock_token_manager.check_duplicate_base_email.assert_called_once_with(
'joe+test@example.com', 'test_user_id'
)
mock_token_manager.delete_keycloak_user.assert_called_once_with('test_user_id')
@pytest.mark.asyncio
async def test_keycloak_callback_duplicate_email_deletion_fails(mock_request):
"""Test keycloak_callback when duplicate is detected but deletion fails."""
with (
patch('server.routes.auth.token_manager') as mock_token_manager,
):
# Arrange
mock_token_manager.get_keycloak_tokens = AsyncMock(
return_value=('test_access_token', 'test_refresh_token')
)
mock_token_manager.get_user_info = AsyncMock(
return_value={
'sub': 'test_user_id',
'preferred_username': 'test_user',
'email': 'joe+test@example.com',
'identity_provider': 'github',
}
)
mock_token_manager.check_duplicate_base_email = AsyncMock(return_value=True)
mock_token_manager.delete_keycloak_user = AsyncMock(return_value=False)
# Act
result = await keycloak_callback(
code='test_code', state='test_state', request=mock_request
)
# Assert
assert isinstance(result, RedirectResponse)
assert result.status_code == 302
assert 'duplicated_email=true' in result.headers['location']
mock_token_manager.delete_keycloak_user.assert_called_once_with('test_user_id')
@pytest.mark.asyncio
async def test_keycloak_callback_duplicate_check_exception(mock_request):
"""Test keycloak_callback when duplicate check raises exception."""
with (
patch('server.routes.auth.token_manager') as mock_token_manager,
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.auth.session_maker') as mock_session_maker,
):
# Arrange
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_user_settings = MagicMock()
mock_user_settings.accepted_tos = '2025-01-01'
mock_query.first.return_value = mock_user_settings
mock_token_manager.get_keycloak_tokens = AsyncMock(
return_value=('test_access_token', 'test_refresh_token')
)
mock_token_manager.get_user_info = AsyncMock(
return_value={
'sub': 'test_user_id',
'preferred_username': 'test_user',
'email': 'joe+test@example.com',
'identity_provider': 'github',
'email_verified': True,
}
)
mock_token_manager.check_duplicate_base_email = AsyncMock(
side_effect=Exception('Check failed')
)
mock_token_manager.store_idp_tokens = AsyncMock()
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
# Act
result = await keycloak_callback(
code='test_code', state='test_state', request=mock_request
)
# Assert
# Should proceed with normal flow despite exception (fail open)
assert isinstance(result, RedirectResponse)
assert result.status_code == 302
@pytest.mark.asyncio
async def test_keycloak_callback_no_duplicate_email(mock_request):
"""Test keycloak_callback when no duplicate email is found."""
with (
patch('server.routes.auth.token_manager') as mock_token_manager,
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.auth.session_maker') as mock_session_maker,
):
# Arrange
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_user_settings = MagicMock()
mock_user_settings.accepted_tos = '2025-01-01'
mock_query.first.return_value = mock_user_settings
mock_token_manager.get_keycloak_tokens = AsyncMock(
return_value=('test_access_token', 'test_refresh_token')
)
mock_token_manager.get_user_info = AsyncMock(
return_value={
'sub': 'test_user_id',
'preferred_username': 'test_user',
'email': 'joe+test@example.com',
'identity_provider': 'github',
'email_verified': True,
}
)
mock_token_manager.check_duplicate_base_email = AsyncMock(return_value=False)
mock_token_manager.store_idp_tokens = AsyncMock()
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
# Act
result = await keycloak_callback(
code='test_code', state='test_state', request=mock_request
)
# Assert
assert isinstance(result, RedirectResponse)
assert result.status_code == 302
mock_token_manager.check_duplicate_base_email.assert_called_once_with(
'joe+test@example.com', 'test_user_id'
)
# Should not delete user when no duplicate found
mock_token_manager.delete_keycloak_user.assert_not_called()
@pytest.mark.asyncio
async def test_keycloak_callback_no_email_in_user_info(mock_request):
"""Test keycloak_callback when email is not in user_info."""
with (
patch('server.routes.auth.token_manager') as mock_token_manager,
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.auth.session_maker') as mock_session_maker,
):
# Arrange
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_user_settings = MagicMock()
mock_user_settings.accepted_tos = '2025-01-01'
mock_query.first.return_value = mock_user_settings
mock_token_manager.get_keycloak_tokens = AsyncMock(
return_value=('test_access_token', 'test_refresh_token')
)
mock_token_manager.get_user_info = AsyncMock(
return_value={
'sub': 'test_user_id',
'preferred_username': 'test_user',
# No email field
'identity_provider': 'github',
'email_verified': True,
}
)
mock_token_manager.store_idp_tokens = AsyncMock()
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
# Act
result = await keycloak_callback(
code='test_code', state='test_state', request=mock_request
)
# Assert
assert isinstance(result, RedirectResponse)
assert result.status_code == 302
# Should not check for duplicate when email is missing
mock_token_manager.check_duplicate_base_email.assert_not_called()

View File

@@ -4,7 +4,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import stripe
from fastapi import HTTPException, Request, status
from httpx import HTTPStatusError, Response
from httpx import Response
from integrations.stripe_service import has_payment_method
from server.routes.billing import (
CreateBillingSessionResponse,
@@ -78,8 +78,6 @@ def mock_subscription_request():
@pytest.mark.asyncio
async def test_get_credits_lite_llm_error():
mock_request = Request(scope={'type': 'http', 'state': {'user_id': 'mock_user'}})
mock_response = Response(
status_code=500, json={'error': 'Internal Server Error'}, request=MagicMock()
)
@@ -88,11 +86,12 @@ async def test_get_credits_lite_llm_error():
with patch('integrations.stripe_service.STRIPE_API_KEY', 'mock_key'):
with patch('httpx.AsyncClient', return_value=mock_client):
with pytest.raises(HTTPStatusError) as exc_info:
await get_credits(mock_request)
with pytest.raises(HTTPException) as exc_info:
await get_credits('mock_user')
assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert (
exc_info.value.response.status_code
== status.HTTP_500_INTERNAL_SERVER_ERROR
exc_info.value.detail
== 'Failed to retrieve credit balance from billing service'
)

View File

@@ -0,0 +1,181 @@
"""Unit tests for DomainBlocker class."""
import pytest
from server.auth.domain_blocker import DomainBlocker
@pytest.fixture
def domain_blocker():
"""Create a DomainBlocker instance for testing."""
return DomainBlocker()
@pytest.mark.parametrize(
'blocked_domains,expected',
[
(['colsch.us', 'other-domain.com'], True),
(['example.com'], True),
([], False),
],
)
def test_is_active(domain_blocker, blocked_domains, expected):
"""Test that is_active returns correct value based on blocked domains configuration."""
# Arrange
domain_blocker.blocked_domains = blocked_domains
# Act
result = domain_blocker.is_active()
# Assert
assert result == expected
@pytest.mark.parametrize(
'email,expected_domain',
[
('user@example.com', 'example.com'),
('test@colsch.us', 'colsch.us'),
('user.name@other-domain.com', 'other-domain.com'),
('USER@EXAMPLE.COM', 'example.com'), # Case insensitive
('user@EXAMPLE.COM', 'example.com'),
(' user@example.com ', 'example.com'), # Whitespace handling
],
)
def test_extract_domain_valid_emails(domain_blocker, email, expected_domain):
"""Test that _extract_domain correctly extracts and normalizes domains from valid emails."""
# Act
result = domain_blocker._extract_domain(email)
# Assert
assert result == expected_domain
@pytest.mark.parametrize(
'email,expected',
[
(None, None),
('', None),
('invalid-email', None),
('user@', None), # Empty domain after @
('no-at-sign', None),
],
)
def test_extract_domain_invalid_emails(domain_blocker, email, expected):
"""Test that _extract_domain returns None for invalid email formats."""
# Act
result = domain_blocker._extract_domain(email)
# Assert
assert result == expected
def test_is_domain_blocked_when_inactive(domain_blocker):
"""Test that is_domain_blocked returns False when blocking is not active."""
# Arrange
domain_blocker.blocked_domains = []
# Act
result = domain_blocker.is_domain_blocked('user@colsch.us')
# Assert
assert result is False
def test_is_domain_blocked_with_none_email(domain_blocker):
"""Test that is_domain_blocked returns False when email is None."""
# Arrange
domain_blocker.blocked_domains = ['colsch.us']
# Act
result = domain_blocker.is_domain_blocked(None)
# Assert
assert result is False
def test_is_domain_blocked_with_empty_email(domain_blocker):
"""Test that is_domain_blocked returns False when email is empty."""
# Arrange
domain_blocker.blocked_domains = ['colsch.us']
# Act
result = domain_blocker.is_domain_blocked('')
# Assert
assert result is False
def test_is_domain_blocked_with_invalid_email(domain_blocker):
"""Test that is_domain_blocked returns False when email format is invalid."""
# Arrange
domain_blocker.blocked_domains = ['colsch.us']
# Act
result = domain_blocker.is_domain_blocked('invalid-email')
# Assert
assert result is False
def test_is_domain_blocked_domain_not_blocked(domain_blocker):
"""Test that is_domain_blocked returns False when domain is not in blocked list."""
# Arrange
domain_blocker.blocked_domains = ['colsch.us', 'other-domain.com']
# Act
result = domain_blocker.is_domain_blocked('user@example.com')
# Assert
assert result is False
def test_is_domain_blocked_domain_blocked(domain_blocker):
"""Test that is_domain_blocked returns True when domain is in blocked list."""
# Arrange
domain_blocker.blocked_domains = ['colsch.us', 'other-domain.com']
# Act
result = domain_blocker.is_domain_blocked('user@colsch.us')
# Assert
assert result is True
def test_is_domain_blocked_case_insensitive(domain_blocker):
"""Test that is_domain_blocked performs case-insensitive domain matching."""
# Arrange
domain_blocker.blocked_domains = ['colsch.us']
# Act
result = domain_blocker.is_domain_blocked('user@COLSCH.US')
# Assert
assert result is True
def test_is_domain_blocked_multiple_blocked_domains(domain_blocker):
"""Test that is_domain_blocked correctly checks against multiple blocked domains."""
# Arrange
domain_blocker.blocked_domains = ['colsch.us', 'other-domain.com', 'blocked.org']
# Act
result1 = domain_blocker.is_domain_blocked('user@other-domain.com')
result2 = domain_blocker.is_domain_blocked('user@blocked.org')
result3 = domain_blocker.is_domain_blocked('user@allowed.com')
# Assert
assert result1 is True
assert result2 is True
assert result3 is False
def test_is_domain_blocked_with_whitespace(domain_blocker):
"""Test that is_domain_blocked handles emails with whitespace correctly."""
# Arrange
domain_blocker.blocked_domains = ['colsch.us']
# Act
result = domain_blocker.is_domain_blocked(' user@colsch.us ')
# Assert
assert result is True

View File

@@ -0,0 +1,294 @@
"""Tests for email validation utilities."""
import re
from server.auth.email_validation import (
extract_base_email,
get_base_email_regex_pattern,
has_plus_modifier,
matches_base_email,
)
class TestExtractBaseEmail:
"""Test cases for extract_base_email function."""
def test_extract_base_email_with_plus_modifier(self):
"""Test extracting base email from email with + modifier."""
# Arrange
email = 'joe+test@example.com'
# Act
result = extract_base_email(email)
# Assert
assert result == 'joe@example.com'
def test_extract_base_email_without_plus_modifier(self):
"""Test that email without + modifier is returned as-is."""
# Arrange
email = 'joe@example.com'
# Act
result = extract_base_email(email)
# Assert
assert result == 'joe@example.com'
def test_extract_base_email_multiple_plus_signs(self):
"""Test extracting base email when multiple + signs exist."""
# Arrange
email = 'joe+openhands+test@example.com'
# Act
result = extract_base_email(email)
# Assert
assert result == 'joe@example.com'
def test_extract_base_email_invalid_no_at_symbol(self):
"""Test that invalid email without @ returns None."""
# Arrange
email = 'invalid-email'
# Act
result = extract_base_email(email)
# Assert
assert result is None
def test_extract_base_email_empty_string(self):
"""Test that empty string returns None."""
# Arrange
email = ''
# Act
result = extract_base_email(email)
# Assert
assert result is None
def test_extract_base_email_none(self):
"""Test that None input returns None."""
# Arrange
email = None
# Act
result = extract_base_email(email)
# Assert
assert result is None
class TestHasPlusModifier:
"""Test cases for has_plus_modifier function."""
def test_has_plus_modifier_true(self):
"""Test detecting + modifier in email."""
# Arrange
email = 'joe+test@example.com'
# Act
result = has_plus_modifier(email)
# Assert
assert result is True
def test_has_plus_modifier_false(self):
"""Test that email without + modifier returns False."""
# Arrange
email = 'joe@example.com'
# Act
result = has_plus_modifier(email)
# Assert
assert result is False
def test_has_plus_modifier_invalid_no_at_symbol(self):
"""Test that invalid email without @ returns False."""
# Arrange
email = 'invalid-email'
# Act
result = has_plus_modifier(email)
# Assert
assert result is False
def test_has_plus_modifier_empty_string(self):
"""Test that empty string returns False."""
# Arrange
email = ''
# Act
result = has_plus_modifier(email)
# Assert
assert result is False
class TestMatchesBaseEmail:
"""Test cases for matches_base_email function."""
def test_matches_base_email_exact_match(self):
"""Test that exact base email matches."""
# Arrange
email = 'joe@example.com'
base_email = 'joe@example.com'
# Act
result = matches_base_email(email, base_email)
# Assert
assert result is True
def test_matches_base_email_with_plus_variant(self):
"""Test that email with + variant matches base email."""
# Arrange
email = 'joe+test@example.com'
base_email = 'joe@example.com'
# Act
result = matches_base_email(email, base_email)
# Assert
assert result is True
def test_matches_base_email_different_base(self):
"""Test that different base emails do not match."""
# Arrange
email = 'jane@example.com'
base_email = 'joe@example.com'
# Act
result = matches_base_email(email, base_email)
# Assert
assert result is False
def test_matches_base_email_different_domain(self):
"""Test that same local part but different domain does not match."""
# Arrange
email = 'joe@other.com'
base_email = 'joe@example.com'
# Act
result = matches_base_email(email, base_email)
# Assert
assert result is False
def test_matches_base_email_case_insensitive(self):
"""Test that matching is case-insensitive."""
# Arrange
email = 'JOE+TEST@EXAMPLE.COM'
base_email = 'joe@example.com'
# Act
result = matches_base_email(email, base_email)
# Assert
assert result is True
def test_matches_base_email_empty_strings(self):
"""Test that empty strings return False."""
# Arrange
email = ''
base_email = 'joe@example.com'
# Act
result = matches_base_email(email, base_email)
# Assert
assert result is False
class TestGetBaseEmailRegexPattern:
"""Test cases for get_base_email_regex_pattern function."""
def test_get_base_email_regex_pattern_valid(self):
"""Test generating valid regex pattern for base email."""
# Arrange
base_email = 'joe@example.com'
# Act
pattern = get_base_email_regex_pattern(base_email)
# Assert
assert pattern is not None
assert isinstance(pattern, re.Pattern)
assert pattern.match('joe@example.com') is not None
assert pattern.match('joe+test@example.com') is not None
assert pattern.match('joe+openhands@example.com') is not None
def test_get_base_email_regex_pattern_matches_plus_variant(self):
"""Test that regex pattern matches + variant."""
# Arrange
base_email = 'joe@example.com'
pattern = get_base_email_regex_pattern(base_email)
# Act
match = pattern.match('joe+test@example.com')
# Assert
assert match is not None
def test_get_base_email_regex_pattern_rejects_different_base(self):
"""Test that regex pattern rejects different base email."""
# Arrange
base_email = 'joe@example.com'
pattern = get_base_email_regex_pattern(base_email)
# Act
match = pattern.match('jane@example.com')
# Assert
assert match is None
def test_get_base_email_regex_pattern_rejects_different_domain(self):
"""Test that regex pattern rejects different domain."""
# Arrange
base_email = 'joe@example.com'
pattern = get_base_email_regex_pattern(base_email)
# Act
match = pattern.match('joe@other.com')
# Assert
assert match is None
def test_get_base_email_regex_pattern_case_insensitive(self):
"""Test that regex pattern is case-insensitive."""
# Arrange
base_email = 'joe@example.com'
pattern = get_base_email_regex_pattern(base_email)
# Act
match = pattern.match('JOE+TEST@EXAMPLE.COM')
# Assert
assert match is not None
def test_get_base_email_regex_pattern_special_characters(self):
"""Test that regex pattern handles special characters in email."""
# Arrange
base_email = 'user.name+tag@example-site.com'
pattern = get_base_email_regex_pattern(base_email)
# Act
match = pattern.match('user.name+test@example-site.com')
# Assert
assert match is not None
def test_get_base_email_regex_pattern_invalid_base_email(self):
"""Test that invalid base email returns None."""
# Arrange
base_email = 'invalid-email'
# Act
pattern = get_base_email_regex_pattern(base_email)
# Assert
assert pattern is None

View File

@@ -6,6 +6,7 @@ from server.constants import (
CURRENT_USER_SETTINGS_VERSION,
LITE_LLM_API_URL,
LITE_LLM_TEAM_ID,
get_default_litellm_model,
)
from storage.saas_settings_store import SaasSettingsStore
from storage.user_settings import UserSettings
@@ -393,10 +394,11 @@ async def test_create_user_in_lite_llm(settings_store):
mock_response = AsyncMock()
mock_response.is_success = True
mock_client.post.return_value = mock_response
test_model = 'custom-model/test-model'
# Test with email
await settings_store._create_user_in_lite_llm(
mock_client, 'test@example.com', 50, 10
mock_client, 'test@example.com', 50, 10, test_model
)
# Get the actual call arguments
@@ -412,11 +414,11 @@ async def test_create_user_in_lite_llm(settings_store):
assert call_args['json']['auto_create_key'] is True
assert call_args['json']['send_invite_email'] is False
assert call_args['json']['metadata']['version'] == CURRENT_USER_SETTINGS_VERSION
assert 'model' in call_args['json']['metadata']
assert call_args['json']['metadata']['model'] == test_model
# Test with None email
mock_client.post.reset_mock()
await settings_store._create_user_in_lite_llm(mock_client, None, 25, 15)
await settings_store._create_user_in_lite_llm(mock_client, None, 25, 15, test_model)
# Get the actual call arguments
call_args = mock_client.post.call_args[1]
@@ -431,12 +433,12 @@ async def test_create_user_in_lite_llm(settings_store):
assert call_args['json']['auto_create_key'] is True
assert call_args['json']['send_invite_email'] is False
assert call_args['json']['metadata']['version'] == CURRENT_USER_SETTINGS_VERSION
assert 'model' in call_args['json']['metadata']
assert call_args['json']['metadata']['model'] == test_model
# Verify response is returned correctly
assert (
await settings_store._create_user_in_lite_llm(
mock_client, 'email@test.com', 30, 7
mock_client, 'email@test.com', 30, 7, test_model
)
== mock_response
)
@@ -464,3 +466,808 @@ async def test_encryption(settings_store):
# But we should be able to decrypt it when loading
loaded_settings = await settings_store.load()
assert loaded_settings.llm_api_key.get_secret_value() == 'secret_key'
@pytest.mark.asyncio
async def test_update_settings_with_litellm_default_preserves_custom_model(
settings_store, mock_litellm_api, session_maker
):
# Arrange: User has a custom LLM model set
custom_model = 'anthropic/claude-3-5-sonnet-20241022'
settings = Settings(llm_model=custom_model)
with (
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'user@example.com'}),
),
patch('storage.saas_settings_store.session_maker', session_maker),
):
# Act: Update settings with LiteLLM defaults
updated_settings = await settings_store.update_settings_with_litellm_default(
settings
)
# Assert: Custom model is preserved
assert updated_settings is not None
assert updated_settings.llm_model == custom_model
assert updated_settings.agent == 'CodeActAgent'
assert updated_settings.llm_api_key is not None
# Assert: LiteLLM metadata contains user's custom model
call_args = mock_litellm_api.return_value.__aenter__.return_value.post.call_args[1]
assert call_args['json']['metadata']['model'] == custom_model
@pytest.mark.asyncio
async def test_update_settings_with_litellm_default_uses_default_when_no_model(
settings_store, mock_litellm_api, session_maker
):
# Arrange: User has no model set (new user scenario)
settings = Settings()
with (
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'newuser@example.com'}),
),
patch('storage.saas_settings_store.session_maker', session_maker),
):
# Act: Update settings with LiteLLM defaults
updated_settings = await settings_store.update_settings_with_litellm_default(
settings
)
# Assert: Default model is assigned
assert updated_settings is not None
expected_default = get_default_litellm_model()
assert updated_settings.llm_model == expected_default
assert updated_settings.agent == 'CodeActAgent'
# Assert: LiteLLM metadata contains default model
call_args = mock_litellm_api.return_value.__aenter__.return_value.post.call_args[1]
assert call_args['json']['metadata']['model'] == expected_default
@pytest.mark.asyncio
async def test_update_settings_with_litellm_default_handles_empty_string_model(
settings_store, mock_litellm_api, session_maker
):
# Arrange: User has empty string as model (edge case)
settings = Settings(llm_model='')
with (
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'user@example.com'}),
),
patch('storage.saas_settings_store.session_maker', session_maker),
):
# Act: Update settings with LiteLLM defaults
updated_settings = await settings_store.update_settings_with_litellm_default(
settings
)
# Assert: Default model is used (empty string treated as no model)
assert updated_settings is not None
expected_default = get_default_litellm_model()
assert updated_settings.llm_model == expected_default
@pytest.mark.asyncio
async def test_update_settings_with_litellm_default_handles_whitespace_model(
settings_store, mock_litellm_api, session_maker
):
# Arrange: User has whitespace-only model (edge case)
settings = Settings(llm_model=' ')
with (
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'user@example.com'}),
),
patch('storage.saas_settings_store.session_maker', session_maker),
):
# Act: Update settings with LiteLLM defaults
updated_settings = await settings_store.update_settings_with_litellm_default(
settings
)
# Assert: Default model is used (whitespace treated as no model)
assert updated_settings is not None
expected_default = get_default_litellm_model()
assert updated_settings.llm_model == expected_default
@pytest.mark.asyncio
async def test_update_settings_with_litellm_default_preserves_custom_api_key(
settings_store, mock_litellm_api, session_maker
):
# Arrange: User has a custom API key and custom model (so has_custom=True)
custom_api_key = 'sk-custom-user-api-key-12345'
custom_model = 'gpt-4'
settings = Settings(llm_model=custom_model, llm_api_key=SecretStr(custom_api_key))
with (
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'user@example.com'}),
),
patch('storage.saas_settings_store.session_maker', session_maker),
):
# Act: Update settings with LiteLLM defaults
updated_settings = await settings_store.update_settings_with_litellm_default(
settings
)
# Assert: Custom API key is preserved when user has custom settings
assert updated_settings is not None
assert updated_settings.llm_api_key.get_secret_value() == custom_api_key
assert updated_settings.llm_api_key.get_secret_value() != 'test_api_key'
@pytest.mark.asyncio
async def test_update_settings_with_litellm_default_preserves_custom_base_url(
settings_store, mock_litellm_api, session_maker
):
# Arrange: User has a custom base URL
custom_base_url = 'https://api.custom-llm-provider.com/v1'
settings = Settings(llm_base_url=custom_base_url)
with (
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'user@example.com'}),
),
patch('storage.saas_settings_store.session_maker', session_maker),
):
# Act: Update settings with LiteLLM defaults
updated_settings = await settings_store.update_settings_with_litellm_default(
settings
)
# Assert: Custom base URL is preserved
assert updated_settings is not None
assert updated_settings.llm_base_url == custom_base_url
assert updated_settings.llm_base_url != LITE_LLM_API_URL
@pytest.mark.asyncio
async def test_update_settings_with_litellm_default_preserves_custom_api_key_and_base_url(
settings_store, mock_litellm_api, session_maker
):
# Arrange: User has both custom API key and base URL
custom_api_key = 'sk-custom-user-api-key-67890'
custom_base_url = 'https://api.another-llm-provider.com/v1'
custom_model = 'openai/gpt-4'
settings = Settings(
llm_model=custom_model,
llm_api_key=SecretStr(custom_api_key),
llm_base_url=custom_base_url,
)
with (
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'user@example.com'}),
),
patch('storage.saas_settings_store.session_maker', session_maker),
):
# Act: Update settings with LiteLLM defaults
updated_settings = await settings_store.update_settings_with_litellm_default(
settings
)
# Assert: All custom settings are preserved
assert updated_settings is not None
assert updated_settings.llm_model == custom_model
assert updated_settings.llm_api_key.get_secret_value() == custom_api_key
assert updated_settings.llm_base_url == custom_base_url
@pytest.mark.asyncio
async def test_update_settings_with_litellm_default_uses_default_api_key_when_none(
settings_store, mock_litellm_api, session_maker
):
# Arrange: User has no API key set
settings = Settings(llm_api_key=None)
with (
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'user@example.com'}),
),
patch('storage.saas_settings_store.session_maker', session_maker),
):
# Act: Update settings with LiteLLM defaults
updated_settings = await settings_store.update_settings_with_litellm_default(
settings
)
# Assert: Default LiteLLM API key is assigned
assert updated_settings is not None
assert updated_settings.llm_api_key is not None
assert updated_settings.llm_api_key.get_secret_value() == 'test_api_key'
@pytest.mark.asyncio
async def test_update_settings_with_litellm_default_uses_default_base_url_when_none(
settings_store, mock_litellm_api, session_maker
):
# Arrange: User has no base URL set
settings = Settings(llm_base_url=None)
with (
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'user@example.com'}),
),
patch('storage.saas_settings_store.session_maker', session_maker),
patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://test.url'),
):
# Act: Update settings with LiteLLM defaults
updated_settings = await settings_store.update_settings_with_litellm_default(
settings
)
# Assert: Default LiteLLM base URL is assigned (using mocked value)
assert updated_settings is not None
assert updated_settings.llm_base_url == 'http://test.url'
@pytest.mark.asyncio
async def test_update_settings_with_litellm_default_handles_empty_api_key(
settings_store, mock_litellm_api, session_maker
):
# Arrange: User has empty string as API key (edge case)
settings = Settings(llm_api_key=SecretStr(''))
with (
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'user@example.com'}),
),
patch('storage.saas_settings_store.session_maker', session_maker),
):
# Act: Update settings with LiteLLM defaults
updated_settings = await settings_store.update_settings_with_litellm_default(
settings
)
# Assert: Default API key is used (empty string treated as no key)
assert updated_settings is not None
assert updated_settings.llm_api_key.get_secret_value() == 'test_api_key'
@pytest.mark.asyncio
async def test_update_settings_with_litellm_default_handles_empty_base_url(
settings_store, mock_litellm_api, session_maker
):
# Arrange: User has empty string as base URL (edge case)
settings = Settings(llm_base_url='')
with (
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'user@example.com'}),
),
patch('storage.saas_settings_store.session_maker', session_maker),
patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://test.url'),
):
# Act: Update settings with LiteLLM defaults
updated_settings = await settings_store.update_settings_with_litellm_default(
settings
)
# Assert: Default base URL is used (empty string treated as no URL)
assert updated_settings is not None
assert updated_settings.llm_base_url == 'http://test.url'
@pytest.mark.asyncio
async def test_update_settings_with_litellm_default_handles_whitespace_api_key(
settings_store, mock_litellm_api, session_maker
):
# Arrange: User has whitespace-only API key (edge case)
settings = Settings(llm_api_key=SecretStr(' '))
with (
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'user@example.com'}),
),
patch('storage.saas_settings_store.session_maker', session_maker),
):
# Act: Update settings with LiteLLM defaults
updated_settings = await settings_store.update_settings_with_litellm_default(
settings
)
# Assert: Default API key is used (whitespace treated as no key)
assert updated_settings is not None
assert updated_settings.llm_api_key.get_secret_value() == 'test_api_key'
@pytest.mark.asyncio
async def test_update_settings_with_litellm_default_handles_whitespace_base_url(
settings_store, mock_litellm_api, session_maker
):
# Arrange: User has whitespace-only base URL (edge case)
settings = Settings(llm_base_url=' ')
with (
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'user@example.com'}),
),
patch('storage.saas_settings_store.session_maker', session_maker),
patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://test.url'),
):
# Act: Update settings with LiteLLM defaults
updated_settings = await settings_store.update_settings_with_litellm_default(
settings
)
# Assert: Default base URL is used (whitespace treated as no URL)
assert updated_settings is not None
assert updated_settings.llm_base_url == 'http://test.url'
# Tests for version migration and helper methods
@pytest.mark.asyncio
async def test_has_custom_settings_with_custom_base_url(settings_store):
# Arrange: User with custom base URL (BYOR)
with patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'):
settings = Settings(llm_base_url='http://custom.url')
# Act: Check if has custom settings
has_custom = settings_store._has_custom_settings(settings, None)
# Assert: Custom base URL detected
assert has_custom is True
@pytest.mark.asyncio
async def test_has_custom_settings_with_default_base_url(settings_store):
# Arrange: User with default base URL
with patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'):
settings = Settings(llm_base_url='http://default.url')
# Act: Check if has custom settings
has_custom = settings_store._has_custom_settings(settings, None)
# Assert: No custom settings (no model set)
assert has_custom is False
@pytest.mark.asyncio
async def test_has_custom_settings_with_no_model(settings_store):
# Arrange: User with no model set
with patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'):
settings = Settings(llm_model=None, llm_base_url='http://default.url')
# Act: Check if has custom settings
has_custom = settings_store._has_custom_settings(settings, None)
# Assert: No custom settings (using defaults)
assert has_custom is False
@pytest.mark.asyncio
async def test_has_custom_settings_with_empty_model(settings_store):
# Arrange: User with empty model
with patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'):
settings = Settings(llm_model='', llm_base_url='http://default.url')
# Act: Check if has custom settings
has_custom = settings_store._has_custom_settings(settings, None)
# Assert: No custom settings (empty treated as no model)
assert has_custom is False
@pytest.mark.asyncio
async def test_has_custom_settings_with_whitespace_model(settings_store):
# Arrange: User with whitespace-only model
with patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'):
settings = Settings(llm_model=' ', llm_base_url='http://default.url')
# Act: Check if has custom settings
has_custom = settings_store._has_custom_settings(settings, None)
# Assert: No custom settings (whitespace treated as no model)
assert has_custom is False
@pytest.mark.asyncio
async def test_has_custom_settings_with_custom_model(settings_store):
# Arrange: User with custom model
with patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'):
settings = Settings(llm_model='gpt-4', llm_base_url='http://default.url')
# Act: Check if has custom settings
has_custom = settings_store._has_custom_settings(settings, None)
# Assert: Custom model detected
assert has_custom is True
@pytest.mark.asyncio
async def test_has_custom_settings_matches_old_default_model(settings_store):
# Arrange: User with old version and model matching old default
with (
patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'),
patch('server.constants.CURRENT_USER_SETTINGS_VERSION', 5),
patch(
'server.constants.USER_SETTINGS_VERSION_TO_MODEL',
{1: 'claude-3-5-sonnet-20241022'},
),
):
settings = Settings(
llm_model='litellm_proxy/prod/claude-3-5-sonnet-20241022',
llm_base_url='http://default.url',
)
# Act: Check if has custom settings
has_custom = settings_store._has_custom_settings(settings, 1)
# Assert: Matches old default, so not custom
assert has_custom is False
@pytest.mark.asyncio
async def test_has_custom_settings_matches_old_default_by_base_name(settings_store):
# Arrange: User with old version and model matching old default by base name
with (
patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'),
patch('server.constants.CURRENT_USER_SETTINGS_VERSION', 5),
patch(
'server.constants.USER_SETTINGS_VERSION_TO_MODEL',
{1: 'claude-3-5-sonnet-20241022'},
),
):
settings = Settings(
llm_model='anthropic/claude-3-5-sonnet-20241022',
llm_base_url='http://default.url',
)
# Act: Check if has custom settings
has_custom = settings_store._has_custom_settings(settings, 1)
# Assert: Matches old default by base name, so not custom
assert has_custom is False
@pytest.mark.asyncio
async def test_has_custom_settings_with_old_version_but_custom_model(settings_store):
# Arrange: User with old version but custom model
with (
patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'),
patch('server.constants.CURRENT_USER_SETTINGS_VERSION', 5),
patch(
'server.constants.USER_SETTINGS_VERSION_TO_MODEL',
{1: 'claude-3-5-sonnet-20241022'},
),
):
settings = Settings(llm_model='gpt-4', llm_base_url='http://default.url')
# Act: Check if has custom settings
has_custom = settings_store._has_custom_settings(settings, 1)
# Assert: Custom model detected
assert has_custom is True
@pytest.mark.asyncio
async def test_has_custom_settings_with_current_version(settings_store):
# Arrange: User with current version
with (
patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'),
patch('server.constants.CURRENT_USER_SETTINGS_VERSION', 5),
patch(
'server.constants.USER_SETTINGS_VERSION_TO_MODEL',
{1: 'claude-3-5-sonnet-20241022', 5: 'claude-opus-4-5-20251101'},
),
):
settings = Settings(
llm_model='claude-3-5-sonnet-20241022', llm_base_url='http://default.url'
)
# Act: Check if has custom settings
has_custom = settings_store._has_custom_settings(settings, 5)
# Assert: Current version, so model is custom (not old default)
assert has_custom is True
@pytest.mark.asyncio
async def test_has_custom_settings_with_none_version(settings_store):
# Arrange: User with no version
with patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'):
settings = Settings(
llm_model='claude-3-5-sonnet-20241022', llm_base_url='http://default.url'
)
# Act: Check if has custom settings
has_custom = settings_store._has_custom_settings(settings, None)
# Assert: No version, so model is custom
assert has_custom is True
@pytest.mark.asyncio
async def test_has_custom_settings_with_invalid_version(settings_store):
# Arrange: User with invalid version
with (
patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'),
patch('server.constants.CURRENT_USER_SETTINGS_VERSION', 5),
patch(
'server.constants.USER_SETTINGS_VERSION_TO_MODEL',
{1: 'claude-3-5-sonnet-20241022'},
),
):
settings = Settings(
llm_model='claude-3-5-sonnet-20241022', llm_base_url='http://default.url'
)
# Act: Check if has custom settings
has_custom = settings_store._has_custom_settings(settings, 99)
# Assert: Invalid version, so model is custom
assert has_custom is True
@pytest.mark.asyncio
async def test_has_custom_settings_normalizes_whitespace(settings_store):
# Arrange: Settings with whitespace in values
with patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'):
settings = Settings(
llm_model=' claude-3-5-sonnet-20241022 ',
llm_base_url=' http://default.url ',
)
# Act: Check if has custom settings
has_custom = settings_store._has_custom_settings(settings, None)
# Assert: Whitespace is normalized, custom model detected
assert has_custom is True
@pytest.mark.asyncio
async def test_update_settings_upgrades_user_from_old_defaults(
settings_store, mock_litellm_api, session_maker
):
# Arrange: User with old version using old defaults
old_version = 1
old_model = 'litellm_proxy/prod/claude-3-5-sonnet-20241022'
settings = Settings(llm_model=old_model, llm_base_url=LITE_LLM_API_URL)
# Use a consistent test URL
test_base_url = 'http://test.url'
with (
patch('storage.saas_settings_store.session_maker', session_maker),
patch(
'server.constants.USER_SETTINGS_VERSION_TO_MODEL',
{1: 'claude-3-5-sonnet-20241022', 5: 'claude-opus-4-5-20251101'},
),
patch(
'storage.saas_settings_store.USER_SETTINGS_VERSION_TO_MODEL',
{1: 'claude-3-5-sonnet-20241022', 5: 'claude-opus-4-5-20251101'},
),
patch('server.constants.CURRENT_USER_SETTINGS_VERSION', 5),
patch('storage.saas_settings_store.CURRENT_USER_SETTINGS_VERSION', 5),
patch('storage.saas_settings_store.LITE_LLM_API_URL', test_base_url),
patch(
'storage.saas_settings_store.get_default_litellm_model',
return_value='litellm_proxy/prod/claude-opus-4-5-20251101',
),
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'user@example.com'}),
),
):
# Create existing user settings with old version
with session_maker() as session:
existing_settings = UserSettings(
keycloak_user_id=settings_store.user_id,
user_version=old_version,
llm_model=old_model,
llm_base_url=test_base_url,
)
session.add(existing_settings)
session.commit()
# Update settings to use test_base_url
# Set user_version to match the database so _has_custom_settings can detect old defaults
settings = Settings(
llm_model=old_model, llm_base_url=test_base_url, user_version=old_version
)
# Act: Update settings
updated_settings = await settings_store.update_settings_with_litellm_default(
settings
)
# Assert: Settings upgraded to new defaults
assert updated_settings is not None
assert (
updated_settings.llm_model == 'litellm_proxy/prod/claude-opus-4-5-20251101'
)
assert updated_settings.llm_base_url == test_base_url
@pytest.mark.asyncio
async def test_update_settings_preserves_custom_settings_during_upgrade(
settings_store, mock_litellm_api, session_maker
):
# Arrange: User with old version but custom settings
old_version = 1
custom_model = 'gpt-4'
custom_base_url = 'http://custom.url'
settings = Settings(llm_model=custom_model, llm_base_url=custom_base_url)
with (
patch('storage.saas_settings_store.session_maker', session_maker),
patch(
'server.constants.USER_SETTINGS_VERSION_TO_MODEL',
{1: 'claude-3-5-sonnet-20241022'},
),
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'user@example.com'}),
),
):
# Create existing user settings with old version
with session_maker() as session:
existing_settings = UserSettings(
keycloak_user_id=settings_store.user_id,
user_version=old_version,
llm_model=custom_model,
llm_base_url=custom_base_url,
)
session.add(existing_settings)
session.commit()
# Act: Update settings
updated_settings = await settings_store.update_settings_with_litellm_default(
settings
)
# Assert: Custom settings preserved
assert updated_settings is not None
assert updated_settings.llm_model == custom_model
assert updated_settings.llm_base_url == custom_base_url
@pytest.mark.asyncio
async def test_update_settings_migrates_billing_margin_v3_to_v4(
settings_store, mock_litellm_api, session_maker
):
# Arrange: User with version 3 and billing margin
old_version = 3
billing_margin = 2.0
max_budget = 10.0
spend = 5.0
settings = Settings()
mock_get_response = AsyncMock()
mock_get_response.is_success = True
mock_get_response.json = MagicMock(
return_value={'user_info': {'max_budget': max_budget, 'spend': spend}}
)
mock_post_response = AsyncMock()
mock_post_response.is_success = True
mock_post_response.json = MagicMock(return_value={'key': 'test_api_key'})
with (
patch('storage.saas_settings_store.session_maker', session_maker),
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'user@example.com'}),
),
patch('httpx.AsyncClient') as mock_client,
):
mock_client.return_value.__aenter__.return_value.get.return_value = (
mock_get_response
)
mock_client.return_value.__aenter__.return_value.post.return_value = (
mock_post_response
)
# Create existing user settings with version 3 and billing margin
with session_maker() as session:
existing_settings = UserSettings(
keycloak_user_id=settings_store.user_id,
user_version=old_version,
billing_margin=billing_margin,
)
session.add(existing_settings)
session.commit()
# Act: Update settings
updated_settings = await settings_store.update_settings_with_litellm_default(
settings
)
# Assert: Settings updated
assert updated_settings is not None
# Assert: Billing margin applied to budget
call_args = mock_client.return_value.__aenter__.return_value.post.call_args[1]
assert call_args['json']['max_budget'] == max_budget * billing_margin
assert call_args['json']['spend'] == spend * billing_margin
# Assert: Billing margin reset to 1.0
with session_maker() as session:
updated_user_settings = (
session.query(UserSettings)
.filter(UserSettings.keycloak_user_id == settings_store.user_id)
.first()
)
assert updated_user_settings.billing_margin == 1.0
@pytest.mark.asyncio
async def test_update_settings_skips_billing_margin_migration_when_already_v4(
settings_store, mock_litellm_api, session_maker
):
# Arrange: User with version 4
version = 4
billing_margin = 2.0
max_budget = 10.0
spend = 5.0
settings = Settings()
mock_get_response = AsyncMock()
mock_get_response.is_success = True
mock_get_response.json = MagicMock(
return_value={'user_info': {'max_budget': max_budget, 'spend': spend}}
)
mock_post_response = AsyncMock()
mock_post_response.is_success = True
mock_post_response.json = MagicMock(return_value={'key': 'test_api_key'})
with (
patch('storage.saas_settings_store.session_maker', session_maker),
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'user@example.com'}),
),
patch('httpx.AsyncClient') as mock_client,
):
mock_client.return_value.__aenter__.return_value.get.return_value = (
mock_get_response
)
mock_client.return_value.__aenter__.return_value.post.return_value = (
mock_post_response
)
# Create existing user settings with version 4
with session_maker() as session:
existing_settings = UserSettings(
keycloak_user_id=settings_store.user_id,
user_version=version,
billing_margin=billing_margin,
)
session.add(existing_settings)
session.commit()
# Act: Update settings
updated_settings = await settings_store.update_settings_with_litellm_default(
settings
)
# Assert: Settings updated
assert updated_settings is not None
# Assert: Billing margin NOT applied (version >= 4)
call_args = mock_client.return_value.__aenter__.return_value.post.call_args[1]
assert call_args['json']['max_budget'] == max_budget
assert call_args['json']['spend'] == spend

View File

@@ -5,7 +5,12 @@ import jwt
import pytest
from fastapi import Request
from pydantic import SecretStr
from server.auth.auth_error import BearerTokenError, CookieError, NoCredentialsError
from server.auth.auth_error import (
AuthError,
BearerTokenError,
CookieError,
NoCredentialsError,
)
from server.auth.saas_user_auth import (
SaasUserAuth,
get_api_key_from_header,
@@ -647,3 +652,97 @@ def test_get_api_key_from_header_bearer_with_empty_token():
# Assert that empty string from Bearer is returned (current behavior)
# This tests the current implementation behavior
assert api_key == ''
@pytest.mark.asyncio
async def test_saas_user_auth_from_signed_token_blocked_domain(mock_config):
"""Test that saas_user_auth_from_signed_token raises AuthError when email domain is blocked."""
# Arrange
access_payload = {
'sub': 'test_user_id',
'exp': int(time.time()) + 3600,
'email': 'user@colsch.us',
'email_verified': True,
}
access_token = jwt.encode(access_payload, 'access_secret', algorithm='HS256')
token_payload = {
'access_token': access_token,
'refresh_token': 'test_refresh_token',
}
signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256')
with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker:
mock_domain_blocker.is_active.return_value = True
mock_domain_blocker.is_domain_blocked.return_value = True
# Act & Assert
with pytest.raises(AuthError) as exc_info:
await saas_user_auth_from_signed_token(signed_token)
assert 'email domain is not allowed' in str(exc_info.value)
mock_domain_blocker.is_domain_blocked.assert_called_once_with('user@colsch.us')
@pytest.mark.asyncio
async def test_saas_user_auth_from_signed_token_allowed_domain(mock_config):
"""Test that saas_user_auth_from_signed_token succeeds when email domain is not blocked."""
# Arrange
access_payload = {
'sub': 'test_user_id',
'exp': int(time.time()) + 3600,
'email': 'user@example.com',
'email_verified': True,
}
access_token = jwt.encode(access_payload, 'access_secret', algorithm='HS256')
token_payload = {
'access_token': access_token,
'refresh_token': 'test_refresh_token',
}
signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256')
with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker:
mock_domain_blocker.is_active.return_value = True
mock_domain_blocker.is_domain_blocked.return_value = False
# Act
result = await saas_user_auth_from_signed_token(signed_token)
# Assert
assert isinstance(result, SaasUserAuth)
assert result.user_id == 'test_user_id'
assert result.email == 'user@example.com'
mock_domain_blocker.is_domain_blocked.assert_called_once_with(
'user@example.com'
)
@pytest.mark.asyncio
async def test_saas_user_auth_from_signed_token_domain_blocking_inactive(mock_config):
"""Test that saas_user_auth_from_signed_token succeeds when domain blocking is not active."""
# Arrange
access_payload = {
'sub': 'test_user_id',
'exp': int(time.time()) + 3600,
'email': 'user@colsch.us',
'email_verified': True,
}
access_token = jwt.encode(access_payload, 'access_secret', algorithm='HS256')
token_payload = {
'access_token': access_token,
'refresh_token': 'test_refresh_token',
}
signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256')
with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker:
mock_domain_blocker.is_active.return_value = False
# Act
result = await saas_user_auth_from_signed_token(signed_token)
# Assert
assert isinstance(result, SaasUserAuth)
assert result.user_id == 'test_user_id'
mock_domain_blocker.is_domain_blocked.assert_not_called()

View File

@@ -0,0 +1 @@
"""Tests for sharing package."""

View File

@@ -0,0 +1,91 @@
"""Tests for public conversation models."""
from datetime import datetime
from uuid import uuid4
from server.sharing.shared_conversation_models import (
SharedConversation,
SharedConversationPage,
SharedConversationSortOrder,
)
def test_public_conversation_creation():
"""Test that SharedConversation can be created with all required fields."""
conversation_id = uuid4()
now = datetime.utcnow()
conversation = SharedConversation(
id=conversation_id,
created_by_user_id='test_user',
sandbox_id='test_sandbox',
title='Test Conversation',
created_at=now,
updated_at=now,
selected_repository=None,
parent_conversation_id=None,
)
assert conversation.id == conversation_id
assert conversation.title == 'Test Conversation'
assert conversation.created_by_user_id == 'test_user'
assert conversation.sandbox_id == 'test_sandbox'
def test_public_conversation_page_creation():
"""Test that SharedConversationPage can be created."""
conversation_id = uuid4()
now = datetime.utcnow()
conversation = SharedConversation(
id=conversation_id,
created_by_user_id='test_user',
sandbox_id='test_sandbox',
title='Test Conversation',
created_at=now,
updated_at=now,
selected_repository=None,
parent_conversation_id=None,
)
page = SharedConversationPage(
items=[conversation],
next_page_id='next_page',
)
assert len(page.items) == 1
assert page.items[0].id == conversation_id
assert page.next_page_id == 'next_page'
def test_public_conversation_sort_order_enum():
"""Test that SharedConversationSortOrder enum has expected values."""
assert hasattr(SharedConversationSortOrder, 'CREATED_AT')
assert hasattr(SharedConversationSortOrder, 'CREATED_AT_DESC')
assert hasattr(SharedConversationSortOrder, 'UPDATED_AT')
assert hasattr(SharedConversationSortOrder, 'UPDATED_AT_DESC')
assert hasattr(SharedConversationSortOrder, 'TITLE')
assert hasattr(SharedConversationSortOrder, 'TITLE_DESC')
def test_public_conversation_optional_fields():
"""Test that SharedConversation works with optional fields."""
conversation_id = uuid4()
parent_id = uuid4()
now = datetime.utcnow()
conversation = SharedConversation(
id=conversation_id,
created_by_user_id='test_user',
sandbox_id='test_sandbox',
title='Test Conversation',
created_at=now,
updated_at=now,
selected_repository='owner/repo',
parent_conversation_id=parent_id,
llm_model='gpt-4',
)
assert conversation.selected_repository == 'owner/repo'
assert conversation.parent_conversation_id == parent_id
assert conversation.llm_model == 'gpt-4'

View File

@@ -0,0 +1,430 @@
"""Tests for SharedConversationInfoService."""
from datetime import UTC, datetime
from typing import AsyncGenerator
from uuid import uuid4
import pytest
from server.sharing.shared_conversation_models import (
SharedConversationSortOrder,
)
from server.sharing.sql_shared_conversation_info_service import (
SQLSharedConversationInfoService,
)
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationInfo,
)
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
SQLAppConversationInfoService,
)
from openhands.app_server.user.specifiy_user_context import SpecifyUserContext
from openhands.app_server.utils.sql_utils import Base
from openhands.integrations.provider import ProviderType
from openhands.sdk.llm import MetricsSnapshot
from openhands.sdk.llm.utils.metrics import TokenUsage
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
@pytest.fixture
async def async_engine():
"""Create an async SQLite engine for testing."""
engine = create_async_engine(
'sqlite+aiosqlite:///:memory:',
poolclass=StaticPool,
connect_args={'check_same_thread': False},
echo=False,
)
# Create all tables
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.fixture
async def async_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
"""Create an async session for testing."""
async_session_maker = async_sessionmaker(
async_engine, class_=AsyncSession, expire_on_commit=False
)
async with async_session_maker() as db_session:
yield db_session
@pytest.fixture
async def shared_conversation_info_service(async_session):
"""Create a SharedConversationInfoService for testing."""
return SQLSharedConversationInfoService(db_session=async_session)
@pytest.fixture
async def app_conversation_service(async_session):
"""Create an AppConversationInfoService for creating test data."""
return SQLAppConversationInfoService(
db_session=async_session, user_context=SpecifyUserContext(user_id=None)
)
@pytest.fixture
def sample_conversation_info():
"""Create a sample conversation info for testing."""
return AppConversationInfo(
id=uuid4(),
created_by_user_id='test_user',
sandbox_id='test_sandbox',
selected_repository='test/repo',
selected_branch='main',
git_provider=ProviderType.GITHUB,
title='Test Conversation',
trigger=ConversationTrigger.GUI,
pr_number=[123],
llm_model='gpt-4',
metrics=MetricsSnapshot(
accumulated_cost=1.5,
max_budget_per_task=10.0,
accumulated_token_usage=TokenUsage(
prompt_tokens=100,
completion_tokens=50,
cache_read_tokens=0,
cache_write_tokens=0,
context_window=4096,
per_turn_token=150,
),
),
parent_conversation_id=None,
sub_conversation_ids=[],
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
public=True, # Make it public for testing
)
@pytest.fixture
def sample_private_conversation_info():
"""Create a sample private conversation info for testing."""
return AppConversationInfo(
id=uuid4(),
created_by_user_id='test_user',
sandbox_id='test_sandbox_private',
selected_repository='test/private_repo',
selected_branch='main',
git_provider=ProviderType.GITHUB,
title='Private Conversation',
trigger=ConversationTrigger.GUI,
pr_number=[124],
llm_model='gpt-4',
metrics=MetricsSnapshot(
accumulated_cost=2.0,
max_budget_per_task=10.0,
accumulated_token_usage=TokenUsage(
prompt_tokens=200,
completion_tokens=100,
cache_read_tokens=0,
cache_write_tokens=0,
context_window=4096,
per_turn_token=300,
),
),
parent_conversation_id=None,
sub_conversation_ids=[],
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
public=False, # Make it private
)
class TestSharedConversationInfoService:
"""Test cases for SharedConversationInfoService."""
@pytest.mark.asyncio
@pytest.mark.asyncio
async def test_get_shared_conversation_info_returns_public_conversation(
self,
shared_conversation_info_service,
app_conversation_service,
sample_conversation_info,
):
"""Test that get_shared_conversation_info returns a public conversation."""
# Create a public conversation
await app_conversation_service.save_app_conversation_info(
sample_conversation_info
)
# Retrieve it via public service
result = await shared_conversation_info_service.get_shared_conversation_info(
sample_conversation_info.id
)
assert result is not None
assert result.id == sample_conversation_info.id
assert result.title == sample_conversation_info.title
assert result.created_by_user_id == sample_conversation_info.created_by_user_id
@pytest.mark.asyncio
async def test_get_shared_conversation_info_returns_none_for_private_conversation(
self,
shared_conversation_info_service,
app_conversation_service,
sample_private_conversation_info,
):
"""Test that get_shared_conversation_info returns None for private conversations."""
# Create a private conversation
await app_conversation_service.save_app_conversation_info(
sample_private_conversation_info
)
# Try to retrieve it via public service
result = await shared_conversation_info_service.get_shared_conversation_info(
sample_private_conversation_info.id
)
assert result is None
@pytest.mark.asyncio
async def test_get_shared_conversation_info_returns_none_for_nonexistent_conversation(
self, shared_conversation_info_service
):
"""Test that get_shared_conversation_info returns None for nonexistent conversations."""
nonexistent_id = uuid4()
result = await shared_conversation_info_service.get_shared_conversation_info(
nonexistent_id
)
assert result is None
@pytest.mark.asyncio
async def test_search_shared_conversation_info_returns_only_public_conversations(
self,
shared_conversation_info_service,
app_conversation_service,
sample_conversation_info,
sample_private_conversation_info,
):
"""Test that search only returns public conversations."""
# Create both public and private conversations
await app_conversation_service.save_app_conversation_info(
sample_conversation_info
)
await app_conversation_service.save_app_conversation_info(
sample_private_conversation_info
)
# Search for all conversations
result = (
await shared_conversation_info_service.search_shared_conversation_info()
)
# Should only return the public conversation
assert len(result.items) == 1
assert result.items[0].id == sample_conversation_info.id
assert result.items[0].title == sample_conversation_info.title
@pytest.mark.asyncio
async def test_search_shared_conversation_info_with_title_filter(
self,
shared_conversation_info_service,
app_conversation_service,
sample_conversation_info,
):
"""Test searching with title filter."""
# Create a public conversation
await app_conversation_service.save_app_conversation_info(
sample_conversation_info
)
# Search with matching title
result = await shared_conversation_info_service.search_shared_conversation_info(
title__contains='Test'
)
assert len(result.items) == 1
# Search with non-matching title
result = await shared_conversation_info_service.search_shared_conversation_info(
title__contains='NonExistent'
)
assert len(result.items) == 0
@pytest.mark.asyncio
async def test_search_shared_conversation_info_with_sort_order(
self,
shared_conversation_info_service,
app_conversation_service,
):
"""Test searching with different sort orders."""
# Create multiple public conversations with different titles and timestamps
conv1 = AppConversationInfo(
id=uuid4(),
created_by_user_id='test_user',
sandbox_id='test_sandbox_1',
title='A First Conversation',
created_at=datetime(2023, 1, 1, tzinfo=UTC),
updated_at=datetime(2023, 1, 1, tzinfo=UTC),
public=True,
metrics=MetricsSnapshot(
accumulated_cost=0.0,
max_budget_per_task=10.0,
accumulated_token_usage=TokenUsage(),
),
)
conv2 = AppConversationInfo(
id=uuid4(),
created_by_user_id='test_user',
sandbox_id='test_sandbox_2',
title='B Second Conversation',
created_at=datetime(2023, 1, 2, tzinfo=UTC),
updated_at=datetime(2023, 1, 2, tzinfo=UTC),
public=True,
metrics=MetricsSnapshot(
accumulated_cost=0.0,
max_budget_per_task=10.0,
accumulated_token_usage=TokenUsage(),
),
)
await app_conversation_service.save_app_conversation_info(conv1)
await app_conversation_service.save_app_conversation_info(conv2)
# Test sort by title ascending
result = await shared_conversation_info_service.search_shared_conversation_info(
sort_order=SharedConversationSortOrder.TITLE
)
assert len(result.items) == 2
assert result.items[0].title == 'A First Conversation'
assert result.items[1].title == 'B Second Conversation'
# Test sort by title descending
result = await shared_conversation_info_service.search_shared_conversation_info(
sort_order=SharedConversationSortOrder.TITLE_DESC
)
assert len(result.items) == 2
assert result.items[0].title == 'B Second Conversation'
assert result.items[1].title == 'A First Conversation'
# Test sort by created_at ascending
result = await shared_conversation_info_service.search_shared_conversation_info(
sort_order=SharedConversationSortOrder.CREATED_AT
)
assert len(result.items) == 2
assert result.items[0].id == conv1.id
assert result.items[1].id == conv2.id
# Test sort by created_at descending (default)
result = await shared_conversation_info_service.search_shared_conversation_info(
sort_order=SharedConversationSortOrder.CREATED_AT_DESC
)
assert len(result.items) == 2
assert result.items[0].id == conv2.id
assert result.items[1].id == conv1.id
@pytest.mark.asyncio
async def test_count_shared_conversation_info(
self,
shared_conversation_info_service,
app_conversation_service,
sample_conversation_info,
sample_private_conversation_info,
):
"""Test counting public conversations."""
# Initially should be 0
count = await shared_conversation_info_service.count_shared_conversation_info()
assert count == 0
# Create a public conversation
await app_conversation_service.save_app_conversation_info(
sample_conversation_info
)
count = await shared_conversation_info_service.count_shared_conversation_info()
assert count == 1
# Create a private conversation - count should remain 1
await app_conversation_service.save_app_conversation_info(
sample_private_conversation_info
)
count = await shared_conversation_info_service.count_shared_conversation_info()
assert count == 1
@pytest.mark.asyncio
async def test_batch_get_shared_conversation_info(
self,
shared_conversation_info_service,
app_conversation_service,
sample_conversation_info,
sample_private_conversation_info,
):
"""Test batch getting public conversations."""
# Create both public and private conversations
await app_conversation_service.save_app_conversation_info(
sample_conversation_info
)
await app_conversation_service.save_app_conversation_info(
sample_private_conversation_info
)
# Batch get both conversations
result = (
await shared_conversation_info_service.batch_get_shared_conversation_info(
[sample_conversation_info.id, sample_private_conversation_info.id]
)
)
# Should return the public one and None for the private one
assert len(result) == 2
assert result[0] is not None
assert result[0].id == sample_conversation_info.id
assert result[1] is None
@pytest.mark.asyncio
async def test_search_with_pagination(
self,
shared_conversation_info_service,
app_conversation_service,
):
"""Test search with pagination."""
# Create multiple public conversations
conversations = []
for i in range(5):
conv = AppConversationInfo(
id=uuid4(),
created_by_user_id='test_user',
sandbox_id=f'test_sandbox_{i}',
title=f'Conversation {i}',
created_at=datetime(2023, 1, i + 1, tzinfo=UTC),
updated_at=datetime(2023, 1, i + 1, tzinfo=UTC),
public=True,
metrics=MetricsSnapshot(
accumulated_cost=0.0,
max_budget_per_task=10.0,
accumulated_token_usage=TokenUsage(),
),
)
conversations.append(conv)
await app_conversation_service.save_app_conversation_info(conv)
# Get first page with limit 2
result = await shared_conversation_info_service.search_shared_conversation_info(
limit=2, sort_order=SharedConversationSortOrder.CREATED_AT
)
assert len(result.items) == 2
assert result.next_page_id is not None
# Get next page
result2 = (
await shared_conversation_info_service.search_shared_conversation_info(
limit=2,
page_id=result.next_page_id,
sort_order=SharedConversationSortOrder.CREATED_AT,
)
)
assert len(result2.items) == 2
assert result2.next_page_id is not None
# Verify no overlap between pages
page1_ids = {item.id for item in result.items}
page2_ids = {item.id for item in result2.items}
assert page1_ids.isdisjoint(page2_ids)

View File

@@ -0,0 +1,365 @@
"""Tests for SharedEventService."""
from datetime import UTC, datetime
from unittest.mock import AsyncMock
from uuid import uuid4
import pytest
from server.sharing.filesystem_shared_event_service import (
SharedEventServiceImpl,
)
from server.sharing.shared_conversation_info_service import (
SharedConversationInfoService,
)
from server.sharing.shared_conversation_models import SharedConversation
from openhands.agent_server.models import EventPage, EventSortOrder
from openhands.app_server.event.event_service import EventService
from openhands.sdk.llm import MetricsSnapshot
from openhands.sdk.llm.utils.metrics import TokenUsage
@pytest.fixture
def mock_shared_conversation_info_service():
"""Create a mock SharedConversationInfoService."""
return AsyncMock(spec=SharedConversationInfoService)
@pytest.fixture
def mock_event_service():
"""Create a mock EventService."""
return AsyncMock(spec=EventService)
@pytest.fixture
def shared_event_service(mock_shared_conversation_info_service, mock_event_service):
"""Create a SharedEventService for testing."""
return SharedEventServiceImpl(
shared_conversation_info_service=mock_shared_conversation_info_service,
event_service=mock_event_service,
)
@pytest.fixture
def sample_public_conversation():
"""Create a sample public conversation."""
return SharedConversation(
id=uuid4(),
created_by_user_id='test_user',
sandbox_id='test_sandbox',
title='Test Public Conversation',
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
metrics=MetricsSnapshot(
accumulated_cost=0.0,
max_budget_per_task=10.0,
accumulated_token_usage=TokenUsage(),
),
)
@pytest.fixture
def sample_event():
"""Create a sample event."""
# For testing purposes, we'll just use a mock that the EventPage can accept
# The actual event creation is complex and not the focus of these tests
return None
class TestSharedEventService:
"""Test cases for SharedEventService."""
async def test_get_shared_event_returns_event_for_public_conversation(
self,
shared_event_service,
mock_shared_conversation_info_service,
mock_event_service,
sample_public_conversation,
sample_event,
):
"""Test that get_shared_event returns an event for a public conversation."""
conversation_id = sample_public_conversation.id
event_id = 'test_event_id'
# Mock the public conversation service to return a public conversation
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
# Mock the event service to return an event
mock_event_service.get_event.return_value = sample_event
# Call the method
result = await shared_event_service.get_shared_event(conversation_id, event_id)
# Verify the result
assert result == sample_event
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
conversation_id
)
mock_event_service.get_event.assert_called_once_with(event_id)
async def test_get_shared_event_returns_none_for_private_conversation(
self,
shared_event_service,
mock_shared_conversation_info_service,
mock_event_service,
):
"""Test that get_shared_event returns None for a private conversation."""
conversation_id = uuid4()
event_id = 'test_event_id'
# Mock the public conversation service to return None (private conversation)
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = None
# Call the method
result = await shared_event_service.get_shared_event(conversation_id, event_id)
# Verify the result
assert result is None
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
conversation_id
)
# Event service should not be called
mock_event_service.get_event.assert_not_called()
async def test_search_shared_events_returns_events_for_public_conversation(
self,
shared_event_service,
mock_shared_conversation_info_service,
mock_event_service,
sample_public_conversation,
sample_event,
):
"""Test that search_shared_events returns events for a public conversation."""
conversation_id = sample_public_conversation.id
# Mock the public conversation service to return a public conversation
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
# Mock the event service to return events
mock_event_page = EventPage(items=[], next_page_id=None)
mock_event_service.search_events.return_value = mock_event_page
# Call the method
result = await shared_event_service.search_shared_events(
conversation_id=conversation_id,
kind__eq='ActionEvent',
limit=10,
)
# Verify the result
assert result == mock_event_page
assert len(result.items) == 0 # Empty list as we mocked
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
conversation_id
)
mock_event_service.search_events.assert_called_once_with(
conversation_id__eq=conversation_id,
kind__eq='ActionEvent',
timestamp__gte=None,
timestamp__lt=None,
sort_order=EventSortOrder.TIMESTAMP,
page_id=None,
limit=10,
)
async def test_search_shared_events_returns_empty_for_private_conversation(
self,
shared_event_service,
mock_shared_conversation_info_service,
mock_event_service,
):
"""Test that search_shared_events returns empty page for a private conversation."""
conversation_id = uuid4()
# Mock the public conversation service to return None (private conversation)
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = None
# Call the method
result = await shared_event_service.search_shared_events(
conversation_id=conversation_id,
limit=10,
)
# Verify the result
assert isinstance(result, EventPage)
assert len(result.items) == 0
assert result.next_page_id is None
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
conversation_id
)
# Event service should not be called
mock_event_service.search_events.assert_not_called()
async def test_count_shared_events_returns_count_for_public_conversation(
self,
shared_event_service,
mock_shared_conversation_info_service,
mock_event_service,
sample_public_conversation,
):
"""Test that count_shared_events returns count for a public conversation."""
conversation_id = sample_public_conversation.id
# Mock the public conversation service to return a public conversation
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
# Mock the event service to return a count
mock_event_service.count_events.return_value = 5
# Call the method
result = await shared_event_service.count_shared_events(
conversation_id=conversation_id,
kind__eq='ActionEvent',
)
# Verify the result
assert result == 5
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
conversation_id
)
mock_event_service.count_events.assert_called_once_with(
conversation_id__eq=conversation_id,
kind__eq='ActionEvent',
timestamp__gte=None,
timestamp__lt=None,
sort_order=EventSortOrder.TIMESTAMP,
)
async def test_count_shared_events_returns_zero_for_private_conversation(
self,
shared_event_service,
mock_shared_conversation_info_service,
mock_event_service,
):
"""Test that count_shared_events returns 0 for a private conversation."""
conversation_id = uuid4()
# Mock the public conversation service to return None (private conversation)
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = None
# Call the method
result = await shared_event_service.count_shared_events(
conversation_id=conversation_id,
)
# Verify the result
assert result == 0
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
conversation_id
)
# Event service should not be called
mock_event_service.count_events.assert_not_called()
async def test_batch_get_shared_events_returns_events_for_public_conversation(
self,
shared_event_service,
mock_shared_conversation_info_service,
mock_event_service,
sample_public_conversation,
sample_event,
):
"""Test that batch_get_shared_events returns events for a public conversation."""
conversation_id = sample_public_conversation.id
event_ids = ['event1', 'event2']
# Mock the public conversation service to return a public conversation
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
# Mock the event service to return events
mock_event_service.get_event.side_effect = [sample_event, None]
# Call the method
result = await shared_event_service.batch_get_shared_events(
conversation_id, event_ids
)
# Verify the result
assert len(result) == 2
assert result[0] == sample_event
assert result[1] is None
# Verify that get_shared_conversation_info was called for each event
assert (
mock_shared_conversation_info_service.get_shared_conversation_info.call_count
== 2
)
# Verify that get_event was called for each event
assert mock_event_service.get_event.call_count == 2
async def test_batch_get_shared_events_returns_none_for_private_conversation(
self,
shared_event_service,
mock_shared_conversation_info_service,
mock_event_service,
):
"""Test that batch_get_shared_events returns None for a private conversation."""
conversation_id = uuid4()
event_ids = ['event1', 'event2']
# Mock the public conversation service to return None (private conversation)
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = None
# Call the method
result = await shared_event_service.batch_get_shared_events(
conversation_id, event_ids
)
# Verify the result
assert len(result) == 2
assert result[0] is None
assert result[1] is None
# Verify that get_shared_conversation_info was called for each event
assert (
mock_shared_conversation_info_service.get_shared_conversation_info.call_count
== 2
)
# Event service should not be called
mock_event_service.get_event.assert_not_called()
async def test_search_shared_events_with_all_parameters(
self,
shared_event_service,
mock_shared_conversation_info_service,
mock_event_service,
sample_public_conversation,
):
"""Test search_shared_events with all parameters."""
conversation_id = sample_public_conversation.id
timestamp_gte = datetime(2023, 1, 1, tzinfo=UTC)
timestamp_lt = datetime(2023, 12, 31, tzinfo=UTC)
# Mock the public conversation service to return a public conversation
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
# Mock the event service to return events
mock_event_page = EventPage(items=[], next_page_id='next_page')
mock_event_service.search_events.return_value = mock_event_page
# Call the method with all parameters
result = await shared_event_service.search_shared_events(
conversation_id=conversation_id,
kind__eq='ObservationEvent',
timestamp__gte=timestamp_gte,
timestamp__lt=timestamp_lt,
sort_order=EventSortOrder.TIMESTAMP_DESC,
page_id='current_page',
limit=50,
)
# Verify the result
assert result == mock_event_page
mock_event_service.search_events.assert_called_once_with(
conversation_id__eq=conversation_id,
kind__eq='ObservationEvent',
timestamp__gte=timestamp_gte,
timestamp__lt=timestamp_lt,
sort_order=EventSortOrder.TIMESTAMP_DESC,
page_id='current_page',
limit=50,
)

View File

@@ -1,6 +1,8 @@
from unittest.mock import MagicMock
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from keycloak.exceptions import KeycloakConnectionError, KeycloakError
from server.auth.token_manager import TokenManager
from sqlalchemy.orm import Session
from storage.offline_token_store import OfflineTokenStore
from storage.stored_offline_token import StoredOfflineToken
@@ -32,6 +34,14 @@ def token_store(mock_session_maker, mock_config):
return OfflineTokenStore('test_user_id', mock_session_maker, mock_config)
@pytest.fixture
def token_manager():
with patch('server.config.get_config') as mock_get_config:
mock_config = mock_get_config.return_value
mock_config.jwt_secret.get_secret_value.return_value = 'test_secret'
return TokenManager(external=False)
@pytest.mark.asyncio
async def test_store_token_new_record(token_store, mock_session):
# Setup
@@ -109,3 +119,419 @@ async def test_get_instance(mock_config):
assert isinstance(result, OfflineTokenStore)
assert result.user_id == test_user_id
assert result.config == mock_config
class TestCheckDuplicateBaseEmail:
"""Test cases for check_duplicate_base_email method."""
@pytest.mark.asyncio
async def test_check_duplicate_base_email_no_plus_modifier(self, token_manager):
"""Test that emails without + modifier are still checked for duplicates."""
# Arrange
email = 'joe@example.com'
current_user_id = 'user123'
with (
patch.object(
token_manager, '_query_users_by_wildcard_pattern'
) as mock_query,
patch.object(token_manager, '_find_duplicate_in_users') as mock_find,
):
mock_find.return_value = False
mock_query.return_value = {}
# Act
result = await token_manager.check_duplicate_base_email(
email, current_user_id
)
# Assert
assert result is False
mock_query.assert_called_once()
mock_find.assert_called_once()
@pytest.mark.asyncio
async def test_check_duplicate_base_email_empty_email(self, token_manager):
"""Test that empty email returns False."""
# Arrange
email = ''
current_user_id = 'user123'
# Act
result = await token_manager.check_duplicate_base_email(email, current_user_id)
# Assert
assert result is False
@pytest.mark.asyncio
async def test_check_duplicate_base_email_invalid_email(self, token_manager):
"""Test that invalid email returns False."""
# Arrange
email = 'invalid-email'
current_user_id = 'user123'
# Act
result = await token_manager.check_duplicate_base_email(email, current_user_id)
# Assert
assert result is False
@pytest.mark.asyncio
async def test_check_duplicate_base_email_duplicate_found(self, token_manager):
"""Test that duplicate email is detected when found."""
# Arrange
email = 'joe+test@example.com'
current_user_id = 'user123'
existing_user = {
'id': 'existing_user_id',
'email': 'joe@example.com',
}
with (
patch.object(
token_manager, '_query_users_by_wildcard_pattern'
) as mock_query,
patch.object(token_manager, '_find_duplicate_in_users') as mock_find,
):
mock_find.return_value = True
mock_query.return_value = {'existing_user_id': existing_user}
# Act
result = await token_manager.check_duplicate_base_email(
email, current_user_id
)
# Assert
assert result is True
mock_query.assert_called_once()
mock_find.assert_called_once()
@pytest.mark.asyncio
async def test_check_duplicate_base_email_no_duplicate(self, token_manager):
"""Test that no duplicate is found when none exists."""
# Arrange
email = 'joe+test@example.com'
current_user_id = 'user123'
with (
patch.object(
token_manager, '_query_users_by_wildcard_pattern'
) as mock_query,
patch.object(token_manager, '_find_duplicate_in_users') as mock_find,
):
mock_find.return_value = False
mock_query.return_value = {}
# Act
result = await token_manager.check_duplicate_base_email(
email, current_user_id
)
# Assert
assert result is False
@pytest.mark.asyncio
async def test_check_duplicate_base_email_keycloak_connection_error(
self, token_manager
):
"""Test that KeycloakConnectionError triggers retry and raises RetryError."""
# Arrange
email = 'joe+test@example.com'
current_user_id = 'user123'
with patch.object(
token_manager, '_query_users_by_wildcard_pattern'
) as mock_query:
mock_query.side_effect = KeycloakConnectionError('Connection failed')
# Act & Assert
# KeycloakConnectionError is re-raised, which triggers retry decorator
# After retries exhaust (2 attempts), it raises RetryError
from tenacity import RetryError
with pytest.raises(RetryError):
await token_manager.check_duplicate_base_email(email, current_user_id)
@pytest.mark.asyncio
async def test_check_duplicate_base_email_general_exception(self, token_manager):
"""Test that general exceptions are handled gracefully."""
# Arrange
email = 'joe+test@example.com'
current_user_id = 'user123'
with patch.object(
token_manager, '_query_users_by_wildcard_pattern'
) as mock_query:
mock_query.side_effect = Exception('Unexpected error')
# Act
result = await token_manager.check_duplicate_base_email(
email, current_user_id
)
# Assert
assert result is False
class TestQueryUsersByWildcardPattern:
"""Test cases for _query_users_by_wildcard_pattern method."""
@pytest.mark.asyncio
async def test_query_users_by_wildcard_pattern_success_with_search(
self, token_manager
):
"""Test successful query using search parameter."""
# Arrange
local_part = 'joe'
domain = 'example.com'
mock_users = [
{'id': 'user1', 'email': 'joe@example.com'},
{'id': 'user2', 'email': 'joe+test@example.com'},
]
with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin:
mock_admin = MagicMock()
mock_admin.a_get_users = AsyncMock(return_value=mock_users)
mock_get_admin.return_value = mock_admin
# Act
result = await token_manager._query_users_by_wildcard_pattern(
local_part, domain
)
# Assert
assert len(result) == 2
assert 'user1' in result
assert 'user2' in result
mock_admin.a_get_users.assert_called_once_with(
{'search': 'joe*@example.com'}
)
@pytest.mark.asyncio
async def test_query_users_by_wildcard_pattern_fallback_to_q(self, token_manager):
"""Test fallback to q parameter when search fails."""
# Arrange
local_part = 'joe'
domain = 'example.com'
mock_users = [{'id': 'user1', 'email': 'joe@example.com'}]
with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin:
mock_admin = MagicMock()
# First call fails, second succeeds
mock_admin.a_get_users = AsyncMock(
side_effect=[Exception('Search failed'), mock_users]
)
mock_get_admin.return_value = mock_admin
# Act
result = await token_manager._query_users_by_wildcard_pattern(
local_part, domain
)
# Assert
assert len(result) == 1
assert 'user1' in result
assert mock_admin.a_get_users.call_count == 2
@pytest.mark.asyncio
async def test_query_users_by_wildcard_pattern_empty_result(self, token_manager):
"""Test query returns empty dict when no users found."""
# Arrange
local_part = 'joe'
domain = 'example.com'
with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin:
mock_admin = MagicMock()
mock_admin.a_get_users = AsyncMock(return_value=[])
mock_get_admin.return_value = mock_admin
# Act
result = await token_manager._query_users_by_wildcard_pattern(
local_part, domain
)
# Assert
assert result == {}
class TestFindDuplicateInUsers:
"""Test cases for _find_duplicate_in_users method."""
def test_find_duplicate_in_users_with_regex_match(self, token_manager):
"""Test finding duplicate using regex pattern."""
# Arrange
users = {
'user1': {'id': 'user1', 'email': 'joe@example.com'},
'user2': {'id': 'user2', 'email': 'joe+test@example.com'},
}
base_email = 'joe@example.com'
current_user_id = 'user3'
# Act
result = token_manager._find_duplicate_in_users(
users, base_email, current_user_id
)
# Assert
assert result is True
def test_find_duplicate_in_users_fallback_to_simple_matching(self, token_manager):
"""Test fallback to simple matching when regex pattern is None."""
# Arrange
users = {
'user1': {'id': 'user1', 'email': 'joe@example.com'},
}
base_email = 'invalid-email' # Will cause regex pattern to be None
current_user_id = 'user2'
with patch(
'server.auth.token_manager.get_base_email_regex_pattern', return_value=None
):
# Act
result = token_manager._find_duplicate_in_users(
users, base_email, current_user_id
)
# Assert
# Should use fallback matching, but invalid base_email won't match
assert result is False
def test_find_duplicate_in_users_excludes_current_user(self, token_manager):
"""Test that current user is excluded from duplicate check."""
# Arrange
users = {
'user1': {'id': 'user1', 'email': 'joe@example.com'},
}
base_email = 'joe@example.com'
current_user_id = 'user1' # Same as user in users dict
# Act
result = token_manager._find_duplicate_in_users(
users, base_email, current_user_id
)
# Assert
assert result is False
def test_find_duplicate_in_users_no_match(self, token_manager):
"""Test that no duplicate is found when emails don't match."""
# Arrange
users = {
'user1': {'id': 'user1', 'email': 'jane@example.com'},
}
base_email = 'joe@example.com'
current_user_id = 'user2'
# Act
result = token_manager._find_duplicate_in_users(
users, base_email, current_user_id
)
# Assert
assert result is False
def test_find_duplicate_in_users_empty_dict(self, token_manager):
"""Test that empty users dict returns False."""
# Arrange
users: dict[str, dict] = {}
base_email = 'joe@example.com'
current_user_id = 'user1'
# Act
result = token_manager._find_duplicate_in_users(
users, base_email, current_user_id
)
# Assert
assert result is False
class TestDeleteKeycloakUser:
"""Test cases for delete_keycloak_user method."""
@pytest.mark.asyncio
async def test_delete_keycloak_user_success(self, token_manager):
"""Test successful deletion of Keycloak user."""
# Arrange
user_id = 'test_user_id'
with (
patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin,
patch('asyncio.to_thread') as mock_to_thread,
):
mock_admin = MagicMock()
mock_admin.delete_user = MagicMock()
mock_get_admin.return_value = mock_admin
mock_to_thread.return_value = None
# Act
result = await token_manager.delete_keycloak_user(user_id)
# Assert
assert result is True
mock_to_thread.assert_called_once_with(mock_admin.delete_user, user_id)
@pytest.mark.asyncio
async def test_delete_keycloak_user_connection_error(self, token_manager):
"""Test handling of KeycloakConnectionError triggers retry and raises RetryError."""
# Arrange
user_id = 'test_user_id'
with (
patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin,
patch('asyncio.to_thread') as mock_to_thread,
):
mock_admin = MagicMock()
mock_admin.delete_user = MagicMock()
mock_get_admin.return_value = mock_admin
mock_to_thread.side_effect = KeycloakConnectionError('Connection failed')
# Act & Assert
# KeycloakConnectionError triggers retry decorator
# After retries exhaust (2 attempts), it raises RetryError
from tenacity import RetryError
with pytest.raises(RetryError):
await token_manager.delete_keycloak_user(user_id)
@pytest.mark.asyncio
async def test_delete_keycloak_user_keycloak_error(self, token_manager):
"""Test handling of KeycloakError (e.g., user not found)."""
# Arrange
user_id = 'test_user_id'
with (
patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin,
patch('asyncio.to_thread') as mock_to_thread,
):
mock_admin = MagicMock()
mock_admin.delete_user = MagicMock()
mock_get_admin.return_value = mock_admin
mock_to_thread.side_effect = KeycloakError('User not found')
# Act
result = await token_manager.delete_keycloak_user(user_id)
# Assert
assert result is False
@pytest.mark.asyncio
async def test_delete_keycloak_user_general_exception(self, token_manager):
"""Test handling of general exceptions."""
# Arrange
user_id = 'test_user_id'
with (
patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin,
patch('asyncio.to_thread') as mock_to_thread,
):
mock_admin = MagicMock()
mock_admin.delete_user = MagicMock()
mock_get_admin.return_value = mock_admin
mock_to_thread.side_effect = Exception('Unexpected error')
# Act
result = await token_manager.delete_keycloak_user(user_id)
# Assert
assert result is False

View File

@@ -1,4 +1,4 @@
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from server.auth.token_manager import TokenManager, create_encryption_utility
@@ -246,3 +246,103 @@ async def test_refresh(token_manager):
mock_keycloak.return_value.a_refresh_token.assert_called_once_with(
'test_refresh_token'
)
@pytest.mark.asyncio
async def test_disable_keycloak_user_success(token_manager):
"""Test successful disabling of a Keycloak user account."""
# Arrange
user_id = 'test_user_id'
email = 'user@colsch.us'
mock_user = {
'id': user_id,
'username': 'testuser',
'email': email,
'emailVerified': True,
}
with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin:
mock_admin = MagicMock()
mock_admin.a_get_user = AsyncMock(return_value=mock_user)
mock_admin.a_update_user = AsyncMock()
mock_get_admin.return_value = mock_admin
# Act
await token_manager.disable_keycloak_user(user_id, email)
# Assert
mock_admin.a_get_user.assert_called_once_with(user_id)
mock_admin.a_update_user.assert_called_once_with(
user_id=user_id,
payload={
'enabled': False,
'username': 'testuser',
'email': email,
'emailVerified': True,
},
)
@pytest.mark.asyncio
async def test_disable_keycloak_user_without_email(token_manager):
"""Test disabling Keycloak user without providing email."""
# Arrange
user_id = 'test_user_id'
mock_user = {
'id': user_id,
'username': 'testuser',
'email': 'user@example.com',
'emailVerified': False,
}
with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin:
mock_admin = MagicMock()
mock_admin.a_get_user = AsyncMock(return_value=mock_user)
mock_admin.a_update_user = AsyncMock()
mock_get_admin.return_value = mock_admin
# Act
await token_manager.disable_keycloak_user(user_id)
# Assert
mock_admin.a_get_user.assert_called_once_with(user_id)
mock_admin.a_update_user.assert_called_once()
@pytest.mark.asyncio
async def test_disable_keycloak_user_not_found(token_manager):
"""Test disabling Keycloak user when user is not found."""
# Arrange
user_id = 'nonexistent_user_id'
email = 'user@colsch.us'
with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin:
mock_admin = MagicMock()
mock_admin.a_get_user = AsyncMock(return_value=None)
mock_get_admin.return_value = mock_admin
# Act
await token_manager.disable_keycloak_user(user_id, email)
# Assert
mock_admin.a_get_user.assert_called_once_with(user_id)
mock_admin.a_update_user.assert_not_called()
@pytest.mark.asyncio
async def test_disable_keycloak_user_exception_handling(token_manager):
"""Test that disable_keycloak_user handles exceptions gracefully without raising."""
# Arrange
user_id = 'test_user_id'
email = 'user@colsch.us'
with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin:
mock_admin = MagicMock()
mock_admin.a_get_user = AsyncMock(side_effect=Exception('Connection error'))
mock_get_admin.return_value = mock_admin
# Act & Assert - should not raise exception
await token_manager.disable_keycloak_user(user_id, email)
# Verify the method was called
mock_admin.a_get_user.assert_called_once_with(user_id)

146
frontend/__tests__/MSW.md Normal file
View File

@@ -0,0 +1,146 @@
# Mock Service Worker (MSW) Guide
## Overview
[Mock Service Worker (MSW)](https://mswjs.io/) is an API mocking library that intercepts outgoing network requests at the network level. Unlike traditional mocking that patches `fetch` or `axios`, MSW uses a Service Worker in the browser and direct request interception in Node.js—making mocks transparent to your application code.
We use MSW in this project for:
- **Testing**: Write reliable unit and integration tests without real network calls
- **Development**: Run the frontend with mocked APIs when the backend isn't available or when working on features with pending backend APIs
The same mock handlers work in both environments, so you write them once and reuse everywhere.
## Relevant Files
- `src/mocks/handlers.ts` - Main handler registry that combines all domain handlers
- `src/mocks/*-handlers.ts` - Domain-specific handlers (auth, billing, conversation, etc.)
- `src/mocks/browser.ts` - Browser setup for development mode
- `src/mocks/node.ts` - Node.js setup for tests
- `vitest.setup.ts` - Global test setup with MSW lifecycle hooks
## Development Workflow
### Running with Mocked APIs
```sh
# Run with API mocking enabled
npm run dev:mock
# Run with API mocking + SaaS mode simulation
npm run dev:mock:saas
```
These commands set `VITE_MOCK_API=true` which activates the MSW Service Worker to intercept requests.
> [!NOTE]
> **OSS vs SaaS Mode**
>
> OpenHands runs in two modes:
> - **OSS mode**: For local/self-hosted deployments where users provide their own LLM API keys and configure git providers manually
> - **SaaS mode**: For the cloud offering with billing, managed API keys, and OAuth-based GitHub integration
>
> Use `dev:mock:saas` when working on SaaS-specific features like billing, API key management, or subscription flows.
## Writing Tests
### Service Layer Mocking (Recommended)
For most tests, mock at the service layer using `vi.spyOn`. This approach is explicit, test-scoped, and makes the scenario being tested clear.
```typescript
import { vi } from "vitest";
import SettingsService from "#/api/settings-service/settings-service.api";
const getSettingsSpy = vi.spyOn(SettingsService, "getSettings");
getSettingsSpy.mockResolvedValue({
llm_model: "openai/gpt-4o",
llm_api_key_set: true,
// ... other settings
});
```
Use `mockResolvedValue` for success scenarios and `mockRejectedValue` for error scenarios:
```typescript
getSettingsSpy.mockRejectedValue(new Error("Failed to fetch settings"));
```
### Network Layer Mocking (Advanced)
For tests that need actual network-level behavior (WebSockets, testing retry logic, etc.), use `server.use()` to override handlers per test.
> [!IMPORTANT]
> **Reuse the global server instance** - Don't create new `setupServer()` calls in individual tests. The project already has a global MSW server configured in `vitest.setup.ts` that handles lifecycle (`server.listen()`, `server.resetHandlers()`, `server.close()`). Use `server.use()` to add runtime handlers for specific test scenarios.
```typescript
import { http, HttpResponse } from "msw";
import { server } from "#/mocks/node";
it("should handle server errors", async () => {
server.use(
http.get("/api/my-endpoint", () => {
return new HttpResponse(null, { status: 500 });
}),
);
// ... test code
});
```
For WebSocket testing, see `__tests__/helpers/msw-websocket-setup.ts` for utilities.
## Adding New API Mocks
When adding new API endpoints, create mocks in both places to maintain 1:1 similarity with the backend:
### 1. Add to `src/mocks/` (for development)
Create or update a domain-specific handler file:
```typescript
// src/mocks/my-feature-handlers.ts
import { http, HttpResponse } from "msw";
export const MY_FEATURE_HANDLERS = [
http.get("/api/my-feature", () => {
return HttpResponse.json({
data: "mock response",
});
}),
];
```
Register in `handlers.ts`:
```typescript
import { MY_FEATURE_HANDLERS } from "./my-feature-handlers";
export const handlers = [
// ... existing handlers
...MY_FEATURE_HANDLERS,
];
```
### 2. Mock in tests for specific scenarios
In your test files, spy on the service method to control responses per test case:
```typescript
import { vi } from "vitest";
import MyFeatureService from "#/api/my-feature-service.api";
const spy = vi.spyOn(MyFeatureService, "getData");
spy.mockResolvedValue({ data: "test-specific response" });
```
See `__tests__/routes/llm-settings.test.tsx` for a real-world example of service layer mocking.
> [!TIP]
> For guidance on creating service APIs, see `src/api/README.md`.
## Best Practices
- **Keep mocks close to real API contracts** - Update mocks when backend changes
- **Use service layer mocking for most tests** - It's simpler and more explicit
- **Reserve network layer mocking for integration tests** - WebSockets, retry logic, etc.
- **Export mock data from handler files** - Reuse in tests (e.g., `MOCK_DEFAULT_USER_SETTINGS`)

View File

@@ -0,0 +1,18 @@
import { test, expect, vi } from "vitest";
import axios from "axios";
import V1GitService from "../../src/api/git-service/v1-git-service.api";
vi.mock("axios");
test("getGitChanges throws when response is not an array (dead runtime returns HTML)", async () => {
const htmlResponse = "<!DOCTYPE html><html>...</html>";
vi.mocked(axios.get).mockResolvedValue({ data: htmlResponse });
await expect(
V1GitService.getGitChanges(
"http://localhost:3000/api/conversations/123",
"test-api-key",
"/workspace",
),
).rejects.toThrow("Invalid response from runtime");
});

View File

@@ -1,6 +1,7 @@
import { render, screen } from "@testing-library/react";
import { it, describe, expect, vi, beforeEach, afterEach } from "vitest";
import userEvent from "@testing-library/user-event";
import { MemoryRouter } from "react-router";
import { AuthModal } from "#/components/features/waitlist/auth-modal";
// Mock the useAuthUrl hook
@@ -27,11 +28,13 @@ describe("AuthModal", () => {
it("should render the GitHub and GitLab buttons", () => {
render(
<AuthModal
githubAuthUrl="mock-url"
appMode="saas"
providersConfigured={["github", "gitlab"]}
/>,
<MemoryRouter>
<AuthModal
githubAuthUrl="mock-url"
appMode="saas"
providersConfigured={["github", "gitlab"]}
/>
</MemoryRouter>,
);
const githubButton = screen.getByRole("button", {
@@ -49,11 +52,13 @@ describe("AuthModal", () => {
const user = userEvent.setup();
const mockUrl = "https://github.com/login/oauth/authorize";
render(
<AuthModal
githubAuthUrl={mockUrl}
appMode="saas"
providersConfigured={["github"]}
/>,
<MemoryRouter>
<AuthModal
githubAuthUrl={mockUrl}
appMode="saas"
providersConfigured={["github"]}
/>
</MemoryRouter>,
);
const githubButton = screen.getByRole("button", {
@@ -65,10 +70,14 @@ describe("AuthModal", () => {
});
it("should render Terms of Service and Privacy Policy text with correct links", () => {
render(<AuthModal githubAuthUrl="mock-url" appMode="saas" />);
render(
<MemoryRouter>
<AuthModal githubAuthUrl="mock-url" appMode="saas" />
</MemoryRouter>,
);
// Find the terms of service section using data-testid
const termsSection = screen.getByTestId("auth-modal-terms-of-service");
const termsSection = screen.getByTestId("terms-and-privacy-notice");
expect(termsSection).toBeInTheDocument();
// Check that all text content is present in the paragraph
@@ -105,8 +114,44 @@ describe("AuthModal", () => {
expect(termsSection).toContainElement(privacyLink);
});
it("should display email verified message when emailVerified prop is true", () => {
render(
<MemoryRouter>
<AuthModal
githubAuthUrl="mock-url"
appMode="saas"
emailVerified={true}
/>
</MemoryRouter>,
);
expect(
screen.getByText("AUTH$EMAIL_VERIFIED_PLEASE_LOGIN"),
).toBeInTheDocument();
});
it("should not display email verified message when emailVerified prop is false", () => {
render(
<MemoryRouter>
<AuthModal
githubAuthUrl="mock-url"
appMode="saas"
emailVerified={false}
/>
</MemoryRouter>,
);
expect(
screen.queryByText("AUTH$EMAIL_VERIFIED_PLEASE_LOGIN"),
).not.toBeInTheDocument();
});
it("should open Terms of Service link in new tab", () => {
render(<AuthModal githubAuthUrl="mock-url" appMode="saas" />);
render(
<MemoryRouter>
<AuthModal githubAuthUrl="mock-url" appMode="saas" />
</MemoryRouter>,
);
const tosLink = screen.getByRole("link", {
name: "COMMON$TERMS_OF_SERVICE",
@@ -115,11 +160,58 @@ describe("AuthModal", () => {
});
it("should open Privacy Policy link in new tab", () => {
render(<AuthModal githubAuthUrl="mock-url" appMode="saas" />);
render(
<MemoryRouter>
<AuthModal githubAuthUrl="mock-url" appMode="saas" />
</MemoryRouter>,
);
const privacyLink = screen.getByRole("link", {
name: "COMMON$PRIVACY_POLICY",
});
expect(privacyLink).toHaveAttribute("target", "_blank");
});
describe("Duplicate email error message", () => {
const renderAuthModalWithRouter = (initialEntries: string[]) => {
const hasDuplicatedEmail = initialEntries.includes(
"/?duplicated_email=true",
);
return render(
<MemoryRouter initialEntries={initialEntries}>
<AuthModal
githubAuthUrl="mock-url"
appMode="saas"
providersConfigured={["github"]}
hasDuplicatedEmail={hasDuplicatedEmail}
/>
</MemoryRouter>,
);
};
it("should display error message when duplicated_email query parameter is true", () => {
// Arrange
const initialEntries = ["/?duplicated_email=true"];
// Act
renderAuthModalWithRouter(initialEntries);
// Assert
const errorMessage = screen.getByText("AUTH$DUPLICATE_EMAIL_ERROR");
expect(errorMessage).toBeInTheDocument();
});
it("should not display error message when duplicated_email query parameter is missing", () => {
// Arrange
const initialEntries = ["/"];
// Act
renderAuthModalWithRouter(initialEntries);
// Assert
const errorMessage = screen.queryByText("AUTH$DUPLICATE_EMAIL_ERROR");
expect(errorMessage).not.toBeInTheDocument();
});
});
});

View File

@@ -23,6 +23,11 @@ describe("ConversationPanel", () => {
Component: () => <ConversationPanel onClose={onCloseMock} />,
path: "/",
},
{
// Add route to prevent "No routes matched location" warning
Component: () => null,
path: "/conversations/:conversationId",
},
]);
const renderConversationPanel = () => renderWithProviders(<RouterStub />);

View File

@@ -48,9 +48,12 @@ describe("MaintenanceBanner", () => {
expect(button).toBeInTheDocument();
});
// maintenance-banner
it("handles invalid date gracefully", () => {
// Suppress expected console.warn for invalid date parsing
const consoleWarnSpy = vi
.spyOn(console, "warn")
.mockImplementation(() => {});
const invalidTime = "invalid-date";
render(
@@ -62,6 +65,9 @@ describe("MaintenanceBanner", () => {
// Check if the banner is rendered
const banner = screen.queryByTestId("maintenance-banner");
expect(banner).not.toBeInTheDocument();
// Restore console.warn
consoleWarnSpy.mockRestore();
});
it("click on dismiss button removes banner", () => {

View File

@@ -0,0 +1,187 @@
import React from "react";
import { screen, waitFor } from "@testing-library/react";
import userEvent from "@testing-library/user-event";
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import { MemoryRouter } from "react-router";
import { emailService } from "#/api/email-service/email-service.api";
import { EmailVerificationModal } from "#/components/features/waitlist/email-verification-modal";
import * as ToastHandlers from "#/utils/custom-toast-handlers";
import { renderWithProviders, createAxiosError } from "../../../../test-utils";
describe("EmailVerificationModal", () => {
const mockOnClose = vi.fn();
const resendEmailVerificationSpy = vi.spyOn(
emailService,
"resendEmailVerification",
);
const displaySuccessToastSpy = vi.spyOn(ToastHandlers, "displaySuccessToast");
const displayErrorToastSpy = vi.spyOn(ToastHandlers, "displayErrorToast");
const renderWithRouter = (ui: React.ReactElement) => {
return renderWithProviders(<MemoryRouter>{ui}</MemoryRouter>);
};
beforeEach(() => {
vi.clearAllMocks();
});
it("should render the email verification message", () => {
// Arrange & Act
renderWithRouter(<EmailVerificationModal onClose={mockOnClose} />);
// Assert
expect(
screen.getByText("AUTH$PLEASE_CHECK_EMAIL_TO_VERIFY"),
).toBeInTheDocument();
});
it("should render the TermsAndPrivacyNotice component", () => {
// Arrange & Act
renderWithRouter(<EmailVerificationModal onClose={mockOnClose} />);
// Assert
const termsSection = screen.getByTestId("terms-and-privacy-notice");
expect(termsSection).toBeInTheDocument();
});
it("should render resend verification button", () => {
// Arrange & Act
renderWithRouter(<EmailVerificationModal onClose={mockOnClose} />);
// Assert
expect(
screen.getByText("SETTINGS$RESEND_VERIFICATION"),
).toBeInTheDocument();
});
it("should call resendEmailVerification when the button is clicked", async () => {
// Arrange
const userId = "test_user_id";
resendEmailVerificationSpy.mockResolvedValue({
message: "Email verification message sent",
});
renderWithRouter(
<EmailVerificationModal onClose={mockOnClose} userId={userId} />,
);
// Act
const resendButton = screen.getByText("SETTINGS$RESEND_VERIFICATION");
await userEvent.click(resendButton);
// Assert
await waitFor(() => {
expect(resendEmailVerificationSpy).toHaveBeenCalledWith({
userId,
isAuthFlow: true,
});
});
});
it("should display success toast when resend succeeds", async () => {
// Arrange
resendEmailVerificationSpy.mockResolvedValue({
message: "Email verification message sent",
});
renderWithRouter(<EmailVerificationModal onClose={mockOnClose} />);
// Act
const resendButton = screen.getByText("SETTINGS$RESEND_VERIFICATION");
await userEvent.click(resendButton);
// Assert
await waitFor(() => {
expect(displaySuccessToastSpy).toHaveBeenCalledWith(
"SETTINGS$VERIFICATION_EMAIL_SENT",
);
});
});
it("should display rate limit error message when receiving 429 status", async () => {
// Arrange
const rateLimitError = createAxiosError(429, "Too Many Requests", {
detail: "Too many requests. Please wait 2 minutes before trying again.",
});
resendEmailVerificationSpy.mockRejectedValue(rateLimitError);
renderWithRouter(<EmailVerificationModal onClose={mockOnClose} />);
// Act
const resendButton = screen.getByText("SETTINGS$RESEND_VERIFICATION");
await userEvent.click(resendButton);
// Assert
await waitFor(() => {
expect(displayErrorToastSpy).toHaveBeenCalledWith(
"Too many requests. Please wait 2 minutes before trying again.",
);
});
});
it("should display generic error message when receiving non-429 error", async () => {
// Arrange
const genericError = createAxiosError(500, "Internal Server Error", {
error: "Internal server error",
});
resendEmailVerificationSpy.mockRejectedValue(genericError);
renderWithRouter(<EmailVerificationModal onClose={mockOnClose} />);
// Act
const resendButton = screen.getByText("SETTINGS$RESEND_VERIFICATION");
await userEvent.click(resendButton);
// Assert
await waitFor(() => {
expect(displayErrorToastSpy).toHaveBeenCalledWith(
"SETTINGS$FAILED_TO_RESEND_VERIFICATION",
);
});
});
it("should disable button and show sending text while request is pending", async () => {
// Arrange
let resolvePromise: (value: { message: string }) => void;
const pendingPromise = new Promise<{ message: string }>((resolve) => {
resolvePromise = resolve;
});
resendEmailVerificationSpy.mockReturnValue(pendingPromise);
renderWithRouter(<EmailVerificationModal onClose={mockOnClose} />);
// Act
const resendButton = screen.getByText("SETTINGS$RESEND_VERIFICATION");
await userEvent.click(resendButton);
// Assert
await waitFor(() => {
const sendingButton = screen.getByText("SETTINGS$SENDING");
expect(sendingButton).toBeInTheDocument();
expect(sendingButton).toBeDisabled();
});
// Cleanup
resolvePromise!({ message: "Email verification message sent" });
});
it("should re-enable button after request completes", async () => {
// Arrange
resendEmailVerificationSpy.mockResolvedValue({
message: "Email verification message sent",
});
renderWithRouter(<EmailVerificationModal onClose={mockOnClose} />);
// Act
const resendButton = screen.getByText("SETTINGS$RESEND_VERIFICATION");
await userEvent.click(resendButton);
// Assert
await waitFor(() => {
expect(resendEmailVerificationSpy).toHaveBeenCalled();
});
// After successful send, the button will be disabled due to cooldown
// So we just verify the mutation was called successfully
await waitFor(() => {
const button = screen.getByRole("button");
expect(button).toBeDisabled(); // Button is disabled during cooldown
expect(button).toHaveTextContent(/SETTINGS\$RESEND_VERIFICATION/);
});
});
});

View File

@@ -0,0 +1,48 @@
import { render, screen } from "@testing-library/react";
import { it, describe, expect } from "vitest";
import { TermsAndPrivacyNotice } from "#/components/shared/terms-and-privacy-notice";
describe("TermsAndPrivacyNotice", () => {
it("should render Terms of Service and Privacy Policy links", () => {
// Arrange & Act
render(<TermsAndPrivacyNotice />);
// Assert
const termsSection = screen.getByTestId("terms-and-privacy-notice");
expect(termsSection).toBeInTheDocument();
const tosLink = screen.getByRole("link", {
name: "COMMON$TERMS_OF_SERVICE",
});
const privacyLink = screen.getByRole("link", {
name: "COMMON$PRIVACY_POLICY",
});
expect(tosLink).toBeInTheDocument();
expect(tosLink).toHaveAttribute("href", "https://www.all-hands.dev/tos");
expect(tosLink).toHaveAttribute("target", "_blank");
expect(tosLink).toHaveAttribute("rel", "noopener noreferrer");
expect(privacyLink).toBeInTheDocument();
expect(privacyLink).toHaveAttribute(
"href",
"https://www.all-hands.dev/privacy",
);
expect(privacyLink).toHaveAttribute("target", "_blank");
expect(privacyLink).toHaveAttribute("rel", "noopener noreferrer");
});
it("should render all required text content", () => {
// Arrange & Act
render(<TermsAndPrivacyNotice />);
// Assert
const termsSection = screen.getByTestId("terms-and-privacy-notice");
expect(termsSection).toHaveTextContent(
"AUTH$BY_SIGNING_UP_YOU_AGREE_TO_OUR",
);
expect(termsSection).toHaveTextContent("COMMON$TERMS_OF_SERVICE");
expect(termsSection).toHaveTextContent("COMMON$AND");
expect(termsSection).toHaveTextContent("COMMON$PRIVACY_POLICY");
});
});

View File

@@ -6,6 +6,7 @@ import {
beforeEach,
afterAll,
afterEach,
vi,
} from "vitest";
import { screen, waitFor, render, cleanup } from "@testing-library/react";
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
@@ -34,6 +35,7 @@ import {
} from "#/contexts/conversation-websocket-context";
import { conversationWebSocketTestSetup } from "./helpers/msw-websocket-setup";
import { useEventStore } from "#/stores/use-event-store";
import { isV1Event } from "#/types/v1/type-guards";
// MSW WebSocket mock setup
const { wsLink, server: mswServer } = conversationWebSocketTestSetup();
@@ -141,6 +143,11 @@ describe("Conversation WebSocket Handler", () => {
});
it("should handle malformed/invalid event data gracefully", async () => {
// Suppress expected console.warn for invalid JSON parsing
const consoleWarnSpy = vi
.spyOn(console, "warn")
.mockImplementation(() => {});
// Set up MSW to send various invalid events when connection is established
mswServer.use(
wsLink.addEventListener("connection", ({ client, server }) => {
@@ -203,6 +210,9 @@ describe("Conversation WebSocket Handler", () => {
"valid-event-123",
);
expect(screen.getByTestId("ui-events-count")).toHaveTextContent("1");
// Restore console.warn
consoleWarnSpy.mockRestore();
});
});
@@ -415,6 +425,101 @@ describe("Conversation WebSocket Handler", () => {
// based on how the actual reconnection logic is implemented
});
it("should not create duplicate events when WebSocket reconnects with resend_all=true", async () => {
const conversationId = "test-conversation-reconnect";
let connectionCount = 0;
// Clear event store before test
useEventStore.getState().clearEvents();
// Create mock events that will be sent on each connection
const mockHistoryEvents = [
createMockUserMessageEvent({ id: "event-1" }),
createMockMessageEvent({ id: "event-2" }),
createMockMessageEvent({ id: "event-3" }),
];
// Set up MSW to mock event count API and WebSocket
// The WebSocket will resend all events on each connection (simulating resend_all=true behavior)
mswServer.use(
http.get(
`http://localhost:3000/api/conversations/${conversationId}/events/count`,
() => HttpResponse.json(3),
),
wsLink.addEventListener("connection", ({ client, server }) => {
connectionCount += 1;
server.connect();
// Send all history events on EVERY connection (simulating resend_all=true)
mockHistoryEvents.forEach((event) => {
client.send(JSON.stringify(event));
});
// On first connection, simulate a disconnect after events are sent
if (connectionCount === 1) {
setTimeout(() => {
client.close(1006, "Simulated disconnect");
}, 100);
}
}),
);
// Render with WebSocket context
renderWithWebSocketContext(
<ConnectionStatusComponent />,
conversationId,
`http://localhost:3000/api/conversations/${conversationId}`,
);
// Wait for initial connection and events
await waitFor(() => {
expect(screen.getByTestId("connection-state")).toHaveTextContent(
"OPEN",
);
});
await waitFor(() => {
expect(useEventStore.getState().events.length).toBe(3);
});
// Wait for disconnect
await waitFor(() => {
expect(screen.getByTestId("connection-state")).toHaveTextContent(
"CLOSED",
);
});
// Wait for reconnection
await waitFor(
() => {
expect(screen.getByTestId("connection-state")).toHaveTextContent(
"OPEN",
);
},
{ timeout: 5000 },
);
// Give time for resent events to be processed
await new Promise((resolve) => {
setTimeout(resolve, 200);
});
// After reconnection, events should NOT be duplicated
// The server sends 3 events again (resend_all=true), but we should deduplicate
const { events } = useEventStore.getState();
const v1Events = events.filter(isV1Event);
const uniqueEventIds = [...new Set(v1Events.map((e) => e.id))];
// This assertion will FAIL with current implementation (showing the bug)
// Expected: 3 events (deduplicated)
// Actual: 6 events (duplicated)
expect(v1Events.length).toBe(3);
expect(uniqueEventIds.length).toBe(3);
// Verify we actually had 2 connections
expect(connectionCount).toBe(2);
});
it.todo("should track and display errors with proper metadata");
it.todo("should set appropriate error states on connection failures");
it.todo(

View File

@@ -1,174 +0,0 @@
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
import { renderHook, waitFor } from "@testing-library/react";
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
import React from "react";
import {
useGitHubIssuesPRs,
useRefreshGitHubIssuesPRs,
} from "../src/hooks/query/use-github-issues-prs";
import { useShouldShowUserFeatures } from "../src/hooks/use-should-show-user-features";
import GitHubIssuesPRsService from "../src/api/github-service/github-issues-prs.api";
// Mock the dependencies
vi.mock("../src/hooks/use-should-show-user-features");
vi.mock("../src/api/github-service/github-issues-prs.api", () => ({
default: {
getGitHubItems: vi.fn(),
buildItemUrl: vi.fn(),
},
}));
const mockUseShouldShowUserFeatures = vi.mocked(useShouldShowUserFeatures);
const mockGetGitHubItems = vi.mocked(GitHubIssuesPRsService.getGitHubItems);
const createWrapper = () => {
const queryClient = new QueryClient({
defaultOptions: {
queries: {
retry: false,
},
},
});
return ({ children }: { children: React.ReactNode }) =>
React.createElement(QueryClientProvider, { client: queryClient }, children);
};
describe("useGitHubIssuesPRs", () => {
beforeEach(() => {
vi.clearAllMocks();
localStorage.clear();
mockUseShouldShowUserFeatures.mockReturnValue(false);
});
afterEach(() => {
localStorage.clear();
});
it("should be disabled when useShouldShowUserFeatures returns false", () => {
mockUseShouldShowUserFeatures.mockReturnValue(false);
const { result } = renderHook(() => useGitHubIssuesPRs(), {
wrapper: createWrapper(),
});
expect(result.current.isLoading).toBe(false);
expect(result.current.isFetching).toBe(false);
});
it("should be enabled when useShouldShowUserFeatures returns true", () => {
mockUseShouldShowUserFeatures.mockReturnValue(true);
mockGetGitHubItems.mockResolvedValue({
items: [],
cached_at: new Date().toISOString(),
});
const { result } = renderHook(() => useGitHubIssuesPRs(), {
wrapper: createWrapper(),
});
// When enabled, the query should be loading/fetching
expect(result.current.isLoading).toBe(true);
});
it("should fetch and return GitHub items", async () => {
mockUseShouldShowUserFeatures.mockReturnValue(true);
const mockItems = [
{
git_provider: "github" as const,
item_type: "issue" as const,
status: "OPEN_ISSUE" as const,
repo: "test/repo",
number: 1,
title: "Test Issue",
author: "testuser",
assignees: ["testuser"],
created_at: "2024-01-01T00:00:00Z",
updated_at: "2024-01-01T00:00:00Z",
url: "https://github.com/test/repo/issues/1",
},
];
mockGetGitHubItems.mockResolvedValue({
items: mockItems,
cached_at: new Date().toISOString(),
});
const { result } = renderHook(() => useGitHubIssuesPRs(), {
wrapper: createWrapper(),
});
await waitFor(() => {
expect(result.current.isSuccess).toBe(true);
});
expect(result.current.data?.items).toEqual(mockItems);
});
it("should filter by item type", async () => {
mockUseShouldShowUserFeatures.mockReturnValue(true);
const mockItems = [
{
git_provider: "github" as const,
item_type: "issue" as const,
status: "OPEN_ISSUE" as const,
repo: "test/repo",
number: 1,
title: "Test Issue",
author: "testuser",
assignees: ["testuser"],
created_at: "2024-01-01T00:00:00Z",
updated_at: "2024-01-01T00:00:00Z",
url: "https://github.com/test/repo/issues/1",
},
];
mockGetGitHubItems.mockResolvedValue({
items: mockItems,
cached_at: new Date().toISOString(),
});
const { result } = renderHook(
() => useGitHubIssuesPRs({ itemType: "issues" }),
{
wrapper: createWrapper(),
},
);
await waitFor(() => {
expect(result.current.isSuccess).toBe(true);
});
expect(mockGetGitHubItems).toHaveBeenCalledWith({ itemType: "issues" });
});
});
describe("useRefreshGitHubIssuesPRs", () => {
beforeEach(() => {
vi.clearAllMocks();
localStorage.clear();
});
afterEach(() => {
localStorage.clear();
});
it("should clear localStorage cache when called", () => {
// Set up some cached data
localStorage.setItem(
"github-issues-prs-cache",
JSON.stringify({
data: { items: [], cached_at: new Date().toISOString() },
timestamp: Date.now(),
}),
);
const { result } = renderHook(() => useRefreshGitHubIssuesPRs(), {
wrapper: createWrapper(),
});
// Call the refresh function
result.current();
// Check that localStorage was cleared
expect(localStorage.getItem("github-issues-prs-cache")).toBeNull();
});
});

View File

@@ -34,7 +34,11 @@ describe("useWebSocket", () => {
}),
);
beforeAll(() => mswServer.listen());
beforeAll(() =>
mswServer.listen({
onUnhandledRequest: "warn",
}),
);
afterEach(() => mswServer.resetHandlers());
afterAll(() => mswServer.close());

View File

@@ -0,0 +1,227 @@
# Testing with React Router
## Overview
React Router components and hooks require a routing context to function. In tests, we need to provide this context while maintaining control over the routing state.
This guide covers the two main approaches used in the OpenHands frontend:
1. **`createRoutesStub`** - Creates a complete route structure for testing components with their actual route configuration, loaders, and nested routes.
2. **`MemoryRouter`** - Provides a minimal routing context for components that just need router hooks to work.
Choose your approach based on what your component actually needs from the router.
## When to Use Each Approach
### `createRoutesStub` (Recommended)
Use `createRoutesStub` when your component:
- Relies on route parameters (`useParams`)
- Uses loader data (`useLoaderData`) or `clientLoader`
- Has nested routes or uses `<Outlet />`
- Needs to test navigation between routes
> [!NOTE]
> `createRoutesStub` is intended for unit testing **reusable components** that depend on router context. For testing full route/page components, consider E2E tests (Playwright, Cypress) instead.
```typescript
import { createRoutesStub } from "react-router";
import { render } from "@testing-library/react";
const RouterStub = createRoutesStub([
{
Component: MyRouteComponent,
path: "/conversations/:conversationId",
},
]);
render(<RouterStub initialEntries={["/conversations/123"]} />);
```
**With nested routes and loaders:**
```typescript
const RouterStub = createRoutesStub([
{
Component: SettingsScreen,
clientLoader,
path: "/settings",
children: [
{
Component: () => <div data-testid="llm-settings" />,
path: "/settings",
},
{
Component: () => <div data-testid="git-settings" />,
path: "/settings/integrations",
},
],
},
]);
render(<RouterStub initialEntries={["/settings/integrations"]} />);
```
> [!TIP]
> When using `clientLoader` from a Route module, you may encounter type mismatches. Use `@ts-expect-error` as a workaround:
```typescript
import { clientLoader } from "@/routes/settings";
const RouterStub = createRoutesStub([
{
path: "/settings",
Component: SettingsScreen,
// @ts-expect-error: loader types won't align between test and app code
loader: clientLoader,
},
]);
```
### `MemoryRouter`
Use `MemoryRouter` when your component:
- Only needs basic routing context to render
- Uses `<Link>` components but you don't need to test navigation
- Doesn't depend on specific route parameters or loaders
```typescript
import { MemoryRouter } from "react-router";
import { render } from "@testing-library/react";
render(
<MemoryRouter>
<MyComponent />
</MemoryRouter>
);
```
**With initial route:**
```typescript
render(
<MemoryRouter initialEntries={["/some/path"]}>
<MyComponent />
</MemoryRouter>
);
```
## Anti-patterns to Avoid
### Using `BrowserRouter` in tests
`BrowserRouter` interacts with the actual browser history API, which can cause issues in test environments:
```typescript
// ❌ Avoid
render(
<BrowserRouter>
<MyComponent />
</BrowserRouter>
);
// ✅ Use MemoryRouter instead
render(
<MemoryRouter>
<MyComponent />
</MemoryRouter>
);
```
### Mocking router hooks when `createRoutesStub` would work
Mocking hooks like `useParams` directly can be brittle and doesn't test the actual routing behavior:
```typescript
// ❌ Avoid when possible
vi.mock("react-router", async () => {
const actual = await vi.importActual("react-router");
return {
...actual,
useParams: () => ({ conversationId: "123" }),
};
});
// ✅ Prefer createRoutesStub - tests real routing behavior
const RouterStub = createRoutesStub([
{
Component: MyComponent,
path: "/conversations/:conversationId",
},
]);
render(<RouterStub initialEntries={["/conversations/123"]} />);
```
## Common Patterns
### Combining with `QueryClientProvider`
Many components need both routing and TanStack Query context:
```typescript
import { createRoutesStub } from "react-router";
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
const queryClient = new QueryClient({
defaultOptions: {
queries: { retry: false },
},
});
const RouterStub = createRoutesStub([
{
Component: MyComponent,
path: "/",
},
]);
render(<RouterStub />, {
wrapper: ({ children }) => (
<QueryClientProvider client={queryClient}>
{children}
</QueryClientProvider>
),
});
```
### Testing navigation behavior
Verify that user interactions trigger the expected navigation:
```typescript
import { createRoutesStub } from "react-router";
import { screen } from "@testing-library/react";
import userEvent from "@testing-library/user-event";
const RouterStub = createRoutesStub([
{
Component: HomeScreen,
path: "/",
},
{
Component: () => <div data-testid="settings-screen" />,
path: "/settings",
},
]);
render(<RouterStub initialEntries={["/"]} />);
const user = userEvent.setup();
await user.click(screen.getByRole("link", { name: /settings/i }));
expect(screen.getByTestId("settings-screen")).toBeInTheDocument();
```
## See Also
### Codebase Examples
- [settings.test.tsx](__tests__/routes/settings.test.tsx) - `createRoutesStub` with nested routes and loaders
- [home-screen.test.tsx](__tests__/routes/home-screen.test.tsx) - `createRoutesStub` with navigation testing
- [chat-interface.test.tsx](__tests__/components/chat/chat-interface.test.tsx) - `MemoryRouter` usage
### Official Documentation
- [React Router Testing Guide](https://reactrouter.com/start/framework/testing) - Official guide on testing with `createRoutesStub`
- [MemoryRouter API](https://reactrouter.com/api/declarative-routers/MemoryRouter) - API reference for `MemoryRouter`

View File

@@ -298,6 +298,7 @@ describe("Form submission", () => {
gitlab: { token: "", host: "" },
bitbucket: { token: "", host: "" },
azure_devops: { token: "", host: "" },
forgejo: { token: "", host: "" },
});
});
@@ -320,6 +321,7 @@ describe("Form submission", () => {
gitlab: { token: "test-token", host: "" },
bitbucket: { token: "", host: "" },
azure_devops: { token: "", host: "" },
forgejo: { token: "", host: "" },
});
});
@@ -342,6 +344,7 @@ describe("Form submission", () => {
gitlab: { token: "", host: "" },
bitbucket: { token: "test-token", host: "" },
azure_devops: { token: "", host: "" },
forgejo: { token: "", host: "" },
});
});
@@ -364,6 +367,7 @@ describe("Form submission", () => {
gitlab: { token: "", host: "" },
bitbucket: { token: "", host: "" },
azure_devops: { token: "test-token", host: "" },
forgejo: { token: "", host: "" },
});
});

View File

@@ -1,316 +0,0 @@
import { render, screen, waitFor } from "@testing-library/react";
import { beforeEach, describe, expect, it, vi } from "vitest";
import { QueryClientProvider, QueryClient } from "@tanstack/react-query";
import userEvent from "@testing-library/user-event";
import { MemoryRouter, Routes, Route } from "react-router";
import GitHubIssuesPRsPage from "#/routes/github-issues-prs";
import GitHubIssuesPRsService from "#/api/github-service/github-issues-prs.api";
import ConversationService from "#/api/conversation-service/conversation-service.api";
import { useShouldShowUserFeatures } from "#/hooks/use-should-show-user-features";
import { useUserProviders } from "#/hooks/use-user-providers";
// Mock the services
vi.mock("#/api/github-service/github-issues-prs.api", () => ({
default: {
getGitHubItems: vi.fn(),
buildItemUrl: vi.fn((provider, repo, number, type) => {
if (provider === "github") {
return `https://github.com/${repo}/${type === "issue" ? "issues" : "pull"}/${number}`;
}
return "";
}),
},
}));
vi.mock("#/api/conversation-service/conversation-service.api", () => ({
default: {
searchConversations: vi.fn(),
createConversation: vi.fn(),
},
}));
vi.mock("#/hooks/use-should-show-user-features", () => ({
useShouldShowUserFeatures: vi.fn(),
}));
vi.mock("#/hooks/use-user-providers", () => ({
useUserProviders: vi.fn(),
}));
// Mock react-i18next to return the key as the translation
vi.mock("react-i18next", () => ({
useTranslation: () => ({
t: (key: string) => key,
i18n: { language: "en" },
}),
}));
const mockGetGitHubItems = vi.mocked(GitHubIssuesPRsService.getGitHubItems);
const mockSearchConversations = vi.mocked(
ConversationService.searchConversations,
);
const mockUseShouldShowUserFeatures = vi.mocked(useShouldShowUserFeatures);
const mockUseUserProviders = vi.mocked(useUserProviders);
const renderGitHubIssuesPRsPage = () =>
render(
<QueryClientProvider client={new QueryClient()}>
<MemoryRouter initialEntries={["/github-issues-prs"]}>
<Routes>
<Route path="/github-issues-prs" element={<GitHubIssuesPRsPage />} />
<Route
path="/conversations/:conversationId"
element={<div data-testid="conversation-screen" />}
/>
</Routes>
</MemoryRouter>
</QueryClientProvider>,
);
const MOCK_GITHUB_ITEMS = [
{
git_provider: "github" as const,
item_type: "issue" as const,
status: "OPEN_ISSUE" as const,
repo: "test/repo",
number: 1,
title: "Test Issue",
author: "testuser",
assignees: ["testuser"],
created_at: "2024-01-01T00:00:00Z",
updated_at: "2024-01-01T00:00:00Z",
url: "https://github.com/test/repo/issues/1",
},
{
git_provider: "github" as const,
item_type: "pr" as const,
status: "OPEN_PR" as const,
repo: "test/repo",
number: 2,
title: "Test PR",
author: "testuser",
assignees: [],
created_at: "2024-01-01T00:00:00Z",
updated_at: "2024-01-01T00:00:00Z",
url: "https://github.com/test/repo/pull/2",
},
];
describe("GitHubIssuesPRsPage", () => {
beforeEach(() => {
vi.clearAllMocks();
localStorage.clear();
mockUseShouldShowUserFeatures.mockReturnValue(true);
mockUseUserProviders.mockReturnValue({
providers: ["github"],
isLoadingSettings: false,
});
mockGetGitHubItems.mockResolvedValue({
items: MOCK_GITHUB_ITEMS,
cached_at: new Date().toISOString(),
});
mockSearchConversations.mockResolvedValue([]);
});
it("should render the page title", async () => {
renderGitHubIssuesPRsPage();
await waitFor(() => {
expect(screen.getByText("GITHUB_ISSUES_PRS$TITLE")).toBeInTheDocument();
});
});
it("should render the view type selector", async () => {
renderGitHubIssuesPRsPage();
await waitFor(() => {
// The view label has a colon after it
expect(screen.getByText("GITHUB_ISSUES_PRS$VIEW:")).toBeInTheDocument();
expect(screen.getByRole("combobox")).toBeInTheDocument();
});
});
it("should render filter checkboxes", async () => {
renderGitHubIssuesPRsPage();
await waitFor(() => {
expect(
screen.getByText("GITHUB_ISSUES_PRS$ASSIGNED_TO_ME"),
).toBeInTheDocument();
expect(
screen.getByText("GITHUB_ISSUES_PRS$AUTHORED_BY_ME"),
).toBeInTheDocument();
});
});
it("should render the refresh button", async () => {
renderGitHubIssuesPRsPage();
await waitFor(() => {
expect(screen.getByText("GITHUB_ISSUES_PRS$REFRESH")).toBeInTheDocument();
});
});
it("should display GitHub items when loaded", async () => {
renderGitHubIssuesPRsPage();
await waitFor(() => {
expect(screen.getByText("Test Issue")).toBeInTheDocument();
expect(screen.getByText("Test PR")).toBeInTheDocument();
});
});
it("should display item status badges", async () => {
renderGitHubIssuesPRsPage();
await waitFor(() => {
expect(
screen.getByText("GITHUB_ISSUES_PRS$OPEN_ISSUE"),
).toBeInTheDocument();
expect(screen.getByText("GITHUB_ISSUES_PRS$OPEN_PR")).toBeInTheDocument();
});
});
it("should display Start Session buttons for items without related conversations", async () => {
renderGitHubIssuesPRsPage();
await waitFor(() => {
const startButtons = screen.getAllByText(
"GITHUB_ISSUES_PRS$START_SESSION",
);
expect(startButtons.length).toBe(2);
});
});
it("should filter items when view type is changed to issues", async () => {
renderGitHubIssuesPRsPage();
await waitFor(() => {
expect(screen.getByText("Test Issue")).toBeInTheDocument();
expect(screen.getByText("Test PR")).toBeInTheDocument();
});
// Change view type to issues
const select = screen.getByRole("combobox");
await userEvent.selectOptions(select, "issues");
await waitFor(() => {
expect(screen.getByText("Test Issue")).toBeInTheDocument();
expect(screen.queryByText("Test PR")).not.toBeInTheDocument();
});
});
it("should filter items when view type is changed to PRs", async () => {
renderGitHubIssuesPRsPage();
await waitFor(() => {
expect(screen.getByText("Test Issue")).toBeInTheDocument();
expect(screen.getByText("Test PR")).toBeInTheDocument();
});
// Change view type to PRs
const select = screen.getByRole("combobox");
await userEvent.selectOptions(select, "prs");
await waitFor(() => {
expect(screen.queryByText("Test Issue")).not.toBeInTheDocument();
expect(screen.getByText("Test PR")).toBeInTheDocument();
});
});
it("should display Resume Session button when a related conversation exists", async () => {
mockSearchConversations.mockResolvedValue([
{
conversation_id: "conv-1",
title: "Working on #1",
selected_repository: "test/repo",
selected_branch: "main",
git_provider: "github",
pr_number: [1],
created_at: "2024-01-01T00:00:00Z",
last_updated_at: "2024-01-01T00:00:00Z",
status: "RUNNING",
runtime_status: null,
url: null,
session_api_key: null,
},
]);
renderGitHubIssuesPRsPage();
await waitFor(() => {
expect(
screen.getByText("GITHUB_ISSUES_PRS$RESUME_SESSION"),
).toBeInTheDocument();
});
});
it("should show empty state when no items are found", async () => {
mockGetGitHubItems.mockResolvedValue({
items: [],
cached_at: new Date().toISOString(),
});
renderGitHubIssuesPRsPage();
await waitFor(() => {
expect(
screen.getByText("GITHUB_ISSUES_PRS$NO_ITEMS"),
).toBeInTheDocument();
});
});
it("should display View on GitHub links", async () => {
renderGitHubIssuesPRsPage();
await waitFor(() => {
const viewLinks = screen.getAllByText("GITHUB_ISSUES_PRS$VIEW_ON_GITHUB");
expect(viewLinks.length).toBe(2);
});
});
it("should show loading state while checking settings", async () => {
mockUseUserProviders.mockReturnValue({
providers: [],
isLoadingSettings: true,
});
renderGitHubIssuesPRsPage();
// Should show loading spinner
expect(screen.getByTestId("loading-spinner")).toBeInTheDocument();
});
it("should show no-token message when GitHub token is not configured", async () => {
mockUseUserProviders.mockReturnValue({
providers: [],
isLoadingSettings: false,
});
renderGitHubIssuesPRsPage();
await waitFor(() => {
expect(screen.getByText("GITHUB_ISSUES_PRS$NO_TOKEN")).toBeInTheDocument();
expect(
screen.getByText("GITHUB_ISSUES_PRS$CONFIGURE_TOKEN"),
).toBeInTheDocument();
});
});
it("should have a link to git settings when no token is configured", async () => {
mockUseUserProviders.mockReturnValue({
providers: [],
isLoadingSettings: false,
});
renderGitHubIssuesPRsPage();
await waitFor(() => {
const configureLink = screen.getByText("GITHUB_ISSUES_PRS$CONFIGURE_TOKEN");
expect(configureLink).toHaveAttribute("href", "/settings/git");
});
});
});

View File

@@ -910,6 +910,162 @@ describe("Form submission", () => {
});
});
describe("View persistence after saving advanced settings", () => {
it("should remain on Advanced view after saving when memory condenser is disabled", async () => {
// Arrange: Start with default settings (basic view)
const getSettingsSpy = vi.spyOn(SettingsService, "getSettings");
getSettingsSpy.mockResolvedValue({
...MOCK_DEFAULT_USER_SETTINGS,
});
const saveSettingsSpy = vi.spyOn(SettingsService, "saveSettings");
saveSettingsSpy.mockResolvedValue(true);
renderLlmSettingsScreen();
await screen.findByTestId("llm-settings-screen");
// Verify we start in basic view
expect(screen.getByTestId("llm-settings-form-basic")).toBeInTheDocument();
// Act: User manually switches to Advanced view
const advancedSwitch = screen.getByTestId("advanced-settings-switch");
await userEvent.click(advancedSwitch);
await screen.findByTestId("llm-settings-form-advanced");
// User disables memory condenser (advanced-only setting)
const condenserSwitch = screen.getByTestId(
"enable-memory-condenser-switch",
);
expect(condenserSwitch).toBeChecked();
await userEvent.click(condenserSwitch);
expect(condenserSwitch).not.toBeChecked();
// Mock the updated settings that will be returned after save
getSettingsSpy.mockResolvedValue({
...MOCK_DEFAULT_USER_SETTINGS,
enable_default_condenser: false, // Now disabled
});
// User saves settings
const submitButton = screen.getByTestId("submit-button");
await userEvent.click(submitButton);
// Assert: View should remain on Advanced after save
await waitFor(() => {
expect(
screen.getByTestId("llm-settings-form-advanced"),
).toBeInTheDocument();
expect(
screen.queryByTestId("llm-settings-form-basic"),
).not.toBeInTheDocument();
expect(advancedSwitch).toBeChecked();
});
});
it("should remain on Advanced view after saving when condenser max size is customized", async () => {
// Arrange: Start with default settings
const getSettingsSpy = vi.spyOn(SettingsService, "getSettings");
getSettingsSpy.mockResolvedValue({
...MOCK_DEFAULT_USER_SETTINGS,
});
const saveSettingsSpy = vi.spyOn(SettingsService, "saveSettings");
saveSettingsSpy.mockResolvedValue(true);
renderLlmSettingsScreen();
await screen.findByTestId("llm-settings-screen");
// Act: User manually switches to Advanced view
const advancedSwitch = screen.getByTestId("advanced-settings-switch");
await userEvent.click(advancedSwitch);
await screen.findByTestId("llm-settings-form-advanced");
// User sets custom condenser max size (advanced-only setting)
const condenserMaxSizeInput = screen.getByTestId(
"condenser-max-size-input",
);
await userEvent.clear(condenserMaxSizeInput);
await userEvent.type(condenserMaxSizeInput, "200");
// Mock the updated settings that will be returned after save
getSettingsSpy.mockResolvedValue({
...MOCK_DEFAULT_USER_SETTINGS,
condenser_max_size: 200, // Custom value
});
// User saves settings
const submitButton = screen.getByTestId("submit-button");
await userEvent.click(submitButton);
// Assert: View should remain on Advanced after save
await waitFor(() => {
expect(
screen.getByTestId("llm-settings-form-advanced"),
).toBeInTheDocument();
expect(
screen.queryByTestId("llm-settings-form-basic"),
).not.toBeInTheDocument();
expect(advancedSwitch).toBeChecked();
});
});
it("should remain on Advanced view after saving when search API key is set", async () => {
// Arrange: Start with default settings (non-SaaS mode to show search API key field)
const getConfigSpy = vi.spyOn(OptionService, "getConfig");
getConfigSpy.mockResolvedValue({
APP_MODE: "oss",
GITHUB_CLIENT_ID: "fake-github-client-id",
POSTHOG_CLIENT_KEY: "fake-posthog-client-key",
FEATURE_FLAGS: {
ENABLE_BILLING: false,
HIDE_LLM_SETTINGS: false,
ENABLE_JIRA: false,
ENABLE_JIRA_DC: false,
ENABLE_LINEAR: false,
},
});
const getSettingsSpy = vi.spyOn(SettingsService, "getSettings");
getSettingsSpy.mockResolvedValue({
...MOCK_DEFAULT_USER_SETTINGS,
search_api_key: "", // Default empty value
});
const saveSettingsSpy = vi.spyOn(SettingsService, "saveSettings");
saveSettingsSpy.mockResolvedValue(true);
renderLlmSettingsScreen();
await screen.findByTestId("llm-settings-screen");
// Act: User manually switches to Advanced view
const advancedSwitch = screen.getByTestId("advanced-settings-switch");
await userEvent.click(advancedSwitch);
await screen.findByTestId("llm-settings-form-advanced");
// User sets search API key (advanced-only setting)
const searchApiKeyInput = screen.getByTestId("search-api-key-input");
await userEvent.type(searchApiKeyInput, "test-search-api-key");
// Mock the updated settings that will be returned after save
getSettingsSpy.mockResolvedValue({
...MOCK_DEFAULT_USER_SETTINGS,
search_api_key: "test-search-api-key", // Now set
});
// User saves settings
const submitButton = screen.getByTestId("submit-button");
await userEvent.click(submitButton);
// Assert: View should remain on Advanced after save
await waitFor(() => {
expect(
screen.getByTestId("llm-settings-form-advanced"),
).toBeInTheDocument();
expect(
screen.queryByTestId("llm-settings-form-basic"),
).not.toBeInTheDocument();
expect(advancedSwitch).toBeChecked();
});
});
});
describe("Status toasts", () => {
describe("Basic form", () => {
it("should call displaySuccessToast when the settings are saved", async () => {

View File

@@ -0,0 +1,242 @@
import { render, screen, waitFor } from "@testing-library/react";
import { it, describe, expect, vi, beforeEach, afterEach } from "vitest";
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
import { createRoutesStub } from "react-router";
import MainApp from "#/routes/root-layout";
import OptionService from "#/api/option-service/option-service.api";
import AuthService from "#/api/auth-service/auth-service.api";
import SettingsService from "#/api/settings-service/settings-service.api";
// Mock other hooks that are not the focus of these tests
vi.mock("#/hooks/use-github-auth-url", () => ({
useGitHubAuthUrl: () => "https://github.com/oauth/authorize",
}));
vi.mock("#/hooks/use-is-on-tos-page", () => ({
useIsOnTosPage: () => false,
}));
vi.mock("#/hooks/use-auto-login", () => ({
useAutoLogin: () => {},
}));
vi.mock("#/hooks/use-auth-callback", () => ({
useAuthCallback: () => {},
}));
vi.mock("#/hooks/use-migrate-user-consent", () => ({
useMigrateUserConsent: () => ({
migrateUserConsent: vi.fn(),
}),
}));
vi.mock("#/hooks/use-reo-tracking", () => ({
useReoTracking: () => {},
}));
vi.mock("#/hooks/use-sync-posthog-consent", () => ({
useSyncPostHogConsent: () => {},
}));
vi.mock("#/utils/custom-toast-handlers", () => ({
displaySuccessToast: vi.fn(),
}));
const RouterStub = createRoutesStub([
{
Component: MainApp,
path: "/",
children: [
{
Component: () => <div data-testid="outlet-content">Content</div>,
path: "/",
},
],
},
]);
const createWrapper = () => {
const queryClient = new QueryClient({
defaultOptions: {
queries: {
retry: false,
},
},
});
return ({ children }: { children: React.ReactNode }) => (
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
);
};
describe("MainApp - Email Verification Flow", () => {
beforeEach(() => {
vi.clearAllMocks();
// Default mocks for services
vi.spyOn(OptionService, "getConfig").mockResolvedValue({
APP_MODE: "saas",
GITHUB_CLIENT_ID: "test-client-id",
POSTHOG_CLIENT_KEY: "test-posthog-key",
PROVIDERS_CONFIGURED: ["github"],
AUTH_URL: "https://auth.example.com",
FEATURE_FLAGS: {
ENABLE_BILLING: false,
HIDE_LLM_SETTINGS: false,
ENABLE_JIRA: false,
ENABLE_JIRA_DC: false,
ENABLE_LINEAR: false,
},
});
vi.spyOn(AuthService, "authenticate").mockResolvedValue(true);
vi.spyOn(SettingsService, "getSettings").mockResolvedValue({
language: "en",
user_consents_to_analytics: true,
llm_model: "",
llm_base_url: "",
agent: "",
llm_api_key: null,
llm_api_key_set: false,
search_api_key_set: false,
confirmation_mode: false,
security_analyzer: null,
remote_runtime_resource_factor: null,
provider_tokens_set: {},
enable_default_condenser: false,
condenser_max_size: null,
enable_sound_notifications: false,
enable_proactive_conversation_starters: false,
enable_solvability_analysis: false,
max_budget_per_task: null,
});
// Mock localStorage
vi.stubGlobal("localStorage", {
getItem: vi.fn(() => null),
setItem: vi.fn(),
removeItem: vi.fn(),
clear: vi.fn(),
});
});
afterEach(() => {
vi.restoreAllMocks();
vi.unstubAllGlobals();
});
it("should display EmailVerificationModal when email_verification_required=true is in query params", async () => {
// Arrange & Act
render(
<RouterStub initialEntries={["/?email_verification_required=true"]} />,
{ wrapper: createWrapper() },
);
// Assert
await waitFor(() => {
expect(
screen.getByText("AUTH$PLEASE_CHECK_EMAIL_TO_VERIFY"),
).toBeInTheDocument();
});
});
it("should set emailVerified state and pass to AuthModal when email_verified=true is in query params", async () => {
// Arrange
// Mock a 401 error to simulate unauthenticated user
const axiosError = {
response: { status: 401 },
isAxiosError: true,
};
vi.spyOn(AuthService, "authenticate").mockRejectedValue(axiosError);
// Act
render(<RouterStub initialEntries={["/?email_verified=true"]} />, {
wrapper: createWrapper(),
});
// Assert - Wait for AuthModal to render (since user is not authenticated)
await waitFor(() => {
expect(
screen.getByText("AUTH$EMAIL_VERIFIED_PLEASE_LOGIN"),
).toBeInTheDocument();
});
});
it("should handle both email_verification_required and email_verified params together", async () => {
// Arrange & Act
render(
<RouterStub
initialEntries={[
"/?email_verification_required=true&email_verified=true",
]}
/>,
{ wrapper: createWrapper() },
);
// Assert - EmailVerificationModal should take precedence
await waitFor(() => {
expect(
screen.getByText("AUTH$PLEASE_CHECK_EMAIL_TO_VERIFY"),
).toBeInTheDocument();
});
});
it("should remove query parameters from URL after processing", async () => {
// Arrange & Act
const { container } = render(
<RouterStub initialEntries={["/?email_verification_required=true"]} />,
{ wrapper: createWrapper() },
);
// Assert - Wait for the modal to appear (which indicates processing happened)
await waitFor(() => {
expect(
screen.getByText("AUTH$PLEASE_CHECK_EMAIL_TO_VERIFY"),
).toBeInTheDocument();
});
// Verify that the query parameter was processed by checking the modal appeared
// The hook removes the parameter from the URL, so we verify the behavior indirectly
expect(container).toBeInTheDocument();
});
it("should not display EmailVerificationModal when email_verification_required is not in query params", async () => {
// Arrange - No query params set
// Act
render(<RouterStub />, { wrapper: createWrapper() });
// Assert
await waitFor(() => {
expect(
screen.queryByText("AUTH$PLEASE_CHECK_EMAIL_TO_VERIFY"),
).not.toBeInTheDocument();
});
});
it("should not display email verified message when email_verified is not in query params", async () => {
// Arrange
// Mock a 401 error to simulate unauthenticated user
const axiosError = {
response: { status: 401 },
isAxiosError: true,
};
vi.spyOn(AuthService, "authenticate").mockRejectedValue(axiosError);
// Act
render(<RouterStub />, { wrapper: createWrapper() });
// Assert - AuthModal should render but without email verified message
await waitFor(() => {
const authModal = screen.queryByText(
"AUTH$SIGN_IN_WITH_IDENTITY_PROVIDER",
);
if (authModal) {
expect(
screen.queryByText("AUTH$EMAIL_VERIFIED_PLEASE_LOGIN"),
).not.toBeInTheDocument();
}
});
});
});

View File

@@ -1,76 +0,0 @@
import { describe, it, expect } from "vitest";
import GitHubIssuesPRsService from "../../src/api/github-service/github-issues-prs.api";
describe("GitHubIssuesPRsService", () => {
describe("buildItemUrl", () => {
it("should build correct GitHub issue URL", () => {
const url = GitHubIssuesPRsService.buildItemUrl(
"github",
"owner/repo",
123,
"issue",
);
expect(url).toBe("https://github.com/owner/repo/issues/123");
});
it("should build correct GitHub PR URL", () => {
const url = GitHubIssuesPRsService.buildItemUrl(
"github",
"owner/repo",
456,
"pr",
);
expect(url).toBe("https://github.com/owner/repo/pull/456");
});
it("should build correct GitLab issue URL", () => {
const url = GitHubIssuesPRsService.buildItemUrl(
"gitlab",
"owner/repo",
123,
"issue",
);
expect(url).toBe("https://gitlab.com/owner/repo/-/issues/123");
});
it("should build correct GitLab MR URL", () => {
const url = GitHubIssuesPRsService.buildItemUrl(
"gitlab",
"owner/repo",
456,
"pr",
);
expect(url).toBe("https://gitlab.com/owner/repo/-/merge_requests/456");
});
it("should build correct Bitbucket issue URL", () => {
const url = GitHubIssuesPRsService.buildItemUrl(
"bitbucket",
"owner/repo",
123,
"issue",
);
expect(url).toBe("https://bitbucket.org/owner/repo/issues/123");
});
it("should build correct Bitbucket PR URL", () => {
const url = GitHubIssuesPRsService.buildItemUrl(
"bitbucket",
"owner/repo",
456,
"pr",
);
expect(url).toBe("https://bitbucket.org/owner/repo/pull-requests/456");
});
it("should return empty string for unknown provider", () => {
const url = GitHubIssuesPRsService.buildItemUrl(
"unknown" as any,
"owner/repo",
123,
"issue",
);
expect(url).toBe("");
});
});
});

View File

@@ -9,7 +9,7 @@ import { useShouldShowUserFeatures } from "../src/hooks/use-should-show-user-fea
vi.mock("../src/hooks/use-should-show-user-features");
vi.mock("#/api/suggestions-service/suggestions-service.api", () => ({
SuggestionsService: {
getSuggestedTasks: vi.fn(),
getSuggestedTasks: vi.fn().mockResolvedValue([]),
},
}));

View File

@@ -29,5 +29,75 @@ describe("hasAdvancedSettingsSet", () => {
}),
).toBe(true);
});
test("enable_default_condenser is disabled", () => {
// Arrange
const settings = {
...DEFAULT_SETTINGS,
enable_default_condenser: false,
};
// Act
const result = hasAdvancedSettingsSet(settings);
// Assert
expect(result).toBe(true);
});
test("condenser_max_size is customized above default", () => {
// Arrange
const settings = {
...DEFAULT_SETTINGS,
condenser_max_size: 200,
};
// Act
const result = hasAdvancedSettingsSet(settings);
// Assert
expect(result).toBe(true);
});
test("condenser_max_size is customized below default", () => {
// Arrange
const settings = {
...DEFAULT_SETTINGS,
condenser_max_size: 50,
};
// Act
const result = hasAdvancedSettingsSet(settings);
// Assert
expect(result).toBe(true);
});
test("search_api_key is set to non-empty value", () => {
// Arrange
const settings = {
...DEFAULT_SETTINGS,
search_api_key: "test-api-key-123",
};
// Act
const result = hasAdvancedSettingsSet(settings);
// Assert
expect(result).toBe(true);
});
test("search_api_key with whitespace is treated as set", () => {
// Arrange
const settings = {
...DEFAULT_SETTINGS,
search_api_key: " test-key ",
};
// Act
const result = hasAdvancedSettingsSet(settings);
// Assert
expect(result).toBe(true);
});
});
});

File diff suppressed because it is too large Load Diff

View File

@@ -1,22 +1,22 @@
{
"name": "openhands-frontend",
"version": "1.0.0",
"version": "1.1.0",
"private": true,
"type": "module",
"engines": {
"node": ">=22.0.0"
},
"dependencies": {
"@heroui/react": "2.8.6",
"@heroui/react": "2.8.7",
"@microlink/react-json-view": "^1.26.2",
"@monaco-editor/react": "^4.7.0-rc.0",
"@react-router/node": "^7.11.0",
"@react-router/serve": "^7.11.0",
"@tailwindcss/vite": "^4.1.18",
"@tanstack/react-query": "^5.90.12",
"@tanstack/react-query": "^5.90.14",
"@uidotdev/usehooks": "^2.4.1",
"@xterm/addon-fit": "^0.10.0",
"@xterm/xterm": "^5.4.0",
"@xterm/addon-fit": "^0.11.0",
"@xterm/xterm": "^6.0.0",
"axios": "^1.13.2",
"class-variance-authority": "^0.7.1",
"clsx": "^2.1.1",
@@ -29,7 +29,7 @@
"isbot": "^5.1.32",
"lucide-react": "^0.562.0",
"monaco-editor": "^0.55.1",
"posthog-js": "^1.309.1",
"posthog-js": "^1.310.1",
"react": "^19.2.3",
"react-dom": "^19.2.3",
"react-hot-toast": "^2.6.0",
@@ -41,7 +41,7 @@
"remark-breaks": "^4.0.0",
"remark-gfm": "^4.0.1",
"sirv-cli": "^3.0.1",
"socket.io-client": "^4.8.1",
"socket.io-client": "^4.8.3",
"tailwind-merge": "^3.4.0",
"tailwind-scrollbar": "^4.0.2",
"vite": "^7.3.0",
@@ -109,7 +109,7 @@
"eslint-plugin-react-hooks": "^4.6.2",
"eslint-plugin-unused-imports": "^4.2.0",
"husky": "^9.1.7",
"jsdom": "^27.3.0",
"jsdom": "^27.4.0",
"lint-staged": "^16.2.7",
"msw": "^2.6.6",
"prettier": "^3.7.3",

102
frontend/src/api/README.md Normal file
View File

@@ -0,0 +1,102 @@
# API Services Guide
## Overview
Services are the abstraction layer between frontend components and backend APIs. They encapsulate HTTP requests using the shared `openHands` axios instance and provide typed methods for each endpoint.
Each service is a plain object with async methods.
## Structure
Each service lives in its own directory:
```
src/api/
├── billing-service/
│ ├── billing-service.api.ts # Service methods
│ └── billing.types.ts # Types and interfaces
├── organization-service/
│ ├── organization-service.api.ts
│ └── organization.types.ts
└── open-hands-axios.ts # Shared axios instance
```
## Creating a Service
Use an object literal with named export. Use object destructuring for parameters to make calls self-documenting.
```typescript
// feature-service/feature-service.api.ts
import { openHands } from "../open-hands-axios";
import { Feature, CreateFeatureParams } from "./feature.types";
export const featureService = {
getFeature: async ({ id }: { id: string }) => {
const { data } = await openHands.get<Feature>(`/api/features/${id}`);
return data;
},
createFeature: async ({ name, description }: CreateFeatureParams) => {
const { data } = await openHands.post<Feature>("/api/features", {
name,
description,
});
return data;
},
};
```
### Types
Define types in a separate file within the same directory:
```typescript
// feature-service/feature.types.ts
export interface Feature {
id: string;
name: string;
description: string;
}
export interface CreateFeatureParams {
name: string;
description: string;
}
```
## Usage
> [!IMPORTANT]
> **Don't call services directly in components.** Wrap them in TanStack Query hooks.
>
> Why? TanStack Query provides:
> - **Caching** - Avoid redundant network requests
> - **Deduplication** - Multiple components requesting the same data share one request
> - **Loading/error states** - Built-in `isLoading`, `isError`, `data` states
> - **Background refetching** - Data stays fresh automatically
>
> Hooks location:
> - `src/hooks/query/` for data fetching (`useQuery`)
> - `src/hooks/mutation/` for writes/updates (`useMutation`)
```typescript
// src/hooks/query/use-feature.ts
import { useQuery } from "@tanstack/react-query";
import { featureService } from "#/api/feature-service/feature-service.api";
export const useFeature = (id: string) => {
return useQuery({
queryKey: ["feature", id],
queryFn: () => featureService.getFeature({ id }),
});
};
```
## Naming Conventions
| Item | Convention | Example |
|------|------------|---------|
| Directory | `feature-service/` | `billing-service/` |
| Service file | `feature-service.api.ts` | `billing-service.api.ts` |
| Types file | `feature.types.ts` | `billing.types.ts` |
| Export name | `featureService` | `billingService` |

View File

@@ -298,6 +298,23 @@ class V1ConversationService {
return data;
}
/**
* Update a V1 conversation's public flag
* @param conversationId The conversation ID
* @param isPublic Whether the conversation should be public
* @returns Updated conversation info
*/
static async updateConversationPublicFlag(
conversationId: string,
isPublic: boolean,
): Promise<V1AppConversation> {
const { data } = await openHands.patch<V1AppConversation>(
`/api/v1/app-conversations/${conversationId}`,
{ public: isPublic },
);
return data;
}
/**
* Read a file from a specific conversation's sandbox workspace
* @param conversationId The conversation ID
@@ -317,6 +334,21 @@ class V1ConversationService {
return data;
}
/**
* Download a conversation trajectory as a zip file
* @param conversationId The conversation ID
* @returns A blob containing the zip file
*/
static async downloadConversation(conversationId: string): Promise<Blob> {
const response = await openHands.get(
`/api/v1/app-conversations/${conversationId}/download`,
{
responseType: "blob",
},
);
return response.data;
}
/**
* Get all skills associated with a V1 conversation
* @param conversationId The conversation ID

View File

@@ -98,6 +98,7 @@ export interface V1AppConversation {
execution_status: V1ConversationExecutionStatus | null;
conversation_url: string | null;
session_api_key: string | null;
public?: boolean;
}
export interface Skill {

View File

@@ -0,0 +1,35 @@
import { openHands } from "../open-hands-axios";
import {
ResendEmailVerificationParams,
ResendEmailVerificationResponse,
} from "./email.types";
/**
* Email Service API - Handles all email-related API endpoints
*/
export const emailService = {
/**
* Resend email verification to the user's registered email address
* @param userId - Optional user ID to send verification email for
* @param isAuthFlow - Whether this is part of the authentication flow
* @returns The response message indicating the email was sent
*/
resendEmailVerification: async ({
userId,
isAuthFlow,
}: ResendEmailVerificationParams): Promise<ResendEmailVerificationResponse> => {
const body: { user_id?: string; is_auth_flow?: boolean } = {};
if (userId) {
body.user_id = userId;
}
if (isAuthFlow !== undefined) {
body.is_auth_flow = isAuthFlow;
}
const { data } = await openHands.put<ResendEmailVerificationResponse>(
"/api/email/resend",
body,
{ withCredentials: true },
);
return data;
},
};

View File

@@ -0,0 +1,8 @@
export interface ResendEmailVerificationParams {
userId?: string | null;
isAuthFlow?: boolean;
}
export interface ResendEmailVerificationResponse {
message: string;
}

View File

@@ -131,9 +131,18 @@ class GitService {
repository: string,
page: number = 1,
perPage: number = 30,
selectedProvider?: Provider,
): Promise<PaginatedBranchesResponse> {
const { data } = await openHands.get<PaginatedBranchesResponse>(
`/api/user/repository/branches?repository=${encodeURIComponent(repository)}&page=${page}&per_page=${perPage}`,
`/api/user/repository/branches`,
{
params: {
repository,
page,
per_page: perPage,
selected_provider: selectedProvider,
},
},
);
return data;

View File

@@ -53,6 +53,13 @@ class V1GitService {
// V1 API returns V1GitChangeStatus types, we need to map them to V0 format
const { data } = await axios.get<V1GitChange[]>(url, { headers });
// Validate response is an array (could be HTML error page if runtime is dead)
if (!Array.isArray(data)) {
throw new Error(
"Invalid response from runtime - runtime may be unavailable",
);
}
// Map V1 statuses to V0 format for compatibility
return data.map((change) => ({
status: mapV1ToV0Status(change.status),

View File

@@ -1,124 +0,0 @@
import { openHands } from "../open-hands-axios";
import { Provider } from "#/types/settings";
export type GitHubItemType = "issue" | "pr";
export type GitHubItemStatus =
| "MERGE_CONFLICTS"
| "FAILING_CHECKS"
| "UNRESOLVED_COMMENTS"
| "OPEN_ISSUE"
| "OPEN_PR";
export interface GitHubItem {
git_provider: Provider;
item_type: GitHubItemType;
status: GitHubItemStatus;
repo: string;
number: number;
title: string;
author: string;
assignees: string[];
created_at: string;
updated_at: string;
url: string;
}
export interface GitHubItemsFilter {
itemType?: "issues" | "prs" | "all";
assignedToMe?: boolean;
authoredByMe?: boolean;
}
export interface GitHubItemsResponse {
items: GitHubItem[];
cached_at?: string;
}
/**
* GitHub Issues/PRs Service - Handles fetching GitHub issues and pull requests
*/
class GitHubIssuesPRsService {
/**
* Get GitHub issues and PRs for the authenticated user
* This uses the existing suggested-tasks endpoint and transforms the data
*/
static async getGitHubItems(
filter?: GitHubItemsFilter,
): Promise<GitHubItemsResponse> {
const { data } = await openHands.get<
Array<{
git_provider: Provider;
task_type: GitHubItemStatus;
repo: string;
issue_number: number;
title: string;
}>
>("/api/user/suggested-tasks");
// Transform the suggested tasks into GitHubItems
const items: GitHubItem[] = data.map((task) => ({
git_provider: task.git_provider,
item_type: task.task_type === "OPEN_ISSUE" ? "issue" : "pr",
status: task.task_type,
repo: task.repo,
number: task.issue_number,
title: task.title,
author: "", // Not available from suggested-tasks endpoint
assignees: [], // Not available from suggested-tasks endpoint
created_at: new Date().toISOString(), // Not available from suggested-tasks endpoint
updated_at: new Date().toISOString(), // Not available from suggested-tasks endpoint
url: GitHubIssuesPRsService.buildItemUrl(
task.git_provider,
task.repo,
task.issue_number,
task.task_type === "OPEN_ISSUE" ? "issue" : "pr",
),
}));
// Apply filters
let filteredItems = items;
if (filter?.itemType === "issues") {
filteredItems = filteredItems.filter(
(item) => item.item_type === "issue",
);
} else if (filter?.itemType === "prs") {
filteredItems = filteredItems.filter((item) => item.item_type === "pr");
}
// Note: assignedToMe and authoredByMe filters would require additional API data
// For now, the suggested-tasks endpoint already returns:
// - PRs authored by the user
// - Issues assigned to the user
// So these filters are implicitly applied by the backend
return {
items: filteredItems,
cached_at: new Date().toISOString(),
};
}
/**
* Build the URL for a GitHub item
*/
static buildItemUrl(
provider: Provider,
repo: string,
number: number,
itemType: GitHubItemType,
): string {
if (provider === "github") {
return `https://github.com/${repo}/${itemType === "issue" ? "issues" : "pull"}/${number}`;
}
if (provider === "gitlab") {
return `https://gitlab.com/${repo}/-/${itemType === "issue" ? "issues" : "merge_requests"}/${number}`;
}
if (provider === "bitbucket") {
return `https://bitbucket.org/${repo}/${itemType === "issue" ? "issues" : "pull-requests"}/${number}`;
}
return "";
}
}
export default GitHubIssuesPRsService;

View File

@@ -78,6 +78,7 @@ export interface Conversation {
pr_number?: number[] | null;
conversation_version?: "V0" | "V1";
sub_conversation_ids?: string[];
public?: boolean;
}
export interface ResultSet<T> {

View File

@@ -0,0 +1,60 @@
import { OpenHandsEvent } from "#/types/v1/core";
import { openHands } from "./open-hands-axios";
export interface SharedConversation {
id: string;
created_by_user_id: string | null;
sandbox_id: string;
selected_repository: string | null;
selected_branch: string | null;
git_provider: string | null;
title: string | null;
pr_number: number[];
llm_model: string | null;
metrics: unknown | null;
parent_conversation_id: string | null;
sub_conversation_ids: string[];
created_at: string;
updated_at: string;
}
export interface EventPage {
items: OpenHandsEvent[];
next_page_id: string | null;
}
export const sharedConversationService = {
/**
* Get a single shared conversation by ID
*/
async getSharedConversation(
conversationId: string,
): Promise<SharedConversation | null> {
const response = await openHands.get<(SharedConversation | null)[]>(
"/api/shared-conversations",
{ params: { ids: conversationId } },
);
return response.data[0] || null;
},
/**
* Get events for a shared conversation
*/
async getSharedConversationEvents(
conversationId: string,
limit: number = 100,
pageId?: string,
): Promise<EventPage> {
const response = await openHands.get<EventPage>(
"/api/shared-events/search",
{
params: {
conversation_id: conversationId,
limit,
...(pageId && { page_id: pageId }),
},
},
);
return response.data;
},
};

View File

@@ -11,6 +11,7 @@ interface ConversationCardActionsProps {
onStop?: (event: React.MouseEvent<HTMLButtonElement>) => void;
onEdit?: (event: React.MouseEvent<HTMLButtonElement>) => void;
onDownloadViaVSCode?: (event: React.MouseEvent<HTMLButtonElement>) => void;
onDownloadConversation?: (event: React.MouseEvent<HTMLButtonElement>) => void;
conversationStatus?: ConversationStatus;
conversationId?: string;
showOptions?: boolean;
@@ -23,6 +24,7 @@ export function ConversationCardActions({
onStop,
onEdit,
onDownloadViaVSCode,
onDownloadConversation,
conversationStatus,
conversationId,
showOptions,
@@ -62,6 +64,9 @@ export function ConversationCardActions({
onDownloadViaVSCode={
conversationId && showOptions ? onDownloadViaVSCode : undefined
}
onDownloadConversation={
conversationId ? onDownloadConversation : undefined
}
position="bottom"
/>
</div>

View File

@@ -24,6 +24,7 @@ interface ConversationCardContextMenuProps {
onShowAgentTools?: (event: React.MouseEvent<HTMLButtonElement>) => void;
onShowSkills?: (event: React.MouseEvent<HTMLButtonElement>) => void;
onDownloadViaVSCode?: (event: React.MouseEvent<HTMLButtonElement>) => void;
onDownloadConversation?: (event: React.MouseEvent<HTMLButtonElement>) => void;
position?: "top" | "bottom";
}
@@ -39,24 +40,27 @@ export function ConversationCardContextMenu({
onShowAgentTools,
onShowSkills,
onDownloadViaVSCode,
onDownloadConversation,
position = "bottom",
}: ConversationCardContextMenuProps) {
const { t } = useTranslation();
const ref = useClickOutsideElement<HTMLUListElement>(onClose);
const generateSection = useCallback(
(items: React.ReactNode[], isLast?: boolean) => {
(items: React.ReactNode[], sectionKey: string, isLast?: boolean) => {
const filteredItems = items.filter((i) => i != null);
if (filteredItems.length > 0) {
return !isLast
? [
...filteredItems,
<Divider key="conversation-card-context-menu-divider" />,
]
: filteredItems;
return !isLast ? (
<React.Fragment key={sectionKey}>
{filteredItems}
<Divider />
</React.Fragment>
) : (
<React.Fragment key={sectionKey}>{filteredItems}</React.Fragment>
);
}
return [];
return null;
},
[],
);
@@ -69,76 +73,104 @@ export function ConversationCardContextMenu({
alignment="right"
className="mt-0"
>
{generateSection([
onEdit && (
<ContextMenuListItem
testId="edit-button"
onClick={onEdit}
className={contextMenuListItemClassName}
>
<ConversationNameContextMenuIconText
icon={<EditIcon width={16} height={16} />}
text={t(I18nKey.BUTTON$RENAME)}
/>
</ContextMenuListItem>
),
])}
{generateSection([
onShowAgentTools && (
<ContextMenuListItem
testId="show-agent-tools-button"
onClick={onShowAgentTools}
className={contextMenuListItemClassName}
>
<ConversationNameContextMenuIconText
icon={<ToolsIcon width={16} height={16} />}
text={t(I18nKey.BUTTON$SHOW_AGENT_TOOLS_AND_METADATA)}
/>
</ContextMenuListItem>
),
onShowSkills && (
<ContextMenuListItem
testId="show-skills-button"
onClick={onShowSkills}
className={contextMenuListItemClassName}
>
<ConversationNameContextMenuIconText
icon={<RobotIcon width={16} height={16} />}
text={t(I18nKey.CONVERSATION$SHOW_SKILLS)}
/>
</ContextMenuListItem>
),
])}
{generateSection([
onStop && (
<ContextMenuListItem
testId="stop-button"
onClick={onStop}
className={contextMenuListItemClassName}
>
<ConversationNameContextMenuIconText
icon={<CloseIcon width={16} height={16} />}
text={t(I18nKey.COMMON$CLOSE_CONVERSATION_STOP_RUNTIME)}
/>
</ContextMenuListItem>
),
onDownloadViaVSCode && (
<ContextMenuListItem
testId="download-vscode-button"
onClick={onDownloadViaVSCode}
className={contextMenuListItemClassName}
>
<ConversationNameContextMenuIconText
icon={<DownloadIcon width={16} height={16} />}
text={t(I18nKey.BUTTON$DOWNLOAD_VIA_VSCODE)}
/>
</ContextMenuListItem>
),
])}
{generateSection(
[
onEdit && (
<ContextMenuListItem
key="edit-button"
testId="edit-button"
onClick={onEdit}
className={contextMenuListItemClassName}
>
<ConversationNameContextMenuIconText
icon={<EditIcon width={16} height={16} />}
text={t(I18nKey.BUTTON$RENAME)}
/>
</ContextMenuListItem>
),
],
"edit-section",
)}
{generateSection(
[
onShowAgentTools && (
<ContextMenuListItem
key="show-agent-tools-button"
testId="show-agent-tools-button"
onClick={onShowAgentTools}
className={contextMenuListItemClassName}
>
<ConversationNameContextMenuIconText
icon={<ToolsIcon width={16} height={16} />}
text={t(I18nKey.BUTTON$SHOW_AGENT_TOOLS_AND_METADATA)}
/>
</ContextMenuListItem>
),
onShowSkills && (
<ContextMenuListItem
key="show-skills-button"
testId="show-skills-button"
onClick={onShowSkills}
className={contextMenuListItemClassName}
>
<ConversationNameContextMenuIconText
icon={<RobotIcon width={16} height={16} />}
text={t(I18nKey.CONVERSATION$SHOW_SKILLS)}
/>
</ContextMenuListItem>
),
],
"tools-section",
)}
{generateSection(
[
onStop && (
<ContextMenuListItem
key="stop-button"
testId="stop-button"
onClick={onStop}
className={contextMenuListItemClassName}
>
<ConversationNameContextMenuIconText
icon={<CloseIcon width={16} height={16} />}
text={t(I18nKey.COMMON$CLOSE_CONVERSATION_STOP_RUNTIME)}
/>
</ContextMenuListItem>
),
onDownloadViaVSCode && (
<ContextMenuListItem
key="download-vscode-button"
testId="download-vscode-button"
onClick={onDownloadViaVSCode}
className={contextMenuListItemClassName}
>
<ConversationNameContextMenuIconText
icon={<DownloadIcon width={16} height={16} />}
text={t(I18nKey.BUTTON$DOWNLOAD_VIA_VSCODE)}
/>
</ContextMenuListItem>
),
onDownloadConversation && (
<ContextMenuListItem
key="download-trajectory-button"
testId="download-trajectory-button"
onClick={onDownloadConversation}
className={contextMenuListItemClassName}
>
<ConversationNameContextMenuIconText
icon={<DownloadIcon width={16} height={16} />}
text={t(I18nKey.BUTTON$EXPORT_CONVERSATION)}
/>
</ContextMenuListItem>
),
],
"control-section",
)}
{generateSection(
[
onDisplayCost && (
<ContextMenuListItem
key="display-cost-button"
testId="display-cost-button"
onClick={onDisplayCost}
className={contextMenuListItemClassName}
@@ -151,6 +183,7 @@ export function ConversationCardContextMenu({
),
onDelete && (
<ContextMenuListItem
key="delete-button"
testId="delete-button"
onClick={onDelete}
className={contextMenuListItemClassName}
@@ -158,10 +191,11 @@ export function ConversationCardContextMenu({
<ConversationNameContextMenuIconText
icon={<DeleteIcon width={16} height={16} />}
text={t(I18nKey.COMMON$DELETE_CONVERSATION)}
/>{" "}
/>
</ContextMenuListItem>
),
],
"info-section",
true,
)}
</ContextMenu>

View File

@@ -8,6 +8,7 @@ import { RepositorySelection } from "#/api/open-hands.types";
import { ConversationCardHeader } from "./conversation-card-header";
import { ConversationCardActions } from "./conversation-card-actions";
import { ConversationCardFooter } from "./conversation-card-footer";
import { useDownloadConversation } from "#/hooks/use-download-conversation";
interface ConversationCardProps {
onClick?: () => void;
@@ -46,6 +47,7 @@ export function ConversationCard({
}: ConversationCardProps) {
const posthog = usePostHog();
const [titleMode, setTitleMode] = React.useState<"view" | "edit">("view");
const { mutateAsync: downloadConversation } = useDownloadConversation();
const onTitleSave = (newTitle: string) => {
if (newTitle !== "" && newTitle !== title) {
@@ -101,6 +103,18 @@ export function ConversationCard({
onContextMenuToggle?.(false);
};
const handleDownloadConversation = async (
event: React.MouseEvent<HTMLButtonElement>,
) => {
event.preventDefault();
event.stopPropagation();
if (conversationId && conversationVersion === "V1") {
await downloadConversation(conversationId);
}
onContextMenuToggle?.(false);
};
const hasContextMenu = !!(onDelete || onChangeTitle || showOptions);
return (
@@ -130,6 +144,11 @@ export function ConversationCard({
onStop={onStop && handleStop}
onEdit={onChangeTitle && handleEdit}
onDownloadViaVSCode={handleDownloadViaVSCode}
onDownloadConversation={
conversationVersion === "V1"
? handleDownloadConversation
: undefined
}
conversationStatus={conversationStatus}
conversationId={conversationId}
showOptions={showOptions}

View File

@@ -7,6 +7,7 @@ import { ContextMenuListItem } from "../context-menu/context-menu-list-item";
import { Divider } from "#/ui/divider";
import { I18nKey } from "#/i18n/declaration";
import { useActiveConversation } from "#/hooks/query/use-active-conversation";
import { useConfig } from "#/hooks/query/use-config";
import EditIcon from "#/icons/u-edit.svg?react";
import RobotIcon from "#/icons/u-robot.svg?react";
@@ -16,6 +17,7 @@ import DownloadIcon from "#/icons/u-download.svg?react";
import CreditCardIcon from "#/icons/u-credit-card.svg?react";
import CloseIcon from "#/icons/u-close.svg?react";
import DeleteIcon from "#/icons/u-delete.svg?react";
import LinkIcon from "#/icons/link-external.svg?react";
import { ConversationNameContextMenuIconText } from "./conversation-name-context-menu-icon-text";
import { CONTEXT_MENU_ICON_TEXT_CLASSNAME } from "#/utils/constants";
@@ -34,6 +36,9 @@ interface ConversationNameContextMenuProps {
onShowSkills?: (event: React.MouseEvent<HTMLButtonElement>) => void;
onExportConversation?: (event: React.MouseEvent<HTMLButtonElement>) => void;
onDownloadViaVSCode?: (event: React.MouseEvent<HTMLButtonElement>) => void;
onTogglePublic?: (event: React.MouseEvent<HTMLButtonElement>) => void;
onDownloadConversation?: (event: React.MouseEvent<HTMLButtonElement>) => void;
onCopyShareLink?: (event: React.MouseEvent<HTMLButtonElement>) => void;
position?: "top" | "bottom";
}
@@ -47,6 +52,9 @@ export function ConversationNameContextMenu({
onShowSkills,
onExportConversation,
onDownloadViaVSCode,
onTogglePublic,
onDownloadConversation,
onCopyShareLink,
position = "bottom",
}: ConversationNameContextMenuProps) {
const { width } = useWindowSize();
@@ -54,11 +62,17 @@ export function ConversationNameContextMenu({
const { t } = useTranslation();
const ref = useClickOutsideElement<HTMLUListElement>(onClose);
const { data: conversation } = useActiveConversation();
const { data: config } = useConfig();
// This is a temporary measure and may be re-enabled in the future
const isV1Conversation = conversation?.conversation_version === "V1";
const hasDownload = Boolean(onDownloadViaVSCode);
// Check if we should show the public sharing option
// Only show for V1 conversations in SAAS mode
const shouldShowPublicSharing =
isV1Conversation && config?.APP_MODE === "saas" && onTogglePublic;
const hasDownload = Boolean(onDownloadViaVSCode || onDownloadConversation);
const hasExport = Boolean(onExportConversation);
const hasTools = Boolean(onShowAgentTools || onShowSkills);
const hasInfo = Boolean(onDisplayCost);
@@ -118,9 +132,9 @@ export function ConversationNameContextMenu({
</ContextMenuListItem>
)}
{(hasExport || hasDownload) && !isV1Conversation && (
{(hasExport || hasDownload) && !isV1Conversation ? (
<Divider testId="separator-export" />
)}
) : null}
{onExportConversation && !isV1Conversation && (
<ContextMenuListItem
@@ -150,10 +164,22 @@ export function ConversationNameContextMenu({
</ContextMenuListItem>
)}
{(hasInfo || hasControl) && !isV1Conversation && (
<Divider testId="separator-info-control" />
{onDownloadConversation && isV1Conversation && (
<ContextMenuListItem
testId="download-trajectory-button"
onClick={onDownloadConversation}
className={contextMenuListItemClassName}
>
<ConversationNameContextMenuIconText
icon={<DownloadIcon width={16} height={16} />}
text={t(I18nKey.BUTTON$EXPORT_CONVERSATION)}
className={CONTEXT_MENU_ICON_TEXT_CLASSNAME}
/>
</ContextMenuListItem>
)}
{(hasInfo || hasControl) && <Divider testId="separator-info-control" />}
{onDisplayCost && (
<ContextMenuListItem
testId="display-cost-button"
@@ -168,6 +194,36 @@ export function ConversationNameContextMenu({
</ContextMenuListItem>
)}
{shouldShowPublicSharing && (
<ContextMenuListItem
testId="share-publicly-button"
onClick={onTogglePublic}
className={contextMenuListItemClassName}
>
<div className="flex items-center gap-2 justify-between w-full">
<div className="flex items-center gap-2">
<input
type="checkbox"
checked={conversation?.public || false}
className="w-4 h-4 ml-2"
/>
<span>{t(I18nKey.CONVERSATION$SHARE_PUBLICLY)}</span>
</div>
{conversation?.public && onCopyShareLink && (
<button
type="button"
data-testid="copy-share-link-button"
onClick={onCopyShareLink}
className="p-1 hover:bg-[#717888] rounded"
title={t(I18nKey.BUTTON$COPY_TO_CLIPBOARD)}
>
<LinkIcon width={16} height={16} />
</button>
)}
</div>
</ContextMenuListItem>
)}
{onStop && (
<ContextMenuListItem
testId="stop-button"

View File

@@ -6,6 +6,7 @@ import { useUpdateConversation } from "#/hooks/mutation/use-update-conversation"
import { useConversationNameContextMenu } from "#/hooks/use-conversation-name-context-menu";
import { displaySuccessToast } from "#/utils/custom-toast-handlers";
import { I18nKey } from "#/i18n/declaration";
import { ENABLE_PUBLIC_CONVERSATION_SHARING } from "#/utils/feature-flags";
import { EllipsisButton } from "../conversation-panel/ellipsis-button";
import { ConversationNameContextMenu } from "./conversation-name-context-menu";
import { SystemMessageModal } from "../conversation-panel/system-message-modal";
@@ -30,10 +31,13 @@ export function ConversationName() {
handleDelete,
handleStop,
handleDownloadViaVSCode,
handleDownloadConversation,
handleDisplayCost,
handleShowAgentTools,
handleShowSkills,
handleExportConversation,
handleTogglePublic,
handleCopyShareLink,
handleConfirmDelete,
handleConfirmStop,
metricsModalVisible,
@@ -50,6 +54,7 @@ export function ConversationName() {
shouldShowStop,
shouldShowDownload,
shouldShowExport,
shouldShowDownloadConversation,
shouldShowDisplayCost,
shouldShowAgentTools,
shouldShowSkills,
@@ -177,6 +182,21 @@ export function ConversationName() {
onDownloadViaVSCode={
shouldShowDownload ? handleDownloadViaVSCode : undefined
}
onTogglePublic={
ENABLE_PUBLIC_CONVERSATION_SHARING()
? handleTogglePublic
: undefined
}
onCopyShareLink={
ENABLE_PUBLIC_CONVERSATION_SHARING()
? handleCopyShareLink
: undefined
}
onDownloadConversation={
shouldShowDownloadConversation
? handleDownloadConversation
: undefined
}
position="bottom"
/>
)}

View File

@@ -79,6 +79,7 @@ export function BranchDropdownMenu({
menuRef={menuRef}
renderItem={renderItem}
renderEmptyState={renderEmptyState}
itemKey={(branch) => branch.name}
/>
</div>
);

View File

@@ -211,6 +211,7 @@ export function GitProviderDropdown({
getItemProps={getItemProps}
renderItem={renderItem}
renderEmptyState={renderEmptyState}
itemKey={(provider) => provider}
/>
<ErrorMessage isError={!!errorMessage} message={errorMessage} />

View File

@@ -369,6 +369,7 @@ export function GitRepoDropdown({
stickyFooterItem={stickyFooterItem}
testId="git-repo-dropdown-menu"
numberOfRecentItems={recentRepositories.length}
itemKey={(repo) => repo.id}
/>
<ErrorMessage isError={isError} />

View File

@@ -39,6 +39,7 @@ export function ConversationStatusIndicator({
ariaLabel={statusLabel}
placement="right"
showArrow
asSpan
className="p-0 border-0 bg-transparent hover:opacity-100"
tooltipClassName="bg-[#1a1a1a] text-white text-xs shadow-lg"
>

View File

@@ -20,63 +20,59 @@ export function RecentConversation({ conversation }: RecentConversationProps) {
conversation.selected_repository && conversation.selected_branch;
return (
<Link to={`/conversations/${conversation.conversation_id}`}>
<button
type="button"
className="flex flex-col gap-1 p-[14px] cursor-pointer w-full rounded-lg hover:bg-[#5C5D62] transition-all duration-300 text-left"
>
<div className="flex items-center gap-2 pl-1">
<ConversationStatusIndicator
conversationStatus={conversation.status}
/>
<span className="text-xs text-white leading-6 font-normal">
{conversation.title}
</span>
</div>
<div className="flex items-center justify-between text-xs text-[#A3A3A3] leading-4 font-normal">
<div className="flex items-center gap-3">
{hasRepository ? (
<div className="flex items-center gap-2">
<GitProviderIcon
gitProvider={conversation.git_provider as Provider}
/>
<span
className="max-w-[124px] truncate"
title={conversation.selected_repository || ""}
>
{conversation.selected_repository}
</span>
</div>
) : (
<div className="flex items-center gap-1">
<RepoForkedIcon width={12} height={12} color="#A3A3A3" />
<span className="max-w-[124px] truncate">
{t(I18nKey.COMMON$NO_REPOSITORY)}
</span>
</div>
)}
{hasRepository ? (
<div className="flex items-center gap-1">
<CodeBranchIcon width={12} height={12} color="#A3A3A3" />
<span
className="max-w-[124px] truncate"
title={conversation.selected_branch || ""}
>
{conversation.selected_branch}
</span>
</div>
) : null}
</div>
{(conversation.created_at || conversation.last_updated_at) && (
<span>
{formatTimeDelta(
conversation.created_at || conversation.last_updated_at,
)}{" "}
{t(I18nKey.CONVERSATION$AGO)}
</span>
<Link
to={`/conversations/${conversation.conversation_id}`}
className="flex flex-col gap-1 p-[14px] cursor-pointer w-full rounded-lg hover:bg-[#5C5D62] transition-all duration-300 text-left"
>
<div className="flex items-center gap-2 pl-1">
<ConversationStatusIndicator conversationStatus={conversation.status} />
<span className="text-xs text-white leading-6 font-normal">
{conversation.title}
</span>
</div>
<div className="flex items-center justify-between text-xs text-[#A3A3A3] leading-4 font-normal">
<div className="flex items-center gap-3">
{hasRepository ? (
<div className="flex items-center gap-2">
<GitProviderIcon
gitProvider={conversation.git_provider as Provider}
/>
<span
className="max-w-[124px] truncate"
title={conversation.selected_repository || ""}
>
{conversation.selected_repository}
</span>
</div>
) : (
<div className="flex items-center gap-1">
<RepoForkedIcon width={12} height={12} color="#A3A3A3" />
<span className="max-w-[124px] truncate">
{t(I18nKey.COMMON$NO_REPOSITORY)}
</span>
</div>
)}
{hasRepository ? (
<div className="flex items-center gap-1">
<CodeBranchIcon width={12} height={12} color="#A3A3A3" />
<span
className="max-w-[124px] truncate"
title={conversation.selected_branch || ""}
>
{conversation.selected_branch}
</span>
</div>
) : null}
</div>
</button>
{(conversation.created_at || conversation.last_updated_at) && (
<span>
{formatTimeDelta(
conversation.created_at || conversation.last_updated_at,
)}{" "}
{t(I18nKey.CONVERSATION$AGO)}
</span>
)}
</div>
</Link>
);
}

View File

@@ -33,6 +33,7 @@ export interface GenericDropdownMenuProps<T> {
stickyFooterItem?: React.ReactNode;
testId?: string;
numberOfRecentItems?: number;
itemKey: (item: T) => string | number;
}
export function GenericDropdownMenu<T>({
@@ -51,12 +52,28 @@ export function GenericDropdownMenu<T>({
stickyFooterItem,
testId,
numberOfRecentItems = 0,
itemKey,
}: GenericDropdownMenuProps<T>) {
if (!isOpen) return null;
const hasItems = filteredItems.length > 0;
const showEmptyState = !hasItems && !stickyTopItem && !stickyFooterItem;
// Always render the menu container (even when closed) so getMenuProps is always called
// This prevents the downshift warning about forgetting to call getMenuProps
if (!isOpen) {
return (
<div className="relative">
<ul
// eslint-disable-next-line react/jsx-props-no-spreading
{...getMenuProps({
ref: menuRef,
className: "hidden",
"data-testid": testId,
})}
/>
</div>
);
}
return (
<div className="relative">
<div
@@ -85,21 +102,24 @@ export function GenericDropdownMenu<T>({
) : (
<>
{stickyTopItem}
{filteredItems.map((item, index) => (
<>
{renderItem(
item,
index,
highlightedIndex,
selectedItem,
getItemProps,
)}
{numberOfRecentItems > 0 &&
index === numberOfRecentItems - 1 && (
<div className="border-b border-[#727987] bg-[#454545] pb-1 mb-1 h-[1px]" />
{filteredItems.map((item, index) => {
const key = itemKey(item);
return (
<React.Fragment key={key}>
{renderItem(
item,
index,
highlightedIndex,
selectedItem,
getItemProps,
)}
</>
))}
{numberOfRecentItems > 0 &&
index === numberOfRecentItems - 1 && (
<div className="border-b border-[#727987] bg-[#454545] pb-1 mb-1 h-[1px]" />
)}
</React.Fragment>
);
})}
</>
)}
</ul>

View File

@@ -17,9 +17,10 @@ export function MicroagentManagementAccordionTitle({
<TooltipButton
tooltip={repository.full_name}
ariaLabel={repository.full_name}
className="text-white text-base font-normal bg-transparent p-0 min-w-0 h-auto cursor-pointer truncate max-w-[194px] translate-y-[-1px]"
testId="repository-name-tooltip"
placement="bottom"
asSpan
className="text-white text-base font-normal bg-transparent p-0 min-w-0 h-auto cursor-pointer truncate max-w-[194px] translate-y-[-1px]"
>
<span>{repository.full_name}</span>
</TooltipButton>

Some files were not shown because too many files have changed in this diff Show More