Compare commits

..

68 Commits

Author SHA1 Message Date
openhands
4f16748505 Merge branch 'main' into fix/settings-storage
Resolve merge conflict in enterprise/storage/saas_settings_store.py:
- Keep the fix that sets default llm_base_url if not set
- Use is_openhands_model from openhands.utils.llm (already imported)

Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-19 21:35:47 +00:00
Hiep Le
8927ac2230 fix(backend): organization members now see correct shared credit balance (#12942) 2026-02-20 01:34:53 +07:00
Rohit Malhotra
f3429e33ca Fix Resend sync to respect deleted users (#12904)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-19 17:43:15 +00:00
Tim O'Farrell
7cd219792b Add type hints and use model objects in api_keys.py endpoints (#12939)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-19 08:40:46 -07:00
Hiep Le
2aabe2ed8c fix(backend): add organization filtering to V1 conversation queries (#12923) 2026-02-19 20:39:28 +07:00
Tim O'Farrell
731a9a813e More readable logs for local debugging (#12926) 2026-02-19 02:27:57 -07:00
Tim O'Farrell
123e556fed Added endpoint for readiness probe (#12927) 2026-02-19 02:27:35 -07:00
Chujiang
6676cae249 fix: add missing type hints and improve test logging (#12810)
Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
2026-02-19 00:58:39 +01:00
Clay Arnold
fede37b496 fix: add claude-opus-4-6 to temperature/top_p guard (#12874)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 00:33:17 +01:00
Hiep Le
3bcd6f18df fix(backend): set user email fields from user_info during create_user (#12921) 2026-02-19 02:06:20 +07:00
Rohit Malhotra
0da18440c2 Mention free MiniMax usage and drop free credits (#12918)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-18 13:54:05 -05:00
Hiep Le
ac76e10048 refactor(backend): include current_org_id in organization list response (#12915) 2026-02-18 20:35:40 +07:00
Hiep Le
b98bae8b5f refactor(backend): rename orgmemberresponse.role_name to role (#12914) 2026-02-18 20:23:07 +07:00
Tim O'Farrell
516721d1ee fix: add default uuid4 to event_callback_result primary key (#12908)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-18 05:57:13 -07:00
Hiep Le
4d6f66ca28 feat: add user invitation logic (#12883)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-18 13:24:19 +07:00
chuckbutkus
b18568da0b Feature/permission based authorization (#12906)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-18 01:17:39 -05:00
mamoodi
83dd3c169c Release 1.4.0 (#12897) 2026-02-17 13:09:29 -05:00
Tim O'Farrell
35bddb14f1 fix: preserve import order in clean_proactive_convo_table.py (#12901)
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: OpenHands Bot <contact@all-hands.dev>
2026-02-17 17:52:54 +00:00
Tim O'Farrell
e8425218e2 Remove alembic errors dumped into logs by cron jobs (#12900) 2026-02-17 17:22:54 +00:00
Rohit Malhotra
0a879fa781 Grant free credits after minimum purchase (#12899)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-17 11:00:42 -05:00
Hiep Le
41e142bbab fix(backend): system prompt override (planning agent) (#12893) 2026-02-17 16:15:26 +07:00
Engel Nyst
b06b9eedac fix: wire suggested task prompts for V1 (#12787)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-16 23:57:32 +01:00
Tim O'Farrell
a9afafa991 Default model for new users is minimax (#12889) 2026-02-16 12:24:30 -07:00
mamoodi
663ace4b39 Add saas-rel* branch pattern to ghcr-build workflow (#12888)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-16 12:27:37 -05:00
Hiep Le
2d085a6e0a fix(frontend): add auto-scroll when new messages arrive in chat (#12885) 2026-02-16 23:46:14 +07:00
Hiep Le
8b7112abe8 refactor(frontend): hide planning preview component when plan content is empty (#12879) 2026-02-16 18:35:20 +07:00
Hiep Le
34547ba947 fix(backend): enable byor key export after purchasing credits (#12862) 2026-02-16 17:02:06 +07:00
Graham Neubig
5f958ab60d fix: suppress alembic INFO logs before import to prevent Datadog misclassification (#12691)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-12 14:32:39 -05:00
Hiep Le
d7656bf1c9 refactor(backend): rename user role to member across the system (#12853) 2026-02-13 00:45:47 +07:00
Tim O'Farrell
2bc107564c Support list_files and get_trajectory for nested conversation managers (#12850)
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: OpenHands Bot <contact@all-hands.dev>
2026-02-12 10:39:00 -07:00
Tim O'Farrell
85eb1e1504 Check event types before making expensive API calls in GitHub webhook handler (#12819)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-12 09:33:59 -07:00
OpenHands Bot
cd235cc8c7 Bump SDK packages to v1.11.4 (#12839)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Tim O'Farrell <tofarr@gmail.com>
2026-02-11 10:55:46 -07:00
Graham Neubig
40f52dfabc Use lowercase minimax-m2.5 for consistency (#12840)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-12 01:29:17 +08:00
Hiep Le
bab7bf85e8 fix(backend): prevent org deletion from setting current_org_id to NULL (#12817) 2026-02-12 00:15:21 +07:00
Hiep Le
c856537f65 refactor(backend): update the patch organization api to support organization name updates (#12834) 2026-02-12 00:08:43 +07:00
Graham Neubig
736f5b2255 Add MiniMax-M2.5 model support (#12835)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-11 16:57:22 +00:00
chuckbutkus
c1d9d11772 Log all exceptions in get_user() when authentication fails (#12836)
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: OpenHands Bot <contact@all-hands.dev>
2026-02-11 11:49:13 -05:00
sp.wack
85244499fe fix(frontend): performance and loading state bugs (#12821) 2026-02-11 15:34:52 +00:00
Hiep Le
c55084e223 fix(backend): read RECAPTCHA_SITE_KEY from environment in V1 web client config (#12830) 2026-02-11 18:59:52 +07:00
Tim O'Farrell
e3bb75deb4 fix(enterprise): use poetry.lock for reproducible dependency builds (#12820)
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: OpenHands Bot <contact@all-hands.dev>
2026-02-11 04:51:12 -07:00
Hiep Le
1948200762 chore: update sdk to the latest version (#12811) 2026-02-11 12:57:08 +07:00
Tim O'Farrell
affe0af361 Add debug logging for sandbox startup health checks (#12814)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-10 07:23:43 -07:00
Hiep Le
f20c956196 feat(backend): implement org member patch api (#12800) 2026-02-10 20:01:24 +07:00
Alexander Grattan
4a089a3a0d fix(docs): update Gray Swan API links and onboarding instructions in security README (#12809) 2026-02-10 10:14:49 +00:00
Hiep Le
aa0b2d0b74 feat(backend): add api for switching between orgs (#12799) 2026-02-10 14:22:52 +07:00
Hiep Le
bef9b80b9d fix(frontend): add missing border radius to conversation loading on first load (#12796) 2026-02-09 21:36:07 +07:00
Graham Neubig
c4a90b1f89 Fix Resend ValidationError by adding email validation (#12511)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-08 09:47:39 -05:00
sp.wack
0d13c57d9f feat(backend): org get me route (#12760)
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: hieptl <hieptl.developer@gmail.com>
2026-02-07 16:11:25 +07:00
Graham Neubig
b3422f1275 Add PR Review by OpenHands workflow (#12784)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-06 17:26:16 -05:00
Xingyao Wang
f139a9970b feat: add SANDBOX_STARTUP_GRACE_SECONDS env var for configurable startup timeout (#12741)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-07 06:12:29 +08:00
Jamie Chicago
54d156122c Add automated PR review workflow using OpenHands (#12698)
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Graham Neubig <neubig@gmail.com>
2026-02-06 19:02:55 +00:00
Tim O'Farrell
ac072bf686 feat(frontend): change alert banner from solid background to border style (#12783)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-06 18:05:29 +00:00
Hiep Le
a53812c029 feat(backend): develop delete /api/organizations/{orgid}/members/{userid} api (#12734) 2026-02-07 00:50:47 +07:00
Tim O'Farrell
1d1c0925b5 refactor: Move check_byor_export_enabled to OrgService and add tests (PR #12753 followup) (#12782)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-06 17:03:03 +00:00
Hiep Le
872f41e3c0 feat(backend): implement get /api/organizations/{orgId}/members api (#12735) 2026-02-06 23:47:30 +07:00
Tim O'Farrell
d43ff82534 feat: Add BYOR export flag to org for LLM key access control (#12753)
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: hieptl <hieptl.developer@gmail.com>
2026-02-06 09:30:12 -07:00
huangkevin-apr
8cd8c011b2 fix(a11y): Add aria-label to Sidebar component (#12728) 2026-02-06 22:32:52 +07:00
Tim O'Farrell
5c68b10983 (Frontend) Migrate to new /api/v1/web-client/config endpoint (#12479)
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: hieptl <hieptl.developer@gmail.com>
2026-02-06 08:31:40 -07:00
Graham Neubig
a97fad1976 fix: Add PostHog error tracking for V1 AgentErrorEvent and ConversationErrorEvent (#12543)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-06 09:51:01 -05:00
Graham Neubig
4c3542a91c fix: use appropriate log level for webhook installation results (#12493)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-06 09:01:37 -05:00
Tim O'Farrell
f460057f58 chore: add deprecation notices to all runtime directory files (#12772)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-06 05:15:02 -07:00
MkDev11
4fa2ad0f47 fix: add exponential backoff retry for env var export when bash session is busy (#12748)
Co-authored-by: mkdev11 <MkDev11@users.noreply.github.com>
2026-02-06 05:07:17 -07:00
Hiep Le
dd8be12809 feat(backend): return is_personal field in OrgResponse (#12777) 2026-02-06 19:01:06 +07:00
Tim O'Farrell
89475095d9 Preload callback processor class to prevent Pydantic Deserialization Error (#12776) 2026-02-06 04:29:28 -07:00
Tim O'Farrell
05d5f8848a Fix V1 GitHub conversations failing to clone repository (#12775)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-06 03:08:55 -07:00
Hiep Le
ee2885eb0b feat: store plan.md file in appropriate configuration folders (#12713) 2026-02-06 16:09:39 +07:00
Tim O'Farrell
545257f870 Refactor: Add LLM provider utilities and improve API base URL detection (#12766)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-05 14:22:32 -07:00
Chuck Butkus
a071b0651e Try key gen fix again 2026-02-04 20:23:36 -05:00
289 changed files with 17702 additions and 1999 deletions

View File

@@ -9,6 +9,7 @@ on:
push:
branches:
- main
- "saas-rel-*"
tags:
- "*"
pull_request:

View File

@@ -0,0 +1,127 @@
---
name: PR Review by OpenHands
on:
# Use pull_request_target to allow fork PRs to access secrets when triggered by maintainers
# Security: This workflow runs when:
# 1. A new PR is opened (non-draft), OR
# 2. A draft PR is marked as ready for review, OR
# 3. A maintainer adds the 'review-this' label, OR
# 4. A maintainer requests openhands-agent or all-hands-bot as a reviewer
# Only users with write access can add labels or request reviews, ensuring security.
# The PR code is explicitly checked out for review, but secrets are only accessible
# because the workflow runs in the base repository context
pull_request_target:
types: [opened, ready_for_review, labeled, review_requested]
permissions:
contents: read
pull-requests: write
issues: write
jobs:
pr-review:
# Run when one of the following conditions is met:
# 1. A new non-draft PR is opened by a trusted contributor, OR
# 2. A draft PR is converted to ready for review by a trusted contributor, OR
# 3. 'review-this' label is added, OR
# 4. openhands-agent or all-hands-bot is requested as a reviewer
# Note: FIRST_TIME_CONTRIBUTOR PRs require manual trigger via label/reviewer request
if: |
(github.event.action == 'opened' && github.event.pull_request.draft == false && github.event.pull_request.author_association != 'FIRST_TIME_CONTRIBUTOR') ||
(github.event.action == 'ready_for_review' && github.event.pull_request.author_association != 'FIRST_TIME_CONTRIBUTOR') ||
github.event.label.name == 'review-this' ||
github.event.requested_reviewer.login == 'openhands-agent' ||
github.event.requested_reviewer.login == 'all-hands-bot'
concurrency:
group: pr-review-${{ github.event.pull_request.number }}
cancel-in-progress: true
runs-on: blacksmith-4vcpu-ubuntu-2404
env:
LLM_MODEL: litellm_proxy/claude-sonnet-4-5-20250929
LLM_BASE_URL: https://llm-proxy.app.all-hands.dev
# PR context will be automatically provided by the agent script
PR_NUMBER: ${{ github.event.pull_request.number }}
PR_TITLE: ${{ github.event.pull_request.title }}
PR_BODY: ${{ github.event.pull_request.body }}
PR_BASE_BRANCH: ${{ github.event.pull_request.base.ref }}
PR_HEAD_BRANCH: ${{ github.event.pull_request.head.ref }}
REPO_NAME: ${{ github.repository }}
steps:
- name: Checkout software-agent-sdk repository
uses: actions/checkout@v5
with:
repository: OpenHands/software-agent-sdk
path: software-agent-sdk
- name: Checkout PR repository
uses: actions/checkout@v5
with:
# When using pull_request_target, explicitly checkout the PR branch
# This ensures we review the actual PR code (including fork PRs)
repository: ${{ github.event.pull_request.head.repo.full_name }}
ref: ${{ github.event.pull_request.head.ref }}
fetch-depth: 0
# Security: Don't persist credentials to prevent untrusted PR code from using them
persist-credentials: false
path: pr-repo
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: '3.13'
- name: Install uv
uses: astral-sh/setup-uv@v7
with:
enable-cache: true
- name: Install GitHub CLI
run: |
# Install GitHub CLI for posting review comments
sudo apt-get update
sudo apt-get install -y gh
- name: Install OpenHands dependencies
run: |
# Install OpenHands SDK and tools from local checkout
uv pip install --system ./software-agent-sdk/openhands-sdk ./software-agent-sdk/openhands-tools
- name: Check required configuration
env:
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
run: |
if [ -z "$LLM_API_KEY" ]; then
echo "Error: LLM_API_KEY secret is not set."
exit 1
fi
echo "PR Number: $PR_NUMBER"
echo "PR Title: $PR_TITLE"
echo "Repository: $REPO_NAME"
echo "LLM model: $LLM_MODEL"
if [ -n "$LLM_BASE_URL" ]; then
echo "LLM base URL: $LLM_BASE_URL"
fi
- name: Run PR review
env:
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
GITHUB_TOKEN: ${{ secrets.ALLHANDS_BOT_GITHUB_PAT }}
LMNR_PROJECT_API_KEY: ${{ secrets.LMNR_SKILLS_API_KEY }}
run: |
# Change to the PR repository directory so agent can analyze the code
cd pr-repo
# Run the PR review script from the software-agent-sdk checkout
uv run python ../software-agent-sdk/examples/03_github_workflows/02_pr_review/agent_script.py
- name: Upload logs as artifact
uses: actions/upload-artifact@v5
if: always()
with:
name: openhands-pr-review-logs
path: |
*.log
output/
retention-days: 7

View File

@@ -54,7 +54,7 @@ The experience will be familiar to anyone who has used Devin or Jules.
### OpenHands Cloud
This is a deployment of OpenHands GUI, running on hosted infrastructure.
You can try it with a free $10 credit by [signing in with your GitHub or GitLab account](https://app.all-hands.dev).
You can try it for free using the Minimax model by [signing in with your GitHub or GitLab account](https://app.all-hands.dev).
OpenHands Cloud comes with source-available features and integrations:
- Integrations with Slack, Jira, and Linear

View File

@@ -23,12 +23,23 @@ RUN apt-get update && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
# Install Python packages with security fixes
RUN /app/.venv/bin/pip install alembic psycopg2-binary cloud-sql-python-connector pg8000 gspread stripe python-keycloak asyncpg sqlalchemy[asyncio] resend tenacity slack-sdk ddtrace "posthog>=6.0.0" "limits==5.2.0" coredis prometheus-client shap scikit-learn pandas numpy google-cloud-recaptcha-enterprise && \
# Update packages with known CVE fixes
/app/.venv/bin/pip install --upgrade \
"mcp>=1.10.0" \
"pillow>=11.3.0"
# Install poetry and export before importing current code.
RUN /app/.venv/bin/pip install poetry poetry-plugin-export
# Install Python dependencies from poetry.lock for reproducible builds
# Copy lock files first for better Docker layer caching
COPY --chown=openhands:openhands enterprise/pyproject.toml enterprise/poetry.lock /tmp/enterprise/
RUN cd /tmp/enterprise && \
# Export only main dependencies with hashes for supply chain security
/app/.venv/bin/poetry export --only main -o requirements.txt && \
# Remove the local path dependency (openhands-ai is already in base image)
sed -i '/^-e /d; /openhands-ai/d' requirements.txt && \
# Install pinned dependencies from lock file
/app/.venv/bin/pip install -r requirements.txt && \
# Cleanup - return to /app before removing /tmp/enterprise
cd /app && \
rm -rf /tmp/enterprise && \
/app/.venv/bin/pip uninstall -y poetry poetry-plugin-export
WORKDIR /app
COPY --chown=openhands:openhands --chmod=770 enterprise .

View File

@@ -28,9 +28,11 @@ class SaaSExperimentManager(ExperimentManager):
return agent
if EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT:
agent = agent.model_copy(
update={'system_prompt_filename': 'system_prompt_long_horizon.j2'}
)
# Skip experiment for planning agents which require their specialized prompt
if agent.system_prompt_filename != 'system_prompt_planning.j2':
agent = agent.model_copy(
update={'system_prompt_filename': 'system_prompt_long_horizon.j2'}
)
return agent

View File

@@ -145,11 +145,7 @@ class GithubManager(Manager):
).get('body', ''):
return False
if GithubFactory.is_eligible_for_conversation_starter(
message
) and self._user_has_write_access_to_repo(installation_id, repo_name, username):
await GithubFactory.trigger_conversation_starter(message)
# Check event types before making expensive API calls (e.g., _user_has_write_access_to_repo)
if not (
GithubFactory.is_labeled_issue(message)
or GithubFactory.is_issue_comment(message)
@@ -159,8 +155,17 @@ class GithubManager(Manager):
return False
logger.info(f'[GitHub] Checking permissions for {username} in {repo_name}')
user_has_write_access = self._user_has_write_access_to_repo(
installation_id, repo_name, username
)
return self._user_has_write_access_to_repo(installation_id, repo_name, username)
if (
GithubFactory.is_eligible_for_conversation_starter(message)
and user_has_write_access
):
await GithubFactory.trigger_conversation_starter(message)
return user_has_write_access
async def receive_message(self, message: Message):
self._confirm_incoming_source_type(message)

View File

@@ -167,17 +167,15 @@ async def install_webhook_on_resource(
scopes=SCOPES,
)
logger.info(
'Creating new webhook',
extra={
'webhook_id': webhook_id,
'status': status,
'resource_id': resource_id,
'resource_type': resource_type,
},
)
log_extra = {
'webhook_id': webhook_id,
'status': status,
'resource_id': resource_id,
'resource_type': resource_type,
}
if status == WebhookStatus.RATE_LIMITED:
logger.warning('Rate limited while creating webhook', extra=log_extra)
raise BreakLoopException()
if webhook_id:
@@ -191,9 +189,8 @@ async def install_webhook_on_resource(
'webhook_uuid': webhook_uuid, # required to identify which webhook installation is sending payload
},
)
logger.info(
f'Installed webhook for {webhook.user_id} on {resource_type}:{resource_id}'
)
logger.info('Created new webhook', extra=log_extra)
else:
logger.error('Failed to create webhook', extra=log_extra)
return webhook_id, status

View File

@@ -1,6 +1,6 @@
from openhands.app_server.user.user_context import UserContext
from openhands.app_server.user.user_models import UserInfo
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderHandler
from openhands.integrations.service_types import ProviderType
from openhands.sdk.secret import SecretSource, StaticSecret
from openhands.server.user_auth.user_auth import UserAuth
@@ -14,6 +14,7 @@ class ResolverUserContext(UserContext):
saas_user_auth: UserAuth,
):
self.saas_user_auth = saas_user_auth
self._provider_handler: ProviderHandler | None = None
async def get_user_id(self) -> str | None:
return await self.saas_user_auth.get_user_id()
@@ -29,12 +30,26 @@ class ResolverUserContext(UserContext):
return UserInfo(id=user_id)
async def _get_provider_handler(self) -> ProviderHandler:
"""Get or create a ProviderHandler for git operations."""
if self._provider_handler is None:
provider_tokens = await self.saas_user_auth.get_provider_tokens()
if provider_tokens is None:
raise ValueError('No provider tokens available')
user_id = await self.saas_user_auth.get_user_id()
self._provider_handler = ProviderHandler(
provider_tokens=provider_tokens, external_auth_id=user_id
)
return self._provider_handler
async def get_authenticated_git_url(
self, repository: str, is_optional: bool = False
) -> str:
# This would need to be implemented based on the git provider tokens
# For now, return a basic HTTPS URL
return f'https://github.com/{repository}.git'
provider_handler = await self._get_provider_handler()
url = await provider_handler.get_authenticated_git_url(
repository, is_optional=is_optional
)
return url
async def get_latest_token(self, provider_type: ProviderType) -> str | None:
# Return the appropriate token string from git_provider_tokens

View File

@@ -1,10 +1,15 @@
import logging
import os
from logging.config import fileConfig
from alembic import context
from google.cloud.sql.connector import Connector
from sqlalchemy import create_engine
from storage.base import Base
# Suppress alembic.runtime.plugins INFO logs during import to prevent non-JSON logs in production
# These plugin setup messages would otherwise appear before logging is configured
logging.getLogger('alembic.runtime.plugins').setLevel(logging.WARNING)
from alembic import context # noqa: E402
from google.cloud.sql.connector import Connector # noqa: E402
from sqlalchemy import create_engine # noqa: E402
from storage.base import Base # noqa: E402
target_metadata = Base.metadata

View File

@@ -0,0 +1,46 @@
"""Add byor_export_enabled flag to org table.
Revision ID: 091
Revises: 090
Create Date: 2025-01-15 00:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '091'
down_revision: Union[str, None] = '090'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add byor_export_enabled column to org table with default false
op.add_column(
'org',
sa.Column(
'byor_export_enabled',
sa.Boolean,
nullable=False,
server_default=sa.text('false'),
),
)
# Set byor_export_enabled to true for orgs that have completed billing sessions
op.execute(
sa.text("""
UPDATE org SET byor_export_enabled = TRUE
WHERE id IN (
SELECT DISTINCT org_id FROM billing_sessions
WHERE status = 'completed' AND org_id IS NOT NULL
)
""")
)
def downgrade() -> None:
op.drop_column('org', 'byor_export_enabled')

View File

@@ -0,0 +1,29 @@
"""Rename 'user' role to 'member' in role table.
Revision ID: 092
Revises: 091
Create Date: 2025-02-12 00:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '092'
down_revision: Union[str, None] = '091'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Rename 'user' role to 'member' for clarity
# This avoids confusion between the 'user' role and the 'user' entity/account
op.execute(sa.text("UPDATE role SET name = 'member' WHERE name = 'user'"))
def downgrade() -> None:
# Revert 'member' role back to 'user'
op.execute(sa.text("UPDATE role SET name = 'user' WHERE name = 'member'"))

View File

@@ -0,0 +1,37 @@
"""Add pending_free_credits flag to org table.
Revision ID: 093
Revises: 092
Create Date: 2025-02-17 00:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '093'
down_revision: Union[str, None] = '092'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add pending_free_credits column to org table with default false.
# New orgs will have this set to TRUE at creation time.
# Existing orgs default to FALSE (not eligible - they already got $10 at signup).
op.add_column(
'org',
sa.Column(
'pending_free_credits',
sa.Boolean,
nullable=False,
server_default=sa.text('false'),
),
)
def downgrade() -> None:
op.drop_column('org', 'pending_free_credits')

View File

@@ -0,0 +1,110 @@
"""create org_invitation table
Revision ID: 094
Revises: 093
Create Date: 2026-02-18 00:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = '094'
down_revision: Union[str, None] = '093'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Create org_invitation table
op.create_table(
'org_invitation',
sa.Column('id', sa.Integer, sa.Identity(), primary_key=True),
sa.Column('token', sa.String(64), nullable=False),
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('email', sa.String(255), nullable=False),
sa.Column('role_id', sa.Integer, nullable=False),
sa.Column('inviter_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column(
'status',
sa.String(20),
nullable=False,
server_default=sa.text("'pending'"),
),
sa.Column(
'created_at',
sa.DateTime,
nullable=False,
server_default=sa.text('CURRENT_TIMESTAMP'),
),
sa.Column('expires_at', sa.DateTime, nullable=False),
sa.Column('accepted_at', sa.DateTime, nullable=True),
sa.Column('accepted_by_user_id', postgresql.UUID(as_uuid=True), nullable=True),
# Foreign key constraints
sa.ForeignKeyConstraint(
['org_id'],
['org.id'],
name='org_invitation_org_fkey',
ondelete='CASCADE',
),
sa.ForeignKeyConstraint(
['role_id'],
['role.id'],
name='org_invitation_role_fkey',
),
sa.ForeignKeyConstraint(
['inviter_id'],
['user.id'],
name='org_invitation_inviter_fkey',
),
sa.ForeignKeyConstraint(
['accepted_by_user_id'],
['user.id'],
name='org_invitation_accepter_fkey',
),
)
# Create indexes
op.create_index(
'ix_org_invitation_token',
'org_invitation',
['token'],
unique=True,
)
op.create_index(
'ix_org_invitation_org_id',
'org_invitation',
['org_id'],
)
op.create_index(
'ix_org_invitation_email',
'org_invitation',
['email'],
)
op.create_index(
'ix_org_invitation_status',
'org_invitation',
['status'],
)
# Composite index for checking pending invitations
op.create_index(
'ix_org_invitation_org_email_status',
'org_invitation',
['org_id', 'email', 'status'],
)
def downgrade() -> None:
# Drop indexes
op.drop_index('ix_org_invitation_org_email_status', table_name='org_invitation')
op.drop_index('ix_org_invitation_status', table_name='org_invitation')
op.drop_index('ix_org_invitation_email', table_name='org_invitation')
op.drop_index('ix_org_invitation_org_id', table_name='org_invitation')
op.drop_index('ix_org_invitation_token', table_name='org_invitation')
# Drop table
op.drop_table('org_invitation')

View File

@@ -0,0 +1,37 @@
"""Drop pending_free_credits column from org table.
Revision ID: 095
Revises: 094
Create Date: 2025-02-18 00:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '095'
down_revision: Union[str, None] = '094'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Drop the pending_free_credits column from org table.
# This column was used for tracking free credit eligibility but is no longer needed.
op.drop_column('org', 'pending_free_credits')
def downgrade() -> None:
# Re-add pending_free_credits column with default false.
op.add_column(
'org',
sa.Column(
'pending_free_credits',
sa.Boolean,
nullable=False,
server_default=sa.text('false'),
),
)

View File

@@ -0,0 +1,67 @@
"""Create resend_synced_users table.
Revision ID: 096
Revises: 095
Create Date: 2025-02-17 00:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '096'
down_revision: Union[str, None] = '095'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Create resend_synced_users table for tracking users synced to Resend audiences."""
op.create_table(
'resend_synced_users',
sa.Column(
'id',
sa.UUID(as_uuid=True),
nullable=False,
primary_key=True,
),
sa.Column('email', sa.String(), nullable=False),
sa.Column('audience_id', sa.String(), nullable=False),
sa.Column(
'synced_at',
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text('CURRENT_TIMESTAMP'),
),
sa.Column('keycloak_user_id', sa.String(), nullable=True),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint(
'email', 'audience_id', name='uq_resend_synced_email_audience'
),
)
# Create index on email for fast lookups
op.create_index(
'ix_resend_synced_users_email',
'resend_synced_users',
['email'],
)
# Create index on audience_id for filtering by audience
op.create_index(
'ix_resend_synced_users_audience_id',
'resend_synced_users',
['audience_id'],
)
def downgrade() -> None:
"""Drop resend_synced_users table."""
op.drop_index(
'ix_resend_synced_users_audience_id', table_name='resend_synced_users'
)
op.drop_index('ix_resend_synced_users_email', table_name='resend_synced_users')
op.drop_table('resend_synced_users')

210
enterprise/poetry.lock generated
View File

@@ -6102,14 +6102,14 @@ llama = ["llama-index (>=0.12.29,<0.13.0)", "llama-index-core (>=0.12.29,<0.13.0
[[package]]
name = "openhands-agent-server"
version = "1.11.1"
version = "1.11.4"
description = "OpenHands Agent Server - REST/WebSocket interface for OpenHands AI Agent"
optional = false
python-versions = ">=3.12"
groups = ["main"]
files = [
{file = "openhands_agent_server-1.11.1-py3-none-any.whl", hash = "sha256:28e3ca670114c7a936a33f2d193238fbdc75f429c4e0bb99a03b14e6c01663c9"},
{file = "openhands_agent_server-1.11.1.tar.gz", hash = "sha256:06eaf8b8eda4ca05de24751a7d269b22f611328c6cb2b4b91f2486011228b69a"},
{file = "openhands_agent_server-1.11.4-py3-none-any.whl", hash = "sha256:739bdb774dbfcd23d6e87ee6ee32bc0999f22300037506b6dd33e9ea67fa5c2a"},
{file = "openhands_agent_server-1.11.4.tar.gz", hash = "sha256:41247f7022a046eb50ca3b552bc6d12bfa9776e1bd27d0989da91b9f7ac77ca2"},
]
[package.dependencies]
@@ -6168,9 +6168,9 @@ memory-profiler = ">=0.61"
numpy = "*"
openai = "2.8"
openhands-aci = "0.3.2"
openhands-agent-server = "1.11.1"
openhands-sdk = "1.11.1"
openhands-tools = "1.11.1"
openhands-agent-server = "1.11.4"
openhands-sdk = "1.11.4"
openhands-tools = "1.11.4"
opentelemetry-api = ">=1.33.1"
opentelemetry-exporter-otlp-proto-grpc = ">=1.33.1"
pathspec = ">=0.12.1"
@@ -6225,14 +6225,14 @@ url = ".."
[[package]]
name = "openhands-sdk"
version = "1.11.1"
version = "1.11.4"
description = "OpenHands SDK - Core functionality for building AI agents"
optional = false
python-versions = ">=3.12"
groups = ["main"]
files = [
{file = "openhands_sdk-1.11.1-py3-none-any.whl", hash = "sha256:10ee0777286b149db21bdeeadb6d4c57f461da4049a4ba07576e7228b5c76c85"},
{file = "openhands_sdk-1.11.1.tar.gz", hash = "sha256:57f5884d0596a8659b7c0cdbe86ebaa74c810c4e2645fcff45f0113894dd9376"},
{file = "openhands_sdk-1.11.4-py3-none-any.whl", hash = "sha256:9f4607c5d94b56fbcd533207026ee892779dd50e29bce79277ff82454a4f76d5"},
{file = "openhands_sdk-1.11.4.tar.gz", hash = "sha256:4088744f6b8856eeab22d3bc17e47d1736ea7ced945c2fa126bd7d48c14bb313"},
]
[package.dependencies]
@@ -6253,14 +6253,14 @@ boto3 = ["boto3 (>=1.35.0)"]
[[package]]
name = "openhands-tools"
version = "1.11.1"
version = "1.11.4"
description = "OpenHands Tools - Runtime tools for AI agents"
optional = false
python-versions = ">=3.12"
groups = ["main"]
files = [
{file = "openhands_tools-1.11.1-py3-none-any.whl", hash = "sha256:0b64763def90dda5b6545a356a437437c2029ec9bc47a4e6dac5c06dea6a4e77"},
{file = "openhands_tools-1.11.1.tar.gz", hash = "sha256:2a71d2d0619ca631b3b7f5bd741bfdf97f7ebe6f96dc2540f79b9a688a6309fc"},
{file = "openhands_tools-1.11.4-py3-none-any.whl", hash = "sha256:efd721b73e87a0dac69171a76931363fa59fcde98107ca86081ee7bf0253673a"},
{file = "openhands_tools-1.11.4.tar.gz", hash = "sha256:80671b1ea8c85a5247a75ea2340ae31d76363e9c723b104699a9a77e66d2043c"},
]
[package.dependencies]
@@ -6851,103 +6851,103 @@ scramp = ">=1.4.5"
[[package]]
name = "pillow"
version = "12.1.0"
version = "12.1.1"
description = "Python Imaging Library (fork)"
optional = false
python-versions = ">=3.10"
groups = ["main", "test"]
files = [
{file = "pillow-12.1.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:fb125d860738a09d363a88daa0f59c4533529a90e564785e20fe875b200b6dbd"},
{file = "pillow-12.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cad302dc10fac357d3467a74a9561c90609768a6f73a1923b0fd851b6486f8b0"},
{file = "pillow-12.1.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:a40905599d8079e09f25027423aed94f2823adaf2868940de991e53a449e14a8"},
{file = "pillow-12.1.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:92a7fe4225365c5e3a8e598982269c6d6698d3e783b3b1ae979e7819f9cd55c1"},
{file = "pillow-12.1.0-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f10c98f49227ed8383d28174ee95155a675c4ed7f85e2e573b04414f7e371bda"},
{file = "pillow-12.1.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8637e29d13f478bc4f153d8daa9ffb16455f0a6cb287da1b432fdad2bfbd66c7"},
{file = "pillow-12.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:21e686a21078b0f9cb8c8a961d99e6a4ddb88e0fc5ea6e130172ddddc2e5221a"},
{file = "pillow-12.1.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:2415373395a831f53933c23ce051021e79c8cd7979822d8cc478547a3f4da8ef"},
{file = "pillow-12.1.0-cp310-cp310-win32.whl", hash = "sha256:e75d3dba8fc1ddfec0cd752108f93b83b4f8d6ab40e524a95d35f016b9683b09"},
{file = "pillow-12.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:64efdf00c09e31efd754448a383ea241f55a994fd079866b92d2bbff598aad91"},
{file = "pillow-12.1.0-cp310-cp310-win_arm64.whl", hash = "sha256:f188028b5af6b8fb2e9a76ac0f841a575bd1bd396e46ef0840d9b88a48fdbcea"},
{file = "pillow-12.1.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:a83e0850cb8f5ac975291ebfc4170ba481f41a28065277f7f735c202cd8e0af3"},
{file = "pillow-12.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b6e53e82ec2db0717eabb276aa56cf4e500c9a7cec2c2e189b55c24f65a3e8c0"},
{file = "pillow-12.1.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:40a8e3b9e8773876d6e30daed22f016509e3987bab61b3b7fe309d7019a87451"},
{file = "pillow-12.1.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:800429ac32c9b72909c671aaf17ecd13110f823ddb7db4dfef412a5587c2c24e"},
{file = "pillow-12.1.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0b022eaaf709541b391ee069f0022ee5b36c709df71986e3f7be312e46f42c84"},
{file = "pillow-12.1.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1f345e7bc9d7f368887c712aa5054558bad44d2a301ddf9248599f4161abc7c0"},
{file = "pillow-12.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d70347c8a5b7ccd803ec0c85c8709f036e6348f1e6a5bf048ecd9c64d3550b8b"},
{file = "pillow-12.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1fcc52d86ce7a34fd17cb04e87cfdb164648a3662a6f20565910a99653d66c18"},
{file = "pillow-12.1.0-cp311-cp311-win32.whl", hash = "sha256:3ffaa2f0659e2f740473bcf03c702c39a8d4b2b7ffc629052028764324842c64"},
{file = "pillow-12.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:806f3987ffe10e867bab0ddad45df1148a2b98221798457fa097ad85d6e8bc75"},
{file = "pillow-12.1.0-cp311-cp311-win_arm64.whl", hash = "sha256:9f5fefaca968e700ad1a4a9de98bf0869a94e397fe3524c4c9450c1445252304"},
{file = "pillow-12.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:a332ac4ccb84b6dde65dbace8431f3af08874bf9770719d32a635c4ef411b18b"},
{file = "pillow-12.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:907bfa8a9cb790748a9aa4513e37c88c59660da3bcfffbd24a7d9e6abf224551"},
{file = "pillow-12.1.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:efdc140e7b63b8f739d09a99033aa430accce485ff78e6d311973a67b6bf3208"},
{file = "pillow-12.1.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bef9768cab184e7ae6e559c032e95ba8d07b3023c289f79a2bd36e8bf85605a5"},
{file = "pillow-12.1.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:742aea052cf5ab5034a53c3846165bc3ce88d7c38e954120db0ab867ca242661"},
{file = "pillow-12.1.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a6dfc2af5b082b635af6e08e0d1f9f1c4e04d17d4e2ca0ef96131e85eda6eb17"},
{file = "pillow-12.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:609e89d9f90b581c8d16358c9087df76024cf058fa693dd3e1e1620823f39670"},
{file = "pillow-12.1.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:43b4899cfd091a9693a1278c4982f3e50f7fb7cff5153b05174b4afc9593b616"},
{file = "pillow-12.1.0-cp312-cp312-win32.whl", hash = "sha256:aa0c9cc0b82b14766a99fbe6084409972266e82f459821cd26997a488a7261a7"},
{file = "pillow-12.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:d70534cea9e7966169ad29a903b99fc507e932069a881d0965a1a84bb57f6c6d"},
{file = "pillow-12.1.0-cp312-cp312-win_arm64.whl", hash = "sha256:65b80c1ee7e14a87d6a068dd3b0aea268ffcabfe0498d38661b00c5b4b22e74c"},
{file = "pillow-12.1.0-cp313-cp313-ios_13_0_arm64_iphoneos.whl", hash = "sha256:7b5dd7cbae20285cdb597b10eb5a2c13aa9de6cde9bb64a3c1317427b1db1ae1"},
{file = "pillow-12.1.0-cp313-cp313-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:29a4cef9cb672363926f0470afc516dbf7305a14d8c54f7abbb5c199cd8f8179"},
{file = "pillow-12.1.0-cp313-cp313-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:681088909d7e8fa9e31b9799aaa59ba5234c58e5e4f1951b4c4d1082a2e980e0"},
{file = "pillow-12.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:983976c2ab753166dc66d36af6e8ec15bb511e4a25856e2227e5f7e00a160587"},
{file = "pillow-12.1.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:db44d5c160a90df2d24a24760bbd37607d53da0b34fb546c4c232af7192298ac"},
{file = "pillow-12.1.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:6b7a9d1db5dad90e2991645874f708e87d9a3c370c243c2d7684d28f7e133e6b"},
{file = "pillow-12.1.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6258f3260986990ba2fa8a874f8b6e808cf5abb51a94015ca3dc3c68aa4f30ea"},
{file = "pillow-12.1.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e115c15e3bc727b1ca3e641a909f77f8ca72a64fff150f666fcc85e57701c26c"},
{file = "pillow-12.1.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6741e6f3074a35e47c77b23a4e4f2d90db3ed905cb1c5e6e0d49bff2045632bc"},
{file = "pillow-12.1.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:935b9d1aed48fcfb3f838caac506f38e29621b44ccc4f8a64d575cb1b2a88644"},
{file = "pillow-12.1.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5fee4c04aad8932da9f8f710af2c1a15a83582cfb884152a9caa79d4efcdbf9c"},
{file = "pillow-12.1.0-cp313-cp313-win32.whl", hash = "sha256:a786bf667724d84aa29b5db1c61b7bfdde380202aaca12c3461afd6b71743171"},
{file = "pillow-12.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:461f9dfdafa394c59cd6d818bdfdbab4028b83b02caadaff0ffd433faf4c9a7a"},
{file = "pillow-12.1.0-cp313-cp313-win_arm64.whl", hash = "sha256:9212d6b86917a2300669511ed094a9406888362e085f2431a7da985a6b124f45"},
{file = "pillow-12.1.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:00162e9ca6d22b7c3ee8e61faa3c3253cd19b6a37f126cad04f2f88b306f557d"},
{file = "pillow-12.1.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:7d6daa89a00b58c37cb1747ec9fb7ac3bc5ffd5949f5888657dfddde6d1312e0"},
{file = "pillow-12.1.0-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:e2479c7f02f9d505682dc47df8c0ea1fc5e264c4d1629a5d63fe3e2334b89554"},
{file = "pillow-12.1.0-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f188d580bd870cda1e15183790d1cc2fa78f666e76077d103edf048eed9c356e"},
{file = "pillow-12.1.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0fde7ec5538ab5095cc02df38ee99b0443ff0e1c847a045554cf5f9af1f4aa82"},
{file = "pillow-12.1.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0ed07dca4a8464bada6139ab38f5382f83e5f111698caf3191cb8dbf27d908b4"},
{file = "pillow-12.1.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:f45bd71d1fa5e5749587613037b172e0b3b23159d1c00ef2fc920da6f470e6f0"},
{file = "pillow-12.1.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:277518bf4fe74aa91489e1b20577473b19ee70fb97c374aa50830b279f25841b"},
{file = "pillow-12.1.0-cp313-cp313t-win32.whl", hash = "sha256:7315f9137087c4e0ee73a761b163fc9aa3b19f5f606a7fc08d83fd3e4379af65"},
{file = "pillow-12.1.0-cp313-cp313t-win_amd64.whl", hash = "sha256:0ddedfaa8b5f0b4ffbc2fa87b556dc59f6bb4ecb14a53b33f9189713ae8053c0"},
{file = "pillow-12.1.0-cp313-cp313t-win_arm64.whl", hash = "sha256:80941e6d573197a0c28f394753de529bb436b1ca990ed6e765cf42426abc39f8"},
{file = "pillow-12.1.0-cp314-cp314-ios_13_0_arm64_iphoneos.whl", hash = "sha256:5cb7bc1966d031aec37ddb9dcf15c2da5b2e9f7cc3ca7c54473a20a927e1eb91"},
{file = "pillow-12.1.0-cp314-cp314-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:97e9993d5ed946aba26baf9c1e8cf18adbab584b99f452ee72f7ee8acb882796"},
{file = "pillow-12.1.0-cp314-cp314-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:414b9a78e14ffeb98128863314e62c3f24b8a86081066625700b7985b3f529bd"},
{file = "pillow-12.1.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:e6bdb408f7c9dd2a5ff2b14a3b0bb6d4deb29fb9961e6eb3ae2031ae9a5cec13"},
{file = "pillow-12.1.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:3413c2ae377550f5487991d444428f1a8ae92784aac79caa8b1e3b89b175f77e"},
{file = "pillow-12.1.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:e5dcbe95016e88437ecf33544ba5db21ef1b8dd6e1b434a2cb2a3d605299e643"},
{file = "pillow-12.1.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d0a7735df32ccbcc98b98a1ac785cc4b19b580be1bdf0aeb5c03223220ea09d5"},
{file = "pillow-12.1.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0c27407a2d1b96774cbc4a7594129cc027339fd800cd081e44497722ea1179de"},
{file = "pillow-12.1.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:15c794d74303828eaa957ff8070846d0efe8c630901a1c753fdc63850e19ecd9"},
{file = "pillow-12.1.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:c990547452ee2800d8506c4150280757f88532f3de2a58e3022e9b179107862a"},
{file = "pillow-12.1.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:b63e13dd27da389ed9475b3d28510f0f954bca0041e8e551b2a4eb1eab56a39a"},
{file = "pillow-12.1.0-cp314-cp314-win32.whl", hash = "sha256:1a949604f73eb07a8adab38c4fe50791f9919344398bdc8ac6b307f755fc7030"},
{file = "pillow-12.1.0-cp314-cp314-win_amd64.whl", hash = "sha256:4f9f6a650743f0ddee5593ac9e954ba1bdbc5e150bc066586d4f26127853ab94"},
{file = "pillow-12.1.0-cp314-cp314-win_arm64.whl", hash = "sha256:808b99604f7873c800c4840f55ff389936ef1948e4e87645eaf3fccbc8477ac4"},
{file = "pillow-12.1.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:bc11908616c8a283cf7d664f77411a5ed2a02009b0097ff8abbba5e79128ccf2"},
{file = "pillow-12.1.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:896866d2d436563fa2a43a9d72f417874f16b5545955c54a64941e87c1376c61"},
{file = "pillow-12.1.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8e178e3e99d3c0ea8fc64b88447f7cac8ccf058af422a6cedc690d0eadd98c51"},
{file = "pillow-12.1.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:079af2fb0c599c2ec144ba2c02766d1b55498e373b3ac64687e43849fbbef5bc"},
{file = "pillow-12.1.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bdec5e43377761c5dbca620efb69a77f6855c5a379e32ac5b158f54c84212b14"},
{file = "pillow-12.1.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:565c986f4b45c020f5421a4cea13ef294dde9509a8577f29b2fc5edc7587fff8"},
{file = "pillow-12.1.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:43aca0a55ce1eefc0aefa6253661cb54571857b1a7b2964bd8a1e3ef4b729924"},
{file = "pillow-12.1.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:0deedf2ea233722476b3a81e8cdfbad786f7adbed5d848469fa59fe52396e4ef"},
{file = "pillow-12.1.0-cp314-cp314t-win32.whl", hash = "sha256:b17fbdbe01c196e7e159aacb889e091f28e61020a8abeac07b68079b6e626988"},
{file = "pillow-12.1.0-cp314-cp314t-win_amd64.whl", hash = "sha256:27b9baecb428899db6c0de572d6d305cfaf38ca1596b5c0542a5182e3e74e8c6"},
{file = "pillow-12.1.0-cp314-cp314t-win_arm64.whl", hash = "sha256:f61333d817698bdcdd0f9d7793e365ac3d2a21c1f1eb02b32ad6aefb8d8ea831"},
{file = "pillow-12.1.0-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:ca94b6aac0d7af2a10ba08c0f888b3d5114439b6b3ef39968378723622fed377"},
{file = "pillow-12.1.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:351889afef0f485b84078ea40fe33727a0492b9af3904661b0abbafee0355b72"},
{file = "pillow-12.1.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:bb0984b30e973f7e2884362b7d23d0a348c7143ee559f38ef3eaab640144204c"},
{file = "pillow-12.1.0-pp311-pypy311_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:84cabc7095dd535ca934d57e9ce2a72ffd216e435a84acb06b2277b1de2689bd"},
{file = "pillow-12.1.0-pp311-pypy311_pp73-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:53d8b764726d3af1a138dd353116f774e3862ec7e3794e0c8781e30db0f35dfc"},
{file = "pillow-12.1.0-pp311-pypy311_pp73-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5da841d81b1a05ef940a8567da92decaa15bc4d7dedb540a8c219ad83d91808a"},
{file = "pillow-12.1.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:75af0b4c229ac519b155028fa1be632d812a519abba9b46b20e50c6caa184f19"},
{file = "pillow-12.1.0.tar.gz", hash = "sha256:5c5ae0a06e9ea030ab786b0251b32c7e4ce10e58d983c0d5c56029455180b5b9"},
{file = "pillow-12.1.1-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:1f1625b72740fdda5d77b4def688eb8fd6490975d06b909fd19f13f391e077e0"},
{file = "pillow-12.1.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:178aa072084bd88ec759052feca8e56cbb14a60b39322b99a049e58090479713"},
{file = "pillow-12.1.1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:b66e95d05ba806247aaa1561f080abc7975daf715c30780ff92a20e4ec546e1b"},
{file = "pillow-12.1.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:89c7e895002bbe49cdc5426150377cbbc04767d7547ed145473f496dfa40408b"},
{file = "pillow-12.1.1-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3a5cbdcddad0af3da87cb16b60d23648bc3b51967eb07223e9fed77a82b457c4"},
{file = "pillow-12.1.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9f51079765661884a486727f0729d29054242f74b46186026582b4e4769918e4"},
{file = "pillow-12.1.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:99c1506ea77c11531d75e3a412832a13a71c7ebc8192ab9e4b2e355555920e3e"},
{file = "pillow-12.1.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:36341d06738a9f66c8287cf8b876d24b18db9bd8740fa0672c74e259ad408cff"},
{file = "pillow-12.1.1-cp310-cp310-win32.whl", hash = "sha256:6c52f062424c523d6c4db85518774cc3d50f5539dd6eed32b8f6229b26f24d40"},
{file = "pillow-12.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:c6008de247150668a705a6338156efb92334113421ceecf7438a12c9a12dab23"},
{file = "pillow-12.1.1-cp310-cp310-win_arm64.whl", hash = "sha256:1a9b0ee305220b392e1124a764ee4265bd063e54a751a6b62eff69992f457fa9"},
{file = "pillow-12.1.1-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:e879bb6cd5c73848ef3b2b48b8af9ff08c5b71ecda8048b7dd22d8a33f60be32"},
{file = "pillow-12.1.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:365b10bb9417dd4498c0e3b128018c4a624dc11c7b97d8cc54effe3b096f4c38"},
{file = "pillow-12.1.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d4ce8e329c93845720cd2014659ca67eac35f6433fd3050393d85f3ecef0dad5"},
{file = "pillow-12.1.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fc354a04072b765eccf2204f588a7a532c9511e8b9c7f900e1b64e3e33487090"},
{file = "pillow-12.1.1-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7e7976bf1910a8116b523b9f9f58bf410f3e8aa330cd9a2bb2953f9266ab49af"},
{file = "pillow-12.1.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:597bd9c8419bc7c6af5604e55847789b69123bbe25d65cc6ad3012b4f3c98d8b"},
{file = "pillow-12.1.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:2c1fc0f2ca5f96a3c8407e41cca26a16e46b21060fe6d5b099d2cb01412222f5"},
{file = "pillow-12.1.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:578510d88c6229d735855e1f278aa305270438d36a05031dfaae5067cc8eb04d"},
{file = "pillow-12.1.1-cp311-cp311-win32.whl", hash = "sha256:7311c0a0dcadb89b36b7025dfd8326ecfa36964e29913074d47382706e516a7c"},
{file = "pillow-12.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:fbfa2a7c10cc2623f412753cddf391c7f971c52ca40a3f65dc5039b2939e8563"},
{file = "pillow-12.1.1-cp311-cp311-win_arm64.whl", hash = "sha256:b81b5e3511211631b3f672a595e3221252c90af017e399056d0faabb9538aa80"},
{file = "pillow-12.1.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:ab323b787d6e18b3d91a72fc99b1a2c28651e4358749842b8f8dfacd28ef2052"},
{file = "pillow-12.1.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:adebb5bee0f0af4909c30db0d890c773d1a92ffe83da908e2e9e720f8edf3984"},
{file = "pillow-12.1.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:bb66b7cc26f50977108790e2456b7921e773f23db5630261102233eb355a3b79"},
{file = "pillow-12.1.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:aee2810642b2898bb187ced9b349e95d2a7272930796e022efaf12e99dccd293"},
{file = "pillow-12.1.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a0b1cd6232e2b618adcc54d9882e4e662a089d5768cd188f7c245b4c8c44a397"},
{file = "pillow-12.1.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7aac39bcf8d4770d089588a2e1dd111cbaa42df5a94be3114222057d68336bd0"},
{file = "pillow-12.1.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:ab174cd7d29a62dd139c44bf74b698039328f45cb03b4596c43473a46656b2f3"},
{file = "pillow-12.1.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:339ffdcb7cbeaa08221cd401d517d4b1fe7a9ed5d400e4a8039719238620ca35"},
{file = "pillow-12.1.1-cp312-cp312-win32.whl", hash = "sha256:5d1f9575a12bed9e9eedd9a4972834b08c97a352bd17955ccdebfeca5913fa0a"},
{file = "pillow-12.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:21329ec8c96c6e979cd0dfd29406c40c1d52521a90544463057d2aaa937d66a6"},
{file = "pillow-12.1.1-cp312-cp312-win_arm64.whl", hash = "sha256:af9a332e572978f0218686636610555ae3defd1633597be015ed50289a03c523"},
{file = "pillow-12.1.1-cp313-cp313-ios_13_0_arm64_iphoneos.whl", hash = "sha256:d242e8ac078781f1de88bf823d70c1a9b3c7950a44cdf4b7c012e22ccbcd8e4e"},
{file = "pillow-12.1.1-cp313-cp313-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:02f84dfad02693676692746df05b89cf25597560db2857363a208e393429f5e9"},
{file = "pillow-12.1.1-cp313-cp313-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:e65498daf4b583091ccbb2556c7000abf0f3349fcd57ef7adc9a84a394ed29f6"},
{file = "pillow-12.1.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:6c6db3b84c87d48d0088943bf33440e0c42370b99b1c2a7989216f7b42eede60"},
{file = "pillow-12.1.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:8b7e5304e34942bf62e15184219a7b5ad4ff7f3bb5cca4d984f37df1a0e1aee2"},
{file = "pillow-12.1.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:18e5bddd742a44b7e6b1e773ab5db102bd7a94c32555ba656e76d319d19c3850"},
{file = "pillow-12.1.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fc44ef1f3de4f45b50ccf9136999d71abb99dca7706bc75d222ed350b9fd2289"},
{file = "pillow-12.1.1-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5a8eb7ed8d4198bccbd07058416eeec51686b498e784eda166395a23eb99138e"},
{file = "pillow-12.1.1-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:47b94983da0c642de92ced1702c5b6c292a84bd3a8e1d1702ff923f183594717"},
{file = "pillow-12.1.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:518a48c2aab7ce596d3bf79d0e275661b846e86e4d0e7dec34712c30fe07f02a"},
{file = "pillow-12.1.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a550ae29b95c6dc13cf69e2c9dc5747f814c54eeb2e32d683e5e93af56caa029"},
{file = "pillow-12.1.1-cp313-cp313-win32.whl", hash = "sha256:a003d7422449f6d1e3a34e3dd4110c22148336918ddbfc6a32581cd54b2e0b2b"},
{file = "pillow-12.1.1-cp313-cp313-win_amd64.whl", hash = "sha256:344cf1e3dab3be4b1fa08e449323d98a2a3f819ad20f4b22e77a0ede31f0faa1"},
{file = "pillow-12.1.1-cp313-cp313-win_arm64.whl", hash = "sha256:5c0dd1636633e7e6a0afe7bf6a51a14992b7f8e60de5789018ebbdfae55b040a"},
{file = "pillow-12.1.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0330d233c1a0ead844fc097a7d16c0abff4c12e856c0b325f231820fee1f39da"},
{file = "pillow-12.1.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:5dae5f21afb91322f2ff791895ddd8889e5e947ff59f71b46041c8ce6db790bc"},
{file = "pillow-12.1.1-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2e0c664be47252947d870ac0d327fea7e63985a08794758aa8af5b6cb6ec0c9c"},
{file = "pillow-12.1.1-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:691ab2ac363b8217f7d31b3497108fb1f50faab2f75dfb03284ec2f217e87bf8"},
{file = "pillow-12.1.1-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e9e8064fb1cc019296958595f6db671fba95209e3ceb0c4734c9baf97de04b20"},
{file = "pillow-12.1.1-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:472a8d7ded663e6162dafdf20015c486a7009483ca671cece7a9279b512fcb13"},
{file = "pillow-12.1.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:89b54027a766529136a06cfebeecb3a04900397a3590fd252160b888479517bf"},
{file = "pillow-12.1.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:86172b0831b82ce4f7877f280055892b31179e1576aa00d0df3bb1bbf8c3e524"},
{file = "pillow-12.1.1-cp313-cp313t-win32.whl", hash = "sha256:44ce27545b6efcf0fdbdceb31c9a5bdea9333e664cda58a7e674bb74608b3986"},
{file = "pillow-12.1.1-cp313-cp313t-win_amd64.whl", hash = "sha256:a285e3eb7a5a45a2ff504e31f4a8d1b12ef62e84e5411c6804a42197c1cf586c"},
{file = "pillow-12.1.1-cp313-cp313t-win_arm64.whl", hash = "sha256:cc7d296b5ea4d29e6570dabeaed58d31c3fea35a633a69679fb03d7664f43fb3"},
{file = "pillow-12.1.1-cp314-cp314-ios_13_0_arm64_iphoneos.whl", hash = "sha256:417423db963cb4be8bac3fc1204fe61610f6abeed1580a7a2cbb2fbda20f12af"},
{file = "pillow-12.1.1-cp314-cp314-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:b957b71c6b2387610f556a7eb0828afbe40b4a98036fc0d2acfa5a44a0c2036f"},
{file = "pillow-12.1.1-cp314-cp314-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:097690ba1f2efdeb165a20469d59d8bb03c55fb6621eb2041a060ae8ea3e9642"},
{file = "pillow-12.1.1-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:2815a87ab27848db0321fb78c7f0b2c8649dee134b7f2b80c6a45c6831d75ccd"},
{file = "pillow-12.1.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:f7ed2c6543bad5a7d5530eb9e78c53132f93dfa44a28492db88b41cdab885202"},
{file = "pillow-12.1.1-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:652a2c9ccfb556235b2b501a3a7cf3742148cd22e04b5625c5fe057ea3e3191f"},
{file = "pillow-12.1.1-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d6e4571eedf43af33d0fc233a382a76e849badbccdf1ac438841308652a08e1f"},
{file = "pillow-12.1.1-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b574c51cf7d5d62e9be37ba446224b59a2da26dc4c1bb2ecbe936a4fb1a7cb7f"},
{file = "pillow-12.1.1-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a37691702ed687799de29a518d63d4682d9016932db66d4e90c345831b02fb4e"},
{file = "pillow-12.1.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:f95c00d5d6700b2b890479664a06e754974848afaae5e21beb4d83c106923fd0"},
{file = "pillow-12.1.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:559b38da23606e68681337ad74622c4dbba02254fc9cb4488a305dd5975c7eeb"},
{file = "pillow-12.1.1-cp314-cp314-win32.whl", hash = "sha256:03edcc34d688572014ff223c125a3f77fb08091e4607e7745002fc214070b35f"},
{file = "pillow-12.1.1-cp314-cp314-win_amd64.whl", hash = "sha256:50480dcd74fa63b8e78235957d302d98d98d82ccbfac4c7e12108ba9ecbdba15"},
{file = "pillow-12.1.1-cp314-cp314-win_arm64.whl", hash = "sha256:5cb1785d97b0c3d1d1a16bc1d710c4a0049daefc4935f3a8f31f827f4d3d2e7f"},
{file = "pillow-12.1.1-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:1f90cff8aa76835cba5769f0b3121a22bd4eb9e6884cfe338216e557a9a548b8"},
{file = "pillow-12.1.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:1f1be78ce9466a7ee64bfda57bdba0f7cc499d9794d518b854816c41bf0aa4e9"},
{file = "pillow-12.1.1-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:42fc1f4677106188ad9a55562bbade416f8b55456f522430fadab3cef7cd4e60"},
{file = "pillow-12.1.1-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:98edb152429ab62a1818039744d8fbb3ccab98a7c29fc3d5fcef158f3f1f68b7"},
{file = "pillow-12.1.1-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d470ab1178551dd17fdba0fef463359c41aaa613cdcd7ff8373f54be629f9f8f"},
{file = "pillow-12.1.1-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6408a7b064595afcab0a49393a413732a35788f2a5092fdc6266952ed67de586"},
{file = "pillow-12.1.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:5d8c41325b382c07799a3682c1c258469ea2ff97103c53717b7893862d0c98ce"},
{file = "pillow-12.1.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:c7697918b5be27424e9ce568193efd13d925c4481dd364e43f5dff72d33e10f8"},
{file = "pillow-12.1.1-cp314-cp314t-win32.whl", hash = "sha256:d2912fd8114fc5545aa3a4b5576512f64c55a03f3ebcca4c10194d593d43ea36"},
{file = "pillow-12.1.1-cp314-cp314t-win_amd64.whl", hash = "sha256:4ceb838d4bd9dab43e06c363cab2eebf63846d6a4aeaea283bbdfd8f1a8ed58b"},
{file = "pillow-12.1.1-cp314-cp314t-win_arm64.whl", hash = "sha256:7b03048319bfc6170e93bd60728a1af51d3dd7704935feb228c4d4faab35d334"},
{file = "pillow-12.1.1-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:600fd103672b925fe62ed08e0d874ea34d692474df6f4bf7ebe148b30f89f39f"},
{file = "pillow-12.1.1-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:665e1b916b043cef294bc54d47bf02d87e13f769bc4bc5fa225a24b3a6c5aca9"},
{file = "pillow-12.1.1-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:495c302af3aad1ca67420ddd5c7bd480c8867ad173528767d906428057a11f0e"},
{file = "pillow-12.1.1-pp311-pypy311_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:8fd420ef0c52c88b5a035a0886f367748c72147b2b8f384c9d12656678dfdfa9"},
{file = "pillow-12.1.1-pp311-pypy311_pp73-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f975aa7ef9684ce7e2c18a3aa8f8e2106ce1e46b94ab713d156b2898811651d3"},
{file = "pillow-12.1.1-pp311-pypy311_pp73-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8089c852a56c2966cf18835db62d9b34fef7ba74c726ad943928d494fa7f4735"},
{file = "pillow-12.1.1-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:cb9bb857b2d057c6dfc72ac5f3b44836924ba15721882ef103cecb40d002d80e"},
{file = "pillow-12.1.1.tar.gz", hash = "sha256:9ad8fa5937ab05218e2b6a4cff30295ad35afd2f83ac592e68c0d871bb0fdbc4"},
]
[package.extras]
@@ -14917,4 +14917,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt
[metadata]
lock-version = "2.1"
python-versions = "^3.12,<3.14"
content-hash = "b5cbb1e25176845ac9f95650a802667e2f8be1a536e3e55a9269b5af5a42e3fc"
content-hash = "1cad6029269393af67155e930c72eae2c03da02e4b3a3699823f6168c14a4218"

View File

@@ -44,6 +44,12 @@ httpx = "*"
scikit-learn = "^1.7.0"
shap = "^0.48.0"
google-cloud-recaptcha-enterprise = "^1.24.0"
# Dependencies previously only in Dockerfile, now managed via poetry.lock
prometheus-client = "^0.24.0"
pandas = "^2.2.0"
numpy = "^2.2.0"
mcp = "^1.10.0"
pillow = "^12.1.0"
[tool.poetry.group.dev.dependencies]
ruff = "0.8.3"

View File

@@ -38,6 +38,12 @@ from server.routes.integration.linear import linear_integration_router # noqa:
from server.routes.integration.slack import slack_router # noqa: E402
from server.routes.mcp_patch import patch_mcp_server # noqa: E402
from server.routes.oauth_device import oauth_device_router # noqa: E402
from server.routes.org_invitations import ( # noqa: E402
accept_router as invitation_accept_router,
)
from server.routes.org_invitations import ( # noqa: E402
invitation_router,
)
from server.routes.orgs import org_router # noqa: E402
from server.routes.readiness import readiness_router # noqa: E402
from server.routes.user import saas_user_router # noqa: E402
@@ -78,8 +84,15 @@ base_app.include_router(shared_event_router)
# Add GitHub integration router only if GITHUB_APP_CLIENT_ID is set
if GITHUB_APP_CLIENT_ID:
# Make sure that the callback processor is loaded here so we don't get an error when deserializing
from integrations.github.github_v1_callback_processor import ( # noqa: E402
GithubV1CallbackProcessor,
)
from server.routes.integration.github import github_integration_router # noqa: E402
# Bludgeon mypy into not deleting my import
logger.debug(f'Loaded {GithubV1CallbackProcessor.__name__}')
base_app.include_router(
github_integration_router
) # Add additional route for integration webhook events
@@ -92,6 +105,8 @@ if GITLAB_APP_CLIENT_ID:
base_app.include_router(api_keys_router) # Add routes for API key management
base_app.include_router(org_router) # Add routes for organization management
base_app.include_router(invitation_router) # Add routes for org invitation management
base_app.include_router(invitation_accept_router) # Add route for accepting invitations
add_github_proxy_routes(base_app)
add_debugging_routes(
base_app

View File

@@ -0,0 +1,306 @@
"""
Permission-based authorization dependencies for API endpoints.
This module provides FastAPI dependencies for checking user permissions
within organizations. It uses a permission-based authorization model where
roles (owner, admin, member) are mapped to specific permissions.
Permissions are defined in the Permission enum and mapped to roles via
ROLE_PERMISSIONS. This allows fine-grained access control while maintaining
the familiar role-based hierarchy.
Usage:
from server.auth.authorization import (
Permission,
require_permission,
)
@router.get('/{org_id}/settings')
async def get_settings(
org_id: UUID,
user_id: str = Depends(require_permission(Permission.VIEW_LLM_SETTINGS)),
):
# Only users with VIEW_LLM_SETTINGS permission can access
...
@router.patch('/{org_id}/settings')
async def update_settings(
org_id: UUID,
user_id: str = Depends(require_permission(Permission.EDIT_LLM_SETTINGS)),
):
# Only users with EDIT_LLM_SETTINGS permission can access
...
"""
from enum import Enum
from uuid import UUID
from fastapi import Depends, HTTPException, status
from storage.org_member_store import OrgMemberStore
from storage.role import Role
from storage.role_store import RoleStore
from openhands.core.logger import openhands_logger as logger
from openhands.server.user_auth import get_user_id
class Permission(str, Enum):
"""Permissions that can be assigned to roles."""
# Secrets
MANAGE_SECRETS = 'manage_secrets'
# MCP
MANAGE_MCP = 'manage_mcp'
# Integrations
MANAGE_INTEGRATIONS = 'manage_integrations'
# Application Settings
MANAGE_APPLICATION_SETTINGS = 'manage_application_settings'
# API Keys
MANAGE_API_KEYS = 'manage_api_keys'
# LLM Settings
VIEW_LLM_SETTINGS = 'view_llm_settings'
EDIT_LLM_SETTINGS = 'edit_llm_settings'
# Billing
VIEW_BILLING = 'view_billing'
ADD_CREDITS = 'add_credits'
# Organization Members
INVITE_USER_TO_ORGANIZATION = 'invite_user_to_organization'
CHANGE_USER_ROLE_MEMBER = 'change_user_role:member'
CHANGE_USER_ROLE_ADMIN = 'change_user_role:admin'
CHANGE_USER_ROLE_OWNER = 'change_user_role:owner'
# Organization Management
VIEW_ORG_SETTINGS = 'view_org_settings'
CHANGE_ORGANIZATION_NAME = 'change_organization_name'
DELETE_ORGANIZATION = 'delete_organization'
# Temporary permissions until we finish the API updates.
EDIT_ORG_SETTINGS = 'edit_org_settings'
class RoleName(str, Enum):
"""Role names used in the system."""
OWNER = 'owner'
ADMIN = 'admin'
MEMBER = 'member'
# Permission mappings for each role
ROLE_PERMISSIONS: dict[RoleName, frozenset[Permission]] = {
RoleName.OWNER: frozenset(
[
# Settings (Full access)
Permission.MANAGE_SECRETS,
Permission.MANAGE_MCP,
Permission.MANAGE_INTEGRATIONS,
Permission.MANAGE_APPLICATION_SETTINGS,
Permission.MANAGE_API_KEYS,
Permission.VIEW_LLM_SETTINGS,
Permission.EDIT_LLM_SETTINGS,
Permission.VIEW_BILLING,
Permission.ADD_CREDITS,
# Organization Members
Permission.INVITE_USER_TO_ORGANIZATION,
Permission.CHANGE_USER_ROLE_MEMBER,
Permission.CHANGE_USER_ROLE_ADMIN,
Permission.CHANGE_USER_ROLE_OWNER,
# Organization Management
Permission.VIEW_ORG_SETTINGS,
Permission.EDIT_ORG_SETTINGS,
# Organization Management (Owner only)
Permission.CHANGE_ORGANIZATION_NAME,
Permission.DELETE_ORGANIZATION,
]
),
RoleName.ADMIN: frozenset(
[
# Settings (Full access)
Permission.MANAGE_SECRETS,
Permission.MANAGE_MCP,
Permission.MANAGE_INTEGRATIONS,
Permission.MANAGE_APPLICATION_SETTINGS,
Permission.MANAGE_API_KEYS,
Permission.VIEW_LLM_SETTINGS,
Permission.EDIT_LLM_SETTINGS,
Permission.VIEW_BILLING,
Permission.ADD_CREDITS,
# Organization Members
Permission.INVITE_USER_TO_ORGANIZATION,
Permission.CHANGE_USER_ROLE_MEMBER,
Permission.CHANGE_USER_ROLE_ADMIN,
# Organization Management
Permission.VIEW_ORG_SETTINGS,
Permission.EDIT_ORG_SETTINGS,
]
),
RoleName.MEMBER: frozenset(
[
# Settings (Full access)
Permission.MANAGE_SECRETS,
Permission.MANAGE_MCP,
Permission.MANAGE_INTEGRATIONS,
Permission.MANAGE_APPLICATION_SETTINGS,
Permission.MANAGE_API_KEYS,
# Settings (View only)
Permission.VIEW_ORG_SETTINGS,
Permission.VIEW_LLM_SETTINGS,
]
),
}
def get_user_org_role(user_id: str, org_id: UUID | None) -> Role | None:
"""
Get the user's role in an organization (synchronous version).
Args:
user_id: User ID (string that will be converted to UUID)
org_id: Organization ID, or None to use the user's current organization
Returns:
Role object if user is a member, None otherwise
"""
from uuid import UUID as parse_uuid
if org_id is None:
org_member = OrgMemberStore.get_org_member_for_current_org(parse_uuid(user_id))
else:
org_member = OrgMemberStore.get_org_member(org_id, parse_uuid(user_id))
if not org_member:
return None
return RoleStore.get_role_by_id(org_member.role_id)
async def get_user_org_role_async(user_id: str, org_id: UUID | None) -> Role | None:
"""
Get the user's role in an organization (async version).
Args:
user_id: User ID (string that will be converted to UUID)
org_id: Organization ID, or None to use the user's current organization
Returns:
Role object if user is a member, None otherwise
"""
from uuid import UUID as parse_uuid
if org_id is None:
org_member = await OrgMemberStore.get_org_member_for_current_org_async(
parse_uuid(user_id)
)
else:
org_member = await OrgMemberStore.get_org_member_async(
org_id, parse_uuid(user_id)
)
if not org_member:
return None
return await RoleStore.get_role_by_id_async(org_member.role_id)
def get_role_permissions(role_name: str) -> frozenset[Permission]:
"""
Get the permissions for a role.
Args:
role_name: Name of the role
Returns:
Set of permissions for the role
"""
try:
role_enum = RoleName(role_name)
return ROLE_PERMISSIONS.get(role_enum, frozenset())
except ValueError:
return frozenset()
def has_permission(user_role: Role, permission: Permission) -> bool:
"""
Check if a role has a specific permission.
Args:
user_role: User's Role object
permission: Permission to check
Returns:
True if the role has the permission
"""
permissions = get_role_permissions(user_role.name)
return permission in permissions
def require_permission(permission: Permission):
"""
Factory function that creates a dependency to require a specific permission.
This creates a FastAPI dependency that:
1. Extracts org_id from the path parameter
2. Gets the authenticated user_id
3. Checks if the user has the required permission in the organization
4. Returns the user_id if authorized, raises HTTPException otherwise
Usage:
@router.get('/{org_id}/settings')
async def get_settings(
org_id: UUID,
user_id: str = Depends(require_permission(Permission.VIEW_LLM_SETTINGS)),
):
...
Args:
permission: The permission required to access the endpoint
Returns:
Dependency function that validates permission and returns user_id
"""
async def permission_checker(
org_id: UUID | None = None,
user_id: str | None = Depends(get_user_id),
) -> str:
if not user_id:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail='User not authenticated',
)
user_role = await get_user_org_role_async(user_id, org_id)
if not user_role:
logger.warning(
'User not a member of organization',
extra={'user_id': user_id, 'org_id': str(org_id)},
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail='User is not a member of this organization',
)
if not has_permission(user_role, permission):
logger.warning(
'Insufficient permissions',
extra={
'user_id': user_id,
'org_id': str(org_id),
'user_role': user_role.name,
'required_permission': permission.value,
},
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f'Requires {permission.value} permission',
)
return user_id
return permission_checker

View File

@@ -15,6 +15,11 @@ IS_FEATURE_ENV = (
) # Does not include the staging deployment
IS_LOCAL_ENV = bool(HOST == 'localhost')
# Role name constants
ROLE_OWNER = 'owner'
ROLE_ADMIN = 'admin'
ROLE_MEMBER = 'member'
# Deprecated - billing margins are now handled internally in litellm
DEFAULT_BILLING_MARGIN = float(os.environ.get('DEFAULT_BILLING_MARGIN', '1.0'))
@@ -25,7 +30,9 @@ PERSONAL_WORKSPACE_VERSION_TO_MODEL = {
2: 'claude-3-7-sonnet-20250219',
3: 'claude-sonnet-4-20250514',
4: 'claude-sonnet-4-20250514',
5: 'claude-opus-4-5-20251101',
# Minimax is now the default as it gives results close to claude in terms of quality
# but at a much lower price
5: 'minimax-m2.5',
}
LITELLM_DEFAULT_MODEL = os.getenv('LITELLM_DEFAULT_MODEL')
@@ -54,7 +61,6 @@ SUBSCRIPTION_PRICE_DATA = {
},
}
DEFAULT_INITIAL_BUDGET = float(os.environ.get('DEFAULT_INITIAL_BUDGET', '10'))
STRIPE_API_KEY = os.environ.get('STRIPE_API_KEY', None)
REQUIRE_PAYMENT = os.environ.get('REQUIRE_PAYMENT', '0') in ('1', 'true')

View File

@@ -51,6 +51,14 @@ def custom_json_serializer(obj, **kwargs):
obj['stack_info'] = format_stack(stack_info)
result = json.dumps(obj, **kwargs)
# Swap out newlines to make things easier to read. This will produce
# invalid json but means we can have similar logs in local development
# to production, making things easier to correlate. Obviously,
# LOG_JSON_FOR_CONSOLE should not be used in production environments.
if LOG_JSON_FOR_CONSOLE:
result = result.replace('\\n', '\n')
return result

View File

@@ -160,6 +160,7 @@ class SetAuthCookieMiddleware:
'/api/billing/customer-setup-success',
'/api/billing/stripe-webhook',
'/api/email/resend',
'/api/organizations/members/invite/accept',
'/oauth/device/authorize',
'/oauth/device/token',
'/api/v1/web-client/config',

View File

@@ -2,10 +2,12 @@ from datetime import UTC, datetime
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, field_validator
from storage.api_key import ApiKey
from storage.api_key_store import ApiKeyStore
from storage.lite_llm_manager import LiteLlmManager
from storage.org_member import OrgMember
from storage.org_member_store import OrgMemberStore
from storage.org_service import OrgService
from storage.user_store import UserStore
from openhands.core.logger import openhands_logger as logger
@@ -52,7 +54,6 @@ async def store_byor_key_in_db(user_id: str, key: str) -> None:
async def generate_byor_key(user_id: str) -> str | None:
"""Generate a new BYOR key for a user."""
try:
user = await UserStore.get_user_by_id_async(user_id)
if not user:
@@ -135,9 +136,9 @@ class ApiKeyCreate(BaseModel):
class ApiKeyResponse(BaseModel):
id: int
name: str | None = None
created_at: str
last_used_at: str | None = None
expires_at: str | None = None
created_at: datetime
last_used_at: datetime | None = None
expires_at: datetime | None = None
class ApiKeyCreateResponse(ApiKeyResponse):
@@ -148,8 +149,47 @@ class LlmApiKeyResponse(BaseModel):
key: str | None
@api_router.post('', response_model=ApiKeyCreateResponse)
async def create_api_key(key_data: ApiKeyCreate, user_id: str = Depends(get_user_id)):
class ByorPermittedResponse(BaseModel):
permitted: bool
class MessageResponse(BaseModel):
message: str
def api_key_to_response(key: ApiKey) -> ApiKeyResponse:
"""Convert an ApiKey model to an ApiKeyResponse."""
return ApiKeyResponse(
id=key.id,
name=key.name,
created_at=key.created_at,
last_used_at=key.last_used_at,
expires_at=key.expires_at,
)
@api_router.get('/llm/byor/permitted', tags=['Keys'])
async def check_byor_permitted(
user_id: str = Depends(get_user_id),
) -> ByorPermittedResponse:
"""Check if BYOR key export is permitted for the user's current org."""
try:
permitted = await OrgService.check_byor_export_enabled(user_id)
return ByorPermittedResponse(permitted=permitted)
except Exception as e:
logger.exception(
'Error checking BYOR export permission', extra={'error': str(e)}
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to check BYOR export permission',
)
@api_router.post('', tags=['Keys'])
async def create_api_key(
key_data: ApiKeyCreate, user_id: str = Depends(get_user_id)
) -> ApiKeyCreateResponse:
"""Create a new API key for the authenticated user."""
try:
api_key = await api_key_store.create_api_key(
@@ -158,48 +198,29 @@ async def create_api_key(key_data: ApiKeyCreate, user_id: str = Depends(get_user
# Get the created key details
keys = await api_key_store.list_api_keys(user_id)
for key in keys:
if key['name'] == key_data.name:
return {
**key,
'key': api_key,
'created_at': (
key['created_at'].isoformat() if key['created_at'] else None
),
'last_used_at': (
key['last_used_at'].isoformat() if key['last_used_at'] else None
),
'expires_at': (
key['expires_at'].isoformat() if key['expires_at'] else None
),
}
if key.name == key_data.name:
return ApiKeyCreateResponse(
id=key.id,
name=key.name,
key=api_key,
created_at=key.created_at,
last_used_at=key.last_used_at,
expires_at=key.expires_at,
)
except Exception:
logger.exception('Error creating API key')
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to create API key',
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to create API key',
)
@api_router.get('', response_model=list[ApiKeyResponse])
async def list_api_keys(user_id: str = Depends(get_user_id)):
@api_router.get('', tags=['Keys'])
async def list_api_keys(user_id: str = Depends(get_user_id)) -> list[ApiKeyResponse]:
"""List all API keys for the authenticated user."""
try:
keys = await api_key_store.list_api_keys(user_id)
return [
{
**key,
'created_at': (
key['created_at'].isoformat() if key['created_at'] else None
),
'last_used_at': (
key['last_used_at'].isoformat() if key['last_used_at'] else None
),
'expires_at': (
key['expires_at'].isoformat() if key['expires_at'] else None
),
}
for key in keys
]
return [api_key_to_response(key) for key in keys]
except Exception:
logger.exception('Error listing API keys')
raise HTTPException(
@@ -208,8 +229,10 @@ async def list_api_keys(user_id: str = Depends(get_user_id)):
)
@api_router.delete('/{key_id}')
async def delete_api_key(key_id: int, user_id: str = Depends(get_user_id)):
@api_router.delete('/{key_id}', tags=['Keys'])
async def delete_api_key(
key_id: int, user_id: str = Depends(get_user_id)
) -> MessageResponse:
"""Delete an API key."""
try:
# First, verify the key belongs to the user
@@ -217,7 +240,7 @@ async def delete_api_key(key_id: int, user_id: str = Depends(get_user_id)):
key_to_delete = None
for key in keys:
if key['id'] == key_id:
if key.id == key_id:
key_to_delete = key
break
@@ -235,7 +258,7 @@ async def delete_api_key(key_id: int, user_id: str = Depends(get_user_id)):
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to delete API key',
)
return {'message': 'API key deleted successfully'}
return MessageResponse(message='API key deleted successfully')
except HTTPException:
raise
except Exception:
@@ -246,22 +269,33 @@ async def delete_api_key(key_id: int, user_id: str = Depends(get_user_id)):
)
@api_router.get('/llm/byor', response_model=LlmApiKeyResponse)
async def get_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
@api_router.get('/llm/byor', tags=['Keys'])
async def get_llm_api_key_for_byor(
user_id: str = Depends(get_user_id),
) -> LlmApiKeyResponse:
"""Get the LLM API key for BYOR (Bring Your Own Runtime) for the authenticated user.
This endpoint validates that the key exists in LiteLLM before returning it.
If validation fails, it automatically generates a new key to ensure users
always receive a working key.
Returns 402 Payment Required if BYOR export is not enabled for the user's org.
"""
try:
# Check if BYOR export is enabled for the user's org
if not await OrgService.check_byor_export_enabled(user_id):
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail='BYOR key export is not enabled. Purchase credits to enable this feature.',
)
# Check if the BYOR key exists in the database
byor_key = await get_byor_key_from_db(user_id)
if byor_key:
# Validate that the key is actually registered in LiteLLM
is_valid = await LiteLlmManager.verify_key(byor_key, user_id)
if is_valid:
return {'key': byor_key}
return LlmApiKeyResponse(key=byor_key)
else:
# Key exists in DB but is invalid in LiteLLM - regenerate it
logger.warning(
@@ -286,7 +320,7 @@ async def get_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
'Successfully generated and stored new BYOR key',
extra={'user_id': user_id},
)
return {'key': key}
return LlmApiKeyResponse(key=key)
else:
logger.error(
'Failed to generate new BYOR LLM API key',
@@ -308,12 +342,24 @@ async def get_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
)
@api_router.post('/llm/byor/refresh', response_model=LlmApiKeyResponse)
async def refresh_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
"""Refresh the LLM API key for BYOR (Bring Your Own Runtime) for the authenticated user."""
@api_router.post('/llm/byor/refresh', tags=['Keys'])
async def refresh_llm_api_key_for_byor(
user_id: str = Depends(get_user_id),
) -> LlmApiKeyResponse:
"""Refresh the LLM API key for BYOR (Bring Your Own Runtime) for the authenticated user.
Returns 402 Payment Required if BYOR export is not enabled for the user's org.
"""
logger.info('Starting BYOR LLM API key refresh', extra={'user_id': user_id})
try:
# Check if BYOR export is enabled for the user's org
if not await OrgService.check_byor_export_enabled(user_id):
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail='BYOR key export is not enabled. Purchase credits to enable this feature.',
)
# Get the existing BYOR key from the database
existing_byor_key = await get_byor_key_from_db(user_id)
@@ -352,7 +398,7 @@ async def refresh_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
'BYOR LLM API key refresh completed successfully',
extra={'user_id': user_id},
)
return {'key': key}
return LlmApiKeyResponse(key=key)
except HTTPException as he:
logger.error(
'HTTP exception during BYOR LLM API key refresh',

View File

@@ -5,6 +5,7 @@ import warnings
from datetime import datetime, timezone
from typing import Annotated, Literal, Optional
from urllib.parse import quote
from uuid import UUID as parse_uuid
import posthog
from fastapi import APIRouter, Header, HTTPException, Request, Response, status
@@ -26,6 +27,13 @@ from server.auth.token_manager import TokenManager
from server.config import sign_token
from server.constants import IS_FEATURE_ENV
from server.routes.event_webhook import _get_session_api_key, _get_user_id
from server.services.org_invitation_service import (
EmailMismatchError,
InvitationExpiredError,
InvitationInvalidError,
OrgInvitationService,
UserAlreadyMemberError,
)
from storage.database import session_maker
from storage.user import User
from storage.user_store import UserStore
@@ -104,22 +112,40 @@ def get_cookie_samesite(request: Request) -> Literal['lax', 'strict']:
)
def _extract_oauth_state(state: str | None) -> tuple[str, str | None, str | None]:
"""Extract redirect URL, reCAPTCHA token, and invitation token from OAuth state.
Returns:
Tuple of (redirect_url, recaptcha_token, invitation_token).
Tokens may be None.
"""
if not state:
return '', None, None
try:
# Try to decode as JSON (new format with reCAPTCHA and/or invitation)
state_data = json.loads(base64.urlsafe_b64decode(state.encode()).decode())
return (
state_data.get('redirect_url', ''),
state_data.get('recaptcha_token'),
state_data.get('invitation_token'),
)
except Exception:
# Old format - state is just the redirect URL
return state, None, None
# Keep alias for backward compatibility
def _extract_recaptcha_state(state: str | None) -> tuple[str, str | None]:
"""Extract redirect URL and reCAPTCHA token from OAuth state.
Deprecated: Use _extract_oauth_state instead.
Returns:
Tuple of (redirect_url, recaptcha_token). Token may be None.
"""
if not state:
return '', None
try:
# Try to decode as JSON (new format with reCAPTCHA)
state_data = json.loads(base64.urlsafe_b64decode(state.encode()).decode())
return state_data.get('redirect_url', ''), state_data.get('recaptcha_token')
except Exception:
# Old format - state is just the redirect URL
return state, None
redirect_url, recaptcha_token, _ = _extract_oauth_state(state)
return redirect_url, recaptcha_token
@oauth_router.get('/keycloak/callback')
@@ -130,8 +156,8 @@ async def keycloak_callback(
error: Optional[str] = None,
error_description: Optional[str] = None,
):
# Extract redirect URL and reCAPTCHA token from state
redirect_url, recaptcha_token = _extract_recaptcha_state(state)
# Extract redirect URL, reCAPTCHA token, and invitation token from state
redirect_url, recaptcha_token, invitation_token = _extract_oauth_state(state)
if not redirect_url:
redirect_url = str(request.base_url)
@@ -302,8 +328,13 @@ async def keycloak_callback(
from server.routes.email import verify_email
await verify_email(request=request, user_id=user_id, is_auth_flow=True)
redirect_url = f'{request.base_url}login?email_verification_required=true&user_id={user_id}'
response = RedirectResponse(redirect_url, status_code=302)
verification_redirect_url = f'{request.base_url}login?email_verification_required=true&user_id={user_id}'
# Preserve invitation token so it can be included in OAuth state after verification
if invitation_token:
verification_redirect_url = (
f'{verification_redirect_url}&invitation_token={invitation_token}'
)
response = RedirectResponse(verification_redirect_url, status_code=302)
return response
# default to github IDP for now.
@@ -381,14 +412,90 @@ async def keycloak_callback(
)
has_accepted_tos = user.accepted_tos is not None
# Process invitation token if present (after email verification but before TOS)
if invitation_token:
try:
logger.info(
'Processing invitation token during auth callback',
extra={
'user_id': user_id,
'invitation_token_prefix': invitation_token[:10] + '...',
},
)
await OrgInvitationService.accept_invitation(
invitation_token, parse_uuid(user_id)
)
logger.info(
'Invitation accepted during auth callback',
extra={'user_id': user_id},
)
except InvitationExpiredError:
logger.warning(
'Invitation expired during auth callback',
extra={'user_id': user_id},
)
# Add query param to redirect URL
if '?' in redirect_url:
redirect_url = f'{redirect_url}&invitation_expired=true'
else:
redirect_url = f'{redirect_url}?invitation_expired=true'
except InvitationInvalidError as e:
logger.warning(
'Invalid invitation during auth callback',
extra={'user_id': user_id, 'error': str(e)},
)
if '?' in redirect_url:
redirect_url = f'{redirect_url}&invitation_invalid=true'
else:
redirect_url = f'{redirect_url}?invitation_invalid=true'
except UserAlreadyMemberError:
logger.info(
'User already member during invitation acceptance',
extra={'user_id': user_id},
)
if '?' in redirect_url:
redirect_url = f'{redirect_url}&already_member=true'
else:
redirect_url = f'{redirect_url}?already_member=true'
except EmailMismatchError as e:
logger.warning(
'Email mismatch during auth callback invitation acceptance',
extra={'user_id': user_id, 'error': str(e)},
)
if '?' in redirect_url:
redirect_url = f'{redirect_url}&email_mismatch=true'
else:
redirect_url = f'{redirect_url}?email_mismatch=true'
except Exception as e:
logger.exception(
'Unexpected error processing invitation during auth callback',
extra={'user_id': user_id, 'error': str(e)},
)
# Don't fail the login if invitation processing fails
if '?' in redirect_url:
redirect_url = f'{redirect_url}&invitation_error=true'
else:
redirect_url = f'{redirect_url}?invitation_error=true'
# If the user hasn't accepted the TOS, redirect to the TOS page
if not has_accepted_tos:
encoded_redirect_url = quote(redirect_url, safe='')
tos_redirect_url = (
f'{request.base_url}accept-tos?redirect_url={encoded_redirect_url}'
)
if invitation_token:
tos_redirect_url = f'{tos_redirect_url}&invitation_success=true'
response = RedirectResponse(tos_redirect_url, status_code=302)
else:
if invitation_token:
redirect_url = f'{redirect_url}&invitation_success=true'
response = RedirectResponse(redirect_url, status_code=302)
set_response_cookie(

View File

@@ -9,14 +9,13 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import RedirectResponse
from integrations import stripe_service
from pydantic import BaseModel
from server.constants import (
STRIPE_API_KEY,
)
from server.constants import STRIPE_API_KEY
from server.logger import logger
from starlette.datastructures import URL
from storage.billing_session import BillingSession
from storage.database import session_maker
from storage.lite_llm_manager import LiteLlmManager
from storage.org import Org
from storage.subscription_access import SubscriptionAccess
from storage.user_store import UserStore
@@ -94,9 +93,9 @@ async def get_credits(user_id: str = Depends(get_user_id)) -> GetCreditsResponse
user_team_info = await LiteLlmManager.get_user_team_info(
user_id, str(user.current_org_id)
)
# Update to use calculate_credits
spend = user_team_info.get('spend', 0)
max_budget = (user_team_info.get('litellm_budget_table') or {}).get('max_budget', 0)
max_budget, spend = LiteLlmManager.get_budget_from_team_info(
user_team_info, user_id, str(user.current_org_id)
)
credits = max(max_budget - spend, 0)
return GetCreditsResponse(credits=Decimal('{:.2f}'.format(credits)))
@@ -148,7 +147,7 @@ async def create_customer_setup_session(
customer=customer_info['customer_id'],
mode='setup',
payment_method_types=['card'],
success_url=f'{base_url}?free_credits=success',
success_url=f'{base_url}?setup=success',
cancel_url=f'{base_url}',
)
return CreateBillingSessionResponse(redirect_url=checkout_session.url)
@@ -250,15 +249,21 @@ async def success_callback(session_id: str, request: Request):
)
amount_subtotal = stripe_session.amount_subtotal or 0
add_credits = amount_subtotal / 100
max_budget = (user_team_info.get('litellm_budget_table') or {}).get(
'max_budget', 0
max_budget, _ = LiteLlmManager.get_budget_from_team_info(
user_team_info, billing_session.user_id, str(user.current_org_id)
)
org = session.query(Org).filter(Org.id == user.current_org_id).first()
new_max_budget = max_budget + add_credits
await LiteLlmManager.update_team_and_users_budget(
str(user.current_org_id), new_max_budget
)
# Enable BYOR export for the org now that they've purchased credits
if org:
org.byor_export_enabled = True
# Store transaction status
billing_session.status = 'completed'
billing_session.price = add_credits

View File

@@ -0,0 +1,122 @@
"""
Pydantic models and custom exceptions for organization invitations.
"""
from pydantic import BaseModel, EmailStr
from storage.org_invitation import OrgInvitation
from storage.role_store import RoleStore
class InvitationError(Exception):
"""Base exception for invitation errors."""
pass
class InvitationAlreadyExistsError(InvitationError):
"""Raised when a pending invitation already exists for the email."""
def __init__(
self, message: str = 'A pending invitation already exists for this email'
):
super().__init__(message)
class UserAlreadyMemberError(InvitationError):
"""Raised when the user is already a member of the organization."""
def __init__(self, message: str = 'User is already a member of this organization'):
super().__init__(message)
class InvitationExpiredError(InvitationError):
"""Raised when the invitation has expired."""
def __init__(self, message: str = 'Invitation has expired'):
super().__init__(message)
class InvitationInvalidError(InvitationError):
"""Raised when the invitation is invalid or revoked."""
def __init__(self, message: str = 'Invitation is no longer valid'):
super().__init__(message)
class InsufficientPermissionError(InvitationError):
"""Raised when the user lacks permission to perform the action."""
def __init__(self, message: str = 'Insufficient permission'):
super().__init__(message)
class EmailMismatchError(InvitationError):
"""Raised when the accepting user's email doesn't match the invitation email."""
def __init__(self, message: str = 'Your email does not match the invitation'):
super().__init__(message)
class InvitationCreate(BaseModel):
"""Request model for creating invitation(s)."""
emails: list[EmailStr]
role: str = 'member' # Default to member role
class InvitationResponse(BaseModel):
"""Response model for invitation details."""
id: int
email: str
role: str
status: str
created_at: str
expires_at: str
inviter_email: str | None = None
@classmethod
def from_invitation(
cls,
invitation: OrgInvitation,
inviter_email: str | None = None,
) -> 'InvitationResponse':
"""Create an InvitationResponse from an OrgInvitation entity.
Args:
invitation: The invitation entity to convert
inviter_email: Optional email of the inviter
Returns:
InvitationResponse: The response model instance
"""
role_name = ''
if invitation.role:
role_name = invitation.role.name
elif invitation.role_id:
role = RoleStore.get_role_by_id(invitation.role_id)
role_name = role.name if role else ''
return cls(
id=invitation.id,
email=invitation.email,
role=role_name,
status=invitation.status,
created_at=invitation.created_at.isoformat(),
expires_at=invitation.expires_at.isoformat(),
inviter_email=inviter_email,
)
class InvitationFailure(BaseModel):
"""Response model for a failed invitation."""
email: str
error: str
class BatchInvitationResponse(BaseModel):
"""Response model for batch invitation creation."""
successful: list[InvitationResponse]
failed: list[InvitationFailure]

View File

@@ -0,0 +1,226 @@
"""API routes for organization invitations."""
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import RedirectResponse
from server.routes.org_invitation_models import (
BatchInvitationResponse,
EmailMismatchError,
InsufficientPermissionError,
InvitationCreate,
InvitationExpiredError,
InvitationFailure,
InvitationInvalidError,
InvitationResponse,
UserAlreadyMemberError,
)
from server.services.org_invitation_service import OrgInvitationService
from server.utils.rate_limit_utils import check_rate_limit_by_user_id
from openhands.core.logger import openhands_logger as logger
from openhands.server.user_auth import get_user_id
from openhands.server.user_auth.user_auth import get_user_auth
# Router for invitation operations on an organization (requires org_id)
invitation_router = APIRouter(prefix='/api/organizations/{org_id}/members')
# Router for accepting invitations (no org_id required)
accept_router = APIRouter(prefix='/api/organizations/members/invite')
@invitation_router.post(
'/invite',
response_model=BatchInvitationResponse,
status_code=status.HTTP_201_CREATED,
)
async def create_invitation(
org_id: UUID,
invitation_data: InvitationCreate,
request: Request,
user_id: str = Depends(get_user_id),
):
"""Create organization invitations for multiple email addresses.
Sends emails to invitees with secure links to join the organization.
Supports batch invitations - some may succeed while others fail.
Permission rules:
- Only owners and admins can create invitations
- Admins can only invite with 'member' or 'admin' role (not 'owner')
- Owners can invite with any role
Args:
org_id: Organization UUID
invitation_data: Invitation details (emails array, role)
request: FastAPI request
user_id: Authenticated user ID (from dependency)
Returns:
BatchInvitationResponse: Lists of successful and failed invitations
Raises:
HTTPException 400: Invalid role or organization not found
HTTPException 403: User lacks permission to invite
HTTPException 429: Rate limit exceeded
"""
# Rate limit: 10 invitations per minute per user (6 seconds between requests)
await check_rate_limit_by_user_id(
request=request,
key_prefix='org_invitation_create',
user_id=user_id,
user_rate_limit_seconds=6,
)
try:
successful, failed = await OrgInvitationService.create_invitations_batch(
org_id=org_id,
emails=[str(email) for email in invitation_data.emails],
role_name=invitation_data.role,
inviter_id=UUID(user_id),
)
logger.info(
'Batch organization invitations created',
extra={
'org_id': str(org_id),
'total_emails': len(invitation_data.emails),
'successful': len(successful),
'failed': len(failed),
'inviter_id': user_id,
},
)
return BatchInvitationResponse(
successful=[InvitationResponse.from_invitation(inv) for inv in successful],
failed=[
InvitationFailure(email=email, error=error) for email, error in failed
],
)
except InsufficientPermissionError as e:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(e),
)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
)
except Exception as e:
logger.exception(
'Unexpected error creating batch invitations',
extra={'org_id': str(org_id), 'error': str(e)},
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='An unexpected error occurred',
)
@accept_router.get('/accept')
async def accept_invitation(
token: str,
request: Request,
):
"""Accept an organization invitation via token.
This endpoint is accessed via the link in the invitation email.
Flow:
1. If user is authenticated: Accept invitation directly and redirect to home
2. If user is not authenticated: Redirect to login page with invitation token
- Frontend stores token and includes it in OAuth state during login
- After authentication, keycloak_callback processes the invitation
Args:
token: The invitation token from the email link
request: FastAPI request
Returns:
RedirectResponse: Redirect to home page on success, or login page if not authenticated,
or home page with error query params on failure
"""
base_url = str(request.base_url).rstrip('/')
# Try to get user_id from auth (may not be authenticated)
user_id = None
try:
user_auth = await get_user_auth(request)
if user_auth:
user_id = await user_auth.get_user_id()
except Exception:
pass
if not user_id:
# User not authenticated - redirect to login page with invitation token
# Frontend will store the token and include it in OAuth state during login
logger.info(
'Invitation accept: redirecting unauthenticated user to login',
extra={'token_prefix': token[:10] + '...'},
)
login_url = f'{base_url}/login?invitation_token={token}'
return RedirectResponse(login_url, status_code=302)
# User is authenticated - process the invitation directly
try:
await OrgInvitationService.accept_invitation(token, UUID(user_id))
logger.info(
'Invitation accepted successfully',
extra={
'token_prefix': token[:10] + '...',
'user_id': user_id,
},
)
# Redirect to home page on success
return RedirectResponse(f'{base_url}/', status_code=302)
except InvitationExpiredError:
logger.warning(
'Invitation accept failed: expired',
extra={'token_prefix': token[:10] + '...', 'user_id': user_id},
)
return RedirectResponse(f'{base_url}/?invitation_expired=true', status_code=302)
except InvitationInvalidError as e:
logger.warning(
'Invitation accept failed: invalid',
extra={
'token_prefix': token[:10] + '...',
'user_id': user_id,
'error': str(e),
},
)
return RedirectResponse(f'{base_url}/?invitation_invalid=true', status_code=302)
except UserAlreadyMemberError:
logger.info(
'Invitation accept: user already member',
extra={'token_prefix': token[:10] + '...', 'user_id': user_id},
)
return RedirectResponse(f'{base_url}/?already_member=true', status_code=302)
except EmailMismatchError as e:
logger.warning(
'Invitation accept failed: email mismatch',
extra={
'token_prefix': token[:10] + '...',
'user_id': user_id,
'error': str(e),
},
)
return RedirectResponse(f'{base_url}/?email_mismatch=true', status_code=302)
except Exception as e:
logger.exception(
'Unexpected error accepting invitation',
extra={
'token_prefix': token[:10] + '...',
'user_id': user_id,
'error': str(e),
},
)
return RedirectResponse(f'{base_url}/?invitation_error=true', status_code=302)

View File

@@ -1,7 +1,9 @@
from typing import Annotated
from pydantic import BaseModel, EmailStr, Field, StringConstraints
from pydantic import BaseModel, EmailStr, Field, SecretStr, StringConstraints
from storage.org import Org
from storage.org_member import OrgMember
from storage.role import Role
class OrgCreationError(Exception):
@@ -43,6 +45,16 @@ class OrgAuthorizationError(OrgDeletionError):
super().__init__(message)
class OrphanedUserError(OrgDeletionError):
"""Raised when deleting an org would leave users without any organization."""
def __init__(self, user_ids: list[str]):
self.user_ids = user_ids
super().__init__(
f'Cannot delete organization: {len(user_ids)} user(s) would have no remaining organization'
)
class OrgNotFoundError(Exception):
"""Raised when organization is not found or user doesn't have access."""
@@ -51,6 +63,61 @@ class OrgNotFoundError(Exception):
super().__init__(f'Organization with id "{org_id}" not found')
class OrgMemberNotFoundError(Exception):
"""Raised when a member is not found in an organization."""
def __init__(self, org_id: str, user_id: str):
self.org_id = org_id
self.user_id = user_id
super().__init__(f'Member "{user_id}" not found in organization "{org_id}"')
class RoleNotFoundError(Exception):
"""Raised when a role is not found."""
def __init__(self, role_id: int):
self.role_id = role_id
super().__init__(f'Role with id "{role_id}" not found')
class InvalidRoleError(Exception):
"""Raised when an invalid role name is specified."""
def __init__(self, role_name: str):
self.role_name = role_name
super().__init__(f'Invalid role: "{role_name}"')
class InsufficientPermissionError(Exception):
"""Raised when user lacks permission to perform an operation."""
def __init__(self, message: str = 'Insufficient permission'):
super().__init__(message)
class CannotModifySelfError(Exception):
"""Raised when user attempts to modify their own membership."""
def __init__(self, action: str = 'modify'):
self.action = action
super().__init__(f'Cannot {action} your own membership')
class LastOwnerError(Exception):
"""Raised when attempting to remove or demote the last owner."""
def __init__(self, action: str = 'remove'):
self.action = action
super().__init__(f'Cannot {action} the last owner of an organization')
class MemberUpdateError(Exception):
"""Raised when member update operation fails."""
def __init__(self, message: str = 'Failed to update member'):
super().__init__(message)
class OrgCreate(BaseModel):
"""Request model for creating a new organization."""
@@ -91,14 +158,18 @@ class OrgResponse(BaseModel):
enable_solvability_analysis: bool | None = None
v1_enabled: bool | None = None
credits: float | None = None
is_personal: bool = False
@classmethod
def from_org(cls, org: Org, credits: float | None = None) -> 'OrgResponse':
def from_org(
cls, org: Org, credits: float | None = None, user_id: str | None = None
) -> 'OrgResponse':
"""Create an OrgResponse from an Org entity.
Args:
org: The organization entity to convert
credits: Optional credits value (defaults to None)
user_id: Optional user ID to determine if org is personal (defaults to None)
Returns:
OrgResponse: The response model instance
@@ -134,6 +205,7 @@ class OrgResponse(BaseModel):
enable_solvability_analysis=org.enable_solvability_analysis,
v1_enabled=org.v1_enabled,
credits=credits,
is_personal=str(org.id) == user_id if user_id else False,
)
@@ -142,12 +214,17 @@ class OrgPage(BaseModel):
items: list[OrgResponse]
next_page_id: str | None = None
current_org_id: str | None = None
class OrgUpdate(BaseModel):
"""Request model for updating an organization."""
# Basic organization information (any authenticated user can update)
name: Annotated[
str | None,
StringConstraints(strip_whitespace=True, min_length=1, max_length=255),
] = None
contact_name: str | None = None
contact_email: EmailStr | None = None
conversation_expiration: int | None = None
@@ -173,3 +250,79 @@ class OrgUpdate(BaseModel):
confirmation_mode: bool | None = None
enable_default_condenser: bool | None = None
condenser_max_size: int | None = Field(default=None, ge=20)
class OrgMemberResponse(BaseModel):
"""Response model for a single organization member."""
user_id: str
email: str | None
role_id: int
role: str
role_rank: int
status: str | None
class OrgMemberPage(BaseModel):
"""Paginated response for organization members."""
items: list[OrgMemberResponse]
next_page_id: str | None = None
class OrgMemberUpdate(BaseModel):
"""Request model for updating an organization member."""
role: str | None = None # Role name: 'owner', 'admin', or 'member'
class MeResponse(BaseModel):
"""Response model for the current user's membership in an organization."""
org_id: str
user_id: str
email: str
role: str
llm_api_key: str
max_iterations: int | None = None
llm_model: str | None = None
llm_api_key_for_byor: str | None = None
llm_base_url: str | None = None
status: str | None = None
@staticmethod
def _mask_key(secret: SecretStr | None) -> str:
"""Mask an API key, showing only last 4 characters."""
if secret is None:
return ''
raw = secret.get_secret_value()
if not raw:
return ''
if len(raw) <= 4:
return '****'
return '****' + raw[-4:]
@classmethod
def from_org_member(cls, member: OrgMember, role: Role, email: str) -> 'MeResponse':
"""Create a MeResponse from an OrgMember, Role, and user email.
Args:
member: The OrgMember entity
role: The Role entity (provides role name)
email: The user's email address
Returns:
MeResponse with masked API keys
"""
return cls(
org_id=str(member.org_id),
user_id=str(member.user_id),
email=email,
role=role.name,
llm_api_key=cls._mask_key(member.llm_api_key),
max_iterations=member.max_iterations,
llm_model=member.llm_model,
llm_api_key_for_byor=cls._mask_key(member.llm_api_key_for_byor) or None,
llm_base_url=member.llm_base_url,
status=member.status,
)

View File

@@ -2,19 +2,37 @@ from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, status
from server.auth.authorization import (
Permission,
require_permission,
)
from server.email_validation import get_admin_user_id
from server.routes.org_models import (
CannotModifySelfError,
InsufficientPermissionError,
InvalidRoleError,
LastOwnerError,
LiteLLMIntegrationError,
MemberUpdateError,
MeResponse,
OrgAuthorizationError,
OrgCreate,
OrgDatabaseError,
OrgMemberNotFoundError,
OrgMemberPage,
OrgMemberResponse,
OrgMemberUpdate,
OrgNameExistsError,
OrgNotFoundError,
OrgPage,
OrgResponse,
OrgUpdate,
OrphanedUserError,
RoleNotFoundError,
)
from server.services.org_member_service import OrgMemberService
from storage.org_service import OrgService
from storage.user_store import UserStore
from openhands.core.logger import openhands_logger as logger
from openhands.server.user_auth import get_user_id
@@ -61,6 +79,12 @@ async def list_user_orgs(
)
try:
# Fetch user to get current_org_id
user = await UserStore.get_user_by_id_async(user_id)
current_org_id = (
str(user.current_org_id) if user and user.current_org_id else None
)
# Fetch organizations from service layer
orgs, next_page_id = OrgService.get_user_orgs_paginated(
user_id=user_id,
@@ -69,7 +93,9 @@ async def list_user_orgs(
)
# Convert Org entities to OrgResponse objects
org_responses = [OrgResponse.from_org(org, credits=None) for org in orgs]
org_responses = [
OrgResponse.from_org(org, credits=None, user_id=user_id) for org in orgs
]
logger.info(
'Successfully retrieved organizations',
@@ -80,7 +106,11 @@ async def list_user_orgs(
},
)
return OrgPage(items=org_responses, next_page_id=next_page_id)
return OrgPage(
items=org_responses,
next_page_id=next_page_id,
current_org_id=current_org_id,
)
except Exception as e:
logger.exception(
@@ -136,7 +166,7 @@ async def create_org(
# Retrieve credits from LiteLLM
credits = await OrgService.get_org_credits(user_id, org.id)
return OrgResponse.from_org(org, credits=credits)
return OrgResponse.from_org(org, credits=credits, user_id=user_id)
except OrgNameExistsError as e:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
@@ -174,23 +204,26 @@ async def create_org(
@org_router.get('/{org_id}', response_model=OrgResponse, status_code=status.HTTP_200_OK)
async def get_org(
org_id: UUID,
user_id: str = Depends(get_user_id),
user_id: str = Depends(require_permission(Permission.VIEW_ORG_SETTINGS)),
) -> OrgResponse:
"""Get organization details by ID.
This endpoint allows authenticated users who are members of an organization
to retrieve its details. Only members of the organization can access this endpoint.
This endpoint retrieves details for a specific organization. Access requires
the VIEW_ORG_SETTINGS permission, which is granted to all organization members
(member, admin, and owner roles).
Args:
org_id: Organization ID (UUID)
user_id: Authenticated user ID (injected by dependency)
user_id: Authenticated user ID (injected by require_permission dependency)
Returns:
OrgResponse: The organization details
Raises:
HTTPException: 401 if user is not authenticated
HTTPException: 403 if user lacks VIEW_ORG_SETTINGS permission
HTTPException: 404 if organization not found
HTTPException: 422 if org_id is not a valid UUID (handled by FastAPI)
HTTPException: 404 if organization not found or user is not a member
HTTPException: 500 if retrieval fails
"""
logger.info(
@@ -211,7 +244,7 @@ async def get_org(
# Retrieve credits from LiteLLM
credits = await OrgService.get_org_credits(user_id, org.id)
return OrgResponse.from_org(org, credits=credits)
return OrgResponse.from_org(org, credits=credits, user_id=user_id)
except OrgNotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
@@ -228,26 +261,86 @@ async def get_org(
)
@org_router.get('/{org_id}/me', response_model=MeResponse)
async def get_me(
org_id: UUID,
user_id: str = Depends(get_user_id),
) -> MeResponse:
"""Get the current user's membership record for an organization.
Returns the authenticated user's role, status, email, and LLM override
fields (with masked API keys) within the specified organization.
Args:
org_id: Organization ID (UUID)
user_id: Authenticated user ID (injected by dependency)
Returns:
MeResponse: The user's membership data
Raises:
HTTPException: 404 if user is not a member or org doesn't exist
HTTPException: 500 if retrieval fails
"""
logger.info(
'Retrieving current member details',
extra={'user_id': user_id, 'org_id': str(org_id)},
)
try:
user_uuid = UUID(user_id)
return OrgMemberService.get_me(org_id, user_uuid)
except OrgMemberNotFoundError:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f'Organization with id "{org_id}" not found',
)
except RoleNotFoundError as e:
logger.exception(
'Role not found for org member',
extra={
'user_id': user_id,
'org_id': str(org_id),
'role_id': e.role_id,
},
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='An unexpected error occurred',
)
except Exception as e:
logger.exception(
'Unexpected error retrieving member details',
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='An unexpected error occurred',
)
@org_router.delete('/{org_id}', status_code=status.HTTP_200_OK)
async def delete_org(
org_id: UUID,
user_id: str = Depends(get_admin_user_id),
user_id: str = Depends(require_permission(Permission.DELETE_ORGANIZATION)),
) -> dict:
"""Delete an organization.
This endpoint allows authenticated organization owners to delete their organization.
All associated data including organization members, conversations, billing data,
and external LiteLLM team resources will be permanently removed.
This endpoint permanently deletes an organization and all associated data including
organization members, conversations, billing data, and external LiteLLM team resources.
Access requires the DELETE_ORGANIZATION permission, which is granted only to owners.
Args:
org_id: Organization ID to delete
user_id: Authenticated user ID (injected by dependency)
org_id: Organization ID to delete (UUID)
user_id: Authenticated user ID (injected by require_permission dependency)
Returns:
dict: Confirmation message with deleted organization details
Raises:
HTTPException: 403 if user is not the organization owner
HTTPException: 401 if user is not authenticated
HTTPException: 403 if user lacks DELETE_ORGANIZATION permission
HTTPException: 404 if organization not found
HTTPException: 500 if deletion fails
"""
@@ -303,6 +396,19 @@ async def delete_org(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(e),
)
except OrphanedUserError as e:
logger.warning(
'Cannot delete organization: users would be orphaned',
extra={
'user_id': user_id,
'org_id': str(org_id),
'orphaned_users': e.user_ids,
},
)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
)
except OrgDatabaseError as e:
logger.error(
'Database error during organization deletion',
@@ -327,25 +433,26 @@ async def delete_org(
async def update_org(
org_id: UUID,
update_data: OrgUpdate,
user_id: str = Depends(get_user_id),
user_id: str = Depends(require_permission(Permission.EDIT_ORG_SETTINGS)),
) -> OrgResponse:
"""Update an existing organization.
This endpoint allows authenticated users to update organization settings.
LLM-related settings require admin or owner role in the organization.
This endpoint updates organization settings. Access requires the EDIT_ORG_SETTINGS
permission, which is granted to admin and owner roles.
Args:
org_id: Organization ID to update (UUID validated by FastAPI)
org_id: Organization ID to update (UUID)
update_data: Organization update data
user_id: Authenticated user ID (injected by dependency)
user_id: Authenticated user ID (injected by require_permission dependency)
Returns:
OrgResponse: The updated organization details
Raises:
HTTPException: 400 if org_id is invalid UUID format (handled by FastAPI)
HTTPException: 403 if user lacks permission for LLM settings
HTTPException: 401 if user is not authenticated
HTTPException: 403 if user lacks EDIT_ORG_SETTINGS permission
HTTPException: 404 if organization not found
HTTPException: 409 if organization name already exists
HTTPException: 422 if validation errors occur (handled by FastAPI)
HTTPException: 500 if update fails
"""
@@ -368,7 +475,7 @@ async def update_org(
# Retrieve credits from LiteLLM (following same pattern as create endpoint)
credits = await OrgService.get_org_credits(user_id, updated_org.id)
return OrgResponse.from_org(updated_org, credits=credits)
return OrgResponse.from_org(updated_org, credits=credits, user_id=user_id)
except ValueError as e:
# Organization not found
@@ -376,6 +483,11 @@ async def update_org(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
)
except OrgNameExistsError as e:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=str(e),
)
except PermissionError as e:
# User lacks permission for LLM settings
raise HTTPException(
@@ -400,3 +512,314 @@ async def update_org(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='An unexpected error occurred',
)
@org_router.get('/{org_id}/members')
async def get_org_members(
org_id: UUID,
page_id: Annotated[
str | None,
Query(title='Optional next_page_id from the previously returned page'),
] = None,
limit: Annotated[
int,
Query(
title='The max number of results in the page',
gt=0,
lte=100,
),
] = 100,
user_id: str = Depends(require_permission(Permission.VIEW_ORG_SETTINGS)),
) -> OrgMemberPage:
"""Get all members of an organization with cursor-based pagination.
This endpoint retrieves a paginated list of organization members. Access requires
the VIEW_ORG_SETTINGS permission, which is granted to all organization members
(member, admin, and owner roles).
Args:
org_id: Organization ID (UUID)
page_id: Optional page ID (offset) for pagination
limit: Maximum number of members to return (1-100, default 100)
user_id: Authenticated user ID (injected by require_permission dependency)
Returns:
OrgMemberPage: Paginated list of organization members
Raises:
HTTPException: 401 if user is not authenticated
HTTPException: 403 if user lacks VIEW_ORG_SETTINGS permission
HTTPException: 400 if org_id or page_id format is invalid
HTTPException: 500 if retrieval fails
"""
try:
success, error_code, data = await OrgMemberService.get_org_members(
org_id=org_id,
current_user_id=UUID(user_id),
page_id=page_id,
limit=limit,
)
if not success:
error_map = {
'not_a_member': (
status.HTTP_403_FORBIDDEN,
'You are not a member of this organization',
),
'invalid_page_id': (
status.HTTP_400_BAD_REQUEST,
'Invalid page_id format',
),
}
status_code, detail = error_map.get(
error_code, (status.HTTP_500_INTERNAL_SERVER_ERROR, 'An error occurred')
)
raise HTTPException(status_code=status_code, detail=detail)
if data is None:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to retrieve members',
)
return data
except HTTPException:
raise
except ValueError:
logger.exception('Invalid UUID format')
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail='Invalid organization ID format',
)
except Exception:
logger.exception('Error retrieving organization members')
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to retrieve members',
)
@org_router.delete('/{org_id}/members/{user_id}')
async def remove_org_member(
org_id: UUID,
user_id: str,
current_user_id: str = Depends(get_user_id),
):
"""Remove a member from an organization.
Only owners and admins can remove members:
- Owners can remove admins and regular users
- Admins can only remove regular users
Users cannot remove themselves. The last owner cannot be removed.
"""
try:
success, error = await OrgMemberService.remove_org_member(
org_id=org_id,
target_user_id=UUID(user_id),
current_user_id=UUID(current_user_id),
)
if not success:
error_map = {
'not_a_member': (
status.HTTP_403_FORBIDDEN,
'You are not a member of this organization',
),
'cannot_remove_self': (
status.HTTP_403_FORBIDDEN,
'Cannot remove yourself from an organization',
),
'member_not_found': (
status.HTTP_404_NOT_FOUND,
'Member not found in this organization',
),
'insufficient_permission': (
status.HTTP_403_FORBIDDEN,
'You do not have permission to remove this member',
),
'cannot_remove_last_owner': (
status.HTTP_400_BAD_REQUEST,
'Cannot remove the last owner of an organization',
),
'removal_failed': (
status.HTTP_500_INTERNAL_SERVER_ERROR,
'Failed to remove member',
),
}
status_code, detail = error_map.get(
error, (status.HTTP_500_INTERNAL_SERVER_ERROR, 'An error occurred')
)
raise HTTPException(status_code=status_code, detail=detail)
return {'message': 'Member removed successfully'}
except HTTPException:
raise
except ValueError:
logger.exception('Invalid UUID format')
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail='Invalid organization or user ID format',
)
except Exception:
logger.exception('Error removing organization member')
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to remove member',
)
@org_router.post(
'/{org_id}/switch', response_model=OrgResponse, status_code=status.HTTP_200_OK
)
async def switch_org(
org_id: UUID,
user_id: str = Depends(get_user_id),
) -> OrgResponse:
"""Switch to a different organization.
This endpoint allows authenticated users to switch their current active
organization. The user must be a member of the target organization.
Args:
org_id: Organization ID to switch to (UUID)
user_id: Authenticated user ID (injected by dependency)
Returns:
OrgResponse: The organization details that was switched to
Raises:
HTTPException: 422 if org_id is not a valid UUID (handled by FastAPI)
HTTPException: 403 if user is not a member of the organization
HTTPException: 404 if organization not found
HTTPException: 500 if switch fails
"""
logger.info(
'Switching organization',
extra={
'user_id': user_id,
'org_id': str(org_id),
},
)
try:
# Use service layer to switch organization with membership validation
org = await OrgService.switch_org(
user_id=user_id,
org_id=org_id,
)
# Retrieve credits from LiteLLM for the new current org
credits = await OrgService.get_org_credits(user_id, org.id)
return OrgResponse.from_org(org, credits=credits, user_id=user_id)
except OrgNotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
)
except OrgAuthorizationError as e:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(e),
)
except OrgDatabaseError as e:
logger.error(
'Database operation failed during organization switch',
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to switch organization',
)
except Exception as e:
logger.exception(
'Unexpected error switching organization',
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='An unexpected error occurred',
)
@org_router.patch('/{org_id}/members/{user_id}', response_model=OrgMemberResponse)
async def update_org_member(
org_id: UUID,
user_id: str,
update_data: OrgMemberUpdate,
current_user_id: str = Depends(get_user_id),
) -> OrgMemberResponse:
"""Update a member's role in an organization.
Permission rules:
- Admins can change roles of regular members to Admin or Member
- Admins cannot modify other Admins or Owners
- Owners can change roles of Admins and Members to any role (Owner, Admin, Member)
- Owners cannot modify other Owners
Members cannot modify their own role. The last owner cannot be demoted.
"""
try:
return await OrgMemberService.update_org_member(
org_id=org_id,
target_user_id=UUID(user_id),
current_user_id=UUID(current_user_id),
update_data=update_data,
)
except OrgMemberNotFoundError as e:
# Distinguish between requester not being a member vs target not found
if str(current_user_id) in str(e):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail='You are not a member of this organization',
)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail='Member not found in this organization',
)
except CannotModifySelfError:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail='Cannot modify your own role',
)
except RoleNotFoundError:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Role configuration error',
)
except InvalidRoleError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail='Invalid role specified',
)
except InsufficientPermissionError:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail='You do not have permission to modify this member',
)
except LastOwnerError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail='Cannot demote the last owner of an organization',
)
except MemberUpdateError:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to update member',
)
except ValueError:
logger.exception('Invalid UUID format')
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail='Invalid organization or user ID format',
)
except Exception:
logger.exception('Error updating organization member')
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to update member',
)

View File

@@ -1139,6 +1139,71 @@ class SaasNestedConversationManager(ConversationManager):
}
update_conversation_metadata(conversation_id, metadata_content)
async def list_files(self, sid: str, path: str | None = None) -> list[str]:
"""List files in the workspace for a conversation.
Delegates to the nested container's list-files endpoint.
Args:
sid: The session/conversation ID.
path: Optional path to list files from. If None, lists from workspace root.
Returns:
A list of file paths.
Raises:
ValueError: If the conversation is not running.
httpx.HTTPError: If there's an error communicating with the nested runtime.
"""
runtime = await self._get_runtime(sid)
if runtime is None or runtime.get('status') != 'running':
raise ValueError(f'Conversation {sid} is not running')
nested_url = self._get_nested_url_for_runtime(runtime['runtime_id'], sid)
session_api_key = runtime.get('session_api_key')
return await self._fetch_list_files_from_nested(
sid, nested_url, session_api_key, path
)
async def select_file(self, sid: str, file: str) -> tuple[str | None, str | None]:
"""Read a file from the workspace via nested container.
Raises:
ValueError: If the conversation is not running.
httpx.HTTPError: If there's an error communicating with the nested runtime.
"""
runtime = await self._get_runtime(sid)
if runtime is None or runtime.get('status') != 'running':
raise ValueError(f'Conversation {sid} is not running')
nested_url = self._get_nested_url_for_runtime(runtime['runtime_id'], sid)
session_api_key = runtime.get('session_api_key')
return await self._fetch_select_file_from_nested(
sid, nested_url, session_api_key, file
)
async def upload_files(
self, sid: str, files: list[tuple[str, bytes]]
) -> tuple[list[str], list[dict[str, str]]]:
"""Upload files to the workspace via nested container.
Raises:
ValueError: If the conversation is not running.
httpx.HTTPError: If there's an error communicating with the nested runtime.
"""
runtime = await self._get_runtime(sid)
if runtime is None or runtime.get('status') != 'running':
raise ValueError(f'Conversation {sid} is not running')
nested_url = self._get_nested_url_for_runtime(runtime['runtime_id'], sid)
session_api_key = runtime.get('session_api_key')
return await self._fetch_upload_files_to_nested(
sid, nested_url, session_api_key, files
)
def _last_updated_at_key(conversation: ConversationMetadata) -> float:
last_updated_at = conversation.last_updated_at

View File

@@ -0,0 +1,131 @@
"""Email service for sending transactional emails via Resend."""
import os
try:
import resend
RESEND_AVAILABLE = True
except ImportError:
RESEND_AVAILABLE = False
from openhands.core.logger import openhands_logger as logger
DEFAULT_FROM_EMAIL = 'OpenHands <no-reply@openhands.dev>'
DEFAULT_WEB_HOST = 'https://app.all-hands.dev'
class EmailService:
"""Service for sending transactional emails."""
@staticmethod
def _get_resend_client() -> bool:
"""Initialize and return the Resend client.
Returns:
bool: True if client is ready, False otherwise
"""
if not RESEND_AVAILABLE:
logger.warning('Resend library not installed, skipping email')
return False
resend_api_key = os.environ.get('RESEND_API_KEY')
if not resend_api_key:
logger.warning('RESEND_API_KEY not configured, skipping email')
return False
resend.api_key = resend_api_key
return True
@staticmethod
def send_invitation_email(
to_email: str,
org_name: str,
inviter_name: str,
role_name: str,
invitation_token: str,
invitation_id: int,
) -> None:
"""Send an organization invitation email.
Args:
to_email: Recipient's email address
org_name: Name of the organization
inviter_name: Display name of the person who sent the invite
role_name: Role being offered (e.g., 'member', 'admin')
invitation_token: The secure invitation token
invitation_id: The invitation ID for logging
"""
if not EmailService._get_resend_client():
return
# Build invitation URL
web_host = os.environ.get('WEB_HOST', DEFAULT_WEB_HOST)
invitation_url = f'{web_host}/api/organizations/members/invite/accept?token={invitation_token}'
from_email = os.environ.get('RESEND_FROM_EMAIL', DEFAULT_FROM_EMAIL)
params = {
'from': from_email,
'to': [to_email],
'subject': f"You're invited to join {org_name} on OpenHands",
'html': f"""
<div style="font-family: Arial, sans-serif; max-width: 600px; margin: 0 auto;">
<p>Hi,</p>
<p><strong>{inviter_name}</strong> has invited you to join <strong>{org_name}</strong> on OpenHands as a <strong>{role_name}</strong>.</p>
<p>Click the button below to accept the invitation:</p>
<p style="margin: 30px 0;">
<a href="{invitation_url}"
style="background-color: #c9b974; color: #0D0F11; padding: 8px 16px;
text-decoration: none; border-radius: 8px; display: inline-block;
font-size: 14px; font-weight: 600;">
Accept Invitation
</a>
</p>
<p style="color: #666; font-size: 14px;">
Or copy and paste this link into your browser:<br>
<a href="{invitation_url}" style="color: #c9b974; font-weight: 600;">{invitation_url}</a>
</p>
<p style="color: #666; font-size: 14px;">
This invitation will expire in 7 days.
</p>
<p style="color: #666; font-size: 14px;">
If you weren't expecting this invitation, you can safely ignore this email.
</p>
<hr style="border: none; border-top: 1px solid #eee; margin: 30px 0;">
<p style="color: #999; font-size: 12px;">
Best,<br>
The OpenHands Team
</p>
</div>
""",
}
try:
response = resend.Emails.send(params)
logger.info(
'Invitation email sent',
extra={
'invitation_id': invitation_id,
'email': to_email,
'response_id': response.get('id') if response else None,
},
)
except Exception as e:
logger.error(
'Failed to send invitation email',
extra={
'invitation_id': invitation_id,
'email': to_email,
'error': str(e),
},
)
raise

View File

@@ -0,0 +1,397 @@
"""Service for managing organization invitations."""
import asyncio
from uuid import UUID
from server.auth.token_manager import TokenManager
from server.constants import ROLE_ADMIN, ROLE_OWNER
from server.routes.org_invitation_models import (
EmailMismatchError,
InsufficientPermissionError,
InvitationExpiredError,
InvitationInvalidError,
UserAlreadyMemberError,
)
from server.services.email_service import EmailService
from storage.org_invitation import OrgInvitation
from storage.org_invitation_store import OrgInvitationStore
from storage.org_member_store import OrgMemberStore
from storage.org_service import OrgService
from storage.org_store import OrgStore
from storage.role_store import RoleStore
from storage.user_store import UserStore
from openhands.core.logger import openhands_logger as logger
class OrgInvitationService:
"""Service for organization invitation operations."""
@staticmethod
async def create_invitation(
org_id: UUID,
email: str,
role_name: str,
inviter_id: UUID,
) -> OrgInvitation:
"""Create a new organization invitation.
This method:
1. Validates the organization exists
2. Validates this is not a personal workspace
3. Checks inviter has owner/admin role
4. Validates role assignment permissions
5. Checks if user is already a member
6. Creates the invitation
7. Sends the invitation email
Args:
org_id: Organization UUID
email: Invitee's email address
role_name: Role to assign on acceptance (owner, admin, member)
inviter_id: User ID of the person creating the invitation
Returns:
OrgInvitation: The created invitation
Raises:
ValueError: If organization or role not found
InsufficientPermissionError: If inviter lacks permission
UserAlreadyMemberError: If email is already a member
InvitationAlreadyExistsError: If pending invitation exists
"""
email = email.lower().strip()
logger.info(
'Creating organization invitation',
extra={
'org_id': str(org_id),
'email': email,
'role_name': role_name,
'inviter_id': str(inviter_id),
},
)
# Step 1: Validate organization exists
org = OrgStore.get_org_by_id(org_id)
if not org:
raise ValueError(f'Organization {org_id} not found')
# Step 2: Check this is not a personal workspace
# A personal workspace has org_id matching the user's id
if str(org_id) == str(inviter_id):
raise InsufficientPermissionError(
'Cannot invite users to a personal workspace'
)
# Step 3: Check inviter is a member and has permission
inviter_member = OrgMemberStore.get_org_member(org_id, inviter_id)
if not inviter_member:
raise InsufficientPermissionError(
'You are not a member of this organization'
)
inviter_role = RoleStore.get_role_by_id(inviter_member.role_id)
if not inviter_role or inviter_role.name not in [ROLE_OWNER, ROLE_ADMIN]:
raise InsufficientPermissionError('Only owners and admins can invite users')
# Step 4: Validate role assignment permissions
role_name_lower = role_name.lower()
if role_name_lower == ROLE_OWNER and inviter_role.name != ROLE_OWNER:
raise InsufficientPermissionError('Only owners can invite with owner role')
# Get the target role
target_role = RoleStore.get_role_by_name(role_name_lower)
if not target_role:
raise ValueError(f'Invalid role: {role_name}')
# Step 5: Check if user is already a member (by email)
existing_user = await UserStore.get_user_by_email_async(email)
if existing_user:
existing_member = OrgMemberStore.get_org_member(org_id, existing_user.id)
if existing_member:
raise UserAlreadyMemberError(
'User is already a member of this organization'
)
# Step 6: Create the invitation
invitation = await OrgInvitationStore.create_invitation(
org_id=org_id,
email=email,
role_id=target_role.id,
inviter_id=inviter_id,
)
# Step 7: Send invitation email
try:
# Get inviter info for the email
inviter_user = UserStore.get_user_by_id(str(inviter_member.user_id))
inviter_name = 'A team member'
if inviter_user and inviter_user.email:
inviter_name = inviter_user.email.split('@')[0]
EmailService.send_invitation_email(
to_email=email,
org_name=org.name,
inviter_name=inviter_name,
role_name=target_role.name,
invitation_token=invitation.token,
invitation_id=invitation.id,
)
except Exception as e:
logger.error(
'Failed to send invitation email',
extra={
'invitation_id': invitation.id,
'email': email,
'error': str(e),
},
)
# Don't fail the invitation creation if email fails
# The user can still access via direct link
return invitation
@staticmethod
async def create_invitations_batch(
org_id: UUID,
emails: list[str],
role_name: str,
inviter_id: UUID,
) -> tuple[list[OrgInvitation], list[tuple[str, str]]]:
"""Create multiple organization invitations concurrently.
Validates permissions once upfront, then creates invitations in parallel.
Args:
org_id: Organization UUID
emails: List of invitee email addresses
role_name: Role to assign on acceptance (owner, admin, member)
inviter_id: User ID of the person creating the invitations
Returns:
Tuple of (successful_invitations, failed_emails_with_errors)
Raises:
ValueError: If organization or role not found
InsufficientPermissionError: If inviter lacks permission
"""
logger.info(
'Creating batch organization invitations',
extra={
'org_id': str(org_id),
'email_count': len(emails),
'role_name': role_name,
'inviter_id': str(inviter_id),
},
)
# Step 1: Validate permissions upfront (shared for all emails)
org = OrgStore.get_org_by_id(org_id)
if not org:
raise ValueError(f'Organization {org_id} not found')
if str(org_id) == str(inviter_id):
raise InsufficientPermissionError(
'Cannot invite users to a personal workspace'
)
inviter_member = OrgMemberStore.get_org_member(org_id, inviter_id)
if not inviter_member:
raise InsufficientPermissionError(
'You are not a member of this organization'
)
inviter_role = RoleStore.get_role_by_id(inviter_member.role_id)
if not inviter_role or inviter_role.name not in [ROLE_OWNER, ROLE_ADMIN]:
raise InsufficientPermissionError('Only owners and admins can invite users')
role_name_lower = role_name.lower()
if role_name_lower == ROLE_OWNER and inviter_role.name != ROLE_OWNER:
raise InsufficientPermissionError('Only owners can invite with owner role')
target_role = RoleStore.get_role_by_name(role_name_lower)
if not target_role:
raise ValueError(f'Invalid role: {role_name}')
# Step 2: Create invitations concurrently
async def create_single(
email: str,
) -> tuple[str, OrgInvitation | None, str | None]:
"""Create single invitation, return (email, invitation, error)."""
try:
invitation = await OrgInvitationService.create_invitation(
org_id=org_id,
email=email,
role_name=role_name,
inviter_id=inviter_id,
)
return (email, invitation, None)
except (UserAlreadyMemberError, ValueError) as e:
return (email, None, str(e))
results = await asyncio.gather(*[create_single(email) for email in emails])
# Step 3: Separate successes and failures
successful: list[OrgInvitation] = []
failed: list[tuple[str, str]] = []
for email, invitation, error in results:
if invitation:
successful.append(invitation)
elif error:
failed.append((email, error))
logger.info(
'Batch invitation creation completed',
extra={
'org_id': str(org_id),
'successful': len(successful),
'failed': len(failed),
},
)
return successful, failed
@staticmethod
async def accept_invitation(token: str, user_id: UUID) -> OrgInvitation:
"""Accept an organization invitation.
This method:
1. Validates the token and invitation status
2. Checks expiration
3. Verifies user is not already a member
4. Creates LiteLLM integration
5. Adds user to the organization
6. Marks invitation as accepted
Args:
token: The invitation token
user_id: The user accepting the invitation
Returns:
OrgInvitation: The accepted invitation
Raises:
InvitationInvalidError: If token is invalid or invitation not pending
InvitationExpiredError: If invitation has expired
UserAlreadyMemberError: If user is already a member
"""
logger.info(
'Accepting organization invitation',
extra={
'token_prefix': token[:10] + '...' if len(token) > 10 else token,
'user_id': str(user_id),
},
)
# Step 1: Get and validate invitation
invitation = await OrgInvitationStore.get_invitation_by_token(token)
if not invitation:
raise InvitationInvalidError('Invalid invitation token')
if invitation.status != OrgInvitation.STATUS_PENDING:
if invitation.status == OrgInvitation.STATUS_ACCEPTED:
raise InvitationInvalidError('Invitation has already been accepted')
elif invitation.status == OrgInvitation.STATUS_REVOKED:
raise InvitationInvalidError('Invitation has been revoked')
else:
raise InvitationInvalidError('Invitation is no longer valid')
# Step 2: Check expiration
if OrgInvitationStore.is_token_expired(invitation):
await OrgInvitationStore.update_invitation_status(
invitation.id, OrgInvitation.STATUS_EXPIRED
)
raise InvitationExpiredError('Invitation has expired')
# Step 2.5: Verify user email matches invitation email
user = await UserStore.get_user_by_id_async(str(user_id))
if not user:
raise InvitationInvalidError('User not found')
user_email = user.email
# Fallback: fetch email from Keycloak if not in database (for existing users)
if not user_email:
token_manager = TokenManager()
user_info = await token_manager.get_user_info_from_user_id(str(user_id))
user_email = user_info.get('email') if user_info else None
if not user_email:
raise EmailMismatchError('Your account does not have an email address')
user_email = user_email.lower().strip()
invitation_email = invitation.email.lower().strip()
if user_email != invitation_email:
logger.warning(
'Email mismatch during invitation acceptance',
extra={
'user_id': str(user_id),
'user_email': user_email,
'invitation_email': invitation_email,
'invitation_id': invitation.id,
},
)
raise EmailMismatchError()
# Step 3: Check if user is already a member
existing_member = OrgMemberStore.get_org_member(invitation.org_id, user_id)
if existing_member:
raise UserAlreadyMemberError(
'You are already a member of this organization'
)
# Step 4: Create LiteLLM integration for the user in the new org
try:
settings = await OrgService.create_litellm_integration(
invitation.org_id, str(user_id)
)
except Exception as e:
logger.error(
'Failed to create LiteLLM integration for invitation acceptance',
extra={
'invitation_id': invitation.id,
'user_id': str(user_id),
'org_id': str(invitation.org_id),
'error': str(e),
},
)
raise InvitationInvalidError(
'Failed to set up organization access. Please try again.'
)
# Step 5: Add user to organization
from storage.org_member_store import OrgMemberStore as OMS
org_member_kwargs = OMS.get_kwargs_from_settings(settings)
# Don't override with org defaults - use invitation-specified role
org_member_kwargs.pop('llm_model', None)
org_member_kwargs.pop('llm_base_url', None)
OrgMemberStore.add_user_to_org(
org_id=invitation.org_id,
user_id=user_id,
role_id=invitation.role_id,
llm_api_key=settings.llm_api_key,
status='active',
)
# Step 6: Mark invitation as accepted
updated_invitation = await OrgInvitationStore.update_invitation_status(
invitation.id,
OrgInvitation.STATUS_ACCEPTED,
accepted_by_user_id=user_id,
)
logger.info(
'Organization invitation accepted',
extra={
'invitation_id': invitation.id,
'user_id': str(user_id),
'org_id': str(invitation.org_id),
'role_id': invitation.role_id,
},
)
return updated_invitation

View File

@@ -0,0 +1,342 @@
"""Service for managing organization members."""
from uuid import UUID
from server.constants import ROLE_ADMIN, ROLE_MEMBER, ROLE_OWNER
from server.routes.org_models import (
CannotModifySelfError,
InsufficientPermissionError,
InvalidRoleError,
LastOwnerError,
MemberUpdateError,
MeResponse,
OrgMemberNotFoundError,
OrgMemberPage,
OrgMemberResponse,
OrgMemberUpdate,
RoleNotFoundError,
)
from storage.org_member_store import OrgMemberStore
from storage.role_store import RoleStore
from storage.user_store import UserStore
from openhands.utils.async_utils import call_sync_from_async
class OrgMemberService:
"""Service for organization member operations."""
@staticmethod
def get_me(org_id: UUID, user_id: UUID) -> MeResponse:
"""Get the current user's membership record for an organization.
Retrieves the authenticated user's role, status, email, and LLM override
fields (with masked API keys) within the specified organization.
Args:
org_id: Organization ID (UUID)
user_id: User ID (UUID)
Returns:
MeResponse: The user's membership data with masked API keys
Raises:
OrgMemberNotFoundError: If user is not a member of the organization
RoleNotFoundError: If the role associated with the member is not found
"""
# Look up the user's membership in this org
org_member = OrgMemberStore.get_org_member(org_id, user_id)
if org_member is None:
raise OrgMemberNotFoundError(str(org_id), str(user_id))
# Resolve role name from role_id
role = RoleStore.get_role_by_id(org_member.role_id)
if role is None:
raise RoleNotFoundError(org_member.role_id)
# Get user email
user = UserStore.get_user_by_id(str(user_id))
email = user.email if user and user.email else ''
return MeResponse.from_org_member(org_member, role, email)
@staticmethod
async def get_org_members(
org_id: UUID,
current_user_id: UUID,
page_id: str | None = None,
limit: int = 100,
) -> tuple[bool, str | None, OrgMemberPage | None]:
"""Get organization members with authorization check.
Returns:
Tuple of (success, error_code, data). If success is True, error_code is None.
"""
# Verify current user is a member of the organization
requester_membership = OrgMemberStore.get_org_member(org_id, current_user_id)
if not requester_membership:
return False, 'not_a_member', None
# Parse page_id to get offset (page_id is offset encoded as string)
offset = 0
if page_id is not None:
try:
offset = int(page_id)
if offset < 0:
return False, 'invalid_page_id', None
except ValueError:
return False, 'invalid_page_id', None
# Call store to get paginated members
members, has_more = await OrgMemberStore.get_org_members_paginated(
org_id=org_id, offset=offset, limit=limit
)
# Transform data to response format
items = []
for member in members:
# Access user and role relationships (eagerly loaded)
user = member.user
role = member.role
items.append(
OrgMemberResponse(
user_id=str(member.user_id),
email=user.email if user else None,
role_id=member.role_id,
role=role.name if role else '',
role_rank=role.rank if role else 0,
status=member.status,
)
)
# Calculate next_page_id
next_page_id = None
if has_more:
next_page_id = str(offset + limit)
return True, None, OrgMemberPage(items=items, next_page_id=next_page_id)
@staticmethod
async def remove_org_member(
org_id: UUID,
target_user_id: UUID,
current_user_id: UUID,
) -> tuple[bool, str | None]:
"""Remove a member from an organization.
Returns:
Tuple of (success, error_message). If success is True, error_message is None.
"""
def _remove_member():
# Get current user's membership in the org
requester_membership = OrgMemberStore.get_org_member(
org_id, current_user_id
)
if not requester_membership:
return False, 'not_a_member'
# Check if trying to remove self
if str(current_user_id) == str(target_user_id):
return False, 'cannot_remove_self'
# Get target user's membership
target_membership = OrgMemberStore.get_org_member(org_id, target_user_id)
if not target_membership:
return False, 'member_not_found'
requester_role = RoleStore.get_role_by_id(requester_membership.role_id)
target_role = RoleStore.get_role_by_id(target_membership.role_id)
if not requester_role or not target_role:
return False, 'role_not_found'
# Check permission based on roles
if not OrgMemberService._can_remove_member(
requester_role.name, target_role.name
):
return False, 'insufficient_permission'
# Check if removing the last owner
if target_role.name == ROLE_OWNER:
if OrgMemberService._is_last_owner(org_id, target_user_id):
return False, 'cannot_remove_last_owner'
# Perform the removal
success = OrgMemberStore.remove_user_from_org(org_id, target_user_id)
if not success:
return False, 'removal_failed'
return True, None
return await call_sync_from_async(_remove_member)
@staticmethod
async def update_org_member(
org_id: UUID,
target_user_id: UUID,
current_user_id: UUID,
update_data: OrgMemberUpdate,
) -> OrgMemberResponse:
"""Update a member's role in an organization.
Permission rules:
- Admins can change roles of users (rank > ADMIN_RANK) to Admin or User
- Admins cannot modify other Admins or Owners
- Owners can change roles of non-owners (rank > OWNER_RANK) to any role
- Owners cannot modify other Owners
Args:
org_id: Organization ID
target_user_id: User ID of the member to update
current_user_id: User ID of the requester
update_data: Update data containing fields to modify
Returns:
OrgMemberResponse: The updated member data
Raises:
OrgMemberNotFoundError: If requester or target is not a member
CannotModifySelfError: If trying to modify self
RoleNotFoundError: If role configuration is invalid
InvalidRoleError: If new_role_name is not a valid role
InsufficientPermissionError: If requester lacks permission
LastOwnerError: If trying to demote the last owner
MemberUpdateError: If update operation fails
"""
new_role_name = update_data.role
def _update_member():
# Get current user's membership in the org
requester_membership = OrgMemberStore.get_org_member(
org_id, current_user_id
)
if not requester_membership:
raise OrgMemberNotFoundError(str(org_id), str(current_user_id))
# Check if trying to modify self
if str(current_user_id) == str(target_user_id):
raise CannotModifySelfError('modify')
# Get target user's membership
target_membership = OrgMemberStore.get_org_member(org_id, target_user_id)
if not target_membership:
raise OrgMemberNotFoundError(str(org_id), str(target_user_id))
# Get roles
requester_role = RoleStore.get_role_by_id(requester_membership.role_id)
target_role = RoleStore.get_role_by_id(target_membership.role_id)
if not requester_role:
raise RoleNotFoundError(requester_membership.role_id)
if not target_role:
raise RoleNotFoundError(target_membership.role_id)
# If no role change requested, return current state
if new_role_name is None:
user = UserStore.get_user_by_id(str(target_user_id))
return OrgMemberResponse(
user_id=str(target_membership.user_id),
email=user.email if user else None,
role_id=target_membership.role_id,
role=target_role.name,
role_rank=target_role.rank,
status=target_membership.status,
)
# Validate new role exists
new_role = RoleStore.get_role_by_name(new_role_name.lower())
if not new_role:
raise InvalidRoleError(new_role_name)
# Check permission to modify target
if not OrgMemberService._can_update_member_role(
requester_role.name, target_role.name, new_role.name
):
raise InsufficientPermissionError(
'You do not have permission to modify this member'
)
# Check if demoting the last owner
if (
target_role.name == ROLE_OWNER
and new_role.name != ROLE_OWNER
and OrgMemberService._is_last_owner(org_id, target_user_id)
):
raise LastOwnerError('demote')
# Perform the update
updated_member = OrgMemberStore.update_user_role_in_org(
org_id, target_user_id, new_role.id
)
if not updated_member:
raise MemberUpdateError('Failed to update member')
# Get user email for response
user = UserStore.get_user_by_id(str(target_user_id))
return OrgMemberResponse(
user_id=str(updated_member.user_id),
email=user.email if user else None,
role_id=updated_member.role_id,
role=new_role.name,
role_rank=new_role.rank,
status=updated_member.status,
)
return await call_sync_from_async(_update_member)
@staticmethod
def _can_update_member_role(
requester_role_name: str, target_role_name: str, new_role_name: str
) -> bool:
"""Check if requester can change target's role to new_role.
Permission rules:
- Owners can modify admins and users, can set any role
- Owners cannot modify other owners
- Admins can only modify users
- Admins can only set admin or user roles (not owner)
"""
is_requester_owner = requester_role_name == ROLE_OWNER
is_requester_admin = requester_role_name == ROLE_ADMIN
is_target_owner = target_role_name == ROLE_OWNER
is_target_admin = target_role_name == ROLE_ADMIN
is_new_role_owner = new_role_name == ROLE_OWNER
if is_requester_owner:
# Owners cannot modify other owners
if is_target_owner:
return False
# Owners can set any role (owner, admin, user)
return True
elif is_requester_admin:
# Admins cannot modify owners or other admins
if is_target_owner or is_target_admin:
return False
# Admins can only set admin or user roles (not owner)
return not is_new_role_owner
return False
@staticmethod
def _can_remove_member(requester_role_name: str, target_role_name: str) -> bool:
"""Check if requester can remove target based on roles."""
if requester_role_name == ROLE_OWNER:
return True
elif requester_role_name == ROLE_ADMIN:
# Admins can only remove members (not owners or other admins)
return target_role_name == ROLE_MEMBER
return False
@staticmethod
def _is_last_owner(org_id: UUID, user_id: UUID) -> bool:
"""Check if user is the last owner of the organization."""
members = OrgMemberStore.get_org_members(org_id)
owners = []
for m in members:
# Use role_id (column) instead of role (relationship) to avoid DetachedInstanceError
role = RoleStore.get_role_by_id(m.role_id)
if role and role.name == ROLE_OWNER:
owners.append(m)
return len(owners) == 1 and str(owners[0].user_id) == str(user_id)

View File

@@ -22,11 +22,63 @@ from openhands.app_server.app_conversation.app_conversation_models import (
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
SQLAppConversationInfoService,
)
from openhands.app_server.errors import AuthError
from openhands.app_server.services.injector import InjectorState
class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
"""Extended SQLAppConversationInfoService with user-based filtering and SAAS metadata handling."""
"""Extended SQLAppConversationInfoService with user and organization-based filtering and SAAS metadata handling."""
async def _get_current_user(self) -> User | None:
"""Get the current user using the existing db_session.
Uses self.db_session to avoid opening a separate database session.
Returns:
User object or None if no user_id is available
"""
user_id_str = await self.user_context.get_user_id()
if not user_id_str:
return None
user_id_uuid = UUID(user_id_str)
result = await self.db_session.execute(
select(User).where(User.id == user_id_uuid)
)
return result.scalars().first()
async def _apply_user_and_org_filter(self, query):
"""Apply user_id and org_id filters to ensure conversation isolation.
Filters conversations by:
- user_id: Only show conversations belonging to the current user
- org_id: Only show conversations belonging to the user's current organization
Args:
query: SQLAlchemy query to apply filters to
Returns:
Query with user and organization filters applied
Raises:
AuthError: If no user_id is available (secure default: deny access)
"""
user_id_str = await self.user_context.get_user_id()
if not user_id_str:
# Secure default: no user means no access, not "show everything"
raise AuthError('User authentication required')
user_id_uuid = UUID(user_id_str)
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
# Filter by organization ID to ensure conversations are isolated per organization
user = await self._get_current_user()
if user and user.current_org_id is not None:
query = query.where(
StoredConversationMetadataSaas.org_id == user.current_org_id
)
return query
async def _secure_select(self):
query = (
@@ -38,13 +90,7 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
)
.where(StoredConversationMetadata.conversation_version == 'V1')
)
user_id_str = await self.user_context.get_user_id()
if user_id_str:
user_id_uuid = UUID(user_id_str)
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
return query
return await self._apply_user_and_org_filter(query)
async def _secure_select_with_saas_metadata(self):
"""Select query that includes SAAS metadata for retrieving user_id."""
@@ -57,13 +103,7 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
)
.where(StoredConversationMetadata.conversation_version == 'V1')
)
user_id_str = await self.user_context.get_user_id()
if user_id_str:
user_id_uuid = UUID(user_id_str)
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
return query
return await self._apply_user_and_org_filter(query)
async def search_app_conversation_info(
self,
@@ -155,21 +195,16 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
"""Count conversations matching the given filters with SAAS metadata."""
query = (
select(func.count(StoredConversationMetadata.conversation_id))
.select_from(
StoredConversationMetadata.join(
StoredConversationMetadataSaas,
StoredConversationMetadata.conversation_id
== StoredConversationMetadataSaas.conversation_id,
)
.join(
StoredConversationMetadataSaas,
StoredConversationMetadata.conversation_id
== StoredConversationMetadataSaas.conversation_id,
)
.where(StoredConversationMetadata.conversation_version == 'V1')
)
# Apply user filtering
user_id_str = await self.user_context.get_user_id()
if user_id_str:
user_id_uuid = UUID(user_id_str)
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
# Apply user and organization filtering
query = await self._apply_user_and_org_filter(query)
query = self._apply_filters_with_saas_metadata(
query=query,
@@ -233,7 +268,13 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
result = result_set.first()
if result:
stored_metadata, saas_metadata = result
return self._to_info_with_user_id(stored_metadata, saas_metadata)
# Fetch sub-conversation IDs
sub_conversation_ids = await self.get_sub_conversation_ids(conversation_id)
return self._to_info_with_user_id(
stored_metadata,
saas_metadata,
sub_conversation_ids=sub_conversation_ids,
)
return None
async def batch_get_app_conversation_info(
@@ -262,8 +303,16 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
for conversation_id in conversation_id_strs:
if conversation_id in info_by_id:
stored_metadata, saas_metadata = info_by_id[conversation_id]
# Fetch sub-conversation IDs for each conversation
sub_conversation_ids = await self.get_sub_conversation_ids(
UUID(conversation_id)
)
results.append(
self._to_info_with_user_id(stored_metadata, saas_metadata)
self._to_info_with_user_id(
stored_metadata,
saas_metadata,
sub_conversation_ids=sub_conversation_ids,
)
)
else:
results.append(None)
@@ -316,10 +365,11 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
self,
stored: StoredConversationMetadata,
saas_metadata: StoredConversationMetadataSaas,
sub_conversation_ids: list[UUID] | None = None,
) -> AppConversationInfo:
"""Convert stored metadata to AppConversationInfo with user_id from SAAS metadata."""
# Use the base _to_info method to get the basic info
info = self._to_info(stored)
info = self._to_info(stored, sub_conversation_ids=sub_conversation_ids)
# Override the created_by_user_id with the user_id from SAAS metadata
info.created_by_user_id = (

View File

@@ -20,8 +20,10 @@ from storage.linear_workspace import LinearWorkspace
from storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus
from storage.openhands_pr import OpenhandsPR
from storage.org import Org
from storage.org_invitation import OrgInvitation
from storage.org_member import OrgMember
from storage.proactive_convos import ProactiveConversation
from storage.resend_synced_user import ResendSyncedUser
from storage.role import Role
from storage.slack_conversation import SlackConversation
from storage.slack_team import SlackTeam
@@ -65,8 +67,10 @@ __all__ = [
'MaintenanceTaskStatus',
'OpenhandsPR',
'Org',
'OrgInvitation',
'OrgMember',
'ProactiveConversation',
'ResendSyncedUser',
'Role',
'SlackConversation',
'SlackTeam',

View File

@@ -126,7 +126,7 @@ class ApiKeyStore:
return True
async def list_api_keys(self, user_id: str) -> list[dict]:
async def list_api_keys(self, user_id: str) -> list[ApiKey]:
"""List all API keys for a user."""
user = await UserStore.get_user_by_id_async(user_id)
org_id = user.current_org_id
@@ -134,24 +134,14 @@ class ApiKeyStore:
def _list_api_keys_from_db(self, user_id: str, org_id: str) -> list[ApiKey]:
with self.session_maker() as session:
keys = (
keys: list[ApiKey] = (
session.query(ApiKey)
.filter(ApiKey.user_id == user_id)
.filter(ApiKey.org_id == org_id)
.all()
)
return [
{
'id': key.id,
'name': key.name,
'created_at': key.created_at,
'last_used_at': key.last_used_at,
'expires_at': key.expires_at,
}
for key in keys
if 'MCP_API_KEY' != key.name
]
return [key for key in keys if key.name != 'MCP_API_KEY']
async def retrieve_mcp_api_key(self, user_id: str) -> str | None:
user = await UserStore.get_user_by_id_async(user_id)

View File

@@ -10,7 +10,6 @@ import httpx
from pydantic import SecretStr
from server.auth.token_manager import TokenManager
from server.constants import (
DEFAULT_INITIAL_BUDGET,
LITE_LLM_API_KEY,
LITE_LLM_API_URL,
LITE_LLM_TEAM_ID,
@@ -44,6 +43,34 @@ def get_byor_key_alias(keycloak_user_id: str, org_id: str) -> str:
class LiteLlmManager:
"""Manage LiteLLM interactions."""
@staticmethod
def get_budget_from_team_info(
user_team_info: dict | None, user_id: str, org_id: str
) -> tuple[float, float]:
"""Extract max_budget and spend from user team info.
For personal orgs (user_id == org_id), uses litellm_budget_table.max_budget.
For team orgs, uses max_budget_in_team (populated by get_user_team_info).
Args:
user_team_info: The response from get_user_team_info
user_id: The user's ID
org_id: The organization's ID
Returns:
Tuple of (max_budget, spend)
"""
if not user_team_info:
return 0, 0
spend = user_team_info.get('spend', 0)
if user_id == org_id:
max_budget = (user_team_info.get('litellm_budget_table') or {}).get(
'max_budget', 0
)
else:
max_budget = user_team_info.get('max_budget_in_team') or 0
return max_budget, spend
@staticmethod
async def create_entries(
org_id: str,
@@ -72,8 +99,33 @@ class LiteLlmManager:
'x-goog-api-key': LITE_LLM_API_KEY,
}
) as client:
# Check if team already exists and get its budget
# New users joining existing orgs should inherit the team's budget
team_budget = 0.0
try:
existing_team = await LiteLlmManager._get_team(client, org_id)
if existing_team:
team_info = existing_team.get('team_info', {})
team_budget = team_info.get('max_budget', 0.0) or 0.0
logger.info(
'LiteLlmManager:create_entries:existing_team_budget',
extra={
'org_id': org_id,
'user_id': keycloak_user_id,
'team_budget': team_budget,
},
)
except httpx.HTTPStatusError as e:
# Team doesn't exist yet (404) - this is expected for first user
if e.response.status_code != 404:
raise
logger.info(
'LiteLlmManager:create_entries:no_existing_team',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
await LiteLlmManager._create_team(
client, keycloak_user_id, org_id, DEFAULT_INITIAL_BUDGET
client, keycloak_user_id, org_id, team_budget
)
if create_user:
@@ -82,7 +134,7 @@ class LiteLlmManager:
)
await LiteLlmManager._add_user_to_team(
client, keycloak_user_id, org_id, DEFAULT_INITIAL_BUDGET
client, keycloak_user_id, org_id, team_budget
)
key = await LiteLlmManager._generate_key(
@@ -894,21 +946,31 @@ class LiteLlmManager:
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return None
team_info = await LiteLlmManager._get_team(client, team_id)
if not team_info:
team_response = await LiteLlmManager._get_team(client, team_id)
if not team_response:
return None
# Filter team_memberships based on team_id and keycloak_user_id
user_membership = next(
(
membership
for membership in team_info.get('team_memberships', [])
for membership in team_response.get('team_memberships', [])
if membership.get('user_id') == keycloak_user_id
and membership.get('team_id') == team_id
),
None,
)
if not user_membership:
return None
# For team orgs (user_id != team_id), include team-level budget info
# The team's max_budget and spend are shared across all members
if keycloak_user_id != team_id:
team_info = team_response.get('team_info', {})
user_membership['max_budget_in_team'] = team_info.get('max_budget')
user_membership['spend'] = team_info.get('spend', 0)
return user_membership
@staticmethod

View File

@@ -46,10 +46,12 @@ class Org(Base): # type: ignore
v1_enabled = Column(Boolean, nullable=True)
conversation_expiration = Column(Integer, nullable=True)
condenser_max_size = Column(Integer, nullable=True)
byor_export_enabled = Column(Boolean, nullable=False, default=False)
# Relationships
org_members = relationship('OrgMember', back_populates='org')
current_users = relationship('User', back_populates='current_org')
invitations = relationship('OrgInvitation', back_populates='org')
billing_sessions = relationship('BillingSession', back_populates='org')
stored_conversation_metadata_saas = relationship(
'StoredConversationMetadataSaas', back_populates='org'

View File

@@ -0,0 +1,59 @@
"""
SQLAlchemy model for Organization Invitation.
"""
from sqlalchemy import UUID, Column, DateTime, ForeignKey, Integer, String, text
from sqlalchemy.orm import relationship
from storage.base import Base
class OrgInvitation(Base): # type: ignore
"""Organization invitation model.
Represents an invitation for a user to join an organization.
Invitations are created by organization owners/admins and contain
a secure token that can be used to accept the invitation.
"""
__tablename__ = 'org_invitation'
id = Column(Integer, primary_key=True, autoincrement=True)
token = Column(String(64), nullable=False, unique=True, index=True)
org_id = Column(
UUID(as_uuid=True),
ForeignKey('org.id', ondelete='CASCADE'),
nullable=False,
index=True,
)
email = Column(String(255), nullable=False, index=True)
role_id = Column(Integer, ForeignKey('role.id'), nullable=False)
inviter_id = Column(UUID(as_uuid=True), ForeignKey('user.id'), nullable=False)
status = Column(
String(20),
nullable=False,
server_default=text("'pending'"),
)
created_at = Column(
DateTime,
nullable=False,
server_default=text('CURRENT_TIMESTAMP'),
)
expires_at = Column(DateTime, nullable=False)
accepted_at = Column(DateTime, nullable=True)
accepted_by_user_id = Column(
UUID(as_uuid=True),
ForeignKey('user.id'),
nullable=True,
)
# Relationships
org = relationship('Org', back_populates='invitations')
role = relationship('Role')
inviter = relationship('User', foreign_keys=[inviter_id])
accepted_by_user = relationship('User', foreign_keys=[accepted_by_user_id])
# Status constants
STATUS_PENDING = 'pending'
STATUS_ACCEPTED = 'accepted'
STATUS_REVOKED = 'revoked'
STATUS_EXPIRED = 'expired'

View File

@@ -0,0 +1,227 @@
"""
Store class for managing organization invitations.
"""
import secrets
import string
from datetime import datetime, timedelta
from typing import Optional
from uuid import UUID
from sqlalchemy import and_, select
from sqlalchemy.orm import joinedload
from storage.database import a_session_maker
from storage.org_invitation import OrgInvitation
from openhands.core.logger import openhands_logger as logger
# Invitation token configuration
INVITATION_TOKEN_PREFIX = 'inv-'
INVITATION_TOKEN_LENGTH = 48 # Total length will be 52 with prefix
DEFAULT_EXPIRATION_DAYS = 7
class OrgInvitationStore:
"""Store for managing organization invitations."""
@staticmethod
def generate_token(length: int = INVITATION_TOKEN_LENGTH) -> str:
"""Generate a secure invitation token.
Uses cryptographically secure random generation for tokens.
Pattern from api_key_store.py.
Args:
length: Length of the random part of the token
Returns:
str: Token with prefix (e.g., 'inv-aBcDeF123...')
"""
alphabet = string.ascii_letters + string.digits
random_part = ''.join(secrets.choice(alphabet) for _ in range(length))
return f'{INVITATION_TOKEN_PREFIX}{random_part}'
@staticmethod
async def create_invitation(
org_id: UUID,
email: str,
role_id: int,
inviter_id: UUID,
expiration_days: int = DEFAULT_EXPIRATION_DAYS,
) -> OrgInvitation:
"""Create a new organization invitation.
Args:
org_id: Organization UUID
email: Invitee's email address
role_id: Role ID to assign on acceptance
inviter_id: User ID of the person creating the invitation
expiration_days: Days until the invitation expires
Returns:
OrgInvitation: The created invitation record
"""
async with a_session_maker() as session:
token = OrgInvitationStore.generate_token()
# Use timezone-naive datetime for database compatibility
expires_at = datetime.utcnow() + timedelta(days=expiration_days)
invitation = OrgInvitation(
token=token,
org_id=org_id,
email=email.lower().strip(),
role_id=role_id,
inviter_id=inviter_id,
status=OrgInvitation.STATUS_PENDING,
expires_at=expires_at,
)
session.add(invitation)
await session.commit()
# Re-fetch with eagerly loaded relationships to avoid DetachedInstanceError
result = await session.execute(
select(OrgInvitation)
.options(joinedload(OrgInvitation.role))
.filter(OrgInvitation.id == invitation.id)
)
invitation = result.scalars().first()
logger.info(
'Created organization invitation',
extra={
'invitation_id': invitation.id,
'org_id': str(org_id),
'email': email,
'inviter_id': str(inviter_id),
'expires_at': expires_at.isoformat(),
},
)
return invitation
@staticmethod
async def get_invitation_by_token(token: str) -> Optional[OrgInvitation]:
"""Get an invitation by its token.
Args:
token: The invitation token
Returns:
OrgInvitation or None if not found
"""
async with a_session_maker() as session:
result = await session.execute(
select(OrgInvitation)
.options(joinedload(OrgInvitation.org), joinedload(OrgInvitation.role))
.filter(OrgInvitation.token == token)
)
return result.scalars().first()
@staticmethod
async def get_pending_invitation(
org_id: UUID, email: str
) -> Optional[OrgInvitation]:
"""Get a pending invitation for an email in an organization.
Args:
org_id: Organization UUID
email: Email address to check
Returns:
OrgInvitation or None if no pending invitation exists
"""
async with a_session_maker() as session:
result = await session.execute(
select(OrgInvitation).filter(
and_(
OrgInvitation.org_id == org_id,
OrgInvitation.email == email.lower().strip(),
OrgInvitation.status == OrgInvitation.STATUS_PENDING,
)
)
)
return result.scalars().first()
@staticmethod
async def update_invitation_status(
invitation_id: int,
status: str,
accepted_by_user_id: Optional[UUID] = None,
) -> Optional[OrgInvitation]:
"""Update an invitation's status.
Args:
invitation_id: The invitation ID
status: New status (pending, accepted, revoked, expired)
accepted_by_user_id: User ID who accepted (only for 'accepted' status)
Returns:
Updated OrgInvitation or None if not found
"""
async with a_session_maker() as session:
result = await session.execute(
select(OrgInvitation).filter(OrgInvitation.id == invitation_id)
)
invitation = result.scalars().first()
if not invitation:
return None
old_status = invitation.status
invitation.status = status
if status == OrgInvitation.STATUS_ACCEPTED and accepted_by_user_id:
# Use timezone-naive datetime for database compatibility
invitation.accepted_at = datetime.utcnow()
invitation.accepted_by_user_id = accepted_by_user_id
await session.commit()
await session.refresh(invitation)
logger.info(
'Updated invitation status',
extra={
'invitation_id': invitation_id,
'old_status': old_status,
'new_status': status,
'accepted_by_user_id': (
str(accepted_by_user_id) if accepted_by_user_id else None
),
},
)
return invitation
@staticmethod
def is_token_expired(invitation: OrgInvitation) -> bool:
"""Check if an invitation token has expired.
Args:
invitation: The invitation to check
Returns:
bool: True if expired, False otherwise
"""
# Use timezone-naive datetime for comparison (database stores without timezone)
now = datetime.utcnow()
return invitation.expires_at < now
@staticmethod
async def mark_expired_if_needed(invitation: OrgInvitation) -> bool:
"""Check if invitation is expired and update status if needed.
Args:
invitation: The invitation to check
Returns:
bool: True if invitation was marked as expired, False otherwise
"""
if (
invitation.status == OrgInvitation.STATUS_PENDING
and OrgInvitationStore.is_token_expired(invitation)
):
await OrgInvitationStore.update_invitation_status(
invitation.id, OrgInvitation.STATUS_EXPIRED
)
return True
return False

View File

@@ -6,8 +6,10 @@ from typing import Optional
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.orm import joinedload
from storage.database import a_session_maker, session_maker
from storage.org_member import OrgMember
from storage.user import User
from storage.user_settings import UserSettings
from openhands.storage.data_models.settings import Settings
@@ -59,6 +61,51 @@ class OrgMemberStore:
)
return result.scalars().first()
@staticmethod
def get_org_member_for_current_org(user_id: UUID) -> Optional[OrgMember]:
"""Get the org member for a user's current organization.
Args:
user_id: The user's UUID.
Returns:
The OrgMember for the user's current organization, or None if not found.
"""
with session_maker() as session:
result = (
session.query(OrgMember)
.join(User, User.id == OrgMember.user_id)
.filter(
User.id == user_id,
OrgMember.org_id == User.current_org_id,
)
.first()
)
return result
@staticmethod
async def get_org_member_for_current_org_async(
user_id: UUID,
) -> Optional[OrgMember]:
"""Get the org member for a user's current organization (async version).
Args:
user_id: The user's UUID.
Returns:
The OrgMember for the user's current organization, or None if not found.
"""
async with a_session_maker() as session:
result = await session.execute(
select(OrgMember)
.join(User, User.id == OrgMember.user_id)
.filter(
User.id == user_id,
OrgMember.org_id == User.current_org_id,
)
)
return result.scalars().first()
@staticmethod
def get_user_orgs(user_id: UUID) -> list[OrgMember]:
"""Get all organizations for a user."""
@@ -135,3 +182,36 @@ class OrgMemberStore:
if (normalized := c.name.lstrip('_')) and hasattr(user_settings, normalized)
}
return kwargs
@staticmethod
async def get_org_members_paginated(
org_id: UUID,
offset: int = 0,
limit: int = 100,
) -> tuple[list[OrgMember], bool]:
"""Get paginated list of organization members with user and role info.
Returns:
Tuple of (members_list, has_more) where has_more indicates if there are more results.
"""
async with a_session_maker() as session:
# Query for limit + 1 items to determine if there are more results
# Order by user_id for consistent pagination
query = (
select(OrgMember)
.options(joinedload(OrgMember.user), joinedload(OrgMember.role))
.filter(OrgMember.org_id == org_id)
.order_by(OrgMember.user_id)
.offset(offset)
.limit(limit + 1)
)
result = await session.execute(query)
members = list(result.scalars().all())
# Check if there are more results
has_more = len(members) > limit
if has_more:
# Remove the extra item
members = members[:limit]
return members, has_more

View File

@@ -521,6 +521,7 @@ class OrgService:
Raises:
ValueError: If organization not found
PermissionError: If user is not a member, or lacks admin/owner role for LLM settings
OrgNameExistsError: If new name already exists for another organization
OrgDatabaseError: If database update fails
"""
logger.info(
@@ -550,6 +551,24 @@ class OrgService:
'User must be a member of the organization to update it'
)
# Check if name is being updated and validate uniqueness
if update_data.name is not None:
# Check if new name conflicts with another org
existing_org_with_name = OrgStore.get_org_by_name(update_data.name)
if (
existing_org_with_name is not None
and existing_org_with_name.id != org_id
):
logger.warning(
'Attempted to update organization with duplicate name',
extra={
'user_id': user_id,
'org_id': str(org_id),
'attempted_name': update_data.name,
},
)
raise OrgNameExistsError(update_data.name)
# Check if update contains any LLM settings
llm_fields_being_updated = OrgService._has_llm_settings_updates(update_data)
if llm_fields_being_updated:
@@ -637,10 +656,9 @@ class OrgService:
)
return None
max_budget = (user_team_info.get('litellm_budget_table') or {}).get(
'max_budget', 0
max_budget, spend = LiteLlmManager.get_budget_from_team_info(
user_team_info, user_id, str(org_id)
)
spend = user_team_info.get('spend', 0)
credits = max(max_budget - spend, 0)
logger.debug(
@@ -842,3 +860,94 @@ class OrgService:
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
)
raise OrgDatabaseError(f'Failed to delete organization: {str(e)}')
@staticmethod
async def check_byor_export_enabled(user_id: str) -> bool:
"""Check if BYOR export is enabled for the user's current org.
Returns True if the user's current org has byor_export_enabled set to True.
Returns False if the user is not found, has no current org, or the flag is False.
Args:
user_id: User ID to check
Returns:
bool: True if BYOR export is enabled, False otherwise
"""
user = await UserStore.get_user_by_id_async(user_id)
if not user or not user.current_org_id:
return False
org = OrgStore.get_org_by_id(user.current_org_id)
if not org:
return False
return org.byor_export_enabled
@staticmethod
async def switch_org(user_id: str, org_id: UUID) -> Org:
"""
Switch user's current organization to the specified organization.
This method:
1. Validates that the organization exists
2. Validates that the user is a member of the organization
3. Updates the user's current_org_id
Args:
user_id: User ID (string that will be converted to UUID)
org_id: Organization ID to switch to
Returns:
Org: The organization that was switched to
Raises:
OrgNotFoundError: If organization doesn't exist
OrgAuthorizationError: If user is not a member of the organization
OrgDatabaseError: If database update fails
"""
logger.info(
'Switching user organization',
extra={'user_id': user_id, 'org_id': str(org_id)},
)
# Step 1: Check if organization exists
org = OrgStore.get_org_by_id(org_id)
if not org:
raise OrgNotFoundError(str(org_id))
# Step 2: Validate user is a member of the organization
if not OrgService.is_org_member(user_id, org_id):
logger.warning(
'User attempted to switch to organization they are not a member of',
extra={'user_id': user_id, 'org_id': str(org_id)},
)
raise OrgAuthorizationError(
'User must be a member of the organization to switch to it'
)
# Step 3: Update user's current_org_id
try:
updated_user = UserStore.update_current_org(user_id, org_id)
if not updated_user:
raise OrgDatabaseError('User not found')
logger.info(
'Successfully switched user organization',
extra={
'user_id': user_id,
'org_id': str(org_id),
'org_name': org.name,
},
)
return org
except OrgDatabaseError:
raise
except Exception as e:
logger.error(
'Failed to switch user organization',
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
)
raise OrgDatabaseError(f'Failed to switch organization: {str(e)}')

View File

@@ -10,6 +10,7 @@ from server.constants import (
ORG_SETTINGS_VERSION,
get_default_litellm_model,
)
from server.routes.org_models import OrphanedUserError
from sqlalchemy import text
from sqlalchemy.orm import joinedload
from storage.database import session_maker
@@ -320,17 +321,41 @@ class OrgStore:
{'org_id': str(org_id)},
)
# 3. Delete organization memberships
# 3. Handle users with this as current_org_id BEFORE deleting memberships
# Single query to find orphaned users (those with no alternative org)
orphaned_users = session.execute(
text("""
SELECT u.id
FROM "user" u
WHERE u.current_org_id = :org_id
AND NOT EXISTS (
SELECT 1 FROM org_member om
WHERE om.user_id = u.id AND om.org_id != :org_id
)
"""),
{'org_id': str(org_id)},
).fetchall()
if orphaned_users:
raise OrphanedUserError([str(row[0]) for row in orphaned_users])
# Batch update: reassign current_org_id to an alternative org for all affected users
session.execute(
text('DELETE FROM org_member WHERE org_id = :org_id'),
text("""
UPDATE "user" u
SET current_org_id = (
SELECT om.org_id FROM org_member om
WHERE om.user_id = u.id AND om.org_id != :org_id
LIMIT 1
)
WHERE u.current_org_id = :org_id
"""),
{'org_id': str(org_id)},
)
# 4. Handle users with this as current_org_id
# 4. Delete organization memberships (now safe)
session.execute(
text(
'UPDATE "user" SET current_org_id = NULL WHERE current_org_id = :org_id'
),
text('DELETE FROM org_member WHERE org_id = :org_id'),
{'org_id': str(org_id)},
)

View File

@@ -0,0 +1,35 @@
"""SQLAlchemy model for tracking users synced to Resend audiences."""
from datetime import UTC, datetime
from uuid import uuid4
from sqlalchemy import Column, DateTime, String, UniqueConstraint
from sqlalchemy.dialects.postgresql import UUID
from storage.base import Base
class ResendSyncedUser(Base): # type: ignore
"""Tracks users that have been synced to a Resend audience.
This table ensures that once a user is synced to a Resend audience,
they won't be re-added even if they are later deleted from the
Resend UI. This respects manual deletions/unsubscribes.
"""
__tablename__ = 'resend_synced_users'
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4)
email = Column(String, nullable=False, index=True)
audience_id = Column(String, nullable=False, index=True)
synced_at = Column(
DateTime(timezone=True),
default=lambda: datetime.now(UTC),
nullable=False,
)
keycloak_user_id = Column(String, nullable=True)
__table_args__ = (
UniqueConstraint(
'email', 'audience_id', name='uq_resend_synced_email_audience'
),
)

View File

@@ -0,0 +1,125 @@
"""Store class for managing Resend synced users."""
from dataclasses import dataclass
from datetime import UTC, datetime
from typing import Optional, Set
from sqlalchemy import delete, select
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import sessionmaker
from storage.resend_synced_user import ResendSyncedUser
@dataclass
class ResendSyncedUserStore:
"""Store for tracking users synced to Resend audiences."""
session_maker: sessionmaker
def is_user_synced(self, email: str, audience_id: str) -> bool:
"""Check if a user has been synced to a specific audience.
Args:
email: The email address to check.
audience_id: The Resend audience ID.
Returns:
True if the user has been synced, False otherwise.
"""
with self.session_maker() as session:
stmt = select(ResendSyncedUser).where(
ResendSyncedUser.email == email.lower(),
ResendSyncedUser.audience_id == audience_id,
)
result = session.execute(stmt).first()
return result is not None
def get_synced_emails_for_audience(self, audience_id: str) -> Set[str]:
"""Get all synced email addresses for a specific audience.
Args:
audience_id: The Resend audience ID.
Returns:
A set of lowercase email addresses that have been synced.
"""
with self.session_maker() as session:
stmt = select(ResendSyncedUser.email).where(
ResendSyncedUser.audience_id == audience_id,
)
result = session.execute(stmt).scalars().all()
return set(result)
def mark_user_synced(
self,
email: str,
audience_id: str,
keycloak_user_id: Optional[str] = None,
) -> ResendSyncedUser:
"""Mark a user as synced to a specific audience.
Uses upsert to handle race conditions - if the user is already
marked as synced, this is a no-op.
Args:
email: The email address of the user.
audience_id: The Resend audience ID.
keycloak_user_id: Optional Keycloak user ID.
Returns:
The ResendSyncedUser record.
Raises:
RuntimeError: If the record could not be created or retrieved.
"""
with self.session_maker() as session:
stmt = (
insert(ResendSyncedUser)
.values(
email=email.lower(),
audience_id=audience_id,
keycloak_user_id=keycloak_user_id,
synced_at=datetime.now(UTC),
)
.on_conflict_do_nothing(constraint='uq_resend_synced_email_audience')
.returning(ResendSyncedUser)
)
result = session.execute(stmt)
session.commit()
row = result.first()
if row:
return row[0]
# on_conflict_do_nothing triggered, fetch the existing record
existing = session.execute(
select(ResendSyncedUser).where(
ResendSyncedUser.email == email.lower(),
ResendSyncedUser.audience_id == audience_id,
)
).first()
if existing:
return existing[0]
raise RuntimeError(
f'Failed to create or retrieve synced user record for {email}'
)
def remove_synced_user(self, email: str, audience_id: str) -> bool:
"""Remove a user's synced status for a specific audience.
Args:
email: The email address of the user.
audience_id: The Resend audience ID.
Returns:
True if a record was deleted, False if no record existed.
"""
with self.session_maker() as session:
stmt = delete(ResendSyncedUser).where(
ResendSyncedUser.email == email.lower(),
ResendSyncedUser.audience_id == audience_id,
)
result = session.execute(stmt)
session.commit()
return result.rowcount > 0

View File

@@ -29,6 +29,20 @@ class RoleStore:
with session_maker() as session:
return session.query(Role).filter(Role.id == role_id).first()
@staticmethod
async def get_role_by_id_async(
role_id: int,
session: Optional[AsyncSession] = None,
) -> Optional[Role]:
"""Get role by ID (async version)."""
if session is not None:
result = await session.execute(select(Role).where(Role.id == role_id))
return result.scalars().first()
async with a_session_maker() as session:
result = await session.execute(select(Role).where(Role.id == role_id))
return result.scalars().first()
@staticmethod
def get_role_by_name(name: str) -> Optional[Role]:
"""Get role by name."""

View File

@@ -24,6 +24,7 @@ from openhands.core.config.openhands_config import OpenHandsConfig
from openhands.server.settings import Settings
from openhands.storage.settings.settings_store import SettingsStore
from openhands.utils.async_utils import call_sync_from_async
from openhands.utils.llm import is_openhands_model
@dataclass
@@ -106,13 +107,13 @@ class SaasSettingsStore(SettingsStore):
},
}
kwargs['llm_api_key'] = org_member.llm_api_key
if org_member.max_iterations is not None:
if org_member.max_iterations:
kwargs['max_iterations'] = org_member.max_iterations
if org_member.llm_model is not None:
if org_member.llm_model:
kwargs['llm_model'] = org_member.llm_model
if org_member.llm_api_key_for_byor is not None:
if org_member.llm_api_key_for_byor:
kwargs['llm_api_key_for_byor'] = org_member.llm_api_key_for_byor
if org_member.llm_base_url is not None:
if org_member.llm_base_url:
kwargs['llm_base_url'] = org_member.llm_base_url
if org.v1_enabled is None:
kwargs['v1_enabled'] = True
@@ -161,9 +162,12 @@ class SaasSettingsStore(SettingsStore):
return None
# Check if we need to generate an LLM key.
is_openhands_provider = self._is_openhands_provider(item)
if is_openhands_provider or item.llm_base_url == LITE_LLM_API_URL:
await self._ensure_api_key(item, str(org_id), is_openhands_provider)
if not item.llm_base_url:
item.llm_base_url = LITE_LLM_API_URL
if item.llm_base_url == LITE_LLM_API_URL:
await self._ensure_api_key(
item, str(org_id), openhands_type=is_openhands_model(item.llm_model)
)
kwargs = item.model_dump(context={'expose_secrets': True})
for model in (user, org, org_member):
@@ -228,10 +232,6 @@ class SaasSettingsStore(SettingsStore):
fernet_key = b64encode(hashlib.sha256(jwt_secret.encode()).digest())
return Fernet(fernet_key)
def _is_openhands_provider(self, item: Settings) -> bool:
"""Check if the settings use the OpenHands provider."""
return bool(item.llm_model and item.llm_model.startswith('openhands/'))
async def _ensure_api_key(
self, item: Settings, org_id: str, openhands_type: bool = False
) -> None:

View File

@@ -5,6 +5,7 @@ Store class for managing users.
import asyncio
import uuid
from typing import Optional
from uuid import UUID
from server.auth.token_manager import TokenManager
from server.constants import (
@@ -82,6 +83,8 @@ class UserStore:
role_id=role_id,
**user_kwargs,
)
user.email = user_info.get('email')
user.email_verified = user_info.get('email_verified')
session.add(user)
role = RoleStore.get_role_by_name('owner')
@@ -172,6 +175,19 @@ class UserStore:
)
decrypted_user_settings = UserSettings(**kwargs)
with session_maker() as session:
# Check if user has completed billing sessions to enable BYOR export
from storage.billing_session import BillingSession
has_completed_billing = (
session.query(BillingSession)
.filter(
BillingSession.user_id == user_id,
BillingSession.status == 'completed',
)
.first()
is not None
)
# create personal org
org = Org(
id=uuid.UUID(user_id),
@@ -180,6 +196,7 @@ class UserStore:
contact_name=resolve_display_name(user_info)
or user_info.get('username', ''),
contact_email=user_info['email'],
byor_export_enabled=has_completed_billing,
)
session.add(org)
@@ -753,12 +770,62 @@ class UserStore:
finally:
await UserStore._release_user_creation_lock(user_id)
@staticmethod
async def get_user_by_email_async(email: str) -> Optional[User]:
"""Get user by email address (async version).
This method looks up a user by their email address. Note that email
addresses may not be unique across all users in rare cases.
Args:
email: The email address to search for
Returns:
User: The user with the matching email, or None if not found
"""
if not email:
return None
async with a_session_maker() as session:
result = await session.execute(
select(User)
.options(joinedload(User.org_members))
.filter(User.email == email.lower().strip())
)
return result.scalars().first()
@staticmethod
def list_users() -> list[User]:
"""List all users."""
with session_maker() as session:
return session.query(User).all()
@staticmethod
def update_current_org(user_id: str, org_id: UUID) -> Optional[User]:
"""Update the user's current organization.
Args:
user_id: The user's ID (Keycloak user ID)
org_id: The organization ID to set as current
Returns:
User: The updated user object, or None if user not found
"""
with session_maker() as session:
user = (
session.query(User)
.filter(User.id == uuid.UUID(user_id))
.with_for_update()
.first()
)
if not user:
return None
user.current_org_id = org_id
session.commit()
session.refresh(user)
return user
@staticmethod
async def backfill_contact_name(user_id: str, user_info: dict) -> None:
"""Update contact_name on the personal org if it still has a username-style value.

View File

@@ -1,11 +1,19 @@
import asyncio
import asyncio # noqa: I001
from storage.proactive_conversation_store import ProactiveConversationStore
# This must be before the import of storage
# to set up logging and prevent alembic from
# running its mouth.
from openhands.core.logger import openhands_logger
from storage.proactive_conversation_store import (
ProactiveConversationStore,
)
OLDER_THAN = 30 # 30 minutes
async def main():
openhands_logger.info('clean_proactive_convo_table')
convo_store = ProactiveConversationStore()
await convo_store.clean_old_convos(older_than_minutes=OLDER_THAN)

View File

@@ -26,6 +26,7 @@ Optional environment variables:
"""
import os
import re
import sys
import time
from typing import Any, Dict, List, Optional
@@ -34,6 +35,7 @@ import resend
from keycloak.exceptions import KeycloakError
from resend.exceptions import ResendError
from server.auth.token_manager import get_keycloak_admin
from storage.resend_synced_user_store import ResendSyncedUserStore
from tenacity import (
retry,
retry_if_exception_type,
@@ -68,9 +70,6 @@ RATE_LIMIT = float(os.environ.get('RATE_LIMIT', '2')) # Requests per second
# Set up Resend API
resend.api_key = RESEND_API_KEY
print('resend module', resend)
print('has contacts', hasattr(resend, 'Contacts'))
class ResendSyncError(Exception):
"""Base exception for Resend sync errors."""
@@ -90,6 +89,31 @@ class ResendAPIError(ResendSyncError):
pass
# Email validation regex pattern - matches standard email format
# This pattern is intentionally strict to avoid Resend API validation errors
# It rejects special characters like ! that some email providers technically allow
# but Resend's API does not accept
EMAIL_REGEX = re.compile(r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$')
def is_valid_email(email: str) -> bool:
"""Validate an email address format.
This uses a regex pattern that matches most valid email addresses
while rejecting addresses with special characters that Resend's API
does not accept (e.g., exclamation marks).
Args:
email: The email address to validate.
Returns:
True if the email is valid, False otherwise.
"""
if not email:
return False
return bool(EMAIL_REGEX.match(email))
def get_keycloak_users(offset: int = 0, limit: int = 100) -> List[Dict[str, Any]]:
"""Get users from Keycloak using the admin client.
@@ -173,8 +197,6 @@ def get_resend_contacts(audience_id: str) -> Dict[str, Dict[str, Any]]:
Raises:
ResendAPIError: If the API call fails.
"""
print('getting resend contacts')
print('has resend contacts', hasattr(resend, 'Contacts'))
try:
contacts = resend.Contacts.list(audience_id).get('data', [])
# Create a dictionary mapping email addresses to contact data for
@@ -291,8 +313,84 @@ def send_welcome_email(
raise
def _get_resend_synced_user_store() -> ResendSyncedUserStore:
"""Get the ResendSyncedUserStore instance.
This is separated into a function to allow for easier testing/mocking.
"""
from openhands.app_server.config import get_global_config
config = get_global_config()
db_session_injector = config.db_session
return ResendSyncedUserStore(session_maker=db_session_injector.get_session_maker())
def _backfill_existing_resend_contacts(
synced_user_store: ResendSyncedUserStore,
audience_id: str,
) -> int:
"""Backfill the synced_users table with contacts already in Resend.
This ensures that users who were added to Resend before the tracking
table existed are properly recorded, preventing duplicate welcome emails.
Args:
synced_user_store: The store for tracking synced users.
audience_id: The Resend audience ID.
Returns:
The number of contacts backfilled.
"""
logger.info('Starting backfill of existing Resend contacts...')
try:
resend_contacts = get_resend_contacts(audience_id)
logger.info(f'Found {len(resend_contacts)} contacts in Resend audience')
already_synced_emails = synced_user_store.get_synced_emails_for_audience(
audience_id
)
logger.info(
f'Found {len(already_synced_emails)} already synced emails in database'
)
backfilled_count = 0
for email in resend_contacts:
if email.lower() not in already_synced_emails:
synced_user_store.mark_user_synced(
email=email,
audience_id=audience_id,
keycloak_user_id=None, # We don't have this info during backfill
)
backfilled_count += 1
logger.debug(f'Backfilled existing Resend contact: {email}')
logger.info(
f'Backfill completed: {backfilled_count} contacts added to tracking'
)
return backfilled_count
except Exception:
logger.exception('Error during backfill of existing Resend contacts')
# Don't fail the entire sync if backfill fails - just log and continue
return 0
def sync_users_to_resend():
"""Sync users from Keycloak to Resend."""
"""Sync users from Keycloak to Resend.
This function syncs users from Keycloak to a Resend audience. It tracks
which users have been synced in the database to ensure that:
1. Users are only added once (even across multiple sync runs)
2. Users who are manually deleted from Resend are not re-added
The tracking is done via the resend_synced_users table, which records
each email/audience_id combination that has been synced.
On first run (or when new contacts exist in Resend), it will backfill
the tracking table with existing Resend contacts to avoid sending
duplicate welcome emails.
"""
# Check required environment variables
required_vars = {
'RESEND_API_KEY': RESEND_API_KEY,
@@ -318,27 +416,36 @@ def sync_users_to_resend():
)
try:
# Get the store for tracking synced users
synced_user_store = _get_resend_synced_user_store()
# Backfill existing Resend contacts into our tracking table
# This ensures users already in Resend don't get duplicate welcome emails
backfilled_count = _backfill_existing_resend_contacts(
synced_user_store, RESEND_AUDIENCE_ID
)
# Get the total number of users
total_users = get_total_keycloak_users()
logger.info(
f'Found {total_users} users in Keycloak realm {KEYCLOAK_REALM_NAME}'
)
# Get contacts from Resend
resend_contacts = get_resend_contacts(RESEND_AUDIENCE_ID)
logger.info(
f'Found {len(resend_contacts)} contacts in Resend audience '
f'{RESEND_AUDIENCE_ID}'
)
# Stats
stats = {
'total_users': total_users,
'existing_contacts': len(resend_contacts),
'backfilled_contacts': backfilled_count,
'already_synced': 0,
'added_contacts': 0,
'skipped_invalid_emails': 0,
'errors': 0,
}
synced_emails = synced_user_store.get_synced_emails_for_audience(
RESEND_AUDIENCE_ID
)
logger.info(f'Found {len(synced_emails)} already synced emails in database')
# Process users in batches
offset = 0
while offset < total_users:
@@ -351,39 +458,65 @@ def sync_users_to_resend():
continue
email = email.lower()
if email in resend_contacts:
logger.debug(f'User {email} already exists in Resend, skipping')
if email in synced_emails:
logger.debug(
f'User {email} was already synced to this audience, skipping'
)
stats['already_synced'] += 1
continue
# Validate email format before attempting to add to Resend
if not is_valid_email(email):
logger.warning(f'Skipping user with invalid email format: {email}')
stats['skipped_invalid_emails'] += 1
continue
first_name = user.get('first_name')
last_name = user.get('last_name')
keycloak_user_id = user.get('id')
# Mark as synced first (optimistic) to ensure consistency.
# If Resend API fails, we remove the record.
try:
synced_user_store.mark_user_synced(
email=email,
audience_id=RESEND_AUDIENCE_ID,
keycloak_user_id=keycloak_user_id,
)
except Exception:
logger.exception(f'Failed to mark user {email} as synced')
stats['errors'] += 1
continue
try:
first_name = user.get('first_name')
last_name = user.get('last_name')
# Add the contact to the Resend audience
add_contact_to_resend(
RESEND_AUDIENCE_ID, email, first_name, last_name
)
logger.info(f'Added user {email} to Resend')
stats['added_contacts'] += 1
# Sleep to respect rate limit after first API call
time.sleep(1 / RATE_LIMIT)
# Send a welcome email to the newly added contact
try:
send_welcome_email(email, first_name, last_name)
logger.info(f'Sent welcome email to {email}')
except Exception:
logger.exception(
f'Failed to send welcome email to {email}, but contact was added to audience'
)
# Continue with the sync process even if sending the welcome email fails
# Sleep to respect rate limit after second API call
time.sleep(1 / RATE_LIMIT)
except Exception:
logger.exception(f'Error adding user {email} to Resend')
synced_user_store.remove_synced_user(email, RESEND_AUDIENCE_ID)
stats['errors'] += 1
continue
synced_emails.add(email)
stats['added_contacts'] += 1
# Sleep to respect rate limit after first API call
time.sleep(1 / RATE_LIMIT)
# Send a welcome email to the newly added contact
try:
send_welcome_email(email, first_name, last_name)
logger.info(f'Sent welcome email to {email}')
except Exception:
logger.exception(
f'Failed to send welcome email to {email}, but contact was added to audience'
)
# Sleep to respect rate limit after second API call
time.sleep(1 / RATE_LIMIT)
offset += BATCH_SIZE

View File

@@ -126,3 +126,24 @@ def test_run_agent_variant_tests_v1_calls_handler_and_sets_system_prompt(monkeyp
# Should be a different instance than the original (copied after handler runs)
assert result is not agent
assert result.system_prompt_filename == 'system_prompt_long_horizon.j2'
@patch('experiments.experiment_manager.ENABLE_EXPERIMENT_MANAGER', True)
@patch('experiments.experiment_manager.EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT', True)
def test_run_agent_variant_tests_v1_preserves_planning_agent_system_prompt():
"""Planning agents should retain their specialized system prompt and not be overwritten by the experiment."""
# Arrange
planning_agent = make_agent().model_copy(
update={'system_prompt_filename': 'system_prompt_planning.j2'}
)
conv_id = uuid4()
# Act
result: Agent = SaaSExperimentManager.run_agent_variant_tests__v1(
user_id='user-planning',
conversation_id=conv_id,
agent=planning_agent,
)
# Assert
assert result.system_prompt_filename == 'system_prompt_planning.j2'

View File

@@ -141,12 +141,14 @@ def test_custom_to_static_conversion():
def create_provider_tokens(
tokens_dict: dict[ProviderType, str],
) -> dict[ProviderType, ProviderToken]:
"""Helper to create provider tokens dictionary."""
return {
provider_type: ProviderToken(token=SecretStr(token_value))
for provider_type, token_value in tokens_dict.items()
}
) -> MappingProxyType:
"""Helper to create provider tokens as MappingProxyType."""
return MappingProxyType(
{
provider_type: ProviderToken(token=SecretStr(token_value))
for provider_type, token_value in tokens_dict.items()
}
)
@pytest.mark.asyncio
@@ -264,3 +266,63 @@ async def test_get_latest_token_can_be_used_with_static_secret(
# Assert - this should NOT raise a ValidationError
static_secret = StaticSecret(value=token, description='GITHUB authentication token')
assert static_secret.get_value() == token_value
# ---------------------------------------------------------------------------
# Tests for get_authenticated_git_url - ensuring proper authenticated URLs
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_authenticated_git_url_raises_when_no_tokens(
resolver_context, mock_saas_user_auth
):
"""Test that get_authenticated_git_url raises error when no provider tokens available."""
# Arrange
mock_saas_user_auth.get_provider_tokens = AsyncMock(return_value=None)
# Act & Assert
with pytest.raises(ValueError, match='No provider tokens available'):
await resolver_context.get_authenticated_git_url('owner/repo')
@pytest.mark.asyncio
async def test_get_provider_handler_caches_instance(
resolver_context, mock_saas_user_auth
):
"""Test that _get_provider_handler caches the handler instance."""
# Arrange
token_value = 'ghp_test_token'
provider_tokens = create_provider_tokens({ProviderType.GITHUB: token_value})
mock_saas_user_auth.get_provider_tokens = AsyncMock(return_value=provider_tokens)
mock_saas_user_auth.get_user_id = AsyncMock(return_value='test-user-id')
# Act - call _get_provider_handler twice
handler1 = await resolver_context._get_provider_handler()
handler2 = await resolver_context._get_provider_handler()
# Assert - should be the same instance (cached)
assert handler1 is handler2
# get_provider_tokens should only be called once
assert mock_saas_user_auth.get_provider_tokens.call_count == 1
@pytest.mark.asyncio
async def test_get_provider_handler_creates_handler_with_correct_params(
resolver_context, mock_saas_user_auth
):
"""Test that _get_provider_handler creates ProviderHandler with correct parameters."""
# Arrange
token_value = 'ghp_test_token'
provider_tokens = create_provider_tokens({ProviderType.GITHUB: token_value})
mock_saas_user_auth.get_provider_tokens = AsyncMock(return_value=provider_tokens)
mock_saas_user_auth.get_user_id = AsyncMock(return_value='test-user-id')
# Act
handler = await resolver_context._get_provider_handler()
# Assert
from openhands.integrations.provider import ProviderHandler
assert isinstance(handler, ProviderHandler)
assert handler.provider_tokens == provider_tokens

View File

@@ -6,6 +6,9 @@ import httpx
import pytest
from fastapi import HTTPException
from server.routes.api_keys import (
ByorPermittedResponse,
LlmApiKeyResponse,
check_byor_permitted,
delete_byor_key_from_litellm,
get_llm_api_key_for_byor,
)
@@ -182,16 +185,18 @@ class TestGetLlmApiKeyForByor:
"""Test the get_llm_api_key_for_byor endpoint."""
@pytest.mark.asyncio
@patch('storage.org_service.OrgService.check_byor_export_enabled')
@patch('server.routes.api_keys.store_byor_key_in_db')
@patch('server.routes.api_keys.generate_byor_key')
@patch('server.routes.api_keys.get_byor_key_from_db')
async def test_no_key_in_database_generates_new(
self, mock_get_key, mock_generate_key, mock_store_key
self, mock_get_key, mock_generate_key, mock_store_key, mock_check_enabled
):
"""Test that when no key exists in database, a new one is generated."""
# Arrange
user_id = 'user-123'
new_key = 'sk-new-generated-key'
mock_check_enabled.return_value = True
mock_get_key.return_value = None
mock_generate_key.return_value = new_key
mock_store_key.return_value = None
@@ -200,21 +205,24 @@ class TestGetLlmApiKeyForByor:
result = await get_llm_api_key_for_byor(user_id=user_id)
# Assert
assert result == {'key': new_key}
assert result == LlmApiKeyResponse(key=new_key)
mock_check_enabled.assert_called_once_with(user_id)
mock_get_key.assert_called_once_with(user_id)
mock_generate_key.assert_called_once_with(user_id)
mock_store_key.assert_called_once_with(user_id, new_key)
@pytest.mark.asyncio
@patch('storage.org_service.OrgService.check_byor_export_enabled')
@patch('storage.lite_llm_manager.LiteLlmManager.verify_key')
@patch('server.routes.api_keys.get_byor_key_from_db')
async def test_valid_key_in_database_returns_key(
self, mock_get_key, mock_verify_key
self, mock_get_key, mock_verify_key, mock_check_enabled
):
"""Test that when a valid key exists in database, it is returned."""
# Arrange
user_id = 'user-123'
existing_key = 'sk-existing-valid-key'
mock_check_enabled.return_value = True
mock_get_key.return_value = existing_key
mock_verify_key.return_value = True
@@ -222,11 +230,13 @@ class TestGetLlmApiKeyForByor:
result = await get_llm_api_key_for_byor(user_id=user_id)
# Assert
assert result == {'key': existing_key}
assert result == LlmApiKeyResponse(key=existing_key)
mock_check_enabled.assert_called_once_with(user_id)
mock_get_key.assert_called_once_with(user_id)
mock_verify_key.assert_called_once_with(existing_key, user_id)
@pytest.mark.asyncio
@patch('storage.org_service.OrgService.check_byor_export_enabled')
@patch('server.routes.api_keys.store_byor_key_in_db')
@patch('server.routes.api_keys.generate_byor_key')
@patch('server.routes.api_keys.delete_byor_key_from_litellm')
@@ -239,12 +249,14 @@ class TestGetLlmApiKeyForByor:
mock_delete_key,
mock_generate_key,
mock_store_key,
mock_check_enabled,
):
"""Test that when an invalid key exists in database, it is regenerated."""
# Arrange
user_id = 'user-123'
invalid_key = 'sk-invalid-key'
new_key = 'sk-new-generated-key'
mock_check_enabled.return_value = True
mock_get_key.return_value = invalid_key
mock_verify_key.return_value = False
mock_delete_key.return_value = True
@@ -255,7 +267,8 @@ class TestGetLlmApiKeyForByor:
result = await get_llm_api_key_for_byor(user_id=user_id)
# Assert
assert result == {'key': new_key}
assert result == LlmApiKeyResponse(key=new_key)
mock_check_enabled.assert_called_once_with(user_id)
mock_get_key.assert_called_once_with(user_id)
mock_verify_key.assert_called_once_with(invalid_key, user_id)
mock_delete_key.assert_called_once_with(user_id, invalid_key)
@@ -263,6 +276,7 @@ class TestGetLlmApiKeyForByor:
mock_store_key.assert_called_once_with(user_id, new_key)
@pytest.mark.asyncio
@patch('storage.org_service.OrgService.check_byor_export_enabled')
@patch('server.routes.api_keys.store_byor_key_in_db')
@patch('server.routes.api_keys.generate_byor_key')
@patch('server.routes.api_keys.delete_byor_key_from_litellm')
@@ -275,12 +289,14 @@ class TestGetLlmApiKeyForByor:
mock_delete_key,
mock_generate_key,
mock_store_key,
mock_check_enabled,
):
"""Test that even if deletion fails, regeneration still proceeds."""
# Arrange
user_id = 'user-123'
invalid_key = 'sk-invalid-key'
new_key = 'sk-new-generated-key'
mock_check_enabled.return_value = True
mock_get_key.return_value = invalid_key
mock_verify_key.return_value = False
mock_delete_key.return_value = False # Deletion fails
@@ -291,20 +307,23 @@ class TestGetLlmApiKeyForByor:
result = await get_llm_api_key_for_byor(user_id=user_id)
# Assert
assert result == {'key': new_key}
assert result == LlmApiKeyResponse(key=new_key)
mock_check_enabled.assert_called_once_with(user_id)
mock_delete_key.assert_called_once_with(user_id, invalid_key)
mock_generate_key.assert_called_once_with(user_id)
mock_store_key.assert_called_once_with(user_id, new_key)
@pytest.mark.asyncio
@patch('storage.org_service.OrgService.check_byor_export_enabled')
@patch('server.routes.api_keys.generate_byor_key')
@patch('server.routes.api_keys.get_byor_key_from_db')
async def test_key_generation_failure_raises_exception(
self, mock_get_key, mock_generate_key
self, mock_get_key, mock_generate_key, mock_check_enabled
):
"""Test that when key generation fails, an HTTPException is raised."""
# Arrange
user_id = 'user-123'
mock_check_enabled.return_value = True
mock_get_key.return_value = None
mock_generate_key.return_value = None
@@ -316,11 +335,15 @@ class TestGetLlmApiKeyForByor:
assert 'Failed to generate new BYOR LLM API key' in exc_info.value.detail
@pytest.mark.asyncio
@patch('storage.org_service.OrgService.check_byor_export_enabled')
@patch('server.routes.api_keys.get_byor_key_from_db')
async def test_database_error_raises_exception(self, mock_get_key):
async def test_database_error_raises_exception(
self, mock_get_key, mock_check_enabled
):
"""Test that database errors are properly handled."""
# Arrange
user_id = 'user-123'
mock_check_enabled.return_value = True
mock_get_key.side_effect = Exception('Database connection error')
# Act & Assert
@@ -330,6 +353,21 @@ class TestGetLlmApiKeyForByor:
assert exc_info.value.status_code == 500
assert 'Failed to retrieve BYOR LLM API key' in exc_info.value.detail
@pytest.mark.asyncio
@patch('storage.org_service.OrgService.check_byor_export_enabled')
async def test_byor_export_disabled_returns_402(self, mock_check_enabled):
"""Test that when BYOR export is disabled, 402 is returned."""
# Arrange
user_id = 'user-123'
mock_check_enabled.return_value = False
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
await get_llm_api_key_for_byor(user_id=user_id)
assert exc_info.value.status_code == 402
assert 'BYOR key export is not enabled' in exc_info.value.detail
class TestDeleteByorKeyFromLitellm:
"""Test the delete_byor_key_from_litellm function with alias cleanup."""
@@ -425,3 +463,52 @@ class TestDeleteByorKeyFromLitellm:
# Assert
assert result is False
class TestCheckByorPermitted:
"""Test the check_byor_permitted endpoint."""
@pytest.mark.asyncio
@patch('storage.org_service.OrgService.check_byor_export_enabled')
async def test_permitted_when_enabled(self, mock_check_enabled):
"""Test that permitted=True is returned when BYOR export is enabled."""
# Arrange
user_id = 'user-123'
mock_check_enabled.return_value = True
# Act
result = await check_byor_permitted(user_id=user_id)
# Assert
assert result == ByorPermittedResponse(permitted=True)
mock_check_enabled.assert_called_once_with(user_id)
@pytest.mark.asyncio
@patch('storage.org_service.OrgService.check_byor_export_enabled')
async def test_not_permitted_when_disabled(self, mock_check_enabled):
"""Test that permitted=False is returned when BYOR export is disabled."""
# Arrange
user_id = 'user-123'
mock_check_enabled.return_value = False
# Act
result = await check_byor_permitted(user_id=user_id)
# Assert
assert result == ByorPermittedResponse(permitted=False)
mock_check_enabled.assert_called_once_with(user_id)
@pytest.mark.asyncio
@patch('storage.org_service.OrgService.check_byor_export_enabled')
async def test_error_raises_500(self, mock_check_enabled):
"""Test that an exception raises 500 error."""
# Arrange
user_id = 'user-123'
mock_check_enabled.side_effect = Exception('Database error')
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
await check_byor_permitted(user_id=user_id)
assert exc_info.value.status_code == 500
assert 'Failed to check BYOR export permission' in exc_info.value.detail

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,158 @@
"""Unit tests for ResendSyncedUserStore."""
from unittest.mock import MagicMock
import pytest
# Import directly from the module files to avoid loading all of storage/__init__.py
# which has many dependencies
from storage.resend_synced_user import ResendSyncedUser
from storage.resend_synced_user_store import ResendSyncedUserStore
@pytest.fixture
def mock_session():
"""Mock database session."""
session = MagicMock()
return session
@pytest.fixture
def mock_session_maker(mock_session):
"""Mock session maker."""
session_maker = MagicMock()
session_maker.return_value.__enter__.return_value = mock_session
session_maker.return_value.__exit__.return_value = None
return session_maker
@pytest.fixture
def store(mock_session_maker):
"""Create ResendSyncedUserStore instance."""
return ResendSyncedUserStore(session_maker=mock_session_maker)
class TestResendSyncedUserStore:
"""Test cases for ResendSyncedUserStore."""
def test_is_user_synced_returns_true_when_exists(self, store, mock_session):
"""Test is_user_synced returns True when user exists in database."""
email = 'test@example.com'
audience_id = 'test-audience-123'
mock_row = MagicMock()
mock_session.execute.return_value.first.return_value = mock_row
result = store.is_user_synced(email, audience_id)
assert result is True
mock_session.execute.assert_called_once()
def test_is_user_synced_returns_false_when_not_exists(self, store, mock_session):
"""Test is_user_synced returns False when user doesn't exist."""
email = 'test@example.com'
audience_id = 'test-audience-123'
mock_session.execute.return_value.first.return_value = None
result = store.is_user_synced(email, audience_id)
assert result is False
def test_is_user_synced_normalizes_email_to_lowercase(self, store, mock_session):
"""Test that is_user_synced normalizes email to lowercase."""
email = 'TEST@EXAMPLE.COM'
audience_id = 'test-audience-123'
mock_session.execute.return_value.first.return_value = None
store.is_user_synced(email, audience_id)
# Verify the query was called (we can't easily check the exact SQL)
mock_session.execute.assert_called_once()
def test_mark_user_synced_creates_new_record(self, store, mock_session):
"""Test that mark_user_synced creates a new record."""
email = 'test@example.com'
audience_id = 'test-audience-123'
keycloak_user_id = 'kc-user-123'
mock_synced_user = MagicMock(spec=ResendSyncedUser)
mock_result = MagicMock()
mock_result.first.return_value = (mock_synced_user,)
mock_session.execute.return_value = mock_result
result = store.mark_user_synced(email, audience_id, keycloak_user_id)
assert result == mock_synced_user
mock_session.execute.assert_called_once()
mock_session.commit.assert_called_once()
def test_mark_user_synced_handles_existing_record(self, store, mock_session):
"""Test that mark_user_synced handles conflict (existing record)."""
email = 'test@example.com'
audience_id = 'test-audience-123'
# First execute (insert) returns None (conflict occurred)
# Second execute (select existing) returns the record
mock_existing_user = MagicMock(spec=ResendSyncedUser)
mock_result_insert = MagicMock()
mock_result_insert.first.return_value = None
mock_result_select = MagicMock()
mock_result_select.first.return_value = (mock_existing_user,)
mock_session.execute.side_effect = [mock_result_insert, mock_result_select]
result = store.mark_user_synced(email, audience_id)
assert result == mock_existing_user
assert mock_session.execute.call_count == 2
mock_session.commit.assert_called_once()
def test_mark_user_synced_normalizes_email_to_lowercase(self, store, mock_session):
"""Test that mark_user_synced normalizes email to lowercase."""
email = 'TEST@EXAMPLE.COM'
audience_id = 'test-audience-123'
mock_synced_user = MagicMock(spec=ResendSyncedUser)
mock_result = MagicMock()
mock_result.first.return_value = (mock_synced_user,)
mock_session.execute.return_value = mock_result
store.mark_user_synced(email, audience_id)
# Verify execute was called (the email normalization happens in the SQL)
mock_session.execute.assert_called_once()
mock_session.commit.assert_called_once()
def test_mark_user_synced_without_keycloak_user_id(self, store, mock_session):
"""Test that mark_user_synced works without keycloak_user_id."""
email = 'test@example.com'
audience_id = 'test-audience-123'
mock_synced_user = MagicMock(spec=ResendSyncedUser)
mock_result = MagicMock()
mock_result.first.return_value = (mock_synced_user,)
mock_session.execute.return_value = mock_result
result = store.mark_user_synced(email, audience_id)
assert result == mock_synced_user
mock_session.execute.assert_called_once()
class TestResendSyncedUser:
"""Test cases for ResendSyncedUser model."""
def test_model_has_required_fields(self):
"""Test that the model has all required fields."""
assert hasattr(ResendSyncedUser, 'id')
assert hasattr(ResendSyncedUser, 'email')
assert hasattr(ResendSyncedUser, 'audience_id')
assert hasattr(ResendSyncedUser, 'synced_at')
assert hasattr(ResendSyncedUser, 'keycloak_user_id')
def test_model_table_name(self):
"""Test the model's table name."""
assert ResendSyncedUser.__tablename__ == 'resend_synced_users'

View File

@@ -10,8 +10,12 @@ from unittest.mock import AsyncMock, MagicMock
from uuid import UUID, uuid4
import pytest
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from storage.base import Base
from storage.org import Org
from storage.user import User
from enterprise.server.utils.saas_app_conversation_info_injector import (
SaasSQLAppConversationInfoService,
@@ -20,10 +24,15 @@ from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationInfo,
)
from openhands.app_server.user.specifiy_user_context import SpecifyUserContext
from openhands.app_server.utils.sql_utils import Base
from openhands.integrations.service_types import ProviderType
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
# Test UUIDs
USER1_ID = UUID('a1111111-1111-1111-1111-111111111111')
USER2_ID = UUID('b2222222-2222-2222-2222-222222222222')
ORG1_ID = UUID('c1111111-1111-1111-1111-111111111111')
ORG2_ID = UUID('d2222222-2222-2222-2222-222222222222')
@pytest.fixture
async def async_engine():
@@ -55,6 +64,41 @@ async def async_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
yield db_session
@pytest.fixture
async def async_session_with_users(async_engine) -> AsyncGenerator[AsyncSession, None]:
"""Create an async session with pre-populated Org and User rows for testing."""
async_session_maker = async_sessionmaker(
async_engine, class_=AsyncSession, expire_on_commit=False
)
async with async_session_maker() as db_session:
# Insert Orgs first (required for User foreign key)
org1 = Org(
id=ORG1_ID,
name='test-org-1',
enable_default_condenser=True,
enable_proactive_conversation_starters=True,
)
org2 = Org(
id=ORG2_ID,
name='test-org-2',
enable_default_condenser=True,
enable_proactive_conversation_starters=True,
)
db_session.add(org1)
db_session.add(org2)
await db_session.flush()
# Insert Users
user1 = User(id=USER1_ID, current_org_id=ORG1_ID)
user2 = User(id=USER2_ID, current_org_id=ORG2_ID)
db_session.add(user1)
db_session.add(user2)
await db_session.commit()
yield db_session
@pytest.fixture
def service(async_session) -> SaasSQLAppConversationInfoService:
"""Create a SQLAppConversationInfoService instance for testing."""
@@ -178,15 +222,26 @@ class TestSaasSQLAppConversationInfoService:
assert user1_id != user2_id
@pytest.mark.asyncio
async def test_secure_select_includes_user_filtering(
async def test_secure_select_includes_user_and_org_filtering(
self,
saas_service_user1: SaasSQLAppConversationInfoService,
async_session_with_users: AsyncSession,
):
"""Test that _secure_select method includes user filtering."""
# This test verifies that the _secure_select method exists and can be called
# The actual SQL generation is tested implicitly through integration
query = await saas_service_user1._secure_select()
assert query is not None
"""Test that _secure_select method includes both user_id and org_id filtering."""
service = SaasSQLAppConversationInfoService(
db_session=async_session_with_users,
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
)
query = await service._secure_select()
# Convert query to string to verify filters are present
query_str = str(query.compile(compile_kwargs={'literal_binds': True}))
# Verify user_id filter is present
assert str(USER1_ID) in query_str or str(USER1_ID).replace('-', '') in query_str
# Verify org_id filter is present (user1 is in org1)
assert str(ORG1_ID) in query_str or str(ORG1_ID).replace('-', '') in query_str
@pytest.mark.asyncio
async def test_to_info_with_user_id_functionality(
@@ -241,100 +296,32 @@ class TestSaasSQLAppConversationInfoService:
assert result.sandbox_id == 'test-sandbox'
@pytest.mark.asyncio
async def test_user_isolation(
async def test_user_isolation_different_users(
self,
async_session: AsyncSession,
multiple_conversation_infos: list[AppConversationInfo],
async_session_with_users: AsyncSession,
):
"""Test that user isolation works correctly."""
from unittest.mock import MagicMock
from storage.user import User
# Mock the database session execute method to return mock users
# This mock intercepts User queries and returns a mock user object
# with user_id and org_id the same as the user_id_uuid from the query
original_execute = async_session.execute
async def mock_execute(query):
query_str = str(query)
# Check if this is a User query
if '"user"' in query_str.lower() and '"user".id' in query_str.lower():
# Extract the UUID from the query parameters
# The query will have bound parameters, we need to get the UUID value
if hasattr(query, 'compile'):
try:
compiled = query.compile(compile_kwargs={'literal_binds': True})
query_with_params = str(compiled)
# Extract UUID from the query string
import re
# Try both formats: with dashes and without dashes
uuid_pattern_with_dashes = r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}'
uuid_pattern_without_dashes = r'[a-f0-9]{32}'
uuid_match = re.search(
uuid_pattern_with_dashes, query_with_params
)
if not uuid_match:
uuid_match = re.search(
uuid_pattern_without_dashes, query_with_params
)
if uuid_match:
user_id_str = uuid_match.group(0)
# If the UUID doesn't have dashes, add them
if len(user_id_str) == 32 and '-' not in user_id_str:
# Convert from 'a1111111111111111111111111111111' to 'a1111111-1111-1111-1111-111111111111'
user_id_str = f'{user_id_str[:8]}-{user_id_str[8:12]}-{user_id_str[12:16]}-{user_id_str[16:20]}-{user_id_str[20:]}'
user_id_uuid = UUID(user_id_str)
# Create a mock user with user_id and org_id the same as user_id_uuid
mock_user = MagicMock(spec=User)
mock_user.id = user_id_uuid
mock_user.current_org_id = user_id_uuid
# Create a mock result
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = mock_user
return mock_result
except Exception:
# If there's any error in parsing, fall back to original execute
pass
# For all other queries, use the original execute method
return await original_execute(query)
# Apply the mock
async_session.execute = mock_execute
"""Test that different users cannot see each other's conversations."""
# Create services for different users
user1_service = SaasSQLAppConversationInfoService(
db_session=async_session,
user_context=SpecifyUserContext(
user_id='a1111111-1111-1111-1111-111111111111'
),
db_session=async_session_with_users,
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
)
user2_service = SaasSQLAppConversationInfoService(
db_session=async_session,
user_context=SpecifyUserContext(
user_id='b2222222-2222-2222-2222-222222222222'
),
db_session=async_session_with_users,
user_context=SpecifyUserContext(user_id=str(USER2_ID)),
)
# Create conversations for different users
user1_info = AppConversationInfo(
id=uuid4(),
created_by_user_id='a1111111-1111-1111-1111-111111111111',
created_by_user_id=str(USER1_ID),
sandbox_id='sandbox_user1',
title='User 1 Conversation',
)
user2_info = AppConversationInfo(
id=uuid4(),
created_by_user_id='b2222222-2222-2222-2222-222222222222',
created_by_user_id=str(USER2_ID),
sandbox_id='sandbox_user2',
title='User 2 Conversation',
)
@@ -346,18 +333,12 @@ class TestSaasSQLAppConversationInfoService:
# User 1 should only see their conversation
user1_page = await user1_service.search_app_conversation_info()
assert len(user1_page.items) == 1
assert (
user1_page.items[0].created_by_user_id
== 'a1111111-1111-1111-1111-111111111111'
)
assert user1_page.items[0].created_by_user_id == str(USER1_ID)
# User 2 should only see their conversation
user2_page = await user2_service.search_app_conversation_info()
assert len(user2_page.items) == 1
assert (
user2_page.items[0].created_by_user_id
== 'b2222222-2222-2222-2222-222222222222'
)
assert user2_page.items[0].created_by_user_id == str(USER2_ID)
# User 1 should not be able to get user 2's conversation
user2_from_user1 = await user1_service.get_app_conversation_info(user2_info.id)
@@ -366,3 +347,142 @@ class TestSaasSQLAppConversationInfoService:
# User 2 should not be able to get user 1's conversation
user1_from_user2 = await user2_service.get_app_conversation_info(user1_info.id)
assert user1_from_user2 is None
@pytest.mark.asyncio
async def test_same_user_org_switching_isolation(
self,
async_session_with_users: AsyncSession,
):
"""Test that the same user switching orgs cannot see conversations from other orgs.
This tests the actual bug scenario: a user creates a conversation in org1,
then switches to org2, and should NOT see org1's conversations.
"""
# Create service for user1 in org1
user1_service_org1 = SaasSQLAppConversationInfoService(
db_session=async_session_with_users,
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
)
# Create a conversation while user is in org1
conv_in_org1 = AppConversationInfo(
id=uuid4(),
created_by_user_id=str(USER1_ID),
sandbox_id='sandbox_org1',
title='Conversation in Org 1',
)
await user1_service_org1.save_app_conversation_info(conv_in_org1)
# Verify user can see the conversation in org1
page_in_org1 = await user1_service_org1.search_app_conversation_info()
assert len(page_in_org1.items) == 1
assert page_in_org1.items[0].title == 'Conversation in Org 1'
# Simulate user switching to org2 by updating current_org_id using ORM
result = await async_session_with_users.execute(
select(User).where(User.id == USER1_ID)
)
user_to_update = result.scalars().first()
user_to_update.current_org_id = ORG2_ID
await async_session_with_users.commit()
# Clear SQLAlchemy's identity map cache to simulate a new request
async_session_with_users.expire_all()
# Create new service instance (simulating a new request after org switch)
user1_service_org2 = SaasSQLAppConversationInfoService(
db_session=async_session_with_users,
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
)
# User should NOT see org1's conversations after switching to org2
page_in_org2 = await user1_service_org2.search_app_conversation_info()
assert (
len(page_in_org2.items) == 0
), 'User should not see conversations from org1 after switching to org2'
# User should not be able to get the specific conversation from org1
conv_from_org2 = await user1_service_org2.get_app_conversation_info(
conv_in_org1.id
)
assert (
conv_from_org2 is None
), 'User should not be able to access org1 conversation from org2'
# Now create a conversation in org2
conv_in_org2 = AppConversationInfo(
id=uuid4(),
created_by_user_id=str(USER1_ID),
sandbox_id='sandbox_org2',
title='Conversation in Org 2',
)
await user1_service_org2.save_app_conversation_info(conv_in_org2)
# User should only see org2's conversation
page_in_org2_after = await user1_service_org2.search_app_conversation_info()
assert len(page_in_org2_after.items) == 1
assert page_in_org2_after.items[0].title == 'Conversation in Org 2'
# Switch back to org1 and verify isolation works both ways
result = await async_session_with_users.execute(
select(User).where(User.id == USER1_ID)
)
user_to_update = result.scalars().first()
user_to_update.current_org_id = ORG1_ID
await async_session_with_users.commit()
async_session_with_users.expire_all()
user1_service_back_to_org1 = SaasSQLAppConversationInfoService(
db_session=async_session_with_users,
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
)
# User should only see org1's conversation now
page_back_in_org1 = (
await user1_service_back_to_org1.search_app_conversation_info()
)
assert len(page_back_in_org1.items) == 1
assert page_back_in_org1.items[0].title == 'Conversation in Org 1'
@pytest.mark.asyncio
async def test_count_respects_org_isolation(
self,
async_session_with_users: AsyncSession,
):
"""Test that count_app_conversation_info respects org isolation."""
# Create service for user1 in org1
user1_service = SaasSQLAppConversationInfoService(
db_session=async_session_with_users,
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
)
# Create conversations in org1
for i in range(3):
conv = AppConversationInfo(
id=uuid4(),
created_by_user_id=str(USER1_ID),
sandbox_id=f'sandbox_org1_{i}',
title=f'Org1 Conversation {i}',
)
await user1_service.save_app_conversation_info(conv)
# Count should be 3
count_org1 = await user1_service.count_app_conversation_info()
assert count_org1 == 3
# Switch to org2 using ORM
result = await async_session_with_users.execute(
select(User).where(User.id == USER1_ID)
)
user_to_update = result.scalars().first()
user_to_update.current_org_id = ORG2_ID
await async_session_with_users.commit()
async_session_with_users.expire_all()
user1_service_org2 = SaasSQLAppConversationInfoService(
db_session=async_session_with_users,
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
)
# Count should be 0 in org2
count_org2 = await user1_service_org2.count_app_conversation_info()
assert count_org2 == 0

View File

@@ -0,0 +1,117 @@
"""Tests for resend_keycloak email validation."""
from sync.resend_keycloak import is_valid_email
class TestIsValidEmail:
"""Test cases for is_valid_email function."""
def test_valid_simple_email(self):
"""Test that a simple valid email passes validation."""
assert is_valid_email('user@example.com') is True
def test_valid_email_with_plus(self):
"""Test that email with + modifier passes validation."""
assert is_valid_email('user+tag@example.com') is True
def test_valid_email_with_dots(self):
"""Test that email with dots in local part passes validation."""
assert is_valid_email('first.last@example.com') is True
def test_valid_email_with_numbers(self):
"""Test that email with numbers passes validation."""
assert is_valid_email('user123@example.com') is True
def test_valid_email_with_subdomain(self):
"""Test that email with subdomain passes validation."""
assert is_valid_email('user@mail.example.com') is True
def test_valid_email_with_hyphen_domain(self):
"""Test that email with hyphen in domain passes validation."""
assert is_valid_email('user@example-site.com') is True
def test_valid_email_with_underscore(self):
"""Test that email with underscore passes validation."""
assert is_valid_email('user_name@example.com') is True
def test_valid_email_with_percent(self):
"""Test that email with percent sign passes validation."""
assert is_valid_email('user%name@example.com') is True
def test_invalid_email_with_exclamation(self):
"""Test that email with exclamation mark fails validation.
This is the specific case from the bug report:
ethanjames3713+!@gmail.com
"""
assert is_valid_email('ethanjames3713+!@gmail.com') is False
def test_invalid_email_with_special_chars(self):
"""Test that email with other special characters fails validation."""
assert is_valid_email('user!name@example.com') is False
assert is_valid_email('user#name@example.com') is False
assert is_valid_email('user$name@example.com') is False
assert is_valid_email('user&name@example.com') is False
assert is_valid_email("user'name@example.com") is False
assert is_valid_email('user*name@example.com') is False
assert is_valid_email('user=name@example.com') is False
assert is_valid_email('user^name@example.com') is False
assert is_valid_email('user`name@example.com') is False
assert is_valid_email('user{name@example.com') is False
assert is_valid_email('user|name@example.com') is False
assert is_valid_email('user}name@example.com') is False
assert is_valid_email('user~name@example.com') is False
def test_invalid_email_no_at_symbol(self):
"""Test that email without @ symbol fails validation."""
assert is_valid_email('userexample.com') is False
def test_invalid_email_no_domain(self):
"""Test that email without domain fails validation."""
assert is_valid_email('user@') is False
def test_invalid_email_no_local_part(self):
"""Test that email without local part fails validation."""
assert is_valid_email('@example.com') is False
def test_invalid_email_no_tld(self):
"""Test that email without TLD fails validation."""
assert is_valid_email('user@example') is False
def test_invalid_email_single_char_tld(self):
"""Test that email with single character TLD fails validation."""
assert is_valid_email('user@example.c') is False
def test_invalid_email_empty_string(self):
"""Test that empty string fails validation."""
assert is_valid_email('') is False
def test_invalid_email_none(self):
"""Test that None fails validation."""
assert is_valid_email(None) is False
def test_invalid_email_whitespace(self):
"""Test that email with whitespace fails validation."""
assert is_valid_email('user @example.com') is False
assert is_valid_email('user@ example.com') is False
assert is_valid_email(' user@example.com') is False
assert is_valid_email('user@example.com ') is False
def test_invalid_email_double_at(self):
"""Test that email with double @ fails validation."""
assert is_valid_email('user@@example.com') is False
def test_email_double_dot_domain(self):
"""Test email with double dot in domain.
Note: The regex allows this as it's technically valid in some edge cases,
and Resend's API may accept it. The main goal is to reject special
characters like ! that Resend definitely rejects.
"""
# This is allowed by our regex - Resend may or may not accept it
assert is_valid_email('user@example..com') is True
def test_case_insensitive_validation(self):
"""Test that validation works for uppercase emails."""
assert is_valid_email('USER@EXAMPLE.COM') is True
assert is_valid_email('User@Example.Com') is True

View File

@@ -265,17 +265,17 @@ async def test_list_api_keys(
# Verify
mock_get_user.assert_called_once_with(user_id)
assert len(result) == 2
assert result[0]['id'] == 1
assert result[0]['name'] == 'Key 1'
assert result[0]['created_at'] == now
assert result[0]['last_used_at'] == now
assert result[0]['expires_at'] == now + timedelta(days=30)
assert result[0].id == 1
assert result[0].name == 'Key 1'
assert result[0].created_at == now
assert result[0].last_used_at == now
assert result[0].expires_at == now + timedelta(days=30)
assert result[1]['id'] == 2
assert result[1]['name'] == 'Key 2'
assert result[1]['created_at'] == now
assert result[1]['last_used_at'] is None
assert result[1]['expires_at'] is None
assert result[1].id == 2
assert result[1].name == 'Key 2'
assert result[1].created_at == now
assert result[1].last_used_at is None
assert result[1].expires_at is None
@pytest.mark.asyncio

View File

@@ -0,0 +1,181 @@
"""Tests for auth callback invitation acceptance - EmailMismatchError handling."""
import pytest
class TestAuthCallbackInvitationEmailMismatch:
"""Test cases for EmailMismatchError handling during auth callback."""
@pytest.fixture
def mock_redirect_url(self):
"""Base redirect URL."""
return 'https://app.example.com/'
@pytest.fixture
def mock_user_id(self):
"""Mock user ID."""
return '87654321-4321-8765-4321-876543218765'
def test_email_mismatch_appends_to_url_without_query_params(
self, mock_redirect_url, mock_user_id
):
"""Test that email_mismatch=true is appended correctly when URL has no query params."""
from server.routes.org_invitation_models import EmailMismatchError
# Simulate the logic from auth.py
redirect_url = mock_redirect_url
try:
raise EmailMismatchError('Your email does not match the invitation')
except EmailMismatchError:
if '?' in redirect_url:
redirect_url = f'{redirect_url}&email_mismatch=true'
else:
redirect_url = f'{redirect_url}?email_mismatch=true'
assert redirect_url == 'https://app.example.com/?email_mismatch=true'
def test_email_mismatch_appends_to_url_with_query_params(self, mock_user_id):
"""Test that email_mismatch=true is appended correctly when URL has existing query params."""
from server.routes.org_invitation_models import EmailMismatchError
redirect_url = 'https://app.example.com/?other_param=value'
try:
raise EmailMismatchError()
except EmailMismatchError:
if '?' in redirect_url:
redirect_url = f'{redirect_url}&email_mismatch=true'
else:
redirect_url = f'{redirect_url}?email_mismatch=true'
assert (
redirect_url
== 'https://app.example.com/?other_param=value&email_mismatch=true'
)
def test_email_mismatch_error_has_default_message(self):
"""Test that EmailMismatchError has the default message."""
from server.routes.org_invitation_models import EmailMismatchError
error = EmailMismatchError()
assert str(error) == 'Your email does not match the invitation'
def test_email_mismatch_error_accepts_custom_message(self):
"""Test that EmailMismatchError accepts a custom message."""
from server.routes.org_invitation_models import EmailMismatchError
custom_message = 'Custom error message'
error = EmailMismatchError(custom_message)
assert str(error) == custom_message
def test_email_mismatch_error_is_invitation_error(self):
"""Test that EmailMismatchError inherits from InvitationError."""
from server.routes.org_invitation_models import (
EmailMismatchError,
InvitationError,
)
error = EmailMismatchError()
assert isinstance(error, InvitationError)
class TestInvitationTokenInOAuthState:
"""Test cases for invitation token handling in OAuth state."""
def test_invitation_token_included_in_oauth_state(self):
"""Test that invitation token is included in OAuth state data."""
import base64
import json
# Simulate building OAuth state with invitation token
state_data = {
'redirect_url': 'https://app.example.com/',
'invitation_token': 'inv-test-token-12345',
}
encoded_state = base64.b64encode(json.dumps(state_data).encode()).decode()
decoded_data = json.loads(base64.b64decode(encoded_state))
assert decoded_data['invitation_token'] == 'inv-test-token-12345'
assert decoded_data['redirect_url'] == 'https://app.example.com/'
def test_invitation_token_extracted_from_oauth_state(self):
"""Test that invitation token can be extracted from OAuth state."""
import base64
import json
state_data = {
'redirect_url': 'https://app.example.com/',
'invitation_token': 'inv-test-token-12345',
}
encoded_state = base64.b64encode(json.dumps(state_data).encode()).decode()
# Simulate decoding in callback
decoded_state = json.loads(base64.b64decode(encoded_state))
invitation_token = decoded_state.get('invitation_token')
assert invitation_token == 'inv-test-token-12345'
def test_oauth_state_without_invitation_token(self):
"""Test that OAuth state works without invitation token."""
import base64
import json
state_data = {
'redirect_url': 'https://app.example.com/',
}
encoded_state = base64.b64encode(json.dumps(state_data).encode()).decode()
decoded_data = json.loads(base64.b64decode(encoded_state))
assert 'invitation_token' not in decoded_data
assert decoded_data['redirect_url'] == 'https://app.example.com/'
class TestAuthCallbackInvitationErrors:
"""Test cases for various invitation error scenarios in auth callback."""
def test_invitation_expired_appends_flag(self):
"""Test that invitation_expired=true is appended for expired invitations."""
from server.routes.org_invitation_models import InvitationExpiredError
redirect_url = 'https://app.example.com/'
try:
raise InvitationExpiredError()
except InvitationExpiredError:
if '?' in redirect_url:
redirect_url = f'{redirect_url}&invitation_expired=true'
else:
redirect_url = f'{redirect_url}?invitation_expired=true'
assert redirect_url == 'https://app.example.com/?invitation_expired=true'
def test_invitation_invalid_appends_flag(self):
"""Test that invitation_invalid=true is appended for invalid invitations."""
from server.routes.org_invitation_models import InvitationInvalidError
redirect_url = 'https://app.example.com/'
try:
raise InvitationInvalidError()
except InvitationInvalidError:
if '?' in redirect_url:
redirect_url = f'{redirect_url}&invitation_invalid=true'
else:
redirect_url = f'{redirect_url}?invitation_invalid=true'
assert redirect_url == 'https://app.example.com/?invitation_invalid=true'
def test_already_member_appends_flag(self):
"""Test that already_member=true is appended when user is already a member."""
from server.routes.org_invitation_models import UserAlreadyMemberError
redirect_url = 'https://app.example.com/'
try:
raise UserAlreadyMemberError()
except UserAlreadyMemberError:
if '?' in redirect_url:
redirect_url = f'{redirect_url}&already_member=true'
else:
redirect_url = f'{redirect_url}?already_member=true'
assert redirect_url == 'https://app.example.com/?already_member=true'

View File

@@ -0,0 +1,756 @@
"""
Unit tests for permission-based authorization (authorization.py).
Tests the FastAPI dependencies that validate user permissions within organizations.
"""
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import uuid4
import pytest
from fastapi import HTTPException
from server.auth.authorization import (
ROLE_PERMISSIONS,
Permission,
RoleName,
get_role_permissions,
get_user_org_role,
has_permission,
require_permission,
)
# =============================================================================
# Tests for Permission enum
# =============================================================================
class TestPermission:
"""Tests for Permission enum."""
def test_permission_values(self):
"""
GIVEN: Permission enum
WHEN: Accessing permission values
THEN: All expected permissions exist with correct string values
"""
assert Permission.MANAGE_SECRETS.value == 'manage_secrets'
assert Permission.MANAGE_MCP.value == 'manage_mcp'
assert Permission.MANAGE_INTEGRATIONS.value == 'manage_integrations'
assert (
Permission.MANAGE_APPLICATION_SETTINGS.value
== 'manage_application_settings'
)
assert Permission.MANAGE_API_KEYS.value == 'manage_api_keys'
assert Permission.VIEW_LLM_SETTINGS.value == 'view_llm_settings'
assert Permission.EDIT_LLM_SETTINGS.value == 'edit_llm_settings'
assert Permission.VIEW_BILLING.value == 'view_billing'
assert Permission.ADD_CREDITS.value == 'add_credits'
assert (
Permission.INVITE_USER_TO_ORGANIZATION.value
== 'invite_user_to_organization'
)
assert Permission.CHANGE_USER_ROLE_MEMBER.value == 'change_user_role:member'
assert Permission.CHANGE_USER_ROLE_ADMIN.value == 'change_user_role:admin'
assert Permission.CHANGE_USER_ROLE_OWNER.value == 'change_user_role:owner'
assert Permission.VIEW_ORG_SETTINGS.value == 'view_org_settings'
assert Permission.CHANGE_ORGANIZATION_NAME.value == 'change_organization_name'
assert Permission.DELETE_ORGANIZATION.value == 'delete_organization'
def test_permission_from_string(self):
"""
GIVEN: Valid permission string
WHEN: Creating Permission from string
THEN: Correct enum value is returned
"""
assert Permission('manage_secrets') == Permission.MANAGE_SECRETS
assert Permission('view_llm_settings') == Permission.VIEW_LLM_SETTINGS
assert Permission('delete_organization') == Permission.DELETE_ORGANIZATION
def test_permission_invalid_string(self):
"""
GIVEN: Invalid permission string
WHEN: Creating Permission from string
THEN: ValueError is raised
"""
with pytest.raises(ValueError):
Permission('invalid_permission')
# =============================================================================
# Tests for RoleName enum
# =============================================================================
class TestRoleName:
"""Tests for RoleName enum."""
def test_role_name_values(self):
"""
GIVEN: RoleName enum
WHEN: Accessing role name values
THEN: All expected roles exist with correct string values
"""
assert RoleName.OWNER.value == 'owner'
assert RoleName.ADMIN.value == 'admin'
assert RoleName.MEMBER.value == 'member'
def test_role_name_from_string(self):
"""
GIVEN: Valid role name string
WHEN: Creating RoleName from string
THEN: Correct enum value is returned
"""
assert RoleName('owner') == RoleName.OWNER
assert RoleName('admin') == RoleName.ADMIN
assert RoleName('member') == RoleName.MEMBER
def test_role_name_invalid_string(self):
"""
GIVEN: Invalid role name string
WHEN: Creating RoleName from string
THEN: ValueError is raised
"""
with pytest.raises(ValueError):
RoleName('invalid_role')
# =============================================================================
# Tests for ROLE_PERMISSIONS mapping
# =============================================================================
class TestRolePermissions:
"""Tests for role permission mappings."""
def test_owner_has_all_permissions(self):
"""
GIVEN: ROLE_PERMISSIONS mapping
WHEN: Checking owner permissions
THEN: Owner has all permissions including owner-only permissions
"""
owner_perms = ROLE_PERMISSIONS[RoleName.OWNER]
assert Permission.MANAGE_SECRETS in owner_perms
assert Permission.MANAGE_MCP in owner_perms
assert Permission.VIEW_LLM_SETTINGS in owner_perms
assert Permission.EDIT_LLM_SETTINGS in owner_perms
assert Permission.VIEW_BILLING in owner_perms
assert Permission.ADD_CREDITS in owner_perms
assert Permission.INVITE_USER_TO_ORGANIZATION in owner_perms
assert Permission.CHANGE_USER_ROLE_MEMBER in owner_perms
assert Permission.CHANGE_USER_ROLE_ADMIN in owner_perms
assert Permission.CHANGE_USER_ROLE_OWNER in owner_perms
assert Permission.CHANGE_ORGANIZATION_NAME in owner_perms
assert Permission.DELETE_ORGANIZATION in owner_perms
def test_admin_has_admin_permissions(self):
"""
GIVEN: ROLE_PERMISSIONS mapping
WHEN: Checking admin permissions
THEN: Admin has admin permissions but not owner-only permissions
"""
admin_perms = ROLE_PERMISSIONS[RoleName.ADMIN]
assert Permission.MANAGE_SECRETS in admin_perms
assert Permission.MANAGE_MCP in admin_perms
assert Permission.VIEW_LLM_SETTINGS in admin_perms
assert Permission.EDIT_LLM_SETTINGS in admin_perms
assert Permission.VIEW_BILLING in admin_perms
assert Permission.ADD_CREDITS in admin_perms
assert Permission.INVITE_USER_TO_ORGANIZATION in admin_perms
assert Permission.CHANGE_USER_ROLE_MEMBER in admin_perms
assert Permission.CHANGE_USER_ROLE_ADMIN in admin_perms
# Admin should NOT have owner-only permissions
assert Permission.CHANGE_USER_ROLE_OWNER not in admin_perms
assert Permission.CHANGE_ORGANIZATION_NAME not in admin_perms
assert Permission.DELETE_ORGANIZATION not in admin_perms
def test_member_has_limited_permissions(self):
"""
GIVEN: ROLE_PERMISSIONS mapping
WHEN: Checking member permissions
THEN: Member has limited permissions
"""
member_perms = ROLE_PERMISSIONS[RoleName.MEMBER]
# Member has basic settings permissions
assert Permission.MANAGE_SECRETS in member_perms
assert Permission.MANAGE_MCP in member_perms
assert Permission.MANAGE_INTEGRATIONS in member_perms
assert Permission.MANAGE_APPLICATION_SETTINGS in member_perms
assert Permission.MANAGE_API_KEYS in member_perms
assert Permission.VIEW_LLM_SETTINGS in member_perms
assert Permission.VIEW_ORG_SETTINGS in member_perms
# Member should NOT have admin/owner permissions
assert Permission.EDIT_LLM_SETTINGS not in member_perms
assert Permission.VIEW_BILLING not in member_perms
assert Permission.ADD_CREDITS not in member_perms
assert Permission.INVITE_USER_TO_ORGANIZATION not in member_perms
assert Permission.CHANGE_USER_ROLE_MEMBER not in member_perms
assert Permission.CHANGE_USER_ROLE_ADMIN not in member_perms
assert Permission.CHANGE_USER_ROLE_OWNER not in member_perms
assert Permission.CHANGE_ORGANIZATION_NAME not in member_perms
assert Permission.DELETE_ORGANIZATION not in member_perms
# =============================================================================
# Tests for get_role_permissions function
# =============================================================================
class TestGetRolePermissions:
"""Tests for get_role_permissions function."""
def test_get_owner_permissions(self):
"""
GIVEN: Role name 'owner'
WHEN: get_role_permissions is called
THEN: Owner permissions are returned
"""
perms = get_role_permissions('owner')
assert Permission.DELETE_ORGANIZATION in perms
assert Permission.CHANGE_ORGANIZATION_NAME in perms
def test_get_admin_permissions(self):
"""
GIVEN: Role name 'admin'
WHEN: get_role_permissions is called
THEN: Admin permissions are returned
"""
perms = get_role_permissions('admin')
assert Permission.EDIT_LLM_SETTINGS in perms
assert Permission.DELETE_ORGANIZATION not in perms
def test_get_member_permissions(self):
"""
GIVEN: Role name 'member'
WHEN: get_role_permissions is called
THEN: Member permissions are returned
"""
perms = get_role_permissions('member')
assert Permission.VIEW_LLM_SETTINGS in perms
assert Permission.EDIT_LLM_SETTINGS not in perms
def test_get_invalid_role_permissions(self):
"""
GIVEN: Invalid role name
WHEN: get_role_permissions is called
THEN: Empty frozenset is returned
"""
perms = get_role_permissions('invalid_role')
assert perms == frozenset()
# =============================================================================
# Tests for has_permission function
# =============================================================================
class TestHasPermission:
"""Tests for has_permission function."""
def test_owner_has_delete_organization_permission(self):
"""
GIVEN: User with owner role
WHEN: Checking for DELETE_ORGANIZATION permission
THEN: Returns True
"""
mock_role = MagicMock()
mock_role.name = 'owner'
assert has_permission(mock_role, Permission.DELETE_ORGANIZATION) is True
def test_owner_has_view_llm_settings_permission(self):
"""
GIVEN: User with owner role
WHEN: Checking for VIEW_LLM_SETTINGS permission
THEN: Returns True
"""
mock_role = MagicMock()
mock_role.name = 'owner'
assert has_permission(mock_role, Permission.VIEW_LLM_SETTINGS) is True
def test_admin_has_edit_llm_settings_permission(self):
"""
GIVEN: User with admin role
WHEN: Checking for EDIT_LLM_SETTINGS permission
THEN: Returns True
"""
mock_role = MagicMock()
mock_role.name = 'admin'
assert has_permission(mock_role, Permission.EDIT_LLM_SETTINGS) is True
def test_admin_lacks_delete_organization_permission(self):
"""
GIVEN: User with admin role
WHEN: Checking for DELETE_ORGANIZATION permission
THEN: Returns False
"""
mock_role = MagicMock()
mock_role.name = 'admin'
assert has_permission(mock_role, Permission.DELETE_ORGANIZATION) is False
def test_member_has_view_llm_settings_permission(self):
"""
GIVEN: User with member role
WHEN: Checking for VIEW_LLM_SETTINGS permission
THEN: Returns True
"""
mock_role = MagicMock()
mock_role.name = 'member'
assert has_permission(mock_role, Permission.VIEW_LLM_SETTINGS) is True
def test_member_lacks_edit_llm_settings_permission(self):
"""
GIVEN: User with member role
WHEN: Checking for EDIT_LLM_SETTINGS permission
THEN: Returns False
"""
mock_role = MagicMock()
mock_role.name = 'member'
assert has_permission(mock_role, Permission.EDIT_LLM_SETTINGS) is False
def test_member_lacks_delete_organization_permission(self):
"""
GIVEN: User with member role
WHEN: Checking for DELETE_ORGANIZATION permission
THEN: Returns False
"""
mock_role = MagicMock()
mock_role.name = 'member'
assert has_permission(mock_role, Permission.DELETE_ORGANIZATION) is False
def test_invalid_role_has_no_permissions(self):
"""
GIVEN: User with invalid role
WHEN: Checking for any permission
THEN: Returns False
"""
mock_role = MagicMock()
mock_role.name = 'invalid_role'
assert has_permission(mock_role, Permission.VIEW_LLM_SETTINGS) is False
assert has_permission(mock_role, Permission.DELETE_ORGANIZATION) is False
# =============================================================================
# Tests for get_user_org_role function
# =============================================================================
class TestGetUserOrgRole:
"""Tests for get_user_org_role function."""
def test_returns_role_when_member_exists(self):
"""
GIVEN: User is a member of organization with role
WHEN: get_user_org_role is called
THEN: Role object is returned
"""
user_id = str(uuid4())
org_id = uuid4()
mock_org_member = MagicMock()
mock_org_member.role_id = 1
mock_role = MagicMock()
mock_role.name = 'admin'
with (
patch(
'server.auth.authorization.OrgMemberStore.get_org_member',
return_value=mock_org_member,
),
patch(
'server.auth.authorization.RoleStore.get_role_by_id',
return_value=mock_role,
),
):
result = get_user_org_role(user_id, org_id)
assert result == mock_role
def test_returns_none_when_not_member(self):
"""
GIVEN: User is not a member of organization
WHEN: get_user_org_role is called
THEN: None is returned
"""
user_id = str(uuid4())
org_id = uuid4()
with patch(
'server.auth.authorization.OrgMemberStore.get_org_member',
return_value=None,
):
result = get_user_org_role(user_id, org_id)
assert result is None
def test_returns_role_when_org_id_is_none(self):
"""
GIVEN: User with a current organization
WHEN: get_user_org_role is called with org_id=None
THEN: Role object is returned using get_org_member_for_current_org
"""
user_id = str(uuid4())
mock_org_member = MagicMock()
mock_org_member.role_id = 1
mock_role = MagicMock()
mock_role.name = 'admin'
with (
patch(
'server.auth.authorization.OrgMemberStore.get_org_member_for_current_org',
return_value=mock_org_member,
) as mock_get_current,
patch(
'server.auth.authorization.OrgMemberStore.get_org_member',
) as mock_get_org_member,
patch(
'server.auth.authorization.RoleStore.get_role_by_id',
return_value=mock_role,
),
):
result = get_user_org_role(user_id, None)
assert result == mock_role
mock_get_current.assert_called_once()
mock_get_org_member.assert_not_called()
def test_returns_none_when_org_id_is_none_and_no_current_org(self):
"""
GIVEN: User with no current organization membership
WHEN: get_user_org_role is called with org_id=None
THEN: None is returned
"""
user_id = str(uuid4())
with patch(
'server.auth.authorization.OrgMemberStore.get_org_member_for_current_org',
return_value=None,
):
result = get_user_org_role(user_id, None)
assert result is None
# =============================================================================
# Tests for require_permission dependency
# =============================================================================
class TestRequirePermission:
"""Tests for require_permission dependency factory."""
@pytest.mark.asyncio
async def test_returns_user_id_when_authorized(self):
"""
GIVEN: User with required permission
WHEN: Permission checker is called
THEN: User ID is returned
"""
user_id = str(uuid4())
org_id = uuid4()
mock_role = MagicMock()
mock_role.name = 'admin'
with patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_role),
):
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
result = await permission_checker(org_id=org_id, user_id=user_id)
assert result == user_id
@pytest.mark.asyncio
async def test_raises_401_when_not_authenticated(self):
"""
GIVEN: No user ID (not authenticated)
WHEN: Permission checker is called
THEN: 401 Unauthorized is raised
"""
org_id = uuid4()
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
with pytest.raises(HTTPException) as exc_info:
await permission_checker(org_id=org_id, user_id=None)
assert exc_info.value.status_code == 401
assert 'not authenticated' in exc_info.value.detail.lower()
@pytest.mark.asyncio
async def test_raises_403_when_not_member(self):
"""
GIVEN: User is not a member of organization
WHEN: Permission checker is called
THEN: 403 Forbidden is raised
"""
user_id = str(uuid4())
org_id = uuid4()
with patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=None),
):
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
with pytest.raises(HTTPException) as exc_info:
await permission_checker(org_id=org_id, user_id=user_id)
assert exc_info.value.status_code == 403
assert 'not a member' in exc_info.value.detail.lower()
@pytest.mark.asyncio
async def test_raises_403_when_insufficient_permission(self):
"""
GIVEN: User without required permission
WHEN: Permission checker is called
THEN: 403 Forbidden is raised
"""
user_id = str(uuid4())
org_id = uuid4()
mock_role = MagicMock()
mock_role.name = 'member'
with patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_role),
):
permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
with pytest.raises(HTTPException) as exc_info:
await permission_checker(org_id=org_id, user_id=user_id)
assert exc_info.value.status_code == 403
assert 'delete_organization' in exc_info.value.detail.lower()
@pytest.mark.asyncio
async def test_owner_can_delete_organization(self):
"""
GIVEN: User with owner role
WHEN: DELETE_ORGANIZATION permission is required
THEN: User ID is returned
"""
user_id = str(uuid4())
org_id = uuid4()
mock_role = MagicMock()
mock_role.name = 'owner'
with patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_role),
):
permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
result = await permission_checker(org_id=org_id, user_id=user_id)
assert result == user_id
@pytest.mark.asyncio
async def test_admin_cannot_delete_organization(self):
"""
GIVEN: User with admin role
WHEN: DELETE_ORGANIZATION permission is required
THEN: 403 Forbidden is raised
"""
user_id = str(uuid4())
org_id = uuid4()
mock_role = MagicMock()
mock_role.name = 'admin'
with patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_role),
):
permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
with pytest.raises(HTTPException) as exc_info:
await permission_checker(org_id=org_id, user_id=user_id)
assert exc_info.value.status_code == 403
@pytest.mark.asyncio
async def test_logs_warning_on_insufficient_permission(self):
"""
GIVEN: User without required permission
WHEN: Permission checker is called
THEN: Warning is logged with details
"""
user_id = str(uuid4())
org_id = uuid4()
mock_role = MagicMock()
mock_role.name = 'member'
with (
patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_role),
),
patch('server.auth.authorization.logger') as mock_logger,
):
permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
with pytest.raises(HTTPException):
await permission_checker(org_id=org_id, user_id=user_id)
mock_logger.warning.assert_called()
call_args = mock_logger.warning.call_args
assert call_args[1]['extra']['user_id'] == user_id
assert call_args[1]['extra']['user_role'] == 'member'
assert call_args[1]['extra']['required_permission'] == 'delete_organization'
@pytest.mark.asyncio
async def test_returns_user_id_when_org_id_is_none(self):
"""
GIVEN: User with required permission in their current org
WHEN: Permission checker is called with org_id=None
THEN: User ID is returned
"""
user_id = str(uuid4())
mock_role = MagicMock()
mock_role.name = 'admin'
with patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_role),
) as mock_get_role:
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
result = await permission_checker(org_id=None, user_id=user_id)
assert result == user_id
mock_get_role.assert_called_once_with(user_id, None)
@pytest.mark.asyncio
async def test_raises_403_when_org_id_is_none_and_not_member(self):
"""
GIVEN: User not a member of their current organization
WHEN: Permission checker is called with org_id=None
THEN: HTTPException with 403 status is raised
"""
user_id = str(uuid4())
with patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=None),
):
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
with pytest.raises(HTTPException) as exc_info:
await permission_checker(org_id=None, user_id=user_id)
assert exc_info.value.status_code == 403
assert 'not a member' in exc_info.value.detail
# =============================================================================
# Tests for permission-based access control scenarios
# =============================================================================
class TestPermissionScenarios:
"""Tests for real-world permission scenarios."""
@pytest.mark.asyncio
async def test_member_can_manage_secrets(self):
"""
GIVEN: User with member role
WHEN: MANAGE_SECRETS permission is required
THEN: User ID is returned
"""
user_id = str(uuid4())
org_id = uuid4()
mock_role = MagicMock()
mock_role.name = 'member'
with patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_role),
):
permission_checker = require_permission(Permission.MANAGE_SECRETS)
result = await permission_checker(org_id=org_id, user_id=user_id)
assert result == user_id
@pytest.mark.asyncio
async def test_member_cannot_invite_users(self):
"""
GIVEN: User with member role
WHEN: INVITE_USER_TO_ORGANIZATION permission is required
THEN: 403 Forbidden is raised
"""
user_id = str(uuid4())
org_id = uuid4()
mock_role = MagicMock()
mock_role.name = 'member'
with patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_role),
):
permission_checker = require_permission(
Permission.INVITE_USER_TO_ORGANIZATION
)
with pytest.raises(HTTPException) as exc_info:
await permission_checker(org_id=org_id, user_id=user_id)
assert exc_info.value.status_code == 403
@pytest.mark.asyncio
async def test_admin_can_invite_users(self):
"""
GIVEN: User with admin role
WHEN: INVITE_USER_TO_ORGANIZATION permission is required
THEN: User ID is returned
"""
user_id = str(uuid4())
org_id = uuid4()
mock_role = MagicMock()
mock_role.name = 'admin'
with patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_role),
):
permission_checker = require_permission(
Permission.INVITE_USER_TO_ORGANIZATION
)
result = await permission_checker(org_id=org_id, user_id=user_id)
assert result == user_id
@pytest.mark.asyncio
async def test_admin_cannot_change_owner_role(self):
"""
GIVEN: User with admin role
WHEN: CHANGE_USER_ROLE_OWNER permission is required
THEN: 403 Forbidden is raised
"""
user_id = str(uuid4())
org_id = uuid4()
mock_role = MagicMock()
mock_role.name = 'admin'
with patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_role),
):
permission_checker = require_permission(Permission.CHANGE_USER_ROLE_OWNER)
with pytest.raises(HTTPException) as exc_info:
await permission_checker(org_id=org_id, user_id=user_id)
assert exc_info.value.status_code == 403
@pytest.mark.asyncio
async def test_owner_can_change_owner_role(self):
"""
GIVEN: User with owner role
WHEN: CHANGE_USER_ROLE_OWNER permission is required
THEN: User ID is returned
"""
user_id = str(uuid4())
org_id = uuid4()
mock_role = MagicMock()
mock_role.name = 'owner'
with patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_role),
):
permission_checker = require_permission(Permission.CHANGE_USER_ROLE_OWNER)
result = await permission_checker(org_id=org_id, user_id=user_id)
assert result == user_id

View File

@@ -101,7 +101,7 @@ async def test_get_credits_success():
json={
'user_info': {
'spend': 25.50,
'litellm_budget_table': {'max_budget': 100.00},
'max_budget_in_team': 100.00,
}
},
request=MagicMock(),
@@ -121,7 +121,7 @@ async def test_get_credits_success():
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
return_value={
'spend': 25.50,
'litellm_budget_table': {'max_budget': 100.00},
'max_budget_in_team': 100.00,
},
),
):
@@ -299,6 +299,8 @@ async def test_success_callback_success():
mock_billing_session.status = 'in_progress'
mock_billing_session.user_id = 'mock_user'
mock_org = MagicMock()
with (
patch('server.routes.billing.session_maker') as mock_session_maker,
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
@@ -311,7 +313,7 @@ async def test_success_callback_success():
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
return_value={
'spend': 25.50,
'litellm_budget_table': {'max_budget': 100.00},
'max_budget_in_team': 100.00,
},
),
patch(
@@ -319,7 +321,17 @@ async def test_success_callback_success():
) as mock_update_budget,
):
mock_db_session = MagicMock()
# First query: BillingSession (query().filter().filter().first())
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_billing_session
# Second query: Org (query().filter().first()) - use side_effect for different return chains
mock_query_chain_billing = MagicMock()
mock_query_chain_billing.filter.return_value.filter.return_value.first.return_value = mock_billing_session
mock_query_chain_org = MagicMock()
mock_query_chain_org.filter.return_value.first.return_value = mock_org
mock_db_session.query.side_effect = [
mock_query_chain_billing,
mock_query_chain_org,
]
mock_session_maker.return_value.__enter__.return_value = mock_db_session
mock_stripe_retrieve.return_value = MagicMock(
@@ -337,9 +349,12 @@ async def test_success_callback_success():
# Verify LiteLLM API calls
mock_update_budget.assert_called_once_with(
'mock_org_id',
125.0, # 100 + (25.00 from Stripe)
125.0, # 100 + 25.00
)
# Verify BYOR export is enabled for the org (updated in same session)
assert mock_org.byor_export_enabled is True
# Verify database updates
assert mock_billing_session.status == 'completed'
assert mock_billing_session.price == 25.0
@@ -387,6 +402,68 @@ async def test_success_callback_lite_llm_error():
mock_db_session.commit.assert_not_called()
@pytest.mark.asyncio
async def test_success_callback_lite_llm_update_budget_error_rollback():
"""Test that database changes are not committed when update_team_and_users_budget fails.
This test verifies that if LiteLlmManager.update_team_and_users_budget raises an exception,
the database transaction rolls back.
"""
mock_request = Request(scope={'type': 'http'})
mock_request._base_url = URL('http://test.com/')
mock_billing_session = MagicMock()
mock_billing_session.status = 'in_progress'
mock_billing_session.user_id = 'mock_user'
mock_org = MagicMock()
with (
patch('server.routes.billing.session_maker') as mock_session_maker,
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
patch(
'storage.user_store.UserStore.get_user_by_id_async',
new_callable=AsyncMock,
return_value=MagicMock(current_org_id='mock_org_id'),
),
patch(
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
return_value={
'spend': 0,
'max_budget_in_team': 0,
},
),
patch(
'storage.lite_llm_manager.LiteLlmManager.update_team_and_users_budget',
side_effect=Exception('LiteLLM API Error'),
),
):
mock_db_session = MagicMock()
mock_query_chain_billing = MagicMock()
mock_query_chain_billing.filter.return_value.filter.return_value.first.return_value = mock_billing_session
mock_query_chain_org = MagicMock()
mock_query_chain_org.filter.return_value.first.return_value = mock_org
mock_db_session.query.side_effect = [
mock_query_chain_billing,
mock_query_chain_org,
]
mock_session_maker.return_value.__enter__.return_value = mock_db_session
mock_stripe_retrieve.return_value = MagicMock(
status='complete',
amount_subtotal=1000, # $10
customer='mock_customer_id',
)
with pytest.raises(Exception, match='LiteLLM API Error'):
await success_callback('test_session_id', mock_request)
# Verify no database commit occurred - the transaction should roll back
assert mock_billing_session.status == 'in_progress'
mock_db_session.merge.assert_not_called()
mock_db_session.commit.assert_not_called()
@pytest.mark.asyncio
async def test_cancel_callback_session_not_found():
"""Test cancel callback when billing session is not found."""
@@ -502,6 +579,6 @@ async def test_create_customer_setup_session_success():
customer='mock-customer-id',
mode='setup',
payment_method_types=['card'],
success_url='https://test.com/?free_credits=success',
success_url='https://test.com/?setup=success',
cancel_url='https://test.com/',
)

View File

@@ -48,7 +48,7 @@ async def test_create_customer_setup_session_uses_customer_id():
customer=customer_id,
mode='setup',
payment_method_types=['card'],
success_url=f'{request.base_url}?free_credits=success',
success_url=f'{request.base_url}?setup=success',
cancel_url=f'{request.base_url}',
)

View File

@@ -0,0 +1,192 @@
"""Tests for email service."""
import os
from unittest.mock import MagicMock, patch
from server.services.email_service import (
DEFAULT_WEB_HOST,
EmailService,
)
class TestEmailServiceInvitationUrl:
"""Test cases for invitation URL generation."""
def test_invitation_url_uses_correct_endpoint(self):
"""Test that invitation URL points to the correct API endpoint."""
mock_response = MagicMock()
mock_response.get.return_value = 'test-email-id'
with (
patch.dict(os.environ, {'RESEND_API_KEY': 'test-key'}),
patch('server.services.email_service.RESEND_AVAILABLE', True),
patch('server.services.email_service.resend') as mock_resend,
):
mock_resend.Emails.send.return_value = mock_response
EmailService.send_invitation_email(
to_email='test@example.com',
org_name='Test Org',
inviter_name='Inviter',
role_name='member',
invitation_token='inv-test-token-12345',
invitation_id=1,
)
# Get the call arguments
call_args = mock_resend.Emails.send.call_args
email_params = call_args[0][0]
# Verify the URL in the email HTML contains the correct endpoint
assert (
'/api/organizations/members/invite/accept?token='
in email_params['html']
)
assert 'inv-test-token-12345' in email_params['html']
def test_invitation_url_uses_web_host_env_var(self):
"""Test that invitation URL uses WEB_HOST environment variable."""
custom_host = 'https://custom.example.com'
mock_response = MagicMock()
mock_response.get.return_value = 'test-email-id'
with (
patch.dict(
os.environ,
{'RESEND_API_KEY': 'test-key', 'WEB_HOST': custom_host},
),
patch('server.services.email_service.RESEND_AVAILABLE', True),
patch('server.services.email_service.resend') as mock_resend,
):
mock_resend.Emails.send.return_value = mock_response
EmailService.send_invitation_email(
to_email='test@example.com',
org_name='Test Org',
inviter_name='Inviter',
role_name='member',
invitation_token='inv-test-token-12345',
invitation_id=1,
)
call_args = mock_resend.Emails.send.call_args
email_params = call_args[0][0]
expected_url = f'{custom_host}/api/organizations/members/invite/accept?token=inv-test-token-12345'
assert expected_url in email_params['html']
def test_invitation_url_uses_default_host_when_env_not_set(self):
"""Test that invitation URL falls back to DEFAULT_WEB_HOST when env not set."""
mock_response = MagicMock()
mock_response.get.return_value = 'test-email-id'
env_without_web_host = {'RESEND_API_KEY': 'test-key'}
# Remove WEB_HOST if it exists
env_without_web_host.pop('WEB_HOST', None)
with (
patch.dict(os.environ, env_without_web_host, clear=True),
patch('server.services.email_service.RESEND_AVAILABLE', True),
patch('server.services.email_service.resend') as mock_resend,
):
# Clear WEB_HOST from the environment
os.environ.pop('WEB_HOST', None)
mock_resend.Emails.send.return_value = mock_response
EmailService.send_invitation_email(
to_email='test@example.com',
org_name='Test Org',
inviter_name='Inviter',
role_name='member',
invitation_token='inv-test-token-12345',
invitation_id=1,
)
call_args = mock_resend.Emails.send.call_args
email_params = call_args[0][0]
expected_url = f'{DEFAULT_WEB_HOST}/api/organizations/members/invite/accept?token=inv-test-token-12345'
assert expected_url in email_params['html']
class TestEmailServiceGetResendClient:
"""Test cases for Resend client initialization."""
def test_get_resend_client_returns_false_when_resend_not_available(self):
"""Test that _get_resend_client returns False when resend is not installed."""
with patch('server.services.email_service.RESEND_AVAILABLE', False):
result = EmailService._get_resend_client()
assert result is False
def test_get_resend_client_returns_false_when_api_key_not_configured(self):
"""Test that _get_resend_client returns False when API key is missing."""
with (
patch('server.services.email_service.RESEND_AVAILABLE', True),
patch.dict(os.environ, {}, clear=True),
):
os.environ.pop('RESEND_API_KEY', None)
result = EmailService._get_resend_client()
assert result is False
def test_get_resend_client_returns_true_when_configured(self):
"""Test that _get_resend_client returns True when properly configured."""
with (
patch.dict(os.environ, {'RESEND_API_KEY': 'test-key'}),
patch('server.services.email_service.RESEND_AVAILABLE', True),
patch('server.services.email_service.resend') as mock_resend,
):
result = EmailService._get_resend_client()
assert result is True
assert mock_resend.api_key == 'test-key'
class TestEmailServiceSendInvitationEmail:
"""Test cases for send_invitation_email method."""
def test_send_invitation_email_skips_when_client_not_ready(self):
"""Test that email sending is skipped when client is not ready."""
with patch.object(
EmailService, '_get_resend_client', return_value=False
) as mock_get_client:
# Should not raise, just return early
EmailService.send_invitation_email(
to_email='test@example.com',
org_name='Test Org',
inviter_name='Inviter',
role_name='member',
invitation_token='inv-test-token',
invitation_id=1,
)
mock_get_client.assert_called_once()
def test_send_invitation_email_includes_all_required_info(self):
"""Test that invitation email includes org name, inviter name, and role."""
mock_response = MagicMock()
mock_response.get.return_value = 'test-email-id'
with (
patch.dict(os.environ, {'RESEND_API_KEY': 'test-key'}),
patch('server.services.email_service.RESEND_AVAILABLE', True),
patch('server.services.email_service.resend') as mock_resend,
):
mock_resend.Emails.send.return_value = mock_response
EmailService.send_invitation_email(
to_email='test@example.com',
org_name='Acme Corp',
inviter_name='John Doe',
role_name='admin',
invitation_token='inv-test-token-12345',
invitation_id=42,
)
call_args = mock_resend.Emails.send.call_args
email_params = call_args[0][0]
# Verify email content
assert email_params['to'] == ['test@example.com']
assert 'Acme Corp' in email_params['subject']
assert 'John Doe' in email_params['html']
assert 'Acme Corp' in email_params['html']
assert 'admin' in email_params['html']

View File

@@ -142,44 +142,192 @@ class TestLiteLlmManager:
@pytest.mark.asyncio
async def test_create_entries_cloud_deployment(self, mock_settings, mock_response):
"""Test create_entries in cloud deployment mode."""
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}):
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
with patch(
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
):
with patch(
'storage.lite_llm_manager.TokenManager'
) as mock_token_manager:
mock_token_manager.return_value.get_user_info_from_user_id = (
AsyncMock(return_value={'email': 'test@example.com'})
)
mock_404_response = MagicMock()
mock_404_response.status_code = 404
mock_404_response.is_success = False
with patch('httpx.AsyncClient') as mock_client_class:
mock_client = AsyncMock()
mock_client_class.return_value.__aenter__.return_value = (
mock_client
)
mock_client.post.return_value = mock_response
mock_token_manager = MagicMock()
mock_token_manager.return_value.get_user_info_from_user_id = AsyncMock(
return_value={'email': 'test@example.com'}
)
result = await LiteLlmManager.create_entries(
'test-org-id',
'test-user-id',
mock_settings,
create_user=False,
)
mock_client = AsyncMock()
mock_client.get.return_value = mock_404_response
mock_client.get.return_value.raise_for_status.side_effect = (
httpx.HTTPStatusError(
message='Not Found', request=MagicMock(), response=mock_404_response
)
)
mock_client.post.return_value = mock_response
assert result is not None
assert result.agent == 'CodeActAgent'
assert result.llm_model == get_default_litellm_model()
assert (
result.llm_api_key.get_secret_value() == 'test-api-key'
)
assert result.llm_base_url == 'http://test.com'
mock_client_class = MagicMock()
mock_client_class.return_value.__aenter__.return_value = mock_client
# Verify API calls were made
assert (
mock_client.post.call_count == 3
) # create_team, create_user, add_user_to_team, generate_key
with (
patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}),
patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'),
patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'),
patch('storage.lite_llm_manager.TokenManager', mock_token_manager),
patch('httpx.AsyncClient', mock_client_class),
):
result = await LiteLlmManager.create_entries(
'test-org-id', 'test-user-id', mock_settings, create_user=False
)
assert result is not None
assert result.agent == 'CodeActAgent'
assert result.llm_model == get_default_litellm_model()
assert result.llm_api_key.get_secret_value() == 'test-api-key'
assert result.llm_base_url == 'http://test.com'
# Verify API calls were made (get_team + 3 posts)
assert mock_client.get.call_count == 1 # get_team
assert (
mock_client.post.call_count == 3
) # create_team, add_user_to_team, generate_key
@pytest.mark.asyncio
async def test_create_entries_inherits_existing_team_budget(
self, mock_settings, mock_response
):
"""Test that create_entries inherits budget from existing team."""
mock_team_response = MagicMock()
mock_team_response.is_success = True
mock_team_response.status_code = 200
mock_team_response.json.return_value = {
'team_info': {'max_budget': 30.0, 'spend': 5.0},
'team_memberships': [],
}
mock_team_response.raise_for_status = MagicMock()
mock_token_manager = MagicMock()
mock_token_manager.return_value.get_user_info_from_user_id = AsyncMock(
return_value={'email': 'test@example.com'}
)
mock_client = AsyncMock()
mock_client.get.return_value = mock_team_response
mock_client.post.return_value = mock_response
mock_client_class = MagicMock()
mock_client_class.return_value.__aenter__.return_value = mock_client
with (
patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}),
patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'),
patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'),
patch('storage.lite_llm_manager.TokenManager', mock_token_manager),
patch('httpx.AsyncClient', mock_client_class),
):
result = await LiteLlmManager.create_entries(
'test-org-id', 'test-user-id', mock_settings, create_user=False
)
assert result is not None
# Verify _get_team was called first
mock_client.get.assert_called_once()
get_call_url = mock_client.get.call_args[0][0]
assert 'team/info' in get_call_url
assert 'test-org-id' in get_call_url
# Verify _create_team was called with inherited budget (30.0)
create_team_call = mock_client.post.call_args_list[0]
assert 'team/new' in create_team_call[0][0]
assert create_team_call[1]['json']['max_budget'] == 30.0
# Verify _add_user_to_team was called with inherited budget (30.0)
add_user_call = mock_client.post.call_args_list[1]
assert 'team/member_add' in add_user_call[0][0]
assert add_user_call[1]['json']['max_budget_in_team'] == 30.0
@pytest.mark.asyncio
async def test_create_entries_new_org_uses_zero_budget(
self, mock_settings, mock_response
):
"""Test that create_entries uses budget=0 for new org (team doesn't exist)."""
mock_404_response = MagicMock()
mock_404_response.status_code = 404
mock_404_response.is_success = False
mock_token_manager = MagicMock()
mock_token_manager.return_value.get_user_info_from_user_id = AsyncMock(
return_value={'email': 'test@example.com'}
)
mock_client = AsyncMock()
mock_client.get.return_value = mock_404_response
mock_client.get.return_value.raise_for_status.side_effect = (
httpx.HTTPStatusError(
message='Not Found', request=MagicMock(), response=mock_404_response
)
)
mock_client.post.return_value = mock_response
mock_client_class = MagicMock()
mock_client_class.return_value.__aenter__.return_value = mock_client
with (
patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}),
patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'),
patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'),
patch('storage.lite_llm_manager.TokenManager', mock_token_manager),
patch('httpx.AsyncClient', mock_client_class),
):
result = await LiteLlmManager.create_entries(
'test-org-id', 'test-user-id', mock_settings, create_user=False
)
assert result is not None
# Verify _create_team was called with budget=0
create_team_call = mock_client.post.call_args_list[0]
assert 'team/new' in create_team_call[0][0]
assert create_team_call[1]['json']['max_budget'] == 0.0
# Verify _add_user_to_team was called with budget=0
add_user_call = mock_client.post.call_args_list[1]
assert 'team/member_add' in add_user_call[0][0]
assert add_user_call[1]['json']['max_budget_in_team'] == 0.0
@pytest.mark.asyncio
async def test_create_entries_propagates_non_404_errors(self, mock_settings):
"""Test that create_entries propagates non-404 errors from _get_team."""
mock_500_response = MagicMock()
mock_500_response.status_code = 500
mock_500_response.is_success = False
mock_token_manager = MagicMock()
mock_token_manager.return_value.get_user_info_from_user_id = AsyncMock(
return_value={'email': 'test@example.com'}
)
mock_client = AsyncMock()
mock_client.get.return_value = mock_500_response
mock_client.get.return_value.raise_for_status.side_effect = (
httpx.HTTPStatusError(
message='Internal Server Error',
request=MagicMock(),
response=mock_500_response,
)
)
mock_client_class = MagicMock()
mock_client_class.return_value.__aenter__.return_value = mock_client
with (
patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}),
patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'),
patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'),
patch('storage.lite_llm_manager.TokenManager', mock_token_manager),
patch('httpx.AsyncClient', mock_client_class),
):
with pytest.raises(httpx.HTTPStatusError) as exc_info:
await LiteLlmManager.create_entries(
'test-org-id', 'test-user-id', mock_settings, create_user=False
)
assert exc_info.value.response.status_code == 500
@pytest.mark.asyncio
async def test_migrate_entries_missing_config(self, mock_user_settings):

View File

@@ -0,0 +1,464 @@
"""Tests for organization invitation service - email validation."""
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import UUID
import pytest
from server.routes.org_invitation_models import (
EmailMismatchError,
)
from server.services.org_invitation_service import OrgInvitationService
from storage.org_invitation import OrgInvitation
class TestAcceptInvitationEmailValidation:
"""Test cases for email validation during invitation acceptance."""
@pytest.fixture
def mock_invitation(self):
"""Create a mock invitation with pending status."""
invitation = MagicMock(spec=OrgInvitation)
invitation.id = 1
invitation.email = 'alice@example.com'
invitation.status = OrgInvitation.STATUS_PENDING
invitation.org_id = UUID('12345678-1234-5678-1234-567812345678')
invitation.role_id = 1
return invitation
@pytest.fixture
def mock_user(self):
"""Create a mock user with email."""
user = MagicMock()
user.id = UUID('87654321-4321-8765-4321-876543218765')
user.email = 'alice@example.com'
return user
@pytest.mark.asyncio
async def test_accept_invitation_email_matches(self, mock_invitation, mock_user):
"""Test that invitation is accepted when user email matches invitation email."""
# Arrange
user_id = mock_user.id
token = 'inv-test-token-12345'
with patch.object(
OrgInvitationService, 'accept_invitation', new_callable=AsyncMock
) as mock_accept:
mock_accept.return_value = mock_invitation
# Act
await OrgInvitationService.accept_invitation(token, user_id)
# Assert
mock_accept.assert_called_once_with(token, user_id)
@pytest.mark.asyncio
async def test_accept_invitation_email_mismatch_raises_error(
self, mock_invitation, mock_user
):
"""Test that EmailMismatchError is raised when emails don't match."""
# Arrange
user_id = mock_user.id
token = 'inv-test-token-12345'
mock_user.email = 'bob@example.com' # Different email
with (
patch(
'server.services.org_invitation_service.OrgInvitationStore.get_invitation_by_token',
new_callable=AsyncMock,
) as mock_get_invitation,
patch(
'server.services.org_invitation_service.OrgInvitationStore.is_token_expired'
) as mock_is_expired,
patch(
'server.services.org_invitation_service.UserStore.get_user_by_id_async',
new_callable=AsyncMock,
) as mock_get_user,
):
mock_get_invitation.return_value = mock_invitation
mock_is_expired.return_value = False
mock_get_user.return_value = mock_user
# Act & Assert
with pytest.raises(EmailMismatchError):
await OrgInvitationService.accept_invitation(token, user_id)
@pytest.mark.asyncio
async def test_accept_invitation_user_no_email_keycloak_fallback_matches(
self, mock_invitation
):
"""Test that Keycloak email is used when user has no email in database."""
# Arrange
user_id = UUID('87654321-4321-8765-4321-876543218765')
token = 'inv-test-token-12345'
mock_user = MagicMock()
mock_user.id = user_id
mock_user.email = None # No email in database
mock_keycloak_user_info = {'email': 'alice@example.com'} # Email from Keycloak
with (
patch(
'server.services.org_invitation_service.OrgInvitationStore.get_invitation_by_token',
new_callable=AsyncMock,
) as mock_get_invitation,
patch(
'server.services.org_invitation_service.OrgInvitationStore.is_token_expired'
) as mock_is_expired,
patch(
'server.services.org_invitation_service.UserStore.get_user_by_id_async',
new_callable=AsyncMock,
) as mock_get_user,
patch(
'server.services.org_invitation_service.TokenManager'
) as mock_token_manager_class,
patch(
'server.services.org_invitation_service.OrgMemberStore.get_org_member'
) as mock_get_member,
patch(
'server.services.org_invitation_service.OrgService.create_litellm_integration',
new_callable=AsyncMock,
) as mock_create_litellm,
patch(
'server.services.org_invitation_service.OrgMemberStore.add_user_to_org'
),
patch(
'server.services.org_invitation_service.OrgInvitationStore.update_invitation_status',
new_callable=AsyncMock,
) as mock_update_status,
):
mock_get_invitation.return_value = mock_invitation
mock_is_expired.return_value = False
mock_get_user.return_value = mock_user
# Mock TokenManager instance
mock_token_manager = MagicMock()
mock_token_manager.get_user_info_from_user_id = AsyncMock(
return_value=mock_keycloak_user_info
)
mock_token_manager_class.return_value = mock_token_manager
mock_get_member.return_value = None # Not already a member
mock_create_litellm.return_value = MagicMock(llm_api_key='test-key')
mock_update_status.return_value = mock_invitation
# Act - should not raise error because Keycloak email matches
await OrgInvitationService.accept_invitation(token, user_id)
# Assert
mock_token_manager.get_user_info_from_user_id.assert_called_once_with(
str(user_id)
)
@pytest.mark.asyncio
async def test_accept_invitation_no_email_anywhere_raises_error(
self, mock_invitation
):
"""Test that EmailMismatchError is raised when user has no email in database or Keycloak."""
# Arrange
user_id = UUID('87654321-4321-8765-4321-876543218765')
token = 'inv-test-token-12345'
mock_user = MagicMock()
mock_user.id = user_id
mock_user.email = None # No email in database
with (
patch(
'server.services.org_invitation_service.OrgInvitationStore.get_invitation_by_token',
new_callable=AsyncMock,
) as mock_get_invitation,
patch(
'server.services.org_invitation_service.OrgInvitationStore.is_token_expired'
) as mock_is_expired,
patch(
'server.services.org_invitation_service.UserStore.get_user_by_id_async',
new_callable=AsyncMock,
) as mock_get_user,
patch(
'server.services.org_invitation_service.TokenManager'
) as mock_token_manager_class,
):
mock_get_invitation.return_value = mock_invitation
mock_is_expired.return_value = False
mock_get_user.return_value = mock_user
# Mock TokenManager to return no email
mock_token_manager = MagicMock()
mock_token_manager.get_user_info_from_user_id = AsyncMock(return_value={})
mock_token_manager_class.return_value = mock_token_manager
# Act & Assert
with pytest.raises(EmailMismatchError) as exc_info:
await OrgInvitationService.accept_invitation(token, user_id)
assert 'does not have an email address' in str(exc_info.value)
@pytest.mark.asyncio
async def test_accept_invitation_email_comparison_is_case_insensitive(
self, mock_invitation
):
"""Test that email comparison is case insensitive."""
# Arrange
user_id = UUID('87654321-4321-8765-4321-876543218765')
token = 'inv-test-token-12345'
mock_user = MagicMock()
mock_user.id = user_id
mock_user.email = 'ALICE@EXAMPLE.COM' # Uppercase email
mock_invitation.email = 'alice@example.com' # Lowercase in invitation
with (
patch(
'server.services.org_invitation_service.OrgInvitationStore.get_invitation_by_token',
new_callable=AsyncMock,
) as mock_get_invitation,
patch(
'server.services.org_invitation_service.OrgInvitationStore.is_token_expired'
) as mock_is_expired,
patch(
'server.services.org_invitation_service.UserStore.get_user_by_id_async',
new_callable=AsyncMock,
) as mock_get_user,
patch(
'server.services.org_invitation_service.OrgMemberStore.get_org_member'
) as mock_get_member,
patch(
'server.services.org_invitation_service.OrgService.create_litellm_integration',
new_callable=AsyncMock,
) as mock_create_litellm,
patch(
'server.services.org_invitation_service.OrgMemberStore.add_user_to_org'
),
patch(
'server.services.org_invitation_service.OrgInvitationStore.update_invitation_status',
new_callable=AsyncMock,
) as mock_update_status,
):
mock_get_invitation.return_value = mock_invitation
mock_is_expired.return_value = False
mock_get_user.return_value = mock_user
mock_get_member.return_value = None
mock_create_litellm.return_value = MagicMock(llm_api_key='test-key')
mock_update_status.return_value = mock_invitation
# Act - should not raise error because emails match case-insensitively
await OrgInvitationService.accept_invitation(token, user_id)
# Assert - invitation was accepted (update_invitation_status was called)
mock_update_status.assert_called_once()
class TestCreateInvitationsBatch:
"""Test cases for batch invitation creation."""
@pytest.fixture
def org_id(self):
"""Organization UUID for testing."""
return UUID('12345678-1234-5678-1234-567812345678')
@pytest.fixture
def inviter_id(self):
"""Inviter UUID for testing."""
return UUID('87654321-4321-8765-4321-876543218765')
@pytest.fixture
def mock_org(self):
"""Create a mock organization."""
org = MagicMock()
org.id = UUID('12345678-1234-5678-1234-567812345678')
org.name = 'Test Org'
return org
@pytest.fixture
def mock_inviter_member(self):
"""Create a mock inviter member with owner role."""
member = MagicMock()
member.user_id = UUID('87654321-4321-8765-4321-876543218765')
member.role_id = 1
return member
@pytest.fixture
def mock_owner_role(self):
"""Create a mock owner role."""
role = MagicMock()
role.id = 1
role.name = 'owner'
return role
@pytest.fixture
def mock_member_role(self):
"""Create a mock member role."""
role = MagicMock()
role.id = 3
role.name = 'member'
return role
@pytest.mark.asyncio
async def test_batch_creates_all_invitations_successfully(
self,
org_id,
inviter_id,
mock_org,
mock_inviter_member,
mock_owner_role,
mock_member_role,
):
"""Test that batch creation succeeds for all valid emails."""
# Arrange
emails = ['alice@example.com', 'bob@example.com']
mock_invitation_1 = MagicMock(spec=OrgInvitation)
mock_invitation_1.id = 1
mock_invitation_2 = MagicMock(spec=OrgInvitation)
mock_invitation_2.id = 2
with (
patch(
'server.services.org_invitation_service.OrgStore.get_org_by_id',
return_value=mock_org,
),
patch(
'server.services.org_invitation_service.OrgMemberStore.get_org_member',
return_value=mock_inviter_member,
),
patch(
'server.services.org_invitation_service.RoleStore.get_role_by_id',
return_value=mock_owner_role,
),
patch(
'server.services.org_invitation_service.RoleStore.get_role_by_name',
return_value=mock_member_role,
),
patch.object(
OrgInvitationService,
'create_invitation',
new_callable=AsyncMock,
side_effect=[mock_invitation_1, mock_invitation_2],
),
):
# Act
successful, failed = await OrgInvitationService.create_invitations_batch(
org_id=org_id,
emails=emails,
role_name='member',
inviter_id=inviter_id,
)
# Assert
assert len(successful) == 2
assert len(failed) == 0
@pytest.mark.asyncio
async def test_batch_handles_partial_success(
self,
org_id,
inviter_id,
mock_org,
mock_inviter_member,
mock_owner_role,
mock_member_role,
):
"""Test that batch returns partial results when some emails fail."""
# Arrange
from server.routes.org_invitation_models import UserAlreadyMemberError
emails = ['alice@example.com', 'existing@example.com']
mock_invitation = MagicMock(spec=OrgInvitation)
mock_invitation.id = 1
with (
patch(
'server.services.org_invitation_service.OrgStore.get_org_by_id',
return_value=mock_org,
),
patch(
'server.services.org_invitation_service.OrgMemberStore.get_org_member',
return_value=mock_inviter_member,
),
patch(
'server.services.org_invitation_service.RoleStore.get_role_by_id',
return_value=mock_owner_role,
),
patch(
'server.services.org_invitation_service.RoleStore.get_role_by_name',
return_value=mock_member_role,
),
patch.object(
OrgInvitationService,
'create_invitation',
new_callable=AsyncMock,
side_effect=[mock_invitation, UserAlreadyMemberError()],
),
):
# Act
successful, failed = await OrgInvitationService.create_invitations_batch(
org_id=org_id,
emails=emails,
role_name='member',
inviter_id=inviter_id,
)
# Assert
assert len(successful) == 1
assert len(failed) == 1
assert failed[0][0] == 'existing@example.com'
@pytest.mark.asyncio
async def test_batch_fails_entirely_on_permission_error(self, org_id, inviter_id):
"""Test that permission error fails the entire batch upfront."""
# Arrange
emails = ['alice@example.com', 'bob@example.com']
with patch(
'server.services.org_invitation_service.OrgStore.get_org_by_id',
return_value=None, # Organization not found
):
# Act & Assert
with pytest.raises(ValueError) as exc_info:
await OrgInvitationService.create_invitations_batch(
org_id=org_id,
emails=emails,
role_name='member',
inviter_id=inviter_id,
)
assert 'not found' in str(exc_info.value)
@pytest.mark.asyncio
async def test_batch_fails_on_invalid_role(
self, org_id, inviter_id, mock_org, mock_inviter_member, mock_owner_role
):
"""Test that invalid role fails the entire batch."""
# Arrange
emails = ['alice@example.com']
with (
patch(
'server.services.org_invitation_service.OrgStore.get_org_by_id',
return_value=mock_org,
),
patch(
'server.services.org_invitation_service.OrgMemberStore.get_org_member',
return_value=mock_inviter_member,
),
patch(
'server.services.org_invitation_service.RoleStore.get_role_by_id',
return_value=mock_owner_role,
),
patch(
'server.services.org_invitation_service.RoleStore.get_role_by_name',
return_value=None, # Invalid role
),
):
# Act & Assert
with pytest.raises(ValueError) as exc_info:
await OrgInvitationService.create_invitations_batch(
org_id=org_id,
emails=emails,
role_name='invalid_role',
inviter_id=inviter_id,
)
assert 'Invalid role' in str(exc_info.value)

View File

@@ -0,0 +1,308 @@
"""Tests for organization invitation store."""
from datetime import datetime, timedelta
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from storage.org_invitation import OrgInvitation
from storage.org_invitation_store import (
INVITATION_TOKEN_LENGTH,
INVITATION_TOKEN_PREFIX,
OrgInvitationStore,
)
class TestGenerateToken:
"""Test cases for token generation."""
def test_generate_token_has_correct_prefix(self):
"""Test that generated tokens have the correct prefix."""
token = OrgInvitationStore.generate_token()
assert token.startswith(INVITATION_TOKEN_PREFIX)
def test_generate_token_has_correct_length(self):
"""Test that generated tokens have the correct total length."""
token = OrgInvitationStore.generate_token()
expected_length = len(INVITATION_TOKEN_PREFIX) + INVITATION_TOKEN_LENGTH
assert len(token) == expected_length
def test_generate_token_uses_alphanumeric_characters(self):
"""Test that generated tokens use only alphanumeric characters."""
token = OrgInvitationStore.generate_token()
# Remove prefix and check the rest is alphanumeric
random_part = token[len(INVITATION_TOKEN_PREFIX) :]
assert random_part.isalnum()
def test_generate_token_is_unique(self):
"""Test that generated tokens are unique (probabilistically)."""
tokens = [OrgInvitationStore.generate_token() for _ in range(100)]
assert len(set(tokens)) == 100
class TestIsTokenExpired:
"""Test cases for token expiration checking."""
def test_token_not_expired_when_future(self):
"""Test that tokens with future expiration are not expired."""
invitation = MagicMock(spec=OrgInvitation)
invitation.expires_at = datetime.utcnow() + timedelta(days=1)
result = OrgInvitationStore.is_token_expired(invitation)
assert result is False
def test_token_expired_when_past(self):
"""Test that tokens with past expiration are expired."""
invitation = MagicMock(spec=OrgInvitation)
invitation.expires_at = datetime.utcnow() - timedelta(seconds=1)
result = OrgInvitationStore.is_token_expired(invitation)
assert result is True
def test_token_expired_at_exact_boundary(self):
"""Test that tokens at exact expiration time are expired."""
# A token that expires "now" should be expired
now = datetime.utcnow()
invitation = MagicMock(spec=OrgInvitation)
invitation.expires_at = now - timedelta(microseconds=1)
result = OrgInvitationStore.is_token_expired(invitation)
assert result is True
class TestCreateInvitation:
"""Test cases for invitation creation."""
@pytest.mark.asyncio
async def test_create_invitation_normalizes_email(self):
"""Test that email is normalized (lowercase, stripped) on creation."""
mock_session = AsyncMock()
mock_session.add = MagicMock()
mock_session.commit = AsyncMock()
mock_session.execute = AsyncMock()
# Mock the result of the re-fetch query
mock_result = MagicMock()
mock_invitation = MagicMock()
mock_invitation.id = 1
mock_invitation.email = 'test@example.com'
mock_result.scalars.return_value.first.return_value = mock_invitation
mock_session.execute.return_value = mock_result
with patch(
'storage.org_invitation_store.a_session_maker'
) as mock_session_maker:
mock_session_manager = AsyncMock()
mock_session_manager.__aenter__.return_value = mock_session
mock_session_manager.__aexit__.return_value = None
mock_session_maker.return_value = mock_session_manager
from uuid import UUID
await OrgInvitationStore.create_invitation(
org_id=UUID('12345678-1234-5678-1234-567812345678'),
email=' TEST@EXAMPLE.COM ',
role_id=1,
inviter_id=UUID('87654321-4321-8765-4321-876543218765'),
)
# Verify that the OrgInvitation was created with normalized email
add_call = mock_session.add.call_args
created_invitation = add_call[0][0]
assert created_invitation.email == 'test@example.com'
class TestGetInvitationByToken:
"""Test cases for getting invitation by token."""
@pytest.mark.asyncio
async def test_get_invitation_by_token_returns_invitation(self):
"""Test that get_invitation_by_token returns the invitation when found."""
mock_invitation = MagicMock(spec=OrgInvitation)
mock_invitation.token = 'inv-test-token-12345'
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.scalars.return_value.first.return_value = mock_invitation
mock_session.execute = AsyncMock(return_value=mock_result)
with patch(
'storage.org_invitation_store.a_session_maker'
) as mock_session_maker:
mock_session_manager = AsyncMock()
mock_session_manager.__aenter__.return_value = mock_session
mock_session_manager.__aexit__.return_value = None
mock_session_maker.return_value = mock_session_manager
result = await OrgInvitationStore.get_invitation_by_token(
'inv-test-token-12345'
)
assert result == mock_invitation
@pytest.mark.asyncio
async def test_get_invitation_by_token_returns_none_when_not_found(self):
"""Test that get_invitation_by_token returns None when not found."""
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.scalars.return_value.first.return_value = None
mock_session.execute = AsyncMock(return_value=mock_result)
with patch(
'storage.org_invitation_store.a_session_maker'
) as mock_session_maker:
mock_session_manager = AsyncMock()
mock_session_manager.__aenter__.return_value = mock_session
mock_session_manager.__aexit__.return_value = None
mock_session_maker.return_value = mock_session_manager
result = await OrgInvitationStore.get_invitation_by_token(
'inv-nonexistent-token'
)
assert result is None
class TestGetPendingInvitation:
"""Test cases for getting pending invitation."""
@pytest.mark.asyncio
async def test_get_pending_invitation_normalizes_email(self):
"""Test that email is normalized when querying for pending invitations."""
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.scalars.return_value.first.return_value = None
mock_session.execute = AsyncMock(return_value=mock_result)
with patch(
'storage.org_invitation_store.a_session_maker'
) as mock_session_maker:
mock_session_manager = AsyncMock()
mock_session_manager.__aenter__.return_value = mock_session
mock_session_manager.__aexit__.return_value = None
mock_session_maker.return_value = mock_session_manager
from uuid import UUID
await OrgInvitationStore.get_pending_invitation(
org_id=UUID('12345678-1234-5678-1234-567812345678'),
email=' TEST@EXAMPLE.COM ',
)
# Verify the query was called (email normalization happens in the filter)
assert mock_session.execute.called
class TestUpdateInvitationStatus:
"""Test cases for updating invitation status."""
@pytest.mark.asyncio
async def test_update_status_sets_accepted_at_for_accepted(self):
"""Test that accepted_at is set when status is accepted."""
from uuid import UUID
mock_invitation = MagicMock(spec=OrgInvitation)
mock_invitation.id = 1
mock_invitation.status = OrgInvitation.STATUS_PENDING
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.scalars.return_value.first.return_value = mock_invitation
mock_session.execute = AsyncMock(return_value=mock_result)
mock_session.commit = AsyncMock()
mock_session.refresh = AsyncMock()
with patch(
'storage.org_invitation_store.a_session_maker'
) as mock_session_maker:
mock_session_manager = AsyncMock()
mock_session_manager.__aenter__.return_value = mock_session
mock_session_manager.__aexit__.return_value = None
mock_session_maker.return_value = mock_session_manager
user_id = UUID('87654321-4321-8765-4321-876543218765')
await OrgInvitationStore.update_invitation_status(
invitation_id=1,
status=OrgInvitation.STATUS_ACCEPTED,
accepted_by_user_id=user_id,
)
assert mock_invitation.accepted_at is not None
assert mock_invitation.accepted_by_user_id == user_id
@pytest.mark.asyncio
async def test_update_status_returns_none_when_not_found(self):
"""Test that update returns None when invitation not found."""
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.scalars.return_value.first.return_value = None
mock_session.execute = AsyncMock(return_value=mock_result)
with patch(
'storage.org_invitation_store.a_session_maker'
) as mock_session_maker:
mock_session_manager = AsyncMock()
mock_session_manager.__aenter__.return_value = mock_session
mock_session_manager.__aexit__.return_value = None
mock_session_maker.return_value = mock_session_manager
result = await OrgInvitationStore.update_invitation_status(
invitation_id=999,
status=OrgInvitation.STATUS_ACCEPTED,
)
assert result is None
class TestMarkExpiredIfNeeded:
"""Test cases for marking expired invitations."""
@pytest.mark.asyncio
async def test_marks_expired_when_pending_and_past_expiry(self):
"""Test that pending expired invitations are marked as expired."""
mock_invitation = MagicMock(spec=OrgInvitation)
mock_invitation.id = 1
mock_invitation.status = OrgInvitation.STATUS_PENDING
mock_invitation.expires_at = datetime.utcnow() - timedelta(days=1)
with patch.object(
OrgInvitationStore,
'update_invitation_status',
new_callable=AsyncMock,
) as mock_update:
result = await OrgInvitationStore.mark_expired_if_needed(mock_invitation)
assert result is True
mock_update.assert_called_once_with(1, OrgInvitation.STATUS_EXPIRED)
@pytest.mark.asyncio
async def test_does_not_mark_when_not_expired(self):
"""Test that non-expired invitations are not marked."""
mock_invitation = MagicMock(spec=OrgInvitation)
mock_invitation.id = 1
mock_invitation.status = OrgInvitation.STATUS_PENDING
mock_invitation.expires_at = datetime.utcnow() + timedelta(days=1)
with patch.object(
OrgInvitationStore,
'update_invitation_status',
new_callable=AsyncMock,
) as mock_update:
result = await OrgInvitationStore.mark_expired_if_needed(mock_invitation)
assert result is False
mock_update.assert_not_called()
@pytest.mark.asyncio
async def test_does_not_mark_when_not_pending(self):
"""Test that non-pending invitations are not marked even if expired."""
mock_invitation = MagicMock(spec=OrgInvitation)
mock_invitation.id = 1
mock_invitation.status = OrgInvitation.STATUS_ACCEPTED
mock_invitation.expires_at = datetime.utcnow() - timedelta(days=1)
with patch.object(
OrgInvitationStore,
'update_invitation_status',
new_callable=AsyncMock,
) as mock_update:
result = await OrgInvitationStore.mark_expired_if_needed(mock_invitation)
assert result is False
mock_update.assert_not_called()

View File

@@ -0,0 +1,388 @@
"""Tests for organization invitations API router."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from server.routes.org_invitation_models import (
EmailMismatchError,
InvitationExpiredError,
InvitationInvalidError,
UserAlreadyMemberError,
)
from server.routes.org_invitations import accept_router, invitation_router
@pytest.fixture
def app():
"""Create a FastAPI app with the invitation routers."""
app = FastAPI()
app.include_router(invitation_router)
app.include_router(accept_router)
return app
@pytest.fixture
def client(app):
"""Create a test client for the app."""
return TestClient(app)
class TestRouterPrefixes:
"""Test that router prefixes are configured correctly."""
def test_invitation_router_has_correct_prefix(self):
"""Test that invitation_router has /api/organizations/{org_id}/members prefix."""
assert invitation_router.prefix == '/api/organizations/{org_id}/members'
def test_accept_router_has_correct_prefix(self):
"""Test that accept_router has /api/organizations/members/invite prefix."""
assert accept_router.prefix == '/api/organizations/members/invite'
class TestAcceptInvitationEndpoint:
"""Test cases for the accept invitation endpoint."""
@pytest.fixture
def mock_user_auth(self):
"""Create a mock user auth."""
user_auth = MagicMock()
user_auth.get_user_id = AsyncMock(
return_value='87654321-4321-8765-4321-876543218765'
)
return user_auth
@pytest.mark.asyncio
async def test_accept_unauthenticated_redirects_to_login(self, client):
"""Test that unauthenticated users are redirected to login with invitation token."""
with patch(
'server.routes.org_invitations.get_user_auth',
new_callable=AsyncMock,
return_value=None,
):
response = client.get(
'/api/organizations/members/invite/accept?token=inv-test-token-123',
follow_redirects=False,
)
assert response.status_code == 302
assert '/login?invitation_token=inv-test-token-123' in response.headers.get(
'location', ''
)
@pytest.mark.asyncio
async def test_accept_authenticated_success_redirects_home(
self, client, mock_user_auth
):
"""Test that successful acceptance redirects to home page."""
mock_invitation = MagicMock()
with (
patch(
'server.routes.org_invitations.get_user_auth',
new_callable=AsyncMock,
return_value=mock_user_auth,
),
patch(
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
new_callable=AsyncMock,
return_value=mock_invitation,
),
):
response = client.get(
'/api/organizations/members/invite/accept?token=inv-test-token-123',
follow_redirects=False,
)
assert response.status_code == 302
location = response.headers.get('location', '')
assert location.endswith('/')
assert 'invitation_expired' not in location
assert 'invitation_invalid' not in location
assert 'email_mismatch' not in location
@pytest.mark.asyncio
async def test_accept_expired_invitation_redirects_with_flag(
self, client, mock_user_auth
):
"""Test that expired invitation redirects with invitation_expired=true."""
with (
patch(
'server.routes.org_invitations.get_user_auth',
new_callable=AsyncMock,
return_value=mock_user_auth,
),
patch(
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
new_callable=AsyncMock,
side_effect=InvitationExpiredError(),
),
):
response = client.get(
'/api/organizations/members/invite/accept?token=inv-test-token-123',
follow_redirects=False,
)
assert response.status_code == 302
assert 'invitation_expired=true' in response.headers.get('location', '')
@pytest.mark.asyncio
async def test_accept_invalid_invitation_redirects_with_flag(
self, client, mock_user_auth
):
"""Test that invalid invitation redirects with invitation_invalid=true."""
with (
patch(
'server.routes.org_invitations.get_user_auth',
new_callable=AsyncMock,
return_value=mock_user_auth,
),
patch(
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
new_callable=AsyncMock,
side_effect=InvitationInvalidError(),
),
):
response = client.get(
'/api/organizations/members/invite/accept?token=inv-test-token-123',
follow_redirects=False,
)
assert response.status_code == 302
assert 'invitation_invalid=true' in response.headers.get('location', '')
@pytest.mark.asyncio
async def test_accept_already_member_redirects_with_flag(
self, client, mock_user_auth
):
"""Test that already member error redirects with already_member=true."""
with (
patch(
'server.routes.org_invitations.get_user_auth',
new_callable=AsyncMock,
return_value=mock_user_auth,
),
patch(
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
new_callable=AsyncMock,
side_effect=UserAlreadyMemberError(),
),
):
response = client.get(
'/api/organizations/members/invite/accept?token=inv-test-token-123',
follow_redirects=False,
)
assert response.status_code == 302
assert 'already_member=true' in response.headers.get('location', '')
@pytest.mark.asyncio
async def test_accept_email_mismatch_redirects_with_flag(
self, client, mock_user_auth
):
"""Test that email mismatch error redirects with email_mismatch=true."""
with (
patch(
'server.routes.org_invitations.get_user_auth',
new_callable=AsyncMock,
return_value=mock_user_auth,
),
patch(
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
new_callable=AsyncMock,
side_effect=EmailMismatchError(),
),
):
response = client.get(
'/api/organizations/members/invite/accept?token=inv-test-token-123',
follow_redirects=False,
)
assert response.status_code == 302
assert 'email_mismatch=true' in response.headers.get('location', '')
@pytest.mark.asyncio
async def test_accept_unexpected_error_redirects_with_flag(
self, client, mock_user_auth
):
"""Test that unexpected errors redirect with invitation_error=true."""
with (
patch(
'server.routes.org_invitations.get_user_auth',
new_callable=AsyncMock,
return_value=mock_user_auth,
),
patch(
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
new_callable=AsyncMock,
side_effect=Exception('Unexpected error'),
),
):
response = client.get(
'/api/organizations/members/invite/accept?token=inv-test-token-123',
follow_redirects=False,
)
assert response.status_code == 302
assert 'invitation_error=true' in response.headers.get('location', '')
class TestCreateInvitationBatchEndpoint:
"""Test cases for the batch invitation creation endpoint."""
@pytest.fixture
def batch_app(self):
"""Create a FastAPI app with dependency overrides for batch tests."""
from openhands.server.user_auth import get_user_id
app = FastAPI()
app.include_router(invitation_router)
# Override the get_user_id dependency
app.dependency_overrides[get_user_id] = (
lambda: '87654321-4321-8765-4321-876543218765'
)
return app
@pytest.fixture
def batch_client(self, batch_app):
"""Create a test client with dependency overrides."""
return TestClient(batch_app)
@pytest.fixture
def mock_invitation(self):
"""Create a mock invitation."""
from datetime import datetime
invitation = MagicMock()
invitation.id = 1
invitation.email = 'alice@example.com'
invitation.role = MagicMock(name='member')
invitation.role.name = 'member'
invitation.role_id = 3
invitation.status = 'pending'
invitation.created_at = datetime(2026, 2, 17, 10, 0, 0)
invitation.expires_at = datetime(2026, 2, 24, 10, 0, 0)
return invitation
@pytest.mark.asyncio
async def test_batch_create_returns_successful_invitations(
self, batch_client, mock_invitation
):
"""Test that batch creation returns successful invitations."""
mock_invitation_2 = MagicMock()
mock_invitation_2.id = 2
mock_invitation_2.email = 'bob@example.com'
mock_invitation_2.role = MagicMock()
mock_invitation_2.role.name = 'member'
mock_invitation_2.role_id = 3
mock_invitation_2.status = 'pending'
mock_invitation_2.created_at = mock_invitation.created_at
mock_invitation_2.expires_at = mock_invitation.expires_at
with (
patch(
'server.routes.org_invitations.check_rate_limit_by_user_id',
new_callable=AsyncMock,
),
patch(
'server.routes.org_invitations.OrgInvitationService.create_invitations_batch',
new_callable=AsyncMock,
return_value=([mock_invitation, mock_invitation_2], []),
),
):
response = batch_client.post(
'/api/organizations/12345678-1234-5678-1234-567812345678/members/invite',
json={
'emails': ['alice@example.com', 'bob@example.com'],
'role': 'member',
},
)
assert response.status_code == 201
data = response.json()
assert len(data['successful']) == 2
assert len(data['failed']) == 0
@pytest.mark.asyncio
async def test_batch_create_returns_partial_success(
self, batch_client, mock_invitation
):
"""Test that batch creation returns both successful and failed invitations."""
failed_emails = [('existing@example.com', 'User is already a member')]
with (
patch(
'server.routes.org_invitations.check_rate_limit_by_user_id',
new_callable=AsyncMock,
),
patch(
'server.routes.org_invitations.OrgInvitationService.create_invitations_batch',
new_callable=AsyncMock,
return_value=([mock_invitation], failed_emails),
),
):
response = batch_client.post(
'/api/organizations/12345678-1234-5678-1234-567812345678/members/invite',
json={
'emails': ['alice@example.com', 'existing@example.com'],
'role': 'member',
},
)
assert response.status_code == 201
data = response.json()
assert len(data['successful']) == 1
assert len(data['failed']) == 1
assert data['failed'][0]['email'] == 'existing@example.com'
assert 'already a member' in data['failed'][0]['error']
@pytest.mark.asyncio
async def test_batch_create_permission_denied_returns_403(self, batch_client):
"""Test that permission denied returns 403 for entire batch."""
from server.routes.org_invitation_models import InsufficientPermissionError
with (
patch(
'server.routes.org_invitations.check_rate_limit_by_user_id',
new_callable=AsyncMock,
),
patch(
'server.routes.org_invitations.OrgInvitationService.create_invitations_batch',
new_callable=AsyncMock,
side_effect=InsufficientPermissionError(
'Only owners and admins can invite'
),
),
):
response = batch_client.post(
'/api/organizations/12345678-1234-5678-1234-567812345678/members/invite',
json={'emails': ['alice@example.com'], 'role': 'member'},
)
assert response.status_code == 403
assert 'owners and admins' in response.json()['detail']
@pytest.mark.asyncio
async def test_batch_create_invalid_role_returns_400(self, batch_client):
"""Test that invalid role returns 400."""
with (
patch(
'server.routes.org_invitations.check_rate_limit_by_user_id',
new_callable=AsyncMock,
),
patch(
'server.routes.org_invitations.OrgInvitationService.create_invitations_batch',
new_callable=AsyncMock,
side_effect=ValueError('Invalid role: superuser'),
),
):
response = batch_client.post(
'/api/organizations/12345678-1234-5678-1234-567812345678/members/invite',
json={'emails': ['alice@example.com'], 'role': 'superuser'},
)
assert response.status_code == 400
assert 'Invalid role' in response.json()['detail']

View File

@@ -1,8 +1,15 @@
import uuid
from unittest.mock import patch
import pytest
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
# Mock the database module before importing OrgMemberStore
with patch('storage.database.engine'), patch('storage.database.a_engine'):
with patch('storage.database.engine', create=True), patch(
'storage.database.a_engine', create=True
):
from storage.base import Base
from storage.org import Org
from storage.org_member import OrgMember
from storage.org_member_store import OrgMemberStore
@@ -10,6 +17,31 @@ with patch('storage.database.engine'), patch('storage.database.a_engine'):
from storage.user import User
@pytest.fixture
async def async_engine():
"""Create an async SQLite engine for testing."""
engine = create_async_engine(
'sqlite+aiosqlite:///:memory:',
poolclass=StaticPool,
connect_args={'check_same_thread': False},
echo=False,
)
# Create all tables
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.fixture
async def async_session_maker(async_engine):
"""Create an async session maker for testing."""
return async_sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False)
def test_get_org_members(session_maker):
# Test getting org_members by org ID
with session_maker() as session:
@@ -126,6 +158,57 @@ def test_get_org_member(session_maker):
assert retrieved_org_member.llm_api_key.get_secret_value() == 'test-key'
def test_get_org_member_for_current_org(session_maker):
# Test getting org_member for user's current organization
with session_maker() as session:
# Create test data - user belongs to two orgs but current_org is org1
org1 = Org(name='test-org-1')
org2 = Org(name='test-org-2')
session.add_all([org1, org2])
session.flush()
user = User(id=uuid.uuid4(), current_org_id=org1.id)
role = Role(name='admin', rank=1)
session.add_all([user, role])
session.flush()
org_member1 = OrgMember(
org_id=org1.id,
user_id=user.id,
role_id=role.id,
llm_api_key='test-key-1',
status='active',
)
org_member2 = OrgMember(
org_id=org2.id,
user_id=user.id,
role_id=role.id,
llm_api_key='test-key-2',
status='active',
)
session.add_all([org_member1, org_member2])
session.commit()
user_id = user.id
org1_id = org1.id
# Test retrieval - should return org_member for current_org (org1)
with patch('storage.org_member_store.session_maker', session_maker):
retrieved_org_member = OrgMemberStore.get_org_member_for_current_org(user_id)
assert retrieved_org_member is not None
assert retrieved_org_member.org_id == org1_id
assert retrieved_org_member.user_id == user_id
assert retrieved_org_member.llm_api_key.get_secret_value() == 'test-key-1'
def test_get_org_member_for_current_org_user_not_found(session_maker):
# Test getting org_member for non-existent user
with patch('storage.org_member_store.session_maker', session_maker):
retrieved_org_member = OrgMemberStore.get_org_member_for_current_org(
uuid.uuid4()
)
assert retrieved_org_member is None
def test_add_user_to_org(session_maker):
# Test adding a user to an org
with session_maker() as session:
@@ -251,3 +334,324 @@ def test_remove_user_from_org_not_found(session_maker):
with patch('storage.org_member_store.session_maker', session_maker):
result = OrgMemberStore.remove_user_from_org(uuid4(), 99999)
assert result is False
@pytest.mark.asyncio
async def test_get_org_members_paginated_basic(async_session_maker):
"""Test basic pagination returns correct number of items."""
# Arrange
async with async_session_maker() as session:
org = Org(name='test-org')
session.add(org)
await session.flush()
role = Role(name='admin', rank=1)
session.add(role)
await session.flush()
# Create 5 users
users = [
User(id=uuid.uuid4(), current_org_id=org.id, email=f'user{i}@example.com')
for i in range(5)
]
session.add_all(users)
await session.flush()
# Create org members
org_members = [
OrgMember(
org_id=org.id,
user_id=user.id,
role_id=role.id,
llm_api_key=f'test-key-{i}',
status='active',
)
for i, user in enumerate(users)
]
session.add_all(org_members)
await session.commit()
org_id = org.id
# Act
with patch('storage.org_member_store.a_session_maker', async_session_maker):
members, has_more = await OrgMemberStore.get_org_members_paginated(
org_id=org_id, offset=0, limit=3
)
# Assert
assert len(members) == 3
assert has_more is True
# Verify user and role relationships are loaded
assert all(member.user is not None for member in members)
assert all(member.role is not None for member in members)
@pytest.mark.asyncio
async def test_get_org_members_paginated_no_more(async_session_maker):
"""Test pagination when there are no more results."""
# Arrange
async with async_session_maker() as session:
org = Org(name='test-org')
session.add(org)
await session.flush()
role = Role(name='admin', rank=1)
session.add(role)
await session.flush()
# Create 3 users
users = [
User(id=uuid.uuid4(), current_org_id=org.id, email=f'user{i}@example.com')
for i in range(3)
]
session.add_all(users)
await session.flush()
# Create org members
org_members = [
OrgMember(
org_id=org.id,
user_id=user.id,
role_id=role.id,
llm_api_key=f'test-key-{i}',
status='active',
)
for i, user in enumerate(users)
]
session.add_all(org_members)
await session.commit()
org_id = org.id
# Act
with patch('storage.org_member_store.a_session_maker', async_session_maker):
members, has_more = await OrgMemberStore.get_org_members_paginated(
org_id=org_id, offset=0, limit=5
)
# Assert
assert len(members) == 3
assert has_more is False
@pytest.mark.asyncio
async def test_get_org_members_paginated_exact_limit(async_session_maker):
"""Test pagination when results exactly match limit."""
# Arrange
async with async_session_maker() as session:
org = Org(name='test-org')
session.add(org)
await session.flush()
role = Role(name='admin', rank=1)
session.add(role)
await session.flush()
# Create exactly 5 users
users = [
User(id=uuid.uuid4(), current_org_id=org.id, email=f'user{i}@example.com')
for i in range(5)
]
session.add_all(users)
await session.flush()
# Create org members
org_members = [
OrgMember(
org_id=org.id,
user_id=user.id,
role_id=role.id,
llm_api_key=f'test-key-{i}',
status='active',
)
for i, user in enumerate(users)
]
session.add_all(org_members)
await session.commit()
org_id = org.id
# Act
with patch('storage.org_member_store.a_session_maker', async_session_maker):
members, has_more = await OrgMemberStore.get_org_members_paginated(
org_id=org_id, offset=0, limit=5
)
# Assert
assert len(members) == 5
assert has_more is False
@pytest.mark.asyncio
async def test_get_org_members_paginated_with_offset(async_session_maker):
"""Test pagination with offset skips correct number of items."""
# Arrange
async with async_session_maker() as session:
org = Org(name='test-org')
session.add(org)
await session.flush()
role = Role(name='admin', rank=1)
session.add(role)
await session.flush()
# Create 10 users
users = [
User(id=uuid.uuid4(), current_org_id=org.id, email=f'user{i}@example.com')
for i in range(10)
]
session.add_all(users)
await session.flush()
# Create org members
org_members = [
OrgMember(
org_id=org.id,
user_id=user.id,
role_id=role.id,
llm_api_key=f'test-key-{i}',
status='active',
)
for i, user in enumerate(users)
]
session.add_all(org_members)
await session.commit()
org_id = org.id
# Act - Get first page
with patch('storage.org_member_store.a_session_maker', async_session_maker):
first_page, has_more_first = await OrgMemberStore.get_org_members_paginated(
org_id=org_id, offset=0, limit=3
)
# Get second page
second_page, has_more_second = await OrgMemberStore.get_org_members_paginated(
org_id=org_id, offset=3, limit=3
)
# Assert
assert len(first_page) == 3
assert has_more_first is True
assert len(second_page) == 3
assert has_more_second is True
# Verify no overlap between pages
first_user_ids = {member.user_id for member in first_page}
second_user_ids = {member.user_id for member in second_page}
assert first_user_ids.isdisjoint(second_user_ids)
@pytest.mark.asyncio
async def test_get_org_members_paginated_empty_org(async_session_maker):
"""Test pagination with empty organization returns empty list."""
# Arrange
async with async_session_maker() as session:
org = Org(name='test-org')
session.add(org)
await session.commit()
org_id = org.id
# Act
with patch('storage.org_member_store.a_session_maker', async_session_maker):
members, has_more = await OrgMemberStore.get_org_members_paginated(
org_id=org_id, offset=0, limit=10
)
# Assert
assert len(members) == 0
assert has_more is False
@pytest.mark.asyncio
async def test_get_org_members_paginated_ordering(async_session_maker):
"""Test that pagination orders results by user_id."""
# Arrange
async with async_session_maker() as session:
org = Org(name='test-org')
session.add(org)
await session.flush()
role = Role(name='admin', rank=1)
session.add(role)
await session.flush()
# Create users with specific IDs to test ordering
user_ids = [uuid.uuid4() for _ in range(5)]
user_ids.sort() # Sort to verify ordering
users = [
User(id=user_id, current_org_id=org.id, email=f'user{i}@example.com')
for i, user_id in enumerate(user_ids)
]
session.add_all(users)
await session.flush()
# Create org members in reverse order to test that ordering works
org_members = [
OrgMember(
org_id=org.id,
user_id=user_id,
role_id=role.id,
llm_api_key=f'test-key-{i}',
status='active',
)
for i, user_id in enumerate(reversed(user_ids))
]
session.add_all(org_members)
await session.commit()
org_id = org.id
# Act
with patch('storage.org_member_store.a_session_maker', async_session_maker):
members, has_more = await OrgMemberStore.get_org_members_paginated(
org_id=org_id, offset=0, limit=10
)
# Assert
assert len(members) == 5
# Verify members are ordered by user_id
member_user_ids = [member.user_id for member in members]
assert member_user_ids == sorted(member_user_ids)
@pytest.mark.asyncio
async def test_get_org_members_paginated_eager_loading(async_session_maker):
"""Test that user and role relationships are eagerly loaded."""
# Arrange
async with async_session_maker() as session:
org = Org(name='test-org')
session.add(org)
await session.flush()
role = Role(name='owner', rank=10)
session.add(role)
await session.flush()
user = User(id=uuid.uuid4(), current_org_id=org.id, email='test@example.com')
session.add(user)
await session.flush()
org_member = OrgMember(
org_id=org.id,
user_id=user.id,
role_id=role.id,
llm_api_key='test-key',
status='active',
)
session.add(org_member)
await session.commit()
org_id = org.id
# Act
with patch('storage.org_member_store.a_session_maker', async_session_maker):
members, has_more = await OrgMemberStore.get_org_members_paginated(
org_id=org_id, offset=0, limit=10
)
# Assert
assert len(members) == 1
member = members[0]
# Verify relationships are loaded (not lazy)
assert member.user is not None
assert member.user.email == 'test@example.com'
assert member.role is not None
assert member.role.name == 'owner'
assert member.role.rank == 10

View File

@@ -482,7 +482,7 @@ async def test_get_org_credits_success(mock_litellm_api):
spend = 25.0
mock_team_info = {
'litellm_budget_table': {'max_budget': max_budget},
'max_budget_in_team': max_budget,
'spend': spend,
}
@@ -1535,6 +1535,118 @@ async def test_update_org_with_permissions_database_error(session_maker):
assert 'Failed to update organization' in str(exc_info.value)
@pytest.mark.asyncio
async def test_update_org_with_permissions_duplicate_name_raises_org_name_exists_error(
session_maker,
):
"""
GIVEN: User updates org name to a name already used by another organization
WHEN: update_org_with_permissions is called
THEN: OrgNameExistsError is raised with the conflicting name
"""
# Arrange
org_id = uuid.uuid4()
other_org_id = uuid.uuid4()
user_id = str(uuid.uuid4())
duplicate_name = 'Existing Org Name'
mock_current_org = Org(
id=org_id,
name='My Org',
contact_name='John Doe',
contact_email='john@example.com',
org_version=5,
)
mock_org_with_name = Org(
id=other_org_id,
name=duplicate_name,
contact_name='Jane Doe',
contact_email='jane@example.com',
)
from server.routes.org_models import OrgUpdate
update_data = OrgUpdate(name=duplicate_name)
with (
patch('storage.org_store.session_maker', session_maker),
patch('storage.org_member_store.session_maker', session_maker),
patch('storage.role_store.session_maker', session_maker),
patch(
'storage.org_service.OrgStore.get_org_by_id',
return_value=mock_current_org,
),
patch('storage.org_service.OrgService.is_org_member', return_value=True),
patch(
'storage.org_service.OrgStore.get_org_by_name',
return_value=mock_org_with_name,
),
):
# Act & Assert
with pytest.raises(OrgNameExistsError) as exc_info:
await OrgService.update_org_with_permissions(
org_id=org_id,
update_data=update_data,
user_id=user_id,
)
assert duplicate_name in str(exc_info.value)
@pytest.mark.asyncio
async def test_update_org_with_permissions_same_name_allowed(session_maker):
"""
GIVEN: User updates org with name unchanged (same as current org name)
WHEN: update_org_with_permissions is called
THEN: No OrgNameExistsError; update proceeds (name uniqueness allows same org)
"""
# Arrange
org_id = uuid.uuid4()
user_id = str(uuid.uuid4())
current_name = 'My Org'
mock_org = Org(
id=org_id,
name=current_name,
contact_name='John Doe',
contact_email='john@example.com',
org_version=5,
)
from server.routes.org_models import OrgUpdate
update_data = OrgUpdate(name=current_name)
with (
patch('storage.org_store.session_maker', session_maker),
patch('storage.org_member_store.session_maker', session_maker),
patch('storage.role_store.session_maker', session_maker),
patch(
'storage.org_service.OrgStore.get_org_by_id',
return_value=mock_org,
),
patch('storage.org_service.OrgService.is_org_member', return_value=True),
patch(
'storage.org_service.OrgStore.get_org_by_name',
return_value=mock_org,
),
patch(
'storage.org_service.OrgStore.update_org',
return_value=mock_org,
),
):
# Act
result = await OrgService.update_org_with_permissions(
org_id=org_id,
update_data=update_data,
user_id=user_id,
)
# Assert
assert result is not None
assert result.name == current_name
@pytest.mark.asyncio
async def test_update_org_with_permissions_only_llm_fields(session_maker):
"""
@@ -1657,3 +1769,258 @@ async def test_update_org_with_permissions_only_non_llm_fields(session_maker):
assert result.contact_name == 'Jane Doe'
assert result.conversation_expiration == 60
assert result.enable_proactive_conversation_starters is False
@pytest.mark.asyncio
async def test_check_byor_export_enabled_returns_true_when_enabled():
"""
GIVEN: User has current_org with byor_export_enabled=True
WHEN: check_byor_export_enabled is called
THEN: Returns True
"""
# Arrange
user_id = 'test-user-123'
org_id = uuid.uuid4()
mock_user = MagicMock()
mock_user.current_org_id = org_id
mock_org = MagicMock()
mock_org.byor_export_enabled = True
with (
patch(
'storage.org_service.UserStore.get_user_by_id_async',
AsyncMock(return_value=mock_user),
),
patch(
'storage.org_service.OrgStore.get_org_by_id',
return_value=mock_org,
),
):
# Act
result = await OrgService.check_byor_export_enabled(user_id)
# Assert
assert result is True
@pytest.mark.asyncio
async def test_check_byor_export_enabled_returns_false_when_disabled():
"""
GIVEN: User has current_org with byor_export_enabled=False
WHEN: check_byor_export_enabled is called
THEN: Returns False
"""
# Arrange
user_id = 'test-user-123'
org_id = uuid.uuid4()
mock_user = MagicMock()
mock_user.current_org_id = org_id
mock_org = MagicMock()
mock_org.byor_export_enabled = False
with (
patch(
'storage.org_service.UserStore.get_user_by_id_async',
AsyncMock(return_value=mock_user),
),
patch(
'storage.org_service.OrgStore.get_org_by_id',
return_value=mock_org,
),
):
# Act
result = await OrgService.check_byor_export_enabled(user_id)
# Assert
assert result is False
@pytest.mark.asyncio
async def test_check_byor_export_enabled_returns_false_when_user_not_found():
"""
GIVEN: User does not exist
WHEN: check_byor_export_enabled is called
THEN: Returns False
"""
# Arrange
user_id = 'nonexistent-user'
with patch(
'storage.org_service.UserStore.get_user_by_id_async',
AsyncMock(return_value=None),
):
# Act
result = await OrgService.check_byor_export_enabled(user_id)
# Assert
assert result is False
@pytest.mark.asyncio
async def test_check_byor_export_enabled_returns_false_when_no_current_org():
"""
GIVEN: User exists but has no current_org_id
WHEN: check_byor_export_enabled is called
THEN: Returns False
"""
# Arrange
user_id = 'test-user-123'
mock_user = MagicMock()
mock_user.current_org_id = None
with patch(
'storage.org_service.UserStore.get_user_by_id_async',
AsyncMock(return_value=mock_user),
):
# Act
result = await OrgService.check_byor_export_enabled(user_id)
# Assert
assert result is False
@pytest.mark.asyncio
async def test_check_byor_export_enabled_returns_false_when_org_not_found():
"""
GIVEN: User has current_org_id but org does not exist
WHEN: check_byor_export_enabled is called
THEN: Returns False
"""
# Arrange
user_id = 'test-user-123'
org_id = uuid.uuid4()
mock_user = MagicMock()
mock_user.current_org_id = org_id
with (
patch(
'storage.org_service.UserStore.get_user_by_id_async',
AsyncMock(return_value=mock_user),
),
patch(
'storage.org_service.OrgStore.get_org_by_id',
return_value=None,
),
):
# Act
result = await OrgService.check_byor_export_enabled(user_id)
# Assert
assert result is False
@pytest.mark.asyncio
async def test_switch_org_success():
"""
GIVEN: Valid org_id and user_id where user is a member
WHEN: switch_org is called
THEN: User's current_org_id is updated and org is returned
"""
# Arrange
org_id = uuid.uuid4()
user_id = str(uuid.uuid4())
mock_org = Org(
id=org_id,
name='Target Organization',
contact_name='John Doe',
contact_email='john@example.com',
)
mock_updated_user = User(id=uuid.UUID(user_id), current_org_id=org_id)
with (
patch('storage.org_service.OrgStore.get_org_by_id', return_value=mock_org),
patch('storage.org_service.OrgService.is_org_member', return_value=True),
patch(
'storage.org_service.UserStore.update_current_org',
return_value=mock_updated_user,
),
):
# Act
result = await OrgService.switch_org(user_id, org_id)
# Assert
assert result is not None
assert result.id == org_id
assert result.name == 'Target Organization'
@pytest.mark.asyncio
async def test_switch_org_org_not_found():
"""
GIVEN: Organization does not exist
WHEN: switch_org is called
THEN: OrgNotFoundError is raised
"""
# Arrange
org_id = uuid.uuid4()
user_id = str(uuid.uuid4())
with patch('storage.org_service.OrgStore.get_org_by_id', return_value=None):
# Act & Assert
with pytest.raises(OrgNotFoundError) as exc_info:
await OrgService.switch_org(user_id, org_id)
assert str(org_id) in str(exc_info.value)
@pytest.mark.asyncio
async def test_switch_org_user_not_member():
"""
GIVEN: User is not a member of the organization
WHEN: switch_org is called
THEN: OrgAuthorizationError is raised
"""
# Arrange
org_id = uuid.uuid4()
user_id = str(uuid.uuid4())
mock_org = Org(
id=org_id,
name='Target Organization',
contact_name='John Doe',
contact_email='john@example.com',
)
with (
patch('storage.org_service.OrgStore.get_org_by_id', return_value=mock_org),
patch('storage.org_service.OrgService.is_org_member', return_value=False),
):
# Act & Assert
with pytest.raises(OrgAuthorizationError) as exc_info:
await OrgService.switch_org(user_id, org_id)
assert 'member' in str(exc_info.value).lower()
@pytest.mark.asyncio
async def test_switch_org_user_not_found():
"""
GIVEN: User does not exist in database
WHEN: switch_org is called
THEN: OrgDatabaseError is raised
"""
# Arrange
org_id = uuid.uuid4()
user_id = str(uuid.uuid4())
mock_org = Org(
id=org_id,
name='Target Organization',
contact_name='John Doe',
contact_email='john@example.com',
)
with (
patch('storage.org_service.OrgStore.get_org_by_id', return_value=mock_org),
patch('storage.org_service.OrgService.is_org_member', return_value=True),
patch('storage.org_service.UserStore.update_current_org', return_value=None),
):
# Act & Assert
with pytest.raises(OrgDatabaseError) as exc_info:
await OrgService.switch_org(user_id, org_id)
assert 'User not found' in str(exc_info.value)

View File

@@ -786,3 +786,23 @@ def test_get_user_orgs_paginated_ordering(session_maker, mock_litellm_api):
assert orgs[0].name == 'Apple Org'
assert orgs[1].name == 'Banana Org'
assert orgs[2].name == 'Zebra Org'
def test_orphaned_user_error_contains_user_ids():
"""
GIVEN: OrphanedUserError is created with a list of user IDs
WHEN: The error message is accessed
THEN: Message includes the count and stores user IDs
"""
# Arrange
from server.routes.org_models import OrphanedUserError
user_ids = [str(uuid.uuid4()), str(uuid.uuid4())]
# Act
error = OrphanedUserError(user_ids)
# Assert
assert error.user_ids == user_ids
assert '2 user(s)' in str(error)
assert 'no remaining organization' in str(error)

View File

@@ -3,7 +3,9 @@ from unittest.mock import MagicMock, patch
import pytest
# Mock the database module before importing
with patch('storage.database.engine'), patch('storage.database.a_engine'):
with patch('storage.database.engine', create=True), patch(
'storage.database.a_engine', create=True
):
from integrations.github.github_view import get_user_proactive_conversation_setting
from storage.org import Org

View File

@@ -398,6 +398,121 @@ async def test_create_user_contact_name_falls_back_to_username():
assert org.contact_name == 'jdoe'
# --- Tests for email fields in create_user() ---
# create_user() should populate user.email and user.email_verified from the
# Keycloak user_info, ensuring the user table has the correct email data.
class _StopAfterUserCreation(Exception):
"""Halt create_user() after User creation for email field inspection."""
pass
@pytest.mark.asyncio
async def test_create_user_sets_email_from_user_info():
"""create_user() should set user.email and user.email_verified from user_info."""
# Arrange
user_id = str(uuid.uuid4())
user_info = {
'preferred_username': 'testuser',
'email': 'testuser@example.com',
'email_verified': True,
}
mock_session = MagicMock()
mock_sm = MagicMock()
mock_sm.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_sm.return_value.__exit__ = MagicMock(return_value=False)
mock_settings = Settings(language='en')
mock_role = MagicMock()
mock_role.id = 1
with (
patch('storage.user_store.session_maker', mock_sm),
patch.object(
UserStore,
'create_default_settings',
new_callable=AsyncMock,
return_value=mock_settings,
),
patch('storage.org_store.OrgStore.get_kwargs_from_settings', return_value={}),
patch.object(UserStore, 'get_kwargs_from_settings', return_value={}),
patch('storage.user_store.RoleStore.get_role_by_name', return_value=mock_role),
patch(
'storage.org_member_store.OrgMemberStore.get_kwargs_from_settings',
return_value={'llm_model': None, 'llm_base_url': None},
),
patch.object(
mock_session,
'commit',
side_effect=_StopAfterUserCreation,
),
):
# Act
with pytest.raises(_StopAfterUserCreation):
await UserStore.create_user(user_id, user_info)
# Assert - User is the second object added to session (after Org)
user = mock_session.add.call_args_list[1][0][0]
assert isinstance(user, User)
assert user.email == 'testuser@example.com'
assert user.email_verified is True
@pytest.mark.asyncio
async def test_create_user_handles_missing_email_verified():
"""create_user() should handle missing email_verified in user_info gracefully."""
# Arrange
user_id = str(uuid.uuid4())
user_info = {
'preferred_username': 'testuser',
'email': 'testuser@example.com',
# email_verified is not present
}
mock_session = MagicMock()
mock_sm = MagicMock()
mock_sm.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_sm.return_value.__exit__ = MagicMock(return_value=False)
mock_settings = Settings(language='en')
mock_role = MagicMock()
mock_role.id = 1
with (
patch('storage.user_store.session_maker', mock_sm),
patch.object(
UserStore,
'create_default_settings',
new_callable=AsyncMock,
return_value=mock_settings,
),
patch('storage.org_store.OrgStore.get_kwargs_from_settings', return_value={}),
patch.object(UserStore, 'get_kwargs_from_settings', return_value={}),
patch('storage.user_store.RoleStore.get_role_by_name', return_value=mock_role),
patch(
'storage.org_member_store.OrgMemberStore.get_kwargs_from_settings',
return_value={'llm_model': None, 'llm_base_url': None},
),
patch.object(
mock_session,
'commit',
side_effect=_StopAfterUserCreation,
),
):
# Act
with pytest.raises(_StopAfterUserCreation):
await UserStore.create_user(user_id, user_info)
# Assert - User should have email but email_verified should be None
user = mock_session.add.call_args_list[1][0][0]
assert isinstance(user, User)
assert user.email == 'testuser@example.com'
assert user.email_verified is None
# --- Tests for backfill_contact_name on login ---
# Existing users created before the resolve_display_name fix may have
# username-style values in contact_name. The backfill updates these to
@@ -522,3 +637,46 @@ async def test_backfill_contact_name_preserves_custom_value(session_maker):
with session_maker() as session:
org = session.query(Org).filter(Org.id == uuid.UUID(user_id)).first()
assert org.contact_name == 'Custom Corp Name'
def test_update_current_org_success(session_maker):
"""
GIVEN: User exists in database
WHEN: update_current_org is called with new org_id
THEN: User's current_org_id is updated and user is returned
"""
# Arrange
user_id = str(uuid.uuid4())
initial_org_id = uuid.uuid4()
new_org_id = uuid.uuid4()
with session_maker() as session:
user = User(id=uuid.UUID(user_id), current_org_id=initial_org_id)
session.add(user)
session.commit()
# Act
with patch('storage.user_store.session_maker', session_maker):
result = UserStore.update_current_org(user_id, new_org_id)
# Assert
assert result is not None
assert result.current_org_id == new_org_id
def test_update_current_org_user_not_found(session_maker):
"""
GIVEN: User does not exist in database
WHEN: update_current_org is called
THEN: None is returned
"""
# Arrange
user_id = str(uuid.uuid4())
org_id = uuid.uuid4()
# Act
with patch('storage.user_store.session_maker', session_maker):
result = UserStore.update_current_org(user_id, org_id)
# Assert
assert result is None

View File

@@ -0,0 +1,27 @@
import { describe, expect, it, vi } from "vitest";
import V1ConversationService from "#/api/conversation-service/v1-conversation-service.api";
const { mockGet } = vi.hoisted(() => ({ mockGet: vi.fn() }));
vi.mock("#/api/open-hands-axios", () => ({
openHands: { get: mockGet },
}));
describe("V1ConversationService", () => {
describe("readConversationFile", () => {
it("uses default plan path when filePath is not provided", async () => {
// Arrange
const conversationId = "conv-123";
mockGet.mockResolvedValue({ data: "# PLAN content" });
// Act
await V1ConversationService.readConversationFile(conversationId);
// Assert
expect(mockGet).toHaveBeenCalledTimes(1);
const callUrl = mockGet.mock.calls[0][0] as string;
expect(callUrl).toContain(
"file_path=%2Fworkspace%2Fproject%2F.agents_tmp%2FPLAN.md",
);
});
});
});

View File

@@ -125,7 +125,7 @@ describe("ChatInterface - Chat Suggestions", () => {
});
(useConfig as unknown as ReturnType<typeof vi.fn>).mockReturnValue({
data: { APP_MODE: "local" },
data: { app_mode: "local" },
});
(useGetTrajectory as unknown as ReturnType<typeof vi.fn>).mockReturnValue({
mutate: vi.fn(),
@@ -258,7 +258,7 @@ describe("ChatInterface - Empty state", () => {
errorMessage: null,
});
(useConfig as unknown as ReturnType<typeof vi.fn>).mockReturnValue({
data: { APP_MODE: "local" },
data: { app_mode: "local" },
});
(useGetTrajectory as unknown as ReturnType<typeof vi.fn>).mockReturnValue({
mutate: vi.fn(),

View File

@@ -113,15 +113,15 @@ describe("ExpandableMessage", () => {
it("should render the out of credits message when the user is out of credits", async () => {
const getConfigSpy = vi.spyOn(OptionService, "getConfig");
// @ts-expect-error - We only care about the APP_MODE and FEATURE_FLAGS fields
// @ts-expect-error - We only care about the app_mode and feature_flags fields
getConfigSpy.mockResolvedValue({
APP_MODE: "saas",
FEATURE_FLAGS: {
ENABLE_BILLING: true,
HIDE_LLM_SETTINGS: false,
ENABLE_JIRA: false,
ENABLE_JIRA_DC: false,
ENABLE_LINEAR: false,
app_mode: "saas",
feature_flags: {
enable_billing: true,
hide_llm_settings: false,
enable_jira: false,
enable_jira_dc: false,
enable_linear: false,
},
});
const RouterStub = createRoutesStub([

View File

@@ -0,0 +1,250 @@
import { describe, expect, it, vi, beforeEach } from "vitest";
import { screen } from "@testing-library/react";
import { QueryClient } from "@tanstack/react-query";
import { MemoryRouter, Route, Routes } from "react-router";
import { render } from "@testing-library/react";
import { QueryClientProvider } from "@tanstack/react-query";
import { useParamsMock, createUserMessageEvent } from "test-utils";
import { ChatInterface } from "#/components/features/chat/chat-interface";
import { useWsClient } from "#/context/ws-client-provider";
import { useConversationId } from "#/hooks/use-conversation-id";
import { useActiveConversation } from "#/hooks/query/use-active-conversation";
import { useConversationWebSocket } from "#/contexts/conversation-websocket-context";
import { useConfig } from "#/hooks/query/use-config";
import { useGetTrajectory } from "#/hooks/mutation/use-get-trajectory";
import { useUnifiedUploadFiles } from "#/hooks/mutation/use-unified-upload-files";
import { useEventStore } from "#/stores/use-event-store";
import { useAgentState } from "#/hooks/use-agent-state";
import { AgentState } from "#/types/agent-state";
import { OpenHandsAction } from "#/types/core/actions";
// Module-level mocks
vi.mock("#/context/ws-client-provider");
vi.mock("#/hooks/query/use-config");
vi.mock("#/hooks/mutation/use-get-trajectory");
vi.mock("#/hooks/mutation/use-unified-upload-files");
vi.mock("#/hooks/use-conversation-id");
vi.mock("#/hooks/query/use-active-conversation");
vi.mock("#/contexts/conversation-websocket-context");
vi.mock("#/hooks/use-user-providers", () => ({
useUserProviders: () => ({
providers: [],
}),
}));
vi.mock("#/hooks/use-conversation-name-context-menu", () => ({
useConversationNameContextMenu: () => ({
isOpen: false,
contextMenuRef: { current: null },
handleContextMenu: vi.fn(),
handleClose: vi.fn(),
handleRename: vi.fn(),
handleDelete: vi.fn(),
}),
}));
vi.mock("#/hooks/use-agent-state", () => ({
useAgentState: vi.fn(() => ({
curAgentState: AgentState.AWAITING_USER_INPUT,
})),
}));
// Helper to render with QueryClient and route params
const renderWithQueryClient = (
ui: React.ReactElement,
queryClient: QueryClient,
route = "/test-conversation-id",
) =>
render(
<QueryClientProvider client={queryClient}>
<MemoryRouter initialEntries={[route]}>
<Routes>
<Route path="/:conversationId" element={ui} />
<Route path="/" element={ui} />
</Routes>
</MemoryRouter>
</QueryClientProvider>,
);
// V0 user event (numeric id, action property)
const createV0UserEvent = (): OpenHandsAction => ({
id: 1,
source: "user",
action: "message",
args: {
content: "Hello from V0",
image_urls: [],
file_urls: [],
},
message: "Hello from V0",
timestamp: "2025-07-01T00:00:00Z",
});
describe("ChatInterface message display continuity (spec 3.1)", () => {
let queryClient: QueryClient;
beforeEach(() => {
queryClient = new QueryClient({
defaultOptions: { queries: { retry: false } },
});
useParamsMock.mockReturnValue({ conversationId: "test-conversation-id" });
vi.mocked(useConversationId).mockReturnValue({
conversationId: "test-conversation-id",
});
// Default: V0, no loading, no events
(useWsClient as unknown as ReturnType<typeof vi.fn>).mockReturnValue({
send: vi.fn(),
isLoadingMessages: false,
parsedEvents: [],
});
(useConfig as unknown as ReturnType<typeof vi.fn>).mockReturnValue({
data: { app_mode: "local" },
});
(useGetTrajectory as unknown as ReturnType<typeof vi.fn>).mockReturnValue({
mutate: vi.fn(),
mutateAsync: vi.fn(),
isLoading: false,
});
(
useUnifiedUploadFiles as unknown as ReturnType<typeof vi.fn>
).mockReturnValue({
mutateAsync: vi
.fn()
.mockResolvedValue({ skipped_files: [], uploaded_files: [] }),
isLoading: false,
});
// Default: no conversation (V0 behavior)
vi.mocked(useActiveConversation).mockReturnValue({
data: undefined,
} as ReturnType<typeof useActiveConversation>);
// Default: no websocket context
vi.mocked(useConversationWebSocket).mockReturnValue(null);
});
describe("V1 conversations", () => {
beforeEach(() => {
// Set up V1 conversation
vi.mocked(useActiveConversation).mockReturnValue({
data: { conversation_version: "V1" },
} as ReturnType<typeof useActiveConversation>);
});
it("shows messages immediately when V1 events exist in store, even while loading", () => {
// Simulate: history is loading but events already exist in store (e.g., remount)
vi.mocked(useConversationWebSocket).mockReturnValue({
isLoadingHistory: true,
connectionState: "OPEN",
sendMessage: vi.fn(),
});
// Put V1 user events in the store
const v1UserEvent = createUserMessageEvent("evt-1");
useEventStore.setState({
events: [v1UserEvent],
uiEvents: [v1UserEvent],
});
renderWithQueryClient(<ChatInterface />, queryClient);
// AC1: Messages should display immediately without skeleton
expect(
screen.queryByTestId("chat-messages-skeleton"),
).not.toBeInTheDocument();
expect(screen.queryByTestId("loading-spinner")).not.toBeInTheDocument();
});
it("shows skeleton when store is empty and loading", () => {
// Simulate: first load, no events yet
vi.mocked(useConversationWebSocket).mockReturnValue({
isLoadingHistory: true,
connectionState: "OPEN",
sendMessage: vi.fn(),
});
// Store is empty
useEventStore.setState({
events: [],
uiEvents: [],
});
renderWithQueryClient(<ChatInterface />, queryClient);
// AC5: Genuine first-load shows skeleton
expect(screen.getByTestId("chat-messages-skeleton")).toBeInTheDocument();
});
it("shows messages when loading is already false on mount (edge case)", () => {
// Simulate: component re-mounts when WebSocket has already finished loading
vi.mocked(useConversationWebSocket).mockReturnValue({
isLoadingHistory: false,
connectionState: "OPEN",
sendMessage: vi.fn(),
});
// V1 events in store
const v1UserEvent = createUserMessageEvent("evt-2");
useEventStore.setState({
events: [v1UserEvent],
uiEvents: [v1UserEvent],
});
renderWithQueryClient(<ChatInterface />, queryClient);
// Messages should display, no skeleton
expect(
screen.queryByTestId("chat-messages-skeleton"),
).not.toBeInTheDocument();
expect(screen.queryByTestId("loading-spinner")).not.toBeInTheDocument();
});
});
describe("V0 conversations", () => {
it("shows messages when V0 events exist in store even if isLoadingMessages is true", () => {
// Simulate: loading flag is still true but events already exist in store (e.g., remount)
(useWsClient as unknown as ReturnType<typeof vi.fn>).mockReturnValue({
send: vi.fn(),
isLoadingMessages: true,
parsedEvents: [],
});
// Put V0 user events in the store
useEventStore.setState({
events: [createV0UserEvent()],
uiEvents: [],
});
renderWithQueryClient(<ChatInterface />, queryClient);
// AC1/AC4: Messages display immediately, no skeleton
expect(
screen.queryByTestId("chat-messages-skeleton"),
).not.toBeInTheDocument();
});
it("shows skeleton when store is empty and isLoadingMessages is true", () => {
// Simulate: genuine first load, no events yet
(useWsClient as unknown as ReturnType<typeof vi.fn>).mockReturnValue({
send: vi.fn(),
isLoadingMessages: true,
parsedEvents: [],
});
// Store is empty
useEventStore.setState({
events: [],
uiEvents: [],
});
renderWithQueryClient(<ChatInterface />, queryClient);
// AC5: Genuine first-load shows skeleton
expect(screen.getByTestId("chat-messages-skeleton")).toBeInTheDocument();
});
});
});

View File

@@ -5,7 +5,7 @@ import { EventMessage } from "#/components/features/chat/event-message";
vi.mock("#/hooks/query/use-config", () => ({
useConfig: () => ({
data: { APP_MODE: "saas" },
data: { app_mode: "saas" },
}),
}));

View File

@@ -0,0 +1,287 @@
import { fireEvent, render, screen, within } from "@testing-library/react";
import { afterEach, describe, expect, it, vi } from "vitest";
import { act } from "react";
import { MemoryRouter } from "react-router";
import { AlertBanner } from "#/components/features/alerts/alert-banner";
// Mock react-i18next
vi.mock("react-i18next", async () => {
const actual =
await vi.importActual<typeof import("react-i18next")>("react-i18next");
return {
...actual,
useTranslation: () => ({
t: (key: string, options?: { time?: string }) => {
const translations: Record<string, string> = {
MAINTENANCE$SCHEDULED_MESSAGE: `Scheduled maintenance will begin at ${options?.time || "{{time}}"}`,
ALERT$FAULTY_MODELS_MESSAGE:
"The following models are currently reporting errors:",
"ERROR$TRANSLATED_KEY": "This is a translated error message",
};
return translations[key] || key;
},
}),
};
});
describe("AlertBanner", () => {
afterEach(() => {
localStorage.clear();
});
describe("Maintenance alerts", () => {
it("renders maintenance banner with formatted time", () => {
const startTime = "2024-01-15T10:00:00-05:00";
const updatedAt = "2024-01-14T10:00:00Z";
const { container } = render(
<MemoryRouter>
<AlertBanner maintenanceStartTime={startTime} updatedAt={updatedAt} />
</MemoryRouter>,
);
const banner = screen.queryByTestId("alert-banner");
expect(banner).toBeInTheDocument();
const svgIcon = container.querySelector("svg");
expect(svgIcon).toBeInTheDocument();
const button = within(banner!).queryByTestId("dismiss-button");
expect(button).toBeInTheDocument();
});
it("click on dismiss button removes banner", () => {
const startTime = "2024-01-15T10:00:00-05:00";
const updatedAt = "2024-01-14T10:00:00Z";
render(
<MemoryRouter>
<AlertBanner maintenanceStartTime={startTime} updatedAt={updatedAt} />
</MemoryRouter>,
);
const banner = screen.queryByTestId("alert-banner");
const button = within(banner!).queryByTestId("dismiss-button");
act(() => {
fireEvent.click(button!);
});
expect(banner).not.toBeInTheDocument();
});
it("banner reappears when updatedAt changes", () => {
const startTime = "2024-01-15T10:00:00-05:00";
const updatedAt = "2024-01-14T10:00:00Z";
const newUpdatedAt = "2024-01-15T10:00:00Z";
const { rerender } = render(
<MemoryRouter>
<AlertBanner maintenanceStartTime={startTime} updatedAt={updatedAt} />
</MemoryRouter>,
);
const banner = screen.queryByTestId("alert-banner");
const button = within(banner!).queryByTestId("dismiss-button");
act(() => {
fireEvent.click(button!);
});
expect(banner).not.toBeInTheDocument();
rerender(
<MemoryRouter>
<AlertBanner
maintenanceStartTime={startTime}
updatedAt={newUpdatedAt}
/>
</MemoryRouter>,
);
expect(screen.queryByTestId("alert-banner")).toBeInTheDocument();
});
});
describe("Faulty models alerts", () => {
it("renders banner with faulty models list", () => {
const faultyModels = ["gpt-4", "claude-3"];
const updatedAt = "2024-01-14T10:00:00Z";
render(
<MemoryRouter>
<AlertBanner faultyModels={faultyModels} updatedAt={updatedAt} />
</MemoryRouter>,
);
const banner = screen.queryByTestId("alert-banner");
expect(banner).toBeInTheDocument();
expect(
screen.getByText(/The following models are currently reporting errors:/),
).toBeInTheDocument();
// Models are displayed in the order they are provided
expect(screen.getByText(/gpt-4/)).toBeInTheDocument();
expect(screen.getByText(/claude-3/)).toBeInTheDocument();
});
it("does not render banner when faulty models array is empty", () => {
const updatedAt = "2024-01-14T10:00:00Z";
render(
<MemoryRouter>
<AlertBanner faultyModels={[]} updatedAt={updatedAt} />
</MemoryRouter>,
);
const banner = screen.queryByTestId("alert-banner");
expect(banner).not.toBeInTheDocument();
});
it("banner reappears when updatedAt changes", () => {
const faultyModels = ["gpt-4"];
const updatedAt = "2024-01-14T10:00:00Z";
const newUpdatedAt = "2024-01-15T10:00:00Z";
const { rerender } = render(
<MemoryRouter>
<AlertBanner faultyModels={faultyModels} updatedAt={updatedAt} />
</MemoryRouter>,
);
const banner = screen.queryByTestId("alert-banner");
const button = within(banner!).queryByTestId("dismiss-button");
act(() => {
fireEvent.click(button!);
});
expect(banner).not.toBeInTheDocument();
rerender(
<MemoryRouter>
<AlertBanner faultyModels={faultyModels} updatedAt={newUpdatedAt} />
</MemoryRouter>,
);
expect(screen.queryByTestId("alert-banner")).toBeInTheDocument();
});
});
describe("Error message alerts", () => {
it("renders banner with translated error message", () => {
const updatedAt = "2024-01-14T10:00:00Z";
render(
<MemoryRouter>
<AlertBanner errorMessage="ERROR$TRANSLATED_KEY" updatedAt={updatedAt} />
</MemoryRouter>,
);
const banner = screen.queryByTestId("alert-banner");
expect(banner).toBeInTheDocument();
expect(
screen.getByText("This is a translated error message"),
).toBeInTheDocument();
});
it("renders banner with raw error message when no translation exists", () => {
const rawErrorMessage = "This is a raw error without translation";
const updatedAt = "2024-01-14T10:00:00Z";
render(
<MemoryRouter>
<AlertBanner errorMessage={rawErrorMessage} updatedAt={updatedAt} />
</MemoryRouter>,
);
const banner = screen.queryByTestId("alert-banner");
expect(banner).toBeInTheDocument();
expect(screen.getByText(rawErrorMessage)).toBeInTheDocument();
});
it("does not render banner when error message is empty", () => {
const updatedAt = "2024-01-14T10:00:00Z";
render(
<MemoryRouter>
<AlertBanner errorMessage="" updatedAt={updatedAt} />
</MemoryRouter>,
);
const banner = screen.queryByTestId("alert-banner");
expect(banner).not.toBeInTheDocument();
});
it("does not render banner when error message is null", () => {
const updatedAt = "2024-01-14T10:00:00Z";
render(
<MemoryRouter>
<AlertBanner errorMessage={null} updatedAt={updatedAt} />
</MemoryRouter>,
);
const banner = screen.queryByTestId("alert-banner");
expect(banner).not.toBeInTheDocument();
});
});
describe("Multiple alerts", () => {
it("renders all alerts when multiple conditions are present", () => {
const startTime = "2024-01-15T10:00:00-05:00";
const faultyModels = ["gpt-4"];
const errorMessage = "ERROR$TRANSLATED_KEY";
const updatedAt = "2024-01-14T10:00:00Z";
render(
<MemoryRouter>
<AlertBanner
maintenanceStartTime={startTime}
faultyModels={faultyModels}
errorMessage={errorMessage}
updatedAt={updatedAt}
/>
</MemoryRouter>,
);
const banner = screen.queryByTestId("alert-banner");
expect(banner).toBeInTheDocument();
expect(
screen.getByText(/Scheduled maintenance will begin at/),
).toBeInTheDocument();
expect(
screen.getByText(/The following models are currently reporting errors:/),
).toBeInTheDocument();
expect(
screen.getByText("This is a translated error message"),
).toBeInTheDocument();
});
it("dismissing hides all alerts", () => {
const startTime = "2024-01-15T10:00:00-05:00";
const faultyModels = ["gpt-4"];
const updatedAt = "2024-01-14T10:00:00Z";
render(
<MemoryRouter>
<AlertBanner
maintenanceStartTime={startTime}
faultyModels={faultyModels}
updatedAt={updatedAt}
/>
</MemoryRouter>,
);
const banner = screen.queryByTestId("alert-banner");
const button = within(banner!).queryByTestId("dismiss-button");
act(() => {
fireEvent.click(button!);
});
expect(banner).not.toBeInTheDocument();
});
});
});

View File

@@ -151,8 +151,9 @@ describe("LoginContent", () => {
await user.click(githubButton);
// Wait for async handleAuthRedirect to complete
// The URL includes state parameter added by handleAuthRedirect
await waitFor(() => {
expect(window.location.href).toBe(mockUrl);
expect(window.location.href).toContain(mockUrl);
});
});
@@ -201,4 +202,103 @@ describe("LoginContent", () => {
expect(screen.getByTestId("terms-and-privacy-notice")).toBeInTheDocument();
});
it("should display invitation pending message when hasInvitation is true", () => {
render(
<MemoryRouter>
<LoginContent
githubAuthUrl="https://github.com/oauth/authorize"
appMode="saas"
providersConfigured={["github"]}
hasInvitation
/>
</MemoryRouter>,
);
expect(screen.getByText("AUTH$INVITATION_PENDING")).toBeInTheDocument();
});
it("should not display invitation pending message when hasInvitation is false", () => {
render(
<MemoryRouter>
<LoginContent
githubAuthUrl="https://github.com/oauth/authorize"
appMode="saas"
providersConfigured={["github"]}
hasInvitation={false}
/>
</MemoryRouter>,
);
expect(
screen.queryByText("AUTH$INVITATION_PENDING"),
).not.toBeInTheDocument();
});
it("should call buildOAuthStateData when clicking auth button", async () => {
const user = userEvent.setup();
const mockBuildOAuthStateData = vi.fn((baseState) => ({
...baseState,
invitation_token: "inv-test-token-12345",
}));
render(
<MemoryRouter>
<LoginContent
githubAuthUrl="https://github.com/login/oauth/authorize"
appMode="saas"
providersConfigured={["github"]}
buildOAuthStateData={mockBuildOAuthStateData}
/>
</MemoryRouter>,
);
const githubButton = screen.getByRole("button", {
name: "GITHUB$CONNECT_TO_GITHUB",
});
await user.click(githubButton);
await waitFor(() => {
expect(mockBuildOAuthStateData).toHaveBeenCalled();
const callArg = mockBuildOAuthStateData.mock.calls[0][0];
expect(callArg).toHaveProperty("redirect_url");
});
});
it("should encode state with invitation token when buildOAuthStateData provides token", async () => {
const user = userEvent.setup();
const mockBuildOAuthStateData = vi.fn((baseState) => ({
...baseState,
invitation_token: "inv-test-token-12345",
}));
render(
<MemoryRouter>
<LoginContent
githubAuthUrl="https://github.com/login/oauth/authorize"
appMode="saas"
providersConfigured={["github"]}
buildOAuthStateData={mockBuildOAuthStateData}
/>
</MemoryRouter>,
);
const githubButton = screen.getByRole("button", {
name: "GITHUB$CONNECT_TO_GITHUB",
});
await user.click(githubButton);
await waitFor(() => {
const redirectUrl = window.location.href;
// The URL should contain an encoded state parameter
expect(redirectUrl).toContain("state=");
// Decode and verify the state contains invitation_token
const url = new URL(redirectUrl);
const state = url.searchParams.get("state");
if (state) {
const decodedState = JSON.parse(atob(state));
expect(decodedState.invitation_token).toBe("inv-test-token-12345");
}
});
});
});

View File

@@ -92,19 +92,21 @@ describe("PlanPreview", () => {
});
it("should render nothing when planContent is null", () => {
renderPlanPreview(<PlanPreview planContent={null} />);
// Arrange & Act
const { container } = renderPlanPreview(<PlanPreview planContent={null} />);
const contentDiv = screen.getByTestId("plan-preview-content");
expect(contentDiv).toBeInTheDocument();
expect(contentDiv.textContent?.trim() || "").toBe("");
// Assert
expect(container.firstChild).toBeNull();
});
it("should render nothing when planContent is undefined", () => {
renderPlanPreview(<PlanPreview planContent={undefined} />);
// Arrange & Act
const { container } = renderPlanPreview(
<PlanPreview planContent={undefined} />,
);
const contentDiv = screen.getByTestId("plan-preview-content");
expect(contentDiv).toBeInTheDocument();
expect(contentDiv.textContent?.trim() || "").toBe("");
// Assert
expect(container.firstChild).toBeNull();
});
it("should render markdown content when planContent is provided", () => {
@@ -170,7 +172,7 @@ describe("PlanPreview", () => {
// Arrange
const user = userEvent.setup();
const expectedPrompt =
"Execute the plan based on the workspace/project/PLAN.md file.";
"Execute the plan based on the .agents_tmp/PLAN.md file.";
renderPlanPreview(<PlanPreview planContent="Plan content" />);
const buildButton = screen.getByTestId("plan-preview-build-button");
@@ -201,7 +203,7 @@ describe("PlanPreview", () => {
useOptimisticUserMessageStore.setState({ optimisticUserMessage: null });
const user = userEvent.setup();
const expectedPrompt =
"Execute the plan based on the workspace/project/PLAN.md file.";
"Execute the plan based on the .agents_tmp/PLAN.md file.";
renderPlanPreview(<PlanPreview planContent="Plan content" />);
const buildButton = screen.getByTestId("plan-preview-build-button");

View File

@@ -0,0 +1,112 @@
import { render, screen } from "@testing-library/react";
import { describe, it, expect, vi, beforeEach } from "vitest";
// Mutable mock state for controlling breakpoint
let mockIsMobile = false;
// Track ChatInterface unmount via vi.fn()
const chatInterfaceUnmount = vi.fn();
vi.mock("#/hooks/use-breakpoint", () => ({
useBreakpoint: () => mockIsMobile,
}));
vi.mock("#/hooks/use-resizable-panels", () => ({
useResizablePanels: () => ({
leftWidth: 50,
rightWidth: 50,
isDragging: false,
containerRef: { current: null },
handleMouseDown: vi.fn(),
}),
}));
vi.mock("#/stores/conversation-store", () => ({
useConversationStore: () => ({
isRightPanelShown: false,
}),
}));
// Mock ChatInterface with useEffect to track mount/unmount lifecycle
vi.mock("#/components/features/chat/chat-interface", () => {
// eslint-disable-next-line @typescript-eslint/no-require-imports
const React = require("react");
return {
ChatInterface: () => {
React.useEffect(() => {
return () => chatInterfaceUnmount();
}, []);
return <div data-testid="chat-interface">Chat Interface</div>;
},
};
});
vi.mock(
"#/components/features/conversation/conversation-tabs/conversation-tab-content/conversation-tab-content",
() => ({
ConversationTabContent: () => <div data-testid="tab-content" />,
}),
);
import { ConversationMain } from "#/components/features/conversation/conversation-main/conversation-main";
describe("ConversationMain - Layout Transition Stability", () => {
beforeEach(() => {
mockIsMobile = false;
chatInterfaceUnmount.mockClear();
});
it("renders ChatInterface at desktop width", () => {
mockIsMobile = false;
render(<ConversationMain />);
expect(screen.getByTestId("chat-interface")).toBeInTheDocument();
});
it("renders ChatInterface at mobile width", () => {
mockIsMobile = true;
render(<ConversationMain />);
expect(screen.getByTestId("chat-interface")).toBeInTheDocument();
});
it("does not unmount ChatInterface when crossing from desktop to mobile", () => {
mockIsMobile = false;
const { rerender } = render(<ConversationMain />);
expect(chatInterfaceUnmount).not.toHaveBeenCalled();
// Cross the breakpoint to mobile
mockIsMobile = true;
rerender(<ConversationMain />);
// ChatInterface must NOT have been unmounted and remounted
expect(chatInterfaceUnmount).not.toHaveBeenCalled();
expect(screen.getByTestId("chat-interface")).toBeInTheDocument();
});
it("does not unmount ChatInterface when crossing from mobile to desktop", () => {
mockIsMobile = true;
const { rerender } = render(<ConversationMain />);
expect(chatInterfaceUnmount).not.toHaveBeenCalled();
// Cross the breakpoint to desktop
mockIsMobile = false;
rerender(<ConversationMain />);
// ChatInterface must NOT have been unmounted and remounted
expect(chatInterfaceUnmount).not.toHaveBeenCalled();
expect(screen.getByTestId("chat-interface")).toBeInTheDocument();
});
it("survives rapid back-and-forth resize without unmounting ChatInterface", () => {
mockIsMobile = false;
const { rerender } = render(<ConversationMain />);
// Simulate rapid resize back and forth across the breakpoint
for (const mobile of [true, false, true, false, true]) {
mockIsMobile = mobile;
rerender(<ConversationMain />);
}
expect(chatInterfaceUnmount).not.toHaveBeenCalled();
expect(screen.getByTestId("chat-interface")).toBeInTheDocument();
});
});

View File

@@ -33,7 +33,7 @@ const {
})),
useConfigMock: vi.fn(() => ({
data: {
APP_MODE: "oss",
app_mode: "oss",
},
})),
}));
@@ -659,7 +659,7 @@ describe("ConversationNameContextMenu - Share Link Functionality", () => {
useConfigMock.mockReturnValue({
data: {
APP_MODE: "saas",
app_mode: "saas",
},
});
@@ -685,7 +685,7 @@ describe("ConversationNameContextMenu - Share Link Functionality", () => {
useConfigMock.mockReturnValue({
data: {
APP_MODE: "saas",
app_mode: "saas",
},
});
@@ -718,7 +718,7 @@ describe("ConversationNameContextMenu - Share Link Functionality", () => {
useConfigMock.mockReturnValue({
data: {
APP_MODE: "saas",
app_mode: "saas",
},
});
@@ -751,7 +751,7 @@ describe("ConversationNameContextMenu - Share Link Functionality", () => {
useConfigMock.mockReturnValue({
data: {
APP_MODE: "saas",
app_mode: "saas",
},
});
@@ -781,7 +781,7 @@ describe("ConversationNameContextMenu - Share Link Functionality", () => {
useConfigMock.mockReturnValue({
data: {
APP_MODE: "saas",
app_mode: "saas",
},
});
@@ -810,7 +810,7 @@ describe("ConversationNameContextMenu - Share Link Functionality", () => {
useConfigMock.mockReturnValue({
data: {
APP_MODE: "saas",
app_mode: "saas",
},
});
});

View File

@@ -177,10 +177,10 @@ describe("RepoConnector", () => {
it("should render the 'add github repos' link in dropdown if saas mode and github provider is set", async () => {
const getConfiSpy = vi.spyOn(OptionService, "getConfig");
// @ts-expect-error - only return the APP_MODE and APP_SLUG
// @ts-expect-error - only return the app_mode and github_app_slug
getConfiSpy.mockResolvedValue({
APP_MODE: "saas",
APP_SLUG: "openhands",
app_mode: "saas",
github_app_slug: "openhands",
});
const getSettingsSpy = vi.spyOn(SettingsService, "getSettings");
@@ -224,10 +224,10 @@ describe("RepoConnector", () => {
it("should not render the 'add github repos' link if github provider is not set", async () => {
const getConfiSpy = vi.spyOn(OptionService, "getConfig");
// @ts-expect-error - only return the APP_MODE and APP_SLUG
// @ts-expect-error - only return the app_mode and github_app_slug for this test
getConfiSpy.mockResolvedValue({
APP_MODE: "saas",
APP_SLUG: "openhands",
app_mode: "saas",
github_app_slug: "openhands",
});
const getSettingsSpy = vi.spyOn(SettingsService, "getSettings");
@@ -269,9 +269,9 @@ describe("RepoConnector", () => {
it("should not render the 'add github repos' link in dropdown if oss mode", async () => {
const getConfiSpy = vi.spyOn(OptionService, "getConfig");
// @ts-expect-error - only return the APP_MODE
// @ts-expect-error - only return the app_mode
getConfiSpy.mockResolvedValue({
APP_MODE: "oss",
app_mode: "oss",
});
const getSettingsSpy = vi.spyOn(SettingsService, "getSettings");

View File

@@ -30,7 +30,7 @@ vi.mock("#/hooks/query/use-is-authed", () => ({
vi.mock("#/hooks/query/use-config", () => ({
useConfig: () => ({
data: { APP_MODE: "saas" },
data: { app_mode: "saas" },
isLoading: false,
}),
}));

View File

@@ -1,119 +0,0 @@
import { fireEvent, render, screen, within } from "@testing-library/react";
import { afterEach, describe, expect, it, vi } from "vitest";
import { act } from "react";
import { MemoryRouter } from "react-router";
import { MaintenanceBanner } from "#/components/features/maintenance/maintenance-banner";
// Mock react-i18next
vi.mock("react-i18next", async () => {
const actual =
await vi.importActual<typeof import("react-i18next")>("react-i18next");
return {
...actual,
useTranslation: () => ({
t: (key: string, options?: { time?: string }) => {
const translations: Record<string, string> = {
MAINTENANCE$SCHEDULED_MESSAGE: `Scheduled maintenance will begin at ${options?.time || "{{time}}"}`,
};
return translations[key] || key;
},
}),
};
});
describe("MaintenanceBanner", () => {
afterEach(() => {
localStorage.clear();
});
it("renders maintenance banner with formatted time", () => {
const startTime = "2024-01-15T10:00:00-05:00"; // EST timestamp
const { container } = render(
<MemoryRouter>
<MaintenanceBanner startTime={startTime} />
</MemoryRouter>,
);
// Check if the banner is rendered
const banner = screen.queryByTestId("maintenance-banner");
expect(banner).toBeInTheDocument();
// Check if the warning icon (SVG) is present
const svgIcon = container.querySelector("svg");
expect(svgIcon).toBeInTheDocument();
// Check if the button to close is present
const button = within(banner!).queryByTestId("dismiss-button");
expect(button).toBeInTheDocument();
});
it("handles invalid date gracefully", () => {
// Suppress expected console.warn for invalid date parsing
const consoleWarnSpy = vi
.spyOn(console, "warn")
.mockImplementation(() => {});
const invalidTime = "invalid-date";
render(
<MemoryRouter>
<MaintenanceBanner startTime={invalidTime} />
</MemoryRouter>,
);
// Check if the banner is rendered
const banner = screen.queryByTestId("maintenance-banner");
expect(banner).not.toBeInTheDocument();
// Restore console.warn
consoleWarnSpy.mockRestore();
});
it("click on dismiss button removes banner", () => {
const startTime = "2024-01-15T10:00:00-05:00"; // EST timestamp
render(
<MemoryRouter>
<MaintenanceBanner startTime={startTime} />
</MemoryRouter>,
);
// Check if the banner is rendered
const banner = screen.queryByTestId("maintenance-banner");
const button = within(banner!).queryByTestId("dismiss-button");
act(() => {
fireEvent.click(button!);
});
expect(banner).not.toBeInTheDocument();
});
it("banner reappears after dismissing on next maintenance event(future time)", () => {
const startTime = "2024-01-15T10:00:00-05:00"; // EST timestamp
const nextStartTime = "2025-01-15T10:00:00-05:00"; // EST timestamp
const { rerender } = render(
<MemoryRouter>
<MaintenanceBanner startTime={startTime} />
</MemoryRouter>,
);
// Check if the banner is rendered
const banner = screen.queryByTestId("maintenance-banner");
const button = within(banner!).queryByTestId("dismiss-button");
act(() => {
fireEvent.click(button!);
});
expect(banner).not.toBeInTheDocument();
rerender(
<MemoryRouter>
<MaintenanceBanner startTime={nextStartTime} />
</MemoryRouter>,
);
expect(screen.queryByTestId("maintenance-banner")).toBeInTheDocument();
});
});

View File

@@ -305,7 +305,7 @@ describe("MicroagentManagement", () => {
mockUseConfig.mockReturnValue({
data: {
APP_MODE: "oss",
app_mode: "oss",
},
});

View File

@@ -27,17 +27,17 @@ describe("PaymentForm", () => {
const renderPaymentForm = () => renderWithProviders(<PaymentForm />);
beforeEach(() => {
// useBalance hook will return the balance only if the APP_MODE is "saas" and the billing feature is enabled
// useBalance hook will return the balance only if the app_mode is "saas" and the billing feature is enabled
// @ts-expect-error - partial mock for testing
getConfigSpy.mockResolvedValue({
APP_MODE: "saas",
GITHUB_CLIENT_ID: "123",
POSTHOG_CLIENT_KEY: "456",
FEATURE_FLAGS: {
ENABLE_BILLING: true,
HIDE_LLM_SETTINGS: false,
ENABLE_JIRA: false,
ENABLE_JIRA_DC: false,
ENABLE_LINEAR: false,
app_mode: "saas",
posthog_client_key: "456",
feature_flags: {
enable_billing: true,
hide_llm_settings: false,
enable_jira: false,
enable_jira_dc: false,
enable_linear: false,
},
});
});

View File

@@ -9,27 +9,34 @@ import { Sidebar } from "#/components/features/sidebar/sidebar";
import SettingsService from "#/api/settings-service/settings-service.api";
import OptionService from "#/api/option-service/option-service.api";
import { MOCK_DEFAULT_USER_SETTINGS } from "#/mocks/handlers";
import { GetConfigResponse } from "#/api/option-service/option.types";
import { WebClientConfig } from "#/api/option-service/option.types";
// Helper to create mock config with sensible defaults
const createMockConfig = (
overrides: Omit<Partial<GetConfigResponse>, "FEATURE_FLAGS"> & {
FEATURE_FLAGS?: Partial<GetConfigResponse["FEATURE_FLAGS"]>;
overrides: Omit<Partial<WebClientConfig>, "feature_flags"> & {
feature_flags?: Partial<WebClientConfig["feature_flags"]>;
} = {},
): GetConfigResponse => {
const { FEATURE_FLAGS: featureFlagOverrides, ...restOverrides } = overrides;
): WebClientConfig => {
const { feature_flags: featureFlagOverrides, ...restOverrides } = overrides;
return {
APP_MODE: "oss",
GITHUB_CLIENT_ID: "test-client-id",
POSTHOG_CLIENT_KEY: "test-posthog-key",
FEATURE_FLAGS: {
ENABLE_BILLING: false,
HIDE_LLM_SETTINGS: false,
ENABLE_JIRA: false,
ENABLE_JIRA_DC: false,
ENABLE_LINEAR: false,
app_mode: "oss",
posthog_client_key: "test-posthog-key",
feature_flags: {
enable_billing: false,
hide_llm_settings: false,
enable_jira: false,
enable_jira_dc: false,
enable_linear: false,
...featureFlagOverrides,
},
providers_configured: [],
maintenance_start_time: null,
auth_url: null,
recaptcha_site_key: null,
faulty_models: [],
error_message: null,
updated_at: "2024-01-14T10:00:00Z",
github_app_slug: null,
...restOverrides,
};
};
@@ -76,9 +83,9 @@ describe("Sidebar", () => {
});
describe("Settings modal auto-open behavior", () => {
it("should NOT open settings modal when HIDE_LLM_SETTINGS is true even with 404 error", async () => {
it("should NOT open settings modal when hide_llm_settings is true even with 404 error", async () => {
getConfigSpy.mockResolvedValue(
createMockConfig({ FEATURE_FLAGS: { HIDE_LLM_SETTINGS: true } }),
createMockConfig({ feature_flags: { hide_llm_settings: true } }),
);
getSettingsSpy.mockRejectedValue(createAxiosNotFoundErrorObject());
@@ -89,21 +96,21 @@ describe("Sidebar", () => {
expect(getSettingsSpy).toHaveBeenCalled();
});
// Settings modal should NOT appear when HIDE_LLM_SETTINGS is true
// Settings modal should NOT appear when hide_llm_settings is true
await waitFor(() => {
expect(screen.queryByTestId("ai-config-modal")).not.toBeInTheDocument();
});
});
it("should open settings modal when HIDE_LLM_SETTINGS is false and 404 error in OSS mode", async () => {
it("should open settings modal when hide_llm_settings is false and 404 error in OSS mode", async () => {
getConfigSpy.mockResolvedValue(
createMockConfig({ FEATURE_FLAGS: { HIDE_LLM_SETTINGS: false } }),
createMockConfig({ feature_flags: { hide_llm_settings: false } }),
);
getSettingsSpy.mockRejectedValue(createAxiosNotFoundErrorObject());
renderSidebar();
// Settings modal should appear when HIDE_LLM_SETTINGS is false
// Settings modal should appear when hide_llm_settings is false
await waitFor(() => {
expect(screen.getByTestId("ai-config-modal")).toBeInTheDocument();
});
@@ -112,8 +119,8 @@ describe("Sidebar", () => {
it("should NOT open settings modal in SaaS mode even with 404 error", async () => {
getConfigSpy.mockResolvedValue(
createMockConfig({
APP_MODE: "saas",
FEATURE_FLAGS: { HIDE_LLM_SETTINGS: false },
app_mode: "saas",
feature_flags: { hide_llm_settings: false },
}),
);
getSettingsSpy.mockRejectedValue(createAxiosNotFoundErrorObject());
@@ -133,7 +140,7 @@ describe("Sidebar", () => {
it("should NOT open settings modal when settings exist (no 404 error)", async () => {
getConfigSpy.mockResolvedValue(
createMockConfig({ FEATURE_FLAGS: { HIDE_LLM_SETTINGS: false } }),
createMockConfig({ feature_flags: { hide_llm_settings: false } }),
);
getSettingsSpy.mockResolvedValue(MOCK_DEFAULT_USER_SETTINGS);
@@ -152,7 +159,7 @@ describe("Sidebar", () => {
it("should NOT open settings modal when on /settings path", async () => {
getConfigSpy.mockResolvedValue(
createMockConfig({ FEATURE_FLAGS: { HIDE_LLM_SETTINGS: false } }),
createMockConfig({ feature_flags: { hide_llm_settings: false } }),
);
getSettingsSpy.mockRejectedValue(createAxiosNotFoundErrorObject());

View File

@@ -22,7 +22,7 @@ describe("PostHogWrapper", () => {
// Mock the config fetch
// @ts-expect-error - partial mock
vi.spyOn(OptionService, "getConfig").mockResolvedValue({
POSTHOG_CLIENT_KEY: "test-posthog-key",
posthog_client_key: "test-posthog-key",
});
});

View File

@@ -13,7 +13,7 @@ const useIsAuthedMock = vi
const useConfigMock = vi
.fn()
.mockReturnValue({ data: { APP_MODE: "saas" }, isLoading: false });
.mockReturnValue({ data: { app_mode: "saas" }, isLoading: false });
const useUserProvidersMock = vi
.fn()
@@ -46,7 +46,7 @@ describe("UserActions", () => {
// Reset all mocks to default values before each test
useIsAuthedMock.mockReturnValue({ data: true, isLoading: false });
useConfigMock.mockReturnValue({
data: { APP_MODE: "saas" },
data: { app_mode: "saas" },
isLoading: false,
});
useUserProvidersMock.mockReturnValue({
@@ -89,7 +89,7 @@ describe("UserActions", () => {
useIsAuthedMock.mockReturnValue({ data: false, isLoading: false });
// Keep other mocks with default values
useConfigMock.mockReturnValue({
data: { APP_MODE: "saas" },
data: { app_mode: "saas" },
isLoading: false,
});
useUserProvidersMock.mockReturnValue({
@@ -126,7 +126,7 @@ describe("UserActions", () => {
useIsAuthedMock.mockReturnValue({ data: false, isLoading: false });
// Keep other mocks with default values
useConfigMock.mockReturnValue({
data: { APP_MODE: "saas" },
data: { app_mode: "saas" },
isLoading: false,
});
useUserProvidersMock.mockReturnValue({
@@ -154,7 +154,7 @@ describe("UserActions", () => {
useIsAuthedMock.mockReturnValue({ data: false, isLoading: false });
// Keep other mocks with default values
useConfigMock.mockReturnValue({
data: { APP_MODE: "saas" },
data: { app_mode: "saas" },
isLoading: false,
});
useUserProvidersMock.mockReturnValue({
@@ -179,7 +179,7 @@ describe("UserActions", () => {
useIsAuthedMock.mockReturnValue({ data: true, isLoading: false });
// Ensure config and providers are set correctly
useConfigMock.mockReturnValue({
data: { APP_MODE: "saas" },
data: { app_mode: "saas" },
isLoading: false,
});
useUserProvidersMock.mockReturnValue({
@@ -211,7 +211,7 @@ describe("UserActions", () => {
// Start with authentication and providers
useIsAuthedMock.mockReturnValue({ data: true, isLoading: false });
useConfigMock.mockReturnValue({
data: { APP_MODE: "saas" },
data: { app_mode: "saas" },
isLoading: false,
});
useUserProvidersMock.mockReturnValue({
@@ -236,7 +236,7 @@ describe("UserActions", () => {
useIsAuthedMock.mockReturnValue({ data: false, isLoading: false });
// Keep other mocks with default values
useConfigMock.mockReturnValue({
data: { APP_MODE: "saas" },
data: { app_mode: "saas" },
isLoading: false,
});
useUserProvidersMock.mockReturnValue({
@@ -265,7 +265,7 @@ describe("UserActions", () => {
// Ensure authentication and providers are set correctly
useIsAuthedMock.mockReturnValue({ data: true, isLoading: false });
useConfigMock.mockReturnValue({
data: { APP_MODE: "saas" },
data: { app_mode: "saas" },
isLoading: false,
});
useUserProvidersMock.mockReturnValue({

View File

@@ -0,0 +1,29 @@
import { WebClientConfig } from "#/api/option-service/option.types";
/**
* Creates a mock WebClientConfig with all required fields.
* Use this helper to create test config objects with sensible defaults.
*/
export const createMockWebClientConfig = (
overrides: Partial<WebClientConfig> = {},
): WebClientConfig => ({
app_mode: "oss",
posthog_client_key: "test-posthog-key",
feature_flags: {
enable_billing: false,
hide_llm_settings: false,
enable_jira: false,
enable_jira_dc: false,
enable_linear: false,
...overrides.feature_flags,
},
providers_configured: [],
maintenance_start_time: null,
auth_url: null,
recaptcha_site_key: null,
faulty_models: [],
error_message: null,
updated_at: new Date().toISOString(),
github_app_slug: null,
...overrides,
});

View File

@@ -0,0 +1,104 @@
import { renderHook, waitFor } from "@testing-library/react";
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
import { describe, expect, it, vi } from "vitest";
import V1ConversationService from "#/api/conversation-service/v1-conversation-service.api";
import { useCreateConversation } from "#/hooks/mutation/use-create-conversation";
import { SuggestedTask } from "#/utils/types";
vi.mock("#/hooks/query/use-settings", async () => {
const actual = await vi.importActual<typeof import("#/hooks/query/use-settings")>(
"#/hooks/query/use-settings",
);
return {
...actual,
useSettings: vi.fn().mockReturnValue({
data: {
v1_enabled: true,
},
isLoading: false,
}),
};
});
vi.mock("#/hooks/use-tracking", () => ({
useTracking: () => ({
trackConversationCreated: vi.fn(),
}),
}));
describe("useCreateConversation", () => {
it("passes suggested tasks to the V1 create conversation API", async () => {
const createConversationSpy = vi
.spyOn(V1ConversationService, "createConversation")
.mockResolvedValue({
id: "task-id",
created_by_user_id: null,
status: "READY",
detail: null,
app_conversation_id: null,
sandbox_id: null,
agent_server_url: "http://agent-server.local",
request: {
sandbox_id: null,
initial_message: {
role: "user",
content: [{ type: "text", text: "Please address the comments" }],
},
processors: [],
llm_model: null,
selected_repository: null,
selected_branch: null,
git_provider: "github",
suggested_task: null,
title: null,
trigger: null,
pr_number: [],
parent_conversation_id: null,
agent_type: "default",
},
created_at: new Date().toISOString(),
updated_at: new Date().toISOString(),
});
const { result } = renderHook(() => useCreateConversation(), {
wrapper: ({ children }) => (
<QueryClientProvider client={new QueryClient()}>
{children}
</QueryClientProvider>
),
});
const suggestedTask: SuggestedTask = {
git_provider: "github",
issue_number: 42,
repo: "owner/repo",
title: "Resolve comments",
task_type: "UNRESOLVED_COMMENTS",
};
await result.current.mutateAsync({
query: "Please address the comments",
repository: {
name: "owner/repo",
gitProvider: "github",
branch: "main",
},
conversationInstructions: "Focus on review comments",
suggestedTask,
});
await waitFor(() => {
expect(createConversationSpy).toHaveBeenCalledWith(
"owner/repo",
"github",
"Please address the comments",
"main",
"Focus on review comments",
suggestedTask,
undefined,
undefined,
undefined,
);
});
});
});

View File

@@ -1,4 +1,4 @@
import { describe, it, expect, afterEach, vi } from "vitest";
import { describe, it, expect, afterEach, beforeEach, vi } from "vitest";
import React from "react";
import { renderHook, waitFor } from "@testing-library/react";
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
@@ -112,3 +112,192 @@ describe("useConversationHistory", () => {
expect(EventService.searchEventsV1).not.toHaveBeenCalled();
});
});
describe("useConversationHistory cache key stability", () => {
let localQueryClient: QueryClient;
let localWrapper: ({
children,
}: {
children: React.ReactNode;
}) => React.ReactElement;
beforeEach(() => {
localQueryClient = new QueryClient({
defaultOptions: {
queries: {
retry: false,
},
},
});
localWrapper = ({ children }: { children: React.ReactNode }) =>
React.createElement(
QueryClientProvider,
{ client: localQueryClient },
children,
);
});
afterEach(() => {
localQueryClient.clear();
vi.clearAllMocks();
});
it("does not refetch when conversation object changes but version stays the same", async () => {
const v1Spy = vi.spyOn(EventService, "searchEventsV1");
v1Spy.mockResolvedValue([makeEvent()]);
const conv1 = makeConversation("V1");
vi.mocked(useUserConversation).mockReturnValue({
data: conv1,
isLoading: false,
isPending: false,
isError: false,
error: null,
refetch: vi.fn(),
} as any);
const { result, rerender } = renderHook(
() => useConversationHistory("conv-stable"),
{ wrapper: localWrapper },
);
await waitFor(() => {
expect(result.current.data).toBeDefined();
});
expect(v1Spy).toHaveBeenCalledTimes(1);
// Simulate background polling: new object reference with different mutable fields
// but the SAME conversation_version
const conv2: Conversation = {
...conv1,
last_updated_at: "2099-01-01T00:00:00Z",
status: "STOPPED",
runtime_status: "STATUS$STOPPED",
};
vi.mocked(useUserConversation).mockReturnValue({
data: conv2,
isLoading: false,
isPending: false,
isError: false,
error: null,
refetch: vi.fn(),
} as any);
rerender();
// Allow any potential async refetch to trigger
await new Promise((r) => {
setTimeout(r, 50);
});
// Must NOT refetch — version hasn't changed, only mutable fields did
expect(v1Spy).toHaveBeenCalledTimes(1);
});
// Edge case: version change MUST trigger a refetch with the correct endpoint
it("refetches when conversation_version changes from V0 to V1", async () => {
const v0Spy = vi.spyOn(EventService, "searchEventsV0");
const v1Spy = vi.spyOn(EventService, "searchEventsV1");
v0Spy.mockResolvedValue([makeEvent()]);
v1Spy.mockResolvedValue([makeEvent()]);
// Start with V0
vi.mocked(useUserConversation).mockReturnValue({
data: makeConversation("V0"),
isLoading: false,
isPending: false,
isError: false,
error: null,
refetch: vi.fn(),
} as any);
const { result, rerender } = renderHook(
() => useConversationHistory("conv-version-change"),
{ wrapper: localWrapper },
);
await waitFor(() => {
expect(result.current.data).toBeDefined();
});
expect(v0Spy).toHaveBeenCalledTimes(1);
// Switch to V1 — new version means new cache key, must refetch
vi.mocked(useUserConversation).mockReturnValue({
data: makeConversation("V1"),
isLoading: false,
isPending: false,
isError: false,
error: null,
refetch: vi.fn(),
} as any);
rerender();
await waitFor(() => {
expect(v1Spy).toHaveBeenCalledTimes(1);
});
});
it("treats cached history as never stale (staleTime is Infinity)", async () => {
const v1Spy = vi.spyOn(EventService, "searchEventsV1");
v1Spy.mockResolvedValue([makeEvent()]);
vi.mocked(useUserConversation).mockReturnValue({
data: makeConversation("V1"),
isLoading: false,
isPending: false,
isError: false,
error: null,
refetch: vi.fn(),
} as any);
const { result } = renderHook(
() => useConversationHistory("conv-stale-check"),
{ wrapper: localWrapper },
);
await waitFor(() => {
expect(result.current.data).toBeDefined();
});
// Check the query's staleTime option in the cache
const queries = localQueryClient.getQueryCache().findAll({
queryKey: ["conversation-history", "conv-stale-check"],
});
expect(queries).toHaveLength(1);
expect((queries[0].options as Record<string, unknown>).staleTime).toBe(
Infinity,
);
});
it("has gcTime of at least 30 minutes for navigation resilience", async () => {
const v1Spy = vi.spyOn(EventService, "searchEventsV1");
v1Spy.mockResolvedValue([makeEvent()]);
vi.mocked(useUserConversation).mockReturnValue({
data: makeConversation("V1"),
isLoading: false,
isPending: false,
isError: false,
error: null,
refetch: vi.fn(),
} as any);
const { result } = renderHook(
() => useConversationHistory("conv-gc-check"),
{ wrapper: localWrapper },
);
await waitFor(() => {
expect(result.current.data).toBeDefined();
});
const queries = localQueryClient.getQueryCache().findAll({
queryKey: ["conversation-history", "conv-gc-check"],
});
expect(queries).toHaveLength(1);
expect(queries[0].options.gcTime).toBeGreaterThanOrEqual(30 * 60 * 1000);
});
});

View File

@@ -0,0 +1,180 @@
import { describe, expect, it, beforeEach, afterEach, vi } from "vitest";
import { renderHook, act } from "@testing-library/react";
import { useBreakpoint } from "#/hooks/use-breakpoint";
// Helper to set window.innerWidth and dispatch resize event
function setWindowWidth(width: number) {
Object.defineProperty(window, "innerWidth", {
writable: true,
configurable: true,
value: width,
});
window.dispatchEvent(new Event("resize"));
}
describe("useBreakpoint", () => {
const originalInnerWidth = window.innerWidth;
beforeEach(() => {
// Start at a known desktop width
Object.defineProperty(window, "innerWidth", {
writable: true,
configurable: true,
value: 1200,
});
});
afterEach(() => {
Object.defineProperty(window, "innerWidth", {
writable: true,
configurable: true,
value: originalInnerWidth,
});
});
it("returns false (not mobile) when window width is above the breakpoint", () => {
Object.defineProperty(window, "innerWidth", { value: 1200 });
const { result } = renderHook(() => useBreakpoint());
expect(result.current).toBe(false);
});
it("returns true (mobile) when window width is at the breakpoint (1024)", () => {
Object.defineProperty(window, "innerWidth", { value: 1024 });
const { result } = renderHook(() => useBreakpoint());
expect(result.current).toBe(true);
});
it("returns true (mobile) when window width is below the breakpoint", () => {
Object.defineProperty(window, "innerWidth", { value: 800 });
const { result } = renderHook(() => useBreakpoint());
expect(result.current).toBe(true);
});
it("updates from false to true when window resizes below the breakpoint", () => {
Object.defineProperty(window, "innerWidth", { value: 1200 });
const { result } = renderHook(() => useBreakpoint());
expect(result.current).toBe(false);
act(() => {
setWindowWidth(800);
});
expect(result.current).toBe(true);
});
it("updates from true to false when window resizes above the breakpoint", () => {
Object.defineProperty(window, "innerWidth", { value: 800 });
const { result } = renderHook(() => useBreakpoint());
expect(result.current).toBe(true);
act(() => {
setWindowWidth(1200);
});
expect(result.current).toBe(false);
});
it("does NOT trigger re-render when width changes within the desktop range", () => {
Object.defineProperty(window, "innerWidth", { value: 1200 });
const renderCount = vi.fn();
const { result } = renderHook(() => {
renderCount();
return useBreakpoint();
});
expect(result.current).toBe(false);
const initialRenderCount = renderCount.mock.calls.length;
// Resize within desktop range (still above 1024) — should NOT re-render
act(() => {
setWindowWidth(1300);
});
act(() => {
setWindowWidth(1100);
});
act(() => {
setWindowWidth(1025);
});
expect(result.current).toBe(false);
// No additional renders beyond the initial render
expect(renderCount.mock.calls.length).toBe(initialRenderCount);
});
it("does NOT trigger re-render when width changes within the mobile range", () => {
Object.defineProperty(window, "innerWidth", { value: 800 });
const renderCount = vi.fn();
const { result } = renderHook(() => {
renderCount();
return useBreakpoint();
});
expect(result.current).toBe(true);
const initialRenderCount = renderCount.mock.calls.length;
// Resize within mobile range (still at or below 1024) — should NOT re-render
act(() => {
setWindowWidth(600);
});
act(() => {
setWindowWidth(1024);
});
act(() => {
setWindowWidth(900);
});
expect(result.current).toBe(true);
expect(renderCount.mock.calls.length).toBe(initialRenderCount);
});
it("handles rapid resize across the breakpoint without issues", () => {
Object.defineProperty(window, "innerWidth", { value: 1200 });
const { result } = renderHook(() => useBreakpoint());
expect(result.current).toBe(false);
// Rapid toggles across the breakpoint
act(() => {
setWindowWidth(800);
});
expect(result.current).toBe(true);
act(() => {
setWindowWidth(1200);
});
expect(result.current).toBe(false);
act(() => {
setWindowWidth(1024);
});
expect(result.current).toBe(true);
act(() => {
setWindowWidth(1025);
});
expect(result.current).toBe(false);
});
it("cleans up the resize event listener on unmount", () => {
const removeEventListenerSpy = vi.spyOn(window, "removeEventListener");
const { unmount } = renderHook(() => useBreakpoint());
unmount();
expect(removeEventListenerSpy).toHaveBeenCalledWith(
"resize",
expect.any(Function),
);
removeEventListenerSpy.mockRestore();
});
it("accepts a custom breakpoint value", () => {
Object.defineProperty(window, "innerWidth", { value: 768 });
const { result } = renderHook(() => useBreakpoint(768));
expect(result.current).toBe(true);
act(() => {
setWindowWidth(769);
});
expect(result.current).toBe(false);
});
});

View File

@@ -0,0 +1,383 @@
import { describe, expect, it, beforeEach } from "vitest";
import { renderHook, act } from "@testing-library/react";
import { useFilteredEvents } from "#/hooks/use-filtered-events";
import { useEventStore } from "#/stores/use-event-store";
import type { OpenHandsAction } from "#/types/core/actions";
import type { ActionEvent, MessageEvent } from "#/types/v1/core";
import { SecurityRisk } from "#/types/v1/core";
// --- V0 event factories ---
function createV0UserMessage(id: number): OpenHandsAction {
return {
id,
source: "user",
action: "message",
args: { content: `User message ${id}`, image_urls: [], file_urls: [] },
message: `User message ${id}`,
timestamp: `2025-07-01T00:00:0${id}Z`,
};
}
function createV0AgentMessage(id: number): OpenHandsAction {
return {
id,
source: "agent",
action: "message",
args: {
thought: `Agent thought ${id}`,
image_urls: null,
file_urls: [],
wait_for_response: true,
},
message: `Agent response ${id}`,
timestamp: `2025-07-01T00:00:0${id}Z`,
};
}
function createV0SystemEvent(id: number): OpenHandsAction {
return {
id,
source: "environment",
action: "system",
args: {
content: "source .openhands/setup.sh",
tools: null,
openhands_version: null,
agent_class: null,
},
message: "Running setup script",
timestamp: `2025-07-01T00:00:0${id}Z`,
};
}
// --- V1 event factories ---
function createV1UserMessage(id: string): MessageEvent {
return {
id,
timestamp: "2025-07-01T00:00:01Z",
source: "user",
llm_message: {
role: "user",
content: [{ type: "text", text: `User message ${id}` }],
},
activated_microagents: [],
extended_content: [],
};
}
function createV1AgentAction(id: string): ActionEvent {
return {
id,
timestamp: "2025-07-01T00:00:02Z",
source: "agent",
thought: [{ type: "text", text: "Agent thought" }],
thinking_blocks: [],
action: {
kind: "ExecuteBashAction",
command: "echo test",
is_input: false,
timeout: null,
reset: false,
},
tool_name: "execute_bash",
tool_call_id: "call-1",
tool_call: {
id: "call-1",
type: "function",
function: { name: "execute_bash", arguments: '{"command": "echo test"}' },
},
llm_response_id: "response-1",
security_risk: SecurityRisk.UNKNOWN,
};
}
beforeEach(() => {
// Reset the event store before each test
useEventStore.setState({
events: [],
eventIds: new Set(),
uiEvents: [],
});
});
describe("useFilteredEvents", () => {
describe("referential stability", () => {
it("returns the same v0Events reference when storeEvents has not changed", () => {
const v0Event = createV0UserMessage(1);
useEventStore.setState({
events: [v0Event],
eventIds: new Set([1]),
uiEvents: [v0Event],
});
const { result, rerender } = renderHook(() => useFilteredEvents());
const firstV0Events = result.current.v0Events;
// Rerender without changing the store
rerender();
expect(result.current.v0Events).toBe(firstV0Events);
});
it("returns the same v1UiEvents reference when uiEvents has not changed", () => {
const v1Event = createV1UserMessage("msg-1");
useEventStore.setState({
events: [v1Event],
eventIds: new Set(["msg-1"]),
uiEvents: [v1Event],
});
const { result, rerender } = renderHook(() => useFilteredEvents());
const firstV1UiEvents = result.current.v1UiEvents;
rerender();
expect(result.current.v1UiEvents).toBe(firstV1UiEvents);
});
it("returns the same v1FullEvents reference when storeEvents has not changed", () => {
const v1Event = createV1UserMessage("msg-1");
useEventStore.setState({
events: [v1Event],
eventIds: new Set(["msg-1"]),
uiEvents: [v1Event],
});
const { result, rerender } = renderHook(() => useFilteredEvents());
const firstV1FullEvents = result.current.v1FullEvents;
rerender();
expect(result.current.v1FullEvents).toBe(firstV1FullEvents);
});
it("returns a new v0Events reference when storeEvents changes", () => {
const v0Event1 = createV0UserMessage(1);
useEventStore.setState({
events: [v0Event1],
eventIds: new Set([1]),
uiEvents: [v0Event1],
});
const { result } = renderHook(() => useFilteredEvents());
const firstV0Events = result.current.v0Events;
// Add a new event to the store (new array reference)
const v0Event2 = createV0AgentMessage(2);
act(() => {
useEventStore.setState({
events: [v0Event1, v0Event2],
eventIds: new Set([1, 2]),
uiEvents: [v0Event1, v0Event2],
});
});
expect(result.current.v0Events).not.toBe(firstV0Events);
expect(result.current.v0Events).toHaveLength(2);
});
});
describe("V0 event filtering", () => {
it("filters V0 events through isV0Event, isActionOrObservation, and shouldRenderEvent", () => {
const userMsg = createV0UserMessage(1);
const agentMsg = createV0AgentMessage(2);
useEventStore.setState({
events: [userMsg, agentMsg],
eventIds: new Set([1, 2]),
uiEvents: [userMsg, agentMsg],
});
const { result } = renderHook(() => useFilteredEvents());
expect(result.current.v0Events).toHaveLength(2);
expect(result.current.v0Events).toContainEqual(userMsg);
expect(result.current.v0Events).toContainEqual(agentMsg);
});
it("excludes V0 system events from v0Events", () => {
const userMsg = createV0UserMessage(1);
const systemEvent = createV0SystemEvent(2);
useEventStore.setState({
events: [userMsg, systemEvent],
eventIds: new Set([1, 2]),
uiEvents: [userMsg, systemEvent],
});
const { result } = renderHook(() => useFilteredEvents());
// System events are filtered out by shouldRenderEvent
expect(result.current.v0Events).toHaveLength(1);
expect(result.current.v0Events[0]).toEqual(userMsg);
});
it("does not include V1 events in v0Events", () => {
const v0Event = createV0UserMessage(1);
const v1Event = createV1UserMessage("msg-1");
useEventStore.setState({
events: [v0Event, v1Event],
eventIds: new Set([1, "msg-1"]),
uiEvents: [v0Event, v1Event],
});
const { result } = renderHook(() => useFilteredEvents());
expect(result.current.v0Events).toHaveLength(1);
expect(result.current.v0Events[0]).toEqual(v0Event);
});
});
describe("V1 event filtering", () => {
it("filters V1 events into v1FullEvents", () => {
const v1Event = createV1UserMessage("msg-1");
useEventStore.setState({
events: [v1Event],
eventIds: new Set(["msg-1"]),
uiEvents: [v1Event],
});
const { result } = renderHook(() => useFilteredEvents());
expect(result.current.v1FullEvents).toHaveLength(1);
expect(result.current.v1FullEvents[0]).toEqual(v1Event);
});
it("does not include V0 events in v1FullEvents", () => {
const v0Event = createV0UserMessage(1);
const v1Event = createV1UserMessage("msg-1");
useEventStore.setState({
events: [v0Event, v1Event],
eventIds: new Set([1, "msg-1"]),
uiEvents: [v0Event, v1Event],
});
const { result } = renderHook(() => useFilteredEvents());
expect(result.current.v1FullEvents).toHaveLength(1);
expect(result.current.v1FullEvents[0]).toEqual(v1Event);
});
});
describe("totalEvents", () => {
it("returns V0 event count when V0 events exist", () => {
const v0Event1 = createV0UserMessage(1);
const v0Event2 = createV0AgentMessage(2);
useEventStore.setState({
events: [v0Event1, v0Event2],
eventIds: new Set([1, 2]),
uiEvents: [v0Event1, v0Event2],
});
const { result } = renderHook(() => useFilteredEvents());
expect(result.current.totalEvents).toBe(2);
});
it("returns 0 when no events exist", () => {
const { result } = renderHook(() => useFilteredEvents());
expect(result.current.totalEvents).toBe(0);
});
});
describe("hasSubstantiveAgentActions", () => {
it("returns false when no events exist", () => {
const { result } = renderHook(() => useFilteredEvents());
expect(result.current.hasSubstantiveAgentActions).toBe(false);
});
it("returns false when only user events exist (V0)", () => {
const userMsg = createV0UserMessage(1);
useEventStore.setState({
events: [userMsg],
eventIds: new Set([1]),
uiEvents: [userMsg],
});
const { result } = renderHook(() => useFilteredEvents());
expect(result.current.hasSubstantiveAgentActions).toBe(false);
});
it("returns true when V0 agent message actions exist", () => {
const agentMsg = createV0AgentMessage(1);
useEventStore.setState({
events: [agentMsg],
eventIds: new Set([1]),
uiEvents: [agentMsg],
});
const { result } = renderHook(() => useFilteredEvents());
expect(result.current.hasSubstantiveAgentActions).toBe(true);
});
it("returns true when V1 agent action events exist", () => {
const agentAction = createV1AgentAction("action-1");
useEventStore.setState({
events: [agentAction],
eventIds: new Set(["action-1"]),
uiEvents: [agentAction],
});
const { result } = renderHook(() => useFilteredEvents());
expect(result.current.hasSubstantiveAgentActions).toBe(true);
});
});
describe("userEventsExist", () => {
it("returns false when no events exist", () => {
const { result } = renderHook(() => useFilteredEvents());
expect(result.current.userEventsExist).toBe(false);
});
it("returns true when V0 user events exist", () => {
const userMsg = createV0UserMessage(1);
useEventStore.setState({
events: [userMsg],
eventIds: new Set([1]),
uiEvents: [userMsg],
});
const { result } = renderHook(() => useFilteredEvents());
expect(result.current.v0UserEventsExist).toBe(true);
expect(result.current.userEventsExist).toBe(true);
});
it("returns true when V1 user events exist", () => {
const userMsg = createV1UserMessage("msg-1");
useEventStore.setState({
events: [userMsg],
eventIds: new Set(["msg-1"]),
uiEvents: [userMsg],
});
const { result } = renderHook(() => useFilteredEvents());
expect(result.current.v1UserEventsExist).toBe(true);
expect(result.current.userEventsExist).toBe(true);
});
});
describe("empty store", () => {
it("returns empty arrays and false flags for empty store", () => {
const { result } = renderHook(() => useFilteredEvents());
expect(result.current.v0Events).toEqual([]);
expect(result.current.v1UiEvents).toEqual([]);
expect(result.current.v1FullEvents).toEqual([]);
expect(result.current.totalEvents).toBe(0);
expect(result.current.hasSubstantiveAgentActions).toBe(false);
expect(result.current.v0UserEventsExist).toBe(false);
expect(result.current.v1UserEventsExist).toBe(false);
expect(result.current.userEventsExist).toBe(false);
});
});
});

View File

@@ -41,8 +41,7 @@ describe("useHandleBuildPlanClick", () => {
(createChatMessage as unknown as ReturnType<typeof vi.fn>).mockReturnValue({
action: "message",
args: {
content:
"Execute the plan based on the workspace/project/PLAN.md file.",
content: "Execute the plan based on the .agents_tmp/PLAN.md file.",
image_urls: [],
file_urls: [],
timestamp: expect.any(String),
@@ -78,7 +77,7 @@ describe("useHandleBuildPlanClick", () => {
// Arrange
const { result } = renderHook(() => useHandleBuildPlanClick());
const expectedPrompt =
"Execute the plan based on the workspace/project/PLAN.md file.";
"Execute the plan based on the .agents_tmp/PLAN.md file.";
// Act
act(() => {
@@ -109,7 +108,7 @@ describe("useHandleBuildPlanClick", () => {
useOptimisticUserMessageStore.setState({ optimisticUserMessage: null });
const { result } = renderHook(() => useHandleBuildPlanClick());
const expectedPrompt =
"Execute the plan based on the workspace/project/PLAN.md file.";
"Execute the plan based on the .agents_tmp/PLAN.md file.";
// Act
act(() => {
@@ -155,7 +154,7 @@ describe("useHandleBuildPlanClick", () => {
expect(useConversationStore.getState().conversationMode).toBe("code");
expect(mockSend).toHaveBeenCalledTimes(1);
expect(useOptimisticUserMessageStore.getState().optimisticUserMessage).toBe(
"Execute the plan based on the workspace/project/PLAN.md file.",
"Execute the plan based on the .agents_tmp/PLAN.md file.",
);
});

View File

@@ -0,0 +1,170 @@
import { act, renderHook } from "@testing-library/react";
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
const INVITATION_TOKEN_KEY = "openhands_invitation_token";
// Mock setSearchParams function
const mockSetSearchParams = vi.fn();
// Default mock searchParams
let mockSearchParamsData: Record<string, string> = {};
// Mock react-router
vi.mock("react-router", () => ({
useSearchParams: () => [
{
get: (key: string) => mockSearchParamsData[key] || null,
has: (key: string) => key in mockSearchParamsData,
},
mockSetSearchParams,
],
}));
// Import after mocking
import { useInvitation } from "#/hooks/use-invitation";
describe("useInvitation", () => {
beforeEach(() => {
// Clear localStorage before each test
localStorage.clear();
// Reset mock data
mockSearchParamsData = {};
mockSetSearchParams.mockClear();
});
afterEach(() => {
vi.clearAllMocks();
});
describe("initialization", () => {
it("should initialize with null token when localStorage is empty", () => {
// Arrange - localStorage is empty (cleared in beforeEach)
// Act
const { result } = renderHook(() => useInvitation());
// Assert
expect(result.current.invitationToken).toBeNull();
expect(result.current.hasInvitation).toBe(false);
});
it("should initialize with token from localStorage if present", () => {
// Arrange
const storedToken = "inv-stored-token-12345";
localStorage.setItem(INVITATION_TOKEN_KEY, storedToken);
// Act
const { result } = renderHook(() => useInvitation());
// Assert
expect(result.current.invitationToken).toBe(storedToken);
expect(result.current.hasInvitation).toBe(true);
});
});
describe("URL token capture", () => {
it("should capture invitation_token from URL and store in localStorage", () => {
// Arrange
const urlToken = "inv-url-token-67890";
mockSearchParamsData = { invitation_token: urlToken };
// Act
renderHook(() => useInvitation());
// Assert
expect(localStorage.getItem(INVITATION_TOKEN_KEY)).toBe(urlToken);
expect(mockSetSearchParams).toHaveBeenCalled();
});
});
describe("completion cleanup", () => {
it("should clear localStorage when email_mismatch param is present", () => {
// Arrange
const storedToken = "inv-token-to-clear";
localStorage.setItem(INVITATION_TOKEN_KEY, storedToken);
mockSearchParamsData = { email_mismatch: "true" };
// Act
const { result } = renderHook(() => useInvitation());
// Assert
expect(localStorage.getItem(INVITATION_TOKEN_KEY)).toBeNull();
expect(mockSetSearchParams).toHaveBeenCalled();
});
it("should clear localStorage when invitation_success param is present", () => {
// Arrange
const storedToken = "inv-token-to-clear";
localStorage.setItem(INVITATION_TOKEN_KEY, storedToken);
mockSearchParamsData = { invitation_success: "true" };
// Act
renderHook(() => useInvitation());
// Assert
expect(localStorage.getItem(INVITATION_TOKEN_KEY)).toBeNull();
});
it("should clear localStorage when invitation_expired param is present", () => {
// Arrange
localStorage.setItem(INVITATION_TOKEN_KEY, "inv-token");
mockSearchParamsData = { invitation_expired: "true" };
// Act
renderHook(() => useInvitation());
// Assert
expect(localStorage.getItem(INVITATION_TOKEN_KEY)).toBeNull();
});
});
describe("buildOAuthStateData", () => {
it("should include invitation_token in OAuth state when token is present", () => {
// Arrange
const token = "inv-oauth-token-12345";
localStorage.setItem(INVITATION_TOKEN_KEY, token);
const { result } = renderHook(() => useInvitation());
const baseState = { redirect_url: "/dashboard" };
// Act
const stateData = result.current.buildOAuthStateData(baseState);
// Assert
expect(stateData.invitation_token).toBe(token);
expect(stateData.redirect_url).toBe("/dashboard");
});
it("should not include invitation_token when no token is present", () => {
// Arrange - no token in localStorage
const { result } = renderHook(() => useInvitation());
const baseState = { redirect_url: "/dashboard" };
// Act
const stateData = result.current.buildOAuthStateData(baseState);
// Assert
expect(stateData.invitation_token).toBeUndefined();
expect(stateData.redirect_url).toBe("/dashboard");
});
});
describe("clearInvitation", () => {
it("should remove token from localStorage when called", () => {
// Arrange
localStorage.setItem(INVITATION_TOKEN_KEY, "inv-token-to-clear");
const { result } = renderHook(() => useInvitation());
// Act
act(() => {
result.current.clearInvitation();
});
// Assert
expect(localStorage.getItem(INVITATION_TOKEN_KEY)).toBeNull();
expect(result.current.invitationToken).toBeNull();
expect(result.current.hasInvitation).toBe(false);
});
});
});

View File

@@ -0,0 +1,126 @@
import { describe, expect, it, vi, beforeEach } from "vitest";
import { renderHook, act } from "@testing-library/react";
import { useScrollToBottom } from "#/hooks/use-scroll-to-bottom";
import type { RefObject } from "react";
/**
* Creates a mock scroll element with a trackable scrollTop setter.
*
* state.scrollTop can be set directly (bypassing the spy) to position
* the element for onChatBodyScroll calls without polluting the spy.
*/
function createMockScrollElement(initialScrollHeight = 1000) {
const state = {
scrollTop: 0,
scrollHeight: initialScrollHeight,
clientHeight: 500,
};
const scrollTopSetter = vi.fn((value: number) => {
state.scrollTop = value;
});
const element = {
get scrollTop() {
return state.scrollTop;
},
set scrollTop(value: number) {
scrollTopSetter(value);
},
get scrollHeight() {
return state.scrollHeight;
},
get clientHeight() {
return state.clientHeight;
},
} as unknown as HTMLDivElement;
return { element, scrollTopSetter, state };
}
describe("useScrollToBottom", () => {
let mock: ReturnType<typeof createMockScrollElement>;
let ref: RefObject<HTMLDivElement>;
beforeEach(() => {
mock = createMockScrollElement(1000);
ref = { current: mock.element } as RefObject<HTMLDivElement>;
});
describe("no automatic scrolling on render", () => {
it("does NOT scroll on initial render", () => {
renderHook(() => useScrollToBottom(ref));
// No useLayoutEffect means no automatic scroll-to-bottom
expect(mock.scrollTopSetter).not.toHaveBeenCalled();
});
it("does NOT scroll when re-rendered (e.g., during resize)", () => {
const { rerender } = renderHook(() => useScrollToBottom(ref));
mock.state.scrollHeight = 1500;
rerender();
expect(mock.scrollTopSetter).not.toHaveBeenCalled();
});
});
describe("scroll position tracking", () => {
it("tracks hitBottom correctly via onChatBodyScroll", () => {
const { result } = renderHook(() => useScrollToBottom(ref));
// Position at bottom: scrollTop(480) + clientHeight(500) = 980 >= 1000 - 20
mock.state.scrollTop = 480;
act(() => {
result.current.onChatBodyScroll(mock.element);
});
expect(result.current.hitBottom).toBe(true);
// Position not at bottom: scrollTop(200) + clientHeight(500) = 700 < 980
mock.state.scrollTop = 200;
act(() => {
result.current.onChatBodyScroll(mock.element);
});
expect(result.current.hitBottom).toBe(false);
});
it("disables autoScroll when user scrolls up", () => {
const { result } = renderHook(() => useScrollToBottom(ref));
// First scroll to establish prevScrollTopRef
mock.state.scrollTop = 400;
act(() => {
result.current.onChatBodyScroll(mock.element);
});
// Scroll up (lower scrollTop than previous)
mock.state.scrollTop = 200;
act(() => {
result.current.onChatBodyScroll(mock.element);
});
expect(result.current.autoScroll).toBe(false);
});
it("re-enables autoScroll when user reaches bottom", () => {
const { result } = renderHook(() => useScrollToBottom(ref));
// Scroll up to disable autoScroll
mock.state.scrollTop = 400;
act(() => {
result.current.onChatBodyScroll(mock.element);
});
mock.state.scrollTop = 200;
act(() => {
result.current.onChatBodyScroll(mock.element);
});
expect(result.current.autoScroll).toBe(false);
// Scroll to bottom
mock.state.scrollTop = 500; // 500 + 500 = 1000 >= 980
act(() => {
result.current.onChatBodyScroll(mock.element);
});
expect(result.current.autoScroll).toBe(true);
});
});
});

View File

@@ -12,8 +12,8 @@ const wrapper = ({ children }: { children: React.ReactNode }) => (
const mockConfig = (appMode: "saas" | "oss", hideLlmSettings = false) => {
vi.spyOn(OptionService, "getConfig").mockResolvedValue({
APP_MODE: appMode,
FEATURE_FLAGS: { HIDE_LLM_SETTINGS: hideLlmSettings },
app_mode: appMode,
feature_flags: { hide_llm_settings: hideLlmSettings },
} as Awaited<ReturnType<typeof OptionService.getConfig>>);
};
@@ -22,7 +22,7 @@ describe("useSettingsNavItems", () => {
queryClient.clear();
});
it("should return SAAS_NAV_ITEMS when APP_MODE is 'saas'", async () => {
it("should return SAAS_NAV_ITEMS when app_mode is 'saas'", async () => {
mockConfig("saas");
const { result } = renderHook(() => useSettingsNavItems(), { wrapper });
@@ -31,7 +31,7 @@ describe("useSettingsNavItems", () => {
});
});
it("should return OSS_NAV_ITEMS when APP_MODE is 'oss'", async () => {
it("should return OSS_NAV_ITEMS when app_mode is 'oss'", async () => {
mockConfig("oss");
const { result } = renderHook(() => useSettingsNavItems(), { wrapper });
@@ -40,7 +40,7 @@ describe("useSettingsNavItems", () => {
});
});
it("should filter out '/settings' item when HIDE_LLM_SETTINGS feature flag is enabled", async () => {
it("should filter out '/settings' item when hide_llm_settings feature flag is enabled", async () => {
mockConfig("saas", true);
const { result } = renderHook(() => useSettingsNavItems(), { wrapper });

Some files were not shown because too many files have changed in this diff Show More