mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-06 21:44:00 -05:00
Merge branch 'main' into reuse-running-sandboxes
This commit is contained in:
12
.github/CODEOWNERS
vendored
12
.github/CODEOWNERS
vendored
@@ -1,12 +1,8 @@
|
||||
# CODEOWNERS file for OpenHands repository
|
||||
# See https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners
|
||||
|
||||
# Frontend code owners
|
||||
/frontend/ @amanape
|
||||
/openhands-ui/ @amanape
|
||||
|
||||
# Evaluation code owners
|
||||
/frontend/ @amanape @hieptl
|
||||
/openhands-ui/ @amanape @hieptl
|
||||
/openhands/ @tofarr @malhotra5 @hieptl
|
||||
/enterprise/ @chuckbutkus @tofarr @malhotra5
|
||||
/evaluation/ @xingyaoww @neubig
|
||||
|
||||
# Documentation code owners
|
||||
/docs/ @mamoodi
|
||||
|
||||
2
.github/workflows/check-package-versions.yml
vendored
2
.github/workflows/check-package-versions.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
|
||||
8
.github/workflows/e2e-tests.yml
vendored
8
.github/workflows/e2e-tests.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
||||
poetry-version: 2.1.3
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.12'
|
||||
cache: 'poetry'
|
||||
@@ -38,7 +38,7 @@ jobs:
|
||||
sudo apt-get install -y libgtk-3-0 libnotify4 libnss3 libxss1 libxtst6 xauth xvfb libgbm1 libasound2t64 netcat-openbsd
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: '22'
|
||||
cache: 'npm'
|
||||
@@ -192,7 +192,7 @@ jobs:
|
||||
|
||||
- name: Upload test results
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: playwright-report
|
||||
path: tests/e2e/test-results/
|
||||
@@ -200,7 +200,7 @@ jobs:
|
||||
|
||||
- name: Upload OpenHands logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: openhands-logs
|
||||
path: |
|
||||
|
||||
2
.github/workflows/fe-e2e-tests.yml
vendored
2
.github/workflows/fe-e2e-tests.yml
vendored
@@ -39,7 +39,7 @@ jobs:
|
||||
working-directory: ./frontend
|
||||
run: npx playwright test --project=chromium
|
||||
- name: Upload Playwright report
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
if: always()
|
||||
with:
|
||||
name: playwright-report
|
||||
|
||||
6
.github/workflows/ghcr-build.yml
vendored
6
.github/workflows/ghcr-build.yml
vendored
@@ -161,7 +161,7 @@ jobs:
|
||||
context: containers/runtime
|
||||
- name: Upload runtime source for fork
|
||||
if: github.event.pull_request.head.repo.fork
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: runtime-src-${{ matrix.base_image.tag }}
|
||||
path: containers/runtime
|
||||
@@ -268,7 +268,7 @@ jobs:
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Download runtime source for fork
|
||||
if: github.event.pull_request.head.repo.fork
|
||||
uses: actions/download-artifact@v4
|
||||
uses: actions/download-artifact@v6
|
||||
with:
|
||||
name: runtime-src-${{ matrix.base_image.tag }}
|
||||
path: containers/runtime
|
||||
@@ -330,7 +330,7 @@ jobs:
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Download runtime source for fork
|
||||
if: github.event.pull_request.head.repo.fork
|
||||
uses: actions/download-artifact@v4
|
||||
uses: actions/download-artifact@v6
|
||||
with:
|
||||
name: runtime-src-${{ matrix.base_image.tag }}
|
||||
path: containers/runtime
|
||||
|
||||
4
.github/workflows/openhands-resolver.yml
vendored
4
.github/workflows/openhands-resolver.yml
vendored
@@ -89,7 +89,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.12"
|
||||
- name: Upgrade pip
|
||||
@@ -269,7 +269,7 @@ jobs:
|
||||
fi
|
||||
|
||||
- name: Upload output.jsonl as artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
if: always() # Upload even if the previous steps fail
|
||||
with:
|
||||
name: resolver-output
|
||||
|
||||
6
.github/workflows/py-tests.yml
vendored
6
.github/workflows/py-tests.yml
vendored
@@ -63,7 +63,7 @@ jobs:
|
||||
env:
|
||||
COVERAGE_FILE: ".coverage.runtime.${{ matrix.python_version }}"
|
||||
- name: Store coverage file
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: coverage-openhands
|
||||
path: |
|
||||
@@ -95,7 +95,7 @@ jobs:
|
||||
env:
|
||||
COVERAGE_FILE: ".coverage.enterprise.${{ matrix.python_version }}"
|
||||
- name: Store coverage file
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: coverage-enterprise
|
||||
path: ".coverage.enterprise.${{ matrix.python_version }}"
|
||||
@@ -113,7 +113,7 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/download-artifact@v5
|
||||
- uses: actions/download-artifact@v6
|
||||
id: download
|
||||
with:
|
||||
pattern: coverage-*
|
||||
|
||||
6
.github/workflows/vscode-extension-build.yml
vendored
6
.github/workflows/vscode-extension-build.yml
vendored
@@ -37,7 +37,7 @@ jobs:
|
||||
node-version: '22'
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
@@ -70,7 +70,7 @@ jobs:
|
||||
fi
|
||||
|
||||
- name: Upload VSCode extension artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: vscode-extension
|
||||
path: openhands/integrations/vscode/openhands-vscode-0.0.1.vsix
|
||||
@@ -142,7 +142,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Download .vsix artifact
|
||||
uses: actions/download-artifact@v4
|
||||
uses: actions/download-artifact@v6
|
||||
with:
|
||||
name: vscode-extension
|
||||
path: ./
|
||||
|
||||
@@ -63,7 +63,7 @@ Frontend:
|
||||
- We use TanStack Query (fka React Query) for data fetching and cache management
|
||||
- Data Access Layer: API client methods are located in `frontend/src/api` and should never be called directly from UI components - they must always be wrapped with TanStack Query
|
||||
- Custom hooks are located in `frontend/src/hooks/query/` and `frontend/src/hooks/mutation/`
|
||||
- Query hooks should follow the pattern use[Resource] (e.g., `useConversationMicroagents`)
|
||||
- Query hooks should follow the pattern use[Resource] (e.g., `useConversationSkills`)
|
||||
- Mutation hooks should follow the pattern use[Action] (e.g., `useDeleteConversation`)
|
||||
- Architecture rule: UI components → TanStack Query hooks → Data Access Layer (`frontend/src/api`) → API endpoints
|
||||
|
||||
|
||||
@@ -161,7 +161,7 @@ poetry run pytest ./tests/unit/test_*.py
|
||||
To reduce build time (e.g., if no changes were made to the client-runtime component), you can use an existing Docker
|
||||
container image by setting the SANDBOX_RUNTIME_CONTAINER_IMAGE environment variable to the desired Docker image.
|
||||
|
||||
Example: `export SANDBOX_RUNTIME_CONTAINER_IMAGE=ghcr.io/openhands/runtime:0.62-nikolaik`
|
||||
Example: `export SANDBOX_RUNTIME_CONTAINER_IMAGE=ghcr.io/openhands/runtime:1.0-nikolaik`
|
||||
|
||||
## Develop inside Docker container
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
|
||||
<div align="center">
|
||||
<a href="https://github.com/OpenHands/OpenHands/blob/main/LICENSE"><img src="https://img.shields.io/badge/LICENSE-MIT-20B2AA?style=for-the-badge" alt="MIT License"></a>
|
||||
<a href="https://docs.google.com/spreadsheets/d/1wOUdFCMyY6Nt0AIqF705KN4JKOWgeI4wUGUP60krXXs/edit?gid=811504672#gid=811504672"><img src="https://img.shields.io/badge/SWEBench-72.8-00cc00?logoColor=FFE165&style=for-the-badge" alt="Benchmark Score"></a>
|
||||
<a href="https://docs.google.com/spreadsheets/d/1wOUdFCMyY6Nt0AIqF705KN4JKOWgeI4wUGUP60krXXs/edit?gid=811504672#gid=811504672"><img src="https://img.shields.io/badge/SWEBench-77.6-00cc00?logoColor=FFE165&style=for-the-badge" alt="Benchmark Score"></a>
|
||||
<br/>
|
||||
<a href="https://docs.openhands.dev/sdk"><img src="https://img.shields.io/badge/Documentation-000?logo=googledocs&logoColor=FFE165&style=for-the-badge" alt="Check out the documentation"></a>
|
||||
<a href="https://arxiv.org/abs/2511.03690"><img src="https://img.shields.io/badge/Paper-000?logoColor=FFE165&logo=arxiv&style=for-the-badge" alt="Tech Report"></a>
|
||||
|
||||
@@ -12,7 +12,7 @@ services:
|
||||
- SANDBOX_API_HOSTNAME=host.docker.internal
|
||||
- DOCKER_HOST_ADDR=host.docker.internal
|
||||
#
|
||||
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-ghcr.io/openhands/runtime:0.62-nikolaik}
|
||||
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-ghcr.io/openhands/runtime:1.0-nikolaik}
|
||||
- SANDBOX_USER_ID=${SANDBOX_USER_ID:-1234}
|
||||
- WORKSPACE_MOUNT_PATH=${WORKSPACE_BASE:-$PWD/workspace}
|
||||
ports:
|
||||
|
||||
@@ -7,7 +7,7 @@ services:
|
||||
image: openhands:latest
|
||||
container_name: openhands-app-${DATE:-}
|
||||
environment:
|
||||
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-docker.openhands.dev/openhands/runtime:0.62-nikolaik}
|
||||
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-docker.openhands.dev/openhands/runtime:1.0-nikolaik}
|
||||
#- SANDBOX_USER_ID=${SANDBOX_USER_ID:-1234} # enable this only if you want a specific non-root sandbox user but you will have to manually adjust permissions of ~/.openhands for this user
|
||||
- WORKSPACE_MOUNT_PATH=${WORKSPACE_BASE:-$PWD/workspace}
|
||||
ports:
|
||||
|
||||
@@ -31,9 +31,8 @@ RUN pip install alembic psycopg2-binary cloud-sql-python-connector pg8000 gsprea
|
||||
"pillow>=11.3.0"
|
||||
|
||||
WORKDIR /app
|
||||
COPY enterprise .
|
||||
COPY --chown=openhands:openhands --chmod=770 enterprise .
|
||||
|
||||
RUN chown -R openhands:openhands /app && chmod -R 770 /app
|
||||
USER openhands
|
||||
|
||||
# Command will be overridden by Kubernetes deployment template
|
||||
|
||||
@@ -721,6 +721,7 @@
|
||||
"https://$WEB_HOST/oauth/keycloak/callback",
|
||||
"https://$WEB_HOST/oauth/keycloak/offline/callback",
|
||||
"https://$WEB_HOST/slack/keycloak-callback",
|
||||
"https://$WEB_HOST/oauth/device/keycloak-callback",
|
||||
"https://$WEB_HOST/api/email/verified",
|
||||
"/realms/$KEYCLOAK_REALM_NAME/$KEYCLOAK_CLIENT_ID/*"
|
||||
],
|
||||
|
||||
@@ -116,7 +116,7 @@ lines.append('POSTHOG_CLIENT_KEY=test')
|
||||
lines.append('ENABLE_PROACTIVE_CONVERSATION_STARTERS=true')
|
||||
lines.append('MAX_CONCURRENT_CONVERSATIONS=10')
|
||||
lines.append('LITE_LLM_API_URL=https://llm-proxy.eval.all-hands.dev')
|
||||
lines.append('LITELLM_DEFAULT_MODEL=litellm_proxy/claude-sonnet-4-20250514')
|
||||
lines.append('LITELLM_DEFAULT_MODEL=litellm_proxy/claude-opus-4-5-20251101')
|
||||
lines.append(f'LITE_LLM_API_KEY={lite_llm_api_key}')
|
||||
lines.append('LOCAL_DEPLOYMENT=true')
|
||||
lines.append('DB_HOST=localhost')
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
"""Create device_codes table for OAuth 2.0 Device Flow
|
||||
|
||||
Revision ID: 084
|
||||
Revises: 083
|
||||
Create Date: 2024-12-10 12:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '084'
|
||||
down_revision = '083'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
"""Create device_codes table for OAuth 2.0 Device Flow."""
|
||||
op.create_table(
|
||||
'device_codes',
|
||||
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column('device_code', sa.String(length=128), nullable=False),
|
||||
sa.Column('user_code', sa.String(length=16), nullable=False),
|
||||
sa.Column('status', sa.String(length=32), nullable=False),
|
||||
sa.Column('keycloak_user_id', sa.String(length=255), nullable=True),
|
||||
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('authorized_at', sa.DateTime(timezone=True), nullable=True),
|
||||
# Rate limiting fields for RFC 8628 section 3.5 compliance
|
||||
sa.Column('last_poll_time', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('current_interval', sa.Integer(), nullable=False, default=5),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
)
|
||||
|
||||
# Create indexes for efficient lookups
|
||||
op.create_index(
|
||||
'ix_device_codes_device_code', 'device_codes', ['device_code'], unique=True
|
||||
)
|
||||
op.create_index(
|
||||
'ix_device_codes_user_code', 'device_codes', ['user_code'], unique=True
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
"""Drop device_codes table."""
|
||||
op.drop_index('ix_device_codes_user_code', table_name='device_codes')
|
||||
op.drop_index('ix_device_codes_device_code', table_name='device_codes')
|
||||
op.drop_table('device_codes')
|
||||
@@ -0,0 +1,41 @@
|
||||
"""add public column to conversation_metadata
|
||||
|
||||
Revision ID: 085
|
||||
Revises: 084
|
||||
Create Date: 2025-01-27 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '085'
|
||||
down_revision: Union[str, None] = '084'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
op.add_column(
|
||||
'conversation_metadata',
|
||||
sa.Column('public', sa.Boolean(), nullable=True),
|
||||
)
|
||||
op.create_index(
|
||||
op.f('ix_conversation_metadata_public'),
|
||||
'conversation_metadata',
|
||||
['public'],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
op.drop_index(
|
||||
op.f('ix_conversation_metadata_public'),
|
||||
table_name='conversation_metadata',
|
||||
)
|
||||
op.drop_column('conversation_metadata', 'public')
|
||||
82
enterprise/poetry.lock
generated
82
enterprise/poetry.lock
generated
@@ -4558,25 +4558,25 @@ valkey = ["valkey (>=6)"]
|
||||
|
||||
[[package]]
|
||||
name = "litellm"
|
||||
version = "1.80.7"
|
||||
version = "1.80.11"
|
||||
description = "Library to easily interface with LLM API providers"
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "litellm-1.80.7-py3-none-any.whl", hash = "sha256:f7d993f78c1e0e4e1202b2a925cc6540b55b6e5fb055dd342d88b145ab3102ed"},
|
||||
{file = "litellm-1.80.7.tar.gz", hash = "sha256:3977a8d195aef842d01c18bf9e22984829363c6a4b54daf9a43c9dd9f190b42c"},
|
||||
{file = "litellm-1.80.11-py3-none-any.whl", hash = "sha256:406283d66ead77dc7ff0e0b2559c80e9e497d8e7c2257efb1cb9210a20d09d54"},
|
||||
{file = "litellm-1.80.11.tar.gz", hash = "sha256:c9fc63e7acb6360363238fe291bcff1488c59ff66020416d8376c0ee56414a19"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
aiohttp = ">=3.10"
|
||||
click = "*"
|
||||
fastuuid = ">=0.13.0"
|
||||
grpcio = ">=1.62.3,<1.68.0"
|
||||
grpcio = {version = ">=1.62.3,<1.68.0", markers = "python_version < \"3.14\""}
|
||||
httpx = ">=0.23.0"
|
||||
importlib-metadata = ">=6.8.0"
|
||||
jinja2 = ">=3.1.2,<4.0.0"
|
||||
jsonschema = ">=4.22.0,<5.0.0"
|
||||
jsonschema = ">=4.23.0,<5.0.0"
|
||||
openai = ">=2.8.0"
|
||||
pydantic = ">=2.5.0,<3.0.0"
|
||||
python-dotenv = ">=0.2.0"
|
||||
@@ -4587,7 +4587,7 @@ tokenizers = "*"
|
||||
caching = ["diskcache (>=5.6.1,<6.0.0)"]
|
||||
extra-proxy = ["azure-identity (>=1.15.0,<2.0.0) ; python_version >= \"3.9\"", "azure-keyvault-secrets (>=4.8.0,<5.0.0)", "google-cloud-iam (>=2.19.1,<3.0.0)", "google-cloud-kms (>=2.21.3,<3.0.0)", "prisma (==0.11.0)", "redisvl (>=0.4.1,<0.5.0) ; python_version >= \"3.9\" and python_version < \"3.14\"", "resend (>=0.8.0)"]
|
||||
mlflow = ["mlflow (>3.1.4) ; python_version >= \"3.10\""]
|
||||
proxy = ["PyJWT (>=2.10.1,<3.0.0) ; python_version >= \"3.9\"", "apscheduler (>=3.10.4,<4.0.0)", "azure-identity (>=1.15.0,<2.0.0) ; python_version >= \"3.9\"", "azure-storage-blob (>=12.25.1,<13.0.0)", "backoff", "boto3 (==1.36.0)", "cryptography", "fastapi (>=0.120.1)", "fastapi-sso (>=0.16.0,<0.17.0)", "gunicorn (>=23.0.0,<24.0.0)", "litellm-enterprise (==0.1.22)", "litellm-proxy-extras (==0.4.9)", "mcp (>=1.21.2,<2.0.0) ; python_version >= \"3.10\"", "orjson (>=3.9.7,<4.0.0)", "polars (>=1.31.0,<2.0.0) ; python_version >= \"3.10\"", "pynacl (>=1.5.0,<2.0.0)", "python-multipart (>=0.0.18,<0.0.19)", "pyyaml (>=6.0.1,<7.0.0)", "rich (==13.7.1)", "rq", "soundfile (>=0.12.1,<0.13.0)", "uvicorn (>=0.31.1,<0.32.0)", "uvloop (>=0.21.0,<0.22.0) ; sys_platform != \"win32\"", "websockets (>=15.0.1,<16.0.0)"]
|
||||
proxy = ["PyJWT (>=2.10.1,<3.0.0) ; python_version >= \"3.9\"", "apscheduler (>=3.10.4,<4.0.0)", "azure-identity (>=1.15.0,<2.0.0) ; python_version >= \"3.9\"", "azure-storage-blob (>=12.25.1,<13.0.0)", "backoff", "boto3 (==1.36.0)", "cryptography", "fastapi (>=0.120.1)", "fastapi-sso (>=0.16.0,<0.17.0)", "gunicorn (>=23.0.0,<24.0.0)", "litellm-enterprise (==0.1.27)", "litellm-proxy-extras (==0.4.16)", "mcp (>=1.21.2,<2.0.0) ; python_version >= \"3.10\"", "orjson (>=3.9.7,<4.0.0)", "polars (>=1.31.0,<2.0.0) ; python_version >= \"3.10\"", "pynacl (>=1.5.0,<2.0.0)", "python-multipart (>=0.0.18,<0.0.19)", "pyyaml (>=6.0.1,<7.0.0)", "rich (==13.7.1)", "rq", "soundfile (>=0.12.1,<0.13.0)", "uvicorn (>=0.31.1,<0.32.0)", "uvloop (>=0.21.0,<0.22.0) ; sys_platform != \"win32\"", "websockets (>=15.0.1,<16.0.0)"]
|
||||
semantic-router = ["semantic-router (>=0.1.12) ; python_version >= \"3.9\" and python_version < \"3.14\""]
|
||||
utils = ["numpydoc"]
|
||||
|
||||
@@ -4624,14 +4624,14 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "lmnr"
|
||||
version = "0.7.20"
|
||||
version = "0.7.24"
|
||||
description = "Python SDK for Laminar"
|
||||
optional = false
|
||||
python-versions = "<4,>=3.10"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "lmnr-0.7.20-py3-none-any.whl", hash = "sha256:5f9fa7444e6f96c25e097f66484ff29e632bdd1de0e9346948bf5595f4a8af38"},
|
||||
{file = "lmnr-0.7.20.tar.gz", hash = "sha256:1f484cd618db2d71af65f90a0b8b36d20d80dc91a5138b811575c8677bf7c4fd"},
|
||||
{file = "lmnr-0.7.24-py3-none-any.whl", hash = "sha256:ad780d4a62ece897048811f3368639c240a9329ab31027da8c96545137a3a08a"},
|
||||
{file = "lmnr-0.7.24.tar.gz", hash = "sha256:aa6973f46fc4ba95c9061c1feceb58afc02eb43c9376c21e32545371ff6123d7"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -4654,14 +4654,15 @@ tqdm = ">=4.0"
|
||||
|
||||
[package.extras]
|
||||
alephalpha = ["opentelemetry-instrumentation-alephalpha (>=0.47.1)"]
|
||||
all = ["opentelemetry-instrumentation-alephalpha (>=0.47.1)", "opentelemetry-instrumentation-bedrock (>=0.47.1)", "opentelemetry-instrumentation-chromadb (>=0.47.1)", "opentelemetry-instrumentation-cohere (>=0.47.1)", "opentelemetry-instrumentation-crewai (>=0.47.1)", "opentelemetry-instrumentation-haystack (>=0.47.1)", "opentelemetry-instrumentation-lancedb (>=0.47.1)", "opentelemetry-instrumentation-langchain (>=0.47.1)", "opentelemetry-instrumentation-llamaindex (>=0.47.1)", "opentelemetry-instrumentation-marqo (>=0.47.1)", "opentelemetry-instrumentation-mcp (>=0.47.1)", "opentelemetry-instrumentation-milvus (>=0.47.1)", "opentelemetry-instrumentation-mistralai (>=0.47.1)", "opentelemetry-instrumentation-ollama (>=0.47.1)", "opentelemetry-instrumentation-pinecone (>=0.47.1)", "opentelemetry-instrumentation-qdrant (>=0.47.1)", "opentelemetry-instrumentation-replicate (>=0.47.1)", "opentelemetry-instrumentation-sagemaker (>=0.47.1)", "opentelemetry-instrumentation-together (>=0.47.1)", "opentelemetry-instrumentation-transformers (>=0.47.1)", "opentelemetry-instrumentation-vertexai (>=0.47.1)", "opentelemetry-instrumentation-watsonx (>=0.47.1)", "opentelemetry-instrumentation-weaviate (>=0.47.1)"]
|
||||
all = ["opentelemetry-instrumentation-alephalpha (>=0.47.1)", "opentelemetry-instrumentation-bedrock (>=0.47.1)", "opentelemetry-instrumentation-chromadb (>=0.47.1)", "opentelemetry-instrumentation-cohere (>=0.47.1)", "opentelemetry-instrumentation-crewai (>=0.47.1)", "opentelemetry-instrumentation-haystack (>=0.47.1)", "opentelemetry-instrumentation-lancedb (>=0.47.1)", "opentelemetry-instrumentation-langchain (>=0.47.1,<0.48.0)", "opentelemetry-instrumentation-llamaindex (>=0.47.1)", "opentelemetry-instrumentation-marqo (>=0.47.1)", "opentelemetry-instrumentation-mcp (>=0.47.1)", "opentelemetry-instrumentation-milvus (>=0.47.1)", "opentelemetry-instrumentation-mistralai (>=0.47.1)", "opentelemetry-instrumentation-ollama (>=0.47.1)", "opentelemetry-instrumentation-pinecone (>=0.47.1)", "opentelemetry-instrumentation-qdrant (>=0.47.1)", "opentelemetry-instrumentation-replicate (>=0.47.1)", "opentelemetry-instrumentation-sagemaker (>=0.47.1)", "opentelemetry-instrumentation-together (>=0.47.1)", "opentelemetry-instrumentation-transformers (>=0.47.1)", "opentelemetry-instrumentation-vertexai (>=0.47.1)", "opentelemetry-instrumentation-watsonx (>=0.47.1)", "opentelemetry-instrumentation-weaviate (>=0.47.1)"]
|
||||
bedrock = ["opentelemetry-instrumentation-bedrock (>=0.47.1)"]
|
||||
chromadb = ["opentelemetry-instrumentation-chromadb (>=0.47.1)"]
|
||||
claude-agent-sdk = ["lmnr-claude-code-proxy (>=0.1.0a5)"]
|
||||
cohere = ["opentelemetry-instrumentation-cohere (>=0.47.1)"]
|
||||
crewai = ["opentelemetry-instrumentation-crewai (>=0.47.1)"]
|
||||
haystack = ["opentelemetry-instrumentation-haystack (>=0.47.1)"]
|
||||
lancedb = ["opentelemetry-instrumentation-lancedb (>=0.47.1)"]
|
||||
langchain = ["opentelemetry-instrumentation-langchain (>=0.47.1)"]
|
||||
langchain = ["opentelemetry-instrumentation-langchain (>=0.47.1,<0.48.0)"]
|
||||
llamaindex = ["opentelemetry-instrumentation-llamaindex (>=0.47.1)"]
|
||||
marqo = ["opentelemetry-instrumentation-marqo (>=0.47.1)"]
|
||||
mcp = ["opentelemetry-instrumentation-mcp (>=0.47.1)"]
|
||||
@@ -5835,13 +5836,15 @@ llama = ["llama-index (>=0.12.29,<0.13.0)", "llama-index-core (>=0.12.29,<0.13.0
|
||||
|
||||
[[package]]
|
||||
name = "openhands-agent-server"
|
||||
version = "1.5.2"
|
||||
version = "1.7.1"
|
||||
description = "OpenHands Agent Server - REST/WebSocket interface for OpenHands AI Agent"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = []
|
||||
develop = false
|
||||
files = [
|
||||
{file = "openhands_agent_server-1.7.1-py3-none-any.whl", hash = "sha256:e5c57f1b73293d00a68b77f9d290f59d9e2217d9df844fb01c7d2f929c3417f4"},
|
||||
{file = "openhands_agent_server-1.7.1.tar.gz", hash = "sha256:c82e1e6748ea3b4278ef2ee72f091dc37da6667c854b3aa3c0bc616086a82310"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
aiosqlite = ">=0.19"
|
||||
@@ -5855,16 +5858,9 @@ uvicorn = ">=0.31.1"
|
||||
websockets = ">=12"
|
||||
wsproto = ">=1.2.0"
|
||||
|
||||
[package.source]
|
||||
type = "git"
|
||||
url = "https://github.com/OpenHands/agent-sdk.git"
|
||||
reference = "34fcb39268229948d90bb3f9964b20a736f65777"
|
||||
resolved_reference = "34fcb39268229948d90bb3f9964b20a736f65777"
|
||||
subdirectory = "openhands-agent-server"
|
||||
|
||||
[[package]]
|
||||
name = "openhands-ai"
|
||||
version = "0.0.0-post.5671+d772dd65a"
|
||||
version = "0.0.0-post.5742+ee50f333b"
|
||||
description = "OpenHands: Code Less, Make More"
|
||||
optional = false
|
||||
python-versions = "^3.12,<3.14"
|
||||
@@ -5900,15 +5896,15 @@ json-repair = "*"
|
||||
jupyter_kernel_gateway = "*"
|
||||
kubernetes = "^33.1.0"
|
||||
libtmux = ">=0.46.2"
|
||||
litellm = ">=1.74.3, <=1.80.7, !=1.64.4, !=1.67.*"
|
||||
litellm = ">=1.74.3, !=1.64.4, !=1.67.*"
|
||||
lmnr = "^0.7.20"
|
||||
memory-profiler = "^0.61.0"
|
||||
numpy = "*"
|
||||
openai = "2.8.0"
|
||||
openhands-aci = "0.3.2"
|
||||
openhands-agent-server = {git = "https://github.com/OpenHands/agent-sdk.git", rev = "34fcb39268229948d90bb3f9964b20a736f65777", subdirectory = "openhands-agent-server"}
|
||||
openhands-sdk = {git = "https://github.com/OpenHands/agent-sdk.git", rev = "34fcb39268229948d90bb3f9964b20a736f65777", subdirectory = "openhands-sdk"}
|
||||
openhands-tools = {git = "https://github.com/OpenHands/agent-sdk.git", rev = "34fcb39268229948d90bb3f9964b20a736f65777", subdirectory = "openhands-tools"}
|
||||
openhands-agent-server = "1.7.1"
|
||||
openhands-sdk = "1.7.1"
|
||||
openhands-tools = "1.7.1"
|
||||
opentelemetry-api = "^1.33.1"
|
||||
opentelemetry-exporter-otlp-proto-grpc = "^1.33.1"
|
||||
pathspec = "^0.12.1"
|
||||
@@ -5964,20 +5960,22 @@ url = ".."
|
||||
|
||||
[[package]]
|
||||
name = "openhands-sdk"
|
||||
version = "1.5.2"
|
||||
version = "1.7.1"
|
||||
description = "OpenHands SDK - Core functionality for building AI agents"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = []
|
||||
develop = false
|
||||
files = [
|
||||
{file = "openhands_sdk-1.7.1-py3-none-any.whl", hash = "sha256:e097e34dfbd45f38225ae2ff4830702424bcf742bc197b5a811540a75265b135"},
|
||||
{file = "openhands_sdk-1.7.1.tar.gz", hash = "sha256:e13d1fe8bf14dffd91e9080608072a989132c981cf9bfcd124fa4f7a68a13691"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
deprecation = ">=2.1.0"
|
||||
fastmcp = ">=2.11.3"
|
||||
httpx = ">=0.27.0"
|
||||
litellm = ">=1.80.7"
|
||||
lmnr = ">=0.7.20"
|
||||
litellm = ">=1.80.10"
|
||||
lmnr = ">=0.7.24"
|
||||
pydantic = ">=2.11.7"
|
||||
python-frontmatter = ">=1.1.0"
|
||||
python-json-logger = ">=3.3.0"
|
||||
@@ -5987,22 +5985,17 @@ websockets = ">=12"
|
||||
[package.extras]
|
||||
boto3 = ["boto3 (>=1.35.0)"]
|
||||
|
||||
[package.source]
|
||||
type = "git"
|
||||
url = "https://github.com/OpenHands/agent-sdk.git"
|
||||
reference = "34fcb39268229948d90bb3f9964b20a736f65777"
|
||||
resolved_reference = "34fcb39268229948d90bb3f9964b20a736f65777"
|
||||
subdirectory = "openhands-sdk"
|
||||
|
||||
[[package]]
|
||||
name = "openhands-tools"
|
||||
version = "1.5.2"
|
||||
version = "1.7.1"
|
||||
description = "OpenHands Tools - Runtime tools for AI agents"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = []
|
||||
develop = false
|
||||
files = [
|
||||
{file = "openhands_tools-1.7.1-py3-none-any.whl", hash = "sha256:e25815f24925e94fbd4d8c3fd9b2147a0556fde595bf4f80a7dbba1014ea3c86"},
|
||||
{file = "openhands_tools-1.7.1.tar.gz", hash = "sha256:f3823f7bd302c78969c454730cf793eb63109ce2d986e78585989c53986cc966"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
bashlex = ">=0.18"
|
||||
@@ -6015,13 +6008,6 @@ openhands-sdk = "*"
|
||||
pydantic = ">=2.11.7"
|
||||
tom-swe = ">=1.0.3"
|
||||
|
||||
[package.source]
|
||||
type = "git"
|
||||
url = "https://github.com/OpenHands/agent-sdk.git"
|
||||
reference = "34fcb39268229948d90bb3f9964b20a736f65777"
|
||||
resolved_reference = "34fcb39268229948d90bb3f9964b20a736f65777"
|
||||
subdirectory = "openhands-tools"
|
||||
|
||||
[[package]]
|
||||
name = "openpyxl"
|
||||
version = "3.1.5"
|
||||
|
||||
@@ -34,8 +34,15 @@ from server.routes.integration.jira_dc import jira_dc_integration_router # noqa
|
||||
from server.routes.integration.linear import linear_integration_router # noqa: E402
|
||||
from server.routes.integration.slack import slack_router # noqa: E402
|
||||
from server.routes.mcp_patch import patch_mcp_server # noqa: E402
|
||||
from server.routes.oauth_device import oauth_device_router # noqa: E402
|
||||
from server.routes.readiness import readiness_router # noqa: E402
|
||||
from server.routes.user import saas_user_router # noqa: E402
|
||||
from server.sharing.shared_conversation_router import ( # noqa: E402
|
||||
router as shared_conversation_router,
|
||||
)
|
||||
from server.sharing.shared_event_router import ( # noqa: E402
|
||||
router as shared_event_router,
|
||||
)
|
||||
|
||||
from openhands.server.app import app as base_app # noqa: E402
|
||||
from openhands.server.listen_socket import sio # noqa: E402
|
||||
@@ -60,10 +67,13 @@ base_app.mount('/internal/metrics', metrics_app())
|
||||
base_app.include_router(readiness_router) # Add routes for readiness checks
|
||||
base_app.include_router(api_router) # Add additional route for github auth
|
||||
base_app.include_router(oauth_router) # Add additional route for oauth callback
|
||||
base_app.include_router(oauth_device_router) # Add OAuth 2.0 Device Flow routes
|
||||
base_app.include_router(saas_user_router) # Add additional route SAAS user calls
|
||||
base_app.include_router(
|
||||
billing_router
|
||||
) # Add routes for credit management and Stripe payment integration
|
||||
base_app.include_router(shared_conversation_router)
|
||||
base_app.include_router(shared_event_router)
|
||||
|
||||
# Add GitHub integration router only if GITHUB_APP_CLIENT_ID is set
|
||||
if GITHUB_APP_CLIENT_ID:
|
||||
@@ -97,6 +107,7 @@ base_app.include_router(
|
||||
event_webhook_router
|
||||
) # Add routes for Events in nested runtimes
|
||||
|
||||
|
||||
base_app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=PERMITTED_CORS_ORIGINS,
|
||||
|
||||
@@ -38,3 +38,8 @@ ROLE_CHECK_ENABLED = os.getenv('ROLE_CHECK_ENABLED', 'false').lower() in (
|
||||
'y',
|
||||
'on',
|
||||
)
|
||||
BLOCKED_EMAIL_DOMAINS = [
|
||||
domain.strip().lower()
|
||||
for domain in os.getenv('BLOCKED_EMAIL_DOMAINS', '').split(',')
|
||||
if domain.strip()
|
||||
]
|
||||
|
||||
56
enterprise/server/auth/domain_blocker.py
Normal file
56
enterprise/server/auth/domain_blocker.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from server.auth.constants import BLOCKED_EMAIL_DOMAINS
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
class DomainBlocker:
|
||||
def __init__(self) -> None:
|
||||
logger.debug('Initializing DomainBlocker')
|
||||
self.blocked_domains: list[str] = BLOCKED_EMAIL_DOMAINS
|
||||
if self.blocked_domains:
|
||||
logger.info(
|
||||
f'Successfully loaded {len(self.blocked_domains)} blocked email domains: {self.blocked_domains}'
|
||||
)
|
||||
|
||||
def is_active(self) -> bool:
|
||||
"""Check if domain blocking is enabled"""
|
||||
return bool(self.blocked_domains)
|
||||
|
||||
def _extract_domain(self, email: str) -> str | None:
|
||||
"""Extract and normalize email domain from email address"""
|
||||
if not email:
|
||||
return None
|
||||
try:
|
||||
# Extract domain part after @
|
||||
if '@' not in email:
|
||||
return None
|
||||
domain = email.split('@')[1].strip().lower()
|
||||
return domain if domain else None
|
||||
except Exception:
|
||||
logger.debug(f'Error extracting domain from email: {email}', exc_info=True)
|
||||
return None
|
||||
|
||||
def is_domain_blocked(self, email: str) -> bool:
|
||||
"""Check if email domain is blocked"""
|
||||
if not self.is_active():
|
||||
return False
|
||||
|
||||
if not email:
|
||||
logger.debug('No email provided for domain check')
|
||||
return False
|
||||
|
||||
domain = self._extract_domain(email)
|
||||
if not domain:
|
||||
logger.debug(f'Could not extract domain from email: {email}')
|
||||
return False
|
||||
|
||||
is_blocked = domain in self.blocked_domains
|
||||
if is_blocked:
|
||||
logger.warning(f'Email domain {domain} is blocked for email: {email}')
|
||||
else:
|
||||
logger.debug(f'Email domain {domain} is not blocked')
|
||||
|
||||
return is_blocked
|
||||
|
||||
|
||||
domain_blocker = DomainBlocker()
|
||||
109
enterprise/server/auth/email_validation.py
Normal file
109
enterprise/server/auth/email_validation.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""Email validation utilities for preventing duplicate signups with + modifier."""
|
||||
|
||||
import re
|
||||
|
||||
|
||||
def extract_base_email(email: str) -> str | None:
|
||||
"""Extract base email from an email address.
|
||||
|
||||
For emails with + modifier, extracts the base email (local part before + and @, plus domain).
|
||||
For emails without + modifier, returns the email as-is.
|
||||
|
||||
Examples:
|
||||
extract_base_email("joe+test@example.com") -> "joe@example.com"
|
||||
extract_base_email("joe@example.com") -> "joe@example.com"
|
||||
extract_base_email("joe+openhands+test@example.com") -> "joe@example.com"
|
||||
|
||||
Args:
|
||||
email: The email address to process
|
||||
|
||||
Returns:
|
||||
The base email address, or None if email format is invalid
|
||||
"""
|
||||
if not email or '@' not in email:
|
||||
return None
|
||||
|
||||
try:
|
||||
local_part, domain = email.rsplit('@', 1)
|
||||
# Extract the part before + if it exists
|
||||
base_local = local_part.split('+', 1)[0]
|
||||
return f'{base_local}@{domain}'
|
||||
except (ValueError, AttributeError):
|
||||
return None
|
||||
|
||||
|
||||
def has_plus_modifier(email: str) -> bool:
|
||||
"""Check if an email address contains a + modifier.
|
||||
|
||||
Args:
|
||||
email: The email address to check
|
||||
|
||||
Returns:
|
||||
True if email contains + before @, False otherwise
|
||||
"""
|
||||
if not email or '@' not in email:
|
||||
return False
|
||||
|
||||
try:
|
||||
local_part, _ = email.rsplit('@', 1)
|
||||
return '+' in local_part
|
||||
except (ValueError, AttributeError):
|
||||
return False
|
||||
|
||||
|
||||
def matches_base_email(email: str, base_email: str) -> bool:
|
||||
"""Check if an email matches a base email pattern.
|
||||
|
||||
An email matches if:
|
||||
- It is exactly the base email (e.g., joe@example.com)
|
||||
- It has the same base local part and domain, with or without + modifier
|
||||
(e.g., joe+test@example.com matches base joe@example.com)
|
||||
|
||||
Args:
|
||||
email: The email address to check
|
||||
base_email: The base email to match against
|
||||
|
||||
Returns:
|
||||
True if email matches the base pattern, False otherwise
|
||||
"""
|
||||
if not email or not base_email:
|
||||
return False
|
||||
|
||||
# Extract base from both emails for comparison
|
||||
email_base = extract_base_email(email)
|
||||
base_email_normalized = extract_base_email(base_email)
|
||||
|
||||
if not email_base or not base_email_normalized:
|
||||
return False
|
||||
|
||||
# Emails match if they have the same base
|
||||
return email_base.lower() == base_email_normalized.lower()
|
||||
|
||||
|
||||
def get_base_email_regex_pattern(base_email: str) -> re.Pattern | None:
|
||||
"""Generate a regex pattern to match emails with the same base.
|
||||
|
||||
For base_email "joe@example.com", the pattern will match:
|
||||
- joe@example.com
|
||||
- joe+anything@example.com
|
||||
|
||||
Args:
|
||||
base_email: The base email address
|
||||
|
||||
Returns:
|
||||
A compiled regex pattern, or None if base_email is invalid
|
||||
"""
|
||||
base = extract_base_email(base_email)
|
||||
if not base:
|
||||
return None
|
||||
|
||||
try:
|
||||
local_part, domain = base.rsplit('@', 1)
|
||||
# Escape special regex characters in local part and domain
|
||||
escaped_local = re.escape(local_part)
|
||||
escaped_domain = re.escape(domain)
|
||||
# Pattern: joe@example.com OR joe+anything@example.com
|
||||
pattern = rf'^{escaped_local}(\+[^@\s]+)?@{escaped_domain}$'
|
||||
return re.compile(pattern, re.IGNORECASE)
|
||||
except (ValueError, AttributeError):
|
||||
return None
|
||||
@@ -13,6 +13,7 @@ from server.auth.auth_error import (
|
||||
ExpiredError,
|
||||
NoCredentialsError,
|
||||
)
|
||||
from server.auth.domain_blocker import domain_blocker
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.config import get_config
|
||||
from server.logger import logger
|
||||
@@ -153,8 +154,10 @@ class SaasUserAuth(UserAuth):
|
||||
try:
|
||||
# TODO: I think we can do this in a single request if we refactor
|
||||
with session_maker() as session:
|
||||
tokens = session.query(AuthTokens).where(
|
||||
AuthTokens.keycloak_user_id == self.user_id
|
||||
tokens = (
|
||||
session.query(AuthTokens)
|
||||
.where(AuthTokens.keycloak_user_id == self.user_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
for token in tokens:
|
||||
@@ -312,6 +315,16 @@ async def saas_user_auth_from_signed_token(signed_token: str) -> SaasUserAuth:
|
||||
user_id = access_token_payload['sub']
|
||||
email = access_token_payload['email']
|
||||
email_verified = access_token_payload['email_verified']
|
||||
|
||||
# Check if email domain is blocked
|
||||
if email and domain_blocker.is_active() and domain_blocker.is_domain_blocked(email):
|
||||
logger.warning(
|
||||
f'Blocked authentication attempt for existing user with email: {email}'
|
||||
)
|
||||
raise AuthError(
|
||||
'Access denied: Your email domain is not allowed to access this service'
|
||||
)
|
||||
|
||||
logger.debug('saas_user_auth_from_signed_token:return')
|
||||
|
||||
return SaasUserAuth(
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
@@ -25,6 +26,11 @@ from server.auth.constants import (
|
||||
KEYCLOAK_SERVER_URL,
|
||||
KEYCLOAK_SERVER_URL_EXT,
|
||||
)
|
||||
from server.auth.email_validation import (
|
||||
extract_base_email,
|
||||
get_base_email_regex_pattern,
|
||||
matches_base_email,
|
||||
)
|
||||
from server.auth.keycloak_manager import get_keycloak_admin, get_keycloak_openid
|
||||
from server.config import get_config
|
||||
from server.logger import logger
|
||||
@@ -509,6 +515,183 @@ class TokenManager:
|
||||
logger.info(f'Got user ID {keycloak_user_id} from email: {email}')
|
||||
return keycloak_user_id
|
||||
|
||||
async def _query_users_by_wildcard_pattern(
|
||||
self, local_part: str, domain: str
|
||||
) -> dict[str, dict]:
|
||||
"""Query Keycloak for users matching a wildcard email pattern.
|
||||
|
||||
Tries multiple query methods to find users with emails matching
|
||||
the pattern {local_part}*@{domain}. This catches the base email
|
||||
and all + modifier variants.
|
||||
|
||||
Args:
|
||||
local_part: The local part of the email (before @)
|
||||
domain: The domain part of the email (after @)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping user IDs to user objects
|
||||
"""
|
||||
keycloak_admin = get_keycloak_admin(self.external)
|
||||
all_users = {}
|
||||
|
||||
# Query for users with emails matching the base pattern using wildcard
|
||||
# Pattern: {local_part}*@{domain} - catches base email and all + variants
|
||||
# This may also catch unintended matches (e.g., joesmith@example.com), but
|
||||
# they will be filtered out by the regex pattern check later
|
||||
# Use 'search' parameter for Keycloak 26+ (better wildcard support)
|
||||
wildcard_queries = [
|
||||
{'search': f'{local_part}*@{domain}'}, # Try 'search' parameter first
|
||||
{'q': f'email:{local_part}*@{domain}'}, # Fallback to 'q' parameter
|
||||
]
|
||||
|
||||
for query_params in wildcard_queries:
|
||||
try:
|
||||
users = await keycloak_admin.a_get_users(query_params)
|
||||
for user in users:
|
||||
all_users[user.get('id')] = user
|
||||
break # Success, no need to try fallback
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f'Wildcard query failed with {list(query_params.keys())[0]}: {e}'
|
||||
)
|
||||
continue # Try next query method
|
||||
|
||||
return all_users
|
||||
|
||||
def _find_duplicate_in_users(
|
||||
self, users: dict[str, dict], base_email: str, current_user_id: str
|
||||
) -> bool:
|
||||
"""Check if any user in the provided list matches the base email pattern.
|
||||
|
||||
Filters users to find duplicates that match the base email pattern,
|
||||
excluding the current user.
|
||||
|
||||
Args:
|
||||
users: Dictionary mapping user IDs to user objects
|
||||
base_email: The base email to match against
|
||||
current_user_id: The user ID to exclude from the check
|
||||
|
||||
Returns:
|
||||
True if a duplicate is found, False otherwise
|
||||
"""
|
||||
regex_pattern = get_base_email_regex_pattern(base_email)
|
||||
if not regex_pattern:
|
||||
logger.warning(
|
||||
f'Could not generate regex pattern for base email: {base_email}'
|
||||
)
|
||||
# Fallback to simple matching
|
||||
for user in users.values():
|
||||
user_email = user.get('email', '').lower()
|
||||
if (
|
||||
user_email
|
||||
and user.get('id') != current_user_id
|
||||
and matches_base_email(user_email, base_email)
|
||||
):
|
||||
logger.info(
|
||||
f'Found duplicate email: {user_email} matches base {base_email}'
|
||||
)
|
||||
return True
|
||||
else:
|
||||
for user in users.values():
|
||||
user_email = user.get('email', '')
|
||||
if (
|
||||
user_email
|
||||
and user.get('id') != current_user_id
|
||||
and regex_pattern.match(user_email)
|
||||
):
|
||||
logger.info(
|
||||
f'Found duplicate email: {user_email} matches base {base_email}'
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(2),
|
||||
retry=retry_if_exception_type(KeycloakConnectionError),
|
||||
before_sleep=_before_sleep_callback,
|
||||
)
|
||||
async def check_duplicate_base_email(
|
||||
self, email: str, current_user_id: str
|
||||
) -> bool:
|
||||
"""Check if a user with the same base email already exists.
|
||||
|
||||
This method checks for duplicate signups using email + modifier.
|
||||
It checks if any user exists with the same base email, regardless of whether
|
||||
the provided email has a + modifier or not.
|
||||
|
||||
Examples:
|
||||
- If email is "joe+test@example.com", it checks for existing users with
|
||||
base email "joe@example.com" (e.g., "joe@example.com", "joe+1@example.com")
|
||||
- If email is "joe@example.com", it checks for existing users with
|
||||
base email "joe@example.com" (e.g., "joe+1@example.com", "joe+test@example.com")
|
||||
|
||||
Args:
|
||||
email: The email address to check (may or may not contain + modifier)
|
||||
current_user_id: The user ID of the current user (to exclude from check)
|
||||
|
||||
Returns:
|
||||
True if a duplicate is found (excluding current user), False otherwise
|
||||
"""
|
||||
if not email:
|
||||
return False
|
||||
|
||||
base_email = extract_base_email(email)
|
||||
if not base_email:
|
||||
logger.warning(f'Could not extract base email from: {email}')
|
||||
return False
|
||||
|
||||
try:
|
||||
local_part, domain = base_email.rsplit('@', 1)
|
||||
users = await self._query_users_by_wildcard_pattern(local_part, domain)
|
||||
return self._find_duplicate_in_users(users, base_email, current_user_id)
|
||||
|
||||
except KeycloakConnectionError:
|
||||
logger.exception('KeycloakConnectionError when checking duplicate email')
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f'Unexpected error checking duplicate email: {e}')
|
||||
# On any error, allow signup to proceed (fail open)
|
||||
return False
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(2),
|
||||
retry=retry_if_exception_type(KeycloakConnectionError),
|
||||
before_sleep=_before_sleep_callback,
|
||||
)
|
||||
async def delete_keycloak_user(self, user_id: str) -> bool:
|
||||
"""Delete a user from Keycloak.
|
||||
|
||||
This method is used to clean up user accounts that were created
|
||||
but should not exist (e.g., duplicate email signups).
|
||||
|
||||
Args:
|
||||
user_id: The Keycloak user ID to delete
|
||||
|
||||
Returns:
|
||||
True if deletion was successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
keycloak_admin = get_keycloak_admin(self.external)
|
||||
# Use the sync method (python-keycloak doesn't have async delete_user)
|
||||
# Run it in a thread executor to avoid blocking the event loop
|
||||
await asyncio.to_thread(keycloak_admin.delete_user, user_id)
|
||||
logger.info(f'Successfully deleted Keycloak user {user_id}')
|
||||
return True
|
||||
except KeycloakConnectionError:
|
||||
logger.exception(f'KeycloakConnectionError when deleting user {user_id}')
|
||||
raise
|
||||
except KeycloakError as e:
|
||||
# User might not exist or already deleted
|
||||
logger.warning(
|
||||
f'KeycloakError when deleting user {user_id}: {e}',
|
||||
extra={'user_id': user_id, 'error': str(e)},
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.exception(f'Unexpected error deleting Keycloak user {user_id}: {e}')
|
||||
return False
|
||||
|
||||
async def get_user_info_from_user_id(self, user_id: str) -> dict | None:
|
||||
keycloak_admin = get_keycloak_admin(self.external)
|
||||
user = await keycloak_admin.a_get_user(user_id)
|
||||
@@ -527,6 +710,49 @@ class TokenManager:
|
||||
github_id = github_ids[0]
|
||||
return github_id
|
||||
|
||||
async def disable_keycloak_user(
|
||||
self, user_id: str, email: str | None = None
|
||||
) -> None:
|
||||
"""Disable a Keycloak user account.
|
||||
|
||||
Args:
|
||||
user_id: The Keycloak user ID to disable
|
||||
email: Optional email address for logging purposes
|
||||
|
||||
This method attempts to disable the user account but will not raise exceptions.
|
||||
Errors are logged but do not prevent the operation from completing.
|
||||
"""
|
||||
try:
|
||||
keycloak_admin = get_keycloak_admin(self.external)
|
||||
# Get current user to preserve other fields
|
||||
user = await keycloak_admin.a_get_user(user_id)
|
||||
if user:
|
||||
# Update user with enabled=False to disable the account
|
||||
await keycloak_admin.a_update_user(
|
||||
user_id=user_id,
|
||||
payload={
|
||||
'enabled': False,
|
||||
'username': user.get('username', ''),
|
||||
'email': user.get('email', ''),
|
||||
'emailVerified': user.get('emailVerified', False),
|
||||
},
|
||||
)
|
||||
email_str = f', email: {email}' if email else ''
|
||||
logger.info(
|
||||
f'Disabled Keycloak account for user_id: {user_id}{email_str}'
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f'User not found in Keycloak when attempting to disable: {user_id}'
|
||||
)
|
||||
except Exception as e:
|
||||
# Log error but don't raise - the caller should handle the blocking regardless
|
||||
email_str = f', email: {email}' if email else ''
|
||||
logger.error(
|
||||
f'Failed to disable Keycloak account for user_id: {user_id}{email_str}: {str(e)}',
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
def store_org_token(self, installation_id: int, installation_token: str):
|
||||
"""Store a GitHub App installation token.
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ USER_SETTINGS_VERSION_TO_MODEL = {
|
||||
2: 'claude-3-7-sonnet-20250219',
|
||||
3: 'claude-sonnet-4-20250514',
|
||||
4: 'claude-sonnet-4-20250514',
|
||||
5: 'claude-opus-4-5-20251101',
|
||||
}
|
||||
|
||||
LITELLM_DEFAULT_MODEL = os.getenv('LITELLM_DEFAULT_MODEL')
|
||||
|
||||
@@ -1,331 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import socketio
|
||||
from server.clustered_conversation_manager import ClusteredConversationManager
|
||||
from server.saas_nested_conversation_manager import SaasNestedConversationManager
|
||||
|
||||
from openhands.core.config import LLMConfig, OpenHandsConfig
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.server.config.server_config import ServerConfig
|
||||
from openhands.server.conversation_manager.conversation_manager import (
|
||||
ConversationManager,
|
||||
)
|
||||
from openhands.server.data_models.agent_loop_info import AgentLoopInfo
|
||||
from openhands.server.monitoring import MonitoringListener
|
||||
from openhands.server.session.conversation import ServerConversation
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
from openhands.storage.files import FileStore
|
||||
from openhands.utils.async_utils import wait_all
|
||||
|
||||
_LEGACY_ENTRY_TIMEOUT_SECONDS = 3600
|
||||
|
||||
|
||||
@dataclass
|
||||
class LegacyCacheEntry:
|
||||
"""Cache entry for legacy mode status."""
|
||||
|
||||
is_legacy: bool
|
||||
timestamp: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class LegacyConversationManager(ConversationManager):
|
||||
"""
|
||||
Conversation manager for use while migrating - since existing conversations are not nested!
|
||||
Separate class from SaasNestedConversationManager so it can be easliy removed in a few weeks.
|
||||
(As of 2025-07-23)
|
||||
"""
|
||||
|
||||
sio: socketio.AsyncServer
|
||||
config: OpenHandsConfig
|
||||
server_config: ServerConfig
|
||||
file_store: FileStore
|
||||
conversation_manager: SaasNestedConversationManager
|
||||
legacy_conversation_manager: ClusteredConversationManager
|
||||
_legacy_cache: dict[str, LegacyCacheEntry] = field(default_factory=dict)
|
||||
|
||||
async def __aenter__(self):
|
||||
await wait_all(
|
||||
[
|
||||
self.conversation_manager.__aenter__(),
|
||||
self.legacy_conversation_manager.__aenter__(),
|
||||
]
|
||||
)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
await wait_all(
|
||||
[
|
||||
self.conversation_manager.__aexit__(exc_type, exc_value, traceback),
|
||||
self.legacy_conversation_manager.__aexit__(
|
||||
exc_type, exc_value, traceback
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
async def request_llm_completion(
|
||||
self,
|
||||
sid: str,
|
||||
service_id: str,
|
||||
llm_config: LLMConfig,
|
||||
messages: list[dict[str, str]],
|
||||
) -> str:
|
||||
session = self.get_agent_session(sid)
|
||||
llm_registry = session.llm_registry
|
||||
return llm_registry.request_extraneous_completion(
|
||||
service_id, llm_config, messages
|
||||
)
|
||||
|
||||
async def attach_to_conversation(
|
||||
self, sid: str, user_id: str | None = None
|
||||
) -> ServerConversation | None:
|
||||
if await self.should_start_in_legacy_mode(sid):
|
||||
return await self.legacy_conversation_manager.attach_to_conversation(
|
||||
sid, user_id
|
||||
)
|
||||
return await self.conversation_manager.attach_to_conversation(sid, user_id)
|
||||
|
||||
async def detach_from_conversation(self, conversation: ServerConversation):
|
||||
if await self.should_start_in_legacy_mode(conversation.sid):
|
||||
return await self.legacy_conversation_manager.detach_from_conversation(
|
||||
conversation
|
||||
)
|
||||
return await self.conversation_manager.detach_from_conversation(conversation)
|
||||
|
||||
async def join_conversation(
|
||||
self,
|
||||
sid: str,
|
||||
connection_id: str,
|
||||
settings: Settings,
|
||||
user_id: str | None,
|
||||
) -> AgentLoopInfo:
|
||||
if await self.should_start_in_legacy_mode(sid):
|
||||
return await self.legacy_conversation_manager.join_conversation(
|
||||
sid, connection_id, settings, user_id
|
||||
)
|
||||
return await self.conversation_manager.join_conversation(
|
||||
sid, connection_id, settings, user_id
|
||||
)
|
||||
|
||||
def get_agent_session(self, sid: str):
|
||||
session = self.legacy_conversation_manager.get_agent_session(sid)
|
||||
if session is None:
|
||||
session = self.conversation_manager.get_agent_session(sid)
|
||||
return session
|
||||
|
||||
async def get_running_agent_loops(
|
||||
self, user_id: str | None = None, filter_to_sids: set[str] | None = None
|
||||
) -> set[str]:
|
||||
if filter_to_sids and len(filter_to_sids) == 1:
|
||||
sid = next(iter(filter_to_sids))
|
||||
if await self.should_start_in_legacy_mode(sid):
|
||||
return await self.legacy_conversation_manager.get_running_agent_loops(
|
||||
user_id, filter_to_sids
|
||||
)
|
||||
return await self.conversation_manager.get_running_agent_loops(
|
||||
user_id, filter_to_sids
|
||||
)
|
||||
|
||||
# Get all running agent loops from both managers
|
||||
agent_loops, legacy_agent_loops = await wait_all(
|
||||
[
|
||||
self.conversation_manager.get_running_agent_loops(
|
||||
user_id, filter_to_sids
|
||||
),
|
||||
self.legacy_conversation_manager.get_running_agent_loops(
|
||||
user_id, filter_to_sids
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# Combine the results
|
||||
result = set()
|
||||
for sid in legacy_agent_loops:
|
||||
if await self.should_start_in_legacy_mode(sid):
|
||||
result.add(sid)
|
||||
|
||||
for sid in agent_loops:
|
||||
if not await self.should_start_in_legacy_mode(sid):
|
||||
result.add(sid)
|
||||
|
||||
return result
|
||||
|
||||
async def is_agent_loop_running(self, sid: str) -> bool:
|
||||
return bool(await self.get_running_agent_loops(filter_to_sids={sid}))
|
||||
|
||||
async def get_connections(
|
||||
self, user_id: str | None = None, filter_to_sids: set[str] | None = None
|
||||
) -> dict[str, str]:
|
||||
if filter_to_sids and len(filter_to_sids) == 1:
|
||||
sid = next(iter(filter_to_sids))
|
||||
if await self.should_start_in_legacy_mode(sid):
|
||||
return await self.legacy_conversation_manager.get_connections(
|
||||
user_id, filter_to_sids
|
||||
)
|
||||
return await self.conversation_manager.get_connections(
|
||||
user_id, filter_to_sids
|
||||
)
|
||||
agent_loops, legacy_agent_loops = await wait_all(
|
||||
[
|
||||
self.conversation_manager.get_connections(user_id, filter_to_sids),
|
||||
self.legacy_conversation_manager.get_connections(
|
||||
user_id, filter_to_sids
|
||||
),
|
||||
]
|
||||
)
|
||||
legacy_agent_loops.update(agent_loops)
|
||||
return legacy_agent_loops
|
||||
|
||||
async def maybe_start_agent_loop(
|
||||
self,
|
||||
sid: str,
|
||||
settings: Settings,
|
||||
user_id: str, # type: ignore[override]
|
||||
initial_user_msg: MessageAction | None = None,
|
||||
replay_json: str | None = None,
|
||||
) -> AgentLoopInfo:
|
||||
if await self.should_start_in_legacy_mode(sid):
|
||||
return await self.legacy_conversation_manager.maybe_start_agent_loop(
|
||||
sid, settings, user_id, initial_user_msg, replay_json
|
||||
)
|
||||
return await self.conversation_manager.maybe_start_agent_loop(
|
||||
sid, settings, user_id, initial_user_msg, replay_json
|
||||
)
|
||||
|
||||
async def send_to_event_stream(self, connection_id: str, data: dict):
|
||||
return await self.legacy_conversation_manager.send_to_event_stream(
|
||||
connection_id, data
|
||||
)
|
||||
|
||||
async def send_event_to_conversation(self, sid: str, data: dict):
|
||||
if await self.should_start_in_legacy_mode(sid):
|
||||
await self.legacy_conversation_manager.send_event_to_conversation(sid, data)
|
||||
await self.conversation_manager.send_event_to_conversation(sid, data)
|
||||
|
||||
async def disconnect_from_session(self, connection_id: str):
|
||||
return await self.legacy_conversation_manager.disconnect_from_session(
|
||||
connection_id
|
||||
)
|
||||
|
||||
async def close_session(self, sid: str):
|
||||
if await self.should_start_in_legacy_mode(sid):
|
||||
await self.legacy_conversation_manager.close_session(sid)
|
||||
await self.conversation_manager.close_session(sid)
|
||||
|
||||
async def get_agent_loop_info(
|
||||
self, user_id: str | None = None, filter_to_sids: set[str] | None = None
|
||||
) -> list[AgentLoopInfo]:
|
||||
if filter_to_sids and len(filter_to_sids) == 1:
|
||||
sid = next(iter(filter_to_sids))
|
||||
if await self.should_start_in_legacy_mode(sid):
|
||||
return await self.legacy_conversation_manager.get_agent_loop_info(
|
||||
user_id, filter_to_sids
|
||||
)
|
||||
return await self.conversation_manager.get_agent_loop_info(
|
||||
user_id, filter_to_sids
|
||||
)
|
||||
agent_loops, legacy_agent_loops = await wait_all(
|
||||
[
|
||||
self.conversation_manager.get_agent_loop_info(user_id, filter_to_sids),
|
||||
self.legacy_conversation_manager.get_agent_loop_info(
|
||||
user_id, filter_to_sids
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# Combine results
|
||||
result = []
|
||||
legacy_sids = set()
|
||||
|
||||
# Add legacy agent loops
|
||||
for agent_loop in legacy_agent_loops:
|
||||
if await self.should_start_in_legacy_mode(agent_loop.conversation_id):
|
||||
result.append(agent_loop)
|
||||
legacy_sids.add(agent_loop.conversation_id)
|
||||
|
||||
# Add non-legacy agent loops
|
||||
for agent_loop in agent_loops:
|
||||
if (
|
||||
agent_loop.conversation_id not in legacy_sids
|
||||
and not await self.should_start_in_legacy_mode(
|
||||
agent_loop.conversation_id
|
||||
)
|
||||
):
|
||||
result.append(agent_loop)
|
||||
|
||||
return result
|
||||
|
||||
def _cleanup_expired_cache_entries(self):
|
||||
"""Remove expired entries from the local cache."""
|
||||
current_time = time.time()
|
||||
expired_keys = [
|
||||
key
|
||||
for key, entry in self._legacy_cache.items()
|
||||
if current_time - entry.timestamp > _LEGACY_ENTRY_TIMEOUT_SECONDS
|
||||
]
|
||||
for key in expired_keys:
|
||||
del self._legacy_cache[key]
|
||||
|
||||
async def should_start_in_legacy_mode(self, conversation_id: str) -> bool:
|
||||
"""
|
||||
Check if a conversation should run in legacy mode by directly checking the runtime.
|
||||
The /list method does not include stopped conversations even though the PVC for these
|
||||
may not yet have been deleted, so we need to check /sessions/{session_id} directly.
|
||||
"""
|
||||
# Clean up expired entries periodically
|
||||
self._cleanup_expired_cache_entries()
|
||||
|
||||
# First check the local cache
|
||||
if conversation_id in self._legacy_cache:
|
||||
cached_entry = self._legacy_cache[conversation_id]
|
||||
# Check if the cached value is still valid
|
||||
if time.time() - cached_entry.timestamp <= _LEGACY_ENTRY_TIMEOUT_SECONDS:
|
||||
return cached_entry.is_legacy
|
||||
|
||||
# If not in cache or expired, check the runtime directly
|
||||
runtime = await self.conversation_manager._get_runtime(conversation_id)
|
||||
is_legacy = self.is_legacy_runtime(runtime)
|
||||
|
||||
# Cache the result with current timestamp
|
||||
self._legacy_cache[conversation_id] = LegacyCacheEntry(is_legacy, time.time())
|
||||
|
||||
return is_legacy
|
||||
|
||||
def is_legacy_runtime(self, runtime: dict | None) -> bool:
|
||||
"""
|
||||
Determine if a runtime is a legacy runtime based on its command.
|
||||
|
||||
Args:
|
||||
runtime: The runtime dictionary or None if not found
|
||||
|
||||
Returns:
|
||||
bool: True if this is a legacy runtime, False otherwise
|
||||
"""
|
||||
if runtime is None:
|
||||
return False
|
||||
return 'openhands.server' not in runtime['command']
|
||||
|
||||
@classmethod
|
||||
def get_instance(
|
||||
cls,
|
||||
sio: socketio.AsyncServer,
|
||||
config: OpenHandsConfig,
|
||||
file_store: FileStore,
|
||||
server_config: ServerConfig,
|
||||
monitoring_listener: MonitoringListener,
|
||||
) -> ConversationManager:
|
||||
return LegacyConversationManager(
|
||||
sio=sio,
|
||||
config=config,
|
||||
server_config=server_config,
|
||||
file_store=file_store,
|
||||
conversation_manager=SaasNestedConversationManager.get_instance(
|
||||
sio, config, file_store, server_config, monitoring_listener
|
||||
),
|
||||
legacy_conversation_manager=ClusteredConversationManager.get_instance(
|
||||
sio, config, file_store, server_config, monitoring_listener
|
||||
),
|
||||
)
|
||||
@@ -152,17 +152,22 @@ class SetAuthCookieMiddleware:
|
||||
return False
|
||||
path = request.url.path
|
||||
|
||||
is_api_that_should_attach = path.startswith('/api') and path not in (
|
||||
ignore_paths = (
|
||||
'/api/options/config',
|
||||
'/api/keycloak/callback',
|
||||
'/api/billing/success',
|
||||
'/api/billing/cancel',
|
||||
'/api/billing/customer-setup-success',
|
||||
'/api/billing/stripe-webhook',
|
||||
'/oauth/device/authorize',
|
||||
'/oauth/device/token',
|
||||
)
|
||||
if path in ignore_paths:
|
||||
return False
|
||||
|
||||
is_mcp = path.startswith('/mcp')
|
||||
return is_api_that_should_attach or is_mcp
|
||||
is_api_route = path.startswith('/api')
|
||||
return is_api_route or is_mcp
|
||||
|
||||
async def _logout(self, request: Request):
|
||||
# Log out of keycloak - this prevents issues where you did not log in with the idp you believe you used
|
||||
|
||||
@@ -14,6 +14,7 @@ from server.auth.constants import (
|
||||
KEYCLOAK_SERVER_URL_EXT,
|
||||
ROLE_CHECK_ENABLED,
|
||||
)
|
||||
from server.auth.domain_blocker import domain_blocker
|
||||
from server.auth.gitlab_sync import schedule_gitlab_repo_sync
|
||||
from server.auth.saas_user_auth import SaasUserAuth
|
||||
from server.auth.token_manager import TokenManager
|
||||
@@ -145,7 +146,74 @@ async def keycloak_callback(
|
||||
content={'error': 'Missing user ID or username in response'},
|
||||
)
|
||||
|
||||
email = user_info.get('email')
|
||||
user_id = user_info['sub']
|
||||
|
||||
# Check if email domain is blocked
|
||||
email = user_info.get('email')
|
||||
if email and domain_blocker.is_active() and domain_blocker.is_domain_blocked(email):
|
||||
logger.warning(
|
||||
f'Blocked authentication attempt for email: {email}, user_id: {user_id}'
|
||||
)
|
||||
|
||||
# Disable the Keycloak account
|
||||
await token_manager.disable_keycloak_user(user_id, email)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={
|
||||
'error': 'Access denied: Your email domain is not allowed to access this service'
|
||||
},
|
||||
)
|
||||
|
||||
# Check for duplicate email with + modifier
|
||||
if email:
|
||||
try:
|
||||
has_duplicate = await token_manager.check_duplicate_base_email(
|
||||
email, user_id
|
||||
)
|
||||
if has_duplicate:
|
||||
logger.warning(
|
||||
f'Blocked signup attempt for email {email} - duplicate base email found',
|
||||
extra={'user_id': user_id, 'email': email},
|
||||
)
|
||||
|
||||
# Delete the Keycloak user that was automatically created during OAuth
|
||||
# This prevents orphaned accounts in Keycloak
|
||||
# The delete_keycloak_user method already handles all errors internally
|
||||
deletion_success = await token_manager.delete_keycloak_user(user_id)
|
||||
if deletion_success:
|
||||
logger.info(
|
||||
f'Deleted Keycloak user {user_id} after detecting duplicate email {email}'
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f'Failed to delete Keycloak user {user_id} after detecting duplicate email {email}. '
|
||||
f'User may need to be manually cleaned up.'
|
||||
)
|
||||
|
||||
# Redirect to home page with query parameter indicating the issue
|
||||
home_url = f'{request.base_url}?duplicated_email=true'
|
||||
return RedirectResponse(home_url, status_code=302)
|
||||
except Exception as e:
|
||||
# Log error but allow signup to proceed (fail open)
|
||||
logger.error(
|
||||
f'Error checking duplicate email for {email}: {e}',
|
||||
extra={'user_id': user_id, 'email': email},
|
||||
)
|
||||
|
||||
# Check email verification status
|
||||
email_verified = user_info.get('email_verified', False)
|
||||
if not email_verified:
|
||||
# Send verification email
|
||||
# Import locally to avoid circular import with email.py
|
||||
from server.routes.email import verify_email
|
||||
|
||||
await verify_email(request=request, user_id=user_id, is_auth_flow=True)
|
||||
redirect_url = f'{request.base_url}?email_verification_required=true'
|
||||
response = RedirectResponse(redirect_url, status_code=302)
|
||||
return response
|
||||
|
||||
# default to github IDP for now.
|
||||
# TODO: remove default once Keycloak is updated universally with the new attribute.
|
||||
idp: str = user_info.get('identity_provider', ProviderType.GITHUB.value)
|
||||
|
||||
@@ -74,7 +74,7 @@ async def update_email(
|
||||
accepted_tos=user_auth.accepted_tos,
|
||||
)
|
||||
|
||||
await _verify_email(request=request, user_id=user_id)
|
||||
await verify_email(request=request, user_id=user_id)
|
||||
|
||||
logger.info(f'Updating email address for {user_id} to {email}')
|
||||
return response
|
||||
@@ -91,8 +91,10 @@ async def update_email(
|
||||
|
||||
|
||||
@api_router.put('/verify')
|
||||
async def verify_email(request: Request, user_id: str = Depends(get_user_id)):
|
||||
await _verify_email(request=request, user_id=user_id)
|
||||
async def resend_email_verification(
|
||||
request: Request, user_id: str = Depends(get_user_id)
|
||||
):
|
||||
await verify_email(request=request, user_id=user_id)
|
||||
|
||||
logger.info(f'Resending verification email for {user_id}')
|
||||
return JSONResponse(
|
||||
@@ -124,10 +126,14 @@ async def verified_email(request: Request):
|
||||
return response
|
||||
|
||||
|
||||
async def _verify_email(request: Request, user_id: str):
|
||||
async def verify_email(request: Request, user_id: str, is_auth_flow: bool = False):
|
||||
keycloak_admin = get_keycloak_admin()
|
||||
scheme = 'http' if request.url.hostname == 'localhost' else 'https'
|
||||
redirect_uri = f'{scheme}://{request.url.netloc}/api/email/verified'
|
||||
redirect_uri = (
|
||||
f'{scheme}://{request.url.netloc}?email_verified=true'
|
||||
if is_auth_flow
|
||||
else f'{scheme}://{request.url.netloc}/api/email/verified'
|
||||
)
|
||||
logger.info(f'Redirect URI: {redirect_uri}')
|
||||
await keycloak_admin.a_send_verify_email(
|
||||
user_id=user_id,
|
||||
|
||||
324
enterprise/server/routes/oauth_device.py
Normal file
324
enterprise/server/routes/oauth_device.py
Normal file
@@ -0,0 +1,324 @@
|
||||
"""OAuth 2.0 Device Flow endpoints for CLI authentication."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Form, HTTPException, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
from storage.database import session_maker
|
||||
from storage.device_code_store import DeviceCodeStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_id
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DEVICE_CODE_EXPIRES_IN = 600 # 10 minutes
|
||||
DEVICE_TOKEN_POLL_INTERVAL = 5 # seconds
|
||||
|
||||
API_KEY_NAME = 'Device Link Access Key'
|
||||
KEY_EXPIRATION_TIME = timedelta(days=1) # Key expires in 24 hours
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DeviceAuthorizationResponse(BaseModel):
|
||||
device_code: str
|
||||
user_code: str
|
||||
verification_uri: str
|
||||
verification_uri_complete: str
|
||||
expires_in: int
|
||||
interval: int
|
||||
|
||||
|
||||
class DeviceTokenResponse(BaseModel):
|
||||
access_token: str # This will be the user's API key
|
||||
token_type: str = 'Bearer'
|
||||
expires_in: Optional[int] = None # API keys may not have expiration
|
||||
|
||||
|
||||
class DeviceTokenErrorResponse(BaseModel):
|
||||
error: str
|
||||
error_description: Optional[str] = None
|
||||
interval: Optional[int] = None # Required for slow_down error
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Router + stores
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
oauth_device_router = APIRouter(prefix='/oauth/device')
|
||||
device_code_store = DeviceCodeStore(session_maker)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _oauth_error(
|
||||
status_code: int,
|
||||
error: str,
|
||||
description: str,
|
||||
interval: Optional[int] = None,
|
||||
) -> JSONResponse:
|
||||
"""Return a JSON OAuth-style error response."""
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content=DeviceTokenErrorResponse(
|
||||
error=error,
|
||||
error_description=description,
|
||||
interval=interval,
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@oauth_device_router.post('/authorize', response_model=DeviceAuthorizationResponse)
|
||||
async def device_authorization(
|
||||
http_request: Request,
|
||||
) -> DeviceAuthorizationResponse:
|
||||
"""Start device flow by generating device and user codes."""
|
||||
try:
|
||||
device_code_entry = device_code_store.create_device_code(
|
||||
expires_in=DEVICE_CODE_EXPIRES_IN,
|
||||
)
|
||||
|
||||
base_url = str(http_request.base_url).rstrip('/')
|
||||
verification_uri = f'{base_url}/oauth/device/verify'
|
||||
verification_uri_complete = (
|
||||
f'{verification_uri}?user_code={device_code_entry.user_code}'
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'Device authorization initiated',
|
||||
extra={'user_code': device_code_entry.user_code},
|
||||
)
|
||||
|
||||
return DeviceAuthorizationResponse(
|
||||
device_code=device_code_entry.device_code,
|
||||
user_code=device_code_entry.user_code,
|
||||
verification_uri=verification_uri,
|
||||
verification_uri_complete=verification_uri_complete,
|
||||
expires_in=DEVICE_CODE_EXPIRES_IN,
|
||||
interval=device_code_entry.current_interval,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception('Error in device authorization: %s', str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Internal server error',
|
||||
) from e
|
||||
|
||||
|
||||
@oauth_device_router.post('/token')
|
||||
async def device_token(device_code: str = Form(...)):
|
||||
"""Poll for a token until the user authorizes or the code expires."""
|
||||
try:
|
||||
device_code_entry = device_code_store.get_by_device_code(device_code)
|
||||
|
||||
if not device_code_entry:
|
||||
return _oauth_error(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
'invalid_grant',
|
||||
'Invalid device code',
|
||||
)
|
||||
|
||||
# Check rate limiting (RFC 8628 section 3.5)
|
||||
is_too_fast, current_interval = device_code_entry.check_rate_limit()
|
||||
if is_too_fast:
|
||||
# Update poll time and increase interval
|
||||
device_code_store.update_poll_time(device_code, increase_interval=True)
|
||||
logger.warning(
|
||||
'Client polling too fast, returning slow_down error',
|
||||
extra={
|
||||
'device_code': device_code[:8] + '...', # Log partial for privacy
|
||||
'new_interval': current_interval,
|
||||
},
|
||||
)
|
||||
return _oauth_error(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
'slow_down',
|
||||
f'Polling too frequently. Wait at least {current_interval} seconds between requests.',
|
||||
interval=current_interval,
|
||||
)
|
||||
|
||||
# Update poll time for successful rate limit check
|
||||
device_code_store.update_poll_time(device_code, increase_interval=False)
|
||||
|
||||
if device_code_entry.is_expired():
|
||||
return _oauth_error(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
'expired_token',
|
||||
'Device code has expired',
|
||||
)
|
||||
|
||||
if device_code_entry.status == 'denied':
|
||||
return _oauth_error(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
'access_denied',
|
||||
'User denied the authorization request',
|
||||
)
|
||||
|
||||
if device_code_entry.status == 'pending':
|
||||
return _oauth_error(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
'authorization_pending',
|
||||
'User has not yet completed authorization',
|
||||
)
|
||||
|
||||
if device_code_entry.status == 'authorized':
|
||||
# Retrieve the specific API key for this device using the user_code
|
||||
api_key_store = ApiKeyStore.get_instance()
|
||||
device_key_name = f'{API_KEY_NAME} ({device_code_entry.user_code})'
|
||||
device_api_key = api_key_store.retrieve_api_key_by_name(
|
||||
device_code_entry.keycloak_user_id, device_key_name
|
||||
)
|
||||
|
||||
if not device_api_key:
|
||||
logger.error(
|
||||
'No device API key found for authorized device',
|
||||
extra={
|
||||
'user_id': device_code_entry.keycloak_user_id,
|
||||
'user_code': device_code_entry.user_code,
|
||||
},
|
||||
)
|
||||
return _oauth_error(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
'server_error',
|
||||
'API key not found',
|
||||
)
|
||||
|
||||
# Return the API key as access_token
|
||||
return DeviceTokenResponse(
|
||||
access_token=device_api_key,
|
||||
)
|
||||
|
||||
# Fallback for unexpected status values
|
||||
logger.error(
|
||||
'Unknown device code status',
|
||||
extra={'status': device_code_entry.status},
|
||||
)
|
||||
return _oauth_error(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
'server_error',
|
||||
'Unknown device code status',
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception('Error in device token: %s', str(e))
|
||||
return _oauth_error(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
'server_error',
|
||||
'Internal server error',
|
||||
)
|
||||
|
||||
|
||||
@oauth_device_router.post('/verify-authenticated')
|
||||
async def device_verification_authenticated(
|
||||
user_code: str = Form(...),
|
||||
user_id: str = Depends(get_user_id),
|
||||
):
|
||||
"""Process device verification for authenticated users (called by frontend)."""
|
||||
try:
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail='Authentication required',
|
||||
)
|
||||
|
||||
# Validate device code
|
||||
device_code_entry = device_code_store.get_by_user_code(user_code)
|
||||
if not device_code_entry:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='The device code is invalid or has expired.',
|
||||
)
|
||||
|
||||
if not device_code_entry.is_pending():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='This device code has already been processed.',
|
||||
)
|
||||
|
||||
# First, authorize the device code
|
||||
success = device_code_store.authorize_device_code(
|
||||
user_code=user_code,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
if not success:
|
||||
logger.error(
|
||||
'Failed to authorize device code',
|
||||
extra={'user_code': user_code, 'user_id': user_id},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to authorize the device. Please try again.',
|
||||
)
|
||||
|
||||
# Only create API key AFTER successful authorization
|
||||
api_key_store = ApiKeyStore.get_instance()
|
||||
try:
|
||||
# Create a unique API key for this device using user_code in the name
|
||||
device_key_name = f'{API_KEY_NAME} ({user_code})'
|
||||
api_key_store.create_api_key(
|
||||
user_id,
|
||||
name=device_key_name,
|
||||
expires_at=datetime.now(UTC) + KEY_EXPIRATION_TIME,
|
||||
)
|
||||
logger.info(
|
||||
'Created new device API key for user after successful authorization',
|
||||
extra={'user_id': user_id, 'user_code': user_code},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Failed to create device API key after authorization: %s', str(e)
|
||||
)
|
||||
|
||||
# Clean up: revert the device authorization since API key creation failed
|
||||
# This prevents the device from being in an authorized state without an API key
|
||||
try:
|
||||
device_code_store.deny_device_code(user_code)
|
||||
logger.info(
|
||||
'Reverted device authorization due to API key creation failure',
|
||||
extra={'user_code': user_code, 'user_id': user_id},
|
||||
)
|
||||
except Exception as cleanup_error:
|
||||
logger.exception(
|
||||
'Failed to revert device authorization during cleanup: %s',
|
||||
str(cleanup_error),
|
||||
)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to create API key for device access.',
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'Device code authorized with API key successfully',
|
||||
extra={'user_code': user_code, 'user_id': user_id},
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={'message': 'Device authorized successfully!'},
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception('Error in device verification: %s', str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='An unexpected error occurred. Please try again.',
|
||||
)
|
||||
@@ -31,6 +31,7 @@ from openhands.events.event_store import EventStore
|
||||
from openhands.events.serialization.event import event_to_dict
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderHandler
|
||||
from openhands.runtime.impl.remote.remote_runtime import RemoteRuntime
|
||||
from openhands.runtime.plugins.vscode import VSCodeRequirement
|
||||
from openhands.runtime.runtime_status import RuntimeStatus
|
||||
from openhands.server.config.server_config import ServerConfig
|
||||
from openhands.server.constants import ROOM_KEY
|
||||
@@ -71,10 +72,13 @@ RUNTIME_CONVERSATION_URL = RUNTIME_URL_PATTERN + (
|
||||
)
|
||||
|
||||
RUNTIME_USERNAME = os.getenv('RUNTIME_USERNAME')
|
||||
|
||||
SU_TO_USER = os.getenv('SU_TO_USER', 'false')
|
||||
truthy = {'1', 'true', 't', 'yes', 'y', 'on'}
|
||||
SU_TO_USER = str(SU_TO_USER.lower() in truthy).lower()
|
||||
|
||||
DISABLE_VSCODE_PLUGIN = os.getenv('DISABLE_VSCODE_PLUGIN', 'false').lower() == 'true'
|
||||
|
||||
# Time in seconds before a Redis entry is considered expired if not refreshed
|
||||
_REDIS_ENTRY_TIMEOUT_SECONDS = 300
|
||||
|
||||
@@ -799,6 +803,9 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
env_vars['INIT_GIT_IN_EMPTY_WORKSPACE'] = '1'
|
||||
env_vars['ENABLE_V1'] = '0'
|
||||
env_vars['SU_TO_USER'] = SU_TO_USER
|
||||
env_vars['DISABLE_VSCODE_PLUGIN'] = str(DISABLE_VSCODE_PLUGIN).lower()
|
||||
env_vars['BROWSERGYM_DOWNLOAD_DIR'] = '/workspace/.downloads/'
|
||||
env_vars['PLAYWRIGHT_BROWSERS_PATH'] = '/opt/playwright-browsers'
|
||||
|
||||
# We need this for LLM traces tracking to identify the source of the LLM calls
|
||||
env_vars['WEB_HOST'] = WEB_HOST
|
||||
@@ -814,11 +821,18 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
if self._runtime_container_image:
|
||||
config.sandbox.runtime_container_image = self._runtime_container_image
|
||||
|
||||
plugins = [
|
||||
plugin
|
||||
for plugin in agent.sandbox_plugins
|
||||
if not (DISABLE_VSCODE_PLUGIN and isinstance(plugin, VSCodeRequirement))
|
||||
]
|
||||
logger.info(f'Loaded plugins for runtime {sid}: {plugins}')
|
||||
|
||||
runtime = RemoteRuntime(
|
||||
config=config,
|
||||
event_stream=None, # type: ignore[arg-type]
|
||||
sid=sid,
|
||||
plugins=agent.sandbox_plugins,
|
||||
plugins=plugins,
|
||||
# env_vars=env_vars,
|
||||
# status_callback: Callable[..., None] | None = None,
|
||||
attach_to_existing=False,
|
||||
|
||||
20
enterprise/server/sharing/README.md
Normal file
20
enterprise/server/sharing/README.md
Normal file
@@ -0,0 +1,20 @@
|
||||
# Sharing Package
|
||||
|
||||
This package contains functionality for sharing conversations.
|
||||
|
||||
## Components
|
||||
|
||||
- **shared.py**: Data models for shared conversations
|
||||
- **shared_conversation_info_service.py**: Service interface for accessing shared conversation info
|
||||
- **sql_shared_conversation_info_service.py**: SQL implementation of the shared conversation info service
|
||||
- **shared_event_service.py**: Service interface for accessing shared events
|
||||
- **shared_event_service_impl.py**: Implementation of the shared event service
|
||||
- **shared_conversation_router.py**: REST API endpoints for shared conversations
|
||||
- **shared_event_router.py**: REST API endpoints for shared events
|
||||
|
||||
## Features
|
||||
|
||||
- Read-only access to shared conversations
|
||||
- Event access for shared conversations
|
||||
- Search and filtering capabilities
|
||||
- Pagination support
|
||||
142
enterprise/server/sharing/filesystem_shared_event_service.py
Normal file
142
enterprise/server/sharing/filesystem_shared_event_service.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""Implementation of SharedEventService.
|
||||
|
||||
This implementation provides read-only access to events from shared conversations:
|
||||
- Validates that the conversation is shared before returning events
|
||||
- Uses existing EventService for actual event retrieval
|
||||
- Uses SharedConversationInfoService for shared conversation validation
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Request
|
||||
from server.sharing.shared_conversation_info_service import (
|
||||
SharedConversationInfoService,
|
||||
)
|
||||
from server.sharing.shared_event_service import (
|
||||
SharedEventService,
|
||||
SharedEventServiceInjector,
|
||||
)
|
||||
from server.sharing.sql_shared_conversation_info_service import (
|
||||
SQLSharedConversationInfoService,
|
||||
)
|
||||
|
||||
from openhands.agent_server.models import EventPage, EventSortOrder
|
||||
from openhands.app_server.event.event_service import EventService
|
||||
from openhands.app_server.event_callback.event_callback_models import EventKind
|
||||
from openhands.app_server.services.injector import InjectorState
|
||||
from openhands.sdk import Event
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SharedEventServiceImpl(SharedEventService):
|
||||
"""Implementation of SharedEventService that validates shared access."""
|
||||
|
||||
shared_conversation_info_service: SharedConversationInfoService
|
||||
event_service: EventService
|
||||
|
||||
async def get_shared_event(
|
||||
self, conversation_id: UUID, event_id: str
|
||||
) -> Event | None:
|
||||
"""Given a conversation_id and event_id, retrieve an event if the conversation is shared."""
|
||||
# First check if the conversation is shared
|
||||
shared_conversation_info = (
|
||||
await self.shared_conversation_info_service.get_shared_conversation_info(
|
||||
conversation_id
|
||||
)
|
||||
)
|
||||
if shared_conversation_info is None:
|
||||
return None
|
||||
|
||||
# If conversation is shared, get the event
|
||||
return await self.event_service.get_event(event_id)
|
||||
|
||||
async def search_shared_events(
|
||||
self,
|
||||
conversation_id: UUID,
|
||||
kind__eq: EventKind | None = None,
|
||||
timestamp__gte: datetime | None = None,
|
||||
timestamp__lt: datetime | None = None,
|
||||
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
|
||||
page_id: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> EventPage:
|
||||
"""Search events for a specific shared conversation."""
|
||||
# First check if the conversation is shared
|
||||
shared_conversation_info = (
|
||||
await self.shared_conversation_info_service.get_shared_conversation_info(
|
||||
conversation_id
|
||||
)
|
||||
)
|
||||
if shared_conversation_info is None:
|
||||
# Return empty page if conversation is not shared
|
||||
return EventPage(items=[], next_page_id=None)
|
||||
|
||||
# If conversation is shared, search events for this conversation
|
||||
return await self.event_service.search_events(
|
||||
conversation_id__eq=conversation_id,
|
||||
kind__eq=kind__eq,
|
||||
timestamp__gte=timestamp__gte,
|
||||
timestamp__lt=timestamp__lt,
|
||||
sort_order=sort_order,
|
||||
page_id=page_id,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
async def count_shared_events(
|
||||
self,
|
||||
conversation_id: UUID,
|
||||
kind__eq: EventKind | None = None,
|
||||
timestamp__gte: datetime | None = None,
|
||||
timestamp__lt: datetime | None = None,
|
||||
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
|
||||
) -> int:
|
||||
"""Count events for a specific shared conversation."""
|
||||
# First check if the conversation is shared
|
||||
shared_conversation_info = (
|
||||
await self.shared_conversation_info_service.get_shared_conversation_info(
|
||||
conversation_id
|
||||
)
|
||||
)
|
||||
if shared_conversation_info is None:
|
||||
return 0
|
||||
|
||||
# If conversation is shared, count events for this conversation
|
||||
return await self.event_service.count_events(
|
||||
conversation_id__eq=conversation_id,
|
||||
kind__eq=kind__eq,
|
||||
timestamp__gte=timestamp__gte,
|
||||
timestamp__lt=timestamp__lt,
|
||||
sort_order=sort_order,
|
||||
)
|
||||
|
||||
|
||||
class SharedEventServiceImplInjector(SharedEventServiceInjector):
|
||||
async def inject(
|
||||
self, state: InjectorState, request: Request | None = None
|
||||
) -> AsyncGenerator[SharedEventService, None]:
|
||||
# Define inline to prevent circular lookup
|
||||
from openhands.app_server.config import (
|
||||
get_db_session,
|
||||
get_event_service,
|
||||
)
|
||||
|
||||
async with (
|
||||
get_db_session(state, request) as db_session,
|
||||
get_event_service(state, request) as event_service,
|
||||
):
|
||||
shared_conversation_info_service = SQLSharedConversationInfoService(
|
||||
db_session=db_session
|
||||
)
|
||||
service = SharedEventServiceImpl(
|
||||
shared_conversation_info_service=shared_conversation_info_service,
|
||||
event_service=event_service,
|
||||
)
|
||||
yield service
|
||||
@@ -0,0 +1,66 @@
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from server.sharing.shared_conversation_models import (
|
||||
SharedConversation,
|
||||
SharedConversationPage,
|
||||
SharedConversationSortOrder,
|
||||
)
|
||||
|
||||
from openhands.app_server.services.injector import Injector
|
||||
from openhands.sdk.utils.models import DiscriminatedUnionMixin
|
||||
|
||||
|
||||
class SharedConversationInfoService(ABC):
|
||||
"""Service for accessing shared conversation info without user restrictions."""
|
||||
|
||||
@abstractmethod
|
||||
async def search_shared_conversation_info(
|
||||
self,
|
||||
title__contains: str | None = None,
|
||||
created_at__gte: datetime | None = None,
|
||||
created_at__lt: datetime | None = None,
|
||||
updated_at__gte: datetime | None = None,
|
||||
updated_at__lt: datetime | None = None,
|
||||
sort_order: SharedConversationSortOrder = SharedConversationSortOrder.CREATED_AT_DESC,
|
||||
page_id: str | None = None,
|
||||
limit: int = 100,
|
||||
include_sub_conversations: bool = False,
|
||||
) -> SharedConversationPage:
|
||||
"""Search for shared conversations."""
|
||||
|
||||
@abstractmethod
|
||||
async def count_shared_conversation_info(
|
||||
self,
|
||||
title__contains: str | None = None,
|
||||
created_at__gte: datetime | None = None,
|
||||
created_at__lt: datetime | None = None,
|
||||
updated_at__gte: datetime | None = None,
|
||||
updated_at__lt: datetime | None = None,
|
||||
) -> int:
|
||||
"""Count shared conversations."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_shared_conversation_info(
|
||||
self, conversation_id: UUID
|
||||
) -> SharedConversation | None:
|
||||
"""Get a single shared conversation info, returning None if missing or not shared."""
|
||||
|
||||
async def batch_get_shared_conversation_info(
|
||||
self, conversation_ids: list[UUID]
|
||||
) -> list[SharedConversation | None]:
|
||||
"""Get a batch of shared conversation info, return None for any missing or non-shared."""
|
||||
return await asyncio.gather(
|
||||
*[
|
||||
self.get_shared_conversation_info(conversation_id)
|
||||
for conversation_id in conversation_ids
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class SharedConversationInfoServiceInjector(
|
||||
DiscriminatedUnionMixin, Injector[SharedConversationInfoService], ABC
|
||||
):
|
||||
pass
|
||||
56
enterprise/server/sharing/shared_conversation_models.py
Normal file
56
enterprise/server/sharing/shared_conversation_models.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
# Simplified imports to avoid dependency chain issues
|
||||
# from openhands.integrations.service_types import ProviderType
|
||||
# from openhands.sdk.llm import MetricsSnapshot
|
||||
# from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
# For now, use Any to avoid import issues
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from openhands.agent_server.utils import OpenHandsUUID, utc_now
|
||||
|
||||
ProviderType = Any
|
||||
MetricsSnapshot = Any
|
||||
ConversationTrigger = Any
|
||||
|
||||
|
||||
class SharedConversation(BaseModel):
|
||||
"""Shared conversation info model with all fields from AppConversationInfo."""
|
||||
|
||||
id: OpenHandsUUID = Field(default_factory=uuid4)
|
||||
|
||||
created_by_user_id: str | None
|
||||
sandbox_id: str
|
||||
|
||||
selected_repository: str | None = None
|
||||
selected_branch: str | None = None
|
||||
git_provider: ProviderType | None = None
|
||||
title: str | None = None
|
||||
pr_number: list[int] = Field(default_factory=list)
|
||||
llm_model: str | None = None
|
||||
|
||||
metrics: MetricsSnapshot | None = None
|
||||
|
||||
parent_conversation_id: OpenHandsUUID | None = None
|
||||
sub_conversation_ids: list[OpenHandsUUID] = Field(default_factory=list)
|
||||
|
||||
created_at: datetime = Field(default_factory=utc_now)
|
||||
updated_at: datetime = Field(default_factory=utc_now)
|
||||
|
||||
|
||||
class SharedConversationSortOrder(Enum):
|
||||
CREATED_AT = 'CREATED_AT'
|
||||
CREATED_AT_DESC = 'CREATED_AT_DESC'
|
||||
UPDATED_AT = 'UPDATED_AT'
|
||||
UPDATED_AT_DESC = 'UPDATED_AT_DESC'
|
||||
TITLE = 'TITLE'
|
||||
TITLE_DESC = 'TITLE_DESC'
|
||||
|
||||
|
||||
class SharedConversationPage(BaseModel):
|
||||
items: list[SharedConversation]
|
||||
next_page_id: str | None = None
|
||||
135
enterprise/server/sharing/shared_conversation_router.py
Normal file
135
enterprise/server/sharing/shared_conversation_router.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Shared Conversation router for OpenHands Server."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from server.sharing.shared_conversation_info_service import (
|
||||
SharedConversationInfoService,
|
||||
)
|
||||
from server.sharing.shared_conversation_models import (
|
||||
SharedConversation,
|
||||
SharedConversationPage,
|
||||
SharedConversationSortOrder,
|
||||
)
|
||||
from server.sharing.sql_shared_conversation_info_service import (
|
||||
SQLSharedConversationInfoServiceInjector,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix='/api/shared-conversations', tags=['Sharing'])
|
||||
shared_conversation_info_service_dependency = Depends(
|
||||
SQLSharedConversationInfoServiceInjector().depends
|
||||
)
|
||||
|
||||
# Read methods
|
||||
|
||||
|
||||
@router.get('/search')
|
||||
async def search_shared_conversations(
|
||||
title__contains: Annotated[
|
||||
str | None,
|
||||
Query(title='Filter by title containing this string'),
|
||||
] = None,
|
||||
created_at__gte: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Filter by created_at greater than or equal to this datetime'),
|
||||
] = None,
|
||||
created_at__lt: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Filter by created_at less than this datetime'),
|
||||
] = None,
|
||||
updated_at__gte: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Filter by updated_at greater than or equal to this datetime'),
|
||||
] = None,
|
||||
updated_at__lt: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Filter by updated_at less than this datetime'),
|
||||
] = None,
|
||||
sort_order: Annotated[
|
||||
SharedConversationSortOrder,
|
||||
Query(title='Sort order for results'),
|
||||
] = SharedConversationSortOrder.CREATED_AT_DESC,
|
||||
page_id: Annotated[
|
||||
str | None,
|
||||
Query(title='Optional next_page_id from the previously returned page'),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(
|
||||
title='The max number of results in the page',
|
||||
gt=0,
|
||||
lte=100,
|
||||
),
|
||||
] = 100,
|
||||
include_sub_conversations: Annotated[
|
||||
bool,
|
||||
Query(
|
||||
title='If True, include sub-conversations in the results. If False (default), exclude all sub-conversations.'
|
||||
),
|
||||
] = False,
|
||||
shared_conversation_service: SharedConversationInfoService = shared_conversation_info_service_dependency,
|
||||
) -> SharedConversationPage:
|
||||
"""Search / List shared conversations."""
|
||||
assert limit > 0
|
||||
assert limit <= 100
|
||||
return await shared_conversation_service.search_shared_conversation_info(
|
||||
title__contains=title__contains,
|
||||
created_at__gte=created_at__gte,
|
||||
created_at__lt=created_at__lt,
|
||||
updated_at__gte=updated_at__gte,
|
||||
updated_at__lt=updated_at__lt,
|
||||
sort_order=sort_order,
|
||||
page_id=page_id,
|
||||
limit=limit,
|
||||
include_sub_conversations=include_sub_conversations,
|
||||
)
|
||||
|
||||
|
||||
@router.get('/count')
|
||||
async def count_shared_conversations(
|
||||
title__contains: Annotated[
|
||||
str | None,
|
||||
Query(title='Filter by title containing this string'),
|
||||
] = None,
|
||||
created_at__gte: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Filter by created_at greater than or equal to this datetime'),
|
||||
] = None,
|
||||
created_at__lt: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Filter by created_at less than this datetime'),
|
||||
] = None,
|
||||
updated_at__gte: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Filter by updated_at greater than or equal to this datetime'),
|
||||
] = None,
|
||||
updated_at__lt: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Filter by updated_at less than this datetime'),
|
||||
] = None,
|
||||
shared_conversation_service: SharedConversationInfoService = shared_conversation_info_service_dependency,
|
||||
) -> int:
|
||||
"""Count shared conversations matching the given filters."""
|
||||
return await shared_conversation_service.count_shared_conversation_info(
|
||||
title__contains=title__contains,
|
||||
created_at__gte=created_at__gte,
|
||||
created_at__lt=created_at__lt,
|
||||
updated_at__gte=updated_at__gte,
|
||||
updated_at__lt=updated_at__lt,
|
||||
)
|
||||
|
||||
|
||||
@router.get('')
|
||||
async def batch_get_shared_conversations(
|
||||
ids: Annotated[list[str], Query()],
|
||||
shared_conversation_service: SharedConversationInfoService = shared_conversation_info_service_dependency,
|
||||
) -> list[SharedConversation | None]:
|
||||
"""Get a batch of shared conversations given their ids. Return None for any missing or non-shared."""
|
||||
assert len(ids) <= 100
|
||||
uuids = [UUID(id_) for id_ in ids]
|
||||
shared_conversation_info = (
|
||||
await shared_conversation_service.batch_get_shared_conversation_info(uuids)
|
||||
)
|
||||
return shared_conversation_info
|
||||
126
enterprise/server/sharing/shared_event_router.py
Normal file
126
enterprise/server/sharing/shared_event_router.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""Shared Event router for OpenHands Server."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from server.sharing.filesystem_shared_event_service import (
|
||||
SharedEventServiceImplInjector,
|
||||
)
|
||||
from server.sharing.shared_event_service import SharedEventService
|
||||
|
||||
from openhands.agent_server.models import EventPage, EventSortOrder
|
||||
from openhands.app_server.event_callback.event_callback_models import EventKind
|
||||
from openhands.sdk import Event
|
||||
|
||||
router = APIRouter(prefix='/api/shared-events', tags=['Sharing'])
|
||||
shared_event_service_dependency = Depends(SharedEventServiceImplInjector().depends)
|
||||
|
||||
|
||||
# Read methods
|
||||
|
||||
|
||||
@router.get('/search')
|
||||
async def search_shared_events(
|
||||
conversation_id: Annotated[
|
||||
str,
|
||||
Query(title='Conversation ID to search events for'),
|
||||
],
|
||||
kind__eq: Annotated[
|
||||
EventKind | None,
|
||||
Query(title='Optional filter by event kind'),
|
||||
] = None,
|
||||
timestamp__gte: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Optional filter by timestamp greater than or equal to'),
|
||||
] = None,
|
||||
timestamp__lt: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Optional filter by timestamp less than'),
|
||||
] = None,
|
||||
sort_order: Annotated[
|
||||
EventSortOrder,
|
||||
Query(title='Sort order for results'),
|
||||
] = EventSortOrder.TIMESTAMP,
|
||||
page_id: Annotated[
|
||||
str | None,
|
||||
Query(title='Optional next_page_id from the previously returned page'),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(title='The max number of results in the page', gt=0, lte=100),
|
||||
] = 100,
|
||||
shared_event_service: SharedEventService = shared_event_service_dependency,
|
||||
) -> EventPage:
|
||||
"""Search / List events for a shared conversation."""
|
||||
assert limit > 0
|
||||
assert limit <= 100
|
||||
return await shared_event_service.search_shared_events(
|
||||
conversation_id=UUID(conversation_id),
|
||||
kind__eq=kind__eq,
|
||||
timestamp__gte=timestamp__gte,
|
||||
timestamp__lt=timestamp__lt,
|
||||
sort_order=sort_order,
|
||||
page_id=page_id,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
|
||||
@router.get('/count')
|
||||
async def count_shared_events(
|
||||
conversation_id: Annotated[
|
||||
str,
|
||||
Query(title='Conversation ID to count events for'),
|
||||
],
|
||||
kind__eq: Annotated[
|
||||
EventKind | None,
|
||||
Query(title='Optional filter by event kind'),
|
||||
] = None,
|
||||
timestamp__gte: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Optional filter by timestamp greater than or equal to'),
|
||||
] = None,
|
||||
timestamp__lt: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Optional filter by timestamp less than'),
|
||||
] = None,
|
||||
sort_order: Annotated[
|
||||
EventSortOrder,
|
||||
Query(title='Sort order for results'),
|
||||
] = EventSortOrder.TIMESTAMP,
|
||||
shared_event_service: SharedEventService = shared_event_service_dependency,
|
||||
) -> int:
|
||||
"""Count events for a shared conversation matching the given filters."""
|
||||
return await shared_event_service.count_shared_events(
|
||||
conversation_id=UUID(conversation_id),
|
||||
kind__eq=kind__eq,
|
||||
timestamp__gte=timestamp__gte,
|
||||
timestamp__lt=timestamp__lt,
|
||||
sort_order=sort_order,
|
||||
)
|
||||
|
||||
|
||||
@router.get('')
|
||||
async def batch_get_shared_events(
|
||||
conversation_id: Annotated[
|
||||
UUID,
|
||||
Query(title='Conversation ID to get events for'),
|
||||
],
|
||||
id: Annotated[list[str], Query()],
|
||||
shared_event_service: SharedEventService = shared_event_service_dependency,
|
||||
) -> list[Event | None]:
|
||||
"""Get a batch of events for a shared conversation given their ids, returning null for any missing event."""
|
||||
assert len(id) <= 100
|
||||
events = await shared_event_service.batch_get_shared_events(conversation_id, id)
|
||||
return events
|
||||
|
||||
|
||||
@router.get('/{conversation_id}/{event_id}')
|
||||
async def get_shared_event(
|
||||
conversation_id: UUID,
|
||||
event_id: str,
|
||||
shared_event_service: SharedEventService = shared_event_service_dependency,
|
||||
) -> Event | None:
|
||||
"""Get a single event from a shared conversation by conversation_id and event_id."""
|
||||
return await shared_event_service.get_shared_event(conversation_id, event_id)
|
||||
64
enterprise/server/sharing/shared_event_service.py
Normal file
64
enterprise/server/sharing/shared_event_service.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from openhands.agent_server.models import EventPage, EventSortOrder
|
||||
from openhands.app_server.event_callback.event_callback_models import EventKind
|
||||
from openhands.app_server.services.injector import Injector
|
||||
from openhands.sdk import Event
|
||||
from openhands.sdk.utils.models import DiscriminatedUnionMixin
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SharedEventService(ABC):
|
||||
"""Event Service for getting events from shared conversations only."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_shared_event(
|
||||
self, conversation_id: UUID, event_id: str
|
||||
) -> Event | None:
|
||||
"""Given a conversation_id and event_id, retrieve an event if the conversation is shared."""
|
||||
|
||||
@abstractmethod
|
||||
async def search_shared_events(
|
||||
self,
|
||||
conversation_id: UUID,
|
||||
kind__eq: EventKind | None = None,
|
||||
timestamp__gte: datetime | None = None,
|
||||
timestamp__lt: datetime | None = None,
|
||||
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
|
||||
page_id: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> EventPage:
|
||||
"""Search events for a specific shared conversation."""
|
||||
|
||||
@abstractmethod
|
||||
async def count_shared_events(
|
||||
self,
|
||||
conversation_id: UUID,
|
||||
kind__eq: EventKind | None = None,
|
||||
timestamp__gte: datetime | None = None,
|
||||
timestamp__lt: datetime | None = None,
|
||||
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
|
||||
) -> int:
|
||||
"""Count events for a specific shared conversation."""
|
||||
|
||||
async def batch_get_shared_events(
|
||||
self, conversation_id: UUID, event_ids: list[str]
|
||||
) -> list[Event | None]:
|
||||
"""Given a conversation_id and list of event_ids, get events if the conversation is shared."""
|
||||
return await asyncio.gather(
|
||||
*[
|
||||
self.get_shared_event(conversation_id, event_id)
|
||||
for event_id in event_ids
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class SharedEventServiceInjector(
|
||||
DiscriminatedUnionMixin, Injector[SharedEventService], ABC
|
||||
):
|
||||
pass
|
||||
@@ -0,0 +1,282 @@
|
||||
"""SQL implementation of SharedConversationInfoService.
|
||||
|
||||
This implementation provides read-only access to shared conversations:
|
||||
- Direct database access without user permission checks
|
||||
- Filters only conversations marked as shared (currently public)
|
||||
- Full async/await support using SQL async db_sessions
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from typing import AsyncGenerator
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Request
|
||||
from server.sharing.shared_conversation_info_service import (
|
||||
SharedConversationInfoService,
|
||||
SharedConversationInfoServiceInjector,
|
||||
)
|
||||
from server.sharing.shared_conversation_models import (
|
||||
SharedConversation,
|
||||
SharedConversationPage,
|
||||
SharedConversationSortOrder,
|
||||
)
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
StoredConversationMetadata,
|
||||
)
|
||||
from openhands.app_server.services.injector import InjectorState
|
||||
from openhands.integrations.provider import ProviderType
|
||||
from openhands.sdk.llm import MetricsSnapshot
|
||||
from openhands.sdk.llm.utils.metrics import TokenUsage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SQLSharedConversationInfoService(SharedConversationInfoService):
|
||||
"""SQL implementation of SharedConversationInfoService for shared conversations only."""
|
||||
|
||||
db_session: AsyncSession
|
||||
|
||||
async def search_shared_conversation_info(
|
||||
self,
|
||||
title__contains: str | None = None,
|
||||
created_at__gte: datetime | None = None,
|
||||
created_at__lt: datetime | None = None,
|
||||
updated_at__gte: datetime | None = None,
|
||||
updated_at__lt: datetime | None = None,
|
||||
sort_order: SharedConversationSortOrder = SharedConversationSortOrder.CREATED_AT_DESC,
|
||||
page_id: str | None = None,
|
||||
limit: int = 100,
|
||||
include_sub_conversations: bool = False,
|
||||
) -> SharedConversationPage:
|
||||
"""Search for shared conversations."""
|
||||
query = self._public_select()
|
||||
|
||||
# Conditionally exclude sub-conversations based on the parameter
|
||||
if not include_sub_conversations:
|
||||
# Exclude sub-conversations (only include top-level conversations)
|
||||
query = query.where(
|
||||
StoredConversationMetadata.parent_conversation_id.is_(None)
|
||||
)
|
||||
|
||||
query = self._apply_filters(
|
||||
query=query,
|
||||
title__contains=title__contains,
|
||||
created_at__gte=created_at__gte,
|
||||
created_at__lt=created_at__lt,
|
||||
updated_at__gte=updated_at__gte,
|
||||
updated_at__lt=updated_at__lt,
|
||||
)
|
||||
|
||||
# Add sort order
|
||||
if sort_order == SharedConversationSortOrder.CREATED_AT:
|
||||
query = query.order_by(StoredConversationMetadata.created_at)
|
||||
elif sort_order == SharedConversationSortOrder.CREATED_AT_DESC:
|
||||
query = query.order_by(StoredConversationMetadata.created_at.desc())
|
||||
elif sort_order == SharedConversationSortOrder.UPDATED_AT:
|
||||
query = query.order_by(StoredConversationMetadata.last_updated_at)
|
||||
elif sort_order == SharedConversationSortOrder.UPDATED_AT_DESC:
|
||||
query = query.order_by(StoredConversationMetadata.last_updated_at.desc())
|
||||
elif sort_order == SharedConversationSortOrder.TITLE:
|
||||
query = query.order_by(StoredConversationMetadata.title)
|
||||
elif sort_order == SharedConversationSortOrder.TITLE_DESC:
|
||||
query = query.order_by(StoredConversationMetadata.title.desc())
|
||||
|
||||
# Apply pagination
|
||||
if page_id is not None:
|
||||
try:
|
||||
offset = int(page_id)
|
||||
query = query.offset(offset)
|
||||
except ValueError:
|
||||
# If page_id is not a valid integer, start from beginning
|
||||
offset = 0
|
||||
else:
|
||||
offset = 0
|
||||
|
||||
# Apply limit and get one extra to check if there are more results
|
||||
query = query.limit(limit + 1)
|
||||
|
||||
result = await self.db_session.execute(query)
|
||||
rows = result.scalars().all()
|
||||
|
||||
# Check if there are more results
|
||||
has_more = len(rows) > limit
|
||||
if has_more:
|
||||
rows = rows[:limit]
|
||||
|
||||
items = [self._to_shared_conversation(row) for row in rows]
|
||||
|
||||
# Calculate next page ID
|
||||
next_page_id = None
|
||||
if has_more:
|
||||
next_page_id = str(offset + limit)
|
||||
|
||||
return SharedConversationPage(items=items, next_page_id=next_page_id)
|
||||
|
||||
async def count_shared_conversation_info(
|
||||
self,
|
||||
title__contains: str | None = None,
|
||||
created_at__gte: datetime | None = None,
|
||||
created_at__lt: datetime | None = None,
|
||||
updated_at__gte: datetime | None = None,
|
||||
updated_at__lt: datetime | None = None,
|
||||
) -> int:
|
||||
"""Count shared conversations matching the given filters."""
|
||||
from sqlalchemy import func
|
||||
|
||||
query = select(func.count(StoredConversationMetadata.conversation_id))
|
||||
# Only include shared conversations
|
||||
query = query.where(StoredConversationMetadata.public == True) # noqa: E712
|
||||
query = query.where(StoredConversationMetadata.conversation_version == 'V1')
|
||||
|
||||
query = self._apply_filters(
|
||||
query=query,
|
||||
title__contains=title__contains,
|
||||
created_at__gte=created_at__gte,
|
||||
created_at__lt=created_at__lt,
|
||||
updated_at__gte=updated_at__gte,
|
||||
updated_at__lt=updated_at__lt,
|
||||
)
|
||||
|
||||
result = await self.db_session.execute(query)
|
||||
return result.scalar() or 0
|
||||
|
||||
async def get_shared_conversation_info(
|
||||
self, conversation_id: UUID
|
||||
) -> SharedConversation | None:
|
||||
"""Get a single public conversation info, returning None if missing or not shared."""
|
||||
query = self._public_select().where(
|
||||
StoredConversationMetadata.conversation_id == str(conversation_id)
|
||||
)
|
||||
|
||||
result = await self.db_session.execute(query)
|
||||
stored = result.scalar_one_or_none()
|
||||
|
||||
if stored is None:
|
||||
return None
|
||||
|
||||
return self._to_shared_conversation(stored)
|
||||
|
||||
def _public_select(self):
|
||||
"""Create a select query that only returns public conversations."""
|
||||
query = select(StoredConversationMetadata).where(
|
||||
StoredConversationMetadata.conversation_version == 'V1'
|
||||
)
|
||||
# Only include conversations marked as public
|
||||
query = query.where(StoredConversationMetadata.public == True) # noqa: E712
|
||||
return query
|
||||
|
||||
def _apply_filters(
|
||||
self,
|
||||
query,
|
||||
title__contains: str | None = None,
|
||||
created_at__gte: datetime | None = None,
|
||||
created_at__lt: datetime | None = None,
|
||||
updated_at__gte: datetime | None = None,
|
||||
updated_at__lt: datetime | None = None,
|
||||
):
|
||||
"""Apply common filters to a query."""
|
||||
if title__contains is not None:
|
||||
query = query.where(
|
||||
StoredConversationMetadata.title.contains(title__contains)
|
||||
)
|
||||
|
||||
if created_at__gte is not None:
|
||||
query = query.where(
|
||||
StoredConversationMetadata.created_at >= created_at__gte
|
||||
)
|
||||
|
||||
if created_at__lt is not None:
|
||||
query = query.where(StoredConversationMetadata.created_at < created_at__lt)
|
||||
|
||||
if updated_at__gte is not None:
|
||||
query = query.where(
|
||||
StoredConversationMetadata.last_updated_at >= updated_at__gte
|
||||
)
|
||||
|
||||
if updated_at__lt is not None:
|
||||
query = query.where(
|
||||
StoredConversationMetadata.last_updated_at < updated_at__lt
|
||||
)
|
||||
|
||||
return query
|
||||
|
||||
def _to_shared_conversation(
|
||||
self,
|
||||
stored: StoredConversationMetadata,
|
||||
sub_conversation_ids: list[UUID] | None = None,
|
||||
) -> SharedConversation:
|
||||
"""Convert StoredConversationMetadata to SharedConversation."""
|
||||
# V1 conversations should always have a sandbox_id
|
||||
sandbox_id = stored.sandbox_id
|
||||
assert sandbox_id is not None
|
||||
|
||||
# Rebuild token usage
|
||||
token_usage = TokenUsage(
|
||||
prompt_tokens=stored.prompt_tokens,
|
||||
completion_tokens=stored.completion_tokens,
|
||||
cache_read_tokens=stored.cache_read_tokens,
|
||||
cache_write_tokens=stored.cache_write_tokens,
|
||||
context_window=stored.context_window,
|
||||
per_turn_token=stored.per_turn_token,
|
||||
)
|
||||
|
||||
# Rebuild metrics object
|
||||
metrics = MetricsSnapshot(
|
||||
accumulated_cost=stored.accumulated_cost,
|
||||
max_budget_per_task=stored.max_budget_per_task,
|
||||
accumulated_token_usage=token_usage,
|
||||
)
|
||||
|
||||
# Get timestamps
|
||||
created_at = self._fix_timezone(stored.created_at)
|
||||
updated_at = self._fix_timezone(stored.last_updated_at)
|
||||
|
||||
return SharedConversation(
|
||||
id=UUID(stored.conversation_id),
|
||||
created_by_user_id=stored.user_id if stored.user_id else None,
|
||||
sandbox_id=stored.sandbox_id,
|
||||
selected_repository=stored.selected_repository,
|
||||
selected_branch=stored.selected_branch,
|
||||
git_provider=(
|
||||
ProviderType(stored.git_provider) if stored.git_provider else None
|
||||
),
|
||||
title=stored.title,
|
||||
pr_number=stored.pr_number,
|
||||
llm_model=stored.llm_model,
|
||||
metrics=metrics,
|
||||
parent_conversation_id=(
|
||||
UUID(stored.parent_conversation_id)
|
||||
if stored.parent_conversation_id
|
||||
else None
|
||||
),
|
||||
sub_conversation_ids=sub_conversation_ids or [],
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
)
|
||||
|
||||
def _fix_timezone(self, value: datetime) -> datetime:
|
||||
"""Sqlite does not store timezones - and since we can't update the existing models
|
||||
we assume UTC if the timezone is missing."""
|
||||
if not value.tzinfo:
|
||||
value = value.replace(tzinfo=UTC)
|
||||
return value
|
||||
|
||||
|
||||
class SQLSharedConversationInfoServiceInjector(SharedConversationInfoServiceInjector):
|
||||
async def inject(
|
||||
self, state: InjectorState, request: Request | None = None
|
||||
) -> AsyncGenerator[SharedConversationInfoService, None]:
|
||||
# Define inline to prevent circular lookup
|
||||
from openhands.app_server.config import get_db_session
|
||||
|
||||
async with get_db_session(state, request) as db_session:
|
||||
service = SQLSharedConversationInfoService(db_session=db_session)
|
||||
yield service
|
||||
@@ -17,10 +17,13 @@ from openhands.core.logger import openhands_logger as logger
|
||||
class ApiKeyStore:
|
||||
session_maker: sessionmaker
|
||||
|
||||
API_KEY_PREFIX = 'sk-oh-'
|
||||
|
||||
def generate_api_key(self, length: int = 32) -> str:
|
||||
"""Generate a random API key."""
|
||||
"""Generate a random API key with the sk-oh- prefix."""
|
||||
alphabet = string.ascii_letters + string.digits
|
||||
return ''.join(secrets.choice(alphabet) for _ in range(length))
|
||||
random_part = ''.join(secrets.choice(alphabet) for _ in range(length))
|
||||
return f'{self.API_KEY_PREFIX}{random_part}'
|
||||
|
||||
def create_api_key(
|
||||
self, user_id: str, name: str | None = None, expires_at: datetime | None = None
|
||||
@@ -57,9 +60,15 @@ class ApiKeyStore:
|
||||
return None
|
||||
|
||||
# Check if the key has expired
|
||||
if key_record.expires_at and key_record.expires_at < now:
|
||||
logger.info(f'API key has expired: {key_record.id}')
|
||||
return None
|
||||
if key_record.expires_at:
|
||||
# Handle timezone-naive datetime from database by assuming it's UTC
|
||||
expires_at = key_record.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=UTC)
|
||||
|
||||
if expires_at < now:
|
||||
logger.info(f'API key has expired: {key_record.id}')
|
||||
return None
|
||||
|
||||
# Update last_used_at timestamp
|
||||
session.execute(
|
||||
@@ -125,6 +134,33 @@ class ApiKeyStore:
|
||||
|
||||
return None
|
||||
|
||||
def retrieve_api_key_by_name(self, user_id: str, name: str) -> str | None:
|
||||
"""Retrieve an API key by name for a specific user."""
|
||||
with self.session_maker() as session:
|
||||
key_record = (
|
||||
session.query(ApiKey)
|
||||
.filter(ApiKey.user_id == user_id, ApiKey.name == name)
|
||||
.first()
|
||||
)
|
||||
return key_record.key if key_record else None
|
||||
|
||||
def delete_api_key_by_name(self, user_id: str, name: str) -> bool:
|
||||
"""Delete an API key by name for a specific user."""
|
||||
with self.session_maker() as session:
|
||||
key_record = (
|
||||
session.query(ApiKey)
|
||||
.filter(ApiKey.user_id == user_id, ApiKey.name == name)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not key_record:
|
||||
return False
|
||||
|
||||
session.delete(key_record)
|
||||
session.commit()
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> ApiKeyStore:
|
||||
"""Get an instance of the ApiKeyStore."""
|
||||
|
||||
@@ -19,17 +19,23 @@ GCP_REGION = os.environ.get('GCP_REGION')
|
||||
|
||||
POOL_SIZE = int(os.environ.get('DB_POOL_SIZE', '25'))
|
||||
MAX_OVERFLOW = int(os.environ.get('DB_MAX_OVERFLOW', '10'))
|
||||
POOL_RECYCLE = int(os.environ.get('DB_POOL_RECYCLE', '1800'))
|
||||
|
||||
# Initialize Cloud SQL Connector once at module level for GCP environments.
|
||||
_connector = None
|
||||
|
||||
|
||||
def _get_db_engine():
|
||||
if GCP_DB_INSTANCE: # GCP environments
|
||||
|
||||
def get_db_connection():
|
||||
global _connector
|
||||
from google.cloud.sql.connector import Connector
|
||||
|
||||
connector = Connector()
|
||||
if not _connector:
|
||||
_connector = Connector()
|
||||
instance_string = f'{GCP_PROJECT}:{GCP_REGION}:{GCP_DB_INSTANCE}'
|
||||
return connector.connect(
|
||||
return _connector.connect(
|
||||
instance_string, 'pg8000', user=DB_USER, password=DB_PASS, db=DB_NAME
|
||||
)
|
||||
|
||||
@@ -38,6 +44,7 @@ def _get_db_engine():
|
||||
creator=get_db_connection,
|
||||
pool_size=POOL_SIZE,
|
||||
max_overflow=MAX_OVERFLOW,
|
||||
pool_recycle=POOL_RECYCLE,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
else:
|
||||
@@ -48,6 +55,7 @@ def _get_db_engine():
|
||||
host_string,
|
||||
pool_size=POOL_SIZE,
|
||||
max_overflow=MAX_OVERFLOW,
|
||||
pool_recycle=POOL_RECYCLE,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
|
||||
|
||||
109
enterprise/storage/device_code.py
Normal file
109
enterprise/storage/device_code.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""Device code storage model for OAuth 2.0 Device Flow."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
|
||||
from sqlalchemy import Column, DateTime, Integer, String
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class DeviceCodeStatus(Enum):
|
||||
"""Status of a device code authorization request."""
|
||||
|
||||
PENDING = 'pending'
|
||||
AUTHORIZED = 'authorized'
|
||||
EXPIRED = 'expired'
|
||||
DENIED = 'denied'
|
||||
|
||||
|
||||
class DeviceCode(Base):
|
||||
"""Device code for OAuth 2.0 Device Flow.
|
||||
|
||||
This stores the device codes issued during the device authorization flow,
|
||||
along with their status and associated user information once authorized.
|
||||
"""
|
||||
|
||||
__tablename__ = 'device_codes'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
device_code = Column(String(128), unique=True, nullable=False, index=True)
|
||||
user_code = Column(String(16), unique=True, nullable=False, index=True)
|
||||
status = Column(String(32), nullable=False, default=DeviceCodeStatus.PENDING.value)
|
||||
|
||||
# Keycloak user ID who authorized the device (set during verification)
|
||||
keycloak_user_id = Column(String(255), nullable=True)
|
||||
|
||||
# Timestamps
|
||||
expires_at = Column(DateTime(timezone=True), nullable=False)
|
||||
authorized_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Rate limiting fields for RFC 8628 section 3.5 compliance
|
||||
last_poll_time = Column(DateTime(timezone=True), nullable=True)
|
||||
current_interval = Column(Integer, nullable=False, default=5)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<DeviceCode(user_code='{self.user_code}', status='{self.status}')>"
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if the device code has expired."""
|
||||
now = datetime.now(timezone.utc)
|
||||
return now > self.expires_at
|
||||
|
||||
def is_pending(self) -> bool:
|
||||
"""Check if the device code is still pending authorization."""
|
||||
return self.status == DeviceCodeStatus.PENDING.value and not self.is_expired()
|
||||
|
||||
def is_authorized(self) -> bool:
|
||||
"""Check if the device code has been authorized."""
|
||||
return self.status == DeviceCodeStatus.AUTHORIZED.value
|
||||
|
||||
def authorize(self, user_id: str) -> None:
|
||||
"""Mark the device code as authorized."""
|
||||
self.status = DeviceCodeStatus.AUTHORIZED.value
|
||||
self.keycloak_user_id = user_id # Set the Keycloak user ID during authorization
|
||||
self.authorized_at = datetime.now(timezone.utc)
|
||||
|
||||
def deny(self) -> None:
|
||||
"""Mark the device code as denied."""
|
||||
self.status = DeviceCodeStatus.DENIED.value
|
||||
|
||||
def expire(self) -> None:
|
||||
"""Mark the device code as expired."""
|
||||
self.status = DeviceCodeStatus.EXPIRED.value
|
||||
|
||||
def check_rate_limit(self) -> tuple[bool, int]:
|
||||
"""Check if the client is polling too fast.
|
||||
|
||||
Returns:
|
||||
tuple: (is_too_fast, current_interval)
|
||||
- is_too_fast: True if client should receive slow_down error
|
||||
- current_interval: Current polling interval to use
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# If this is the first poll, allow it
|
||||
if self.last_poll_time is None:
|
||||
return False, self.current_interval
|
||||
|
||||
# Calculate time since last poll
|
||||
time_since_last_poll = (now - self.last_poll_time).total_seconds()
|
||||
|
||||
# Check if polling too fast
|
||||
if time_since_last_poll < self.current_interval:
|
||||
# Increase interval for slow_down (RFC 8628 section 3.5)
|
||||
new_interval = min(self.current_interval + 5, 60) # Cap at 60 seconds
|
||||
return True, new_interval
|
||||
|
||||
return False, self.current_interval
|
||||
|
||||
def update_poll_time(self, increase_interval: bool = False) -> None:
|
||||
"""Update the last poll time and optionally increase the interval.
|
||||
|
||||
Args:
|
||||
increase_interval: If True, increase the current interval for slow_down
|
||||
"""
|
||||
self.last_poll_time = datetime.now(timezone.utc)
|
||||
|
||||
if increase_interval:
|
||||
# Increase interval by 5 seconds, cap at 60 seconds (RFC 8628)
|
||||
self.current_interval = min(self.current_interval + 5, 60)
|
||||
167
enterprise/storage/device_code_store.py
Normal file
167
enterprise/storage/device_code_store.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""Device code store for OAuth 2.0 Device Flow."""
|
||||
|
||||
import secrets
|
||||
import string
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from storage.device_code import DeviceCode
|
||||
|
||||
|
||||
class DeviceCodeStore:
|
||||
"""Store for managing OAuth 2.0 device codes."""
|
||||
|
||||
def __init__(self, session_maker):
|
||||
self.session_maker = session_maker
|
||||
|
||||
def generate_user_code(self) -> str:
|
||||
"""Generate a human-readable user code (8 characters, uppercase letters and digits)."""
|
||||
# Use a mix of uppercase letters and digits, avoiding confusing characters
|
||||
alphabet = 'ABCDEFGHJKLMNPQRSTUVWXYZ23456789' # No I, O, 0, 1
|
||||
return ''.join(secrets.choice(alphabet) for _ in range(8))
|
||||
|
||||
def generate_device_code(self) -> str:
|
||||
"""Generate a secure device code (128 characters)."""
|
||||
alphabet = string.ascii_letters + string.digits
|
||||
return ''.join(secrets.choice(alphabet) for _ in range(128))
|
||||
|
||||
def create_device_code(
|
||||
self,
|
||||
expires_in: int = 600, # 10 minutes default
|
||||
max_attempts: int = 10,
|
||||
) -> DeviceCode:
|
||||
"""Create a new device code entry.
|
||||
|
||||
Uses database constraints to ensure uniqueness, avoiding TOCTOU race conditions.
|
||||
Retries on constraint violations until unique codes are generated.
|
||||
|
||||
Args:
|
||||
expires_in: Expiration time in seconds
|
||||
max_attempts: Maximum number of attempts to generate unique codes
|
||||
|
||||
Returns:
|
||||
The created DeviceCode instance
|
||||
|
||||
Raises:
|
||||
RuntimeError: If unable to generate unique codes after max_attempts
|
||||
"""
|
||||
for attempt in range(max_attempts):
|
||||
user_code = self.generate_user_code()
|
||||
device_code = self.generate_device_code()
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in)
|
||||
|
||||
device_code_entry = DeviceCode(
|
||||
device_code=device_code,
|
||||
user_code=user_code,
|
||||
keycloak_user_id=None, # Will be set during authorization
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
try:
|
||||
with self.session_maker() as session:
|
||||
session.add(device_code_entry)
|
||||
session.commit()
|
||||
session.refresh(device_code_entry)
|
||||
session.expunge(device_code_entry) # Detach from session cleanly
|
||||
return device_code_entry
|
||||
except IntegrityError:
|
||||
# Constraint violation - codes already exist, retry with new codes
|
||||
continue
|
||||
|
||||
raise RuntimeError(
|
||||
f'Failed to generate unique device codes after {max_attempts} attempts'
|
||||
)
|
||||
|
||||
def get_by_device_code(self, device_code: str) -> DeviceCode | None:
|
||||
"""Get device code entry by device code."""
|
||||
with self.session_maker() as session:
|
||||
result = (
|
||||
session.query(DeviceCode).filter_by(device_code=device_code).first()
|
||||
)
|
||||
if result:
|
||||
session.expunge(result) # Detach from session cleanly
|
||||
return result
|
||||
|
||||
def get_by_user_code(self, user_code: str) -> DeviceCode | None:
|
||||
"""Get device code entry by user code."""
|
||||
with self.session_maker() as session:
|
||||
result = session.query(DeviceCode).filter_by(user_code=user_code).first()
|
||||
if result:
|
||||
session.expunge(result) # Detach from session cleanly
|
||||
return result
|
||||
|
||||
def authorize_device_code(self, user_code: str, user_id: str) -> bool:
|
||||
"""Authorize a device code.
|
||||
|
||||
Args:
|
||||
user_code: The user code to authorize
|
||||
user_id: The user ID from Keycloak
|
||||
|
||||
Returns:
|
||||
True if authorization was successful, False otherwise
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
device_code_entry = (
|
||||
session.query(DeviceCode).filter_by(user_code=user_code).first()
|
||||
)
|
||||
|
||||
if not device_code_entry:
|
||||
return False
|
||||
|
||||
if not device_code_entry.is_pending():
|
||||
return False
|
||||
|
||||
device_code_entry.authorize(user_id)
|
||||
session.commit()
|
||||
|
||||
return True
|
||||
|
||||
def deny_device_code(self, user_code: str) -> bool:
|
||||
"""Deny a device code authorization.
|
||||
|
||||
Args:
|
||||
user_code: The user code to deny
|
||||
|
||||
Returns:
|
||||
True if denial was successful, False otherwise
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
device_code_entry = (
|
||||
session.query(DeviceCode).filter_by(user_code=user_code).first()
|
||||
)
|
||||
|
||||
if not device_code_entry:
|
||||
return False
|
||||
|
||||
if not device_code_entry.is_pending():
|
||||
return False
|
||||
|
||||
device_code_entry.deny()
|
||||
session.commit()
|
||||
|
||||
return True
|
||||
|
||||
def update_poll_time(
|
||||
self, device_code: str, increase_interval: bool = False
|
||||
) -> bool:
|
||||
"""Update the poll time for a device code and optionally increase interval.
|
||||
|
||||
Args:
|
||||
device_code: The device code to update
|
||||
increase_interval: If True, increase the polling interval for slow_down
|
||||
|
||||
Returns:
|
||||
True if update was successful, False otherwise
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
device_code_entry = (
|
||||
session.query(DeviceCode).filter_by(device_code=device_code).first()
|
||||
)
|
||||
|
||||
if not device_code_entry:
|
||||
return False
|
||||
|
||||
device_code_entry.update_poll_time(increase_interval)
|
||||
session.commit()
|
||||
|
||||
return True
|
||||
@@ -61,6 +61,7 @@ class SaasConversationStore(ConversationStore):
|
||||
kwargs.pop('context_window', None)
|
||||
kwargs.pop('per_turn_token', None)
|
||||
kwargs.pop('parent_conversation_id', None)
|
||||
kwargs.pop('public')
|
||||
|
||||
return ConversationMetadata(**kwargs)
|
||||
|
||||
|
||||
@@ -94,6 +94,7 @@ class SaasSettingsStore(SettingsStore):
|
||||
}
|
||||
self._decrypt_kwargs(kwargs)
|
||||
settings = Settings(**kwargs)
|
||||
|
||||
return settings
|
||||
|
||||
async def store(self, item: Settings):
|
||||
|
||||
@@ -4,6 +4,8 @@ from uuid import uuid4
|
||||
|
||||
from integrations.types import GitLabResourceType
|
||||
from integrations.utils import GITLAB_WEBHOOK_URL
|
||||
from sqlalchemy import text
|
||||
from storage.database import session_maker
|
||||
from storage.gitlab_webhook import GitlabWebhook, WebhookStatus
|
||||
from storage.gitlab_webhook_store import GitlabWebhookStore
|
||||
|
||||
@@ -258,6 +260,25 @@ class VerifyWebhookStatus:
|
||||
|
||||
from integrations.gitlab.gitlab_service import SaaSGitLabService
|
||||
|
||||
# Check if the table exists before proceeding
|
||||
# This handles cases where the CronJob runs before database migrations complete
|
||||
with session_maker() as session:
|
||||
query = text("""
|
||||
SELECT EXISTS (
|
||||
SELECT FROM information_schema.tables
|
||||
WHERE table_name = 'gitlab_webhook'
|
||||
)
|
||||
""")
|
||||
result = await session.execute(query)
|
||||
table_exists = result.scalar() or False
|
||||
|
||||
if not table_exists:
|
||||
logger.info(
|
||||
'gitlab_webhook table does not exist yet, '
|
||||
'waiting for database migrations to complete'
|
||||
)
|
||||
return
|
||||
|
||||
# Get an instance of the webhook store
|
||||
webhook_store = await GitlabWebhookStore.get_instance()
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ from storage.base import Base
|
||||
# Anything not loaded here may not have a table created for it.
|
||||
from storage.billing_session import BillingSession
|
||||
from storage.conversation_work import ConversationWork
|
||||
from storage.device_code import DeviceCode # noqa: F401
|
||||
from storage.feedback import Feedback
|
||||
from storage.github_app_installation import GithubAppInstallation
|
||||
from storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus
|
||||
|
||||
151
enterprise/tests/unit/server/routes/test_email_routes.py
Normal file
151
enterprise/tests/unit/server/routes/test_email_routes.py
Normal file
@@ -0,0 +1,151 @@
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import Request
|
||||
from fastapi.responses import RedirectResponse
|
||||
from pydantic import SecretStr
|
||||
from server.auth.saas_user_auth import SaasUserAuth
|
||||
from server.routes.email import verified_email, verify_email
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request():
|
||||
"""Create a mock request object."""
|
||||
request = MagicMock(spec=Request)
|
||||
request.url = MagicMock()
|
||||
request.url.hostname = 'localhost'
|
||||
request.url.netloc = 'localhost:8000'
|
||||
request.url.path = '/api/email/verified'
|
||||
request.base_url = 'http://localhost:8000/'
|
||||
request.headers = {}
|
||||
request.cookies = {}
|
||||
request.query_params = MagicMock()
|
||||
return request
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_auth():
|
||||
"""Create a mock SaasUserAuth object."""
|
||||
auth = MagicMock(spec=SaasUserAuth)
|
||||
auth.access_token = SecretStr('test_access_token')
|
||||
auth.refresh_token = SecretStr('test_refresh_token')
|
||||
auth.email = 'test@example.com'
|
||||
auth.email_verified = False
|
||||
auth.accepted_tos = True
|
||||
auth.refresh = AsyncMock()
|
||||
return auth
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_email_default_behavior(mock_request):
|
||||
"""Test verify_email with default is_auth_flow=False."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
mock_keycloak_admin = AsyncMock()
|
||||
mock_keycloak_admin.a_send_verify_email = AsyncMock()
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin
|
||||
):
|
||||
await verify_email(request=mock_request, user_id=user_id)
|
||||
|
||||
# Assert
|
||||
mock_keycloak_admin.a_send_verify_email.assert_called_once()
|
||||
call_args = mock_keycloak_admin.a_send_verify_email.call_args
|
||||
assert call_args.kwargs['user_id'] == user_id
|
||||
assert (
|
||||
call_args.kwargs['redirect_uri'] == 'http://localhost:8000/api/email/verified'
|
||||
)
|
||||
assert 'client_id' in call_args.kwargs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_email_with_auth_flow(mock_request):
|
||||
"""Test verify_email with is_auth_flow=True."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
mock_keycloak_admin = AsyncMock()
|
||||
mock_keycloak_admin.a_send_verify_email = AsyncMock()
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin
|
||||
):
|
||||
await verify_email(request=mock_request, user_id=user_id, is_auth_flow=True)
|
||||
|
||||
# Assert
|
||||
mock_keycloak_admin.a_send_verify_email.assert_called_once()
|
||||
call_args = mock_keycloak_admin.a_send_verify_email.call_args
|
||||
assert call_args.kwargs['user_id'] == user_id
|
||||
assert (
|
||||
call_args.kwargs['redirect_uri'] == 'http://localhost:8000?email_verified=true'
|
||||
)
|
||||
assert 'client_id' in call_args.kwargs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_email_https_scheme(mock_request):
|
||||
"""Test verify_email uses https scheme for non-localhost hosts."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
mock_request.url.hostname = 'example.com'
|
||||
mock_request.url.netloc = 'example.com'
|
||||
mock_keycloak_admin = AsyncMock()
|
||||
mock_keycloak_admin.a_send_verify_email = AsyncMock()
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin
|
||||
):
|
||||
await verify_email(request=mock_request, user_id=user_id, is_auth_flow=True)
|
||||
|
||||
# Assert
|
||||
call_args = mock_keycloak_admin.a_send_verify_email.call_args
|
||||
assert call_args.kwargs['redirect_uri'].startswith('https://')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verified_email_default_redirect(mock_request, mock_user_auth):
|
||||
"""Test verified_email redirects to /settings/user by default."""
|
||||
# Arrange
|
||||
mock_request.query_params.get.return_value = None
|
||||
|
||||
# Act
|
||||
with (
|
||||
patch('server.routes.email.get_user_auth', return_value=mock_user_auth),
|
||||
patch('server.routes.email.set_response_cookie') as mock_set_cookie,
|
||||
):
|
||||
result = await verified_email(mock_request)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.status_code == 302
|
||||
assert result.headers['location'] == 'http://localhost:8000/settings/user'
|
||||
mock_user_auth.refresh.assert_called_once()
|
||||
mock_set_cookie.assert_called_once()
|
||||
assert mock_user_auth.email_verified is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verified_email_https_scheme(mock_request, mock_user_auth):
|
||||
"""Test verified_email uses https scheme for non-localhost hosts."""
|
||||
# Arrange
|
||||
mock_request.url.hostname = 'example.com'
|
||||
mock_request.url.netloc = 'example.com'
|
||||
mock_request.query_params.get.return_value = None
|
||||
|
||||
# Act
|
||||
with (
|
||||
patch('server.routes.email.get_user_auth', return_value=mock_user_auth),
|
||||
patch('server.routes.email.set_response_cookie') as mock_set_cookie,
|
||||
):
|
||||
result = await verified_email(mock_request)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.headers['location'].startswith('https://')
|
||||
mock_set_cookie.assert_called_once()
|
||||
# Verify secure flag is True for https
|
||||
call_kwargs = mock_set_cookie.call_args.kwargs
|
||||
assert call_kwargs['secure'] is True
|
||||
610
enterprise/tests/unit/server/routes/test_oauth_device.py
Normal file
610
enterprise/tests/unit/server/routes/test_oauth_device.py
Normal file
@@ -0,0 +1,610 @@
|
||||
"""Unit tests for OAuth2 Device Flow endpoints."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from server.routes.oauth_device import (
|
||||
device_authorization,
|
||||
device_token,
|
||||
device_verification_authenticated,
|
||||
)
|
||||
from storage.device_code import DeviceCode
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_device_code_store():
|
||||
"""Mock device code store."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_api_key_store():
|
||||
"""Mock API key store."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_token_manager():
|
||||
"""Mock token manager."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request():
|
||||
"""Mock FastAPI request."""
|
||||
request = MagicMock(spec=Request)
|
||||
request.base_url = 'https://test.example.com/'
|
||||
return request
|
||||
|
||||
|
||||
class TestDeviceAuthorization:
|
||||
"""Test device authorization endpoint."""
|
||||
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_device_authorization_success(self, mock_store, mock_request):
|
||||
"""Test successful device authorization."""
|
||||
mock_device = DeviceCode(
|
||||
device_code='test-device-code-123',
|
||||
user_code='ABC12345',
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
current_interval=5, # Default interval
|
||||
)
|
||||
mock_store.create_device_code.return_value = mock_device
|
||||
|
||||
result = await device_authorization(mock_request)
|
||||
|
||||
assert result.device_code == 'test-device-code-123'
|
||||
assert result.user_code == 'ABC12345'
|
||||
assert result.expires_in == 600
|
||||
assert result.interval == 5 # Should match device's current_interval
|
||||
assert 'verify' in result.verification_uri
|
||||
assert 'ABC12345' in result.verification_uri_complete
|
||||
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_device_authorization_with_increased_interval(
|
||||
self, mock_store, mock_request
|
||||
):
|
||||
"""Test device authorization returns increased interval from rate limiting."""
|
||||
mock_device = DeviceCode(
|
||||
device_code='test-device-code-456',
|
||||
user_code='XYZ98765',
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
current_interval=15, # Increased interval from previous rate limiting
|
||||
)
|
||||
mock_store.create_device_code.return_value = mock_device
|
||||
|
||||
result = await device_authorization(mock_request)
|
||||
|
||||
assert result.device_code == 'test-device-code-456'
|
||||
assert result.user_code == 'XYZ98765'
|
||||
assert result.expires_in == 600
|
||||
assert result.interval == 15 # Should match device's increased current_interval
|
||||
assert 'verify' in result.verification_uri
|
||||
assert 'XYZ98765' in result.verification_uri_complete
|
||||
|
||||
|
||||
class TestDeviceToken:
|
||||
"""Test device token endpoint."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'device_exists,status,expected_error',
|
||||
[
|
||||
(False, None, 'invalid_grant'),
|
||||
(True, 'expired', 'expired_token'),
|
||||
(True, 'denied', 'access_denied'),
|
||||
(True, 'pending', 'authorization_pending'),
|
||||
],
|
||||
)
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_device_token_error_cases(
|
||||
self, mock_store, device_exists, status, expected_error
|
||||
):
|
||||
"""Test various error cases for device token endpoint."""
|
||||
device_code = 'test-device-code'
|
||||
|
||||
if device_exists:
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_expired.return_value = status == 'expired'
|
||||
mock_device.status = status
|
||||
# Mock rate limiting - return False (not too fast) and default interval
|
||||
mock_device.check_rate_limit.return_value = (False, 5)
|
||||
mock_store.get_by_device_code.return_value = mock_device
|
||||
mock_store.update_poll_time.return_value = True
|
||||
else:
|
||||
mock_store.get_by_device_code.return_value = None
|
||||
|
||||
result = await device_token(device_code=device_code)
|
||||
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == 400
|
||||
# Check error in response content
|
||||
content = result.body.decode()
|
||||
assert expected_error in content
|
||||
|
||||
@patch('server.routes.oauth_device.ApiKeyStore')
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_device_token_success(self, mock_store, mock_api_key_class):
|
||||
"""Test successful device token retrieval."""
|
||||
device_code = 'test-device-code'
|
||||
|
||||
# Mock authorized device
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_expired.return_value = False
|
||||
mock_device.status = 'authorized'
|
||||
mock_device.keycloak_user_id = 'user-123'
|
||||
mock_device.user_code = (
|
||||
'ABC12345' # Add user_code for device-specific API key lookup
|
||||
)
|
||||
# Mock rate limiting - return False (not too fast) and default interval
|
||||
mock_device.check_rate_limit.return_value = (False, 5)
|
||||
mock_store.get_by_device_code.return_value = mock_device
|
||||
mock_store.update_poll_time.return_value = True
|
||||
|
||||
# Mock API key retrieval
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_store.retrieve_api_key_by_name.return_value = 'test-api-key'
|
||||
mock_api_key_class.get_instance.return_value = mock_api_key_store
|
||||
|
||||
result = await device_token(device_code=device_code)
|
||||
|
||||
# Check that result is a DeviceTokenResponse
|
||||
assert result.access_token == 'test-api-key'
|
||||
assert result.token_type == 'Bearer'
|
||||
|
||||
# Verify that the correct device-specific API key name was used
|
||||
mock_api_key_store.retrieve_api_key_by_name.assert_called_once_with(
|
||||
'user-123', 'Device Link Access Key (ABC12345)'
|
||||
)
|
||||
|
||||
|
||||
class TestDeviceVerificationAuthenticated:
|
||||
"""Test device verification authenticated endpoint."""
|
||||
|
||||
async def test_verification_unauthenticated_user(self):
|
||||
"""Test verification with unauthenticated user."""
|
||||
with pytest.raises(HTTPException):
|
||||
await device_verification_authenticated(user_code='ABC12345', user_id=None)
|
||||
|
||||
@patch('server.routes.oauth_device.ApiKeyStore')
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_verification_invalid_device_code(
|
||||
self, mock_store, mock_api_key_class
|
||||
):
|
||||
"""Test verification with invalid device code."""
|
||||
mock_store.get_by_user_code.return_value = None
|
||||
|
||||
with pytest.raises(HTTPException):
|
||||
await device_verification_authenticated(
|
||||
user_code='INVALID', user_id='user-123'
|
||||
)
|
||||
|
||||
@patch('server.routes.oauth_device.ApiKeyStore')
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_verification_already_processed(self, mock_store, mock_api_key_class):
|
||||
"""Test verification with already processed device code."""
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_pending.return_value = False
|
||||
mock_store.get_by_user_code.return_value = mock_device
|
||||
|
||||
with pytest.raises(HTTPException):
|
||||
await device_verification_authenticated(
|
||||
user_code='ABC12345', user_id='user-123'
|
||||
)
|
||||
|
||||
@patch('server.routes.oauth_device.ApiKeyStore')
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_verification_success(self, mock_store, mock_api_key_class):
|
||||
"""Test successful device verification."""
|
||||
# Mock device code
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_pending.return_value = True
|
||||
mock_store.get_by_user_code.return_value = mock_device
|
||||
mock_store.authorize_device_code.return_value = True
|
||||
|
||||
# Mock API key store
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_class.get_instance.return_value = mock_api_key_store
|
||||
|
||||
result = await device_verification_authenticated(
|
||||
user_code='ABC12345', user_id='user-123'
|
||||
)
|
||||
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == 200
|
||||
# Should NOT delete existing API keys (multiple devices allowed)
|
||||
mock_api_key_store.delete_api_key_by_name.assert_not_called()
|
||||
# Should create a new API key with device-specific name
|
||||
mock_api_key_store.create_api_key.assert_called_once()
|
||||
call_args = mock_api_key_store.create_api_key.call_args
|
||||
assert call_args[1]['name'] == 'Device Link Access Key (ABC12345)'
|
||||
mock_store.authorize_device_code.assert_called_once_with(
|
||||
user_code='ABC12345', user_id='user-123'
|
||||
)
|
||||
|
||||
@patch('server.routes.oauth_device.ApiKeyStore')
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_multiple_device_authentication(self, mock_store, mock_api_key_class):
|
||||
"""Test that multiple devices can authenticate simultaneously."""
|
||||
# Mock API key store
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_class.get_instance.return_value = mock_api_key_store
|
||||
|
||||
# Simulate two different devices
|
||||
device1_code = 'ABC12345'
|
||||
device2_code = 'XYZ67890'
|
||||
user_id = 'user-123'
|
||||
|
||||
# Mock device codes
|
||||
mock_device1 = MagicMock()
|
||||
mock_device1.is_pending.return_value = True
|
||||
mock_device2 = MagicMock()
|
||||
mock_device2.is_pending.return_value = True
|
||||
|
||||
# Configure mock store to return appropriate device for each user_code
|
||||
def get_by_user_code_side_effect(user_code):
|
||||
if user_code == device1_code:
|
||||
return mock_device1
|
||||
elif user_code == device2_code:
|
||||
return mock_device2
|
||||
return None
|
||||
|
||||
mock_store.get_by_user_code.side_effect = get_by_user_code_side_effect
|
||||
mock_store.authorize_device_code.return_value = True
|
||||
|
||||
# Authenticate first device
|
||||
result1 = await device_verification_authenticated(
|
||||
user_code=device1_code, user_id=user_id
|
||||
)
|
||||
|
||||
# Authenticate second device
|
||||
result2 = await device_verification_authenticated(
|
||||
user_code=device2_code, user_id=user_id
|
||||
)
|
||||
|
||||
# Both should succeed
|
||||
assert isinstance(result1, JSONResponse)
|
||||
assert result1.status_code == 200
|
||||
assert isinstance(result2, JSONResponse)
|
||||
assert result2.status_code == 200
|
||||
|
||||
# Should create two separate API keys with different names
|
||||
assert mock_api_key_store.create_api_key.call_count == 2
|
||||
|
||||
# Check that each device got a unique API key name
|
||||
call_args_list = mock_api_key_store.create_api_key.call_args_list
|
||||
device1_name = call_args_list[0][1]['name']
|
||||
device2_name = call_args_list[1][1]['name']
|
||||
|
||||
assert device1_name == f'Device Link Access Key ({device1_code})'
|
||||
assert device2_name == f'Device Link Access Key ({device2_code})'
|
||||
assert device1_name != device2_name # Ensure they're different
|
||||
|
||||
# Should NOT delete any existing API keys
|
||||
mock_api_key_store.delete_api_key_by_name.assert_not_called()
|
||||
|
||||
|
||||
class TestDeviceTokenRateLimiting:
|
||||
"""Test rate limiting for device token polling (RFC 8628 section 3.5)."""
|
||||
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_first_poll_allowed(self, mock_store):
|
||||
"""Test that the first poll is always allowed."""
|
||||
# Create a device code with no previous poll time
|
||||
mock_device = DeviceCode(
|
||||
device_code='test_device_code',
|
||||
user_code='ABC123',
|
||||
status='pending',
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
last_poll_time=None, # First poll
|
||||
current_interval=5,
|
||||
)
|
||||
mock_store.get_by_device_code.return_value = mock_device
|
||||
mock_store.update_poll_time.return_value = True
|
||||
|
||||
device_code = 'test_device_code'
|
||||
result = await device_token(device_code=device_code)
|
||||
|
||||
# Should return authorization_pending, not slow_down
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == 400
|
||||
content = result.body.decode()
|
||||
assert 'authorization_pending' in content
|
||||
assert 'slow_down' not in content
|
||||
|
||||
# Should update poll time without increasing interval
|
||||
mock_store.update_poll_time.assert_called_with(
|
||||
'test_device_code', increase_interval=False
|
||||
)
|
||||
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_normal_polling_allowed(self, mock_store):
|
||||
"""Test that normal polling (respecting interval) is allowed."""
|
||||
# Create a device code with last poll time 6 seconds ago (interval is 5)
|
||||
last_poll = datetime.now(UTC) - timedelta(seconds=6)
|
||||
mock_device = DeviceCode(
|
||||
device_code='test_device_code',
|
||||
user_code='ABC123',
|
||||
status='pending',
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
last_poll_time=last_poll,
|
||||
current_interval=5,
|
||||
)
|
||||
mock_store.get_by_device_code.return_value = mock_device
|
||||
mock_store.update_poll_time.return_value = True
|
||||
|
||||
device_code = 'test_device_code'
|
||||
result = await device_token(device_code=device_code)
|
||||
|
||||
# Should return authorization_pending, not slow_down
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == 400
|
||||
content = result.body.decode()
|
||||
assert 'authorization_pending' in content
|
||||
assert 'slow_down' not in content
|
||||
|
||||
# Should update poll time without increasing interval
|
||||
mock_store.update_poll_time.assert_called_with(
|
||||
'test_device_code', increase_interval=False
|
||||
)
|
||||
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_fast_polling_returns_slow_down(self, mock_store):
|
||||
"""Test that polling too fast returns slow_down error."""
|
||||
# Create a device code with last poll time 2 seconds ago (interval is 5)
|
||||
last_poll = datetime.now(UTC) - timedelta(seconds=2)
|
||||
mock_device = DeviceCode(
|
||||
device_code='test_device_code',
|
||||
user_code='ABC123',
|
||||
status='pending',
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
last_poll_time=last_poll,
|
||||
current_interval=5,
|
||||
)
|
||||
mock_store.get_by_device_code.return_value = mock_device
|
||||
mock_store.update_poll_time.return_value = True
|
||||
|
||||
device_code = 'test_device_code'
|
||||
result = await device_token(device_code=device_code)
|
||||
|
||||
# Should return slow_down error
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == 400
|
||||
content = result.body.decode()
|
||||
assert 'slow_down' in content
|
||||
assert 'interval' in content
|
||||
assert '10' in content # New interval should be 5 + 5 = 10
|
||||
|
||||
# Should update poll time and increase interval
|
||||
mock_store.update_poll_time.assert_called_with(
|
||||
'test_device_code', increase_interval=True
|
||||
)
|
||||
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_interval_increases_with_repeated_fast_polling(self, mock_store):
|
||||
"""Test that interval increases with repeated fast polling."""
|
||||
# Create a device code with higher current interval from previous slow_down
|
||||
last_poll = datetime.now(UTC) - timedelta(seconds=5) # 5 seconds ago
|
||||
mock_device = DeviceCode(
|
||||
device_code='test_device_code',
|
||||
user_code='ABC123',
|
||||
status='pending',
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
last_poll_time=last_poll,
|
||||
current_interval=15, # Already increased from previous slow_down
|
||||
)
|
||||
mock_store.get_by_device_code.return_value = mock_device
|
||||
mock_store.update_poll_time.return_value = True
|
||||
|
||||
device_code = 'test_device_code'
|
||||
result = await device_token(device_code=device_code)
|
||||
|
||||
# Should return slow_down error with increased interval
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == 400
|
||||
content = result.body.decode()
|
||||
assert 'slow_down' in content
|
||||
assert '20' in content # New interval should be 15 + 5 = 20
|
||||
|
||||
# Should update poll time and increase interval
|
||||
mock_store.update_poll_time.assert_called_with(
|
||||
'test_device_code', increase_interval=True
|
||||
)
|
||||
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_interval_caps_at_maximum(self, mock_store):
|
||||
"""Test that interval is capped at maximum value."""
|
||||
# Create a device code with interval near maximum
|
||||
last_poll = datetime.now(UTC) - timedelta(seconds=30)
|
||||
mock_device = DeviceCode(
|
||||
device_code='test_device_code',
|
||||
user_code='ABC123',
|
||||
status='pending',
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
last_poll_time=last_poll,
|
||||
current_interval=58, # Near maximum of 60
|
||||
)
|
||||
mock_store.get_by_device_code.return_value = mock_device
|
||||
mock_store.update_poll_time.return_value = True
|
||||
|
||||
device_code = 'test_device_code'
|
||||
result = await device_token(device_code=device_code)
|
||||
|
||||
# Should return slow_down error with capped interval
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == 400
|
||||
content = result.body.decode()
|
||||
assert 'slow_down' in content
|
||||
assert '60' in content # Should be capped at 60, not 63
|
||||
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_rate_limiting_with_authorized_device(self, mock_store):
|
||||
"""Test that rate limiting still applies to authorized devices."""
|
||||
# Create an authorized device code with recent poll
|
||||
last_poll = datetime.now(UTC) - timedelta(seconds=2)
|
||||
mock_device = DeviceCode(
|
||||
device_code='test_device_code',
|
||||
user_code='ABC123',
|
||||
status='authorized', # Device is authorized
|
||||
keycloak_user_id='user123',
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
last_poll_time=last_poll,
|
||||
current_interval=5,
|
||||
)
|
||||
mock_store.get_by_device_code.return_value = mock_device
|
||||
mock_store.update_poll_time.return_value = True
|
||||
|
||||
device_code = 'test_device_code'
|
||||
result = await device_token(device_code=device_code)
|
||||
|
||||
# Should still return slow_down error even for authorized device
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == 400
|
||||
content = result.body.decode()
|
||||
assert 'slow_down' in content
|
||||
|
||||
# Should update poll time and increase interval
|
||||
mock_store.update_poll_time.assert_called_with(
|
||||
'test_device_code', increase_interval=True
|
||||
)
|
||||
|
||||
|
||||
class TestDeviceVerificationTransactionIntegrity:
|
||||
"""Test transaction integrity for device verification to prevent orphaned API keys."""
|
||||
|
||||
@patch('server.routes.oauth_device.ApiKeyStore')
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_authorization_failure_prevents_api_key_creation(
|
||||
self, mock_store, mock_api_key_class
|
||||
):
|
||||
"""Test that if device authorization fails, no API key is created."""
|
||||
# Mock device code
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_pending.return_value = True
|
||||
mock_store.get_by_user_code.return_value = mock_device
|
||||
mock_store.authorize_device_code.return_value = False # Authorization fails
|
||||
|
||||
# Mock API key store
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_class.get_instance.return_value = mock_api_key_store
|
||||
|
||||
# Should raise HTTPException due to authorization failure
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await device_verification_authenticated(
|
||||
user_code='ABC12345', user_id='user-123'
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert 'Failed to authorize the device' in exc_info.value.detail
|
||||
|
||||
# API key should NOT be created since authorization failed
|
||||
mock_api_key_store.create_api_key.assert_not_called()
|
||||
mock_store.authorize_device_code.assert_called_once_with(
|
||||
user_code='ABC12345', user_id='user-123'
|
||||
)
|
||||
|
||||
@patch('server.routes.oauth_device.ApiKeyStore')
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_api_key_creation_failure_reverts_authorization(
|
||||
self, mock_store, mock_api_key_class
|
||||
):
|
||||
"""Test that if API key creation fails after authorization, the authorization is reverted."""
|
||||
# Mock device code
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_pending.return_value = True
|
||||
mock_store.get_by_user_code.return_value = mock_device
|
||||
mock_store.authorize_device_code.return_value = True # Authorization succeeds
|
||||
mock_store.deny_device_code.return_value = True # Cleanup succeeds
|
||||
|
||||
# Mock API key store to fail on creation
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_store.create_api_key.side_effect = Exception('Database error')
|
||||
mock_api_key_class.get_instance.return_value = mock_api_key_store
|
||||
|
||||
# Should raise HTTPException due to API key creation failure
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await device_verification_authenticated(
|
||||
user_code='ABC12345', user_id='user-123'
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert 'Failed to create API key for device access' in exc_info.value.detail
|
||||
|
||||
# Authorization should have been attempted first
|
||||
mock_store.authorize_device_code.assert_called_once_with(
|
||||
user_code='ABC12345', user_id='user-123'
|
||||
)
|
||||
|
||||
# API key creation should have been attempted after authorization
|
||||
mock_api_key_store.create_api_key.assert_called_once()
|
||||
|
||||
# Authorization should be reverted due to API key creation failure
|
||||
mock_store.deny_device_code.assert_called_once_with('ABC12345')
|
||||
|
||||
@patch('server.routes.oauth_device.ApiKeyStore')
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_api_key_creation_failure_cleanup_failure_logged(
|
||||
self, mock_store, mock_api_key_class
|
||||
):
|
||||
"""Test that cleanup failure is logged but doesn't prevent the main error from being raised."""
|
||||
# Mock device code
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_pending.return_value = True
|
||||
mock_store.get_by_user_code.return_value = mock_device
|
||||
mock_store.authorize_device_code.return_value = True # Authorization succeeds
|
||||
mock_store.deny_device_code.side_effect = Exception(
|
||||
'Cleanup failed'
|
||||
) # Cleanup fails
|
||||
|
||||
# Mock API key store to fail on creation
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_store.create_api_key.side_effect = Exception('Database error')
|
||||
mock_api_key_class.get_instance.return_value = mock_api_key_store
|
||||
|
||||
# Should still raise HTTPException for the original API key creation failure
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await device_verification_authenticated(
|
||||
user_code='ABC12345', user_id='user-123'
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert 'Failed to create API key for device access' in exc_info.value.detail
|
||||
|
||||
# Both operations should have been attempted
|
||||
mock_store.authorize_device_code.assert_called_once()
|
||||
mock_api_key_store.create_api_key.assert_called_once()
|
||||
mock_store.deny_device_code.assert_called_once_with('ABC12345')
|
||||
|
||||
@patch('server.routes.oauth_device.ApiKeyStore')
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_successful_flow_creates_api_key_after_authorization(
|
||||
self, mock_store, mock_api_key_class
|
||||
):
|
||||
"""Test that in the successful flow, API key is created only after authorization."""
|
||||
# Mock device code
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_pending.return_value = True
|
||||
mock_store.get_by_user_code.return_value = mock_device
|
||||
mock_store.authorize_device_code.return_value = True # Authorization succeeds
|
||||
|
||||
# Mock API key store
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_class.get_instance.return_value = mock_api_key_store
|
||||
|
||||
result = await device_verification_authenticated(
|
||||
user_code='ABC12345', user_id='user-123'
|
||||
)
|
||||
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == 200
|
||||
|
||||
# Verify the order: authorization first, then API key creation
|
||||
mock_store.authorize_device_code.assert_called_once_with(
|
||||
user_code='ABC12345', user_id='user-123'
|
||||
)
|
||||
mock_api_key_store.create_api_key.assert_called_once()
|
||||
|
||||
# No cleanup should be needed in successful case
|
||||
mock_store.deny_device_code.assert_not_called()
|
||||
83
enterprise/tests/unit/storage/test_device_code.py
Normal file
83
enterprise/tests/unit/storage/test_device_code.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""Unit tests for DeviceCode model."""
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
from storage.device_code import DeviceCode, DeviceCodeStatus
|
||||
|
||||
|
||||
class TestDeviceCode:
|
||||
"""Test cases for DeviceCode model."""
|
||||
|
||||
@pytest.fixture
|
||||
def device_code(self):
|
||||
"""Create a test device code."""
|
||||
return DeviceCode(
|
||||
device_code='test-device-code-123',
|
||||
user_code='ABC12345',
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(minutes=10),
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'expires_delta,expected',
|
||||
[
|
||||
(timedelta(minutes=5), False), # Future expiry
|
||||
(timedelta(minutes=-5), True), # Past expiry
|
||||
(timedelta(seconds=1), False), # Just future (not expired)
|
||||
],
|
||||
)
|
||||
def test_is_expired(self, expires_delta, expected):
|
||||
"""Test expiration check with various time deltas."""
|
||||
device_code = DeviceCode(
|
||||
device_code='test-device-code',
|
||||
user_code='ABC12345',
|
||||
expires_at=datetime.now(timezone.utc) + expires_delta,
|
||||
)
|
||||
assert device_code.is_expired() == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'status,expired,expected',
|
||||
[
|
||||
(DeviceCodeStatus.PENDING.value, False, True),
|
||||
(DeviceCodeStatus.PENDING.value, True, False),
|
||||
(DeviceCodeStatus.AUTHORIZED.value, False, False),
|
||||
(DeviceCodeStatus.DENIED.value, False, False),
|
||||
],
|
||||
)
|
||||
def test_is_pending(self, status, expired, expected):
|
||||
"""Test pending status check."""
|
||||
expires_at = (
|
||||
datetime.now(timezone.utc) - timedelta(minutes=1)
|
||||
if expired
|
||||
else datetime.now(timezone.utc) + timedelta(minutes=10)
|
||||
)
|
||||
device_code = DeviceCode(
|
||||
device_code='test-device-code',
|
||||
user_code='ABC12345',
|
||||
status=status,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
assert device_code.is_pending() == expected
|
||||
|
||||
def test_authorize(self, device_code):
|
||||
"""Test device authorization."""
|
||||
user_id = 'test-user-123'
|
||||
|
||||
device_code.authorize(user_id)
|
||||
|
||||
assert device_code.status == DeviceCodeStatus.AUTHORIZED.value
|
||||
assert device_code.keycloak_user_id == user_id
|
||||
assert device_code.authorized_at is not None
|
||||
assert isinstance(device_code.authorized_at, datetime)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'method,expected_status',
|
||||
[
|
||||
('deny', DeviceCodeStatus.DENIED.value),
|
||||
('expire', DeviceCodeStatus.EXPIRED.value),
|
||||
],
|
||||
)
|
||||
def test_status_changes(self, device_code, method, expected_status):
|
||||
"""Test status change methods."""
|
||||
getattr(device_code, method)()
|
||||
assert device_code.status == expected_status
|
||||
193
enterprise/tests/unit/storage/test_device_code_store.py
Normal file
193
enterprise/tests/unit/storage/test_device_code_store.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""Unit tests for DeviceCodeStore."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from storage.device_code import DeviceCode
|
||||
from storage.device_code_store import DeviceCodeStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session():
|
||||
"""Mock database session."""
|
||||
session = MagicMock()
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_maker(mock_session):
|
||||
"""Mock session maker."""
|
||||
session_maker = MagicMock()
|
||||
session_maker.return_value.__enter__.return_value = mock_session
|
||||
session_maker.return_value.__exit__.return_value = None
|
||||
return session_maker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def device_code_store(mock_session_maker):
|
||||
"""Create DeviceCodeStore instance."""
|
||||
return DeviceCodeStore(mock_session_maker)
|
||||
|
||||
|
||||
class TestDeviceCodeStore:
|
||||
"""Test cases for DeviceCodeStore."""
|
||||
|
||||
def test_generate_user_code(self, device_code_store):
|
||||
"""Test user code generation."""
|
||||
code = device_code_store.generate_user_code()
|
||||
|
||||
assert len(code) == 8
|
||||
assert code.isupper()
|
||||
# Should not contain confusing characters
|
||||
assert not any(char in code for char in 'IO01')
|
||||
|
||||
def test_generate_device_code(self, device_code_store):
|
||||
"""Test device code generation."""
|
||||
code = device_code_store.generate_device_code()
|
||||
|
||||
assert len(code) == 128
|
||||
assert code.isalnum()
|
||||
|
||||
def test_create_device_code_success(self, device_code_store, mock_session):
|
||||
"""Test successful device code creation."""
|
||||
# Mock successful creation (no IntegrityError)
|
||||
mock_device_code = MagicMock(spec=DeviceCode)
|
||||
mock_device_code.device_code = 'test-device-code-123'
|
||||
mock_device_code.user_code = 'TESTCODE'
|
||||
|
||||
# Mock the session to return our mock device code after refresh
|
||||
def mock_refresh(obj):
|
||||
obj.device_code = mock_device_code.device_code
|
||||
obj.user_code = mock_device_code.user_code
|
||||
|
||||
mock_session.refresh.side_effect = mock_refresh
|
||||
|
||||
result = device_code_store.create_device_code(expires_in=600)
|
||||
|
||||
assert isinstance(result, DeviceCode)
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
mock_session.refresh.assert_called_once()
|
||||
mock_session.expunge.assert_called_once()
|
||||
|
||||
def test_create_device_code_with_retries(
|
||||
self, device_code_store, mock_session_maker
|
||||
):
|
||||
"""Test device code creation with constraint violation retries."""
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session_maker.return_value.__exit__.return_value = None
|
||||
|
||||
# First attempt fails with IntegrityError, second succeeds
|
||||
mock_session.commit.side_effect = [IntegrityError('', '', ''), None]
|
||||
|
||||
mock_device_code = MagicMock(spec=DeviceCode)
|
||||
mock_device_code.device_code = 'test-device-code-456'
|
||||
mock_device_code.user_code = 'TESTCD2'
|
||||
|
||||
def mock_refresh(obj):
|
||||
obj.device_code = mock_device_code.device_code
|
||||
obj.user_code = mock_device_code.user_code
|
||||
|
||||
mock_session.refresh.side_effect = mock_refresh
|
||||
|
||||
store = DeviceCodeStore(mock_session_maker)
|
||||
result = store.create_device_code(expires_in=600)
|
||||
|
||||
assert isinstance(result, DeviceCode)
|
||||
assert mock_session.add.call_count == 2 # Two attempts
|
||||
assert mock_session.commit.call_count == 2 # Two attempts
|
||||
|
||||
def test_create_device_code_max_attempts_exceeded(
|
||||
self, device_code_store, mock_session_maker
|
||||
):
|
||||
"""Test device code creation failure after max attempts."""
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session_maker.return_value.__exit__.return_value = None
|
||||
|
||||
# All attempts fail with IntegrityError
|
||||
mock_session.commit.side_effect = IntegrityError('', '', '')
|
||||
|
||||
store = DeviceCodeStore(mock_session_maker)
|
||||
|
||||
with pytest.raises(
|
||||
RuntimeError,
|
||||
match='Failed to generate unique device codes after 3 attempts',
|
||||
):
|
||||
store.create_device_code(expires_in=600, max_attempts=3)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'lookup_method,lookup_field',
|
||||
[
|
||||
('get_by_device_code', 'device_code'),
|
||||
('get_by_user_code', 'user_code'),
|
||||
],
|
||||
)
|
||||
def test_lookup_methods(
|
||||
self, device_code_store, mock_session, lookup_method, lookup_field
|
||||
):
|
||||
"""Test device code lookup methods."""
|
||||
test_code = 'test-code-123'
|
||||
mock_device_code = MagicMock()
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
mock_device_code
|
||||
)
|
||||
|
||||
result = getattr(device_code_store, lookup_method)(test_code)
|
||||
|
||||
assert result == mock_device_code
|
||||
mock_session.query.assert_called_once_with(DeviceCode)
|
||||
mock_session.query.return_value.filter_by.assert_called_once_with(
|
||||
**{lookup_field: test_code}
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'device_exists,is_pending,expected_result',
|
||||
[
|
||||
(True, True, True), # Success case
|
||||
(False, True, False), # Device not found
|
||||
(True, False, False), # Device not pending
|
||||
],
|
||||
)
|
||||
def test_authorize_device_code(
|
||||
self,
|
||||
device_code_store,
|
||||
mock_session,
|
||||
device_exists,
|
||||
is_pending,
|
||||
expected_result,
|
||||
):
|
||||
"""Test device code authorization."""
|
||||
user_code = 'ABC12345'
|
||||
user_id = 'test-user-123'
|
||||
|
||||
if device_exists:
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_pending.return_value = is_pending
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_device
|
||||
else:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
result = device_code_store.authorize_device_code(user_code, user_id)
|
||||
|
||||
assert result == expected_result
|
||||
if expected_result:
|
||||
mock_device.authorize.assert_called_once_with(user_id)
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_deny_device_code(self, device_code_store, mock_session):
|
||||
"""Test device code denial."""
|
||||
user_code = 'ABC12345'
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_pending.return_value = True
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
mock_device
|
||||
)
|
||||
|
||||
result = device_code_store.deny_device_code(user_code)
|
||||
|
||||
assert result is True
|
||||
mock_device.deny.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
@@ -25,10 +25,12 @@ def api_key_store(mock_session_maker):
|
||||
|
||||
|
||||
def test_generate_api_key(api_key_store):
|
||||
"""Test that generate_api_key returns a string of the expected length."""
|
||||
"""Test that generate_api_key returns a string with sk-oh- prefix and expected length."""
|
||||
key = api_key_store.generate_api_key(length=32)
|
||||
assert isinstance(key, str)
|
||||
assert len(key) == 32
|
||||
assert key.startswith('sk-oh-')
|
||||
# Total length should be prefix (6 chars) + random part (32 chars) = 38 chars
|
||||
assert len(key) == len('sk-oh-') + 32
|
||||
|
||||
|
||||
def test_create_api_key(api_key_store, mock_session):
|
||||
@@ -90,6 +92,50 @@ def test_validate_api_key_expired(api_key_store, mock_session):
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
|
||||
def test_validate_api_key_expired_timezone_naive(api_key_store, mock_session):
|
||||
"""Test validating an expired API key with timezone-naive datetime from database."""
|
||||
# Setup
|
||||
api_key = 'test-api-key'
|
||||
mock_key_record = MagicMock()
|
||||
# Simulate timezone-naive datetime as returned from database
|
||||
mock_key_record.expires_at = datetime.now() - timedelta(days=1) # No UTC timezone
|
||||
mock_key_record.id = 1
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = (
|
||||
mock_key_record
|
||||
)
|
||||
|
||||
# Execute
|
||||
result = api_key_store.validate_api_key(api_key)
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
mock_session.execute.assert_not_called()
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
|
||||
def test_validate_api_key_valid_timezone_naive(api_key_store, mock_session):
|
||||
"""Test validating a valid API key with timezone-naive datetime from database."""
|
||||
# Setup
|
||||
api_key = 'test-api-key'
|
||||
user_id = 'test-user-123'
|
||||
mock_key_record = MagicMock()
|
||||
mock_key_record.user_id = user_id
|
||||
# Simulate timezone-naive datetime as returned from database (future date)
|
||||
mock_key_record.expires_at = datetime.now() + timedelta(days=1) # No UTC timezone
|
||||
mock_key_record.id = 1
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = (
|
||||
mock_key_record
|
||||
)
|
||||
|
||||
# Execute
|
||||
result = api_key_store.validate_api_key(api_key)
|
||||
|
||||
# Verify
|
||||
assert result == user_id
|
||||
mock_session.execute.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_validate_api_key_not_found(api_key_store, mock_session):
|
||||
"""Test validating a non-existent API key."""
|
||||
# Setup
|
||||
|
||||
@@ -136,6 +136,7 @@ async def test_keycloak_callback_user_not_allowed(mock_request):
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'identity_provider': 'github',
|
||||
'email_verified': True,
|
||||
}
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
@@ -184,6 +185,7 @@ async def test_keycloak_callback_success_with_valid_offline_token(mock_request):
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'identity_provider': 'github',
|
||||
'email_verified': True,
|
||||
}
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
@@ -214,6 +216,82 @@ async def test_keycloak_callback_success_with_valid_offline_token(mock_request):
|
||||
mock_posthog.set.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_email_not_verified(mock_request):
|
||||
"""Test keycloak_callback when email is not verified."""
|
||||
# Arrange
|
||||
mock_verify_email = AsyncMock()
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.email.verify_email', mock_verify_email),
|
||||
):
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'identity_provider': 'github',
|
||||
'email_verified': False,
|
||||
}
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_verifier.is_active.return_value = False
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.status_code == 302
|
||||
assert 'email_verification_required=true' in result.headers['location']
|
||||
mock_verify_email.assert_called_once_with(
|
||||
request=mock_request, user_id='test_user_id', is_auth_flow=True
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_email_not_verified_missing_field(mock_request):
|
||||
"""Test keycloak_callback when email_verified field is missing (defaults to False)."""
|
||||
# Arrange
|
||||
mock_verify_email = AsyncMock()
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.email.verify_email', mock_verify_email),
|
||||
):
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'identity_provider': 'github',
|
||||
# email_verified field is missing
|
||||
}
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_verifier.is_active.return_value = False
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.status_code == 302
|
||||
assert 'email_verification_required=true' in result.headers['location']
|
||||
mock_verify_email.assert_called_once_with(
|
||||
request=mock_request, user_id='test_user_id', is_auth_flow=True
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_success_without_offline_token(mock_request):
|
||||
"""Test successful keycloak_callback without valid offline token."""
|
||||
@@ -248,6 +326,7 @@ async def test_keycloak_callback_success_without_offline_token(mock_request):
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'identity_provider': 'github',
|
||||
'email_verified': True,
|
||||
}
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
@@ -442,3 +521,418 @@ async def test_logout_without_refresh_token():
|
||||
|
||||
mock_token_manager.logout.assert_not_called()
|
||||
assert 'set-cookie' in result.headers
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_blocked_email_domain(mock_request):
|
||||
"""Test keycloak_callback when email domain is blocked."""
|
||||
# Arrange
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
):
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'email': 'user@colsch.us',
|
||||
'identity_provider': 'github',
|
||||
}
|
||||
)
|
||||
mock_token_manager.disable_keycloak_user = AsyncMock()
|
||||
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
mock_domain_blocker.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
assert 'error' in result.body.decode()
|
||||
assert 'email domain is not allowed' in result.body.decode()
|
||||
mock_domain_blocker.is_domain_blocked.assert_called_once_with('user@colsch.us')
|
||||
mock_token_manager.disable_keycloak_user.assert_called_once_with(
|
||||
'test_user_id', 'user@colsch.us'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_allowed_email_domain(mock_request):
|
||||
"""Test keycloak_callback when email domain is not blocked."""
|
||||
# Arrange
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
|
||||
mock_user_settings = MagicMock()
|
||||
mock_user_settings.accepted_tos = '2025-01-01'
|
||||
mock_query.first.return_value = mock_user_settings
|
||||
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'email': 'user@example.com',
|
||||
'identity_provider': 'github',
|
||||
'email_verified': True,
|
||||
}
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
mock_domain_blocker.is_domain_blocked.assert_called_once_with(
|
||||
'user@example.com'
|
||||
)
|
||||
mock_token_manager.disable_keycloak_user.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_domain_blocking_inactive(mock_request):
|
||||
"""Test keycloak_callback when domain blocking is not active."""
|
||||
# Arrange
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
|
||||
mock_user_settings = MagicMock()
|
||||
mock_user_settings.accepted_tos = '2025-01-01'
|
||||
mock_query.first.return_value = mock_user_settings
|
||||
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'email': 'user@colsch.us',
|
||||
'identity_provider': 'github',
|
||||
'email_verified': True,
|
||||
}
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
mock_domain_blocker.is_active.return_value = False
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
mock_domain_blocker.is_domain_blocked.assert_not_called()
|
||||
mock_token_manager.disable_keycloak_user.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_missing_email(mock_request):
|
||||
"""Test keycloak_callback when user info does not contain email."""
|
||||
# Arrange
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
|
||||
mock_user_settings = MagicMock()
|
||||
mock_user_settings.accepted_tos = '2025-01-01'
|
||||
mock_query.first.return_value = mock_user_settings
|
||||
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'identity_provider': 'github',
|
||||
'email_verified': True,
|
||||
# No email field
|
||||
}
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
mock_domain_blocker.is_domain_blocked.assert_not_called()
|
||||
mock_token_manager.disable_keycloak_user.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_duplicate_email_detected(mock_request):
|
||||
"""Test keycloak_callback when duplicate email is detected."""
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
):
|
||||
# Arrange
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'email': 'joe+test@example.com',
|
||||
'identity_provider': 'github',
|
||||
}
|
||||
)
|
||||
mock_token_manager.check_duplicate_base_email = AsyncMock(return_value=True)
|
||||
mock_token_manager.delete_keycloak_user = AsyncMock(return_value=True)
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.status_code == 302
|
||||
assert 'duplicated_email=true' in result.headers['location']
|
||||
mock_token_manager.check_duplicate_base_email.assert_called_once_with(
|
||||
'joe+test@example.com', 'test_user_id'
|
||||
)
|
||||
mock_token_manager.delete_keycloak_user.assert_called_once_with('test_user_id')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_duplicate_email_deletion_fails(mock_request):
|
||||
"""Test keycloak_callback when duplicate is detected but deletion fails."""
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
):
|
||||
# Arrange
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'email': 'joe+test@example.com',
|
||||
'identity_provider': 'github',
|
||||
}
|
||||
)
|
||||
mock_token_manager.check_duplicate_base_email = AsyncMock(return_value=True)
|
||||
mock_token_manager.delete_keycloak_user = AsyncMock(return_value=False)
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.status_code == 302
|
||||
assert 'duplicated_email=true' in result.headers['location']
|
||||
mock_token_manager.delete_keycloak_user.assert_called_once_with('test_user_id')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_duplicate_check_exception(mock_request):
|
||||
"""Test keycloak_callback when duplicate check raises exception."""
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
):
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_user_settings = MagicMock()
|
||||
mock_user_settings.accepted_tos = '2025-01-01'
|
||||
mock_query.first.return_value = mock_user_settings
|
||||
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'email': 'joe+test@example.com',
|
||||
'identity_provider': 'github',
|
||||
'email_verified': True,
|
||||
}
|
||||
)
|
||||
mock_token_manager.check_duplicate_base_email = AsyncMock(
|
||||
side_effect=Exception('Check failed')
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
# Should proceed with normal flow despite exception (fail open)
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.status_code == 302
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_no_duplicate_email(mock_request):
|
||||
"""Test keycloak_callback when no duplicate email is found."""
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
):
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_user_settings = MagicMock()
|
||||
mock_user_settings.accepted_tos = '2025-01-01'
|
||||
mock_query.first.return_value = mock_user_settings
|
||||
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
'email': 'joe+test@example.com',
|
||||
'identity_provider': 'github',
|
||||
'email_verified': True,
|
||||
}
|
||||
)
|
||||
mock_token_manager.check_duplicate_base_email = AsyncMock(return_value=False)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.status_code == 302
|
||||
mock_token_manager.check_duplicate_base_email.assert_called_once_with(
|
||||
'joe+test@example.com', 'test_user_id'
|
||||
)
|
||||
# Should not delete user when no duplicate found
|
||||
mock_token_manager.delete_keycloak_user.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_no_email_in_user_info(mock_request):
|
||||
"""Test keycloak_callback when email is not in user_info."""
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
):
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_user_settings = MagicMock()
|
||||
mock_user_settings.accepted_tos = '2025-01-01'
|
||||
mock_query.first.return_value = mock_user_settings
|
||||
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value={
|
||||
'sub': 'test_user_id',
|
||||
'preferred_username': 'test_user',
|
||||
# No email field
|
||||
'identity_provider': 'github',
|
||||
'email_verified': True,
|
||||
}
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.status_code == 302
|
||||
# Should not check for duplicate when email is missing
|
||||
mock_token_manager.check_duplicate_base_email.assert_not_called()
|
||||
|
||||
181
enterprise/tests/unit/test_domain_blocker.py
Normal file
181
enterprise/tests/unit/test_domain_blocker.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""Unit tests for DomainBlocker class."""
|
||||
|
||||
import pytest
|
||||
from server.auth.domain_blocker import DomainBlocker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def domain_blocker():
|
||||
"""Create a DomainBlocker instance for testing."""
|
||||
return DomainBlocker()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'blocked_domains,expected',
|
||||
[
|
||||
(['colsch.us', 'other-domain.com'], True),
|
||||
(['example.com'], True),
|
||||
([], False),
|
||||
],
|
||||
)
|
||||
def test_is_active(domain_blocker, blocked_domains, expected):
|
||||
"""Test that is_active returns correct value based on blocked domains configuration."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = blocked_domains
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_active()
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'email,expected_domain',
|
||||
[
|
||||
('user@example.com', 'example.com'),
|
||||
('test@colsch.us', 'colsch.us'),
|
||||
('user.name@other-domain.com', 'other-domain.com'),
|
||||
('USER@EXAMPLE.COM', 'example.com'), # Case insensitive
|
||||
('user@EXAMPLE.COM', 'example.com'),
|
||||
(' user@example.com ', 'example.com'), # Whitespace handling
|
||||
],
|
||||
)
|
||||
def test_extract_domain_valid_emails(domain_blocker, email, expected_domain):
|
||||
"""Test that _extract_domain correctly extracts and normalizes domains from valid emails."""
|
||||
# Act
|
||||
result = domain_blocker._extract_domain(email)
|
||||
|
||||
# Assert
|
||||
assert result == expected_domain
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'email,expected',
|
||||
[
|
||||
(None, None),
|
||||
('', None),
|
||||
('invalid-email', None),
|
||||
('user@', None), # Empty domain after @
|
||||
('no-at-sign', None),
|
||||
],
|
||||
)
|
||||
def test_extract_domain_invalid_emails(domain_blocker, email, expected):
|
||||
"""Test that _extract_domain returns None for invalid email formats."""
|
||||
# Act
|
||||
result = domain_blocker._extract_domain(email)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_is_domain_blocked_when_inactive(domain_blocker):
|
||||
"""Test that is_domain_blocked returns False when blocking is not active."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = []
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked('user@colsch.us')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_is_domain_blocked_with_none_email(domain_blocker):
|
||||
"""Test that is_domain_blocked returns False when email is None."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['colsch.us']
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked(None)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_is_domain_blocked_with_empty_email(domain_blocker):
|
||||
"""Test that is_domain_blocked returns False when email is empty."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['colsch.us']
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked('')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_is_domain_blocked_with_invalid_email(domain_blocker):
|
||||
"""Test that is_domain_blocked returns False when email format is invalid."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['colsch.us']
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked('invalid-email')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_is_domain_blocked_domain_not_blocked(domain_blocker):
|
||||
"""Test that is_domain_blocked returns False when domain is not in blocked list."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['colsch.us', 'other-domain.com']
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked('user@example.com')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_is_domain_blocked_domain_blocked(domain_blocker):
|
||||
"""Test that is_domain_blocked returns True when domain is in blocked list."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['colsch.us', 'other-domain.com']
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked('user@colsch.us')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_is_domain_blocked_case_insensitive(domain_blocker):
|
||||
"""Test that is_domain_blocked performs case-insensitive domain matching."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['colsch.us']
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked('user@COLSCH.US')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_is_domain_blocked_multiple_blocked_domains(domain_blocker):
|
||||
"""Test that is_domain_blocked correctly checks against multiple blocked domains."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['colsch.us', 'other-domain.com', 'blocked.org']
|
||||
|
||||
# Act
|
||||
result1 = domain_blocker.is_domain_blocked('user@other-domain.com')
|
||||
result2 = domain_blocker.is_domain_blocked('user@blocked.org')
|
||||
result3 = domain_blocker.is_domain_blocked('user@allowed.com')
|
||||
|
||||
# Assert
|
||||
assert result1 is True
|
||||
assert result2 is True
|
||||
assert result3 is False
|
||||
|
||||
|
||||
def test_is_domain_blocked_with_whitespace(domain_blocker):
|
||||
"""Test that is_domain_blocked handles emails with whitespace correctly."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['colsch.us']
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked(' user@colsch.us ')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
294
enterprise/tests/unit/test_email_validation.py
Normal file
294
enterprise/tests/unit/test_email_validation.py
Normal file
@@ -0,0 +1,294 @@
|
||||
"""Tests for email validation utilities."""
|
||||
|
||||
import re
|
||||
|
||||
from server.auth.email_validation import (
|
||||
extract_base_email,
|
||||
get_base_email_regex_pattern,
|
||||
has_plus_modifier,
|
||||
matches_base_email,
|
||||
)
|
||||
|
||||
|
||||
class TestExtractBaseEmail:
|
||||
"""Test cases for extract_base_email function."""
|
||||
|
||||
def test_extract_base_email_with_plus_modifier(self):
|
||||
"""Test extracting base email from email with + modifier."""
|
||||
# Arrange
|
||||
email = 'joe+test@example.com'
|
||||
|
||||
# Act
|
||||
result = extract_base_email(email)
|
||||
|
||||
# Assert
|
||||
assert result == 'joe@example.com'
|
||||
|
||||
def test_extract_base_email_without_plus_modifier(self):
|
||||
"""Test that email without + modifier is returned as-is."""
|
||||
# Arrange
|
||||
email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
result = extract_base_email(email)
|
||||
|
||||
# Assert
|
||||
assert result == 'joe@example.com'
|
||||
|
||||
def test_extract_base_email_multiple_plus_signs(self):
|
||||
"""Test extracting base email when multiple + signs exist."""
|
||||
# Arrange
|
||||
email = 'joe+openhands+test@example.com'
|
||||
|
||||
# Act
|
||||
result = extract_base_email(email)
|
||||
|
||||
# Assert
|
||||
assert result == 'joe@example.com'
|
||||
|
||||
def test_extract_base_email_invalid_no_at_symbol(self):
|
||||
"""Test that invalid email without @ returns None."""
|
||||
# Arrange
|
||||
email = 'invalid-email'
|
||||
|
||||
# Act
|
||||
result = extract_base_email(email)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
def test_extract_base_email_empty_string(self):
|
||||
"""Test that empty string returns None."""
|
||||
# Arrange
|
||||
email = ''
|
||||
|
||||
# Act
|
||||
result = extract_base_email(email)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
def test_extract_base_email_none(self):
|
||||
"""Test that None input returns None."""
|
||||
# Arrange
|
||||
email = None
|
||||
|
||||
# Act
|
||||
result = extract_base_email(email)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestHasPlusModifier:
|
||||
"""Test cases for has_plus_modifier function."""
|
||||
|
||||
def test_has_plus_modifier_true(self):
|
||||
"""Test detecting + modifier in email."""
|
||||
# Arrange
|
||||
email = 'joe+test@example.com'
|
||||
|
||||
# Act
|
||||
result = has_plus_modifier(email)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
def test_has_plus_modifier_false(self):
|
||||
"""Test that email without + modifier returns False."""
|
||||
# Arrange
|
||||
email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
result = has_plus_modifier(email)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
def test_has_plus_modifier_invalid_no_at_symbol(self):
|
||||
"""Test that invalid email without @ returns False."""
|
||||
# Arrange
|
||||
email = 'invalid-email'
|
||||
|
||||
# Act
|
||||
result = has_plus_modifier(email)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
def test_has_plus_modifier_empty_string(self):
|
||||
"""Test that empty string returns False."""
|
||||
# Arrange
|
||||
email = ''
|
||||
|
||||
# Act
|
||||
result = has_plus_modifier(email)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestMatchesBaseEmail:
|
||||
"""Test cases for matches_base_email function."""
|
||||
|
||||
def test_matches_base_email_exact_match(self):
|
||||
"""Test that exact base email matches."""
|
||||
# Arrange
|
||||
email = 'joe@example.com'
|
||||
base_email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
result = matches_base_email(email, base_email)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
def test_matches_base_email_with_plus_variant(self):
|
||||
"""Test that email with + variant matches base email."""
|
||||
# Arrange
|
||||
email = 'joe+test@example.com'
|
||||
base_email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
result = matches_base_email(email, base_email)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
def test_matches_base_email_different_base(self):
|
||||
"""Test that different base emails do not match."""
|
||||
# Arrange
|
||||
email = 'jane@example.com'
|
||||
base_email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
result = matches_base_email(email, base_email)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
def test_matches_base_email_different_domain(self):
|
||||
"""Test that same local part but different domain does not match."""
|
||||
# Arrange
|
||||
email = 'joe@other.com'
|
||||
base_email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
result = matches_base_email(email, base_email)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
def test_matches_base_email_case_insensitive(self):
|
||||
"""Test that matching is case-insensitive."""
|
||||
# Arrange
|
||||
email = 'JOE+TEST@EXAMPLE.COM'
|
||||
base_email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
result = matches_base_email(email, base_email)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
def test_matches_base_email_empty_strings(self):
|
||||
"""Test that empty strings return False."""
|
||||
# Arrange
|
||||
email = ''
|
||||
base_email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
result = matches_base_email(email, base_email)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestGetBaseEmailRegexPattern:
|
||||
"""Test cases for get_base_email_regex_pattern function."""
|
||||
|
||||
def test_get_base_email_regex_pattern_valid(self):
|
||||
"""Test generating valid regex pattern for base email."""
|
||||
# Arrange
|
||||
base_email = 'joe@example.com'
|
||||
|
||||
# Act
|
||||
pattern = get_base_email_regex_pattern(base_email)
|
||||
|
||||
# Assert
|
||||
assert pattern is not None
|
||||
assert isinstance(pattern, re.Pattern)
|
||||
assert pattern.match('joe@example.com') is not None
|
||||
assert pattern.match('joe+test@example.com') is not None
|
||||
assert pattern.match('joe+openhands@example.com') is not None
|
||||
|
||||
def test_get_base_email_regex_pattern_matches_plus_variant(self):
|
||||
"""Test that regex pattern matches + variant."""
|
||||
# Arrange
|
||||
base_email = 'joe@example.com'
|
||||
pattern = get_base_email_regex_pattern(base_email)
|
||||
|
||||
# Act
|
||||
match = pattern.match('joe+test@example.com')
|
||||
|
||||
# Assert
|
||||
assert match is not None
|
||||
|
||||
def test_get_base_email_regex_pattern_rejects_different_base(self):
|
||||
"""Test that regex pattern rejects different base email."""
|
||||
# Arrange
|
||||
base_email = 'joe@example.com'
|
||||
pattern = get_base_email_regex_pattern(base_email)
|
||||
|
||||
# Act
|
||||
match = pattern.match('jane@example.com')
|
||||
|
||||
# Assert
|
||||
assert match is None
|
||||
|
||||
def test_get_base_email_regex_pattern_rejects_different_domain(self):
|
||||
"""Test that regex pattern rejects different domain."""
|
||||
# Arrange
|
||||
base_email = 'joe@example.com'
|
||||
pattern = get_base_email_regex_pattern(base_email)
|
||||
|
||||
# Act
|
||||
match = pattern.match('joe@other.com')
|
||||
|
||||
# Assert
|
||||
assert match is None
|
||||
|
||||
def test_get_base_email_regex_pattern_case_insensitive(self):
|
||||
"""Test that regex pattern is case-insensitive."""
|
||||
# Arrange
|
||||
base_email = 'joe@example.com'
|
||||
pattern = get_base_email_regex_pattern(base_email)
|
||||
|
||||
# Act
|
||||
match = pattern.match('JOE+TEST@EXAMPLE.COM')
|
||||
|
||||
# Assert
|
||||
assert match is not None
|
||||
|
||||
def test_get_base_email_regex_pattern_special_characters(self):
|
||||
"""Test that regex pattern handles special characters in email."""
|
||||
# Arrange
|
||||
base_email = 'user.name+tag@example-site.com'
|
||||
pattern = get_base_email_regex_pattern(base_email)
|
||||
|
||||
# Act
|
||||
match = pattern.match('user.name+test@example-site.com')
|
||||
|
||||
# Assert
|
||||
assert match is not None
|
||||
|
||||
def test_get_base_email_regex_pattern_invalid_base_email(self):
|
||||
"""Test that invalid base email returns None."""
|
||||
# Arrange
|
||||
base_email = 'invalid-email'
|
||||
|
||||
# Act
|
||||
pattern = get_base_email_regex_pattern(base_email)
|
||||
|
||||
# Assert
|
||||
assert pattern is None
|
||||
@@ -1,485 +0,0 @@
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from server.legacy_conversation_manager import (
|
||||
_LEGACY_ENTRY_TIMEOUT_SECONDS,
|
||||
LegacyCacheEntry,
|
||||
LegacyConversationManager,
|
||||
)
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.server.config.server_config import ServerConfig
|
||||
from openhands.server.monitoring import MonitoringListener
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_sio():
|
||||
"""Create a mock SocketIO server."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config():
|
||||
"""Create a mock OpenHands config."""
|
||||
return MagicMock(spec=OpenHandsConfig)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_server_config():
|
||||
"""Create a mock server config."""
|
||||
return MagicMock(spec=ServerConfig)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_file_store():
|
||||
"""Create a mock file store."""
|
||||
return MagicMock(spec=InMemoryFileStore)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_monitoring_listener():
|
||||
"""Create a mock monitoring listener."""
|
||||
return MagicMock(spec=MonitoringListener)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_conversation_manager():
|
||||
"""Create a mock SaasNestedConversationManager."""
|
||||
mock_cm = MagicMock()
|
||||
mock_cm._get_runtime = AsyncMock()
|
||||
return mock_cm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_legacy_conversation_manager():
|
||||
"""Create a mock ClusteredConversationManager."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def legacy_manager(
|
||||
mock_sio,
|
||||
mock_config,
|
||||
mock_server_config,
|
||||
mock_file_store,
|
||||
mock_conversation_manager,
|
||||
mock_legacy_conversation_manager,
|
||||
):
|
||||
"""Create a LegacyConversationManager instance for testing."""
|
||||
return LegacyConversationManager(
|
||||
sio=mock_sio,
|
||||
config=mock_config,
|
||||
server_config=mock_server_config,
|
||||
file_store=mock_file_store,
|
||||
conversation_manager=mock_conversation_manager,
|
||||
legacy_conversation_manager=mock_legacy_conversation_manager,
|
||||
)
|
||||
|
||||
|
||||
class TestLegacyCacheEntry:
|
||||
"""Test the LegacyCacheEntry dataclass."""
|
||||
|
||||
def test_cache_entry_creation(self):
|
||||
"""Test creating a cache entry."""
|
||||
timestamp = time.time()
|
||||
entry = LegacyCacheEntry(is_legacy=True, timestamp=timestamp)
|
||||
|
||||
assert entry.is_legacy is True
|
||||
assert entry.timestamp == timestamp
|
||||
|
||||
def test_cache_entry_false(self):
|
||||
"""Test creating a cache entry with False value."""
|
||||
timestamp = time.time()
|
||||
entry = LegacyCacheEntry(is_legacy=False, timestamp=timestamp)
|
||||
|
||||
assert entry.is_legacy is False
|
||||
assert entry.timestamp == timestamp
|
||||
|
||||
|
||||
class TestLegacyConversationManagerCacheCleanup:
|
||||
"""Test cache cleanup functionality."""
|
||||
|
||||
def test_cleanup_expired_cache_entries_removes_expired(self, legacy_manager):
|
||||
"""Test that expired entries are removed from cache."""
|
||||
current_time = time.time()
|
||||
expired_time = current_time - _LEGACY_ENTRY_TIMEOUT_SECONDS - 1
|
||||
valid_time = current_time - 100 # Well within timeout
|
||||
|
||||
# Add both expired and valid entries
|
||||
legacy_manager._legacy_cache = {
|
||||
'expired_conversation': LegacyCacheEntry(True, expired_time),
|
||||
'valid_conversation': LegacyCacheEntry(False, valid_time),
|
||||
'another_expired': LegacyCacheEntry(True, expired_time - 100),
|
||||
}
|
||||
|
||||
legacy_manager._cleanup_expired_cache_entries()
|
||||
|
||||
# Only valid entry should remain
|
||||
assert len(legacy_manager._legacy_cache) == 1
|
||||
assert 'valid_conversation' in legacy_manager._legacy_cache
|
||||
assert 'expired_conversation' not in legacy_manager._legacy_cache
|
||||
assert 'another_expired' not in legacy_manager._legacy_cache
|
||||
|
||||
def test_cleanup_expired_cache_entries_keeps_valid(self, legacy_manager):
|
||||
"""Test that valid entries are kept during cleanup."""
|
||||
current_time = time.time()
|
||||
valid_time = current_time - 100 # Well within timeout
|
||||
|
||||
legacy_manager._legacy_cache = {
|
||||
'valid_conversation_1': LegacyCacheEntry(True, valid_time),
|
||||
'valid_conversation_2': LegacyCacheEntry(False, valid_time - 50),
|
||||
}
|
||||
|
||||
legacy_manager._cleanup_expired_cache_entries()
|
||||
|
||||
# Both entries should remain
|
||||
assert len(legacy_manager._legacy_cache) == 2
|
||||
assert 'valid_conversation_1' in legacy_manager._legacy_cache
|
||||
assert 'valid_conversation_2' in legacy_manager._legacy_cache
|
||||
|
||||
def test_cleanup_expired_cache_entries_empty_cache(self, legacy_manager):
|
||||
"""Test cleanup with empty cache."""
|
||||
legacy_manager._legacy_cache = {}
|
||||
|
||||
legacy_manager._cleanup_expired_cache_entries()
|
||||
|
||||
assert len(legacy_manager._legacy_cache) == 0
|
||||
|
||||
|
||||
class TestIsLegacyRuntime:
|
||||
"""Test the is_legacy_runtime method."""
|
||||
|
||||
def test_is_legacy_runtime_none(self, legacy_manager):
|
||||
"""Test with None runtime."""
|
||||
result = legacy_manager.is_legacy_runtime(None)
|
||||
assert result is False
|
||||
|
||||
def test_is_legacy_runtime_legacy_command(self, legacy_manager):
|
||||
"""Test with legacy runtime command."""
|
||||
runtime = {'command': 'some_old_legacy_command'}
|
||||
result = legacy_manager.is_legacy_runtime(runtime)
|
||||
assert result is True
|
||||
|
||||
def test_is_legacy_runtime_new_command(self, legacy_manager):
|
||||
"""Test with new runtime command containing openhands.server."""
|
||||
runtime = {'command': 'python -m openhands.server.listen'}
|
||||
result = legacy_manager.is_legacy_runtime(runtime)
|
||||
assert result is False
|
||||
|
||||
def test_is_legacy_runtime_partial_match(self, legacy_manager):
|
||||
"""Test with command that partially matches but is still legacy."""
|
||||
runtime = {'command': 'openhands.client.start'}
|
||||
result = legacy_manager.is_legacy_runtime(runtime)
|
||||
assert result is True
|
||||
|
||||
def test_is_legacy_runtime_empty_command(self, legacy_manager):
|
||||
"""Test with empty command."""
|
||||
runtime = {'command': ''}
|
||||
result = legacy_manager.is_legacy_runtime(runtime)
|
||||
assert result is True
|
||||
|
||||
def test_is_legacy_runtime_missing_command_key(self, legacy_manager):
|
||||
"""Test with runtime missing command key."""
|
||||
runtime = {'other_key': 'value'}
|
||||
# This should raise a KeyError
|
||||
with pytest.raises(KeyError):
|
||||
legacy_manager.is_legacy_runtime(runtime)
|
||||
|
||||
|
||||
class TestShouldStartInLegacyMode:
|
||||
"""Test the should_start_in_legacy_mode method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_hit_valid_entry_legacy(self, legacy_manager):
|
||||
"""Test cache hit with valid legacy entry."""
|
||||
conversation_id = 'test_conversation'
|
||||
current_time = time.time()
|
||||
|
||||
# Add valid cache entry
|
||||
legacy_manager._legacy_cache[conversation_id] = LegacyCacheEntry(
|
||||
True, current_time - 100
|
||||
)
|
||||
|
||||
result = await legacy_manager.should_start_in_legacy_mode(conversation_id)
|
||||
|
||||
assert result is True
|
||||
# Should not call _get_runtime since we hit cache
|
||||
legacy_manager.conversation_manager._get_runtime.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_hit_valid_entry_non_legacy(self, legacy_manager):
|
||||
"""Test cache hit with valid non-legacy entry."""
|
||||
conversation_id = 'test_conversation'
|
||||
current_time = time.time()
|
||||
|
||||
# Add valid cache entry
|
||||
legacy_manager._legacy_cache[conversation_id] = LegacyCacheEntry(
|
||||
False, current_time - 100
|
||||
)
|
||||
|
||||
result = await legacy_manager.should_start_in_legacy_mode(conversation_id)
|
||||
|
||||
assert result is False
|
||||
# Should not call _get_runtime since we hit cache
|
||||
legacy_manager.conversation_manager._get_runtime.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_miss_legacy_runtime(self, legacy_manager):
|
||||
"""Test cache miss with legacy runtime."""
|
||||
conversation_id = 'test_conversation'
|
||||
runtime = {'command': 'old_command'}
|
||||
|
||||
legacy_manager.conversation_manager._get_runtime.return_value = runtime
|
||||
|
||||
result = await legacy_manager.should_start_in_legacy_mode(conversation_id)
|
||||
|
||||
assert result is True
|
||||
# Should call _get_runtime
|
||||
legacy_manager.conversation_manager._get_runtime.assert_called_once_with(
|
||||
conversation_id
|
||||
)
|
||||
# Should cache the result
|
||||
assert conversation_id in legacy_manager._legacy_cache
|
||||
assert legacy_manager._legacy_cache[conversation_id].is_legacy is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_miss_non_legacy_runtime(self, legacy_manager):
|
||||
"""Test cache miss with non-legacy runtime."""
|
||||
conversation_id = 'test_conversation'
|
||||
runtime = {'command': 'python -m openhands.server.listen'}
|
||||
|
||||
legacy_manager.conversation_manager._get_runtime.return_value = runtime
|
||||
|
||||
result = await legacy_manager.should_start_in_legacy_mode(conversation_id)
|
||||
|
||||
assert result is False
|
||||
# Should call _get_runtime
|
||||
legacy_manager.conversation_manager._get_runtime.assert_called_once_with(
|
||||
conversation_id
|
||||
)
|
||||
# Should cache the result
|
||||
assert conversation_id in legacy_manager._legacy_cache
|
||||
assert legacy_manager._legacy_cache[conversation_id].is_legacy is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_expired_entry(self, legacy_manager):
|
||||
"""Test with expired cache entry."""
|
||||
conversation_id = 'test_conversation'
|
||||
expired_time = time.time() - _LEGACY_ENTRY_TIMEOUT_SECONDS - 1
|
||||
runtime = {'command': 'python -m openhands.server.listen'}
|
||||
|
||||
# Add expired cache entry
|
||||
legacy_manager._legacy_cache[conversation_id] = LegacyCacheEntry(
|
||||
True,
|
||||
expired_time, # This should be considered expired
|
||||
)
|
||||
|
||||
legacy_manager.conversation_manager._get_runtime.return_value = runtime
|
||||
|
||||
result = await legacy_manager.should_start_in_legacy_mode(conversation_id)
|
||||
|
||||
assert result is False # Runtime indicates non-legacy
|
||||
# Should call _get_runtime since cache is expired
|
||||
legacy_manager.conversation_manager._get_runtime.assert_called_once_with(
|
||||
conversation_id
|
||||
)
|
||||
# Should update cache with new result
|
||||
assert legacy_manager._legacy_cache[conversation_id].is_legacy is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_exactly_at_timeout(self, legacy_manager):
|
||||
"""Test with cache entry exactly at timeout boundary."""
|
||||
conversation_id = 'test_conversation'
|
||||
timeout_time = time.time() - _LEGACY_ENTRY_TIMEOUT_SECONDS
|
||||
runtime = {'command': 'python -m openhands.server.listen'}
|
||||
|
||||
# Add cache entry exactly at timeout
|
||||
legacy_manager._legacy_cache[conversation_id] = LegacyCacheEntry(
|
||||
True, timeout_time
|
||||
)
|
||||
|
||||
legacy_manager.conversation_manager._get_runtime.return_value = runtime
|
||||
|
||||
result = await legacy_manager.should_start_in_legacy_mode(conversation_id)
|
||||
|
||||
# Should treat as expired and fetch from runtime
|
||||
assert result is False
|
||||
legacy_manager.conversation_manager._get_runtime.assert_called_once_with(
|
||||
conversation_id
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_returns_none(self, legacy_manager):
|
||||
"""Test when runtime returns None."""
|
||||
conversation_id = 'test_conversation'
|
||||
|
||||
legacy_manager.conversation_manager._get_runtime.return_value = None
|
||||
|
||||
result = await legacy_manager.should_start_in_legacy_mode(conversation_id)
|
||||
|
||||
assert result is False
|
||||
# Should cache the result
|
||||
assert conversation_id in legacy_manager._legacy_cache
|
||||
assert legacy_manager._legacy_cache[conversation_id].is_legacy is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_called_on_each_invocation(self, legacy_manager):
|
||||
"""Test that cleanup is called on each invocation."""
|
||||
conversation_id = 'test_conversation'
|
||||
runtime = {'command': 'test'}
|
||||
|
||||
legacy_manager.conversation_manager._get_runtime.return_value = runtime
|
||||
|
||||
# Mock the cleanup method to verify it's called
|
||||
with patch.object(
|
||||
legacy_manager, '_cleanup_expired_cache_entries'
|
||||
) as mock_cleanup:
|
||||
await legacy_manager.should_start_in_legacy_mode(conversation_id)
|
||||
mock_cleanup.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_conversations_cached_independently(self, legacy_manager):
|
||||
"""Test that multiple conversations are cached independently."""
|
||||
conv1 = 'conversation_1'
|
||||
conv2 = 'conversation_2'
|
||||
|
||||
runtime1 = {'command': 'old_command'} # Legacy
|
||||
runtime2 = {'command': 'python -m openhands.server.listen'} # Non-legacy
|
||||
|
||||
# Mock to return different runtimes based on conversation_id
|
||||
def mock_get_runtime(conversation_id):
|
||||
if conversation_id == conv1:
|
||||
return runtime1
|
||||
return runtime2
|
||||
|
||||
legacy_manager.conversation_manager._get_runtime.side_effect = mock_get_runtime
|
||||
|
||||
result1 = await legacy_manager.should_start_in_legacy_mode(conv1)
|
||||
result2 = await legacy_manager.should_start_in_legacy_mode(conv2)
|
||||
|
||||
assert result1 is True
|
||||
assert result2 is False
|
||||
|
||||
# Both should be cached
|
||||
assert conv1 in legacy_manager._legacy_cache
|
||||
assert conv2 in legacy_manager._legacy_cache
|
||||
assert legacy_manager._legacy_cache[conv1].is_legacy is True
|
||||
assert legacy_manager._legacy_cache[conv2].is_legacy is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_timestamp_updated_on_refresh(self, legacy_manager):
|
||||
"""Test that cache timestamp is updated when entry is refreshed."""
|
||||
conversation_id = 'test_conversation'
|
||||
old_time = time.time() - _LEGACY_ENTRY_TIMEOUT_SECONDS - 1
|
||||
runtime = {'command': 'test'}
|
||||
|
||||
# Add expired entry
|
||||
legacy_manager._legacy_cache[conversation_id] = LegacyCacheEntry(True, old_time)
|
||||
legacy_manager.conversation_manager._get_runtime.return_value = runtime
|
||||
|
||||
# Record time before call
|
||||
before_call = time.time()
|
||||
await legacy_manager.should_start_in_legacy_mode(conversation_id)
|
||||
after_call = time.time()
|
||||
|
||||
# Timestamp should be updated
|
||||
cached_entry = legacy_manager._legacy_cache[conversation_id]
|
||||
assert cached_entry.timestamp >= before_call
|
||||
assert cached_entry.timestamp <= after_call
|
||||
|
||||
|
||||
class TestLegacyConversationManagerIntegration:
|
||||
"""Integration tests for LegacyConversationManager."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_instance_creates_proper_manager(
|
||||
self,
|
||||
mock_sio,
|
||||
mock_config,
|
||||
mock_file_store,
|
||||
mock_server_config,
|
||||
mock_monitoring_listener,
|
||||
):
|
||||
"""Test that get_instance creates a properly configured manager."""
|
||||
with patch(
|
||||
'server.legacy_conversation_manager.SaasNestedConversationManager'
|
||||
) as mock_saas, patch(
|
||||
'server.legacy_conversation_manager.ClusteredConversationManager'
|
||||
) as mock_clustered:
|
||||
mock_saas.get_instance.return_value = MagicMock()
|
||||
mock_clustered.get_instance.return_value = MagicMock()
|
||||
|
||||
manager = LegacyConversationManager.get_instance(
|
||||
mock_sio,
|
||||
mock_config,
|
||||
mock_file_store,
|
||||
mock_server_config,
|
||||
mock_monitoring_listener,
|
||||
)
|
||||
|
||||
assert isinstance(manager, LegacyConversationManager)
|
||||
assert manager.sio == mock_sio
|
||||
assert manager.config == mock_config
|
||||
assert manager.file_store == mock_file_store
|
||||
assert manager.server_config == mock_server_config
|
||||
|
||||
# Verify that both nested managers are created
|
||||
mock_saas.get_instance.assert_called_once()
|
||||
mock_clustered.get_instance.assert_called_once()
|
||||
|
||||
def test_legacy_cache_initialized_empty(self, legacy_manager):
|
||||
"""Test that legacy cache is initialized as empty dict."""
|
||||
assert isinstance(legacy_manager._legacy_cache, dict)
|
||||
assert len(legacy_manager._legacy_cache) == 0
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Test edge cases and error scenarios."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_runtime_raises_exception(self, legacy_manager):
|
||||
"""Test behavior when _get_runtime raises an exception."""
|
||||
conversation_id = 'test_conversation'
|
||||
|
||||
legacy_manager.conversation_manager._get_runtime.side_effect = Exception(
|
||||
'Runtime error'
|
||||
)
|
||||
|
||||
# Should propagate the exception
|
||||
with pytest.raises(Exception, match='Runtime error'):
|
||||
await legacy_manager.should_start_in_legacy_mode(conversation_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_very_large_cache(self, legacy_manager):
|
||||
"""Test behavior with a large number of cache entries."""
|
||||
current_time = time.time()
|
||||
|
||||
# Add many cache entries
|
||||
for i in range(1000):
|
||||
legacy_manager._legacy_cache[f'conversation_{i}'] = LegacyCacheEntry(
|
||||
i % 2 == 0, current_time - i
|
||||
)
|
||||
|
||||
# This should work without issues
|
||||
await legacy_manager.should_start_in_legacy_mode('new_conversation')
|
||||
|
||||
# Should have added one more entry
|
||||
assert len(legacy_manager._legacy_cache) == 1001
|
||||
|
||||
def test_cleanup_with_concurrent_modifications(self, legacy_manager):
|
||||
"""Test cleanup behavior when cache is modified during cleanup."""
|
||||
current_time = time.time()
|
||||
expired_time = current_time - _LEGACY_ENTRY_TIMEOUT_SECONDS - 1
|
||||
|
||||
# Add expired entries
|
||||
legacy_manager._legacy_cache = {
|
||||
f'conversation_{i}': LegacyCacheEntry(True, expired_time) for i in range(10)
|
||||
}
|
||||
|
||||
# This should work without raising exceptions
|
||||
legacy_manager._cleanup_expired_cache_entries()
|
||||
|
||||
# All entries should be removed
|
||||
assert len(legacy_manager._legacy_cache) == 0
|
||||
@@ -5,7 +5,12 @@ import jwt
|
||||
import pytest
|
||||
from fastapi import Request
|
||||
from pydantic import SecretStr
|
||||
from server.auth.auth_error import BearerTokenError, CookieError, NoCredentialsError
|
||||
from server.auth.auth_error import (
|
||||
AuthError,
|
||||
BearerTokenError,
|
||||
CookieError,
|
||||
NoCredentialsError,
|
||||
)
|
||||
from server.auth.saas_user_auth import (
|
||||
SaasUserAuth,
|
||||
get_api_key_from_header,
|
||||
@@ -647,3 +652,97 @@ def test_get_api_key_from_header_bearer_with_empty_token():
|
||||
# Assert that empty string from Bearer is returned (current behavior)
|
||||
# This tests the current implementation behavior
|
||||
assert api_key == ''
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_saas_user_auth_from_signed_token_blocked_domain(mock_config):
|
||||
"""Test that saas_user_auth_from_signed_token raises AuthError when email domain is blocked."""
|
||||
# Arrange
|
||||
access_payload = {
|
||||
'sub': 'test_user_id',
|
||||
'exp': int(time.time()) + 3600,
|
||||
'email': 'user@colsch.us',
|
||||
'email_verified': True,
|
||||
}
|
||||
access_token = jwt.encode(access_payload, 'access_secret', algorithm='HS256')
|
||||
|
||||
token_payload = {
|
||||
'access_token': access_token,
|
||||
'refresh_token': 'test_refresh_token',
|
||||
}
|
||||
signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256')
|
||||
|
||||
with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker:
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
mock_domain_blocker.is_domain_blocked.return_value = True
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(AuthError) as exc_info:
|
||||
await saas_user_auth_from_signed_token(signed_token)
|
||||
|
||||
assert 'email domain is not allowed' in str(exc_info.value)
|
||||
mock_domain_blocker.is_domain_blocked.assert_called_once_with('user@colsch.us')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_saas_user_auth_from_signed_token_allowed_domain(mock_config):
|
||||
"""Test that saas_user_auth_from_signed_token succeeds when email domain is not blocked."""
|
||||
# Arrange
|
||||
access_payload = {
|
||||
'sub': 'test_user_id',
|
||||
'exp': int(time.time()) + 3600,
|
||||
'email': 'user@example.com',
|
||||
'email_verified': True,
|
||||
}
|
||||
access_token = jwt.encode(access_payload, 'access_secret', algorithm='HS256')
|
||||
|
||||
token_payload = {
|
||||
'access_token': access_token,
|
||||
'refresh_token': 'test_refresh_token',
|
||||
}
|
||||
signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256')
|
||||
|
||||
with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker:
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
|
||||
# Act
|
||||
result = await saas_user_auth_from_signed_token(signed_token)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, SaasUserAuth)
|
||||
assert result.user_id == 'test_user_id'
|
||||
assert result.email == 'user@example.com'
|
||||
mock_domain_blocker.is_domain_blocked.assert_called_once_with(
|
||||
'user@example.com'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_saas_user_auth_from_signed_token_domain_blocking_inactive(mock_config):
|
||||
"""Test that saas_user_auth_from_signed_token succeeds when domain blocking is not active."""
|
||||
# Arrange
|
||||
access_payload = {
|
||||
'sub': 'test_user_id',
|
||||
'exp': int(time.time()) + 3600,
|
||||
'email': 'user@colsch.us',
|
||||
'email_verified': True,
|
||||
}
|
||||
access_token = jwt.encode(access_payload, 'access_secret', algorithm='HS256')
|
||||
|
||||
token_payload = {
|
||||
'access_token': access_token,
|
||||
'refresh_token': 'test_refresh_token',
|
||||
}
|
||||
signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256')
|
||||
|
||||
with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker:
|
||||
mock_domain_blocker.is_active.return_value = False
|
||||
|
||||
# Act
|
||||
result = await saas_user_auth_from_signed_token(signed_token)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, SaasUserAuth)
|
||||
assert result.user_id == 'test_user_id'
|
||||
mock_domain_blocker.is_domain_blocked.assert_not_called()
|
||||
|
||||
1
enterprise/tests/unit/test_sharing/__init__.py
Normal file
1
enterprise/tests/unit/test_sharing/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for sharing package."""
|
||||
@@ -0,0 +1,91 @@
|
||||
"""Tests for public conversation models."""
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import uuid4
|
||||
|
||||
from server.sharing.shared_conversation_models import (
|
||||
SharedConversation,
|
||||
SharedConversationPage,
|
||||
SharedConversationSortOrder,
|
||||
)
|
||||
|
||||
|
||||
def test_public_conversation_creation():
|
||||
"""Test that SharedConversation can be created with all required fields."""
|
||||
conversation_id = uuid4()
|
||||
now = datetime.utcnow()
|
||||
|
||||
conversation = SharedConversation(
|
||||
id=conversation_id,
|
||||
created_by_user_id='test_user',
|
||||
sandbox_id='test_sandbox',
|
||||
title='Test Conversation',
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
selected_repository=None,
|
||||
parent_conversation_id=None,
|
||||
)
|
||||
|
||||
assert conversation.id == conversation_id
|
||||
assert conversation.title == 'Test Conversation'
|
||||
assert conversation.created_by_user_id == 'test_user'
|
||||
assert conversation.sandbox_id == 'test_sandbox'
|
||||
|
||||
|
||||
def test_public_conversation_page_creation():
|
||||
"""Test that SharedConversationPage can be created."""
|
||||
conversation_id = uuid4()
|
||||
now = datetime.utcnow()
|
||||
|
||||
conversation = SharedConversation(
|
||||
id=conversation_id,
|
||||
created_by_user_id='test_user',
|
||||
sandbox_id='test_sandbox',
|
||||
title='Test Conversation',
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
selected_repository=None,
|
||||
parent_conversation_id=None,
|
||||
)
|
||||
|
||||
page = SharedConversationPage(
|
||||
items=[conversation],
|
||||
next_page_id='next_page',
|
||||
)
|
||||
|
||||
assert len(page.items) == 1
|
||||
assert page.items[0].id == conversation_id
|
||||
assert page.next_page_id == 'next_page'
|
||||
|
||||
|
||||
def test_public_conversation_sort_order_enum():
|
||||
"""Test that SharedConversationSortOrder enum has expected values."""
|
||||
assert hasattr(SharedConversationSortOrder, 'CREATED_AT')
|
||||
assert hasattr(SharedConversationSortOrder, 'CREATED_AT_DESC')
|
||||
assert hasattr(SharedConversationSortOrder, 'UPDATED_AT')
|
||||
assert hasattr(SharedConversationSortOrder, 'UPDATED_AT_DESC')
|
||||
assert hasattr(SharedConversationSortOrder, 'TITLE')
|
||||
assert hasattr(SharedConversationSortOrder, 'TITLE_DESC')
|
||||
|
||||
|
||||
def test_public_conversation_optional_fields():
|
||||
"""Test that SharedConversation works with optional fields."""
|
||||
conversation_id = uuid4()
|
||||
parent_id = uuid4()
|
||||
now = datetime.utcnow()
|
||||
|
||||
conversation = SharedConversation(
|
||||
id=conversation_id,
|
||||
created_by_user_id='test_user',
|
||||
sandbox_id='test_sandbox',
|
||||
title='Test Conversation',
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
selected_repository='owner/repo',
|
||||
parent_conversation_id=parent_id,
|
||||
llm_model='gpt-4',
|
||||
)
|
||||
|
||||
assert conversation.selected_repository == 'owner/repo'
|
||||
assert conversation.parent_conversation_id == parent_id
|
||||
assert conversation.llm_model == 'gpt-4'
|
||||
@@ -0,0 +1,430 @@
|
||||
"""Tests for SharedConversationInfoService."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from typing import AsyncGenerator
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from server.sharing.shared_conversation_models import (
|
||||
SharedConversationSortOrder,
|
||||
)
|
||||
from server.sharing.sql_shared_conversation_info_service import (
|
||||
SQLSharedConversationInfoService,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationInfo,
|
||||
)
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
SQLAppConversationInfoService,
|
||||
)
|
||||
from openhands.app_server.user.specifiy_user_context import SpecifyUserContext
|
||||
from openhands.app_server.utils.sql_utils import Base
|
||||
from openhands.integrations.provider import ProviderType
|
||||
from openhands.sdk.llm import MetricsSnapshot
|
||||
from openhands.sdk.llm.utils.metrics import TokenUsage
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_engine():
|
||||
"""Create an async SQLite engine for testing."""
|
||||
engine = create_async_engine(
|
||||
'sqlite+aiosqlite:///:memory:',
|
||||
poolclass=StaticPool,
|
||||
connect_args={'check_same_thread': False},
|
||||
echo=False,
|
||||
)
|
||||
|
||||
# Create all tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Create an async session for testing."""
|
||||
async_session_maker = async_sessionmaker(
|
||||
async_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
async with async_session_maker() as db_session:
|
||||
yield db_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def shared_conversation_info_service(async_session):
|
||||
"""Create a SharedConversationInfoService for testing."""
|
||||
return SQLSharedConversationInfoService(db_session=async_session)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def app_conversation_service(async_session):
|
||||
"""Create an AppConversationInfoService for creating test data."""
|
||||
return SQLAppConversationInfoService(
|
||||
db_session=async_session, user_context=SpecifyUserContext(user_id=None)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_conversation_info():
|
||||
"""Create a sample conversation info for testing."""
|
||||
return AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='test_user',
|
||||
sandbox_id='test_sandbox',
|
||||
selected_repository='test/repo',
|
||||
selected_branch='main',
|
||||
git_provider=ProviderType.GITHUB,
|
||||
title='Test Conversation',
|
||||
trigger=ConversationTrigger.GUI,
|
||||
pr_number=[123],
|
||||
llm_model='gpt-4',
|
||||
metrics=MetricsSnapshot(
|
||||
accumulated_cost=1.5,
|
||||
max_budget_per_task=10.0,
|
||||
accumulated_token_usage=TokenUsage(
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
context_window=4096,
|
||||
per_turn_token=150,
|
||||
),
|
||||
),
|
||||
parent_conversation_id=None,
|
||||
sub_conversation_ids=[],
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
public=True, # Make it public for testing
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_private_conversation_info():
|
||||
"""Create a sample private conversation info for testing."""
|
||||
return AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='test_user',
|
||||
sandbox_id='test_sandbox_private',
|
||||
selected_repository='test/private_repo',
|
||||
selected_branch='main',
|
||||
git_provider=ProviderType.GITHUB,
|
||||
title='Private Conversation',
|
||||
trigger=ConversationTrigger.GUI,
|
||||
pr_number=[124],
|
||||
llm_model='gpt-4',
|
||||
metrics=MetricsSnapshot(
|
||||
accumulated_cost=2.0,
|
||||
max_budget_per_task=10.0,
|
||||
accumulated_token_usage=TokenUsage(
|
||||
prompt_tokens=200,
|
||||
completion_tokens=100,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
context_window=4096,
|
||||
per_turn_token=300,
|
||||
),
|
||||
),
|
||||
parent_conversation_id=None,
|
||||
sub_conversation_ids=[],
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
public=False, # Make it private
|
||||
)
|
||||
|
||||
|
||||
class TestSharedConversationInfoService:
|
||||
"""Test cases for SharedConversationInfoService."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_shared_conversation_info_returns_public_conversation(
|
||||
self,
|
||||
shared_conversation_info_service,
|
||||
app_conversation_service,
|
||||
sample_conversation_info,
|
||||
):
|
||||
"""Test that get_shared_conversation_info returns a public conversation."""
|
||||
# Create a public conversation
|
||||
await app_conversation_service.save_app_conversation_info(
|
||||
sample_conversation_info
|
||||
)
|
||||
|
||||
# Retrieve it via public service
|
||||
result = await shared_conversation_info_service.get_shared_conversation_info(
|
||||
sample_conversation_info.id
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.id == sample_conversation_info.id
|
||||
assert result.title == sample_conversation_info.title
|
||||
assert result.created_by_user_id == sample_conversation_info.created_by_user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_shared_conversation_info_returns_none_for_private_conversation(
|
||||
self,
|
||||
shared_conversation_info_service,
|
||||
app_conversation_service,
|
||||
sample_private_conversation_info,
|
||||
):
|
||||
"""Test that get_shared_conversation_info returns None for private conversations."""
|
||||
# Create a private conversation
|
||||
await app_conversation_service.save_app_conversation_info(
|
||||
sample_private_conversation_info
|
||||
)
|
||||
|
||||
# Try to retrieve it via public service
|
||||
result = await shared_conversation_info_service.get_shared_conversation_info(
|
||||
sample_private_conversation_info.id
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_shared_conversation_info_returns_none_for_nonexistent_conversation(
|
||||
self, shared_conversation_info_service
|
||||
):
|
||||
"""Test that get_shared_conversation_info returns None for nonexistent conversations."""
|
||||
nonexistent_id = uuid4()
|
||||
result = await shared_conversation_info_service.get_shared_conversation_info(
|
||||
nonexistent_id
|
||||
)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_shared_conversation_info_returns_only_public_conversations(
|
||||
self,
|
||||
shared_conversation_info_service,
|
||||
app_conversation_service,
|
||||
sample_conversation_info,
|
||||
sample_private_conversation_info,
|
||||
):
|
||||
"""Test that search only returns public conversations."""
|
||||
# Create both public and private conversations
|
||||
await app_conversation_service.save_app_conversation_info(
|
||||
sample_conversation_info
|
||||
)
|
||||
await app_conversation_service.save_app_conversation_info(
|
||||
sample_private_conversation_info
|
||||
)
|
||||
|
||||
# Search for all conversations
|
||||
result = (
|
||||
await shared_conversation_info_service.search_shared_conversation_info()
|
||||
)
|
||||
|
||||
# Should only return the public conversation
|
||||
assert len(result.items) == 1
|
||||
assert result.items[0].id == sample_conversation_info.id
|
||||
assert result.items[0].title == sample_conversation_info.title
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_shared_conversation_info_with_title_filter(
|
||||
self,
|
||||
shared_conversation_info_service,
|
||||
app_conversation_service,
|
||||
sample_conversation_info,
|
||||
):
|
||||
"""Test searching with title filter."""
|
||||
# Create a public conversation
|
||||
await app_conversation_service.save_app_conversation_info(
|
||||
sample_conversation_info
|
||||
)
|
||||
|
||||
# Search with matching title
|
||||
result = await shared_conversation_info_service.search_shared_conversation_info(
|
||||
title__contains='Test'
|
||||
)
|
||||
assert len(result.items) == 1
|
||||
|
||||
# Search with non-matching title
|
||||
result = await shared_conversation_info_service.search_shared_conversation_info(
|
||||
title__contains='NonExistent'
|
||||
)
|
||||
assert len(result.items) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_shared_conversation_info_with_sort_order(
|
||||
self,
|
||||
shared_conversation_info_service,
|
||||
app_conversation_service,
|
||||
):
|
||||
"""Test searching with different sort orders."""
|
||||
# Create multiple public conversations with different titles and timestamps
|
||||
conv1 = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='test_user',
|
||||
sandbox_id='test_sandbox_1',
|
||||
title='A First Conversation',
|
||||
created_at=datetime(2023, 1, 1, tzinfo=UTC),
|
||||
updated_at=datetime(2023, 1, 1, tzinfo=UTC),
|
||||
public=True,
|
||||
metrics=MetricsSnapshot(
|
||||
accumulated_cost=0.0,
|
||||
max_budget_per_task=10.0,
|
||||
accumulated_token_usage=TokenUsage(),
|
||||
),
|
||||
)
|
||||
conv2 = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='test_user',
|
||||
sandbox_id='test_sandbox_2',
|
||||
title='B Second Conversation',
|
||||
created_at=datetime(2023, 1, 2, tzinfo=UTC),
|
||||
updated_at=datetime(2023, 1, 2, tzinfo=UTC),
|
||||
public=True,
|
||||
metrics=MetricsSnapshot(
|
||||
accumulated_cost=0.0,
|
||||
max_budget_per_task=10.0,
|
||||
accumulated_token_usage=TokenUsage(),
|
||||
),
|
||||
)
|
||||
|
||||
await app_conversation_service.save_app_conversation_info(conv1)
|
||||
await app_conversation_service.save_app_conversation_info(conv2)
|
||||
|
||||
# Test sort by title ascending
|
||||
result = await shared_conversation_info_service.search_shared_conversation_info(
|
||||
sort_order=SharedConversationSortOrder.TITLE
|
||||
)
|
||||
assert len(result.items) == 2
|
||||
assert result.items[0].title == 'A First Conversation'
|
||||
assert result.items[1].title == 'B Second Conversation'
|
||||
|
||||
# Test sort by title descending
|
||||
result = await shared_conversation_info_service.search_shared_conversation_info(
|
||||
sort_order=SharedConversationSortOrder.TITLE_DESC
|
||||
)
|
||||
assert len(result.items) == 2
|
||||
assert result.items[0].title == 'B Second Conversation'
|
||||
assert result.items[1].title == 'A First Conversation'
|
||||
|
||||
# Test sort by created_at ascending
|
||||
result = await shared_conversation_info_service.search_shared_conversation_info(
|
||||
sort_order=SharedConversationSortOrder.CREATED_AT
|
||||
)
|
||||
assert len(result.items) == 2
|
||||
assert result.items[0].id == conv1.id
|
||||
assert result.items[1].id == conv2.id
|
||||
|
||||
# Test sort by created_at descending (default)
|
||||
result = await shared_conversation_info_service.search_shared_conversation_info(
|
||||
sort_order=SharedConversationSortOrder.CREATED_AT_DESC
|
||||
)
|
||||
assert len(result.items) == 2
|
||||
assert result.items[0].id == conv2.id
|
||||
assert result.items[1].id == conv1.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_shared_conversation_info(
|
||||
self,
|
||||
shared_conversation_info_service,
|
||||
app_conversation_service,
|
||||
sample_conversation_info,
|
||||
sample_private_conversation_info,
|
||||
):
|
||||
"""Test counting public conversations."""
|
||||
# Initially should be 0
|
||||
count = await shared_conversation_info_service.count_shared_conversation_info()
|
||||
assert count == 0
|
||||
|
||||
# Create a public conversation
|
||||
await app_conversation_service.save_app_conversation_info(
|
||||
sample_conversation_info
|
||||
)
|
||||
count = await shared_conversation_info_service.count_shared_conversation_info()
|
||||
assert count == 1
|
||||
|
||||
# Create a private conversation - count should remain 1
|
||||
await app_conversation_service.save_app_conversation_info(
|
||||
sample_private_conversation_info
|
||||
)
|
||||
count = await shared_conversation_info_service.count_shared_conversation_info()
|
||||
assert count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_get_shared_conversation_info(
|
||||
self,
|
||||
shared_conversation_info_service,
|
||||
app_conversation_service,
|
||||
sample_conversation_info,
|
||||
sample_private_conversation_info,
|
||||
):
|
||||
"""Test batch getting public conversations."""
|
||||
# Create both public and private conversations
|
||||
await app_conversation_service.save_app_conversation_info(
|
||||
sample_conversation_info
|
||||
)
|
||||
await app_conversation_service.save_app_conversation_info(
|
||||
sample_private_conversation_info
|
||||
)
|
||||
|
||||
# Batch get both conversations
|
||||
result = (
|
||||
await shared_conversation_info_service.batch_get_shared_conversation_info(
|
||||
[sample_conversation_info.id, sample_private_conversation_info.id]
|
||||
)
|
||||
)
|
||||
|
||||
# Should return the public one and None for the private one
|
||||
assert len(result) == 2
|
||||
assert result[0] is not None
|
||||
assert result[0].id == sample_conversation_info.id
|
||||
assert result[1] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_with_pagination(
|
||||
self,
|
||||
shared_conversation_info_service,
|
||||
app_conversation_service,
|
||||
):
|
||||
"""Test search with pagination."""
|
||||
# Create multiple public conversations
|
||||
conversations = []
|
||||
for i in range(5):
|
||||
conv = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='test_user',
|
||||
sandbox_id=f'test_sandbox_{i}',
|
||||
title=f'Conversation {i}',
|
||||
created_at=datetime(2023, 1, i + 1, tzinfo=UTC),
|
||||
updated_at=datetime(2023, 1, i + 1, tzinfo=UTC),
|
||||
public=True,
|
||||
metrics=MetricsSnapshot(
|
||||
accumulated_cost=0.0,
|
||||
max_budget_per_task=10.0,
|
||||
accumulated_token_usage=TokenUsage(),
|
||||
),
|
||||
)
|
||||
conversations.append(conv)
|
||||
await app_conversation_service.save_app_conversation_info(conv)
|
||||
|
||||
# Get first page with limit 2
|
||||
result = await shared_conversation_info_service.search_shared_conversation_info(
|
||||
limit=2, sort_order=SharedConversationSortOrder.CREATED_AT
|
||||
)
|
||||
assert len(result.items) == 2
|
||||
assert result.next_page_id is not None
|
||||
|
||||
# Get next page
|
||||
result2 = (
|
||||
await shared_conversation_info_service.search_shared_conversation_info(
|
||||
limit=2,
|
||||
page_id=result.next_page_id,
|
||||
sort_order=SharedConversationSortOrder.CREATED_AT,
|
||||
)
|
||||
)
|
||||
assert len(result2.items) == 2
|
||||
assert result2.next_page_id is not None
|
||||
|
||||
# Verify no overlap between pages
|
||||
page1_ids = {item.id for item in result.items}
|
||||
page2_ids = {item.id for item in result2.items}
|
||||
assert page1_ids.isdisjoint(page2_ids)
|
||||
@@ -0,0 +1,365 @@
|
||||
"""Tests for SharedEventService."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from server.sharing.filesystem_shared_event_service import (
|
||||
SharedEventServiceImpl,
|
||||
)
|
||||
from server.sharing.shared_conversation_info_service import (
|
||||
SharedConversationInfoService,
|
||||
)
|
||||
from server.sharing.shared_conversation_models import SharedConversation
|
||||
|
||||
from openhands.agent_server.models import EventPage, EventSortOrder
|
||||
from openhands.app_server.event.event_service import EventService
|
||||
from openhands.sdk.llm import MetricsSnapshot
|
||||
from openhands.sdk.llm.utils.metrics import TokenUsage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_shared_conversation_info_service():
|
||||
"""Create a mock SharedConversationInfoService."""
|
||||
return AsyncMock(spec=SharedConversationInfoService)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_event_service():
|
||||
"""Create a mock EventService."""
|
||||
return AsyncMock(spec=EventService)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def shared_event_service(mock_shared_conversation_info_service, mock_event_service):
|
||||
"""Create a SharedEventService for testing."""
|
||||
return SharedEventServiceImpl(
|
||||
shared_conversation_info_service=mock_shared_conversation_info_service,
|
||||
event_service=mock_event_service,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_public_conversation():
|
||||
"""Create a sample public conversation."""
|
||||
return SharedConversation(
|
||||
id=uuid4(),
|
||||
created_by_user_id='test_user',
|
||||
sandbox_id='test_sandbox',
|
||||
title='Test Public Conversation',
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
metrics=MetricsSnapshot(
|
||||
accumulated_cost=0.0,
|
||||
max_budget_per_task=10.0,
|
||||
accumulated_token_usage=TokenUsage(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_event():
|
||||
"""Create a sample event."""
|
||||
# For testing purposes, we'll just use a mock that the EventPage can accept
|
||||
# The actual event creation is complex and not the focus of these tests
|
||||
return None
|
||||
|
||||
|
||||
class TestSharedEventService:
|
||||
"""Test cases for SharedEventService."""
|
||||
|
||||
async def test_get_shared_event_returns_event_for_public_conversation(
|
||||
self,
|
||||
shared_event_service,
|
||||
mock_shared_conversation_info_service,
|
||||
mock_event_service,
|
||||
sample_public_conversation,
|
||||
sample_event,
|
||||
):
|
||||
"""Test that get_shared_event returns an event for a public conversation."""
|
||||
conversation_id = sample_public_conversation.id
|
||||
event_id = 'test_event_id'
|
||||
|
||||
# Mock the public conversation service to return a public conversation
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
|
||||
|
||||
# Mock the event service to return an event
|
||||
mock_event_service.get_event.return_value = sample_event
|
||||
|
||||
# Call the method
|
||||
result = await shared_event_service.get_shared_event(conversation_id, event_id)
|
||||
|
||||
# Verify the result
|
||||
assert result == sample_event
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
|
||||
conversation_id
|
||||
)
|
||||
mock_event_service.get_event.assert_called_once_with(event_id)
|
||||
|
||||
async def test_get_shared_event_returns_none_for_private_conversation(
|
||||
self,
|
||||
shared_event_service,
|
||||
mock_shared_conversation_info_service,
|
||||
mock_event_service,
|
||||
):
|
||||
"""Test that get_shared_event returns None for a private conversation."""
|
||||
conversation_id = uuid4()
|
||||
event_id = 'test_event_id'
|
||||
|
||||
# Mock the public conversation service to return None (private conversation)
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = None
|
||||
|
||||
# Call the method
|
||||
result = await shared_event_service.get_shared_event(conversation_id, event_id)
|
||||
|
||||
# Verify the result
|
||||
assert result is None
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
|
||||
conversation_id
|
||||
)
|
||||
# Event service should not be called
|
||||
mock_event_service.get_event.assert_not_called()
|
||||
|
||||
async def test_search_shared_events_returns_events_for_public_conversation(
|
||||
self,
|
||||
shared_event_service,
|
||||
mock_shared_conversation_info_service,
|
||||
mock_event_service,
|
||||
sample_public_conversation,
|
||||
sample_event,
|
||||
):
|
||||
"""Test that search_shared_events returns events for a public conversation."""
|
||||
conversation_id = sample_public_conversation.id
|
||||
|
||||
# Mock the public conversation service to return a public conversation
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
|
||||
|
||||
# Mock the event service to return events
|
||||
mock_event_page = EventPage(items=[], next_page_id=None)
|
||||
mock_event_service.search_events.return_value = mock_event_page
|
||||
|
||||
# Call the method
|
||||
result = await shared_event_service.search_shared_events(
|
||||
conversation_id=conversation_id,
|
||||
kind__eq='ActionEvent',
|
||||
limit=10,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == mock_event_page
|
||||
assert len(result.items) == 0 # Empty list as we mocked
|
||||
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
|
||||
conversation_id
|
||||
)
|
||||
mock_event_service.search_events.assert_called_once_with(
|
||||
conversation_id__eq=conversation_id,
|
||||
kind__eq='ActionEvent',
|
||||
timestamp__gte=None,
|
||||
timestamp__lt=None,
|
||||
sort_order=EventSortOrder.TIMESTAMP,
|
||||
page_id=None,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
async def test_search_shared_events_returns_empty_for_private_conversation(
|
||||
self,
|
||||
shared_event_service,
|
||||
mock_shared_conversation_info_service,
|
||||
mock_event_service,
|
||||
):
|
||||
"""Test that search_shared_events returns empty page for a private conversation."""
|
||||
conversation_id = uuid4()
|
||||
|
||||
# Mock the public conversation service to return None (private conversation)
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = None
|
||||
|
||||
# Call the method
|
||||
result = await shared_event_service.search_shared_events(
|
||||
conversation_id=conversation_id,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, EventPage)
|
||||
assert len(result.items) == 0
|
||||
assert result.next_page_id is None
|
||||
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
|
||||
conversation_id
|
||||
)
|
||||
# Event service should not be called
|
||||
mock_event_service.search_events.assert_not_called()
|
||||
|
||||
async def test_count_shared_events_returns_count_for_public_conversation(
|
||||
self,
|
||||
shared_event_service,
|
||||
mock_shared_conversation_info_service,
|
||||
mock_event_service,
|
||||
sample_public_conversation,
|
||||
):
|
||||
"""Test that count_shared_events returns count for a public conversation."""
|
||||
conversation_id = sample_public_conversation.id
|
||||
|
||||
# Mock the public conversation service to return a public conversation
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
|
||||
|
||||
# Mock the event service to return a count
|
||||
mock_event_service.count_events.return_value = 5
|
||||
|
||||
# Call the method
|
||||
result = await shared_event_service.count_shared_events(
|
||||
conversation_id=conversation_id,
|
||||
kind__eq='ActionEvent',
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == 5
|
||||
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
|
||||
conversation_id
|
||||
)
|
||||
mock_event_service.count_events.assert_called_once_with(
|
||||
conversation_id__eq=conversation_id,
|
||||
kind__eq='ActionEvent',
|
||||
timestamp__gte=None,
|
||||
timestamp__lt=None,
|
||||
sort_order=EventSortOrder.TIMESTAMP,
|
||||
)
|
||||
|
||||
async def test_count_shared_events_returns_zero_for_private_conversation(
|
||||
self,
|
||||
shared_event_service,
|
||||
mock_shared_conversation_info_service,
|
||||
mock_event_service,
|
||||
):
|
||||
"""Test that count_shared_events returns 0 for a private conversation."""
|
||||
conversation_id = uuid4()
|
||||
|
||||
# Mock the public conversation service to return None (private conversation)
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = None
|
||||
|
||||
# Call the method
|
||||
result = await shared_event_service.count_shared_events(
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == 0
|
||||
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
|
||||
conversation_id
|
||||
)
|
||||
# Event service should not be called
|
||||
mock_event_service.count_events.assert_not_called()
|
||||
|
||||
async def test_batch_get_shared_events_returns_events_for_public_conversation(
|
||||
self,
|
||||
shared_event_service,
|
||||
mock_shared_conversation_info_service,
|
||||
mock_event_service,
|
||||
sample_public_conversation,
|
||||
sample_event,
|
||||
):
|
||||
"""Test that batch_get_shared_events returns events for a public conversation."""
|
||||
conversation_id = sample_public_conversation.id
|
||||
event_ids = ['event1', 'event2']
|
||||
|
||||
# Mock the public conversation service to return a public conversation
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
|
||||
|
||||
# Mock the event service to return events
|
||||
mock_event_service.get_event.side_effect = [sample_event, None]
|
||||
|
||||
# Call the method
|
||||
result = await shared_event_service.batch_get_shared_events(
|
||||
conversation_id, event_ids
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert len(result) == 2
|
||||
assert result[0] == sample_event
|
||||
assert result[1] is None
|
||||
|
||||
# Verify that get_shared_conversation_info was called for each event
|
||||
assert (
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.call_count
|
||||
== 2
|
||||
)
|
||||
# Verify that get_event was called for each event
|
||||
assert mock_event_service.get_event.call_count == 2
|
||||
|
||||
async def test_batch_get_shared_events_returns_none_for_private_conversation(
|
||||
self,
|
||||
shared_event_service,
|
||||
mock_shared_conversation_info_service,
|
||||
mock_event_service,
|
||||
):
|
||||
"""Test that batch_get_shared_events returns None for a private conversation."""
|
||||
conversation_id = uuid4()
|
||||
event_ids = ['event1', 'event2']
|
||||
|
||||
# Mock the public conversation service to return None (private conversation)
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = None
|
||||
|
||||
# Call the method
|
||||
result = await shared_event_service.batch_get_shared_events(
|
||||
conversation_id, event_ids
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert len(result) == 2
|
||||
assert result[0] is None
|
||||
assert result[1] is None
|
||||
|
||||
# Verify that get_shared_conversation_info was called for each event
|
||||
assert (
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.call_count
|
||||
== 2
|
||||
)
|
||||
# Event service should not be called
|
||||
mock_event_service.get_event.assert_not_called()
|
||||
|
||||
async def test_search_shared_events_with_all_parameters(
|
||||
self,
|
||||
shared_event_service,
|
||||
mock_shared_conversation_info_service,
|
||||
mock_event_service,
|
||||
sample_public_conversation,
|
||||
):
|
||||
"""Test search_shared_events with all parameters."""
|
||||
conversation_id = sample_public_conversation.id
|
||||
timestamp_gte = datetime(2023, 1, 1, tzinfo=UTC)
|
||||
timestamp_lt = datetime(2023, 12, 31, tzinfo=UTC)
|
||||
|
||||
# Mock the public conversation service to return a public conversation
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
|
||||
|
||||
# Mock the event service to return events
|
||||
mock_event_page = EventPage(items=[], next_page_id='next_page')
|
||||
mock_event_service.search_events.return_value = mock_event_page
|
||||
|
||||
# Call the method with all parameters
|
||||
result = await shared_event_service.search_shared_events(
|
||||
conversation_id=conversation_id,
|
||||
kind__eq='ObservationEvent',
|
||||
timestamp__gte=timestamp_gte,
|
||||
timestamp__lt=timestamp_lt,
|
||||
sort_order=EventSortOrder.TIMESTAMP_DESC,
|
||||
page_id='current_page',
|
||||
limit=50,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == mock_event_page
|
||||
|
||||
mock_event_service.search_events.assert_called_once_with(
|
||||
conversation_id__eq=conversation_id,
|
||||
kind__eq='ObservationEvent',
|
||||
timestamp__gte=timestamp_gte,
|
||||
timestamp__lt=timestamp_lt,
|
||||
sort_order=EventSortOrder.TIMESTAMP_DESC,
|
||||
page_id='current_page',
|
||||
limit=50,
|
||||
)
|
||||
@@ -1,6 +1,8 @@
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from keycloak.exceptions import KeycloakConnectionError, KeycloakError
|
||||
from server.auth.token_manager import TokenManager
|
||||
from sqlalchemy.orm import Session
|
||||
from storage.offline_token_store import OfflineTokenStore
|
||||
from storage.stored_offline_token import StoredOfflineToken
|
||||
@@ -32,6 +34,14 @@ def token_store(mock_session_maker, mock_config):
|
||||
return OfflineTokenStore('test_user_id', mock_session_maker, mock_config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def token_manager():
|
||||
with patch('server.config.get_config') as mock_get_config:
|
||||
mock_config = mock_get_config.return_value
|
||||
mock_config.jwt_secret.get_secret_value.return_value = 'test_secret'
|
||||
return TokenManager(external=False)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_token_new_record(token_store, mock_session):
|
||||
# Setup
|
||||
@@ -109,3 +119,419 @@ async def test_get_instance(mock_config):
|
||||
assert isinstance(result, OfflineTokenStore)
|
||||
assert result.user_id == test_user_id
|
||||
assert result.config == mock_config
|
||||
|
||||
|
||||
class TestCheckDuplicateBaseEmail:
|
||||
"""Test cases for check_duplicate_base_email method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_duplicate_base_email_no_plus_modifier(self, token_manager):
|
||||
"""Test that emails without + modifier are still checked for duplicates."""
|
||||
# Arrange
|
||||
email = 'joe@example.com'
|
||||
current_user_id = 'user123'
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
token_manager, '_query_users_by_wildcard_pattern'
|
||||
) as mock_query,
|
||||
patch.object(token_manager, '_find_duplicate_in_users') as mock_find,
|
||||
):
|
||||
mock_find.return_value = False
|
||||
mock_query.return_value = {}
|
||||
|
||||
# Act
|
||||
result = await token_manager.check_duplicate_base_email(
|
||||
email, current_user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
mock_query.assert_called_once()
|
||||
mock_find.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_duplicate_base_email_empty_email(self, token_manager):
|
||||
"""Test that empty email returns False."""
|
||||
# Arrange
|
||||
email = ''
|
||||
current_user_id = 'user123'
|
||||
|
||||
# Act
|
||||
result = await token_manager.check_duplicate_base_email(email, current_user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_duplicate_base_email_invalid_email(self, token_manager):
|
||||
"""Test that invalid email returns False."""
|
||||
# Arrange
|
||||
email = 'invalid-email'
|
||||
current_user_id = 'user123'
|
||||
|
||||
# Act
|
||||
result = await token_manager.check_duplicate_base_email(email, current_user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_duplicate_base_email_duplicate_found(self, token_manager):
|
||||
"""Test that duplicate email is detected when found."""
|
||||
# Arrange
|
||||
email = 'joe+test@example.com'
|
||||
current_user_id = 'user123'
|
||||
existing_user = {
|
||||
'id': 'existing_user_id',
|
||||
'email': 'joe@example.com',
|
||||
}
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
token_manager, '_query_users_by_wildcard_pattern'
|
||||
) as mock_query,
|
||||
patch.object(token_manager, '_find_duplicate_in_users') as mock_find,
|
||||
):
|
||||
mock_find.return_value = True
|
||||
mock_query.return_value = {'existing_user_id': existing_user}
|
||||
|
||||
# Act
|
||||
result = await token_manager.check_duplicate_base_email(
|
||||
email, current_user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_query.assert_called_once()
|
||||
mock_find.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_duplicate_base_email_no_duplicate(self, token_manager):
|
||||
"""Test that no duplicate is found when none exists."""
|
||||
# Arrange
|
||||
email = 'joe+test@example.com'
|
||||
current_user_id = 'user123'
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
token_manager, '_query_users_by_wildcard_pattern'
|
||||
) as mock_query,
|
||||
patch.object(token_manager, '_find_duplicate_in_users') as mock_find,
|
||||
):
|
||||
mock_find.return_value = False
|
||||
mock_query.return_value = {}
|
||||
|
||||
# Act
|
||||
result = await token_manager.check_duplicate_base_email(
|
||||
email, current_user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_duplicate_base_email_keycloak_connection_error(
|
||||
self, token_manager
|
||||
):
|
||||
"""Test that KeycloakConnectionError triggers retry and raises RetryError."""
|
||||
# Arrange
|
||||
email = 'joe+test@example.com'
|
||||
current_user_id = 'user123'
|
||||
|
||||
with patch.object(
|
||||
token_manager, '_query_users_by_wildcard_pattern'
|
||||
) as mock_query:
|
||||
mock_query.side_effect = KeycloakConnectionError('Connection failed')
|
||||
|
||||
# Act & Assert
|
||||
# KeycloakConnectionError is re-raised, which triggers retry decorator
|
||||
# After retries exhaust (2 attempts), it raises RetryError
|
||||
from tenacity import RetryError
|
||||
|
||||
with pytest.raises(RetryError):
|
||||
await token_manager.check_duplicate_base_email(email, current_user_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_duplicate_base_email_general_exception(self, token_manager):
|
||||
"""Test that general exceptions are handled gracefully."""
|
||||
# Arrange
|
||||
email = 'joe+test@example.com'
|
||||
current_user_id = 'user123'
|
||||
|
||||
with patch.object(
|
||||
token_manager, '_query_users_by_wildcard_pattern'
|
||||
) as mock_query:
|
||||
mock_query.side_effect = Exception('Unexpected error')
|
||||
|
||||
# Act
|
||||
result = await token_manager.check_duplicate_base_email(
|
||||
email, current_user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestQueryUsersByWildcardPattern:
|
||||
"""Test cases for _query_users_by_wildcard_pattern method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_users_by_wildcard_pattern_success_with_search(
|
||||
self, token_manager
|
||||
):
|
||||
"""Test successful query using search parameter."""
|
||||
# Arrange
|
||||
local_part = 'joe'
|
||||
domain = 'example.com'
|
||||
mock_users = [
|
||||
{'id': 'user1', 'email': 'joe@example.com'},
|
||||
{'id': 'user2', 'email': 'joe+test@example.com'},
|
||||
]
|
||||
|
||||
with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin:
|
||||
mock_admin = MagicMock()
|
||||
mock_admin.a_get_users = AsyncMock(return_value=mock_users)
|
||||
mock_get_admin.return_value = mock_admin
|
||||
|
||||
# Act
|
||||
result = await token_manager._query_users_by_wildcard_pattern(
|
||||
local_part, domain
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert 'user1' in result
|
||||
assert 'user2' in result
|
||||
mock_admin.a_get_users.assert_called_once_with(
|
||||
{'search': 'joe*@example.com'}
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_users_by_wildcard_pattern_fallback_to_q(self, token_manager):
|
||||
"""Test fallback to q parameter when search fails."""
|
||||
# Arrange
|
||||
local_part = 'joe'
|
||||
domain = 'example.com'
|
||||
mock_users = [{'id': 'user1', 'email': 'joe@example.com'}]
|
||||
|
||||
with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin:
|
||||
mock_admin = MagicMock()
|
||||
# First call fails, second succeeds
|
||||
mock_admin.a_get_users = AsyncMock(
|
||||
side_effect=[Exception('Search failed'), mock_users]
|
||||
)
|
||||
mock_get_admin.return_value = mock_admin
|
||||
|
||||
# Act
|
||||
result = await token_manager._query_users_by_wildcard_pattern(
|
||||
local_part, domain
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
assert 'user1' in result
|
||||
assert mock_admin.a_get_users.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_users_by_wildcard_pattern_empty_result(self, token_manager):
|
||||
"""Test query returns empty dict when no users found."""
|
||||
# Arrange
|
||||
local_part = 'joe'
|
||||
domain = 'example.com'
|
||||
|
||||
with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin:
|
||||
mock_admin = MagicMock()
|
||||
mock_admin.a_get_users = AsyncMock(return_value=[])
|
||||
mock_get_admin.return_value = mock_admin
|
||||
|
||||
# Act
|
||||
result = await token_manager._query_users_by_wildcard_pattern(
|
||||
local_part, domain
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == {}
|
||||
|
||||
|
||||
class TestFindDuplicateInUsers:
|
||||
"""Test cases for _find_duplicate_in_users method."""
|
||||
|
||||
def test_find_duplicate_in_users_with_regex_match(self, token_manager):
|
||||
"""Test finding duplicate using regex pattern."""
|
||||
# Arrange
|
||||
users = {
|
||||
'user1': {'id': 'user1', 'email': 'joe@example.com'},
|
||||
'user2': {'id': 'user2', 'email': 'joe+test@example.com'},
|
||||
}
|
||||
base_email = 'joe@example.com'
|
||||
current_user_id = 'user3'
|
||||
|
||||
# Act
|
||||
result = token_manager._find_duplicate_in_users(
|
||||
users, base_email, current_user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
def test_find_duplicate_in_users_fallback_to_simple_matching(self, token_manager):
|
||||
"""Test fallback to simple matching when regex pattern is None."""
|
||||
# Arrange
|
||||
users = {
|
||||
'user1': {'id': 'user1', 'email': 'joe@example.com'},
|
||||
}
|
||||
base_email = 'invalid-email' # Will cause regex pattern to be None
|
||||
current_user_id = 'user2'
|
||||
|
||||
with patch(
|
||||
'server.auth.token_manager.get_base_email_regex_pattern', return_value=None
|
||||
):
|
||||
# Act
|
||||
result = token_manager._find_duplicate_in_users(
|
||||
users, base_email, current_user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
# Should use fallback matching, but invalid base_email won't match
|
||||
assert result is False
|
||||
|
||||
def test_find_duplicate_in_users_excludes_current_user(self, token_manager):
|
||||
"""Test that current user is excluded from duplicate check."""
|
||||
# Arrange
|
||||
users = {
|
||||
'user1': {'id': 'user1', 'email': 'joe@example.com'},
|
||||
}
|
||||
base_email = 'joe@example.com'
|
||||
current_user_id = 'user1' # Same as user in users dict
|
||||
|
||||
# Act
|
||||
result = token_manager._find_duplicate_in_users(
|
||||
users, base_email, current_user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
def test_find_duplicate_in_users_no_match(self, token_manager):
|
||||
"""Test that no duplicate is found when emails don't match."""
|
||||
# Arrange
|
||||
users = {
|
||||
'user1': {'id': 'user1', 'email': 'jane@example.com'},
|
||||
}
|
||||
base_email = 'joe@example.com'
|
||||
current_user_id = 'user2'
|
||||
|
||||
# Act
|
||||
result = token_manager._find_duplicate_in_users(
|
||||
users, base_email, current_user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
def test_find_duplicate_in_users_empty_dict(self, token_manager):
|
||||
"""Test that empty users dict returns False."""
|
||||
# Arrange
|
||||
users: dict[str, dict] = {}
|
||||
base_email = 'joe@example.com'
|
||||
current_user_id = 'user1'
|
||||
|
||||
# Act
|
||||
result = token_manager._find_duplicate_in_users(
|
||||
users, base_email, current_user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestDeleteKeycloakUser:
|
||||
"""Test cases for delete_keycloak_user method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_keycloak_user_success(self, token_manager):
|
||||
"""Test successful deletion of Keycloak user."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
|
||||
with (
|
||||
patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin,
|
||||
patch('asyncio.to_thread') as mock_to_thread,
|
||||
):
|
||||
mock_admin = MagicMock()
|
||||
mock_admin.delete_user = MagicMock()
|
||||
mock_get_admin.return_value = mock_admin
|
||||
mock_to_thread.return_value = None
|
||||
|
||||
# Act
|
||||
result = await token_manager.delete_keycloak_user(user_id)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_to_thread.assert_called_once_with(mock_admin.delete_user, user_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_keycloak_user_connection_error(self, token_manager):
|
||||
"""Test handling of KeycloakConnectionError triggers retry and raises RetryError."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
|
||||
with (
|
||||
patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin,
|
||||
patch('asyncio.to_thread') as mock_to_thread,
|
||||
):
|
||||
mock_admin = MagicMock()
|
||||
mock_admin.delete_user = MagicMock()
|
||||
mock_get_admin.return_value = mock_admin
|
||||
mock_to_thread.side_effect = KeycloakConnectionError('Connection failed')
|
||||
|
||||
# Act & Assert
|
||||
# KeycloakConnectionError triggers retry decorator
|
||||
# After retries exhaust (2 attempts), it raises RetryError
|
||||
from tenacity import RetryError
|
||||
|
||||
with pytest.raises(RetryError):
|
||||
await token_manager.delete_keycloak_user(user_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_keycloak_user_keycloak_error(self, token_manager):
|
||||
"""Test handling of KeycloakError (e.g., user not found)."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
|
||||
with (
|
||||
patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin,
|
||||
patch('asyncio.to_thread') as mock_to_thread,
|
||||
):
|
||||
mock_admin = MagicMock()
|
||||
mock_admin.delete_user = MagicMock()
|
||||
mock_get_admin.return_value = mock_admin
|
||||
mock_to_thread.side_effect = KeycloakError('User not found')
|
||||
|
||||
# Act
|
||||
result = await token_manager.delete_keycloak_user(user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_keycloak_user_general_exception(self, token_manager):
|
||||
"""Test handling of general exceptions."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
|
||||
with (
|
||||
patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin,
|
||||
patch('asyncio.to_thread') as mock_to_thread,
|
||||
):
|
||||
mock_admin = MagicMock()
|
||||
mock_admin.delete_user = MagicMock()
|
||||
mock_get_admin.return_value = mock_admin
|
||||
mock_to_thread.side_effect = Exception('Unexpected error')
|
||||
|
||||
# Act
|
||||
result = await token_manager.delete_keycloak_user(user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from server.auth.token_manager import TokenManager, create_encryption_utility
|
||||
@@ -246,3 +246,103 @@ async def test_refresh(token_manager):
|
||||
mock_keycloak.return_value.a_refresh_token.assert_called_once_with(
|
||||
'test_refresh_token'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disable_keycloak_user_success(token_manager):
|
||||
"""Test successful disabling of a Keycloak user account."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
email = 'user@colsch.us'
|
||||
mock_user = {
|
||||
'id': user_id,
|
||||
'username': 'testuser',
|
||||
'email': email,
|
||||
'emailVerified': True,
|
||||
}
|
||||
|
||||
with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin:
|
||||
mock_admin = MagicMock()
|
||||
mock_admin.a_get_user = AsyncMock(return_value=mock_user)
|
||||
mock_admin.a_update_user = AsyncMock()
|
||||
mock_get_admin.return_value = mock_admin
|
||||
|
||||
# Act
|
||||
await token_manager.disable_keycloak_user(user_id, email)
|
||||
|
||||
# Assert
|
||||
mock_admin.a_get_user.assert_called_once_with(user_id)
|
||||
mock_admin.a_update_user.assert_called_once_with(
|
||||
user_id=user_id,
|
||||
payload={
|
||||
'enabled': False,
|
||||
'username': 'testuser',
|
||||
'email': email,
|
||||
'emailVerified': True,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disable_keycloak_user_without_email(token_manager):
|
||||
"""Test disabling Keycloak user without providing email."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
mock_user = {
|
||||
'id': user_id,
|
||||
'username': 'testuser',
|
||||
'email': 'user@example.com',
|
||||
'emailVerified': False,
|
||||
}
|
||||
|
||||
with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin:
|
||||
mock_admin = MagicMock()
|
||||
mock_admin.a_get_user = AsyncMock(return_value=mock_user)
|
||||
mock_admin.a_update_user = AsyncMock()
|
||||
mock_get_admin.return_value = mock_admin
|
||||
|
||||
# Act
|
||||
await token_manager.disable_keycloak_user(user_id)
|
||||
|
||||
# Assert
|
||||
mock_admin.a_get_user.assert_called_once_with(user_id)
|
||||
mock_admin.a_update_user.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disable_keycloak_user_not_found(token_manager):
|
||||
"""Test disabling Keycloak user when user is not found."""
|
||||
# Arrange
|
||||
user_id = 'nonexistent_user_id'
|
||||
email = 'user@colsch.us'
|
||||
|
||||
with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin:
|
||||
mock_admin = MagicMock()
|
||||
mock_admin.a_get_user = AsyncMock(return_value=None)
|
||||
mock_get_admin.return_value = mock_admin
|
||||
|
||||
# Act
|
||||
await token_manager.disable_keycloak_user(user_id, email)
|
||||
|
||||
# Assert
|
||||
mock_admin.a_get_user.assert_called_once_with(user_id)
|
||||
mock_admin.a_update_user.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disable_keycloak_user_exception_handling(token_manager):
|
||||
"""Test that disable_keycloak_user handles exceptions gracefully without raising."""
|
||||
# Arrange
|
||||
user_id = 'test_user_id'
|
||||
email = 'user@colsch.us'
|
||||
|
||||
with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin:
|
||||
mock_admin = MagicMock()
|
||||
mock_admin.a_get_user = AsyncMock(side_effect=Exception('Connection error'))
|
||||
mock_get_admin.return_value = mock_admin
|
||||
|
||||
# Act & Assert - should not raise exception
|
||||
await token_manager.disable_keycloak_user(user_id, email)
|
||||
|
||||
# Verify the method was called
|
||||
mock_admin.a_get_user.assert_called_once_with(user_id)
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
# Evaluation
|
||||
|
||||
> [!WARNING]
|
||||
> **This directory is deprecated.** Our new benchmarks are located at [OpenHands/benchmarks](https://github.com/OpenHands/benchmarks).
|
||||
>
|
||||
> If you have already implemented a benchmark in this directory and would like to contribute it, we are happy to have the contribution. However, if you are starting anew, please use the new location.
|
||||
|
||||
This folder contains code and resources to run experiments and evaluations.
|
||||
|
||||
## For Benchmark Users
|
||||
|
||||
@@ -18,6 +18,8 @@
|
||||
"i18next/no-literal-string": "error",
|
||||
"unused-imports/no-unused-imports": "error",
|
||||
"prettier/prettier": ["error"],
|
||||
// Enforce using optional chaining (?.) instead of && chains for null/undefined checks
|
||||
"@typescript-eslint/prefer-optional-chain": "error",
|
||||
// Resolves https://stackoverflow.com/questions/59265981/typescript-eslint-missing-file-extension-ts-import-extensions/59268871#59268871
|
||||
"import/extensions": [
|
||||
"error",
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
public-hoist-pattern[]=*@nextui-org/*
|
||||
enable-pre-post-scripts=true
|
||||
146
frontend/__tests__/MSW.md
Normal file
146
frontend/__tests__/MSW.md
Normal file
@@ -0,0 +1,146 @@
|
||||
# Mock Service Worker (MSW) Guide
|
||||
|
||||
## Overview
|
||||
|
||||
[Mock Service Worker (MSW)](https://mswjs.io/) is an API mocking library that intercepts outgoing network requests at the network level. Unlike traditional mocking that patches `fetch` or `axios`, MSW uses a Service Worker in the browser and direct request interception in Node.js—making mocks transparent to your application code.
|
||||
|
||||
We use MSW in this project for:
|
||||
- **Testing**: Write reliable unit and integration tests without real network calls
|
||||
- **Development**: Run the frontend with mocked APIs when the backend isn't available or when working on features with pending backend APIs
|
||||
|
||||
The same mock handlers work in both environments, so you write them once and reuse everywhere.
|
||||
|
||||
## Relevant Files
|
||||
|
||||
- `src/mocks/handlers.ts` - Main handler registry that combines all domain handlers
|
||||
- `src/mocks/*-handlers.ts` - Domain-specific handlers (auth, billing, conversation, etc.)
|
||||
- `src/mocks/browser.ts` - Browser setup for development mode
|
||||
- `src/mocks/node.ts` - Node.js setup for tests
|
||||
- `vitest.setup.ts` - Global test setup with MSW lifecycle hooks
|
||||
|
||||
## Development Workflow
|
||||
|
||||
### Running with Mocked APIs
|
||||
|
||||
```sh
|
||||
# Run with API mocking enabled
|
||||
npm run dev:mock
|
||||
|
||||
# Run with API mocking + SaaS mode simulation
|
||||
npm run dev:mock:saas
|
||||
```
|
||||
|
||||
These commands set `VITE_MOCK_API=true` which activates the MSW Service Worker to intercept requests.
|
||||
|
||||
> [!NOTE]
|
||||
> **OSS vs SaaS Mode**
|
||||
>
|
||||
> OpenHands runs in two modes:
|
||||
> - **OSS mode**: For local/self-hosted deployments where users provide their own LLM API keys and configure git providers manually
|
||||
> - **SaaS mode**: For the cloud offering with billing, managed API keys, and OAuth-based GitHub integration
|
||||
>
|
||||
> Use `dev:mock:saas` when working on SaaS-specific features like billing, API key management, or subscription flows.
|
||||
|
||||
|
||||
## Writing Tests
|
||||
|
||||
### Service Layer Mocking (Recommended)
|
||||
|
||||
For most tests, mock at the service layer using `vi.spyOn`. This approach is explicit, test-scoped, and makes the scenario being tested clear.
|
||||
|
||||
```typescript
|
||||
import { vi } from "vitest";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
|
||||
const getSettingsSpy = vi.spyOn(SettingsService, "getSettings");
|
||||
getSettingsSpy.mockResolvedValue({
|
||||
llm_model: "openai/gpt-4o",
|
||||
llm_api_key_set: true,
|
||||
// ... other settings
|
||||
});
|
||||
```
|
||||
|
||||
Use `mockResolvedValue` for success scenarios and `mockRejectedValue` for error scenarios:
|
||||
|
||||
```typescript
|
||||
getSettingsSpy.mockRejectedValue(new Error("Failed to fetch settings"));
|
||||
```
|
||||
|
||||
### Network Layer Mocking (Advanced)
|
||||
|
||||
For tests that need actual network-level behavior (WebSockets, testing retry logic, etc.), use `server.use()` to override handlers per test.
|
||||
|
||||
> [!IMPORTANT]
|
||||
> **Reuse the global server instance** - Don't create new `setupServer()` calls in individual tests. The project already has a global MSW server configured in `vitest.setup.ts` that handles lifecycle (`server.listen()`, `server.resetHandlers()`, `server.close()`). Use `server.use()` to add runtime handlers for specific test scenarios.
|
||||
|
||||
```typescript
|
||||
import { http, HttpResponse } from "msw";
|
||||
import { server } from "#/mocks/node";
|
||||
|
||||
it("should handle server errors", async () => {
|
||||
server.use(
|
||||
http.get("/api/my-endpoint", () => {
|
||||
return new HttpResponse(null, { status: 500 });
|
||||
}),
|
||||
);
|
||||
// ... test code
|
||||
});
|
||||
```
|
||||
|
||||
For WebSocket testing, see `__tests__/helpers/msw-websocket-setup.ts` for utilities.
|
||||
|
||||
## Adding New API Mocks
|
||||
|
||||
When adding new API endpoints, create mocks in both places to maintain 1:1 similarity with the backend:
|
||||
|
||||
### 1. Add to `src/mocks/` (for development)
|
||||
|
||||
Create or update a domain-specific handler file:
|
||||
|
||||
```typescript
|
||||
// src/mocks/my-feature-handlers.ts
|
||||
import { http, HttpResponse } from "msw";
|
||||
|
||||
export const MY_FEATURE_HANDLERS = [
|
||||
http.get("/api/my-feature", () => {
|
||||
return HttpResponse.json({
|
||||
data: "mock response",
|
||||
});
|
||||
}),
|
||||
];
|
||||
```
|
||||
|
||||
Register in `handlers.ts`:
|
||||
|
||||
```typescript
|
||||
import { MY_FEATURE_HANDLERS } from "./my-feature-handlers";
|
||||
|
||||
export const handlers = [
|
||||
// ... existing handlers
|
||||
...MY_FEATURE_HANDLERS,
|
||||
];
|
||||
```
|
||||
|
||||
### 2. Mock in tests for specific scenarios
|
||||
|
||||
In your test files, spy on the service method to control responses per test case:
|
||||
|
||||
```typescript
|
||||
import { vi } from "vitest";
|
||||
import MyFeatureService from "#/api/my-feature-service.api";
|
||||
|
||||
const spy = vi.spyOn(MyFeatureService, "getData");
|
||||
spy.mockResolvedValue({ data: "test-specific response" });
|
||||
```
|
||||
|
||||
See `__tests__/routes/llm-settings.test.tsx` for a real-world example of service layer mocking.
|
||||
|
||||
> [!TIP]
|
||||
> For guidance on creating service APIs, see `src/api/README.md`.
|
||||
|
||||
## Best Practices
|
||||
|
||||
- **Keep mocks close to real API contracts** - Update mocks when backend changes
|
||||
- **Use service layer mocking for most tests** - It's simpler and more explicit
|
||||
- **Reserve network layer mocking for integration tests** - WebSockets, retry logic, etc.
|
||||
- **Export mock data from handler files** - Reuse in tests (e.g., `MOCK_DEFAULT_USER_SETTINGS`)
|
||||
149
frontend/__tests__/components/conversation-tab-title.test.tsx
Normal file
149
frontend/__tests__/components/conversation-tab-title.test.tsx
Normal file
@@ -0,0 +1,149 @@
|
||||
import { render, screen, waitFor } from "@testing-library/react";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { describe, expect, it, vi, beforeEach, afterEach } from "vitest";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { ConversationTabTitle } from "#/components/features/conversation/conversation-tabs/conversation-tab-title";
|
||||
import GitService from "#/api/git-service/git-service.api";
|
||||
import V1GitService from "#/api/git-service/v1-git-service.api";
|
||||
|
||||
// Mock the services that the hook depends on
|
||||
vi.mock("#/api/git-service/git-service.api");
|
||||
vi.mock("#/api/git-service/v1-git-service.api");
|
||||
|
||||
// Mock the hooks that useUnifiedGetGitChanges depends on
|
||||
vi.mock("#/hooks/use-conversation-id", () => ({
|
||||
useConversationId: () => ({
|
||||
conversationId: "test-conversation-id",
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/query/use-active-conversation", () => ({
|
||||
useActiveConversation: () => ({
|
||||
data: {
|
||||
conversation_version: "V0",
|
||||
url: null,
|
||||
session_api_key: null,
|
||||
selected_repository: null,
|
||||
},
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/use-runtime-is-ready", () => ({
|
||||
useRuntimeIsReady: () => true,
|
||||
}));
|
||||
|
||||
vi.mock("#/utils/get-git-path", () => ({
|
||||
getGitPath: () => "/workspace",
|
||||
}));
|
||||
|
||||
describe("ConversationTabTitle", () => {
|
||||
let queryClient: QueryClient;
|
||||
|
||||
beforeEach(() => {
|
||||
queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: {
|
||||
retry: false,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Mock GitService methods
|
||||
vi.mocked(GitService.getGitChanges).mockResolvedValue([]);
|
||||
vi.mocked(V1GitService.getGitChanges).mockResolvedValue([]);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
queryClient.clear();
|
||||
});
|
||||
|
||||
const renderWithProviders = (ui: React.ReactElement) => {
|
||||
return render(
|
||||
<QueryClientProvider client={queryClient}>{ui}</QueryClientProvider>,
|
||||
);
|
||||
};
|
||||
|
||||
describe("Rendering", () => {
|
||||
it("should render the title", () => {
|
||||
// Arrange
|
||||
const title = "Test Title";
|
||||
|
||||
// Act
|
||||
renderWithProviders(
|
||||
<ConversationTabTitle title={title} conversationKey="browser" />,
|
||||
);
|
||||
|
||||
// Assert
|
||||
expect(screen.getByText(title)).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should show refresh button when conversationKey is 'editor'", () => {
|
||||
// Arrange
|
||||
const title = "Changes";
|
||||
|
||||
// Act
|
||||
renderWithProviders(
|
||||
<ConversationTabTitle title={title} conversationKey="editor" />,
|
||||
);
|
||||
|
||||
// Assert
|
||||
const refreshButton = screen.getByRole("button");
|
||||
expect(refreshButton).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not show refresh button when conversationKey is not 'editor'", () => {
|
||||
// Arrange
|
||||
const title = "Browser";
|
||||
|
||||
// Act
|
||||
renderWithProviders(
|
||||
<ConversationTabTitle title={title} conversationKey="browser" />,
|
||||
);
|
||||
|
||||
// Assert
|
||||
expect(screen.queryByRole("button")).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("User Interactions", () => {
|
||||
it("should call refetch and trigger GitService.getGitChanges when refresh button is clicked", async () => {
|
||||
// Arrange
|
||||
const user = userEvent.setup();
|
||||
const title = "Changes";
|
||||
const mockGitChanges: Array<{
|
||||
path: string;
|
||||
status: "M" | "A" | "D" | "R" | "U";
|
||||
}> = [
|
||||
{ path: "file1.ts", status: "M" },
|
||||
{ path: "file2.ts", status: "A" },
|
||||
];
|
||||
|
||||
vi.mocked(GitService.getGitChanges).mockResolvedValue(mockGitChanges);
|
||||
|
||||
renderWithProviders(
|
||||
<ConversationTabTitle title={title} conversationKey="editor" />,
|
||||
);
|
||||
|
||||
const refreshButton = screen.getByRole("button");
|
||||
|
||||
// Wait for initial query to complete
|
||||
await waitFor(() => {
|
||||
expect(GitService.getGitChanges).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
// Clear the mock to track refetch calls
|
||||
vi.mocked(GitService.getGitChanges).mockClear();
|
||||
|
||||
// Act
|
||||
await user.click(refreshButton);
|
||||
|
||||
// Assert - refetch should trigger another service call
|
||||
await waitFor(() => {
|
||||
expect(GitService.getGitChanges).toHaveBeenCalledWith(
|
||||
"test-conversation-id",
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,6 +1,7 @@
|
||||
import { render, screen } from "@testing-library/react";
|
||||
import { it, describe, expect, vi, beforeEach, afterEach } from "vitest";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { MemoryRouter } from "react-router";
|
||||
import { AuthModal } from "#/components/features/waitlist/auth-modal";
|
||||
|
||||
// Mock the useAuthUrl hook
|
||||
@@ -27,11 +28,13 @@ describe("AuthModal", () => {
|
||||
|
||||
it("should render the GitHub and GitLab buttons", () => {
|
||||
render(
|
||||
<AuthModal
|
||||
githubAuthUrl="mock-url"
|
||||
appMode="saas"
|
||||
providersConfigured={["github", "gitlab"]}
|
||||
/>,
|
||||
<MemoryRouter>
|
||||
<AuthModal
|
||||
githubAuthUrl="mock-url"
|
||||
appMode="saas"
|
||||
providersConfigured={["github", "gitlab"]}
|
||||
/>
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
const githubButton = screen.getByRole("button", {
|
||||
@@ -49,11 +52,13 @@ describe("AuthModal", () => {
|
||||
const user = userEvent.setup();
|
||||
const mockUrl = "https://github.com/login/oauth/authorize";
|
||||
render(
|
||||
<AuthModal
|
||||
githubAuthUrl={mockUrl}
|
||||
appMode="saas"
|
||||
providersConfigured={["github"]}
|
||||
/>,
|
||||
<MemoryRouter>
|
||||
<AuthModal
|
||||
githubAuthUrl={mockUrl}
|
||||
appMode="saas"
|
||||
providersConfigured={["github"]}
|
||||
/>
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
const githubButton = screen.getByRole("button", {
|
||||
@@ -65,10 +70,14 @@ describe("AuthModal", () => {
|
||||
});
|
||||
|
||||
it("should render Terms of Service and Privacy Policy text with correct links", () => {
|
||||
render(<AuthModal githubAuthUrl="mock-url" appMode="saas" />);
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<AuthModal githubAuthUrl="mock-url" appMode="saas" />
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
// Find the terms of service section using data-testid
|
||||
const termsSection = screen.getByTestId("auth-modal-terms-of-service");
|
||||
const termsSection = screen.getByTestId("terms-and-privacy-notice");
|
||||
expect(termsSection).toBeInTheDocument();
|
||||
|
||||
// Check that all text content is present in the paragraph
|
||||
@@ -105,8 +114,44 @@ describe("AuthModal", () => {
|
||||
expect(termsSection).toContainElement(privacyLink);
|
||||
});
|
||||
|
||||
it("should display email verified message when emailVerified prop is true", () => {
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<AuthModal
|
||||
githubAuthUrl="mock-url"
|
||||
appMode="saas"
|
||||
emailVerified={true}
|
||||
/>
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
expect(
|
||||
screen.getByText("AUTH$EMAIL_VERIFIED_PLEASE_LOGIN"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not display email verified message when emailVerified prop is false", () => {
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<AuthModal
|
||||
githubAuthUrl="mock-url"
|
||||
appMode="saas"
|
||||
emailVerified={false}
|
||||
/>
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
expect(
|
||||
screen.queryByText("AUTH$EMAIL_VERIFIED_PLEASE_LOGIN"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should open Terms of Service link in new tab", () => {
|
||||
render(<AuthModal githubAuthUrl="mock-url" appMode="saas" />);
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<AuthModal githubAuthUrl="mock-url" appMode="saas" />
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
const tosLink = screen.getByRole("link", {
|
||||
name: "COMMON$TERMS_OF_SERVICE",
|
||||
@@ -115,11 +160,58 @@ describe("AuthModal", () => {
|
||||
});
|
||||
|
||||
it("should open Privacy Policy link in new tab", () => {
|
||||
render(<AuthModal githubAuthUrl="mock-url" appMode="saas" />);
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<AuthModal githubAuthUrl="mock-url" appMode="saas" />
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
const privacyLink = screen.getByRole("link", {
|
||||
name: "COMMON$PRIVACY_POLICY",
|
||||
});
|
||||
expect(privacyLink).toHaveAttribute("target", "_blank");
|
||||
});
|
||||
|
||||
describe("Duplicate email error message", () => {
|
||||
const renderAuthModalWithRouter = (initialEntries: string[]) => {
|
||||
const hasDuplicatedEmail = initialEntries.includes(
|
||||
"/?duplicated_email=true",
|
||||
);
|
||||
|
||||
return render(
|
||||
<MemoryRouter initialEntries={initialEntries}>
|
||||
<AuthModal
|
||||
githubAuthUrl="mock-url"
|
||||
appMode="saas"
|
||||
providersConfigured={["github"]}
|
||||
hasDuplicatedEmail={hasDuplicatedEmail}
|
||||
/>
|
||||
</MemoryRouter>,
|
||||
);
|
||||
};
|
||||
|
||||
it("should display error message when duplicated_email query parameter is true", () => {
|
||||
// Arrange
|
||||
const initialEntries = ["/?duplicated_email=true"];
|
||||
|
||||
// Act
|
||||
renderAuthModalWithRouter(initialEntries);
|
||||
|
||||
// Assert
|
||||
const errorMessage = screen.getByText("AUTH$DUPLICATE_EMAIL_ERROR");
|
||||
expect(errorMessage).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not display error message when duplicated_email query parameter is missing", () => {
|
||||
// Arrange
|
||||
const initialEntries = ["/"];
|
||||
|
||||
// Act
|
||||
renderAuthModalWithRouter(initialEntries);
|
||||
|
||||
// Assert
|
||||
const errorMessage = screen.queryByText("AUTH$DUPLICATE_EMAIL_ERROR");
|
||||
expect(errorMessage).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -23,6 +23,11 @@ describe("ConversationPanel", () => {
|
||||
Component: () => <ConversationPanel onClose={onCloseMock} />,
|
||||
path: "/",
|
||||
},
|
||||
{
|
||||
// Add route to prevent "No routes matched location" warning
|
||||
Component: () => null,
|
||||
path: "/conversations/:conversationId",
|
||||
},
|
||||
]);
|
||||
|
||||
const renderConversationPanel = () => renderWithProviders(<RouterStub />);
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
import { render, screen } from "@testing-library/react";
|
||||
import { describe, it, expect, vi } from "vitest";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { MemoryRouter } from "react-router";
|
||||
import { AgentStatus } from "#/components/features/controls/agent-status";
|
||||
import { AgentState } from "#/types/agent-state";
|
||||
import { useAgentState } from "#/hooks/use-agent-state";
|
||||
import { useConversationStore } from "#/stores/conversation-store";
|
||||
|
||||
vi.mock("#/hooks/use-agent-state");
|
||||
|
||||
vi.mock("#/hooks/use-conversation-id", () => ({
|
||||
useConversationId: () => ({ conversationId: "test-id" }),
|
||||
}));
|
||||
|
||||
const wrapper = ({ children }: { children: React.ReactNode }) => (
|
||||
<MemoryRouter>
|
||||
<QueryClientProvider client={new QueryClient()}>
|
||||
{children}
|
||||
</QueryClientProvider>
|
||||
</MemoryRouter>
|
||||
);
|
||||
|
||||
const renderAgentStatus = ({
|
||||
isPausing = false,
|
||||
}: { isPausing?: boolean } = {}) =>
|
||||
render(
|
||||
<AgentStatus
|
||||
handleStop={vi.fn()}
|
||||
handleResumeAgent={vi.fn()}
|
||||
isPausing={isPausing}
|
||||
/>,
|
||||
{ wrapper },
|
||||
);
|
||||
|
||||
describe("AgentStatus - isLoading logic", () => {
|
||||
it("should show loading when curAgentState is INIT", () => {
|
||||
vi.mocked(useAgentState).mockReturnValue({
|
||||
curAgentState: AgentState.INIT,
|
||||
});
|
||||
|
||||
renderAgentStatus();
|
||||
|
||||
expect(screen.getByTestId("agent-loading-spinner")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should show loading when isPausing is true, even if shouldShownAgentLoading is false", () => {
|
||||
vi.mocked(useAgentState).mockReturnValue({
|
||||
curAgentState: AgentState.AWAITING_USER_INPUT,
|
||||
});
|
||||
|
||||
renderAgentStatus({ isPausing: true });
|
||||
|
||||
expect(screen.getByTestId("agent-loading-spinner")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should NOT update global shouldShownAgentLoading when only isPausing is true", () => {
|
||||
vi.mocked(useAgentState).mockReturnValue({
|
||||
curAgentState: AgentState.AWAITING_USER_INPUT,
|
||||
});
|
||||
|
||||
renderAgentStatus({ isPausing: true });
|
||||
|
||||
// Loading spinner shows (because isPausing)
|
||||
expect(screen.getByTestId("agent-loading-spinner")).toBeInTheDocument();
|
||||
|
||||
// But global state should be false (because shouldShownAgentLoading is false)
|
||||
const { shouldShownAgentLoading } = useConversationStore.getState();
|
||||
expect(shouldShownAgentLoading).toBe(false);
|
||||
});
|
||||
});
|
||||
@@ -42,7 +42,7 @@ vi.mock("react-i18next", async () => {
|
||||
BUTTON$EXPORT_CONVERSATION: "Export Conversation",
|
||||
BUTTON$DOWNLOAD_VIA_VSCODE: "Download via VS Code",
|
||||
BUTTON$SHOW_AGENT_TOOLS_AND_METADATA: "Show Agent Tools",
|
||||
CONVERSATION$SHOW_MICROAGENTS: "Show Microagents",
|
||||
CONVERSATION$SHOW_SKILLS: "Show Skills",
|
||||
BUTTON$DISPLAY_COST: "Display Cost",
|
||||
COMMON$CLOSE_CONVERSATION_STOP_RUNTIME:
|
||||
"Close Conversation (Stop Runtime)",
|
||||
@@ -290,7 +290,7 @@ describe("ConversationNameContextMenu", () => {
|
||||
onStop: vi.fn(),
|
||||
onDisplayCost: vi.fn(),
|
||||
onShowAgentTools: vi.fn(),
|
||||
onShowMicroagents: vi.fn(),
|
||||
onShowSkills: vi.fn(),
|
||||
onExportConversation: vi.fn(),
|
||||
onDownloadViaVSCode: vi.fn(),
|
||||
};
|
||||
@@ -304,7 +304,7 @@ describe("ConversationNameContextMenu", () => {
|
||||
expect(screen.getByTestId("stop-button")).toBeInTheDocument();
|
||||
expect(screen.getByTestId("display-cost-button")).toBeInTheDocument();
|
||||
expect(screen.getByTestId("show-agent-tools-button")).toBeInTheDocument();
|
||||
expect(screen.getByTestId("show-microagents-button")).toBeInTheDocument();
|
||||
expect(screen.getByTestId("show-skills-button")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByTestId("export-conversation-button"),
|
||||
).toBeInTheDocument();
|
||||
@@ -321,9 +321,7 @@ describe("ConversationNameContextMenu", () => {
|
||||
expect(
|
||||
screen.queryByTestId("show-agent-tools-button"),
|
||||
).not.toBeInTheDocument();
|
||||
expect(
|
||||
screen.queryByTestId("show-microagents-button"),
|
||||
).not.toBeInTheDocument();
|
||||
expect(screen.queryByTestId("show-skills-button")).not.toBeInTheDocument();
|
||||
expect(
|
||||
screen.queryByTestId("export-conversation-button"),
|
||||
).not.toBeInTheDocument();
|
||||
@@ -410,19 +408,19 @@ describe("ConversationNameContextMenu", () => {
|
||||
|
||||
it("should call show microagents handler when show microagents button is clicked", async () => {
|
||||
const user = userEvent.setup();
|
||||
const onShowMicroagents = vi.fn();
|
||||
const onShowSkills = vi.fn();
|
||||
|
||||
renderWithProviders(
|
||||
<ConversationNameContextMenu
|
||||
{...defaultProps}
|
||||
onShowMicroagents={onShowMicroagents}
|
||||
onShowSkills={onShowSkills}
|
||||
/>,
|
||||
);
|
||||
|
||||
const showMicroagentsButton = screen.getByTestId("show-microagents-button");
|
||||
const showMicroagentsButton = screen.getByTestId("show-skills-button");
|
||||
await user.click(showMicroagentsButton);
|
||||
|
||||
expect(onShowMicroagents).toHaveBeenCalledTimes(1);
|
||||
expect(onShowSkills).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it("should call export conversation handler when export conversation button is clicked", async () => {
|
||||
@@ -519,7 +517,7 @@ describe("ConversationNameContextMenu", () => {
|
||||
onStop: vi.fn(),
|
||||
onDisplayCost: vi.fn(),
|
||||
onShowAgentTools: vi.fn(),
|
||||
onShowMicroagents: vi.fn(),
|
||||
onShowSkills: vi.fn(),
|
||||
onExportConversation: vi.fn(),
|
||||
onDownloadViaVSCode: vi.fn(),
|
||||
};
|
||||
@@ -541,8 +539,8 @@ describe("ConversationNameContextMenu", () => {
|
||||
expect(screen.getByTestId("show-agent-tools-button")).toHaveTextContent(
|
||||
"Show Agent Tools",
|
||||
);
|
||||
expect(screen.getByTestId("show-microagents-button")).toHaveTextContent(
|
||||
"Show Microagents",
|
||||
expect(screen.getByTestId("show-skills-button")).toHaveTextContent(
|
||||
"Show Skills",
|
||||
);
|
||||
expect(screen.getByTestId("export-conversation-button")).toHaveTextContent(
|
||||
"Export Conversation",
|
||||
|
||||
@@ -48,9 +48,12 @@ describe("MaintenanceBanner", () => {
|
||||
expect(button).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// maintenance-banner
|
||||
|
||||
it("handles invalid date gracefully", () => {
|
||||
// Suppress expected console.warn for invalid date parsing
|
||||
const consoleWarnSpy = vi
|
||||
.spyOn(console, "warn")
|
||||
.mockImplementation(() => {});
|
||||
|
||||
const invalidTime = "invalid-date";
|
||||
|
||||
render(
|
||||
@@ -62,6 +65,9 @@ describe("MaintenanceBanner", () => {
|
||||
// Check if the banner is rendered
|
||||
const banner = screen.queryByTestId("maintenance-banner");
|
||||
expect(banner).not.toBeInTheDocument();
|
||||
|
||||
// Restore console.warn
|
||||
consoleWarnSpy.mockRestore();
|
||||
});
|
||||
|
||||
it("click on dismiss button removes banner", () => {
|
||||
|
||||
@@ -12,7 +12,7 @@ import GitService from "#/api/git-service/git-service.api";
|
||||
import { GitRepository } from "#/types/git";
|
||||
import { RepositoryMicroagent } from "#/types/microagent-management";
|
||||
import { Conversation } from "#/api/open-hands.types";
|
||||
import { useMicroagentManagementStore } from "#/state/microagent-management-store";
|
||||
import { useMicroagentManagementStore } from "#/stores/microagent-management-store";
|
||||
|
||||
// Mock hooks
|
||||
const mockUseUserProviders = vi.fn();
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { render, screen, fireEvent } from "@testing-library/react";
|
||||
import { describe, it, expect, vi } from "vitest";
|
||||
import { MCPServerForm } from "../mcp-server-form";
|
||||
import { MCPServerForm } from "#/components/features/settings/mcp-settings/mcp-server-form";
|
||||
|
||||
// i18n mock
|
||||
vi.mock("react-i18next", () => ({
|
||||
@@ -1,6 +1,6 @@
|
||||
import { render, screen } from "@testing-library/react";
|
||||
import { describe, it, expect, vi } from "vitest";
|
||||
import { MCPServerList } from "../mcp-server-list";
|
||||
import { MCPServerList } from "#/components/features/settings/mcp-settings/mcp-server-list";
|
||||
|
||||
// Mock react-i18next
|
||||
vi.mock("react-i18next", () => ({
|
||||
@@ -0,0 +1,28 @@
|
||||
import { render, screen } from "@testing-library/react";
|
||||
import { it, describe, expect, vi, beforeEach } from "vitest";
|
||||
import { EmailVerificationModal } from "#/components/features/waitlist/email-verification-modal";
|
||||
|
||||
describe("EmailVerificationModal", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it("should render the email verification message", () => {
|
||||
// Arrange & Act
|
||||
render(<EmailVerificationModal onClose={vi.fn()} />);
|
||||
|
||||
// Assert
|
||||
expect(
|
||||
screen.getByText("AUTH$PLEASE_CHECK_EMAIL_TO_VERIFY"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should render the TermsAndPrivacyNotice component", () => {
|
||||
// Arrange & Act
|
||||
render(<EmailVerificationModal onClose={vi.fn()} />);
|
||||
|
||||
// Assert
|
||||
const termsSection = screen.getByTestId("terms-and-privacy-notice");
|
||||
expect(termsSection).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
@@ -6,7 +6,7 @@ import { InteractiveChatBox } from "#/components/features/chat/interactive-chat-
|
||||
import { renderWithProviders } from "../../test-utils";
|
||||
import { AgentState } from "#/types/agent-state";
|
||||
import { useAgentState } from "#/hooks/use-agent-state";
|
||||
import { useConversationStore } from "#/state/conversation-store";
|
||||
import { useConversationStore } from "#/stores/conversation-store";
|
||||
|
||||
vi.mock("#/hooks/use-agent-state", () => ({
|
||||
useAgentState: vi.fn(),
|
||||
|
||||
@@ -1,91 +0,0 @@
|
||||
import { screen } from "@testing-library/react";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
|
||||
import { renderWithProviders } from "test-utils";
|
||||
import { MicroagentsModal } from "#/components/features/conversation-panel/microagents-modal";
|
||||
import ConversationService from "#/api/conversation-service/conversation-service.api";
|
||||
import { AgentState } from "#/types/agent-state";
|
||||
import { useAgentState } from "#/hooks/use-agent-state";
|
||||
|
||||
// Mock the agent state hook
|
||||
vi.mock("#/hooks/use-agent-state", () => ({
|
||||
useAgentState: vi.fn(),
|
||||
}));
|
||||
|
||||
// Mock the conversation ID hook
|
||||
vi.mock("#/hooks/use-conversation-id", () => ({
|
||||
useConversationId: () => ({ conversationId: "test-conversation-id" }),
|
||||
}));
|
||||
|
||||
describe("MicroagentsModal - Refresh Button", () => {
|
||||
const mockOnClose = vi.fn();
|
||||
const conversationId = "test-conversation-id";
|
||||
|
||||
const defaultProps = {
|
||||
onClose: mockOnClose,
|
||||
conversationId,
|
||||
};
|
||||
|
||||
const mockMicroagents = [
|
||||
{
|
||||
name: "Test Agent 1",
|
||||
type: "repo" as const,
|
||||
triggers: ["test", "example"],
|
||||
content: "This is test content for agent 1",
|
||||
},
|
||||
{
|
||||
name: "Test Agent 2",
|
||||
type: "knowledge" as const,
|
||||
triggers: ["help", "support"],
|
||||
content: "This is test content for agent 2",
|
||||
},
|
||||
];
|
||||
|
||||
beforeEach(() => {
|
||||
// Reset all mocks before each test
|
||||
vi.clearAllMocks();
|
||||
|
||||
// Setup default mock for getMicroagents
|
||||
vi.spyOn(ConversationService, "getMicroagents").mockResolvedValue({
|
||||
microagents: mockMicroagents,
|
||||
});
|
||||
|
||||
// Mock the agent state to return a ready state
|
||||
vi.mocked(useAgentState).mockReturnValue({
|
||||
curAgentState: AgentState.AWAITING_USER_INPUT,
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe("Refresh Button Rendering", () => {
|
||||
it("should render the refresh button with correct text and test ID", async () => {
|
||||
renderWithProviders(<MicroagentsModal {...defaultProps} />);
|
||||
|
||||
// Wait for the component to load and render the refresh button
|
||||
const refreshButton = await screen.findByTestId("refresh-microagents");
|
||||
expect(refreshButton).toBeInTheDocument();
|
||||
expect(refreshButton).toHaveTextContent("BUTTON$REFRESH");
|
||||
});
|
||||
});
|
||||
|
||||
describe("Refresh Button Functionality", () => {
|
||||
it("should call refetch when refresh button is clicked", async () => {
|
||||
const user = userEvent.setup();
|
||||
const refreshSpy = vi.spyOn(ConversationService, "getMicroagents");
|
||||
|
||||
renderWithProviders(<MicroagentsModal {...defaultProps} />);
|
||||
|
||||
// Wait for the component to load and render the refresh button
|
||||
const refreshButton = await screen.findByTestId("refresh-microagents");
|
||||
|
||||
refreshSpy.mockClear();
|
||||
|
||||
await user.click(refreshButton);
|
||||
|
||||
expect(refreshSpy).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
});
|
||||
394
frontend/__tests__/components/modals/skills/skill-modal.test.tsx
Normal file
394
frontend/__tests__/components/modals/skills/skill-modal.test.tsx
Normal file
@@ -0,0 +1,394 @@
|
||||
import { screen } from "@testing-library/react";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
|
||||
import { renderWithProviders } from "test-utils";
|
||||
import { SkillsModal } from "#/components/features/conversation-panel/skills-modal";
|
||||
import ConversationService from "#/api/conversation-service/conversation-service.api";
|
||||
import V1ConversationService from "#/api/conversation-service/v1-conversation-service.api";
|
||||
import { AgentState } from "#/types/agent-state";
|
||||
import { useAgentState } from "#/hooks/use-agent-state";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
|
||||
// Mock the agent state hook
|
||||
vi.mock("#/hooks/use-agent-state", () => ({
|
||||
useAgentState: vi.fn(),
|
||||
}));
|
||||
|
||||
// Mock the conversation ID hook
|
||||
vi.mock("#/hooks/use-conversation-id", () => ({
|
||||
useConversationId: () => ({ conversationId: "test-conversation-id" }),
|
||||
}));
|
||||
|
||||
describe("SkillsModal - Refresh Button", () => {
|
||||
const mockOnClose = vi.fn();
|
||||
const conversationId = "test-conversation-id";
|
||||
|
||||
const defaultProps = {
|
||||
onClose: mockOnClose,
|
||||
conversationId,
|
||||
};
|
||||
|
||||
const mockSkills = [
|
||||
{
|
||||
name: "Test Agent 1",
|
||||
type: "repo" as const,
|
||||
triggers: ["test", "example"],
|
||||
content: "This is test content for agent 1",
|
||||
},
|
||||
{
|
||||
name: "Test Agent 2",
|
||||
type: "knowledge" as const,
|
||||
triggers: ["help", "support"],
|
||||
content: "This is test content for agent 2",
|
||||
},
|
||||
];
|
||||
|
||||
beforeEach(() => {
|
||||
// Reset all mocks before each test
|
||||
vi.clearAllMocks();
|
||||
|
||||
// Setup default mock for getMicroagents (V0)
|
||||
vi.spyOn(ConversationService, "getMicroagents").mockResolvedValue({
|
||||
microagents: mockSkills,
|
||||
});
|
||||
|
||||
// Mock the agent state to return a ready state
|
||||
vi.mocked(useAgentState).mockReturnValue({
|
||||
curAgentState: AgentState.AWAITING_USER_INPUT,
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe("Refresh Button Rendering", () => {
|
||||
it("should render the refresh button with correct text and test ID", async () => {
|
||||
renderWithProviders(<SkillsModal {...defaultProps} />);
|
||||
|
||||
// Wait for the component to load and render the refresh button
|
||||
const refreshButton = await screen.findByTestId("refresh-skills");
|
||||
expect(refreshButton).toBeInTheDocument();
|
||||
expect(refreshButton).toHaveTextContent("BUTTON$REFRESH");
|
||||
});
|
||||
});
|
||||
|
||||
describe("Refresh Button Functionality", () => {
|
||||
it("should call refetch when refresh button is clicked", async () => {
|
||||
const user = userEvent.setup();
|
||||
const refreshSpy = vi.spyOn(ConversationService, "getMicroagents");
|
||||
|
||||
renderWithProviders(<SkillsModal {...defaultProps} />);
|
||||
|
||||
// Wait for the component to load and render the refresh button
|
||||
const refreshButton = await screen.findByTestId("refresh-skills");
|
||||
|
||||
// Clear previous calls to only track the click
|
||||
refreshSpy.mockClear();
|
||||
|
||||
await user.click(refreshButton);
|
||||
|
||||
// Verify the refresh triggered a new API call
|
||||
expect(refreshSpy).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("useConversationSkills - V1 API Integration", () => {
|
||||
const conversationId = "test-conversation-id";
|
||||
|
||||
const mockMicroagents = [
|
||||
{
|
||||
name: "V0 Test Agent",
|
||||
type: "repo" as const,
|
||||
triggers: ["v0"],
|
||||
content: "V0 skill content",
|
||||
},
|
||||
];
|
||||
|
||||
const mockSkills = [
|
||||
{
|
||||
name: "V1 Test Skill",
|
||||
type: "knowledge" as const,
|
||||
triggers: ["v1", "skill"],
|
||||
content: "V1 skill content",
|
||||
},
|
||||
];
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
|
||||
// Mock agent state
|
||||
vi.mocked(useAgentState).mockReturnValue({
|
||||
curAgentState: AgentState.AWAITING_USER_INPUT,
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe("V0 API Usage (v1_enabled: false)", () => {
|
||||
it("should call v0 ConversationService.getMicroagents when v1_enabled is false", async () => {
|
||||
// Arrange
|
||||
const getMicroagentsSpy = vi
|
||||
.spyOn(ConversationService, "getMicroagents")
|
||||
.mockResolvedValue({ microagents: mockMicroagents });
|
||||
|
||||
vi.spyOn(SettingsService, "getSettings").mockResolvedValue({
|
||||
v1_enabled: false,
|
||||
llm_model: "test-model",
|
||||
llm_base_url: "",
|
||||
agent: "test-agent",
|
||||
language: "en",
|
||||
llm_api_key: null,
|
||||
llm_api_key_set: false,
|
||||
search_api_key_set: false,
|
||||
confirmation_mode: false,
|
||||
security_analyzer: null,
|
||||
remote_runtime_resource_factor: null,
|
||||
provider_tokens_set: {},
|
||||
enable_default_condenser: false,
|
||||
condenser_max_size: null,
|
||||
enable_sound_notifications: false,
|
||||
enable_proactive_conversation_starters: false,
|
||||
enable_solvability_analysis: false,
|
||||
user_consents_to_analytics: null,
|
||||
max_budget_per_task: null,
|
||||
});
|
||||
|
||||
// Act
|
||||
renderWithProviders(<SkillsModal onClose={vi.fn()} />);
|
||||
|
||||
// Assert
|
||||
await screen.findByText("V0 Test Agent");
|
||||
expect(getMicroagentsSpy).toHaveBeenCalledWith(conversationId);
|
||||
expect(getMicroagentsSpy).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it("should display v0 skills correctly", async () => {
|
||||
// Arrange
|
||||
vi.spyOn(ConversationService, "getMicroagents").mockResolvedValue({
|
||||
microagents: mockMicroagents,
|
||||
});
|
||||
|
||||
vi.spyOn(SettingsService, "getSettings").mockResolvedValue({
|
||||
v1_enabled: false,
|
||||
llm_model: "test-model",
|
||||
llm_base_url: "",
|
||||
agent: "test-agent",
|
||||
language: "en",
|
||||
llm_api_key: null,
|
||||
llm_api_key_set: false,
|
||||
search_api_key_set: false,
|
||||
confirmation_mode: false,
|
||||
security_analyzer: null,
|
||||
remote_runtime_resource_factor: null,
|
||||
provider_tokens_set: {},
|
||||
enable_default_condenser: false,
|
||||
condenser_max_size: null,
|
||||
enable_sound_notifications: false,
|
||||
enable_proactive_conversation_starters: false,
|
||||
enable_solvability_analysis: false,
|
||||
user_consents_to_analytics: null,
|
||||
max_budget_per_task: null,
|
||||
});
|
||||
|
||||
// Act
|
||||
renderWithProviders(<SkillsModal onClose={vi.fn()} />);
|
||||
|
||||
// Assert
|
||||
const agentName = await screen.findByText("V0 Test Agent");
|
||||
expect(agentName).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("V1 API Usage (v1_enabled: true)", () => {
|
||||
it("should call v1 V1ConversationService.getSkills when v1_enabled is true", async () => {
|
||||
// Arrange
|
||||
const getSkillsSpy = vi
|
||||
.spyOn(V1ConversationService, "getSkills")
|
||||
.mockResolvedValue({ skills: mockSkills });
|
||||
|
||||
vi.spyOn(SettingsService, "getSettings").mockResolvedValue({
|
||||
v1_enabled: true,
|
||||
llm_model: "test-model",
|
||||
llm_base_url: "",
|
||||
agent: "test-agent",
|
||||
language: "en",
|
||||
llm_api_key: null,
|
||||
llm_api_key_set: false,
|
||||
search_api_key_set: false,
|
||||
confirmation_mode: false,
|
||||
security_analyzer: null,
|
||||
remote_runtime_resource_factor: null,
|
||||
provider_tokens_set: {},
|
||||
enable_default_condenser: false,
|
||||
condenser_max_size: null,
|
||||
enable_sound_notifications: false,
|
||||
enable_proactive_conversation_starters: false,
|
||||
enable_solvability_analysis: false,
|
||||
user_consents_to_analytics: null,
|
||||
max_budget_per_task: null,
|
||||
});
|
||||
|
||||
// Act
|
||||
renderWithProviders(<SkillsModal onClose={vi.fn()} />);
|
||||
|
||||
// Assert
|
||||
await screen.findByText("V1 Test Skill");
|
||||
expect(getSkillsSpy).toHaveBeenCalledWith(conversationId);
|
||||
expect(getSkillsSpy).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it("should display v1 skills correctly", async () => {
|
||||
// Arrange
|
||||
vi.spyOn(V1ConversationService, "getSkills").mockResolvedValue({
|
||||
skills: mockSkills,
|
||||
});
|
||||
|
||||
vi.spyOn(SettingsService, "getSettings").mockResolvedValue({
|
||||
v1_enabled: true,
|
||||
llm_model: "test-model",
|
||||
llm_base_url: "",
|
||||
agent: "test-agent",
|
||||
language: "en",
|
||||
llm_api_key: null,
|
||||
llm_api_key_set: false,
|
||||
search_api_key_set: false,
|
||||
confirmation_mode: false,
|
||||
security_analyzer: null,
|
||||
remote_runtime_resource_factor: null,
|
||||
provider_tokens_set: {},
|
||||
enable_default_condenser: false,
|
||||
condenser_max_size: null,
|
||||
enable_sound_notifications: false,
|
||||
enable_proactive_conversation_starters: false,
|
||||
enable_solvability_analysis: false,
|
||||
user_consents_to_analytics: null,
|
||||
max_budget_per_task: null,
|
||||
});
|
||||
|
||||
// Act
|
||||
renderWithProviders(<SkillsModal onClose={vi.fn()} />);
|
||||
|
||||
// Assert
|
||||
const skillName = await screen.findByText("V1 Test Skill");
|
||||
expect(skillName).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should use v1 API when v1_enabled is true", async () => {
|
||||
// Arrange
|
||||
vi.spyOn(SettingsService, "getSettings").mockResolvedValue({
|
||||
v1_enabled: true,
|
||||
llm_model: "test-model",
|
||||
llm_base_url: "",
|
||||
agent: "test-agent",
|
||||
language: "en",
|
||||
llm_api_key: null,
|
||||
llm_api_key_set: false,
|
||||
search_api_key_set: false,
|
||||
confirmation_mode: false,
|
||||
security_analyzer: null,
|
||||
remote_runtime_resource_factor: null,
|
||||
provider_tokens_set: {},
|
||||
enable_default_condenser: false,
|
||||
condenser_max_size: null,
|
||||
enable_sound_notifications: false,
|
||||
enable_proactive_conversation_starters: false,
|
||||
enable_solvability_analysis: false,
|
||||
user_consents_to_analytics: null,
|
||||
max_budget_per_task: null,
|
||||
});
|
||||
|
||||
const getSkillsSpy = vi
|
||||
.spyOn(V1ConversationService, "getSkills")
|
||||
.mockResolvedValue({
|
||||
skills: mockSkills,
|
||||
});
|
||||
|
||||
// Act
|
||||
renderWithProviders(<SkillsModal onClose={vi.fn()} />);
|
||||
|
||||
// Assert
|
||||
await screen.findByText("V1 Test Skill");
|
||||
// Verify v1 API was called
|
||||
expect(getSkillsSpy).toHaveBeenCalledWith(conversationId);
|
||||
});
|
||||
});
|
||||
|
||||
describe("API Switching on Settings Change", () => {
|
||||
it("should refetch using different API when v1_enabled setting changes", async () => {
|
||||
// Arrange
|
||||
const getMicroagentsSpy = vi
|
||||
.spyOn(ConversationService, "getMicroagents")
|
||||
.mockResolvedValue({ microagents: mockMicroagents });
|
||||
const getSkillsSpy = vi
|
||||
.spyOn(V1ConversationService, "getSkills")
|
||||
.mockResolvedValue({ skills: mockSkills });
|
||||
|
||||
const settingsSpy = vi
|
||||
.spyOn(SettingsService, "getSettings")
|
||||
.mockResolvedValue({
|
||||
v1_enabled: false,
|
||||
llm_model: "test-model",
|
||||
llm_base_url: "",
|
||||
agent: "test-agent",
|
||||
language: "en",
|
||||
llm_api_key: null,
|
||||
llm_api_key_set: false,
|
||||
search_api_key_set: false,
|
||||
confirmation_mode: false,
|
||||
security_analyzer: null,
|
||||
remote_runtime_resource_factor: null,
|
||||
provider_tokens_set: {},
|
||||
enable_default_condenser: false,
|
||||
condenser_max_size: null,
|
||||
enable_sound_notifications: false,
|
||||
enable_proactive_conversation_starters: false,
|
||||
enable_solvability_analysis: false,
|
||||
user_consents_to_analytics: null,
|
||||
max_budget_per_task: null,
|
||||
});
|
||||
|
||||
// Act - Initial render with v1_enabled: false
|
||||
const { rerender } = renderWithProviders(
|
||||
<SkillsModal onClose={vi.fn()} />,
|
||||
);
|
||||
|
||||
// Assert - v0 API called initially
|
||||
await screen.findByText("V0 Test Agent");
|
||||
expect(getMicroagentsSpy).toHaveBeenCalledWith(conversationId);
|
||||
|
||||
// Arrange - Change settings to v1_enabled: true
|
||||
settingsSpy.mockResolvedValue({
|
||||
v1_enabled: true,
|
||||
llm_model: "test-model",
|
||||
llm_base_url: "",
|
||||
agent: "test-agent",
|
||||
language: "en",
|
||||
llm_api_key: null,
|
||||
llm_api_key_set: false,
|
||||
search_api_key_set: false,
|
||||
confirmation_mode: false,
|
||||
security_analyzer: null,
|
||||
remote_runtime_resource_factor: null,
|
||||
provider_tokens_set: {},
|
||||
enable_default_condenser: false,
|
||||
condenser_max_size: null,
|
||||
enable_sound_notifications: false,
|
||||
enable_proactive_conversation_starters: false,
|
||||
enable_solvability_analysis: false,
|
||||
user_consents_to_analytics: null,
|
||||
max_budget_per_task: null,
|
||||
});
|
||||
|
||||
// Act - Force re-render
|
||||
rerender(<SkillsModal onClose={vi.fn()} />);
|
||||
|
||||
// Assert - v1 API should be called after settings change
|
||||
await screen.findByText("V1 Test Skill");
|
||||
expect(getSkillsSpy).toHaveBeenCalledWith(conversationId);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,48 @@
|
||||
import { render, screen } from "@testing-library/react";
|
||||
import { it, describe, expect } from "vitest";
|
||||
import { TermsAndPrivacyNotice } from "#/components/shared/terms-and-privacy-notice";
|
||||
|
||||
describe("TermsAndPrivacyNotice", () => {
|
||||
it("should render Terms of Service and Privacy Policy links", () => {
|
||||
// Arrange & Act
|
||||
render(<TermsAndPrivacyNotice />);
|
||||
|
||||
// Assert
|
||||
const termsSection = screen.getByTestId("terms-and-privacy-notice");
|
||||
expect(termsSection).toBeInTheDocument();
|
||||
|
||||
const tosLink = screen.getByRole("link", {
|
||||
name: "COMMON$TERMS_OF_SERVICE",
|
||||
});
|
||||
const privacyLink = screen.getByRole("link", {
|
||||
name: "COMMON$PRIVACY_POLICY",
|
||||
});
|
||||
|
||||
expect(tosLink).toBeInTheDocument();
|
||||
expect(tosLink).toHaveAttribute("href", "https://www.all-hands.dev/tos");
|
||||
expect(tosLink).toHaveAttribute("target", "_blank");
|
||||
expect(tosLink).toHaveAttribute("rel", "noopener noreferrer");
|
||||
|
||||
expect(privacyLink).toBeInTheDocument();
|
||||
expect(privacyLink).toHaveAttribute(
|
||||
"href",
|
||||
"https://www.all-hands.dev/privacy",
|
||||
);
|
||||
expect(privacyLink).toHaveAttribute("target", "_blank");
|
||||
expect(privacyLink).toHaveAttribute("rel", "noopener noreferrer");
|
||||
});
|
||||
|
||||
it("should render all required text content", () => {
|
||||
// Arrange & Act
|
||||
render(<TermsAndPrivacyNotice />);
|
||||
|
||||
// Assert
|
||||
const termsSection = screen.getByTestId("terms-and-privacy-notice");
|
||||
expect(termsSection).toHaveTextContent(
|
||||
"AUTH$BY_SIGNING_UP_YOU_AGREE_TO_OUR",
|
||||
);
|
||||
expect(termsSection).toHaveTextContent("COMMON$TERMS_OF_SERVICE");
|
||||
expect(termsSection).toHaveTextContent("COMMON$AND");
|
||||
expect(termsSection).toHaveTextContent("COMMON$PRIVACY_POLICY");
|
||||
});
|
||||
});
|
||||
@@ -1,7 +1,7 @@
|
||||
import { act, screen } from "@testing-library/react";
|
||||
import { renderWithProviders } from "test-utils";
|
||||
import { vi, describe, afterEach, it, expect } from "vitest";
|
||||
import { Command, useCommandStore } from "#/state/command-store";
|
||||
import { Command, useCommandStore } from "#/stores/command-store";
|
||||
import Terminal from "#/components/features/terminal/terminal";
|
||||
|
||||
const renderTerminal = (commands: Command[] = []) => {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { describe, it, expect } from "vitest";
|
||||
import { getObservationContent } from "../get-observation-content";
|
||||
import { getObservationContent } from "#/components/v1/chat/event-content-helpers/get-observation-content";
|
||||
import { ObservationEvent } from "#/types/v1/core";
|
||||
import { BrowserObservation } from "#/types/v1/core/base/observation";
|
||||
|
||||
@@ -6,13 +6,14 @@ import {
|
||||
beforeEach,
|
||||
afterAll,
|
||||
afterEach,
|
||||
vi,
|
||||
} from "vitest";
|
||||
import { screen, waitFor, render, cleanup } from "@testing-library/react";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { http, HttpResponse } from "msw";
|
||||
import { useOptimisticUserMessageStore } from "#/stores/optimistic-user-message-store";
|
||||
import { useBrowserStore } from "#/stores/browser-store";
|
||||
import { useCommandStore } from "#/state/command-store";
|
||||
import { useCommandStore } from "#/stores/command-store";
|
||||
import {
|
||||
createMockMessageEvent,
|
||||
createMockUserMessageEvent,
|
||||
@@ -141,6 +142,11 @@ describe("Conversation WebSocket Handler", () => {
|
||||
});
|
||||
|
||||
it("should handle malformed/invalid event data gracefully", async () => {
|
||||
// Suppress expected console.warn for invalid JSON parsing
|
||||
const consoleWarnSpy = vi
|
||||
.spyOn(console, "warn")
|
||||
.mockImplementation(() => {});
|
||||
|
||||
// Set up MSW to send various invalid events when connection is established
|
||||
mswServer.use(
|
||||
wsLink.addEventListener("connection", ({ client, server }) => {
|
||||
@@ -203,6 +209,9 @@ describe("Conversation WebSocket Handler", () => {
|
||||
"valid-event-123",
|
||||
);
|
||||
expect(screen.getByTestId("ui-events-count")).toHaveTextContent("1");
|
||||
|
||||
// Restore console.warn
|
||||
consoleWarnSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -453,18 +462,10 @@ describe("Conversation WebSocket Handler", () => {
|
||||
|
||||
// Set up MSW to mock both the HTTP API and WebSocket connection
|
||||
mswServer.use(
|
||||
http.get("/api/v1/events/count", ({ request }) => {
|
||||
const url = new URL(request.url);
|
||||
const conversationIdParam = url.searchParams.get(
|
||||
"conversation_id__eq",
|
||||
);
|
||||
|
||||
if (conversationIdParam === conversationId) {
|
||||
return HttpResponse.json(expectedEventCount);
|
||||
}
|
||||
|
||||
return HttpResponse.json(0);
|
||||
}),
|
||||
http.get(
|
||||
`http://localhost:3000/api/conversations/${conversationId}/events/count`,
|
||||
() => HttpResponse.json(expectedEventCount),
|
||||
),
|
||||
wsLink.addEventListener("connection", ({ client, server }) => {
|
||||
server.connect();
|
||||
// Send all history events
|
||||
@@ -520,18 +521,10 @@ describe("Conversation WebSocket Handler", () => {
|
||||
|
||||
// Set up MSW to mock both the HTTP API and WebSocket connection
|
||||
mswServer.use(
|
||||
http.get("/api/v1/events/count", ({ request }) => {
|
||||
const url = new URL(request.url);
|
||||
const conversationIdParam = url.searchParams.get(
|
||||
"conversation_id__eq",
|
||||
);
|
||||
|
||||
if (conversationIdParam === conversationId) {
|
||||
return HttpResponse.json(0);
|
||||
}
|
||||
|
||||
return HttpResponse.json(0);
|
||||
}),
|
||||
http.get(
|
||||
`http://localhost:3000/api/conversations/${conversationId}/events/count`,
|
||||
() => HttpResponse.json(0),
|
||||
),
|
||||
wsLink.addEventListener("connection", ({ server }) => {
|
||||
server.connect();
|
||||
// No events sent for empty history
|
||||
@@ -577,18 +570,10 @@ describe("Conversation WebSocket Handler", () => {
|
||||
|
||||
// Set up MSW to mock both the HTTP API and WebSocket connection
|
||||
mswServer.use(
|
||||
http.get("/api/v1/events/count", ({ request }) => {
|
||||
const url = new URL(request.url);
|
||||
const conversationIdParam = url.searchParams.get(
|
||||
"conversation_id__eq",
|
||||
);
|
||||
|
||||
if (conversationIdParam === conversationId) {
|
||||
return HttpResponse.json(expectedEventCount);
|
||||
}
|
||||
|
||||
return HttpResponse.json(0);
|
||||
}),
|
||||
http.get(
|
||||
`http://localhost:3000/api/conversations/${conversationId}/events/count`,
|
||||
() => HttpResponse.json(expectedEventCount),
|
||||
),
|
||||
wsLink.addEventListener("connection", ({ client, server }) => {
|
||||
server.connect();
|
||||
// Send all history events
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
/* eslint-disable max-classes-per-file */
|
||||
import { beforeAll, describe, expect, it, vi, afterEach } from "vitest";
|
||||
import { useTerminal } from "#/hooks/use-terminal";
|
||||
import { Command, useCommandStore } from "#/state/command-store";
|
||||
import { Command, useCommandStore } from "#/stores/command-store";
|
||||
import { renderWithProviders } from "../../test-utils";
|
||||
|
||||
// Mock the WsClient context
|
||||
@@ -43,6 +42,11 @@ describe("useTerminal", () => {
|
||||
write: vi.fn(),
|
||||
writeln: vi.fn(),
|
||||
dispose: vi.fn(),
|
||||
element: document.createElement("div"),
|
||||
}));
|
||||
|
||||
const mockFitAddon = vi.hoisted(() => ({
|
||||
fit: vi.fn(),
|
||||
}));
|
||||
|
||||
beforeAll(() => {
|
||||
@@ -68,6 +72,15 @@ describe("useTerminal", () => {
|
||||
writeln = mockTerminal.writeln;
|
||||
|
||||
dispose = mockTerminal.dispose;
|
||||
|
||||
element = mockTerminal.element;
|
||||
},
|
||||
}));
|
||||
|
||||
// mock FitAddon
|
||||
vi.mock("@xterm/addon-fit", () => ({
|
||||
FitAddon: class {
|
||||
fit = mockFitAddon.fit;
|
||||
},
|
||||
}));
|
||||
});
|
||||
@@ -96,4 +109,18 @@ describe("useTerminal", () => {
|
||||
expect(mockTerminal.writeln).toHaveBeenNthCalledWith(1, "echo hello");
|
||||
expect(mockTerminal.writeln).toHaveBeenNthCalledWith(2, "hello");
|
||||
});
|
||||
|
||||
it("should not call fit() when terminal.element is null", () => {
|
||||
// Temporarily set element to null to simulate terminal not being opened
|
||||
const originalElement = mockTerminal.element;
|
||||
mockTerminal.element = null as unknown as HTMLDivElement;
|
||||
|
||||
renderWithProviders(<TestTerminalComponent />);
|
||||
|
||||
// fit() should not be called because terminal.element is null
|
||||
expect(mockFitAddon.fit).not.toHaveBeenCalled();
|
||||
|
||||
// Restore original element
|
||||
mockTerminal.element = originalElement;
|
||||
});
|
||||
});
|
||||
|
||||
@@ -34,7 +34,11 @@ describe("useWebSocket", () => {
|
||||
}),
|
||||
);
|
||||
|
||||
beforeAll(() => mswServer.listen());
|
||||
beforeAll(() =>
|
||||
mswServer.listen({
|
||||
onUnhandledRequest: "warn",
|
||||
}),
|
||||
);
|
||||
afterEach(() => mswServer.resetHandlers());
|
||||
afterAll(() => mswServer.close());
|
||||
|
||||
|
||||
227
frontend/__tests__/router.md
Normal file
227
frontend/__tests__/router.md
Normal file
@@ -0,0 +1,227 @@
|
||||
# Testing with React Router
|
||||
|
||||
## Overview
|
||||
|
||||
React Router components and hooks require a routing context to function. In tests, we need to provide this context while maintaining control over the routing state.
|
||||
|
||||
This guide covers the two main approaches used in the OpenHands frontend:
|
||||
|
||||
1. **`createRoutesStub`** - Creates a complete route structure for testing components with their actual route configuration, loaders, and nested routes.
|
||||
2. **`MemoryRouter`** - Provides a minimal routing context for components that just need router hooks to work.
|
||||
|
||||
Choose your approach based on what your component actually needs from the router.
|
||||
|
||||
## When to Use Each Approach
|
||||
|
||||
### `createRoutesStub` (Recommended)
|
||||
|
||||
Use `createRoutesStub` when your component:
|
||||
- Relies on route parameters (`useParams`)
|
||||
- Uses loader data (`useLoaderData`) or `clientLoader`
|
||||
- Has nested routes or uses `<Outlet />`
|
||||
- Needs to test navigation between routes
|
||||
|
||||
> [!NOTE]
|
||||
> `createRoutesStub` is intended for unit testing **reusable components** that depend on router context. For testing full route/page components, consider E2E tests (Playwright, Cypress) instead.
|
||||
|
||||
```typescript
|
||||
import { createRoutesStub } from "react-router";
|
||||
import { render } from "@testing-library/react";
|
||||
|
||||
const RouterStub = createRoutesStub([
|
||||
{
|
||||
Component: MyRouteComponent,
|
||||
path: "/conversations/:conversationId",
|
||||
},
|
||||
]);
|
||||
|
||||
render(<RouterStub initialEntries={["/conversations/123"]} />);
|
||||
```
|
||||
|
||||
**With nested routes and loaders:**
|
||||
|
||||
```typescript
|
||||
const RouterStub = createRoutesStub([
|
||||
{
|
||||
Component: SettingsScreen,
|
||||
clientLoader,
|
||||
path: "/settings",
|
||||
children: [
|
||||
{
|
||||
Component: () => <div data-testid="llm-settings" />,
|
||||
path: "/settings",
|
||||
},
|
||||
{
|
||||
Component: () => <div data-testid="git-settings" />,
|
||||
path: "/settings/integrations",
|
||||
},
|
||||
],
|
||||
},
|
||||
]);
|
||||
|
||||
render(<RouterStub initialEntries={["/settings/integrations"]} />);
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> When using `clientLoader` from a Route module, you may encounter type mismatches. Use `@ts-expect-error` as a workaround:
|
||||
|
||||
```typescript
|
||||
import { clientLoader } from "@/routes/settings";
|
||||
|
||||
const RouterStub = createRoutesStub([
|
||||
{
|
||||
path: "/settings",
|
||||
Component: SettingsScreen,
|
||||
// @ts-expect-error: loader types won't align between test and app code
|
||||
loader: clientLoader,
|
||||
},
|
||||
]);
|
||||
```
|
||||
|
||||
### `MemoryRouter`
|
||||
|
||||
Use `MemoryRouter` when your component:
|
||||
- Only needs basic routing context to render
|
||||
- Uses `<Link>` components but you don't need to test navigation
|
||||
- Doesn't depend on specific route parameters or loaders
|
||||
|
||||
```typescript
|
||||
import { MemoryRouter } from "react-router";
|
||||
import { render } from "@testing-library/react";
|
||||
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<MyComponent />
|
||||
</MemoryRouter>
|
||||
);
|
||||
```
|
||||
|
||||
**With initial route:**
|
||||
|
||||
```typescript
|
||||
render(
|
||||
<MemoryRouter initialEntries={["/some/path"]}>
|
||||
<MyComponent />
|
||||
</MemoryRouter>
|
||||
);
|
||||
```
|
||||
|
||||
## Anti-patterns to Avoid
|
||||
|
||||
### Using `BrowserRouter` in tests
|
||||
|
||||
`BrowserRouter` interacts with the actual browser history API, which can cause issues in test environments:
|
||||
|
||||
```typescript
|
||||
// ❌ Avoid
|
||||
render(
|
||||
<BrowserRouter>
|
||||
<MyComponent />
|
||||
</BrowserRouter>
|
||||
);
|
||||
|
||||
// ✅ Use MemoryRouter instead
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<MyComponent />
|
||||
</MemoryRouter>
|
||||
);
|
||||
```
|
||||
|
||||
### Mocking router hooks when `createRoutesStub` would work
|
||||
|
||||
Mocking hooks like `useParams` directly can be brittle and doesn't test the actual routing behavior:
|
||||
|
||||
```typescript
|
||||
// ❌ Avoid when possible
|
||||
vi.mock("react-router", async () => {
|
||||
const actual = await vi.importActual("react-router");
|
||||
return {
|
||||
...actual,
|
||||
useParams: () => ({ conversationId: "123" }),
|
||||
};
|
||||
});
|
||||
|
||||
// ✅ Prefer createRoutesStub - tests real routing behavior
|
||||
const RouterStub = createRoutesStub([
|
||||
{
|
||||
Component: MyComponent,
|
||||
path: "/conversations/:conversationId",
|
||||
},
|
||||
]);
|
||||
|
||||
render(<RouterStub initialEntries={["/conversations/123"]} />);
|
||||
```
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Combining with `QueryClientProvider`
|
||||
|
||||
Many components need both routing and TanStack Query context:
|
||||
|
||||
```typescript
|
||||
import { createRoutesStub } from "react-router";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: { retry: false },
|
||||
},
|
||||
});
|
||||
|
||||
const RouterStub = createRoutesStub([
|
||||
{
|
||||
Component: MyComponent,
|
||||
path: "/",
|
||||
},
|
||||
]);
|
||||
|
||||
render(<RouterStub />, {
|
||||
wrapper: ({ children }) => (
|
||||
<QueryClientProvider client={queryClient}>
|
||||
{children}
|
||||
</QueryClientProvider>
|
||||
),
|
||||
});
|
||||
```
|
||||
|
||||
### Testing navigation behavior
|
||||
|
||||
Verify that user interactions trigger the expected navigation:
|
||||
|
||||
```typescript
|
||||
import { createRoutesStub } from "react-router";
|
||||
import { screen } from "@testing-library/react";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
|
||||
const RouterStub = createRoutesStub([
|
||||
{
|
||||
Component: HomeScreen,
|
||||
path: "/",
|
||||
},
|
||||
{
|
||||
Component: () => <div data-testid="settings-screen" />,
|
||||
path: "/settings",
|
||||
},
|
||||
]);
|
||||
|
||||
render(<RouterStub initialEntries={["/"]} />);
|
||||
|
||||
const user = userEvent.setup();
|
||||
await user.click(screen.getByRole("link", { name: /settings/i }));
|
||||
|
||||
expect(screen.getByTestId("settings-screen")).toBeInTheDocument();
|
||||
```
|
||||
|
||||
## See Also
|
||||
|
||||
### Codebase Examples
|
||||
|
||||
- [settings.test.tsx](__tests__/routes/settings.test.tsx) - `createRoutesStub` with nested routes and loaders
|
||||
- [home-screen.test.tsx](__tests__/routes/home-screen.test.tsx) - `createRoutesStub` with navigation testing
|
||||
- [chat-interface.test.tsx](__tests__/components/chat/chat-interface.test.tsx) - `MemoryRouter` usage
|
||||
|
||||
### Official Documentation
|
||||
|
||||
- [React Router Testing Guide](https://reactrouter.com/start/framework/testing) - Official guide on testing with `createRoutesStub`
|
||||
- [MemoryRouter API](https://reactrouter.com/api/declarative-routers/MemoryRouter) - API reference for `MemoryRouter`
|
||||
@@ -298,6 +298,7 @@ describe("Form submission", () => {
|
||||
gitlab: { token: "", host: "" },
|
||||
bitbucket: { token: "", host: "" },
|
||||
azure_devops: { token: "", host: "" },
|
||||
forgejo: { token: "", host: "" },
|
||||
});
|
||||
});
|
||||
|
||||
@@ -320,6 +321,7 @@ describe("Form submission", () => {
|
||||
gitlab: { token: "test-token", host: "" },
|
||||
bitbucket: { token: "", host: "" },
|
||||
azure_devops: { token: "", host: "" },
|
||||
forgejo: { token: "", host: "" },
|
||||
});
|
||||
});
|
||||
|
||||
@@ -342,6 +344,7 @@ describe("Form submission", () => {
|
||||
gitlab: { token: "", host: "" },
|
||||
bitbucket: { token: "test-token", host: "" },
|
||||
azure_devops: { token: "", host: "" },
|
||||
forgejo: { token: "", host: "" },
|
||||
});
|
||||
});
|
||||
|
||||
@@ -364,6 +367,7 @@ describe("Form submission", () => {
|
||||
gitlab: { token: "", host: "" },
|
||||
bitbucket: { token: "", host: "" },
|
||||
azure_devops: { token: "test-token", host: "" },
|
||||
forgejo: { token: "", host: "" },
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ describe("Content", () => {
|
||||
|
||||
await waitFor(() => {
|
||||
expect(provider).toHaveValue("OpenHands");
|
||||
expect(model).toHaveValue("claude-sonnet-4-20250514");
|
||||
expect(model).toHaveValue("claude-opus-4-5-20251101");
|
||||
|
||||
expect(apiKey).toHaveValue("");
|
||||
expect(apiKey).toHaveProperty("placeholder", "");
|
||||
@@ -190,7 +190,7 @@ describe("Content", () => {
|
||||
const agent = screen.getByTestId("agent-input");
|
||||
const condensor = screen.getByTestId("enable-memory-condenser-switch");
|
||||
|
||||
expect(model).toHaveValue("openhands/claude-sonnet-4-20250514");
|
||||
expect(model).toHaveValue("openhands/claude-opus-4-5-20251101");
|
||||
expect(baseUrl).toHaveValue("");
|
||||
expect(apiKey).toHaveValue("");
|
||||
expect(apiKey).toHaveProperty("placeholder", "");
|
||||
@@ -910,6 +910,162 @@ describe("Form submission", () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe("View persistence after saving advanced settings", () => {
|
||||
it("should remain on Advanced view after saving when memory condenser is disabled", async () => {
|
||||
// Arrange: Start with default settings (basic view)
|
||||
const getSettingsSpy = vi.spyOn(SettingsService, "getSettings");
|
||||
getSettingsSpy.mockResolvedValue({
|
||||
...MOCK_DEFAULT_USER_SETTINGS,
|
||||
});
|
||||
const saveSettingsSpy = vi.spyOn(SettingsService, "saveSettings");
|
||||
saveSettingsSpy.mockResolvedValue(true);
|
||||
|
||||
renderLlmSettingsScreen();
|
||||
await screen.findByTestId("llm-settings-screen");
|
||||
|
||||
// Verify we start in basic view
|
||||
expect(screen.getByTestId("llm-settings-form-basic")).toBeInTheDocument();
|
||||
|
||||
// Act: User manually switches to Advanced view
|
||||
const advancedSwitch = screen.getByTestId("advanced-settings-switch");
|
||||
await userEvent.click(advancedSwitch);
|
||||
await screen.findByTestId("llm-settings-form-advanced");
|
||||
|
||||
// User disables memory condenser (advanced-only setting)
|
||||
const condenserSwitch = screen.getByTestId(
|
||||
"enable-memory-condenser-switch",
|
||||
);
|
||||
expect(condenserSwitch).toBeChecked();
|
||||
await userEvent.click(condenserSwitch);
|
||||
expect(condenserSwitch).not.toBeChecked();
|
||||
|
||||
// Mock the updated settings that will be returned after save
|
||||
getSettingsSpy.mockResolvedValue({
|
||||
...MOCK_DEFAULT_USER_SETTINGS,
|
||||
enable_default_condenser: false, // Now disabled
|
||||
});
|
||||
|
||||
// User saves settings
|
||||
const submitButton = screen.getByTestId("submit-button");
|
||||
await userEvent.click(submitButton);
|
||||
|
||||
// Assert: View should remain on Advanced after save
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByTestId("llm-settings-form-advanced"),
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
screen.queryByTestId("llm-settings-form-basic"),
|
||||
).not.toBeInTheDocument();
|
||||
expect(advancedSwitch).toBeChecked();
|
||||
});
|
||||
});
|
||||
|
||||
it("should remain on Advanced view after saving when condenser max size is customized", async () => {
|
||||
// Arrange: Start with default settings
|
||||
const getSettingsSpy = vi.spyOn(SettingsService, "getSettings");
|
||||
getSettingsSpy.mockResolvedValue({
|
||||
...MOCK_DEFAULT_USER_SETTINGS,
|
||||
});
|
||||
const saveSettingsSpy = vi.spyOn(SettingsService, "saveSettings");
|
||||
saveSettingsSpy.mockResolvedValue(true);
|
||||
|
||||
renderLlmSettingsScreen();
|
||||
await screen.findByTestId("llm-settings-screen");
|
||||
|
||||
// Act: User manually switches to Advanced view
|
||||
const advancedSwitch = screen.getByTestId("advanced-settings-switch");
|
||||
await userEvent.click(advancedSwitch);
|
||||
await screen.findByTestId("llm-settings-form-advanced");
|
||||
|
||||
// User sets custom condenser max size (advanced-only setting)
|
||||
const condenserMaxSizeInput = screen.getByTestId(
|
||||
"condenser-max-size-input",
|
||||
);
|
||||
await userEvent.clear(condenserMaxSizeInput);
|
||||
await userEvent.type(condenserMaxSizeInput, "200");
|
||||
|
||||
// Mock the updated settings that will be returned after save
|
||||
getSettingsSpy.mockResolvedValue({
|
||||
...MOCK_DEFAULT_USER_SETTINGS,
|
||||
condenser_max_size: 200, // Custom value
|
||||
});
|
||||
|
||||
// User saves settings
|
||||
const submitButton = screen.getByTestId("submit-button");
|
||||
await userEvent.click(submitButton);
|
||||
|
||||
// Assert: View should remain on Advanced after save
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByTestId("llm-settings-form-advanced"),
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
screen.queryByTestId("llm-settings-form-basic"),
|
||||
).not.toBeInTheDocument();
|
||||
expect(advancedSwitch).toBeChecked();
|
||||
});
|
||||
});
|
||||
|
||||
it("should remain on Advanced view after saving when search API key is set", async () => {
|
||||
// Arrange: Start with default settings (non-SaaS mode to show search API key field)
|
||||
const getConfigSpy = vi.spyOn(OptionService, "getConfig");
|
||||
getConfigSpy.mockResolvedValue({
|
||||
APP_MODE: "oss",
|
||||
GITHUB_CLIENT_ID: "fake-github-client-id",
|
||||
POSTHOG_CLIENT_KEY: "fake-posthog-client-key",
|
||||
FEATURE_FLAGS: {
|
||||
ENABLE_BILLING: false,
|
||||
HIDE_LLM_SETTINGS: false,
|
||||
ENABLE_JIRA: false,
|
||||
ENABLE_JIRA_DC: false,
|
||||
ENABLE_LINEAR: false,
|
||||
},
|
||||
});
|
||||
|
||||
const getSettingsSpy = vi.spyOn(SettingsService, "getSettings");
|
||||
getSettingsSpy.mockResolvedValue({
|
||||
...MOCK_DEFAULT_USER_SETTINGS,
|
||||
search_api_key: "", // Default empty value
|
||||
});
|
||||
const saveSettingsSpy = vi.spyOn(SettingsService, "saveSettings");
|
||||
saveSettingsSpy.mockResolvedValue(true);
|
||||
|
||||
renderLlmSettingsScreen();
|
||||
await screen.findByTestId("llm-settings-screen");
|
||||
|
||||
// Act: User manually switches to Advanced view
|
||||
const advancedSwitch = screen.getByTestId("advanced-settings-switch");
|
||||
await userEvent.click(advancedSwitch);
|
||||
await screen.findByTestId("llm-settings-form-advanced");
|
||||
|
||||
// User sets search API key (advanced-only setting)
|
||||
const searchApiKeyInput = screen.getByTestId("search-api-key-input");
|
||||
await userEvent.type(searchApiKeyInput, "test-search-api-key");
|
||||
|
||||
// Mock the updated settings that will be returned after save
|
||||
getSettingsSpy.mockResolvedValue({
|
||||
...MOCK_DEFAULT_USER_SETTINGS,
|
||||
search_api_key: "test-search-api-key", // Now set
|
||||
});
|
||||
|
||||
// User saves settings
|
||||
const submitButton = screen.getByTestId("submit-button");
|
||||
await userEvent.click(submitButton);
|
||||
|
||||
// Assert: View should remain on Advanced after save
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByTestId("llm-settings-form-advanced"),
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
screen.queryByTestId("llm-settings-form-basic"),
|
||||
).not.toBeInTheDocument();
|
||||
expect(advancedSwitch).toBeChecked();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("Status toasts", () => {
|
||||
describe("Basic form", () => {
|
||||
it("should call displaySuccessToast when the settings are saved", async () => {
|
||||
|
||||
242
frontend/__tests__/routes/root-layout.test.tsx
Normal file
242
frontend/__tests__/routes/root-layout.test.tsx
Normal file
@@ -0,0 +1,242 @@
|
||||
import { render, screen, waitFor } from "@testing-library/react";
|
||||
import { it, describe, expect, vi, beforeEach, afterEach } from "vitest";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { createRoutesStub } from "react-router";
|
||||
import MainApp from "#/routes/root-layout";
|
||||
import OptionService from "#/api/option-service/option-service.api";
|
||||
import AuthService from "#/api/auth-service/auth-service.api";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
|
||||
// Mock other hooks that are not the focus of these tests
|
||||
vi.mock("#/hooks/use-github-auth-url", () => ({
|
||||
useGitHubAuthUrl: () => "https://github.com/oauth/authorize",
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/use-is-on-tos-page", () => ({
|
||||
useIsOnTosPage: () => false,
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/use-auto-login", () => ({
|
||||
useAutoLogin: () => {},
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/use-auth-callback", () => ({
|
||||
useAuthCallback: () => {},
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/use-migrate-user-consent", () => ({
|
||||
useMigrateUserConsent: () => ({
|
||||
migrateUserConsent: vi.fn(),
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/use-reo-tracking", () => ({
|
||||
useReoTracking: () => {},
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/use-sync-posthog-consent", () => ({
|
||||
useSyncPostHogConsent: () => {},
|
||||
}));
|
||||
|
||||
vi.mock("#/utils/custom-toast-handlers", () => ({
|
||||
displaySuccessToast: vi.fn(),
|
||||
}));
|
||||
|
||||
const RouterStub = createRoutesStub([
|
||||
{
|
||||
Component: MainApp,
|
||||
path: "/",
|
||||
children: [
|
||||
{
|
||||
Component: () => <div data-testid="outlet-content">Content</div>,
|
||||
path: "/",
|
||||
},
|
||||
],
|
||||
},
|
||||
]);
|
||||
|
||||
const createWrapper = () => {
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: {
|
||||
retry: false,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
return ({ children }: { children: React.ReactNode }) => (
|
||||
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
|
||||
);
|
||||
};
|
||||
|
||||
describe("MainApp - Email Verification Flow", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
|
||||
// Default mocks for services
|
||||
vi.spyOn(OptionService, "getConfig").mockResolvedValue({
|
||||
APP_MODE: "saas",
|
||||
GITHUB_CLIENT_ID: "test-client-id",
|
||||
POSTHOG_CLIENT_KEY: "test-posthog-key",
|
||||
PROVIDERS_CONFIGURED: ["github"],
|
||||
AUTH_URL: "https://auth.example.com",
|
||||
FEATURE_FLAGS: {
|
||||
ENABLE_BILLING: false,
|
||||
HIDE_LLM_SETTINGS: false,
|
||||
ENABLE_JIRA: false,
|
||||
ENABLE_JIRA_DC: false,
|
||||
ENABLE_LINEAR: false,
|
||||
},
|
||||
});
|
||||
|
||||
vi.spyOn(AuthService, "authenticate").mockResolvedValue(true);
|
||||
|
||||
vi.spyOn(SettingsService, "getSettings").mockResolvedValue({
|
||||
language: "en",
|
||||
user_consents_to_analytics: true,
|
||||
llm_model: "",
|
||||
llm_base_url: "",
|
||||
agent: "",
|
||||
llm_api_key: null,
|
||||
llm_api_key_set: false,
|
||||
search_api_key_set: false,
|
||||
confirmation_mode: false,
|
||||
security_analyzer: null,
|
||||
remote_runtime_resource_factor: null,
|
||||
provider_tokens_set: {},
|
||||
enable_default_condenser: false,
|
||||
condenser_max_size: null,
|
||||
enable_sound_notifications: false,
|
||||
enable_proactive_conversation_starters: false,
|
||||
enable_solvability_analysis: false,
|
||||
max_budget_per_task: null,
|
||||
});
|
||||
|
||||
// Mock localStorage
|
||||
vi.stubGlobal("localStorage", {
|
||||
getItem: vi.fn(() => null),
|
||||
setItem: vi.fn(),
|
||||
removeItem: vi.fn(),
|
||||
clear: vi.fn(),
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
it("should display EmailVerificationModal when email_verification_required=true is in query params", async () => {
|
||||
// Arrange & Act
|
||||
render(
|
||||
<RouterStub initialEntries={["/?email_verification_required=true"]} />,
|
||||
{ wrapper: createWrapper() },
|
||||
);
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("AUTH$PLEASE_CHECK_EMAIL_TO_VERIFY"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should set emailVerified state and pass to AuthModal when email_verified=true is in query params", async () => {
|
||||
// Arrange
|
||||
// Mock a 401 error to simulate unauthenticated user
|
||||
const axiosError = {
|
||||
response: { status: 401 },
|
||||
isAxiosError: true,
|
||||
};
|
||||
vi.spyOn(AuthService, "authenticate").mockRejectedValue(axiosError);
|
||||
|
||||
// Act
|
||||
render(<RouterStub initialEntries={["/?email_verified=true"]} />, {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
// Assert - Wait for AuthModal to render (since user is not authenticated)
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("AUTH$EMAIL_VERIFIED_PLEASE_LOGIN"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should handle both email_verification_required and email_verified params together", async () => {
|
||||
// Arrange & Act
|
||||
render(
|
||||
<RouterStub
|
||||
initialEntries={[
|
||||
"/?email_verification_required=true&email_verified=true",
|
||||
]}
|
||||
/>,
|
||||
{ wrapper: createWrapper() },
|
||||
);
|
||||
|
||||
// Assert - EmailVerificationModal should take precedence
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("AUTH$PLEASE_CHECK_EMAIL_TO_VERIFY"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should remove query parameters from URL after processing", async () => {
|
||||
// Arrange & Act
|
||||
const { container } = render(
|
||||
<RouterStub initialEntries={["/?email_verification_required=true"]} />,
|
||||
{ wrapper: createWrapper() },
|
||||
);
|
||||
|
||||
// Assert - Wait for the modal to appear (which indicates processing happened)
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("AUTH$PLEASE_CHECK_EMAIL_TO_VERIFY"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Verify that the query parameter was processed by checking the modal appeared
|
||||
// The hook removes the parameter from the URL, so we verify the behavior indirectly
|
||||
expect(container).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not display EmailVerificationModal when email_verification_required is not in query params", async () => {
|
||||
// Arrange - No query params set
|
||||
|
||||
// Act
|
||||
render(<RouterStub />, { wrapper: createWrapper() });
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.queryByText("AUTH$PLEASE_CHECK_EMAIL_TO_VERIFY"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should not display email verified message when email_verified is not in query params", async () => {
|
||||
// Arrange
|
||||
// Mock a 401 error to simulate unauthenticated user
|
||||
const axiosError = {
|
||||
response: { status: 401 },
|
||||
isAxiosError: true,
|
||||
};
|
||||
vi.spyOn(AuthService, "authenticate").mockRejectedValue(axiosError);
|
||||
|
||||
// Act
|
||||
render(<RouterStub />, { wrapper: createWrapper() });
|
||||
|
||||
// Assert - AuthModal should render but without email verified message
|
||||
await waitFor(() => {
|
||||
const authModal = screen.queryByText(
|
||||
"AUTH$SIGN_IN_WITH_IDENTITY_PROVIDER",
|
||||
);
|
||||
if (authModal) {
|
||||
expect(
|
||||
screen.queryByText("AUTH$EMAIL_VERIFIED_PLEASE_LOGIN"),
|
||||
).not.toBeInTheDocument();
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,8 +1,8 @@
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
|
||||
import { handleStatusMessage } from "../actions";
|
||||
import { handleStatusMessage } from "#/services/actions";
|
||||
import { StatusMessage } from "#/types/message";
|
||||
import { queryClient } from "#/query-client-config";
|
||||
import { useStatusStore } from "#/state/status-store";
|
||||
import { useStatusStore } from "#/stores/status-store";
|
||||
import { trackError } from "#/utils/error-handler";
|
||||
|
||||
// Mock dependencies
|
||||
@@ -12,7 +12,7 @@ vi.mock("#/query-client-config", () => ({
|
||||
},
|
||||
}));
|
||||
|
||||
vi.mock("#/state/status-store", () => ({
|
||||
vi.mock("#/stores/status-store", () => ({
|
||||
useStatusStore: {
|
||||
getState: vi.fn(() => ({
|
||||
setCurStatusMessage: vi.fn(),
|
||||
@@ -1,7 +1,7 @@
|
||||
import { describe, it, expect, vi, beforeEach } from "vitest";
|
||||
import ActionType from "#/types/action-type";
|
||||
import { ActionMessage } from "#/types/message";
|
||||
import { useCommandStore } from "#/state/command-store";
|
||||
import { useCommandStore } from "#/stores/command-store";
|
||||
|
||||
const mockDispatch = vi.fn();
|
||||
const mockAppendInput = vi.fn();
|
||||
|
||||
@@ -9,7 +9,7 @@ import { useShouldShowUserFeatures } from "../src/hooks/use-should-show-user-fea
|
||||
vi.mock("../src/hooks/use-should-show-user-features");
|
||||
vi.mock("#/api/suggestions-service/suggestions-service.api", () => ({
|
||||
SuggestionsService: {
|
||||
getSuggestedTasks: vi.fn(),
|
||||
getSuggestedTasks: vi.fn().mockResolvedValue([]),
|
||||
},
|
||||
}));
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import toast from "react-hot-toast";
|
||||
import {
|
||||
displaySuccessToast,
|
||||
displayErrorToast,
|
||||
} from "../custom-toast-handlers";
|
||||
} from "#/utils/custom-toast-handlers";
|
||||
|
||||
// Mock react-hot-toast
|
||||
vi.mock("react-hot-toast", () => ({
|
||||
@@ -29,5 +29,75 @@ describe("hasAdvancedSettingsSet", () => {
|
||||
}),
|
||||
).toBe(true);
|
||||
});
|
||||
|
||||
test("enable_default_condenser is disabled", () => {
|
||||
// Arrange
|
||||
const settings = {
|
||||
...DEFAULT_SETTINGS,
|
||||
enable_default_condenser: false,
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = hasAdvancedSettingsSet(settings);
|
||||
|
||||
// Assert
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
test("condenser_max_size is customized above default", () => {
|
||||
// Arrange
|
||||
const settings = {
|
||||
...DEFAULT_SETTINGS,
|
||||
condenser_max_size: 200,
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = hasAdvancedSettingsSet(settings);
|
||||
|
||||
// Assert
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
test("condenser_max_size is customized below default", () => {
|
||||
// Arrange
|
||||
const settings = {
|
||||
...DEFAULT_SETTINGS,
|
||||
condenser_max_size: 50,
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = hasAdvancedSettingsSet(settings);
|
||||
|
||||
// Assert
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
test("search_api_key is set to non-empty value", () => {
|
||||
// Arrange
|
||||
const settings = {
|
||||
...DEFAULT_SETTINGS,
|
||||
search_api_key: "test-api-key-123",
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = hasAdvancedSettingsSet(settings);
|
||||
|
||||
// Assert
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
test("search_api_key with whitespace is treated as set", () => {
|
||||
// Arrange
|
||||
const settings = {
|
||||
...DEFAULT_SETTINGS,
|
||||
search_api_key: " test-key ",
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = hasAdvancedSettingsSet(settings);
|
||||
|
||||
// Assert
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { describe, it, expect } from "vitest";
|
||||
import { parseMaxBudgetPerTask, extractSettings } from "../settings-utils";
|
||||
import { parseMaxBudgetPerTask, extractSettings } from "#/utils/settings-utils";
|
||||
|
||||
describe("parseMaxBudgetPerTask", () => {
|
||||
it("should return null for empty string", () => {
|
||||
@@ -1,5 +1,5 @@
|
||||
import { describe, it, expect } from "vitest";
|
||||
import { calculateToastDuration } from "../toast-duration";
|
||||
import { calculateToastDuration } from "#/utils/toast-duration";
|
||||
|
||||
describe("calculateToastDuration", () => {
|
||||
it("should return minimum duration for short messages", () => {
|
||||
@@ -1,5 +1,5 @@
|
||||
import { describe, it, expect, beforeEach, afterEach } from "vitest";
|
||||
import { transformVSCodeUrl } from "../vscode-url-helper";
|
||||
import { transformVSCodeUrl } from "#/utils/vscode-url-helper";
|
||||
|
||||
describe("transformVSCodeUrl", () => {
|
||||
const originalWindowLocation = window.location;
|
||||
2868
frontend/package-lock.json
generated
2868
frontend/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user