mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
50 Commits
feature/gi
...
1.1.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9885ddea33 | ||
|
|
103e3ead0a | ||
|
|
d5e83d0f06 | ||
|
|
443918af3c | ||
|
|
910646d11f | ||
|
|
d9d19043f1 | ||
|
|
4dec38c7ce | ||
|
|
c3f51d9dbe | ||
|
|
ecbd3ae749 | ||
|
|
8ee1394e8c | ||
|
|
d628e1f20a | ||
|
|
1480d4acb0 | ||
|
|
58a70e8b0d | ||
|
|
49e46a5fa1 | ||
|
|
2cf6494773 | ||
|
|
d3afbfa447 | ||
|
|
8d69b4066f | ||
|
|
2261281656 | ||
|
|
d68b2cdd1a | ||
|
|
c70ecc8fe3 | ||
|
|
a3e85e2c2d | ||
|
|
3bef4e6c2d | ||
|
|
97654e6a5e | ||
|
|
30114666ad | ||
|
|
ee50f333ba | ||
|
|
09d1748a14 | ||
|
|
81519343c4 | ||
|
|
f742811e81 | ||
|
|
f8e4b5562e | ||
|
|
cb1d1f8a0d | ||
|
|
a829d10213 | ||
|
|
cb8c1fa263 | ||
|
|
c80f70392f | ||
|
|
94e6490a79 | ||
|
|
09af93a02a | ||
|
|
5407ea55aa | ||
|
|
fe1026ee8a | ||
|
|
6d14ce420e | ||
|
|
36fe23aea3 | ||
|
|
9049b95792 | ||
|
|
e2b2aa52cd | ||
|
|
dc99c7b62e | ||
|
|
8bc1a47a78 | ||
|
|
8d0e7a92b8 | ||
|
|
f6e7628bff | ||
|
|
fae83230ee | ||
|
|
a9d2f72d72 | ||
|
|
2b8f779b65 | ||
|
|
10edb28729 | ||
|
|
5553d3ca2e |
2
.github/workflows/check-package-versions.yml
vendored
2
.github/workflows/check-package-versions.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
|
||||
8
.github/workflows/e2e-tests.yml
vendored
8
.github/workflows/e2e-tests.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
||||
poetry-version: 2.1.3
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.12'
|
||||
cache: 'poetry'
|
||||
@@ -38,7 +38,7 @@ jobs:
|
||||
sudo apt-get install -y libgtk-3-0 libnotify4 libnss3 libxss1 libxtst6 xauth xvfb libgbm1 libasound2t64 netcat-openbsd
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: '22'
|
||||
cache: 'npm'
|
||||
@@ -192,7 +192,7 @@ jobs:
|
||||
|
||||
- name: Upload test results
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: playwright-report
|
||||
path: tests/e2e/test-results/
|
||||
@@ -200,7 +200,7 @@ jobs:
|
||||
|
||||
- name: Upload OpenHands logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: openhands-logs
|
||||
path: |
|
||||
|
||||
@@ -43,7 +43,7 @@ jobs:
|
||||
⚠️ This PR contains **migrations**
|
||||
|
||||
- name: Comment warning on PR
|
||||
uses: peter-evans/create-or-update-comment@v4
|
||||
uses: peter-evans/create-or-update-comment@v5
|
||||
with:
|
||||
issue-number: ${{ github.event.pull_request.number }}
|
||||
comment-id: ${{ steps.find-comment.outputs.comment-id }}
|
||||
|
||||
2
.github/workflows/fe-e2e-tests.yml
vendored
2
.github/workflows/fe-e2e-tests.yml
vendored
@@ -39,7 +39,7 @@ jobs:
|
||||
working-directory: ./frontend
|
||||
run: npx playwright test --project=chromium
|
||||
- name: Upload Playwright report
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
if: always()
|
||||
with:
|
||||
name: playwright-report
|
||||
|
||||
10
.github/workflows/ghcr-build.yml
vendored
10
.github/workflows/ghcr-build.yml
vendored
@@ -64,7 +64,7 @@ jobs:
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha }}
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3.6.0
|
||||
uses: docker/setup-qemu-action@v3.7.0
|
||||
with:
|
||||
image: tonistiigi/binfmt:latest
|
||||
- name: Login to GHCR
|
||||
@@ -102,7 +102,7 @@ jobs:
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha }}
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3.6.0
|
||||
uses: docker/setup-qemu-action@v3.7.0
|
||||
with:
|
||||
image: tonistiigi/binfmt:latest
|
||||
- name: Login to GHCR
|
||||
@@ -161,7 +161,7 @@ jobs:
|
||||
context: containers/runtime
|
||||
- name: Upload runtime source for fork
|
||||
if: github.event.pull_request.head.repo.fork
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: runtime-src-${{ matrix.base_image.tag }}
|
||||
path: containers/runtime
|
||||
@@ -268,7 +268,7 @@ jobs:
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Download runtime source for fork
|
||||
if: github.event.pull_request.head.repo.fork
|
||||
uses: actions/download-artifact@v4
|
||||
uses: actions/download-artifact@v6
|
||||
with:
|
||||
name: runtime-src-${{ matrix.base_image.tag }}
|
||||
path: containers/runtime
|
||||
@@ -330,7 +330,7 @@ jobs:
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Download runtime source for fork
|
||||
if: github.event.pull_request.head.repo.fork
|
||||
uses: actions/download-artifact@v4
|
||||
uses: actions/download-artifact@v6
|
||||
with:
|
||||
name: runtime-src-${{ matrix.base_image.tag }}
|
||||
path: containers/runtime
|
||||
|
||||
6
.github/workflows/openhands-resolver.yml
vendored
6
.github/workflows/openhands-resolver.yml
vendored
@@ -89,7 +89,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.12"
|
||||
- name: Upgrade pip
|
||||
@@ -118,7 +118,7 @@ jobs:
|
||||
contains(github.event.review.body, '@openhands-agent-exp')
|
||||
)
|
||||
)
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ${{ env.pythonLocation }}/lib/python3.12/site-packages/*
|
||||
key: ${{ runner.os }}-pip-openhands-resolver-${{ hashFiles('/tmp/requirements.txt') }}
|
||||
@@ -269,7 +269,7 @@ jobs:
|
||||
fi
|
||||
|
||||
- name: Upload output.jsonl as artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
if: always() # Upload even if the previous steps fail
|
||||
with:
|
||||
name: resolver-output
|
||||
|
||||
6
.github/workflows/py-tests.yml
vendored
6
.github/workflows/py-tests.yml
vendored
@@ -63,7 +63,7 @@ jobs:
|
||||
env:
|
||||
COVERAGE_FILE: ".coverage.runtime.${{ matrix.python_version }}"
|
||||
- name: Store coverage file
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: coverage-openhands
|
||||
path: |
|
||||
@@ -95,7 +95,7 @@ jobs:
|
||||
env:
|
||||
COVERAGE_FILE: ".coverage.enterprise.${{ matrix.python_version }}"
|
||||
- name: Store coverage file
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: coverage-enterprise
|
||||
path: ".coverage.enterprise.${{ matrix.python_version }}"
|
||||
@@ -113,7 +113,7 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/download-artifact@v5
|
||||
- uses: actions/download-artifact@v6
|
||||
id: download
|
||||
with:
|
||||
pattern: coverage-*
|
||||
|
||||
6
.github/workflows/vscode-extension-build.yml
vendored
6
.github/workflows/vscode-extension-build.yml
vendored
@@ -37,7 +37,7 @@ jobs:
|
||||
node-version: '22'
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
@@ -70,7 +70,7 @@ jobs:
|
||||
fi
|
||||
|
||||
- name: Upload VSCode extension artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: vscode-extension
|
||||
path: openhands/integrations/vscode/openhands-vscode-0.0.1.vsix
|
||||
@@ -142,7 +142,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Download .vsix artifact
|
||||
uses: actions/download-artifact@v4
|
||||
uses: actions/download-artifact@v6
|
||||
with:
|
||||
name: vscode-extension
|
||||
path: ./
|
||||
|
||||
@@ -161,7 +161,7 @@ poetry run pytest ./tests/unit/test_*.py
|
||||
To reduce build time (e.g., if no changes were made to the client-runtime component), you can use an existing Docker
|
||||
container image by setting the SANDBOX_RUNTIME_CONTAINER_IMAGE environment variable to the desired Docker image.
|
||||
|
||||
Example: `export SANDBOX_RUNTIME_CONTAINER_IMAGE=ghcr.io/openhands/runtime:1.0-nikolaik`
|
||||
Example: `export SANDBOX_RUNTIME_CONTAINER_IMAGE=ghcr.io/openhands/runtime:1.1-nikolaik`
|
||||
|
||||
## Develop inside Docker container
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ services:
|
||||
- SANDBOX_API_HOSTNAME=host.docker.internal
|
||||
- DOCKER_HOST_ADDR=host.docker.internal
|
||||
#
|
||||
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-ghcr.io/openhands/runtime:1.0-nikolaik}
|
||||
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-ghcr.io/openhands/runtime:1.1-nikolaik}
|
||||
- SANDBOX_USER_ID=${SANDBOX_USER_ID:-1234}
|
||||
- WORKSPACE_MOUNT_PATH=${WORKSPACE_BASE:-$PWD/workspace}
|
||||
ports:
|
||||
|
||||
@@ -7,7 +7,7 @@ services:
|
||||
image: openhands:latest
|
||||
container_name: openhands-app-${DATE:-}
|
||||
environment:
|
||||
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-docker.openhands.dev/openhands/runtime:1.0-nikolaik}
|
||||
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-docker.openhands.dev/openhands/runtime:1.1-nikolaik}
|
||||
#- SANDBOX_USER_ID=${SANDBOX_USER_ID:-1234} # enable this only if you want a specific non-root sandbox user but you will have to manually adjust permissions of ~/.openhands for this user
|
||||
- WORKSPACE_MOUNT_PATH=${WORKSPACE_BASE:-$PWD/workspace}
|
||||
ports:
|
||||
|
||||
@@ -50,7 +50,7 @@ First run this to retrieve Github App secrets
|
||||
```
|
||||
gcloud auth application-default login
|
||||
gcloud config set project global-432717
|
||||
local/decrypt_env.sh
|
||||
enterprise_local/decrypt_env.sh /path/to/root/of/deploy/repo
|
||||
```
|
||||
|
||||
Now run this to generate a `.env` file, which will used to run SAAS locally
|
||||
|
||||
4
enterprise/enterprise_local/decrypt_env.sh
Normal file → Executable file
4
enterprise/enterprise_local/decrypt_env.sh
Normal file → Executable file
@@ -4,12 +4,12 @@ set -euo pipefail
|
||||
# Check if DEPLOY_DIR argument was provided
|
||||
if [ $# -lt 1 ]; then
|
||||
echo "Usage: $0 <DEPLOY_DIR>"
|
||||
echo "Example: $0 /path/to/deploy"
|
||||
echo "Example: $0 /path/to/root/of/deploy/repo"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Normalize path (remove trailing slash)
|
||||
DEPLOY_DIR="${DEPLOY_DIR%/}"
|
||||
DEPLOY_DIR="${1%/}"
|
||||
|
||||
# Function to decrypt and rename
|
||||
decrypt_and_move() {
|
||||
|
||||
@@ -321,7 +321,7 @@ def append_conversation_footer(message: str, conversation_id: str) -> str:
|
||||
The message with the conversation footer appended
|
||||
"""
|
||||
conversation_link = CONVERSATION_URL.format(conversation_id)
|
||||
footer = f'\n\n<sub>[View full conversation]({conversation_link})</sub>'
|
||||
footer = f'\n\n[View full conversation]({conversation_link})'
|
||||
return message + footer
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
"""add public column to conversation_metadata
|
||||
|
||||
Revision ID: 085
|
||||
Revises: 084
|
||||
Create Date: 2025-01-27 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '085'
|
||||
down_revision: Union[str, None] = '084'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
op.add_column(
|
||||
'conversation_metadata',
|
||||
sa.Column('public', sa.Boolean(), nullable=True),
|
||||
)
|
||||
op.create_index(
|
||||
op.f('ix_conversation_metadata_public'),
|
||||
'conversation_metadata',
|
||||
['public'],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
op.drop_index(
|
||||
op.f('ix_conversation_metadata_public'),
|
||||
table_name='conversation_metadata',
|
||||
)
|
||||
op.drop_column('conversation_metadata', 'public')
|
||||
42
enterprise/poetry.lock
generated
42
enterprise/poetry.lock
generated
@@ -4558,25 +4558,25 @@ valkey = ["valkey (>=6)"]
|
||||
|
||||
[[package]]
|
||||
name = "litellm"
|
||||
version = "1.80.7"
|
||||
version = "1.80.11"
|
||||
description = "Library to easily interface with LLM API providers"
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "litellm-1.80.7-py3-none-any.whl", hash = "sha256:f7d993f78c1e0e4e1202b2a925cc6540b55b6e5fb055dd342d88b145ab3102ed"},
|
||||
{file = "litellm-1.80.7.tar.gz", hash = "sha256:3977a8d195aef842d01c18bf9e22984829363c6a4b54daf9a43c9dd9f190b42c"},
|
||||
{file = "litellm-1.80.11-py3-none-any.whl", hash = "sha256:406283d66ead77dc7ff0e0b2559c80e9e497d8e7c2257efb1cb9210a20d09d54"},
|
||||
{file = "litellm-1.80.11.tar.gz", hash = "sha256:c9fc63e7acb6360363238fe291bcff1488c59ff66020416d8376c0ee56414a19"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
aiohttp = ">=3.10"
|
||||
click = "*"
|
||||
fastuuid = ">=0.13.0"
|
||||
grpcio = ">=1.62.3,<1.68.0"
|
||||
grpcio = {version = ">=1.62.3,<1.68.0", markers = "python_version < \"3.14\""}
|
||||
httpx = ">=0.23.0"
|
||||
importlib-metadata = ">=6.8.0"
|
||||
jinja2 = ">=3.1.2,<4.0.0"
|
||||
jsonschema = ">=4.22.0,<5.0.0"
|
||||
jsonschema = ">=4.23.0,<5.0.0"
|
||||
openai = ">=2.8.0"
|
||||
pydantic = ">=2.5.0,<3.0.0"
|
||||
python-dotenv = ">=0.2.0"
|
||||
@@ -4587,7 +4587,7 @@ tokenizers = "*"
|
||||
caching = ["diskcache (>=5.6.1,<6.0.0)"]
|
||||
extra-proxy = ["azure-identity (>=1.15.0,<2.0.0) ; python_version >= \"3.9\"", "azure-keyvault-secrets (>=4.8.0,<5.0.0)", "google-cloud-iam (>=2.19.1,<3.0.0)", "google-cloud-kms (>=2.21.3,<3.0.0)", "prisma (==0.11.0)", "redisvl (>=0.4.1,<0.5.0) ; python_version >= \"3.9\" and python_version < \"3.14\"", "resend (>=0.8.0)"]
|
||||
mlflow = ["mlflow (>3.1.4) ; python_version >= \"3.10\""]
|
||||
proxy = ["PyJWT (>=2.10.1,<3.0.0) ; python_version >= \"3.9\"", "apscheduler (>=3.10.4,<4.0.0)", "azure-identity (>=1.15.0,<2.0.0) ; python_version >= \"3.9\"", "azure-storage-blob (>=12.25.1,<13.0.0)", "backoff", "boto3 (==1.36.0)", "cryptography", "fastapi (>=0.120.1)", "fastapi-sso (>=0.16.0,<0.17.0)", "gunicorn (>=23.0.0,<24.0.0)", "litellm-enterprise (==0.1.22)", "litellm-proxy-extras (==0.4.9)", "mcp (>=1.21.2,<2.0.0) ; python_version >= \"3.10\"", "orjson (>=3.9.7,<4.0.0)", "polars (>=1.31.0,<2.0.0) ; python_version >= \"3.10\"", "pynacl (>=1.5.0,<2.0.0)", "python-multipart (>=0.0.18,<0.0.19)", "pyyaml (>=6.0.1,<7.0.0)", "rich (==13.7.1)", "rq", "soundfile (>=0.12.1,<0.13.0)", "uvicorn (>=0.31.1,<0.32.0)", "uvloop (>=0.21.0,<0.22.0) ; sys_platform != \"win32\"", "websockets (>=15.0.1,<16.0.0)"]
|
||||
proxy = ["PyJWT (>=2.10.1,<3.0.0) ; python_version >= \"3.9\"", "apscheduler (>=3.10.4,<4.0.0)", "azure-identity (>=1.15.0,<2.0.0) ; python_version >= \"3.9\"", "azure-storage-blob (>=12.25.1,<13.0.0)", "backoff", "boto3 (==1.36.0)", "cryptography", "fastapi (>=0.120.1)", "fastapi-sso (>=0.16.0,<0.17.0)", "gunicorn (>=23.0.0,<24.0.0)", "litellm-enterprise (==0.1.27)", "litellm-proxy-extras (==0.4.16)", "mcp (>=1.21.2,<2.0.0) ; python_version >= \"3.10\"", "orjson (>=3.9.7,<4.0.0)", "polars (>=1.31.0,<2.0.0) ; python_version >= \"3.10\"", "pynacl (>=1.5.0,<2.0.0)", "python-multipart (>=0.0.18,<0.0.19)", "pyyaml (>=6.0.1,<7.0.0)", "rich (==13.7.1)", "rq", "soundfile (>=0.12.1,<0.13.0)", "uvicorn (>=0.31.1,<0.32.0)", "uvloop (>=0.21.0,<0.22.0) ; sys_platform != \"win32\"", "websockets (>=15.0.1,<16.0.0)"]
|
||||
semantic-router = ["semantic-router (>=0.1.12) ; python_version >= \"3.9\" and python_version < \"3.14\""]
|
||||
utils = ["numpydoc"]
|
||||
|
||||
@@ -5836,14 +5836,14 @@ llama = ["llama-index (>=0.12.29,<0.13.0)", "llama-index-core (>=0.12.29,<0.13.0
|
||||
|
||||
[[package]]
|
||||
name = "openhands-agent-server"
|
||||
version = "1.6.0"
|
||||
version = "1.7.1"
|
||||
description = "OpenHands Agent Server - REST/WebSocket interface for OpenHands AI Agent"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_agent_server-1.6.0-py3-none-any.whl", hash = "sha256:e6ae865ac3e7a96b234e10a0faad23f6210e025bbf7721cb66bc7a71d160848c"},
|
||||
{file = "openhands_agent_server-1.6.0.tar.gz", hash = "sha256:44ce7694ae2d4bb0666d318ef13e6618bd4dc73022c60354839fe6130e67d02a"},
|
||||
{file = "openhands_agent_server-1.7.1-py3-none-any.whl", hash = "sha256:e5c57f1b73293d00a68b77f9d290f59d9e2217d9df844fb01c7d2f929c3417f4"},
|
||||
{file = "openhands_agent_server-1.7.1.tar.gz", hash = "sha256:c82e1e6748ea3b4278ef2ee72f091dc37da6667c854b3aa3c0bc616086a82310"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -5860,7 +5860,7 @@ wsproto = ">=1.2.0"
|
||||
|
||||
[[package]]
|
||||
name = "openhands-ai"
|
||||
version = "0.0.0-post.5687+7853b41ad"
|
||||
version = "0.0.0-post.5742+ee50f333b"
|
||||
description = "OpenHands: Code Less, Make More"
|
||||
optional = false
|
||||
python-versions = "^3.12,<3.14"
|
||||
@@ -5896,15 +5896,15 @@ json-repair = "*"
|
||||
jupyter_kernel_gateway = "*"
|
||||
kubernetes = "^33.1.0"
|
||||
libtmux = ">=0.46.2"
|
||||
litellm = ">=1.74.3, <=1.80.7, !=1.64.4, !=1.67.*"
|
||||
litellm = ">=1.74.3, !=1.64.4, !=1.67.*"
|
||||
lmnr = "^0.7.20"
|
||||
memory-profiler = "^0.61.0"
|
||||
numpy = "*"
|
||||
openai = "2.8.0"
|
||||
openhands-aci = "0.3.2"
|
||||
openhands-agent-server = "1.6.0"
|
||||
openhands-sdk = "1.6.0"
|
||||
openhands-tools = "1.6.0"
|
||||
openhands-agent-server = "1.7.1"
|
||||
openhands-sdk = "1.7.1"
|
||||
openhands-tools = "1.7.1"
|
||||
opentelemetry-api = "^1.33.1"
|
||||
opentelemetry-exporter-otlp-proto-grpc = "^1.33.1"
|
||||
pathspec = "^0.12.1"
|
||||
@@ -5960,21 +5960,21 @@ url = ".."
|
||||
|
||||
[[package]]
|
||||
name = "openhands-sdk"
|
||||
version = "1.6.0"
|
||||
version = "1.7.1"
|
||||
description = "OpenHands SDK - Core functionality for building AI agents"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_sdk-1.6.0-py3-none-any.whl", hash = "sha256:94d2f87fb35406373da6728ae2d88584137f9e9b67fa0e940444c72f2e44e7d3"},
|
||||
{file = "openhands_sdk-1.6.0.tar.gz", hash = "sha256:f45742350e3874a7f5b08befc4a9d5adc7e4454f7ab5f8391c519eee3116090f"},
|
||||
{file = "openhands_sdk-1.7.1-py3-none-any.whl", hash = "sha256:e097e34dfbd45f38225ae2ff4830702424bcf742bc197b5a811540a75265b135"},
|
||||
{file = "openhands_sdk-1.7.1.tar.gz", hash = "sha256:e13d1fe8bf14dffd91e9080608072a989132c981cf9bfcd124fa4f7a68a13691"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
deprecation = ">=2.1.0"
|
||||
fastmcp = ">=2.11.3"
|
||||
httpx = ">=0.27.0"
|
||||
litellm = ">=1.80.7"
|
||||
litellm = ">=1.80.10"
|
||||
lmnr = ">=0.7.24"
|
||||
pydantic = ">=2.11.7"
|
||||
python-frontmatter = ">=1.1.0"
|
||||
@@ -5987,14 +5987,14 @@ boto3 = ["boto3 (>=1.35.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "openhands-tools"
|
||||
version = "1.6.0"
|
||||
version = "1.7.1"
|
||||
description = "OpenHands Tools - Runtime tools for AI agents"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_tools-1.6.0-py3-none-any.whl", hash = "sha256:176556d44186536751b23fe052d3505492cc2afb8d52db20fb7a2cc0169cd57a"},
|
||||
{file = "openhands_tools-1.6.0.tar.gz", hash = "sha256:d07ba31050fd4a7891a4c48388aa53ce9f703e17064ddbd59146d6c77e5980b3"},
|
||||
{file = "openhands_tools-1.7.1-py3-none-any.whl", hash = "sha256:e25815f24925e94fbd4d8c3fd9b2147a0556fde595bf4f80a7dbba1014ea3c86"},
|
||||
{file = "openhands_tools-1.7.1.tar.gz", hash = "sha256:f3823f7bd302c78969c454730cf793eb63109ce2d986e78585989c53986cc966"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
||||
@@ -37,6 +37,12 @@ from server.routes.mcp_patch import patch_mcp_server # noqa: E402
|
||||
from server.routes.oauth_device import oauth_device_router # noqa: E402
|
||||
from server.routes.readiness import readiness_router # noqa: E402
|
||||
from server.routes.user import saas_user_router # noqa: E402
|
||||
from server.sharing.shared_conversation_router import ( # noqa: E402
|
||||
router as shared_conversation_router,
|
||||
)
|
||||
from server.sharing.shared_event_router import ( # noqa: E402
|
||||
router as shared_event_router,
|
||||
)
|
||||
|
||||
from openhands.server.app import app as base_app # noqa: E402
|
||||
from openhands.server.listen_socket import sio # noqa: E402
|
||||
@@ -66,6 +72,8 @@ base_app.include_router(saas_user_router) # Add additional route SAAS user call
|
||||
base_app.include_router(
|
||||
billing_router
|
||||
) # Add routes for credit management and Stripe payment integration
|
||||
base_app.include_router(shared_conversation_router)
|
||||
base_app.include_router(shared_event_router)
|
||||
|
||||
# Add GitHub integration router only if GITHUB_APP_CLIENT_ID is set
|
||||
if GITHUB_APP_CLIENT_ID:
|
||||
@@ -99,6 +107,7 @@ base_app.include_router(
|
||||
event_webhook_router
|
||||
) # Add routes for Events in nested runtimes
|
||||
|
||||
|
||||
base_app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=PERMITTED_CORS_ORIGINS,
|
||||
|
||||
@@ -38,3 +38,8 @@ ROLE_CHECK_ENABLED = os.getenv('ROLE_CHECK_ENABLED', 'false').lower() in (
|
||||
'y',
|
||||
'on',
|
||||
)
|
||||
BLOCKED_EMAIL_DOMAINS = [
|
||||
domain.strip().lower()
|
||||
for domain in os.getenv('BLOCKED_EMAIL_DOMAINS', '').split(',')
|
||||
if domain.strip()
|
||||
]
|
||||
|
||||
56
enterprise/server/auth/domain_blocker.py
Normal file
56
enterprise/server/auth/domain_blocker.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from server.auth.constants import BLOCKED_EMAIL_DOMAINS
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
class DomainBlocker:
|
||||
def __init__(self) -> None:
|
||||
logger.debug('Initializing DomainBlocker')
|
||||
self.blocked_domains: list[str] = BLOCKED_EMAIL_DOMAINS
|
||||
if self.blocked_domains:
|
||||
logger.info(
|
||||
f'Successfully loaded {len(self.blocked_domains)} blocked email domains: {self.blocked_domains}'
|
||||
)
|
||||
|
||||
def is_active(self) -> bool:
|
||||
"""Check if domain blocking is enabled"""
|
||||
return bool(self.blocked_domains)
|
||||
|
||||
def _extract_domain(self, email: str) -> str | None:
|
||||
"""Extract and normalize email domain from email address"""
|
||||
if not email:
|
||||
return None
|
||||
try:
|
||||
# Extract domain part after @
|
||||
if '@' not in email:
|
||||
return None
|
||||
domain = email.split('@')[1].strip().lower()
|
||||
return domain if domain else None
|
||||
except Exception:
|
||||
logger.debug(f'Error extracting domain from email: {email}', exc_info=True)
|
||||
return None
|
||||
|
||||
def is_domain_blocked(self, email: str) -> bool:
|
||||
"""Check if email domain is blocked"""
|
||||
if not self.is_active():
|
||||
return False
|
||||
|
||||
if not email:
|
||||
logger.debug('No email provided for domain check')
|
||||
return False
|
||||
|
||||
domain = self._extract_domain(email)
|
||||
if not domain:
|
||||
logger.debug(f'Could not extract domain from email: {email}')
|
||||
return False
|
||||
|
||||
is_blocked = domain in self.blocked_domains
|
||||
if is_blocked:
|
||||
logger.warning(f'Email domain {domain} is blocked for email: {email}')
|
||||
else:
|
||||
logger.debug(f'Email domain {domain} is not blocked')
|
||||
|
||||
return is_blocked
|
||||
|
||||
|
||||
domain_blocker = DomainBlocker()
|
||||
109
enterprise/server/auth/email_validation.py
Normal file
109
enterprise/server/auth/email_validation.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""Email validation utilities for preventing duplicate signups with + modifier."""
|
||||
|
||||
import re
|
||||
|
||||
|
||||
def extract_base_email(email: str) -> str | None:
|
||||
"""Extract base email from an email address.
|
||||
|
||||
For emails with + modifier, extracts the base email (local part before + and @, plus domain).
|
||||
For emails without + modifier, returns the email as-is.
|
||||
|
||||
Examples:
|
||||
extract_base_email("joe+test@example.com") -> "joe@example.com"
|
||||
extract_base_email("joe@example.com") -> "joe@example.com"
|
||||
extract_base_email("joe+openhands+test@example.com") -> "joe@example.com"
|
||||
|
||||
Args:
|
||||
email: The email address to process
|
||||
|
||||
Returns:
|
||||
The base email address, or None if email format is invalid
|
||||
"""
|
||||
if not email or '@' not in email:
|
||||
return None
|
||||
|
||||
try:
|
||||
local_part, domain = email.rsplit('@', 1)
|
||||
# Extract the part before + if it exists
|
||||
base_local = local_part.split('+', 1)[0]
|
||||
return f'{base_local}@{domain}'
|
||||
except (ValueError, AttributeError):
|
||||
return None
|
||||
|
||||
|
||||
def has_plus_modifier(email: str) -> bool:
|
||||
"""Check if an email address contains a + modifier.
|
||||
|
||||
Args:
|
||||
email: The email address to check
|
||||
|
||||
Returns:
|
||||
True if email contains + before @, False otherwise
|
||||
"""
|
||||
if not email or '@' not in email:
|
||||
return False
|
||||
|
||||
try:
|
||||
local_part, _ = email.rsplit('@', 1)
|
||||
return '+' in local_part
|
||||
except (ValueError, AttributeError):
|
||||
return False
|
||||
|
||||
|
||||
def matches_base_email(email: str, base_email: str) -> bool:
|
||||
"""Check if an email matches a base email pattern.
|
||||
|
||||
An email matches if:
|
||||
- It is exactly the base email (e.g., joe@example.com)
|
||||
- It has the same base local part and domain, with or without + modifier
|
||||
(e.g., joe+test@example.com matches base joe@example.com)
|
||||
|
||||
Args:
|
||||
email: The email address to check
|
||||
base_email: The base email to match against
|
||||
|
||||
Returns:
|
||||
True if email matches the base pattern, False otherwise
|
||||
"""
|
||||
if not email or not base_email:
|
||||
return False
|
||||
|
||||
# Extract base from both emails for comparison
|
||||
email_base = extract_base_email(email)
|
||||
base_email_normalized = extract_base_email(base_email)
|
||||
|
||||
if not email_base or not base_email_normalized:
|
||||
return False
|
||||
|
||||
# Emails match if they have the same base
|
||||
return email_base.lower() == base_email_normalized.lower()
|
||||
|
||||
|
||||
def get_base_email_regex_pattern(base_email: str) -> re.Pattern | None:
|
||||
"""Generate a regex pattern to match emails with the same base.
|
||||
|
||||
For base_email "joe@example.com", the pattern will match:
|
||||
- joe@example.com
|
||||
- joe+anything@example.com
|
||||
|
||||
Args:
|
||||
base_email: The base email address
|
||||
|
||||
Returns:
|
||||
A compiled regex pattern, or None if base_email is invalid
|
||||
"""
|
||||
base = extract_base_email(base_email)
|
||||
if not base:
|
||||
return None
|
||||
|
||||
try:
|
||||
local_part, domain = base.rsplit('@', 1)
|
||||
# Escape special regex characters in local part and domain
|
||||
escaped_local = re.escape(local_part)
|
||||
escaped_domain = re.escape(domain)
|
||||
# Pattern: joe@example.com OR joe+anything@example.com
|
||||
pattern = rf'^{escaped_local}(\+[^@\s]+)?@{escaped_domain}$'
|
||||
return re.compile(pattern, re.IGNORECASE)
|
||||
except (ValueError, AttributeError):
|
||||
return None
|
||||
@@ -13,6 +13,7 @@ from server.auth.auth_error import (
|
||||
ExpiredError,
|
||||
NoCredentialsError,
|
||||
)
|
||||
from server.auth.domain_blocker import domain_blocker
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.config import get_config
|
||||
from server.logger import logger
|
||||
@@ -153,8 +154,10 @@ class SaasUserAuth(UserAuth):
|
||||
try:
|
||||
# TODO: I think we can do this in a single request if we refactor
|
||||
with session_maker() as session:
|
||||
tokens = session.query(AuthTokens).where(
|
||||
AuthTokens.keycloak_user_id == self.user_id
|
||||
tokens = (
|
||||
session.query(AuthTokens)
|
||||
.where(AuthTokens.keycloak_user_id == self.user_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
for token in tokens:
|
||||
@@ -312,6 +315,16 @@ async def saas_user_auth_from_signed_token(signed_token: str) -> SaasUserAuth:
|
||||
user_id = access_token_payload['sub']
|
||||
email = access_token_payload['email']
|
||||
email_verified = access_token_payload['email_verified']
|
||||
|
||||
# Check if email domain is blocked
|
||||
if email and domain_blocker.is_active() and domain_blocker.is_domain_blocked(email):
|
||||
logger.warning(
|
||||
f'Blocked authentication attempt for existing user with email: {email}'
|
||||
)
|
||||
raise AuthError(
|
||||
'Access denied: Your email domain is not allowed to access this service'
|
||||
)
|
||||
|
||||
logger.debug('saas_user_auth_from_signed_token:return')
|
||||
|
||||
return SaasUserAuth(
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
@@ -25,6 +26,11 @@ from server.auth.constants import (
|
||||
KEYCLOAK_SERVER_URL,
|
||||
KEYCLOAK_SERVER_URL_EXT,
|
||||
)
|
||||
from server.auth.email_validation import (
|
||||
extract_base_email,
|
||||
get_base_email_regex_pattern,
|
||||
matches_base_email,
|
||||
)
|
||||
from server.auth.keycloak_manager import get_keycloak_admin, get_keycloak_openid
|
||||
from server.config import get_config
|
||||
from server.logger import logger
|
||||
@@ -509,6 +515,183 @@ class TokenManager:
|
||||
logger.info(f'Got user ID {keycloak_user_id} from email: {email}')
|
||||
return keycloak_user_id
|
||||
|
||||
async def _query_users_by_wildcard_pattern(
|
||||
self, local_part: str, domain: str
|
||||
) -> dict[str, dict]:
|
||||
"""Query Keycloak for users matching a wildcard email pattern.
|
||||
|
||||
Tries multiple query methods to find users with emails matching
|
||||
the pattern {local_part}*@{domain}. This catches the base email
|
||||
and all + modifier variants.
|
||||
|
||||
Args:
|
||||
local_part: The local part of the email (before @)
|
||||
domain: The domain part of the email (after @)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping user IDs to user objects
|
||||
"""
|
||||
keycloak_admin = get_keycloak_admin(self.external)
|
||||
all_users = {}
|
||||
|
||||
# Query for users with emails matching the base pattern using wildcard
|
||||
# Pattern: {local_part}*@{domain} - catches base email and all + variants
|
||||
# This may also catch unintended matches (e.g., joesmith@example.com), but
|
||||
# they will be filtered out by the regex pattern check later
|
||||
# Use 'search' parameter for Keycloak 26+ (better wildcard support)
|
||||
wildcard_queries = [
|
||||
{'search': f'{local_part}*@{domain}'}, # Try 'search' parameter first
|
||||
{'q': f'email:{local_part}*@{domain}'}, # Fallback to 'q' parameter
|
||||
]
|
||||
|
||||
for query_params in wildcard_queries:
|
||||
try:
|
||||
users = await keycloak_admin.a_get_users(query_params)
|
||||
for user in users:
|
||||
all_users[user.get('id')] = user
|
||||
break # Success, no need to try fallback
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f'Wildcard query failed with {list(query_params.keys())[0]}: {e}'
|
||||
)
|
||||
continue # Try next query method
|
||||
|
||||
return all_users
|
||||
|
||||
def _find_duplicate_in_users(
|
||||
self, users: dict[str, dict], base_email: str, current_user_id: str
|
||||
) -> bool:
|
||||
"""Check if any user in the provided list matches the base email pattern.
|
||||
|
||||
Filters users to find duplicates that match the base email pattern,
|
||||
excluding the current user.
|
||||
|
||||
Args:
|
||||
users: Dictionary mapping user IDs to user objects
|
||||
base_email: The base email to match against
|
||||
current_user_id: The user ID to exclude from the check
|
||||
|
||||
Returns:
|
||||
True if a duplicate is found, False otherwise
|
||||
"""
|
||||
regex_pattern = get_base_email_regex_pattern(base_email)
|
||||
if not regex_pattern:
|
||||
logger.warning(
|
||||
f'Could not generate regex pattern for base email: {base_email}'
|
||||
)
|
||||
# Fallback to simple matching
|
||||
for user in users.values():
|
||||
user_email = user.get('email', '').lower()
|
||||
if (
|
||||
user_email
|
||||
and user.get('id') != current_user_id
|
||||
and matches_base_email(user_email, base_email)
|
||||
):
|
||||
logger.info(
|
||||
f'Found duplicate email: {user_email} matches base {base_email}'
|
||||
)
|
||||
return True
|
||||
else:
|
||||
for user in users.values():
|
||||
user_email = user.get('email', '')
|
||||
if (
|
||||
user_email
|
||||
and user.get('id') != current_user_id
|
||||
and regex_pattern.match(user_email)
|
||||
):
|
||||
logger.info(
|
||||
f'Found duplicate email: {user_email} matches base {base_email}'
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(2),
|
||||
retry=retry_if_exception_type(KeycloakConnectionError),
|
||||
before_sleep=_before_sleep_callback,
|
||||
)
|
||||
async def check_duplicate_base_email(
|
||||
self, email: str, current_user_id: str
|
||||
) -> bool:
|
||||
"""Check if a user with the same base email already exists.
|
||||
|
||||
This method checks for duplicate signups using email + modifier.
|
||||
It checks if any user exists with the same base email, regardless of whether
|
||||
the provided email has a + modifier or not.
|
||||
|
||||
Examples:
|
||||
- If email is "joe+test@example.com", it checks for existing users with
|
||||
base email "joe@example.com" (e.g., "joe@example.com", "joe+1@example.com")
|
||||
- If email is "joe@example.com", it checks for existing users with
|
||||
base email "joe@example.com" (e.g., "joe+1@example.com", "joe+test@example.com")
|
||||
|
||||
Args:
|
||||
email: The email address to check (may or may not contain + modifier)
|
||||
current_user_id: The user ID of the current user (to exclude from check)
|
||||
|
||||
Returns:
|
||||
True if a duplicate is found (excluding current user), False otherwise
|
||||
"""
|
||||
if not email:
|
||||
return False
|
||||
|
||||
base_email = extract_base_email(email)
|
||||
if not base_email:
|
||||
logger.warning(f'Could not extract base email from: {email}')
|
||||
return False
|
||||
|
||||
try:
|
||||
local_part, domain = base_email.rsplit('@', 1)
|
||||
users = await self._query_users_by_wildcard_pattern(local_part, domain)
|
||||
return self._find_duplicate_in_users(users, base_email, current_user_id)
|
||||
|
||||
except KeycloakConnectionError:
|
||||
logger.exception('KeycloakConnectionError when checking duplicate email')
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f'Unexpected error checking duplicate email: {e}')
|
||||
# On any error, allow signup to proceed (fail open)
|
||||
return False
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(2),
|
||||
retry=retry_if_exception_type(KeycloakConnectionError),
|
||||
before_sleep=_before_sleep_callback,
|
||||
)
|
||||
async def delete_keycloak_user(self, user_id: str) -> bool:
|
||||
"""Delete a user from Keycloak.
|
||||
|
||||
This method is used to clean up user accounts that were created
|
||||
but should not exist (e.g., duplicate email signups).
|
||||
|
||||
Args:
|
||||
user_id: The Keycloak user ID to delete
|
||||
|
||||
Returns:
|
||||
True if deletion was successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
keycloak_admin = get_keycloak_admin(self.external)
|
||||
# Use the sync method (python-keycloak doesn't have async delete_user)
|
||||
# Run it in a thread executor to avoid blocking the event loop
|
||||
await asyncio.to_thread(keycloak_admin.delete_user, user_id)
|
||||
logger.info(f'Successfully deleted Keycloak user {user_id}')
|
||||
return True
|
||||
except KeycloakConnectionError:
|
||||
logger.exception(f'KeycloakConnectionError when deleting user {user_id}')
|
||||
raise
|
||||
except KeycloakError as e:
|
||||
# User might not exist or already deleted
|
||||
logger.warning(
|
||||
f'KeycloakError when deleting user {user_id}: {e}',
|
||||
extra={'user_id': user_id, 'error': str(e)},
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.exception(f'Unexpected error deleting Keycloak user {user_id}: {e}')
|
||||
return False
|
||||
|
||||
async def get_user_info_from_user_id(self, user_id: str) -> dict | None:
|
||||
keycloak_admin = get_keycloak_admin(self.external)
|
||||
user = await keycloak_admin.a_get_user(user_id)
|
||||
@@ -527,6 +710,49 @@ class TokenManager:
|
||||
github_id = github_ids[0]
|
||||
return github_id
|
||||
|
||||
async def disable_keycloak_user(
|
||||
self, user_id: str, email: str | None = None
|
||||
) -> None:
|
||||
"""Disable a Keycloak user account.
|
||||
|
||||
Args:
|
||||
user_id: The Keycloak user ID to disable
|
||||
email: Optional email address for logging purposes
|
||||
|
||||
This method attempts to disable the user account but will not raise exceptions.
|
||||
Errors are logged but do not prevent the operation from completing.
|
||||
"""
|
||||
try:
|
||||
keycloak_admin = get_keycloak_admin(self.external)
|
||||
# Get current user to preserve other fields
|
||||
user = await keycloak_admin.a_get_user(user_id)
|
||||
if user:
|
||||
# Update user with enabled=False to disable the account
|
||||
await keycloak_admin.a_update_user(
|
||||
user_id=user_id,
|
||||
payload={
|
||||
'enabled': False,
|
||||
'username': user.get('username', ''),
|
||||
'email': user.get('email', ''),
|
||||
'emailVerified': user.get('emailVerified', False),
|
||||
},
|
||||
)
|
||||
email_str = f', email: {email}' if email else ''
|
||||
logger.info(
|
||||
f'Disabled Keycloak account for user_id: {user_id}{email_str}'
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f'User not found in Keycloak when attempting to disable: {user_id}'
|
||||
)
|
||||
except Exception as e:
|
||||
# Log error but don't raise - the caller should handle the blocking regardless
|
||||
email_str = f', email: {email}' if email else ''
|
||||
logger.error(
|
||||
f'Failed to disable Keycloak account for user_id: {user_id}{email_str}: {str(e)}',
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
def store_org_token(self, installation_id: int, installation_token: str):
|
||||
"""Store a GitHub App installation token.
|
||||
|
||||
|
||||
@@ -159,6 +159,7 @@ class SetAuthCookieMiddleware:
|
||||
'/api/billing/cancel',
|
||||
'/api/billing/customer-setup-success',
|
||||
'/api/billing/stripe-webhook',
|
||||
'/api/email/resend',
|
||||
'/oauth/device/authorize',
|
||||
'/oauth/device/token',
|
||||
)
|
||||
|
||||
@@ -14,6 +14,7 @@ from server.auth.constants import (
|
||||
KEYCLOAK_SERVER_URL_EXT,
|
||||
ROLE_CHECK_ENABLED,
|
||||
)
|
||||
from server.auth.domain_blocker import domain_blocker
|
||||
from server.auth.gitlab_sync import schedule_gitlab_repo_sync
|
||||
from server.auth.saas_user_auth import SaasUserAuth
|
||||
from server.auth.token_manager import TokenManager
|
||||
@@ -145,7 +146,76 @@ async def keycloak_callback(
|
||||
content={'error': 'Missing user ID or username in response'},
|
||||
)
|
||||
|
||||
email = user_info.get('email')
|
||||
user_id = user_info['sub']
|
||||
|
||||
# Check if email domain is blocked
|
||||
email = user_info.get('email')
|
||||
if email and domain_blocker.is_active() and domain_blocker.is_domain_blocked(email):
|
||||
logger.warning(
|
||||
f'Blocked authentication attempt for email: {email}, user_id: {user_id}'
|
||||
)
|
||||
|
||||
# Disable the Keycloak account
|
||||
await token_manager.disable_keycloak_user(user_id, email)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={
|
||||
'error': 'Access denied: Your email domain is not allowed to access this service'
|
||||
},
|
||||
)
|
||||
|
||||
# Check for duplicate email with + modifier
|
||||
if email:
|
||||
try:
|
||||
has_duplicate = await token_manager.check_duplicate_base_email(
|
||||
email, user_id
|
||||
)
|
||||
if has_duplicate:
|
||||
logger.warning(
|
||||
f'Blocked signup attempt for email {email} - duplicate base email found',
|
||||
extra={'user_id': user_id, 'email': email},
|
||||
)
|
||||
|
||||
# Delete the Keycloak user that was automatically created during OAuth
|
||||
# This prevents orphaned accounts in Keycloak
|
||||
# The delete_keycloak_user method already handles all errors internally
|
||||
deletion_success = await token_manager.delete_keycloak_user(user_id)
|
||||
if deletion_success:
|
||||
logger.info(
|
||||
f'Deleted Keycloak user {user_id} after detecting duplicate email {email}'
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f'Failed to delete Keycloak user {user_id} after detecting duplicate email {email}. '
|
||||
f'User may need to be manually cleaned up.'
|
||||
)
|
||||
|
||||
# Redirect to home page with query parameter indicating the issue
|
||||
home_url = f'{request.base_url}?duplicated_email=true'
|
||||
return RedirectResponse(home_url, status_code=302)
|
||||
except Exception as e:
|
||||
# Log error but allow signup to proceed (fail open)
|
||||
logger.error(
|
||||
f'Error checking duplicate email for {email}: {e}',
|
||||
extra={'user_id': user_id, 'email': email},
|
||||
)
|
||||
|
||||
# Check email verification status
|
||||
email_verified = user_info.get('email_verified', False)
|
||||
if not email_verified:
|
||||
# Send verification email
|
||||
# Import locally to avoid circular import with email.py
|
||||
from server.routes.email import verify_email
|
||||
|
||||
await verify_email(request=request, user_id=user_id, is_auth_flow=True)
|
||||
redirect_url = (
|
||||
f'{request.base_url}?email_verification_required=true&user_id={user_id}'
|
||||
)
|
||||
response = RedirectResponse(redirect_url, status_code=302)
|
||||
return response
|
||||
|
||||
# default to github IDP for now.
|
||||
# TODO: remove default once Keycloak is updated universally with the new attribute.
|
||||
idp: str = user_info.get('identity_provider', ProviderType.GITHUB.value)
|
||||
|
||||
@@ -111,10 +111,24 @@ def calculate_credits(user_info: LiteLlmUserInfo) -> float:
|
||||
async def get_credits(user_id: str = Depends(get_user_id)) -> GetCreditsResponse:
|
||||
if not stripe_service.STRIPE_API_KEY:
|
||||
return GetCreditsResponse()
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
user_json = await _get_litellm_user(client, user_id)
|
||||
credits = calculate_credits(user_json['user_info'])
|
||||
return GetCreditsResponse(credits=Decimal('{:.2f}'.format(credits)))
|
||||
try:
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
user_json = await _get_litellm_user(client, user_id)
|
||||
credits = calculate_credits(user_json['user_info'])
|
||||
return GetCreditsResponse(credits=Decimal('{:.2f}'.format(credits)))
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(
|
||||
f'litellm_get_user_failed: {type(e).__name__}: {e}',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'status_code': e.response.status_code,
|
||||
},
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to retrieve credit balance from billing service',
|
||||
)
|
||||
|
||||
|
||||
# Endpoint to retrieve user's current subscription access
|
||||
|
||||
@@ -7,6 +7,7 @@ from server.auth.constants import KEYCLOAK_CLIENT_ID
|
||||
from server.auth.keycloak_manager import get_keycloak_admin
|
||||
from server.auth.saas_user_auth import SaasUserAuth
|
||||
from server.routes.auth import set_response_cookie
|
||||
from server.utils.rate_limit_utils import check_rate_limit_by_user_id
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_id
|
||||
@@ -28,6 +29,11 @@ class EmailUpdate(BaseModel):
|
||||
return v
|
||||
|
||||
|
||||
class ResendEmailVerificationRequest(BaseModel):
|
||||
user_id: str | None = None
|
||||
is_auth_flow: bool = False
|
||||
|
||||
|
||||
@api_router.post('')
|
||||
async def update_email(
|
||||
email_data: EmailUpdate, request: Request, user_id: str = Depends(get_user_id)
|
||||
@@ -74,7 +80,7 @@ async def update_email(
|
||||
accepted_tos=user_auth.accepted_tos,
|
||||
)
|
||||
|
||||
await _verify_email(request=request, user_id=user_id)
|
||||
await verify_email(request=request, user_id=user_id)
|
||||
|
||||
logger.info(f'Updating email address for {user_id} to {email}')
|
||||
return response
|
||||
@@ -90,9 +96,41 @@ async def update_email(
|
||||
)
|
||||
|
||||
|
||||
@api_router.put('/verify')
|
||||
async def verify_email(request: Request, user_id: str = Depends(get_user_id)):
|
||||
await _verify_email(request=request, user_id=user_id)
|
||||
@api_router.put('/resend')
|
||||
async def resend_email_verification(
|
||||
request: Request,
|
||||
body: ResendEmailVerificationRequest | None = None,
|
||||
):
|
||||
# Get user_id from body if provided, otherwise from auth
|
||||
user_id: str | None = None
|
||||
if body and body.user_id:
|
||||
user_id = body.user_id
|
||||
else:
|
||||
try:
|
||||
user_id = await get_user_id(request)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='user_id is required in request body or user must be authenticated',
|
||||
)
|
||||
|
||||
# Check rate limit (uses user_id if available, otherwise falls back to IP)
|
||||
# Use 30 seconds for user-based rate limiting to match frontend cooldown
|
||||
await check_rate_limit_by_user_id(
|
||||
request=request,
|
||||
key_prefix='email_resend',
|
||||
user_id=user_id,
|
||||
user_rate_limit_seconds=30,
|
||||
ip_rate_limit_seconds=60, # 1 minute for IP-based limiting (more lenient)
|
||||
)
|
||||
|
||||
# Get is_auth_flow from body if provided, default to False
|
||||
is_auth_flow = body.is_auth_flow if body else False
|
||||
|
||||
await verify_email(request=request, user_id=user_id, is_auth_flow=is_auth_flow)
|
||||
|
||||
logger.info(f'Resending verification email for {user_id}')
|
||||
return JSONResponse(
|
||||
@@ -124,10 +162,13 @@ async def verified_email(request: Request):
|
||||
return response
|
||||
|
||||
|
||||
async def _verify_email(request: Request, user_id: str):
|
||||
async def verify_email(request: Request, user_id: str, is_auth_flow: bool = False):
|
||||
keycloak_admin = get_keycloak_admin()
|
||||
scheme = 'http' if request.url.hostname == 'localhost' else 'https'
|
||||
redirect_uri = f'{scheme}://{request.url.netloc}/api/email/verified'
|
||||
if is_auth_flow:
|
||||
redirect_uri = f'{scheme}://{request.url.netloc}?email_verified=true'
|
||||
else:
|
||||
redirect_uri = f'{scheme}://{request.url.netloc}/api/email/verified'
|
||||
logger.info(f'Redirect URI: {redirect_uri}')
|
||||
await keycloak_admin.a_send_verify_email(
|
||||
user_id=user_id,
|
||||
|
||||
@@ -134,12 +134,12 @@ async def _process_batch_operations_background(
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'error_processing_batch_operation',
|
||||
f'error_processing_batch_operation: {type(e).__name__}: {e}',
|
||||
extra={
|
||||
'path': batch_op.path,
|
||||
'method': str(batch_op.method),
|
||||
'error': str(e),
|
||||
},
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -804,6 +804,8 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
env_vars['ENABLE_V1'] = '0'
|
||||
env_vars['SU_TO_USER'] = SU_TO_USER
|
||||
env_vars['DISABLE_VSCODE_PLUGIN'] = str(DISABLE_VSCODE_PLUGIN).lower()
|
||||
env_vars['BROWSERGYM_DOWNLOAD_DIR'] = '/workspace/.downloads/'
|
||||
env_vars['PLAYWRIGHT_BROWSERS_PATH'] = '/opt/playwright-browsers'
|
||||
|
||||
# We need this for LLM traces tracking to identify the source of the LLM calls
|
||||
env_vars['WEB_HOST'] = WEB_HOST
|
||||
|
||||
20
enterprise/server/sharing/README.md
Normal file
20
enterprise/server/sharing/README.md
Normal file
@@ -0,0 +1,20 @@
|
||||
# Sharing Package
|
||||
|
||||
This package contains functionality for sharing conversations.
|
||||
|
||||
## Components
|
||||
|
||||
- **shared.py**: Data models for shared conversations
|
||||
- **shared_conversation_info_service.py**: Service interface for accessing shared conversation info
|
||||
- **sql_shared_conversation_info_service.py**: SQL implementation of the shared conversation info service
|
||||
- **shared_event_service.py**: Service interface for accessing shared events
|
||||
- **shared_event_service_impl.py**: Implementation of the shared event service
|
||||
- **shared_conversation_router.py**: REST API endpoints for shared conversations
|
||||
- **shared_event_router.py**: REST API endpoints for shared events
|
||||
|
||||
## Features
|
||||
|
||||
- Read-only access to shared conversations
|
||||
- Event access for shared conversations
|
||||
- Search and filtering capabilities
|
||||
- Pagination support
|
||||
142
enterprise/server/sharing/filesystem_shared_event_service.py
Normal file
142
enterprise/server/sharing/filesystem_shared_event_service.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""Implementation of SharedEventService.
|
||||
|
||||
This implementation provides read-only access to events from shared conversations:
|
||||
- Validates that the conversation is shared before returning events
|
||||
- Uses existing EventService for actual event retrieval
|
||||
- Uses SharedConversationInfoService for shared conversation validation
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Request
|
||||
from server.sharing.shared_conversation_info_service import (
|
||||
SharedConversationInfoService,
|
||||
)
|
||||
from server.sharing.shared_event_service import (
|
||||
SharedEventService,
|
||||
SharedEventServiceInjector,
|
||||
)
|
||||
from server.sharing.sql_shared_conversation_info_service import (
|
||||
SQLSharedConversationInfoService,
|
||||
)
|
||||
|
||||
from openhands.agent_server.models import EventPage, EventSortOrder
|
||||
from openhands.app_server.event.event_service import EventService
|
||||
from openhands.app_server.event_callback.event_callback_models import EventKind
|
||||
from openhands.app_server.services.injector import InjectorState
|
||||
from openhands.sdk import Event
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SharedEventServiceImpl(SharedEventService):
|
||||
"""Implementation of SharedEventService that validates shared access."""
|
||||
|
||||
shared_conversation_info_service: SharedConversationInfoService
|
||||
event_service: EventService
|
||||
|
||||
async def get_shared_event(
|
||||
self, conversation_id: UUID, event_id: str
|
||||
) -> Event | None:
|
||||
"""Given a conversation_id and event_id, retrieve an event if the conversation is shared."""
|
||||
# First check if the conversation is shared
|
||||
shared_conversation_info = (
|
||||
await self.shared_conversation_info_service.get_shared_conversation_info(
|
||||
conversation_id
|
||||
)
|
||||
)
|
||||
if shared_conversation_info is None:
|
||||
return None
|
||||
|
||||
# If conversation is shared, get the event
|
||||
return await self.event_service.get_event(event_id)
|
||||
|
||||
async def search_shared_events(
|
||||
self,
|
||||
conversation_id: UUID,
|
||||
kind__eq: EventKind | None = None,
|
||||
timestamp__gte: datetime | None = None,
|
||||
timestamp__lt: datetime | None = None,
|
||||
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
|
||||
page_id: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> EventPage:
|
||||
"""Search events for a specific shared conversation."""
|
||||
# First check if the conversation is shared
|
||||
shared_conversation_info = (
|
||||
await self.shared_conversation_info_service.get_shared_conversation_info(
|
||||
conversation_id
|
||||
)
|
||||
)
|
||||
if shared_conversation_info is None:
|
||||
# Return empty page if conversation is not shared
|
||||
return EventPage(items=[], next_page_id=None)
|
||||
|
||||
# If conversation is shared, search events for this conversation
|
||||
return await self.event_service.search_events(
|
||||
conversation_id__eq=conversation_id,
|
||||
kind__eq=kind__eq,
|
||||
timestamp__gte=timestamp__gte,
|
||||
timestamp__lt=timestamp__lt,
|
||||
sort_order=sort_order,
|
||||
page_id=page_id,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
async def count_shared_events(
|
||||
self,
|
||||
conversation_id: UUID,
|
||||
kind__eq: EventKind | None = None,
|
||||
timestamp__gte: datetime | None = None,
|
||||
timestamp__lt: datetime | None = None,
|
||||
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
|
||||
) -> int:
|
||||
"""Count events for a specific shared conversation."""
|
||||
# First check if the conversation is shared
|
||||
shared_conversation_info = (
|
||||
await self.shared_conversation_info_service.get_shared_conversation_info(
|
||||
conversation_id
|
||||
)
|
||||
)
|
||||
if shared_conversation_info is None:
|
||||
return 0
|
||||
|
||||
# If conversation is shared, count events for this conversation
|
||||
return await self.event_service.count_events(
|
||||
conversation_id__eq=conversation_id,
|
||||
kind__eq=kind__eq,
|
||||
timestamp__gte=timestamp__gte,
|
||||
timestamp__lt=timestamp__lt,
|
||||
sort_order=sort_order,
|
||||
)
|
||||
|
||||
|
||||
class SharedEventServiceImplInjector(SharedEventServiceInjector):
|
||||
async def inject(
|
||||
self, state: InjectorState, request: Request | None = None
|
||||
) -> AsyncGenerator[SharedEventService, None]:
|
||||
# Define inline to prevent circular lookup
|
||||
from openhands.app_server.config import (
|
||||
get_db_session,
|
||||
get_event_service,
|
||||
)
|
||||
|
||||
async with (
|
||||
get_db_session(state, request) as db_session,
|
||||
get_event_service(state, request) as event_service,
|
||||
):
|
||||
shared_conversation_info_service = SQLSharedConversationInfoService(
|
||||
db_session=db_session
|
||||
)
|
||||
service = SharedEventServiceImpl(
|
||||
shared_conversation_info_service=shared_conversation_info_service,
|
||||
event_service=event_service,
|
||||
)
|
||||
yield service
|
||||
@@ -0,0 +1,66 @@
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from server.sharing.shared_conversation_models import (
|
||||
SharedConversation,
|
||||
SharedConversationPage,
|
||||
SharedConversationSortOrder,
|
||||
)
|
||||
|
||||
from openhands.app_server.services.injector import Injector
|
||||
from openhands.sdk.utils.models import DiscriminatedUnionMixin
|
||||
|
||||
|
||||
class SharedConversationInfoService(ABC):
|
||||
"""Service for accessing shared conversation info without user restrictions."""
|
||||
|
||||
@abstractmethod
|
||||
async def search_shared_conversation_info(
|
||||
self,
|
||||
title__contains: str | None = None,
|
||||
created_at__gte: datetime | None = None,
|
||||
created_at__lt: datetime | None = None,
|
||||
updated_at__gte: datetime | None = None,
|
||||
updated_at__lt: datetime | None = None,
|
||||
sort_order: SharedConversationSortOrder = SharedConversationSortOrder.CREATED_AT_DESC,
|
||||
page_id: str | None = None,
|
||||
limit: int = 100,
|
||||
include_sub_conversations: bool = False,
|
||||
) -> SharedConversationPage:
|
||||
"""Search for shared conversations."""
|
||||
|
||||
@abstractmethod
|
||||
async def count_shared_conversation_info(
|
||||
self,
|
||||
title__contains: str | None = None,
|
||||
created_at__gte: datetime | None = None,
|
||||
created_at__lt: datetime | None = None,
|
||||
updated_at__gte: datetime | None = None,
|
||||
updated_at__lt: datetime | None = None,
|
||||
) -> int:
|
||||
"""Count shared conversations."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_shared_conversation_info(
|
||||
self, conversation_id: UUID
|
||||
) -> SharedConversation | None:
|
||||
"""Get a single shared conversation info, returning None if missing or not shared."""
|
||||
|
||||
async def batch_get_shared_conversation_info(
|
||||
self, conversation_ids: list[UUID]
|
||||
) -> list[SharedConversation | None]:
|
||||
"""Get a batch of shared conversation info, return None for any missing or non-shared."""
|
||||
return await asyncio.gather(
|
||||
*[
|
||||
self.get_shared_conversation_info(conversation_id)
|
||||
for conversation_id in conversation_ids
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class SharedConversationInfoServiceInjector(
|
||||
DiscriminatedUnionMixin, Injector[SharedConversationInfoService], ABC
|
||||
):
|
||||
pass
|
||||
56
enterprise/server/sharing/shared_conversation_models.py
Normal file
56
enterprise/server/sharing/shared_conversation_models.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
# Simplified imports to avoid dependency chain issues
|
||||
# from openhands.integrations.service_types import ProviderType
|
||||
# from openhands.sdk.llm import MetricsSnapshot
|
||||
# from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
# For now, use Any to avoid import issues
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from openhands.agent_server.utils import OpenHandsUUID, utc_now
|
||||
|
||||
ProviderType = Any
|
||||
MetricsSnapshot = Any
|
||||
ConversationTrigger = Any
|
||||
|
||||
|
||||
class SharedConversation(BaseModel):
|
||||
"""Shared conversation info model with all fields from AppConversationInfo."""
|
||||
|
||||
id: OpenHandsUUID = Field(default_factory=uuid4)
|
||||
|
||||
created_by_user_id: str | None
|
||||
sandbox_id: str
|
||||
|
||||
selected_repository: str | None = None
|
||||
selected_branch: str | None = None
|
||||
git_provider: ProviderType | None = None
|
||||
title: str | None = None
|
||||
pr_number: list[int] = Field(default_factory=list)
|
||||
llm_model: str | None = None
|
||||
|
||||
metrics: MetricsSnapshot | None = None
|
||||
|
||||
parent_conversation_id: OpenHandsUUID | None = None
|
||||
sub_conversation_ids: list[OpenHandsUUID] = Field(default_factory=list)
|
||||
|
||||
created_at: datetime = Field(default_factory=utc_now)
|
||||
updated_at: datetime = Field(default_factory=utc_now)
|
||||
|
||||
|
||||
class SharedConversationSortOrder(Enum):
|
||||
CREATED_AT = 'CREATED_AT'
|
||||
CREATED_AT_DESC = 'CREATED_AT_DESC'
|
||||
UPDATED_AT = 'UPDATED_AT'
|
||||
UPDATED_AT_DESC = 'UPDATED_AT_DESC'
|
||||
TITLE = 'TITLE'
|
||||
TITLE_DESC = 'TITLE_DESC'
|
||||
|
||||
|
||||
class SharedConversationPage(BaseModel):
|
||||
items: list[SharedConversation]
|
||||
next_page_id: str | None = None
|
||||
135
enterprise/server/sharing/shared_conversation_router.py
Normal file
135
enterprise/server/sharing/shared_conversation_router.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Shared Conversation router for OpenHands Server."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from server.sharing.shared_conversation_info_service import (
|
||||
SharedConversationInfoService,
|
||||
)
|
||||
from server.sharing.shared_conversation_models import (
|
||||
SharedConversation,
|
||||
SharedConversationPage,
|
||||
SharedConversationSortOrder,
|
||||
)
|
||||
from server.sharing.sql_shared_conversation_info_service import (
|
||||
SQLSharedConversationInfoServiceInjector,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix='/api/shared-conversations', tags=['Sharing'])
|
||||
shared_conversation_info_service_dependency = Depends(
|
||||
SQLSharedConversationInfoServiceInjector().depends
|
||||
)
|
||||
|
||||
# Read methods
|
||||
|
||||
|
||||
@router.get('/search')
|
||||
async def search_shared_conversations(
|
||||
title__contains: Annotated[
|
||||
str | None,
|
||||
Query(title='Filter by title containing this string'),
|
||||
] = None,
|
||||
created_at__gte: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Filter by created_at greater than or equal to this datetime'),
|
||||
] = None,
|
||||
created_at__lt: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Filter by created_at less than this datetime'),
|
||||
] = None,
|
||||
updated_at__gte: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Filter by updated_at greater than or equal to this datetime'),
|
||||
] = None,
|
||||
updated_at__lt: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Filter by updated_at less than this datetime'),
|
||||
] = None,
|
||||
sort_order: Annotated[
|
||||
SharedConversationSortOrder,
|
||||
Query(title='Sort order for results'),
|
||||
] = SharedConversationSortOrder.CREATED_AT_DESC,
|
||||
page_id: Annotated[
|
||||
str | None,
|
||||
Query(title='Optional next_page_id from the previously returned page'),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(
|
||||
title='The max number of results in the page',
|
||||
gt=0,
|
||||
lte=100,
|
||||
),
|
||||
] = 100,
|
||||
include_sub_conversations: Annotated[
|
||||
bool,
|
||||
Query(
|
||||
title='If True, include sub-conversations in the results. If False (default), exclude all sub-conversations.'
|
||||
),
|
||||
] = False,
|
||||
shared_conversation_service: SharedConversationInfoService = shared_conversation_info_service_dependency,
|
||||
) -> SharedConversationPage:
|
||||
"""Search / List shared conversations."""
|
||||
assert limit > 0
|
||||
assert limit <= 100
|
||||
return await shared_conversation_service.search_shared_conversation_info(
|
||||
title__contains=title__contains,
|
||||
created_at__gte=created_at__gte,
|
||||
created_at__lt=created_at__lt,
|
||||
updated_at__gte=updated_at__gte,
|
||||
updated_at__lt=updated_at__lt,
|
||||
sort_order=sort_order,
|
||||
page_id=page_id,
|
||||
limit=limit,
|
||||
include_sub_conversations=include_sub_conversations,
|
||||
)
|
||||
|
||||
|
||||
@router.get('/count')
|
||||
async def count_shared_conversations(
|
||||
title__contains: Annotated[
|
||||
str | None,
|
||||
Query(title='Filter by title containing this string'),
|
||||
] = None,
|
||||
created_at__gte: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Filter by created_at greater than or equal to this datetime'),
|
||||
] = None,
|
||||
created_at__lt: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Filter by created_at less than this datetime'),
|
||||
] = None,
|
||||
updated_at__gte: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Filter by updated_at greater than or equal to this datetime'),
|
||||
] = None,
|
||||
updated_at__lt: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Filter by updated_at less than this datetime'),
|
||||
] = None,
|
||||
shared_conversation_service: SharedConversationInfoService = shared_conversation_info_service_dependency,
|
||||
) -> int:
|
||||
"""Count shared conversations matching the given filters."""
|
||||
return await shared_conversation_service.count_shared_conversation_info(
|
||||
title__contains=title__contains,
|
||||
created_at__gte=created_at__gte,
|
||||
created_at__lt=created_at__lt,
|
||||
updated_at__gte=updated_at__gte,
|
||||
updated_at__lt=updated_at__lt,
|
||||
)
|
||||
|
||||
|
||||
@router.get('')
|
||||
async def batch_get_shared_conversations(
|
||||
ids: Annotated[list[str], Query()],
|
||||
shared_conversation_service: SharedConversationInfoService = shared_conversation_info_service_dependency,
|
||||
) -> list[SharedConversation | None]:
|
||||
"""Get a batch of shared conversations given their ids. Return None for any missing or non-shared."""
|
||||
assert len(ids) <= 100
|
||||
uuids = [UUID(id_) for id_ in ids]
|
||||
shared_conversation_info = (
|
||||
await shared_conversation_service.batch_get_shared_conversation_info(uuids)
|
||||
)
|
||||
return shared_conversation_info
|
||||
126
enterprise/server/sharing/shared_event_router.py
Normal file
126
enterprise/server/sharing/shared_event_router.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""Shared Event router for OpenHands Server."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from server.sharing.filesystem_shared_event_service import (
|
||||
SharedEventServiceImplInjector,
|
||||
)
|
||||
from server.sharing.shared_event_service import SharedEventService
|
||||
|
||||
from openhands.agent_server.models import EventPage, EventSortOrder
|
||||
from openhands.app_server.event_callback.event_callback_models import EventKind
|
||||
from openhands.sdk import Event
|
||||
|
||||
router = APIRouter(prefix='/api/shared-events', tags=['Sharing'])
|
||||
shared_event_service_dependency = Depends(SharedEventServiceImplInjector().depends)
|
||||
|
||||
|
||||
# Read methods
|
||||
|
||||
|
||||
@router.get('/search')
|
||||
async def search_shared_events(
|
||||
conversation_id: Annotated[
|
||||
str,
|
||||
Query(title='Conversation ID to search events for'),
|
||||
],
|
||||
kind__eq: Annotated[
|
||||
EventKind | None,
|
||||
Query(title='Optional filter by event kind'),
|
||||
] = None,
|
||||
timestamp__gte: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Optional filter by timestamp greater than or equal to'),
|
||||
] = None,
|
||||
timestamp__lt: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Optional filter by timestamp less than'),
|
||||
] = None,
|
||||
sort_order: Annotated[
|
||||
EventSortOrder,
|
||||
Query(title='Sort order for results'),
|
||||
] = EventSortOrder.TIMESTAMP,
|
||||
page_id: Annotated[
|
||||
str | None,
|
||||
Query(title='Optional next_page_id from the previously returned page'),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(title='The max number of results in the page', gt=0, lte=100),
|
||||
] = 100,
|
||||
shared_event_service: SharedEventService = shared_event_service_dependency,
|
||||
) -> EventPage:
|
||||
"""Search / List events for a shared conversation."""
|
||||
assert limit > 0
|
||||
assert limit <= 100
|
||||
return await shared_event_service.search_shared_events(
|
||||
conversation_id=UUID(conversation_id),
|
||||
kind__eq=kind__eq,
|
||||
timestamp__gte=timestamp__gte,
|
||||
timestamp__lt=timestamp__lt,
|
||||
sort_order=sort_order,
|
||||
page_id=page_id,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
|
||||
@router.get('/count')
|
||||
async def count_shared_events(
|
||||
conversation_id: Annotated[
|
||||
str,
|
||||
Query(title='Conversation ID to count events for'),
|
||||
],
|
||||
kind__eq: Annotated[
|
||||
EventKind | None,
|
||||
Query(title='Optional filter by event kind'),
|
||||
] = None,
|
||||
timestamp__gte: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Optional filter by timestamp greater than or equal to'),
|
||||
] = None,
|
||||
timestamp__lt: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Optional filter by timestamp less than'),
|
||||
] = None,
|
||||
sort_order: Annotated[
|
||||
EventSortOrder,
|
||||
Query(title='Sort order for results'),
|
||||
] = EventSortOrder.TIMESTAMP,
|
||||
shared_event_service: SharedEventService = shared_event_service_dependency,
|
||||
) -> int:
|
||||
"""Count events for a shared conversation matching the given filters."""
|
||||
return await shared_event_service.count_shared_events(
|
||||
conversation_id=UUID(conversation_id),
|
||||
kind__eq=kind__eq,
|
||||
timestamp__gte=timestamp__gte,
|
||||
timestamp__lt=timestamp__lt,
|
||||
sort_order=sort_order,
|
||||
)
|
||||
|
||||
|
||||
@router.get('')
|
||||
async def batch_get_shared_events(
|
||||
conversation_id: Annotated[
|
||||
UUID,
|
||||
Query(title='Conversation ID to get events for'),
|
||||
],
|
||||
id: Annotated[list[str], Query()],
|
||||
shared_event_service: SharedEventService = shared_event_service_dependency,
|
||||
) -> list[Event | None]:
|
||||
"""Get a batch of events for a shared conversation given their ids, returning null for any missing event."""
|
||||
assert len(id) <= 100
|
||||
events = await shared_event_service.batch_get_shared_events(conversation_id, id)
|
||||
return events
|
||||
|
||||
|
||||
@router.get('/{conversation_id}/{event_id}')
|
||||
async def get_shared_event(
|
||||
conversation_id: UUID,
|
||||
event_id: str,
|
||||
shared_event_service: SharedEventService = shared_event_service_dependency,
|
||||
) -> Event | None:
|
||||
"""Get a single event from a shared conversation by conversation_id and event_id."""
|
||||
return await shared_event_service.get_shared_event(conversation_id, event_id)
|
||||
64
enterprise/server/sharing/shared_event_service.py
Normal file
64
enterprise/server/sharing/shared_event_service.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from openhands.agent_server.models import EventPage, EventSortOrder
|
||||
from openhands.app_server.event_callback.event_callback_models import EventKind
|
||||
from openhands.app_server.services.injector import Injector
|
||||
from openhands.sdk import Event
|
||||
from openhands.sdk.utils.models import DiscriminatedUnionMixin
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SharedEventService(ABC):
|
||||
"""Event Service for getting events from shared conversations only."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_shared_event(
|
||||
self, conversation_id: UUID, event_id: str
|
||||
) -> Event | None:
|
||||
"""Given a conversation_id and event_id, retrieve an event if the conversation is shared."""
|
||||
|
||||
@abstractmethod
|
||||
async def search_shared_events(
|
||||
self,
|
||||
conversation_id: UUID,
|
||||
kind__eq: EventKind | None = None,
|
||||
timestamp__gte: datetime | None = None,
|
||||
timestamp__lt: datetime | None = None,
|
||||
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
|
||||
page_id: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> EventPage:
|
||||
"""Search events for a specific shared conversation."""
|
||||
|
||||
@abstractmethod
|
||||
async def count_shared_events(
|
||||
self,
|
||||
conversation_id: UUID,
|
||||
kind__eq: EventKind | None = None,
|
||||
timestamp__gte: datetime | None = None,
|
||||
timestamp__lt: datetime | None = None,
|
||||
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
|
||||
) -> int:
|
||||
"""Count events for a specific shared conversation."""
|
||||
|
||||
async def batch_get_shared_events(
|
||||
self, conversation_id: UUID, event_ids: list[str]
|
||||
) -> list[Event | None]:
|
||||
"""Given a conversation_id and list of event_ids, get events if the conversation is shared."""
|
||||
return await asyncio.gather(
|
||||
*[
|
||||
self.get_shared_event(conversation_id, event_id)
|
||||
for event_id in event_ids
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class SharedEventServiceInjector(
|
||||
DiscriminatedUnionMixin, Injector[SharedEventService], ABC
|
||||
):
|
||||
pass
|
||||
@@ -0,0 +1,282 @@
|
||||
"""SQL implementation of SharedConversationInfoService.
|
||||
|
||||
This implementation provides read-only access to shared conversations:
|
||||
- Direct database access without user permission checks
|
||||
- Filters only conversations marked as shared (currently public)
|
||||
- Full async/await support using SQL async db_sessions
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from typing import AsyncGenerator
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Request
|
||||
from server.sharing.shared_conversation_info_service import (
|
||||
SharedConversationInfoService,
|
||||
SharedConversationInfoServiceInjector,
|
||||
)
|
||||
from server.sharing.shared_conversation_models import (
|
||||
SharedConversation,
|
||||
SharedConversationPage,
|
||||
SharedConversationSortOrder,
|
||||
)
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
StoredConversationMetadata,
|
||||
)
|
||||
from openhands.app_server.services.injector import InjectorState
|
||||
from openhands.integrations.provider import ProviderType
|
||||
from openhands.sdk.llm import MetricsSnapshot
|
||||
from openhands.sdk.llm.utils.metrics import TokenUsage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SQLSharedConversationInfoService(SharedConversationInfoService):
|
||||
"""SQL implementation of SharedConversationInfoService for shared conversations only."""
|
||||
|
||||
db_session: AsyncSession
|
||||
|
||||
async def search_shared_conversation_info(
|
||||
self,
|
||||
title__contains: str | None = None,
|
||||
created_at__gte: datetime | None = None,
|
||||
created_at__lt: datetime | None = None,
|
||||
updated_at__gte: datetime | None = None,
|
||||
updated_at__lt: datetime | None = None,
|
||||
sort_order: SharedConversationSortOrder = SharedConversationSortOrder.CREATED_AT_DESC,
|
||||
page_id: str | None = None,
|
||||
limit: int = 100,
|
||||
include_sub_conversations: bool = False,
|
||||
) -> SharedConversationPage:
|
||||
"""Search for shared conversations."""
|
||||
query = self._public_select()
|
||||
|
||||
# Conditionally exclude sub-conversations based on the parameter
|
||||
if not include_sub_conversations:
|
||||
# Exclude sub-conversations (only include top-level conversations)
|
||||
query = query.where(
|
||||
StoredConversationMetadata.parent_conversation_id.is_(None)
|
||||
)
|
||||
|
||||
query = self._apply_filters(
|
||||
query=query,
|
||||
title__contains=title__contains,
|
||||
created_at__gte=created_at__gte,
|
||||
created_at__lt=created_at__lt,
|
||||
updated_at__gte=updated_at__gte,
|
||||
updated_at__lt=updated_at__lt,
|
||||
)
|
||||
|
||||
# Add sort order
|
||||
if sort_order == SharedConversationSortOrder.CREATED_AT:
|
||||
query = query.order_by(StoredConversationMetadata.created_at)
|
||||
elif sort_order == SharedConversationSortOrder.CREATED_AT_DESC:
|
||||
query = query.order_by(StoredConversationMetadata.created_at.desc())
|
||||
elif sort_order == SharedConversationSortOrder.UPDATED_AT:
|
||||
query = query.order_by(StoredConversationMetadata.last_updated_at)
|
||||
elif sort_order == SharedConversationSortOrder.UPDATED_AT_DESC:
|
||||
query = query.order_by(StoredConversationMetadata.last_updated_at.desc())
|
||||
elif sort_order == SharedConversationSortOrder.TITLE:
|
||||
query = query.order_by(StoredConversationMetadata.title)
|
||||
elif sort_order == SharedConversationSortOrder.TITLE_DESC:
|
||||
query = query.order_by(StoredConversationMetadata.title.desc())
|
||||
|
||||
# Apply pagination
|
||||
if page_id is not None:
|
||||
try:
|
||||
offset = int(page_id)
|
||||
query = query.offset(offset)
|
||||
except ValueError:
|
||||
# If page_id is not a valid integer, start from beginning
|
||||
offset = 0
|
||||
else:
|
||||
offset = 0
|
||||
|
||||
# Apply limit and get one extra to check if there are more results
|
||||
query = query.limit(limit + 1)
|
||||
|
||||
result = await self.db_session.execute(query)
|
||||
rows = result.scalars().all()
|
||||
|
||||
# Check if there are more results
|
||||
has_more = len(rows) > limit
|
||||
if has_more:
|
||||
rows = rows[:limit]
|
||||
|
||||
items = [self._to_shared_conversation(row) for row in rows]
|
||||
|
||||
# Calculate next page ID
|
||||
next_page_id = None
|
||||
if has_more:
|
||||
next_page_id = str(offset + limit)
|
||||
|
||||
return SharedConversationPage(items=items, next_page_id=next_page_id)
|
||||
|
||||
async def count_shared_conversation_info(
|
||||
self,
|
||||
title__contains: str | None = None,
|
||||
created_at__gte: datetime | None = None,
|
||||
created_at__lt: datetime | None = None,
|
||||
updated_at__gte: datetime | None = None,
|
||||
updated_at__lt: datetime | None = None,
|
||||
) -> int:
|
||||
"""Count shared conversations matching the given filters."""
|
||||
from sqlalchemy import func
|
||||
|
||||
query = select(func.count(StoredConversationMetadata.conversation_id))
|
||||
# Only include shared conversations
|
||||
query = query.where(StoredConversationMetadata.public == True) # noqa: E712
|
||||
query = query.where(StoredConversationMetadata.conversation_version == 'V1')
|
||||
|
||||
query = self._apply_filters(
|
||||
query=query,
|
||||
title__contains=title__contains,
|
||||
created_at__gte=created_at__gte,
|
||||
created_at__lt=created_at__lt,
|
||||
updated_at__gte=updated_at__gte,
|
||||
updated_at__lt=updated_at__lt,
|
||||
)
|
||||
|
||||
result = await self.db_session.execute(query)
|
||||
return result.scalar() or 0
|
||||
|
||||
async def get_shared_conversation_info(
|
||||
self, conversation_id: UUID
|
||||
) -> SharedConversation | None:
|
||||
"""Get a single public conversation info, returning None if missing or not shared."""
|
||||
query = self._public_select().where(
|
||||
StoredConversationMetadata.conversation_id == str(conversation_id)
|
||||
)
|
||||
|
||||
result = await self.db_session.execute(query)
|
||||
stored = result.scalar_one_or_none()
|
||||
|
||||
if stored is None:
|
||||
return None
|
||||
|
||||
return self._to_shared_conversation(stored)
|
||||
|
||||
def _public_select(self):
|
||||
"""Create a select query that only returns public conversations."""
|
||||
query = select(StoredConversationMetadata).where(
|
||||
StoredConversationMetadata.conversation_version == 'V1'
|
||||
)
|
||||
# Only include conversations marked as public
|
||||
query = query.where(StoredConversationMetadata.public == True) # noqa: E712
|
||||
return query
|
||||
|
||||
def _apply_filters(
|
||||
self,
|
||||
query,
|
||||
title__contains: str | None = None,
|
||||
created_at__gte: datetime | None = None,
|
||||
created_at__lt: datetime | None = None,
|
||||
updated_at__gte: datetime | None = None,
|
||||
updated_at__lt: datetime | None = None,
|
||||
):
|
||||
"""Apply common filters to a query."""
|
||||
if title__contains is not None:
|
||||
query = query.where(
|
||||
StoredConversationMetadata.title.contains(title__contains)
|
||||
)
|
||||
|
||||
if created_at__gte is not None:
|
||||
query = query.where(
|
||||
StoredConversationMetadata.created_at >= created_at__gte
|
||||
)
|
||||
|
||||
if created_at__lt is not None:
|
||||
query = query.where(StoredConversationMetadata.created_at < created_at__lt)
|
||||
|
||||
if updated_at__gte is not None:
|
||||
query = query.where(
|
||||
StoredConversationMetadata.last_updated_at >= updated_at__gte
|
||||
)
|
||||
|
||||
if updated_at__lt is not None:
|
||||
query = query.where(
|
||||
StoredConversationMetadata.last_updated_at < updated_at__lt
|
||||
)
|
||||
|
||||
return query
|
||||
|
||||
def _to_shared_conversation(
|
||||
self,
|
||||
stored: StoredConversationMetadata,
|
||||
sub_conversation_ids: list[UUID] | None = None,
|
||||
) -> SharedConversation:
|
||||
"""Convert StoredConversationMetadata to SharedConversation."""
|
||||
# V1 conversations should always have a sandbox_id
|
||||
sandbox_id = stored.sandbox_id
|
||||
assert sandbox_id is not None
|
||||
|
||||
# Rebuild token usage
|
||||
token_usage = TokenUsage(
|
||||
prompt_tokens=stored.prompt_tokens,
|
||||
completion_tokens=stored.completion_tokens,
|
||||
cache_read_tokens=stored.cache_read_tokens,
|
||||
cache_write_tokens=stored.cache_write_tokens,
|
||||
context_window=stored.context_window,
|
||||
per_turn_token=stored.per_turn_token,
|
||||
)
|
||||
|
||||
# Rebuild metrics object
|
||||
metrics = MetricsSnapshot(
|
||||
accumulated_cost=stored.accumulated_cost,
|
||||
max_budget_per_task=stored.max_budget_per_task,
|
||||
accumulated_token_usage=token_usage,
|
||||
)
|
||||
|
||||
# Get timestamps
|
||||
created_at = self._fix_timezone(stored.created_at)
|
||||
updated_at = self._fix_timezone(stored.last_updated_at)
|
||||
|
||||
return SharedConversation(
|
||||
id=UUID(stored.conversation_id),
|
||||
created_by_user_id=stored.user_id if stored.user_id else None,
|
||||
sandbox_id=stored.sandbox_id,
|
||||
selected_repository=stored.selected_repository,
|
||||
selected_branch=stored.selected_branch,
|
||||
git_provider=(
|
||||
ProviderType(stored.git_provider) if stored.git_provider else None
|
||||
),
|
||||
title=stored.title,
|
||||
pr_number=stored.pr_number,
|
||||
llm_model=stored.llm_model,
|
||||
metrics=metrics,
|
||||
parent_conversation_id=(
|
||||
UUID(stored.parent_conversation_id)
|
||||
if stored.parent_conversation_id
|
||||
else None
|
||||
),
|
||||
sub_conversation_ids=sub_conversation_ids or [],
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
)
|
||||
|
||||
def _fix_timezone(self, value: datetime) -> datetime:
|
||||
"""Sqlite does not store timezones - and since we can't update the existing models
|
||||
we assume UTC if the timezone is missing."""
|
||||
if not value.tzinfo:
|
||||
value = value.replace(tzinfo=UTC)
|
||||
return value
|
||||
|
||||
|
||||
class SQLSharedConversationInfoServiceInjector(SharedConversationInfoServiceInjector):
|
||||
async def inject(
|
||||
self, state: InjectorState, request: Request | None = None
|
||||
) -> AsyncGenerator[SharedConversationInfoService, None]:
|
||||
# Define inline to prevent circular lookup
|
||||
from openhands.app_server.config import get_db_session
|
||||
|
||||
async with get_db_session(state, request) as db_session:
|
||||
service = SQLSharedConversationInfoService(db_session=db_session)
|
||||
yield service
|
||||
83
enterprise/server/utils/rate_limit_utils.py
Normal file
83
enterprise/server/utils/rate_limit_utils.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from fastapi import HTTPException, Request, status
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.shared import sio
|
||||
|
||||
# Rate limiting constants
|
||||
RATE_LIMIT_USER_SECONDS = 120 # 2 minutes per user_id
|
||||
RATE_LIMIT_IP_SECONDS = 300 # 5 minutes per IP address
|
||||
|
||||
|
||||
async def check_rate_limit_by_user_id(
|
||||
request: Request,
|
||||
key_prefix: str,
|
||||
user_id: str | None,
|
||||
user_rate_limit_seconds: int = RATE_LIMIT_USER_SECONDS,
|
||||
ip_rate_limit_seconds: int = RATE_LIMIT_IP_SECONDS,
|
||||
) -> None:
|
||||
"""
|
||||
Check rate limit for requests, using user_id when available, falling back to IP address.
|
||||
|
||||
Uses Redis to store rate limit keys with expiration. If a key already exists,
|
||||
it means the rate limit is active and the request will be rejected.
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object
|
||||
key_prefix: Prefix for the Redis key (e.g., "email_resend")
|
||||
user_id: User ID if available, None otherwise
|
||||
user_rate_limit_seconds: Rate limit window in seconds for user_id-based limiting (default: 120)
|
||||
ip_rate_limit_seconds: Rate limit window in seconds for IP-based limiting (default: 300)
|
||||
|
||||
Raises:
|
||||
HTTPException: If rate limit is exceeded (429 status code)
|
||||
"""
|
||||
try:
|
||||
redis = sio.manager.redis
|
||||
if not redis:
|
||||
# If Redis is unavailable, log warning and allow request (fail open)
|
||||
logger.warning('Redis unavailable for rate limiting, allowing request')
|
||||
return
|
||||
|
||||
if user_id:
|
||||
# Rate limit by user_id (primary method)
|
||||
rate_limit_key = f'{key_prefix}:{user_id}'
|
||||
rate_limit_seconds = user_rate_limit_seconds
|
||||
else:
|
||||
# Fallback to IP address rate limiting
|
||||
client_ip = request.client.host if request.client else 'unknown'
|
||||
rate_limit_key = f'{key_prefix}:ip:{client_ip}'
|
||||
rate_limit_seconds = ip_rate_limit_seconds
|
||||
|
||||
# Try to set the key with expiration. If it already exists (nx=True fails),
|
||||
# it means the rate limit is active
|
||||
created = await redis.set(rate_limit_key, 1, nx=True, ex=rate_limit_seconds)
|
||||
|
||||
if not created:
|
||||
logger.info(
|
||||
f'Rate limit exceeded for {rate_limit_key}',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'ip': request.client.host if request.client else 'unknown',
|
||||
},
|
||||
)
|
||||
# Format error message based on duration
|
||||
if rate_limit_seconds < 60:
|
||||
wait_message = f'{rate_limit_seconds} seconds'
|
||||
elif rate_limit_seconds % 60 == 0:
|
||||
wait_message = f'{rate_limit_seconds // 60} minute{"s" if rate_limit_seconds // 60 != 1 else ""}'
|
||||
else:
|
||||
minutes = rate_limit_seconds // 60
|
||||
seconds = rate_limit_seconds % 60
|
||||
wait_message = f'{minutes} minute{"s" if minutes != 1 else ""} and {seconds} second{"s" if seconds != 1 else ""}'
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail=f'Too many requests. Please wait {wait_message} before trying again.',
|
||||
)
|
||||
except HTTPException:
|
||||
# Re-raise HTTPException (rate limit exceeded)
|
||||
raise
|
||||
except Exception as e:
|
||||
# Log error but allow request (fail open) to avoid blocking legitimate users
|
||||
logger.warning(f'Error checking rate limit: {e}', exc_info=True)
|
||||
return
|
||||
@@ -19,17 +19,23 @@ GCP_REGION = os.environ.get('GCP_REGION')
|
||||
|
||||
POOL_SIZE = int(os.environ.get('DB_POOL_SIZE', '25'))
|
||||
MAX_OVERFLOW = int(os.environ.get('DB_MAX_OVERFLOW', '10'))
|
||||
POOL_RECYCLE = int(os.environ.get('DB_POOL_RECYCLE', '1800'))
|
||||
|
||||
# Initialize Cloud SQL Connector once at module level for GCP environments.
|
||||
_connector = None
|
||||
|
||||
|
||||
def _get_db_engine():
|
||||
if GCP_DB_INSTANCE: # GCP environments
|
||||
|
||||
def get_db_connection():
|
||||
global _connector
|
||||
from google.cloud.sql.connector import Connector
|
||||
|
||||
connector = Connector()
|
||||
if not _connector:
|
||||
_connector = Connector()
|
||||
instance_string = f'{GCP_PROJECT}:{GCP_REGION}:{GCP_DB_INSTANCE}'
|
||||
return connector.connect(
|
||||
return _connector.connect(
|
||||
instance_string, 'pg8000', user=DB_USER, password=DB_PASS, db=DB_NAME
|
||||
)
|
||||
|
||||
@@ -38,6 +44,7 @@ def _get_db_engine():
|
||||
creator=get_db_connection,
|
||||
pool_size=POOL_SIZE,
|
||||
max_overflow=MAX_OVERFLOW,
|
||||
pool_recycle=POOL_RECYCLE,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
else:
|
||||
@@ -48,6 +55,7 @@ def _get_db_engine():
|
||||
host_string,
|
||||
pool_size=POOL_SIZE,
|
||||
max_overflow=MAX_OVERFLOW,
|
||||
pool_recycle=POOL_RECYCLE,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -61,6 +61,7 @@ class SaasConversationStore(ConversationStore):
|
||||
kwargs.pop('context_window', None)
|
||||
kwargs.pop('per_turn_token', None)
|
||||
kwargs.pop('parent_conversation_id', None)
|
||||
kwargs.pop('public')
|
||||
|
||||
return ConversationMetadata(**kwargs)
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from server.constants import (
|
||||
LITE_LLM_API_URL,
|
||||
LITE_LLM_TEAM_ID,
|
||||
REQUIRE_PAYMENT,
|
||||
USER_SETTINGS_VERSION_TO_MODEL,
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from server.logger import logger
|
||||
@@ -202,6 +203,53 @@ class SaasSettingsStore(SettingsStore):
|
||||
)
|
||||
return None
|
||||
|
||||
def _has_custom_settings(
|
||||
self, settings: Settings, old_user_version: int | None
|
||||
) -> bool:
|
||||
"""
|
||||
Check if user has custom LLM settings that should be preserved.
|
||||
Returns True if user customized either model or base_url.
|
||||
|
||||
Args:
|
||||
settings: The user's current settings
|
||||
old_user_version: The user's old settings version, if any
|
||||
|
||||
Returns:
|
||||
True if user has custom settings, False if using old defaults
|
||||
"""
|
||||
# Normalize values
|
||||
user_model = (
|
||||
settings.llm_model.strip()
|
||||
if settings.llm_model and settings.llm_model.strip()
|
||||
else None
|
||||
)
|
||||
user_base_url = (
|
||||
settings.llm_base_url.strip()
|
||||
if settings.llm_base_url and settings.llm_base_url.strip()
|
||||
else None
|
||||
)
|
||||
|
||||
# Custom base_url = definitely custom settings (BYOK)
|
||||
if user_base_url and user_base_url != LITE_LLM_API_URL:
|
||||
return True
|
||||
|
||||
# No model set = using defaults
|
||||
if not user_model:
|
||||
return False
|
||||
|
||||
# Check if model matches old version's default
|
||||
if (
|
||||
old_user_version
|
||||
and old_user_version < CURRENT_USER_SETTINGS_VERSION
|
||||
and old_user_version in USER_SETTINGS_VERSION_TO_MODEL
|
||||
):
|
||||
old_default_base = USER_SETTINGS_VERSION_TO_MODEL[old_user_version]
|
||||
user_model_base = user_model.split('/')[-1]
|
||||
if user_model_base == old_default_base:
|
||||
return False # Matches old default
|
||||
|
||||
return True # Custom model
|
||||
|
||||
async def update_settings_with_litellm_default(
|
||||
self, settings: Settings
|
||||
) -> Settings | None:
|
||||
@@ -213,6 +261,17 @@ class SaasSettingsStore(SettingsStore):
|
||||
return None
|
||||
local_deploy = os.environ.get('LOCAL_DEPLOYMENT', None)
|
||||
key = LITE_LLM_API_KEY
|
||||
|
||||
# Check if user has custom settings
|
||||
has_custom = self._has_custom_settings(settings, settings.user_version)
|
||||
|
||||
# Determine model to use (needed before LiteLLM user creation)
|
||||
llm_model_to_use = (
|
||||
settings.llm_model
|
||||
if has_custom and settings.llm_model
|
||||
else get_default_litellm_model()
|
||||
)
|
||||
|
||||
if not local_deploy:
|
||||
# Get user info to add to litellm
|
||||
token_manager = TokenManager()
|
||||
@@ -276,7 +335,7 @@ class SaasSettingsStore(SettingsStore):
|
||||
|
||||
# Create the new litellm user
|
||||
response = await self._create_user_in_lite_llm(
|
||||
client, email, max_budget, spend
|
||||
client, email, max_budget, spend, llm_model_to_use
|
||||
)
|
||||
if not response.is_success:
|
||||
logger.warning(
|
||||
@@ -285,7 +344,7 @@ class SaasSettingsStore(SettingsStore):
|
||||
)
|
||||
# Litellm insists on unique email addresses - it is possible the email address was registered with a different user.
|
||||
response = await self._create_user_in_lite_llm(
|
||||
client, None, max_budget, spend
|
||||
client, None, max_budget, spend, llm_model_to_use
|
||||
)
|
||||
|
||||
# User failed to create in litellm - this is an unforseen error state...
|
||||
@@ -311,11 +370,17 @@ class SaasSettingsStore(SettingsStore):
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
|
||||
if has_custom:
|
||||
settings.llm_model = settings.llm_model or get_default_litellm_model()
|
||||
settings.llm_base_url = settings.llm_base_url or LITE_LLM_API_URL
|
||||
settings.llm_api_key = settings.llm_api_key or SecretStr(key)
|
||||
else:
|
||||
settings.llm_model = get_default_litellm_model()
|
||||
settings.llm_base_url = LITE_LLM_API_URL
|
||||
settings.llm_api_key = SecretStr(key)
|
||||
|
||||
settings.agent = 'CodeActAgent'
|
||||
# Use the model corresponding to the current user settings version
|
||||
settings.llm_model = get_default_litellm_model()
|
||||
settings.llm_api_key = SecretStr(key)
|
||||
settings.llm_base_url = LITE_LLM_API_URL
|
||||
|
||||
return settings
|
||||
|
||||
@classmethod
|
||||
@@ -398,7 +463,12 @@ class SaasSettingsStore(SettingsStore):
|
||||
)
|
||||
|
||||
async def _create_user_in_lite_llm(
|
||||
self, client: httpx.AsyncClient, email: str | None, max_budget: int, spend: int
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
email: str | None,
|
||||
max_budget: int,
|
||||
spend: int,
|
||||
llm_model: str,
|
||||
):
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/new',
|
||||
@@ -413,7 +483,7 @@ class SaasSettingsStore(SettingsStore):
|
||||
'send_invite_email': False,
|
||||
'metadata': {
|
||||
'version': CURRENT_USER_SETTINGS_VERSION,
|
||||
'model': get_default_litellm_model(),
|
||||
'model': llm_model,
|
||||
},
|
||||
'key_alias': f'OpenHands Cloud - user {self.user_id}',
|
||||
},
|
||||
|
||||
@@ -4,6 +4,8 @@ from uuid import uuid4
|
||||
|
||||
from integrations.types import GitLabResourceType
|
||||
from integrations.utils import GITLAB_WEBHOOK_URL
|
||||
from sqlalchemy import text
|
||||
from storage.database import a_session_maker
|
||||
from storage.gitlab_webhook import GitlabWebhook, WebhookStatus
|
||||
from storage.gitlab_webhook_store import GitlabWebhookStore
|
||||
|
||||
@@ -258,6 +260,25 @@ class VerifyWebhookStatus:
|
||||
|
||||
from integrations.gitlab.gitlab_service import SaaSGitLabService
|
||||
|
||||
# Check if the table exists before proceeding
|
||||
# This handles cases where the CronJob runs before database migrations complete
|
||||
async with a_session_maker() as session:
|
||||
query = text("""
|
||||
SELECT EXISTS (
|
||||
SELECT FROM information_schema.tables
|
||||
WHERE table_name = 'gitlab_webhook'
|
||||
)
|
||||
""")
|
||||
result = await session.execute(query)
|
||||
table_exists = result.scalar() or False
|
||||
|
||||
if not table_exists:
|
||||
logger.info(
|
||||
'gitlab_webhook table does not exist yet, '
|
||||
'waiting for database migrations to complete'
|
||||
)
|
||||
return
|
||||
|
||||
# Get an instance of the webhook store
|
||||
webhook_store = await GitlabWebhookStore.get_instance()
|
||||
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
"""Tests for enterprise integrations utils module."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from integrations.utils import get_summary_for_agent_state
|
||||
from integrations.utils import (
|
||||
append_conversation_footer,
|
||||
get_summary_for_agent_state,
|
||||
)
|
||||
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events.observation.agent import AgentStateChangedObservation
|
||||
@@ -157,3 +162,138 @@ class TestGetSummaryForAgentState:
|
||||
assert 'try again later' in result.lower()
|
||||
# RATE_LIMITED doesn't include conversation link in response
|
||||
assert self.conversation_link not in result
|
||||
|
||||
|
||||
class TestAppendConversationFooter:
|
||||
"""Test cases for append_conversation_footer function."""
|
||||
|
||||
@patch(
|
||||
'integrations.utils.CONVERSATION_URL', 'https://example.com/conversations/{}'
|
||||
)
|
||||
def test_appends_footer_with_markdown_link(self):
|
||||
"""Test that footer is appended with correct markdown link format."""
|
||||
# Arrange
|
||||
message = 'This is a test message'
|
||||
conversation_id = 'test-conv-123'
|
||||
|
||||
# Act
|
||||
result = append_conversation_footer(message, conversation_id)
|
||||
|
||||
# Assert
|
||||
assert result.startswith(message)
|
||||
assert (
|
||||
'[View full conversation](https://example.com/conversations/test-conv-123)'
|
||||
in result
|
||||
)
|
||||
assert result.endswith(
|
||||
'[View full conversation](https://example.com/conversations/test-conv-123)'
|
||||
)
|
||||
|
||||
@patch(
|
||||
'integrations.utils.CONVERSATION_URL', 'https://example.com/conversations/{}'
|
||||
)
|
||||
def test_footer_does_not_contain_html_tags(self):
|
||||
"""Test that footer does not contain HTML tags like <sub>."""
|
||||
# Arrange
|
||||
message = 'Test message'
|
||||
conversation_id = 'test-conv-456'
|
||||
|
||||
# Act
|
||||
result = append_conversation_footer(message, conversation_id)
|
||||
|
||||
# Assert
|
||||
assert '<sub>' not in result
|
||||
assert '</sub>' not in result
|
||||
|
||||
@patch(
|
||||
'integrations.utils.CONVERSATION_URL', 'https://example.com/conversations/{}'
|
||||
)
|
||||
def test_footer_format_with_newlines(self):
|
||||
"""Test that footer is properly separated with newlines."""
|
||||
# Arrange
|
||||
message = 'Original message content'
|
||||
conversation_id = 'test-conv-789'
|
||||
|
||||
# Act
|
||||
result = append_conversation_footer(message, conversation_id)
|
||||
|
||||
# Assert
|
||||
assert (
|
||||
result
|
||||
== 'Original message content\n\n[View full conversation](https://example.com/conversations/test-conv-789)'
|
||||
)
|
||||
|
||||
@patch(
|
||||
'integrations.utils.CONVERSATION_URL', 'https://example.com/conversations/{}'
|
||||
)
|
||||
def test_empty_message_still_appends_footer(self):
|
||||
"""Test that footer is appended even when message is empty."""
|
||||
# Arrange
|
||||
message = ''
|
||||
conversation_id = 'empty-msg-conv'
|
||||
|
||||
# Act
|
||||
result = append_conversation_footer(message, conversation_id)
|
||||
|
||||
# Assert
|
||||
assert result.startswith('\n\n')
|
||||
assert (
|
||||
'[View full conversation](https://example.com/conversations/empty-msg-conv)'
|
||||
in result
|
||||
)
|
||||
|
||||
@patch(
|
||||
'integrations.utils.CONVERSATION_URL', 'https://example.com/conversations/{}'
|
||||
)
|
||||
def test_conversation_id_with_special_characters(self):
|
||||
"""Test that footer handles conversation IDs with special characters."""
|
||||
# Arrange
|
||||
message = 'Test message'
|
||||
conversation_id = 'conv-123_abc-456'
|
||||
|
||||
# Act
|
||||
result = append_conversation_footer(message, conversation_id)
|
||||
|
||||
# Assert
|
||||
expected_url = 'https://example.com/conversations/conv-123_abc-456'
|
||||
assert expected_url in result
|
||||
assert '[View full conversation]' in result
|
||||
|
||||
@patch(
|
||||
'integrations.utils.CONVERSATION_URL', 'https://example.com/conversations/{}'
|
||||
)
|
||||
def test_multiline_message_preserves_content(self):
|
||||
"""Test that multiline messages are preserved correctly."""
|
||||
# Arrange
|
||||
message = 'Line 1\nLine 2\nLine 3'
|
||||
conversation_id = 'multiline-conv'
|
||||
|
||||
# Act
|
||||
result = append_conversation_footer(message, conversation_id)
|
||||
|
||||
# Assert
|
||||
assert result.startswith('Line 1\nLine 2\nLine 3')
|
||||
assert '\n\n[View full conversation]' in result
|
||||
assert message in result
|
||||
|
||||
@patch(
|
||||
'integrations.utils.CONVERSATION_URL', 'https://example.com/conversations/{}'
|
||||
)
|
||||
def test_footer_contains_only_markdown_syntax(self):
|
||||
"""Test that footer uses only markdown syntax, not HTML."""
|
||||
# Arrange
|
||||
message = 'Test message'
|
||||
conversation_id = 'markdown-test'
|
||||
|
||||
# Act
|
||||
result = append_conversation_footer(message, conversation_id)
|
||||
|
||||
# Assert
|
||||
footer_part = result[len(message) :]
|
||||
# Should only contain markdown link syntax: [text](url)
|
||||
assert footer_part.startswith('\n\n[')
|
||||
assert '](' in footer_part
|
||||
assert footer_part.endswith(')')
|
||||
# Should not contain any HTML tags (specifically <sub> tags that were removed)
|
||||
assert '<sub>' not in footer_part
|
||||
assert '</sub>' not in footer_part
|
||||
|
||||
361
enterprise/tests/unit/server/routes/test_email_routes.py
Normal file
361
enterprise/tests/unit/server/routes/test_email_routes.py
Normal file
@@ -0,0 +1,361 @@
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException, Request, status
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from pydantic import SecretStr
|
||||
from server.auth.saas_user_auth import SaasUserAuth
|
||||
from server.routes.email import (
|
||||
ResendEmailVerificationRequest,
|
||||
resend_email_verification,
|
||||
verified_email,
|
||||
verify_email,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request():
|
||||
"""Create a mock request object."""
|
||||
request = MagicMock(spec=Request)
|
||||
request.url = MagicMock()
|
||||
request.url.hostname = 'localhost'
|
||||
request.url.netloc = 'localhost:8000'
|
||||
request.url.path = '/api/email/verified'
|
||||
request.base_url = 'http://localhost:8000/'
|
||||
request.headers = {}
|
||||
request.cookies = {}
|
||||
request.query_params = MagicMock()
|
||||
return request
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_auth():
|
||||
"""Create a mock SaasUserAuth object."""
|
||||
auth = MagicMock(spec=SaasUserAuth)
|
||||
auth.access_token = SecretStr('test_access_token')
|
||||
auth.refresh_token = SecretStr('test_refresh_token')
|
||||
auth.email = 'test@example.com'
|
||||
auth.email_verified = False
|
||||
auth.accepted_tos = True
|
||||
auth.refresh = AsyncMock()
|
||||
return auth
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_email_default_behavior(mock_request):
|
||||
"""Test verify_email with default is_auth_flow=False."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
mock_keycloak_admin = AsyncMock()
|
||||
mock_keycloak_admin.a_send_verify_email = AsyncMock()
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin
|
||||
):
|
||||
await verify_email(request=mock_request, user_id=user_id)
|
||||
|
||||
# Assert
|
||||
mock_keycloak_admin.a_send_verify_email.assert_called_once()
|
||||
call_args = mock_keycloak_admin.a_send_verify_email.call_args
|
||||
assert call_args.kwargs['user_id'] == user_id
|
||||
assert (
|
||||
call_args.kwargs['redirect_uri'] == 'http://localhost:8000/api/email/verified'
|
||||
)
|
||||
assert 'client_id' in call_args.kwargs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_email_with_auth_flow(mock_request):
|
||||
"""Test verify_email with is_auth_flow=True."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
mock_keycloak_admin = AsyncMock()
|
||||
mock_keycloak_admin.a_send_verify_email = AsyncMock()
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin
|
||||
):
|
||||
await verify_email(request=mock_request, user_id=user_id, is_auth_flow=True)
|
||||
|
||||
# Assert
|
||||
mock_keycloak_admin.a_send_verify_email.assert_called_once()
|
||||
call_args = mock_keycloak_admin.a_send_verify_email.call_args
|
||||
assert call_args.kwargs['user_id'] == user_id
|
||||
assert (
|
||||
call_args.kwargs['redirect_uri'] == 'http://localhost:8000?email_verified=true'
|
||||
)
|
||||
assert 'client_id' in call_args.kwargs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_email_https_scheme(mock_request):
|
||||
"""Test verify_email uses https scheme for non-localhost hosts."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
mock_request.url.hostname = 'example.com'
|
||||
mock_request.url.netloc = 'example.com'
|
||||
mock_keycloak_admin = AsyncMock()
|
||||
mock_keycloak_admin.a_send_verify_email = AsyncMock()
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin
|
||||
):
|
||||
await verify_email(request=mock_request, user_id=user_id, is_auth_flow=True)
|
||||
|
||||
# Assert
|
||||
call_args = mock_keycloak_admin.a_send_verify_email.call_args
|
||||
assert call_args.kwargs['redirect_uri'].startswith('https://')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verified_email_default_redirect(mock_request, mock_user_auth):
|
||||
"""Test verified_email redirects to /settings/user by default."""
|
||||
# Arrange
|
||||
mock_request.query_params.get.return_value = None
|
||||
|
||||
# Act
|
||||
with (
|
||||
patch('server.routes.email.get_user_auth', return_value=mock_user_auth),
|
||||
patch('server.routes.email.set_response_cookie') as mock_set_cookie,
|
||||
):
|
||||
result = await verified_email(mock_request)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.status_code == 302
|
||||
assert result.headers['location'] == 'http://localhost:8000/settings/user'
|
||||
mock_user_auth.refresh.assert_called_once()
|
||||
mock_set_cookie.assert_called_once()
|
||||
assert mock_user_auth.email_verified is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verified_email_https_scheme(mock_request, mock_user_auth):
|
||||
"""Test verified_email uses https scheme for non-localhost hosts."""
|
||||
# Arrange
|
||||
mock_request.url.hostname = 'example.com'
|
||||
mock_request.url.netloc = 'example.com'
|
||||
mock_request.query_params.get.return_value = None
|
||||
|
||||
# Act
|
||||
with (
|
||||
patch('server.routes.email.get_user_auth', return_value=mock_user_auth),
|
||||
patch('server.routes.email.set_response_cookie') as mock_set_cookie,
|
||||
):
|
||||
result = await verified_email(mock_request)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.headers['location'].startswith('https://')
|
||||
mock_set_cookie.assert_called_once()
|
||||
# Verify secure flag is True for https
|
||||
call_kwargs = mock_set_cookie.call_args.kwargs
|
||||
assert call_kwargs['secure'] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resend_email_verification_with_user_id_from_body_succeeds(mock_request):
|
||||
"""Test resend_email_verification succeeds when user_id is provided in body."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
body = ResendEmailVerificationRequest(user_id=user_id, is_auth_flow=False)
|
||||
mock_keycloak_admin = AsyncMock()
|
||||
mock_keycloak_admin.a_send_verify_email = AsyncMock()
|
||||
|
||||
with (
|
||||
patch('server.routes.email.check_rate_limit_by_user_id') as mock_rate_limit,
|
||||
patch(
|
||||
'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin
|
||||
),
|
||||
patch('server.routes.email.logger') as mock_logger,
|
||||
):
|
||||
mock_rate_limit.return_value = None # Rate limit check passes
|
||||
|
||||
# Act
|
||||
result = await resend_email_verification(request=mock_request, body=body)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == status.HTTP_200_OK
|
||||
assert 'message' in result.body.decode()
|
||||
mock_rate_limit.assert_called_once_with(
|
||||
request=mock_request,
|
||||
key_prefix='email_resend',
|
||||
user_id=user_id,
|
||||
user_rate_limit_seconds=30,
|
||||
ip_rate_limit_seconds=60,
|
||||
)
|
||||
mock_keycloak_admin.a_send_verify_email.assert_called_once()
|
||||
# Logger is called multiple times (verify_email and resend_email_verification)
|
||||
# Check that the resend message was logged
|
||||
assert any(
|
||||
'Resending verification email for' in str(call)
|
||||
for call in mock_logger.info.call_args_list
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resend_email_verification_with_user_id_from_auth_succeeds(mock_request):
|
||||
"""Test resend_email_verification succeeds when user_id comes from authentication."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
mock_keycloak_admin = AsyncMock()
|
||||
mock_keycloak_admin.a_send_verify_email = AsyncMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.email.get_user_id', return_value=user_id
|
||||
) as mock_get_user_id,
|
||||
patch('server.routes.email.check_rate_limit_by_user_id') as mock_rate_limit,
|
||||
patch(
|
||||
'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin
|
||||
),
|
||||
):
|
||||
mock_rate_limit.return_value = None # Rate limit check passes
|
||||
|
||||
# Act
|
||||
result = await resend_email_verification(request=mock_request, body=None)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == status.HTTP_200_OK
|
||||
mock_get_user_id.assert_called_once_with(mock_request)
|
||||
mock_rate_limit.assert_called_once_with(
|
||||
request=mock_request,
|
||||
key_prefix='email_resend',
|
||||
user_id=user_id,
|
||||
user_rate_limit_seconds=30,
|
||||
ip_rate_limit_seconds=60,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resend_email_verification_without_user_id_returns_400(mock_request):
|
||||
"""Test resend_email_verification returns 400 when user_id is not available."""
|
||||
# Arrange
|
||||
with patch(
|
||||
'server.routes.email.get_user_id', side_effect=Exception('Not authenticated')
|
||||
):
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await resend_email_verification(request=mock_request, body=None)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST
|
||||
assert 'user_id is required' in exc_info.value.detail
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resend_email_verification_rate_limit_exceeded_returns_429(mock_request):
|
||||
"""Test resend_email_verification returns 429 when rate limit is exceeded."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
body = ResendEmailVerificationRequest(user_id=user_id)
|
||||
|
||||
with (
|
||||
patch('server.routes.email.check_rate_limit_by_user_id') as mock_rate_limit,
|
||||
):
|
||||
mock_rate_limit.side_effect = HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail='Too many requests. Please wait 2 minutes before trying again.',
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await resend_email_verification(request=mock_request, body=body)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
||||
assert 'Too many requests' in exc_info.value.detail
|
||||
mock_rate_limit.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resend_email_verification_with_is_auth_flow_true(mock_request):
|
||||
"""Test resend_email_verification passes is_auth_flow to verify_email."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
body = ResendEmailVerificationRequest(user_id=user_id, is_auth_flow=True)
|
||||
mock_keycloak_admin = AsyncMock()
|
||||
mock_keycloak_admin.a_send_verify_email = AsyncMock()
|
||||
|
||||
with (
|
||||
patch('server.routes.email.check_rate_limit_by_user_id') as mock_rate_limit,
|
||||
patch(
|
||||
'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin
|
||||
),
|
||||
):
|
||||
mock_rate_limit.return_value = None
|
||||
|
||||
# Act
|
||||
await resend_email_verification(request=mock_request, body=body)
|
||||
|
||||
# Assert
|
||||
mock_keycloak_admin.a_send_verify_email.assert_called_once()
|
||||
call_args = mock_keycloak_admin.a_send_verify_email.call_args
|
||||
# Verify that verify_email was called with is_auth_flow=True
|
||||
# We check this indirectly by verifying the redirect_uri
|
||||
assert 'email_verified=true' in call_args.kwargs['redirect_uri']
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resend_email_verification_with_is_auth_flow_false(mock_request):
|
||||
"""Test resend_email_verification uses default is_auth_flow=False when not specified."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
body = ResendEmailVerificationRequest(user_id=user_id, is_auth_flow=False)
|
||||
mock_keycloak_admin = AsyncMock()
|
||||
mock_keycloak_admin.a_send_verify_email = AsyncMock()
|
||||
|
||||
with (
|
||||
patch('server.routes.email.check_rate_limit_by_user_id') as mock_rate_limit,
|
||||
patch(
|
||||
'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin
|
||||
),
|
||||
):
|
||||
mock_rate_limit.return_value = None
|
||||
|
||||
# Act
|
||||
await resend_email_verification(request=mock_request, body=body)
|
||||
|
||||
# Assert
|
||||
mock_keycloak_admin.a_send_verify_email.assert_called_once()
|
||||
call_args = mock_keycloak_admin.a_send_verify_email.call_args
|
||||
# Verify that verify_email was called with is_auth_flow=False
|
||||
assert '/api/email/verified' in call_args.kwargs['redirect_uri']
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resend_email_verification_body_none_uses_auth(mock_request):
|
||||
"""Test resend_email_verification uses auth when body is None."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
mock_keycloak_admin = AsyncMock()
|
||||
mock_keycloak_admin.a_send_verify_email = AsyncMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.email.get_user_id', return_value=user_id
|
||||
) as mock_get_user_id,
|
||||
patch('server.routes.email.check_rate_limit_by_user_id') as mock_rate_limit,
|
||||
patch(
|
||||
'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin
|
||||
),
|
||||
):
|
||||
mock_rate_limit.return_value = None
|
||||
|
||||
# Act
|
||||
result = await resend_email_verification(request=mock_request, body=None)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == status.HTTP_200_OK
|
||||
mock_get_user_id.assert_called_once()
|
||||
mock_rate_limit.assert_called_once_with(
|
||||
request=mock_request,
|
||||
key_prefix='email_resend',
|
||||
user_id=user_id,
|
||||
user_rate_limit_seconds=30,
|
||||
ip_rate_limit_seconds=60,
|
||||
)
|
||||
@@ -699,12 +699,11 @@ class TestProcessBatchOperationsBackground:
|
||||
# Should not raise exceptions
|
||||
await _process_batch_operations_background(batch_ops, 'test-api-key')
|
||||
|
||||
# Should log the error
|
||||
mock_logger.error.assert_called_once_with(
|
||||
'error_processing_batch_operation',
|
||||
extra={
|
||||
'path': 'invalid-path',
|
||||
'method': 'BatchMethod.POST',
|
||||
'error': mock_logger.error.call_args[1]['extra']['error'],
|
||||
},
|
||||
)
|
||||
# Should log the error with exception type and message in the log message
|
||||
mock_logger.error.assert_called_once()
|
||||
call_args = mock_logger.error.call_args
|
||||
log_message = call_args[0][0]
|
||||
assert log_message.startswith('error_processing_batch_operation:')
|
||||
assert call_args[1]['extra']['path'] == 'invalid-path'
|
||||
assert call_args[1]['extra']['method'] == 'BatchMethod.POST'
|
||||
assert call_args[1]['exc_info'] is True
|
||||
|
||||
290
enterprise/tests/unit/server/test_rate_limit_utils.py
Normal file
290
enterprise/tests/unit/server/test_rate_limit_utils.py
Normal file
@@ -0,0 +1,290 @@
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException, Request, status
|
||||
from server.utils.rate_limit_utils import (
|
||||
RATE_LIMIT_IP_SECONDS,
|
||||
RATE_LIMIT_USER_SECONDS,
|
||||
check_rate_limit_by_user_id,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request():
|
||||
"""Create a mock request object."""
|
||||
request = MagicMock(spec=Request)
|
||||
request.client = MagicMock()
|
||||
request.client.host = '192.168.1.1'
|
||||
return request
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis():
|
||||
"""Create a mock Redis client."""
|
||||
redis = AsyncMock()
|
||||
redis.set = AsyncMock(return_value=True) # First call succeeds (key doesn't exist)
|
||||
return redis
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_by_user_id_first_request_succeeds(mock_request, mock_redis):
|
||||
"""Test that first request with user_id succeeds and sets rate limit key."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
key_prefix = 'email_resend'
|
||||
|
||||
with (
|
||||
patch('server.utils.rate_limit_utils.sio') as mock_sio,
|
||||
patch('server.utils.rate_limit_utils.logger') as mock_logger,
|
||||
):
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
# Act
|
||||
await check_rate_limit_by_user_id(
|
||||
request=mock_request, key_prefix=key_prefix, user_id=user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_redis.set.assert_called_once_with(
|
||||
f'{key_prefix}:{user_id}', 1, nx=True, ex=RATE_LIMIT_USER_SECONDS
|
||||
)
|
||||
mock_logger.warning.assert_not_called()
|
||||
mock_logger.info.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_by_user_id_second_request_within_window_fails(
|
||||
mock_request, mock_redis
|
||||
):
|
||||
"""Test that second request with same user_id within rate limit window fails."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
key_prefix = 'email_resend'
|
||||
mock_redis.set = AsyncMock(return_value=False) # Key already exists
|
||||
|
||||
with (
|
||||
patch('server.utils.rate_limit_utils.sio') as mock_sio,
|
||||
patch('server.utils.rate_limit_utils.logger') as mock_logger,
|
||||
):
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await check_rate_limit_by_user_id(
|
||||
request=mock_request, key_prefix=key_prefix, user_id=user_id
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
||||
assert 'Too many requests' in exc_info.value.detail
|
||||
assert f'{RATE_LIMIT_USER_SECONDS // 60} minutes' in exc_info.value.detail
|
||||
mock_logger.info.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_by_ip_when_user_id_is_none(mock_request, mock_redis):
|
||||
"""Test that rate limiting falls back to IP address when user_id is None."""
|
||||
# Arrange
|
||||
key_prefix = 'email_resend'
|
||||
|
||||
with (
|
||||
patch('server.utils.rate_limit_utils.sio') as mock_sio,
|
||||
patch('server.utils.rate_limit_utils.logger') as mock_logger,
|
||||
):
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
# Act
|
||||
await check_rate_limit_by_user_id(
|
||||
request=mock_request, key_prefix=key_prefix, user_id=None
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_redis.set.assert_called_once_with(
|
||||
f'{key_prefix}:ip:{mock_request.client.host}',
|
||||
1,
|
||||
nx=True,
|
||||
ex=RATE_LIMIT_IP_SECONDS,
|
||||
)
|
||||
mock_logger.warning.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_by_ip_second_request_within_window_fails(
|
||||
mock_request, mock_redis
|
||||
):
|
||||
"""Test that second request from same IP within rate limit window fails."""
|
||||
# Arrange
|
||||
key_prefix = 'email_resend'
|
||||
mock_redis.set = AsyncMock(return_value=False) # Key already exists
|
||||
|
||||
with (
|
||||
patch('server.utils.rate_limit_utils.sio') as mock_sio,
|
||||
):
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await check_rate_limit_by_user_id(
|
||||
request=mock_request, key_prefix=key_prefix, user_id=None
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
||||
assert f'{RATE_LIMIT_IP_SECONDS // 60} minutes' in exc_info.value.detail
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_redis_unavailable_fails_open(mock_request):
|
||||
"""Test that rate limiting fails open when Redis is unavailable."""
|
||||
# Arrange
|
||||
key_prefix = 'email_resend'
|
||||
user_id = 'test_user_id'
|
||||
|
||||
with (
|
||||
patch('server.utils.rate_limit_utils.sio') as mock_sio,
|
||||
patch('server.utils.rate_limit_utils.logger') as mock_logger,
|
||||
):
|
||||
mock_sio.manager.redis = None # Redis unavailable
|
||||
|
||||
# Act
|
||||
await check_rate_limit_by_user_id(
|
||||
request=mock_request, key_prefix=key_prefix, user_id=user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_logger.warning.assert_called_once_with(
|
||||
'Redis unavailable for rate limiting, allowing request'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_redis_exception_fails_open(mock_request, mock_redis):
|
||||
"""Test that rate limiting fails open when Redis raises an exception."""
|
||||
# Arrange
|
||||
key_prefix = 'email_resend'
|
||||
user_id = 'test_user_id'
|
||||
mock_redis.set = AsyncMock(side_effect=Exception('Redis connection error'))
|
||||
|
||||
with (
|
||||
patch('server.utils.rate_limit_utils.sio') as mock_sio,
|
||||
patch('server.utils.rate_limit_utils.logger') as mock_logger,
|
||||
):
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
# Act
|
||||
await check_rate_limit_by_user_id(
|
||||
request=mock_request, key_prefix=key_prefix, user_id=user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_logger.warning.assert_called_once()
|
||||
assert 'Error checking rate limit' in str(mock_logger.warning.call_args[0][0])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_custom_key_prefix(mock_request, mock_redis):
|
||||
"""Test that different key prefixes create different rate limit keys."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
key_prefix = 'password_reset'
|
||||
|
||||
with patch('server.utils.rate_limit_utils.sio') as mock_sio:
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
# Act
|
||||
await check_rate_limit_by_user_id(
|
||||
request=mock_request, key_prefix=key_prefix, user_id=user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_redis.set.assert_called_once_with(
|
||||
f'{key_prefix}:{user_id}', 1, nx=True, ex=RATE_LIMIT_USER_SECONDS
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_custom_rate_limit_seconds(mock_request, mock_redis):
|
||||
"""Test that custom rate limit seconds are used correctly."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
key_prefix = 'email_resend'
|
||||
custom_user_seconds = 60
|
||||
custom_ip_seconds = 180
|
||||
|
||||
with patch('server.utils.rate_limit_utils.sio') as mock_sio:
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
# Act
|
||||
await check_rate_limit_by_user_id(
|
||||
request=mock_request,
|
||||
key_prefix=key_prefix,
|
||||
user_id=user_id,
|
||||
user_rate_limit_seconds=custom_user_seconds,
|
||||
ip_rate_limit_seconds=custom_ip_seconds,
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_redis.set.assert_called_once_with(
|
||||
f'{key_prefix}:{user_id}', 1, nx=True, ex=custom_user_seconds
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_ip_with_unknown_client(mock_request, mock_redis):
|
||||
"""Test that rate limiting handles missing client host gracefully."""
|
||||
# Arrange
|
||||
key_prefix = 'email_resend'
|
||||
mock_request.client = None # No client information
|
||||
|
||||
with patch('server.utils.rate_limit_utils.sio') as mock_sio:
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
# Act
|
||||
await check_rate_limit_by_user_id(
|
||||
request=mock_request, key_prefix=key_prefix, user_id=None
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_redis.set.assert_called_once_with(
|
||||
f'{key_prefix}:ip:unknown', 1, nx=True, ex=RATE_LIMIT_IP_SECONDS
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_different_users_have_separate_limits(
|
||||
mock_request, mock_redis
|
||||
):
|
||||
"""Test that different user_ids have separate rate limit keys."""
|
||||
# Arrange
|
||||
key_prefix = 'email_resend'
|
||||
user_id_1 = 'user_1'
|
||||
user_id_2 = 'user_2'
|
||||
|
||||
with patch('server.utils.rate_limit_utils.sio') as mock_sio:
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
# Act
|
||||
await check_rate_limit_by_user_id(
|
||||
request=mock_request, key_prefix=key_prefix, user_id=user_id_1
|
||||
)
|
||||
await check_rate_limit_by_user_id(
|
||||
request=mock_request, key_prefix=key_prefix, user_id=user_id_2
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert mock_redis.set.call_count == 2
|
||||
# Extract call arguments properly
|
||||
call_args_list = [
|
||||
(call[0][0], call[0][1], call[1]['nx'], call[1]['ex'])
|
||||
for call in mock_redis.set.call_args_list
|
||||
]
|
||||
assert (
|
||||
f'{key_prefix}:{user_id_1}',
|
||||
1,
|
||||
True,
|
||||
RATE_LIMIT_USER_SECONDS,
|
||||
) in call_args_list
|
||||
assert (
|
||||
f'{key_prefix}:{user_id_2}',
|
||||
1,
|
||||
True,
|
||||
RATE_LIMIT_USER_SECONDS,
|
||||
) in call_args_list
|
||||
@@ -234,3 +234,53 @@ async def test_middleware_with_other_auth_error(middleware, mock_request):
|
||||
assert 'set-cookie' in result.headers
|
||||
# Logger should be called for non-NoCredentialsError
|
||||
mock_logger.warning.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_ignores_email_resend_path(
|
||||
middleware, mock_request, mock_response
|
||||
):
|
||||
"""Test middleware ignores /api/email/resend path and doesn't require authentication."""
|
||||
# Arrange
|
||||
mock_request.cookies = {}
|
||||
mock_request.url = MagicMock()
|
||||
mock_request.url.hostname = 'localhost'
|
||||
mock_request.url.path = '/api/email/resend'
|
||||
mock_call_next = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Act
|
||||
result = await middleware(mock_request, mock_call_next)
|
||||
|
||||
# Assert
|
||||
assert result == mock_response
|
||||
mock_call_next.assert_called_once_with(mock_request)
|
||||
# Should not raise NoCredentialsError even without auth cookie
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_ignores_email_resend_path_no_tos_check(
|
||||
middleware, mock_request, mock_response
|
||||
):
|
||||
"""Test middleware doesn't check TOS for /api/email/resend path."""
|
||||
# Arrange
|
||||
mock_request.cookies = {'keycloak_auth': 'test_cookie'}
|
||||
mock_request.url = MagicMock()
|
||||
mock_request.url.hostname = 'localhost'
|
||||
mock_request.url.path = '/api/email/resend'
|
||||
mock_call_next = AsyncMock(return_value=mock_response)
|
||||
|
||||
with (
|
||||
patch('server.middleware.jwt.decode') as mock_decode,
|
||||
patch('server.middleware.config') as mock_config,
|
||||
):
|
||||
# Even with accepted_tos=False, should not raise TosNotAcceptedError
|
||||
mock_decode.return_value = {'accepted_tos': False}
|
||||
mock_config.jwt_secret.get_secret_value.return_value = 'test_secret'
|
||||
|
||||
# Act
|
||||
result = await middleware(mock_request, mock_call_next)
|
||||
|
||||
# Assert
|
||||
assert result == mock_response
|
||||
mock_call_next.assert_called_once_with(mock_request)
|
||||
# Should not raise TosNotAcceptedError for this path
|
||||
|
||||
@@ -136,6 +136,7 @@ async def test_keycloak_callback_user_not_allowed(mock_request):
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'identity_provider': 'github',
|
||||
'email_verified': True,
|
||||
}
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
@@ -184,6 +185,7 @@ async def test_keycloak_callback_success_with_valid_offline_token(mock_request):
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'identity_provider': 'github',
|
||||
'email_verified': True,
|
||||
}
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
@@ -214,6 +216,84 @@ async def test_keycloak_callback_success_with_valid_offline_token(mock_request):
|
||||
mock_posthog.set.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_email_not_verified(mock_request):
|
||||
"""Test keycloak_callback when email is not verified."""
|
||||
# Arrange
|
||||
mock_verify_email = AsyncMock()
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.email.verify_email', mock_verify_email),
|
||||
):
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'identity_provider': 'github',
|
||||
'email_verified': False,
|
||||
}
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_verifier.is_active.return_value = False
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.status_code == 302
|
||||
assert 'email_verification_required=true' in result.headers['location']
|
||||
assert 'user_id=test_user_id' in result.headers['location']
|
||||
mock_verify_email.assert_called_once_with(
|
||||
request=mock_request, user_id='test_user_id', is_auth_flow=True
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_email_not_verified_missing_field(mock_request):
|
||||
"""Test keycloak_callback when email_verified field is missing (defaults to False)."""
|
||||
# Arrange
|
||||
mock_verify_email = AsyncMock()
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.email.verify_email', mock_verify_email),
|
||||
):
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'identity_provider': 'github',
|
||||
# email_verified field is missing
|
||||
}
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_verifier.is_active.return_value = False
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.status_code == 302
|
||||
assert 'email_verification_required=true' in result.headers['location']
|
||||
assert 'user_id=test_user_id' in result.headers['location']
|
||||
mock_verify_email.assert_called_once_with(
|
||||
request=mock_request, user_id='test_user_id', is_auth_flow=True
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_success_without_offline_token(mock_request):
|
||||
"""Test successful keycloak_callback without valid offline token."""
|
||||
@@ -248,6 +328,7 @@ async def test_keycloak_callback_success_without_offline_token(mock_request):
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'identity_provider': 'github',
|
||||
'email_verified': True,
|
||||
}
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
@@ -442,3 +523,418 @@ async def test_logout_without_refresh_token():
|
||||
|
||||
mock_token_manager.logout.assert_not_called()
|
||||
assert 'set-cookie' in result.headers
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_blocked_email_domain(mock_request):
|
||||
"""Test keycloak_callback when email domain is blocked."""
|
||||
# Arrange
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
):
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'email': 'user@colsch.us',
|
||||
'identity_provider': 'github',
|
||||
}
|
||||
)
|
||||
mock_token_manager.disable_keycloak_user = AsyncMock()
|
||||
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
mock_domain_blocker.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
assert 'error' in result.body.decode()
|
||||
assert 'email domain is not allowed' in result.body.decode()
|
||||
mock_domain_blocker.is_domain_blocked.assert_called_once_with('user@colsch.us')
|
||||
mock_token_manager.disable_keycloak_user.assert_called_once_with(
|
||||
'test_user_id', 'user@colsch.us'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_allowed_email_domain(mock_request):
|
||||
"""Test keycloak_callback when email domain is not blocked."""
|
||||
# Arrange
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
|
||||
mock_user_settings = MagicMock()
|
||||
mock_user_settings.accepted_tos = '2025-01-01'
|
||||
mock_query.first.return_value = mock_user_settings
|
||||
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'email': 'user@example.com',
|
||||
'identity_provider': 'github',
|
||||
'email_verified': True,
|
||||
}
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
mock_domain_blocker.is_domain_blocked.assert_called_once_with(
|
||||
'user@example.com'
|
||||
)
|
||||
mock_token_manager.disable_keycloak_user.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_domain_blocking_inactive(mock_request):
|
||||
"""Test keycloak_callback when domain blocking is not active."""
|
||||
# Arrange
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
|
||||
mock_user_settings = MagicMock()
|
||||
mock_user_settings.accepted_tos = '2025-01-01'
|
||||
mock_query.first.return_value = mock_user_settings
|
||||
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'email': 'user@colsch.us',
|
||||
'identity_provider': 'github',
|
||||
'email_verified': True,
|
||||
}
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
mock_domain_blocker.is_active.return_value = False
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
mock_domain_blocker.is_domain_blocked.assert_not_called()
|
||||
mock_token_manager.disable_keycloak_user.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_missing_email(mock_request):
|
||||
"""Test keycloak_callback when user info does not contain email."""
|
||||
# Arrange
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
|
||||
mock_user_settings = MagicMock()
|
||||
mock_user_settings.accepted_tos = '2025-01-01'
|
||||
mock_query.first.return_value = mock_user_settings
|
||||
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'identity_provider': 'github',
|
||||
'email_verified': True,
|
||||
# No email field
|
||||
}
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
mock_domain_blocker.is_domain_blocked.assert_not_called()
|
||||
mock_token_manager.disable_keycloak_user.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_duplicate_email_detected(mock_request):
|
||||
"""Test keycloak_callback when duplicate email is detected."""
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
):
|
||||
# Arrange
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'email': 'joe+test@example.com',
|
||||
'identity_provider': 'github',
|
||||
}
|
||||
)
|
||||
mock_token_manager.check_duplicate_base_email = AsyncMock(return_value=True)
|
||||
mock_token_manager.delete_keycloak_user = AsyncMock(return_value=True)
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.status_code == 302
|
||||
assert 'duplicated_email=true' in result.headers['location']
|
||||
mock_token_manager.check_duplicate_base_email.assert_called_once_with(
|
||||
'joe+test@example.com', 'test_user_id'
|
||||
)
|
||||
mock_token_manager.delete_keycloak_user.assert_called_once_with('test_user_id')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_duplicate_email_deletion_fails(mock_request):
|
||||
"""Test keycloak_callback when duplicate is detected but deletion fails."""
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
):
|
||||
# Arrange
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'email': 'joe+test@example.com',
|
||||
'identity_provider': 'github',
|
||||
}
|
||||
)
|
||||
mock_token_manager.check_duplicate_base_email = AsyncMock(return_value=True)
|
||||
mock_token_manager.delete_keycloak_user = AsyncMock(return_value=False)
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.status_code == 302
|
||||
assert 'duplicated_email=true' in result.headers['location']
|
||||
mock_token_manager.delete_keycloak_user.assert_called_once_with('test_user_id')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_duplicate_check_exception(mock_request):
|
||||
"""Test keycloak_callback when duplicate check raises exception."""
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
):
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_user_settings = MagicMock()
|
||||
mock_user_settings.accepted_tos = '2025-01-01'
|
||||
mock_query.first.return_value = mock_user_settings
|
||||
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'email': 'joe+test@example.com',
|
||||
'identity_provider': 'github',
|
||||
'email_verified': True,
|
||||
}
|
||||
)
|
||||
mock_token_manager.check_duplicate_base_email = AsyncMock(
|
||||
side_effect=Exception('Check failed')
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
# Should proceed with normal flow despite exception (fail open)
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.status_code == 302
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_no_duplicate_email(mock_request):
|
||||
"""Test keycloak_callback when no duplicate email is found."""
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
):
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_user_settings = MagicMock()
|
||||
mock_user_settings.accepted_tos = '2025-01-01'
|
||||
mock_query.first.return_value = mock_user_settings
|
||||
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'email': 'joe+test@example.com',
|
||||
'identity_provider': 'github',
|
||||
'email_verified': True,
|
||||
}
|
||||
)
|
||||
mock_token_manager.check_duplicate_base_email = AsyncMock(return_value=False)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.status_code == 302
|
||||
mock_token_manager.check_duplicate_base_email.assert_called_once_with(
|
||||
'joe+test@example.com', 'test_user_id'
|
||||
)
|
||||
# Should not delete user when no duplicate found
|
||||
mock_token_manager.delete_keycloak_user.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_no_email_in_user_info(mock_request):
|
||||
"""Test keycloak_callback when email is not in user_info."""
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
):
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_user_settings = MagicMock()
|
||||
mock_user_settings.accepted_tos = '2025-01-01'
|
||||
mock_query.first.return_value = mock_user_settings
|
||||
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
# No email field
|
||||
'identity_provider': 'github',
|
||||
'email_verified': True,
|
||||
}
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.status_code == 302
|
||||
# Should not check for duplicate when email is missing
|
||||
mock_token_manager.check_duplicate_base_email.assert_not_called()
|
||||
|
||||
@@ -4,7 +4,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
import stripe
|
||||
from fastapi import HTTPException, Request, status
|
||||
from httpx import HTTPStatusError, Response
|
||||
from httpx import Response
|
||||
from integrations.stripe_service import has_payment_method
|
||||
from server.routes.billing import (
|
||||
CreateBillingSessionResponse,
|
||||
@@ -78,8 +78,6 @@ def mock_subscription_request():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_credits_lite_llm_error():
|
||||
mock_request = Request(scope={'type': 'http', 'state': {'user_id': 'mock_user'}})
|
||||
|
||||
mock_response = Response(
|
||||
status_code=500, json={'error': 'Internal Server Error'}, request=MagicMock()
|
||||
)
|
||||
@@ -88,11 +86,12 @@ async def test_get_credits_lite_llm_error():
|
||||
|
||||
with patch('integrations.stripe_service.STRIPE_API_KEY', 'mock_key'):
|
||||
with patch('httpx.AsyncClient', return_value=mock_client):
|
||||
with pytest.raises(HTTPStatusError) as exc_info:
|
||||
await get_credits(mock_request)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_credits('mock_user')
|
||||
assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
assert (
|
||||
exc_info.value.response.status_code
|
||||
== status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
exc_info.value.detail
|
||||
== 'Failed to retrieve credit balance from billing service'
|
||||
)
|
||||
|
||||
|
||||
|
||||
181
enterprise/tests/unit/test_domain_blocker.py
Normal file
181
enterprise/tests/unit/test_domain_blocker.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""Unit tests for DomainBlocker class."""
|
||||
|
||||
import pytest
|
||||
from server.auth.domain_blocker import DomainBlocker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def domain_blocker():
|
||||
"""Create a DomainBlocker instance for testing."""
|
||||
return DomainBlocker()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'blocked_domains,expected',
|
||||
[
|
||||
(['colsch.us', 'other-domain.com'], True),
|
||||
(['example.com'], True),
|
||||
([], False),
|
||||
],
|
||||
)
|
||||
def test_is_active(domain_blocker, blocked_domains, expected):
|
||||
"""Test that is_active returns correct value based on blocked domains configuration."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = blocked_domains
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_active()
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'email,expected_domain',
|
||||
[
|
||||
('user@example.com', 'example.com'),
|
||||
('test@colsch.us', 'colsch.us'),
|
||||
('user.name@other-domain.com', 'other-domain.com'),
|
||||
('USER@EXAMPLE.COM', 'example.com'), # Case insensitive
|
||||
('user@EXAMPLE.COM', 'example.com'),
|
||||
(' user@example.com ', 'example.com'), # Whitespace handling
|
||||
],
|
||||
)
|
||||
def test_extract_domain_valid_emails(domain_blocker, email, expected_domain):
|
||||
"""Test that _extract_domain correctly extracts and normalizes domains from valid emails."""
|
||||
# Act
|
||||
result = domain_blocker._extract_domain(email)
|
||||
|
||||
# Assert
|
||||
assert result == expected_domain
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'email,expected',
|
||||
[
|
||||
(None, None),
|
||||
('', None),
|
||||
('invalid-email', None),
|
||||
('user@', None), # Empty domain after @
|
||||
('no-at-sign', None),
|
||||
],
|
||||
)
|
||||
def test_extract_domain_invalid_emails(domain_blocker, email, expected):
|
||||
"""Test that _extract_domain returns None for invalid email formats."""
|
||||
# Act
|
||||
result = domain_blocker._extract_domain(email)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_is_domain_blocked_when_inactive(domain_blocker):
|
||||
"""Test that is_domain_blocked returns False when blocking is not active."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = []
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked('user@colsch.us')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_is_domain_blocked_with_none_email(domain_blocker):
|
||||
"""Test that is_domain_blocked returns False when email is None."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['colsch.us']
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked(None)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_is_domain_blocked_with_empty_email(domain_blocker):
|
||||
"""Test that is_domain_blocked returns False when email is empty."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['colsch.us']
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked('')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_is_domain_blocked_with_invalid_email(domain_blocker):
|
||||
"""Test that is_domain_blocked returns False when email format is invalid."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['colsch.us']
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked('invalid-email')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_is_domain_blocked_domain_not_blocked(domain_blocker):
|
||||
"""Test that is_domain_blocked returns False when domain is not in blocked list."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['colsch.us', 'other-domain.com']
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked('user@example.com')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_is_domain_blocked_domain_blocked(domain_blocker):
|
||||
"""Test that is_domain_blocked returns True when domain is in blocked list."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['colsch.us', 'other-domain.com']
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked('user@colsch.us')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_is_domain_blocked_case_insensitive(domain_blocker):
|
||||
"""Test that is_domain_blocked performs case-insensitive domain matching."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['colsch.us']
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked('user@COLSCH.US')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_is_domain_blocked_multiple_blocked_domains(domain_blocker):
|
||||
"""Test that is_domain_blocked correctly checks against multiple blocked domains."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['colsch.us', 'other-domain.com', 'blocked.org']
|
||||
|
||||
# Act
|
||||
result1 = domain_blocker.is_domain_blocked('user@other-domain.com')
|
||||
result2 = domain_blocker.is_domain_blocked('user@blocked.org')
|
||||
result3 = domain_blocker.is_domain_blocked('user@allowed.com')
|
||||
|
||||
# Assert
|
||||
assert result1 is True
|
||||
assert result2 is True
|
||||
assert result3 is False
|
||||
|
||||
|
||||
def test_is_domain_blocked_with_whitespace(domain_blocker):
|
||||
"""Test that is_domain_blocked handles emails with whitespace correctly."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['colsch.us']
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked(' user@colsch.us ')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
294
enterprise/tests/unit/test_email_validation.py
Normal file
294
enterprise/tests/unit/test_email_validation.py
Normal file
@@ -0,0 +1,294 @@
|
||||
"""Tests for email validation utilities."""
|
||||
|
||||
import re
|
||||
|
||||
from server.auth.email_validation import (
|
||||
extract_base_email,
|
||||
get_base_email_regex_pattern,
|
||||
has_plus_modifier,
|
||||
matches_base_email,
|
||||
)
|
||||
|
||||
|
||||
class TestExtractBaseEmail:
|
||||
"""Test cases for extract_base_email function."""
|
||||
|
||||
def test_extract_base_email_with_plus_modifier(self):
|
||||
"""Test extracting base email from email with + modifier."""
|
||||
# Arrange
|
||||
email = 'joe+test@example.com'
|
||||
|
||||
# Act
|
||||
result = extract_base_email(email)
|
||||
|
||||
# Assert
|
||||
assert result == 'joe@example.com'
|
||||
|
||||
def test_extract_base_email_without_plus_modifier(self):
|
||||
"""Test that email without + modifier is returned as-is."""
|
||||
# Arrange
|
||||
email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
result = extract_base_email(email)
|
||||
|
||||
# Assert
|
||||
assert result == 'joe@example.com'
|
||||
|
||||
def test_extract_base_email_multiple_plus_signs(self):
|
||||
"""Test extracting base email when multiple + signs exist."""
|
||||
# Arrange
|
||||
email = 'joe+openhands+test@example.com'
|
||||
|
||||
# Act
|
||||
result = extract_base_email(email)
|
||||
|
||||
# Assert
|
||||
assert result == 'joe@example.com'
|
||||
|
||||
def test_extract_base_email_invalid_no_at_symbol(self):
|
||||
"""Test that invalid email without @ returns None."""
|
||||
# Arrange
|
||||
email = 'invalid-email'
|
||||
|
||||
# Act
|
||||
result = extract_base_email(email)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
def test_extract_base_email_empty_string(self):
|
||||
"""Test that empty string returns None."""
|
||||
# Arrange
|
||||
email = ''
|
||||
|
||||
# Act
|
||||
result = extract_base_email(email)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
def test_extract_base_email_none(self):
|
||||
"""Test that None input returns None."""
|
||||
# Arrange
|
||||
email = None
|
||||
|
||||
# Act
|
||||
result = extract_base_email(email)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestHasPlusModifier:
|
||||
"""Test cases for has_plus_modifier function."""
|
||||
|
||||
def test_has_plus_modifier_true(self):
|
||||
"""Test detecting + modifier in email."""
|
||||
# Arrange
|
||||
email = 'joe+test@example.com'
|
||||
|
||||
# Act
|
||||
result = has_plus_modifier(email)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
def test_has_plus_modifier_false(self):
|
||||
"""Test that email without + modifier returns False."""
|
||||
# Arrange
|
||||
email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
result = has_plus_modifier(email)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
def test_has_plus_modifier_invalid_no_at_symbol(self):
|
||||
"""Test that invalid email without @ returns False."""
|
||||
# Arrange
|
||||
email = 'invalid-email'
|
||||
|
||||
# Act
|
||||
result = has_plus_modifier(email)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
def test_has_plus_modifier_empty_string(self):
|
||||
"""Test that empty string returns False."""
|
||||
# Arrange
|
||||
email = ''
|
||||
|
||||
# Act
|
||||
result = has_plus_modifier(email)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestMatchesBaseEmail:
|
||||
"""Test cases for matches_base_email function."""
|
||||
|
||||
def test_matches_base_email_exact_match(self):
|
||||
"""Test that exact base email matches."""
|
||||
# Arrange
|
||||
email = 'joe@example.com'
|
||||
base_email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
result = matches_base_email(email, base_email)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
def test_matches_base_email_with_plus_variant(self):
|
||||
"""Test that email with + variant matches base email."""
|
||||
# Arrange
|
||||
email = 'joe+test@example.com'
|
||||
base_email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
result = matches_base_email(email, base_email)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
def test_matches_base_email_different_base(self):
|
||||
"""Test that different base emails do not match."""
|
||||
# Arrange
|
||||
email = 'jane@example.com'
|
||||
base_email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
result = matches_base_email(email, base_email)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
def test_matches_base_email_different_domain(self):
|
||||
"""Test that same local part but different domain does not match."""
|
||||
# Arrange
|
||||
email = 'joe@other.com'
|
||||
base_email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
result = matches_base_email(email, base_email)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
def test_matches_base_email_case_insensitive(self):
|
||||
"""Test that matching is case-insensitive."""
|
||||
# Arrange
|
||||
email = 'JOE+TEST@EXAMPLE.COM'
|
||||
base_email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
result = matches_base_email(email, base_email)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
def test_matches_base_email_empty_strings(self):
|
||||
"""Test that empty strings return False."""
|
||||
# Arrange
|
||||
email = ''
|
||||
base_email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
result = matches_base_email(email, base_email)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestGetBaseEmailRegexPattern:
|
||||
"""Test cases for get_base_email_regex_pattern function."""
|
||||
|
||||
def test_get_base_email_regex_pattern_valid(self):
|
||||
"""Test generating valid regex pattern for base email."""
|
||||
# Arrange
|
||||
base_email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
pattern = get_base_email_regex_pattern(base_email)
|
||||
|
||||
# Assert
|
||||
assert pattern is not None
|
||||
assert isinstance(pattern, re.Pattern)
|
||||
assert pattern.match('joe@example.com') is not None
|
||||
assert pattern.match('joe+test@example.com') is not None
|
||||
assert pattern.match('joe+openhands@example.com') is not None
|
||||
|
||||
def test_get_base_email_regex_pattern_matches_plus_variant(self):
|
||||
"""Test that regex pattern matches + variant."""
|
||||
# Arrange
|
||||
base_email = 'joe@example.com'
|
||||
pattern = get_base_email_regex_pattern(base_email)
|
||||
|
||||
# Act
|
||||
match = pattern.match('joe+test@example.com')
|
||||
|
||||
# Assert
|
||||
assert match is not None
|
||||
|
||||
def test_get_base_email_regex_pattern_rejects_different_base(self):
|
||||
"""Test that regex pattern rejects different base email."""
|
||||
# Arrange
|
||||
base_email = 'joe@example.com'
|
||||
pattern = get_base_email_regex_pattern(base_email)
|
||||
|
||||
# Act
|
||||
match = pattern.match('jane@example.com')
|
||||
|
||||
# Assert
|
||||
assert match is None
|
||||
|
||||
def test_get_base_email_regex_pattern_rejects_different_domain(self):
|
||||
"""Test that regex pattern rejects different domain."""
|
||||
# Arrange
|
||||
base_email = 'joe@example.com'
|
||||
pattern = get_base_email_regex_pattern(base_email)
|
||||
|
||||
# Act
|
||||
match = pattern.match('joe@other.com')
|
||||
|
||||
# Assert
|
||||
assert match is None
|
||||
|
||||
def test_get_base_email_regex_pattern_case_insensitive(self):
|
||||
"""Test that regex pattern is case-insensitive."""
|
||||
# Arrange
|
||||
base_email = 'joe@example.com'
|
||||
pattern = get_base_email_regex_pattern(base_email)
|
||||
|
||||
# Act
|
||||
match = pattern.match('JOE+TEST@EXAMPLE.COM')
|
||||
|
||||
# Assert
|
||||
assert match is not None
|
||||
|
||||
def test_get_base_email_regex_pattern_special_characters(self):
|
||||
"""Test that regex pattern handles special characters in email."""
|
||||
# Arrange
|
||||
base_email = 'user.name+tag@example-site.com'
|
||||
pattern = get_base_email_regex_pattern(base_email)
|
||||
|
||||
# Act
|
||||
match = pattern.match('user.name+test@example-site.com')
|
||||
|
||||
# Assert
|
||||
assert match is not None
|
||||
|
||||
def test_get_base_email_regex_pattern_invalid_base_email(self):
|
||||
"""Test that invalid base email returns None."""
|
||||
# Arrange
|
||||
base_email = 'invalid-email'
|
||||
|
||||
# Act
|
||||
pattern = get_base_email_regex_pattern(base_email)
|
||||
|
||||
# Assert
|
||||
assert pattern is None
|
||||
@@ -6,6 +6,7 @@ from server.constants import (
|
||||
CURRENT_USER_SETTINGS_VERSION,
|
||||
LITE_LLM_API_URL,
|
||||
LITE_LLM_TEAM_ID,
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
from storage.user_settings import UserSettings
|
||||
@@ -393,10 +394,11 @@ async def test_create_user_in_lite_llm(settings_store):
|
||||
mock_response = AsyncMock()
|
||||
mock_response.is_success = True
|
||||
mock_client.post.return_value = mock_response
|
||||
test_model = 'custom-model/test-model'
|
||||
|
||||
# Test with email
|
||||
await settings_store._create_user_in_lite_llm(
|
||||
mock_client, 'test@example.com', 50, 10
|
||||
mock_client, 'test@example.com', 50, 10, test_model
|
||||
)
|
||||
|
||||
# Get the actual call arguments
|
||||
@@ -412,11 +414,11 @@ async def test_create_user_in_lite_llm(settings_store):
|
||||
assert call_args['json']['auto_create_key'] is True
|
||||
assert call_args['json']['send_invite_email'] is False
|
||||
assert call_args['json']['metadata']['version'] == CURRENT_USER_SETTINGS_VERSION
|
||||
assert 'model' in call_args['json']['metadata']
|
||||
assert call_args['json']['metadata']['model'] == test_model
|
||||
|
||||
# Test with None email
|
||||
mock_client.post.reset_mock()
|
||||
await settings_store._create_user_in_lite_llm(mock_client, None, 25, 15)
|
||||
await settings_store._create_user_in_lite_llm(mock_client, None, 25, 15, test_model)
|
||||
|
||||
# Get the actual call arguments
|
||||
call_args = mock_client.post.call_args[1]
|
||||
@@ -431,12 +433,12 @@ async def test_create_user_in_lite_llm(settings_store):
|
||||
assert call_args['json']['auto_create_key'] is True
|
||||
assert call_args['json']['send_invite_email'] is False
|
||||
assert call_args['json']['metadata']['version'] == CURRENT_USER_SETTINGS_VERSION
|
||||
assert 'model' in call_args['json']['metadata']
|
||||
assert call_args['json']['metadata']['model'] == test_model
|
||||
|
||||
# Verify response is returned correctly
|
||||
assert (
|
||||
await settings_store._create_user_in_lite_llm(
|
||||
mock_client, 'email@test.com', 30, 7
|
||||
mock_client, 'email@test.com', 30, 7, test_model
|
||||
)
|
||||
== mock_response
|
||||
)
|
||||
@@ -464,3 +466,808 @@ async def test_encryption(settings_store):
|
||||
# But we should be able to decrypt it when loading
|
||||
loaded_settings = await settings_store.load()
|
||||
assert loaded_settings.llm_api_key.get_secret_value() == 'secret_key'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settings_with_litellm_default_preserves_custom_model(
|
||||
settings_store, mock_litellm_api, session_maker
|
||||
):
|
||||
# Arrange: User has a custom LLM model set
|
||||
custom_model = 'anthropic/claude-3-5-sonnet-20241022'
|
||||
settings = Settings(llm_model=custom_model)
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'user@example.com'}),
|
||||
),
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
):
|
||||
# Act: Update settings with LiteLLM defaults
|
||||
updated_settings = await settings_store.update_settings_with_litellm_default(
|
||||
settings
|
||||
)
|
||||
|
||||
# Assert: Custom model is preserved
|
||||
assert updated_settings is not None
|
||||
assert updated_settings.llm_model == custom_model
|
||||
assert updated_settings.agent == 'CodeActAgent'
|
||||
assert updated_settings.llm_api_key is not None
|
||||
|
||||
# Assert: LiteLLM metadata contains user's custom model
|
||||
call_args = mock_litellm_api.return_value.__aenter__.return_value.post.call_args[1]
|
||||
assert call_args['json']['metadata']['model'] == custom_model
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settings_with_litellm_default_uses_default_when_no_model(
|
||||
settings_store, mock_litellm_api, session_maker
|
||||
):
|
||||
# Arrange: User has no model set (new user scenario)
|
||||
settings = Settings()
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'newuser@example.com'}),
|
||||
),
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
):
|
||||
# Act: Update settings with LiteLLM defaults
|
||||
updated_settings = await settings_store.update_settings_with_litellm_default(
|
||||
settings
|
||||
)
|
||||
|
||||
# Assert: Default model is assigned
|
||||
assert updated_settings is not None
|
||||
expected_default = get_default_litellm_model()
|
||||
assert updated_settings.llm_model == expected_default
|
||||
assert updated_settings.agent == 'CodeActAgent'
|
||||
|
||||
# Assert: LiteLLM metadata contains default model
|
||||
call_args = mock_litellm_api.return_value.__aenter__.return_value.post.call_args[1]
|
||||
assert call_args['json']['metadata']['model'] == expected_default
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settings_with_litellm_default_handles_empty_string_model(
|
||||
settings_store, mock_litellm_api, session_maker
|
||||
):
|
||||
# Arrange: User has empty string as model (edge case)
|
||||
settings = Settings(llm_model='')
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'user@example.com'}),
|
||||
),
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
):
|
||||
# Act: Update settings with LiteLLM defaults
|
||||
updated_settings = await settings_store.update_settings_with_litellm_default(
|
||||
settings
|
||||
)
|
||||
|
||||
# Assert: Default model is used (empty string treated as no model)
|
||||
assert updated_settings is not None
|
||||
expected_default = get_default_litellm_model()
|
||||
assert updated_settings.llm_model == expected_default
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settings_with_litellm_default_handles_whitespace_model(
|
||||
settings_store, mock_litellm_api, session_maker
|
||||
):
|
||||
# Arrange: User has whitespace-only model (edge case)
|
||||
settings = Settings(llm_model=' ')
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'user@example.com'}),
|
||||
),
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
):
|
||||
# Act: Update settings with LiteLLM defaults
|
||||
updated_settings = await settings_store.update_settings_with_litellm_default(
|
||||
settings
|
||||
)
|
||||
|
||||
# Assert: Default model is used (whitespace treated as no model)
|
||||
assert updated_settings is not None
|
||||
expected_default = get_default_litellm_model()
|
||||
assert updated_settings.llm_model == expected_default
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settings_with_litellm_default_preserves_custom_api_key(
|
||||
settings_store, mock_litellm_api, session_maker
|
||||
):
|
||||
# Arrange: User has a custom API key and custom model (so has_custom=True)
|
||||
custom_api_key = 'sk-custom-user-api-key-12345'
|
||||
custom_model = 'gpt-4'
|
||||
settings = Settings(llm_model=custom_model, llm_api_key=SecretStr(custom_api_key))
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'user@example.com'}),
|
||||
),
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
):
|
||||
# Act: Update settings with LiteLLM defaults
|
||||
updated_settings = await settings_store.update_settings_with_litellm_default(
|
||||
settings
|
||||
)
|
||||
|
||||
# Assert: Custom API key is preserved when user has custom settings
|
||||
assert updated_settings is not None
|
||||
assert updated_settings.llm_api_key.get_secret_value() == custom_api_key
|
||||
assert updated_settings.llm_api_key.get_secret_value() != 'test_api_key'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settings_with_litellm_default_preserves_custom_base_url(
|
||||
settings_store, mock_litellm_api, session_maker
|
||||
):
|
||||
# Arrange: User has a custom base URL
|
||||
custom_base_url = 'https://api.custom-llm-provider.com/v1'
|
||||
settings = Settings(llm_base_url=custom_base_url)
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'user@example.com'}),
|
||||
),
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
):
|
||||
# Act: Update settings with LiteLLM defaults
|
||||
updated_settings = await settings_store.update_settings_with_litellm_default(
|
||||
settings
|
||||
)
|
||||
|
||||
# Assert: Custom base URL is preserved
|
||||
assert updated_settings is not None
|
||||
assert updated_settings.llm_base_url == custom_base_url
|
||||
assert updated_settings.llm_base_url != LITE_LLM_API_URL
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settings_with_litellm_default_preserves_custom_api_key_and_base_url(
|
||||
settings_store, mock_litellm_api, session_maker
|
||||
):
|
||||
# Arrange: User has both custom API key and base URL
|
||||
custom_api_key = 'sk-custom-user-api-key-67890'
|
||||
custom_base_url = 'https://api.another-llm-provider.com/v1'
|
||||
custom_model = 'openai/gpt-4'
|
||||
settings = Settings(
|
||||
llm_model=custom_model,
|
||||
llm_api_key=SecretStr(custom_api_key),
|
||||
llm_base_url=custom_base_url,
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'user@example.com'}),
|
||||
),
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
):
|
||||
# Act: Update settings with LiteLLM defaults
|
||||
updated_settings = await settings_store.update_settings_with_litellm_default(
|
||||
settings
|
||||
)
|
||||
|
||||
# Assert: All custom settings are preserved
|
||||
assert updated_settings is not None
|
||||
assert updated_settings.llm_model == custom_model
|
||||
assert updated_settings.llm_api_key.get_secret_value() == custom_api_key
|
||||
assert updated_settings.llm_base_url == custom_base_url
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settings_with_litellm_default_uses_default_api_key_when_none(
|
||||
settings_store, mock_litellm_api, session_maker
|
||||
):
|
||||
# Arrange: User has no API key set
|
||||
settings = Settings(llm_api_key=None)
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'user@example.com'}),
|
||||
),
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
):
|
||||
# Act: Update settings with LiteLLM defaults
|
||||
updated_settings = await settings_store.update_settings_with_litellm_default(
|
||||
settings
|
||||
)
|
||||
|
||||
# Assert: Default LiteLLM API key is assigned
|
||||
assert updated_settings is not None
|
||||
assert updated_settings.llm_api_key is not None
|
||||
assert updated_settings.llm_api_key.get_secret_value() == 'test_api_key'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settings_with_litellm_default_uses_default_base_url_when_none(
|
||||
settings_store, mock_litellm_api, session_maker
|
||||
):
|
||||
# Arrange: User has no base URL set
|
||||
settings = Settings(llm_base_url=None)
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'user@example.com'}),
|
||||
),
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://test.url'),
|
||||
):
|
||||
# Act: Update settings with LiteLLM defaults
|
||||
updated_settings = await settings_store.update_settings_with_litellm_default(
|
||||
settings
|
||||
)
|
||||
|
||||
# Assert: Default LiteLLM base URL is assigned (using mocked value)
|
||||
assert updated_settings is not None
|
||||
assert updated_settings.llm_base_url == 'http://test.url'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settings_with_litellm_default_handles_empty_api_key(
|
||||
settings_store, mock_litellm_api, session_maker
|
||||
):
|
||||
# Arrange: User has empty string as API key (edge case)
|
||||
settings = Settings(llm_api_key=SecretStr(''))
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'user@example.com'}),
|
||||
),
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
):
|
||||
# Act: Update settings with LiteLLM defaults
|
||||
updated_settings = await settings_store.update_settings_with_litellm_default(
|
||||
settings
|
||||
)
|
||||
|
||||
# Assert: Default API key is used (empty string treated as no key)
|
||||
assert updated_settings is not None
|
||||
assert updated_settings.llm_api_key.get_secret_value() == 'test_api_key'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settings_with_litellm_default_handles_empty_base_url(
|
||||
settings_store, mock_litellm_api, session_maker
|
||||
):
|
||||
# Arrange: User has empty string as base URL (edge case)
|
||||
settings = Settings(llm_base_url='')
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'user@example.com'}),
|
||||
),
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://test.url'),
|
||||
):
|
||||
# Act: Update settings with LiteLLM defaults
|
||||
updated_settings = await settings_store.update_settings_with_litellm_default(
|
||||
settings
|
||||
)
|
||||
|
||||
# Assert: Default base URL is used (empty string treated as no URL)
|
||||
assert updated_settings is not None
|
||||
assert updated_settings.llm_base_url == 'http://test.url'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settings_with_litellm_default_handles_whitespace_api_key(
|
||||
settings_store, mock_litellm_api, session_maker
|
||||
):
|
||||
# Arrange: User has whitespace-only API key (edge case)
|
||||
settings = Settings(llm_api_key=SecretStr(' '))
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'user@example.com'}),
|
||||
),
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
):
|
||||
# Act: Update settings with LiteLLM defaults
|
||||
updated_settings = await settings_store.update_settings_with_litellm_default(
|
||||
settings
|
||||
)
|
||||
|
||||
# Assert: Default API key is used (whitespace treated as no key)
|
||||
assert updated_settings is not None
|
||||
assert updated_settings.llm_api_key.get_secret_value() == 'test_api_key'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settings_with_litellm_default_handles_whitespace_base_url(
|
||||
settings_store, mock_litellm_api, session_maker
|
||||
):
|
||||
# Arrange: User has whitespace-only base URL (edge case)
|
||||
settings = Settings(llm_base_url=' ')
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'user@example.com'}),
|
||||
),
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://test.url'),
|
||||
):
|
||||
# Act: Update settings with LiteLLM defaults
|
||||
updated_settings = await settings_store.update_settings_with_litellm_default(
|
||||
settings
|
||||
)
|
||||
|
||||
# Assert: Default base URL is used (whitespace treated as no URL)
|
||||
assert updated_settings is not None
|
||||
assert updated_settings.llm_base_url == 'http://test.url'
|
||||
|
||||
|
||||
# Tests for version migration and helper methods
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_custom_settings_with_custom_base_url(settings_store):
|
||||
# Arrange: User with custom base URL (BYOR)
|
||||
with patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'):
|
||||
settings = Settings(llm_base_url='http://custom.url')
|
||||
|
||||
# Act: Check if has custom settings
|
||||
has_custom = settings_store._has_custom_settings(settings, None)
|
||||
|
||||
# Assert: Custom base URL detected
|
||||
assert has_custom is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_custom_settings_with_default_base_url(settings_store):
|
||||
# Arrange: User with default base URL
|
||||
with patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'):
|
||||
settings = Settings(llm_base_url='http://default.url')
|
||||
|
||||
# Act: Check if has custom settings
|
||||
has_custom = settings_store._has_custom_settings(settings, None)
|
||||
|
||||
# Assert: No custom settings (no model set)
|
||||
assert has_custom is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_custom_settings_with_no_model(settings_store):
|
||||
# Arrange: User with no model set
|
||||
with patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'):
|
||||
settings = Settings(llm_model=None, llm_base_url='http://default.url')
|
||||
|
||||
# Act: Check if has custom settings
|
||||
has_custom = settings_store._has_custom_settings(settings, None)
|
||||
|
||||
# Assert: No custom settings (using defaults)
|
||||
assert has_custom is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_custom_settings_with_empty_model(settings_store):
|
||||
# Arrange: User with empty model
|
||||
with patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'):
|
||||
settings = Settings(llm_model='', llm_base_url='http://default.url')
|
||||
|
||||
# Act: Check if has custom settings
|
||||
has_custom = settings_store._has_custom_settings(settings, None)
|
||||
|
||||
# Assert: No custom settings (empty treated as no model)
|
||||
assert has_custom is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_custom_settings_with_whitespace_model(settings_store):
|
||||
# Arrange: User with whitespace-only model
|
||||
with patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'):
|
||||
settings = Settings(llm_model=' ', llm_base_url='http://default.url')
|
||||
|
||||
# Act: Check if has custom settings
|
||||
has_custom = settings_store._has_custom_settings(settings, None)
|
||||
|
||||
# Assert: No custom settings (whitespace treated as no model)
|
||||
assert has_custom is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_custom_settings_with_custom_model(settings_store):
|
||||
# Arrange: User with custom model
|
||||
with patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'):
|
||||
settings = Settings(llm_model='gpt-4', llm_base_url='http://default.url')
|
||||
|
||||
# Act: Check if has custom settings
|
||||
has_custom = settings_store._has_custom_settings(settings, None)
|
||||
|
||||
# Assert: Custom model detected
|
||||
assert has_custom is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_custom_settings_matches_old_default_model(settings_store):
|
||||
# Arrange: User with old version and model matching old default
|
||||
with (
|
||||
patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'),
|
||||
patch('server.constants.CURRENT_USER_SETTINGS_VERSION', 5),
|
||||
patch(
|
||||
'server.constants.USER_SETTINGS_VERSION_TO_MODEL',
|
||||
{1: 'claude-3-5-sonnet-20241022'},
|
||||
),
|
||||
):
|
||||
settings = Settings(
|
||||
llm_model='litellm_proxy/prod/claude-3-5-sonnet-20241022',
|
||||
llm_base_url='http://default.url',
|
||||
)
|
||||
|
||||
# Act: Check if has custom settings
|
||||
has_custom = settings_store._has_custom_settings(settings, 1)
|
||||
|
||||
# Assert: Matches old default, so not custom
|
||||
assert has_custom is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_custom_settings_matches_old_default_by_base_name(settings_store):
|
||||
# Arrange: User with old version and model matching old default by base name
|
||||
with (
|
||||
patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'),
|
||||
patch('server.constants.CURRENT_USER_SETTINGS_VERSION', 5),
|
||||
patch(
|
||||
'server.constants.USER_SETTINGS_VERSION_TO_MODEL',
|
||||
{1: 'claude-3-5-sonnet-20241022'},
|
||||
),
|
||||
):
|
||||
settings = Settings(
|
||||
llm_model='anthropic/claude-3-5-sonnet-20241022',
|
||||
llm_base_url='http://default.url',
|
||||
)
|
||||
|
||||
# Act: Check if has custom settings
|
||||
has_custom = settings_store._has_custom_settings(settings, 1)
|
||||
|
||||
# Assert: Matches old default by base name, so not custom
|
||||
assert has_custom is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_custom_settings_with_old_version_but_custom_model(settings_store):
|
||||
# Arrange: User with old version but custom model
|
||||
with (
|
||||
patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'),
|
||||
patch('server.constants.CURRENT_USER_SETTINGS_VERSION', 5),
|
||||
patch(
|
||||
'server.constants.USER_SETTINGS_VERSION_TO_MODEL',
|
||||
{1: 'claude-3-5-sonnet-20241022'},
|
||||
),
|
||||
):
|
||||
settings = Settings(llm_model='gpt-4', llm_base_url='http://default.url')
|
||||
|
||||
# Act: Check if has custom settings
|
||||
has_custom = settings_store._has_custom_settings(settings, 1)
|
||||
|
||||
# Assert: Custom model detected
|
||||
assert has_custom is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_custom_settings_with_current_version(settings_store):
|
||||
# Arrange: User with current version
|
||||
with (
|
||||
patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'),
|
||||
patch('server.constants.CURRENT_USER_SETTINGS_VERSION', 5),
|
||||
patch(
|
||||
'server.constants.USER_SETTINGS_VERSION_TO_MODEL',
|
||||
{1: 'claude-3-5-sonnet-20241022', 5: 'claude-opus-4-5-20251101'},
|
||||
),
|
||||
):
|
||||
settings = Settings(
|
||||
llm_model='claude-3-5-sonnet-20241022', llm_base_url='http://default.url'
|
||||
)
|
||||
|
||||
# Act: Check if has custom settings
|
||||
has_custom = settings_store._has_custom_settings(settings, 5)
|
||||
|
||||
# Assert: Current version, so model is custom (not old default)
|
||||
assert has_custom is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_custom_settings_with_none_version(settings_store):
|
||||
# Arrange: User with no version
|
||||
with patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'):
|
||||
settings = Settings(
|
||||
llm_model='claude-3-5-sonnet-20241022', llm_base_url='http://default.url'
|
||||
)
|
||||
|
||||
# Act: Check if has custom settings
|
||||
has_custom = settings_store._has_custom_settings(settings, None)
|
||||
|
||||
# Assert: No version, so model is custom
|
||||
assert has_custom is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_custom_settings_with_invalid_version(settings_store):
|
||||
# Arrange: User with invalid version
|
||||
with (
|
||||
patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'),
|
||||
patch('server.constants.CURRENT_USER_SETTINGS_VERSION', 5),
|
||||
patch(
|
||||
'server.constants.USER_SETTINGS_VERSION_TO_MODEL',
|
||||
{1: 'claude-3-5-sonnet-20241022'},
|
||||
),
|
||||
):
|
||||
settings = Settings(
|
||||
llm_model='claude-3-5-sonnet-20241022', llm_base_url='http://default.url'
|
||||
)
|
||||
|
||||
# Act: Check if has custom settings
|
||||
has_custom = settings_store._has_custom_settings(settings, 99)
|
||||
|
||||
# Assert: Invalid version, so model is custom
|
||||
assert has_custom is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_custom_settings_normalizes_whitespace(settings_store):
|
||||
# Arrange: Settings with whitespace in values
|
||||
with patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'):
|
||||
settings = Settings(
|
||||
llm_model=' claude-3-5-sonnet-20241022 ',
|
||||
llm_base_url=' http://default.url ',
|
||||
)
|
||||
|
||||
# Act: Check if has custom settings
|
||||
has_custom = settings_store._has_custom_settings(settings, None)
|
||||
|
||||
# Assert: Whitespace is normalized, custom model detected
|
||||
assert has_custom is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settings_upgrades_user_from_old_defaults(
|
||||
settings_store, mock_litellm_api, session_maker
|
||||
):
|
||||
# Arrange: User with old version using old defaults
|
||||
old_version = 1
|
||||
old_model = 'litellm_proxy/prod/claude-3-5-sonnet-20241022'
|
||||
settings = Settings(llm_model=old_model, llm_base_url=LITE_LLM_API_URL)
|
||||
|
||||
# Use a consistent test URL
|
||||
test_base_url = 'http://test.url'
|
||||
|
||||
with (
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
patch(
|
||||
'server.constants.USER_SETTINGS_VERSION_TO_MODEL',
|
||||
{1: 'claude-3-5-sonnet-20241022', 5: 'claude-opus-4-5-20251101'},
|
||||
),
|
||||
patch(
|
||||
'storage.saas_settings_store.USER_SETTINGS_VERSION_TO_MODEL',
|
||||
{1: 'claude-3-5-sonnet-20241022', 5: 'claude-opus-4-5-20251101'},
|
||||
),
|
||||
patch('server.constants.CURRENT_USER_SETTINGS_VERSION', 5),
|
||||
patch('storage.saas_settings_store.CURRENT_USER_SETTINGS_VERSION', 5),
|
||||
patch('storage.saas_settings_store.LITE_LLM_API_URL', test_base_url),
|
||||
patch(
|
||||
'storage.saas_settings_store.get_default_litellm_model',
|
||||
return_value='litellm_proxy/prod/claude-opus-4-5-20251101',
|
||||
),
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'user@example.com'}),
|
||||
),
|
||||
):
|
||||
# Create existing user settings with old version
|
||||
with session_maker() as session:
|
||||
existing_settings = UserSettings(
|
||||
keycloak_user_id=settings_store.user_id,
|
||||
user_version=old_version,
|
||||
llm_model=old_model,
|
||||
llm_base_url=test_base_url,
|
||||
)
|
||||
session.add(existing_settings)
|
||||
session.commit()
|
||||
|
||||
# Update settings to use test_base_url
|
||||
# Set user_version to match the database so _has_custom_settings can detect old defaults
|
||||
settings = Settings(
|
||||
llm_model=old_model, llm_base_url=test_base_url, user_version=old_version
|
||||
)
|
||||
|
||||
# Act: Update settings
|
||||
updated_settings = await settings_store.update_settings_with_litellm_default(
|
||||
settings
|
||||
)
|
||||
|
||||
# Assert: Settings upgraded to new defaults
|
||||
assert updated_settings is not None
|
||||
assert (
|
||||
updated_settings.llm_model == 'litellm_proxy/prod/claude-opus-4-5-20251101'
|
||||
)
|
||||
assert updated_settings.llm_base_url == test_base_url
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settings_preserves_custom_settings_during_upgrade(
|
||||
settings_store, mock_litellm_api, session_maker
|
||||
):
|
||||
# Arrange: User with old version but custom settings
|
||||
old_version = 1
|
||||
custom_model = 'gpt-4'
|
||||
custom_base_url = 'http://custom.url'
|
||||
settings = Settings(llm_model=custom_model, llm_base_url=custom_base_url)
|
||||
|
||||
with (
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
patch(
|
||||
'server.constants.USER_SETTINGS_VERSION_TO_MODEL',
|
||||
{1: 'claude-3-5-sonnet-20241022'},
|
||||
),
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'user@example.com'}),
|
||||
),
|
||||
):
|
||||
# Create existing user settings with old version
|
||||
with session_maker() as session:
|
||||
existing_settings = UserSettings(
|
||||
keycloak_user_id=settings_store.user_id,
|
||||
user_version=old_version,
|
||||
llm_model=custom_model,
|
||||
llm_base_url=custom_base_url,
|
||||
)
|
||||
session.add(existing_settings)
|
||||
session.commit()
|
||||
|
||||
# Act: Update settings
|
||||
updated_settings = await settings_store.update_settings_with_litellm_default(
|
||||
settings
|
||||
)
|
||||
|
||||
# Assert: Custom settings preserved
|
||||
assert updated_settings is not None
|
||||
assert updated_settings.llm_model == custom_model
|
||||
assert updated_settings.llm_base_url == custom_base_url
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settings_migrates_billing_margin_v3_to_v4(
|
||||
settings_store, mock_litellm_api, session_maker
|
||||
):
|
||||
# Arrange: User with version 3 and billing margin
|
||||
old_version = 3
|
||||
billing_margin = 2.0
|
||||
max_budget = 10.0
|
||||
spend = 5.0
|
||||
|
||||
settings = Settings()
|
||||
|
||||
mock_get_response = AsyncMock()
|
||||
mock_get_response.is_success = True
|
||||
mock_get_response.json = MagicMock(
|
||||
return_value={'user_info': {'max_budget': max_budget, 'spend': spend}}
|
||||
)
|
||||
|
||||
mock_post_response = AsyncMock()
|
||||
mock_post_response.is_success = True
|
||||
mock_post_response.json = MagicMock(return_value={'key': 'test_api_key'})
|
||||
|
||||
with (
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'user@example.com'}),
|
||||
),
|
||||
patch('httpx.AsyncClient') as mock_client,
|
||||
):
|
||||
mock_client.return_value.__aenter__.return_value.get.return_value = (
|
||||
mock_get_response
|
||||
)
|
||||
mock_client.return_value.__aenter__.return_value.post.return_value = (
|
||||
mock_post_response
|
||||
)
|
||||
|
||||
# Create existing user settings with version 3 and billing margin
|
||||
with session_maker() as session:
|
||||
existing_settings = UserSettings(
|
||||
keycloak_user_id=settings_store.user_id,
|
||||
user_version=old_version,
|
||||
billing_margin=billing_margin,
|
||||
)
|
||||
session.add(existing_settings)
|
||||
session.commit()
|
||||
|
||||
# Act: Update settings
|
||||
updated_settings = await settings_store.update_settings_with_litellm_default(
|
||||
settings
|
||||
)
|
||||
|
||||
# Assert: Settings updated
|
||||
assert updated_settings is not None
|
||||
|
||||
# Assert: Billing margin applied to budget
|
||||
call_args = mock_client.return_value.__aenter__.return_value.post.call_args[1]
|
||||
assert call_args['json']['max_budget'] == max_budget * billing_margin
|
||||
assert call_args['json']['spend'] == spend * billing_margin
|
||||
|
||||
# Assert: Billing margin reset to 1.0
|
||||
with session_maker() as session:
|
||||
updated_user_settings = (
|
||||
session.query(UserSettings)
|
||||
.filter(UserSettings.keycloak_user_id == settings_store.user_id)
|
||||
.first()
|
||||
)
|
||||
assert updated_user_settings.billing_margin == 1.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settings_skips_billing_margin_migration_when_already_v4(
|
||||
settings_store, mock_litellm_api, session_maker
|
||||
):
|
||||
# Arrange: User with version 4
|
||||
version = 4
|
||||
billing_margin = 2.0
|
||||
max_budget = 10.0
|
||||
spend = 5.0
|
||||
|
||||
settings = Settings()
|
||||
|
||||
mock_get_response = AsyncMock()
|
||||
mock_get_response.is_success = True
|
||||
mock_get_response.json = MagicMock(
|
||||
return_value={'user_info': {'max_budget': max_budget, 'spend': spend}}
|
||||
)
|
||||
|
||||
mock_post_response = AsyncMock()
|
||||
mock_post_response.is_success = True
|
||||
mock_post_response.json = MagicMock(return_value={'key': 'test_api_key'})
|
||||
|
||||
with (
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'user@example.com'}),
|
||||
),
|
||||
patch('httpx.AsyncClient') as mock_client,
|
||||
):
|
||||
mock_client.return_value.__aenter__.return_value.get.return_value = (
|
||||
mock_get_response
|
||||
)
|
||||
mock_client.return_value.__aenter__.return_value.post.return_value = (
|
||||
mock_post_response
|
||||
)
|
||||
|
||||
# Create existing user settings with version 4
|
||||
with session_maker() as session:
|
||||
existing_settings = UserSettings(
|
||||
keycloak_user_id=settings_store.user_id,
|
||||
user_version=version,
|
||||
billing_margin=billing_margin,
|
||||
)
|
||||
session.add(existing_settings)
|
||||
session.commit()
|
||||
|
||||
# Act: Update settings
|
||||
updated_settings = await settings_store.update_settings_with_litellm_default(
|
||||
settings
|
||||
)
|
||||
|
||||
# Assert: Settings updated
|
||||
assert updated_settings is not None
|
||||
|
||||
# Assert: Billing margin NOT applied (version >= 4)
|
||||
call_args = mock_client.return_value.__aenter__.return_value.post.call_args[1]
|
||||
assert call_args['json']['max_budget'] == max_budget
|
||||
assert call_args['json']['spend'] == spend
|
||||
|
||||
@@ -5,7 +5,12 @@ import jwt
|
||||
import pytest
|
||||
from fastapi import Request
|
||||
from pydantic import SecretStr
|
||||
from server.auth.auth_error import BearerTokenError, CookieError, NoCredentialsError
|
||||
from server.auth.auth_error import (
|
||||
AuthError,
|
||||
BearerTokenError,
|
||||
CookieError,
|
||||
NoCredentialsError,
|
||||
)
|
||||
from server.auth.saas_user_auth import (
|
||||
SaasUserAuth,
|
||||
get_api_key_from_header,
|
||||
@@ -647,3 +652,97 @@ def test_get_api_key_from_header_bearer_with_empty_token():
|
||||
# Assert that empty string from Bearer is returned (current behavior)
|
||||
# This tests the current implementation behavior
|
||||
assert api_key == ''
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_saas_user_auth_from_signed_token_blocked_domain(mock_config):
|
||||
"""Test that saas_user_auth_from_signed_token raises AuthError when email domain is blocked."""
|
||||
# Arrange
|
||||
access_payload = {
|
||||
'sub': 'test_user_id',
|
||||
'exp': int(time.time()) + 3600,
|
||||
'email': 'user@colsch.us',
|
||||
'email_verified': True,
|
||||
}
|
||||
access_token = jwt.encode(access_payload, 'access_secret', algorithm='HS256')
|
||||
|
||||
token_payload = {
|
||||
'access_token': access_token,
|
||||
'refresh_token': 'test_refresh_token',
|
||||
}
|
||||
signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256')
|
||||
|
||||
with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker:
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
mock_domain_blocker.is_domain_blocked.return_value = True
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(AuthError) as exc_info:
|
||||
await saas_user_auth_from_signed_token(signed_token)
|
||||
|
||||
assert 'email domain is not allowed' in str(exc_info.value)
|
||||
mock_domain_blocker.is_domain_blocked.assert_called_once_with('user@colsch.us')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_saas_user_auth_from_signed_token_allowed_domain(mock_config):
|
||||
"""Test that saas_user_auth_from_signed_token succeeds when email domain is not blocked."""
|
||||
# Arrange
|
||||
access_payload = {
|
||||
'sub': 'test_user_id',
|
||||
'exp': int(time.time()) + 3600,
|
||||
'email': 'user@example.com',
|
||||
'email_verified': True,
|
||||
}
|
||||
access_token = jwt.encode(access_payload, 'access_secret', algorithm='HS256')
|
||||
|
||||
token_payload = {
|
||||
'access_token': access_token,
|
||||
'refresh_token': 'test_refresh_token',
|
||||
}
|
||||
signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256')
|
||||
|
||||
with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker:
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
|
||||
# Act
|
||||
result = await saas_user_auth_from_signed_token(signed_token)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, SaasUserAuth)
|
||||
assert result.user_id == 'test_user_id'
|
||||
assert result.email == 'user@example.com'
|
||||
mock_domain_blocker.is_domain_blocked.assert_called_once_with(
|
||||
'user@example.com'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_saas_user_auth_from_signed_token_domain_blocking_inactive(mock_config):
|
||||
"""Test that saas_user_auth_from_signed_token succeeds when domain blocking is not active."""
|
||||
# Arrange
|
||||
access_payload = {
|
||||
'sub': 'test_user_id',
|
||||
'exp': int(time.time()) + 3600,
|
||||
'email': 'user@colsch.us',
|
||||
'email_verified': True,
|
||||
}
|
||||
access_token = jwt.encode(access_payload, 'access_secret', algorithm='HS256')
|
||||
|
||||
token_payload = {
|
||||
'access_token': access_token,
|
||||
'refresh_token': 'test_refresh_token',
|
||||
}
|
||||
signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256')
|
||||
|
||||
with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker:
|
||||
mock_domain_blocker.is_active.return_value = False
|
||||
|
||||
# Act
|
||||
result = await saas_user_auth_from_signed_token(signed_token)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, SaasUserAuth)
|
||||
assert result.user_id == 'test_user_id'
|
||||
mock_domain_blocker.is_domain_blocked.assert_not_called()
|
||||
|
||||
1
enterprise/tests/unit/test_sharing/__init__.py
Normal file
1
enterprise/tests/unit/test_sharing/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for sharing package."""
|
||||
@@ -0,0 +1,91 @@
|
||||
"""Tests for public conversation models."""
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import uuid4
|
||||
|
||||
from server.sharing.shared_conversation_models import (
|
||||
SharedConversation,
|
||||
SharedConversationPage,
|
||||
SharedConversationSortOrder,
|
||||
)
|
||||
|
||||
|
||||
def test_public_conversation_creation():
|
||||
"""Test that SharedConversation can be created with all required fields."""
|
||||
conversation_id = uuid4()
|
||||
now = datetime.utcnow()
|
||||
|
||||
conversation = SharedConversation(
|
||||
id=conversation_id,
|
||||
created_by_user_id='test_user',
|
||||
sandbox_id='test_sandbox',
|
||||
title='Test Conversation',
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
selected_repository=None,
|
||||
parent_conversation_id=None,
|
||||
)
|
||||
|
||||
assert conversation.id == conversation_id
|
||||
assert conversation.title == 'Test Conversation'
|
||||
assert conversation.created_by_user_id == 'test_user'
|
||||
assert conversation.sandbox_id == 'test_sandbox'
|
||||
|
||||
|
||||
def test_public_conversation_page_creation():
|
||||
"""Test that SharedConversationPage can be created."""
|
||||
conversation_id = uuid4()
|
||||
now = datetime.utcnow()
|
||||
|
||||
conversation = SharedConversation(
|
||||
id=conversation_id,
|
||||
created_by_user_id='test_user',
|
||||
sandbox_id='test_sandbox',
|
||||
title='Test Conversation',
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
selected_repository=None,
|
||||
parent_conversation_id=None,
|
||||
)
|
||||
|
||||
page = SharedConversationPage(
|
||||
items=[conversation],
|
||||
next_page_id='next_page',
|
||||
)
|
||||
|
||||
assert len(page.items) == 1
|
||||
assert page.items[0].id == conversation_id
|
||||
assert page.next_page_id == 'next_page'
|
||||
|
||||
|
||||
def test_public_conversation_sort_order_enum():
|
||||
"""Test that SharedConversationSortOrder enum has expected values."""
|
||||
assert hasattr(SharedConversationSortOrder, 'CREATED_AT')
|
||||
assert hasattr(SharedConversationSortOrder, 'CREATED_AT_DESC')
|
||||
assert hasattr(SharedConversationSortOrder, 'UPDATED_AT')
|
||||
assert hasattr(SharedConversationSortOrder, 'UPDATED_AT_DESC')
|
||||
assert hasattr(SharedConversationSortOrder, 'TITLE')
|
||||
assert hasattr(SharedConversationSortOrder, 'TITLE_DESC')
|
||||
|
||||
|
||||
def test_public_conversation_optional_fields():
|
||||
"""Test that SharedConversation works with optional fields."""
|
||||
conversation_id = uuid4()
|
||||
parent_id = uuid4()
|
||||
now = datetime.utcnow()
|
||||
|
||||
conversation = SharedConversation(
|
||||
id=conversation_id,
|
||||
created_by_user_id='test_user',
|
||||
sandbox_id='test_sandbox',
|
||||
title='Test Conversation',
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
selected_repository='owner/repo',
|
||||
parent_conversation_id=parent_id,
|
||||
llm_model='gpt-4',
|
||||
)
|
||||
|
||||
assert conversation.selected_repository == 'owner/repo'
|
||||
assert conversation.parent_conversation_id == parent_id
|
||||
assert conversation.llm_model == 'gpt-4'
|
||||
@@ -0,0 +1,430 @@
|
||||
"""Tests for SharedConversationInfoService."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from typing import AsyncGenerator
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from server.sharing.shared_conversation_models import (
|
||||
SharedConversationSortOrder,
|
||||
)
|
||||
from server.sharing.sql_shared_conversation_info_service import (
|
||||
SQLSharedConversationInfoService,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationInfo,
|
||||
)
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
SQLAppConversationInfoService,
|
||||
)
|
||||
from openhands.app_server.user.specifiy_user_context import SpecifyUserContext
|
||||
from openhands.app_server.utils.sql_utils import Base
|
||||
from openhands.integrations.provider import ProviderType
|
||||
from openhands.sdk.llm import MetricsSnapshot
|
||||
from openhands.sdk.llm.utils.metrics import TokenUsage
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_engine():
|
||||
"""Create an async SQLite engine for testing."""
|
||||
engine = create_async_engine(
|
||||
'sqlite+aiosqlite:///:memory:',
|
||||
poolclass=StaticPool,
|
||||
connect_args={'check_same_thread': False},
|
||||
echo=False,
|
||||
)
|
||||
|
||||
# Create all tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Create an async session for testing."""
|
||||
async_session_maker = async_sessionmaker(
|
||||
async_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
async with async_session_maker() as db_session:
|
||||
yield db_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def shared_conversation_info_service(async_session):
|
||||
"""Create a SharedConversationInfoService for testing."""
|
||||
return SQLSharedConversationInfoService(db_session=async_session)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def app_conversation_service(async_session):
|
||||
"""Create an AppConversationInfoService for creating test data."""
|
||||
return SQLAppConversationInfoService(
|
||||
db_session=async_session, user_context=SpecifyUserContext(user_id=None)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_conversation_info():
|
||||
"""Create a sample conversation info for testing."""
|
||||
return AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='test_user',
|
||||
sandbox_id='test_sandbox',
|
||||
selected_repository='test/repo',
|
||||
selected_branch='main',
|
||||
git_provider=ProviderType.GITHUB,
|
||||
title='Test Conversation',
|
||||
trigger=ConversationTrigger.GUI,
|
||||
pr_number=[123],
|
||||
llm_model='gpt-4',
|
||||
metrics=MetricsSnapshot(
|
||||
accumulated_cost=1.5,
|
||||
max_budget_per_task=10.0,
|
||||
accumulated_token_usage=TokenUsage(
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
context_window=4096,
|
||||
per_turn_token=150,
|
||||
),
|
||||
),
|
||||
parent_conversation_id=None,
|
||||
sub_conversation_ids=[],
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
public=True, # Make it public for testing
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_private_conversation_info():
|
||||
"""Create a sample private conversation info for testing."""
|
||||
return AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='test_user',
|
||||
sandbox_id='test_sandbox_private',
|
||||
selected_repository='test/private_repo',
|
||||
selected_branch='main',
|
||||
git_provider=ProviderType.GITHUB,
|
||||
title='Private Conversation',
|
||||
trigger=ConversationTrigger.GUI,
|
||||
pr_number=[124],
|
||||
llm_model='gpt-4',
|
||||
metrics=MetricsSnapshot(
|
||||
accumulated_cost=2.0,
|
||||
max_budget_per_task=10.0,
|
||||
accumulated_token_usage=TokenUsage(
|
||||
prompt_tokens=200,
|
||||
completion_tokens=100,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
context_window=4096,
|
||||
per_turn_token=300,
|
||||
),
|
||||
),
|
||||
parent_conversation_id=None,
|
||||
sub_conversation_ids=[],
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
public=False, # Make it private
|
||||
)
|
||||
|
||||
|
||||
class TestSharedConversationInfoService:
|
||||
"""Test cases for SharedConversationInfoService."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_shared_conversation_info_returns_public_conversation(
|
||||
self,
|
||||
shared_conversation_info_service,
|
||||
app_conversation_service,
|
||||
sample_conversation_info,
|
||||
):
|
||||
"""Test that get_shared_conversation_info returns a public conversation."""
|
||||
# Create a public conversation
|
||||
await app_conversation_service.save_app_conversation_info(
|
||||
sample_conversation_info
|
||||
)
|
||||
|
||||
# Retrieve it via public service
|
||||
result = await shared_conversation_info_service.get_shared_conversation_info(
|
||||
sample_conversation_info.id
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.id == sample_conversation_info.id
|
||||
assert result.title == sample_conversation_info.title
|
||||
assert result.created_by_user_id == sample_conversation_info.created_by_user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_shared_conversation_info_returns_none_for_private_conversation(
|
||||
self,
|
||||
shared_conversation_info_service,
|
||||
app_conversation_service,
|
||||
sample_private_conversation_info,
|
||||
):
|
||||
"""Test that get_shared_conversation_info returns None for private conversations."""
|
||||
# Create a private conversation
|
||||
await app_conversation_service.save_app_conversation_info(
|
||||
sample_private_conversation_info
|
||||
)
|
||||
|
||||
# Try to retrieve it via public service
|
||||
result = await shared_conversation_info_service.get_shared_conversation_info(
|
||||
sample_private_conversation_info.id
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_shared_conversation_info_returns_none_for_nonexistent_conversation(
|
||||
self, shared_conversation_info_service
|
||||
):
|
||||
"""Test that get_shared_conversation_info returns None for nonexistent conversations."""
|
||||
nonexistent_id = uuid4()
|
||||
result = await shared_conversation_info_service.get_shared_conversation_info(
|
||||
nonexistent_id
|
||||
)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_shared_conversation_info_returns_only_public_conversations(
|
||||
self,
|
||||
shared_conversation_info_service,
|
||||
app_conversation_service,
|
||||
sample_conversation_info,
|
||||
sample_private_conversation_info,
|
||||
):
|
||||
"""Test that search only returns public conversations."""
|
||||
# Create both public and private conversations
|
||||
await app_conversation_service.save_app_conversation_info(
|
||||
sample_conversation_info
|
||||
)
|
||||
await app_conversation_service.save_app_conversation_info(
|
||||
sample_private_conversation_info
|
||||
)
|
||||
|
||||
# Search for all conversations
|
||||
result = (
|
||||
await shared_conversation_info_service.search_shared_conversation_info()
|
||||
)
|
||||
|
||||
# Should only return the public conversation
|
||||
assert len(result.items) == 1
|
||||
assert result.items[0].id == sample_conversation_info.id
|
||||
assert result.items[0].title == sample_conversation_info.title
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_shared_conversation_info_with_title_filter(
|
||||
self,
|
||||
shared_conversation_info_service,
|
||||
app_conversation_service,
|
||||
sample_conversation_info,
|
||||
):
|
||||
"""Test searching with title filter."""
|
||||
# Create a public conversation
|
||||
await app_conversation_service.save_app_conversation_info(
|
||||
sample_conversation_info
|
||||
)
|
||||
|
||||
# Search with matching title
|
||||
result = await shared_conversation_info_service.search_shared_conversation_info(
|
||||
title__contains='Test'
|
||||
)
|
||||
assert len(result.items) == 1
|
||||
|
||||
# Search with non-matching title
|
||||
result = await shared_conversation_info_service.search_shared_conversation_info(
|
||||
title__contains='NonExistent'
|
||||
)
|
||||
assert len(result.items) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_shared_conversation_info_with_sort_order(
|
||||
self,
|
||||
shared_conversation_info_service,
|
||||
app_conversation_service,
|
||||
):
|
||||
"""Test searching with different sort orders."""
|
||||
# Create multiple public conversations with different titles and timestamps
|
||||
conv1 = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='test_user',
|
||||
sandbox_id='test_sandbox_1',
|
||||
title='A First Conversation',
|
||||
created_at=datetime(2023, 1, 1, tzinfo=UTC),
|
||||
updated_at=datetime(2023, 1, 1, tzinfo=UTC),
|
||||
public=True,
|
||||
metrics=MetricsSnapshot(
|
||||
accumulated_cost=0.0,
|
||||
max_budget_per_task=10.0,
|
||||
accumulated_token_usage=TokenUsage(),
|
||||
),
|
||||
)
|
||||
conv2 = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='test_user',
|
||||
sandbox_id='test_sandbox_2',
|
||||
title='B Second Conversation',
|
||||
created_at=datetime(2023, 1, 2, tzinfo=UTC),
|
||||
updated_at=datetime(2023, 1, 2, tzinfo=UTC),
|
||||
public=True,
|
||||
metrics=MetricsSnapshot(
|
||||
accumulated_cost=0.0,
|
||||
max_budget_per_task=10.0,
|
||||
accumulated_token_usage=TokenUsage(),
|
||||
),
|
||||
)
|
||||
|
||||
await app_conversation_service.save_app_conversation_info(conv1)
|
||||
await app_conversation_service.save_app_conversation_info(conv2)
|
||||
|
||||
# Test sort by title ascending
|
||||
result = await shared_conversation_info_service.search_shared_conversation_info(
|
||||
sort_order=SharedConversationSortOrder.TITLE
|
||||
)
|
||||
assert len(result.items) == 2
|
||||
assert result.items[0].title == 'A First Conversation'
|
||||
assert result.items[1].title == 'B Second Conversation'
|
||||
|
||||
# Test sort by title descending
|
||||
result = await shared_conversation_info_service.search_shared_conversation_info(
|
||||
sort_order=SharedConversationSortOrder.TITLE_DESC
|
||||
)
|
||||
assert len(result.items) == 2
|
||||
assert result.items[0].title == 'B Second Conversation'
|
||||
assert result.items[1].title == 'A First Conversation'
|
||||
|
||||
# Test sort by created_at ascending
|
||||
result = await shared_conversation_info_service.search_shared_conversation_info(
|
||||
sort_order=SharedConversationSortOrder.CREATED_AT
|
||||
)
|
||||
assert len(result.items) == 2
|
||||
assert result.items[0].id == conv1.id
|
||||
assert result.items[1].id == conv2.id
|
||||
|
||||
# Test sort by created_at descending (default)
|
||||
result = await shared_conversation_info_service.search_shared_conversation_info(
|
||||
sort_order=SharedConversationSortOrder.CREATED_AT_DESC
|
||||
)
|
||||
assert len(result.items) == 2
|
||||
assert result.items[0].id == conv2.id
|
||||
assert result.items[1].id == conv1.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_shared_conversation_info(
|
||||
self,
|
||||
shared_conversation_info_service,
|
||||
app_conversation_service,
|
||||
sample_conversation_info,
|
||||
sample_private_conversation_info,
|
||||
):
|
||||
"""Test counting public conversations."""
|
||||
# Initially should be 0
|
||||
count = await shared_conversation_info_service.count_shared_conversation_info()
|
||||
assert count == 0
|
||||
|
||||
# Create a public conversation
|
||||
await app_conversation_service.save_app_conversation_info(
|
||||
sample_conversation_info
|
||||
)
|
||||
count = await shared_conversation_info_service.count_shared_conversation_info()
|
||||
assert count == 1
|
||||
|
||||
# Create a private conversation - count should remain 1
|
||||
await app_conversation_service.save_app_conversation_info(
|
||||
sample_private_conversation_info
|
||||
)
|
||||
count = await shared_conversation_info_service.count_shared_conversation_info()
|
||||
assert count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_get_shared_conversation_info(
|
||||
self,
|
||||
shared_conversation_info_service,
|
||||
app_conversation_service,
|
||||
sample_conversation_info,
|
||||
sample_private_conversation_info,
|
||||
):
|
||||
"""Test batch getting public conversations."""
|
||||
# Create both public and private conversations
|
||||
await app_conversation_service.save_app_conversation_info(
|
||||
sample_conversation_info
|
||||
)
|
||||
await app_conversation_service.save_app_conversation_info(
|
||||
sample_private_conversation_info
|
||||
)
|
||||
|
||||
# Batch get both conversations
|
||||
result = (
|
||||
await shared_conversation_info_service.batch_get_shared_conversation_info(
|
||||
[sample_conversation_info.id, sample_private_conversation_info.id]
|
||||
)
|
||||
)
|
||||
|
||||
# Should return the public one and None for the private one
|
||||
assert len(result) == 2
|
||||
assert result[0] is not None
|
||||
assert result[0].id == sample_conversation_info.id
|
||||
assert result[1] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_with_pagination(
|
||||
self,
|
||||
shared_conversation_info_service,
|
||||
app_conversation_service,
|
||||
):
|
||||
"""Test search with pagination."""
|
||||
# Create multiple public conversations
|
||||
conversations = []
|
||||
for i in range(5):
|
||||
conv = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='test_user',
|
||||
sandbox_id=f'test_sandbox_{i}',
|
||||
title=f'Conversation {i}',
|
||||
created_at=datetime(2023, 1, i + 1, tzinfo=UTC),
|
||||
updated_at=datetime(2023, 1, i + 1, tzinfo=UTC),
|
||||
public=True,
|
||||
metrics=MetricsSnapshot(
|
||||
accumulated_cost=0.0,
|
||||
max_budget_per_task=10.0,
|
||||
accumulated_token_usage=TokenUsage(),
|
||||
),
|
||||
)
|
||||
conversations.append(conv)
|
||||
await app_conversation_service.save_app_conversation_info(conv)
|
||||
|
||||
# Get first page with limit 2
|
||||
result = await shared_conversation_info_service.search_shared_conversation_info(
|
||||
limit=2, sort_order=SharedConversationSortOrder.CREATED_AT
|
||||
)
|
||||
assert len(result.items) == 2
|
||||
assert result.next_page_id is not None
|
||||
|
||||
# Get next page
|
||||
result2 = (
|
||||
await shared_conversation_info_service.search_shared_conversation_info(
|
||||
limit=2,
|
||||
page_id=result.next_page_id,
|
||||
sort_order=SharedConversationSortOrder.CREATED_AT,
|
||||
)
|
||||
)
|
||||
assert len(result2.items) == 2
|
||||
assert result2.next_page_id is not None
|
||||
|
||||
# Verify no overlap between pages
|
||||
page1_ids = {item.id for item in result.items}
|
||||
page2_ids = {item.id for item in result2.items}
|
||||
assert page1_ids.isdisjoint(page2_ids)
|
||||
@@ -0,0 +1,365 @@
|
||||
"""Tests for SharedEventService."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from server.sharing.filesystem_shared_event_service import (
|
||||
SharedEventServiceImpl,
|
||||
)
|
||||
from server.sharing.shared_conversation_info_service import (
|
||||
SharedConversationInfoService,
|
||||
)
|
||||
from server.sharing.shared_conversation_models import SharedConversation
|
||||
|
||||
from openhands.agent_server.models import EventPage, EventSortOrder
|
||||
from openhands.app_server.event.event_service import EventService
|
||||
from openhands.sdk.llm import MetricsSnapshot
|
||||
from openhands.sdk.llm.utils.metrics import TokenUsage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_shared_conversation_info_service():
|
||||
"""Create a mock SharedConversationInfoService."""
|
||||
return AsyncMock(spec=SharedConversationInfoService)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_event_service():
|
||||
"""Create a mock EventService."""
|
||||
return AsyncMock(spec=EventService)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def shared_event_service(mock_shared_conversation_info_service, mock_event_service):
|
||||
"""Create a SharedEventService for testing."""
|
||||
return SharedEventServiceImpl(
|
||||
shared_conversation_info_service=mock_shared_conversation_info_service,
|
||||
event_service=mock_event_service,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_public_conversation():
|
||||
"""Create a sample public conversation."""
|
||||
return SharedConversation(
|
||||
id=uuid4(),
|
||||
created_by_user_id='test_user',
|
||||
sandbox_id='test_sandbox',
|
||||
title='Test Public Conversation',
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
metrics=MetricsSnapshot(
|
||||
accumulated_cost=0.0,
|
||||
max_budget_per_task=10.0,
|
||||
accumulated_token_usage=TokenUsage(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_event():
|
||||
"""Create a sample event."""
|
||||
# For testing purposes, we'll just use a mock that the EventPage can accept
|
||||
# The actual event creation is complex and not the focus of these tests
|
||||
return None
|
||||
|
||||
|
||||
class TestSharedEventService:
|
||||
"""Test cases for SharedEventService."""
|
||||
|
||||
async def test_get_shared_event_returns_event_for_public_conversation(
|
||||
self,
|
||||
shared_event_service,
|
||||
mock_shared_conversation_info_service,
|
||||
mock_event_service,
|
||||
sample_public_conversation,
|
||||
sample_event,
|
||||
):
|
||||
"""Test that get_shared_event returns an event for a public conversation."""
|
||||
conversation_id = sample_public_conversation.id
|
||||
event_id = 'test_event_id'
|
||||
|
||||
# Mock the public conversation service to return a public conversation
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
|
||||
|
||||
# Mock the event service to return an event
|
||||
mock_event_service.get_event.return_value = sample_event
|
||||
|
||||
# Call the method
|
||||
result = await shared_event_service.get_shared_event(conversation_id, event_id)
|
||||
|
||||
# Verify the result
|
||||
assert result == sample_event
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
|
||||
conversation_id
|
||||
)
|
||||
mock_event_service.get_event.assert_called_once_with(event_id)
|
||||
|
||||
async def test_get_shared_event_returns_none_for_private_conversation(
|
||||
self,
|
||||
shared_event_service,
|
||||
mock_shared_conversation_info_service,
|
||||
mock_event_service,
|
||||
):
|
||||
"""Test that get_shared_event returns None for a private conversation."""
|
||||
conversation_id = uuid4()
|
||||
event_id = 'test_event_id'
|
||||
|
||||
# Mock the public conversation service to return None (private conversation)
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = None
|
||||
|
||||
# Call the method
|
||||
result = await shared_event_service.get_shared_event(conversation_id, event_id)
|
||||
|
||||
# Verify the result
|
||||
assert result is None
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
|
||||
conversation_id
|
||||
)
|
||||
# Event service should not be called
|
||||
mock_event_service.get_event.assert_not_called()
|
||||
|
||||
async def test_search_shared_events_returns_events_for_public_conversation(
|
||||
self,
|
||||
shared_event_service,
|
||||
mock_shared_conversation_info_service,
|
||||
mock_event_service,
|
||||
sample_public_conversation,
|
||||
sample_event,
|
||||
):
|
||||
"""Test that search_shared_events returns events for a public conversation."""
|
||||
conversation_id = sample_public_conversation.id
|
||||
|
||||
# Mock the public conversation service to return a public conversation
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
|
||||
|
||||
# Mock the event service to return events
|
||||
mock_event_page = EventPage(items=[], next_page_id=None)
|
||||
mock_event_service.search_events.return_value = mock_event_page
|
||||
|
||||
# Call the method
|
||||
result = await shared_event_service.search_shared_events(
|
||||
conversation_id=conversation_id,
|
||||
kind__eq='ActionEvent',
|
||||
limit=10,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == mock_event_page
|
||||
assert len(result.items) == 0 # Empty list as we mocked
|
||||
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
|
||||
conversation_id
|
||||
)
|
||||
mock_event_service.search_events.assert_called_once_with(
|
||||
conversation_id__eq=conversation_id,
|
||||
kind__eq='ActionEvent',
|
||||
timestamp__gte=None,
|
||||
timestamp__lt=None,
|
||||
sort_order=EventSortOrder.TIMESTAMP,
|
||||
page_id=None,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
async def test_search_shared_events_returns_empty_for_private_conversation(
|
||||
self,
|
||||
shared_event_service,
|
||||
mock_shared_conversation_info_service,
|
||||
mock_event_service,
|
||||
):
|
||||
"""Test that search_shared_events returns empty page for a private conversation."""
|
||||
conversation_id = uuid4()
|
||||
|
||||
# Mock the public conversation service to return None (private conversation)
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = None
|
||||
|
||||
# Call the method
|
||||
result = await shared_event_service.search_shared_events(
|
||||
conversation_id=conversation_id,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, EventPage)
|
||||
assert len(result.items) == 0
|
||||
assert result.next_page_id is None
|
||||
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
|
||||
conversation_id
|
||||
)
|
||||
# Event service should not be called
|
||||
mock_event_service.search_events.assert_not_called()
|
||||
|
||||
async def test_count_shared_events_returns_count_for_public_conversation(
|
||||
self,
|
||||
shared_event_service,
|
||||
mock_shared_conversation_info_service,
|
||||
mock_event_service,
|
||||
sample_public_conversation,
|
||||
):
|
||||
"""Test that count_shared_events returns count for a public conversation."""
|
||||
conversation_id = sample_public_conversation.id
|
||||
|
||||
# Mock the public conversation service to return a public conversation
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
|
||||
|
||||
# Mock the event service to return a count
|
||||
mock_event_service.count_events.return_value = 5
|
||||
|
||||
# Call the method
|
||||
result = await shared_event_service.count_shared_events(
|
||||
conversation_id=conversation_id,
|
||||
kind__eq='ActionEvent',
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == 5
|
||||
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
|
||||
conversation_id
|
||||
)
|
||||
mock_event_service.count_events.assert_called_once_with(
|
||||
conversation_id__eq=conversation_id,
|
||||
kind__eq='ActionEvent',
|
||||
timestamp__gte=None,
|
||||
timestamp__lt=None,
|
||||
sort_order=EventSortOrder.TIMESTAMP,
|
||||
)
|
||||
|
||||
async def test_count_shared_events_returns_zero_for_private_conversation(
|
||||
self,
|
||||
shared_event_service,
|
||||
mock_shared_conversation_info_service,
|
||||
mock_event_service,
|
||||
):
|
||||
"""Test that count_shared_events returns 0 for a private conversation."""
|
||||
conversation_id = uuid4()
|
||||
|
||||
# Mock the public conversation service to return None (private conversation)
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = None
|
||||
|
||||
# Call the method
|
||||
result = await shared_event_service.count_shared_events(
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == 0
|
||||
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
|
||||
conversation_id
|
||||
)
|
||||
# Event service should not be called
|
||||
mock_event_service.count_events.assert_not_called()
|
||||
|
||||
async def test_batch_get_shared_events_returns_events_for_public_conversation(
|
||||
self,
|
||||
shared_event_service,
|
||||
mock_shared_conversation_info_service,
|
||||
mock_event_service,
|
||||
sample_public_conversation,
|
||||
sample_event,
|
||||
):
|
||||
"""Test that batch_get_shared_events returns events for a public conversation."""
|
||||
conversation_id = sample_public_conversation.id
|
||||
event_ids = ['event1', 'event2']
|
||||
|
||||
# Mock the public conversation service to return a public conversation
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
|
||||
|
||||
# Mock the event service to return events
|
||||
mock_event_service.get_event.side_effect = [sample_event, None]
|
||||
|
||||
# Call the method
|
||||
result = await shared_event_service.batch_get_shared_events(
|
||||
conversation_id, event_ids
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert len(result) == 2
|
||||
assert result[0] == sample_event
|
||||
assert result[1] is None
|
||||
|
||||
# Verify that get_shared_conversation_info was called for each event
|
||||
assert (
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.call_count
|
||||
== 2
|
||||
)
|
||||
# Verify that get_event was called for each event
|
||||
assert mock_event_service.get_event.call_count == 2
|
||||
|
||||
async def test_batch_get_shared_events_returns_none_for_private_conversation(
|
||||
self,
|
||||
shared_event_service,
|
||||
mock_shared_conversation_info_service,
|
||||
mock_event_service,
|
||||
):
|
||||
"""Test that batch_get_shared_events returns None for a private conversation."""
|
||||
conversation_id = uuid4()
|
||||
event_ids = ['event1', 'event2']
|
||||
|
||||
# Mock the public conversation service to return None (private conversation)
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = None
|
||||
|
||||
# Call the method
|
||||
result = await shared_event_service.batch_get_shared_events(
|
||||
conversation_id, event_ids
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert len(result) == 2
|
||||
assert result[0] is None
|
||||
assert result[1] is None
|
||||
|
||||
# Verify that get_shared_conversation_info was called for each event
|
||||
assert (
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.call_count
|
||||
== 2
|
||||
)
|
||||
# Event service should not be called
|
||||
mock_event_service.get_event.assert_not_called()
|
||||
|
||||
async def test_search_shared_events_with_all_parameters(
|
||||
self,
|
||||
shared_event_service,
|
||||
mock_shared_conversation_info_service,
|
||||
mock_event_service,
|
||||
sample_public_conversation,
|
||||
):
|
||||
"""Test search_shared_events with all parameters."""
|
||||
conversation_id = sample_public_conversation.id
|
||||
timestamp_gte = datetime(2023, 1, 1, tzinfo=UTC)
|
||||
timestamp_lt = datetime(2023, 12, 31, tzinfo=UTC)
|
||||
|
||||
# Mock the public conversation service to return a public conversation
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
|
||||
|
||||
# Mock the event service to return events
|
||||
mock_event_page = EventPage(items=[], next_page_id='next_page')
|
||||
mock_event_service.search_events.return_value = mock_event_page
|
||||
|
||||
# Call the method with all parameters
|
||||
result = await shared_event_service.search_shared_events(
|
||||
conversation_id=conversation_id,
|
||||
kind__eq='ObservationEvent',
|
||||
timestamp__gte=timestamp_gte,
|
||||
timestamp__lt=timestamp_lt,
|
||||
sort_order=EventSortOrder.TIMESTAMP_DESC,
|
||||
page_id='current_page',
|
||||
limit=50,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == mock_event_page
|
||||
|
||||
mock_event_service.search_events.assert_called_once_with(
|
||||
conversation_id__eq=conversation_id,
|
||||
kind__eq='ObservationEvent',
|
||||
timestamp__gte=timestamp_gte,
|
||||
timestamp__lt=timestamp_lt,
|
||||
sort_order=EventSortOrder.TIMESTAMP_DESC,
|
||||
page_id='current_page',
|
||||
limit=50,
|
||||
)
|
||||
@@ -1,6 +1,8 @@
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from keycloak.exceptions import KeycloakConnectionError, KeycloakError
|
||||
from server.auth.token_manager import TokenManager
|
||||
from sqlalchemy.orm import Session
|
||||
from storage.offline_token_store import OfflineTokenStore
|
||||
from storage.stored_offline_token import StoredOfflineToken
|
||||
@@ -32,6 +34,14 @@ def token_store(mock_session_maker, mock_config):
|
||||
return OfflineTokenStore('test_user_id', mock_session_maker, mock_config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def token_manager():
|
||||
with patch('server.config.get_config') as mock_get_config:
|
||||
mock_config = mock_get_config.return_value
|
||||
mock_config.jwt_secret.get_secret_value.return_value = 'test_secret'
|
||||
return TokenManager(external=False)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_token_new_record(token_store, mock_session):
|
||||
# Setup
|
||||
@@ -109,3 +119,419 @@ async def test_get_instance(mock_config):
|
||||
assert isinstance(result, OfflineTokenStore)
|
||||
assert result.user_id == test_user_id
|
||||
assert result.config == mock_config
|
||||
|
||||
|
||||
class TestCheckDuplicateBaseEmail:
|
||||
"""Test cases for check_duplicate_base_email method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_duplicate_base_email_no_plus_modifier(self, token_manager):
|
||||
"""Test that emails without + modifier are still checked for duplicates."""
|
||||
# Arrange
|
||||
email = 'joe@example.com'
|
||||
current_user_id = 'user123'
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
token_manager, '_query_users_by_wildcard_pattern'
|
||||
) as mock_query,
|
||||
patch.object(token_manager, '_find_duplicate_in_users') as mock_find,
|
||||
):
|
||||
mock_find.return_value = False
|
||||
mock_query.return_value = {}
|
||||
|
||||
# Act
|
||||
result = await token_manager.check_duplicate_base_email(
|
||||
email, current_user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
mock_query.assert_called_once()
|
||||
mock_find.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_duplicate_base_email_empty_email(self, token_manager):
|
||||
"""Test that empty email returns False."""
|
||||
# Arrange
|
||||
email = ''
|
||||
current_user_id = 'user123'
|
||||
|
||||
# Act
|
||||
result = await token_manager.check_duplicate_base_email(email, current_user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_duplicate_base_email_invalid_email(self, token_manager):
|
||||
"""Test that invalid email returns False."""
|
||||
# Arrange
|
||||
email = 'invalid-email'
|
||||
current_user_id = 'user123'
|
||||
|
||||
# Act
|
||||
result = await token_manager.check_duplicate_base_email(email, current_user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_duplicate_base_email_duplicate_found(self, token_manager):
|
||||
"""Test that duplicate email is detected when found."""
|
||||
# Arrange
|
||||
email = 'joe+test@example.com'
|
||||
current_user_id = 'user123'
|
||||
existing_user = {
|
||||
'id': 'existing_user_id',
|
||||
'email': 'joe@example.com',
|
||||
}
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
token_manager, '_query_users_by_wildcard_pattern'
|
||||
) as mock_query,
|
||||
patch.object(token_manager, '_find_duplicate_in_users') as mock_find,
|
||||
):
|
||||
mock_find.return_value = True
|
||||
mock_query.return_value = {'existing_user_id': existing_user}
|
||||
|
||||
# Act
|
||||
result = await token_manager.check_duplicate_base_email(
|
||||
email, current_user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_query.assert_called_once()
|
||||
mock_find.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_duplicate_base_email_no_duplicate(self, token_manager):
|
||||
"""Test that no duplicate is found when none exists."""
|
||||
# Arrange
|
||||
email = 'joe+test@example.com'
|
||||
current_user_id = 'user123'
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
token_manager, '_query_users_by_wildcard_pattern'
|
||||
) as mock_query,
|
||||
patch.object(token_manager, '_find_duplicate_in_users') as mock_find,
|
||||
):
|
||||
mock_find.return_value = False
|
||||
mock_query.return_value = {}
|
||||
|
||||
# Act
|
||||
result = await token_manager.check_duplicate_base_email(
|
||||
email, current_user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_duplicate_base_email_keycloak_connection_error(
|
||||
self, token_manager
|
||||
):
|
||||
"""Test that KeycloakConnectionError triggers retry and raises RetryError."""
|
||||
# Arrange
|
||||
email = 'joe+test@example.com'
|
||||
current_user_id = 'user123'
|
||||
|
||||
with patch.object(
|
||||
token_manager, '_query_users_by_wildcard_pattern'
|
||||
) as mock_query:
|
||||
mock_query.side_effect = KeycloakConnectionError('Connection failed')
|
||||
|
||||
# Act & Assert
|
||||
# KeycloakConnectionError is re-raised, which triggers retry decorator
|
||||
# After retries exhaust (2 attempts), it raises RetryError
|
||||
from tenacity import RetryError
|
||||
|
||||
with pytest.raises(RetryError):
|
||||
await token_manager.check_duplicate_base_email(email, current_user_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_duplicate_base_email_general_exception(self, token_manager):
|
||||
"""Test that general exceptions are handled gracefully."""
|
||||
# Arrange
|
||||
email = 'joe+test@example.com'
|
||||
current_user_id = 'user123'
|
||||
|
||||
with patch.object(
|
||||
token_manager, '_query_users_by_wildcard_pattern'
|
||||
) as mock_query:
|
||||
mock_query.side_effect = Exception('Unexpected error')
|
||||
|
||||
# Act
|
||||
result = await token_manager.check_duplicate_base_email(
|
||||
email, current_user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestQueryUsersByWildcardPattern:
|
||||
"""Test cases for _query_users_by_wildcard_pattern method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_users_by_wildcard_pattern_success_with_search(
|
||||
self, token_manager
|
||||
):
|
||||
"""Test successful query using search parameter."""
|
||||
# Arrange
|
||||
local_part = 'joe'
|
||||
domain = 'example.com'
|
||||
mock_users = [
|
||||
{'id': 'user1', 'email': 'joe@example.com'},
|
||||
{'id': 'user2', 'email': 'joe+test@example.com'},
|
||||
]
|
||||
|
||||
with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin:
|
||||
mock_admin = MagicMock()
|
||||
mock_admin.a_get_users = AsyncMock(return_value=mock_users)
|
||||
mock_get_admin.return_value = mock_admin
|
||||
|
||||
# Act
|
||||
result = await token_manager._query_users_by_wildcard_pattern(
|
||||
local_part, domain
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert 'user1' in result
|
||||
assert 'user2' in result
|
||||
mock_admin.a_get_users.assert_called_once_with(
|
||||
{'search': 'joe*@example.com'}
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_users_by_wildcard_pattern_fallback_to_q(self, token_manager):
|
||||
"""Test fallback to q parameter when search fails."""
|
||||
# Arrange
|
||||
local_part = 'joe'
|
||||
domain = 'example.com'
|
||||
mock_users = [{'id': 'user1', 'email': 'joe@example.com'}]
|
||||
|
||||
with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin:
|
||||
mock_admin = MagicMock()
|
||||
# First call fails, second succeeds
|
||||
mock_admin.a_get_users = AsyncMock(
|
||||
side_effect=[Exception('Search failed'), mock_users]
|
||||
)
|
||||
mock_get_admin.return_value = mock_admin
|
||||
|
||||
# Act
|
||||
result = await token_manager._query_users_by_wildcard_pattern(
|
||||
local_part, domain
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
assert 'user1' in result
|
||||
assert mock_admin.a_get_users.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_users_by_wildcard_pattern_empty_result(self, token_manager):
|
||||
"""Test query returns empty dict when no users found."""
|
||||
# Arrange
|
||||
local_part = 'joe'
|
||||
domain = 'example.com'
|
||||
|
||||
with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin:
|
||||
mock_admin = MagicMock()
|
||||
mock_admin.a_get_users = AsyncMock(return_value=[])
|
||||
mock_get_admin.return_value = mock_admin
|
||||
|
||||
# Act
|
||||
result = await token_manager._query_users_by_wildcard_pattern(
|
||||
local_part, domain
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == {}
|
||||
|
||||
|
||||
class TestFindDuplicateInUsers:
|
||||
"""Test cases for _find_duplicate_in_users method."""
|
||||
|
||||
def test_find_duplicate_in_users_with_regex_match(self, token_manager):
|
||||
"""Test finding duplicate using regex pattern."""
|
||||
# Arrange
|
||||
users = {
|
||||
'user1': {'id': 'user1', 'email': 'joe@example.com'},
|
||||
'user2': {'id': 'user2', 'email': 'joe+test@example.com'},
|
||||
}
|
||||
base_email = 'joe@example.com'
|
||||
current_user_id = 'user3'
|
||||
|
||||
# Act
|
||||
result = token_manager._find_duplicate_in_users(
|
||||
users, base_email, current_user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
def test_find_duplicate_in_users_fallback_to_simple_matching(self, token_manager):
|
||||
"""Test fallback to simple matching when regex pattern is None."""
|
||||
# Arrange
|
||||
users = {
|
||||
'user1': {'id': 'user1', 'email': 'joe@example.com'},
|
||||
}
|
||||
base_email = 'invalid-email' # Will cause regex pattern to be None
|
||||
current_user_id = 'user2'
|
||||
|
||||
with patch(
|
||||
'server.auth.token_manager.get_base_email_regex_pattern', return_value=None
|
||||
):
|
||||
# Act
|
||||
result = token_manager._find_duplicate_in_users(
|
||||
users, base_email, current_user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
# Should use fallback matching, but invalid base_email won't match
|
||||
assert result is False
|
||||
|
||||
def test_find_duplicate_in_users_excludes_current_user(self, token_manager):
|
||||
"""Test that current user is excluded from duplicate check."""
|
||||
# Arrange
|
||||
users = {
|
||||
'user1': {'id': 'user1', 'email': 'joe@example.com'},
|
||||
}
|
||||
base_email = 'joe@example.com'
|
||||
current_user_id = 'user1' # Same as user in users dict
|
||||
|
||||
# Act
|
||||
result = token_manager._find_duplicate_in_users(
|
||||
users, base_email, current_user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
def test_find_duplicate_in_users_no_match(self, token_manager):
|
||||
"""Test that no duplicate is found when emails don't match."""
|
||||
# Arrange
|
||||
users = {
|
||||
'user1': {'id': 'user1', 'email': 'jane@example.com'},
|
||||
}
|
||||
base_email = 'joe@example.com'
|
||||
current_user_id = 'user2'
|
||||
|
||||
# Act
|
||||
result = token_manager._find_duplicate_in_users(
|
||||
users, base_email, current_user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
def test_find_duplicate_in_users_empty_dict(self, token_manager):
|
||||
"""Test that empty users dict returns False."""
|
||||
# Arrange
|
||||
users: dict[str, dict] = {}
|
||||
base_email = 'joe@example.com'
|
||||
current_user_id = 'user1'
|
||||
|
||||
# Act
|
||||
result = token_manager._find_duplicate_in_users(
|
||||
users, base_email, current_user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestDeleteKeycloakUser:
|
||||
"""Test cases for delete_keycloak_user method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_keycloak_user_success(self, token_manager):
|
||||
"""Test successful deletion of Keycloak user."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
|
||||
with (
|
||||
patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin,
|
||||
patch('asyncio.to_thread') as mock_to_thread,
|
||||
):
|
||||
mock_admin = MagicMock()
|
||||
mock_admin.delete_user = MagicMock()
|
||||
mock_get_admin.return_value = mock_admin
|
||||
mock_to_thread.return_value = None
|
||||
|
||||
# Act
|
||||
result = await token_manager.delete_keycloak_user(user_id)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_to_thread.assert_called_once_with(mock_admin.delete_user, user_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_keycloak_user_connection_error(self, token_manager):
|
||||
"""Test handling of KeycloakConnectionError triggers retry and raises RetryError."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
|
||||
with (
|
||||
patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin,
|
||||
patch('asyncio.to_thread') as mock_to_thread,
|
||||
):
|
||||
mock_admin = MagicMock()
|
||||
mock_admin.delete_user = MagicMock()
|
||||
mock_get_admin.return_value = mock_admin
|
||||
mock_to_thread.side_effect = KeycloakConnectionError('Connection failed')
|
||||
|
||||
# Act & Assert
|
||||
# KeycloakConnectionError triggers retry decorator
|
||||
# After retries exhaust (2 attempts), it raises RetryError
|
||||
from tenacity import RetryError
|
||||
|
||||
with pytest.raises(RetryError):
|
||||
await token_manager.delete_keycloak_user(user_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_keycloak_user_keycloak_error(self, token_manager):
|
||||
"""Test handling of KeycloakError (e.g., user not found)."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
|
||||
with (
|
||||
patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin,
|
||||
patch('asyncio.to_thread') as mock_to_thread,
|
||||
):
|
||||
mock_admin = MagicMock()
|
||||
mock_admin.delete_user = MagicMock()
|
||||
mock_get_admin.return_value = mock_admin
|
||||
mock_to_thread.side_effect = KeycloakError('User not found')
|
||||
|
||||
# Act
|
||||
result = await token_manager.delete_keycloak_user(user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_keycloak_user_general_exception(self, token_manager):
|
||||
"""Test handling of general exceptions."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
|
||||
with (
|
||||
patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin,
|
||||
patch('asyncio.to_thread') as mock_to_thread,
|
||||
):
|
||||
mock_admin = MagicMock()
|
||||
mock_admin.delete_user = MagicMock()
|
||||
mock_get_admin.return_value = mock_admin
|
||||
mock_to_thread.side_effect = Exception('Unexpected error')
|
||||
|
||||
# Act
|
||||
result = await token_manager.delete_keycloak_user(user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from server.auth.token_manager import TokenManager, create_encryption_utility
|
||||
@@ -246,3 +246,103 @@ async def test_refresh(token_manager):
|
||||
mock_keycloak.return_value.a_refresh_token.assert_called_once_with(
|
||||
'test_refresh_token'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disable_keycloak_user_success(token_manager):
|
||||
"""Test successful disabling of a Keycloak user account."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
email = 'user@colsch.us'
|
||||
mock_user = {
|
||||
'id': user_id,
|
||||
'username': 'testuser',
|
||||
'email': email,
|
||||
'emailVerified': True,
|
||||
}
|
||||
|
||||
with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin:
|
||||
mock_admin = MagicMock()
|
||||
mock_admin.a_get_user = AsyncMock(return_value=mock_user)
|
||||
mock_admin.a_update_user = AsyncMock()
|
||||
mock_get_admin.return_value = mock_admin
|
||||
|
||||
# Act
|
||||
await token_manager.disable_keycloak_user(user_id, email)
|
||||
|
||||
# Assert
|
||||
mock_admin.a_get_user.assert_called_once_with(user_id)
|
||||
mock_admin.a_update_user.assert_called_once_with(
|
||||
user_id=user_id,
|
||||
payload={
|
||||
'enabled': False,
|
||||
'username': 'testuser',
|
||||
'email': email,
|
||||
'emailVerified': True,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disable_keycloak_user_without_email(token_manager):
|
||||
"""Test disabling Keycloak user without providing email."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
mock_user = {
|
||||
'id': user_id,
|
||||
'username': 'testuser',
|
||||
'email': 'user@example.com',
|
||||
'emailVerified': False,
|
||||
}
|
||||
|
||||
with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin:
|
||||
mock_admin = MagicMock()
|
||||
mock_admin.a_get_user = AsyncMock(return_value=mock_user)
|
||||
mock_admin.a_update_user = AsyncMock()
|
||||
mock_get_admin.return_value = mock_admin
|
||||
|
||||
# Act
|
||||
await token_manager.disable_keycloak_user(user_id)
|
||||
|
||||
# Assert
|
||||
mock_admin.a_get_user.assert_called_once_with(user_id)
|
||||
mock_admin.a_update_user.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disable_keycloak_user_not_found(token_manager):
|
||||
"""Test disabling Keycloak user when user is not found."""
|
||||
# Arrange
|
||||
user_id = 'nonexistent_user_id'
|
||||
email = 'user@colsch.us'
|
||||
|
||||
with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin:
|
||||
mock_admin = MagicMock()
|
||||
mock_admin.a_get_user = AsyncMock(return_value=None)
|
||||
mock_get_admin.return_value = mock_admin
|
||||
|
||||
# Act
|
||||
await token_manager.disable_keycloak_user(user_id, email)
|
||||
|
||||
# Assert
|
||||
mock_admin.a_get_user.assert_called_once_with(user_id)
|
||||
mock_admin.a_update_user.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disable_keycloak_user_exception_handling(token_manager):
|
||||
"""Test that disable_keycloak_user handles exceptions gracefully without raising."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
email = 'user@colsch.us'
|
||||
|
||||
with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin:
|
||||
mock_admin = MagicMock()
|
||||
mock_admin.a_get_user = AsyncMock(side_effect=Exception('Connection error'))
|
||||
mock_get_admin.return_value = mock_admin
|
||||
|
||||
# Act & Assert - should not raise exception
|
||||
await token_manager.disable_keycloak_user(user_id, email)
|
||||
|
||||
# Verify the method was called
|
||||
mock_admin.a_get_user.assert_called_once_with(user_id)
|
||||
|
||||
146
frontend/__tests__/MSW.md
Normal file
146
frontend/__tests__/MSW.md
Normal file
@@ -0,0 +1,146 @@
|
||||
# Mock Service Worker (MSW) Guide
|
||||
|
||||
## Overview
|
||||
|
||||
[Mock Service Worker (MSW)](https://mswjs.io/) is an API mocking library that intercepts outgoing network requests at the network level. Unlike traditional mocking that patches `fetch` or `axios`, MSW uses a Service Worker in the browser and direct request interception in Node.js—making mocks transparent to your application code.
|
||||
|
||||
We use MSW in this project for:
|
||||
- **Testing**: Write reliable unit and integration tests without real network calls
|
||||
- **Development**: Run the frontend with mocked APIs when the backend isn't available or when working on features with pending backend APIs
|
||||
|
||||
The same mock handlers work in both environments, so you write them once and reuse everywhere.
|
||||
|
||||
## Relevant Files
|
||||
|
||||
- `src/mocks/handlers.ts` - Main handler registry that combines all domain handlers
|
||||
- `src/mocks/*-handlers.ts` - Domain-specific handlers (auth, billing, conversation, etc.)
|
||||
- `src/mocks/browser.ts` - Browser setup for development mode
|
||||
- `src/mocks/node.ts` - Node.js setup for tests
|
||||
- `vitest.setup.ts` - Global test setup with MSW lifecycle hooks
|
||||
|
||||
## Development Workflow
|
||||
|
||||
### Running with Mocked APIs
|
||||
|
||||
```sh
|
||||
# Run with API mocking enabled
|
||||
npm run dev:mock
|
||||
|
||||
# Run with API mocking + SaaS mode simulation
|
||||
npm run dev:mock:saas
|
||||
```
|
||||
|
||||
These commands set `VITE_MOCK_API=true` which activates the MSW Service Worker to intercept requests.
|
||||
|
||||
> [!NOTE]
|
||||
> **OSS vs SaaS Mode**
|
||||
>
|
||||
> OpenHands runs in two modes:
|
||||
> - **OSS mode**: For local/self-hosted deployments where users provide their own LLM API keys and configure git providers manually
|
||||
> - **SaaS mode**: For the cloud offering with billing, managed API keys, and OAuth-based GitHub integration
|
||||
>
|
||||
> Use `dev:mock:saas` when working on SaaS-specific features like billing, API key management, or subscription flows.
|
||||
|
||||
|
||||
## Writing Tests
|
||||
|
||||
### Service Layer Mocking (Recommended)
|
||||
|
||||
For most tests, mock at the service layer using `vi.spyOn`. This approach is explicit, test-scoped, and makes the scenario being tested clear.
|
||||
|
||||
```typescript
|
||||
import { vi } from "vitest";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
|
||||
const getSettingsSpy = vi.spyOn(SettingsService, "getSettings");
|
||||
getSettingsSpy.mockResolvedValue({
|
||||
llm_model: "openai/gpt-4o",
|
||||
llm_api_key_set: true,
|
||||
// ... other settings
|
||||
});
|
||||
```
|
||||
|
||||
Use `mockResolvedValue` for success scenarios and `mockRejectedValue` for error scenarios:
|
||||
|
||||
```typescript
|
||||
getSettingsSpy.mockRejectedValue(new Error("Failed to fetch settings"));
|
||||
```
|
||||
|
||||
### Network Layer Mocking (Advanced)
|
||||
|
||||
For tests that need actual network-level behavior (WebSockets, testing retry logic, etc.), use `server.use()` to override handlers per test.
|
||||
|
||||
> [!IMPORTANT]
|
||||
> **Reuse the global server instance** - Don't create new `setupServer()` calls in individual tests. The project already has a global MSW server configured in `vitest.setup.ts` that handles lifecycle (`server.listen()`, `server.resetHandlers()`, `server.close()`). Use `server.use()` to add runtime handlers for specific test scenarios.
|
||||
|
||||
```typescript
|
||||
import { http, HttpResponse } from "msw";
|
||||
import { server } from "#/mocks/node";
|
||||
|
||||
it("should handle server errors", async () => {
|
||||
server.use(
|
||||
http.get("/api/my-endpoint", () => {
|
||||
return new HttpResponse(null, { status: 500 });
|
||||
}),
|
||||
);
|
||||
// ... test code
|
||||
});
|
||||
```
|
||||
|
||||
For WebSocket testing, see `__tests__/helpers/msw-websocket-setup.ts` for utilities.
|
||||
|
||||
## Adding New API Mocks
|
||||
|
||||
When adding new API endpoints, create mocks in both places to maintain 1:1 similarity with the backend:
|
||||
|
||||
### 1. Add to `src/mocks/` (for development)
|
||||
|
||||
Create or update a domain-specific handler file:
|
||||
|
||||
```typescript
|
||||
// src/mocks/my-feature-handlers.ts
|
||||
import { http, HttpResponse } from "msw";
|
||||
|
||||
export const MY_FEATURE_HANDLERS = [
|
||||
http.get("/api/my-feature", () => {
|
||||
return HttpResponse.json({
|
||||
data: "mock response",
|
||||
});
|
||||
}),
|
||||
];
|
||||
```
|
||||
|
||||
Register in `handlers.ts`:
|
||||
|
||||
```typescript
|
||||
import { MY_FEATURE_HANDLERS } from "./my-feature-handlers";
|
||||
|
||||
export const handlers = [
|
||||
// ... existing handlers
|
||||
...MY_FEATURE_HANDLERS,
|
||||
];
|
||||
```
|
||||
|
||||
### 2. Mock in tests for specific scenarios
|
||||
|
||||
In your test files, spy on the service method to control responses per test case:
|
||||
|
||||
```typescript
|
||||
import { vi } from "vitest";
|
||||
import MyFeatureService from "#/api/my-feature-service.api";
|
||||
|
||||
const spy = vi.spyOn(MyFeatureService, "getData");
|
||||
spy.mockResolvedValue({ data: "test-specific response" });
|
||||
```
|
||||
|
||||
See `__tests__/routes/llm-settings.test.tsx` for a real-world example of service layer mocking.
|
||||
|
||||
> [!TIP]
|
||||
> For guidance on creating service APIs, see `src/api/README.md`.
|
||||
|
||||
## Best Practices
|
||||
|
||||
- **Keep mocks close to real API contracts** - Update mocks when backend changes
|
||||
- **Use service layer mocking for most tests** - It's simpler and more explicit
|
||||
- **Reserve network layer mocking for integration tests** - WebSockets, retry logic, etc.
|
||||
- **Export mock data from handler files** - Reuse in tests (e.g., `MOCK_DEFAULT_USER_SETTINGS`)
|
||||
18
frontend/__tests__/api/v1-git-service.test.ts
Normal file
18
frontend/__tests__/api/v1-git-service.test.ts
Normal file
@@ -0,0 +1,18 @@
|
||||
import { test, expect, vi } from "vitest";
|
||||
import axios from "axios";
|
||||
import V1GitService from "../../src/api/git-service/v1-git-service.api";
|
||||
|
||||
vi.mock("axios");
|
||||
|
||||
test("getGitChanges throws when response is not an array (dead runtime returns HTML)", async () => {
|
||||
const htmlResponse = "<!DOCTYPE html><html>...</html>";
|
||||
vi.mocked(axios.get).mockResolvedValue({ data: htmlResponse });
|
||||
|
||||
await expect(
|
||||
V1GitService.getGitChanges(
|
||||
"http://localhost:3000/api/conversations/123",
|
||||
"test-api-key",
|
||||
"/workspace",
|
||||
),
|
||||
).rejects.toThrow("Invalid response from runtime");
|
||||
});
|
||||
@@ -1,6 +1,7 @@
|
||||
import { render, screen } from "@testing-library/react";
|
||||
import { it, describe, expect, vi, beforeEach, afterEach } from "vitest";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { MemoryRouter } from "react-router";
|
||||
import { AuthModal } from "#/components/features/waitlist/auth-modal";
|
||||
|
||||
// Mock the useAuthUrl hook
|
||||
@@ -27,11 +28,13 @@ describe("AuthModal", () => {
|
||||
|
||||
it("should render the GitHub and GitLab buttons", () => {
|
||||
render(
|
||||
<AuthModal
|
||||
githubAuthUrl="mock-url"
|
||||
appMode="saas"
|
||||
providersConfigured={["github", "gitlab"]}
|
||||
/>,
|
||||
<MemoryRouter>
|
||||
<AuthModal
|
||||
githubAuthUrl="mock-url"
|
||||
appMode="saas"
|
||||
providersConfigured={["github", "gitlab"]}
|
||||
/>
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
const githubButton = screen.getByRole("button", {
|
||||
@@ -49,11 +52,13 @@ describe("AuthModal", () => {
|
||||
const user = userEvent.setup();
|
||||
const mockUrl = "https://github.com/login/oauth/authorize";
|
||||
render(
|
||||
<AuthModal
|
||||
githubAuthUrl={mockUrl}
|
||||
appMode="saas"
|
||||
providersConfigured={["github"]}
|
||||
/>,
|
||||
<MemoryRouter>
|
||||
<AuthModal
|
||||
githubAuthUrl={mockUrl}
|
||||
appMode="saas"
|
||||
providersConfigured={["github"]}
|
||||
/>
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
const githubButton = screen.getByRole("button", {
|
||||
@@ -65,10 +70,14 @@ describe("AuthModal", () => {
|
||||
});
|
||||
|
||||
it("should render Terms of Service and Privacy Policy text with correct links", () => {
|
||||
render(<AuthModal githubAuthUrl="mock-url" appMode="saas" />);
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<AuthModal githubAuthUrl="mock-url" appMode="saas" />
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
// Find the terms of service section using data-testid
|
||||
const termsSection = screen.getByTestId("auth-modal-terms-of-service");
|
||||
const termsSection = screen.getByTestId("terms-and-privacy-notice");
|
||||
expect(termsSection).toBeInTheDocument();
|
||||
|
||||
// Check that all text content is present in the paragraph
|
||||
@@ -105,8 +114,44 @@ describe("AuthModal", () => {
|
||||
expect(termsSection).toContainElement(privacyLink);
|
||||
});
|
||||
|
||||
it("should display email verified message when emailVerified prop is true", () => {
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<AuthModal
|
||||
githubAuthUrl="mock-url"
|
||||
appMode="saas"
|
||||
emailVerified={true}
|
||||
/>
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
expect(
|
||||
screen.getByText("AUTH$EMAIL_VERIFIED_PLEASE_LOGIN"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not display email verified message when emailVerified prop is false", () => {
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<AuthModal
|
||||
githubAuthUrl="mock-url"
|
||||
appMode="saas"
|
||||
emailVerified={false}
|
||||
/>
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
expect(
|
||||
screen.queryByText("AUTH$EMAIL_VERIFIED_PLEASE_LOGIN"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should open Terms of Service link in new tab", () => {
|
||||
render(<AuthModal githubAuthUrl="mock-url" appMode="saas" />);
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<AuthModal githubAuthUrl="mock-url" appMode="saas" />
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
const tosLink = screen.getByRole("link", {
|
||||
name: "COMMON$TERMS_OF_SERVICE",
|
||||
@@ -115,11 +160,58 @@ describe("AuthModal", () => {
|
||||
});
|
||||
|
||||
it("should open Privacy Policy link in new tab", () => {
|
||||
render(<AuthModal githubAuthUrl="mock-url" appMode="saas" />);
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<AuthModal githubAuthUrl="mock-url" appMode="saas" />
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
const privacyLink = screen.getByRole("link", {
|
||||
name: "COMMON$PRIVACY_POLICY",
|
||||
});
|
||||
expect(privacyLink).toHaveAttribute("target", "_blank");
|
||||
});
|
||||
|
||||
describe("Duplicate email error message", () => {
|
||||
const renderAuthModalWithRouter = (initialEntries: string[]) => {
|
||||
const hasDuplicatedEmail = initialEntries.includes(
|
||||
"/?duplicated_email=true",
|
||||
);
|
||||
|
||||
return render(
|
||||
<MemoryRouter initialEntries={initialEntries}>
|
||||
<AuthModal
|
||||
githubAuthUrl="mock-url"
|
||||
appMode="saas"
|
||||
providersConfigured={["github"]}
|
||||
hasDuplicatedEmail={hasDuplicatedEmail}
|
||||
/>
|
||||
</MemoryRouter>,
|
||||
);
|
||||
};
|
||||
|
||||
it("should display error message when duplicated_email query parameter is true", () => {
|
||||
// Arrange
|
||||
const initialEntries = ["/?duplicated_email=true"];
|
||||
|
||||
// Act
|
||||
renderAuthModalWithRouter(initialEntries);
|
||||
|
||||
// Assert
|
||||
const errorMessage = screen.getByText("AUTH$DUPLICATE_EMAIL_ERROR");
|
||||
expect(errorMessage).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not display error message when duplicated_email query parameter is missing", () => {
|
||||
// Arrange
|
||||
const initialEntries = ["/"];
|
||||
|
||||
// Act
|
||||
renderAuthModalWithRouter(initialEntries);
|
||||
|
||||
// Assert
|
||||
const errorMessage = screen.queryByText("AUTH$DUPLICATE_EMAIL_ERROR");
|
||||
expect(errorMessage).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -23,6 +23,11 @@ describe("ConversationPanel", () => {
|
||||
Component: () => <ConversationPanel onClose={onCloseMock} />,
|
||||
path: "/",
|
||||
},
|
||||
{
|
||||
// Add route to prevent "No routes matched location" warning
|
||||
Component: () => null,
|
||||
path: "/conversations/:conversationId",
|
||||
},
|
||||
]);
|
||||
|
||||
const renderConversationPanel = () => renderWithProviders(<RouterStub />);
|
||||
|
||||
@@ -48,9 +48,12 @@ describe("MaintenanceBanner", () => {
|
||||
expect(button).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// maintenance-banner
|
||||
|
||||
it("handles invalid date gracefully", () => {
|
||||
// Suppress expected console.warn for invalid date parsing
|
||||
const consoleWarnSpy = vi
|
||||
.spyOn(console, "warn")
|
||||
.mockImplementation(() => {});
|
||||
|
||||
const invalidTime = "invalid-date";
|
||||
|
||||
render(
|
||||
@@ -62,6 +65,9 @@ describe("MaintenanceBanner", () => {
|
||||
// Check if the banner is rendered
|
||||
const banner = screen.queryByTestId("maintenance-banner");
|
||||
expect(banner).not.toBeInTheDocument();
|
||||
|
||||
// Restore console.warn
|
||||
consoleWarnSpy.mockRestore();
|
||||
});
|
||||
|
||||
it("click on dismiss button removes banner", () => {
|
||||
|
||||
@@ -0,0 +1,187 @@
|
||||
import React from "react";
|
||||
import { screen, waitFor } from "@testing-library/react";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { MemoryRouter } from "react-router";
|
||||
import { emailService } from "#/api/email-service/email-service.api";
|
||||
import { EmailVerificationModal } from "#/components/features/waitlist/email-verification-modal";
|
||||
import * as ToastHandlers from "#/utils/custom-toast-handlers";
|
||||
import { renderWithProviders, createAxiosError } from "../../../../test-utils";
|
||||
|
||||
describe("EmailVerificationModal", () => {
|
||||
const mockOnClose = vi.fn();
|
||||
|
||||
const resendEmailVerificationSpy = vi.spyOn(
|
||||
emailService,
|
||||
"resendEmailVerification",
|
||||
);
|
||||
const displaySuccessToastSpy = vi.spyOn(ToastHandlers, "displaySuccessToast");
|
||||
const displayErrorToastSpy = vi.spyOn(ToastHandlers, "displayErrorToast");
|
||||
|
||||
const renderWithRouter = (ui: React.ReactElement) => {
|
||||
return renderWithProviders(<MemoryRouter>{ui}</MemoryRouter>);
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it("should render the email verification message", () => {
|
||||
// Arrange & Act
|
||||
renderWithRouter(<EmailVerificationModal onClose={mockOnClose} />);
|
||||
|
||||
// Assert
|
||||
expect(
|
||||
screen.getByText("AUTH$PLEASE_CHECK_EMAIL_TO_VERIFY"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should render the TermsAndPrivacyNotice component", () => {
|
||||
// Arrange & Act
|
||||
renderWithRouter(<EmailVerificationModal onClose={mockOnClose} />);
|
||||
|
||||
// Assert
|
||||
const termsSection = screen.getByTestId("terms-and-privacy-notice");
|
||||
expect(termsSection).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should render resend verification button", () => {
|
||||
// Arrange & Act
|
||||
renderWithRouter(<EmailVerificationModal onClose={mockOnClose} />);
|
||||
|
||||
// Assert
|
||||
expect(
|
||||
screen.getByText("SETTINGS$RESEND_VERIFICATION"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should call resendEmailVerification when the button is clicked", async () => {
|
||||
// Arrange
|
||||
const userId = "test_user_id";
|
||||
resendEmailVerificationSpy.mockResolvedValue({
|
||||
message: "Email verification message sent",
|
||||
});
|
||||
renderWithRouter(
|
||||
<EmailVerificationModal onClose={mockOnClose} userId={userId} />,
|
||||
);
|
||||
|
||||
// Act
|
||||
const resendButton = screen.getByText("SETTINGS$RESEND_VERIFICATION");
|
||||
await userEvent.click(resendButton);
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(resendEmailVerificationSpy).toHaveBeenCalledWith({
|
||||
userId,
|
||||
isAuthFlow: true,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
it("should display success toast when resend succeeds", async () => {
|
||||
// Arrange
|
||||
resendEmailVerificationSpy.mockResolvedValue({
|
||||
message: "Email verification message sent",
|
||||
});
|
||||
renderWithRouter(<EmailVerificationModal onClose={mockOnClose} />);
|
||||
|
||||
// Act
|
||||
const resendButton = screen.getByText("SETTINGS$RESEND_VERIFICATION");
|
||||
await userEvent.click(resendButton);
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(displaySuccessToastSpy).toHaveBeenCalledWith(
|
||||
"SETTINGS$VERIFICATION_EMAIL_SENT",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it("should display rate limit error message when receiving 429 status", async () => {
|
||||
// Arrange
|
||||
const rateLimitError = createAxiosError(429, "Too Many Requests", {
|
||||
detail: "Too many requests. Please wait 2 minutes before trying again.",
|
||||
});
|
||||
resendEmailVerificationSpy.mockRejectedValue(rateLimitError);
|
||||
renderWithRouter(<EmailVerificationModal onClose={mockOnClose} />);
|
||||
|
||||
// Act
|
||||
const resendButton = screen.getByText("SETTINGS$RESEND_VERIFICATION");
|
||||
await userEvent.click(resendButton);
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(displayErrorToastSpy).toHaveBeenCalledWith(
|
||||
"Too many requests. Please wait 2 minutes before trying again.",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it("should display generic error message when receiving non-429 error", async () => {
|
||||
// Arrange
|
||||
const genericError = createAxiosError(500, "Internal Server Error", {
|
||||
error: "Internal server error",
|
||||
});
|
||||
resendEmailVerificationSpy.mockRejectedValue(genericError);
|
||||
renderWithRouter(<EmailVerificationModal onClose={mockOnClose} />);
|
||||
|
||||
// Act
|
||||
const resendButton = screen.getByText("SETTINGS$RESEND_VERIFICATION");
|
||||
await userEvent.click(resendButton);
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(displayErrorToastSpy).toHaveBeenCalledWith(
|
||||
"SETTINGS$FAILED_TO_RESEND_VERIFICATION",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it("should disable button and show sending text while request is pending", async () => {
|
||||
// Arrange
|
||||
let resolvePromise: (value: { message: string }) => void;
|
||||
const pendingPromise = new Promise<{ message: string }>((resolve) => {
|
||||
resolvePromise = resolve;
|
||||
});
|
||||
resendEmailVerificationSpy.mockReturnValue(pendingPromise);
|
||||
renderWithRouter(<EmailVerificationModal onClose={mockOnClose} />);
|
||||
|
||||
// Act
|
||||
const resendButton = screen.getByText("SETTINGS$RESEND_VERIFICATION");
|
||||
await userEvent.click(resendButton);
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
const sendingButton = screen.getByText("SETTINGS$SENDING");
|
||||
expect(sendingButton).toBeInTheDocument();
|
||||
expect(sendingButton).toBeDisabled();
|
||||
});
|
||||
|
||||
// Cleanup
|
||||
resolvePromise!({ message: "Email verification message sent" });
|
||||
});
|
||||
|
||||
it("should re-enable button after request completes", async () => {
|
||||
// Arrange
|
||||
resendEmailVerificationSpy.mockResolvedValue({
|
||||
message: "Email verification message sent",
|
||||
});
|
||||
renderWithRouter(<EmailVerificationModal onClose={mockOnClose} />);
|
||||
|
||||
// Act
|
||||
const resendButton = screen.getByText("SETTINGS$RESEND_VERIFICATION");
|
||||
await userEvent.click(resendButton);
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(resendEmailVerificationSpy).toHaveBeenCalled();
|
||||
});
|
||||
// After successful send, the button will be disabled due to cooldown
|
||||
// So we just verify the mutation was called successfully
|
||||
await waitFor(() => {
|
||||
const button = screen.getByRole("button");
|
||||
expect(button).toBeDisabled(); // Button is disabled during cooldown
|
||||
expect(button).toHaveTextContent(/SETTINGS\$RESEND_VERIFICATION/);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,48 @@
|
||||
import { render, screen } from "@testing-library/react";
|
||||
import { it, describe, expect } from "vitest";
|
||||
import { TermsAndPrivacyNotice } from "#/components/shared/terms-and-privacy-notice";
|
||||
|
||||
describe("TermsAndPrivacyNotice", () => {
|
||||
it("should render Terms of Service and Privacy Policy links", () => {
|
||||
// Arrange & Act
|
||||
render(<TermsAndPrivacyNotice />);
|
||||
|
||||
// Assert
|
||||
const termsSection = screen.getByTestId("terms-and-privacy-notice");
|
||||
expect(termsSection).toBeInTheDocument();
|
||||
|
||||
const tosLink = screen.getByRole("link", {
|
||||
name: "COMMON$TERMS_OF_SERVICE",
|
||||
});
|
||||
const privacyLink = screen.getByRole("link", {
|
||||
name: "COMMON$PRIVACY_POLICY",
|
||||
});
|
||||
|
||||
expect(tosLink).toBeInTheDocument();
|
||||
expect(tosLink).toHaveAttribute("href", "https://www.all-hands.dev/tos");
|
||||
expect(tosLink).toHaveAttribute("target", "_blank");
|
||||
expect(tosLink).toHaveAttribute("rel", "noopener noreferrer");
|
||||
|
||||
expect(privacyLink).toBeInTheDocument();
|
||||
expect(privacyLink).toHaveAttribute(
|
||||
"href",
|
||||
"https://www.all-hands.dev/privacy",
|
||||
);
|
||||
expect(privacyLink).toHaveAttribute("target", "_blank");
|
||||
expect(privacyLink).toHaveAttribute("rel", "noopener noreferrer");
|
||||
});
|
||||
|
||||
it("should render all required text content", () => {
|
||||
// Arrange & Act
|
||||
render(<TermsAndPrivacyNotice />);
|
||||
|
||||
// Assert
|
||||
const termsSection = screen.getByTestId("terms-and-privacy-notice");
|
||||
expect(termsSection).toHaveTextContent(
|
||||
"AUTH$BY_SIGNING_UP_YOU_AGREE_TO_OUR",
|
||||
);
|
||||
expect(termsSection).toHaveTextContent("COMMON$TERMS_OF_SERVICE");
|
||||
expect(termsSection).toHaveTextContent("COMMON$AND");
|
||||
expect(termsSection).toHaveTextContent("COMMON$PRIVACY_POLICY");
|
||||
});
|
||||
});
|
||||
@@ -6,6 +6,7 @@ import {
|
||||
beforeEach,
|
||||
afterAll,
|
||||
afterEach,
|
||||
vi,
|
||||
} from "vitest";
|
||||
import { screen, waitFor, render, cleanup } from "@testing-library/react";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
@@ -34,6 +35,7 @@ import {
|
||||
} from "#/contexts/conversation-websocket-context";
|
||||
import { conversationWebSocketTestSetup } from "./helpers/msw-websocket-setup";
|
||||
import { useEventStore } from "#/stores/use-event-store";
|
||||
import { isV1Event } from "#/types/v1/type-guards";
|
||||
|
||||
// MSW WebSocket mock setup
|
||||
const { wsLink, server: mswServer } = conversationWebSocketTestSetup();
|
||||
@@ -141,6 +143,11 @@ describe("Conversation WebSocket Handler", () => {
|
||||
});
|
||||
|
||||
it("should handle malformed/invalid event data gracefully", async () => {
|
||||
// Suppress expected console.warn for invalid JSON parsing
|
||||
const consoleWarnSpy = vi
|
||||
.spyOn(console, "warn")
|
||||
.mockImplementation(() => {});
|
||||
|
||||
// Set up MSW to send various invalid events when connection is established
|
||||
mswServer.use(
|
||||
wsLink.addEventListener("connection", ({ client, server }) => {
|
||||
@@ -203,6 +210,9 @@ describe("Conversation WebSocket Handler", () => {
|
||||
"valid-event-123",
|
||||
);
|
||||
expect(screen.getByTestId("ui-events-count")).toHaveTextContent("1");
|
||||
|
||||
// Restore console.warn
|
||||
consoleWarnSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -415,6 +425,101 @@ describe("Conversation WebSocket Handler", () => {
|
||||
// based on how the actual reconnection logic is implemented
|
||||
});
|
||||
|
||||
it("should not create duplicate events when WebSocket reconnects with resend_all=true", async () => {
|
||||
const conversationId = "test-conversation-reconnect";
|
||||
let connectionCount = 0;
|
||||
|
||||
// Clear event store before test
|
||||
useEventStore.getState().clearEvents();
|
||||
|
||||
// Create mock events that will be sent on each connection
|
||||
const mockHistoryEvents = [
|
||||
createMockUserMessageEvent({ id: "event-1" }),
|
||||
createMockMessageEvent({ id: "event-2" }),
|
||||
createMockMessageEvent({ id: "event-3" }),
|
||||
];
|
||||
|
||||
// Set up MSW to mock event count API and WebSocket
|
||||
// The WebSocket will resend all events on each connection (simulating resend_all=true behavior)
|
||||
mswServer.use(
|
||||
http.get(
|
||||
`http://localhost:3000/api/conversations/${conversationId}/events/count`,
|
||||
() => HttpResponse.json(3),
|
||||
),
|
||||
wsLink.addEventListener("connection", ({ client, server }) => {
|
||||
connectionCount += 1;
|
||||
server.connect();
|
||||
|
||||
// Send all history events on EVERY connection (simulating resend_all=true)
|
||||
mockHistoryEvents.forEach((event) => {
|
||||
client.send(JSON.stringify(event));
|
||||
});
|
||||
|
||||
// On first connection, simulate a disconnect after events are sent
|
||||
if (connectionCount === 1) {
|
||||
setTimeout(() => {
|
||||
client.close(1006, "Simulated disconnect");
|
||||
}, 100);
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
// Render with WebSocket context
|
||||
renderWithWebSocketContext(
|
||||
<ConnectionStatusComponent />,
|
||||
conversationId,
|
||||
`http://localhost:3000/api/conversations/${conversationId}`,
|
||||
);
|
||||
|
||||
// Wait for initial connection and events
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("connection-state")).toHaveTextContent(
|
||||
"OPEN",
|
||||
);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(useEventStore.getState().events.length).toBe(3);
|
||||
});
|
||||
|
||||
// Wait for disconnect
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("connection-state")).toHaveTextContent(
|
||||
"CLOSED",
|
||||
);
|
||||
});
|
||||
|
||||
// Wait for reconnection
|
||||
await waitFor(
|
||||
() => {
|
||||
expect(screen.getByTestId("connection-state")).toHaveTextContent(
|
||||
"OPEN",
|
||||
);
|
||||
},
|
||||
{ timeout: 5000 },
|
||||
);
|
||||
|
||||
// Give time for resent events to be processed
|
||||
await new Promise((resolve) => {
|
||||
setTimeout(resolve, 200);
|
||||
});
|
||||
|
||||
// After reconnection, events should NOT be duplicated
|
||||
// The server sends 3 events again (resend_all=true), but we should deduplicate
|
||||
const { events } = useEventStore.getState();
|
||||
const v1Events = events.filter(isV1Event);
|
||||
const uniqueEventIds = [...new Set(v1Events.map((e) => e.id))];
|
||||
|
||||
// This assertion will FAIL with current implementation (showing the bug)
|
||||
// Expected: 3 events (deduplicated)
|
||||
// Actual: 6 events (duplicated)
|
||||
expect(v1Events.length).toBe(3);
|
||||
expect(uniqueEventIds.length).toBe(3);
|
||||
|
||||
// Verify we actually had 2 connections
|
||||
expect(connectionCount).toBe(2);
|
||||
});
|
||||
|
||||
it.todo("should track and display errors with proper metadata");
|
||||
it.todo("should set appropriate error states on connection failures");
|
||||
it.todo(
|
||||
|
||||
@@ -1,174 +0,0 @@
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
|
||||
import { renderHook, waitFor } from "@testing-library/react";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import React from "react";
|
||||
import {
|
||||
useGitHubIssuesPRs,
|
||||
useRefreshGitHubIssuesPRs,
|
||||
} from "../src/hooks/query/use-github-issues-prs";
|
||||
import { useShouldShowUserFeatures } from "../src/hooks/use-should-show-user-features";
|
||||
import GitHubIssuesPRsService from "../src/api/github-service/github-issues-prs.api";
|
||||
|
||||
// Mock the dependencies
|
||||
vi.mock("../src/hooks/use-should-show-user-features");
|
||||
vi.mock("../src/api/github-service/github-issues-prs.api", () => ({
|
||||
default: {
|
||||
getGitHubItems: vi.fn(),
|
||||
buildItemUrl: vi.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
const mockUseShouldShowUserFeatures = vi.mocked(useShouldShowUserFeatures);
|
||||
const mockGetGitHubItems = vi.mocked(GitHubIssuesPRsService.getGitHubItems);
|
||||
|
||||
const createWrapper = () => {
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: {
|
||||
retry: false,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
return ({ children }: { children: React.ReactNode }) =>
|
||||
React.createElement(QueryClientProvider, { client: queryClient }, children);
|
||||
};
|
||||
|
||||
describe("useGitHubIssuesPRs", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
localStorage.clear();
|
||||
mockUseShouldShowUserFeatures.mockReturnValue(false);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
localStorage.clear();
|
||||
});
|
||||
|
||||
it("should be disabled when useShouldShowUserFeatures returns false", () => {
|
||||
mockUseShouldShowUserFeatures.mockReturnValue(false);
|
||||
|
||||
const { result } = renderHook(() => useGitHubIssuesPRs(), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
expect(result.current.isLoading).toBe(false);
|
||||
expect(result.current.isFetching).toBe(false);
|
||||
});
|
||||
|
||||
it("should be enabled when useShouldShowUserFeatures returns true", () => {
|
||||
mockUseShouldShowUserFeatures.mockReturnValue(true);
|
||||
mockGetGitHubItems.mockResolvedValue({
|
||||
items: [],
|
||||
cached_at: new Date().toISOString(),
|
||||
});
|
||||
|
||||
const { result } = renderHook(() => useGitHubIssuesPRs(), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
// When enabled, the query should be loading/fetching
|
||||
expect(result.current.isLoading).toBe(true);
|
||||
});
|
||||
|
||||
it("should fetch and return GitHub items", async () => {
|
||||
mockUseShouldShowUserFeatures.mockReturnValue(true);
|
||||
const mockItems = [
|
||||
{
|
||||
git_provider: "github" as const,
|
||||
item_type: "issue" as const,
|
||||
status: "OPEN_ISSUE" as const,
|
||||
repo: "test/repo",
|
||||
number: 1,
|
||||
title: "Test Issue",
|
||||
author: "testuser",
|
||||
assignees: ["testuser"],
|
||||
created_at: "2024-01-01T00:00:00Z",
|
||||
updated_at: "2024-01-01T00:00:00Z",
|
||||
url: "https://github.com/test/repo/issues/1",
|
||||
},
|
||||
];
|
||||
mockGetGitHubItems.mockResolvedValue({
|
||||
items: mockItems,
|
||||
cached_at: new Date().toISOString(),
|
||||
});
|
||||
|
||||
const { result } = renderHook(() => useGitHubIssuesPRs(), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.isSuccess).toBe(true);
|
||||
});
|
||||
|
||||
expect(result.current.data?.items).toEqual(mockItems);
|
||||
});
|
||||
|
||||
it("should filter by item type", async () => {
|
||||
mockUseShouldShowUserFeatures.mockReturnValue(true);
|
||||
const mockItems = [
|
||||
{
|
||||
git_provider: "github" as const,
|
||||
item_type: "issue" as const,
|
||||
status: "OPEN_ISSUE" as const,
|
||||
repo: "test/repo",
|
||||
number: 1,
|
||||
title: "Test Issue",
|
||||
author: "testuser",
|
||||
assignees: ["testuser"],
|
||||
created_at: "2024-01-01T00:00:00Z",
|
||||
updated_at: "2024-01-01T00:00:00Z",
|
||||
url: "https://github.com/test/repo/issues/1",
|
||||
},
|
||||
];
|
||||
mockGetGitHubItems.mockResolvedValue({
|
||||
items: mockItems,
|
||||
cached_at: new Date().toISOString(),
|
||||
});
|
||||
|
||||
const { result } = renderHook(
|
||||
() => useGitHubIssuesPRs({ itemType: "issues" }),
|
||||
{
|
||||
wrapper: createWrapper(),
|
||||
},
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.isSuccess).toBe(true);
|
||||
});
|
||||
|
||||
expect(mockGetGitHubItems).toHaveBeenCalledWith({ itemType: "issues" });
|
||||
});
|
||||
});
|
||||
|
||||
describe("useRefreshGitHubIssuesPRs", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
localStorage.clear();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
localStorage.clear();
|
||||
});
|
||||
|
||||
it("should clear localStorage cache when called", () => {
|
||||
// Set up some cached data
|
||||
localStorage.setItem(
|
||||
"github-issues-prs-cache",
|
||||
JSON.stringify({
|
||||
data: { items: [], cached_at: new Date().toISOString() },
|
||||
timestamp: Date.now(),
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() => useRefreshGitHubIssuesPRs(), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
// Call the refresh function
|
||||
result.current();
|
||||
|
||||
// Check that localStorage was cleared
|
||||
expect(localStorage.getItem("github-issues-prs-cache")).toBeNull();
|
||||
});
|
||||
});
|
||||
@@ -34,7 +34,11 @@ describe("useWebSocket", () => {
|
||||
}),
|
||||
);
|
||||
|
||||
beforeAll(() => mswServer.listen());
|
||||
beforeAll(() =>
|
||||
mswServer.listen({
|
||||
onUnhandledRequest: "warn",
|
||||
}),
|
||||
);
|
||||
afterEach(() => mswServer.resetHandlers());
|
||||
afterAll(() => mswServer.close());
|
||||
|
||||
|
||||
227
frontend/__tests__/router.md
Normal file
227
frontend/__tests__/router.md
Normal file
@@ -0,0 +1,227 @@
|
||||
# Testing with React Router
|
||||
|
||||
## Overview
|
||||
|
||||
React Router components and hooks require a routing context to function. In tests, we need to provide this context while maintaining control over the routing state.
|
||||
|
||||
This guide covers the two main approaches used in the OpenHands frontend:
|
||||
|
||||
1. **`createRoutesStub`** - Creates a complete route structure for testing components with their actual route configuration, loaders, and nested routes.
|
||||
2. **`MemoryRouter`** - Provides a minimal routing context for components that just need router hooks to work.
|
||||
|
||||
Choose your approach based on what your component actually needs from the router.
|
||||
|
||||
## When to Use Each Approach
|
||||
|
||||
### `createRoutesStub` (Recommended)
|
||||
|
||||
Use `createRoutesStub` when your component:
|
||||
- Relies on route parameters (`useParams`)
|
||||
- Uses loader data (`useLoaderData`) or `clientLoader`
|
||||
- Has nested routes or uses `<Outlet />`
|
||||
- Needs to test navigation between routes
|
||||
|
||||
> [!NOTE]
|
||||
> `createRoutesStub` is intended for unit testing **reusable components** that depend on router context. For testing full route/page components, consider E2E tests (Playwright, Cypress) instead.
|
||||
|
||||
```typescript
|
||||
import { createRoutesStub } from "react-router";
|
||||
import { render } from "@testing-library/react";
|
||||
|
||||
const RouterStub = createRoutesStub([
|
||||
{
|
||||
Component: MyRouteComponent,
|
||||
path: "/conversations/:conversationId",
|
||||
},
|
||||
]);
|
||||
|
||||
render(<RouterStub initialEntries={["/conversations/123"]} />);
|
||||
```
|
||||
|
||||
**With nested routes and loaders:**
|
||||
|
||||
```typescript
|
||||
const RouterStub = createRoutesStub([
|
||||
{
|
||||
Component: SettingsScreen,
|
||||
clientLoader,
|
||||
path: "/settings",
|
||||
children: [
|
||||
{
|
||||
Component: () => <div data-testid="llm-settings" />,
|
||||
path: "/settings",
|
||||
},
|
||||
{
|
||||
Component: () => <div data-testid="git-settings" />,
|
||||
path: "/settings/integrations",
|
||||
},
|
||||
],
|
||||
},
|
||||
]);
|
||||
|
||||
render(<RouterStub initialEntries={["/settings/integrations"]} />);
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> When using `clientLoader` from a Route module, you may encounter type mismatches. Use `@ts-expect-error` as a workaround:
|
||||
|
||||
```typescript
|
||||
import { clientLoader } from "@/routes/settings";
|
||||
|
||||
const RouterStub = createRoutesStub([
|
||||
{
|
||||
path: "/settings",
|
||||
Component: SettingsScreen,
|
||||
// @ts-expect-error: loader types won't align between test and app code
|
||||
loader: clientLoader,
|
||||
},
|
||||
]);
|
||||
```
|
||||
|
||||
### `MemoryRouter`
|
||||
|
||||
Use `MemoryRouter` when your component:
|
||||
- Only needs basic routing context to render
|
||||
- Uses `<Link>` components but you don't need to test navigation
|
||||
- Doesn't depend on specific route parameters or loaders
|
||||
|
||||
```typescript
|
||||
import { MemoryRouter } from "react-router";
|
||||
import { render } from "@testing-library/react";
|
||||
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<MyComponent />
|
||||
</MemoryRouter>
|
||||
);
|
||||
```
|
||||
|
||||
**With initial route:**
|
||||
|
||||
```typescript
|
||||
render(
|
||||
<MemoryRouter initialEntries={["/some/path"]}>
|
||||
<MyComponent />
|
||||
</MemoryRouter>
|
||||
);
|
||||
```
|
||||
|
||||
## Anti-patterns to Avoid
|
||||
|
||||
### Using `BrowserRouter` in tests
|
||||
|
||||
`BrowserRouter` interacts with the actual browser history API, which can cause issues in test environments:
|
||||
|
||||
```typescript
|
||||
// ❌ Avoid
|
||||
render(
|
||||
<BrowserRouter>
|
||||
<MyComponent />
|
||||
</BrowserRouter>
|
||||
);
|
||||
|
||||
// ✅ Use MemoryRouter instead
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<MyComponent />
|
||||
</MemoryRouter>
|
||||
);
|
||||
```
|
||||
|
||||
### Mocking router hooks when `createRoutesStub` would work
|
||||
|
||||
Mocking hooks like `useParams` directly can be brittle and doesn't test the actual routing behavior:
|
||||
|
||||
```typescript
|
||||
// ❌ Avoid when possible
|
||||
vi.mock("react-router", async () => {
|
||||
const actual = await vi.importActual("react-router");
|
||||
return {
|
||||
...actual,
|
||||
useParams: () => ({ conversationId: "123" }),
|
||||
};
|
||||
});
|
||||
|
||||
// ✅ Prefer createRoutesStub - tests real routing behavior
|
||||
const RouterStub = createRoutesStub([
|
||||
{
|
||||
Component: MyComponent,
|
||||
path: "/conversations/:conversationId",
|
||||
},
|
||||
]);
|
||||
|
||||
render(<RouterStub initialEntries={["/conversations/123"]} />);
|
||||
```
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Combining with `QueryClientProvider`
|
||||
|
||||
Many components need both routing and TanStack Query context:
|
||||
|
||||
```typescript
|
||||
import { createRoutesStub } from "react-router";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: { retry: false },
|
||||
},
|
||||
});
|
||||
|
||||
const RouterStub = createRoutesStub([
|
||||
{
|
||||
Component: MyComponent,
|
||||
path: "/",
|
||||
},
|
||||
]);
|
||||
|
||||
render(<RouterStub />, {
|
||||
wrapper: ({ children }) => (
|
||||
<QueryClientProvider client={queryClient}>
|
||||
{children}
|
||||
</QueryClientProvider>
|
||||
),
|
||||
});
|
||||
```
|
||||
|
||||
### Testing navigation behavior
|
||||
|
||||
Verify that user interactions trigger the expected navigation:
|
||||
|
||||
```typescript
|
||||
import { createRoutesStub } from "react-router";
|
||||
import { screen } from "@testing-library/react";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
|
||||
const RouterStub = createRoutesStub([
|
||||
{
|
||||
Component: HomeScreen,
|
||||
path: "/",
|
||||
},
|
||||
{
|
||||
Component: () => <div data-testid="settings-screen" />,
|
||||
path: "/settings",
|
||||
},
|
||||
]);
|
||||
|
||||
render(<RouterStub initialEntries={["/"]} />);
|
||||
|
||||
const user = userEvent.setup();
|
||||
await user.click(screen.getByRole("link", { name: /settings/i }));
|
||||
|
||||
expect(screen.getByTestId("settings-screen")).toBeInTheDocument();
|
||||
```
|
||||
|
||||
## See Also
|
||||
|
||||
### Codebase Examples
|
||||
|
||||
- [settings.test.tsx](__tests__/routes/settings.test.tsx) - `createRoutesStub` with nested routes and loaders
|
||||
- [home-screen.test.tsx](__tests__/routes/home-screen.test.tsx) - `createRoutesStub` with navigation testing
|
||||
- [chat-interface.test.tsx](__tests__/components/chat/chat-interface.test.tsx) - `MemoryRouter` usage
|
||||
|
||||
### Official Documentation
|
||||
|
||||
- [React Router Testing Guide](https://reactrouter.com/start/framework/testing) - Official guide on testing with `createRoutesStub`
|
||||
- [MemoryRouter API](https://reactrouter.com/api/declarative-routers/MemoryRouter) - API reference for `MemoryRouter`
|
||||
@@ -298,6 +298,7 @@ describe("Form submission", () => {
|
||||
gitlab: { token: "", host: "" },
|
||||
bitbucket: { token: "", host: "" },
|
||||
azure_devops: { token: "", host: "" },
|
||||
forgejo: { token: "", host: "" },
|
||||
});
|
||||
});
|
||||
|
||||
@@ -320,6 +321,7 @@ describe("Form submission", () => {
|
||||
gitlab: { token: "test-token", host: "" },
|
||||
bitbucket: { token: "", host: "" },
|
||||
azure_devops: { token: "", host: "" },
|
||||
forgejo: { token: "", host: "" },
|
||||
});
|
||||
});
|
||||
|
||||
@@ -342,6 +344,7 @@ describe("Form submission", () => {
|
||||
gitlab: { token: "", host: "" },
|
||||
bitbucket: { token: "test-token", host: "" },
|
||||
azure_devops: { token: "", host: "" },
|
||||
forgejo: { token: "", host: "" },
|
||||
});
|
||||
});
|
||||
|
||||
@@ -364,6 +367,7 @@ describe("Form submission", () => {
|
||||
gitlab: { token: "", host: "" },
|
||||
bitbucket: { token: "", host: "" },
|
||||
azure_devops: { token: "test-token", host: "" },
|
||||
forgejo: { token: "", host: "" },
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -1,316 +0,0 @@
|
||||
import { render, screen, waitFor } from "@testing-library/react";
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { QueryClientProvider, QueryClient } from "@tanstack/react-query";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { MemoryRouter, Routes, Route } from "react-router";
|
||||
import GitHubIssuesPRsPage from "#/routes/github-issues-prs";
|
||||
import GitHubIssuesPRsService from "#/api/github-service/github-issues-prs.api";
|
||||
import ConversationService from "#/api/conversation-service/conversation-service.api";
|
||||
import { useShouldShowUserFeatures } from "#/hooks/use-should-show-user-features";
|
||||
import { useUserProviders } from "#/hooks/use-user-providers";
|
||||
|
||||
// Mock the services
|
||||
vi.mock("#/api/github-service/github-issues-prs.api", () => ({
|
||||
default: {
|
||||
getGitHubItems: vi.fn(),
|
||||
buildItemUrl: vi.fn((provider, repo, number, type) => {
|
||||
if (provider === "github") {
|
||||
return `https://github.com/${repo}/${type === "issue" ? "issues" : "pull"}/${number}`;
|
||||
}
|
||||
return "";
|
||||
}),
|
||||
},
|
||||
}));
|
||||
|
||||
vi.mock("#/api/conversation-service/conversation-service.api", () => ({
|
||||
default: {
|
||||
searchConversations: vi.fn(),
|
||||
createConversation: vi.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/use-should-show-user-features", () => ({
|
||||
useShouldShowUserFeatures: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/use-user-providers", () => ({
|
||||
useUserProviders: vi.fn(),
|
||||
}));
|
||||
|
||||
// Mock react-i18next to return the key as the translation
|
||||
vi.mock("react-i18next", () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key,
|
||||
i18n: { language: "en" },
|
||||
}),
|
||||
}));
|
||||
|
||||
const mockGetGitHubItems = vi.mocked(GitHubIssuesPRsService.getGitHubItems);
|
||||
const mockSearchConversations = vi.mocked(
|
||||
ConversationService.searchConversations,
|
||||
);
|
||||
const mockUseShouldShowUserFeatures = vi.mocked(useShouldShowUserFeatures);
|
||||
const mockUseUserProviders = vi.mocked(useUserProviders);
|
||||
|
||||
const renderGitHubIssuesPRsPage = () =>
|
||||
render(
|
||||
<QueryClientProvider client={new QueryClient()}>
|
||||
<MemoryRouter initialEntries={["/github-issues-prs"]}>
|
||||
<Routes>
|
||||
<Route path="/github-issues-prs" element={<GitHubIssuesPRsPage />} />
|
||||
<Route
|
||||
path="/conversations/:conversationId"
|
||||
element={<div data-testid="conversation-screen" />}
|
||||
/>
|
||||
</Routes>
|
||||
</MemoryRouter>
|
||||
</QueryClientProvider>,
|
||||
);
|
||||
|
||||
const MOCK_GITHUB_ITEMS = [
|
||||
{
|
||||
git_provider: "github" as const,
|
||||
item_type: "issue" as const,
|
||||
status: "OPEN_ISSUE" as const,
|
||||
repo: "test/repo",
|
||||
number: 1,
|
||||
title: "Test Issue",
|
||||
author: "testuser",
|
||||
assignees: ["testuser"],
|
||||
created_at: "2024-01-01T00:00:00Z",
|
||||
updated_at: "2024-01-01T00:00:00Z",
|
||||
url: "https://github.com/test/repo/issues/1",
|
||||
},
|
||||
{
|
||||
git_provider: "github" as const,
|
||||
item_type: "pr" as const,
|
||||
status: "OPEN_PR" as const,
|
||||
repo: "test/repo",
|
||||
number: 2,
|
||||
title: "Test PR",
|
||||
author: "testuser",
|
||||
assignees: [],
|
||||
created_at: "2024-01-01T00:00:00Z",
|
||||
updated_at: "2024-01-01T00:00:00Z",
|
||||
url: "https://github.com/test/repo/pull/2",
|
||||
},
|
||||
];
|
||||
|
||||
describe("GitHubIssuesPRsPage", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
localStorage.clear();
|
||||
|
||||
mockUseShouldShowUserFeatures.mockReturnValue(true);
|
||||
mockUseUserProviders.mockReturnValue({
|
||||
providers: ["github"],
|
||||
isLoadingSettings: false,
|
||||
});
|
||||
|
||||
mockGetGitHubItems.mockResolvedValue({
|
||||
items: MOCK_GITHUB_ITEMS,
|
||||
cached_at: new Date().toISOString(),
|
||||
});
|
||||
|
||||
mockSearchConversations.mockResolvedValue([]);
|
||||
});
|
||||
|
||||
it("should render the page title", async () => {
|
||||
renderGitHubIssuesPRsPage();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("GITHUB_ISSUES_PRS$TITLE")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should render the view type selector", async () => {
|
||||
renderGitHubIssuesPRsPage();
|
||||
|
||||
await waitFor(() => {
|
||||
// The view label has a colon after it
|
||||
expect(screen.getByText("GITHUB_ISSUES_PRS$VIEW:")).toBeInTheDocument();
|
||||
expect(screen.getByRole("combobox")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should render filter checkboxes", async () => {
|
||||
renderGitHubIssuesPRsPage();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("GITHUB_ISSUES_PRS$ASSIGNED_TO_ME"),
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText("GITHUB_ISSUES_PRS$AUTHORED_BY_ME"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should render the refresh button", async () => {
|
||||
renderGitHubIssuesPRsPage();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("GITHUB_ISSUES_PRS$REFRESH")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should display GitHub items when loaded", async () => {
|
||||
renderGitHubIssuesPRsPage();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("Test Issue")).toBeInTheDocument();
|
||||
expect(screen.getByText("Test PR")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should display item status badges", async () => {
|
||||
renderGitHubIssuesPRsPage();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("GITHUB_ISSUES_PRS$OPEN_ISSUE"),
|
||||
).toBeInTheDocument();
|
||||
expect(screen.getByText("GITHUB_ISSUES_PRS$OPEN_PR")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should display Start Session buttons for items without related conversations", async () => {
|
||||
renderGitHubIssuesPRsPage();
|
||||
|
||||
await waitFor(() => {
|
||||
const startButtons = screen.getAllByText(
|
||||
"GITHUB_ISSUES_PRS$START_SESSION",
|
||||
);
|
||||
expect(startButtons.length).toBe(2);
|
||||
});
|
||||
});
|
||||
|
||||
it("should filter items when view type is changed to issues", async () => {
|
||||
renderGitHubIssuesPRsPage();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("Test Issue")).toBeInTheDocument();
|
||||
expect(screen.getByText("Test PR")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Change view type to issues
|
||||
const select = screen.getByRole("combobox");
|
||||
await userEvent.selectOptions(select, "issues");
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("Test Issue")).toBeInTheDocument();
|
||||
expect(screen.queryByText("Test PR")).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should filter items when view type is changed to PRs", async () => {
|
||||
renderGitHubIssuesPRsPage();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("Test Issue")).toBeInTheDocument();
|
||||
expect(screen.getByText("Test PR")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Change view type to PRs
|
||||
const select = screen.getByRole("combobox");
|
||||
await userEvent.selectOptions(select, "prs");
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByText("Test Issue")).not.toBeInTheDocument();
|
||||
expect(screen.getByText("Test PR")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should display Resume Session button when a related conversation exists", async () => {
|
||||
mockSearchConversations.mockResolvedValue([
|
||||
{
|
||||
conversation_id: "conv-1",
|
||||
title: "Working on #1",
|
||||
selected_repository: "test/repo",
|
||||
selected_branch: "main",
|
||||
git_provider: "github",
|
||||
pr_number: [1],
|
||||
created_at: "2024-01-01T00:00:00Z",
|
||||
last_updated_at: "2024-01-01T00:00:00Z",
|
||||
status: "RUNNING",
|
||||
runtime_status: null,
|
||||
url: null,
|
||||
session_api_key: null,
|
||||
},
|
||||
]);
|
||||
|
||||
renderGitHubIssuesPRsPage();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("GITHUB_ISSUES_PRS$RESUME_SESSION"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should show empty state when no items are found", async () => {
|
||||
mockGetGitHubItems.mockResolvedValue({
|
||||
items: [],
|
||||
cached_at: new Date().toISOString(),
|
||||
});
|
||||
|
||||
renderGitHubIssuesPRsPage();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("GITHUB_ISSUES_PRS$NO_ITEMS"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should display View on GitHub links", async () => {
|
||||
renderGitHubIssuesPRsPage();
|
||||
|
||||
await waitFor(() => {
|
||||
const viewLinks = screen.getAllByText("GITHUB_ISSUES_PRS$VIEW_ON_GITHUB");
|
||||
expect(viewLinks.length).toBe(2);
|
||||
});
|
||||
});
|
||||
|
||||
it("should show loading state while checking settings", async () => {
|
||||
mockUseUserProviders.mockReturnValue({
|
||||
providers: [],
|
||||
isLoadingSettings: true,
|
||||
});
|
||||
|
||||
renderGitHubIssuesPRsPage();
|
||||
|
||||
// Should show loading spinner
|
||||
expect(screen.getByTestId("loading-spinner")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should show no-token message when GitHub token is not configured", async () => {
|
||||
mockUseUserProviders.mockReturnValue({
|
||||
providers: [],
|
||||
isLoadingSettings: false,
|
||||
});
|
||||
|
||||
renderGitHubIssuesPRsPage();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("GITHUB_ISSUES_PRS$NO_TOKEN")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText("GITHUB_ISSUES_PRS$CONFIGURE_TOKEN"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should have a link to git settings when no token is configured", async () => {
|
||||
mockUseUserProviders.mockReturnValue({
|
||||
providers: [],
|
||||
isLoadingSettings: false,
|
||||
});
|
||||
|
||||
renderGitHubIssuesPRsPage();
|
||||
|
||||
await waitFor(() => {
|
||||
const configureLink = screen.getByText("GITHUB_ISSUES_PRS$CONFIGURE_TOKEN");
|
||||
expect(configureLink).toHaveAttribute("href", "/settings/git");
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -910,6 +910,162 @@ describe("Form submission", () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe("View persistence after saving advanced settings", () => {
|
||||
it("should remain on Advanced view after saving when memory condenser is disabled", async () => {
|
||||
// Arrange: Start with default settings (basic view)
|
||||
const getSettingsSpy = vi.spyOn(SettingsService, "getSettings");
|
||||
getSettingsSpy.mockResolvedValue({
|
||||
...MOCK_DEFAULT_USER_SETTINGS,
|
||||
});
|
||||
const saveSettingsSpy = vi.spyOn(SettingsService, "saveSettings");
|
||||
saveSettingsSpy.mockResolvedValue(true);
|
||||
|
||||
renderLlmSettingsScreen();
|
||||
await screen.findByTestId("llm-settings-screen");
|
||||
|
||||
// Verify we start in basic view
|
||||
expect(screen.getByTestId("llm-settings-form-basic")).toBeInTheDocument();
|
||||
|
||||
// Act: User manually switches to Advanced view
|
||||
const advancedSwitch = screen.getByTestId("advanced-settings-switch");
|
||||
await userEvent.click(advancedSwitch);
|
||||
await screen.findByTestId("llm-settings-form-advanced");
|
||||
|
||||
// User disables memory condenser (advanced-only setting)
|
||||
const condenserSwitch = screen.getByTestId(
|
||||
"enable-memory-condenser-switch",
|
||||
);
|
||||
expect(condenserSwitch).toBeChecked();
|
||||
await userEvent.click(condenserSwitch);
|
||||
expect(condenserSwitch).not.toBeChecked();
|
||||
|
||||
// Mock the updated settings that will be returned after save
|
||||
getSettingsSpy.mockResolvedValue({
|
||||
...MOCK_DEFAULT_USER_SETTINGS,
|
||||
enable_default_condenser: false, // Now disabled
|
||||
});
|
||||
|
||||
// User saves settings
|
||||
const submitButton = screen.getByTestId("submit-button");
|
||||
await userEvent.click(submitButton);
|
||||
|
||||
// Assert: View should remain on Advanced after save
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByTestId("llm-settings-form-advanced"),
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
screen.queryByTestId("llm-settings-form-basic"),
|
||||
).not.toBeInTheDocument();
|
||||
expect(advancedSwitch).toBeChecked();
|
||||
});
|
||||
});
|
||||
|
||||
it("should remain on Advanced view after saving when condenser max size is customized", async () => {
|
||||
// Arrange: Start with default settings
|
||||
const getSettingsSpy = vi.spyOn(SettingsService, "getSettings");
|
||||
getSettingsSpy.mockResolvedValue({
|
||||
...MOCK_DEFAULT_USER_SETTINGS,
|
||||
});
|
||||
const saveSettingsSpy = vi.spyOn(SettingsService, "saveSettings");
|
||||
saveSettingsSpy.mockResolvedValue(true);
|
||||
|
||||
renderLlmSettingsScreen();
|
||||
await screen.findByTestId("llm-settings-screen");
|
||||
|
||||
// Act: User manually switches to Advanced view
|
||||
const advancedSwitch = screen.getByTestId("advanced-settings-switch");
|
||||
await userEvent.click(advancedSwitch);
|
||||
await screen.findByTestId("llm-settings-form-advanced");
|
||||
|
||||
// User sets custom condenser max size (advanced-only setting)
|
||||
const condenserMaxSizeInput = screen.getByTestId(
|
||||
"condenser-max-size-input",
|
||||
);
|
||||
await userEvent.clear(condenserMaxSizeInput);
|
||||
await userEvent.type(condenserMaxSizeInput, "200");
|
||||
|
||||
// Mock the updated settings that will be returned after save
|
||||
getSettingsSpy.mockResolvedValue({
|
||||
...MOCK_DEFAULT_USER_SETTINGS,
|
||||
condenser_max_size: 200, // Custom value
|
||||
});
|
||||
|
||||
// User saves settings
|
||||
const submitButton = screen.getByTestId("submit-button");
|
||||
await userEvent.click(submitButton);
|
||||
|
||||
// Assert: View should remain on Advanced after save
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByTestId("llm-settings-form-advanced"),
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
screen.queryByTestId("llm-settings-form-basic"),
|
||||
).not.toBeInTheDocument();
|
||||
expect(advancedSwitch).toBeChecked();
|
||||
});
|
||||
});
|
||||
|
||||
it("should remain on Advanced view after saving when search API key is set", async () => {
|
||||
// Arrange: Start with default settings (non-SaaS mode to show search API key field)
|
||||
const getConfigSpy = vi.spyOn(OptionService, "getConfig");
|
||||
getConfigSpy.mockResolvedValue({
|
||||
APP_MODE: "oss",
|
||||
GITHUB_CLIENT_ID: "fake-github-client-id",
|
||||
POSTHOG_CLIENT_KEY: "fake-posthog-client-key",
|
||||
FEATURE_FLAGS: {
|
||||
ENABLE_BILLING: false,
|
||||
HIDE_LLM_SETTINGS: false,
|
||||
ENABLE_JIRA: false,
|
||||
ENABLE_JIRA_DC: false,
|
||||
ENABLE_LINEAR: false,
|
||||
},
|
||||
});
|
||||
|
||||
const getSettingsSpy = vi.spyOn(SettingsService, "getSettings");
|
||||
getSettingsSpy.mockResolvedValue({
|
||||
...MOCK_DEFAULT_USER_SETTINGS,
|
||||
search_api_key: "", // Default empty value
|
||||
});
|
||||
const saveSettingsSpy = vi.spyOn(SettingsService, "saveSettings");
|
||||
saveSettingsSpy.mockResolvedValue(true);
|
||||
|
||||
renderLlmSettingsScreen();
|
||||
await screen.findByTestId("llm-settings-screen");
|
||||
|
||||
// Act: User manually switches to Advanced view
|
||||
const advancedSwitch = screen.getByTestId("advanced-settings-switch");
|
||||
await userEvent.click(advancedSwitch);
|
||||
await screen.findByTestId("llm-settings-form-advanced");
|
||||
|
||||
// User sets search API key (advanced-only setting)
|
||||
const searchApiKeyInput = screen.getByTestId("search-api-key-input");
|
||||
await userEvent.type(searchApiKeyInput, "test-search-api-key");
|
||||
|
||||
// Mock the updated settings that will be returned after save
|
||||
getSettingsSpy.mockResolvedValue({
|
||||
...MOCK_DEFAULT_USER_SETTINGS,
|
||||
search_api_key: "test-search-api-key", // Now set
|
||||
});
|
||||
|
||||
// User saves settings
|
||||
const submitButton = screen.getByTestId("submit-button");
|
||||
await userEvent.click(submitButton);
|
||||
|
||||
// Assert: View should remain on Advanced after save
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByTestId("llm-settings-form-advanced"),
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
screen.queryByTestId("llm-settings-form-basic"),
|
||||
).not.toBeInTheDocument();
|
||||
expect(advancedSwitch).toBeChecked();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("Status toasts", () => {
|
||||
describe("Basic form", () => {
|
||||
it("should call displaySuccessToast when the settings are saved", async () => {
|
||||
|
||||
242
frontend/__tests__/routes/root-layout.test.tsx
Normal file
242
frontend/__tests__/routes/root-layout.test.tsx
Normal file
@@ -0,0 +1,242 @@
|
||||
import { render, screen, waitFor } from "@testing-library/react";
|
||||
import { it, describe, expect, vi, beforeEach, afterEach } from "vitest";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { createRoutesStub } from "react-router";
|
||||
import MainApp from "#/routes/root-layout";
|
||||
import OptionService from "#/api/option-service/option-service.api";
|
||||
import AuthService from "#/api/auth-service/auth-service.api";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
|
||||
// Mock other hooks that are not the focus of these tests
|
||||
vi.mock("#/hooks/use-github-auth-url", () => ({
|
||||
useGitHubAuthUrl: () => "https://github.com/oauth/authorize",
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/use-is-on-tos-page", () => ({
|
||||
useIsOnTosPage: () => false,
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/use-auto-login", () => ({
|
||||
useAutoLogin: () => {},
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/use-auth-callback", () => ({
|
||||
useAuthCallback: () => {},
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/use-migrate-user-consent", () => ({
|
||||
useMigrateUserConsent: () => ({
|
||||
migrateUserConsent: vi.fn(),
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/use-reo-tracking", () => ({
|
||||
useReoTracking: () => {},
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/use-sync-posthog-consent", () => ({
|
||||
useSyncPostHogConsent: () => {},
|
||||
}));
|
||||
|
||||
vi.mock("#/utils/custom-toast-handlers", () => ({
|
||||
displaySuccessToast: vi.fn(),
|
||||
}));
|
||||
|
||||
const RouterStub = createRoutesStub([
|
||||
{
|
||||
Component: MainApp,
|
||||
path: "/",
|
||||
children: [
|
||||
{
|
||||
Component: () => <div data-testid="outlet-content">Content</div>,
|
||||
path: "/",
|
||||
},
|
||||
],
|
||||
},
|
||||
]);
|
||||
|
||||
const createWrapper = () => {
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: {
|
||||
retry: false,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
return ({ children }: { children: React.ReactNode }) => (
|
||||
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
|
||||
);
|
||||
};
|
||||
|
||||
describe("MainApp - Email Verification Flow", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
|
||||
// Default mocks for services
|
||||
vi.spyOn(OptionService, "getConfig").mockResolvedValue({
|
||||
APP_MODE: "saas",
|
||||
GITHUB_CLIENT_ID: "test-client-id",
|
||||
POSTHOG_CLIENT_KEY: "test-posthog-key",
|
||||
PROVIDERS_CONFIGURED: ["github"],
|
||||
AUTH_URL: "https://auth.example.com",
|
||||
FEATURE_FLAGS: {
|
||||
ENABLE_BILLING: false,
|
||||
HIDE_LLM_SETTINGS: false,
|
||||
ENABLE_JIRA: false,
|
||||
ENABLE_JIRA_DC: false,
|
||||
ENABLE_LINEAR: false,
|
||||
},
|
||||
});
|
||||
|
||||
vi.spyOn(AuthService, "authenticate").mockResolvedValue(true);
|
||||
|
||||
vi.spyOn(SettingsService, "getSettings").mockResolvedValue({
|
||||
language: "en",
|
||||
user_consents_to_analytics: true,
|
||||
llm_model: "",
|
||||
llm_base_url: "",
|
||||
agent: "",
|
||||
llm_api_key: null,
|
||||
llm_api_key_set: false,
|
||||
search_api_key_set: false,
|
||||
confirmation_mode: false,
|
||||
security_analyzer: null,
|
||||
remote_runtime_resource_factor: null,
|
||||
provider_tokens_set: {},
|
||||
enable_default_condenser: false,
|
||||
condenser_max_size: null,
|
||||
enable_sound_notifications: false,
|
||||
enable_proactive_conversation_starters: false,
|
||||
enable_solvability_analysis: false,
|
||||
max_budget_per_task: null,
|
||||
});
|
||||
|
||||
// Mock localStorage
|
||||
vi.stubGlobal("localStorage", {
|
||||
getItem: vi.fn(() => null),
|
||||
setItem: vi.fn(),
|
||||
removeItem: vi.fn(),
|
||||
clear: vi.fn(),
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
it("should display EmailVerificationModal when email_verification_required=true is in query params", async () => {
|
||||
// Arrange & Act
|
||||
render(
|
||||
<RouterStub initialEntries={["/?email_verification_required=true"]} />,
|
||||
{ wrapper: createWrapper() },
|
||||
);
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("AUTH$PLEASE_CHECK_EMAIL_TO_VERIFY"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should set emailVerified state and pass to AuthModal when email_verified=true is in query params", async () => {
|
||||
// Arrange
|
||||
// Mock a 401 error to simulate unauthenticated user
|
||||
const axiosError = {
|
||||
response: { status: 401 },
|
||||
isAxiosError: true,
|
||||
};
|
||||
vi.spyOn(AuthService, "authenticate").mockRejectedValue(axiosError);
|
||||
|
||||
// Act
|
||||
render(<RouterStub initialEntries={["/?email_verified=true"]} />, {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
// Assert - Wait for AuthModal to render (since user is not authenticated)
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("AUTH$EMAIL_VERIFIED_PLEASE_LOGIN"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should handle both email_verification_required and email_verified params together", async () => {
|
||||
// Arrange & Act
|
||||
render(
|
||||
<RouterStub
|
||||
initialEntries={[
|
||||
"/?email_verification_required=true&email_verified=true",
|
||||
]}
|
||||
/>,
|
||||
{ wrapper: createWrapper() },
|
||||
);
|
||||
|
||||
// Assert - EmailVerificationModal should take precedence
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("AUTH$PLEASE_CHECK_EMAIL_TO_VERIFY"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should remove query parameters from URL after processing", async () => {
|
||||
// Arrange & Act
|
||||
const { container } = render(
|
||||
<RouterStub initialEntries={["/?email_verification_required=true"]} />,
|
||||
{ wrapper: createWrapper() },
|
||||
);
|
||||
|
||||
// Assert - Wait for the modal to appear (which indicates processing happened)
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("AUTH$PLEASE_CHECK_EMAIL_TO_VERIFY"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Verify that the query parameter was processed by checking the modal appeared
|
||||
// The hook removes the parameter from the URL, so we verify the behavior indirectly
|
||||
expect(container).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not display EmailVerificationModal when email_verification_required is not in query params", async () => {
|
||||
// Arrange - No query params set
|
||||
|
||||
// Act
|
||||
render(<RouterStub />, { wrapper: createWrapper() });
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.queryByText("AUTH$PLEASE_CHECK_EMAIL_TO_VERIFY"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should not display email verified message when email_verified is not in query params", async () => {
|
||||
// Arrange
|
||||
// Mock a 401 error to simulate unauthenticated user
|
||||
const axiosError = {
|
||||
response: { status: 401 },
|
||||
isAxiosError: true,
|
||||
};
|
||||
vi.spyOn(AuthService, "authenticate").mockRejectedValue(axiosError);
|
||||
|
||||
// Act
|
||||
render(<RouterStub />, { wrapper: createWrapper() });
|
||||
|
||||
// Assert - AuthModal should render but without email verified message
|
||||
await waitFor(() => {
|
||||
const authModal = screen.queryByText(
|
||||
"AUTH$SIGN_IN_WITH_IDENTITY_PROVIDER",
|
||||
);
|
||||
if (authModal) {
|
||||
expect(
|
||||
screen.queryByText("AUTH$EMAIL_VERIFIED_PLEASE_LOGIN"),
|
||||
).not.toBeInTheDocument();
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,76 +0,0 @@
|
||||
import { describe, it, expect } from "vitest";
|
||||
import GitHubIssuesPRsService from "../../src/api/github-service/github-issues-prs.api";
|
||||
|
||||
describe("GitHubIssuesPRsService", () => {
|
||||
describe("buildItemUrl", () => {
|
||||
it("should build correct GitHub issue URL", () => {
|
||||
const url = GitHubIssuesPRsService.buildItemUrl(
|
||||
"github",
|
||||
"owner/repo",
|
||||
123,
|
||||
"issue",
|
||||
);
|
||||
expect(url).toBe("https://github.com/owner/repo/issues/123");
|
||||
});
|
||||
|
||||
it("should build correct GitHub PR URL", () => {
|
||||
const url = GitHubIssuesPRsService.buildItemUrl(
|
||||
"github",
|
||||
"owner/repo",
|
||||
456,
|
||||
"pr",
|
||||
);
|
||||
expect(url).toBe("https://github.com/owner/repo/pull/456");
|
||||
});
|
||||
|
||||
it("should build correct GitLab issue URL", () => {
|
||||
const url = GitHubIssuesPRsService.buildItemUrl(
|
||||
"gitlab",
|
||||
"owner/repo",
|
||||
123,
|
||||
"issue",
|
||||
);
|
||||
expect(url).toBe("https://gitlab.com/owner/repo/-/issues/123");
|
||||
});
|
||||
|
||||
it("should build correct GitLab MR URL", () => {
|
||||
const url = GitHubIssuesPRsService.buildItemUrl(
|
||||
"gitlab",
|
||||
"owner/repo",
|
||||
456,
|
||||
"pr",
|
||||
);
|
||||
expect(url).toBe("https://gitlab.com/owner/repo/-/merge_requests/456");
|
||||
});
|
||||
|
||||
it("should build correct Bitbucket issue URL", () => {
|
||||
const url = GitHubIssuesPRsService.buildItemUrl(
|
||||
"bitbucket",
|
||||
"owner/repo",
|
||||
123,
|
||||
"issue",
|
||||
);
|
||||
expect(url).toBe("https://bitbucket.org/owner/repo/issues/123");
|
||||
});
|
||||
|
||||
it("should build correct Bitbucket PR URL", () => {
|
||||
const url = GitHubIssuesPRsService.buildItemUrl(
|
||||
"bitbucket",
|
||||
"owner/repo",
|
||||
456,
|
||||
"pr",
|
||||
);
|
||||
expect(url).toBe("https://bitbucket.org/owner/repo/pull-requests/456");
|
||||
});
|
||||
|
||||
it("should return empty string for unknown provider", () => {
|
||||
const url = GitHubIssuesPRsService.buildItemUrl(
|
||||
"unknown" as any,
|
||||
"owner/repo",
|
||||
123,
|
||||
"issue",
|
||||
);
|
||||
expect(url).toBe("");
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -9,7 +9,7 @@ import { useShouldShowUserFeatures } from "../src/hooks/use-should-show-user-fea
|
||||
vi.mock("../src/hooks/use-should-show-user-features");
|
||||
vi.mock("#/api/suggestions-service/suggestions-service.api", () => ({
|
||||
SuggestionsService: {
|
||||
getSuggestedTasks: vi.fn(),
|
||||
getSuggestedTasks: vi.fn().mockResolvedValue([]),
|
||||
},
|
||||
}));
|
||||
|
||||
|
||||
@@ -29,5 +29,75 @@ describe("hasAdvancedSettingsSet", () => {
|
||||
}),
|
||||
).toBe(true);
|
||||
});
|
||||
|
||||
test("enable_default_condenser is disabled", () => {
|
||||
// Arrange
|
||||
const settings = {
|
||||
...DEFAULT_SETTINGS,
|
||||
enable_default_condenser: false,
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = hasAdvancedSettingsSet(settings);
|
||||
|
||||
// Assert
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
test("condenser_max_size is customized above default", () => {
|
||||
// Arrange
|
||||
const settings = {
|
||||
...DEFAULT_SETTINGS,
|
||||
condenser_max_size: 200,
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = hasAdvancedSettingsSet(settings);
|
||||
|
||||
// Assert
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
test("condenser_max_size is customized below default", () => {
|
||||
// Arrange
|
||||
const settings = {
|
||||
...DEFAULT_SETTINGS,
|
||||
condenser_max_size: 50,
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = hasAdvancedSettingsSet(settings);
|
||||
|
||||
// Assert
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
test("search_api_key is set to non-empty value", () => {
|
||||
// Arrange
|
||||
const settings = {
|
||||
...DEFAULT_SETTINGS,
|
||||
search_api_key: "test-api-key-123",
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = hasAdvancedSettingsSet(settings);
|
||||
|
||||
// Assert
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
test("search_api_key with whitespace is treated as set", () => {
|
||||
// Arrange
|
||||
const settings = {
|
||||
...DEFAULT_SETTINGS,
|
||||
search_api_key: " test-key ",
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = hasAdvancedSettingsSet(settings);
|
||||
|
||||
// Assert
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
1807
frontend/package-lock.json
generated
1807
frontend/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -1,22 +1,22 @@
|
||||
{
|
||||
"name": "openhands-frontend",
|
||||
"version": "1.0.0",
|
||||
"version": "1.1.0",
|
||||
"private": true,
|
||||
"type": "module",
|
||||
"engines": {
|
||||
"node": ">=22.0.0"
|
||||
},
|
||||
"dependencies": {
|
||||
"@heroui/react": "2.8.6",
|
||||
"@heroui/react": "2.8.7",
|
||||
"@microlink/react-json-view": "^1.26.2",
|
||||
"@monaco-editor/react": "^4.7.0-rc.0",
|
||||
"@react-router/node": "^7.11.0",
|
||||
"@react-router/serve": "^7.11.0",
|
||||
"@tailwindcss/vite": "^4.1.18",
|
||||
"@tanstack/react-query": "^5.90.12",
|
||||
"@tanstack/react-query": "^5.90.14",
|
||||
"@uidotdev/usehooks": "^2.4.1",
|
||||
"@xterm/addon-fit": "^0.10.0",
|
||||
"@xterm/xterm": "^5.4.0",
|
||||
"@xterm/addon-fit": "^0.11.0",
|
||||
"@xterm/xterm": "^6.0.0",
|
||||
"axios": "^1.13.2",
|
||||
"class-variance-authority": "^0.7.1",
|
||||
"clsx": "^2.1.1",
|
||||
@@ -29,7 +29,7 @@
|
||||
"isbot": "^5.1.32",
|
||||
"lucide-react": "^0.562.0",
|
||||
"monaco-editor": "^0.55.1",
|
||||
"posthog-js": "^1.309.1",
|
||||
"posthog-js": "^1.310.1",
|
||||
"react": "^19.2.3",
|
||||
"react-dom": "^19.2.3",
|
||||
"react-hot-toast": "^2.6.0",
|
||||
@@ -41,7 +41,7 @@
|
||||
"remark-breaks": "^4.0.0",
|
||||
"remark-gfm": "^4.0.1",
|
||||
"sirv-cli": "^3.0.1",
|
||||
"socket.io-client": "^4.8.1",
|
||||
"socket.io-client": "^4.8.3",
|
||||
"tailwind-merge": "^3.4.0",
|
||||
"tailwind-scrollbar": "^4.0.2",
|
||||
"vite": "^7.3.0",
|
||||
@@ -109,7 +109,7 @@
|
||||
"eslint-plugin-react-hooks": "^4.6.2",
|
||||
"eslint-plugin-unused-imports": "^4.2.0",
|
||||
"husky": "^9.1.7",
|
||||
"jsdom": "^27.3.0",
|
||||
"jsdom": "^27.4.0",
|
||||
"lint-staged": "^16.2.7",
|
||||
"msw": "^2.6.6",
|
||||
"prettier": "^3.7.3",
|
||||
|
||||
102
frontend/src/api/README.md
Normal file
102
frontend/src/api/README.md
Normal file
@@ -0,0 +1,102 @@
|
||||
# API Services Guide
|
||||
|
||||
## Overview
|
||||
|
||||
Services are the abstraction layer between frontend components and backend APIs. They encapsulate HTTP requests using the shared `openHands` axios instance and provide typed methods for each endpoint.
|
||||
|
||||
Each service is a plain object with async methods.
|
||||
|
||||
## Structure
|
||||
|
||||
Each service lives in its own directory:
|
||||
|
||||
```
|
||||
src/api/
|
||||
├── billing-service/
|
||||
│ ├── billing-service.api.ts # Service methods
|
||||
│ └── billing.types.ts # Types and interfaces
|
||||
├── organization-service/
|
||||
│ ├── organization-service.api.ts
|
||||
│ └── organization.types.ts
|
||||
└── open-hands-axios.ts # Shared axios instance
|
||||
```
|
||||
|
||||
## Creating a Service
|
||||
|
||||
Use an object literal with named export. Use object destructuring for parameters to make calls self-documenting.
|
||||
|
||||
```typescript
|
||||
// feature-service/feature-service.api.ts
|
||||
import { openHands } from "../open-hands-axios";
|
||||
import { Feature, CreateFeatureParams } from "./feature.types";
|
||||
|
||||
export const featureService = {
|
||||
getFeature: async ({ id }: { id: string }) => {
|
||||
const { data } = await openHands.get<Feature>(`/api/features/${id}`);
|
||||
return data;
|
||||
},
|
||||
|
||||
createFeature: async ({ name, description }: CreateFeatureParams) => {
|
||||
const { data } = await openHands.post<Feature>("/api/features", {
|
||||
name,
|
||||
description,
|
||||
});
|
||||
return data;
|
||||
},
|
||||
};
|
||||
```
|
||||
|
||||
### Types
|
||||
|
||||
Define types in a separate file within the same directory:
|
||||
|
||||
```typescript
|
||||
// feature-service/feature.types.ts
|
||||
export interface Feature {
|
||||
id: string;
|
||||
name: string;
|
||||
description: string;
|
||||
}
|
||||
|
||||
export interface CreateFeatureParams {
|
||||
name: string;
|
||||
description: string;
|
||||
}
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
> [!IMPORTANT]
|
||||
> **Don't call services directly in components.** Wrap them in TanStack Query hooks.
|
||||
>
|
||||
> Why? TanStack Query provides:
|
||||
> - **Caching** - Avoid redundant network requests
|
||||
> - **Deduplication** - Multiple components requesting the same data share one request
|
||||
> - **Loading/error states** - Built-in `isLoading`, `isError`, `data` states
|
||||
> - **Background refetching** - Data stays fresh automatically
|
||||
>
|
||||
> Hooks location:
|
||||
> - `src/hooks/query/` for data fetching (`useQuery`)
|
||||
> - `src/hooks/mutation/` for writes/updates (`useMutation`)
|
||||
|
||||
```typescript
|
||||
// src/hooks/query/use-feature.ts
|
||||
import { useQuery } from "@tanstack/react-query";
|
||||
import { featureService } from "#/api/feature-service/feature-service.api";
|
||||
|
||||
export const useFeature = (id: string) => {
|
||||
return useQuery({
|
||||
queryKey: ["feature", id],
|
||||
queryFn: () => featureService.getFeature({ id }),
|
||||
});
|
||||
};
|
||||
```
|
||||
|
||||
## Naming Conventions
|
||||
|
||||
| Item | Convention | Example |
|
||||
|------|------------|---------|
|
||||
| Directory | `feature-service/` | `billing-service/` |
|
||||
| Service file | `feature-service.api.ts` | `billing-service.api.ts` |
|
||||
| Types file | `feature.types.ts` | `billing.types.ts` |
|
||||
| Export name | `featureService` | `billingService` |
|
||||
@@ -298,6 +298,23 @@ class V1ConversationService {
|
||||
return data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Update a V1 conversation's public flag
|
||||
* @param conversationId The conversation ID
|
||||
* @param isPublic Whether the conversation should be public
|
||||
* @returns Updated conversation info
|
||||
*/
|
||||
static async updateConversationPublicFlag(
|
||||
conversationId: string,
|
||||
isPublic: boolean,
|
||||
): Promise<V1AppConversation> {
|
||||
const { data } = await openHands.patch<V1AppConversation>(
|
||||
`/api/v1/app-conversations/${conversationId}`,
|
||||
{ public: isPublic },
|
||||
);
|
||||
return data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Read a file from a specific conversation's sandbox workspace
|
||||
* @param conversationId The conversation ID
|
||||
@@ -317,6 +334,21 @@ class V1ConversationService {
|
||||
return data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Download a conversation trajectory as a zip file
|
||||
* @param conversationId The conversation ID
|
||||
* @returns A blob containing the zip file
|
||||
*/
|
||||
static async downloadConversation(conversationId: string): Promise<Blob> {
|
||||
const response = await openHands.get(
|
||||
`/api/v1/app-conversations/${conversationId}/download`,
|
||||
{
|
||||
responseType: "blob",
|
||||
},
|
||||
);
|
||||
return response.data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all skills associated with a V1 conversation
|
||||
* @param conversationId The conversation ID
|
||||
|
||||
@@ -98,6 +98,7 @@ export interface V1AppConversation {
|
||||
execution_status: V1ConversationExecutionStatus | null;
|
||||
conversation_url: string | null;
|
||||
session_api_key: string | null;
|
||||
public?: boolean;
|
||||
}
|
||||
|
||||
export interface Skill {
|
||||
|
||||
35
frontend/src/api/email-service/email-service.api.ts
Normal file
35
frontend/src/api/email-service/email-service.api.ts
Normal file
@@ -0,0 +1,35 @@
|
||||
import { openHands } from "../open-hands-axios";
|
||||
import {
|
||||
ResendEmailVerificationParams,
|
||||
ResendEmailVerificationResponse,
|
||||
} from "./email.types";
|
||||
|
||||
/**
|
||||
* Email Service API - Handles all email-related API endpoints
|
||||
*/
|
||||
export const emailService = {
|
||||
/**
|
||||
* Resend email verification to the user's registered email address
|
||||
* @param userId - Optional user ID to send verification email for
|
||||
* @param isAuthFlow - Whether this is part of the authentication flow
|
||||
* @returns The response message indicating the email was sent
|
||||
*/
|
||||
resendEmailVerification: async ({
|
||||
userId,
|
||||
isAuthFlow,
|
||||
}: ResendEmailVerificationParams): Promise<ResendEmailVerificationResponse> => {
|
||||
const body: { user_id?: string; is_auth_flow?: boolean } = {};
|
||||
if (userId) {
|
||||
body.user_id = userId;
|
||||
}
|
||||
if (isAuthFlow !== undefined) {
|
||||
body.is_auth_flow = isAuthFlow;
|
||||
}
|
||||
const { data } = await openHands.put<ResendEmailVerificationResponse>(
|
||||
"/api/email/resend",
|
||||
body,
|
||||
{ withCredentials: true },
|
||||
);
|
||||
return data;
|
||||
},
|
||||
};
|
||||
8
frontend/src/api/email-service/email.types.ts
Normal file
8
frontend/src/api/email-service/email.types.ts
Normal file
@@ -0,0 +1,8 @@
|
||||
export interface ResendEmailVerificationParams {
|
||||
userId?: string | null;
|
||||
isAuthFlow?: boolean;
|
||||
}
|
||||
|
||||
export interface ResendEmailVerificationResponse {
|
||||
message: string;
|
||||
}
|
||||
@@ -131,9 +131,18 @@ class GitService {
|
||||
repository: string,
|
||||
page: number = 1,
|
||||
perPage: number = 30,
|
||||
selectedProvider?: Provider,
|
||||
): Promise<PaginatedBranchesResponse> {
|
||||
const { data } = await openHands.get<PaginatedBranchesResponse>(
|
||||
`/api/user/repository/branches?repository=${encodeURIComponent(repository)}&page=${page}&per_page=${perPage}`,
|
||||
`/api/user/repository/branches`,
|
||||
{
|
||||
params: {
|
||||
repository,
|
||||
page,
|
||||
per_page: perPage,
|
||||
selected_provider: selectedProvider,
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
return data;
|
||||
|
||||
@@ -53,6 +53,13 @@ class V1GitService {
|
||||
// V1 API returns V1GitChangeStatus types, we need to map them to V0 format
|
||||
const { data } = await axios.get<V1GitChange[]>(url, { headers });
|
||||
|
||||
// Validate response is an array (could be HTML error page if runtime is dead)
|
||||
if (!Array.isArray(data)) {
|
||||
throw new Error(
|
||||
"Invalid response from runtime - runtime may be unavailable",
|
||||
);
|
||||
}
|
||||
|
||||
// Map V1 statuses to V0 format for compatibility
|
||||
return data.map((change) => ({
|
||||
status: mapV1ToV0Status(change.status),
|
||||
|
||||
@@ -1,124 +0,0 @@
|
||||
import { openHands } from "../open-hands-axios";
|
||||
import { Provider } from "#/types/settings";
|
||||
|
||||
export type GitHubItemType = "issue" | "pr";
|
||||
|
||||
export type GitHubItemStatus =
|
||||
| "MERGE_CONFLICTS"
|
||||
| "FAILING_CHECKS"
|
||||
| "UNRESOLVED_COMMENTS"
|
||||
| "OPEN_ISSUE"
|
||||
| "OPEN_PR";
|
||||
|
||||
export interface GitHubItem {
|
||||
git_provider: Provider;
|
||||
item_type: GitHubItemType;
|
||||
status: GitHubItemStatus;
|
||||
repo: string;
|
||||
number: number;
|
||||
title: string;
|
||||
author: string;
|
||||
assignees: string[];
|
||||
created_at: string;
|
||||
updated_at: string;
|
||||
url: string;
|
||||
}
|
||||
|
||||
export interface GitHubItemsFilter {
|
||||
itemType?: "issues" | "prs" | "all";
|
||||
assignedToMe?: boolean;
|
||||
authoredByMe?: boolean;
|
||||
}
|
||||
|
||||
export interface GitHubItemsResponse {
|
||||
items: GitHubItem[];
|
||||
cached_at?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* GitHub Issues/PRs Service - Handles fetching GitHub issues and pull requests
|
||||
*/
|
||||
class GitHubIssuesPRsService {
|
||||
/**
|
||||
* Get GitHub issues and PRs for the authenticated user
|
||||
* This uses the existing suggested-tasks endpoint and transforms the data
|
||||
*/
|
||||
static async getGitHubItems(
|
||||
filter?: GitHubItemsFilter,
|
||||
): Promise<GitHubItemsResponse> {
|
||||
const { data } = await openHands.get<
|
||||
Array<{
|
||||
git_provider: Provider;
|
||||
task_type: GitHubItemStatus;
|
||||
repo: string;
|
||||
issue_number: number;
|
||||
title: string;
|
||||
}>
|
||||
>("/api/user/suggested-tasks");
|
||||
|
||||
// Transform the suggested tasks into GitHubItems
|
||||
const items: GitHubItem[] = data.map((task) => ({
|
||||
git_provider: task.git_provider,
|
||||
item_type: task.task_type === "OPEN_ISSUE" ? "issue" : "pr",
|
||||
status: task.task_type,
|
||||
repo: task.repo,
|
||||
number: task.issue_number,
|
||||
title: task.title,
|
||||
author: "", // Not available from suggested-tasks endpoint
|
||||
assignees: [], // Not available from suggested-tasks endpoint
|
||||
created_at: new Date().toISOString(), // Not available from suggested-tasks endpoint
|
||||
updated_at: new Date().toISOString(), // Not available from suggested-tasks endpoint
|
||||
url: GitHubIssuesPRsService.buildItemUrl(
|
||||
task.git_provider,
|
||||
task.repo,
|
||||
task.issue_number,
|
||||
task.task_type === "OPEN_ISSUE" ? "issue" : "pr",
|
||||
),
|
||||
}));
|
||||
|
||||
// Apply filters
|
||||
let filteredItems = items;
|
||||
|
||||
if (filter?.itemType === "issues") {
|
||||
filteredItems = filteredItems.filter(
|
||||
(item) => item.item_type === "issue",
|
||||
);
|
||||
} else if (filter?.itemType === "prs") {
|
||||
filteredItems = filteredItems.filter((item) => item.item_type === "pr");
|
||||
}
|
||||
|
||||
// Note: assignedToMe and authoredByMe filters would require additional API data
|
||||
// For now, the suggested-tasks endpoint already returns:
|
||||
// - PRs authored by the user
|
||||
// - Issues assigned to the user
|
||||
// So these filters are implicitly applied by the backend
|
||||
|
||||
return {
|
||||
items: filteredItems,
|
||||
cached_at: new Date().toISOString(),
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Build the URL for a GitHub item
|
||||
*/
|
||||
static buildItemUrl(
|
||||
provider: Provider,
|
||||
repo: string,
|
||||
number: number,
|
||||
itemType: GitHubItemType,
|
||||
): string {
|
||||
if (provider === "github") {
|
||||
return `https://github.com/${repo}/${itemType === "issue" ? "issues" : "pull"}/${number}`;
|
||||
}
|
||||
if (provider === "gitlab") {
|
||||
return `https://gitlab.com/${repo}/-/${itemType === "issue" ? "issues" : "merge_requests"}/${number}`;
|
||||
}
|
||||
if (provider === "bitbucket") {
|
||||
return `https://bitbucket.org/${repo}/${itemType === "issue" ? "issues" : "pull-requests"}/${number}`;
|
||||
}
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
export default GitHubIssuesPRsService;
|
||||
@@ -78,6 +78,7 @@ export interface Conversation {
|
||||
pr_number?: number[] | null;
|
||||
conversation_version?: "V0" | "V1";
|
||||
sub_conversation_ids?: string[];
|
||||
public?: boolean;
|
||||
}
|
||||
|
||||
export interface ResultSet<T> {
|
||||
|
||||
60
frontend/src/api/shared-conversation-service.api.ts
Normal file
60
frontend/src/api/shared-conversation-service.api.ts
Normal file
@@ -0,0 +1,60 @@
|
||||
import { OpenHandsEvent } from "#/types/v1/core";
|
||||
import { openHands } from "./open-hands-axios";
|
||||
|
||||
export interface SharedConversation {
|
||||
id: string;
|
||||
created_by_user_id: string | null;
|
||||
sandbox_id: string;
|
||||
selected_repository: string | null;
|
||||
selected_branch: string | null;
|
||||
git_provider: string | null;
|
||||
title: string | null;
|
||||
pr_number: number[];
|
||||
llm_model: string | null;
|
||||
metrics: unknown | null;
|
||||
parent_conversation_id: string | null;
|
||||
sub_conversation_ids: string[];
|
||||
created_at: string;
|
||||
updated_at: string;
|
||||
}
|
||||
|
||||
export interface EventPage {
|
||||
items: OpenHandsEvent[];
|
||||
next_page_id: string | null;
|
||||
}
|
||||
|
||||
export const sharedConversationService = {
|
||||
/**
|
||||
* Get a single shared conversation by ID
|
||||
*/
|
||||
async getSharedConversation(
|
||||
conversationId: string,
|
||||
): Promise<SharedConversation | null> {
|
||||
const response = await openHands.get<(SharedConversation | null)[]>(
|
||||
"/api/shared-conversations",
|
||||
{ params: { ids: conversationId } },
|
||||
);
|
||||
return response.data[0] || null;
|
||||
},
|
||||
|
||||
/**
|
||||
* Get events for a shared conversation
|
||||
*/
|
||||
async getSharedConversationEvents(
|
||||
conversationId: string,
|
||||
limit: number = 100,
|
||||
pageId?: string,
|
||||
): Promise<EventPage> {
|
||||
const response = await openHands.get<EventPage>(
|
||||
"/api/shared-events/search",
|
||||
{
|
||||
params: {
|
||||
conversation_id: conversationId,
|
||||
limit,
|
||||
...(pageId && { page_id: pageId }),
|
||||
},
|
||||
},
|
||||
);
|
||||
return response.data;
|
||||
},
|
||||
};
|
||||
@@ -11,6 +11,7 @@ interface ConversationCardActionsProps {
|
||||
onStop?: (event: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
onEdit?: (event: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
onDownloadViaVSCode?: (event: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
onDownloadConversation?: (event: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
conversationStatus?: ConversationStatus;
|
||||
conversationId?: string;
|
||||
showOptions?: boolean;
|
||||
@@ -23,6 +24,7 @@ export function ConversationCardActions({
|
||||
onStop,
|
||||
onEdit,
|
||||
onDownloadViaVSCode,
|
||||
onDownloadConversation,
|
||||
conversationStatus,
|
||||
conversationId,
|
||||
showOptions,
|
||||
@@ -62,6 +64,9 @@ export function ConversationCardActions({
|
||||
onDownloadViaVSCode={
|
||||
conversationId && showOptions ? onDownloadViaVSCode : undefined
|
||||
}
|
||||
onDownloadConversation={
|
||||
conversationId ? onDownloadConversation : undefined
|
||||
}
|
||||
position="bottom"
|
||||
/>
|
||||
</div>
|
||||
|
||||
@@ -24,6 +24,7 @@ interface ConversationCardContextMenuProps {
|
||||
onShowAgentTools?: (event: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
onShowSkills?: (event: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
onDownloadViaVSCode?: (event: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
onDownloadConversation?: (event: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
position?: "top" | "bottom";
|
||||
}
|
||||
|
||||
@@ -39,24 +40,27 @@ export function ConversationCardContextMenu({
|
||||
onShowAgentTools,
|
||||
onShowSkills,
|
||||
onDownloadViaVSCode,
|
||||
onDownloadConversation,
|
||||
position = "bottom",
|
||||
}: ConversationCardContextMenuProps) {
|
||||
const { t } = useTranslation();
|
||||
const ref = useClickOutsideElement<HTMLUListElement>(onClose);
|
||||
|
||||
const generateSection = useCallback(
|
||||
(items: React.ReactNode[], isLast?: boolean) => {
|
||||
(items: React.ReactNode[], sectionKey: string, isLast?: boolean) => {
|
||||
const filteredItems = items.filter((i) => i != null);
|
||||
|
||||
if (filteredItems.length > 0) {
|
||||
return !isLast
|
||||
? [
|
||||
...filteredItems,
|
||||
<Divider key="conversation-card-context-menu-divider" />,
|
||||
]
|
||||
: filteredItems;
|
||||
return !isLast ? (
|
||||
<React.Fragment key={sectionKey}>
|
||||
{filteredItems}
|
||||
<Divider />
|
||||
</React.Fragment>
|
||||
) : (
|
||||
<React.Fragment key={sectionKey}>{filteredItems}</React.Fragment>
|
||||
);
|
||||
}
|
||||
return [];
|
||||
return null;
|
||||
},
|
||||
[],
|
||||
);
|
||||
@@ -69,76 +73,104 @@ export function ConversationCardContextMenu({
|
||||
alignment="right"
|
||||
className="mt-0"
|
||||
>
|
||||
{generateSection([
|
||||
onEdit && (
|
||||
<ContextMenuListItem
|
||||
testId="edit-button"
|
||||
onClick={onEdit}
|
||||
className={contextMenuListItemClassName}
|
||||
>
|
||||
<ConversationNameContextMenuIconText
|
||||
icon={<EditIcon width={16} height={16} />}
|
||||
text={t(I18nKey.BUTTON$RENAME)}
|
||||
/>
|
||||
</ContextMenuListItem>
|
||||
),
|
||||
])}
|
||||
{generateSection([
|
||||
onShowAgentTools && (
|
||||
<ContextMenuListItem
|
||||
testId="show-agent-tools-button"
|
||||
onClick={onShowAgentTools}
|
||||
className={contextMenuListItemClassName}
|
||||
>
|
||||
<ConversationNameContextMenuIconText
|
||||
icon={<ToolsIcon width={16} height={16} />}
|
||||
text={t(I18nKey.BUTTON$SHOW_AGENT_TOOLS_AND_METADATA)}
|
||||
/>
|
||||
</ContextMenuListItem>
|
||||
),
|
||||
onShowSkills && (
|
||||
<ContextMenuListItem
|
||||
testId="show-skills-button"
|
||||
onClick={onShowSkills}
|
||||
className={contextMenuListItemClassName}
|
||||
>
|
||||
<ConversationNameContextMenuIconText
|
||||
icon={<RobotIcon width={16} height={16} />}
|
||||
text={t(I18nKey.CONVERSATION$SHOW_SKILLS)}
|
||||
/>
|
||||
</ContextMenuListItem>
|
||||
),
|
||||
])}
|
||||
{generateSection([
|
||||
onStop && (
|
||||
<ContextMenuListItem
|
||||
testId="stop-button"
|
||||
onClick={onStop}
|
||||
className={contextMenuListItemClassName}
|
||||
>
|
||||
<ConversationNameContextMenuIconText
|
||||
icon={<CloseIcon width={16} height={16} />}
|
||||
text={t(I18nKey.COMMON$CLOSE_CONVERSATION_STOP_RUNTIME)}
|
||||
/>
|
||||
</ContextMenuListItem>
|
||||
),
|
||||
onDownloadViaVSCode && (
|
||||
<ContextMenuListItem
|
||||
testId="download-vscode-button"
|
||||
onClick={onDownloadViaVSCode}
|
||||
className={contextMenuListItemClassName}
|
||||
>
|
||||
<ConversationNameContextMenuIconText
|
||||
icon={<DownloadIcon width={16} height={16} />}
|
||||
text={t(I18nKey.BUTTON$DOWNLOAD_VIA_VSCODE)}
|
||||
/>
|
||||
</ContextMenuListItem>
|
||||
),
|
||||
])}
|
||||
{generateSection(
|
||||
[
|
||||
onEdit && (
|
||||
<ContextMenuListItem
|
||||
key="edit-button"
|
||||
testId="edit-button"
|
||||
onClick={onEdit}
|
||||
className={contextMenuListItemClassName}
|
||||
>
|
||||
<ConversationNameContextMenuIconText
|
||||
icon={<EditIcon width={16} height={16} />}
|
||||
text={t(I18nKey.BUTTON$RENAME)}
|
||||
/>
|
||||
</ContextMenuListItem>
|
||||
),
|
||||
],
|
||||
"edit-section",
|
||||
)}
|
||||
{generateSection(
|
||||
[
|
||||
onShowAgentTools && (
|
||||
<ContextMenuListItem
|
||||
key="show-agent-tools-button"
|
||||
testId="show-agent-tools-button"
|
||||
onClick={onShowAgentTools}
|
||||
className={contextMenuListItemClassName}
|
||||
>
|
||||
<ConversationNameContextMenuIconText
|
||||
icon={<ToolsIcon width={16} height={16} />}
|
||||
text={t(I18nKey.BUTTON$SHOW_AGENT_TOOLS_AND_METADATA)}
|
||||
/>
|
||||
</ContextMenuListItem>
|
||||
),
|
||||
onShowSkills && (
|
||||
<ContextMenuListItem
|
||||
key="show-skills-button"
|
||||
testId="show-skills-button"
|
||||
onClick={onShowSkills}
|
||||
className={contextMenuListItemClassName}
|
||||
>
|
||||
<ConversationNameContextMenuIconText
|
||||
icon={<RobotIcon width={16} height={16} />}
|
||||
text={t(I18nKey.CONVERSATION$SHOW_SKILLS)}
|
||||
/>
|
||||
</ContextMenuListItem>
|
||||
),
|
||||
],
|
||||
"tools-section",
|
||||
)}
|
||||
{generateSection(
|
||||
[
|
||||
onStop && (
|
||||
<ContextMenuListItem
|
||||
key="stop-button"
|
||||
testId="stop-button"
|
||||
onClick={onStop}
|
||||
className={contextMenuListItemClassName}
|
||||
>
|
||||
<ConversationNameContextMenuIconText
|
||||
icon={<CloseIcon width={16} height={16} />}
|
||||
text={t(I18nKey.COMMON$CLOSE_CONVERSATION_STOP_RUNTIME)}
|
||||
/>
|
||||
</ContextMenuListItem>
|
||||
),
|
||||
onDownloadViaVSCode && (
|
||||
<ContextMenuListItem
|
||||
key="download-vscode-button"
|
||||
testId="download-vscode-button"
|
||||
onClick={onDownloadViaVSCode}
|
||||
className={contextMenuListItemClassName}
|
||||
>
|
||||
<ConversationNameContextMenuIconText
|
||||
icon={<DownloadIcon width={16} height={16} />}
|
||||
text={t(I18nKey.BUTTON$DOWNLOAD_VIA_VSCODE)}
|
||||
/>
|
||||
</ContextMenuListItem>
|
||||
),
|
||||
onDownloadConversation && (
|
||||
<ContextMenuListItem
|
||||
key="download-trajectory-button"
|
||||
testId="download-trajectory-button"
|
||||
onClick={onDownloadConversation}
|
||||
className={contextMenuListItemClassName}
|
||||
>
|
||||
<ConversationNameContextMenuIconText
|
||||
icon={<DownloadIcon width={16} height={16} />}
|
||||
text={t(I18nKey.BUTTON$EXPORT_CONVERSATION)}
|
||||
/>
|
||||
</ContextMenuListItem>
|
||||
),
|
||||
],
|
||||
"control-section",
|
||||
)}
|
||||
{generateSection(
|
||||
[
|
||||
onDisplayCost && (
|
||||
<ContextMenuListItem
|
||||
key="display-cost-button"
|
||||
testId="display-cost-button"
|
||||
onClick={onDisplayCost}
|
||||
className={contextMenuListItemClassName}
|
||||
@@ -151,6 +183,7 @@ export function ConversationCardContextMenu({
|
||||
),
|
||||
onDelete && (
|
||||
<ContextMenuListItem
|
||||
key="delete-button"
|
||||
testId="delete-button"
|
||||
onClick={onDelete}
|
||||
className={contextMenuListItemClassName}
|
||||
@@ -158,10 +191,11 @@ export function ConversationCardContextMenu({
|
||||
<ConversationNameContextMenuIconText
|
||||
icon={<DeleteIcon width={16} height={16} />}
|
||||
text={t(I18nKey.COMMON$DELETE_CONVERSATION)}
|
||||
/>{" "}
|
||||
/>
|
||||
</ContextMenuListItem>
|
||||
),
|
||||
],
|
||||
"info-section",
|
||||
true,
|
||||
)}
|
||||
</ContextMenu>
|
||||
|
||||
@@ -8,6 +8,7 @@ import { RepositorySelection } from "#/api/open-hands.types";
|
||||
import { ConversationCardHeader } from "./conversation-card-header";
|
||||
import { ConversationCardActions } from "./conversation-card-actions";
|
||||
import { ConversationCardFooter } from "./conversation-card-footer";
|
||||
import { useDownloadConversation } from "#/hooks/use-download-conversation";
|
||||
|
||||
interface ConversationCardProps {
|
||||
onClick?: () => void;
|
||||
@@ -46,6 +47,7 @@ export function ConversationCard({
|
||||
}: ConversationCardProps) {
|
||||
const posthog = usePostHog();
|
||||
const [titleMode, setTitleMode] = React.useState<"view" | "edit">("view");
|
||||
const { mutateAsync: downloadConversation } = useDownloadConversation();
|
||||
|
||||
const onTitleSave = (newTitle: string) => {
|
||||
if (newTitle !== "" && newTitle !== title) {
|
||||
@@ -101,6 +103,18 @@ export function ConversationCard({
|
||||
onContextMenuToggle?.(false);
|
||||
};
|
||||
|
||||
const handleDownloadConversation = async (
|
||||
event: React.MouseEvent<HTMLButtonElement>,
|
||||
) => {
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
|
||||
if (conversationId && conversationVersion === "V1") {
|
||||
await downloadConversation(conversationId);
|
||||
}
|
||||
onContextMenuToggle?.(false);
|
||||
};
|
||||
|
||||
const hasContextMenu = !!(onDelete || onChangeTitle || showOptions);
|
||||
|
||||
return (
|
||||
@@ -130,6 +144,11 @@ export function ConversationCard({
|
||||
onStop={onStop && handleStop}
|
||||
onEdit={onChangeTitle && handleEdit}
|
||||
onDownloadViaVSCode={handleDownloadViaVSCode}
|
||||
onDownloadConversation={
|
||||
conversationVersion === "V1"
|
||||
? handleDownloadConversation
|
||||
: undefined
|
||||
}
|
||||
conversationStatus={conversationStatus}
|
||||
conversationId={conversationId}
|
||||
showOptions={showOptions}
|
||||
|
||||
@@ -7,6 +7,7 @@ import { ContextMenuListItem } from "../context-menu/context-menu-list-item";
|
||||
import { Divider } from "#/ui/divider";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { useActiveConversation } from "#/hooks/query/use-active-conversation";
|
||||
import { useConfig } from "#/hooks/query/use-config";
|
||||
|
||||
import EditIcon from "#/icons/u-edit.svg?react";
|
||||
import RobotIcon from "#/icons/u-robot.svg?react";
|
||||
@@ -16,6 +17,7 @@ import DownloadIcon from "#/icons/u-download.svg?react";
|
||||
import CreditCardIcon from "#/icons/u-credit-card.svg?react";
|
||||
import CloseIcon from "#/icons/u-close.svg?react";
|
||||
import DeleteIcon from "#/icons/u-delete.svg?react";
|
||||
import LinkIcon from "#/icons/link-external.svg?react";
|
||||
import { ConversationNameContextMenuIconText } from "./conversation-name-context-menu-icon-text";
|
||||
import { CONTEXT_MENU_ICON_TEXT_CLASSNAME } from "#/utils/constants";
|
||||
|
||||
@@ -34,6 +36,9 @@ interface ConversationNameContextMenuProps {
|
||||
onShowSkills?: (event: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
onExportConversation?: (event: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
onDownloadViaVSCode?: (event: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
onTogglePublic?: (event: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
onDownloadConversation?: (event: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
onCopyShareLink?: (event: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
position?: "top" | "bottom";
|
||||
}
|
||||
|
||||
@@ -47,6 +52,9 @@ export function ConversationNameContextMenu({
|
||||
onShowSkills,
|
||||
onExportConversation,
|
||||
onDownloadViaVSCode,
|
||||
onTogglePublic,
|
||||
onDownloadConversation,
|
||||
onCopyShareLink,
|
||||
position = "bottom",
|
||||
}: ConversationNameContextMenuProps) {
|
||||
const { width } = useWindowSize();
|
||||
@@ -54,11 +62,17 @@ export function ConversationNameContextMenu({
|
||||
const { t } = useTranslation();
|
||||
const ref = useClickOutsideElement<HTMLUListElement>(onClose);
|
||||
const { data: conversation } = useActiveConversation();
|
||||
const { data: config } = useConfig();
|
||||
|
||||
// This is a temporary measure and may be re-enabled in the future
|
||||
const isV1Conversation = conversation?.conversation_version === "V1";
|
||||
|
||||
const hasDownload = Boolean(onDownloadViaVSCode);
|
||||
// Check if we should show the public sharing option
|
||||
// Only show for V1 conversations in SAAS mode
|
||||
const shouldShowPublicSharing =
|
||||
isV1Conversation && config?.APP_MODE === "saas" && onTogglePublic;
|
||||
|
||||
const hasDownload = Boolean(onDownloadViaVSCode || onDownloadConversation);
|
||||
const hasExport = Boolean(onExportConversation);
|
||||
const hasTools = Boolean(onShowAgentTools || onShowSkills);
|
||||
const hasInfo = Boolean(onDisplayCost);
|
||||
@@ -118,9 +132,9 @@ export function ConversationNameContextMenu({
|
||||
</ContextMenuListItem>
|
||||
)}
|
||||
|
||||
{(hasExport || hasDownload) && !isV1Conversation && (
|
||||
{(hasExport || hasDownload) && !isV1Conversation ? (
|
||||
<Divider testId="separator-export" />
|
||||
)}
|
||||
) : null}
|
||||
|
||||
{onExportConversation && !isV1Conversation && (
|
||||
<ContextMenuListItem
|
||||
@@ -150,10 +164,22 @@ export function ConversationNameContextMenu({
|
||||
</ContextMenuListItem>
|
||||
)}
|
||||
|
||||
{(hasInfo || hasControl) && !isV1Conversation && (
|
||||
<Divider testId="separator-info-control" />
|
||||
{onDownloadConversation && isV1Conversation && (
|
||||
<ContextMenuListItem
|
||||
testId="download-trajectory-button"
|
||||
onClick={onDownloadConversation}
|
||||
className={contextMenuListItemClassName}
|
||||
>
|
||||
<ConversationNameContextMenuIconText
|
||||
icon={<DownloadIcon width={16} height={16} />}
|
||||
text={t(I18nKey.BUTTON$EXPORT_CONVERSATION)}
|
||||
className={CONTEXT_MENU_ICON_TEXT_CLASSNAME}
|
||||
/>
|
||||
</ContextMenuListItem>
|
||||
)}
|
||||
|
||||
{(hasInfo || hasControl) && <Divider testId="separator-info-control" />}
|
||||
|
||||
{onDisplayCost && (
|
||||
<ContextMenuListItem
|
||||
testId="display-cost-button"
|
||||
@@ -168,6 +194,36 @@ export function ConversationNameContextMenu({
|
||||
</ContextMenuListItem>
|
||||
)}
|
||||
|
||||
{shouldShowPublicSharing && (
|
||||
<ContextMenuListItem
|
||||
testId="share-publicly-button"
|
||||
onClick={onTogglePublic}
|
||||
className={contextMenuListItemClassName}
|
||||
>
|
||||
<div className="flex items-center gap-2 justify-between w-full">
|
||||
<div className="flex items-center gap-2">
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={conversation?.public || false}
|
||||
className="w-4 h-4 ml-2"
|
||||
/>
|
||||
<span>{t(I18nKey.CONVERSATION$SHARE_PUBLICLY)}</span>
|
||||
</div>
|
||||
{conversation?.public && onCopyShareLink && (
|
||||
<button
|
||||
type="button"
|
||||
data-testid="copy-share-link-button"
|
||||
onClick={onCopyShareLink}
|
||||
className="p-1 hover:bg-[#717888] rounded"
|
||||
title={t(I18nKey.BUTTON$COPY_TO_CLIPBOARD)}
|
||||
>
|
||||
<LinkIcon width={16} height={16} />
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</ContextMenuListItem>
|
||||
)}
|
||||
|
||||
{onStop && (
|
||||
<ContextMenuListItem
|
||||
testId="stop-button"
|
||||
|
||||
@@ -6,6 +6,7 @@ import { useUpdateConversation } from "#/hooks/mutation/use-update-conversation"
|
||||
import { useConversationNameContextMenu } from "#/hooks/use-conversation-name-context-menu";
|
||||
import { displaySuccessToast } from "#/utils/custom-toast-handlers";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { ENABLE_PUBLIC_CONVERSATION_SHARING } from "#/utils/feature-flags";
|
||||
import { EllipsisButton } from "../conversation-panel/ellipsis-button";
|
||||
import { ConversationNameContextMenu } from "./conversation-name-context-menu";
|
||||
import { SystemMessageModal } from "../conversation-panel/system-message-modal";
|
||||
@@ -30,10 +31,13 @@ export function ConversationName() {
|
||||
handleDelete,
|
||||
handleStop,
|
||||
handleDownloadViaVSCode,
|
||||
handleDownloadConversation,
|
||||
handleDisplayCost,
|
||||
handleShowAgentTools,
|
||||
handleShowSkills,
|
||||
handleExportConversation,
|
||||
handleTogglePublic,
|
||||
handleCopyShareLink,
|
||||
handleConfirmDelete,
|
||||
handleConfirmStop,
|
||||
metricsModalVisible,
|
||||
@@ -50,6 +54,7 @@ export function ConversationName() {
|
||||
shouldShowStop,
|
||||
shouldShowDownload,
|
||||
shouldShowExport,
|
||||
shouldShowDownloadConversation,
|
||||
shouldShowDisplayCost,
|
||||
shouldShowAgentTools,
|
||||
shouldShowSkills,
|
||||
@@ -177,6 +182,21 @@ export function ConversationName() {
|
||||
onDownloadViaVSCode={
|
||||
shouldShowDownload ? handleDownloadViaVSCode : undefined
|
||||
}
|
||||
onTogglePublic={
|
||||
ENABLE_PUBLIC_CONVERSATION_SHARING()
|
||||
? handleTogglePublic
|
||||
: undefined
|
||||
}
|
||||
onCopyShareLink={
|
||||
ENABLE_PUBLIC_CONVERSATION_SHARING()
|
||||
? handleCopyShareLink
|
||||
: undefined
|
||||
}
|
||||
onDownloadConversation={
|
||||
shouldShowDownloadConversation
|
||||
? handleDownloadConversation
|
||||
: undefined
|
||||
}
|
||||
position="bottom"
|
||||
/>
|
||||
)}
|
||||
|
||||
@@ -79,6 +79,7 @@ export function BranchDropdownMenu({
|
||||
menuRef={menuRef}
|
||||
renderItem={renderItem}
|
||||
renderEmptyState={renderEmptyState}
|
||||
itemKey={(branch) => branch.name}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -211,6 +211,7 @@ export function GitProviderDropdown({
|
||||
getItemProps={getItemProps}
|
||||
renderItem={renderItem}
|
||||
renderEmptyState={renderEmptyState}
|
||||
itemKey={(provider) => provider}
|
||||
/>
|
||||
|
||||
<ErrorMessage isError={!!errorMessage} message={errorMessage} />
|
||||
|
||||
@@ -369,6 +369,7 @@ export function GitRepoDropdown({
|
||||
stickyFooterItem={stickyFooterItem}
|
||||
testId="git-repo-dropdown-menu"
|
||||
numberOfRecentItems={recentRepositories.length}
|
||||
itemKey={(repo) => repo.id}
|
||||
/>
|
||||
|
||||
<ErrorMessage isError={isError} />
|
||||
|
||||
@@ -39,6 +39,7 @@ export function ConversationStatusIndicator({
|
||||
ariaLabel={statusLabel}
|
||||
placement="right"
|
||||
showArrow
|
||||
asSpan
|
||||
className="p-0 border-0 bg-transparent hover:opacity-100"
|
||||
tooltipClassName="bg-[#1a1a1a] text-white text-xs shadow-lg"
|
||||
>
|
||||
|
||||
@@ -20,63 +20,59 @@ export function RecentConversation({ conversation }: RecentConversationProps) {
|
||||
conversation.selected_repository && conversation.selected_branch;
|
||||
|
||||
return (
|
||||
<Link to={`/conversations/${conversation.conversation_id}`}>
|
||||
<button
|
||||
type="button"
|
||||
className="flex flex-col gap-1 p-[14px] cursor-pointer w-full rounded-lg hover:bg-[#5C5D62] transition-all duration-300 text-left"
|
||||
>
|
||||
<div className="flex items-center gap-2 pl-1">
|
||||
<ConversationStatusIndicator
|
||||
conversationStatus={conversation.status}
|
||||
/>
|
||||
<span className="text-xs text-white leading-6 font-normal">
|
||||
{conversation.title}
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex items-center justify-between text-xs text-[#A3A3A3] leading-4 font-normal">
|
||||
<div className="flex items-center gap-3">
|
||||
{hasRepository ? (
|
||||
<div className="flex items-center gap-2">
|
||||
<GitProviderIcon
|
||||
gitProvider={conversation.git_provider as Provider}
|
||||
/>
|
||||
<span
|
||||
className="max-w-[124px] truncate"
|
||||
title={conversation.selected_repository || ""}
|
||||
>
|
||||
{conversation.selected_repository}
|
||||
</span>
|
||||
</div>
|
||||
) : (
|
||||
<div className="flex items-center gap-1">
|
||||
<RepoForkedIcon width={12} height={12} color="#A3A3A3" />
|
||||
<span className="max-w-[124px] truncate">
|
||||
{t(I18nKey.COMMON$NO_REPOSITORY)}
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
{hasRepository ? (
|
||||
<div className="flex items-center gap-1">
|
||||
<CodeBranchIcon width={12} height={12} color="#A3A3A3" />
|
||||
<span
|
||||
className="max-w-[124px] truncate"
|
||||
title={conversation.selected_branch || ""}
|
||||
>
|
||||
{conversation.selected_branch}
|
||||
</span>
|
||||
</div>
|
||||
) : null}
|
||||
</div>
|
||||
{(conversation.created_at || conversation.last_updated_at) && (
|
||||
<span>
|
||||
{formatTimeDelta(
|
||||
conversation.created_at || conversation.last_updated_at,
|
||||
)}{" "}
|
||||
{t(I18nKey.CONVERSATION$AGO)}
|
||||
</span>
|
||||
<Link
|
||||
to={`/conversations/${conversation.conversation_id}`}
|
||||
className="flex flex-col gap-1 p-[14px] cursor-pointer w-full rounded-lg hover:bg-[#5C5D62] transition-all duration-300 text-left"
|
||||
>
|
||||
<div className="flex items-center gap-2 pl-1">
|
||||
<ConversationStatusIndicator conversationStatus={conversation.status} />
|
||||
<span className="text-xs text-white leading-6 font-normal">
|
||||
{conversation.title}
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex items-center justify-between text-xs text-[#A3A3A3] leading-4 font-normal">
|
||||
<div className="flex items-center gap-3">
|
||||
{hasRepository ? (
|
||||
<div className="flex items-center gap-2">
|
||||
<GitProviderIcon
|
||||
gitProvider={conversation.git_provider as Provider}
|
||||
/>
|
||||
<span
|
||||
className="max-w-[124px] truncate"
|
||||
title={conversation.selected_repository || ""}
|
||||
>
|
||||
{conversation.selected_repository}
|
||||
</span>
|
||||
</div>
|
||||
) : (
|
||||
<div className="flex items-center gap-1">
|
||||
<RepoForkedIcon width={12} height={12} color="#A3A3A3" />
|
||||
<span className="max-w-[124px] truncate">
|
||||
{t(I18nKey.COMMON$NO_REPOSITORY)}
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
{hasRepository ? (
|
||||
<div className="flex items-center gap-1">
|
||||
<CodeBranchIcon width={12} height={12} color="#A3A3A3" />
|
||||
<span
|
||||
className="max-w-[124px] truncate"
|
||||
title={conversation.selected_branch || ""}
|
||||
>
|
||||
{conversation.selected_branch}
|
||||
</span>
|
||||
</div>
|
||||
) : null}
|
||||
</div>
|
||||
</button>
|
||||
{(conversation.created_at || conversation.last_updated_at) && (
|
||||
<span>
|
||||
{formatTimeDelta(
|
||||
conversation.created_at || conversation.last_updated_at,
|
||||
)}{" "}
|
||||
{t(I18nKey.CONVERSATION$AGO)}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</Link>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -33,6 +33,7 @@ export interface GenericDropdownMenuProps<T> {
|
||||
stickyFooterItem?: React.ReactNode;
|
||||
testId?: string;
|
||||
numberOfRecentItems?: number;
|
||||
itemKey: (item: T) => string | number;
|
||||
}
|
||||
|
||||
export function GenericDropdownMenu<T>({
|
||||
@@ -51,12 +52,28 @@ export function GenericDropdownMenu<T>({
|
||||
stickyFooterItem,
|
||||
testId,
|
||||
numberOfRecentItems = 0,
|
||||
itemKey,
|
||||
}: GenericDropdownMenuProps<T>) {
|
||||
if (!isOpen) return null;
|
||||
|
||||
const hasItems = filteredItems.length > 0;
|
||||
const showEmptyState = !hasItems && !stickyTopItem && !stickyFooterItem;
|
||||
|
||||
// Always render the menu container (even when closed) so getMenuProps is always called
|
||||
// This prevents the downshift warning about forgetting to call getMenuProps
|
||||
if (!isOpen) {
|
||||
return (
|
||||
<div className="relative">
|
||||
<ul
|
||||
// eslint-disable-next-line react/jsx-props-no-spreading
|
||||
{...getMenuProps({
|
||||
ref: menuRef,
|
||||
className: "hidden",
|
||||
"data-testid": testId,
|
||||
})}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="relative">
|
||||
<div
|
||||
@@ -85,21 +102,24 @@ export function GenericDropdownMenu<T>({
|
||||
) : (
|
||||
<>
|
||||
{stickyTopItem}
|
||||
{filteredItems.map((item, index) => (
|
||||
<>
|
||||
{renderItem(
|
||||
item,
|
||||
index,
|
||||
highlightedIndex,
|
||||
selectedItem,
|
||||
getItemProps,
|
||||
)}
|
||||
{numberOfRecentItems > 0 &&
|
||||
index === numberOfRecentItems - 1 && (
|
||||
<div className="border-b border-[#727987] bg-[#454545] pb-1 mb-1 h-[1px]" />
|
||||
{filteredItems.map((item, index) => {
|
||||
const key = itemKey(item);
|
||||
return (
|
||||
<React.Fragment key={key}>
|
||||
{renderItem(
|
||||
item,
|
||||
index,
|
||||
highlightedIndex,
|
||||
selectedItem,
|
||||
getItemProps,
|
||||
)}
|
||||
</>
|
||||
))}
|
||||
{numberOfRecentItems > 0 &&
|
||||
index === numberOfRecentItems - 1 && (
|
||||
<div className="border-b border-[#727987] bg-[#454545] pb-1 mb-1 h-[1px]" />
|
||||
)}
|
||||
</React.Fragment>
|
||||
);
|
||||
})}
|
||||
</>
|
||||
)}
|
||||
</ul>
|
||||
|
||||
@@ -17,9 +17,10 @@ export function MicroagentManagementAccordionTitle({
|
||||
<TooltipButton
|
||||
tooltip={repository.full_name}
|
||||
ariaLabel={repository.full_name}
|
||||
className="text-white text-base font-normal bg-transparent p-0 min-w-0 h-auto cursor-pointer truncate max-w-[194px] translate-y-[-1px]"
|
||||
testId="repository-name-tooltip"
|
||||
placement="bottom"
|
||||
asSpan
|
||||
className="text-white text-base font-normal bg-transparent p-0 min-w-0 h-auto cursor-pointer truncate max-w-[194px] translate-y-[-1px]"
|
||||
>
|
||||
<span>{repository.full_name}</span>
|
||||
</TooltipButton>
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user